Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b0bbc739f3 | |||
| 42fcc0eb16 | |||
| 683f779c31 | |||
| ef71c67683 | |||
| 570f0b7ea7 | |||
| 1b550e6503 | |||
| 66392765b9 | |||
| 8cbc318c33 | |||
| 908e8fc435 | |||
| 5a46194d41 | |||
| 0b04ffefd2 | |||
| 5264b3b64c | |||
| 5809ac5688 | |||
| 8bdacb4f7a | |||
| 1ce1c23d10 | |||
| 5da6c32722 | |||
| 6b768261e2 | |||
| 6934220b52 | |||
| 79980eb313 | |||
| 9f111ad2c2 | |||
| 7b7165a44b | |||
| 193deb7a8c | |||
| 49c3d1b0fd | |||
| c4b3521257 | |||
| 169d874dad | |||
| dfc188e179 | |||
| 493c385cb1 | |||
|
|
71e8d1079e | ||
|
|
a7fc2487e9 | ||
|
|
b504bb26c8 | ||
|
|
1f0a5e5159 | ||
|
|
09dfe46a5b | ||
| cd08f88434 | |||
|
|
0bc47a9e7f | ||
|
|
93cf04a1cf | ||
| 5c8e07e7f1 |
@ -13,6 +13,7 @@ trigger:
|
|||||||
steps:
|
steps:
|
||||||
- name: deploy-backend
|
- name: deploy-backend
|
||||||
image: appleboy/drone-ssh
|
image: appleboy/drone-ssh
|
||||||
|
timeout: 30m
|
||||||
settings:
|
settings:
|
||||||
host:
|
host:
|
||||||
from_secret: server_host
|
from_secret: server_host
|
||||||
@ -21,10 +22,13 @@ steps:
|
|||||||
password:
|
password:
|
||||||
from_secret: server_password
|
from_secret: server_password
|
||||||
port: 22
|
port: 22
|
||||||
|
command_timeout: 30m
|
||||||
script:
|
script:
|
||||||
- cd BETTERCALLPRASKOVIA
|
- cd BetterCallPraskovia
|
||||||
- git pull origin main
|
- git pull origin main
|
||||||
- docker-compose stop backend tg_bot
|
- docker-compose stop backend tg_bot
|
||||||
- docker-compose up --build -d backend tg_bot
|
- docker-compose rm -f backend tg_bot
|
||||||
- docker system prune -f
|
- docker-compose build backend tg_bot
|
||||||
|
- docker-compose up -d --no-deps backend tg_bot
|
||||||
|
- docker image prune -f
|
||||||
|
|
||||||
1613
AI_api.yaml
1613
AI_api.yaml
File diff suppressed because it is too large
Load Diff
34
backend/.dockerignore
Normal file
34
backend/.dockerignore
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
.git/
|
||||||
|
.gitignore
|
||||||
|
.gitattributes
|
||||||
|
|
||||||
|
Dockerfile*
|
||||||
|
docker-compose*.yml
|
||||||
|
.dockerignore
|
||||||
|
|
||||||
|
drone.yml
|
||||||
|
|
||||||
|
tmp/
|
||||||
|
temp/
|
||||||
|
*.tmp
|
||||||
|
|
||||||
|
Thumbs.db
|
||||||
|
.DS_Store
|
||||||
@ -8,7 +8,7 @@ RUN apt-get update && apt-get install -y \
|
|||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install -r requirements.txt
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
|
|||||||
28
backend/alembic/versions/002_add_premium_fields.py
Normal file
28
backend/alembic/versions/002_add_premium_fields.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
"""Add premium fields to users
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2024-01-02 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision = '002'
|
||||||
|
down_revision = '001'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column('users', sa.Column('is_premium', sa.Boolean(), nullable=False, server_default='false'))
|
||||||
|
op.add_column('users', sa.Column('premium_until', sa.DateTime(), nullable=True))
|
||||||
|
op.add_column('users', sa.Column('questions_used', sa.Integer(), nullable=False, server_default='0'))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('users', 'questions_used')
|
||||||
|
op.drop_column('users', 'premium_until')
|
||||||
|
op.drop_column('users', 'is_premium')
|
||||||
|
|
||||||
33
backend/alembic/versions/003_remove_embeddings_table.py
Normal file
33
backend/alembic/versions/003_remove_embeddings_table.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""Remove unused embeddings table
|
||||||
|
|
||||||
|
Revision ID: 003
|
||||||
|
Revises: 002
|
||||||
|
Create Date: 2024-12-24 12:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision = '003'
|
||||||
|
down_revision = '002'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.drop_table('embeddings')
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'embeddings',
|
||||||
|
sa.Column('embedding_id', sa.dialects.postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('document_id', sa.dialects.postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('embedding', sa.dialects.postgresql.JSON(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column('model_version', sa.String(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['document_id'], ['documents.document_id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('embedding_id')
|
||||||
|
)
|
||||||
|
|
||||||
@ -1,4 +1,4 @@
|
|||||||
fastapi==0.104.1
|
fastapi==0.100.1
|
||||||
uvicorn[standard]==0.24.0
|
uvicorn[standard]==0.24.0
|
||||||
sqlalchemy[asyncio]==2.0.23
|
sqlalchemy[asyncio]==2.0.23
|
||||||
asyncpg==0.29.0
|
asyncpg==0.29.0
|
||||||
|
|||||||
23
backend/run.py
Normal file
23
backend/run.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
backend_dir = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(backend_dir))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"src.presentation.main:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8000,
|
||||||
|
reload=True,
|
||||||
|
log_level="info"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ class CacheService:
|
|||||||
"question": question,
|
"question": question,
|
||||||
"answer": answer
|
"answer": answer
|
||||||
}
|
}
|
||||||
await self.reids_client.set_json(key, value, ttl or self.default_ttl)
|
await self.redis_client.set_json(key, value, ttl or self.default_ttl)
|
||||||
|
|
||||||
async def invalidate_collection_cache(self, collection_id: UUID):
|
async def invalidate_collection_cache(self, collection_id: UUID):
|
||||||
pattern = f"rag:answer:{collection_id}:*"
|
pattern = f"rag:answer:{collection_id}:*"
|
||||||
|
|||||||
@ -67,6 +67,8 @@ class DocumentParserService:
|
|||||||
return title, content
|
return title, content
|
||||||
except YandexOCRError:
|
except YandexOCRError:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise YandexOCRError(f"Ошибка при парсинге изображения: {str(e)}") from e
|
raise YandexOCRError(f"Ошибка при парсинге изображения: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -42,9 +42,18 @@ class RAGService:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = self.embedding_service.embed_texts([c.content for c in chunks])
|
EMBEDDING_BATCH_SIZE = 50
|
||||||
|
all_embeddings: list[list[float]] = []
|
||||||
|
|
||||||
|
for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE):
|
||||||
|
batch_chunks = chunks[i:i + EMBEDDING_BATCH_SIZE]
|
||||||
|
batch_texts = [c.content for c in batch_chunks]
|
||||||
|
batch_embeddings = self.embedding_service.embed_texts(batch_texts)
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
|
print(f"Created {len(all_embeddings)} embeddings, upserting to Qdrant...")
|
||||||
await self.vector_repository.upsert_chunks(
|
await self.vector_repository.upsert_chunks(
|
||||||
chunks, embeddings, model_version=self.embedding_service.model_version()
|
chunks, all_embeddings, model_version=self.embedding_service.model_version()
|
||||||
)
|
)
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|||||||
@ -39,5 +39,10 @@ class TextSplitter:
|
|||||||
|
|
||||||
def _split_sentences(self, text: str) -> Iterable[str]:
|
def _split_sentences(self, text: str) -> Iterable[str]:
|
||||||
parts = re.split(r"(?<=[\.\?\!])\s+", text)
|
parts = re.split(r"(?<=[\.\?\!])\s+", text)
|
||||||
|
if len(parts) == 1 and len(text) > self.chunk_size * 2:
|
||||||
|
chunk_text = []
|
||||||
|
for i in range(0, len(text), self.chunk_size):
|
||||||
|
chunk_text.append(text[i:i + self.chunk_size])
|
||||||
|
return chunk_text
|
||||||
return [p.strip() for p in parts if p.strip()]
|
return [p.strip() for p in parts if p.strip()]
|
||||||
|
|
||||||
|
|||||||
@ -139,3 +139,67 @@ class CollectionUseCases:
|
|||||||
all_collections = {c.collection_id: c for c in owned + public + accessed_collections}
|
all_collections = {c.collection_id: c for c in owned + public + accessed_collections}
|
||||||
return list(all_collections.values())[skip:skip+limit]
|
return list(all_collections.values())[skip:skip+limit]
|
||||||
|
|
||||||
|
async def list_collection_access(self, collection_id: UUID, user_id: UUID) -> list[CollectionAccess]:
|
||||||
|
"""Получить список доступа к коллекции"""
|
||||||
|
collection = await self.get_collection(collection_id)
|
||||||
|
|
||||||
|
has_access = await self.check_access(collection_id, user_id)
|
||||||
|
if not has_access:
|
||||||
|
raise ForbiddenError("У вас нет доступа к этой коллекции")
|
||||||
|
|
||||||
|
return await self.access_repository.list_by_collection(collection_id)
|
||||||
|
|
||||||
|
async def grant_access_by_telegram_id(
|
||||||
|
self,
|
||||||
|
collection_id: UUID,
|
||||||
|
telegram_id: str,
|
||||||
|
owner_id: UUID
|
||||||
|
) -> CollectionAccess:
|
||||||
|
"""Предоставить доступ пользователю к коллекции по Telegram ID"""
|
||||||
|
collection = await self.get_collection(collection_id)
|
||||||
|
|
||||||
|
if collection.owner_id != owner_id:
|
||||||
|
raise ForbiddenError("Только владелец может предоставлять доступ")
|
||||||
|
|
||||||
|
user = await self.user_repository.get_by_telegram_id(telegram_id)
|
||||||
|
if not user:
|
||||||
|
from src.domain.entities.user import User, UserRole
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"Creating new user with telegram_id: {telegram_id}")
|
||||||
|
user = User(telegram_id=telegram_id, role=UserRole.USER)
|
||||||
|
try:
|
||||||
|
user = await self.user_repository.create(user)
|
||||||
|
logger.info(f"User created successfully: user_id={user.user_id}, telegram_id={user.telegram_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating user: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if user.user_id == owner_id:
|
||||||
|
raise ForbiddenError("Владелец уже имеет доступ к коллекции")
|
||||||
|
|
||||||
|
existing_access = await self.access_repository.get_by_user_and_collection(user.user_id, collection_id)
|
||||||
|
if existing_access:
|
||||||
|
return existing_access
|
||||||
|
|
||||||
|
access = CollectionAccess(user_id=user.user_id, collection_id=collection_id)
|
||||||
|
return await self.access_repository.create(access)
|
||||||
|
|
||||||
|
async def revoke_access_by_telegram_id(
|
||||||
|
self,
|
||||||
|
collection_id: UUID,
|
||||||
|
telegram_id: str,
|
||||||
|
owner_id: UUID
|
||||||
|
) -> bool:
|
||||||
|
"""Отозвать доступ пользователя к коллекции по Telegram ID"""
|
||||||
|
collection = await self.get_collection(collection_id)
|
||||||
|
|
||||||
|
if collection.owner_id != owner_id:
|
||||||
|
raise ForbiddenError("Только владелец может отзывать доступ")
|
||||||
|
|
||||||
|
user = await self.user_repository.get_by_telegram_id(telegram_id)
|
||||||
|
if not user:
|
||||||
|
raise NotFoundError(f"Пользователь с telegram_id {telegram_id} не найден")
|
||||||
|
|
||||||
|
return await self.access_repository.delete_by_user_and_collection(user.user_id, collection_id)
|
||||||
|
|
||||||
|
|||||||
@ -3,11 +3,15 @@ Use cases для работы с документами
|
|||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import BinaryIO, Optional
|
from typing import BinaryIO, Optional
|
||||||
|
import httpx
|
||||||
from src.domain.entities.document import Document
|
from src.domain.entities.document import Document
|
||||||
from src.domain.repositories.document_repository import IDocumentRepository
|
from src.domain.repositories.document_repository import IDocumentRepository
|
||||||
from src.domain.repositories.collection_repository import ICollectionRepository
|
from src.domain.repositories.collection_repository import ICollectionRepository
|
||||||
|
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
|
||||||
from src.application.services.document_parser_service import DocumentParserService
|
from src.application.services.document_parser_service import DocumentParserService
|
||||||
|
from src.application.services.rag_service import RAGService
|
||||||
from src.shared.exceptions import NotFoundError, ForbiddenError
|
from src.shared.exceptions import NotFoundError, ForbiddenError
|
||||||
|
from src.shared.config import settings
|
||||||
|
|
||||||
|
|
||||||
class DocumentUseCases:
|
class DocumentUseCases:
|
||||||
@ -17,11 +21,26 @@ class DocumentUseCases:
|
|||||||
self,
|
self,
|
||||||
document_repository: IDocumentRepository,
|
document_repository: IDocumentRepository,
|
||||||
collection_repository: ICollectionRepository,
|
collection_repository: ICollectionRepository,
|
||||||
parser_service: DocumentParserService
|
access_repository: ICollectionAccessRepository,
|
||||||
|
parser_service: DocumentParserService,
|
||||||
|
rag_service: Optional[RAGService] = None
|
||||||
):
|
):
|
||||||
self.document_repository = document_repository
|
self.document_repository = document_repository
|
||||||
self.collection_repository = collection_repository
|
self.collection_repository = collection_repository
|
||||||
|
self.access_repository = access_repository
|
||||||
self.parser_service = parser_service
|
self.parser_service = parser_service
|
||||||
|
self.rag_service = rag_service
|
||||||
|
|
||||||
|
async def _check_collection_access(self, user_id: UUID, collection) -> bool:
|
||||||
|
"""Проверить доступ пользователя к коллекции"""
|
||||||
|
if collection.owner_id == user_id:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if collection.is_public:
|
||||||
|
return True
|
||||||
|
|
||||||
|
access = await self.access_repository.get_by_user_and_collection(user_id, collection.collection_id)
|
||||||
|
return access is not None
|
||||||
|
|
||||||
async def create_document(
|
async def create_document(
|
||||||
self,
|
self,
|
||||||
@ -43,20 +62,43 @@ class DocumentUseCases:
|
|||||||
)
|
)
|
||||||
return await self.document_repository.create(document)
|
return await self.document_repository.create(document)
|
||||||
|
|
||||||
|
async def _send_telegram_notification(self, telegram_id: str, message: str):
|
||||||
|
"""Отправить уведомление пользователю через Telegram Bot API"""
|
||||||
|
if not settings.TELEGRAM_BOT_TOKEN:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = f"https://api.telegram.org/bot{settings.TELEGRAM_BOT_TOKEN}/sendMessage"
|
||||||
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
json={
|
||||||
|
"chat_id": telegram_id,
|
||||||
|
"text": message,
|
||||||
|
"parse_mode": "HTML"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Failed to send Telegram notification: {response.status_code}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error sending Telegram notification: {e}")
|
||||||
|
|
||||||
async def upload_and_parse_document(
|
async def upload_and_parse_document(
|
||||||
self,
|
self,
|
||||||
collection_id: UUID,
|
collection_id: UUID,
|
||||||
file: BinaryIO,
|
file: BinaryIO,
|
||||||
filename: str,
|
filename: str,
|
||||||
user_id: UUID
|
user_id: UUID,
|
||||||
|
telegram_id: Optional[str] = None
|
||||||
) -> Document:
|
) -> Document:
|
||||||
"""Загрузить и распарсить документ"""
|
"""Загрузить и распарсить документ, затем автоматически проиндексировать"""
|
||||||
collection = await self.collection_repository.get_by_id(collection_id)
|
collection = await self.collection_repository.get_by_id(collection_id)
|
||||||
if not collection:
|
if not collection:
|
||||||
raise NotFoundError(f"Коллекция {collection_id} не найдена")
|
raise NotFoundError(f"Коллекция {collection_id} не найдена")
|
||||||
|
|
||||||
if collection.owner_id != user_id:
|
has_access = await self._check_collection_access(user_id, collection)
|
||||||
raise ForbiddenError("Только владелец может добавлять документы")
|
if not has_access:
|
||||||
|
raise ForbiddenError("У вас нет доступа к этой коллекции")
|
||||||
|
|
||||||
title, content = await self.parser_service.parse_pdf(file, filename)
|
title, content = await self.parser_service.parse_pdf(file, filename)
|
||||||
|
|
||||||
@ -66,7 +108,41 @@ class DocumentUseCases:
|
|||||||
content=content,
|
content=content,
|
||||||
metadata={"filename": filename}
|
metadata={"filename": filename}
|
||||||
)
|
)
|
||||||
return await self.document_repository.create(document)
|
document = await self.document_repository.create(document)
|
||||||
|
|
||||||
|
if self.rag_service and telegram_id:
|
||||||
|
try:
|
||||||
|
await self._send_telegram_notification(
|
||||||
|
telegram_id,
|
||||||
|
"🔄 <b>Начинаю индексацию документа...</b>\n\n"
|
||||||
|
f"📄 <b>Документ:</b> {title}\n\n"
|
||||||
|
f"Это может занять некоторое время.\n"
|
||||||
|
f"Вы получите уведомление по завершении."
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = await self.rag_service.index_document(document)
|
||||||
|
|
||||||
|
await self._send_telegram_notification(
|
||||||
|
telegram_id,
|
||||||
|
"✅ <b>Индексация завершена!</b>\n\n"
|
||||||
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
|
f"📄 <b>Документ:</b> {title}\n"
|
||||||
|
f"📊 <b>Проиндексировано чанков:</b> {len(chunks)}\n\n"
|
||||||
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
|
f"💡 <b>Теперь вы можете задавать вопросы по этому документу!</b>\n"
|
||||||
|
f"Просто напишите ваш вопрос, и я найду ответ на основе загруженного документа."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ошибка при автоматической индексации документа {document.document_id}: {e}")
|
||||||
|
if telegram_id:
|
||||||
|
await self._send_telegram_notification(
|
||||||
|
telegram_id,
|
||||||
|
"⚠️ <b>Ошибка при индексации</b>\n\n"
|
||||||
|
f"Документ загружен, но индексация не завершена.\n"
|
||||||
|
f"Ошибка: {str(e)[:200]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return document
|
||||||
|
|
||||||
async def get_document(self, document_id: UUID) -> Document:
|
async def get_document(self, document_id: UUID) -> Document:
|
||||||
"""Получить документ по ID"""
|
"""Получить документ по ID"""
|
||||||
@ -87,8 +163,11 @@ class DocumentUseCases:
|
|||||||
document = await self.get_document(document_id)
|
document = await self.get_document(document_id)
|
||||||
|
|
||||||
collection = await self.collection_repository.get_by_id(document.collection_id)
|
collection = await self.collection_repository.get_by_id(document.collection_id)
|
||||||
if not collection or collection.owner_id != user_id:
|
if not collection:
|
||||||
raise ForbiddenError("Только владелец коллекции может изменять документы")
|
raise NotFoundError(f"Коллекция {document.collection_id} не найдена")
|
||||||
|
has_access = await self._check_collection_access(user_id, collection)
|
||||||
|
if not has_access:
|
||||||
|
raise ForbiddenError("У вас нет доступа к этой коллекции")
|
||||||
|
|
||||||
if title is not None:
|
if title is not None:
|
||||||
document.title = title
|
document.title = title
|
||||||
|
|||||||
@ -3,6 +3,7 @@ Use cases для работы с пользователями
|
|||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from src.domain.entities.user import User, UserRole
|
from src.domain.entities.user import User, UserRole
|
||||||
from src.domain.repositories.user_repository import IUserRepository
|
from src.domain.repositories.user_repository import IUserRepository
|
||||||
from src.shared.exceptions import NotFoundError, ValidationError
|
from src.shared.exceptions import NotFoundError, ValidationError
|
||||||
@ -53,3 +54,26 @@ class UserUseCases:
|
|||||||
"""Получить список пользователей"""
|
"""Получить список пользователей"""
|
||||||
return await self.user_repository.list_all(skip=skip, limit=limit)
|
return await self.user_repository.list_all(skip=skip, limit=limit)
|
||||||
|
|
||||||
|
async def increment_questions_used(self, telegram_id: str) -> User:
|
||||||
|
"""Увеличить счетчик использованных вопросов"""
|
||||||
|
user = await self.user_repository.get_by_telegram_id(telegram_id)
|
||||||
|
if not user:
|
||||||
|
raise NotFoundError(f"Пользователь с telegram_id {telegram_id} не найден")
|
||||||
|
|
||||||
|
user.questions_used += 1
|
||||||
|
return await self.user_repository.update(user)
|
||||||
|
|
||||||
|
async def activate_premium(self, telegram_id: str, days: int = 30) -> User:
|
||||||
|
"""Активировать premium статус"""
|
||||||
|
user = await self.user_repository.get_by_telegram_id(telegram_id)
|
||||||
|
if not user:
|
||||||
|
raise NotFoundError(f"Пользователь с telegram_id {telegram_id} не найден")
|
||||||
|
|
||||||
|
user.is_premium = True
|
||||||
|
if user.premium_until and user.premium_until > datetime.utcnow():
|
||||||
|
user.premium_until = user.premium_until + timedelta(days=days)
|
||||||
|
else:
|
||||||
|
user.premium_until = datetime.utcnow() + timedelta(days=days)
|
||||||
|
|
||||||
|
return await self.user_repository.update(user)
|
||||||
|
|
||||||
|
|||||||
@ -1,25 +0,0 @@
|
|||||||
"""
|
|
||||||
Доменная сущность Embedding
|
|
||||||
"""
|
|
||||||
from datetime import datetime
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding:
|
|
||||||
"""Эмбеддинг документа"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
document_id: UUID,
|
|
||||||
embedding: list[float] | None = None,
|
|
||||||
model_version: str = "",
|
|
||||||
embedding_id: UUID | None = None,
|
|
||||||
created_at: datetime | None = None
|
|
||||||
):
|
|
||||||
self.embedding_id = embedding_id or uuid4()
|
|
||||||
self.document_id = document_id
|
|
||||||
self.embedding = embedding or []
|
|
||||||
self.model_version = model_version
|
|
||||||
self.created_at = created_at or datetime.utcnow()
|
|
||||||
|
|
||||||
@ -20,12 +20,18 @@ class User:
|
|||||||
telegram_id: str,
|
telegram_id: str,
|
||||||
role: UserRole = UserRole.USER,
|
role: UserRole = UserRole.USER,
|
||||||
user_id: UUID | None = None,
|
user_id: UUID | None = None,
|
||||||
created_at: datetime | None = None
|
created_at: datetime | None = None,
|
||||||
|
is_premium: bool = False,
|
||||||
|
premium_until: datetime | None = None,
|
||||||
|
questions_used: int = 0
|
||||||
):
|
):
|
||||||
self.user_id = user_id or uuid4()
|
self.user_id = user_id or uuid4()
|
||||||
self.telegram_id = telegram_id
|
self.telegram_id = telegram_id
|
||||||
self.role = role
|
self.role = role
|
||||||
self.created_at = created_at or datetime.utcnow()
|
self.created_at = created_at or datetime.utcnow()
|
||||||
|
self.is_premium = is_premium
|
||||||
|
self.premium_until = premium_until
|
||||||
|
self.questions_used = questions_used
|
||||||
|
|
||||||
def is_admin(self) -> bool:
|
def is_admin(self) -> bool:
|
||||||
"""проверка, является ли пользователь администратором"""
|
"""проверка, является ли пользователь администратором"""
|
||||||
|
|||||||
@ -17,6 +17,10 @@ class UserModel(Base):
|
|||||||
telegram_id = Column(String, unique=True, nullable=False, index=True)
|
telegram_id = Column(String, unique=True, nullable=False, index=True)
|
||||||
role = Column(String, nullable=False, default="user")
|
role = Column(String, nullable=False, default="user")
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||||
|
is_premium = Column(Boolean, default=False, nullable=False)
|
||||||
|
premium_until = Column(DateTime, nullable=True)
|
||||||
|
questions_used = Column(Integer, default=0, nullable=False)
|
||||||
|
|
||||||
collections = relationship("CollectionModel", back_populates="owner", cascade="all, delete-orphan")
|
collections = relationship("CollectionModel", back_populates="owner", cascade="all, delete-orphan")
|
||||||
conversations = relationship("ConversationModel", back_populates="user", cascade="all, delete-orphan")
|
conversations = relationship("ConversationModel", back_populates="user", cascade="all, delete-orphan")
|
||||||
collection_accesses = relationship("CollectionAccessModel", back_populates="user", cascade="all, delete-orphan")
|
collection_accesses = relationship("CollectionAccessModel", back_populates="user", cascade="all, delete-orphan")
|
||||||
@ -49,19 +53,6 @@ class DocumentModel(Base):
|
|||||||
document_metadata = Column("metadata", JSON, nullable=True, default={})
|
document_metadata = Column("metadata", JSON, nullable=True, default={})
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||||
collection = relationship("CollectionModel", back_populates="documents")
|
collection = relationship("CollectionModel", back_populates="documents")
|
||||||
embeddings = relationship("EmbeddingModel", back_populates="document", cascade="all, delete-orphan")
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModel(Base):
|
|
||||||
"""Модель эмбеддинга (заглушка)"""
|
|
||||||
__tablename__ = "embeddings"
|
|
||||||
|
|
||||||
embedding_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
||||||
document_id = Column(UUID(as_uuid=True), ForeignKey("documents.document_id"), nullable=False)
|
|
||||||
embedding = Column(JSON, nullable=True)
|
|
||||||
model_version = Column(String, nullable=True)
|
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
|
||||||
document = relationship("DocumentModel", back_populates="embeddings")
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationModel(Base):
|
class ConversationModel(Base):
|
||||||
|
|||||||
@ -23,7 +23,10 @@ class PostgreSQLUserRepository(IUserRepository):
|
|||||||
user_id=user.user_id,
|
user_id=user.user_id,
|
||||||
telegram_id=user.telegram_id,
|
telegram_id=user.telegram_id,
|
||||||
role=user.role.value,
|
role=user.role.value,
|
||||||
created_at=user.created_at
|
created_at=user.created_at,
|
||||||
|
is_premium=user.is_premium,
|
||||||
|
premium_until=user.premium_until,
|
||||||
|
questions_used=user.questions_used
|
||||||
)
|
)
|
||||||
self.session.add(db_user)
|
self.session.add(db_user)
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
@ -57,6 +60,9 @@ class PostgreSQLUserRepository(IUserRepository):
|
|||||||
|
|
||||||
db_user.telegram_id = user.telegram_id
|
db_user.telegram_id = user.telegram_id
|
||||||
db_user.role = user.role.value
|
db_user.role = user.role.value
|
||||||
|
db_user.is_premium = user.is_premium
|
||||||
|
db_user.premium_until = user.premium_until
|
||||||
|
db_user.questions_used = user.questions_used
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
await self.session.refresh(db_user)
|
await self.session.refresh(db_user)
|
||||||
return self._to_entity(db_user)
|
return self._to_entity(db_user)
|
||||||
@ -90,6 +96,9 @@ class PostgreSQLUserRepository(IUserRepository):
|
|||||||
user_id=db_user.user_id,
|
user_id=db_user.user_id,
|
||||||
telegram_id=db_user.telegram_id,
|
telegram_id=db_user.telegram_id,
|
||||||
role=UserRole(db_user.role),
|
role=UserRole(db_user.role),
|
||||||
created_at=db_user.created_at
|
created_at=db_user.created_at,
|
||||||
|
is_premium=db_user.is_premium,
|
||||||
|
premium_until=db_user.premium_until,
|
||||||
|
questions_used=db_user.questions_used
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -36,6 +36,8 @@ class QdrantVectorRepository(IVectorRepository):
|
|||||||
embeddings: Sequence[list[float]],
|
embeddings: Sequence[list[float]],
|
||||||
model_version: str,
|
model_version: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
BATCH_SIZE = 100
|
||||||
|
|
||||||
points = []
|
points = []
|
||||||
for chunk, vector in zip(chunks, embeddings):
|
for chunk, vector in zip(chunks, embeddings):
|
||||||
points.append(
|
points.append(
|
||||||
@ -52,7 +54,13 @@ class QdrantVectorRepository(IVectorRepository):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.client.upsert(collection_name=self.collection_name, points=points)
|
|
||||||
|
if len(points) >= BATCH_SIZE:
|
||||||
|
self.client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
points = []
|
||||||
|
|
||||||
|
if points:
|
||||||
|
self.client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -2,10 +2,12 @@
|
|||||||
Админ-панель - упрощенная версия через API эндпоинты
|
Админ-панель - упрощенная версия через API эндпоинты
|
||||||
В будущем можно интегрировать полноценную админ-панель
|
В будущем можно интегрировать полноценную админ-панель
|
||||||
"""
|
"""
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from typing import List
|
from typing import List, Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from dishka.integrations.fastapi import FromDishka
|
from dishka.integrations.fastapi import FromDishka, inject
|
||||||
|
from src.domain.repositories.user_repository import IUserRepository
|
||||||
|
from src.presentation.middleware.auth_middleware import get_current_user
|
||||||
from src.presentation.schemas.user_schemas import UserResponse
|
from src.presentation.schemas.user_schemas import UserResponse
|
||||||
from src.presentation.schemas.collection_schemas import CollectionResponse
|
from src.presentation.schemas.collection_schemas import CollectionResponse
|
||||||
from src.presentation.schemas.document_schemas import DocumentResponse
|
from src.presentation.schemas.document_schemas import DocumentResponse
|
||||||
@ -19,13 +21,16 @@ router = APIRouter(prefix="/admin", tags=["admin"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/users", response_model=List[UserResponse])
|
@router.get("/users", response_model=List[UserResponse])
|
||||||
|
@inject
|
||||||
async def admin_list_users(
|
async def admin_list_users(
|
||||||
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[UserUseCases, FromDishka()],
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
|
||||||
use_cases: FromDishka[UserUseCases] = FromDishka()
|
|
||||||
):
|
):
|
||||||
"""Получить список всех пользователей (только для админов)"""
|
"""Получить список всех пользователей (только для админов)"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
if not current_user.is_admin():
|
if not current_user.is_admin():
|
||||||
raise HTTPException(status_code=403, detail="Требуются права администратора")
|
raise HTTPException(status_code=403, detail="Требуются права администратора")
|
||||||
users = await use_cases.list_users(skip=skip, limit=limit)
|
users = await use_cases.list_users(skip=skip, limit=limit)
|
||||||
@ -33,13 +38,16 @@ async def admin_list_users(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/collections", response_model=List[CollectionResponse])
|
@router.get("/collections", response_model=List[CollectionResponse])
|
||||||
|
@inject
|
||||||
async def admin_list_collections(
|
async def admin_list_collections(
|
||||||
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()],
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
|
||||||
use_cases: FromDishka[CollectionUseCases] = FromDishka()
|
|
||||||
):
|
):
|
||||||
"""Получить список всех коллекций (только для админов)"""
|
"""Получить список всех коллекций (только для админов)"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
from src.infrastructure.database.base import AsyncSessionLocal
|
from src.infrastructure.database.base import AsyncSessionLocal
|
||||||
from src.infrastructure.repositories.postgresql.collection_repository import PostgreSQLCollectionRepository
|
from src.infrastructure.repositories.postgresql.collection_repository import PostgreSQLCollectionRepository
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|||||||
@ -2,31 +2,37 @@
|
|||||||
API роутеры для работы с коллекциями
|
API роутеры для работы с коллекциями
|
||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status, Depends, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from typing import List
|
from typing import List, Annotated
|
||||||
from dishka.integrations.fastapi import FromDishka
|
from dishka.integrations.fastapi import FromDishka, inject
|
||||||
|
from src.domain.repositories.user_repository import IUserRepository
|
||||||
|
from src.presentation.middleware.auth_middleware import get_current_user
|
||||||
from src.presentation.schemas.collection_schemas import (
|
from src.presentation.schemas.collection_schemas import (
|
||||||
CollectionCreate,
|
CollectionCreate,
|
||||||
CollectionUpdate,
|
CollectionUpdate,
|
||||||
CollectionResponse,
|
CollectionResponse,
|
||||||
CollectionAccessGrant,
|
CollectionAccessGrant,
|
||||||
CollectionAccessResponse
|
CollectionAccessResponse,
|
||||||
|
CollectionAccessListResponse,
|
||||||
|
CollectionAccessUserInfo
|
||||||
)
|
)
|
||||||
from src.application.use_cases.collection_use_cases import CollectionUseCases
|
from src.application.use_cases.collection_use_cases import CollectionUseCases
|
||||||
from src.domain.entities.user import User
|
from src.domain.entities.user import User
|
||||||
from src.presentation.middleware.auth_middleware import get_current_user
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/collections", tags=["collections"])
|
router = APIRouter(prefix="/collections", tags=["collections"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=CollectionResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=CollectionResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@inject
|
||||||
async def create_collection(
|
async def create_collection(
|
||||||
collection_data: CollectionCreate,
|
collection_data: CollectionCreate,
|
||||||
current_user: User = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[CollectionUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Создать коллекцию"""
|
"""Создать коллекцию"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
collection = await use_cases.create_collection(
|
collection = await use_cases.create_collection(
|
||||||
name=collection_data.name,
|
name=collection_data.name,
|
||||||
owner_id=current_user.user_id,
|
owner_id=current_user.user_id,
|
||||||
@ -37,23 +43,36 @@ async def create_collection(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{collection_id}", response_model=CollectionResponse)
|
@router.get("/{collection_id}", response_model=CollectionResponse)
|
||||||
|
@inject
|
||||||
async def get_collection(
|
async def get_collection(
|
||||||
collection_id: UUID,
|
collection_id: UUID,
|
||||||
use_cases: FromDishka[CollectionUseCases] = FromDishka()
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Получить коллекцию по ID"""
|
"""Получить коллекцию по ID"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
collection = await use_cases.get_collection(collection_id)
|
collection = await use_cases.get_collection(collection_id)
|
||||||
|
|
||||||
|
has_access = await use_cases.check_access(collection_id, current_user.user_id)
|
||||||
|
if not has_access:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=403, detail="У вас нет доступа к этой коллекции")
|
||||||
|
|
||||||
return CollectionResponse.from_entity(collection)
|
return CollectionResponse.from_entity(collection)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{collection_id}", response_model=CollectionResponse)
|
@router.put("/{collection_id}", response_model=CollectionResponse)
|
||||||
|
@inject
|
||||||
async def update_collection(
|
async def update_collection(
|
||||||
collection_id: UUID,
|
collection_id: UUID,
|
||||||
collection_data: CollectionUpdate,
|
collection_data: CollectionUpdate,
|
||||||
current_user: User = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[CollectionUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Обновить коллекцию"""
|
"""Обновить коллекцию"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
collection = await use_cases.update_collection(
|
collection = await use_cases.update_collection(
|
||||||
collection_id=collection_id,
|
collection_id=collection_id,
|
||||||
user_id=current_user.user_id,
|
user_id=current_user.user_id,
|
||||||
@ -65,24 +84,30 @@ async def update_collection(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{collection_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{collection_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@inject
|
||||||
async def delete_collection(
|
async def delete_collection(
|
||||||
collection_id: UUID,
|
collection_id: UUID,
|
||||||
current_user: User = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[CollectionUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Удалить коллекцию"""
|
"""Удалить коллекцию"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
await use_cases.delete_collection(collection_id, current_user.user_id)
|
await use_cases.delete_collection(collection_id, current_user.user_id)
|
||||||
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
|
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[CollectionResponse])
|
@router.get("", response_model=List[CollectionResponse])
|
||||||
|
@inject
|
||||||
async def list_collections(
|
async def list_collections(
|
||||||
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()],
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100
|
||||||
current_user: User = FromDishka(),
|
|
||||||
use_cases: FromDishka[CollectionUseCases] = FromDishka()
|
|
||||||
):
|
):
|
||||||
"""Получить список коллекций, доступных пользователю"""
|
"""Получить список коллекций, доступных пользователю"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
collections = await use_cases.list_user_collections(
|
collections = await use_cases.list_user_collections(
|
||||||
user_id=current_user.user_id,
|
user_id=current_user.user_id,
|
||||||
skip=skip,
|
skip=skip,
|
||||||
@ -92,13 +117,16 @@ async def list_collections(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{collection_id}/access", response_model=CollectionAccessResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/{collection_id}/access", response_model=CollectionAccessResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@inject
|
||||||
async def grant_access(
|
async def grant_access(
|
||||||
collection_id: UUID,
|
collection_id: UUID,
|
||||||
access_data: CollectionAccessGrant,
|
access_data: CollectionAccessGrant,
|
||||||
current_user: User = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[CollectionUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Предоставить доступ пользователю к коллекции"""
|
"""Предоставить доступ пользователю к коллекции"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
access = await use_cases.grant_access(
|
access = await use_cases.grant_access(
|
||||||
collection_id=collection_id,
|
collection_id=collection_id,
|
||||||
user_id=access_data.user_id,
|
user_id=access_data.user_id,
|
||||||
@ -108,13 +136,91 @@ async def grant_access(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{collection_id}/access/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{collection_id}/access/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@inject
|
||||||
async def revoke_access(
|
async def revoke_access(
|
||||||
collection_id: UUID,
|
collection_id: UUID,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
current_user: User = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[CollectionUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Отозвать доступ пользователя к коллекции"""
|
"""Отозвать доступ пользователя к коллекции"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
await use_cases.revoke_access(collection_id, user_id, current_user.user_id)
|
await use_cases.revoke_access(collection_id, user_id, current_user.user_id)
|
||||||
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
|
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{collection_id}/access", response_model=List[CollectionAccessListResponse])
|
||||||
|
@inject
|
||||||
|
async def list_collection_access(
|
||||||
|
collection_id: UUID,
|
||||||
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()]
|
||||||
|
):
|
||||||
|
"""Получить список пользователей с доступом к коллекции"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
|
accesses = await use_cases.list_collection_access(collection_id, current_user.user_id)
|
||||||
|
result = []
|
||||||
|
for access in accesses:
|
||||||
|
user = await user_repo.get_by_id(access.user_id)
|
||||||
|
if user:
|
||||||
|
user_info = CollectionAccessUserInfo(
|
||||||
|
user_id=user.user_id,
|
||||||
|
telegram_id=user.telegram_id,
|
||||||
|
role=user.role.value,
|
||||||
|
created_at=user.created_at
|
||||||
|
)
|
||||||
|
result.append(CollectionAccessListResponse(
|
||||||
|
access_id=access.access_id,
|
||||||
|
user=user_info,
|
||||||
|
collection_id=access.collection_id,
|
||||||
|
created_at=access.created_at
|
||||||
|
))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{collection_id}/access/telegram/{telegram_id}", response_model=CollectionAccessResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@inject
|
||||||
|
async def grant_access_by_telegram_id(
|
||||||
|
collection_id: UUID,
|
||||||
|
telegram_id: str,
|
||||||
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()]
|
||||||
|
):
|
||||||
|
"""Предоставить доступ пользователю к коллекции по Telegram ID"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
|
logger.info(f"Granting access: collection_id={collection_id}, target_telegram_id={telegram_id}, owner_id={current_user.user_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
access = await use_cases.grant_access_by_telegram_id(
|
||||||
|
collection_id=collection_id,
|
||||||
|
telegram_id=telegram_id,
|
||||||
|
owner_id=current_user.user_id
|
||||||
|
)
|
||||||
|
logger.info(f"Access granted successfully: access_id={access.access_id}")
|
||||||
|
return CollectionAccessResponse.from_entity(access)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error granting access: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{collection_id}/access/telegram/{telegram_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@inject
|
||||||
|
async def revoke_access_by_telegram_id(
|
||||||
|
collection_id: UUID,
|
||||||
|
telegram_id: str,
|
||||||
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[CollectionUseCases, FromDishka()]
|
||||||
|
):
|
||||||
|
"""Отозвать доступ пользователя к коллекции по Telegram ID"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
|
await use_cases.revoke_access_by_telegram_id(collection_id, telegram_id, current_user.user_id)
|
||||||
|
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
|
||||||
|
|
||||||
|
|||||||
@ -2,10 +2,12 @@
|
|||||||
API роутеры для работы с беседами
|
API роутеры для работы с беседами
|
||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status, Depends, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from typing import List
|
from typing import List, Annotated
|
||||||
from dishka.integrations.fastapi import FromDishka
|
from dishka.integrations.fastapi import FromDishka, inject
|
||||||
|
from src.domain.repositories.user_repository import IUserRepository
|
||||||
|
from src.presentation.middleware.auth_middleware import get_current_user
|
||||||
from src.presentation.schemas.conversation_schemas import (
|
from src.presentation.schemas.conversation_schemas import (
|
||||||
ConversationCreate,
|
ConversationCreate,
|
||||||
ConversationResponse
|
ConversationResponse
|
||||||
@ -17,12 +19,15 @@ router = APIRouter(prefix="/conversations", tags=["conversations"])
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@inject
|
||||||
async def create_conversation(
|
async def create_conversation(
|
||||||
conversation_data: ConversationCreate,
|
conversation_data: ConversationCreate,
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[ConversationUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[ConversationUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Создать беседу"""
|
"""Создать беседу"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
conversation = await use_cases.create_conversation(
|
conversation = await use_cases.create_conversation(
|
||||||
user_id=current_user.user_id,
|
user_id=current_user.user_id,
|
||||||
collection_id=conversation_data.collection_id
|
collection_id=conversation_data.collection_id
|
||||||
@ -31,35 +36,44 @@ async def create_conversation(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{conversation_id}", response_model=ConversationResponse)
|
@router.get("/{conversation_id}", response_model=ConversationResponse)
|
||||||
|
@inject
|
||||||
async def get_conversation(
|
async def get_conversation(
|
||||||
conversation_id: UUID,
|
conversation_id: UUID,
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[ConversationUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[ConversationUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Получить беседу по ID"""
|
"""Получить беседу по ID"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
conversation = await use_cases.get_conversation(conversation_id, current_user.user_id)
|
conversation = await use_cases.get_conversation(conversation_id, current_user.user_id)
|
||||||
return ConversationResponse.from_entity(conversation)
|
return ConversationResponse.from_entity(conversation)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@inject
|
||||||
async def delete_conversation(
|
async def delete_conversation(
|
||||||
conversation_id: UUID,
|
conversation_id: UUID,
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[ConversationUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[ConversationUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Удалить беседу"""
|
"""Удалить беседу"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
await use_cases.delete_conversation(conversation_id, current_user.user_id)
|
await use_cases.delete_conversation(conversation_id, current_user.user_id)
|
||||||
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
|
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[ConversationResponse])
|
@router.get("", response_model=List[ConversationResponse])
|
||||||
|
@inject
|
||||||
async def list_conversations(
|
async def list_conversations(
|
||||||
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[ConversationUseCases, FromDishka()],
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
|
||||||
use_cases: FromDishka[ConversationUseCases] = FromDishka()
|
|
||||||
):
|
):
|
||||||
"""Получить список бесед пользователя"""
|
"""Получить список бесед пользователя"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
conversations = await use_cases.list_user_conversations(
|
conversations = await use_cases.list_user_conversations(
|
||||||
user_id=current_user.user_id,
|
user_id=current_user.user_id,
|
||||||
skip=skip,
|
skip=skip,
|
||||||
|
|||||||
@ -2,28 +2,34 @@
|
|||||||
API роутеры для работы с документами
|
API роутеры для работы с документами
|
||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from fastapi import APIRouter, status, UploadFile, File
|
from fastapi import APIRouter, status, UploadFile, File, Depends, Request, Query
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from typing import List
|
from typing import List, Annotated
|
||||||
from dishka.integrations.fastapi import FromDishka
|
from dishka.integrations.fastapi import FromDishka, inject
|
||||||
|
from src.domain.repositories.user_repository import IUserRepository
|
||||||
|
from src.presentation.middleware.auth_middleware import get_current_user
|
||||||
from src.presentation.schemas.document_schemas import (
|
from src.presentation.schemas.document_schemas import (
|
||||||
DocumentCreate,
|
DocumentCreate,
|
||||||
DocumentUpdate,
|
DocumentUpdate,
|
||||||
DocumentResponse
|
DocumentResponse
|
||||||
)
|
)
|
||||||
from src.application.use_cases.document_use_cases import DocumentUseCases
|
from src.application.use_cases.document_use_cases import DocumentUseCases
|
||||||
|
from src.application.use_cases.collection_use_cases import CollectionUseCases
|
||||||
from src.domain.entities.user import User
|
from src.domain.entities.user import User
|
||||||
|
|
||||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@inject
|
||||||
async def create_document(
|
async def create_document(
|
||||||
document_data: DocumentCreate,
|
document_data: DocumentCreate,
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[DocumentUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[DocumentUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Создать документ"""
|
"""Создать документ"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
document = await use_cases.create_document(
|
document = await use_cases.create_document(
|
||||||
collection_id=document_data.collection_id,
|
collection_id=document_data.collection_id,
|
||||||
title=document_data.title,
|
title=document_data.title,
|
||||||
@ -34,13 +40,16 @@ async def create_document(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/upload", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/upload", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@inject
|
||||||
async def upload_document(
|
async def upload_document(
|
||||||
collection_id: UUID,
|
collection_id: UUID = Query(...),
|
||||||
file: UploadFile = File(...),
|
request: Request = None,
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
user_repo: Annotated[IUserRepository, FromDishka()] = None,
|
||||||
use_cases: FromDishka[DocumentUseCases] = FromDishka()
|
use_cases: Annotated[DocumentUseCases, FromDishka()] = None,
|
||||||
|
file: UploadFile = File(...)
|
||||||
):
|
):
|
||||||
"""Загрузить и распарсить PDF документ или изображение"""
|
"""Загрузить и распарсить PDF документ или изображение"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
raise JSONResponse(
|
raise JSONResponse(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@ -60,15 +69,17 @@ async def upload_document(
|
|||||||
collection_id=collection_id,
|
collection_id=collection_id,
|
||||||
file=file.file,
|
file=file.file,
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
user_id=current_user.user_id
|
user_id=current_user.user_id,
|
||||||
|
telegram_id=current_user.telegram_id
|
||||||
)
|
)
|
||||||
return DocumentResponse.from_entity(document)
|
return DocumentResponse.from_entity(document)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{document_id}", response_model=DocumentResponse)
|
@router.get("/{document_id}", response_model=DocumentResponse)
|
||||||
|
@inject
|
||||||
async def get_document(
|
async def get_document(
|
||||||
document_id: UUID,
|
document_id: UUID,
|
||||||
use_cases: FromDishka[DocumentUseCases] = FromDishka()
|
use_cases: Annotated[DocumentUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Получить документ по ID"""
|
"""Получить документ по ID"""
|
||||||
document = await use_cases.get_document(document_id)
|
document = await use_cases.get_document(document_id)
|
||||||
@ -76,13 +87,16 @@ async def get_document(
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/{document_id}", response_model=DocumentResponse)
|
@router.put("/{document_id}", response_model=DocumentResponse)
|
||||||
|
@inject
|
||||||
async def update_document(
|
async def update_document(
|
||||||
document_id: UUID,
|
document_id: UUID,
|
||||||
document_data: DocumentUpdate,
|
document_data: DocumentUpdate,
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[DocumentUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[DocumentUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Обновить документ"""
|
"""Обновить документ"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
document = await use_cases.update_document(
|
document = await use_cases.update_document(
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
user_id=current_user.user_id,
|
user_id=current_user.user_id,
|
||||||
@ -94,24 +108,39 @@ async def update_document(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{document_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{document_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@inject
|
||||||
async def delete_document(
|
async def delete_document(
|
||||||
document_id: UUID,
|
document_id: UUID,
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[DocumentUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[DocumentUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Удалить документ"""
|
"""Удалить документ"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
await use_cases.delete_document(document_id, current_user.user_id)
|
await use_cases.delete_document(document_id, current_user.user_id)
|
||||||
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
|
return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=None)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/collection/{collection_id}", response_model=List[DocumentResponse])
|
@router.get("/collection/{collection_id}", response_model=List[DocumentResponse])
|
||||||
|
@inject
|
||||||
async def list_collection_documents(
|
async def list_collection_documents(
|
||||||
collection_id: UUID,
|
collection_id: UUID,
|
||||||
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[DocumentUseCases, FromDishka()],
|
||||||
|
collection_use_cases: Annotated[CollectionUseCases, FromDishka()],
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100
|
||||||
use_cases: FromDishka[DocumentUseCases] = FromDishka()
|
|
||||||
):
|
):
|
||||||
"""Получить документы коллекции"""
|
"""Получить документы коллекции"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
|
|
||||||
|
|
||||||
|
has_access = await collection_use_cases.check_access(collection_id, current_user.user_id)
|
||||||
|
if not has_access:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=403, detail="У вас нет доступа к этой коллекции")
|
||||||
|
|
||||||
documents = await use_cases.list_collection_documents(
|
documents = await use_cases.list_collection_documents(
|
||||||
collection_id=collection_id,
|
collection_id=collection_id,
|
||||||
skip=skip,
|
skip=skip,
|
||||||
|
|||||||
@ -2,10 +2,12 @@
|
|||||||
API роутеры для работы с сообщениями
|
API роутеры для работы с сообщениями
|
||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status, Depends, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from typing import List
|
from typing import List, Annotated
|
||||||
from dishka.integrations.fastapi import FromDishka
|
from dishka.integrations.fastapi import FromDishka, inject
|
||||||
|
from src.domain.repositories.user_repository import IUserRepository
|
||||||
|
from src.presentation.middleware.auth_middleware import get_current_user
|
||||||
from src.presentation.schemas.message_schemas import (
|
from src.presentation.schemas.message_schemas import (
|
||||||
MessageCreate,
|
MessageCreate,
|
||||||
MessageUpdate,
|
MessageUpdate,
|
||||||
@ -18,12 +20,15 @@ router = APIRouter(prefix="/messages", tags=["messages"])
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@inject
|
||||||
async def create_message(
|
async def create_message(
|
||||||
message_data: MessageCreate,
|
message_data: MessageCreate,
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
request: Request,
|
||||||
use_cases: FromDishka[MessageUseCases] = FromDishka()
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[MessageUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Создать сообщение"""
|
"""Создать сообщение"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
message = await use_cases.create_message(
|
message = await use_cases.create_message(
|
||||||
conversation_id=message_data.conversation_id,
|
conversation_id=message_data.conversation_id,
|
||||||
content=message_data.content,
|
content=message_data.content,
|
||||||
@ -35,9 +40,10 @@ async def create_message(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{message_id}", response_model=MessageResponse)
|
@router.get("/{message_id}", response_model=MessageResponse)
|
||||||
|
@inject
|
||||||
async def get_message(
|
async def get_message(
|
||||||
message_id: UUID,
|
message_id: UUID,
|
||||||
use_cases: FromDishka[MessageUseCases] = FromDishka()
|
use_cases: Annotated[MessageUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Получить сообщение по ID"""
|
"""Получить сообщение по ID"""
|
||||||
message = await use_cases.get_message(message_id)
|
message = await use_cases.get_message(message_id)
|
||||||
@ -45,10 +51,11 @@ async def get_message(
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/{message_id}", response_model=MessageResponse)
|
@router.put("/{message_id}", response_model=MessageResponse)
|
||||||
|
@inject
|
||||||
async def update_message(
|
async def update_message(
|
||||||
message_id: UUID,
|
message_id: UUID,
|
||||||
message_data: MessageUpdate,
|
message_data: MessageUpdate,
|
||||||
use_cases: FromDishka[MessageUseCases] = FromDishka()
|
use_cases: Annotated[MessageUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Обновить сообщение"""
|
"""Обновить сообщение"""
|
||||||
message = await use_cases.update_message(
|
message = await use_cases.update_message(
|
||||||
@ -60,9 +67,10 @@ async def update_message(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@inject
|
||||||
async def delete_message(
|
async def delete_message(
|
||||||
message_id: UUID,
|
message_id: UUID,
|
||||||
use_cases: FromDishka[MessageUseCases] = FromDishka()
|
use_cases: Annotated[MessageUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Удалить сообщение"""
|
"""Удалить сообщение"""
|
||||||
await use_cases.delete_message(message_id)
|
await use_cases.delete_message(message_id)
|
||||||
@ -70,14 +78,17 @@ async def delete_message(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/conversation/{conversation_id}", response_model=List[MessageResponse])
|
@router.get("/conversation/{conversation_id}", response_model=List[MessageResponse])
|
||||||
|
@inject
|
||||||
async def list_conversation_messages(
|
async def list_conversation_messages(
|
||||||
conversation_id: UUID,
|
conversation_id: UUID,
|
||||||
|
request: Request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[MessageUseCases, FromDishka()],
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
|
||||||
use_cases: FromDishka[MessageUseCases] = FromDishka()
|
|
||||||
):
|
):
|
||||||
"""Получить сообщения беседы"""
|
"""Получить сообщения беседы"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
messages = await use_cases.list_conversation_messages(
|
messages = await use_cases.list_conversation_messages(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
user_id=current_user.user_id,
|
user_id=current_user.user_id,
|
||||||
|
|||||||
@ -1,39 +1,31 @@
|
|||||||
"""
|
"""
|
||||||
API для RAG: индексация документов и ответы на вопросы
|
API для RAG: ответы на вопросы
|
||||||
"""
|
"""
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status, Request
|
||||||
from dishka.integrations.fastapi import FromDishka
|
from typing import Annotated
|
||||||
|
from dishka.integrations.fastapi import FromDishka, inject
|
||||||
|
from src.domain.repositories.user_repository import IUserRepository
|
||||||
|
from src.presentation.middleware.auth_middleware import get_current_user
|
||||||
from src.presentation.schemas.rag_schemas import (
|
from src.presentation.schemas.rag_schemas import (
|
||||||
QuestionRequest,
|
QuestionRequest,
|
||||||
RAGAnswer,
|
RAGAnswer,
|
||||||
IndexDocumentRequest,
|
|
||||||
IndexDocumentResponse,
|
|
||||||
)
|
)
|
||||||
from src.application.use_cases.rag_use_cases import RAGUseCases
|
from src.application.use_cases.rag_use_cases import RAGUseCases
|
||||||
from src.domain.entities.user import User
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/rag", tags=["rag"])
|
router = APIRouter(prefix="/rag", tags=["rag"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/index", response_model=IndexDocumentResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def index_document(
|
|
||||||
body: IndexDocumentRequest,
|
|
||||||
use_cases: FromDishka[RAGUseCases] = FromDishka(),
|
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
|
||||||
):
|
|
||||||
"""Индексирование идет через чанкирование, далее эмбеддинг и загрузка в векторную бд"""
|
|
||||||
result = await use_cases.index_document(body.document_id)
|
|
||||||
return IndexDocumentResponse(**result)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/question", response_model=RAGAnswer, status_code=status.HTTP_200_OK)
|
@router.post("/question", response_model=RAGAnswer, status_code=status.HTTP_200_OK)
|
||||||
|
@inject
|
||||||
async def ask_question(
|
async def ask_question(
|
||||||
body: QuestionRequest,
|
body: QuestionRequest,
|
||||||
use_cases: FromDishka[RAGUseCases] = FromDishka(),
|
request: Request,
|
||||||
current_user: FromDishka[User] = FromDishka(),
|
user_repo: Annotated[IUserRepository, FromDishka()],
|
||||||
|
use_cases: Annotated[RAGUseCases, FromDishka()],
|
||||||
):
|
):
|
||||||
"""Отвечает на вопрос, используя RAG в рамках беседы"""
|
"""Отвечает на вопрос, используя RAG в рамках беседы"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
result = await use_cases.ask_question(
|
result = await use_cases.ask_question(
|
||||||
conversation_id=body.conversation_id,
|
conversation_id=body.conversation_id,
|
||||||
user_id=current_user.user_id,
|
user_id=current_user.user_id,
|
||||||
|
|||||||
@ -2,10 +2,12 @@
|
|||||||
API роутеры для работы с пользователями
|
API роутеры для работы с пользователями
|
||||||
"""
|
"""
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status, Depends, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from typing import List
|
from typing import List, Annotated
|
||||||
from dishka.integrations.fastapi import FromDishka
|
from dishka.integrations.fastapi import FromDishka, inject
|
||||||
|
from src.domain.repositories.user_repository import IUserRepository
|
||||||
|
from src.presentation.middleware.auth_middleware import get_current_user
|
||||||
from src.presentation.schemas.user_schemas import UserCreate, UserUpdate, UserResponse
|
from src.presentation.schemas.user_schemas import UserCreate, UserUpdate, UserResponse
|
||||||
from src.application.use_cases.user_use_cases import UserUseCases
|
from src.application.use_cases.user_use_cases import UserUseCases
|
||||||
from src.domain.entities.user import User
|
from src.domain.entities.user import User
|
||||||
@ -14,9 +16,10 @@ router = APIRouter(prefix="/users", tags=["users"])
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@inject
|
||||||
async def create_user(
|
async def create_user(
|
||||||
user_data: UserCreate,
|
user_data: UserCreate,
|
||||||
use_cases: FromDishka[UserUseCases] = FromDishka()
|
use_cases: Annotated[UserUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Создать пользователя"""
|
"""Создать пользователя"""
|
||||||
user = await use_cases.create_user(
|
user = await use_cases.create_user(
|
||||||
@ -27,17 +30,59 @@ async def create_user(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
@inject
|
||||||
async def get_current_user_info(
|
async def get_current_user_info(
|
||||||
current_user: FromDishka[User] = FromDishka()
|
request,
|
||||||
|
user_repo: Annotated[IUserRepository, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Получить информацию о текущем пользователе"""
|
"""Получить информацию о текущем пользователе"""
|
||||||
|
current_user = await get_current_user(request, user_repo)
|
||||||
return UserResponse.from_entity(current_user)
|
return UserResponse.from_entity(current_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/telegram/{telegram_id}", response_model=UserResponse)
|
||||||
|
@inject
|
||||||
|
async def get_user_by_telegram_id(
|
||||||
|
telegram_id: str,
|
||||||
|
use_cases: Annotated[UserUseCases, FromDishka()]
|
||||||
|
):
|
||||||
|
"""Получить пользователя по Telegram ID"""
|
||||||
|
user = await use_cases.get_user_by_telegram_id(telegram_id)
|
||||||
|
if not user:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=404, detail=f"Пользователь с telegram_id {telegram_id} не найден")
|
||||||
|
return UserResponse.from_entity(user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/telegram/{telegram_id}/increment-questions", response_model=UserResponse)
|
||||||
|
@inject
|
||||||
|
async def increment_questions(
|
||||||
|
telegram_id: str,
|
||||||
|
use_cases: Annotated[UserUseCases, FromDishka()]
|
||||||
|
):
|
||||||
|
"""Увеличить счетчик использованных вопросов"""
|
||||||
|
user = await use_cases.increment_questions_used(telegram_id)
|
||||||
|
return UserResponse.from_entity(user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/telegram/{telegram_id}/activate-premium", response_model=UserResponse)
|
||||||
|
@inject
|
||||||
|
async def activate_premium(
|
||||||
|
|
||||||
|
use_cases: Annotated[UserUseCases, FromDishka()],
|
||||||
|
telegram_id: str,
|
||||||
|
days: int = 30,
|
||||||
|
):
|
||||||
|
"""Активировать premium статус"""
|
||||||
|
user = await use_cases.activate_premium(telegram_id, days=days)
|
||||||
|
return UserResponse.from_entity(user)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=UserResponse)
|
@router.get("/{user_id}", response_model=UserResponse)
|
||||||
|
@inject
|
||||||
async def get_user(
|
async def get_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
use_cases: FromDishka[UserUseCases] = FromDishka()
|
use_cases: Annotated[UserUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Получить пользователя по ID"""
|
"""Получить пользователя по ID"""
|
||||||
user = await use_cases.get_user(user_id)
|
user = await use_cases.get_user(user_id)
|
||||||
@ -45,10 +90,11 @@ async def get_user(
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=UserResponse)
|
@router.put("/{user_id}", response_model=UserResponse)
|
||||||
|
@inject
|
||||||
async def update_user(
|
async def update_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
user_data: UserUpdate,
|
user_data: UserUpdate,
|
||||||
use_cases: FromDishka[UserUseCases] = FromDishka()
|
use_cases: Annotated[UserUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Обновить пользователя"""
|
"""Обновить пользователя"""
|
||||||
user = await use_cases.update_user(
|
user = await use_cases.update_user(
|
||||||
@ -60,9 +106,10 @@ async def update_user(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@inject
|
||||||
async def delete_user(
|
async def delete_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
use_cases: FromDishka[UserUseCases] = FromDishka()
|
use_cases: Annotated[UserUseCases, FromDishka()]
|
||||||
):
|
):
|
||||||
"""Удалить пользователя"""
|
"""Удалить пользователя"""
|
||||||
await use_cases.delete_user(user_id)
|
await use_cases.delete_user(user_id)
|
||||||
@ -70,10 +117,11 @@ async def delete_user(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[UserResponse])
|
@router.get("", response_model=List[UserResponse])
|
||||||
|
@inject
|
||||||
async def list_users(
|
async def list_users(
|
||||||
|
use_cases: Annotated[UserUseCases, FromDishka()],
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100
|
||||||
use_cases: FromDishka[UserUseCases] = FromDishka()
|
|
||||||
):
|
):
|
||||||
"""Получить список пользователей"""
|
"""Получить список пользователей"""
|
||||||
users = await use_cases.list_users(skip=skip, limit=limit)
|
users = await use_cases.list_users(skip=skip, limit=limit)
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
if '/app' not in sys.path:
|
backend_dir = Path(__file__).parent.parent.parent
|
||||||
sys.path.insert(0, '/app')
|
if str(backend_dir) not in sys.path:
|
||||||
|
sys.path.insert(0, str(backend_dir))
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -20,15 +23,17 @@ from src.infrastructure.database.base import engine, Base
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Управление жизненным циклом приложения"""
|
"""Управление жизненным циклом приложения"""
|
||||||
container = create_container()
|
|
||||||
setup_dishka(container, app)
|
|
||||||
try:
|
try:
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Примечание при создании таблиц: {e}")
|
print(f"Примечание при создании таблиц: {e}")
|
||||||
yield
|
yield
|
||||||
await container.close()
|
if hasattr(app.state, 'container') and hasattr(app.state.container, 'close'):
|
||||||
|
if asyncio.iscoroutinefunction(app.state.container.close):
|
||||||
|
await app.state.container.close()
|
||||||
|
else:
|
||||||
|
app.state.container.close()
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
@ -39,6 +44,10 @@ app = FastAPI(
|
|||||||
lifespan=lifespan
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
|
container = create_container()
|
||||||
|
setup_dishka(container, app)
|
||||||
|
app.state.container = container
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=settings.CORS_ORIGINS,
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
|||||||
@ -75,3 +75,22 @@ class CollectionAccessResponse(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionAccessUserInfo(BaseModel):
|
||||||
|
"""Информация о пользователе с доступом"""
|
||||||
|
user_id: UUID
|
||||||
|
telegram_id: str
|
||||||
|
role: str
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionAccessListResponse(BaseModel):
|
||||||
|
"""Схема ответа со списком доступа"""
|
||||||
|
access_id: UUID
|
||||||
|
user: CollectionAccessUserInfo
|
||||||
|
collection_id: UUID
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|||||||
@ -26,10 +26,3 @@ class RAGAnswer(BaseModel):
|
|||||||
usage: dict[str, Any] = {}
|
usage: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
class IndexDocumentRequest(BaseModel):
|
|
||||||
document_id: UUID
|
|
||||||
|
|
||||||
|
|
||||||
class IndexDocumentResponse(BaseModel):
|
|
||||||
chunks_indexed: int
|
|
||||||
|
|
||||||
|
|||||||
@ -30,6 +30,9 @@ class UserResponse(BaseModel):
|
|||||||
telegram_id: str
|
telegram_id: str
|
||||||
role: UserRole
|
role: UserRole
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
is_premium: bool = False
|
||||||
|
premium_until: datetime | None = None
|
||||||
|
questions_used: int = 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_entity(cls, user: "User") -> "UserResponse":
|
def from_entity(cls, user: "User") -> "UserResponse":
|
||||||
@ -38,7 +41,10 @@ class UserResponse(BaseModel):
|
|||||||
user_id=user.user_id,
|
user_id=user.user_id,
|
||||||
telegram_id=user.telegram_id,
|
telegram_id=user.telegram_id,
|
||||||
role=user.role,
|
role=user.role,
|
||||||
created_at=user.created_at
|
created_at=user.created_at,
|
||||||
|
is_premium=user.is_premium,
|
||||||
|
premium_until=user.premium_until,
|
||||||
|
questions_used=user.questions_used
|
||||||
)
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|||||||
@ -7,18 +7,19 @@ from typing import Optional
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
"""Настройки (загружаются из .env автоматически)"""
|
||||||
|
|
||||||
POSTGRES_HOST: str
|
POSTGRES_HOST: str = "localhost"
|
||||||
POSTGRES_PORT: int
|
POSTGRES_PORT: int = 5432
|
||||||
POSTGRES_USER: str
|
POSTGRES_USER: str = "postgres"
|
||||||
POSTGRES_PASSWORD: str
|
POSTGRES_PASSWORD: str = "postgres"
|
||||||
POSTGRES_DB: str
|
POSTGRES_DB: str = "lawyer_ai"
|
||||||
|
|
||||||
QDRANT_HOST: str
|
QDRANT_HOST: str = "localhost"
|
||||||
QDRANT_PORT: int
|
QDRANT_PORT: int = 6333
|
||||||
|
|
||||||
REDIS_HOST: str
|
REDIS_HOST: str = "localhost"
|
||||||
REDIS_PORT: int
|
REDIS_PORT: int = 6379
|
||||||
|
|
||||||
TELEGRAM_BOT_TOKEN: Optional[str] = None
|
TELEGRAM_BOT_TOKEN: Optional[str] = None
|
||||||
YANDEX_OCR_API_KEY: Optional[str] = None
|
YANDEX_OCR_API_KEY: Optional[str] = None
|
||||||
@ -29,11 +30,12 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
APP_NAME: str = "ИИ-юрист"
|
APP_NAME: str = "ИИ-юрист"
|
||||||
DEBUG: bool = False
|
DEBUG: bool = False
|
||||||
SECRET_KEY: str
|
SECRET_KEY: str = "your-secret-key-change-in-production"
|
||||||
CORS_ORIGINS: list[str] = ["*"]
|
CORS_ORIGINS: list[str] = ["*"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def database_url(self) -> str:
|
def database_url(self) -> str:
|
||||||
|
"""Вычисляемый URL подключения"""
|
||||||
return f"postgresql://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}@{self.POSTGRES_HOST}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}"
|
return f"postgresql://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}@{self.POSTGRES_HOST}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}"
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from dishka import Container, Provider, Scope, provide
|
from dishka import Container, Provider, Scope, provide, make_async_container
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@ -39,13 +39,9 @@ from src.application.use_cases.rag_use_cases import RAGUseCases
|
|||||||
|
|
||||||
class DatabaseProvider(Provider):
|
class DatabaseProvider(Provider):
|
||||||
@provide(scope=Scope.REQUEST)
|
@provide(scope=Scope.REQUEST)
|
||||||
@asynccontextmanager
|
|
||||||
async def get_db(self) -> AsyncSession:
|
async def get_db(self) -> AsyncSession:
|
||||||
async with AsyncSessionLocal() as session:
|
session = AsyncSessionLocal()
|
||||||
try:
|
return session
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
class RepositoryProvider(Provider):
|
class RepositoryProvider(Provider):
|
||||||
@ -77,7 +73,7 @@ class RepositoryProvider(Provider):
|
|||||||
class ServiceProvider(Provider):
|
class ServiceProvider(Provider):
|
||||||
@provide(scope=Scope.APP)
|
@provide(scope=Scope.APP)
|
||||||
def get_redis_client(self) -> RedisClient:
|
def get_redis_client(self) -> RedisClient:
|
||||||
return RedisClient()
|
return RedisClient(host=settings.REDIS_HOST, port=settings.REDIS_PORT)
|
||||||
|
|
||||||
@provide(scope=Scope.APP)
|
@provide(scope=Scope.APP)
|
||||||
def get_cache_service(self, redis_client: RedisClient) -> CacheService:
|
def get_cache_service(self, redis_client: RedisClient) -> CacheService:
|
||||||
@ -95,8 +91,6 @@ class ServiceProvider(Provider):
|
|||||||
def get_parser_service(self, ocr_service: YandexOCRService) -> DocumentParserService:
|
def get_parser_service(self, ocr_service: YandexOCRService) -> DocumentParserService:
|
||||||
return DocumentParserService(ocr_service)
|
return DocumentParserService(ocr_service)
|
||||||
|
|
||||||
|
|
||||||
class VectorServiceProvider(Provider):
|
|
||||||
@provide(scope=Scope.APP)
|
@provide(scope=Scope.APP)
|
||||||
def get_qdrant_client(self) -> QdrantClient:
|
def get_qdrant_client(self) -> QdrantClient:
|
||||||
return QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT)
|
return QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT)
|
||||||
@ -134,12 +128,6 @@ class VectorServiceProvider(Provider):
|
|||||||
splitter=text_splitter,
|
splitter=text_splitter,
|
||||||
)
|
)
|
||||||
|
|
||||||
class AuthProvider(Provider):
|
|
||||||
@provide(scope=Scope.REQUEST)
|
|
||||||
async def get_current_user(self, request: Request, user_repo: IUserRepository) -> User:
|
|
||||||
from src.presentation.middleware.auth_middleware import get_current_user
|
|
||||||
return await get_current_user(request, user_repo)
|
|
||||||
|
|
||||||
|
|
||||||
class UseCaseProvider(Provider):
|
class UseCaseProvider(Provider):
|
||||||
@provide(scope=Scope.REQUEST)
|
@provide(scope=Scope.REQUEST)
|
||||||
@ -163,9 +151,11 @@ class UseCaseProvider(Provider):
|
|||||||
self,
|
self,
|
||||||
document_repo: IDocumentRepository,
|
document_repo: IDocumentRepository,
|
||||||
collection_repo: ICollectionRepository,
|
collection_repo: ICollectionRepository,
|
||||||
parser_service: DocumentParserService
|
access_repo: ICollectionAccessRepository,
|
||||||
|
parser_service: DocumentParserService,
|
||||||
|
rag_service: RAGService
|
||||||
) -> DocumentUseCases:
|
) -> DocumentUseCases:
|
||||||
return DocumentUseCases(document_repo, collection_repo, parser_service)
|
return DocumentUseCases(document_repo, collection_repo, access_repo, parser_service, rag_service)
|
||||||
|
|
||||||
@provide(scope=Scope.REQUEST)
|
@provide(scope=Scope.REQUEST)
|
||||||
def get_conversation_use_cases(
|
def get_conversation_use_cases(
|
||||||
@ -197,12 +187,10 @@ class UseCaseProvider(Provider):
|
|||||||
|
|
||||||
|
|
||||||
def create_container() -> Container:
|
def create_container() -> Container:
|
||||||
container = Container()
|
return make_async_container(
|
||||||
container.add_provider(DatabaseProvider())
|
DatabaseProvider(),
|
||||||
container.add_provider(RepositoryProvider())
|
RepositoryProvider(),
|
||||||
container.add_provider(ServiceProvider())
|
ServiceProvider(),
|
||||||
container.add_provider(AuthProvider())
|
UseCaseProvider()
|
||||||
container.add_provider(UseCaseProvider())
|
)
|
||||||
container.add_provider(VectorServiceProvider())
|
|
||||||
return container
|
|
||||||
|
|
||||||
|
|||||||
@ -1,93 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
from sqlalchemy import create_engine, inspect
|
|
||||||
from sqlalchemy.orm import declarative_base, Session
|
|
||||||
from sqlalchemy import Column, String, DateTime, Boolean, Integer, Text
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
DB_PATH = os.path.join(BASE_DIR, 'data', 'bot.db')
|
|
||||||
DATABASE_URL = f"sqlite:///{DB_PATH}"
|
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
|
||||||
|
|
||||||
if os.path.exists(DB_PATH):
|
|
||||||
try:
|
|
||||||
temp_engine = create_engine(DATABASE_URL)
|
|
||||||
inspector = inspect(temp_engine)
|
|
||||||
tables = inspector.get_table_names()
|
|
||||||
if tables:
|
|
||||||
sys.exit(0)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
choice = input("Перезаписать БД? (y/N): ")
|
|
||||||
if choice.lower() != 'y':
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
engine = create_engine(DATABASE_URL, echo=False)
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
class UserModel(Base):
|
|
||||||
__tablename__ = "users"
|
|
||||||
|
|
||||||
user_id = Column("user_id", String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
|
||||||
telegram_id = Column("telegram_id", String(100), nullable=False, unique=True)
|
|
||||||
created_at = Column("created_at", DateTime, default=datetime.utcnow, nullable=False)
|
|
||||||
role = Column("role", String(20), default="user", nullable=False)
|
|
||||||
is_premium = Column(Boolean, default=False, nullable=False)
|
|
||||||
premium_until = Column(DateTime, nullable=True)
|
|
||||||
questions_used = Column(Integer, default=0, nullable=False)
|
|
||||||
username = Column(String(100), nullable=True)
|
|
||||||
first_name = Column(String(100), nullable=True)
|
|
||||||
last_name = Column(String(100), nullable=True)
|
|
||||||
|
|
||||||
class PaymentModel(Base):
|
|
||||||
__tablename__ = "payments"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
payment_id = Column(String(36), default=lambda: str(uuid.uuid4()), nullable=False, unique=True)
|
|
||||||
user_id = Column(Integer, nullable=False)
|
|
||||||
amount = Column(String(20), nullable=False)
|
|
||||||
currency = Column(String(3), default="RUB", nullable=False)
|
|
||||||
status = Column(String(20), default="pending", nullable=False)
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
|
||||||
yookassa_payment_id = Column(String(100), unique=True, nullable=True)
|
|
||||||
description = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
Base.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
session = Session(bind=engine)
|
|
||||||
|
|
||||||
existing = session.query(UserModel).filter_by(telegram_id="123456789").first()
|
|
||||||
if not existing:
|
|
||||||
test_user = UserModel(
|
|
||||||
telegram_id="123456789",
|
|
||||||
username="test_user",
|
|
||||||
first_name="Test",
|
|
||||||
last_name="User",
|
|
||||||
is_premium=True
|
|
||||||
)
|
|
||||||
session.add(test_user)
|
|
||||||
|
|
||||||
existing_payment = session.query(PaymentModel).filter_by(yookassa_payment_id="test_yoo_001").first()
|
|
||||||
if not existing_payment:
|
|
||||||
test_payment = PaymentModel(
|
|
||||||
user_id=123456789,
|
|
||||||
amount="500.00",
|
|
||||||
status="succeeded",
|
|
||||||
description="Test payment",
|
|
||||||
yookassa_payment_id="test_yoo_001"
|
|
||||||
)
|
|
||||||
session.add(test_payment)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Ошибка: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
@ -1,31 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tg_bot.infrastructure.database.database import engine, Base
|
|
||||||
from tg_bot.infrastructure.database import models
|
|
||||||
|
|
||||||
print("СОЗДАНИЕ ТАБЛИЦ БАЗЫ ДАННЫХ")
|
|
||||||
Base.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
print("Таблицы успешно созданы!")
|
|
||||||
print(" • users")
|
|
||||||
print(" • payments")
|
|
||||||
print()
|
|
||||||
print(f"База данных: {engine.url}")
|
|
||||||
|
|
||||||
db_path = "data/bot.db"
|
|
||||||
if os.path.exists(db_path):
|
|
||||||
size = os.path.getsize(db_path)
|
|
||||||
print(f"Размер файла: {size} байт")
|
|
||||||
else:
|
|
||||||
print("Файл БД не найден, но таблицы созданы")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Ошибка: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
print("=" * 50)
|
|
||||||
@ -70,6 +70,7 @@ services:
|
|||||||
DEEPSEEK_API_KEY: ${DEEPSEEK_API_KEY}
|
DEEPSEEK_API_KEY: ${DEEPSEEK_API_KEY}
|
||||||
DEEPSEEK_API_URL: ${DEEPSEEK_API_URL:-https://api.deepseek.com/v1/chat/completions}
|
DEEPSEEK_API_URL: ${DEEPSEEK_API_URL:-https://api.deepseek.com/v1/chat/completions}
|
||||||
YANDEX_OCR_API_KEY: ${YANDEX_OCR_API_KEY}
|
YANDEX_OCR_API_KEY: ${YANDEX_OCR_API_KEY}
|
||||||
|
BACKEND_URL: ${BACKEND_URL:-http://backend:8000/api/v1}
|
||||||
DEBUG: "true"
|
DEBUG: "true"
|
||||||
depends_on:
|
depends_on:
|
||||||
- postgres
|
- postgres
|
||||||
|
|||||||
23
pytest.ini
Normal file
23
pytest.ini
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
[pytest]
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
asyncio_mode = auto
|
||||||
|
addopts =
|
||||||
|
-v
|
||||||
|
--strict-markers
|
||||||
|
--tb=short
|
||||||
|
--cov=backend/src
|
||||||
|
--cov=tg_bot
|
||||||
|
--cov-report=term-missing
|
||||||
|
--cov-report=xml
|
||||||
|
--cov-fail-under=65
|
||||||
|
--ignore=venv
|
||||||
|
--ignore=.venv
|
||||||
|
markers =
|
||||||
|
unit: Unit tests
|
||||||
|
integration: Integration tests
|
||||||
|
metrics: Metrics tests
|
||||||
|
slow: Slow running tests
|
||||||
|
|
||||||
@ -1,11 +0,0 @@
|
|||||||
pydantic>=2.5.0
|
|
||||||
pydantic-settings>=2.1.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
aiogram>=3.10.0
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
aiosqlite>=0.19.0
|
|
||||||
httpx>=0.25.0
|
|
||||||
yookassa>=2.4.0
|
|
||||||
fastapi>=0.104.0
|
|
||||||
uvicorn>=0.24.0
|
|
||||||
python-multipart>=0.0.6
|
|
||||||
94
tests/README.md
Normal file
94
tests/README.md
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# Тесты для проекта BetterCallPraskovia
|
||||||
|
|
||||||
|
## Структура тестов
|
||||||
|
|
||||||
|
```
|
||||||
|
tests/
|
||||||
|
├── conftest.py
|
||||||
|
├── unit/
|
||||||
|
│ ├── test_rag_service.py
|
||||||
|
│ ├── test_user_service.py
|
||||||
|
│ ├── test_deepseek_client.py
|
||||||
|
│ ├── test_document_use_cases.py
|
||||||
|
│ └── test_collection_use_cases.py
|
||||||
|
├── integration/
|
||||||
|
│ └── test_rag_integration.py
|
||||||
|
└── metrics/
|
||||||
|
└── test_hit_at_5.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Установка зависимостей
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r tests/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Запуск тестов
|
||||||
|
|
||||||
|
### Все тесты
|
||||||
|
```bash
|
||||||
|
pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Только юнит-тесты
|
||||||
|
```bash
|
||||||
|
pytest tests/unit/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Только интеграционные тесты
|
||||||
|
```bash
|
||||||
|
pytest tests/integration/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Только тесты метрик
|
||||||
|
```bash
|
||||||
|
pytest tests/metrics/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Только тесты tg_bot
|
||||||
|
```bash
|
||||||
|
pytest tests/unit/test_rag_service.py tests/unit/test_user_service.py tests/unit/test_deepseek_client.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### С покрытием кода
|
||||||
|
```bash
|
||||||
|
pytest --cov=backend/src --cov=tg_bot --cov-report=html
|
||||||
|
```
|
||||||
|
|
||||||
|
### С минимальным покрытием 65%
|
||||||
|
```bash
|
||||||
|
pytest --cov-fail-under=65
|
||||||
|
```
|
||||||
|
|
||||||
|
## Метрика hit@5
|
||||||
|
|
||||||
|
Проверка что в топ-5 релевантных документов есть хотя бы 1 нужный документ.
|
||||||
|
|
||||||
|
- **hit@5 = 1**, если есть хотя бы 1 релевантный документ в топ-5
|
||||||
|
- **hit@5 = 0**, если нет релевантных документов в топ-5
|
||||||
|
|
||||||
|
Среднее значение hit@5 для всех запросов должно быть **> 50%**
|
||||||
|
|
||||||
|
## Покрытие кода
|
||||||
|
|
||||||
|
**coverage ≥ 65%**
|
||||||
|
|
||||||
|
Проверка покрытия:
|
||||||
|
```bash
|
||||||
|
pytest --cov=backend/src --cov=tg_bot --cov-report=term-missing --cov-fail-under=65
|
||||||
|
```
|
||||||
|
|
||||||
|
## Маркеры тестов
|
||||||
|
|
||||||
|
- `@pytest.mark.unit` - юнит-тесты
|
||||||
|
- `@pytest.mark.integration` - интеграционные тесты
|
||||||
|
- `@pytest.mark.metrics` - тесты метрик
|
||||||
|
- `@pytest.mark.slow` - медленные тесты
|
||||||
|
|
||||||
|
Запуск по маркерам:
|
||||||
|
```bash
|
||||||
|
pytest -m unit
|
||||||
|
pytest -m integration
|
||||||
|
pytest -m metrics
|
||||||
|
```
|
||||||
|
|
||||||
170
tests/conftest.py
Normal file
170
tests/conftest.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from uuid import uuid4
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_document():
|
||||||
|
try:
|
||||||
|
from backend.src.domain.entities.document import Document
|
||||||
|
except ImportError:
|
||||||
|
from src.domain.entities.document import Document
|
||||||
|
return Document(
|
||||||
|
collection_id=uuid4(),
|
||||||
|
title="Тестовый документ",
|
||||||
|
content="Содержание документа",
|
||||||
|
metadata={"type": "law", "article": "123"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_documents_list():
|
||||||
|
try:
|
||||||
|
from backend.src.domain.entities.document import Document
|
||||||
|
except ImportError:
|
||||||
|
from src.domain.entities.document import Document
|
||||||
|
collection_id = uuid4()
|
||||||
|
return [
|
||||||
|
Document(
|
||||||
|
collection_id=collection_id,
|
||||||
|
title=f"Документ {i}",
|
||||||
|
content=f"Содержание документа {i} ",
|
||||||
|
metadata={"relevance_score": 0.9 - i * 0.1}
|
||||||
|
)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_relevant_documents():
|
||||||
|
return [
|
||||||
|
{"document_id": str(uuid4()), "title": "Гражданский кодекс РФ", "relevance": True},
|
||||||
|
{"document_id": str(uuid4()), "title": "Трудовой кодекс РФ", "relevance": True},
|
||||||
|
{"document_id": str(uuid4()), "title": "Налоговый кодекс РФ", "relevance": True},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_rag_response():
|
||||||
|
return {
|
||||||
|
"answer": "Тестовый ответ на вопрос",
|
||||||
|
"sources": [
|
||||||
|
{"title": "Документ 1", "collection": "Коллекция 1", "document_id": str(uuid4())},
|
||||||
|
{"title": "Документ 2", "collection": "Коллекция 1", "document_id": str(uuid4())},
|
||||||
|
{"title": "Документ 3", "collection": "Коллекция 2", "document_id": str(uuid4())},
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_collection():
|
||||||
|
try:
|
||||||
|
from backend.src.domain.entities.collection import Collection
|
||||||
|
except ImportError:
|
||||||
|
from src.domain.entities.collection import Collection
|
||||||
|
return Collection(
|
||||||
|
name="Коллекция",
|
||||||
|
owner_id=uuid4(),
|
||||||
|
description="Описание коллекции",
|
||||||
|
is_public=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_user():
|
||||||
|
try:
|
||||||
|
from backend.src.domain.entities.user import User, UserRole
|
||||||
|
except ImportError:
|
||||||
|
from src.domain.entities.user import User, UserRole
|
||||||
|
return User(
|
||||||
|
telegram_id="123456789",
|
||||||
|
role=UserRole.USER
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_document_repository():
|
||||||
|
repository = AsyncMock()
|
||||||
|
repository.get_by_id = AsyncMock()
|
||||||
|
repository.create = AsyncMock()
|
||||||
|
repository.update = AsyncMock()
|
||||||
|
repository.delete = AsyncMock(return_value=True)
|
||||||
|
repository.list_by_collection = AsyncMock(return_value=[])
|
||||||
|
return repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_collection_repository():
|
||||||
|
repository = AsyncMock()
|
||||||
|
repository.get_by_id = AsyncMock()
|
||||||
|
repository.create = AsyncMock()
|
||||||
|
repository.update = AsyncMock()
|
||||||
|
repository.delete = AsyncMock(return_value=True)
|
||||||
|
repository.list_by_owner = AsyncMock(return_value=[])
|
||||||
|
repository.list_public = AsyncMock(return_value=[])
|
||||||
|
return repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_user_repository():
|
||||||
|
repository = AsyncMock()
|
||||||
|
repository.get_by_id = AsyncMock()
|
||||||
|
repository.get_by_telegram_id = AsyncMock()
|
||||||
|
repository.create = AsyncMock()
|
||||||
|
repository.update = AsyncMock()
|
||||||
|
repository.delete = AsyncMock(return_value=True)
|
||||||
|
return repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_deepseek_client():
|
||||||
|
client = AsyncMock()
|
||||||
|
client.chat_completion = AsyncMock(return_value={
|
||||||
|
"content": "Ответ от DeepSeek",
|
||||||
|
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
|
||||||
|
})
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_queries_with_ground_truth():
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"query": "Какие права имеет работник при увольнении?",
|
||||||
|
"relevant_document_ids": [str(uuid4()), str(uuid4())],
|
||||||
|
"expected_top5_contains": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "Как оформить договор купли-продажи?",
|
||||||
|
"relevant_document_ids": [str(uuid4())],
|
||||||
|
"expected_top5_contains": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "Какие налоги платит ИП?",
|
||||||
|
"relevant_document_ids": [str(uuid4()), str(uuid4()), str(uuid4())],
|
||||||
|
"expected_top5_contains": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "Права потребителя при возврате товара",
|
||||||
|
"relevant_document_ids": [str(uuid4())],
|
||||||
|
"expected_top5_contains": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "Как расторгнуть брак?",
|
||||||
|
"relevant_document_ids": [str(uuid4())],
|
||||||
|
"expected_top5_contains": True
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_aiohttp_session():
|
||||||
|
session = AsyncMock()
|
||||||
|
session.get = AsyncMock()
|
||||||
|
session.post = AsyncMock()
|
||||||
|
session.__aenter__ = AsyncMock(return_value=session)
|
||||||
|
session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
return session
|
||||||
146
tests/integration/test_rag_integration.py
Normal file
146
tests/integration/test_rag_integration.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from uuid import uuid4
|
||||||
|
import aiohttp
|
||||||
|
from tg_bot.application.services.rag_service import RAGService
|
||||||
|
from tests.metrics.test_hit_at_5 import calculate_hit_at_5, calculate_average_hit_at_5
|
||||||
|
|
||||||
|
|
||||||
|
class TestRAGIntegration:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def rag_service(self):
|
||||||
|
service = RAGService()
|
||||||
|
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
||||||
|
service.deepseek_client = DeepSeekClient()
|
||||||
|
return service
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_service_search_documents_real_implementation(self, rag_service):
|
||||||
|
user_telegram_id = "123456789"
|
||||||
|
query = "трудовой договор"
|
||||||
|
user_uuid = str(uuid4())
|
||||||
|
collection_id = str(uuid4())
|
||||||
|
|
||||||
|
mock_user_response = AsyncMock()
|
||||||
|
mock_user_response.status = 200
|
||||||
|
mock_user_response.json = AsyncMock(return_value={"user_id": user_uuid})
|
||||||
|
|
||||||
|
mock_collections_response = AsyncMock()
|
||||||
|
mock_collections_response.status = 200
|
||||||
|
mock_collections_response.json = AsyncMock(return_value=[
|
||||||
|
{"collection_id": collection_id, "name": "Законы"}
|
||||||
|
])
|
||||||
|
|
||||||
|
mock_documents_response = AsyncMock()
|
||||||
|
mock_documents_response.status = 200
|
||||||
|
mock_documents_response.json = AsyncMock(return_value=[
|
||||||
|
{
|
||||||
|
"document_id": str(uuid4()),
|
||||||
|
"title": "Трудовой кодекс",
|
||||||
|
"content": "Содержание о трудовых договорах"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
with patch('aiohttp.ClientSession') as mock_session_class:
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_session.get = AsyncMock(side_effect=[
|
||||||
|
mock_user_response,
|
||||||
|
mock_collections_response,
|
||||||
|
mock_documents_response
|
||||||
|
])
|
||||||
|
mock_session_class.return_value = mock_session
|
||||||
|
|
||||||
|
result = await rag_service.search_documents_in_collections(user_telegram_id, query)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
if result:
|
||||||
|
assert "document_id" in result[0]
|
||||||
|
assert "title" in result[0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_service_generate_answer_real_implementation(self, rag_service):
|
||||||
|
question = "Какие права имеет работник?"
|
||||||
|
user_telegram_id = "123456789"
|
||||||
|
|
||||||
|
mock_documents = [
|
||||||
|
{
|
||||||
|
"document_id": str(uuid4()),
|
||||||
|
"title": "Трудовой кодекс",
|
||||||
|
"content": "Работник имеет право на...",
|
||||||
|
"collection_name": "Законы"
|
||||||
|
}
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||||
|
patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek:
|
||||||
|
|
||||||
|
mock_search.return_value = mock_documents
|
||||||
|
mock_deepseek.return_value = {
|
||||||
|
"content": "Работник имеет следующие права...",
|
||||||
|
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||||
|
|
||||||
|
assert "answer" in result
|
||||||
|
assert "sources" in result
|
||||||
|
assert "usage" in result
|
||||||
|
assert len(result["sources"]) <= 5
|
||||||
|
assert result["answer"] != ""
|
||||||
|
|
||||||
|
for source in result["sources"]:
|
||||||
|
assert "title" in source
|
||||||
|
assert "collection" in source
|
||||||
|
assert "document_id" in source
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_service_limits_to_top5_documents(self, rag_service):
|
||||||
|
question = "Тестовый вопрос"
|
||||||
|
user_telegram_id = "123456789"
|
||||||
|
|
||||||
|
many_documents = [
|
||||||
|
{
|
||||||
|
"document_id": str(uuid4()),
|
||||||
|
"title": f"Документ {i}",
|
||||||
|
"content": f"Содержание {i}",
|
||||||
|
"collection_name": "Коллекция"
|
||||||
|
}
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||||
|
patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek:
|
||||||
|
|
||||||
|
mock_search.return_value = many_documents
|
||||||
|
mock_deepseek.return_value = {
|
||||||
|
"content": "Ответ",
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||||
|
|
||||||
|
assert len(result["sources"]) == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_service_handles_empty_search_results(self, rag_service):
|
||||||
|
question = "Вопрос без документов"
|
||||||
|
user_telegram_id = "123456789"
|
||||||
|
|
||||||
|
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||||
|
patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek:
|
||||||
|
|
||||||
|
mock_search.return_value = []
|
||||||
|
mock_deepseek.return_value = {
|
||||||
|
"content": "Ответ",
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||||
|
|
||||||
|
assert result["sources"] == []
|
||||||
|
assert "Релевантные документы не найдены" in result.get("answer", "") or \
|
||||||
|
result["answer"] == "No relevant documents found"
|
||||||
133
tests/metrics/test_hit_at_5.py
Normal file
133
tests/metrics/test_hit_at_5.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import pytest
|
||||||
|
from uuid import uuid4
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_hit_at_5(retrieved_document_ids: List[str], relevant_document_ids: List[str]) -> int:
|
||||||
|
if not retrieved_document_ids or not relevant_document_ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
top5_ids = set(retrieved_document_ids[:5])
|
||||||
|
relevant_ids = set(relevant_document_ids)
|
||||||
|
|
||||||
|
return 1 if top5_ids.intersection(relevant_ids) else 0
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_average_hit_at_5(results: List[int]) -> float:
|
||||||
|
if not results:
|
||||||
|
return 0.0
|
||||||
|
return sum(results) / len(results)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHitAt5Metric:
|
||||||
|
|
||||||
|
def test_hit_at_5_returns_1_when_relevant_document_in_top5(self):
|
||||||
|
relevant_ids = [str(uuid4()), str(uuid4())]
|
||||||
|
retrieved_ids = [
|
||||||
|
str(uuid4()),
|
||||||
|
relevant_ids[0],
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4())
|
||||||
|
]
|
||||||
|
|
||||||
|
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
||||||
|
assert result == 1
|
||||||
|
|
||||||
|
def test_hit_at_5_returns_0_when_no_relevant_document_in_top5(self):
|
||||||
|
relevant_ids = [str(uuid4()), str(uuid4())]
|
||||||
|
retrieved_ids = [
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4())
|
||||||
|
]
|
||||||
|
|
||||||
|
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
def test_hit_at_5_returns_1_when_multiple_relevant_documents(self):
|
||||||
|
relevant_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
|
||||||
|
retrieved_ids = [
|
||||||
|
relevant_ids[0],
|
||||||
|
str(uuid4()),
|
||||||
|
relevant_ids[1],
|
||||||
|
str(uuid4()),
|
||||||
|
relevant_ids[2]
|
||||||
|
]
|
||||||
|
|
||||||
|
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
||||||
|
assert result == 1
|
||||||
|
|
||||||
|
def test_hit_at_5_handles_empty_lists(self):
|
||||||
|
result = calculate_hit_at_5([], [str(uuid4())])
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
result = calculate_hit_at_5([str(uuid4())], [])
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
result = calculate_hit_at_5([], [])
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
def test_hit_at_5_only_checks_top5(self):
|
||||||
|
relevant_ids = [str(uuid4())]
|
||||||
|
retrieved_ids = [
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4()),
|
||||||
|
str(uuid4()),
|
||||||
|
relevant_ids[0]
|
||||||
|
]
|
||||||
|
|
||||||
|
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
def test_calculate_average_hit_at_5(self):
|
||||||
|
results = [1, 1, 0, 1, 0, 1, 0, 1, 1, 1]
|
||||||
|
|
||||||
|
average = calculate_average_hit_at_5(results)
|
||||||
|
assert average == 0.7
|
||||||
|
|
||||||
|
def test_calculate_average_hit_at_5_all_ones(self):
|
||||||
|
results = [1, 1, 1, 1, 1]
|
||||||
|
|
||||||
|
average = calculate_average_hit_at_5(results)
|
||||||
|
assert average == 1.0
|
||||||
|
|
||||||
|
def test_calculate_average_hit_at_5_all_zeros(self):
|
||||||
|
results = [0, 0, 0, 0, 0]
|
||||||
|
|
||||||
|
average = calculate_average_hit_at_5(results)
|
||||||
|
assert average == 0.0
|
||||||
|
|
||||||
|
def test_calculate_average_hit_at_5_empty_list(self):
|
||||||
|
average = calculate_average_hit_at_5([])
|
||||||
|
assert average == 0.0
|
||||||
|
|
||||||
|
def test_hit_at_5_quality_threshold(self):
|
||||||
|
results = [1] * 60 + [0] * 40
|
||||||
|
|
||||||
|
average = calculate_average_hit_at_5(results)
|
||||||
|
assert average > 0.5, f"Качество {average} должно быть > 0.5"
|
||||||
|
assert average == 0.6
|
||||||
|
|
||||||
|
def test_hit_at_5_quality_below_threshold(self):
|
||||||
|
results = [1] * 40 + [0] * 60
|
||||||
|
|
||||||
|
average = calculate_average_hit_at_5(results)
|
||||||
|
assert average < 0.5, f"Качество {average} должно быть < 0.5"
|
||||||
|
assert average == 0.4
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("hit_count,total,expected_quality", [
|
||||||
|
(51, 100, 0.51),
|
||||||
|
(50, 100, 0.50),
|
||||||
|
(60, 100, 0.60),
|
||||||
|
(75, 100, 0.75),
|
||||||
|
(100, 100, 1.0),
|
||||||
|
])
|
||||||
|
def test_hit_at_5_various_qualities(self, hit_count, total, expected_quality):
|
||||||
|
results = [1] * hit_count + [0] * (total - hit_count)
|
||||||
|
average = calculate_average_hit_at_5(results)
|
||||||
|
assert average == expected_quality
|
||||||
8
tests/requirements.txt
Normal file
8
tests/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
pytest==7.4.3
|
||||||
|
pytest-asyncio==0.21.1
|
||||||
|
pytest-cov==4.1.0
|
||||||
|
pytest-mock==3.12.0
|
||||||
|
pytest-timeout==2.2.0
|
||||||
|
httpx>=0.25.2
|
||||||
|
aiohttp>=3.9.1
|
||||||
|
|
||||||
132
tests/unit/test_collection_use_cases.py
Normal file
132
tests/unit/test_collection_use_cases.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import pytest
|
||||||
|
from uuid import uuid4
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
try:
|
||||||
|
from backend.src.application.use_cases.collection_use_cases import CollectionUseCases
|
||||||
|
from backend.src.shared.exceptions import NotFoundError, ForbiddenError
|
||||||
|
from backend.src.domain.repositories.collection_access_repository import ICollectionAccessRepository
|
||||||
|
except ImportError:
|
||||||
|
from src.application.use_cases.collection_use_cases import CollectionUseCases
|
||||||
|
from src.shared.exceptions import NotFoundError, ForbiddenError
|
||||||
|
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectionUseCases:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def collection_use_cases(self, mock_collection_repository, mock_user_repository):
|
||||||
|
mock_access_repository = AsyncMock()
|
||||||
|
mock_access_repository.get_by_user_and_collection = AsyncMock(return_value=None)
|
||||||
|
mock_access_repository.create = AsyncMock()
|
||||||
|
mock_access_repository.delete_by_user_and_collection = AsyncMock(return_value=True)
|
||||||
|
mock_access_repository.list_by_user = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
return CollectionUseCases(
|
||||||
|
collection_repository=mock_collection_repository,
|
||||||
|
access_repository=mock_access_repository,
|
||||||
|
user_repository=mock_user_repository
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_collection_success(self, collection_use_cases, mock_user,
|
||||||
|
mock_collection_repository, mock_user_repository):
|
||||||
|
owner_id = uuid4()
|
||||||
|
mock_user_repository.get_by_id = AsyncMock(return_value=mock_user)
|
||||||
|
mock_collection_repository.create = AsyncMock(return_value=mock_user)
|
||||||
|
|
||||||
|
result = await collection_use_cases.create_collection(
|
||||||
|
name="Тестовая коллекция",
|
||||||
|
owner_id=owner_id,
|
||||||
|
description="Описание",
|
||||||
|
is_public=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
mock_user_repository.get_by_id.assert_called_once_with(owner_id)
|
||||||
|
mock_collection_repository.create.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_collection_user_not_found(self, collection_use_cases, mock_user_repository):
|
||||||
|
owner_id = uuid4()
|
||||||
|
mock_user_repository.get_by_id = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await collection_use_cases.create_collection(
|
||||||
|
name="Коллекция",
|
||||||
|
owner_id=owner_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_collection_success(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||||
|
collection_id = uuid4()
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
|
||||||
|
result = await collection_use_cases.get_collection(collection_id)
|
||||||
|
|
||||||
|
assert result == mock_collection
|
||||||
|
mock_collection_repository.get_by_id.assert_called_once_with(collection_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_collection_not_found(self, collection_use_cases, mock_collection_repository):
|
||||||
|
collection_id = uuid4()
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await collection_use_cases.get_collection(collection_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_collection_success(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||||
|
collection_id = uuid4()
|
||||||
|
user_id = uuid4()
|
||||||
|
mock_collection.owner_id = user_id
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
mock_collection_repository.update = AsyncMock(return_value=mock_collection)
|
||||||
|
|
||||||
|
result = await collection_use_cases.update_collection(
|
||||||
|
collection_id=collection_id,
|
||||||
|
user_id=user_id,
|
||||||
|
name="Обновленное название"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert mock_collection.name == "Обновленное название"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_collection_forbidden(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||||
|
collection_id = uuid4()
|
||||||
|
user_id = uuid4()
|
||||||
|
owner_id = uuid4()
|
||||||
|
mock_collection.owner_id = owner_id
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
|
||||||
|
with pytest.raises(ForbiddenError):
|
||||||
|
await collection_use_cases.update_collection(
|
||||||
|
collection_id=collection_id,
|
||||||
|
user_id=user_id,
|
||||||
|
name="Название"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_access_owner(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||||
|
collection_id = uuid4()
|
||||||
|
user_id = uuid4()
|
||||||
|
mock_collection.owner_id = user_id
|
||||||
|
mock_collection.is_public = False
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
|
||||||
|
result = await collection_use_cases.check_access(collection_id, user_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_access_public(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||||
|
collection_id = uuid4()
|
||||||
|
user_id = uuid4()
|
||||||
|
owner_id = uuid4()
|
||||||
|
mock_collection.owner_id = owner_id
|
||||||
|
mock_collection.is_public = True
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
|
||||||
|
result = await collection_use_cases.check_access(collection_id, user_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
114
tests/unit/test_deepseek_client.py
Normal file
114
tests/unit/test_deepseek_client.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
import httpx
|
||||||
|
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
||||||
|
from tg_bot.infrastructure.external.deepseek_client import DeepSeekAPIError
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeepSeekClient:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def deepseek_client(self):
|
||||||
|
return DeepSeekClient(api_key="test_key", api_url="https://api.test.com/v1/chat/completions")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_success(self, deepseek_client):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Тестовый вопрос"}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response_data = {
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": "Тестовый ответ от DeepSeek"
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"total_tokens": 30
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch('httpx.AsyncClient') as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = mock_response_data
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value = mock_client_instance
|
||||||
|
|
||||||
|
result = await deepseek_client.chat_completion(messages)
|
||||||
|
|
||||||
|
assert "content" in result
|
||||||
|
assert result["content"] == "Тестовый ответ от DeepSeek"
|
||||||
|
assert "usage" in result
|
||||||
|
assert result["usage"]["total_tokens"] == 30
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_no_api_key(self):
|
||||||
|
client = DeepSeekClient(api_key=None)
|
||||||
|
messages = [{"role": "user", "content": "Вопрос"}]
|
||||||
|
|
||||||
|
result = await client.chat_completion(messages)
|
||||||
|
|
||||||
|
assert "content" in result
|
||||||
|
assert "DEEPSEEK_API_KEY" in result["content"] or "не установлен" in result["content"]
|
||||||
|
assert result["usage"]["total_tokens"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_api_error(self, deepseek_client):
|
||||||
|
import httpx
|
||||||
|
messages = [{"role": "user", "content": "Вопрос"}]
|
||||||
|
|
||||||
|
with patch('httpx.AsyncClient') as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 401
|
||||||
|
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"Unauthorized", request=MagicMock(), response=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value = mock_client_instance
|
||||||
|
|
||||||
|
with pytest.raises(DeepSeekAPIError):
|
||||||
|
await deepseek_client.chat_completion(messages)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_with_parameters(self, deepseek_client):
|
||||||
|
messages = [{"role": "user", "content": "Вопрос"}]
|
||||||
|
|
||||||
|
mock_response_data = {
|
||||||
|
"choices": [{"message": {"content": "Ответ"}}],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch('httpx.AsyncClient') as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = mock_response_data
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||||
|
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value = mock_client_instance
|
||||||
|
|
||||||
|
result = await deepseek_client.chat_completion(
|
||||||
|
messages,
|
||||||
|
model="deepseek-chat",
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["content"] == "Ответ"
|
||||||
|
call_args = mock_client_instance.post.call_args
|
||||||
|
assert call_args is not None
|
||||||
141
tests/unit/test_document_use_cases.py
Normal file
141
tests/unit/test_document_use_cases.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import pytest
|
||||||
|
from uuid import uuid4
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
try:
|
||||||
|
from backend.src.application.use_cases.document_use_cases import DocumentUseCases
|
||||||
|
from backend.src.shared.exceptions import NotFoundError, ForbiddenError
|
||||||
|
from backend.src.application.services.document_parser_service import DocumentParserService
|
||||||
|
except ImportError:
|
||||||
|
from src.application.use_cases.document_use_cases import DocumentUseCases
|
||||||
|
from src.shared.exceptions import NotFoundError, ForbiddenError
|
||||||
|
from src.application.services.document_parser_service import DocumentParserService
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentUseCases:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def document_use_cases(self, mock_document_repository, mock_collection_repository):
|
||||||
|
mock_parser = AsyncMock()
|
||||||
|
mock_parser.parse_pdf = AsyncMock(return_value=("Парсенный документ", "Содержание"))
|
||||||
|
|
||||||
|
return DocumentUseCases(
|
||||||
|
document_repository=mock_document_repository,
|
||||||
|
collection_repository=mock_collection_repository,
|
||||||
|
parser_service=mock_parser
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_document_success(self, document_use_cases, mock_collection, mock_document_repository, mock_collection_repository):
|
||||||
|
collection_id = uuid4()
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
mock_document_repository.create = AsyncMock(return_value=mock_collection)
|
||||||
|
|
||||||
|
result = await document_use_cases.create_document(
|
||||||
|
collection_id=collection_id,
|
||||||
|
title="Тестовый документ",
|
||||||
|
content="Содержание",
|
||||||
|
metadata={"type": "law"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
mock_collection_repository.get_by_id.assert_called_once_with(collection_id)
|
||||||
|
mock_document_repository.create.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_document_collection_not_found(self, document_use_cases, mock_collection_repository):
|
||||||
|
collection_id = uuid4()
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await document_use_cases.create_document(
|
||||||
|
collection_id=collection_id,
|
||||||
|
title="Документ",
|
||||||
|
content="Содержание"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_document_success(self, document_use_cases, mock_document, mock_document_repository):
|
||||||
|
document_id = uuid4()
|
||||||
|
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
|
||||||
|
|
||||||
|
result = await document_use_cases.get_document(document_id)
|
||||||
|
|
||||||
|
assert result == mock_document
|
||||||
|
mock_document_repository.get_by_id.assert_called_once_with(document_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_document_not_found(self, document_use_cases, mock_document_repository):
|
||||||
|
document_id = uuid4()
|
||||||
|
mock_document_repository.get_by_id = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await document_use_cases.get_document(document_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_document_success(self, document_use_cases, mock_document, mock_collection,
|
||||||
|
mock_document_repository, mock_collection_repository):
|
||||||
|
document_id = uuid4()
|
||||||
|
user_id = uuid4()
|
||||||
|
mock_document.collection_id = uuid4()
|
||||||
|
mock_collection.owner_id = user_id
|
||||||
|
|
||||||
|
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
mock_document_repository.update = AsyncMock(return_value=mock_document)
|
||||||
|
|
||||||
|
result = await document_use_cases.update_document(
|
||||||
|
document_id=document_id,
|
||||||
|
user_id=user_id,
|
||||||
|
title="Обновленное название"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert mock_document.title == "Обновленное название"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_document_forbidden(self, document_use_cases, mock_document, mock_collection,
|
||||||
|
mock_document_repository, mock_collection_repository):
|
||||||
|
document_id = uuid4()
|
||||||
|
user_id = uuid4()
|
||||||
|
owner_id = uuid4()
|
||||||
|
mock_document.collection_id = uuid4()
|
||||||
|
mock_collection.owner_id = owner_id
|
||||||
|
|
||||||
|
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
|
||||||
|
with pytest.raises(ForbiddenError):
|
||||||
|
await document_use_cases.update_document(
|
||||||
|
document_id=document_id,
|
||||||
|
user_id=user_id,
|
||||||
|
title="Название"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_document_success(self, document_use_cases, mock_document, mock_collection,
|
||||||
|
mock_document_repository, mock_collection_repository):
|
||||||
|
document_id = uuid4()
|
||||||
|
user_id = uuid4()
|
||||||
|
mock_document.collection_id = uuid4()
|
||||||
|
mock_collection.owner_id = user_id
|
||||||
|
|
||||||
|
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
mock_document_repository.delete = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
result = await document_use_cases.delete_document(document_id, user_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_document_repository.delete.assert_called_once_with(document_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_collection_documents(self, document_use_cases, mock_collection, mock_documents_list,
|
||||||
|
mock_collection_repository, mock_document_repository):
|
||||||
|
collection_id = uuid4()
|
||||||
|
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||||
|
mock_document_repository.list_by_collection = AsyncMock(return_value=mock_documents_list)
|
||||||
|
|
||||||
|
result = await document_use_cases.list_collection_documents(collection_id, skip=0, limit=10)
|
||||||
|
|
||||||
|
assert len(result) == len(mock_documents_list)
|
||||||
|
mock_document_repository.list_by_collection.assert_called_once_with(collection_id, skip=0, limit=10)
|
||||||
171
tests/unit/test_rag_service.py
Normal file
171
tests/unit/test_rag_service.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from uuid import uuid4
|
||||||
|
from tg_bot.application.services.rag_service import RAGService
|
||||||
|
|
||||||
|
|
||||||
|
class TestRAGService:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def rag_service(self):
|
||||||
|
service = RAGService()
|
||||||
|
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
||||||
|
service.deepseek_client = DeepSeekClient()
|
||||||
|
return service
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_documents_in_collections_success(self, rag_service):
|
||||||
|
user_telegram_id = "123456789"
|
||||||
|
query = "трудовой договор"
|
||||||
|
|
||||||
|
mock_documents = [
|
||||||
|
{
|
||||||
|
"document_id": str(uuid4()),
|
||||||
|
"title": "Трудовой кодекс РФ",
|
||||||
|
"content": "Содержание о трудовых договорах",
|
||||||
|
"collection_name": "Законы"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"document_id": str(uuid4()),
|
||||||
|
"title": "Правила оформления",
|
||||||
|
"content": "Как оформить трудовой договор",
|
||||||
|
"collection_name": "Инструкции"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch('aiohttp.ClientSession') as mock_session:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={
|
||||||
|
"user_id": str(uuid4())
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_collections_response = AsyncMock()
|
||||||
|
mock_collections_response.status = 200
|
||||||
|
mock_collections_response.json = AsyncMock(return_value=[
|
||||||
|
{"collection_id": str(uuid4()), "name": "Законы"}
|
||||||
|
])
|
||||||
|
|
||||||
|
mock_search_response = AsyncMock()
|
||||||
|
mock_search_response.status = 200
|
||||||
|
mock_search_response.json = AsyncMock(return_value=mock_documents)
|
||||||
|
|
||||||
|
mock_session_instance = MagicMock()
|
||||||
|
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
||||||
|
mock_session_instance.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_session_instance.get = AsyncMock(side_effect=[
|
||||||
|
mock_response,
|
||||||
|
mock_collections_response,
|
||||||
|
mock_search_response
|
||||||
|
])
|
||||||
|
mock_session.return_value = mock_session_instance
|
||||||
|
|
||||||
|
result = await rag_service.search_documents_in_collections(
|
||||||
|
user_telegram_id, query, limit_per_collection=5
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) > 0
|
||||||
|
assert result[0]["title"] == "Трудовой кодекс РФ"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_documents_empty_result(self, rag_service):
|
||||||
|
user_telegram_id = "123456789"
|
||||||
|
query = "несуществующий запрос"
|
||||||
|
|
||||||
|
with patch('aiohttp.ClientSession') as mock_session:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={
|
||||||
|
"user_id": str(uuid4())
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_collections_response = AsyncMock()
|
||||||
|
mock_collections_response.status = 200
|
||||||
|
mock_collections_response.json = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
mock_session_instance = MagicMock()
|
||||||
|
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
||||||
|
mock_session_instance.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_session_instance.get = AsyncMock(side_effect=[
|
||||||
|
mock_response,
|
||||||
|
mock_collections_response
|
||||||
|
])
|
||||||
|
mock_session.return_value = mock_session_instance
|
||||||
|
|
||||||
|
result = await rag_service.search_documents_in_collections(
|
||||||
|
user_telegram_id, query
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_answer_with_rag_success(self, rag_service, mock_rag_response):
|
||||||
|
question = "Какие права имеет работник?"
|
||||||
|
user_telegram_id = "123456789"
|
||||||
|
|
||||||
|
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||||
|
patch.object(rag_service, 'deepseek_client') as mock_client:
|
||||||
|
|
||||||
|
mock_search.return_value = [
|
||||||
|
{
|
||||||
|
"document_id": str(uuid4()),
|
||||||
|
"title": "Трудовой кодекс",
|
||||||
|
"content": "Работник имеет право на...",
|
||||||
|
"collection_name": "Законы"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_client.chat_completion = AsyncMock(return_value={
|
||||||
|
"content": "Работник имеет следующие права...",
|
||||||
|
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
|
||||||
|
})
|
||||||
|
|
||||||
|
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||||
|
|
||||||
|
assert "answer" in result
|
||||||
|
assert "sources" in result
|
||||||
|
assert "usage" in result
|
||||||
|
assert len(result["sources"]) <= 5
|
||||||
|
assert result["answer"] != ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_answer_limits_to_top5(self, rag_service):
|
||||||
|
question = "Тестовый вопрос"
|
||||||
|
user_telegram_id = "123456789"
|
||||||
|
|
||||||
|
many_documents = [
|
||||||
|
{
|
||||||
|
"document_id": str(uuid4()),
|
||||||
|
"title": f"Документ {i}",
|
||||||
|
"content": f"Содержание {i}",
|
||||||
|
"collection_name": "Коллекция"
|
||||||
|
}
|
||||||
|
for i in range(20)
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||||
|
patch.object(rag_service, 'deepseek_client') as mock_client:
|
||||||
|
|
||||||
|
mock_search.return_value = many_documents
|
||||||
|
mock_client.chat_completion = AsyncMock(return_value={
|
||||||
|
"content": "Ответ",
|
||||||
|
"usage": {}
|
||||||
|
})
|
||||||
|
|
||||||
|
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||||
|
|
||||||
|
assert len(result["sources"]) == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_answer_no_documents(self, rag_service):
|
||||||
|
question = "Вопрос без документов"
|
||||||
|
user_telegram_id = "123456789"
|
||||||
|
|
||||||
|
with patch.object(rag_service, 'search_documents_in_collections') as mock_search:
|
||||||
|
mock_search.return_value = []
|
||||||
|
|
||||||
|
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||||
|
|
||||||
|
assert result["sources"] == []
|
||||||
|
assert "Релевантные документы не найдены" in result.get("answer", "") or \
|
||||||
|
result["answer"] == "No relevant documents found"
|
||||||
193
tests/unit/test_user_service.py
Normal file
193
tests/unit/test_user_service.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from tg_bot.domain.services.user_service import UserService
|
||||||
|
from tg_bot.infrastructure.database.models import UserModel
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserService:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session(self):
|
||||||
|
session = AsyncMock()
|
||||||
|
session.execute = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
session.rollback = AsyncMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_service(self, mock_session):
|
||||||
|
return UserService(mock_session)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_by_telegram_id_success(self, user_service, mock_session):
|
||||||
|
telegram_id = 123456789
|
||||||
|
mock_user = UserModel(
|
||||||
|
telegram_id=str(telegram_id),
|
||||||
|
username="test_user",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none = MagicMock(return_value=mock_user)
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
result = await user_service.get_user_by_telegram_id(telegram_id)
|
||||||
|
|
||||||
|
assert result == mock_user
|
||||||
|
assert result.telegram_id == str(telegram_id)
|
||||||
|
mock_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_by_telegram_id_not_found(self, user_service, mock_session):
|
||||||
|
telegram_id = 999999999
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none = MagicMock(return_value=None)
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
result = await user_service.get_user_by_telegram_id(telegram_id)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
mock_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_or_create_user_new_user(self, user_service, mock_session):
|
||||||
|
telegram_id = 123456789
|
||||||
|
username = "new_user"
|
||||||
|
first_name = "New"
|
||||||
|
last_name = "User"
|
||||||
|
|
||||||
|
mock_result_not_found = MagicMock()
|
||||||
|
mock_result_not_found.scalar_one_or_none = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
mock_result_found = MagicMock()
|
||||||
|
created_user = UserModel(
|
||||||
|
telegram_id=str(telegram_id),
|
||||||
|
username=username,
|
||||||
|
first_name=first_name,
|
||||||
|
last_name=last_name
|
||||||
|
)
|
||||||
|
mock_result_found.scalar_one_or_none = MagicMock(return_value=created_user)
|
||||||
|
|
||||||
|
mock_session.execute.side_effect = [mock_result_not_found, mock_result_found]
|
||||||
|
|
||||||
|
result = await user_service.get_or_create_user(telegram_id, username, first_name, last_name)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.telegram_id == str(telegram_id)
|
||||||
|
assert result.username == username
|
||||||
|
mock_session.add.assert_called_once()
|
||||||
|
mock_session.commit.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_or_create_user_existing_user(self, user_service, mock_session):
|
||||||
|
telegram_id = 123456789
|
||||||
|
existing_user = UserModel(
|
||||||
|
telegram_id=str(telegram_id),
|
||||||
|
username="old_username",
|
||||||
|
first_name="Old",
|
||||||
|
last_name="Name"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none = MagicMock(return_value=existing_user)
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
result = await user_service.get_or_create_user(
|
||||||
|
telegram_id, "new_username", "New", "Name"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == existing_user
|
||||||
|
assert result.username == "new_username"
|
||||||
|
assert result.first_name == "New"
|
||||||
|
assert result.last_name == "Name"
|
||||||
|
mock_session.commit.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_user_questions_success(self, user_service, mock_session):
|
||||||
|
telegram_id = 123456789
|
||||||
|
user = UserModel(
|
||||||
|
telegram_id=str(telegram_id),
|
||||||
|
questions_used=5
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none = MagicMock(return_value=user)
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
result = await user_service.update_user_questions(telegram_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert user.questions_used == 6
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_user_questions_user_not_found(self, user_service, mock_session):
|
||||||
|
telegram_id = 999999999
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none = MagicMock(return_value=None)
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
result = await user_service.update_user_questions(telegram_id)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_session.commit.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_activate_premium_success(self, user_service, mock_session):
|
||||||
|
telegram_id = 123456789
|
||||||
|
user = UserModel(
|
||||||
|
telegram_id=str(telegram_id),
|
||||||
|
is_premium=False,
|
||||||
|
premium_until=None
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none = MagicMock(return_value=user)
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
result = await user_service.activate_premium(telegram_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert user.is_premium is True
|
||||||
|
assert user.premium_until is not None
|
||||||
|
assert user.premium_until > datetime.now()
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_activate_premium_extend_existing(self, user_service, mock_session):
|
||||||
|
telegram_id = 123456789
|
||||||
|
existing_premium_until = datetime.now() + timedelta(days=10)
|
||||||
|
user = UserModel(
|
||||||
|
telegram_id=str(telegram_id),
|
||||||
|
is_premium=True,
|
||||||
|
premium_until=existing_premium_until
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none = MagicMock(return_value=user)
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
result = await user_service.activate_premium(telegram_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert user.is_premium is True
|
||||||
|
assert user.premium_until > existing_premium_until
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_activate_premium_user_not_found(self, user_service, mock_session):
|
||||||
|
telegram_id = 999999999
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none = MagicMock(return_value=None)
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
result = await user_service.activate_premium(telegram_id)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_session.commit.assert_not_called()
|
||||||
35
tg_bot/.dockerignore
Normal file
35
tg_bot/.dockerignore
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
.git/
|
||||||
|
.gitignore
|
||||||
|
.gitattributes
|
||||||
|
|
||||||
|
Dockerfile*
|
||||||
|
docker-compose*.yml
|
||||||
|
.dockerignore
|
||||||
|
|
||||||
|
drone.yml
|
||||||
|
|
||||||
|
tmp/
|
||||||
|
temp/
|
||||||
|
*.tmp
|
||||||
|
|
||||||
|
Thumbs.db
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
@ -8,9 +8,9 @@ RUN apt-get update && apt-get install -y \
|
|||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install -r requirements.txt
|
||||||
|
|
||||||
COPY . .
|
COPY . ./tg_bot/
|
||||||
|
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|||||||
@ -1,139 +1,130 @@
|
|||||||
import aiohttp
|
"""
|
||||||
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
RAG сервис для бота - вызывает API бэкенда
|
||||||
|
"""
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
|
from tg_bot.infrastructure.http_client import create_http_session
|
||||||
BACKEND_URL = "http://localhost:8001/api/v1"
|
|
||||||
|
|
||||||
|
|
||||||
class RAGService:
|
class RAGService:
|
||||||
|
"""Сервис для работы с RAG через API бэкенда"""
|
||||||
|
|
||||||
def __init__(self):
|
async def get_or_create_conversation(
|
||||||
self.deepseek_client = DeepSeekClient()
|
self,
|
||||||
|
user_telegram_id: str,
|
||||||
async def search_documents_in_collections(
|
collection_id: str = None
|
||||||
self,
|
) -> str | None:
|
||||||
user_telegram_id: str,
|
"""Получить или создать беседу для пользователя"""
|
||||||
query: str,
|
|
||||||
limit_per_collection: int = 5
|
|
||||||
) -> list[dict]:
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{BACKEND_URL}/users/telegram/{user_telegram_id}"
|
f"{settings.BACKEND_URL}/collections/",
|
||||||
) as user_response:
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
if user_response.status != 200:
|
|
||||||
return []
|
|
||||||
|
|
||||||
user_data = await user_response.json()
|
|
||||||
user_uuid = str(user_data.get("user_id"))
|
|
||||||
|
|
||||||
if not user_uuid:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async with session.get(
|
|
||||||
f"{BACKEND_URL}/collections/",
|
|
||||||
headers={"X-Telegram-ID": user_telegram_id}
|
|
||||||
) as collections_response:
|
) as collections_response:
|
||||||
if collections_response.status != 200:
|
if collections_response.status != 200:
|
||||||
return []
|
return None
|
||||||
|
|
||||||
collections = await collections_response.json()
|
collections = await collections_response.json()
|
||||||
|
|
||||||
all_documents = []
|
if not collections:
|
||||||
for collection in collections:
|
if not collection_id:
|
||||||
collection_id = collection.get("collection_id")
|
async with session.post(
|
||||||
|
f"{settings.BACKEND_URL}/collections",
|
||||||
|
json={
|
||||||
|
"name": "Основная коллекция",
|
||||||
|
"description": "Коллекция по умолчанию",
|
||||||
|
"is_public": False
|
||||||
|
},
|
||||||
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
|
) as create_collection_response:
|
||||||
|
if create_collection_response.status in [200, 201]:
|
||||||
|
collection_data = await create_collection_response.json()
|
||||||
|
collection_id = collection_data.get("collection_id")
|
||||||
|
else:
|
||||||
|
collection_id = collection_id
|
||||||
|
else:
|
||||||
|
collection_id = collections[0].get("collection_id")
|
||||||
|
|
||||||
if not collection_id:
|
if not collection_id:
|
||||||
continue
|
return None
|
||||||
|
|
||||||
try:
|
async with session.get(
|
||||||
async with aiohttp.ClientSession() as search_session:
|
f"{settings.BACKEND_URL}/conversations",
|
||||||
async with search_session.get(
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
f"{BACKEND_URL}/documents/collection/{collection_id}",
|
) as conversations_response:
|
||||||
params={"search": query, "limit": limit_per_collection},
|
if conversations_response.status == 200:
|
||||||
headers={"X-Telegram-ID": user_telegram_id}
|
conversations = await conversations_response.json()
|
||||||
) as search_response:
|
for conv in conversations:
|
||||||
if search_response.status == 200:
|
if conv.get("collection_id") == str(collection_id):
|
||||||
documents = await search_response.json()
|
return conv.get("conversation_id")
|
||||||
for doc in documents:
|
|
||||||
doc["collection_name"] = collection.get("name", "Unknown")
|
|
||||||
all_documents.append(doc)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error searching collection {collection_id}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return all_documents[:20]
|
async with session.post(
|
||||||
|
f"{settings.BACKEND_URL}/conversations",
|
||||||
|
json={"collection_id": str(collection_id)},
|
||||||
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
|
) as create_conversation_response:
|
||||||
|
if create_conversation_response.status in [200, 201]:
|
||||||
|
conversation_data = await create_conversation_response.json()
|
||||||
|
return conversation_data.get("conversation_id")
|
||||||
|
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error searching documents: {e}")
|
print(f"Error getting/creating conversation: {e}")
|
||||||
return []
|
return None
|
||||||
|
|
||||||
async def generate_answer_with_rag(
|
async def generate_answer_with_rag(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
user_telegram_id: str
|
user_telegram_id: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
documents = await self.search_documents_in_collections(
|
"""Генерирует ответ используя RAG через API бэкенда"""
|
||||||
user_telegram_id,
|
|
||||||
question
|
|
||||||
)
|
|
||||||
|
|
||||||
context_parts = []
|
|
||||||
sources = []
|
|
||||||
|
|
||||||
for doc in documents[:5]:
|
|
||||||
title = doc.get("title", "Без названия")
|
|
||||||
content = doc.get("content", "")[:1000]
|
|
||||||
collection_name = doc.get("collection_name", "Unknown")
|
|
||||||
|
|
||||||
context_parts.append(f"Документ: {title}\nКоллекция: {collection_name}\nСодержание: {content[:500]}...")
|
|
||||||
sources.append({
|
|
||||||
"title": title,
|
|
||||||
"collection": collection_name,
|
|
||||||
"document_id": doc.get("document_id")
|
|
||||||
})
|
|
||||||
|
|
||||||
context = "\n\n".join(context_parts) if context_parts else "Релевантные документы не найдены."
|
|
||||||
|
|
||||||
system_prompt = """Ты - помощник-юрист, который отвечает на вопросы на основе предоставленных документов.
|
|
||||||
Используй информацию из документов для формирования точного и полезного ответа.
|
|
||||||
Если в документах нет информации для ответа, честно скажи об этом."""
|
|
||||||
|
|
||||||
user_prompt = f"""Контекст из документов:
|
|
||||||
{context}
|
|
||||||
|
|
||||||
Вопрос пользователя: {question}
|
|
||||||
|
|
||||||
Ответь на вопрос, используя информацию из предоставленных документов. Если информации недостаточно, укажи это."""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
messages = [
|
conversation_id = await self.get_or_create_conversation(user_telegram_id)
|
||||||
{"role": "system", "content": system_prompt},
|
if not conversation_id:
|
||||||
{"role": "user", "content": user_prompt}
|
|
||||||
]
|
|
||||||
|
|
||||||
response = await self.deepseek_client.chat_completion(
|
|
||||||
messages=messages,
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=2000
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"answer": response.get("content", "Failed to generate answer"),
|
|
||||||
"sources": sources,
|
|
||||||
"usage": response.get("usage", {})
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error generating answer: {e}")
|
|
||||||
if documents:
|
|
||||||
return {
|
return {
|
||||||
"answer": f"Found {len(documents)} documents but failed to generate answer",
|
"answer": "Не удалось создать беседу. Попробуйте позже.",
|
||||||
"sources": sources[:3],
|
|
||||||
"usage": {}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"answer": "No relevant documents found",
|
|
||||||
"sources": [],
|
"sources": [],
|
||||||
"usage": {}
|
"usage": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.BACKEND_URL}/rag/question",
|
||||||
|
json={
|
||||||
|
"conversation_id": str(conversation_id),
|
||||||
|
"question": question,
|
||||||
|
"top_k": 20,
|
||||||
|
"rerank_top_n": 5
|
||||||
|
},
|
||||||
|
headers={"X-Telegram-ID": user_telegram_id}
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
result = await response.json()
|
||||||
|
sources = []
|
||||||
|
for source in result.get("sources", []):
|
||||||
|
sources.append({
|
||||||
|
"title": source.get("title", "Без названия"),
|
||||||
|
"document_id": source.get("document_id"),
|
||||||
|
"chunk_id": source.get("chunk_id"),
|
||||||
|
"index": source.get("index", 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"answer": result.get("answer", "Не удалось сгенерировать ответ."),
|
||||||
|
"sources": sources,
|
||||||
|
"usage": result.get("usage", {}),
|
||||||
|
"conversation_id": str(conversation_id)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
print(f"RAG API error: {response.status} - {error_text}")
|
||||||
|
return {
|
||||||
|
"answer": "Ошибка при генерации ответа. Попробуйте позже.",
|
||||||
|
"sources": [],
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating answer with RAG: {e}")
|
||||||
|
return {
|
||||||
|
"answer": "Произошла ошибка при генерации ответа. Попробуйте позже.",
|
||||||
|
"sources": [],
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import os
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
"""Настройки приложения получаеи из env файла, тут не ищи, мы спрятали:)"""
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
env_file_encoding="utf-8",
|
env_file_encoding="utf-8",
|
||||||
@ -13,26 +14,35 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
APP_NAME: str = "VibeLawyerBot"
|
APP_NAME: str = "VibeLawyerBot"
|
||||||
VERSION: str = "0.1.0"
|
VERSION: str = "0.1.0"
|
||||||
DEBUG: bool = True
|
DEBUG: bool = False
|
||||||
TELEGRAM_BOT_TOKEN: str = ""
|
|
||||||
|
TELEGRAM_BOT_TOKEN: str
|
||||||
|
|
||||||
FREE_QUESTIONS_LIMIT: int = 5
|
FREE_QUESTIONS_LIMIT: int = 5
|
||||||
PAYMENT_AMOUNT: float = 500.0
|
PAYMENT_AMOUNT: float = 500.0
|
||||||
DATABASE_URL: str = "sqlite:///data/bot.db"
|
|
||||||
LOG_LEVEL: str = "INFO"
|
LOG_LEVEL: str = "INFO"
|
||||||
LOG_FILE: str = "logs/bot.log"
|
LOG_FILE: str = "logs/bot.log"
|
||||||
|
|
||||||
YOOKASSA_SHOP_ID: str = "1230200"
|
|
||||||
YOOKASSA_SECRET_KEY: str = "test_GVoixmlp0FqohXcyFzFHbRlAUoA3B1I2aMtAkAE_ubw"
|
YOOKASSA_SHOP_ID: str
|
||||||
|
YOOKASSA_SECRET_KEY: str
|
||||||
YOOKASSA_RETURN_URL: str = "https://t.me/vibelawyer_bot"
|
YOOKASSA_RETURN_URL: str = "https://t.me/vibelawyer_bot"
|
||||||
YOOKASSA_WEBHOOK_SECRET: Optional[str] = None
|
YOOKASSA_WEBHOOK_SECRET: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
DEEPSEEK_API_KEY: Optional[str] = None
|
DEEPSEEK_API_KEY: Optional[str] = None
|
||||||
DEEPSEEK_API_URL: str = "https://api.deepseek.com/v1/chat/completions"
|
DEEPSEEK_API_URL: str = "https://api.deepseek.com/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
|
BACKEND_URL: str
|
||||||
|
|
||||||
|
|
||||||
ADMIN_IDS_STR: str = ""
|
ADMIN_IDS_STR: str = ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ADMIN_IDS(self) -> List[int]:
|
def ADMIN_IDS(self) -> List[int]:
|
||||||
|
"""Список ID администраторов из строки через запятую"""
|
||||||
if not self.ADMIN_IDS_STR:
|
if not self.ADMIN_IDS_STR:
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,67 +0,0 @@
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy import select
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Optional
|
|
||||||
from tg_bot.infrastructure.database.models import UserModel
|
|
||||||
|
|
||||||
|
|
||||||
class UserService:
|
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession):
|
|
||||||
self.session = session
|
|
||||||
|
|
||||||
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[UserModel]:
|
|
||||||
result = await self.session.execute(
|
|
||||||
select(UserModel).filter_by(telegram_id=str(telegram_id))
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
async def get_or_create_user(
|
|
||||||
self,
|
|
||||||
telegram_id: int,
|
|
||||||
username: str = "",
|
|
||||||
first_name: str = "",
|
|
||||||
last_name: str = ""
|
|
||||||
) -> UserModel:
|
|
||||||
user = await self.get_user_by_telegram_id(telegram_id)
|
|
||||||
if not user:
|
|
||||||
user = UserModel(
|
|
||||||
telegram_id=str(telegram_id),
|
|
||||||
username=username,
|
|
||||||
first_name=first_name,
|
|
||||||
last_name=last_name
|
|
||||||
)
|
|
||||||
self.session.add(user)
|
|
||||||
await self.session.commit()
|
|
||||||
else:
|
|
||||||
user.username = username
|
|
||||||
user.first_name = first_name
|
|
||||||
user.last_name = last_name
|
|
||||||
await self.session.commit()
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def update_user_questions(self, telegram_id: int) -> bool:
|
|
||||||
user = await self.get_user_by_telegram_id(telegram_id)
|
|
||||||
if user:
|
|
||||||
user.questions_used += 1
|
|
||||||
await self.session.commit()
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def activate_premium(self, telegram_id: int) -> bool:
|
|
||||||
try:
|
|
||||||
user = await self.get_user_by_telegram_id(telegram_id)
|
|
||||||
if user:
|
|
||||||
user.is_premium = True
|
|
||||||
if user.premium_until and user.premium_until > datetime.now():
|
|
||||||
user.premium_until = user.premium_until + timedelta(days=30)
|
|
||||||
else:
|
|
||||||
user.premium_until = datetime.now() + timedelta(days=30)
|
|
||||||
await self.session.commit()
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error activating premium: {e}")
|
|
||||||
await self.session.rollback()
|
|
||||||
return False
|
|
||||||
126
tg_bot/domain/user_service.py
Normal file
126
tg_bot/domain/user_service.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from tg_bot.config.settings import settings
|
||||||
|
from tg_bot.infrastructure.http_client import create_http_session
|
||||||
|
|
||||||
|
|
||||||
|
class User:
|
||||||
|
"""Модель пользователя для телеграм-бота"""
|
||||||
|
def __init__(self, data: dict):
|
||||||
|
self.user_id = data.get("user_id")
|
||||||
|
self.telegram_id = data.get("telegram_id")
|
||||||
|
self.role = data.get("role")
|
||||||
|
created_at_str = data.get("created_at")
|
||||||
|
if created_at_str:
|
||||||
|
try:
|
||||||
|
created_at_str = created_at_str.replace("Z", "+00:00")
|
||||||
|
self.created_at = datetime.fromisoformat(created_at_str)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
self.created_at = None
|
||||||
|
else:
|
||||||
|
self.created_at = None
|
||||||
|
|
||||||
|
premium_until_str = data.get("premium_until")
|
||||||
|
if premium_until_str:
|
||||||
|
try:
|
||||||
|
premium_until_str = premium_until_str.replace("Z", "+00:00")
|
||||||
|
self.premium_until = datetime.fromisoformat(premium_until_str)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
self.premium_until = None
|
||||||
|
else:
|
||||||
|
self.premium_until = None
|
||||||
|
|
||||||
|
self.is_premium = data.get("is_premium", False)
|
||||||
|
self.questions_used = data.get("questions_used", 0)
|
||||||
|
|
||||||
|
|
||||||
|
class UserService:
|
||||||
|
"""Сервис для работы с пользователями через API бэкенда"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.backend_url = settings.BACKEND_URL
|
||||||
|
print(f"UserService initialized with BACKEND_URL: {self.backend_url}")
|
||||||
|
|
||||||
|
async def get_user_by_telegram_id(self, telegram_id: int) -> Optional[User]:
|
||||||
|
"""Получить пользователя по Telegram ID"""
|
||||||
|
try:
|
||||||
|
url = f"{self.backend_url}/users/telegram/{telegram_id}"
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.get(url) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
return User(data)
|
||||||
|
return None
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
print(f"Backend not available at {self.backend_url}: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting user: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_or_create_user(
|
||||||
|
self,
|
||||||
|
telegram_id: int,
|
||||||
|
username: str = "",
|
||||||
|
first_name: str = "",
|
||||||
|
last_name: str = ""
|
||||||
|
) -> User:
|
||||||
|
"""Получить или создать пользователя"""
|
||||||
|
user = await self.get_user_by_telegram_id(telegram_id)
|
||||||
|
if not user:
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.backend_url}/users",
|
||||||
|
json={"telegram_id": str(telegram_id), "role": "user"}
|
||||||
|
) as response:
|
||||||
|
if response.status in [200, 201]:
|
||||||
|
data = await response.json()
|
||||||
|
return User(data)
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
raise Exception(
|
||||||
|
f"Backend API returned status {response.status}: {error_text}. "
|
||||||
|
f"Make sure the backend server is running at {self.backend_url}"
|
||||||
|
)
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
error_msg = (
|
||||||
|
f"Cannot connect to backend API at {self.backend_url}. "
|
||||||
|
f"Please ensure the backend server is running on port 8000. "
|
||||||
|
f"Start it with: cd project/backend && python run.py"
|
||||||
|
)
|
||||||
|
print(f"Error creating user: {error_msg}")
|
||||||
|
print(f"Original error: {e}")
|
||||||
|
raise ConnectionError(error_msg) from e
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error creating user: {e}. Backend URL: {self.backend_url}"
|
||||||
|
print(error_msg)
|
||||||
|
raise
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def update_user_questions(self, telegram_id: int) -> bool:
|
||||||
|
"""Увеличить счетчик использованных вопросов"""
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.backend_url}/users/telegram/{telegram_id}/increment-questions"
|
||||||
|
) as response:
|
||||||
|
return response.status == 200
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error updating questions: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def activate_premium(self, telegram_id: int, days: int = 30) -> bool:
|
||||||
|
"""Активировать premium статус"""
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.backend_url}/users/telegram/{telegram_id}/activate-premium",
|
||||||
|
params={"days": days}
|
||||||
|
) as response:
|
||||||
|
return response.status == 200
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error activating premium: {e}")
|
||||||
|
return False
|
||||||
2
tg_bot/infrastructure/__init__.py
Normal file
2
tg_bot/infrastructure/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
"""Infrastructure layer for the Telegram bot"""
|
||||||
|
|
||||||
@ -1,19 +0,0 @@
|
|||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
|
||||||
from tg_bot.config.settings import settings
|
|
||||||
|
|
||||||
database_url = settings.DATABASE_URL
|
|
||||||
if database_url.startswith("sqlite:///"):
|
|
||||||
database_url = database_url.replace("sqlite:///", "sqlite+aiosqlite:///")
|
|
||||||
|
|
||||||
engine = create_async_engine(
|
|
||||||
database_url,
|
|
||||||
echo=settings.DEBUG
|
|
||||||
)
|
|
||||||
|
|
||||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|
||||||
|
|
||||||
async def create_tables():
|
|
||||||
from .models import Base
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
print(f"Таблицы созданы: {settings.DATABASE_URL}")
|
|
||||||
@ -1,39 +0,0 @@
|
|||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from sqlalchemy import Column, String, DateTime, Boolean, Integer, Text
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class UserModel(Base):
|
|
||||||
__tablename__ = "users"
|
|
||||||
user_id = Column("user_id", String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
|
||||||
telegram_id = Column("telegram_id", String(100), nullable=False, unique=True)
|
|
||||||
created_at = Column("created_at", DateTime, default=datetime.utcnow, nullable=False)
|
|
||||||
role = Column("role", String(20), default="user", nullable=False)
|
|
||||||
|
|
||||||
is_premium = Column(Boolean, default=False, nullable=False)
|
|
||||||
premium_until = Column(DateTime, nullable=True)
|
|
||||||
questions_used = Column(Integer, default=0, nullable=False)
|
|
||||||
|
|
||||||
username = Column(String(100), nullable=True)
|
|
||||||
first_name = Column(String(100), nullable=True)
|
|
||||||
last_name = Column(String(100), nullable=True)
|
|
||||||
|
|
||||||
|
|
||||||
class PaymentModel(Base):
|
|
||||||
__tablename__ = "payments"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
payment_id = Column(String(36), default=lambda: str(uuid.uuid4()), nullable=False, unique=True)
|
|
||||||
user_id = Column(Integer, nullable=False)
|
|
||||||
amount = Column(String(20), nullable=False)
|
|
||||||
currency = Column(String(3), default="RUB", nullable=False)
|
|
||||||
status = Column(String(20), default="pending", nullable=False)
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
|
||||||
yookassa_payment_id = Column(String(100), unique=True, nullable=True)
|
|
||||||
description = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<Payment(user_id={self.user_id}, amount={self.amount}, status={self.status})>"
|
|
||||||
172
tg_bot/infrastructure/external/deepseek_client.py
vendored
172
tg_bot/infrastructure/external/deepseek_client.py
vendored
@ -1,172 +0,0 @@
|
|||||||
import json
|
|
||||||
from typing import Optional, AsyncIterator
|
|
||||||
import httpx
|
|
||||||
from tg_bot.config.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekAPIError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekClient:
|
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, api_url: str | None = None):
|
|
||||||
self.api_key = api_key or settings.DEEPSEEK_API_KEY
|
|
||||||
self.api_url = api_url or settings.DEEPSEEK_API_URL
|
|
||||||
self.timeout = 60.0
|
|
||||||
|
|
||||||
def _get_headers(self) -> dict[str, str]:
|
|
||||||
if not self.api_key:
|
|
||||||
raise DeepSeekAPIError("API key not set")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, str]],
|
|
||||||
model: str = "deepseek-chat",
|
|
||||||
temperature: float = 0.7,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
stream: bool = False
|
|
||||||
) -> dict:
|
|
||||||
if not self.api_key:
|
|
||||||
return {
|
|
||||||
"content": "API key not configured",
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
"stream": stream
|
|
||||||
}
|
|
||||||
|
|
||||||
if max_tokens is not None:
|
|
||||||
payload["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
response = await client.post(
|
|
||||||
self.api_url,
|
|
||||||
headers=self._get_headers(),
|
|
||||||
json=payload
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if "choices" in data and len(data["choices"]) > 0:
|
|
||||||
content = data["choices"][0]["message"]["content"]
|
|
||||||
else:
|
|
||||||
raise DeepSeekAPIError("Invalid response format")
|
|
||||||
|
|
||||||
usage = data.get("usage", {})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": content,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
|
||||||
"completion_tokens": usage.get("completion_tokens", 0),
|
|
||||||
"total_tokens": usage.get("total_tokens", 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_msg = f"API error: {e.response.status_code}"
|
|
||||||
try:
|
|
||||||
error_data = e.response.json()
|
|
||||||
if "error" in error_data:
|
|
||||||
error_msg = error_data['error'].get('message', error_msg)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise DeepSeekAPIError(error_msg) from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
|
|
||||||
except Exception as e:
|
|
||||||
raise DeepSeekAPIError(str(e)) from e
|
|
||||||
|
|
||||||
async def stream_chat_completion(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, str]],
|
|
||||||
model: str = "deepseek-chat",
|
|
||||||
temperature: float = 0.7,
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
) -> AsyncIterator[str]:
|
|
||||||
if not self.api_key:
|
|
||||||
yield "API key not configured"
|
|
||||||
return
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
"stream": True
|
|
||||||
}
|
|
||||||
|
|
||||||
if max_tokens is not None:
|
|
||||||
payload["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
async with client.stream(
|
|
||||||
"POST",
|
|
||||||
self.api_url,
|
|
||||||
headers=self._get_headers(),
|
|
||||||
json=payload
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
if line.startswith("data: "):
|
|
||||||
line = line[6:]
|
|
||||||
|
|
||||||
if line.strip() == "[DONE]":
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
|
|
||||||
if "choices" in data and len(data["choices"]) > 0:
|
|
||||||
delta = data["choices"][0].get("delta", {})
|
|
||||||
content = delta.get("content", "")
|
|
||||||
if content:
|
|
||||||
yield content
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_msg = f"API error: {e.response.status_code}"
|
|
||||||
try:
|
|
||||||
error_data = e.response.json()
|
|
||||||
if "error" in error_data:
|
|
||||||
error_msg = error_data['error'].get('message', error_msg)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise DeepSeekAPIError(error_msg) from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise DeepSeekAPIError(f"Connection error: {str(e)}") from e
|
|
||||||
except Exception as e:
|
|
||||||
raise DeepSeekAPIError(str(e)) from e
|
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
|
||||||
if not self.api_key:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_messages = [{"role": "user", "content": "test"}]
|
|
||||||
await self.chat_completion(test_messages, max_tokens=1)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
24
tg_bot/infrastructure/http_client.py
Normal file
24
tg_bot/infrastructure/http_client.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import aiohttp
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_session(timeout: Optional[aiohttp.ClientTimeout] = None) -> aiohttp.ClientSession:
|
||||||
|
"""
|
||||||
|
Создаем сессию для запросов к бэку
|
||||||
|
"""
|
||||||
|
if timeout is None:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=30, connect=10)
|
||||||
|
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
limit=100,
|
||||||
|
limit_per_host=30
|
||||||
|
)
|
||||||
|
|
||||||
|
return aiohttp.ClientSession(
|
||||||
|
connector=connector,
|
||||||
|
timeout=timeout,
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ from tg_bot.infrastructure.telegram.handlers import (
|
|||||||
stats_handler,
|
stats_handler,
|
||||||
question_handler,
|
question_handler,
|
||||||
buy_handler,
|
buy_handler,
|
||||||
collection_handler
|
collection_handler,
|
||||||
|
document_handler
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -25,17 +26,28 @@ async def create_bot() -> tuple[Bot, Dispatcher]:
|
|||||||
dp.include_router(start_handler.router)
|
dp.include_router(start_handler.router)
|
||||||
dp.include_router(help_handler.router)
|
dp.include_router(help_handler.router)
|
||||||
dp.include_router(stats_handler.router)
|
dp.include_router(stats_handler.router)
|
||||||
dp.include_router(question_handler.router)
|
|
||||||
dp.include_router(buy_handler.router)
|
dp.include_router(buy_handler.router)
|
||||||
dp.include_router(collection_handler.router)
|
dp.include_router(collection_handler.router)
|
||||||
|
dp.include_router(document_handler.router)
|
||||||
|
dp.include_router(question_handler.router)
|
||||||
return bot, dp
|
return bot, dp
|
||||||
|
|
||||||
|
|
||||||
async def start_bot():
|
async def start_bot():
|
||||||
bot = None
|
bot = None
|
||||||
try:
|
try:
|
||||||
|
if not settings.TELEGRAM_BOT_TOKEN or not settings.TELEGRAM_BOT_TOKEN.strip():
|
||||||
|
raise ValueError("TELEGRAM_BOT_TOKEN не установлен в переменных окружения или файле .env")
|
||||||
|
|
||||||
bot, dp = await create_bot()
|
bot, dp = await create_bot()
|
||||||
|
|
||||||
|
try:
|
||||||
|
bot_info = await bot.get_me()
|
||||||
|
username = bot_info.username if bot_info.username else f"ID: {bot_info.id}"
|
||||||
|
logger.info(f"Бот успешно подключен: @{username}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Неверный токен Telegram бота: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webhook_info = await bot.get_webhook_info()
|
webhook_info = await bot.get_webhook_info()
|
||||||
if webhook_info.url:
|
if webhook_info.url:
|
||||||
|
|||||||
@ -2,16 +2,14 @@ from aiogram import Router, types
|
|||||||
from aiogram.filters import Command
|
from aiogram.filters import Command
|
||||||
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton
|
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
import aiohttp
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
from tg_bot.payment.yookassa.client import yookassa_client
|
from tg_bot.payment.yookassa.client import yookassa_client
|
||||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
from tg_bot.domain.user_service import UserService
|
||||||
from tg_bot.infrastructure.database.models import PaymentModel
|
from datetime import datetime
|
||||||
from tg_bot.domain.services.user_service import UserService
|
|
||||||
from sqlalchemy import select
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
router = Router()
|
router = Router()
|
||||||
|
user_service = UserService()
|
||||||
|
|
||||||
|
|
||||||
@router.message(Command("buy"))
|
@router.message(Command("buy"))
|
||||||
@ -19,23 +17,23 @@ async def cmd_buy(message: Message):
|
|||||||
user_id = message.from_user.id
|
user_id = message.from_user.id
|
||||||
username = message.from_user.username or f"user_{user_id}"
|
username = message.from_user.username or f"user_{user_id}"
|
||||||
|
|
||||||
async with AsyncSessionLocal() as session:
|
try:
|
||||||
try:
|
user = await user_service.get_user_by_telegram_id(user_id)
|
||||||
user_service = UserService(session)
|
|
||||||
user = await user_service.get_user_by_telegram_id(user_id)
|
|
||||||
|
|
||||||
if user and user.is_premium and user.premium_until and user.premium_until > datetime.now():
|
if user and user.is_premium and user.premium_until and user.premium_until > datetime.now():
|
||||||
days_left = (user.premium_until - datetime.now()).days
|
days_left = (user.premium_until - datetime.now()).days
|
||||||
await message.answer(
|
await message.answer(
|
||||||
f"<b>У вас уже есть активная подписка!</b>\n\n"
|
f"<b>У вас уже есть активная подписка!</b>\n\n"
|
||||||
f"• Статус: Premium активен\n"
|
f"• Статус: Premium активен\n"
|
||||||
f"• Действует до: {user.premium_until.strftime('%d.%m.%Y')}\n"
|
f"• Действует до: {user.premium_until.strftime('%d.%m.%Y')}\n"
|
||||||
f"• Осталось дней: {days_left}\n\n"
|
f"• Осталось дней: {days_left}\n\n"
|
||||||
f"Новая подписка будет добавлена к текущей.",
|
f"Новая подписка будет добавлена к текущей.",
|
||||||
parse_mode="HTML"
|
parse_mode="HTML"
|
||||||
)
|
)
|
||||||
except Exception:
|
except aiohttp.ClientError as e:
|
||||||
pass
|
print(f"Не удалось подключиться к backend при проверке подписки: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ошибка при проверке подписки: {e}")
|
||||||
|
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"*Создаю ссылку для оплаты...*\n\n"
|
"*Создаю ссылку для оплаты...*\n\n"
|
||||||
@ -50,23 +48,7 @@ async def cmd_buy(message: Message):
|
|||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncSessionLocal() as session:
|
print(f"Платёж создан в ЮKассе: {payment_data['id']}")
|
||||||
try:
|
|
||||||
payment = PaymentModel(
|
|
||||||
payment_id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
amount=str(settings.PAYMENT_AMOUNT),
|
|
||||||
currency="RUB",
|
|
||||||
status="pending",
|
|
||||||
yookassa_payment_id=payment_data["id"],
|
|
||||||
description="Оплата подписки VibeLawyerBot"
|
|
||||||
)
|
|
||||||
session.add(payment)
|
|
||||||
await session.commit()
|
|
||||||
print(f"Платёж сохранён в БД: {payment.payment_id}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Ошибка сохранения платежа в БД: {e}")
|
|
||||||
await session.rollback()
|
|
||||||
|
|
||||||
keyboard = InlineKeyboardMarkup(
|
keyboard = InlineKeyboardMarkup(
|
||||||
inline_keyboard=[
|
inline_keyboard=[
|
||||||
@ -139,27 +121,15 @@ async def check_payment_status(callback_query: types.CallbackQuery):
|
|||||||
payment = YooPayment.find_one(yookassa_id)
|
payment = YooPayment.find_one(yookassa_id)
|
||||||
|
|
||||||
if payment.status == "succeeded":
|
if payment.status == "succeeded":
|
||||||
async with AsyncSessionLocal() as session:
|
try:
|
||||||
try:
|
success = await user_service.activate_premium(user_id)
|
||||||
result = await session.execute(
|
if success:
|
||||||
select(PaymentModel).filter_by(yookassa_payment_id=yookassa_id)
|
user = await user_service.get_user_by_telegram_id(user_id)
|
||||||
)
|
if user:
|
||||||
db_payment = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if db_payment:
|
|
||||||
db_payment.status = "succeeded"
|
|
||||||
user_service = UserService(session)
|
|
||||||
success = await user_service.activate_premium(user_id)
|
|
||||||
if success:
|
|
||||||
user = await user_service.get_user_by_telegram_id(user_id)
|
|
||||||
await session.commit()
|
|
||||||
if not user:
|
|
||||||
user = await user_service.get_user_by_telegram_id(user_id)
|
|
||||||
|
|
||||||
await callback_query.message.answer(
|
await callback_query.message.answer(
|
||||||
"<b>Оплата подтверждена!</b>\n\n"
|
"<b>Оплата подтверждена!</b>\n\n"
|
||||||
f"Ваш premium-доступ активирован до: "
|
f"Ваш premium-доступ активирован до: "
|
||||||
f"<b>{user.premium_until.strftime('%d.%m.%Y')}</b>\n\n"
|
f"<b>{user.premium_until.strftime('%d.%m.%Y') if user.premium_until else 'Не указано'}</b>\n\n"
|
||||||
"Теперь вы можете:\n"
|
"Теперь вы можете:\n"
|
||||||
"• Задавать неограниченное количество вопросов\n"
|
"• Задавать неограниченное количество вопросов\n"
|
||||||
"• Получать приоритетные ответы\n"
|
"• Получать приоритетные ответы\n"
|
||||||
@ -169,12 +139,23 @@ async def check_payment_status(callback_query: types.CallbackQuery):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await callback_query.message.answer(
|
await callback_query.message.answer(
|
||||||
"<b>Платёж найден в ЮKассе, но не в нашей БД</b>\n\n"
|
"<b>Оплата подтверждена, но не удалось активировать premium</b>\n\n"
|
||||||
"Пожалуйста, обратитесь к администратору.",
|
"Пожалуйста, обратитесь к администратору.",
|
||||||
parse_mode="HTML"
|
parse_mode="HTML"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
else:
|
||||||
print(f"Ошибка обработки платежа: {e}")
|
await callback_query.message.answer(
|
||||||
|
"<b>Оплата подтверждена, но не удалось активировать premium</b>\n\n"
|
||||||
|
"Пожалуйста, обратитесь к администратору.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ошибка обработки платежа: {e}")
|
||||||
|
await callback_query.message.answer(
|
||||||
|
"<b>Ошибка активации premium</b>\n\n"
|
||||||
|
"Пожалуйста, обратитесь к администратору.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
elif payment.status == "pending":
|
elif payment.status == "pending":
|
||||||
await callback_query.message.answer(
|
await callback_query.message.answer(
|
||||||
@ -206,42 +187,13 @@ async def check_payment_status(callback_query: types.CallbackQuery):
|
|||||||
|
|
||||||
@router.message(Command("mypayments"))
|
@router.message(Command("mypayments"))
|
||||||
async def cmd_my_payments(message: Message):
|
async def cmd_my_payments(message: Message):
|
||||||
user_id = message.from_user.id
|
await message.answer(
|
||||||
|
"<b>История платежей</b>\n\n"
|
||||||
async with AsyncSessionLocal() as session:
|
"История платежей хранится в системе оплаты ЮKassa.\n"
|
||||||
try:
|
"Для проверки статуса подписки используйте команду /stats.\n\n"
|
||||||
result = await session.execute(
|
"Для оформления новой подписки используйте команду /buy",
|
||||||
select(PaymentModel).filter_by(user_id=user_id).order_by(PaymentModel.created_at.desc()).limit(10)
|
parse_mode="HTML"
|
||||||
)
|
)
|
||||||
payments = result.scalars().all()
|
|
||||||
|
|
||||||
if not payments:
|
|
||||||
await message.answer(
|
|
||||||
"<b>У вас пока нет платежей</b>\n\n"
|
|
||||||
"Используйте команду /buy чтобы оформить подписку.",
|
|
||||||
parse_mode="HTML"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
response = ["<b>Ваши последние платежи:</b>\n"]
|
|
||||||
|
|
||||||
for i, payment in enumerate(payments, 1):
|
|
||||||
status_text = "Успешно" if payment.status == "succeeded" else "Ожидание" if payment.status == "pending" else "Ошибка"
|
|
||||||
response.append(
|
|
||||||
f"\n<b>{i}. {payment.amount} руб. ({status_text})</b>\n"
|
|
||||||
f"Статус: {payment.status}\n"
|
|
||||||
f"Дата: {payment.created_at.strftime('%d.%m.%Y %H:%M')}\n"
|
|
||||||
f"ID: <code>{payment.payment_id[:8]}...</code>"
|
|
||||||
)
|
|
||||||
|
|
||||||
response.append("\n\n<i>Полный доступ открывается после успешной оплаты</i>")
|
|
||||||
|
|
||||||
await message.answer(
|
|
||||||
"\n".join(response),
|
|
||||||
parse_mode="HTML"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Ошибка получения платежей: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.message(Command("testcards"))
|
@router.message(Command("testcards"))
|
||||||
|
|||||||
@ -1,18 +1,36 @@
|
|||||||
from aiogram import Router
|
from aiogram import Router, F
|
||||||
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, CallbackQuery
|
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, CallbackQuery
|
||||||
from aiogram.filters import Command
|
from aiogram.filters import Command, StateFilter
|
||||||
|
from aiogram.fsm.context import FSMContext
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
from urllib.parse import unquote
|
||||||
|
from tg_bot.config.settings import settings
|
||||||
|
from tg_bot.infrastructure.http_client import create_http_session
|
||||||
|
from tg_bot.infrastructure.telegram.states.collection_states import (
|
||||||
|
CollectionAccessStates,
|
||||||
|
CollectionEditStates
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_title(title: str) -> str:
|
||||||
|
if not title:
|
||||||
|
return "Без названия"
|
||||||
|
try:
|
||||||
|
decoded = unquote(title)
|
||||||
|
if decoded != title or '%' not in title:
|
||||||
|
return decoded
|
||||||
|
return title
|
||||||
|
except Exception:
|
||||||
|
return title
|
||||||
|
|
||||||
router = Router()
|
router = Router()
|
||||||
|
|
||||||
BACKEND_URL = "http://localhost:8001/api/v1"
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_collections(telegram_id: str):
|
async def get_user_collections(telegram_id: str):
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{BACKEND_URL}/collections/",
|
f"{settings.BACKEND_URL}/collections/",
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
) as response:
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
@ -25,24 +43,37 @@ async def get_user_collections(telegram_id: str):
|
|||||||
|
|
||||||
async def get_collection_documents(collection_id: str, telegram_id: str):
|
async def get_collection_documents(collection_id: str, telegram_id: str):
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
collection_id = str(collection_id).strip()
|
||||||
|
url = f"{settings.BACKEND_URL}/documents/collection/{collection_id}"
|
||||||
|
print(f"DEBUG get_collection_documents: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
|
||||||
|
|
||||||
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{BACKEND_URL}/documents/collection/{collection_id}",
|
url,
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
) as response:
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
return []
|
elif response.status == 422:
|
||||||
|
error_text = await response.text()
|
||||||
|
print(f"Validation error getting documents: {response.status} - {error_text}, collection_id: {collection_id}, URL: {url}")
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
print(f"Error getting documents: {response.status} - {error_text}, collection_id: {collection_id}, URL: {url}")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting documents: {e}")
|
print(f"Exception getting documents: {e}, collection_id: {collection_id}, type: {type(collection_id)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
async def search_in_collection(collection_id: str, query: str, telegram_id: str):
|
async def search_in_collection(collection_id: str, query: str, telegram_id: str):
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with create_http_session() as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{BACKEND_URL}/documents/collection/{collection_id}",
|
f"{settings.BACKEND_URL}/documents/collection/{collection_id}",
|
||||||
params={"search": query},
|
params={"search": query},
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
) as response:
|
) as response:
|
||||||
@ -54,6 +85,91 @@ async def search_in_collection(collection_id: str, query: str, telegram_id: str)
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def get_collection_info(collection_id: str, telegram_id: str):
|
||||||
|
"""Получить информацию о коллекции"""
|
||||||
|
try:
|
||||||
|
collection_id = str(collection_id).strip()
|
||||||
|
url = f"{settings.BACKEND_URL}/collections/{collection_id}"
|
||||||
|
print(f"DEBUG get_collection_info: URL={url}, collection_id={collection_id}, telegram_id={telegram_id}")
|
||||||
|
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.get(
|
||||||
|
url,
|
||||||
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
elif response.status == 422:
|
||||||
|
error_text = await response.text()
|
||||||
|
print(f"Validation error getting collection info: {response.status} - {error_text}, collection_id: {collection_id}, URL: {url}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
print(f"Error getting collection info: {response.status} - {error_text}, collection_id: {collection_id}, URL: {url}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception getting collection info: {e}, collection_id: {collection_id}, type: {type(collection_id)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_collection_access_list(collection_id: str, telegram_id: str):
|
||||||
|
"""Получить список пользователей с доступом к коллекции"""
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{settings.BACKEND_URL}/collections/{collection_id}/access",
|
||||||
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting access list: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def grant_collection_access(collection_id: str, telegram_id: str, owner_telegram_id: str):
|
||||||
|
"""Предоставить доступ к коллекции"""
|
||||||
|
try:
|
||||||
|
url = f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}"
|
||||||
|
print(f"DEBUG grant_collection_access: URL={url}, target_telegram_id={telegram_id}, owner_telegram_id={owner_telegram_id}")
|
||||||
|
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
headers={"X-Telegram-ID": owner_telegram_id}
|
||||||
|
) as response:
|
||||||
|
if response.status == 201:
|
||||||
|
result = await response.json()
|
||||||
|
print(f"DEBUG: Access granted successfully: {result}")
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
print(f"ERROR granting access: status={response.status}, error={error_text}, target_telegram_id={telegram_id}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception granting access: {e}, target_telegram_id={telegram_id}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_collection_access(collection_id: str, telegram_id: str, owner_telegram_id: str):
|
||||||
|
"""Отозвать доступ к коллекции"""
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.delete(
|
||||||
|
f"{settings.BACKEND_URL}/collections/{collection_id}/access/telegram/{telegram_id}",
|
||||||
|
headers={"X-Telegram-ID": owner_telegram_id}
|
||||||
|
) as response:
|
||||||
|
return response.status == 204
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error revoking access: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@router.message(Command("mycollections"))
|
@router.message(Command("mycollections"))
|
||||||
async def cmd_mycollections(message: Message):
|
async def cmd_mycollections(message: Message):
|
||||||
telegram_id = str(message.from_user.id)
|
telegram_id = str(message.from_user.id)
|
||||||
@ -140,7 +256,7 @@ async def cmd_search(message: Message):
|
|||||||
|
|
||||||
response = f"<b>Результаты поиска:</b> \"{query}\"\n\n"
|
response = f"<b>Результаты поиска:</b> \"{query}\"\n\n"
|
||||||
for i, doc in enumerate(results[:5], 1):
|
for i, doc in enumerate(results[:5], 1):
|
||||||
title = doc.get("title", "Без названия")
|
title = decode_title(doc.get("title", "Без названия"))
|
||||||
content = doc.get("content", "")[:200]
|
content = doc.get("content", "")[:200]
|
||||||
response += f"{i}. <b>{title}</b>\n"
|
response += f"{i}. <b>{title}</b>\n"
|
||||||
response += f" <i>{content}...</i>\n\n"
|
response += f" <i>{content}...</i>\n\n"
|
||||||
@ -148,36 +264,495 @@ async def cmd_search(message: Message):
|
|||||||
await message.answer(response, parse_mode="HTML")
|
await message.answer(response, parse_mode="HTML")
|
||||||
|
|
||||||
|
|
||||||
@router.callback_query(lambda c: c.data.startswith("collection:"))
|
@router.callback_query(lambda c: c.data.startswith("collection:") and not c.data.startswith("collection:documents:") and not c.data.startswith("collection:edit:") and not c.data.startswith("collection:access:") and not c.data.startswith("collection:view_access:"))
|
||||||
async def show_collection_documents(callback: CallbackQuery):
|
async def show_collection_menu(callback: CallbackQuery):
|
||||||
collection_id = callback.data.split(":")[1]
|
"""Показать меню коллекции с опциями в зависимости от прав"""
|
||||||
|
parts = callback.data.split(":", 1)
|
||||||
|
if len(parts) < 2:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Ошибка</b>\n\nНеверный формат данных.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
return
|
||||||
|
|
||||||
|
collection_id = parts[1]
|
||||||
telegram_id = str(callback.from_user.id)
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
await callback.answer("Загружаю документы...")
|
print(f"DEBUG: collection_id from callback (menu): {collection_id}, callback_data: {callback.data}")
|
||||||
|
|
||||||
documents = await get_collection_documents(collection_id, telegram_id)
|
await callback.answer("Загружаю информацию...")
|
||||||
|
|
||||||
if not documents:
|
collection_info = await get_collection_info(collection_id, telegram_id)
|
||||||
|
if not collection_info:
|
||||||
await callback.message.answer(
|
await callback.message.answer(
|
||||||
f"<b>Коллекция пуста</b>\n\n"
|
"<b>Ошибка</b>\n\nНе удалось загрузить информацию о коллекции.",
|
||||||
f"В этой коллекции пока нет документов.\n"
|
|
||||||
f"Обратитесь к администратору для добавления документов.",
|
|
||||||
parse_mode="HTML"
|
parse_mode="HTML"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
owner_id = collection_info.get("owner_id")
|
||||||
|
collection_name = collection_info.get("name", "Коллекция")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{settings.BACKEND_URL}/users/telegram/{telegram_id}"
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
user_info = await response.json()
|
||||||
|
current_user_id = user_info.get("user_id")
|
||||||
|
is_owner = str(owner_id) == str(current_user_id)
|
||||||
|
else:
|
||||||
|
is_owner = False
|
||||||
|
except:
|
||||||
|
is_owner = False
|
||||||
|
|
||||||
|
keyboard_buttons = []
|
||||||
|
|
||||||
|
collection_id_str = str(collection_id)
|
||||||
|
|
||||||
|
if is_owner:
|
||||||
|
keyboard_buttons = [
|
||||||
|
[InlineKeyboardButton(text="Просмотр документов", callback_data=f"collection:documents:{collection_id_str}")],
|
||||||
|
[InlineKeyboardButton(text="Редактировать коллекцию", callback_data=f"collection:edit:{collection_id_str}")],
|
||||||
|
[InlineKeyboardButton(text="Управление доступом", callback_data=f"collection:access:{collection_id_str}")],
|
||||||
|
[InlineKeyboardButton(text="Загрузить документ", callback_data=f"document:upload:{collection_id_str}")],
|
||||||
|
[InlineKeyboardButton(text="Назад к коллекциям", callback_data="collections:list")]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
keyboard_buttons = [
|
||||||
|
[InlineKeyboardButton(text="Просмотр документов", callback_data=f"collection:documents:{collection_id_str}")],
|
||||||
|
[InlineKeyboardButton(text="Просмотр доступа", callback_data=f"collection:view_access:{collection_id_str}")],
|
||||||
|
[InlineKeyboardButton(text="Назад к коллекциям", callback_data="collections:list")]
|
||||||
|
]
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
|
||||||
|
|
||||||
|
role_text = "<b>Владелец</b>" if is_owner else "<b>Доступ</b>"
|
||||||
|
response = f"<b>{collection_name}</b>\n\n"
|
||||||
|
response += f"{role_text}\n\n"
|
||||||
|
response += f"ID: <code>{collection_id}</code>\n\n"
|
||||||
|
response += "Выберите действие:"
|
||||||
|
|
||||||
|
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("collection:documents:"))
|
||||||
|
async def show_collection_documents(callback: CallbackQuery):
|
||||||
|
"""Показать документы коллекции"""
|
||||||
|
try:
|
||||||
|
parts = callback.data.split(":", 2)
|
||||||
|
if len(parts) < 3:
|
||||||
|
raise ValueError("Неверный формат callback_data")
|
||||||
|
|
||||||
|
collection_id = parts[2]
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
print(f"DEBUG: collection_id from callback: {collection_id}, callback_data: {callback.data}")
|
||||||
|
|
||||||
|
await callback.answer("Загружаю документы...")
|
||||||
|
|
||||||
|
collection_info = await get_collection_info(collection_id, telegram_id)
|
||||||
|
if not collection_info:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Ошибка</b>\n\nНе удалось загрузить информацию о коллекции. Проверьте, что у вас есть доступ к этой коллекции.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
documents = await get_collection_documents(collection_id, telegram_id)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
await callback.message.answer(
|
||||||
|
f"<b>Коллекция пуста</b>\n\n"
|
||||||
|
f"В этой коллекции пока нет документов.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except IndexError:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Ошибка</b>\n\nНеверный формат данных.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in show_collection_documents: {e}")
|
||||||
|
await callback.message.answer(
|
||||||
|
f"<b>Ошибка</b>\n\nПроизошла ошибка при загрузке документов: {str(e)}",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
return
|
||||||
|
|
||||||
response = f"<b>Документы в коллекции:</b>\n\n"
|
response = f"<b>Документы в коллекции:</b>\n\n"
|
||||||
|
keyboard_buttons = []
|
||||||
|
|
||||||
for i, doc in enumerate(documents[:10], 1):
|
for i, doc in enumerate(documents[:10], 1):
|
||||||
title = doc.get("title", "Без названия")
|
doc_id = doc.get("document_id")
|
||||||
|
title = decode_title(doc.get("title", "Без названия"))
|
||||||
content_preview = doc.get("content", "")[:100]
|
content_preview = doc.get("content", "")[:100]
|
||||||
response += f"{i}. <b>{title}</b>\n"
|
response += f"{i}. <b>{title}</b>\n"
|
||||||
if content_preview:
|
if content_preview:
|
||||||
response += f" <i>{content_preview}...</i>\n"
|
response += f" <i>{content_preview}...</i>\n"
|
||||||
response += "\n"
|
response += "\n"
|
||||||
|
|
||||||
|
keyboard_buttons.append([
|
||||||
|
InlineKeyboardButton(
|
||||||
|
text=f"{title[:30]}",
|
||||||
|
callback_data=f"document:view:{doc_id}"
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
if len(documents) > 10:
|
if len(documents) > 10:
|
||||||
response += f"\n<i>Показано 10 из {len(documents)} документов</i>"
|
response += f"\n<i>Показано 10 из {len(documents)} документов</i>"
|
||||||
|
|
||||||
await callback.message.answer(response, parse_mode="HTML")
|
|
||||||
|
collection_id_for_back = str(collection_info.get("collection_id", collection_id))
|
||||||
|
keyboard_buttons.append([
|
||||||
|
InlineKeyboardButton(text="Назад", callback_data=f"collection:{collection_id_for_back}")
|
||||||
|
])
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
|
||||||
|
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("collection:access:"))
|
||||||
|
async def show_access_management(callback: CallbackQuery):
|
||||||
|
"""Показать меню управления доступом (только для владельца)"""
|
||||||
|
collection_id = callback.data.split(":")[2]
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
await callback.answer("Загружаю список доступа...")
|
||||||
|
|
||||||
|
access_list = await get_collection_access_list(collection_id, telegram_id)
|
||||||
|
|
||||||
|
response = "<b>Управление доступом</b>\n\n"
|
||||||
|
response += "<b>Пользователи с доступом:</b>\n\n"
|
||||||
|
|
||||||
|
keyboard_buttons = []
|
||||||
|
|
||||||
|
if access_list:
|
||||||
|
for i, access in enumerate(access_list[:10], 1):
|
||||||
|
user = access.get("user", {})
|
||||||
|
user_telegram_id = user.get("telegram_id", "N/A")
|
||||||
|
role = user.get("role", "user")
|
||||||
|
response += f"{i}. <code>{user_telegram_id}</code> ({role})\n"
|
||||||
|
|
||||||
|
keyboard_buttons.append([
|
||||||
|
InlineKeyboardButton(
|
||||||
|
text=f" Удалить {user_telegram_id}",
|
||||||
|
callback_data=f"access:remove:{collection_id}:{user_telegram_id}"
|
||||||
|
)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
response += "<i>Нет пользователей с доступом</i>\n\n"
|
||||||
|
|
||||||
|
keyboard_buttons.extend([
|
||||||
|
[InlineKeyboardButton(text="Добавить доступ", callback_data=f"access:add:{collection_id}")],
|
||||||
|
[InlineKeyboardButton(text="Назад", callback_data=f"collection:{collection_id}")]
|
||||||
|
])
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
|
||||||
|
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("collection:view_access:"))
|
||||||
|
async def show_access_list(callback: CallbackQuery):
|
||||||
|
"""Показать список пользователей с доступом (read-only для пользователей с доступом)"""
|
||||||
|
collection_id = callback.data.split(":")[2]
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
await callback.answer("Загружаю список доступа...")
|
||||||
|
|
||||||
|
access_list = await get_collection_access_list(collection_id, telegram_id)
|
||||||
|
|
||||||
|
response = "<b>Пользователи с доступом</b>\n\n"
|
||||||
|
|
||||||
|
if access_list:
|
||||||
|
for i, access in enumerate(access_list[:20], 1):
|
||||||
|
user = access.get("user", {})
|
||||||
|
user_telegram_id = user.get("telegram_id", "N/A")
|
||||||
|
role = user.get("role", "user")
|
||||||
|
response += f"{i}. <code>{user_telegram_id}</code> ({role})\n"
|
||||||
|
else:
|
||||||
|
response += "<i>Нет пользователей с доступом</i>\n"
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=[[
|
||||||
|
InlineKeyboardButton(text="Назад", callback_data=f"collection:{collection_id}")
|
||||||
|
]])
|
||||||
|
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("access:add:"))
|
||||||
|
async def add_access_prompt(callback: CallbackQuery, state: FSMContext):
|
||||||
|
"""Запросить пересылку сообщения для добавления доступа"""
|
||||||
|
collection_id = callback.data.split(":")[2]
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
await state.update_data(collection_id=collection_id)
|
||||||
|
await state.set_state(CollectionAccessStates.waiting_for_username)
|
||||||
|
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Добавить доступ</b>\n\n"
|
||||||
|
"Перешлите любое сообщение от пользователя, которому нужно предоставить доступ.\n\n"
|
||||||
|
"<i>Просто перешлите сообщение от нужного пользователя.</i>",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(StateFilter(CollectionAccessStates.waiting_for_username))
|
||||||
|
async def process_add_access(message: Message, state: FSMContext):
|
||||||
|
"""Обработать добавление доступа через пересылку сообщения"""
|
||||||
|
telegram_id = str(message.from_user.id)
|
||||||
|
data = await state.get_data()
|
||||||
|
collection_id = data.get("collection_id")
|
||||||
|
|
||||||
|
if not collection_id:
|
||||||
|
await message.answer("Ошибка: не указана коллекция")
|
||||||
|
await state.clear()
|
||||||
|
return
|
||||||
|
|
||||||
|
target_telegram_id = None
|
||||||
|
|
||||||
|
if message.forward_from:
|
||||||
|
target_telegram_id = str(message.forward_from.id)
|
||||||
|
elif message.forward_from_chat:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
"Пожалуйста, перешлите сообщение от пользователя, а не из группы или канала.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await state.clear()
|
||||||
|
return
|
||||||
|
elif message.forward_date:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Информация о пересылке скрыта</b>\n\n"
|
||||||
|
"Пользователь скрыл информацию о пересылке в настройках приватности Telegram.\n\n"
|
||||||
|
"Попросите пользователя временно разрешить пересылку сообщений.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await state.clear()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
"Пожалуйста, перешлите сообщение от пользователя, которому нужно предоставить доступ.\n\n"
|
||||||
|
"<i>Просто перешлите любое сообщение от нужного пользователя.</i>",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await state.clear()
|
||||||
|
return
|
||||||
|
|
||||||
|
if not target_telegram_id:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
"Не удалось определить Telegram ID пользователя.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await state.clear()
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"DEBUG: Attempting to grant access: collection_id={collection_id}, target_telegram_id={target_telegram_id}, owner_telegram_id={telegram_id}")
|
||||||
|
result = await grant_collection_access(collection_id, target_telegram_id, telegram_id)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
user_info = ""
|
||||||
|
if message.forward_from:
|
||||||
|
user_name = message.forward_from.first_name or ""
|
||||||
|
user_username = f"@{message.forward_from.username}" if message.forward_from.username else ""
|
||||||
|
user_info = f"{user_name} {user_username}".strip() or target_telegram_id
|
||||||
|
else:
|
||||||
|
user_info = target_telegram_id
|
||||||
|
|
||||||
|
await message.answer(
|
||||||
|
f"<b>Доступ предоставлен</b>\n\n"
|
||||||
|
f"Пользователю <code>{target_telegram_id}</code> предоставлен доступ к коллекции.\n\n"
|
||||||
|
f"Пользователь: {user_info}\n\n"
|
||||||
|
f"<i>Примечание: Если пользователь еще не взаимодействовал с ботом, он был автоматически создан в системе.</i>",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
"Не удалось предоставить доступ. Возможно:\n"
|
||||||
|
"• Доступ уже предоставлен\n"
|
||||||
|
"• Произошла ошибка на сервере\n"
|
||||||
|
"• Вы не являетесь владельцем коллекции\n\n"
|
||||||
|
"Проверьте логи сервера для получения подробной информации.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
await state.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("access:remove:"))
|
||||||
|
async def remove_access(callback: CallbackQuery):
|
||||||
|
"""Удалить доступ пользователя"""
|
||||||
|
parts = callback.data.split(":")
|
||||||
|
collection_id = parts[2]
|
||||||
|
target_telegram_id = parts[3]
|
||||||
|
owner_telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
await callback.answer("Удаляю доступ...")
|
||||||
|
|
||||||
|
result = await revoke_collection_access(collection_id, target_telegram_id, owner_telegram_id)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
await callback.message.answer(
|
||||||
|
f"<b>Доступ отозван</b>\n\n"
|
||||||
|
f"Доступ пользователя <code>{target_telegram_id}</code> отозван.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
"Не удалось отозвать доступ.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("collection:edit:"))
|
||||||
|
async def edit_collection_prompt(callback: CallbackQuery, state: FSMContext):
|
||||||
|
"""Запросить данные для редактирования коллекции"""
|
||||||
|
collection_id = callback.data.split(":")[2]
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
collection_info = await get_collection_info(collection_id, telegram_id)
|
||||||
|
if not collection_info:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Ошибка</b>\n\nНе удалось загрузить информацию о коллекции.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
return
|
||||||
|
|
||||||
|
await state.update_data(collection_id=collection_id)
|
||||||
|
await state.set_state(CollectionEditStates.waiting_for_name)
|
||||||
|
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Редактирование коллекции</b>\n\n"
|
||||||
|
"Отправьте новое название коллекции или /skip чтобы оставить текущее.\n\n"
|
||||||
|
f"Текущее название: <b>{collection_info.get('name', 'Без названия')}</b>",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(StateFilter(CollectionEditStates.waiting_for_name))
|
||||||
|
async def process_edit_collection_name(message: Message, state: FSMContext):
|
||||||
|
"""Обработать новое название коллекции"""
|
||||||
|
telegram_id = str(message.from_user.id)
|
||||||
|
data = await state.get_data()
|
||||||
|
collection_id = data.get("collection_id")
|
||||||
|
|
||||||
|
if message.text and message.text.strip() == "/skip":
|
||||||
|
new_name = None
|
||||||
|
else:
|
||||||
|
new_name = message.text.strip() if message.text else None
|
||||||
|
|
||||||
|
await state.update_data(name=new_name)
|
||||||
|
await state.set_state(CollectionEditStates.waiting_for_description)
|
||||||
|
|
||||||
|
collection_info = await get_collection_info(collection_id, telegram_id)
|
||||||
|
current_description = collection_info.get("description", "") if collection_info else ""
|
||||||
|
|
||||||
|
await message.answer(
|
||||||
|
"<b>Описание коллекции</b>\n\n"
|
||||||
|
"Отправьте новое описание коллекции или /skip чтобы оставить текущее.\n\n"
|
||||||
|
f"Текущее описание: <i>{current_description[:100] if current_description else 'Нет описания'}...</i>",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(StateFilter(CollectionEditStates.waiting_for_description))
|
||||||
|
async def process_edit_collection_description(message: Message, state: FSMContext):
|
||||||
|
"""Обработать новое описание коллекции"""
|
||||||
|
telegram_id = str(message.from_user.id)
|
||||||
|
data = await state.get_data()
|
||||||
|
collection_id = data.get("collection_id")
|
||||||
|
name = data.get("name")
|
||||||
|
|
||||||
|
if message.text and message.text.strip() == "/skip":
|
||||||
|
new_description = None
|
||||||
|
else:
|
||||||
|
new_description = message.text.strip() if message.text else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
update_data = {}
|
||||||
|
if name:
|
||||||
|
update_data["name"] = name
|
||||||
|
if new_description:
|
||||||
|
update_data["description"] = new_description
|
||||||
|
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.put(
|
||||||
|
f"{settings.BACKEND_URL}/collections/{collection_id}",
|
||||||
|
json=update_data,
|
||||||
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Коллекция обновлена</b>\n\n"
|
||||||
|
"Изменения сохранены.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
await message.answer(
|
||||||
|
f"<b>Ошибка</b>\n\n"
|
||||||
|
f"Не удалось обновить коллекцию: {error_text}",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await message.answer(
|
||||||
|
f"<b>Ошибка</b>\n\n"
|
||||||
|
f"Произошла ошибка: {str(e)}",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
await state.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data == "collections:list")
|
||||||
|
async def back_to_collections(callback: CallbackQuery):
|
||||||
|
"""Вернуться к списку коллекций"""
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
collections = await get_user_collections(telegram_id)
|
||||||
|
|
||||||
|
if not collections:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>У вас пока нет коллекций</b>\n\n"
|
||||||
|
"Обратитесь к администратору для создания коллекций и добавления документов.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
response = "<b>Ваши коллекции документов:</b>\n\n"
|
||||||
|
keyboard_buttons = []
|
||||||
|
|
||||||
|
for i, collection in enumerate(collections[:10], 1):
|
||||||
|
name = collection.get("name", "Без названия")
|
||||||
|
description = collection.get("description", "")
|
||||||
|
collection_id = collection.get("collection_id")
|
||||||
|
|
||||||
|
response += f"{i}. <b>{name}</b>\n"
|
||||||
|
if description:
|
||||||
|
response += f" <i>{description[:50]}...</i>\n"
|
||||||
|
response += f" ID: <code>{collection_id}</code>\n\n"
|
||||||
|
|
||||||
|
keyboard_buttons.append([
|
||||||
|
InlineKeyboardButton(
|
||||||
|
text=f"{name}",
|
||||||
|
callback_data=f"collection:{collection_id}"
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
|
||||||
|
response += "<i>Нажмите на коллекцию, чтобы посмотреть документы</i>"
|
||||||
|
|
||||||
|
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
396
tg_bot/infrastructure/telegram/handlers/document_handler.py
Normal file
396
tg_bot/infrastructure/telegram/handlers/document_handler.py
Normal file
@ -0,0 +1,396 @@
|
|||||||
|
"""
|
||||||
|
Обработчики для работы с документами
|
||||||
|
"""
|
||||||
|
from aiogram import Router, F
|
||||||
|
from aiogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton, CallbackQuery
|
||||||
|
from aiogram.filters import StateFilter
|
||||||
|
from aiogram.fsm.context import FSMContext
|
||||||
|
import aiohttp
|
||||||
|
from urllib.parse import unquote
|
||||||
|
from tg_bot.config.settings import settings
|
||||||
|
from tg_bot.infrastructure.http_client import create_http_session
|
||||||
|
from tg_bot.infrastructure.telegram.states.collection_states import (
|
||||||
|
DocumentEditStates,
|
||||||
|
DocumentUploadStates
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_title(title: str) -> str:
|
||||||
|
"""Декодирует URL-encoded название документа"""
|
||||||
|
if not title:
|
||||||
|
return "Без названия"
|
||||||
|
try:
|
||||||
|
decoded = unquote(title)
|
||||||
|
if decoded != title or '%' not in title:
|
||||||
|
return decoded
|
||||||
|
return title
|
||||||
|
except Exception:
|
||||||
|
return title
|
||||||
|
router = Router()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_document_info(document_id: str, telegram_id: str):
|
||||||
|
"""Получить информацию о документе"""
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{settings.BACKEND_URL}/documents/{document_id}",
|
||||||
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting document info: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_document(document_id: str, telegram_id: str):
|
||||||
|
"""Удалить документ"""
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.delete(
|
||||||
|
f"{settings.BACKEND_URL}/documents/{document_id}",
|
||||||
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
|
) as response:
|
||||||
|
return response.status == 204
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting document: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def update_document(document_id: str, telegram_id: str, title: str = None, content: str = None):
|
||||||
|
"""Обновить документ"""
|
||||||
|
try:
|
||||||
|
update_data = {}
|
||||||
|
if title:
|
||||||
|
update_data["title"] = title
|
||||||
|
if content:
|
||||||
|
update_data["content"] = content
|
||||||
|
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.put(
|
||||||
|
f"{settings.BACKEND_URL}/documents/{document_id}",
|
||||||
|
json=update_data,
|
||||||
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error updating document: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_document_to_collection(collection_id: str, file_data: bytes, filename: str, telegram_id: str):
|
||||||
|
"""Загрузить документ в коллекцию"""
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
form_data = aiohttp.FormData()
|
||||||
|
form_data.add_field('file', file_data, filename=filename, content_type='application/octet-stream')
|
||||||
|
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.BACKEND_URL}/documents/upload?collection_id={collection_id}",
|
||||||
|
data=form_data,
|
||||||
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
|
) as response:
|
||||||
|
if response.status == 201:
|
||||||
|
return await response.json()
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
print(f"Upload error: {response.status} - {error_text}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error uploading document: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("document:view:"))
|
||||||
|
async def view_document(callback: CallbackQuery):
|
||||||
|
"""Просмотр документа"""
|
||||||
|
document_id = callback.data.split(":")[2]
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
await callback.answer("Загружаю документ...")
|
||||||
|
|
||||||
|
document = await get_document_info(document_id, telegram_id)
|
||||||
|
if not document:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Ошибка</b>\n\nНе удалось загрузить документ.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
title = decode_title(document.get("title", "Без названия"))
|
||||||
|
content = document.get("content", "")
|
||||||
|
collection_id = document.get("collection_id")
|
||||||
|
|
||||||
|
content_preview = content[:2000] if len(content) > 2000 else content
|
||||||
|
has_more = len(content) > 2000
|
||||||
|
|
||||||
|
response = f"<b>{title}</b>\n\n"
|
||||||
|
response += f"<i>{content_preview}</i>"
|
||||||
|
if has_more:
|
||||||
|
response += "\n\n<i>...</i>"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with create_http_session() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{settings.BACKEND_URL}/collections/{collection_id}",
|
||||||
|
headers={"X-Telegram-ID": telegram_id}
|
||||||
|
) as response_collection:
|
||||||
|
if response_collection.status == 200:
|
||||||
|
collection_info = await response_collection.json()
|
||||||
|
owner_id = collection_info.get("owner_id")
|
||||||
|
|
||||||
|
async with session.get(
|
||||||
|
f"{settings.BACKEND_URL}/users/telegram/{telegram_id}"
|
||||||
|
) as response_user:
|
||||||
|
if response_user.status == 200:
|
||||||
|
user_info = await response_user.json()
|
||||||
|
current_user_id = user_info.get("user_id")
|
||||||
|
is_owner = str(owner_id) == str(current_user_id)
|
||||||
|
|
||||||
|
keyboard_buttons = []
|
||||||
|
if is_owner:
|
||||||
|
keyboard_buttons = [
|
||||||
|
[InlineKeyboardButton(text="Редактировать", callback_data=f"document:edit:{document_id}")],
|
||||||
|
[InlineKeyboardButton(text="Удалить", callback_data=f"document:delete:{document_id}")],
|
||||||
|
[InlineKeyboardButton(text="Назад", callback_data=f"collection:documents:{collection_id}")]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
keyboard_buttons = [
|
||||||
|
[InlineKeyboardButton(text="Редактировать", callback_data=f"document:edit:{document_id}")],
|
||||||
|
[InlineKeyboardButton(text="Назад", callback_data=f"collection:documents:{collection_id}")]
|
||||||
|
]
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=keyboard_buttons)
|
||||||
|
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
|
||||||
|
return
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=[[
|
||||||
|
InlineKeyboardButton(text="Назад", callback_data=f"collection:documents:{collection_id}")
|
||||||
|
]])
|
||||||
|
await callback.message.answer(response, parse_mode="HTML", reply_markup=keyboard)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("document:edit:"))
|
||||||
|
async def edit_document_prompt(callback: CallbackQuery, state: FSMContext):
|
||||||
|
"""Запросить данные для редактирования документа"""
|
||||||
|
document_id = callback.data.split(":")[2]
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
document = await get_document_info(document_id, telegram_id)
|
||||||
|
if not document:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Ошибка</b>\n\nНе удалось загрузить документ.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
return
|
||||||
|
|
||||||
|
await state.update_data(document_id=document_id)
|
||||||
|
await state.set_state(DocumentEditStates.waiting_for_title)
|
||||||
|
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Редактирование документа</b>\n\n"
|
||||||
|
"Отправьте новое название документа или /skip чтобы оставить текущее.\n\n"
|
||||||
|
f"Текущее название: <b>{decode_title(document.get('title', 'Без названия'))}</b>",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(StateFilter(DocumentEditStates.waiting_for_title))
|
||||||
|
async def process_edit_title(message: Message, state: FSMContext):
|
||||||
|
"""Обработать новое название документа"""
|
||||||
|
telegram_id = str(message.from_user.id)
|
||||||
|
data = await state.get_data()
|
||||||
|
document_id = data.get("document_id")
|
||||||
|
|
||||||
|
if message.text and message.text.strip() == "/skip":
|
||||||
|
new_title = None
|
||||||
|
else:
|
||||||
|
new_title = message.text.strip() if message.text else None
|
||||||
|
|
||||||
|
await state.update_data(title=new_title)
|
||||||
|
await state.set_state(DocumentEditStates.waiting_for_content)
|
||||||
|
|
||||||
|
await message.answer(
|
||||||
|
"<b>Содержимое документа</b>\n\n"
|
||||||
|
"Отправьте новое содержимое документа или /skip чтобы оставить текущее.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(StateFilter(DocumentEditStates.waiting_for_content))
|
||||||
|
async def process_edit_content(message: Message, state: FSMContext):
|
||||||
|
"""Обработать новое содержимое документа"""
|
||||||
|
telegram_id = str(message.from_user.id)
|
||||||
|
data = await state.get_data()
|
||||||
|
document_id = data.get("document_id")
|
||||||
|
title = data.get("title")
|
||||||
|
|
||||||
|
if message.text and message.text.strip() == "/skip":
|
||||||
|
new_content = None
|
||||||
|
else:
|
||||||
|
new_content = message.text.strip() if message.text else None
|
||||||
|
|
||||||
|
result = await update_document(document_id, telegram_id, title=title, content=new_content)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Документ обновлен</b>\n\n"
|
||||||
|
"Изменения сохранены.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
"Не удалось обновить документ.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
await state.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("document:delete:"))
|
||||||
|
async def delete_document_confirm(callback: CallbackQuery):
|
||||||
|
"""Подтверждение удаления документа"""
|
||||||
|
document_id = callback.data.split(":")[2]
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=[
|
||||||
|
[InlineKeyboardButton(text="Да, удалить", callback_data=f"document:delete_confirm:{document_id}")],
|
||||||
|
[InlineKeyboardButton(text="Отмена", callback_data=f"document:view:{document_id}")]
|
||||||
|
])
|
||||||
|
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Подтверждение удаления</b>\n\n"
|
||||||
|
"Вы уверены, что хотите удалить этот документ?",
|
||||||
|
parse_mode="HTML",
|
||||||
|
reply_markup=keyboard
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("document:delete_confirm:"))
|
||||||
|
async def delete_document_execute(callback: CallbackQuery):
|
||||||
|
"""Выполнить удаление документа"""
|
||||||
|
document_id = callback.data.split(":")[2]
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
await callback.answer("Удаляю документ...")
|
||||||
|
|
||||||
|
# Получаем информацию о документе для возврата к коллекции
|
||||||
|
document = await get_document_info(document_id, telegram_id)
|
||||||
|
collection_id = document.get("collection_id") if document else None
|
||||||
|
|
||||||
|
result = await delete_document(document_id, telegram_id)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Документ удален</b>",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
"Не удалось удалить документ.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(lambda c: c.data.startswith("document:upload:"))
|
||||||
|
async def upload_document_prompt(callback: CallbackQuery, state: FSMContext):
|
||||||
|
"""Запросить файл для загрузки"""
|
||||||
|
collection_id = callback.data.split(":")[2]
|
||||||
|
telegram_id = str(callback.from_user.id)
|
||||||
|
|
||||||
|
await state.update_data(collection_id=collection_id)
|
||||||
|
await state.set_state(DocumentUploadStates.waiting_for_file)
|
||||||
|
|
||||||
|
await callback.message.answer(
|
||||||
|
"<b>Загрузка документа</b>\n\n"
|
||||||
|
"Отправьте файл (PDF, PNG, JPG, JPEG, TIFF, BMP).\n\n"
|
||||||
|
"Поддерживаемые форматы:\n"
|
||||||
|
"• PDF\n"
|
||||||
|
"• Изображения: PNG, JPG, JPEG, TIFF, BMP",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await callback.answer()
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(StateFilter(DocumentUploadStates.waiting_for_file), F.document | F.photo)
|
||||||
|
async def process_upload_document(message: Message, state: FSMContext):
|
||||||
|
"""Обработать загрузку документа"""
|
||||||
|
telegram_id = str(message.from_user.id)
|
||||||
|
data = await state.get_data()
|
||||||
|
collection_id = data.get("collection_id")
|
||||||
|
|
||||||
|
if not collection_id:
|
||||||
|
await message.answer("Ошибка: не указана коллекция")
|
||||||
|
await state.clear()
|
||||||
|
return
|
||||||
|
|
||||||
|
file_id = None
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
if message.document:
|
||||||
|
file_id = message.document.file_id
|
||||||
|
filename = message.document.file_name or "document.pdf"
|
||||||
|
|
||||||
|
supported_extensions = ['.pdf', '.png', '.jpg', '.jpeg', '.tiff', '.bmp']
|
||||||
|
file_ext = filename.lower().rsplit('.', 1)[-1] if '.' in filename else ''
|
||||||
|
if f'.{file_ext}' not in supported_extensions:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
f"Неподдерживаемый формат файла: {file_ext}\n\n"
|
||||||
|
"Поддерживаются: PDF, PNG, JPG, JPEG, TIFF, BMP",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await state.clear()
|
||||||
|
return
|
||||||
|
elif message.photo:
|
||||||
|
file_id = message.photo[-1].file_id
|
||||||
|
filename = "photo.jpg"
|
||||||
|
else:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
"Пожалуйста, отправьте файл (PDF или изображение).",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
await state.clear()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
file = await message.bot.get_file(file_id)
|
||||||
|
file_data = await message.bot.download_file(file.file_path)
|
||||||
|
file_bytes = file_data.read()
|
||||||
|
|
||||||
|
result = await upload_document_to_collection(collection_id, file_bytes, filename, telegram_id)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
await message.answer(
|
||||||
|
f"<b>✅ Документ загружен и добавлен в коллекцию</b>\n\n"
|
||||||
|
f"<b>Название:</b> {decode_title(result.get('title', filename))}\n\n"
|
||||||
|
f"📄 Документ сейчас индексируется. Вы получите уведомление, когда индексация завершится.\n\n",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
"Не удалось загрузить документ.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error uploading document: {e}")
|
||||||
|
await message.answer(
|
||||||
|
"<b>Ошибка</b>\n\n"
|
||||||
|
f"Произошла ошибка при загрузке: {str(e)}",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
await state.clear()
|
||||||
|
|
||||||
@ -1,16 +1,14 @@
|
|||||||
from aiogram import Router, types
|
from aiogram import Router, types
|
||||||
from aiogram.types import Message
|
from aiogram.types import Message
|
||||||
from datetime import datetime
|
|
||||||
import aiohttp
|
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
from tg_bot.domain.user_service import UserService, User
|
||||||
from tg_bot.infrastructure.database.models import UserModel
|
|
||||||
from tg_bot.domain.services.user_service import UserService
|
|
||||||
from tg_bot.application.services.rag_service import RAGService
|
from tg_bot.application.services.rag_service import RAGService
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
router = Router()
|
router = Router()
|
||||||
BACKEND_URL = "http://localhost:8001/api/v1"
|
|
||||||
rag_service = RAGService()
|
rag_service = RAGService()
|
||||||
|
user_service = UserService()
|
||||||
|
|
||||||
@router.message()
|
@router.message()
|
||||||
async def handle_question(message: Message):
|
async def handle_question(message: Message):
|
||||||
@ -19,58 +17,37 @@ async def handle_question(message: Message):
|
|||||||
if question_text.startswith('/'):
|
if question_text.startswith('/'):
|
||||||
return
|
return
|
||||||
|
|
||||||
async with AsyncSessionLocal() as session:
|
try:
|
||||||
try:
|
user = await user_service.get_user_by_telegram_id(user_id)
|
||||||
user_service = UserService(session)
|
|
||||||
user = await user_service.get_user_by_telegram_id(user_id)
|
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
user = await user_service.get_or_create_user(
|
user = await user_service.get_or_create_user(
|
||||||
user_id,
|
user_id,
|
||||||
message.from_user.username or "",
|
message.from_user.username or "",
|
||||||
message.from_user.first_name or "",
|
message.from_user.first_name or "",
|
||||||
message.from_user.last_name or ""
|
message.from_user.last_name or ""
|
||||||
)
|
|
||||||
await ensure_user_in_backend(str(user_id), message.from_user)
|
|
||||||
|
|
||||||
if user.is_premium:
|
|
||||||
await process_premium_question(message, user, question_text, user_service)
|
|
||||||
|
|
||||||
elif user.questions_used < settings.FREE_QUESTIONS_LIMIT:
|
|
||||||
await process_free_question(message, user, question_text, user_service)
|
|
||||||
|
|
||||||
else:
|
|
||||||
await handle_limit_exceeded(message, user)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing question: {e}")
|
|
||||||
await message.answer(
|
|
||||||
"Произошла ошибка. Попробуйте позже.",
|
|
||||||
parse_mode="HTML"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if user.is_premium:
|
||||||
|
await process_premium_question(message, user, question_text)
|
||||||
|
|
||||||
async def ensure_user_in_backend(telegram_id: str, telegram_user):
|
elif user.questions_used < settings.FREE_QUESTIONS_LIMIT:
|
||||||
try:
|
await process_free_question(message, user, question_text)
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(
|
else:
|
||||||
f"{BACKEND_URL}/users/telegram/{telegram_id}"
|
await handle_limit_exceeded(message, user)
|
||||||
) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
return
|
|
||||||
|
|
||||||
async with session.post(
|
|
||||||
f"{BACKEND_URL}/users",
|
|
||||||
json={"telegram_id": telegram_id, "role": "user"}
|
|
||||||
) as create_response:
|
|
||||||
if create_response.status in [200, 201]:
|
|
||||||
print(f"Пользователь {telegram_id} создан в backend")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating user in backend: {e}")
|
print(f"Error processing question: {e}")
|
||||||
|
await message.answer(
|
||||||
|
"Произошла ошибка. Попробуйте позже.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_premium_question(message: Message, user: UserModel, question_text: str, user_service: UserService):
|
async def process_premium_question(message: Message, user: User, question_text: str):
|
||||||
await user_service.update_user_questions(user.telegram_id)
|
await user_service.update_user_questions(int(user.telegram_id))
|
||||||
|
user = await user_service.get_user_by_telegram_id(int(user.telegram_id))
|
||||||
|
|
||||||
await message.bot.send_chat_action(message.chat.id, "typing")
|
await message.bot.send_chat_action(message.chat.id, "typing")
|
||||||
|
|
||||||
@ -83,37 +60,41 @@ async def process_premium_question(message: Message, user: UserModel, question_t
|
|||||||
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
||||||
sources = rag_result.get("sources", [])
|
sources = rag_result.get("sources", [])
|
||||||
|
|
||||||
await save_conversation_to_backend(
|
# Беседа уже сохранена в бэкенде через API /rag/question
|
||||||
str(message.from_user.id),
|
|
||||||
question_text,
|
import re
|
||||||
answer,
|
formatted_answer = answer
|
||||||
sources
|
formatted_answer = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', formatted_answer)
|
||||||
)
|
formatted_answer = re.sub(r'^(\d+)\.\s+', r'\1. ', formatted_answer, flags=re.MULTILINE)
|
||||||
|
formatted_answer = formatted_answer.replace("- ", "• ")
|
||||||
|
|
||||||
response = (
|
response = (
|
||||||
f"<b>Ваш вопрос:</b>\n"
|
f"<b>Ваш вопрос:</b>\n"
|
||||||
f"<i>{question_text[:200]}</i>\n\n"
|
f"<i>{question_text[:200]}</i>\n\n"
|
||||||
f"<b>Ответ:</b>\n{answer}\n\n"
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
|
f"💬 <b>Ответ:</b>\n\n"
|
||||||
|
f"{formatted_answer}\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
if sources:
|
if sources:
|
||||||
response += f"<b>Источники из коллекций:</b>\n"
|
response += f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
collections_used = {}
|
response += f"📚 <b>Источники:</b>\n"
|
||||||
for source in sources[:5]:
|
for idx, source in enumerate(sources[:5], 1):
|
||||||
collection_name = source.get('collection', 'Неизвестно')
|
title = source.get('title', 'Без названия')
|
||||||
if collection_name not in collections_used:
|
try:
|
||||||
collections_used[collection_name] = []
|
from urllib.parse import unquote
|
||||||
collections_used[collection_name].append(source.get('title', 'Без названия'))
|
decoded = unquote(title)
|
||||||
|
if decoded != title or '%' in title:
|
||||||
for i, (collection_name, titles) in enumerate(collections_used.items(), 1):
|
title = decoded
|
||||||
response += f"{i}. <b>Коллекция:</b> {collection_name}\n"
|
except:
|
||||||
for title in titles[:2]:
|
pass
|
||||||
response += f" • {title}\n"
|
response += f" {idx}. {title}\n"
|
||||||
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
response += "\n<i>💡 Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
||||||
|
|
||||||
response += (
|
response += (
|
||||||
f"<b>Статус:</b> Premium (вопросов безлимитно)\n"
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
f"<b>Всего вопросов:</b> {user.questions_used}"
|
f"✨ <b>Статус:</b> Premium (вопросов безлимитно)\n"
|
||||||
|
f"📊 <b>Всего вопросов:</b> {user.questions_used}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -121,17 +102,20 @@ async def process_premium_question(message: Message, user: UserModel, question_t
|
|||||||
response = (
|
response = (
|
||||||
f"<b>Ваш вопрос:</b>\n"
|
f"<b>Ваш вопрос:</b>\n"
|
||||||
f"<i>{question_text[:200]}</i>\n\n"
|
f"<i>{question_text[:200]}</i>\n\n"
|
||||||
f"Ошибка при генерации ответа. Попробуйте позже.\n\n"
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
f"<b>Статус:</b> Premium\n"
|
f"❌ <b>Ошибка при генерации ответа.</b>\n"
|
||||||
f"<b>Всего вопросов:</b> {user.questions_used}"
|
f"Попробуйте позже.\n\n"
|
||||||
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
|
f"✨ <b>Статус:</b> Premium\n"
|
||||||
|
f"📊 <b>Всего вопросов:</b> {user.questions_used}"
|
||||||
)
|
)
|
||||||
|
|
||||||
await message.answer(response, parse_mode="HTML")
|
await message.answer(response, parse_mode="HTML")
|
||||||
|
|
||||||
|
|
||||||
async def process_free_question(message: Message, user: UserModel, question_text: str, user_service: UserService):
|
async def process_free_question(message: Message, user: User, question_text: str):
|
||||||
await user_service.update_user_questions(user.telegram_id)
|
await user_service.update_user_questions(int(user.telegram_id))
|
||||||
user = await user_service.get_user_by_telegram_id(user.telegram_id)
|
user = await user_service.get_user_by_telegram_id(int(user.telegram_id))
|
||||||
remaining = settings.FREE_QUESTIONS_LIMIT - user.questions_used
|
remaining = settings.FREE_QUESTIONS_LIMIT - user.questions_used
|
||||||
|
|
||||||
await message.bot.send_chat_action(message.chat.id, "typing")
|
await message.bot.send_chat_action(message.chat.id, "typing")
|
||||||
@ -145,138 +129,69 @@ async def process_free_question(message: Message, user: UserModel, question_text
|
|||||||
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
answer = rag_result.get("answer", "Извините, не удалось сгенерировать ответ.")
|
||||||
sources = rag_result.get("sources", [])
|
sources = rag_result.get("sources", [])
|
||||||
|
|
||||||
await save_conversation_to_backend(
|
# Уже все сохранили через /rag/question
|
||||||
str(message.from_user.id),
|
|
||||||
question_text,
|
|
||||||
answer,
|
|
||||||
sources
|
|
||||||
)
|
|
||||||
|
|
||||||
|
formatted_answer = answer
|
||||||
|
formatted_answer = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', formatted_answer)
|
||||||
|
formatted_answer = re.sub(r'^(\d+)\.\s+', r'\1. ', formatted_answer, flags=re.MULTILINE)
|
||||||
|
formatted_answer = formatted_answer.replace("- ", "• ")
|
||||||
response = (
|
response = (
|
||||||
f"<b>Ваш вопрос:</b>\n"
|
f"<b>Ваш вопрос:</b>\n"
|
||||||
f"<i>{question_text[:200]}</i>\n\n"
|
f"<i>{question_text[:200]}</i>\n\n"
|
||||||
f"<b>Ответ:</b>\n{answer}\n\n"
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
|
f"💬 <b>Ответ:</b>\n\n"
|
||||||
|
f"{formatted_answer}\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
if sources:
|
if sources:
|
||||||
response += f"<b>Источники из коллекций:</b>\n"
|
response += f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
collections_used = {}
|
response += f"📚 <b>Источники:</b>\n"
|
||||||
for source in sources[:5]:
|
for idx, source in enumerate(sources[:5], 1):
|
||||||
collection_name = source.get('collection', 'Неизвестно')
|
title = source.get('title', 'Без названия')
|
||||||
if collection_name not in collections_used:
|
try:
|
||||||
collections_used[collection_name] = []
|
from urllib.parse import unquote
|
||||||
collections_used[collection_name].append(source.get('title', 'Без названия'))
|
decoded = unquote(title)
|
||||||
|
if decoded != title or '%' in title:
|
||||||
for i, (collection_name, titles) in enumerate(collections_used.items(), 1):
|
title = decoded
|
||||||
response += f"{i}. <b>Коллекция:</b> {collection_name}\n"
|
except:
|
||||||
for title in titles[:2]:
|
pass
|
||||||
response += f" • {title}\n"
|
response += f" {idx}. {title}\n"
|
||||||
response += "\n<i>Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
response += "\n<i>💡 Используйте /mycollections для просмотра всех коллекций</i>\n\n"
|
||||||
|
|
||||||
response += (
|
response += (
|
||||||
f"<b>Статус:</b> Бесплатный доступ\n"
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
f"<b>Использовано вопросов:</b> {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n"
|
f"📊 <b>Статус:</b> Бесплатный доступ\n"
|
||||||
f"<b>Осталось бесплатных:</b> {remaining}\n\n"
|
f"📈 <b>Использовано вопросов:</b> {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n"
|
||||||
|
f"🎯 <b>Осталось бесплатных:</b> {remaining}\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
if remaining <= 3 and remaining > 0:
|
if remaining <= 3 and remaining > 0:
|
||||||
response += f"<i>Осталось мало вопросов! Для продолжения используйте /buy</i>\n\n"
|
response += f"⚠️ <i>Осталось мало вопросов! Для продолжения используйте /buy</i>\n\n"
|
||||||
|
|
||||||
response += f"<i>Для безлимитного доступа: /buy</i>"
|
response += f"💎 <i>Для безлимитного доступа: /buy</i>"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating answer: {e}")
|
print(f"Error generating answer: {e}")
|
||||||
response = (
|
response = (
|
||||||
f"<b>Ваш вопрос:</b>\n"
|
f"<b>Ваш вопрос:</b>\n"
|
||||||
f"<i>{question_text[:200]}</i>\n\n"
|
f"<i>{question_text[:200]}</i>\n\n"
|
||||||
f"Ошибка при генерации ответа. Попробуйте позже.\n\n"
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
f"<b>Статус:</b> Бесплатный доступ\n"
|
f"❌ <b>Ошибка при генерации ответа.</b>\n"
|
||||||
f"<b>Использовано вопросов:</b> {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n"
|
f"Попробуйте позже.\n\n"
|
||||||
f"<b>Осталось бесплатных:</b> {remaining}\n\n"
|
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||||
f"<i>Для безлимитного доступа: /buy</i>"
|
f"📊 <b>Статус:</b> Бесплатный доступ\n"
|
||||||
|
f"📈 <b>Использовано вопросов:</b> {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n"
|
||||||
|
f"🎯 <b>Осталось бесплатных:</b> {remaining}\n\n"
|
||||||
|
f"💎 <i>Для безлимитного доступа: /buy</i>"
|
||||||
)
|
)
|
||||||
|
|
||||||
await message.answer(response, parse_mode="HTML")
|
await message.answer(response, parse_mode="HTML")
|
||||||
|
|
||||||
|
|
||||||
async def save_conversation_to_backend(telegram_id: str, question: str, answer: str, sources: list):
|
#Сново сохраняется в /rag/question
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(
|
|
||||||
f"{BACKEND_URL}/users/telegram/{telegram_id}"
|
|
||||||
) as user_response:
|
|
||||||
if user_response.status != 200:
|
|
||||||
return
|
|
||||||
user_data = await user_response.json()
|
|
||||||
user_uuid = user_data.get("user_id")
|
|
||||||
|
|
||||||
async with session.get(
|
|
||||||
f"{BACKEND_URL}/collections/",
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
) as collections_response:
|
|
||||||
collections = []
|
|
||||||
if collections_response.status == 200:
|
|
||||||
collections = await collections_response.json()
|
|
||||||
|
|
||||||
collection_id = None
|
|
||||||
if collections:
|
|
||||||
collection_id = collections[0].get("collection_id")
|
|
||||||
else:
|
|
||||||
async with session.post(
|
|
||||||
f"{BACKEND_URL}/collections",
|
|
||||||
json={
|
|
||||||
"name": "Основная коллекция",
|
|
||||||
"description": "Коллекция по умолчанию",
|
|
||||||
"is_public": False
|
|
||||||
},
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
) as create_collection_response:
|
|
||||||
if create_collection_response.status in [200, 201]:
|
|
||||||
collection_data = await create_collection_response.json()
|
|
||||||
collection_id = collection_data.get("collection_id")
|
|
||||||
|
|
||||||
if not collection_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
async with session.post(
|
|
||||||
f"{BACKEND_URL}/conversations",
|
|
||||||
json={"collection_id": str(collection_id)},
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
) as conversation_response:
|
|
||||||
if conversation_response.status not in [200, 201]:
|
|
||||||
return
|
|
||||||
conversation_data = await conversation_response.json()
|
|
||||||
conversation_id = conversation_data.get("conversation_id")
|
|
||||||
|
|
||||||
if not conversation_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
await session.post(
|
|
||||||
f"{BACKEND_URL}/messages",
|
|
||||||
json={
|
|
||||||
"conversation_id": str(conversation_id),
|
|
||||||
"content": question,
|
|
||||||
"role": "user"
|
|
||||||
},
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.post(
|
|
||||||
f"{BACKEND_URL}/messages",
|
|
||||||
json={
|
|
||||||
"conversation_id": str(conversation_id),
|
|
||||||
"content": answer,
|
|
||||||
"role": "assistant",
|
|
||||||
"sources": {"documents": sources}
|
|
||||||
},
|
|
||||||
headers={"X-Telegram-ID": telegram_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving conversation: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_limit_exceeded(message: Message, user: UserModel):
|
async def handle_limit_exceeded(message: Message, user: User):
|
||||||
response = (
|
response = (
|
||||||
f"<b>Лимит бесплатных вопросов исчерпан!</b>\n\n"
|
f"<b>Лимит бесплатных вопросов исчерпан!</b>\n\n"
|
||||||
|
|
||||||
|
|||||||
@ -4,10 +4,10 @@ from aiogram.types import Message
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
from tg_bot.domain.user_service import UserService
|
||||||
from tg_bot.domain.services.user_service import UserService
|
|
||||||
|
|
||||||
router = Router()
|
router = Router()
|
||||||
|
user_service = UserService()
|
||||||
|
|
||||||
@router.message(Command("start"))
|
@router.message(Command("start"))
|
||||||
async def cmd_start(message: Message):
|
async def cmd_start(message: Message):
|
||||||
@ -16,22 +16,19 @@ async def cmd_start(message: Message):
|
|||||||
username = message.from_user.username or ""
|
username = message.from_user.username or ""
|
||||||
first_name = message.from_user.first_name or ""
|
first_name = message.from_user.first_name or ""
|
||||||
last_name = message.from_user.last_name or ""
|
last_name = message.from_user.last_name or ""
|
||||||
async with AsyncSessionLocal() as session:
|
try:
|
||||||
try:
|
existing_user = await user_service.get_user_by_telegram_id(user_id)
|
||||||
user_service = UserService(session)
|
user = await user_service.get_or_create_user(
|
||||||
existing_user = await user_service.get_user_by_telegram_id(user_id)
|
user_id,
|
||||||
user = await user_service.get_or_create_user(
|
username,
|
||||||
user_id,
|
first_name,
|
||||||
username,
|
last_name
|
||||||
first_name,
|
)
|
||||||
last_name
|
if not existing_user:
|
||||||
)
|
print(f"Новый пользователь: {user_id}")
|
||||||
if not existing_user:
|
|
||||||
print(f"Новый пользователь: {user_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ошибка сохранения пользователя: {e}")
|
print(f"Ошибка сохранения пользователя: {e}")
|
||||||
await session.rollback()
|
|
||||||
welcome_text = (
|
welcome_text = (
|
||||||
f"<b>Привет, {first_name}!</b>\n\n"
|
f"<b>Привет, {first_name}!</b>\n\n"
|
||||||
f"Я <b>VibeLawyerBot</b> - ваш помощник в юридических вопросах.\n\n"
|
f"Я <b>VibeLawyerBot</b> - ваш помощник в юридических вопросах.\n\n"
|
||||||
|
|||||||
@ -4,58 +4,56 @@ from aiogram.filters import Command
|
|||||||
from aiogram.types import Message
|
from aiogram.types import Message
|
||||||
|
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
from tg_bot.domain.user_service import UserService
|
||||||
from tg_bot.domain.services.user_service import UserService
|
|
||||||
|
|
||||||
router = Router()
|
router = Router()
|
||||||
|
user_service = UserService()
|
||||||
|
|
||||||
|
|
||||||
@router.message(Command("stats"))
|
@router.message(Command("stats"))
|
||||||
async def cmd_stats(message: Message):
|
async def cmd_stats(message: Message):
|
||||||
user_id = message.from_user.id
|
user_id = message.from_user.id
|
||||||
|
|
||||||
async with AsyncSessionLocal() as session:
|
try:
|
||||||
try:
|
user = await user_service.get_user_by_telegram_id(user_id)
|
||||||
user_service = UserService(session)
|
|
||||||
user = await user_service.get_user_by_telegram_id(user_id)
|
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
stats_text = (
|
stats_text = (
|
||||||
f"<b>Ваша статистика</b>\n\n"
|
f"<b>Ваша статистика</b>\n\n"
|
||||||
f"<b>Основное:</b>\n"
|
f"<b>Основное:</b>\n"
|
||||||
f"• ID: <code>{user_id}</code>\n"
|
f"• ID: <code>{user_id}</code>\n"
|
||||||
f"• Premium: {'Да' if user.is_premium else 'Нет'}\n"
|
f"• Premium: {'Да' if user.is_premium else 'Нет'}\n"
|
||||||
f"• Вопросов использовано: {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n\n"
|
f"• Вопросов использовано: {user.questions_used}/{settings.FREE_QUESTIONS_LIMIT}\n\n"
|
||||||
)
|
|
||||||
|
|
||||||
if user.is_premium:
|
|
||||||
stats_text += (
|
|
||||||
f"<b>Premium статус:</b>\n"
|
|
||||||
f"• Активен до: {user.premium_until.strftime('%d.%m.%Y') if user.premium_until else 'Не указано'}\n"
|
|
||||||
f"• Лимит вопросов: безлимитно\n\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
remaining = max(0, settings.FREE_QUESTIONS_LIMIT - user.questions_used)
|
|
||||||
stats_text += (
|
|
||||||
f"<b>Бесплатный доступ:</b>\n"
|
|
||||||
f"• Осталось вопросов: {remaining}\n"
|
|
||||||
f"• Для безлимита: /buy\n\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stats_text = (
|
|
||||||
f"<b>Добро пожаловать!</b>\n\n"
|
|
||||||
f"Вы новый пользователь.\n"
|
|
||||||
f"• Ваш ID: <code>{user_id}</code>\n"
|
|
||||||
f"• Бесплатных вопросов: {settings.FREE_QUESTIONS_LIMIT}\n"
|
|
||||||
f"• Для начала работы просто задайте вопрос!\n\n"
|
|
||||||
f"<i>Используйте /buy для получения полного доступа</i>"
|
|
||||||
)
|
|
||||||
|
|
||||||
await message.answer(stats_text, parse_mode="HTML")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await message.answer(
|
|
||||||
f"<b>Ошибка получения статистики</b>\n\n"
|
|
||||||
f"Попробуйте позже.",
|
|
||||||
parse_mode="HTML"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if user.is_premium:
|
||||||
|
stats_text += (
|
||||||
|
f"<b>Premium статус:</b>\n"
|
||||||
|
f"• Активен до: {user.premium_until.strftime('%d.%m.%Y') if user.premium_until else 'Не указано'}\n"
|
||||||
|
f"• Лимит вопросов: безлимитно\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
remaining = max(0, settings.FREE_QUESTIONS_LIMIT - user.questions_used)
|
||||||
|
stats_text += (
|
||||||
|
f"<b>Бесплатный доступ:</b>\n"
|
||||||
|
f"• Осталось вопросов: {remaining}\n"
|
||||||
|
f"• Для безлимита: /buy\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stats_text = (
|
||||||
|
f"<b>Добро пожаловать!</b>\n\n"
|
||||||
|
f"Вы новый пользователь.\n"
|
||||||
|
f"• Ваш ID: <code>{user_id}</code>\n"
|
||||||
|
f"• Бесплатных вопросов: {settings.FREE_QUESTIONS_LIMIT}\n"
|
||||||
|
f"• Для начала работы просто задайте вопрос!\n\n"
|
||||||
|
f"<i>Используйте /buy для получения полного доступа</i>"
|
||||||
|
)
|
||||||
|
|
||||||
|
await message.answer(stats_text, parse_mode="HTML")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await message.answer(
|
||||||
|
f"<b>Ошибка получения статистики</b>\n\n"
|
||||||
|
f"Попробуйте позже.",
|
||||||
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
|||||||
27
tg_bot/infrastructure/telegram/states/collection_states.py
Normal file
27
tg_bot/infrastructure/telegram/states/collection_states.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
"""
|
||||||
|
FSM состояния для работы с коллекциями
|
||||||
|
"""
|
||||||
|
from aiogram.fsm.state import State, StatesGroup
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionAccessStates(StatesGroup):
|
||||||
|
"""Состояния для управления доступом к коллекции"""
|
||||||
|
waiting_for_username = State()
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionEditStates(StatesGroup):
|
||||||
|
"""Состояния для редактирования коллекции"""
|
||||||
|
waiting_for_name = State()
|
||||||
|
waiting_for_description = State()
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentEditStates(StatesGroup):
|
||||||
|
"""Состояния для редактирования документа"""
|
||||||
|
waiting_for_title = State()
|
||||||
|
waiting_for_content = State()
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentUploadStates(StatesGroup):
|
||||||
|
"""Состояния для загрузки документа"""
|
||||||
|
waiting_for_file = State()
|
||||||
|
|
||||||
@ -1,19 +1,17 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
parent_dir = os.path.dirname(current_dir)
|
|
||||||
sys.path.insert(0, parent_dir)
|
|
||||||
|
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
|
|
||||||
|
log_file_path = settings.LOG_FILE
|
||||||
|
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
handlers=[
|
handlers=[
|
||||||
logging.FileHandler(settings.LOG_FILE),
|
logging.FileHandler(log_file_path),
|
||||||
logging.StreamHandler()
|
logging.StreamHandler()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -18,10 +18,7 @@ async def handle_yookassa_webhook(request: Request):
|
|||||||
print(f"Webhook received: {event_type}")
|
print(f"Webhook received: {event_type}")
|
||||||
try:
|
try:
|
||||||
from tg_bot.config.settings import settings
|
from tg_bot.config.settings import settings
|
||||||
from tg_bot.domain.services.user_service import UserService
|
from tg_bot.domain.user_service import UserService
|
||||||
from tg_bot.infrastructure.database.database import AsyncSessionLocal
|
|
||||||
from tg_bot.infrastructure.database.models import UserModel
|
|
||||||
from sqlalchemy import select
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
|
|
||||||
if event_type == "payment.succeeded":
|
if event_type == "payment.succeeded":
|
||||||
@ -29,38 +26,34 @@ async def handle_yookassa_webhook(request: Request):
|
|||||||
user_id = payment.get("metadata", {}).get("user_id")
|
user_id = payment.get("metadata", {}).get("user_id")
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
async with AsyncSessionLocal() as session:
|
user_service = UserService()
|
||||||
user_service = UserService(session)
|
success = await user_service.activate_premium(int(user_id))
|
||||||
success = await user_service.activate_premium(int(user_id))
|
if success:
|
||||||
if success:
|
print(f"Premium activated for user {user_id}")
|
||||||
print(f"Premium activated for user {user_id}")
|
|
||||||
|
|
||||||
result = await session.execute(
|
user = await user_service.get_user_by_telegram_id(int(user_id))
|
||||||
select(UserModel).filter_by(telegram_id=str(user_id))
|
|
||||||
)
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if user and settings.TELEGRAM_BOT_TOKEN:
|
if user and settings.TELEGRAM_BOT_TOKEN:
|
||||||
try:
|
try:
|
||||||
bot = Bot(token=settings.TELEGRAM_BOT_TOKEN)
|
bot = Bot(token=settings.TELEGRAM_BOT_TOKEN)
|
||||||
premium_until = user.premium_until or datetime.now() + timedelta(days=30)
|
premium_until = user.premium_until or datetime.now() + timedelta(days=30)
|
||||||
|
|
||||||
notification = (
|
notification = (
|
||||||
f"<b>Оплата подтверждена!</b>\n\n"
|
f"<b>Оплата подтверждена!</b>\n\n"
|
||||||
f"Premium активирован до {premium_until.strftime('%d.%m.%Y')}"
|
f"Premium активирован до {premium_until.strftime('%d.%m.%Y')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
await bot.send_message(
|
await bot.send_message(
|
||||||
chat_id=int(user_id),
|
chat_id=int(user_id),
|
||||||
text=notification,
|
text=notification,
|
||||||
parse_mode="HTML"
|
parse_mode="HTML"
|
||||||
)
|
)
|
||||||
print(f"Notification sent to user {user_id}")
|
print(f"Notification sent to user {user_id}")
|
||||||
await bot.session.close()
|
await bot.session.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error sending notification: {e}")
|
print(f"Error sending notification: {e}")
|
||||||
else:
|
else:
|
||||||
print(f"User {user_id} not found")
|
print(f"User {user_id} not found or failed to activate premium")
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"Import error: {e}")
|
print(f"Import error: {e}")
|
||||||
|
|||||||
8
tg_bot/requirements.txt
Normal file
8
tg_bot/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
pydantic>=2.5.0
|
||||||
|
pydantic-settings>=2.1.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
aiogram>=3.10.0
|
||||||
|
httpx>=0.25.2
|
||||||
|
yookassa>=2.4.0
|
||||||
|
aiohttp>=3.9.1
|
||||||
|
|
||||||
20
tg_bot/run.py
Normal file
20
tg_bot/run.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
|
||||||
|
"""
|
||||||
|
Скрипт для запуска Telegram бота без Docker
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
tg_bot_dir = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(tg_bot_dir))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
from main import main
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user