2025-12-22 22:57:14 +03:00

93 lines
3.4 KiB
Python

"""
Use cases для RAG: индексация документов и ответы на вопросы
"""
from uuid import UUID
from src.application.services.rag_service import RAGService
from src.application.services.cache_service import CacheService
from src.domain.repositories.document_repository import IDocumentRepository
from src.domain.repositories.conversation_repository import IConversationRepository
from src.domain.repositories.message_repository import IMessageRepository
from src.domain.entities.message import Message, MessageRole
from src.shared.exceptions import NotFoundError, ForbiddenError
class RAGUseCases:
def __init__(
self,
rag_service: RAGService,
document_repo: IDocumentRepository,
conversation_repo: IConversationRepository,
message_repo: IMessageRepository,
cache_service: CacheService,
):
self.rag_service = rag_service
self.document_repo = document_repo
self.conversation_repo = conversation_repo
self.message_repo = message_repo
self.cache_service = cache_service
async def index_document(self, document_id: UUID) -> dict:
document = await self.document_repo.get_by_id(document_id)
if not document:
raise NotFoundError(f"Документ {document_id} не найден")
chunks = await self.rag_service.index_document(document)
return {"chunks_indexed": len(chunks)}
async def ask_question(
self,
conversation_id: UUID,
user_id: UUID,
question: str,
top_k: int = 20,
rerank_top_n: int = 5,
) -> dict:
conversation = await self.conversation_repo.get_by_id(conversation_id)
if not conversation:
raise NotFoundError(f"Беседа {conversation_id} не найдена")
if conversation.user_id != user_id:
raise ForbiddenError("Нет доступа к этой беседе")
user_message = Message(
conversation_id=conversation_id, content=question, role=MessageRole.USER
)
await self.message_repo.create(user_message)
cached_answer = None
if self.cache_service:
cached_answer = await self.cache_service.get_cached_answer(conversation.collection_id, question)
if cached_answer:
generation = cached_answer
else:
retrieved = await self.rag_service.retrieve(
query=question,
collection_id=conversation.collection_id,
limit=top_k,
rerank_top_n=rerank_top_n,
)
chunks = [c for c, _ in retrieved]
generation = await self.rag_service.generate_answer(question, chunks)
if self.cache_service:
await self.cache_service.cache_answer(
conversation.collection_id,
question,
generation
)
assistant_message = Message(
conversation_id=conversation_id,
content=generation["content"],
role=MessageRole.ASSISTANT,
sources={"chunks": generation.get("sources", [])},
)
await self.message_repo.create(assistant_message)
return {
"answer": generation["content"],
"sources": generation.get("sources", []),
"usage": generation.get("usage", {}),
}