-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval_mt.py
94 lines (82 loc) · 2.84 KB
/
eval_mt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from datasets import load_dataset
from transformers.pipelines.pt_utils import KeyDataset
import numpy as np
from src.Dialects import (
AfricanAmericanVernacular,
IndianDialect,
ColloquialSingaporeDialect,
ChicanoDialect,
AppalachianDialect,
NigerianDialect,
BlackSouthAfricanDialect,
)
from sacrebleu.metrics import BLEU
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
TASK = "translation"
# CKPT = "facebook/nllb-200-distilled-600M"
CKPT = "facebook/nllb-200-distilled-1.3B"
src_lang = "eng_Latn"
tgt_lang_dict = {"de": "deu_Latn", "ru": "rus_Cyrl", "zh": "zho_Hans", "gu": "guj_Gujr"}
device = 3 if "1.3B" in CKPT else 1
def dialect_factory(dialect):
def dialect_transform(examples):
D = dialect(morphosyntax=True)
examples["src"] = [
D.convert_sae_to_dialect(src_text) for src_text in examples["src"]
]
return examples
return dialect_transform
def flatten_factory(target):
def flatten(example):
example["src"] = example["translation"]["en"]
example["tgt"] = example["translation"][target]
del example["translation"]
return example
return flatten
def translate_factory(pipe):
def translate(examples):
examples["tgt_pred"] = [
out["translation_text"] for out in pipe(examples["src"], batch_size=16)
]
return examples
return translate
model = AutoModelForSeq2SeqLM.from_pretrained(CKPT).to("cuda:" + str(device))
tokenizer = AutoTokenizer.from_pretrained(CKPT)
for lang in ["de", "gu", "zh", "ru"]:
dataset = load_dataset(f"WillHeld/wmt19-valid-only-{lang}_en")["validation"]
pipe = pipeline(
TASK,
model=model,
tokenizer=tokenizer,
src_lang=src_lang,
tgt_lang=tgt_lang_dict[lang],
max_length=400,
device=device,
)
sacrebleu = BLEU(trg_lang=lang)
for dialect in [
None,
AfricanAmericanVernacular,
IndianDialect,
ColloquialSingaporeDialect,
ChicanoDialect,
AppalachianDialect,
NigerianDialect,
BlackSouthAfricanDialect,
]:
d_dataset = dataset.map(flatten_factory(lang))
if dialect:
dialect_name = dialect(morphosyntax=True).dialect_name
dialect_transform = dialect_factory(dialect)
d_dataset = d_dataset.map(dialect_transform, num_proc=24, batched=True)
else:
dialect_name = "Standard American"
d_dataset = d_dataset.map(translate_factory(pipe), batched=True)
rng = np.random.default_rng(12345)
res = sacrebleu.corpus_score(
list(d_dataset["tgt_pred"]),
[list(d_dataset["tgt"])],
n_bootstrap=1000,
)
print(f"{dialect_name} en -> {lang}")
print(res.format().encode("latin-1", "replace").decode("latin-1"))