From dbeb3ea04b65f8b905f00ccd401b14759586ce5f Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Mon, 6 Mar 2023 10:25:31 +0800 Subject: [PATCH] fix comment --- .../metrics/one_minus_norm_edit_distance.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/mmeval/metrics/one_minus_norm_edit_distance.py b/mmeval/metrics/one_minus_norm_edit_distance.py index 5c5e5be8..f6bd0bec 100644 --- a/mmeval/metrics/one_minus_norm_edit_distance.py +++ b/mmeval/metrics/one_minus_norm_edit_distance.py @@ -21,17 +21,17 @@ class OneMinusNormEditDistance(BaseMetric): - unchanged: Do not change prediction texts and labels. - upper: Convert prediction texts and labels into uppercase - characters. + characters. - lower: Convert prediction texts and labels into lowercase - characters. + characters. Usually, it only works for English characters. Defaults to 'unchanged'. invalid_symbol (str): A regular expression to filter out invalid or - not cared characters. Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'. **kwargs: Keyword parameters passed to :class:`BaseMetric`. - Example: + Examples: >>> from mmeval import OneMinusNormEditDistance >>> metric = OneMinusNormEditDistance() >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) @@ -43,7 +43,7 @@ class OneMinusNormEditDistance(BaseMetric): def __init__(self, letter_case: str = 'unchanged', - invalid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]', **kwargs): super().__init__(**kwargs) @@ -51,14 +51,14 @@ def __init__(self, self.letter_case = letter_case self.invalid_symbol = re.compile(invalid_symbol) - def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 + def add(self, predictions: Sequence[str], groundtruths: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 """Process one batch of data and predictions. Args: predictions (list[str]): The prediction texts. - labels (list[str]): The ground truth texts. + groundtruths (list[str]): The ground truth texts. """ - for pred, label in zip(predictions, labels): + for pred, label in zip(predictions, groundtruths): if self.letter_case in ['upper', 'lower']: pred = getattr(pred, self.letter_case)() label = getattr(label, self.letter_case)() @@ -75,11 +75,12 @@ def compute_metric(self, results: List[float]) -> Dict: Returns: dict[str, float]: Nested dicts as results. - - 1-N.E.D (float): One minus the normalized edit distance. + + - 1-N.E.D (float): One minus the normalized edit distance. """ gt_word_num = len(results) norm_ed_sum = sum(results) normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num) - eval_res = {} - eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance - return eval_res + metric_results = {} + metric_results['1-N.E.D'] = 1.0 - normalized_edit_distance + return metric_results