Skip to content

Commit

Permalink
sacrebleu[feat]: Add ja-mecab tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
stancld committed Sep 10, 2023
1 parent 6960602 commit c1a10cc
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 40 deletions.
2 changes: 2 additions & 0 deletions requirements/text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ nltk >=3.6, <=3.8.1
tqdm >=4.41.0, <=4.66.1
regex >=2021.9.24, <=2023.8.8
transformers >4.4.0, <4.30.3
mecab-python3 >= 1.0.6, <1.1.0
ipadiac >= 1.0.0, <1.1.0
73 changes: 53 additions & 20 deletions src/torchmetrics/functional/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
from typing_extensions import Literal

from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update
from torchmetrics.utilities.imports import _REGEX_AVAILABLE
from torchmetrics.utilities.imports import _IPADIC_AVAILABLE, _MECAB_AVAILABLE, _REGEX_AVAILABLE

AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char")
AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char", "ja-mecab")
Tokenizers = Literal["none", "13a", "zh", "intl", "char", "ja-mecab"]

_UCODE_RANGES = (
("\u3400", "\u4db5"), # CJK Unified Ideographs Extension A, release 3.0
Expand Down Expand Up @@ -116,9 +117,12 @@ class _SacreBLEUTokenizer:
"zh": "_tokenize_zh",
"intl": "_tokenize_international",
"char": "_tokenize_char",
"ja-mecab": "_tokenize_ja_mecab",
}

def __init__(self, tokenize: Literal["none", "13a", "zh", "intl", "char"], lowercase: bool = False) -> None:
def __init__(self, tokenize: Tokenizers, lowercase: bool = False) -> None:
self._check_tokenizers_validity(tokenize)

self.tokenize_fn = getattr(self, self._TOKENIZE_FN[tokenize])
self.lowercase = lowercase

Expand All @@ -127,9 +131,9 @@ def __call__(self, line: str) -> Sequence[str]:
return self._lower(tokenized_line, self.lowercase).split()

@classmethod
def tokenize(
cls, line: str, tokenize: Literal["none", "13a", "zh", "intl", "char"], lowercase: bool = False
) -> Sequence[str]:
def tokenize(cls, line: str, tokenize: Tokenizers, lowercase: bool = False) -> Sequence[str]:
cls._check_tokenizers_validity(tokenize)

tokenize_fn = getattr(cls, cls._TOKENIZE_FN[tokenize])
tokenized_line = tokenize_fn(line)
return cls._lower(tokenized_line, lowercase).split()
Expand Down Expand Up @@ -274,19 +278,60 @@ def _tokenize_char(cls, line: str) -> str:
"""
return " ".join(char for char in line)

@classmethod
def _tokenize_ja_mecab(cls, line: str) -> str:
"""Tokenizes a Japanese string line using MeCab morphological analyzer.
Args:
line: the input string to tokenize.
Return:
The tokenized string.
"""
import ipadic
import MeCab

tagger = MeCab.Tagger(ipadic.MECAB_ARGS + " -Owakati")

line = line.strip()
return tagger.parse(line).strip()

@staticmethod
def _lower(line: str, lowercase: bool) -> str:
if lowercase:
return line.lower()
return line

@classmethod
def _check_tokenizers_validity(cls, tokenize: Tokenizers) -> None:
"""Check if a supported tokenizer is chosen.
Also check all dependencies of a given tokenizers are installed.
"""
if tokenize not in cls._TOKENIZE_FN.keys():
raise ValueError(f"Unsupported tokenizer selected. Please, choose one of {list(cls._TOKENIZE_FN.keys())}")

if tokenize == "intl" and not _REGEX_AVAILABLE:
raise ModuleNotFoundError(
"`'intl'` tokenization requires that `regex` is installed."
" Use `pip install regex` or `pip install torchmetrics[text]`."
)

if tokenize == "ja-mecab" and not (_MECAB_AVAILABLE and _IPADIC_AVAILABLE):
raise ModuleNotFoundError(
"`'ja-mecab'` tokenization requires that `MeCab` and `ipadic` are installed."
" Use `pip install mecab-python3 ipadic` or `pip install torchmetrics[text]`."
)


def sacre_bleu_score(
preds: Sequence[str],
target: Sequence[Sequence[str]],
n_gram: int = 4,
smooth: bool = False,
tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
tokenize: Tokenizers = "13a",
lowercase: bool = False,
weights: Optional[Sequence[float]] = None,
) -> Tensor:
Expand All @@ -300,7 +345,7 @@ def sacre_bleu_score(
n_gram: Gram value ranged from 1 to 4
smooth: Whether to apply smoothing - see [2]
tokenize: Tokenization technique to be used.
Supported tokenization: ['none', '13a', 'zh', 'intl', 'char']
Supported tokenization: ['none', '13a', 'zh', 'intl', 'char', 'ja-mecab']
lowercase: If ``True``, BLEU score over lowercased text is calculated.
weights:
Weights used for unigrams, bigrams, etc. to calculate BLEU score.
Expand Down Expand Up @@ -330,20 +375,8 @@ def sacre_bleu_score(
and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_
"""
if tokenize not in AVAILABLE_TOKENIZERS:
raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.")

if tokenize not in _SacreBLEUTokenizer._TOKENIZE_FN.keys():
raise ValueError(
f"Unsupported tokenizer selected. Please, choose one of {list(_SacreBLEUTokenizer._TOKENIZE_FN.keys())}"
)
if len(preds) != len(target):
raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}")
if tokenize == "intl" and not _REGEX_AVAILABLE:
raise ModuleNotFoundError(
"`'intl'` tokenization requires that `regex` is installed."
" Use `pip install regex` or `pip install torchmetrics[text]`."
)

if weights is not None and len(weights) != n_gram:
raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
Expand Down
18 changes: 3 additions & 15 deletions src/torchmetrics/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,17 @@
from typing import Any, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.text.bleu import _bleu_score_update
from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer
from torchmetrics.functional.text.sacre_bleu import Tokenizers, _SacreBLEUTokenizer
from torchmetrics.text.bleu import BLEUScore
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _REGEX_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SacreBLEUScore.plot"]


AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char")


class SacreBLEUScore(BLEUScore):
"""Calculate `BLEU score`_ of machine translated text with one or more references.
Expand Down Expand Up @@ -95,20 +91,12 @@ def __init__(
self,
n_gram: int = 4,
smooth: bool = False,
tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
tokenize: Tokenizers = "13a",
lowercase: bool = False,
weights: Optional[Sequence[float]] = None,
**kwargs: Any,
) -> None:
super().__init__(n_gram=n_gram, smooth=smooth, weights=weights, **kwargs)
if tokenize not in AVAILABLE_TOKENIZERS:
raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.")

if tokenize == "intl" and not _REGEX_AVAILABLE:
raise ModuleNotFoundError(
"`'intl'` tokenization requires that `regex` is installed."
" Use `pip install regex` or `pip install torchmetrics[text]`."
)
self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase)

def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,7 @@
_XLA_AVAILABLE: bool = package_available("torch_xla")
_PIQ_GREATER_EQUAL_0_8: Optional[bool] = compare_version("piq", operator.ge, "0.8.0")
_FASTER_COCO_EVAL_AVAILABLE: bool = package_available("faster_coco_eval")
_MECAB_AVAILABLE: bool = package_available("MeCab")
_IPADIC_AVAILABLE: bool = package_available("ipadic")

_LATEX_AVAILABLE: bool = shutil.which("latex") is not None
16 changes: 11 additions & 5 deletions tests/unittests/text/test_sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest
from torch import Tensor, tensor
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
from torchmetrics.functional.text.sacre_bleu import AVAILABLE_TOKENIZERS, sacre_bleu_score
from torchmetrics.text.sacre_bleu import SacreBLEUScore
from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE

Expand All @@ -28,9 +28,6 @@
from sacrebleu.metrics import BLEU


TOKENIZERS = ("none", "13a", "zh", "intl", "char")


def _sacrebleu_fn(preds: Sequence[str], targets: Sequence[Sequence[str]], tokenize: str, lowercase: bool) -> Tensor:
sacrebleu_fn = BLEU(tokenize=tokenize, lowercase=lowercase)
# Sacrebleu expects different format of input
Expand All @@ -44,7 +41,7 @@ def _sacrebleu_fn(preds: Sequence[str], targets: Sequence[Sequence[str]], tokeni
[(_inputs_multiple_references.preds, _inputs_multiple_references.targets)],
)
@pytest.mark.parametrize(["lowercase"], [(False,), (True,)])
@pytest.mark.parametrize("tokenize", TOKENIZERS)
@pytest.mark.parametrize("tokenize", AVAILABLE_TOKENIZERS)
@pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu")
class TestSacreBLEUScore(TextTester):
"""Test class for `SacreBLEUScore` metric."""
Expand Down Expand Up @@ -109,3 +106,12 @@ def test_no_and_uniform_weights_class():
no_weights_score = no_weights_bleu(preds, targets)
uniform_weights_score = uniform_weights_bleu(preds, targets)
assert no_weights_score == uniform_weights_score


def test_tokenize_ja_mecab():
"""Test that `ja-mecab` tokenizer works on a Japanese text in alignment with the SacreBleu implementation."""
sacrebleu = SacreBLEUScore(tokenize="ja-mecab")

preds = ["これは美しい花です。"]
targets = [["これは美しい花です。", "おいしい寿司を食べたい。"]]
assert sacrebleu(preds, targets) == _sacrebleu_fn(preds, targets, tokenize="ja-mecab", lowercase=False)

0 comments on commit c1a10cc

Please sign in to comment.