forked from HSE_team/BetterCallPraskovia
* Add migration
* Delete legacy from bot * Clear old models * Единый http клиент * РАГ полечен
This commit is contained in:
parent
1ce1c23d10
commit
8bdacb4f7a
33
backend/alembic/versions/003_remove_embeddings_table.py
Normal file
33
backend/alembic/versions/003_remove_embeddings_table.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Remove unused embeddings table
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2024-12-24 12:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = '003'
|
||||
down_revision = '002'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table('embeddings')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
'embeddings',
|
||||
sa.Column('embedding_id', sa.dialects.postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('document_id', sa.dialects.postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('embedding', sa.dialects.postgresql.JSON(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('model_version', sa.String(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['document_id'], ['documents.document_id'], ),
|
||||
sa.PrimaryKeyConstraint('embedding_id')
|
||||
)
|
||||
|
||||
@ -8,6 +8,7 @@ from src.domain.repositories.document_repository import IDocumentRepository
|
||||
from src.domain.repositories.collection_repository import ICollectionRepository
|
||||
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
|
||||
from src.application.services.document_parser_service import DocumentParserService
|
||||
from src.application.services.rag_service import RAGService
|
||||
from src.shared.exceptions import NotFoundError, ForbiddenError
|
||||
|
||||
|
||||
@ -19,12 +20,14 @@ class DocumentUseCases:
|
||||
document_repository: IDocumentRepository,
|
||||
collection_repository: ICollectionRepository,
|
||||
access_repository: ICollectionAccessRepository,
|
||||
parser_service: DocumentParserService
|
||||
parser_service: DocumentParserService,
|
||||
rag_service: Optional[RAGService] = None
|
||||
):
|
||||
self.document_repository = document_repository
|
||||
self.collection_repository = collection_repository
|
||||
self.access_repository = access_repository
|
||||
self.parser_service = parser_service
|
||||
self.rag_service = rag_service
|
||||
|
||||
async def _check_collection_access(self, user_id: UUID, collection) -> bool:
|
||||
"""Проверить доступ пользователя к коллекции"""
|
||||
@ -64,7 +67,7 @@ class DocumentUseCases:
|
||||
filename: str,
|
||||
user_id: UUID
|
||||
) -> Document:
|
||||
"""Загрузить и распарсить документ"""
|
||||
"""Загрузить и распарсить документ, затем автоматически проиндексировать"""
|
||||
collection = await self.collection_repository.get_by_id(collection_id)
|
||||
if not collection:
|
||||
raise NotFoundError(f"Коллекция {collection_id} не найдена")
|
||||
@ -81,7 +84,15 @@ class DocumentUseCases:
|
||||
content=content,
|
||||
metadata={"filename": filename}
|
||||
)
|
||||
return await self.document_repository.create(document)
|
||||
document = await self.document_repository.create(document)
|
||||
|
||||
if self.rag_service:
|
||||
try:
|
||||
await self.rag_service.index_document(document)
|
||||
except Exception as e:
|
||||
print(f"Ошибка при автоматической индексации документа {document.document_id}: {e}")
|
||||
|
||||
return document
|
||||
|
||||
async def get_document(self, document_id: UUID) -> Document:
|
||||
"""Получить документ по ID"""
|
||||
|
||||
@ -1,25 +0,0 @@
|
||||
"""
|
||||
Доменная сущность Embedding
|
||||
"""
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Embedding:
|
||||
"""Эмбеддинг документа"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
document_id: UUID,
|
||||
embedding: list[float] | None = None,
|
||||
model_version: str = "",
|
||||
embedding_id: UUID | None = None,
|
||||
created_at: datetime | None = None
|
||||
):
|
||||
self.embedding_id = embedding_id or uuid4()
|
||||
self.document_id = document_id
|
||||
self.embedding = embedding or []
|
||||
self.model_version = model_version
|
||||
self.created_at = created_at or datetime.utcnow()
|
||||
|
||||
@ -53,19 +53,6 @@ class DocumentModel(Base):
|
||||
document_metadata = Column("metadata", JSON, nullable=True, default={})
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
collection = relationship("CollectionModel", back_populates="documents")
|
||||
embeddings = relationship("EmbeddingModel", back_populates="document", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class EmbeddingModel(Base):
|
||||
"""Модель эмбеддинга (заглушка)"""
|
||||
__tablename__ = "embeddings"
|
||||
|
||||
embedding_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
document_id = Column(UUID(as_uuid=True), ForeignKey("documents.document_id"), nullable=False)
|
||||
embedding = Column(JSON, nullable=True)
|
||||
model_version = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
document = relationship("DocumentModel", back_populates="embeddings")
|
||||
|
||||
|
||||
class ConversationModel(Base):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
API для RAG: индексация документов и ответы на вопросы
|
||||
API для RAG: ответы на вопросы
|
||||
"""
|
||||
from fastapi import APIRouter, status, Request
|
||||
from typing import Annotated
|
||||
@ -9,30 +9,13 @@ from src.presentation.middleware.auth_middleware import get_current_user
|
||||
from src.presentation.schemas.rag_schemas import (
|
||||
QuestionRequest,
|
||||
RAGAnswer,
|
||||
IndexDocumentRequest,
|
||||
IndexDocumentResponse,
|
||||
)
|
||||
from src.application.use_cases.rag_use_cases import RAGUseCases
|
||||
from src.domain.entities.user import User
|
||||
|
||||
|
||||
router = APIRouter(prefix="/rag", tags=["rag"])
|
||||
|
||||
|
||||
@router.post("/index", response_model=IndexDocumentResponse, status_code=status.HTTP_200_OK)
|
||||
@inject
|
||||
async def index_document(
|
||||
body: IndexDocumentRequest,
|
||||
request: Request,
|
||||
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||
use_cases: Annotated[RAGUseCases, FromDishka()],
|
||||
):
|
||||
"""Индексирование идет через чанкирование, далее эмбеддинг и загрузка в векторную бд"""
|
||||
current_user = await get_current_user(request, user_repo)
|
||||
result = await use_cases.index_document(body.document_id)
|
||||
return IndexDocumentResponse(**result)
|
||||
|
||||
|
||||
@router.post("/question", response_model=RAGAnswer, status_code=status.HTTP_200_OK)
|
||||
@inject
|
||||
async def ask_question(
|
||||
|
||||
@ -26,10 +26,3 @@ class RAGAnswer(BaseModel):
|
||||
usage: dict[str, Any] = {}
|
||||
|
||||
|
||||
class IndexDocumentRequest(BaseModel):
|
||||
document_id: UUID
|
||||
|
||||
|
||||
class IndexDocumentResponse(BaseModel):
|
||||
chunks_indexed: int
|
||||
|
||||
|
||||
@ -152,9 +152,10 @@ class UseCaseProvider(Provider):
|
||||
document_repo: IDocumentRepository,
|
||||
collection_repo: ICollectionRepository,
|
||||
access_repo: ICollectionAccessRepository,
|
||||
parser_service: DocumentParserService
|
||||
parser_service: DocumentParserService,
|
||||
rag_service: RAGService
|
||||
) -> DocumentUseCases:
|
||||
return DocumentUseCases(document_repo, collection_repo, access_repo, parser_service)
|
||||
return DocumentUseCases(document_repo, collection_repo, access_repo, parser_service, rag_service)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_conversation_use_cases(
|
||||
|
||||
@ -1,137 +1,130 @@
|
||||
import aiohttp
|
||||
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
||||
"""
|
||||
RAG сервис для бота - вызывает API бэкенда
|
||||
"""
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.infrastructure.http_client import create_http_session
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""Сервис для работы с RAG через API бэкенда"""
|
||||
|
||||
def __init__(self):
|
||||
self.deepseek_client = DeepSeekClient()
|
||||
|
||||
async def search_documents_in_collections(
|
||||
self,
|
||||
user_telegram_id: str,
|
||||
query: str,
|
||||
limit_per_collection: int = 5
|
||||
) -> list[dict]:
|
||||
async def get_or_create_conversation(
|
||||
self,
|
||||
user_telegram_id: str,
|
||||
collection_id: str = None
|
||||
) -> str | None:
|
||||
"""Получить или создать беседу для пользователя"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.get(
|
||||
f"{settings.BACKEND_URL}/users/telegram/{user_telegram_id}"
|
||||
) as user_response:
|
||||
if user_response.status != 200:
|
||||
return []
|
||||
|
||||
user_data = await user_response.json()
|
||||
user_uuid = str(user_data.get("user_id"))
|
||||
|
||||
if not user_uuid:
|
||||
return []
|
||||
|
||||
async with session.get(
|
||||
f"{settings.BACKEND_URL}/collections/",
|
||||
headers={"X-Telegram-ID": user_telegram_id}
|
||||
f"{settings.BACKEND_URL}/collections/",
|
||||
headers={"X-Telegram-ID": user_telegram_id}
|
||||
) as collections_response:
|
||||
if collections_response.status != 200:
|
||||
return []
|
||||
|
||||
return None
|
||||
collections = await collections_response.json()
|
||||
|
||||
all_documents = []
|
||||
for collection in collections:
|
||||
collection_id = collection.get("collection_id")
|
||||
if not collections:
|
||||
if not collection_id:
|
||||
async with session.post(
|
||||
f"{settings.BACKEND_URL}/collections",
|
||||
json={
|
||||
"name": "Основная коллекция",
|
||||
"description": "Коллекция по умолчанию",
|
||||
"is_public": False
|
||||
},
|
||||
headers={"X-Telegram-ID": user_telegram_id}
|
||||
) as create_collection_response:
|
||||
if create_collection_response.status in [200, 201]:
|
||||
collection_data = await create_collection_response.json()
|
||||
collection_id = collection_data.get("collection_id")
|
||||
else:
|
||||
collection_id = collection_id
|
||||
else:
|
||||
collection_id = collections[0].get("collection_id")
|
||||
|
||||
if not collection_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as search_session:
|
||||
async with search_session.get(
|
||||
f"{settings.BACKEND_URL}/documents/collection/{collection_id}",
|
||||
params={"search": query, "limit": limit_per_collection},
|
||||
headers={"X-Telegram-ID": user_telegram_id}
|
||||
) as search_response:
|
||||
if search_response.status == 200:
|
||||
documents = await search_response.json()
|
||||
for doc in documents:
|
||||
doc["collection_name"] = collection.get("name", "Unknown")
|
||||
all_documents.append(doc)
|
||||
except Exception as e:
|
||||
print(f"Error searching collection {collection_id}: {e}")
|
||||
continue
|
||||
|
||||
return all_documents[:20]
|
||||
|
||||
return None
|
||||
|
||||
async with session.get(
|
||||
f"{settings.BACKEND_URL}/conversations",
|
||||
headers={"X-Telegram-ID": user_telegram_id}
|
||||
) as conversations_response:
|
||||
if conversations_response.status == 200:
|
||||
conversations = await conversations_response.json()
|
||||
for conv in conversations:
|
||||
if conv.get("collection_id") == str(collection_id):
|
||||
return conv.get("conversation_id")
|
||||
|
||||
async with session.post(
|
||||
f"{settings.BACKEND_URL}/conversations",
|
||||
json={"collection_id": str(collection_id)},
|
||||
headers={"X-Telegram-ID": user_telegram_id}
|
||||
) as create_conversation_response:
|
||||
if create_conversation_response.status in [200, 201]:
|
||||
conversation_data = await create_conversation_response.json()
|
||||
return conversation_data.get("conversation_id")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error searching documents: {e}")
|
||||
return []
|
||||
print(f"Error getting/creating conversation: {e}")
|
||||
return None
|
||||
|
||||
async def generate_answer_with_rag(
|
||||
self,
|
||||
question: str,
|
||||
user_telegram_id: str
|
||||
self,
|
||||
question: str,
|
||||
user_telegram_id: str
|
||||
) -> dict:
|
||||
documents = await self.search_documents_in_collections(
|
||||
user_telegram_id,
|
||||
question
|
||||
)
|
||||
|
||||
context_parts = []
|
||||
sources = []
|
||||
|
||||
for doc in documents[:5]:
|
||||
title = doc.get("title", "Без названия")
|
||||
content = doc.get("content", "")[:1000]
|
||||
collection_name = doc.get("collection_name", "Unknown")
|
||||
|
||||
context_parts.append(f"Документ: {title}\nКоллекция: {collection_name}\nСодержание: {content[:500]}...")
|
||||
sources.append({
|
||||
"title": title,
|
||||
"collection": collection_name,
|
||||
"document_id": doc.get("document_id")
|
||||
})
|
||||
|
||||
context = "\n\n".join(context_parts) if context_parts else "Релевантные документы не найдены."
|
||||
|
||||
system_prompt = """Ты - помощник-юрист, который отвечает на вопросы на основе предоставленных документов.
|
||||
Используй информацию из документов для формирования точного и полезного ответа.
|
||||
Если в документах нет информации для ответа, честно скажи об этом."""
|
||||
|
||||
user_prompt = f"""Контекст из документов:
|
||||
{context}
|
||||
|
||||
Вопрос пользователя: {question}
|
||||
|
||||
Ответь на вопрос, используя информацию из предоставленных документов. Если информации недостаточно, укажи это."""
|
||||
|
||||
"""Генерирует ответ используя RAG через API бэкенда"""
|
||||
try:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
response = await self.deepseek_client.chat_completion(
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
return {
|
||||
"answer": response.get("content", "Failed to generate answer"),
|
||||
"sources": sources,
|
||||
"usage": response.get("usage", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error generating answer: {e}")
|
||||
if documents:
|
||||
conversation_id = await self.get_or_create_conversation(user_telegram_id)
|
||||
if not conversation_id:
|
||||
return {
|
||||
"answer": f"Found {len(documents)} documents but failed to generate answer",
|
||||
"sources": sources[:3],
|
||||
"usage": {}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"answer": "No relevant documents found",
|
||||
"answer": "Не удалось создать беседу. Попробуйте позже.",
|
||||
"sources": [],
|
||||
"usage": {}
|
||||
}
|
||||
|
||||
async with create_http_session() as session:
|
||||
async with session.post(
|
||||
f"{settings.BACKEND_URL}/rag/question",
|
||||
json={
|
||||
"conversation_id": str(conversation_id),
|
||||
"question": question,
|
||||
"top_k": 20,
|
||||
"rerank_top_n": 5
|
||||
},
|
||||
headers={"X-Telegram-ID": user_telegram_id}
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
sources = []
|
||||
for source in result.get("sources", []):
|
||||
sources.append({
|
||||
"title": source.get("title", "Без названия"),
|
||||
"document_id": source.get("document_id"),
|
||||
"chunk_id": source.get("chunk_id"),
|
||||
"index": source.get("index", 0)
|
||||
})
|
||||
|
||||
return {
|
||||
"answer": result.get("answer", "Не удалось сгенерировать ответ."),
|
||||
"sources": sources,
|
||||
"usage": result.get("usage", {}),
|
||||
"conversation_id": str(conversation_id)
|
||||
}
|
||||
else:
|
||||
error_text = await response.text()
|
||||
print(f"RAG API error: {response.status} - {error_text}")
|
||||
return {
|
||||
"answer": "Ошибка при генерации ответа. Попробуйте позже.",
|
||||
"sources": [],
|
||||
"usage": {}
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error generating answer with RAG: {e}")
|
||||
return {
|
||||
"answer": "Произошла ошибка при генерации ответа. Попробуйте позже.",
|
||||
"sources": [],
|
||||
"usage": {}
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Настройки приложения (загружаются из .env файла в корне проекта)"""
|
||||
"""Настройки приложения получаеи из env файла, тут не ищи, мы спрятали:)"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
@ -16,7 +16,7 @@ class Settings(BaseSettings):
|
||||
VERSION: str = "0.1.0"
|
||||
DEBUG: bool = False
|
||||
|
||||
TELEGRAM_BOT_TOKEN: str = ""
|
||||
TELEGRAM_BOT_TOKEN: str
|
||||
|
||||
FREE_QUESTIONS_LIMIT: int = 5
|
||||
PAYMENT_AMOUNT: float = 500.0
|
||||
@ -25,8 +25,8 @@ class Settings(BaseSettings):
|
||||
LOG_FILE: str = "logs/bot.log"
|
||||
|
||||
|
||||
YOOKASSA_SHOP_ID: str = ""
|
||||
YOOKASSA_SECRET_KEY: str = ""
|
||||
YOOKASSA_SHOP_ID: str
|
||||
YOOKASSA_SECRET_KEY: str
|
||||
YOOKASSA_RETURN_URL: str = "https://t.me/vibelawyer_bot"
|
||||
YOOKASSA_WEBHOOK_SECRET: Optional[str] = None
|
||||
|
||||
@ -35,7 +35,7 @@ class Settings(BaseSettings):
|
||||
DEEPSEEK_API_URL: str = "https://api.deepseek.com/v1/chat/completions"
|
||||
|
||||
|
||||
BACKEND_URL: str = "http://localhost:8000/api/v1"
|
||||
BACKEND_URL: str
|
||||
|
||||
|
||||
ADMIN_IDS_STR: str = ""
|
||||
|
||||
@ -3,7 +3,7 @@ import aiohttp
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.infrastructure.http_client import create_http_session, normalize_backend_url
|
||||
from tg_bot.infrastructure.http_client import create_http_session
|
||||
|
||||
|
||||
class User:
|
||||
@ -40,7 +40,7 @@ class UserService:
|
||||
"""Сервис для работы с пользователями через API бэкенда"""
|
||||
|
||||
def __init__(self):
|
||||
self.backend_url = normalize_backend_url(settings.BACKEND_URL)
|
||||
self.backend_url = settings.BACKEND_URL
|
||||
print(f"UserService initialized with BACKEND_URL: {self.backend_url}")
|
||||
|
||||
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[User]:
|
||||
@ -48,7 +48,7 @@ class UserService:
|
||||
try:
|
||||
url = f"{self.backend_url}/users/telegram/{telegram_id}"
|
||||
async with create_http_session() as session:
|
||||
async with session.get(url, ssl=False) as response:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return User(data)
|
||||
@ -74,8 +74,7 @@ class UserService:
|
||||
async with create_http_session() as session:
|
||||
async with session.post(
|
||||
f"{self.backend_url}/users",
|
||||
json={"telegram_id": str(telegram_id), "role": "user"},
|
||||
ssl=False
|
||||
json={"telegram_id": str(telegram_id), "role": "user"}
|
||||
) as response:
|
||||
if response.status in [200, 201]:
|
||||
data = await response.json()
|
||||
@ -106,8 +105,7 @@ class UserService:
|
||||
try:
|
||||
async with create_http_session() as session:
|
||||
async with session.post(
|
||||
f"{self.backend_url}/users/telegram/{telegram_id}/increment-questions",
|
||||
ssl=False
|
||||
f"{self.backend_url}/users/telegram/{telegram_id}/increment-questions"
|
||||
) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
@ -120,8 +118,7 @@ class UserService:
|
||||
async with create_http_session() as session:
|
||||
async with session.post(
|
||||
f"{self.backend_url}/users/telegram/{telegram_id}/activate-premium",
|
||||
params={"days": days},
|
||||
ssl=False
|
||||
params={"days": days}
|
||||
) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
172
tg_bot/infrastructure/external/deepseek_client.py
vendored
172
tg_bot/infrastructure/external/deepseek_client.py
vendored
@ -1,172 +0,0 @@
|
||||
import json
|
||||
from typing import Optional, AsyncIterator
|
||||
import httpx
|
||||
from tg_bot.config.settings import settings
|
||||
|
||||
|
||||
class DeepSeekAPIError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeepSeekClient:
|
||||
|
||||
def __init__(self, api_key: str | None = None, api_url: str | None = None):
|
||||
self.api_key = api_key or settings.DEEPSEEK_API_KEY
|
||||
self.api_url = api_url or settings.DEEPSEEK_API_URL
|
||||
self.timeout = 60.0
|
||||
|
||||
def _get_headers(self) -> dict[str, str]:
|
||||
if not self.api_key:
|
||||
raise DeepSeekAPIError("API key not set")
|
||||
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str = "deepseek-chat",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
stream: bool = False
|
||||
) -> dict:
|
||||
if not self.api_key:
|
||||
return {
|
||||
"content": "API key not configured",
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
if max_tokens is not None:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
self.api_url,
|
||||
headers=self._get_headers(),
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
else:
|
||||
raise DeepSeekAPIError("Invalid response format")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0)
|
||||
}
|
||||
}
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_msg = f"API error: {e.response.status_code}"
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
if "error" in error_data:
|
||||
error_msg = error_data['error'].get('message', error_msg)
|
||||
except:
|
||||
pass
|
||||
raise DeepSeekAPIError(error_msg) from e
|
||||
except httpx.RequestError as e:
|
||||
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
|
||||
except Exception as e:
|
||||
raise DeepSeekAPIError(str(e)) from e
|
||||
|
||||
async def stream_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str = "deepseek-chat",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
if not self.api_key:
|
||||
yield "API key not configured"
|
||||
return
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
if max_tokens is not None:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
self.api_url,
|
||||
headers=self._get_headers(),
|
||||
json=payload
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_msg = f"API error: {e.response.status_code}"
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
if "error" in error_data:
|
||||
error_msg = error_data['error'].get('message', error_msg)
|
||||
except:
|
||||
pass
|
||||
raise DeepSeekAPIError(error_msg) from e
|
||||
except httpx.RequestError as e:
|
||||
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
|
||||
except Exception as e:
|
||||
raise DeepSeekAPIError(str(e)) from e
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
if not self.api_key:
|
||||
return False
|
||||
|
||||
try:
|
||||
test_messages = [{"role": "user", "content": "test"}]
|
||||
await self.chat_completion(test_messages, max_tokens=1)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@ -1,87 +1,23 @@
|
||||
"""HTTP client utilities for making requests to the backend API"""
|
||||
import aiohttp
|
||||
from typing import Optional
|
||||
import ssl
|
||||
import os
|
||||
|
||||
|
||||
def get_windows_host_ip() -> Optional[str]:
|
||||
"""
|
||||
Get the Windows host IP address when running in WSL.
|
||||
In WSL2, the Windows host IP is typically the first nameserver in /etc/resolv.conf.
|
||||
"""
|
||||
try:
|
||||
if os.path.exists("/etc/resolv.conf"):
|
||||
with open("/etc/resolv.conf", "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("nameserver"):
|
||||
ip = line.split()[1]
|
||||
if ip not in ["127.0.0.1", "127.0.0.53"] and not ip.startswith("fe80"):
|
||||
return ip
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def normalize_backend_url(url: str) -> str:
|
||||
"""
|
||||
Normalize backend URL for better compatibility, especially on WSL and Docker.
|
||||
"""
|
||||
if not ("localhost" in url or "127.0.0.1" in url):
|
||||
return url
|
||||
if os.path.exists("/.dockerenv"):
|
||||
print(f"Warning: Running in Docker but URL contains localhost: {url}")
|
||||
print("Please set BACKEND_URL environment variable in docker-compose.yml to use Docker service name (e.g., http://backend:8000/api/v1)")
|
||||
return url.replace("localhost", "127.0.0.1")
|
||||
try:
|
||||
if os.path.exists("/proc/version"):
|
||||
with open("/proc/version", "r") as f:
|
||||
version_content = f.read().lower()
|
||||
if "microsoft" in version_content:
|
||||
windows_ip = get_windows_host_ip()
|
||||
if windows_ip:
|
||||
if "localhost" in url or "127.0.0.1" in url:
|
||||
url = url.replace("localhost", windows_ip).replace("127.0.0.1", windows_ip)
|
||||
print(f"WSL detected: Using Windows host IP {windows_ip} for backend connection")
|
||||
return url
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not detect WSL environment: {e}")
|
||||
|
||||
if url.startswith("http://localhost") or url.startswith("https://localhost"):
|
||||
return url.replace("localhost", "127.0.0.1")
|
||||
return url
|
||||
|
||||
|
||||
def create_http_session(timeout: Optional[aiohttp.ClientTimeout] = None) -> aiohttp.ClientSession:
|
||||
"""
|
||||
Create a configured aiohttp ClientSession for backend API requests.
|
||||
|
||||
Args:
|
||||
timeout: Optional timeout configuration. Defaults to 30 seconds total timeout.
|
||||
|
||||
Returns:
|
||||
Configured aiohttp.ClientSession
|
||||
Создаем сессию для запросов к бэку
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = aiohttp.ClientTimeout(total=30, connect=10)
|
||||
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=False,
|
||||
limit=100,
|
||||
limit_per_host=30,
|
||||
force_close=True,
|
||||
enable_cleanup_closed=True
|
||||
limit_per_host=30
|
||||
)
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
return aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@ from aiogram.filters import Command, StateFilter
|
||||
from aiogram.fsm.context import FSMContext
|
||||
import aiohttp
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.infrastructure.http_client import create_http_session
|
||||
from tg_bot.infrastructure.telegram.states.collection_states import (
|
||||
CollectionAccessStates,
|
||||
CollectionEditStates
|
||||
@ -14,7 +15,7 @@ router = Router()
|
||||
|
||||
async def get_user_collections(telegram_id: str):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.get(
|
||||
f"{settings.BACKEND_URL}/collections/",
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
@ -33,7 +34,7 @@ async def get_collection_documents(collection_id: str, telegram_id: str):
|
||||
url = f"{settings.BACKEND_URL}/documents/collection/{collection_id}"
|
||||
print(f"DEBUG get_collection_documents: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
@ -57,7 +58,7 @@ async def get_collection_documents(collection_id: str, telegram_id: str):
|
||||
|
||||
async def search_in_collection(collection_id: str, query: str, telegram_id: str):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.get(
|
||||
f"{settings.BACKEND_URL}/documents/collection/{collection_id}",
|
||||
params={"search": query},
|
||||
@ -78,7 +79,7 @@ async def get_collection_info(collection_id: str, telegram_id: str):
|
||||
url = f"{settings.BACKEND_URL}/collections/{collection_id}"
|
||||
print(f"DEBUG get_collection_info: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
@ -103,7 +104,7 @@ async def get_collection_info(collection_id: str, telegram_id: str):
|
||||
async def get_collection_access_list(collection_id: str, telegram_id: str):
|
||||
"""Получить список пользователей с доступом к коллекции"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.get(
|
||||
f"{settings.BACKEND_URL}/collections/{collection_id}/access",
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
@ -122,7 +123,7 @@ async def grant_collection_access(collection_id: str, telegram_id: str, owner_te
|
||||
url = f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}"
|
||||
print(f"DEBUG grant_collection_access: URL={url}, target_telegram_id={telegram_id}, owner_telegram_id={owner_telegram_id}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
headers={"X-Telegram-ID": owner_telegram_id}
|
||||
@ -145,7 +146,7 @@ async def grant_collection_access(collection_id: str, telegram_id: str, owner_te
|
||||
async def revoke_collection_access(collection_id: str, telegram_id: str, owner_telegram_id: str):
|
||||
"""Отозвать доступ к коллекции"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.delete(
|
||||
f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}",
|
||||
headers={"X-Telegram-ID": owner_telegram_id}
|
||||
@ -281,7 +282,7 @@ async def show_collection_menu(callback: CallbackQuery):
|
||||
collection_name = collection_info.get("name", "Коллекция")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.get(
|
||||
f"{settings.BACKEND_URL}/users/telegram/{telegram_id}"
|
||||
) as response:
|
||||
@ -673,7 +674,7 @@ async def process_edit_collection_description(message: Message, state: FSMContex
|
||||
if new_description:
|
||||
update_data["description"] = new_description
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.put(
|
||||
f"{settings.BACKEND_URL}/collections/{collection_id}",
|
||||
json=update_data,
|
||||
|
||||
@ -7,6 +7,7 @@ from aiogram.filters import StateFilter
|
||||
from aiogram.fsm.context import FSMContext
|
||||
import aiohttp
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.infrastructure.http_client import create_http_session
|
||||
from tg_bot.infrastructure.telegram.states.collection_states import (
|
||||
DocumentEditStates,
|
||||
DocumentUploadStates
|
||||
@ -18,7 +19,7 @@ router = Router()
|
||||
async def get_document_info(document_id: str, telegram_id: str):
|
||||
"""Получить информацию о документе"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.get(
|
||||
f"{settings.BACKEND_URL}/documents/{document_id}",
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
@ -34,7 +35,7 @@ async def get_document_info(document_id: str, telegram_id: str):
|
||||
async def delete_document(document_id: str, telegram_id: str):
|
||||
"""Удалить документ"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.delete(
|
||||
f"{settings.BACKEND_URL}/documents/{document_id}",
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
@ -54,7 +55,7 @@ async def update_document(document_id: str, telegram_id: str, title: str = None,
|
||||
if content:
|
||||
update_data["content"] = content
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.put(
|
||||
f"{settings.BACKEND_URL}/documents/{document_id}",
|
||||
json=update_data,
|
||||
@ -71,7 +72,7 @@ async def update_document(document_id: str, telegram_id: str, title: str = None,
|
||||
async def upload_document_to_collection(collection_id: str, file_data: bytes, filename: str, telegram_id: str):
|
||||
"""Загрузить документ в коллекцию"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field('file', file_data, filename=filename, content_type='application/octet-stream')
|
||||
|
||||
@ -120,7 +121,7 @@ async def view_document(callback: CallbackQuery):
|
||||
response += "\n\n<i>...</i>"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with create_http_session() as session:
|
||||
async with session.get(
|
||||
f"{settings.BACKEND_URL}/collections/{collection_id}",
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from aiogram import Router, types
|
||||
from aiogram.types import Message
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
from tg_bot.config.settings import settings
|
||||
from tg_bot.domain.services.user_service import UserService, User
|
||||
from tg_bot.application.services.rag_service import RAGService
|
||||
@ -60,12 +58,7 @@ async def process_premium_question(message: Message, user: User, question_text:
|
||||
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
||||
sources = rag_result.get("sources", [])
|
||||
|
||||
await save_conversation_to_backend(
|
||||
str(message.from_user.id),
|
||||
question_text,
|
||||
answer,
|
||||
sources
|
||||
)
|
||||
# Беседа уже сохранена в бэкенде через API /rag/question
|
||||
|
||||
response = (
|
||||
f"<b>Ваш вопрос:</b>\n"
|
||||
@ -74,18 +67,10 @@ async def process_premium_question(message: Message, user: User, question_text:
|
||||
)
|
||||
|
||||
if sources:
|
||||
response += f"<b>Источники из коллекций:</b>\n"
|
||||
collections_used = {}
|
||||
for source in sources[:5]:
|
||||
collection_name = source.get('collection', 'Неизвестно')
|
||||
if collection_name not in collections_used:
|
||||
collections_used[collection_name] = []
|
||||
collections_used[collection_name].append(source.get('title', 'Без названия'))
|
||||
|
||||
for i, (collection_name, titles) in enumerate(collections_used.items(), 1):
|
||||
response += f"{i}. <b>Коллекция:</b> {collection_name}\n"
|
||||
for title in titles[:2]:
|
||||
response += f" • {title}\n"
|
||||
response += f"<b>Источники:</b>\n"
|
||||
for idx, source in enumerate(sources[:5], 1):
|
||||
title = source.get('title', 'Без названия')
|
||||
response += f"{idx}. {title}\n"
|
||||
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
||||
|
||||
response += (
|
||||
@ -122,12 +107,7 @@ async def process_free_question(message: Message, user: User, question_text: str
|
||||
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
||||
sources = rag_result.get("sources", [])
|
||||
|
||||
await save_conversation_to_backend(
|
||||
str(message.from_user.id),
|
||||
question_text,
|
||||
answer,
|
||||
sources
|
||||
)
|
||||
# Уже все сохранили через /rag/question
|
||||
|
||||
response = (
|
||||
f"<b>Ваш вопрос:</b>\n"
|
||||
@ -136,18 +116,10 @@ async def process_free_question(message: Message, user: User, question_text: str
|
||||
)
|
||||
|
||||
if sources:
|
||||
response += f"<b>Источники из коллекций:</b>\n"
|
||||
collections_used = {}
|
||||
for source in sources[:5]:
|
||||
collection_name = source.get('collection', 'Неизвестно')
|
||||
if collection_name not in collections_used:
|
||||
collections_used[collection_name] = []
|
||||
collections_used[collection_name].append(source.get('title', 'Без названия'))
|
||||
|
||||
for i, (collection_name, titles) in enumerate(collections_used.items(), 1):
|
||||
response += f"{i}. <b>Коллекция:</b> {collection_name}\n"
|
||||
for title in titles[:2]:
|
||||
response += f" • {title}\n"
|
||||
response += f"<b>Источники:</b>\n"
|
||||
for idx, source in enumerate(sources[:5], 1):
|
||||
title = source.get('title', 'Без названия')
|
||||
response += f"{idx}. {title}\n"
|
||||
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
||||
|
||||
response += (
|
||||
@ -176,83 +148,7 @@ async def process_free_question(message: Message, user: User, question_text: str
|
||||
await message.answer(response, parse_mode="HTML")
|
||||
|
||||
|
||||
async def save_conversation_to_backend(telegram_id: str, question: str, answer: str, sources: list):
|
||||
try:
|
||||
from tg_bot.config.settings import settings
|
||||
backend_url = settings.BACKEND_URL
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{backend_url}/users/telegram/{telegram_id}"
|
||||
) as user_response:
|
||||
if user_response.status != 200:
|
||||
return
|
||||
user_data = await user_response.json()
|
||||
user_uuid = user_data.get("user_id")
|
||||
|
||||
async with session.get(
|
||||
f"{backend_url}/collections/",
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
) as collections_response:
|
||||
collections = []
|
||||
if collections_response.status == 200:
|
||||
collections = await collections_response.json()
|
||||
|
||||
collection_id = None
|
||||
if collections:
|
||||
collection_id = collections[0].get("collection_id")
|
||||
else:
|
||||
async with session.post(
|
||||
f"{backend_url}/collections",
|
||||
json={
|
||||
"name": "Основная коллекция",
|
||||
"description": "Коллекция по умолчанию",
|
||||
"is_public": False
|
||||
},
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
) as create_collection_response:
|
||||
if create_collection_response.status in [200, 201]:
|
||||
collection_data = await create_collection_response.json()
|
||||
collection_id = collection_data.get("collection_id")
|
||||
|
||||
if not collection_id:
|
||||
return
|
||||
|
||||
async with session.post(
|
||||
f"{backend_url}/conversations",
|
||||
json={"collection_id": str(collection_id)},
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
) as conversation_response:
|
||||
if conversation_response.status not in [200, 201]:
|
||||
return
|
||||
conversation_data = await conversation_response.json()
|
||||
conversation_id = conversation_data.get("conversation_id")
|
||||
|
||||
if not conversation_id:
|
||||
return
|
||||
|
||||
await session.post(
|
||||
f"{backend_url}/messages",
|
||||
json={
|
||||
"conversation_id": str(conversation_id),
|
||||
"content": question,
|
||||
"role": "user"
|
||||
},
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
)
|
||||
|
||||
await session.post(
|
||||
f"{backend_url}/messages",
|
||||
json={
|
||||
"conversation_id": str(conversation_id),
|
||||
"content": answer,
|
||||
"role": "assistant",
|
||||
"sources": {"documents": sources}
|
||||
},
|
||||
headers={"X-Telegram-ID": telegram_id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving conversation: {e}")
|
||||
#Сново сохраняется в /rag/question
|
||||
|
||||
|
||||
async def handle_limit_exceeded(message: Message, user: User):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user