forked from HSE_team/BetterCallPraskovia
* Add migration
* Delete legacy from bot * Clear old models * Единый http клиент * РАГ полечен
This commit is contained in:
parent
1ce1c23d10
commit
8bdacb4f7a
33
backend/alembic/versions/003_remove_embeddings_table.py
Normal file
33
backend/alembic/versions/003_remove_embeddings_table.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""Remove unused embeddings table
|
||||||
|
|
||||||
|
Revision ID: 003
|
||||||
|
Revises: 002
|
||||||
|
Create Date: 2024-12-24 12:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision = '003'
|
||||||
|
down_revision = '002'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.drop_table('embeddings')
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'embeddings',
|
||||||
|
sa.Column('embedding_id', sa.dialects.postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('document_id', sa.dialects.postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('embedding', sa.dialects.postgresql.JSON(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column('model_version', sa.String(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['document_id'], ['documents.document_id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('embedding_id')
|
||||||
|
)
|
||||||
|
|
||||||
@ -8,6 +8,7 @@ from src.domain.repositories.document_repository import IDocumentRepository
|
|||||||
from src.domain.repositories.collection_repository import ICollectionRepository
|
from src.domain.repositories.collection_repository import ICollectionRepository
|
||||||
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
|
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
|
||||||
from src.application.services.document_parser_service import DocumentParserService
|
from src.application.services.document_parser_service import DocumentParserService
|
||||||
|
from src.application.services.rag_service import RAGService
|
||||||
from src.shared.exceptions import NotFoundError, ForbiddenError
|
from src.shared.exceptions import NotFoundError, ForbiddenError
|
||||||
|
|
||||||
|
|
||||||
@ -19,12 +20,14 @@ class DocumentUseCases:
|
|||||||
document_repository: IDocumentRepository,
|
document_repository: IDocumentRepository,
|
||||||
collection_repository: ICollectionRepository,
|
collection_repository: ICollectionRepository,
|
||||||
access_repository: ICollectionAccessRepository,
|
access_repository: ICollectionAccessRepository,
|
||||||
parser_service: DocumentParserService
|
parser_service: DocumentParserService,
|
||||||
|
rag_service: Optional[RAGService] = None
|
||||||
):
|
):
|
||||||
self.document_repository = document_repository
|
self.document_repository = document_repository
|
||||||
self.collection_repository = collection_repository
|
self.collection_repository = collection_repository
|
||||||
self.access_repository = access_repository
|
self.access_repository = access_repository
|
||||||
self.parser_service = parser_service
|
self.parser_service = parser_service
|
||||||
|
self.rag_service = rag_service
|
||||||
|
|
||||||
async def _check_collection_access(self, user_id: UUID, collection) -> bool:
|
async def _check_collection_access(self, user_id: UUID, collection) -> bool:
|
||||||
"""Проверить доступ пользователя к коллекции"""
|
"""Проверить доступ пользователя к коллекции"""
|
||||||
@ -64,7 +67,7 @@ class DocumentUseCases:
|
|||||||
filename: str,
|
filename: str,
|
||||||
user_id: UUID
|
user_id: UUID
|
||||||
) -> Document:
|
) -> Document:
|
||||||
"""Загрузить и распарсить документ"""
|
"""Загрузить и распарсить документ, затем автоматически проиндексировать"""
|
||||||
collection = await self.collection_repository.get_by_id(collection_id)
|
collection = await self.collection_repository.get_by_id(collection_id)
|
||||||
if not collection:
|
if not collection:
|
||||||
raise NotFoundError(f"Коллекция {collection_id} не найдена")
|
raise NotFoundError(f"Коллекция {collection_id} не найдена")
|
||||||
@ -81,7 +84,15 @@ class DocumentUseCases:
|
|||||||
content=content,
|
content=content,
|
||||||
metadata={"filename": filename}
|
metadata={"filename": filename}
|
||||||
)
|
)
|
||||||
return await self.document_repository.create(document)
|
document = await self.document_repository.create(document)
|
||||||
|
|
||||||
|
if self.rag_service:
|
||||||
|
try:
|
||||||
|
await self.rag_service.index_document(document)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ошибка при автоматической индексации документа {document.document_id}: {e}")
|
||||||
|
|
||||||
|
return document
|
||||||
|
|
||||||
async def get_document(self, document_id: UUID) -> Document:
|
async def get_document(self, document_id: UUID) -> Document:
|
||||||
"""Получить документ по ID"""
|
"""Получить документ по ID"""
|
||||||
|
|||||||
@ -1,25 +0,0 @@
|
|||||||
"""
|
|
||||||
Доменная сущность Embedding
|
|
||||||
"""
|
|
||||||
from datetime import datetime
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding:
|
|
||||||
"""Эмбеддинг документа"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
document_id: UUID,
|
|
||||||
embedding: list[float] | None = None,
|
|
||||||
model_version: str = "",
|
|
||||||
embedding_id: UUID | None = None,
|
|
||||||
created_at: datetime | None = None
|
|
||||||
):
|
|
||||||
self.embedding_id = embedding_id or uuid4()
|
|
||||||
self.document_id = document_id
|
|
||||||
self.embedding = embedding or []
|
|
||||||
self.model_version = model_version
|
|
||||||
self.created_at = created_at or datetime.utcnow()
|
|
||||||
|
|
||||||
@ -53,19 +53,6 @@ class DocumentModel(Base):
|
|||||||
document_metadata = Column("metadata", JSON, nullable=True, default={})
|
document_metadata = Column("metadata", JSON, nullable=True, default={})
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||||
collection = relationship("CollectionModel", back_populates="documents")
|
collection = relationship("CollectionModel", back_populates="documents")
|
||||||
embeddings = relationship("EmbeddingModel", back_populates="document", cascade="all, delete-orphan")
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModel(Base):
|
|
||||||
"""Модель эмбеддинга (заглушка)"""
|
|
||||||
__tablename__ = "embeddings"
|
|
||||||
|
|
||||||
embedding_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
||||||
document_id = Column(UUID(as_uuid=True), ForeignKey("documents.document_id"), nullable=False)
|
|
||||||
embedding = Column(JSON, nullable=True)
|
|
||||||
model_version = Column(String, nullable=True)
|
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
|
||||||
document = relationship("DocumentModel", back_populates="embeddings")
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationModel(Base):
|
class ConversationModel(Base):
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
API для RAG: индексация документов и ответы на вопросы
|
API для RAG: ответы на вопросы
|
||||||
"""
|
"""
|
||||||
from fastapi import APIRouter, status, Request
|
from fastapi import APIRouter, status, Request
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
@ -9,30 +9,13 @@ from src.presentation.middleware.auth_middleware import get_current_user
|
|||||||
from src.presentation.schemas.rag_schemas import (
|
from src.presentation.schemas.rag_schemas import (
|
||||||
QuestionRequest,
|
QuestionRequest,
|
||||||
RAGAnswer,
|
RAGAnswer,
|
||||||
IndexDocumentRequest,
|
|
||||||
IndexDocumentResponse,
|
|
||||||
)
|
)
|
||||||
from src.application.use_cases.rag_use_cases import RAGUseCases
|
from src.application.use_cases.rag_use_cases import RAGUseCases
|
||||||
from src.domain.entities.user import User
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/rag", tags=["rag"])
|
router = APIRouter(prefix="/rag", tags=["rag"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/index", response_model=IndexDocumentResponse, status_code=status.HTTP_200_OK)
|
|
||||||
@inject
|
|
||||||
async def index_document(
|
|
||||||
body: IndexDocumentRequest,
|
|
||||||
request: Request,
|
|
||||||
user_repo: Annotated[IUserRepository, FromDishka()],
|
|
||||||
use_cases: Annotated[RAGUseCases, FromDishka()],
|
|
||||||
):
|
|
||||||
"""Индексирование идет через чанкирование, далее эмбеддинг и загрузка в векторную бд"""
|
|
||||||
current_user = await get_current_user(request, user_repo)
|
|
||||||
result = await use_cases.index_document(body.document_id)
|
|
||||||
return IndexDocumentResponse(**result)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/question", response_model=RAGAnswer, status_code=status.HTTP_200_OK)
|
@router.post("/question", response_model=RAGAnswer, status_code=status.HTTP_200_OK)
|
||||||
@inject
|
@inject
|
||||||
async def ask_question(
|
async def ask_question(
|
||||||
|
|||||||
@ -26,10 +26,3 @@ class RAGAnswer(BaseModel):
|
|||||||
usage: dict[str, Any] = {}
|
usage: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
class IndexDocumentRequest(BaseModel):
|
|
||||||
document_id: UUID
|
|
||||||
|
|
||||||
|
|
||||||
class IndexDocumentResponse(BaseModel):
|
|
||||||
chunks_indexed: int
|
|
||||||
|
|
||||||
|
|||||||
@ -152,9 +152,10 @@ class UseCaseProvider(Provider):
|
|||||||
document_repo: IDocumentRepository,
|
document_repo: IDocumentRepository,
|
||||||
collection_repo: ICollectionRepository,
|
collection_repo: ICollectionRepository,
|
||||||
access_repo: ICollectionAccessRepository,
|
access_repo: ICollectionAccessRepository,
|
||||||
parser_service: DocumentParserService
|
parser_service: DocumentParserService,
|
||||||
|
rag_service: RAGService
|
||||||
) -> DocumentUseCases:
|
) -> DocumentUseCases:
|
||||||
return DocumentUseCases(document_repo, collection_repo, access_repo, parser_service)
|
return DocumentUseCases(document_repo, collection_repo, access_repo, parser_service, rag_service)
|
||||||
|
|
||||||
@provide(scope=Scope.REQUEST)
|
@provide(scope=Scope.REQUEST)
|
||||||
def get_conversation_use_cases(
|
def get_conversation_use_cases(
|
||||||
|
|||||||
@ -1,137 +1,130 @@
|
|||||||
import aiohttp
|
"""
|
||||||
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
RAG сервис для бота - вызывает API бэкенда
|
||||||
|
"""
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
|
from tg_bot.infrastructure.http_client import create_http_session
|
||||||
|
|
||||||
|
|
||||||
class RAGService:
|
class RAGService:
|
||||||
|
"""Сервис для работы с RAG через API бэкенда"""
|
||||||
|
|
||||||
def __init__(self):
|
async def get_or_create_conversation(
|
||||||
self.deepseek_client = DeepSeekClient()
|
self,
|
||||||
|
user_telegram_id: str,
|
||||||
async def search_documents_in_collections(
|
collection_id: str = None
|
||||||
self,
|
) -> str | None:
|
||||||
user_telegram_id: str,
|
"""Получить или создать беседу для пользователя"""
|
||||||
query: str,
|
|
||||||
limit_per_collection: int = 5
|
|
||||||
) -> list[dict]:
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{settings.BACKEND_URL}/users/telegram/{user_telegram_id}"
|
f"{settings.BACKEND_URL}/collections/",
|
||||||
) as user_response:
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
if user_response.status != 200:
|
|
||||||
return []
|
|
||||||
|
|
||||||
user_data = await user_response.json()
|
|
||||||
user_uuid = str(user_data.get("user_id"))
|
|
||||||
|
|
||||||
if not user_uuid:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async with session.get(
|
|
||||||
f"{settings.BACKEND_URL}/collections/",
|
|
||||||
headers={"X-Telegram-ID": user_telegram_id}
|
|
||||||
) as collections_response:
|
) as collections_response:
|
||||||
if collections_response.status != 200:
|
if collections_response.status != 200:
|
||||||
return []
|
return None
|
||||||
|
|
||||||
collections = await collections_response.json()
|
collections = await collections_response.json()
|
||||||
|
|
||||||
all_documents = []
|
if not collections:
|
||||||
for collection in collections:
|
if not collection_id:
|
||||||
collection_id = collection.get("collection_id")
|
async with session.post(
|
||||||
|
f"{settings.BACKEND_URL}/collections",
|
||||||
|
json={
|
||||||
|
"name": "Основная коллекция",
|
||||||
|
"description": "Коллекция по умолчанию",
|
||||||
|
"is_public": False
|
||||||
|
},
|
||||||
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
|
) as create_collection_response:
|
||||||
|
if create_collection_response.status in [200, 201]:
|
||||||
|
collection_data = await create_collection_response.json()
|
||||||
|
collection_id = collection_data.get("collection_id")
|
||||||
|
else:
|
||||||
|
collection_id = collection_id
|
||||||
|
else:
|
||||||
|
collection_id = collections[0].get("collection_id")
|
||||||
|
|
||||||
if not collection_id:
|
if not collection_id:
|
||||||
continue
|
return None
|
||||||
|
|
||||||
try:
|
async with session.get(
|
||||||
async with aiohttp.ClientSession() as search_session:
|
f"{settings.BACKEND_URL}/conversations",
|
||||||
async with search_session.get(
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
f"{settings.BACKEND_URL}/documents/collection/{collection_id}",
|
) as conversations_response:
|
||||||
params={"search": query, "limit": limit_per_collection},
|
if conversations_response.status == 200:
|
||||||
headers={"X-Telegram-ID": user_telegram_id}
|
conversations = await conversations_response.json()
|
||||||
) as search_response:
|
for conv in conversations:
|
||||||
if search_response.status == 200:
|
if conv.get("collection_id") == str(collection_id):
|
||||||
documents = await search_response.json()
|
return conv.get("conversation_id")
|
||||||
for doc in documents:
|
|
||||||
doc["collection_name"] = collection.get("name", "Unknown")
|
async with session.post(
|
||||||
all_documents.append(doc)
|
f"{settings.BACKEND_URL}/conversations",
|
||||||
except Exception as e:
|
json={"collection_id": str(collection_id)},
|
||||||
print(f"Error searching collection {collection_id}: {e}")
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
continue
|
) as create_conversation_response:
|
||||||
|
if create_conversation_response.status in [200, 201]:
|
||||||
return all_documents[:20]
|
conversation_data = await create_conversation_response.json()
|
||||||
|
return conversation_data.get("conversation_id")
|
||||||
|
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error searching documents: {e}")
|
print(f"Error getting/creating conversation: {e}")
|
||||||
return []
|
return None
|
||||||
|
|
||||||
async def generate_answer_with_rag(
|
async def generate_answer_with_rag(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
user_telegram_id: str
|
user_telegram_id: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
documents = await self.search_documents_in_collections(
|
"""Генерирует ответ используя RAG через API бэкенда"""
|
||||||
user_telegram_id,
|
|
||||||
question
|
|
||||||
)
|
|
||||||
|
|
||||||
context_parts = []
|
|
||||||
sources = []
|
|
||||||
|
|
||||||
for doc in documents[:5]:
|
|
||||||
title = doc.get("title", "Без названия")
|
|
||||||
content = doc.get("content", "")[:1000]
|
|
||||||
collection_name = doc.get("collection_name", "Unknown")
|
|
||||||
|
|
||||||
context_parts.append(f"Документ: {title}\nКоллекция: {collection_name}\nСодержание: {content[:500]}...")
|
|
||||||
sources.append({
|
|
||||||
"title": title,
|
|
||||||
"collection": collection_name,
|
|
||||||
"document_id": doc.get("document_id")
|
|
||||||
})
|
|
||||||
|
|
||||||
context = "\n\n".join(context_parts) if context_parts else "Релевантные документы не найдены."
|
|
||||||
|
|
||||||
system_prompt = """Ты - помощник-юрист, который отвечает на вопросы на основе предоставленных документов.
|
|
||||||
Используй информацию из документов для формирования точного и полезного ответа.
|
|
||||||
Если в документах нет информации для ответа, честно скажи об этом."""
|
|
||||||
|
|
||||||
user_prompt = f"""Контекст из документов:
|
|
||||||
{context}
|
|
||||||
|
|
||||||
Вопрос пользователя: {question}
|
|
||||||
|
|
||||||
Ответь на вопрос, используя информацию из предоставленных документов. Если информации недостаточно, укажи это."""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
messages = [
|
conversation_id = await self.get_or_create_conversation(user_telegram_id)
|
||||||
{"role": "system", "content": system_prompt},
|
if not conversation_id:
|
||||||
{"role": "user", "content": user_prompt}
|
|
||||||
]
|
|
||||||
|
|
||||||
response = await self.deepseek_client.chat_completion(
|
|
||||||
messages=messages,
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=2000
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"answer": response.get("content", "Failed to generate answer"),
|
|
||||||
"sources": sources,
|
|
||||||
"usage": response.get("usage", {})
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error generating answer: {e}")
|
|
||||||
if documents:
|
|
||||||
return {
|
return {
|
||||||
"answer": f"Found {len(documents)} documents but failed to generate answer",
|
"answer": "Не удалось создать беседу. Попробуйте позже.",
|
||||||
"sources": sources[:3],
|
|
||||||
"usage": {}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"answer": "No relevant documents found",
|
|
||||||
"sources": [],
|
"sources": [],
|
||||||
"usage": {}
|
"usage": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.BACKEND_URL}/rag/question",
|
||||||
|
json={
|
||||||
|
"conversation_id": str(conversation_id),
|
||||||
|
"question": question,
|
||||||
|
"top_k": 20,
|
||||||
|
"rerank_top_n": 5
|
||||||
|
},
|
||||||
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
result = await response.json()
|
||||||
|
sources = []
|
||||||
|
for source in result.get("sources", []):
|
||||||
|
sources.append({
|
||||||
|
"title": source.get("title", "Без названия"),
|
||||||
|
"document_id": source.get("document_id"),
|
||||||
|
"chunk_id": source.get("chunk_id"),
|
||||||
|
"index": source.get("index", 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"answer": result.get("answer", "Не удалось сгенерировать ответ."),
|
||||||
|
"sources": sources,
|
||||||
|
"usage": result.get("usage", {}),
|
||||||
|
"conversation_id": str(conversation_id)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
print(f"RAG API error: {response.status} - {error_text}")
|
||||||
|
return {
|
||||||
|
"answer": "Ошибка при генерации ответа. Попробуйте позже.",
|
||||||
|
"sources": [],
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating answer with RAG: {e}")
|
||||||
|
return {
|
||||||
|
"answer": "Произошла ошибка при генерации ответа. Попробуйте позже.",
|
||||||
|
"sources": [],
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""Настройки приложения (загружаются из .env файла в корне проекта)"""
|
"""Настройки приложения получаеи из env файла, тут не ищи, мы спрятали:)"""
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
@ -16,7 +16,7 @@ class Settings(BaseSettings):
|
|||||||
VERSION: str = "0.1.0"
|
VERSION: str = "0.1.0"
|
||||||
DEBUG: bool = False
|
DEBUG: bool = False
|
||||||
|
|
||||||
TELEGRAM_BOT_TOKEN: str = ""
|
TELEGRAM_BOT_TOKEN: str
|
||||||
|
|
||||||
FREE_QUESTIONS_LIMIT: int = 5
|
FREE_QUESTIONS_LIMIT: int = 5
|
||||||
PAYMENT_AMOUNT: float = 500.0
|
PAYMENT_AMOUNT: float = 500.0
|
||||||
@ -25,8 +25,8 @@ class Settings(BaseSettings):
|
|||||||
LOG_FILE: str = "logs/bot.log"
|
LOG_FILE: str = "logs/bot.log"
|
||||||
|
|
||||||
|
|
||||||
YOOKASSA_SHOP_ID: str = ""
|
YOOKASSA_SHOP_ID: str
|
||||||
YOOKASSA_SECRET_KEY: str = ""
|
YOOKASSA_SECRET_KEY: str
|
||||||
YOOKASSA_RETURN_URL: str = "https://t.me/vibelawyer_bot"
|
YOOKASSA_RETURN_URL: str = "https://t.me/vibelawyer_bot"
|
||||||
YOOKASSA_WEBHOOK_SECRET: Optional[str] = None
|
YOOKASSA_WEBHOOK_SECRET: Optional[str] = None
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ class Settings(BaseSettings):
|
|||||||
DEEPSEEK_API_URL: str = "https://api.deepseek.com/v1/chat/completions"
|
DEEPSEEK_API_URL: str = "https://api.deepseek.com/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
BACKEND_URL: str = "http://localhost:8000/api/v1"
|
BACKEND_URL: str
|
||||||
|
|
||||||
|
|
||||||
ADMIN_IDS_STR: str = ""
|
ADMIN_IDS_STR: str = ""
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import aiohttp
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
from tg_bot.infrastructure.http_client import create_http_session, normalize_backend_url
|
from tg_bot.infrastructure.http_client import create_http_session
|
||||||
|
|
||||||
|
|
||||||
class User:
|
class User:
|
||||||
@ -40,7 +40,7 @@ class UserService:
|
|||||||
"""Сервис для работы с пользователями через API бэкенда"""
|
"""Сервис для работы с пользователями через API бэкенда"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.backend_url = normalize_backend_url(settings.BACKEND_URL)
|
self.backend_url = settings.BACKEND_URL
|
||||||
print(f"UserService initialized with BACKEND_URL: {self.backend_url}")
|
print(f"UserService initialized with BACKEND_URL: {self.backend_url}")
|
||||||
|
|
||||||
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[User]:
|
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[User]:
|
||||||
@ -48,7 +48,7 @@ class UserService:
|
|||||||
try:
|
try:
|
||||||
url = f"{self.backend_url}/users/telegram/{telegram_id}"
|
url = f"{self.backend_url}/users/telegram/{telegram_id}"
|
||||||
async with create_http_session() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(url, ssl=False) as response:
|
async with session.get(url) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
return User(data)
|
return User(data)
|
||||||
@ -74,8 +74,7 @@ class UserService:
|
|||||||
async with create_http_session() as session:
|
async with create_http_session() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.backend_url}/users",
|
f"{self.backend_url}/users",
|
||||||
json={"telegram_id": str(telegram_id), "role": "user"},
|
json={"telegram_id": str(telegram_id), "role": "user"}
|
||||||
ssl=False
|
|
||||||
) as response:
|
) as response:
|
||||||
if response.status in [200, 201]:
|
if response.status in [200, 201]:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
@ -106,8 +105,7 @@ class UserService:
|
|||||||
try:
|
try:
|
||||||
async with create_http_session() as session:
|
async with create_http_session() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.backend_url}/users/telegram/{telegram_id}/increment-questions",
|
f"{self.backend_url}/users/telegram/{telegram_id}/increment-questions"
|
||||||
ssl=False
|
|
||||||
) as response:
|
) as response:
|
||||||
return response.status == 200
|
return response.status == 200
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -120,8 +118,7 @@ class UserService:
|
|||||||
async with create_http_session() as session:
|
async with create_http_session() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.backend_url}/users/telegram/{telegram_id}/activate-premium",
|
f"{self.backend_url}/users/telegram/{telegram_id}/activate-premium",
|
||||||
params={"days": days},
|
params={"days": days}
|
||||||
ssl=False
|
|
||||||
) as response:
|
) as response:
|
||||||
return response.status == 200
|
return response.status == 200
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
172
tg_bot/infrastructure/external/deepseek_client.py
vendored
172
tg_bot/infrastructure/external/deepseek_client.py
vendored
@ -1,172 +0,0 @@
|
|||||||
import json
|
|
||||||
from typing import Optional, AsyncIterator
|
|
||||||
import httpx
|
|
||||||
from tg_bot.config.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekAPIError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekClient:
|
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, api_url: str | None = None):
|
|
||||||
self.api_key = api_key or settings.DEEPSEEK_API_KEY
|
|
||||||
self.api_url = api_url or settings.DEEPSEEK_API_URL
|
|
||||||
self.timeout = 60.0
|
|
||||||
|
|
||||||
def _get_headers(self) -> dict[str, str]:
|
|
||||||
if not self.api_key:
|
|
||||||
raise DeepSeekAPIError("API key not set")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, str]],
|
|
||||||
model: str = "deepseek-chat",
|
|
||||||
temperature: float = 0.7,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
stream: bool = False
|
|
||||||
) -> dict:
|
|
||||||
if not self.api_key:
|
|
||||||
return {
|
|
||||||
"content": "API key not configured",
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
"stream": stream
|
|
||||||
}
|
|
||||||
|
|
||||||
if max_tokens is not None:
|
|
||||||
payload["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
response = await client.post(
|
|
||||||
self.api_url,
|
|
||||||
headers=self._get_headers(),
|
|
||||||
json=payload
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if "choices" in data and len(data["choices"]) > 0:
|
|
||||||
content = data["choices"][0]["message"]["content"]
|
|
||||||
else:
|
|
||||||
raise DeepSeekAPIError("Invalid response format")
|
|
||||||
|
|
||||||
usage = data.get("usage", {})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": content,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
|
||||||
"completion_tokens": usage.get("completion_tokens", 0),
|
|
||||||
"total_tokens": usage.get("total_tokens", 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_msg = f"API error: {e.response.status_code}"
|
|
||||||
try:
|
|
||||||
error_data = e.response.json()
|
|
||||||
if "error" in error_data:
|
|
||||||
error_msg = error_data['error'].get('message', error_msg)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise DeepSeekAPIError(error_msg) from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
|
|
||||||
except Exception as e:
|
|
||||||
raise DeepSeekAPIError(str(e)) from e
|
|
||||||
|
|
||||||
async def stream_chat_completion(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, str]],
|
|
||||||
model: str = "deepseek-chat",
|
|
||||||
temperature: float = 0.7,
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
) -> AsyncIterator[str]:
|
|
||||||
if not self.api_key:
|
|
||||||
yield "API key not configured"
|
|
||||||
return
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
"stream": True
|
|
||||||
}
|
|
||||||
|
|
||||||
if max_tokens is not None:
|
|
||||||
payload["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
async with client.stream(
|
|
||||||
"POST",
|
|
||||||
self.api_url,
|
|
||||||
headers=self._get_headers(),
|
|
||||||
json=payload
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
if line.startswith("data: "):
|
|
||||||
line = line[6:]
|
|
||||||
|
|
||||||
if line.strip() == "[DONE]":
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
|
|
||||||
if "choices" in data and len(data["choices"]) > 0:
|
|
||||||
delta = data["choices"][0].get("delta", {})
|
|
||||||
content = delta.get("content", "")
|
|
||||||
if content:
|
|
||||||
yield content
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_msg = f"API error: {e.response.status_code}"
|
|
||||||
try:
|
|
||||||
error_data = e.response.json()
|
|
||||||
if "error" in error_data:
|
|
||||||
error_msg = error_data['error'].get('message', error_msg)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise DeepSeekAPIError(error_msg) from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
|
|
||||||
except Exception as e:
|
|
||||||
raise DeepSeekAPIError(str(e)) from e
|
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
|
||||||
if not self.api_key:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_messages = [{"role": "user", "content": "test"}]
|
|
||||||
await self.chat_completion(test_messages, max_tokens=1)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@ -1,87 +1,23 @@
|
|||||||
"""HTTP client utilities for making requests to the backend API"""
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import ssl
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_windows_host_ip() -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get the Windows host IP address when running in WSL.
|
|
||||||
In WSL2, the Windows host IP is typically the first nameserver in /etc/resolv.conf.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if os.path.exists("/etc/resolv.conf"):
|
|
||||||
with open("/etc/resolv.conf", "r") as f:
|
|
||||||
for line in f:
|
|
||||||
if line.startswith("nameserver"):
|
|
||||||
ip = line.split()[1]
|
|
||||||
if ip not in ["127.0.0.1", "127.0.0.53"] and not ip.startswith("fe80"):
|
|
||||||
return ip
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_backend_url(url: str) -> str:
|
|
||||||
"""
|
|
||||||
Normalize backend URL for better compatibility, especially on WSL and Docker.
|
|
||||||
"""
|
|
||||||
if not ("localhost" in url or "127.0.0.1" in url):
|
|
||||||
return url
|
|
||||||
if os.path.exists("/.dockerenv"):
|
|
||||||
print(f"Warning: Running in Docker but URL contains localhost: {url}")
|
|
||||||
print("Please set BACKEND_URL environment variable in docker-compose.yml to use Docker service name (e.g., http://backend:8000/api/v1)")
|
|
||||||
return url.replace("localhost", "127.0.0.1")
|
|
||||||
try:
|
|
||||||
if os.path.exists("/proc/version"):
|
|
||||||
with open("/proc/version", "r") as f:
|
|
||||||
version_content = f.read().lower()
|
|
||||||
if "microsoft" in version_content:
|
|
||||||
windows_ip = get_windows_host_ip()
|
|
||||||
if windows_ip:
|
|
||||||
if "localhost" in url or "127.0.0.1" in url:
|
|
||||||
url = url.replace("localhost", windows_ip).replace("127.0.0.1", windows_ip)
|
|
||||||
print(f"WSL detected: Using Windows host IP {windows_ip} for backend connection")
|
|
||||||
return url
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not detect WSL environment: {e}")
|
|
||||||
|
|
||||||
if url.startswith("http://localhost") or url.startswith("https://localhost"):
|
|
||||||
return url.replace("localhost", "127.0.0.1")
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
def create_http_session(timeout: Optional[aiohttp.ClientTimeout] = None) -> aiohttp.ClientSession:
|
def create_http_session(timeout: Optional[aiohttp.ClientTimeout] = None) -> aiohttp.ClientSession:
|
||||||
"""
|
"""
|
||||||
Create a configured aiohttp ClientSession for backend API requests.
|
Создаем сессию для запросов к бэку
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: Optional timeout configuration. Defaults to 30 seconds total timeout.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured aiohttp.ClientSession
|
|
||||||
"""
|
"""
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = aiohttp.ClientTimeout(total=30, connect=10)
|
timeout = aiohttp.ClientTimeout(total=30, connect=10)
|
||||||
|
|
||||||
connector = aiohttp.TCPConnector(
|
connector = aiohttp.TCPConnector(
|
||||||
ssl=False,
|
|
||||||
limit=100,
|
limit=100,
|
||||||
limit_per_host=30,
|
limit_per_host=30
|
||||||
force_close=True,
|
|
||||||
enable_cleanup_closed=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ssl_context = ssl.create_default_context()
|
|
||||||
ssl_context.check_hostname = False
|
|
||||||
ssl_context.verify_mode = ssl.CERT_NONE
|
|
||||||
|
|
||||||
return aiohttp.ClientSession(
|
return aiohttp.ClientSession(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json"
|
"Accept": "application/json"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from aiogram.filters import Command, StateFilter
|
|||||||
from aiogram.fsm.context import FSMContext
|
from aiogram.fsm.context import FSMContext
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
|
from tg_bot.infrastructure.http_client import create_http_session
|
||||||
from tg_bot.infrastructure.telegram.states.collection_states import (
|
from tg_bot.infrastructure.telegram.states.collection_states import (
|
||||||
CollectionAccessStates,
|
CollectionAccessStates,
|
||||||
CollectionEditStates
|
CollectionEditStates
|
||||||
@ -14,7 +15,7 @@ router = Router()
|
|||||||
|
|
||||||
async def get_user_collections(telegram_id: str):
|
async def get_user_collections(telegram_id: str):
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{settings.BACKEND_URL}/collections/",
|
f"{settings.BACKEND_URL}/collections/",
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
@ -33,7 +34,7 @@ async def get_collection_documents(collection_id: str, telegram_id: str):
|
|||||||
url = f"{settings.BACKEND_URL}/documents/collection/{collection_id}"
|
url = f"{settings.BACKEND_URL}/documents/collection/{collection_id}"
|
||||||
print(f"DEBUG get_collection_documents: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
|
print(f"DEBUG get_collection_documents: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url,
|
url,
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
@ -57,7 +58,7 @@ async def get_collection_documents(collection_id: str, telegram_id: str):
|
|||||||
|
|
||||||
async def search_in_collection(collection_id: str, query: str, telegram_id: str):
|
async def search_in_collection(collection_id: str, query: str, telegram_id: str):
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{settings.BACKEND_URL}/documents/collection/{collection_id}",
|
f"{settings.BACKEND_URL}/documents/collection/{collection_id}",
|
||||||
params={"search": query},
|
params={"search": query},
|
||||||
@ -78,7 +79,7 @@ async def get_collection_info(collection_id: str, telegram_id: str):
|
|||||||
url = f"{settings.BACKEND_URL}/collections/{collection_id}"
|
url = f"{settings.BACKEND_URL}/collections/{collection_id}"
|
||||||
print(f"DEBUG get_collection_info: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
|
print(f"DEBUG get_collection_info: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url,
|
url,
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
@ -103,7 +104,7 @@ async def get_collection_info(collection_id: str, telegram_id: str):
|
|||||||
async def get_collection_access_list(collection_id: str, telegram_id: str):
|
async def get_collection_access_list(collection_id: str, telegram_id: str):
|
||||||
"""Получить список пользователей с доступом к коллекции"""
|
"""Получить список пользователей с доступом к коллекции"""
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{settings.BACKEND_URL}/collections/{collection_id}/access",
|
f"{settings.BACKEND_URL}/collections/{collection_id}/access",
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
@ -122,7 +123,7 @@ async def grant_collection_access(collection_id: str, telegram_id: str, owner_te
|
|||||||
url = f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}"
|
url = f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}"
|
||||||
print(f"DEBUG grant_collection_access: URL={url}, target_telegram_id={telegram_id}, owner_telegram_id={owner_telegram_id}")
|
print(f"DEBUG grant_collection_access: URL={url}, target_telegram_id={telegram_id}, owner_telegram_id={owner_telegram_id}")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
url,
|
url,
|
||||||
headers={"X-Telegram-ID": owner_telegram_id}
|
headers={"X-Telegram-ID": owner_telegram_id}
|
||||||
@ -145,7 +146,7 @@ async def grant_collection_access(collection_id: str, telegram_id: str, owner_te
|
|||||||
async def revoke_collection_access(collection_id: str, telegram_id: str, owner_telegram_id: str):
|
async def revoke_collection_access(collection_id: str, telegram_id: str, owner_telegram_id: str):
|
||||||
"""Отозвать доступ к коллекции"""
|
"""Отозвать доступ к коллекции"""
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.delete(
|
async with session.delete(
|
||||||
f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}",
|
f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}",
|
||||||
headers={"X-Telegram-ID": owner_telegram_id}
|
headers={"X-Telegram-ID": owner_telegram_id}
|
||||||
@ -281,7 +282,7 @@ async def show_collection_menu(callback: CallbackQuery):
|
|||||||
collection_name = collection_info.get("name", "Коллекция")
|
collection_name = collection_info.get("name", "Коллекция")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{settings.BACKEND_URL}/users/telegram/{telegram_id}"
|
f"{settings.BACKEND_URL}/users/telegram/{telegram_id}"
|
||||||
) as response:
|
) as response:
|
||||||
@ -673,7 +674,7 @@ async def process_edit_collection_description(message: Message, state: FSMContex
|
|||||||
if new_description:
|
if new_description:
|
||||||
update_data["description"] = new_description
|
update_data["description"] = new_description
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.put(
|
async with session.put(
|
||||||
f"{settings.BACKEND_URL}/collections/{collection_id}",
|
f"{settings.BACKEND_URL}/collections/{collection_id}",
|
||||||
json=update_data,
|
json=update_data,
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from aiogram.filters import StateFilter
|
|||||||
from aiogram.fsm.context import FSMContext
|
from aiogram.fsm.context import FSMContext
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
|
from tg_bot.infrastructure.http_client import create_http_session
|
||||||
from tg_bot.infrastructure.telegram.states.collection_states import (
|
from tg_bot.infrastructure.telegram.states.collection_states import (
|
||||||
DocumentEditStates,
|
DocumentEditStates,
|
||||||
DocumentUploadStates
|
DocumentUploadStates
|
||||||
@ -18,7 +19,7 @@ router = Router()
|
|||||||
async def get_document_info(document_id: str, telegram_id: str):
|
async def get_document_info(document_id: str, telegram_id: str):
|
||||||
"""Получить информацию о документе"""
|
"""Получить информацию о документе"""
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{settings.BACKEND_URL}/documents/{document_id}",
|
f"{settings.BACKEND_URL}/documents/{document_id}",
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
@ -34,7 +35,7 @@ async def get_document_info(document_id: str, telegram_id: str):
|
|||||||
async def delete_document(document_id: str, telegram_id: str):
|
async def delete_document(document_id: str, telegram_id: str):
|
||||||
"""Удалить документ"""
|
"""Удалить документ"""
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.delete(
|
async with session.delete(
|
||||||
f"{settings.BACKEND_URL}/documents/{document_id}",
|
f"{settings.BACKEND_URL}/documents/{document_id}",
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
@ -54,7 +55,7 @@ async def update_document(document_id: str, telegram_id: str, title: str = None,
|
|||||||
if content:
|
if content:
|
||||||
update_data["content"] = content
|
update_data["content"] = content
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.put(
|
async with session.put(
|
||||||
f"{settings.BACKEND_URL}/documents/{document_id}",
|
f"{settings.BACKEND_URL}/documents/{document_id}",
|
||||||
json=update_data,
|
json=update_data,
|
||||||
@ -71,7 +72,7 @@ async def update_document(document_id: str, telegram_id: str, title: str = None,
|
|||||||
async def upload_document_to_collection(collection_id: str, file_data: bytes, filename: str, telegram_id: str):
|
async def upload_document_to_collection(collection_id: str, file_data: bytes, filename: str, telegram_id: str):
|
||||||
"""Загрузить документ в коллекцию"""
|
"""Загрузить документ в коллекцию"""
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
form_data = aiohttp.FormData()
|
form_data = aiohttp.FormData()
|
||||||
form_data.add_field('file', file_data, filename=filename, content_type='application/octet-stream')
|
form_data.add_field('file', file_data, filename=filename, content_type='application/octet-stream')
|
||||||
|
|
||||||
@ -120,7 +121,7 @@ async def view_document(callback: CallbackQuery):
|
|||||||
response += "\n\n<i>...</i>"
|
response += "\n\n<i>...</i>"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{settings.BACKEND_URL}/collections/{collection_id}",
|
f"{settings.BACKEND_URL}/collections/{collection_id}",
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
from aiogram import Router, types
|
from aiogram import Router, types
|
||||||
from aiogram.types import Message
|
from aiogram.types import Message
|
||||||
from datetime import datetime
|
|
||||||
import aiohttp
|
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
from tg_bot.domain.services.user_service import UserService, User
|
from tg_bot.domain.services.user_service import UserService, User
|
||||||
from tg_bot.application.services.rag_service import RAGService
|
from tg_bot.application.services.rag_service import RAGService
|
||||||
@ -60,12 +58,7 @@ async def process_premium_question(message: Message, user: User, question_text:
|
|||||||
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
||||||
sources = rag_result.get("sources", [])
|
sources = rag_result.get("sources", [])
|
||||||
|
|
||||||
await save_conversation_to_backend(
|
# Беседа уже сохранена в бэкенде через API /rag/question
|
||||||
str(message.from_user.id),
|
|
||||||
question_text,
|
|
||||||
answer,
|
|
||||||
sources
|
|
||||||
)
|
|
||||||
|
|
||||||
response = (
|
response = (
|
||||||
f"<b>Ваш вопрос:</b>\n"
|
f"<b>Ваш вопрос:</b>\n"
|
||||||
@ -74,18 +67,10 @@ async def process_premium_question(message: Message, user: User, question_text:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if sources:
|
if sources:
|
||||||
response += f"<b>Источники из коллекций:</b>\n"
|
response += f"<b>Источники:</b>\n"
|
||||||
collections_used = {}
|
for idx, source in enumerate(sources[:5], 1):
|
||||||
for source in sources[:5]:
|
title = source.get('title', 'Без названия')
|
||||||
collection_name = source.get('collection', 'Неизвестно')
|
response += f"{idx}. {title}\n"
|
||||||
if collection_name not in collections_used:
|
|
||||||
collections_used[collection_name] = []
|
|
||||||
collections_used[collection_name].append(source.get('title', 'Без названия'))
|
|
||||||
|
|
||||||
for i, (collection_name, titles) in enumerate(collections_used.items(), 1):
|
|
||||||
response += f"{i}. <b>Коллекция:</b> {collection_name}\n"
|
|
||||||
for title in titles[:2]:
|
|
||||||
response += f" • {title}\n"
|
|
||||||
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
||||||
|
|
||||||
response += (
|
response += (
|
||||||
@ -122,12 +107,7 @@ async def process_free_question(message: Message, user: User, question_text: str
|
|||||||
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
||||||
sources = rag_result.get("sources", [])
|
sources = rag_result.get("sources", [])
|
||||||
|
|
||||||
await save_conversation_to_backend(
|
# Уже все сохранили через /rag/question
|
||||||
str(message.from_user.id),
|
|
||||||
question_text,
|
|
||||||
answer,
|
|
||||||
sources
|
|
||||||
)
|
|
||||||
|
|
||||||
response = (
|
response = (
|
||||||
f"<b>Ваш вопрос:</b>\n"
|
f"<b>Ваш вопрос:</b>\n"
|
||||||
@ -136,18 +116,10 @@ async def process_free_question(message: Message, user: User, question_text: str
|
|||||||
)
|
)
|
||||||
|
|
||||||
if sources:
|
if sources:
|
||||||
response += f"<b>Источники из коллекций:</b>\n"
|
response += f"<b>Источники:</b>\n"
|
||||||
collections_used = {}
|
for idx, source in enumerate(sources[:5], 1):
|
||||||
for source in sources[:5]:
|
title = source.get('title', 'Без названия')
|
||||||
collection_name = source.get('collection', 'Неизвестно')
|
response += f"{idx}. {title}\n"
|
||||||
if collection_name not in collections_used:
|
|
||||||
collections_used[collection_name] = []
|
|
||||||
collections_used[collection_name].append(source.get('title', 'Без названия'))
|
|
||||||
|
|
||||||
for i, (collection_name, titles) in enumerate(collections_used.items(), 1):
|
|
||||||
response += f"{i}. <b>Коллекция:</b> {collection_name}\n"
|
|
||||||
for title in titles[:2]:
|
|
||||||
response += f" • {title}\n"
|
|
||||||
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
||||||
|
|
||||||
response += (
|
response += (
|
||||||
@ -176,83 +148,7 @@ async def process_free_question(message: Message, user: User, question_text: str
|
|||||||
await message.answer(response, parse_mode="HTML")
|
await message.answer(response, parse_mode="HTML")
|
||||||
|
|
||||||
|
|
||||||
async def save_conversation_to_backend(telegram_id: str, question: str, answer: str, sources: list):
|
#Сново сохраняется в /rag/question
|
||||||
try:
|
|
||||||
from tg_bot.config.settings import settings
|
|
||||||
backend_url = settings.BACKEND_URL
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(
|
|
||||||
f"{backend_url}/users/telegram/{telegram_id}"
|
|
||||||
) as user_response:
|
|
||||||
if user_response.status != 200:
|
|
||||||
return
|
|
||||||
user_data = await user_response.json()
|
|
||||||
user_uuid = user_data.get("user_id")
|
|
||||||
|
|
||||||
async with session.get(
|
|
||||||
f"{backend_url}/collections/",
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
) as collections_response:
|
|
||||||
collections = []
|
|
||||||
if collections_response.status == 200:
|
|
||||||
collections = await collections_response.json()
|
|
||||||
|
|
||||||
collection_id = None
|
|
||||||
if collections:
|
|
||||||
collection_id = collections[0].get("collection_id")
|
|
||||||
else:
|
|
||||||
async with session.post(
|
|
||||||
f"{backend_url}/collections",
|
|
||||||
json={
|
|
||||||
"name": "Основная коллекция",
|
|
||||||
"description": "Коллекция по умолчанию",
|
|
||||||
"is_public": False
|
|
||||||
},
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
) as create_collection_response:
|
|
||||||
if create_collection_response.status in [200, 201]:
|
|
||||||
collection_data = await create_collection_response.json()
|
|
||||||
collection_id = collection_data.get("collection_id")
|
|
||||||
|
|
||||||
if not collection_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
async with session.post(
|
|
||||||
f"{backend_url}/conversations",
|
|
||||||
json={"collection_id": str(collection_id)},
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
) as conversation_response:
|
|
||||||
if conversation_response.status not in [200, 201]:
|
|
||||||
return
|
|
||||||
conversation_data = await conversation_response.json()
|
|
||||||
conversation_id = conversation_data.get("conversation_id")
|
|
||||||
|
|
||||||
if not conversation_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
await session.post(
|
|
||||||
f"{backend_url}/messages",
|
|
||||||
json={
|
|
||||||
"conversation_id": str(conversation_id),
|
|
||||||
"content": question,
|
|
||||||
"role": "user"
|
|
||||||
},
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.post(
|
|
||||||
f"{backend_url}/messages",
|
|
||||||
json={
|
|
||||||
"conversation_id": str(conversation_id),
|
|
||||||
"content": answer,
|
|
||||||
"role": "assistant",
|
|
||||||
"sources": {"documents": sources}
|
|
||||||
},
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving conversation: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_limit_exceeded(message: Message, user: User):
|
async def handle_limit_exceeded(message: Message, user: User):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user