import hashlib from typing import Optional from uuid import UUID from src.infrastructure.external.redis_client import RedisClient class CacheService: def __init__(self, redis_client: RedisClient, default_ttl: int = 3600 * 24): self.redis_client = redis_client self.default_ttl = default_ttl def _make_key(self, collection_id: UUID, question: str) -> str: question_hash = hashlib.sha256(question.encode()).hexdigest()[:16] return f"rag:answer:{collection_id}:{question_hash}" async def get_cached_answer(self, collection_id: UUID, question: str) -> Optional[dict]: key = self._make_key(collection_id, question) cached = await self.redis_client.get_json(key) if cached and cached.get("question") == question: return cached.get("answer") return None async def cache_answer(self, collection_id: UUID, question: str, answer: dict, ttl: Optional[int] = None): key = self._make_key(collection_id, question) value = { "question": question, "answer": answer } await self.redis_client.set_json(key, value, ttl or self.default_ttl) async def invalidate_collection_cache(self, collection_id: UUID): pattern = f"rag:answer:{collection_id}:*" keys = await self.redis_client.keys(pattern) if keys: for key in keys: await self.redis_client.delete(key) async def invalidate_document_cache(self, document_id: UUID): pattern = f"rag:answer:*" keys = await self.redis_client.keys(pattern) if keys: for key in keys: cached = await self.redis_client.get_json(key) if cached: sources = cached.get("answer", {}).get("sources", []) doc_ids = [s.get("document_id") for s in sources] if str(document_id) in doc_ids: await self.redis_client.delete(key)