Skip to content

Commit

Permalink
Rework script predictor function to return script codes and to be mor…
Browse files Browse the repository at this point in the history
…e precise about Han variants and Japanese
  • Loading branch information
isaac091 committed Oct 3, 2024
1 parent dc61b10 commit e2d0a88
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
37 changes: 27 additions & 10 deletions silnlp/common/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import Counter

import hanzidentifier
from hanzidentifier import SIMPLIFIED, TRADITIONAL, identify

LOGGER = logging.getLogger(__package__ + ".script_utils")

Expand Down Expand Up @@ -2032,21 +2032,38 @@ def script(char):
return "Unknown"


def get_script(text: str) -> str:
def predict_han_variant(text: str) -> str:
num_trad = 0
num_simp = 0
for c in text:
char_type = identify(c)
num_trad += char_type == TRADITIONAL
num_simp += char_type == SIMPLIFIED

return "Hant" if num_trad > num_simp else "Hans"


def predict_script_code(text: str) -> str:
if len(text) == 0:
return "None"

counts = Counter([script(char) for char in text])
return counts.most_common()[0][0]
pred_script = counts.most_common()[0][0]

if pred_script in ["Hiragana", "Katakana"] or (
pred_script == "Han" and ("Hiragana" in counts.keys() or "Katakana" in counts.keys())
):
return "Jpan"
if pred_script == "Han":
return predict_han_variant(text)

return SCRIPT_CODES[pred_script]


def is_represented(script: str, model: str) -> bool:
def is_represented(script_code: str, model: str) -> bool:
for model_prefix in REPRESENTED_SCRIPTS:
if model.startswith(model_prefix):
if script in REPRESENTED_SCRIPTS[model_prefix]:
return True
elif script in ["Hiragana", "Katakana"] and "Japanese" in REPRESENTED_SCRIPTS[model_prefix]:
return True
if model.startswith(model_prefix) and script_code in REPRESENTED_SCRIPTS[model_prefix]:
return True
return False


Expand All @@ -2058,7 +2075,7 @@ def main() -> None:
with open(args.input, encoding="utf-8-sig") as f:
text = f.read()

file_script = get_script(text)
file_script = predict_script_code(text)
LOGGER.info(f"Script: {file_script}")


Expand Down
10 changes: 5 additions & 5 deletions silnlp/nmt/analyze_project_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..common.collect_verse_counts import DT_CANON, NT_CANON, OT_CANON, collect_verse_counts
from ..common.corpus import filter_parallel_corpus, get_mt_corpus_path, get_scripture_parallel_corpus, include_chapters
from ..common.environment import SIL_NLP_ENV
from ..common.script_utils import get_script, is_represented
from ..common.script_utils import is_represented, predict_script_code
from ..common.utils import get_git_revision_hash
from .clearml_connection import SILClearML
from .config import Config, get_data_file_pairs
Expand Down Expand Up @@ -103,11 +103,11 @@ def get_corpus_stats(config: Config, force_align: bool = False, deutero: bool =
filtered_count = parallel_count - len(corpus)
filtered_alignment_score = mean(corpus["score"])

src_script = get_script("".join(corpus["source"][: min(len(corpus["source"]), 3000)]))
src_script = predict_script_code("".join(corpus["source"][: min(len(corpus["source"]), 3000)]))
src_script_in_model = (
is_represented(src_script, config.model) if config.model != "SILTransformerBase" else None
)
trg_script = get_script("".join(corpus["target"][: min(len(corpus["target"]), 3000)]))
trg_script = predict_script_code("".join(corpus["target"][: min(len(corpus["target"]), 3000)]))
trg_script_in_model = (
is_represented(trg_script, config.model) if config.model != "SILTransformerBase" else None
)
Expand Down Expand Up @@ -207,11 +207,11 @@ def get_extra_alignments(config: Config, deutero: bool = False) -> List[str]:
[is_ot_nt(VerseRef.from_string(vref).book_num) for vref in align_corpus["vref"]]
]
parallel_count = len(align_corpus.index)
src_script = get_script("".join(align_corpus["source"][: min(len(align_corpus["source"]), 3000)]))
src_script = predict_script_code("".join(align_corpus["source"][: min(len(align_corpus["source"]), 3000)]))
src_script_in_model = (
is_represented(src_script, config.model) if config.model != "SILTransformerBase" else None
)
trg_script = get_script("".join(align_corpus["target"][: min(len(align_corpus["target"]), 3000)]))
trg_script = predict_script_code("".join(align_corpus["target"][: min(len(align_corpus["target"]), 3000)]))
trg_script_in_model = (
is_represented(trg_script, config.model) if config.model != "SILTransformerBase" else None
)
Expand Down
1 change: 0 additions & 1 deletion silnlp/nmt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
write_corpus,
)
from ..common.environment import SIL_NLP_ENV
from ..common.script_utils import get_script, is_represented
from ..common.translator import TranslationGroup
from ..common.utils import NoiseMethod, Side, create_noise_methods, get_mt_exp_dir, is_set, set_seed
from .augment import AugmentMethod, create_augment_methods
Expand Down

0 comments on commit e2d0a88

Please sign in to comment.