forked from HSE_team/BetterCallPraskovia
93 lines
3.4 KiB
Python
93 lines
3.4 KiB
Python
"""
|
|
Use cases для RAG: индексация документов и ответы на вопросы
|
|
"""
|
|
from uuid import UUID
|
|
from src.application.services.rag_service import RAGService
|
|
from src.application.services.cache_service import CacheService
|
|
from src.domain.repositories.document_repository import IDocumentRepository
|
|
from src.domain.repositories.conversation_repository import IConversationRepository
|
|
from src.domain.repositories.message_repository import IMessageRepository
|
|
from src.domain.entities.message import Message, MessageRole
|
|
from src.shared.exceptions import NotFoundError, ForbiddenError
|
|
|
|
|
|
class RAGUseCases:
|
|
|
|
def __init__(
|
|
self,
|
|
rag_service: RAGService,
|
|
document_repo: IDocumentRepository,
|
|
conversation_repo: IConversationRepository,
|
|
message_repo: IMessageRepository,
|
|
cache_service: CacheService,
|
|
):
|
|
self.rag_service = rag_service
|
|
self.document_repo = document_repo
|
|
self.conversation_repo = conversation_repo
|
|
self.message_repo = message_repo
|
|
self.cache_service = cache_service
|
|
|
|
async def index_document(self, document_id: UUID) -> dict:
|
|
document = await self.document_repo.get_by_id(document_id)
|
|
if not document:
|
|
raise NotFoundError(f"Документ {document_id} не найден")
|
|
chunks = await self.rag_service.index_document(document)
|
|
return {"chunks_indexed": len(chunks)}
|
|
|
|
async def ask_question(
|
|
self,
|
|
conversation_id: UUID,
|
|
user_id: UUID,
|
|
question: str,
|
|
top_k: int = 20,
|
|
rerank_top_n: int = 5,
|
|
) -> dict:
|
|
conversation = await self.conversation_repo.get_by_id(conversation_id)
|
|
if not conversation:
|
|
raise NotFoundError(f"Беседа {conversation_id} не найдена")
|
|
if conversation.user_id != user_id:
|
|
raise ForbiddenError("Нет доступа к этой беседе")
|
|
|
|
user_message = Message(
|
|
conversation_id=conversation_id, content=question, role=MessageRole.USER
|
|
)
|
|
await self.message_repo.create(user_message)
|
|
|
|
cached_answer = None
|
|
if self.cache_service:
|
|
cached_answer = await self.cache_service.get_cached_answer(conversation.collection_id, question)
|
|
|
|
if cached_answer:
|
|
generation = cached_answer
|
|
else:
|
|
retrieved = await self.rag_service.retrieve(
|
|
query=question,
|
|
collection_id=conversation.collection_id,
|
|
limit=top_k,
|
|
rerank_top_n=rerank_top_n,
|
|
)
|
|
chunks = [c for c, _ in retrieved]
|
|
generation = await self.rag_service.generate_answer(question, chunks)
|
|
|
|
if self.cache_service:
|
|
await self.cache_service.cache_answer(
|
|
conversation.collection_id,
|
|
question,
|
|
generation
|
|
)
|
|
|
|
assistant_message = Message(
|
|
conversation_id=conversation_id,
|
|
content=generation["content"],
|
|
role=MessageRole.ASSISTANT,
|
|
sources={"chunks": generation.get("sources", [])},
|
|
)
|
|
await self.message_repo.create(assistant_message)
|
|
|
|
return {
|
|
"answer": generation["content"],
|
|
"sources": generation.get("sources", []),
|
|
"usage": generation.get("usage", {}),
|
|
}
|
|
|