forked from HSE_team/BetterCallPraskovia
redis include + client with bis logic
This commit is contained in:
parent
d18cc1fb76
commit
74510ce406
@ -13,3 +13,4 @@ dishka==0.7.0
|
|||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
sentence-transformers==2.7.0
|
sentence-transformers==2.7.0
|
||||||
qdrant-client==1.9.0
|
qdrant-client==1.9.0
|
||||||
|
redis==5.0.1
|
||||||
|
|||||||
49
backend/src/application/services/cache_service.py
Normal file
49
backend/src/application/services/cache_service.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import hashlib
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
from src.infrastructure.external.redis_client import RedisClient
|
||||||
|
|
||||||
|
|
||||||
|
class CacheService:
|
||||||
|
def __init__(self, redis_client: RedisClient, default_ttl: int = 3600 * 24):
|
||||||
|
self.redis_client = redis_client
|
||||||
|
self.default_ttl = default_ttl
|
||||||
|
|
||||||
|
def _make_key(self, collection_id: UUID, question: str) -> str:
|
||||||
|
question_hash = hashlib.sha256(question.encode()).hexdigest()[:16]
|
||||||
|
return f"rag:answer:{collection_id}:{question_hash}"
|
||||||
|
|
||||||
|
async def get_cached_answer(self, collection_id: UUID, question: str) -> Optional[dict]:
|
||||||
|
key = self._make_key(collection_id, question)
|
||||||
|
cached = await self.redis_client.get_json(key)
|
||||||
|
if cached and cached.get("question") == question:
|
||||||
|
return cached.get("answer")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def cache_answer(self, collection_id: UUID, question: str, answer: dict, ttl: Optional[int] = None):
|
||||||
|
key = self._make_key(collection_id, question)
|
||||||
|
value = {
|
||||||
|
"question": question,
|
||||||
|
"answer": answer
|
||||||
|
}
|
||||||
|
await self.reids_client.set_json(key, value, ttl or self.default_ttl)
|
||||||
|
|
||||||
|
async def invalidate_collection_cache(self, collection_id: UUID):
|
||||||
|
pattern = f"rag:answer:{collection_id}:*"
|
||||||
|
keys = await self.redis_client.keys(pattern)
|
||||||
|
if keys:
|
||||||
|
for key in keys:
|
||||||
|
await self.redis_client.delete(key)
|
||||||
|
|
||||||
|
async def invalidate_document_cache(self, document_id: UUID):
|
||||||
|
pattern = f"rag:answer:*"
|
||||||
|
keys = await self.redis_client.keys(pattern)
|
||||||
|
if keys:
|
||||||
|
for key in keys:
|
||||||
|
cached = await self.redis_client.get_json(key)
|
||||||
|
if cached:
|
||||||
|
sources = cached.get("answer", {}).get("sources", [])
|
||||||
|
doc_ids = [s.get("document_id") for s in sources]
|
||||||
|
if str(document_id) in doc_ids:
|
||||||
|
await self.redis_client.delete(key)
|
||||||
|
|
||||||
@ -3,6 +3,7 @@ Use cases для RAG: индексация документов и ответы
|
|||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from src.application.services.rag_service import RAGService
|
from src.application.services.rag_service import RAGService
|
||||||
|
from src.application.services.cache_service import CacheService
|
||||||
from src.domain.repositories.document_repository import IDocumentRepository
|
from src.domain.repositories.document_repository import IDocumentRepository
|
||||||
from src.domain.repositories.conversation_repository import IConversationRepository
|
from src.domain.repositories.conversation_repository import IConversationRepository
|
||||||
from src.domain.repositories.message_repository import IMessageRepository
|
from src.domain.repositories.message_repository import IMessageRepository
|
||||||
@ -18,11 +19,13 @@ class RAGUseCases:
|
|||||||
document_repo: IDocumentRepository,
|
document_repo: IDocumentRepository,
|
||||||
conversation_repo: IConversationRepository,
|
conversation_repo: IConversationRepository,
|
||||||
message_repo: IMessageRepository,
|
message_repo: IMessageRepository,
|
||||||
|
cache_service: CacheService,
|
||||||
):
|
):
|
||||||
self.rag_service = rag_service
|
self.rag_service = rag_service
|
||||||
self.document_repo = document_repo
|
self.document_repo = document_repo
|
||||||
self.conversation_repo = conversation_repo
|
self.conversation_repo = conversation_repo
|
||||||
self.message_repo = message_repo
|
self.message_repo = message_repo
|
||||||
|
self.cache_service = cache_service
|
||||||
|
|
||||||
async def index_document(self, document_id: UUID) -> dict:
|
async def index_document(self, document_id: UUID) -> dict:
|
||||||
document = await self.document_repo.get_by_id(document_id)
|
document = await self.document_repo.get_by_id(document_id)
|
||||||
@ -50,14 +53,28 @@ class RAGUseCases:
|
|||||||
)
|
)
|
||||||
await self.message_repo.create(user_message)
|
await self.message_repo.create(user_message)
|
||||||
|
|
||||||
retrieved = await self.rag_service.retrieve(
|
cached_answer = None
|
||||||
query=question,
|
if self.cache_service:
|
||||||
collection_id=conversation.collection_id,
|
cached_answer = await self.cache_service.get_cached_answer(conversation.collection_id, question)
|
||||||
limit=top_k,
|
|
||||||
rerank_top_n=rerank_top_n,
|
if cached_answer:
|
||||||
)
|
generation = cached_answer
|
||||||
chunks = [c for c, _ in retrieved]
|
else:
|
||||||
generation = await self.rag_service.generate_answer(question, chunks)
|
retrieved = await self.rag_service.retrieve(
|
||||||
|
query=question,
|
||||||
|
collection_id=conversation.collection_id,
|
||||||
|
limit=top_k,
|
||||||
|
rerank_top_n=rerank_top_n,
|
||||||
|
)
|
||||||
|
chunks = [c for c, _ in retrieved]
|
||||||
|
generation = await self.rag_service.generate_answer(question, chunks)
|
||||||
|
|
||||||
|
if self.cache_service:
|
||||||
|
await self.cache_service.cache_answer(
|
||||||
|
conversation.collection_id,
|
||||||
|
question,
|
||||||
|
generation
|
||||||
|
)
|
||||||
|
|
||||||
assistant_message = Message(
|
assistant_message = Message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
|||||||
76
backend/src/infrastructure/external/redis_client.py
vendored
Normal file
76
backend/src/infrastructure/external/redis_client.py
vendored
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import json
|
||||||
|
from typing import Optional, Any
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from src.shared.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class RedisClient:
|
||||||
|
def __init__(self, host: str, port: int):
|
||||||
|
self.host = host or settings.REDIS_HOST
|
||||||
|
self.port = port or settings.REDIS_PORT
|
||||||
|
self._client: Optional[aioredis.Redis] = None
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
if self._client is None:
|
||||||
|
self._client = await aioredis.from_url(
|
||||||
|
f"redis://{self.host}:{self.port}",
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[str]:
|
||||||
|
if self._client is None:
|
||||||
|
await self.connect()
|
||||||
|
return await self._client.get(key)
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str, ttl: Optional[int] = None):
|
||||||
|
if self._client is None:
|
||||||
|
await self.connect()
|
||||||
|
if ttl:
|
||||||
|
await self._client.setex(key, ttl, value)
|
||||||
|
else:
|
||||||
|
await self._client.set(key, value)
|
||||||
|
|
||||||
|
async def get_json(self, key: str) -> Optional[dict[str, Any]]:
|
||||||
|
value = await self.get(key)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_json(self, key: str, value: dict[str, Any], ttl: Optional[int] = None):
|
||||||
|
json_str = json.dumps(value)
|
||||||
|
await self.set(key, json_str, ttl)
|
||||||
|
|
||||||
|
async def delete(self, key: str):
|
||||||
|
if self._client is None:
|
||||||
|
await self.connect()
|
||||||
|
await self._client.delete(key)
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
if self._client is None:
|
||||||
|
await self.connect()
|
||||||
|
return bool(await self._client.exists(key))
|
||||||
|
|
||||||
|
async def incr(self, key: str) -> int:
|
||||||
|
if self._client is None:
|
||||||
|
await self.connect()
|
||||||
|
return await self._client.incr(key)
|
||||||
|
|
||||||
|
async def expire(self, key: str, seconds: int):
|
||||||
|
if self._client is None:
|
||||||
|
await self.connect()
|
||||||
|
await self._client.expire(key, seconds)
|
||||||
|
|
||||||
|
async def keys(self, pattern: str) -> list[str]:
|
||||||
|
if self._client is None:
|
||||||
|
await self.connect()
|
||||||
|
return await self._client.keys(pattern)
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ from src.shared.config import settings
|
|||||||
from src.shared.exceptions import LawyerAIException
|
from src.shared.exceptions import LawyerAIException
|
||||||
from src.shared.di_container import create_container
|
from src.shared.di_container import create_container
|
||||||
from src.presentation.middleware.error_handler import exception_handler
|
from src.presentation.middleware.error_handler import exception_handler
|
||||||
from src.presentation.api.v1 import users, collections, documents, conversations, messages
|
from src.presentation.api.v1 import users, collections, documents, conversations, messages, rag
|
||||||
from src.infrastructure.database.base import engine, Base
|
from src.infrastructure.database.base import engine, Base
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,9 @@ from src.domain.repositories.collection_access_repository import ICollectionAcce
|
|||||||
from src.domain.repositories.vector_repository import IVectorRepository
|
from src.domain.repositories.vector_repository import IVectorRepository
|
||||||
from src.infrastructure.external.yandex_ocr import YandexOCRService
|
from src.infrastructure.external.yandex_ocr import YandexOCRService
|
||||||
from src.infrastructure.external.deepseek_client import DeepSeekClient
|
from src.infrastructure.external.deepseek_client import DeepSeekClient
|
||||||
|
from src.infrastructure.external.redis_client import RedisClient
|
||||||
from src.application.services.document_parser_service import DocumentParserService
|
from src.application.services.document_parser_service import DocumentParserService
|
||||||
|
from src.application.services.cache_service import CacheService
|
||||||
from src.application.use_cases.user_use_cases import UserUseCases
|
from src.application.use_cases.user_use_cases import UserUseCases
|
||||||
from src.application.use_cases.collection_use_cases import CollectionUseCases
|
from src.application.use_cases.collection_use_cases import CollectionUseCases
|
||||||
from src.application.use_cases.document_use_cases import DocumentUseCases
|
from src.application.use_cases.document_use_cases import DocumentUseCases
|
||||||
@ -73,6 +75,14 @@ class RepositoryProvider(Provider):
|
|||||||
|
|
||||||
|
|
||||||
class ServiceProvider(Provider):
|
class ServiceProvider(Provider):
|
||||||
|
@provide(scope=Scope.APP)
|
||||||
|
def get_redis_client(self) -> RedisClient:
|
||||||
|
return RedisClient()
|
||||||
|
|
||||||
|
@provide(scope=Scope.APP)
|
||||||
|
def get_cache_service(self, redis_client: RedisClient) -> CacheService:
|
||||||
|
return CacheService(redis_client)
|
||||||
|
|
||||||
@provide(scope=Scope.APP)
|
@provide(scope=Scope.APP)
|
||||||
def get_ocr_service(self) -> YandexOCRService:
|
def get_ocr_service(self) -> YandexOCRService:
|
||||||
return YandexOCRService()
|
return YandexOCRService()
|
||||||
@ -180,9 +190,10 @@ class UseCaseProvider(Provider):
|
|||||||
rag_service: RAGService,
|
rag_service: RAGService,
|
||||||
document_repo: IDocumentRepository,
|
document_repo: IDocumentRepository,
|
||||||
conversation_repo: IConversationRepository,
|
conversation_repo: IConversationRepository,
|
||||||
message_repo: IMessageRepository
|
message_repo: IMessageRepository,
|
||||||
|
cache_service: CacheService
|
||||||
) -> RAGUseCases:
|
) -> RAGUseCases:
|
||||||
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo)
|
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo, cache_service)
|
||||||
|
|
||||||
|
|
||||||
def create_container() -> Container:
|
def create_container() -> Container:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user