diff --git a/examples/nllb/laser_distillation/laser_distillation_task.py b/examples/nllb/laser_distillation/laser_distillation_task.py index d8190b18d6..a3310a8ebf 100644 --- a/examples/nllb/laser_distillation/laser_distillation_task.py +++ b/examples/nllb/laser_distillation/laser_distillation_task.py @@ -222,12 +222,14 @@ def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): src_dictionary, tgt_dictionary, args.student_bpe_symbol, args.teacher_bpe_symbol, interval=1000, samples=5 ) # added to dictionary during setup_task - self.mask_idx = self.src_dictionary.index("") + if isinstance(self.src_dictionary, BertDictionary): + self.mask_idx = self.src_dictionary.index("[MASK]") + else: + self.mask_idx = self.src_dictionary.index("") + @classmethod def setup_task(cls, args, **kwargs): - import pdb - pdb.set_trace() config = json.load(open(args.configfile)) num_tasks = max([dataset["id"] for dataset in config["train"]]) + 1 @@ -238,8 +240,12 @@ def setup_task(cls, args, **kwargs): config["src_vocab"] and config["tgt_vocab"] ), f"Source and target vocab must be specified" + # if args.student_bpe_symbol in [ "bert", "none" ] : + # src_dictionary = BertDictionary.load(config["src_vocab"]) + # else: src_dictionary = Dictionary.load(config["src_vocab"]) - src_dictionary.add_symbol("") + if "" not in src_dictionary.indices: + src_dictionary.add_symbol("") tgt_dictionary = Dictionary.load(config["tgt_vocab"]) logger.info( @@ -779,6 +785,15 @@ def reduce_metrics(self, logging_outputs, criterion): lambda meters: utils.get_perplexity(meters["tlm_loss"].avg), ) + # @classmethod + # def load_dictionary(cls, filename): + # """Load the dictionary from the filename + + # Args: + # filename (str): the filename + # """ + # return BertDictionary.load(filename) + class SamplePrint: def __init__(self, source_dictionary, target_dictionary, student_bpe_symbol, teacher_bpe_symbol, interval, samples): @@ -834,6 +849,45 @@ def __call__(self, student_src_tokens, teacher_src_tokens, student_teacher_task) ) ) +class BertDictionary(Dictionary): + """A mapping from symbols to consecutive integers""" + def __init__(self): + self.nspecial = 5 + self.bos_word = "[CLS]" + self.pad_word = "[PAD]" + self.eos_word = "[SEP]" + self.unk_word = "[UNK]" + self.mask_word = "[MASK]" + self.bos_index = self.pad_index = self.eos_index = self.unk_index = self.mask_index = None + self.symbols = [] + self.count = [] + self.indices = {} + + def add_symbol(self, word, n=1, overwrite=False): + """Adds a word to the dictionary""" + if word in self.indices and overwrite: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + if word == self.bos_word: + self.bos_index = idx + elif word == self.pad_word: + self.pad_index = idx + elif word == self.eos_word: + self.eos_index = idx + elif word == self.unk_word: + self.unk_index = idx + elif word == self.mask_word: + self.mask_index = idx + return idx + + def mask(self): + """Helper to get index of mask symbol""" + return self.mask_index @contextmanager def check_before_after_modelsize(model): @@ -879,7 +933,7 @@ def get_laser_lstm_args(args): lstm_args.decoder_lang_embed_dim = 32 return lstm_args -BERT_TO_TRANSFORMER_KEY_MAPPING = { +HUGGINGFACE_TO_FAIRSEQ_KEY_MAPPING = { "layer.": "layers.", "attention.self.query": "self_attn.q_proj", "attention.self.key": "self_attn.k_proj", @@ -893,9 +947,9 @@ def get_laser_lstm_args(args): def map_transformer_layer_attribute_names(key): new_key = key - for substring in BERT_TO_TRANSFORMER_KEY_MAPPING.keys(): + for substring in HUGGINGFACE_TO_FAIRSEQ_KEY_MAPPING.keys(): if substring in key: - new_key = new_key.replace(substring, BERT_TO_TRANSFORMER_KEY_MAPPING[substring]) + new_key = new_key.replace(substring, HUGGINGFACE_TO_FAIRSEQ_KEY_MAPPING[substring]) return new_key # compute weighting per lang diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index d45f707574..44690b980c 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -256,6 +256,9 @@ def add_from_file(self, f): if field == "#fairseq:overwrite": overwrite = True line, field = line.rsplit(" ", 1) + elif field == "#fairseq:duplicate": + overwrite = False + line, field = line.rsplit(" ", 1) else: overwrite = False count = int(field) @@ -266,7 +269,9 @@ def add_from_file(self, f): "Duplicate words can overwrite earlier ones by adding the " "#fairseq:overwrite flag at the end of the corresponding row " "in the dictionary file. If using the Camembert model, please " - "download an updated copy of the model file.".format(word) + "download an updated copy of the model file. " + "Use the #fairseq:duplicate flag to allow duplicates " + "(backward compatibility after overwrite bug fix).".format(word) ) self.add_symbol(word, n=count, overwrite=overwrite) except ValueError: @@ -279,14 +284,18 @@ def _save(self, f, kv_iterator): PathManager.mkdirs(os.path.dirname(f)) with PathManager.open(f, "w", encoding="utf-8") as fd: return self.save(fd) - for k, v in kv_iterator: - print("{} {}".format(k, v), file=f) + for k, v, flag in kv_iterator: + print("{} {} {}".format(k, v, flag), file=f) def _get_meta(self): return [], [] def _load_meta(self, lines): return 0 + + def _get_duplicate_flags(self): + return [ '#fairseq:duplicate' if s in self.symbols[:i] else '' for i, s in enumerate(self.symbols) ] + def save(self, f): """Stores dictionary into a text file""" @@ -296,6 +305,7 @@ def save(self, f): zip( ex_keys + self.symbols[self.nspecial :], ex_vals + self.count[self.nspecial :], + ex_vals + self._get_duplicate_flags(), ), )