diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..a43df4b --- /dev/null +++ b/pytest.ini @@ -0,0 +1,23 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +asyncio_mode = auto +addopts = + -v + --strict-markers + --tb=short + --cov=backend/src + --cov=tg_bot + --cov-report=term-missing + --cov-report=xml + --cov-fail-under=65 + --ignore=venv + --ignore=.venv +markers = + unit: Unit tests + integration: Integration tests + metrics: Metrics tests + slow: Slow running tests + diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..b6f2291 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,94 @@ +# Тесты для проекта BetterCallPraskovia + +## Структура тестов + +``` +tests/ +├── conftest.py +├── unit/ +│ ├── test_rag_service.py +│ ├── test_user_service.py +│ ├── test_deepseek_client.py +│ ├── test_document_use_cases.py +│ └── test_collection_use_cases.py +├── integration/ +│ └── test_rag_integration.py +└── metrics/ + └── test_hit_at_5.py +``` + +## Установка зависимостей + +```bash +pip install -r tests/requirements.txt +``` + +## Запуск тестов + +### Все тесты +```bash +pytest +``` + +### Только юнит-тесты +```bash +pytest tests/unit/ +``` + +### Только интеграционные тесты +```bash +pytest tests/integration/ +``` + +### Только тесты метрик +```bash +pytest tests/metrics/ +``` + +### Только тесты tg_bot +```bash +pytest tests/unit/test_rag_service.py tests/unit/test_user_service.py tests/unit/test_deepseek_client.py +``` + +### С покрытием кода +```bash +pytest --cov=backend/src --cov=tg_bot --cov-report=html +``` + +### С минимальным покрытием 65% +```bash +pytest --cov-fail-under=65 +``` + +## Метрика hit@5 + +Проверка что в топ-5 релевантных документов есть хотя бы 1 нужный документ. + +- **hit@5 = 1**, если есть хотя бы 1 релевантный документ в топ-5 +- **hit@5 = 0**, если нет релевантных документов в топ-5 + +Среднее значение hit@5 для всех запросов должно быть **> 50%** + +## Покрытие кода + +**coverage ≥ 65%** + +Проверка покрытия: +```bash +pytest --cov=backend/src --cov=tg_bot --cov-report=term-missing --cov-fail-under=65 +``` + +## Маркеры тестов + +- `@pytest.mark.unit` - юнит-тесты +- `@pytest.mark.integration` - интеграционные тесты +- `@pytest.mark.metrics` - тесты метрик +- `@pytest.mark.slow` - медленные тесты + +Запуск по маркерам: +```bash +pytest -m unit +pytest -m integration +pytest -m metrics +``` + diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..e5b2ae2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,170 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 +from datetime import datetime +from typing import List, Dict + + +@pytest.fixture +def mock_document(): + try: + from backend.src.domain.entities.document import Document + except ImportError: + from src.domain.entities.document import Document + return Document( + collection_id=uuid4(), + title="Тестовый документ", + content="Содержание документа", + metadata={"type": "law", "article": "123"} + ) + + +@pytest.fixture +def mock_documents_list(): + try: + from backend.src.domain.entities.document import Document + except ImportError: + from src.domain.entities.document import Document + collection_id = uuid4() + return [ + Document( + collection_id=collection_id, + title=f"Документ {i}", + content=f"Содержание документа {i} ", + metadata={"relevance_score": 0.9 - i * 0.1} + ) + for i in range(10) + ] + + +@pytest.fixture +def mock_relevant_documents(): + return [ + {"document_id": str(uuid4()), "title": "Гражданский кодекс РФ", "relevance": True}, + {"document_id": str(uuid4()), "title": "Трудовой кодекс РФ", "relevance": True}, + {"document_id": str(uuid4()), "title": "Налоговый кодекс РФ", "relevance": True}, + ] + + +@pytest.fixture +def mock_rag_response(): + return { + "answer": "Тестовый ответ на вопрос", + "sources": [ + {"title": "Документ 1", "collection": "Коллекция 1", "document_id": str(uuid4())}, + {"title": "Документ 2", "collection": "Коллекция 1", "document_id": str(uuid4())}, + {"title": "Документ 3", "collection": "Коллекция 2", "document_id": str(uuid4())}, + ], + "usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300} + } + + +@pytest.fixture +def mock_collection(): + try: + from backend.src.domain.entities.collection import Collection + except ImportError: + from src.domain.entities.collection import Collection + return Collection( + name="Коллекция", + owner_id=uuid4(), + description="Описание коллекции", + is_public=False + ) + + +@pytest.fixture +def mock_user(): + try: + from backend.src.domain.entities.user import User, UserRole + except ImportError: + from src.domain.entities.user import User, UserRole + return User( + telegram_id="123456789", + role=UserRole.USER + ) + + +@pytest.fixture +def mock_document_repository(): + repository = AsyncMock() + repository.get_by_id = AsyncMock() + repository.create = AsyncMock() + repository.update = AsyncMock() + repository.delete = AsyncMock(return_value=True) + repository.list_by_collection = AsyncMock(return_value=[]) + return repository + + +@pytest.fixture +def mock_collection_repository(): + repository = AsyncMock() + repository.get_by_id = AsyncMock() + repository.create = AsyncMock() + repository.update = AsyncMock() + repository.delete = AsyncMock(return_value=True) + repository.list_by_owner = AsyncMock(return_value=[]) + repository.list_public = AsyncMock(return_value=[]) + return repository + + +@pytest.fixture +def mock_user_repository(): + repository = AsyncMock() + repository.get_by_id = AsyncMock() + repository.get_by_telegram_id = AsyncMock() + repository.create = AsyncMock() + repository.update = AsyncMock() + repository.delete = AsyncMock(return_value=True) + return repository + + +@pytest.fixture +def mock_deepseek_client(): + client = AsyncMock() + client.chat_completion = AsyncMock(return_value={ + "content": "Ответ от DeepSeek", + "usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300} + }) + return client + + +@pytest.fixture +def sample_queries_with_ground_truth(): + return [ + { + "query": "Какие права имеет работник при увольнении?", + "relevant_document_ids": [str(uuid4()), str(uuid4())], + "expected_top5_contains": True + }, + { + "query": "Как оформить договор купли-продажи?", + "relevant_document_ids": [str(uuid4())], + "expected_top5_contains": True + }, + { + "query": "Какие налоги платит ИП?", + "relevant_document_ids": [str(uuid4()), str(uuid4()), str(uuid4())], + "expected_top5_contains": True + }, + { + "query": "Права потребителя при возврате товара", + "relevant_document_ids": [str(uuid4())], + "expected_top5_contains": True + }, + { + "query": "Как расторгнуть брак?", + "relevant_document_ids": [str(uuid4())], + "expected_top5_contains": True + }, + ] + + +@pytest.fixture +def mock_aiohttp_session(): + session = AsyncMock() + session.get = AsyncMock() + session.post = AsyncMock() + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock(return_value=None) + return session diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_rag_integration.py b/tests/integration/test_rag_integration.py new file mode 100644 index 0000000..3fec432 --- /dev/null +++ b/tests/integration/test_rag_integration.py @@ -0,0 +1,146 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import uuid4 +import aiohttp +from tg_bot.application.services.rag_service import RAGService +from tests.metrics.test_hit_at_5 import calculate_hit_at_5, calculate_average_hit_at_5 + + +class TestRAGIntegration: + + @pytest.fixture + def rag_service(self): + service = RAGService() + from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient + service.deepseek_client = DeepSeekClient() + return service + + @pytest.mark.asyncio + async def test_rag_service_search_documents_real_implementation(self, rag_service): + user_telegram_id = "123456789" + query = "трудовой договор" + user_uuid = str(uuid4()) + collection_id = str(uuid4()) + + mock_user_response = AsyncMock() + mock_user_response.status = 200 + mock_user_response.json = AsyncMock(return_value={"user_id": user_uuid}) + + mock_collections_response = AsyncMock() + mock_collections_response.status = 200 + mock_collections_response.json = AsyncMock(return_value=[ + {"collection_id": collection_id, "name": "Законы"} + ]) + + mock_documents_response = AsyncMock() + mock_documents_response.status = 200 + mock_documents_response.json = AsyncMock(return_value=[ + { + "document_id": str(uuid4()), + "title": "Трудовой кодекс", + "content": "Содержание о трудовых договорах" + } + ]) + + with patch('aiohttp.ClientSession') as mock_session_class: + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + mock_session.get = AsyncMock(side_effect=[ + mock_user_response, + mock_collections_response, + mock_documents_response + ]) + mock_session_class.return_value = mock_session + + result = await rag_service.search_documents_in_collections(user_telegram_id, query) + + assert isinstance(result, list) + if result: + assert "document_id" in result[0] + assert "title" in result[0] + + @pytest.mark.asyncio + async def test_rag_service_generate_answer_real_implementation(self, rag_service): + question = "Какие права имеет работник?" + user_telegram_id = "123456789" + + mock_documents = [ + { + "document_id": str(uuid4()), + "title": "Трудовой кодекс", + "content": "Работник имеет право на...", + "collection_name": "Законы" + } + for _ in range(3) + ] + + with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \ + patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek: + + mock_search.return_value = mock_documents + mock_deepseek.return_value = { + "content": "Работник имеет следующие права...", + "usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300} + } + + result = await rag_service.generate_answer_with_rag(question, user_telegram_id) + + assert "answer" in result + assert "sources" in result + assert "usage" in result + assert len(result["sources"]) <= 5 + assert result["answer"] != "" + + for source in result["sources"]: + assert "title" in source + assert "collection" in source + assert "document_id" in source + + @pytest.mark.asyncio + async def test_rag_service_limits_to_top5_documents(self, rag_service): + question = "Тестовый вопрос" + user_telegram_id = "123456789" + + many_documents = [ + { + "document_id": str(uuid4()), + "title": f"Документ {i}", + "content": f"Содержание {i}", + "collection_name": "Коллекция" + } + for i in range(10) + ] + + with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \ + patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek: + + mock_search.return_value = many_documents + mock_deepseek.return_value = { + "content": "Ответ", + "usage": {} + } + + result = await rag_service.generate_answer_with_rag(question, user_telegram_id) + + assert len(result["sources"]) == 5 + + @pytest.mark.asyncio + async def test_rag_service_handles_empty_search_results(self, rag_service): + question = "Вопрос без документов" + user_telegram_id = "123456789" + + with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \ + patch.object(rag_service.deepseek_client, 'chat_completion') as mock_deepseek: + + mock_search.return_value = [] + mock_deepseek.return_value = { + "content": "Ответ", + "usage": {} + } + + result = await rag_service.generate_answer_with_rag(question, user_telegram_id) + + assert result["sources"] == [] + assert "Релевантные документы не найдены" in result.get("answer", "") or \ + result["answer"] == "No relevant documents found" diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/metrics/test_hit_at_5.py b/tests/metrics/test_hit_at_5.py new file mode 100644 index 0000000..405a5aa --- /dev/null +++ b/tests/metrics/test_hit_at_5.py @@ -0,0 +1,133 @@ +import pytest +from uuid import uuid4 +from typing import List, Dict + + +def calculate_hit_at_5(retrieved_document_ids: List[str], relevant_document_ids: List[str]) -> int: + if not retrieved_document_ids or not relevant_document_ids: + return 0 + + top5_ids = set(retrieved_document_ids[:5]) + relevant_ids = set(relevant_document_ids) + + return 1 if top5_ids.intersection(relevant_ids) else 0 + + +def calculate_average_hit_at_5(results: List[int]) -> float: + if not results: + return 0.0 + return sum(results) / len(results) + + +class TestHitAt5Metric: + + def test_hit_at_5_returns_1_when_relevant_document_in_top5(self): + relevant_ids = [str(uuid4()), str(uuid4())] + retrieved_ids = [ + str(uuid4()), + relevant_ids[0], + str(uuid4()), + str(uuid4()), + str(uuid4()) + ] + + result = calculate_hit_at_5(retrieved_ids, relevant_ids) + assert result == 1 + + def test_hit_at_5_returns_0_when_no_relevant_document_in_top5(self): + relevant_ids = [str(uuid4()), str(uuid4())] + retrieved_ids = [ + str(uuid4()), + str(uuid4()), + str(uuid4()), + str(uuid4()), + str(uuid4()) + ] + + result = calculate_hit_at_5(retrieved_ids, relevant_ids) + assert result == 0 + + def test_hit_at_5_returns_1_when_multiple_relevant_documents(self): + relevant_ids = [str(uuid4()), str(uuid4()), str(uuid4())] + retrieved_ids = [ + relevant_ids[0], + str(uuid4()), + relevant_ids[1], + str(uuid4()), + relevant_ids[2] + ] + + result = calculate_hit_at_5(retrieved_ids, relevant_ids) + assert result == 1 + + def test_hit_at_5_handles_empty_lists(self): + result = calculate_hit_at_5([], [str(uuid4())]) + assert result == 0 + + result = calculate_hit_at_5([str(uuid4())], []) + assert result == 0 + + result = calculate_hit_at_5([], []) + assert result == 0 + + def test_hit_at_5_only_checks_top5(self): + relevant_ids = [str(uuid4())] + retrieved_ids = [ + str(uuid4()), + str(uuid4()), + str(uuid4()), + str(uuid4()), + str(uuid4()), + relevant_ids[0] + ] + + result = calculate_hit_at_5(retrieved_ids, relevant_ids) + assert result == 0 + + def test_calculate_average_hit_at_5(self): + results = [1, 1, 0, 1, 0, 1, 0, 1, 1, 1] + + average = calculate_average_hit_at_5(results) + assert average == 0.7 + + def test_calculate_average_hit_at_5_all_ones(self): + results = [1, 1, 1, 1, 1] + + average = calculate_average_hit_at_5(results) + assert average == 1.0 + + def test_calculate_average_hit_at_5_all_zeros(self): + results = [0, 0, 0, 0, 0] + + average = calculate_average_hit_at_5(results) + assert average == 0.0 + + def test_calculate_average_hit_at_5_empty_list(self): + average = calculate_average_hit_at_5([]) + assert average == 0.0 + + def test_hit_at_5_quality_threshold(self): + results = [1] * 60 + [0] * 40 + + average = calculate_average_hit_at_5(results) + assert average > 0.5, f"Качество {average} должно быть > 0.5" + assert average == 0.6 + + def test_hit_at_5_quality_below_threshold(self): + results = [1] * 40 + [0] * 60 + + average = calculate_average_hit_at_5(results) + assert average < 0.5, f"Качество {average} должно быть < 0.5" + assert average == 0.4 + + @pytest.mark.parametrize("hit_count,total,expected_quality", [ + (51, 100, 0.51), + (50, 100, 0.50), + (60, 100, 0.60), + (75, 100, 0.75), + (100, 100, 1.0), + ]) + def test_hit_at_5_various_qualities(self, hit_count, total, expected_quality): + results = [1] * hit_count + [0] * (total - hit_count) + average = calculate_average_hit_at_5(results) + assert average == expected_quality diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..e341148 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,8 @@ +pytest==7.4.3 +pytest-asyncio==0.21.1 +pytest-cov==4.1.0 +pytest-mock==3.12.0 +pytest-timeout==2.2.0 +httpx>=0.25.2 +aiohttp>=3.9.1 + diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_collection_use_cases.py b/tests/unit/test_collection_use_cases.py new file mode 100644 index 0000000..7994c79 --- /dev/null +++ b/tests/unit/test_collection_use_cases.py @@ -0,0 +1,132 @@ +import pytest +from uuid import uuid4 +from unittest.mock import AsyncMock +try: + from backend.src.application.use_cases.collection_use_cases import CollectionUseCases + from backend.src.shared.exceptions import NotFoundError, ForbiddenError + from backend.src.domain.repositories.collection_access_repository import ICollectionAccessRepository +except ImportError: + from src.application.use_cases.collection_use_cases import CollectionUseCases + from src.shared.exceptions import NotFoundError, ForbiddenError + from src.domain.repositories.collection_access_repository import ICollectionAccessRepository + + +class TestCollectionUseCases: + + @pytest.fixture + def collection_use_cases(self, mock_collection_repository, mock_user_repository): + mock_access_repository = AsyncMock() + mock_access_repository.get_by_user_and_collection = AsyncMock(return_value=None) + mock_access_repository.create = AsyncMock() + mock_access_repository.delete_by_user_and_collection = AsyncMock(return_value=True) + mock_access_repository.list_by_user = AsyncMock(return_value=[]) + + return CollectionUseCases( + collection_repository=mock_collection_repository, + access_repository=mock_access_repository, + user_repository=mock_user_repository + ) + + @pytest.mark.asyncio + async def test_create_collection_success(self, collection_use_cases, mock_user, + mock_collection_repository, mock_user_repository): + owner_id = uuid4() + mock_user_repository.get_by_id = AsyncMock(return_value=mock_user) + mock_collection_repository.create = AsyncMock(return_value=mock_user) + + result = await collection_use_cases.create_collection( + name="Тестовая коллекция", + owner_id=owner_id, + description="Описание", + is_public=False + ) + + assert result is not None + mock_user_repository.get_by_id.assert_called_once_with(owner_id) + mock_collection_repository.create.assert_called_once() + + @pytest.mark.asyncio + async def test_create_collection_user_not_found(self, collection_use_cases, mock_user_repository): + owner_id = uuid4() + mock_user_repository.get_by_id = AsyncMock(return_value=None) + + with pytest.raises(NotFoundError): + await collection_use_cases.create_collection( + name="Коллекция", + owner_id=owner_id + ) + + @pytest.mark.asyncio + async def test_get_collection_success(self, collection_use_cases, mock_collection, mock_collection_repository): + collection_id = uuid4() + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + + result = await collection_use_cases.get_collection(collection_id) + + assert result == mock_collection + mock_collection_repository.get_by_id.assert_called_once_with(collection_id) + + @pytest.mark.asyncio + async def test_get_collection_not_found(self, collection_use_cases, mock_collection_repository): + collection_id = uuid4() + mock_collection_repository.get_by_id = AsyncMock(return_value=None) + + with pytest.raises(NotFoundError): + await collection_use_cases.get_collection(collection_id) + + @pytest.mark.asyncio + async def test_update_collection_success(self, collection_use_cases, mock_collection, mock_collection_repository): + collection_id = uuid4() + user_id = uuid4() + mock_collection.owner_id = user_id + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + mock_collection_repository.update = AsyncMock(return_value=mock_collection) + + result = await collection_use_cases.update_collection( + collection_id=collection_id, + user_id=user_id, + name="Обновленное название" + ) + + assert result is not None + assert mock_collection.name == "Обновленное название" + + @pytest.mark.asyncio + async def test_update_collection_forbidden(self, collection_use_cases, mock_collection, mock_collection_repository): + collection_id = uuid4() + user_id = uuid4() + owner_id = uuid4() + mock_collection.owner_id = owner_id + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + + with pytest.raises(ForbiddenError): + await collection_use_cases.update_collection( + collection_id=collection_id, + user_id=user_id, + name="Название" + ) + + @pytest.mark.asyncio + async def test_check_access_owner(self, collection_use_cases, mock_collection, mock_collection_repository): + collection_id = uuid4() + user_id = uuid4() + mock_collection.owner_id = user_id + mock_collection.is_public = False + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + + result = await collection_use_cases.check_access(collection_id, user_id) + + assert result is True + + @pytest.mark.asyncio + async def test_check_access_public(self, collection_use_cases, mock_collection, mock_collection_repository): + collection_id = uuid4() + user_id = uuid4() + owner_id = uuid4() + mock_collection.owner_id = owner_id + mock_collection.is_public = True + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + + result = await collection_use_cases.check_access(collection_id, user_id) + + assert result is True diff --git a/tests/unit/test_deepseek_client.py b/tests/unit/test_deepseek_client.py new file mode 100644 index 0000000..ecf0298 --- /dev/null +++ b/tests/unit/test_deepseek_client.py @@ -0,0 +1,114 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +import httpx +from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient +from tg_bot.infrastructure.external.deepseek_client import DeepSeekAPIError + + +class TestDeepSeekClient: + + @pytest.fixture + def deepseek_client(self): + return DeepSeekClient(api_key="test_key", api_url="https://api.test.com/v1/chat/completions") + + @pytest.mark.asyncio + async def test_chat_completion_success(self, deepseek_client): + messages = [ + {"role": "user", "content": "Тестовый вопрос"} + ] + + mock_response_data = { + "choices": [{ + "message": { + "content": "Тестовый ответ от DeepSeek" + } + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } + } + + with patch('httpx.AsyncClient') as mock_client: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_response_data + mock_response.raise_for_status = MagicMock() + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client.return_value = mock_client_instance + + result = await deepseek_client.chat_completion(messages) + + assert "content" in result + assert result["content"] == "Тестовый ответ от DeepSeek" + assert "usage" in result + assert result["usage"]["total_tokens"] == 30 + + @pytest.mark.asyncio + async def test_chat_completion_no_api_key(self): + client = DeepSeekClient(api_key=None) + messages = [{"role": "user", "content": "Вопрос"}] + + result = await client.chat_completion(messages) + + assert "content" in result + assert "DEEPSEEK_API_KEY" in result["content"] or "не установлен" in result["content"] + assert result["usage"]["total_tokens"] == 0 + + @pytest.mark.asyncio + async def test_chat_completion_api_error(self, deepseek_client): + import httpx + messages = [{"role": "user", "content": "Вопрос"}] + + with patch('httpx.AsyncClient') as mock_client: + mock_response = AsyncMock() + mock_response.status_code = 401 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Unauthorized", request=MagicMock(), response=mock_response + ) + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client.return_value = mock_client_instance + + with pytest.raises(DeepSeekAPIError): + await deepseek_client.chat_completion(messages) + + @pytest.mark.asyncio + async def test_chat_completion_with_parameters(self, deepseek_client): + messages = [{"role": "user", "content": "Вопрос"}] + + mock_response_data = { + "choices": [{"message": {"content": "Ответ"}}], + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15} + } + + with patch('httpx.AsyncClient') as mock_client: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_response_data + mock_response.raise_for_status = MagicMock() + + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client.return_value = mock_client_instance + + result = await deepseek_client.chat_completion( + messages, + model="deepseek-chat", + temperature=0.7, + max_tokens=100 + ) + + assert result["content"] == "Ответ" + call_args = mock_client_instance.post.call_args + assert call_args is not None diff --git a/tests/unit/test_document_use_cases.py b/tests/unit/test_document_use_cases.py new file mode 100644 index 0000000..2ab1812 --- /dev/null +++ b/tests/unit/test_document_use_cases.py @@ -0,0 +1,141 @@ +import pytest +from uuid import uuid4 +from unittest.mock import AsyncMock +try: + from backend.src.application.use_cases.document_use_cases import DocumentUseCases + from backend.src.shared.exceptions import NotFoundError, ForbiddenError + from backend.src.application.services.document_parser_service import DocumentParserService +except ImportError: + from src.application.use_cases.document_use_cases import DocumentUseCases + from src.shared.exceptions import NotFoundError, ForbiddenError + from src.application.services.document_parser_service import DocumentParserService + + +class TestDocumentUseCases: + + @pytest.fixture + def document_use_cases(self, mock_document_repository, mock_collection_repository): + mock_parser = AsyncMock() + mock_parser.parse_pdf = AsyncMock(return_value=("Парсенный документ", "Содержание")) + + return DocumentUseCases( + document_repository=mock_document_repository, + collection_repository=mock_collection_repository, + parser_service=mock_parser + ) + + @pytest.mark.asyncio + async def test_create_document_success(self, document_use_cases, mock_collection, mock_document_repository, mock_collection_repository): + collection_id = uuid4() + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + mock_document_repository.create = AsyncMock(return_value=mock_collection) + + result = await document_use_cases.create_document( + collection_id=collection_id, + title="Тестовый документ", + content="Содержание", + metadata={"type": "law"} + ) + + assert result is not None + mock_collection_repository.get_by_id.assert_called_once_with(collection_id) + mock_document_repository.create.assert_called_once() + + @pytest.mark.asyncio + async def test_create_document_collection_not_found(self, document_use_cases, mock_collection_repository): + collection_id = uuid4() + mock_collection_repository.get_by_id = AsyncMock(return_value=None) + + with pytest.raises(NotFoundError): + await document_use_cases.create_document( + collection_id=collection_id, + title="Документ", + content="Содержание" + ) + + @pytest.mark.asyncio + async def test_get_document_success(self, document_use_cases, mock_document, mock_document_repository): + document_id = uuid4() + mock_document_repository.get_by_id = AsyncMock(return_value=mock_document) + + result = await document_use_cases.get_document(document_id) + + assert result == mock_document + mock_document_repository.get_by_id.assert_called_once_with(document_id) + + @pytest.mark.asyncio + async def test_get_document_not_found(self, document_use_cases, mock_document_repository): + document_id = uuid4() + mock_document_repository.get_by_id = AsyncMock(return_value=None) + + with pytest.raises(NotFoundError): + await document_use_cases.get_document(document_id) + + @pytest.mark.asyncio + async def test_update_document_success(self, document_use_cases, mock_document, mock_collection, + mock_document_repository, mock_collection_repository): + document_id = uuid4() + user_id = uuid4() + mock_document.collection_id = uuid4() + mock_collection.owner_id = user_id + + mock_document_repository.get_by_id = AsyncMock(return_value=mock_document) + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + mock_document_repository.update = AsyncMock(return_value=mock_document) + + result = await document_use_cases.update_document( + document_id=document_id, + user_id=user_id, + title="Обновленное название" + ) + + assert result is not None + assert mock_document.title == "Обновленное название" + + @pytest.mark.asyncio + async def test_update_document_forbidden(self, document_use_cases, mock_document, mock_collection, + mock_document_repository, mock_collection_repository): + document_id = uuid4() + user_id = uuid4() + owner_id = uuid4() + mock_document.collection_id = uuid4() + mock_collection.owner_id = owner_id + + mock_document_repository.get_by_id = AsyncMock(return_value=mock_document) + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + + with pytest.raises(ForbiddenError): + await document_use_cases.update_document( + document_id=document_id, + user_id=user_id, + title="Название" + ) + + @pytest.mark.asyncio + async def test_delete_document_success(self, document_use_cases, mock_document, mock_collection, + mock_document_repository, mock_collection_repository): + document_id = uuid4() + user_id = uuid4() + mock_document.collection_id = uuid4() + mock_collection.owner_id = user_id + + mock_document_repository.get_by_id = AsyncMock(return_value=mock_document) + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + mock_document_repository.delete = AsyncMock(return_value=True) + + result = await document_use_cases.delete_document(document_id, user_id) + + assert result is True + mock_document_repository.delete.assert_called_once_with(document_id) + + @pytest.mark.asyncio + async def test_list_collection_documents(self, document_use_cases, mock_collection, mock_documents_list, + mock_collection_repository, mock_document_repository): + collection_id = uuid4() + mock_collection_repository.get_by_id = AsyncMock(return_value=mock_collection) + mock_document_repository.list_by_collection = AsyncMock(return_value=mock_documents_list) + + result = await document_use_cases.list_collection_documents(collection_id, skip=0, limit=10) + + assert len(result) == len(mock_documents_list) + mock_document_repository.list_by_collection.assert_called_once_with(collection_id, skip=0, limit=10) diff --git a/tests/unit/test_rag_service.py b/tests/unit/test_rag_service.py new file mode 100644 index 0000000..c98b77e --- /dev/null +++ b/tests/unit/test_rag_service.py @@ -0,0 +1,171 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import uuid4 +from tg_bot.application.services.rag_service import RAGService + + +class TestRAGService: + + @pytest.fixture + def rag_service(self): + service = RAGService() + from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient + service.deepseek_client = DeepSeekClient() + return service + + @pytest.mark.asyncio + async def test_search_documents_in_collections_success(self, rag_service): + user_telegram_id = "123456789" + query = "трудовой договор" + + mock_documents = [ + { + "document_id": str(uuid4()), + "title": "Трудовой кодекс РФ", + "content": "Содержание о трудовых договорах", + "collection_name": "Законы" + }, + { + "document_id": str(uuid4()), + "title": "Правила оформления", + "content": "Как оформить трудовой договор", + "collection_name": "Инструкции" + } + ] + + with patch('aiohttp.ClientSession') as mock_session: + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + "user_id": str(uuid4()) + }) + + mock_collections_response = AsyncMock() + mock_collections_response.status = 200 + mock_collections_response.json = AsyncMock(return_value=[ + {"collection_id": str(uuid4()), "name": "Законы"} + ]) + + mock_search_response = AsyncMock() + mock_search_response.status = 200 + mock_search_response.json = AsyncMock(return_value=mock_documents) + + mock_session_instance = MagicMock() + mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance) + mock_session_instance.__aexit__ = AsyncMock(return_value=None) + mock_session_instance.get = AsyncMock(side_effect=[ + mock_response, + mock_collections_response, + mock_search_response + ]) + mock_session.return_value = mock_session_instance + + result = await rag_service.search_documents_in_collections( + user_telegram_id, query, limit_per_collection=5 + ) + + assert len(result) > 0 + assert result[0]["title"] == "Трудовой кодекс РФ" + + @pytest.mark.asyncio + async def test_search_documents_empty_result(self, rag_service): + user_telegram_id = "123456789" + query = "несуществующий запрос" + + with patch('aiohttp.ClientSession') as mock_session: + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + "user_id": str(uuid4()) + }) + + mock_collections_response = AsyncMock() + mock_collections_response.status = 200 + mock_collections_response.json = AsyncMock(return_value=[]) + + mock_session_instance = MagicMock() + mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance) + mock_session_instance.__aexit__ = AsyncMock(return_value=None) + mock_session_instance.get = AsyncMock(side_effect=[ + mock_response, + mock_collections_response + ]) + mock_session.return_value = mock_session_instance + + result = await rag_service.search_documents_in_collections( + user_telegram_id, query + ) + + assert result == [] + + @pytest.mark.asyncio + async def test_generate_answer_with_rag_success(self, rag_service, mock_rag_response): + question = "Какие права имеет работник?" + user_telegram_id = "123456789" + + with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \ + patch.object(rag_service, 'deepseek_client') as mock_client: + + mock_search.return_value = [ + { + "document_id": str(uuid4()), + "title": "Трудовой кодекс", + "content": "Работник имеет право на...", + "collection_name": "Законы" + } + ] + + mock_client.chat_completion = AsyncMock(return_value={ + "content": "Работник имеет следующие права...", + "usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300} + }) + + result = await rag_service.generate_answer_with_rag(question, user_telegram_id) + + assert "answer" in result + assert "sources" in result + assert "usage" in result + assert len(result["sources"]) <= 5 + assert result["answer"] != "" + + @pytest.mark.asyncio + async def test_generate_answer_limits_to_top5(self, rag_service): + question = "Тестовый вопрос" + user_telegram_id = "123456789" + + many_documents = [ + { + "document_id": str(uuid4()), + "title": f"Документ {i}", + "content": f"Содержание {i}", + "collection_name": "Коллекция" + } + for i in range(20) + ] + + with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \ + patch.object(rag_service, 'deepseek_client') as mock_client: + + mock_search.return_value = many_documents + mock_client.chat_completion = AsyncMock(return_value={ + "content": "Ответ", + "usage": {} + }) + + result = await rag_service.generate_answer_with_rag(question, user_telegram_id) + + assert len(result["sources"]) == 5 + + @pytest.mark.asyncio + async def test_generate_answer_no_documents(self, rag_service): + question = "Вопрос без документов" + user_telegram_id = "123456789" + + with patch.object(rag_service, 'search_documents_in_collections') as mock_search: + mock_search.return_value = [] + + result = await rag_service.generate_answer_with_rag(question, user_telegram_id) + + assert result["sources"] == [] + assert "Релевантные документы не найдены" in result.get("answer", "") or \ + result["answer"] == "No relevant documents found" diff --git a/tests/unit/test_user_service.py b/tests/unit/test_user_service.py new file mode 100644 index 0000000..5cbf31e --- /dev/null +++ b/tests/unit/test_user_service.py @@ -0,0 +1,193 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock +from datetime import datetime, timedelta +from tg_bot.domain.services.user_service import UserService +from tg_bot.infrastructure.database.models import UserModel + + +class TestUserService: + + @pytest.fixture + def mock_session(self): + session = AsyncMock() + session.execute = AsyncMock() + session.add = MagicMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + return session + + @pytest.fixture + def user_service(self, mock_session): + return UserService(mock_session) + + @pytest.mark.asyncio + async def test_get_user_by_telegram_id_success(self, user_service, mock_session): + telegram_id = 123456789 + mock_user = UserModel( + telegram_id=str(telegram_id), + username="test_user", + first_name="Test", + last_name="User" + ) + + mock_result = MagicMock() + mock_result.scalar_one_or_none = MagicMock(return_value=mock_user) + mock_session.execute.return_value = mock_result + + result = await user_service.get_user_by_telegram_id(telegram_id) + + assert result == mock_user + assert result.telegram_id == str(telegram_id) + mock_session.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_get_user_by_telegram_id_not_found(self, user_service, mock_session): + telegram_id = 999999999 + + mock_result = MagicMock() + mock_result.scalar_one_or_none = MagicMock(return_value=None) + mock_session.execute.return_value = mock_result + + result = await user_service.get_user_by_telegram_id(telegram_id) + + assert result is None + mock_session.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_get_or_create_user_new_user(self, user_service, mock_session): + telegram_id = 123456789 + username = "new_user" + first_name = "New" + last_name = "User" + + mock_result_not_found = MagicMock() + mock_result_not_found.scalar_one_or_none = MagicMock(return_value=None) + + mock_result_found = MagicMock() + created_user = UserModel( + telegram_id=str(telegram_id), + username=username, + first_name=first_name, + last_name=last_name + ) + mock_result_found.scalar_one_or_none = MagicMock(return_value=created_user) + + mock_session.execute.side_effect = [mock_result_not_found, mock_result_found] + + result = await user_service.get_or_create_user(telegram_id, username, first_name, last_name) + + assert result is not None + assert result.telegram_id == str(telegram_id) + assert result.username == username + mock_session.add.assert_called_once() + mock_session.commit.assert_called() + + @pytest.mark.asyncio + async def test_get_or_create_user_existing_user(self, user_service, mock_session): + telegram_id = 123456789 + existing_user = UserModel( + telegram_id=str(telegram_id), + username="old_username", + first_name="Old", + last_name="Name" + ) + + mock_result = MagicMock() + mock_result.scalar_one_or_none = MagicMock(return_value=existing_user) + mock_session.execute.return_value = mock_result + + result = await user_service.get_or_create_user( + telegram_id, "new_username", "New", "Name" + ) + + assert result == existing_user + assert result.username == "new_username" + assert result.first_name == "New" + assert result.last_name == "Name" + mock_session.commit.assert_called() + + @pytest.mark.asyncio + async def test_update_user_questions_success(self, user_service, mock_session): + telegram_id = 123456789 + user = UserModel( + telegram_id=str(telegram_id), + questions_used=5 + ) + + mock_result = MagicMock() + mock_result.scalar_one_or_none = MagicMock(return_value=user) + mock_session.execute.return_value = mock_result + + result = await user_service.update_user_questions(telegram_id) + + assert result is True + assert user.questions_used == 6 + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_update_user_questions_user_not_found(self, user_service, mock_session): + telegram_id = 999999999 + + mock_result = MagicMock() + mock_result.scalar_one_or_none = MagicMock(return_value=None) + mock_session.execute.return_value = mock_result + + result = await user_service.update_user_questions(telegram_id) + + assert result is False + mock_session.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_activate_premium_success(self, user_service, mock_session): + telegram_id = 123456789 + user = UserModel( + telegram_id=str(telegram_id), + is_premium=False, + premium_until=None + ) + + mock_result = MagicMock() + mock_result.scalar_one_or_none = MagicMock(return_value=user) + mock_session.execute.return_value = mock_result + + result = await user_service.activate_premium(telegram_id) + + assert result is True + assert user.is_premium is True + assert user.premium_until is not None + assert user.premium_until > datetime.now() + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_activate_premium_extend_existing(self, user_service, mock_session): + telegram_id = 123456789 + existing_premium_until = datetime.now() + timedelta(days=10) + user = UserModel( + telegram_id=str(telegram_id), + is_premium=True, + premium_until=existing_premium_until + ) + + mock_result = MagicMock() + mock_result.scalar_one_or_none = MagicMock(return_value=user) + mock_session.execute.return_value = mock_result + + result = await user_service.activate_premium(telegram_id) + + assert result is True + assert user.is_premium is True + assert user.premium_until > existing_premium_until + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_activate_premium_user_not_found(self, user_service, mock_session): + telegram_id = 999999999 + + mock_result = MagicMock() + mock_result.scalar_one_or_none = MagicMock(return_value=None) + mock_session.execute.return_value = mock_result + + result = await user_service.activate_premium(telegram_id) + + assert result is False + mock_session.commit.assert_not_called() diff --git a/requirements.txt b/tg_bot/requirements.txt similarity index 65% rename from requirements.txt rename to tg_bot/requirements.txt index 502efcb..8a95e17 100644 --- a/requirements.txt +++ b/tg_bot/requirements.txt @@ -4,8 +4,7 @@ python-dotenv>=1.0.0 aiogram>=3.10.0 sqlalchemy>=2.0.0 aiosqlite>=0.19.0 -httpx>=0.25.0 +httpx>=0.25.2 yookassa>=2.4.0 -fastapi>=0.104.0 -uvicorn>=0.24.0 -python-multipart>=0.0.6 \ No newline at end of file +aiohttp>=3.9.1 +