Skip to content

Commit

Permalink
implement fairseq duplicate
Browse files Browse the repository at this point in the history
  • Loading branch information
NISHIMWE Lydia committed Sep 21, 2023
1 parent 9f8882b commit 559d3e8
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 10 deletions.
68 changes: 61 additions & 7 deletions examples/nllb/laser_distillation/laser_distillation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,14 @@ def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks):
src_dictionary, tgt_dictionary, args.student_bpe_symbol, args.teacher_bpe_symbol, interval=1000, samples=5
)
# added to dictionary during setup_task
self.mask_idx = self.src_dictionary.index("<mask>")
if isinstance(self.src_dictionary, BertDictionary):
self.mask_idx = self.src_dictionary.index("[MASK]")
else:
self.mask_idx = self.src_dictionary.index("<mask>")


@classmethod
def setup_task(cls, args, **kwargs):
import pdb
pdb.set_trace()
config = json.load(open(args.configfile))
num_tasks = max([dataset["id"] for dataset in config["train"]]) + 1

Expand All @@ -238,8 +240,12 @@ def setup_task(cls, args, **kwargs):
config["src_vocab"] and config["tgt_vocab"]
), f"Source and target vocab must be specified"

# if args.student_bpe_symbol in [ "bert", "none" ] :
# src_dictionary = BertDictionary.load(config["src_vocab"])
# else:
src_dictionary = Dictionary.load(config["src_vocab"])
src_dictionary.add_symbol("<mask>")
if "<mask>" not in src_dictionary.indices:
src_dictionary.add_symbol("<mask>")
tgt_dictionary = Dictionary.load(config["tgt_vocab"])

logger.info(
Expand Down Expand Up @@ -779,6 +785,15 @@ def reduce_metrics(self, logging_outputs, criterion):
lambda meters: utils.get_perplexity(meters["tlm_loss"].avg),
)

# @classmethod
# def load_dictionary(cls, filename):
# """Load the dictionary from the filename

# Args:
# filename (str): the filename
# """
# return BertDictionary.load(filename)


class SamplePrint:
def __init__(self, source_dictionary, target_dictionary, student_bpe_symbol, teacher_bpe_symbol, interval, samples):
Expand Down Expand Up @@ -834,6 +849,45 @@ def __call__(self, student_src_tokens, teacher_src_tokens, student_teacher_task)
)
)

class BertDictionary(Dictionary):
"""A mapping from symbols to consecutive integers"""
def __init__(self):
self.nspecial = 5
self.bos_word = "[CLS]"
self.pad_word = "[PAD]"
self.eos_word = "[SEP]"
self.unk_word = "[UNK]"
self.mask_word = "[MASK]"
self.bos_index = self.pad_index = self.eos_index = self.unk_index = self.mask_index = None
self.symbols = []
self.count = []
self.indices = {}

def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices and overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
if word == self.bos_word:
self.bos_index = idx
elif word == self.pad_word:
self.pad_index = idx
elif word == self.eos_word:
self.eos_index = idx
elif word == self.unk_word:
self.unk_index = idx
elif word == self.mask_word:
self.mask_index = idx
return idx

def mask(self):
"""Helper to get index of mask symbol"""
return self.mask_index

@contextmanager
def check_before_after_modelsize(model):
Expand Down Expand Up @@ -879,7 +933,7 @@ def get_laser_lstm_args(args):
lstm_args.decoder_lang_embed_dim = 32
return lstm_args

BERT_TO_TRANSFORMER_KEY_MAPPING = {
HUGGINGFACE_TO_FAIRSEQ_KEY_MAPPING = {
"layer.": "layers.",
"attention.self.query": "self_attn.q_proj",
"attention.self.key": "self_attn.k_proj",
Expand All @@ -893,9 +947,9 @@ def get_laser_lstm_args(args):

def map_transformer_layer_attribute_names(key):
new_key = key
for substring in BERT_TO_TRANSFORMER_KEY_MAPPING.keys():
for substring in HUGGINGFACE_TO_FAIRSEQ_KEY_MAPPING.keys():
if substring in key:
new_key = new_key.replace(substring, BERT_TO_TRANSFORMER_KEY_MAPPING[substring])
new_key = new_key.replace(substring, HUGGINGFACE_TO_FAIRSEQ_KEY_MAPPING[substring])
return new_key

# compute weighting per lang
Expand Down
16 changes: 13 additions & 3 deletions fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ def add_from_file(self, f):
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
elif field == "#fairseq:duplicate":
overwrite = False
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
Expand All @@ -266,7 +269,9 @@ def add_from_file(self, f):
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
"download an updated copy of the model file. "
"Use the #fairseq:duplicate flag to allow duplicates "
"(backward compatibility after overwrite bug fix).".format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
Expand All @@ -279,14 +284,18 @@ def _save(self, f, kv_iterator):
PathManager.mkdirs(os.path.dirname(f))
with PathManager.open(f, "w", encoding="utf-8") as fd:
return self.save(fd)
for k, v in kv_iterator:
print("{} {}".format(k, v), file=f)
for k, v, flag in kv_iterator:
print("{} {} {}".format(k, v, flag), file=f)

def _get_meta(self):
return [], []

def _load_meta(self, lines):
return 0

def _get_duplicate_flags(self):
return [ '#fairseq:duplicate' if s in self.symbols[:i] else '' for i, s in enumerate(self.symbols) ]


def save(self, f):
"""Stores dictionary into a text file"""
Expand All @@ -296,6 +305,7 @@ def save(self, f):
zip(
ex_keys + self.symbols[self.nspecial :],
ex_vals + self.count[self.nspecial :],
ex_vals + self._get_duplicate_flags(),
),
)

Expand Down

0 comments on commit 559d3e8

Please sign in to comment.