From e2d0a882b942d2fd9666fc96436beda5e93ed8dc Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Thu, 3 Oct 2024 02:04:16 +0000 Subject: [PATCH] Rework script predictor function to return script codes and to be more precise about Han variants and Japanese --- silnlp/common/script_utils.py | 37 +++++++++++++++++++++-------- silnlp/nmt/analyze_project_pairs.py | 10 ++++---- silnlp/nmt/config.py | 1 - 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/silnlp/common/script_utils.py b/silnlp/common/script_utils.py index aebfa944..d9c9a41f 100644 --- a/silnlp/common/script_utils.py +++ b/silnlp/common/script_utils.py @@ -2,7 +2,7 @@ import logging from collections import Counter -import hanzidentifier +from hanzidentifier import SIMPLIFIED, TRADITIONAL, identify LOGGER = logging.getLogger(__package__ + ".script_utils") @@ -2032,21 +2032,38 @@ def script(char): return "Unknown" -def get_script(text: str) -> str: +def predict_han_variant(text: str) -> str: + num_trad = 0 + num_simp = 0 + for c in text: + char_type = identify(c) + num_trad += char_type == TRADITIONAL + num_simp += char_type == SIMPLIFIED + + return "Hant" if num_trad > num_simp else "Hans" + + +def predict_script_code(text: str) -> str: if len(text) == 0: return "None" counts = Counter([script(char) for char in text]) - return counts.most_common()[0][0] + pred_script = counts.most_common()[0][0] + + if pred_script in ["Hiragana", "Katakana"] or ( + pred_script == "Han" and ("Hiragana" in counts.keys() or "Katakana" in counts.keys()) + ): + return "Jpan" + if pred_script == "Han": + return predict_han_variant(text) + + return SCRIPT_CODES[pred_script] -def is_represented(script: str, model: str) -> bool: +def is_represented(script_code: str, model: str) -> bool: for model_prefix in REPRESENTED_SCRIPTS: - if model.startswith(model_prefix): - if script in REPRESENTED_SCRIPTS[model_prefix]: - return True - elif script in ["Hiragana", "Katakana"] and "Japanese" in REPRESENTED_SCRIPTS[model_prefix]: - return True + if model.startswith(model_prefix) and script_code in REPRESENTED_SCRIPTS[model_prefix]: + return True return False @@ -2058,7 +2075,7 @@ def main() -> None: with open(args.input, encoding="utf-8-sig") as f: text = f.read() - file_script = get_script(text) + file_script = predict_script_code(text) LOGGER.info(f"Script: {file_script}") diff --git a/silnlp/nmt/analyze_project_pairs.py b/silnlp/nmt/analyze_project_pairs.py index d9e67da8..92057286 100644 --- a/silnlp/nmt/analyze_project_pairs.py +++ b/silnlp/nmt/analyze_project_pairs.py @@ -13,7 +13,7 @@ from ..common.collect_verse_counts import DT_CANON, NT_CANON, OT_CANON, collect_verse_counts from ..common.corpus import filter_parallel_corpus, get_mt_corpus_path, get_scripture_parallel_corpus, include_chapters from ..common.environment import SIL_NLP_ENV -from ..common.script_utils import get_script, is_represented +from ..common.script_utils import is_represented, predict_script_code from ..common.utils import get_git_revision_hash from .clearml_connection import SILClearML from .config import Config, get_data_file_pairs @@ -103,11 +103,11 @@ def get_corpus_stats(config: Config, force_align: bool = False, deutero: bool = filtered_count = parallel_count - len(corpus) filtered_alignment_score = mean(corpus["score"]) - src_script = get_script("".join(corpus["source"][: min(len(corpus["source"]), 3000)])) + src_script = predict_script_code("".join(corpus["source"][: min(len(corpus["source"]), 3000)])) src_script_in_model = ( is_represented(src_script, config.model) if config.model != "SILTransformerBase" else None ) - trg_script = get_script("".join(corpus["target"][: min(len(corpus["target"]), 3000)])) + trg_script = predict_script_code("".join(corpus["target"][: min(len(corpus["target"]), 3000)])) trg_script_in_model = ( is_represented(trg_script, config.model) if config.model != "SILTransformerBase" else None ) @@ -207,11 +207,11 @@ def get_extra_alignments(config: Config, deutero: bool = False) -> List[str]: [is_ot_nt(VerseRef.from_string(vref).book_num) for vref in align_corpus["vref"]] ] parallel_count = len(align_corpus.index) - src_script = get_script("".join(align_corpus["source"][: min(len(align_corpus["source"]), 3000)])) + src_script = predict_script_code("".join(align_corpus["source"][: min(len(align_corpus["source"]), 3000)])) src_script_in_model = ( is_represented(src_script, config.model) if config.model != "SILTransformerBase" else None ) - trg_script = get_script("".join(align_corpus["target"][: min(len(align_corpus["target"]), 3000)])) + trg_script = predict_script_code("".join(align_corpus["target"][: min(len(align_corpus["target"]), 3000)])) trg_script_in_model = ( is_represented(trg_script, config.model) if config.model != "SILTransformerBase" else None ) diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index 300d3b0e..910b9345 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -37,7 +37,6 @@ write_corpus, ) from ..common.environment import SIL_NLP_ENV -from ..common.script_utils import get_script, is_represented from ..common.translator import TranslationGroup from ..common.utils import NoiseMethod, Side, create_noise_methods, get_mt_exp_dir, is_set, set_seed from .augment import AugmentMethod, create_augment_methods