172 lines
6.7 KiB
Python
172 lines
6.7 KiB
Python
import pytest
|
||
from unittest.mock import AsyncMock, patch, MagicMock
|
||
from uuid import uuid4
|
||
from tg_bot.application.services.rag_service import RAGService
|
||
|
||
|
||
class TestRAGService:
|
||
|
||
@pytest.fixture
|
||
def rag_service(self):
|
||
service = RAGService()
|
||
from tg_bot.infrastructure.external.deepseek_client import DeepSeekClient
|
||
service.deepseek_client = DeepSeekClient()
|
||
return service
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_documents_in_collections_success(self, rag_service):
|
||
user_telegram_id = "123456789"
|
||
query = "трудовой договор"
|
||
|
||
mock_documents = [
|
||
{
|
||
"document_id": str(uuid4()),
|
||
"title": "Трудовой кодекс РФ",
|
||
"content": "Содержание о трудовых договорах",
|
||
"collection_name": "Законы"
|
||
},
|
||
{
|
||
"document_id": str(uuid4()),
|
||
"title": "Правила оформления",
|
||
"content": "Как оформить трудовой договор",
|
||
"collection_name": "Инструкции"
|
||
}
|
||
]
|
||
|
||
with patch('aiohttp.ClientSession') as mock_session:
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={
|
||
"user_id": str(uuid4())
|
||
})
|
||
|
||
mock_collections_response = AsyncMock()
|
||
mock_collections_response.status = 200
|
||
mock_collections_response.json = AsyncMock(return_value=[
|
||
{"collection_id": str(uuid4()), "name": "Законы"}
|
||
])
|
||
|
||
mock_search_response = AsyncMock()
|
||
mock_search_response.status = 200
|
||
mock_search_response.json = AsyncMock(return_value=mock_documents)
|
||
|
||
mock_session_instance = MagicMock()
|
||
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
||
mock_session_instance.__aexit__ = AsyncMock(return_value=None)
|
||
mock_session_instance.get = AsyncMock(side_effect=[
|
||
mock_response,
|
||
mock_collections_response,
|
||
mock_search_response
|
||
])
|
||
mock_session.return_value = mock_session_instance
|
||
|
||
result = await rag_service.search_documents_in_collections(
|
||
user_telegram_id, query, limit_per_collection=5
|
||
)
|
||
|
||
assert len(result) > 0
|
||
assert result[0]["title"] == "Трудовой кодекс РФ"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_documents_empty_result(self, rag_service):
|
||
user_telegram_id = "123456789"
|
||
query = "несуществующий запрос"
|
||
|
||
with patch('aiohttp.ClientSession') as mock_session:
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={
|
||
"user_id": str(uuid4())
|
||
})
|
||
|
||
mock_collections_response = AsyncMock()
|
||
mock_collections_response.status = 200
|
||
mock_collections_response.json = AsyncMock(return_value=[])
|
||
|
||
mock_session_instance = MagicMock()
|
||
mock_session_instance.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
||
mock_session_instance.__aexit__ = AsyncMock(return_value=None)
|
||
mock_session_instance.get = AsyncMock(side_effect=[
|
||
mock_response,
|
||
mock_collections_response
|
||
])
|
||
mock_session.return_value = mock_session_instance
|
||
|
||
result = await rag_service.search_documents_in_collections(
|
||
user_telegram_id, query
|
||
)
|
||
|
||
assert result == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_answer_with_rag_success(self, rag_service, mock_rag_response):
|
||
question = "Какие права имеет работник?"
|
||
user_telegram_id = "123456789"
|
||
|
||
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||
patch.object(rag_service, 'deepseek_client') as mock_client:
|
||
|
||
mock_search.return_value = [
|
||
{
|
||
"document_id": str(uuid4()),
|
||
"title": "Трудовой кодекс",
|
||
"content": "Работник имеет право на...",
|
||
"collection_name": "Законы"
|
||
}
|
||
]
|
||
|
||
mock_client.chat_completion = AsyncMock(return_value={
|
||
"content": "Работник имеет следующие права...",
|
||
"usage": {"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}
|
||
})
|
||
|
||
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||
|
||
assert "answer" in result
|
||
assert "sources" in result
|
||
assert "usage" in result
|
||
assert len(result["sources"]) <= 5
|
||
assert result["answer"] != ""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_answer_limits_to_top5(self, rag_service):
|
||
question = "Тестовый вопрос"
|
||
user_telegram_id = "123456789"
|
||
|
||
many_documents = [
|
||
{
|
||
"document_id": str(uuid4()),
|
||
"title": f"Документ {i}",
|
||
"content": f"Содержание {i}",
|
||
"collection_name": "Коллекция"
|
||
}
|
||
for i in range(20)
|
||
]
|
||
|
||
with patch.object(rag_service, 'search_documents_in_collections') as mock_search, \
|
||
patch.object(rag_service, 'deepseek_client') as mock_client:
|
||
|
||
mock_search.return_value = many_documents
|
||
mock_client.chat_completion = AsyncMock(return_value={
|
||
"content": "Ответ",
|
||
"usage": {}
|
||
})
|
||
|
||
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||
|
||
assert len(result["sources"]) == 5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_answer_no_documents(self, rag_service):
|
||
question = "Вопрос без документов"
|
||
user_telegram_id = "123456789"
|
||
|
||
with patch.object(rag_service, 'search_documents_in_collections') as mock_search:
|
||
mock_search.return_value = []
|
||
|
||
result = await rag_service.generate_answer_with_rag(question, user_telegram_id)
|
||
|
||
assert result["sources"] == []
|
||
assert "Релевантные документы не найдены" in result.get("answer", "") or \
|
||
result["answer"] == "No relevant documents found"
|