diff --git a/fanoutqa/eval/scorer.py b/fanoutqa/eval/scorer.py index 63d6132..1906c30 100644 --- a/fanoutqa/eval/scorer.py +++ b/fanoutqa/eval/scorer.py @@ -134,9 +134,12 @@ def score_rouge(self) -> Tuple[RougeScore, Dict[str, RougeScore]]: results = self.rouge.score(str_answer(q.answer), str_answer(a["answer"])) for k, v in results.items(): scores[k].append(v) - raw_scores[q.id] = RougeScore(**{ - k: RougeScorePart(precision=v.precision, recall=v.recall, fscore=v.fmeasure) for k, v in results.items() - }) + raw_scores[q.id] = RougeScore( + **{ + k: RougeScorePart(precision=v.precision, recall=v.recall, fscore=v.fmeasure) + for k, v in results.items() + } + ) assert all(len(v) == self.eval_len for v in scores.values()) assert len(raw_scores) == self.eval_len