diff --git a/backend/requirements.txt b/backend/requirements.txt index 4ec21aa..69742a0 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -13,3 +13,4 @@ dishka==0.7.0 numpy==1.26.4 sentence-transformers==2.7.0 qdrant-client==1.9.0 +redis==5.0.1 diff --git a/backend/src/application/services/cache_service.py b/backend/src/application/services/cache_service.py new file mode 100644 index 0000000..5d3c561 --- /dev/null +++ b/backend/src/application/services/cache_service.py @@ -0,0 +1,49 @@ +import hashlib +from typing import Optional +from uuid import UUID +from src.infrastructure.external.redis_client import RedisClient + + +class CacheService: + def __init__(self, redis_client: RedisClient, default_ttl: int = 3600 * 24): + self.redis_client = redis_client + self.default_ttl = default_ttl + + def _make_key(self, collection_id: UUID, question: str) -> str: + question_hash = hashlib.sha256(question.encode()).hexdigest()[:16] + return f"rag:answer:{collection_id}:{question_hash}" + + async def get_cached_answer(self, collection_id: UUID, question: str) -> Optional[dict]: + key = self._make_key(collection_id, question) + cached = await self.redis_client.get_json(key) + if cached and cached.get("question") == question: + return cached.get("answer") + return None + + async def cache_answer(self, collection_id: UUID, question: str, answer: dict, ttl: Optional[int] = None): + key = self._make_key(collection_id, question) + value = { + "question": question, + "answer": answer + } + await self.reids_client.set_json(key, value, ttl or self.default_ttl) + + async def invalidate_collection_cache(self, collection_id: UUID): + pattern = f"rag:answer:{collection_id}:*" + keys = await self.redis_client.keys(pattern) + if keys: + for key in keys: + await self.redis_client.delete(key) + + async def invalidate_document_cache(self, document_id: UUID): + pattern = f"rag:answer:*" + keys = await self.redis_client.keys(pattern) + if keys: + for key in keys: + cached = await self.redis_client.get_json(key) + if cached: + sources = cached.get("answer", {}).get("sources", []) + doc_ids = [s.get("document_id") for s in sources] + if str(document_id) in doc_ids: + await self.redis_client.delete(key) + diff --git a/backend/src/application/use_cases/rag_use_cases.py b/backend/src/application/use_cases/rag_use_cases.py index 427c146..e0a2ce5 100644 --- a/backend/src/application/use_cases/rag_use_cases.py +++ b/backend/src/application/use_cases/rag_use_cases.py @@ -3,6 +3,7 @@ Use cases для RAG: индексация документов и ответы """ from uuid import UUID from src.application.services.rag_service import RAGService +from src.application.services.cache_service import CacheService from src.domain.repositories.document_repository import IDocumentRepository from src.domain.repositories.conversation_repository import IConversationRepository from src.domain.repositories.message_repository import IMessageRepository @@ -18,11 +19,13 @@ class RAGUseCases: document_repo: IDocumentRepository, conversation_repo: IConversationRepository, message_repo: IMessageRepository, + cache_service: CacheService, ): self.rag_service = rag_service self.document_repo = document_repo self.conversation_repo = conversation_repo self.message_repo = message_repo + self.cache_service = cache_service async def index_document(self, document_id: UUID) -> dict: document = await self.document_repo.get_by_id(document_id) @@ -50,14 +53,28 @@ class RAGUseCases: ) await self.message_repo.create(user_message) - retrieved = await self.rag_service.retrieve( - query=question, - collection_id=conversation.collection_id, - limit=top_k, - rerank_top_n=rerank_top_n, - ) - chunks = [c for c, _ in retrieved] - generation = await self.rag_service.generate_answer(question, chunks) + cached_answer = None + if self.cache_service: + cached_answer = await self.cache_service.get_cached_answer(conversation.collection_id, question) + + if cached_answer: + generation = cached_answer + else: + retrieved = await self.rag_service.retrieve( + query=question, + collection_id=conversation.collection_id, + limit=top_k, + rerank_top_n=rerank_top_n, + ) + chunks = [c for c, _ in retrieved] + generation = await self.rag_service.generate_answer(question, chunks) + + if self.cache_service: + await self.cache_service.cache_answer( + conversation.collection_id, + question, + generation + ) assistant_message = Message( conversation_id=conversation_id, diff --git a/backend/src/infrastructure/external/redis_client.py b/backend/src/infrastructure/external/redis_client.py new file mode 100644 index 0000000..db7839b --- /dev/null +++ b/backend/src/infrastructure/external/redis_client.py @@ -0,0 +1,76 @@ +import json +from typing import Optional, Any +import redis.asyncio as aioredis +from src.shared.config import settings + + +class RedisClient: + def __init__(self, host: str, port: int): + self.host = host or settings.REDIS_HOST + self.port = port or settings.REDIS_PORT + self._client: Optional[aioredis.Redis] = None + + async def connect(self): + if self._client is None: + self._client = await aioredis.from_url( + f"redis://{self.host}:{self.port}", + encoding="utf-8", + decode_responses=True + ) + + async def disconnect(self): + if self._client: + await self._client.aclose() + self._client = None + + async def get(self, key: str) -> Optional[str]: + if self._client is None: + await self.connect() + return await self._client.get(key) + + async def set(self, key: str, value: str, ttl: Optional[int] = None): + if self._client is None: + await self.connect() + if ttl: + await self._client.setex(key, ttl, value) + else: + await self._client.set(key, value) + + async def get_json(self, key: str) -> Optional[dict[str, Any]]: + value = await self.get(key) + if value is None: + return None + try: + return json.loads(value) + except json.JSONDecodeError: + return None + + async def set_json(self, key: str, value: dict[str, Any], ttl: Optional[int] = None): + json_str = json.dumps(value) + await self.set(key, json_str, ttl) + + async def delete(self, key: str): + if self._client is None: + await self.connect() + await self._client.delete(key) + + async def exists(self, key: str) -> bool: + if self._client is None: + await self.connect() + return bool(await self._client.exists(key)) + + async def incr(self, key: str) -> int: + if self._client is None: + await self.connect() + return await self._client.incr(key) + + async def expire(self, key: str, seconds: int): + if self._client is None: + await self.connect() + await self._client.expire(key, seconds) + + async def keys(self, pattern: str) -> list[str]: + if self._client is None: + await self.connect() + return await self._client.keys(pattern) + diff --git a/backend/src/presentation/main.py b/backend/src/presentation/main.py index b84c6e6..05122c5 100644 --- a/backend/src/presentation/main.py +++ b/backend/src/presentation/main.py @@ -13,7 +13,7 @@ from src.shared.config import settings from src.shared.exceptions import LawyerAIException from src.shared.di_container import create_container from src.presentation.middleware.error_handler import exception_handler -from src.presentation.api.v1 import users, collections, documents, conversations, messages +from src.presentation.api.v1 import users, collections, documents, conversations, messages, rag from src.infrastructure.database.base import engine, Base diff --git a/backend/src/shared/di_container.py b/backend/src/shared/di_container.py index 6f7db9e..a5589e0 100644 --- a/backend/src/shared/di_container.py +++ b/backend/src/shared/di_container.py @@ -19,7 +19,9 @@ from src.domain.repositories.collection_access_repository import ICollectionAcce from src.domain.repositories.vector_repository import IVectorRepository from src.infrastructure.external.yandex_ocr import YandexOCRService from src.infrastructure.external.deepseek_client import DeepSeekClient +from src.infrastructure.external.redis_client import RedisClient from src.application.services.document_parser_service import DocumentParserService +from src.application.services.cache_service import CacheService from src.application.use_cases.user_use_cases import UserUseCases from src.application.use_cases.collection_use_cases import CollectionUseCases from src.application.use_cases.document_use_cases import DocumentUseCases @@ -73,6 +75,14 @@ class RepositoryProvider(Provider): class ServiceProvider(Provider): + @provide(scope=Scope.APP) + def get_redis_client(self) -> RedisClient: + return RedisClient() + + @provide(scope=Scope.APP) + def get_cache_service(self, redis_client: RedisClient) -> CacheService: + return CacheService(redis_client) + @provide(scope=Scope.APP) def get_ocr_service(self) -> YandexOCRService: return YandexOCRService() @@ -180,9 +190,10 @@ class UseCaseProvider(Provider): rag_service: RAGService, document_repo: IDocumentRepository, conversation_repo: IConversationRepository, - message_repo: IMessageRepository + message_repo: IMessageRepository, + cache_service: CacheService ) -> RAGUseCases: - return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo) + return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo, cache_service) def create_container() -> Container: