134 lines
4.0 KiB
Python
134 lines
4.0 KiB
Python
import pytest
|
|
from uuid import uuid4
|
|
from typing import List, Dict
|
|
|
|
|
|
def calculate_hit_at_5(retrieved_document_ids: List[str], relevant_document_ids: List[str]) -> int:
|
|
if not retrieved_document_ids or not relevant_document_ids:
|
|
return 0
|
|
|
|
top5_ids = set(retrieved_document_ids[:5])
|
|
relevant_ids = set(relevant_document_ids)
|
|
|
|
return 1 if top5_ids.intersection(relevant_ids) else 0
|
|
|
|
|
|
def calculate_average_hit_at_5(results: List[int]) -> float:
|
|
if not results:
|
|
return 0.0
|
|
return sum(results) / len(results)
|
|
|
|
|
|
class TestHitAt5Metric:
|
|
|
|
def test_hit_at_5_returns_1_when_relevant_document_in_top5(self):
|
|
relevant_ids = [str(uuid4()), str(uuid4())]
|
|
retrieved_ids = [
|
|
str(uuid4()),
|
|
relevant_ids[0],
|
|
str(uuid4()),
|
|
str(uuid4()),
|
|
str(uuid4())
|
|
]
|
|
|
|
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
|
assert result == 1
|
|
|
|
def test_hit_at_5_returns_0_when_no_relevant_document_in_top5(self):
|
|
relevant_ids = [str(uuid4()), str(uuid4())]
|
|
retrieved_ids = [
|
|
str(uuid4()),
|
|
str(uuid4()),
|
|
str(uuid4()),
|
|
str(uuid4()),
|
|
str(uuid4())
|
|
]
|
|
|
|
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
|
assert result == 0
|
|
|
|
def test_hit_at_5_returns_1_when_multiple_relevant_documents(self):
|
|
relevant_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
|
|
retrieved_ids = [
|
|
relevant_ids[0],
|
|
str(uuid4()),
|
|
relevant_ids[1],
|
|
str(uuid4()),
|
|
relevant_ids[2]
|
|
]
|
|
|
|
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
|
assert result == 1
|
|
|
|
def test_hit_at_5_handles_empty_lists(self):
|
|
result = calculate_hit_at_5([], [str(uuid4())])
|
|
assert result == 0
|
|
|
|
result = calculate_hit_at_5([str(uuid4())], [])
|
|
assert result == 0
|
|
|
|
result = calculate_hit_at_5([], [])
|
|
assert result == 0
|
|
|
|
def test_hit_at_5_only_checks_top5(self):
|
|
relevant_ids = [str(uuid4())]
|
|
retrieved_ids = [
|
|
str(uuid4()),
|
|
str(uuid4()),
|
|
str(uuid4()),
|
|
str(uuid4()),
|
|
str(uuid4()),
|
|
relevant_ids[0]
|
|
]
|
|
|
|
result = calculate_hit_at_5(retrieved_ids, relevant_ids)
|
|
assert result == 0
|
|
|
|
def test_calculate_average_hit_at_5(self):
|
|
results = [1, 1, 0, 1, 0, 1, 0, 1, 1, 1]
|
|
|
|
average = calculate_average_hit_at_5(results)
|
|
assert average == 0.7
|
|
|
|
def test_calculate_average_hit_at_5_all_ones(self):
|
|
results = [1, 1, 1, 1, 1]
|
|
|
|
average = calculate_average_hit_at_5(results)
|
|
assert average == 1.0
|
|
|
|
def test_calculate_average_hit_at_5_all_zeros(self):
|
|
results = [0, 0, 0, 0, 0]
|
|
|
|
average = calculate_average_hit_at_5(results)
|
|
assert average == 0.0
|
|
|
|
def test_calculate_average_hit_at_5_empty_list(self):
|
|
average = calculate_average_hit_at_5([])
|
|
assert average == 0.0
|
|
|
|
def test_hit_at_5_quality_threshold(self):
|
|
results = [1] * 60 + [0] * 40
|
|
|
|
average = calculate_average_hit_at_5(results)
|
|
assert average > 0.5, f"Качество {average} должно быть > 0.5"
|
|
assert average == 0.6
|
|
|
|
def test_hit_at_5_quality_below_threshold(self):
|
|
results = [1] * 40 + [0] * 60
|
|
|
|
average = calculate_average_hit_at_5(results)
|
|
assert average < 0.5, f"Качество {average} должно быть < 0.5"
|
|
assert average == 0.4
|
|
|
|
@pytest.mark.parametrize("hit_count,total,expected_quality", [
|
|
(51, 100, 0.51),
|
|
(50, 100, 0.50),
|
|
(60, 100, 0.60),
|
|
(75, 100, 0.75),
|
|
(100, 100, 1.0),
|
|
])
|
|
def test_hit_at_5_various_qualities(self, hit_count, total, expected_quality):
|
|
results = [1] * hit_count + [0] * (total - hit_count)
|
|
average = calculate_average_hit_at_5(results)
|
|
assert average == expected_quality
|