forked from HSE_team/BetterCallPraskovia
50 lines
1.9 KiB
Python
50 lines
1.9 KiB
Python
import hashlib
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
from src.infrastructure.external.redis_client import RedisClient
|
|
|
|
|
|
class CacheService:
|
|
def __init__(self, redis_client: RedisClient, default_ttl: int = 3600 * 24):
|
|
self.redis_client = redis_client
|
|
self.default_ttl = default_ttl
|
|
|
|
def _make_key(self, collection_id: UUID, question: str) -> str:
|
|
question_hash = hashlib.sha256(question.encode()).hexdigest()[:16]
|
|
return f"rag:answer:{collection_id}:{question_hash}"
|
|
|
|
async def get_cached_answer(self, collection_id: UUID, question: str) -> Optional[dict]:
|
|
key = self._make_key(collection_id, question)
|
|
cached = await self.redis_client.get_json(key)
|
|
if cached and cached.get("question") == question:
|
|
return cached.get("answer")
|
|
return None
|
|
|
|
async def cache_answer(self, collection_id: UUID, question: str, answer: dict, ttl: Optional[int] = None):
|
|
key = self._make_key(collection_id, question)
|
|
value = {
|
|
"question": question,
|
|
"answer": answer
|
|
}
|
|
await self.redis_client.set_json(key, value, ttl or self.default_ttl)
|
|
|
|
async def invalidate_collection_cache(self, collection_id: UUID):
|
|
pattern = f"rag:answer:{collection_id}:*"
|
|
keys = await self.redis_client.keys(pattern)
|
|
if keys:
|
|
for key in keys:
|
|
await self.redis_client.delete(key)
|
|
|
|
async def invalidate_document_cache(self, document_id: UUID):
|
|
pattern = f"rag:answer:*"
|
|
keys = await self.redis_client.keys(pattern)
|
|
if keys:
|
|
for key in keys:
|
|
cached = await self.redis_client.get_json(key)
|
|
if cached:
|
|
sources = cached.get("answer", {}).get("sources", [])
|
|
doc_ids = [s.get("document_id") for s in sources]
|
|
if str(document_id) in doc_ids:
|
|
await self.redis_client.delete(key)
|
|
|