forked from Lightning-AI/torchmetrics
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sacrebleu: Add more tokenizers for SacreBLEU metric (Lightning-AI#2068)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
- Loading branch information
1 parent
287afb5
commit aab8851
Showing
11 changed files
with
230 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,18 +37,28 @@ | |
# MIT License | ||
# Copyright (c) 2017 - Shujian Huang <[email protected]> | ||
|
||
import os | ||
import re | ||
import tempfile | ||
from functools import partial | ||
from typing import ClassVar, Optional, Sequence | ||
from typing import Any, ClassVar, Dict, Optional, Sequence | ||
|
||
import torch | ||
from torch import Tensor, tensor | ||
from typing_extensions import Literal | ||
|
||
from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update | ||
from torchmetrics.utilities.imports import _REGEX_AVAILABLE | ||
from torchmetrics.utilities.imports import ( | ||
_IPADIC_AVAILABLE, | ||
_MECAB_AVAILABLE, | ||
_MECAB_KO_AVAILABLE, | ||
_MECAB_KO_DIC_AVAILABLE, | ||
_REGEX_AVAILABLE, | ||
_SENTENCEPIECE_AVAILABLE, | ||
) | ||
|
||
AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char") | ||
AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200") | ||
_Tokenizers_list = Literal["none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200"] | ||
|
||
_UCODE_RANGES = ( | ||
("\u3400", "\u4db5"), # CJK Unified Ideographs Extension A, release 3.0 | ||
|
@@ -77,6 +87,14 @@ | |
) | ||
|
||
|
||
_FLORES_LOCAL_DIR = os.path.join(tempfile.gettempdir(), "torchmetrics-flores") | ||
# Model paths copied from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_spm.py. | ||
_FLORES_MODELS_URL = { | ||
"flores101": "https://dl.fbaipublicfiles.com/fairseq/models/flores/sacrebleu_tokenizer_spm.model", | ||
"flores200": "https://tinyurl.com/flores200sacrebleuspm", | ||
} | ||
|
||
|
||
class _SacreBLEUTokenizer: | ||
"""Tokenizer used for SacreBLEU calculation. | ||
|
@@ -116,9 +134,18 @@ class _SacreBLEUTokenizer: | |
"zh": "_tokenize_zh", | ||
"intl": "_tokenize_international", | ||
"char": "_tokenize_char", | ||
"ja-mecab": "_tokenize_ja_mecab", | ||
"ko-mecab": "_tokenize_ko_mecab", | ||
"flores101": "_tokenize_flores_101", | ||
"flores200": "_tokenize_flores_200", | ||
} | ||
|
||
def __init__(self, tokenize: Literal["none", "13a", "zh", "intl", "char"], lowercase: bool = False) -> None: | ||
# Keep it as class variable to avoid initializing over and over again | ||
sentencepiece_processors: ClassVar[Dict[str, Optional[Any]]] = {"flores101": None, "flores200": None} | ||
|
||
def __init__(self, tokenize: _Tokenizers_list, lowercase: bool = False) -> None: | ||
self._check_tokenizers_validity(tokenize) | ||
|
||
self.tokenize_fn = getattr(self, self._TOKENIZE_FN[tokenize]) | ||
self.lowercase = lowercase | ||
|
||
|
@@ -127,9 +154,9 @@ def __call__(self, line: str) -> Sequence[str]: | |
return self._lower(tokenized_line, self.lowercase).split() | ||
|
||
@classmethod | ||
def tokenize( | ||
cls, line: str, tokenize: Literal["none", "13a", "zh", "intl", "char"], lowercase: bool = False | ||
) -> Sequence[str]: | ||
def tokenize(cls, line: str, tokenize: _Tokenizers_list, lowercase: bool = False) -> Sequence[str]: | ||
cls._check_tokenizers_validity(tokenize) | ||
|
||
tokenize_fn = getattr(cls, cls._TOKENIZE_FN[tokenize]) | ||
tokenized_line = tokenize_fn(line) | ||
return cls._lower(tokenized_line, lowercase).split() | ||
|
@@ -274,19 +301,159 @@ def _tokenize_char(cls, line: str) -> str: | |
""" | ||
return " ".join(char for char in line) | ||
|
||
@classmethod | ||
def _tokenize_ja_mecab(cls, line: str) -> str: | ||
"""Tokenizes a Japanese string line using MeCab morphological analyzer. | ||
Args: | ||
line: the input string to tokenize. | ||
Return: | ||
The tokenized string. | ||
""" | ||
import ipadic | ||
import MeCab | ||
|
||
tagger = MeCab.Tagger(ipadic.MECAB_ARGS + " -Owakati") | ||
|
||
line = line.strip() | ||
return tagger.parse(line).strip() | ||
|
||
@classmethod | ||
def _tokenize_ko_mecab(cls, line: str) -> str: | ||
"""Tokenizes a Korean string line using MeCab-korean morphological analyzer. | ||
Args: | ||
line: the input string to tokenize. | ||
Return: | ||
The tokenized string. | ||
""" | ||
import mecab_ko | ||
import mecab_ko_dic | ||
|
||
tagger = mecab_ko.Tagger(mecab_ko_dic.MECAB_ARGS + " -Owakati") | ||
|
||
line = line.strip() | ||
return tagger.parse(line).strip() | ||
|
||
@classmethod | ||
def _tokenize_flores(cls, line: str, tokenize: Literal["flores101", "flores200"]) -> str: | ||
"""Tokenizes a string line using sentencepiece tokenizer. | ||
Args: | ||
line: the input string to tokenize. | ||
tokenize: Tokenization technique to be used. | ||
Return: | ||
The tokenized string. | ||
""" | ||
import sentencepiece | ||
|
||
if cls.sentencepiece_processors[tokenize] is None: | ||
cls.sentencepiece_processors[tokenize] = sentencepiece.SentencePieceProcessor() | ||
|
||
file_path = os.path.join(_FLORES_LOCAL_DIR, _FLORES_MODELS_URL[tokenize].split("/")[-1]) | ||
if not os.path.exists(file_path): | ||
cls.download_flores_file(tokenize) | ||
|
||
cls.sentencepiece_processors[tokenize].Load(file_path) # type: ignore[union-attr] | ||
|
||
return " ".join(cls.sentencepiece_processors[tokenize].EncodeAsPieces(line)) # type: ignore[union-attr] | ||
|
||
@classmethod | ||
def _tokenize_flores_101(cls, line: str) -> str: | ||
"""Tokenizes a string line using sentencepiece tokenizer according to `FLORES-101`_ dataset. | ||
Args: | ||
line: the input string to tokenize. | ||
Return: | ||
The tokenized string. | ||
""" | ||
return cls._tokenize_flores(line, "flores101") | ||
|
||
@classmethod | ||
def _tokenize_flores_200(cls, line: str) -> str: | ||
"""Tokenizes a string line using sentencepiece tokenizer according to `FLORES-200`_ dataset. | ||
Args: | ||
line: the input string to tokenize. | ||
Return: | ||
The tokenized string. | ||
""" | ||
return cls._tokenize_flores(line, "flores200") | ||
|
||
@staticmethod | ||
def _lower(line: str, lowercase: bool) -> str: | ||
if lowercase: | ||
return line.lower() | ||
return line | ||
|
||
@classmethod | ||
def _check_tokenizers_validity(cls, tokenize: _Tokenizers_list) -> None: | ||
"""Check if a supported tokenizer is chosen. | ||
Also check all dependencies of a given tokenizers are installed. | ||
""" | ||
if tokenize not in cls._TOKENIZE_FN: | ||
raise ValueError(f"Unsupported tokenizer selected. Please, choose one of {list(cls._TOKENIZE_FN.keys())}") | ||
|
||
if tokenize == "intl" and not _REGEX_AVAILABLE: | ||
raise ModuleNotFoundError( | ||
"`'intl'` tokenization requires that `regex` is installed." | ||
" Use `pip install regex` or `pip install torchmetrics[text]`." | ||
) | ||
|
||
if tokenize == "ja-mecab" and not (_MECAB_AVAILABLE and _IPADIC_AVAILABLE): | ||
raise ModuleNotFoundError( | ||
"`'ja-mecab'` tokenization requires that `MeCab` and `ipadic` are installed." | ||
" Use `pip install mecab-python3 ipadic` or `pip install torchmetrics[text]`." | ||
) | ||
|
||
if tokenize == "ko-mecab" and not (_MECAB_KO_AVAILABLE and _MECAB_KO_DIC_AVAILABLE): | ||
raise ModuleNotFoundError( | ||
"`'ko-mecab'` tokenization requires that `mecab_ko` and `mecab_ko_dic` are installed." | ||
" Use `pip install mecab_ko mecab_ko_dic` or `pip install torchmetrics[text]`." | ||
) | ||
|
||
if "flores" in tokenize and not _SENTENCEPIECE_AVAILABLE: | ||
raise ModuleNotFoundError( | ||
"`'flores101' and 'flores200'` tokenizations require that `sentencepiece` is installed." | ||
" Use `pip install sentencepiece` or `pip install torchmetrics[text]`." | ||
) | ||
|
||
@staticmethod | ||
def download_flores_file(model_name: Literal["flores101", "flores200"]) -> None: | ||
"""Download necessary files for `flores` tokenization via `sentencepiece`.""" | ||
import ssl | ||
import urllib.request | ||
|
||
os.makedirs(_FLORES_LOCAL_DIR, exist_ok=True) | ||
|
||
model_url = _FLORES_MODELS_URL[model_name] | ||
file_path = os.path.join(_FLORES_LOCAL_DIR, model_url.split("/")[-1]) | ||
|
||
try: | ||
with open(file_path, "wb") as out_file, urllib.request.urlopen(model_url) as remote_file: | ||
out_file.write(remote_file.read()) | ||
except ssl.SSLError as e: | ||
raise OSError(f"Failed to download {model_name} model.") from e | ||
|
||
|
||
def sacre_bleu_score( | ||
preds: Sequence[str], | ||
target: Sequence[Sequence[str]], | ||
n_gram: int = 4, | ||
smooth: bool = False, | ||
tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", | ||
tokenize: _Tokenizers_list = "13a", | ||
lowercase: bool = False, | ||
weights: Optional[Sequence[float]] = None, | ||
) -> Tensor: | ||
|
@@ -299,8 +466,8 @@ def sacre_bleu_score( | |
target: An iterable of iterables of reference corpus | ||
n_gram: Gram value ranged from 1 to 4 | ||
smooth: Whether to apply smoothing - see [2] | ||
tokenize: Tokenization technique to be used. | ||
Supported tokenization: ['none', '13a', 'zh', 'intl', 'char'] | ||
tokenize: Tokenization technique to be used. Choose between ``'none'``, ``'13a'``, ``'zh'``, ``'intl'``, | ||
``'char'``, ``'ja-mecab'``, ``'ko-mecab'``, ``'flores101'`` and ``'flores200'``. | ||
lowercase: If ``True``, BLEU score over lowercased text is calculated. | ||
weights: | ||
Weights used for unigrams, bigrams, etc. to calculate BLEU score. | ||
|
@@ -330,20 +497,8 @@ def sacre_bleu_score( | |
and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ | ||
""" | ||
if tokenize not in AVAILABLE_TOKENIZERS: | ||
raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") | ||
|
||
if tokenize not in _SacreBLEUTokenizer._TOKENIZE_FN: | ||
raise ValueError( | ||
f"Unsupported tokenizer selected. Please, choose one of {list(_SacreBLEUTokenizer._TOKENIZE_FN.keys())}" | ||
) | ||
if len(preds) != len(target): | ||
raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}") | ||
if tokenize == "intl" and not _REGEX_AVAILABLE: | ||
raise ModuleNotFoundError( | ||
"`'intl'` tokenization requires that `regex` is installed." | ||
" Use `pip install regex` or `pip install torchmetrics[text]`." | ||
) | ||
|
||
if weights is not None and len(weights) != n_gram: | ||
raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.