Compare commits

..

46 Commits

Author SHA1 Message Date
b0bbc739f3 update swagger 2025-12-24 18:45:38 +03:00
42fcc0eb16 Забыл батчить, теперь ок 2025-12-24 16:17:50 +03:00
683f779c31 UTF 8 вместо абракадабры 2025-12-24 15:55:49 +03:00
ef71c67683 styling add 2025-12-24 15:49:39 +03:00
570f0b7ea7 Messages for indexing 2025-12-24 15:45:35 +03:00
1b550e6503 drone 3 2025-12-24 15:31:04 +03:00
66392765b9 fix drone v2 2025-12-24 15:27:57 +03:00
8cbc318c33 fix build 2025-12-24 15:27:02 +03:00
908e8fc435 fix drone 2025-12-24 15:21:30 +03:00
5a46194d41 fix import 2025-12-24 15:16:27 +03:00
0b04ffefd2 fix 2025-12-24 14:26:15 +03:00
5264b3b64c fix 2025-12-24 13:53:55 +03:00
5809ac5688 fix build 2025-12-24 13:50:53 +03:00
8bdacb4f7a * Add migration
* Delete legacy from bot
* Clear old models
* Единый http клиент
* РАГ полечен
2025-12-24 13:44:52 +03:00
1ce1c23d10 delete --no-cache-dir for libs 2025-12-24 11:09:25 +03:00
5da6c32722 fix name 2025-12-24 11:04:44 +03:00
6b768261e2 test 4 2025-12-24 11:01:35 +03:00
6934220b52 test 3 2025-12-24 10:59:57 +03:00
79980eb313 fuck drone 2025-12-24 10:58:55 +03:00
9f111ad2c2 test drone 2025-12-24 10:57:27 +03:00
7b7165a44b test2 2025-12-24 10:41:16 +03:00
193deb7a8c test 2025-12-24 10:37:13 +03:00
49c3d1b0fd Merge pull request 'andrewbokh' (#6) from andrewbokh into main
Reviewed-on: HSE_team/BetterCallPraskovia#6
2025-12-24 10:36:02 +03:00
c4b3521257 added admin panel 2025-12-24 06:28:01 +03:00
169d874dad fixed bot and server connectivity issues 2025-12-24 04:38:38 +03:00
dfc188e179 fixed DI 2025-12-24 03:14:37 +03:00
493c385cb1 temp 2025-12-23 22:20:42 +03:00
Arxip222
71e8d1079e Merge branch 'luluka' 2025-12-23 12:42:05 +03:00
Arxip222
a7fc2487e9 micro fix dishka 2025-12-23 12:25:46 +03:00
Arxip222
b504bb26c8 hot fix: add type subscription 2025-12-23 12:23:22 +03:00
Arxip222
1f0a5e5159 bot fix 2025-12-23 12:20:09 +03:00
Arxip222
09dfe46a5b hot fx 2 2025-12-23 12:16:19 +03:00
cd08f88434 tests 2025-12-23 12:08:28 +03:00
Arxip222
0bc47a9e7f hot fix 2025-12-23 11:52:02 +03:00
Arxip222
93cf04a1cf move requirements to bot 2025-12-23 11:47:34 +03:00
5c8e07e7f1 Merge pull request 'fix secrets in config' (#4) from arkhip into main
Reviewed-on: HSE_team/BetterCallPraskovia#4
2025-12-23 00:56:57 +03:00
Arxip222
3800f9b554 fix secrets in config 2025-12-23 00:11:57 +03:00
d25feb8d2d Merge pull request 'arkhip' (#3) from arkhip into main
Reviewed-on: HSE_team/BetterCallPraskovia#3
2025-12-22 23:40:02 +03:00
a9099335a3 Merge pull request 'luluka' (#2) from luluka into main
Reviewed-on: HSE_team/BetterCallPraskovia#2
2025-12-22 23:39:26 +03:00
c210c4a3c5 docker + docker compose + ignore + drone 2025-12-22 23:33:30 +03:00
Arxip222
74510ce406 redis include + client with bis logic 2025-12-22 22:57:14 +03:00
6c730918ec Merge remote-tracking branch 'origin/main' into luluka 2025-12-22 21:52:12 +03:00
4a9fab7fba Merge pull request 'polina_tg' (#1) from polina_tg into main
Reviewed-on: HSE_team/BetterCallPraskovia#1
2025-12-22 21:41:08 +03:00
Arxip222
d18cc1fb76 to 22 2025-12-22 13:41:09 +03:00
84556dd220 docker compose 2025-12-21 23:23:50 +03:00
4649882c27 docker backend 2025-12-21 22:52:49 +03:00
87 changed files with 7026 additions and 1852 deletions

35
.dockerignore Normal file
View File

@ -0,0 +1,35 @@
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
venv/
env/
ENV/
.venv/
.vscode/
.idea/
*.swp
*.swo
*~
.DS_Store
.git/
.gitignore
.gitattributes
Dockerfile*
docker-compose*.yml
.dockerignore
drone.yml
tmp/
temp/
*.tmp
Thumbs.db
.DS_Store

34
.drone.yml Normal file
View File

@ -0,0 +1,34 @@
kind: pipeline
type: docker
name: deploy-backend
trigger:
event:
- push
- pull_request
- tag
branch:
- main
steps:
- name: deploy-backend
image: appleboy/drone-ssh
timeout: 30m
settings:
host:
from_secret: server_host
username:
from_secret: server_username
password:
from_secret: server_password
port: 22
command_timeout: 30m
script:
- cd BetterCallPraskovia
- git pull origin main
- docker-compose stop backend tg_bot
- docker-compose rm -f backend tg_bot
- docker-compose build backend tg_bot
- docker-compose up -d --no-deps backend tg_bot
- docker image prune -f

1
.gitignore vendored
View File

@ -1,4 +1,3 @@
# Python
__pycache__/ __pycache__/
*.pyc *.pyc
*.pyo *.pyo

1509
AI_api.yaml Normal file

File diff suppressed because it is too large Load Diff

248
README.md Normal file
View File

@ -0,0 +1,248 @@
# Техническая документация проекта: Юридический AI-ассистент (RAG Система)
## 1. Концепция и методология RAG
**Что такое RAG в данном проекте?**
Мы не просто отправляем вопрос в ChatGPT. Мы реализуем архитектуру **Retrieval-Augmented Generation**. Это необходимо для того, чтобы LLM не галлюцинировала законы, а отвечала строго по тексту загруженных документов.
**Принцип работы:**
1. **Поиск:** Запрос пользователя превращается в вектор. В базе данных находятся фрагменты законов, семантически близкие к запросу
2. **Дополнение:** Найденные фрагменты подклеиваются к промпту пользователя в качестве контекста
3. **Генерация:** LLM генерирует ответ, используя только предоставленный контекст, и проставляет ссылки
## 2. Функциональность
Пользователь (юрист или сотрудник компании) взаимодействует с системой через Telegram-бот. Примеры интерфейса:
- **Отправка запроса**: Пользователь пишет сообщение, например: "Найди все договоры поставки с компанией ООО "Ромашка" или "Какая ответственность за неуплату налогов ИП?". Бот отвечает на естественном языке, с точными фактами и обязательными ссылками на источники (например, "Согласно статье 122 Налогового кодекса РФ")
- **Выбор коллекции**: Пользователь может выбрать доступную его компании коллекцию документов (список доступных коллекций отображается) и далее задавать вопросы в рамках нее
- **Создание коллекций**: Любой пользователь может создать собственную коллекцию через админпанель и выбрать логины пользователей, которые имеют право ей пользоваться
- **Добавление данных**: Админы коллекций через панель могут добавлять документы в свою коллекцию
- **Подписка**: После 10 бесплатных запросов бот предлагает оплатить подписку через встроенный платежный сервис
- **Источники**: Каждый факт в ответе сопровождается ссылкой, например, "Документ: Налоговый кодекс РФ, статья 122"
- **Просмотр коллекций**: Пользователь может просмотреть список с именами и описаниями доступных коллекций
## 3. Стек технологий и обоснование
| Компонент | Технология | Обоснование выбора |
| :--- | :--- | :--- |
| **Backend API** | **FastAPI** | Асинхронность для работы с LLM стримами, легкая интеграция DI, валидация аднных |
| **Telegram Bot** | **Aiogram** | Асинхронный, нативная поддержка машин состояний, мощный механизм мидлвара для логирования и авторизации. |
| **Vector DB** | **Qdrant** | Высокопроизводительная БД на Rust. Поддерживает гибридный поиск и фильтрацию по метаданным из коробки. |
| **Cache & Bus** | **Redis** | 1. **Кэширование RAG:** Хранение пар вопрос-ответ для экономии денег на LLM<br>2. **FSM Storage:** Хранение состояний диалога бота<br>3. **Rate Limiting:** Подсчет количества запросов пользователя |
| **Main DB** | **PostgreSQL** | Хранение профилей пользователей, логов чатов, конфигураций коллекций и данных для биллинга |
| **LLM** | **DeepSeek** | Использование API внешних провайдеров |
## 4. API Интерфейс
Бэкенд предоставляет REST API, документированное по стандарту OpenAPI.
Полная swagger спецификация представлена в файлах вместе с данным документом (*AI_api.yaml*)
## 5. Какие задачи необходимо выполнить
- **Разработка backend (FastAPI)**: Создать API для обработки запросов, интеграции с БД, эмбеддингом и LLM. Включает DDD, Dependency Injection (dishka), асинхронность. Также сервис авторизации.
- **Telegram-бот (aiogram)**: Интеграция с API, авторизация, обработка сообщений, команд для коллекций и платежей.
- **Векторная БД (Qdrant)**: Настройка для хранения embeddings, коллекций (в сущностях DocumentCollection, Embeddings).
- **Индексатор и Поисковик**: Реализация предобработки, векторизации и семантического поиска.
- **Реранкер и Генератор**: Модули для реранкинга топ-k и генерации ответов с источниками.
- **Админ панель**: Простой интерфейс для добавления/удаления данных.
- **Тестирование**: Создание тестового датасета с hit@5 > 50%.
- **Кэширование**: Для ускорения похожих запросов.
- **Платежный сервис**: Простая подписка после 10 запросов.
- **Деплой**: Docker-контейнеры для FastAPI и бота.
## 6. Сколько времени потратится на каждую задачу
Оценки времени (в рабочих днях):
| Задача | Время (дни) |
| -------------------------------------- | ---------------- |
| Backend (FastAPI) | 5 |
| Telegram-бот | 4 |
| Векторная БД | 2 |
| Индексатор и Поисковик | 3 |
| Реранкер и Генератор | 4 |
| Админ панель | 2 |
| Тестирование | 2 |
| Датасет проверки качества | 2 |
| Кэширование | 2 |
| Платежный сервис | 2 |
| Деплой и интеграция | 3 |
| Эмбеддинг | 2 |
| **Итого** | **33** |
Общий срок — 4 недели (с учетом параллельной работы команды из 4 человек).
## 7. Кто что делает в команде
Распределение ролей в команде из 4 человек:
- **Developer 1 (Lead/Backend)**: Backend (FastAPI, DDD, DI), Векторная БД, Индексатор/Поисковик.
- **Developer 2 (AI/ML)**: Эмбеддинг, Реранкер, Генератор, Тестирование (датасет, метрики).
- **Developer 3 (Frontend/Bot)**: Telegram-бот, Админ панель, Платежный сервис.
- **Developer 4 (DevOps/Tester)**: Кэширование, Деплой (Docker), тестирование функционала.
Каждый использует DTO и интерфейсы для контрактов, что обеспечивает изоляцию компонентов и легкость изменений.
## Диаграмма компонентов системы
``` mermaid
graph TD
User((Пользователь))
Admin((Админ))
subgraph "Интерфейсы"
TGBot[Telegram Bot
Aiogram]
API_Endpoint[API Gateway
FastAPI]
end
subgraph "Бэкенд RAG"
Logic[Бизнес-логика
Services & RAG Pipeline]
end
subgraph "Базы данных"
PG[(PostgreSQL
Users, Logs)]
VDB[(Vector DB
Embeddings)]
Redis[(Redis
Cache)]
end
subgraph "Внешние API"
LLM[LLM API
DeepSeek]
Pay[Payment API
Yookassa]
end
User -->|Команды/Вопросы| TGBot
TGBot -->|REST| API_Endpoint
Admin -->|Загрузка документов| API_Endpoint
API_Endpoint --> Logic
Logic -->|Кэширование| Redis
Logic -->|Метаданные| PG
Logic -->|Поиск похожих| VDB
Logic -->|Генерация ответа| LLM
Logic -->|Транзакции| Pay
```
## Пайплайн RAG
``` mermaid
sequenceDiagram
autonumber
actor User as Юрист
participant System as Бот + RAG Сервис
participant VectorDB as Векторная БД
participant AI as LLM
Note over User, System: Шаг 1: Запрос
User->>System: "Статья 5.5, что это?"
Note over System, VectorDB: Шаг 2: Поиск фактов
System->>VectorDB: Ищет похожие статьи законов
VectorDB-->>System: Возвращает Топ-5 документов
Note over System, AI: Шаг 3: Генерация
System->>AI: Отправляет промпт
AI-->>System: Формирует ответ, опираясь на законы
Note over User, System: Шаг 4: Ответ
System-->>User: Текст консультации + Ссылки на статьи
```
## Схема индексации
``` mermaid
graph TD
Actor[Администратор / Юрист] -->|Загрузка документа| API[FastAPI]
API --> Input
subgraph "1. Предобработка"
Input[Файл PDF/DOCX] --> Extract(Извлечение текста)
Extract --> Clean{Очистка}
Clean -- Удаление шума --> Norm[Нормализация текста]
end
subgraph "2. Чанкирование"
Norm --> Splitter[Разбиение на фрагменты]
--> Chunks[Список чанков]
end
subgraph "3. Векторизация"
Chunks --> Model[Эмбеддинг модель
FRIDA]
Model --> Vectors[Векторные представления]
end
subgraph "4. Сохранение"
Vectors --> VDB[(Векторная БД
Qdrant)]
Chunks -->|Заголовок, дата, автор| MetaDB[(Postgres
Метаданные)]
end
```
## Схема БД
``` mermaid
erDiagram
USERS {
uuid user_id PK
string telegram_id
datetime created_at
string role "user / admin"
}
COLLECTIONS {
uuid collection_id PK
string name
text description
uuid owner_id FK
boolean is_public
datetime created_at
}
DOCUMENTS {
uuid document_id PK
uuid collection_id FK
string title
text content
json metadata
vector vector_embedding "Вектор всего документа"
datetime created_at
}
EMBEDDINGS {
uuid embedding_id PK
uuid document_id FK
vector embedding
string model_version
datetime created_at
}
CONVERSATIONS {
uuid conversation_id PK
uuid user_id FK
uuid collection_id FK
datetime created_at
datetime updated_at
}
MESSAGES {
uuid message_id PK
uuid conversation_id FK
text content
string role "user / ai"
json sources "Ссылки на использованные документы"
datetime created_at
}
USERS ||--o{ COLLECTIONS : "owner_id"
USERS ||--o{ CONVERSATIONS : "user_id"
COLLECTIONS ||--o{ DOCUMENTS : "collection_id"
COLLECTIONS ||--o{ CONVERSATIONS : "collection_id"
DOCUMENTS ||--o{ EMBEDDINGS : "document_id"
CONVERSATIONS ||--o{ MESSAGES : "conversation_id"
```

34
backend/.dockerignore Normal file
View File

@ -0,0 +1,34 @@
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
venv/
env/
ENV/
.venv/
.vscode/
.idea/
*.swp
*.swo
*~
.DS_Store
.git/
.gitignore
.gitattributes
Dockerfile*
docker-compose*.yml
.dockerignore
drone.yml
tmp/
temp/
*.tmp
Thumbs.db
.DS_Store

21
backend/Dockerfile Normal file
View File

@ -0,0 +1,21 @@
FROM python:3.11-slim
WORKDIR /app
RUN apt-get update && apt-get install -y \
gcc \
postgresql-client \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
ENV PYTHONPATH=/app
ENV PYTHONUNBUFFERED=1
EXPOSE 8000
CMD ["python", "-m", "uvicorn", "src.presentation.main:app", "--host", "0.0.0.0", "--port", "8000", "--log-level", "info"]

View File

@ -0,0 +1,28 @@
"""Add premium fields to users
Revision ID: 002
Revises: 001
Create Date: 2024-01-02 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = '002'
down_revision = '001'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column('users', sa.Column('is_premium', sa.Boolean(), nullable=False, server_default='false'))
op.add_column('users', sa.Column('premium_until', sa.DateTime(), nullable=True))
op.add_column('users', sa.Column('questions_used', sa.Integer(), nullable=False, server_default='0'))
def downgrade() -> None:
op.drop_column('users', 'questions_used')
op.drop_column('users', 'premium_until')
op.drop_column('users', 'is_premium')

View File

@ -0,0 +1,33 @@
"""Remove unused embeddings table
Revision ID: 003
Revises: 002
Create Date: 2024-12-24 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = '003'
down_revision = '002'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table('embeddings')
def downgrade() -> None:
op.create_table(
'embeddings',
sa.Column('embedding_id', sa.dialects.postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('document_id', sa.dialects.postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('embedding', sa.dialects.postgresql.JSON(astext_type=sa.Text()), nullable=True),
sa.Column('model_version', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['document_id'], ['documents.document_id'], ),
sa.PrimaryKeyConstraint('embedding_id')
)

View File

@ -1,4 +1,4 @@
fastapi==0.104.1 fastapi==0.100.1
uvicorn[standard]==0.24.0 uvicorn[standard]==0.24.0
sqlalchemy[asyncio]==2.0.23 sqlalchemy[asyncio]==2.0.23
asyncpg==0.29.0 asyncpg==0.29.0
@ -10,4 +10,7 @@ httpx==0.25.2
PyMuPDF==1.23.8 PyMuPDF==1.23.8
Pillow==10.2.0 Pillow==10.2.0
dishka==0.7.0 dishka==0.7.0
numpy==1.26.4
sentence-transformers==2.7.0
qdrant-client==1.9.0
redis==5.0.1

View File

@ -0,0 +1,23 @@
import sys
import os
from pathlib import Path
backend_dir = Path(__file__).parent
sys.path.insert(0, str(backend_dir))
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"src.presentation.main:app",
host="0.0.0.0",
port=8000,
reload=True,
log_level="info"
)

View File

@ -0,0 +1,49 @@
import hashlib
from typing import Optional
from uuid import UUID
from src.infrastructure.external.redis_client import RedisClient
class CacheService:
def __init__(self, redis_client: RedisClient, default_ttl: int = 3600 * 24):
self.redis_client = redis_client
self.default_ttl = default_ttl
def _make_key(self, collection_id: UUID, question: str) -> str:
question_hash = hashlib.sha256(question.encode()).hexdigest()[:16]
return f"rag:answer:{collection_id}:{question_hash}"
async def get_cached_answer(self, collection_id: UUID, question: str) -> Optional[dict]:
key = self._make_key(collection_id, question)
cached = await self.redis_client.get_json(key)
if cached and cached.get("question") == question:
return cached.get("answer")
return None
async def cache_answer(self, collection_id: UUID, question: str, answer: dict, ttl: Optional[int] = None):
key = self._make_key(collection_id, question)
value = {
"question": question,
"answer": answer
}
await self.redis_client.set_json(key, value, ttl or self.default_ttl)
async def invalidate_collection_cache(self, collection_id: UUID):
pattern = f"rag:answer:{collection_id}:*"
keys = await self.redis_client.keys(pattern)
if keys:
for key in keys:
await self.redis_client.delete(key)
async def invalidate_document_cache(self, document_id: UUID):
pattern = f"rag:answer:*"
keys = await self.redis_client.keys(pattern)
if keys:
for key in keys:
cached = await self.redis_client.get_json(key)
if cached:
sources = cached.get("answer", {}).get("sources", [])
doc_ids = [s.get("document_id") for s in sources]
if str(document_id) in doc_ids:
await self.redis_client.delete(key)

View File

@ -67,6 +67,8 @@ class DocumentParserService:
return title, content return title, content
except YandexOCRError: except YandexOCRError:
raise raise
except Exception as e: except Exception as e:
raise YandexOCRError(f"Ошибка при парсинге изображения: {str(e)}") from e raise YandexOCRError(f"Ошибка при парсинге изображения: {str(e)}") from e

View File

@ -0,0 +1,37 @@
"""
Сервис для вычисления эмбеддингов текстов
"""
from functools import lru_cache
from typing import Iterable
import numpy as np
from sentence_transformers import SentenceTransformer
class EmbeddingService:
def __init__(self, model_name: str | None = None):
self.model_name = model_name or "intfloat/multilingual-e5-base"
self._model = None
@property
def model(self) -> SentenceTransformer:
if self._model is None:
self._model = SentenceTransformer(self.model_name)
return self._model
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
embeddings = self.model.encode(
list(texts),
batch_size=8,
show_progress_bar=False,
normalize_embeddings=True,
)
return [np.array(v, dtype=np.float32).tolist() for v in embeddings]
def embed_query(self, text: str) -> list[float]:
return self.embed_texts([text])[0]
@lru_cache(maxsize=1)
def model_version(self) -> str:
return self.model_name

View File

@ -0,0 +1,118 @@
"""
Сервис RAG: индексация, поиск, генерация ответа
"""
from typing import Sequence
from uuid import UUID
from src.application.services.text_splitter import TextSplitter
from src.application.services.embedding_service import EmbeddingService
from src.application.services.reranker_service import RerankerService
from src.domain.entities.document import Document
from src.domain.entities.chunk import DocumentChunk
from src.domain.repositories.vector_repository import IVectorRepository
from src.infrastructure.external.deepseek_client import DeepSeekClient
class RAGService:
def __init__(
self,
vector_repository: IVectorRepository,
embedding_service: EmbeddingService,
reranker_service: RerankerService,
deepseek_client: DeepSeekClient,
splitter: TextSplitter | None = None,
):
self.vector_repository = vector_repository
self.embedding_service = embedding_service
self.reranker_service = reranker_service
self.deepseek_client = deepseek_client
self.splitter = splitter or TextSplitter()
async def index_document(self, document: Document) -> list[DocumentChunk]:
chunks_text = self.splitter.split(document.content)
chunks: list[DocumentChunk] = []
for idx, text in enumerate(chunks_text):
chunks.append(
DocumentChunk(
document_id=document.document_id,
collection_id=document.collection_id,
content=text,
order=idx,
metadata={"title": document.title},
)
)
EMBEDDING_BATCH_SIZE = 50
all_embeddings: list[list[float]] = []
for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE):
batch_chunks = chunks[i:i + EMBEDDING_BATCH_SIZE]
batch_texts = [c.content for c in batch_chunks]
batch_embeddings = self.embedding_service.embed_texts(batch_texts)
all_embeddings.extend(batch_embeddings)
print(f"Created {len(all_embeddings)} embeddings, upserting to Qdrant...")
await self.vector_repository.upsert_chunks(
chunks, all_embeddings, model_version=self.embedding_service.model_version()
)
return chunks
async def retrieve(
self, query: str, collection_id: UUID, limit: int = 20, rerank_top_n: int = 5
) -> list[tuple[DocumentChunk, float]]:
query_embedding = self.embedding_service.embed_query(query)
candidates = await self.vector_repository.search(
query_embedding, collection_id=collection_id, limit=limit
)
if not candidates:
return []
passages = [c.content for c, _ in candidates]
order = self.reranker_service.rerank(query, passages, top_n=rerank_top_n)
return [candidates[i] for i in order if i < len(candidates)]
async def generate_answer(
self,
query: str,
context_chunks: Sequence[DocumentChunk],
max_tokens: int | None = 400,
temperature: float = 0.2,
) -> dict:
context_blocks = [
f"[{idx+1}] {c.content}\nИсточник: документ {c.metadata.get('title','')} (chunk {c.order})"
for idx, c in enumerate(context_chunks)
]
context = "\n\n".join(context_blocks)
system_prompt = (
"Ты юридический ассистент. Отвечай только на основе переданного контекста. "
"Обязательно добавляй ссылки на источники в формате [номер]. "
"Если ответа нет в контексте, скажи, что данных недостаточно."
)
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": f"Вопрос: {query}\n\nКонтекст:\n{context}",
},
]
resp = await self.deepseek_client.chat_completion(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
)
return {
"content": resp.get("content", ""),
"usage": resp.get("usage", {}),
"sources": [
{
"index": idx + 1,
"document_id": str(chunk.document_id),
"chunk_id": str(chunk.chunk_id),
"title": chunk.metadata.get("title", ""),
}
for idx, chunk in enumerate(context_chunks)
],
}

View File

@ -0,0 +1,44 @@
"""
Сервис реранкинга результатов поиска
"""
from typing import Sequence
import numpy as np
from sentence_transformers import CrossEncoder, SentenceTransformer
class RerankerService:
def __init__(
self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
fallback_encoder: SentenceTransformer | None = None,
):
self.model_name = model_name
self._model: CrossEncoder | None = None
self.fallback_encoder = fallback_encoder
@property
def model(self) -> CrossEncoder:
if self._model is None:
self._model = CrossEncoder(self.model_name)
return self._model
def rerank(self, query: str, passages: Sequence[str], top_n: int = 5) -> list[int]:
if not passages:
return []
try:
scores = self.model.predict([[query, p] for p in passages])
order = np.argsort(scores)[::-1]
return order[:top_n].tolist()
except Exception:
if self.fallback_encoder:
q_emb = self.fallback_encoder.encode(query, normalize_embeddings=True)
p_emb = self.fallback_encoder.encode(
list(passages), normalize_embeddings=True
)
sims = np.dot(p_emb, q_emb)
order = np.argsort(sims)[::-1]
return order[:top_n].tolist()
return list(range(min(top_n, len(passages))))

View File

@ -0,0 +1,48 @@
"""
Простой текстовый сплиттер для подготовки чанков
"""
import re
from typing import Iterable
class TextSplitter:
def __init__(self, chunk_size: int = 800, chunk_overlap: int = 200):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split(self, text: str) -> list[str]:
normalized = self._normalize(text)
if not normalized:
return []
sentences = self._split_sentences(normalized)
chunks: list[str] = []
current: list[str] = []
current_len = 0
for sent in sentences:
if current_len + len(sent) > self.chunk_size and current:
chunks.append(" ".join(current).strip())
while current and current_len > self.chunk_overlap:
popped = current.pop(0)
current_len -= len(popped)
current.append(sent)
current_len += len(sent)
if current:
chunks.append(" ".join(current).strip())
return [c for c in chunks if c]
def _normalize(self, text: str) -> str:
return re.sub(r"\s+", " ", text).strip()
def _split_sentences(self, text: str) -> Iterable[str]:
parts = re.split(r"(?<=[\.\?\!])\s+", text)
if len(parts) == 1 and len(text) > self.chunk_size * 2:
chunk_text = []
for i in range(0, len(text), self.chunk_size):
chunk_text.append(text[i:i + self.chunk_size])
return chunk_text
return [p.strip() for p in parts if p.strip()]

View File

@ -139,3 +139,67 @@ class CollectionUseCases:
all_collections = {c.collection_id: c for c in owned + public + accessed_collections} all_collections = {c.collection_id: c for c in owned + public + accessed_collections}
return list(all_collections.values())[skip:skip+limit] return list(all_collections.values())[skip:skip+limit]
async def list_collection_access(self, collection_id: UUID, user_id: UUID) -> list[CollectionAccess]:
"""Получить список доступа к коллекции"""
collection = await self.get_collection(collection_id)
has_access = await self.check_access(collection_id, user_id)
if not has_access:
raise ForbiddenError("У вас нет доступа к этой коллекции")
return await self.access_repository.list_by_collection(collection_id)
async def grant_access_by_telegram_id(
self,
collection_id: UUID,
telegram_id: str,
owner_id: UUID
) -> CollectionAccess:
"""Предоставить доступ пользователю к коллекции по Telegram ID"""
collection = await self.get_collection(collection_id)
if collection.owner_id != owner_id:
raise ForbiddenError("Только владелец может предоставлять доступ")
user = await self.user_repository.get_by_telegram_id(telegram_id)
if not user:
from src.domain.entities.user import User, UserRole
import logging
logger = logging.getLogger(__name__)
logger.info(f"Creating new user with telegram_id: {telegram_id}")
user = User(telegram_id=telegram_id, role=UserRole.USER)
try:
user = await self.user_repository.create(user)
logger.info(f"User created successfully: user_id={user.user_id}, telegram_id={user.telegram_id}")
except Exception as e:
logger.error(f"Error creating user: {e}")
raise
if user.user_id == owner_id:
raise ForbiddenError("Владелец уже имеет доступ к коллекции")
existing_access = await self.access_repository.get_by_user_and_collection(user.user_id, collection_id)
if existing_access:
return existing_access
access = CollectionAccess(user_id=user.user_id, collection_id=collection_id)
return await self.access_repository.create(access)
async def revoke_access_by_telegram_id(
self,
collection_id: UUID,
telegram_id: str,
owner_id: UUID
) -> bool:
"""Отозвать доступ пользователя к коллекции по Telegram ID"""
collection = await self.get_collection(collection_id)
if collection.owner_id != owner_id:
raise ForbiddenError("Только владелец может отзывать доступ")
user = await self.user_repository.get_by_telegram_id(telegram_id)
if not user:
raise NotFoundError(f"Пользователь с telegram_id {telegram_id} не найден")
return await self.access_repository.delete_by_user_and_collection(user.user_id, collection_id)

View File

@ -3,11 +3,15 @@ Use cases для работы с документами
""" """
from uuid import UUID from uuid import UUID
from typing import BinaryIO, Optional from typing import BinaryIO, Optional
import httpx
from src.domain.entities.document import Document from src.domain.entities.document import Document
from src.domain.repositories.document_repository import IDocumentRepository from src.domain.repositories.document_repository import IDocumentRepository
from src.domain.repositories.collection_repository import ICollectionRepository from src.domain.repositories.collection_repository import ICollectionRepository
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
from src.application.services.document_parser_service import DocumentParserService from src.application.services.document_parser_service import DocumentParserService
from src.application.services.rag_service import RAGService
from src.shared.exceptions import NotFoundError, ForbiddenError from src.shared.exceptions import NotFoundError, ForbiddenError
from src.shared.config import settings
class DocumentUseCases: class DocumentUseCases:
@ -17,11 +21,26 @@ class DocumentUseCases:
self, self,
document_repository: IDocumentRepository, document_repository: IDocumentRepository,
collection_repository: ICollectionRepository, collection_repository: ICollectionRepository,
parser_service: DocumentParserService access_repository: ICollectionAccessRepository,
parser_service: DocumentParserService,
rag_service: Optional[RAGService] = None
): ):
self.document_repository = document_repository self.document_repository = document_repository
self.collection_repository = collection_repository self.collection_repository = collection_repository
self.access_repository = access_repository
self.parser_service = parser_service self.parser_service = parser_service
self.rag_service = rag_service
async def _check_collection_access(self, user_id: UUID, collection) -> bool:
"""Проверить доступ пользователя к коллекции"""
if collection.owner_id == user_id:
return True
if collection.is_public:
return True
access = await self.access_repository.get_by_user_and_collection(user_id, collection.collection_id)
return access is not None
async def create_document( async def create_document(
self, self,
@ -43,20 +62,43 @@ class DocumentUseCases:
) )
return await self.document_repository.create(document) return await self.document_repository.create(document)
async def _send_telegram_notification(self, telegram_id: str, message: str):
"""Отправить уведомление пользователю через Telegram Bot API"""
if not settings.TELEGRAM_BOT_TOKEN:
return
try:
url = f"https://api.telegram.org/bot{settings.TELEGRAM_BOT_TOKEN}/sendMessage"
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
url,
json={
"chat_id": telegram_id,
"text": message,
"parse_mode": "HTML"
}
)
if response.status_code != 200:
print(f"Failed to send Telegram notification: {response.status_code}")
except Exception as e:
print(f"Error sending Telegram notification: {e}")
async def upload_and_parse_document( async def upload_and_parse_document(
self, self,
collection_id: UUID, collection_id: UUID,
file: BinaryIO, file: BinaryIO,
filename: str, filename: str,
user_id: UUID user_id: UUID,
telegram_id: Optional[str] = None
) -> Document: ) -> Document:
"""Загрузить и распарсить документ""" """Загрузить и распарсить документ, затем автоматически проиндексировать"""
collection = await self.collection_repository.get_by_id(collection_id) collection = await self.collection_repository.get_by_id(collection_id)
if not collection: if not collection:
raise NotFoundError(f"Коллекция {collection_id} не найдена") raise NotFoundError(f"Коллекция {collection_id} не найдена")
if collection.owner_id != user_id: has_access = await self._check_collection_access(user_id, collection)
raise ForbiddenError("Только владелец может добавлять документы") if not has_access:
raise ForbiddenError("У вас нет доступа к этой коллекции")
title, content = await self.parser_service.parse_pdf(file, filename) title, content = await self.parser_service.parse_pdf(file, filename)
@ -66,7 +108,41 @@ class DocumentUseCases:
content=content, content=content,
metadata={"filename": filename} metadata={"filename": filename}
) )
return await self.document_repository.create(document) document = await self.document_repository.create(document)
if self.rag_service and telegram_id:
try:
await self._send_telegram_notification(
telegram_id,
"🔄 <b>Начинаю индексацию документа...</b>\n\n"
f"📄 <b>Документ:</b> {title}\n\n"
f"Это может занять некоторое время.\n"
f"Вы получите уведомление по завершении."
)
chunks = await self.rag_service.index_document(document)
await self._send_telegram_notification(
telegram_id,
"✅ <b>Индексация завершена!</b>\n\n"
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"📄 <b>Документ:</b> {title}\n"
f"📊 <b>Проиндексировано чанков:</b> {len(chunks)}\n\n"
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"💡 <b>Теперь вы можете задавать вопросы по этому документу!</b>\n"
f"Просто напишите ваш вопрос, и я найду ответ на основе загруженного документа."
)
except Exception as e:
print(f"Ошибка при автоматической индексации документа {document.document_id}: {e}")
if telegram_id:
await self._send_telegram_notification(
telegram_id,
"⚠️ <b>Ошибка при индексации</b>\n\n"
f"Документ загружен, но индексация не завершена.\n"
f"Ошибка: {str(e)[:200]}"
)
return document
async def get_document(self, document_id: UUID) -> Document: async def get_document(self, document_id: UUID) -> Document:
"""Получить документ по ID""" """Получить документ по ID"""
@ -87,8 +163,11 @@ class DocumentUseCases:
document = await self.get_document(document_id) document = await self.get_document(document_id)
collection = await self.collection_repository.get_by_id(document.collection_id) collection = await self.collection_repository.get_by_id(document.collection_id)
if not collection or collection.owner_id != user_id: if not collection:
raise ForbiddenError("Только владелец коллекции может изменять документы") raise NotFoundError(f"Коллекция {document.collection_id} не найдена")
has_access = await self._check_collection_access(user_id, collection)
if not has_access:
raise ForbiddenError("У вас нет доступа к этой коллекции")
if title is not None: if title is not None:
document.title = title document.title = title

View File

@ -0,0 +1,92 @@
"""
Use cases для RAG: индексация документов и ответы на вопросы
"""
from uuid import UUID
from src.application.services.rag_service import RAGService
from src.application.services.cache_service import CacheService
from src.domain.repositories.document_repository import IDocumentRepository
from src.domain.repositories.conversation_repository import IConversationRepository
from src.domain.repositories.message_repository import IMessageRepository
from src.domain.entities.message import Message, MessageRole
from src.shared.exceptions import NotFoundError, ForbiddenError
class RAGUseCases:
def __init__(
self,
rag_service: RAGService,
document_repo: IDocumentRepository,
conversation_repo: IConversationRepository,
message_repo: IMessageRepository,
cache_service: CacheService,
):
self.rag_service = rag_service
self.document_repo = document_repo
self.conversation_repo = conversation_repo
self.message_repo = message_repo
self.cache_service = cache_service
async def index_document(self, document_id: UUID) -> dict:
document = await self.document_repo.get_by_id(document_id)
if not document:
raise NotFoundError(f"Документ {document_id} не найден")
chunks = await self.rag_service.index_document(document)
return {"chunks_indexed": len(chunks)}
async def ask_question(
self,
conversation_id: UUID,
user_id: UUID,
question: str,
top_k: int = 20,
rerank_top_n: int = 5,
) -> dict:
conversation = await self.conversation_repo.get_by_id(conversation_id)
if not conversation:
raise NotFoundError(f"Беседа {conversation_id} не найдена")
if conversation.user_id != user_id:
raise ForbiddenError("Нет доступа к этой беседе")
user_message = Message(
conversation_id=conversation_id, content=question, role=MessageRole.USER
)
await self.message_repo.create(user_message)
cached_answer = None
if self.cache_service:
cached_answer = await self.cache_service.get_cached_answer(conversation.collection_id, question)
if cached_answer:
generation = cached_answer
else:
retrieved = await self.rag_service.retrieve(
query=question,
collection_id=conversation.collection_id,
limit=top_k,
rerank_top_n=rerank_top_n,
)
chunks = [c for c, _ in retrieved]
generation = await self.rag_service.generate_answer(question, chunks)
if self.cache_service:
await self.cache_service.cache_answer(
conversation.collection_id,
question,
generation
)
assistant_message = Message(
conversation_id=conversation_id,
content=generation["content"],
role=MessageRole.ASSISTANT,
sources={"chunks": generation.get("sources", [])},
)
await self.message_repo.create(assistant_message)
return {
"answer": generation["content"],
"sources": generation.get("sources", []),
"usage": generation.get("usage", {}),
}

View File

@ -3,6 +3,7 @@ Use cases для работы с пользователями
""" """
from uuid import UUID from uuid import UUID
from typing import Optional from typing import Optional
from datetime import datetime, timedelta
from src.domain.entities.user import User, UserRole from src.domain.entities.user import User, UserRole
from src.domain.repositories.user_repository import IUserRepository from src.domain.repositories.user_repository import IUserRepository
from src.shared.exceptions import NotFoundError, ValidationError from src.shared.exceptions import NotFoundError, ValidationError
@ -53,3 +54,26 @@ class UserUseCases:
"""Получить список пользователей""" """Получить список пользователей"""
return await self.user_repository.list_all(skip=skip, limit=limit) return await self.user_repository.list_all(skip=skip, limit=limit)
async def increment_questions_used(self, telegram_id: str) -> User:
"""Увеличить счетчик использованных вопросов"""
user = await self.user_repository.get_by_telegram_id(telegram_id)
if not user:
raise NotFoundError(f"Пользователь с telegram_id {telegram_id} не найден")
user.questions_used += 1
return await self.user_repository.update(user)
async def activate_premium(self, telegram_id: str, days: int = 30) -> User:
"""Активировать premium статус"""
user = await self.user_repository.get_by_telegram_id(telegram_id)
if not user:
raise NotFoundError(f"Пользователь с telegram_id {telegram_id} не найден")
user.is_premium = True
if user.premium_until and user.premium_until > datetime.utcnow():
user.premium_until = user.premium_until + timedelta(days=days)
else:
user.premium_until = datetime.utcnow() + timedelta(days=days)
return await self.user_repository.update(user)

View File

@ -0,0 +1,28 @@
"""
Доменная сущность чанка
"""
from datetime import datetime
from uuid import UUID, uuid4
from typing import Any
class DocumentChunk:
def __init__(
self,
document_id: UUID,
collection_id: UUID,
content: str,
chunk_id: UUID | None = None,
order: int = 0,
metadata: dict[str, Any] | None = None,
created_at: datetime | None = None,
):
self.chunk_id = chunk_id or uuid4()
self.document_id = document_id
self.collection_id = collection_id
self.content = content
self.order = order
self.metadata = metadata or {}
self.created_at = created_at or datetime.utcnow()

View File

@ -1,25 +0,0 @@
"""
Доменная сущность Embedding
"""
from datetime import datetime
from uuid import UUID, uuid4
from typing import Any
class Embedding:
"""Эмбеддинг документа"""
def __init__(
self,
document_id: UUID,
embedding: list[float] | None = None,
model_version: str = "",
embedding_id: UUID | None = None,
created_at: datetime | None = None
):
self.embedding_id = embedding_id or uuid4()
self.document_id = document_id
self.embedding = embedding or []
self.model_version = model_version
self.created_at = created_at or datetime.utcnow()

View File

@ -20,12 +20,18 @@ class User:
telegram_id: str, telegram_id: str,
role: UserRole = UserRole.USER, role: UserRole = UserRole.USER,
user_id: UUID | None = None, user_id: UUID | None = None,
created_at: datetime | None = None created_at: datetime | None = None,
is_premium: bool = False,
premium_until: datetime | None = None,
questions_used: int = 0
): ):
self.user_id = user_id or uuid4() self.user_id = user_id or uuid4()
self.telegram_id = telegram_id self.telegram_id = telegram_id
self.role = role self.role = role
self.created_at = created_at or datetime.utcnow() self.created_at = created_at or datetime.utcnow()
self.is_premium = is_premium
self.premium_until = premium_until
self.questions_used = questions_used
def is_admin(self) -> bool: def is_admin(self) -> bool:
"""проверка, является ли пользователь администратором""" """проверка, является ли пользователь администратором"""

View File

@ -0,0 +1,31 @@
"""
Интерфейс репозитория/хранилища векторов
"""
from abc import ABC, abstractmethod
from typing import Sequence
from uuid import UUID
from src.domain.entities.chunk import DocumentChunk
class IVectorRepository(ABC):
@abstractmethod
async def upsert_chunks(
self,
chunks: Sequence[DocumentChunk],
embeddings: Sequence[list[float]],
model_version: str,
) -> None:
"""Сохранить или обновить вектора чанков"""
raise NotImplementedError
@abstractmethod
async def search(
self,
query_embedding: list[float],
collection_id: UUID,
limit: int = 20,
) -> list[tuple[DocumentChunk, float]]:
"""Поиск ближайших чанков по коллекции с расстоянием"""
raise NotImplementedError

View File

@ -17,6 +17,10 @@ class UserModel(Base):
telegram_id = Column(String, unique=True, nullable=False, index=True) telegram_id = Column(String, unique=True, nullable=False, index=True)
role = Column(String, nullable=False, default="user") role = Column(String, nullable=False, default="user")
created_at = Column(DateTime, nullable=False, default=datetime.utcnow) created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
is_premium = Column(Boolean, default=False, nullable=False)
premium_until = Column(DateTime, nullable=True)
questions_used = Column(Integer, default=0, nullable=False)
collections = relationship("CollectionModel", back_populates="owner", cascade="all, delete-orphan") collections = relationship("CollectionModel", back_populates="owner", cascade="all, delete-orphan")
conversations = relationship("ConversationModel", back_populates="user", cascade="all, delete-orphan") conversations = relationship("ConversationModel", back_populates="user", cascade="all, delete-orphan")
collection_accesses = relationship("CollectionAccessModel", back_populates="user", cascade="all, delete-orphan") collection_accesses = relationship("CollectionAccessModel", back_populates="user", cascade="all, delete-orphan")
@ -49,19 +53,6 @@ class DocumentModel(Base):
document_metadata = Column("metadata", JSON, nullable=True, default={}) document_metadata = Column("metadata", JSON, nullable=True, default={})
created_at = Column(DateTime, nullable=False, default=datetime.utcnow) created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
collection = relationship("CollectionModel", back_populates="documents") collection = relationship("CollectionModel", back_populates="documents")
embeddings = relationship("EmbeddingModel", back_populates="document", cascade="all, delete-orphan")
class EmbeddingModel(Base):
"""Модель эмбеддинга (заглушка)"""
__tablename__ = "embeddings"
embedding_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
document_id = Column(UUID(as_uuid=True), ForeignKey("documents.document_id"), nullable=False)
embedding = Column(JSON, nullable=True)
model_version = Column(String, nullable=True)
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
document = relationship("DocumentModel", back_populates="embeddings")
class ConversationModel(Base): class ConversationModel(Base):

View File

@ -0,0 +1,76 @@
import json
from typing import Optional, Any
import redis.asyncio as aioredis
from src.shared.config import settings
class RedisClient:
def __init__(self, host: str, port: int):
self.host = host or settings.REDIS_HOST
self.port = port or settings.REDIS_PORT
self._client: Optional[aioredis.Redis] = None
async def connect(self):
if self._client is None:
self._client = await aioredis.from_url(
f"redis://{self.host}:{self.port}",
encoding="utf-8",
decode_responses=True
)
async def disconnect(self):
if self._client:
await self._client.aclose()
self._client = None
async def get(self, key: str) -> Optional[str]:
if self._client is None:
await self.connect()
return await self._client.get(key)
async def set(self, key: str, value: str, ttl: Optional[int] = None):
if self._client is None:
await self.connect()
if ttl:
await self._client.setex(key, ttl, value)
else:
await self._client.set(key, value)
async def get_json(self, key: str) -> Optional[dict[str, Any]]:
value = await self.get(key)
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
return None
async def set_json(self, key: str, value: dict[str, Any], ttl: Optional[int] = None):
json_str = json.dumps(value)
await self.set(key, json_str, ttl)
async def delete(self, key: str):
if self._client is None:
await self.connect()
await self._client.delete(key)
async def exists(self, key: str) -> bool:
if self._client is None:
await self.connect()
return bool(await self._client.exists(key))
async def incr(self, key: str) -> int:
if self._client is None:
await self.connect()
return await self._client.incr(key)
async def expire(self, key: str, seconds: int):
if self._client is None:
await self.connect()
await self._client.expire(key, seconds)
async def keys(self, pattern: str) -> list[str]:
if self._client is None:
await self.connect()
return await self._client.keys(pattern)

View File

@ -23,7 +23,10 @@ class PostgreSQLUserRepository(IUserRepository):
user_id=user.user_id, user_id=user.user_id,
telegram_id=user.telegram_id, telegram_id=user.telegram_id,
role=user.role.value, role=user.role.value,
created_at=user.created_at created_at=user.created_at,
is_premium=user.is_premium,
premium_until=user.premium_until,
questions_used=user.questions_used
) )
self.session.add(db_user) self.session.add(db_user)
await self.session.commit() await self.session.commit()
@ -57,6 +60,9 @@ class PostgreSQLUserRepository(IUserRepository):
db_user.telegram_id = user.telegram_id db_user.telegram_id = user.telegram_id
db_user.role = user.role.value db_user.role = user.role.value
db_user.is_premium = user.is_premium
db_user.premium_until = user.premium_until
db_user.questions_used = user.questions_used
await self.session.commit() await self.session.commit()
await self.session.refresh(db_user) await self.session.refresh(db_user)
return self._to_entity(db_user) return self._to_entity(db_user)
@ -90,6 +96,9 @@ class PostgreSQLUserRepository(IUserRepository):
user_id=db_user.user_id, user_id=db_user.user_id,
telegram_id=db_user.telegram_id, telegram_id=db_user.telegram_id,
role=UserRole(db_user.role), role=UserRole(db_user.role),
created_at=db_user.created_at created_at=db_user.created_at,
is_premium=db_user.is_premium,
premium_until=db_user.premium_until,
questions_used=db_user.questions_used
) )

View File

@ -0,0 +1,4 @@
"""
Qdrant repositories
"""

View File

@ -0,0 +1,92 @@
"""
Qdrant реализация векторного хранилища
"""
from typing import Sequence
from uuid import UUID
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from src.domain.entities.chunk import DocumentChunk
from src.domain.repositories.vector_repository import IVectorRepository
class QdrantVectorRepository(IVectorRepository):
def __init__(
self,
client: QdrantClient,
collection_name: str = "documents",
vector_size: int = 768,
):
self.client = client
self.collection_name = collection_name
self.vector_size = vector_size
self._ensure_collection()
def _ensure_collection(self) -> None:
"""Создает коллекцию при отсутствии"""
if self.collection_name in [c.name for c in self.client.get_collections().collections]:
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
)
async def upsert_chunks(
self,
chunks: Sequence[DocumentChunk],
embeddings: Sequence[list[float]],
model_version: str,
) -> None:
BATCH_SIZE = 100
points = []
for chunk, vector in zip(chunks, embeddings):
points.append(
PointStruct(
id=str(chunk.chunk_id),
vector=vector,
payload={
"document_id": str(chunk.document_id),
"collection_id": str(chunk.collection_id),
"content": chunk.content,
"order": chunk.order,
"model_version": model_version,
"title": chunk.metadata.get("title", ""),
},
)
)
if len(points) >= BATCH_SIZE:
self.client.upsert(collection_name=self.collection_name, points=points)
points = []
if points:
self.client.upsert(collection_name=self.collection_name, points=points)
async def search(
self,
query_embedding: list[float],
collection_id: UUID,
limit: int = 20,
) -> list[tuple[DocumentChunk, float]]:
res = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding,
query_filter=Filter(
must=[FieldCondition(key="collection_id", match=MatchValue(value=str(collection_id)))]
),
limit=limit,
)
results: list[tuple[DocumentChunk, float]] = []
for hit in res:
payload = hit.payload or {}
chunk = DocumentChunk(
document_id=UUID(payload["document_id"]),
collection_id=UUID(payload["collection_id"]),
content=payload.get("content", ""),
chunk_id=UUID(hit.id),
order=payload.get("order", 0),
metadata={"title": payload.get("title", ""), "model_version": payload.get("model_version", "")},
)
results.append((chunk, hit.score))
return results

View File

@ -2,10 +2,12 @@
Админ-панель - упрощенная версия через API эндпоинты Админ-панель - упрощенная версия через API эндпоинты
В будущем можно интегрировать полноценную админ-панель В будущем можно интегрировать полноценную админ-панель
""" """
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Request
from typing import List from typing import List, Annotated
from uuid import UUID from uuid import UUID
from dishka.integrations.fastapi import FromDishka from dishka.integrations.fastapi import FromDishka, inject
from src.domain.repositories.user_repository import IUserRepository
from src.presentation.middleware.auth_middleware import get_current_user
from src.presentation.schemas.user_schemas import UserResponse from src.presentation.schemas.user_schemas import UserResponse
from src.presentation.schemas.collection_schemas import CollectionResponse from src.presentation.schemas.collection_schemas import CollectionResponse
from src.presentation.schemas.document_schemas import DocumentResponse from src.presentation.schemas.document_schemas import DocumentResponse
@ -19,13 +21,16 @@ router = APIRouter(prefix="/admin", tags=["admin"])
@router.get("/users", response_model=List[UserResponse]) @router.get("/users", response_model=List[UserResponse])
@inject
async def admin_list_users( async def admin_list_users(
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[UserUseCases, FromDishka()],
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100
current_user: FromDishka[User] = FromDishka(),
use_cases: FromDishka[UserUseCases] = FromDishka()
): ):
"""Получить список всех пользователей (только для админов)""" """Получить список всех пользователей (только для админов)"""
current_user = await get_current_user(request, user_repo)
if not current_user.is_admin(): if not current_user.is_admin():
raise HTTPException(status_code=403, detail="Требуются права администратора") raise HTTPException(status_code=403, detail="Требуются права администратора")
users = await use_cases.list_users(skip=skip, limit=limit) users = await use_cases.list_users(skip=skip, limit=limit)
@ -33,13 +38,16 @@ async def admin_list_users(
@router.get("/collections", response_model=List[CollectionResponse]) @router.get("/collections", response_model=List[CollectionResponse])
@inject
async def admin_list_collections( async def admin_list_collections(
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()],
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100
current_user: FromDishka[User] = FromDishka(),
use_cases: FromDishka[CollectionUseCases] = FromDishka()
): ):
"""Получить список всех коллекций (только для админов)""" """Получить список всех коллекций (только для админов)"""
current_user = await get_current_user(request, user_repo)
from src.infrastructure.database.base import AsyncSessionLocal from src.infrastructure.database.base import AsyncSessionLocal
from src.infrastructure.repositories.postgresql.collection_repository import PostgreSQLCollectionRepository from src.infrastructure.repositories.postgresql.collection_repository import PostgreSQLCollectionRepository
from sqlalchemy import select from sqlalchemy import select

View File

@ -2,31 +2,37 @@
API роутеры для работы с коллекциями API роутеры для работы с коллекциями
""" """
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, status from fastapi import APIRouter, status, Depends, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import List from typing import List, Annotated
from dishka.integrations.fastapi import FromDishka from dishka.integrations.fastapi import FromDishka, inject
from src.domain.repositories.user_repository import IUserRepository
from src.presentation.middleware.auth_middleware import get_current_user
from src.presentation.schemas.collection_schemas import ( from src.presentation.schemas.collection_schemas import (
CollectionCreate, CollectionCreate,
CollectionUpdate, CollectionUpdate,
CollectionResponse, CollectionResponse,
CollectionAccessGrant, CollectionAccessGrant,
CollectionAccessResponse CollectionAccessResponse,
CollectionAccessListResponse,
CollectionAccessUserInfo
) )
from src.application.use_cases.collection_use_cases import CollectionUseCases from src.application.use_cases.collection_use_cases import CollectionUseCases
from src.domain.entities.user import User from src.domain.entities.user import User
from src.presentation.middleware.auth_middleware import get_current_user
router = APIRouter(prefix="/collections", tags=["collections"]) router = APIRouter(prefix="/collections", tags=["collections"])
@router.post("", response_model=CollectionResponse, status_code=status.HTTP_201_CREATED) @router.post("", response_model=CollectionResponse, status_code=status.HTTP_201_CREATED)
@inject
async def create_collection( async def create_collection(
collection_data: CollectionCreate, collection_data: CollectionCreate,
current_user: User = FromDishka(), request: Request,
use_cases: FromDishka[CollectionUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()]
): ):
"""Создать коллекцию""" """Создать коллекцию"""
current_user = await get_current_user(request, user_repo)
collection = await use_cases.create_collection( collection = await use_cases.create_collection(
name=collection_data.name, name=collection_data.name,
owner_id=current_user.user_id, owner_id=current_user.user_id,
@ -37,23 +43,36 @@ async def create_collection(
@router.get("/{collection_id}", response_model=CollectionResponse) @router.get("/{collection_id}", response_model=CollectionResponse)
@inject
async def get_collection( async def get_collection(
collection_id: UUID, collection_id: UUID,
use_cases: FromDishka[CollectionUseCases] = FromDishka() request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()]
): ):
"""Получить коллекцию по ID""" """Получить коллекцию по ID"""
current_user = await get_current_user(request, user_repo)
collection = await use_cases.get_collection(collection_id) collection = await use_cases.get_collection(collection_id)
has_access = await use_cases.check_access(collection_id, current_user.user_id)
if not has_access:
from fastapi import HTTPException
raise HTTPException(status_code=403, detail="У вас нет доступа к этой коллекции")
return CollectionResponse.from_entity(collection) return CollectionResponse.from_entity(collection)
@router.put("/{collection_id}", response_model=CollectionResponse) @router.put("/{collection_id}", response_model=CollectionResponse)
@inject
async def update_collection( async def update_collection(
collection_id: UUID, collection_id: UUID,
collection_data: CollectionUpdate, collection_data: CollectionUpdate,
current_user: User = FromDishka(), request: Request,
use_cases: FromDishka[CollectionUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()]
): ):
"""Обновить коллекцию""" """Обновить коллекцию"""
current_user = await get_current_user(request, user_repo)
collection = await use_cases.update_collection( collection = await use_cases.update_collection(
collection_id=collection_id, collection_id=collection_id,
user_id=current_user.user_id, user_id=current_user.user_id,
@ -65,24 +84,30 @@ async def update_collection(
@router.delete("/{collection_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{collection_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject
async def delete_collection( async def delete_collection(
collection_id: UUID, collection_id: UUID,
current_user: User = FromDishka(), request: Request,
use_cases: FromDishka[CollectionUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()]
): ):
"""Удалить коллекцию""" """Удалить коллекцию"""
current_user = await get_current_user(request, user_repo)
await use_cases.delete_collection(collection_id, current_user.user_id) await use_cases.delete_collection(collection_id, current_user.user_id)
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None) return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
@router.get("", response_model=List[CollectionResponse]) @router.get("", response_model=List[CollectionResponse])
@inject
async def list_collections( async def list_collections(
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()],
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100
current_user: User = FromDishka(),
use_cases: FromDishka[CollectionUseCases] = FromDishka()
): ):
"""Получить список коллекций, доступных пользователю""" """Получить список коллекций, доступных пользователю"""
current_user = await get_current_user(request, user_repo)
collections = await use_cases.list_user_collections( collections = await use_cases.list_user_collections(
user_id=current_user.user_id, user_id=current_user.user_id,
skip=skip, skip=skip,
@ -92,13 +117,16 @@ async def list_collections(
@router.post("/{collection_id}/access", response_model=CollectionAccessResponse, status_code=status.HTTP_201_CREATED) @router.post("/{collection_id}/access", response_model=CollectionAccessResponse, status_code=status.HTTP_201_CREATED)
@inject
async def grant_access( async def grant_access(
collection_id: UUID, collection_id: UUID,
access_data: CollectionAccessGrant, access_data: CollectionAccessGrant,
current_user: User = FromDishka(), request: Request,
use_cases: FromDishka[CollectionUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()]
): ):
"""Предоставить доступ пользователю к коллекции""" """Предоставить доступ пользователю к коллекции"""
current_user = await get_current_user(request, user_repo)
access = await use_cases.grant_access( access = await use_cases.grant_access(
collection_id=collection_id, collection_id=collection_id,
user_id=access_data.user_id, user_id=access_data.user_id,
@ -108,13 +136,91 @@ async def grant_access(
@router.delete("/{collection_id}/access/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{collection_id}/access/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject
async def revoke_access( async def revoke_access(
collection_id: UUID, collection_id: UUID,
user_id: UUID, user_id: UUID,
current_user: User = FromDishka(), request: Request,
use_cases: FromDishka[CollectionUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()]
): ):
"""Отозвать доступ пользователя к коллекции""" """Отозвать доступ пользователя к коллекции"""
current_user = await get_current_user(request, user_repo)
await use_cases.revoke_access(collection_id, user_id, current_user.user_id) await use_cases.revoke_access(collection_id, user_id, current_user.user_id)
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None) return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
@router.get("/{collection_id}/access", response_model=List[CollectionAccessListResponse])
@inject
async def list_collection_access(
collection_id: UUID,
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()]
):
"""Получить список пользователей с доступом к коллекции"""
current_user = await get_current_user(request, user_repo)
accesses = await use_cases.list_collection_access(collection_id, current_user.user_id)
result = []
for access in accesses:
user = await user_repo.get_by_id(access.user_id)
if user:
user_info = CollectionAccessUserInfo(
user_id=user.user_id,
telegram_id=user.telegram_id,
role=user.role.value,
created_at=user.created_at
)
result.append(CollectionAccessListResponse(
access_id=access.access_id,
user=user_info,
collection_id=access.collection_id,
created_at=access.created_at
))
return result
@router.post("/{collection_id}/access/telegram/{telegram_id}", response_model=CollectionAccessResponse, status_code=status.HTTP_201_CREATED)
@inject
async def grant_access_by_telegram_id(
collection_id: UUID,
telegram_id: str,
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()]
):
"""Предоставить доступ пользователю к коллекции по Telegram ID"""
import logging
logger = logging.getLogger(__name__)
current_user = await get_current_user(request, user_repo)
logger.info(f"Granting access: collection_id={collection_id}, target_telegram_id={telegram_id}, owner_id={current_user.user_id}")
try:
access = await use_cases.grant_access_by_telegram_id(
collection_id=collection_id,
telegram_id=telegram_id,
owner_id=current_user.user_id
)
logger.info(f"Access granted successfully: access_id={access.access_id}")
return CollectionAccessResponse.from_entity(access)
except Exception as e:
logger.error(f"Error granting access: {e}", exc_info=True)
raise
@router.delete("/{collection_id}/access/telegram/{telegram_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject
async def revoke_access_by_telegram_id(
collection_id: UUID,
telegram_id: str,
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[CollectionUseCases, FromDishka()]
):
"""Отозвать доступ пользователя к коллекции по Telegram ID"""
current_user = await get_current_user(request, user_repo)
await use_cases.revoke_access_by_telegram_id(collection_id, telegram_id, current_user.user_id)
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)

View File

@ -2,10 +2,12 @@
API роутеры для работы с беседами API роутеры для работы с беседами
""" """
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, status from fastapi import APIRouter, status, Depends, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import List from typing import List, Annotated
from dishka.integrations.fastapi import FromDishka from dishka.integrations.fastapi import FromDishka, inject
from src.domain.repositories.user_repository import IUserRepository
from src.presentation.middleware.auth_middleware import get_current_user
from src.presentation.schemas.conversation_schemas import ( from src.presentation.schemas.conversation_schemas import (
ConversationCreate, ConversationCreate,
ConversationResponse ConversationResponse
@ -17,12 +19,15 @@ router = APIRouter(prefix="/conversations", tags=["conversations"])
@router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED) @router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
@inject
async def create_conversation( async def create_conversation(
conversation_data: ConversationCreate, conversation_data: ConversationCreate,
current_user: FromDishka[User] = FromDishka(), request: Request,
use_cases: FromDishka[ConversationUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[ConversationUseCases, FromDishka()]
): ):
"""Создать беседу""" """Создать беседу"""
current_user = await get_current_user(request, user_repo)
conversation = await use_cases.create_conversation( conversation = await use_cases.create_conversation(
user_id=current_user.user_id, user_id=current_user.user_id,
collection_id=conversation_data.collection_id collection_id=conversation_data.collection_id
@ -31,35 +36,44 @@ async def create_conversation(
@router.get("/{conversation_id}", response_model=ConversationResponse) @router.get("/{conversation_id}", response_model=ConversationResponse)
@inject
async def get_conversation( async def get_conversation(
conversation_id: UUID, conversation_id: UUID,
current_user: FromDishka[User] = FromDishka(), request: Request,
use_cases: FromDishka[ConversationUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[ConversationUseCases, FromDishka()]
): ):
"""Получить беседу по ID""" """Получить беседу по ID"""
current_user = await get_current_user(request, user_repo)
conversation = await use_cases.get_conversation(conversation_id, current_user.user_id) conversation = await use_cases.get_conversation(conversation_id, current_user.user_id)
return ConversationResponse.from_entity(conversation) return ConversationResponse.from_entity(conversation)
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject
async def delete_conversation( async def delete_conversation(
conversation_id: UUID, conversation_id: UUID,
current_user: FromDishka[User] = FromDishka(), request: Request,
use_cases: FromDishka[ConversationUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[ConversationUseCases, FromDishka()]
): ):
"""Удалить беседу""" """Удалить беседу"""
current_user = await get_current_user(request, user_repo)
await use_cases.delete_conversation(conversation_id, current_user.user_id) await use_cases.delete_conversation(conversation_id, current_user.user_id)
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None) return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
@router.get("", response_model=List[ConversationResponse]) @router.get("", response_model=List[ConversationResponse])
@inject
async def list_conversations( async def list_conversations(
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[ConversationUseCases, FromDishka()],
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100
current_user: FromDishka[User] = FromDishka(),
use_cases: FromDishka[ConversationUseCases] = FromDishka()
): ):
"""Получить список бесед пользователя""" """Получить список бесед пользователя"""
current_user = await get_current_user(request, user_repo)
conversations = await use_cases.list_user_conversations( conversations = await use_cases.list_user_conversations(
user_id=current_user.user_id, user_id=current_user.user_id,
skip=skip, skip=skip,

View File

@ -2,28 +2,34 @@
API роутеры для работы с документами API роутеры для работы с документами
""" """
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, status, UploadFile, File from fastapi import APIRouter, status, UploadFile, File, Depends, Request, Query
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import List from typing import List, Annotated
from dishka.integrations.fastapi import FromDishka from dishka.integrations.fastapi import FromDishka, inject
from src.domain.repositories.user_repository import IUserRepository
from src.presentation.middleware.auth_middleware import get_current_user
from src.presentation.schemas.document_schemas import ( from src.presentation.schemas.document_schemas import (
DocumentCreate, DocumentCreate,
DocumentUpdate, DocumentUpdate,
DocumentResponse DocumentResponse
) )
from src.application.use_cases.document_use_cases import DocumentUseCases from src.application.use_cases.document_use_cases import DocumentUseCases
from src.application.use_cases.collection_use_cases import CollectionUseCases
from src.domain.entities.user import User from src.domain.entities.user import User
router = APIRouter(prefix="/documents", tags=["documents"]) router = APIRouter(prefix="/documents", tags=["documents"])
@router.post("", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED) @router.post("", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED)
@inject
async def create_document( async def create_document(
document_data: DocumentCreate, document_data: DocumentCreate,
current_user: FromDishka[User] = FromDishka(), request: Request,
use_cases: FromDishka[DocumentUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[DocumentUseCases, FromDishka()]
): ):
"""Создать документ""" """Создать документ"""
current_user = await get_current_user(request, user_repo)
document = await use_cases.create_document( document = await use_cases.create_document(
collection_id=document_data.collection_id, collection_id=document_data.collection_id,
title=document_data.title, title=document_data.title,
@ -34,13 +40,16 @@ async def create_document(
@router.post("/upload", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED) @router.post("/upload", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED)
@inject
async def upload_document( async def upload_document(
collection_id: UUID, collection_id: UUID = Query(...),
file: UploadFile = File(...), request: Request = None,
current_user: FromDishka[User] = FromDishka(), user_repo: Annotated[IUserRepository, FromDishka()] = None,
use_cases: FromDishka[DocumentUseCases] = FromDishka() use_cases: Annotated[DocumentUseCases, FromDishka()] = None,
file: UploadFile = File(...)
): ):
"""Загрузить и распарсить PDF документ или изображение""" """Загрузить и распарсить PDF документ или изображение"""
current_user = await get_current_user(request, user_repo)
if not file.filename: if not file.filename:
raise JSONResponse( raise JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -60,15 +69,17 @@ async def upload_document(
collection_id=collection_id, collection_id=collection_id,
file=file.file, file=file.file,
filename=file.filename, filename=file.filename,
user_id=current_user.user_id user_id=current_user.user_id,
telegram_id=current_user.telegram_id
) )
return DocumentResponse.from_entity(document) return DocumentResponse.from_entity(document)
@router.get("/{document_id}", response_model=DocumentResponse) @router.get("/{document_id}", response_model=DocumentResponse)
@inject
async def get_document( async def get_document(
document_id: UUID, document_id: UUID,
use_cases: FromDishka[DocumentUseCases] = FromDishka() use_cases: Annotated[DocumentUseCases, FromDishka()]
): ):
"""Получить документ по ID""" """Получить документ по ID"""
document = await use_cases.get_document(document_id) document = await use_cases.get_document(document_id)
@ -76,13 +87,16 @@ async def get_document(
@router.put("/{document_id}", response_model=DocumentResponse) @router.put("/{document_id}", response_model=DocumentResponse)
@inject
async def update_document( async def update_document(
document_id: UUID, document_id: UUID,
document_data: DocumentUpdate, document_data: DocumentUpdate,
current_user: FromDishka[User] = FromDishka(), request: Request,
use_cases: FromDishka[DocumentUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[DocumentUseCases, FromDishka()]
): ):
"""Обновить документ""" """Обновить документ"""
current_user = await get_current_user(request, user_repo)
document = await use_cases.update_document( document = await use_cases.update_document(
document_id=document_id, document_id=document_id,
user_id=current_user.user_id, user_id=current_user.user_id,
@ -94,24 +108,39 @@ async def update_document(
@router.delete("/{document_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{document_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject
async def delete_document( async def delete_document(
document_id: UUID, document_id: UUID,
current_user: FromDishka[User] = FromDishka(), request: Request,
use_cases: FromDishka[DocumentUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[DocumentUseCases, FromDishka()]
): ):
"""Удалить документ""" """Удалить документ"""
current_user = await get_current_user(request, user_repo)
await use_cases.delete_document(document_id, current_user.user_id) await use_cases.delete_document(document_id, current_user.user_id)
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None) return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
@router.get("/collection/{collection_id}", response_model=List[DocumentResponse]) @router.get("/collection/{collection_id}", response_model=List[DocumentResponse])
@inject
async def list_collection_documents( async def list_collection_documents(
collection_id: UUID, collection_id: UUID,
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[DocumentUseCases, FromDishka()],
collection_use_cases: Annotated[CollectionUseCases, FromDishka()],
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100
use_cases: FromDishka[DocumentUseCases] = FromDishka()
): ):
"""Получить документы коллекции""" """Получить документы коллекции"""
current_user = await get_current_user(request, user_repo)
has_access = await collection_use_cases.check_access(collection_id, current_user.user_id)
if not has_access:
from fastapi import HTTPException
raise HTTPException(status_code=403, detail="У вас нет доступа к этой коллекции")
documents = await use_cases.list_collection_documents( documents = await use_cases.list_collection_documents(
collection_id=collection_id, collection_id=collection_id,
skip=skip, skip=skip,

View File

@ -2,10 +2,12 @@
API роутеры для работы с сообщениями API роутеры для работы с сообщениями
""" """
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, status from fastapi import APIRouter, status, Depends, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import List from typing import List, Annotated
from dishka.integrations.fastapi import FromDishka from dishka.integrations.fastapi import FromDishka, inject
from src.domain.repositories.user_repository import IUserRepository
from src.presentation.middleware.auth_middleware import get_current_user
from src.presentation.schemas.message_schemas import ( from src.presentation.schemas.message_schemas import (
MessageCreate, MessageCreate,
MessageUpdate, MessageUpdate,
@ -18,12 +20,15 @@ router = APIRouter(prefix="/messages", tags=["messages"])
@router.post("", response_model=MessageResponse, status_code=status.HTTP_201_CREATED) @router.post("", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
@inject
async def create_message( async def create_message(
message_data: MessageCreate, message_data: MessageCreate,
current_user: FromDishka[User] = FromDishka(), request: Request,
use_cases: FromDishka[MessageUseCases] = FromDishka() user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[MessageUseCases, FromDishka()]
): ):
"""Создать сообщение""" """Создать сообщение"""
current_user = await get_current_user(request, user_repo)
message = await use_cases.create_message( message = await use_cases.create_message(
conversation_id=message_data.conversation_id, conversation_id=message_data.conversation_id,
content=message_data.content, content=message_data.content,
@ -35,9 +40,10 @@ async def create_message(
@router.get("/{message_id}", response_model=MessageResponse) @router.get("/{message_id}", response_model=MessageResponse)
@inject
async def get_message( async def get_message(
message_id: UUID, message_id: UUID,
use_cases: FromDishka[MessageUseCases] = FromDishka() use_cases: Annotated[MessageUseCases, FromDishka()]
): ):
"""Получить сообщение по ID""" """Получить сообщение по ID"""
message = await use_cases.get_message(message_id) message = await use_cases.get_message(message_id)
@ -45,10 +51,11 @@ async def get_message(
@router.put("/{message_id}", response_model=MessageResponse) @router.put("/{message_id}", response_model=MessageResponse)
@inject
async def update_message( async def update_message(
message_id: UUID, message_id: UUID,
message_data: MessageUpdate, message_data: MessageUpdate,
use_cases: FromDishka[MessageUseCases] = FromDishka() use_cases: Annotated[MessageUseCases, FromDishka()]
): ):
"""Обновить сообщение""" """Обновить сообщение"""
message = await use_cases.update_message( message = await use_cases.update_message(
@ -60,9 +67,10 @@ async def update_message(
@router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject
async def delete_message( async def delete_message(
message_id: UUID, message_id: UUID,
use_cases: FromDishka[MessageUseCases] = FromDishka() use_cases: Annotated[MessageUseCases, FromDishka()]
): ):
"""Удалить сообщение""" """Удалить сообщение"""
await use_cases.delete_message(message_id) await use_cases.delete_message(message_id)
@ -70,14 +78,17 @@ async def delete_message(
@router.get("/conversation/{conversation_id}", response_model=List[MessageResponse]) @router.get("/conversation/{conversation_id}", response_model=List[MessageResponse])
@inject
async def list_conversation_messages( async def list_conversation_messages(
conversation_id: UUID, conversation_id: UUID,
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[MessageUseCases, FromDishka()],
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100
current_user: FromDishka[User] = FromDishka(),
use_cases: FromDishka[MessageUseCases] = FromDishka()
): ):
"""Получить сообщения беседы""" """Получить сообщения беседы"""
current_user = await get_current_user(request, user_repo)
messages = await use_cases.list_conversation_messages( messages = await use_cases.list_conversation_messages(
conversation_id=conversation_id, conversation_id=conversation_id,
user_id=current_user.user_id, user_id=current_user.user_id,

View File

@ -0,0 +1,37 @@
"""
API для RAG: ответы на вопросы
"""
from fastapi import APIRouter, status, Request
from typing import Annotated
from dishka.integrations.fastapi import FromDishka, inject
from src.domain.repositories.user_repository import IUserRepository
from src.presentation.middleware.auth_middleware import get_current_user
from src.presentation.schemas.rag_schemas import (
QuestionRequest,
RAGAnswer,
)
from src.application.use_cases.rag_use_cases import RAGUseCases
router = APIRouter(prefix="/rag", tags=["rag"])
@router.post("/question", response_model=RAGAnswer, status_code=status.HTTP_200_OK)
@inject
async def ask_question(
body: QuestionRequest,
request: Request,
user_repo: Annotated[IUserRepository, FromDishka()],
use_cases: Annotated[RAGUseCases, FromDishka()],
):
"""Отвечает на вопрос, используя RAG в рамках беседы"""
current_user = await get_current_user(request, user_repo)
result = await use_cases.ask_question(
conversation_id=body.conversation_id,
user_id=current_user.user_id,
question=body.question,
top_k=body.top_k,
rerank_top_n=body.rerank_top_n,
)
return RAGAnswer(**result)

View File

@ -2,10 +2,12 @@
API роутеры для работы с пользователями API роутеры для работы с пользователями
""" """
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, status from fastapi import APIRouter, status, Depends, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import List from typing import List, Annotated
from dishka.integrations.fastapi import FromDishka from dishka.integrations.fastapi import FromDishka, inject
from src.domain.repositories.user_repository import IUserRepository
from src.presentation.middleware.auth_middleware import get_current_user
from src.presentation.schemas.user_schemas import UserCreate, UserUpdate, UserResponse from src.presentation.schemas.user_schemas import UserCreate, UserUpdate, UserResponse
from src.application.use_cases.user_use_cases import UserUseCases from src.application.use_cases.user_use_cases import UserUseCases
from src.domain.entities.user import User from src.domain.entities.user import User
@ -14,9 +16,10 @@ router = APIRouter(prefix="/users", tags=["users"])
@router.post("", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @router.post("", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@inject
async def create_user( async def create_user(
user_data: UserCreate, user_data: UserCreate,
use_cases: FromDishka[UserUseCases] = FromDishka() use_cases: Annotated[UserUseCases, FromDishka()]
): ):
"""Создать пользователя""" """Создать пользователя"""
user = await use_cases.create_user( user = await use_cases.create_user(
@ -27,17 +30,59 @@ async def create_user(
@router.get("/me", response_model=UserResponse) @router.get("/me", response_model=UserResponse)
@inject
async def get_current_user_info( async def get_current_user_info(
current_user: FromDishka[User] = FromDishka() request,
user_repo: Annotated[IUserRepository, FromDishka()]
): ):
"""Получить информацию о текущем пользователе""" """Получить информацию о текущем пользователе"""
current_user = await get_current_user(request, user_repo)
return UserResponse.from_entity(current_user) return UserResponse.from_entity(current_user)
@router.get("/telegram/{telegram_id}", response_model=UserResponse)
@inject
async def get_user_by_telegram_id(
telegram_id: str,
use_cases: Annotated[UserUseCases, FromDishka()]
):
"""Получить пользователя по Telegram ID"""
user = await use_cases.get_user_by_telegram_id(telegram_id)
if not user:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail=f"Пользователь с telegram_id {telegram_id} не найден")
return UserResponse.from_entity(user)
@router.post("/telegram/{telegram_id}/increment-questions", response_model=UserResponse)
@inject
async def increment_questions(
telegram_id: str,
use_cases: Annotated[UserUseCases, FromDishka()]
):
"""Увеличить счетчик использованных вопросов"""
user = await use_cases.increment_questions_used(telegram_id)
return UserResponse.from_entity(user)
@router.post("/telegram/{telegram_id}/activate-premium", response_model=UserResponse)
@inject
async def activate_premium(
use_cases: Annotated[UserUseCases, FromDishka()],
telegram_id: str,
days: int = 30,
):
"""Активировать premium статус"""
user = await use_cases.activate_premium(telegram_id, days=days)
return UserResponse.from_entity(user)
@router.get("/{user_id}", response_model=UserResponse) @router.get("/{user_id}", response_model=UserResponse)
@inject
async def get_user( async def get_user(
user_id: UUID, user_id: UUID,
use_cases: FromDishka[UserUseCases] = FromDishka() use_cases: Annotated[UserUseCases, FromDishka()]
): ):
"""Получить пользователя по ID""" """Получить пользователя по ID"""
user = await use_cases.get_user(user_id) user = await use_cases.get_user(user_id)
@ -45,10 +90,11 @@ async def get_user(
@router.put("/{user_id}", response_model=UserResponse) @router.put("/{user_id}", response_model=UserResponse)
@inject
async def update_user( async def update_user(
user_id: UUID, user_id: UUID,
user_data: UserUpdate, user_data: UserUpdate,
use_cases: FromDishka[UserUseCases] = FromDishka() use_cases: Annotated[UserUseCases, FromDishka()]
): ):
"""Обновить пользователя""" """Обновить пользователя"""
user = await use_cases.update_user( user = await use_cases.update_user(
@ -60,9 +106,10 @@ async def update_user(
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject
async def delete_user( async def delete_user(
user_id: UUID, user_id: UUID,
use_cases: FromDishka[UserUseCases] = FromDishka() use_cases: Annotated[UserUseCases, FromDishka()]
): ):
"""Удалить пользователя""" """Удалить пользователя"""
await use_cases.delete_user(user_id) await use_cases.delete_user(user_id)
@ -70,10 +117,11 @@ async def delete_user(
@router.get("", response_model=List[UserResponse]) @router.get("", response_model=List[UserResponse])
@inject
async def list_users( async def list_users(
use_cases: Annotated[UserUseCases, FromDishka()],
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100
use_cases: FromDishka[UserUseCases] = FromDishka()
): ):
"""Получить список пользователей""" """Получить список пользователей"""
users = await use_cases.list_users(skip=skip, limit=limit) users = await use_cases.list_users(skip=skip, limit=limit)

View File

@ -1,11 +1,11 @@
"""
Главный файл FastAPI приложения
"""
import sys import sys
import os import os
import asyncio
from pathlib import Path
if '/app' not in sys.path: backend_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, '/app') if str(backend_dir) not in sys.path:
sys.path.insert(0, str(backend_dir))
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -16,22 +16,24 @@ from src.shared.config import settings
from src.shared.exceptions import LawyerAIException from src.shared.exceptions import LawyerAIException
from src.shared.di_container import create_container from src.shared.di_container import create_container
from src.presentation.middleware.error_handler import exception_handler from src.presentation.middleware.error_handler import exception_handler
from src.presentation.api.v1 import users, collections, documents, conversations, messages from src.presentation.api.v1 import users, collections, documents, conversations, messages, rag
from src.infrastructure.database.base import engine, Base from src.infrastructure.database.base import engine, Base
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Управление жизненным циклом приложения""" """Управление жизненным циклом приложения"""
container = create_container()
setup_dishka(container, app)
try: try:
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
except Exception as e: except Exception as e:
print(f"Примечание при создании таблиц: {e}") print(f"Примечание при создании таблиц: {e}")
yield yield
await container.close() if hasattr(app.state, 'container') and hasattr(app.state.container, 'close'):
if asyncio.iscoroutinefunction(app.state.container.close):
await app.state.container.close()
else:
app.state.container.close()
await engine.dispose() await engine.dispose()
@ -42,6 +44,10 @@ app = FastAPI(
lifespan=lifespan lifespan=lifespan
) )
container = create_container()
setup_dishka(container, app)
app.state.container = container
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.CORS_ORIGINS, allow_origins=settings.CORS_ORIGINS,
@ -57,6 +63,7 @@ app.include_router(collections.router, prefix="/api/v1")
app.include_router(documents.router, prefix="/api/v1") app.include_router(documents.router, prefix="/api/v1")
app.include_router(conversations.router, prefix="/api/v1") app.include_router(conversations.router, prefix="/api/v1")
app.include_router(messages.router, prefix="/api/v1") app.include_router(messages.router, prefix="/api/v1")
app.include_router(rag.router, prefix="/api/v1")
try: try:
from src.presentation.api.v1 import admin from src.presentation.api.v1 import admin

View File

@ -75,3 +75,22 @@ class CollectionAccessResponse(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class CollectionAccessUserInfo(BaseModel):
"""Информация о пользователе с доступом"""
user_id: UUID
telegram_id: str
role: str
created_at: datetime
class CollectionAccessListResponse(BaseModel):
"""Схема ответа со списком доступа"""
access_id: UUID
user: CollectionAccessUserInfo
collection_id: UUID
created_at: datetime
class Config:
from_attributes = True

View File

@ -0,0 +1,28 @@
"""
Схемы для RAG
"""
from uuid import UUID
from pydantic import BaseModel, Field
from typing import List, Any
class QuestionRequest(BaseModel):
conversation_id: UUID
question: str = Field(..., min_length=3)
top_k: int = 20
rerank_top_n: int = 5
class RAGSource(BaseModel):
index: int
document_id: str
chunk_id: str
title: str | None = None
class RAGAnswer(BaseModel):
answer: str
sources: List[RAGSource] = []
usage: dict[str, Any] = {}

View File

@ -30,6 +30,9 @@ class UserResponse(BaseModel):
telegram_id: str telegram_id: str
role: UserRole role: UserRole
created_at: datetime created_at: datetime
is_premium: bool = False
premium_until: datetime | None = None
questions_used: int = 0
@classmethod @classmethod
def from_entity(cls, user: "User") -> "UserResponse": def from_entity(cls, user: "User") -> "UserResponse":
@ -38,7 +41,10 @@ class UserResponse(BaseModel):
user_id=user.user_id, user_id=user.user_id,
telegram_id=user.telegram_id, telegram_id=user.telegram_id,
role=user.role, role=user.role,
created_at=user.created_at created_at=user.created_at,
is_premium=user.is_premium,
premium_until=user.premium_until,
questions_used=user.questions_used
) )
class Config: class Config:

View File

@ -1,7 +1,4 @@
""" from dishka import Container, Provider, Scope, provide, make_async_container
DI контейнер на основе dishka
"""
from dishka import Container, Provider, Scope, provide
from fastapi import Request from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -19,26 +16,32 @@ from src.domain.repositories.document_repository import IDocumentRepository
from src.domain.repositories.conversation_repository import IConversationRepository from src.domain.repositories.conversation_repository import IConversationRepository
from src.domain.repositories.message_repository import IMessageRepository from src.domain.repositories.message_repository import IMessageRepository
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
from src.domain.repositories.vector_repository import IVectorRepository
from src.infrastructure.external.yandex_ocr import YandexOCRService from src.infrastructure.external.yandex_ocr import YandexOCRService
from src.infrastructure.external.deepseek_client import DeepSeekClient from src.infrastructure.external.deepseek_client import DeepSeekClient
from src.infrastructure.external.redis_client import RedisClient
from src.application.services.document_parser_service import DocumentParserService from src.application.services.document_parser_service import DocumentParserService
from src.application.services.cache_service import CacheService
from src.application.use_cases.user_use_cases import UserUseCases from src.application.use_cases.user_use_cases import UserUseCases
from src.application.use_cases.collection_use_cases import CollectionUseCases from src.application.use_cases.collection_use_cases import CollectionUseCases
from src.application.use_cases.document_use_cases import DocumentUseCases from src.application.use_cases.document_use_cases import DocumentUseCases
from src.application.use_cases.conversation_use_cases import ConversationUseCases from src.application.use_cases.conversation_use_cases import ConversationUseCases
from src.application.use_cases.message_use_cases import MessageUseCases from src.application.use_cases.message_use_cases import MessageUseCases
from src.domain.entities.user import User from src.domain.entities.user import User
from src.shared.config import settings
from qdrant_client import QdrantClient
from src.infrastructure.repositories.qdrant.vector_repository import QdrantVectorRepository
from src.application.services.embedding_service import EmbeddingService
from src.application.services.reranker_service import RerankerService
from src.application.services.rag_service import RAGService
from src.application.services.text_splitter import TextSplitter
from src.application.use_cases.rag_use_cases import RAGUseCases
class DatabaseProvider(Provider): class DatabaseProvider(Provider):
@provide(scope=Scope.REQUEST) @provide(scope=Scope.REQUEST)
@asynccontextmanager
async def get_db(self) -> AsyncSession: async def get_db(self) -> AsyncSession:
async with AsyncSessionLocal() as session: session = AsyncSessionLocal()
try: return session
yield session
finally:
await session.close()
class RepositoryProvider(Provider): class RepositoryProvider(Provider):
@ -68,6 +71,14 @@ class RepositoryProvider(Provider):
class ServiceProvider(Provider): class ServiceProvider(Provider):
@provide(scope=Scope.APP)
def get_redis_client(self) -> RedisClient:
return RedisClient(host=settings.REDIS_HOST, port=settings.REDIS_PORT)
@provide(scope=Scope.APP)
def get_cache_service(self, redis_client: RedisClient) -> CacheService:
return CacheService(redis_client)
@provide(scope=Scope.APP) @provide(scope=Scope.APP)
def get_ocr_service(self) -> YandexOCRService: def get_ocr_service(self) -> YandexOCRService:
return YandexOCRService() return YandexOCRService()
@ -80,12 +91,42 @@ class ServiceProvider(Provider):
def get_parser_service(self, ocr_service: YandexOCRService) -> DocumentParserService: def get_parser_service(self, ocr_service: YandexOCRService) -> DocumentParserService:
return DocumentParserService(ocr_service) return DocumentParserService(ocr_service)
@provide(scope=Scope.APP)
def get_qdrant_client(self) -> QdrantClient:
return QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT)
class AuthProvider(Provider): @provide(scope=Scope.APP)
@provide(scope=Scope.REQUEST) def get_vector_repository(self, client: QdrantClient) -> IVectorRepository:
async def get_current_user(self, request: Request, user_repo: IUserRepository) -> User: return QdrantVectorRepository(client=client, vector_size=768)
from src.presentation.middleware.auth_middleware import get_current_user
return await get_current_user(request, user_repo) @provide(scope=Scope.APP)
def get_embedding_service(self) -> EmbeddingService:
return EmbeddingService()
@provide(scope=Scope.APP)
def get_reranker_service(self, embedding_service: EmbeddingService) -> RerankerService:
return RerankerService(fallback_encoder=embedding_service.model)
@provide(scope=Scope.APP)
def get_text_splitter(self) -> TextSplitter:
return TextSplitter()
@provide(scope=Scope.APP)
def get_rag_service(
self,
vector_repo: IVectorRepository,
embedding_service: EmbeddingService,
reranker_service: RerankerService,
deepseek_client: DeepSeekClient,
text_splitter: TextSplitter
) -> RAGService:
return RAGService(
vector_repository=vector_repo,
embedding_service=embedding_service,
reranker_service=reranker_service,
deepseek_client=deepseek_client,
splitter=text_splitter,
)
class UseCaseProvider(Provider): class UseCaseProvider(Provider):
@ -110,9 +151,11 @@ class UseCaseProvider(Provider):
self, self,
document_repo: IDocumentRepository, document_repo: IDocumentRepository,
collection_repo: ICollectionRepository, collection_repo: ICollectionRepository,
parser_service: DocumentParserService access_repo: ICollectionAccessRepository,
parser_service: DocumentParserService,
rag_service: RAGService
) -> DocumentUseCases: ) -> DocumentUseCases:
return DocumentUseCases(document_repo, collection_repo, parser_service) return DocumentUseCases(document_repo, collection_repo, access_repo, parser_service, rag_service)
@provide(scope=Scope.REQUEST) @provide(scope=Scope.REQUEST)
def get_conversation_use_cases( def get_conversation_use_cases(
@ -131,13 +174,23 @@ class UseCaseProvider(Provider):
) -> MessageUseCases: ) -> MessageUseCases:
return MessageUseCases(message_repo, conversation_repo) return MessageUseCases(message_repo, conversation_repo)
@provide(scope=Scope.REQUEST)
def get_rag_use_cases(
self,
rag_service: RAGService,
document_repo: IDocumentRepository,
conversation_repo: IConversationRepository,
message_repo: IMessageRepository,
cache_service: CacheService
) -> RAGUseCases:
return RAGUseCases(rag_service, document_repo, conversation_repo, message_repo, cache_service)
def create_container() -> Container: def create_container() -> Container:
container = Container() return make_async_container(
container.add_provider(DatabaseProvider()) DatabaseProvider(),
container.add_provider(RepositoryProvider()) RepositoryProvider(),
container.add_provider(ServiceProvider()) ServiceProvider(),
container.add_provider(AuthProvider()) UseCaseProvider()
container.add_provider(UseCaseProvider()) )
return container

View File

@ -1,93 +0,0 @@
import os
import sys
from sqlalchemy import create_engine, inspect
from sqlalchemy.orm import declarative_base, Session
from sqlalchemy import Column, String, DateTime, Boolean, Integer, Text
import uuid
from datetime import datetime
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DB_PATH = os.path.join(BASE_DIR, 'data', 'bot.db')
DATABASE_URL = f"sqlite:///{DB_PATH}"
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
if os.path.exists(DB_PATH):
try:
temp_engine = create_engine(DATABASE_URL)
inspector = inspect(temp_engine)
tables = inspector.get_table_names()
if tables:
sys.exit(0)
except:
pass
choice = input("Перезаписать БД? (y/N): ")
if choice.lower() != 'y':
sys.exit(0)
engine = create_engine(DATABASE_URL, echo=False)
Base = declarative_base()
class UserModel(Base):
__tablename__ = "users"
user_id = Column("user_id", String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
telegram_id = Column("telegram_id", String(100), nullable=False, unique=True)
created_at = Column("created_at", DateTime, default=datetime.utcnow, nullable=False)
role = Column("role", String(20), default="user", nullable=False)
is_premium = Column(Boolean, default=False, nullable=False)
premium_until = Column(DateTime, nullable=True)
questions_used = Column(Integer, default=0, nullable=False)
username = Column(String(100), nullable=True)
first_name = Column(String(100), nullable=True)
last_name = Column(String(100), nullable=True)
class PaymentModel(Base):
__tablename__ = "payments"
id = Column(Integer, primary_key=True, autoincrement=True)
payment_id = Column(String(36), default=lambda: str(uuid.uuid4()), nullable=False, unique=True)
user_id = Column(Integer, nullable=False)
amount = Column(String(20), nullable=False)
currency = Column(String(3), default="RUB", nullable=False)
status = Column(String(20), default="pending", nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
yookassa_payment_id = Column(String(100), unique=True, nullable=True)
description = Column(Text, nullable=True)
try:
Base.metadata.create_all(bind=engine)
session = Session(bind=engine)
existing = session.query(UserModel).filter_by(telegram_id="123456789").first()
if not existing:
test_user = UserModel(
telegram_id="123456789",
username="test_user",
first_name="Test",
last_name="User",
is_premium=True
)
session.add(test_user)
existing_payment = session.query(PaymentModel).filter_by(yookassa_payment_id="test_yoo_001").first()
if not existing_payment:
test_payment = PaymentModel(
user_id=123456789,
amount="500.00",
status="succeeded",
description="Test payment",
yookassa_payment_id="test_yoo_001"
)
session.add(test_payment)
session.commit()
session.close()
except Exception as e:
print(f"Ошибка: {e}")
import traceback
traceback.print_exc()

View File

@ -1,31 +0,0 @@
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
from tg_bot.infrastructure.database.database import engine, Base
from tg_bot.infrastructure.database import models
print("СОЗДАНИЕ ТАБЛИЦ БАЗЫ ДАННЫХ")
Base.metadata.create_all(bind=engine)
print("Таблицы успешно созданы!")
print(" • users")
print(" • payments")
print()
print(f"База данных: {engine.url}")
db_path = "data/bot.db"
if os.path.exists(db_path):
size = os.path.getsize(db_path)
print(f"Размер файла: {size} байт")
else:
print("Файл БД не найден, но таблицы созданы")
except Exception as e:
print(f"Ошибка: {e}")
import traceback
traceback.print_exc()
print("=" * 50)

82
docker-compose.yml Normal file
View File

@ -0,0 +1,82 @@
services:
postgres:
image: postgres:15-alpine
restart: unless-stopped
env_file:
- .env
environment:
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: ${POSTGRES_DB}
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
qdrant:
image: qdrant/qdrant:latest
restart: unless-stopped
ports:
- "6333:6333"
volumes:
- qdrant_data:/qdrant/storage
redis:
image: redis:7-alpine
restart: unless-stopped
ports:
- "6379:6379"
volumes:
- redis_data:/data
backend:
build: ./backend
restart: unless-stopped
env_file:
- .env
environment:
POSTGRES_HOST: postgres
POSTGRES_PORT: ${POSTGRES_PORT}
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: ${POSTGRES_DB}
QDRANT_HOST: qdrant
QDRANT_PORT: ${QDRANT_PORT}
REDIS_HOST: redis
REDIS_PORT: ${REDIS_PORT}
DEBUG: "true"
SECRET_KEY: ${SECRET_KEY}
APP_NAME: ${APP_NAME}
CORS_ORIGINS: ${CORS_ORIGINS}
ports:
- "8000:8000"
depends_on:
- postgres
- qdrant
- redis
tg_bot:
build: ./tg_bot
restart: unless-stopped
env_file:
- .env
environment:
POSTGRES_HOST: postgres
POSTGRES_PORT: ${POSTGRES_PORT}
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: ${POSTGRES_DB}
TELEGRAM_BOT_TOKEN: ${TELEGRAM_BOT_TOKEN}
DEEPSEEK_API_KEY: ${DEEPSEEK_API_KEY}
DEEPSEEK_API_URL: ${DEEPSEEK_API_URL:-https://api.deepseek.com/v1/chat/completions}
YANDEX_OCR_API_KEY: ${YANDEX_OCR_API_KEY}
BACKEND_URL: ${BACKEND_URL:-http://backend:8000/api/v1}
DEBUG: "true"
depends_on:
- postgres
- backend
volumes:
postgres_data:
qdrant_data:
redis_data:

23
pytest.ini Normal file
View File

@ -0,0 +1,23 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
addopts =
-v
--strict-markers
--tb=short
--cov=backend/src
--cov=tg_bot
--cov-report=term-missing
--cov-report=xml
--cov-fail-under=65
--ignore=venv
--ignore=.venv
markers =
unit: Unit tests
integration: Integration tests
metrics: Metrics tests
slow: Slow running tests

View File

@ -1,11 +0,0 @@
pydantic>=2.5.0
pydantic-settings>=2.1.0
python-dotenv>=1.0.0
aiogram>=3.10.0
sqlalchemy>=2.0.0
aiosqlite>=0.19.0
httpx>=0.25.0
yookassa>=2.4.0
fastapi>=0.104.0
uvicorn>=0.24.0
python-multipart>=0.0.6

94
tests/README.md Normal file
View File

@ -0,0 +1,94 @@
# Тесты для проекта BetterCallPraskovia
## Структура тестов
```
tests/
├── conftest.py
├── unit/
│ ├── test_rag_service.py
│ ├── test_user_service.py
│ ├── test_deepseek_client.py
│ ├── test_document_use_cases.py
│ └── test_collection_use_cases.py
├── integration/
│ └── test_rag_integration.py
└── metrics/
└── test_hit_at_5.py
```
## Установка зависимостей
```bash
pip install -r tests/requirements.txt
```
## Запуск тестов
### Все тесты
```bash
pytest
```
### Только юнит-тесты
```bash
pytest tests/unit/
```
### Только интеграционные тесты
```bash
pytest tests/integration/
```
### Только тесты метрик
```bash
pytest tests/metrics/
```
### Только тесты tg_bot
```bash
pytest tests/unit/test_rag_service.py tests/unit/test_user_service.py tests/unit/test_deepseek_client.py
```
### С покрытием кода
```bash
pytest --cov=backend/src --cov=tg_bot --cov-report=html
```
### С минимальным покрытием 65%
```bash
pytest --cov-fail-under=65
```
## Метрика hit@5
Проверка что в топ-5 релевантных документов есть хотя бы 1 нужный документ.
- **hit@5 = 1**, если есть хотя бы 1 релевантный документ в топ-5
- **hit@5 = 0**, если нет релевантных документов в топ-5
Среднее значение hit@5 для всех запросов должно быть **> 50%**
## Покрытие кода
**coverage ≥ 65%**
Проверка покрытия:
```bash
pytest --cov=backend/src --cov=tg_bot --cov-report=term-missing --cov-fail-under=65
```
## Маркеры тестов
- `@pytest.mark.unit` - юнит-тесты
- `@pytest.mark.integration` - интеграционные тесты
- `@pytest.mark.metrics` - тесты метрик
- `@pytest.mark.slow` - медленные тесты
Запуск по маркерам:
```bash
pytest -m unit
pytest -m integration
pytest -m metrics
```

170
tests/conftest.py Normal file
View File

@ -0,0 +1,170 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
from datetime import datetime
from typing import List, Dict
@pytest.fixture
def mock_document():
try:
from backend.src.domain.entities.document import Document
except ImportError:
from src.domain.entities.document import Document
return Document(
collection_id=uuid4(),
title="Тестовый документ",
content="Содержание документа",
metadata={"type": "law", "article": "123"}
)
@pytest.fixture
def mock_documents_list():
try:
from backend.src.domain.entities.document import Document
except ImportError:
from src.domain.entities.document import Document
collection_id = uuid4()
return [
Document(
collection_id=collection_id,
title=f"Документ {i}",
content=f"Содержание документа {i} ",
metadata={"relevance_score": 0.9 - i * 0.1}
)
for i in range(10)
]
@pytest.fixture
def mock_relevant_documents():
return [
{"document_id": str(uuid4()), "title": "Гражданский кодекс РФ", "relevance": True},
{"document_id": str(uuid4()), "title": "Трудовой кодекс РФ", "relevance": True},
{"document_id": str(uuid4()), "title": "Налоговый кодекс РФ", "relevance": True},
]
@pytest.fixture
def mock_rag_response():
return {
"answer": "Тестовый ответ на вопрос",
"sources": [
{"title": "Документ 1", "collection": "Коллекция 1", "document_id": str(uuid4())},
{"title": "Документ 2", "collection": "Коллекция 1", "document_id": str(uuid4())},
{"title": "Документ 3", "collection": "Коллекция 2", "document_id": str(uuid4())},
],
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
}
@pytest.fixture
def mock_collection():
try:
from backend.src.domain.entities.collection import Collection
except ImportError:
from src.domain.entities.collection import Collection
return Collection(
name="Коллекция",
owner_id=uuid4(),
description="Описание коллекции",
is_public=False
)
@pytest.fixture
def mock_user():
try:
from backend.src.domain.entities.user import User, UserRole
except ImportError:
from src.domain.entities.user import User, UserRole
return User(
telegram_id="123456789",
role=UserRole.USER
)
@pytest.fixture
def mock_document_repository():
repository = AsyncMock()
repository.get_by_id = AsyncMock()
repository.create = AsyncMock()
repository.update = AsyncMock()
repository.delete = AsyncMock(return_value=True)
repository.list_by_collection = AsyncMock(return_value=[])
return repository
@pytest.fixture
def mock_collection_repository():
repository = AsyncMock()
repository.get_by_id = AsyncMock()
repository.create = AsyncMock()
repository.update = AsyncMock()
repository.delete = AsyncMock(return_value=True)
repository.list_by_owner = AsyncMock(return_value=[])
repository.list_public = AsyncMock(return_value=[])
return repository
@pytest.fixture
def mock_user_repository():
repository = AsyncMock()
repository.get_by_id = AsyncMock()
repository.get_by_telegram_id = AsyncMock()
repository.create = AsyncMock()
repository.update = AsyncMock()
repository.delete = AsyncMock(return_value=True)
return repository
@pytest.fixture
def mock_deepseek_client():
client = AsyncMock()
client.chat_completion = AsyncMock(return_value={
"content": "Ответ от DeepSeek",
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
})
return client
@pytest.fixture
def sample_queries_with_ground_truth():
return [
{
"query": "Какие права имеет работник при увольнении?",
"relevant_document_ids": [str(uuid4()), str(uuid4())],
"expected_top5_contains": True
},
{
"query": "Как оформить договор купли-продажи?",
"relevant_document_ids": [str(uuid4())],
"expected_top5_contains": True
},
{
"query": "Какие налоги платит ИП?",
"relevant_document_ids": [str(uuid4()), str(uuid4()), str(uuid4())],
"expected_top5_contains": True
},
{
"query": "Права потребителя при возврате товара",
"relevant_document_ids": [str(uuid4())],
"expected_top5_contains": True
},
{
"query": "Как расторгнуть брак?",
"relevant_document_ids": [str(uuid4())],
"expected_top5_contains": True
},
]
@pytest.fixture
def mock_aiohttp_session():
session = AsyncMock()
session.get = AsyncMock()
session.post = AsyncMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)
return session

View File

@ -0,0 +1,146 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import uuid4
import aiohttp
from tg_bot.application.services.rag_service import RAGService
from tests.metrics.test_hit_at_5 import calculate_hit_at_5, calculate_average_hit_at_5
class TestRAGIntegration:
@pytest.fixture
def rag_service(self):
service = RAGService()
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
service.deepseek_client = DeepSeekClient()
return service
@pytest.mark.asyncio
async def test_rag_service_search_documents_real_implementation(self, rag_service):
user_telegram_id = "123456789"
query = "трудовой договор"
user_uuid = str(uuid4())
collection_id = str(uuid4())
mock_user_response = AsyncMock()
mock_user_response.status = 200
mock_user_response.json = AsyncMock(return_value={"user_id": user_uuid})
mock_collections_response = AsyncMock()
mock_collections_response.status = 200
mock_collections_response.json = AsyncMock(return_value=[
{"collection_id": collection_id, "name": "Законы"}
])
mock_documents_response = AsyncMock()
mock_documents_response.status = 200
mock_documents_response.json = AsyncMock(return_value=[
{
"document_id": str(uuid4()),
"title": "Трудовой кодекс",
"content": "Содержание о трудовых договорах"
}
])
with patch('aiohttp.ClientSession') as mock_session_class:
mock_session = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
mock_session.get = AsyncMock(side_effect=[
mock_user_response,
mock_collections_response,
mock_documents_response
])
mock_session_class.return_value = mock_session
result = await rag_service.search_documents_in_collections(user_telegram_id, query)
assert isinstance(result, list)
if result:
assert "document_id" in result[0]
assert "title" in result[0]
@pytest.mark.asyncio
async def test_rag_service_generate_answer_real_implementation(self, rag_service):
question = "Какие права имеет работник?"
user_telegram_id = "123456789"
mock_documents = [
{
"document_id": str(uuid4()),
"title": "Трудовой кодекс",
"content": "Работник имеет право на...",
"collection_name": "Законы"
}
for _ in range(3)
]
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek:
mock_search.return_value = mock_documents
mock_deepseek.return_value = {
"content": "Работник имеет следующие права...",
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
}
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
assert "answer" in result
assert "sources" in result
assert "usage" in result
assert len(result["sources"]) <= 5
assert result["answer"] != ""
for source in result["sources"]:
assert "title" in source
assert "collection" in source
assert "document_id" in source
@pytest.mark.asyncio
async def test_rag_service_limits_to_top5_documents(self, rag_service):
question = "Тестовый вопрос"
user_telegram_id = "123456789"
many_documents = [
{
"document_id": str(uuid4()),
"title": f"Документ {i}",
"content": f"Содержание {i}",
"collection_name": "Коллекция"
}
for i in range(10)
]
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek:
mock_search.return_value = many_documents
mock_deepseek.return_value = {
"content": "Ответ",
"usage": {}
}
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
assert len(result["sources"]) == 5
@pytest.mark.asyncio
async def test_rag_service_handles_empty_search_results(self, rag_service):
question = "Вопрос без документов"
user_telegram_id = "123456789"
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek:
mock_search.return_value = []
mock_deepseek.return_value = {
"content": "Ответ",
"usage": {}
}
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
assert result["sources"] == []
assert "Релевантные документы не найдены" in result.get("answer", "") or \
result["answer"] == "No relevant documents found"

View File

@ -0,0 +1,133 @@
import pytest
from uuid import uuid4
from typing import List, Dict
def calculate_hit_at_5(retrieved_document_ids: List[str], relevant_document_ids: List[str]) -> int:
if not retrieved_document_ids or not relevant_document_ids:
return 0
top5_ids = set(retrieved_document_ids[:5])
relevant_ids = set(relevant_document_ids)
return 1 if top5_ids.intersection(relevant_ids) else 0
def calculate_average_hit_at_5(results: List[int]) -> float:
if not results:
return 0.0
return sum(results) / len(results)
class TestHitAt5Metric:
def test_hit_at_5_returns_1_when_relevant_document_in_top5(self):
relevant_ids = [str(uuid4()), str(uuid4())]
retrieved_ids = [
str(uuid4()),
relevant_ids[0],
str(uuid4()),
str(uuid4()),
str(uuid4())
]
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
assert result == 1
def test_hit_at_5_returns_0_when_no_relevant_document_in_top5(self):
relevant_ids = [str(uuid4()), str(uuid4())]
retrieved_ids = [
str(uuid4()),
str(uuid4()),
str(uuid4()),
str(uuid4()),
str(uuid4())
]
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
assert result == 0
def test_hit_at_5_returns_1_when_multiple_relevant_documents(self):
relevant_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
retrieved_ids = [
relevant_ids[0],
str(uuid4()),
relevant_ids[1],
str(uuid4()),
relevant_ids[2]
]
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
assert result == 1
def test_hit_at_5_handles_empty_lists(self):
result = calculate_hit_at_5([], [str(uuid4())])
assert result == 0
result = calculate_hit_at_5([str(uuid4())], [])
assert result == 0
result = calculate_hit_at_5([], [])
assert result == 0
def test_hit_at_5_only_checks_top5(self):
relevant_ids = [str(uuid4())]
retrieved_ids = [
str(uuid4()),
str(uuid4()),
str(uuid4()),
str(uuid4()),
str(uuid4()),
relevant_ids[0]
]
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
assert result == 0
def test_calculate_average_hit_at_5(self):
results = [1, 1, 0, 1, 0, 1, 0, 1, 1, 1]
average = calculate_average_hit_at_5(results)
assert average == 0.7
def test_calculate_average_hit_at_5_all_ones(self):
results = [1, 1, 1, 1, 1]
average = calculate_average_hit_at_5(results)
assert average == 1.0
def test_calculate_average_hit_at_5_all_zeros(self):
results = [0, 0, 0, 0, 0]
average = calculate_average_hit_at_5(results)
assert average == 0.0
def test_calculate_average_hit_at_5_empty_list(self):
average = calculate_average_hit_at_5([])
assert average == 0.0
def test_hit_at_5_quality_threshold(self):
results = [1] * 60 + [0] * 40
average = calculate_average_hit_at_5(results)
assert average > 0.5, f"Качество {average} должно быть > 0.5"
assert average == 0.6
def test_hit_at_5_quality_below_threshold(self):
results = [1] * 40 + [0] * 60
average = calculate_average_hit_at_5(results)
assert average < 0.5, f"Качество {average} должно быть < 0.5"
assert average == 0.4
@pytest.mark.parametrize("hit_count,total,expected_quality", [
(51, 100, 0.51),
(50, 100, 0.50),
(60, 100, 0.60),
(75, 100, 0.75),
(100, 100, 1.0),
])
def test_hit_at_5_various_qualities(self, hit_count, total, expected_quality):
results = [1] * hit_count + [0] * (total - hit_count)
average = calculate_average_hit_at_5(results)
assert average == expected_quality

8
tests/requirements.txt Normal file
View File

@ -0,0 +1,8 @@
pytest==7.4.3
pytest-asyncio==0.21.1
pytest-cov==4.1.0
pytest-mock==3.12.0
pytest-timeout==2.2.0
httpx>=0.25.2
aiohttp>=3.9.1

View File

@ -0,0 +1,132 @@
import pytest
from uuid import uuid4
from unittest.mock import AsyncMock
try:
from backend.src.application.use_cases.collection_use_cases import CollectionUseCases
from backend.src.shared.exceptions import NotFoundError, ForbiddenError
from backend.src.domain.repositories.collection_access_repository import ICollectionAccessRepository
except ImportError:
from src.application.use_cases.collection_use_cases import CollectionUseCases
from src.shared.exceptions import NotFoundError, ForbiddenError
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
class TestCollectionUseCases:
@pytest.fixture
def collection_use_cases(self, mock_collection_repository, mock_user_repository):
mock_access_repository = AsyncMock()
mock_access_repository.get_by_user_and_collection = AsyncMock(return_value=None)
mock_access_repository.create = AsyncMock()
mock_access_repository.delete_by_user_and_collection = AsyncMock(return_value=True)
mock_access_repository.list_by_user = AsyncMock(return_value=[])
return CollectionUseCases(
collection_repository=mock_collection_repository,
access_repository=mock_access_repository,
user_repository=mock_user_repository
)
@pytest.mark.asyncio
async def test_create_collection_success(self, collection_use_cases, mock_user,
mock_collection_repository, mock_user_repository):
owner_id = uuid4()
mock_user_repository.get_by_id = AsyncMock(return_value=mock_user)
mock_collection_repository.create = AsyncMock(return_value=mock_user)
result = await collection_use_cases.create_collection(
name="Тестовая коллекция",
owner_id=owner_id,
description="Описание",
is_public=False
)
assert result is not None
mock_user_repository.get_by_id.assert_called_once_with(owner_id)
mock_collection_repository.create.assert_called_once()
@pytest.mark.asyncio
async def test_create_collection_user_not_found(self, collection_use_cases, mock_user_repository):
owner_id = uuid4()
mock_user_repository.get_by_id = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await collection_use_cases.create_collection(
name="Коллекция",
owner_id=owner_id
)
@pytest.mark.asyncio
async def test_get_collection_success(self, collection_use_cases, mock_collection, mock_collection_repository):
collection_id = uuid4()
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
result = await collection_use_cases.get_collection(collection_id)
assert result == mock_collection
mock_collection_repository.get_by_id.assert_called_once_with(collection_id)
@pytest.mark.asyncio
async def test_get_collection_not_found(self, collection_use_cases, mock_collection_repository):
collection_id = uuid4()
mock_collection_repository.get_by_id = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await collection_use_cases.get_collection(collection_id)
@pytest.mark.asyncio
async def test_update_collection_success(self, collection_use_cases, mock_collection, mock_collection_repository):
collection_id = uuid4()
user_id = uuid4()
mock_collection.owner_id = user_id
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
mock_collection_repository.update = AsyncMock(return_value=mock_collection)
result = await collection_use_cases.update_collection(
collection_id=collection_id,
user_id=user_id,
name="Обновленное название"
)
assert result is not None
assert mock_collection.name == "Обновленное название"
@pytest.mark.asyncio
async def test_update_collection_forbidden(self, collection_use_cases, mock_collection, mock_collection_repository):
collection_id = uuid4()
user_id = uuid4()
owner_id = uuid4()
mock_collection.owner_id = owner_id
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
with pytest.raises(ForbiddenError):
await collection_use_cases.update_collection(
collection_id=collection_id,
user_id=user_id,
name="Название"
)
@pytest.mark.asyncio
async def test_check_access_owner(self, collection_use_cases, mock_collection, mock_collection_repository):
collection_id = uuid4()
user_id = uuid4()
mock_collection.owner_id = user_id
mock_collection.is_public = False
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
result = await collection_use_cases.check_access(collection_id, user_id)
assert result is True
@pytest.mark.asyncio
async def test_check_access_public(self, collection_use_cases, mock_collection, mock_collection_repository):
collection_id = uuid4()
user_id = uuid4()
owner_id = uuid4()
mock_collection.owner_id = owner_id
mock_collection.is_public = True
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
result = await collection_use_cases.check_access(collection_id, user_id)
assert result is True

View File

@ -0,0 +1,114 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
import httpx
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
from tg_bot.infrastructure.external.deepseek_client import DeepSeekAPIError
class TestDeepSeekClient:
@pytest.fixture
def deepseek_client(self):
return DeepSeekClient(api_key="test_key", api_url="https://api.test.com/v1/chat/completions")
@pytest.mark.asyncio
async def test_chat_completion_success(self, deepseek_client):
messages = [
{"role": "user", "content": "Тестовый вопрос"}
]
mock_response_data = {
"choices": [{
"message": {
"content": "Тестовый ответ от DeepSeek"
}
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}
with patch('httpx.AsyncClient') as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_client.return_value = mock_client_instance
result = await deepseek_client.chat_completion(messages)
assert "content" in result
assert result["content"] == "Тестовый ответ от DeepSeek"
assert "usage" in result
assert result["usage"]["total_tokens"] == 30
@pytest.mark.asyncio
async def test_chat_completion_no_api_key(self):
client = DeepSeekClient(api_key=None)
messages = [{"role": "user", "content": "Вопрос"}]
result = await client.chat_completion(messages)
assert "content" in result
assert "DEEPSEEK_API_KEY" in result["content"] or "не установлен" in result["content"]
assert result["usage"]["total_tokens"] == 0
@pytest.mark.asyncio
async def test_chat_completion_api_error(self, deepseek_client):
import httpx
messages = [{"role": "user", "content": "Вопрос"}]
with patch('httpx.AsyncClient') as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Unauthorized", request=MagicMock(), response=mock_response
)
mock_client_instance = AsyncMock()
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_client.return_value = mock_client_instance
with pytest.raises(DeepSeekAPIError):
await deepseek_client.chat_completion(messages)
@pytest.mark.asyncio
async def test_chat_completion_with_parameters(self, deepseek_client):
messages = [{"role": "user", "content": "Вопрос"}]
mock_response_data = {
"choices": [{"message": {"content": "Ответ"}}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
}
with patch('httpx.AsyncClient') as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_client.return_value = mock_client_instance
result = await deepseek_client.chat_completion(
messages,
model="deepseek-chat",
temperature=0.7,
max_tokens=100
)
assert result["content"] == "Ответ"
call_args = mock_client_instance.post.call_args
assert call_args is not None

View File

@ -0,0 +1,141 @@
import pytest
from uuid import uuid4
from unittest.mock import AsyncMock
try:
from backend.src.application.use_cases.document_use_cases import DocumentUseCases
from backend.src.shared.exceptions import NotFoundError, ForbiddenError
from backend.src.application.services.document_parser_service import DocumentParserService
except ImportError:
from src.application.use_cases.document_use_cases import DocumentUseCases
from src.shared.exceptions import NotFoundError, ForbiddenError
from src.application.services.document_parser_service import DocumentParserService
class TestDocumentUseCases:
@pytest.fixture
def document_use_cases(self, mock_document_repository, mock_collection_repository):
mock_parser = AsyncMock()
mock_parser.parse_pdf = AsyncMock(return_value=("Парсенный документ", "Содержание"))
return DocumentUseCases(
document_repository=mock_document_repository,
collection_repository=mock_collection_repository,
parser_service=mock_parser
)
@pytest.mark.asyncio
async def test_create_document_success(self, document_use_cases, mock_collection, mock_document_repository, mock_collection_repository):
collection_id = uuid4()
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
mock_document_repository.create = AsyncMock(return_value=mock_collection)
result = await document_use_cases.create_document(
collection_id=collection_id,
title="Тестовый документ",
content="Содержание",
metadata={"type": "law"}
)
assert result is not None
mock_collection_repository.get_by_id.assert_called_once_with(collection_id)
mock_document_repository.create.assert_called_once()
@pytest.mark.asyncio
async def test_create_document_collection_not_found(self, document_use_cases, mock_collection_repository):
collection_id = uuid4()
mock_collection_repository.get_by_id = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await document_use_cases.create_document(
collection_id=collection_id,
title="Документ",
content="Содержание"
)
@pytest.mark.asyncio
async def test_get_document_success(self, document_use_cases, mock_document, mock_document_repository):
document_id = uuid4()
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
result = await document_use_cases.get_document(document_id)
assert result == mock_document
mock_document_repository.get_by_id.assert_called_once_with(document_id)
@pytest.mark.asyncio
async def test_get_document_not_found(self, document_use_cases, mock_document_repository):
document_id = uuid4()
mock_document_repository.get_by_id = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await document_use_cases.get_document(document_id)
@pytest.mark.asyncio
async def test_update_document_success(self, document_use_cases, mock_document, mock_collection,
mock_document_repository, mock_collection_repository):
document_id = uuid4()
user_id = uuid4()
mock_document.collection_id = uuid4()
mock_collection.owner_id = user_id
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
mock_document_repository.update = AsyncMock(return_value=mock_document)
result = await document_use_cases.update_document(
document_id=document_id,
user_id=user_id,
title="Обновленное название"
)
assert result is not None
assert mock_document.title == "Обновленное название"
@pytest.mark.asyncio
async def test_update_document_forbidden(self, document_use_cases, mock_document, mock_collection,
mock_document_repository, mock_collection_repository):
document_id = uuid4()
user_id = uuid4()
owner_id = uuid4()
mock_document.collection_id = uuid4()
mock_collection.owner_id = owner_id
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
with pytest.raises(ForbiddenError):
await document_use_cases.update_document(
document_id=document_id,
user_id=user_id,
title="Название"
)
@pytest.mark.asyncio
async def test_delete_document_success(self, document_use_cases, mock_document, mock_collection,
mock_document_repository, mock_collection_repository):
document_id = uuid4()
user_id = uuid4()
mock_document.collection_id = uuid4()
mock_collection.owner_id = user_id
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
mock_document_repository.delete = AsyncMock(return_value=True)
result = await document_use_cases.delete_document(document_id, user_id)
assert result is True
mock_document_repository.delete.assert_called_once_with(document_id)
@pytest.mark.asyncio
async def test_list_collection_documents(self, document_use_cases, mock_collection, mock_documents_list,
mock_collection_repository, mock_document_repository):
collection_id = uuid4()
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
mock_document_repository.list_by_collection = AsyncMock(return_value=mock_documents_list)
result = await document_use_cases.list_collection_documents(collection_id, skip=0, limit=10)
assert len(result) == len(mock_documents_list)
mock_document_repository.list_by_collection.assert_called_once_with(collection_id, skip=0, limit=10)

View File

@ -0,0 +1,171 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import uuid4
from tg_bot.application.services.rag_service import RAGService
class TestRAGService:
@pytest.fixture
def rag_service(self):
service = RAGService()
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
service.deepseek_client = DeepSeekClient()
return service
@pytest.mark.asyncio
async def test_search_documents_in_collections_success(self, rag_service):
user_telegram_id = "123456789"
query = "трудовой договор"
mock_documents = [
{
"document_id": str(uuid4()),
"title": "Трудовой кодекс РФ",
"content": "Содержание о трудовых договорах",
"collection_name": "Законы"
},
{
"document_id": str(uuid4()),
"title": "Правила оформления",
"content": "Как оформить трудовой договор",
"collection_name": "Инструкции"
}
]
with patch('aiohttp.ClientSession') as mock_session:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
"user_id": str(uuid4())
})
mock_collections_response = AsyncMock()
mock_collections_response.status = 200
mock_collections_response.json = AsyncMock(return_value=[
{"collection_id": str(uuid4()), "name": "Законы"}
])
mock_search_response = AsyncMock()
mock_search_response.status = 200
mock_search_response.json = AsyncMock(return_value=mock_documents)
mock_session_instance = MagicMock()
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
mock_session_instance.__aexit__ = AsyncMock(return_value=None)
mock_session_instance.get = AsyncMock(side_effect=[
mock_response,
mock_collections_response,
mock_search_response
])
mock_session.return_value = mock_session_instance
result = await rag_service.search_documents_in_collections(
user_telegram_id, query, limit_per_collection=5
)
assert len(result) > 0
assert result[0]["title"] == "Трудовой кодекс РФ"
@pytest.mark.asyncio
async def test_search_documents_empty_result(self, rag_service):
user_telegram_id = "123456789"
query = "несуществующий запрос"
with patch('aiohttp.ClientSession') as mock_session:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
"user_id": str(uuid4())
})
mock_collections_response = AsyncMock()
mock_collections_response.status = 200
mock_collections_response.json = AsyncMock(return_value=[])
mock_session_instance = MagicMock()
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
mock_session_instance.__aexit__ = AsyncMock(return_value=None)
mock_session_instance.get = AsyncMock(side_effect=[
mock_response,
mock_collections_response
])
mock_session.return_value = mock_session_instance
result = await rag_service.search_documents_in_collections(
user_telegram_id, query
)
assert result == []
@pytest.mark.asyncio
async def test_generate_answer_with_rag_success(self, rag_service, mock_rag_response):
question = "Какие права имеет работник?"
user_telegram_id = "123456789"
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
patch.object(rag_service, 'deepseek_client') as mock_client:
mock_search.return_value = [
{
"document_id": str(uuid4()),
"title": "Трудовой кодекс",
"content": "Работник имеет право на...",
"collection_name": "Законы"
}
]
mock_client.chat_completion = AsyncMock(return_value={
"content": "Работник имеет следующие права...",
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
})
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
assert "answer" in result
assert "sources" in result
assert "usage" in result
assert len(result["sources"]) <= 5
assert result["answer"] != ""
@pytest.mark.asyncio
async def test_generate_answer_limits_to_top5(self, rag_service):
question = "Тестовый вопрос"
user_telegram_id = "123456789"
many_documents = [
{
"document_id": str(uuid4()),
"title": f"Документ {i}",
"content": f"Содержание {i}",
"collection_name": "Коллекция"
}
for i in range(20)
]
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
patch.object(rag_service, 'deepseek_client') as mock_client:
mock_search.return_value = many_documents
mock_client.chat_completion = AsyncMock(return_value={
"content": "Ответ",
"usage": {}
})
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
assert len(result["sources"]) == 5
@pytest.mark.asyncio
async def test_generate_answer_no_documents(self, rag_service):
question = "Вопрос без документов"
user_telegram_id = "123456789"
with patch.object(rag_service, 'search_documents_in_collections') as mock_search:
mock_search.return_value = []
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
assert result["sources"] == []
assert "Релевантные документы не найдены" in result.get("answer", "") or \
result["answer"] == "No relevant documents found"

View File

@ -0,0 +1,193 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from datetime import datetime, timedelta
from tg_bot.domain.services.user_service import UserService
from tg_bot.infrastructure.database.models import UserModel
class TestUserService:
@pytest.fixture
def mock_session(self):
session = AsyncMock()
session.execute = AsyncMock()
session.add = MagicMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
return session
@pytest.fixture
def user_service(self, mock_session):
return UserService(mock_session)
@pytest.mark.asyncio
async def test_get_user_by_telegram_id_success(self, user_service, mock_session):
telegram_id = 123456789
mock_user = UserModel(
telegram_id=str(telegram_id),
username="test_user",
first_name="Test",
last_name="User"
)
mock_result = MagicMock()
mock_result.scalar_one_or_none = MagicMock(return_value=mock_user)
mock_session.execute.return_value = mock_result
result = await user_service.get_user_by_telegram_id(telegram_id)
assert result == mock_user
assert result.telegram_id == str(telegram_id)
mock_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_by_telegram_id_not_found(self, user_service, mock_session):
telegram_id = 999999999
mock_result = MagicMock()
mock_result.scalar_one_or_none = MagicMock(return_value=None)
mock_session.execute.return_value = mock_result
result = await user_service.get_user_by_telegram_id(telegram_id)
assert result is None
mock_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_or_create_user_new_user(self, user_service, mock_session):
telegram_id = 123456789
username = "new_user"
first_name = "New"
last_name = "User"
mock_result_not_found = MagicMock()
mock_result_not_found.scalar_one_or_none = MagicMock(return_value=None)
mock_result_found = MagicMock()
created_user = UserModel(
telegram_id=str(telegram_id),
username=username,
first_name=first_name,
last_name=last_name
)
mock_result_found.scalar_one_or_none = MagicMock(return_value=created_user)
mock_session.execute.side_effect = [mock_result_not_found, mock_result_found]
result = await user_service.get_or_create_user(telegram_id, username, first_name, last_name)
assert result is not None
assert result.telegram_id == str(telegram_id)
assert result.username == username
mock_session.add.assert_called_once()
mock_session.commit.assert_called()
@pytest.mark.asyncio
async def test_get_or_create_user_existing_user(self, user_service, mock_session):
telegram_id = 123456789
existing_user = UserModel(
telegram_id=str(telegram_id),
username="old_username",
first_name="Old",
last_name="Name"
)
mock_result = MagicMock()
mock_result.scalar_one_or_none = MagicMock(return_value=existing_user)
mock_session.execute.return_value = mock_result
result = await user_service.get_or_create_user(
telegram_id, "new_username", "New", "Name"
)
assert result == existing_user
assert result.username == "new_username"
assert result.first_name == "New"
assert result.last_name == "Name"
mock_session.commit.assert_called()
@pytest.mark.asyncio
async def test_update_user_questions_success(self, user_service, mock_session):
telegram_id = 123456789
user = UserModel(
telegram_id=str(telegram_id),
questions_used=5
)
mock_result = MagicMock()
mock_result.scalar_one_or_none = MagicMock(return_value=user)
mock_session.execute.return_value = mock_result
result = await user_service.update_user_questions(telegram_id)
assert result is True
assert user.questions_used == 6
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_user_questions_user_not_found(self, user_service, mock_session):
telegram_id = 999999999
mock_result = MagicMock()
mock_result.scalar_one_or_none = MagicMock(return_value=None)
mock_session.execute.return_value = mock_result
result = await user_service.update_user_questions(telegram_id)
assert result is False
mock_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_activate_premium_success(self, user_service, mock_session):
telegram_id = 123456789
user = UserModel(
telegram_id=str(telegram_id),
is_premium=False,
premium_until=None
)
mock_result = MagicMock()
mock_result.scalar_one_or_none = MagicMock(return_value=user)
mock_session.execute.return_value = mock_result
result = await user_service.activate_premium(telegram_id)
assert result is True
assert user.is_premium is True
assert user.premium_until is not None
assert user.premium_until > datetime.now()
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_activate_premium_extend_existing(self, user_service, mock_session):
telegram_id = 123456789
existing_premium_until = datetime.now() + timedelta(days=10)
user = UserModel(
telegram_id=str(telegram_id),
is_premium=True,
premium_until=existing_premium_until
)
mock_result = MagicMock()
mock_result.scalar_one_or_none = MagicMock(return_value=user)
mock_session.execute.return_value = mock_result
result = await user_service.activate_premium(telegram_id)
assert result is True
assert user.is_premium is True
assert user.premium_until > existing_premium_until
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_activate_premium_user_not_found(self, user_service, mock_session):
telegram_id = 999999999
mock_result = MagicMock()
mock_result.scalar_one_or_none = MagicMock(return_value=None)
mock_session.execute.return_value = mock_result
result = await user_service.activate_premium(telegram_id)
assert result is False
mock_session.commit.assert_not_called()

35
tg_bot/.dockerignore Normal file
View File

@ -0,0 +1,35 @@
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
venv/
env/
ENV/
.venv/
.vscode/
.idea/
*.swp
*.swo
*~
.DS_Store
.git/
.gitignore
.gitattributes
Dockerfile*
docker-compose*.yml
.dockerignore
drone.yml
tmp/
temp/
*.tmp
Thumbs.db
.DS_Store

18
tg_bot/Dockerfile Normal file
View File

@ -0,0 +1,18 @@
FROM python:3.11-slim
WORKDIR /app
RUN apt-get update && apt-get install -y \
gcc \
postgresql-client \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . ./tg_bot/
ENV PYTHONPATH=/app
ENV PYTHONUNBUFFERED=1
CMD ["python", "tg_bot/main.py"]

View File

@ -1,139 +1,130 @@
import aiohttp """
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient RAG сервис для бота - вызывает API бэкенда
"""
from tg_bot.config.settings import settings from tg_bot.config.settings import settings
from tg_bot.infrastructure.http_client import create_http_session
BACKEND_URL = "http://localhost:8001/api/v1"
class RAGService: class RAGService:
"""Сервис для работы с RAG через API бэкенда"""
def __init__(self): async def get_or_create_conversation(
self.deepseek_client = DeepSeekClient()
async def search_documents_in_collections(
self, self,
user_telegram_id: str, user_telegram_id: str,
query: str, collection_id: str = None
limit_per_collection: int = 5 ) -> str | None:
) -> list[dict]: """Получить или создать беседу для пользователя"""
try: try:
async with aiohttp.ClientSession() as session: async with create_http_session() as session:
async with session.get( async with session.get(
f"{BACKEND_URL}/users/telegram/{user_telegram_id}" f"{settings.BACKEND_URL}/collections/",
) as user_response:
if user_response.status != 200:
return []
user_data = await user_response.json()
user_uuid = str(user_data.get("user_id"))
if not user_uuid:
return []
async with session.get(
f"{BACKEND_URL}/collections/",
headers={"X-Telegram-ID": user_telegram_id} headers={"X-Telegram-ID": user_telegram_id}
) as collections_response: ) as collections_response:
if collections_response.status != 200: if collections_response.status != 200:
return [] return None
collections = await collections_response.json() collections = await collections_response.json()
all_documents = [] if not collections:
for collection in collections:
collection_id = collection.get("collection_id")
if not collection_id: if not collection_id:
continue async with session.post(
f"{settings.BACKEND_URL}/collections",
try: json={
async with aiohttp.ClientSession() as search_session: "name": "Основная коллекция",
async with search_session.get( "description": "Коллекция по умолчанию",
f"{BACKEND_URL}/documents/collection/{collection_id}", "is_public": False
params={"search": query, "limit": limit_per_collection}, },
headers={"X-Telegram-ID": user_telegram_id} headers={"X-Telegram-ID": user_telegram_id}
) as search_response: ) as create_collection_response:
if search_response.status == 200: if create_collection_response.status in [200, 201]:
documents = await search_response.json() collection_data = await create_collection_response.json()
for doc in documents: collection_id = collection_data.get("collection_id")
doc["collection_name"] = collection.get("name", "Unknown") else:
all_documents.append(doc) collection_id = collection_id
except Exception as e: else:
print(f"Error searching collection {collection_id}: {e}") collection_id = collections[0].get("collection_id")
continue
return all_documents[:20] if not collection_id:
return None
async with session.get(
f"{settings.BACKEND_URL}/conversations",
headers={"X-Telegram-ID": user_telegram_id}
) as conversations_response:
if conversations_response.status == 200:
conversations = await conversations_response.json()
for conv in conversations:
if conv.get("collection_id") == str(collection_id):
return conv.get("conversation_id")
async with session.post(
f"{settings.BACKEND_URL}/conversations",
json={"collection_id": str(collection_id)},
headers={"X-Telegram-ID": user_telegram_id}
) as create_conversation_response:
if create_conversation_response.status in [200, 201]:
conversation_data = await create_conversation_response.json()
return conversation_data.get("conversation_id")
return None
except Exception as e: except Exception as e:
print(f"Error searching documents: {e}") print(f"Error getting/creating conversation: {e}")
return [] return None
async def generate_answer_with_rag( async def generate_answer_with_rag(
self, self,
question: str, question: str,
user_telegram_id: str user_telegram_id: str
) -> dict: ) -> dict:
documents = await self.search_documents_in_collections( """Генерирует ответ используя RAG через API бэкенда"""
user_telegram_id,
question
)
context_parts = []
sources = []
for doc in documents[:5]:
title = doc.get("title", "Без названия")
content = doc.get("content", "")[:1000]
collection_name = doc.get("collection_name", "Unknown")
context_parts.append(f"Документ: {title}\nКоллекция: {collection_name}\nСодержание: {content[:500]}...")
sources.append({
"title": title,
"collection": collection_name,
"document_id": doc.get("document_id")
})
context = "\n\n".join(context_parts) if context_parts else "Релевантные документы не найдены."
system_prompt = """Ты - помощник-юрист, который отвечает на вопросы на основе предоставленных документов.
Используй информацию из документов для формирования точного и полезного ответа.
Если в документах нет информации для ответа, честно скажи об этом."""
user_prompt = f"""Контекст из документов:
{context}
Вопрос пользователя: {question}
Ответь на вопрос, используя информацию из предоставленных документов. Если информации недостаточно, укажи это."""
try: try:
messages = [ conversation_id = await self.get_or_create_conversation(user_telegram_id)
{"role": "system", "content": system_prompt}, if not conversation_id:
{"role": "user", "content": user_prompt}
]
response = await self.deepseek_client.chat_completion(
messages=messages,
temperature=0.7,
max_tokens=2000
)
return { return {
"answer": response.get("content", "Failed to generate answer"), "answer": "Не удалось создать беседу. Попробуйте позже.",
"sources": sources, "sources": [],
"usage": response.get("usage", {}) "usage": {}
} }
except Exception as e: async with create_http_session() as session:
print(f"Error generating answer: {e}") async with session.post(
if documents: f"{settings.BACKEND_URL}/rag/question",
return { json={
"answer": f"Found {len(documents)} documents but failed to generate answer", "conversation_id": str(conversation_id),
"sources": sources[:3], "question": question,
"usage": {} "top_k": 20,
} "rerank_top_n": 5
else: },
return { headers={"X-Telegram-ID": user_telegram_id}
"answer": "No relevant documents found", ) as response:
if response.status == 200:
result = await response.json()
sources = []
for source in result.get("sources", []):
sources.append({
"title": source.get("title", "Без названия"),
"document_id": source.get("document_id"),
"chunk_id": source.get("chunk_id"),
"index": source.get("index", 0)
})
return {
"answer": result.get("answer", "Не удалось сгенерировать ответ."),
"sources": sources,
"usage": result.get("usage", {}),
"conversation_id": str(conversation_id)
}
else:
error_text = await response.text()
print(f"RAG API error: {response.status} - {error_text}")
return {
"answer": "Ошибка при генерации ответа. Попробуйте позже.",
"sources": [],
"usage": {}
}
except Exception as e:
print(f"Error generating answer with RAG: {e}")
return {
"answer": "Произошла ошибка при генерации ответа. Попробуйте позже.",
"sources": [], "sources": [],
"usage": {} "usage": {}
} }

View File

@ -1,9 +1,10 @@
import os
from typing import List, Optional from typing import List, Optional
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
"""Настройки приложения получаеи из env файла, тут не ищи, мы спрятали:)"""
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file=".env",
env_file_encoding="utf-8", env_file_encoding="utf-8",
@ -13,26 +14,35 @@ class Settings(BaseSettings):
APP_NAME: str = "VibeLawyerBot" APP_NAME: str = "VibeLawyerBot"
VERSION: str = "0.1.0" VERSION: str = "0.1.0"
DEBUG: bool = True DEBUG: bool = False
TELEGRAM_BOT_TOKEN: str = ""
TELEGRAM_BOT_TOKEN: str
FREE_QUESTIONS_LIMIT: int = 5 FREE_QUESTIONS_LIMIT: int = 5
PAYMENT_AMOUNT: float = 500.0 PAYMENT_AMOUNT: float = 500.0
DATABASE_URL: str = "sqlite:///data/bot.db"
LOG_LEVEL: str = "INFO" LOG_LEVEL: str = "INFO"
LOG_FILE: str = "logs/bot.log" LOG_FILE: str = "logs/bot.log"
YOOKASSA_SHOP_ID: str = "1230200"
YOOKASSA_SECRET_KEY: str = "test_GVoixmlp0FqohXcyFzFHbRlAUoA3B1I2aMtAkAE_ubw" YOOKASSA_SHOP_ID: str
YOOKASSA_SECRET_KEY: str
YOOKASSA_RETURN_URL: str = "https://t.me/vibelawyer_bot" YOOKASSA_RETURN_URL: str = "https://t.me/vibelawyer_bot"
YOOKASSA_WEBHOOK_SECRET: Optional[str] = None YOOKASSA_WEBHOOK_SECRET: Optional[str] = None
DEEPSEEK_API_KEY: Optional[str] = None DEEPSEEK_API_KEY: Optional[str] = None
DEEPSEEK_API_URL: str = "https://api.deepseek.com/v1/chat/completions" DEEPSEEK_API_URL: str = "https://api.deepseek.com/v1/chat/completions"
BACKEND_URL: str
ADMIN_IDS_STR: str = "" ADMIN_IDS_STR: str = ""
@property @property
def ADMIN_IDS(self) -> List[int]: def ADMIN_IDS(self) -> List[int]:
"""Список ID администраторов из строки через запятую"""
if not self.ADMIN_IDS_STR: if not self.ADMIN_IDS_STR:
return [] return []
try: try:

View File

@ -1,67 +0,0 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from datetime import datetime, timedelta
from typing import Optional
from tg_bot.infrastructure.database.models import UserModel
class UserService:
def __init__(self, session: AsyncSession):
self.session = session
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[UserModel]:
result = await self.session.execute(
select(UserModel).filter_by(telegram_id=str(telegram_id))
)
return result.scalar_one_or_none()
async def get_or_create_user(
self,
telegram_id: int,
username: str = "",
first_name: str = "",
last_name: str = ""
) -> UserModel:
user = await self.get_user_by_telegram_id(telegram_id)
if not user:
user = UserModel(
telegram_id=str(telegram_id),
username=username,
first_name=first_name,
last_name=last_name
)
self.session.add(user)
await self.session.commit()
else:
user.username = username
user.first_name = first_name
user.last_name = last_name
await self.session.commit()
return user
async def update_user_questions(self, telegram_id: int) -> bool:
user = await self.get_user_by_telegram_id(telegram_id)
if user:
user.questions_used += 1
await self.session.commit()
return True
return False
async def activate_premium(self, telegram_id: int) -> bool:
try:
user = await self.get_user_by_telegram_id(telegram_id)
if user:
user.is_premium = True
if user.premium_until and user.premium_until > datetime.now():
user.premium_until = user.premium_until + timedelta(days=30)
else:
user.premium_until = datetime.now() + timedelta(days=30)
await self.session.commit()
return True
else:
return False
except Exception as e:
print(f"Error activating premium: {e}")
await self.session.rollback()
return False

View File

@ -0,0 +1,126 @@
import aiohttp
from datetime import datetime
from typing import Optional
from tg_bot.config.settings import settings
from tg_bot.infrastructure.http_client import create_http_session
class User:
"""Модель пользователя для телеграм-бота"""
def __init__(self, data: dict):
self.user_id = data.get("user_id")
self.telegram_id = data.get("telegram_id")
self.role = data.get("role")
created_at_str = data.get("created_at")
if created_at_str:
try:
created_at_str = created_at_str.replace("Z", "+00:00")
self.created_at = datetime.fromisoformat(created_at_str)
except (ValueError, AttributeError):
self.created_at = None
else:
self.created_at = None
premium_until_str = data.get("premium_until")
if premium_until_str:
try:
premium_until_str = premium_until_str.replace("Z", "+00:00")
self.premium_until = datetime.fromisoformat(premium_until_str)
except (ValueError, AttributeError):
self.premium_until = None
else:
self.premium_until = None
self.is_premium = data.get("is_premium", False)
self.questions_used = data.get("questions_used", 0)
class UserService:
"""Сервис для работы с пользователями через API бэкенда"""
def __init__(self):
self.backend_url = settings.BACKEND_URL
print(f"UserService initialized with BACKEND_URL: {self.backend_url}")
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[User]:
"""Получить пользователя по Telegram ID"""
try:
url = f"{self.backend_url}/users/telegram/{telegram_id}"
async with create_http_session() as session:
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
return User(data)
return None
except aiohttp.ClientConnectorError as e:
print(f"Backend not available at {self.backend_url}: {e}")
return None
except Exception as e:
print(f"Error getting user: {e}")
return None
async def get_or_create_user(
self,
telegram_id: int,
username: str = "",
first_name: str = "",
last_name: str = ""
) -> User:
"""Получить или создать пользователя"""
user = await self.get_user_by_telegram_id(telegram_id)
if not user:
try:
async with create_http_session() as session:
async with session.post(
f"{self.backend_url}/users",
json={"telegram_id": str(telegram_id), "role": "user"}
) as response:
if response.status in [200, 201]:
data = await response.json()
return User(data)
else:
error_text = await response.text()
raise Exception(
f"Backend API returned status {response.status}: {error_text}. "
f"Make sure the backend server is running at {self.backend_url}"
)
except aiohttp.ClientConnectorError as e:
error_msg = (
f"Cannot connect to backend API at {self.backend_url}. "
f"Please ensure the backend server is running on port 8000. "
f"Start it with: cd project/backend && python run.py"
)
print(f"Error creating user: {error_msg}")
print(f"Original error: {e}")
raise ConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Error creating user: {e}. Backend URL: {self.backend_url}"
print(error_msg)
raise
return user
async def update_user_questions(self, telegram_id: int) -> bool:
"""Увеличить счетчик использованных вопросов"""
try:
async with create_http_session() as session:
async with session.post(
f"{self.backend_url}/users/telegram/{telegram_id}/increment-questions"
) as response:
return response.status == 200
except Exception as e:
print(f"Error updating questions: {e}")
return False
async def activate_premium(self, telegram_id: int, days: int = 30) -> bool:
"""Активировать premium статус"""
try:
async with create_http_session() as session:
async with session.post(
f"{self.backend_url}/users/telegram/{telegram_id}/activate-premium",
params={"days": days}
) as response:
return response.status == 200
except Exception as e:
print(f"Error activating premium: {e}")
return False

View File

@ -0,0 +1,2 @@
"""Infrastructure layer for the Telegram bot"""

View File

@ -1,19 +0,0 @@
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from tg_bot.config.settings import settings
database_url = settings.DATABASE_URL
if database_url.startswith("sqlite:///"):
database_url = database_url.replace("sqlite:///", "sqlite+aiosqlite:///")
engine = create_async_engine(
database_url,
echo=settings.DEBUG
)
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async def create_tables():
from .models import Base
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
print(f"Таблицы созданы: {settings.DATABASE_URL}")

View File

@ -1,39 +0,0 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, DateTime, Boolean, Integer, Text
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class UserModel(Base):
__tablename__ = "users"
user_id = Column("user_id", String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
telegram_id = Column("telegram_id", String(100), nullable=False, unique=True)
created_at = Column("created_at", DateTime, default=datetime.utcnow, nullable=False)
role = Column("role", String(20), default="user", nullable=False)
is_premium = Column(Boolean, default=False, nullable=False)
premium_until = Column(DateTime, nullable=True)
questions_used = Column(Integer, default=0, nullable=False)
username = Column(String(100), nullable=True)
first_name = Column(String(100), nullable=True)
last_name = Column(String(100), nullable=True)
class PaymentModel(Base):
__tablename__ = "payments"
id = Column(Integer, primary_key=True, autoincrement=True)
payment_id = Column(String(36), default=lambda: str(uuid.uuid4()), nullable=False, unique=True)
user_id = Column(Integer, nullable=False)
amount = Column(String(20), nullable=False)
currency = Column(String(3), default="RUB", nullable=False)
status = Column(String(20), default="pending", nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
yookassa_payment_id = Column(String(100), unique=True, nullable=True)
description = Column(Text, nullable=True)
def __repr__(self):
return f"<Payment(user_id={self.user_id}, amount={self.amount}, status={self.status})>"

View File

@ -1,172 +0,0 @@
import json
from typing import Optional, AsyncIterator
import httpx
from tg_bot.config.settings import settings
class DeepSeekAPIError(Exception):
pass
class DeepSeekClient:
def __init__(self, api_key: str | None = None, api_url: str | None = None):
self.api_key = api_key or settings.DEEPSEEK_API_KEY
self.api_url = api_url or settings.DEEPSEEK_API_URL
self.timeout = 60.0
def _get_headers(self) -> dict[str, str]:
if not self.api_key:
raise DeepSeekAPIError("API key not set")
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
async def chat_completion(
self,
messages: list[dict[str, str]],
model: str = "deepseek-chat",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
stream: bool = False
) -> dict:
if not self.api_key:
return {
"content": "API key not configured",
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"stream": stream
}
if max_tokens is not None:
payload["max_tokens"] = max_tokens
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
self.api_url,
headers=self._get_headers(),
json=payload
)
response.raise_for_status()
data = response.json()
if "choices" in data and len(data["choices"]) > 0:
content = data["choices"][0]["message"]["content"]
else:
raise DeepSeekAPIError("Invalid response format")
usage = data.get("usage", {})
return {
"content": content,
"usage": {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0)
}
}
except httpx.HTTPStatusError as e:
error_msg = f"API error: {e.response.status_code}"
try:
error_data = e.response.json()
if "error" in error_data:
error_msg = error_data['error'].get('message', error_msg)
except:
pass
raise DeepSeekAPIError(error_msg) from e
except httpx.RequestError as e:
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
except Exception as e:
raise DeepSeekAPIError(str(e)) from e
async def stream_chat_completion(
self,
messages: list[dict[str, str]],
model: str = "deepseek-chat",
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[str]:
if not self.api_key:
yield "API key not configured"
return
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"stream": True
}
if max_tokens is not None:
payload["max_tokens"] = max_tokens
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
async with client.stream(
"POST",
self.api_url,
headers=self._get_headers(),
json=payload
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
if line.strip() == "[DONE]":
break
try:
data = json.loads(line)
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except httpx.HTTPStatusError as e:
error_msg = f"API error: {e.response.status_code}"
try:
error_data = e.response.json()
if "error" in error_data:
error_msg = error_data['error'].get('message', error_msg)
except:
pass
raise DeepSeekAPIError(error_msg) from e
except httpx.RequestError as e:
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
except Exception as e:
raise DeepSeekAPIError(str(e)) from e
async def health_check(self) -> bool:
if not self.api_key:
return False
try:
test_messages = [{"role": "user", "content": "test"}]
await self.chat_completion(test_messages, max_tokens=1)
return True
except Exception:
return False

View File

@ -0,0 +1,24 @@
import aiohttp
from typing import Optional
def create_http_session(timeout: Optional[aiohttp.ClientTimeout] = None) -> aiohttp.ClientSession:
"""
Создаем сессию для запросов к бэку
"""
if timeout is None:
timeout = aiohttp.ClientTimeout(total=30, connect=10)
connector = aiohttp.TCPConnector(
limit=100,
limit_per_host=30
)
return aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={
"Accept": "application/json"
}
)

View File

@ -10,7 +10,8 @@ from tg_bot.infrastructure.telegram.handlers import (
stats_handler, stats_handler,
question_handler, question_handler,
buy_handler, buy_handler,
collection_handler collection_handler,
document_handler
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,17 +26,28 @@ async def create_bot() -> tuple[Bot, Dispatcher]:
dp.include_router(start_handler.router) dp.include_router(start_handler.router)
dp.include_router(help_handler.router) dp.include_router(help_handler.router)
dp.include_router(stats_handler.router) dp.include_router(stats_handler.router)
dp.include_router(question_handler.router)
dp.include_router(buy_handler.router) dp.include_router(buy_handler.router)
dp.include_router(collection_handler.router) dp.include_router(collection_handler.router)
dp.include_router(document_handler.router)
dp.include_router(question_handler.router)
return bot, dp return bot, dp
async def start_bot(): async def start_bot():
bot = None bot = None
try: try:
if not settings.TELEGRAM_BOT_TOKEN or not settings.TELEGRAM_BOT_TOKEN.strip():
raise ValueError("TELEGRAM_BOT_TOKEN не установлен в переменных окружения или файле .env")
bot, dp = await create_bot() bot, dp = await create_bot()
try:
bot_info = await bot.get_me()
username = bot_info.username if bot_info.username else f"ID: {bot_info.id}"
logger.info(f"Бот успешно подключен: @{username}")
except Exception as e:
raise ValueError(f"Неверный токен Telegram бота: {e}")
try: try:
webhook_info = await bot.get_webhook_info() webhook_info = await bot.get_webhook_info()
if webhook_info.url: if webhook_info.url:

View File

@ -2,16 +2,14 @@ from aiogram import Router, types
from aiogram.filters import Command from aiogram.filters import Command
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton
from decimal import Decimal from decimal import Decimal
import aiohttp
from tg_bot.config.settings import settings from tg_bot.config.settings import settings
from tg_bot.payment.yookassa.client import yookassa_client from tg_bot.payment.yookassa.client import yookassa_client
from tg_bot.infrastructure.database.database import AsyncSessionLocal from tg_bot.domain.user_service import UserService
from tg_bot.infrastructure.database.models import PaymentModel from datetime import datetime
from tg_bot.domain.services.user_service import UserService
from sqlalchemy import select
import uuid
from datetime import datetime, timedelta
router = Router() router = Router()
user_service = UserService()
@router.message(Command("buy")) @router.message(Command("buy"))
@ -19,9 +17,7 @@ async def cmd_buy(message: Message):
user_id = message.from_user.id user_id = message.from_user.id
username = message.from_user.username or f"user_{user_id}" username = message.from_user.username or f"user_{user_id}"
async with AsyncSessionLocal() as session:
try: try:
user_service = UserService(session)
user = await user_service.get_user_by_telegram_id(user_id) user = await user_service.get_user_by_telegram_id(user_id)
if user and user.is_premium and user.premium_until and user.premium_until > datetime.now(): if user and user.is_premium and user.premium_until and user.premium_until > datetime.now():
@ -34,8 +30,10 @@ async def cmd_buy(message: Message):
f"Новая подписка будет добавлена к текущей.", f"Новая подписка будет добавлена к текущей.",
parse_mode="HTML" parse_mode="HTML"
) )
except Exception: except aiohttp.ClientError as e:
pass print(f"Не удалось подключиться к backend при проверке подписки: {e}")
except Exception as e:
print(f"Ошибка при проверке подписки: {e}")
await message.answer( await message.answer(
"*Создаю ссылку для оплаты...*\n\n" "*Создаю ссылку для оплаты...*\n\n"
@ -50,23 +48,7 @@ async def cmd_buy(message: Message):
user_id=user_id user_id=user_id
) )
async with AsyncSessionLocal() as session: print(f"Платёж создан в ЮKассе: {payment_data['id']}")
try:
payment = PaymentModel(
payment_id=str(uuid.uuid4()),
user_id=user_id,
amount=str(settings.PAYMENT_AMOUNT),
currency="RUB",
status="pending",
yookassa_payment_id=payment_data["id"],
description="Оплата подписки VibeLawyerBot"
)
session.add(payment)
await session.commit()
print(f"Платёж сохранён в БД: {payment.payment_id}")
except Exception as e:
print(f"Ошибка сохранения платежа в БД: {e}")
await session.rollback()
keyboard = InlineKeyboardMarkup( keyboard = InlineKeyboardMarkup(
inline_keyboard=[ inline_keyboard=[
@ -139,27 +121,15 @@ async def check_payment_status(callback_query: types.CallbackQuery):
payment = YooPayment.find_one(yookassa_id) payment = YooPayment.find_one(yookassa_id)
if payment.status == "succeeded": if payment.status == "succeeded":
async with AsyncSessionLocal() as session:
try: try:
result = await session.execute(
select(PaymentModel).filter_by(yookassa_payment_id=yookassa_id)
)
db_payment = result.scalar_one_or_none()
if db_payment:
db_payment.status = "succeeded"
user_service = UserService(session)
success = await user_service.activate_premium(user_id) success = await user_service.activate_premium(user_id)
if success: if success:
user = await user_service.get_user_by_telegram_id(user_id) user = await user_service.get_user_by_telegram_id(user_id)
await session.commit() if user:
if not user:
user = await user_service.get_user_by_telegram_id(user_id)
await callback_query.message.answer( await callback_query.message.answer(
"<b>Оплата подтверждена!</b>\n\n" "<b>Оплата подтверждена!</b>\n\n"
f"Ваш premium-доступ активирован до: " f"Ваш premium-доступ активирован до: "
f"<b>{user.premium_until.strftime('%d.%m.%Y')}</b>\n\n" f"<b>{user.premium_until.strftime('%d.%m.%Y') if user.premium_until else 'Не указано'}</b>\n\n"
"Теперь вы можете:\n" "Теперь вы можете:\n"
"• Задавать неограниченное количество вопросов\n" "• Задавать неограниченное количество вопросов\n"
"• Получать приоритетные ответы\n" "• Получать приоритетные ответы\n"
@ -169,12 +139,23 @@ async def check_payment_status(callback_query: types.CallbackQuery):
) )
else: else:
await callback_query.message.answer( await callback_query.message.answer(
"<b>Платёж найден в ЮKассе, но не в нашей БД</b>\n\n" "<b>Оплата подтверждена, но не удалось активировать premium</b>\n\n"
"Пожалуйста, обратитесь к администратору.",
parse_mode="HTML"
)
else:
await callback_query.message.answer(
"<b>Оплата подтверждена, но не удалось активировать premium</b>\n\n"
"Пожалуйста, обратитесь к администратору.", "Пожалуйста, обратитесь к администратору.",
parse_mode="HTML" parse_mode="HTML"
) )
except Exception as e: except Exception as e:
print(f"Ошибка обработки платежа: {e}") print(f"Ошибка обработки платежа: {e}")
await callback_query.message.answer(
"<b>Ошибка активации premium</b>\n\n"
"Пожалуйста, обратитесь к администратору.",
parse_mode="HTML"
)
elif payment.status == "pending": elif payment.status == "pending":
await callback_query.message.answer( await callback_query.message.answer(
@ -206,42 +187,13 @@ async def check_payment_status(callback_query: types.CallbackQuery):
@router.message(Command("mypayments")) @router.message(Command("mypayments"))
async def cmd_my_payments(message: Message): async def cmd_my_payments(message: Message):
user_id = message.from_user.id
async with AsyncSessionLocal() as session:
try:
result = await session.execute(
select(PaymentModel).filter_by(user_id=user_id).order_by(PaymentModel.created_at.desc()).limit(10)
)
payments = result.scalars().all()
if not payments:
await message.answer( await message.answer(
"<b>У вас пока нет платежей</b>\n\n" "<b>История платежей</b>\n\n"
"Используйте команду /buy чтобы оформить подписку.", "История платежей хранится в системе оплаты ЮKassa.\n"
"Для проверки статуса подписки используйте команду /stats.\n\n"
"Для оформления новой подписки используйте команду /buy",
parse_mode="HTML" parse_mode="HTML"
) )
return
response = ["<b>Ваши последние платежи:</b>\n"]
for i, payment in enumerate(payments, 1):
status_text = "Успешно" if payment.status == "succeeded" else "Ожидание" if payment.status == "pending" else "Ошибка"
response.append(
f"\n<b>{i}. {payment.amount} руб. ({status_text})</b>\n"
f"Статус: {payment.status}\n"
f"Дата: {payment.created_at.strftime('%d.%m.%Y %H:%M')}\n"
f"ID: <code>{payment.payment_id[:8]}...</code>"
)
response.append("\n\n<i>Полный доступ открывается после успешной оплаты</i>")
await message.answer(
"\n".join(response),
parse_mode="HTML"
)
except Exception as e:
print(f"Ошибка получения платежей: {e}")
@router.message(Command("testcards")) @router.message(Command("testcards"))

View File

@ -1,18 +1,36 @@
from aiogram import Router from aiogram import Router, F
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, CallbackQuery from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, CallbackQuery
from aiogram.filters import Command from aiogram.filters import Command, StateFilter
from aiogram.fsm.context import FSMContext
import aiohttp import aiohttp
from urllib.parse import unquote
from tg_bot.config.settings import settings
from tg_bot.infrastructure.http_client import create_http_session
from tg_bot.infrastructure.telegram.states.collection_states import (
CollectionAccessStates,
CollectionEditStates
)
def decode_title(title: str) -> str:
if not title:
return "Без названия"
try:
decoded = unquote(title)
if decoded != title or '%' not in title:
return decoded
return title
except Exception:
return title
router = Router() router = Router()
BACKEND_URL = "http://localhost:8001/api/v1"
async def get_user_collections(telegram_id: str): async def get_user_collections(telegram_id: str):
try: try:
async with aiohttp.ClientSession() as session: async with create_http_session() as session:
async with session.get( async with session.get(
f"{BACKEND_URL}/collections/", f"{settings.BACKEND_URL}/collections/",
headers={"X-Telegram-ID": telegram_id} headers={"X-Telegram-ID": telegram_id}
) as response: ) as response:
if response.status == 200: if response.status == 200:
@ -25,24 +43,37 @@ async def get_user_collections(telegram_id: str):
async def get_collection_documents(collection_id: str, telegram_id: str): async def get_collection_documents(collection_id: str, telegram_id: str):
try: try:
async with aiohttp.ClientSession() as session: collection_id = str(collection_id).strip()
url = f"{settings.BACKEND_URL}/documents/collection/{collection_id}"
print(f"DEBUG get_collection_documents: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
async with create_http_session() as session:
async with session.get( async with session.get(
f"{BACKEND_URL}/documents/collection/{collection_id}", url,
headers={"X-Telegram-ID": telegram_id} headers={"X-Telegram-ID": telegram_id}
) as response: ) as response:
if response.status == 200: if response.status == 200:
return await response.json() return await response.json()
elif response.status == 422:
error_text = await response.text()
print(f"Validation error getting documents: {response.status} - {error_text}, collection_id: {collection_id}, URL: {url}")
return []
else:
error_text = await response.text()
print(f"Error getting documents: {response.status} - {error_text}, collection_id: {collection_id}, URL: {url}")
return [] return []
except Exception as e: except Exception as e:
print(f"Error getting documents: {e}") print(f"Exception getting documents: {e}, collection_id: {collection_id}, type: {type(collection_id)}")
import traceback
traceback.print_exc()
return [] return []
async def search_in_collection(collection_id: str, query: str, telegram_id: str): async def search_in_collection(collection_id: str, query: str, telegram_id: str):
try: try:
async with aiohttp.ClientSession() as session: async with create_http_session() as session:
async with session.get( async with session.get(
f"{BACKEND_URL}/documents/collection/{collection_id}", f"{settings.BACKEND_URL}/documents/collection/{collection_id}",
params={"search": query}, params={"search": query},
headers={"X-Telegram-ID": telegram_id} headers={"X-Telegram-ID": telegram_id}
) as response: ) as response:
@ -54,6 +85,91 @@ async def search_in_collection(collection_id: str, query: str, telegram_id: str)
return [] return []
async def get_collection_info(collection_id: str, telegram_id: str):
"""Получить информацию о коллекции"""
try:
collection_id = str(collection_id).strip()
url = f"{settings.BACKEND_URL}/collections/{collection_id}"
print(f"DEBUG get_collection_info: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
async with create_http_session() as session:
async with session.get(
url,
headers={"X-Telegram-ID": telegram_id}
) as response:
if response.status == 200:
return await response.json()
elif response.status == 422:
error_text = await response.text()
print(f"Validation error getting collection info: {response.status} - {error_text}, collection_id: {collection_id}, URL: {url}")
return None
else:
error_text = await response.text()
print(f"Error getting collection info: {response.status} - {error_text}, collection_id: {collection_id}, URL: {url}")
return None
except Exception as e:
print(f"Exception getting collection info: {e}, collection_id: {collection_id}, type: {type(collection_id)}")
import traceback
traceback.print_exc()
return None
async def get_collection_access_list(collection_id: str, telegram_id: str):
"""Получить список пользователей с доступом к коллекции"""
try:
async with create_http_session() as session:
async with session.get(
f"{settings.BACKEND_URL}/collections/{collection_id}/access",
headers={"X-Telegram-ID": telegram_id}
) as response:
if response.status == 200:
return await response.json()
return []
except Exception as e:
print(f"Error getting access list: {e}")
return []
async def grant_collection_access(collection_id: str, telegram_id: str, owner_telegram_id: str):
"""Предоставить доступ к коллекции"""
try:
url = f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}"
print(f"DEBUG grant_collection_access: URL={url}, target_telegram_id={telegram_id}, owner_telegram_id={owner_telegram_id}")
async with create_http_session() as session:
async with session.post(
url,
headers={"X-Telegram-ID": owner_telegram_id}
) as response:
if response.status == 201:
result = await response.json()
print(f"DEBUG: Access granted successfully: {result}")
return result
else:
error_text = await response.text()
print(f"ERROR granting access: status={response.status}, error={error_text}, target_telegram_id={telegram_id}")
return None
except Exception as e:
print(f"Exception granting access: {e}, target_telegram_id={telegram_id}")
import traceback
traceback.print_exc()
return None
async def revoke_collection_access(collection_id: str, telegram_id: str, owner_telegram_id: str):
"""Отозвать доступ к коллекции"""
try:
async with create_http_session() as session:
async with session.delete(
f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}",
headers={"X-Telegram-ID": owner_telegram_id}
) as response:
return response.status == 204
except Exception as e:
print(f"Error revoking access: {e}")
return False
@router.message(Command("mycollections")) @router.message(Command("mycollections"))
async def cmd_mycollections(message: Message): async def cmd_mycollections(message: Message):
telegram_id = str(message.from_user.id) telegram_id = str(message.from_user.id)
@ -140,7 +256,7 @@ async def cmd_search(message: Message):
response = f"<b>Результаты поиска:</b> \"{query}\"\n\n" response = f"<b>Результаты поиска:</b> \"{query}\"\n\n"
for i, doc in enumerate(results[:5], 1): for i, doc in enumerate(results[:5], 1):
title = doc.get("title", "Без названия") title = decode_title(doc.get("title", "Без названия"))
content = doc.get("content", "")[:200] content = doc.get("content", "")[:200]
response += f"{i}. <b>{title}</b>\n" response += f"{i}. <b>{title}</b>\n"
response += f" <i>{content}...</i>\n\n" response += f" <i>{content}...</i>\n\n"
@ -148,36 +264,495 @@ async def cmd_search(message: Message):
await message.answer(response, parse_mode="HTML") await message.answer(response, parse_mode="HTML")
@router.callback_query(lambda c: c.data.startswith("collection:")) @router.callback_query(lambda c: c.data.startswith("collection:") and not c.data.startswith("collection:documents:") and not c.data.startswith("collection:edit:") and not c.data.startswith("collection:access:") and not c.data.startswith("collection:view_access:"))
async def show_collection_documents(callback: CallbackQuery): async def show_collection_menu(callback: CallbackQuery):
collection_id = callback.data.split(":")[1] """Показать меню коллекции с опциями в зависимости от прав"""
parts = callback.data.split(":", 1)
if len(parts) < 2:
await callback.message.answer(
"<b>Ошибка</b>\n\nНеверный формат данных.",
parse_mode="HTML"
)
await callback.answer()
return
collection_id = parts[1]
telegram_id = str(callback.from_user.id) telegram_id = str(callback.from_user.id)
print(f"DEBUG: collection_id from callback (menu): {collection_id}, callback_data: {callback.data}")
await callback.answer("Загружаю информацию...")
collection_info = await get_collection_info(collection_id, telegram_id)
if not collection_info:
await callback.message.answer(
"<b>Ошибка</b>\n\nНе удалось загрузить информацию о коллекции.",
parse_mode="HTML"
)
return
owner_id = collection_info.get("owner_id")
collection_name = collection_info.get("name", "Коллекция")
try:
async with create_http_session() as session:
async with session.get(
f"{settings.BACKEND_URL}/users/telegram/{telegram_id}"
) as response:
if response.status == 200:
user_info = await response.json()
current_user_id = user_info.get("user_id")
is_owner = str(owner_id) == str(current_user_id)
else:
is_owner = False
except:
is_owner = False
keyboard_buttons = []
collection_id_str = str(collection_id)
if is_owner:
keyboard_buttons = [
[InlineKeyboardButton(text="Просмотр документов", callback_data=f"collection:documents:{collection_id_str}")],
[InlineKeyboardButton(text="Редактировать коллекцию", callback_data=f"collection:edit:{collection_id_str}")],
[InlineKeyboardButton(text="Управление доступом", callback_data=f"collection:access:{collection_id_str}")],
[InlineKeyboardButton(text="Загрузить документ", callback_data=f"document:upload:{collection_id_str}")],
[InlineKeyboardButton(text="Назад к коллекциям", callback_data="collections:list")]
]
else:
keyboard_buttons = [
[InlineKeyboardButton(text="Просмотр документов", callback_data=f"collection:documents:{collection_id_str}")],
[InlineKeyboardButton(text="Просмотр доступа", callback_data=f"collection:view_access:{collection_id_str}")],
[InlineKeyboardButton(text="Назад к коллекциям", callback_data="collections:list")]
]
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
role_text = "<b>Владелец</b>" if is_owner else "<b>Доступ</b>"
response = f"<b>{collection_name}</b>\n\n"
response += f"{role_text}\n\n"
response += f"ID: <code>{collection_id}</code>\n\n"
response += "Выберите действие:"
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
@router.callback_query(lambda c: c.data.startswith("collection:documents:"))
async def show_collection_documents(callback: CallbackQuery):
"""Показать документы коллекции"""
try:
parts = callback.data.split(":", 2)
if len(parts) < 3:
raise ValueError("Неверный формат callback_data")
collection_id = parts[2]
telegram_id = str(callback.from_user.id)
print(f"DEBUG: collection_id from callback: {collection_id}, callback_data: {callback.data}")
await callback.answer("Загружаю документы...") await callback.answer("Загружаю документы...")
collection_info = await get_collection_info(collection_id, telegram_id)
if not collection_info:
await callback.message.answer(
"<b>Ошибка</b>\n\nНе удалось загрузить информацию о коллекции. Проверьте, что у вас есть доступ к этой коллекции.",
parse_mode="HTML"
)
return
documents = await get_collection_documents(collection_id, telegram_id) documents = await get_collection_documents(collection_id, telegram_id)
if not documents: if not documents:
await callback.message.answer( await callback.message.answer(
f"<b>Коллекция пуста</b>\n\n" f"<b>Коллекция пуста</b>\n\n"
f"В этой коллекции пока нет документов.\n" f"В этой коллекции пока нет документов.",
f"Обратитесь к администратору для добавления документов.",
parse_mode="HTML" parse_mode="HTML"
) )
return return
except IndexError:
await callback.message.answer(
"<b>Ошибка</b>\n\nНеверный формат данных.",
parse_mode="HTML"
)
await callback.answer()
return
except Exception as e:
print(f"Error in show_collection_documents: {e}")
await callback.message.answer(
f"<b>Ошибка</b>\n\nПроизошла ошибка при загрузке документов: {str(e)}",
parse_mode="HTML"
)
await callback.answer()
return
response = f"<b>Документы в коллекции:</b>\n\n" response = f"<b>Документы в коллекции:</b>\n\n"
keyboard_buttons = []
for i, doc in enumerate(documents[:10], 1): for i, doc in enumerate(documents[:10], 1):
title = doc.get("title", "Без названия") doc_id = doc.get("document_id")
title = decode_title(doc.get("title", "Без названия"))
content_preview = doc.get("content", "")[:100] content_preview = doc.get("content", "")[:100]
response += f"{i}. <b>{title}</b>\n" response += f"{i}. <b>{title}</b>\n"
if content_preview: if content_preview:
response += f" <i>{content_preview}...</i>\n" response += f" <i>{content_preview}...</i>\n"
response += "\n" response += "\n"
keyboard_buttons.append([
InlineKeyboardButton(
text=f"{title[:30]}",
callback_data=f"document:view:{doc_id}"
)
])
if len(documents) > 10: if len(documents) > 10:
response += f"\n<i>Показано 10 из {len(documents)} документов</i>" response += f"\n<i>Показано 10 из {len(documents)} документов</i>"
await callback.message.answer(response, parse_mode="HTML")
collection_id_for_back = str(collection_info.get("collection_id", collection_id))
keyboard_buttons.append([
InlineKeyboardButton(text="Назад", callback_data=f"collection:{collection_id_for_back}")
])
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
@router.callback_query(lambda c: c.data.startswith("collection:access:"))
async def show_access_management(callback: CallbackQuery):
"""Показать меню управления доступом (только для владельца)"""
collection_id = callback.data.split(":")[2]
telegram_id = str(callback.from_user.id)
await callback.answer("Загружаю список доступа...")
access_list = await get_collection_access_list(collection_id, telegram_id)
response = "<b>Управление доступом</b>\n\n"
response += "<b>Пользователи с доступом:</b>\n\n"
keyboard_buttons = []
if access_list:
for i, access in enumerate(access_list[:10], 1):
user = access.get("user", {})
user_telegram_id = user.get("telegram_id", "N/A")
role = user.get("role", "user")
response += f"{i}. <code>{user_telegram_id}</code> ({role})\n"
keyboard_buttons.append([
InlineKeyboardButton(
text=f" Удалить {user_telegram_id}",
callback_data=f"access:remove:{collection_id}:{user_telegram_id}"
)
])
else:
response += "<i>Нет пользователей с доступом</i>\n\n"
keyboard_buttons.extend([
[InlineKeyboardButton(text="Добавить доступ", callback_data=f"access:add:{collection_id}")],
[InlineKeyboardButton(text="Назад", callback_data=f"collection:{collection_id}")]
])
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
@router.callback_query(lambda c: c.data.startswith("collection:view_access:"))
async def show_access_list(callback: CallbackQuery):
"""Показать список пользователей с доступом (read-only для пользователей с доступом)"""
collection_id = callback.data.split(":")[2]
telegram_id = str(callback.from_user.id)
await callback.answer("Загружаю список доступа...")
access_list = await get_collection_access_list(collection_id, telegram_id)
response = "<b>Пользователи с доступом</b>\n\n"
if access_list:
for i, access in enumerate(access_list[:20], 1):
user = access.get("user", {})
user_telegram_id = user.get("telegram_id", "N/A")
role = user.get("role", "user")
response += f"{i}. <code>{user_telegram_id}</code> ({role})\n"
else:
response += "<i>Нет пользователей с доступом</i>\n"
keyboard = InlineKeyboardMarkup(inline_keyboard=[[
InlineKeyboardButton(text="Назад", callback_data=f"collection:{collection_id}")
]])
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
@router.callback_query(lambda c: c.data.startswith("access:add:"))
async def add_access_prompt(callback: CallbackQuery, state: FSMContext):
"""Запросить пересылку сообщения для добавления доступа"""
collection_id = callback.data.split(":")[2]
telegram_id = str(callback.from_user.id)
await state.update_data(collection_id=collection_id)
await state.set_state(CollectionAccessStates.waiting_for_username)
await callback.message.answer(
"<b>Добавить доступ</b>\n\n"
"Перешлите любое сообщение от пользователя, которому нужно предоставить доступ.\n\n"
"<i>Просто перешлите сообщение от нужного пользователя.</i>",
parse_mode="HTML"
)
await callback.answer()
@router.message(StateFilter(CollectionAccessStates.waiting_for_username))
async def process_add_access(message: Message, state: FSMContext):
"""Обработать добавление доступа через пересылку сообщения"""
telegram_id = str(message.from_user.id)
data = await state.get_data()
collection_id = data.get("collection_id")
if not collection_id:
await message.answer("Ошибка: не указана коллекция")
await state.clear()
return
target_telegram_id = None
if message.forward_from:
target_telegram_id = str(message.forward_from.id)
elif message.forward_from_chat:
await message.answer(
"<b>Ошибка</b>\n\n"
"Пожалуйста, перешлите сообщение от пользователя, а не из группы или канала.",
parse_mode="HTML"
)
await state.clear()
return
elif message.forward_date:
await message.answer(
"<b>Информация о пересылке скрыта</b>\n\n"
"Пользователь скрыл информацию о пересылке в настройках приватности Telegram.\n\n"
"Попросите пользователя временно разрешить пересылку сообщений.",
parse_mode="HTML"
)
await state.clear()
return
else:
await message.answer(
"<b>Ошибка</b>\n\n"
"Пожалуйста, перешлите сообщение от пользователя, которому нужно предоставить доступ.\n\n"
"<i>Просто перешлите любое сообщение от нужного пользователя.</i>",
parse_mode="HTML"
)
await state.clear()
return
if not target_telegram_id:
await message.answer(
"<b>Ошибка</b>\n\n"
"Не удалось определить Telegram ID пользователя.",
parse_mode="HTML"
)
await state.clear()
return
print(f"DEBUG: Attempting to grant access: collection_id={collection_id}, target_telegram_id={target_telegram_id}, owner_telegram_id={telegram_id}")
result = await grant_collection_access(collection_id, target_telegram_id, telegram_id)
if result:
user_info = ""
if message.forward_from:
user_name = message.forward_from.first_name or ""
user_username = f"@{message.forward_from.username}" if message.forward_from.username else ""
user_info = f"{user_name} {user_username}".strip() or target_telegram_id
else:
user_info = target_telegram_id
await message.answer(
f"<b>Доступ предоставлен</b>\n\n"
f"Пользователю <code>{target_telegram_id}</code> предоставлен доступ к коллекции.\n\n"
f"Пользователь: {user_info}\n\n"
f"<i>Примечание: Если пользователь еще не взаимодействовал с ботом, он был автоматически создан в системе.</i>",
parse_mode="HTML"
)
else:
await message.answer(
"<b>Ошибка</b>\n\n"
"Не удалось предоставить доступ. Возможно:\n"
"• Доступ уже предоставлен\n"
"• Произошла ошибка на сервере\n"
"• Вы не являетесь владельцем коллекции\n\n"
"Проверьте логи сервера для получения подробной информации.",
parse_mode="HTML"
)
await state.clear()
@router.callback_query(lambda c: c.data.startswith("access:remove:"))
async def remove_access(callback: CallbackQuery):
"""Удалить доступ пользователя"""
parts = callback.data.split(":")
collection_id = parts[2]
target_telegram_id = parts[3]
owner_telegram_id = str(callback.from_user.id)
await callback.answer("Удаляю доступ...")
result = await revoke_collection_access(collection_id, target_telegram_id, owner_telegram_id)
if result:
await callback.message.answer(
f"<b>Доступ отозван</b>\n\n"
f"Доступ пользователя <code>{target_telegram_id}</code> отозван.",
parse_mode="HTML"
)
else:
await callback.message.answer(
"<b>Ошибка</b>\n\n"
"Не удалось отозвать доступ.",
parse_mode="HTML"
)
@router.callback_query(lambda c: c.data.startswith("collection:edit:"))
async def edit_collection_prompt(callback: CallbackQuery, state: FSMContext):
"""Запросить данные для редактирования коллекции"""
collection_id = callback.data.split(":")[2]
telegram_id = str(callback.from_user.id)
collection_info = await get_collection_info(collection_id, telegram_id)
if not collection_info:
await callback.message.answer(
"<b>Ошибка</b>\n\nНе удалось загрузить информацию о коллекции.",
parse_mode="HTML"
)
await callback.answer()
return
await state.update_data(collection_id=collection_id)
await state.set_state(CollectionEditStates.waiting_for_name)
await callback.message.answer(
"<b>Редактирование коллекции</b>\n\n"
"Отправьте новое название коллекции или /skip чтобы оставить текущее.\n\n"
f"Текущее название: <b>{collection_info.get('name', 'Без названия')}</b>",
parse_mode="HTML"
)
await callback.answer()
@router.message(StateFilter(CollectionEditStates.waiting_for_name))
async def process_edit_collection_name(message: Message, state: FSMContext):
"""Обработать новое название коллекции"""
telegram_id = str(message.from_user.id)
data = await state.get_data()
collection_id = data.get("collection_id")
if message.text and message.text.strip() == "/skip":
new_name = None
else:
new_name = message.text.strip() if message.text else None
await state.update_data(name=new_name)
await state.set_state(CollectionEditStates.waiting_for_description)
collection_info = await get_collection_info(collection_id, telegram_id)
current_description = collection_info.get("description", "") if collection_info else ""
await message.answer(
"<b>Описание коллекции</b>\n\n"
"Отправьте новое описание коллекции или /skip чтобы оставить текущее.\n\n"
f"Текущее описание: <i>{current_description[:100] if current_description else 'Нет описания'}...</i>",
parse_mode="HTML"
)
@router.message(StateFilter(CollectionEditStates.waiting_for_description))
async def process_edit_collection_description(message: Message, state: FSMContext):
"""Обработать новое описание коллекции"""
telegram_id = str(message.from_user.id)
data = await state.get_data()
collection_id = data.get("collection_id")
name = data.get("name")
if message.text and message.text.strip() == "/skip":
new_description = None
else:
new_description = message.text.strip() if message.text else None
try:
update_data = {}
if name:
update_data["name"] = name
if new_description:
update_data["description"] = new_description
async with create_http_session() as session:
async with session.put(
f"{settings.BACKEND_URL}/collections/{collection_id}",
json=update_data,
headers={"X-Telegram-ID": telegram_id}
) as response:
if response.status == 200:
await message.answer(
"<b>Коллекция обновлена</b>\n\n"
"Изменения сохранены.",
parse_mode="HTML"
)
else:
error_text = await response.text()
await message.answer(
f"<b>Ошибка</b>\n\n"
f"Не удалось обновить коллекцию: {error_text}",
parse_mode="HTML"
)
except Exception as e:
await message.answer(
f"<b>Ошибка</b>\n\n"
f"Произошла ошибка: {str(e)}",
parse_mode="HTML"
)
await state.clear()
@router.callback_query(lambda c: c.data == "collections:list")
async def back_to_collections(callback: CallbackQuery):
"""Вернуться к списку коллекций"""
telegram_id = str(callback.from_user.id)
collections = await get_user_collections(telegram_id)
if not collections:
await callback.message.answer(
"<b>У вас пока нет коллекций</b>\n\n"
"Обратитесь к администратору для создания коллекций и добавления документов.",
parse_mode="HTML"
)
return
response = "<b>Ваши коллекции документов:</b>\n\n"
keyboard_buttons = []
for i, collection in enumerate(collections[:10], 1):
name = collection.get("name", "Без названия")
description = collection.get("description", "")
collection_id = collection.get("collection_id")
response += f"{i}. <b>{name}</b>\n"
if description:
response += f" <i>{description[:50]}...</i>\n"
response += f" ID: <code>{collection_id}</code>\n\n"
keyboard_buttons.append([
InlineKeyboardButton(
text=f"{name}",
callback_data=f"collection:{collection_id}"
)
])
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
response += "<i>Нажмите на коллекцию, чтобы посмотреть документы</i>"
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)

View File

@ -0,0 +1,396 @@
"""
Обработчики для работы с документами
"""
from aiogram import Router, F
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, CallbackQuery
from aiogram.filters import StateFilter
from aiogram.fsm.context import FSMContext
import aiohttp
from urllib.parse import unquote
from tg_bot.config.settings import settings
from tg_bot.infrastructure.http_client import create_http_session
from tg_bot.infrastructure.telegram.states.collection_states import (
DocumentEditStates,
DocumentUploadStates
)
def decode_title(title: str) -> str:
"""Декодирует URL-encoded название документа"""
if not title:
return "Без названия"
try:
decoded = unquote(title)
if decoded != title or '%' not in title:
return decoded
return title
except Exception:
return title
router = Router()
async def get_document_info(document_id: str, telegram_id: str):
"""Получить информацию о документе"""
try:
async with create_http_session() as session:
async with session.get(
f"{settings.BACKEND_URL}/documents/{document_id}",
headers={"X-Telegram-ID": telegram_id}
) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
print(f"Error getting document info: {e}")
return None
async def delete_document(document_id: str, telegram_id: str):
"""Удалить документ"""
try:
async with create_http_session() as session:
async with session.delete(
f"{settings.BACKEND_URL}/documents/{document_id}",
headers={"X-Telegram-ID": telegram_id}
) as response:
return response.status == 204
except Exception as e:
print(f"Error deleting document: {e}")
return False
async def update_document(document_id: str, telegram_id: str, title: str = None, content: str = None):
"""Обновить документ"""
try:
update_data = {}
if title:
update_data["title"] = title
if content:
update_data["content"] = content
async with create_http_session() as session:
async with session.put(
f"{settings.BACKEND_URL}/documents/{document_id}",
json=update_data,
headers={"X-Telegram-ID": telegram_id}
) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
print(f"Error updating document: {e}")
return None
async def upload_document_to_collection(collection_id: str, file_data: bytes, filename: str, telegram_id: str):
"""Загрузить документ в коллекцию"""
try:
async with create_http_session() as session:
form_data = aiohttp.FormData()
form_data.add_field('file', file_data, filename=filename, content_type='application/octet-stream')
async with session.post(
f"{settings.BACKEND_URL}/documents/upload?collection_id={collection_id}",
data=form_data,
headers={"X-Telegram-ID": telegram_id}
) as response:
if response.status == 201:
return await response.json()
else:
error_text = await response.text()
print(f"Upload error: {response.status} - {error_text}")
return None
except Exception as e:
print(f"Error uploading document: {e}")
return None
@router.callback_query(lambda c: c.data.startswith("document:view:"))
async def view_document(callback: CallbackQuery):
"""Просмотр документа"""
document_id = callback.data.split(":")[2]
telegram_id = str(callback.from_user.id)
await callback.answer("Загружаю документ...")
document = await get_document_info(document_id, telegram_id)
if not document:
await callback.message.answer(
"<b>Ошибка</b>\n\nНе удалось загрузить документ.",
parse_mode="HTML"
)
return
title = decode_title(document.get("title", "Без названия"))
content = document.get("content", "")
collection_id = document.get("collection_id")
content_preview = content[:2000] if len(content) > 2000 else content
has_more = len(content) > 2000
response = f"<b>{title}</b>\n\n"
response += f"<i>{content_preview}</i>"
if has_more:
response += "\n\n<i>...</i>"
try:
async with create_http_session() as session:
async with session.get(
f"{settings.BACKEND_URL}/collections/{collection_id}",
headers={"X-Telegram-ID": telegram_id}
) as response_collection:
if response_collection.status == 200:
collection_info = await response_collection.json()
owner_id = collection_info.get("owner_id")
async with session.get(
f"{settings.BACKEND_URL}/users/telegram/{telegram_id}"
) as response_user:
if response_user.status == 200:
user_info = await response_user.json()
current_user_id = user_info.get("user_id")
is_owner = str(owner_id) == str(current_user_id)
keyboard_buttons = []
if is_owner:
keyboard_buttons = [
[InlineKeyboardButton(text="Редактировать", callback_data=f"document:edit:{document_id}")],
[InlineKeyboardButton(text="Удалить", callback_data=f"document:delete:{document_id}")],
[InlineKeyboardButton(text="Назад", callback_data=f"collection:documents:{collection_id}")]
]
else:
keyboard_buttons = [
[InlineKeyboardButton(text="Редактировать", callback_data=f"document:edit:{document_id}")],
[InlineKeyboardButton(text="Назад", callback_data=f"collection:documents:{collection_id}")]
]
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
return
except:
pass
keyboard = InlineKeyboardMarkup(inline_keyboard=[[
InlineKeyboardButton(text="Назад", callback_data=f"collection:documents:{collection_id}")
]])
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
@router.callback_query(lambda c: c.data.startswith("document:edit:"))
async def edit_document_prompt(callback: CallbackQuery, state: FSMContext):
"""Запросить данные для редактирования документа"""
document_id = callback.data.split(":")[2]
telegram_id = str(callback.from_user.id)
document = await get_document_info(document_id, telegram_id)
if not document:
await callback.message.answer(
"<b>Ошибка</b>\n\nНе удалось загрузить документ.",
parse_mode="HTML"
)
await callback.answer()
return
await state.update_data(document_id=document_id)
await state.set_state(DocumentEditStates.waiting_for_title)
await callback.message.answer(
"<b>Редактирование документа</b>\n\n"
"Отправьте новое название документа или /skip чтобы оставить текущее.\n\n"
f"Текущее название: <b>{decode_title(document.get('title', 'Без названия'))}</b>",
parse_mode="HTML"
)
await callback.answer()
@router.message(StateFilter(DocumentEditStates.waiting_for_title))
async def process_edit_title(message: Message, state: FSMContext):
"""Обработать новое название документа"""
telegram_id = str(message.from_user.id)
data = await state.get_data()
document_id = data.get("document_id")
if message.text and message.text.strip() == "/skip":
new_title = None
else:
new_title = message.text.strip() if message.text else None
await state.update_data(title=new_title)
await state.set_state(DocumentEditStates.waiting_for_content)
await message.answer(
"<b>Содержимое документа</b>\n\n"
"Отправьте новое содержимое документа или /skip чтобы оставить текущее.",
parse_mode="HTML"
)
@router.message(StateFilter(DocumentEditStates.waiting_for_content))
async def process_edit_content(message: Message, state: FSMContext):
"""Обработать новое содержимое документа"""
telegram_id = str(message.from_user.id)
data = await state.get_data()
document_id = data.get("document_id")
title = data.get("title")
if message.text and message.text.strip() == "/skip":
new_content = None
else:
new_content = message.text.strip() if message.text else None
result = await update_document(document_id, telegram_id, title=title, content=new_content)
if result:
await message.answer(
"<b>Документ обновлен</b>\n\n"
"Изменения сохранены.",
parse_mode="HTML"
)
else:
await message.answer(
"<b>Ошибка</b>\n\n"
"Не удалось обновить документ.",
parse_mode="HTML"
)
await state.clear()
@router.callback_query(lambda c: c.data.startswith("document:delete:"))
async def delete_document_confirm(callback: CallbackQuery):
"""Подтверждение удаления документа"""
document_id = callback.data.split(":")[2]
keyboard = InlineKeyboardMarkup(inline_keyboard=[
[InlineKeyboardButton(text="Да, удалить", callback_data=f"document:delete_confirm:{document_id}")],
[InlineKeyboardButton(text="Отмена", callback_data=f"document:view:{document_id}")]
])
await callback.message.answer(
"<b>Подтверждение удаления</b>\n\n"
"Вы уверены, что хотите удалить этот документ?",
parse_mode="HTML",
reply_markup=keyboard
)
await callback.answer()
@router.callback_query(lambda c: c.data.startswith("document:delete_confirm:"))
async def delete_document_execute(callback: CallbackQuery):
"""Выполнить удаление документа"""
document_id = callback.data.split(":")[2]
telegram_id = str(callback.from_user.id)
await callback.answer("Удаляю документ...")
# Получаем информацию о документе для возврата к коллекции
document = await get_document_info(document_id, telegram_id)
collection_id = document.get("collection_id") if document else None
result = await delete_document(document_id, telegram_id)
if result:
await callback.message.answer(
"<b>Документ удален</b>",
parse_mode="HTML"
)
else:
await callback.message.answer(
"<b>Ошибка</b>\n\n"
"Не удалось удалить документ.",
parse_mode="HTML"
)
@router.callback_query(lambda c: c.data.startswith("document:upload:"))
async def upload_document_prompt(callback: CallbackQuery, state: FSMContext):
"""Запросить файл для загрузки"""
collection_id = callback.data.split(":")[2]
telegram_id = str(callback.from_user.id)
await state.update_data(collection_id=collection_id)
await state.set_state(DocumentUploadStates.waiting_for_file)
await callback.message.answer(
"<b>Загрузка документа</b>\n\n"
"Отправьте файл (PDF, PNG, JPG, JPEG, TIFF, BMP).\n\n"
"Поддерживаемые форматы:\n"
"• PDF\n"
"• Изображения: PNG, JPG, JPEG, TIFF, BMP",
parse_mode="HTML"
)
await callback.answer()
@router.message(StateFilter(DocumentUploadStates.waiting_for_file), F.document | F.photo)
async def process_upload_document(message: Message, state: FSMContext):
"""Обработать загрузку документа"""
telegram_id = str(message.from_user.id)
data = await state.get_data()
collection_id = data.get("collection_id")
if not collection_id:
await message.answer("Ошибка: не указана коллекция")
await state.clear()
return
file_id = None
filename = None
if message.document:
file_id = message.document.file_id
filename = message.document.file_name or "document.pdf"
supported_extensions = ['.pdf', '.png', '.jpg', '.jpeg', '.tiff', '.bmp']
file_ext = filename.lower().rsplit('.', 1)[-1] if '.' in filename else ''
if f'.{file_ext}' not in supported_extensions:
await message.answer(
"<b>Ошибка</b>\n\n"
f"Неподдерживаемый формат файла: {file_ext}\n\n"
"Поддерживаются: PDF, PNG, JPG, JPEG, TIFF, BMP",
parse_mode="HTML"
)
await state.clear()
return
elif message.photo:
file_id = message.photo[-1].file_id
filename = "photo.jpg"
else:
await message.answer(
"<b>Ошибка</b>\n\n"
"Пожалуйста, отправьте файл (PDF или изображение).",
parse_mode="HTML"
)
await state.clear()
return
try:
file = await message.bot.get_file(file_id)
file_data = await message.bot.download_file(file.file_path)
file_bytes = file_data.read()
result = await upload_document_to_collection(collection_id, file_bytes, filename, telegram_id)
if result:
await message.answer(
f"<b>✅ Документ загружен и добавлен в коллекцию</b>\n\n"
f"<b>Название:</b> {decode_title(result.get('title', filename))}\n\n"
f"📄 Документ сейчас индексируется. Вы получите уведомление, когда индексация завершится.\n\n",
parse_mode="HTML"
)
else:
await message.answer(
"<b>Ошибка</b>\n\n"
"Не удалось загрузить документ.",
parse_mode="HTML"
)
except Exception as e:
print(f"Error uploading document: {e}")
await message.answer(
"<b>Ошибка</b>\n\n"
f"Произошла ошибка при загрузке: {str(e)}",
parse_mode="HTML"
)
await state.clear()

View File

@ -1,16 +1,14 @@
from aiogram import Router, types from aiogram import Router, types
from aiogram.types import Message from aiogram.types import Message
from datetime import datetime
import aiohttp
from tg_bot.config.settings import settings from tg_bot.config.settings import settings
from tg_bot.infrastructure.database.database import AsyncSessionLocal from tg_bot.domain.user_service import UserService, User
from tg_bot.infrastructure.database.models import UserModel
from tg_bot.domain.services.user_service import UserService
from tg_bot.application.services.rag_service import RAGService from tg_bot.application.services.rag_service import RAGService
import re
router = Router() router = Router()
BACKEND_URL = "http://localhost:8001/api/v1"
rag_service = RAGService() rag_service = RAGService()
user_service = UserService()
@router.message() @router.message()
async def handle_question(message: Message): async def handle_question(message: Message):
@ -19,9 +17,7 @@ async def handle_question(message: Message):
if question_text.startswith('/'): if question_text.startswith('/'):
return return
async with AsyncSessionLocal() as session:
try: try:
user_service = UserService(session)
user = await user_service.get_user_by_telegram_id(user_id) user = await user_service.get_user_by_telegram_id(user_id)
if not user: if not user:
@ -31,13 +27,12 @@ async def handle_question(message: Message):
message.from_user.first_name or "", message.from_user.first_name or "",
message.from_user.last_name or "" message.from_user.last_name or ""
) )
await ensure_user_in_backend(str(user_id), message.from_user)
if user.is_premium: if user.is_premium:
await process_premium_question(message, user, question_text, user_service) await process_premium_question(message, user, question_text)
elif user.questions_used < settings.FREE_QUESTIONS_LIMIT: elif user.questions_used < settings.FREE_QUESTIONS_LIMIT:
await process_free_question(message, user, question_text, user_service) await process_free_question(message, user, question_text)
else: else:
await handle_limit_exceeded(message, user) await handle_limit_exceeded(message, user)
@ -50,27 +45,9 @@ async def handle_question(message: Message):
) )
async def ensure_user_in_backend(telegram_id: str, telegram_user): async def process_premium_question(message: Message, user: User, question_text: str):
try: await user_service.update_user_questions(int(user.telegram_id))
async with aiohttp.ClientSession() as session: user = await user_service.get_user_by_telegram_id(int(user.telegram_id))
async with session.get(
f"{BACKEND_URL}/users/telegram/{telegram_id}"
) as response:
if response.status == 200:
return
async with session.post(
f"{BACKEND_URL}/users",
json={"telegram_id": telegram_id, "role": "user"}
) as create_response:
if create_response.status in [200, 201]:
print(f"Пользователь {telegram_id} создан в backend")
except Exception as e:
print(f"Error creating user in backend: {e}")
async def process_premium_question(message: Message, user: UserModel, question_text: str, user_service: UserService):
await user_service.update_user_questions(user.telegram_id)
await message.bot.send_chat_action(message.chat.id, "typing") await message.bot.send_chat_action(message.chat.id, "typing")
@ -83,37 +60,41 @@ async def process_premium_question(message: Message, user: UserModel, question_t
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.") answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
sources = rag_result.get("sources", []) sources = rag_result.get("sources", [])
await save_conversation_to_backend( # Беседа уже сохранена в бэкенде через API /rag/question
str(message.from_user.id),
question_text, import re
answer, formatted_answer = answer
sources formatted_answer = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', formatted_answer)
) formatted_answer = re.sub(r'^(\d+)\.\s+', r'\1. ', formatted_answer, flags=re.MULTILINE)
formatted_answer = formatted_answer.replace("- ", "")
response = ( response = (
f"<b>Ваш вопрос:</b>\n" f"<b>Ваш вопрос:</b>\n"
f"<i>{question_text[:200]}</i>\n\n" f"<i>{question_text[:200]}</i>\n\n"
f"<b>Ответ:</b>\n{answer}\n\n" f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"💬 <b>Ответ:</b>\n\n"
f"{formatted_answer}\n\n"
) )
if sources: if sources:
response += f"<b>Источники из коллекций:</b>\n" response += f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
collections_used = {} response += f"📚 <b>Источники:</b>\n"
for source in sources[:5]: for idx, source in enumerate(sources[:5], 1):
collection_name = source.get('collection', 'Неизвестно') title = source.get('title', 'Без названия')
if collection_name not in collections_used: try:
collections_used[collection_name] = [] from urllib.parse import unquote
collections_used[collection_name].append(source.get('title', 'Без названия')) decoded = unquote(title)
if decoded != title or '%' in title:
for i, (collection_name, titles) in enumerate(collections_used.items(), 1): title = decoded
response += f"{i}. <b>Коллекция:</b> {collection_name}\n" except:
for title in titles[:2]: pass
response += f" {title}\n" response += f" {idx}. {title}\n"
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n" response += "\n<i>💡 Используйте /mycollections для просмотра всех коллекций</i>\n\n"
response += ( response += (
f"<b>Статус:</b> Premium (вопросов безлимитно)\n" f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"<b>Всего вопросов:</b> {user.questions_used}" f"✨ <b>Статус:</b> Premium (вопросов безлимитно)\n"
f"📊 <b>Всего вопросов:</b> {user.questions_used}"
) )
except Exception as e: except Exception as e:
@ -121,17 +102,20 @@ async def process_premium_question(message: Message, user: UserModel, question_t
response = ( response = (
f"<b>Ваш вопрос:</b>\n" f"<b>Ваш вопрос:</b>\n"
f"<i>{question_text[:200]}</i>\n\n" f"<i>{question_text[:200]}</i>\n\n"
f"Ошибка при генерации ответа. Попробуйте позже.\n\n" f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"<b>Статус:</b> Premium\n" f"❌ <b>Ошибка при генерации ответа.</b>\n"
f"<b>Всего вопросов:</b> {user.questions_used}" f"Попробуйте позже.\n\n"
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"✨ <b>Статус:</b> Premium\n"
f"📊 <b>Всего вопросов:</b> {user.questions_used}"
) )
await message.answer(response, parse_mode="HTML") await message.answer(response, parse_mode="HTML")
async def process_free_question(message: Message, user: UserModel, question_text: str, user_service: UserService): async def process_free_question(message: Message, user: User, question_text: str):
await user_service.update_user_questions(user.telegram_id) await user_service.update_user_questions(int(user.telegram_id))
user = await user_service.get_user_by_telegram_id(user.telegram_id) user = await user_service.get_user_by_telegram_id(int(user.telegram_id))
remaining = settings.FREE_QUESTIONS_LIMIT - user.questions_used remaining = settings.FREE_QUESTIONS_LIMIT - user.questions_used
await message.bot.send_chat_action(message.chat.id, "typing") await message.bot.send_chat_action(message.chat.id, "typing")
@ -145,138 +129,69 @@ async def process_free_question(message: Message, user: UserModel, question_text
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.") answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
sources = rag_result.get("sources", []) sources = rag_result.get("sources", [])
await save_conversation_to_backend( # Уже все сохранили через /rag/question
str(message.from_user.id),
question_text,
answer,
sources
)
formatted_answer = answer
formatted_answer = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', formatted_answer)
formatted_answer = re.sub(r'^(\d+)\.\s+', r'\1. ', formatted_answer, flags=re.MULTILINE)
formatted_answer = formatted_answer.replace("- ", "")
response = ( response = (
f"<b>Ваш вопрос:</b>\n" f"<b>Ваш вопрос:</b>\n"
f"<i>{question_text[:200]}</i>\n\n" f"<i>{question_text[:200]}</i>\n\n"
f"<b>Ответ:</b>\n{answer}\n\n" f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"💬 <b>Ответ:</b>\n\n"
f"{formatted_answer}\n\n"
) )
if sources: if sources:
response += f"<b>Источники из коллекций:</b>\n" response += f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
collections_used = {} response += f"📚 <b>Источники:</b>\n"
for source in sources[:5]: for idx, source in enumerate(sources[:5], 1):
collection_name = source.get('collection', 'Неизвестно') title = source.get('title', 'Без названия')
if collection_name not in collections_used: try:
collections_used[collection_name] = [] from urllib.parse import unquote
collections_used[collection_name].append(source.get('title', 'Без названия')) decoded = unquote(title)
if decoded != title or '%' in title:
for i, (collection_name, titles) in enumerate(collections_used.items(), 1): title = decoded
response += f"{i}. <b>Коллекция:</b> {collection_name}\n" except:
for title in titles[:2]: pass
response += f" {title}\n" response += f" {idx}. {title}\n"
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n" response += "\n<i>💡 Используйте /mycollections для просмотра всех коллекций</i>\n\n"
response += ( response += (
f"<b>Статус:</b> Бесплатный доступ\n" f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"<b>Использовано вопросов:</b> {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n" f"📊 <b>Статус:</b> Бесплатный доступ\n"
f"<b>Осталось бесплатных:</b> {remaining}\n\n" f"📈 <b>Использовано вопросов:</b> {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n"
f"🎯 <b>Осталось бесплатных:</b> {remaining}\n\n"
) )
if remaining <= 3 and remaining > 0: if remaining <= 3 and remaining > 0:
response += f"<i>Осталось мало вопросов! Для продолжения используйте /buy</i>\n\n" response += f"⚠️ <i>Осталось мало вопросов! Для продолжения используйте /buy</i>\n\n"
response += f"<i>Для безлимитного доступа: /buy</i>" response += f"💎 <i>Для безлимитного доступа: /buy</i>"
except Exception as e: except Exception as e:
print(f"Error generating answer: {e}") print(f"Error generating answer: {e}")
response = ( response = (
f"<b>Ваш вопрос:</b>\n" f"<b>Ваш вопрос:</b>\n"
f"<i>{question_text[:200]}</i>\n\n" f"<i>{question_text[:200]}</i>\n\n"
f"Ошибка при генерации ответа. Попробуйте позже.\n\n" f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"<b>Статус:</b> Бесплатный доступ\n" f"❌ <b>Ошибка при генерации ответа.</b>\n"
f"<b>Использовано вопросов:</b> {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n" f"Попробуйте позже.\n\n"
f"<b>Осталось бесплатных:</b> {remaining}\n\n" f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
f"<i>Для безлимитного доступа: /buy</i>" f"📊 <b>Статус:</b> Бесплатный доступ\n"
f"📈 <b>Использовано вопросов:</b> {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n"
f"🎯 <b>Осталось бесплатных:</b> {remaining}\n\n"
f"💎 <i>Для безлимитного доступа: /buy</i>"
) )
await message.answer(response, parse_mode="HTML") await message.answer(response, parse_mode="HTML")
async def save_conversation_to_backend(telegram_id: str, question: str, answer: str, sources: list): #Сново сохраняется в /rag/question
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{BACKEND_URL}/users/telegram/{telegram_id}"
) as user_response:
if user_response.status != 200:
return
user_data = await user_response.json()
user_uuid = user_data.get("user_id")
async with session.get(
f"{BACKEND_URL}/collections/",
headers={"X-Telegram-ID": telegram_id}
) as collections_response:
collections = []
if collections_response.status == 200:
collections = await collections_response.json()
collection_id = None
if collections:
collection_id = collections[0].get("collection_id")
else:
async with session.post(
f"{BACKEND_URL}/collections",
json={
"name": "Основная коллекция",
"description": "Коллекция по умолчанию",
"is_public": False
},
headers={"X-Telegram-ID": telegram_id}
) as create_collection_response:
if create_collection_response.status in [200, 201]:
collection_data = await create_collection_response.json()
collection_id = collection_data.get("collection_id")
if not collection_id:
return
async with session.post(
f"{BACKEND_URL}/conversations",
json={"collection_id": str(collection_id)},
headers={"X-Telegram-ID": telegram_id}
) as conversation_response:
if conversation_response.status not in [200, 201]:
return
conversation_data = await conversation_response.json()
conversation_id = conversation_data.get("conversation_id")
if not conversation_id:
return
await session.post(
f"{BACKEND_URL}/messages",
json={
"conversation_id": str(conversation_id),
"content": question,
"role": "user"
},
headers={"X-Telegram-ID": telegram_id}
)
await session.post(
f"{BACKEND_URL}/messages",
json={
"conversation_id": str(conversation_id),
"content": answer,
"role": "assistant",
"sources": {"documents": sources}
},
headers={"X-Telegram-ID": telegram_id}
)
except Exception as e:
print(f"Error saving conversation: {e}")
async def handle_limit_exceeded(message: Message, user: UserModel): async def handle_limit_exceeded(message: Message, user: User):
response = ( response = (
f"<b>Лимит бесплатных вопросов исчерпан!</b>\n\n" f"<b>Лимит бесплатных вопросов исчерпан!</b>\n\n"

View File

@ -4,10 +4,10 @@ from aiogram.types import Message
from datetime import datetime from datetime import datetime
from tg_bot.config.settings import settings from tg_bot.config.settings import settings
from tg_bot.infrastructure.database.database import AsyncSessionLocal from tg_bot.domain.user_service import UserService
from tg_bot.domain.services.user_service import UserService
router = Router() router = Router()
user_service = UserService()
@router.message(Command("start")) @router.message(Command("start"))
async def cmd_start(message: Message): async def cmd_start(message: Message):
@ -16,9 +16,7 @@ async def cmd_start(message: Message):
username = message.from_user.username or "" username = message.from_user.username or ""
first_name = message.from_user.first_name or "" first_name = message.from_user.first_name or ""
last_name = message.from_user.last_name or "" last_name = message.from_user.last_name or ""
async with AsyncSessionLocal() as session:
try: try:
user_service = UserService(session)
existing_user = await user_service.get_user_by_telegram_id(user_id) existing_user = await user_service.get_user_by_telegram_id(user_id)
user = await user_service.get_or_create_user( user = await user_service.get_or_create_user(
user_id, user_id,
@ -31,7 +29,6 @@ async def cmd_start(message: Message):
except Exception as e: except Exception as e:
print(f"Ошибка сохранения пользователя: {e}") print(f"Ошибка сохранения пользователя: {e}")
await session.rollback()
welcome_text = ( welcome_text = (
f"<b>Привет, {first_name}!</b>\n\n" f"<b>Привет, {first_name}!</b>\n\n"
f"Я <b>VibeLawyerBot</b> - ваш помощник в юридических вопросах.\n\n" f"Я <b>VibeLawyerBot</b> - ваш помощник в юридических вопросах.\n\n"

View File

@ -4,19 +4,17 @@ from aiogram.filters import Command
from aiogram.types import Message from aiogram.types import Message
from tg_bot.config.settings import settings from tg_bot.config.settings import settings
from tg_bot.infrastructure.database.database import AsyncSessionLocal from tg_bot.domain.user_service import UserService
from tg_bot.domain.services.user_service import UserService
router = Router() router = Router()
user_service = UserService()
@router.message(Command("stats")) @router.message(Command("stats"))
async def cmd_stats(message: Message): async def cmd_stats(message: Message):
user_id = message.from_user.id user_id = message.from_user.id
async with AsyncSessionLocal() as session:
try: try:
user_service = UserService(session)
user = await user_service.get_user_by_telegram_id(user_id) user = await user_service.get_user_by_telegram_id(user_id)
if user: if user:

View File

@ -0,0 +1,27 @@
"""
FSM состояния для работы с коллекциями
"""
from aiogram.fsm.state import State, StatesGroup
class CollectionAccessStates(StatesGroup):
"""Состояния для управления доступом к коллекции"""
waiting_for_username = State()
class CollectionEditStates(StatesGroup):
"""Состояния для редактирования коллекции"""
waiting_for_name = State()
waiting_for_description = State()
class DocumentEditStates(StatesGroup):
"""Состояния для редактирования документа"""
waiting_for_title = State()
waiting_for_content = State()
class DocumentUploadStates(StatesGroup):
"""Состояния для загрузки документа"""
waiting_for_file = State()

View File

@ -1,19 +1,17 @@
import asyncio import asyncio
import logging import logging
import sys
import os import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
from tg_bot.config.settings import settings from tg_bot.config.settings import settings
log_file_path = settings.LOG_FILE
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[ handlers=[
logging.FileHandler(settings.LOG_FILE), logging.FileHandler(log_file_path),
logging.StreamHandler() logging.StreamHandler()
] ]
) )

View File

@ -18,10 +18,7 @@ async def handle_yookassa_webhook(request: Request):
print(f"Webhook received: {event_type}") print(f"Webhook received: {event_type}")
try: try:
from tg_bot.config.settings import settings from tg_bot.config.settings import settings
from tg_bot.domain.services.user_service import UserService from tg_bot.domain.user_service import UserService
from tg_bot.infrastructure.database.database import AsyncSessionLocal
from tg_bot.infrastructure.database.models import UserModel
from sqlalchemy import select
from aiogram import Bot from aiogram import Bot
if event_type == "payment.succeeded": if event_type == "payment.succeeded":
@ -29,16 +26,12 @@ async def handle_yookassa_webhook(request: Request):
user_id = payment.get("metadata", {}).get("user_id") user_id = payment.get("metadata", {}).get("user_id")
if user_id: if user_id:
async with AsyncSessionLocal() as session: user_service = UserService()
user_service = UserService(session)
success = await user_service.activate_premium(int(user_id)) success = await user_service.activate_premium(int(user_id))
if success: if success:
print(f"Premium activated for user {user_id}") print(f"Premium activated for user {user_id}")
result = await session.execute( user = await user_service.get_user_by_telegram_id(int(user_id))
select(UserModel).filter_by(telegram_id=str(user_id))
)
user = result.scalar_one_or_none()
if user and settings.TELEGRAM_BOT_TOKEN: if user and settings.TELEGRAM_BOT_TOKEN:
try: try:
@ -60,7 +53,7 @@ async def handle_yookassa_webhook(request: Request):
except Exception as e: except Exception as e:
print(f"Error sending notification: {e}") print(f"Error sending notification: {e}")
else: else:
print(f"User {user_id} not found") print(f"User {user_id} not found or failed to activate premium")
except ImportError as e: except ImportError as e:
print(f"Import error: {e}") print(f"Import error: {e}")

8
tg_bot/requirements.txt Normal file
View File

@ -0,0 +1,8 @@
pydantic>=2.5.0
pydantic-settings>=2.1.0
python-dotenv>=1.0.0
aiogram>=3.10.0
httpx>=0.25.2
yookassa>=2.4.0
aiohttp>=3.9.1

20
tg_bot/run.py Normal file
View File

@ -0,0 +1,20 @@
"""
Скрипт для запуска Telegram бота без Docker
"""
import sys
import os
from pathlib import Path
tg_bot_dir = Path(__file__).parent
sys.path.insert(0, str(tg_bot_dir))
if __name__ == "__main__":
from main import main
import asyncio
asyncio.run(main())