2025-12-24 03:14:37 +03:00

199 lines
8.1 KiB
Python

from dishka import Container, Provider, Scope, provide, make_async_container
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession
from contextlib import asynccontextmanager
from src.infrastructure.database.base import AsyncSessionLocal
from src.infrastructure.repositories.postgresql.user_repository import PostgreSQLUserRepository
from src.infrastructure.repositories.postgresql.collection_repository import PostgreSQLCollectionRepository
from src.infrastructure.repositories.postgresql.document_repository import PostgreSQLDocumentRepository
from src.infrastructure.repositories.postgresql.conversation_repository import PostgreSQLConversationRepository
from src.infrastructure.repositories.postgresql.message_repository import PostgreSQLMessageRepository
from src.infrastructure.repositories.postgresql.collection_access_repository import PostgreSQLCollectionAccessRepository
from src.domain.repositories.user_repository import IUserRepository
from src.domain.repositories.collection_repository import ICollectionRepository
from src.domain.repositories.document_repository import IDocumentRepository
from src.domain.repositories.conversation_repository import IConversationRepository
from src.domain.repositories.message_repository import IMessageRepository
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
from src.domain.repositories.vector_repository import IVectorRepository
from src.infrastructure.external.yandex_ocr import YandexOCRService
from src.infrastructure.external.deepseek_client import DeepSeekClient
from src.infrastructure.external.redis_client import RedisClient
from src.application.services.document_parser_service import DocumentParserService
from src.application.services.cache_service import CacheService
from src.application.use_cases.user_use_cases import UserUseCases
from src.application.use_cases.collection_use_cases import CollectionUseCases
from src.application.use_cases.document_use_cases import DocumentUseCases
from src.application.use_cases.conversation_use_cases import ConversationUseCases
from src.application.use_cases.message_use_cases import MessageUseCases
from src.domain.entities.user import User
from src.shared.config import settings
from qdrant_client import QdrantClient
from src.infrastructure.repositories.qdrant.vector_repository import QdrantVectorRepository
from src.application.services.embedding_service import EmbeddingService
from src.application.services.reranker_service import RerankerService
from src.application.services.rag_service import RAGService
from src.application.services.text_splitter import TextSplitter
from src.application.use_cases.rag_use_cases import RAGUseCases
class DatabaseProvider(Provider):
@provide(scope=Scope.REQUEST)
@asynccontextmanager
async def get_db(self) -> AsyncSession:
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
class RepositoryProvider(Provider):
@provide(scope=Scope.REQUEST)
def get_user_repository(self, session: AsyncSession) -> IUserRepository:
return PostgreSQLUserRepository(session)
@provide(scope=Scope.REQUEST)
def get_collection_repository(self, session: AsyncSession) -> ICollectionRepository:
return PostgreSQLCollectionRepository(session)
@provide(scope=Scope.REQUEST)
def get_document_repository(self, session: AsyncSession) -> IDocumentRepository:
return PostgreSQLDocumentRepository(session)
@provide(scope=Scope.REQUEST)
def get_conversation_repository(self, session: AsyncSession) -> IConversationRepository:
return PostgreSQLConversationRepository(session)
@provide(scope=Scope.REQUEST)
def get_message_repository(self, session: AsyncSession) -> IMessageRepository:
return PostgreSQLMessageRepository(session)
@provide(scope=Scope.REQUEST)
def get_collection_access_repository(self, session: AsyncSession) -> ICollectionAccessRepository:
return PostgreSQLCollectionAccessRepository(session)
class ServiceProvider(Provider):
@provide(scope=Scope.APP)
def get_redis_client(self) -> RedisClient:
return RedisClient(host=settings.REDIS_HOST, port=settings.REDIS_PORT)
@provide(scope=Scope.APP)
def get_cache_service(self, redis_client: RedisClient) -> CacheService:
return CacheService(redis_client)
@provide(scope=Scope.APP)
def get_ocr_service(self) -> YandexOCRService:
return YandexOCRService()
@provide(scope=Scope.APP)
def get_deepseek_client(self) -> DeepSeekClient:
return DeepSeekClient()
@provide(scope=Scope.APP)
def get_parser_service(self, ocr_service: YandexOCRService) -> DocumentParserService:
return DocumentParserService(ocr_service)
@provide(scope=Scope.APP)
def get_qdrant_client(self) -> QdrantClient:
return QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT)
@provide(scope=Scope.APP)
def get_vector_repository(self, client: QdrantClient) -> IVectorRepository:
return QdrantVectorRepository(client=client, vector_size=768)
@provide(scope=Scope.APP)
def get_embedding_service(self) -> EmbeddingService:
return EmbeddingService()
@provide(scope=Scope.APP)
def get_reranker_service(self, embedding_service: EmbeddingService) -> RerankerService:
return RerankerService(fallback_encoder=embedding_service.model)
@provide(scope=Scope.APP)
def get_text_splitter(self) -> TextSplitter:
return TextSplitter()
@provide(scope=Scope.APP)
def get_rag_service(
self,
vector_repo: IVectorRepository,
embedding_service: EmbeddingService,
reranker_service: RerankerService,
deepseek_client: DeepSeekClient,
text_splitter: TextSplitter
) -> RAGService:
return RAGService(
vector_repository=vector_repo,
embedding_service=embedding_service,
reranker_service=reranker_service,
deepseek_client=deepseek_client,
splitter=text_splitter,
)
class UseCaseProvider(Provider):
@provide(scope=Scope.REQUEST)
def get_user_use_cases(
self,
user_repo: IUserRepository
) -> UserUseCases:
return UserUseCases(user_repo)
@provide(scope=Scope.REQUEST)
def get_collection_use_cases(
self,
collection_repo: ICollectionRepository,
access_repo: ICollectionAccessRepository,
user_repo: IUserRepository
) -> CollectionUseCases:
return CollectionUseCases(collection_repo, access_repo, user_repo)
@provide(scope=Scope.REQUEST)
def get_document_use_cases(
self,
document_repo: IDocumentRepository,
collection_repo: ICollectionRepository,
parser_service: DocumentParserService
) -> DocumentUseCases:
return DocumentUseCases(document_repo, collection_repo, parser_service)
@provide(scope=Scope.REQUEST)
def get_conversation_use_cases(
self,
conversation_repo: IConversationRepository,
collection_repo: ICollectionRepository,
access_repo: ICollectionAccessRepository
) -> ConversationUseCases:
return ConversationUseCases(conversation_repo, collection_repo, access_repo)
@provide(scope=Scope.REQUEST)
def get_message_use_cases(
self,
message_repo: IMessageRepository,
conversation_repo: IConversationRepository
) -> MessageUseCases:
return MessageUseCases(message_repo, conversation_repo)
@provide(scope=Scope.REQUEST)
def get_rag_use_cases(
self,
rag_service: RAGService,
document_repo: IDocumentRepository,
conversation_repo: IConversationRepository,
message_repo: IMessageRepository,
cache_service: CacheService
) -> RAGUseCases:
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo, cache_service)
def create_container() -> Container:
return make_async_container(
DatabaseProvider(),
RepositoryProvider(),
ServiceProvider(),
UseCaseProvider()
)