Merge pull request 'arkhip' (#3) from arkhip into main

Reviewed-on: HSE_team/BetterCallPraskovia#3
This commit is contained in:
Arxip222 2025-12-22 23:40:02 +03:00
commit d25feb8d2d
19 changed files with 1318 additions and 9 deletions

318
AI_api.yaml Normal file
View File

@ -0,0 +1,318 @@
openapi: 3.0.3
info:
title: Legal RAG AI API
description: API для юридического AI-ассистента. Обеспечивает работу RAG-пайплайна (поиск + генерация), управление коллекциями документов и биллинг.
version: 1.0.0
servers:
- url: http://localhost:8000/api/v1
description: Local Development Server
tags:
- name: Auth & Users
description: Управление пользователями и проверка лимитов
- name: RAG & Chat
description: Основной функционал (чат, поиск)
- name: Collections
description: Управление базами знаний
- name: Billing
description: Платежи
paths:
/users/register:
post:
tags:
- Auth & Users
summary: Регистрация/Вход пользователя через Telegram
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/UserAuthRequest'
responses:
'200':
description: Успешная авторизация
content:
application/json:
schema:
$ref: '#/components/schemas/UserResponse'
/users/me:
get:
tags:
- Auth & Users
summary: Получить профиль пользователя
parameters:
- in: header
name: X-Telegram-ID
schema:
type: string
required: true
description: Telegram ID пользователя
responses:
'200':
description: Профиль пользователя
content:
application/json:
schema:
$ref: '#/components/schemas/UserResponse'
'403':
description: Доступ запрещен
/chat/ask:
post:
tags:
- RAG & Chat
summary: Задать вопрос юристу
parameters:
- in: header
name: X-Telegram-ID
schema:
type: string
required: true
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/ChatRequest'
responses:
'200':
description: Ответ сгенерирован
content:
application/json:
schema:
$ref: '#/components/schemas/ChatResponse'
'402':
description: Лимит бесплатных запросов исчерпан
/chat/history:
get:
tags:
- RAG & Chat
summary: Получить историю диалога
parameters:
- in: query
name: conversation_id
schema:
type: string
format: uuid
required: true
responses:
'200':
description: Список сообщений
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/MessageDTO'
/collections:
get:
tags:
- Collections
summary: Список доступных коллекций
parameters:
- in: header
name: X-Telegram-ID
schema:
type: string
required: true
responses:
'200':
description: Список коллекций
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/CollectionDTO'
post:
tags:
- Collections
summary: Создать новую коллекцию
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/CreateCollectionRequest'
responses:
'201':
description: Коллекция создана
content:
application/json:
schema:
$ref: '#/components/schemas/CollectionDTO'
/collections/{collection_id}/upload:
post:
tags:
- Collections
summary: Загрузка документа в коллекцию
parameters:
- in: path
name: collection_id
required: true
schema:
type: string
format: uuid
requestBody:
content:
multipart/form-data:
schema:
type: object
properties:
file:
type: string
format: binary
responses:
'202':
description: Файл принят в обработку
content:
application/json:
schema:
type: object
properties:
task_id:
type: string
status:
type: string
/billing/create-payment:
post:
tags:
- Billing
summary: Создать ссылку на оплату
parameters:
- in: header
name: X-Telegram-ID
required: true
schema:
type: string
responses:
'200':
description: Ссылка сформирована
content:
application/json:
schema:
type: object
properties:
payment_url:
type: string
payment_id:
type: string
/billing/webhook:
post:
tags:
- Billing
summary: Вебхук от платежной системы
responses:
'200':
description: OK
components:
schemas:
UserAuthRequest:
type: object
required:
- telegram_id
properties:
telegram_id:
type: string
example: "123456789"
username:
type: string
first_name:
type: string
UserResponse:
type: object
properties:
id:
type: string
format: uuid
telegram_id:
type: string
role:
type: string
enum:
- user
- admin
request_count:
type: integer
is_premium:
type: boolean
ChatRequest:
type: object
required:
- query
- collection_id
properties:
query:
type: string
example: "Какая ответственность за неуплату НДС?"
collection_id:
type: string
format: uuid
conversation_id:
type: string
format: uuid
nullable: true
ChatResponse:
type: object
properties:
answer:
type: string
sources:
type: array
items:
$ref: '#/components/schemas/Source'
conversation_id:
type: string
format: uuid
request_count_left:
type: integer
Source:
type: object
properties:
title:
type: string
page:
type: integer
relevance_score:
type: number
format: float
snippet:
type: string
MessageDTO:
type: object
properties:
role:
type: string
enum:
- user
- ai
content:
type: string
created_at:
type: string
format: date-time
CollectionDTO:
type: object
properties:
id:
type: string
format: uuid
name:
type: string
description:
type: string
is_public:
type: boolean
owner_id:
type: string
format: uuid
CreateCollectionRequest:
type: object
required:
- name
properties:
name:
type: string
description:
type: string
is_public:
type: boolean
default: false

248
README.md Normal file
View File

@ -0,0 +1,248 @@
# Техническая документация проекта: Юридический AI-ассистент (RAG Система)
## 1. Концепция и методология RAG
**Что такое RAG в данном проекте?**
Мы не просто отправляем вопрос в ChatGPT. Мы реализуем архитектуру **Retrieval-Augmented Generation**. Это необходимо для того, чтобы LLM не галлюцинировала законы, а отвечала строго по тексту загруженных документов.
**Принцип работы:**
1. **Поиск:** Запрос пользователя превращается в вектор. В базе данных находятся фрагменты законов, семантически близкие к запросу
2. **Дополнение:** Найденные фрагменты подклеиваются к промпту пользователя в качестве контекста
3. **Генерация:** LLM генерирует ответ, используя только предоставленный контекст, и проставляет ссылки
## 2. Функциональность
Пользователь (юрист или сотрудник компании) взаимодействует с системой через Telegram-бот. Примеры интерфейса:
- **Отправка запроса**: Пользователь пишет сообщение, например: "Найди все договоры поставки с компанией ООО "Ромашка" или "Какая ответственность за неуплату налогов ИП?". Бот отвечает на естественном языке, с точными фактами и обязательными ссылками на источники (например, "Согласно статье 122 Налогового кодекса РФ")
- **Выбор коллекции**: Пользователь может выбрать доступную его компании коллекцию документов (список доступных коллекций отображается) и далее задавать вопросы в рамках нее
- **Создание коллекций**: Любой пользователь может создать собственную коллекцию через админпанель и выбрать логины пользователей, которые имеют право ей пользоваться
- **Добавление данных**: Админы коллекций через панель могут добавлять документы в свою коллекцию
- **Подписка**: После 10 бесплатных запросов бот предлагает оплатить подписку через встроенный платежный сервис
- **Источники**: Каждый факт в ответе сопровождается ссылкой, например, "Документ: Налоговый кодекс РФ, статья 122"
- **Просмотр коллекций**: Пользователь может просмотреть список с именами и описаниями доступных коллекций
## 3. Стек технологий и обоснование
| Компонент | Технология | Обоснование выбора |
| :--- | :--- | :--- |
| **Backend API** | **FastAPI** | Асинхронность для работы с LLM стримами, легкая интеграция DI, валидация аднных |
| **Telegram Bot** | **Aiogram** | Асинхронный, нативная поддержка машин состояний, мощный механизм мидлвара для логирования и авторизации. |
| **Vector DB** | **Qdrant** | Высокопроизводительная БД на Rust. Поддерживает гибридный поиск и фильтрацию по метаданным из коробки. |
| **Cache & Bus** | **Redis** | 1. **Кэширование RAG:** Хранение пар вопрос-ответ для экономии денег на LLM<br>2. **FSM Storage:** Хранение состояний диалога бота<br>3. **Rate Limiting:** Подсчет количества запросов пользователя |
| **Main DB** | **PostgreSQL** | Хранение профилей пользователей, логов чатов, конфигураций коллекций и данных для биллинга |
| **LLM** | **DeepSeek** | Использование API внешних провайдеров |
## 4. API Интерфейс
Бэкенд предоставляет REST API, документированное по стандарту OpenAPI.
Полная swagger спецификация представлена в файлах вместе с данным документом (*AI_api.yaml*)
## 5. Какие задачи необходимо выполнить
- **Разработка backend (FastAPI)**: Создать API для обработки запросов, интеграции с БД, эмбеддингом и LLM. Включает DDD, Dependency Injection (dishka), асинхронность. Также сервис авторизации.
- **Telegram-бот (aiogram)**: Интеграция с API, авторизация, обработка сообщений, команд для коллекций и платежей.
- **Векторная БД (Qdrant)**: Настройка для хранения embeddings, коллекций (в сущностях DocumentCollection, Embeddings).
- **Индексатор и Поисковик**: Реализация предобработки, векторизации и семантического поиска.
- **Реранкер и Генератор**: Модули для реранкинга топ-k и генерации ответов с источниками.
- **Админ панель**: Простой интерфейс для добавления/удаления данных.
- **Тестирование**: Создание тестового датасета с hit@5 > 50%.
- **Кэширование**: Для ускорения похожих запросов.
- **Платежный сервис**: Простая подписка после 10 запросов.
- **Деплой**: Docker-контейнеры для FastAPI и бота.
## 6. Сколько времени потратится на каждую задачу
Оценки времени (в рабочих днях):
| Задача | Время (дни) |
| -------------------------------------- | ---------------- |
| Backend (FastAPI) | 5 |
| Telegram-бот | 4 |
| Векторная БД | 2 |
| Индексатор и Поисковик | 3 |
| Реранкер и Генератор | 4 |
| Админ панель | 2 |
| Тестирование | 2 |
| Датасет проверки качества | 2 |
| Кэширование | 2 |
| Платежный сервис | 2 |
| Деплой и интеграция | 3 |
| Эмбеддинг | 2 |
| **Итого** | **33** |
Общий срок — 4 недели (с учетом параллельной работы команды из 4 человек).
## 7. Кто что делает в команде
Распределение ролей в команде из 4 человек:
- **Developer 1 (Lead/Backend)**: Backend (FastAPI, DDD, DI), Векторная БД, Индексатор/Поисковик.
- **Developer 2 (AI/ML)**: Эмбеддинг, Реранкер, Генератор, Тестирование (датасет, метрики).
- **Developer 3 (Frontend/Bot)**: Telegram-бот, Админ панель, Платежный сервис.
- **Developer 4 (DevOps/Tester)**: Кэширование, Деплой (Docker), тестирование функционала.
Каждый использует DTO и интерфейсы для контрактов, что обеспечивает изоляцию компонентов и легкость изменений.
## Диаграмма компонентов системы
``` mermaid
graph TD
User((Пользователь))
Admin((Админ))
subgraph "Интерфейсы"
TGBot[Telegram Bot
Aiogram]
API_Endpoint[API Gateway
FastAPI]
end
subgraph "Бэкенд RAG"
Logic[Бизнес-логика
Services & RAG Pipeline]
end
subgraph "Базы данных"
PG[(PostgreSQL
Users, Logs)]
VDB[(Vector DB
Embeddings)]
Redis[(Redis
Cache)]
end
subgraph "Внешние API"
LLM[LLM API
DeepSeek]
Pay[Payment API
Yookassa]
end
User -->|Команды/Вопросы| TGBot
TGBot -->|REST| API_Endpoint
Admin -->|Загрузка документов| API_Endpoint
API_Endpoint --> Logic
Logic -->|Кэширование| Redis
Logic -->|Метаданные| PG
Logic -->|Поиск похожих| VDB
Logic -->|Генерация ответа| LLM
Logic -->|Транзакции| Pay
```
## Пайплайн RAG
``` mermaid
sequenceDiagram
autonumber
actor User as Юрист
participant System as Бот + RAG Сервис
participant VectorDB as Векторная БД
participant AI as LLM
Note over User, System: Шаг 1: Запрос
User->>System: "Статья 5.5, что это?"
Note over System, VectorDB: Шаг 2: Поиск фактов
System->>VectorDB: Ищет похожие статьи законов
VectorDB-->>System: Возвращает Топ-5 документов
Note over System, AI: Шаг 3: Генерация
System->>AI: Отправляет промпт
AI-->>System: Формирует ответ, опираясь на законы
Note over User, System: Шаг 4: Ответ
System-->>User: Текст консультации + Ссылки на статьи
```
## Схема индексации
``` mermaid
graph TD
Actor[Администратор / Юрист] -->|Загрузка документа| API[FastAPI]
API --> Input
subgraph "1. Предобработка"
Input[Файл PDF/DOCX] --> Extract(Извлечение текста)
Extract --> Clean{Очистка}
Clean -- Удаление шума --> Norm[Нормализация текста]
end
subgraph "2. Чанкирование"
Norm --> Splitter[Разбиение на фрагменты]
--> Chunks[Список чанков]
end
subgraph "3. Векторизация"
Chunks --> Model[Эмбеддинг модель
FRIDA]
Model --> Vectors[Векторные представления]
end
subgraph "4. Сохранение"
Vectors --> VDB[(Векторная БД
Qdrant)]
Chunks -->|Заголовок, дата, автор| MetaDB[(Postgres
Метаданные)]
end
```
## Схема БД
``` mermaid
erDiagram
USERS {
uuid user_id PK
string telegram_id
datetime created_at
string role "user / admin"
}
COLLECTIONS {
uuid collection_id PK
string name
text description
uuid owner_id FK
boolean is_public
datetime created_at
}
DOCUMENTS {
uuid document_id PK
uuid collection_id FK
string title
text content
json metadata
vector vector_embedding "Вектор всего документа"
datetime created_at
}
EMBEDDINGS {
uuid embedding_id PK
uuid document_id FK
vector embedding
string model_version
datetime created_at
}
CONVERSATIONS {
uuid conversation_id PK
uuid user_id FK
uuid collection_id FK
datetime created_at
datetime updated_at
}
MESSAGES {
uuid message_id PK
uuid conversation_id FK
text content
string role "user / ai"
json sources "Ссылки на использованные документы"
datetime created_at
}
USERS ||--o{ COLLECTIONS : "owner_id"
USERS ||--o{ CONVERSATIONS : "user_id"
COLLECTIONS ||--o{ DOCUMENTS : "collection_id"
COLLECTIONS ||--o{ CONVERSATIONS : "collection_id"
DOCUMENTS ||--o{ EMBEDDINGS : "document_id"
CONVERSATIONS ||--o{ MESSAGES : "conversation_id"
```

View File

@ -10,4 +10,7 @@ httpx==0.25.2
PyMuPDF==1.23.8 PyMuPDF==1.23.8
Pillow==10.2.0 Pillow==10.2.0
dishka==0.7.0 dishka==0.7.0
numpy==1.26.4
sentence-transformers==2.7.0
qdrant-client==1.9.0
redis==5.0.1

View File

View File

@ -0,0 +1,49 @@
import hashlib
from typing import Optional
from uuid import UUID
from src.infrastructure.external.redis_client import RedisClient
class CacheService:
def __init__(self, redis_client: RedisClient, default_ttl: int = 3600 * 24):
self.redis_client = redis_client
self.default_ttl = default_ttl
def _make_key(self, collection_id: UUID, question: str) -> str:
question_hash = hashlib.sha256(question.encode()).hexdigest()[:16]
return f"rag:answer:{collection_id}:{question_hash}"
async def get_cached_answer(self, collection_id: UUID, question: str) -> Optional[dict]:
key = self._make_key(collection_id, question)
cached = await self.redis_client.get_json(key)
if cached and cached.get("question") == question:
return cached.get("answer")
return None
async def cache_answer(self, collection_id: UUID, question: str, answer: dict, ttl: Optional[int] = None):
key = self._make_key(collection_id, question)
value = {
"question": question,
"answer": answer
}
await self.reids_client.set_json(key, value, ttl or self.default_ttl)
async def invalidate_collection_cache(self, collection_id: UUID):
pattern = f"rag:answer:{collection_id}:*"
keys = await self.redis_client.keys(pattern)
if keys:
for key in keys:
await self.redis_client.delete(key)
async def invalidate_document_cache(self, document_id: UUID):
pattern = f"rag:answer:*"
keys = await self.redis_client.keys(pattern)
if keys:
for key in keys:
cached = await self.redis_client.get_json(key)
if cached:
sources = cached.get("answer", {}).get("sources", [])
doc_ids = [s.get("document_id") for s in sources]
if str(document_id) in doc_ids:
await self.redis_client.delete(key)

View File

@ -0,0 +1,37 @@
"""
Сервис для вычисления эмбеддингов текстов
"""
from functools import lru_cache
from typing import Iterable
import numpy as np
from sentence_transformers import SentenceTransformer
class EmbeddingService:
def __init__(self, model_name: str | None = None):
self.model_name = model_name or "intfloat/multilingual-e5-base"
self._model = None
@property
def model(self) -> SentenceTransformer:
if self._model is None:
self._model = SentenceTransformer(self.model_name)
return self._model
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
embeddings = self.model.encode(
list(texts),
batch_size=8,
show_progress_bar=False,
normalize_embeddings=True,
)
return [np.array(v, dtype=np.float32).tolist() for v in embeddings]
def embed_query(self, text: str) -> list[float]:
return self.embed_texts([text])[0]
@lru_cache(maxsize=1)
def model_version(self) -> str:
return self.model_name

View File

@ -0,0 +1,109 @@
"""
Сервис RAG: индексация, поиск, генерация ответа
"""
from typing import Sequence
from uuid import UUID
from src.application.services.text_splitter import TextSplitter
from src.application.services.embedding_service import EmbeddingService
from src.application.services.reranker_service import RerankerService
from src.domain.entities.document import Document
from src.domain.entities.chunk import DocumentChunk
from src.domain.repositories.vector_repository import IVectorRepository
from src.infrastructure.external.deepseek_client import DeepSeekClient
class RAGService:
def __init__(
self,
vector_repository: IVectorRepository,
embedding_service: EmbeddingService,
reranker_service: RerankerService,
deepseek_client: DeepSeekClient,
splitter: TextSplitter | None = None,
):
self.vector_repository = vector_repository
self.embedding_service = embedding_service
self.reranker_service = reranker_service
self.deepseek_client = deepseek_client
self.splitter = splitter or TextSplitter()
async def index_document(self, document: Document) -> list[DocumentChunk]:
chunks_text = self.splitter.split(document.content)
chunks: list[DocumentChunk] = []
for idx, text in enumerate(chunks_text):
chunks.append(
DocumentChunk(
document_id=document.document_id,
collection_id=document.collection_id,
content=text,
order=idx,
metadata={"title": document.title},
)
)
embeddings = self.embedding_service.embed_texts([c.content for c in chunks])
await self.vector_repository.upsert_chunks(
chunks, embeddings, model_version=self.embedding_service.model_version()
)
return chunks
async def retrieve(
self, query: str, collection_id: UUID, limit: int = 20, rerank_top_n: int = 5
) -> list[tuple[DocumentChunk, float]]:
query_embedding = self.embedding_service.embed_query(query)
candidates = await self.vector_repository.search(
query_embedding, collection_id=collection_id, limit=limit
)
if not candidates:
return []
passages = [c.content for c, _ in candidates]
order = self.reranker_service.rerank(query, passages, top_n=rerank_top_n)
return [candidates[i] for i in order if i < len(candidates)]
async def generate_answer(
self,
query: str,
context_chunks: Sequence[DocumentChunk],
max_tokens: int | None = 400,
temperature: float = 0.2,
) -> dict:
context_blocks = [
f"[{idx+1}] {c.content}\nИсточник: документ {c.metadata.get('title','')} (chunk {c.order})"
for idx, c in enumerate(context_chunks)
]
context = "\n\n".join(context_blocks)
system_prompt = (
"Ты юридический ассистент. Отвечай только на основе переданного контекста. "
"Обязательно добавляй ссылки на источники в формате [номер]. "
"Если ответа нет в контексте, скажи, что данных недостаточно."
)
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": f"Вопрос: {query}\n\nКонтекст:\n{context}",
},
]
resp = await self.deepseek_client.chat_completion(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
)
return {
"content": resp.get("content", ""),
"usage": resp.get("usage", {}),
"sources": [
{
"index": idx + 1,
"document_id": str(chunk.document_id),
"chunk_id": str(chunk.chunk_id),
"title": chunk.metadata.get("title", ""),
}
for idx, chunk in enumerate(context_chunks)
],
}

View File

@ -0,0 +1,44 @@
"""
Сервис реранкинга результатов поиска
"""
from typing import Sequence
import numpy as np
from sentence_transformers import CrossEncoder, SentenceTransformer
class RerankerService:
def __init__(
self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
fallback_encoder: SentenceTransformer | None = None,
):
self.model_name = model_name
self._model: CrossEncoder | None = None
self.fallback_encoder = fallback_encoder
@property
def model(self) -> CrossEncoder:
if self._model is None:
self._model = CrossEncoder(self.model_name)
return self._model
def rerank(self, query: str, passages: Sequence[str], top_n: int = 5) -> list[int]:
if not passages:
return []
try:
scores = self.model.predict([[query, p] for p in passages])
order = np.argsort(scores)[::-1]
return order[:top_n].tolist()
except Exception:
if self.fallback_encoder:
q_emb = self.fallback_encoder.encode(query, normalize_embeddings=True)
p_emb = self.fallback_encoder.encode(
list(passages), normalize_embeddings=True
)
sims = np.dot(p_emb, q_emb)
order = np.argsort(sims)[::-1]
return order[:top_n].tolist()
return list(range(min(top_n, len(passages))))

View File

@ -0,0 +1,43 @@
"""
Простой текстовый сплиттер для подготовки чанков
"""
import re
from typing import Iterable
class TextSplitter:
def __init__(self, chunk_size: int = 800, chunk_overlap: int = 200):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split(self, text: str) -> list[str]:
normalized = self._normalize(text)
if not normalized:
return []
sentences = self._split_sentences(normalized)
chunks: list[str] = []
current: list[str] = []
current_len = 0
for sent in sentences:
if current_len + len(sent) > self.chunk_size and current:
chunks.append(" ".join(current).strip())
while current and current_len > self.chunk_overlap:
popped = current.pop(0)
current_len -= len(popped)
current.append(sent)
current_len += len(sent)
if current:
chunks.append(" ".join(current).strip())
return [c for c in chunks if c]
def _normalize(self, text: str) -> str:
return re.sub(r"\s+", " ", text).strip()
def _split_sentences(self, text: str) -> Iterable[str]:
parts = re.split(r"(?<=[\.\?\!])\s+", text)
return [p.strip() for p in parts if p.strip()]

View File

@ -0,0 +1,92 @@
"""
Use cases для RAG: индексация документов и ответы на вопросы
"""
from uuid import UUID
from src.application.services.rag_service import RAGService
from src.application.services.cache_service import CacheService
from src.domain.repositories.document_repository import IDocumentRepository
from src.domain.repositories.conversation_repository import IConversationRepository
from src.domain.repositories.message_repository import IMessageRepository
from src.domain.entities.message import Message, MessageRole
from src.shared.exceptions import NotFoundError, ForbiddenError
class RAGUseCases:
def __init__(
self,
rag_service: RAGService,
document_repo: IDocumentRepository,
conversation_repo: IConversationRepository,
message_repo: IMessageRepository,
cache_service: CacheService,
):
self.rag_service = rag_service
self.document_repo = document_repo
self.conversation_repo = conversation_repo
self.message_repo = message_repo
self.cache_service = cache_service
async def index_document(self, document_id: UUID) -> dict:
document = await self.document_repo.get_by_id(document_id)
if not document:
raise NotFoundError(f"Документ {document_id} не найден")
chunks = await self.rag_service.index_document(document)
return {"chunks_indexed": len(chunks)}
async def ask_question(
self,
conversation_id: UUID,
user_id: UUID,
question: str,
top_k: int = 20,
rerank_top_n: int = 5,
) -> dict:
conversation = await self.conversation_repo.get_by_id(conversation_id)
if not conversation:
raise NotFoundError(f"Беседа {conversation_id} не найдена")
if conversation.user_id != user_id:
raise ForbiddenError("Нет доступа к этой беседе")
user_message = Message(
conversation_id=conversation_id, content=question, role=MessageRole.USER
)
await self.message_repo.create(user_message)
cached_answer = None
if self.cache_service:
cached_answer = await self.cache_service.get_cached_answer(conversation.collection_id, question)
if cached_answer:
generation = cached_answer
else:
retrieved = await self.rag_service.retrieve(
query=question,
collection_id=conversation.collection_id,
limit=top_k,
rerank_top_n=rerank_top_n,
)
chunks = [c for c, _ in retrieved]
generation = await self.rag_service.generate_answer(question, chunks)
if self.cache_service:
await self.cache_service.cache_answer(
conversation.collection_id,
question,
generation
)
assistant_message = Message(
conversation_id=conversation_id,
content=generation["content"],
role=MessageRole.ASSISTANT,
sources={"chunks": generation.get("sources", [])},
)
await self.message_repo.create(assistant_message)
return {
"answer": generation["content"],
"sources": generation.get("sources", []),
"usage": generation.get("usage", {}),
}

View File

@ -0,0 +1,28 @@
"""
Доменная сущность чанка
"""
from datetime import datetime
from uuid import UUID, uuid4
from typing import Any
class DocumentChunk:
def __init__(
self,
document_id: UUID,
collection_id: UUID,
content: str,
chunk_id: UUID | None = None,
order: int = 0,
metadata: dict[str, Any] | None = None,
created_at: datetime | None = None,
):
self.chunk_id = chunk_id or uuid4()
self.document_id = document_id
self.collection_id = collection_id
self.content = content
self.order = order
self.metadata = metadata or {}
self.created_at = created_at or datetime.utcnow()

View File

@ -0,0 +1,31 @@
"""
Интерфейс репозитория/хранилища векторов
"""
from abc import ABC, abstractmethod
from typing import Sequence
from uuid import UUID
from src.domain.entities.chunk import DocumentChunk
class IVectorRepository(ABC):
@abstractmethod
async def upsert_chunks(
self,
chunks: Sequence[DocumentChunk],
embeddings: Sequence[list[float]],
model_version: str,
) -> None:
"""Сохранить или обновить вектора чанков"""
raise NotImplementedError
@abstractmethod
async def search(
self,
query_embedding: list[float],
collection_id: UUID,
limit: int = 20,
) -> list[tuple[DocumentChunk, float]]:
"""Поиск ближайших чанков по коллекции с расстоянием"""
raise NotImplementedError

View File

@ -0,0 +1,76 @@
import json
from typing import Optional, Any
import redis.asyncio as aioredis
from src.shared.config import settings
class RedisClient:
def __init__(self, host: str, port: int):
self.host = host or settings.REDIS_HOST
self.port = port or settings.REDIS_PORT
self._client: Optional[aioredis.Redis] = None
async def connect(self):
if self._client is None:
self._client = await aioredis.from_url(
f"redis://{self.host}:{self.port}",
encoding="utf-8",
decode_responses=True
)
async def disconnect(self):
if self._client:
await self._client.aclose()
self._client = None
async def get(self, key: str) -> Optional[str]:
if self._client is None:
await self.connect()
return await self._client.get(key)
async def set(self, key: str, value: str, ttl: Optional[int] = None):
if self._client is None:
await self.connect()
if ttl:
await self._client.setex(key, ttl, value)
else:
await self._client.set(key, value)
async def get_json(self, key: str) -> Optional[dict[str, Any]]:
value = await self.get(key)
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
return None
async def set_json(self, key: str, value: dict[str, Any], ttl: Optional[int] = None):
json_str = json.dumps(value)
await self.set(key, json_str, ttl)
async def delete(self, key: str):
if self._client is None:
await self.connect()
await self._client.delete(key)
async def exists(self, key: str) -> bool:
if self._client is None:
await self.connect()
return bool(await self._client.exists(key))
async def incr(self, key: str) -> int:
if self._client is None:
await self.connect()
return await self._client.incr(key)
async def expire(self, key: str, seconds: int):
if self._client is None:
await self.connect()
await self._client.expire(key, seconds)
async def keys(self, pattern: str) -> list[str]:
if self._client is None:
await self.connect()
return await self._client.keys(pattern)

View File

@ -0,0 +1,4 @@
"""
Qdrant repositories
"""

View File

@ -0,0 +1,84 @@
"""
Qdrant реализация векторного хранилища
"""
from typing import Sequence
from uuid import UUID
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from src.domain.entities.chunk import DocumentChunk
from src.domain.repositories.vector_repository import IVectorRepository
class QdrantVectorRepository(IVectorRepository):
def __init__(
self,
client: QdrantClient,
collection_name: str = "documents",
vector_size: int = 768,
):
self.client = client
self.collection_name = collection_name
self.vector_size = vector_size
self._ensure_collection()
def _ensure_collection(self) -> None:
"""Создает коллекцию при отсутствии"""
if self.collection_name in [c.name for c in self.client.get_collections().collections]:
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
)
async def upsert_chunks(
self,
chunks: Sequence[DocumentChunk],
embeddings: Sequence[list[float]],
model_version: str,
) -> None:
points = []
for chunk, vector in zip(chunks, embeddings):
points.append(
PointStruct(
id=str(chunk.chunk_id),
vector=vector,
payload={
"document_id": str(chunk.document_id),
"collection_id": str(chunk.collection_id),
"content": chunk.content,
"order": chunk.order,
"model_version": model_version,
"title": chunk.metadata.get("title", ""),
},
)
)
self.client.upsert(collection_name=self.collection_name, points=points)
async def search(
self,
query_embedding: list[float],
collection_id: UUID,
limit: int = 20,
) -> list[tuple[DocumentChunk, float]]:
res = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding,
query_filter=Filter(
must=[FieldCondition(key="collection_id", match=MatchValue(value=str(collection_id)))]
),
limit=limit,
)
results: list[tuple[DocumentChunk, float]] = []
for hit in res:
payload = hit.payload or {}
chunk = DocumentChunk(
document_id=UUID(payload["document_id"]),
collection_id=UUID(payload["collection_id"]),
content=payload.get("content", ""),
chunk_id=UUID(hit.id),
order=payload.get("order", 0),
metadata={"title": payload.get("title", ""), "model_version": payload.get("model_version", "")},
)
results.append((chunk, hit.score))
return results

View File

@ -0,0 +1,45 @@
"""
API для RAG: индексация документов и ответы на вопросы
"""
from fastapi import APIRouter, status
from dishka.integrations.fastapi import FromDishka
from src.presentation.schemas.rag_schemas import (
QuestionRequest,
RAGAnswer,
IndexDocumentRequest,
IndexDocumentResponse,
)
from src.application.use_cases.rag_use_cases import RAGUseCases
from src.domain.entities.user import User
router = APIRouter(prefix="/rag", tags=["rag"])
@router.post("/index", response_model=IndexDocumentResponse, status_code=status.HTTP_200_OK)
async def index_document(
body: IndexDocumentRequest,
use_cases: FromDishka[RAGUseCases] = FromDishka(),
current_user: FromDishka[User] = FromDishka(),
):
"""Индексирование идет через чанкирование, далее эмбеддинг и загрузка в векторную бд"""
result = await use_cases.index_document(body.document_id)
return IndexDocumentResponse(**result)
@router.post("/question", response_model=RAGAnswer, status_code=status.HTTP_200_OK)
async def ask_question(
body: QuestionRequest,
use_cases: FromDishka[RAGUseCases] = FromDishka(),
current_user: FromDishka[User] = FromDishka(),
):
"""Отвечает на вопрос, используя RAG в рамках беседы"""
result = await use_cases.ask_question(
conversation_id=body.conversation_id,
user_id=current_user.user_id,
question=body.question,
top_k=body.top_k,
rerank_top_n=body.rerank_top_n,
)
return RAGAnswer(**result)

View File

@ -1,6 +1,3 @@
"""
Главный файл FastAPI приложения
"""
import sys import sys
import os import os
@ -16,7 +13,7 @@ from src.shared.config import settings
from src.shared.exceptions import LawyerAIException from src.shared.exceptions import LawyerAIException
from src.shared.di_container import create_container from src.shared.di_container import create_container
from src.presentation.middleware.error_handler import exception_handler from src.presentation.middleware.error_handler import exception_handler
from src.presentation.api.v1 import users, collections, documents, conversations, messages from src.presentation.api.v1 import users, collections, documents, conversations, messages, rag
from src.infrastructure.database.base import engine, Base from src.infrastructure.database.base import engine, Base
@ -57,6 +54,7 @@ app.include_router(collections.router, prefix="/api/v1")
app.include_router(documents.router, prefix="/api/v1") app.include_router(documents.router, prefix="/api/v1")
app.include_router(conversations.router, prefix="/api/v1") app.include_router(conversations.router, prefix="/api/v1")
app.include_router(messages.router, prefix="/api/v1") app.include_router(messages.router, prefix="/api/v1")
app.include_router(rag.router, prefix="/api/v1")
try: try:
from src.presentation.api.v1 import admin from src.presentation.api.v1 import admin

View File

@ -0,0 +1,35 @@
"""
Схемы для RAG
"""
from uuid import UUID
from pydantic import BaseModel, Field
from typing import List, Any
class QuestionRequest(BaseModel):
conversation_id: UUID
question: str = Field(..., min_length=3)
top_k: int = 20
rerank_top_n: int = 5
class RAGSource(BaseModel):
index: int
document_id: str
chunk_id: str
title: str | None = None
class RAGAnswer(BaseModel):
answer: str
sources: List[RAGSource] = []
usage: dict[str, Any] = {}
class IndexDocumentRequest(BaseModel):
document_id: UUID
class IndexDocumentResponse(BaseModel):
chunks_indexed: int

View File

@ -1,6 +1,3 @@
"""
DI контейнер на основе dishka
"""
from dishka import Container, Provider, Scope, provide from dishka import Container, Provider, Scope, provide
from fastapi import Request from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -19,16 +16,26 @@ from src.domain.repositories.document_repository import IDocumentRepository
from src.domain.repositories.conversation_repository import IConversationRepository from src.domain.repositories.conversation_repository import IConversationRepository
from src.domain.repositories.message_repository import IMessageRepository from src.domain.repositories.message_repository import IMessageRepository
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
from src.domain.repositories.vector_repository import IVectorRepository
from src.infrastructure.external.yandex_ocr import YandexOCRService from src.infrastructure.external.yandex_ocr import YandexOCRService
from src.infrastructure.external.deepseek_client import DeepSeekClient from src.infrastructure.external.deepseek_client import DeepSeekClient
from src.infrastructure.external.redis_client import RedisClient
from src.application.services.document_parser_service import DocumentParserService from src.application.services.document_parser_service import DocumentParserService
from src.application.services.cache_service import CacheService
from src.application.use_cases.user_use_cases import UserUseCases from src.application.use_cases.user_use_cases import UserUseCases
from src.application.use_cases.collection_use_cases import CollectionUseCases from src.application.use_cases.collection_use_cases import CollectionUseCases
from src.application.use_cases.document_use_cases import DocumentUseCases from src.application.use_cases.document_use_cases import DocumentUseCases
from src.application.use_cases.conversation_use_cases import ConversationUseCases from src.application.use_cases.conversation_use_cases import ConversationUseCases
from src.application.use_cases.message_use_cases import MessageUseCases from src.application.use_cases.message_use_cases import MessageUseCases
from src.domain.entities.user import User from src.domain.entities.user import User
from src.shared.config import settings
from qdrant_client import QdrantClient
from src.infrastructure.repositories.qdrant.vector_repository import QdrantVectorRepository
from src.application.services.embedding_service import EmbeddingService
from src.application.services.reranker_service import RerankerService
from src.application.services.rag_service import RAGService
from src.application.services.text_splitter import TextSplitter
from src.application.use_cases.rag_use_cases import RAGUseCases
class DatabaseProvider(Provider): class DatabaseProvider(Provider):
@provide(scope=Scope.REQUEST) @provide(scope=Scope.REQUEST)
@ -68,6 +75,14 @@ class RepositoryProvider(Provider):
class ServiceProvider(Provider): class ServiceProvider(Provider):
@provide(scope=Scope.APP)
def get_redis_client(self) -> RedisClient:
return RedisClient()
@provide(scope=Scope.APP)
def get_cache_service(self, redis_client: RedisClient) -> CacheService:
return CacheService(redis_client)
@provide(scope=Scope.APP) @provide(scope=Scope.APP)
def get_ocr_service(self) -> YandexOCRService: def get_ocr_service(self) -> YandexOCRService:
return YandexOCRService() return YandexOCRService()
@ -81,6 +96,44 @@ class ServiceProvider(Provider):
return DocumentParserService(ocr_service) return DocumentParserService(ocr_service)
class VectorServiceProvider(Provider):
@provide(scope=Scope.APP)
def get_qdrant_client(self) -> QdrantClient:
return QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT)
@provide(scope=Scope.APP)
def get_vector_repository(self, client: QdrantClient) -> IVectorRepository:
return QdrantVectorRepository(client=client, vector_size=768)
@provide(scope=Scope.APP)
def get_embedding_service(self) -> EmbeddingService:
return EmbeddingService()
@provide(scope=Scope.APP)
def get_reranker_service(self, embedding_service: EmbeddingService) -> RerankerService:
return RerankerService(fallback_encoder=embedding_service.model)
@provide(scope=Scope.APP)
def get_text_splitter(self) -> TextSplitter:
return TextSplitter()
@provide(scope=Scope.APP)
def get_rag_service(
self,
vector_repo: IVectorRepository,
embedding_service: EmbeddingService,
reranker_service: RerankerService,
deepseek_client: DeepSeekClient,
text_splitter: TextSplitter
) -> RAGService:
return RAGService(
vector_repository=vector_repo,
embedding_service=embedding_service,
reranker_service=reranker_service,
deepseek_client=deepseek_client,
splitter=text_splitter,
)
class AuthProvider(Provider): class AuthProvider(Provider):
@provide(scope=Scope.REQUEST) @provide(scope=Scope.REQUEST)
async def get_current_user(self, request: Request, user_repo: IUserRepository) -> User: async def get_current_user(self, request: Request, user_repo: IUserRepository) -> User:
@ -131,6 +184,17 @@ class UseCaseProvider(Provider):
) -> MessageUseCases: ) -> MessageUseCases:
return MessageUseCases(message_repo, conversation_repo) return MessageUseCases(message_repo, conversation_repo)
@provide(scope=Scope.REQUEST)
def get_rag_use_cases(
self,
rag_service: RAGService,
document_repo: IDocumentRepository,
conversation_repo: IConversationRepository,
message_repo: IMessageRepository,
cache_service: CacheService
) -> RAGUseCases:
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo, cache_service)
def create_container() -> Container: def create_container() -> Container:
container = Container() container = Container()
@ -139,5 +203,6 @@ def create_container() -> Container:
container.add_provider(ServiceProvider()) container.add_provider(ServiceProvider())
container.add_provider(AuthProvider()) container.add_provider(AuthProvider())
container.add_provider(UseCaseProvider()) container.add_provider(UseCaseProvider())
container.add_provider(VectorServiceProvider())
return container return container