redis include + client with bis logic

This commit is contained in:
Arxip222 2025-12-22 22:57:14 +03:00
parent d18cc1fb76
commit 74510ce406
6 changed files with 165 additions and 11 deletions

View File

@ -13,3 +13,4 @@ dishka==0.7.0
numpy==1.26.4
sentence-transformers==2.7.0
qdrant-client==1.9.0
redis==5.0.1

View File

@ -0,0 +1,49 @@
import hashlib
from typing import Optional
from uuid import UUID
from src.infrastructure.external.redis_client import RedisClient
class CacheService:
def __init__(self, redis_client: RedisClient, default_ttl: int = 3600 * 24):
self.redis_client = redis_client
self.default_ttl = default_ttl
def _make_key(self, collection_id: UUID, question: str) -> str:
question_hash = hashlib.sha256(question.encode()).hexdigest()[:16]
return f"rag:answer:{collection_id}:{question_hash}"
async def get_cached_answer(self, collection_id: UUID, question: str) -> Optional[dict]:
key = self._make_key(collection_id, question)
cached = await self.redis_client.get_json(key)
if cached and cached.get("question") == question:
return cached.get("answer")
return None
async def cache_answer(self, collection_id: UUID, question: str, answer: dict, ttl: Optional[int] = None):
key = self._make_key(collection_id, question)
value = {
"question": question,
"answer": answer
}
await self.reids_client.set_json(key, value, ttl or self.default_ttl)
async def invalidate_collection_cache(self, collection_id: UUID):
pattern = f"rag:answer:{collection_id}:*"
keys = await self.redis_client.keys(pattern)
if keys:
for key in keys:
await self.redis_client.delete(key)
async def invalidate_document_cache(self, document_id: UUID):
pattern = f"rag:answer:*"
keys = await self.redis_client.keys(pattern)
if keys:
for key in keys:
cached = await self.redis_client.get_json(key)
if cached:
sources = cached.get("answer", {}).get("sources", [])
doc_ids = [s.get("document_id") for s in sources]
if str(document_id) in doc_ids:
await self.redis_client.delete(key)

View File

@ -3,6 +3,7 @@ Use cases для RAG: индексация документов и ответы
"""
from uuid import UUID
from src.application.services.rag_service import RAGService
from src.application.services.cache_service import CacheService
from src.domain.repositories.document_repository import IDocumentRepository
from src.domain.repositories.conversation_repository import IConversationRepository
from src.domain.repositories.message_repository import IMessageRepository
@ -18,11 +19,13 @@ class RAGUseCases:
document_repo: IDocumentRepository,
conversation_repo: IConversationRepository,
message_repo: IMessageRepository,
cache_service: CacheService,
):
self.rag_service = rag_service
self.document_repo = document_repo
self.conversation_repo = conversation_repo
self.message_repo = message_repo
self.cache_service = cache_service
async def index_document(self, document_id: UUID) -> dict:
document = await self.document_repo.get_by_id(document_id)
@ -50,6 +53,13 @@ class RAGUseCases:
)
await self.message_repo.create(user_message)
cached_answer = None
if self.cache_service:
cached_answer = await self.cache_service.get_cached_answer(conversation.collection_id, question)
if cached_answer:
generation = cached_answer
else:
retrieved = await self.rag_service.retrieve(
query=question,
collection_id=conversation.collection_id,
@ -59,6 +69,13 @@ class RAGUseCases:
chunks = [c for c, _ in retrieved]
generation = await self.rag_service.generate_answer(question, chunks)
if self.cache_service:
await self.cache_service.cache_answer(
conversation.collection_id,
question,
generation
)
assistant_message = Message(
conversation_id=conversation_id,
content=generation["content"],

View File

@ -0,0 +1,76 @@
import json
from typing import Optional, Any
import redis.asyncio as aioredis
from src.shared.config import settings
class RedisClient:
def __init__(self, host: str, port: int):
self.host = host or settings.REDIS_HOST
self.port = port or settings.REDIS_PORT
self._client: Optional[aioredis.Redis] = None
async def connect(self):
if self._client is None:
self._client = await aioredis.from_url(
f"redis://{self.host}:{self.port}",
encoding="utf-8",
decode_responses=True
)
async def disconnect(self):
if self._client:
await self._client.aclose()
self._client = None
async def get(self, key: str) -> Optional[str]:
if self._client is None:
await self.connect()
return await self._client.get(key)
async def set(self, key: str, value: str, ttl: Optional[int] = None):
if self._client is None:
await self.connect()
if ttl:
await self._client.setex(key, ttl, value)
else:
await self._client.set(key, value)
async def get_json(self, key: str) -> Optional[dict[str, Any]]:
value = await self.get(key)
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
return None
async def set_json(self, key: str, value: dict[str, Any], ttl: Optional[int] = None):
json_str = json.dumps(value)
await self.set(key, json_str, ttl)
async def delete(self, key: str):
if self._client is None:
await self.connect()
await self._client.delete(key)
async def exists(self, key: str) -> bool:
if self._client is None:
await self.connect()
return bool(await self._client.exists(key))
async def incr(self, key: str) -> int:
if self._client is None:
await self.connect()
return await self._client.incr(key)
async def expire(self, key: str, seconds: int):
if self._client is None:
await self.connect()
await self._client.expire(key, seconds)
async def keys(self, pattern: str) -> list[str]:
if self._client is None:
await self.connect()
return await self._client.keys(pattern)

View File

@ -13,7 +13,7 @@ from src.shared.config import settings
from src.shared.exceptions import LawyerAIException
from src.shared.di_container import create_container
from src.presentation.middleware.error_handler import exception_handler
from src.presentation.api.v1 import users, collections, documents, conversations, messages
from src.presentation.api.v1 import users, collections, documents, conversations, messages, rag
from src.infrastructure.database.base import engine, Base

View File

@ -19,7 +19,9 @@ from src.domain.repositories.collection_access_repository import ICollectionAcce
from src.domain.repositories.vector_repository import IVectorRepository
from src.infrastructure.external.yandex_ocr import YandexOCRService
from src.infrastructure.external.deepseek_client import DeepSeekClient
from src.infrastructure.external.redis_client import RedisClient
from src.application.services.document_parser_service import DocumentParserService
from src.application.services.cache_service import CacheService
from src.application.use_cases.user_use_cases import UserUseCases
from src.application.use_cases.collection_use_cases import CollectionUseCases
from src.application.use_cases.document_use_cases import DocumentUseCases
@ -73,6 +75,14 @@ class RepositoryProvider(Provider):
class ServiceProvider(Provider):
@provide(scope=Scope.APP)
def get_redis_client(self) -> RedisClient:
return RedisClient()
@provide(scope=Scope.APP)
def get_cache_service(self, redis_client: RedisClient) -> CacheService:
return CacheService(redis_client)
@provide(scope=Scope.APP)
def get_ocr_service(self) -> YandexOCRService:
return YandexOCRService()
@ -180,9 +190,10 @@ class UseCaseProvider(Provider):
rag_service: RAGService,
document_repo: IDocumentRepository,
conversation_repo: IConversationRepository,
message_repo: IMessageRepository
message_repo: IMessageRepository,
cache_service: CacheService
) -> RAGUseCases:
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo)
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo, cache_service)
def create_container() -> Container: