tests
This commit is contained in:
parent
c210c4a3c5
commit
cd08f88434
23
pytest.ini
Normal file
23
pytest.ini
Normal file
@ -0,0 +1,23 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
asyncio_mode = auto
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
--cov=backend/src
|
||||
--cov=tg_bot
|
||||
--cov-report=term-missing
|
||||
--cov-report=xml
|
||||
--cov-fail-under=65
|
||||
--ignore=venv
|
||||
--ignore=.venv
|
||||
markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
metrics: Metrics tests
|
||||
slow: Slow running tests
|
||||
|
||||
94
tests/README.md
Normal file
94
tests/README.md
Normal file
@ -0,0 +1,94 @@
|
||||
# Тесты для проекта BetterCallPraskovia
|
||||
|
||||
## Структура тестов
|
||||
|
||||
```
|
||||
tests/
|
||||
├── conftest.py
|
||||
├── unit/
|
||||
│ ├── test_rag_service.py
|
||||
│ ├── test_user_service.py
|
||||
│ ├── test_deepseek_client.py
|
||||
│ ├── test_document_use_cases.py
|
||||
│ └── test_collection_use_cases.py
|
||||
├── integration/
|
||||
│ └── test_rag_integration.py
|
||||
└── metrics/
|
||||
└── test_hit_at_5.py
|
||||
```
|
||||
|
||||
## Установка зависимостей
|
||||
|
||||
```bash
|
||||
pip install -r tests/requirements.txt
|
||||
```
|
||||
|
||||
## Запуск тестов
|
||||
|
||||
### Все тесты
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
### Только юнит-тесты
|
||||
```bash
|
||||
pytest tests/unit/
|
||||
```
|
||||
|
||||
### Только интеграционные тесты
|
||||
```bash
|
||||
pytest tests/integration/
|
||||
```
|
||||
|
||||
### Только тесты метрик
|
||||
```bash
|
||||
pytest tests/metrics/
|
||||
```
|
||||
|
||||
### Только тесты tg_bot
|
||||
```bash
|
||||
pytest tests/unit/test_rag_service.py tests/unit/test_user_service.py tests/unit/test_deepseek_client.py
|
||||
```
|
||||
|
||||
### С покрытием кода
|
||||
```bash
|
||||
pytest --cov=backend/src --cov=tg_bot --cov-report=html
|
||||
```
|
||||
|
||||
### С минимальным покрытием 65%
|
||||
```bash
|
||||
pytest --cov-fail-under=65
|
||||
```
|
||||
|
||||
## Метрика hit@5
|
||||
|
||||
Проверка что в топ-5 релевантных документов есть хотя бы 1 нужный документ.
|
||||
|
||||
- **hit@5 = 1**, если есть хотя бы 1 релевантный документ в топ-5
|
||||
- **hit@5 = 0**, если нет релевантных документов в топ-5
|
||||
|
||||
Среднее значение hit@5 для всех запросов должно быть **> 50%**
|
||||
|
||||
## Покрытие кода
|
||||
|
||||
**coverage ≥ 65%**
|
||||
|
||||
Проверка покрытия:
|
||||
```bash
|
||||
pytest --cov=backend/src --cov=tg_bot --cov-report=term-missing --cov-fail-under=65
|
||||
```
|
||||
|
||||
## Маркеры тестов
|
||||
|
||||
- `@pytest.mark.unit` - юнит-тесты
|
||||
- `@pytest.mark.integration` - интеграционные тесты
|
||||
- `@pytest.mark.metrics` - тесты метрик
|
||||
- `@pytest.mark.slow` - медленные тесты
|
||||
|
||||
Запуск по маркерам:
|
||||
```bash
|
||||
pytest -m unit
|
||||
pytest -m integration
|
||||
pytest -m metrics
|
||||
```
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
170
tests/conftest.py
Normal file
170
tests/conftest.py
Normal file
@ -0,0 +1,170 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
try:
|
||||
from backend.src.domain.entities.document import Document
|
||||
except ImportError:
|
||||
from src.domain.entities.document import Document
|
||||
return Document(
|
||||
collection_id=uuid4(),
|
||||
title="Тестовый документ",
|
||||
content="Содержание документа",
|
||||
metadata={"type": "law", "article": "123"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_documents_list():
|
||||
try:
|
||||
from backend.src.domain.entities.document import Document
|
||||
except ImportError:
|
||||
from src.domain.entities.document import Document
|
||||
collection_id = uuid4()
|
||||
return [
|
||||
Document(
|
||||
collection_id=collection_id,
|
||||
title=f"Документ {i}",
|
||||
content=f"Содержание документа {i} ",
|
||||
metadata={"relevance_score": 0.9 - i * 0.1}
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_relevant_documents():
|
||||
return [
|
||||
{"document_id": str(uuid4()), "title": "Гражданский кодекс РФ", "relevance": True},
|
||||
{"document_id": str(uuid4()), "title": "Трудовой кодекс РФ", "relevance": True},
|
||||
{"document_id": str(uuid4()), "title": "Налоговый кодекс РФ", "relevance": True},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_response():
|
||||
return {
|
||||
"answer": "Тестовый ответ на вопрос",
|
||||
"sources": [
|
||||
{"title": "Документ 1", "collection": "Коллекция 1", "document_id": str(uuid4())},
|
||||
{"title": "Документ 2", "collection": "Коллекция 1", "document_id": str(uuid4())},
|
||||
{"title": "Документ 3", "collection": "Коллекция 2", "document_id": str(uuid4())},
|
||||
],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_collection():
|
||||
try:
|
||||
from backend.src.domain.entities.collection import Collection
|
||||
except ImportError:
|
||||
from src.domain.entities.collection import Collection
|
||||
return Collection(
|
||||
name="Коллекция",
|
||||
owner_id=uuid4(),
|
||||
description="Описание коллекции",
|
||||
is_public=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
try:
|
||||
from backend.src.domain.entities.user import User, UserRole
|
||||
except ImportError:
|
||||
from src.domain.entities.user import User, UserRole
|
||||
return User(
|
||||
telegram_id="123456789",
|
||||
role=UserRole.USER
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_repository():
|
||||
repository = AsyncMock()
|
||||
repository.get_by_id = AsyncMock()
|
||||
repository.create = AsyncMock()
|
||||
repository.update = AsyncMock()
|
||||
repository.delete = AsyncMock(return_value=True)
|
||||
repository.list_by_collection = AsyncMock(return_value=[])
|
||||
return repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_collection_repository():
|
||||
repository = AsyncMock()
|
||||
repository.get_by_id = AsyncMock()
|
||||
repository.create = AsyncMock()
|
||||
repository.update = AsyncMock()
|
||||
repository.delete = AsyncMock(return_value=True)
|
||||
repository.list_by_owner = AsyncMock(return_value=[])
|
||||
repository.list_public = AsyncMock(return_value=[])
|
||||
return repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_repository():
|
||||
repository = AsyncMock()
|
||||
repository.get_by_id = AsyncMock()
|
||||
repository.get_by_telegram_id = AsyncMock()
|
||||
repository.create = AsyncMock()
|
||||
repository.update = AsyncMock()
|
||||
repository.delete = AsyncMock(return_value=True)
|
||||
return repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deepseek_client():
|
||||
client = AsyncMock()
|
||||
client.chat_completion = AsyncMock(return_value={
|
||||
"content": "Ответ от DeepSeek",
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
|
||||
})
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_queries_with_ground_truth():
|
||||
return [
|
||||
{
|
||||
"query": "Какие права имеет работник при увольнении?",
|
||||
"relevant_document_ids": [str(uuid4()), str(uuid4())],
|
||||
"expected_top5_contains": True
|
||||
},
|
||||
{
|
||||
"query": "Как оформить договор купли-продажи?",
|
||||
"relevant_document_ids": [str(uuid4())],
|
||||
"expected_top5_contains": True
|
||||
},
|
||||
{
|
||||
"query": "Какие налоги платит ИП?",
|
||||
"relevant_document_ids": [str(uuid4()), str(uuid4()), str(uuid4())],
|
||||
"expected_top5_contains": True
|
||||
},
|
||||
{
|
||||
"query": "Права потребителя при возврате товара",
|
||||
"relevant_document_ids": [str(uuid4())],
|
||||
"expected_top5_contains": True
|
||||
},
|
||||
{
|
||||
"query": "Как расторгнуть брак?",
|
||||
"relevant_document_ids": [str(uuid4())],
|
||||
"expected_top5_contains": True
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_aiohttp_session():
|
||||
session = AsyncMock()
|
||||
session.get = AsyncMock()
|
||||
session.post = AsyncMock()
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=None)
|
||||
return session
|
||||
0
tests/integration/__init__.py
Normal file
0
tests/integration/__init__.py
Normal file
146
tests/integration/test_rag_integration.py
Normal file
146
tests/integration/test_rag_integration.py
Normal file
@ -0,0 +1,146 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import uuid4
|
||||
import aiohttp
|
||||
from tg_bot.application.services.rag_service import RAGService
|
||||
from tests.metrics.test_hit_at_5 import calculate_hit_at_5, calculate_average_hit_at_5
|
||||
|
||||
|
||||
class TestRAGIntegration:
|
||||
|
||||
@pytest.fixture
|
||||
def rag_service(self):
|
||||
service = RAGService()
|
||||
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
||||
service.deepseek_client = DeepSeekClient()
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_service_search_documents_real_implementation(self, rag_service):
|
||||
user_telegram_id = "123456789"
|
||||
query = "трудовой договор"
|
||||
user_uuid = str(uuid4())
|
||||
collection_id = str(uuid4())
|
||||
|
||||
mock_user_response = AsyncMock()
|
||||
mock_user_response.status = 200
|
||||
mock_user_response.json = AsyncMock(return_value={"user_id": user_uuid})
|
||||
|
||||
mock_collections_response = AsyncMock()
|
||||
mock_collections_response.status = 200
|
||||
mock_collections_response.json = AsyncMock(return_value=[
|
||||
{"collection_id": collection_id, "name": "Законы"}
|
||||
])
|
||||
|
||||
mock_documents_response = AsyncMock()
|
||||
mock_documents_response.status = 200
|
||||
mock_documents_response.json = AsyncMock(return_value=[
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"title": "Трудовой кодекс",
|
||||
"content": "Содержание о трудовых договорах"
|
||||
}
|
||||
])
|
||||
|
||||
with patch('aiohttp.ClientSession') as mock_session_class:
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session.get = AsyncMock(side_effect=[
|
||||
mock_user_response,
|
||||
mock_collections_response,
|
||||
mock_documents_response
|
||||
])
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
result = await rag_service.search_documents_in_collections(user_telegram_id, query)
|
||||
|
||||
assert isinstance(result, list)
|
||||
if result:
|
||||
assert "document_id" in result[0]
|
||||
assert "title" in result[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_service_generate_answer_real_implementation(self, rag_service):
|
||||
question = "Какие права имеет работник?"
|
||||
user_telegram_id = "123456789"
|
||||
|
||||
mock_documents = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"title": "Трудовой кодекс",
|
||||
"content": "Работник имеет право на...",
|
||||
"collection_name": "Законы"
|
||||
}
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||
patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek:
|
||||
|
||||
mock_search.return_value = mock_documents
|
||||
mock_deepseek.return_value = {
|
||||
"content": "Работник имеет следующие права...",
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
|
||||
}
|
||||
|
||||
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||
|
||||
assert "answer" in result
|
||||
assert "sources" in result
|
||||
assert "usage" in result
|
||||
assert len(result["sources"]) <= 5
|
||||
assert result["answer"] != ""
|
||||
|
||||
for source in result["sources"]:
|
||||
assert "title" in source
|
||||
assert "collection" in source
|
||||
assert "document_id" in source
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_service_limits_to_top5_documents(self, rag_service):
|
||||
question = "Тестовый вопрос"
|
||||
user_telegram_id = "123456789"
|
||||
|
||||
many_documents = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"title": f"Документ {i}",
|
||||
"content": f"Содержание {i}",
|
||||
"collection_name": "Коллекция"
|
||||
}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||
patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek:
|
||||
|
||||
mock_search.return_value = many_documents
|
||||
mock_deepseek.return_value = {
|
||||
"content": "Ответ",
|
||||
"usage": {}
|
||||
}
|
||||
|
||||
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||
|
||||
assert len(result["sources"]) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_service_handles_empty_search_results(self, rag_service):
|
||||
question = "Вопрос без документов"
|
||||
user_telegram_id = "123456789"
|
||||
|
||||
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||
patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek:
|
||||
|
||||
mock_search.return_value = []
|
||||
mock_deepseek.return_value = {
|
||||
"content": "Ответ",
|
||||
"usage": {}
|
||||
}
|
||||
|
||||
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||
|
||||
assert result["sources"] == []
|
||||
assert "Релевантные документы не найдены" in result.get("answer", "") or \
|
||||
result["answer"] == "No relevant documents found"
|
||||
0
tests/metrics/__init__.py
Normal file
0
tests/metrics/__init__.py
Normal file
133
tests/metrics/test_hit_at_5.py
Normal file
133
tests/metrics/test_hit_at_5.py
Normal file
@ -0,0 +1,133 @@
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def calculate_hit_at_5(retrieved_document_ids: List[str], relevant_document_ids: List[str]) -> int:
|
||||
if not retrieved_document_ids or not relevant_document_ids:
|
||||
return 0
|
||||
|
||||
top5_ids = set(retrieved_document_ids[:5])
|
||||
relevant_ids = set(relevant_document_ids)
|
||||
|
||||
return 1 if top5_ids.intersection(relevant_ids) else 0
|
||||
|
||||
|
||||
def calculate_average_hit_at_5(results: List[int]) -> float:
|
||||
if not results:
|
||||
return 0.0
|
||||
return sum(results) / len(results)
|
||||
|
||||
|
||||
class TestHitAt5Metric:
|
||||
|
||||
def test_hit_at_5_returns_1_when_relevant_document_in_top5(self):
|
||||
relevant_ids = [str(uuid4()), str(uuid4())]
|
||||
retrieved_ids = [
|
||||
str(uuid4()),
|
||||
relevant_ids[0],
|
||||
str(uuid4()),
|
||||
str(uuid4()),
|
||||
str(uuid4())
|
||||
]
|
||||
|
||||
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
||||
assert result == 1
|
||||
|
||||
def test_hit_at_5_returns_0_when_no_relevant_document_in_top5(self):
|
||||
relevant_ids = [str(uuid4()), str(uuid4())]
|
||||
retrieved_ids = [
|
||||
str(uuid4()),
|
||||
str(uuid4()),
|
||||
str(uuid4()),
|
||||
str(uuid4()),
|
||||
str(uuid4())
|
||||
]
|
||||
|
||||
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
||||
assert result == 0
|
||||
|
||||
def test_hit_at_5_returns_1_when_multiple_relevant_documents(self):
|
||||
relevant_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
|
||||
retrieved_ids = [
|
||||
relevant_ids[0],
|
||||
str(uuid4()),
|
||||
relevant_ids[1],
|
||||
str(uuid4()),
|
||||
relevant_ids[2]
|
||||
]
|
||||
|
||||
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
||||
assert result == 1
|
||||
|
||||
def test_hit_at_5_handles_empty_lists(self):
|
||||
result = calculate_hit_at_5([], [str(uuid4())])
|
||||
assert result == 0
|
||||
|
||||
result = calculate_hit_at_5([str(uuid4())], [])
|
||||
assert result == 0
|
||||
|
||||
result = calculate_hit_at_5([], [])
|
||||
assert result == 0
|
||||
|
||||
def test_hit_at_5_only_checks_top5(self):
|
||||
relevant_ids = [str(uuid4())]
|
||||
retrieved_ids = [
|
||||
str(uuid4()),
|
||||
str(uuid4()),
|
||||
str(uuid4()),
|
||||
str(uuid4()),
|
||||
str(uuid4()),
|
||||
relevant_ids[0]
|
||||
]
|
||||
|
||||
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
||||
assert result == 0
|
||||
|
||||
def test_calculate_average_hit_at_5(self):
|
||||
results = [1, 1, 0, 1, 0, 1, 0, 1, 1, 1]
|
||||
|
||||
average = calculate_average_hit_at_5(results)
|
||||
assert average == 0.7
|
||||
|
||||
def test_calculate_average_hit_at_5_all_ones(self):
|
||||
results = [1, 1, 1, 1, 1]
|
||||
|
||||
average = calculate_average_hit_at_5(results)
|
||||
assert average == 1.0
|
||||
|
||||
def test_calculate_average_hit_at_5_all_zeros(self):
|
||||
results = [0, 0, 0, 0, 0]
|
||||
|
||||
average = calculate_average_hit_at_5(results)
|
||||
assert average == 0.0
|
||||
|
||||
def test_calculate_average_hit_at_5_empty_list(self):
|
||||
average = calculate_average_hit_at_5([])
|
||||
assert average == 0.0
|
||||
|
||||
def test_hit_at_5_quality_threshold(self):
|
||||
results = [1] * 60 + [0] * 40
|
||||
|
||||
average = calculate_average_hit_at_5(results)
|
||||
assert average > 0.5, f"Качество {average} должно быть > 0.5"
|
||||
assert average == 0.6
|
||||
|
||||
def test_hit_at_5_quality_below_threshold(self):
|
||||
results = [1] * 40 + [0] * 60
|
||||
|
||||
average = calculate_average_hit_at_5(results)
|
||||
assert average < 0.5, f"Качество {average} должно быть < 0.5"
|
||||
assert average == 0.4
|
||||
|
||||
@pytest.mark.parametrize("hit_count,total,expected_quality", [
|
||||
(51, 100, 0.51),
|
||||
(50, 100, 0.50),
|
||||
(60, 100, 0.60),
|
||||
(75, 100, 0.75),
|
||||
(100, 100, 1.0),
|
||||
])
|
||||
def test_hit_at_5_various_qualities(self, hit_count, total, expected_quality):
|
||||
results = [1] * hit_count + [0] * (total - hit_count)
|
||||
average = calculate_average_hit_at_5(results)
|
||||
assert average == expected_quality
|
||||
8
tests/requirements.txt
Normal file
8
tests/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
pytest-cov==4.1.0
|
||||
pytest-mock==3.12.0
|
||||
pytest-timeout==2.2.0
|
||||
httpx>=0.25.2
|
||||
aiohttp>=3.9.1
|
||||
|
||||
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
132
tests/unit/test_collection_use_cases.py
Normal file
132
tests/unit/test_collection_use_cases.py
Normal file
@ -0,0 +1,132 @@
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from unittest.mock import AsyncMock
|
||||
try:
|
||||
from backend.src.application.use_cases.collection_use_cases import CollectionUseCases
|
||||
from backend.src.shared.exceptions import NotFoundError, ForbiddenError
|
||||
from backend.src.domain.repositories.collection_access_repository import ICollectionAccessRepository
|
||||
except ImportError:
|
||||
from src.application.use_cases.collection_use_cases import CollectionUseCases
|
||||
from src.shared.exceptions import NotFoundError, ForbiddenError
|
||||
from src.domain.repositories.collection_access_repository import ICollectionAccessRepository
|
||||
|
||||
|
||||
class TestCollectionUseCases:
|
||||
|
||||
@pytest.fixture
|
||||
def collection_use_cases(self, mock_collection_repository, mock_user_repository):
|
||||
mock_access_repository = AsyncMock()
|
||||
mock_access_repository.get_by_user_and_collection = AsyncMock(return_value=None)
|
||||
mock_access_repository.create = AsyncMock()
|
||||
mock_access_repository.delete_by_user_and_collection = AsyncMock(return_value=True)
|
||||
mock_access_repository.list_by_user = AsyncMock(return_value=[])
|
||||
|
||||
return CollectionUseCases(
|
||||
collection_repository=mock_collection_repository,
|
||||
access_repository=mock_access_repository,
|
||||
user_repository=mock_user_repository
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_collection_success(self, collection_use_cases, mock_user,
|
||||
mock_collection_repository, mock_user_repository):
|
||||
owner_id = uuid4()
|
||||
mock_user_repository.get_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_collection_repository.create = AsyncMock(return_value=mock_user)
|
||||
|
||||
result = await collection_use_cases.create_collection(
|
||||
name="Тестовая коллекция",
|
||||
owner_id=owner_id,
|
||||
description="Описание",
|
||||
is_public=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
mock_user_repository.get_by_id.assert_called_once_with(owner_id)
|
||||
mock_collection_repository.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_collection_user_not_found(self, collection_use_cases, mock_user_repository):
|
||||
owner_id = uuid4()
|
||||
mock_user_repository.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
await collection_use_cases.create_collection(
|
||||
name="Коллекция",
|
||||
owner_id=owner_id
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_collection_success(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||
collection_id = uuid4()
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
|
||||
result = await collection_use_cases.get_collection(collection_id)
|
||||
|
||||
assert result == mock_collection
|
||||
mock_collection_repository.get_by_id.assert_called_once_with(collection_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_collection_not_found(self, collection_use_cases, mock_collection_repository):
|
||||
collection_id = uuid4()
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
await collection_use_cases.get_collection(collection_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_collection_success(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||
collection_id = uuid4()
|
||||
user_id = uuid4()
|
||||
mock_collection.owner_id = user_id
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
mock_collection_repository.update = AsyncMock(return_value=mock_collection)
|
||||
|
||||
result = await collection_use_cases.update_collection(
|
||||
collection_id=collection_id,
|
||||
user_id=user_id,
|
||||
name="Обновленное название"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert mock_collection.name == "Обновленное название"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_collection_forbidden(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||
collection_id = uuid4()
|
||||
user_id = uuid4()
|
||||
owner_id = uuid4()
|
||||
mock_collection.owner_id = owner_id
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
|
||||
with pytest.raises(ForbiddenError):
|
||||
await collection_use_cases.update_collection(
|
||||
collection_id=collection_id,
|
||||
user_id=user_id,
|
||||
name="Название"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_access_owner(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||
collection_id = uuid4()
|
||||
user_id = uuid4()
|
||||
mock_collection.owner_id = user_id
|
||||
mock_collection.is_public = False
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
|
||||
result = await collection_use_cases.check_access(collection_id, user_id)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_access_public(self, collection_use_cases, mock_collection, mock_collection_repository):
|
||||
collection_id = uuid4()
|
||||
user_id = uuid4()
|
||||
owner_id = uuid4()
|
||||
mock_collection.owner_id = owner_id
|
||||
mock_collection.is_public = True
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
|
||||
result = await collection_use_cases.check_access(collection_id, user_id)
|
||||
|
||||
assert result is True
|
||||
114
tests/unit/test_deepseek_client.py
Normal file
114
tests/unit/test_deepseek_client.py
Normal file
@ -0,0 +1,114 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import httpx
|
||||
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
||||
from tg_bot.infrastructure.external.deepseek_client import DeepSeekAPIError
|
||||
|
||||
|
||||
class TestDeepSeekClient:
|
||||
|
||||
@pytest.fixture
|
||||
def deepseek_client(self):
|
||||
return DeepSeekClient(api_key="test_key", api_url="https://api.test.com/v1/chat/completions")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_success(self, deepseek_client):
|
||||
messages = [
|
||||
{"role": "user", "content": "Тестовый вопрос"}
|
||||
]
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "Тестовый ответ от DeepSeek"
|
||||
}
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
result = await deepseek_client.chat_completion(messages)
|
||||
|
||||
assert "content" in result
|
||||
assert result["content"] == "Тестовый ответ от DeepSeek"
|
||||
assert "usage" in result
|
||||
assert result["usage"]["total_tokens"] == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_no_api_key(self):
|
||||
client = DeepSeekClient(api_key=None)
|
||||
messages = [{"role": "user", "content": "Вопрос"}]
|
||||
|
||||
result = await client.chat_completion(messages)
|
||||
|
||||
assert "content" in result
|
||||
assert "DEEPSEEK_API_KEY" in result["content"] or "не установлен" in result["content"]
|
||||
assert result["usage"]["total_tokens"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_api_error(self, deepseek_client):
|
||||
import httpx
|
||||
messages = [{"role": "user", "content": "Вопрос"}]
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"Unauthorized", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
with pytest.raises(DeepSeekAPIError):
|
||||
await deepseek_client.chat_completion(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_parameters(self, deepseek_client):
|
||||
messages = [{"role": "user", "content": "Вопрос"}]
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "Ответ"}}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||
}
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
result = await deepseek_client.chat_completion(
|
||||
messages,
|
||||
model="deepseek-chat",
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
assert result["content"] == "Ответ"
|
||||
call_args = mock_client_instance.post.call_args
|
||||
assert call_args is not None
|
||||
141
tests/unit/test_document_use_cases.py
Normal file
141
tests/unit/test_document_use_cases.py
Normal file
@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from unittest.mock import AsyncMock
|
||||
try:
|
||||
from backend.src.application.use_cases.document_use_cases import DocumentUseCases
|
||||
from backend.src.shared.exceptions import NotFoundError, ForbiddenError
|
||||
from backend.src.application.services.document_parser_service import DocumentParserService
|
||||
except ImportError:
|
||||
from src.application.use_cases.document_use_cases import DocumentUseCases
|
||||
from src.shared.exceptions import NotFoundError, ForbiddenError
|
||||
from src.application.services.document_parser_service import DocumentParserService
|
||||
|
||||
|
||||
class TestDocumentUseCases:
|
||||
|
||||
@pytest.fixture
|
||||
def document_use_cases(self, mock_document_repository, mock_collection_repository):
|
||||
mock_parser = AsyncMock()
|
||||
mock_parser.parse_pdf = AsyncMock(return_value=("Парсенный документ", "Содержание"))
|
||||
|
||||
return DocumentUseCases(
|
||||
document_repository=mock_document_repository,
|
||||
collection_repository=mock_collection_repository,
|
||||
parser_service=mock_parser
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_document_success(self, document_use_cases, mock_collection, mock_document_repository, mock_collection_repository):
|
||||
collection_id = uuid4()
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
mock_document_repository.create = AsyncMock(return_value=mock_collection)
|
||||
|
||||
result = await document_use_cases.create_document(
|
||||
collection_id=collection_id,
|
||||
title="Тестовый документ",
|
||||
content="Содержание",
|
||||
metadata={"type": "law"}
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
mock_collection_repository.get_by_id.assert_called_once_with(collection_id)
|
||||
mock_document_repository.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_document_collection_not_found(self, document_use_cases, mock_collection_repository):
|
||||
collection_id = uuid4()
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
await document_use_cases.create_document(
|
||||
collection_id=collection_id,
|
||||
title="Документ",
|
||||
content="Содержание"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_document_success(self, document_use_cases, mock_document, mock_document_repository):
|
||||
document_id = uuid4()
|
||||
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
|
||||
|
||||
result = await document_use_cases.get_document(document_id)
|
||||
|
||||
assert result == mock_document
|
||||
mock_document_repository.get_by_id.assert_called_once_with(document_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_document_not_found(self, document_use_cases, mock_document_repository):
|
||||
document_id = uuid4()
|
||||
mock_document_repository.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
await document_use_cases.get_document(document_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_document_success(self, document_use_cases, mock_document, mock_collection,
|
||||
mock_document_repository, mock_collection_repository):
|
||||
document_id = uuid4()
|
||||
user_id = uuid4()
|
||||
mock_document.collection_id = uuid4()
|
||||
mock_collection.owner_id = user_id
|
||||
|
||||
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
mock_document_repository.update = AsyncMock(return_value=mock_document)
|
||||
|
||||
result = await document_use_cases.update_document(
|
||||
document_id=document_id,
|
||||
user_id=user_id,
|
||||
title="Обновленное название"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert mock_document.title == "Обновленное название"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_document_forbidden(self, document_use_cases, mock_document, mock_collection,
|
||||
mock_document_repository, mock_collection_repository):
|
||||
document_id = uuid4()
|
||||
user_id = uuid4()
|
||||
owner_id = uuid4()
|
||||
mock_document.collection_id = uuid4()
|
||||
mock_collection.owner_id = owner_id
|
||||
|
||||
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
|
||||
with pytest.raises(ForbiddenError):
|
||||
await document_use_cases.update_document(
|
||||
document_id=document_id,
|
||||
user_id=user_id,
|
||||
title="Название"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_document_success(self, document_use_cases, mock_document, mock_collection,
|
||||
mock_document_repository, mock_collection_repository):
|
||||
document_id = uuid4()
|
||||
user_id = uuid4()
|
||||
mock_document.collection_id = uuid4()
|
||||
mock_collection.owner_id = user_id
|
||||
|
||||
mock_document_repository.get_by_id = AsyncMock(return_value=mock_document)
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
mock_document_repository.delete = AsyncMock(return_value=True)
|
||||
|
||||
result = await document_use_cases.delete_document(document_id, user_id)
|
||||
|
||||
assert result is True
|
||||
mock_document_repository.delete.assert_called_once_with(document_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_collection_documents(self, document_use_cases, mock_collection, mock_documents_list,
|
||||
mock_collection_repository, mock_document_repository):
|
||||
collection_id = uuid4()
|
||||
mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection)
|
||||
mock_document_repository.list_by_collection = AsyncMock(return_value=mock_documents_list)
|
||||
|
||||
result = await document_use_cases.list_collection_documents(collection_id, skip=0, limit=10)
|
||||
|
||||
assert len(result) == len(mock_documents_list)
|
||||
mock_document_repository.list_by_collection.assert_called_once_with(collection_id, skip=0, limit=10)
|
||||
171
tests/unit/test_rag_service.py
Normal file
171
tests/unit/test_rag_service.py
Normal file
@ -0,0 +1,171 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import uuid4
|
||||
from tg_bot.application.services.rag_service import RAGService
|
||||
|
||||
|
||||
class TestRAGService:
|
||||
|
||||
@pytest.fixture
|
||||
def rag_service(self):
|
||||
service = RAGService()
|
||||
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
||||
service.deepseek_client = DeepSeekClient()
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_documents_in_collections_success(self, rag_service):
|
||||
user_telegram_id = "123456789"
|
||||
query = "трудовой договор"
|
||||
|
||||
mock_documents = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"title": "Трудовой кодекс РФ",
|
||||
"content": "Содержание о трудовых договорах",
|
||||
"collection_name": "Законы"
|
||||
},
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"title": "Правила оформления",
|
||||
"content": "Как оформить трудовой договор",
|
||||
"collection_name": "Инструкции"
|
||||
}
|
||||
]
|
||||
|
||||
with patch('aiohttp.ClientSession') as mock_session:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
"user_id": str(uuid4())
|
||||
})
|
||||
|
||||
mock_collections_response = AsyncMock()
|
||||
mock_collections_response.status = 200
|
||||
mock_collections_response.json = AsyncMock(return_value=[
|
||||
{"collection_id": str(uuid4()), "name": "Законы"}
|
||||
])
|
||||
|
||||
mock_search_response = AsyncMock()
|
||||
mock_search_response.status = 200
|
||||
mock_search_response.json = AsyncMock(return_value=mock_documents)
|
||||
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
||||
mock_session_instance.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session_instance.get = AsyncMock(side_effect=[
|
||||
mock_response,
|
||||
mock_collections_response,
|
||||
mock_search_response
|
||||
])
|
||||
mock_session.return_value = mock_session_instance
|
||||
|
||||
result = await rag_service.search_documents_in_collections(
|
||||
user_telegram_id, query, limit_per_collection=5
|
||||
)
|
||||
|
||||
assert len(result) > 0
|
||||
assert result[0]["title"] == "Трудовой кодекс РФ"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_documents_empty_result(self, rag_service):
|
||||
user_telegram_id = "123456789"
|
||||
query = "несуществующий запрос"
|
||||
|
||||
with patch('aiohttp.ClientSession') as mock_session:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
"user_id": str(uuid4())
|
||||
})
|
||||
|
||||
mock_collections_response = AsyncMock()
|
||||
mock_collections_response.status = 200
|
||||
mock_collections_response.json = AsyncMock(return_value=[])
|
||||
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
||||
mock_session_instance.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session_instance.get = AsyncMock(side_effect=[
|
||||
mock_response,
|
||||
mock_collections_response
|
||||
])
|
||||
mock_session.return_value = mock_session_instance
|
||||
|
||||
result = await rag_service.search_documents_in_collections(
|
||||
user_telegram_id, query
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_answer_with_rag_success(self, rag_service, mock_rag_response):
|
||||
question = "Какие права имеет работник?"
|
||||
user_telegram_id = "123456789"
|
||||
|
||||
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||
patch.object(rag_service, 'deepseek_client') as mock_client:
|
||||
|
||||
mock_search.return_value = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"title": "Трудовой кодекс",
|
||||
"content": "Работник имеет право на...",
|
||||
"collection_name": "Законы"
|
||||
}
|
||||
]
|
||||
|
||||
mock_client.chat_completion = AsyncMock(return_value={
|
||||
"content": "Работник имеет следующие права...",
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
|
||||
})
|
||||
|
||||
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||
|
||||
assert "answer" in result
|
||||
assert "sources" in result
|
||||
assert "usage" in result
|
||||
assert len(result["sources"]) <= 5
|
||||
assert result["answer"] != ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_answer_limits_to_top5(self, rag_service):
|
||||
question = "Тестовый вопрос"
|
||||
user_telegram_id = "123456789"
|
||||
|
||||
many_documents = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"title": f"Документ {i}",
|
||||
"content": f"Содержание {i}",
|
||||
"collection_name": "Коллекция"
|
||||
}
|
||||
for i in range(20)
|
||||
]
|
||||
|
||||
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||||
patch.object(rag_service, 'deepseek_client') as mock_client:
|
||||
|
||||
mock_search.return_value = many_documents
|
||||
mock_client.chat_completion = AsyncMock(return_value={
|
||||
"content": "Ответ",
|
||||
"usage": {}
|
||||
})
|
||||
|
||||
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||
|
||||
assert len(result["sources"]) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_answer_no_documents(self, rag_service):
|
||||
question = "Вопрос без документов"
|
||||
user_telegram_id = "123456789"
|
||||
|
||||
with patch.object(rag_service, 'search_documents_in_collections') as mock_search:
|
||||
mock_search.return_value = []
|
||||
|
||||
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||||
|
||||
assert result["sources"] == []
|
||||
assert "Релевантные документы не найдены" in result.get("answer", "") or \
|
||||
result["answer"] == "No relevant documents found"
|
||||
193
tests/unit/test_user_service.py
Normal file
193
tests/unit/test_user_service.py
Normal file
@ -0,0 +1,193 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
from tg_bot.domain.services.user_service import UserService
|
||||
from tg_bot.infrastructure.database.models import UserModel
|
||||
|
||||
|
||||
class TestUserService:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def user_service(self, mock_session):
|
||||
return UserService(mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_telegram_id_success(self, user_service, mock_session):
|
||||
telegram_id = 123456789
|
||||
mock_user = UserModel(
|
||||
telegram_id=str(telegram_id),
|
||||
username="test_user",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none = MagicMock(return_value=mock_user)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await user_service.get_user_by_telegram_id(telegram_id)
|
||||
|
||||
assert result == mock_user
|
||||
assert result.telegram_id == str(telegram_id)
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_telegram_id_not_found(self, user_service, mock_session):
|
||||
telegram_id = 999999999
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none = MagicMock(return_value=None)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await user_service.get_user_by_telegram_id(telegram_id)
|
||||
|
||||
assert result is None
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_new_user(self, user_service, mock_session):
|
||||
telegram_id = 123456789
|
||||
username = "new_user"
|
||||
first_name = "New"
|
||||
last_name = "User"
|
||||
|
||||
mock_result_not_found = MagicMock()
|
||||
mock_result_not_found.scalar_one_or_none = MagicMock(return_value=None)
|
||||
|
||||
mock_result_found = MagicMock()
|
||||
created_user = UserModel(
|
||||
telegram_id=str(telegram_id),
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name
|
||||
)
|
||||
mock_result_found.scalar_one_or_none = MagicMock(return_value=created_user)
|
||||
|
||||
mock_session.execute.side_effect = [mock_result_not_found, mock_result_found]
|
||||
|
||||
result = await user_service.get_or_create_user(telegram_id, username, first_name, last_name)
|
||||
|
||||
assert result is not None
|
||||
assert result.telegram_id == str(telegram_id)
|
||||
assert result.username == username
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_existing_user(self, user_service, mock_session):
|
||||
telegram_id = 123456789
|
||||
existing_user = UserModel(
|
||||
telegram_id=str(telegram_id),
|
||||
username="old_username",
|
||||
first_name="Old",
|
||||
last_name="Name"
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none = MagicMock(return_value=existing_user)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await user_service.get_or_create_user(
|
||||
telegram_id, "new_username", "New", "Name"
|
||||
)
|
||||
|
||||
assert result == existing_user
|
||||
assert result.username == "new_username"
|
||||
assert result.first_name == "New"
|
||||
assert result.last_name == "Name"
|
||||
mock_session.commit.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_questions_success(self, user_service, mock_session):
|
||||
telegram_id = 123456789
|
||||
user = UserModel(
|
||||
telegram_id=str(telegram_id),
|
||||
questions_used=5
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none = MagicMock(return_value=user)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await user_service.update_user_questions(telegram_id)
|
||||
|
||||
assert result is True
|
||||
assert user.questions_used == 6
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_questions_user_not_found(self, user_service, mock_session):
|
||||
telegram_id = 999999999
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none = MagicMock(return_value=None)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await user_service.update_user_questions(telegram_id)
|
||||
|
||||
assert result is False
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_activate_premium_success(self, user_service, mock_session):
|
||||
telegram_id = 123456789
|
||||
user = UserModel(
|
||||
telegram_id=str(telegram_id),
|
||||
is_premium=False,
|
||||
premium_until=None
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none = MagicMock(return_value=user)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await user_service.activate_premium(telegram_id)
|
||||
|
||||
assert result is True
|
||||
assert user.is_premium is True
|
||||
assert user.premium_until is not None
|
||||
assert user.premium_until > datetime.now()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_activate_premium_extend_existing(self, user_service, mock_session):
|
||||
telegram_id = 123456789
|
||||
existing_premium_until = datetime.now() + timedelta(days=10)
|
||||
user = UserModel(
|
||||
telegram_id=str(telegram_id),
|
||||
is_premium=True,
|
||||
premium_until=existing_premium_until
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none = MagicMock(return_value=user)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await user_service.activate_premium(telegram_id)
|
||||
|
||||
assert result is True
|
||||
assert user.is_premium is True
|
||||
assert user.premium_until > existing_premium_until
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_activate_premium_user_not_found(self, user_service, mock_session):
|
||||
telegram_id = 999999999
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none = MagicMock(return_value=None)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await user_service.activate_premium(telegram_id)
|
||||
|
||||
assert result is False
|
||||
mock_session.commit.assert_not_called()
|
||||
@ -4,8 +4,7 @@ python-dotenv>=1.0.0
|
||||
aiogram>=3.10.0
|
||||
sqlalchemy>=2.0.0
|
||||
aiosqlite>=0.19.0
|
||||
httpx>=0.25.0
|
||||
httpx>=0.25.2
|
||||
yookassa>=2.4.0
|
||||
fastapi>=0.104.0
|
||||
uvicorn>=0.24.0
|
||||
python-multipart>=0.0.6
|
||||
aiohttp>=3.9.1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user