forked from HSE_team/BetterCallPraskovia
redis include + client with bis logic
This commit is contained in:
parent
d18cc1fb76
commit
74510ce406
@ -13,3 +13,4 @@ dishka==0.7.0
|
||||
numpy==1.26.4
|
||||
sentence-transformers==2.7.0
|
||||
qdrant-client==1.9.0
|
||||
redis==5.0.1
|
||||
|
||||
49
backend/src/application/services/cache_service.py
Normal file
49
backend/src/application/services/cache_service.py
Normal file
@ -0,0 +1,49 @@
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from src.infrastructure.external.redis_client import RedisClient
|
||||
|
||||
|
||||
class CacheService:
|
||||
def __init__(self, redis_client: RedisClient, default_ttl: int = 3600 * 24):
|
||||
self.redis_client = redis_client
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
def _make_key(self, collection_id: UUID, question: str) -> str:
|
||||
question_hash = hashlib.sha256(question.encode()).hexdigest()[:16]
|
||||
return f"rag:answer:{collection_id}:{question_hash}"
|
||||
|
||||
async def get_cached_answer(self, collection_id: UUID, question: str) -> Optional[dict]:
|
||||
key = self._make_key(collection_id, question)
|
||||
cached = await self.redis_client.get_json(key)
|
||||
if cached and cached.get("question") == question:
|
||||
return cached.get("answer")
|
||||
return None
|
||||
|
||||
async def cache_answer(self, collection_id: UUID, question: str, answer: dict, ttl: Optional[int] = None):
|
||||
key = self._make_key(collection_id, question)
|
||||
value = {
|
||||
"question": question,
|
||||
"answer": answer
|
||||
}
|
||||
await self.reids_client.set_json(key, value, ttl or self.default_ttl)
|
||||
|
||||
async def invalidate_collection_cache(self, collection_id: UUID):
|
||||
pattern = f"rag:answer:{collection_id}:*"
|
||||
keys = await self.redis_client.keys(pattern)
|
||||
if keys:
|
||||
for key in keys:
|
||||
await self.redis_client.delete(key)
|
||||
|
||||
async def invalidate_document_cache(self, document_id: UUID):
|
||||
pattern = f"rag:answer:*"
|
||||
keys = await self.redis_client.keys(pattern)
|
||||
if keys:
|
||||
for key in keys:
|
||||
cached = await self.redis_client.get_json(key)
|
||||
if cached:
|
||||
sources = cached.get("answer", {}).get("sources", [])
|
||||
doc_ids = [s.get("document_id") for s in sources]
|
||||
if str(document_id) in doc_ids:
|
||||
await self.redis_client.delete(key)
|
||||
|
||||
@ -3,6 +3,7 @@ Use cases для RAG: индексация документов и ответы
|
||||
"""
|
||||
from uuid import UUID
|
||||
from src.application.services.rag_service import RAGService
|
||||
from src.application.services.cache_service import CacheService
|
||||
from src.domain.repositories.document_repository import IDocumentRepository
|
||||
from src.domain.repositories.conversation_repository import IConversationRepository
|
||||
from src.domain.repositories.message_repository import IMessageRepository
|
||||
@ -18,11 +19,13 @@ class RAGUseCases:
|
||||
document_repo: IDocumentRepository,
|
||||
conversation_repo: IConversationRepository,
|
||||
message_repo: IMessageRepository,
|
||||
cache_service: CacheService,
|
||||
):
|
||||
self.rag_service = rag_service
|
||||
self.document_repo = document_repo
|
||||
self.conversation_repo = conversation_repo
|
||||
self.message_repo = message_repo
|
||||
self.cache_service = cache_service
|
||||
|
||||
async def index_document(self, document_id: UUID) -> dict:
|
||||
document = await self.document_repo.get_by_id(document_id)
|
||||
@ -50,14 +53,28 @@ class RAGUseCases:
|
||||
)
|
||||
await self.message_repo.create(user_message)
|
||||
|
||||
retrieved = await self.rag_service.retrieve(
|
||||
query=question,
|
||||
collection_id=conversation.collection_id,
|
||||
limit=top_k,
|
||||
rerank_top_n=rerank_top_n,
|
||||
)
|
||||
chunks = [c for c, _ in retrieved]
|
||||
generation = await self.rag_service.generate_answer(question, chunks)
|
||||
cached_answer = None
|
||||
if self.cache_service:
|
||||
cached_answer = await self.cache_service.get_cached_answer(conversation.collection_id, question)
|
||||
|
||||
if cached_answer:
|
||||
generation = cached_answer
|
||||
else:
|
||||
retrieved = await self.rag_service.retrieve(
|
||||
query=question,
|
||||
collection_id=conversation.collection_id,
|
||||
limit=top_k,
|
||||
rerank_top_n=rerank_top_n,
|
||||
)
|
||||
chunks = [c for c, _ in retrieved]
|
||||
generation = await self.rag_service.generate_answer(question, chunks)
|
||||
|
||||
if self.cache_service:
|
||||
await self.cache_service.cache_answer(
|
||||
conversation.collection_id,
|
||||
question,
|
||||
generation
|
||||
)
|
||||
|
||||
assistant_message = Message(
|
||||
conversation_id=conversation_id,
|
||||
|
||||
76
backend/src/infrastructure/external/redis_client.py
vendored
Normal file
76
backend/src/infrastructure/external/redis_client.py
vendored
Normal file
@ -0,0 +1,76 @@
|
||||
import json
|
||||
from typing import Optional, Any
|
||||
import redis.asyncio as aioredis
|
||||
from src.shared.config import settings
|
||||
|
||||
|
||||
class RedisClient:
|
||||
def __init__(self, host: str, port: int):
|
||||
self.host = host or settings.REDIS_HOST
|
||||
self.port = port or settings.REDIS_PORT
|
||||
self._client: Optional[aioredis.Redis] = None
|
||||
|
||||
async def connect(self):
|
||||
if self._client is None:
|
||||
self._client = await aioredis.from_url(
|
||||
f"redis://{self.host}:{self.port}",
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
|
||||
async def disconnect(self):
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
if self._client is None:
|
||||
await self.connect()
|
||||
return await self._client.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, ttl: Optional[int] = None):
|
||||
if self._client is None:
|
||||
await self.connect()
|
||||
if ttl:
|
||||
await self._client.setex(key, ttl, value)
|
||||
else:
|
||||
await self._client.set(key, value)
|
||||
|
||||
async def get_json(self, key: str) -> Optional[dict[str, Any]]:
|
||||
value = await self.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
async def set_json(self, key: str, value: dict[str, Any], ttl: Optional[int] = None):
|
||||
json_str = json.dumps(value)
|
||||
await self.set(key, json_str, ttl)
|
||||
|
||||
async def delete(self, key: str):
|
||||
if self._client is None:
|
||||
await self.connect()
|
||||
await self._client.delete(key)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
if self._client is None:
|
||||
await self.connect()
|
||||
return bool(await self._client.exists(key))
|
||||
|
||||
async def incr(self, key: str) -> int:
|
||||
if self._client is None:
|
||||
await self.connect()
|
||||
return await self._client.incr(key)
|
||||
|
||||
async def expire(self, key: str, seconds: int):
|
||||
if self._client is None:
|
||||
await self.connect()
|
||||
await self._client.expire(key, seconds)
|
||||
|
||||
async def keys(self, pattern: str) -> list[str]:
|
||||
if self._client is None:
|
||||
await self.connect()
|
||||
return await self._client.keys(pattern)
|
||||
|
||||
@ -13,7 +13,7 @@ from src.shared.config import settings
|
||||
from src.shared.exceptions import LawyerAIException
|
||||
from src.shared.di_container import create_container
|
||||
from src.presentation.middleware.error_handler import exception_handler
|
||||
from src.presentation.api.v1 import users, collections, documents, conversations, messages
|
||||
from src.presentation.api.v1 import users, collections, documents, conversations, messages, rag
|
||||
from src.infrastructure.database.base import engine, Base
|
||||
|
||||
|
||||
|
||||
@ -19,7 +19,9 @@ from src.domain.repositories.collection_access_repository import ICollectionAcce
|
||||
from src.domain.repositories.vector_repository import IVectorRepository
|
||||
from src.infrastructure.external.yandex_ocr import YandexOCRService
|
||||
from src.infrastructure.external.deepseek_client import DeepSeekClient
|
||||
from src.infrastructure.external.redis_client import RedisClient
|
||||
from src.application.services.document_parser_service import DocumentParserService
|
||||
from src.application.services.cache_service import CacheService
|
||||
from src.application.use_cases.user_use_cases import UserUseCases
|
||||
from src.application.use_cases.collection_use_cases import CollectionUseCases
|
||||
from src.application.use_cases.document_use_cases import DocumentUseCases
|
||||
@ -73,6 +75,14 @@ class RepositoryProvider(Provider):
|
||||
|
||||
|
||||
class ServiceProvider(Provider):
|
||||
@provide(scope=Scope.APP)
|
||||
def get_redis_client(self) -> RedisClient:
|
||||
return RedisClient()
|
||||
|
||||
@provide(scope=Scope.APP)
|
||||
def get_cache_service(self, redis_client: RedisClient) -> CacheService:
|
||||
return CacheService(redis_client)
|
||||
|
||||
@provide(scope=Scope.APP)
|
||||
def get_ocr_service(self) -> YandexOCRService:
|
||||
return YandexOCRService()
|
||||
@ -180,9 +190,10 @@ class UseCaseProvider(Provider):
|
||||
rag_service: RAGService,
|
||||
document_repo: IDocumentRepository,
|
||||
conversation_repo: IConversationRepository,
|
||||
message_repo: IMessageRepository
|
||||
message_repo: IMessageRepository,
|
||||
cache_service: CacheService
|
||||
) -> RAGUseCases:
|
||||
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo)
|
||||
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo, cache_service)
|
||||
|
||||
|
||||
def create_container() -> Container:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user