Skip to content

Commit

Permalink
Sacrebleu: Add more tokenizers for SacreBLEU metric (#2068)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2023
1 parent d8962b4 commit f283b8f
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 52 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/ci-integrate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-20.04, macOS-11, windows-2022]
os: [ubuntu-20.04, macOS-12, windows-2022]
python-version: ["3.8", "3.10"]
requires: ["oldest", "latest"]
exclude:
- { python-version: "3.10", requires: "oldest" }
- { python-version: "3.10", os: "windows" } # todo: https://discuss.pytorch.org/t/numpy-is-not-available-error/146192
include:
- { python-version: "3.10", requires: "latest", os: "ubuntu-22.04" }
- { python-version: "3.10", requires: "latest", os: "macOS-12" }
env:
PYTORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ jobs:
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "1.13.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.0" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.0.0" }
- { os: "macOS-11", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "macOS-11", python-version: "3.9", pytorch-version: "1.13.1" }
- { os: "macOS-11", python-version: "3.10", pytorch-version: "2.0.0" }
- { os: "macOS-11", python-version: "3.11", pytorch-version: "2.0.0" }
- { os: "macOS-12", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "macOS-12", python-version: "3.9", pytorch-version: "1.13.1" }
- { os: "macOS-12", python-version: "3.10", pytorch-version: "2.0.0" }
- { os: "macOS-12", python-version: "3.11", pytorch-version: "2.0.0" }
- { os: "windows-2022", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "windows-2022", python-version: "3.9", pytorch-version: "1.13.1" }
- { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.0" }
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.**


## [UnReleased] - 2023-MM-DD

### Added

- Added more tokenizers for `SacreBLEU` metric ([#2068](https://github.com/Lightning-AI/torchmetrics/pull/2068))


- Added `average` argument to multiclass versions of `PrecisionRecallCurve` and `ROC` ([#2084](https://github.com/Lightning-AI/torchmetrics/pull/2084))

- Added error if `NoTrainInceptionV3` is being initialized without `torch-fidelity` not being installed ([#2143](https://github.com/Lightning-AI/torchmetrics/pull/2143))
Expand Down
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,6 @@
.. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html
.. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index
.. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score
.. _FLORES-101: https://arxiv.org/abs/2106.03193
.. _FLORES-200: https://arxiv.org/abs/2207.04672
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
5 changes: 5 additions & 0 deletions requirements/text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@ nltk >=3.6, <=3.8.1
tqdm >=4.41.0, <=4.66.1
regex >=2021.9.24, <=2023.8.8
transformers >4.4.0, <4.30.3
mecab-python3 >=1.0.6, <1.1.0
mecab-ko >=1.0.0, <1.1.0
mecab-ko-dic >=1.0.0, <1.1.0
ipadic >=1.0.0, <1.1.0
sentencepiece >=0.1.98, <=0.1.99
2 changes: 1 addition & 1 deletion requirements/text_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ jiwer >=2.3.0, <3.1.0
rouge-score >0.1.0, <=0.1.2
bert_score ==0.3.13
huggingface-hub <0.18 # hotfix, failing SDR for latest PT 1.11
sacrebleu >=2.0.0, <=2.3.1
sacrebleu >=2.3.0, <=2.3.1
199 changes: 177 additions & 22 deletions src/torchmetrics/functional/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,28 @@
# MIT License
# Copyright (c) 2017 - Shujian Huang <[email protected]>

import os
import re
import tempfile
from functools import partial
from typing import ClassVar, Optional, Sequence
from typing import Any, ClassVar, Dict, Optional, Sequence

import torch
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update
from torchmetrics.utilities.imports import _REGEX_AVAILABLE
from torchmetrics.utilities.imports import (
_IPADIC_AVAILABLE,
_MECAB_AVAILABLE,
_MECAB_KO_AVAILABLE,
_MECAB_KO_DIC_AVAILABLE,
_REGEX_AVAILABLE,
_SENTENCEPIECE_AVAILABLE,
)

AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char")
AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200")
_Tokenizers_list = Literal["none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200"]

_UCODE_RANGES = (
("\u3400", "\u4db5"), # CJK Unified Ideographs Extension A, release 3.0
Expand Down Expand Up @@ -77,6 +87,14 @@
)


_FLORES_LOCAL_DIR = os.path.join(tempfile.gettempdir(), "torchmetrics-flores")
# Model paths copied from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_spm.py.
_FLORES_MODELS_URL = {
"flores101": "https://dl.fbaipublicfiles.com/fairseq/models/flores/sacrebleu_tokenizer_spm.model",
"flores200": "https://tinyurl.com/flores200sacrebleuspm",
}


class _SacreBLEUTokenizer:
"""Tokenizer used for SacreBLEU calculation.
Expand Down Expand Up @@ -116,9 +134,18 @@ class _SacreBLEUTokenizer:
"zh": "_tokenize_zh",
"intl": "_tokenize_international",
"char": "_tokenize_char",
"ja-mecab": "_tokenize_ja_mecab",
"ko-mecab": "_tokenize_ko_mecab",
"flores101": "_tokenize_flores_101",
"flores200": "_tokenize_flores_200",
}

def __init__(self, tokenize: Literal["none", "13a", "zh", "intl", "char"], lowercase: bool = False) -> None:
# Keep it as class variable to avoid initializing over and over again
sentencepiece_processors: ClassVar[Dict[str, Optional[Any]]] = {"flores101": None, "flores200": None}

def __init__(self, tokenize: _Tokenizers_list, lowercase: bool = False) -> None:
self._check_tokenizers_validity(tokenize)

self.tokenize_fn = getattr(self, self._TOKENIZE_FN[tokenize])
self.lowercase = lowercase

Expand All @@ -127,9 +154,9 @@ def __call__(self, line: str) -> Sequence[str]:
return self._lower(tokenized_line, self.lowercase).split()

@classmethod
def tokenize(
cls, line: str, tokenize: Literal["none", "13a", "zh", "intl", "char"], lowercase: bool = False
) -> Sequence[str]:
def tokenize(cls, line: str, tokenize: _Tokenizers_list, lowercase: bool = False) -> Sequence[str]:
cls._check_tokenizers_validity(tokenize)

tokenize_fn = getattr(cls, cls._TOKENIZE_FN[tokenize])
tokenized_line = tokenize_fn(line)
return cls._lower(tokenized_line, lowercase).split()
Expand Down Expand Up @@ -274,19 +301,159 @@ def _tokenize_char(cls, line: str) -> str:
"""
return " ".join(char for char in line)

@classmethod
def _tokenize_ja_mecab(cls, line: str) -> str:
"""Tokenizes a Japanese string line using MeCab morphological analyzer.
Args:
line: the input string to tokenize.
Return:
The tokenized string.
"""
import ipadic
import MeCab

tagger = MeCab.Tagger(ipadic.MECAB_ARGS + " -Owakati")

line = line.strip()
return tagger.parse(line).strip()

@classmethod
def _tokenize_ko_mecab(cls, line: str) -> str:
"""Tokenizes a Korean string line using MeCab-korean morphological analyzer.
Args:
line: the input string to tokenize.
Return:
The tokenized string.
"""
import mecab_ko
import mecab_ko_dic

tagger = mecab_ko.Tagger(mecab_ko_dic.MECAB_ARGS + " -Owakati")

line = line.strip()
return tagger.parse(line).strip()

@classmethod
def _tokenize_flores(cls, line: str, tokenize: Literal["flores101", "flores200"]) -> str:
"""Tokenizes a string line using sentencepiece tokenizer.
Args:
line: the input string to tokenize.
tokenize: Tokenization technique to be used.
Return:
The tokenized string.
"""
import sentencepiece

if cls.sentencepiece_processors[tokenize] is None:
cls.sentencepiece_processors[tokenize] = sentencepiece.SentencePieceProcessor()

file_path = os.path.join(_FLORES_LOCAL_DIR, _FLORES_MODELS_URL[tokenize].split("/")[-1])
if not os.path.exists(file_path):
cls.download_flores_file(tokenize)

cls.sentencepiece_processors[tokenize].Load(file_path) # type: ignore[union-attr]

return " ".join(cls.sentencepiece_processors[tokenize].EncodeAsPieces(line)) # type: ignore[union-attr]

@classmethod
def _tokenize_flores_101(cls, line: str) -> str:
"""Tokenizes a string line using sentencepiece tokenizer according to `FLORES-101`_ dataset.
Args:
line: the input string to tokenize.
Return:
The tokenized string.
"""
return cls._tokenize_flores(line, "flores101")

@classmethod
def _tokenize_flores_200(cls, line: str) -> str:
"""Tokenizes a string line using sentencepiece tokenizer according to `FLORES-200`_ dataset.
Args:
line: the input string to tokenize.
Return:
The tokenized string.
"""
return cls._tokenize_flores(line, "flores200")

@staticmethod
def _lower(line: str, lowercase: bool) -> str:
if lowercase:
return line.lower()
return line

@classmethod
def _check_tokenizers_validity(cls, tokenize: _Tokenizers_list) -> None:
"""Check if a supported tokenizer is chosen.
Also check all dependencies of a given tokenizers are installed.
"""
if tokenize not in cls._TOKENIZE_FN:
raise ValueError(f"Unsupported tokenizer selected. Please, choose one of {list(cls._TOKENIZE_FN.keys())}")

if tokenize == "intl" and not _REGEX_AVAILABLE:
raise ModuleNotFoundError(
"`'intl'` tokenization requires that `regex` is installed."
" Use `pip install regex` or `pip install torchmetrics[text]`."
)

if tokenize == "ja-mecab" and not (_MECAB_AVAILABLE and _IPADIC_AVAILABLE):
raise ModuleNotFoundError(
"`'ja-mecab'` tokenization requires that `MeCab` and `ipadic` are installed."
" Use `pip install mecab-python3 ipadic` or `pip install torchmetrics[text]`."
)

if tokenize == "ko-mecab" and not (_MECAB_KO_AVAILABLE and _MECAB_KO_DIC_AVAILABLE):
raise ModuleNotFoundError(
"`'ko-mecab'` tokenization requires that `mecab_ko` and `mecab_ko_dic` are installed."
" Use `pip install mecab_ko mecab_ko_dic` or `pip install torchmetrics[text]`."
)

if "flores" in tokenize and not _SENTENCEPIECE_AVAILABLE:
raise ModuleNotFoundError(
"`'flores101' and 'flores200'` tokenizations require that `sentencepiece` is installed."
" Use `pip install sentencepiece` or `pip install torchmetrics[text]`."
)

@staticmethod
def download_flores_file(model_name: Literal["flores101", "flores200"]) -> None:
"""Download necessary files for `flores` tokenization via `sentencepiece`."""
import ssl
import urllib.request

os.makedirs(_FLORES_LOCAL_DIR, exist_ok=True)

model_url = _FLORES_MODELS_URL[model_name]
file_path = os.path.join(_FLORES_LOCAL_DIR, model_url.split("/")[-1])

try:
with open(file_path, "wb") as out_file, urllib.request.urlopen(model_url) as remote_file:
out_file.write(remote_file.read())
except ssl.SSLError as e:
raise OSError(f"Failed to download {model_name} model.") from e


def sacre_bleu_score(
preds: Sequence[str],
target: Sequence[Sequence[str]],
n_gram: int = 4,
smooth: bool = False,
tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
tokenize: _Tokenizers_list = "13a",
lowercase: bool = False,
weights: Optional[Sequence[float]] = None,
) -> Tensor:
Expand All @@ -299,8 +466,8 @@ def sacre_bleu_score(
target: An iterable of iterables of reference corpus
n_gram: Gram value ranged from 1 to 4
smooth: Whether to apply smoothing - see [2]
tokenize: Tokenization technique to be used.
Supported tokenization: ['none', '13a', 'zh', 'intl', 'char']
tokenize: Tokenization technique to be used. Choose between ``'none'``, ``'13a'``, ``'zh'``, ``'intl'``,
``'char'``, ``'ja-mecab'``, ``'ko-mecab'``, ``'flores101'`` and ``'flores200'``.
lowercase: If ``True``, BLEU score over lowercased text is calculated.
weights:
Weights used for unigrams, bigrams, etc. to calculate BLEU score.
Expand Down Expand Up @@ -330,20 +497,8 @@ def sacre_bleu_score(
and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_
"""
if tokenize not in AVAILABLE_TOKENIZERS:
raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.")

if tokenize not in _SacreBLEUTokenizer._TOKENIZE_FN:
raise ValueError(
f"Unsupported tokenizer selected. Please, choose one of {list(_SacreBLEUTokenizer._TOKENIZE_FN.keys())}"
)
if len(preds) != len(target):
raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}")
if tokenize == "intl" and not _REGEX_AVAILABLE:
raise ModuleNotFoundError(
"`'intl'` tokenization requires that `regex` is installed."
" Use `pip install regex` or `pip install torchmetrics[text]`."
)

if weights is not None and len(weights) != n_gram:
raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
Expand Down
22 changes: 5 additions & 17 deletions src/torchmetrics/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,17 @@
from typing import Any, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.text.bleu import _bleu_score_update
from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer
from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer, _Tokenizers_list
from torchmetrics.text.bleu import BLEUScore
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _REGEX_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SacreBLEUScore.plot"]


AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char")


class SacreBLEUScore(BLEUScore):
"""Calculate `BLEU score`_ of machine translated text with one or more references.
Expand All @@ -53,8 +49,8 @@ class SacreBLEUScore(BLEUScore):
Args:
n_gram: Gram value ranged from 1 to 4
smooth: Whether to apply smoothing, see `SacreBLEU`_
tokenize: Tokenization technique to be used.
Supported tokenization: ``['none', '13a', 'zh', 'intl', 'char']``
tokenize: Tokenization technique to be used. Choose between ``'none'``, ``'13a'``, ``'zh'``, ``'intl'``,
``'char'``, ``'ja-mecab'``, ``'ko-mecab'``, ``'flores101'`` and ``'flores200'``.
lowercase: If ``True``, BLEU score over lowercased text is calculated.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
weights:
Expand Down Expand Up @@ -95,20 +91,12 @@ def __init__(
self,
n_gram: int = 4,
smooth: bool = False,
tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
tokenize: _Tokenizers_list = "13a",
lowercase: bool = False,
weights: Optional[Sequence[float]] = None,
**kwargs: Any,
) -> None:
super().__init__(n_gram=n_gram, smooth=smooth, weights=weights, **kwargs)
if tokenize not in AVAILABLE_TOKENIZERS:
raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.")

if tokenize == "intl" and not _REGEX_AVAILABLE:
raise ModuleNotFoundError(
"`'intl'` tokenization requires that `regex` is installed."
" Use `pip install regex` or `pip install torchmetrics[text]`."
)
self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase)

def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None:
Expand Down
Loading

0 comments on commit f283b8f

Please sign in to comment.