From e772aeb29d48450adf96c93e983313273cd4dc9d Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Wed, 24 Jul 2024 21:00:15 +0200 Subject: [PATCH 1/2] Update dictionary.py with fairseq overwrite bug fix --- fairseq/data/dictionary.py | 75 +++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 2d061dd424..03df789f7e 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -1,7 +1,6 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os @@ -9,7 +8,6 @@ from multiprocessing import Pool import torch - from fairseq import utils from fairseq.data import data_utils from fairseq.file_chunker_utils import Chunker, find_offsets @@ -28,19 +26,21 @@ def __init__( eos="", unk="", extra_special_symbols=None, + add_special_symbols=True, ): self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos self.symbols = [] self.count = [] self.indices = {} - self.bos_index = self.add_symbol(bos) - self.pad_index = self.add_symbol(pad) - self.eos_index = self.add_symbol(eos) - self.unk_index = self.add_symbol(unk) - if extra_special_symbols: - for s in extra_special_symbols: - self.add_symbol(s) - self.nspecial = len(self.symbols) + if add_special_symbols: + self.bos_index = self.add_symbol(bos) + self.pad_index = self.add_symbol(pad) + self.eos_index = self.add_symbol(eos) + self.unk_index = self.add_symbol(unk) + if extra_special_symbols: + for s in extra_special_symbols: + self.add_symbol(s) + self.nspecial = len(self.symbols) def __eq__(self, other): return self.indices == other.indices @@ -124,7 +124,7 @@ def unk_string(self, escape=False): else: return self.unk_word - def add_symbol(self, word, n=1, overwrite=False): + def add_symbol(self, word, n=1, overwrite=True): """Adds a word to the dictionary""" if word in self.indices and overwrite: idx = self.indices[word] @@ -215,7 +215,7 @@ def unk(self): return self.unk_index @classmethod - def load(cls, f): + def load(cls, f, add_special_symbols=True): """Loads the dictionary from a text file with the format: Example:: @@ -230,7 +230,7 @@ def load(cls, f): and `#fairseq:duplicate` to keep them (for backward compatibility after bug fix) """ - d = cls() + d = cls(add_special_symbols=add_special_symbols) d.add_from_file(f) return d @@ -259,25 +259,25 @@ def add_from_file(self, f): try: line, field = line.rstrip().rsplit(" ", 1) if field == "#fairseq:overwrite": - overwrite, duplicate = True, False + overwrite = True line, field = line.rsplit(" ", 1) elif field == "#fairseq:duplicate": - overwrite, duplicate = False, True + overwrite = False line, field = line.rsplit(" ", 1) else: - overwrite, duplicate = False, False + if line in self: + raise RuntimeError( + "Duplicate word found when loading Dictionary: '{}'. " + "Duplicate words can overwrite earlier ones by adding the " + "#fairseq:overwrite flag at the end of the corresponding row " + "in the dictionary file. Use the #fairseq:duplicate flag " + "to keep duplicates in the dictionary (backward compatibility " + "after bug fix). If using the Camembert model, please " + "download an updated copy of the model file.".format(word) + ) + overwrite = True # default behaviour count = int(field) word = line - if word in self and not overwrite and not duplicate: - raise RuntimeError( - "Duplicate word found when loading Dictionary: '{}'. " - "Duplicate words can overwrite earlier ones by adding the " - "#fairseq:overwrite flag at the end of the corresponding row " - "in the dictionary file. Use the #fairseq:duplicate flag " - "to keep duplicates in the dictionary (backward compatibility " - "after bug fix). If using the Camembert model, please " - "download an updated copy of the model file.".format(word) - ) self.add_symbol(word, n=count, overwrite=overwrite) except ValueError: raise ValueError( @@ -297,7 +297,7 @@ def _get_meta(self): def _load_meta(self, lines): return 0 - + def save(self, f): """Stores dictionary into a text file""" ex_keys, ex_vals = self._get_meta() @@ -332,19 +332,20 @@ def encode_line( words = line_tokenizer(line) if reverse_order: words = list(reversed(words)) - ids = [] + nwords = len(words) + ids = torch.IntTensor(nwords + 1 if append_eos else nwords) - for word in words: + for i, word in enumerate(words): if add_if_not_exist: - idx = self.add_symbol(word, overwrite=True) + idx = self.add_symbol(word) else: idx = self.index(word) if consumer is not None: consumer(word, idx) - ids.append(idx) + ids[i] = idx if append_eos: - ids.append(self.eos_index) - return torch.tensor(ids, dtype=torch.int32) + ids[nwords] = self.eos_index + return ids @staticmethod def _add_file_to_dictionary_single_worker( @@ -366,7 +367,7 @@ def _add_file_to_dictionary_single_worker( def add_file_to_dictionary(filename, dict, tokenize, num_workers): def merge_result(counter): for w, c in sorted(counter.items()): - dict.add_symbol(w, c, overwrite=True) + dict.add_symbol(w, c) local_file = PathManager.get_local_path(filename) offsets = find_offsets(local_file, num_workers) From 752fb471cecca474f89d669f6f3b3fcd7e1f4712 Mon Sep 17 00:00:00 2001 From: Lydia Nishimwe Date: Wed, 24 Jul 2024 21:03:18 +0200 Subject: [PATCH 2/2] Update test_dictionary.py