Skip to content

Commit

Permalink
Merge pull request #1 from lydianish/nllb
Browse files Browse the repository at this point in the history
Update dictionary.py with fairseq overwrite bug fix
  • Loading branch information
lydianish authored Jul 24, 2024
2 parents a8df9d0 + 752fb47 commit af754b1
Showing 1 changed file with 38 additions and 37 deletions.
75 changes: 38 additions & 37 deletions fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
from collections import Counter
from multiprocessing import Pool

import torch

from fairseq import utils
from fairseq.data import data_utils
from fairseq.file_chunker_utils import Chunker, find_offsets
Expand All @@ -28,19 +26,21 @@ def __init__(
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
add_special_symbols=True,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
if add_special_symbols:
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)

def __eq__(self, other):
return self.indices == other.indices
Expand Down Expand Up @@ -124,7 +124,7 @@ def unk_string(self, escape=False):
else:
return self.unk_word

def add_symbol(self, word, n=1, overwrite=False):
def add_symbol(self, word, n=1, overwrite=True):
"""Adds a word to the dictionary"""
if word in self.indices and overwrite:
idx = self.indices[word]
Expand Down Expand Up @@ -215,7 +215,7 @@ def unk(self):
return self.unk_index

@classmethod
def load(cls, f):
def load(cls, f, add_special_symbols=True):
"""Loads the dictionary from a text file with the format:
Example::
Expand All @@ -230,7 +230,7 @@ def load(cls, f):
and `#fairseq:duplicate` to keep them (for backward compatibility
after bug fix)
"""
d = cls()
d = cls(add_special_symbols=add_special_symbols)
d.add_from_file(f)
return d

Expand Down Expand Up @@ -259,25 +259,25 @@ def add_from_file(self, f):
try:
line, field = line.rstrip().rsplit(" ", 1)
if field == "#fairseq:overwrite":
overwrite, duplicate = True, False
overwrite = True
line, field = line.rsplit(" ", 1)
elif field == "#fairseq:duplicate":
overwrite, duplicate = False, True
overwrite = False
line, field = line.rsplit(" ", 1)
else:
overwrite, duplicate = False, False
if line in self:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. Use the #fairseq:duplicate flag "
"to keep duplicates in the dictionary (backward compatibility "
"after bug fix). If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
)
overwrite = True # default behaviour
count = int(field)
word = line
if word in self and not overwrite and not duplicate:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. Use the #fairseq:duplicate flag "
"to keep duplicates in the dictionary (backward compatibility "
"after bug fix). If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
Expand All @@ -297,7 +297,7 @@ def _get_meta(self):

def _load_meta(self, lines):
return 0

def save(self, f):
"""Stores dictionary into a text file"""
ex_keys, ex_vals = self._get_meta()
Expand Down Expand Up @@ -332,19 +332,20 @@ def encode_line(
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
ids = []
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)

for word in words:
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word, overwrite=True)
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids.append(idx)
ids[i] = idx
if append_eos:
ids.append(self.eos_index)
return torch.tensor(ids, dtype=torch.int32)
ids[nwords] = self.eos_index
return ids

@staticmethod
def _add_file_to_dictionary_single_worker(
Expand All @@ -366,7 +367,7 @@ def _add_file_to_dictionary_single_worker(
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in sorted(counter.items()):
dict.add_symbol(w, c, overwrite=True)
dict.add_symbol(w, c)

local_file = PathManager.get_local_path(filename)
offsets = find_offsets(local_file, num_workers)
Expand Down

0 comments on commit af754b1

Please sign in to comment.