Skip to content

Commit

Permalink
Merge pull request #268 from facebookresearch/fix-parity
Browse files Browse the repository at this point in the history
Ensure `laser_encoders` has parity with existing LASER inference code for release
  • Loading branch information
heffernankevin authored Nov 20, 2023
2 parents 90db293 + 77bf7fb commit b4aed58
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
7 changes: 6 additions & 1 deletion laser_encoders/laser_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@

import sentencepiece as spm
from sacremoses import MosesDetokenizer, MosesPunctNormalizer
from unicategories import categories

from laser_encoders.download_models import LaserModelDownloader
from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE, SPM_LANGUAGE

SPACE_NORMALIZER = re.compile(r"\s+")
NON_PRINT_CHARS = set(c for c in categories["C"].characters())

logging.basicConfig(
stream=sys.stdout,
Expand Down Expand Up @@ -59,6 +61,9 @@ def __init__(

assert spm_model.exists(), f"spm model file: {spm_model} does not exist"
self.moses_punct_normalizer = MosesPunctNormalizer(self.lang, perl_parity=True)
# add parity with MOSES release-4.0
self.moses_punct_normalizer.substitutions[21] = ("‘", r'"')
self.moses_punct_normalizer.substitutions[22] = ("‚", r'"')
self.moses_detokenizer = MosesDetokenizer()
self.spm_encoder = spm.SentencePieceProcessor(model_file=str(self.spm_model))

Expand All @@ -75,7 +80,7 @@ def log(self, message: str) -> None:

def tokenize(self, text: str) -> str:
# Preprocessing
sentence_text = "".join(c for c in text if c.isprintable)
sentence_text = "".join([c if c not in NON_PRINT_CHARS else " " for c in text])
if self.normalize_punct:
sentence_text = self.moses_punct_normalizer.normalize(sentence_text)
if self.descape:
Expand Down
2 changes: 1 addition & 1 deletion laser_encoders/test_laser_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_lowercase(tokenizer):

def test_is_printable(tokenizer):
test_data = "Hello, \tWorld! ABC\x1f123"
expected_output = "▁hel lo , ▁world ! ▁ab c 12 3"
expected_output = "▁hel lo , ▁world ! ▁ab c 12 3"
assert tokenizer.tokenize(test_data) == expected_output


Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ readme = "laser_encoders/README.md"
requires-python = ">=3.8"

dependencies = [
'sacremoses>=0.1.0',
'sacremoses==0.1.0',
'unicategories>=0.1.2',
'sentencepiece>=0.1.99',
'numpy>=1.21.3',
'torch>=1.10.0',
Expand Down

0 comments on commit b4aed58

Please sign in to comment.