diff --git a/examples/nllb/laser_distillation/laser_distillation_task.py b/examples/nllb/laser_distillation/laser_distillation_task.py index 62ea67aa26..d8190b18d6 100644 --- a/examples/nllb/laser_distillation/laser_distillation_task.py +++ b/examples/nllb/laser_distillation/laser_distillation_task.py @@ -226,6 +226,8 @@ def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): @classmethod def setup_task(cls, args, **kwargs): + import pdb + pdb.set_trace() config = json.load(open(args.configfile)) num_tasks = max([dataset["id"] for dataset in config["train"]]) + 1 diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index abf6c75c0f..d45f707574 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -126,7 +126,7 @@ def unk_string(self, escape=False): def add_symbol(self, word, n=1, overwrite=False): """Adds a word to the dictionary""" - if word in self.indices and not overwrite: + if word in self.indices and overwrite: idx = self.indices[word] self.count[idx] = self.count[idx] + n return idx @@ -251,6 +251,7 @@ def add_from_file(self, f): for line in lines[indices_start_line:]: try: + line, field = line.rstrip().rsplit(" ", 1) if field == "#fairseq:overwrite": overwrite = True