2025-12-24 03:14:37 +03:00

50 lines
1.9 KiB
Python

import hashlib
from typing import Optional
from uuid import UUID
from src.infrastructure.external.redis_client import RedisClient
class CacheService:
def __init__(self, redis_client: RedisClient, default_ttl: int = 3600 * 24):
self.redis_client = redis_client
self.default_ttl = default_ttl
def _make_key(self, collection_id: UUID, question: str) -> str:
question_hash = hashlib.sha256(question.encode()).hexdigest()[:16]
return f"rag:answer:{collection_id}:{question_hash}"
async def get_cached_answer(self, collection_id: UUID, question: str) -> Optional[dict]:
key = self._make_key(collection_id, question)
cached = await self.redis_client.get_json(key)
if cached and cached.get("question") == question:
return cached.get("answer")
return None
async def cache_answer(self, collection_id: UUID, question: str, answer: dict, ttl: Optional[int] = None):
key = self._make_key(collection_id, question)
value = {
"question": question,
"answer": answer
}
await self.redis_client.set_json(key, value, ttl or self.default_ttl)
async def invalidate_collection_cache(self, collection_id: UUID):
pattern = f"rag:answer:{collection_id}:*"
keys = await self.redis_client.keys(pattern)
if keys:
for key in keys:
await self.redis_client.delete(key)
async def invalidate_document_cache(self, document_id: UUID):
pattern = f"rag:answer:*"
keys = await self.redis_client.keys(pattern)
if keys:
for key in keys:
cached = await self.redis_client.get_json(key)
if cached:
sources = cached.get("answer", {}).get("sources", [])
doc_ids = [s.get("document_id") for s in sources]
if str(document_id) in doc_ids:
await self.redis_client.delete(key)