Skip to content

Commit

Permalink
ADD: HitRate, MRR metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
haru-256 committed Dec 30, 2024
1 parent 4a1dd59 commit bf0cf9e
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 0 deletions.
94 changes: 94 additions & 0 deletions common/tests/test_utils/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch

from utils.metrics import (
MRR,
HitRate,
create_classification_inputs,
create_retrieval_inputs,
hit_rate_v1,
Expand Down Expand Up @@ -96,3 +98,95 @@ def test_hit_rate():
expected = torch.tensor([1.0, 1.0]).mean()
actual = hit_rate_v2(score, target, k=4)
torch.testing.assert_close(actual, expected)


class TestMRR:
def test_forward(self):
mrr = MRR(k=4)

# test forward batch
score = torch.tensor(
[
[0.1, 0.2, 0.3, 0.4],
[0.4, 0.3, 0.2, 0.1],
]
)
target = torch.tensor(
[
[1, 0, 0, 0],
[1, 0, 0, 0],
]
).long()
expected = torch.tensor([1 / 4, 1.0]).mean()
actual = mrr(score, target)
torch.testing.assert_close(actual, expected)

score = torch.tensor(
[
[0.1, 0.2, 0.3, 0.4],
]
)
target = torch.tensor(
[
[1, 0, 0, 0],
]
).long()
expected = torch.tensor([1 / 4]).mean()
actual = mrr(score, target)
torch.testing.assert_close(actual, expected)

# test accumulate
expected = torch.tensor([1 / 4, 1.0, 1 / 4]).mean()
actual = mrr.compute()
torch.testing.assert_close(actual, expected)

# check reset
mrr.reset()
assert len(mrr.mrr) == 0
assert len(mrr.num_queries) == 0


class TestHitRate:
def test_forward(self):
hit_rate = HitRate(k=1)

# test forward batch
score = torch.tensor(
[
[0.1, 0.2, 0.3, 0.4],
[0.4, 0.3, 0.2, 0.1],
]
)
target = torch.tensor(
[
[1, 0, 0, 0],
[1, 0, 0, 0],
]
).long()
expected = torch.tensor([0.0, 1.0]).mean()
actual = hit_rate(score, target)
torch.testing.assert_close(actual, expected)

score = torch.tensor(
[
[0.1, 0.2, 0.3, 0.4],
]
)
target = torch.tensor(
[
[1, 0, 0, 0],
]
).long()
expected = torch.tensor([0.0]).mean()
actual = hit_rate(score, target)
torch.testing.assert_close(actual, expected)

# test accumulate
expected = torch.tensor([0.0, 1.0, 0.0]).mean()
actual = hit_rate.compute()
torch.testing.assert_close(actual, expected)

# check reset
hit_rate.reset()
assert len(hit_rate.hit_rate) == 0
assert len(hit_rate.num_queries) == 0
47 changes: 47 additions & 0 deletions common/utils/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torchmetrics import Metric
from torchmetrics.functional.retrieval import retrieval_hit_rate, retrieval_reciprocal_rank


Expand Down Expand Up @@ -151,3 +152,49 @@ def hit_rate_v2(score: torch.Tensor, target: torch.Tensor, k: int = 10) -> torch
dtype=torch.float32,
)
return values.mean()


class MRR(Metric):
def __init__(self, k: int = 10):
"""mean reciprocal rank at k
Args:
k: top k. Defaults to 10.
"""
super().__init__()
self.k = k
self.add_state("mrr", default=[], dist_reduce_fx=None)
self.add_state("num_queries", default=[], dist_reduce_fx=None)

def update(self, score: torch.Tensor, target: torch.Tensor) -> None:
self.mrr.append(mrr(score, target, self.k))
self.num_queries.append(score.size(0))

def compute(self) -> torch.Tensor:
# mrrはqueryごとの平均なので、全体の平均値に変換
_mrr = torch.as_tensor(self.mrr, dtype=torch.float32)
_num_queries = torch.as_tensor(self.num_queries, dtype=torch.float32)
return (_mrr * _num_queries).sum() / _num_queries.sum()


class HitRate(Metric):
def __init__(self, k: int = 10):
"""hit rate at k
Args:
k: top k. Defaults to 10.
"""
super().__init__()
self.k = k
self.add_state("hit_rate", default=[], dist_reduce_fx=None)
self.add_state("num_queries", default=[], dist_reduce_fx=None)

def update(self, score: torch.Tensor, target: torch.Tensor) -> None:
self.hit_rate.append(hit_rate(score, target, self.k))
self.num_queries.append(score.size(0))

def compute(self) -> torch.Tensor:
# hit_rateはqueryごとの平均なので、全体の平均値に変換
_hit_rate = torch.as_tensor(self.hit_rate, dtype=torch.float32)
_num_queries = torch.as_tensor(self.num_queries, dtype=torch.float32)
return (_hit_rate * _num_queries).sum() / _num_queries.sum()

0 comments on commit bf0cf9e

Please sign in to comment.