196 lines
8.0 KiB
Python
196 lines
8.0 KiB
Python
from dishka import Container, Provider, Scope, provide, make_async_container
|
|
from fastapi import Request
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from contextlib import asynccontextmanager
|
|
|
|
from src.infrastructure.database.base import AsyncSessionLocal
|
|
from src.infrastructure.repositories.postgresql.user_repository import PostgreSQLUserRepository
|
|
from src.infrastructure.repositories.postgresql.collection_repository import PostgreSQLCollectionRepository
|
|
from src.infrastructure.repositories.postgresql.document_repository import PostgreSQLDocumentRepository
|
|
from src.infrastructure.repositories.postgresql.conversation_repository import PostgreSQLConversationRepository
|
|
from src.infrastructure.repositories.postgresql.message_repository import PostgreSQLMessageRepository
|
|
from src.infrastructure.repositories.postgresql.collection_access_repository import PostgreSQLCollectionAccessRepository
|
|
from src.domain.repositories.user_repository import IUserRepository
|
|
from src.domain.repositories.collection_repository import ICollectionRepository
|
|
from src.domain.repositories.document_repository import IDocumentRepository
|
|
from src.domain.repositories.conversation_repository import IConversationRepository
|
|
from src.domain.repositories.message_repository import IMessageRepository
|
|
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
|
|
from src.domain.repositories.vector_repository import IVectorRepository
|
|
from src.infrastructure.external.yandex_ocr import YandexOCRService
|
|
from src.infrastructure.external.deepseek_client import DeepSeekClient
|
|
from src.infrastructure.external.redis_client import RedisClient
|
|
from src.application.services.document_parser_service import DocumentParserService
|
|
from src.application.services.cache_service import CacheService
|
|
from src.application.use_cases.user_use_cases import UserUseCases
|
|
from src.application.use_cases.collection_use_cases import CollectionUseCases
|
|
from src.application.use_cases.document_use_cases import DocumentUseCases
|
|
from src.application.use_cases.conversation_use_cases import ConversationUseCases
|
|
from src.application.use_cases.message_use_cases import MessageUseCases
|
|
from src.domain.entities.user import User
|
|
from src.shared.config import settings
|
|
from qdrant_client import QdrantClient
|
|
from src.infrastructure.repositories.qdrant.vector_repository import QdrantVectorRepository
|
|
from src.application.services.embedding_service import EmbeddingService
|
|
from src.application.services.reranker_service import RerankerService
|
|
from src.application.services.rag_service import RAGService
|
|
from src.application.services.text_splitter import TextSplitter
|
|
from src.application.use_cases.rag_use_cases import RAGUseCases
|
|
|
|
class DatabaseProvider(Provider):
|
|
@provide(scope=Scope.REQUEST)
|
|
async def get_db(self) -> AsyncSession:
|
|
session = AsyncSessionLocal()
|
|
return session
|
|
|
|
|
|
class RepositoryProvider(Provider):
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_user_repository(self, session: AsyncSession) -> IUserRepository:
|
|
return PostgreSQLUserRepository(session)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_collection_repository(self, session: AsyncSession) -> ICollectionRepository:
|
|
return PostgreSQLCollectionRepository(session)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_document_repository(self, session: AsyncSession) -> IDocumentRepository:
|
|
return PostgreSQLDocumentRepository(session)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_conversation_repository(self, session: AsyncSession) -> IConversationRepository:
|
|
return PostgreSQLConversationRepository(session)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_message_repository(self, session: AsyncSession) -> IMessageRepository:
|
|
return PostgreSQLMessageRepository(session)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_collection_access_repository(self, session: AsyncSession) -> ICollectionAccessRepository:
|
|
return PostgreSQLCollectionAccessRepository(session)
|
|
|
|
|
|
class ServiceProvider(Provider):
|
|
@provide(scope=Scope.APP)
|
|
def get_redis_client(self) -> RedisClient:
|
|
return RedisClient(host=settings.REDIS_HOST, port=settings.REDIS_PORT)
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_cache_service(self, redis_client: RedisClient) -> CacheService:
|
|
return CacheService(redis_client)
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_ocr_service(self) -> YandexOCRService:
|
|
return YandexOCRService()
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_deepseek_client(self) -> DeepSeekClient:
|
|
return DeepSeekClient()
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_parser_service(self, ocr_service: YandexOCRService) -> DocumentParserService:
|
|
return DocumentParserService(ocr_service)
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_qdrant_client(self) -> QdrantClient:
|
|
return QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT)
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_vector_repository(self, client: QdrantClient) -> IVectorRepository:
|
|
return QdrantVectorRepository(client=client, vector_size=768)
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_embedding_service(self) -> EmbeddingService:
|
|
return EmbeddingService()
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_reranker_service(self, embedding_service: EmbeddingService) -> RerankerService:
|
|
return RerankerService(fallback_encoder=embedding_service.model)
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_text_splitter(self) -> TextSplitter:
|
|
return TextSplitter()
|
|
|
|
@provide(scope=Scope.APP)
|
|
def get_rag_service(
|
|
self,
|
|
vector_repo: IVectorRepository,
|
|
embedding_service: EmbeddingService,
|
|
reranker_service: RerankerService,
|
|
deepseek_client: DeepSeekClient,
|
|
text_splitter: TextSplitter
|
|
) -> RAGService:
|
|
return RAGService(
|
|
vector_repository=vector_repo,
|
|
embedding_service=embedding_service,
|
|
reranker_service=reranker_service,
|
|
deepseek_client=deepseek_client,
|
|
splitter=text_splitter,
|
|
)
|
|
|
|
|
|
class UseCaseProvider(Provider):
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_user_use_cases(
|
|
self,
|
|
user_repo: IUserRepository
|
|
) -> UserUseCases:
|
|
return UserUseCases(user_repo)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_collection_use_cases(
|
|
self,
|
|
collection_repo: ICollectionRepository,
|
|
access_repo: ICollectionAccessRepository,
|
|
user_repo: IUserRepository
|
|
) -> CollectionUseCases:
|
|
return CollectionUseCases(collection_repo, access_repo, user_repo)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_document_use_cases(
|
|
self,
|
|
document_repo: IDocumentRepository,
|
|
collection_repo: ICollectionRepository,
|
|
access_repo: ICollectionAccessRepository,
|
|
parser_service: DocumentParserService
|
|
) -> DocumentUseCases:
|
|
return DocumentUseCases(document_repo, collection_repo, access_repo, parser_service)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_conversation_use_cases(
|
|
self,
|
|
conversation_repo: IConversationRepository,
|
|
collection_repo: ICollectionRepository,
|
|
access_repo: ICollectionAccessRepository
|
|
) -> ConversationUseCases:
|
|
return ConversationUseCases(conversation_repo, collection_repo, access_repo)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_message_use_cases(
|
|
self,
|
|
message_repo: IMessageRepository,
|
|
conversation_repo: IConversationRepository
|
|
) -> MessageUseCases:
|
|
return MessageUseCases(message_repo, conversation_repo)
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
def get_rag_use_cases(
|
|
self,
|
|
rag_service: RAGService,
|
|
document_repo: IDocumentRepository,
|
|
conversation_repo: IConversationRepository,
|
|
message_repo: IMessageRepository,
|
|
cache_service: CacheService
|
|
) -> RAGUseCases:
|
|
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo, cache_service)
|
|
|
|
|
|
def create_container() -> Container:
|
|
return make_async_container(
|
|
DatabaseProvider(),
|
|
RepositoryProvider(),
|
|
ServiceProvider(),
|
|
UseCaseProvider()
|
|
)
|
|
|