2025-12-23 12:08:28 +03:00

134 lines
4.0 KiB
Python

import pytest
from uuid import uuid4
from typing import List, Dict
def calculate_hit_at_5(retrieved_document_ids: List[str], relevant_document_ids: List[str]) -> int:
if not retrieved_document_ids or not relevant_document_ids:
return 0
top5_ids = set(retrieved_document_ids[:5])
relevant_ids = set(relevant_document_ids)
return 1 if top5_ids.intersection(relevant_ids) else 0
def calculate_average_hit_at_5(results: List[int]) -> float:
if not results:
return 0.0
return sum(results) / len(results)
class TestHitAt5Metric:
def test_hit_at_5_returns_1_when_relevant_document_in_top5(self):
relevant_ids = [str(uuid4()), str(uuid4())]
retrieved_ids = [
str(uuid4()),
relevant_ids[0],
str(uuid4()),
str(uuid4()),
str(uuid4())
]
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
assert result == 1
def test_hit_at_5_returns_0_when_no_relevant_document_in_top5(self):
relevant_ids = [str(uuid4()), str(uuid4())]
retrieved_ids = [
str(uuid4()),
str(uuid4()),
str(uuid4()),
str(uuid4()),
str(uuid4())
]
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
assert result == 0
def test_hit_at_5_returns_1_when_multiple_relevant_documents(self):
relevant_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
retrieved_ids = [
relevant_ids[0],
str(uuid4()),
relevant_ids[1],
str(uuid4()),
relevant_ids[2]
]
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
assert result == 1
def test_hit_at_5_handles_empty_lists(self):
result = calculate_hit_at_5([], [str(uuid4())])
assert result == 0
result = calculate_hit_at_5([str(uuid4())], [])
assert result == 0
result = calculate_hit_at_5([], [])
assert result == 0
def test_hit_at_5_only_checks_top5(self):
relevant_ids = [str(uuid4())]
retrieved_ids = [
str(uuid4()),
str(uuid4()),
str(uuid4()),
str(uuid4()),
str(uuid4()),
relevant_ids[0]
]
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
assert result == 0
def test_calculate_average_hit_at_5(self):
results = [1, 1, 0, 1, 0, 1, 0, 1, 1, 1]
average = calculate_average_hit_at_5(results)
assert average == 0.7
def test_calculate_average_hit_at_5_all_ones(self):
results = [1, 1, 1, 1, 1]
average = calculate_average_hit_at_5(results)
assert average == 1.0
def test_calculate_average_hit_at_5_all_zeros(self):
results = [0, 0, 0, 0, 0]
average = calculate_average_hit_at_5(results)
assert average == 0.0
def test_calculate_average_hit_at_5_empty_list(self):
average = calculate_average_hit_at_5([])
assert average == 0.0
def test_hit_at_5_quality_threshold(self):
results = [1] * 60 + [0] * 40
average = calculate_average_hit_at_5(results)
assert average > 0.5, f"Качество {average} должно быть > 0.5"
assert average == 0.6
def test_hit_at_5_quality_below_threshold(self):
results = [1] * 40 + [0] * 60
average = calculate_average_hit_at_5(results)
assert average < 0.5, f"Качество {average} должно быть < 0.5"
assert average == 0.4
@pytest.mark.parametrize("hit_count,total,expected_quality", [
(51, 100, 0.51),
(50, 100, 0.50),
(60, 100, 0.60),
(75, 100, 0.75),
(100, 100, 1.0),
])
def test_hit_at_5_various_qualities(self, hit_count, total, expected_quality):
results = [1] * hit_count + [0] * (total - hit_count)
average = calculate_average_hit_at_5(results)
assert average == expected_quality