diff --git a/docs/conf.py b/docs/conf.py
index d6d150c1f0..52971a27e7 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -20,10 +20,11 @@
import os
import sys
+
# source code directory, relative to this file, for sphinx-autobuild
-sys.path.insert(0, os.path.abspath('..'))
+sys.path.insert(0, os.path.abspath(".."))
-source_suffix = ['.rst']
+source_suffix = [".rst"]
# -- General configuration ------------------------------------------------
@@ -35,34 +36,34 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.intersphinx',
- 'sphinx.ext.viewcode',
- 'sphinx.ext.napoleon',
- 'sphinxarg.ext',
+ "sphinx.ext.autodoc",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.napoleon",
+ "sphinxarg.ext",
]
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# General information about the project.
-project = 'fairseq'
-copyright = '2019, Facebook AI Research (FAIR)'
-author = 'Facebook AI Research (FAIR)'
+project = "fairseq"
+copyright = "2019, Facebook AI Research (FAIR)"
+author = "Facebook AI Research (FAIR)"
-github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/'
+github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
-version = '0.9.0'
+version = "0.9.0"
# The full version, including alpha/beta/rc tags.
-release = '0.9.0'
+release = "0.9.0"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@@ -74,11 +75,11 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
-exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
-highlight_language = 'python'
+pygments_style = "sphinx"
+highlight_language = "python"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
@@ -89,7 +90,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
-html_theme = 'sphinx_rtd_theme'
+html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
@@ -100,11 +101,11 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
html_context = {
- 'css_files': [
- '_static/theme_overrides.css', # override wide tables in RTD theme
+ "css_files": [
+ "_static/theme_overrides.css", # override wide tables in RTD theme
],
}
@@ -113,7 +114,7 @@
#
# This is required for the alabaster theme
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
-#html_sidebars = {
+# html_sidebars = {
# '**': [
# 'about.html',
# 'navigation.html',
@@ -121,12 +122,12 @@
# 'searchbox.html',
# 'donate.html',
# ]
-#}
+# }
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
- 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
- 'python': ('https://docs.python.org/', None),
- 'torch': ('https://pytorch.org/docs/master/', None),
+ "numpy": ("http://docs.scipy.org/doc/numpy/", None),
+ "python": ("https://docs.python.org/", None),
+ "torch": ("https://pytorch.org/docs/master/", None),
}
diff --git a/examples/__init__.py b/examples/__init__.py
index 9369be1b77..9a6b08a75b 100644
--- a/examples/__init__.py
+++ b/examples/__init__.py
@@ -3,6 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-__version__ = '0.9.0'
+__version__ = "0.9.0"
import examples.noisychannel # noqa
diff --git a/examples/backtranslation/deduplicate_lines.py b/examples/backtranslation/deduplicate_lines.py
index 35a407e556..50e458328c 100644
--- a/examples/backtranslation/deduplicate_lines.py
+++ b/examples/backtranslation/deduplicate_lines.py
@@ -7,8 +7,8 @@
import argparse
import fileinput
import hashlib
-from multiprocessing import Pool
import sys
+from multiprocessing import Pool
def get_hashes_and_lines(raw_line):
@@ -18,12 +18,12 @@ def get_hashes_and_lines(raw_line):
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('--workers', type=int, default=10)
- parser.add_argument('files', nargs='*', help='input files')
+ parser.add_argument("--workers", type=int, default=10)
+ parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args()
seen = set()
- with fileinput.input(args.files, mode='rb') as h:
+ with fileinput.input(args.files, mode="rb") as h:
pool = Pool(args.workers)
results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
for i, (hash, raw_line) in enumerate(results):
@@ -37,5 +37,5 @@ def main():
print(file=sys.stderr, flush=True)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/backtranslation/extract_bt_data.py b/examples/backtranslation/extract_bt_data.py
index 26a46942c8..e766391e87 100644
--- a/examples/backtranslation/extract_bt_data.py
+++ b/examples/backtranslation/extract_bt_data.py
@@ -11,26 +11,38 @@
def main():
- parser = argparse.ArgumentParser(description=(
- 'Extract back-translations from the stdout of fairseq-generate. '
- 'If there are multiply hypotheses for a source, we only keep the first one. '
- ))
- parser.add_argument('--output', required=True, help='output prefix')
- parser.add_argument('--srclang', required=True, help='source language (extracted from H-* lines)')
- parser.add_argument('--tgtlang', required=True, help='target language (extracted from S-* lines)')
- parser.add_argument('--minlen', type=int, help='min length filter')
- parser.add_argument('--maxlen', type=int, help='max length filter')
- parser.add_argument('--ratio', type=float, help='ratio filter')
- parser.add_argument('files', nargs='*', help='input files')
+ parser = argparse.ArgumentParser(
+ description=(
+ "Extract back-translations from the stdout of fairseq-generate. "
+ "If there are multiply hypotheses for a source, we only keep the first one. "
+ )
+ )
+ parser.add_argument("--output", required=True, help="output prefix")
+ parser.add_argument(
+ "--srclang", required=True, help="source language (extracted from H-* lines)"
+ )
+ parser.add_argument(
+ "--tgtlang", required=True, help="target language (extracted from S-* lines)"
+ )
+ parser.add_argument("--minlen", type=int, help="min length filter")
+ parser.add_argument("--maxlen", type=int, help="max length filter")
+ parser.add_argument("--ratio", type=float, help="ratio filter")
+ parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args()
def validate(src, tgt):
- srclen = len(src.split(' ')) if src != '' else 0
- tgtlen = len(tgt.split(' ')) if tgt != '' else 0
+ srclen = len(src.split(" ")) if src != "" else 0
+ tgtlen = len(tgt.split(" ")) if tgt != "" else 0
if (
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
- or (args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen))
- or (args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio))
+ or (
+ args.maxlen is not None
+ and (srclen > args.maxlen or tgtlen > args.maxlen)
+ )
+ or (
+ args.ratio is not None
+ and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
+ )
):
return False
return True
@@ -41,19 +53,20 @@ def safe_index(toks, index, default):
except IndexError:
return default
- with open(args.output + '.' + args.srclang, 'w') as src_h, \
- open(args.output + '.' + args.tgtlang, 'w') as tgt_h:
+ with open(args.output + "." + args.srclang, "w") as src_h, open(
+ args.output + "." + args.tgtlang, "w"
+ ) as tgt_h:
for line in tqdm(fileinput.input(args.files)):
- if line.startswith('S-'):
- tgt = safe_index(line.rstrip().split('\t'), 1, '')
- elif line.startswith('H-'):
+ if line.startswith("S-"):
+ tgt = safe_index(line.rstrip().split("\t"), 1, "")
+ elif line.startswith("H-"):
if tgt is not None:
- src = safe_index(line.rstrip().split('\t'), 2, '')
+ src = safe_index(line.rstrip().split("\t"), 2, "")
if validate(src, tgt):
print(src, file=src_h)
print(tgt, file=tgt_h)
tgt = None
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/byte_level_bpe/get_bitext.py b/examples/byte_level_bpe/get_bitext.py
index 7770ea667b..6ac1eeec1e 100644
--- a/examples/byte_level_bpe/get_bitext.py
+++ b/examples/byte_level_bpe/get_bitext.py
@@ -4,203 +4,251 @@
# LICENSE file in the root directory of this source tree.
-import os.path as op
import argparse
import os
-from multiprocessing import cpu_count
+import os.path as op
from collections import namedtuple
-from typing import Optional, List
+from multiprocessing import cpu_count
+from typing import List, Optional
import sentencepiece as sp
-
-from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
-from fairseq.data.encoders.byte_utils import byte_encode
-from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
-from fairseq.data.encoders.characters import Characters
from fairseq.data.encoders.byte_bpe import ByteBPE
+from fairseq.data.encoders.byte_utils import byte_encode
from fairseq.data.encoders.bytes import Bytes
+from fairseq.data.encoders.characters import Characters
+from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
+from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
-SPLITS = ['train', 'valid', 'test']
+SPLITS = ["train", "valid", "test"]
def _convert_xml(in_path: str, out_path: str):
- with open(in_path) as f, open(out_path, 'w') as f_o:
+ with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
ss = s.strip()
- if not ss.startswith('', '').split('">')
+ ss = ss.replace("", "").split('">')
assert len(ss) == 2
- f_o.write(ss[1].strip() + '\n')
+ f_o.write(ss[1].strip() + "\n")
def _convert_train(in_path: str, out_path: str):
- with open(in_path) as f, open(out_path, 'w') as f_o:
+ with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
ss = s.strip()
- if ss.startswith('<'):
+ if ss.startswith("<"):
continue
- f_o.write(ss.strip() + '\n')
+ f_o.write(ss.strip() + "\n")
def _get_bytes(in_path: str, out_path: str):
- with open(in_path) as f, open(out_path, 'w') as f_o:
+ with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
- f_o.write(Bytes.encode(s.strip()) + '\n')
+ f_o.write(Bytes.encode(s.strip()) + "\n")
def _get_chars(in_path: str, out_path: str):
- with open(in_path) as f, open(out_path, 'w') as f_o:
+ with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
- f_o.write(Characters.encode(s.strip()) + '\n')
+ f_o.write(Characters.encode(s.strip()) + "\n")
def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
- Args = namedtuple('Args', ['moses_source_lang', 'moses_target_lang',
- 'moses_no_dash_splits', 'moses_no_escape'])
- args = Args(moses_source_lang=src, moses_target_lang=tgt,
- moses_no_dash_splits=False, moses_no_escape=False)
+ Args = namedtuple(
+ "Args",
+ [
+ "moses_source_lang",
+ "moses_target_lang",
+ "moses_no_dash_splits",
+ "moses_no_escape",
+ ],
+ )
+ args = Args(
+ moses_source_lang=src,
+ moses_target_lang=tgt,
+ moses_no_dash_splits=False,
+ moses_no_escape=False,
+ )
pretokenizer = MosesTokenizer(args)
- with open(in_path) as f, open(out_path, 'w') as f_o:
+ with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
- f_o.write(pretokenizer.encode(s.strip()) + '\n')
+ f_o.write(pretokenizer.encode(s.strip()) + "\n")
def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
- with open(out_path, 'w') as f_o:
+ with open(out_path, "w") as f_o:
for lang in [src, tgt]:
- with open(f'{in_path_prefix}.{lang}') as f:
+ with open(f"{in_path_prefix}.{lang}") as f:
for s in f:
- f_o.write(byte_encode(s.strip()) + '\n')
+ f_o.write(byte_encode(s.strip()) + "\n")
def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
arguments = [
- f'--input={in_path}', f'--model_prefix={model_prefix}',
- f'--model_type=bpe', f'--vocab_size={vocab_size}',
- '--character_coverage=1.0', '--normalization_rule_name=identity',
- f'--num_threads={cpu_count()}'
+ f"--input={in_path}",
+ f"--model_prefix={model_prefix}",
+ f"--model_type=bpe",
+ f"--vocab_size={vocab_size}",
+ "--character_coverage=1.0",
+ "--normalization_rule_name=identity",
+ f"--num_threads={cpu_count()}",
]
- sp.SentencePieceTrainer.Train(' '.join(arguments))
+ sp.SentencePieceTrainer.Train(" ".join(arguments))
def _apply_bbpe(model_path: str, in_path: str, out_path: str):
- Args = namedtuple('Args', ['sentencepiece_model_path'])
+ Args = namedtuple("Args", ["sentencepiece_model_path"])
args = Args(sentencepiece_model_path=model_path)
tokenizer = ByteBPE(args)
- with open(in_path) as f, open(out_path, 'w') as f_o:
+ with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
- f_o.write(tokenizer.encode(s.strip()) + '\n')
+ f_o.write(tokenizer.encode(s.strip()) + "\n")
def _apply_bpe(model_path: str, in_path: str, out_path: str):
- Args = namedtuple('Args', ['sentencepiece_model'])
+ Args = namedtuple("Args", ["sentencepiece_model"])
args = Args(sentencepiece_model=model_path)
tokenizer = SentencepieceBPE(args)
- with open(in_path) as f, open(out_path, 'w') as f_o:
+ with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
- f_o.write(tokenizer.encode(s.strip()) + '\n')
+ f_o.write(tokenizer.encode(s.strip()) + "\n")
def _concat_files(in_paths: List[str], out_path: str):
- with open(out_path, 'w') as f_o:
+ with open(out_path, "w") as f_o:
for p in in_paths:
with open(p) as f:
for r in f:
f_o.write(r)
-def preprocess_iwslt17(root: str, src: str, tgt: str, bpe_size: Optional[int],
- need_chars: bool, bbpe_size: Optional[int],
- need_bytes: bool):
+def preprocess_iwslt17(
+ root: str,
+ src: str,
+ tgt: str,
+ bpe_size: Optional[int],
+ need_chars: bool,
+ bbpe_size: Optional[int],
+ need_bytes: bool,
+):
# extract bitext
- in_root = op.join(root, f'{src}-{tgt}')
+ in_root = op.join(root, f"{src}-{tgt}")
for lang in [src, tgt]:
_convert_train(
- op.join(in_root, f'train.tags.{src}-{tgt}.{lang}'),
- op.join(root, f'train.{lang}')
+ op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
+ op.join(root, f"train.{lang}"),
)
_convert_xml(
- op.join(in_root, f'IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml'),
- op.join(root, f'valid.{lang}')
+ op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
+ op.join(root, f"valid.{lang}"),
)
_convert_xml(
- op.join(in_root, f'IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml'),
- op.join(root, f'test.{lang}')
+ op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
+ op.join(root, f"test.{lang}"),
)
# pre-tokenize
for lang in [src, tgt]:
for split in SPLITS:
- pretokenize(op.join(root, f'{split}.{lang}'),
- op.join(root, f'{split}.moses.{lang}'), src, tgt)
+ pretokenize(
+ op.join(root, f"{split}.{lang}"),
+ op.join(root, f"{split}.moses.{lang}"),
+ src,
+ tgt,
+ )
# tokenize with BPE vocabulary
if bpe_size is not None:
# learn vocabulary
- concated_train_path = op.join(root, 'train.all')
+ concated_train_path = op.join(root, "train.all")
_concat_files(
- [op.join(root, 'train.moses.fr'), op.join(root, 'train.moses.en')],
- concated_train_path
+ [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
+ concated_train_path,
)
- bpe_model_prefix = op.join(root, f'spm_bpe{bpe_size}')
+ bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
_get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
os.remove(concated_train_path)
# apply
for lang in [src, tgt]:
for split in SPLITS:
_apply_bpe(
- bpe_model_prefix + '.model',
- op.join(root, f'{split}.moses.{lang}'),
- op.join(root, f'{split}.moses.bpe{bpe_size}.{lang}')
+ bpe_model_prefix + ".model",
+ op.join(root, f"{split}.moses.{lang}"),
+ op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
)
# tokenize with bytes vocabulary
if need_bytes:
for lang in [src, tgt]:
for split in SPLITS:
- _get_bytes(op.join(root, f'{split}.moses.{lang}'),
- op.join(root, f'{split}.moses.bytes.{lang}'))
+ _get_bytes(
+ op.join(root, f"{split}.moses.{lang}"),
+ op.join(root, f"{split}.moses.bytes.{lang}"),
+ )
# tokenize with characters vocabulary
if need_chars:
for lang in [src, tgt]:
for split in SPLITS:
- _get_chars(op.join(root, f'{split}.moses.{lang}'),
- op.join(root, f'{split}.moses.chars.{lang}'))
+ _get_chars(
+ op.join(root, f"{split}.moses.{lang}"),
+ op.join(root, f"{split}.moses.chars.{lang}"),
+ )
# tokenize with byte-level BPE vocabulary
if bbpe_size is not None:
# learn vocabulary
- bchar_path = op.join(root, 'train.bchar')
- _convert_to_bchar(op.join(root, 'train.moses'), src, tgt, bchar_path)
- bbpe_model_prefix = op.join(root, f'spm_bbpe{bbpe_size}')
+ bchar_path = op.join(root, "train.bchar")
+ _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
+ bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
_get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
os.remove(bchar_path)
# apply
for lang in [src, tgt]:
for split in SPLITS:
_apply_bbpe(
- bbpe_model_prefix + '.model',
- op.join(root, f'{split}.moses.{lang}'),
- op.join(root, f'{split}.moses.bbpe{bbpe_size}.{lang}')
+ bbpe_model_prefix + ".model",
+ op.join(root, f"{split}.moses.{lang}"),
+ op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
)
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('--root', type=str, default='data')
- parser.add_argument('--bpe-vocab', default=None, type=int,
- help='Generate tokenized bitext with BPE of size K.'
- 'Default to None (disabled).')
- parser.add_argument('--bbpe-vocab', default=None, type=int,
- help='Generate tokenized bitext with BBPE of size K.'
- 'Default to None (disabled).')
- parser.add_argument('--byte-vocab', action='store_true',
- help='Generate tokenized bitext with bytes vocabulary')
- parser.add_argument('--char-vocab', action='store_true',
- help='Generate tokenized bitext with chars vocabulary')
+ parser.add_argument("--root", type=str, default="data")
+ parser.add_argument(
+ "--bpe-vocab",
+ default=None,
+ type=int,
+ help="Generate tokenized bitext with BPE of size K."
+ "Default to None (disabled).",
+ )
+ parser.add_argument(
+ "--bbpe-vocab",
+ default=None,
+ type=int,
+ help="Generate tokenized bitext with BBPE of size K."
+ "Default to None (disabled).",
+ )
+ parser.add_argument(
+ "--byte-vocab",
+ action="store_true",
+ help="Generate tokenized bitext with bytes vocabulary",
+ )
+ parser.add_argument(
+ "--char-vocab",
+ action="store_true",
+ help="Generate tokenized bitext with chars vocabulary",
+ )
args = parser.parse_args()
- preprocess_iwslt17(args.root, 'fr', 'en', args.bpe_vocab, args.char_vocab,
- args.bbpe_vocab, args.byte_vocab)
+ preprocess_iwslt17(
+ args.root,
+ "fr",
+ "en",
+ args.bpe_vocab,
+ args.char_vocab,
+ args.bbpe_vocab,
+ args.byte_vocab,
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/byte_level_bpe/gru_transformer.py b/examples/byte_level_bpe/gru_transformer.py
index 7ba8e4084f..d4efa93a4d 100644
--- a/examples/byte_level_bpe/gru_transformer.py
+++ b/examples/byte_level_bpe/gru_transformer.py
@@ -11,7 +11,7 @@
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import register_model, register_model_architecture
-from fairseq.models.transformer import TransformerModel, TransformerEncoder
+from fairseq.models.transformer import TransformerEncoder, TransformerModel
@register_model("gru_transformer")
@@ -24,9 +24,12 @@ def build_encoder(cls, args, src_dict, embed_tokens):
class GRUTransformerEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
- self.emb_ctx = nn.GRU(input_size=embed_tokens.embedding_dim,
- hidden_size=embed_tokens.embedding_dim // 2,
- num_layers=1, bidirectional=True)
+ self.emb_ctx = nn.GRU(
+ input_size=embed_tokens.embedding_dim,
+ hidden_size=embed_tokens.embedding_dim // 2,
+ num_layers=1,
+ bidirectional=True,
+ )
def forward_embedding(self, src_tokens):
# embed tokens and positions
diff --git a/examples/constrained_decoding/normalize.py b/examples/constrained_decoding/normalize.py
index 2a7ae03102..4ae2b5111b 100755
--- a/examples/constrained_decoding/normalize.py
+++ b/examples/constrained_decoding/normalize.py
@@ -16,11 +16,12 @@ def main(args):
print(normalizer.normalize(line.rstrip()), flush=True)
-if __name__ == '__main__':
+if __name__ == "__main__":
import argparse
+
parser = argparse.ArgumentParser()
- parser.add_argument('--lang', '-l', default='en')
- parser.add_argument('--penn', '-p', action='store_true')
+ parser.add_argument("--lang", "-l", default="en")
+ parser.add_argument("--penn", "-p", action="store_true")
args = parser.parse_args()
main(args)
diff --git a/examples/constrained_decoding/tok.py b/examples/constrained_decoding/tok.py
index 9215a66538..b1f888a8c0 100755
--- a/examples/constrained_decoding/tok.py
+++ b/examples/constrained_decoding/tok.py
@@ -6,12 +6,14 @@
# LICENSE file in the root directory of this source tree.
import sys
+
import sacremoses
def main(args):
"""Tokenizes, preserving tabs"""
mt = sacremoses.MosesTokenizer(lang=args.lang)
+
def tok(s):
return mt.tokenize(s, return_str=True)
@@ -20,12 +22,13 @@ def tok(s):
print(*parts, sep="\t", flush=True)
-if __name__ == '__main__':
+if __name__ == "__main__":
import argparse
+
parser = argparse.ArgumentParser()
- parser.add_argument('--lang', '-l', default='en')
- parser.add_argument('--penn', '-p', action='store_true')
- parser.add_argument('--fields', '-f', help="fields to tokenize")
+ parser.add_argument("--lang", "-l", default="en")
+ parser.add_argument("--penn", "-p", action="store_true")
+ parser.add_argument("--fields", "-f", help="fields to tokenize")
args = parser.parse_args()
main(args)
diff --git a/examples/criss/mining/mine.py b/examples/criss/mining/mine.py
index a902a4ab64..c86f73ae87 100644
--- a/examples/criss/mining/mine.py
+++ b/examples/criss/mining/mine.py
@@ -3,14 +3,15 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import faiss
-import numpy as np
-import glob
import argparse
+import glob
from subprocess import check_call
+import faiss
+import numpy as np
+
-GB = 1024*1024*1024
+GB = 1024 * 1024 * 1024
def call(cmd):
@@ -18,14 +19,14 @@ def call(cmd):
check_call(cmd, shell=True)
-def get_batches(directory, lang, prefix='all_avg_pool'):
+def get_batches(directory, lang, prefix="all_avg_pool"):
print(f"Finding in {directory}/{prefix}.{lang}*")
- files = glob.glob(f'{directory}/{prefix}.{lang}*')
+ files = glob.glob(f"{directory}/{prefix}.{lang}*")
emb_files = []
txt_files = []
for emb_fi in files:
emb_files.append(emb_fi)
- txt_fi = emb_fi.replace(prefix, 'sentences')
+ txt_fi = emb_fi.replace(prefix, "sentences")
txt_files.append(txt_fi)
return emb_files, txt_files
@@ -38,7 +39,7 @@ def load_batch(emb_file, dim):
return embeddings
-def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'):
+def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
sims = []
inds = []
xfrom = 0
@@ -53,7 +54,7 @@ def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction='x2y'):
y_batch = load_batch(y_batch_f, dim)
neighbor_size = min(k, y_batch.shape[0])
yto = yfrom + y_batch.shape[0]
- print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto))
+ print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
idx = faiss.IndexFlatIP(dim)
idx = faiss.index_cpu_to_all_gpus(idx)
idx.add(y_batch)
@@ -86,8 +87,10 @@ def score(sim, fwd_mean, bwd_mean, margin):
return margin(sim, (fwd_mean + bwd_mean) / 2)
-def score_candidates(sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False):
- print(' - scoring {:d} candidates'.format(sim_mat.shape[0]))
+def score_candidates(
+ sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
+):
+ print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
scores = np.zeros(candidate_inds.shape)
for i in range(scores.shape[0]):
for j in range(scores.shape[1]):
@@ -106,42 +109,50 @@ def load_text(files):
return all_sentences
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Mine bitext')
- parser.add_argument('--src-lang', help='Source language')
- parser.add_argument('--tgt-lang', help='Target language')
- parser.add_argument('--dict-path', help='Path to dictionary file', default='dict.txt')
- parser.add_argument('--spm-path', help='Path to SPM model file', default='sentence.bpe.model')
- parser.add_argument('--dim', type=int, default=1024,
- help='Embedding dimension')
- parser.add_argument('--mem', type=int, default=5,
- help='Memory in GB')
- parser.add_argument('--src-dir', help='Source directory')
- parser.add_argument('--tgt-dir', help='Target directory')
- parser.add_argument('--output', help='Output path')
- parser.add_argument('--neighborhood', type=int, default=4,
- help='Embedding dimension')
- parser.add_argument('--threshold', type=float, default=1.06,
- help='Threshold on mined bitext')
- parser.add_argument('--valid-size', type=int, default=2000,
- help='Number of sentences used for validation set')
- parser.add_argument('--min-count', type=int, default=50000,
- help='Min num sentences used for each language')
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Mine bitext")
+ parser.add_argument("--src-lang", help="Source language")
+ parser.add_argument("--tgt-lang", help="Target language")
+ parser.add_argument(
+ "--dict-path", help="Path to dictionary file", default="dict.txt"
+ )
+ parser.add_argument(
+ "--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
+ )
+ parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
+ parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
+ parser.add_argument("--src-dir", help="Source directory")
+ parser.add_argument("--tgt-dir", help="Target directory")
+ parser.add_argument("--output", help="Output path")
+ parser.add_argument(
+ "--neighborhood", type=int, default=4, help="Embedding dimension"
+ )
+ parser.add_argument(
+ "--threshold", type=float, default=1.06, help="Threshold on mined bitext"
+ )
+ parser.add_argument(
+ "--valid-size",
+ type=int,
+ default=2000,
+ help="Number of sentences used for validation set",
+ )
+ parser.add_argument(
+ "--min-count",
+ type=int,
+ default=50000,
+ help="Min num sentences used for each language",
+ )
args = parser.parse_args()
x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
margin = lambda a, b: a / b
y2x_sim, y2x_ind = knnGPU_sharded(
- y_batches_f, x_batches_f,
- args.dim,
- args.neighborhood,
- direction='y2x')
+ y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
+ )
x2y_sim, x2y_ind = knnGPU_sharded(
- x_batches_f, y_batches_f,
- args.dim,
- args.neighborhood,
- direction='x2y')
+ x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
+ )
x2y_mean = x2y_sim.mean(axis=1)
y2x_mean = y2x_sim.mean(axis=1)
@@ -149,8 +160,13 @@ def load_text(files):
bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
- indices = np.stack((np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
- np.concatenate((fwd_best, np.arange(y2x_ind.shape[0])))), axis=1)
+ indices = np.stack(
+ (
+ np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
+ np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
+ ),
+ axis=1,
+ )
scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
x_sentences = load_text(x_sents_f)
@@ -162,20 +178,20 @@ def load_text(files):
directory = args.output
call(f"mkdir -p {directory}")
src_out = open(
- f'{directory}/all.{args.src_lang}',
- mode='w',
- encoding='utf-8',
- errors='surrogateescape')
+ f"{directory}/all.{args.src_lang}",
+ mode="w",
+ encoding="utf-8",
+ errors="surrogateescape",
+ )
tgt_out = open(
- f'{directory}/all.{args.tgt_lang}',
- mode='w',
- encoding='utf-8',
- errors='surrogateescape')
+ f"{directory}/all.{args.tgt_lang}",
+ mode="w",
+ encoding="utf-8",
+ errors="surrogateescape",
+ )
scores_out = open(
- f'{directory}/all.scores',
- mode='w',
- encoding='utf-8',
- errors='surrogateescape')
+ f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
+ )
count = 0
for i in np.argsort(-scores):
src_ind, trg_ind = indices[i]
@@ -195,20 +211,23 @@ def load_text(files):
scores_out.close()
print(f"Found {count} pairs for threshold={threshold}")
- with open(f'{directory}/all.{args.src_lang}') as all_s, \
- open(f'{directory}/all.{args.tgt_lang}') as all_t, \
- open(f'{directory}/valid.{args.src_lang}', 'w') as valid_s, \
- open(f'{directory}/valid.{args.tgt_lang}', 'w') as valid_t, \
- open(f'{directory}/train.{args.src_lang}', 'w') as train_s, \
- open(f'{directory}/train.{args.tgt_lang}', 'w') as train_t:
- count = 0
- for s_line, t_line in zip(all_s, all_t):
- s_line = s_line.split('\t')[1]
- t_line = t_line.split('\t')[1]
- if count >= args.valid_size:
- train_s.write(s_line)
- train_t.write(t_line)
- else:
- valid_s.write(s_line)
- valid_t.write(t_line)
- count += 1
+ with open(f"{directory}/all.{args.src_lang}") as all_s, open(
+ f"{directory}/all.{args.tgt_lang}"
+ ) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
+ f"{directory}/valid.{args.tgt_lang}", "w"
+ ) as valid_t, open(
+ f"{directory}/train.{args.src_lang}", "w"
+ ) as train_s, open(
+ f"{directory}/train.{args.tgt_lang}", "w"
+ ) as train_t:
+ count = 0
+ for s_line, t_line in zip(all_s, all_t):
+ s_line = s_line.split("\t")[1]
+ t_line = t_line.split("\t")[1]
+ if count >= args.valid_size:
+ train_s.write(s_line)
+ train_t.write(t_line)
+ else:
+ valid_s.write(s_line)
+ valid_t.write(t_line)
+ count += 1
diff --git a/examples/criss/save_encoder.py b/examples/criss/save_encoder.py
index 8132bbf0fa..4d0f17f0f2 100644
--- a/examples/criss/save_encoder.py
+++ b/examples/criss/save_encoder.py
@@ -7,27 +7,29 @@
Translate pre-processed data with a trained model.
"""
+import numpy as np
import torch
-
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.sequence_generator import EnsembleModel
-import numpy as np
-def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False):
+def get_avg_pool(
+ models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
+):
model = EnsembleModel(models)
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
- k: v for k, v in sample['net_input'].items()
- if k != 'prev_output_tokens'
+ k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)
np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
- encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(np.float32)
+ encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
+ np.float32
+ )
encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
if has_langtok:
encoder_mask = encoder_mask[1:, :, :]
@@ -38,13 +40,15 @@ def get_avg_pool(models, sample, prefix_tokens, src_dict, remove_bpe, has_langto
def main(args):
- assert args.path is not None, '--path required for generation!'
- assert not args.sampling or args.nbest == args.beam, \
- '--sampling requires --nbest to be equal to --beam'
- assert args.replace_unk is None or args.raw_text, \
- '--replace-unk requires a raw text dataset (--raw-text)'
-
- args.beam=1
+ assert args.path is not None, "--path required for generation!"
+ assert (
+ not args.sampling or args.nbest == args.beam
+ ), "--sampling requires --nbest to be equal to --beam"
+ assert (
+ args.replace_unk is None or args.raw_text
+ ), "--replace-unk requires a raw text dataset (--raw-text)"
+
+ args.beam = 1
utils.import_user_module(args)
if args.max_tokens is None:
@@ -58,15 +62,15 @@ def main(args):
# Set dictionaries
try:
- src_dict = getattr(task, 'source_dictionary', None)
+ src_dict = getattr(task, "source_dictionary", None)
except NotImplementedError:
src_dict = None
tgt_dict = task.target_dictionary
# Load ensemble
- print('| loading model(s) from {}'.format(args.path))
+ print("| loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
- args.path.split(':'),
+ args.path.split(":"),
arg_overrides=eval(args.model_overrides),
task=task,
)
@@ -105,9 +109,9 @@ def main(args):
shard_id = 0
all_avg_pool = None
encoder_has_langtok = (
- hasattr(task.args, 'encoder_langtok')
+ hasattr(task.args, "encoder_langtok")
and task.args.encoder_langtok is not None
- and hasattr(task.args, 'lang_tok_replacing_bos_eos')
+ and hasattr(task.args, "lang_tok_replacing_bos_eos")
and not task.args.lang_tok_replacing_bos_eos
)
with progress_bar.build_progress_bar(args, itr) as t:
@@ -116,34 +120,42 @@ def main(args):
print("Skipping None")
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
- if 'net_input' not in sample:
+ if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
- prefix_tokens = sample['target'][:, :args.prefix_size]
+ prefix_tokens = sample["target"][:, : args.prefix_size]
with torch.no_grad():
avg_pool = get_avg_pool(
- models, sample, prefix_tokens, src_dict,
- args.remove_bpe,
- has_langtok=encoder_has_langtok)
+ models,
+ sample,
+ prefix_tokens,
+ src_dict,
+ args.remove_bpe,
+ has_langtok=encoder_has_langtok,
+ )
if all_avg_pool is not None:
all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
else:
all_avg_pool = avg_pool
- if not isinstance(sample['id'], list):
- sample_ids = sample['id'].tolist()
+ if not isinstance(sample["id"], list):
+ sample_ids = sample["id"].tolist()
else:
- sample_ids = sample['id']
+ sample_ids = sample["id"]
for i, sample_id in enumerate(sample_ids):
# Remove padding
- src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
+ src_tokens = utils.strip_pad(
+ sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
+ )
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
- src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
+ src_str = task.dataset(args.gen_subset).src.get_original_text(
+ sample_id
+ )
else:
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
@@ -152,37 +164,50 @@ def main(args):
if not args.quiet:
if src_dict is not None:
- print('S-{}\t{}'.format(sample_id, src_str))
+ print("S-{}\t{}".format(sample_id, src_str))
source_sentences.append(f"{sample_id}\t{src_str}")
- num_sentences += sample['nsentences']
+ num_sentences += sample["nsentences"]
if all_avg_pool.shape[0] >= 1000000:
- with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
- 'w') as avg_pool_file:
+ with open(
+ f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
+ "w",
+ ) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file)
- with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
- sentence_file.writelines(f'{line}\n' for line in source_sentences)
+ with open(
+ f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
+ "w",
+ ) as sentence_file:
+ sentence_file.writelines(f"{line}\n" for line in source_sentences)
all_avg_pool = None
source_sentences = []
shard_id += 1
if all_avg_pool is not None:
- with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
- 'w') as avg_pool_file:
+ with open(
+ f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
+ ) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file)
- with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
- sentence_file.writelines(f'{line}\n' for line in source_sentences)
+ with open(
+ f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
+ ) as sentence_file:
+ sentence_file.writelines(f"{line}\n" for line in source_sentences)
return None
def cli_main():
parser = options.get_generation_parser()
- parser.add_argument('--encoder-save-dir', default='', type=str, metavar='N',
- help='directory to save encoder outputs')
+ parser.add_argument(
+ "--encoder-save-dir",
+ default="",
+ type=str,
+ metavar="N",
+ help="directory to save encoder outputs",
+ )
args = options.parse_args_and_arch(parser)
main(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/examples/criss/sentence_retrieval/encoder_analysis.py b/examples/criss/sentence_retrieval/encoder_analysis.py
index c0d74af23a..b41bfbe387 100644
--- a/examples/criss/sentence_retrieval/encoder_analysis.py
+++ b/examples/criss/sentence_retrieval/encoder_analysis.py
@@ -3,10 +3,11 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import numpy as np
import argparse
import glob
+import numpy as np
+
DIM = 1024
@@ -14,9 +15,13 @@
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
target_ids = [tid for tid in target_embs]
source_mat = np.stack(source_embs.values(), axis=0)
- normalized_source_mat = source_mat / np.linalg.norm(source_mat, axis=1, keepdims=True)
+ normalized_source_mat = source_mat / np.linalg.norm(
+ source_mat, axis=1, keepdims=True
+ )
target_mat = np.stack(target_embs.values(), axis=0)
- normalized_target_mat = target_mat / np.linalg.norm(target_mat, axis=1, keepdims=True)
+ normalized_target_mat = target_mat / np.linalg.norm(
+ target_mat, axis=1, keepdims=True
+ )
sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
if return_sim_mat:
return sim_mat
@@ -36,14 +41,14 @@ def load_embeddings(directory, LANGS):
lang_dir = f"{directory}/{lang}"
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
for embed_file in embedding_files:
- shard_id = embed_file.split('.')[-1]
+ shard_id = embed_file.split(".")[-1]
embeddings = np.fromfile(embed_file, dtype=np.float32)
num_rows = embeddings.shape[0] // DIM
embeddings = embeddings.reshape((num_rows, DIM))
- with open(f'{lang_dir}/sentences.{lang}.{shard_id}') as sentence_file:
+ with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
for idx, line in enumerate(sentence_file):
- sentence_id, sentence = line.strip().split('\t')
+ sentence_id, sentence = line.strip().split("\t")
sentence_texts[lang][sentence_id] = sentence
sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
@@ -55,7 +60,7 @@ def compute_accuracy(directory, LANGS):
top_1_accuracy = {}
- top1_str = " ".join(LANGS) + '\n'
+ top1_str = " ".join(LANGS) + "\n"
for source_lang in LANGS:
top_1_accuracy[source_lang] = {}
top1_str += f"{source_lang} "
@@ -63,8 +68,8 @@ def compute_accuracy(directory, LANGS):
top1 = 0
top5 = 0
neighbors_map = compute_dist(
- sentence_embeddings[source_lang],
- sentence_embeddings[target_lang])
+ sentence_embeddings[source_lang], sentence_embeddings[target_lang]
+ )
for sentence_id, neighbors in neighbors_map.items():
if sentence_id == neighbors[0]:
top1 += 1
@@ -75,17 +80,13 @@ def compute_accuracy(directory, LANGS):
top1_str += "\n"
print(top1_str)
- print(top1_str, file=open(f"{directory}/accuracy", 'w'))
+ print(top1_str, file=open(f"{directory}/accuracy", "w"))
if __name__ == "__main__":
- parser = argparse.ArgumentParser(description='Analyze encoder outputs')
- parser.add_argument('directory',
- help='Source language corpus'
- )
- parser.add_argument('--langs',
- help='List of langs'
- )
+ parser = argparse.ArgumentParser(description="Analyze encoder outputs")
+ parser.add_argument("directory", help="Source language corpus")
+ parser.add_argument("--langs", help="List of langs")
args = parser.parse_args()
- langs = args.langs.split(',')
+ langs = args.langs.split(",")
compute_accuracy(args.directory, langs)
diff --git a/examples/latent_depth/src/__init__.py b/examples/latent_depth/src/__init__.py
index 8a86fa5817..c5fa76039f 100644
--- a/examples/latent_depth/src/__init__.py
+++ b/examples/latent_depth/src/__init__.py
@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .models import latent_multilingual_transformer # noqa
-from .modules import latent_layers # noqa
-from .loss import latent_depth # noqa
-from . import multilingual_translation_latent_depth # noqa
+from . import multilingual_translation_latent_depth # noqa
+from .loss import latent_depth # noqa
+from .models import latent_multilingual_transformer # noqa
+from .modules import latent_layers # noqa
diff --git a/examples/latent_depth/src/loss/latent_depth.py b/examples/latent_depth/src/loss/latent_depth.py
index f647c758ee..a3b9535eca 100644
--- a/examples/latent_depth/src/loss/latent_depth.py
+++ b/examples/latent_depth/src/loss/latent_depth.py
@@ -3,8 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
import math
+
+import torch
from torch.nn.modules.loss import _Loss
@@ -19,17 +20,16 @@ def forward(self, layer_samples, lang_idx, update_num, sample_size):
eps = 1e-7
if prior == "uniform":
# uniform prior
- kl_loss = (samples * (
- torch.log(samples + eps) - math.log(0.5)
- )).sum(-1)
+ kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
elif prior == "agged_posterior":
# aggregated posterior
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
agged_q = torch.sum(y_t, dim=0)
row_norm = agged_q.sum(-1)
normed_agg_q = agged_q / row_norm
- kl_loss = (samples * (
- torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1)
+ kl_loss = (
+ samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
+ ).sum(-1)
else:
raise NotImplementedError("The specified prior is not implemented.")
@@ -37,7 +37,9 @@ def forward(self, layer_samples, lang_idx, update_num, sample_size):
kl_loss /= layer_samples[0].size()[0]
kl_weight = min(
self.args.sparsity_weight,
- (update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates
+ (update_num - self.args.soft_update)
+ * self.args.sparsity_weight
+ / self.args.anneal_updates,
)
kl_loss *= kl_weight * sample_size
return kl_loss
@@ -58,15 +60,17 @@ def forward(self, layer_samples_list, update_num, sample_size):
share_loss = 0
global_sparsity_loss = 0
layer_samples = torch.stack(layer_samples_list, dim=0)
- if ((self.args.target_layers > 0 or self.args.share_weight > 0) and
- update_num > (self.args.soft_update + self.args.anneal_updates)):
+ if (
+ self.args.target_layers > 0 or self.args.share_weight > 0
+ ) and update_num > (self.args.soft_update + self.args.anneal_updates):
# anneal sparsity weight
if update_num < (self.args.anneal_updates + self.args.soft_update):
weight_anneal = 0
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
weight_anneal = (
(update_num - self.args.soft_update - self.args.anneal_updates)
- * self.args.share_weight / self.args.anneal_updates
+ * self.args.share_weight
+ / self.args.anneal_updates
)
else:
weight_anneal = 1
@@ -75,12 +79,21 @@ def forward(self, layer_samples_list, update_num, sample_size):
layer_utilization /= layer_samples.size()[0]
if self.args.share_weight > 0:
# encouraging sharing across languages
- share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0)
- batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss
+ share_loss = sum(
+ -1.0 * v * math.log(v) for v in layer_utilization if v > 0
+ )
+ batch_loss += (
+ weight_anneal * self.args.share_weight * sample_size * share_loss
+ )
if self.args.target_layers > 0:
# computed expected number of layers selected
expeted_layers = sum(layer_utilization)
# compute l2 loss wrt target number of layers
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
- batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss
+ batch_loss += (
+ weight_anneal
+ * self.args.share_weight
+ * sample_size
+ * global_sparsity_loss
+ )
return batch_loss
diff --git a/examples/latent_depth/src/models/latent_multilingual_transformer.py b/examples/latent_depth/src/models/latent_multilingual_transformer.py
index 97573cbd75..9e075fcc47 100644
--- a/examples/latent_depth/src/models/latent_multilingual_transformer.py
+++ b/examples/latent_depth/src/models/latent_multilingual_transformer.py
@@ -3,34 +3,31 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from fairseq.models import (
- register_model,
- register_model_architecture,
-)
+from fairseq.models import register_model, register_model_architecture
+from fairseq.models.multilingual_transformer import MultilingualTransformerModel
from fairseq.models.transformer import (
- base_architecture,
- TransformerEncoder,
TransformerDecoder,
+ TransformerEncoder,
+ base_architecture,
)
-from fairseq.models.multilingual_transformer import MultilingualTransformerModel
-from .latent_transformer import (
- LatentTransformerEncoder,
- LatentTransformerDecoder,
-)
+from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
-@register_model('latent_multilingual_transformer')
+@register_model("latent_multilingual_transformer")
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
"""A variant of standard multilingual Transformer models which encoder and/or
- decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
+ decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
(https://arxiv.org/abs/2009.13102).
"""
+
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
if is_encoder:
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
- return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs))
+ return LatentTransformerEncoder(
+ args, lang_dict, embed_tokens, num_logits=len(langs)
+ )
else:
return TransformerEncoder(args, lang_dict, embed_tokens)
else:
@@ -42,19 +39,21 @@ def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
return TransformerDecoder(args, lang_dict, embed_tokens)
-@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer')
+@register_model_architecture(
+ "latent_multilingual_transformer", "latent_multilingual_transformer"
+)
def latent_multilingual_architecture(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
- args.encoder_layers = getattr(args, 'encoder_layers', 12)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
- args.decoder_layers = getattr(args, 'decoder_layers', 24)
- args.share_encoders = getattr(args, 'share_encoders', True)
- args.share_decoders = getattr(args, 'share_decoders', True)
- args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True)
- args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
+ args.decoder_layers = getattr(args, "decoder_layers", 24)
+ args.share_encoders = getattr(args, "share_encoders", True)
+ args.share_decoders = getattr(args, "share_decoders", True)
+ args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
+ args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
base_architecture(args)
diff --git a/examples/latent_depth/src/models/latent_transformer.py b/examples/latent_depth/src/models/latent_transformer.py
index 5d47340f58..db30239bff 100644
--- a/examples/latent_depth/src/models/latent_transformer.py
+++ b/examples/latent_depth/src/models/latent_transformer.py
@@ -7,26 +7,27 @@
import torch.nn as nn
from fairseq.models.fairseq_encoder import EncoderOut
-from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
-from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
-from ..modules.latent_layers import LayerSelect
+from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
+from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
from torch import Tensor
+from ..modules.latent_layers import LayerSelect
+
class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder.
"""
+
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.encoder_layers
super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
- self.layers = nn.ModuleList([
- self._build_encoder_layer(args, idx)
- for idx in range(args.encoder_layers)
- ])
+ self.layers = nn.ModuleList(
+ [self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
+ )
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
@@ -50,6 +51,7 @@ class LatentTransformerEncoderLayer(TransformerEncoderLayer):
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
"""
+
def __init__(self, args, idx, layer_select=None):
super().__init__(args)
self.idx = idx
@@ -63,7 +65,10 @@ class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder.
"""
- def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1):
+
+ def __init__(
+ self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
+ ):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(
@@ -71,16 +76,20 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_lo
)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
- self.layers = nn.ModuleList([
- self._build_decoder_layer(args, no_encoder_attn, idx)
- for idx in range(args.decoder_layers)
- ])
+ self.layers = nn.ModuleList(
+ [
+ self._build_decoder_layer(args, no_encoder_attn, idx)
+ for idx in range(args.decoder_layers)
+ ]
+ )
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
- return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn)
+ return LatentTransformerDecoderLayer(
+ args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
+ )
def forward(
self,
@@ -119,8 +128,15 @@ class LatentTransformerDecoderLayer(TransformerDecoderLayer):
(default: False).
"""
+
def __init__(
- self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
+ self,
+ args,
+ idx,
+ layer_select=None,
+ no_encoder_attn=False,
+ add_bias_kv=False,
+ add_zero_attn=False,
):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx
diff --git a/examples/latent_depth/src/modules/latent_layers.py b/examples/latent_depth/src/modules/latent_layers.py
index e772ac3237..a2b8ab4476 100644
--- a/examples/latent_depth/src/modules/latent_layers.py
+++ b/examples/latent_depth/src/modules/latent_layers.py
@@ -12,6 +12,7 @@ class LayerSelect(nn.Module):
either (soft) weighting or (hard) selection of residual connection.
https://arxiv.org/abs/2009.13102
"""
+
def __init__(self, num_layers, num_logits, args):
super(LayerSelect, self).__init__()
self.args = args
@@ -27,14 +28,14 @@ def __init__(self, num_layers, num_logits, args):
@staticmethod
def add_args(parser):
parser.add_argument(
- '--soft-select',
- action='store_true',
- help='use soft samples in training an inference'
+ "--soft-select",
+ action="store_true",
+ help="use soft samples in training an inference",
)
- parser.add_argument('--sampling-tau', type=float, help='sampling temperature')
+ parser.add_argument("--sampling-tau", type=float, help="sampling temperature")
def sample(self, logit_idx):
- """ To leverage the efficiency of distributed training, samples for all
+ """To leverage the efficiency of distributed training, samples for all
layers are computed at once for each logit_idx. Logits are parameters
learnt independent of each other.
@@ -43,7 +44,9 @@ def sample(self, logit_idx):
"""
assert logit_idx is not None
self.samples = self._gumbel_sigmoid(
- self.layer_logits[logit_idx, :].detach() if self.detach_grad else self.layer_logits[logit_idx, :],
+ self.layer_logits[logit_idx, :].detach()
+ if self.detach_grad
+ else self.layer_logits[logit_idx, :],
dim=-1,
tau=self.tau,
hard=self.hard_select,
@@ -54,10 +57,20 @@ def forward(self, i):
sample = self.samples[i]
return sample
- def _gumbel_sigmoid(self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5):
+ def _gumbel_sigmoid(
+ self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
+ ):
# ~Gumbel(0,1)
- gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
- gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
+ gumbels1 = (
+ -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
+ .exponential_()
+ .log()
+ )
+ gumbels2 = (
+ -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
+ .exponential_()
+ .log()
+ )
# Difference of two gumbels because we apply a sigmoid
gumbels1 = (logits + gumbels1 - gumbels2) / tau
y_soft = gumbels1.sigmoid()
diff --git a/examples/latent_depth/src/multilingual_translation_latent_depth.py b/examples/latent_depth/src/multilingual_translation_latent_depth.py
index 1a19f8f8f9..b5cd51a470 100644
--- a/examples/latent_depth/src/multilingual_translation_latent_depth.py
+++ b/examples/latent_depth/src/multilingual_translation_latent_depth.py
@@ -5,10 +5,11 @@
from fairseq.tasks import register_task
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
+
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
-@register_task('multilingual_translation_latent_depth')
+@register_task("multilingual_translation_latent_depth")
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
"""A task for multiple translation with latent depth.
@@ -39,7 +40,9 @@ def add_args(parser):
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
- self.src_langs, self.tgt_langs = zip(*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs])
+ self.src_langs, self.tgt_langs = zip(
+ *[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]
+ )
if self.training and self.encoder_latent_layer:
assert self.args.share_encoders
if self.training and self.decoder_latent_layer:
@@ -47,46 +50,56 @@ def __init__(self, args, dicts, training):
if training or self.encoder_latent_layer or self.decoder_latent_layer:
self.lang_pairs = args.lang_pairs
else:
- self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
+ self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
self.eval_lang_pairs = self.lang_pairs
self.model_lang_pairs = self.lang_pairs
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
self.kl_loss = LatentLayersKLLoss(self.args)
self.sparsity_loss = LatentLayersSparsityLoss(self.args)
- def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad):
+ def _per_lang_pair_train_loss(
+ self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
+ ):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
- model.models[lang_pair].encoder.layer_select.hard_select = update_num > self.args.soft_update
+ model.models[lang_pair].encoder.layer_select.hard_select = (
+ update_num > self.args.soft_update
+ )
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
- model.models[lang_pair].decoder.layer_select.hard_select = update_num > self.args.soft_update
+ model.models[lang_pair].decoder.layer_select.hard_select = (
+ update_num > self.args.soft_update
+ )
- loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
+ loss, sample_size, logging_output = criterion(
+ model.models[lang_pair], sample[lang_pair]
+ )
if self.encoder_latent_layer:
none_samples = sum(
- 1 if x is None else 0 for x in model.models[lang_pair].encoder.layer_select.layer_samples
+ 1 if x is None else 0
+ for x in model.models[lang_pair].encoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].encoder.layer_select.layer_samples,
src_lang_idx,
update_num,
- sample_size
+ sample_size,
)
if self.decoder_latent_layer:
none_samples = sum(
- 1 if x is None else 0 for x in model.models[lang_pair].decoder.layer_select.layer_samples
+ 1 if x is None else 0
+ for x in model.models[lang_pair].decoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].decoder.layer_select.layer_samples,
tgt_lang_idx,
update_num,
- sample_size
+ sample_size,
)
if ignore_grad:
loss *= 0
@@ -99,18 +112,31 @@ def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sam
return loss, sample_size, logging_output
- def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
+ def train_step(
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
+ ):
agg_loss, agg_sample_size, agg_logging_output = super().train_step(
- sample, model, criterion, optimizer, update_num, ignore_grad)
+ sample, model, criterion, optimizer, update_num, ignore_grad
+ )
# compute auxiliary loss from layere sparsity, based on all samples from all languages
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
sparsity_loss = 0
if self.encoder_latent_layer:
sparsity_loss += self.sparsity_loss(
- next(iter(model.models.values())).encoder.layer_select.layer_samples, update_num, agg_sample_size)
+ next(
+ iter(model.models.values())
+ ).encoder.layer_select.layer_samples,
+ update_num,
+ agg_sample_size,
+ )
if self.decoder_latent_layer:
sparsity_loss += self.sparsity_loss(
- next(iter(model.models.values())).decoder.layer_select.layer_samples, update_num, agg_sample_size)
+ next(
+ iter(model.models.values())
+ ).decoder.layer_select.layer_samples,
+ update_num,
+ agg_sample_size,
+ )
if sparsity_loss > 0:
optimizer.backward(sparsity_loss)
return agg_loss, agg_sample_size, agg_logging_output
@@ -123,10 +149,14 @@ def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample):
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
- loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
+ loss, sample_size, logging_output = criterion(
+ model.models[lang_pair], sample[lang_pair]
+ )
return loss, sample_size, logging_output
- def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
+ def inference_step(
+ self, generator, models, sample, prefix_tokens=None, constraints=None
+ ):
if self.encoder_latent_layer or self.decoder_latent_layer:
for model in models:
if self.encoder_latent_layer:
@@ -137,15 +167,23 @@ def inference_step(self, generator, models, sample, prefix_tokens=None, constrai
assert model.decoder.layer_select is not None
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
model.decoder.set_lang_idx(tgt_lang_idx)
- return super().inference_step(generator, models, sample, prefix_tokens, constraints)
+ return super().inference_step(
+ generator, models, sample, prefix_tokens, constraints
+ )
@property
def encoder_latent_layer(self):
- return hasattr(self.args, "encoder_latent_layer") and self.args.encoder_latent_layer
+ return (
+ hasattr(self.args, "encoder_latent_layer")
+ and self.args.encoder_latent_layer
+ )
@property
def decoder_latent_layer(self):
- return hasattr(self.args, "decoder_latent_layer") and self.args.decoder_latent_layer
+ return (
+ hasattr(self.args, "decoder_latent_layer")
+ and self.args.decoder_latent_layer
+ )
@property
def src_lang_idx_dict(self):
diff --git a/examples/linformer/src/models/linformer_roberta.py b/examples/linformer/src/models/linformer_roberta.py
index 722f5a4b9e..913351f238 100644
--- a/examples/linformer/src/models/linformer_roberta.py
+++ b/examples/linformer/src/models/linformer_roberta.py
@@ -8,37 +8,40 @@
import logging
-from fairseq.models import (
- register_model,
- register_model_architecture,
-)
-from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
+from fairseq.models import register_model, register_model_architecture
+from fairseq.models.roberta import RobertaEncoder, RobertaModel
-from fairseq.models.roberta import (
- RobertaModel,
- RobertaEncoder,
-)
+from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
logger = logging.getLogger(__name__)
-@register_model('linformer_roberta')
+@register_model("linformer_roberta")
class LinformerModel(RobertaModel):
-
@staticmethod
def add_args(parser):
RobertaModel.add_args(parser)
# add args for Linformer
- parser.add_argument('--compressed', type=int,
- help='compressed ratio of sequence length')
- parser.add_argument('--shared-kv-compressed', type=int,
- help='share compressed matrix between k and v, in each layer')
- parser.add_argument('--shared-layer-kv-compressed', type=int,
- help='share compressed matrix between k and v and across all layers')
- parser.add_argument('--freeze-compress', type=int,
- help='freeze the parameters in compressed layer')
+ parser.add_argument(
+ "--compressed", type=int, help="compressed ratio of sequence length"
+ )
+ parser.add_argument(
+ "--shared-kv-compressed",
+ type=int,
+ help="share compressed matrix between k and v, in each layer",
+ )
+ parser.add_argument(
+ "--shared-layer-kv-compressed",
+ type=int,
+ help="share compressed matrix between k and v and across all layers",
+ )
+ parser.add_argument(
+ "--freeze-compress",
+ type=int,
+ help="freeze the parameters in compressed layer",
+ )
@classmethod
def build_model(cls, args, task):
@@ -47,7 +50,7 @@ def build_model(cls, args, task):
# make sure all arguments are present
base_architecture(args)
- if not hasattr(args, 'max_positions'):
+ if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
encoder = LinformerEncoder(args, task.source_dictionary)
@@ -85,47 +88,47 @@ def __init__(self, args, dictionary):
)
-@register_model_architecture('linformer_roberta', 'linformer_roberta')
+@register_model_architecture("linformer_roberta", "linformer_roberta")
def base_architecture(args):
- args.encoder_layers = getattr(args, 'encoder_layers', 12)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
-
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
- args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
-
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
- args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
- args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None)
- args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0)
- args.compressed = getattr(args, 'compressed', 4)
- args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0)
- args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0)
- args.freeze_compress = getattr(args, 'freeze_compress', 0)
-
-
-@register_model_architecture('linformer_roberta', 'linformer_roberta_base')
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
+
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
+
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
+ args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
+ args.compressed = getattr(args, "compressed", 4)
+ args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
+ args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
+ args.freeze_compress = getattr(args, "freeze_compress", 0)
+
+
+@register_model_architecture("linformer_roberta", "linformer_roberta_base")
def linformer_roberta_base_architecture(args):
base_architecture(args)
-@register_model_architecture('linformer_roberta', 'linformer_roberta_large')
+@register_model_architecture("linformer_roberta", "linformer_roberta_large")
def linformer_roberta_large_architecture(args):
- args.encoder_layers = getattr(args, 'encoder_layers', 24)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
-
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
- args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
-
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
- args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
- args.compressed = getattr(args, 'compressed', 4)
- args.shared_kv_compressed = getattr(args, 'shared_kv_compressed', 0)
- args.shared_layer_kv_compressed = getattr(args, 'shared_layer_kv_compressed', 0)
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
+
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
+
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
+ args.compressed = getattr(args, "compressed", 4)
+ args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
+ args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
diff --git a/examples/linformer/src/modules/linformer_sentence_encoder.py b/examples/linformer/src/modules/linformer_sentence_encoder.py
index e3d170023d..d6de9eeaae 100644
--- a/examples/linformer/src/modules/linformer_sentence_encoder.py
+++ b/examples/linformer/src/modules/linformer_sentence_encoder.py
@@ -6,8 +6,8 @@
import math
import torch.nn as nn
-
from fairseq.modules import TransformerSentenceEncoder
+
from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer
@@ -117,7 +117,9 @@ def build_transformer_sentence_encoder_layer(
qn_block_size,
):
if self.shared_layer_kv_compressed == 1:
- compress_layer = nn.Linear(self.max_seq_len, self.max_seq_len // self.compressed)
+ compress_layer = nn.Linear(
+ self.max_seq_len, self.max_seq_len // self.compressed
+ )
# intialize parameters for compressed layer
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
if self.freeze_compress == 1:
@@ -139,8 +141,7 @@ def build_transformer_sentence_encoder_layer(
max_seq_len=self.max_seq_len,
shared_kv_compressed=self.shared_kv_compressed,
shared_compress_layer=(
- None if self.shared_layer_kv_compressed == 0
- else self.compress_layer
+ None if self.shared_layer_kv_compressed == 0 else self.compress_layer
),
freeze_compress=self.freeze_compress,
)
@@ -156,7 +157,8 @@ def upgrade_state_dict_named(self, state_dict, name):
if self.shared_layer_kv_compressed:
for layer_idx in range(len(self.layers)):
new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format(
- layer_idx, k[len(prefix + 'compress_layer.'):],
+ layer_idx,
+ k[len(prefix + "compress_layer.") :],
)
items_to_add[new_k] = state_dict[k]
diff --git a/examples/linformer/src/modules/linformer_sentence_encoder_layer.py b/examples/linformer/src/modules/linformer_sentence_encoder_layer.py
index e0a6047ce8..d27c5afd09 100644
--- a/examples/linformer/src/modules/linformer_sentence_encoder_layer.py
+++ b/examples/linformer/src/modules/linformer_sentence_encoder_layer.py
@@ -6,6 +6,7 @@
from typing import Callable
from fairseq.modules import TransformerSentenceEncoderLayer
+
from .multihead_linear_attention import MultiheadLinearAttention
@@ -23,7 +24,7 @@ def __init__(
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
- activation_fn: str = 'relu',
+ activation_fn: str = "relu",
export: bool = False,
q_noise: float = 0.0,
qn_block_size: int = 8,
diff --git a/examples/linformer/src/modules/multihead_linear_attention.py b/examples/linformer/src/modules/multihead_linear_attention.py
index 472cd4e3ea..ba2c36b1ef 100644
--- a/examples/linformer/src/modules/multihead_linear_attention.py
+++ b/examples/linformer/src/modules/multihead_linear_attention.py
@@ -9,10 +9,10 @@
import torch
import torch.nn.functional as F
from fairseq import utils
-from torch import Tensor, nn
-from torch.nn import Parameter
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.quant_noise import quant_noise
+from torch import Tensor, nn
+from torch.nn import Parameter
@with_incremental_state
@@ -65,16 +65,24 @@ def __init__(
"Self-attention requires query, key and " "value to be of the same size"
)
- self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
- self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
- self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
+ self.k_proj = quant_noise(
+ nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ self.v_proj = quant_noise(
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ self.q_proj = quant_noise(
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
# used for compress sequence to subsequence
if shared_compress_layer is None:
self.compress_seq_len = max_seq_len // compressed
self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
if shared_kv_compressed == 0:
- self.compress_v = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
+ self.compress_v = nn.Linear(
+ max_seq_len, self.compress_seq_len, bias=False
+ )
self.layerwise_sharing = False
else:
self.compress_k = shared_compress_layer
@@ -83,7 +91,9 @@ def __init__(
self.layerwise_sharing = True
self.shared_kv_compressed = shared_kv_compressed
- self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
+ self.out_proj = quant_noise(
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
@@ -116,22 +126,28 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
- if not self.layerwise_sharing: # otherwise, we already initialize the parameters
- nn.init.xavier_uniform_(self.compress_k.weight, gain=1/math.sqrt(2))
+ if (
+ not self.layerwise_sharing
+ ): # otherwise, we already initialize the parameters
+ nn.init.xavier_uniform_(self.compress_k.weight, gain=1 / math.sqrt(2))
if self.shared_kv_compressed == 0:
- nn.init.xavier_uniform_(self.compress_v.weight, gain=1/math.sqrt(2))
+ nn.init.xavier_uniform_(
+ self.compress_v.weight, gain=1 / math.sqrt(2)
+ )
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
- if not self.layerwise_sharing: # otherwise, we already initialize the parameters
+ if (
+ not self.layerwise_sharing
+ ): # otherwise, we already initialize the parameters
nn.init.xavier_uniform_(self.compress_k.weight)
if self.shared_kv_compressed == 0:
nn.init.xavier_uniform_(self.compress_v.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
- nn.init.constant_(self.out_proj.bias, 0.)
+ nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
@@ -189,14 +205,26 @@ def forward(
q = self.q_proj(query)
k_input = query.permute(1, 2, 0).contiguous() # B * C * T
- k_input = F.linear(k_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
+ k_input = (
+ F.linear(k_input, self.compress_k.weight[:, 0:tgt_len])
+ .permute(2, 0, 1)
+ .contiguous()
+ )
k = self.k_proj(k_input)
v_input = query.permute(1, 2, 0).contiguous() # B * C * T
if self.shared_kv_compressed == 0:
- v_input = F.linear(v_input, self.compress_v.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
+ v_input = (
+ F.linear(v_input, self.compress_v.weight[:, 0:tgt_len])
+ .permute(2, 0, 1)
+ .contiguous()
+ )
if self.shared_kv_compressed == 1: # use shared kv compressed linear layer
- v_input = F.linear(v_input, self.compress_k.weight[:, 0: tgt_len]).permute(2, 0, 1).contiguous()
+ v_input = (
+ F.linear(v_input, self.compress_k.weight[:, 0:tgt_len])
+ .permute(2, 0, 1)
+ .contiguous()
+ )
v = self.v_proj(v_input)
elif self.encoder_decoder_attention:
# encoder-decoder attention
@@ -302,7 +330,9 @@ def forward(
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
- attn_weights = MultiheadLinearAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+ attn_weights = MultiheadLinearAttention.apply_sparse_mask(
+ attn_weights, tgt_len, src_len, bsz
+ )
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
@@ -385,7 +415,9 @@ def _append_prev_key_padding_mask(
@torch.jit.export
def reorder_incremental_state(
- self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
+ self,
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
+ new_order: Tensor,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
@@ -393,7 +425,9 @@ def reorder_incremental_state(
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
- if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(0):
+ if self.encoder_decoder_attention and input_buffer_k.size(
+ 0
+ ) == new_order.size(0):
break
input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
@@ -428,8 +462,8 @@ def upgrade_state_dict_named(self, state_dict, name):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
- items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim:2 * dim]
- items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:]
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
keys_to_remove.append(k)
@@ -438,9 +472,9 @@ def upgrade_state_dict_named(self, state_dict, name):
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
- dim:2 * dim
+ dim : 2 * dim
]
- items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:]
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
keys_to_remove.append(prefix + "in_proj_bias")
diff --git a/examples/m2m_100/tokenizers/tokenize_indic.py b/examples/m2m_100/tokenizers/tokenize_indic.py
index c1303b3d15..a44fad07f7 100644
--- a/examples/m2m_100/tokenizers/tokenize_indic.py
+++ b/examples/m2m_100/tokenizers/tokenize_indic.py
@@ -8,14 +8,16 @@
import sys
-from indicnlp.tokenize.indic_tokenize import trivial_tokenize
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
+from indicnlp.tokenize.indic_tokenize import trivial_tokenize
-factory=IndicNormalizerFactory()
-normalizer=factory.get_normalizer(sys.argv[1],remove_nuktas=False,nasals_mode='do_nothing')
+
+factory = IndicNormalizerFactory()
+normalizer = factory.get_normalizer(
+ sys.argv[1], remove_nuktas=False, nasals_mode="do_nothing"
+)
for line in sys.stdin:
- normalized_line=normalizer.normalize(line.strip())
- tokenized_line=' '.join(trivial_tokenize(normalized_line, sys.argv[1]))
+ normalized_line = normalizer.normalize(line.strip())
+ tokenized_line = " ".join(trivial_tokenize(normalized_line, sys.argv[1]))
print(tokenized_line)
-
diff --git a/examples/m2m_100/tokenizers/tokenize_thai.py b/examples/m2m_100/tokenizers/tokenize_thai.py
index 7c7b7ebfaa..9c72cb8905 100644
--- a/examples/m2m_100/tokenizers/tokenize_thai.py
+++ b/examples/m2m_100/tokenizers/tokenize_thai.py
@@ -8,5 +8,6 @@
from pythainlp import word_tokenize
+
for line in sys.stdin:
print(" ".join(word_tokenize(line.strip())))
diff --git a/examples/m2m_100/tokenizers/tokenize_zh.py b/examples/m2m_100/tokenizers/tokenize_zh.py
index 531a7fb49b..674b5849cb 100644
--- a/examples/m2m_100/tokenizers/tokenize_zh.py
+++ b/examples/m2m_100/tokenizers/tokenize_zh.py
@@ -6,7 +6,9 @@
import fileinput
+
import sacrebleu
+
for line in fileinput.input():
print(sacrebleu.tokenize_zh(line))
diff --git a/examples/megatron_11b/detok.py b/examples/megatron_11b/detok.py
index a77a0b4960..49921b28a1 100644
--- a/examples/megatron_11b/detok.py
+++ b/examples/megatron_11b/detok.py
@@ -6,19 +6,27 @@
import argparse
import fileinput
+
import sacremoses
def main():
- parser = argparse.ArgumentParser(description='')
- parser.add_argument('files', nargs='*', help='input files')
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("files", nargs="*", help="input files")
args = parser.parse_args()
detok = sacremoses.MosesDetokenizer()
for line in fileinput.input(args.files, openhook=fileinput.hook_compressed):
- print(detok.detokenize(line.strip().split(' ')).replace(' @', '').replace('@ ', '').replace(' =', '=').replace('= ', '=').replace(' – ', '–'))
+ print(
+ detok.detokenize(line.strip().split(" "))
+ .replace(" @", "")
+ .replace("@ ", "")
+ .replace(" =", "=")
+ .replace("= ", "=")
+ .replace(" – ", "–")
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/noisychannel/rerank.py b/examples/noisychannel/rerank.py
index a5927a53b3..4df424e6b5 100644
--- a/examples/noisychannel/rerank.py
+++ b/examples/noisychannel/rerank.py
@@ -7,21 +7,22 @@
from multiprocessing import Pool
import numpy as np
-
from fairseq import options
from fairseq.data import dictionary
from fairseq.scoring import bleu
from . import (
rerank_generate,
+ rerank_options,
rerank_score_bw,
rerank_score_lm,
- rerank_options,
rerank_utils,
)
-def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize):
+def score_target_hypo(
+ args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize
+):
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
@@ -61,11 +62,21 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
bitext2_score = None
bitext2_backwards = None
- score = rerank_utils.get_score(a, b, c, target_len,
- bitext1.rescore_score[i], bitext2_score, lm_score=lm_score,
- lenpen=lenpen, src_len=bitext1.source_lengths[i],
- tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards,
- bitext2_backwards=bitext2_backwards, normalize=normalize)
+ score = rerank_utils.get_score(
+ a,
+ b,
+ c,
+ target_len,
+ bitext1.rescore_score[i],
+ bitext2_score,
+ lm_score=lm_score,
+ lenpen=lenpen,
+ src_len=bitext1.source_lengths[i],
+ tgt_len=bitext1.target_lengths[i],
+ bitext1_backwards=bitext1.backwards,
+ bitext2_backwards=bitext2_backwards,
+ normalize=normalize,
+ )
if score > best_score:
best_score = score
@@ -88,8 +99,11 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
for key in range(len(gen_keys)):
if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
- "pred and rescore hypo mismatch: i: " + str(key) + ", "
- + str(hypo_lst[key]) + str(gen_keys[key])
+ "pred and rescore hypo mismatch: i: "
+ + str(key)
+ + ", "
+ + str(hypo_lst[key])
+ + str(gen_keys[key])
+ str(gen_output.no_bpe_hypo[key])
)
sys_tok = dict.encode_line(hypo_lst[key])
@@ -97,7 +111,9 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
scorer.add(ref_tok, sys_tok)
else:
- full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
+ full_hypo = rerank_utils.get_full_from_prefix(
+ hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
+ )
sys_tok = dict.encode_line(full_hypo)
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
scorer.add(ref_tok, sys_tok)
@@ -107,20 +123,31 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
# recover the orinal ids from n best list generation
for key in range(len(gen_output.no_bpe_target)):
if args.prefix_len is None:
- assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \
- "pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key])
+ assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
+ "pred and rescore hypo mismatch:"
+ + "i:"
+ + str(key)
+ + str(hypo_lst[key])
+ + str(gen_output.no_bpe_hypo[key])
+ )
ordered_hypos[gen_keys[key]] = hypo_lst[key]
- ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
+ ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
+ gen_keys[key]
+ ]
else:
- full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
+ full_hypo = rerank_utils.get_full_from_prefix(
+ hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
+ )
ordered_hypos[gen_keys[key]] = full_hypo
- ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
+ ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
+ gen_keys[key]
+ ]
# write the hypos in the original order from nbest list generation
if args.num_shards == (len(bitext1_lst)):
- with open(target_outfile, 'w') as t:
- with open(hypo_outfile, 'w') as h:
+ with open(target_outfile, "w") as t:
+ with open(hypo_outfile, "w") as h:
for key in range(len(ordered_hypos)):
t.write(ordered_targets[key])
h.write(ordered_hypos[key])
@@ -135,17 +162,38 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
def match_target_hypo(args, target_outfile, hypo_outfile):
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
if len(args.weight1) == 1:
- res = score_target_hypo(args, args.weight1[0], args.weight2[0],
- args.weight3[0], args.lenpen[0], target_outfile,
- hypo_outfile, True, args.normalize)
+ res = score_target_hypo(
+ args,
+ args.weight1[0],
+ args.weight2[0],
+ args.weight3[0],
+ args.lenpen[0],
+ target_outfile,
+ hypo_outfile,
+ True,
+ args.normalize,
+ )
rerank_scores = [res]
else:
print("launching pool")
with Pool(32) as p:
- rerank_scores = p.starmap(score_target_hypo,
- [(args, args.weight1[i], args.weight2[i], args.weight3[i],
- args.lenpen[i], target_outfile, hypo_outfile,
- False, args.normalize) for i in range(len(args.weight1))])
+ rerank_scores = p.starmap(
+ score_target_hypo,
+ [
+ (
+ args,
+ args.weight1[i],
+ args.weight2[i],
+ args.weight3[i],
+ args.lenpen[i],
+ target_outfile,
+ hypo_outfile,
+ False,
+ args.normalize,
+ )
+ for i in range(len(args.weight1))
+ ],
+ )
if len(rerank_scores) > 1:
best_index = np.argmax(rerank_scores)
@@ -155,11 +203,22 @@ def match_target_hypo(args, target_outfile, hypo_outfile):
print("best weight1", args.weight1[best_index])
print("best weight2", args.weight2[best_index])
print("best weight3", args.weight3[best_index])
- return args.lenpen[best_index], args.weight1[best_index], \
- args.weight2[best_index], args.weight3[best_index], best_score
+ return (
+ args.lenpen[best_index],
+ args.weight1[best_index],
+ args.weight2[best_index],
+ args.weight3[best_index],
+ best_score,
+ )
else:
- return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0]
+ return (
+ args.lenpen[0],
+ args.weight1[0],
+ args.weight2[0],
+ args.weight3[0],
+ rerank_scores[0],
+ )
def load_score_files(args):
@@ -175,55 +234,100 @@ def load_score_files(args):
for shard_id in shard_ids:
using_nbest = args.nbest_list is not None
- pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
- backwards_preprocessed_dir, lm_preprocessed_dir = \
- rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
- args.gen_model_name, shard_id, args.num_shards, args.sampling,
- args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
-
- rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
- rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
-
- score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
- target_prefix_frac=args.target_prefix_frac,
- source_prefix_frac=args.source_prefix_frac,
- backwards=args.backwards1)
+ (
+ pre_gen,
+ left_to_right_preprocessed_dir,
+ right_to_left_preprocessed_dir,
+ backwards_preprocessed_dir,
+ lm_preprocessed_dir,
+ ) = rerank_utils.get_directories(
+ args.data_dir_name,
+ args.num_rescore,
+ args.gen_subset,
+ args.gen_model_name,
+ shard_id,
+ args.num_shards,
+ args.sampling,
+ args.prefix_len,
+ args.target_prefix_frac,
+ args.source_prefix_frac,
+ )
+
+ rerank1_is_gen = (
+ args.gen_model == args.score_model1 and args.source_prefix_frac is None
+ )
+ rerank2_is_gen = (
+ args.gen_model == args.score_model2 and args.source_prefix_frac is None
+ )
+
+ score1_file = rerank_utils.rescore_file_name(
+ pre_gen,
+ args.prefix_len,
+ args.model1_name,
+ target_prefix_frac=args.target_prefix_frac,
+ source_prefix_frac=args.source_prefix_frac,
+ backwards=args.backwards1,
+ )
if args.score_model2 is not None:
- score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
- target_prefix_frac=args.target_prefix_frac,
- source_prefix_frac=args.source_prefix_frac,
- backwards=args.backwards2)
+ score2_file = rerank_utils.rescore_file_name(
+ pre_gen,
+ args.prefix_len,
+ args.model2_name,
+ target_prefix_frac=args.target_prefix_frac,
+ source_prefix_frac=args.source_prefix_frac,
+ backwards=args.backwards2,
+ )
if args.language_model is not None:
- lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
+ lm_score_file = rerank_utils.rescore_file_name(
+ pre_gen, args.prefix_len, args.lm_name, lm_file=True
+ )
# get gen output
- predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
+ predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
- gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
- nbest=using_nbest, prefix_len=args.prefix_len,
- target_prefix_frac=args.target_prefix_frac)
+ gen_output = rerank_utils.BitextOutputFromGen(
+ predictions_bpe_file,
+ bpe_symbol=args.remove_bpe,
+ nbest=using_nbest,
+ prefix_len=args.prefix_len,
+ target_prefix_frac=args.target_prefix_frac,
+ )
if rerank1_is_gen:
bitext1 = gen_output
else:
- bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1,
- args.remove_bpe, args.prefix_len, args.target_prefix_frac,
- args.source_prefix_frac)
+ bitext1 = rerank_utils.BitextOutput(
+ score1_file,
+ args.backwards1,
+ args.right_to_left1,
+ args.remove_bpe,
+ args.prefix_len,
+ args.target_prefix_frac,
+ args.source_prefix_frac,
+ )
if args.score_model2 is not None or args.nbest_list is not None:
if rerank2_is_gen:
bitext2 = gen_output
else:
- bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2,
- args.remove_bpe, args.prefix_len, args.target_prefix_frac,
- args.source_prefix_frac)
-
- assert bitext2.source_lengths == bitext1.source_lengths, \
- "source lengths for rescoring models do not match"
- assert bitext2.target_lengths == bitext1.target_lengths, \
- "target lengths for rescoring models do not match"
+ bitext2 = rerank_utils.BitextOutput(
+ score2_file,
+ args.backwards2,
+ args.right_to_left2,
+ args.remove_bpe,
+ args.prefix_len,
+ args.target_prefix_frac,
+ args.source_prefix_frac,
+ )
+
+ assert (
+ bitext2.source_lengths == bitext1.source_lengths
+ ), "source lengths for rescoring models do not match"
+ assert (
+ bitext2.target_lengths == bitext1.target_lengths
+ ), "target lengths for rescoring models do not match"
else:
if args.diff_bpe:
assert args.score_model2 is None
@@ -232,8 +336,13 @@ def load_score_files(args):
bitext2 = None
if args.language_model is not None:
- lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len,
- args.remove_bpe, args.target_prefix_frac)
+ lm_res1 = rerank_utils.LMOutput(
+ lm_score_file,
+ args.lm_dict,
+ args.prefix_len,
+ args.remove_bpe,
+ args.target_prefix_frac,
+ )
else:
lm_res1 = None
@@ -259,28 +368,46 @@ def rerank(args):
shard_ids = [args.shard_id]
for shard_id in shard_ids:
- pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
- backwards_preprocessed_dir, lm_preprocessed_dir = \
- rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
- args.gen_model_name, shard_id, args.num_shards, args.sampling,
- args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
+ (
+ pre_gen,
+ left_to_right_preprocessed_dir,
+ right_to_left_preprocessed_dir,
+ backwards_preprocessed_dir,
+ lm_preprocessed_dir,
+ ) = rerank_utils.get_directories(
+ args.data_dir_name,
+ args.num_rescore,
+ args.gen_subset,
+ args.gen_model_name,
+ shard_id,
+ args.num_shards,
+ args.sampling,
+ args.prefix_len,
+ args.target_prefix_frac,
+ args.source_prefix_frac,
+ )
rerank_generate.gen_and_reprocess_nbest(args)
rerank_score_bw.score_bw(args)
rerank_score_lm.score_lm(args)
if args.write_hypos is None:
- write_targets = pre_gen+"/matched_targets"
- write_hypos = pre_gen+"/matched_hypos"
+ write_targets = pre_gen + "/matched_targets"
+ write_hypos = pre_gen + "/matched_hypos"
else:
- write_targets = args.write_hypos+"_targets" + args.gen_subset
- write_hypos = args.write_hypos+"_hypos" + args.gen_subset
+ write_targets = args.write_hypos + "_targets" + args.gen_subset
+ write_hypos = args.write_hypos + "_hypos" + args.gen_subset
if args.all_shards:
write_targets += "_all_shards"
write_hypos += "_all_shards"
- best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \
- match_target_hypo(args, write_targets, write_hypos)
+ (
+ best_lenpen,
+ best_weight1,
+ best_weight2,
+ best_weight3,
+ best_score,
+ ) = match_target_hypo(args, write_targets, write_hypos)
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
@@ -291,5 +418,5 @@ def cli_main():
rerank(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/examples/noisychannel/rerank_generate.py b/examples/noisychannel/rerank_generate.py
index d2da6eacf9..4356b3387e 100644
--- a/examples/noisychannel/rerank_generate.py
+++ b/examples/noisychannel/rerank_generate.py
@@ -8,9 +8,9 @@
Generate n-best translations using a trained model.
"""
-from contextlib import redirect_stdout
import os
import subprocess
+from contextlib import redirect_stdout
from fairseq import options
from fairseq_cli import generate, preprocess
@@ -22,8 +22,12 @@ def gen_and_reprocess_nbest(args):
if args.score_dict_dir is None:
args.score_dict_dir = args.data
if args.prefix_len is not None:
- assert args.right_to_left1 is False, "prefix length not compatible with right to left models"
- assert args.right_to_left2 is False, "prefix length not compatible with right to left models"
+ assert (
+ args.right_to_left1 is False
+ ), "prefix length not compatible with right to left models"
+ assert (
+ args.right_to_left2 is False
+ ), "prefix length not compatible with right to left models"
if args.nbest_list is not None:
assert args.score_model2 is None
@@ -35,27 +39,50 @@ def gen_and_reprocess_nbest(args):
scorer1_src = args.source_lang
scorer1_tgt = args.target_lang
- store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name
+ store_data = (
+ os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name
+ )
if not os.path.exists(store_data):
os.makedirs(store_data)
- pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
- backwards_preprocessed_dir, lm_preprocessed_dir = \
- rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
- args.gen_model_name, args.shard_id, args.num_shards,
- args.sampling, args.prefix_len, args.target_prefix_frac,
- args.source_prefix_frac)
- assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported"
- assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported"
- assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \
- "target prefix frac and target prefix len incompatible"
+ (
+ pre_gen,
+ left_to_right_preprocessed_dir,
+ right_to_left_preprocessed_dir,
+ backwards_preprocessed_dir,
+ lm_preprocessed_dir,
+ ) = rerank_utils.get_directories(
+ args.data_dir_name,
+ args.num_rescore,
+ args.gen_subset,
+ args.gen_model_name,
+ args.shard_id,
+ args.num_shards,
+ args.sampling,
+ args.prefix_len,
+ args.target_prefix_frac,
+ args.source_prefix_frac,
+ )
+ assert not (
+ args.right_to_left1 and args.backwards1
+ ), "backwards right to left not supported"
+ assert not (
+ args.right_to_left2 and args.backwards2
+ ), "backwards right to left not supported"
+ assert not (
+ args.prefix_len is not None and args.target_prefix_frac is not None
+ ), "target prefix frac and target prefix len incompatible"
# make directory to store generation results
if not os.path.exists(pre_gen):
os.makedirs(pre_gen)
- rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
- rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
+ rerank1_is_gen = (
+ args.gen_model == args.score_model1 and args.source_prefix_frac is None
+ )
+ rerank2_is_gen = (
+ args.gen_model == args.score_model2 and args.source_prefix_frac is None
+ )
if args.nbest_list is not None:
rerank2_is_gen = True
@@ -70,17 +97,25 @@ def gen_and_reprocess_nbest(args):
if not os.path.exists(backwards_preprocessed_dir):
os.makedirs(backwards_preprocessed_dir)
- score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
- target_prefix_frac=args.target_prefix_frac,
- source_prefix_frac=args.source_prefix_frac,
- backwards=args.backwards1)
+ score1_file = rerank_utils.rescore_file_name(
+ pre_gen,
+ args.prefix_len,
+ args.model1_name,
+ target_prefix_frac=args.target_prefix_frac,
+ source_prefix_frac=args.source_prefix_frac,
+ backwards=args.backwards1,
+ )
if args.score_model2 is not None:
- score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
- target_prefix_frac=args.target_prefix_frac,
- source_prefix_frac=args.source_prefix_frac,
- backwards=args.backwards2)
+ score2_file = rerank_utils.rescore_file_name(
+ pre_gen,
+ args.prefix_len,
+ args.model2_name,
+ target_prefix_frac=args.target_prefix_frac,
+ source_prefix_frac=args.source_prefix_frac,
+ backwards=args.backwards2,
+ )
- predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
+ predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
using_nbest = args.nbest_list is not None
@@ -92,17 +127,29 @@ def gen_and_reprocess_nbest(args):
if not os.path.isfile(predictions_bpe_file):
print("STEP 1: generate predictions using the p(T|S) model with bpe")
print(args.data)
- param1 = [args.data,
- "--path", args.gen_model,
- "--shard-id", str(args.shard_id),
- "--num-shards", str(args.num_shards),
- "--nbest", str(args.num_rescore),
- "--batch-size", str(args.batch_size),
- "--beam", str(args.num_rescore),
- "--batch-size", str(args.num_rescore),
- "--gen-subset", args.gen_subset,
- "--source-lang", args.source_lang,
- "--target-lang", args.target_lang]
+ param1 = [
+ args.data,
+ "--path",
+ args.gen_model,
+ "--shard-id",
+ str(args.shard_id),
+ "--num-shards",
+ str(args.num_shards),
+ "--nbest",
+ str(args.num_rescore),
+ "--batch-size",
+ str(args.batch_size),
+ "--beam",
+ str(args.num_rescore),
+ "--batch-size",
+ str(args.num_rescore),
+ "--gen-subset",
+ args.gen_subset,
+ "--source-lang",
+ args.source_lang,
+ "--target-lang",
+ args.target_lang,
+ ]
if args.sampling:
param1 += ["--sampling"]
@@ -110,124 +157,229 @@ def gen_and_reprocess_nbest(args):
input_args = options.parse_args_and_arch(gen_parser, param1)
print(input_args)
- with open(predictions_bpe_file, 'w') as f:
+ with open(predictions_bpe_file, "w") as f:
with redirect_stdout(f):
generate.main(input_args)
- gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
- nbest=using_nbest, prefix_len=args.prefix_len,
- target_prefix_frac=args.target_prefix_frac)
+ gen_output = rerank_utils.BitextOutputFromGen(
+ predictions_bpe_file,
+ bpe_symbol=args.remove_bpe,
+ nbest=using_nbest,
+ prefix_len=args.prefix_len,
+ target_prefix_frac=args.target_prefix_frac,
+ )
if args.diff_bpe:
- rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
- gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang,
- pre_gen+"/target_gen_bpe."+args.target_lang,
- pre_gen+"/reference_gen_bpe."+args.target_lang)
+ rerank_utils.write_reprocessed(
+ gen_output.no_bpe_source,
+ gen_output.no_bpe_hypo,
+ gen_output.no_bpe_target,
+ pre_gen + "/source_gen_bpe." + args.source_lang,
+ pre_gen + "/target_gen_bpe." + args.target_lang,
+ pre_gen + "/reference_gen_bpe." + args.target_lang,
+ )
bitext_bpe = args.rescore_bpe_code
- bpe_src_param = ["-c", bitext_bpe,
- "--input", pre_gen+"/source_gen_bpe."+args.source_lang,
- "--output", pre_gen+"/rescore_data."+args.source_lang]
- bpe_tgt_param = ["-c", bitext_bpe,
- "--input", pre_gen+"/target_gen_bpe."+args.target_lang,
- "--output", pre_gen+"/rescore_data."+args.target_lang]
-
- subprocess.call(["python",
- os.path.join(os.path.dirname(__file__),
- "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
- shell=False)
-
- subprocess.call(["python",
- os.path.join(os.path.dirname(__file__),
- "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param,
- shell=False)
-
- if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \
- (args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen):
- print("STEP 2: process the output of generate.py so we have clean text files with the translations")
+ bpe_src_param = [
+ "-c",
+ bitext_bpe,
+ "--input",
+ pre_gen + "/source_gen_bpe." + args.source_lang,
+ "--output",
+ pre_gen + "/rescore_data." + args.source_lang,
+ ]
+ bpe_tgt_param = [
+ "-c",
+ bitext_bpe,
+ "--input",
+ pre_gen + "/target_gen_bpe." + args.target_lang,
+ "--output",
+ pre_gen + "/rescore_data." + args.target_lang,
+ ]
+
+ subprocess.call(
+ [
+ "python",
+ os.path.join(
+ os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
+ ),
+ ]
+ + bpe_src_param,
+ shell=False,
+ )
+
+ subprocess.call(
+ [
+ "python",
+ os.path.join(
+ os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
+ ),
+ ]
+ + bpe_tgt_param,
+ shell=False,
+ )
+
+ if (not os.path.isfile(score1_file) and not rerank1_is_gen) or (
+ args.score_model2 is not None
+ and not os.path.isfile(score2_file)
+ and not rerank2_is_gen
+ ):
+ print(
+ "STEP 2: process the output of generate.py so we have clean text files with the translations"
+ )
rescore_file = "/rescore_data"
if args.prefix_len is not None:
- prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len)
+ prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len)
if args.target_prefix_frac is not None:
- target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac)
+ target_prefix_frac_rescore_file = (
+ rescore_file + "target_prefix_frac" + str(args.target_prefix_frac)
+ )
if args.source_prefix_frac is not None:
- source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac)
+ source_prefix_frac_rescore_file = (
+ rescore_file + "source_prefix_frac" + str(args.source_prefix_frac)
+ )
if not args.right_to_left1 or not args.right_to_left2:
if not args.diff_bpe:
- rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
- pre_gen+rescore_file+"."+args.source_lang,
- pre_gen+rescore_file+"."+args.target_lang,
- pre_gen+"/reference_file", bpe_symbol=args.remove_bpe)
+ rerank_utils.write_reprocessed(
+ gen_output.source,
+ gen_output.hypo,
+ gen_output.target,
+ pre_gen + rescore_file + "." + args.source_lang,
+ pre_gen + rescore_file + "." + args.target_lang,
+ pre_gen + "/reference_file",
+ bpe_symbol=args.remove_bpe,
+ )
if args.prefix_len is not None:
bw_rescore_file = prefix_len_rescore_file
- rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
- pre_gen+prefix_len_rescore_file+"."+args.source_lang,
- pre_gen+prefix_len_rescore_file+"."+args.target_lang,
- pre_gen+"/reference_file", prefix_len=args.prefix_len,
- bpe_symbol=args.remove_bpe)
+ rerank_utils.write_reprocessed(
+ gen_output.source,
+ gen_output.hypo,
+ gen_output.target,
+ pre_gen + prefix_len_rescore_file + "." + args.source_lang,
+ pre_gen + prefix_len_rescore_file + "." + args.target_lang,
+ pre_gen + "/reference_file",
+ prefix_len=args.prefix_len,
+ bpe_symbol=args.remove_bpe,
+ )
elif args.target_prefix_frac is not None:
bw_rescore_file = target_prefix_frac_rescore_file
- rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
- pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang,
- pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang,
- pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
- target_prefix_frac=args.target_prefix_frac)
+ rerank_utils.write_reprocessed(
+ gen_output.source,
+ gen_output.hypo,
+ gen_output.target,
+ pre_gen
+ + target_prefix_frac_rescore_file
+ + "."
+ + args.source_lang,
+ pre_gen
+ + target_prefix_frac_rescore_file
+ + "."
+ + args.target_lang,
+ pre_gen + "/reference_file",
+ bpe_symbol=args.remove_bpe,
+ target_prefix_frac=args.target_prefix_frac,
+ )
else:
bw_rescore_file = rescore_file
if args.source_prefix_frac is not None:
fw_rescore_file = source_prefix_frac_rescore_file
- rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
- pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang,
- pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang,
- pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
- source_prefix_frac=args.source_prefix_frac)
+ rerank_utils.write_reprocessed(
+ gen_output.source,
+ gen_output.hypo,
+ gen_output.target,
+ pre_gen
+ + source_prefix_frac_rescore_file
+ + "."
+ + args.source_lang,
+ pre_gen
+ + source_prefix_frac_rescore_file
+ + "."
+ + args.target_lang,
+ pre_gen + "/reference_file",
+ bpe_symbol=args.remove_bpe,
+ source_prefix_frac=args.source_prefix_frac,
+ )
else:
fw_rescore_file = rescore_file
if args.right_to_left1 or args.right_to_left2:
- rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
- pre_gen+"/right_to_left_rescore_data."+args.source_lang,
- pre_gen+"/right_to_left_rescore_data."+args.target_lang,
- pre_gen+"/right_to_left_reference_file",
- right_to_left=True, bpe_symbol=args.remove_bpe)
+ rerank_utils.write_reprocessed(
+ gen_output.source,
+ gen_output.hypo,
+ gen_output.target,
+ pre_gen + "/right_to_left_rescore_data." + args.source_lang,
+ pre_gen + "/right_to_left_rescore_data." + args.target_lang,
+ pre_gen + "/right_to_left_reference_file",
+ right_to_left=True,
+ bpe_symbol=args.remove_bpe,
+ )
print("STEP 3: binarize the translations")
- if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen:
+ if (
+ not args.right_to_left1
+ or args.score_model2 is not None
+ and not args.right_to_left2
+ or not rerank1_is_gen
+ ):
if args.backwards1 or args.backwards2:
if args.backwards_score_dict_dir is not None:
bw_dict = args.backwards_score_dict_dir
else:
bw_dict = args.score_dict_dir
- bw_preprocess_param = ["--source-lang", scorer1_src,
- "--target-lang", scorer1_tgt,
- "--trainpref", pre_gen+bw_rescore_file,
- "--srcdict", bw_dict + "/dict." + scorer1_src + ".txt",
- "--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt",
- "--destdir", backwards_preprocessed_dir]
+ bw_preprocess_param = [
+ "--source-lang",
+ scorer1_src,
+ "--target-lang",
+ scorer1_tgt,
+ "--trainpref",
+ pre_gen + bw_rescore_file,
+ "--srcdict",
+ bw_dict + "/dict." + scorer1_src + ".txt",
+ "--tgtdict",
+ bw_dict + "/dict." + scorer1_tgt + ".txt",
+ "--destdir",
+ backwards_preprocessed_dir,
+ ]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(bw_preprocess_param)
preprocess.main(input_args)
- preprocess_param = ["--source-lang", scorer1_src,
- "--target-lang", scorer1_tgt,
- "--trainpref", pre_gen+fw_rescore_file,
- "--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
- "--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
- "--destdir", left_to_right_preprocessed_dir]
+ preprocess_param = [
+ "--source-lang",
+ scorer1_src,
+ "--target-lang",
+ scorer1_tgt,
+ "--trainpref",
+ pre_gen + fw_rescore_file,
+ "--srcdict",
+ args.score_dict_dir + "/dict." + scorer1_src + ".txt",
+ "--tgtdict",
+ args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
+ "--destdir",
+ left_to_right_preprocessed_dir,
+ ]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args)
if args.right_to_left1 or args.right_to_left2:
- preprocess_param = ["--source-lang", scorer1_src,
- "--target-lang", scorer1_tgt,
- "--trainpref", pre_gen+"/right_to_left_rescore_data",
- "--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
- "--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
- "--destdir", right_to_left_preprocessed_dir]
+ preprocess_param = [
+ "--source-lang",
+ scorer1_src,
+ "--target-lang",
+ scorer1_tgt,
+ "--trainpref",
+ pre_gen + "/right_to_left_rescore_data",
+ "--srcdict",
+ args.score_dict_dir + "/dict." + scorer1_src + ".txt",
+ "--tgtdict",
+ args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
+ "--destdir",
+ right_to_left_preprocessed_dir,
+ ]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args)
@@ -241,5 +393,5 @@ def cli_main():
gen_and_reprocess_nbest(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/examples/noisychannel/rerank_options.py b/examples/noisychannel/rerank_options.py
index a425fb295b..ca7a2e0a61 100644
--- a/examples/noisychannel/rerank_options.py
+++ b/examples/noisychannel/rerank_options.py
@@ -6,14 +6,14 @@
from fairseq import options
-def get_reranking_parser(default_task='translation'):
- parser = options.get_parser('Generation and reranking', default_task)
+def get_reranking_parser(default_task="translation"):
+ parser = options.get_parser("Generation and reranking", default_task)
add_reranking_args(parser)
return parser
-def get_tuning_parser(default_task='translation'):
- parser = options.get_parser('Reranking tuning', default_task)
+def get_tuning_parser(default_task="translation"):
+ parser = options.get_parser("Reranking tuning", default_task)
add_reranking_args(parser)
add_tuning_args(parser)
return parser
@@ -110,17 +110,40 @@ def add_reranking_args(parser):
def add_tuning_args(parser):
group = parser.add_argument_group("Tuning")
- group.add_argument('--lower-bound', default=[-0.7], nargs='+', type=float,
- help='lower bound of search space')
- group.add_argument('--upper-bound', default=[3], nargs='+', type=float,
- help='upper bound of search space')
- group.add_argument('--tune-param', default=['lenpen'], nargs='+',
- choices=['lenpen', 'weight1', 'weight2', 'weight3'],
- help='the parameter(s) to tune')
- group.add_argument('--tune-subset', default='valid', choices=['valid', 'test', 'train'],
- help='the subset to tune on ')
- group.add_argument('--num-trials', default=1000, type=int,
- help='number of trials to do for random search')
- group.add_argument('--share-weights', action='store_true',
- help='share weight2 and weight 3')
+ group.add_argument(
+ "--lower-bound",
+ default=[-0.7],
+ nargs="+",
+ type=float,
+ help="lower bound of search space",
+ )
+ group.add_argument(
+ "--upper-bound",
+ default=[3],
+ nargs="+",
+ type=float,
+ help="upper bound of search space",
+ )
+ group.add_argument(
+ "--tune-param",
+ default=["lenpen"],
+ nargs="+",
+ choices=["lenpen", "weight1", "weight2", "weight3"],
+ help="the parameter(s) to tune",
+ )
+ group.add_argument(
+ "--tune-subset",
+ default="valid",
+ choices=["valid", "test", "train"],
+ help="the subset to tune on ",
+ )
+ group.add_argument(
+ "--num-trials",
+ default=1000,
+ type=int,
+ help="number of trials to do for random search",
+ )
+ group.add_argument(
+ "--share-weights", action="store_true", help="share weight2 and weight 3"
+ )
return group
diff --git a/examples/noisychannel/rerank_score_bw.py b/examples/noisychannel/rerank_score_bw.py
index 6a875e9fe3..895673b1cc 100644
--- a/examples/noisychannel/rerank_score_bw.py
+++ b/examples/noisychannel/rerank_score_bw.py
@@ -3,8 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from contextlib import redirect_stdout
import os
+from contextlib import redirect_stdout
from fairseq import options
from fairseq_cli import generate
@@ -13,82 +13,124 @@
def score_bw(args):
- if args.backwards1:
- scorer1_src = args.target_lang
- scorer1_tgt = args.source_lang
+ if args.backwards1:
+ scorer1_src = args.target_lang
+ scorer1_tgt = args.source_lang
+ else:
+ scorer1_src = args.source_lang
+ scorer1_tgt = args.target_lang
+
+ if args.score_model2 is not None:
+ if args.backwards2:
+ scorer2_src = args.target_lang
+ scorer2_tgt = args.source_lang
else:
- scorer1_src = args.source_lang
- scorer1_tgt = args.target_lang
-
- if args.score_model2 is not None:
- if args.backwards2:
- scorer2_src = args.target_lang
- scorer2_tgt = args.source_lang
- else:
- scorer2_src = args.source_lang
- scorer2_tgt = args.target_lang
-
- rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
- rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
-
- pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
- backwards_preprocessed_dir, lm_preprocessed_dir = \
- rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
- args.gen_model_name, args.shard_id, args.num_shards,
- args.sampling, args.prefix_len, args.target_prefix_frac,
- args.source_prefix_frac)
-
- score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
- target_prefix_frac=args.target_prefix_frac,
- source_prefix_frac=args.source_prefix_frac,
- backwards=args.backwards1)
-
- if args.score_model2 is not None:
- score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
- target_prefix_frac=args.target_prefix_frac,
- source_prefix_frac=args.source_prefix_frac,
- backwards=args.backwards2)
-
- if args.right_to_left1:
- rerank_data1 = right_to_left_preprocessed_dir
- elif args.backwards1:
- rerank_data1 = backwards_preprocessed_dir
+ scorer2_src = args.source_lang
+ scorer2_tgt = args.target_lang
+
+ rerank1_is_gen = (
+ args.gen_model == args.score_model1 and args.source_prefix_frac is None
+ )
+ rerank2_is_gen = (
+ args.gen_model == args.score_model2 and args.source_prefix_frac is None
+ )
+
+ (
+ pre_gen,
+ left_to_right_preprocessed_dir,
+ right_to_left_preprocessed_dir,
+ backwards_preprocessed_dir,
+ lm_preprocessed_dir,
+ ) = rerank_utils.get_directories(
+ args.data_dir_name,
+ args.num_rescore,
+ args.gen_subset,
+ args.gen_model_name,
+ args.shard_id,
+ args.num_shards,
+ args.sampling,
+ args.prefix_len,
+ args.target_prefix_frac,
+ args.source_prefix_frac,
+ )
+
+ score1_file = rerank_utils.rescore_file_name(
+ pre_gen,
+ args.prefix_len,
+ args.model1_name,
+ target_prefix_frac=args.target_prefix_frac,
+ source_prefix_frac=args.source_prefix_frac,
+ backwards=args.backwards1,
+ )
+
+ if args.score_model2 is not None:
+ score2_file = rerank_utils.rescore_file_name(
+ pre_gen,
+ args.prefix_len,
+ args.model2_name,
+ target_prefix_frac=args.target_prefix_frac,
+ source_prefix_frac=args.source_prefix_frac,
+ backwards=args.backwards2,
+ )
+
+ if args.right_to_left1:
+ rerank_data1 = right_to_left_preprocessed_dir
+ elif args.backwards1:
+ rerank_data1 = backwards_preprocessed_dir
+ else:
+ rerank_data1 = left_to_right_preprocessed_dir
+
+ gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
+ if not rerank1_is_gen and not os.path.isfile(score1_file):
+ print("STEP 4: score the translations for model 1")
+
+ model_param1 = [
+ "--path",
+ args.score_model1,
+ "--source-lang",
+ scorer1_src,
+ "--target-lang",
+ scorer1_tgt,
+ ]
+ gen_model1_param = [rerank_data1] + gen_param + model_param1
+
+ gen_parser = options.get_generation_parser()
+ input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
+
+ with open(score1_file, "w") as f:
+ with redirect_stdout(f):
+ generate.main(input_args)
+
+ if (
+ args.score_model2 is not None
+ and not os.path.isfile(score2_file)
+ and not rerank2_is_gen
+ ):
+ print("STEP 4: score the translations for model 2")
+
+ if args.right_to_left2:
+ rerank_data2 = right_to_left_preprocessed_dir
+ elif args.backwards2:
+ rerank_data2 = backwards_preprocessed_dir
else:
- rerank_data1 = left_to_right_preprocessed_dir
-
- gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
- if not rerank1_is_gen and not os.path.isfile(score1_file):
- print("STEP 4: score the translations for model 1")
-
- model_param1 = ["--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt]
- gen_model1_param = [rerank_data1] + gen_param + model_param1
-
- gen_parser = options.get_generation_parser()
- input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
-
- with open(score1_file, 'w') as f:
- with redirect_stdout(f):
- generate.main(input_args)
-
- if args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen:
- print("STEP 4: score the translations for model 2")
-
- if args.right_to_left2:
- rerank_data2 = right_to_left_preprocessed_dir
- elif args.backwards2:
- rerank_data2 = backwards_preprocessed_dir
- else:
- rerank_data2 = left_to_right_preprocessed_dir
+ rerank_data2 = left_to_right_preprocessed_dir
- model_param2 = ["--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt]
- gen_model2_param = [rerank_data2] + gen_param + model_param2
+ model_param2 = [
+ "--path",
+ args.score_model2,
+ "--source-lang",
+ scorer2_src,
+ "--target-lang",
+ scorer2_tgt,
+ ]
+ gen_model2_param = [rerank_data2] + gen_param + model_param2
- gen_parser = options.get_generation_parser()
- input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
+ gen_parser = options.get_generation_parser()
+ input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
- with open(score2_file, 'w') as f:
- with redirect_stdout(f):
- generate.main(input_args)
+ with open(score2_file, "w") as f:
+ with redirect_stdout(f):
+ generate.main(input_args)
def cli_main():
@@ -97,5 +139,5 @@ def cli_main():
score_bw(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/examples/noisychannel/rerank_score_lm.py b/examples/noisychannel/rerank_score_lm.py
index 74b858e3c8..fa3aa64462 100644
--- a/examples/noisychannel/rerank_score_lm.py
+++ b/examples/noisychannel/rerank_score_lm.py
@@ -12,22 +12,38 @@
def score_lm(args):
using_nbest = args.nbest_list is not None
- pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
- backwards_preprocessed_dir, lm_preprocessed_dir = \
- rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
- args.gen_model_name, args.shard_id, args.num_shards,
- args.sampling, args.prefix_len, args.target_prefix_frac,
- args.source_prefix_frac)
-
- predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
+ (
+ pre_gen,
+ left_to_right_preprocessed_dir,
+ right_to_left_preprocessed_dir,
+ backwards_preprocessed_dir,
+ lm_preprocessed_dir,
+ ) = rerank_utils.get_directories(
+ args.data_dir_name,
+ args.num_rescore,
+ args.gen_subset,
+ args.gen_model_name,
+ args.shard_id,
+ args.num_shards,
+ args.sampling,
+ args.prefix_len,
+ args.target_prefix_frac,
+ args.source_prefix_frac,
+ )
+
+ predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
- gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest)
+ gen_output = rerank_utils.BitextOutputFromGen(
+ predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest
+ )
if args.language_model is not None:
- lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
+ lm_score_file = rerank_utils.rescore_file_name(
+ pre_gen, args.prefix_len, args.lm_name, lm_file=True
+ )
if args.language_model is not None and not os.path.isfile(lm_score_file):
print("STEP 4.5: language modeling for P(T)")
@@ -38,10 +54,21 @@ def score_lm(args):
else:
bpe_status = "different"
- rerank_utils.lm_scoring(lm_preprocessed_dir, bpe_status, gen_output, pre_gen,
- args.lm_dict, args.lm_name, args.language_model,
- args.lm_bpe_code, 128, lm_score_file, args.target_lang,
- args.source_lang, prefix_len=args.prefix_len)
+ rerank_utils.lm_scoring(
+ lm_preprocessed_dir,
+ bpe_status,
+ gen_output,
+ pre_gen,
+ args.lm_dict,
+ args.lm_name,
+ args.language_model,
+ args.lm_bpe_code,
+ 128,
+ lm_score_file,
+ args.target_lang,
+ args.source_lang,
+ prefix_len=args.prefix_len,
+ )
def cli_main():
@@ -50,5 +77,5 @@ def cli_main():
score_lm(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/examples/noisychannel/rerank_tune.py b/examples/noisychannel/rerank_tune.py
index 789096b3fa..1be71744a3 100644
--- a/examples/noisychannel/rerank_tune.py
+++ b/examples/noisychannel/rerank_tune.py
@@ -5,8 +5,8 @@
import argparse
import random
-import numpy as np
+import numpy as np
from fairseq import options
from . import rerank, rerank_options
@@ -14,7 +14,7 @@
def random_search(args):
param_values = []
- tuneable_parameters = ['lenpen', 'weight1', 'weight2', 'weight3']
+ tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"]
initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
for i, elem in enumerate(initial_params):
if type(elem) is not list:
@@ -33,51 +33,60 @@ def random_search(args):
param_values += initial_params
random.seed(args.seed)
- random_params = np.array([
- [random.uniform(args.lower_bound[i], args.upper_bound[i]) for i in range(len(args.tune_param))]
- for k in range(args.num_trials)
- ])
- set_params = np.array([
- [initial_params[i][0] for i in range(len(tuneable_parameters))]
- for k in range(args.num_trials)
- ])
+ random_params = np.array(
+ [
+ [
+ random.uniform(args.lower_bound[i], args.upper_bound[i])
+ for i in range(len(args.tune_param))
+ ]
+ for k in range(args.num_trials)
+ ]
+ )
+ set_params = np.array(
+ [
+ [initial_params[i][0] for i in range(len(tuneable_parameters))]
+ for k in range(args.num_trials)
+ ]
+ )
random_params = np.concatenate((random_params, set_params), 1)
rerank_args = vars(args).copy()
if args.nbest_list:
- rerank_args['gen_subset'] = 'test'
+ rerank_args["gen_subset"] = "test"
else:
- rerank_args['gen_subset'] = args.tune_subset
+ rerank_args["gen_subset"] = args.tune_subset
for k in range(len(tune_parameters)):
rerank_args[tune_parameters[k]] = list(random_params[:, k])
if args.share_weights:
- k = tune_parameters.index('weight2')
- rerank_args['weight3'] = list(random_params[:, k])
+ k = tune_parameters.index("weight2")
+ rerank_args["weight3"] = list(random_params[:, k])
rerank_args = argparse.Namespace(**rerank_args)
- best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(rerank_args)
+ best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(
+ rerank_args
+ )
rerank_args = vars(args).copy()
- rerank_args['lenpen'] = [best_lenpen]
- rerank_args['weight1'] = [best_weight1]
- rerank_args['weight2'] = [best_weight2]
- rerank_args['weight3'] = [best_weight3]
+ rerank_args["lenpen"] = [best_lenpen]
+ rerank_args["weight1"] = [best_weight1]
+ rerank_args["weight2"] = [best_weight2]
+ rerank_args["weight3"] = [best_weight3]
# write the hypothesis from the valid set from the best trial
if args.gen_subset != "valid":
- rerank_args['gen_subset'] = "valid"
+ rerank_args["gen_subset"] = "valid"
rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args)
# test with the best hyperparameters on gen subset
rerank_args = vars(args).copy()
- rerank_args['gen_subset'] = args.gen_subset
- rerank_args['lenpen'] = [best_lenpen]
- rerank_args['weight1'] = [best_weight1]
- rerank_args['weight2'] = [best_weight2]
- rerank_args['weight3'] = [best_weight3]
+ rerank_args["gen_subset"] = args.gen_subset
+ rerank_args["lenpen"] = [best_lenpen]
+ rerank_args["weight1"] = [best_weight1]
+ rerank_args["weight2"] = [best_weight2]
+ rerank_args["weight3"] = [best_weight3]
rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args)
@@ -89,5 +98,5 @@ def cli_main():
random_search(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/examples/noisychannel/rerank_utils.py b/examples/noisychannel/rerank_utils.py
index e1fcf918c5..2c6bf1b1af 100644
--- a/examples/noisychannel/rerank_utils.py
+++ b/examples/noisychannel/rerank_utils.py
@@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from contextlib import redirect_stdout
import math
import os
import re
import subprocess
+from contextlib import redirect_stdout
from fairseq import options
from fairseq_cli import eval_lm, preprocess
@@ -20,7 +20,7 @@ def reprocess(fle):
# per source, so the values for hypothesis_dict are lists.
# parses output of generate.py
- with open(fle, 'r') as f:
+ with open(fle, "r") as f:
txt = f.read()
"""reprocess generate.py output"""
@@ -45,7 +45,9 @@ def reprocess(fle):
if line_type == "H":
h_txt = line[j:]
hypo = re.search(hp, h_txt)
- assert hypo is not None, ("regular expression failed to find the hypothesis scoring")
+ assert (
+ hypo is not None
+ ), "regular expression failed to find the hypothesis scoring"
_, i = hypo.span()
score = hypo.group()
if id_num in hypothesis_dict:
@@ -56,9 +58,9 @@ def reprocess(fle):
score_dict[id_num] = [float(score)]
elif line_type == "S":
- source_dict[id_num] = (line[j:])
+ source_dict[id_num] = line[j:]
elif line_type == "T":
- target_dict[id_num] = (line[j:])
+ target_dict[id_num] = line[j:]
elif line_type == "P":
pos_scores = (line[j:]).split()
pos_scores = [float(x) for x in pos_scores]
@@ -72,7 +74,7 @@ def reprocess(fle):
def reprocess_nbest(fle):
"""reprocess interactive.py output"""
- with open(fle, 'r') as f:
+ with open(fle, "r") as f:
txt = f.read()
source_dict = {}
@@ -82,7 +84,7 @@ def reprocess_nbest(fle):
pos_score_dict = {}
lines = txt.split("\n")
- hp = re.compile(r'[-]?\d+[.]?\d+')
+ hp = re.compile(r"[-]?\d+[.]?\d+")
j = -1
for _i, line in enumerate(lines):
@@ -119,59 +121,76 @@ def reprocess_nbest(fle):
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
-def write_reprocessed(sources, hypos, targets, source_outfile,
- hypo_outfile, target_outfile, right_to_left=False,
- prefix_len=None, bpe_symbol=None,
- target_prefix_frac=None, source_prefix_frac=None):
+def write_reprocessed(
+ sources,
+ hypos,
+ targets,
+ source_outfile,
+ hypo_outfile,
+ target_outfile,
+ right_to_left=False,
+ prefix_len=None,
+ bpe_symbol=None,
+ target_prefix_frac=None,
+ source_prefix_frac=None,
+):
"""writes nbest hypothesis for rescoring"""
- assert not (prefix_len is not None and target_prefix_frac is not None), \
- "in writing reprocessed, only one type of prefix may be used"
- assert not (prefix_len is not None and source_prefix_frac is not None), \
- "in writing reprocessed, only one type of prefix may be used"
- assert not (target_prefix_frac is not None and source_prefix_frac is not None), \
- "in writing reprocessed, only one type of prefix may be used"
-
- with open(source_outfile, 'w') as source_file, \
- open(hypo_outfile, 'w') as hypo_file, \
- open(target_outfile, 'w') as target_file:
+ assert not (
+ prefix_len is not None and target_prefix_frac is not None
+ ), "in writing reprocessed, only one type of prefix may be used"
+ assert not (
+ prefix_len is not None and source_prefix_frac is not None
+ ), "in writing reprocessed, only one type of prefix may be used"
+ assert not (
+ target_prefix_frac is not None and source_prefix_frac is not None
+ ), "in writing reprocessed, only one type of prefix may be used"
+
+ with open(source_outfile, "w") as source_file, open(
+ hypo_outfile, "w"
+ ) as hypo_file, open(target_outfile, "w") as target_file:
assert len(sources) == len(hypos), "sources and hypos list length mismatch"
if right_to_left:
for i in range(len(sources)):
- for j in range(len(hypos[i])):
- if prefix_len is None:
- hypo_file.write(make_right_to_left(hypos[i][j])+"\n")
- else:
- raise NotImplementedError()
- source_file.write(make_right_to_left(sources[i])+"\n")
- target_file.write(make_right_to_left(targets[i])+"\n")
+ for j in range(len(hypos[i])):
+ if prefix_len is None:
+ hypo_file.write(make_right_to_left(hypos[i][j]) + "\n")
+ else:
+ raise NotImplementedError()
+ source_file.write(make_right_to_left(sources[i]) + "\n")
+ target_file.write(make_right_to_left(targets[i]) + "\n")
else:
for i in sorted(sources.keys()):
- for j in range(len(hypos[i])):
- if prefix_len is not None:
- shortened = get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)+"\n"
- hypo_file.write(shortened)
- source_file.write(sources[i])
- target_file.write(targets[i])
- elif target_prefix_frac is not None:
- num_words, shortened, num_bpe_tokens = \
- calc_length_from_frac(hypos[i][j], target_prefix_frac, bpe_symbol)
- shortened += "\n"
- hypo_file.write(shortened)
- source_file.write(sources[i])
- target_file.write(targets[i])
- elif source_prefix_frac is not None:
- num_words, shortened, num_bpe_tokensn = \
- calc_length_from_frac(sources[i], source_prefix_frac, bpe_symbol)
- shortened += "\n"
- hypo_file.write(hypos[i][j])
- source_file.write(shortened)
- target_file.write(targets[i])
- else:
- hypo_file.write(hypos[i][j])
- source_file.write(sources[i])
- target_file.write(targets[i])
+ for j in range(len(hypos[i])):
+ if prefix_len is not None:
+ shortened = (
+ get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)
+ + "\n"
+ )
+ hypo_file.write(shortened)
+ source_file.write(sources[i])
+ target_file.write(targets[i])
+ elif target_prefix_frac is not None:
+ num_words, shortened, num_bpe_tokens = calc_length_from_frac(
+ hypos[i][j], target_prefix_frac, bpe_symbol
+ )
+ shortened += "\n"
+ hypo_file.write(shortened)
+ source_file.write(sources[i])
+ target_file.write(targets[i])
+ elif source_prefix_frac is not None:
+ num_words, shortened, num_bpe_tokensn = calc_length_from_frac(
+ sources[i], source_prefix_frac, bpe_symbol
+ )
+ shortened += "\n"
+ hypo_file.write(hypos[i][j])
+ source_file.write(shortened)
+ target_file.write(targets[i])
+ else:
+ hypo_file.write(hypos[i][j])
+ source_file.write(sources[i])
+ target_file.write(targets[i])
def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
@@ -207,7 +226,9 @@ def get_prefix_from_len(sentence, bpe_symbol, prefix_len):
if bpe_count == 0:
return sentence[:prefix_len]
else:
- return sentence[:prefix_len]+get_prefix_from_len(sentence[prefix_len:], bpe_symbol, bpe_count)
+ return sentence[:prefix_len] + get_prefix_from_len(
+ sentence[prefix_len:], bpe_symbol, bpe_count
+ )
def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
@@ -225,9 +246,9 @@ def make_right_to_left(line):
def remove_bpe(line, bpe_symbol):
- line = line.replace("\n", '')
- line = (line + ' ').replace(bpe_symbol, '').rstrip()
- return line+("\n")
+ line = line.replace("\n", "")
+ line = (line + " ").replace(bpe_symbol, "").rstrip()
+ return line + ("\n")
def remove_bpe_dict(pred_dict, bpe_symbol):
@@ -242,7 +263,7 @@ def remove_bpe_dict(pred_dict, bpe_symbol):
def parse_bleu_scoring(line):
- p = re.compile(r'(BLEU4 = )\d+[.]\d+')
+ p = re.compile(r"(BLEU4 = )\d+[.]\d+")
res = re.search(p, line)
assert res is not None, line
return float(res.group()[8:])
@@ -259,9 +280,21 @@ def get_full_from_prefix(hypo_prefix, hypos):
raise Exception()
-def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=None,
- lenpen=None, src_len=None, tgt_len=None, bitext1_backwards=False,
- bitext2_backwards=False, normalize=False):
+def get_score(
+ a,
+ b,
+ c,
+ target_len,
+ bitext_score1,
+ bitext_score2=None,
+ lm_score=None,
+ lenpen=None,
+ src_len=None,
+ tgt_len=None,
+ bitext1_backwards=False,
+ bitext2_backwards=False,
+ normalize=False,
+):
if bitext1_backwards:
bitext1_norm = src_len
else:
@@ -275,9 +308,13 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
bitext2_norm = 1
bitext_score2 = 0
if normalize:
- score = a*bitext_score1/bitext1_norm + b*bitext_score2/bitext2_norm+c*lm_score/src_len
+ score = (
+ a * bitext_score1 / bitext1_norm
+ + b * bitext_score2 / bitext2_norm
+ + c * lm_score / src_len
+ )
else:
- score = a*bitext_score1 + b*bitext_score2+c*lm_score
+ score = a * bitext_score1 + b * bitext_score2 + c * lm_score
if lenpen is not None:
score /= (target_len) ** float(lenpen)
@@ -286,8 +323,16 @@ def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=N
class BitextOutput(object):
- def __init__(self, output_file, backwards, right_to_left, bpe_symbol,
- prefix_len=None, target_prefix_frac=None, source_prefix_frac=None):
+ def __init__(
+ self,
+ output_file,
+ backwards,
+ right_to_left,
+ bpe_symbol,
+ prefix_len=None,
+ target_prefix_frac=None,
+ source_prefix_frac=None,
+ ):
"""process output from rescoring"""
source, hypo, score, target, pos_score = reprocess(output_file)
if backwards:
@@ -296,7 +341,9 @@ def __init__(self, output_file, backwards, right_to_left, bpe_symbol,
self.hypo_fracs = target_prefix_frac
# remove length penalty so we can use raw scores
- score, num_bpe_tokens = get_score_from_pos(pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards)
+ score, num_bpe_tokens = get_score_from_pos(
+ pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards
+ )
source_lengths = {}
target_lengths = {}
@@ -341,7 +388,9 @@ def __init__(self, output_file, backwards, right_to_left, bpe_symbol,
score[i] = float(score[i][0])
pos_score[i] = pos_score[i][0]
else:
- assert len(hypo[i]) == 1, "expected only one hypothesis per source sentence"
+ assert (
+ len(hypo[i]) == 1
+ ), "expected only one hypothesis per source sentence"
source[i] = remove_bpe(source[i], bpe_symbol)
target[i] = remove_bpe(target[i], bpe_symbol)
hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
@@ -360,11 +409,26 @@ def __init__(self, output_file, backwards, right_to_left, bpe_symbol,
class BitextOutputFromGen(object):
- def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_len=None, target_prefix_frac=None):
+ def __init__(
+ self,
+ predictions_bpe_file,
+ bpe_symbol=None,
+ nbest=False,
+ prefix_len=None,
+ target_prefix_frac=None,
+ ):
if nbest:
- pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess_nbest(predictions_bpe_file)
+ (
+ pred_source,
+ pred_hypo,
+ pred_score,
+ pred_target,
+ pred_pos_score,
+ ) = reprocess_nbest(predictions_bpe_file)
else:
- pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(predictions_bpe_file)
+ pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(
+ predictions_bpe_file
+ )
assert len(pred_source) == len(pred_hypo)
assert len(pred_source) == len(pred_score)
@@ -372,8 +436,9 @@ def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_le
assert len(pred_source) == len(pred_pos_score)
# remove length penalty so we can use raw scores
- pred_score, num_bpe_tokens = get_score_from_pos(pred_pos_score, prefix_len, pred_hypo,
- bpe_symbol, target_prefix_frac, False)
+ pred_score, num_bpe_tokens = get_score_from_pos(
+ pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False
+ )
self.source = pred_source
self.target = pred_target
@@ -414,7 +479,9 @@ def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_le
index += 1
-def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards):
+def get_score_from_pos(
+ pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards
+):
score_dict = {}
num_bpe_tokens_dict = {}
assert prefix_len is None or hypo_frac is None
@@ -423,11 +490,15 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
num_bpe_tokens_dict[key] = []
for i in range(len(pos_score_dict[key])):
if prefix_len is not None and not backwards:
- num_bpe_tokens = get_num_bpe_tokens_from_len(hypo_dict[key][i], bpe_symbol, prefix_len)
+ num_bpe_tokens = get_num_bpe_tokens_from_len(
+ hypo_dict[key][i], bpe_symbol, prefix_len
+ )
score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
num_bpe_tokens_dict[key].append(num_bpe_tokens)
elif hypo_frac is not None:
- num_words, shortened, hypo_prefix_len = calc_length_from_frac(hypo_dict[key][i], hypo_frac, bpe_symbol)
+ num_words, shortened, hypo_prefix_len = calc_length_from_frac(
+ hypo_dict[key][i], hypo_frac, bpe_symbol
+ )
score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
num_bpe_tokens_dict[key].append(hypo_prefix_len)
else:
@@ -437,10 +508,26 @@ def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_f
class LMOutput(object):
- def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
- lm_sentences, lm_sen_scores, lm_sen_pos_scores, lm_no_bpe_sentences, lm_bpe_tokens = \
- parse_lm(lm_score_file, prefix_len=prefix_len,
- bpe_symbol=bpe_symbol, target_prefix_frac=target_prefix_frac)
+ def __init__(
+ self,
+ lm_score_file,
+ lm_dict=None,
+ prefix_len=None,
+ bpe_symbol=None,
+ target_prefix_frac=None,
+ ):
+ (
+ lm_sentences,
+ lm_sen_scores,
+ lm_sen_pos_scores,
+ lm_no_bpe_sentences,
+ lm_bpe_tokens,
+ ) = parse_lm(
+ lm_score_file,
+ prefix_len=prefix_len,
+ bpe_symbol=bpe_symbol,
+ target_prefix_frac=target_prefix_frac,
+ )
self.sentences = lm_sentences
self.score = lm_sen_scores
@@ -452,7 +539,7 @@ def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None
def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
"""parse output of eval_lm"""
- with open(input_file, 'r') as f:
+ with open(input_file, "r") as f:
text = f.readlines()
text = text[7:]
cleaned_text = text[:-2]
@@ -467,20 +554,23 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
if tokens[0].isdigit():
line_id = int(tokens[0])
scores = [float(x[1:-1]) for x in tokens[2::2]]
- sentences[line_id] = " ".join(tokens[1::2][:-1])+"\n"
+ sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n"
if bpe_symbol is not None:
# exclude symbol to match output from generate.py
- bpe_sen = " ".join(tokens[1::2][:-1])+"\n"
+ bpe_sen = " ".join(tokens[1::2][:-1]) + "\n"
no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
no_bpe_sentences[line_id] = no_bpe_sen
if prefix_len is not None:
- num_bpe_tokens = get_num_bpe_tokens_from_len(bpe_sen, bpe_symbol, prefix_len)
+ num_bpe_tokens = get_num_bpe_tokens_from_len(
+ bpe_sen, bpe_symbol, prefix_len
+ )
sen_scores[line_id] = sum(scores[:num_bpe_tokens])
num_bpe_tokens_dict[line_id] = num_bpe_tokens
elif target_prefix_frac is not None:
- num_words, shortened, target_prefix_len = calc_length_from_frac(bpe_sen, target_prefix_frac,
- bpe_symbol)
+ num_words, shortened, target_prefix_len = calc_length_from_frac(
+ bpe_sen, target_prefix_frac, bpe_symbol
+ )
sen_scores[line_id] = sum(scores[:target_prefix_len])
num_bpe_tokens_dict[line_id] = target_prefix_len
else:
@@ -492,160 +582,269 @@ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=No
return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
-def get_directories(data_dir_name, num_rescore, gen_subset,
- fw_name, shard_id, num_shards,
- sampling=False, prefix_len=None,
- target_prefix_frac=None, source_prefix_frac=None):
- nbest_file_id = "nbest_" + str(num_rescore) + \
- "_subset_" + gen_subset + \
- "_fw_name_" + fw_name + \
- "_shard_" + str(shard_id) + \
- "_of_" + str(num_shards)
+def get_directories(
+ data_dir_name,
+ num_rescore,
+ gen_subset,
+ fw_name,
+ shard_id,
+ num_shards,
+ sampling=False,
+ prefix_len=None,
+ target_prefix_frac=None,
+ source_prefix_frac=None,
+):
+ nbest_file_id = (
+ "nbest_"
+ + str(num_rescore)
+ + "_subset_"
+ + gen_subset
+ + "_fw_name_"
+ + fw_name
+ + "_shard_"
+ + str(shard_id)
+ + "_of_"
+ + str(num_shards)
+ )
if sampling:
nbest_file_id += "_sampling"
# the directory containing all information for this nbest list
- pre_gen = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+data_dir_name+"/"+nbest_file_id
+ pre_gen = (
+ os.path.join(os.path.dirname(__file__))
+ + "/rerank_data/"
+ + data_dir_name
+ + "/"
+ + nbest_file_id
+ )
# the directory to store the preprocessed nbest list, for left to right rescoring
- left_to_right_preprocessed_dir = pre_gen+"/left_to_right_preprocessed"
+ left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed"
if source_prefix_frac is not None:
- left_to_right_preprocessed_dir = left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
+ left_to_right_preprocessed_dir = (
+ left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
+ )
# the directory to store the preprocessed nbest list, for right to left rescoring
- right_to_left_preprocessed_dir = pre_gen+"/right_to_left_preprocessed"
+ right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed"
# the directory to store the preprocessed nbest list, for backwards rescoring
- backwards_preprocessed_dir = pre_gen+"/backwards"
+ backwards_preprocessed_dir = pre_gen + "/backwards"
if target_prefix_frac is not None:
- backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_frac"+str(target_prefix_frac)
+ backwards_preprocessed_dir = (
+ backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac)
+ )
elif prefix_len is not None:
- backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_"+str(prefix_len)
+ backwards_preprocessed_dir = (
+ backwards_preprocessed_dir + "/prefix_" + str(prefix_len)
+ )
# the directory to store the preprocessed nbest list, for rescoring with P(T)
- lm_preprocessed_dir = pre_gen+"/lm_preprocessed"
-
- return pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
- backwards_preprocessed_dir, lm_preprocessed_dir
-
-
-def lm_scoring(preprocess_directory, bpe_status, gen_output, pre_gen,
- cur_lm_dict, cur_lm_name, cur_language_model, cur_lm_bpe_code,
- batch_size, lm_score_file, target_lang, source_lang, prefix_len=None):
+ lm_preprocessed_dir = pre_gen + "/lm_preprocessed"
+
+ return (
+ pre_gen,
+ left_to_right_preprocessed_dir,
+ right_to_left_preprocessed_dir,
+ backwards_preprocessed_dir,
+ lm_preprocessed_dir,
+ )
+
+
+def lm_scoring(
+ preprocess_directory,
+ bpe_status,
+ gen_output,
+ pre_gen,
+ cur_lm_dict,
+ cur_lm_name,
+ cur_language_model,
+ cur_lm_bpe_code,
+ batch_size,
+ lm_score_file,
+ target_lang,
+ source_lang,
+ prefix_len=None,
+):
if prefix_len is not None:
- assert bpe_status == "different", "bpe status must be different to use prefix len"
+ assert (
+ bpe_status == "different"
+ ), "bpe status must be different to use prefix len"
if bpe_status == "no bpe":
# run lm on output without bpe
- write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
- gen_output.no_bpe_target, pre_gen+"/rescore_data_no_bpe.de",
- pre_gen+"/rescore_data_no_bpe.en", pre_gen+"/reference_file_no_bpe")
-
- preprocess_lm_param = ["--only-source",
- "--trainpref", pre_gen+"/rescore_data_no_bpe."+target_lang,
- "--srcdict", cur_lm_dict,
- "--destdir", preprocess_directory]
+ write_reprocessed(
+ gen_output.no_bpe_source,
+ gen_output.no_bpe_hypo,
+ gen_output.no_bpe_target,
+ pre_gen + "/rescore_data_no_bpe.de",
+ pre_gen + "/rescore_data_no_bpe.en",
+ pre_gen + "/reference_file_no_bpe",
+ )
+
+ preprocess_lm_param = [
+ "--only-source",
+ "--trainpref",
+ pre_gen + "/rescore_data_no_bpe." + target_lang,
+ "--srcdict",
+ cur_lm_dict,
+ "--destdir",
+ preprocess_directory,
+ ]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
- eval_lm_param = [preprocess_directory,
- "--path", cur_language_model,
- "--output-word-probs",
- "--batch-size", str(batch_size),
- "--max-tokens", "1024",
- "--sample-break-mode", "eos",
- "--gen-subset", "train"]
+ eval_lm_param = [
+ preprocess_directory,
+ "--path",
+ cur_language_model,
+ "--output-word-probs",
+ "--batch-size",
+ str(batch_size),
+ "--max-tokens",
+ "1024",
+ "--sample-break-mode",
+ "eos",
+ "--gen-subset",
+ "train",
+ ]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
- with open(lm_score_file, 'w') as f:
+ with open(lm_score_file, "w") as f:
with redirect_stdout(f):
eval_lm.main(input_args)
elif bpe_status == "shared":
- preprocess_lm_param = ["--only-source",
- "--trainpref", pre_gen+"/rescore_data."+target_lang,
- "--srcdict", cur_lm_dict,
- "--destdir", preprocess_directory]
- preprocess_parser = options.get_preprocessing_parser()
- input_args = preprocess_parser.parse_args(preprocess_lm_param)
- preprocess.main(input_args)
-
- eval_lm_param = [preprocess_directory,
- "--path", cur_language_model,
- "--output-word-probs",
- "--batch-size", str(batch_size),
- "--sample-break-mode", "eos",
- "--gen-subset", "train"]
-
- eval_lm_parser = options.get_eval_lm_parser()
- input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
-
- with open(lm_score_file, 'w') as f:
- with redirect_stdout(f):
- eval_lm.main(input_args)
+ preprocess_lm_param = [
+ "--only-source",
+ "--trainpref",
+ pre_gen + "/rescore_data." + target_lang,
+ "--srcdict",
+ cur_lm_dict,
+ "--destdir",
+ preprocess_directory,
+ ]
+ preprocess_parser = options.get_preprocessing_parser()
+ input_args = preprocess_parser.parse_args(preprocess_lm_param)
+ preprocess.main(input_args)
+
+ eval_lm_param = [
+ preprocess_directory,
+ "--path",
+ cur_language_model,
+ "--output-word-probs",
+ "--batch-size",
+ str(batch_size),
+ "--sample-break-mode",
+ "eos",
+ "--gen-subset",
+ "train",
+ ]
+
+ eval_lm_parser = options.get_eval_lm_parser()
+ input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
+
+ with open(lm_score_file, "w") as f:
+ with redirect_stdout(f):
+ eval_lm.main(input_args)
elif bpe_status == "different":
- rescore_file = pre_gen+"/rescore_data_no_bpe"
- rescore_bpe = pre_gen+"/rescore_data_new_bpe"
+ rescore_file = pre_gen + "/rescore_data_no_bpe"
+ rescore_bpe = pre_gen + "/rescore_data_new_bpe"
rescore_file += "."
rescore_bpe += "."
- write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
- gen_output.no_bpe_target, rescore_file+source_lang,
- rescore_file+target_lang, pre_gen+"/reference_file_no_bpe",
- bpe_symbol=None)
+ write_reprocessed(
+ gen_output.no_bpe_source,
+ gen_output.no_bpe_hypo,
+ gen_output.no_bpe_target,
+ rescore_file + source_lang,
+ rescore_file + target_lang,
+ pre_gen + "/reference_file_no_bpe",
+ bpe_symbol=None,
+ )
# apply LM bpe to nbest list
- bpe_src_param = ["-c", cur_lm_bpe_code,
- "--input", rescore_file+target_lang,
- "--output", rescore_bpe+target_lang]
- subprocess.call(["python",
- os.path.join(os.path.dirname(__file__),
- "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
- shell=False)
+ bpe_src_param = [
+ "-c",
+ cur_lm_bpe_code,
+ "--input",
+ rescore_file + target_lang,
+ "--output",
+ rescore_bpe + target_lang,
+ ]
+ subprocess.call(
+ [
+ "python",
+ os.path.join(
+ os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
+ ),
+ ]
+ + bpe_src_param,
+ shell=False,
+ )
# uncomment to use fastbpe instead of subword-nmt bpe
# bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
# subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
preprocess_dir = preprocess_directory
- preprocess_lm_param = ["--only-source",
- "--trainpref", rescore_bpe+target_lang,
- "--srcdict", cur_lm_dict,
- "--destdir", preprocess_dir]
+ preprocess_lm_param = [
+ "--only-source",
+ "--trainpref",
+ rescore_bpe + target_lang,
+ "--srcdict",
+ cur_lm_dict,
+ "--destdir",
+ preprocess_dir,
+ ]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
- eval_lm_param = [preprocess_dir,
- "--path", cur_language_model,
- "--output-word-probs",
- "--batch-size", str(batch_size),
- "--max-tokens", "1024",
- "--sample-break-mode", "eos",
- "--gen-subset", "train"]
+ eval_lm_param = [
+ preprocess_dir,
+ "--path",
+ cur_language_model,
+ "--output-word-probs",
+ "--batch-size",
+ str(batch_size),
+ "--max-tokens",
+ "1024",
+ "--sample-break-mode",
+ "eos",
+ "--gen-subset",
+ "train",
+ ]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
- with open(lm_score_file, 'w') as f:
+ with open(lm_score_file, "w") as f:
with redirect_stdout(f):
eval_lm.main(input_args)
-def rescore_file_name(nbest_dir, prefix_len, scorer_name, lm_file=False,
- target_prefix_frac=None, source_prefix_frac=None, backwards=None):
+def rescore_file_name(
+ nbest_dir,
+ prefix_len,
+ scorer_name,
+ lm_file=False,
+ target_prefix_frac=None,
+ source_prefix_frac=None,
+ backwards=None,
+):
if lm_file:
- score_file = nbest_dir+"/lm_score_translations_model_"+scorer_name+".txt"
+ score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt"
else:
- score_file = nbest_dir+"/"+scorer_name+"_score_translations.txt"
+ score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt"
if backwards:
if prefix_len is not None:
- score_file += "prefix_len"+str(prefix_len)
+ score_file += "prefix_len" + str(prefix_len)
elif target_prefix_frac is not None:
- score_file += "target_prefix_frac"+str(target_prefix_frac)
+ score_file += "target_prefix_frac" + str(target_prefix_frac)
else:
if source_prefix_frac is not None:
- score_file += "source_prefix_frac"+str(source_prefix_frac)
+ score_file += "source_prefix_frac" + str(source_prefix_frac)
return score_file
diff --git a/examples/paraphraser/paraphrase.py b/examples/paraphraser/paraphrase.py
index 405df296b3..d3422fb3db 100644
--- a/examples/paraphraser/paraphrase.py
+++ b/examples/paraphraser/paraphrase.py
@@ -13,57 +13,66 @@
def main():
- parser = argparse.ArgumentParser(description='')
- parser.add_argument('--en2fr', required=True,
- help='path to en2fr model')
- parser.add_argument('--fr2en', required=True,
- help='path to fr2en mixture of experts model')
- parser.add_argument('--user-dir',
- help='path to fairseq examples/translation_moe/src directory')
- parser.add_argument('--num-experts', type=int, default=10,
- help='(keep at 10 unless using a different model)')
- parser.add_argument('files', nargs='*', default=['-'],
- help='input files to paraphrase; "-" for stdin')
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("--en2fr", required=True, help="path to en2fr model")
+ parser.add_argument(
+ "--fr2en", required=True, help="path to fr2en mixture of experts model"
+ )
+ parser.add_argument(
+ "--user-dir", help="path to fairseq examples/translation_moe/src directory"
+ )
+ parser.add_argument(
+ "--num-experts",
+ type=int,
+ default=10,
+ help="(keep at 10 unless using a different model)",
+ )
+ parser.add_argument(
+ "files",
+ nargs="*",
+ default=["-"],
+ help='input files to paraphrase; "-" for stdin',
+ )
args = parser.parse_args()
if args.user_dir is None:
args.user_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
- 'translation_moe',
- 'src',
+ "translation_moe",
+ "src",
)
if os.path.exists(args.user_dir):
- logging.info('found user_dir:' + args.user_dir)
+ logging.info("found user_dir:" + args.user_dir)
else:
raise RuntimeError(
- 'cannot find fairseq examples/translation_moe/src '
- '(tried looking here: {})'.format(args.user_dir)
+ "cannot find fairseq examples/translation_moe/src "
+ "(tried looking here: {})".format(args.user_dir)
)
- logging.info('loading en2fr model from:' + args.en2fr)
+ logging.info("loading en2fr model from:" + args.en2fr)
en2fr = TransformerModel.from_pretrained(
model_name_or_path=args.en2fr,
- tokenizer='moses',
- bpe='sentencepiece',
+ tokenizer="moses",
+ bpe="sentencepiece",
).eval()
- logging.info('loading fr2en model from:' + args.fr2en)
+ logging.info("loading fr2en model from:" + args.fr2en)
fr2en = TransformerModel.from_pretrained(
model_name_or_path=args.fr2en,
- tokenizer='moses',
- bpe='sentencepiece',
+ tokenizer="moses",
+ bpe="sentencepiece",
user_dir=args.user_dir,
- task='translation_moe',
+ task="translation_moe",
).eval()
def gen_paraphrases(en):
fr = en2fr.translate(en)
return [
- fr2en.translate(fr, inference_step_args={'expert': i})
+ fr2en.translate(fr, inference_step_args={"expert": i})
for i in range(args.num_experts)
]
- logging.info('Type the input sentence and press return:')
+ logging.info("Type the input sentence and press return:")
for line in fileinput.input(args.files):
line = line.strip()
if len(line) == 0:
@@ -72,5 +81,5 @@ def gen_paraphrases(en):
print(paraphrase)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/pointer_generator/postprocess.py b/examples/pointer_generator/postprocess.py
index a01434b5ce..b213aed80f 100755
--- a/examples/pointer_generator/postprocess.py
+++ b/examples/pointer_generator/postprocess.py
@@ -4,9 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import sys
-import re
import argparse
+import re
+import sys
class OOVIndexError(IndexError):
@@ -25,8 +25,8 @@ def __init__(self, pos, source_seq, target_seq):
def replace_oovs(source_in, target_in, target_out):
"""Replaces tokens in the target text with the corresponding word in
- the source text.
- """
+ the source text.
+ """
oov_re = re.compile("^$")
diff --git a/examples/pointer_generator/preprocess.py b/examples/pointer_generator/preprocess.py
index 4b7a5ab9c5..f72ca7d3d9 100755
--- a/examples/pointer_generator/preprocess.py
+++ b/examples/pointer_generator/preprocess.py
@@ -10,8 +10,8 @@
def replace_oovs(source_in, target_in, vocabulary, source_out, target_out):
"""Replaces out-of-vocabulary words in source and target text with ,
- where N in is the position of the word in the source sequence.
- """
+ where N in is the position of the word in the source sequence.
+ """
def format_unk(pos):
return "".format(pos)
diff --git a/examples/pointer_generator/src/transformer_pg.py b/examples/pointer_generator/src/transformer_pg.py
index af933b3495..079fdda581 100644
--- a/examples/pointer_generator/src/transformer_pg.py
+++ b/examples/pointer_generator/src/transformer_pg.py
@@ -8,19 +8,17 @@
import torch
import torch.nn as nn
-
-from fairseq import utils, metrics
+from fairseq import metrics, utils
from fairseq.models import register_model, register_model_architecture
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import (
- TransformerModel,
+ DEFAULT_MAX_SOURCE_POSITIONS,
+ DEFAULT_MAX_TARGET_POSITIONS,
TransformerDecoder,
TransformerEncoder,
+ TransformerModel,
base_architecture,
- DEFAULT_MAX_SOURCE_POSITIONS,
- DEFAULT_MAX_TARGET_POSITIONS,
)
-
from torch import Tensor
diff --git a/examples/roberta/commonsense_qa/commonsense_qa_task.py b/examples/roberta/commonsense_qa/commonsense_qa_task.py
index 7ed2bc36a4..216093f708 100644
--- a/examples/roberta/commonsense_qa/commonsense_qa_task.py
+++ b/examples/roberta/commonsense_qa/commonsense_qa_task.py
@@ -8,40 +8,44 @@
import numpy as np
import torch
-
from fairseq.data import (
- data_utils,
Dictionary,
- encoders,
IdDataset,
ListDataset,
NestedDictionaryDataset,
- NumSamplesDataset,
NumelDataset,
+ NumSamplesDataset,
RawLabelDataset,
RightPadDataset,
SortDataset,
+ data_utils,
+ encoders,
)
-from fairseq.tasks import register_task, LegacyFairseqTask
+from fairseq.tasks import LegacyFairseqTask, register_task
-@register_task('commonsense_qa')
+@register_task("commonsense_qa")
class CommonsenseQATask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Commonsense QA."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', metavar='DIR',
- help='path to data directory; we load .jsonl')
- parser.add_argument('--init-token', type=int, default=None,
- help='add token at the beginning of each batch item')
- parser.add_argument('--num-classes', type=int, default=5)
+ parser.add_argument(
+ "data", metavar="DIR", help="path to data directory; we load .jsonl"
+ )
+ parser.add_argument(
+ "--init-token",
+ type=int,
+ default=None,
+ help="add token at the beginning of each batch item",
+ )
+ parser.add_argument("--num-classes", type=int, default=5)
def __init__(self, args, vocab):
super().__init__(args)
self.vocab = vocab
- self.mask = vocab.add_symbol('')
+ self.mask = vocab.add_symbol("")
self.bpe = encoders.build_bpe(args)
@@ -53,20 +57,24 @@ def load_dictionary(cls, filename):
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
- dictionary.add_symbol('')
+ dictionary.add_symbol("")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
- assert args.criterion == 'sentence_ranking', 'Must set --criterion=sentence_ranking'
+ assert (
+ args.criterion == "sentence_ranking"
+ ), "Must set --criterion=sentence_ranking"
# load data and label dictionaries
- vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
- print('| dictionary: {} types'.format(len(vocab)))
+ vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
+ print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
- def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
+ def load_dataset(
+ self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
+ ):
"""Load a given dataset split.
Args:
@@ -77,16 +85,18 @@ def binarize(s, append_bos=False):
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
- s, append_eos=True, add_if_not_exist=False,
+ s,
+ append_eos=True,
+ add_if_not_exist=False,
).long()
if append_bos and self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens
if data_path is None:
- data_path = os.path.join(self.args.data, split + '.jsonl')
+ data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
- raise FileNotFoundError('Cannot find data: {}'.format(data_path))
+ raise FileNotFoundError("Cannot find data: {}".format(data_path))
src_tokens = [[] for i in range(self.args.num_classes)]
src_lengths = [[] for i in range(self.args.num_classes)]
@@ -95,20 +105,23 @@ def binarize(s, append_bos=False):
with open(data_path) as h:
for line in h:
example = json.loads(line.strip())
- if 'answerKey' in example:
- label = ord(example['answerKey']) - ord('A')
+ if "answerKey" in example:
+ label = ord(example["answerKey"]) - ord("A")
labels.append(label)
- question = example['question']['stem']
- assert len(example['question']['choices']) == self.args.num_classes
+ question = example["question"]["stem"]
+ assert len(example["question"]["choices"]) == self.args.num_classes
# format: ` Q: Where would I not want a fox? A: hen house `
- question = 'Q: ' + question
+ question = "Q: " + question
question_toks = binarize(question, append_bos=True)
- for i, choice in enumerate(example['question']['choices']):
- src = 'A: ' + choice['text']
+ for i, choice in enumerate(example["question"]["choices"]):
+ src = "A: " + choice["text"]
src_bin = torch.cat([question_toks, binarize(src)])
src_tokens[i].append(src_bin)
src_lengths[i].append(len(src_bin))
- assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes))
+ assert all(
+ len(src_tokens[0]) == len(src_tokens[i])
+ for i in range(self.args.num_classes)
+ )
assert len(src_tokens[0]) == len(src_lengths[0])
assert len(labels) == 0 or len(labels) == len(src_tokens[0])
@@ -118,24 +131,26 @@ def binarize(s, append_bos=False):
src_lengths[i] = ListDataset(src_lengths[i])
dataset = {
- 'id': IdDataset(),
- 'nsentences': NumSamplesDataset(),
- 'ntokens': NumelDataset(src_tokens[0], reduce=True),
+ "id": IdDataset(),
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_tokens[0], reduce=True),
}
for i in range(self.args.num_classes):
- dataset.update({
- 'net_input{}'.format(i + 1): {
- 'src_tokens': RightPadDataset(
- src_tokens[i],
- pad_idx=self.source_dictionary.pad(),
- ),
- 'src_lengths': src_lengths[i],
+ dataset.update(
+ {
+ "net_input{}".format(i + 1): {
+ "src_tokens": RightPadDataset(
+ src_tokens[i],
+ pad_idx=self.source_dictionary.pad(),
+ ),
+ "src_lengths": src_lengths[i],
+ }
}
- })
+ )
if len(labels) > 0:
- dataset.update({'target': RawLabelDataset(labels)})
+ dataset.update({"target": RawLabelDataset(labels)})
dataset = NestedDictionaryDataset(
dataset,
@@ -149,17 +164,18 @@ def binarize(s, append_bos=False):
sort_order=[np.random.permutation(len(dataset))],
)
- print('| Loaded {} with {} samples'.format(split, len(dataset)))
+ print("| Loaded {} with {} samples".format(split, len(dataset)))
self.datasets[split] = dataset
return self.datasets[split]
def build_model(self, args):
from fairseq import models
+
model = models.build_model(args, self)
model.register_classification_head(
- 'sentence_classification_head',
+ "sentence_classification_head",
num_classes=1,
)
diff --git a/examples/roberta/multiprocessing_bpe_encoder.py b/examples/roberta/multiprocessing_bpe_encoder.py
index f0240c210f..43fe0451bf 100644
--- a/examples/roberta/multiprocessing_bpe_encoder.py
+++ b/examples/roberta/multiprocessing_bpe_encoder.py
@@ -8,7 +8,6 @@
import argparse
import contextlib
import sys
-
from collections import Counter
from multiprocessing import Pool
@@ -26,23 +25,23 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder-json",
- help='path to encoder.json',
+ help="path to encoder.json",
)
parser.add_argument(
"--vocab-bpe",
type=str,
- help='path to vocab.bpe',
+ help="path to vocab.bpe",
)
parser.add_argument(
"--inputs",
nargs="+",
- default=['-'],
+ default=["-"],
help="input files to filter/encode",
)
parser.add_argument(
"--outputs",
nargs="+",
- default=['-'],
+ default=["-"],
help="path to save encoded outputs",
)
parser.add_argument(
@@ -53,18 +52,21 @@ def main():
parser.add_argument("--workers", type=int, default=20)
args = parser.parse_args()
- assert len(args.inputs) == len(args.outputs), \
- "number of input and output paths should match"
+ assert len(args.inputs) == len(
+ args.outputs
+ ), "number of input and output paths should match"
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
- if input != "-" else sys.stdin
+ if input != "-"
+ else sys.stdin
for input in args.inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
- if output != "-" else sys.stdout
+ if output != "-"
+ else sys.stdout
for output in args.outputs
]
@@ -87,7 +89,6 @@ def main():
class MultiprocessingEncoder(object):
-
def __init__(self, args):
self.args = args
diff --git a/examples/roberta/preprocess_RACE.py b/examples/roberta/preprocess_RACE.py
index f6f606a389..cdd6607271 100644
--- a/examples/roberta/preprocess_RACE.py
+++ b/examples/roberta/preprocess_RACE.py
@@ -25,7 +25,7 @@ def get_examples(data_dir, set_type):
examples = []
levels = ["middle", "high"]
- set_type_c = set_type.split('-')
+ set_type_c = set_type.split("-")
if len(set_type_c) == 2:
levels = [set_type_c[1]]
set_type = set_type_c[0]
@@ -33,13 +33,13 @@ def get_examples(data_dir, set_type):
cur_dir = os.path.join(data_dir, set_type, level)
for filename in os.listdir(cur_dir):
cur_path = os.path.join(cur_dir, filename)
- with open(cur_path, 'r') as f:
+ with open(cur_path, "r") as f:
cur_data = json.load(f)
answers = cur_data["answers"]
options = cur_data["options"]
questions = cur_data["questions"]
context = cur_data["article"].replace("\n", " ")
- context = re.sub(r'\s+', ' ', context)
+ context = re.sub(r"\s+", " ", context)
for i in range(len(answers)):
label = ord(answers[i]) - ord("A")
qa_list = []
@@ -50,7 +50,7 @@ def get_examples(data_dir, set_type):
qa_cat = question.replace("_", option)
else:
qa_cat = " ".join([question, option])
- qa_cat = re.sub(r'\s+', ' ', qa_cat)
+ qa_cat = re.sub(r"\s+", " ", qa_cat)
qa_list.append(qa_cat)
examples.append(InputExample(context, qa_list, label))
@@ -64,11 +64,11 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-dir",
- help='input directory for downloaded RACE dataset',
+ help="input directory for downloaded RACE dataset",
)
parser.add_argument(
"--output-dir",
- help='output directory for extracted data',
+ help="output directory for extracted data",
)
args = parser.parse_args()
@@ -77,17 +77,20 @@ def main():
for set_type in ["train", "dev", "test-middle", "test-high"]:
examples = get_examples(args.input_dir, set_type)
- qa_file_paths = [os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4)]
- qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths]
+ qa_file_paths = [
+ os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
+ for i in range(4)
+ ]
+ qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
- outf_context = open(outf_context_path, 'w')
- outf_label = open(outf_label_path, 'w')
+ outf_context = open(outf_context_path, "w")
+ outf_label = open(outf_label_path, "w")
for example in examples:
- outf_context.write(example.paragraph + '\n')
+ outf_context.write(example.paragraph + "\n")
for i in range(4):
- qa_files[i].write(example.qa_list[i] + '\n')
- outf_label.write(str(example.label) + '\n')
+ qa_files[i].write(example.qa_list[i] + "\n")
+ outf_label.write(str(example.label) + "\n")
for f in qa_files:
f.close()
@@ -95,5 +98,5 @@ def main():
outf_context.close()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/roberta/wsc/wsc_criterion.py b/examples/roberta/wsc/wsc_criterion.py
index dd909ab20c..1a5901234b 100644
--- a/examples/roberta/wsc/wsc_criterion.py
+++ b/examples/roberta/wsc/wsc_criterion.py
@@ -7,19 +7,17 @@
import torch
import torch.nn.functional as F
-
from fairseq import utils
-from fairseq.data import encoders
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
+from fairseq.data import encoders
-@register_criterion('wsc')
+@register_criterion("wsc")
class WSCCriterion(LegacyFairseqCriterion):
-
def __init__(self, args, task):
super().__init__(args, task)
if self.args.save_predictions is not None:
- self.prediction_h = open(self.args.save_predictions, 'w')
+ self.prediction_h = open(self.args.save_predictions, "w")
else:
self.prediction_h = None
self.bpe = encoders.build_bpe(args)
@@ -32,12 +30,16 @@ def __del__(self):
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
- parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0)
- parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0)
- parser.add_argument('--wsc-cross-entropy', action='store_true',
- help='use cross entropy formulation instead of margin loss')
- parser.add_argument('--save-predictions', metavar='FILE',
- help='file to save predictions to')
+ parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
+ parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
+ parser.add_argument(
+ "--wsc-cross-entropy",
+ action="store_true",
+ help="use cross entropy formulation instead of margin loss",
+ )
+ parser.add_argument(
+ "--save-predictions", metavar="FILE", help="file to save predictions to"
+ )
def get_masked_input(self, tokens, mask):
masked_tokens = tokens.clone()
@@ -60,27 +62,26 @@ def get_loss(self, query_lprobs, cand_lprobs):
)
else:
return (
- - query_lprobs
- + self.args.wsc_margin_alpha * (
- cand_lprobs - query_lprobs + self.args.wsc_margin_beta
- ).clamp(min=0)
+ -query_lprobs
+ + self.args.wsc_margin_alpha
+ * (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
).sum()
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
- loss, nloss = 0., 0
+ loss, nloss = 0.0, 0
ncorrect, nqueries = 0, 0
- for i, label in enumerate(sample['labels']):
+ for i, label in enumerate(sample["labels"]):
query_lprobs = self.get_lprobs(
model,
- sample['query_tokens'][i].unsqueeze(0),
- sample['query_masks'][i].unsqueeze(0),
+ sample["query_tokens"][i].unsqueeze(0),
+ sample["query_masks"][i].unsqueeze(0),
)
cand_lprobs = self.get_lprobs(
model,
- sample['candidate_tokens'][i],
- sample['candidate_masks'][i],
+ sample["candidate_tokens"][i],
+ sample["candidate_masks"][i],
)
pred = (query_lprobs >= cand_lprobs).all().item()
@@ -95,72 +96,72 @@ def forward(self, model, sample, reduce=True):
nloss += 1
loss += self.get_loss(query_lprobs, cand_lprobs)
- id = sample['id'][i].item()
+ id = sample["id"][i].item()
if self.prediction_h is not None:
- print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h)
+ print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
if nloss == 0:
loss = torch.tensor(0.0, requires_grad=True)
sample_size = nqueries if nqueries > 0 else 1
logging_output = {
- 'loss': utils.item(loss.data) if reduce else loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample['nsentences'],
- 'sample_size': sample_size,
- 'ncorrect': ncorrect,
- 'nqueries': nqueries,
+ "loss": utils.item(loss.data) if reduce else loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample["nsentences"],
+ "sample_size": sample_size,
+ "ncorrect": ncorrect,
+ "nqueries": nqueries,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
- loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
- ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
- nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
- sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
- 'loss': loss_sum / sample_size / math.log(2),
- 'ntokens': ntokens,
- 'nsentences': nsentences,
- 'sample_size': sample_size,
+ "loss": loss_sum / sample_size / math.log(2),
+ "ntokens": ntokens,
+ "nsentences": nsentences,
+ "sample_size": sample_size,
}
- ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
- nqueries = sum(log.get('nqueries', 0) for log in logging_outputs)
+ ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
+ nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
if nqueries > 0:
- agg_output['accuracy'] = ncorrect / float(nqueries)
+ agg_output["accuracy"] = ncorrect / float(nqueries)
return agg_output
-@register_criterion('winogrande')
+@register_criterion("winogrande")
class WinograndeCriterion(WSCCriterion):
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
query_lprobs = self.get_lprobs(
model,
- sample['query_tokens'],
- sample['query_masks'],
+ sample["query_tokens"],
+ sample["query_masks"],
)
cand_lprobs = self.get_lprobs(
model,
- sample['candidate_tokens'],
- sample['candidate_masks'],
+ sample["candidate_tokens"],
+ sample["candidate_masks"],
)
pred = query_lprobs >= cand_lprobs
loss = self.get_loss(query_lprobs, cand_lprobs)
- sample_size = sample['query_tokens'].size(0)
+ sample_size = sample["query_tokens"].size(0)
ncorrect = pred.sum().item()
logging_output = {
- 'loss': utils.item(loss.data) if reduce else loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample['nsentences'],
- 'sample_size': sample_size,
- 'ncorrect': ncorrect,
- 'nqueries': sample_size,
+ "loss": utils.item(loss.data) if reduce else loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample["nsentences"],
+ "sample_size": sample_size,
+ "ncorrect": ncorrect,
+ "nqueries": sample_size,
}
return loss, sample_size, logging_output
diff --git a/examples/roberta/wsc/wsc_task.py b/examples/roberta/wsc/wsc_task.py
index 058e3eea23..602ea737ed 100644
--- a/examples/roberta/wsc/wsc_task.py
+++ b/examples/roberta/wsc/wsc_task.py
@@ -10,47 +10,51 @@
import numpy as np
import torch
import torch.nn.functional as F
-
from fairseq import utils
from fairseq.data import (
- data_utils,
Dictionary,
- encoders,
IdDataset,
ListDataset,
NestedDictionaryDataset,
- NumSamplesDataset,
NumelDataset,
+ NumSamplesDataset,
PadDataset,
SortDataset,
+ data_utils,
+ encoders,
)
-from fairseq.tasks import register_task, LegacyFairseqTask
+from fairseq.tasks import LegacyFairseqTask, register_task
from . import wsc_utils
-@register_task('wsc')
+@register_task("wsc")
class WSCTask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', metavar='DIR',
- help='path to data directory; we load .jsonl')
- parser.add_argument('--init-token', type=int, default=None,
- help='add token at the beginning of each batch item')
+ parser.add_argument(
+ "data", metavar="DIR", help="path to data directory; we load .jsonl"
+ )
+ parser.add_argument(
+ "--init-token",
+ type=int,
+ default=None,
+ help="add token at the beginning of each batch item",
+ )
def __init__(self, args, vocab):
super().__init__(args)
self.vocab = vocab
- self.mask = vocab.add_symbol('')
+ self.mask = vocab.add_symbol("")
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args)
# hack to handle GPT-2 BPE, which includes leading spaces
- if args.bpe == 'gpt2':
+ if args.bpe == "gpt2":
self.leading_space = True
self.trailing_space = False
else:
@@ -65,16 +69,16 @@ def load_dictionary(cls, filename):
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
- dictionary.add_symbol('')
+ dictionary.add_symbol("")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
- assert args.criterion == 'wsc', 'Must set --criterion=wsc'
+ assert args.criterion == "wsc", "Must set --criterion=wsc"
# load data and label dictionaries
- vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
- print('| dictionary: {} types'.format(len(vocab)))
+ vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
+ print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
@@ -84,7 +88,9 @@ def binarize(self, s: str, append_eos: bool = False):
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
- s, append_eos=append_eos, add_if_not_exist=False,
+ s,
+ append_eos=append_eos,
+ add_if_not_exist=False,
).long()
if self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
@@ -98,19 +104,21 @@ def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space)
mask = torch.zeros_like(toks, dtype=torch.bool)
mask_start = len(self.binarize(prefix))
mask_size = len(self.binarize(leading_space + txt))
- mask[mask_start:mask_start + mask_size] = 1
+ mask[mask_start : mask_start + mask_size] = 1
return toks, mask
- def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
+ def load_dataset(
+ self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
+ ):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if data_path is None:
- data_path = os.path.join(self.args.data, split + '.jsonl')
+ data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
- raise FileNotFoundError('Cannot find data: {}'.format(data_path))
+ raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = []
query_masks = []
@@ -121,13 +129,15 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl
labels = []
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
- prefix = sentence[:pronoun_span.start].text
- suffix = sentence[pronoun_span.end:].text_with_ws
+ prefix = sentence[: pronoun_span.start].text
+ suffix = sentence[pronoun_span.end :].text_with_ws
# spaCy spans include trailing spaces, but we need to know about
# leading spaces for the GPT-2 BPE
- leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else ''
- trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else ''
+ leading_space = (
+ " " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
+ )
+ trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""
# get noun phrases, excluding pronouns and anything overlapping with the query
cand_spans = wsc_utils.filter_noun_chunks(
@@ -152,7 +162,11 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl
cand_toks, cand_masks = [], []
for cand_span in cand_spans:
toks, mask = self.binarize_with_mask(
- cand_span.text, prefix, suffix, leading_space, trailing_space,
+ cand_span.text,
+ prefix,
+ suffix,
+ leading_space,
+ trailing_space,
)
cand_toks.append(toks)
cand_masks.append(mask)
@@ -176,17 +190,17 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
- labels = ListDataset(labels, [1]*len(labels))
+ labels = ListDataset(labels, [1] * len(labels))
dataset = {
- 'id': IdDataset(),
- 'query_tokens': query_tokens,
- 'query_masks': query_masks,
- 'candidate_tokens': candidate_tokens,
- 'candidate_masks': candidate_masks,
- 'labels': labels,
- 'nsentences': NumSamplesDataset(),
- 'ntokens': NumelDataset(query_tokens, reduce=True),
+ "id": IdDataset(),
+ "query_tokens": query_tokens,
+ "query_masks": query_masks,
+ "candidate_tokens": candidate_tokens,
+ "candidate_masks": candidate_masks,
+ "labels": labels,
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(
@@ -210,9 +224,9 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl
def build_dataset_for_inference(self, sample_json):
with tempfile.NamedTemporaryFile(buffering=0) as h:
- h.write((json.dumps(sample_json) + '\n').encode('utf-8'))
+ h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
dataset = self.load_dataset(
- 'disambiguate_pronoun',
+ "disambiguate_pronoun",
data_path=h.name,
return_only=True,
)
@@ -239,19 +253,19 @@ def get_lprobs(tokens, mask):
return scores
cand_lprobs = get_lprobs(
- sample['candidate_tokens'][0],
- sample['candidate_masks'][0],
+ sample["candidate_tokens"][0],
+ sample["candidate_masks"][0],
)
- if sample['query_tokens'][0] is not None:
+ if sample["query_tokens"][0] is not None:
query_lprobs = get_lprobs(
- sample['query_tokens'][0].unsqueeze(0),
- sample['query_masks'][0].unsqueeze(0),
+ sample["query_tokens"][0].unsqueeze(0),
+ sample["query_masks"][0].unsqueeze(0),
)
return (query_lprobs >= cand_lprobs).all().item() == 1
else:
best_idx = cand_lprobs.argmax().item()
- full_cand = sample['candidate_tokens'][0][best_idx]
- mask = sample['candidate_masks'][0][best_idx]
+ full_cand = sample["candidate_tokens"][0][best_idx]
+ mask = sample["candidate_masks"][0][best_idx]
toks = full_cand[mask.bool()]
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
@@ -264,7 +278,7 @@ def target_dictionary(self):
return self.vocab
-@register_task('winogrande')
+@register_task("winogrande")
class WinograndeTask(WSCTask):
"""
Task for WinoGrande dataset. Efficient implementation for Winograd schema
@@ -273,24 +287,26 @@ class WinograndeTask(WSCTask):
@classmethod
def setup_task(cls, args, **kwargs):
- assert args.criterion == 'winogrande', 'Must set --criterion=winogrande'
+ assert args.criterion == "winogrande", "Must set --criterion=winogrande"
# load data and label dictionaries
- vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
- print('| dictionary: {} types'.format(len(vocab)))
+ vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
+ print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
- def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
+ def load_dataset(
+ self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
+ ):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if data_path is None:
- data_path = os.path.join(self.args.data, split + '.jsonl')
+ data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
- raise FileNotFoundError('Cannot find data: {}'.format(data_path))
+ raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = []
query_masks = []
@@ -299,19 +315,23 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl
candidate_masks = []
candidate_lengths = []
- itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test'))
+ itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))
for sample in itr:
sentence, pronoun_span, query, cand_text = sample
- prefix = sentence[:pronoun_span[0]].rstrip()
- suffix = sentence[pronoun_span[1]:]
+ prefix = sentence[: pronoun_span[0]].rstrip()
+ suffix = sentence[pronoun_span[1] :]
- leading_space = ' ' if sentence[:pronoun_span[0]].endswith(' ') else ''
- trailing_space = ''
+ leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
+ trailing_space = ""
if query is not None:
query_toks, query_mask = self.binarize_with_mask(
- query, prefix, suffix, leading_space, trailing_space,
+ query,
+ prefix,
+ suffix,
+ leading_space,
+ trailing_space,
)
query_len = len(query_toks)
else:
@@ -322,7 +342,11 @@ def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_onl
query_lengths.append(query_len)
cand_toks, cand_mask = self.binarize_with_mask(
- cand_text, prefix, suffix, leading_space, trailing_space,
+ cand_text,
+ prefix,
+ suffix,
+ leading_space,
+ trailing_space,
)
candidate_tokens.append(cand_toks)
@@ -342,17 +366,19 @@ def get_pad_dataset_fn(tokens, length, pad_idx):
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
candidate_lengths = np.array(candidate_lengths)
- candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad())
+ candidate_tokens = get_pad_dataset_fn(
+ candidate_tokens, candidate_lengths, self.vocab.pad()
+ )
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
dataset = {
- 'id': IdDataset(),
- 'query_tokens': query_tokens,
- 'query_masks': query_masks,
- 'candidate_tokens': candidate_tokens,
- 'candidate_masks': candidate_masks,
- 'nsentences': NumSamplesDataset(),
- 'ntokens': NumelDataset(query_tokens, reduce=True),
+ "id": IdDataset(),
+ "query_tokens": query_tokens,
+ "query_masks": query_masks,
+ "candidate_tokens": candidate_tokens,
+ "candidate_masks": candidate_masks,
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(
diff --git a/examples/roberta/wsc/wsc_utils.py b/examples/roberta/wsc/wsc_utils.py
index 2d4822479e..da6ba74383 100644
--- a/examples/roberta/wsc/wsc_utils.py
+++ b/examples/roberta/wsc/wsc_utils.py
@@ -3,48 +3,48 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from functools import lru_cache
import json
+from functools import lru_cache
def convert_sentence_to_json(sentence):
- if '_' in sentence:
- prefix, rest = sentence.split('_', 1)
- query, rest = rest.split('_', 1)
- query_index = len(prefix.rstrip().split(' '))
+ if "_" in sentence:
+ prefix, rest = sentence.split("_", 1)
+ query, rest = rest.split("_", 1)
+ query_index = len(prefix.rstrip().split(" "))
else:
query, query_index = None, None
- prefix, rest = sentence.split('[', 1)
- pronoun, rest = rest.split(']', 1)
- pronoun_index = len(prefix.rstrip().split(' '))
+ prefix, rest = sentence.split("[", 1)
+ pronoun, rest = rest.split("]", 1)
+ pronoun_index = len(prefix.rstrip().split(" "))
- sentence = sentence.replace('_', '').replace('[', '').replace(']', '')
+ sentence = sentence.replace("_", "").replace("[", "").replace("]", "")
return {
- 'idx': 0,
- 'text': sentence,
- 'target': {
- 'span1_index': query_index,
- 'span1_text': query,
- 'span2_index': pronoun_index,
- 'span2_text': pronoun,
+ "idx": 0,
+ "text": sentence,
+ "target": {
+ "span1_index": query_index,
+ "span1_text": query,
+ "span2_index": pronoun_index,
+ "span2_text": pronoun,
},
}
def extended_noun_chunks(sentence):
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
- np_start, cur_np = 0, 'NONE'
+ np_start, cur_np = 0, "NONE"
for i, token in enumerate(sentence):
- np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE'
+ np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
if np_type != cur_np:
- if cur_np != 'NONE':
+ if cur_np != "NONE":
noun_chunks.add((np_start, i))
- if np_type != 'NONE':
+ if np_type != "NONE":
np_start = i
cur_np = np_type
- if cur_np != 'NONE':
+ if cur_np != "NONE":
noun_chunks.add((np_start, len(sentence)))
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
@@ -61,14 +61,14 @@ def find_token(sentence, start_pos):
def find_span(sentence, search_text, start=0):
search_text = search_text.lower()
for tok in sentence[start:]:
- remainder = sentence[tok.i:].text.lower()
+ remainder = sentence[tok.i :].text.lower()
if remainder.startswith(search_text):
len_to_consume = len(search_text)
start_idx = tok.idx
- for next_tok in sentence[tok.i:]:
+ for next_tok in sentence[tok.i :]:
end_idx = next_tok.idx + len(next_tok.text)
if end_idx - start_idx == len_to_consume:
- span = sentence[tok.i:next_tok.i + 1]
+ span = sentence[tok.i : next_tok.i + 1]
return span
return None
@@ -76,13 +76,15 @@ def find_span(sentence, search_text, start=0):
@lru_cache(maxsize=1)
def get_detokenizer():
from sacremoses import MosesDetokenizer
- detok = MosesDetokenizer(lang='en')
+
+ detok = MosesDetokenizer(lang="en")
return detok
@lru_cache(maxsize=1)
def get_spacy_nlp():
import en_core_web_lg
+
nlp = en_core_web_lg.load()
return nlp
@@ -95,45 +97,45 @@ def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
for line in fin:
sample = json.loads(line.strip())
- if positive_only and 'label' in sample and not sample['label']:
+ if positive_only and "label" in sample and not sample["label"]:
# only consider examples where the query is correct
continue
- target = sample['target']
+ target = sample["target"]
# clean up the query
- query = target['span1_text']
+ query = target["span1_text"]
if query is not None:
- if '\n' in query:
+ if "\n" in query:
continue
- if query.endswith('.') or query.endswith(','):
+ if query.endswith(".") or query.endswith(","):
query = query[:-1]
# split tokens
- tokens = sample['text'].split(' ')
+ tokens = sample["text"].split(" ")
def strip_pronoun(x):
return x.rstrip('.,"')
# find the pronoun
- pronoun_idx = target['span2_index']
- pronoun = strip_pronoun(target['span2_text'])
+ pronoun_idx = target["span2_index"]
+ pronoun = strip_pronoun(target["span2_text"])
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
# hack: sometimes the index is misaligned
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
pronoun_idx += 1
else:
- raise Exception('Misaligned pronoun!')
+ raise Exception("Misaligned pronoun!")
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
# split tokens before and after the pronoun
before = tokens[:pronoun_idx]
- after = tokens[pronoun_idx + 1:]
+ after = tokens[pronoun_idx + 1 :]
# the GPT BPE attaches leading spaces to tokens, so we keep track
# of whether we need spaces before or after the pronoun
- leading_space = ' ' if pronoun_idx > 0 else ''
- trailing_space = ' ' if len(after) > 0 else ''
+ leading_space = " " if pronoun_idx > 0 else ""
+ trailing_space = " " if len(after) > 0 else ""
# detokenize
before = detok.detokenize(before, return_str=True)
@@ -142,14 +144,14 @@ def strip_pronoun(x):
# hack: when the pronoun ends in a period (or comma), move the
# punctuation to the "after" part
- if pronoun.endswith('.') or pronoun.endswith(','):
+ if pronoun.endswith(".") or pronoun.endswith(","):
after = pronoun[-1] + trailing_space + after
pronoun = pronoun[:-1]
# hack: when the "after" part begins with a comma or period, remove
# the trailing space
- if after.startswith('.') or after.startswith(','):
- trailing_space = ''
+ if after.startswith(".") or after.startswith(","):
+ trailing_space = ""
# parse sentence with spacy
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
@@ -164,13 +166,13 @@ def strip_pronoun(x):
# convert to format where pronoun is surrounded by "[]" and
# query is surrounded by "_"
query_span = find_span(sentence, query)
- query_with_ws = '_{}_{}'.format(
+ query_with_ws = "_{}_{}".format(
query_span.text,
- (' ' if query_span.text_with_ws.endswith(' ') else '')
+ (" " if query_span.text_with_ws.endswith(" ") else ""),
)
- pronoun_with_ws = '[{}]{}'.format(
+ pronoun_with_ws = "[{}]{}".format(
pronoun_span.text,
- (' ' if pronoun_span.text_with_ws.endswith(' ') else '')
+ (" " if pronoun_span.text_with_ws.endswith(" ") else ""),
)
if query_span.start < pronoun_span.start:
first = (query_span, query_with_ws)
@@ -179,41 +181,45 @@ def strip_pronoun(x):
first = (pronoun_span, pronoun_with_ws)
second = (query_span, query_with_ws)
sentence = (
- sentence[:first[0].start].text_with_ws
+ sentence[: first[0].start].text_with_ws
+ first[1]
- + sentence[first[0].end:second[0].start].text_with_ws
+ + sentence[first[0].end : second[0].start].text_with_ws
+ second[1]
- + sentence[second[0].end:].text
+ + sentence[second[0].end :].text
)
- yield sentence, sample.get('label', None)
+ yield sentence, sample.get("label", None)
else:
- yield sentence, pronoun_span, query, sample.get('label', None)
+ yield sentence, pronoun_span, query, sample.get("label", None)
def winogrande_jsonl_iterator(input_fname, eval=False):
with open(input_fname) as fin:
for line in fin:
sample = json.loads(line.strip())
- sentence, option1, option2 = sample['sentence'], sample['option1'],\
- sample['option2']
+ sentence, option1, option2 = (
+ sample["sentence"],
+ sample["option1"],
+ sample["option2"],
+ )
- pronoun_span = (sentence.index('_'), sentence.index('_') + 1)
+ pronoun_span = (sentence.index("_"), sentence.index("_") + 1)
if eval:
query, cand = option1, option2
else:
- query = option1 if sample['answer'] == '1' else option2
- cand = option2 if sample['answer'] == '1' else option1
+ query = option1 if sample["answer"] == "1" else option2
+ cand = option2 if sample["answer"] == "1" else option1
yield sentence, pronoun_span, query, cand
-def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False):
+def filter_noun_chunks(
+ chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
+):
if exclude_pronouns:
chunks = [
- np for np in chunks if (
- np.lemma_ != '-PRON-'
- and not all(tok.pos_ == 'PRON' for tok in np)
- )
+ np
+ for np in chunks
+ if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
]
if exclude_query is not None:
@@ -224,9 +230,8 @@ def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact
found = False
for excl in excl_txt:
if (
- (not exact_match and (lower_chunk in excl or excl in lower_chunk))
- or lower_chunk == excl
- ):
+ not exact_match and (lower_chunk in excl or excl in lower_chunk)
+ ) or lower_chunk == excl:
found = True
break
if not found:
diff --git a/examples/simultaneous_translation/__init__.py b/examples/simultaneous_translation/__init__.py
index e6963d6d1b..446fc86c8a 100644
--- a/examples/simultaneous_translation/__init__.py
+++ b/examples/simultaneous_translation/__init__.py
@@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . import criterions, models, eval # noqa
+from . import criterions, eval, models # noqa
diff --git a/examples/simultaneous_translation/criterions/__init__.py b/examples/simultaneous_translation/criterions/__init__.py
index 84dc80ad95..08791bfff3 100644
--- a/examples/simultaneous_translation/criterions/__init__.py
+++ b/examples/simultaneous_translation/criterions/__init__.py
@@ -6,6 +6,7 @@
import importlib
import os
+
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
criterion_name = file[: file.find(".py")]
diff --git a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py b/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py
index d4d544ec5f..b3c8f6d53f 100644
--- a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py
+++ b/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py
@@ -3,21 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from examples.simultaneous_translation.utils.latency import LatencyTraining
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
- LabelSmoothedCrossEntropyCriterion
-)
-
-from examples.simultaneous_translation.utils.latency import (
- LatencyTraining
+ LabelSmoothedCrossEntropyCriterion,
)
-@register_criterion('latency_augmented_label_smoothed_cross_entropy')
+@register_criterion("latency_augmented_label_smoothed_cross_entropy")
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
LabelSmoothedCrossEntropyCriterion
):
-
def __init__(self, args, task):
super().__init__(args, task)
self.eps = args.label_smoothing
@@ -40,7 +36,7 @@ def __init__(self, args, task):
def add_args(parser):
super(
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
- LatencyAugmentedLabelSmoothedCrossEntropyCriterion
+ LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
).add_args(parser)
"""Add criterion-specific arguments to the parser."""
# fmt: off
@@ -69,7 +65,8 @@ def compute_loss(self, model, net_output, sample, reduce=True):
# Get latency loss
latency_loss = self.latency_train.loss(
- attn_list, source_padding_mask, target_padding_mask)
+ attn_list, source_padding_mask, target_padding_mask
+ )
loss += latency_loss
diff --git a/examples/simultaneous_translation/eval/agents/__init__.py b/examples/simultaneous_translation/eval/agents/__init__.py
index 1c23fc1ad9..511e7b2474 100644
--- a/examples/simultaneous_translation/eval/agents/__init__.py
+++ b/examples/simultaneous_translation/eval/agents/__init__.py
@@ -5,16 +5,20 @@
import importlib
import os
+
from fairseq import registry
-build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry('--agent-type')
+
+build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry(
+ "--agent-type"
+)
-DEFAULT_EOS = ''
+DEFAULT_EOS = ""
GET = 0
SEND = 1
for file in os.listdir(os.path.dirname(__file__)):
- if file.endswith('.py') and not file.startswith('_'):
- module = file[:file.find('.py')]
- importlib.import_module('agents.' + module)
+ if file.endswith(".py") and not file.startswith("_"):
+ module = file[: file.find(".py")]
+ importlib.import_module("agents." + module)
diff --git a/examples/simultaneous_translation/eval/agents/agent.py b/examples/simultaneous_translation/eval/agents/agent.py
index 1977a24dd9..997392cf9b 100644
--- a/examples/simultaneous_translation/eval/agents/agent.py
+++ b/examples/simultaneous_translation/eval/agents/agent.py
@@ -3,14 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . import GET, SEND, DEFAULT_EOS
import time
-from multiprocessing.pool import ThreadPool as Pool
from functools import partial
+from multiprocessing.pool import ThreadPool as Pool
+
+from . import DEFAULT_EOS, GET, SEND
class Agent(object):
"an agent needs to follow this pattern"
+
def __init__(self, *args, **kwargs):
pass
@@ -40,26 +42,26 @@ def decode(self, session, low=0, high=100000, num_thread=10):
with Pool(10) as p:
p.map(
partial(self._decode_one, session),
- [sent_id for sent_id in range(low, high + 1)]
+ [sent_id for sent_id in range(low, high + 1)],
)
else:
for sent_id in range(low, high + 1):
self._decode_one(session, sent_id)
- print(f'Finished {low} to {high} in {time.time() - t0}s')
+ print(f"Finished {low} to {high} in {time.time() - t0}s")
def _decode_one(self, session, sent_id):
action = {}
self.reset()
states = self.init_states()
- while action.get('value', None) != DEFAULT_EOS:
+ while action.get("value", None) != DEFAULT_EOS:
# take an action
action = self.policy(states)
- if action['key'] == GET:
+ if action["key"] == GET:
new_states = session.get_src(sent_id, action["value"])
states = self.update_states(states, new_states)
- elif action['key'] == SEND:
- session.send_hypo(sent_id, action['value'])
+ elif action["key"] == SEND:
+ session.send_hypo(sent_id, action["value"])
print(" ".join(states["tokens"]["tgt"]))
diff --git a/examples/simultaneous_translation/eval/agents/simul_trans_agent.py b/examples/simultaneous_translation/eval/agents/simul_trans_agent.py
index 1b6960c5fa..071b9e89ce 100644
--- a/examples/simultaneous_translation/eval/agents/simul_trans_agent.py
+++ b/examples/simultaneous_translation/eval/agents/simul_trans_agent.py
@@ -3,11 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . agent import Agent
-from . import DEFAULT_EOS, GET, SEND
-from fairseq import checkpoint_utils, utils, tasks
-import os
import json
+import os
+
+from fairseq import checkpoint_utils, tasks, utils
+
+from . import DEFAULT_EOS, GET, SEND
+from .agent import Agent
class SimulTransAgent(Agent):
@@ -51,13 +53,15 @@ def load_dictionary(self, task):
raise NotImplementedError
def load_model(self, args):
- args.user_dir = os.path.join(os.path.dirname(__file__), '..', '..')
+ args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
utils.import_user_module(args)
filename = args.model_path
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
- state = checkpoint_utils.load_checkpoint_to_cpu(filename, json.loads(args.model_overrides))
+ state = checkpoint_utils.load_checkpoint_to_cpu(
+ filename, json.loads(args.model_overrides)
+ )
saved_args = state["args"]
saved_args.data = args.data_bin
@@ -79,7 +83,7 @@ def init_states(self):
"steps": {"src": 0, "tgt": 0},
"finished": False,
"finish_read": False,
- "model_states": {}
+ "model_states": {},
}
def update_states(self, states, new_state):
@@ -115,38 +119,38 @@ def finish_read(self, states):
def write_action(self, states):
token, index = self.model.predict_from_states(states)
- if index == self.dict["tgt"].eos() or len(states["tokens"]["tgt"]) > self.max_len:
+ if (
+ index == self.dict["tgt"].eos()
+ or len(states["tokens"]["tgt"]) > self.max_len
+ ):
# Finish this sentence is predict EOS
states["finished"] = True
end_idx_last_full_word = self._target_length(states)
else:
states["tokens"]["tgt"] += [token]
- end_idx_last_full_word = (
- self.word_splitter["tgt"]
- .end_idx_last_full_word(states["tokens"]["tgt"])
+ end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
+ states["tokens"]["tgt"]
)
self._append_indices(states, [index], "tgt")
if end_idx_last_full_word > states["steps"]["tgt"]:
# Only sent detokenized full words to the server
word = self.word_splitter["tgt"].merge(
- states["tokens"]["tgt"][
- states["steps"]["tgt"]: end_idx_last_full_word
- ]
+ states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
)
states["steps"]["tgt"] = end_idx_last_full_word
states["segments"]["tgt"] += [word]
- return {'key': SEND, 'value': word}
+ return {"key": SEND, "value": word}
else:
return None
def read_action(self, states):
- return {'key': GET, 'value': None}
+ return {"key": GET, "value": None}
def finish_action(self):
- return {'key': SEND, 'value': DEFAULT_EOS}
+ return {"key": SEND, "value": DEFAULT_EOS}
def reset(self):
pass
@@ -160,4 +164,4 @@ def _append_indices(self, states, new_indices, key):
states["indices"][key] += new_indices
def _target_length(self, states):
- return len(states["tokens"]['tgt'])
+ return len(states["tokens"]["tgt"])
diff --git a/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py b/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py
index 65f7cbd313..7c34817bf6 100644
--- a/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py
+++ b/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py
@@ -3,10 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . simul_trans_agent import SimulTransAgent
-from . import DEFAULT_EOS, GET
-from . import register_agent
-from . word_splitter import SPLITTER_DICT
+from . import DEFAULT_EOS, GET, register_agent
+from .simul_trans_agent import SimulTransAgent
+from .word_splitter import SPLITTER_DICT
@register_agent("simul_trans_text")
@@ -15,11 +14,11 @@ def build_word_splitter(self, args):
self.word_splitter = {}
self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type](
- getattr(args, f"src_splitter_path")
- )
+ getattr(args, f"src_splitter_path")
+ )
self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type](
- getattr(args, f"tgt_splitter_path")
- )
+ getattr(args, f"tgt_splitter_path")
+ )
def load_dictionary(self, task):
self.dict = {}
@@ -37,12 +36,16 @@ def update_states(self, states, new_state):
tokens = self.word_splitter["src"].split(new_word)
# Get indices from dictionary
# You can change to you own dictionary
- indices = self.dict["src"].encode_line(
- tokens,
- line_tokenizer=lambda x: x,
- add_if_not_exist=False,
- append_eos=False
- ).tolist()
+ indices = (
+ self.dict["src"]
+ .encode_line(
+ tokens,
+ line_tokenizer=lambda x: x,
+ add_if_not_exist=False,
+ append_eos=False,
+ )
+ .tolist()
+ )
else:
tokens = [new_word]
indices = [self.dict["src"].eos()]
@@ -61,11 +64,11 @@ def read_action(self, states):
# At leat one word is read
if len(states["tokens"]["src"]) == 0:
- return {'key': GET, 'value': None}
+ return {"key": GET, "value": None}
# Only request new word if there is no buffered tokens
if len(states["tokens"]["src"]) <= states["steps"]["src"]:
- return {'key': GET, 'value': None}
+ return {"key": GET, "value": None}
return None
diff --git a/examples/simultaneous_translation/eval/agents/word_splitter.py b/examples/simultaneous_translation/eval/agents/word_splitter.py
index ea564f21ee..c3f71200a5 100644
--- a/examples/simultaneous_translation/eval/agents/word_splitter.py
+++ b/examples/simultaneous_translation/eval/agents/word_splitter.py
@@ -40,6 +40,7 @@ class BPEWordSplitter(object):
def __init__(self, model_path):
super().__init__()
from subword_nmt.apply_bpe import BPE
+
with open(model_path) as f:
self.model = BPE(f)
@@ -48,7 +49,7 @@ def split(self, string):
def end_idx_last_full_word(self, tokens):
# Begin of word indices
- bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != '@@']
+ bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]
if len(bow_indices) < 2:
return 0
@@ -63,6 +64,7 @@ class SentencePieceModelWordSplitter(object):
def __init__(self, model_path):
super().__init__()
import sentencepiece as spm
+
self.model = spm.SentencePieceProcessor()
self.model.Load(model_path)
@@ -71,7 +73,7 @@ def split(self, string):
def end_idx_last_full_word(self, tokens):
# Begin of word indices
- bow_indices = [i for i, t in enumerate(tokens) if t[0] == '\u2581']
+ bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]
if len(bow_indices) < 2:
return 0
diff --git a/examples/simultaneous_translation/eval/client.py b/examples/simultaneous_translation/eval/client.py
index 5cbaa71d31..3ca4ea73b8 100644
--- a/examples/simultaneous_translation/eval/client.py
+++ b/examples/simultaneous_translation/eval/client.py
@@ -3,19 +3,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import requests
from typing import Optional
+
+import requests
from scorers import build_scorer
class SimulSTEvaluationService(object):
- DEFAULT_HOSTNAME = 'localhost'
+ DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
self.hostname = hostname
self.port = port
- self.base_url = f'http://{self.hostname}:{self.port}'
+ self.base_url = f"http://{self.hostname}:{self.port}"
def __enter__(self):
self.new_session()
@@ -25,56 +26,53 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def new_session(self):
# start eval session
- url = f'{self.base_url}'
+ url = f"{self.base_url}"
try:
_ = requests.post(url)
except Exception as e:
- print(f'Failed to start an evaluation session: {e}')
+ print(f"Failed to start an evaluation session: {e}")
- print('Evaluation session started.')
+ print("Evaluation session started.")
return self
def get_scores(self):
# end eval session
- url = f'{self.base_url}/result'
+ url = f"{self.base_url}/result"
try:
r = requests.get(url)
- print('Scores: {}'.format(r.json()))
- print('Evaluation session finished.')
+ print("Scores: {}".format(r.json()))
+ print("Evaluation session finished.")
except Exception as e:
- print(f'Failed to end an evaluation session: {e}')
+ print(f"Failed to end an evaluation session: {e}")
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
- url = f'{self.base_url}/src'
+ url = f"{self.base_url}/src"
params = {"sent_id": sent_id}
if extra_params is not None:
for key in extra_params.keys():
params[key] = extra_params[key]
try:
- r = requests.get(
- url,
- params=params
- )
+ r = requests.get(url, params=params)
except Exception as e:
- print(f'Failed to request a source segment: {e}')
+ print(f"Failed to request a source segment: {e}")
return r.json()
def send_hypo(self, sent_id: int, hypo: str) -> None:
- url = f'{self.base_url}/hypo'
+ url = f"{self.base_url}/hypo"
params = {"sent_id": sent_id}
try:
requests.put(url, params=params, data=hypo.encode("utf-8"))
except Exception as e:
- print(f'Failed to send a translated segment: {e}')
+ print(f"Failed to send a translated segment: {e}")
def corpus_info(self):
- url = f'{self.base_url}'
+ url = f"{self.base_url}"
try:
r = requests.get(url)
except Exception as e:
- print(f'Failed to request corpus information: {e}')
+ print(f"Failed to request corpus information: {e}")
return r.json()
diff --git a/examples/simultaneous_translation/eval/eval_latency.py b/examples/simultaneous_translation/eval/eval_latency.py
index 12cfaa4ed1..50021de47c 100644
--- a/examples/simultaneous_translation/eval/eval_latency.py
+++ b/examples/simultaneous_translation/eval/eval_latency.py
@@ -3,20 +3,21 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from examples.simultaneous_translation.utils.latency import LatencyInference
import argparse
-import torch
import json
+import torch
+from examples.simultaneous_translation.utils.latency import LatencyInference
+
LATENCY_METRICS = [
- 'differentiable_average_lagging',
- 'average_lagging',
- 'average_proportion',
+ "differentiable_average_lagging",
+ "average_lagging",
+ "average_proportion",
]
-class LatencyScorer():
+class LatencyScorer:
def __init__(self, start_from_zero=True):
self.recorder = []
self.scores = {}
@@ -26,10 +27,7 @@ def __init__(self, start_from_zero=True):
def update_reorder(self, list_of_dict):
self.recorder = []
for info in list_of_dict:
- delays = [
- int(x) - int(not self.start_from_zero)
- for x in info["delays"]
- ]
+ delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]]
delays = torch.LongTensor(delays).unsqueeze(0)
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
@@ -59,7 +57,7 @@ def score(cls, list_of_dict, start_from_zero=True):
scorer = LatencyInference()
recorder = []
- with open(args.input, 'r') as f:
+ with open(args.input, "r") as f:
for line in f:
info = json.loads(line)
@@ -74,7 +72,7 @@ def score(cls, list_of_dict, start_from_zero=True):
average_results = {}
for metric in LATENCY_METRICS:
- average_results[metric] = sum(
- [x[metric][0, 0].item() for x in recorder]
- ) / len(recorder)
+ average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len(
+ recorder
+ )
print(f"{metric}: {average_results[metric]}")
diff --git a/examples/simultaneous_translation/eval/evaluate.py b/examples/simultaneous_translation/eval/evaluate.py
index 07f93e7fb0..2f7474621a 100644
--- a/examples/simultaneous_translation/eval/evaluate.py
+++ b/examples/simultaneous_translation/eval/evaluate.py
@@ -5,37 +5,48 @@
import argparse
+from agents import build_agent
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
from fairseq.registry import REGISTRIES
-from agents import build_agent
-DEFAULT_HOSTNAME = 'localhost'
+
+DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
def get_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
- help='server hostname')
- parser.add_argument('--port', type=int, default=DEFAULT_PORT,
- help='server port number')
- parser.add_argument('--agent-type', default='simul_trans_text',
- help='Agent type')
- parser.add_argument('--scorer-type', default='text',
- help='Scorer type')
- parser.add_argument('--start-idx', type=int, default=0,
- help='Start index of the sentence to evaluate')
- parser.add_argument('--end-idx', type=int, default=float('inf'),
- help='End index of the sentence to evaluate')
- parser.add_argument('--scores', action="store_true",
- help='Request scores from server')
- parser.add_argument('--reset-server', action="store_true",
- help='Reset the server')
- parser.add_argument('--num-threads', type=int, default=10,
- help='Number of threads used by agent')
- parser.add_argument('--local', action="store_true", default=False,
- help='Local evaluation')
+ parser.add_argument(
+ "--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
+ )
+ parser.add_argument(
+ "--port", type=int, default=DEFAULT_PORT, help="server port number"
+ )
+ parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
+ parser.add_argument("--scorer-type", default="text", help="Scorer type")
+ parser.add_argument(
+ "--start-idx",
+ type=int,
+ default=0,
+ help="Start index of the sentence to evaluate",
+ )
+ parser.add_argument(
+ "--end-idx",
+ type=int,
+ default=float("inf"),
+ help="End index of the sentence to evaluate",
+ )
+ parser.add_argument(
+ "--scores", action="store_true", help="Request scores from server"
+ )
+ parser.add_argument("--reset-server", action="store_true", help="Reset the server")
+ parser.add_argument(
+ "--num-threads", type=int, default=10, help="Number of threads used by agent"
+ )
+ parser.add_argument(
+ "--local", action="store_true", default=False, help="Local evaluation"
+ )
args, _ = parser.parse_known_args()
diff --git a/examples/simultaneous_translation/eval/scorers/__init__.py b/examples/simultaneous_translation/eval/scorers/__init__.py
index c7fbb5495d..0a0e0a0518 100644
--- a/examples/simultaneous_translation/eval/scorers/__init__.py
+++ b/examples/simultaneous_translation/eval/scorers/__init__.py
@@ -5,15 +5,15 @@
import importlib
import os
+
from fairseq import registry
-(
- build_scorer,
- register_scorer,
- SCORER_REGISTRIES,
- _
-) = registry.setup_registry('--scorer-type')
+
+
+(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry(
+ "--scorer-type"
+)
for file in os.listdir(os.path.dirname(__file__)):
- if file.endswith('.py') and not file.startswith('_'):
- module = file[:file.find('.py')]
- importlib.import_module('scorers.' + module)
+ if file.endswith(".py") and not file.startswith("_"):
+ module = file[: file.find(".py")]
+ importlib.import_module("scorers." + module)
diff --git a/examples/simultaneous_translation/eval/scorers/scorer.py b/examples/simultaneous_translation/eval/scorers/scorer.py
index d16f130e75..d6d3e30aef 100644
--- a/examples/simultaneous_translation/eval/scorers/scorer.py
+++ b/examples/simultaneous_translation/eval/scorers/scorer.py
@@ -3,16 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from vizseq.scorers.bleu import BLEUScorer
-from vizseq.scorers.ter import TERScorer
-from vizseq.scorers.meteor import METEORScorer
-from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
-from collections import defaultdict
import json
import os
+from collections import defaultdict
+
+from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
+from vizseq.scorers.bleu import BLEUScorer
+from vizseq.scorers.meteor import METEORScorer
+from vizseq.scorers.ter import TERScorer
-DEFAULT_EOS = ''
+DEFAULT_EOS = ""
class SimulScorer(object):
@@ -23,7 +24,7 @@ def __init__(self, args):
self.output_files = {
"text": os.path.join(args.output, "text"),
"delay": os.path.join(args.output, "delay"),
- "scores": os.path.join(args.output, "scores")
+ "scores": os.path.join(args.output, "scores"),
}
else:
self.output_files = None
@@ -52,14 +53,7 @@ def send_src(self, sent_id, *args):
def recv_hyp(self, sent_id, list_of_tokens):
for token in list_of_tokens:
- self.translations[
- sent_id
- ].append(
- (
- token,
- self.steps[sent_id]
- )
- )
+ self.translations[sent_id].append((token, self.steps[sent_id]))
def reset(self):
self.steps = defaultdict(int)
@@ -76,8 +70,9 @@ def score(self):
delays += [[t[1] for t in self.translations[i]]]
bleu_score = BLEUScorer(
- sent_level=False, corpus_level=True,
- extra_args={'bleu_tokenizer': self.tokenizer}
+ sent_level=False,
+ corpus_level=True,
+ extra_args={"bleu_tokenizer": self.tokenizer},
).score(translations, [self.data["tgt"]])
ter_score = TERScorer(sent_level=False, corpus_level=True).score(
@@ -92,16 +87,16 @@ def score(self):
{"src_len": src_len, "delays": delay}
for src_len, delay in zip(self.src_lengths(), delays)
],
- start_from_zero=False
+ start_from_zero=False,
)
scores = {
- 'BLEU': bleu_score[0],
- 'TER': ter_score[0],
- 'METEOR': meteor_score[0],
- 'DAL': latency_score['differentiable_average_lagging'],
- 'AL': latency_score['average_lagging'],
- 'AP': latency_score['average_proportion'],
+ "BLEU": bleu_score[0],
+ "TER": ter_score[0],
+ "METEOR": meteor_score[0],
+ "DAL": latency_score["differentiable_average_lagging"],
+ "AL": latency_score["average_lagging"],
+ "AP": latency_score["average_proportion"],
}
if self.output_files is not None:
@@ -109,9 +104,9 @@ def score(self):
os.makedirs(self.output_dir, exist_ok=True)
self.write_results_to_file(translations, delays, scores)
except BaseException as be:
- print(f'Failed to write results to {self.output_dir}.')
+ print(f"Failed to write results to {self.output_dir}.")
print(be)
- print('Skip writing predictions')
+ print("Skip writing predictions")
return scores
@@ -125,12 +120,8 @@ def write_results_to_file(self, translations, delays, scores):
with open(self.output_files["delay"], "w") as f:
for i, delay in enumerate(delays):
f.write(
- json.dumps(
- {
- "src_len": self.src_lengths()[i],
- "delays": delay
- }
- ) + "\n"
+ json.dumps({"src_len": self.src_lengths()[i], "delays": delay})
+ + "\n"
)
with open(self.output_files["scores"], "w") as f:
@@ -163,7 +154,7 @@ def _load_wav_info_from_json(cls, file):
list_to_return.append(
{
"path": item["input"]["path"].strip(),
- "length": item["input"]["length_ms"]
+ "length": item["input"]["length_ms"],
}
)
return list_to_return
diff --git a/examples/simultaneous_translation/eval/scorers/text_scorer.py b/examples/simultaneous_translation/eval/scorers/text_scorer.py
index 4a5daaff21..649a2c7e5c 100644
--- a/examples/simultaneous_translation/eval/scorers/text_scorer.py
+++ b/examples/simultaneous_translation/eval/scorers/text_scorer.py
@@ -3,8 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . scorer import SimulScorer
from . import register_scorer
+from .scorer import SimulScorer
@register_scorer("text")
@@ -13,7 +13,7 @@ def __init__(self, args):
super().__init__(args)
self.data = {
"src": self._load_text_file(args.src_file, split=True),
- "tgt": self._load_text_file(args.tgt_file, split=False)
+ "tgt": self._load_text_file(args.tgt_file, split=False),
}
def send_src(self, sent_id, *args):
@@ -21,7 +21,7 @@ def send_src(self, sent_id, *args):
dict_to_return = {
"sent_id": sent_id,
"segment_id": self.steps[sent_id],
- "segment": self.eos
+ "segment": self.eos,
}
# Consider EOS
self.steps[sent_id] = len(self.data["src"][sent_id]) + 1
@@ -29,7 +29,7 @@ def send_src(self, sent_id, *args):
dict_to_return = {
"sent_id": sent_id,
"segment_id": self.steps[sent_id],
- "segment": self.data["src"][sent_id][self.steps[sent_id]]
+ "segment": self.data["src"][sent_id][self.steps[sent_id]],
}
self.steps[sent_id] += 1
diff --git a/examples/simultaneous_translation/eval/server.py b/examples/simultaneous_translation/eval/server.py
index a108881e38..e44ceaff85 100644
--- a/examples/simultaneous_translation/eval/server.py
+++ b/examples/simultaneous_translation/eval/server.py
@@ -3,12 +3,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
-import sys
import json
-from tornado import web, ioloop
+import sys
+
from scorers import build_scorer
+from tornado import ioloop, web
+
-DEFAULT_HOSTNAME = 'localhost'
+DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
@@ -34,10 +36,10 @@ def get(self):
class SourceHandler(ScorerHandler):
def get(self):
- sent_id = int(self.get_argument('sent_id'))
+ sent_id = int(self.get_argument("sent_id"))
segment_size = None
if "segment_size" in self.request.arguments:
- string = self.get_argument('segment_size')
+ string = self.get_argument("segment_size")
if len(string) > 0:
segment_size = int(string)
@@ -48,8 +50,8 @@ def get(self):
class HypothesisHandler(ScorerHandler):
def put(self):
- sent_id = int(self.get_argument('sent_id'))
- list_of_tokens = self.request.body.decode('utf-8').strip().split()
+ sent_id = int(self.get_argument("sent_id"))
+ list_of_tokens = self.request.body.decode("utf-8").strip().split()
self.scorer.recv_hyp(sent_id, list_of_tokens)
@@ -67,18 +69,21 @@ def add_args():
def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
- app = web.Application([
- (r'/result', ResultHandler, dict(scorer=scorer)),
- (r'/src', SourceHandler, dict(scorer=scorer)),
- (r'/hypo', HypothesisHandler, dict(scorer=scorer)),
- (r'/', EvalSessionHandler, dict(scorer=scorer)),
- ], debug=debug)
+ app = web.Application(
+ [
+ (r"/result", ResultHandler, dict(scorer=scorer)),
+ (r"/src", SourceHandler, dict(scorer=scorer)),
+ (r"/hypo", HypothesisHandler, dict(scorer=scorer)),
+ (r"/", EvalSessionHandler, dict(scorer=scorer)),
+ ],
+ debug=debug,
+ )
app.listen(port, max_buffer_size=1024 ** 3)
sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
ioloop.IOLoop.current().start()
-if __name__ == '__main__':
+if __name__ == "__main__":
args = add_args()
scorer = build_scorer(args)
start_server(scorer, args.hostname, args.port, args.debug)
diff --git a/examples/simultaneous_translation/models/__init__.py b/examples/simultaneous_translation/models/__init__.py
index 138006ed8c..083da43732 100644
--- a/examples/simultaneous_translation/models/__init__.py
+++ b/examples/simultaneous_translation/models/__init__.py
@@ -6,7 +6,10 @@
import importlib
import os
+
for file in os.listdir(os.path.dirname(__file__)):
- if file.endswith('.py') and not file.startswith('_'):
- model_name = file[:file.find('.py')]
- importlib.import_module('examples.simultaneous_translation.models.' + model_name)
+ if file.endswith(".py") and not file.startswith("_"):
+ model_name = file[: file.find(".py")]
+ importlib.import_module(
+ "examples.simultaneous_translation.models." + model_name
+ )
diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py
index 759f195386..ab8adf3aab 100644
--- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py
+++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py
@@ -6,42 +6,34 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-from fairseq.models import (
- register_model,
- register_model_architecture,
+from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
+ TransformerMonotonicDecoderLayer,
+ TransformerMonotonicEncoderLayer,
)
-
-
+from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
- TransformerModel,
- TransformerEncoder,
TransformerDecoder,
+ TransformerEncoder,
+ TransformerModel,
base_architecture,
transformer_iwslt_de_en,
transformer_vaswani_wmt_en_de_big,
)
-from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
- TransformerMonotonicDecoderLayer,
- TransformerMonotonicEncoderLayer
-)
-
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
-@register_model('transformer_unidirectional')
+@register_model("transformer_unidirectional")
class TransformerUnidirectionalModel(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
-@register_model('transformer_monotonic')
+@register_model("transformer_monotonic")
class TransformerMonotonicModel(TransformerModel):
-
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@@ -62,26 +54,17 @@ def _indices_from_states(self, states):
)
tgt_indices = tensor(
- [
- [self.decoder.dictionary.eos()]
- + states["indices"]["tgt"]
- ]
+ [[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
)
else:
- src_indices = states["indices"]["src"][: 1 +
- states["steps"]["src"]]
+ src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
tgt_indices = states["indices"]["tgt"]
return src_indices, None, tgt_indices
def predict_from_states(self, states):
- decoder_states = self.decoder.output_layer(
- states["decoder_features"]
- )
- lprobs = self.get_normalized_probs(
- [decoder_states[:, -1:]],
- log_probs=True
- )
+ decoder_states = self.decoder.output_layer(states["decoder_features"])
+ lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)
index = lprobs.argmax(dim=-1)
@@ -90,25 +73,24 @@ def predict_from_states(self, states):
return token, index[0, 0].item()
def decision_from_states(self, states):
- '''
+ """
This funcion take states dictionary as input, and gives the agent
a decision of whether read a token from server. Moreover, the decoder
states are also calculated here so we can directly generate a target
token without recompute every thing
- '''
+ """
self.eval()
if len(states["tokens"]["src"]) == 0:
return 0
- src_indices, src_lengths, tgt_indices = self._indices_from_states(
- states)
+ src_indices, src_lengths, tgt_indices = self._indices_from_states(states)
# Update encoder states if needed
if (
- "encoder_states" not in states or
- states["encoder_states"][0].size(1) <= states["steps"]["src"]
+ "encoder_states" not in states
+ or states["encoder_states"][0].size(1) <= states["steps"]["src"]
):
encoder_out_dict = self.encoder(src_indices, src_lengths)
states["encoder_states"] = encoder_out_dict
@@ -136,16 +118,14 @@ def decision_from_states(self, states):
class TransformerMonotonicEncoder(TransformerEncoder):
-
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
- self.layers.extend([
- TransformerMonotonicEncoderLayer(args)
- for i in range(args.encoder_layers)
- ])
+ self.layers.extend(
+ [TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
+ )
class TransformerMonotonicDecoder(TransformerDecoder):
@@ -166,19 +146,24 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
self.dictionary = dictionary
self.layers = nn.ModuleList([])
- self.layers.extend([
- TransformerMonotonicDecoderLayer(args, no_encoder_attn)
- for _ in range(args.decoder_layers)
- ])
+ self.layers.extend(
+ [
+ TransformerMonotonicDecoderLayer(args, no_encoder_attn)
+ for _ in range(args.decoder_layers)
+ ]
+ )
def pre_attention(
- self, prev_output_tokens, encoder_out_dict,
- incremental_state=None
+ self, prev_output_tokens, encoder_out_dict, incremental_state=None
):
- positions = self.embed_positions(
- prev_output_tokens,
- incremental_state=incremental_state,
- ) if self.embed_positions is not None else None
+ positions = (
+ self.embed_positions(
+ prev_output_tokens,
+ incremental_state=incremental_state,
+ )
+ if self.embed_positions is not None
+ else None
+ )
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
@@ -216,8 +201,7 @@ def post_attention(self, x):
return x
def extract_features(
- self, prev_output_tokens, encoder_out,
- incremental_state=None, **unused
+ self, prev_output_tokens, encoder_out, incremental_state=None, **unused
):
"""
Similar to *forward* but only return features.
@@ -228,14 +212,8 @@ def extract_features(
- a dictionary with any model-specific outputs
"""
# incremental_state = None
- (
- x,
- encoder_outs,
- encoder_padding_mask
- ) = self.pre_attention(
- prev_output_tokens,
- encoder_out,
- incremental_state
+ (x, encoder_outs, encoder_padding_mask) = self.pre_attention(
+ prev_output_tokens, encoder_out, incremental_state
)
attn = None
inner_states = [x]
@@ -250,7 +228,8 @@ def extract_features(
encoder_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
self_attn_mask=self.buffered_future_mask(x)
- if incremental_state is None else None,
+ if incremental_state is None
+ else None,
)
inner_states.append(x)
@@ -261,38 +240,30 @@ def extract_features(
step_list.append(curr_steps)
if incremental_state.get("online", False):
- p_choose = attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
-
- new_steps = (
- curr_steps
- + (p_choose < 0.5).t().type_as(curr_steps)
+ p_choose = (
+ attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
)
+ new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
+
if (new_steps >= incremental_state["steps"]["src"]).any():
# We need to prune the last self_attn saved_state
# if model decide not to read
# otherwise there will be duplicated saved_state
for j in range(i + 1):
- self.layers[j].prune_incremental_state(
- incremental_state)
+ self.layers[j].prune_incremental_state(incremental_state)
return x, {"action": 0}
- if (
- incremental_state is not None
- and not incremental_state.get("online", False)
- ):
+ if incremental_state is not None and not incremental_state.get("online", False):
# Here is for fast evaluation
- fastest_step = torch.max(
- torch.cat(step_list, dim=1),
- dim=1,
- keepdim=True
- )[0] + 1
+ fastest_step = (
+ torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
+ )
if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = torch.cat(
- [incremental_state["fastest_step"], fastest_step],
- dim=1
+ [incremental_state["fastest_step"], fastest_step], dim=1
)
else:
incremental_state["fastest_step"] = fastest_step
@@ -310,25 +281,19 @@ def extract_features(
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
if "fastest_step" in incremental_state:
- incremental_state["fastest_step"] = (
- incremental_state["fastest_step"]
- .index_select(0, new_order)
- )
+ incremental_state["fastest_step"] = incremental_state[
+ "fastest_step"
+ ].index_select(0, new_order)
-@register_model_architecture(
- 'transformer_monotonic',
- 'transformer_monotonic'
-)
+@register_model_architecture("transformer_monotonic", "transformer_monotonic")
def base_monotonic_rchitecture(args):
base_architecture(args)
- args.encoder_unidirectional = getattr(
- args, 'encoder_unidirectional', False)
+ args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
@register_model_architecture(
- 'transformer_monotonic',
- 'transformer_monotonic_iwslt_de_en'
+ "transformer_monotonic", "transformer_monotonic_iwslt_de_en"
)
def transformer_monotonic_iwslt_de_en(args):
transformer_iwslt_de_en(args)
@@ -337,24 +302,21 @@ def transformer_monotonic_iwslt_de_en(args):
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture(
- 'transformer_monotonic',
- 'transformer_monotonic_vaswani_wmt_en_de_big'
+ "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
)
def transformer_monotonic_vaswani_wmt_en_de_big(args):
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture(
- 'transformer_monotonic',
- 'transformer_monotonic_vaswani_wmt_en_fr_big'
+ "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
)
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
transformer_monotonic_vaswani_wmt_en_fr_big(args)
@register_model_architecture(
- 'transformer_unidirectional',
- 'transformer_unidirectional_iwslt_de_en'
+ "transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
)
def transformer_unidirectional_iwslt_de_en(args):
transformer_iwslt_de_en(args)
diff --git a/examples/simultaneous_translation/modules/__init__.py b/examples/simultaneous_translation/modules/__init__.py
index 8fd9d379a5..ad64774de4 100644
--- a/examples/simultaneous_translation/modules/__init__.py
+++ b/examples/simultaneous_translation/modules/__init__.py
@@ -7,14 +7,18 @@
import os
from fairseq import registry
+
+
(
build_monotonic_attention,
register_monotonic_attention,
MONOTONIC_ATTENTION_REGISTRY,
- _
-) = registry.setup_registry('--simul-type')
+ _,
+) = registry.setup_registry("--simul-type")
for file in os.listdir(os.path.dirname(__file__)):
- if file.endswith('.py') and not file.startswith('_'):
- model_name = file[:file.find('.py')]
- importlib.import_module('examples.simultaneous_translation.modules.' + model_name)
+ if file.endswith(".py") and not file.startswith("_"):
+ model_name = file[: file.find(".py")]
+ importlib.import_module(
+ "examples.simultaneous_translation.modules." + model_name
+ )
diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py
index d508b8cfba..c09725ac9a 100644
--- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py
+++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py
@@ -4,22 +4,19 @@
# LICENSE file in the root directory of this source tree.
import math
+
import torch
-import torch.nn.functional as F
import torch.nn as nn
-
-from fairseq import utils
-
-from fairseq.modules import MultiheadAttention
-
+import torch.nn.functional as F
from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
- lengths_to_mask
+ lengths_to_mask,
)
-
-
+from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
+from fairseq.modules import MultiheadAttention
from fairseq.utils import convert_padding_direction
+
from . import register_monotonic_attention
@@ -28,6 +25,7 @@ class MonotonicAttention(nn.Module):
"""
Abstract class of monotonic attentions
"""
+
def __init__(self, args):
self.eps = args.attention_eps
self.mass_preservation = args.mass_preservation
@@ -38,7 +36,8 @@ def __init__(self, args):
self.energy_bias_init = args.energy_bias_init
self.energy_bias = (
nn.Parameter(self.energy_bias_init * torch.ones([1]))
- if args.energy_bias is True else 0
+ if args.energy_bias is True
+ else 0
)
@staticmethod
@@ -90,7 +89,7 @@ def attn_energy(self, q_proj, k_proj, key_padding_mask=None):
if key_padding_mask is not None:
attn_energy = attn_energy.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
- float('-inf'),
+ float("-inf"),
)
return attn_energy
@@ -131,10 +130,7 @@ def expected_alignment_train(self, p_choose, key_padding_mask):
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
- * torch.cumsum(
- previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i],
- dim=1
- )
+ * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
).clamp(0, 1.0)
previous_attn.append(alpha_i.unsqueeze(1))
@@ -170,8 +166,7 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state
# prev_monotonic_step: bsz, num_heads
bsz = bsz_num_heads // self.num_heads
prev_monotonic_step = monotonic_cache.get(
- "step",
- p_choose.new_zeros([bsz, self.num_heads]).long()
+ "step", p_choose.new_zeros([bsz, self.num_heads]).long()
)
bsz, num_heads = prev_monotonic_step.size()
assert num_heads == self.num_heads
@@ -181,8 +176,7 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state
p_choose = p_choose.view(bsz, num_heads, src_len)
if key_padding_mask is not None:
- src_lengths = src_len - \
- key_padding_mask.sum(dim=1, keepdim=True).long()
+ src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long()
else:
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
@@ -197,10 +191,7 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state
# left_pad_source = True:
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
- max_steps = (
- src_lengths - 1 if self.mass_preservation
- else src_lengths
- )
+ max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
# finish_read: bsz, num_heads
finish_read = new_monotonic_step.eq(max_steps)
@@ -210,11 +201,11 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state
# only choose the p at monotonic steps
# p_choose_i: bsz , self.num_heads
p_choose_i = (
- p_choose
- .gather(
+ p_choose.gather(
2,
- (step_offset + new_monotonic_step).unsqueeze(2)
- .clamp(0, src_len - 1)
+ (step_offset + new_monotonic_step)
+ .unsqueeze(2)
+ .clamp(0, src_len - 1),
)
).squeeze(2)
@@ -239,21 +230,17 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state
# alpha: bsz * num_heads, 1, src_len
# new_monotonic_step: bsz, num_heads
- alpha = (
- p_choose
- .new_zeros([bsz * self.num_heads, src_len])
- .scatter(
- 1,
- (step_offset + new_monotonic_step).view(bsz *
- self.num_heads, 1).clamp(0, src_len - 1),
- 1
- )
+ alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter(
+ 1,
+ (step_offset + new_monotonic_step)
+ .view(bsz * self.num_heads, 1)
+ .clamp(0, src_len - 1),
+ 1,
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
- (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1),
- 0
+ (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0
)
alpha = alpha.unsqueeze(1)
@@ -266,8 +253,14 @@ def v_proj_output(self, value):
raise NotImplementedError
def forward(
- self, query, key, value,
- key_padding_mask=None, incremental_state=None, *args, **kwargs,
+ self,
+ query,
+ key,
+ value,
+ key_padding_mask=None,
+ incremental_state=None,
+ *args,
+ **kwargs,
):
tgt_len, bsz, embed_dim = query.size()
@@ -280,25 +273,24 @@ def forward(
# expected alignment alpha
# bsz * self.num_heads, tgt_len, src_len
if incremental_state is not None:
- alpha = self.expected_alignment_infer(p_choose, key_padding_mask, incremental_state)
+ alpha = self.expected_alignment_infer(
+ p_choose, key_padding_mask, incremental_state
+ )
else:
alpha = self.expected_alignment_train(p_choose, key_padding_mask)
# expected attention beta
# bsz * self.num_heads, tgt_len, src_len
- beta = self.expected_attention(alpha, query, key, value, key_padding_mask, incremental_state)
+ beta = self.expected_attention(
+ alpha, query, key, value, key_padding_mask, incremental_state
+ )
attn_weights = beta
v_proj = self.v_proj_output(value)
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
- attn = (
- attn
- .transpose(0, 1)
- .contiguous()
- .view(tgt_len, bsz, embed_dim)
- )
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
@@ -318,26 +310,32 @@ def reorder_incremental_state(self, incremental_state, new_order):
self._set_monotonic_buffer(incremental_state, input_buffer)
def _get_monotonic_buffer(self, incremental_state):
- return utils.get_incremental_state(
- self,
- incremental_state,
- 'monotonic',
- ) or {}
+ return (
+ utils.get_incremental_state(
+ self,
+ incremental_state,
+ "monotonic",
+ )
+ or {}
+ )
def _set_monotonic_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
- 'monotonic',
+ "monotonic",
buffer,
)
def get_pointer(self, incremental_state):
- return utils.get_incremental_state(
- self,
- incremental_state,
- 'monotonic',
- ) or {}
+ return (
+ utils.get_incremental_state(
+ self,
+ incremental_state,
+ "monotonic",
+ )
+ or {}
+ )
def get_fastest_pointer(self, incremental_state):
return self.get_pointer(incremental_state)["step"].max(0)[0]
@@ -354,23 +352,22 @@ def set_pointer(self, incremental_state, p_choose):
utils.set_incremental_state(
self,
incremental_state,
- 'monotonic',
+ "monotonic",
{"step": buffer},
)
@register_monotonic_attention("hard_aligned")
class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
-
def __init__(self, args):
MultiheadAttention.__init__(
self,
embed_dim=args.decoder_embed_dim,
num_heads=args.decoder_attention_heads,
- kdim=getattr(args, 'encoder_embed_dim', None),
- vdim=getattr(args, 'encoder_embed_dim', None),
+ kdim=getattr(args, "encoder_embed_dim", None),
+ vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
- encoder_decoder_attention=True
+ encoder_decoder_attention=True,
)
MonotonicAttention.__init__(self, args)
@@ -395,21 +392,33 @@ def input_projections(self, query, key, value, name):
bsz = query.size(1)
q = self.q_in_proj[name](query)
q *= self.scaling
- q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ q = (
+ q.contiguous()
+ .view(-1, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
else:
q = None
if key is not None:
bsz = key.size(1)
k = self.k_in_proj[name](key)
- k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ k = (
+ k.contiguous()
+ .view(-1, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
else:
k = None
if value is not None:
bsz = value.size(1)
v = self.v_in_proj[name](value)
- v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ v = (
+ v.contiguous()
+ .view(-1, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
else:
v = None
@@ -441,8 +450,7 @@ def p_choose(self, query, key, key_padding_mask=None):
if self.training:
# add noise here to encourage discretness
noise = (
- torch
- .normal(self.noise_mean, self.noise_var, attn_energy.size())
+ torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
.type_as(attn_energy)
.to(attn_energy.device)
)
@@ -454,9 +462,9 @@ def p_choose(self, query, key, key_padding_mask=None):
return p_choose.view(-1, tgt_len, src_len)
def expected_attention(self, alpha, *args):
- '''
+ """
For MMA-H, beta = alpha
- '''
+ """
return alpha
def v_proj_output(self, value):
@@ -479,13 +487,19 @@ def init_soft_attention(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
- nn.init.xavier_uniform_(self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2))
- nn.init.xavier_uniform_(self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(
+ self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
+ )
+ nn.init.xavier_uniform_(
+ self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
+ )
else:
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
- def expected_attention(self, alpha, query, key, value, key_padding_mask, incremental_state):
+ def expected_attention(
+ self, alpha, query, key, value, key_padding_mask, incremental_state
+ ):
# monotonic attention, we will calculate milk here
bsz_x_num_heads, tgt_len, src_len = alpha.size()
bsz = int(bsz_x_num_heads / self.num_heads)
@@ -507,9 +521,10 @@ def expected_attention(self, alpha, query, key, value, key_padding_mask, increme
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
monotonic_step += step_offset
mask = lengths_to_mask(
- monotonic_step.view(-1), soft_energy.size(2), 1).unsqueeze(1)
+ monotonic_step.view(-1), soft_energy.size(2), 1
+ ).unsqueeze(1)
- soft_energy = soft_energy.masked_fill(~ mask.bool(), float('-inf'))
+ soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy)
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
@@ -524,14 +539,20 @@ def expected_attention(self, alpha, query, key, value, key_padding_mask, increme
if key_padding_mask is not None:
if key_padding_mask.any():
exp_soft_energy_cumsum = (
- exp_soft_energy_cumsum.view(-1, self.num_heads, tgt_len, src_len)
- .masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps)
+ exp_soft_energy_cumsum.view(
+ -1, self.num_heads, tgt_len, src_len
+ )
+ .masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
+ )
.view(-1, tgt_len, src_len)
)
inner_items = alpha / exp_soft_energy_cumsum
- beta = exp_soft_energy * torch.cumsum(inner_items.flip(dims=[2]), dim=2).flip(dims=[2])
+ beta = exp_soft_energy * torch.cumsum(
+ inner_items.flip(dims=[2]), dim=2
+ ).flip(dims=[2])
beta = self.dropout_module(beta)
@@ -547,7 +568,9 @@ def __init__(self, args):
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
self.waitk_lagging = args.waitk_lagging
- assert self.waitk_lagging > 0, f"Lagging has to been larger than 0, get {self.waitk_lagging}."
+ assert (
+ self.waitk_lagging > 0
+ ), f"Lagging has to been larger than 0, get {self.waitk_lagging}."
@staticmethod
def add_args(parser):
@@ -556,10 +579,13 @@ def add_args(parser):
MonotonicMultiheadAttentionWaitk,
).add_args(parser)
- parser.add_argument('--waitk-lagging', type=int, required=True,
- help='Wait k lagging')
+ parser.add_argument(
+ "--waitk-lagging", type=int, required=True, help="Wait k lagging"
+ )
- def p_choose(self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None):
+ def p_choose(
+ self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None
+ ):
"""
query: bsz, tgt_len
key: bsz, src_len
@@ -574,16 +600,22 @@ def p_choose(self, query, key, key_padding_mask=None, attn_mask=None, incrementa
if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
# Left pad source
# add -1 to the end
- p_choose = p_choose.masked_fill(key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1)
- p_choose = convert_padding_direction(p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True)
+ p_choose = p_choose.masked_fill(
+ key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1
+ )
+ p_choose = convert_padding_direction(
+ p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True
+ )
p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
# remove -1
p_choose[p_choose.eq(-1)] = 0
# Extend to each head
p_choose = (
- p_choose.contiguous().unsqueeze(1)
- .expand(-1, self.num_heads, -1, -1).contiguous()
+ p_choose.contiguous()
+ .unsqueeze(1)
+ .expand(-1, self.num_heads, -1, -1)
+ .contiguous()
.view(-1, tgt_len, src_len)
)
diff --git a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py
index a9545b2540..442b7d487d 100644
--- a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py
+++ b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py
@@ -3,37 +3,32 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from fairseq.modules import (
- LayerNorm,
- TransformerEncoderLayer,
- TransformerDecoderLayer
-)
+from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer
from . import build_monotonic_attention
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
-
def forward(self, x, encoder_padding_mask):
seq_len, _, _ = x.size()
attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
- attn_mask = attn_mask.masked_fill(attn_mask.bool(), float('-inf'))
+ attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
return super().forward(x, encoder_padding_mask, attn_mask)
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
-
- def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
+ def __init__(
+ self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
+ ):
super().__init__(
args,
no_encoder_attn=True,
add_bias_kv=add_bias_kv,
- add_zero_attn=add_zero_attn
+ add_zero_attn=add_zero_attn,
)
self.encoder_attn = build_monotonic_attention(args)
self.encoder_attn_layer_norm = LayerNorm(
- self.embed_dim,
- export=getattr(args, 'char_inputs', False)
+ self.embed_dim, export=getattr(args, "char_inputs", False)
)
def prune_incremental_state(self, incremental_state):
@@ -46,12 +41,8 @@ def prune(module):
input_buffer = {}
break
module._set_input_buffer(incremental_state, input_buffer)
+
prune(self.self_attn)
def get_steps(self, incremental_state):
- return (
- self.encoder_attn
- ._get_monotonic_buffer(
- incremental_state
- ).get("step", 0)
- )
+ return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)
diff --git a/examples/simultaneous_translation/utils/__init__.py b/examples/simultaneous_translation/utils/__init__.py
index 8e5886008f..be0ba4d99a 100644
--- a/examples/simultaneous_translation/utils/__init__.py
+++ b/examples/simultaneous_translation/utils/__init__.py
@@ -9,6 +9,6 @@
# automatically import any Python files in the criterions/ directory
for file in os.listdir(os.path.dirname(__file__)):
- if file.endswith('.py') and not file.startswith('_'):
- module = file[:file.find('.py')]
- importlib.import_module('examples.simultaneous_translation.utils.' + module)
+ if file.endswith(".py") and not file.startswith("_"):
+ module = file[: file.find(".py")]
+ importlib.import_module("examples.simultaneous_translation.utils." + module)
diff --git a/examples/simultaneous_translation/utils/functions.py b/examples/simultaneous_translation/utils/functions.py
index 620dd1d866..f795b5f31c 100644
--- a/examples/simultaneous_translation/utils/functions.py
+++ b/examples/simultaneous_translation/utils/functions.py
@@ -16,7 +16,9 @@ def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
tensor_size = list(tensor.size())
tensor_size[dim] = 1
return_tensor = safe_cumprod(
- torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), dim=dim, eps=eps
+ torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
+ dim=dim,
+ eps=eps,
)
if dim == 0:
@@ -132,12 +134,14 @@ def moving_sum(x, start_idx: int, end_idx: int):
# batch_size, 1, src_len
moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
- moving_sum = torch.nn.functional.conv1d(
- x,
- moving_sum_weight,
- padding=start_idx + end_idx - 1
- ).squeeze(1).t()
- moving_sum = moving_sum[end_idx: -start_idx]
+ moving_sum = (
+ torch.nn.functional.conv1d(
+ x, moving_sum_weight, padding=start_idx + end_idx - 1
+ )
+ .squeeze(1)
+ .t()
+ )
+ moving_sum = moving_sum[end_idx:-start_idx]
assert src_len == moving_sum.size(0)
assert batch_size == moving_sum.size(1)
diff --git a/examples/simultaneous_translation/utils/latency.py b/examples/simultaneous_translation/utils/latency.py
index 9d09584176..5d800a5d9e 100644
--- a/examples/simultaneous_translation/utils/latency.py
+++ b/examples/simultaneous_translation/utils/latency.py
@@ -18,7 +18,7 @@ def prepare_latency_metric(
src_lens,
target_padding_mask=None,
batch_first: bool = False,
- start_from_zero: bool = True
+ start_from_zero: bool = True,
):
assert len(delays.size()) == 2
assert len(src_lens.size()) == 2
@@ -59,11 +59,7 @@ def __call__(
start_from_zero: bool = True,
):
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
- delays,
- src_lens,
- target_padding_mask,
- batch_first,
- start_from_zero
+ delays, src_lens, target_padding_mask, batch_first, start_from_zero
)
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
@@ -89,10 +85,13 @@ class AverageProportion(LatencyMetric):
AP = 1 / (|x||y]) sum_i^|Y| deleys_i
"""
+
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
if target_padding_mask is not None:
- AP = torch.sum(delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True)
+ AP = torch.sum(
+ delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True
+ )
else:
AP = torch.sum(delays, dim=0, keepdim=True)
@@ -116,14 +115,24 @@ class AverageLagging(LatencyMetric):
gamma = |y| / |x|
tau = argmin_i(delays_i = |x|)
"""
+
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
# tau = argmin_i(delays_i = |x|)
tgt_len, bsz = delays.size()
lagging_padding_mask = delays >= src_lens
- lagging_padding_mask = torch.nn.functional.pad(lagging_padding_mask.t(), (1, 0)).t()[:-1, :]
+ lagging_padding_mask = torch.nn.functional.pad(
+ lagging_padding_mask.t(), (1, 0)
+ ).t()[:-1, :]
gamma = tgt_lens / src_lens
- lagging = delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
+ lagging = (
+ delays
+ - torch.arange(delays.size(0))
+ .unsqueeze(1)
+ .type_as(delays)
+ .expand_as(delays)
+ / gamma
+ )
lagging.masked_fill_(lagging_padding_mask, 0)
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
AL = lagging.sum(dim=0, keepdim=True) / tau
@@ -149,6 +158,7 @@ class DifferentiableAverageLagging(LatencyMetric):
2. max(delays_i, delays'_{i-1} + 1 / gamma)
"""
+
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
tgt_len, bsz = delays.size()
@@ -163,13 +173,18 @@ def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
new_delays[i] = torch.cat(
[
new_delays[i - 1].unsqueeze(0) + 1 / gamma,
- delays[i].unsqueeze(0)
+ delays[i].unsqueeze(0),
],
- dim=0
+ dim=0,
).max(dim=0)[0]
DAL = (
- new_delays - torch.arange(delays.size(0)).unsqueeze(1).type_as(delays).expand_as(delays) / gamma
+ new_delays
+ - torch.arange(delays.size(0))
+ .unsqueeze(1)
+ .type_as(delays)
+ .expand_as(delays)
+ / gamma
)
if target_padding_mask is not None:
DAL = DAL.masked_fill(target_padding_mask, 0)
@@ -186,7 +201,7 @@ def prepare_latency_metric(
src_lens,
target_padding_mask=None,
batch_first: bool = True,
- start_from_zero: bool = True
+ start_from_zero: bool = True,
):
assert batch_first
assert len(delays.size()) == 3
@@ -256,25 +271,21 @@ def __call__(self, monotonic_step, src_lens):
src_lens = src_lens
- delays = (
- monotonic_step
- .view(monotonic_step.size(0), -1, monotonic_step.size(-1))
- .max(dim=1)[0]
- )
+ delays = monotonic_step.view(
+ monotonic_step.size(0), -1, monotonic_step.size(-1)
+ ).max(dim=1)[0]
- delays = (
- delays.masked_fill(delays >= src_lens, 0)
- + (src_lens - 1)
- .expand_as(delays)
- .masked_fill(delays < src_lens, 0)
- )
+ delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as(
+ delays
+ ).masked_fill(delays < src_lens, 0)
return_dict = {}
for key, func in self.metric_calculator.items():
return_dict[key] = func(
- delays.float(), src_lens.float(),
+ delays.float(),
+ src_lens.float(),
target_padding_mask=None,
batch_first=True,
- start_from_zero=True
+ start_from_zero=True,
).t()
return return_dict
@@ -282,8 +293,13 @@ def __call__(self, monotonic_step, src_lens):
class LatencyTraining(object):
def __init__(
- self, avg_weight, var_weight, avg_type, var_type,
- stay_on_last_token, average_method,
+ self,
+ avg_weight,
+ var_weight,
+ avg_type,
+ var_type,
+ stay_on_last_token,
+ average_method,
):
self.avg_weight = avg_weight
self.var_weight = var_weight
@@ -319,17 +335,12 @@ def expected_delays_from_attention(
attention = attention.view(-1, tgt_len, src_len)
if not self.stay_on_last_token:
- residual_attention = \
- 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
- attention = torch.cat(
- [attention[:, :, :-1], residual_attention],
- dim=2
- )
+ residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
+ attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2)
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
steps = (
- torch
- .arange(1, 1 + src_len)
+ torch.arange(1, 1 + src_len)
.unsqueeze(0)
.unsqueeze(1)
.expand_as(attention)
@@ -355,15 +366,12 @@ def expected_delays_from_attention(
src_lens = src_lens.view(-1, 1)
# bsz * num_heads_num_layers, tgt_len, src_len
- expected_delays = (steps * attention).sum(dim=2).view(
- bsz, num_heads_x_layers, tgt_len
+ expected_delays = (
+ (steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len)
)
if target_padding_mask is not None:
- expected_delays.masked_fill_(
- target_padding_mask.unsqueeze(1),
- 0
- )
+ expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0)
return expected_delays, src_lens
@@ -371,8 +379,7 @@ def avg_loss(self, expected_delays, src_lens, target_padding_mask):
bsz, num_heads_x_layers, tgt_len = expected_delays.size()
target_padding_mask = (
- target_padding_mask
- .unsqueeze(1)
+ target_padding_mask.unsqueeze(1)
.expand_as(expected_delays)
.contiguous()
.view(-1, tgt_len)
@@ -396,8 +403,11 @@ def avg_loss(self, expected_delays, src_lens, target_padding_mask):
if self.avg_weight > 0.0:
if self.avg_type in self.metric_calculator:
average_delays = self.metric_calculator[self.avg_type](
- expected_delays, src_lens, target_padding_mask,
- batch_first=True, start_from_zero=False
+ expected_delays,
+ src_lens,
+ target_padding_mask,
+ batch_first=True,
+ start_from_zero=False,
)
else:
raise RuntimeError(f"{self.avg_type} is not supported.")
@@ -408,12 +418,17 @@ def avg_loss(self, expected_delays, src_lens, target_padding_mask):
return 0.0
def var_loss(self, expected_delays, src_lens, target_padding_mask):
- src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[:, :1]
+ src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[
+ :, :1
+ ]
if self.var_weight > 0.0:
if self.var_type in self.variance_calculator:
variance_delays = self.variance_calculator[self.var_type](
- expected_delays, src_lens, target_padding_mask,
- batch_first=True, start_from_zero=False
+ expected_delays,
+ src_lens,
+ target_padding_mask,
+ batch_first=True,
+ start_from_zero=False,
)
else:
raise RuntimeError(f"{self.var_type} is not supported.")
diff --git a/examples/speech_recognition/__init__.py b/examples/speech_recognition/__init__.py
index cd780902e3..0278f6a273 100644
--- a/examples/speech_recognition/__init__.py
+++ b/examples/speech_recognition/__init__.py
@@ -1 +1 @@
-from . import tasks, criterions, models # noqa
+from . import criterions, models, tasks # noqa
diff --git a/examples/speech_recognition/criterions/ASG_loss.py b/examples/speech_recognition/criterions/ASG_loss.py
index 8f932bcd5b..7493654afc 100644
--- a/examples/speech_recognition/criterions/ASG_loss.py
+++ b/examples/speech_recognition/criterions/ASG_loss.py
@@ -6,9 +6,9 @@
# LICENSE file in the root directory of this source tree.
import torch
+from examples.speech_recognition.data.replabels import pack_replabels
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
-from examples.speech_recognition.data.replabels import pack_replabels
@register_criterion("asg_loss")
diff --git a/examples/speech_recognition/data/__init__.py b/examples/speech_recognition/data/__init__.py
index 737a22ec3a..47bb6e24dd 100644
--- a/examples/speech_recognition/data/__init__.py
+++ b/examples/speech_recognition/data/__init__.py
@@ -5,6 +5,7 @@
from .asr_dataset import AsrDataset
+
__all__ = [
- 'AsrDataset',
+ "AsrDataset",
]
diff --git a/examples/speech_recognition/data/asr_dataset.py b/examples/speech_recognition/data/asr_dataset.py
index 47969a2853..63a6fcac85 100644
--- a/examples/speech_recognition/data/asr_dataset.py
+++ b/examples/speech_recognition/data/asr_dataset.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import os
+
import numpy as np
from fairseq.data import FairseqDataset
@@ -30,16 +31,22 @@ class AsrDataset(FairseqDataset):
"""
def __init__(
- self, aud_paths, aud_durations_ms, tgt,
- tgt_dict, ids, speakers,
- num_mel_bins=80, frame_length=25.0, frame_shift=10.0
+ self,
+ aud_paths,
+ aud_durations_ms,
+ tgt,
+ tgt_dict,
+ ids,
+ speakers,
+ num_mel_bins=80,
+ frame_length=25.0,
+ frame_shift=10.0,
):
assert frame_length > 0
assert frame_shift > 0
assert all(x > frame_length for x in aud_durations_ms)
self.frame_sizes = [
- int(1 + (d - frame_length) / frame_shift)
- for d in aud_durations_ms
+ int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
]
assert len(aud_paths) > 0
@@ -57,13 +64,17 @@ def __init__(
self.frame_shift = frame_shift
self.s2s_collater = Seq2SeqCollater(
- 0, 1, pad_index=self.tgt_dict.pad(),
- eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True
+ 0,
+ 1,
+ pad_index=self.tgt_dict.pad(),
+ eos_index=self.tgt_dict.eos(),
+ move_eos_to_beginning=True,
)
def __getitem__(self, index):
import torchaudio
import torchaudio.compliance.kaldi as kaldi
+
tgt_item = self.tgt[index] if self.tgt is not None else None
path = self.aud_paths[index]
@@ -74,7 +85,7 @@ def __getitem__(self, index):
sound,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
- frame_shift=self.frame_shift
+ frame_shift=self.frame_shift,
)
output_cmvn = data_utils.apply_mv_norm(output)
diff --git a/examples/speech_recognition/data/collaters.py b/examples/speech_recognition/data/collaters.py
index 14740d48b7..6acfec876b 100644
--- a/examples/speech_recognition/data/collaters.py
+++ b/examples/speech_recognition/data/collaters.py
@@ -12,18 +12,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals
-import numpy as np
+import numpy as np
import torch
from fairseq.data import data_utils as fairseq_data_utils
class Seq2SeqCollater(object):
"""
- Implements collate function mainly for seq2seq tasks
- This expects each sample to contain feature (src_tokens) and
- targets.
- This collator is also used for aligned training task.
+ Implements collate function mainly for seq2seq tasks
+ This expects each sample to contain feature (src_tokens) and
+ targets.
+ This collator is also used for aligned training task.
"""
def __init__(
diff --git a/examples/speech_recognition/datasets/asr_prep_json.py b/examples/speech_recognition/datasets/asr_prep_json.py
index 2bab825b89..b8db8ff166 100644
--- a/examples/speech_recognition/datasets/asr_prep_json.py
+++ b/examples/speech_recognition/datasets/asr_prep_json.py
@@ -6,52 +6,74 @@
from __future__ import absolute_import, division, print_function, unicode_literals
-from collections import namedtuple
-import concurrent.futures
-from itertools import chain
import argparse
-import os
+import concurrent.futures
import json
-import sentencepiece as spm
import multiprocessing
+import os
+from collections import namedtuple
+from itertools import chain
+import sentencepiece as spm
from fairseq.data import Dictionary
+
MILLISECONDS_TO_SECONDS = 0.001
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
import torchaudio
+
input = {}
output = {}
si, ei = torchaudio.info(aud_path)
- input["length_ms"] = int(si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS)
+ input["length_ms"] = int(
+ si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
+ )
input["path"] = aud_path
token = " ".join(sp.EncodeAsPieces(lable))
ids = tgt_dict.encode_line(token, append_eos=False)
output["text"] = lable
output["token"] = token
- output["tokenid"] = ', '.join(map(str, [t.tolist() for t in ids]))
+ output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
return {utt_id: {"input": input, "output": output}}
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--audio-dirs", nargs="+", default=['-'], required=True,
- help="input directories with audio files")
- parser.add_argument("--labels", required=True,
- help="aggregated input labels with format per line",
- type=argparse.FileType('r', encoding='UTF-8'))
- parser.add_argument("--spm-model", required=True,
- help="sentencepiece model to use for encoding",
- type=argparse.FileType('r', encoding='UTF-8'))
- parser.add_argument("--dictionary", required=True,
- help="file to load fairseq dictionary from",
- type=argparse.FileType('r', encoding='UTF-8'))
+ parser.add_argument(
+ "--audio-dirs",
+ nargs="+",
+ default=["-"],
+ required=True,
+ help="input directories with audio files",
+ )
+ parser.add_argument(
+ "--labels",
+ required=True,
+ help="aggregated input labels with format per line",
+ type=argparse.FileType("r", encoding="UTF-8"),
+ )
+ parser.add_argument(
+ "--spm-model",
+ required=True,
+ help="sentencepiece model to use for encoding",
+ type=argparse.FileType("r", encoding="UTF-8"),
+ )
+ parser.add_argument(
+ "--dictionary",
+ required=True,
+ help="file to load fairseq dictionary from",
+ type=argparse.FileType("r", encoding="UTF-8"),
+ )
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
- parser.add_argument("--output", required=True, type=argparse.FileType('w'),
- help="path to save json output")
+ parser.add_argument(
+ "--output",
+ required=True,
+ type=argparse.FileType("w"),
+ help="path to save json output",
+ )
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
@@ -64,15 +86,17 @@ def main():
(utt_id, label) = line.split(" ", 1)
labels[utt_id] = label
if len(labels) == 0:
- raise Exception('No labels found in ', args.labels_path)
+ raise Exception("No labels found in ", args.labels_path)
- Sample = namedtuple('Sample', 'aud_path utt_id')
+ Sample = namedtuple("Sample", "aud_path utt_id")
samples = []
- for path, _, files in chain.from_iterable(os.walk(path) for path in args.audio_dirs):
+ for path, _, files in chain.from_iterable(
+ os.walk(path) for path in args.audio_dirs
+ ):
for f in files:
if f.endswith(args.audio_format):
if len(os.path.splitext(f)) != 2:
- raise Exception('Expect file name. Got: ', f)
+ raise Exception("Expect file name. Got: ", f)
utt_id = os.path.splitext(f)[0]
if utt_id not in labels:
continue
@@ -81,12 +105,17 @@ def main():
utts = {}
num_cpu = multiprocessing.cpu_count()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
- future_to_sample = {executor.submit(process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict): s for s in samples}
+ future_to_sample = {
+ executor.submit(
+ process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
+ ): s
+ for s in samples
+ }
for future in concurrent.futures.as_completed(future_to_sample):
try:
data = future.result()
except Exception as exc:
- print('generated an exception: ', exc)
+ print("generated an exception: ", exc)
else:
utts.update(data)
json.dump({"utts": utts}, args.output, indent=4)
diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py
index fe5f607d1a..a197ab5a63 100644
--- a/examples/speech_recognition/infer.py
+++ b/examples/speech_recognition/infer.py
@@ -8,17 +8,17 @@
Run inference for pre-processed data with a trained model.
"""
-import editdistance
import logging
import math
import os
import sys
+import editdistance
import numpy as np
import torch
-from fairseq import checkpoint_utils, options, progress_bar, utils, tasks
-from fairseq.logging.meters import StopwatchMeter, TimeMeter
+from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
+from fairseq.logging.meters import StopwatchMeter, TimeMeter
logging.basicConfig()
@@ -52,10 +52,12 @@ def add_asr_eval_argument(parser):
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
parser.add_argument(
- "--w2l-decoder", choices=["viterbi", "kenlm", "fairseqlm"], help="use a w2l decoder"
+ "--w2l-decoder",
+ choices=["viterbi", "kenlm", "fairseqlm"],
+ help="use a w2l decoder",
)
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
- parser.add_argument("--unit-lm", action='store_true', help="if using a unit lm")
+ parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
parser.add_argument("--beam-threshold", type=float, default=25.0)
parser.add_argument("--beam-size-token", type=float, default=100)
@@ -87,10 +89,10 @@ def check_args(args):
# assert args.path is not None, "--path required for generation!"
# assert args.results_path is not None, "--results_path required for generation!"
assert (
- not args.sampling or args.nbest == args.beam
+ not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
- args.replace_unk is None or args.raw_text
+ args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
@@ -110,7 +112,7 @@ def get_dataset_itr(args, task, models):
def process_predictions(
- args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
+ args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
):
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
@@ -122,16 +124,25 @@ def process_predictions(
if res_files is not None:
print(
- "{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"]
+ "{} ({}-{})".format(hyp_pieces, speaker, id),
+ file=res_files["hypo.units"],
+ )
+ print(
+ "{} ({}-{})".format(hyp_words, speaker, id),
+ file=res_files["hypo.words"],
)
- print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"])
tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = post_process(tgt_pieces, args.remove_bpe)
if res_files is not None:
- print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"])
- print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"])
+ print(
+ "{} ({}-{})".format(tgt_pieces, speaker, id),
+ file=res_files["ref.units"],
+ )
+ print(
+ "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
+ )
# only score top hypothesis
if not args.quiet:
logger.debug("HYPO:" + hyp_words)
@@ -146,7 +157,7 @@ def process_predictions(
def prepare_result_files(args):
def get_res_file(file_prefix):
if args.num_shards > 1:
- file_prefix = f'{args.shard_id}_{file_prefix}'
+ file_prefix = f"{args.shard_id}_{file_prefix}"
path = os.path.join(
args.results_path,
"{}-{}-{}.txt".format(
@@ -166,15 +177,17 @@ def get_res_file(file_prefix):
}
-def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=None, model_state=None):
+def load_models_and_criterions(
+ filenames, data_path, arg_overrides=None, task=None, model_state=None
+):
models = []
criterions = []
if arg_overrides is None:
arg_overrides = {}
- arg_overrides['wer_args'] = None
- arg_overrides['data'] = data_path
+ arg_overrides["wer_args"] = None
+ arg_overrides["data"] = data_path
if filenames is None:
assert model_state is not None
@@ -205,8 +218,7 @@ def load_models_and_criterions(filenames, data_path, arg_overrides=None, task=No
def optimize_models(args, use_cuda, models):
- """Optimize ensemble for generation
- """
+ """Optimize ensemble for generation"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
@@ -229,7 +241,7 @@ def generate(self, models, sample, **unused):
emissions = np.stack(self.emissions[ids])
except:
print([x.shape for x in self.emissions[ids]])
- raise Exception('invalid sizes')
+ raise Exception("invalid sizes")
emissions = torch.from_numpy(emissions)
return self.decoder.decode(emissions)
@@ -300,7 +312,9 @@ def build_generator(args):
return W2lFairseqLMDecoder(args, task.target_dictionary)
else:
- print('only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment')
+ print(
+ "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
+ )
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
generator = build_generator(args)
@@ -361,7 +375,11 @@ def build_generator(args):
encoder_out = models[0](**sample["net_input"])
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
- padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None
+ padding = (
+ encoder_out["encoder_padding_mask"][i].cpu().numpy()
+ if encoder_out["encoder_padding_mask"] is not None
+ else None
+ )
features[id.item()] = (feat[i], padding)
continue
hypos = task.inference_step(generator, models, sample, prefix_tokens)
@@ -372,20 +390,31 @@ def build_generator(args):
speaker = None
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
id = sample_id
- toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :]
- target_tokens = (
- utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
+ toks = (
+ sample["target"][i, :]
+ if "target_label" not in sample
+ else sample["target_label"][i, :]
)
+ target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
# Process top predictions
errs, length = process_predictions(
- args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id
+ args,
+ hypos[i],
+ None,
+ tgt_dict,
+ target_tokens,
+ res_files,
+ speaker,
+ id,
)
errs_t += errs
lengths_t += length
wps_meter.update(num_generated_tokens)
t.log({"wps": round(wps_meter.avg)})
- num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
+ num_sentences += (
+ sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
+ )
wer = None
if args.dump_emissions:
@@ -413,7 +442,7 @@ def build_generator(args):
gen_timer.sum,
num_sentences / gen_timer.sum,
1.0 / gen_timer.avg,
- )
+ )
)
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
return task, wer
@@ -424,6 +453,7 @@ def make_parser():
parser = add_asr_eval_argument(parser)
return parser
+
def cli_main():
parser = make_parser()
args = options.parse_args_and_arch(parser)
diff --git a/examples/speech_recognition/models/__init__.py b/examples/speech_recognition/models/__init__.py
index 66ad2b0a1f..0ad9663f11 100644
--- a/examples/speech_recognition/models/__init__.py
+++ b/examples/speech_recognition/models/__init__.py
@@ -1,7 +1,8 @@
import importlib
import os
+
for file in os.listdir(os.path.dirname(__file__)):
- if file.endswith('.py') and not file.startswith('_'):
- model_name = file[:file.find('.py')]
- importlib.import_module('examples.speech_recognition.models.' + model_name)
+ if file.endswith(".py") and not file.startswith("_"):
+ model_name = file[: file.find(".py")]
+ importlib.import_module("examples.speech_recognition.models." + model_name)
diff --git a/examples/speech_recognition/models/vggtransformer.py b/examples/speech_recognition/models/vggtransformer.py
index e9a45ac73e..97974360a4 100644
--- a/examples/speech_recognition/models/vggtransformer.py
+++ b/examples/speech_recognition/models/vggtransformer.py
@@ -9,18 +9,22 @@
import torch
import torch.nn as nn
+from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
+ FairseqEncoderDecoderModel,
FairseqEncoderModel,
FairseqIncrementalDecoder,
- FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
-from fairseq.modules import LinearizedConvolution
-from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
-from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock
+from fairseq.modules import (
+ LinearizedConvolution,
+ TransformerDecoderLayer,
+ TransformerEncoderLayer,
+ VGGBlock,
+)
@register_model("asr_vggtransformer")
@@ -29,6 +33,7 @@ class VGGTransformerModel(FairseqEncoderDecoderModel):
Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660
"""
+
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@@ -602,18 +607,22 @@ def __init__(
self.layers = nn.ModuleList()
if conv_config[-1][0] != transformer_config[0][0]:
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
- self.layers.append(TransformerDecoderLayer(
- prepare_transformer_decoder_params(*transformer_config[0])
- ))
+ self.layers.append(
+ TransformerDecoderLayer(
+ prepare_transformer_decoder_params(*transformer_config[0])
+ )
+ )
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
- self.layers.append(TransformerDecoderLayer(
- prepare_transformer_decoder_params(*transformer_config[i])
- ))
+ self.layers.append(
+ TransformerDecoderLayer(
+ prepare_transformer_decoder_params(*transformer_config[i])
+ )
+ )
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
@@ -713,6 +722,7 @@ def _transpose_if_inference(self, x, incremental_state):
x = x.transpose(0, 1)
return x
+
@register_model("asr_vggtransformer_encoder")
class VGGTransformerEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
diff --git a/examples/speech_recognition/models/w2l_conv_glu_enc.py b/examples/speech_recognition/models/w2l_conv_glu_enc.py
index 26f27553d4..655a9b0d19 100644
--- a/examples/speech_recognition/models/w2l_conv_glu_enc.py
+++ b/examples/speech_recognition/models/w2l_conv_glu_enc.py
@@ -10,7 +10,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
diff --git a/examples/speech_recognition/tasks/__init__.py b/examples/speech_recognition/tasks/__init__.py
index fb9e98372d..ffa5f3bd8c 100644
--- a/examples/speech_recognition/tasks/__init__.py
+++ b/examples/speech_recognition/tasks/__init__.py
@@ -1,7 +1,8 @@
import importlib
import os
+
for file in os.listdir(os.path.dirname(__file__)):
- if file.endswith('.py') and not file.startswith('_'):
- task_name = file[:file.find('.py')]
- importlib.import_module('examples.speech_recognition.tasks.' + task_name)
+ if file.endswith(".py") and not file.startswith("_"):
+ task_name = file[: file.find(".py")]
+ importlib.import_module("examples.speech_recognition.tasks." + task_name)
diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py
index 769ce4ff54..d9f011d55f 100644
--- a/examples/speech_recognition/tasks/speech_recognition.py
+++ b/examples/speech_recognition/tasks/speech_recognition.py
@@ -9,10 +9,10 @@
import sys
import torch
-from fairseq.data import Dictionary
-from fairseq.tasks import register_task, LegacyFairseqTask
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
+from fairseq.data import Dictionary
+from fairseq.tasks import LegacyFairseqTask, register_task
def get_asr_dataset_from_json(data_json_path, tgt_dict):
@@ -78,10 +78,20 @@ def add_args(parser):
parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
)
- parser.add_argument('--max-source-positions', default=sys.maxsize, type=int, metavar='N',
- help='max number of frames in the source sequence')
- parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
- help='max number of tokens in the target sequence')
+ parser.add_argument(
+ "--max-source-positions",
+ default=sys.maxsize,
+ type=int,
+ metavar="N",
+ help="max number of frames in the source sequence",
+ )
+ parser.add_argument(
+ "--max-target-positions",
+ default=1024,
+ type=int,
+ metavar="N",
+ help="max number of tokens in the target sequence",
+ )
def __init__(self, args, tgt_dict):
super().__init__(args)
diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py
index 020aac5593..2a1d8a779d 100644
--- a/examples/speech_recognition/w2l_decoder.py
+++ b/examples/speech_recognition/w2l_decoder.py
@@ -9,16 +9,18 @@
Wav2letter decoders.
"""
-from collections import namedtuple, deque
import gc
import itertools as it
-import numpy as np
-import torch
import os.path as osp
import warnings
+from collections import deque, namedtuple
+
+import numpy as np
+import torch
+from examples.speech_recognition.data.replabels import unpack_replabels
from fairseq import tasks
from fairseq.utils import apply_to_sample
-from examples.speech_recognition.data.replabels import unpack_replabels
+
try:
from wav2letter.common import create_word_dict, load_words
diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py
index 1983f70c10..1efeff4df1 100644
--- a/examples/speech_to_text/data_utils.py
+++ b/examples/speech_to_text/data_utils.py
@@ -4,66 +4,76 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from multiprocessing import cpu_count
+import csv
import os
import os.path as op
-from glob import glob
import zipfile
-import csv
from functools import reduce
-from typing import Dict, Any, List
-from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
+from glob import glob
+from multiprocessing import cpu_count
+from typing import Any, Dict, List
+import numpy as np
import sentencepiece as sp
+from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
+from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
from tqdm import tqdm
-import numpy as np
-from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
-UNK_TOKEN, UNK_TOKEN_ID = '', 3
-BOS_TOKEN, BOS_TOKEN_ID = '', 0
-EOS_TOKEN, EOS_TOKEN_ID = '', 2
-PAD_TOKEN, PAD_TOKEN_ID = '', 1
+UNK_TOKEN, UNK_TOKEN_ID = "", 3
+BOS_TOKEN, BOS_TOKEN_ID = "", 0
+EOS_TOKEN, EOS_TOKEN_ID = "", 2
+PAD_TOKEN, PAD_TOKEN_ID = "", 1
def gen_vocab(
- input_path: str, output_path_prefix: str, model_type='bpe',
- vocab_size=1000,
+ input_path: str,
+ output_path_prefix: str,
+ model_type="bpe",
+ vocab_size=1000,
):
# Train SentencePiece Model
arguments = [
- f'--input={input_path}',
- f'--model_prefix={output_path_prefix}',
- f'--model_type={model_type}',
- f'--vocab_size={vocab_size}',
- '--character_coverage=1.0',
- f'--num_threads={cpu_count()}',
- f'--unk_id={UNK_TOKEN_ID}',
- f'--bos_id={BOS_TOKEN_ID}',
- f'--eos_id={EOS_TOKEN_ID}',
- f'--pad_id={PAD_TOKEN_ID}'
+ f"--input={input_path}",
+ f"--model_prefix={output_path_prefix}",
+ f"--model_type={model_type}",
+ f"--vocab_size={vocab_size}",
+ "--character_coverage=1.0",
+ f"--num_threads={cpu_count()}",
+ f"--unk_id={UNK_TOKEN_ID}",
+ f"--bos_id={BOS_TOKEN_ID}",
+ f"--eos_id={EOS_TOKEN_ID}",
+ f"--pad_id={PAD_TOKEN_ID}",
]
- sp.SentencePieceTrainer.Train(' '.join(arguments))
+ sp.SentencePieceTrainer.Train(" ".join(arguments))
# Export fairseq dictionary
spm = sp.SentencePieceProcessor()
- spm.Load(output_path_prefix + '.model')
+ spm.Load(output_path_prefix + ".model")
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
- assert vocab.get(UNK_TOKEN_ID) == UNK_TOKEN and \
- vocab.get(PAD_TOKEN_ID) == PAD_TOKEN and \
- vocab.get(BOS_TOKEN_ID) == BOS_TOKEN and \
- vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
+ assert (
+ vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
+ and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
+ and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
+ and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
+ )
vocab = {
- i: s for i, s in vocab.items()
+ i: s
+ for i, s in vocab.items()
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
}
- with open(output_path_prefix + '.txt', 'w') as f_out:
+ with open(output_path_prefix + ".txt", "w") as f_out:
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
- f_out.write(f'{s} 1\n')
+ f_out.write(f"{s} 1\n")
-def extract_fbank_features(waveform, sample_rate, output_path=None,
- n_mel_bins=80, apply_utterance_cmvn=True,
- overwrite=False):
+def extract_fbank_features(
+ waveform,
+ sample_rate,
+ output_path=None,
+ n_mel_bins=80,
+ apply_utterance_cmvn=True,
+ overwrite=False,
+):
if output_path is not None and op.exists(output_path) and not overwrite:
return
@@ -74,8 +84,10 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
if features is None:
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
- raise ImportError('Please install pyKaldi or torchaudio to enable '
- 'online filterbank feature extraction')
+ raise ImportError(
+ "Please install pyKaldi or torchaudio to enable "
+ "online filterbank feature extraction"
+ )
if apply_utterance_cmvn:
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
@@ -89,8 +101,8 @@ def extract_fbank_features(waveform, sample_rate, output_path=None,
def create_zip(data_root, zip_path):
cwd = os.path.abspath(os.curdir)
os.chdir(data_root)
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_STORED) as f:
- for filename in tqdm(glob('*.npy')):
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
+ for filename in tqdm(glob("*.npy")):
f.write(filename)
os.chdir(cwd)
@@ -101,69 +113,80 @@ def is_npy_data(data: bytes) -> bool:
def get_zip_manifest(zip_root, zip_filename):
zip_path = op.join(zip_root, zip_filename)
- with zipfile.ZipFile(zip_path, mode='r') as f:
+ with zipfile.ZipFile(zip_path, mode="r") as f:
info = f.infolist()
manifest = {}
for i in tqdm(info):
utt_id = op.splitext(i.filename)[0]
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
- manifest[utt_id] = f'{zip_filename}:{offset}:{file_size}'
- with open(zip_path, 'rb') as f:
+ manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
+ with open(zip_path, "rb") as f:
f.seek(offset)
data = f.read(file_size)
assert len(data) > 1 and is_npy_data(data)
return manifest
-def gen_config_yaml(data_root, spm_filename, yaml_filename='config.yaml',
- specaugment_policy='lb'):
- assert specaugment_policy in {'lb', 'ld'}
+def gen_config_yaml(
+ data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb"
+):
+ assert specaugment_policy in {"lb", "ld"}
data_root = op.abspath(data_root)
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
writer.set_audio_root(op.abspath(data_root))
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
writer.set_input_channels(1)
writer.set_input_feat_per_channel(80)
- if specaugment_policy == 'lb':
+ if specaugment_policy == "lb":
writer.set_specaugment_lb_policy()
else:
writer.set_specaugment_ld_policy()
writer.set_bpe_tokenizer(
- {'bpe': 'sentencepiece',
- 'sentencepiece_model': op.join(data_root, spm_filename)}
+ {
+ "bpe": "sentencepiece",
+ "sentencepiece_model": op.join(data_root, spm_filename),
+ }
)
- writer.set_feature_transforms('_train', ['specaugment'])
+ writer.set_feature_transforms("_train", ["specaugment"])
writer.flush()
def save_df_to_tsv(dataframe, path):
- dataframe.to_csv(path, sep="\t", header=True, index=False, encoding="utf-8",
- escapechar='\\', quoting=csv.QUOTE_NONE)
+ dataframe.to_csv(
+ path,
+ sep="\t",
+ header=True,
+ index=False,
+ encoding="utf-8",
+ escapechar="\\",
+ quoting=csv.QUOTE_NONE,
+ )
-def filter_manifest_df(df, is_train_split=False, extra_filters=None,
- min_n_frames=5, max_n_frames=3000):
+def filter_manifest_df(
+ df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
+):
filters = {
- 'no speech': df['audio'] == '',
- f'short speech (<{min_n_frames} frames)': df['n_frames'] < min_n_frames,
- 'empty sentence': df['tgt_text'] == '',
+ "no speech": df["audio"] == "",
+ f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
+ "empty sentence": df["tgt_text"] == "",
}
if is_train_split:
- filters[f'long speech (>{max_n_frames} frames)'] = \
- df['n_frames'] > max_n_frames
+ filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
if extra_filters is not None:
filters.update(extra_filters)
invalid = reduce(lambda x, y: x | y, filters.values())
valid = ~invalid
print(
- '| ' + ', '.join(f'{n}: {f.sum()}' for n, f in filters.items()) +
- f', total {invalid.sum()} filtered, {valid.sum()} remained.'
+ "| "
+ + ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
+ + f", total {invalid.sum()} filtered, {valid.sum()} remained."
)
return df[valid]
class S2TDataConfigWriter(object):
- DEFAULT_VOCAB_FILENAME = 'dict.txt'
+ DEFAULT_VOCAB_FILENAME = "dict.txt"
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
DEFAULT_INPUT_CHANNELS = 1
@@ -171,48 +194,69 @@ def __init__(self, yaml_path):
try:
import yaml
except ImportError:
- print('Please install PyYAML to load YAML files for S2T data config')
+ print("Please install PyYAML to load YAML files for S2T data config")
self.yaml = yaml
self.yaml_path = yaml_path
self.config = {}
def flush(self):
- with open(self.yaml_path, 'w') as f:
+ with open(self.yaml_path, "w") as f:
self.yaml.dump(self.config, f)
- def set_audio_root(self, audio_root=''):
- self.config['audio_root'] = audio_root
-
- def set_vocab_filename(self, vocab_filename='dict.txt'):
- self.config['vocab_filename'] = vocab_filename
-
- def set_specaugment(self, time_wrap_w: int, freq_mask_n: int,
- freq_mask_f: int, time_mask_n: int, time_mask_t: int,
- time_mask_p: float):
- self.config['specaugment'] = {
- 'time_wrap_W': time_wrap_w, 'freq_mask_N': freq_mask_n,
- 'freq_mask_F': freq_mask_f, 'time_mask_N': time_mask_n,
- 'time_mask_T': time_mask_t, 'time_mask_p': time_mask_p,
+ def set_audio_root(self, audio_root=""):
+ self.config["audio_root"] = audio_root
+
+ def set_vocab_filename(self, vocab_filename="dict.txt"):
+ self.config["vocab_filename"] = vocab_filename
+
+ def set_specaugment(
+ self,
+ time_wrap_w: int,
+ freq_mask_n: int,
+ freq_mask_f: int,
+ time_mask_n: int,
+ time_mask_t: int,
+ time_mask_p: float,
+ ):
+ self.config["specaugment"] = {
+ "time_wrap_W": time_wrap_w,
+ "freq_mask_N": freq_mask_n,
+ "freq_mask_F": freq_mask_f,
+ "time_mask_N": time_mask_n,
+ "time_mask_T": time_mask_t,
+ "time_mask_p": time_mask_p,
}
def set_specaugment_lb_policy(self):
- self.set_specaugment(time_wrap_w=0, freq_mask_n=1, freq_mask_f=27,
- time_mask_n=1, time_mask_t=100, time_mask_p=1.0)
+ self.set_specaugment(
+ time_wrap_w=0,
+ freq_mask_n=1,
+ freq_mask_f=27,
+ time_mask_n=1,
+ time_mask_t=100,
+ time_mask_p=1.0,
+ )
def set_specaugment_ld_policy(self):
- self.set_specaugment(time_wrap_w=0, freq_mask_n=2, freq_mask_f=27,
- time_mask_n=2, time_mask_t=100, time_mask_p=1.0)
+ self.set_specaugment(
+ time_wrap_w=0,
+ freq_mask_n=2,
+ freq_mask_f=27,
+ time_mask_n=2,
+ time_mask_t=100,
+ time_mask_p=1.0,
+ )
def set_input_channels(self, input_channels=1):
- self.config['input_channels'] = input_channels
+ self.config["input_channels"] = input_channels
def set_input_feat_per_channel(self, input_feat_per_channel=80):
- self.config['input_feat_per_channel'] = input_feat_per_channel
+ self.config["input_feat_per_channel"] = input_feat_per_channel
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
- self.config['bpe_tokenizer'] = bpe_tokenizer
+ self.config["bpe_tokenizer"] = bpe_tokenizer
def set_feature_transforms(self, split, transforms: List[str]):
- if 'transforms' not in self.config:
- self.config['transforms'] = {}
- self.config['transforms'][split] = transforms
+ if "transforms" not in self.config:
+ self.config["transforms"] = {}
+ self.config["transforms"][split] = transforms
diff --git a/examples/speech_to_text/prep_covost_data.py b/examples/speech_to_text/prep_covost_data.py
index a70e24e04d..e8a028b446 100644
--- a/examples/speech_to_text/prep_covost_data.py
+++ b/examples/speech_to_text/prep_covost_data.py
@@ -5,30 +5,35 @@
# LICENSE file in the root directory of this source tree.
import argparse
+import csv
import logging
-from tempfile import NamedTemporaryFile
import os
import os.path as op
import shutil
-from typing import Tuple, Optional
-import csv
+from tempfile import NamedTemporaryFile
+from typing import Optional, Tuple
-from torchaudio.datasets.utils import download_url, extract_archive
-from tqdm import tqdm
import pandas as pd
-from torch.utils.data import Dataset
import torchaudio
-from torch import Tensor
-
from examples.speech_to_text.data_utils import (
- gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
- extract_fbank_features, gen_config_yaml, filter_manifest_df
+ create_zip,
+ extract_fbank_features,
+ filter_manifest_df,
+ gen_config_yaml,
+ gen_vocab,
+ get_zip_manifest,
+ save_df_to_tsv,
)
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchaudio.datasets.utils import download_url, extract_archive
+from tqdm import tqdm
+
log = logging.getLogger(__name__)
-MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
+MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class CoVoST(Dataset):
@@ -44,40 +49,82 @@ class CoVoST(Dataset):
found at root path. (default: ``False``).
"""
- CV_URL_TEMPLATE = "https://voice-prod-bundler-ee1969a6ce8178826482b88" \
- "e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
- COVOST_URL_TEMPLATE = "https://dl.fbaipublicfiles.com/covost/" \
- "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
+ CV_URL_TEMPLATE = (
+ "https://voice-prod-bundler-ee1969a6ce8178826482b88"
+ "e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
+ )
+ COVOST_URL_TEMPLATE = (
+ "https://dl.fbaipublicfiles.com/covost/"
+ "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
+ )
VERSIONS = {2}
- SPLITS = ['train', 'dev', 'test']
+ SPLITS = ["train", "dev", "test"]
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
XX_EN_LANGUAGES = {
- 1: ['fr', 'de', 'nl', 'ru', 'es', 'it', 'tr', 'fa', 'sv-SE', 'mn',
- 'zh-CN'],
- 2: ['fr', 'de', 'es', 'ca', 'it', 'ru', 'zh-CN', 'pt', 'fa', 'et', 'mn',
- 'nl', 'tr', 'ar', 'sv-SE', 'lv', 'sl', 'ta', 'ja', 'id', 'cy']
+ 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
+ 2: [
+ "fr",
+ "de",
+ "es",
+ "ca",
+ "it",
+ "ru",
+ "zh-CN",
+ "pt",
+ "fa",
+ "et",
+ "mn",
+ "nl",
+ "tr",
+ "ar",
+ "sv-SE",
+ "lv",
+ "sl",
+ "ta",
+ "ja",
+ "id",
+ "cy",
+ ],
}
EN_XX_LANGUAGES = {
1: [],
- 2: ['de', 'tr', 'fa', 'sv-SE', 'mn', 'zh-CN', 'cy', 'ca', 'sl', 'et',
- 'id',
- 'ar', 'ta', 'lv', 'ja']
+ 2: [
+ "de",
+ "tr",
+ "fa",
+ "sv-SE",
+ "mn",
+ "zh-CN",
+ "cy",
+ "ca",
+ "sl",
+ "et",
+ "id",
+ "ar",
+ "ta",
+ "lv",
+ "ja",
+ ],
}
def __init__(
- self, root: str, split: str, source_language: str,
- target_language: Optional[str] = None, version: int = 2,
- download: bool = False
+ self,
+ root: str,
+ split: str,
+ source_language: str,
+ target_language: Optional[str] = None,
+ version: int = 2,
+ download: bool = False,
) -> None:
assert version in self.VERSIONS and split in self.SPLITS
assert source_language is not None
- self.no_translation = (target_language is None)
+ self.no_translation = target_language is None
if not self.no_translation:
- assert 'en' in {source_language, target_language}
- if source_language == 'en':
+ assert "en" in {source_language, target_language}
+ if source_language == "en":
assert target_language in self.EN_XX_LANGUAGES[version]
else:
assert source_language in self.XX_EN_LANGUAGES[version]
@@ -85,51 +132,60 @@ def __init__(
# Hack here so that we can get "split" column from CoVoST TSV.
# Note that we use CoVoST train split for ASR which is an extension
# to Common Voice train split.
- target_language = 'de' if source_language == 'en' else 'en'
+ target_language = "de" if source_language == "en" else "en"
- self.root = os.path.join(root, 'raw')
+ self.root = os.path.join(root, "raw")
os.makedirs(self.root, exist_ok=True)
- cv_url = self.CV_URL_TEMPLATE.format(ver=self.CV_VERSION_ID[version],
- lang=source_language)
+ cv_url = self.CV_URL_TEMPLATE.format(
+ ver=self.CV_VERSION_ID[version], lang=source_language
+ )
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
if download:
if not os.path.isfile(cv_archive):
download_url(cv_url, self.root, hash_value=None)
extract_archive(cv_archive)
- covost_url = self.COVOST_URL_TEMPLATE.format(src_lang=source_language,
- tgt_lang=target_language)
+ covost_url = self.COVOST_URL_TEMPLATE.format(
+ src_lang=source_language, tgt_lang=target_language
+ )
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
if download:
if not os.path.isfile(covost_archive):
download_url(covost_url, self.root, hash_value=None)
extract_archive(covost_archive)
- cv_tsv = self.load_from_tsv(os.path.join(self.root, 'validated.tsv'))
+ cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
covost_tsv = self.load_from_tsv(
- os.path.join(self.root,
- os.path.basename(covost_url).replace('.tar.gz', ''))
+ os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
+ )
+ df = pd.merge(
+ left=cv_tsv[["path", "sentence", "client_id"]],
+ right=covost_tsv[["path", "translation", "split"]],
+ how="inner",
+ on="path",
)
- df = pd.merge(left=cv_tsv[['path', 'sentence', 'client_id']],
- right=covost_tsv[['path', 'translation', 'split']],
- how='inner', on='path')
- if split == 'train':
- df = df[(df['split'] == split) | (df['split'] == f'{split}_covost')]
+ if split == "train":
+ df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
else:
- df = df[df['split'] == split]
- self.data = df.to_dict(orient='index').items()
+ df = df[df["split"] == split]
+ self.data = df.to_dict(orient="index").items()
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
@classmethod
def load_from_tsv(cls, path: str):
return pd.read_csv(
- path, sep='\t', header=0, encoding='utf-8', escapechar='\\',
- quoting=csv.QUOTE_NONE, na_filter=False
+ path,
+ sep="\t",
+ header=0,
+ encoding="utf-8",
+ escapechar="\\",
+ quoting=csv.QUOTE_NONE,
+ na_filter=False,
)
def __getitem__(
- self, n: int
+ self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset.
@@ -141,12 +197,12 @@ def __getitem__(
sample_id)``
"""
data = self.data[n]
- path = os.path.join(self.root, 'clips', data['path'])
+ path = os.path.join(self.root, "clips", data["path"])
waveform, sample_rate = torchaudio.load(path)
- sentence = data['sentence']
- translation = None if self.no_translation else data['translation']
- speaker_id = data['client_id']
- _id = data['path'].replace('.mp3', '')
+ sentence = data["sentence"]
+ translation = None if self.no_translation else data["translation"]
+ speaker_id = data["client_id"]
+ _id = data["path"].replace(".mp3", "")
return waveform, sample_rate, sentence, translation, speaker_id, _id
def __len__(self) -> int:
@@ -157,76 +213,82 @@ def process(args):
root = op.join(args.data_root, args.src_lang)
os.makedirs(root, exist_ok=True)
# Extract features
- feature_root = op.join(root, 'fbank80')
+ feature_root = op.join(root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in CoVoST.SPLITS:
- print(f'Fetching split {split}...')
- dataset = CoVoST(root, split, args.src_lang, args.tgt_lang,
- download=True)
- print('Extracting log mel filter bank features...')
+ print(f"Fetching split {split}...")
+ dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
+ print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
- extract_fbank_features(waveform, sample_rate,
- op.join(feature_root, f'{utt_id}.npy'))
+ extract_fbank_features(
+ waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
+ )
# Pack features into ZIP
- zip_filename = 'fbank80.zip'
+ zip_filename = "fbank80.zip"
zip_path = op.join(root, zip_filename)
- print('ZIPing features...')
+ print("ZIPing features...")
create_zip(feature_root, zip_path)
- print('Fetching ZIP manifest...')
- zip_manifest = get_zip_manifest(args.data_root,
- f'{args.src_lang}/{zip_filename}')
+ print("Fetching ZIP manifest...")
+ zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
# Generate TSV manifest
- print('Generating manifest...')
+ print("Generating manifest...")
train_text = []
- task = f'asr_{args.src_lang}'
+ task = f"asr_{args.src_lang}"
if args.tgt_lang is not None:
- task = f'st_{args.src_lang}_{args.tgt_lang}'
+ task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
- manifest['id'].append(utt_id)
- manifest['audio'].append(zip_manifest[utt_id])
+ manifest["id"].append(utt_id)
+ manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
- manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
- manifest['tgt_text'].append(
- src_utt if args.tgt_lang is None else tgt_utt
- )
- manifest['speaker'].append(speaker_id)
- is_train_split = split.startswith('train')
+ manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
+ manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
+ manifest["speaker"].append(speaker_id)
+ is_train_split = split.startswith("train")
if is_train_split:
- train_text.extend(manifest['tgt_text'])
+ train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
- save_df_to_tsv(df, op.join(root, f'{split}_{task}.tsv'))
+ save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
# Generate vocab
- vocab_size_str = '' if args.vocab_type == 'char' else str(args.vocab_size)
- spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size_str}_{task}'
- with NamedTemporaryFile(mode='w') as f:
+ vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
+ spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
+ with NamedTemporaryFile(mode="w") as f:
for t in train_text:
- f.write(t + '\n')
- gen_vocab(f.name, op.join(root, spm_filename_prefix),
- args.vocab_type, args.vocab_size)
+ f.write(t + "\n")
+ gen_vocab(
+ f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
+ )
# Generate config YAML
- gen_config_yaml(root, spm_filename_prefix + '.model',
- yaml_filename=f'config_{task}.yaml',
- specaugment_policy='lb')
+ gen_config_yaml(
+ root,
+ spm_filename_prefix + ".model",
+ yaml_filename=f"config_{task}.yaml",
+ specaugment_policy="lb",
+ )
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('--data-root', '-d', required=True, type=str)
- parser.add_argument('--vocab-type', default='unigram', required=True,
- type=str, choices=['bpe', 'unigram', 'char']),
- parser.add_argument('--vocab-size', default=1000, type=int)
- parser.add_argument('--src-lang', '-s', required=True, type=str)
- parser.add_argument('--tgt-lang', '-t', type=str)
+ parser.add_argument("--data-root", "-d", required=True, type=str)
+ parser.add_argument(
+ "--vocab-type",
+ default="unigram",
+ required=True,
+ type=str,
+ choices=["bpe", "unigram", "char"],
+ ),
+ parser.add_argument("--vocab-size", default=1000, type=int)
+ parser.add_argument("--src-lang", "-s", required=True, type=str)
+ parser.add_argument("--tgt-lang", "-t", type=str)
args = parser.parse_args()
process(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/speech_to_text/prep_librispeech_data.py b/examples/speech_to_text/prep_librispeech_data.py
index 4f003ec505..95fcec8fe3 100644
--- a/examples/speech_to_text/prep_librispeech_data.py
+++ b/examples/speech_to_text/prep_librispeech_data.py
@@ -6,91 +6,114 @@
import argparse
import logging
-from tempfile import NamedTemporaryFile
import os
-import shutil
import os.path as op
+import shutil
+from tempfile import NamedTemporaryFile
-from tqdm import tqdm
-from torchaudio.datasets import LIBRISPEECH
import pandas as pd
-
from examples.speech_to_text.data_utils import (
- gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
- extract_fbank_features, gen_config_yaml
+ create_zip,
+ extract_fbank_features,
+ gen_config_yaml,
+ gen_vocab,
+ get_zip_manifest,
+ save_df_to_tsv,
)
+from torchaudio.datasets import LIBRISPEECH
+from tqdm import tqdm
+
log = logging.getLogger(__name__)
-SPLITS = ['train-clean-100', 'train-clean-360', 'train-other-500', 'dev-clean',
- 'dev-other', 'test-clean', 'test-other']
+SPLITS = [
+ "train-clean-100",
+ "train-clean-360",
+ "train-other-500",
+ "dev-clean",
+ "dev-other",
+ "test-clean",
+ "test-other",
+]
-MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
+MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
def process(args):
os.makedirs(args.output_root, exist_ok=True)
# Extract features
- feature_root = op.join(args.output_root, 'fbank80')
+ feature_root = op.join(args.output_root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in SPLITS:
- print(f'Fetching split {split}...')
+ print(f"Fetching split {split}...")
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
- print('Extracting log mel filter bank features...')
+ print("Extracting log mel filter bank features...")
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
- sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
- extract_fbank_features(wav, sample_rate,
- op.join(feature_root, f'{sample_id}.npy'))
+ sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
+ extract_fbank_features(
+ wav, sample_rate, op.join(feature_root, f"{sample_id}.npy")
+ )
# Pack features into ZIP
- zip_filename = 'fbank80.zip'
+ zip_filename = "fbank80.zip"
zip_path = op.join(args.output_root, zip_filename)
- print('ZIPing features...')
+ print("ZIPing features...")
create_zip(feature_root, zip_path)
- print('Fetching ZIP manifest...')
+ print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
# Generate TSV manifest
- print('Generating manifest...')
+ print("Generating manifest...")
train_text = []
for split in SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = LIBRISPEECH(args.output_root, url=split)
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
- sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
- manifest['id'].append(sample_id)
- manifest['audio'].append(zip_manifest[sample_id])
+ sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
+ manifest["id"].append(sample_id)
+ manifest["audio"].append(zip_manifest[sample_id])
duration_ms = int(wav.size(1) / sample_rate * 1000)
- manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
- manifest['tgt_text'].append(utt)
- manifest['speaker'].append(spk_id)
- save_df_to_tsv(pd.DataFrame.from_dict(manifest),
- op.join(args.output_root, f'{split}.tsv'))
- if split.startswith('train'):
- train_text.extend(manifest['tgt_text'])
+ manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
+ manifest["tgt_text"].append(utt)
+ manifest["speaker"].append(spk_id)
+ save_df_to_tsv(
+ pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv")
+ )
+ if split.startswith("train"):
+ train_text.extend(manifest["tgt_text"])
# Generate vocab
- vocab_size = '' if args.vocab_type == 'char' else str(args.vocab_size)
- spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size}'
- with NamedTemporaryFile(mode='w') as f:
+ vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
+ spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
+ with NamedTemporaryFile(mode="w") as f:
for t in train_text:
- f.write(t + '\n')
- gen_vocab(f.name, op.join(args.output_root, spm_filename_prefix),
- args.vocab_type, args.vocab_size)
+ f.write(t + "\n")
+ gen_vocab(
+ f.name,
+ op.join(args.output_root, spm_filename_prefix),
+ args.vocab_type,
+ args.vocab_size,
+ )
# Generate config YAML
- gen_config_yaml(args.output_root, spm_filename_prefix + '.model',
- specaugment_policy='ld')
+ gen_config_yaml(
+ args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld"
+ )
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('--output-root', '-o', required=True, type=str)
- parser.add_argument('--vocab-type', default='unigram', required=True,
- type=str, choices=['bpe', 'unigram', 'char']),
- parser.add_argument('--vocab-size', default=10000, type=int)
+ parser.add_argument("--output-root", "-o", required=True, type=str)
+ parser.add_argument(
+ "--vocab-type",
+ default="unigram",
+ required=True,
+ type=str,
+ choices=["bpe", "unigram", "char"],
+ ),
+ parser.add_argument("--vocab-size", default=10000, type=int)
args = parser.parse_args()
process(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py
index 6c0a9b7132..5593d2e7e2 100644
--- a/examples/speech_to_text/prep_mustc_data.py
+++ b/examples/speech_to_text/prep_mustc_data.py
@@ -6,29 +6,34 @@
import argparse
import logging
-from tempfile import NamedTemporaryFile
import os
import os.path as op
import shutil
-from typing import Tuple
from itertools import groupby
+from tempfile import NamedTemporaryFile
+from typing import Tuple
-from tqdm import tqdm
import pandas as pd
-from torch.utils.data import Dataset
import torchaudio
-from torch import Tensor
-
from examples.speech_to_text.data_utils import (
- gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
- extract_fbank_features, gen_config_yaml, filter_manifest_df
+ create_zip,
+ extract_fbank_features,
+ filter_manifest_df,
+ gen_config_yaml,
+ gen_vocab,
+ get_zip_manifest,
+ save_df_to_tsv,
)
+from torch import Tensor
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
log = logging.getLogger(__name__)
-MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
-TASKS = ['asr', 'st']
+MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
+TASKS = ["asr", "st"]
class MUSTC(Dataset):
@@ -37,49 +42,55 @@ class MUSTC(Dataset):
waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id
"""
- SPLITS = ['train', 'dev', 'tst-COMMON', 'tst-HE']
- LANGUAGES = ['de', 'es', 'fr', 'it', 'nl', 'pt', 'ro', 'ru']
+
+ SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
+ LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
- _root = op.join(root, f'en-{lang}', 'data', split)
- wav_root, txt_root = op.join(_root, 'wav'), op.join(_root, 'txt')
+ _root = op.join(root, f"en-{lang}", "data", split)
+ wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt")
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
# Load audio segments
try:
import yaml
except ImportError:
- print('Please install PyYAML to load YAML files for '
- 'the MuST-C dataset')
- with open(op.join(txt_root, f'{split}.yaml')) as f:
+ print("Please install PyYAML to load YAML files for " "the MuST-C dataset")
+ with open(op.join(txt_root, f"{split}.yaml")) as f:
segments = yaml.load(f, Loader=yaml.BaseLoader)
# Load source and target utterances
- for _lang in ['en', lang]:
- with open(op.join(txt_root, f'{split}.{_lang}')) as f:
+ for _lang in ["en", lang]:
+ with open(op.join(txt_root, f"{split}.{_lang}")) as f:
utterances = [r.strip() for r in f]
assert len(segments) == len(utterances)
for i, u in enumerate(utterances):
segments[i][_lang] = u
# Gather info
self.data = []
- for wav_filename, _seg_group in groupby(segments, lambda x: x['wav']):
+ for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = op.join(wav_root, wav_filename)
sample_rate = torchaudio.info(wav_path)[0].rate
- seg_group = sorted(_seg_group, key=lambda x: x['offset'])
+ seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group):
- offset = int(float(segment['offset']) * sample_rate)
- n_frames = int(float(segment['duration']) * sample_rate)
- _id = f'{op.splitext(wav_filename)[0]}_{i}'
+ offset = int(float(segment["offset"]) * sample_rate)
+ n_frames = int(float(segment["duration"]) * sample_rate)
+ _id = f"{op.splitext(wav_filename)[0]}_{i}"
self.data.append(
- (wav_path, offset, n_frames, sample_rate, segment['en'],
- segment[lang], segment['speaker_id'], _id)
+ (
+ wav_path,
+ offset,
+ n_frames,
+ sample_rate,
+ segment["en"],
+ segment[lang],
+ segment["speaker_id"],
+ _id,
+ )
)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
- wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = \
- self.data[n]
- waveform, _ = torchaudio.load(wav_path, offset=offset,
- num_frames=n_frames)
+ wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
+ waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
def __len__(self) -> int:
@@ -88,85 +99,102 @@ def __len__(self) -> int:
def process(args):
for lang in MUSTC.LANGUAGES:
- cur_root = op.join(args.data_root, f'en-{lang}')
+ cur_root = op.join(args.data_root, f"en-{lang}")
if not op.isdir(cur_root):
- print(f'{cur_root} does not exist. Skipped.')
+ print(f"{cur_root} does not exist. Skipped.")
continue
# Extract features
- feature_root = op.join(cur_root, 'fbank80')
+ feature_root = op.join(cur_root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in MUSTC.SPLITS:
- print(f'Fetching split {split}...')
+ print(f"Fetching split {split}...")
dataset = MUSTC(args.data_root, lang, split)
- print('Extracting log mel filter bank features...')
+ print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
- extract_fbank_features(waveform, sample_rate,
- op.join(feature_root, f'{utt_id}.npy'))
+ extract_fbank_features(
+ waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
+ )
# Pack features into ZIP
- zip_filename = 'fbank80.zip'
+ zip_filename = "fbank80.zip"
zip_path = op.join(cur_root, zip_filename)
- print('ZIPing features...')
+ print("ZIPing features...")
create_zip(feature_root, zip_path)
- print('Fetching ZIP manifest...')
- zip_manifest = get_zip_manifest(args.data_root,
- f'en-{lang}/{zip_filename}')
+ print("Fetching ZIP manifest...")
+ zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}")
# Generate TSV manifest
- print('Generating manifest...')
+ print("Generating manifest...")
train_text = {task: [] for task in TASKS}
for split in MUSTC.SPLITS:
- is_train_split = split.startswith('train')
+ is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
text = {task: [] for task in TASKS}
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
- manifest['id'].append(utt_id)
- manifest['audio'].append(zip_manifest[utt_id])
+ manifest["id"].append(utt_id)
+ manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
- manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
- text['asr'].append(src_utt)
- text['st'].append(tgt_utt)
- manifest['speaker'].append(speaker_id)
+ manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
+ text["asr"].append(src_utt)
+ text["st"].append(tgt_utt)
+ manifest["speaker"].append(speaker_id)
if is_train_split:
for task in TASKS:
train_text[task].extend(text[task])
for task in TASKS:
- manifest['tgt_text'] = text[task]
+ manifest["tgt_text"] = text[task]
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
- save_df_to_tsv(df, op.join(cur_root, f'{split}_{task}.tsv'))
+ save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv"))
# Generate vocab
for task in TASKS:
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
- if task == 'st':
+ if task == "st":
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
- vocab_size_str = '' if vocab_type == 'char' else str(vocab_size)
- spm_filename_prefix = f'spm_{vocab_type}{vocab_size_str}_{task}'
- with NamedTemporaryFile(mode='w') as f:
+ vocab_size_str = "" if vocab_type == "char" else str(vocab_size)
+ spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}"
+ with NamedTemporaryFile(mode="w") as f:
for t in train_text[task]:
- f.write(t + '\n')
- gen_vocab(f.name, op.join(cur_root, spm_filename_prefix),
- vocab_type, vocab_size)
+ f.write(t + "\n")
+ gen_vocab(
+ f.name,
+ op.join(cur_root, spm_filename_prefix),
+ vocab_type,
+ vocab_size,
+ )
# Generate config YAML
- gen_config_yaml(cur_root, spm_filename_prefix + '.model',
- yaml_filename=f'config_{task}.yaml',
- specaugment_policy='lb')
+ gen_config_yaml(
+ cur_root,
+ spm_filename_prefix + ".model",
+ yaml_filename=f"config_{task}.yaml",
+ specaugment_policy="lb",
+ )
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('--data-root', '-d', required=True, type=str)
- parser.add_argument('--asr-vocab-type', default='unigram', required=True,
- type=str, choices=['bpe', 'unigram', 'char']),
- parser.add_argument('--st-vocab-type', default='unigram', required=True,
- type=str, choices=['bpe', 'unigram', 'char']),
- parser.add_argument('--asr-vocab-size', default=5000, type=int)
- parser.add_argument('--st-vocab-size', default=8000, type=int)
+ parser.add_argument("--data-root", "-d", required=True, type=str)
+ parser.add_argument(
+ "--asr-vocab-type",
+ default="unigram",
+ required=True,
+ type=str,
+ choices=["bpe", "unigram", "char"],
+ ),
+ parser.add_argument(
+ "--st-vocab-type",
+ default="unigram",
+ required=True,
+ type=str,
+ choices=["bpe", "unigram", "char"],
+ ),
+ parser.add_argument("--asr-vocab-size", default=5000, type=int)
+ parser.add_argument("--st-vocab-size", default=8000, type=int)
args = parser.parse_args()
process(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/translation_moe/score.py b/examples/translation_moe/score.py
index b68cc828a7..9a529a9850 100644
--- a/examples/translation_moe/score.py
+++ b/examples/translation_moe/score.py
@@ -12,9 +12,9 @@
"""
import argparse
-from itertools import chain
-import sys
import random
+import sys
+from itertools import chain
import numpy as np
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
@@ -22,17 +22,21 @@
def main():
parser = argparse.ArgumentParser(sys.argv[0])
- parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
- help='path to system output')
- parser.add_argument('--ref', default='', metavar='FILE',
- help='path to references')
- parser.add_argument('--output', default='', metavar='FILE',
- help='print outputs into a pretty format')
+ parser.add_argument(
+ "--sys", nargs="*", default="", metavar="FILE", help="path to system output"
+ )
+ parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
+ parser.add_argument(
+ "--output",
+ default="",
+ metavar="FILE",
+ help="print outputs into a pretty format",
+ )
args = parser.parse_args()
if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys)
- print('pairwise BLEU: %.2f' % pairwise(hypos))
+ print("pairwise BLEU: %.2f" % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
@@ -58,18 +62,18 @@ def load_sys(paths):
# S: source
# T: target
# D: detokenized system output
- if line.startswith(('S-', 'T-', 'D-')):
- i = int(line[line.find('-')+1:line.find('\t')])
- if line.startswith('S-'):
- src[i] = line.split('\t')[1]
- if line.startswith('T-'):
- tgt[i] = line.split('\t')[1]
- if line.startswith('D-'):
+ if line.startswith(("S-", "T-", "D-")):
+ i = int(line[line.find("-") + 1 : line.find("\t")])
+ if line.startswith("S-"):
+ src[i] = line.split("\t")[1]
+ if line.startswith("T-"):
+ tgt[i] = line.split("\t")[1]
+ if line.startswith("D-"):
if i not in hypos:
hypos[i] = []
log_probs[i] = []
- hypos[i].append(line.split('\t')[2])
- log_probs[i].append(float(line.split('\t')[1]))
+ hypos[i].append(line.split("\t")[2])
+ log_probs[i].append(float(line.split("\t")[1]))
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
@@ -79,34 +83,34 @@ def load_ref(path):
src, tgt, refs = [], [], []
i = 0
while i < len(lines):
- if lines[i].startswith('S-'):
- src.append(lines[i].split('\t')[1].rstrip())
+ if lines[i].startswith("S-"):
+ src.append(lines[i].split("\t")[1].rstrip())
i += 1
- elif lines[i].startswith('T-'):
- tgt.append(lines[i].split('\t')[1].rstrip())
+ elif lines[i].startswith("T-"):
+ tgt.append(lines[i].split("\t")[1].rstrip())
i += 1
else:
a = []
- while i < len(lines) and lines[i].startswith('R'):
- a.append(lines[i].split('\t')[1].rstrip())
+ while i < len(lines) and lines[i].startswith("R"):
+ a.append(lines[i].split("\t")[1].rstrip())
i += 1
refs.append(a)
return src, tgt, refs
def merge(src, tgt, hypos, log_probs, path):
- with open(path, 'w') as f:
+ with open(path, "w") as f:
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
- f.write(s + '\n')
- f.write(t + '\n')
- f.write('\n')
+ f.write(s + "\n")
+ f.write(t + "\n")
+ f.write("\n")
for h, lp in zip(hs, lps):
- f.write('\t%f\t%s\n' % (lp, h.strip()))
- f.write('------------------------------------------------------\n')
+ f.write("\t%f\t%s\n" % (lp, h.strip()))
+ f.write("------------------------------------------------------\n")
def corpus_bleu(sys_stream, ref_streams):
- bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none')
+ bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
return bleu.score
@@ -116,9 +120,11 @@ def sentence_bleu(hypothesis, reference):
bleu.counts[i] += 1
bleu.totals[i] += 1
bleu = compute_bleu(
- bleu.counts, bleu.totals,
- bleu.sys_len, bleu.ref_len,
- smooth_method='exp',
+ bleu.counts,
+ bleu.totals,
+ bleu.sys_len,
+ bleu.ref_len,
+ smooth_method="exp",
)
return bleu.score
@@ -150,7 +156,7 @@ def multi_ref(refs, hypos):
best = [k for k in range(len(rs)) if s[k] == s[j]]
a.add(random.choice(best))
ref_cnt += len(a)
- print('#refs covered: %.2f' % (ref_cnt / len(refs)))
+ print("#refs covered: %.2f" % (ref_cnt / len(refs)))
# transpose refs and hypos
refs = list(zip(*refs))
@@ -160,33 +166,32 @@ def multi_ref(refs, hypos):
k = len(hypos)
m = len(refs)
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
- duplicated_refs = [
- [ref for ref in refs_i for _ in range(k)]
- for refs_i in refs
- ]
+ duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
loo_bleus = []
for held_out_ref in range(m):
- remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
+ remaining_refs = (
+ duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
+ )
assert len(remaining_refs) == m - 1
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
- print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
+ print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
def intra_ref(refs):
- print('ref pairwise BLEU: %.2f' % pairwise(refs))
+ print("ref pairwise BLEU: %.2f" % pairwise(refs))
refs = list(zip(*refs))
m = len(refs)
concat_h = []
concat_rest = [[] for j in range(m - 1)]
for i, h in enumerate(refs):
- rest = refs[:i] + refs[i+1:]
+ rest = refs[:i] + refs[i + 1 :]
concat_h.append(h)
for j in range(m - 1):
concat_rest[j].extend(rest[j])
concat_h = list(chain.from_iterable(concat_h))
bleu = corpus_bleu(concat_h, concat_rest)
- print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
+ print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/translation_moe/src/logsumexp_moe.py b/examples/translation_moe/src/logsumexp_moe.py
index 0379f226b0..fb299daecb 100644
--- a/examples/translation_moe/src/logsumexp_moe.py
+++ b/examples/translation_moe/src/logsumexp_moe.py
@@ -21,6 +21,6 @@ def forward(ctx, logp, posterior, dim=-1):
@staticmethod
def backward(ctx, grad_output):
- posterior, = ctx.saved_tensors
+ (posterior,) = ctx.saved_tensors
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
return grad_logp, None, None
diff --git a/examples/translation_moe/src/mean_pool_gating_network.py b/examples/translation_moe/src/mean_pool_gating_network.py
index 25743b4e98..484b6ac912 100644
--- a/examples/translation_moe/src/mean_pool_gating_network.py
+++ b/examples/translation_moe/src/mean_pool_gating_network.py
@@ -26,15 +26,15 @@ def __init__(self, embed_dim, num_experts, dropout=None):
def forward(self, encoder_out):
if not (
- hasattr(encoder_out, 'encoder_out')
- and hasattr(encoder_out, 'encoder_padding_mask')
+ hasattr(encoder_out, "encoder_out")
+ and hasattr(encoder_out, "encoder_padding_mask")
and encoder_out.encoder_out.size(2) == self.embed_dim
):
- raise ValueError('Unexpected format for encoder_out')
+ raise ValueError("Unexpected format for encoder_out")
# mean pooling over time
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
- encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
+ encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0
diff --git a/examples/translation_moe/src/translation_moe.py b/examples/translation_moe/src/translation_moe.py
index 5455dd6681..ae458aaad3 100644
--- a/examples/translation_moe/src/translation_moe.py
+++ b/examples/translation_moe/src/translation_moe.py
@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
import torch
-
from fairseq import metrics, utils
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask
@@ -13,7 +12,7 @@
from .mean_pool_gating_network import MeanPoolGatingNetwork
-@register_task('translation_moe')
+@register_task("translation_moe")
class TranslationMoETask(TranslationTask):
"""
Translation task for Mixture of Experts (MoE) models.
@@ -58,19 +57,19 @@ def add_args(parser):
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
- if args.method == 'sMoElp':
+ if args.method == "sMoElp":
# soft MoE with learned prior
self.uniform_prior = False
self.hard_selection = False
- elif args.method == 'sMoEup':
+ elif args.method == "sMoEup":
# soft MoE with uniform prior
self.uniform_prior = True
self.hard_selection = False
- elif args.method == 'hMoElp':
+ elif args.method == "hMoElp":
# hard MoE with learned prior
self.uniform_prior = False
self.hard_selection = True
- elif args.method == 'hMoEup':
+ elif args.method == "hMoEup":
# hard MoE with uniform prior
self.uniform_prior = True
self.hard_selection = True
@@ -78,50 +77,56 @@ def __init__(self, args, src_dict, tgt_dict):
# add indicator tokens for each expert
for i in range(args.num_experts):
# add to both dictionaries in case we're sharing embeddings
- src_dict.add_symbol(''.format(i))
- tgt_dict.add_symbol(''.format(i))
+ src_dict.add_symbol("".format(i))
+ tgt_dict.add_symbol("".format(i))
super().__init__(args, src_dict, tgt_dict)
def build_model(self, args):
from fairseq import models
+
model = models.build_model(args, self)
- if not self.uniform_prior and not hasattr(model, 'gating_network'):
+ if not self.uniform_prior and not hasattr(model, "gating_network"):
if self.args.mean_pool_gating_network:
- if getattr(args, 'mean_pool_gating_network_encoder_dim', None):
+ if getattr(args, "mean_pool_gating_network_encoder_dim", None):
encoder_dim = args.mean_pool_gating_network_encoder_dim
- elif getattr(args, 'encoder_embed_dim', None):
+ elif getattr(args, "encoder_embed_dim", None):
# assume that encoder_embed_dim is the encoder's output dimension
encoder_dim = args.encoder_embed_dim
else:
- raise ValueError('Must specify --mean-pool-gating-network-encoder-dim')
+ raise ValueError(
+ "Must specify --mean-pool-gating-network-encoder-dim"
+ )
- if getattr(args, 'mean_pool_gating_network_dropout', None):
+ if getattr(args, "mean_pool_gating_network_dropout", None):
dropout = args.mean_pool_gating_network_dropout
- elif getattr(args, 'dropout', None):
+ elif getattr(args, "dropout", None):
dropout = args.dropout
else:
- raise ValueError('Must specify --mean-pool-gating-network-dropout')
+ raise ValueError("Must specify --mean-pool-gating-network-dropout")
model.gating_network = MeanPoolGatingNetwork(
- encoder_dim, args.num_experts, dropout,
+ encoder_dim,
+ args.num_experts,
+ dropout,
)
else:
raise ValueError(
- 'translation_moe task with learned prior requires the model to '
- 'have a gating network; try using --mean-pool-gating-network'
+ "translation_moe task with learned prior requires the model to "
+ "have a gating network; try using --mean-pool-gating-network"
)
return model
def expert_index(self, i):
- return i + self.tgt_dict.index('')
+ return i + self.tgt_dict.index("")
def _get_loss(self, sample, model, criterion):
- assert hasattr(criterion, 'compute_loss'), \
- 'translation_moe task requires the criterion to implement the compute_loss() method'
+ assert hasattr(
+ criterion, "compute_loss"
+ ), "translation_moe task requires the criterion to implement the compute_loss() method"
k = self.args.num_experts
- bsz = sample['target'].size(0)
+ bsz = sample["target"].size(0)
def get_lprob_y(encoder_out, prev_output_tokens_k):
net_output = model.decoder(
@@ -134,20 +139,22 @@ def get_lprob_y(encoder_out, prev_output_tokens_k):
def get_lprob_yz(winners=None):
encoder_out = model.encoder(
- src_tokens=sample['net_input']['src_tokens'],
- src_lengths=sample['net_input']['src_lengths'],
+ src_tokens=sample["net_input"]["src_tokens"],
+ src_lengths=sample["net_input"]["src_lengths"],
)
if winners is None:
lprob_y = []
for i in range(k):
- prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
+ prev_output_tokens_k = sample["net_input"][
+ "prev_output_tokens"
+ ].clone()
assert not prev_output_tokens_k.requires_grad
prev_output_tokens_k[:, 0] = self.expert_index(i)
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
else:
- prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
+ prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone()
prev_output_tokens_k[:, 0] = self.expert_index(winners)
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
@@ -177,17 +184,21 @@ def get_lprob_yz(winners=None):
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
loss = loss.sum()
- sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
+ sample_size = (
+ sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
+ )
logging_output = {
- 'loss': utils.item(loss.data),
- 'ntokens': sample['ntokens'],
- 'nsentences': bsz,
- 'sample_size': sample_size,
- 'posterior': prob_z_xy.float().sum(dim=0).cpu(),
+ "loss": utils.item(loss.data),
+ "ntokens": sample["ntokens"],
+ "nsentences": bsz,
+ "sample_size": sample_size,
+ "posterior": prob_z_xy.float().sum(dim=0).cpu(),
}
return loss, sample_size, logging_output
- def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
+ def train_step(
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
+ ):
model.train()
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
if ignore_grad:
@@ -201,7 +212,15 @@ def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
return loss, sample_size, logging_output
- def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None, constraints=None):
+ def inference_step(
+ self,
+ generator,
+ models,
+ sample,
+ prefix_tokens=None,
+ expert=None,
+ constraints=None,
+ ):
expert = expert or self.args.gen_expert
with torch.no_grad():
return generator.generate(
@@ -215,6 +234,6 @@ def inference_step(self, generator, models, sample, prefix_tokens=None, expert=N
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
metrics.log_scalar(
- 'posterior',
- sum(log['posterior'] for log in logging_outputs if 'posterior' in log)
+ "posterior",
+ sum(log["posterior"] for log in logging_outputs if "posterior" in log),
)
diff --git a/examples/unsupervised_quality_estimation/aggregate_scores.py b/examples/unsupervised_quality_estimation/aggregate_scores.py
index 35a6baf67d..66d50d07ff 100644
--- a/examples/unsupervised_quality_estimation/aggregate_scores.py
+++ b/examples/unsupervised_quality_estimation/aggregate_scores.py
@@ -4,37 +4,38 @@
# LICENSE file in the root directory of this source tree.
import argparse
-import numpy as np
import sys
+import numpy as np
+
aggregate_funcs = {
- 'std': np.std,
- 'var': np.var,
- 'median': np.median,
- 'mean': np.mean,
- 'min': np.min,
- 'max': np.max,
+ "std": np.std,
+ "var": np.var,
+ "median": np.median,
+ "mean": np.mean,
+ "min": np.min,
+ "max": np.max,
}
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('-i', '--input_file', required=True, type=str)
- parser.add_argument('-n', '--repeat_times', required=True, type=int)
- parser.add_argument('-o', '--output_file', required=False)
- parser.add_argument('-f', '--func', required=False, default='mean')
+ parser.add_argument("-i", "--input_file", required=True, type=str)
+ parser.add_argument("-n", "--repeat_times", required=True, type=int)
+ parser.add_argument("-o", "--output_file", required=False)
+ parser.add_argument("-f", "--func", required=False, default="mean")
args = parser.parse_args()
- stream = open(args.output_file, 'w') if args.output_file else sys.stdout
+ stream = open(args.output_file, "w") if args.output_file else sys.stdout
segment_scores = []
for line in open(args.input_file):
segment_scores.append(float(line.strip()))
if len(segment_scores) == args.repeat_times:
- stream.write('{}\n'.format(aggregate_funcs[args.func](segment_scores)))
+ stream.write("{}\n".format(aggregate_funcs[args.func](segment_scores)))
segment_scores = []
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/unsupervised_quality_estimation/meteor.py b/examples/unsupervised_quality_estimation/meteor.py
index ed4ba4ec34..4a214e794d 100644
--- a/examples/unsupervised_quality_estimation/meteor.py
+++ b/examples/unsupervised_quality_estimation/meteor.py
@@ -4,14 +4,13 @@
# LICENSE file in the root directory of this source tree.
import argparse
+import math
import os
-import sys
import subprocess
+import sys
import tempfile
-import math
-
-from itertools import combinations
from collections import defaultdict
+from itertools import combinations
def read_translations(path, n_repeats):
@@ -19,7 +18,7 @@ def read_translations(path, n_repeats):
segment_translations = []
translations = defaultdict(list)
for line in open(path):
- segment_translations.append(' '.join(line.split()))
+ segment_translations.append(" ".join(line.split()))
if len(segment_translations) == n_repeats:
translations[segment_counter] = segment_translations
segment_translations = []
@@ -30,42 +29,55 @@ def read_translations(path, n_repeats):
def generate_input(translations, n_repeats):
_, ref_path = tempfile.mkstemp()
_, mt_path = tempfile.mkstemp()
- ref_fh = open(ref_path, 'w')
- mt_fh = open(mt_path, 'w')
+ ref_fh = open(ref_path, "w")
+ mt_fh = open(mt_path, "w")
for segid in sorted(translations.keys()):
assert len(translations[segid]) == n_repeats
indexes = combinations(range(n_repeats), 2)
for idx1, idx2 in indexes:
- mt_fh.write(translations[segid][idx1].strip() + '\n')
- ref_fh.write(translations[segid][idx2].strip() + '\n')
- sys.stderr.write('\nSaved translations to %s and %s' % (ref_path, mt_path))
+ mt_fh.write(translations[segid][idx1].strip() + "\n")
+ ref_fh.write(translations[segid][idx2].strip() + "\n")
+ sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path))
return ref_path, mt_path
-def run_meteor(ref_path, mt_path, metric_path, lang='en'):
+def run_meteor(ref_path, mt_path, metric_path, lang="en"):
_, out_path = tempfile.mkstemp()
- subprocess.call([
- 'java', '-Xmx2G', '-jar', metric_path, mt_path, ref_path,
- '-p', '0.5 0.2 0.6 0.75', # default parameters, only changed alpha to give equal weight to P and R
- '-norm',
- '-l', lang], stdout=open(out_path, 'w'))
+ subprocess.call(
+ [
+ "java",
+ "-Xmx2G",
+ "-jar",
+ metric_path,
+ mt_path,
+ ref_path,
+ "-p",
+ "0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R
+ "-norm",
+ "-l",
+ lang,
+ ],
+ stdout=open(out_path, "w"),
+ )
os.remove(ref_path)
os.remove(mt_path)
- sys.stderr.write('\nSaved Meteor output to %s' % out_path)
+ sys.stderr.write("\nSaved Meteor output to %s" % out_path)
return out_path
def read_output(meteor_output_path, n_repeats):
- n_combinations = math.factorial(n_repeats)/(math.factorial(2) * math.factorial(n_repeats - 2))
+ n_combinations = math.factorial(n_repeats) / (
+ math.factorial(2) * math.factorial(n_repeats - 2)
+ )
raw_scores = []
average_scores = []
for line in open(meteor_output_path):
- if not line.startswith('Segment '):
+ if not line.startswith("Segment "):
continue
- score = float(line.strip().split('\t')[1])
+ score = float(line.strip().split("\t")[1])
raw_scores.append(score)
if len(raw_scores) == n_combinations:
- average_scores.append(sum(raw_scores)/n_combinations)
+ average_scores.append(sum(raw_scores) / n_combinations)
raw_scores = []
os.remove(meteor_output_path)
return average_scores
@@ -73,25 +85,25 @@ def read_output(meteor_output_path, n_repeats):
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('-i', '--input')
- parser.add_argument('-n', '--repeat_times', type=int)
- parser.add_argument('-m', '--meteor')
- parser.add_argument('-o', '--output')
+ parser.add_argument("-i", "--input")
+ parser.add_argument("-n", "--repeat_times", type=int)
+ parser.add_argument("-m", "--meteor")
+ parser.add_argument("-o", "--output")
args = parser.parse_args()
translations = read_translations(args.infile, args.repetitions)
- sys.stderr.write('\nGenerating input for Meteor...')
+ sys.stderr.write("\nGenerating input for Meteor...")
ref_path, mt_path = generate_input(translations, args.repetitions)
- sys.stderr.write('\nRunning Meteor...')
+ sys.stderr.write("\nRunning Meteor...")
out_path = run_meteor(ref_path, mt_path, args.meteor)
- sys.stderr.write('\nReading output...')
+ sys.stderr.write("\nReading output...")
scores = read_output(out_path, args.repetitions)
- sys.stderr.write('\nWriting results...')
- with open(args.output, 'w') as o:
+ sys.stderr.write("\nWriting results...")
+ with open(args.output, "w") as o:
for scr in scores:
- o.write('{}\n'.format(scr))
+ o.write("{}\n".format(scr))
o.close()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/unsupervised_quality_estimation/repeat_lines.py b/examples/unsupervised_quality_estimation/repeat_lines.py
index 661ca17c1b..5a04851a74 100644
--- a/examples/unsupervised_quality_estimation/repeat_lines.py
+++ b/examples/unsupervised_quality_estimation/repeat_lines.py
@@ -8,21 +8,21 @@
def _normalize_spaces(line):
- return ' '.join(line.split())
+ return " ".join(line.split())
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('-i', '--input_file', required=True, type=str)
- parser.add_argument('-n', '--repeat_times', required=True, type=int)
- parser.add_argument('-o', '--output_file', required=False, type=str)
+ parser.add_argument("-i", "--input_file", required=True, type=str)
+ parser.add_argument("-n", "--repeat_times", required=True, type=int)
+ parser.add_argument("-o", "--output_file", required=False, type=str)
args = parser.parse_args()
- stream = open(args.output_file, 'w') if args.output_file else sys.stdout
+ stream = open(args.output_file, "w") if args.output_file else sys.stdout
for line in open(args.input_file):
for _ in range(args.repeat_times):
- stream.write(_normalize_spaces(line) + '\n')
+ stream.write(_normalize_spaces(line) + "\n")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/wav2vec/vq-wav2vec_featurize.py b/examples/wav2vec/vq-wav2vec_featurize.py
index 0d658c07ca..baabc1d365 100644
--- a/examples/wav2vec/vq-wav2vec_featurize.py
+++ b/examples/wav2vec/vq-wav2vec_featurize.py
@@ -8,30 +8,31 @@
Helper script to pre-compute embeddings for a wav2letter++ dataset
"""
+import argparse
+import glob
+import os
+import os.path as osp
import pprint
-import glob, os, argparse
+import soundfile as sf
import torch
+import tqdm
+from fairseq.models.wav2vec.wav2vec import Wav2VecModel
from torch import nn
+from torch.utils.data import DataLoader
+
try:
import tqdm
except:
print("Install tqdm to use --log-format=tqdm")
-from fairseq.models.wav2vec.wav2vec import Wav2VecModel
-
-import tqdm
-import soundfile as sf
-from torch.utils.data import DataLoader
-import os.path as osp
-
class FilesDataset:
def __init__(self, files, labels):
self.files = files
if labels and osp.exists(labels):
- with open(labels, 'r') as lbl_f:
+ with open(labels, "r") as lbl_f:
self.labels = [line.rstrip() for line in lbl_f]
else:
self.labels = labels
@@ -50,7 +51,7 @@ def __getitem__(self, index):
if self.labels:
if isinstance(self.labels, str):
lbl_file = osp.splitext(fname)[0] + "." + self.labels
- with open(lbl_file, 'r') as lblf:
+ with open(lbl_file, "r") as lblf:
lbls = lblf.readline()
assert lbls is not None
else:
@@ -116,24 +117,24 @@ def process_splits(self):
assert len(files) > 0
if self.args.shard is not None:
- files = files[self.args.shard::self.args.num_shards]
+ files = files[self.args.shard :: self.args.num_shards]
lbls = []
- with open(self.data_file(split), 'w') as srcf:
+ with open(self.data_file(split), "w") as srcf:
for line, lbl in self.iterate(files):
print(line, file=srcf)
if self.args.labels:
- lbls.append(lbl + '\n')
+ lbls.append(lbl + "\n")
if self.args.labels:
assert all(a is not None for a in lbls)
- with open(self.lbl_file(split), 'w') as lblf:
+ with open(self.lbl_file(split), "w") as lblf:
lblf.writelines(lbls)
def iterate(self, files):
data = self.load_data(files)
- for samples in tqdm.tqdm(data, total=len(files)//32):
+ for samples in tqdm.tqdm(data, total=len(files) // 32):
for wav, lbl in samples:
x = wav.unsqueeze(0).float().cuda()
@@ -162,7 +163,6 @@ def iterate(self, files):
idx = torch.cat(result, dim=0)
yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl
-
def lbl_file(self, name):
shard_part = "" if self.args.shard is None else f".{self.args.shard}"
return osp.join(self.output_dir, f"{name}.lbl{shard_part}")
@@ -230,7 +230,9 @@ def __call__(self):
self.process_splits()
- if hasattr(self.model.feature_extractor, "vars") and (self.args.shard is None or self.args.shard == 0):
+ if hasattr(self.model.feature_extractor, "vars") and (
+ self.args.shard is None or self.args.shard == 0
+ ):
vars = (
self.model.feature_extractor.vars.view(
self.model.feature_extractor.banks,
@@ -248,4 +250,4 @@ def __call__(self):
write_data = DatasetWriter()
write_data()
- print("Done.")
\ No newline at end of file
+ print("Done.")
diff --git a/examples/wav2vec/wav2vec_featurize.py b/examples/wav2vec/wav2vec_featurize.py
index 445a5d0213..9283930587 100644
--- a/examples/wav2vec/wav2vec_featurize.py
+++ b/examples/wav2vec/wav2vec_featurize.py
@@ -14,13 +14,12 @@
from shutil import copy
import h5py
-import soundfile as sf
import numpy as np
+import soundfile as sf
import torch
-from torch import nn
import tqdm
-
from fairseq.models.wav2vec.wav2vec import Wav2VecModel
+from torch import nn
def read_audio(fname):
@@ -33,7 +32,6 @@ def read_audio(fname):
class PretrainedWav2VecModel(nn.Module):
-
def __init__(self, fname):
super().__init__()
@@ -55,32 +53,33 @@ def forward(self, x):
class EmbeddingWriterConfig(argparse.ArgumentParser):
-
def __init__(self):
super().__init__("Pre-compute embeddings for wav2letter++ datasets")
kwargs = {"action": "store", "type": str, "required": True}
- self.add_argument("--input", "-i",
- help="Input Directory", **kwargs)
- self.add_argument("--output", "-o",
- help="Output Directory", **kwargs)
- self.add_argument("--model",
- help="Path to model checkpoint", **kwargs)
- self.add_argument("--split",
- help="Dataset Splits", nargs='+', **kwargs)
- self.add_argument("--ext", default="wav", required=False,
- help="Audio file extension")
-
- self.add_argument("--no-copy-labels", action="store_true",
- help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.")
- self.add_argument("--use-feat", action="store_true",
- help="Use the feature vector ('z') instead of context vector ('c') for features")
- self.add_argument("--gpu",
- help="GPU to use", default=0, type=int)
-
-
-class Prediction():
+ self.add_argument("--input", "-i", help="Input Directory", **kwargs)
+ self.add_argument("--output", "-o", help="Output Directory", **kwargs)
+ self.add_argument("--model", help="Path to model checkpoint", **kwargs)
+ self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs)
+ self.add_argument(
+ "--ext", default="wav", required=False, help="Audio file extension"
+ )
+
+ self.add_argument(
+ "--no-copy-labels",
+ action="store_true",
+ help="Do not copy label files. Useful for large datasets, use --targetdir in wav2letter then.",
+ )
+ self.add_argument(
+ "--use-feat",
+ action="store_true",
+ help="Use the feature vector ('z') instead of context vector ('c') for features",
+ )
+ self.add_argument("--gpu", help="GPU to use", default=0, type=int)
+
+
+class Prediction:
""" Lightweight wrapper around a fairspeech embedding model """
def __init__(self, fname, gpu=0):
@@ -95,7 +94,7 @@ def __call__(self, x):
return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy()
-class H5Writer():
+class H5Writer:
""" Write features as hdf5 file in wav2letter++ compatible format """
def __init__(self, fname):
@@ -112,7 +111,7 @@ def write(self, data):
class EmbeddingDatasetWriter(object):
- """ Given a model and a wav2letter++ dataset, pre-compute and store embeddings
+ """Given a model and a wav2letter++ dataset, pre-compute and store embeddings
Args:
input_root, str :
@@ -123,13 +122,17 @@ class EmbeddingDatasetWriter(object):
Dataset split
"""
- def __init__(self, input_root, output_root, split,
- model_fname,
- extension="wav",
- gpu=0,
- verbose=False,
- use_feat=False,
- ):
+ def __init__(
+ self,
+ input_root,
+ output_root,
+ split,
+ model_fname,
+ extension="wav",
+ gpu=0,
+ verbose=False,
+ use_feat=False,
+ ):
assert os.path.exists(model_fname)
@@ -143,8 +146,9 @@ def __init__(self, input_root, output_root, split,
self.extension = extension
self.use_feat = use_feat
- assert os.path.exists(self.input_path), \
- "Input path '{}' does not exist".format(self.input_path)
+ assert os.path.exists(self.input_path), "Input path '{}' does not exist".format(
+ self.input_path
+ )
def _progress(self, iterable, **kwargs):
if self.verbose:
@@ -176,7 +180,11 @@ def get_output_path(self, fname=None):
def copy_labels(self):
self.require_output_path()
- labels = list(filter(lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))))
+ labels = list(
+ filter(
+ lambda x: self.extension not in x, glob.glob(self.get_input_path("*"))
+ )
+ )
for fname in tqdm.tqdm(labels):
copy(fname, self.output_path)
@@ -191,10 +199,16 @@ def write_features(self):
paths = self.input_fnames
- fnames_context = map(lambda x: os.path.join(self.output_path, x.replace("." + self.extension, ".h5context")), \
- map(os.path.basename, paths))
+ fnames_context = map(
+ lambda x: os.path.join(
+ self.output_path, x.replace("." + self.extension, ".h5context")
+ ),
+ map(os.path.basename, paths),
+ )
- for name, target_fname in self._progress(zip(paths, fnames_context), total=len(self)):
+ for name, target_fname in self._progress(
+ zip(paths, fnames_context), total=len(self)
+ ):
wav, sr = read_audio(name)
z, c = self.model(wav)
feat = z if self.use_feat else c
@@ -204,7 +218,8 @@ def write_features(self):
def __repr__(self):
return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format(
- n_files=len(self), **self.__dict__)
+ n_files=len(self), **self.__dict__
+ )
if __name__ == "__main__":
diff --git a/examples/wav2vec/wav2vec_manifest.py b/examples/wav2vec/wav2vec_manifest.py
index c80f9883df..1d27f58afc 100644
--- a/examples/wav2vec/wav2vec_manifest.py
+++ b/examples/wav2vec/wav2vec_manifest.py
@@ -10,32 +10,50 @@
import argparse
import glob
import os
-import soundfile
import random
+import soundfile
+
def get_parser():
parser = argparse.ArgumentParser()
- parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index')
- parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D',
- help='percentage of data to use as validation set (between 0 and 1)')
- parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory')
- parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for')
- parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed')
- parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG',
- help='if set, path must contain this substring for a file to be included in the manifest')
+ parser.add_argument(
+ "root", metavar="DIR", help="root directory containing flac files to index"
+ )
+ parser.add_argument(
+ "--valid-percent",
+ default=0.01,
+ type=float,
+ metavar="D",
+ help="percentage of data to use as validation set (between 0 and 1)",
+ )
+ parser.add_argument(
+ "--dest", default=".", type=str, metavar="DIR", help="output directory"
+ )
+ parser.add_argument(
+ "--ext", default="flac", type=str, metavar="EXT", help="extension to look for"
+ )
+ parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed")
+ parser.add_argument(
+ "--path-must-contain",
+ default=None,
+ type=str,
+ metavar="FRAG",
+ help="if set, path must contain this substring for a file to be included in the manifest",
+ )
return parser
def main(args):
- assert args.valid_percent >= 0 and args.valid_percent <= 1.
+ assert args.valid_percent >= 0 and args.valid_percent <= 1.0
dir_path = os.path.realpath(args.root)
- search_path = os.path.join(dir_path, '**/*.' + args.ext)
+ search_path = os.path.join(dir_path, "**/*." + args.ext)
rand = random.Random(args.seed)
- with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open(
- os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f:
+ with open(os.path.join(args.dest, "train.tsv"), "w") as train_f, open(
+ os.path.join(args.dest, "valid.tsv"), "w"
+ ) as valid_f:
print(dir_path, file=train_f)
print(dir_path, file=valid_f)
@@ -47,10 +65,12 @@ def main(args):
frames = soundfile.info(fname).frames
dest = train_f if rand.random() > args.valid_percent else valid_f
- print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest)
+ print(
+ "{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)
diff --git a/fairseq/__init__.py b/fairseq/__init__.py
index a4244c8a3a..cac3d0e43b 100644
--- a/fairseq/__init__.py
+++ b/fairseq/__init__.py
@@ -4,16 +4,17 @@
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
-__all__ = ['pdb']
-__version__ = '1.0.0a0'
+__all__ = ["pdb"]
+__version__ = "1.0.0a0"
import sys
# backwards compatibility to support `from fairseq.meters import AverageMeter`
from fairseq.logging import meters, metrics, progress_bar # noqa
-sys.modules['fairseq.meters'] = meters
-sys.modules['fairseq.metrics'] = metrics
-sys.modules['fairseq.progress_bar'] = progress_bar
+
+sys.modules["fairseq.meters"] = meters
+sys.modules["fairseq.metrics"] = metrics
+sys.modules["fairseq.progress_bar"] = progress_bar
import fairseq.criterions # noqa
import fairseq.models # noqa
diff --git a/fairseq/benchmark/__init__.py b/fairseq/benchmark/__init__.py
index 926f3ce739..f6584661bd 100644
--- a/fairseq/benchmark/__init__.py
+++ b/fairseq/benchmark/__init__.py
@@ -4,9 +4,4 @@
# LICENSE file in the root directory of this source tree.
# import models/tasks to register them
-from . import ( # noqa
- dummy_lm,
- dummy_masked_lm,
- dummy_model,
- dummy_mt,
-)
+from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py
index 3c400e9d7f..6429d04de3 100644
--- a/fairseq/benchmark/dummy_lm.py
+++ b/fairseq/benchmark/dummy_lm.py
@@ -7,25 +7,27 @@
import numpy as np
import torch
-
from fairseq.data import Dictionary, FairseqDataset
-from fairseq.tasks import register_task, LegacyFairseqTask
+from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
-@register_task('dummy_lm')
+@register_task("dummy_lm")
class DummyLMTask(LegacyFairseqTask):
-
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('--dict-size', default=49996, type=int)
- parser.add_argument('--dataset-size', default=100000, type=int)
- parser.add_argument('--tokens-per-sample', default=512, type=int,
- help='max number of total tokens over all segments '
- 'per sample for BERT dataset')
+ parser.add_argument("--dict-size", default=49996, type=int)
+ parser.add_argument("--dataset-size", default=100000, type=int)
+ parser.add_argument(
+ "--tokens-per-sample",
+ default=512,
+ type=int,
+ help="max number of total tokens over all segments "
+ "per sample for BERT dataset",
+ )
def __init__(self, args, dictionary):
super().__init__(args)
@@ -44,8 +46,8 @@ def setup_task(cls, args, **kwargs):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
- dictionary.add_symbol('word{}'.format(i))
- logger.info('dictionary: {} types'.format(len(dictionary)))
+ dictionary.add_symbol("word{}".format(i))
+ logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
@@ -59,16 +61,16 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
- 'id': 1,
- 'net_input': {
- 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
- 'src_lengths': torch.full(
- (bsz, ), self.args.tokens_per_sample, dtype=torch.long
+ "id": 1,
+ "net_input": {
+ "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
+ "src_lengths": torch.full(
+ (bsz,), self.args.tokens_per_sample, dtype=torch.long
),
},
- 'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
- 'nsentences': bsz,
- 'ntokens': bsz * self.args.tokens_per_sample,
+ "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
+ "nsentences": bsz,
+ "ntokens": bsz * self.args.tokens_per_sample,
},
num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample,
@@ -84,7 +86,6 @@ def target_dictionary(self):
class DummyDataset(FairseqDataset):
-
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py
index 621265d452..ab506fe1d5 100644
--- a/fairseq/benchmark/dummy_masked_lm.py
+++ b/fairseq/benchmark/dummy_masked_lm.py
@@ -7,32 +7,34 @@
import numpy as np
import torch
-
from fairseq.data import Dictionary, FairseqDataset
-from fairseq.tasks import register_task, LegacyFairseqTask
+from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
-@register_task('dummy_masked_lm')
+@register_task("dummy_masked_lm")
class DummyMaskedLMTask(LegacyFairseqTask):
-
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('--dict-size', default=49995, type=int)
- parser.add_argument('--dataset-size', default=100000, type=int)
- parser.add_argument('--tokens-per-sample', default=512, type=int,
- help='max number of total tokens over all segments '
- 'per sample for BERT dataset')
+ parser.add_argument("--dict-size", default=49995, type=int)
+ parser.add_argument("--dataset-size", default=100000, type=int)
+ parser.add_argument(
+ "--tokens-per-sample",
+ default=512,
+ type=int,
+ help="max number of total tokens over all segments "
+ "per sample for BERT dataset",
+ )
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
# add mask token
- self.mask_idx = dictionary.add_symbol('')
+ self.mask_idx = dictionary.add_symbol("")
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
mask_idx = 0
@@ -52,8 +54,8 @@ def setup_task(cls, args, **kwargs):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
- dictionary.add_symbol('word{}'.format(i))
- logger.info('dictionary: {} types'.format(len(dictionary)))
+ dictionary.add_symbol("word{}".format(i))
+ logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
@@ -67,16 +69,16 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
- 'id': 1,
- 'net_input': {
- 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
- 'src_lengths': torch.full(
- (bsz, ), self.args.tokens_per_sample, dtype=torch.long
+ "id": 1,
+ "net_input": {
+ "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
+ "src_lengths": torch.full(
+ (bsz,), self.args.tokens_per_sample, dtype=torch.long
),
},
- 'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
- 'nsentences': bsz,
- 'ntokens': bsz * self.args.tokens_per_sample,
+ "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
+ "nsentences": bsz,
+ "ntokens": bsz * self.args.tokens_per_sample,
},
num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample,
@@ -92,7 +94,6 @@ def target_dictionary(self):
class DummyDataset(FairseqDataset):
-
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
diff --git a/fairseq/benchmark/dummy_model.py b/fairseq/benchmark/dummy_model.py
index 817cdb34bb..ff26e4fe65 100644
--- a/fairseq/benchmark/dummy_model.py
+++ b/fairseq/benchmark/dummy_model.py
@@ -5,7 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq.data import Dictionary
from fairseq.models import (
FairseqDecoder,
@@ -15,17 +14,16 @@
)
-@register_model('dummy_model')
+@register_model("dummy_model")
class DummyModel(FairseqLanguageModel):
-
def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args
@staticmethod
def add_args(parser):
- parser.add_argument('--num-layers', type=int, default=24)
- parser.add_argument('--embed-dim', type=int, default=1024)
+ parser.add_argument("--num-layers", type=int, default=24)
+ parser.add_argument("--embed-dim", type=int, default=1024)
@classmethod
def build_model(cls, args, task):
@@ -41,32 +39,35 @@ def forward(self, src_tokens, masked_tokens=None, **kwargs):
class DummyEncoder(FairseqDecoder):
-
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
super().__init__(Dictionary())
self.embed = nn.Embedding(
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
)
- self.layers_a = nn.ModuleList([
- nn.Sequential(
- nn.LayerNorm(embed_dim),
- nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection
- nn.Linear(3*embed_dim, embed_dim), # skip self-attention
- nn.Linear(embed_dim, embed_dim), # output projection
- nn.Dropout(),
- )
- for i in range(num_layers)
- ])
- self.layers_b = nn.ModuleList([
- nn.Sequential(
- nn.LayerNorm(embed_dim),
- nn.Linear(embed_dim, 4*embed_dim), # FFN
- nn.ReLU(),
- nn.Linear(4*embed_dim, embed_dim), # FFN
- nn.Dropout(0.1),
- )
- for i in range(num_layers)
- ])
+ self.layers_a = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.LayerNorm(embed_dim),
+ nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
+ nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
+ nn.Linear(embed_dim, embed_dim), # output projection
+ nn.Dropout(),
+ )
+ for i in range(num_layers)
+ ]
+ )
+ self.layers_b = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.LayerNorm(embed_dim),
+ nn.Linear(embed_dim, 4 * embed_dim), # FFN
+ nn.ReLU(),
+ nn.Linear(4 * embed_dim, embed_dim), # FFN
+ nn.Dropout(0.1),
+ )
+ for i in range(num_layers)
+ ]
+ )
self.out_proj = nn.Linear(embed_dim, num_embed)
def forward(self, tokens, masked_tokens=None):
@@ -90,6 +91,6 @@ def get_normalized_probs(self, net_output, log_probs, sample=None):
return F.softmax(logits, dim=-1)
-@register_model_architecture('dummy_model', 'dummy_model')
+@register_model_architecture("dummy_model", "dummy_model")
def base_architecture(args):
pass
diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py
index 2f8d65d5be..4ca7be93a3 100644
--- a/fairseq/benchmark/dummy_mt.py
+++ b/fairseq/benchmark/dummy_mt.py
@@ -7,24 +7,22 @@
import numpy as np
import torch
-
from fairseq.data import Dictionary, FairseqDataset
-from fairseq.tasks import register_task, LegacyFairseqTask
+from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
-@register_task('dummy_mt')
+@register_task("dummy_mt")
class DummyMTTask(LegacyFairseqTask):
-
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('--dict-size', default=49996, type=int)
- parser.add_argument('--dataset-size', default=100000, type=int)
- parser.add_argument('--src-len', default=30, type=int)
- parser.add_argument('--tgt-len', default=30, type=int)
+ parser.add_argument("--dict-size", default=49996, type=int)
+ parser.add_argument("--dataset-size", default=100000, type=int)
+ parser.add_argument("--src-len", default=30, type=int)
+ parser.add_argument("--tgt-len", default=30, type=int)
def __init__(self, args, dictionary):
super().__init__(args)
@@ -41,8 +39,8 @@ def setup_task(cls, args, **kwargs):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
- dictionary.add_symbol('word{}'.format(i))
- logger.info('dictionary: {} types'.format(len(dictionary)))
+ dictionary.add_symbol("word{}".format(i))
+ logger.info("dictionary: {} types".format(len(dictionary)))
args.max_source_positions = args.src_len + dictionary.pad() + 2
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
@@ -62,17 +60,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
self.datasets[split] = DummyDataset(
{
- 'id': 1,
- 'net_input': {
- 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
- 'src_lengths': torch.full(
- (bsz, ), self.args.src_len, dtype=torch.long
+ "id": 1,
+ "net_input": {
+ "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
+ "src_lengths": torch.full(
+ (bsz,), self.args.src_len, dtype=torch.long
),
- 'prev_output_tokens': tgt.clone(),
+ "prev_output_tokens": tgt.clone(),
},
- 'target': tgt,
- 'nsentences': bsz,
- 'ntokens': bsz * self.args.tgt_len,
+ "target": tgt,
+ "nsentences": bsz,
+ "ntokens": bsz * self.args.tgt_len,
},
num_items=self.args.dataset_size,
item_size=item_size,
@@ -88,7 +86,6 @@ def target_dictionary(self):
class DummyDataset(FairseqDataset):
-
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py
index ec3b90f211..0255c084b5 100644
--- a/fairseq/binarizer.py
+++ b/fairseq/binarizer.py
@@ -6,9 +6,10 @@
import os
from collections import Counter
-from fairseq.tokenizer import tokenize_line
import torch
from fairseq.file_io import PathManager
+from fairseq.tokenizer import tokenize_line
+
def safe_readline(f):
pos = f.tell()
diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py
index 60ab3190c7..75e2c68ca3 100644
--- a/fairseq/checkpoint_utils.py
+++ b/fairseq/checkpoint_utils.py
@@ -67,12 +67,14 @@ def is_better(a, b):
or is_better(val_loss, save_checkpoint.best)
)
if val_loss is not None and args.keep_best_checkpoints > 0:
- checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
- args.best_checkpoint_metric, val_loss)] = (
- not hasattr(save_checkpoint, "best")
- or is_better(val_loss, save_checkpoint.best)
+ checkpoint_conds[
+ "checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss)
+ ] = not hasattr(save_checkpoint, "best") or is_better(
+ val_loss, save_checkpoint.best
)
- checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = not args.no_last_checkpoints
+ checkpoint_conds[
+ "checkpoint_last{}.pt".format(suffix)
+ ] = not args.no_last_checkpoints
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"):
@@ -112,10 +114,14 @@ def is_better(a, b):
if args.keep_best_checkpoints > 0:
# only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths(
- args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric))
+ args.save_dir,
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
+ args.best_checkpoint_metric
+ ),
+ )
if not args.maximize_best_checkpoint_metric:
checkpoints = checkpoints[::-1]
- for old_chk in checkpoints[args.keep_best_checkpoints:]:
+ for old_chk in checkpoints[args.keep_best_checkpoints :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
@@ -133,16 +139,23 @@ def load_checkpoint(args, trainer, **passthrough_args):
reset_meters = args.reset_meters
reset_dataloader = args.reset_dataloader
- if getattr(args, 'finetune_from_model', None) is not None \
- and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader):
- raise ValueError("--finetune-from-model can not be set together with either --reset-optimizer"
- " or reset_lr_scheduler or reset_meters or reset_dataloader")
+ if getattr(args, "finetune_from_model", None) is not None and (
+ reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
+ ):
+ raise ValueError(
+ "--finetune-from-model can not be set together with either --reset-optimizer"
+ " or reset_lr_scheduler or reset_meters or reset_dataloader"
+ )
suffix = getattr(args, "checkpoint_suffix", "")
- if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt'
- checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix))
+ if (
+ args.restore_file == "checkpoint_last.pt"
+ ): # default value of restore_file is 'checkpoint_last.pt'
+ checkpoint_path = os.path.join(
+ args.save_dir, "checkpoint_last{}.pt".format(suffix)
+ )
first_launch = not PathManager.exists(checkpoint_path)
- if getattr(args, 'finetune_from_model', None) is not None and first_launch:
+ if getattr(args, "finetune_from_model", None) is not None and first_launch:
# if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if PathManager.exists(args.finetune_from_model):
@@ -151,19 +164,26 @@ def load_checkpoint(args, trainer, **passthrough_args):
reset_lr_scheduler = True
reset_meters = True
reset_dataloader = True
- logger.info(f'loading pretrained model from {checkpoint_path}: '
- 'optimizer, lr scheduler, meters, dataloader will be reset')
+ logger.info(
+ f"loading pretrained model from {checkpoint_path}: "
+ "optimizer, lr scheduler, meters, dataloader will be reset"
+ )
else:
- raise ValueError(f'--funetune-from-model {args.finetune_from_model} does not exist')
+ raise ValueError(
+ f"--funetune-from-model {args.finetune_from_model} does not exist"
+ )
elif getattr(args, "model_parallel_size", 1) > 1:
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
else:
checkpoint_path = args.restore_file
- if args.restore_file != "checkpoint_last.pt" and getattr(args, 'finetune_from_model', None):
+ if args.restore_file != "checkpoint_last.pt" and getattr(
+ args, "finetune_from_model", None
+ ):
raise ValueError(
- '--finetune-from-model and --restore-file (non-default value) '
- 'can not be specified together: ' + str(args))
+ "--finetune-from-model and --restore-file (non-default value) "
+ "can not be specified together: " + str(args)
+ )
extra_state = trainer.load_checkpoint(
checkpoint_path,
@@ -213,7 +233,9 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
return state
-def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1):
+def load_model_ensemble(
+ filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
+):
"""Loads an ensemble of models.
Args:
@@ -222,18 +244,28 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, s
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
- assert not (strict and num_shards > 1), \
- "Cannot load state dict with strict=True and checkpoint shards > 1"
+ assert not (
+ strict and num_shards > 1
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble, args, _task = load_model_ensemble_and_task(
- filenames, arg_overrides, task, strict, suffix, num_shards,
+ filenames,
+ arg_overrides,
+ task,
+ strict,
+ suffix,
+ num_shards,
)
return ensemble, args
-def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix='', num_shards=1):
+def load_model_ensemble_and_task(
+ filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
+):
from fairseq import tasks
- assert not (strict and num_shards > 1), \
- "Cannot load state dict with strict=True and checkpoint shards > 1"
+
+ assert not (
+ strict and num_shards > 1
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble = []
for filename in filenames:
orig_filename = filename
@@ -533,7 +565,9 @@ def verify_checkpoint_directory(save_dir: str) -> None:
with open(temp_file_path, "w"):
pass
except OSError as e:
- logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
+ logger.warning(
+ "Unable to access checkpoint save directory: {}".format(save_dir)
+ )
raise e
else:
os.remove(temp_file_path)
diff --git a/fairseq/criterions/composite_loss.py b/fairseq/criterions/composite_loss.py
index 6671c696e9..65341c2d3b 100644
--- a/fairseq/criterions/composite_loss.py
+++ b/fairseq/criterions/composite_loss.py
@@ -3,13 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from torch import nn
-
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
+from torch import nn
-@register_criterion('composite_loss')
+@register_criterion("composite_loss")
class CompositeLoss(FairseqCriterion):
"""This is a composite loss that, given a list of model outputs and a list of targets,
computes an average of losses for each output-target pair"""
@@ -40,7 +39,6 @@ def build_criterion(cls, args, task):
underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
class FakeModel(nn.Module):
-
def __init__(self, model, net_out, target):
super().__init__()
self.model = model
@@ -51,7 +49,9 @@ def forward(self, **unused):
return self.net_out
def get_normalized_probs(self, net_output, log_probs, sample=None):
- return self.model.get_normalized_probs(net_output, log_probs, sample=sample)
+ return self.model.get_normalized_probs(
+ net_output, log_probs, sample=sample
+ )
def get_targets(self, *unused):
return self.target
@@ -61,14 +61,13 @@ def decoder(self):
return self.model.decoder
class _CompositeLoss(FairseqCriterion):
-
def __init__(self, task, underlying_criterion):
super().__init__(task)
self.underlying_criterion = underlying_criterion
def forward(self, model, sample, reduce=True):
- net_outputs = model(**sample['net_input'])
- targets = sample['target']
+ net_outputs = model(**sample["net_input"])
+ targets = sample["target"]
bsz = targets[0].size(0)
loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
@@ -77,7 +76,7 @@ def forward(self, model, sample, reduce=True):
logging_output = {}
for o, t in zip(net_outputs[0], targets):
m = FakeModel(model, (o, net_outputs[1]), t)
- sample['target'] = t
+ sample["target"] = t
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
loss += l
sample_size += ss
@@ -85,12 +84,14 @@ def forward(self, model, sample, reduce=True):
loss.div_(len(targets))
sample_size /= len(targets)
- logging_output['loss'] = utils.item(loss.data) if reduce else loss.data
+ logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
- return underlying_criterion.__class__.aggregate_logging_outputs(logging_outputs)
+ return underlying_criterion.__class__.aggregate_logging_outputs(
+ logging_outputs
+ )
@staticmethod
def reduce_metrics(logging_outputs) -> None:
diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py
index 0239a548a9..ef94a86327 100644
--- a/fairseq/criterions/fairseq_criterion.py
+++ b/fairseq/criterions/fairseq_criterion.py
@@ -6,25 +6,23 @@
import inspect
from typing import Any, Dict, List
-from torch.nn.modules.loss import _Loss
-
from fairseq import metrics, utils
from fairseq.dataclass.utils import gen_parser_from_dataclass
+from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
-
def __init__(self, task):
super().__init__()
self.task = task
- if hasattr(task, 'target_dictionary'):
+ if hasattr(task, "target_dictionary"):
tgt_dict = task.target_dictionary
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
@classmethod
def add_args(cls, parser):
"""Add criterion-specific arguments to the parser."""
- dc = getattr(cls, '__dataclass', None)
+ dc = getattr(cls, "__dataclass", None)
if dc is not None:
gen_parser_from_dataclass(parser, dc())
@@ -43,20 +41,20 @@ def build_criterion(cls, args, task):
):
# we haven't implemented inference for these argument types,
# but PRs welcome :)
- raise NotImplementedError('{} not supported'.format(p.kind))
+ raise NotImplementedError("{} not supported".format(p.kind))
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
- if p.name == 'task':
- init_args['task'] = task
+ if p.name == "task":
+ init_args["task"] = task
elif hasattr(args, p.name):
init_args[p.name] = getattr(args, p.name)
elif p.default != p.empty:
pass # we'll use the default value
else:
raise NotImplementedError(
- 'Unable to infer Criterion arguments, please implement '
- '{}.build_criterion'.format(cls.__name__)
+ "Unable to infer Criterion arguments, please implement "
+ "{}.build_criterion".format(cls.__name__)
)
return cls(**init_args)
@@ -76,8 +74,8 @@ def aggregate_logging_outputs(
) -> Dict[str, Any]:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
- 'The aggregate_logging_outputs API is deprecated. '
- 'Please use the reduce_metrics API instead.'
+ "The aggregate_logging_outputs API is deprecated. "
+ "Please use the reduce_metrics API instead."
)
raise NotImplementedError
@@ -85,12 +83,12 @@ def aggregate_logging_outputs(
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
- 'Criterions should implement the reduce_metrics API. '
- 'Falling back to deprecated aggregate_logging_outputs API.'
+ "Criterions should implement the reduce_metrics API. "
+ "Falling back to deprecated aggregate_logging_outputs API."
)
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
for k, v in agg_logging_outputs.items():
- if k in {'nsentences', 'ntokens', 'sample_size'}:
+ if k in {"nsentences", "ntokens", "sample_size"}:
continue
metrics.log_scalar(k, v)
@@ -105,15 +103,14 @@ def logging_outputs_can_be_summed() -> bool:
class LegacyFairseqCriterion(FairseqCriterion):
-
def __init__(self, args, task):
super().__init__(task=task)
self.args = args
utils.deprecation_warning(
- 'Criterions should take explicit arguments instead of an '
- 'argparse.Namespace object, please update your criterion by '
- 'extending FairseqCriterion instead of LegacyFairseqCriterion.'
+ "Criterions should take explicit arguments instead of an "
+ "argparse.Namespace object, please update your criterion by "
+ "extending FairseqCriterion instead of LegacyFairseqCriterion."
)
@classmethod
diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py
index 931a8f76d5..2dc7f7a47d 100644
--- a/fairseq/criterions/label_smoothed_cross_entropy.py
+++ b/fairseq/criterions/label_smoothed_cross_entropy.py
@@ -6,7 +6,6 @@
import math
import torch
-
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@@ -18,8 +17,8 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
- nll_loss.masked_fill_(pad_mask, 0.)
- smooth_loss.masked_fill_(pad_mask, 0.)
+ nll_loss.masked_fill_(pad_mask, 0.0)
+ smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
@@ -27,14 +26,20 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1)
- loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss
+ loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
-@register_criterion('label_smoothed_cross_entropy')
+@register_criterion("label_smoothed_cross_entropy")
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
- def __init__(self, task, sentence_avg, label_smoothing,
- ignore_prefix_size=0, report_accuracy=False):
+ def __init__(
+ self,
+ task,
+ sentence_avg,
+ label_smoothing,
+ ignore_prefix_size=0,
+ report_accuracy=False,
+ ):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = label_smoothing
@@ -61,20 +66,22 @@ def forward(self, model, sample, reduce=True):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
- net_output = model(**sample['net_input'])
+ net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
- sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
+ sample_size = (
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
+ )
logging_output = {
- 'loss': loss.data,
- 'nll_loss': nll_loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample['target'].size(0),
- 'sample_size': sample_size,
+ "loss": loss.data,
+ "nll_loss": nll_loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample["target"].size(0),
+ "sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
- logging_output['n_correct'] = utils.item(n_correct.data)
- logging_output['total'] = utils.item(total.data)
+ logging_output["n_correct"] = utils.item(n_correct.data)
+ logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample):
@@ -82,17 +89,21 @@ def get_lprobs_and_target(self, model, net_output, sample):
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
- lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
- target = target[:, self.ignore_prefix_size:].contiguous()
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
+ target = target[:, self.ignore_prefix_size :].contiguous()
else:
- lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous()
- target = target[self.ignore_prefix_size:, :].contiguous()
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
+ target = target[self.ignore_prefix_size :, :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
- lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
+ lprobs,
+ target,
+ self.eps,
+ ignore_index=self.padding_idx,
+ reduce=reduce,
)
return loss, nll_loss
@@ -100,34 +111,43 @@ def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
- lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
+ )
total = torch.sum(mask)
return n_correct, total
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
- nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
- ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
- sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
- metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
- metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
- metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_scalar(
+ "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
+ )
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
+ )
- total = utils.item(sum(log.get('total', 0) for log in logging_outputs))
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
- metrics.log_scalar('total', total)
+ metrics.log_scalar("total", total)
n_correct = utils.item(
- sum(log.get('n_correct', 0) for log in logging_outputs)
+ sum(log.get("n_correct", 0) for log in logging_outputs)
)
- metrics.log_scalar('n_correct', n_correct)
+ metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
- 'accuracy',
+ "accuracy",
lambda meters: round(
- meters['n_correct'].sum * 100.0 / meters['total'].sum, 3
- ) if meters['total'].sum > 0 else float('nan'),
+ meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
+ )
+ if meters["total"].sum > 0
+ else float("nan"),
)
@staticmethod
diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
index cfc7e008cd..73cfa05310 100644
--- a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
+++ b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
@@ -11,9 +11,10 @@
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
-@register_criterion('label_smoothed_cross_entropy_with_alignment')
-class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion):
-
+@register_criterion("label_smoothed_cross_entropy_with_alignment")
+class LabelSmoothedCrossEntropyCriterionWithAlignment(
+ LabelSmoothedCrossEntropyCriterion
+):
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
super().__init__(task, sentence_avg, label_smoothing)
self.alignment_lambda = alignment_lambda
@@ -22,8 +23,13 @@ def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser)
- parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D',
- help='weight for the alignment loss')
+ parser.add_argument(
+ "--alignment-lambda",
+ default=0.05,
+ type=float,
+ metavar="D",
+ help="weight for the alignment loss",
+ )
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
@@ -33,41 +39,46 @@ def forward(self, model, sample, reduce=True):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
- net_output = model(**sample['net_input'])
+ net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
- sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
+ sample_size = (
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
+ )
logging_output = {
- 'loss': utils.item(loss.data) if reduce else loss.data,
- 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample['target'].size(0),
- 'sample_size': sample_size,
+ "loss": utils.item(loss.data) if reduce else loss.data,
+ "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample["target"].size(0),
+ "sample_size": sample_size,
}
alignment_loss = None
# Compute alignment loss only for training set and non dummy batches.
- if 'alignments' in sample and sample['alignments'] is not None:
+ if "alignments" in sample and sample["alignments"] is not None:
alignment_loss = self.compute_alignment_loss(sample, net_output)
if alignment_loss is not None:
- logging_output['alignment_loss'] = utils.item(alignment_loss.data)
+ logging_output["alignment_loss"] = utils.item(alignment_loss.data)
loss += self.alignment_lambda * alignment_loss
return loss, sample_size, logging_output
def compute_alignment_loss(self, sample, net_output):
- attn_prob = net_output[1]['attn'][0]
+ attn_prob = net_output[1]["attn"][0]
bsz, tgt_sz, src_sz = attn_prob.shape
attn = attn_prob.view(bsz * tgt_sz, src_sz)
- align = sample['alignments']
- align_weights = sample['align_weights'].float()
+ align = sample["alignments"]
+ align_weights = sample["align_weights"].float()
if len(align) > 0:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
- loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum()
+ loss = -(
+ (attn[align[:, 1][:, None], align[:, 0][:, None]]).log()
+ * align_weights[:, None]
+ ).sum()
else:
return None
@@ -76,16 +87,33 @@ def compute_alignment_loss(self, sample, net_output):
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs))
- nll_loss_sum = utils.item(sum(log.get('nll_loss', 0) for log in logging_outputs))
- alignment_loss_sum = utils.item(sum(log.get('alignment_loss', 0) for log in logging_outputs))
- ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs))
- sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs))
-
- metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
- metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
- metrics.log_scalar('alignment_loss', alignment_loss_sum / sample_size / math.log(2), sample_size, round=3)
- metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
+ nll_loss_sum = utils.item(
+ sum(log.get("nll_loss", 0) for log in logging_outputs)
+ )
+ alignment_loss_sum = utils.item(
+ sum(log.get("alignment_loss", 0) for log in logging_outputs)
+ )
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
+ sample_size = utils.item(
+ sum(log.get("sample_size", 0) for log in logging_outputs)
+ )
+
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_scalar(
+ "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
+ )
+ metrics.log_scalar(
+ "alignment_loss",
+ alignment_loss_sum / sample_size / math.log(2),
+ sample_size,
+ round=3,
+ )
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
+ )
@staticmethod
def logging_outputs_can_be_summed() -> bool:
diff --git a/fairseq/criterions/legacy_masked_lm.py b/fairseq/criterions/legacy_masked_lm.py
index 3dbfdfbe46..c70608c5a1 100644
--- a/fairseq/criterions/legacy_masked_lm.py
+++ b/fairseq/criterions/legacy_masked_lm.py
@@ -7,7 +7,6 @@
import torch
import torch.nn.functional as F
-
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@@ -18,8 +17,9 @@ def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
ignore_index is the same as the default value for F.cross_entropy in
pytorch.
"""
- assert logits.size(0) == targets.size(-1), \
- "Logits and Targets tensor shapes don't match up"
+ assert logits.size(0) == targets.size(
+ -1
+ ), "Logits and Targets tensor shapes don't match up"
loss = F.nll_loss(
F.log_softmax(logits, -1, dtype=torch.float32),
@@ -30,7 +30,7 @@ def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
return loss
-@register_criterion('legacy_masked_lm_loss')
+@register_criterion("legacy_masked_lm_loss")
class LegacyMaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
@@ -57,11 +57,18 @@ def __init__(self, task, masked_lm_only, nsp_loss_weight):
def add_args(parser):
"""Args for MaskedLM Loss"""
# Default for masked_lm_only is False so as to not break BERT training
- parser.add_argument('--masked-lm-only', default=False,
- action='store_true', help='compute MLM loss only')
- parser.add_argument('--nsp-loss-weight', default=1.0, type=float,
- help='weight for next sentence prediction'
- ' loss (default 1)')
+ parser.add_argument(
+ "--masked-lm-only",
+ default=False,
+ action="store_true",
+ help="compute MLM loss only",
+ )
+ parser.add_argument(
+ "--nsp-loss-weight",
+ default=1.0,
+ type=float,
+ help="weight for next sentence prediction" " loss (default 1)",
+ )
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
@@ -74,22 +81,21 @@ def forward(self, model, sample, reduce=True):
# reshape lm_logits from (N,T,C) to (N*T,C)
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
- lm_targets = sample['lm_target'].view(-1)
- lm_loss = compute_cross_entropy_loss(
- lm_logits, lm_targets, self.padding_idx)
+ lm_targets = sample["lm_target"].view(-1)
+ lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx)
# compute the number of tokens for which loss is computed. This is used
# to normalize the loss
ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
loss = lm_loss / ntokens
- nsentences = sample['nsentences']
+ nsentences = sample["nsentences"]
# nsentences = 0
# Compute sentence loss if masked_lm_only is False
sentence_loss = None
if not self.masked_lm_only:
- sentence_logits = output_metadata['sentence_logits']
- sentence_targets = sample['sentence_target'].view(-1)
+ sentence_logits = output_metadata["sentence_logits"]
+ sentence_targets = sample["sentence_target"].view(-1)
# This needs to be recomputed due to some differences between
# TokenBlock and BlockPair dataset. This can be resolved with a
# refactor of BERTModel which we will do in the future.
@@ -102,7 +108,8 @@ def forward(self, model, sample, reduce=True):
# refactor in the BERT model.
if sentence_logits is not None:
sentence_loss = compute_cross_entropy_loss(
- sentence_logits, sentence_targets)
+ sentence_logits, sentence_targets
+ )
loss += self.nsp_loss_weight * (sentence_loss / nsentences)
@@ -111,36 +118,54 @@ def forward(self, model, sample, reduce=True):
# here sample_size is just used for logging
sample_size = 1
logging_output = {
- 'loss': utils.item(loss.data) if reduce else loss.data,
- 'lm_loss': utils.item(lm_loss.data) if reduce else lm_loss.data,
+ "loss": utils.item(loss.data) if reduce else loss.data,
+ "lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data,
# sentence loss is not always computed
- 'sentence_loss': (
- (
- utils.item(sentence_loss.data) if reduce
- else sentence_loss.data
- ) if sentence_loss is not None else 0.0
+ "sentence_loss": (
+ (utils.item(sentence_loss.data) if reduce else sentence_loss.data)
+ if sentence_loss is not None
+ else 0.0
),
- 'ntokens': ntokens,
- 'nsentences': nsentences,
- 'sample_size': sample_size,
+ "ntokens": ntokens,
+ "nsentences": nsentences,
+ "sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- lm_loss_sum = sum(log.get('lm_loss', 0) for log in logging_outputs)
- sentence_loss_sum = sum(
- log.get('sentence_loss', 0) for log in logging_outputs)
- ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
- nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
- sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
- agg_loss = sum(log.get('loss', 0) for log in logging_outputs)
-
- metrics.log_scalar('loss', agg_loss / sample_size / math.log(2) if sample_size > 0 else 0., sample_size, round=3)
- metrics.log_scalar('lm_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3)
- metrics.log_scalar('sentence_loss', sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0., nsentences, round=3)
- metrics.log_scalar('nll_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3)
+ lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs)
+ sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+ agg_loss = sum(log.get("loss", 0) for log in logging_outputs)
+
+ metrics.log_scalar(
+ "loss",
+ agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
+ sample_size,
+ round=3,
+ )
+ metrics.log_scalar(
+ "lm_loss",
+ lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
+ ntokens,
+ round=3,
+ )
+ metrics.log_scalar(
+ "sentence_loss",
+ sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0,
+ nsentences,
+ round=3,
+ )
+ metrics.log_scalar(
+ "nll_loss",
+ lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
+ ntokens,
+ round=3,
+ )
@staticmethod
def logging_outputs_can_be_summed() -> bool:
diff --git a/fairseq/criterions/masked_lm.py b/fairseq/criterions/masked_lm.py
index f62ed805f2..b04cfbff6d 100644
--- a/fairseq/criterions/masked_lm.py
+++ b/fairseq/criterions/masked_lm.py
@@ -7,12 +7,11 @@
import torch
import torch.nn.functional as F
-
from fairseq import metrics, modules, utils
from fairseq.criterions import FairseqCriterion, register_criterion
-@register_criterion('masked_lm')
+@register_criterion("masked_lm")
class MaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
@@ -30,7 +29,7 @@ def forward(self, model, sample, reduce=True):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
- masked_tokens = sample['target'].ne(self.padding_idx)
+ masked_tokens = sample["target"].ne(self.padding_idx)
sample_size = masked_tokens.int().sum()
# Rare: when all tokens are masked, project all tokens.
@@ -39,7 +38,7 @@ def forward(self, model, sample, reduce=True):
# (see github.com/pytorch/pytorch/issues/26247).
if self.tpu:
masked_tokens = None # always project all tokens on TPU
- elif masked_tokens.device == torch.device('cpu'):
+ elif masked_tokens.device == torch.device("cpu"):
if not masked_tokens.any():
masked_tokens = None
else:
@@ -49,7 +48,7 @@ def forward(self, model, sample, reduce=True):
masked_tokens.new([True]),
)
- logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
+ logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0]
targets = model.get_targets(sample, [logits])
if masked_tokens is not None:
targets = targets[masked_tokens]
@@ -57,26 +56,30 @@ def forward(self, model, sample, reduce=True):
loss = modules.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
- reduction='sum',
+ reduction="sum",
ignore_index=self.padding_idx,
)
logging_output = {
- 'loss': loss if self.tpu else loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample['nsentences'],
- 'sample_size': sample_size,
+ "loss": loss if self.tpu else loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample["nsentences"],
+ "sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
- sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
- metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
- metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
+ )
@staticmethod
def logging_outputs_can_be_summed() -> bool:
diff --git a/fairseq/criterions/nat_loss.py b/fairseq/criterions/nat_loss.py
index 3326734d55..cdc7da861d 100644
--- a/fairseq/criterions/nat_loss.py
+++ b/fairseq/criterions/nat_loss.py
@@ -5,17 +5,15 @@
import math
-import torch.nn.functional as F
import torch
-from torch import Tensor
-
+import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
+from torch import Tensor
@register_criterion("nat_loss")
class LabelSmoothedDualImitationCriterion(FairseqCriterion):
-
def __init__(self, task, label_smoothing):
super().__init__(task)
self.label_smoothing = label_smoothing
@@ -24,23 +22,23 @@ def __init__(self, task, label_smoothing):
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument(
- '--label-smoothing',
- default=0.,
+ "--label-smoothing",
+ default=0.0,
type=float,
- metavar='D',
- help='epsilon for label smoothing, 0 means no label smoothing',
+ metavar="D",
+ help="epsilon for label smoothing, 0 means no label smoothing",
)
def _compute_loss(
self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0
):
"""
- outputs: batch x len x d_model
- targets: batch x len
- masks: batch x len
+ outputs: batch x len x d_model
+ targets: batch x len
+ masks: batch x len
- policy_logprob: if there is some policy
- depends on the likelihood score as rewards.
+ policy_logprob: if there is some policy
+ depends on the likelihood score as rewards.
"""
def mean_ds(x: Tensor, dim=None) -> Tensor:
@@ -49,6 +47,7 @@ def mean_ds(x: Tensor, dim=None) -> Tensor:
if dim is None
else x.float().mean(dim).type_as(x)
)
+
if masks is not None:
outputs, targets = outputs[masks], targets[masks]
@@ -58,16 +57,17 @@ def mean_ds(x: Tensor, dim=None) -> Tensor:
else:
logits = F.log_softmax(outputs, dim=-1)
if targets.dim() == 1:
- losses = F.nll_loss(logits, targets.to(logits.device), reduction='none')
+ losses = F.nll_loss(logits, targets.to(logits.device), reduction="none")
else: # soft-labels
- losses = F.kl_div(logits, targets.to(logits.device), reduction='none')
+ losses = F.kl_div(logits, targets.to(logits.device), reduction="none")
losses = losses.sum(-1)
nll_loss = mean_ds(losses)
if label_smoothing > 0:
- loss = nll_loss * (
- 1 - label_smoothing) - mean_ds(logits) * label_smoothing
+ loss = (
+ nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing
+ )
else:
loss = nll_loss
@@ -103,14 +103,14 @@ def forward(self, model, sample, reduce=True):
outputs[obj].get("tgt"),
outputs[obj].get("mask", None),
outputs[obj].get("ls", 0.0),
- name=obj + '-loss',
- factor=outputs[obj].get("factor", 1.0)
+ name=obj + "-loss",
+ factor=outputs[obj].get("factor", 1.0),
)
else:
_losses = self._custom_loss(
outputs[obj].get("loss"),
- name=obj + '-loss',
- factor=outputs[obj].get("factor", 1.0)
+ name=obj + "-loss",
+ factor=outputs[obj].get("factor", 1.0),
)
losses += [_losses]
@@ -118,8 +118,7 @@ def forward(self, model, sample, reduce=True):
nll_loss += [_losses.get("nll_loss", 0.0)]
loss = sum(l["loss"] for l in losses)
- nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 \
- else loss.new_tensor(0)
+ nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0)
# NOTE:
# we don't need to use sample_size as denominator for the gradient
@@ -145,13 +144,21 @@ def forward(self, model, sample, reduce=True):
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
+ sample_size = utils.item(
+ sum(log.get("sample_size", 0) for log in logging_outputs)
+ )
loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs))
- metrics.log_scalar('loss', loss / sample_size / math.log(2), sample_size, round=3)
- metrics.log_scalar('nll_loss', nll_loss / sample_size / math.log(2), sample_size, round=3)
- metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
+ metrics.log_scalar(
+ "loss", loss / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_scalar(
+ "nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
+ )
for key in logging_outputs[0]:
if key[-5:] == "-loss":
diff --git a/fairseq/criterions/sentence_prediction.py b/fairseq/criterions/sentence_prediction.py
index 4ba1317856..9519fdc56d 100644
--- a/fairseq/criterions/sentence_prediction.py
+++ b/fairseq/criterions/sentence_prediction.py
@@ -7,14 +7,12 @@
import torch
import torch.nn.functional as F
-
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
-@register_criterion('sentence_prediction')
+@register_criterion("sentence_prediction")
class SentencePredictionCriterion(FairseqCriterion):
-
def __init__(self, task, classification_head_name, regression_target):
super().__init__(task)
self.classification_head_name = classification_head_name
@@ -37,12 +35,12 @@ def forward(self, model, sample, reduce=True):
3) logging outputs to display while training
"""
assert (
- hasattr(model, 'classification_heads')
+ hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
- ), 'model must provide sentence classification head for --criterion=sentence_prediction'
+ ), "model must provide sentence classification head for --criterion=sentence_prediction"
logits, _ = model(
- **sample['net_input'],
+ **sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
)
@@ -51,39 +49,45 @@ def forward(self, model, sample, reduce=True):
if not self.regression_target:
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
- loss = F.nll_loss(lprobs, targets, reduction='sum')
+ loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
logits = logits.view(-1).float()
targets = targets.float()
- loss = F.mse_loss(logits, targets, reduction='sum')
+ loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {
- 'loss': loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample_size,
- 'sample_size': sample_size,
+ "loss": loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample_size,
+ "sample_size": sample_size,
}
if not self.regression_target:
preds = logits.argmax(dim=1)
- logging_output['ncorrect'] = (preds == targets).sum()
+ logging_output["ncorrect"] = (preds == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
- ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
- nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
- sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
- metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
if sample_size != ntokens:
- metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3)
-
- if len(logging_outputs) > 0 and 'ncorrect' in logging_outputs[0]:
- ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
- metrics.log_scalar('accuracy', 100.0 * ncorrect / nsentences, nsentences, round=1)
+ metrics.log_scalar(
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
+ )
+
+ if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
+ ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
+ metrics.log_scalar(
+ "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
+ )
@staticmethod
def logging_outputs_can_be_summed() -> bool:
diff --git a/fairseq/criterions/sentence_ranking.py b/fairseq/criterions/sentence_ranking.py
index 52a0a177d8..d4c76341d4 100644
--- a/fairseq/criterions/sentence_ranking.py
+++ b/fairseq/criterions/sentence_ranking.py
@@ -7,19 +7,17 @@
import torch
import torch.nn.functional as F
-
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
-@register_criterion('sentence_ranking')
+@register_criterion("sentence_ranking")
class SentenceRankingCriterion(FairseqCriterion):
-
def __init__(self, task, ranking_head_name, save_predictions, num_classes):
super().__init__(task)
self.ranking_head_name = ranking_head_name
if save_predictions is not None:
- self.prediction_h = open(save_predictions, 'w')
+ self.prediction_h = open(save_predictions, "w")
else:
self.prediction_h = None
self.num_classes = num_classes
@@ -47,14 +45,14 @@ def forward(self, model, sample, reduce=True):
3) logging outputs to display while training
"""
assert (
- hasattr(model, 'classification_heads')
+ hasattr(model, "classification_heads")
and self.ranking_head_name in model.classification_heads
- ), 'model must provide sentence ranking head for --criterion=sentence_ranking'
+ ), "model must provide sentence ranking head for --criterion=sentence_ranking"
scores = []
for idx in range(self.num_classes):
score, _ = model(
- **sample['net_input{idx}'.format(idx=idx+1)],
+ **sample["net_input{idx}".format(idx=idx + 1)],
classification_head_name=self.ranking_head_name,
)
scores.append(score)
@@ -62,49 +60,55 @@ def forward(self, model, sample, reduce=True):
logits = torch.cat(scores, dim=1)
sample_size = logits.size(0)
- if 'target' in sample:
+ if "target" in sample:
targets = model.get_targets(sample, [logits]).view(-1)
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
- loss = F.nll_loss(lprobs, targets, reduction='sum')
+ loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
targets = None
loss = torch.tensor(0.0, requires_grad=True)
if self.prediction_h is not None:
preds = logits.argmax(dim=1)
- for i, (id, pred) in enumerate(zip(sample['id'].tolist(), preds.tolist())):
+ for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())):
if targets is not None:
label = targets[i].item()
- print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h)
+ print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
else:
- print('{}\t{}'.format(id, pred), file=self.prediction_h)
+ print("{}\t{}".format(id, pred), file=self.prediction_h)
logging_output = {
- 'loss': loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample_size,
- 'sample_size': sample_size,
+ "loss": loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample_size,
+ "sample_size": sample_size,
}
if targets is not None:
- logging_output['ncorrect'] = (logits.argmax(dim=1) == targets).sum()
+ logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
- ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
- nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
- sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
-
- metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
if sample_size != ntokens:
- metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3)
+ metrics.log_scalar(
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
+ )
- if len(logging_outputs) > 0 and 'ncorrect' in logging_outputs[0]:
- ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
- metrics.log_scalar('accuracy', 100.0 * ncorrect / nsentences, nsentences, round=1)
+ if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
+ ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
+ metrics.log_scalar(
+ "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
+ )
@staticmethod
def logging_outputs_can_be_summed() -> bool:
diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py
index cc743524d2..6ac7557dcc 100644
--- a/fairseq/criterions/wav2vec_criterion.py
+++ b/fairseq/criterions/wav2vec_criterion.py
@@ -7,15 +7,13 @@
import torch
import torch.nn.functional as F
-
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.logging.meters import safe_round
-@register_criterion('wav2vec')
+@register_criterion("wav2vec")
class Wav2vecCriterion(FairseqCriterion):
-
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
super().__init__(task)
self.infonce = infonce
@@ -42,12 +40,12 @@ def forward(self, model, sample, reduce=True, log_pred=False):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
- net_output = model(**sample['net_input'])
+ net_output = model(**sample["net_input"])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)
weights = None
- if hasattr(model, 'get_target_weights') and not self.infonce:
+ if hasattr(model, "get_target_weights") and not self.infonce:
weights = model.get_target_weights(target, net_output)
if torch.is_tensor(weights):
weights = weights.float()
@@ -55,9 +53,18 @@ def forward(self, model, sample, reduce=True, log_pred=False):
losses = []
if self.infonce:
- loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",)
+ loss = F.cross_entropy(
+ logits,
+ target,
+ reduction="sum" if reduce else "none",
+ )
else:
- loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",)
+ loss = F.binary_cross_entropy_with_logits(
+ logits,
+ target.float(),
+ weights,
+ reduction="sum" if reduce else "none",
+ )
sample_size = target.numel() if self.infonce else target.long().sum().item()
losses.append(loss.detach().clone())
@@ -69,7 +76,9 @@ def forward(self, model, sample, reduce=True, log_pred=False):
extra_losses = [extra_losses]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
- assert len(extra_losses) == len(self.loss_weights), f'{len(extra_losses)}, {len(self.loss_weights)}'
+ assert len(extra_losses) == len(
+ self.loss_weights
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, coef in zip(extra_losses, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
@@ -77,10 +86,10 @@ def forward(self, model, sample, reduce=True, log_pred=False):
losses.append(p)
logging_output = {
- 'loss': loss.item() if reduce else loss,
- 'ntokens': sample_size,
- 'nsentences': sample['id'].numel(),
- 'sample_size': sample_size,
+ "loss": loss.item() if reduce else loss,
+ "ntokens": sample_size,
+ "nsentences": sample["id"].numel(),
+ "sample_size": sample_size,
}
for lk in self.log_keys:
@@ -89,7 +98,7 @@ def forward(self, model, sample, reduce=True, log_pred=False):
if len(losses) > 1:
for i, l in enumerate(losses):
- logging_output[f'loss_{i}'] = l.item()
+ logging_output[f"loss_{i}"] = l.item()
if self.infonce:
with torch.no_grad():
@@ -108,21 +117,27 @@ def forward(self, model, sample, reduce=True, log_pred=False):
logging_output["count"] = count
if log_pred:
- logging_output['logits'] = logits.cpu().numpy()
- logging_output['target'] = target.cpu().numpy()
+ logging_output["logits"] = logits.cpu().numpy()
+ logging_output["target"] = target.cpu().numpy()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs))
- ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs))
- nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs))
- sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs))
-
- metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
- metrics.log_scalar('ntokens', ntokens)
- metrics.log_scalar('nsentences', nsentences)
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
+ nsentences = utils.item(
+ sum(log.get("nsentences", 0) for log in logging_outputs)
+ )
+ sample_size = utils.item(
+ sum(log.get("sample_size", 0) for log in logging_outputs)
+ )
+
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
+ metrics.log_scalar("ntokens", ntokens)
+ metrics.log_scalar("nsentences", nsentences)
correct = sum(log.get("correct", 0) for log in logging_outputs)
metrics.log_scalar("_correct", correct)
@@ -130,21 +145,31 @@ def reduce_metrics(logging_outputs) -> None:
total = sum(log.get("count", 0) for log in logging_outputs)
metrics.log_scalar("_total", total)
-
if total > 0:
metrics.log_derived(
"accuracy",
- lambda meters: safe_round(meters["_correct"].sum / meters["_total"].sum, 5)
+ lambda meters: safe_round(
+ meters["_correct"].sum / meters["_total"].sum, 5
+ )
if meters["_total"].sum > 0
else float("nan"),
)
- builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'}
+ builtin_keys = {
+ "loss",
+ "ntokens",
+ "nsentences",
+ "sample_size",
+ "correct",
+ "count",
+ }
for k in logging_outputs[0]:
if k not in builtin_keys:
- val = sum(log.get(k, 0) for log in logging_outputs) / len(logging_outputs)
- if k.startswith('loss'):
+ val = sum(log.get(k, 0) for log in logging_outputs) / len(
+ logging_outputs
+ )
+ if k.startswith("loss"):
metrics.log_scalar(k, val / sample_size / math.log(2), sample_size)
else:
metrics.log_scalar(k, val, round=3)
diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py
index 785a0aa643..9b30813955 100644
--- a/fairseq/data/__init__.py
+++ b/fairseq/data/__init__.py
@@ -20,7 +20,12 @@
from .concat_sentences_dataset import ConcatSentencesDataset
from .denoising_dataset import DenoisingDataset
from .id_dataset import IdDataset
-from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
+from .indexed_dataset import (
+ IndexedCachedDataset,
+ IndexedDataset,
+ IndexedRawTextDataset,
+ MMapIndexedDataset,
+)
from .language_pair_dataset import LanguagePairDataset
from .list_dataset import ListDataset
from .lm_context_window_dataset import LMContextWindowDataset
@@ -60,60 +65,60 @@
)
__all__ = [
- 'AddTargetDataset',
- 'AppendTokenDataset',
- 'BacktranslationDataset',
- 'BaseWrapperDataset',
- 'BucketPadLengthDataset',
- 'ColorizeDataset',
- 'ConcatDataset',
- 'ConcatSentencesDataset',
- 'CountingIterator',
- 'DenoisingDataset',
- 'Dictionary',
- 'EncodedFastaDataset',
- 'EpochBatchIterator',
- 'FairseqDataset',
- 'FairseqIterableDataset',
- 'FastaDataset',
- 'GroupedIterator',
- 'IdDataset',
- 'IndexedCachedDataset',
- 'IndexedDataset',
- 'IndexedRawTextDataset',
- 'LanguagePairDataset',
- 'LeftPadDataset',
- 'ListDataset',
- 'LMContextWindowDataset',
- 'LRUCacheDataset',
- 'MaskTokensDataset',
- 'MMapIndexedDataset',
- 'MonolingualDataset',
- 'MultiCorpusSampledDataset',
- 'NestedDictionaryDataset',
- 'NoisingDataset',
- 'NumelDataset',
- 'NumSamplesDataset',
- 'OffsetTokensDataset',
- 'PadDataset',
- 'PrependDataset',
- 'PrependTokenDataset',
- 'ReplaceDataset',
- 'RollDataset',
- 'FileAudioDataset',
- 'RawLabelDataset',
- 'ResamplingDataset',
- 'RightPadDataset',
- 'RoundRobinZipDatasets',
- 'SampledMultiDataset',
- 'SampledMultiEpochDataset',
- 'ShardedIterator',
- 'SortDataset',
- 'StripTokenDataset',
- 'SubsampleDataset',
- 'TokenBlockDataset',
- 'TransformEosDataset',
- 'TransformEosLangPairDataset',
- 'TruncateDataset',
- 'TruncatedDictionary',
+ "AddTargetDataset",
+ "AppendTokenDataset",
+ "BacktranslationDataset",
+ "BaseWrapperDataset",
+ "BucketPadLengthDataset",
+ "ColorizeDataset",
+ "ConcatDataset",
+ "ConcatSentencesDataset",
+ "CountingIterator",
+ "DenoisingDataset",
+ "Dictionary",
+ "EncodedFastaDataset",
+ "EpochBatchIterator",
+ "FairseqDataset",
+ "FairseqIterableDataset",
+ "FastaDataset",
+ "GroupedIterator",
+ "IdDataset",
+ "IndexedCachedDataset",
+ "IndexedDataset",
+ "IndexedRawTextDataset",
+ "LanguagePairDataset",
+ "LeftPadDataset",
+ "ListDataset",
+ "LMContextWindowDataset",
+ "LRUCacheDataset",
+ "MaskTokensDataset",
+ "MMapIndexedDataset",
+ "MonolingualDataset",
+ "MultiCorpusSampledDataset",
+ "NestedDictionaryDataset",
+ "NoisingDataset",
+ "NumelDataset",
+ "NumSamplesDataset",
+ "OffsetTokensDataset",
+ "PadDataset",
+ "PrependDataset",
+ "PrependTokenDataset",
+ "ReplaceDataset",
+ "RollDataset",
+ "FileAudioDataset",
+ "RawLabelDataset",
+ "ResamplingDataset",
+ "RightPadDataset",
+ "RoundRobinZipDatasets",
+ "SampledMultiDataset",
+ "SampledMultiEpochDataset",
+ "ShardedIterator",
+ "SortDataset",
+ "StripTokenDataset",
+ "SubsampleDataset",
+ "TokenBlockDataset",
+ "TransformEosDataset",
+ "TransformEosLangPairDataset",
+ "TruncateDataset",
+ "TruncatedDictionary",
]
diff --git a/fairseq/data/add_target_dataset.py b/fairseq/data/add_target_dataset.py
index 3a42dd7a2e..9ef467058b 100644
--- a/fairseq/data/add_target_dataset.py
+++ b/fairseq/data/add_target_dataset.py
@@ -5,12 +5,20 @@
import torch
-from . import BaseWrapperDataset
-from . import data_utils
+from . import BaseWrapperDataset, data_utils
class AddTargetDataset(BaseWrapperDataset):
- def __init__(self, dataset, labels, pad, eos, batch_targets, process_label=None, add_to_input=False):
+ def __init__(
+ self,
+ dataset,
+ labels,
+ pad,
+ eos,
+ batch_targets,
+ process_label=None,
+ add_to_input=False,
+ ):
super().__init__(dataset)
self.labels = labels
self.batch_targets = batch_targets
@@ -20,7 +28,11 @@ def __init__(self, dataset, labels, pad, eos, batch_targets, process_label=None,
self.add_to_input = add_to_input
def get_label(self, index):
- return self.labels[index] if self.process_label is None else self.process_label(self.labels[index])
+ return (
+ self.labels[index]
+ if self.process_label is None
+ else self.process_label(self.labels[index])
+ )
def __getitem__(self, index):
item = self.dataset[index]
@@ -51,6 +63,8 @@ def collater(self, samples):
if self.add_to_input:
eos = target.new_full((target.size(0), 1), self.eos)
collated["target"] = torch.cat([target, eos], dim=-1).long()
- collated["net_input"]["prev_output_tokens"] = torch.cat([eos, target], dim=-1).long()
+ collated["net_input"]["prev_output_tokens"] = torch.cat(
+ [eos, target], dim=-1
+ ).long()
collated["ntokens"] += target.size(0)
- return collated
\ No newline at end of file
+ return collated
diff --git a/fairseq/data/append_token_dataset.py b/fairseq/data/append_token_dataset.py
index 7298129f62..87695bd0f5 100644
--- a/fairseq/data/append_token_dataset.py
+++ b/fairseq/data/append_token_dataset.py
@@ -10,7 +10,6 @@
class AppendTokenDataset(BaseWrapperDataset):
-
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py
index 3731721953..de08669851 100644
--- a/fairseq/data/audio/audio_utils.py
+++ b/fairseq/data/audio/audio_utils.py
@@ -1,11 +1,11 @@
import os.path as op
-from typing import Union, BinaryIO, Optional, Tuple
+from typing import BinaryIO, Optional, Tuple, Union
import numpy as np
def get_waveform(
- path_or_fp: Union[str, BinaryIO], normalization=True
+ path_or_fp: Union[str, BinaryIO], normalization=True
) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC.
@@ -15,15 +15,15 @@ def get_waveform(
"""
if isinstance(path_or_fp, str):
ext = op.splitext(op.basename(path_or_fp))[1]
- if ext not in {'.flac', '.wav'}:
- raise ValueError(f'Unsupported audio format: {ext}')
+ if ext not in {".flac", ".wav"}:
+ raise ValueError(f"Unsupported audio format: {ext}")
try:
import soundfile as sf
except ImportError:
- raise ImportError('Please install soundfile to load WAV/FLAC file')
+ raise ImportError("Please install soundfile to load WAV/FLAC file")
- waveform, sample_rate = sf.read(path_or_fp, dtype='float32')
+ waveform, sample_rate = sf.read(path_or_fp, dtype="float32")
if not normalization:
waveform *= 2 ** 15 # denormalized to 16-bit signed integers
return waveform, sample_rate
@@ -56,9 +56,11 @@ def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarr
try:
import torch
import torchaudio.compliance.kaldi as ta_kaldi
+
waveform = torch.from_numpy(waveform).unsqueeze(0)
- features = ta_kaldi.fbank(waveform, num_mel_bins=n_bins,
- sample_frequency=sample_rate)
+ features = ta_kaldi.fbank(
+ waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
+ )
return features.numpy()
except ImportError:
return None
@@ -75,7 +77,9 @@ def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray:
if features is None:
features = _get_torchaudio_fbank(sound, sample_rate, n_bins)
if features is None:
- raise ImportError('Please install pyKaldi or torchaudio to enable '
- 'online filterbank feature extraction')
+ raise ImportError(
+ "Please install pyKaldi or torchaudio to enable "
+ "online filterbank feature extraction"
+ )
return features
diff --git a/fairseq/data/audio/feature_transforms/__init__.py b/fairseq/data/audio/feature_transforms/__init__.py
index 399956a33b..359fa06971 100644
--- a/fairseq/data/audio/feature_transforms/__init__.py
+++ b/fairseq/data/audio/feature_transforms/__init__.py
@@ -1,7 +1,7 @@
import importlib
import os
-from typing import Optional, Dict
from abc import ABC, abstractmethod
+from typing import Dict, Optional
class AudioFeatureTransform(ABC):
@@ -18,14 +18,16 @@ def from_config_dict(cls, config: Optional[Dict] = None):
def register_audio_feature_transform(name):
def register_audio_feature_transform_cls(cls):
if name in AUDIO_FEATURE_TRANSFORM_REGISTRY:
- raise ValueError(f'Cannot register duplicate transform ({name})')
+ raise ValueError(f"Cannot register duplicate transform ({name})")
if not issubclass(cls, AudioFeatureTransform):
- raise ValueError(f'Transform ({name}: {cls.__name__}) must extend '
- 'AudioFeatureTransform')
+ raise ValueError(
+ f"Transform ({name}: {cls.__name__}) must extend "
+ "AudioFeatureTransform"
+ )
if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES:
raise ValueError(
- f'Cannot register audio feature transform with duplicate '
- f'class name ({cls.__name__})'
+ f"Cannot register audio feature transform with duplicate "
+ f"class name ({cls.__name__})"
)
AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__)
@@ -42,19 +44,19 @@ def get_audio_feature_transform(name):
for file in os.listdir(transforms_dir):
path = os.path.join(transforms_dir, file)
if (
- not file.startswith('_')
- and not file.startswith('.')
- and (file.endswith('.py') or os.path.isdir(path))
+ not file.startswith("_")
+ and not file.startswith(".")
+ and (file.endswith(".py") or os.path.isdir(path))
):
- name = file[:file.find('.py')] if file.endswith('.py') else file
- importlib.import_module('fairseq.data.audio.feature_transforms.' + name)
+ name = file[: file.find(".py")] if file.endswith(".py") else file
+ importlib.import_module("fairseq.data.audio.feature_transforms." + name)
class CompositeAudioFeatureTransform(AudioFeatureTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
- _transforms = _config.get('transforms')
+ _transforms = _config.get("transforms")
if _transforms is None:
return None
transforms = [
@@ -72,6 +74,9 @@ def __call__(self, x):
return x
def __repr__(self):
- format_string = [self.__class__.__name__ + '('] + \
- [f" {t.__repr__()}" for t in self.transforms] + [')']
- return '\n'.join(format_string)
+ format_string = (
+ [self.__class__.__name__ + "("]
+ + [f" {t.__repr__()}" for t in self.transforms]
+ + [")"]
+ )
+ return "\n".join(format_string)
diff --git a/fairseq/data/audio/feature_transforms/global_cmvn.py b/fairseq/data/audio/feature_transforms/global_cmvn.py
index f9c92a66b1..d512fed300 100644
--- a/fairseq/data/audio/feature_transforms/global_cmvn.py
+++ b/fairseq/data/audio/feature_transforms/global_cmvn.py
@@ -1,10 +1,11 @@
import numpy as np
from fairseq.data.audio.feature_transforms import (
- AudioFeatureTransform, register_audio_feature_transform
+ AudioFeatureTransform,
+ register_audio_feature_transform,
)
-@register_audio_feature_transform('global_cmvn')
+@register_audio_feature_transform("global_cmvn")
class GlobalCMVN(AudioFeatureTransform):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@@ -12,11 +13,11 @@ class GlobalCMVN(AudioFeatureTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
- return GlobalCMVN(_config.get('stats_npz_path'))
+ return GlobalCMVN(_config.get("stats_npz_path"))
def __init__(self, stats_npz_path):
stats = np.load(stats_npz_path)
- self.mean, self.std = stats['mean'], stats['std']
+ self.mean, self.std = stats["mean"], stats["std"]
def __call__(self, x):
x = np.subtract(x, self.mean)
diff --git a/fairseq/data/audio/feature_transforms/specaugment.py b/fairseq/data/audio/feature_transforms/specaugment.py
index e4c36bde3c..2ef4778b85 100644
--- a/fairseq/data/audio/feature_transforms/specaugment.py
+++ b/fairseq/data/audio/feature_transforms/specaugment.py
@@ -3,13 +3,13 @@
from typing import Optional
import numpy as np
-
from fairseq.data.audio.feature_transforms import (
- AudioFeatureTransform, register_audio_feature_transform
+ AudioFeatureTransform,
+ register_audio_feature_transform,
)
-@register_audio_feature_transform('specaugment')
+@register_audio_feature_transform("specaugment")
class SpecAugmentTransform(AudioFeatureTransform):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@@ -17,13 +17,13 @@ class SpecAugmentTransform(AudioFeatureTransform):
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return SpecAugmentTransform(
- _config.get('time_warp_W', 0),
- _config.get('freq_mask_N', 0),
- _config.get('freq_mask_F', 0),
- _config.get('time_mask_N', 0),
- _config.get('time_mask_T', 0),
- _config.get('time_mask_p', 0.0),
- _config.get('mask_value', None),
+ _config.get("time_warp_W", 0),
+ _config.get("freq_mask_N", 0),
+ _config.get("freq_mask_F", 0),
+ _config.get("time_mask_N", 0),
+ _config.get("time_mask_T", 0),
+ _config.get("time_mask_p", 0.0),
+ _config.get("mask_value", None),
)
def __init__(
@@ -41,15 +41,15 @@ def __init__(
mask_value, numbers.Number
), f"mask_value (type: {type(mask_value)}) must be None or a number"
if freq_mask_n > 0:
- assert (
- freq_mask_f > 0
- ), f"freq_mask_F ({freq_mask_f}) " \
- f"must be larger than 0 when doing freq masking."
+ assert freq_mask_f > 0, (
+ f"freq_mask_F ({freq_mask_f}) "
+ f"must be larger than 0 when doing freq masking."
+ )
if time_mask_n > 0:
- assert (
- time_mask_t > 0
- ), f"time_mask_T ({time_mask_t}) must be larger than 0 when " \
- f"doing time masking."
+ assert time_mask_t > 0, (
+ f"time_mask_T ({time_mask_t}) must be larger than 0 when "
+ f"doing time masking."
+ )
self.time_warp_w = time_warp_w
self.freq_mask_n = freq_mask_n
@@ -60,14 +60,21 @@ def __init__(
self.mask_value = mask_value
def __repr__(self):
- return self.__class__.__name__ + '(' + ', '.join(
- [f'time_warp_w={self.time_warp_w}',
- f'freq_mask_n={self.freq_mask_n}',
- f'freq_mask_f={self.freq_mask_f}',
- f'time_mask_n={self.time_mask_n}',
- f'time_mask_t={self.time_mask_t}',
- f'time_mask_p={self.time_mask_p}']
- ) + ')'
+ return (
+ self.__class__.__name__
+ + "("
+ + ", ".join(
+ [
+ f"time_warp_w={self.time_warp_w}",
+ f"freq_mask_n={self.freq_mask_n}",
+ f"freq_mask_f={self.freq_mask_f}",
+ f"time_mask_n={self.time_mask_n}",
+ f"time_mask_t={self.time_mask_t}",
+ f"time_mask_p={self.time_mask_p}",
+ ]
+ )
+ + ")"
+ )
def __call__(self, spectrogram):
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
@@ -89,14 +96,12 @@ def __call__(self, spectrogram):
if self.time_warp_w > 0:
if 2 * self.time_warp_w < num_frames:
import cv2
- w0 = np.random.randint(
- self.time_warp_w, num_frames - self.time_warp_w
- )
+
+ w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
w = np.random.randint(0, self.time_warp_w)
upper, lower = distorted[:w0, :], distorted[w0:, :]
upper = cv2.resize(
- upper, dsize=(num_freqs, w0 + w),
- interpolation=cv2.INTER_LINEAR
+ upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
)
lower = cv2.resize(
lower,
@@ -109,7 +114,7 @@ def __call__(self, spectrogram):
f = np.random.randint(0, self.freq_mask_f)
f0 = np.random.randint(0, num_freqs - f)
if f != 0:
- distorted[:, f0: f0 + f] = mask_value
+ distorted[:, f0 : f0 + f] = mask_value
max_time_mask_t = min(
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
@@ -121,6 +126,6 @@ def __call__(self, spectrogram):
t = np.random.randint(0, max_time_mask_t)
t0 = np.random.randint(0, num_frames - t)
if t != 0:
- distorted[t0: t0 + t, :] = mask_value
+ distorted[t0 : t0 + t, :] = mask_value
return distorted
diff --git a/fairseq/data/audio/feature_transforms/utterance_cmvn.py b/fairseq/data/audio/feature_transforms/utterance_cmvn.py
index cbedd360d0..6bbd0ae821 100644
--- a/fairseq/data/audio/feature_transforms/utterance_cmvn.py
+++ b/fairseq/data/audio/feature_transforms/utterance_cmvn.py
@@ -1,11 +1,11 @@
import numpy as np
-
from fairseq.data.audio.feature_transforms import (
- AudioFeatureTransform, register_audio_feature_transform
+ AudioFeatureTransform,
+ register_audio_feature_transform,
)
-@register_audio_feature_transform('utterance_cmvn')
+@register_audio_feature_transform("utterance_cmvn")
class UtteranceCMVN(AudioFeatureTransform):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@@ -13,16 +13,18 @@ class UtteranceCMVN(AudioFeatureTransform):
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return UtteranceCMVN(
- _config.get('norm_means', True),
- _config.get('norm_vars', True),
+ _config.get("norm_means", True),
+ _config.get("norm_vars", True),
)
def __init__(self, norm_means=True, norm_vars=True):
self.norm_means, self.norm_vars = norm_means, norm_vars
def __repr__(self):
- return self.__class__.__name__ + \
- f'(norm_means={self.norm_means}, norm_vars={self.norm_vars})'
+ return (
+ self.__class__.__name__
+ + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
+ )
def __call__(self, x):
mean = x.mean(axis=0)
diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py
index 09838a54e0..8d6ce85ecc 100644
--- a/fairseq/data/audio/raw_audio_dataset.py
+++ b/fairseq/data/audio/raw_audio_dataset.py
@@ -4,16 +4,17 @@
# LICENSE file in the root directory of this source tree.
-import os
import logging
-import numpy as np
+import os
import sys
+import numpy as np
import torch
import torch.nn.functional as F
from .. import FairseqDataset
+
logger = logging.getLogger(__name__)
@@ -72,11 +73,7 @@ def crop_to_max_size(self, wav, target_size):
return wav[start:end]
def collater(self, samples):
- samples = [
- s
- for s in samples
- if s["source"] is not None
- ]
+ samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py
index df360b2c74..aefe95658d 100644
--- a/fairseq/data/audio/speech_to_text_dataset.py
+++ b/fairseq/data/audio/speech_to_text_dataset.py
@@ -3,54 +3,61 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import logging
-import re
-from typing import List, Tuple, Optional, Dict
-import os.path as op
import csv
import io
+import logging
+import os.path as op
+import re
+from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
-from fairseq.data import (FairseqDataset, Dictionary, ResamplingDataset,
- ConcatDataset, data_utils as fairseq_data_utils)
+from fairseq.data import (
+ ConcatDataset,
+ Dictionary,
+ FairseqDataset,
+ ResamplingDataset,
+ data_utils as fairseq_data_utils,
+)
from fairseq.data.audio.audio_utils import get_fbank, get_waveform
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
+
logging.basicConfig(
- format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO,
)
logger = logging.getLogger(__name__)
class S2TDataConfig(object):
"""Wrapper class for data config YAML"""
+
def __init__(self, yaml_path):
try:
import yaml
except ImportError:
- print('Please install PyYAML to load YAML files for '
- 'S2T data config')
+ print("Please install PyYAML to load YAML files for " "S2T data config")
self.config = {}
if op.isfile(yaml_path):
try:
with open(yaml_path) as f:
self.config = yaml.load(f, Loader=yaml.FullLoader)
except Exception as e:
- logger.info(f'Failed to load config from {yaml_path}: {e}')
+ logger.info(f"Failed to load config from {yaml_path}: {e}")
else:
- logger.info(f'Cannot find {yaml_path}')
+ logger.info(f"Cannot find {yaml_path}")
@property
def vocab_filename(self):
"""fairseq vocabulary file under data root"""
- return self.config.get('vocab_filename', 'dict.txt')
+ return self.config.get("vocab_filename", "dict.txt")
@property
def shuffle(self) -> bool:
"""Shuffle dataset samples before batching"""
- return self.config.get('shuffle', False)
+ return self.config.get("shuffle", False)
@property
def pre_tokenizer(self) -> Dict:
@@ -58,7 +65,7 @@ def pre_tokenizer(self) -> Dict:
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
- return self.config.get('pre_tokenizer', {'tokenizer': None})
+ return self.config.get("pre_tokenizer", {"tokenizer": None})
@property
def bpe_tokenizer(self) -> Dict:
@@ -66,54 +73,55 @@ def bpe_tokenizer(self) -> Dict:
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
- return self.config.get('bpe_tokenizer', None)
+ return self.config.get("bpe_tokenizer", None)
@property
def prepend_tgt_lang_tag(self) -> bool:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
- return self.config.get('prepend_tgt_lang_tag', False)
+ return self.config.get("prepend_tgt_lang_tag", False)
@property
def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)"""
- return self.config.get('input_feat_per_channel', 80)
+ return self.config.get("input_feat_per_channel", 80)
@property
def input_channels(self):
"""The number of channels in the input audio"""
- return self.config.get('input_channels', 1)
+ return self.config.get("input_channels", 1)
@property
def sampling_alpha(self):
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
(alpha = 1 for no resampling)"""
- return self.config.get('sampling_alpha', 1.)
+ return self.config.get("sampling_alpha", 1.0)
@property
def use_audio_input(self):
"""Needed by the dataset loader to see if the model requires
raw audio as inputs."""
- return self.config.get('use_audio_input', False)
+ return self.config.get("use_audio_input", False)
@property
def audio_root(self):
"""Audio paths in the manifest TSV can be relative and this provides
the root path. Set this to empty string when using absolute paths."""
- return self.config.get('audio_root', '')
+ return self.config.get("audio_root", "")
def get_feature_transforms(self, split, is_train):
"""Split-specific feature transforms. Allowing train set wildcard `_train`,
evaluation set wildcard `_eval` and general wildcard `*` for matching."""
from copy import deepcopy
+
cfg = deepcopy(self.config)
- _cur = cfg.get('transforms', {})
+ _cur = cfg.get("transforms", {})
cur = _cur.get(split)
- cur = _cur.get('_train') if cur is None and is_train else cur
- cur = _cur.get('_eval') if cur is None and not is_train else cur
- cur = _cur.get('*') if cur is None else cur
- cfg['transforms'] = cur
+ cur = _cur.get("_train") if cur is None and is_train else cur
+ cur = _cur.get("_eval") if cur is None and not is_train else cur
+ cur = _cur.get("*") if cur is None else cur
+ cfg["transforms"] = cur
return cfg
@@ -122,13 +130,13 @@ def is_npy_data(data: bytes) -> bool:
def is_flac_or_wav_data(data: bytes) -> bool:
- is_flac = (data[0] == 102 and data[1] == 76)
- is_wav = (data[0] == 82 and data[1] == 73)
+ is_flac = data[0] == 102 and data[1] == 76
+ is_wav = data[0] == 82 and data[1] == 73
return is_flac or is_wav
def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes:
- with open(file_path, 'rb') as f:
+ with open(file_path, "rb") as f:
f.seek(offset)
data = f.read(file_size)
return data
@@ -136,15 +144,15 @@ def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes:
def get_features_from_npy_or_audio(path):
ext = op.splitext(op.basename(path))[1]
- if ext not in {'.npy', '.flac', '.wav'}:
+ if ext not in {".npy", ".flac", ".wav"}:
raise ValueError(f'Unsupported file format for "{path}"')
- return np.load(path) if ext == '.npy' else get_fbank(path)
+ return np.load(path) if ext == ".npy" else get_fbank(path)
def get_features_or_waveform_from_uncompressed_zip(
- path, byte_offset, byte_size, need_waveform=False
+ path, byte_offset, byte_size, need_waveform=False
):
- assert path.endswith('.zip')
+ assert path.endswith(".zip")
data = read_from_uncompressed_zip(path, byte_offset, byte_size)
f = io.BytesIO(data)
if is_npy_data(data):
@@ -169,9 +177,9 @@ def get_features_or_waveform(path: str, need_waveform=False):
Returns:
features_or_waveform (numpy.ndarray): speech features or waveform.
"""
- _path, *extra = path.split(':')
+ _path, *extra = path.split(":")
if not op.exists(_path):
- raise FileNotFoundError(f'File not found: {_path}')
+ raise FileNotFoundError(f"File not found: {_path}")
if len(extra) == 0:
if need_waveform:
@@ -183,13 +191,14 @@ def get_features_or_waveform(path: str, need_waveform=False):
_path, extra[0], extra[1], need_waveform=need_waveform
)
else:
- raise ValueError(f'Invalid path: {path}')
+ raise ValueError(f"Invalid path: {path}")
return features_or_waveform
-def _collate_frames(frames: List[torch.Tensor],
- is_audio_input: bool = False) -> torch.Tensor:
+def _collate_frames(
+ frames: List[torch.Tensor], is_audio_input: bool = False
+) -> torch.Tensor:
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
@@ -209,24 +218,24 @@ def _collate_frames(frames: List[torch.Tensor],
class SpeechToTextDataset(FairseqDataset):
- LANG_TAG_TEMPLATE = ''
+ LANG_TAG_TEMPLATE = ""
def __init__(
- self,
- split: str,
- is_train_split: bool,
- data_cfg: S2TDataConfig,
- audio_paths: List[str],
- n_frames: List[int],
- src_texts: Optional[List[str]] = None,
- tgt_texts: Optional[List[str]] = None,
- speakers: Optional[List[str]] = None,
- src_langs: Optional[List[str]] = None,
- tgt_langs: Optional[List[str]] = None,
- ids: Optional[List[str]] = None,
- tgt_dict: Optional[Dictionary] = None,
- pre_tokenizer=None,
- bpe_tokenizer=None,
+ self,
+ split: str,
+ is_train_split: bool,
+ data_cfg: S2TDataConfig,
+ audio_paths: List[str],
+ n_frames: List[int],
+ src_texts: Optional[List[str]] = None,
+ tgt_texts: Optional[List[str]] = None,
+ speakers: Optional[List[str]] = None,
+ src_langs: Optional[List[str]] = None,
+ tgt_langs: Optional[List[str]] = None,
+ ids: Optional[List[str]] = None,
+ tgt_dict: Optional[Dictionary] = None,
+ pre_tokenizer=None,
+ bpe_tokenizer=None,
):
self.split, self.is_train_split = split, is_train_split
self.data_cfg = data_cfg
@@ -239,8 +248,9 @@ def __init__(
assert src_langs is None or len(src_langs) == self.n_samples
assert tgt_langs is None or len(tgt_langs) == self.n_samples
assert ids is None or len(ids) == self.n_samples
- assert (tgt_dict is None and tgt_texts is None) or \
- (tgt_dict is not None and tgt_texts is not None)
+ assert (tgt_dict is None and tgt_texts is None) or (
+ tgt_dict is not None and tgt_texts is not None
+ )
self.tgt_dict = tgt_dict
self.check_tgt_lang_tag()
self.src_texts, self.tgt_texts = src_texts, tgt_texts
@@ -258,21 +268,24 @@ def __init__(
logger.info(self.__repr__())
def __repr__(self):
- return self.__class__.__name__ + \
- f'(split="{self.split}", n_samples={self.n_samples}, ' \
- f'prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, ' \
- f'shuffle={self.shuffle}, transforms={self.feature_transforms})'
+ return (
+ self.__class__.__name__
+ + f'(split="{self.split}", n_samples={self.n_samples}, '
+ f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, "
+ f"shuffle={self.shuffle}, transforms={self.feature_transforms})"
+ )
@classmethod
def is_lang_tag(cls, token):
- pattern = cls.LANG_TAG_TEMPLATE.replace('{}', '(.*)')
+ pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
return re.match(pattern, token)
def check_tgt_lang_tag(self):
if self.data_cfg.prepend_tgt_lang_tag:
assert self.tgt_langs is not None and self.tgt_dict is not None
- tgt_lang_tags = [self.LANG_TAG_TEMPLATE.format(t)
- for t in set(self.tgt_langs)]
+ tgt_lang_tags = [
+ self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
+ ]
assert all(t in self.tgt_dict for t in tgt_lang_tags)
def tokenize_text(self, text: str):
@@ -283,7 +296,7 @@ def tokenize_text(self, text: str):
return text
def __getitem__(
- self, index: int
+ self, index: int
) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]:
source = get_features_or_waveform(
self.audio_paths[index], need_waveform=self.data_cfg.use_audio_input
@@ -308,18 +321,15 @@ def __getitem__(
def __len__(self):
return self.n_samples
- def collater(
- self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]
- ) -> Dict:
+ def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long)
- frames = _collate_frames([s for _, s, _ in samples],
- self.data_cfg.use_audio_input)
- # sort samples by descending number of frames
- n_frames = torch.tensor(
- [s.size(0) for _, s, _ in samples], dtype=torch.long
+ frames = _collate_frames(
+ [s for _, s, _ in samples], self.data_cfg.use_audio_input
)
+ # sort samples by descending number of frames
+ n_frames = torch.tensor([s.size(0) for _, s, _ in samples], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
@@ -329,16 +339,22 @@ def collater(
ntokens = None
if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens(
- [t for _, _, t in samples], self.tgt_dict.pad(),
- self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False
+ [t for _, _, t in samples],
+ self.tgt_dict.pad(),
+ self.tgt_dict.eos(),
+ left_pad=False,
+ move_eos_to_beginning=False,
)
target = target.index_select(0, order)
target_lengths = torch.tensor(
[t.size(0) for _, _, t in samples], dtype=torch.long
).index_select(0, order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
- [t for _, _, t in samples], self.tgt_dict.pad(),
- self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True
+ [t for _, _, t in samples],
+ self.tgt_dict.pad(),
+ self.tgt_dict.eos(),
+ left_pad=False,
+ move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(t.size(0) for _, _, t in samples)
@@ -364,7 +380,7 @@ def size(self, index):
t_len = 0
if self.tgt_texts is not None:
tokenized = self.tokenize_text(self.tgt_texts[index])
- t_len = len(tokenized.split(' '))
+ t_len = len(tokenized.split(" "))
return self.n_frames[index], t_len
@property
@@ -390,43 +406,59 @@ def prefetch(self, indices):
class SpeechToTextDatasetCreator(object):
# mandatory columns
- KEY_ID, KEY_AUDIO, KEY_N_FRAMES = 'id', 'audio', 'n_frames'
- KEY_TGT_TEXT = 'tgt_text'
+ KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
+ KEY_TGT_TEXT = "tgt_text"
# optional columns
- KEY_SPEAKER, KEY_SRC_TEXT = 'speaker', 'src_text'
- KEY_SRC_LANG, KEY_TGT_LANG = 'src_lang', 'tgt_lang'
+ KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
+ KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
# default values
- DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ''
+ DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
@classmethod
- def _from_list(cls, split_name: str, is_train_split,
- samples: List[List[Dict]], data_cfg: S2TDataConfig, tgt_dict,
- pre_tokenizer, bpe_tokenizer) -> SpeechToTextDataset:
+ def _from_list(
+ cls,
+ split_name: str,
+ is_train_split,
+ samples: List[List[Dict]],
+ data_cfg: S2TDataConfig,
+ tgt_dict,
+ pre_tokenizer,
+ bpe_tokenizer,
+ ) -> SpeechToTextDataset:
audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], []
speakers, src_langs, tgt_langs = [], [], []
for s in samples:
ids.extend([ss[cls.KEY_ID] for ss in s])
- audio_paths.extend([op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO])
- for ss in s])
+ audio_paths.extend(
+ [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
+ )
n_frames.extend([int(ss[cls.KEY_N_FRAMES]) for ss in s])
tgt_texts.extend([ss[cls.KEY_TGT_TEXT] for ss in s])
- src_texts.extend([ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT)
- for ss in s])
- speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER)
- for ss in s])
- src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG)
- for ss in s])
- tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG)
- for ss in s])
+ src_texts.extend(
+ [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
+ )
+ speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s])
+ src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s])
+ tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s])
return SpeechToTextDataset(
- split_name, is_train_split, data_cfg, audio_paths, n_frames,
- src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict,
- pre_tokenizer, bpe_tokenizer
+ split_name,
+ is_train_split,
+ data_cfg,
+ audio_paths,
+ n_frames,
+ src_texts,
+ tgt_texts,
+ speakers,
+ src_langs,
+ tgt_langs,
+ ids,
+ tgt_dict,
+ pre_tokenizer,
+ bpe_tokenizer,
)
@classmethod
- def _get_size_ratios(cls, ids: List[str], sizes: List[int],
- alpha: float = 1.):
+ def _get_size_ratios(cls, ids: List[str], sizes: List[int], alpha: float = 1.0):
"""Size ratios for temperature-based sampling
(https://arxiv.org/abs/1907.05019)"""
_sizes = np.array(sizes)
@@ -444,35 +476,58 @@ def _get_size_ratios(cls, ids: List[str], sizes: List[int],
return size_ratio.tolist()
@classmethod
- def from_tsv(cls, root: str, data_cfg: S2TDataConfig, splits: str, tgt_dict,
- pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int,
- seed: int) -> SpeechToTextDataset:
+ def from_tsv(
+ cls,
+ root: str,
+ data_cfg: S2TDataConfig,
+ splits: str,
+ tgt_dict,
+ pre_tokenizer,
+ bpe_tokenizer,
+ is_train_split: bool,
+ epoch: int,
+ seed: int,
+ ) -> SpeechToTextDataset:
samples = []
- _splits = splits.split(',')
+ _splits = splits.split(",")
for split in _splits:
- tsv_path = op.join(root, f'{split}.tsv')
+ tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path):
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
- f, delimiter='\t', quotechar=None, doublequote=False,
- lineterminator='\n', quoting=csv.QUOTE_NONE
+ f,
+ delimiter="\t",
+ quotechar=None,
+ doublequote=False,
+ lineterminator="\n",
+ quoting=csv.QUOTE_NONE,
)
samples.append([dict(e) for e in reader])
assert len(samples) > 0
- datasets = [cls._from_list(name, is_train_split, [s], data_cfg, tgt_dict,
- pre_tokenizer, bpe_tokenizer)
- for name, s in zip(_splits, samples)]
+ datasets = [
+ cls._from_list(
+ name,
+ is_train_split,
+ [s],
+ data_cfg,
+ tgt_dict,
+ pre_tokenizer,
+ bpe_tokenizer,
+ )
+ for name, s in zip(_splits, samples)
+ ]
- if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.:
+ if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.0:
# temperature-based sampling
size_ratios = cls._get_size_ratios(
_splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha
)
datasets = [
- ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch,
- replace=(r >= 1.))
+ ResamplingDataset(
+ d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
+ )
for d, r in zip(datasets, size_ratios)
]
return ConcatDataset(datasets)
diff --git a/fairseq/data/backtranslation_dataset.py b/fairseq/data/backtranslation_dataset.py
index 0007a01506..8f70c90df3 100644
--- a/fairseq/data/backtranslation_dataset.py
+++ b/fairseq/data/backtranslation_dataset.py
@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
import torch
-
from fairseq import utils
from . import FairseqDataset
@@ -36,16 +35,18 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
generated_sources = generate_fn(s)
- id_to_src = {
- sample['id']: sample['source'] for sample in samples
- }
+ id_to_src = {sample["id"]: sample["source"] for sample in samples}
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
return [
- {'id': id.item(), 'target': id_to_src[id.item()], 'source': hypos[0]['tokens'].cpu()}
- for id, hypos in zip(collated_samples['id'], generated_sources)
+ {
+ "id": id.item(),
+ "target": id_to_src[id.item()],
+ "source": hypos[0]["tokens"].cpu(),
+ }
+ for id, hypos in zip(collated_samples["id"], generated_sources)
]
@@ -87,8 +88,9 @@ def __init__(
):
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
- self.output_collater = output_collater if output_collater is not None \
- else tgt_dataset.collater
+ self.output_collater = (
+ output_collater if output_collater is not None else tgt_dataset.collater
+ )
self.cuda = cuda if torch.cuda.is_available() else False
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@@ -126,14 +128,12 @@ def collater(self, samples):
Returns:
dict: a mini-batch with keys coming from *output_collater*
"""
- if samples[0].get('is_dummy', False):
+ if samples[0].get("is_dummy", False):
return samples
samples = backtranslate_samples(
samples=samples,
collate_fn=self.tgt_dataset.collater,
- generate_fn=(
- lambda net_input: self.backtranslation_fn(net_input)
- ),
+ generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
cuda=self.cuda,
)
return self.output_collater(samples)
@@ -159,7 +159,7 @@ def size(self, index):
@property
def supports_prefetch(self):
- return getattr(self.tgt_dataset, 'supports_prefetch', False)
+ return getattr(self.tgt_dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices)
diff --git a/fairseq/data/base_wrapper_dataset.py b/fairseq/data/base_wrapper_dataset.py
index 680dcce9ae..134d398b47 100644
--- a/fairseq/data/base_wrapper_dataset.py
+++ b/fairseq/data/base_wrapper_dataset.py
@@ -9,7 +9,6 @@
class BaseWrapperDataset(FairseqDataset):
-
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
@@ -21,7 +20,7 @@ def __len__(self):
return len(self.dataset)
def collater(self, samples):
- if hasattr(self.dataset, 'collater'):
+ if hasattr(self.dataset, "collater"):
return self.dataset.collater(samples)
else:
return default_collate(samples)
@@ -41,7 +40,7 @@ def ordered_indices(self):
@property
def supports_prefetch(self):
- return getattr(self.dataset, 'supports_prefetch', False)
+ return getattr(self.dataset, "supports_prefetch", False)
def attr(self, attr: str, index: int):
return self.dataset.attr(attr, index)
@@ -75,5 +74,5 @@ def can_reuse_epoch_itr_across_epochs(self):
def set_epoch(self, epoch):
super().set_epoch(epoch)
- if hasattr(self.dataset, 'set_epoch'):
+ if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py
index 6f53d01188..cda8834ac8 100644
--- a/fairseq/data/bucket_pad_length_dataset.py
+++ b/fairseq/data/bucket_pad_length_dataset.py
@@ -5,7 +5,6 @@
import numpy as np
import torch.nn.functional as F
-
from fairseq.data import BaseWrapperDataset
@@ -40,7 +39,7 @@ def __init__(
np.percentile(
sizes,
np.linspace(0, 100, num_buckets + 1),
- interpolation='lower',
+ interpolation="lower",
)[1:]
)
diff --git a/fairseq/data/colorize_dataset.py b/fairseq/data/colorize_dataset.py
index 89e0e04142..6ef097bff1 100644
--- a/fairseq/data/colorize_dataset.py
+++ b/fairseq/data/colorize_dataset.py
@@ -10,6 +10,7 @@
class ColorizeDataset(BaseWrapperDataset):
""" Adds 'colors' property to net input that is obtained from the provided color getter for use by models """
+
def __init__(self, dataset, color_getter):
super().__init__(dataset)
self.color_getter = color_getter
diff --git a/fairseq/data/concat_dataset.py b/fairseq/data/concat_dataset.py
index 0091a28e47..01a4078bb1 100644
--- a/fairseq/data/concat_dataset.py
+++ b/fairseq/data/concat_dataset.py
@@ -49,7 +49,7 @@ def _get_dataset_and_sample_index(self, idx: int):
def collater(self, samples, **extra_args):
# For now only supports datasets with same underlying collater implementations
- if hasattr(self.datasets[0], 'collater'):
+ if hasattr(self.datasets[0], "collater"):
return self.datasets[0].collater(samples, **extra_args)
else:
return default_collate(samples, **extra_args)
@@ -92,14 +92,16 @@ def ordered_indices(self):
# special handling for concatenating lang_pair_datasets
indices = np.arange(len(self))
sizes = self.sizes
- tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
- src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ tgt_sizes = (
+ sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
+ )
+ src_sizes = (
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ )
# sort by target length, then source length
if tgt_sizes is not None:
- indices = indices[
- np.argsort(tgt_sizes[indices], kind='mergesort')
- ]
- return indices[np.argsort(src_sizes[indices], kind='mergesort')]
+ indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
+ return indices[np.argsort(src_sizes[indices], kind="mergesort")]
else:
return np.argsort(self.sizes)
@@ -107,7 +109,7 @@ def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
- if getattr(ds, 'supports_prefetch', False):
+ if getattr(ds, "supports_prefetch", False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
@@ -118,5 +120,5 @@ def can_reuse_epoch_itr_across_epochs(self):
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
- if hasattr(ds, 'set_epoch'):
+ if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
diff --git a/fairseq/data/concat_sentences_dataset.py b/fairseq/data/concat_sentences_dataset.py
index 55445ee1c7..625a29370e 100644
--- a/fairseq/data/concat_sentences_dataset.py
+++ b/fairseq/data/concat_sentences_dataset.py
@@ -9,12 +9,12 @@
class ConcatSentencesDataset(FairseqDataset):
-
def __init__(self, *datasets):
super().__init__()
self.datasets = datasets
- assert all(len(ds) == len(datasets[0]) for ds in datasets), \
- 'datasets must have the same length'
+ assert all(
+ len(ds) == len(datasets[0]) for ds in datasets
+ ), "datasets must have the same length"
def __getitem__(self, index):
return torch.cat([ds[index] for ds in self.datasets])
@@ -40,17 +40,15 @@ def ordered_indices(self):
@property
def supports_prefetch(self):
- return any(
- getattr(ds, 'supports_prefetch', False) for ds in self.datasets
- )
+ return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
def prefetch(self, indices):
for ds in self.datasets:
- if getattr(ds, 'supports_prefetch', False):
+ if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
- if hasattr(ds, 'set_epoch'):
+ if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py
index a8c480c5b1..81f457365a 100644
--- a/fairseq/data/data_utils.py
+++ b/fairseq/data/data_utils.py
@@ -12,8 +12,7 @@
import logging
import os
import warnings
-
-from typing import Tuple, Optional
+from typing import Optional, Tuple
import numpy as np
import torch
@@ -26,19 +25,26 @@ def infer_language_pair(path):
"""Infer language pair from filename: .-.(...).idx"""
src, dst = None, None
for filename in os.listdir(path):
- parts = filename.split('.')
- if len(parts) >= 3 and len(parts[1].split('-')) == 2:
- return parts[1].split('-')
+ parts = filename.split(".")
+ if len(parts) >= 3 and len(parts[1].split("-")) == 2:
+ return parts[1].split("-")
return src, dst
-def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False,
- pad_to_length=None, pad_to_multiple=1):
+def collate_tokens(
+ values,
+ pad_idx,
+ eos_idx=None,
+ left_pad=False,
+ move_eos_to_beginning=False,
+ pad_to_length=None,
+ pad_to_multiple=1,
+):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
- size = int(((size-0.1)//pad_to_multiple + 1) * pad_to_multiple)
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
@@ -54,11 +60,13 @@ def copy_tensor(src, dst):
dst.copy_(src)
for i, v in enumerate(values):
- copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
-def load_indexed_dataset(path, dictionary=None, dataset_impl=None, combine=False, default='cached'):
+def load_indexed_dataset(
+ path, dictionary=None, dataset_impl=None, combine=False, default="cached"
+):
"""A helper function for loading indexed datasets.
Args:
@@ -74,9 +82,10 @@ def load_indexed_dataset(path, dictionary=None, dataset_impl=None, combine=False
"""
from fairseq.data.concat_dataset import ConcatDataset
import fairseq.data.indexed_dataset as indexed_dataset
+
datasets = []
for k in itertools.count():
- path_k = path + (str(k) if k > 0 else '')
+ path_k = path + (str(k) if k > 0 else "")
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
dataset_impl_k = dataset_impl
@@ -90,7 +99,7 @@ def load_indexed_dataset(path, dictionary=None, dataset_impl=None, combine=False
)
if dataset is None:
break
- logger.info('loaded {} examples from: {}'.format(len(dataset), path_k))
+ logger.info("loaded {} examples from: {}".format(len(dataset), path_k))
datasets.append(dataset)
if not combine:
break
@@ -148,8 +157,10 @@ def check_size(idx):
assert isinstance(idx_size, dict)
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
return all(
- all(a is None or b is None or a <= b
- for a, b in zip(idx_size[key], max_positions[key]))
+ all(
+ a is None or b is None or a <= b
+ for a, b in zip(idx_size[key], max_positions[key])
+ )
for key in intersect_keys
)
else:
@@ -166,6 +177,7 @@ def check_size(idx):
a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions)
)
+
ignored = []
itr = collect_filtered(check_size, indices, ignored)
indices = np.fromiter(itr, dtype=np.int64, count=-1)
@@ -186,37 +198,47 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
any elements are filtered (default: False).
"""
warnings.warn(
- 'data_utils.filter_by_size is deprecated. '
- 'Use `FairseqDataset::filter_indices_by_size` instead.',
- stacklevel=2
+ "data_utils.filter_by_size is deprecated. "
+ "Use `FairseqDataset::filter_indices_by_size` instead.",
+ stacklevel=2,
)
if isinstance(max_positions, float) or isinstance(max_positions, int):
- if hasattr(dataset, 'sizes') and isinstance(dataset.sizes, np.ndarray):
+ if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
indices = indices[dataset.sizes[indices] <= max_positions]
- elif hasattr(dataset, 'sizes') and isinstance(dataset.sizes, list) and len(dataset.sizes) == 1:
+ elif (
+ hasattr(dataset, "sizes")
+ and isinstance(dataset.sizes, list)
+ and len(dataset.sizes) == 1
+ ):
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
indices = indices[dataset.sizes[0][indices] <= max_positions]
else:
- indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
+ indices, ignored = _filter_by_size_dynamic(
+ indices, dataset.size, max_positions
+ )
else:
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
if len(ignored) > 0 and raise_exception:
- raise Exception((
- 'Size of sample #{} is invalid (={}) since max_positions={}, '
- 'skip this example with --skip-invalid-size-inputs-valid-test'
- ).format(ignored[0], dataset.size(ignored[0]), max_positions))
+ raise Exception(
+ (
+ "Size of sample #{} is invalid (={}) since max_positions={}, "
+ "skip this example with --skip-invalid-size-inputs-valid-test"
+ ).format(ignored[0], dataset.size(ignored[0]), max_positions)
+ )
if len(ignored) > 0:
- logger.warning((
- '{} samples have invalid sizes and will be skipped, '
- 'max_positions={}, first few sample ids={}'
- ).format(len(ignored), max_positions, ignored[:10]))
+ logger.warning(
+ (
+ "{} samples have invalid sizes and will be skipped, "
+ "max_positions={}, first few sample ids={}"
+ ).format(len(ignored), max_positions, ignored[:10])
+ )
return indices
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
- """ Filter a list of sample indices. Remove those that are longer
+ """Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
@@ -238,21 +260,26 @@ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_siz
ignored = indices[src_sizes[indices] > max_src_size]
else:
ignored = indices[
- (src_sizes[indices] > max_src_size) |
- (tgt_sizes[indices] > max_tgt_size)]
+ (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
+ ]
if len(ignored) > 0:
if tgt_sizes is None:
indices = indices[src_sizes[indices] <= max_src_size]
else:
indices = indices[
- (src_sizes[indices] <= max_src_size) &
- (tgt_sizes[indices] <= max_tgt_size)]
+ (src_sizes[indices] <= max_src_size)
+ & (tgt_sizes[indices] <= max_tgt_size)
+ ]
return indices, ignored.tolist()
def batch_by_size(
- indices, num_tokens_fn, max_tokens=None, max_sentences=None,
- required_batch_size_multiple=1, fixed_shapes=None,
+ indices,
+ num_tokens_fn,
+ max_tokens=None,
+ max_sentences=None,
+ required_batch_size_multiple=1,
+ fixed_shapes=None,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
@@ -274,12 +301,13 @@ def batch_by_size(
"""
try:
from fairseq.data.data_utils_fast import (
- batch_by_size_fast, batch_fixed_shapes_fast,
+ batch_by_size_fast,
+ batch_fixed_shapes_fast,
)
except ImportError:
raise ImportError(
- 'Please build Cython components with: `pip install --editable .` '
- 'or `python setup.py build_ext --inplace`'
+ "Please build Cython components with: `pip install --editable .` "
+ "or `python setup.py build_ext --inplace`"
)
max_tokens = max_tokens if max_tokens is not None else -1
@@ -291,14 +319,20 @@ def batch_by_size(
if fixed_shapes is None:
return batch_by_size_fast(
- indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult,
+ indices,
+ num_tokens_fn,
+ max_tokens,
+ max_sentences,
+ bsz_mult,
)
else:
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
- sort_order = np.lexsort([
- fixed_shapes[:, 1].argsort(), # length
- fixed_shapes[:, 0].argsort(), # bsz
- ])
+ sort_order = np.lexsort(
+ [
+ fixed_shapes[:, 1].argsort(), # length
+ fixed_shapes[:, 0].argsort(), # bsz
+ ]
+ )
fixed_shapes_sorted = fixed_shapes[sort_order]
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
@@ -306,26 +340,27 @@ def batch_by_size(
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
- elif symbol == 'wordpiece':
+ elif symbol == "wordpiece":
sentence = sentence.replace(" ", "").replace("_", " ").strip()
- elif symbol == 'letter':
+ elif symbol == "letter":
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
- elif symbol is not None and symbol != 'none':
+ elif symbol is not None and symbol != "none":
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
+
def compute_mask_indices(
- shape: Tuple[int, int],
- padding_mask: Optional[torch.Tensor],
- mask_prob: float,
- mask_length: int,
- mask_type: str = "static",
- mask_other: float = 0.0,
- min_masks: int = 0,
- no_overlap: bool = False,
- min_space: int = 0,
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
@@ -390,13 +425,14 @@ def compute_mask_indices(
if no_overlap:
mask_idc = []
+
def arrange(s, e, length, keep_length):
- span_start = np.random.randint(s, e-length)
+ span_start = np.random.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
- new_parts.append((s, span_start-min_space+1))
+ new_parts.append((s, span_start - min_space + 1))
if e - span_start - keep_length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
@@ -404,7 +440,10 @@ def arrange(s, e, length, keep_length):
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
- lens = np.fromiter((e - s if e-s >= length+min_space else 0 for s, e in parts), np.int)
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
l_sum = np.sum(lens)
if l_sum == 0:
break
@@ -416,7 +455,7 @@ def arrange(s, e, length, keep_length):
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
- min_len = sz - num_mask - 1
+ min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
@@ -442,10 +481,11 @@ def arrange(s, e, length, keep_length):
def get_mem_usage():
try:
import psutil
+
mb = 1024 * 1024
- return f'used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb'
+ return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
except ImportError:
- return 'N/A'
+ return "N/A"
def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor:
diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py
index 4fe560b0a7..bdb62c8d5d 100644
--- a/fairseq/data/denoising_dataset.py
+++ b/fairseq/data/denoising_dataset.py
@@ -3,11 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+import math
+
import numpy as np
import torch
-import math
-from . import data_utils, FairseqDataset
+from . import FairseqDataset, data_utils
def collate(
@@ -34,53 +35,59 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
pad_to_length=pad_to_length,
)
- id = torch.LongTensor([s['id'] for s in samples])
+ id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
- 'source', left_pad=left_pad_source,
- pad_to_length=pad_to_length['source'] if pad_to_length is not None else None,
+ "source",
+ left_pad=left_pad_source,
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
- src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
+ src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
- if samples[0].get('target', None) is not None:
+ if samples[0].get("target", None) is not None:
target = merge(
- 'target', left_pad=left_pad_target,
- pad_to_length=pad_to_length['target'] if pad_to_length is not None else None,
+ "target",
+ left_pad=left_pad_target,
+ pad_to_length=pad_to_length["target"]
+ if pad_to_length is not None
+ else None,
)
target = target.index_select(0, sort_order)
- ntokens = sum(len(s['target']) for s in samples)
+ ntokens = sum(len(s["target"]) for s in samples)
if input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
- 'target',
+ "target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
- pad_to_length=pad_to_length['target'] if pad_to_length is not None else None,
+ pad_to_length=pad_to_length["target"]
+ if pad_to_length is not None
+ else None,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
- ntokens = sum(len(s['source']) for s in samples)
+ ntokens = sum(len(s["source"]) for s in samples)
batch = {
- 'id': id,
- 'ntokens': ntokens,
- 'net_input': {
- 'src_tokens': src_tokens,
- 'src_lengths': src_lengths,
+ "id": id,
+ "ntokens": ntokens,
+ "net_input": {
+ "src_tokens": src_tokens,
+ "src_lengths": src_lengths,
},
- 'target': target,
- 'nsentences': samples[0]['source'].size(0),
- 'sort_order': sort_order,
+ "target": target,
+ "nsentences": samples[0]["source"].size(0),
+ "sort_order": sort_order,
}
if prev_output_tokens is not None:
- batch['net_input']['prev_output_tokens'] = prev_output_tokens
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch
@@ -130,25 +137,25 @@ def __init__(
self.insert_ratio = args.insert
self.rotate_ratio = args.rotate
self.permute_sentence_ratio = args.permute_sentences
- self.eos = (eos if eos is not None else vocab.eos())
+ self.eos = eos if eos is not None else vocab.eos()
self.item_transform_func = item_transform_func
- if args.bpe != 'gpt2':
+ if args.bpe != "gpt2":
self.full_stop_index = self.vocab.eos()
else:
- assert args.bpe == 'gpt2'
- self.full_stop_index = self.vocab.index('13')
+ assert args.bpe == "gpt2"
+ self.full_stop_index = self.vocab.index("13")
self.replace_length = args.replace_length
if self.replace_length not in [-1, 0, 1]:
- raise ValueError(f'invalid arg: replace_length={self.replace_length}')
- if args.mask_length not in ['subword', 'word', 'span-poisson']:
- raise ValueError(f'invalid arg: mask-length={args.mask_length}')
- if args.mask_length == 'subword' and args.replace_length not in [0, 1]:
- raise ValueError(f'if using subwords, use replace-length=1 or 0')
+ raise ValueError(f"invalid arg: replace_length={self.replace_length}")
+ if args.mask_length not in ["subword", "word", "span-poisson"]:
+ raise ValueError(f"invalid arg: mask-length={args.mask_length}")
+ if args.mask_length == "subword" and args.replace_length not in [0, 1]:
+ raise ValueError(f"if using subwords, use replace-length=1 or 0")
self.mask_span_distribution = None
- if args.mask_length == 'span-poisson':
+ if args.mask_length == "span-poisson":
_lambda = args.poisson_lambda
lambda_to_the_k = 1
@@ -158,7 +165,7 @@ def __init__(
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= _lambda
- k_factorial *= (k + 1)
+ k_factorial *= k + 1
if ps[-1] < 0.0000001:
break
ps = torch.FloatTensor(ps)
@@ -200,16 +207,16 @@ def __getitem__(self, index):
assert source[0] == self.vocab.bos()
assert source[-1] == self.eos
return {
- 'id': index,
- 'source': source,
- 'target': target,
+ "id": index,
+ "source": source,
+ "target": target,
}
def __len__(self):
return len(self.dataset)
def permute_sentences(self, source, p=1.0):
- full_stops = (source == self.full_stop_index)
+ full_stops = source == self.full_stop_index
# Pretend it ends with a full stop so last span is a sentence
full_stops[-2] = 1
@@ -226,8 +233,8 @@ def permute_sentences(self, source, p=1.0):
# Ignore at start
index = 1
for i in ordering:
- sentence = source[(sentence_ends[i - 1] if i > 0 else 1):sentence_ends[i]]
- result[index:index + sentence.size(0)] = sentence
+ sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
+ result[index : index + sentence.size(0)] = sentence
index += sentence.size(0)
return result
@@ -253,7 +260,13 @@ def add_whole_word_mask(self, source, p):
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < num_to_mask:
- lengths = torch.cat([lengths, self.mask_span_distribution.sample(sample_shape=(num_to_mask,))], dim=0)
+ lengths = torch.cat(
+ [
+ lengths,
+ self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
+ ],
+ dim=0,
+ )
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
@@ -276,19 +289,25 @@ def add_whole_word_mask(self, source, p):
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero(as_tuple=False)
- indices = word_starts[torch.randperm(word_starts.size(0))[:num_to_mask]].squeeze(1)
+ indices = word_starts[
+ torch.randperm(word_starts.size(0))[:num_to_mask]
+ ].squeeze(1)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
source_length = source.size(0)
assert source_length - 1 not in indices
to_keep = torch.ones(source_length, dtype=torch.bool)
- is_word_start[-1] = 255 # acts as a long length, so spans don't go over the end of doc
+ is_word_start[
+ -1
+ ] = 255 # acts as a long length, so spans don't go over the end of doc
if self.replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
- source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),))
+ source[indices[mask_random]] = torch.randint(
+ 1, len(self.vocab), size=(mask_random.sum(),)
+ )
if self.mask_span_distribution is not None:
assert len(lengths.size()) == 1
@@ -307,7 +326,9 @@ def add_whole_word_mask(self, source, p):
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
- source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),))
+ source[indices[mask_random]] = torch.randint(
+ 1, len(self.vocab), size=(mask_random.sum(),)
+ )
else:
# A bit faster when all lengths are 1
while indices.size(0) > 0:
@@ -320,7 +341,9 @@ def add_whole_word_mask(self, source, p):
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
- source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),))
+ source[indices[mask_random]] = torch.randint(
+ 1, len(self.vocab), size=(mask_random.sum(),)
+ )
assert source_length - 1 not in indices
@@ -360,7 +383,9 @@ def add_insertion_noise(self, tokens, p):
num_random = int(math.ceil(n * self.random_ratio))
result[noise_indices[num_random:]] = self.mask_idx
- result[noise_indices[:num_random]] = torch.randint(low=1, high=len(self.vocab), size=(num_random,))
+ result[noise_indices[:num_random]] = torch.randint(
+ low=1, high=len(self.vocab), size=(num_random,)
+ )
result[~noise_mask] = tokens
@@ -375,8 +400,8 @@ def collater(self, samples, pad_to_length=None):
dict: a mini-batch of data
"""
return collate(
- samples, self.vocab.pad(), self.eos, self.vocab,
- pad_to_length=pad_to_length)
+ samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
+ )
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
@@ -395,7 +420,7 @@ def ordered_indices(self):
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
- return indices[np.argsort(self.sizes[indices], kind='mergesort')]
+ return indices[np.argsort(self.sizes[indices], kind="mergesort")]
def prefetch(self, indices):
self.src.prefetch(indices)
@@ -404,8 +429,8 @@ def prefetch(self, indices):
@property
def supports_prefetch(self):
return (
- hasattr(self.src, 'supports_prefetch')
+ hasattr(self.src, "supports_prefetch")
and self.src.supports_prefetch
- and hasattr(self.tgt, 'supports_prefetch')
+ and hasattr(self.tgt, "supports_prefetch")
and self.tgt.supports_prefetch
)
diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py
index 3d11f93137..e2df08e092 100644
--- a/fairseq/data/dictionary.py
+++ b/fairseq/data/dictionary.py
@@ -251,8 +251,7 @@ def add_from_file(self, f):
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
- "download an updated copy of the model file."
- .format(word)
+ "download an updated copy of the model file.".format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
diff --git a/fairseq/data/encoders/__init__.py b/fairseq/data/encoders/__init__.py
index d796496b86..2e807d8ae7 100644
--- a/fairseq/data/encoders/__init__.py
+++ b/fairseq/data/encoders/__init__.py
@@ -11,19 +11,19 @@
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry(
- '--tokenizer',
+ "--tokenizer",
default=None,
)
build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry(
- '--bpe',
+ "--bpe",
default=None,
)
# automatically import any Python files in the encoders/ directory
for file in os.listdir(os.path.dirname(__file__)):
- if file.endswith('.py') and not file.startswith('_'):
- module = file[:file.find('.py')]
- importlib.import_module('fairseq.data.encoders.' + module)
+ if file.endswith(".py") and not file.startswith("_"):
+ module = file[: file.find(".py")]
+ importlib.import_module("fairseq.data.encoders." + module)
diff --git a/fairseq/data/encoders/byte_bpe.py b/fairseq/data/encoders/byte_bpe.py
index 1d78ff9150..0d2da3ea1a 100644
--- a/fairseq/data/encoders/byte_bpe.py
+++ b/fairseq/data/encoders/byte_bpe.py
@@ -6,11 +6,15 @@
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
-from fairseq.data.encoders.byte_utils import (byte_encode, smart_byte_decode,
- SPACE, SPACE_ESCAPE)
+from fairseq.data.encoders.byte_utils import (
+ SPACE,
+ SPACE_ESCAPE,
+ byte_encode,
+ smart_byte_decode,
+)
-@register_bpe('byte_bpe')
+@register_bpe("byte_bpe")
class ByteBPE(object):
@staticmethod
def add_args(parser):
@@ -23,10 +27,13 @@ def __init__(self, args):
vocab = file_utils.cached_path(args.sentencepiece_model_path)
try:
import sentencepiece as spm
+
self.sp = spm.SentencePieceProcessor()
self.sp.Load(vocab)
except ImportError:
- raise ImportError('Please install sentencepiece with: pip install sentencepiece')
+ raise ImportError(
+ "Please install sentencepiece with: pip install sentencepiece"
+ )
def encode(self, x: str) -> str:
byte_encoded = byte_encode(x)
@@ -34,5 +41,5 @@ def encode(self, x: str) -> str:
@staticmethod
def decode(x: str) -> str:
- unescaped = x.replace(SPACE, '').replace(SPACE_ESCAPE, SPACE)
+ unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
return smart_byte_decode(unescaped)
diff --git a/fairseq/data/encoders/byte_utils.py b/fairseq/data/encoders/byte_utils.py
index 7c4bb74713..a305c08092 100644
--- a/fairseq/data/encoders/byte_utils.py
+++ b/fairseq/data/encoders/byte_utils.py
@@ -5,13 +5,13 @@
import re
-WHITESPACE_NORMALIZER = re.compile(r'\s+')
+
+WHITESPACE_NORMALIZER = re.compile(r"\s+")
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
# excluding non-breaking space (160) here
PRINTABLE_LATIN = set(
- list(range(32, 126 + 1)) + list(range(161, 172 + 1)) +
- list(range(174, 255 + 1))
+ list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
)
BYTE_TO_BCHAR = {
b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
@@ -21,19 +21,19 @@
def byte_encode(x: str) -> str:
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
- return ''.join([BYTE_TO_BCHAR[b] for b in normalized.encode('utf-8')])
+ return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
def byte_decode(x: str) -> str:
try:
- return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode('utf-8')
+ return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
except ValueError:
- return ''
+ return ""
def smart_byte_decode(x: str) -> str:
output = byte_decode(x)
- if output == '':
+ if output == "":
# DP the best recovery (max valid chars) if it's broken
n_bytes = len(x)
f = [0 for _ in range(n_bytes + 1)]
@@ -41,11 +41,11 @@ def smart_byte_decode(x: str) -> str:
for i in range(1, n_bytes + 1):
f[i], pt[i] = f[i - 1], i - 1
for j in range(1, min(4, i) + 1):
- if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j: i])) > 0:
+ if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
f[i], pt[i] = f[i - j] + 1, i - j
cur_pt = n_bytes
while cur_pt > 0:
if f[cur_pt] == f[pt[cur_pt]] + 1:
- output = byte_decode(x[pt[cur_pt]: cur_pt]) + output
+ output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
cur_pt = pt[cur_pt]
return output
diff --git a/fairseq/data/encoders/bytes.py b/fairseq/data/encoders/bytes.py
index 8bace19c53..bb9554ed53 100644
--- a/fairseq/data/encoders/bytes.py
+++ b/fairseq/data/encoders/bytes.py
@@ -5,11 +5,15 @@
from fairseq.data.encoders import register_bpe
-from fairseq.data.encoders.byte_utils import (byte_encode, smart_byte_decode,
- SPACE, SPACE_ESCAPE)
+from fairseq.data.encoders.byte_utils import (
+ SPACE,
+ SPACE_ESCAPE,
+ byte_encode,
+ smart_byte_decode,
+)
-@register_bpe('bytes')
+@register_bpe("bytes")
class Bytes(object):
def __init__(self, args):
pass
@@ -26,5 +30,5 @@ def encode(x: str) -> str:
@staticmethod
def decode(x: str) -> str:
- unescaped = x.replace(SPACE, '').replace(SPACE_ESCAPE, SPACE)
+ unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
return smart_byte_decode(unescaped)
diff --git a/fairseq/data/encoders/characters.py b/fairseq/data/encoders/characters.py
index db6a58a650..cffc57511c 100644
--- a/fairseq/data/encoders/characters.py
+++ b/fairseq/data/encoders/characters.py
@@ -6,11 +6,12 @@
from fairseq.data.encoders import register_bpe
+
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
-@register_bpe('characters')
+@register_bpe("characters")
class Characters(object):
def __init__(self, args):
pass
@@ -26,4 +27,4 @@ def encode(x: str) -> str:
@staticmethod
def decode(x: str) -> str:
- return x.replace(SPACE, '').replace(SPACE_ESCAPE, SPACE)
+ return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
diff --git a/fairseq/data/encoders/fastbpe.py b/fairseq/data/encoders/fastbpe.py
index ea0badd544..74d4ad8504 100644
--- a/fairseq/data/encoders/fastbpe.py
+++ b/fairseq/data/encoders/fastbpe.py
@@ -7,9 +7,8 @@
from fairseq.data.encoders import register_bpe
-@register_bpe('fastbpe')
+@register_bpe("fastbpe")
class fastBPE(object):
-
@staticmethod
def add_args(parser):
# fmt: off
@@ -19,17 +18,18 @@ def add_args(parser):
def __init__(self, args):
if args.bpe_codes is None:
- raise ValueError('--bpe-codes is required for --bpe=fastbpe')
+ raise ValueError("--bpe-codes is required for --bpe=fastbpe")
codes = file_utils.cached_path(args.bpe_codes)
try:
import fastBPE
+
self.bpe = fastBPE.fastBPE(codes)
self.bpe_symbol = "@@ "
except ImportError:
- raise ImportError('Please install fastBPE with: pip install fastBPE')
+ raise ImportError("Please install fastBPE with: pip install fastBPE")
def encode(self, x: str) -> str:
return self.bpe.apply([x])[0]
def decode(self, x: str) -> str:
- return (x + ' ').replace(self.bpe_symbol, '').rstrip()
+ return (x + " ").replace(self.bpe_symbol, "").rstrip()
diff --git a/fairseq/data/encoders/gpt2_bpe.py b/fairseq/data/encoders/gpt2_bpe.py
index 54e0593d00..8ac099a688 100644
--- a/fairseq/data/encoders/gpt2_bpe.py
+++ b/fairseq/data/encoders/gpt2_bpe.py
@@ -9,13 +9,12 @@
from .gpt2_bpe_utils import get_encoder
-DEFAULT_ENCODER_JSON = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
-DEFAULT_VOCAB_BPE = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
+DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
+DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
-@register_bpe('gpt2')
+@register_bpe("gpt2")
class GPT2BPE(object):
-
@staticmethod
def add_args(parser):
# fmt: off
@@ -29,21 +28,20 @@ def add_args(parser):
def __init__(self, args):
encoder_json = file_utils.cached_path(
- getattr(args, 'gpt2_encoder_json', DEFAULT_ENCODER_JSON)
+ getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON)
)
vocab_bpe = file_utils.cached_path(
- getattr(args, 'gpt2_vocab_bpe', DEFAULT_VOCAB_BPE)
+ getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE)
)
self.bpe = get_encoder(encoder_json, vocab_bpe)
def encode(self, x: str) -> str:
- return ' '.join(map(str, self.bpe.encode(x)))
+ return " ".join(map(str, self.bpe.encode(x)))
def decode(self, x: str) -> str:
- return self.bpe.decode([
- int(tok) if tok not in {'', ''} else tok
- for tok in x.split()
- ])
+ return self.bpe.decode(
+ [int(tok) if tok not in {"", ""} else tok for tok in x.split()]
+ )
def is_beginning_of_word(self, x: str) -> bool:
- return self.decode(x).startswith(' ')
+ return self.decode(x).startswith(" ")
diff --git a/fairseq/data/encoders/gpt2_bpe_utils.py b/fairseq/data/encoders/gpt2_bpe_utils.py
index 1917f82314..688d4e36e3 100644
--- a/fairseq/data/encoders/gpt2_bpe_utils.py
+++ b/fairseq/data/encoders/gpt2_bpe_utils.py
@@ -5,8 +5,8 @@
Original license: MIT
"""
-from functools import lru_cache
import json
+from functools import lru_cache
@lru_cache()
@@ -20,17 +20,22 @@ def bytes_to_unicode():
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("¡"), ord("¬") + 1))
+ + list(range(ord("®"), ord("ÿ") + 1))
+ )
cs = bs[:]
n = 0
- for b in range(2**8):
+ for b in range(2 ** 8):
if b not in bs:
bs.append(b)
- cs.append(2**8+n)
+ cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
+
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
@@ -42,25 +47,28 @@ def get_pairs(word):
prev_char = char
return pairs
-class Encoder:
- def __init__(self, encoder, bpe_merges, errors='replace'):
+class Encoder:
+ def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
- self.decoder = {v:k for k,v in self.encoder.items()}
- self.errors = errors # how to handle errors in decoding
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
- self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
try:
import regex as re
+
self.re = re
except ImportError:
- raise ImportError('Please install regex with: pip install regex')
+ raise ImportError("Please install regex with: pip install regex")
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
- self.pat = self.re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+ self.pat = self.re.compile(
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
+ )
def bpe(self, token):
if token in self.cache:
@@ -72,7 +80,7 @@ def bpe(self, token):
return token
while True:
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
@@ -87,8 +95,8 @@ def bpe(self, token):
new_word.extend(word[i:])
break
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
- new_word.append(first+second)
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
@@ -99,28 +107,33 @@ def bpe(self, token):
break
else:
pairs = get_pairs(word)
- word = ' '.join(word)
+ word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in self.re.findall(self.pat, text):
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+ bpe_tokens.extend(
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
+ )
return bpe_tokens
def decode(self, tokens):
- text = ''.join([self.decoder.get(token, token) for token in tokens])
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
+ text = "".join([self.decoder.get(token, token) for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
+ "utf-8", errors=self.errors
+ )
return text
+
def get_encoder(encoder_json_path, vocab_bpe_path):
- with open(encoder_json_path, 'r') as f:
+ with open(encoder_json_path, "r") as f:
encoder = json.load(f)
- with open(vocab_bpe_path, 'r', encoding="utf-8") as f:
+ with open(vocab_bpe_path, "r", encoding="utf-8") as f:
bpe_data = f.read()
- bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
diff --git a/fairseq/data/encoders/hf_bert_bpe.py b/fairseq/data/encoders/hf_bert_bpe.py
index 16adc45aee..a968fe8857 100644
--- a/fairseq/data/encoders/hf_bert_bpe.py
+++ b/fairseq/data/encoders/hf_bert_bpe.py
@@ -6,9 +6,8 @@
from fairseq.data.encoders import register_bpe
-@register_bpe('bert')
+@register_bpe("bert")
class BertBPE(object):
-
@staticmethod
def add_args(parser):
# fmt: off
@@ -24,25 +23,26 @@ def __init__(self, args):
from transformers import BertTokenizer
except ImportError:
raise ImportError(
- 'Please install transformers with: pip install transformers'
+ "Please install transformers with: pip install transformers"
)
- if 'bpe_vocab_file' in args:
+ if "bpe_vocab_file" in args:
self.bert_tokenizer = BertTokenizer(
- args.bpe_vocab_file,
- do_lower_case=not args.bpe_cased
+ args.bpe_vocab_file, do_lower_case=not args.bpe_cased
)
else:
- vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased'
+ vocab_file_name = (
+ "bert-base-cased" if args.bpe_cased else "bert-base-uncased"
+ )
self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
def encode(self, x: str) -> str:
- return ' '.join(self.bert_tokenizer.tokenize(x))
+ return " ".join(self.bert_tokenizer.tokenize(x))
def decode(self, x: str) -> str:
return self.bert_tokenizer.clean_up_tokenization(
- self.bert_tokenizer.convert_tokens_to_string(x.split(' '))
+ self.bert_tokenizer.convert_tokens_to_string(x.split(" "))
)
def is_beginning_of_word(self, x: str) -> bool:
- return not x.startswith('##')
+ return not x.startswith("##")
diff --git a/fairseq/data/encoders/hf_byte_bpe.py b/fairseq/data/encoders/hf_byte_bpe.py
index 2767df044e..544d408273 100644
--- a/fairseq/data/encoders/hf_byte_bpe.py
+++ b/fairseq/data/encoders/hf_byte_bpe.py
@@ -6,9 +6,8 @@
from fairseq.data.encoders import register_bpe
-@register_bpe('hf_byte_bpe')
+@register_bpe("hf_byte_bpe")
class HuggingFaceByteLevelBPE(object):
-
@staticmethod
def add_args(parser):
# fmt: off
@@ -23,24 +22,22 @@ def __init__(self, args):
from tokenizers import ByteLevelBPETokenizer
except ImportError:
raise ImportError(
- 'Please install huggingface/tokenizers with: '
- 'pip install tokenizers'
+ "Please install huggingface/tokenizers with: " "pip install tokenizers"
)
self.bpe = ByteLevelBPETokenizer(
args.bpe_vocab,
args.bpe_merges,
- add_prefix_space=getattr(args, 'bpe_add_prefix_space', False),
+ add_prefix_space=getattr(args, "bpe_add_prefix_space", False),
)
def encode(self, x: str) -> str:
- return ' '.join(map(str, self.bpe.encode(x).ids))
+ return " ".join(map(str, self.bpe.encode(x).ids))
def decode(self, x: str) -> str:
- return self.bpe.decode([
- int(tok) if tok not in {'', ''} else tok
- for tok in x.split()
- ])
+ return self.bpe.decode(
+ [int(tok) if tok not in {"", ""} else tok for tok in x.split()]
+ )
def is_beginning_of_word(self, x: str) -> bool:
- return self.decode(x).startswith(' ')
+ return self.decode(x).startswith(" ")
diff --git a/fairseq/data/encoders/moses_tokenizer.py b/fairseq/data/encoders/moses_tokenizer.py
index b1e7478b9d..8c24844263 100644
--- a/fairseq/data/encoders/moses_tokenizer.py
+++ b/fairseq/data/encoders/moses_tokenizer.py
@@ -6,9 +6,8 @@
from fairseq.data.encoders import register_tokenizer
-@register_tokenizer('moses')
+@register_tokenizer("moses")
class MosesTokenizer(object):
-
@staticmethod
def add_args(parser):
# fmt: off
@@ -25,17 +24,20 @@ def add_args(parser):
def __init__(self, args):
self.args = args
- if getattr(args, 'moses_source_lang', None) is None:
- args.moses_source_lang = getattr(args, 'source_lang', 'en')
- if getattr(args, 'moses_target_lang', None) is None:
- args.moses_target_lang = getattr(args, 'target_lang', 'en')
+ if getattr(args, "moses_source_lang", None) is None:
+ args.moses_source_lang = getattr(args, "source_lang", "en")
+ if getattr(args, "moses_target_lang", None) is None:
+ args.moses_target_lang = getattr(args, "target_lang", "en")
try:
from sacremoses import MosesTokenizer, MosesDetokenizer
+
self.tok = MosesTokenizer(args.moses_source_lang)
self.detok = MosesDetokenizer(args.moses_target_lang)
except ImportError:
- raise ImportError('Please install Moses tokenizer with: pip install sacremoses')
+ raise ImportError(
+ "Please install Moses tokenizer with: pip install sacremoses"
+ )
def encode(self, x: str) -> str:
return self.tok.tokenize(
diff --git a/fairseq/data/encoders/nltk_tokenizer.py b/fairseq/data/encoders/nltk_tokenizer.py
index 3db8ee5652..3b617e7314 100644
--- a/fairseq/data/encoders/nltk_tokenizer.py
+++ b/fairseq/data/encoders/nltk_tokenizer.py
@@ -6,18 +6,18 @@
from fairseq.data.encoders import register_tokenizer
-@register_tokenizer('nltk')
+@register_tokenizer("nltk")
class NLTKTokenizer(object):
-
def __init__(self, source_lang=None, target_lang=None):
try:
from nltk.tokenize import word_tokenize
+
self.word_tokenize = word_tokenize
except ImportError:
- raise ImportError('Please install nltk with: pip install nltk')
+ raise ImportError("Please install nltk with: pip install nltk")
def encode(self, x: str) -> str:
- return ' '.join(self.word_tokenize(x))
+ return " ".join(self.word_tokenize(x))
def decode(self, x: str) -> str:
return x
diff --git a/fairseq/data/encoders/sentencepiece_bpe.py b/fairseq/data/encoders/sentencepiece_bpe.py
index e5ff5db389..b25c6caebe 100644
--- a/fairseq/data/encoders/sentencepiece_bpe.py
+++ b/fairseq/data/encoders/sentencepiece_bpe.py
@@ -7,9 +7,8 @@
from fairseq.data.encoders import register_bpe
-@register_bpe('sentencepiece')
+@register_bpe("sentencepiece")
class SentencepieceBPE(object):
-
@staticmethod
def add_args(parser):
# fmt: off
@@ -21,23 +20,26 @@ def __init__(self, args):
sentencepiece_model = file_utils.cached_path(args.sentencepiece_model)
try:
import sentencepiece as spm
+
self.sp = spm.SentencePieceProcessor()
self.sp.Load(sentencepiece_model)
except ImportError:
- raise ImportError('Please install sentencepiece with: pip install sentencepiece')
+ raise ImportError(
+ "Please install sentencepiece with: pip install sentencepiece"
+ )
def encode(self, x: str) -> str:
- return ' '.join(self.sp.EncodeAsPieces(x))
+ return " ".join(self.sp.EncodeAsPieces(x))
def decode(self, x: str) -> str:
- return x.replace(' ', '').replace('\u2581', ' ').strip()
+ return x.replace(" ", "").replace("\u2581", " ").strip()
def is_beginning_of_word(self, x: str) -> bool:
- if x in ['', '', '', '']:
+ if x in ["", "", "", ""]:
# special elements are always considered beginnings
# HACK: this logic is already present in fairseq/tasks/masked_lm.py
# but these special tokens are also contained in the sentencepiece
# vocabulary which causes duplicate special tokens. This hack makes
# sure that they are all taken into account.
return True
- return x.startswith('\u2581')
+ return x.startswith("\u2581")
diff --git a/fairseq/data/encoders/space_tokenizer.py b/fairseq/data/encoders/space_tokenizer.py
index 670001a8e8..3bc7ce4958 100644
--- a/fairseq/data/encoders/space_tokenizer.py
+++ b/fairseq/data/encoders/space_tokenizer.py
@@ -8,14 +8,13 @@
from fairseq.data.encoders import register_tokenizer
-@register_tokenizer('space')
+@register_tokenizer("space")
class SpaceTokenizer(object):
-
def __init__(self, source_lang=None, target_lang=None):
self.space_tok = re.compile(r"\s+")
def encode(self, x: str) -> str:
- return self.space_tok.sub(' ', x)
+ return self.space_tok.sub(" ", x)
def decode(self, x: str) -> str:
return x
diff --git a/fairseq/data/encoders/subword_nmt_bpe.py b/fairseq/data/encoders/subword_nmt_bpe.py
index 78f19b43ea..e85f99af39 100644
--- a/fairseq/data/encoders/subword_nmt_bpe.py
+++ b/fairseq/data/encoders/subword_nmt_bpe.py
@@ -7,9 +7,8 @@
from fairseq.data.encoders import register_bpe
-@register_bpe('subword_nmt')
+@register_bpe("subword_nmt")
class SubwordNMTBPE(object):
-
@staticmethod
def add_args(parser):
# fmt: off
@@ -21,15 +20,20 @@ def add_args(parser):
def __init__(self, args):
if args.bpe_codes is None:
- raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
+ raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
codes = file_utils.cached_path(args.bpe_codes)
try:
from subword_nmt import apply_bpe
+
bpe_parser = apply_bpe.create_parser()
- bpe_args = bpe_parser.parse_args([
- '--codes', codes,
- '--separator', args.bpe_separator,
- ])
+ bpe_args = bpe_parser.parse_args(
+ [
+ "--codes",
+ codes,
+ "--separator",
+ args.bpe_separator,
+ ]
+ )
self.bpe = apply_bpe.BPE(
bpe_args.codes,
bpe_args.merges,
@@ -37,12 +41,14 @@ def __init__(self, args):
None,
bpe_args.glossaries,
)
- self.bpe_symbol = bpe_args.separator + ' '
+ self.bpe_symbol = bpe_args.separator + " "
except ImportError:
- raise ImportError('Please install subword_nmt with: pip install subword-nmt')
+ raise ImportError(
+ "Please install subword_nmt with: pip install subword-nmt"
+ )
def encode(self, x: str) -> str:
return self.bpe.process_line(x)
def decode(self, x: str) -> str:
- return (x + ' ').replace(self.bpe_symbol, '').rstrip()
+ return (x + " ").replace(self.bpe_symbol, "").rstrip()
diff --git a/fairseq/data/encoders/utils.py b/fairseq/data/encoders/utils.py
index a0e491c143..d93eb532ef 100644
--- a/fairseq/data/encoders/utils.py
+++ b/fairseq/data/encoders/utils.py
@@ -10,19 +10,21 @@
def get_whole_word_mask(args, dictionary):
bpe = encoders.build_bpe(args)
if bpe is not None:
+
def is_beginning_of_word(i):
if i < dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = dictionary[i]
- if tok.startswith('madeupword'):
+ if tok.startswith("madeupword"):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
- mask_whole_words = torch.ByteTensor(list(
- map(is_beginning_of_word, range(len(dictionary)))
- ))
+
+ mask_whole_words = torch.ByteTensor(
+ list(map(is_beginning_of_word, range(len(dictionary))))
+ )
return mask_whole_words
return None
diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py
index caaef8f713..ed08c1ba20 100644
--- a/fairseq/data/fairseq_dataset.py
+++ b/fairseq/data/fairseq_dataset.py
@@ -5,7 +5,6 @@
import numpy as np
import torch.utils.data
-
from fairseq.data import data_utils
@@ -112,7 +111,7 @@ def batch_by_size(
def adjust_bsz(bsz, num_tokens):
if bsz is None:
- assert max_tokens is not None, 'Must specify --max-tokens'
+ assert max_tokens is not None, "Must specify --max-tokens"
bsz = max_tokens // num_tokens
if max_sentences is not None:
bsz = min(bsz, max_sentences)
@@ -120,13 +119,15 @@ def adjust_bsz(bsz, num_tokens):
bsz >= required_batch_size_multiple
and bsz % required_batch_size_multiple != 0
):
- bsz -= (bsz % required_batch_size_multiple)
+ bsz -= bsz % required_batch_size_multiple
return bsz
- fixed_shapes = np.array([
- [adjust_bsz(bsz, num_tokens), num_tokens]
- for (bsz, num_tokens) in fixed_shapes
- ])
+ fixed_shapes = np.array(
+ [
+ [adjust_bsz(bsz, num_tokens), num_tokens]
+ for (bsz, num_tokens) in fixed_shapes
+ ]
+ )
return data_utils.batch_by_size(
indices,
@@ -154,16 +155,24 @@ def filter_indices_by_size(self, indices, max_sizes):
list: list of removed indices
"""
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
- if hasattr(self, 'sizes') and isinstance(self.sizes, np.ndarray):
+ if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
ignored = indices[self.sizes[indices] > max_sizes].tolist()
indices = indices[self.sizes[indices] <= max_sizes]
- elif hasattr(self, 'sizes') and isinstance(self.sizes, list) and len(self.sizes) == 1:
+ elif (
+ hasattr(self, "sizes")
+ and isinstance(self.sizes, list)
+ and len(self.sizes) == 1
+ ):
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
indices = indices[self.sizes[0][indices] <= max_sizes]
else:
- indices, ignored = data_utils._filter_by_size_dynamic(indices, self.size, max_sizes)
+ indices, ignored = data_utils._filter_by_size_dynamic(
+ indices, self.size, max_sizes
+ )
else:
- indices, ignored = data_utils._filter_by_size_dynamic(indices, self.size, max_sizes)
+ indices, ignored = data_utils._filter_by_size_dynamic(
+ indices, self.size, max_sizes
+ )
return indices, ignored
@property
diff --git a/fairseq/data/id_dataset.py b/fairseq/data/id_dataset.py
index 6a73ba1ff7..3e4d7969cf 100644
--- a/fairseq/data/id_dataset.py
+++ b/fairseq/data/id_dataset.py
@@ -9,7 +9,6 @@
class IdDataset(FairseqDataset):
-
def __getitem__(self, index):
return index
diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py
index 55bf0ca585..3efecab3a6 100644
--- a/fairseq/data/indexed_dataset.py
+++ b/fairseq/data/indexed_dataset.py
@@ -3,18 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from functools import lru_cache
import os
import shutil
import struct
+from functools import lru_cache
import numpy as np
import torch
-
-from . import FairseqDataset
from fairseq.data.fasta_dataset import FastaDataset
from fairseq.file_io import PathManager
+from . import FairseqDataset
+
def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
@@ -24,56 +24,59 @@ def __best_fitting_dtype(vocab_size=None):
def get_available_dataset_impl():
- return ['raw', 'lazy', 'cached', 'mmap', 'fasta']
+ return ["raw", "lazy", "cached", "mmap", "fasta"]
def infer_dataset_impl(path):
if IndexedRawTextDataset.exists(path):
- return 'raw'
+ return "raw"
elif IndexedDataset.exists(path):
- with open(index_file_path(path), 'rb') as f:
+ with open(index_file_path(path), "rb") as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
- return 'cached'
+ return "cached"
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
- return 'mmap'
+ return "mmap"
else:
return None
elif FastaDataset.exists(path):
- return 'fasta'
+ return "fasta"
else:
return None
def make_builder(out_file, impl, vocab_size=None):
- if impl == 'mmap':
- return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
- elif impl == 'fasta':
+ if impl == "mmap":
+ return MMapIndexedDatasetBuilder(
+ out_file, dtype=__best_fitting_dtype(vocab_size)
+ )
+ elif impl == "fasta":
raise NotImplementedError
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
- if impl == 'raw' and IndexedRawTextDataset.exists(path):
+ if impl == "raw" and IndexedRawTextDataset.exists(path):
assert dictionary is not None
return IndexedRawTextDataset(path, dictionary)
- elif impl == 'lazy' and IndexedDataset.exists(path):
+ elif impl == "lazy" and IndexedDataset.exists(path):
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
- elif impl == 'cached' and IndexedDataset.exists(path):
+ elif impl == "cached" and IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
- elif impl == 'mmap' and MMapIndexedDataset.exists(path):
+ elif impl == "mmap" and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path)
- elif impl == 'fasta' and FastaDataset.exists(path):
+ elif impl == "fasta" and FastaDataset.exists(path):
from fairseq.data.fasta_dataset import EncodedFastaDataset
+
return EncodedFastaDataset(path, dictionary)
return None
def dataset_exists(path, impl):
- if impl == 'raw':
+ if impl == "raw":
return IndexedRawTextDataset.exists(path)
- elif impl == 'mmap':
+ elif impl == "mmap":
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
@@ -97,7 +100,7 @@ def write_longs(f, a):
5: np.int64,
6: np.float,
7: np.double,
- 8: np.uint16
+ 8: np.uint16,
}
@@ -109,16 +112,17 @@ def code(dtype):
def index_file_path(prefix_path):
- return prefix_path + '.idx'
+ return prefix_path + ".idx"
def data_file_path(prefix_path):
- return prefix_path + '.bin'
+ return prefix_path + ".bin"
class IndexedDataset(FairseqDataset):
"""Loader for TorchNet IndexedDataset"""
- _HDR_MAGIC = b'TNTIDX\x00\x00'
+
+ _HDR_MAGIC = b"TNTIDX\x00\x00"
def __init__(self, path, fix_lua_indexing=False):
super().__init__()
@@ -128,27 +132,27 @@ def __init__(self, path, fix_lua_indexing=False):
self.read_index(path)
def read_index(self, path):
- with open(index_file_path(path), 'rb') as f:
+ with open(index_file_path(path), "rb") as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
- 'Index file doesn\'t match expected format. '
- 'Make sure that --dataset-impl is configured properly.'
+ "Index file doesn't match expected format. "
+ "Make sure that --dataset-impl is configured properly."
)
version = f.read(8)
- assert struct.unpack('= self._len:
- raise IndexError('index out of range')
+ raise IndexError("index out of range")
def __del__(self):
if self.data_file:
@@ -159,7 +163,7 @@ def __getitem__(self, i):
if not self.data_file:
self.read_data(self.path)
self.check_index(i)
- tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
+ tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
@@ -179,8 +183,8 @@ def size(self, index):
@staticmethod
def exists(path):
- return (
- PathManager.exists(index_file_path(path)) and PathManager.exists(data_file_path(path))
+ return PathManager.exists(index_file_path(path)) and PathManager.exists(
+ data_file_path(path)
)
@property
@@ -189,7 +193,6 @@ def supports_prefetch(self):
class IndexedCachedDataset(IndexedDataset):
-
def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
self.cache = None
@@ -214,7 +217,7 @@ def prefetch(self, indices):
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
- a = self.cache[ptx: ptx + size]
+ a = self.cache[ptx : ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
@@ -226,10 +229,10 @@ def prefetch(self, indices):
@lru_cache(maxsize=8)
def __getitem__(self, i):
self.check_index(i)
- tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
+ tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
- np.copyto(a, self.cache[ptx: ptx + a.size])
+ np.copyto(a, self.cache[ptx : ptx + a.size])
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
@@ -250,12 +253,14 @@ def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.size = len(self.tokens_list)
def read_data(self, path, dictionary):
- with open(path, 'r', encoding='utf-8') as f:
+ with open(path, "r", encoding="utf-8") as f:
for line in f:
- self.lines.append(line.strip('\n'))
+ self.lines.append(line.strip("\n"))
tokens = dictionary.encode_line(
- line, add_if_not_exist=False,
- append_eos=self.append_eos, reverse_order=self.reverse_order,
+ line,
+ add_if_not_exist=False,
+ append_eos=self.append_eos,
+ reverse_order=self.reverse_order,
).long()
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
@@ -263,7 +268,7 @@ def read_data(self, path, dictionary):
def check_index(self, i):
if i < 0 or i >= self.size:
- raise IndexError('index out of range')
+ raise IndexError("index out of range")
@lru_cache(maxsize=8)
def __getitem__(self, i):
@@ -299,11 +304,11 @@ class IndexedDatasetBuilder(object):
np.int32: 4,
np.int64: 8,
np.float: 4,
- np.double: 8
+ np.double: 8,
}
def __init__(self, out_file, dtype=np.int32):
- self.out_file = open(out_file, 'wb')
+ self.out_file = open(out_file, "wb")
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
@@ -330,7 +335,7 @@ def merge_file_(self, another_file):
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
- with open(data_file_path(another_file), 'rb') as f:
+ with open(data_file_path(another_file), "rb") as f:
while True:
data = f.read(1024)
if data:
@@ -340,11 +345,11 @@ def merge_file_(self, another_file):
def finalize(self, index_file):
self.out_file.close()
- index = open(index_file, 'wb')
- index.write(b'TNTIDX\x00\x00')
- index.write(struct.pack('= self.total:
raise RuntimeError(
- 'Mismatch between actual and expected iterable length. '
- 'Please report this to the fairseq developers.'
+ "Mismatch between actual and expected iterable length. "
+ "Please report this to the fairseq developers."
)
self.n += 1
yield x
@@ -138,7 +137,11 @@ def load_state_dict(self, state_dict):
class StreamingEpochBatchIterator(EpochBatchIterating):
def __init__(
- self, dataset, epoch=1, num_shards=1, shard_id=0,
+ self,
+ dataset,
+ epoch=1,
+ num_shards=1,
+ shard_id=0,
):
assert isinstance(dataset, torch.utils.data.IterableDataset)
self.dataset = dataset
@@ -178,11 +181,11 @@ def iterations_in_epoch(self) -> int:
def state_dict(self):
return {
- 'epoch': self.epoch,
+ "epoch": self.epoch,
}
def load_state_dict(self, state_dict):
- self.epoch = state_dict['epoch']
+ self.epoch = state_dict["epoch"]
class EpochBatchIterator(EpochBatchIterating):
@@ -222,14 +225,25 @@ class EpochBatchIterator(EpochBatchIterating):
"""
def __init__(
- self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
- num_workers=0, epoch=1, buffer_size=0, timeout=0,
+ self,
+ dataset,
+ collate_fn,
+ batch_sampler,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=1,
+ buffer_size=0,
+ timeout=0,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.collate_fn = collate_fn
self.batch_sampler = batch_sampler
- self._frozen_batches = tuple(batch_sampler) if not callable(batch_sampler) else None
+ self._frozen_batches = (
+ tuple(batch_sampler) if not callable(batch_sampler) else None
+ )
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
@@ -243,7 +257,7 @@ def __init__(
self.shuffle = True
self._cur_epoch_itr = None
self._next_epoch_itr = None
- self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
+ self._supports_prefetch = getattr(dataset, "supports_prefetch", False)
@property
def frozen_batches(self):
@@ -303,7 +317,9 @@ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
# reset _frozen_batches to refresh the next epoch
self._frozen_batches = None
self._cur_epoch_itr = self._get_iterator_for_epoch(
- self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
+ self.epoch,
+ shuffle,
+ fix_batches_to_gpus=fix_batches_to_gpus,
)
self.shuffle = shuffle
return self._cur_epoch_itr
@@ -330,22 +346,22 @@ def state_dict(self):
epoch = self.epoch
iter_in_epoch = self.iterations_in_epoch
return {
- 'version': 2,
- 'epoch': epoch,
- 'iterations_in_epoch': iter_in_epoch,
- 'shuffle': self.shuffle,
+ "version": 2,
+ "epoch": epoch,
+ "iterations_in_epoch": iter_in_epoch,
+ "shuffle": self.shuffle,
}
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
- self.epoch = state_dict['epoch']
- itr_pos = state_dict.get('iterations_in_epoch', 0)
- version = state_dict.get('version', 1)
+ self.epoch = state_dict["epoch"]
+ itr_pos = state_dict.get("iterations_in_epoch", 0)
+ version = state_dict.get("version", 1)
if itr_pos > 0:
# fast-forward epoch iterator
self._next_epoch_itr = self._get_iterator_for_epoch(
self.epoch,
- shuffle=state_dict.get('shuffle', True),
+ shuffle=state_dict.get("shuffle", True),
offset=itr_pos,
)
if self._next_epoch_itr is None:
@@ -354,15 +370,16 @@ def load_state_dict(self, state_dict):
self.epoch += 1
else:
raise RuntimeError(
- 'Cannot resume training due to dataloader mismatch, please '
- 'report this to the fairseq developers. You can relaunch '
- 'training with `--reset-dataloader` and it should work.'
+ "Cannot resume training due to dataloader mismatch, please "
+ "report this to the fairseq developers. You can relaunch "
+ "training with `--reset-dataloader` and it should work."
)
else:
self._next_epoch_itr = None
- def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0):
-
+ def _get_iterator_for_epoch(
+ self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
+ ):
def shuffle_batches(batches, seed):
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
@@ -374,9 +391,9 @@ def shuffle_batches(batches, seed):
if shuffle and not fix_batches_to_gpus:
batches = shuffle_batches(list(batches), self.seed + epoch)
- batches = list(ShardedIterator(
- batches, self.num_shards, self.shard_id, fill_value=[]
- ))
+ batches = list(
+ ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
+ )
self.dataset.prefetch([i for s in batches for i in s])
if shuffle and fix_batches_to_gpus:
@@ -386,15 +403,15 @@ def shuffle_batches(batches, seed):
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
else:
batches = self.frozen_batches
- batches = list(ShardedIterator(
- batches, self.num_shards, self.shard_id, fill_value=[]
- ))
+ batches = list(
+ ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
+ )
if offset > 0 and offset >= len(batches):
return None
if self.num_workers > 0:
- os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning'
+ os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
# Create data loader
itr = torch.utils.data.DataLoader(
@@ -429,7 +446,7 @@ def __init__(self, iterable, chunk_size):
itr = _chunk_iterator(iterable, chunk_size)
super().__init__(
itr,
- start=int(math.ceil(getattr(iterable, 'n', 0) / float(chunk_size))),
+ start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))),
total=int(math.ceil(len(iterable) / float(chunk_size))),
)
self.chunk_size = chunk_size
@@ -462,7 +479,7 @@ class ShardedIterator(CountingIterator):
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
- raise ValueError('shard_id must be between 0 and num_shards')
+ raise ValueError("shard_id must be between 0 and num_shards")
sharded_len = int(math.ceil(len(iterable) / float(num_shards)))
itr = map(
operator.itemgetter(1),
@@ -474,7 +491,7 @@ def __init__(self, iterable, num_shards, shard_id, fill_value=None):
)
super().__init__(
itr,
- start=int(math.ceil(getattr(iterable, 'n', 0) / float(num_shards))),
+ start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))),
total=sharded_len,
)
@@ -545,7 +562,10 @@ def __next__(self):
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
if time.time() - self.start_time > 5 * 60:
- if self.warning_time is None or time.time() - self.warning_time > 15 * 60:
+ if (
+ self.warning_time is None
+ or time.time() - self.warning_time > 15 * 60
+ ):
logger.debug(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py
index 3014354e7c..62e7109b33 100644
--- a/fairseq/data/language_pair_dataset.py
+++ b/fairseq/data/language_pair_dataset.py
@@ -7,8 +7,7 @@
import numpy as np
import torch
-
-from fairseq.data import data_utils, FairseqDataset
+from fairseq.data import FairseqDataset, data_utils
logger = logging.getLogger(__name__)
@@ -30,7 +29,10 @@ def collate(
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
- pad_idx, eos_idx, left_pad, move_eos_to_beginning,
+ pad_idx,
+ eos_idx,
+ left_pad,
+ move_eos_to_beginning,
pad_to_length=pad_to_length,
pad_to_multiple=pad_to_multiple,
)
@@ -38,7 +40,10 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
def check_alignment(alignment, src_len, tgt_len):
if alignment is None or len(alignment) == 0:
return False
- if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
+ if (
+ alignment[:, 0].max().item() >= src_len - 1
+ or alignment[:, 1].max().item() >= tgt_len - 1
+ ):
logger.warning("alignment size mismatch found, skipping alignment!")
return False
return True
@@ -53,78 +58,90 @@ def compute_alignment_weights(alignments):
index 3 is repeated twice)
"""
align_tgt = alignments[:, 1]
- _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True)
+ _, align_tgt_i, align_tgt_c = torch.unique(
+ align_tgt, return_inverse=True, return_counts=True
+ )
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
- return 1. / align_weights.float()
+ return 1.0 / align_weights.float()
- id = torch.LongTensor([s['id'] for s in samples])
+ id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
- 'source', left_pad=left_pad_source,
- pad_to_length=pad_to_length['source'] if pad_to_length is not None else None
+ "source",
+ left_pad=left_pad_source,
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
- src_lengths = torch.LongTensor([
- s['source'].ne(pad_idx).long().sum() for s in samples
- ])
+ src_lengths = torch.LongTensor(
+ [s["source"].ne(pad_idx).long().sum() for s in samples]
+ )
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
- if samples[0].get('target', None) is not None:
+ if samples[0].get("target", None) is not None:
target = merge(
- 'target', left_pad=left_pad_target,
- pad_to_length=pad_to_length['target'] if pad_to_length is not None else None,
+ "target",
+ left_pad=left_pad_target,
+ pad_to_length=pad_to_length["target"]
+ if pad_to_length is not None
+ else None,
)
target = target.index_select(0, sort_order)
- tgt_lengths = torch.LongTensor([
- s['target'].ne(pad_idx).long().sum() for s in samples
- ]).index_select(0, sort_order)
+ tgt_lengths = torch.LongTensor(
+ [s["target"].ne(pad_idx).long().sum() for s in samples]
+ ).index_select(0, sort_order)
ntokens = tgt_lengths.sum().item()
- if samples[0].get('prev_output_tokens', None) is not None:
- prev_output_tokens = merge('prev_output_tokens', left_pad=left_pad_target)
+ if samples[0].get("prev_output_tokens", None) is not None:
+ prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
elif input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
- 'target',
+ "target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
- pad_to_length=pad_to_length['target'] if pad_to_length is not None else None,
+ pad_to_length=pad_to_length["target"]
+ if pad_to_length is not None
+ else None,
)
else:
ntokens = src_lengths.sum().item()
batch = {
- 'id': id,
- 'nsentences': len(samples),
- 'ntokens': ntokens,
- 'net_input': {
- 'src_tokens': src_tokens,
- 'src_lengths': src_lengths,
+ "id": id,
+ "nsentences": len(samples),
+ "ntokens": ntokens,
+ "net_input": {
+ "src_tokens": src_tokens,
+ "src_lengths": src_lengths,
},
- 'target': target,
+ "target": target,
}
if prev_output_tokens is not None:
- batch['net_input']['prev_output_tokens'] = prev_output_tokens.index_select(0, sort_order)
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
+ 0, sort_order
+ )
- if samples[0].get('alignment', None) is not None:
- bsz, tgt_sz = batch['target'].shape
- src_sz = batch['net_input']['src_tokens'].shape[1]
+ if samples[0].get("alignment", None) is not None:
+ bsz, tgt_sz = batch["target"].shape
+ src_sz = batch["net_input"]["src_tokens"].shape[1]
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
- offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz)
+ offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
if left_pad_source:
- offsets[:, 0] += (src_sz - src_lengths)
+ offsets[:, 0] += src_sz - src_lengths
if left_pad_target:
- offsets[:, 1] += (tgt_sz - tgt_lengths)
+ offsets[:, 1] += tgt_sz - tgt_lengths
alignments = [
alignment + offset
- for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths)
- for alignment in [samples[align_idx]['alignment'].view(-1, 2)]
+ for align_idx, offset, src_len, tgt_len in zip(
+ sort_order, offsets, src_lengths, tgt_lengths
+ )
+ for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
if check_alignment(alignment, src_len, tgt_len)
]
@@ -132,8 +149,8 @@ def compute_alignment_weights(alignments):
alignments = torch.cat(alignments, dim=0)
align_weights = compute_alignment_weights(alignments)
- batch['alignments'] = alignments
- batch['align_weights'] = align_weights
+ batch["alignments"] = alignments
+ batch["align_weights"] = align_weights
if samples[0].get("constraints", None) is not None:
# Collate the packed constraints across the samples, padding to
@@ -142,7 +159,7 @@ def compute_alignment_weights(alignments):
max_len = max(lens)
constraints = torch.zeros((len(samples), max(lens))).long()
for i, sample in enumerate(samples):
- constraints[i, 0:lens[i]] = samples[i].get("constraints")
+ constraints[i, 0 : lens[i]] = samples[i].get("constraints")
batch["constraints"] = constraints
return batch
@@ -188,14 +205,23 @@ class LanguagePairDataset(FairseqDataset):
"""
def __init__(
- self, src, src_sizes, src_dict,
- tgt=None, tgt_sizes=None, tgt_dict=None,
- left_pad_source=True, left_pad_target=False,
- shuffle=True, input_feeding=True,
- remove_eos_from_source=False, append_eos_to_target=False,
+ self,
+ src,
+ src_sizes,
+ src_dict,
+ tgt=None,
+ tgt_sizes=None,
+ tgt_dict=None,
+ left_pad_source=True,
+ left_pad_target=False,
+ shuffle=True,
+ input_feeding=True,
+ remove_eos_from_source=False,
+ append_eos_to_target=False,
align_dataset=None,
constraints=None,
- append_bos=False, eos=None,
+ append_bos=False,
+ eos=None,
num_buckets=0,
src_lang_id=None,
tgt_lang_id=None,
@@ -206,12 +232,18 @@ def __init__(
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
if tgt is not None:
- assert len(src) == len(tgt), "Source and target must contain the same number of examples"
+ assert len(src) == len(
+ tgt
+ ), "Source and target must contain the same number of examples"
self.src = src
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
- self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes
+ self.sizes = (
+ np.vstack((self.src_sizes, self.tgt_sizes)).T
+ if self.tgt_sizes is not None
+ else self.src_sizes
+ )
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
@@ -222,14 +254,17 @@ def __init__(
self.append_eos_to_target = append_eos_to_target
self.align_dataset = align_dataset
if self.align_dataset is not None:
- assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
+ assert (
+ self.tgt_sizes is not None
+ ), "Both source and target needed when alignments are provided"
self.constraints = constraints
self.append_bos = append_bos
- self.eos = (eos if eos is not None else src_dict.eos())
+ self.eos = eos if eos is not None else src_dict.eos()
self.src_lang_id = src_lang_id
self.tgt_lang_id = tgt_lang_id
if num_buckets > 0:
from fairseq.data import BucketPadLengthDataset
+
self.src = BucketPadLengthDataset(
self.src,
sizes=self.src_sizes,
@@ -238,7 +273,7 @@ def __init__(
left_pad=self.left_pad_source,
)
self.src_sizes = self.src.sizes
- logger.info('bucketing source lengths: {}'.format(list(self.src.buckets)))
+ logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
if self.tgt is not None:
self.tgt = BucketPadLengthDataset(
self.tgt,
@@ -248,15 +283,16 @@ def __init__(
left_pad=self.left_pad_target,
)
self.tgt_sizes = self.tgt.sizes
- logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets)))
+ logger.info(
+ "bucketing target lengths: {}".format(list(self.tgt.buckets))
+ )
# determine bucket sizes using self.num_tokens, which will return
# the padded lengths (thanks to BucketPadLengthDataset)
num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
self.buckets = [
- (None, num_tokens)
- for num_tokens in np.unique(self.bucketed_num_tokens)
+ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
]
else:
self.buckets = None
@@ -292,12 +328,12 @@ def __getitem__(self, index):
src_item = self.src[index][:-1]
example = {
- 'id': index,
- 'source': src_item,
- 'target': tgt_item,
+ "id": index,
+ "source": src_item,
+ "target": tgt_item,
}
if self.align_dataset is not None:
- example['alignment'] = self.align_dataset[index]
+ example["alignment"] = self.align_dataset[index]
if self.constraints is not None:
example["constraints"] = self.constraints[index]
return example
@@ -352,27 +388,33 @@ def collater(self, samples, pad_to_length=None):
pad_to_multiple=self.pad_to_multiple,
)
if self.src_lang_id is not None or self.tgt_lang_id is not None:
- src_tokens = res['net_input']['src_tokens']
+ src_tokens = res["net_input"]["src_tokens"]
bsz = src_tokens.size(0)
if self.src_lang_id is not None:
- res['net_input']['src_lang_id'] = torch.LongTensor(
- [[self.src_lang_id]]
- ).expand(bsz, 1).to(src_tokens)
+ res["net_input"]["src_lang_id"] = (
+ torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
+ )
if self.tgt_lang_id is not None:
- res['tgt_lang_id'] = torch.LongTensor(
- [[self.tgt_lang_id]]
- ).expand(bsz, 1).to(src_tokens)
+ res["tgt_lang_id"] = (
+ torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
+ )
return res
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
- return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
+ return max(
+ self.src_sizes[index],
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
+ )
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
- return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
+ return (
+ self.src_sizes[index],
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
+ )
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
@@ -384,22 +426,19 @@ def ordered_indices(self):
if self.buckets is None:
# sort by target length, then source length
if self.tgt_sizes is not None:
- indices = indices[
- np.argsort(self.tgt_sizes[indices], kind='mergesort')
- ]
- return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
+ indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
+ return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
else:
# sort by bucketed_num_tokens, which is:
# max(padded_src_len, padded_tgt_len)
return indices[
- np.argsort(self.bucketed_num_tokens[indices], kind='mergesort')
+ np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
]
@property
def supports_prefetch(self):
- return (
- getattr(self.src, 'supports_prefetch', False)
- and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None)
+ return getattr(self.src, "supports_prefetch", False) and (
+ getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
)
def prefetch(self, indices):
@@ -410,7 +449,7 @@ def prefetch(self, indices):
self.align_dataset.prefetch(indices)
def filter_indices_by_size(self, indices, max_sizes):
- """ Filter a list of sample indices. Remove those that are longer
+ """Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
diff --git a/fairseq/data/legacy/__init__.py b/fairseq/data/legacy/__init__.py
index 1acaafeb09..9bd5c72b5e 100644
--- a/fairseq/data/legacy/__init__.py
+++ b/fairseq/data/legacy/__init__.py
@@ -3,13 +3,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
from .block_pair_dataset import BlockPairDataset
from .masked_lm_dataset import MaskedLMDataset
+from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
+
__all__ = [
- 'BertDictionary',
- 'BlockPairDataset',
- 'MaskedLMDataset',
- 'MaskedLMDictionary',
+ "BertDictionary",
+ "BlockPairDataset",
+ "MaskedLMDataset",
+ "MaskedLMDictionary",
]
diff --git a/fairseq/data/legacy/block_pair_dataset.py b/fairseq/data/legacy/block_pair_dataset.py
index b9fc814147..ba069b4605 100644
--- a/fairseq/data/legacy/block_pair_dataset.py
+++ b/fairseq/data/legacy/block_pair_dataset.py
@@ -7,7 +7,6 @@
import numpy as np
import torch
-
from fairseq.data import FairseqDataset
diff --git a/fairseq/data/legacy/masked_lm_dataset.py b/fairseq/data/legacy/masked_lm_dataset.py
index 953aa85dd4..dd8ea2c60a 100644
--- a/fairseq/data/legacy/masked_lm_dataset.py
+++ b/fairseq/data/legacy/masked_lm_dataset.py
@@ -4,18 +4,14 @@
# LICENSE file in the root directory of this source tree.
import math
+from typing import Dict, List, Tuple
import numpy as np
import torch
-
-from typing import Dict, List, Tuple
-
-from fairseq.data import FairseqDataset, data_utils
-
-from fairseq.data import Dictionary
+from fairseq.data import Dictionary, FairseqDataset, data_utils
+from fairseq.data.concat_dataset import ConcatDataset
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.token_block_dataset import TokenBlockDataset
-from fairseq.data.concat_dataset import ConcatDataset
class MaskedLMDataset(FairseqDataset):
@@ -55,29 +51,31 @@ class MaskedLMDataset(FairseqDataset):
"""
def __init__(
- self,
- dataset: FairseqDataset,
- sizes: np.ndarray,
- vocab: Dictionary,
- pad_idx: int,
- mask_idx: int,
- classif_token_idx: int,
- sep_token_idx: int,
- seed: int = 1,
- shuffle: bool = True,
- has_pairs: bool = True,
- segment_id: int = 0,
- masking_ratio: float = 0.15,
- masking_prob: float = 0.8,
- random_token_prob: float = 0.1
+ self,
+ dataset: FairseqDataset,
+ sizes: np.ndarray,
+ vocab: Dictionary,
+ pad_idx: int,
+ mask_idx: int,
+ classif_token_idx: int,
+ sep_token_idx: int,
+ seed: int = 1,
+ shuffle: bool = True,
+ has_pairs: bool = True,
+ segment_id: int = 0,
+ masking_ratio: float = 0.15,
+ masking_prob: float = 0.8,
+ random_token_prob: float = 0.1,
):
# Make sure the input datasets are the ones supported
assert (
- isinstance(dataset, TokenBlockDataset) or
- isinstance(dataset, BlockPairDataset) or
- isinstance(dataset, ConcatDataset)
- ), "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or " \
- "ConcatDataset"
+ isinstance(dataset, TokenBlockDataset)
+ or isinstance(dataset, BlockPairDataset)
+ or isinstance(dataset, ConcatDataset)
+ ), (
+ "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or "
+ "ConcatDataset"
+ )
self.dataset = dataset
self.sizes = np.array(sizes)
@@ -99,10 +97,7 @@ def __init__(
if not has_pairs:
self.sizes = self.sizes + 1
- def __getitem__(
- self,
- index: int
- ):
+ def __getitem__(self, index: int):
# if has_pairs, then expect 2 blocks and a sentence target
if self.has_pairs:
(block_one, block_two, sentence_target) = self.dataset[index]
@@ -120,11 +115,11 @@ def __len__(self):
return len(self.dataset)
def _mask_block(
- self,
- sentence: np.ndarray,
- mask_idx: int,
- pad_idx: int,
- dictionary_token_range: Tuple,
+ self,
+ sentence: np.ndarray,
+ mask_idx: int,
+ pad_idx: int,
+ dictionary_token_range: Tuple,
):
"""
Mask tokens for Masked Language Model training
@@ -166,22 +161,15 @@ def _mask_block(
# masking_prob + random_token_prob (Eg: 0.9)
elif rand < (self.masking_prob + self.random_token_prob):
# sample random token from dictionary
- masked_sent[i] = (
- np.random.randint(
- dictionary_token_range[0], dictionary_token_range[1]
- )
+ masked_sent[i] = np.random.randint(
+ dictionary_token_range[0], dictionary_token_range[1]
)
else:
target[i] = pad_idx
return masked_sent, target
- def _collate(
- self,
- samples: List[Dict],
- pad_idx: int,
- eos_idx: int
- ):
+ def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int):
"""
Does the heavy lifting for creating a batch from the input list of
examples. The logic is as follows:
@@ -215,12 +203,13 @@ def _collate(
# mask according to specified probabilities.
masked_blk_one, masked_tgt_one = self._mask_block(
- s["block_one"], self.mask_idx, self.pad_idx, token_range,
+ s["block_one"],
+ self.mask_idx,
+ self.pad_idx,
+ token_range,
)
- tokens = np.concatenate([
- [self.classif_token_idx], masked_blk_one
- ])
+ tokens = np.concatenate([[self.classif_token_idx], masked_blk_one])
targets = np.concatenate([[self.pad_idx], masked_tgt_one])
segments = np.ones(len(tokens)) * self.segment_id
@@ -232,9 +221,9 @@ def _collate(
targets_one = np.concatenate([targets, [self.pad_idx]])
masked_blk_two, masked_tgt_two = self._mask_block(
- s["block_two"], self.mask_idx, self.pad_idx, token_range)
- tokens_two = np.concatenate(
- [masked_blk_two, [self.sep_token_idx]])
+ s["block_two"], self.mask_idx, self.pad_idx, token_range
+ )
+ tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]])
targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])
# block + 1 sep + 1 special (CLS)
@@ -254,6 +243,7 @@ def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False
)
+
return {
"id": torch.LongTensor([s["id"] for s in samples]),
"ntokens": sum(len(s["source"]) for s in samples),
@@ -262,16 +252,13 @@ def merge(key):
"segment_labels": merge("segment_labels"),
},
"lm_target": merge("lm_target"),
- "sentence_target": torch.LongTensor(
- [s["sentence_target"] for s in samples]
- ) if self.has_pairs else None,
+ "sentence_target": torch.LongTensor([s["sentence_target"] for s in samples])
+ if self.has_pairs
+ else None,
"nsentences": len(samples),
}
- def collater(
- self,
- samples: List[Dict]
- ):
+ def collater(self, samples: List[Dict]):
"""Merge a list of samples to form a mini-batch.
Args:
@@ -282,20 +269,14 @@ def collater(
"""
return self._collate(samples, self.vocab.pad(), self.vocab.eos())
- def num_tokens(
- self,
- index: int
- ):
+ def num_tokens(self, index: int):
"""
Return the number of tokens in a sample. This value is used to
enforce max-tokens during batching.
"""
return self.sizes[index]
- def size(
- self,
- index: int
- ):
+ def size(self, index: int):
"""
Return an example's size as a float or tuple. This value is used when
filtering a dataset with max-positions.
diff --git a/fairseq/data/legacy/masked_lm_dictionary.py b/fairseq/data/legacy/masked_lm_dictionary.py
index bff4bcb5ec..dee88f7a3e 100644
--- a/fairseq/data/legacy/masked_lm_dictionary.py
+++ b/fairseq/data/legacy/masked_lm_dictionary.py
@@ -11,12 +11,13 @@ class MaskedLMDictionary(Dictionary):
Dictionary for Masked Language Modelling tasks. This extends Dictionary by
adding the mask symbol.
"""
+
def __init__(
self,
- pad='',
- eos='',
- unk='',
- mask='',
+ pad="",
+ eos="",
+ unk="",
+ mask="",
):
super().__init__(pad=pad, eos=eos, unk=unk)
self.mask_word = mask
@@ -33,14 +34,15 @@ class BertDictionary(MaskedLMDictionary):
Dictionary for BERT task. This extends MaskedLMDictionary by adding support
for cls and sep symbols.
"""
+
def __init__(
self,
- pad='',
- eos='',
- unk='',
- mask='',
- cls='',
- sep=''
+ pad="",
+ eos="",
+ unk="",
+ mask="",
+ cls="",
+ sep="",
):
super().__init__(pad=pad, eos=eos, unk=unk, mask=mask)
self.cls_word = cls
diff --git a/fairseq/data/list_dataset.py b/fairseq/data/list_dataset.py
index b96bba3437..12f00aa436 100644
--- a/fairseq/data/list_dataset.py
+++ b/fairseq/data/list_dataset.py
@@ -7,7 +7,6 @@
class ListDataset(BaseWrapperDataset):
-
def __init__(self, dataset, sizes=None):
super().__init__(dataset)
self._sizes = sizes
diff --git a/fairseq/data/lm_context_window_dataset.py b/fairseq/data/lm_context_window_dataset.py
index 17ba08bc7f..29ad887b7d 100644
--- a/fairseq/data/lm_context_window_dataset.py
+++ b/fairseq/data/lm_context_window_dataset.py
@@ -5,7 +5,6 @@
import numpy as np
import torch
-
from fairseq.data.monolingual_dataset import MonolingualDataset
from . import FairseqDataset
@@ -35,11 +34,11 @@ def collater(self, samples):
pad = self.pad_idx
max_sample_len = self.tokens_per_sample + self.context_window
- bsz, tsz = sample['net_input']['src_tokens'].shape
+ bsz, tsz = sample["net_input"]["src_tokens"].shape
start_idxs = [0] * bsz
- toks = sample['net_input']['src_tokens']
- lengths = sample['net_input']['src_lengths']
- tgt = sample['target']
+ toks = sample["net_input"]["src_tokens"]
+ lengths = sample["net_input"]["src_lengths"]
+ tgt = sample["target"]
new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64)
new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64)
sample_lens = toks.ne(pad).long().sum(dim=1).cpu()
@@ -50,13 +49,15 @@ def collater(self, samples):
self.prev_tokens = self.prev_tokens[extra:]
pads = np.full(self.context_window - len(self.prev_tokens), pad)
new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads])
- new_tgt[i, len(self.prev_tokens):len(self.prev_tokens) + len(tgt[i])] = tgt[i]
+ new_tgt[
+ i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i])
+ ] = tgt[i]
start_idxs[i] = len(self.prev_tokens)
lengths[i] += len(self.prev_tokens)
- self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window:]
- sample['net_input']['src_tokens'] = torch.from_numpy(new_toks)
- sample['target'] = torch.from_numpy(new_tgt)
- sample['start_indices'] = start_idxs
+ self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window :]
+ sample["net_input"]["src_tokens"] = torch.from_numpy(new_toks)
+ sample["target"] = torch.from_numpy(new_tgt)
+ sample["start_indices"] = start_idxs
return sample
@@ -72,7 +73,7 @@ def ordered_indices(self):
@property
def supports_prefetch(self):
- return getattr(self.dataset, 'supports_prefetch', False)
+ return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
diff --git a/fairseq/data/lru_cache_dataset.py b/fairseq/data/lru_cache_dataset.py
index 833a2c75cb..a7854ac170 100644
--- a/fairseq/data/lru_cache_dataset.py
+++ b/fairseq/data/lru_cache_dataset.py
@@ -9,7 +9,6 @@
class LRUCacheDataset(BaseWrapperDataset):
-
def __init__(self, dataset, token=None):
super().__init__(dataset)
diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py
index 31f5459307..8ea86245f7 100644
--- a/fairseq/data/mask_tokens_dataset.py
+++ b/fairseq/data/mask_tokens_dataset.py
@@ -7,8 +7,7 @@
import numpy as np
import torch
-
-from fairseq.data import data_utils, Dictionary
+from fairseq.data import Dictionary, data_utils
from . import BaseWrapperDataset, LRUCacheDataset
@@ -86,7 +85,7 @@ def __init__(
weights = np.array(self.vocab.count)
else:
weights = np.ones(len(self.vocab))
- weights[:self.vocab.nspecial] = 0
+ weights[: self.vocab.nspecial] = 0
self.weights = weights / weights.sum()
self.epoch = 0
@@ -105,10 +104,11 @@ def __getitem__(self, index: int):
item = self.dataset[index]
sz = len(item)
- assert self.mask_idx not in item, \
- 'Dataset contains mask_idx (={}), this is not expected!'.format(
- self.mask_idx,
- )
+ assert (
+ self.mask_idx not in item
+ ), "Dataset contains mask_idx (={}), this is not expected!".format(
+ self.mask_idx,
+ )
if self.mask_whole_words is not None:
word_begins_mask = self.mask_whole_words.gather(0, item)
@@ -122,7 +122,8 @@ def __getitem__(self, index: int):
mask = np.full(sz, False)
num_mask = int(
# add a random number for probabilistic rounding
- self.mask_prob * sz + np.random.rand()
+ self.mask_prob * sz
+ + np.random.rand()
)
mask[np.random.choice(sz, num_mask, replace=False)] = True
diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py
index 76c3772374..ec73f1fda8 100644
--- a/fairseq/data/monolingual_dataset.py
+++ b/fairseq/data/monolingual_dataset.py
@@ -6,7 +6,7 @@
import numpy as np
import torch
-from . import data_utils, FairseqDataset
+from . import FairseqDataset, data_utils
def collate(samples, pad_idx, eos_idx):
@@ -17,33 +17,39 @@ def merge(key, is_list=False):
if is_list:
res = []
for i in range(len(samples[0][key])):
- res.append(data_utils.collate_tokens(
- [s[key][i] for s in samples], pad_idx, eos_idx, left_pad=False,
- ))
+ res.append(
+ data_utils.collate_tokens(
+ [s[key][i] for s in samples],
+ pad_idx,
+ eos_idx,
+ left_pad=False,
+ )
+ )
return res
else:
return data_utils.collate_tokens(
- [s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
+ [s[key] for s in samples],
+ pad_idx,
+ eos_idx,
+ left_pad=False,
)
- src_tokens = merge('source')
- if samples[0]['target'] is not None:
- is_target_list = isinstance(samples[0]['target'], list)
- target = merge('target', is_target_list)
+ src_tokens = merge("source")
+ if samples[0]["target"] is not None:
+ is_target_list = isinstance(samples[0]["target"], list)
+ target = merge("target", is_target_list)
else:
target = src_tokens
return {
- 'id': torch.LongTensor([s['id'] for s in samples]),
- 'nsentences': len(samples),
- 'ntokens': sum(len(s['source']) for s in samples),
- 'net_input': {
- 'src_tokens': src_tokens,
- 'src_lengths': torch.LongTensor([
- s['source'].numel() for s in samples
- ]),
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "nsentences": len(samples),
+ "ntokens": sum(len(s["source"]) for s in samples),
+ "net_input": {
+ "src_tokens": src_tokens,
+ "src_lengths": torch.LongTensor([s["source"].numel() for s in samples]),
},
- 'target': target,
+ "target": target,
}
@@ -59,8 +65,17 @@ class MonolingualDataset(FairseqDataset):
(default: True).
"""
- def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
- targets=None, add_bos_token=False):
+ def __init__(
+ self,
+ dataset,
+ sizes,
+ src_vocab,
+ tgt_vocab,
+ add_eos_for_other_targets,
+ shuffle,
+ targets=None,
+ add_bos_token=False,
+ ):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = src_vocab
@@ -69,8 +84,9 @@ def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targe
self.shuffle = shuffle
self.add_bos_token = add_bos_token
- assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
- "targets must be none or one of 'self', 'future', 'past'"
+ assert targets is None or all(
+ t in {"self", "future", "past"} for t in targets
+ ), "targets must be none or one of 'self', 'future', 'past'"
if targets is not None and len(targets) == 0:
targets = None
self.targets = targets
@@ -86,12 +102,14 @@ def __getitem__(self, index):
# Right-to-left language models should condition on *source* and
# predict *past_target*.
source, future_target, past_target = self.dataset[index]
- source, target = self._make_source_target(source, future_target, past_target)
+ source, target = self._make_source_target(
+ source, future_target, past_target
+ )
else:
source = self.dataset[index]
target = None
source, target = self._maybe_add_bos(source, target)
- return {'id': index, 'source': source, 'target': target}
+ return {"id": index, "source": source, "target": target}
def __len__(self):
return len(self.dataset)
@@ -100,27 +118,38 @@ def _make_source_target(self, source, future_target, past_target):
if self.targets is not None:
target = []
- if self.add_eos_for_other_targets and (('self' in self.targets) or ('past' in self.targets)) \
- and source[-1] != self.vocab.eos():
+ if (
+ self.add_eos_for_other_targets
+ and (("self" in self.targets) or ("past" in self.targets))
+ and source[-1] != self.vocab.eos()
+ ):
# append eos at the end of source
source = torch.cat([source, source.new([self.vocab.eos()])])
- if 'future' in self.targets:
- future_target = torch.cat([future_target, future_target.new([self.vocab.pad()])])
- if 'past' in self.targets:
+ if "future" in self.targets:
+ future_target = torch.cat(
+ [future_target, future_target.new([self.vocab.pad()])]
+ )
+ if "past" in self.targets:
# first token is before the start of sentence which is only used in "none" break mode when
# add_eos_for_other_targets is False
- past_target = torch.cat([past_target.new([self.vocab.pad()]), past_target[1:], source[-2, None]])
+ past_target = torch.cat(
+ [
+ past_target.new([self.vocab.pad()]),
+ past_target[1:],
+ source[-2, None],
+ ]
+ )
for t in self.targets:
- if t == 'self':
+ if t == "self":
target.append(source)
- elif t == 'future':
+ elif t == "future":
target.append(future_target)
- elif t == 'past':
+ elif t == "past":
target.append(past_target)
else:
- raise Exception('invalid target ' + t)
+ raise Exception("invalid target " + t)
if len(target) == 1:
target = target[0]
@@ -138,6 +167,7 @@ def _maybe_add_bos(self, source, target):
def _filter_vocab(self, target):
if len(self.tgt_vocab) != len(self.vocab):
+
def _filter(target):
mask = target.ge(len(self.tgt_vocab))
if mask.any():
@@ -194,7 +224,7 @@ def ordered_indices(self):
@property
def supports_prefetch(self):
- return getattr(self.dataset, 'supports_prefetch', False)
+ return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(indices)
diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py
index 7ce269a4df..8c14f4e3ad 100644
--- a/fairseq/data/multilingual/multilingual_data_manager.py
+++ b/fairseq/data/multilingual/multilingual_data_manager.py
@@ -6,9 +6,9 @@
import itertools
import json
import logging
+import math
import os
from collections import OrderedDict, defaultdict
-import math
from fairseq import utils
from fairseq.data import (
@@ -197,7 +197,7 @@ def add_args(parser):
)
parser.add_argument(
"--fixed-dictionary",
- help='Fixed dictionary to use with model path',
+ help="Fixed dictionary to use with model path",
default=None,
type=str,
)
@@ -266,7 +266,9 @@ def load_langs(cls, args, **kwargs):
langs = sorted(langs)
logger.info(f"inferred language list: {langs}")
elif args.lang_dict:
- with open(PathManager.get_local_path(args.lang_dict), "r", encoding="utf-8") as f:
+ with open(
+ PathManager.get_local_path(args.lang_dict), "r", encoding="utf-8"
+ ) as f:
langs = [lang.strip() for lang in f.readlines() if lang.strip()]
logger.info(
f"loaded language list from {args.lang_dict} as they are ordered in file"
@@ -292,7 +294,9 @@ def estimate_global_pass_epoch(self, epoch):
if self.args.virtual_epoch_size is None or self.args.virtual_data_size is None:
return None
# one epoch more for remaining data in each shard
- virtual_epochs_per_shard = math.ceil(self.args.virtual_data_size / self.args.virtual_epoch_size)
+ virtual_epochs_per_shard = math.ceil(
+ self.args.virtual_data_size / self.args.virtual_epoch_size
+ )
# note that fairseq epoch / shard_epoch starts from 1
shard_epoch = (epoch - 1) // virtual_epochs_per_shard + 1
return shard_epoch
@@ -809,7 +813,7 @@ def _get_shard_num_dict(cls, split, paths):
for f in files:
if f.startswith(split) and f.endswith(".idx"):
# idx files of the form "{split}.{src}-{tgt}.{lang}.idx"
- direction = f.split('.')[-3]
+ direction = f.split(".")[-3]
directions.add(direction)
for direction in directions:
shards[direction] += 1
diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py
index 5270675124..3f544b099f 100644
--- a/fairseq/data/multilingual/sampled_multi_dataset.py
+++ b/fairseq/data/multilingual/sampled_multi_dataset.py
@@ -3,25 +3,25 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from typing import List
-from enum import Enum
-from collections import OrderedDict
-from collections import defaultdict
-from bisect import bisect_right
+import datetime
import hashlib
import logging
-import datetime
import time
+from bisect import bisect_right
+from collections import OrderedDict, defaultdict
+from enum import Enum
+from typing import List
import numpy as np
import torch
-
from fairseq import distributed_utils
from fairseq.data import FairseqDataset, data_utils
def get_time_gap(s, e):
- return (datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)).__str__()
+ return (
+ datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
+ ).__str__()
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ def __init__(
eval_key=None,
collate_format=CollateFormat.single,
virtual_size=default_virtual_size_func,
- split='',
+ split="",
shared_collater=False,
shuffle=True,
):
@@ -126,9 +126,7 @@ def _clean_if_not_none(self, var_list):
del v
def _reset_cached_properties(self):
- self._clean_if_not_none([
- self._sizes, self._cur_indices
- ])
+ self._clean_if_not_none([self._sizes, self._cur_indices])
self._sizes = None
self._cur_indices = None
@@ -142,10 +140,14 @@ def setup_sampling(self, sample_ratios, virtual_size):
if not isinstance(sample_ratios, np.ndarray):
sample_ratios = np.array(sample_ratios)
self.sample_ratios = sample_ratios
- virtual_size = default_virtual_size_func if virtual_size is None else virtual_size
+ virtual_size = (
+ default_virtual_size_func if virtual_size is None else virtual_size
+ )
self.virtual_size = (
- virtual_size(self.datasets, self.sample_ratios) if callable(virtual_size)
- else virtual_size)
+ virtual_size(self.datasets, self.sample_ratios)
+ if callable(virtual_size)
+ else virtual_size
+ )
def adjust_sampling(self, epoch, sampling_ratios, virtual_size):
if sampling_ratios is not None:
@@ -166,10 +168,12 @@ def _sync_sample_ratios(self, ratios):
return ret
def random_choice_in_dataset(self, rng, dataset, choice_size):
- if hasattr(dataset, 'random_choice_in_dataset'):
+ if hasattr(dataset, "random_choice_in_dataset"):
return dataset.random_choice_in_dataset(rng, choice_size)
dataset_size = len(dataset)
- return rng.choice(dataset_size, choice_size, replace=(choice_size > dataset_size))
+ return rng.choice(
+ dataset_size, choice_size, replace=(choice_size > dataset_size)
+ )
def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size):
def get_counts(sample_ratios):
@@ -178,7 +182,9 @@ def get_counts(sample_ratios):
assert diff >= 0
# due to round-offs, the size might not match the desired sizes
if diff > 0:
- dataset_indices = rng.choice(len(sample_ratios), size=diff, p=sample_ratios)
+ dataset_indices = rng.choice(
+ len(sample_ratios), size=diff, p=sample_ratios
+ )
for i in dataset_indices:
counts[i] += 1
return counts
@@ -189,7 +195,8 @@ def get_in_dataset_indices(datasets, sizes, sample_ratios):
# if the desired counts are large, sample with replacement:
indices = [
self.random_choice_in_dataset(rng, d, c)
- for c, d in zip(counts, datasets)]
+ for c, d in zip(counts, datasets)
+ ]
return indices
sizes = [len(d) for d in datasets]
@@ -207,8 +214,8 @@ def get_in_dataset_indices(datasets, sizes, sample_ratios):
assert cumulative_sizes[-1] == virtual_size
if virtual_size < sum(sizes):
logger.warning(
- f'virtual data size ({virtual_size}) is less than real data size ({sum(sizes)}).'
- ' If virtual size << real data size, there could be data coverage issue.'
+ f"virtual data size ({virtual_size}) is less than real data size ({sum(sizes)})."
+ " If virtual size << real data size, there could be data coverage issue."
)
in_dataset_indices = np.hstack(in_dataset_indices)
return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset
@@ -237,26 +244,34 @@ def collater(self, samples, **extra_args):
"""Merge a list of samples to form a mini-batch."""
if len(samples) == 0:
return None
- if self.collate_format == 'ordered_dict':
+ if self.collate_format == "ordered_dict":
collect_samples = [[] for _ in range(len(self.datasets))]
for (i, sample) in samples:
collect_samples[i].append(sample)
- batch = OrderedDict([
- (self.keys[i], dataset.collater(collect_samples[i]))
- for i, (key, dataset) in enumerate(zip(self.keys, self.datasets))
- if len(collect_samples[i]) > 0
- ])
- elif self.shared_collater:
- batch = self.datasets[0].collater(
- [s for _, s in samples]
+ batch = OrderedDict(
+ [
+ (self.keys[i], dataset.collater(collect_samples[i]))
+ for i, (key, dataset) in enumerate(zip(self.keys, self.datasets))
+ if len(collect_samples[i]) > 0
+ ]
)
+ elif self.shared_collater:
+ batch = self.datasets[0].collater([s for _, s in samples])
else:
samples_dict = defaultdict(list)
- pad_to_length = defaultdict(int) if 'pad_to_length' not in extra_args else extra_args['pad_to_length']
+ pad_to_length = (
+ defaultdict(int)
+ if "pad_to_length" not in extra_args
+ else extra_args["pad_to_length"]
+ )
for ds_idx, s in samples:
- pad_to_length['source'] = max(pad_to_length['source'], s['source'].size(0))
- if s['target'] is not None:
- pad_to_length['target'] = max(pad_to_length['target'], s['target'].size(0))
+ pad_to_length["source"] = max(
+ pad_to_length["source"], s["source"].size(0)
+ )
+ if s["target"] is not None:
+ pad_to_length["target"] = max(
+ pad_to_length["target"], s["target"].size(0)
+ )
samples_dict[ds_idx].append(s)
batches = [
self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length)
@@ -268,7 +283,9 @@ def straight_data(tensors):
batch = torch.cat(tensors, dim=0)
return batch
- src_lengths = straight_data([b['net_input']['src_lengths'] for b in batches])
+ src_lengths = straight_data(
+ [b["net_input"]["src_lengths"] for b in batches]
+ )
src_lengths, sort_order = src_lengths.sort(descending=True)
def straight_order(tensors):
@@ -276,22 +293,31 @@ def straight_order(tensors):
return batch.index_select(0, sort_order)
batch = {
- 'id': straight_order([b['id'] for b in batches]),
- 'nsentences': sum(b['nsentences'] for b in batches),
- 'ntokens': sum(b['ntokens'] for b in batches),
- 'net_input': {
- 'src_tokens': straight_order([b['net_input']['src_tokens'] for b in batches]),
- 'src_lengths': src_lengths,
+ "id": straight_order([b["id"] for b in batches]),
+ "nsentences": sum(b["nsentences"] for b in batches),
+ "ntokens": sum(b["ntokens"] for b in batches),
+ "net_input": {
+ "src_tokens": straight_order(
+ [b["net_input"]["src_tokens"] for b in batches]
+ ),
+ "src_lengths": src_lengths,
},
- 'target': straight_order([b['target'] for b in batches]) if batches[0]['target'] is not None else None,
+ "target": straight_order([b["target"] for b in batches])
+ if batches[0]["target"] is not None
+ else None,
}
- if 'prev_output_tokens' in batches[0]['net_input']:
- batch['net_input']['prev_output_tokens'] = straight_order(
- [b['net_input']['prev_output_tokens'] for b in batches])
- if 'src_lang_id' in batches[0]['net_input']:
- batch['net_input']['src_lang_id'] = straight_order([b['net_input']['src_lang_id'] for b in batches])
- if 'tgt_lang_id' in batches[0]:
- batch['tgt_lang_id'] = straight_order([b['tgt_lang_id'] for b in batches])
+ if "prev_output_tokens" in batches[0]["net_input"]:
+ batch["net_input"]["prev_output_tokens"] = straight_order(
+ [b["net_input"]["prev_output_tokens"] for b in batches]
+ )
+ if "src_lang_id" in batches[0]["net_input"]:
+ batch["net_input"]["src_lang_id"] = straight_order(
+ [b["net_input"]["src_lang_id"] for b in batches]
+ )
+ if "tgt_lang_id" in batches[0]:
+ batch["tgt_lang_id"] = straight_order(
+ [b["tgt_lang_id"] for b in batches]
+ )
return batch
@property
@@ -300,7 +326,9 @@ def sizes(self):
return self._sizes
start_time = time.time()
in_sub_dataset_indices = [
- self._cur_indices[0 if i == 0 else self.cumulated_sizes[i-1]:self.cumulated_sizes[i]]
+ self._cur_indices[
+ 0 if i == 0 else self.cumulated_sizes[i - 1] : self.cumulated_sizes[i]
+ ]
for i in range(len(self.datasets))
]
sub_dataset_sizes = [
@@ -308,7 +336,7 @@ def sizes(self):
for d, indices in zip(self.datasets, in_sub_dataset_indices)
]
self._sizes = np.vstack(sub_dataset_sizes)
- logger.info(f'sizes() calling time: {get_time_gap(start_time, time.time())}')
+ logger.info(f"sizes() calling time: {get_time_gap(start_time, time.time())}")
return self._sizes
def ordered_indices(self):
@@ -319,14 +347,14 @@ def ordered_indices(self):
sizes = self.sizes
tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
- src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ src_sizes = (
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ )
# sort by target length, then source length
if tgt_sizes is not None:
- indices = indices[
- np.argsort(tgt_sizes[indices], kind='mergesort')
- ]
- sort_indices = indices[np.argsort(src_sizes[indices], kind='mergesort')]
+ indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
+ sort_indices = indices[np.argsort(src_sizes[indices], kind="mergesort")]
return sort_indices
def prefetch(self, indices):
@@ -347,7 +375,7 @@ def set_epoch(self, epoch):
# re-enter so return
return
for d in self.datasets:
- if hasattr(d, 'set_epoch'):
+ if hasattr(d, "set_epoch"):
d.set_epoch(epoch)
self._cur_epoch = epoch
self._establish_virtual_datasets()
@@ -362,37 +390,52 @@ def _establish_virtual_datasets(self):
# Generate a weighted sample of indices as a function of the
# random seed and the current epoch.
rng = np.random.RandomState(
- [
- int(hashlib.sha1(str(self.__class__.__name__).encode('utf-8')).hexdigest(), 16) % (2 ** 32),
- self.seed % (2 ** 32), # global seed
- self._cur_epoch, # epoch index,
- ]
+ [
+ int(
+ hashlib.sha1(
+ str(self.__class__.__name__).encode("utf-8")
+ ).hexdigest(),
+ 16,
+ )
+ % (2 ** 32),
+ self.seed % (2 ** 32), # global seed
+ self._cur_epoch, # epoch index,
+ ]
+ )
+ self._clean_if_not_none(
+ [self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes]
)
- self._clean_if_not_none([
- self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes
- ])
self._sizes = None
indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices(
- rng, self.datasets, self.sample_ratios, self.virtual_size)
+ rng, self.datasets, self.sample_ratios, self.virtual_size
+ )
self._cur_indices = indices
self.cumulated_sizes = cumulated_sizes
self.virtual_size_per_dataset = virtual_size_per_dataset
raw_sizes = [len(d) for d in self.datasets]
sampled_sizes = self.virtual_size_per_dataset
- logger.info(f'[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; '
- f'raw total size: {sum(raw_sizes)}')
- logger.info(f'[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; '
- f'resampled total size: {sum(sampled_sizes)}')
+ logger.info(
+ f"[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; "
+ f"raw total size: {sum(raw_sizes)}"
+ )
+ logger.info(
+ f"[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; "
+ f"resampled total size: {sum(sampled_sizes)}"
+ )
if self.sample_ratios is not None:
- logger.info(f'[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}')
+ logger.info(
+ f"[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}"
+ )
else:
- logger.info(f'[{self.split}] A concat dataset')
- logger.info(f'[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}')
+ logger.info(f"[{self.split}] A concat dataset")
+ logger.info(
+ f"[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}"
+ )
def filter_indices_by_size(self, indices, max_sizes):
- """ Filter a list of sample indices. Remove those that are longer
+ """Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
@@ -406,6 +449,10 @@ def filter_indices_by_size(self, indices, max_sizes):
"""
sizes = self.sizes
tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
- src_sizes = sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ src_sizes = (
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ )
- return data_utils.filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes)
+ return data_utils.filter_paired_dataset_indices_by_size(
+ src_sizes, tgt_sizes, indices, max_sizes
+ )
diff --git a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py
index 81ff78f705..17387b2f85 100644
--- a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py
+++ b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py
@@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree.
import hashlib
-import math
import logging
+import math
import numpy as np
from fairseq.data import SampledMultiDataset
-from .sampled_multi_dataset import default_virtual_size_func, CollateFormat
+
+from .sampled_multi_dataset import CollateFormat, default_virtual_size_func
logger = logging.getLogger(__name__)
@@ -44,6 +45,7 @@ class SampledMultiEpochDataset(SampledMultiDataset):
shard_epoch (int): the real epoch number for shard selection.
shuffle (bool): whether or not to shuffle data (default: True).
"""
+
def __init__(
self,
datasets,
@@ -53,7 +55,7 @@ def __init__(
eval_key=None,
collate_format=CollateFormat.single,
virtual_size=default_virtual_size_func,
- split='',
+ split="",
virtual_epoch_size=None,
shared_collater=False,
shard_epoch=1,
@@ -79,14 +81,22 @@ def __init__(
)
def _setup(self, epoch):
- self.virtual_epoch_size = self.virtual_epoch_size if self.virtual_epoch_size is not None else self.virtual_size
+ self.virtual_epoch_size = (
+ self.virtual_epoch_size
+ if self.virtual_epoch_size is not None
+ else self.virtual_size
+ )
if self.virtual_epoch_size > self.virtual_size:
- logger.warning(f'virtual epoch size {self.virtual_epoch_size} '
- f'is greater than virtual dataset size {self.virtual_size}')
+ logger.warning(
+ f"virtual epoch size {self.virtual_epoch_size} "
+ f"is greater than virtual dataset size {self.virtual_size}"
+ )
self.virtual_epoch_size = self.virtual_size
self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size)
self._current_epoch_start_index = self._get_epoch_start_index(epoch)
- logger.info(f'virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}')
+ logger.info(
+ f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}"
+ )
def _map_epoch_index_to_global(self, index):
index = self._current_epoch_start_index + index
@@ -99,7 +109,8 @@ def sizes(self):
return self._epoch_sizes
_sizes = super().sizes
indices = self._random_global_indices[
- self._current_epoch_start_index:self._current_epoch_start_index + len(self)
+ self._current_epoch_start_index : self._current_epoch_start_index
+ + len(self)
]
self._epoch_sizes = _sizes[indices]
# del super()._sizes to save memory
@@ -114,7 +125,8 @@ def _get_dataset_and_index(self, index):
def __len__(self):
return (
self.virtual_epoch_size
- if self._current_epoch_start_index + self.virtual_epoch_size < self.virtual_size
+ if self._current_epoch_start_index + self.virtual_epoch_size
+ < self.virtual_size
else self.virtual_size - self._current_epoch_start_index
)
@@ -136,38 +148,52 @@ def _get_epoch_start_index(self, epoch):
def _next_global_indices(self, epoch):
rng = np.random.RandomState(
- [
- int(hashlib.sha1(str(self.__class__.__name__).encode('utf-8')).hexdigest(), 16) % (2 ** 32),
- self.seed % (2 ** 32), # global seed
- epoch, # epoch index,
- ]
+ [
+ int(
+ hashlib.sha1(
+ str(self.__class__.__name__).encode("utf-8")
+ ).hexdigest(),
+ 16,
+ )
+ % (2 ** 32),
+ self.seed % (2 ** 32), # global seed
+ epoch, # epoch index,
+ ]
)
del self._random_global_indices
- self._random_global_indices = rng.choice(self.virtual_size, self.virtual_size, replace=False)
+ self._random_global_indices = rng.choice(
+ self.virtual_size, self.virtual_size, replace=False
+ )
if self.load_next_shard is None:
self.load_next_shard = False
else:
# increase shard epoch for next loading
self.shard_epoch += 1
self.load_next_shard = True
- logger.info('to load next epoch/shard in next load_dataset: '
- f'epoch={epoch}/shard_epoch={self.shard_epoch}')
+ logger.info(
+ "to load next epoch/shard in next load_dataset: "
+ f"epoch={epoch}/shard_epoch={self.shard_epoch}"
+ )
def _next_virtual_epoch(self, epoch):
index = self._get_epoch_start_index(epoch)
if index == 0 or self._random_global_indices is None:
# need to start from the beginning,
# so call super().set_epoch(epoch) to establish the global virtual indices
- logger.info('establishing a new set of global virtual indices for '
- f'epoch={epoch}/shard_epoch={self.shard_epoch}')
+ logger.info(
+ "establishing a new set of global virtual indices for "
+ f"epoch={epoch}/shard_epoch={self.shard_epoch}"
+ )
super().set_epoch(epoch)
self._next_global_indices(epoch)
else:
self._cur_epoch = epoch
# reset cache sizes and ordered_indices for the epoch after moving to a new epoch
- self._clean_if_not_none([
- self._epoch_sizes,
- ])
+ self._clean_if_not_none(
+ [
+ self._epoch_sizes,
+ ]
+ )
self._epoch_sizes = None
self._current_epoch_start_index = index
diff --git a/fairseq/data/multilingual/sampling_method.py b/fairseq/data/multilingual/sampling_method.py
index 6a9d39f7a6..140c68f01d 100644
--- a/fairseq/data/multilingual/sampling_method.py
+++ b/fairseq/data/multilingual/sampling_method.py
@@ -3,8 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from typing import List
import logging
+from typing import List
logger = logging.getLogger(__name__)
@@ -16,18 +16,20 @@ def uniform(dataset_sizes: List[int]):
def temperature_sampling(dataset_sizes, temp):
total_size = sum(dataset_sizes)
- return [(size / total_size) ** (1.0/temp) for size in dataset_sizes]
+ return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes]
def make_temperature_sampling(temp=1.0):
def sampling_func(dataset_sizes):
return temperature_sampling(dataset_sizes, temp)
+
return sampling_func
def make_ratio_sampling(ratios):
def sampling_func(dataset_sizes):
return ratios
+
return sampling_func
@@ -35,13 +37,23 @@ class SamplingMethod:
@staticmethod
def add_arguments(parser):
parser.add_argument(
- '--sampling-method',
- choices=['uniform', 'temperature', 'concat', 'RoundRobin', ],
+ "--sampling-method",
+ choices=[
+ "uniform",
+ "temperature",
+ "concat",
+ "RoundRobin",
+ ],
type=str,
- default='concat',
- help='The method to sample data per language pairs')
- parser.add_argument('--sampling-temperature', default=1.5, type=float,
- help='only work with --sampling-method temperature')
+ default="concat",
+ help="The method to sample data per language pairs",
+ )
+ parser.add_argument(
+ "--sampling-temperature",
+ default=1.5,
+ type=float,
+ help="only work with --sampling-method temperature",
+ )
@staticmethod
def build_sampler(args, task):
@@ -56,10 +68,10 @@ def is_adaptive(self):
def sampling_method_selector(self):
args = self.args
- logger.info(f'selected sampler: {args.sampling_method}')
- if args.sampling_method == 'uniform':
+ logger.info(f"selected sampler: {args.sampling_method}")
+ if args.sampling_method == "uniform":
return uniform
- elif args.sampling_method == 'temperature' or self.is_adaptive():
+ elif args.sampling_method == "temperature" or self.is_adaptive():
return make_temperature_sampling(float(args.sampling_temperature))
else:
# default to concating all data set together
diff --git a/fairseq/data/nested_dictionary_dataset.py b/fairseq/data/nested_dictionary_dataset.py
index ebc56303b9..52e74abdda 100644
--- a/fairseq/data/nested_dictionary_dataset.py
+++ b/fairseq/data/nested_dictionary_dataset.py
@@ -15,14 +15,14 @@ def _flatten(dico, prefix=None):
"""Flatten a nested dictionary."""
new_dico = OrderedDict()
if isinstance(dico, dict):
- prefix = prefix + '.' if prefix is not None else ''
+ prefix = prefix + "." if prefix is not None else ""
for k, v in dico.items():
if v is None:
continue
new_dico.update(_flatten(v, prefix + k))
elif isinstance(dico, list):
for i, v in enumerate(dico):
- new_dico.update(_flatten(v, prefix + '.[' + str(i) + ']'))
+ new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]"))
else:
new_dico = OrderedDict({prefix: dico})
return new_dico
@@ -32,10 +32,10 @@ def _unflatten(dico):
"""Unflatten a flattened dictionary into a nested dictionary."""
new_dico = OrderedDict()
for full_k, v in dico.items():
- full_k = full_k.split('.')
+ full_k = full_k.split(".")
node = new_dico
for k in full_k[:-1]:
- if k.startswith('[') and k.endswith(']'):
+ if k.startswith("[") and k.endswith("]"):
k = int(k[1:-1])
if k not in node:
node[k] = OrderedDict()
@@ -45,7 +45,6 @@ def _unflatten(dico):
class NestedDictionaryDataset(FairseqDataset):
-
def __init__(self, defn, sizes=None):
super().__init__()
self.defn = _flatten(defn)
@@ -53,11 +52,17 @@ def __init__(self, defn, sizes=None):
first = None
for v in self.defn.values():
- if not isinstance(v, (FairseqDataset, torch.utils.data.Dataset, )):
- raise ValueError('Expected Dataset but found: {}'.format(v.__class__))
+ if not isinstance(
+ v,
+ (
+ FairseqDataset,
+ torch.utils.data.Dataset,
+ ),
+ ):
+ raise ValueError("Expected Dataset but found: {}".format(v.__class__))
first = first or v
if len(v) > 0:
- assert len(v) == len(first), 'dataset lengths must match'
+ assert len(v) == len(first), "dataset lengths must match"
self._len = len(first)
@@ -107,7 +112,7 @@ def supports_prefetch(self):
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
for ds in self.defn.values():
- if getattr(ds, 'supports_prefetch', False):
+ if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
@property
diff --git a/fairseq/data/noising.py b/fairseq/data/noising.py
index 5801ae6eac..9643d1aa6a 100644
--- a/fairseq/data/noising.py
+++ b/fairseq/data/noising.py
@@ -3,32 +3,34 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
import numpy as np
-
+import torch
from fairseq.data import data_utils
class WordNoising(object):
"""Generate a noisy version of a sentence, without changing words themselves."""
+
def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None):
self.dictionary = dictionary
self.bpe_end = None
if bpe_cont_marker:
- self.bpe_end = np.array([
- not self.dictionary[i].endswith(bpe_cont_marker)
- for i in range(len(self.dictionary))
- ])
+ self.bpe_end = np.array(
+ [
+ not self.dictionary[i].endswith(bpe_cont_marker)
+ for i in range(len(self.dictionary))
+ ]
+ )
elif bpe_end_marker:
- self.bpe_end = np.array([
- self.dictionary[i].endswith(bpe_end_marker)
- for i in range(len(self.dictionary))
- ])
+ self.bpe_end = np.array(
+ [
+ self.dictionary[i].endswith(bpe_end_marker)
+ for i in range(len(self.dictionary))
+ ]
+ )
self.get_word_idx = (
- self._get_bpe_word_idx
- if self.bpe_end is not None
- else self._get_token_idx
+ self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx
)
def noising(self, x, lengths, noising_prob=0.0):
@@ -44,7 +46,7 @@ def _get_bpe_word_idx(self, x):
# x: (T x B)
bpe_end = self.bpe_end[x]
- if (x.size(0) == 1 and x.size(1) == 1):
+ if x.size(0) == 1 and x.size(1) == 1:
# Special case when we only have one word in x. If x = [[N]],
# bpe_end is a scalar (bool) instead of a 2-dim array of bools,
# which makes the sum operation below fail.
@@ -70,7 +72,13 @@ class WordDropout(WordNoising):
then dropped words will be removed. Otherwise, it will be replaced by the
blank_idx."""
- def __init__(self, dictionary, default_dropout_prob=0.1, bpe_cont_marker="@@", bpe_end_marker=None):
+ def __init__(
+ self,
+ dictionary,
+ default_dropout_prob=0.1,
+ bpe_cont_marker="@@",
+ bpe_end_marker=None,
+ ):
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
self.default_dropout_prob = default_dropout_prob
@@ -108,13 +116,12 @@ def noising(self, x, lengths, dropout_prob=None, blank_idx=None):
else:
keep = np.random.rand(num_words) >= dropout_prob
- words = x[:lengths[i], i].tolist()
+ words = x[: lengths[i], i].tolist()
# TODO: speed up the following loop
# drop words from the input according to keep
new_s = [
- w if keep[word_idx[j, i]] else blank_idx
- for j, w in enumerate(words)
+ w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words)
]
new_s = [w for w in new_s if w is not None]
# we need to have at least one word in the sentence (more than the
@@ -132,11 +139,10 @@ def noising(self, x, lengths, dropout_prob=None, blank_idx=None):
# re-construct input
modified_lengths = torch.LongTensor(modified_lengths)
modified_x = torch.LongTensor(
- modified_lengths.max(),
- modified_lengths.size(0)
+ modified_lengths.max(), modified_lengths.size(0)
).fill_(self.dictionary.pad())
for i in range(modified_lengths.size(0)):
- modified_x[:modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))
+ modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))
return modified_x, modified_lengths
@@ -144,7 +150,13 @@ def noising(self, x, lengths, dropout_prob=None, blank_idx=None):
class WordShuffle(WordNoising):
"""Shuffle words by no more than k positions."""
- def __init__(self, dictionary, default_max_shuffle_distance=3, bpe_cont_marker="@@", bpe_end_marker=None):
+ def __init__(
+ self,
+ dictionary,
+ default_max_shuffle_distance=3,
+ bpe_cont_marker="@@",
+ bpe_end_marker=None,
+ ):
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
self.default_max_shuffle_distance = 3
@@ -189,6 +201,7 @@ class UnsupervisedMTNoising(WordNoising):
Implements the default configuration for noising in UnsupervisedMT
(github.com/facebookresearch/UnsupervisedMT)
"""
+
def __init__(
self,
dictionary,
@@ -275,8 +288,13 @@ def __init__(
self.src_dataset = src_dataset
self.src_dict = src_dict
self.seed = seed
- self.noiser = noiser if noiser is not None else noising_class(
- dictionary=src_dict, **kwargs,
+ self.noiser = (
+ noiser
+ if noiser is not None
+ else noising_class(
+ dictionary=src_dict,
+ **kwargs,
+ )
)
def __getitem__(self, index):
diff --git a/fairseq/data/num_samples_dataset.py b/fairseq/data/num_samples_dataset.py
index 9d7ea44019..99a17495c7 100644
--- a/fairseq/data/num_samples_dataset.py
+++ b/fairseq/data/num_samples_dataset.py
@@ -7,7 +7,6 @@
class NumSamplesDataset(FairseqDataset):
-
def __getitem__(self, index):
return 1
diff --git a/fairseq/data/numel_dataset.py b/fairseq/data/numel_dataset.py
index 50087e5857..ac86dfd2f1 100644
--- a/fairseq/data/numel_dataset.py
+++ b/fairseq/data/numel_dataset.py
@@ -10,7 +10,6 @@
class NumelDataset(BaseWrapperDataset):
-
def __init__(self, dataset, reduce=False):
super().__init__(dataset)
self.reduce = reduce
diff --git a/fairseq/data/offset_tokens_dataset.py b/fairseq/data/offset_tokens_dataset.py
index a6fd559a30..6fabbdcdaa 100644
--- a/fairseq/data/offset_tokens_dataset.py
+++ b/fairseq/data/offset_tokens_dataset.py
@@ -7,7 +7,6 @@
class OffsetTokensDataset(BaseWrapperDataset):
-
def __init__(self, dataset, offset):
super().__init__(dataset)
self.offset = offset
diff --git a/fairseq/data/pad_dataset.py b/fairseq/data/pad_dataset.py
index 4c13b549aa..8075bba6a9 100644
--- a/fairseq/data/pad_dataset.py
+++ b/fairseq/data/pad_dataset.py
@@ -9,7 +9,6 @@
class PadDataset(BaseWrapperDataset):
-
def __init__(self, dataset, pad_idx, left_pad):
super().__init__(dataset)
self.pad_idx = pad_idx
@@ -20,12 +19,10 @@ def collater(self, samples):
class LeftPadDataset(PadDataset):
-
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=True)
class RightPadDataset(PadDataset):
-
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=False)
diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py
index 33f250eea9..2b12646783 100644
--- a/fairseq/data/plasma_utils.py
+++ b/fairseq/data/plasma_utils.py
@@ -33,6 +33,7 @@ def plasma(self):
if self._plasma is None and not self.disable:
try:
import pyarrow.plasma as plasma
+
self._plasma = plasma
except ImportError:
self._plasma = None
@@ -45,11 +46,15 @@ def start_server(self):
assert self.path is None
self._server_tmp = tempfile.NamedTemporaryFile()
self.path = self._server_tmp.name
- self._server = subprocess.Popen([
- 'plasma_store',
- '-m', str(int(1.05 * self.array.nbytes)),
- '-s', self.path,
- ])
+ self._server = subprocess.Popen(
+ [
+ "plasma_store",
+ "-m",
+ str(int(1.05 * self.array.nbytes)),
+ "-s",
+ self.path,
+ ]
+ )
@property
def client(self):
@@ -65,11 +70,11 @@ def __getstate__(self):
self.start_server()
self.object_id = self.client.put(self.array)
state = self.__dict__.copy()
- del state['array']
- state['_client'] = None
- state['_server'] = None
- state['_server_tmp'] = None
- state['_plasma'] = None
+ del state["array"]
+ state["_client"] = None
+ state["_server"] = None
+ state["_server_tmp"] = None
+ state["_plasma"] = None
return state
def __setstate__(self, state):
diff --git a/fairseq/data/prepend_token_dataset.py b/fairseq/data/prepend_token_dataset.py
index 9dac71badf..fd1331f4c4 100644
--- a/fairseq/data/prepend_token_dataset.py
+++ b/fairseq/data/prepend_token_dataset.py
@@ -10,7 +10,6 @@
class PrependTokenDataset(BaseWrapperDataset):
-
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
diff --git a/fairseq/data/raw_label_dataset.py b/fairseq/data/raw_label_dataset.py
index e67170f1a5..d054904f41 100644
--- a/fairseq/data/raw_label_dataset.py
+++ b/fairseq/data/raw_label_dataset.py
@@ -9,7 +9,6 @@
class RawLabelDataset(FairseqDataset):
-
def __init__(self, labels):
super().__init__()
self.labels = labels
diff --git a/fairseq/data/replace_dataset.py b/fairseq/data/replace_dataset.py
index 3bc52f0fb5..5aac2ba96b 100644
--- a/fairseq/data/replace_dataset.py
+++ b/fairseq/data/replace_dataset.py
@@ -9,12 +9,12 @@
class ReplaceDataset(BaseWrapperDataset):
"""Replaces tokens found in the dataset by a specified replacement token
- Args:
- dataset (~torch.utils.data.Dataset): dataset to replace tokens in
- replace_map(Dictionary[int,int]): map of token to replace -> replacement token
- offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
- as many as the number of objects returned by the underlying dataset __getitem__ method.
- """
+ Args:
+ dataset (~torch.utils.data.Dataset): dataset to replace tokens in
+ replace_map(Dictionary[int,int]): map of token to replace -> replacement token
+ offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
+ as many as the number of objects returned by the underlying dataset __getitem__ method.
+ """
def __init__(self, dataset, replace_map, offsets):
super().__init__(dataset)
diff --git a/fairseq/data/resampling_dataset.py b/fairseq/data/resampling_dataset.py
index ffb25ac668..3d3b993164 100644
--- a/fairseq/data/resampling_dataset.py
+++ b/fairseq/data/resampling_dataset.py
@@ -6,7 +6,6 @@
import logging
import numpy as np
-
from fairseq.data import BaseWrapperDataset, plasma_utils
@@ -112,7 +111,7 @@ def can_reuse_epoch_itr_across_epochs(self):
return False
def set_epoch(self, epoch):
- logger.debug('ResamplingDataset.set_epoch: {}'.format(epoch))
+ logger.debug("ResamplingDataset.set_epoch: {}".format(epoch))
super().set_epoch(epoch)
if epoch == self._cur_epoch:
diff --git a/fairseq/data/roll_dataset.py b/fairseq/data/roll_dataset.py
index d07800d0f6..a2915eeb3e 100644
--- a/fairseq/data/roll_dataset.py
+++ b/fairseq/data/roll_dataset.py
@@ -9,7 +9,6 @@
class RollDataset(BaseWrapperDataset):
-
def __init__(self, dataset, shifts):
super().__init__(dataset)
self.shifts = shifts
diff --git a/fairseq/data/round_robin_zip_datasets.py b/fairseq/data/round_robin_zip_datasets.py
index 5bfc966ce8..690823fc86 100644
--- a/fairseq/data/round_robin_zip_datasets.py
+++ b/fairseq/data/round_robin_zip_datasets.py
@@ -40,16 +40,19 @@ def __init__(self, datasets, eval_key=None):
self._ordered_indices = None
def _map_index(self, key, index):
- assert self._ordered_indices is not None, \
- 'Must call RoundRobinZipDatasets.ordered_indices() first'
+ assert (
+ self._ordered_indices is not None
+ ), "Must call RoundRobinZipDatasets.ordered_indices() first"
return self._ordered_indices[key][index % len(self.datasets[key])]
def __getitem__(self, index):
if self.eval_key is None:
- return OrderedDict([
- (key, dataset[self._map_index(key, index)])
- for key, dataset in self.datasets.items()
- ])
+ return OrderedDict(
+ [
+ (key, dataset[self._map_index(key, index)])
+ for key, dataset in self.datasets.items()
+ ]
+ )
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key][self._map_index(self.eval_key, index)]
@@ -62,10 +65,12 @@ def collater(self, samples):
if len(samples) == 0:
return None
if self.eval_key is None:
- return OrderedDict([
- (key, dataset.collater([sample[key] for sample in samples]))
- for key, dataset in self.datasets.items()
- ])
+ return OrderedDict(
+ [
+ (key, dataset.collater([sample[key] for sample in samples]))
+ for key, dataset in self.datasets.items()
+ ]
+ )
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key].collater(samples)
@@ -92,16 +97,18 @@ def ordered_indices(self):
# Call the underlying dataset's ordered_indices() here, so that we
# get the same random ordering as we would have from using the
# underlying dataset directly.
- self._ordered_indices = OrderedDict([
- (key, dataset.ordered_indices())
- for key, dataset in self.datasets.items()
- ])
+ self._ordered_indices = OrderedDict(
+ [
+ (key, dataset.ordered_indices())
+ for key, dataset in self.datasets.items()
+ ]
+ )
return np.arange(len(self))
@property
def supports_prefetch(self):
return all(
- getattr(dataset, 'supports_prefetch', False)
+ getattr(dataset, "supports_prefetch", False)
for dataset in self.datasets.values()
)
diff --git a/fairseq/data/shorten_dataset.py b/fairseq/data/shorten_dataset.py
index 85659d101e..6ebb5d88fe 100644
--- a/fairseq/data/shorten_dataset.py
+++ b/fairseq/data/shorten_dataset.py
@@ -10,8 +10,7 @@
class TruncateDataset(BaseWrapperDataset):
- """Truncate a sequence by returning the first truncation_length tokens
- """
+ """Truncate a sequence by returning the first truncation_length tokens"""
def __init__(self, dataset, truncation_length):
super().__init__(dataset)
@@ -23,7 +22,7 @@ def __getitem__(self, index):
item = self.dataset[index]
item_len = item.size(0)
if item_len > self.truncation_length:
- item = item[:self.truncation_length]
+ item = item[: self.truncation_length]
return item
@property
@@ -35,8 +34,7 @@ def __len__(self):
class RandomCropDataset(TruncateDataset):
- """Truncate a sequence by returning a random crop of truncation_length tokens
- """
+ """Truncate a sequence by returning a random crop of truncation_length tokens"""
def __init__(self, dataset, truncation_length, seed=1):
super().__init__(dataset, truncation_length)
@@ -58,9 +56,10 @@ def __getitem__(self, index):
excess = item_len - self.truncation_length
if excess > 0:
start_idx = np.random.randint(0, excess)
- item = item[start_idx:start_idx+self.truncation_length]
+ item = item[start_idx : start_idx + self.truncation_length]
return item
+
def maybe_shorten_dataset(
dataset,
split,
@@ -69,10 +68,11 @@ def maybe_shorten_dataset(
tokens_per_sample,
seed,
):
- truncate_split = split in shorten_data_split_list.split(',') \
- or len(shorten_data_split_list) == 0
- if shorten_method == 'truncate' and truncate_split:
+ truncate_split = (
+ split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0
+ )
+ if shorten_method == "truncate" and truncate_split:
dataset = TruncateDataset(dataset, tokens_per_sample)
- elif shorten_method == 'random_crop' and truncate_split:
+ elif shorten_method == "random_crop" and truncate_split:
dataset = RandomCropDataset(dataset, tokens_per_sample, seed)
return dataset
diff --git a/fairseq/data/sort_dataset.py b/fairseq/data/sort_dataset.py
index 9b510b93a0..b3890e7279 100644
--- a/fairseq/data/sort_dataset.py
+++ b/fairseq/data/sort_dataset.py
@@ -9,7 +9,6 @@
class SortDataset(BaseWrapperDataset):
-
def __init__(self, dataset, sort_order):
super().__init__(dataset)
if not isinstance(sort_order, (list, tuple)):
diff --git a/fairseq/data/strip_token_dataset.py b/fairseq/data/strip_token_dataset.py
index e388db0e5f..cae39ba4d2 100644
--- a/fairseq/data/strip_token_dataset.py
+++ b/fairseq/data/strip_token_dataset.py
@@ -7,7 +7,6 @@
class StripTokenDataset(BaseWrapperDataset):
-
def __init__(self, dataset, id_to_strip):
super().__init__(dataset)
self.id_to_strip = id_to_strip
diff --git a/fairseq/data/subsample_dataset.py b/fairseq/data/subsample_dataset.py
index 7eca9d4cb3..48feaf883f 100644
--- a/fairseq/data/subsample_dataset.py
+++ b/fairseq/data/subsample_dataset.py
@@ -16,10 +16,10 @@
class SubsampleDataset(BaseWrapperDataset):
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
- Args:
- dataset (~torch.utils.data.Dataset): dataset to subsample
- size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
- """
+ Args:
+ dataset (~torch.utils.data.Dataset): dataset to subsample
+ size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
+ """
def __init__(self, dataset, size_ratio, shuffle=False):
super().__init__(dataset)
diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py
index cae872c310..aa33f9d06f 100644
--- a/fairseq/data/token_block_dataset.py
+++ b/fairseq/data/token_block_dataset.py
@@ -5,7 +5,6 @@
import numpy as np
import torch
-
from fairseq.data import FairseqDataset, plasma_utils
@@ -31,6 +30,7 @@ class TokenBlockDataset(FairseqDataset):
'complete_doc' break mode). Typically 1 if the sentences have eos
and 0 otherwise.
"""
+
def __init__(
self,
dataset,
@@ -49,8 +49,8 @@ def __init__(
)
except ImportError:
raise ImportError(
- 'Please build Cython components with: `pip install --editable .` '
- 'or `python setup.py build_ext --inplace`'
+ "Please build Cython components with: `pip install --editable .` "
+ "or `python setup.py build_ext --inplace`"
)
super().__init__()
@@ -69,13 +69,15 @@ def __init__(
sizes = sizes.numpy()
sizes = sizes.astype(np.int64)
- break_mode = break_mode if break_mode is not None else 'none'
+ break_mode = break_mode if break_mode is not None else "none"
# For "eos" break-mode, block_size is not required parameters.
if break_mode == "eos" and block_size is None:
block_size = 0
- slice_indices = _get_slice_indices_fast(sizes, str(break_mode), block_size, document_sep_len)
+ slice_indices = _get_slice_indices_fast(
+ sizes, str(break_mode), block_size, document_sep_len
+ )
self._sizes = slice_indices[:, 1] - slice_indices[:, 0]
# build index mapping block indices to the underlying dataset indices
diff --git a/fairseq/data/transform_eos_dataset.py b/fairseq/data/transform_eos_dataset.py
index 4ce5ad811b..fb14ff018e 100644
--- a/fairseq/data/transform_eos_dataset.py
+++ b/fairseq/data/transform_eos_dataset.py
@@ -33,11 +33,11 @@ def __init__(
has_target=True,
):
if not isinstance(dataset, FairseqDataset):
- raise ValueError('dataset must be an instance of FairseqDataset')
+ raise ValueError("dataset must be an instance of FairseqDataset")
if append_eos_to_src and remove_eos_from_src:
- raise ValueError('cannot combine append_eos_to_src and remove_eos_from_src')
+ raise ValueError("cannot combine append_eos_to_src and remove_eos_from_src")
if append_eos_to_tgt and remove_eos_from_tgt:
- raise ValueError('cannot combine append_eos_to_tgt and remove_eos_from_tgt')
+ raise ValueError("cannot combine append_eos_to_tgt and remove_eos_from_tgt")
self.dataset = dataset
self.eos = torch.LongTensor([eos])
@@ -75,24 +75,23 @@ def __len__(self):
return len(self.dataset)
def collater(self, samples):
-
def transform(item):
if self.append_eos_to_src:
- self.eos = self.eos.to(device=item['source'].device)
- self._check_src(item['source'], expect_eos=False)
- item['source'] = torch.cat([item['source'], self.eos])
+ self.eos = self.eos.to(device=item["source"].device)
+ self._check_src(item["source"], expect_eos=False)
+ item["source"] = torch.cat([item["source"], self.eos])
if self.remove_eos_from_src:
- self.eos = self.eos.to(device=item['source'].device)
- self._check_src(item['source'], expect_eos=True)
- item['source'] = item['source'][:-1]
+ self.eos = self.eos.to(device=item["source"].device)
+ self._check_src(item["source"], expect_eos=True)
+ item["source"] = item["source"][:-1]
if self.append_eos_to_tgt:
- self.eos = self.eos.to(device=item['target'].device)
- self._check_tgt(item['target'], expect_eos=False)
- item['target'] = torch.cat([item['target'], self.eos])
+ self.eos = self.eos.to(device=item["target"].device)
+ self._check_tgt(item["target"], expect_eos=False)
+ item["target"] = torch.cat([item["target"], self.eos])
if self.remove_eos_from_tgt:
- self.eos = self.eos.to(device=item['target'].device)
- self._check_tgt(item['target'], expect_eos=True)
- item['target'] = item['target'][:-1]
+ self.eos = self.eos.to(device=item["target"].device)
+ self._check_tgt(item["target"], expect_eos=True)
+ item["target"] = item["target"][:-1]
return item
samples = list(map(transform, samples))
@@ -115,7 +114,7 @@ def ordered_indices(self):
@property
def supports_prefetch(self):
- return getattr(self.dataset, 'supports_prefetch', False)
+ return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
diff --git a/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq/data/transform_eos_lang_pair_dataset.py
index 2783824838..1dd3d93d2b 100644
--- a/fairseq/data/transform_eos_lang_pair_dataset.py
+++ b/fairseq/data/transform_eos_lang_pair_dataset.py
@@ -4,10 +4,12 @@
# LICENSE file in the root directory of this source tree.
-from . import FairseqDataset
-import torch
from typing import Optional
+import torch
+
+from . import FairseqDataset
+
class TransformEosLangPairDataset(FairseqDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on
@@ -50,25 +52,37 @@ def collater(self, samples, **extra_args):
if self.new_src_eos is not None:
if self.dataset.left_pad_source:
- assert(samples['net_input']['src_tokens'][:, -1] != self.src_eos).sum() == 0
- samples['net_input']['src_tokens'][:, -1] = self.new_src_eos
+ assert (
+ samples["net_input"]["src_tokens"][:, -1] != self.src_eos
+ ).sum() == 0
+ samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos
else:
- eos_idx = samples['net_input']['src_lengths'] - 1
- assert(
- samples['net_input']['src_tokens'][torch.arange(eos_idx.size(0)), eos_idx] != self.src_eos
+ eos_idx = samples["net_input"]["src_lengths"] - 1
+ assert (
+ samples["net_input"]["src_tokens"][
+ torch.arange(eos_idx.size(0)), eos_idx
+ ]
+ != self.src_eos
).sum() == 0
- eos_idx = eos_idx.resize_(len(samples['net_input']['src_lengths']), 1)
- samples['net_input']['src_tokens'].scatter_(1, eos_idx, self.new_src_eos)
+ eos_idx = eos_idx.resize_(len(samples["net_input"]["src_lengths"]), 1)
+ samples["net_input"]["src_tokens"].scatter_(
+ 1, eos_idx, self.new_src_eos
+ )
- if self.new_tgt_bos is not None and 'prev_output_tokens' in samples['net_input']:
+ if (
+ self.new_tgt_bos is not None
+ and "prev_output_tokens" in samples["net_input"]
+ ):
if self.dataset.left_pad_target:
# TODO: support different padding direction on target side
raise NotImplementedError(
- 'TransformEosLangPairDataset does not implement --left-pad-target True option'
+ "TransformEosLangPairDataset does not implement --left-pad-target True option"
)
else:
- assert (samples['net_input']['prev_output_tokens'][:, 0] != self.tgt_bos).sum() == 0
- samples['net_input']['prev_output_tokens'][:, 0] = self.new_tgt_bos
+ assert (
+ samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos
+ ).sum() == 0
+ samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos
return samples
@@ -88,7 +102,7 @@ def ordered_indices(self):
@property
def supports_prefetch(self):
- return getattr(self.dataset, 'supports_prefetch', False)
+ return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py
index 0685c968d5..ed1d12d865 100644
--- a/fairseq/dataclass/data_class.py
+++ b/fairseq/dataclass/data_class.py
@@ -113,12 +113,13 @@ class CommonParams(FairseqDataclass):
default="", metadata={"help": "suffix to add to the checkpoint file name"}
)
checkpoint_shard_count: int = field(
- default=1, metadata={
+ default=1,
+ metadata={
"help": "Number of shards containing the checkpoint - "
- "if the checkpoint is over 300GB, it is preferable "
- "to split it into shards to prevent OOM on CPU while loading "
- "the checkpoint"
- }
+ "if the checkpoint is over 300GB, it is preferable "
+ "to split it into shards to prevent OOM on CPU while loading "
+ "the checkpoint"
+ },
)
quantization_config_path: Optional[str] = field(
default=None, metadata={"help": "path to quantization config file"}
@@ -307,7 +308,10 @@ class DatasetParams(FairseqDataclass):
default=8, metadata={"help": "batch size will be a multiplier of this value"}
)
required_seq_len_multiple: int = field(
- default=1, metadata={"help": "maximum sequence length in batch will be a multiplier of this value"}
+ default=1,
+ metadata={
+ "help": "maximum sequence length in batch will be a multiplier of this value"
+ },
)
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = field(
default=None, metadata={"help": "output dataset implementation"}
@@ -351,8 +355,7 @@ class DatasetParams(FairseqDataclass):
batch_size_valid: Optional[int] = field(
default=None,
metadata={
- "help": "batch size of the validation batch"
- " (defaults to --batch-size)"
+ "help": "batch size of the validation batch" " (defaults to --batch-size)"
},
)
curriculum: int = field(
diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py
index 9ab235d16d..599cc2b4c2 100644
--- a/fairseq/dataclass/utils.py
+++ b/fairseq/dataclass/utils.py
@@ -164,7 +164,9 @@ def get_kwargs_from_dc(
raise NotImplementedError()
if field_default is not MISSING:
kwargs["default"] = ",".join(map(str, field_default))
- elif (isinstance(inter_type, type) and issubclass(inter_type, Enum)) or "Enum" in str(inter_type):
+ elif (
+ isinstance(inter_type, type) and issubclass(inter_type, Enum)
+ ) or "Enum" in str(inter_type):
kwargs["type"] = str
if field_default is not MISSING:
if isinstance(field_default, Enum):
@@ -184,7 +186,7 @@ def get_kwargs_from_dc(
kwargs["help"] = field_help
if field_const is not None:
kwargs["const"] = field_const
- kwargs["nargs"] = '?'
+ kwargs["nargs"] = "?"
return kwargs
for k in dataclass_instance._get_all_attributes():
diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py
index ab5aad1425..bcb0595e6e 100644
--- a/fairseq/distributed_utils.py
+++ b/fairseq/distributed_utils.py
@@ -16,7 +16,6 @@
import torch
import torch.distributed as dist
-
from fairseq import utils
@@ -28,83 +27,103 @@ def is_master(args):
def infer_init_method(args, force_distributed=False):
- if args.distributed_init_method is not None or getattr(args, 'tpu', False):
+ if args.distributed_init_method is not None or getattr(args, "tpu", False):
return
if args.pipeline_model_parallel:
- balance_exists = args.pipeline_balance is not None or \
- args.pipeline_encoder_balance is not None or \
- args.pipeline_decoder_balance is not None
- devices_exist = args.pipeline_devices is not None or \
- args.pipeline_encoder_devices is not None or \
- args.pipeline_decoder_devices is not None
+ balance_exists = (
+ args.pipeline_balance is not None
+ or args.pipeline_encoder_balance is not None
+ or args.pipeline_decoder_balance is not None
+ )
+ devices_exist = (
+ args.pipeline_devices is not None
+ or args.pipeline_encoder_devices is not None
+ or args.pipeline_decoder_devices is not None
+ )
if not balance_exists:
- raise ValueError('--pipeline-balance is currently required for pipeline model parallelism')
+ raise ValueError(
+ "--pipeline-balance is currently required for pipeline model parallelism"
+ )
if not devices_exist:
- raise ValueError('--pipeline-devices is currently required for pipeline model parallelism')
+ raise ValueError(
+ "--pipeline-devices is currently required for pipeline model parallelism"
+ )
args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int)
if args.pipeline_devices is not None:
args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int)
num_pipeline_devices = len(set(args.pipeline_devices))
else:
- args.pipeline_encoder_devices = utils.eval_str_list(args.pipeline_encoder_devices, type=int)
- args.pipeline_decoder_devices = utils.eval_str_list(args.pipeline_decoder_devices, type=int)
- num_pipeline_devices = len(set(args.pipeline_encoder_devices + args.pipeline_decoder_devices))
+ args.pipeline_encoder_devices = utils.eval_str_list(
+ args.pipeline_encoder_devices, type=int
+ )
+ args.pipeline_decoder_devices = utils.eval_str_list(
+ args.pipeline_decoder_devices, type=int
+ )
+ num_pipeline_devices = len(
+ set(args.pipeline_encoder_devices + args.pipeline_decoder_devices)
+ )
gpus_per_node = torch.cuda.device_count()
- assert gpus_per_node >= num_pipeline_devices and gpus_per_node % num_pipeline_devices == 0, (
- 'the number of unique device IDs in --pipeline-devices must evenly divide '
- 'the number of GPUs per node (multi-node pipelining is not yet supported)'
+ assert (
+ gpus_per_node >= num_pipeline_devices
+ and gpus_per_node % num_pipeline_devices == 0
+ ), (
+ "the number of unique device IDs in --pipeline-devices must evenly divide "
+ "the number of GPUs per node (multi-node pipelining is not yet supported)"
)
num_pipelines_per_node = gpus_per_node // num_pipeline_devices
# support torch.distributed.launch
- if all(key in os.environ for key in [
- 'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK'
- ]):
- args.distributed_init_method = 'env://'
- args.distributed_world_size = int(os.environ['WORLD_SIZE'])
- args.distributed_rank = int(os.environ['RANK'])
+ if all(
+ key in os.environ
+ for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
+ ):
+ args.distributed_init_method = "env://"
+ args.distributed_world_size = int(os.environ["WORLD_SIZE"])
+ args.distributed_rank = int(os.environ["RANK"])
# processes are created by torch.distributed.launch
args.distributed_no_spawn = True
# we can determine the init method automatically for Slurm
elif args.distributed_port > 0:
- node_list = os.environ.get('SLURM_STEP_NODELIST')
+ node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None:
- node_list = os.environ.get('SLURM_JOB_NODELIST')
+ node_list = os.environ.get("SLURM_JOB_NODELIST")
if node_list is not None:
try:
- hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
- args.distributed_init_method = 'tcp://{host}:{port}'.format(
- host=hostnames.split()[0].decode('utf-8'),
+ hostnames = subprocess.check_output(
+ ["scontrol", "show", "hostnames", node_list]
+ )
+ args.distributed_init_method = "tcp://{host}:{port}".format(
+ host=hostnames.split()[0].decode("utf-8"),
port=args.distributed_port,
)
- nnodes = int(os.environ.get('SLURM_NNODES'))
- ntasks_per_node = os.environ.get('SLURM_NTASKS_PER_NODE')
+ nnodes = int(os.environ.get("SLURM_NNODES"))
+ ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
if ntasks_per_node is not None:
ntasks_per_node = int(ntasks_per_node)
else:
- ntasks = int(os.environ.get('SLURM_NTASKS'))
- nnodes = int(os.environ.get('SLURM_NNODES'))
+ ntasks = int(os.environ.get("SLURM_NTASKS"))
+ nnodes = int(os.environ.get("SLURM_NNODES"))
assert ntasks % nnodes == 0
ntasks_per_node = int(ntasks / nnodes)
if ntasks_per_node == 1:
gpus_per_node = torch.cuda.device_count()
- node_id = int(os.environ.get('SLURM_NODEID'))
+ node_id = int(os.environ.get("SLURM_NODEID"))
args.distributed_rank = node_id * gpus_per_node
args.distributed_world_size = nnodes * gpus_per_node
elif args.pipeline_model_parallel:
assert ntasks_per_node == num_pipelines_per_node, (
- 'SLURM --ntasks-per-node must match number of pipelines per '
- 'node (={})'.format(num_pipelines_per_node)
+ "SLURM --ntasks-per-node must match number of pipelines per "
+ "node (={})".format(num_pipelines_per_node)
)
args.distributed_no_spawn = True
# For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on
# the first node, [1, 2] on the second node, etc. This
# matches torch.distributed.launch.
- node_id = int(os.environ.get('SLURM_NODEID'))
- local_id = int(os.environ.get('SLURM_LOCALID'))
+ node_id = int(os.environ.get("SLURM_NODEID"))
+ local_id = int(os.environ.get("SLURM_LOCALID"))
args.distributed_rank = node_id * num_pipelines_per_node + local_id
# In the above example, device_id will always be in [0, 1],
# which also matches torch.distributed.launch.
@@ -115,8 +134,8 @@ def infer_init_method(args, force_distributed=False):
else:
assert ntasks_per_node == args.distributed_world_size // nnodes
args.distributed_no_spawn = True
- args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
- args.device_id = int(os.environ.get('SLURM_LOCALID'))
+ args.distributed_rank = int(os.environ.get("SLURM_PROCID"))
+ args.device_id = int(os.environ.get("SLURM_LOCALID"))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
@@ -126,7 +145,7 @@ def infer_init_method(args, force_distributed=False):
# fallback for single node with multiple GPUs
assert args.distributed_world_size <= torch.cuda.device_count()
port = random.randint(10000, 20000)
- args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
+ args.distributed_init_method = "tcp://localhost:{port}".format(port=port)
if args.pipeline_model_parallel:
if not args.distributed_no_spawn:
@@ -134,7 +153,9 @@ def infer_init_method(args, force_distributed=False):
# distributed_world_size to be based on the total number of GPUs, so
# we need to correct them to be based on the number of pipelines.
assert args.distributed_world_size % num_pipeline_devices == 0
- args.distributed_world_size = args.distributed_world_size // num_pipeline_devices
+ args.distributed_world_size = (
+ args.distributed_world_size // num_pipeline_devices
+ )
# In the case of 4-way MP on nodes with 8 GPUs, we want
# distributed_rank to be the starting GPU index for each pipeline
# i.e., 0, 2, ...
@@ -152,14 +173,16 @@ def infer_init_method(args, force_distributed=False):
# if there's multiple pipelines on a node (e.g., 4-way MP on an 8
# GPU node), we need to adjust pipeline_devices accordingly
logger.debug(
- "setting CUDA device={} on rank {}"
- .format(args.device_id, args.distributed_rank)
+ "setting CUDA device={} on rank {}".format(
+ args.device_id, args.distributed_rank
+ )
)
torch.cuda.set_device(args.device_id)
args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices]
logger.info(
- "setting pipeline_devices={} on rank {}"
- .format(args.pipeline_devices, args.distributed_rank),
+ "setting pipeline_devices={} on rank {}".format(
+ args.pipeline_devices, args.distributed_rank
+ ),
)
elif not args.distributed_no_spawn:
args.distributed_num_procs = min(
@@ -169,22 +192,30 @@ def infer_init_method(args, force_distributed=False):
def distributed_init(args):
- if not getattr(args, 'tpu', False):
+ if not getattr(args, "tpu", False):
if torch.distributed.is_initialized():
- warnings.warn('Distributed is already initialized, cannot initialize twice!')
+ warnings.warn(
+ "Distributed is already initialized, cannot initialize twice!"
+ )
else:
- logger.info('distributed init (rank {}): {}'.format(
- args.distributed_rank, args.distributed_init_method,
- ))
+ logger.info(
+ "distributed init (rank {}): {}".format(
+ args.distributed_rank,
+ args.distributed_init_method,
+ )
+ )
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
- logger.info('initialized host {} as rank {}'.format(
- socket.gethostname(), args.distributed_rank,
- ))
+ logger.info(
+ "initialized host {} as rank {}".format(
+ socket.gethostname(),
+ args.distributed_rank,
+ )
+ )
# perform a dummy all-reduce to initialize the NCCL communicator
if torch.cuda.is_available():
@@ -193,10 +224,11 @@ def distributed_init(args):
args.distributed_rank = torch.distributed.get_rank()
else:
import torch_xla.core.xla_model as xm
+
assert xm.xrt_world_size() == args.distributed_world_size
args.device_id = xm.get_local_ordinal()
args.distributed_rank = xm.get_ordinal()
- xm.rendezvous('distributed_init') # wait for all workers
+ xm.rendezvous("distributed_init") # wait for all workers
xm.mark_step()
if not is_master(args):
@@ -211,14 +243,14 @@ def distributed_init(args):
)
except ImportError:
raise ImportError(
- '\n\nPlease install the megatron submodule:'
- '\n\n git submodule update --init '
- 'fairseq/model_parallel/megatron'
+ "\n\nPlease install the megatron submodule:"
+ "\n\n git submodule update --init "
+ "fairseq/model_parallel/megatron"
)
initialize_model_parallel(args.model_parallel_size)
model_parallel_cuda_manual_seed(args.seed)
model_part_number = get_model_parallel_rank()
- args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number)
+ args.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
return args.distributed_rank
@@ -227,11 +259,11 @@ def distributed_main(i, main, args, kwargs):
if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False):
torch.cuda.set_device(args.device_id)
if args.distributed_rank is None: # torch.multiprocessing.spawn
- args.distributed_rank = kwargs.pop('start_rank', 0) + i
+ args.distributed_rank = kwargs.pop("start_rank", 0) + i
args.distributed_rank = distributed_init(args)
- after_distributed_init_fn = kwargs.pop('after_distributed_init_fn', None)
+ after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None)
if after_distributed_init_fn:
args = after_distributed_init_fn(args)
@@ -247,7 +279,7 @@ def call_main(args, main, **kwargs):
if not args.distributed_no_spawn:
start_rank = args.distributed_rank
args.distributed_rank = None # assign automatically
- kwargs['start_rank'] = start_rank
+ kwargs["start_rank"] = start_rank
torch.multiprocessing.spawn(
fn=distributed_main,
args=(main, args, kwargs),
@@ -257,6 +289,7 @@ def call_main(args, main, **kwargs):
distributed_main(args.device_id, main, args, kwargs)
elif getattr(args, "tpu", False) and args.distributed_world_size > 1:
import torch_xla.distributed.xla_multiprocessing as xmp
+
torch.multiprocessing.set_sharing_strategy("file_system")
xmp.spawn(
fn=distributed_main,
@@ -281,9 +314,10 @@ def get_default_group():
def all_reduce(tensor, group=None):
- if isinstance(group, tuple) and group[0] == 'tpu':
+ if isinstance(group, tuple) and group[0] == "tpu":
import torch_xla.core.xla_model as xm
- return xm.all_reduce('sum', [tensor], groups=group[1])
+
+ return xm.all_reduce("sum", [tensor], groups=group[1])
else:
if group is None:
group = get_default_group()
@@ -306,8 +340,10 @@ def all_gather_list(data, group=None, max_size=16384):
world_size = get_world_size()
buffer_size = max_size * world_size
- if not hasattr(all_gather_list, '_buffer') or \
- all_gather_list._buffer.numel() < buffer_size:
+ if (
+ not hasattr(all_gather_list, "_buffer")
+ or all_gather_list._buffer.numel() < buffer_size
+ ):
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
@@ -320,12 +356,14 @@ def all_gather_list(data, group=None, max_size=16384):
header_size = 4 # size of header that contains the length of the encoded data
size = header_size + enc_size
if size > max_size:
- raise ValueError('encoded data size ({}) exceeds max_size ({})'.format(size, max_size))
+ raise ValueError(
+ "encoded data size ({}) exceeds max_size ({})".format(size, max_size)
+ )
header = struct.pack(">I", enc_size)
cpu_buffer[:size] = torch.ByteTensor(list(header + enc))
start = rank * max_size
- buffer[start:start + size].copy_(cpu_buffer[:size])
+ buffer[start : start + size].copy_(cpu_buffer[:size])
all_reduce(buffer, group=group)
@@ -333,20 +371,24 @@ def all_gather_list(data, group=None, max_size=16384):
try:
result = []
for i in range(world_size):
- out_buffer = buffer[i * max_size:(i + 1) * max_size]
- enc_size, = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
+ out_buffer = buffer[i * max_size : (i + 1) * max_size]
+ (enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
if enc_size > 0:
- result.append(pickle.loads(bytes(out_buffer[header_size:header_size + enc_size].tolist())))
+ result.append(
+ pickle.loads(
+ bytes(out_buffer[header_size : header_size + enc_size].tolist())
+ )
+ )
return result
except pickle.UnpicklingError:
raise Exception(
- 'Unable to unpickle data from other workers. all_gather_list requires all '
- 'workers to enter the function together, so this error usually indicates '
- 'that the workers have fallen out of sync somehow. Workers can fall out of '
- 'sync if one of them runs out of memory, or if there are other conditions '
- 'in your training script that can cause one worker to finish an epoch '
- 'while other workers are still iterating over their portions of the data. '
- 'Try rerunning with --ddp-backend=no_c10d and see if that helps.'
+ "Unable to unpickle data from other workers. all_gather_list requires all "
+ "workers to enter the function together, so this error usually indicates "
+ "that the workers have fallen out of sync somehow. Workers can fall out of "
+ "sync if one of them runs out of memory, or if there are other conditions "
+ "in your training script that can cause one worker to finish an epoch "
+ "while other workers are still iterating over their portions of the data. "
+ "Try rerunning with --ddp-backend=no_c10d and see if that helps."
)
diff --git a/fairseq/file_utils.py b/fairseq/file_utils.py
index 62278b367d..0a94ac7112 100644
--- a/fairseq/file_utils.py
+++ b/fairseq/file_utils.py
@@ -10,25 +10,28 @@
"""
import fnmatch
-from functools import wraps, partial
-from hashlib import sha256
-from io import open
import json
import logging
import os
import shutil
import tarfile
import tempfile
+from functools import partial, wraps
+from hashlib import sha256
+from io import open
try:
from torch.hub import _get_torch_home
+
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
- os.getenv('TORCH_HOME', os.path.join(
- os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
-default_cache_path = os.path.join(torch_cache_home, 'pytorch_fairseq')
+ os.getenv(
+ "TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")
+ )
+ )
+default_cache_path = os.path.join(torch_cache_home, "pytorch_fairseq")
try:
from urllib.parse import urlparse
@@ -37,11 +40,10 @@
try:
from pathlib import Path
- PYTORCH_FAIRSEQ_CACHE = Path(
- os.getenv('PYTORCH_FAIRSEQ_CACHE', default_cache_path))
+
+ PYTORCH_FAIRSEQ_CACHE = Path(os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path))
except (AttributeError, ImportError):
- PYTORCH_FAIRSEQ_CACHE = os.getenv(
- 'PYTORCH_FAIRSEQ_CACHE', default_cache_path)
+ PYTORCH_FAIRSEQ_CACHE = os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path)
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
@@ -67,17 +69,23 @@ def load_archive_file(archive_file):
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
else:
- logger.info("loading archive file {} from cache at {}".format(
- archive_file, resolved_archive_file))
+ logger.info(
+ "loading archive file {} from cache at {}".format(
+ archive_file, resolved_archive_file
+ )
+ )
# Extract archive to temp dir and replace .tar.bz2 if necessary
tempdir = None
if not os.path.isdir(resolved_archive_file):
tempdir = tempfile.mkdtemp()
- logger.info("extracting archive file {} to temp dir {}".format(
- resolved_archive_file, tempdir))
+ logger.info(
+ "extracting archive file {} to temp dir {}".format(
+ resolved_archive_file, tempdir
+ )
+ )
ext = os.path.splitext(archive_file)[1][1:]
- with tarfile.open(resolved_archive_file, 'r:' + ext) as archive:
+ with tarfile.open(resolved_archive_file, "r:" + ext) as archive:
top_dir = os.path.commonprefix(archive.getnames())
archive.extractall(tempdir)
os.remove(resolved_archive_file)
@@ -93,14 +101,14 @@ def url_to_filename(url, etag=None):
If `etag` is specified, append its hash to the URL's, delimited
by a period.
"""
- url_bytes = url.encode('utf-8')
+ url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
- etag_bytes = etag.encode('utf-8')
+ etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
- filename += '.' + etag_hash.hexdigest()
+ filename += "." + etag_hash.hexdigest()
return filename
@@ -119,14 +127,14 @@ def filename_to_url(filename, cache_dir=None):
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
- meta_path = cache_path + '.json'
+ meta_path = cache_path + ".json"
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
- url = metadata['url']
- etag = metadata['etag']
+ url = metadata["url"]
+ etag = metadata["etag"]
return url, etag
@@ -147,18 +155,20 @@ def cached_path(url_or_filename, cache_dir=None):
parsed = urlparse(url_or_filename)
- if parsed.scheme in ('http', 'https', 's3'):
+ if parsed.scheme in ("http", "https", "s3"):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
- elif parsed.scheme == '':
+ elif parsed.scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
- raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
+ raise ValueError(
+ "unable to parse {} as a URL or as a local path".format(url_or_filename)
+ )
def split_s3_path(url):
@@ -183,6 +193,7 @@ def s3_request(func):
@wraps(func)
def wrapper(url, *args, **kwargs):
from botocore.exceptions import ClientError
+
try:
return func(url, *args, **kwargs)
except ClientError as exc:
@@ -198,6 +209,7 @@ def wrapper(url, *args, **kwargs):
def s3_etag(url):
"""Check ETag on S3 object."""
import boto3
+
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
@@ -208,6 +220,7 @@ def s3_etag(url):
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
import boto3
+
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
@@ -215,12 +228,18 @@ def s3_get(url, temp_file):
def request_wrap_timeout(func, url):
import requests
+
for attempt, timeout in enumerate([10, 20, 40, 60, 60]):
try:
return func(timeout=timeout)
except requests.exceptions.Timeout as e:
- logger.warning("Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs",
- url, attempt, timeout, exc_info=e)
+ logger.warning(
+ "Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs",
+ url,
+ attempt,
+ timeout,
+ exc_info=e,
+ )
continue
raise RuntimeError(f"Unable to fetch file {url}")
@@ -230,7 +249,7 @@ def http_get(url, temp_file):
from tqdm import tqdm
req = request_wrap_timeout(partial(requests.get, url, stream=True), url)
- content_length = req.headers.get('Content-Length')
+ content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
@@ -259,7 +278,10 @@ def get_from_cache(url, cache_dir=None):
else:
try:
import requests
- response = request_wrap_timeout(partial(requests.head, url, allow_redirects=True), url)
+
+ response = request_wrap_timeout(
+ partial(requests.head, url, allow_redirects=True), url
+ )
if response.status_code != 200:
etag = None
else:
@@ -275,8 +297,8 @@ def get_from_cache(url, cache_dir=None):
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
- matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
- matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
+ matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*")
+ matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
@@ -298,13 +320,13 @@ def get_from_cache(url, cache_dir=None):
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
- with open(cache_path, 'wb') as cache_file:
+ with open(cache_path, "wb") as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
- meta = {'url': url, 'etag': etag}
- meta_path = cache_path + '.json'
- with open(meta_path, 'w') as meta_file:
+ meta = {"url": url, "etag": etag}
+ meta_path = cache_path + ".json"
+ with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta)
meta_file.write(output_string)
@@ -314,12 +336,12 @@ def get_from_cache(url, cache_dir=None):
def read_set_from_file(filename):
- '''
+ """
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
- '''
+ """
collection = set()
- with open(filename, 'r', encoding='utf-8') as file_:
+ with open(filename, "r", encoding="utf-8") as file_:
for line in file_:
collection.add(line.rstrip())
return collection
diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py
index b56135abf3..b293e54e2a 100644
--- a/fairseq/hub_utils.py
+++ b/fairseq/hub_utils.py
@@ -8,13 +8,12 @@
import copy
import logging
import os
-from typing import List, Dict, Iterator, Tuple, Any
+from typing import Any, Dict, Iterator, List, Tuple
import torch
-from torch import nn
-
from fairseq import utils
from fairseq.data import encoders
+from torch import nn
logger = logging.getLogger(__name__)
@@ -22,8 +21,8 @@
def from_pretrained(
model_name_or_path,
- checkpoint_file='model.pt',
- data_name_or_path='.',
+ checkpoint_file="model.pt",
+ data_name_or_path=".",
archive_map=None,
**kwargs
):
@@ -39,34 +38,34 @@ def from_pretrained(
# for each model
if isinstance(model_name_or_path, dict):
for k, v in model_name_or_path.items():
- if k == 'checkpoint_file':
+ if k == "checkpoint_file":
checkpoint_file = v
elif (
- k != 'path'
+ k != "path"
# only set kwargs that don't already have overrides
and k not in kwargs
):
kwargs[k] = v
- model_name_or_path = model_name_or_path['path']
+ model_name_or_path = model_name_or_path["path"]
model_path = file_utils.load_archive_file(model_name_or_path)
# convenience hack for loading data and BPE codes from model archive
- if data_name_or_path.startswith('.'):
- kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path))
+ if data_name_or_path.startswith("."):
+ kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path))
else:
- kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
+ kwargs["data"] = file_utils.load_archive_file(data_name_or_path)
for file, arg in {
- 'code': 'bpe_codes',
- 'bpecodes': 'bpe_codes',
- 'sentencepiece.bpe.model': 'sentencepiece_model',
+ "code": "bpe_codes",
+ "bpecodes": "bpe_codes",
+ "sentencepiece.bpe.model": "sentencepiece_model",
}.items():
path = os.path.join(model_path, file)
if os.path.exists(path):
kwargs[arg] = path
- if 'user_dir' in kwargs:
- utils.import_user_module(argparse.Namespace(user_dir=kwargs['user_dir']))
+ if "user_dir" in kwargs:
+ utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
@@ -74,9 +73,9 @@ def from_pretrained(
)
return {
- 'args': args,
- 'task': task,
- 'models': models,
+ "args": args,
+ "task": task,
+ "models": models,
}
@@ -100,7 +99,7 @@ def __init__(self, args, task, models):
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
- self.align_dict = utils.load_align_dict(getattr(args, 'replace_unk', None))
+ self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None))
self.tokenizer = encoders.build_tokenizer(args)
self.bpe = encoders.build_bpe(args)
@@ -110,28 +109,37 @@ def __init__(self, args, task, models):
)
# this is useful for determining the device
- self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
+ self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
- def translate(self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs) -> List[str]:
+ def translate(
+ self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs
+ ) -> List[str]:
return self.sample(sentences, beam, verbose, **kwargs)
- def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
+ def sample(
+ self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
+ ) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
- return [self.decode(hypos[0]['tokens']) for hypos in batched_hypos]
+ return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
def score(self, sentences: List[str], **kwargs):
if isinstance(sentences, str):
return self.score([sentences], **kwargs)[0]
# NOTE: this doesn't support translation tasks currently
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
- return [hypos[0] for hypos in self.generate(tokenized_sentences, score_reference=True, **kwargs)]
+ return [
+ hypos[0]
+ for hypos in self.generate(
+ tokenized_sentences, score_reference=True, **kwargs
+ )
+ ]
def generate(
self,
@@ -174,17 +182,33 @@ def getarg(name, default):
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
src_str_with_unk = self.string(source_tokens)
- logger.info('S\t{}'.format(src_str_with_unk))
+ logger.info("S\t{}".format(src_str_with_unk))
for hypo in target_hypotheses:
- hypo_str = self.decode(hypo['tokens'])
- logger.info('H\t{}\t{}'.format(hypo['score'], hypo_str))
- logger.info('P\t{}'.format(
- ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
- ))
- if hypo['alignment'] is not None and getarg('print_alignment', False):
- logger.info('A\t{}'.format(
- ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in hypo['alignment']])
- ))
+ hypo_str = self.decode(hypo["tokens"])
+ logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
+ logger.info(
+ "P\t{}".format(
+ " ".join(
+ map(
+ lambda x: "{:.4f}".format(x),
+ hypo["positional_scores"].tolist(),
+ )
+ )
+ )
+ )
+ if hypo["alignment"] is not None and getarg(
+ "print_alignment", False
+ ):
+ logger.info(
+ "A\t{}".format(
+ " ".join(
+ [
+ "{}-{}".format(src_idx, tgt_idx)
+ for src_idx, tgt_idx in hypo["alignment"]
+ ]
+ )
+ )
+ )
return outputs
def encode(self, sentence: str) -> torch.LongTensor:
diff --git a/fairseq/incremental_decoding_utils.py b/fairseq/incremental_decoding_utils.py
index 91128e8879..b26e6cd01c 100644
--- a/fairseq/incremental_decoding_utils.py
+++ b/fairseq/incremental_decoding_utils.py
@@ -3,14 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from typing import Dict, Optional
import uuid
+from typing import Dict, Optional
from torch import Tensor
class FairseqIncrementalState(object):
-
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_incremental_state()
@@ -46,5 +45,7 @@ def set_incremental_state(
def with_incremental_state(cls):
- cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(
+ b for b in cls.__bases__ if b != FairseqIncrementalState
+ )
return cls
diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py
index 6ac805988a..4fb0946f49 100644
--- a/fairseq/iterative_refinement_generator.py
+++ b/fairseq/iterative_refinement_generator.py
@@ -5,20 +5,15 @@
from collections import namedtuple
-import torch
import numpy as np
-
+import torch
from fairseq import utils
-DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
- 'output_tokens',
- 'output_scores',
- 'attn',
- 'step',
- 'max_step',
- 'history'
-])
+DecoderOut = namedtuple(
+ "IterativeRefinementDecoderOut",
+ ["output_tokens", "output_scores", "attn", "step", "max_step", "history"],
+)
class IterativeRefinementGenerator(object):
@@ -103,11 +98,12 @@ def generate_batched_itr(
ref = utils.strip_pad(sample["target"][i, :], self.pad)
yield id, src, ref, hypos[i]
-
@torch.no_grad()
def generate(self, models, sample, prefix_tokens=None, constraints=None):
if constraints is not None:
- raise NotImplementedError("Constrained decoding with the IterativeRefinementGenerator is not supported")
+ raise NotImplementedError(
+ "Constrained decoding with the IterativeRefinementGenerator is not supported"
+ )
# TODO: iterative refinement generator does not support ensemble for now.
if not self.retain_dropout:
@@ -117,13 +113,17 @@ def generate(self, models, sample, prefix_tokens=None, constraints=None):
model, reranker = models[0], None
if self.reranking:
assert len(models) > 1, "Assuming the last checkpoint is the reranker"
- assert self.beam_size > 1, "Reranking requires multiple translation for each example"
+ assert (
+ self.beam_size > 1
+ ), "Reranking requires multiple translation for each example"
reranker = models[-1]
models = models[:-1]
- if len(models) > 1 and hasattr(model, 'enable_ensemble'):
- assert model.allow_ensemble, "{} does not support ensembling".format(model.__class__.__name__)
+ if len(models) > 1 and hasattr(model, "enable_ensemble"):
+ assert model.allow_ensemble, "{} does not support ensembling".format(
+ model.__class__.__name__
+ )
model.enable_ensemble(models)
# TODO: better encoder inputs?
@@ -136,13 +136,22 @@ def generate(self, models, sample, prefix_tokens=None, constraints=None):
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
if self.beam_size > 1:
- assert model.allow_length_beam, \
- "{} does not support decoding with length beam.".format(model.__class__.__name__)
+ assert (
+ model.allow_length_beam
+ ), "{} does not support decoding with length beam.".format(
+ model.__class__.__name__
+ )
# regenerate data based on length-beam
- length_beam_order = utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1)
- encoder_out = model.encoder.reorder_encoder_out(encoder_out, length_beam_order)
- prev_decoder_out = model.regenerate_length_beam(prev_decoder_out, self.beam_size)
+ length_beam_order = (
+ utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1)
+ )
+ encoder_out = model.encoder.reorder_encoder_out(
+ encoder_out, length_beam_order
+ )
+ prev_decoder_out = model.regenerate_length_beam(
+ prev_decoder_out, self.beam_size
+ )
bsz = bsz * self.beam_size
sent_idxs = torch.arange(bsz)
@@ -206,7 +215,10 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
if self.adaptive:
# terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop(
- prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn
+ prev_output_tokens,
+ decoder_out.output_tokens,
+ decoder_out.output_scores,
+ decoder_out.attn,
)
decoder_out = decoder_out._replace(
output_tokens=out_tokens,
@@ -215,7 +227,9 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
)
else:
- terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool()
+ terminated = decoder_out.output_tokens.new_zeros(
+ decoder_out.output_tokens.size(0)
+ ).bool()
if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1)
@@ -225,7 +239,9 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
- None if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) else decoder_out.attn[terminated]
+ None
+ if (decoder_out.attn is None or decoder_out.attn.size(0) == 0)
+ else decoder_out.attn[terminated]
)
if self.retain_history:
@@ -242,13 +258,11 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
]
if self.retain_history:
- finalized[finalized_idxs[i]][0]['history'] = []
+ finalized[finalized_idxs[i]][0]["history"] = []
for j in range(len(finalized_history_tokens)):
- finalized[finalized_idxs[i]][0]['history'].append(
+ finalized[finalized_idxs[i]][0]["history"].append(
finalized_hypos(
- step,
- finalized_history_tokens[j][i],
- None, None
+ step, finalized_history_tokens[j][i], None, None
)
)
@@ -268,7 +282,9 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
if decoder_out.history is not None
else None,
)
- encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero(as_tuple=False).squeeze())
+ encoder_out = model.encoder.reorder_encoder_out(
+ encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()
+ )
sent_idxs = sent_idxs[not_terminated]
prev_output_tokens = prev_decoder_out.output_tokens.clone()
@@ -280,38 +296,64 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
# aggregate information from length beam
finalized = [
- finalized[np.argmax(
- [finalized[self.beam_size * i + j][0]['score'] for j in range(self.beam_size)]
- ) + self.beam_size * i] for i in range(len(finalized) // self.beam_size)
+ finalized[
+ np.argmax(
+ [
+ finalized[self.beam_size * i + j][0]["score"]
+ for j in range(self.beam_size)
+ ]
+ )
+ + self.beam_size * i
]
+ for i in range(len(finalized) // self.beam_size)
+ ]
return finalized
def rerank(self, reranker, finalized, encoder_input, beam_size):
-
def rebuild_batch(finalized):
- finalized_tokens = [f[0]['tokens'] for f in finalized]
+ finalized_tokens = [f[0]["tokens"] for f in finalized]
finalized_maxlen = max(f.size(0) for f in finalized_tokens)
- final_output_tokens = finalized_tokens[0].new_zeros(len(finalized_tokens), finalized_maxlen).fill_(self.pad)
+ final_output_tokens = (
+ finalized_tokens[0]
+ .new_zeros(len(finalized_tokens), finalized_maxlen)
+ .fill_(self.pad)
+ )
for i, f in enumerate(finalized_tokens):
- final_output_tokens[i, :f.size(0)] = f
+ final_output_tokens[i, : f.size(0)] = f
return final_output_tokens
final_output_tokens = rebuild_batch(finalized)
- final_output_tokens[:, 0] = self.eos # autoregressive model assumes starting with EOS
+ final_output_tokens[
+ :, 0
+ ] = self.eos # autoregressive model assumes starting with EOS
reranker_encoder_out = reranker.encoder(*encoder_input)
- length_beam_order = utils.new_arange(
- final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)).t().reshape(-1)
- reranker_encoder_out = reranker.encoder.reorder_encoder_out(reranker_encoder_out, length_beam_order)
+ length_beam_order = (
+ utils.new_arange(
+ final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)
+ )
+ .t()
+ .reshape(-1)
+ )
+ reranker_encoder_out = reranker.encoder.reorder_encoder_out(
+ reranker_encoder_out, length_beam_order
+ )
reranking_scores = reranker.get_normalized_probs(
- reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), True, None)
+ reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out),
+ True,
+ None,
+ )
reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None])
reranking_masks = final_output_tokens[:, 1:].ne(self.pad)
- reranking_scores = reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1)
- reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(reranking_scores)
+ reranking_scores = (
+ reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1)
+ )
+ reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(
+ reranking_scores
+ )
for i in range(len(finalized)):
- finalized[i][0]['score'] = reranking_scores[i]
+ finalized[i][0]["score"] = reranking_scores[i]
return finalized
diff --git a/fairseq/legacy_distributed_data_parallel.py b/fairseq/legacy_distributed_data_parallel.py
index 9832f2c97a..44f87c7c42 100644
--- a/fairseq/legacy_distributed_data_parallel.py
+++ b/fairseq/legacy_distributed_data_parallel.py
@@ -14,9 +14,9 @@
training with `--update-freq`.
"""
+import copy
from collections import OrderedDict
from contextlib import contextmanager
-import copy
import torch
from torch import nn
@@ -42,7 +42,7 @@ class LegacyDistributedDataParallel(nn.Module):
performing all-reduce (default: 256M).
"""
- def __init__(self, module, world_size, process_group=None, buffer_size=2**28):
+ def __init__(self, module, world_size, process_group=None, buffer_size=2 ** 28):
super().__init__()
self.module = module
@@ -66,7 +66,6 @@ def __init__(self, module, world_size, process_group=None, buffer_size=2**28):
paramlists[device] += [param]
self.per_device_params = list(paramlists.values())
-
def __getstate__(self):
attrs = copy.copy(self.__dict__)
return attrs
@@ -99,10 +98,10 @@ def all_reduce_params(params):
for p in params:
sz = p.numel()
if p.grad is not None:
- buffer[offset:offset+sz].copy_(p.grad.data.view(-1))
+ buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
nonzero_buffer = True
else:
- buffer[offset:offset+sz].zero_()
+ buffer[offset : offset + sz].zero_()
offset += sz
else:
# we only have a single grad to all-reduce
@@ -111,7 +110,7 @@ def all_reduce_params(params):
buffer = p.grad.data
nonzero_buffer = True
elif p.numel() <= self.buffer.numel():
- buffer = buffer[:p.numel()]
+ buffer = buffer[: p.numel()]
buffer.zero_()
else:
buffer = torch.zeros_like(p)
@@ -126,9 +125,9 @@ def all_reduce_params(params):
for p in params:
sz = p.numel()
if p.grad is not None:
- p.grad.data.copy_(buffer[offset:offset+sz].view_as(p))
+ p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
else:
- p.grad = buffer[offset:offset+sz].view_as(p).clone()
+ p.grad = buffer[offset : offset + sz].view_as(p).clone()
offset += sz
def reduction_fn():
@@ -149,9 +148,11 @@ def reduction_fn():
if param.grad is None:
param.grad = torch.zeros_like(param)
if param.grad.requires_grad:
- raise RuntimeError("DistributedDataParallel only works "
- "with gradients that don't require "
- "grad")
+ raise RuntimeError(
+ "DistributedDataParallel only works "
+ "with gradients that don't require "
+ "grad"
+ )
sz = param.numel()
if sz > self.buffer.numel():
# all-reduce big params directly
diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py
index 78e6d4d224..6793ef54e6 100644
--- a/fairseq/logging/meters.py
+++ b/fairseq/logging/meters.py
@@ -4,10 +4,11 @@
# LICENSE file in the root directory of this source tree.
import bisect
-from collections import OrderedDict
import time
+from collections import OrderedDict
from typing import Dict, Optional
+
try:
import torch
@@ -16,6 +17,8 @@ def type_as(a, b):
return a.to(b)
else:
return a
+
+
except ImportError:
torch = None
@@ -51,11 +54,11 @@ def smoothed_value(self) -> float:
def safe_round(number, ndigits):
- if hasattr(number, '__round__'):
+ if hasattr(number, "__round__"):
return round(number, ndigits)
elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
return safe_round(number.item(), ndigits)
- elif np is not None and np.ndim(number) == 0 and hasattr(number, 'item'):
+ elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
return safe_round(number.item(), ndigits)
else:
return number
@@ -82,17 +85,17 @@ def update(self, val, n=1):
def state_dict(self):
return {
- 'val': self.val,
- 'sum': self.sum,
- 'count': self.count,
- 'round': self.round,
+ "val": self.val,
+ "sum": self.sum,
+ "count": self.count,
+ "round": self.round,
}
def load_state_dict(self, state_dict):
- self.val = state_dict['val']
- self.sum = state_dict['sum']
- self.count = state_dict['count']
- self.round = state_dict.get('round', None)
+ self.val = state_dict["val"]
+ self.sum = state_dict["sum"]
+ self.count = state_dict["count"]
+ self.round = state_dict.get("round", None)
@property
def avg(self):
@@ -130,18 +133,18 @@ def update(self, val=1):
def state_dict(self):
return {
- 'init': self.elapsed_time,
- 'n': self.n,
- 'round': self.round,
+ "init": self.elapsed_time,
+ "n": self.n,
+ "round": self.round,
}
def load_state_dict(self, state_dict):
- if 'start' in state_dict:
+ if "start" in state_dict:
# backwards compatibility for old state_dicts
- self.reset(init=state_dict['init'])
+ self.reset(init=state_dict["init"])
else:
- self.reset(init=state_dict['init'], n=state_dict['n'])
- self.round = state_dict.get('round', None)
+ self.reset(init=state_dict["init"], n=state_dict["n"])
+ self.round = state_dict.get("round", None)
@property
def avg(self):
@@ -186,16 +189,16 @@ def reset(self):
def state_dict(self):
return {
- 'sum': self.sum,
- 'n': self.n,
- 'round': self.round,
+ "sum": self.sum,
+ "n": self.n,
+ "round": self.round,
}
def load_state_dict(self, state_dict):
- self.sum = state_dict['sum']
- self.n = state_dict['n']
+ self.sum = state_dict["sum"]
+ self.n = state_dict["n"]
self.start_time = None
- self.round = state_dict.get('round', None)
+ self.round = state_dict.get("round", None)
@property
def avg(self):
@@ -204,7 +207,7 @@ def avg(self):
@property
def elapsed_time(self):
if self.start_time is None:
- return 0.
+ return 0.0
return time.perf_counter() - self.start_time
@property
@@ -263,11 +266,13 @@ def get_smoothed_value(self, key: str) -> float:
def get_smoothed_values(self) -> Dict[str, float]:
"""Get all smoothed values."""
- return OrderedDict([
- (key, self.get_smoothed_value(key))
- for key in self.keys()
- if not key.startswith("_")
- ])
+ return OrderedDict(
+ [
+ (key, self.get_smoothed_value(key))
+ for key in self.keys()
+ if not key.startswith("_")
+ ]
+ )
def reset(self):
"""Reset Meter instances."""
diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py
index 6ca1d201e0..7b56e31592 100644
--- a/fairseq/logging/metrics.py
+++ b/fairseq/logging/metrics.py
@@ -11,11 +11,11 @@
:func:`aggregate` context manager for more details.
"""
-from collections import defaultdict, OrderedDict
import contextlib
import time
-from typing import Callable, Dict, List, Optional
import uuid
+from collections import OrderedDict, defaultdict
+from typing import Callable, Dict, List, Optional
from .meters import *
@@ -184,7 +184,7 @@ def log_start_time(key: str, priority: int = 40, round: Optional[int] = None):
agg[key].start()
-def log_stop_time(key: str, weight: float = 0., prehook=None):
+def log_stop_time(key: str, weight: float = 0.0, prehook=None):
"""Log the duration of some event in seconds.
The duration will be computed since :func:`log_start_time` was called.
@@ -279,10 +279,7 @@ def get_smoothed_values(name: str) -> Dict[str, float]:
def state_dict():
- return OrderedDict([
- (name, agg.state_dict())
- for name, agg in _aggregators.items()
- ])
+ return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])
def load_state_dict(state_dict):
diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py
index 97e4162ea0..63e5394815 100644
--- a/fairseq/logging/progress_bar.py
+++ b/fairseq/logging/progress_bar.py
@@ -32,29 +32,30 @@ def progress_bar(
epoch: Optional[int] = None,
prefix: Optional[str] = None,
tensorboard_logdir: Optional[str] = None,
- default_log_format: str = 'tqdm',
+ default_log_format: str = "tqdm",
):
if log_format is None:
log_format = default_log_format
- if log_format == 'tqdm' and not sys.stderr.isatty():
- log_format = 'simple'
+ if log_format == "tqdm" and not sys.stderr.isatty():
+ log_format = "simple"
- if log_format == 'json':
+ if log_format == "json":
bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
- elif log_format == 'none':
+ elif log_format == "none":
bar = NoopProgressBar(iterator, epoch, prefix)
- elif log_format == 'simple':
+ elif log_format == "simple":
bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
- elif log_format == 'tqdm':
+ elif log_format == "tqdm":
bar = TqdmProgressBar(iterator, epoch, prefix)
else:
- raise ValueError('Unknown log format: {}'.format(log_format))
+ raise ValueError("Unknown log format: {}".format(log_format))
if tensorboard_logdir:
try:
# [FB only] custom wrapper for TensorBoard
import palaas # noqa
from .fb_tbmf_wrapper import FbTbmfWrapper
+
bar = FbTbmfWrapper(bar, log_interval)
except ImportError:
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)
@@ -67,14 +68,14 @@ def build_progress_bar(
iterator,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
- default: str = 'tqdm',
- no_progress_bar: str = 'none',
+ default: str = "tqdm",
+ no_progress_bar: str = "none",
):
"""Legacy wrapper that takes an argparse.Namespace."""
- if getattr(args, 'no_progress_bar', False):
+ if getattr(args, "no_progress_bar", False):
default = no_progress_bar
- if getattr(args, 'distributed_rank', 0) == 0:
- tensorboard_logdir = getattr(args, 'tensorboard_logdir', None)
+ if getattr(args, "distributed_rank", 0) == 0:
+ tensorboard_logdir = getattr(args, "tensorboard_logdir", None)
else:
tensorboard_logdir = None
return progress_bar(
@@ -90,13 +91,13 @@ def build_progress_bar(
def format_stat(stat):
if isinstance(stat, Number):
- stat = '{:g}'.format(stat)
+ stat = "{:g}".format(stat)
elif isinstance(stat, AverageMeter):
- stat = '{:.3f}'.format(stat.avg)
+ stat = "{:.3f}".format(stat.avg)
elif isinstance(stat, TimeMeter):
- stat = '{:g}'.format(round(stat.avg))
+ stat = "{:g}".format(round(stat.avg))
elif isinstance(stat, StopwatchMeter):
- stat = '{:g}'.format(round(stat.sum))
+ stat = "{:g}".format(round(stat.sum))
elif torch.is_tensor(stat):
stat = stat.tolist()
return stat
@@ -104,15 +105,16 @@ def format_stat(stat):
class BaseProgressBar(object):
"""Abstract class for progress bars."""
+
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
- self.n = getattr(iterable, 'n', 0)
+ self.n = getattr(iterable, "n", 0)
self.epoch = epoch
- self.prefix = ''
+ self.prefix = ""
if epoch is not None:
- self.prefix += 'epoch {:03d}'.format(epoch)
+ self.prefix += "epoch {:03d}".format(epoch)
if prefix is not None:
- self.prefix += ' | {}'.format(prefix)
+ self.prefix += " | {}".format(prefix)
def __len__(self):
return len(self.iterable)
@@ -135,12 +137,10 @@ def print(self, stats, tag=None, step=None):
raise NotImplementedError
def _str_commas(self, stats):
- return ', '.join(key + '=' + stats[key].strip()
- for key in stats.keys())
+ return ", ".join(key + "=" + stats[key].strip() for key in stats.keys())
def _str_pipes(self, stats):
- return ' | '.join(key + ' ' + stats[key].strip()
- for key in stats.keys())
+ return " | ".join(key + " " + stats[key].strip() for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
@@ -177,11 +177,7 @@ def __iter__(self):
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
step = step or self.i or 0
- if (
- step > 0
- and self.log_interval is not None
- and step % self.log_interval == 0
- ):
+ if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
update = (
self.epoch - 1 + (self.i + 1) / float(self.size)
if self.epoch is not None
@@ -195,7 +191,9 @@ def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self.stats = stats
if tag is not None:
- self.stats = OrderedDict([(tag + '_' + k, v) for k, v in self.stats.items()])
+ self.stats = OrderedDict(
+ [(tag + "_" + k, v) for k, v in self.stats.items()]
+ )
stats = self._format_stats(self.stats, epoch=self.epoch)
with rename_logger(logger, tag):
logger.info(json.dumps(stats))
@@ -203,9 +201,9 @@ def print(self, stats, tag=None, step=None):
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
if epoch is not None:
- postfix['epoch'] = epoch
+ postfix["epoch"] = epoch
if update is not None:
- postfix['update'] = round(update, 3)
+ postfix["update"] = round(update, 3)
# Preprocess stats according to datatype
for key in stats.keys():
postfix[key] = format_stat(stats[key])
@@ -249,24 +247,21 @@ def __iter__(self):
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
step = step or self.i or 0
- if (
- step > 0
- and self.log_interval is not None
- and step % self.log_interval == 0
- ):
+ if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
stats = self._format_stats(stats)
postfix = self._str_commas(stats)
with rename_logger(logger, tag):
logger.info(
- '{}: {:5d} / {:d} {}'
- .format(self.prefix, self.i + 1, self.size, postfix)
+ "{}: {:5d} / {:d} {}".format(
+ self.prefix, self.i + 1, self.size, postfix
+ )
)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
with rename_logger(logger, tag):
- logger.info('{} | {}'.format(self.prefix, postfix))
+ logger.info("{} | {}".format(self.prefix, postfix))
class TqdmProgressBar(BaseProgressBar):
@@ -275,6 +270,7 @@ class TqdmProgressBar(BaseProgressBar):
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
from tqdm import tqdm
+
self.tqdm = tqdm(
iterable,
self.prefix,
@@ -293,7 +289,7 @@ def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
with rename_logger(logger, tag):
- logger.info('{} | {}'.format(self.prefix, postfix))
+ logger.info("{} | {}".format(self.prefix, postfix))
try:
@@ -329,7 +325,7 @@ def _writer(self, key):
_writers = _tensorboard_writers
if key not in _writers:
_writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key))
- _writers[key].add_text('sys.argv', " ".join(sys.argv))
+ _writers[key].add_text("sys.argv", " ".join(sys.argv))
return _writers[key]
def __iter__(self):
@@ -346,12 +342,12 @@ def print(self, stats, tag=None, step=None):
self.wrapped_bar.print(stats, tag=tag, step=step)
def _log_to_tensorboard(self, stats, tag=None, step=None):
- writer = self._writer(tag or '')
+ writer = self._writer(tag or "")
if writer is None:
return
if step is None:
- step = stats['num_updates']
- for key in stats.keys() - {'num_updates'}:
+ step = stats["num_updates"]
+ for key in stats.keys() - {"num_updates"}:
if isinstance(stats[key], AverageMeter):
writer.add_scalar(key, stats[key].val, step)
elif isinstance(stats[key], Number):
diff --git a/fairseq/model_parallel/__init__.py b/fairseq/model_parallel/__init__.py
index cc563db40b..69f2168487 100644
--- a/fairseq/model_parallel/__init__.py
+++ b/fairseq/model_parallel/__init__.py
@@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . import criterions, modules, models # noqa
+from . import criterions, models, modules # noqa
diff --git a/fairseq/model_parallel/criterions/__init__.py b/fairseq/model_parallel/criterions/__init__.py
index b74de55982..6239b50362 100644
--- a/fairseq/model_parallel/criterions/__init__.py
+++ b/fairseq/model_parallel/criterions/__init__.py
@@ -9,6 +9,6 @@
# automatically import any Python files in the criterions/ directory
for file in os.listdir(os.path.dirname(__file__)):
- if file.endswith('.py') and not file.startswith('_'):
- module = file[:file.find('.py')]
- importlib.import_module('fairseq.model_parallel.criterions.' + module)
+ if file.endswith(".py") and not file.startswith("_"):
+ module = file[: file.find(".py")]
+ importlib.import_module("fairseq.model_parallel.criterions." + module)
diff --git a/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py b/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py
index eab8f9af4e..35c50ee152 100644
--- a/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py
+++ b/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py
@@ -8,24 +8,27 @@
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
+
try:
- from fairseq.model_parallel.megatron.mpu.cross_entropy import vocab_parallel_cross_entropy
+ from fairseq.model_parallel.megatron.mpu.cross_entropy import (
+ vocab_parallel_cross_entropy,
+ )
+
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
-@register_criterion('vocab_parallel_cross_entropy')
+@register_criterion("vocab_parallel_cross_entropy")
class VocabParallelCrossEntropyCriterion(FairseqCriterion):
-
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
if not has_megatron_submodule:
raise ImportError(
- '\n\nPlease install the megatron submodule:'
- '\n\n git submodule update --init '
- 'fairseq/model_parallel/megatron'
+ "\n\nPlease install the megatron submodule:"
+ "\n\n git submodule update --init "
+ "fairseq/model_parallel/megatron"
)
def forward(self, model, sample, reduce=True):
@@ -36,33 +39,43 @@ def forward(self, model, sample, reduce=True):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
- net_output = model(**sample['net_input'])
- target = sample['target']
+ net_output = model(**sample["net_input"])
+ target = sample["target"]
loss = vocab_parallel_cross_entropy(net_output[0].float(), target)
loss = (loss * (target != self.padding_idx)).sum()
- sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
+ sample_size = (
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
+ )
logging_output = {
- 'loss': utils.item(loss.data) if reduce else loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample['target'].size(0),
- 'sample_size': sample_size,
+ "loss": utils.item(loss.data) if reduce else loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample["target"].size(0),
+ "sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
- ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
- sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
- metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
+ metrics.log_scalar(
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+ )
if sample_size != ntokens:
- metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3)
- metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
+ metrics.log_scalar(
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
+ )
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
+ )
else:
- metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
+ metrics.log_derived(
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
+ )
@staticmethod
def logging_outputs_can_be_summed() -> bool:
diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py
index d1142a993c..761ffc8e61 100644
--- a/fairseq/model_parallel/megatron_trainer.py
+++ b/fairseq/model_parallel/megatron_trainer.py
@@ -10,6 +10,7 @@
from fairseq import distributed_utils
from fairseq.trainer import Trainer
+
try:
from fairseq.model_parallel.megatron.mpu import (
get_data_parallel_group,
@@ -18,20 +19,21 @@
get_model_parallel_group,
get_model_parallel_src_rank,
)
+
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class MegatronTrainer(Trainer):
- """Main class for model parallel with data parallel training.
- """
+ """Main class for model parallel with data parallel training."""
+
def __init__(self, args, task, model, criterion):
if not has_megatron_submodule:
raise ImportError(
- '\n\nPlease install the megatron submodule:'
- '\n\n git submodule update --init '
- 'fairseq/model_parallel/megatron'
+ "\n\nPlease install the megatron submodule:"
+ "\n\n git submodule update --init "
+ "fairseq/model_parallel/megatron"
)
super().__init__(args, task, model, criterion)
@@ -57,6 +59,7 @@ def _aggregate_model_parallel_grad_norm(total_norm):
distributed_utils.all_reduce(total_norm, group=get_model_parallel_group())
total_norm = total_norm ** 0.5
return total_norm
+
return self.optimizer.clip_grad_norm(
clip_norm,
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
diff --git a/fairseq/model_parallel/models/__init__.py b/fairseq/model_parallel/models/__init__.py
index a3207981ad..3532479e52 100644
--- a/fairseq/model_parallel/models/__init__.py
+++ b/fairseq/model_parallel/models/__init__.py
@@ -11,6 +11,10 @@
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
- if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)):
- model_name = file[:file.find('.py')] if file.endswith('.py') else file
- module = importlib.import_module('fairseq.model_parallel.models.' + model_name)
+ if (
+ not file.startswith("_")
+ and not file.startswith(".")
+ and (file.endswith(".py") or os.path.isdir(path))
+ ):
+ model_name = file[: file.find(".py")] if file.endswith(".py") else file
+ module = importlib.import_module("fairseq.model_parallel.models." + model_name)
diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py
index e11f491486..eb81ded341 100644
--- a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py
+++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py
@@ -3,31 +3,35 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from collections import namedtuple
import math
+from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import options, utils
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
- PositionalEmbedding,
MultiheadAttention,
+ PositionalEmbedding,
)
-EncoderOut = namedtuple('TransformerEncoderOut', [
- 'encoder_out', # T x B x C
- 'encoder_padding_mask', # B x T
- 'encoder_embedding', # B x T x C
- 'encoder_states', # List[T x B x C]
-])
+
+EncoderOut = namedtuple(
+ "TransformerEncoderOut",
+ [
+ "encoder_out", # T x B x C
+ "encoder_padding_mask", # B x T
+ "encoder_embedding", # B x T x C
+ "encoder_states", # List[T x B x C]
+ ],
+)
class TransformerEncoderEmbedding(nn.Module):
""" Encoder Embedding + Positional Embedding """
+
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
@@ -40,11 +44,17 @@ def __init__(self, args, embed_tokens):
self.padding_idx = embed_tokens.padding_idx
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim)
- self.embed_positions = PositionalEmbedding(
- args.max_source_positions, embed_dim, self.padding_idx,
- learned=args.encoder_learned_pos,
- ) if not args.no_token_positional_embeddings else None
- if getattr(args, 'layernorm_embedding', False):
+ self.embed_positions = (
+ PositionalEmbedding(
+ args.max_source_positions,
+ embed_dim,
+ self.padding_idx,
+ learned=args.encoder_learned_pos,
+ )
+ if not args.no_token_positional_embeddings
+ else None
+ )
+ if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
@@ -77,9 +87,10 @@ def forward(self, input):
class TransformerEncoderLayerNorm(nn.Module):
"""
- Layer norm at the the end of all encoder layers if
- args.encoder_enormalize_before = True
+ Layer norm at the the end of all encoder layers if
+ args.encoder_enormalize_before = True
"""
+
def __init__(self, args, embed_dim):
super().__init__()
if args.encoder_normalize_before:
@@ -99,30 +110,45 @@ def forward(self, input):
class TransformerDecoderEmbedding(nn.Module):
""" Decoder Embedding + Positional Embedding """
+
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
- input_embed_dim = sum(e.embedding_dim for e in embed_tokens) \
- if isinstance(embed_tokens, nn.ModuleList) \
+ input_embed_dim = (
+ sum(e.embedding_dim for e in embed_tokens)
+ if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.embedding_dim
+ )
embed_dim = args.decoder_embed_dim
self.output_embed_dim = args.decoder_output_dim
- padding_idx = embed_tokens[0].padding_idx \
- if isinstance(embed_tokens, nn.ModuleList) \
+ padding_idx = (
+ embed_tokens[0].padding_idx
+ if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.padding_idx
+ )
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
- self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
+ self.project_in_dim = (
+ Linear(input_embed_dim, embed_dim, bias=False)
+ if embed_dim != input_embed_dim
+ else None
+ )
- self.embed_positions = PositionalEmbedding(
- args.max_target_positions, embed_dim, padding_idx,
- learned=args.decoder_learned_pos,
- ) if not args.no_token_positional_embeddings else None
+ self.embed_positions = (
+ PositionalEmbedding(
+ args.max_target_positions,
+ embed_dim,
+ padding_idx,
+ learned=args.decoder_learned_pos,
+ )
+ if not args.no_token_positional_embeddings
+ else None
+ )
def forward(self, input):
mt_task = False
@@ -147,10 +173,14 @@ def forward(self, input):
encoder_padding_mask = None
incremental_state = None
- positions = self.embed_positions(
- prev_output_tokens,
- incremental_state=incremental_state,
- ) if self.embed_positions is not None else None
+ positions = (
+ self.embed_positions(
+ prev_output_tokens,
+ incremental_state=incremental_state,
+ )
+ if self.embed_positions is not None
+ else None
+ )
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
@@ -190,8 +220,11 @@ def __init__(self, args, embed_tokens, dictionary):
self.output_embed_dim = args.decoder_output_dim
embed_dim = args.decoder_embed_dim
- self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
- if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None
+ self.project_out_dim = (
+ Linear(embed_dim, self.output_embed_dim, bias=False)
+ if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
+ else None
+ )
self.adaptive_softmax = None
if args.adaptive_softmax_cutoff is not None:
assert not isinstance(embed_tokens, nn.ModuleList)
@@ -205,10 +238,16 @@ def __init__(self, args, embed_tokens, dictionary):
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
- self.embed_tokens = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim))
- nn.init.normal_(self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5)
+ self.embed_tokens = nn.Parameter(
+ torch.Tensor(len(dictionary), self.output_embed_dim)
+ )
+ nn.init.normal_(
+ self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5
+ )
- if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False):
+ if args.decoder_normalize_before and not getattr(
+ args, "no_decoder_final_norm", False
+ ):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
@@ -245,7 +284,7 @@ def output_layer(self, features, **kwargs):
output = F.linear(features[:, :, sidx:eidx], emb.weight)
else:
output += F.linear(features[:, :, sidx:eidx], emb.weight)
-
+
return output
else:
return F.linear(features, self.embed_tokens.weight)
@@ -273,18 +312,20 @@ def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
- self.embed_dim, args.encoder_attention_heads,
- dropout=args.attention_dropout, self_attention=True
+ self.embed_dim,
+ args.encoder_attention_heads,
+ dropout=args.attention_dropout,
+ self_attention=True,
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
- activation=getattr(args, 'activation_fn', 'relu')
+ activation=getattr(args, "activation_fn", "relu")
)
- self.activation_dropout = getattr(args, 'activation_dropout', 0)
+ self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
- self.activation_dropout = getattr(args, 'relu_dropout', 0)
+ self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
@@ -296,17 +337,12 @@ def upgrade_state_dict_named(self, state_dict, name):
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
- layer_norm_map = {
- '0': 'self_attn_layer_norm',
- '1': 'final_layer_norm'
- }
+ layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
- for m in ('weight', 'bias'):
- k = '{}.layer_norms.{}.{}'.format(name, old, m)
+ for m in ("weight", "bias"):
+ k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
- state_dict[
- '{}.{}.{}'.format(name, new, m)
- ] = state_dict[k]
+ state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, input):
@@ -330,7 +366,9 @@ def forward(self, input):
prev_output_tokens = input[2]
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
- x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
+ x, _ = self.self_attn(
+ query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
+ )
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
@@ -370,7 +408,9 @@ class TransformerDecoderLayer(nn.Module):
(default: False).
"""
- def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
+ def __init__(
+ self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
+ ):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
@@ -379,22 +419,22 @@ def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
- self_attention=True
+ self_attention=True,
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
- activation=getattr(args, 'activation_fn', 'relu')
+ activation=getattr(args, "activation_fn", "relu")
)
- self.activation_dropout = getattr(args, 'activation_dropout', 0)
+ self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
- self.activation_dropout = getattr(args, 'relu_dropout', 0)
+ self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
- export = getattr(args, 'char_inputs', False)
+ export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
@@ -404,8 +444,8 @@ def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn
self.encoder_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
- kdim=getattr(args, 'encoder_embed_dim', None),
- vdim=getattr(args, 'encoder_embed_dim', None),
+ kdim=getattr(args, "encoder_embed_dim", None),
+ vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
@@ -520,10 +560,18 @@ def forward(self, input):
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
- if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
- self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
+ if (
+ not hasattr(self, "_future_mask")
+ or self._future_mask is None
+ or self._future_mask.device != tensor.device
+ ):
+ self._future_mask = torch.triu(
+ utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
+ )
if self._future_mask.size(0) < dim:
- self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
+ self._future_mask = torch.triu(
+ utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
+ )
return self._future_mask[:dim, :dim]
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
@@ -548,5 +596,5 @@ def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
- nn.init.constant_(m.bias, 0.)
+ nn.init.constant_(m.bias, 0.0)
return m
diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py
index 65a087a3fb..cbfc6ae4a0 100644
--- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py
+++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py
@@ -5,7 +5,19 @@
import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
from fairseq import utils
+from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import (
+ Embedding,
+ TransformerDecoderEmbedding,
+ TransformerDecoderLayer,
+ TransformerDecoderOutputLayer,
+ TransformerEncoderEmbedding,
+ TransformerEncoderLayer,
+ TransformerEncoderLayerNorm,
+)
from fairseq.models import (
BaseFairseqModel,
FairseqDecoder,
@@ -13,25 +25,14 @@
register_model,
register_model_architecture,
)
+from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import (
base_architecture,
transformer_iwslt_de_en,
transformer_wmt_en_de_big,
)
from fairseq.modules import SinusoidalPositionalEmbedding
-from fairseq.models.fairseq_encoder import EncoderOut
-from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import (
- Embedding,
- TransformerEncoderLayer,
- TransformerDecoderLayer,
- TransformerEncoderEmbedding,
- TransformerEncoderLayerNorm,
- TransformerDecoderEmbedding,
- TransformerDecoderOutputLayer,
-)
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
+
logger = logging.getLogger(__name__)
@@ -40,24 +41,27 @@
DEFAULT_MAX_TARGET_POSITIONS = 1024
-@register_model('pipeline_parallel_transformer')
+@register_model("pipeline_parallel_transformer")
class PipelineParallelTransformerModel(BaseFairseqModel):
def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
try:
from fairscale.nn import Pipe
except ImportError:
- raise ImportError('Please install fairscale with: pip install fairscale')
+ raise ImportError("Please install fairscale with: pip install fairscale")
super().__init__()
assert isinstance(encoder, FairseqEncoder)
assert isinstance(decoder, FairseqDecoder)
- encoder_module_list = \
- [encoder.embedding_layer] + \
- list(encoder.encoder_layers) + \
- [encoder.final_layer_norm]
+ encoder_module_list = (
+ [encoder.embedding_layer]
+ + list(encoder.encoder_layers)
+ + [encoder.final_layer_norm]
+ )
self.num_encoder_modules = len(encoder_module_list)
- decoder_module_list = [decoder.embedding_layer] + \
- list(decoder.decoder_layers) + \
- [decoder.decoder_output_layer]
+ decoder_module_list = (
+ [decoder.embedding_layer]
+ + list(decoder.decoder_layers)
+ + [decoder.decoder_output_layer]
+ )
self.num_decoder_modules = len(decoder_module_list)
module_list = encoder_module_list + decoder_module_list
self.devices = devices
@@ -69,14 +73,12 @@ def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
checkpoint=checkpoint,
)
self.encoder_max_positions = self.max_positions_helper(
- encoder.embedding_layer,
- 'max_source_positions'
+ encoder.embedding_layer, "max_source_positions"
)
self.decoder_max_positions = self.max_positions_helper(
- decoder.embedding_layer,
- 'max_target_positions'
+ decoder.embedding_layer, "max_target_positions"
)
- self.adaptive_softmax = getattr(decoder, 'adaptive_softmax', None)
+ self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None)
# Note: To be populated during inference
self.encoder = None
self.decoder = None
@@ -87,9 +89,10 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens):
input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst)
return self.model(input)
else:
- assert self.encoder is not None and self.decoder is not None, \
- "encoder and decoder need to be initialized by " + \
- "calling the `prepare_for_inference_()` method"
+ assert self.encoder is not None and self.decoder is not None, (
+ "encoder and decoder need to be initialized by "
+ + "calling the `prepare_for_inference_()` method"
+ )
encoder_output_tuple = self.encoder(input)
return self.decoder(encoder_output_tuple)
@@ -109,7 +112,9 @@ def prepare_for_inference_(self, args):
module_count += 1
self.model = None
self.encoder = TransformerEncoder(args, None, None, encoder_module_list)
- self.decoder = TransformerDecoder(args, None, None, decoder_module_list=decoder_module_list)
+ self.decoder = TransformerDecoder(
+ args, None, None, decoder_module_list=decoder_module_list
+ )
@staticmethod
def add_args(parser):
@@ -178,20 +183,22 @@ def build_model_base(cls, args, task):
# make sure all arguments are present in older models
base_architecture(args)
- if not hasattr(args, 'max_source_positions'):
+ if not hasattr(args, "max_source_positions"):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
- if not hasattr(args, 'max_target_positions'):
+ if not hasattr(args, "max_target_positions"):
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
- assert embed_dim % num_embed_chunks == 0, \
- f"Number of embedding chunks = {num_embed_chunks} should be " + \
- f"divisible by the embedding dimension = {embed_dim}"
- assert path is None or num_embed_chunks == 1, \
- "Loading embedding from a path with number of embedding chunks > 1" + \
- " is not yet supported"
+ assert embed_dim % num_embed_chunks == 0, (
+ f"Number of embedding chunks = {num_embed_chunks} should be "
+ + f"divisible by the embedding dimension = {embed_dim}"
+ )
+ assert path is None or num_embed_chunks == 1, (
+ "Loading embedding from a path with number of embedding chunks > 1"
+ + " is not yet supported"
+ )
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
# if provided, load from preloaded dictionaries
@@ -205,30 +212,45 @@ def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
for i in range(num_embed_chunks):
emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx))
return emb
+
num_embed_chunks = args.num_embedding_chunks
if args.share_all_embeddings:
if src_dict != tgt_dict:
- raise ValueError('--share-all-embeddings requires a joined dictionary')
+ raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
- '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
+ )
if args.decoder_embed_path and (
- args.decoder_embed_path != args.encoder_embed_path):
- raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
+ args.decoder_embed_path != args.encoder_embed_path
+ ):
+ raise ValueError(
+ "--share-all-embeddings not compatible with --decoder-embed-path"
+ )
encoder_embed_tokens = build_embedding(
- src_dict, args.encoder_embed_dim, args.encoder_embed_path, num_embed_chunks,
+ src_dict,
+ args.encoder_embed_dim,
+ args.encoder_embed_path,
+ num_embed_chunks,
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
- assert args.share_decoder_input_output_embed or num_embed_chunks == 1, \
- "Not sharing decoder I/O embeddings is not yet supported with number of " + \
- "embedding chunks > 1"
+ assert args.share_decoder_input_output_embed or num_embed_chunks == 1, (
+ "Not sharing decoder I/O embeddings is not yet supported with number of "
+ + "embedding chunks > 1"
+ )
encoder_embed_tokens = build_embedding(
- src_dict, args.encoder_embed_dim, args.encoder_embed_path, num_embed_chunks,
+ src_dict,
+ args.encoder_embed_dim,
+ args.encoder_embed_path,
+ num_embed_chunks,
)
decoder_embed_tokens = build_embedding(
- tgt_dict, args.decoder_embed_dim, args.decoder_embed_path, num_embed_chunks,
+ tgt_dict,
+ args.decoder_embed_dim,
+ args.decoder_embed_path,
+ num_embed_chunks,
)
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
@@ -263,21 +285,24 @@ def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder_max_positions, self.decoder_max_positions)
- def max_positions_helper(self, embedding_layer,
- max_positions_field='max_source_positions'):
+ def max_positions_helper(
+ self, embedding_layer, max_positions_field="max_source_positions"
+ ):
"""Maximum input length supported by the encoder or decoder."""
if embedding_layer.embed_positions is None:
return getattr(embedding_layer, max_positions_field)
- return min(getattr(embedding_layer, max_positions_field),
- embedding_layer.embed_positions.max_positions)
+ return min(
+ getattr(embedding_layer, max_positions_field),
+ embedding_layer.embed_positions.max_positions,
+ )
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
- if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None:
+ if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:
- assert 'target' in sample
- target = sample['target']
+ assert "target" in sample
+ target = sample["target"]
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output, target=target)
@@ -303,7 +328,7 @@ def load_state_dict(self, state_dict, strict=True, args=None):
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
self.upgrade_state_dict(state_dict)
- is_regular_transformer = not any('model.partitions' in k for k in state_dict)
+ is_regular_transformer = not any("model.partitions" in k for k in state_dict)
if is_regular_transformer:
state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict)
return super().load_state_dict(state_dict, strict)
@@ -313,27 +338,50 @@ def convert_to_pipeline_parallel_state_dict(self, state_dict):
encoder_layer_idx = 0
decoder_layer_idx = 0
encoder_key_suffixes = [
- 'self_attn.k_proj.weight', 'self_attn.k_proj.bias',
- 'self_attn.v_proj.weight', 'self_attn.v_proj.bias',
- 'self_attn.q_proj.weight', 'self_attn.q_proj.bias',
- 'self_attn.out_proj.weight', 'self_attn.out_proj.bias',
- 'self_attn_layer_norm.weight', 'self_attn_layer_norm.bias', 'fc1.weight',
- 'fc1.bias', 'fc2.weight', 'fc2.bias', 'final_layer_norm.weight',
- 'final_layer_norm.bias',
+ "self_attn.k_proj.weight",
+ "self_attn.k_proj.bias",
+ "self_attn.v_proj.weight",
+ "self_attn.v_proj.bias",
+ "self_attn.q_proj.weight",
+ "self_attn.q_proj.bias",
+ "self_attn.out_proj.weight",
+ "self_attn.out_proj.bias",
+ "self_attn_layer_norm.weight",
+ "self_attn_layer_norm.bias",
+ "fc1.weight",
+ "fc1.bias",
+ "fc2.weight",
+ "fc2.bias",
+ "final_layer_norm.weight",
+ "final_layer_norm.bias",
]
decoder_key_suffixes = [
- 'self_attn.k_proj.weight', 'self_attn.k_proj.bias',
- 'self_attn.v_proj.weight', 'self_attn.v_proj.bias',
- 'self_attn.q_proj.weight', 'self_attn.q_proj.bias',
- 'self_attn.out_proj.weight', 'self_attn.out_proj.bias',
- 'self_attn_layer_norm.weight', 'self_attn_layer_norm.bias',
- 'encoder_attn.k_proj.weight', 'encoder_attn.k_proj.bias',
- 'encoder_attn.v_proj.weight', 'encoder_attn.v_proj.bias',
- 'encoder_attn.q_proj.weight', 'encoder_attn.q_proj.bias',
- 'encoder_attn.out_proj.weight', 'encoder_attn.out_proj.bias',
- 'encoder_attn_layer_norm.weight', 'encoder_attn_layer_norm.bias',
- 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias',
- 'final_layer_norm.weight', 'final_layer_norm.bias'
+ "self_attn.k_proj.weight",
+ "self_attn.k_proj.bias",
+ "self_attn.v_proj.weight",
+ "self_attn.v_proj.bias",
+ "self_attn.q_proj.weight",
+ "self_attn.q_proj.bias",
+ "self_attn.out_proj.weight",
+ "self_attn.out_proj.bias",
+ "self_attn_layer_norm.weight",
+ "self_attn_layer_norm.bias",
+ "encoder_attn.k_proj.weight",
+ "encoder_attn.k_proj.bias",
+ "encoder_attn.v_proj.weight",
+ "encoder_attn.v_proj.bias",
+ "encoder_attn.q_proj.weight",
+ "encoder_attn.q_proj.bias",
+ "encoder_attn.out_proj.weight",
+ "encoder_attn.out_proj.bias",
+ "encoder_attn_layer_norm.weight",
+ "encoder_attn_layer_norm.bias",
+ "fc1.weight",
+ "fc1.bias",
+ "fc2.weight",
+ "fc2.bias",
+ "final_layer_norm.weight",
+ "final_layer_norm.bias",
]
for pid, partition in enumerate(self.model.partitions):
logger.info(f"Begin Partition {pid}")
@@ -376,29 +424,32 @@ class TransformerEncoder(FairseqEncoder):
def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None):
super().__init__(dictionary)
- self.register_buffer('version', torch.Tensor([3]))
+ self.register_buffer("version", torch.Tensor([3]))
try:
from fairscale.nn import Pipe
except ImportError:
- raise ImportError('Please install fairscale with: pip install fairscale')
+ raise ImportError("Please install fairscale with: pip install fairscale")
if encoder_module_list is None:
embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
- layers = [
- TransformerEncoderLayer(args) for i in range(args.encoder_layers)
- ]
+ layers = [TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
if isinstance(embed_tokens, nn.ModuleList):
emb_dim = sum(e.embedding_dim for e in embed_tokens)
else:
emb_dim = embed_tokens.embedding_dim
final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
encoder_module_list = [embedding_layer] + layers + [final_layer_norm]
- self.use_pipeline = (getattr(args, "pipeline_encoder_balance", None) is not None)
+ self.use_pipeline = getattr(args, "pipeline_encoder_balance", None) is not None
if self.use_pipeline:
- encoder_balance = utils.eval_str_list(args.pipeline_encoder_balance, type=int)
- encoder_devices = utils.eval_str_list(args.pipeline_encoder_devices, type=int)
- assert sum(encoder_balance) == len(encoder_module_list), \
- f"Sum of encoder_balance={encoder_balance} is not equal " + \
- f"to num_encoder_modules={len(encoder_module_list)}"
+ encoder_balance = utils.eval_str_list(
+ args.pipeline_encoder_balance, type=int
+ )
+ encoder_devices = utils.eval_str_list(
+ args.pipeline_encoder_devices, type=int
+ )
+ assert sum(encoder_balance) == len(encoder_module_list), (
+ f"Sum of encoder_balance={encoder_balance} is not equal "
+ + f"to num_encoder_modules={len(encoder_module_list)}"
+ )
self.model = Pipe(
module=nn.Sequential(*encoder_module_list),
balance=encoder_balance,
@@ -433,7 +484,9 @@ def forward(self, src_tokens, src_lengths):
Only populated if *return_all_hiddens* is True.
)
"""
- dummy_prev_output_tokens = torch.zeros(1, dtype=src_tokens.dtype, device=src_tokens.device)
+ dummy_prev_output_tokens = torch.zeros(
+ 1, dtype=src_tokens.dtype, device=src_tokens.device
+ )
input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens)
if self.use_pipeline:
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
@@ -465,11 +518,15 @@ def reorder_encoder_out(self, encoder_out, new_order):
)
if encoder_out.encoder_padding_mask is not None:
encoder_out = encoder_out._replace(
- encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order)
+ encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(
+ 0, new_order
+ )
)
if encoder_out.encoder_embedding is not None:
encoder_out = encoder_out._replace(
- encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order)
+ encoder_embedding=encoder_out.encoder_embedding.index_select(
+ 0, new_order
+ )
)
if encoder_out.encoder_states is not None:
for idx, state in enumerate(encoder_out.encoder_states):
@@ -480,8 +537,10 @@ def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embedding_layer.embed_positions is None:
return self.embedding_layer.max_source_positions
- return min(self.embedding_layer.max_source_positions,
- self.embedding_layer.embed_positions.max_positions)
+ return min(
+ self.embedding_layer.max_source_positions,
+ self.embedding_layer.embed_positions.max_positions,
+ )
class TransformerDecoder(FairseqDecoder):
@@ -497,28 +556,42 @@ class TransformerDecoder(FairseqDecoder):
(default: False).
"""
- def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, decoder_module_list=None):
+ def __init__(
+ self,
+ args,
+ dictionary,
+ embed_tokens,
+ no_encoder_attn=False,
+ decoder_module_list=None,
+ ):
super().__init__(dictionary)
- self.register_buffer('version', torch.Tensor([3]))
+ self.register_buffer("version", torch.Tensor([3]))
try:
from fairscale.nn import Pipe
except ImportError:
- raise ImportError('Please install fairscale with: pip install fairscale')
+ raise ImportError("Please install fairscale with: pip install fairscale")
if decoder_module_list is None:
embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
layers = [
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
- decoder_output_layer = TransformerDecoderOutputLayer(args, embed_tokens, dictionary)
+ decoder_output_layer = TransformerDecoderOutputLayer(
+ args, embed_tokens, dictionary
+ )
decoder_module_list = [embedding_layer] + layers + [decoder_output_layer]
- self.use_pipeline = (getattr(args, "pipeline_decoder_balance", None) is not None)
+ self.use_pipeline = getattr(args, "pipeline_decoder_balance", None) is not None
if self.use_pipeline:
- decoder_balance = utils.eval_str_list(args.pipeline_decoder_balance, type=int)
- decoder_devices = utils.eval_str_list(args.pipeline_decoder_devices, type=int)
- assert sum(decoder_balance) == len(decoder_module_list), \
- f"Sum of decoder_balance={decoder_balance} is not equal " + \
- f"to num_decoder_modules={len(decoder_module_list)}"
+ decoder_balance = utils.eval_str_list(
+ args.pipeline_decoder_balance, type=int
+ )
+ decoder_devices = utils.eval_str_list(
+ args.pipeline_decoder_devices, type=int
+ )
+ assert sum(decoder_balance) == len(decoder_module_list), (
+ f"Sum of decoder_balance={decoder_balance} is not equal "
+ + f"to num_decoder_modules={len(decoder_module_list)}"
+ )
self.model = Pipe(
module=nn.Sequential(*decoder_module_list),
balance=decoder_balance,
@@ -531,7 +604,11 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, decode
self.decoder_layers = nn.Sequential(*decoder_module_list[1:-1])
self.decoder_output_layer = decoder_module_list[-1]
- def forward(self, prev_output_tokens, encoder_out=None,):
+ def forward(
+ self,
+ prev_output_tokens,
+ encoder_out=None,
+ ):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
@@ -548,14 +625,18 @@ def forward(self, prev_output_tokens, encoder_out=None,):
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
- input_tuple = (encoder_out.encoder_out, encoder_out.encoder_padding_mask, prev_output_tokens)
+ input_tuple = (
+ encoder_out.encoder_out,
+ encoder_out.encoder_padding_mask,
+ prev_output_tokens,
+ )
if self.use_pipeline:
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
- return (self.model(input_tuple), )
+ return (self.model(input_tuple),)
else:
embed_layer_output = self.embedding_layer(input_tuple)
state = self.decoder_layers(embed_layer_output)
- return (self.decoder_output_layer(state), )
+ return (self.decoder_output_layer(state),)
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
@@ -572,43 +653,51 @@ def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embedding_layer.embed_positions is None:
return self.embedding_layer.max_target_positions
- return min(self.embedding_layer.max_target_positions,
- self.embedding_layer.embed_positions.max_positions)
+ return min(
+ self.embedding_layer.max_target_positions,
+ self.embedding_layer.embed_positions.max_positions,
+ )
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
- not hasattr(self, '_future_mask')
+ not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
- self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
+ self._future_mask = torch.triu(
+ utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
+ )
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
- weights_key = '{}.embed_positions.weights'.format(name)
+ weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
- state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
+ state_dict[
+ "{}.embed_positions._float_tensor".format(name)
+ ] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
- '0': 'self_attn_layer_norm',
- '1': 'encoder_attn_layer_norm',
- '2': 'final_layer_norm'
+ "0": "self_attn_layer_norm",
+ "1": "encoder_attn_layer_norm",
+ "2": "final_layer_norm",
}
for old, new in layer_norm_map.items():
- for m in ('weight', 'bias'):
- k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
+ for m in ("weight", "bias"):
+ k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
if k in state_dict:
- state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k]
+ state_dict[
+ "{}.layers.{}.{}.{}".format(name, i, new, m)
+ ] = state_dict[k]
del state_dict[k]
- version_key = '{}.version'.format(name)
+ version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
@@ -618,13 +707,15 @@ def upgrade_state_dict_named(self, state_dict, name):
return state_dict
-@register_model_architecture('pipeline_parallel_transformer',
- 'transformer_iwslt_de_en_pipeline_parallel')
+@register_model_architecture(
+ "pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel"
+)
def transformer_iwslt_de_en_dist(args):
transformer_iwslt_de_en(args)
-@register_model_architecture('pipeline_parallel_transformer',
- 'transformer_wmt_en_de_big_pipeline_parallel')
+@register_model_architecture(
+ "pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel"
+)
def transformer_wmt_en_de_big_dist(args):
transformer_wmt_en_de_big(args)
diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py
index ed49fbb338..68ad88d2a5 100644
--- a/fairseq/model_parallel/models/roberta/model.py
+++ b/fairseq/model_parallel/models/roberta/model.py
@@ -11,27 +11,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import utils
-from fairseq.models import (
- FairseqEncoder,
- register_model,
- register_model_architecture,
-)
+from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoder
+from fairseq.models import FairseqEncoder, register_model, register_model_architecture
from fairseq.models.roberta import (
- RobertaModel,
+ RobertaClassificationHead,
RobertaEncoder,
RobertaLMHead,
- RobertaClassificationHead,
-)
-from fairseq.modules import (
- LayerNorm,
- TransformerSentenceEncoder,
-)
-from fairseq.model_parallel.modules import (
- ModelParallelTransformerSentenceEncoder,
+ RobertaModel,
)
+from fairseq.modules import LayerNorm, TransformerSentenceEncoder
from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+
try:
from fairseq.model_parallel.megatron.mpu import (
copy_to_model_parallel_region,
@@ -39,6 +31,7 @@
ColumnParallelLinear,
RowParallelLinear,
)
+
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@@ -46,10 +39,8 @@
logger = logging.getLogger(__name__)
-@register_model('model_parallel_roberta')
+@register_model("model_parallel_roberta")
class ModelParallelRobertaModel(RobertaModel):
-
-
def __init__(self, args, encoder):
super().__init__(args, encoder)
@@ -69,18 +60,25 @@ def build_model(cls, args, task):
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
- if not hasattr(args, 'max_positions'):
+ if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
- if getattr(args, 'untie_weights_roberta', False):
+ if getattr(args, "untie_weights_roberta", False):
raise NotImplementedError(
- '--untie-weights-roberta is not supported in model parallel mode'
+ "--untie-weights-roberta is not supported in model parallel mode"
)
encoder = ModelParallelRobertaEncoder(args, task.source_dictionary)
return cls(args, encoder)
- def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs):
+ def forward(
+ self,
+ src_tokens,
+ features_only=False,
+ return_all_hiddens=False,
+ classification_head_name=None,
+ **kwargs
+ ):
if classification_head_name is not None:
features_only = True
@@ -90,7 +88,9 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, cla
x = self.classification_heads[classification_head_name](x)
return x, extra
- def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
+ def register_classification_head(
+ self, name, num_classes=None, inner_dim=None, **kwargs
+ ):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
@@ -98,7 +98,7 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, *
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
- 'and inner_dim {} (prev: {})'.format(
+ "and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
@@ -146,7 +146,9 @@ def forward(self, features, masked_tokens=None, **kwargs):
class ModelParallelRobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
- def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout):
+ def __init__(
+ self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout
+ ):
super().__init__()
self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True)
self.activation_fn = utils.get_activation_fn(activation_fn)
@@ -206,7 +208,14 @@ def __init__(self, args, dictionary):
weight=self.sentence_encoder.embed_tokens.weight,
)
- def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
+ def forward(
+ self,
+ src_tokens,
+ features_only=False,
+ return_all_hiddens=False,
+ masked_tokens=None,
+ **unused
+ ):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
@@ -223,7 +232,9 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, mas
is a list of hidden states. Note that the hidden
states have shape `(src_len, batch, vocab)`.
"""
- x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens)
+ x, extra = self.extract_features(
+ src_tokens, return_all_hiddens=return_all_hiddens
+ )
if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
@@ -234,7 +245,7 @@ def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
last_state_only=not return_all_hiddens,
)
features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
- return features, {'inner_states': inner_states if return_all_hiddens else None}
+ return features, {"inner_states": inner_states if return_all_hiddens else None}
def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features, masked_tokens)
@@ -244,33 +255,33 @@ def max_positions(self):
return self.args.max_positions
-@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta')
+@register_model_architecture("model_parallel_roberta", "model_parallel_roberta")
def base_architecture(args):
- args.encoder_layers = getattr(args, 'encoder_layers', 12)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
- args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
- args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
- args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None)
- args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
+ args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
-@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_base')
+@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base")
def roberta_base_architecture(args):
base_architecture(args)
-@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_large')
+@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large")
def roberta_large_architecture(args):
- args.encoder_layers = getattr(args, 'encoder_layers', 24)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)
diff --git a/fairseq/model_parallel/models/transformer.py b/fairseq/model_parallel/models/transformer.py
index 3ba539319f..4f34645226 100644
--- a/fairseq/model_parallel/models/transformer.py
+++ b/fairseq/model_parallel/models/transformer.py
@@ -7,21 +7,17 @@
import torch.nn as nn
import torch.nn.functional as F
-
-from fairseq.models import (
- register_model,
+from fairseq.model_parallel.modules import (
+ ModelParallelTransformerDecoderLayer,
+ ModelParallelTransformerEncoderLayer,
)
-
+from fairseq.models import register_model
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
)
-from fairseq.model_parallel.modules import (
- ModelParallelTransformerDecoderLayer,
- ModelParallelTransformerEncoderLayer,
-)
try:
from fairseq.model_parallel.megatron.mpu import (
@@ -29,6 +25,7 @@
gather_from_model_parallel_region,
VocabParallelEmbedding,
)
+
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@@ -37,18 +34,19 @@
logger = logging.getLogger(__name__)
-@register_model('model_parallel_transformer')
+@register_model("model_parallel_transformer")
class ModelParallelTransformerModel(TransformerModel):
"""
Model parallel Transformer model.
"""
+
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
if not has_megatron_submodule:
raise ImportError(
- '\n\nPlease install the megatron submodule:'
- '\n\n git submodule update --init '
- 'fairseq/model_parallel/megatron'
+ "\n\nPlease install the megatron submodule:"
+ "\n\n git submodule update --init "
+ "fairseq/model_parallel/megatron"
)
dictionary.pad_to_multiple_(args.model_parallel_size * 8)
num_embeddings = len(dictionary)
@@ -57,10 +55,15 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs):
nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5)
nn.init.constant_(tensor[1], 0)
- emb = VocabParallelEmbedding(num_embeddings, embed_dim, padding_idx, init_method=_vocab_init)
+
+ emb = VocabParallelEmbedding(
+ num_embeddings, embed_dim, padding_idx, init_method=_vocab_init
+ )
# if provided, load from preloaded dictionaries
if path:
- raise NotImplementedError("Loading of embedding from path is not supported for model parallel")
+ raise NotImplementedError(
+ "Loading of embedding from path is not supported for model parallel"
+ )
return emb
@classmethod
@@ -73,7 +76,7 @@ def build_decoder(cls, args, tgt_dict, embed_tokens):
args,
tgt_dict,
embed_tokens,
- no_encoder_attn=getattr(args, 'no_cross_attention', False),
+ no_encoder_attn=getattr(args, "no_cross_attention", False),
)
@@ -100,7 +103,7 @@ def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if not self.share_input_output_embed:
raise NotImplementedError(
- 'Model parallel training currently requires --share-decoder-input-output-embed'
+ "Model parallel training currently requires --share-decoder-input-output-embed"
)
features = copy_to_model_parallel_region(features)
@@ -108,6 +111,6 @@ def output_layer(self, features, **kwargs):
# project back to size of vocabulary
x = self.output_projection(features)
- if getattr(self.args, 'criterion') != 'vocab_parallel_cross_entropy':
+ if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy":
x = gather_from_model_parallel_region(x).contiguous()
return x
diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py
index 492dad653c..ed378c4320 100644
--- a/fairseq/model_parallel/models/transformer_lm.py
+++ b/fairseq/model_parallel/models/transformer_lm.py
@@ -4,14 +4,14 @@
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
-
-from fairseq.models import register_model, register_model_architecture
-from fairseq.models.transformer_lm import (
- TransformerLanguageModel,
-)
from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder
+from fairseq.models import register_model, register_model_architecture
+from fairseq.models.transformer_lm import TransformerLanguageModel
+
+
try:
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
+
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@@ -20,17 +20,16 @@
DEFAULT_MAX_TARGET_POSITIONS = 1024
-@register_model('model_parallel_transformer_lm')
+@register_model("model_parallel_transformer_lm")
class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
-
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
if not has_megatron_submodule:
raise ImportError(
- '\n\nPlease install the megatron submodule:'
- '\n\n git submodule update --init '
- 'fairseq/model_parallel/megatron'
+ "\n\nPlease install the megatron submodule:"
+ "\n\n git submodule update --init "
+ "fairseq/model_parallel/megatron"
)
# make sure all arguments are present in older models
@@ -42,18 +41,29 @@ def build_model(cls, args, task):
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
- if getattr(args, 'max_target_positions', None) is None:
- args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)
+ if getattr(args, "max_target_positions", None) is None:
+ args.max_target_positions = getattr(
+ args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
+ )
if args.character_embeddings:
- raise NotImplementedError("Character embeddings is not supported for model parallel")
+ raise NotImplementedError(
+ "Character embeddings is not supported for model parallel"
+ )
elif args.adaptive_input:
- raise NotImplementedError("Adaptive input is not supported for model parallel")
+ raise NotImplementedError(
+ "Adaptive input is not supported for model parallel"
+ )
else:
- embed_tokens = cls.build_embedding(args, task.source_dictionary, args.decoder_input_dim)
+ embed_tokens = cls.build_embedding(
+ args, task.source_dictionary, args.decoder_input_dim
+ )
decoder = ModelParallelTransformerDecoder(
- args, task.target_dictionary, embed_tokens, no_encoder_attn=True,
+ args,
+ task.target_dictionary,
+ embed_tokens,
+ no_encoder_attn=True,
)
return cls(decoder)
@@ -62,78 +72,94 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs):
nn.init.normal_(tensor, mean=0, std=embed_dim ** -0.5)
nn.init.constant_(tensor[1], 0)
- embed_tokens = VocabParallelEmbedding(len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init)
+
+ embed_tokens = VocabParallelEmbedding(
+ len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init
+ )
return embed_tokens
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
- if hasattr(args, 'no_tie_adaptive_proj'):
+ if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
- if hasattr(args, 'decoder_final_norm'):
+ if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
- args.activation_fn = getattr(args, 'activation_fn', 'relu')
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.0)
- args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
- args.relu_dropout = getattr(args, 'relu_dropout', 0.0)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
- args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
- args.decoder_layers = getattr(args, 'decoder_layers', 6)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
+ args.activation_fn = getattr(args, "activation_fn", "relu")
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.relu_dropout = getattr(args, "relu_dropout", 0.0)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
# Model training is not stable without this
args.decoder_normalize_before = True
- args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', False)
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
- args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
- args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4)
- args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
- args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
- args.character_embeddings = getattr(args, 'character_embeddings', False)
- args.character_filters = getattr(args, 'character_filters', '[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]')
- args.character_embedding_dim = getattr(args, 'character_embedding_dim', 4)
- args.char_embedder_highway_layers = getattr(args, 'char_embedder_highway_layers', 2)
- args.adaptive_input = getattr(args, 'adaptive_input', False)
- args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4)
- args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None)
- args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False)
- args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False)
- args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
- args.decoder_layerdrop = getattr(args, 'decoder_layerdrop', 0.0)
- args.decoder_layers_to_keep = getattr(args, 'decoder_layers_to_keep', None)
- args.layernorm_embedding = getattr(args, 'layernorm_embedding', False)
- args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
- args.quant_noise_pq = getattr(args, 'quant_noise_pq', 0.0)
- args.quant_noise_pq_block_size = getattr(args, 'quant_noise_pq_block_size', 8)
- args.quant_noise_scalar = getattr(args, 'quant_noise_scalar', 0.0)
- args.add_bos_token = getattr(args, 'add_bos_token', False)
-
-@register_model_architecture('model_parallel_transformer_lm', 'transformer_lm_megatron')
+ args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.character_embeddings = getattr(args, "character_embeddings", False)
+ args.character_filters = getattr(
+ args,
+ "character_filters",
+ "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]",
+ )
+ args.character_embedding_dim = getattr(args, "character_embedding_dim", 4)
+ args.char_embedder_highway_layers = getattr(args, "char_embedder_highway_layers", 2)
+ args.adaptive_input = getattr(args, "adaptive_input", False)
+ args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
+ args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
+ args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
+ args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
+ args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0.0)
+ args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
+ args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0.0)
+ args.add_bos_token = getattr(args, "add_bos_token", False)
+
+
+@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron")
def transformer_lm_megatron(args):
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 3072)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 3072 * 4)
- args.decoder_layers = getattr(args, 'decoder_layers', 72)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 32)
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4)
+ args.decoder_layers = getattr(args, "decoder_layers", 72)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
-@register_model_architecture('model_parallel_transformer_lm', 'transformer_lm_megatron_11b')
+@register_model_architecture(
+ "model_parallel_transformer_lm", "transformer_lm_megatron_11b"
+)
def transformer_lm_megatron_11b(args):
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 3072)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 3072 * 6)
- args.decoder_layers = getattr(args, 'decoder_layers', 72)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 32)
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 6)
+ args.decoder_layers = getattr(args, "decoder_layers", 72)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
diff --git a/fairseq/model_parallel/modules/__init__.py b/fairseq/model_parallel/modules/__init__.py
index 26401dcc7c..fb45b3c9e0 100644
--- a/fairseq/model_parallel/modules/__init__.py
+++ b/fairseq/model_parallel/modules/__init__.py
@@ -5,14 +5,19 @@
"""isort:skip_file"""
from .multihead_attention import ModelParallelMultiheadAttention
-from .transformer_layer import ModelParallelTransformerEncoderLayer, ModelParallelTransformerDecoderLayer
-from .transformer_sentence_encoder_layer import ModelParallelTransformerSentenceEncoderLayer
+from .transformer_layer import (
+ ModelParallelTransformerEncoderLayer,
+ ModelParallelTransformerDecoderLayer,
+)
+from .transformer_sentence_encoder_layer import (
+ ModelParallelTransformerSentenceEncoderLayer,
+)
from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder
__all__ = [
- 'ModelParallelMultiheadAttention',
- 'ModelParallelTransformerEncoderLayer',
- 'ModelParallelTransformerDecoderLayer',
- 'ModelParallelTransformerSentenceEncoder',
- 'ModelParallelTransformerSentenceEncoderLayer',
+ "ModelParallelMultiheadAttention",
+ "ModelParallelTransformerEncoderLayer",
+ "ModelParallelTransformerDecoderLayer",
+ "ModelParallelTransformerSentenceEncoder",
+ "ModelParallelTransformerSentenceEncoderLayer",
]
diff --git a/fairseq/model_parallel/modules/multihead_attention.py b/fairseq/model_parallel/modules/multihead_attention.py
index f55a712b01..4164bf9131 100644
--- a/fairseq/model_parallel/modules/multihead_attention.py
+++ b/fairseq/model_parallel/modules/multihead_attention.py
@@ -8,9 +8,10 @@
import torch
import torch.nn.functional as F
from fairseq import utils
-from torch import Tensor, nn
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
+from torch import Tensor, nn
+
try:
from fairseq.model_parallel.megatron.mpu import (
@@ -19,6 +20,7 @@
ColumnParallelLinear,
RowParallelLinear,
)
+
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@@ -46,9 +48,9 @@ def __init__(
super().__init__()
if not has_megatron_submodule:
raise ImportError(
- '\n\nPlease install the megatron submodule:'
- '\n\n git submodule update --init '
- 'fairseq/model_parallel/megatron'
+ "\n\nPlease install the megatron submodule:"
+ "\n\n git submodule update --init "
+ "fairseq/model_parallel/megatron"
)
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
@@ -74,14 +76,22 @@ def __init__(
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
- assert not self.self_attention or self.qkv_same_dim, (
- "Self-attention requires query, key and value to be of the same size"
- )
+ assert (
+ not self.self_attention or self.qkv_same_dim
+ ), "Self-attention requires query, key and value to be of the same size"
- self.k_proj = ColumnParallelLinear(self.kdim, embed_dim, bias=bias, gather_output=False)
- self.v_proj = ColumnParallelLinear(self.vdim, embed_dim, bias=bias, gather_output=False)
- self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, gather_output=False)
- self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True)
+ self.k_proj = ColumnParallelLinear(
+ self.kdim, embed_dim, bias=bias, gather_output=False
+ )
+ self.v_proj = ColumnParallelLinear(
+ self.vdim, embed_dim, bias=bias, gather_output=False
+ )
+ self.q_proj = ColumnParallelLinear(
+ embed_dim, embed_dim, bias=bias, gather_output=False
+ )
+ self.out_proj = RowParallelLinear(
+ embed_dim, embed_dim, bias=bias, input_is_parallel=True
+ )
self.tpu = False
@@ -145,7 +155,6 @@ def forward(
v = self.v_proj(value)
q *= self.scaling
-
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads_partition, self.head_dim)
@@ -169,7 +178,9 @@ def forward(
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
- prev_key = _prev_key.view(bsz * self.num_heads_partition, -1, self.head_dim)
+ prev_key = _prev_key.view(
+ bsz * self.num_heads_partition, -1, self.head_dim
+ )
if static_kv:
k = prev_key
else:
@@ -178,7 +189,9 @@ def forward(
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
- prev_value = _prev_value.view(bsz * self.num_heads_partition, -1, self.head_dim)
+ prev_value = _prev_value.view(
+ bsz * self.num_heads_partition, -1, self.head_dim
+ )
if static_kv:
v = prev_value
else:
@@ -188,16 +201,22 @@ def forward(
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
- key_padding_mask = ModelParallelMultiheadAttention._append_prev_key_padding_mask(
- key_padding_mask=key_padding_mask,
- prev_key_padding_mask=prev_key_padding_mask,
- batch_size=bsz,
- src_len=k.size(1),
- static_kv=static_kv,
+ key_padding_mask = (
+ ModelParallelMultiheadAttention._append_prev_key_padding_mask(
+ key_padding_mask=key_padding_mask,
+ prev_key_padding_mask=prev_key_padding_mask,
+ batch_size=bsz,
+ src_len=k.size(1),
+ static_kv=static_kv,
+ )
)
- saved_state["prev_key"] = k.view(bsz, self.num_heads_partition, -1, self.head_dim)
- saved_state["prev_value"] = v.view(bsz, self.num_heads_partition, -1, self.head_dim)
+ saved_state["prev_key"] = k.view(
+ bsz, self.num_heads_partition, -1, self.head_dim
+ )
+ saved_state["prev_value"] = v.view(
+ bsz, self.num_heads_partition, -1, self.head_dim
+ )
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
@@ -216,7 +235,11 @@ def forward(
attn_weights = torch.bmm(q, k.transpose(1, 2))
- assert list(attn_weights.size()) == [bsz * self.num_heads_partition, tgt_len, src_len]
+ assert list(attn_weights.size()) == [
+ bsz * self.num_heads_partition,
+ tgt_len,
+ src_len,
+ ]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
@@ -224,20 +247,23 @@ def forward(
if key_padding_mask is not None:
# don't attend to padding symbols
- attn_weights = attn_weights.view(bsz, self.num_heads_partition, tgt_len, src_len)
+ attn_weights = attn_weights.view(
+ bsz, self.num_heads_partition, tgt_len, src_len
+ )
if not self.tpu:
attn_weights = attn_weights.masked_fill(
- key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
- attn_weights = attn_weights.masked_fill(key_padding_mask, float('-inf'))
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
- attn_weights = attn_weights.view(bsz * self.num_heads_partition, tgt_len, src_len)
+ attn_weights = attn_weights.view(
+ bsz * self.num_heads_partition, tgt_len, src_len
+ )
- attn_weights_float = utils.softmax(
- attn_weights, dim=-1
- )
+ attn_weights_float = utils.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
with get_cuda_rng_tracker().fork():
@@ -245,7 +271,11 @@ def forward(
assert v is not None
attn = torch.bmm(attn_probs, v)
- assert list(attn.size()) == [bsz * self.num_heads_partition, tgt_len, self.head_dim]
+ assert list(attn.size()) == [
+ bsz * self.num_heads_partition,
+ tgt_len,
+ self.head_dim,
+ ]
embed_dim_partition = embed_dim // self.model_parallel_size
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition)
attn = self.out_proj(attn)
diff --git a/fairseq/model_parallel/modules/transformer_layer.py b/fairseq/model_parallel/modules/transformer_layer.py
index 30b23d518c..7ab53c6e5f 100644
--- a/fairseq/model_parallel/modules/transformer_layer.py
+++ b/fairseq/model_parallel/modules/transformer_layer.py
@@ -3,18 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from fairseq.modules import (
- TransformerEncoderLayer,
- TransformerDecoderLayer,
-)
-
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
+from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
+
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
)
+
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@@ -23,7 +21,7 @@
class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer block over multiple gpus.
- See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
+ See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
@@ -48,8 +46,9 @@ def build_self_attention(self, embed_dim, args, **unused_kwargs):
class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer block.
- See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
+ See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
+
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder.py b/fairseq/model_parallel/modules/transformer_sentence_encoder.py
index a2a6eb81fa..a5d50a33c6 100644
--- a/fairseq/model_parallel/modules/transformer_sentence_encoder.py
+++ b/fairseq/model_parallel/modules/transformer_sentence_encoder.py
@@ -3,11 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+import random
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
+from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoderLayer
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
@@ -15,24 +17,21 @@
TransformerSentenceEncoder,
)
-from fairseq.model_parallel.modules import (
- ModelParallelTransformerSentenceEncoderLayer,
-)
try:
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
+
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
-import random
-
class ModelParallelTransformerSentenceEncoder(TransformerSentenceEncoder):
"""
Implementation for a Model Parallel Bi-directional Transformer based
Sentence Encoder used in BERT/XLM style pre-trained models.
"""
+
def build_embedding(self, vocab_size, embedding_dim, padding_idx):
return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx)
diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py
index d09158b7f1..e10bf52332 100644
--- a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py
+++ b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py
@@ -5,17 +5,17 @@
import torch
import torch.nn.functional as F
-
from fairseq import utils
-from fairseq.modules import (
- TransformerSentenceEncoderLayer
-)
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
+from fairseq.modules import TransformerSentenceEncoderLayer
+
+
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
)
+
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@@ -26,6 +26,7 @@ class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLay
Implements a Model Parallel Transformer Encoder Layer used in
BERT/XLM style pre-trained models.
"""
+
def build_fc1(self, input_dim, output_dim, **unused):
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
@@ -40,10 +41,7 @@ def build_self_attention(
**kwargs,
):
return ModelParallelMultiheadAttention(
- embed_dim,
- num_attention_heads,
- dropout=dropout,
- self_attention=True
+ embed_dim, num_attention_heads, dropout=dropout, self_attention=True
)
def forward(
diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py
index 48c59cb91d..cdabe36010 100644
--- a/fairseq/models/bart/hub_interface.py
+++ b/fairseq/models/bart/hub_interface.py
@@ -5,14 +5,12 @@
import copy
import logging
+from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-from typing import List
-
from fairseq import utils
from fairseq.data import encoders
@@ -34,19 +32,23 @@ def __init__(self, args, task, model):
self.bpe = encoders.build_bpe(args)
- self.max_positions = min(utils.resolve_max_positions(
- self.task.max_positions(),
- self.model.max_positions(),
- ))
+ self.max_positions = min(
+ utils.resolve_max_positions(
+ self.task.max_positions(),
+ self.model.max_positions(),
+ )
+ )
# this is useful for determining the device
- self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
+ self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
- def encode(self, sentence: str, *addl_sentences, no_separator=True) -> torch.LongTensor:
+ def encode(
+ self, sentence: str, *addl_sentences, no_separator=True
+ ) -> torch.LongTensor:
"""
BPE-encode a sentence (or multiple sentences).
@@ -67,12 +69,12 @@ def encode(self, sentence: str, *addl_sentences, no_separator=True) -> torch.Lon
[0, 8331, 2]
"""
tokens = self.bpe.encode(sentence)
- if len(tokens.split(' ')) > self.max_positions - 2:
- tokens = ' '.join(tokens.split(' ')[:self.max_positions - 2])
- bpe_sentence = ' ' + tokens + ' '
+ if len(tokens.split(" ")) > self.max_positions - 2:
+ tokens = " ".join(tokens.split(" ")[: self.max_positions - 2])
+ bpe_sentence = " " + tokens + " "
for s in addl_sentences:
- bpe_sentence += (' ' if not no_separator else '')
- bpe_sentence += ' ' + self.bpe.encode(s) + ' '
+ bpe_sentence += " " if not no_separator else ""
+ bpe_sentence += " " + self.bpe.encode(s) + " "
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)
return tokens.long()
@@ -81,10 +83,12 @@ def decode(self, tokens: torch.LongTensor):
tokens = tokens.cpu().numpy()
if tokens[0] == self.task.source_dictionary.bos():
tokens = tokens[1:] # remove
- eos_mask = (tokens == self.task.source_dictionary.eos())
+ eos_mask = tokens == self.task.source_dictionary.eos()
doc_mask = eos_mask[1:] & eos_mask[:-1]
sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
- sentences = [self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences]
+ sentences = [
+ self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
+ ]
if len(sentences) == 1:
return sentences[0]
return sentences
@@ -96,18 +100,23 @@ def _build_sample(self, src_tokens: List[torch.LongTensor]):
[x.numel() for x in src_tokens],
)
sample = dataset.collater(dataset)
- sample = utils.apply_to_sample(
- lambda tensor: tensor.to(self.device),
- sample
- )
+ sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample)
return sample
- def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> str:
+ def sample(
+ self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
+ ) -> str:
input = [self.encode(sentence) for sentence in sentences]
hypos = self.generate(input, beam, verbose, **kwargs)
- return [self.decode(x['tokens']) for x in hypos]
-
- def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor:
+ return [self.decode(x["tokens"]) for x in hypos]
+
+ def generate(
+ self,
+ tokens: List[torch.LongTensor],
+ beam: int = 5,
+ verbose: bool = False,
+ **kwargs
+ ) -> torch.LongTensor:
sample = self._build_sample(tokens)
# build generator using current args as well as any kwargs
@@ -120,34 +129,40 @@ def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool
generator,
[self.model],
sample,
- prefix_tokens=sample['net_input']['src_tokens'].new_zeros((len(tokens), 1)).fill_(self.task.source_dictionary.bos()),
+ prefix_tokens=sample["net_input"]["src_tokens"]
+ .new_zeros((len(tokens), 1))
+ .fill_(self.task.source_dictionary.bos()),
)
if verbose:
src_str_with_unk = self.string(tokens)
- logger.info('S\t{}'.format(src_str_with_unk))
+ logger.info("S\t{}".format(src_str_with_unk))
def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))
# Process top predictions
hypos = [x[0] for x in translations]
- hypos = [v for _, v in sorted(zip(sample['id'].tolist(), hypos))]
+ hypos = [v for _, v in sorted(zip(sample["id"].tolist(), hypos))]
return hypos
- def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor:
+ def extract_features(
+ self, tokens: torch.LongTensor, return_all_hiddens: bool = False
+ ) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.size(-1) > min(self.model.max_positions()):
- raise ValueError('tokens exceeds maximum length: {} > {}'.format(
- tokens.size(-1), self.model.max_positions()
- ))
+ raise ValueError(
+ "tokens exceeds maximum length: {} > {}".format(
+ tokens.size(-1), self.model.max_positions()
+ )
+ )
tokens.to(device=self.device),
prev_output_tokens = tokens.clone()
prev_output_tokens[:, 0] = tokens.gather(
1,
- (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1)- 1).unsqueeze(-1),
+ (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1),
).squeeze()
prev_output_tokens[:, 1:] = tokens[:, :-1]
@@ -160,7 +175,7 @@ def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool =
)
if return_all_hiddens:
# convert from T x B x C -> B x T x C
- inner_states = extra['inner_states']
+ inner_states = extra["inner_states"]
return [inner_state.transpose(0, 1) for inner_state in inner_states]
else:
return features # just the last layer's features
diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py
index 90e79e4651..0f22352b68 100644
--- a/fairseq/models/bart/model.py
+++ b/fairseq/models/bart/model.py
@@ -11,12 +11,8 @@
import torch
import torch.nn as nn
-
from fairseq import utils
-from fairseq.models import (
- register_model,
- register_model_architecture,
-)
+from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import TransformerModel
from fairseq.modules.transformer_sentence_encoder import init_bert_params
@@ -26,17 +22,16 @@
logger = logging.getLogger(__name__)
-@register_model('bart')
+@register_model("bart")
class BARTModel(TransformerModel):
-
@classmethod
def hub_models(cls):
return {
- 'bart.base': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz',
- 'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz',
- 'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz',
- 'bart.large.cnn': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz',
- 'bart.large.xsum': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz',
+ "bart.base": "http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz",
+ "bart.large": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz",
+ "bart.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz",
+ "bart.large.cnn": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz",
+ "bart.large.xsum": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz",
}
def __init__(self, args, encoder, decoder):
@@ -51,28 +46,35 @@ def __init__(self, args, encoder, decoder):
def add_args(parser):
super(BARTModel, BARTModel).add_args(parser)
parser.add_argument(
- '--pooler-dropout', type=float, metavar='D',
- help='dropout probability in the masked_lm pooler layers'
+ "--pooler-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability in the masked_lm pooler layers",
)
parser.add_argument(
- '--pooler-activation-fn',
+ "--pooler-activation-fn",
choices=utils.get_available_activation_fns(),
- help='activation function to use for pooler layer'
+ help="activation function to use for pooler layer",
)
parser.add_argument(
- '--spectral-norm-classification-head',
- action='store_true',
- help='Apply spectral normalization on the classification head'
+ "--spectral-norm-classification-head",
+ action="store_true",
+ help="Apply spectral normalization on the classification head",
)
@property
def supported_targets(self):
- return {'self'}
+ return {"self"}
def forward(
- self, src_tokens, src_lengths, prev_output_tokens,
- features_only=False, classification_head_name=None,
- token_embeddings=None, **kwargs
+ self,
+ src_tokens,
+ src_lengths,
+ prev_output_tokens,
+ features_only=False,
+ classification_head_name=None,
+ token_embeddings=None,
+ **kwargs,
):
if classification_head_name is not None:
features_only = True
@@ -103,12 +105,13 @@ def forward(
def from_pretrained(
cls,
model_name_or_path,
- checkpoint_file='model.pt',
- data_name_or_path='.',
- bpe='gpt2',
+ checkpoint_file="model.pt",
+ data_name_or_path=".",
+ bpe="gpt2",
**kwargs,
):
from fairseq import hub_utils
+
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
@@ -118,9 +121,11 @@ def from_pretrained(
load_checkpoint_heads=True,
**kwargs,
)
- return BARTHubInterface(x['args'], x['task'], x['models'][0])
+ return BARTHubInterface(x["args"], x["task"], x["models"][0])
- def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
+ def register_classification_head(
+ self, name, num_classes=None, inner_dim=None, **kwargs
+ ):
"""Register a classification head."""
logger.info("Registering classification head: {0}".format(name))
if name in self.classification_heads:
@@ -129,7 +134,7 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, *
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
- 'and inner_dim {} (prev: {})'.format(
+ "and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
@@ -139,43 +144,54 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, *
num_classes=num_classes,
activation_fn=self.args.pooler_activation_fn,
pooler_dropout=self.args.pooler_dropout,
- do_spectral_norm=self.args.spectral_norm_classification_head
+ do_spectral_norm=self.args.spectral_norm_classification_head,
)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
- prefix = name + '.' if name != '' else ''
- current_head_names = [] if not hasattr(self, 'classification_heads') else \
- self.classification_heads.keys()
+ prefix = name + "." if name != "" else ""
+ current_head_names = (
+ []
+ if not hasattr(self, "classification_heads")
+ else self.classification_heads.keys()
+ )
# Handle new classification heads present in the state dict.
keys_to_delete = []
for k in state_dict.keys():
- if not k.startswith(prefix + 'classification_heads.'):
+ if not k.startswith(prefix + "classification_heads."):
continue
- head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
- num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
- inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
+ head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
+ num_classes = state_dict[
+ prefix + "classification_heads." + head_name + ".out_proj.weight"
+ ].size(0)
+ inner_dim = state_dict[
+ prefix + "classification_heads." + head_name + ".dense.weight"
+ ].size(0)
- if getattr(self.args, 'load_checkpoint_heads', False):
+ if getattr(self.args, "load_checkpoint_heads", False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
- 'deleting classification head ({}) from checkpoint '
- 'not present in current model: {}'.format(head_name, k)
+ "deleting classification head ({}) from checkpoint "
+ "not present in current model: {}".format(head_name, k)
)
keys_to_delete.append(k)
elif (
- num_classes != self.classification_heads[head_name].out_proj.out_features
- or inner_dim != self.classification_heads[head_name].dense.out_features
+ num_classes
+ != self.classification_heads[head_name].out_proj.out_features
+ or inner_dim
+ != self.classification_heads[head_name].dense.out_features
):
logger.warning(
- 'deleting classification head ({}) from checkpoint '
- 'with different dimensions than current model: {}'.format(head_name, k)
+ "deleting classification head ({}) from checkpoint "
+ "with different dimensions than current model: {}".format(
+ head_name, k
+ )
)
keys_to_delete.append(k)
for k in keys_to_delete:
@@ -187,55 +203,66 @@ def truncate_emb(key):
# When finetuning on translation task, remove last row of
# embedding matrix that corresponds to mask_idx token.
- loaded_dict_size = state_dict['encoder.embed_tokens.weight'].size(0)
- if loaded_dict_size == len(self.encoder.dictionary) + 1 and '' not in self.encoder.dictionary:
- truncate_emb('encoder.embed_tokens.weight')
- truncate_emb('decoder.embed_tokens.weight')
- truncate_emb('encoder.output_projection.weight')
- truncate_emb('decoder.output_projection.weight')
+ loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0)
+ if (
+ loaded_dict_size == len(self.encoder.dictionary) + 1
+ and "" not in self.encoder.dictionary
+ ):
+ truncate_emb("encoder.embed_tokens.weight")
+ truncate_emb("decoder.embed_tokens.weight")
+ truncate_emb("encoder.output_projection.weight")
+ truncate_emb("decoder.output_projection.weight")
# When continued pretraining on new set of languages for mbart,
# add extra lang embeddings at the end of embed_tokens.
# Note: newly added languages are assumed to have been added at the end.
- if self.args.task == 'multilingual_denoising' and loaded_dict_size < len(self.encoder.dictionary):
+ if self.args.task == "multilingual_denoising" and loaded_dict_size < len(
+ self.encoder.dictionary
+ ):
logger.info(
- "Adding extra language embeddings not found in pretrained model for "\
+ "Adding extra language embeddings not found in pretrained model for "
"continued pretraining of MBART on new set of languages."
)
- loaded_mask_token_embedding = state_dict['encoder.embed_tokens.weight'][-1, :]
+ loaded_mask_token_embedding = state_dict["encoder.embed_tokens.weight"][
+ -1, :
+ ]
num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
- embed_dim = state_dict['encoder.embed_tokens.weight'].size(1)
+ embed_dim = state_dict["encoder.embed_tokens.weight"].size(1)
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
- nn.init.normal_(
- new_lang_embed_to_add,
- mean=0,
- std=embed_dim ** -0.5
- )
+ nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim ** -0.5)
new_lang_embed_to_add = new_lang_embed_to_add.to(
- dtype=state_dict['encoder.embed_tokens.weight'].dtype,
+ dtype=state_dict["encoder.embed_tokens.weight"].dtype,
)
- state_dict['encoder.embed_tokens.weight'] = torch.cat([
- state_dict['encoder.embed_tokens.weight'][:loaded_dict_size-1, :],
- new_lang_embed_to_add,
- loaded_mask_token_embedding.unsqueeze(0)]
+ state_dict["encoder.embed_tokens.weight"] = torch.cat(
+ [
+ state_dict["encoder.embed_tokens.weight"][
+ : loaded_dict_size - 1, :
+ ],
+ new_lang_embed_to_add,
+ loaded_mask_token_embedding.unsqueeze(0),
+ ]
)
- state_dict['decoder.embed_tokens.weight'] = torch.cat([
- state_dict['decoder.embed_tokens.weight'][:loaded_dict_size-1, :],
- new_lang_embed_to_add,
- loaded_mask_token_embedding.unsqueeze(0)]
+ state_dict["decoder.embed_tokens.weight"] = torch.cat(
+ [
+ state_dict["decoder.embed_tokens.weight"][
+ : loaded_dict_size - 1, :
+ ],
+ new_lang_embed_to_add,
+ loaded_mask_token_embedding.unsqueeze(0),
+ ]
)
# Copy any newly-added classification heads into the state dict
# with their current weights.
- if hasattr(self, 'classification_heads'):
+ if hasattr(self, "classification_heads"):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
- if prefix + 'classification_heads.' + k not in state_dict:
- logger.info('Overwriting', prefix + 'classification_heads.' + k)
- state_dict[prefix + 'classification_heads.' + k] = v
+ if prefix + "classification_heads." + k not in state_dict:
+ logger.info("Overwriting", prefix + "classification_heads." + k)
+ state_dict[prefix + "classification_heads." + k] = v
class BARTClassificationHead(nn.Module):
@@ -248,7 +275,7 @@ def __init__(
num_classes,
activation_fn,
pooler_dropout,
- do_spectral_norm=False
+ do_spectral_norm=False,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
@@ -269,67 +296,73 @@ def forward(self, features, **kwargs):
return x
-@register_model_architecture('bart', 'bart_large')
+@register_model_architecture("bart", "bart_large")
def bart_large_architecture(args):
- args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*1024)
- args.encoder_layers = getattr(args, 'encoder_layers', 12)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
- args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
- args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
- args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
- args.decoder_layers = getattr(args, 'decoder_layers', 12)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
- args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
- args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.)
- args.relu_dropout = getattr(args, 'relu_dropout', 0.)
- args.dropout = getattr(args, 'dropout', 0.1)
- args.max_target_positions = getattr(args, 'max_target_positions', 1024)
- args.max_source_positions = getattr(args, 'max_source_positions', 1024)
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
- args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
- args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True)
- args.share_all_embeddings = getattr(args, 'share_all_embeddings', True)
-
- args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
- args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
-
- args.no_scale_embedding = getattr(args, 'no_scale_embedding', True)
- args.layernorm_embedding = getattr(args, 'layernorm_embedding', True)
-
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
- args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
- args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
-
-
-@register_model_architecture('bart', 'bart_base')
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024)
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
+ args.relu_dropout = getattr(args, "relu_dropout", 0.0)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.max_target_positions = getattr(args, "max_target_positions", 1024)
+ args.max_source_positions = getattr(args, "max_source_positions", 1024)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", True
+ )
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
+
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
+
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
+
+
+@register_model_architecture("bart", "bart_base")
def bart_base_architecture(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*768)
- args.encoder_layers = getattr(args, 'encoder_layers', 6)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
- args.decoder_layers = getattr(args, 'decoder_layers', 6)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768)
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
bart_large_architecture(args)
-@register_model_architecture('bart', 'mbart_large')
+@register_model_architecture("bart", "mbart_large")
def mbart_large_architecture(args):
- args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
bart_large_architecture(args)
-@register_model_architecture('bart', 'mbart_base')
+@register_model_architecture("bart", "mbart_base")
def mbart_base_architecture(args):
- args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
bart_base_architecture(args)
-@register_model_architecture('bart', 'mbart_base_wmt20')
+@register_model_architecture("bart", "mbart_base_wmt20")
def mbart_base_wmt20_architecture(args):
- args.layernorm_embedding = getattr(args, 'layernorm_embedding', False)
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
mbart_base_architecture(args)
diff --git a/fairseq/models/composite_encoder.py b/fairseq/models/composite_encoder.py
index 60d1473f5f..4e20fe3a83 100644
--- a/fairseq/models/composite_encoder.py
+++ b/fairseq/models/composite_encoder.py
@@ -43,7 +43,9 @@ def forward(self, src_tokens, src_lengths):
def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order."""
for key in self.encoders:
- encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order)
+ encoder_out[key] = self.encoders[key].reorder_encoder_out(
+ encoder_out[key], new_order
+ )
return encoder_out
def max_positions(self):
diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py
index 4fe02b20dd..ece10c6333 100644
--- a/fairseq/models/distributed_fairseq_model.py
+++ b/fairseq/models/distributed_fairseq_model.py
@@ -6,7 +6,6 @@
import inspect
import torch.nn as nn
-
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
@@ -32,7 +31,7 @@ def DistributedFairseqModel(args, model, process_group=None):
"""
# determine which DDP class to extend
assert isinstance(model, nn.Module)
- if args.distributed_wrapper == 'DDP' and args.ddp_backend == 'c10d':
+ if args.distributed_wrapper == "DDP" and args.ddp_backend == "c10d":
ddp_class = nn.parallel.DistributedDataParallel
init_kwargs = dict(
module=model,
@@ -43,23 +42,23 @@ def DistributedFairseqModel(args, model, process_group=None):
process_group=process_group,
)
# Maintain backward compatibility
- if 'check_reduction' in inspect.getargspec(ddp_class)[0]:
- init_kwargs['check_reduction'] = True
- if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]:
- init_kwargs['find_unused_parameters'] = args.find_unused_parameters
- elif args.distributed_wrapper == 'DDP' and args.ddp_backend == 'no_c10d':
+ if "check_reduction" in inspect.getargspec(ddp_class)[0]:
+ init_kwargs["check_reduction"] = True
+ if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]:
+ init_kwargs["find_unused_parameters"] = args.find_unused_parameters
+ elif args.distributed_wrapper == "DDP" and args.ddp_backend == "no_c10d":
ddp_class = LegacyDistributedDataParallel
init_kwargs = dict(
module=model,
world_size=args.distributed_world_size,
- buffer_size=2**28,
+ buffer_size=2 ** 28,
process_group=process_group,
)
- elif args.distributed_wrapper == 'SlowMo':
+ elif args.distributed_wrapper == "SlowMo":
if _GOSSIP_DISABLED:
raise ImportError(
- 'Cannot find gossip library. Please install from: '
- 'github.com/facebookresearch/stochastic_gradient_push'
+ "Cannot find gossip library. Please install from: "
+ "github.com/facebookresearch/stochastic_gradient_push"
)
ddp_class = gossip.GossipDataParallel
@@ -82,11 +81,11 @@ def DistributedFairseqModel(args, model, process_group=None):
broadcast_buffers=args.broadcast_buffers,
nprocs_per_node=args.nprocs_per_node,
slowmo_momentum=args.slowmo_momentum,
- localsgd=(args.slowmo_algorithm == 'LocalSGD'),
- localsgd_frequency=args.localsgd_frequency
+ localsgd=(args.slowmo_algorithm == "LocalSGD"),
+ localsgd_frequency=args.localsgd_frequency,
)
else:
- raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)
+ raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
class _DistributedFairseqModel(ddp_class):
"""Extend DistributedDataParallel to check for missing
@@ -96,7 +95,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getattr__(self, name):
- wrapped_module = super().__getattr__('module')
+ wrapped_module = super().__getattr__("module")
if hasattr(wrapped_module, name):
return getattr(wrapped_module, name)
return super().__getattr__(name)
diff --git a/fairseq/models/fairseq_encoder.py b/fairseq/models/fairseq_encoder.py
index 7ddc0fba01..c8873daa28 100644
--- a/fairseq/models/fairseq_encoder.py
+++ b/fairseq/models/fairseq_encoder.py
@@ -3,11 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from typing import Dict, List, NamedTuple, Optional
+
import torch
import torch.nn as nn
-from typing import Dict, List, NamedTuple, Optional
from torch import Tensor
+
EncoderOut = NamedTuple(
"EncoderOut",
[
@@ -55,9 +57,7 @@ def forward_torchscript(self, net_input: Dict[str, Tensor]):
@torch.jit.unused
def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
encoder_input = {
- k: v
- for k, v in net_input.items()
- if k != "prev_output_tokens"
+ k: v for k, v in net_input.items() if k != "prev_output_tokens"
}
return self.forward(**encoder_input)
@@ -86,6 +86,7 @@ def set_num_updates(self, num_updates):
"""State from trainer to pass along to model at every update."""
def _apply(m):
- if hasattr(m, 'set_num_updates') and m != self:
+ if hasattr(m, "set_num_updates") and m != self:
m.set_num_updates(num_updates)
+
self.apply(_apply)
diff --git a/fairseq/models/fairseq_incremental_decoder.py b/fairseq/models/fairseq_incremental_decoder.py
index 68e583fea8..cc72a0f8f3 100644
--- a/fairseq/models/fairseq_incremental_decoder.py
+++ b/fairseq/models/fairseq_incremental_decoder.py
@@ -6,10 +6,9 @@
import logging
from typing import Dict, Optional
-from torch import Tensor
-
-from fairseq.models import FairseqDecoder
from fairseq.incremental_decoding_utils import with_incremental_state
+from fairseq.models import FairseqDecoder
+from torch import Tensor
logger = logging.getLogger(__name__)
@@ -41,7 +40,9 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary):
super().__init__(dictionary)
- def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
+ def forward(
+ self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
+ ):
"""
Args:
prev_output_tokens (LongTensor): shifted output tokens of shape
@@ -58,7 +59,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None,
"""
raise NotImplementedError
- def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
+ def extract_features(
+ self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
+ ):
"""
Returns:
tuple:
@@ -92,19 +95,22 @@ def reorder_incremental_state_scripting(
calling :func:`reorder_incremental_state` directly.
"""
for module in self.modules():
- if hasattr(module, 'reorder_incremental_state'):
+ if hasattr(module, "reorder_incremental_state"):
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
- if getattr(self, '_beam_size', -1) != beam_size:
+ if getattr(self, "_beam_size", -1) != beam_size:
seen = set()
def apply_set_beam_size(module):
- if module != self and hasattr(module, 'set_beam_size') \
- and module not in seen:
+ if (
+ module != self
+ and hasattr(module, "set_beam_size")
+ and module not in seen
+ ):
seen.add(module)
module.set_beam_size(beam_size)
diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py
index facb7d011b..bfd41777b2 100644
--- a/fairseq/models/fairseq_model.py
+++ b/fairseq/models/fairseq_model.py
@@ -223,7 +223,7 @@ def apply_prepare_for_tpu_(module):
@classmethod
def upgrade_args(cls, args):
- if hasattr(args, 'max_sentences') and not hasattr(args, 'batch_size'):
+ if hasattr(args, "max_sentences") and not hasattr(args, "batch_size"):
args.batch_size = args.max_sentences
@classmethod
diff --git a/fairseq/models/fconv.py b/fairseq/models/fconv.py
index c60a2f4e5f..c99a215101 100644
--- a/fairseq/models/fconv.py
+++ b/fairseq/models/fconv.py
@@ -8,22 +8,25 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
- FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
+ FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import (
- AdaptiveSoftmax, BeamableMM, FairseqDropout, GradMultiply, LearnedPositionalEmbedding,
+ AdaptiveSoftmax,
+ BeamableMM,
+ FairseqDropout,
+ GradMultiply,
+ LearnedPositionalEmbedding,
LinearizedConvolution,
)
-@register_model('fconv')
+@register_model("fconv")
class FConvModel(FairseqEncoderDecoderModel):
"""
A fully convolutional model, i.e. a convolutional encoder and a
@@ -44,23 +47,30 @@ class FConvModel(FairseqEncoderDecoderModel):
@classmethod
def hub_models(cls):
-
def moses_subword(path):
return {
- 'path': path,
- 'tokenizer': 'moses',
- 'bpe': 'subword_nmt',
+ "path": path,
+ "tokenizer": "moses",
+ "bpe": "subword_nmt",
}
return {
- 'conv.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2'),
- 'conv.wmt14.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2'),
- 'conv.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2'),
+ "conv.wmt14.en-fr": moses_subword(
+ "https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2"
+ ),
+ "conv.wmt14.en-de": moses_subword(
+ "https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2"
+ ),
+ "conv.wmt17.en-de": moses_subword(
+ "https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2"
+ ),
}
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
- self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
+ self.encoder.num_attention_layers = sum(
+ layer is not None for layer in decoder.attention
+ )
@staticmethod
def add_args(parser):
@@ -147,8 +157,13 @@ class FConvEncoder(FairseqEncoder):
"""
def __init__(
- self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024,
- convolutions=((512, 3),) * 20, dropout=0.1,
+ self,
+ dictionary,
+ embed_dim=512,
+ embed_dict=None,
+ max_positions=1024,
+ convolutions=((512, 3),) * 20,
+ dropout=0.1,
):
super().__init__(dictionary)
self.dropout_module = FairseqDropout(
@@ -160,7 +175,9 @@ def __init__(
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
if embed_dict:
- self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
+ self.embed_tokens = utils.load_embedding(
+ embed_dict, self.dictionary, self.embed_tokens
+ )
self.embed_positions = PositionalEmbedding(
max_positions,
@@ -181,15 +198,23 @@ def __init__(
residual_dim = out_channels
else:
residual_dim = layer_in_channels[-residual]
- self.projections.append(Linear(residual_dim, out_channels)
- if residual_dim != out_channels else None)
+ self.projections.append(
+ Linear(residual_dim, out_channels)
+ if residual_dim != out_channels
+ else None
+ )
if kernel_size % 2 == 1:
padding = kernel_size // 2
else:
padding = 0
self.convolutions.append(
- ConvTBC(in_channels, out_channels * 2, kernel_size,
- dropout=dropout, padding=padding)
+ ConvTBC(
+ in_channels,
+ out_channels * 2,
+ kernel_size,
+ dropout=dropout,
+ padding=padding,
+ )
)
self.residuals.append(residual)
in_channels = out_channels
@@ -232,7 +257,9 @@ def forward(self, src_tokens, src_lengths):
residuals = [x]
# temporal convolutions
- for proj, conv, res_layer in zip(self.projections, self.convolutions, self.residuals):
+ for proj, conv, res_layer in zip(
+ self.projections, self.convolutions, self.residuals
+ ):
if res_layer > 0:
residual = residuals[-res_layer]
residual = residual if proj is None else proj(residual)
@@ -274,19 +301,20 @@ def forward(self, src_tokens, src_lengths):
y = (x + input_embedding) * math.sqrt(0.5)
return {
- 'encoder_out': (x, y),
- 'encoder_padding_mask': encoder_padding_mask, # B x T
+ "encoder_out": (x, y),
+ "encoder_padding_mask": encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out, new_order):
- if encoder_out['encoder_out'] is not None:
- encoder_out['encoder_out'] = (
- encoder_out['encoder_out'][0].index_select(0, new_order),
- encoder_out['encoder_out'][1].index_select(0, new_order),
+ if encoder_out["encoder_out"] is not None:
+ encoder_out["encoder_out"] = (
+ encoder_out["encoder_out"][0].index_select(0, new_order),
+ encoder_out["encoder_out"][1].index_select(0, new_order),
)
- if encoder_out['encoder_padding_mask'] is not None:
- encoder_out['encoder_padding_mask'] = \
- encoder_out['encoder_padding_mask'].index_select(0, new_order)
+ if encoder_out["encoder_padding_mask"] is not None:
+ encoder_out["encoder_padding_mask"] = encoder_out[
+ "encoder_padding_mask"
+ ].index_select(0, new_order)
return encoder_out
def max_positions(self):
@@ -313,10 +341,11 @@ def forward(self, x, target_embedding, encoder_out, encoder_padding_mask):
# don't attend over padding
if encoder_padding_mask is not None:
- x = x.float().masked_fill(
- encoder_padding_mask.unsqueeze(1),
- float('-inf')
- ).type_as(x) # FP16 support: cast to float and back
+ x = (
+ x.float()
+ .masked_fill(encoder_padding_mask.unsqueeze(1), float("-inf"))
+ .type_as(x)
+ ) # FP16 support: cast to float and back
# softmax over last dim
sz = x.size()
@@ -331,7 +360,9 @@ def forward(self, x, target_embedding, encoder_out, encoder_padding_mask):
if encoder_padding_mask is None:
x = x * (s * math.sqrt(1.0 / s))
else:
- s = s - encoder_padding_mask.type_as(x).sum(dim=1, keepdim=True) # exclude padding
+ s = s - encoder_padding_mask.type_as(x).sum(
+ dim=1, keepdim=True
+ ) # exclude padding
s = s.unsqueeze(-1)
x = x * (s * s.rsqrt())
@@ -343,20 +374,29 @@ def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
"""Replace torch.bmm with BeamableMM."""
if beamable_mm_beam_size is not None:
del self.bmm
- self.add_module('bmm', BeamableMM(beamable_mm_beam_size))
+ self.add_module("bmm", BeamableMM(beamable_mm_beam_size))
class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder"""
def __init__(
- self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
- max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
- dropout=0.1, share_embed=False, positional_embeddings=True,
- adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0.,
+ self,
+ dictionary,
+ embed_dim=512,
+ embed_dict=None,
+ out_embed_dim=256,
+ max_positions=1024,
+ convolutions=((512, 3),) * 20,
+ attention=True,
+ dropout=0.1,
+ share_embed=False,
+ positional_embeddings=True,
+ adaptive_softmax_cutoff=None,
+ adaptive_softmax_dropout=0.0,
):
super().__init__(dictionary)
- self.register_buffer('version', torch.Tensor([2]))
+ self.register_buffer("version", torch.Tensor([2]))
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
@@ -368,20 +408,28 @@ def __init__(
# expand True into [True, True, ...] and do the same with False
attention = [attention] * len(convolutions)
if not isinstance(attention, list) or len(attention) != len(convolutions):
- raise ValueError('Attention is expected to be a list of booleans of '
- 'length equal to the number of layers.')
+ raise ValueError(
+ "Attention is expected to be a list of booleans of "
+ "length equal to the number of layers."
+ )
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
if embed_dict:
- self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
+ self.embed_tokens = utils.load_embedding(
+ embed_dict, self.dictionary, self.embed_tokens
+ )
- self.embed_positions = PositionalEmbedding(
- max_positions,
- embed_dim,
- padding_idx,
- ) if positional_embeddings else None
+ self.embed_positions = (
+ PositionalEmbedding(
+ max_positions,
+ embed_dim,
+ padding_idx,
+ )
+ if positional_embeddings
+ else None
+ )
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
@@ -395,14 +443,23 @@ def __init__(
residual_dim = out_channels
else:
residual_dim = layer_in_channels[-residual]
- self.projections.append(Linear(residual_dim, out_channels)
- if residual_dim != out_channels else None)
+ self.projections.append(
+ Linear(residual_dim, out_channels)
+ if residual_dim != out_channels
+ else None
+ )
self.convolutions.append(
- LinearizedConv1d(in_channels, out_channels * 2, kernel_size,
- padding=(kernel_size - 1), dropout=dropout)
+ LinearizedConv1d(
+ in_channels,
+ out_channels * 2,
+ kernel_size,
+ padding=(kernel_size - 1),
+ dropout=dropout,
+ )
+ )
+ self.attention.append(
+ AttentionLayer(out_channels, embed_dim) if attention[i] else None
)
- self.attention.append(AttentionLayer(out_channels, embed_dim)
- if attention[i] else None)
self.residuals.append(residual)
in_channels = out_channels
layer_in_channels.append(out_channels)
@@ -412,26 +469,35 @@ def __init__(
if adaptive_softmax_cutoff is not None:
assert not share_embed
- self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, in_channels, adaptive_softmax_cutoff,
- dropout=adaptive_softmax_dropout)
+ self.adaptive_softmax = AdaptiveSoftmax(
+ num_embeddings,
+ in_channels,
+ adaptive_softmax_cutoff,
+ dropout=adaptive_softmax_dropout,
+ )
else:
self.fc2 = Linear(in_channels, out_embed_dim)
if share_embed:
- assert out_embed_dim == embed_dim, \
- "Shared embed weights implies same dimensions " \
+ assert out_embed_dim == embed_dim, (
+ "Shared embed weights implies same dimensions "
" out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim)
+ )
self.fc3 = nn.Linear(out_embed_dim, num_embeddings)
self.fc3.weight = self.embed_tokens.weight
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
- def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
+ def forward(
+ self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
+ ):
if encoder_out is not None:
- encoder_padding_mask = encoder_out['encoder_padding_mask']
- encoder_out = encoder_out['encoder_out']
+ encoder_padding_mask = encoder_out["encoder_padding_mask"]
+ encoder_out = encoder_out["encoder_out"]
# split and transpose encoder outputs
- encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state)
+ encoder_a, encoder_b = self._split_encoder_out(
+ encoder_out, incremental_state
+ )
if self.embed_positions is not None:
pos_embed = self.embed_positions(prev_output_tokens, incremental_state)
@@ -457,8 +523,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None,
avg_attn_scores = None
num_attn_layers = len(self.attention)
residuals = [x]
- for proj, conv, attention, res_layer in zip(self.projections, self.convolutions, self.attention,
- self.residuals):
+ for proj, conv, attention, res_layer in zip(
+ self.projections, self.convolutions, self.attention, self.residuals
+ ):
if res_layer > 0:
residual = residuals[-res_layer]
residual = residual if proj is None else proj(residual)
@@ -473,7 +540,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None,
if attention is not None:
x = self._transpose_if_training(x, incremental_state)
- x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
+ x, attn_scores = attention(
+ x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask
+ )
if not self.training and self.need_attn:
attn_scores = attn_scores / num_attn_layers
@@ -502,23 +571,31 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None,
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
- encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
+ encoder_out = utils.get_incremental_state(
+ self, incremental_state, "encoder_out"
+ )
if encoder_out is not None:
encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out)
- utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
+ utils.set_incremental_state(
+ self, incremental_state, "encoder_out", encoder_out
+ )
def max_positions(self):
"""Maximum output length supported by the decoder."""
- return self.embed_positions.max_positions if self.embed_positions is not None else float('inf')
+ return (
+ self.embed_positions.max_positions
+ if self.embed_positions is not None
+ else float("inf")
+ )
def upgrade_state_dict(self, state_dict):
- if utils.item(state_dict.get('decoder.version', torch.Tensor([1]))[0]) < 2:
+ if utils.item(state_dict.get("decoder.version", torch.Tensor([1]))[0]) < 2:
# old models use incorrect weight norm dimension
for i, conv in enumerate(self.convolutions):
# reconfigure weight norm
nn.utils.remove_weight_norm(conv)
self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
- state_dict['decoder.version'] = torch.Tensor([1])
+ state_dict["decoder.version"] = torch.Tensor([1])
return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
@@ -535,7 +612,9 @@ def _split_encoder_out(self, encoder_out, incremental_state):
This is cached when doing incremental inference.
"""
- cached_result = utils.get_incremental_state(self, incremental_state, 'encoder_out')
+ cached_result = utils.get_incremental_state(
+ self, incremental_state, "encoder_out"
+ )
if cached_result is not None:
return cached_result
@@ -545,7 +624,7 @@ def _split_encoder_out(self, encoder_out, incremental_state):
result = (encoder_a, encoder_b)
if incremental_state is not None:
- utils.set_incremental_state(self, incremental_state, 'encoder_out', result)
+ utils.set_incremental_state(self, incremental_state, "encoder_out", result)
return result
def _transpose_if_training(self, x, incremental_state):
@@ -567,7 +646,11 @@ def extend_conv_spec(convolutions):
elif len(spec) == 2:
extended.append(spec + (1,))
else:
- raise Exception('invalid number of parameters in convolution spec ' + str(spec) + '. expected 2 or 3')
+ raise Exception(
+ "invalid number of parameters in convolution spec "
+ + str(spec)
+ + ". expected 2 or 3"
+ )
return tuple(extended)
@@ -585,7 +668,7 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx):
return m
-def Linear(in_features, out_features, dropout=0.):
+def Linear(in_features, out_features, dropout=0.0):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features)
nn.init.normal_(m.weight, mean=0, std=math.sqrt((1 - dropout) / in_features))
@@ -593,7 +676,7 @@ def Linear(in_features, out_features, dropout=0.):
return nn.utils.weight_norm(m)
-def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwargs):
+def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
@@ -602,9 +685,10 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwarg
return nn.utils.weight_norm(m, dim=2)
-def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs):
+def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
"""Weight-normalized Conv1d layer"""
from fairseq.modules import ConvTBC
+
m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal_(m.weight, mean=0, std=std)
@@ -612,61 +696,61 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs):
return nn.utils.weight_norm(m, dim=2)
-@register_model_architecture('fconv', 'fconv')
+@register_model_architecture("fconv", "fconv")
def base_architecture(args):
- args.dropout = getattr(args, 'dropout', 0.1)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
- args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
- args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20')
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
- args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20')
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
- args.decoder_attention = getattr(args, 'decoder_attention', 'True')
- args.share_input_output_embed = getattr(args, 'share_input_output_embed', False)
-
-
-@register_model_architecture('fconv', 'fconv_iwslt_de_en')
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 20")
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 20")
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
+ args.decoder_attention = getattr(args, "decoder_attention", "True")
+ args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
+
+
+@register_model_architecture("fconv", "fconv_iwslt_de_en")
def fconv_iwslt_de_en(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
- args.encoder_layers = getattr(args, 'encoder_layers', '[(256, 3)] * 4')
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
- args.decoder_layers = getattr(args, 'decoder_layers', '[(256, 3)] * 3')
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
+ args.encoder_layers = getattr(args, "encoder_layers", "[(256, 3)] * 4")
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
+ args.decoder_layers = getattr(args, "decoder_layers", "[(256, 3)] * 3")
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
base_architecture(args)
-@register_model_architecture('fconv', 'fconv_wmt_en_ro')
+@register_model_architecture("fconv", "fconv_wmt_en_ro")
def fconv_wmt_en_ro(args):
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
base_architecture(args)
-@register_model_architecture('fconv', 'fconv_wmt_en_de')
+@register_model_architecture("fconv", "fconv_wmt_en_de")
def fconv_wmt_en_de(args):
- convs = '[(512, 3)] * 9' # first 9 layers have 512 units
- convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units
- convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions
-
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
- args.encoder_layers = getattr(args, 'encoder_layers', convs)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
- args.decoder_layers = getattr(args, 'decoder_layers', convs)
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
+ convs = "[(512, 3)] * 9" # first 9 layers have 512 units
+ convs += " + [(1024, 3)] * 4" # next 4 layers have 1024 units
+ convs += " + [(2048, 1)] * 2" # final 2 layers use 1x1 convolutions
+
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
+ args.encoder_layers = getattr(args, "encoder_layers", convs)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768)
+ args.decoder_layers = getattr(args, "decoder_layers", convs)
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
base_architecture(args)
-@register_model_architecture('fconv', 'fconv_wmt_en_fr')
+@register_model_architecture("fconv", "fconv_wmt_en_fr")
def fconv_wmt_en_fr(args):
- convs = '[(512, 3)] * 6' # first 6 layers have 512 units
- convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units
- convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units
- convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions
- convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions
-
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
- args.encoder_layers = getattr(args, 'encoder_layers', convs)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
- args.decoder_layers = getattr(args, 'decoder_layers', convs)
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
+ convs = "[(512, 3)] * 6" # first 6 layers have 512 units
+ convs += " + [(768, 3)] * 4" # next 4 layers have 768 units
+ convs += " + [(1024, 3)] * 3" # next 3 layers have 1024 units
+ convs += " + [(2048, 1)] * 1" # next 1 layer uses 1x1 convolutions
+ convs += " + [(4096, 1)] * 1" # final 1 layer uses 1x1 convolutions
+
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
+ args.encoder_layers = getattr(args, "encoder_layers", convs)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768)
+ args.decoder_layers = getattr(args, "decoder_layers", convs)
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
base_architecture(args)
diff --git a/fairseq/models/fconv_lm.py b/fairseq/models/fconv_lm.py
index 4c3c5c66dd..07391eaa29 100644
--- a/fairseq/models/fconv_lm.py
+++ b/fairseq/models/fconv_lm.py
@@ -12,7 +12,7 @@
from fairseq.models.fconv import FConvDecoder
-@register_model('fconv_lm')
+@register_model("fconv_lm")
class FConvLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@@ -20,21 +20,45 @@ def __init__(self, decoder):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
- parser.add_argument('--dropout', type=float, metavar='D',
- help='dropout probability')
- parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
- help='decoder embedding dimension')
- parser.add_argument('--decoder-layers', type=str, metavar='EXPR',
- help='decoder layers [(dim, kernel_size), ...]')
- parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
- help='decoder output embedding dimension')
- parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
- help='comma separated list of adaptive softmax cutoff points. '
- 'Must be used with adaptive_loss criterion')
- parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
- help='sets adaptive softmax dropout for the tail projections')
- parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
- help='decoder attention [True, ...]')
+ parser.add_argument(
+ "--dropout", type=float, metavar="D", help="dropout probability"
+ )
+ parser.add_argument(
+ "--decoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension",
+ )
+ parser.add_argument(
+ "--decoder-layers",
+ type=str,
+ metavar="EXPR",
+ help="decoder layers [(dim, kernel_size), ...]",
+ )
+ parser.add_argument(
+ "--decoder-out-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder output embedding dimension",
+ )
+ parser.add_argument(
+ "--adaptive-softmax-cutoff",
+ metavar="EXPR",
+ help="comma separated list of adaptive softmax cutoff points. "
+ "Must be used with adaptive_loss criterion",
+ )
+ parser.add_argument(
+ "--adaptive-softmax-dropout",
+ type=float,
+ metavar="D",
+ help="sets adaptive softmax dropout for the tail projections",
+ )
+ parser.add_argument(
+ "--decoder-attention",
+ type=str,
+ metavar="EXPR",
+ help="decoder attention [True, ...]",
+ )
@classmethod
def build_model(cls, args, task):
@@ -42,7 +66,9 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_lm_architecture(args)
- if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'):
+ if hasattr(args, "max_target_positions") and not hasattr(
+ args, "tokens_per_sample"
+ ):
args.tokens_per_sample = args.max_target_positions
decoder = FConvDecoder(
@@ -57,48 +83,53 @@ def build_model(cls, args, task):
positional_embeddings=False,
adaptive_softmax_cutoff=(
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int)
- if args.criterion == 'adaptive_loss' else None
+ if args.criterion == "adaptive_loss"
+ else None
),
adaptive_softmax_dropout=args.adaptive_softmax_dropout,
)
return FConvLanguageModel(decoder)
-@register_model_architecture('fconv_lm', 'fconv_lm')
+@register_model_architecture("fconv_lm", "fconv_lm")
def base_lm_architecture(args):
- args.dropout = getattr(args, 'dropout', 0.1)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 128)
- args.decoder_layers = getattr(args, 'decoder_layers', '[(1268, 4)] * 13')
- args.decoder_attention = getattr(args, 'decoder_attention', 'False')
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
- args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
+ args.decoder_layers = getattr(args, "decoder_layers", "[(1268, 4)] * 13")
+ args.decoder_attention = getattr(args, "decoder_attention", "False")
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
-@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_wikitext103')
+@register_model_architecture("fconv_lm", "fconv_lm_dauphin_wikitext103")
def fconv_lm_dauphin_wikitext103(args):
- layers = '[(850, 6)] * 3'
- layers += ' + [(850, 1)] * 1'
- layers += ' + [(850, 5)] * 4'
- layers += ' + [(850, 1)] * 1'
- layers += ' + [(850, 4)] * 3'
- layers += ' + [(1024, 4)] * 1'
- layers += ' + [(2048, 4)] * 1'
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 280)
- args.decoder_layers = getattr(args, 'decoder_layers', layers)
- args.decoder_attention = getattr(args, 'decoder_attention', 'False')
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,20000,200000')
+ layers = "[(850, 6)] * 3"
+ layers += " + [(850, 1)] * 1"
+ layers += " + [(850, 5)] * 4"
+ layers += " + [(850, 1)] * 1"
+ layers += " + [(850, 4)] * 3"
+ layers += " + [(1024, 4)] * 1"
+ layers += " + [(2048, 4)] * 1"
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 280)
+ args.decoder_layers = getattr(args, "decoder_layers", layers)
+ args.decoder_attention = getattr(args, "decoder_attention", "False")
+ args.adaptive_softmax_cutoff = getattr(
+ args, "adaptive_softmax_cutoff", "10000,20000,200000"
+ )
base_lm_architecture(args)
-@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_gbw')
+@register_model_architecture("fconv_lm", "fconv_lm_dauphin_gbw")
def fconv_lm_dauphin_gbw(args):
- layers = '[(512, 5)]'
- layers += ' + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3'
- layers += ' + [(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3'
- layers += ' + [(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6'
- layers += ' + [(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]'
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 128)
- args.decoder_layers = getattr(args, 'decoder_layers', layers)
- args.decoder_attention = getattr(args, 'decoder_attention', 'False')
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
+ layers = "[(512, 5)]"
+ layers += " + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3"
+ layers += " + [(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3"
+ layers += " + [(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6"
+ layers += " + [(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]"
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
+ args.decoder_layers = getattr(args, "decoder_layers", layers)
+ args.decoder_attention = getattr(args, "decoder_attention", "False")
+ args.adaptive_softmax_cutoff = getattr(
+ args, "adaptive_softmax_cutoff", "10000,50000,200000"
+ )
base_lm_architecture(args)
diff --git a/fairseq/models/fconv_self_att.py b/fairseq/models/fconv_self_att.py
index c3582da96f..8357ef7847 100644
--- a/fairseq/models/fconv_self_att.py
+++ b/fairseq/models/fconv_self_att.py
@@ -10,8 +10,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import checkpoint_utils
+from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.models import (
CompositeEncoder,
FairseqDecoder,
@@ -21,48 +21,49 @@
register_model_architecture,
)
from fairseq.modules import (
- FairseqDropout,
DownsampledMultiHeadAttention,
+ FairseqDropout,
GradMultiply,
LayerNorm,
LearnedPositionalEmbedding,
LinearizedConvolution,
)
-from fairseq.incremental_decoding_utils import with_incremental_state
+
logger = logging.getLogger(__name__)
-@register_model('fconv_self_att')
+@register_model("fconv_self_att")
class FConvModelSelfAtt(FairseqEncoderDecoderModel):
-
@classmethod
def hub_models(cls):
return {
- 'conv.stories.pretrained': {
- 'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz',
- 'checkpoint_file': 'pretrained_checkpoint.pt',
- 'tokenizer': 'nltk',
+ "conv.stories.pretrained": {
+ "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz",
+ "checkpoint_file": "pretrained_checkpoint.pt",
+ "tokenizer": "nltk",
},
- 'conv.stories': {
- 'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz',
- 'checkpoint_file': 'fusion_checkpoint.pt',
- 'tokenizer': 'nltk',
- 'pretrained': 'True',
- 'pretrained_checkpoint': './pretrained_checkpoint.pt',
+ "conv.stories": {
+ "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz",
+ "checkpoint_file": "fusion_checkpoint.pt",
+ "tokenizer": "nltk",
+ "pretrained": "True",
+ "pretrained_checkpoint": "./pretrained_checkpoint.pt",
},
# Test set containing dictionaries
- 'data.stories': 'https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2',
+ "data.stories": "https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2",
}
def __init__(self, encoder, decoder, pretrained_encoder=None):
super().__init__(encoder, decoder)
- self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
+ self.encoder.num_attention_layers = sum(
+ layer is not None for layer in decoder.attention
+ )
self.pretrained_encoder = pretrained_encoder
if self.pretrained_encoder is None:
- encoders = {'encoder': encoder}
+ encoders = {"encoder": encoder}
else:
- encoders = {'encoder': encoder, 'pretrained': self.pretrained_encoder}
+ encoders = {"encoder": encoder, "pretrained": self.pretrained_encoder}
# for fusion model, CompositeEncoder contains both pretrained and training encoders
# these are forwarded and then combined in the decoder
self.encoder = CompositeEncoder(encoders)
@@ -113,9 +114,11 @@ def build_model(cls, args, task):
trained_encoder, trained_decoder = None, None
pretrained = eval(args.pretrained)
if pretrained:
- logger.info('loading pretrained model')
+ logger.info("loading pretrained model")
if not os.path.exists(args.pretrained_checkpoint):
- new_pretrained_checkpoint = os.path.join(args.data, args.pretrained_checkpoint)
+ new_pretrained_checkpoint = os.path.join(
+ args.data, args.pretrained_checkpoint
+ )
if os.path.exists(new_pretrained_checkpoint):
args.pretrained_checkpoint = new_pretrained_checkpoint
trained_model = checkpoint_utils.load_model_ensemble(
@@ -169,9 +172,15 @@ def pretrained(self):
class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
+
def __init__(
- self, dictionary, embed_dim=512, max_positions=1024,
- convolutions=((512, 3),) * 20, dropout=0.1, attention=False,
+ self,
+ dictionary,
+ embed_dim=512,
+ max_positions=1024,
+ convolutions=((512, 3),) * 20,
+ dropout=0.1,
+ attention=False,
attention_nheads=1,
):
super().__init__(dictionary)
@@ -205,14 +214,18 @@ def expand_bool_array(val):
self.attproj = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions):
self.projections.append(
- Linear(in_channels, out_channels) if in_channels != out_channels else None
+ Linear(in_channels, out_channels)
+ if in_channels != out_channels
+ else None
)
self.convolutions.append(
ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout)
)
self.attention.append(
- SelfAttention(out_channels, embed_dim, attention_nheads) if attention[i] else None
+ SelfAttention(out_channels, embed_dim, attention_nheads)
+ if attention[i]
+ else None
)
in_channels = out_channels
@@ -235,7 +248,9 @@ def forward(self, src_tokens, src_lengths):
x = x.transpose(0, 1)
# temporal convolutions
- for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
+ for proj, conv, attention in zip(
+ self.projections, self.convolutions, self.attention
+ ):
residual = x if proj is None else proj(x)
if encoder_padding_mask is not None:
@@ -268,23 +283,24 @@ def forward(self, src_tokens, src_lengths):
y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5)
return {
- 'encoder_out': (x, y),
- 'encoder_padding_mask': encoder_padding_mask, # B x T
+ "encoder_out": (x, y),
+ "encoder_padding_mask": encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out, new_order):
- encoder_out['encoder_out'] = tuple(
- eo.index_select(0, new_order) for eo in encoder_out['encoder_out']
+ encoder_out["encoder_out"] = tuple(
+ eo.index_select(0, new_order) for eo in encoder_out["encoder_out"]
)
- if encoder_out['encoder_padding_mask'] is not None:
- encoder_out['encoder_padding_mask'] = \
- encoder_out['encoder_padding_mask'].index_select(0, new_order)
+ if encoder_out["encoder_padding_mask"] is not None:
+ encoder_out["encoder_padding_mask"] = encoder_out[
+ "encoder_padding_mask"
+ ].index_select(0, new_order)
- if 'pretrained' in encoder_out:
- encoder_out['pretrained']['encoder_out'] = tuple(
+ if "pretrained" in encoder_out:
+ encoder_out["pretrained"]["encoder_out"] = tuple(
eo.index_select(0, new_order)
- for eo in encoder_out['pretrained']['encoder_out']
+ for eo in encoder_out["pretrained"]["encoder_out"]
)
return encoder_out
@@ -297,15 +313,27 @@ def max_positions(self):
@with_incremental_state
class FConvDecoder(FairseqDecoder):
"""Convolutional decoder"""
+
def __init__(
- self, dictionary, embed_dim=512, out_embed_dim=256, max_positions=1024,
- convolutions=((512, 3),) * 8, attention=True, dropout=0.1,
- selfattention=False, attention_nheads=1, selfattention_nheads=1,
- project_input=False, gated_attention=False, downsample=False,
- pretrained=False, trained_decoder=None,
+ self,
+ dictionary,
+ embed_dim=512,
+ out_embed_dim=256,
+ max_positions=1024,
+ convolutions=((512, 3),) * 8,
+ attention=True,
+ dropout=0.1,
+ selfattention=False,
+ attention_nheads=1,
+ selfattention_nheads=1,
+ project_input=False,
+ gated_attention=False,
+ downsample=False,
+ pretrained=False,
+ trained_decoder=None,
):
super().__init__(dictionary)
- self.register_buffer('version', torch.Tensor([2]))
+ self.register_buffer("version", torch.Tensor([2]))
self.pretrained = pretrained
self.pretrained_decoder = trained_decoder
self.dropout_module = FairseqDropout(
@@ -324,8 +352,10 @@ def expand_bool_array(val):
selfattention = expand_bool_array(selfattention)
if not isinstance(attention, list) or len(attention) != len(convolutions):
- raise ValueError('Attention is expected to be a list of booleans of '
- 'length equal to the number of layers.')
+ raise ValueError(
+ "Attention is expected to be a list of booleans of "
+ "length equal to the number of layers."
+ )
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
@@ -345,31 +375,49 @@ def expand_bool_array(val):
self.attproj = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions):
self.projections.append(
- Linear(in_channels, out_channels) if in_channels != out_channels else None
+ Linear(in_channels, out_channels)
+ if in_channels != out_channels
+ else None
)
self.convolutions.append(
LinearizedConv1d(
- in_channels, out_channels * 2, kernel_size,
- padding=(kernel_size - 1), dropout=dropout,
+ in_channels,
+ out_channels * 2,
+ kernel_size,
+ padding=(kernel_size - 1),
+ dropout=dropout,
)
)
self.attention.append(
DownsampledMultiHeadAttention(
- out_channels, embed_dim, attention_nheads,
- project_input=project_input, gated=False, downsample=False,
- ) if attention[i] else None
+ out_channels,
+ embed_dim,
+ attention_nheads,
+ project_input=project_input,
+ gated=False,
+ downsample=False,
+ )
+ if attention[i]
+ else None
)
self.attproj.append(
- Linear(out_channels, embed_dim, dropout=dropout) if attention[i] else None
+ Linear(out_channels, embed_dim, dropout=dropout)
+ if attention[i]
+ else None
)
self.selfattention.append(
SelfAttention(
- out_channels, embed_dim, selfattention_nheads,
- project_input=project_input, gated=gated_attention,
+ out_channels,
+ embed_dim,
+ selfattention_nheads,
+ project_input=project_input,
+ gated=gated_attention,
downsample=downsample,
- ) if selfattention[i] else None
+ )
+ if selfattention[i]
+ else None
)
in_channels = out_channels
@@ -379,18 +427,22 @@ def expand_bool_array(val):
# model fusion
if self.pretrained:
# independent gates are learned from the concatenated input
- self.gate1 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid())
- self.gate2 = nn.Sequential(Linear(out_embed_dim*2, out_embed_dim), nn.Sigmoid())
+ self.gate1 = nn.Sequential(
+ Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid()
+ )
+ self.gate2 = nn.Sequential(
+ Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid()
+ )
# pretrained and trained models are joined
self.joining = nn.Sequential(
- Linear(out_embed_dim*2, out_embed_dim*2),
- LayerNorm(out_embed_dim*2),
+ Linear(out_embed_dim * 2, out_embed_dim * 2),
+ LayerNorm(out_embed_dim * 2),
nn.GLU(),
- Linear(out_embed_dim, out_embed_dim*2),
- LayerNorm(out_embed_dim*2),
+ Linear(out_embed_dim, out_embed_dim * 2),
+ LayerNorm(out_embed_dim * 2),
nn.GLU(),
Linear(out_embed_dim, out_embed_dim),
- LayerNorm(out_embed_dim)
+ LayerNorm(out_embed_dim),
)
# pretrained model contains an output layer that is nhid -> vocab size
# but the models are combined in their hidden state
@@ -400,13 +452,14 @@ def expand_bool_array(val):
def save_output():
def hook(a, b, output):
self.pretrained_outputs["out"] = output
+
return hook
self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out):
- trained_encoder_out = encoder_out['pretrained'] if self.pretrained else None
- encoder_out = encoder_out['encoder']['encoder_out']
+ trained_encoder_out = encoder_out["pretrained"] if self.pretrained else None
+ encoder_out = encoder_out["encoder"]["encoder_out"]
encoder_a, encoder_b = self._split_encoder_out(encoder_out)
@@ -427,7 +480,11 @@ def forward(self, prev_output_tokens, encoder_out):
# temporal convolutions
avg_attn_scores = None
for proj, conv, attention, selfattention, attproj in zip(
- self.projections, self.convolutions, self.attention, self.selfattention, self.attproj
+ self.projections,
+ self.convolutions,
+ self.attention,
+ self.selfattention,
+ self.attproj,
):
residual = x if proj is None else proj(x)
@@ -438,7 +495,9 @@ def forward(self, prev_output_tokens, encoder_out):
# attention
if attention is not None:
r = x
- x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b)
+ x, attn_scores = attention(
+ attproj(x) + target_embedding, encoder_a, encoder_b
+ )
x = x + r
if not self.training and self.need_attn:
if avg_attn_scores is None:
@@ -462,7 +521,9 @@ def forward(self, prev_output_tokens, encoder_out):
# fusion gating
if self.pretrained:
- trained_x, _ = self.pretrained_decoder.forward(prev_output_tokens, trained_encoder_out)
+ trained_x, _ = self.pretrained_decoder.forward(
+ prev_output_tokens, trained_encoder_out
+ )
y = torch.cat([x, self.pretrained_outputs["out"]], dim=-1)
gate1 = self.gate1(y)
gate2 = self.gate2(y)
@@ -493,12 +554,25 @@ def _split_encoder_out(self, encoder_out):
class SelfAttention(nn.Module):
-
- def __init__(self, out_channels, embed_dim, num_heads, project_input=False, gated=False, downsample=False):
+ def __init__(
+ self,
+ out_channels,
+ embed_dim,
+ num_heads,
+ project_input=False,
+ gated=False,
+ downsample=False,
+ ):
super().__init__()
self.attention = DownsampledMultiHeadAttention(
- out_channels, embed_dim, num_heads, dropout=0, bias=True,
- project_input=project_input, gated=gated, downsample=downsample,
+ out_channels,
+ embed_dim,
+ num_heads,
+ dropout=0,
+ bias=True,
+ project_input=project_input,
+ gated=gated,
+ downsample=downsample,
)
self.in_proj_q = Linear(out_channels, embed_dim)
self.in_proj_k = Linear(out_channels, embed_dim)
@@ -510,7 +584,9 @@ def forward(self, x):
query = self.in_proj_q(x)
key = self.in_proj_k(x)
value = self.in_proj_v(x)
- x, _ = self.attention(query, key, value, mask_future_timesteps=True, use_scalar_bias=True)
+ x, _ = self.attention(
+ query, key, value, mask_future_timesteps=True, use_scalar_bias=True
+ )
return self.ln(x + residual)
@@ -526,7 +602,7 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx):
return m
-def Linear(in_features, out_features, dropout=0.):
+def Linear(in_features, out_features, dropout=0.0):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features)
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
@@ -534,7 +610,7 @@ def Linear(in_features, out_features, dropout=0.):
return m
-def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwargs):
+def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
@@ -543,9 +619,10 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0., **kwarg
return m
-def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs):
+def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs):
"""Weight-normalized Conv1d layer"""
from fairseq.modules import ConvTBC
+
m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
@@ -553,37 +630,45 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0., **kwargs):
return m
-@register_model_architecture('fconv_self_att', 'fconv_self_att')
+@register_model_architecture("fconv_self_att", "fconv_self_att")
def base_architecture(args):
- args.dropout = getattr(args, 'dropout', 0.1)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
- args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 3')
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 8')
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
- args.decoder_attention = getattr(args, 'decoder_attention', 'True')
- args.self_attention = getattr(args, 'self_attention', 'False')
- args.encoder_attention = getattr(args, 'encoder_attention', 'False')
- args.multihead_attention_nheads = getattr(args, 'multihead_attention_nheads', 1)
- args.multihead_self_attention_nheads = getattr(args, 'multihead_self_attention_nheads', 1)
- args.encoder_attention_nheads = getattr(args, 'encoder_attention_nheads', 1)
- args.project_input = getattr(args, 'project_input', 'False')
- args.gated_attention = getattr(args, 'gated_attention', 'False')
- args.downsample = getattr(args, 'downsample', 'False')
- args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '')
- args.pretrained = getattr(args, 'pretrained', 'False')
-
-
-@register_model_architecture('fconv_self_att', 'fconv_self_att_wp')
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 3")
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 8")
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
+ args.decoder_attention = getattr(args, "decoder_attention", "True")
+ args.self_attention = getattr(args, "self_attention", "False")
+ args.encoder_attention = getattr(args, "encoder_attention", "False")
+ args.multihead_attention_nheads = getattr(args, "multihead_attention_nheads", 1)
+ args.multihead_self_attention_nheads = getattr(
+ args, "multihead_self_attention_nheads", 1
+ )
+ args.encoder_attention_nheads = getattr(args, "encoder_attention_nheads", 1)
+ args.project_input = getattr(args, "project_input", "False")
+ args.gated_attention = getattr(args, "gated_attention", "False")
+ args.downsample = getattr(args, "downsample", "False")
+ args.pretrained_checkpoint = getattr(args, "pretrained_checkpoint", "")
+ args.pretrained = getattr(args, "pretrained", "False")
+
+
+@register_model_architecture("fconv_self_att", "fconv_self_att_wp")
def fconv_self_att_wp(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
- args.encoder_layers = getattr(args, 'encoder_layers', '[(128, 3)] * 2 + [(512,3)] * 1')
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
- args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1')
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
- args.self_attention = getattr(args, 'self_attention', 'True')
- args.multihead_self_attention_nheads = getattr(args, 'multihead_self_attention_nheads', 4)
- args.project_input = getattr(args, 'project_input', 'True')
- args.gated_attention = getattr(args, 'gated_attention', 'True')
- args.downsample = getattr(args, 'downsample', 'True')
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
+ args.encoder_layers = getattr(
+ args, "encoder_layers", "[(128, 3)] * 2 + [(512,3)] * 1"
+ )
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
+ args.decoder_layers = getattr(
+ args, "decoder_layers", "[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1"
+ )
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
+ args.self_attention = getattr(args, "self_attention", "True")
+ args.multihead_self_attention_nheads = getattr(
+ args, "multihead_self_attention_nheads", 4
+ )
+ args.project_input = getattr(args, "project_input", "True")
+ args.gated_attention = getattr(args, "gated_attention", "True")
+ args.downsample = getattr(args, "downsample", "True")
base_architecture(args)
diff --git a/fairseq/models/huggingface/__init__.py b/fairseq/models/huggingface/__init__.py
index 633315f54d..f7911c2c8e 100644
--- a/fairseq/models/huggingface/__init__.py
+++ b/fairseq/models/huggingface/__init__.py
@@ -12,9 +12,9 @@
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
- not file.startswith('_')
- and not file.startswith('.')
- and (file.endswith('.py') or os.path.isdir(path))
+ not file.startswith("_")
+ and not file.startswith(".")
+ and (file.endswith(".py") or os.path.isdir(path))
):
- model_name = file[:file.find('.py')] if file.endswith('.py') else file
- module = importlib.import_module('fairseq.models.huggingface.' + model_name)
+ model_name = file[: file.find(".py")] if file.endswith(".py") else file
+ module = importlib.import_module("fairseq.models.huggingface." + model_name)
diff --git a/fairseq/models/huggingface/hf_gpt2.py b/fairseq/models/huggingface/hf_gpt2.py
index e81954ff65..a823453794 100644
--- a/fairseq/models/huggingface/hf_gpt2.py
+++ b/fairseq/models/huggingface/hf_gpt2.py
@@ -16,13 +16,15 @@
register_model_architecture,
)
+
try:
# Prepend the transformers submodule to the path, so that
# it's prioritized over other installations. This allows
# making local changes in the submodule.
- hf_path = os.path.join(os.path.dirname(__file__), 'transformers', 'src')
+ hf_path = os.path.join(os.path.dirname(__file__), "transformers", "src")
sys.path.insert(0, hf_path)
from transformers import GPT2Config, GPT2LMHeadModel
+
sys.path.remove(hf_path)
has_hf = True
except ImportError:
@@ -35,18 +37,17 @@
DEFAULT_MAX_TARGET_POSITIONS = 1024
-@register_model('hf_gpt2')
+@register_model("hf_gpt2")
class HuggingFaceGPT2LanguageModel(FairseqLanguageModel):
-
def __init__(self, decoder):
super().__init__(decoder)
if not has_hf:
raise ImportError(
- '\n\nPlease install huggingface/transformers with:'
- '\n\n pip install transformers'
- '\n\nOr to make local edits, install the submodule:'
- '\n\n git submodule update --init '
- 'fairseq/models/huggingface/transformers'
+ "\n\nPlease install huggingface/transformers with:"
+ "\n\n pip install transformers"
+ "\n\nOr to make local edits, install the submodule:"
+ "\n\n git submodule update --init "
+ "fairseq/models/huggingface/transformers"
)
@staticmethod
@@ -74,17 +75,16 @@ def build_model(cls, args, task):
class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder):
-
def __init__(self, args, task):
super().__init__(task.target_dictionary)
if not has_hf:
raise ImportError(
- '\n\nPlease install huggingface/transformers with:'
- '\n\n pip install transformers'
- '\n\nOr to make local edits, install the submodule:'
- '\n\n git submodule update --init '
- 'fairseq/models/huggingface/transformers'
+ "\n\nPlease install huggingface/transformers with:"
+ "\n\n pip install transformers"
+ "\n\nOr to make local edits, install the submodule:"
+ "\n\n git submodule update --init "
+ "fairseq/models/huggingface/transformers"
)
config = GPT2Config(
@@ -115,7 +115,7 @@ def forward(
):
features = self.extract_features(prev_output_tokens, incremental_state)
lm_logits = self.model.lm_head(features)
- return (lm_logits, )
+ return (lm_logits,)
def extract_features(
self,
@@ -154,38 +154,38 @@ def max_positions(self):
return self.model.config.n_positions - 1
-@register_model_architecture('hf_gpt2', 'hf_gpt2')
+@register_model_architecture("hf_gpt2", "hf_gpt2")
def default_architecture(args):
- if getattr(args, 'max_target_positions', None) is None:
+ if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
- args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS
+ args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
- args.embed_dim = getattr(args, 'embed_dim', 768)
- args.num_attention_heads = getattr(args, 'num_attention_heads', 12)
- args.num_layers = getattr(args, 'num_layers', 12)
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
+ args.embed_dim = getattr(args, "embed_dim", 768)
+ args.num_attention_heads = getattr(args, "num_attention_heads", 12)
+ args.num_layers = getattr(args, "num_layers", 12)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
-@register_model_architecture('hf_gpt2', 'hf_gpt2_medium')
+@register_model_architecture("hf_gpt2", "hf_gpt2_medium")
def hf_gpt2_medium(args):
- args.embed_dim = getattr(args, 'embed_dim', 1024)
- args.num_attention_heads = getattr(args, 'num_attention_heads', 16)
- args.num_layers = getattr(args, 'num_layers', 24)
+ args.embed_dim = getattr(args, "embed_dim", 1024)
+ args.num_attention_heads = getattr(args, "num_attention_heads", 16)
+ args.num_layers = getattr(args, "num_layers", 24)
default_architecture(args)
-@register_model_architecture('hf_gpt2', 'hf_gpt2_large')
+@register_model_architecture("hf_gpt2", "hf_gpt2_large")
def hf_gpt2_large(args):
- args.embed_dim = getattr(args, 'embed_dim', 1280)
- args.num_attention_heads = getattr(args, 'num_attention_heads', 20)
- args.num_layers = getattr(args, 'num_layers', 36)
+ args.embed_dim = getattr(args, "embed_dim", 1280)
+ args.num_attention_heads = getattr(args, "num_attention_heads", 20)
+ args.num_layers = getattr(args, "num_layers", 36)
default_architecture(args)
-@register_model_architecture('hf_gpt2', 'hf_gpt2_xl')
+@register_model_architecture("hf_gpt2", "hf_gpt2_xl")
def hf_gpt2_xl(args):
- args.embed_dim = getattr(args, 'embed_dim', 1600)
- args.num_attention_heads = getattr(args, 'num_attention_heads', 25)
- args.num_layers = getattr(args, 'num_layers', 48)
+ args.embed_dim = getattr(args, "embed_dim", 1600)
+ args.num_attention_heads = getattr(args, "num_attention_heads", 25)
+ args.num_layers = getattr(args, "num_layers", 48)
default_architecture(args)
diff --git a/fairseq/models/lightconv.py b/fairseq/models/lightconv.py
index 09d4d0be2e..b614da3665 100644
--- a/fairseq/models/lightconv.py
+++ b/fairseq/models/lightconv.py
@@ -8,12 +8,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
- FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
+ FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
@@ -22,13 +21,13 @@
DynamicConv,
FairseqDropout,
LayerNorm,
- PositionalEmbedding,
LightweightConv,
MultiheadAttention,
+ PositionalEmbedding,
)
-@register_model('lightconv')
+@register_model("lightconv")
class LightConvModel(FairseqEncoderDecoderModel):
"""
LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019)
@@ -81,75 +80,175 @@ def __init__(self, encoder, decoder):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
- parser.add_argument('--dropout', type=float, metavar='D',
- help='dropout probability')
- parser.add_argument('--attention-dropout', type=float, metavar='D',
- help='dropout probability for attention weights')
- parser.add_argument('--relu-dropout', type=float, metavar='D',
- help='dropout probability after ReLU in FFN')
- parser.add_argument('--input-dropout', type=float, metavar='D',
- help='dropout probability of the inputs')
- parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
- help='path to pre-trained encoder embedding')
- parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
- help='encoder embedding dimension')
- parser.add_argument('--encoder-conv-dim', type=int, metavar='N',
- help='encoder embedding dimension')
- parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
- help='encoder embedding dimension for FFN')
- parser.add_argument('--encoder-layers', type=int, metavar='N',
- help='num encoder layers')
- parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
- help='num encoder attention heads or LightConv/DynamicConv heads')
- parser.add_argument('--encoder-normalize-before', action='store_true',
- help='apply layernorm before each encoder block')
- parser.add_argument('--encoder-learned-pos', action='store_true',
- help='use learned positional embeddings in the encoder')
- parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
- help='path to pre-trained decoder embedding')
- parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
- help='decoder embedding dimension')
- parser.add_argument('--decoder-conv-dim', type=int, metavar='N',
- help='decoder embedding dimension')
- parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
- help='decoder embedding dimension for FFN')
- parser.add_argument('--decoder-layers', type=int, metavar='N',
- help='num decoder layers')
- parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
- help='num decoder attention heads or LightConv/DynamicConv heads')
- parser.add_argument('--decoder-learned-pos', action='store_true',
- help='use learned positional embeddings in the decoder')
- parser.add_argument('--decoder-normalize-before', action='store_true',
- help='apply layernorm before each decoder block')
- parser.add_argument('--share-decoder-input-output-embed', action='store_true',
- help='share decoder input and output embeddings')
- parser.add_argument('--share-all-embeddings', action='store_true',
- help='share encoder, decoder and output embeddings'
- ' (requires shared dictionary and embed dim)')
- parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
- help='comma separated list of adaptive softmax cutoff points. '
- 'Must be used with adaptive_loss criterion'),
- parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
- help='sets adaptive softmax dropout for the tail projections')
+ parser.add_argument(
+ "--dropout", type=float, metavar="D", help="dropout probability"
+ )
+ parser.add_argument(
+ "--attention-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability for attention weights",
+ )
+ parser.add_argument(
+ "--relu-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability after ReLU in FFN",
+ )
+ parser.add_argument(
+ "--input-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability of the inputs",
+ )
+ parser.add_argument(
+ "--encoder-embed-path",
+ type=str,
+ metavar="STR",
+ help="path to pre-trained encoder embedding",
+ )
+ parser.add_argument(
+ "--encoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension",
+ )
+ parser.add_argument(
+ "--encoder-conv-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension",
+ )
+ parser.add_argument(
+ "--encoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--encoder-layers", type=int, metavar="N", help="num encoder layers"
+ )
+ parser.add_argument(
+ "--encoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num encoder attention heads or LightConv/DynamicConv heads",
+ )
+ parser.add_argument(
+ "--encoder-normalize-before",
+ action="store_true",
+ help="apply layernorm before each encoder block",
+ )
+ parser.add_argument(
+ "--encoder-learned-pos",
+ action="store_true",
+ help="use learned positional embeddings in the encoder",
+ )
+ parser.add_argument(
+ "--decoder-embed-path",
+ type=str,
+ metavar="STR",
+ help="path to pre-trained decoder embedding",
+ )
+ parser.add_argument(
+ "--decoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension",
+ )
+ parser.add_argument(
+ "--decoder-conv-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension",
+ )
+ parser.add_argument(
+ "--decoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
+ )
+ parser.add_argument(
+ "--decoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num decoder attention heads or LightConv/DynamicConv heads",
+ )
+ parser.add_argument(
+ "--decoder-learned-pos",
+ action="store_true",
+ help="use learned positional embeddings in the decoder",
+ )
+ parser.add_argument(
+ "--decoder-normalize-before",
+ action="store_true",
+ help="apply layernorm before each decoder block",
+ )
+ parser.add_argument(
+ "--share-decoder-input-output-embed",
+ action="store_true",
+ help="share decoder input and output embeddings",
+ )
+ parser.add_argument(
+ "--share-all-embeddings",
+ action="store_true",
+ help="share encoder, decoder and output embeddings"
+ " (requires shared dictionary and embed dim)",
+ )
+ parser.add_argument(
+ "--adaptive-softmax-cutoff",
+ metavar="EXPR",
+ help="comma separated list of adaptive softmax cutoff points. "
+ "Must be used with adaptive_loss criterion",
+ ),
+ parser.add_argument(
+ "--adaptive-softmax-dropout",
+ type=float,
+ metavar="D",
+ help="sets adaptive softmax dropout for the tail projections",
+ )
"""LightConv and DynamicConv arguments"""
- parser.add_argument('--encoder-kernel-size-list', type=lambda x: utils.eval_str_list(x, int),
- help='list of kernel size (default: "[3,7,15,31,31,31,31]")')
- parser.add_argument('--decoder-kernel-size-list', type=lambda x: utils.eval_str_list(x, int),
- help='list of kernel size (default: "[3,7,15,31,31,31]")')
- parser.add_argument('--encoder-glu', type=utils.eval_bool,
- help='glu after in proj')
- parser.add_argument('--decoder-glu', type=utils.eval_bool,
- help='glu after in proj')
- parser.add_argument('--encoder-conv-type', default='dynamic', type=str,
- choices=['dynamic', 'lightweight'],
- help='type of convolution')
- parser.add_argument('--decoder-conv-type', default='dynamic', type=str,
- choices=['dynamic', 'lightweight'],
- help='type of convolution')
- parser.add_argument('--weight-softmax', default=True, type=utils.eval_bool)
- parser.add_argument('--weight-dropout', type=float, metavar='D',
- help='dropout probability for conv weights')
+ parser.add_argument(
+ "--encoder-kernel-size-list",
+ type=lambda x: utils.eval_str_list(x, int),
+ help='list of kernel size (default: "[3,7,15,31,31,31,31]")',
+ )
+ parser.add_argument(
+ "--decoder-kernel-size-list",
+ type=lambda x: utils.eval_str_list(x, int),
+ help='list of kernel size (default: "[3,7,15,31,31,31]")',
+ )
+ parser.add_argument(
+ "--encoder-glu", type=utils.eval_bool, help="glu after in proj"
+ )
+ parser.add_argument(
+ "--decoder-glu", type=utils.eval_bool, help="glu after in proj"
+ )
+ parser.add_argument(
+ "--encoder-conv-type",
+ default="dynamic",
+ type=str,
+ choices=["dynamic", "lightweight"],
+ help="type of convolution",
+ )
+ parser.add_argument(
+ "--decoder-conv-type",
+ default="dynamic",
+ type=str,
+ choices=["dynamic", "lightweight"],
+ help="type of convolution",
+ )
+ parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool)
+ parser.add_argument(
+ "--weight-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability for conv weights",
+ )
@classmethod
def build_model(cls, args, task):
@@ -158,9 +257,9 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_architecture(args)
- if not hasattr(args, 'max_source_positions'):
+ if not hasattr(args, "max_source_positions"):
args.max_source_positions = 1024
- if not hasattr(args, 'max_target_positions'):
+ if not hasattr(args, "max_target_positions"):
args.max_target_positions = 1024
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
@@ -177,13 +276,19 @@ def build_embedding(dictionary, embed_dim, path=None):
if args.share_all_embeddings:
if src_dict != tgt_dict:
- raise RuntimeError('--share-all-embeddings requires a joined dictionary')
+ raise RuntimeError(
+ "--share-all-embeddings requires a joined dictionary"
+ )
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
- '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
+ )
if args.decoder_embed_path and (
- args.decoder_embed_path != args.encoder_embed_path):
- raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path')
+ args.decoder_embed_path != args.encoder_embed_path
+ ):
+ raise RuntimeError(
+ "--share-all-embeddings not compatible with --decoder-embed-path"
+ )
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
@@ -215,7 +320,9 @@ class LightConvEncoder(FairseqEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
- self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ args.dropout, module_name=self.__class__.__name__
+ )
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
@@ -223,17 +330,27 @@ def __init__(self, args, dictionary, embed_tokens):
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
- self.embed_positions = PositionalEmbedding(
- args.max_source_positions, embed_dim, self.padding_idx,
- learned=args.encoder_learned_pos,
- ) if not args.no_token_positional_embeddings else None
+ self.embed_positions = (
+ PositionalEmbedding(
+ args.max_source_positions,
+ embed_dim,
+ self.padding_idx,
+ learned=args.encoder_learned_pos,
+ )
+ if not args.no_token_positional_embeddings
+ else None
+ )
self.layers = nn.ModuleList([])
- self.layers.extend([
- LightConvEncoderLayer(args, kernel_size=args.encoder_kernel_size_list[i])
- for i in range(args.encoder_layers)
- ])
- self.register_buffer('version', torch.Tensor([2]))
+ self.layers.extend(
+ [
+ LightConvEncoderLayer(
+ args, kernel_size=args.encoder_kernel_size_list[i]
+ )
+ for i in range(args.encoder_layers)
+ ]
+ )
+ self.register_buffer("version", torch.Tensor([2]))
self.normalize = args.encoder_normalize_before
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
@@ -273,8 +390,8 @@ def forward(self, src_tokens, **unused):
x = self.layer_norm(x)
return {
- 'encoder_out': x, # T x B x C
- 'encoder_padding_mask': encoder_padding_mask, # B x T
+ "encoder_out": x, # T x B x C
+ "encoder_padding_mask": encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out, new_order):
@@ -288,12 +405,14 @@ def reorder_encoder_out(self, encoder_out, new_order):
Returns:
*encoder_out* rearranged according to *new_order*
"""
- if encoder_out['encoder_out'] is not None:
- encoder_out['encoder_out'] = \
- encoder_out['encoder_out'].index_select(1, new_order)
- if encoder_out['encoder_padding_mask'] is not None:
- encoder_out['encoder_padding_mask'] = \
- encoder_out['encoder_padding_mask'].index_select(0, new_order)
+ if encoder_out["encoder_out"] is not None:
+ encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
+ 1, new_order
+ )
+ if encoder_out["encoder_padding_mask"] is not None:
+ encoder_out["encoder_padding_mask"] = encoder_out[
+ "encoder_padding_mask"
+ ].index_select(0, new_order)
return encoder_out
def max_positions(self):
@@ -316,9 +435,13 @@ class LightConvDecoder(FairseqIncrementalDecoder):
Default: ``False``
"""
- def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True):
+ def __init__(
+ self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True
+ ):
super().__init__(dictionary)
- self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ args.dropout, module_name=self.__class__.__name__
+ )
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
@@ -331,23 +454,40 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
- self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
-
- self.embed_positions = PositionalEmbedding(
- args.max_target_positions, embed_dim, padding_idx,
- learned=args.decoder_learned_pos,
- ) if not args.no_token_positional_embeddings else None
+ self.project_in_dim = (
+ Linear(input_embed_dim, embed_dim, bias=False)
+ if embed_dim != input_embed_dim
+ else None
+ )
+
+ self.embed_positions = (
+ PositionalEmbedding(
+ args.max_target_positions,
+ embed_dim,
+ padding_idx,
+ learned=args.decoder_learned_pos,
+ )
+ if not args.no_token_positional_embeddings
+ else None
+ )
self.layers = nn.ModuleList([])
- self.layers.extend([
- LightConvDecoderLayer(args, no_encoder_attn, kernel_size=args.decoder_kernel_size_list[i])
- for i in range(args.decoder_layers)
- ])
+ self.layers.extend(
+ [
+ LightConvDecoderLayer(
+ args, no_encoder_attn, kernel_size=args.decoder_kernel_size_list[i]
+ )
+ for i in range(args.decoder_layers)
+ ]
+ )
self.adaptive_softmax = None
- self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \
- if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None
+ self.project_out_dim = (
+ Linear(embed_dim, output_embed_dim, bias=False)
+ if embed_dim != output_embed_dim and not args.tie_adaptive_weights
+ else None
+ )
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
@@ -360,14 +500,18 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
- self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
+ self.embed_out = nn.Parameter(
+ torch.Tensor(len(dictionary), output_embed_dim)
+ )
nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5)
- self.register_buffer('version', torch.Tensor([2]))
+ self.register_buffer("version", torch.Tensor([2]))
self.normalize = args.decoder_normalize_before and final_norm
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
- def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
+ def forward(
+ self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
+ ):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
@@ -385,10 +529,14 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None,
tgt_len, src_len)`
"""
# embed positions
- positions = self.embed_positions(
- prev_output_tokens,
- incremental_state=incremental_state,
- ) if self.embed_positions is not None else None
+ positions = (
+ self.embed_positions(
+ prev_output_tokens,
+ incremental_state=incremental_state,
+ )
+ if self.embed_positions is not None
+ else None
+ )
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
@@ -415,8 +563,10 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None,
for layer in self.layers:
x, attn = layer(
x,
- encoder_out['encoder_out'] if encoder_out is not None else None,
- encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
+ encoder_out["encoder_out"] if encoder_out is not None else None,
+ encoder_out["encoder_padding_mask"]
+ if encoder_out is not None
+ else None,
incremental_state,
)
inner_states.append(x)
@@ -437,7 +587,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None,
else:
x = F.linear(x, self.embed_out)
- return x, {'attn': attn, 'inner_states': inner_states}
+ return x, {"attn": attn, "inner_states": inner_states}
def max_positions(self):
"""Maximum output length supported by the decoder."""
@@ -447,10 +597,18 @@ def max_positions(self):
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
- if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
- self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
+ if (
+ not hasattr(self, "_future_mask")
+ or self._future_mask is None
+ or self._future_mask.device != tensor.device
+ ):
+ self._future_mask = torch.triu(
+ utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
+ )
if self._future_mask.size(0) < dim:
- self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
+ self._future_mask = torch.triu(
+ utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
+ )
return self._future_mask[:dim, :dim]
@@ -466,31 +624,49 @@ def __init__(self, args, kernel_size=0):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.conv_dim = args.encoder_conv_dim
- padding_l = kernel_size // 2 if kernel_size % 2 == 1 else ((kernel_size - 1) // 2, kernel_size // 2)
+ padding_l = (
+ kernel_size // 2
+ if kernel_size % 2 == 1
+ else ((kernel_size - 1) // 2, kernel_size // 2)
+ )
if args.encoder_glu:
- self.linear1 = Linear(self.embed_dim, 2*self.conv_dim)
+ self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim)
self.act = nn.GLU()
else:
self.linear1 = Linear(self.embed_dim, self.conv_dim)
self.act = None
- if args.encoder_conv_type == 'lightweight':
- self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=padding_l,
- weight_softmax=args.weight_softmax,
- num_heads=args.encoder_attention_heads,
- weight_dropout=args.weight_dropout)
- elif args.encoder_conv_type == 'dynamic':
- self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=padding_l,
- weight_softmax=args.weight_softmax,
- num_heads=args.encoder_attention_heads,
- weight_dropout=args.weight_dropout)
+ if args.encoder_conv_type == "lightweight":
+ self.conv = LightweightConv(
+ self.conv_dim,
+ kernel_size,
+ padding_l=padding_l,
+ weight_softmax=args.weight_softmax,
+ num_heads=args.encoder_attention_heads,
+ weight_dropout=args.weight_dropout,
+ )
+ elif args.encoder_conv_type == "dynamic":
+ self.conv = DynamicConv(
+ self.conv_dim,
+ kernel_size,
+ padding_l=padding_l,
+ weight_softmax=args.weight_softmax,
+ num_heads=args.encoder_attention_heads,
+ weight_dropout=args.weight_dropout,
+ )
else:
raise NotImplementedError
self.linear2 = Linear(self.conv_dim, self.embed_dim)
- self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__)
- self.relu_dropout_module = FairseqDropout(args.relu_dropout, module_name=self.__class__.__name__)
- self.input_dropout_module = FairseqDropout(args.input_dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ args.dropout, module_name=self.__class__.__name__
+ )
+ self.relu_dropout_module = FairseqDropout(
+ args.relu_dropout, module_name=self.__class__.__name__
+ )
+ self.input_dropout_module = FairseqDropout(
+ args.input_dropout, module_name=self.__class__.__name__
+ )
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
@@ -538,8 +714,14 @@ def maybe_layer_norm(self, i, x, before=False, after=False):
return x
def extra_repr(self):
- return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format(
- self.dropout_module.p, self.relu_dropout_module.p, self.input_dropout_module.p, self.normalize_before)
+ return (
+ "dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format(
+ self.dropout_module.p,
+ self.relu_dropout_module.p,
+ self.input_dropout_module.p,
+ self.normalize_before,
+ )
+ )
class LightConvDecoderLayer(nn.Module):
@@ -557,28 +739,42 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0):
self.embed_dim = args.decoder_embed_dim
self.conv_dim = args.decoder_conv_dim
if args.decoder_glu:
- self.linear1 = Linear(self.embed_dim, 2*self.conv_dim)
+ self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim)
self.act = nn.GLU()
else:
self.linear1 = Linear(self.embed_dim, self.conv_dim)
self.act = None
- if args.decoder_conv_type == 'lightweight':
- self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=kernel_size-1,
- weight_softmax=args.weight_softmax,
- num_heads=args.decoder_attention_heads,
- weight_dropout=args.weight_dropout)
- elif args.decoder_conv_type == 'dynamic':
- self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=kernel_size-1,
- weight_softmax=args.weight_softmax,
- num_heads=args.decoder_attention_heads,
- weight_dropout=args.weight_dropout)
+ if args.decoder_conv_type == "lightweight":
+ self.conv = LightweightConv(
+ self.conv_dim,
+ kernel_size,
+ padding_l=kernel_size - 1,
+ weight_softmax=args.weight_softmax,
+ num_heads=args.decoder_attention_heads,
+ weight_dropout=args.weight_dropout,
+ )
+ elif args.decoder_conv_type == "dynamic":
+ self.conv = DynamicConv(
+ self.conv_dim,
+ kernel_size,
+ padding_l=kernel_size - 1,
+ weight_softmax=args.weight_softmax,
+ num_heads=args.decoder_attention_heads,
+ weight_dropout=args.weight_dropout,
+ )
else:
raise NotImplementedError
self.linear2 = Linear(self.conv_dim, self.embed_dim)
- self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__)
- self.relu_dropout_module = FairseqDropout(args.relu_dropout, module_name=self.__class__.__name__)
- self.input_dropout_module = FairseqDropout(args.input_dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ args.dropout, module_name=self.__class__.__name__
+ )
+ self.relu_dropout_module = FairseqDropout(
+ args.relu_dropout, module_name=self.__class__.__name__
+ )
+ self.input_dropout_module = FairseqDropout(
+ args.input_dropout, module_name=self.__class__.__name__
+ )
self.normalize_before = args.decoder_normalize_before
self.conv_layer_norm = LayerNorm(self.embed_dim)
@@ -588,8 +784,10 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0):
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
- self.embed_dim, args.decoder_attention_heads,
- dropout=args.attention_dropout, encoder_decoder_attention=True,
+ self.embed_dim,
+ args.decoder_attention_heads,
+ dropout=args.attention_dropout,
+ encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
@@ -599,9 +797,17 @@ def __init__(self, args, no_encoder_attn=False, kernel_size=0):
self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True
- def forward(self, x, encoder_out, encoder_padding_mask, incremental_state,
- prev_conv_state=None, prev_attn_state=None, conv_mask=None,
- conv_padding_mask=None):
+ def forward(
+ self,
+ x,
+ encoder_out,
+ encoder_padding_mask,
+ incremental_state,
+ prev_conv_state=None,
+ prev_attn_state=None,
+ conv_mask=None,
+ conv_padding_mask=None,
+ ):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
@@ -671,8 +877,14 @@ def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def extra_repr(self):
- return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format(
- self.dropout_module.p, self.relu_dropout_module.p, self.input_dropout_module.p, self.normalize_before)
+ return (
+ "dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format(
+ self.dropout_module.p,
+ self.relu_dropout_module.p,
+ self.input_dropout_module.p,
+ self.normalize_before,
+ )
+ )
def Embedding(num_embeddings, embedding_dim, padding_idx):
@@ -686,101 +898,121 @@ def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
- nn.init.constant_(m.bias, 0.)
+ nn.init.constant_(m.bias, 0.0)
return m
-@register_model_architecture('lightconv', 'lightconv')
+@register_model_architecture("lightconv", "lightconv")
def base_architecture(args):
- args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
- args.encoder_layers = getattr(args, 'encoder_layers', 7)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
- args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
- args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
- args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
- args.decoder_layers = getattr(args, 'decoder_layers', 6)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
- args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
- args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.)
- args.relu_dropout = getattr(args, 'relu_dropout', 0.)
- args.dropout = getattr(args, 'dropout', 0.1)
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
- args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
- args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
- args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
- args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
-
- args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
- args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
-
- args.encoder_conv_dim = getattr(args, 'encoder_conv_dim', args.encoder_embed_dim)
- args.decoder_conv_dim = getattr(args, 'decoder_conv_dim', args.decoder_embed_dim)
-
- args.encoder_kernel_size_list = getattr(args, 'encoder_kernel_size_list', [3, 7, 15, 31, 31, 31, 31])
- args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31])
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
+ args.encoder_layers = getattr(args, "encoder_layers", 7)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
+ args.relu_dropout = getattr(args, "relu_dropout", 0.0)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+
+ args.encoder_conv_dim = getattr(args, "encoder_conv_dim", args.encoder_embed_dim)
+ args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim)
+
+ args.encoder_kernel_size_list = getattr(
+ args, "encoder_kernel_size_list", [3, 7, 15, 31, 31, 31, 31]
+ )
+ args.decoder_kernel_size_list = getattr(
+ args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31]
+ )
if len(args.encoder_kernel_size_list) == 1:
- args.encoder_kernel_size_list = args.encoder_kernel_size_list * args.encoder_layers
+ args.encoder_kernel_size_list = (
+ args.encoder_kernel_size_list * args.encoder_layers
+ )
if len(args.decoder_kernel_size_list) == 1:
- args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers
- assert len(args.encoder_kernel_size_list) == args.encoder_layers, "encoder_kernel_size_list doesn't match encoder_layers"
- assert len(args.decoder_kernel_size_list) == args.decoder_layers, "decoder_kernel_size_list doesn't match decoder_layers"
- args.encoder_glu = getattr(args, 'encoder_glu', True)
- args.decoder_glu = getattr(args, 'decoder_glu', True)
- args.input_dropout = getattr(args, 'input_dropout', 0.1)
- args.weight_dropout = getattr(args, 'weight_dropout', args.attention_dropout)
-
-
-@register_model_architecture('lightconv', 'lightconv_iwslt_de_en')
+ args.decoder_kernel_size_list = (
+ args.decoder_kernel_size_list * args.decoder_layers
+ )
+ assert (
+ len(args.encoder_kernel_size_list) == args.encoder_layers
+ ), "encoder_kernel_size_list doesn't match encoder_layers"
+ assert (
+ len(args.decoder_kernel_size_list) == args.decoder_layers
+ ), "decoder_kernel_size_list doesn't match decoder_layers"
+ args.encoder_glu = getattr(args, "encoder_glu", True)
+ args.decoder_glu = getattr(args, "decoder_glu", True)
+ args.input_dropout = getattr(args, "input_dropout", 0.1)
+ args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout)
+
+
+@register_model_architecture("lightconv", "lightconv_iwslt_de_en")
def lightconv_iwslt_de_en(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
- args.encoder_layers = getattr(args, 'encoder_layers', 7)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
- args.decoder_layers = getattr(args, 'decoder_layers', 6)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.weight_dropout = getattr(args, 'weight_dropout', 0.1)
- args.encoder_glu = getattr(args, 'encoder_glu', False)
- args.decoder_glu = getattr(args, 'decoder_glu', False)
- args.input_dropout = getattr(args, 'input_dropout', 0.0)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
+ args.encoder_layers = getattr(args, "encoder_layers", 7)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.weight_dropout = getattr(args, "weight_dropout", 0.1)
+ args.encoder_glu = getattr(args, "encoder_glu", False)
+ args.decoder_glu = getattr(args, "decoder_glu", False)
+ args.input_dropout = getattr(args, "input_dropout", 0.0)
base_architecture(args)
-@register_model_architecture('lightconv', 'lightconv_wmt_en_de')
+@register_model_architecture("lightconv", "lightconv_wmt_en_de")
def lightconv_wmt_en_de(args):
base_architecture(args)
-@register_model_architecture('lightconv', 'lightconv_wmt_en_de_big')
+@register_model_architecture("lightconv", "lightconv_wmt_en_de_big")
def lightconv_wmt_en_de_big(args):
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
- args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
- args.dropout = getattr(args, 'dropout', 0.3)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
+ args.dropout = getattr(args, "dropout", 0.3)
base_architecture(args)
-@register_model_architecture('lightconv', 'lightconv_wmt_en_fr_big')
+@register_model_architecture("lightconv", "lightconv_wmt_en_fr_big")
def lightconv_wmt_en_fr_big(args):
- args.dropout = getattr(args, 'dropout', 0.1)
+ args.dropout = getattr(args, "dropout", 0.1)
lightconv_wmt_en_de_big(args)
-@register_model_architecture('lightconv', 'lightconv_wmt_zh_en_big')
+@register_model_architecture("lightconv", "lightconv_wmt_zh_en_big")
def lightconv_wmt_zh_en_big(args):
- args.dropout = getattr(args, 'dropout', 0.2)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.2)
- args.weight_dropout = getattr(args, 'weight_dropout', 0.2)
+ args.dropout = getattr(args, "dropout", 0.2)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.2)
+ args.weight_dropout = getattr(args, "weight_dropout", 0.2)
lightconv_wmt_en_de_big(args)
diff --git a/fairseq/models/lightconv_lm.py b/fairseq/models/lightconv_lm.py
index 861f6430e9..1d9efc4e42 100644
--- a/fairseq/models/lightconv_lm.py
+++ b/fairseq/models/lightconv_lm.py
@@ -9,17 +9,11 @@
register_model,
register_model_architecture,
)
-from fairseq.models.lightconv import (
- Embedding,
- LightConvDecoder,
-)
-from fairseq.modules import (
- AdaptiveInput,
- CharacterTokenEmbedder,
-)
+from fairseq.models.lightconv import Embedding, LightConvDecoder
+from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
-@register_model('lightconv_lm')
+@register_model("lightconv_lm")
class LightConvLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@@ -27,72 +21,182 @@ def __init__(self, decoder):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
- parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
- help='dropout probability')
- parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
- help='dropout probability for attention weights')
- parser.add_argument('--relu-dropout', default=0., type=float, metavar='D',
- help='dropout probability after ReLU in FFN')
- parser.add_argument('--input-dropout', type=float, metavar='D',
- help='dropout probability of the inputs')
- parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
- help='decoder embedding dimension')
- parser.add_argument('--decoder-output-dim', type=int, metavar='N',
- help='decoder output dimension')
- parser.add_argument('--decoder-input-dim', type=int, metavar='N',
- help='decoder input dimension')
- parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
- help='decoder embedding dimension for FFN')
- parser.add_argument('--decoder-layers', type=int, metavar='N',
- help='num decoder layers')
- parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
- help='num decoder attention heads or LightConv/DynamicConv heads')
- parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
- help='apply layernorm before each decoder block')
- parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
- help='comma separated list of adaptive softmax cutoff points. '
- 'Must be used with adaptive_loss criterion')
- parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
- help='sets adaptive softmax dropout for the tail projections')
- parser.add_argument('--adaptive-softmax-factor', type=float, metavar='N',
- help='adaptive input factor')
- parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
- help='if set, disables positional embeddings (outside self attention)')
- parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
- help='share decoder input and output embeddings')
- parser.add_argument('--character-embeddings', default=False, action='store_true',
- help='if set, uses character embedding convolutions to produce token embeddings')
- parser.add_argument('--character-filters', type=str, metavar='LIST',
- default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]',
- help='size of character embeddings')
- parser.add_argument('--character-embedding-dim', type=int, metavar='N', default=4,
- help='size of character embeddings')
- parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2,
- help='number of highway layers for character token embeddder')
- parser.add_argument('--adaptive-input', default=False, action='store_true',
- help='if set, uses adaptive input')
- parser.add_argument('--adaptive-input-factor', type=float, metavar='N',
- help='adaptive input factor')
- parser.add_argument('--adaptive-input-cutoff', metavar='EXPR',
- help='comma separated list of adaptive input cutoff points.')
- parser.add_argument('--tie-adaptive-weights', action='store_true',
- help='if set, ties the weights of adaptive softmax and adaptive input')
- parser.add_argument('--tie-adaptive-proj', action='store_true',
- help='if set, ties the projection weights of adaptive softmax and adaptive input')
- parser.add_argument('--decoder-learned-pos', action='store_true',
- help='use learned positional embeddings in the decoder')
+ parser.add_argument(
+ "--dropout",
+ default=0.1,
+ type=float,
+ metavar="D",
+ help="dropout probability",
+ )
+ parser.add_argument(
+ "--attention-dropout",
+ default=0.0,
+ type=float,
+ metavar="D",
+ help="dropout probability for attention weights",
+ )
+ parser.add_argument(
+ "--relu-dropout",
+ default=0.0,
+ type=float,
+ metavar="D",
+ help="dropout probability after ReLU in FFN",
+ )
+ parser.add_argument(
+ "--input-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability of the inputs",
+ )
+ parser.add_argument(
+ "--decoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension",
+ )
+ parser.add_argument(
+ "--decoder-output-dim",
+ type=int,
+ metavar="N",
+ help="decoder output dimension",
+ )
+ parser.add_argument(
+ "--decoder-input-dim", type=int, metavar="N", help="decoder input dimension"
+ )
+ parser.add_argument(
+ "--decoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
+ )
+ parser.add_argument(
+ "--decoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num decoder attention heads or LightConv/DynamicConv heads",
+ )
+ parser.add_argument(
+ "--decoder-normalize-before",
+ default=False,
+ action="store_true",
+ help="apply layernorm before each decoder block",
+ )
+ parser.add_argument(
+ "--adaptive-softmax-cutoff",
+ metavar="EXPR",
+ help="comma separated list of adaptive softmax cutoff points. "
+ "Must be used with adaptive_loss criterion",
+ )
+ parser.add_argument(
+ "--adaptive-softmax-dropout",
+ type=float,
+ metavar="D",
+ help="sets adaptive softmax dropout for the tail projections",
+ )
+ parser.add_argument(
+ "--adaptive-softmax-factor",
+ type=float,
+ metavar="N",
+ help="adaptive input factor",
+ )
+ parser.add_argument(
+ "--no-token-positional-embeddings",
+ default=False,
+ action="store_true",
+ help="if set, disables positional embeddings (outside self attention)",
+ )
+ parser.add_argument(
+ "--share-decoder-input-output-embed",
+ default=False,
+ action="store_true",
+ help="share decoder input and output embeddings",
+ )
+ parser.add_argument(
+ "--character-embeddings",
+ default=False,
+ action="store_true",
+ help="if set, uses character embedding convolutions to produce token embeddings",
+ )
+ parser.add_argument(
+ "--character-filters",
+ type=str,
+ metavar="LIST",
+ default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]",
+ help="size of character embeddings",
+ )
+ parser.add_argument(
+ "--character-embedding-dim",
+ type=int,
+ metavar="N",
+ default=4,
+ help="size of character embeddings",
+ )
+ parser.add_argument(
+ "--char-embedder-highway-layers",
+ type=int,
+ metavar="N",
+ default=2,
+ help="number of highway layers for character token embeddder",
+ )
+ parser.add_argument(
+ "--adaptive-input",
+ default=False,
+ action="store_true",
+ help="if set, uses adaptive input",
+ )
+ parser.add_argument(
+ "--adaptive-input-factor",
+ type=float,
+ metavar="N",
+ help="adaptive input factor",
+ )
+ parser.add_argument(
+ "--adaptive-input-cutoff",
+ metavar="EXPR",
+ help="comma separated list of adaptive input cutoff points.",
+ )
+ parser.add_argument(
+ "--tie-adaptive-weights",
+ action="store_true",
+ help="if set, ties the weights of adaptive softmax and adaptive input",
+ )
+ parser.add_argument(
+ "--tie-adaptive-proj",
+ action="store_true",
+ help="if set, ties the projection weights of adaptive softmax and adaptive input",
+ )
+ parser.add_argument(
+ "--decoder-learned-pos",
+ action="store_true",
+ help="use learned positional embeddings in the decoder",
+ )
"""LightConv and DynamicConv arguments"""
- parser.add_argument('--decoder-kernel-size-list', type=lambda x: utils.eval_str_list(x, int),
- help='list of kernel size (default: "[3,7,15,31,31,31]")')
- parser.add_argument('--decoder-glu', type=utils.eval_bool,
- help='glu after in proj')
- parser.add_argument('--decoder-conv-type', default='dynamic', type=str,
- choices=['dynamic', 'lightweight'],
- help='type of convolution')
- parser.add_argument('--weight-softmax', default=True, type=utils.eval_bool)
- parser.add_argument('--weight-dropout', type=float, metavar='D',
- help='dropout probability for conv weights')
+ parser.add_argument(
+ "--decoder-kernel-size-list",
+ type=lambda x: utils.eval_str_list(x, int),
+ help='list of kernel size (default: "[3,7,15,31,31,31]")',
+ )
+ parser.add_argument(
+ "--decoder-glu", type=utils.eval_bool, help="glu after in proj"
+ )
+ parser.add_argument(
+ "--decoder-conv-type",
+ default="dynamic",
+ type=str,
+ choices=["dynamic", "lightweight"],
+ help="type of convolution",
+ )
+ parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool)
+ parser.add_argument(
+ "--weight-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability for conv weights",
+ )
@classmethod
def build_model(cls, args, task):
@@ -101,76 +205,102 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_lm_architecture(args)
- if getattr(args, 'max_source_positions', None) is None:
+ if getattr(args, "max_source_positions", None) is None:
args.max_source_positions = args.tokens_per_sample
- if getattr(args, 'max_target_positions', None) is None:
+ if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = args.tokens_per_sample
if args.character_embeddings:
- embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters),
- args.character_embedding_dim,
- args.decoder_embed_dim,
- args.char_embedder_highway_layers,
- )
+ embed_tokens = CharacterTokenEmbedder(
+ task.dictionary,
+ eval(args.character_filters),
+ args.character_embedding_dim,
+ args.decoder_embed_dim,
+ args.char_embedder_highway_layers,
+ )
elif args.adaptive_input:
- embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
- args.adaptive_input_factor, args.decoder_embed_dim,
- utils.eval_str_list(args.adaptive_input_cutoff, type=int))
+ embed_tokens = AdaptiveInput(
+ len(task.dictionary),
+ task.dictionary.pad(),
+ args.decoder_input_dim,
+ args.adaptive_input_factor,
+ args.decoder_embed_dim,
+ utils.eval_str_list(args.adaptive_input_cutoff, type=int),
+ )
else:
- embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
+ embed_tokens = Embedding(
+ len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()
+ )
if args.tie_adaptive_weights:
assert args.adaptive_input
assert args.adaptive_input_factor == args.adaptive_softmax_factor
- assert args.adaptive_softmax_cutoff == args.adaptive_input_cutoff, '{} != {}'.format(
- args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
+ assert (
+ args.adaptive_softmax_cutoff == args.adaptive_input_cutoff
+ ), "{} != {}".format(
+ args.adaptive_softmax_cutoff, args.adaptive_input_cutoff
+ )
assert args.decoder_input_dim == args.decoder_output_dim
- decoder = LightConvDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
+ decoder = LightConvDecoder(
+ args,
+ task.output_dictionary,
+ embed_tokens,
+ no_encoder_attn=True,
+ final_norm=False,
+ )
return LightConvLanguageModel(decoder)
-@register_model_architecture('lightconv_lm', 'lightconv_lm')
+@register_model_architecture("lightconv_lm", "lightconv_lm")
def base_lm_architecture(args):
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
- args.decoder_layers = getattr(args, 'decoder_layers', 6)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
- args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
- args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4)
- args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
- args.character_embeddings = getattr(args, 'character_embeddings', False)
+ args.character_embeddings = getattr(args, "character_embeddings", False)
- args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
- args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
- args.decoder_conv_dim = getattr(args, 'decoder_conv_dim', args.decoder_embed_dim)
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+ args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim)
# The model training is not stable without this
args.decoder_normalize_before = True
- args.adaptive_input = getattr(args, 'adaptive_input', False)
- args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4)
- args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None)
+ args.adaptive_input = getattr(args, "adaptive_input", False)
+ args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
+ args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
- args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False)
- args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False)
+ args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
+ args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
- args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31])
+ args.decoder_kernel_size_list = getattr(
+ args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31]
+ )
if len(args.decoder_kernel_size_list) == 1:
- args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers
- assert len(args.decoder_kernel_size_list) == args.decoder_layers, "decoder_kernel_size_list doesn't match decoder_layers"
- args.decoder_glu = getattr(args, 'decoder_glu', True)
- args.input_dropout = getattr(args, 'input_dropout', 0.1)
- args.weight_dropout = getattr(args, 'weight_dropout', args.attention_dropout)
+ args.decoder_kernel_size_list = (
+ args.decoder_kernel_size_list * args.decoder_layers
+ )
+ assert (
+ len(args.decoder_kernel_size_list) == args.decoder_layers
+ ), "decoder_kernel_size_list doesn't match decoder_layers"
+ args.decoder_glu = getattr(args, "decoder_glu", True)
+ args.input_dropout = getattr(args, "input_dropout", 0.1)
+ args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout)
-@register_model_architecture('lightconv_lm', 'lightconv_lm_gbw')
+@register_model_architecture("lightconv_lm", "lightconv_lm_gbw")
def lightconv_lm_gbw(args):
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
base_lm_architecture(args)
diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py
index 8404cafe1d..1a9dca3c75 100644
--- a/fairseq/models/lstm.py
+++ b/fairseq/models/lstm.py
@@ -3,28 +3,28 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from typing import Dict, List, Optional, Tuple
+
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
- FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
+ FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import AdaptiveSoftmax, FairseqDropout
from torch import Tensor
-from typing import Dict, List, Optional, Tuple
DEFAULT_MAX_SOURCE_POSITIONS = 1e5
DEFAULT_MAX_TARGET_POSITIONS = 1e5
-@register_model('lstm')
+@register_model("lstm")
class LSTMModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@@ -89,10 +89,14 @@ def build_model(cls, args, task):
base_architecture(args)
if args.encoder_layers != args.decoder_layers:
- raise ValueError('--encoder-layers must match --decoder-layers')
+ raise ValueError("--encoder-layers must match --decoder-layers")
- max_source_positions = getattr(args, 'max_source_positions', DEFAULT_MAX_SOURCE_POSITIONS)
- max_target_positions = getattr(args, 'max_target_positions', DEFAULT_MAX_TARGET_POSITIONS)
+ max_source_positions = getattr(
+ args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS
+ )
+ max_target_positions = getattr(
+ args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS
+ )
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary)
@@ -104,7 +108,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file(
- args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim)
+ args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim
+ )
else:
num_embeddings = len(task.source_dictionary)
pretrained_encoder_embed = Embedding(
@@ -114,16 +119,17 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
if args.share_all_embeddings:
# double check all parameters combinations are valid
if task.source_dictionary != task.target_dictionary:
- raise ValueError('--share-all-embeddings requires a joint dictionary')
+ raise ValueError("--share-all-embeddings requires a joint dictionary")
if args.decoder_embed_path and (
- args.decoder_embed_path != args.encoder_embed_path):
+ args.decoder_embed_path != args.encoder_embed_path
+ ):
raise ValueError(
- '--share-all-embed not compatible with --decoder-embed-path'
+ "--share-all-embed not compatible with --decoder-embed-path"
)
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
- '--share-all-embeddings requires --encoder-embed-dim to '
- 'match --decoder-embed-dim'
+ "--share-all-embeddings requires --encoder-embed-dim to "
+ "match --decoder-embed-dim"
)
pretrained_decoder_embed = pretrained_encoder_embed
args.share_decoder_input_output_embed = True
@@ -134,14 +140,15 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path,
task.target_dictionary,
- args.decoder_embed_dim
+ args.decoder_embed_dim,
)
# one last double check of parameter combinations
if args.share_decoder_input_output_embed and (
- args.decoder_embed_dim != args.decoder_out_embed_dim):
+ args.decoder_embed_dim != args.decoder_out_embed_dim
+ ):
raise ValueError(
- '--share-decoder-input-output-embeddings requires '
- '--decoder-embed-dim to match --decoder-out-embed-dim'
+ "--share-decoder-input-output-embeddings requires "
+ "--decoder-embed-dim to match --decoder-out-embed-dim"
)
if args.encoder_freeze_embed:
@@ -174,7 +181,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
share_input_output_embed=args.share_decoder_input_output_embed,
adaptive_softmax_cutoff=(
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int)
- if args.criterion == 'adaptive_loss' else None
+ if args.criterion == "adaptive_loss"
+ else None
),
max_target_positions=max_target_positions,
residuals=False,
@@ -190,23 +198,38 @@ def forward(
):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths)
decoder_out = self.decoder(
- prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ incremental_state=incremental_state,
)
return decoder_out
class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
+
def __init__(
- self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
- dropout_in=0.1, dropout_out=0.1, bidirectional=False,
- left_pad=True, pretrained_embed=None, padding_idx=None,
+ self,
+ dictionary,
+ embed_dim=512,
+ hidden_size=512,
+ num_layers=1,
+ dropout_in=0.1,
+ dropout_out=0.1,
+ bidirectional=False,
+ left_pad=True,
+ pretrained_embed=None,
+ padding_idx=None,
max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS,
):
super().__init__(dictionary)
self.num_layers = num_layers
- self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__)
- self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__)
+ self.dropout_in_module = FairseqDropout(
+ dropout_in, module_name=self.__class__.__name__
+ )
+ self.dropout_out_module = FairseqDropout(
+ dropout_out, module_name=self.__class__.__name__
+ )
self.bidirectional = bidirectional
self.hidden_size = hidden_size
self.max_source_positions = max_source_positions
@@ -222,7 +245,7 @@ def __init__(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
- dropout=self.dropout_out_module.p if num_layers > 1 else 0.,
+ dropout=self.dropout_out_module.p if num_layers > 1 else 0.0,
bidirectional=bidirectional,
)
self.left_pad = left_pad
@@ -281,7 +304,9 @@ def forward(
packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0))
# unpack outputs and apply dropout
- x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_idx*1.0)
+ x, _ = nn.utils.rnn.pad_packed_sequence(
+ packed_outs, padding_value=self.padding_idx * 1.0
+ )
x = self.dropout_out_module(x)
assert list(x.size()) == [seqlen, bsz, self.output_units]
@@ -291,24 +316,28 @@ def forward(
encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
- return tuple((
- x, # seq_len x batch x hidden
- final_hiddens, # num_layers x batch x num_directions*hidden
- final_cells, # num_layers x batch x num_directions*hidden
- encoder_padding_mask, # seq_len x batch
- ))
+ return tuple(
+ (
+ x, # seq_len x batch x hidden
+ final_hiddens, # num_layers x batch x num_directions*hidden
+ final_cells, # num_layers x batch x num_directions*hidden
+ encoder_padding_mask, # seq_len x batch
+ )
+ )
def combine_bidir(self, outs, bsz: int):
out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
return out.view(self.num_layers, bsz, -1)
def reorder_encoder_out(self, encoder_out, new_order):
- return tuple((
- encoder_out[0].index_select(1, new_order),
- encoder_out[1].index_select(1, new_order),
- encoder_out[2].index_select(1, new_order),
- encoder_out[3].index_select(1, new_order),
- ))
+ return tuple(
+ (
+ encoder_out[0].index_select(1, new_order),
+ encoder_out[1].index_select(1, new_order),
+ encoder_out[2].index_select(1, new_order),
+ encoder_out[3].index_select(1, new_order),
+ )
+ )
def max_positions(self):
"""Maximum input length supported by the encoder."""
@@ -320,7 +349,9 @@ def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=Fal
super().__init__()
self.input_proj = Linear(input_embed_dim, source_embed_dim, bias=bias)
- self.output_proj = Linear(input_embed_dim + source_embed_dim, output_embed_dim, bias=bias)
+ self.output_proj = Linear(
+ input_embed_dim + source_embed_dim, output_embed_dim, bias=bias
+ )
def forward(self, input, source_hids, encoder_padding_mask):
# input: bsz x input_embed_dim
@@ -334,10 +365,11 @@ def forward(self, input, source_hids, encoder_padding_mask):
# don't attend over padding
if encoder_padding_mask is not None:
- attn_scores = attn_scores.float().masked_fill_(
- encoder_padding_mask,
- float('-inf')
- ).type_as(attn_scores) # FP16 support: cast to float and back
+ attn_scores = (
+ attn_scores.float()
+ .masked_fill_(encoder_padding_mask, float("-inf"))
+ .type_as(attn_scores)
+ ) # FP16 support: cast to float and back
attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz
@@ -350,17 +382,31 @@ def forward(self, input, source_hids, encoder_padding_mask):
class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder."""
+
def __init__(
- self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
- num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
- encoder_output_units=512, pretrained_embed=None,
- share_input_output_embed=False, adaptive_softmax_cutoff=None,
+ self,
+ dictionary,
+ embed_dim=512,
+ hidden_size=512,
+ out_embed_dim=512,
+ num_layers=1,
+ dropout_in=0.1,
+ dropout_out=0.1,
+ attention=True,
+ encoder_output_units=512,
+ pretrained_embed=None,
+ share_input_output_embed=False,
+ adaptive_softmax_cutoff=None,
max_target_positions=DEFAULT_MAX_TARGET_POSITIONS,
residuals=False,
):
super().__init__(dictionary)
- self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__)
- self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__)
+ self.dropout_in_module = FairseqDropout(
+ dropout_in, module_name=self.__class__.__name__
+ )
+ self.dropout_out_module = FairseqDropout(
+ dropout_out, module_name=self.__class__.__name__
+ )
self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed
self.need_attn = True
@@ -386,17 +432,23 @@ def __init__(
# disable input feeding if there is no encoder
# input feeding is described in arxiv.org/abs/1508.04025
input_feed_size = 0 if encoder_output_units == 0 else hidden_size
- self.layers = nn.ModuleList([
- LSTMCell(
- input_size=input_feed_size + embed_dim if layer == 0 else hidden_size,
- hidden_size=hidden_size,
- )
- for layer in range(num_layers)
- ])
+ self.layers = nn.ModuleList(
+ [
+ LSTMCell(
+ input_size=input_feed_size + embed_dim
+ if layer == 0
+ else hidden_size,
+ hidden_size=hidden_size,
+ )
+ for layer in range(num_layers)
+ ]
+ )
if attention:
# TODO make bias configurable
- self.attention = AttentionLayer(hidden_size, encoder_output_units, hidden_size, bias=False)
+ self.attention = AttentionLayer(
+ hidden_size, encoder_output_units, hidden_size, bias=False
+ )
else:
self.attention = None
@@ -406,7 +458,10 @@ def __init__(
if adaptive_softmax_cutoff is not None:
# setting adaptive_softmax dropout to dropout_out for now but can be redefined
self.adaptive_softmax = AdaptiveSoftmax(
- num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out,
+ num_embeddings,
+ hidden_size,
+ adaptive_softmax_cutoff,
+ dropout=dropout_out,
)
elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
@@ -459,7 +514,9 @@ def extract_features(
# initialize previous states (or get from cache during incremental generation)
if incremental_state is not None and len(incremental_state) > 0:
- prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state)
+ prev_hiddens, prev_cells, input_feed = self.get_cached_state(
+ incremental_state
+ )
elif encoder_out is not None:
# setup recurrent cells
prev_hiddens = [encoder_hiddens[i] for i in range(self.num_layers)]
@@ -475,9 +532,12 @@ def extract_features(
prev_cells = [zero_state for i in range(self.num_layers)]
input_feed = None
- assert srclen > 0 or self.attention is None, \
- "attention is not supported if there are no encoder outputs"
- attn_scores = x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None
+ assert (
+ srclen > 0 or self.attention is None
+ ), "attention is not supported if there are no encoder outputs"
+ attn_scores = (
+ x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None
+ )
outs = []
for j in range(seqlen):
# input feeding: concatenate context vector from previous time step
@@ -502,7 +562,9 @@ def extract_features(
# apply attention using the last layer's hidden state
if self.attention is not None:
assert attn_scores is not None
- out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs, encoder_padding_mask)
+ out, attn_scores[:, j, :] = self.attention(
+ hidden, encoder_outs, encoder_padding_mask
+ )
else:
out = hidden
out = self.dropout_out_module(out)
@@ -523,9 +585,9 @@ def extract_features(
"prev_hiddens": prev_hiddens_tensor,
"prev_cells": prev_cells_tensor,
"input_feed": input_feed,
- }
+ },
)
- self.set_incremental_state(incremental_state, 'cached_state', cache_state)
+ self.set_incremental_state(incremental_state, "cached_state", cache_state)
# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
@@ -533,7 +595,7 @@ def extract_features(
# T x B x C -> B x T x C
x = x.transpose(1, 0)
- if hasattr(self, 'additional_fc') and self.adaptive_softmax is None:
+ if hasattr(self, "additional_fc") and self.adaptive_softmax is None:
x = self.additional_fc(x)
x = self.dropout_out_module(x)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
@@ -557,7 +619,7 @@ def get_cached_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
) -> Tuple[List[Tensor], List[Tensor], Optional[Tensor]]:
- cached_state = self.get_incremental_state(incremental_state, 'cached_state')
+ cached_state = self.get_incremental_state(incremental_state, "cached_state")
assert cached_state is not None
prev_hiddens_ = cached_state["prev_hiddens"]
assert prev_hiddens_ is not None
@@ -565,7 +627,9 @@ def get_cached_state(
assert prev_cells_ is not None
prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)]
prev_cells = [prev_cells_[j] for j in range(self.num_layers)]
- input_feed = cached_state["input_feed"] # can be None for decoder-only language models
+ input_feed = cached_state[
+ "input_feed"
+ ] # can be None for decoder-only language models
return prev_hiddens, prev_cells, input_feed
def reorder_incremental_state(
@@ -586,9 +650,9 @@ def reorder_incremental_state(
"prev_hiddens": torch.stack(prev_hiddens),
"prev_cells": torch.stack(prev_cells),
"input_feed": input_feed,
- }
+ },
)
- self.set_incremental_state(incremental_state, 'cached_state', cached_state_new),
+ self.set_incremental_state(incremental_state, "cached_state", cached_state_new),
return
def max_positions(self):
@@ -609,7 +673,7 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
def LSTM(input_size, hidden_size, **kwargs):
m = nn.LSTM(input_size, hidden_size, **kwargs)
for name, param in m.named_parameters():
- if 'weight' in name or 'bias' in name:
+ if "weight" in name or "bias" in name:
param.data.uniform_(-0.1, 0.1)
return m
@@ -617,12 +681,12 @@ def LSTM(input_size, hidden_size, **kwargs):
def LSTMCell(input_size, hidden_size, **kwargs):
m = nn.LSTMCell(input_size, hidden_size, **kwargs)
for name, param in m.named_parameters():
- if 'weight' in name or 'bias' in name:
+ if "weight" in name or "bias" in name:
param.data.uniform_(-0.1, 0.1)
return m
-def Linear(in_features, out_features, bias=True, dropout=0.):
+def Linear(in_features, out_features, bias=True, dropout=0.0):
"""Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.uniform_(-0.1, 0.1)
@@ -631,51 +695,59 @@ def Linear(in_features, out_features, bias=True, dropout=0.):
return m
-@register_model_architecture('lstm', 'lstm')
+@register_model_architecture("lstm", "lstm")
def base_architecture(args):
- args.dropout = getattr(args, 'dropout', 0.1)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
- args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
- args.encoder_freeze_embed = getattr(args, 'encoder_freeze_embed', False)
- args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', args.encoder_embed_dim)
- args.encoder_layers = getattr(args, 'encoder_layers', 1)
- args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False)
- args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout)
- args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
- args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False)
- args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim)
- args.decoder_layers = getattr(args, 'decoder_layers', 1)
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
- args.decoder_attention = getattr(args, 'decoder_attention', '1')
- args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
- args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
- args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
- args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
-
-
-@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_freeze_embed = getattr(args, "encoder_freeze_embed", False)
+ args.encoder_hidden_size = getattr(
+ args, "encoder_hidden_size", args.encoder_embed_dim
+ )
+ args.encoder_layers = getattr(args, "encoder_layers", 1)
+ args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False)
+ args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout)
+ args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_freeze_embed = getattr(args, "decoder_freeze_embed", False)
+ args.decoder_hidden_size = getattr(
+ args, "decoder_hidden_size", args.decoder_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 1)
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
+ args.decoder_attention = getattr(args, "decoder_attention", "1")
+ args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout)
+ args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
+ args.adaptive_softmax_cutoff = getattr(
+ args, "adaptive_softmax_cutoff", "10000,50000,200000"
+ )
+
+
+@register_model_architecture("lstm", "lstm_wiseman_iwslt_de_en")
def lstm_wiseman_iwslt_de_en(args):
- args.dropout = getattr(args, 'dropout', 0.1)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
- args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', 0)
- args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
- args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', 0)
- args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
+ args.encoder_dropout_in = getattr(args, "encoder_dropout_in", 0)
+ args.encoder_dropout_out = getattr(args, "encoder_dropout_out", 0)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256)
+ args.decoder_dropout_in = getattr(args, "decoder_dropout_in", 0)
+ args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout)
base_architecture(args)
-@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
+@register_model_architecture("lstm", "lstm_luong_wmt_en_de")
def lstm_luong_wmt_en_de(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
- args.encoder_layers = getattr(args, 'encoder_layers', 4)
- args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1000)
- args.decoder_layers = getattr(args, 'decoder_layers', 4)
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1000)
- args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', 0)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1000)
+ args.encoder_layers = getattr(args, "encoder_layers", 4)
+ args.encoder_dropout_out = getattr(args, "encoder_dropout_out", 0)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1000)
+ args.decoder_layers = getattr(args, "decoder_layers", 4)
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1000)
+ args.decoder_dropout_out = getattr(args, "decoder_dropout_out", 0)
base_architecture(args)
diff --git a/fairseq/models/lstm_lm.py b/fairseq/models/lstm_lm.py
index 1a39b95289..454f0ac36f 100644
--- a/fairseq/models/lstm_lm.py
+++ b/fairseq/models/lstm_lm.py
@@ -5,15 +5,17 @@
from fairseq import utils
from fairseq.models import (
- FairseqLanguageModel, register_model, register_model_architecture
-)
-from fairseq.models.lstm import (
- LSTMDecoder, Embedding
+ FairseqLanguageModel,
+ register_model,
+ register_model_architecture,
)
+from fairseq.models.lstm import Embedding, LSTMDecoder
+
DEFAULT_MAX_TARGET_POSITIONS = 1e5
-@register_model('lstm_lm')
+
+@register_model("lstm_lm")
class LSTMLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@@ -60,10 +62,12 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_architecture(args)
- if getattr(args, 'max_target_positions', None) is not None:
+ if getattr(args, "max_target_positions", None) is not None:
max_target_positions = args.max_target_positions
else:
- max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)
+ max_target_positions = getattr(
+ args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
+ )
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary)
@@ -76,21 +80,21 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
pretrained_decoder_embed = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
- args.decoder_embed_path,
- task.target_dictionary,
- args.decoder_embed_dim
+ args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim
)
if args.share_decoder_input_output_embed:
# double check all parameters combinations are valid
if task.source_dictionary != task.target_dictionary:
- raise ValueError('--share-decoder-input-output-embeddings requires a joint dictionary')
+ raise ValueError(
+ "--share-decoder-input-output-embeddings requires a joint dictionary"
+ )
if args.decoder_embed_dim != args.decoder_out_embed_dim:
raise ValueError(
- '--share-decoder-input-output-embeddings requires '
- '--decoder-embed-dim to match --decoder-out-embed-dim'
- )
+ "--share-decoder-input-output-embeddings requires "
+ "--decoder-embed-dim to match --decoder-out-embed-dim"
+ )
decoder = LSTMDecoder(
dictionary=task.dictionary,
@@ -106,26 +110,33 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
share_input_output_embed=args.share_decoder_input_output_embed,
adaptive_softmax_cutoff=(
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int)
- if args.criterion == 'adaptive_loss' else None
+ if args.criterion == "adaptive_loss"
+ else None
),
max_target_positions=max_target_positions,
- residuals=args.residuals
+ residuals=args.residuals,
)
return cls(decoder)
-@register_model_architecture('lstm_lm', 'lstm_lm')
+@register_model_architecture("lstm_lm", "lstm_lm")
def base_architecture(args):
- args.dropout = getattr(args, 'dropout', 0.1)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
- args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim)
- args.decoder_layers = getattr(args, 'decoder_layers', 1)
- args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
- args.decoder_attention = getattr(args, 'decoder_attention', '0')
- args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
- args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
- args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
- args.residuals = getattr(args, 'residuals', False)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_hidden_size = getattr(
+ args, "decoder_hidden_size", args.decoder_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 1)
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
+ args.decoder_attention = getattr(args, "decoder_attention", "0")
+ args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout)
+ args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.adaptive_softmax_cutoff = getattr(
+ args, "adaptive_softmax_cutoff", "10000,50000,200000"
+ )
+ args.residuals = getattr(args, "residuals", False)
diff --git a/fairseq/models/masked_lm.py b/fairseq/models/masked_lm.py
index 35a6323ef2..c786de9125 100644
--- a/fairseq/models/masked_lm.py
+++ b/fairseq/models/masked_lm.py
@@ -8,11 +8,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import utils
from fairseq.models import (
- FairseqEncoderModel,
FairseqEncoder,
+ FairseqEncoderModel,
register_model,
register_model_architecture,
)
@@ -27,12 +26,13 @@
logger = logging.getLogger(__name__)
-@register_model('masked_lm')
+@register_model("masked_lm")
class MaskedLMModel(FairseqEncoderModel):
"""
Class for training a Masked Language Model. It also supports an
additional sentence level prediction if the sent-loss argument is set.
"""
+
def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args
@@ -40,66 +40,111 @@ def __init__(self, args, encoder):
# if specified then apply bert initialization on the model. We need
# to explictly call this to make sure that the output embeddings
# and projection layers are also correctly initialized
- if getattr(args, 'apply_bert_init', False):
+ if getattr(args, "apply_bert_init", False):
self.apply(init_bert_params)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# Arguments related to dropout
- parser.add_argument('--dropout', type=float, metavar='D',
- help='dropout probability')
- parser.add_argument('--attention-dropout', type=float,
- metavar='D', help='dropout probability for'
- ' attention weights')
- parser.add_argument('--act-dropout', type=float,
- metavar='D', help='dropout probability after'
- ' activation in FFN')
+ parser.add_argument(
+ "--dropout", type=float, metavar="D", help="dropout probability"
+ )
+ parser.add_argument(
+ "--attention-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability for" " attention weights",
+ )
+ parser.add_argument(
+ "--act-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability after" " activation in FFN",
+ )
# Arguments related to hidden states and self-attention
- parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
- help='encoder embedding dimension for FFN')
- parser.add_argument('--encoder-layers', type=int, metavar='N',
- help='num encoder layers')
- parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
- help='num encoder attention heads')
+ parser.add_argument(
+ "--encoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--encoder-layers", type=int, metavar="N", help="num encoder layers"
+ )
+ parser.add_argument(
+ "--encoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num encoder attention heads",
+ )
# Arguments related to input and output embeddings
- parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
- help='encoder embedding dimension')
- parser.add_argument('--share-encoder-input-output-embed',
- action='store_true', help='share encoder input'
- ' and output embeddings')
- parser.add_argument('--encoder-learned-pos', action='store_true',
- help='use learned positional embeddings in the encoder')
- parser.add_argument('--no-token-positional-embeddings',
- action='store_true',
- help='if set, disables positional embeddings'
- ' (outside self attention)')
- parser.add_argument('--num-segment', type=int, metavar='N',
- help='num segment in the input')
- parser.add_argument('--max-positions', type=int,
- help='number of positional embeddings to learn')
+ parser.add_argument(
+ "--encoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension",
+ )
+ parser.add_argument(
+ "--share-encoder-input-output-embed",
+ action="store_true",
+ help="share encoder input" " and output embeddings",
+ )
+ parser.add_argument(
+ "--encoder-learned-pos",
+ action="store_true",
+ help="use learned positional embeddings in the encoder",
+ )
+ parser.add_argument(
+ "--no-token-positional-embeddings",
+ action="store_true",
+ help="if set, disables positional embeddings" " (outside self attention)",
+ )
+ parser.add_argument(
+ "--num-segment", type=int, metavar="N", help="num segment in the input"
+ )
+ parser.add_argument(
+ "--max-positions", type=int, help="number of positional embeddings to learn"
+ )
# Arguments related to sentence level prediction
- parser.add_argument('--sentence-class-num', type=int, metavar='N',
- help='number of classes for sentence task')
- parser.add_argument('--sent-loss', action='store_true', help='if set,'
- ' calculate sentence level predictions')
+ parser.add_argument(
+ "--sentence-class-num",
+ type=int,
+ metavar="N",
+ help="number of classes for sentence task",
+ )
+ parser.add_argument(
+ "--sent-loss",
+ action="store_true",
+ help="if set," " calculate sentence level predictions",
+ )
# Arguments related to parameter initialization
- parser.add_argument('--apply-bert-init', action='store_true',
- help='use custom param initialization for BERT')
+ parser.add_argument(
+ "--apply-bert-init",
+ action="store_true",
+ help="use custom param initialization for BERT",
+ )
# misc params
- parser.add_argument('--activation-fn',
- choices=utils.get_available_activation_fns(),
- help='activation function to use')
- parser.add_argument('--pooler-activation-fn',
- choices=utils.get_available_activation_fns(),
- help='Which activation function to use for pooler layer.')
- parser.add_argument('--encoder-normalize-before', action='store_true',
- help='apply layernorm before each encoder block')
+ parser.add_argument(
+ "--activation-fn",
+ choices=utils.get_available_activation_fns(),
+ help="activation function to use",
+ )
+ parser.add_argument(
+ "--pooler-activation-fn",
+ choices=utils.get_available_activation_fns(),
+ help="Which activation function to use for pooler layer.",
+ )
+ parser.add_argument(
+ "--encoder-normalize-before",
+ action="store_true",
+ help="apply layernorm before each encoder block",
+ )
def forward(self, src_tokens, segment_labels=None, **kwargs):
return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs)
@@ -113,7 +158,7 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_architecture(args)
- if not hasattr(args, 'max_positions'):
+ if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
logger.info(args)
@@ -160,14 +205,16 @@ def __init__(self, args, dictionary):
self.lm_output_learned_bias = None
# Remove head is set to true during fine-tuning
- self.load_softmax = not getattr(args, 'remove_head', False)
+ self.load_softmax = not getattr(args, "remove_head", False)
self.masked_lm_pooler = nn.Linear(
args.encoder_embed_dim, args.encoder_embed_dim
)
self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn)
- self.lm_head_transform_weight = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim)
+ self.lm_head_transform_weight = nn.Linear(
+ args.encoder_embed_dim, args.encoder_embed_dim
+ )
self.activation_fn = utils.get_activation_fn(args.activation_fn)
self.layer_norm = LayerNorm(args.encoder_embed_dim)
@@ -177,16 +224,12 @@ def __init__(self, args, dictionary):
if not self.share_input_output_embed:
self.embed_out = nn.Linear(
- args.encoder_embed_dim,
- self.vocab_size,
- bias=False
+ args.encoder_embed_dim, self.vocab_size, bias=False
)
if args.sent_loss:
self.sentence_projection_layer = nn.Linear(
- args.encoder_embed_dim,
- self.sentence_out_dim,
- bias=False
+ args.encoder_embed_dim, self.sentence_out_dim, bias=False
)
def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused):
@@ -227,8 +270,9 @@ def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused)
pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep))
# project back to size of vocabulary
- if self.share_input_output_embed \
- and hasattr(self.sentence_encoder.embed_tokens, 'weight'):
+ if self.share_input_output_embed and hasattr(
+ self.sentence_encoder.embed_tokens, "weight"
+ ):
x = F.linear(x, self.sentence_encoder.embed_tokens.weight)
elif self.embed_out is not None:
x = self.embed_out(x)
@@ -239,9 +283,9 @@ def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused)
sentence_logits = self.sentence_projection_layer(pooled_output)
return x, {
- 'inner_states': inner_states,
- 'pooled_output': pooled_output,
- 'sentence_logits': sentence_logits
+ "inner_states": inner_states,
+ "pooled_output": pooled_output,
+ "sentence_logits": sentence_logits,
}
def max_positions(self):
@@ -250,103 +294,110 @@ def max_positions(self):
def upgrade_state_dict_named(self, state_dict, name):
if isinstance(
- self.sentence_encoder.embed_positions,
- SinusoidalPositionalEmbedding
+ self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding
):
state_dict[
- name + '.sentence_encoder.embed_positions._float_tensor'
+ name + ".sentence_encoder.embed_positions._float_tensor"
] = torch.FloatTensor(1)
if not self.load_softmax:
for k in list(state_dict.keys()):
if (
- "embed_out.weight" in k or
- "sentence_projection_layer.weight" in k or
- "lm_output_learned_bias" in k
+ "embed_out.weight" in k
+ or "sentence_projection_layer.weight" in k
+ or "lm_output_learned_bias" in k
):
del state_dict[k]
return state_dict
-@register_model_architecture('masked_lm', 'masked_lm')
+@register_model_architecture("masked_lm", "masked_lm")
def base_architecture(args):
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.act_dropout = getattr(args, 'act_dropout', 0.0)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.act_dropout = getattr(args, "act_dropout", 0.0)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
- args.encoder_layers = getattr(args, 'encoder_layers', 6)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
- args.share_encoder_input_output_embed = getattr(args, 'share_encoder_input_output_embed', False)
- args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
- args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
- args.num_segment = getattr(args, 'num_segment', 2)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.share_encoder_input_output_embed = getattr(
+ args, "share_encoder_input_output_embed", False
+ )
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+ args.num_segment = getattr(args, "num_segment", 2)
- args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
- args.sent_loss = getattr(args, 'sent_loss', False)
+ args.sentence_class_num = getattr(args, "sentence_class_num", 2)
+ args.sent_loss = getattr(args, "sent_loss", False)
- args.apply_bert_init = getattr(args, 'apply_bert_init', False)
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
- args.activation_fn = getattr(args, 'activation_fn', 'relu')
- args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
- args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
+ args.activation_fn = getattr(args, "activation_fn", "relu")
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
-@register_model_architecture('masked_lm', 'bert_base')
+@register_model_architecture("masked_lm", "bert_base")
def bert_base_architecture(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.share_encoder_input_output_embed = getattr(
- args, 'share_encoder_input_output_embed', True)
+ args, "share_encoder_input_output_embed", True
+ )
args.no_token_positional_embeddings = getattr(
- args, 'no_token_positional_embeddings', False)
- args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
- args.num_segment = getattr(args, 'num_segment', 2)
+ args, "no_token_positional_embeddings", False
+ )
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
+ args.num_segment = getattr(args, "num_segment", 2)
- args.encoder_layers = getattr(args, 'encoder_layers', 12)
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
- args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
- args.sent_loss = getattr(args, 'sent_loss', True)
+ args.sentence_class_num = getattr(args, "sentence_class_num", 2)
+ args.sent_loss = getattr(args, "sent_loss", True)
- args.apply_bert_init = getattr(args, 'apply_bert_init', True)
+ args.apply_bert_init = getattr(args, "apply_bert_init", True)
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
- args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
- args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
base_architecture(args)
-@register_model_architecture('masked_lm', 'bert_large')
+@register_model_architecture("masked_lm", "bert_large")
def bert_large_architecture(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
- args.encoder_layers = getattr(args, 'encoder_layers', 24)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
bert_base_architecture(args)
-@register_model_architecture('masked_lm', 'xlm_base')
+@register_model_architecture("masked_lm", "xlm_base")
def xlm_architecture(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.share_encoder_input_output_embed = getattr(
- args, 'share_encoder_input_output_embed', True)
+ args, "share_encoder_input_output_embed", True
+ )
args.no_token_positional_embeddings = getattr(
- args, 'no_token_positional_embeddings', False)
- args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
- args.num_segment = getattr(args, 'num_segment', 1)
+ args, "no_token_positional_embeddings", False
+ )
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
+ args.num_segment = getattr(args, "num_segment", 1)
- args.encoder_layers = getattr(args, 'encoder_layers', 6)
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
- args.sent_loss = getattr(args, 'sent_loss', False)
+ args.sent_loss = getattr(args, "sent_loss", False)
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
- args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
- args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
- args.apply_bert_init = getattr(args, 'apply_bert_init', True)
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
+ args.apply_bert_init = getattr(args, "apply_bert_init", True)
base_architecture(args)
diff --git a/fairseq/models/model_utils.py b/fairseq/models/model_utils.py
index 46ec62f772..732d66b1d5 100644
--- a/fairseq/models/model_utils.py
+++ b/fairseq/models/model_utils.py
@@ -60,7 +60,9 @@ def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor:
@torch.jit.script
-def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int) -> Optional[Tensor]:
+def fill_tensors(
+ x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int
+) -> Optional[Tensor]:
"""
Filling tensor x with y at masked positions (dim=0).
"""
@@ -82,9 +84,9 @@ def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: in
elif x.size(1) > y.size(1):
x[mask] = torch.tensor(padding_idx).type_as(x)
if x.dim() == 2:
- x[mask, :y.size(1)] = y
+ x[mask, : y.size(1)] = y
else:
- x[mask, :y.size(1), :] = y
+ x[mask, : y.size(1), :] = y
else:
x[mask] = y
return x
diff --git a/fairseq/models/multilingual_transformer.py b/fairseq/models/multilingual_transformer.py
index 91a413753c..e3fbbd5710 100644
--- a/fairseq/models/multilingual_transformer.py
+++ b/fairseq/models/multilingual_transformer.py
@@ -12,15 +12,15 @@
register_model_architecture,
)
from fairseq.models.transformer import (
- base_architecture,
Embedding,
- TransformerModel,
- TransformerEncoder,
TransformerDecoder,
+ TransformerEncoder,
+ TransformerModel,
+ base_architecture,
)
-@register_model('multilingual_transformer')
+@register_model("multilingual_transformer")
class MultilingualTransformerModel(FairseqMultiModel):
"""Train Transformer models for multiple language pairs simultaneously.
@@ -44,31 +44,44 @@ def __init__(self, encoders, decoders):
def add_args(parser):
"""Add model-specific arguments to the parser."""
TransformerModel.add_args(parser)
- parser.add_argument('--share-encoder-embeddings', action='store_true',
- help='share encoder embeddings across languages')
- parser.add_argument('--share-decoder-embeddings', action='store_true',
- help='share decoder embeddings across languages')
- parser.add_argument('--share-encoders', action='store_true',
- help='share encoders across languages')
- parser.add_argument('--share-decoders', action='store_true',
- help='share decoders across languages')
+ parser.add_argument(
+ "--share-encoder-embeddings",
+ action="store_true",
+ help="share encoder embeddings across languages",
+ )
+ parser.add_argument(
+ "--share-decoder-embeddings",
+ action="store_true",
+ help="share decoder embeddings across languages",
+ )
+ parser.add_argument(
+ "--share-encoders",
+ action="store_true",
+ help="share encoders across languages",
+ )
+ parser.add_argument(
+ "--share-decoders",
+ action="store_true",
+ help="share decoders across languages",
+ )
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
+
assert isinstance(task, MultilingualTranslationTask)
# make sure all arguments are present in older models
base_multilingual_architecture(args)
- if not hasattr(args, 'max_source_positions'):
+ if not hasattr(args, "max_source_positions"):
args.max_source_positions = 1024
- if not hasattr(args, 'max_target_positions'):
+ if not hasattr(args, "max_target_positions"):
args.max_target_positions = 1024
- src_langs = [lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs]
- tgt_langs = [lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs]
+ src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs]
+ tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs]
if args.share_encoders:
args.share_encoder_embeddings = True
@@ -90,10 +103,14 @@ def build_embedding(dictionary, embed_dim, path=None):
if args.share_all_embeddings:
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
- '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
+ )
if args.decoder_embed_path and (
- args.decoder_embed_path != args.encoder_embed_path):
- raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
+ args.decoder_embed_path != args.encoder_embed_path
+ ):
+ raise ValueError(
+ "--share-all-embeddings not compatible with --decoder-embed-path"
+ )
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=task.langs,
@@ -105,24 +122,20 @@ def build_embedding(dictionary, embed_dim, path=None):
args.share_decoder_input_output_embed = True
else:
if args.share_encoder_embeddings:
- shared_encoder_embed_tokens = (
- FairseqMultiModel.build_shared_embeddings(
- dicts=task.dicts,
- langs=src_langs,
- embed_dim=args.encoder_embed_dim,
- build_embedding=build_embedding,
- pretrained_embed_path=args.encoder_embed_path,
- )
+ shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
+ dicts=task.dicts,
+ langs=src_langs,
+ embed_dim=args.encoder_embed_dim,
+ build_embedding=build_embedding,
+ pretrained_embed_path=args.encoder_embed_path,
)
if args.share_decoder_embeddings:
- shared_decoder_embed_tokens = (
- FairseqMultiModel.build_shared_embeddings(
- dicts=task.dicts,
- langs=tgt_langs,
- embed_dim=args.decoder_embed_dim,
- build_embedding=build_embedding,
- pretrained_embed_path=args.decoder_embed_path,
- )
+ shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
+ dicts=task.dicts,
+ langs=tgt_langs,
+ embed_dim=args.decoder_embed_dim,
+ build_embedding=build_embedding,
+ pretrained_embed_path=args.decoder_embed_path,
)
# encoders/decoders for each language
@@ -134,10 +147,13 @@ def get_encoder(lang):
encoder_embed_tokens = shared_encoder_embed_tokens
else:
encoder_embed_tokens = build_embedding(
- task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path
+ task.dicts[lang],
+ args.encoder_embed_dim,
+ args.encoder_embed_path,
)
lang_encoders[lang] = cls._get_module_class(
- True, args, task.dicts[lang], encoder_embed_tokens, src_langs)
+ True, args, task.dicts[lang], encoder_embed_tokens, src_langs
+ )
return lang_encoders[lang]
def get_decoder(lang):
@@ -146,10 +162,13 @@ def get_decoder(lang):
decoder_embed_tokens = shared_decoder_embed_tokens
else:
decoder_embed_tokens = build_embedding(
- task.dicts[lang], args.decoder_embed_dim, args.decoder_embed_path
+ task.dicts[lang],
+ args.decoder_embed_dim,
+ args.decoder_embed_path,
)
lang_decoders[lang] = cls._get_module_class(
- False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs)
+ False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs
+ )
return lang_decoders[lang]
# shared encoders/decoders (if applicable)
@@ -161,8 +180,12 @@ def get_decoder(lang):
encoders, decoders = OrderedDict(), OrderedDict()
for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
- encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src)
- decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt)
+ encoders[lang_pair] = (
+ shared_encoder if shared_encoder is not None else get_encoder(src)
+ )
+ decoders[lang_pair] = (
+ shared_decoder if shared_decoder is not None else get_decoder(tgt)
+ )
return MultilingualTransformerModel(encoders, decoders)
@@ -174,30 +197,32 @@ def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
def load_state_dict(self, state_dict, strict=True, args=None):
state_dict_subset = state_dict.copy()
for k, _ in state_dict.items():
- assert k.startswith('models.')
- lang_pair = k.split('.')[1]
+ assert k.startswith("models.")
+ lang_pair = k.split(".")[1]
if lang_pair not in self.models:
del state_dict_subset[k]
super().load_state_dict(state_dict_subset, strict=strict, args=args)
-@register_model_architecture('multilingual_transformer', 'multilingual_transformer')
+@register_model_architecture("multilingual_transformer", "multilingual_transformer")
def base_multilingual_architecture(args):
base_architecture(args)
- args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', False)
- args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', False)
- args.share_encoders = getattr(args, 'share_encoders', False)
- args.share_decoders = getattr(args, 'share_decoders', False)
+ args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False)
+ args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False)
+ args.share_encoders = getattr(args, "share_encoders", False)
+ args.share_decoders = getattr(args, "share_decoders", False)
-@register_model_architecture('multilingual_transformer', 'multilingual_transformer_iwslt_de_en')
+@register_model_architecture(
+ "multilingual_transformer", "multilingual_transformer_iwslt_de_en"
+)
def multilingual_transformer_iwslt_de_en(args):
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
- args.encoder_layers = getattr(args, 'encoder_layers', 6)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
- args.decoder_layers = getattr(args, 'decoder_layers', 6)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
base_multilingual_architecture(args)
diff --git a/fairseq/models/nat/cmlm_transformer.py b/fairseq/models/nat/cmlm_transformer.py
index 86c770569d..c876e9453c 100644
--- a/fairseq/models/nat/cmlm_transformer.py
+++ b/fairseq/models/nat/cmlm_transformer.py
@@ -38,26 +38,34 @@ def forward(
# encoding
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
# length prediction
- length_out = self.decoder.forward_length(normalize=False, encoder_out=encoder_out)
- length_tgt = self.decoder.forward_length_prediction(length_out, encoder_out, tgt_tokens)
+ length_out = self.decoder.forward_length(
+ normalize=False, encoder_out=encoder_out
+ )
+ length_tgt = self.decoder.forward_length_prediction(
+ length_out, encoder_out, tgt_tokens
+ )
# decoding
word_ins_out = self.decoder(
normalize=False,
prev_output_tokens=prev_output_tokens,
- encoder_out=encoder_out)
+ encoder_out=encoder_out,
+ )
word_ins_mask = prev_output_tokens.eq(self.unk)
return {
"word_ins": {
- "out": word_ins_out, "tgt": tgt_tokens,
- "mask": word_ins_mask, "ls": self.args.label_smoothing,
- "nll_loss": True
+ "out": word_ins_out,
+ "tgt": tgt_tokens,
+ "mask": word_ins_mask,
+ "ls": self.args.label_smoothing,
+ "nll_loss": True,
},
"length": {
- "out": length_out, "tgt": length_tgt,
- "factor": self.decoder.length_loss_factor
- }
+ "out": length_out,
+ "tgt": length_tgt,
+ "factor": self.decoder.length_loss_factor,
+ },
}
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
@@ -98,7 +106,7 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
- history=history
+ history=history,
)
diff --git a/fairseq/models/nat/fairseq_nat_model.py b/fairseq/models/nat/fairseq_nat_model.py
index d37a234ba9..1dbc29d0f4 100644
--- a/fairseq/models/nat/fairseq_nat_model.py
+++ b/fairseq/models/nat/fairseq_nat_model.py
@@ -4,9 +4,13 @@
# LICENSE file in the root directory of this source tree.
import math
-import torch
-from fairseq.models.transformer import TransformerModel, TransformerEncoder, TransformerDecoder
+import torch
+from fairseq.models.transformer import (
+ TransformerDecoder,
+ TransformerEncoder,
+ TransformerModel,
+)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
@@ -22,22 +26,31 @@ def stack(key):
return torch.stack(outs, -1) if outs[0] is not None else None
return _encoder_out._replace(
- encoder_out=stack('encoder_out'),
- encoder_embedding=stack('encoder_embedding'),
- encoder_states=stack('encoder_states')
+ encoder_out=stack("encoder_out"),
+ encoder_embedding=stack("encoder_embedding"),
+ encoder_states=stack("encoder_states"),
)
+
return wrapper
def ensemble_decoder(func):
def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs):
if self.ensemble_models is None or len(self.ensemble_models) == 1:
- return func(self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs)
+ return func(
+ self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs
+ )
action_outs = [
- func(model, normalize=normalize, encoder_out=encoder_out._replace(
- encoder_out=encoder_out.encoder_out[:, :, :, i]
- ), *args, **kwargs)
+ func(
+ model,
+ normalize=normalize,
+ encoder_out=encoder_out._replace(
+ encoder_out=encoder_out.encoder_out[:, :, :, i]
+ ),
+ *args,
+ **kwargs
+ )
for i, model in enumerate(self.ensemble_models)
]
@@ -51,19 +64,19 @@ def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs):
if i == 0 and normalize:
ensembled_outs += [
torch.logsumexp(
- torch.stack([a[i] for a in action_outs], -1),
- dim=-1) - math.log(len(self.ensemble_models))
+ torch.stack([a[i] for a in action_outs], -1), dim=-1
+ )
+ - math.log(len(self.ensemble_models))
]
elif action_outs[0][i] is not None:
- ensembled_outs += [
- torch.stack([a[i] for a in action_outs], -1)
- ]
+ ensembled_outs += [torch.stack([a[i] for a in action_outs], -1)]
else:
ensembled_outs += [None]
if len(ensembled_outs) == 1:
return ensembled_outs[0]
return tuple(ensembled_outs)
+
return wrapper
@@ -71,6 +84,7 @@ class FairseqNATModel(TransformerModel):
"""
Abstract class for all nonautoregressive-based models
"""
+
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
self.tgt_dict = decoder.dictionary
diff --git a/fairseq/models/nat/insertion_transformer.py b/fairseq/models/nat/insertion_transformer.py
index a5f3c1abc5..bc28000f59 100644
--- a/fairseq/models/nat/insertion_transformer.py
+++ b/fairseq/models/nat/insertion_transformer.py
@@ -6,17 +6,16 @@
import numpy as np
import torch
import torch.nn.functional as F
-
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nat import (
+ FairseqNATModel,
LevenshteinTransformerDecoder,
LevenshteinTransformerModel,
- FairseqNATModel,
- ensemble_decoder
+ ensemble_decoder,
)
from fairseq.models.transformer import Linear
-from fairseq.utils import new_arange
from fairseq.modules.transformer_sentence_encoder import init_bert_params
+from fairseq.utils import new_arange
class NegativeDistanceScore(object):
@@ -58,7 +57,8 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, ta
from fairseq import libnat
except ImportError as e:
import sys
- sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
+
+ sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e
B = in_tokens.size(0)
@@ -147,7 +147,7 @@ def forward(
word_ins_out = self.decoder.forward_word_ins(
normalize=False,
prev_output_tokens=prev_output_tokens,
- encoder_out=encoder_out
+ encoder_out=encoder_out,
)
word_ins_tgt = _get_ins_targets(
@@ -162,9 +162,11 @@ def forward(
return {
"word_ins": {
- "out": word_ins_out, "tgt": word_ins_tgt,
- "mask": word_ins_masks, "ls": self.args.label_smoothing,
- "nll_loss": True
+ "out": word_ins_out,
+ "tgt": word_ins_tgt,
+ "mask": word_ins_masks,
+ "ls": self.args.label_smoothing,
+ "nll_loss": True,
}
}
@@ -178,9 +180,7 @@ def forward_decoder(
# TODO: decoding for InsertionTransformer
word_ins_score = self.decoder.forward_word_ins(
- normalize=True,
- prev_output_tokens=output_tokens,
- encoder_out=encoder_out
+ normalize=True, prev_output_tokens=output_tokens, encoder_out=encoder_out
)
if eos_penalty > 0.0:
@@ -202,7 +202,7 @@ def forward_decoder(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
- history=history
+ history=history,
)
diff --git a/fairseq/models/nat/iterative_nonautoregressive_transformer.py b/fairseq/models/nat/iterative_nonautoregressive_transformer.py
index dc340c387d..bc39509980 100644
--- a/fairseq/models/nat/iterative_nonautoregressive_transformer.py
+++ b/fairseq/models/nat/iterative_nonautoregressive_transformer.py
@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
import torch
-
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nat import NATransformerModel
@@ -44,8 +43,16 @@ def _sequential_poisoning(s, V, beta=0.33, bos=2, eos=3, pad=1):
def gumbel_noise(input, TINY=1e-8):
- return input.new_zeros(*input.size()).uniform_().add_(
- TINY).log_().neg_().add_(TINY).log_().neg_()
+ return (
+ input.new_zeros(*input.size())
+ .uniform_()
+ .add_(TINY)
+ .log_()
+ .neg_()
+ .add_(TINY)
+ .log_()
+ .neg_()
+ )
@register_model("iterative_nonautoregressive_transformer")
@@ -53,12 +60,21 @@ class IterNATransformerModel(NATransformerModel):
@staticmethod
def add_args(parser):
NATransformerModel.add_args(parser)
- parser.add_argument("--train-step", type=int,
- help="number of refinement iterations during training")
- parser.add_argument("--dae-ratio", type=float,
- help="the probability of switching to the denoising auto-encoder loss")
- parser.add_argument("--stochastic-approx", action="store_true",
- help="sampling from the decoder as the inputs for next iteration")
+ parser.add_argument(
+ "--train-step",
+ type=int,
+ help="number of refinement iterations during training",
+ )
+ parser.add_argument(
+ "--dae-ratio",
+ type=float,
+ help="the probability of switching to the denoising auto-encoder loss",
+ )
+ parser.add_argument(
+ "--stochastic-approx",
+ action="store_true",
+ help="sampling from the decoder as the inputs for next iteration",
+ )
@classmethod
def build_model(cls, args, task):
@@ -78,14 +94,18 @@ def forward(
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
# length prediction
- length_out = self.decoder.forward_length(normalize=False, encoder_out=encoder_out)
- length_tgt = self.decoder.forward_length_prediction(length_out, encoder_out, tgt_tokens)
+ length_out = self.decoder.forward_length(
+ normalize=False, encoder_out=encoder_out
+ )
+ length_tgt = self.decoder.forward_length_prediction(
+ length_out, encoder_out, tgt_tokens
+ )
# decoding
word_ins_outs, word_ins_tgts, word_ins_masks = [], [], []
for t in range(self.train_step):
word_ins_out = self.decoder(
- normalize=False,
+ normalize=False,
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
step=t,
@@ -133,14 +153,17 @@ def forward(
return {
"word_ins": {
- "out": word_ins_out, "tgt": word_ins_tgt,
- "mask": word_ins_mask, "ls": self.args.label_smoothing,
- "nll_loss": True
+ "out": word_ins_out,
+ "tgt": word_ins_tgt,
+ "mask": word_ins_mask,
+ "ls": self.args.label_smoothing,
+ "nll_loss": True,
},
"length": {
- "out": length_out, "tgt": length_tgt,
- "factor": self.decoder.length_loss_factor
- }
+ "out": length_out,
+ "tgt": length_tgt,
+ "factor": self.decoder.length_loss_factor,
+ },
}
diff --git a/fairseq/models/nat/levenshtein_transformer.py b/fairseq/models/nat/levenshtein_transformer.py
index e1748145c3..f7a3f003ca 100644
--- a/fairseq/models/nat/levenshtein_transformer.py
+++ b/fairseq/models/nat/levenshtein_transformer.py
@@ -6,33 +6,26 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import register_model, register_model_architecture
-from fairseq.models.transformer import (
- Embedding,
- TransformerDecoderLayer
-)
-
-from fairseq.models.nat import (
- FairseqNATModel,
- FairseqNATDecoder,
- ensemble_decoder
-)
-
+from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder
+from fairseq.models.transformer import Embedding, TransformerDecoderLayer
from fairseq.modules.transformer_sentence_encoder import init_bert_params
-
from .levenshtein_utils import (
- _skip, _skip_encoder_out, _fill,
- _get_ins_targets, _get_del_targets,
- _apply_ins_masks, _apply_ins_words, _apply_del_words
+ _apply_del_words,
+ _apply_ins_masks,
+ _apply_ins_words,
+ _fill,
+ _get_del_targets,
+ _get_ins_targets,
+ _skip,
+ _skip_encoder_out,
)
@register_model("levenshtein_transformer")
class LevenshteinTransformerModel(FairseqNATModel):
-
@property
def allow_length_beam(self):
return False
@@ -63,8 +56,8 @@ def add_args(parser):
)
parser.add_argument(
"--sampling-for-deletion",
- action='store_true',
- help='instead of argmax, use sampling to predict the tokens'
+ action="store_true",
+ help="instead of argmax, use sampling to predict the tokens",
)
@classmethod
@@ -93,19 +86,19 @@ def forward(
mask_ins_out, _ = self.decoder.forward_mask_ins(
normalize=False,
prev_output_tokens=prev_output_tokens,
- encoder_out=encoder_out
+ encoder_out=encoder_out,
)
word_ins_out, _ = self.decoder.forward_word_ins(
normalize=False,
prev_output_tokens=masked_tgt_tokens,
- encoder_out=encoder_out
+ encoder_out=encoder_out,
)
# make online prediction
if self.decoder.sampling_for_deletion:
word_predictions = torch.multinomial(
- F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1).view(
- word_ins_out.size(0), -1)
+ F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1
+ ).view(word_ins_out.size(0), -1)
else:
word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1]
@@ -118,23 +111,29 @@ def forward(
word_del_out, _ = self.decoder.forward_word_del(
normalize=False,
prev_output_tokens=word_predictions,
- encoder_out=encoder_out)
+ encoder_out=encoder_out,
+ )
word_del_masks = word_predictions.ne(self.pad)
return {
"mask_ins": {
- "out": mask_ins_out, "tgt": mask_ins_targets,
- "mask": mask_ins_masks, "ls": 0.01,
+ "out": mask_ins_out,
+ "tgt": mask_ins_targets,
+ "mask": mask_ins_masks,
+ "ls": 0.01,
},
"word_ins": {
- "out": word_ins_out, "tgt": tgt_tokens,
- "mask": masked_tgt_masks, "ls": self.args.label_smoothing,
- "nll_loss": True
+ "out": word_ins_out,
+ "tgt": tgt_tokens,
+ "mask": masked_tgt_masks,
+ "ls": self.args.label_smoothing,
+ "nll_loss": True,
},
"word_del": {
- "out": word_del_out, "tgt": word_del_targets,
- "mask": word_del_masks
- }
+ "out": word_del_out,
+ "tgt": word_del_targets,
+ "mask": word_del_masks,
+ },
}
def forward_decoder(
@@ -164,7 +163,7 @@ def forward_decoder(
word_del_score, word_del_attn = self.decoder.forward_word_del(
normalize=True,
prev_output_tokens=_skip(output_tokens, can_del_word),
- encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_del_word)
+ encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_del_word),
)
word_del_pred = word_del_score.max(-1)[1].bool()
@@ -179,7 +178,7 @@ def forward_decoder(
)
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
- attn = _fill(attn, can_del_word, _attn, 0.)
+ attn = _fill(attn, can_del_word, _attn, 0.0)
if history is not None:
history.append(output_tokens.clone())
@@ -190,7 +189,7 @@ def forward_decoder(
mask_ins_score, _ = self.decoder.forward_mask_ins(
normalize=True,
prev_output_tokens=_skip(output_tokens, can_ins_mask),
- encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_mask)
+ encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_mask),
)
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] = mask_ins_score[:, :, 0] - eos_penalty
@@ -219,7 +218,7 @@ def forward_decoder(
word_ins_score, word_ins_attn = self.decoder.forward_word_ins(
normalize=True,
prev_output_tokens=_skip(output_tokens, can_ins_word),
- encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_word)
+ encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_word),
)
word_ins_score, word_ins_pred = word_ins_score.max(-1)
_tokens, _scores = _apply_ins_words(
@@ -232,7 +231,7 @@ def forward_decoder(
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_word, _scores, 0)
- attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
+ attn = _fill(attn, can_ins_word, word_ins_attn, 0.0)
if history is not None:
history.append(output_tokens.clone())
@@ -247,7 +246,7 @@ def forward_decoder(
output_tokens=output_tokens,
output_scores=output_scores,
attn=attn,
- history=history
+ history=history,
)
def initialize_output_tokens(self, encoder_out, src_tokens):
@@ -265,7 +264,7 @@ def initialize_output_tokens(self, encoder_out, src_tokens):
attn=None,
step=0,
max_step=0,
- history=None
+ history=None,
)
@@ -283,29 +282,40 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
self.embed_word_del = Embedding(2, self.output_embed_dim, None)
# del_word, ins_mask, ins_word
- self.early_exit = [int(i) for i in args.early_exit.split(',')]
+ self.early_exit = [int(i) for i in args.early_exit.split(",")]
assert len(self.early_exit) == 3
# copy layers for mask-predict/deletion
self.layers_msk = None
if getattr(args, "no_share_maskpredictor", False):
- self.layers_msk = nn.ModuleList([
- TransformerDecoderLayer(args, no_encoder_attn)
- for _ in range(self.early_exit[1])
- ])
+ self.layers_msk = nn.ModuleList(
+ [
+ TransformerDecoderLayer(args, no_encoder_attn)
+ for _ in range(self.early_exit[1])
+ ]
+ )
self.layers_del = None
if getattr(args, "no_share_discriminator", False):
- self.layers_del = nn.ModuleList([
- TransformerDecoderLayer(args, no_encoder_attn)
- for _ in range(self.early_exit[0])
- ])
+ self.layers_del = nn.ModuleList(
+ [
+ TransformerDecoderLayer(args, no_encoder_attn)
+ for _ in range(self.early_exit[0])
+ ]
+ )
if getattr(args, "share_discriminator_maskpredictor", False):
- assert getattr(args, "no_share_discriminator", False), "must set saperate discriminator"
+ assert getattr(
+ args, "no_share_discriminator", False
+ ), "must set saperate discriminator"
self.layers_msk = self.layers_del
def extract_features(
- self, prev_output_tokens, encoder_out=None, early_exit=None, layers=None, **unused
+ self,
+ prev_output_tokens,
+ encoder_out=None,
+ early_exit=None,
+ layers=None,
+ **unused
):
"""
Similar to *forward* but only return features.
@@ -344,7 +354,7 @@ def extract_features(
decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
layers = self.layers if layers is None else layers
early_exit = len(layers) if early_exit is None else early_exit
- for _, layer in enumerate(layers[: early_exit]):
+ for _, layer in enumerate(layers[:early_exit]):
x, attn, _ = layer(
x,
encoder_out.encoder_out if encoder_out is not None else None,
@@ -368,33 +378,45 @@ def extract_features(
@ensemble_decoder
def forward_mask_ins(self, normalize, encoder_out, prev_output_tokens, **unused):
features, extra = self.extract_features(
- prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1], layers=self.layers_msk, **unused
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ early_exit=self.early_exit[1],
+ layers=self.layers_msk,
+ **unused
)
features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
decoder_out = F.linear(features_cat, self.embed_mask_ins.weight)
if normalize:
- return F.log_softmax(decoder_out, -1), extra['attn']
- return decoder_out, extra['attn']
+ return F.log_softmax(decoder_out, -1), extra["attn"]
+ return decoder_out, extra["attn"]
@ensemble_decoder
def forward_word_ins(self, normalize, encoder_out, prev_output_tokens, **unused):
features, extra = self.extract_features(
- prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2], layers=self.layers, **unused
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ early_exit=self.early_exit[2],
+ layers=self.layers,
+ **unused
)
decoder_out = self.output_layer(features)
if normalize:
- return F.log_softmax(decoder_out, -1), extra['attn']
- return decoder_out, extra['attn']
+ return F.log_softmax(decoder_out, -1), extra["attn"]
+ return decoder_out, extra["attn"]
@ensemble_decoder
def forward_word_del(self, normalize, encoder_out, prev_output_tokens, **unused):
features, extra = self.extract_features(
- prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0], layers=self.layers_del, **unused
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ early_exit=self.early_exit[0],
+ layers=self.layers_del,
+ **unused
)
decoder_out = F.linear(features, self.embed_word_del.weight)
if normalize:
- return F.log_softmax(decoder_out, -1), extra['attn']
- return decoder_out, extra['attn']
+ return F.log_softmax(decoder_out, -1), extra["attn"]
+ return decoder_out, extra["attn"]
@register_model_architecture("levenshtein_transformer", "levenshtein_transformer")
@@ -439,7 +461,9 @@ def levenshtein_base_architecture(args):
args.early_exit = getattr(args, "early_exit", "6,6,6")
args.no_share_discriminator = getattr(args, "no_share_discriminator", False)
args.no_share_maskpredictor = getattr(args, "no_share_maskpredictor", False)
- args.share_discriminator_maskpredictor = getattr(args, "share_discriminator_maskpredictor", False)
+ args.share_discriminator_maskpredictor = getattr(
+ args, "share_discriminator_maskpredictor", False
+ )
args.no_share_last_layer = getattr(args, "no_share_last_layer", False)
diff --git a/fairseq/models/nat/levenshtein_utils.py b/fairseq/models/nat/levenshtein_utils.py
index 11fb29578b..375a98c2e1 100644
--- a/fairseq/models/nat/levenshtein_utils.py
+++ b/fairseq/models/nat/levenshtein_utils.py
@@ -9,21 +9,27 @@
# -------------- Helper Functions --------------------------------------------------- #
+
def load_libnat():
try:
from fairseq import libnat_cuda
+
return libnat_cuda, True
except ImportError as e:
- print(str(e) + '... fall back to CPU version')
+ print(str(e) + "... fall back to CPU version")
try:
from fairseq import libnat
+
return libnat, False
except ImportError as e:
import sys
- sys.stderr.write("ERROR: missing libnat_cuda. run `python setup.py build_ext --inplace`\n")
+
+ sys.stderr.write(
+ "ERROR: missing libnat_cuda. run `python setup.py build_ext --inplace`\n"
+ )
raise e
@@ -34,14 +40,18 @@ def _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx):
in_masks = in_tokens.ne(padding_idx)
out_masks = out_tokens.ne(padding_idx)
mask_ins_targets, masked_tgt_masks = libnat.generate_insertion_labels(
- out_tokens.int(), libnat.levenshtein_distance(
- in_tokens.int(), out_tokens.int(),
- in_masks.sum(1).int(), out_masks.sum(1).int()
- )
+ out_tokens.int(),
+ libnat.levenshtein_distance(
+ in_tokens.int(),
+ out_tokens.int(),
+ in_masks.sum(1).int(),
+ out_masks.sum(1).int(),
+ ),
)
masked_tgt_masks = masked_tgt_masks.bool() & out_masks
- mask_ins_targets = mask_ins_targets.type_as(
- in_tokens)[:, 1:in_masks.size(1)].masked_fill_(~in_masks[:, 1:], 0)
+ mask_ins_targets = mask_ins_targets.type_as(in_tokens)[
+ :, 1 : in_masks.size(1)
+ ].masked_fill_(~in_masks[:, 1:], 0)
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
@@ -73,7 +83,8 @@ def _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx):
mask_label + [0 for _ in range(out_seq_len - len(mask_label))]
)
mask_ins_targets = [
- mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
+ mask_input[1:-1]
+ + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
for mask_input in mask_inputs
]
@@ -100,18 +111,23 @@ def _get_del_targets_cuda(in_tokens, out_tokens, padding_idx):
word_del_targets = libnat.generate_deletion_labels(
in_tokens.int(),
libnat.levenshtein_distance(
- in_tokens.int(), out_tokens.int(),
- in_masks.sum(1).int(), out_masks.sum(1).int()
- )
+ in_tokens.int(),
+ out_tokens.int(),
+ in_masks.sum(1).int(),
+ out_masks.sum(1).int(),
+ ),
+ )
+ word_del_targets = word_del_targets.type_as(in_tokens).masked_fill_(
+ ~in_masks, 0
)
- word_del_targets = word_del_targets.type_as(in_tokens).masked_fill_(~in_masks, 0)
return word_del_targets
def _get_del_targets_cpu(in_tokens, out_tokens, padding_idx):
out_seq_len = out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
in_tokens_list = [
- [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
+ [t for t in s if t != padding_idx]
+ for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx]
@@ -149,10 +165,7 @@ def _apply_ins_masks(
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
- out_masks = (
- new_arange(out_lengths, out_max_len)[None, :]
- < out_lengths[:, None]
- )
+ out_masks = new_arange(out_lengths, out_max_len)[None, :] < out_lengths[:, None]
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
@@ -173,9 +186,7 @@ def _apply_ins_masks(
return out_tokens, out_scores
-def _apply_ins_words(
- in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx
-):
+def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
@@ -200,11 +211,7 @@ def _apply_del_words(
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
- reordering = (
- new_arange(in_tokens)
- .masked_fill_(word_del_pred, max_len)
- .sort(1)[1]
- )
+ reordering = new_arange(in_tokens).masked_fill_(word_del_pred, max_len).sort(1)[1]
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
@@ -216,7 +223,7 @@ def _apply_del_words(
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
- out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
+ out_attn = in_attn.masked_fill(_mask, 0.0).gather(1, _reordering)
return out_tokens, out_scores, out_attn
@@ -250,7 +257,9 @@ def _skip_encoder_out(encoder, encoder_out, mask):
if not mask.any():
return encoder_out
else:
- return encoder.reorder_encoder_out(encoder_out, mask.nonzero(as_tuple=False).squeeze())
+ return encoder.reorder_encoder_out(
+ encoder_out, mask.nonzero(as_tuple=False).squeeze()
+ )
def _fill(x, mask, y, padding_idx):
@@ -276,9 +285,9 @@ def _fill(x, mask, y, padding_idx):
elif x.size(1) > y.size(1):
x[mask] = padding_idx
if x.dim() == 2:
- x[mask, :y.size(1)] = y
+ x[mask, : y.size(1)] = y
else:
- x[mask, :y.size(1), :] = y
+ x[mask, : y.size(1), :] = y
else:
x[mask] = y
return x
diff --git a/fairseq/models/nat/nat_crf_transformer.py b/fairseq/models/nat/nat_crf_transformer.py
index 8dd3a08f72..d4b3cd931c 100644
--- a/fairseq/models/nat/nat_crf_transformer.py
+++ b/fairseq/models/nat/nat_crf_transformer.py
@@ -4,8 +4,8 @@
# LICENSE file in the root directory of this source tree.
-from fairseq.models.nat import NATransformerModel, base_architecture
from fairseq.models import register_model, register_model_architecture
+from fairseq.models.nat import NATransformerModel, base_architecture
from fairseq.modules import DynamicCRF
@@ -16,7 +16,7 @@ def __init__(self, args, encoder, decoder):
self.crf_layer = DynamicCRF(
num_embedding=len(self.tgt_dict),
low_rank=args.crf_lowrank_approx,
- beam_size=args.crf_beam_approx
+ beam_size=args.crf_beam_approx,
)
@property
@@ -26,12 +26,21 @@ def allow_ensemble(self):
@staticmethod
def add_args(parser):
NATransformerModel.add_args(parser)
- parser.add_argument("--crf-lowrank-approx", type=int,
- help="the dimension of low-rank approximation of transition")
- parser.add_argument("--crf-beam-approx", type=int,
- help="the beam size for apporixmating the normalizing factor")
- parser.add_argument("--word-ins-loss-factor", type=float,
- help="weights on NAT loss used to co-training with CRF loss.")
+ parser.add_argument(
+ "--crf-lowrank-approx",
+ type=int,
+ help="the dimension of low-rank approximation of transition",
+ )
+ parser.add_argument(
+ "--crf-beam-approx",
+ type=int,
+ help="the beam size for apporixmating the normalizing factor",
+ )
+ parser.add_argument(
+ "--word-ins-loss-factor",
+ type=float,
+ help="weights on NAT loss used to co-training with CRF loss.",
+ )
def forward(
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
@@ -40,14 +49,19 @@ def forward(
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
# length prediction
- length_out = self.decoder.forward_length(normalize=False, encoder_out=encoder_out)
- length_tgt = self.decoder.forward_length_prediction(length_out, encoder_out, tgt_tokens)
+ length_out = self.decoder.forward_length(
+ normalize=False, encoder_out=encoder_out
+ )
+ length_tgt = self.decoder.forward_length_prediction(
+ length_out, encoder_out, tgt_tokens
+ )
# decoding
word_ins_out = self.decoder(
normalize=False,
prev_output_tokens=prev_output_tokens,
- encoder_out=encoder_out)
+ encoder_out=encoder_out,
+ )
word_ins_tgt, word_ins_mask = tgt_tokens, tgt_tokens.ne(self.pad)
# compute the log-likelihood of CRF
@@ -56,17 +70,19 @@ def forward(
return {
"word_ins": {
- "out": word_ins_out, "tgt": word_ins_tgt,
- "mask": word_ins_mask, "ls": self.args.label_smoothing,
- "nll_loss": True, "factor": self.args.word_ins_loss_factor
- },
- "word_crf": {
- "loss": crf_nll
+ "out": word_ins_out,
+ "tgt": word_ins_tgt,
+ "mask": word_ins_mask,
+ "ls": self.args.label_smoothing,
+ "nll_loss": True,
+ "factor": self.args.word_ins_loss_factor,
},
+ "word_crf": {"loss": crf_nll},
"length": {
- "out": length_out, "tgt": length_tgt,
- "factor": self.decoder.length_loss_factor
- }
+ "out": length_out,
+ "tgt": length_tgt,
+ "factor": self.decoder.length_loss_factor,
+ },
}
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
@@ -77,9 +93,7 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar
# execute the decoder and get emission scores
output_masks = output_tokens.ne(self.pad)
word_ins_out = self.decoder(
- normalize=False,
- prev_output_tokens=output_tokens,
- encoder_out=encoder_out
+ normalize=False, prev_output_tokens=output_tokens, encoder_out=encoder_out
)
# run viterbi decoding through CRF
@@ -93,7 +107,7 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
- history=history
+ history=history,
)
diff --git a/fairseq/models/nat/nonautoregressive_ensembles.py b/fairseq/models/nat/nonautoregressive_ensembles.py
index 2ed4d956e0..46bb8aac43 100644
--- a/fairseq/models/nat/nonautoregressive_ensembles.py
+++ b/fairseq/models/nat/nonautoregressive_ensembles.py
@@ -7,14 +7,13 @@
import torch
import torch.nn.functional as F
-
from fairseq.models.nat import (
+ _apply_del_words,
+ _apply_ins_masks,
+ _apply_ins_words,
_fill,
_skip,
_skip_encoder_out,
- _apply_ins_masks,
- _apply_ins_words,
- _apply_del_words,
)
@@ -43,7 +42,7 @@ def __init__(self, models):
self.encoder = _EnsembleModelEncoder(self.models)
def has_encoder(self):
- return hasattr(self.models[0], 'encoder')
+ return hasattr(self.models[0], "encoder")
def max_decoder_positions(self):
return min(m.max_decoder_positions() for m in self.models)
@@ -69,7 +68,9 @@ def __init__(self, models):
super().__init__(models)
@torch.no_grad()
- def forward_decoder(self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs):
+ def forward_decoder(
+ self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs
+ ):
# LevT ensembling
# A pipeline of three steps: deletion, placeholder, and word insertion.
# We need to average scores in each step in a pipeline way because of dependence.
@@ -83,7 +84,11 @@ def forward_decoder(self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=
max_lens = output_tokens.new().fill_(255)
else:
if encoder_outs[0].encoder_padding_mask is None:
- src_lens = encoder_outs[0].encoder_out.new(bsz).fill_(encoder_outs[0].encoder_out.size(1))
+ src_lens = (
+ encoder_outs[0]
+ .encoder_out.new(bsz)
+ .fill_(encoder_outs[0].encoder_out.size(1))
+ )
else:
src_lens = (~encoder_outs[0].encoder_padding_mask).sum(1)
max_lens = (src_lens * max_ratio).clamp(min=10).long()
@@ -104,13 +109,13 @@ def forward_decoder(self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
if can_ins_mask.sum() != 0:
output_tokens, output_scores = self.forward_mask_ins(
- encoder_outs,
- output_tokens,
- output_scores,
- can_ins_mask,
- eos_penalty,
- max_lens,
- )
+ encoder_outs,
+ output_tokens,
+ output_scores,
+ can_ins_mask,
+ eos_penalty,
+ max_lens,
+ )
# insert words
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
@@ -132,10 +137,12 @@ def forward_decoder(self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=
output_tokens=output_tokens,
output_scores=output_scores,
attn=attn,
- history=None
+ history=None,
)
- def forward_word_del(self, encoder_outs, output_tokens, output_scores, attn, can_del_word):
+ def forward_word_del(
+ self, encoder_outs, output_tokens, output_scores, attn, can_del_word
+ ):
word_del_score_avg = []
word_del_attn_avg = []
for model, encoder_out in zip(self.models, encoder_outs):
@@ -146,10 +153,12 @@ def forward_word_del(self, encoder_outs, output_tokens, output_scores, attn, can
word_del_score = F.log_softmax(word_del_out, 2)
word_del_score_avg.append(word_del_score)
word_del_attn_avg.append(word_del_attn)
- word_del_score_avg = torch.logsumexp(torch.stack(word_del_score_avg, dim=0), dim=0) - math.log(len(self.models))
+ word_del_score_avg = torch.logsumexp(
+ torch.stack(word_del_score_avg, dim=0), dim=0
+ ) - math.log(len(self.models))
word_del_pred = word_del_score_avg.max(-1)[1].bool()
if word_del_attn_avg[0] is not None:
- word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0)/len(self.models)
+ word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0) / len(self.models)
else:
word_del_attn_avg = None
@@ -164,10 +173,18 @@ def forward_word_del(self, encoder_outs, output_tokens, output_scores, attn, can
)
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
- attn = _fill(attn, can_del_word, _attn, 0.)
+ attn = _fill(attn, can_del_word, _attn, 0.0)
return output_tokens, output_scores, attn
- def forward_mask_ins(self, encoder_outs, output_tokens, output_scores, can_ins_mask, eos_penalty, max_lens):
+ def forward_mask_ins(
+ self,
+ encoder_outs,
+ output_tokens,
+ output_scores,
+ can_ins_mask,
+ eos_penalty,
+ max_lens,
+ ):
mask_ins_score_avg = []
for model, encoder_out in zip(self.models, encoder_outs):
mask_ins_out, _ = model.decoder.forward_mask_ins(
@@ -178,7 +195,9 @@ def forward_mask_ins(self, encoder_outs, output_tokens, output_scores, can_ins_m
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] -= eos_penalty
mask_ins_score_avg.append(mask_ins_score)
- mask_ins_score_avg = torch.logsumexp(torch.stack(mask_ins_score_avg, dim=0), dim=0) - math.log(len(self.models))
+ mask_ins_score_avg = torch.logsumexp(
+ torch.stack(mask_ins_score_avg, dim=0), dim=0
+ ) - math.log(len(self.models))
mask_ins_pred = mask_ins_score_avg.max(-1)[1]
mask_ins_pred = torch.min(
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
@@ -195,7 +214,9 @@ def forward_mask_ins(self, encoder_outs, output_tokens, output_scores, can_ins_m
output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
return output_tokens, output_scores
- def forward_word_ins(self, encoder_outs, output_tokens, output_scores, attn, can_ins_word):
+ def forward_word_ins(
+ self, encoder_outs, output_tokens, output_scores, attn, can_ins_word
+ ):
word_ins_score_avg = []
word_ins_attn_avg = []
for model, encoder_out in zip(self.models, encoder_outs):
@@ -206,9 +227,11 @@ def forward_word_ins(self, encoder_outs, output_tokens, output_scores, attn, can
word_ins_score = F.log_softmax(word_ins_out, 2)
word_ins_score_avg.append(word_ins_score)
word_ins_attn_avg.append(word_ins_attn)
- word_ins_score_avg = torch.logsumexp(torch.stack(word_ins_score_avg, dim=0), dim=0) - math.log(len(self.models))
+ word_ins_score_avg = torch.logsumexp(
+ torch.stack(word_ins_score_avg, dim=0), dim=0
+ ) - math.log(len(self.models))
if word_ins_attn_avg[0] is not None:
- word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0)/len(self.models)
+ word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(self.models)
else:
word_ins_attn_avg = None
word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1)
@@ -223,7 +246,7 @@ def forward_word_ins(self, encoder_outs, output_tokens, output_scores, attn, can
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_word, _scores, 0)
- attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
+ attn = _fill(attn, can_ins_word, word_ins_attn, 0.0)
return output_tokens, output_scores, attn
def initialize_output_tokens(self, encoder_outs, src_tokens):
diff --git a/fairseq/models/nat/nonautoregressive_transformer.py b/fairseq/models/nat/nonautoregressive_transformer.py
index 050755c308..735297fc29 100644
--- a/fairseq/models/nat/nonautoregressive_transformer.py
+++ b/fairseq/models/nat/nonautoregressive_transformer.py
@@ -5,17 +5,11 @@
import torch
import torch.nn.functional as F
-
from fairseq import utils
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import register_model, register_model_architecture
+from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder
from fairseq.models.transformer import Embedding
-
-from fairseq.models.nat import (
- FairseqNATModel,
- FairseqNATDecoder,
- ensemble_decoder
-)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
@@ -48,7 +42,6 @@ def _uniform_assignment(src_lens, trg_lens):
@register_model("nonautoregressive_transformer")
class NATransformerModel(FairseqNATModel):
-
@property
def allow_length_beam(self):
return True
@@ -58,14 +51,26 @@ def add_args(parser):
FairseqNATModel.add_args(parser)
# length prediction
- parser.add_argument("--src-embedding-copy", action="store_true",
- help="copy encoder word embeddings as the initial input of the decoder")
- parser.add_argument("--pred-length-offset", action="store_true",
- help="predicting the length difference between the target and source sentences")
- parser.add_argument("--sg-length-pred", action="store_true",
- help="stop the gradients back-propagated from the length predictor")
- parser.add_argument("--length-loss-factor", type=float,
- help="weights on the length prediction loss")
+ parser.add_argument(
+ "--src-embedding-copy",
+ action="store_true",
+ help="copy encoder word embeddings as the initial input of the decoder",
+ )
+ parser.add_argument(
+ "--pred-length-offset",
+ action="store_true",
+ help="predicting the length difference between the target and source sentences",
+ )
+ parser.add_argument(
+ "--sg-length-pred",
+ action="store_true",
+ help="stop the gradients back-propagated from the length predictor",
+ )
+ parser.add_argument(
+ "--length-loss-factor",
+ type=float,
+ help="weights on the length prediction loss",
+ )
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
@@ -81,25 +86,33 @@ def forward(
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
# length prediction
- length_out = self.decoder.forward_length(normalize=False, encoder_out=encoder_out)
- length_tgt = self.decoder.forward_length_prediction(length_out, encoder_out, tgt_tokens)
+ length_out = self.decoder.forward_length(
+ normalize=False, encoder_out=encoder_out
+ )
+ length_tgt = self.decoder.forward_length_prediction(
+ length_out, encoder_out, tgt_tokens
+ )
# decoding
word_ins_out = self.decoder(
normalize=False,
prev_output_tokens=prev_output_tokens,
- encoder_out=encoder_out)
+ encoder_out=encoder_out,
+ )
return {
"word_ins": {
- "out": word_ins_out, "tgt": tgt_tokens,
- "mask": tgt_tokens.ne(self.pad), "ls": self.args.label_smoothing,
- "nll_loss": True
+ "out": word_ins_out,
+ "tgt": tgt_tokens,
+ "mask": tgt_tokens.ne(self.pad),
+ "ls": self.args.label_smoothing,
+ "nll_loss": True,
},
"length": {
- "out": length_out, "tgt": length_tgt,
- "factor": self.decoder.length_loss_factor
- }
+ "out": length_out,
+ "tgt": length_tgt,
+ "factor": self.decoder.length_loss_factor,
+ },
}
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
@@ -126,14 +139,14 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
- history=history
+ history=history,
)
def initialize_output_tokens(self, encoder_out, src_tokens):
# length prediction
length_tgt = self.decoder.forward_length_prediction(
self.decoder.forward_length(normalize=True, encoder_out=encoder_out),
- encoder_out=encoder_out
+ encoder_out=encoder_out,
)
max_length = length_tgt.clamp_(min=2).max()
@@ -158,13 +171,17 @@ def initialize_output_tokens(self, encoder_out, src_tokens):
attn=None,
step=0,
max_step=0,
- history=None
+ history=None,
)
def regenerate_length_beam(self, decoder_out, beam_size):
output_tokens = decoder_out.output_tokens
length_tgt = output_tokens.ne(self.pad).sum(1)
- length_tgt = length_tgt[:, None] + utils.new_arange(length_tgt, 1, beam_size) - beam_size // 2
+ length_tgt = (
+ length_tgt[:, None]
+ + utils.new_arange(length_tgt, 1, beam_size)
+ - beam_size // 2
+ )
length_tgt = length_tgt.view(-1).clamp_(min=2)
max_length = length_tgt.max()
idx_length = utils.new_arange(length_tgt, max_length)
@@ -183,8 +200,7 @@ def regenerate_length_beam(self, decoder_out, beam_size):
).type_as(decoder_out.output_scores)
return decoder_out._replace(
- output_tokens=initial_output_tokens,
- output_scores=initial_output_scores
+ output_tokens=initial_output_tokens, output_scores=initial_output_scores
)
diff --git a/fairseq/models/roberta/alignment_utils.py b/fairseq/models/roberta/alignment_utils.py
index 45d2e37194..ccc7f74cb9 100644
--- a/fairseq/models/roberta/alignment_utils.py
+++ b/fairseq/models/roberta/alignment_utils.py
@@ -29,23 +29,25 @@ def clean(text):
# remove whitespaces to simplify alignment
bpe_tokens = [roberta.task.source_dictionary.string([x]) for x in bpe_tokens]
- bpe_tokens = [clean(roberta.bpe.decode(x) if x not in {'', ''} else x) for x in bpe_tokens]
+ bpe_tokens = [
+ clean(roberta.bpe.decode(x) if x not in {"", ""} else x) for x in bpe_tokens
+ ]
other_tokens = [clean(str(o)) for o in other_tokens]
# strip leading
bpe_tokens = bpe_tokens[1:]
- assert ''.join(bpe_tokens) == ''.join(other_tokens)
+ assert "".join(bpe_tokens) == "".join(other_tokens)
# create alignment from every word to a list of BPE tokens
alignment = []
- bpe_toks = filter(lambda item: item[1] != '', enumerate(bpe_tokens, start=1))
+ bpe_toks = filter(lambda item: item[1] != "", enumerate(bpe_tokens, start=1))
j, bpe_tok = next(bpe_toks)
for other_tok in other_tokens:
bpe_indices = []
while True:
if other_tok.startswith(bpe_tok):
bpe_indices.append(j)
- other_tok = other_tok[len(bpe_tok):]
+ other_tok = other_tok[len(bpe_tok) :]
try:
j, bpe_tok = next(bpe_toks)
except StopIteration:
@@ -53,11 +55,11 @@ def clean(text):
elif bpe_tok.startswith(other_tok):
# other_tok spans multiple BPE tokens
bpe_indices.append(j)
- bpe_tok = bpe_tok[len(other_tok):]
- other_tok = ''
+ bpe_tok = bpe_tok[len(other_tok) :]
+ other_tok = ""
else:
raise Exception('Cannot align "{}" and "{}"'.format(other_tok, bpe_tok))
- if other_tok == '':
+ if other_tok == "":
break
assert len(bpe_indices) > 0
alignment.append(bpe_indices)
@@ -96,20 +98,21 @@ def align_features_to_words(roberta, features, alignment):
def spacy_nlp():
- if getattr(spacy_nlp, '_nlp', None) is None:
+ if getattr(spacy_nlp, "_nlp", None) is None:
try:
from spacy.lang.en import English
+
spacy_nlp._nlp = English()
except ImportError:
- raise ImportError('Please install spacy with: pip install spacy')
+ raise ImportError("Please install spacy with: pip install spacy")
return spacy_nlp._nlp
def spacy_tokenizer():
- if getattr(spacy_tokenizer, '_tokenizer', None) is None:
+ if getattr(spacy_tokenizer, "_tokenizer", None) is None:
try:
nlp = spacy_nlp()
spacy_tokenizer._tokenizer = nlp.Defaults.create_tokenizer(nlp)
except ImportError:
- raise ImportError('Please install spacy with: pip install spacy')
+ raise ImportError("Please install spacy with: pip install spacy")
return spacy_tokenizer._tokenizer
diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py
index 20456b3f5c..526823bd1f 100644
--- a/fairseq/models/roberta/hub_interface.py
+++ b/fairseq/models/roberta/hub_interface.py
@@ -7,7 +7,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import utils
from fairseq.data import encoders
@@ -27,13 +26,15 @@ def __init__(self, args, task, model):
self.bpe = encoders.build_bpe(args)
# this is useful for determining the device
- self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
+ self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
- def encode(self, sentence: str, *addl_sentences, no_separator=False) -> torch.LongTensor:
+ def encode(
+ self, sentence: str, *addl_sentences, no_separator=False
+ ) -> torch.LongTensor:
"""
BPE-encode a sentence (or multiple sentences).
@@ -54,11 +55,13 @@ def encode(self, sentence: str, *addl_sentences, no_separator=False) -> torch.Lo
>>> roberta.encode('world').tolist()
[0, 8331, 2]
"""
- bpe_sentence = ' ' + self.bpe.encode(sentence) + ' '
+ bpe_sentence = " " + self.bpe.encode(sentence) + " "
for s in addl_sentences:
- bpe_sentence += (' ' if not no_separator else '')
- bpe_sentence += ' ' + self.bpe.encode(s) + ' '
- tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False, add_if_not_exist=False)
+ bpe_sentence += " " if not no_separator else ""
+ bpe_sentence += " " + self.bpe.encode(s) + " "
+ tokens = self.task.source_dictionary.encode_line(
+ bpe_sentence, append_eos=False, add_if_not_exist=False
+ )
return tokens.long()
def decode(self, tokens: torch.LongTensor):
@@ -66,21 +69,27 @@ def decode(self, tokens: torch.LongTensor):
tokens = tokens.numpy()
if tokens[0] == self.task.source_dictionary.bos():
tokens = tokens[1:] # remove
- eos_mask = (tokens == self.task.source_dictionary.eos())
+ eos_mask = tokens == self.task.source_dictionary.eos()
doc_mask = eos_mask[1:] & eos_mask[:-1]
sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
- sentences = [self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences]
+ sentences = [
+ self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
+ ]
if len(sentences) == 1:
return sentences[0]
return sentences
- def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor:
+ def extract_features(
+ self, tokens: torch.LongTensor, return_all_hiddens: bool = False
+ ) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.size(-1) > self.model.max_positions():
- raise ValueError('tokens exceeds maximum length: {} > {}'.format(
- tokens.size(-1), self.model.max_positions()
- ))
+ raise ValueError(
+ "tokens exceeds maximum length: {} > {}".format(
+ tokens.size(-1), self.model.max_positions()
+ )
+ )
features, extra = self.model(
tokens.to(device=self.device),
features_only=True,
@@ -88,7 +97,7 @@ def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool =
)
if return_all_hiddens:
# convert from T x B x C -> B x T x C
- inner_states = extra['inner_states']
+ inner_states = extra["inner_states"]
return [inner_state.transpose(0, 1) for inner_state in inner_states]
else:
return features # just the last layer's features
@@ -107,7 +116,9 @@ def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = Fal
return logits
return F.log_softmax(logits, dim=-1)
- def extract_features_aligned_to_words(self, sentence: str, return_all_hiddens: bool = False) -> torch.Tensor:
+ def extract_features_aligned_to_words(
+ self, sentence: str, return_all_hiddens: bool = False
+ ) -> torch.Tensor:
"""Extract RoBERTa features, aligned to spaCy's word-level tokenizer."""
from fairseq.models.roberta import alignment_utils
from spacy.tokens import Doc
@@ -122,31 +133,42 @@ def extract_features_aligned_to_words(self, sentence: str, return_all_hiddens: b
alignment = alignment_utils.align_bpe_to_words(self, bpe_toks, spacy_toks_ws)
# extract features and align them
- features = self.extract_features(bpe_toks, return_all_hiddens=return_all_hiddens)
+ features = self.extract_features(
+ bpe_toks, return_all_hiddens=return_all_hiddens
+ )
features = features.squeeze(0)
- aligned_feats = alignment_utils.align_features_to_words(self, features, alignment)
+ aligned_feats = alignment_utils.align_features_to_words(
+ self, features, alignment
+ )
# wrap in spaCy Doc
doc = Doc(
nlp.vocab,
- words=[''] + [x.text for x in spacy_toks] + [''],
- spaces=[True] + [x.endswith(' ') for x in spacy_toks_ws[:-1]] + [True, False],
+ words=[""] + [x.text for x in spacy_toks] + [""],
+ spaces=[True]
+ + [x.endswith(" ") for x in spacy_toks_ws[:-1]]
+ + [True, False],
)
assert len(doc) == aligned_feats.size(0)
- doc.user_token_hooks['vector'] = lambda token: aligned_feats[token.i]
+ doc.user_token_hooks["vector"] = lambda token: aligned_feats[token.i]
return doc
def fill_mask(self, masked_input: str, topk: int = 5):
- masked_token = ''
- assert masked_token in masked_input and masked_input.count(masked_token) == 1, \
- "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(masked_token)
+ masked_token = ""
+ assert (
+ masked_token in masked_input and masked_input.count(masked_token) == 1
+ ), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(
+ masked_token
+ )
text_spans = masked_input.split(masked_token)
- text_spans_bpe = (' {0} '.format(masked_token)).join(
- [self.bpe.encode(text_span.rstrip()) for text_span in text_spans]
- ).strip()
+ text_spans_bpe = (
+ (" {0} ".format(masked_token))
+ .join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans])
+ .strip()
+ )
tokens = self.task.source_dictionary.encode_line(
- ' ' + text_spans_bpe + ' ',
+ " " + text_spans_bpe + " ",
append_eos=False,
add_if_not_exist=False,
)
@@ -167,25 +189,31 @@ def fill_mask(self, masked_input: str, topk: int = 5):
topk_predicted_token_bpe = self.task.source_dictionary.string(index)
topk_filled_outputs = []
- for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')):
+ for index, predicted_token_bpe in enumerate(
+ topk_predicted_token_bpe.split(" ")
+ ):
predicted_token = self.bpe.decode(predicted_token_bpe)
# Quick hack to fix https://github.com/pytorch/fairseq/issues/1306
- if predicted_token_bpe.startswith('\u2581'):
- predicted_token = ' ' + predicted_token
+ if predicted_token_bpe.startswith("\u2581"):
+ predicted_token = " " + predicted_token
if " {0}".format(masked_token) in masked_input:
- topk_filled_outputs.append((
- masked_input.replace(
- ' {0}'.format(masked_token), predicted_token
- ),
- values[index].item(),
- predicted_token,
- ))
+ topk_filled_outputs.append(
+ (
+ masked_input.replace(
+ " {0}".format(masked_token), predicted_token
+ ),
+ values[index].item(),
+ predicted_token,
+ )
+ )
else:
- topk_filled_outputs.append((
- masked_input.replace(masked_token, predicted_token),
- values[index].item(),
- predicted_token,
- ))
+ topk_filled_outputs.append(
+ (
+ masked_input.replace(masked_token, predicted_token),
+ values[index].item(),
+ predicted_token,
+ )
+ )
return topk_filled_outputs
def disambiguate_pronoun(self, sentence: str) -> bool:
@@ -198,7 +226,10 @@ def disambiguate_pronoun(self, sentence: str) -> bool:
>>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.')
'The trophy'
"""
- assert hasattr(self.task, 'disambiguate_pronoun'), \
- 'roberta.disambiguate_pronoun() requires a model trained with the WSC task.'
+ assert hasattr(
+ self.task, "disambiguate_pronoun"
+ ), "roberta.disambiguate_pronoun() requires a model trained with the WSC task."
with utils.model_eval(self.model):
- return self.task.disambiguate_pronoun(self.model, sentence, use_cuda=self.device.type == 'cuda')
+ return self.task.disambiguate_pronoun(
+ self.model, sentence, use_cuda=self.device.type == "cuda"
+ )
diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py
index 0917927e34..6ce216a6bf 100644
--- a/fairseq/models/roberta/model.py
+++ b/fairseq/models/roberta/model.py
@@ -11,7 +11,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
@@ -19,12 +18,9 @@
register_model,
register_model_architecture,
)
-from fairseq.modules import (
- LayerNorm,
- TransformerSentenceEncoder,
-)
-from fairseq.modules.transformer_sentence_encoder import init_bert_params
+from fairseq.modules import LayerNorm, TransformerSentenceEncoder
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
+from fairseq.modules.transformer_sentence_encoder import init_bert_params
from .hub_interface import RobertaHubInterface
@@ -32,16 +28,15 @@
logger = logging.getLogger(__name__)
-@register_model('roberta')
+@register_model("roberta")
class RobertaModel(FairseqEncoderModel):
-
@classmethod
def hub_models(cls):
return {
- 'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz',
- 'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz',
- 'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz',
- 'roberta.large.wsc': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz',
+ "roberta.base": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz",
+ "roberta.large": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz",
+ "roberta.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz",
+ "roberta.large.wsc": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz",
}
def __init__(self, args, encoder):
@@ -56,50 +51,117 @@ def __init__(self, args, encoder):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
- parser.add_argument('--encoder-layers', type=int, metavar='L',
- help='num encoder layers')
- parser.add_argument('--encoder-embed-dim', type=int, metavar='H',
- help='encoder embedding dimension')
- parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='F',
- help='encoder embedding dimension for FFN')
- parser.add_argument('--encoder-attention-heads', type=int, metavar='A',
- help='num encoder attention heads')
- parser.add_argument('--activation-fn',
- choices=utils.get_available_activation_fns(),
- help='activation function to use')
- parser.add_argument('--pooler-activation-fn',
- choices=utils.get_available_activation_fns(),
- help='activation function to use for pooler layer')
- parser.add_argument('--encoder-normalize-before', action='store_true',
- help='apply layernorm before each encoder block')
- parser.add_argument('--dropout', type=float, metavar='D',
- help='dropout probability')
- parser.add_argument('--attention-dropout', type=float, metavar='D',
- help='dropout probability for attention weights')
- parser.add_argument('--activation-dropout', type=float, metavar='D',
- help='dropout probability after activation in FFN')
- parser.add_argument('--pooler-dropout', type=float, metavar='D',
- help='dropout probability in the masked_lm pooler layers')
- parser.add_argument('--max-positions', type=int,
- help='number of positional embeddings to learn')
- parser.add_argument('--load-checkpoint-heads', action='store_true',
- help='(re-)register and load heads when loading checkpoints')
+ parser.add_argument(
+ "--encoder-layers", type=int, metavar="L", help="num encoder layers"
+ )
+ parser.add_argument(
+ "--encoder-embed-dim",
+ type=int,
+ metavar="H",
+ help="encoder embedding dimension",
+ )
+ parser.add_argument(
+ "--encoder-ffn-embed-dim",
+ type=int,
+ metavar="F",
+ help="encoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--encoder-attention-heads",
+ type=int,
+ metavar="A",
+ help="num encoder attention heads",
+ )
+ parser.add_argument(
+ "--activation-fn",
+ choices=utils.get_available_activation_fns(),
+ help="activation function to use",
+ )
+ parser.add_argument(
+ "--pooler-activation-fn",
+ choices=utils.get_available_activation_fns(),
+ help="activation function to use for pooler layer",
+ )
+ parser.add_argument(
+ "--encoder-normalize-before",
+ action="store_true",
+ help="apply layernorm before each encoder block",
+ )
+ parser.add_argument(
+ "--dropout", type=float, metavar="D", help="dropout probability"
+ )
+ parser.add_argument(
+ "--attention-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability for attention weights",
+ )
+ parser.add_argument(
+ "--activation-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability after activation in FFN",
+ )
+ parser.add_argument(
+ "--pooler-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability in the masked_lm pooler layers",
+ )
+ parser.add_argument(
+ "--max-positions", type=int, help="number of positional embeddings to learn"
+ )
+ parser.add_argument(
+ "--load-checkpoint-heads",
+ action="store_true",
+ help="(re-)register and load heads when loading checkpoints",
+ )
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
- parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
- help='LayerDrop probability for encoder')
- parser.add_argument('--encoder-layers-to-keep', default=None,
- help='which layers to *keep* when pruning as a comma-separated list')
+ parser.add_argument(
+ "--encoder-layerdrop",
+ type=float,
+ metavar="D",
+ default=0,
+ help="LayerDrop probability for encoder",
+ )
+ parser.add_argument(
+ "--encoder-layers-to-keep",
+ default=None,
+ help="which layers to *keep* when pruning as a comma-separated list",
+ )
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
- parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
- help='iterative PQ quantization noise at training time')
- parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
- help='block size of quantization noise at training time')
- parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
- help='scalar quantization noise and scalar quantization at training time')
- parser.add_argument('--untie-weights-roberta', action='store_true',
- help='Untie weights between embeddings and classifiers in RoBERTa')
- parser.add_argument('--spectral-norm-classification-head', action='store_true', default=False,
- help='Apply spectral normalization on the classification head')
+ parser.add_argument(
+ "--quant-noise-pq",
+ type=float,
+ metavar="D",
+ default=0,
+ help="iterative PQ quantization noise at training time",
+ )
+ parser.add_argument(
+ "--quant-noise-pq-block-size",
+ type=int,
+ metavar="D",
+ default=8,
+ help="block size of quantization noise at training time",
+ )
+ parser.add_argument(
+ "--quant-noise-scalar",
+ type=float,
+ metavar="D",
+ default=0,
+ help="scalar quantization noise and scalar quantization at training time",
+ )
+ parser.add_argument(
+ "--untie-weights-roberta",
+ action="store_true",
+ help="Untie weights between embeddings and classifiers in RoBERTa",
+ )
+ parser.add_argument(
+ "--spectral-norm-classification-head",
+ action="store_true",
+ default=False,
+ help="Apply spectral normalization on the classification head",
+ )
@classmethod
def build_model(cls, args, task):
@@ -108,13 +170,20 @@ def build_model(cls, args, task):
# make sure all arguments are present
base_architecture(args)
- if not hasattr(args, 'max_positions'):
+ if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
encoder = RobertaEncoder(args, task.source_dictionary)
return cls(args, encoder)
- def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs):
+ def forward(
+ self,
+ src_tokens,
+ features_only=False,
+ return_all_hiddens=False,
+ classification_head_name=None,
+ **kwargs
+ ):
if classification_head_name is not None:
features_only = True
@@ -132,7 +201,9 @@ def get_normalized_probs(self, net_output, log_probs, sample=None):
else:
return F.softmax(logits, dim=-1)
- def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
+ def register_classification_head(
+ self, name, num_classes=None, inner_dim=None, **kwargs
+ ):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
@@ -140,7 +211,7 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, *
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
- 'and inner_dim {} (prev: {})'.format(
+ "and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
@@ -157,11 +228,19 @@ def register_classification_head(self, name, num_classes=None, inner_dim=None, *
@property
def supported_targets(self):
- return {'self'}
+ return {"self"}
@classmethod
- def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='gpt2', **kwargs):
+ def from_pretrained(
+ cls,
+ model_name_or_path,
+ checkpoint_file="model.pt",
+ data_name_or_path=".",
+ bpe="gpt2",
+ **kwargs
+ ):
from fairseq import hub_utils
+
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
@@ -171,15 +250,15 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na
load_checkpoint_heads=True,
**kwargs,
)
- return RobertaHubInterface(x['args'], x['task'], x['models'][0])
+ return RobertaHubInterface(x["args"], x["task"], x["models"][0])
def upgrade_state_dict_named(self, state_dict, name):
- prefix = name + '.' if name != '' else ''
+ prefix = name + "." if name != "" else ""
# rename decoder -> encoder before upgrading children modules
for k in list(state_dict.keys()):
- if k.startswith(prefix + 'decoder'):
- new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):]
+ if k.startswith(prefix + "decoder"):
+ new_k = prefix + "encoder" + k[len(prefix + "decoder") :]
state_dict[new_k] = state_dict[k]
del state_dict[k]
@@ -188,35 +267,44 @@ def upgrade_state_dict_named(self, state_dict, name):
# Handle new classification heads present in the state dict.
current_head_names = (
- [] if not hasattr(self, 'classification_heads')
+ []
+ if not hasattr(self, "classification_heads")
else self.classification_heads.keys()
)
keys_to_delete = []
for k in state_dict.keys():
- if not k.startswith(prefix + 'classification_heads.'):
+ if not k.startswith(prefix + "classification_heads."):
continue
- head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
- num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
- inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
+ head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
+ num_classes = state_dict[
+ prefix + "classification_heads." + head_name + ".out_proj.weight"
+ ].size(0)
+ inner_dim = state_dict[
+ prefix + "classification_heads." + head_name + ".dense.weight"
+ ].size(0)
- if getattr(self.args, 'load_checkpoint_heads', False):
+ if getattr(self.args, "load_checkpoint_heads", False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
- 'deleting classification head ({}) from checkpoint '
- 'not present in current model: {}'.format(head_name, k)
+ "deleting classification head ({}) from checkpoint "
+ "not present in current model: {}".format(head_name, k)
)
keys_to_delete.append(k)
elif (
- num_classes != self.classification_heads[head_name].out_proj.out_features
- or inner_dim != self.classification_heads[head_name].dense.out_features
+ num_classes
+ != self.classification_heads[head_name].out_proj.out_features
+ or inner_dim
+ != self.classification_heads[head_name].dense.out_features
):
logger.warning(
- 'deleting classification head ({}) from checkpoint '
- 'with different dimensions than current model: {}'.format(head_name, k)
+ "deleting classification head ({}) from checkpoint "
+ "with different dimensions than current model: {}".format(
+ head_name, k
+ )
)
keys_to_delete.append(k)
for k in keys_to_delete:
@@ -224,12 +312,12 @@ def upgrade_state_dict_named(self, state_dict, name):
# Copy any newly-added classification heads into the state dict
# with their current weights.
- if hasattr(self, 'classification_heads'):
+ if hasattr(self, "classification_heads"):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
- if prefix + 'classification_heads.' + k not in state_dict:
- logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
- state_dict[prefix + 'classification_heads.' + k] = v
+ if prefix + "classification_heads." + k not in state_dict:
+ logger.info("Overwriting " + prefix + "classification_heads." + k)
+ state_dict[prefix + "classification_heads." + k] = v
class RobertaLMHead(nn.Module):
@@ -284,7 +372,8 @@ def __init__(
if do_spectral_norm:
if q_noise != 0:
raise NotImplementedError(
- "Attempting to use Spectral Normalization with Quant Noise. This is not officially supported")
+ "Attempting to use Spectral Normalization with Quant Noise. This is not officially supported"
+ )
self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
def forward(self, features, **kwargs):
@@ -326,7 +415,7 @@ def __init__(self, args, dictionary):
q_noise=args.quant_noise_pq,
qn_block_size=args.quant_noise_pq_block_size,
)
- args.untie_weights_roberta = getattr(args, 'untie_weights_roberta', False)
+ args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False)
self.lm_head = RobertaLMHead(
embed_dim=args.encoder_embed_dim,
@@ -339,7 +428,14 @@ def __init__(self, args, dictionary):
),
)
- def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
+ def forward(
+ self,
+ src_tokens,
+ features_only=False,
+ return_all_hiddens=False,
+ masked_tokens=None,
+ **unused
+ ):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
@@ -356,7 +452,9 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, mas
is a list of hidden states. Note that the hidden
states have shape `(src_len, batch, vocab)`.
"""
- x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens)
+ x, extra = self.extract_features(
+ src_tokens, return_all_hiddens=return_all_hiddens
+ )
if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
@@ -365,10 +463,10 @@ def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs):
inner_states, _ = self.sentence_encoder(
src_tokens,
last_state_only=not return_all_hiddens,
- token_embeddings=kwargs.get('token_embeddings', None),
+ token_embeddings=kwargs.get("token_embeddings", None),
)
features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
- return features, {'inner_states': inner_states if return_all_hiddens else None}
+ return features, {"inner_states": inner_states if return_all_hiddens else None}
def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features, masked_tokens)
@@ -378,44 +476,46 @@ def max_positions(self):
return self.args.max_positions
-@register_model_architecture('roberta', 'roberta')
+@register_model_architecture("roberta", "roberta")
def base_architecture(args):
- args.encoder_layers = getattr(args, 'encoder_layers', 12)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
-
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
- args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
-
- args.dropout = getattr(args, 'dropout', 0.1)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
- args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
- args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
- args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None)
- args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0)
- args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0)
- args.spectral_norm_classification_head = getattr(args, 'spectral_nrom_classification_head', False)
-
-
-@register_model_architecture('roberta', 'roberta_base')
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
+
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
+
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
+ args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
+ args.spectral_norm_classification_head = getattr(
+ args, "spectral_nrom_classification_head", False
+ )
+
+
+@register_model_architecture("roberta", "roberta_base")
def roberta_base_architecture(args):
base_architecture(args)
-@register_model_architecture('roberta', 'roberta_large')
+@register_model_architecture("roberta", "roberta_large")
def roberta_large_architecture(args):
- args.encoder_layers = getattr(args, 'encoder_layers', 24)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)
-@register_model_architecture('roberta', 'xlm')
+@register_model_architecture("roberta", "xlm")
def xlm_architecture(args):
- args.encoder_layers = getattr(args, 'encoder_layers', 16)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1280)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1280*4)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
+ args.encoder_layers = getattr(args, "encoder_layers", 16)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1280 * 4)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)
diff --git a/fairseq/models/roberta/model_camembert.py b/fairseq/models/roberta/model_camembert.py
index eb57d81d8d..46447546fa 100644
--- a/fairseq/models/roberta/model_camembert.py
+++ b/fairseq/models/roberta/model_camembert.py
@@ -12,25 +12,32 @@
from .model import RobertaModel
-@register_model('camembert')
+@register_model("camembert")
class CamembertModel(RobertaModel):
-
@classmethod
def hub_models(cls):
return {
- 'camembert': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz',
- 'camembert.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz',
- 'camembert-base': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz',
- 'camembert-large': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz',
- 'camembert-base-ccnet': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz',
- 'camembert-base-ccnet-4gb': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz',
- 'camembert-base-wikipedia-4gb': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz',
- 'camembert-base-oscar-4gb': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz',
+ "camembert": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz",
+ "camembert.v0": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz",
+ "camembert-base": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz",
+ "camembert-large": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz",
+ "camembert-base-ccnet": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz",
+ "camembert-base-ccnet-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz",
+ "camembert-base-wikipedia-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz",
+ "camembert-base-oscar-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz",
}
@classmethod
- def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs):
+ def from_pretrained(
+ cls,
+ model_name_or_path,
+ checkpoint_file="model.pt",
+ data_name_or_path=".",
+ bpe="sentencepiece",
+ **kwargs
+ ):
from fairseq import hub_utils
+
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
@@ -40,4 +47,4 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na
load_checkpoint_heads=True,
**kwargs,
)
- return RobertaHubInterface(x['args'], x['task'], x['models'][0])
+ return RobertaHubInterface(x["args"], x["task"], x["models"][0])
diff --git a/fairseq/models/roberta/model_xlmr.py b/fairseq/models/roberta/model_xlmr.py
index fa71a27d12..5886880f73 100644
--- a/fairseq/models/roberta/model_xlmr.py
+++ b/fairseq/models/roberta/model_xlmr.py
@@ -12,19 +12,26 @@
from .model import RobertaModel
-@register_model('xlmr')
+@register_model("xlmr")
class XLMRModel(RobertaModel):
-
@classmethod
def hub_models(cls):
return {
- 'xlmr.base': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz',
- 'xlmr.large': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz',
+ "xlmr.base": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz",
+ "xlmr.large": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz",
}
@classmethod
- def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs):
+ def from_pretrained(
+ cls,
+ model_name_or_path,
+ checkpoint_file="model.pt",
+ data_name_or_path=".",
+ bpe="sentencepiece",
+ **kwargs
+ ):
from fairseq import hub_utils
+
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
@@ -34,4 +41,4 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na
load_checkpoint_heads=True,
**kwargs,
)
- return RobertaHubInterface(x['args'], x['task'], x['models'][0])
+ return RobertaHubInterface(x["args"], x["task"], x["models"][0])
diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py
index 351c16fee5..5d7f59b3a6 100644
--- a/fairseq/models/speech_to_text/__init__.py
+++ b/fairseq/models/speech_to_text/__init__.py
@@ -4,4 +4,4 @@
# LICENSE file in the root directory of this source tree.
from .berard import * # noqa
-from .s2t_transformer import * # noqa
+from .s2t_transformer import * # noqa
diff --git a/fairseq/models/speech_to_text/berard.py b/fairseq/models/speech_to_text/berard.py
index f5ae46eeb2..c505e3acaa 100644
--- a/fairseq/models/speech_to_text/berard.py
+++ b/fairseq/models/speech_to_text/berard.py
@@ -6,16 +6,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import checkpoint_utils, utils
+from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
- FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
+ FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
-from fairseq.data.data_utils import lengths_to_padding_mask
@register_model("s2t_berard")
@@ -40,48 +39,85 @@ def __init__(self, encoder, decoder):
@staticmethod
def add_args(parser):
- parser.add_argument("--input-layers", type=str, metavar="EXPR",
- help="List of linear layer dimensions. These "
- "layers are applied to the input features and "
- "are followed by tanh and possibly dropout.")
parser.add_argument(
- "--dropout", type=float, metavar="D",
+ "--input-layers",
+ type=str,
+ metavar="EXPR",
+ help="List of linear layer dimensions. These "
+ "layers are applied to the input features and "
+ "are followed by tanh and possibly dropout.",
+ )
+ parser.add_argument(
+ "--dropout",
+ type=float,
+ metavar="D",
help="Dropout probability to use in the encoder/decoder. "
- "Note that this parameters control dropout in various places, "
- "there is no fine-grained control for dropout for embeddings "
- "vs LSTM layers for example."
+ "Note that this parameters control dropout in various places, "
+ "there is no fine-grained control for dropout for embeddings "
+ "vs LSTM layers for example.",
+ )
+ parser.add_argument(
+ "--in-channels",
+ type=int,
+ metavar="N",
+ help="Number of encoder input channels. " "Typically value is 1.",
+ )
+ parser.add_argument(
+ "--conv-layers",
+ type=str,
+ metavar="EXPR",
+ help="List of conv layers " "(format: (channels, kernel, stride)).",
+ )
+ parser.add_argument(
+ "--num-blstm-layers",
+ type=int,
+ metavar="N",
+ help="Number of encoder bi-LSTM layers.",
)
- parser.add_argument("--in-channels", type=int, metavar="N",
- help="Number of encoder input channels. "
- "Typically value is 1.")
- parser.add_argument("--conv-layers", type=str, metavar="EXPR",
- help="List of conv layers "
- "(format: (channels, kernel, stride)).")
- parser.add_argument("--num-blstm-layers", type=int, metavar="N",
- help="Number of encoder bi-LSTM layers.")
- parser.add_argument("--lstm-size", type=int, metavar="N",
- help="LSTM hidden size.")
parser.add_argument(
- "--decoder-embed-dim", type=int, metavar="N",
- help="Embedding dimension of the decoder target tokens."
+ "--lstm-size", type=int, metavar="N", help="LSTM hidden size."
)
- parser.add_argument("--decoder-hidden-dim", type=int, metavar="N",
- help="Decoder LSTM hidden dimension.")
- parser.add_argument("--decoder-num-layers", type=int, metavar="N",
- help="Number of decoder LSTM layers.")
- parser.add_argument("--attention-dim", type=int, metavar="N",
- help="Hidden layer dimension in MLP attention.")
parser.add_argument(
- "--output-layer-dim", type=int, metavar="N",
- help="Hidden layer dim for linear layer prior to output projection."
+ "--decoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="Embedding dimension of the decoder target tokens.",
)
parser.add_argument(
- "--load-pretrained-encoder-from", type=str, metavar="STR",
- help="model to take encoder weights from (for initialization)"
+ "--decoder-hidden-dim",
+ type=int,
+ metavar="N",
+ help="Decoder LSTM hidden dimension.",
)
parser.add_argument(
- "--load-pretrained-decoder-from", type=str, metavar="STR",
- help="model to take decoder weights from (for initialization)"
+ "--decoder-num-layers",
+ type=int,
+ metavar="N",
+ help="Number of decoder LSTM layers.",
+ )
+ parser.add_argument(
+ "--attention-dim",
+ type=int,
+ metavar="N",
+ help="Hidden layer dimension in MLP attention.",
+ )
+ parser.add_argument(
+ "--output-layer-dim",
+ type=int,
+ metavar="N",
+ help="Hidden layer dim for linear layer prior to output projection.",
+ )
+ parser.add_argument(
+ "--load-pretrained-encoder-from",
+ type=str,
+ metavar="STR",
+ help="model to take encoder weights from (for initialization)",
+ )
+ parser.add_argument(
+ "--load-pretrained-decoder-from",
+ type=str,
+ metavar="STR",
+ help="model to take decoder weights from (for initialization)",
)
@classmethod
@@ -170,8 +206,7 @@ def __init__(
if dropout > 0:
self.input_layers.append(
nn.Sequential(
- nn.Linear(in_features, out_features),
- nn.Dropout(p=dropout)
+ nn.Linear(in_features, out_features), nn.Dropout(p=dropout)
)
)
else:
@@ -194,9 +229,7 @@ def __init__(
padding=conv_kernel_size // 2,
)
)
- self.conv_kernel_sizes_and_strides.append(
- (conv_kernel_size, conv_stride)
- )
+ self.conv_kernel_sizes_and_strides.append((conv_kernel_size, conv_stride))
in_channels = out_channels
lstm_input_dim //= conv_stride
@@ -241,8 +274,7 @@ def forward(self, src_tokens, src_lengths=None, **kwargs):
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) ->
# (T, B, C * feat)
- x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len,
- bsz, -1)
+ x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1)
input_lengths = src_lengths.clone()
for k, s in self.conv_kernel_sizes_and_strides:
@@ -261,8 +293,9 @@ def forward(self, src_tokens, src_lengths=None, **kwargs):
if self.dropout is not None:
x = self.dropout(x)
- encoder_padding_mask = lengths_to_padding_mask(output_lengths).to(
- src_tokens.device).t()
+ encoder_padding_mask = (
+ lengths_to_padding_mask(output_lengths).to(src_tokens.device).t()
+ )
return {
"encoder_out": x, # (T, B, C)
@@ -293,8 +326,7 @@ def __init__(self, decoder_hidden_state_dim, context_dim, attention_dim):
self.context_dim = context_dim
self.attention_dim = attention_dim
# W_ae and b_a
- self.encoder_proj = nn.Linear(context_dim, self.attention_dim,
- bias=True)
+ self.encoder_proj = nn.Linear(context_dim, self.attention_dim, bias=True)
# W_ad
self.decoder_proj = nn.Linear(
decoder_hidden_state_dim, self.attention_dim, bias=False
@@ -314,8 +346,7 @@ def forward(self, decoder_state, source_hids, encoder_padding_mask):
# (src_len*bsz) x attention_dim
encoder_component = self.encoder_proj(flat_source_hids)
# src_len x bsz x attention_dim
- encoder_component = encoder_component.view(src_len, bsz,
- self.attention_dim)
+ encoder_component = encoder_component.view(src_len, bsz, self.attention_dim)
# 1 x bsz x attention_dim
decoder_component = self.decoder_proj(decoder_state).unsqueeze(0)
# Sum with broadcasting and apply the non linearity
@@ -400,8 +431,9 @@ def __init__(
)
self.output_projection = nn.Linear(output_layer_dim, num_embeddings)
- def forward(self, prev_output_tokens, encoder_out=None,
- incremental_state=None, **kwargs):
+ def forward(
+ self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
+ ):
encoder_padding_mask = encoder_out["encoder_padding_mask"]
encoder_outs = encoder_out["encoder_out"]
@@ -428,9 +460,7 @@ def forward(self, prev_output_tokens, encoder_out=None,
if cached_state is not None:
prev_hiddens, prev_cells = cached_state
else:
- prev_hiddens = [
- encoder_out["encoder_out"].mean(dim=0)
- ] * self.num_layers
+ prev_hiddens = [encoder_out["encoder_out"].mean(dim=0)] * self.num_layers
prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers
attn_scores = x.new_zeros(bsz, srclen)
@@ -510,9 +540,7 @@ def reorder_state(state):
return state.index_select(0, new_order)
new_state = tuple(map(reorder_state, cached_state))
- utils.set_incremental_state(
- self, incremental_state, "cached_state", new_state
- )
+ utils.set_incremental_state(self, incremental_state, "cached_state", new_state)
@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard")
@@ -538,8 +566,7 @@ def berard(args):
)
-@register_model_architecture(model_name="s2t_berard",
- arch_name="s2t_berard_256_3_3")
+@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_256_3_3")
def berard_256_3_3(args):
"""Used in
* "Harnessing Indirect Training Data for End-to-End Automatic Speech
@@ -553,8 +580,7 @@ def berard_256_3_3(args):
berard(args)
-@register_model_architecture(model_name="s2t_berard",
- arch_name="s2t_berard_512_3_2")
+@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_3_2")
def berard_512_3_2(args):
args.num_blstm_layers = getattr(args, "num_blstm_layers", 3)
args.lstm_size = getattr(args, "lstm_size", 512)
@@ -567,8 +593,7 @@ def berard_512_3_2(args):
berard(args)
-@register_model_architecture(model_name="s2t_berard",
- arch_name="s2t_berard_512_5_3")
+@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_5_3")
def berard_512_5_3(args):
args.num_blstm_layers = getattr(args, "num_blstm_layers", 5)
args.lstm_size = getattr(args, "lstm_size", 512)
diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py
index 3492f691f7..8e48964f79 100644
--- a/fairseq/models/speech_to_text/s2t_transformer.py
+++ b/fairseq/models/speech_to_text/s2t_transformer.py
@@ -6,14 +6,22 @@
import torch
import torch.nn as nn
-from fairseq import utils, checkpoint_utils
-from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
- register_model, register_model_architecture)
-from fairseq.models.fairseq_encoder import EncoderOut
+from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
+from fairseq.models import (
+ FairseqEncoder,
+ FairseqEncoderDecoderModel,
+ register_model,
+ register_model_architecture,
+)
+from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import Embedding, TransformerDecoder
-from fairseq.modules import (PositionalEmbedding, TransformerEncoderLayer,
- FairseqDropout, LayerNorm)
+from fairseq.modules import (
+ FairseqDropout,
+ LayerNorm,
+ PositionalEmbedding,
+ TransformerEncoderLayer,
+)
from torch import Tensor
@@ -31,15 +39,23 @@ class Conv1dSubsampler(nn.Module):
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
- def __init__(self, in_channels: int, mid_channels: int, out_channels: int,
- kernel_sizes: List[int] = (3, 3)):
+
+ def __init__(
+ self,
+ in_channels: int,
+ mid_channels: int,
+ out_channels: int,
+ kernel_sizes: List[int] = (3, 3),
+ ):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(kernel_sizes)
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
- k, stride=2, padding=k // 2
+ k,
+ stride=2,
+ padding=k // 2,
)
for i, k in enumerate(kernel_sizes)
)
@@ -76,48 +92,109 @@ def __init__(self, encoder, decoder):
def add_args(parser):
"""Add model-specific arguments to the parser."""
# input
- parser.add_argument("--conv-kernel-sizes", type=str, metavar="N",
- help="kernel sizes of Conv1d subsampling layers")
- parser.add_argument("--conv-channels", type=int, metavar="N",
- help="# of channels in Conv1d subsampling layers")
+ parser.add_argument(
+ "--conv-kernel-sizes",
+ type=str,
+ metavar="N",
+ help="kernel sizes of Conv1d subsampling layers",
+ )
+ parser.add_argument(
+ "--conv-channels",
+ type=int,
+ metavar="N",
+ help="# of channels in Conv1d subsampling layers",
+ )
# Transformer
- parser.add_argument("--activation-fn", type=str, default='relu',
- choices=utils.get_available_activation_fns(),
- help="activation function to use")
- parser.add_argument("--dropout", type=float, metavar="D",
- help="dropout probability")
- parser.add_argument("--attention-dropout", type=float, metavar="D",
- help="dropout probability for attention weights")
- parser.add_argument("--activation-dropout", "--relu-dropout",
- type=float, metavar="D",
- help="dropout probability after activation in FFN.")
- parser.add_argument("--encoder-embed-dim", type=int, metavar="N",
- help="encoder embedding dimension")
- parser.add_argument("--encoder-ffn-embed-dim", type=int, metavar="N",
- help="encoder embedding dimension for FFN")
- parser.add_argument("--encoder-layers", type=int, metavar="N",
- help="num encoder layers")
- parser.add_argument("--encoder-attention-heads", type=int, metavar="N",
- help="num encoder attention heads")
- parser.add_argument("--encoder-normalize-before", action="store_true",
- help="apply layernorm before each encoder block")
- parser.add_argument("--decoder-embed-dim", type=int, metavar="N",
- help="decoder embedding dimension")
- parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N",
- help="decoder embedding dimension for FFN")
- parser.add_argument("--decoder-layers", type=int, metavar="N",
- help="num decoder layers")
- parser.add_argument("--decoder-attention-heads", type=int, metavar="N",
- help="num decoder attention heads")
- parser.add_argument("--decoder-normalize-before", action="store_true",
- help="apply layernorm before each decoder block")
- parser.add_argument("--layernorm-embedding", action="store_true",
- help="add layernorm to embedding")
- parser.add_argument("--no-scale-embedding", action="store_true",
- help="if True, dont scale embeddings")
parser.add_argument(
- "--load-pretrained-encoder-from", type=str, metavar="STR",
- help="model to take encoder weights from (for initialization)"
+ "--activation-fn",
+ type=str,
+ default="relu",
+ choices=utils.get_available_activation_fns(),
+ help="activation function to use",
+ )
+ parser.add_argument(
+ "--dropout", type=float, metavar="D", help="dropout probability"
+ )
+ parser.add_argument(
+ "--attention-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability for attention weights",
+ )
+ parser.add_argument(
+ "--activation-dropout",
+ "--relu-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability after activation in FFN.",
+ )
+ parser.add_argument(
+ "--encoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension",
+ )
+ parser.add_argument(
+ "--encoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--encoder-layers", type=int, metavar="N", help="num encoder layers"
+ )
+ parser.add_argument(
+ "--encoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num encoder attention heads",
+ )
+ parser.add_argument(
+ "--encoder-normalize-before",
+ action="store_true",
+ help="apply layernorm before each encoder block",
+ )
+ parser.add_argument(
+ "--decoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension",
+ )
+ parser.add_argument(
+ "--decoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
+ )
+ parser.add_argument(
+ "--decoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num decoder attention heads",
+ )
+ parser.add_argument(
+ "--decoder-normalize-before",
+ action="store_true",
+ help="apply layernorm before each decoder block",
+ )
+ parser.add_argument(
+ "--layernorm-embedding",
+ action="store_true",
+ help="add layernorm to embedding",
+ )
+ parser.add_argument(
+ "--no-scale-embedding",
+ action="store_true",
+ help="if True, dont scale embeddings",
+ )
+ parser.add_argument(
+ "--load-pretrained-encoder-from",
+ type=str,
+ metavar="STR",
+ help="model to take encoder weights from (for initialization)",
)
@classmethod
@@ -127,14 +204,15 @@ def build_encoder(cls, args):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
- logger.info(f'loaded pretrained encoder from: '
- f'{args.load_pretrained_encoder_from}')
+ logger.info(
+ f"loaded pretrained encoder from: "
+ f"{args.load_pretrained_encoder_from}"
+ )
return encoder
@classmethod
def build_decoder(cls, args, task, embed_tokens):
- return TransformerDecoderScriptable(args, task.target_dictionary,
- embed_tokens)
+ return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
@classmethod
def build_model(cls, args, task):
@@ -148,8 +226,9 @@ def build_embedding(dictionary, embed_dim):
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
- decoder_embed_tokens = build_embedding(task.target_dictionary,
- args.decoder_embed_dim)
+ decoder_embed_tokens = build_embedding(
+ task.target_dictionary, args.decoder_embed_dim
+ )
encoder = cls.build_encoder(args)
decoder = cls.build_decoder(args, task, decoder_embed_tokens)
return cls(encoder, decoder)
@@ -161,8 +240,7 @@ def get_normalized_probs(
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
- lprobs = self.get_normalized_probs_scriptable(net_output, log_probs,
- sample)
+ lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
@@ -172,10 +250,10 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens):
argument in its input, which is not supported in torchscript. This
method overrites the forward method definition without **kwargs.
"""
- encoder_out = self.encoder(src_tokens=src_tokens,
- src_lengths=src_lengths)
- decoder_out = self.decoder(prev_output_tokens=prev_output_tokens,
- encoder_out=encoder_out)
+ encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
+ decoder_out = self.decoder(
+ prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
+ )
return decoder_out
@@ -196,13 +274,13 @@ def __init__(self, args):
self.subsample = Conv1dSubsampler(
args.input_feat_per_channel * args.input_channels,
- args.conv_channels, args.encoder_embed_dim,
- [int(k) for k in args.conv_kernel_sizes.split(',')]
+ args.conv_channels,
+ args.encoder_embed_dim,
+ [int(k) for k in args.conv_kernel_sizes.split(",")],
)
self.embed_positions = PositionalEmbedding(
- args.max_source_positions, args.encoder_embed_dim,
- self.padding_idx
+ args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
self.transformer_layers = nn.ModuleList(
@@ -232,9 +310,12 @@ def forward(self, src_tokens, src_lengths):
x = self.layer_norm(x)
return EncoderOut(
- encoder_out=x, encoder_padding_mask=encoder_padding_mask,
- encoder_embedding=None, encoder_states=None, src_tokens=None,
- src_lengths=None
+ encoder_out=x,
+ encoder_padding_mask=encoder_padding_mask,
+ encoder_embedding=None,
+ encoder_states=None,
+ src_tokens=None,
+ src_lengths=None,
)
@torch.jit.export
@@ -245,8 +326,7 @@ def reorder_encoder_out(self, encoder_out: EncoderOut, new_order):
variables for Torchscript Optional refinement
"""
- encoder_padding_mask: Optional[Tensor] = \
- encoder_out.encoder_padding_mask
+ encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask
encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding
new_encoder_out = (
@@ -294,40 +374,40 @@ def extract_features(
):
# call scriptable method from parent class
x, _ = self.extract_features_scriptable(
- prev_output_tokens, encoder_out, incremental_state,
- full_context_alignment, alignment_layer, alignment_heads,
+ prev_output_tokens,
+ encoder_out,
+ incremental_state,
+ full_context_alignment,
+ alignment_layer,
+ alignment_heads,
)
return x, None
-@register_model_architecture(model_name="s2t_transformer",
- arch_name="s2t_transformer")
+@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer")
def base_architecture(args):
# Convolutional subsampler
- args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", '5,5')
+ args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
args.conv_channels = getattr(args, "conv_channels", 1024)
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
- args.encoder_normalize_before = getattr(args, "encoder_normalize_before",
- True)
- args.decoder_embed_dim = getattr(args, "decoder_embed_dim",
- args.encoder_embed_dim)
- args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim",
- args.encoder_ffn_embed_dim)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
+ )
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
- args.decoder_normalize_before = getattr(args, "decoder_normalize_before",
- True)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
- args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff",
- None)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
@@ -337,10 +417,10 @@ def base_architecture(args):
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
- args.decoder_output_dim = getattr(args, "decoder_output_dim",
- args.decoder_embed_dim)
- args.decoder_input_dim = getattr(args, "decoder_input_dim",
- args.decoder_embed_dim)
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
@@ -380,8 +460,7 @@ def s2t_transformer_mp(args):
@register_model_architecture("s2t_transformer", "s2t_transformer_l")
def s2t_transformer_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
- args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim",
- 1024 * 4)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py
index ca1c6aaf5c..fbb7ce2338 100644
--- a/fairseq/models/transformer.py
+++ b/fairseq/models/transformer.py
@@ -30,6 +30,7 @@
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
+
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@@ -308,7 +309,9 @@ def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
- self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ args.dropout, module_name=self.__class__.__name__
+ )
self.encoder_layerdrop = args.encoder_layerdrop
embed_dim = embed_tokens.embedding_dim
@@ -543,7 +546,9 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
self.register_buffer("version", torch.Tensor([3]))
self._future_mask = torch.empty(0)
- self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ args.dropout, module_name=self.__class__.__name__
+ )
self.decoder_layerdrop = args.decoder_layerdrop
self.share_input_output_embed = args.share_decoder_input_output_embed
diff --git a/fairseq/models/transformer_align.py b/fairseq/models/transformer_align.py
index c80cc4341c..eaf585bd10 100644
--- a/fairseq/models/transformer_align.py
+++ b/fairseq/models/transformer_align.py
@@ -5,9 +5,9 @@
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
+ TransformerModel,
base_architecture,
transformer_wmt_en_de_big,
- TransformerModel,
)
diff --git a/fairseq/models/transformer_from_pretrained_xlm.py b/fairseq/models/transformer_from_pretrained_xlm.py
index bd03c8450f..236d9942e1 100644
--- a/fairseq/models/transformer_from_pretrained_xlm.py
+++ b/fairseq/models/transformer_from_pretrained_xlm.py
@@ -19,7 +19,6 @@
@register_model("transformer_from_pretrained_xlm")
class TransformerFromPretrainedXLMModel(TransformerModel):
-
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
@@ -96,25 +95,24 @@ def upgrade_state_dict_with_xlm_weights(
for search_key in ["embed_tokens", "embed_positions", "layers"]:
if search_key in key:
- subkey = key[key.find(search_key):]
+ subkey = key[key.find(search_key) :]
assert subkey in state_dict, (
"{} Transformer encoder / decoder "
"state_dict does not contain {}. Cannot "
"load {} from pretrained XLM checkpoint "
"{} into Transformer.".format(
- str(state_dict.keys()),
- subkey, key, pretrained_xlm_checkpoint)
+ str(state_dict.keys()), subkey, key, pretrained_xlm_checkpoint
)
+ )
state_dict[subkey] = xlm_state_dict[key]
return state_dict
class TransformerEncoderFromPretrainedXLM(TransformerEncoder):
-
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
- if getattr(args, 'init_decoder_only', False):
+ if getattr(args, "init_decoder_only", False):
# Don't load XLM weights for encoder if --init-decoder-only
return
@@ -130,10 +128,9 @@ def __init__(self, args, dictionary, embed_tokens):
class TransformerDecoderFromPretrainedXLM(TransformerDecoder):
-
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(args, dictionary, embed_tokens, no_encoder_attn)
- if getattr(args, 'init_encoder_only', False):
+ if getattr(args, "init_encoder_only", False):
# Don't load XLM weights for decoder if --init-encoder-only
return
assert hasattr(args, "pretrained_xlm_checkpoint"), (
diff --git a/fairseq/models/wav2vec/wav2vec.py b/fairseq/models/wav2vec/wav2vec.py
index 905df824f3..772995b526 100644
--- a/fairseq/models/wav2vec/wav2vec.py
+++ b/fairseq/models/wav2vec/wav2vec.py
@@ -10,7 +10,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.modules import (
Fp32GroupNorm,
@@ -21,6 +20,7 @@
)
from fairseq.utils import buffered_arange
+
logger = logging.getLogger(__name__)
diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py
index 4f1ab2277f..6a0f787601 100644
--- a/fairseq/models/wav2vec/wav2vec2.py
+++ b/fairseq/models/wav2vec/wav2vec2.py
@@ -5,14 +5,12 @@
import logging
import math
-import numpy as np
+from typing import List, Tuple
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-from typing import List, Tuple
-
from fairseq import utils
from fairseq.data.data_utils import compute_mask_indices
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py
index e47e1f7009..52ca9a8007 100644
--- a/fairseq/models/wav2vec/wav2vec2_asr.py
+++ b/fairseq/models/wav2vec/wav2vec2_asr.py
@@ -6,19 +6,17 @@
import contextlib
import copy
import math
-import numpy as np
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import checkpoint_utils, tasks, utils
-
from fairseq.models import (
+ BaseFairseqModel,
FairseqEncoder,
- FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
- BaseFairseqModel,
+ FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
@@ -335,7 +333,9 @@ def __init__(self, args, tgt_dict=None):
state = None
w2v_args = args.w2v_args
- assert args.normalize == w2v_args.normalize, 'Fine-tuning works best when data normalization is the same'
+ assert (
+ args.normalize == w2v_args.normalize
+ ), "Fine-tuning works best when data normalization is the same"
w2v_args.data = args.data
task = tasks.setup_task(w2v_args)
@@ -358,7 +358,7 @@ def __init__(self, args, tgt_dict=None):
if tgt_dict is not None:
self.proj = Linear(d, len(tgt_dict))
- elif getattr(args, 'decoder_embed_dim', d) != d:
+ elif getattr(args, "decoder_embed_dim", d) != d:
self.proj = Linear(d, args.decoder_embed_dim)
else:
self.proj = None
@@ -668,6 +668,8 @@ def seq2seq_architecture(args):
args.decoder_dropout = getattr(args, "decoder_dropout", 0)
args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0)
args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0)
- args.share_decoder_input_output_embed = getattr(args, "share_decoder_input_output_embed", False)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
base_architecture(args)
diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py
index 52432e0de4..e2326ac6e3 100644
--- a/fairseq/modules/__init__.py
+++ b/fairseq/modules/__init__.py
@@ -37,40 +37,40 @@
from .vggblock import VGGBlock
__all__ = [
- 'AdaptiveInput',
- 'AdaptiveSoftmax',
- 'BeamableMM',
- 'CharacterTokenEmbedder',
- 'ConvTBC',
- 'cross_entropy',
- 'DownsampledMultiHeadAttention',
- 'DynamicConv1dTBC',
- 'DynamicConv',
- 'DynamicCRF',
- 'FairseqDropout',
- 'Fp32GroupNorm',
- 'Fp32LayerNorm',
- 'gelu',
- 'gelu_accurate',
- 'GradMultiply',
- 'GumbelVectorQuantizer',
- 'KmeansVectorQuantizer',
- 'LayerDropModuleList',
- 'LayerNorm',
- 'LearnedPositionalEmbedding',
- 'LightweightConv1dTBC',
- 'LightweightConv',
- 'LinearizedConvolution',
- 'MultiheadAttention',
- 'PositionalEmbedding',
- 'SamePad',
- 'ScalarBias',
- 'SinusoidalPositionalEmbedding',
- 'TransformerSentenceEncoderLayer',
- 'TransformerSentenceEncoder',
- 'TransformerDecoderLayer',
- 'TransformerEncoderLayer',
- 'TransposeLast',
- 'VGGBlock',
- 'unfold1d',
+ "AdaptiveInput",
+ "AdaptiveSoftmax",
+ "BeamableMM",
+ "CharacterTokenEmbedder",
+ "ConvTBC",
+ "cross_entropy",
+ "DownsampledMultiHeadAttention",
+ "DynamicConv1dTBC",
+ "DynamicConv",
+ "DynamicCRF",
+ "FairseqDropout",
+ "Fp32GroupNorm",
+ "Fp32LayerNorm",
+ "gelu",
+ "gelu_accurate",
+ "GradMultiply",
+ "GumbelVectorQuantizer",
+ "KmeansVectorQuantizer",
+ "LayerDropModuleList",
+ "LayerNorm",
+ "LearnedPositionalEmbedding",
+ "LightweightConv1dTBC",
+ "LightweightConv",
+ "LinearizedConvolution",
+ "MultiheadAttention",
+ "PositionalEmbedding",
+ "SamePad",
+ "ScalarBias",
+ "SinusoidalPositionalEmbedding",
+ "TransformerSentenceEncoderLayer",
+ "TransformerSentenceEncoder",
+ "TransformerDecoderLayer",
+ "TransformerEncoderLayer",
+ "TransposeLast",
+ "VGGBlock",
+ "unfold1d",
]
diff --git a/fairseq/modules/adaptive_input.py b/fairseq/modules/adaptive_input.py
index 4cfe8fca66..446534a9f8 100644
--- a/fairseq/modules/adaptive_input.py
+++ b/fairseq/modules/adaptive_input.py
@@ -4,15 +4,14 @@
# LICENSE file in the root directory of this source tree.
+from typing import List
+
import torch
-from torch import nn
from fairseq.modules.quant_noise import quant_noise
-
-from typing import List
+from torch import nn
class AdaptiveInput(nn.Module):
-
def __init__(
self,
vocab_size: int,
@@ -29,8 +28,9 @@ def __init__(
if vocab_size > cutoff[-1]:
cutoff = cutoff + [vocab_size]
else:
- assert vocab_size == cutoff[
- -1], 'cannot specify cutoff larger than vocab size'
+ assert (
+ vocab_size == cutoff[-1]
+ ), "cannot specify cutoff larger than vocab size"
self.cutoff = cutoff
self.embedding_dim = output_dim
@@ -43,7 +43,9 @@ def __init__(
dim = int(initial_dim // (factor ** i))
seq = nn.Sequential(
nn.Embedding(size, dim, self.padding_idx),
- quant_noise(nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size),
+ quant_noise(
+ nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size
+ ),
)
self.embeddings.append(seq)
@@ -54,12 +56,12 @@ def init_weights(m):
if isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
- elif hasattr(m, 'weight'):
+ elif hasattr(m, "weight"):
nn.init.xavier_uniform_(m.weight)
self.apply(init_weights)
- self.register_buffer('_float_tensor', torch.FloatTensor(1))
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
def weights_for_band(self, band: int):
return self.embeddings[band][0].weight, self.embeddings[band][1].weight
diff --git a/fairseq/modules/adaptive_softmax.py b/fairseq/modules/adaptive_softmax.py
index 8e47134a70..ae0c77ba0f 100644
--- a/fairseq/modules/adaptive_softmax.py
+++ b/fairseq/modules/adaptive_softmax.py
@@ -3,13 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import operator
import functools
+import operator
import torch
import torch.nn.functional as F
-from fairseq.modules.quant_noise import quant_noise
from fairseq.modules.fairseq_dropout import FairseqDropout
+from fairseq.modules.quant_noise import quant_noise
from torch import nn
@@ -29,23 +29,29 @@ def __init__(self, weights, input_dim, num_classes, q_noise, qn_block_size):
tied_emb, _ = weights
self.num_words, emb_dim = tied_emb.size()
- self.word_proj = quant_noise(TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size)
+ self.word_proj = quant_noise(
+ TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size
+ )
if input_dim != emb_dim:
self.word_proj = nn.Sequential(
- quant_noise(nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size),
+ quant_noise(
+ nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size
+ ),
self.word_proj,
)
- self.class_proj = quant_noise(nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size)
+ self.class_proj = quant_noise(
+ nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size
+ )
self.out_dim = self.num_words + num_classes
- self.register_buffer('_float_tensor', torch.FloatTensor(1))
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
def forward(self, input):
inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1)
out = self._float_tensor.new(inp_sz, self.out_dim)
- out[:, :self.num_words] = self.word_proj(input.view(inp_sz, -1))
- out[:, self.num_words:] = self.class_proj(input.view(inp_sz, -1))
+ out[:, : self.num_words] = self.word_proj(input.view(inp_sz, -1))
+ out[:, self.num_words :] = self.class_proj(input.view(inp_sz, -1))
return out
@@ -56,21 +62,34 @@ class AdaptiveSoftmax(nn.Module):
approximation for GPUs" (http://arxiv.org/abs/1609.04309).
"""
- def __init__(self, vocab_size, input_dim, cutoff, dropout, factor=4., adaptive_inputs=None, tie_proj=False,
- q_noise=0, qn_block_size=8):
+ def __init__(
+ self,
+ vocab_size,
+ input_dim,
+ cutoff,
+ dropout,
+ factor=4.0,
+ adaptive_inputs=None,
+ tie_proj=False,
+ q_noise=0,
+ qn_block_size=8,
+ ):
super().__init__()
if vocab_size > cutoff[-1]:
cutoff = cutoff + [vocab_size]
else:
- assert vocab_size == cutoff[
- -1], 'cannot specify cutoff larger than vocab size'
+ assert (
+ vocab_size == cutoff[-1]
+ ), "cannot specify cutoff larger than vocab size"
output_dim = cutoff[0] + len(cutoff) - 1
self.vocab_size = vocab_size
self.cutoff = cutoff
- self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ dropout, module_name=self.__class__.__name__
+ )
self.input_dim = input_dim
self.factor = factor
self.q_noise = q_noise
@@ -79,38 +98,69 @@ def __init__(self, vocab_size, input_dim, cutoff, dropout, factor=4., adaptive_i
self.lsm = nn.LogSoftmax(dim=1)
if adaptive_inputs is not None:
- self.head = TiedHeadModule(adaptive_inputs.weights_for_band(0), input_dim, len(cutoff) - 1, self.q_noise, self.qn_block_size)
+ self.head = TiedHeadModule(
+ adaptive_inputs.weights_for_band(0),
+ input_dim,
+ len(cutoff) - 1,
+ self.q_noise,
+ self.qn_block_size,
+ )
else:
- self.head = quant_noise(nn.Linear(input_dim, output_dim, bias=False), self.q_noise, self.qn_block_size)
+ self.head = quant_noise(
+ nn.Linear(input_dim, output_dim, bias=False),
+ self.q_noise,
+ self.qn_block_size,
+ )
self._make_tail(adaptive_inputs, tie_proj)
def init_weights(m):
- if hasattr(m, 'weight') and not isinstance(m, TiedLinear) and not isinstance(m, TiedHeadModule):
+ if (
+ hasattr(m, "weight")
+ and not isinstance(m, TiedLinear)
+ and not isinstance(m, TiedHeadModule)
+ ):
nn.init.xavier_uniform_(m.weight)
self.apply(init_weights)
- self.register_buffer('version', torch.LongTensor([1]))
+ self.register_buffer("version", torch.LongTensor([1]))
def _make_tail(self, adaptive_inputs=None, tie_proj=False):
self.tail = nn.ModuleList()
for i in range(len(self.cutoff) - 1):
dim = int(self.input_dim // self.factor ** (i + 1))
- tied_emb, tied_proj = adaptive_inputs.weights_for_band(i + 1) \
- if adaptive_inputs is not None else (None, None)
+ tied_emb, tied_proj = (
+ adaptive_inputs.weights_for_band(i + 1)
+ if adaptive_inputs is not None
+ else (None, None)
+ )
if tied_proj is not None:
if tie_proj:
- proj = quant_noise(TiedLinear(tied_proj, transpose=True), self.q_noise, self.qn_block_size)
+ proj = quant_noise(
+ TiedLinear(tied_proj, transpose=True),
+ self.q_noise,
+ self.qn_block_size,
+ )
else:
- proj = quant_noise(nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False), self.q_noise, self.qn_block_size)
+ proj = quant_noise(
+ nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False),
+ self.q_noise,
+ self.qn_block_size,
+ )
else:
- proj = quant_noise(nn.Linear(self.input_dim, dim, bias=False), self.q_noise, self.qn_block_size)
+ proj = quant_noise(
+ nn.Linear(self.input_dim, dim, bias=False),
+ self.q_noise,
+ self.qn_block_size,
+ )
if tied_emb is None:
- out_proj = nn.Linear(dim, self.cutoff[i + 1] - self.cutoff[i], bias=False)
+ out_proj = nn.Linear(
+ dim, self.cutoff[i + 1] - self.cutoff[i], bias=False
+ )
else:
out_proj = TiedLinear(tied_emb, transpose=False)
@@ -123,9 +173,9 @@ def _make_tail(self, adaptive_inputs=None, tie_proj=False):
self.tail.append(m)
def upgrade_state_dict_named(self, state_dict, name):
- version_name = name + '.version'
+ version_name = name + ".version"
if version_name not in state_dict:
- raise Exception('This version of the model is no longer supported')
+ raise Exception("This version of the model is no longer supported")
def adapt_target(self, target):
"""
@@ -194,7 +244,7 @@ def get_log_prob(self, input, target):
head_sz = self.cutoff[0] + len(self.tail)
log_probs[:, :head_sz] = self.lsm(head_y)
- tail_priors = log_probs[:, self.cutoff[0]: head_sz].clone()
+ tail_priors = log_probs[:, self.cutoff[0] : head_sz].clone()
for i in range(len(self.tail)):
start = self.cutoff[i]
@@ -203,12 +253,16 @@ def get_log_prob(self, input, target):
if target_idxs is None:
tail_out = log_probs[:, start:end]
tail_out.copy_(self.tail[i](input))
- log_probs[:, start:end] = self.lsm(tail_out).add_(tail_priors[:, i, None])
+ log_probs[:, start:end] = self.lsm(tail_out).add_(
+ tail_priors[:, i, None]
+ )
elif target_idxs[i] is not None:
idxs = target_idxs[i]
tail_out = log_probs[idxs, start:end]
tail_out.copy_(self.tail[i](input[idxs]))
- log_probs[idxs, start:end] = self.lsm(tail_out).add_(tail_priors[idxs, i, None])
+ log_probs[idxs, start:end] = self.lsm(tail_out).add_(
+ tail_priors[idxs, i, None]
+ )
log_probs = log_probs.view(bsz, length, -1)
return log_probs
diff --git a/fairseq/modules/beamable_mm.py b/fairseq/modules/beamable_mm.py
index df77105a94..eff1a4607f 100644
--- a/fairseq/modules/beamable_mm.py
+++ b/fairseq/modules/beamable_mm.py
@@ -15,16 +15,18 @@ class BeamableMM(nn.Module):
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
"""
+
def __init__(self, beam_size=None):
super(BeamableMM, self).__init__()
self.beam_size = beam_size
def forward(self, input1, input2):
if (
- not self.training and # test mode
- self.beam_size is not None and # beam size is set
- input1.dim() == 3 and # only support batched input
- input1.size(1) == 1 # single time step update
+ not self.training
+ and self.beam_size is not None # test mode
+ and input1.dim() == 3 # beam size is set
+ and input1.size(1) # only support batched input
+ == 1 # single time step update
):
bsz, beam = input1.size(0), self.beam_size
diff --git a/fairseq/modules/character_token_embedder.py b/fairseq/modules/character_token_embedder.py
index 3abdaf4f28..181221b61b 100644
--- a/fairseq/modules/character_token_embedder.py
+++ b/fairseq/modules/character_token_embedder.py
@@ -7,10 +7,10 @@
from typing import List, Tuple
import torch
-from torch import nn
import torch.nn.functional as F
-
from fairseq.data import Dictionary
+from torch import nn
+
CHAR_PAD_IDX = 0
CHAR_EOS_IDX = 257
@@ -21,14 +21,14 @@
class CharacterTokenEmbedder(torch.nn.Module):
def __init__(
- self,
- vocab: Dictionary,
- filters: List[Tuple[int, int]],
- char_embed_dim: int,
- word_embed_dim: int,
- highway_layers: int,
- max_char_len: int = 50,
- char_inputs: bool = False
+ self,
+ vocab: Dictionary,
+ filters: List[Tuple[int, int]],
+ char_embed_dim: int,
+ word_embed_dim: int,
+ highway_layers: int,
+ max_char_len: int = 50,
+ char_inputs: bool = False,
):
super(CharacterTokenEmbedder, self).__init__()
@@ -52,7 +52,9 @@ def __init__(
self.projection = nn.Linear(last_dim, word_embed_dim)
- assert vocab is not None or char_inputs, "vocab must be set if not using char inputs"
+ assert (
+ vocab is not None or char_inputs
+ ), "vocab must be set if not using char inputs"
self.vocab = None
if vocab is not None:
self.set_vocab(vocab, max_char_len)
@@ -79,7 +81,11 @@ def set_vocab(self, vocab, max_char_len):
word_to_char[i] = torch.LongTensor(char_idxs)
if truncated > 0:
- logger.info('truncated {} words longer than {} characters'.format(truncated, max_char_len))
+ logger.info(
+ "truncated {} words longer than {} characters".format(
+ truncated, max_char_len
+ )
+ )
self.vocab = vocab
self.word_to_char = word_to_char
@@ -93,12 +99,14 @@ def reset_parameters(self):
nn.init.xavier_normal_(self.symbol_embeddings)
nn.init.xavier_uniform_(self.projection.weight)
- nn.init.constant_(self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.)
- nn.init.constant_(self.projection.bias, 0.)
+ nn.init.constant_(
+ self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.0
+ )
+ nn.init.constant_(self.projection.bias, 0.0)
def forward(
- self,
- input: torch.Tensor,
+ self,
+ input: torch.Tensor,
):
if self.char_inputs:
chars = input.view(-1, self.max_char_len)
@@ -113,7 +121,9 @@ def forward(
unk = None
else:
flat_words = input.view(-1)
- chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as(input)
+ chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as(
+ input
+ )
pads = flat_words.eq(self.vocab.pad())
eos = flat_words.eq(self.vocab.eos())
unk = flat_words.eq(self.vocab.unk())
@@ -121,11 +131,17 @@ def forward(
word_embs = self._convolve(chars)
if self.onnx_trace:
if pads.any():
- word_embs = torch.where(pads.unsqueeze(1), word_embs.new_zeros(1), word_embs)
+ word_embs = torch.where(
+ pads.unsqueeze(1), word_embs.new_zeros(1), word_embs
+ )
if eos.any():
- word_embs = torch.where(eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs)
+ word_embs = torch.where(
+ eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs
+ )
if unk is not None and unk.any():
- word_embs = torch.where(unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs)
+ word_embs = torch.where(
+ unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs
+ )
else:
if pads.any():
word_embs[pads] = 0
@@ -137,8 +153,8 @@ def forward(
return word_embs.view(input.size()[:2] + (-1,))
def _convolve(
- self,
- char_idxs: torch.Tensor,
+ self,
+ char_idxs: torch.Tensor,
):
char_embs = self.char_embeddings(char_idxs)
char_embs = char_embs.transpose(1, 2) # BTC -> BCT
@@ -166,15 +182,12 @@ class Highway(torch.nn.Module):
Adopted from the AllenNLP implementation.
"""
- def __init__(
- self,
- input_dim: int,
- num_layers: int = 1
- ):
+ def __init__(self, input_dim: int, num_layers: int = 1):
super(Highway, self).__init__()
self.input_dim = input_dim
- self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2)
- for _ in range(num_layers)])
+ self.layers = nn.ModuleList(
+ [nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]
+ )
self.activation = nn.ReLU()
self.reset_parameters()
@@ -186,15 +199,12 @@ def reset_parameters(self):
# setting the bias on `B(x)` to be positive, because that means `g` will be biased to
# be high, so we will carry the input forward. The bias on `B(x)` is the second half
# of the bias vector in each Linear layer.
- nn.init.constant_(layer.bias[self.input_dim:], 1)
+ nn.init.constant_(layer.bias[self.input_dim :], 1)
- nn.init.constant_(layer.bias[:self.input_dim], 0)
+ nn.init.constant_(layer.bias[: self.input_dim], 0)
nn.init.xavier_normal_(layer.weight)
- def forward(
- self,
- x: torch.Tensor
- ):
+ def forward(self, x: torch.Tensor):
for layer in self.layers:
projection = layer(x)
proj_x, gate = projection.chunk(2, dim=-1)
diff --git a/fairseq/modules/conv_tbc.py b/fairseq/modules/conv_tbc.py
index 1aa3eff9dc..2dc46c4b9b 100644
--- a/fairseq/modules/conv_tbc.py
+++ b/fairseq/modules/conv_tbc.py
@@ -13,6 +13,7 @@ class ConvTBC(torch.nn.Module):
The implementation uses gemm to perform the convolution. This implementation
is faster than cuDNN for small kernel sizes.
"""
+
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
super(ConvTBC, self).__init__()
self.in_channels = in_channels
@@ -20,17 +21,22 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
self.kernel_size = _single(kernel_size)
self.padding = _single(padding)
- self.weight = torch.nn.Parameter(torch.Tensor(
- self.kernel_size[0], in_channels, out_channels))
+ self.weight = torch.nn.Parameter(
+ torch.Tensor(self.kernel_size[0], in_channels, out_channels)
+ )
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
def forward(self, input):
- return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0])
+ return torch.conv_tbc(
+ input.contiguous(), self.weight, self.bias, self.padding[0]
+ )
def __repr__(self):
- s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
- ', padding={padding}')
+ s = (
+ "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}"
+ ", padding={padding}"
+ )
if self.bias is None:
- s += ', bias=False'
- s += ')'
+ s += ", bias=False"
+ s += ")"
return s.format(name=self.__class__.__name__, **self.__dict__)
diff --git a/fairseq/modules/cross_entropy.py b/fairseq/modules/cross_entropy.py
index b46143f3af..0d2beb44bb 100644
--- a/fairseq/modules/cross_entropy.py
+++ b/fairseq/modules/cross_entropy.py
@@ -12,10 +12,13 @@
logger = logging.getLogger(__name__)
-def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction='mean'):
+def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"):
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
return F.nll_loss(
- lprobs, target, ignore_index=ignore_index, reduction=reduction,
+ lprobs,
+ target,
+ ignore_index=ignore_index,
+ reduction=reduction,
)
@@ -23,29 +26,34 @@ def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction='mean'):
import xentropy_cuda
from apex.contrib import xentropy
- logger.info('using fused cross entropy')
+ logger.info("using fused cross entropy")
- def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
- if logits.device == torch.device('cpu'):
+ def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
+ if logits.device == torch.device("cpu"):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
else:
- half_to_float = (logits.dtype == torch.half)
+ half_to_float = logits.dtype == torch.half
losses = xentropy.SoftmaxCrossEntropyLoss.apply(
- logits, target, 0.0, ignore_index, half_to_float,
+ logits,
+ target,
+ 0.0,
+ ignore_index,
+ half_to_float,
)
- if reduction == 'sum':
+ if reduction == "sum":
return losses.sum()
- elif reduction == 'mean':
+ elif reduction == "mean":
if ignore_index >= 0:
return losses.sum() / target.ne(ignore_index).sum()
else:
return losses.mean()
- elif reduction == 'none':
+ elif reduction == "none":
return losses
else:
raise NotImplementedError
+
except ImportError:
- def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
+ def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
diff --git a/fairseq/modules/downsampled_multihead_attention.py b/fairseq/modules/downsampled_multihead_attention.py
index eeaf9bbdd3..2cdece3f7f 100644
--- a/fairseq/modules/downsampled_multihead_attention.py
+++ b/fairseq/modules/downsampled_multihead_attention.py
@@ -9,22 +9,33 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from fairseq.modules.scalar_bias import scalar_bias
from fairseq.modules.fairseq_dropout import FairseqDropout
+from fairseq.modules.scalar_bias import scalar_bias
class SingleHeadAttention(nn.Module):
"""
Single-head attention that supports Gating and Downsampling
"""
+
def __init__(
- self, out_channels, embed_dim, head_dim, head_index, dropout=0.,
- bias=True, project_input=True, gated=False, downsample=False,
+ self,
+ out_channels,
+ embed_dim,
+ head_dim,
+ head_index,
+ dropout=0.0,
+ bias=True,
+ project_input=True,
+ gated=False,
+ downsample=False,
num_heads=1,
):
super().__init__()
self.embed_dim = embed_dim
- self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ dropout, module_name=self.__class__.__name__
+ )
self.head_index = head_index
self.head_dim = head_dim
self.project_input = project_input
@@ -58,11 +69,16 @@ def __init__(
else:
self.out_proj = Linear(out_proj_size, out_channels, bias=bias)
- self.scaling = self.head_dim**-0.5
+ self.scaling = self.head_dim ** -0.5
def forward(
- self, query, key, value, mask_future_timesteps=False,
- key_padding_mask=None, use_scalar_bias=False,
+ self,
+ query,
+ key,
+ value,
+ mask_future_timesteps=False,
+ key_padding_mask=None,
+ use_scalar_bias=False,
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
@@ -106,16 +122,17 @@ def forward(
attn_weights = torch.bmm(q, k.transpose(1, 2))
if mask_future_timesteps:
- assert query.size() == key.size(), \
- 'mask_future_timesteps only applies to self-attention'
+ assert (
+ query.size() == key.size()
+ ), "mask_future_timesteps only applies to self-attention"
attn_weights *= torch.tril(
attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(),
diagonal=-1,
- )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
+ )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0)
attn_weights += torch.triu(
attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(),
- diagonal=0
- )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
+ diagonal=0,
+ )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0)
tgt_size = tgt_len
if use_scalar_bias:
attn_weights = scalar_bias(attn_weights, 2)
@@ -128,7 +145,9 @@ def forward(
if self.downsample:
attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len)
else:
- attn_weights = attn_weights.view(size, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(
+ size, self.num_heads, tgt_len, src_len
+ )
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-math.inf,
@@ -152,9 +171,17 @@ class DownsampledMultiHeadAttention(nn.ModuleList):
"""
Multi-headed attention with Gating and Downsampling
"""
+
def __init__(
- self, out_channels, embed_dim, num_heads, dropout=0., bias=True,
- project_input=True, gated=False, downsample=False,
+ self,
+ out_channels,
+ embed_dim,
+ num_heads,
+ dropout=0.0,
+ bias=True,
+ project_input=True,
+ gated=False,
+ downsample=False,
):
self.embed_dim = embed_dim
self.num_heads = num_heads
@@ -169,9 +196,16 @@ def __init__(
for index in range(self.num_heads):
attention_heads.append(
SingleHeadAttention(
- out_channels, self.embed_dim, self.head_dim, index,
- dropout, bias, self.project_input, self.gated,
- self.downsample, self.num_heads,
+ out_channels,
+ self.embed_dim,
+ self.head_dim,
+ index,
+ dropout,
+ bias,
+ self.project_input,
+ self.gated,
+ self.downsample,
+ self.num_heads,
)
)
super().__init__(modules=attention_heads)
@@ -181,13 +215,26 @@ def __init__(
# if not being downsampled, we can do the heads with one linear layer instead of separate ones
super().__init__()
self.attention_module = SingleHeadAttention(
- out_channels, self.embed_dim, self.head_dim, 1, dropout,
- bias, self.project_input, self.gated, self.downsample, self.num_heads,
+ out_channels,
+ self.embed_dim,
+ self.head_dim,
+ 1,
+ dropout,
+ bias,
+ self.project_input,
+ self.gated,
+ self.downsample,
+ self.num_heads,
)
def forward(
- self, query, key, value, mask_future_timesteps=False,
- key_padding_mask=None, use_scalar_bias=False,
+ self,
+ query,
+ key,
+ value,
+ mask_future_timesteps=False,
+ key_padding_mask=None,
+ use_scalar_bias=False,
):
src_len, bsz, embed_dim = key.size()
tgt_len = query.size(0)
@@ -205,7 +252,12 @@ def forward(
for attention_head_number in range(self.num_heads):
# call the forward of each attention head
_attn, _attn_weight = self[attention_head_number](
- query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
+ query,
+ key,
+ value,
+ mask_future_timesteps,
+ key_padding_mask,
+ use_scalar_bias,
)
attn.append(_attn)
attn_weights.append(_attn_weight)
@@ -214,13 +266,20 @@ def forward(
return full_attn, attn_weights[0].clone()
else:
_attn, _attn_weight = self.attention_module(
- query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
+ query,
+ key,
+ value,
+ mask_future_timesteps,
+ key_padding_mask,
+ use_scalar_bias,
)
attn.append(_attn)
attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2)
full_attn_weights = torch.cat(attn_weights)
- full_attn_weights = full_attn_weights.view(bsz, self.num_heads, tgt_size, src_len)
+ full_attn_weights = full_attn_weights.view(
+ bsz, self.num_heads, tgt_size, src_len
+ )
full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads
return full_attn, full_attn_weights
@@ -229,15 +288,16 @@ class Downsample(nn.Module):
"""
Selects every nth element, where n is the index
"""
+
def __init__(self, index):
super().__init__()
self.index = index
def forward(self, x):
- return x[::self.index+1]
+ return x[:: self.index + 1]
-def Linear(in_features, out_features, dropout=0., bias=True):
+def Linear(in_features, out_features, dropout=0.0, bias=True):
"""Weight-normalized Linear layer (input: B x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
@@ -245,12 +305,12 @@ def Linear(in_features, out_features, dropout=0., bias=True):
return nn.utils.weight_norm(m)
-def GatedLinear(in_features, out_features, dropout=0., bias=True):
+def GatedLinear(in_features, out_features, dropout=0.0, bias=True):
"""Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units"""
return nn.Sequential(
- Linear(in_features, out_features*4, dropout, bias),
+ Linear(in_features, out_features * 4, dropout, bias),
nn.GLU(),
- Linear(out_features*2, out_features*2, dropout, bias),
+ Linear(out_features * 2, out_features * 2, dropout, bias),
nn.GLU(),
- Linear(out_features, out_features, dropout, bias)
+ Linear(out_features, out_features, dropout, bias),
)
diff --git a/fairseq/modules/dynamic_convolution.py b/fairseq/modules/dynamic_convolution.py
index 5a8ecb99a8..5999a04539 100644
--- a/fairseq/modules/dynamic_convolution.py
+++ b/fairseq/modules/dynamic_convolution.py
@@ -6,43 +6,63 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import utils
-from .unfold import unfold1d
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
+from .unfold import unfold1d
+
-def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
- weight_dropout=0., weight_softmax=False,
- renorm_padding=False, bias=False, conv_bias=False,
- query_size=None, in_proj=False):
+def DynamicConv(
+ input_size,
+ kernel_size=1,
+ padding_l=None,
+ num_heads=1,
+ weight_dropout=0.0,
+ weight_softmax=False,
+ renorm_padding=False,
+ bias=False,
+ conv_bias=False,
+ query_size=None,
+ in_proj=False,
+):
if torch.cuda.is_available():
try:
from fairseq.modules.dynamicconv_layer import DynamicconvLayer
- return DynamicconvLayer(input_size, kernel_size=kernel_size,
- padding_l=padding_l, num_heads=num_heads,
- weight_dropout=weight_dropout,
- weight_softmax=weight_softmax, bias=bias)
+
+ return DynamicconvLayer(
+ input_size,
+ kernel_size=kernel_size,
+ padding_l=padding_l,
+ num_heads=num_heads,
+ weight_dropout=weight_dropout,
+ weight_softmax=weight_softmax,
+ bias=bias,
+ )
except ImportError as e:
print(e)
- return DynamicConv1dTBC(input_size, kernel_size=kernel_size,
- padding_l=padding_l, num_heads=num_heads,
- weight_dropout=weight_dropout,
- weight_softmax=weight_softmax, bias=bias)
+ return DynamicConv1dTBC(
+ input_size,
+ kernel_size=kernel_size,
+ padding_l=padding_l,
+ num_heads=num_heads,
+ weight_dropout=weight_dropout,
+ weight_softmax=weight_softmax,
+ bias=bias,
+ )
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
- nn.init.constant_(m.bias, 0.)
+ nn.init.constant_(m.bias, 0.0)
return m
@with_incremental_state
class DynamicConv1dTBC(nn.Module):
- '''Dynamic lightweight convolution taking T x B x C inputs
+ """Dynamic lightweight convolution taking T x B x C inputs
Args:
input_size: # of channels of the input
kernel_size: convolution channels
@@ -64,25 +84,42 @@ class DynamicConv1dTBC(nn.Module):
weight: the learnable weights of the module of shape
`(num_heads, 1, kernel_size)`
bias: the learnable bias of the module of shape `(input_size)`
- '''
- def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1,
- weight_dropout=0., weight_softmax=False,
- renorm_padding=False, bias=False, conv_bias=False,
- query_size=None, in_proj=False):
+ """
+
+ def __init__(
+ self,
+ input_size,
+ kernel_size=1,
+ padding_l=None,
+ num_heads=1,
+ weight_dropout=0.0,
+ weight_softmax=False,
+ renorm_padding=False,
+ bias=False,
+ conv_bias=False,
+ query_size=None,
+ in_proj=False,
+ ):
super().__init__()
self.input_size = input_size
self.query_size = input_size if query_size is None else query_size
self.kernel_size = kernel_size
self.padding_l = padding_l
self.num_heads = num_heads
- self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__)
+ self.weight_dropout_module = FairseqDropout(
+ weight_dropout, module_name=self.__class__.__name__
+ )
self.weight_softmax = weight_softmax
self.renorm_padding = renorm_padding
if in_proj:
- self.weight_linear = Linear(self.input_size, self.input_size + num_heads * kernel_size * 1)
+ self.weight_linear = Linear(
+ self.input_size, self.input_size + num_heads * kernel_size * 1
+ )
else:
- self.weight_linear = Linear(self.query_size, num_heads * kernel_size * 1, bias=bias)
+ self.weight_linear = Linear(
+ self.query_size, num_heads * kernel_size * 1, bias=bias
+ )
if conv_bias:
self.conv_bias = nn.Parameter(torch.Tensor(input_size))
else:
@@ -91,22 +128,27 @@ def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1,
@property
def in_proj(self):
- return self.weight_linear.out_features == self.input_size + self.num_heads * self.kernel_size
+ return (
+ self.weight_linear.out_features
+ == self.input_size + self.num_heads * self.kernel_size
+ )
def reset_parameters(self):
self.weight_linear.reset_parameters()
if self.conv_bias is not None:
- nn.init.constant_(self.conv_bias, 0.)
+ nn.init.constant_(self.conv_bias, 0.0)
def forward(self, x, incremental_state=None, query=None, unfold=None):
- '''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C
+ """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C
args:
x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size)
incremental_state: A dict to keep the state
unfold: unfold the input or not. If not, we use the matrix trick instead
query: use the specified query to predict the conv filters
- '''
- unfold = x.size(0) > 512 if unfold is None else unfold # use unfold mode as default for long sequence to save memory
+ """
+ unfold = (
+ x.size(0) > 512 if unfold is None else unfold
+ ) # use unfold mode as default for long sequence to save memory
unfold = unfold or (incremental_state is not None)
assert query is None or not self.in_proj
@@ -122,8 +164,8 @@ def forward(self, x, incremental_state=None, query=None, unfold=None):
return output
def _forward_unfolded(self, x, incremental_state, query):
- '''The conventional implementation of convolutions.
- Unfolding the input by having a window shifting to the right.'''
+ """The conventional implementation of convolutions.
+ Unfolding the input by having a window shifting to the right."""
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
@@ -132,9 +174,11 @@ def _forward_unfolded(self, x, incremental_state, query):
if self.in_proj:
proj = self.weight_linear(x)
x = proj.narrow(2, 0, self.input_size).contiguous()
- weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1)
+ weight = (
+ proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1)
+ )
else:
- weight = self.weight_linear(query).view(T*B*H, -1)
+ weight = self.weight_linear(query).view(T * B * H, -1)
# renorm_padding is only implemented in _forward_expanded
assert not self.renorm_padding or incremental_state is not None
@@ -145,23 +189,25 @@ def _forward_unfolded(self, x, incremental_state, query):
input_buffer = x.new()
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
if self.kernel_size > 1:
- self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:])
- x_unfold = x_unfold.view(T*B*H, R, -1)
+ self._set_input_buffer(
+ incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :]
+ )
+ x_unfold = x_unfold.view(T * B * H, R, -1)
else:
padding_l = self.padding_l
- if K > T and padding_l == K-1:
- weight = weight.narrow(1, K-T, T)
- K, padding_l = T, T-1
+ if K > T and padding_l == K - 1:
+ weight = weight.narrow(1, K - T, T)
+ K, padding_l = T, T - 1
# unfold the input: T x B x C --> T' x B x C x K
x_unfold = unfold1d(x, K, padding_l, 0)
- x_unfold = x_unfold.view(T*B*H, R, K)
+ x_unfold = x_unfold.view(T * B * H, R, K)
if self.weight_softmax and not self.renorm_padding:
weight = F.softmax(weight, dim=1)
weight = weight.narrow(1, 0, K)
if incremental_state is not None:
- weight = weight[:, -x_unfold.size(2):]
+ weight = weight[:, -x_unfold.size(2) :]
K = weight.size(1)
if self.weight_softmax and self.renorm_padding:
@@ -174,10 +220,10 @@ def _forward_unfolded(self, x, incremental_state, query):
return output
def _forward_expanded(self, x, incremental_stat, query):
- '''Turn the convolution filters into band matrices and do matrix multiplication.
+ """Turn the convolution filters into band matrices and do matrix multiplication.
This is faster when the sequence is short, but less memory efficient.
This is not used in the decoder during inference.
- '''
+ """
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
@@ -185,22 +231,26 @@ def _forward_expanded(self, x, incremental_stat, query):
if self.in_proj:
proj = self.weight_linear(x)
x = proj.narrow(2, 0, self.input_size).contiguous()
- weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1)
+ weight = (
+ proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1)
+ )
else:
- weight = self.weight_linear(query).view(T*B*H, -1)
+ weight = self.weight_linear(query).view(T * B * H, -1)
if not self.renorm_padding:
if self.weight_softmax:
weight = F.softmax(weight, dim=1)
weight = self.weight_dropout_module(weight, inplace=False)
weight = weight.narrow(1, 0, K).contiguous()
- weight = weight.view(T, B*H, K).transpose(0, 1)
+ weight = weight.view(T, B * H, K).transpose(0, 1)
- x = x.view(T, B*H, R).transpose(0, 1)
+ x = x.view(T, B * H, R).transpose(0, 1)
if self.weight_softmax and self.renorm_padding:
# turn the convolution filters into band matrices
- weight_expanded = weight.new(B*H, T, T+K-1).fill_(float('-inf'))
- weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
+ weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf"))
+ weight_expanded.as_strided(
+ (B * H, T, K), (T * (T + K - 1), T + K, 1)
+ ).copy_(weight)
weight_expanded = weight_expanded.narrow(2, self.padding_l, T)
# normalize the weight over valid positions like self-attention
weight_expanded = F.softmax(weight_expanded, dim=2)
@@ -208,12 +258,14 @@ def _forward_expanded(self, x, incremental_stat, query):
else:
P = self.padding_l
# For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length
- if K > T and P == K-1:
- weight = weight.narrow(2, K-T, T)
- K, P = T, T-1
+ if K > T and P == K - 1:
+ weight = weight.narrow(2, K - T, T)
+ K, P = T, T - 1
# turn the convolution filters into band matrices
- weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False)
- weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
+ weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False)
+ weight_expanded.as_strided(
+ (B * H, T, K), (T * (T + K - 1), T + K, 1)
+ ).copy_(weight)
weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T
output = torch.bmm(weight_expanded, x)
output = output.transpose(0, 1).contiguous().view(T, B, C)
@@ -226,20 +278,27 @@ def reorder_incremental_state(self, incremental_state, new_order):
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
- return utils.get_incremental_state(self, incremental_state, 'input_buffer')
+ return utils.get_incremental_state(self, incremental_state, "input_buffer")
def _set_input_buffer(self, incremental_state, new_buffer):
- return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
+ return utils.set_incremental_state(
+ self, incremental_state, "input_buffer", new_buffer
+ )
def extra_repr(self):
- s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}'.format(
- self.input_size, self.kernel_size, self.padding_l,
- self.num_heads, self.weight_softmax, self.conv_bias is not None, self.renorm_padding,
+ s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}".format(
+ self.input_size,
+ self.kernel_size,
+ self.padding_l,
+ self.num_heads,
+ self.weight_softmax,
+ self.conv_bias is not None,
+ self.renorm_padding,
self.in_proj,
)
if self.query_size != self.input_size:
- s += ', query_size={}'.format(self.query_size)
- if self.weight_dropout_module.p > 0.:
- s += ', weight_dropout={}'.format(self.weight_dropout_module.p)
+ s += ", query_size={}".format(self.query_size)
+ if self.weight_dropout_module.p > 0.0:
+ s += ", weight_dropout={}".format(self.weight_dropout_module.p)
return s
diff --git a/fairseq/modules/dynamic_crf_layer.py b/fairseq/modules/dynamic_crf_layer.py
index 6f5acf3772..8fcc6b8d26 100644
--- a/fairseq/modules/dynamic_crf_layer.py
+++ b/fairseq/modules/dynamic_crf_layer.py
@@ -27,16 +27,16 @@ def logsumexp(x, dim=1):
class DynamicCRF(nn.Module):
"""Dynamic CRF layer is used to approximate the traditional
- Conditional Random Fields (CRF)
- $P(y | x) = 1/Z(x) exp(sum_i s(y_i, x) + sum_i t(y_{i-1}, y_i, x))$
+ Conditional Random Fields (CRF)
+ $P(y | x) = 1/Z(x) exp(sum_i s(y_i, x) + sum_i t(y_{i-1}, y_i, x))$
- where in this function, we assume the emition scores (s) are given,
- and the transition score is a |V| x |V| matrix $M$
+ where in this function, we assume the emition scores (s) are given,
+ and the transition score is a |V| x |V| matrix $M$
- in the following two aspects:
- (1) it used a low-rank approximation for the transition matrix:
- $M = E_1 E_2^T$
- (2) it used a beam to estimate the normalizing factor Z(x)
+ in the following two aspects:
+ (1) it used a low-rank approximation for the transition matrix:
+ $M = E_1 E_2^T$
+ (2) it used a beam to estimate the normalizing factor Z(x)
"""
def __init__(self, num_embedding, low_rank=32, beam_size=64):
@@ -51,7 +51,8 @@ def __init__(self, num_embedding, low_rank=32, beam_size=64):
def extra_repr(self):
return "vocab_size={}, low_rank={}, beam_size={}".format(
- self.vocb, self.rank, self.beam)
+ self.vocb, self.rank, self.beam
+ )
def forward(self, emissions, targets, masks, beam=None):
"""
@@ -104,26 +105,27 @@ def _compute_normalizer(self, emissions, targets=None, masks=None, beam=None):
beam = beam if beam is not None else self.beam
batch_size, seq_len = emissions.size()[:2]
if targets is not None:
- _emissions = emissions.scatter(2, targets[:, :, None], np.float('inf'))
+ _emissions = emissions.scatter(2, targets[:, :, None], np.float("inf"))
beam_targets = _emissions.topk(beam, 2)[1]
beam_emission_scores = emissions.gather(2, beam_targets)
else:
beam_emission_scores, beam_targets = emissions.topk(beam, 2)
beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D
- beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D
+ beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D
beam_transition_matrix = torch.bmm(
beam_transition_score1.view(-1, beam, self.rank),
- beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
+ beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2),
+ )
beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam)
# compute the normalizer in the log-space
score = beam_emission_scores[:, 0] # B x K
for i in range(1, seq_len):
- next_score = score[:, :, None] + beam_transition_matrix[:, i-1]
+ next_score = score[:, :, None] + beam_transition_matrix[:, i - 1]
next_score = logsumexp(next_score, dim=1) + beam_emission_scores[:, i]
if masks is not None:
- score = torch.where(masks[:, i:i+1], next_score, score)
+ score = torch.where(masks[:, i : i + 1], next_score, score)
else:
score = next_score
@@ -137,10 +139,11 @@ def _viterbi_decode(self, emissions, masks=None, beam=None):
batch_size, seq_len = emissions.size()[:2]
beam_emission_scores, beam_targets = emissions.topk(beam, 2)
beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D
- beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D
+ beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D
beam_transition_matrix = torch.bmm(
beam_transition_score1.view(-1, beam, self.rank),
- beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
+ beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2),
+ )
beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam)
traj_tokens, traj_scores = [], []
@@ -148,17 +151,19 @@ def _viterbi_decode(self, emissions, masks=None, beam=None):
# compute the normalizer in the log-space
score = beam_emission_scores[:, 0] # B x K
- dummy = torch.arange(beam, device=score.device).expand(*score.size()).contiguous()
+ dummy = (
+ torch.arange(beam, device=score.device).expand(*score.size()).contiguous()
+ )
for i in range(1, seq_len):
traj_scores.append(score)
- _score = score[:, :, None] + beam_transition_matrix[:, i-1]
+ _score = score[:, :, None] + beam_transition_matrix[:, i - 1]
_score, _index = _score.max(dim=1)
_score = _score + beam_emission_scores[:, i]
if masks is not None:
- score = torch.where(masks[:, i: i+1], _score, score)
- index = torch.where(masks[:, i: i+1], _index, dummy)
+ score = torch.where(masks[:, i : i + 1], _score, score)
+ index = torch.where(masks[:, i : i + 1], _index, dummy)
else:
score, index = _score, _index
traj_tokens.append(index)
diff --git a/fairseq/modules/dynamicconv_layer/cuda_function_gen.py b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py
index 926d6ca846..9304f99eb8 100644
--- a/fairseq/modules/dynamicconv_layer/cuda_function_gen.py
+++ b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py
@@ -77,7 +77,7 @@ def gen_forward():
}
"""
- with open("dynamicconv_cuda_forward.cu", 'w') as forward:
+ with open("dynamicconv_cuda_forward.cu", "w") as forward:
forward.write(head)
forward.write(switch)
for k in kernels:
@@ -191,7 +191,7 @@ def gen_backward():
}
"""
- with open("dynamicconv_cuda_backward.cu", 'w') as backward:
+ with open("dynamicconv_cuda_backward.cu", "w") as backward:
backward.write(head)
for seq in seqs:
backward.write(sequence_if.format(seq=seq))
diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py
index 52cc1e8118..4a683d2690 100644
--- a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py
+++ b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py
@@ -3,20 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+import dynamicconv_cuda
import torch
-from torch import nn
-from torch.autograd import Function
import torch.nn.functional as F
-
-import dynamicconv_cuda
from fairseq import utils
-from fairseq.modules.unfold import unfold1d
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
+from fairseq.modules.unfold import unfold1d
+from torch import nn
+from torch.autograd import Function
class dynamicconvFunction(Function):
-
@staticmethod
def forward(ctx, x, weights, padding_l):
ctx.padding_l = padding_l
@@ -28,9 +26,8 @@ def forward(ctx, x, weights, padding_l):
@staticmethod
def backward(ctx, grad_output):
outputs = dynamicconv_cuda.backward(
- grad_output.contiguous(),
- ctx.padding_l,
- *ctx.saved_tensors)
+ grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors
+ )
grad_input, grad_weights = outputs
return grad_input, grad_weights, None
@@ -38,17 +35,17 @@ def backward(ctx, grad_output):
@with_incremental_state
class DynamicconvLayer(nn.Module):
def __init__(
- self,
- input_size,
- kernel_size=1,
- padding_l=None,
- weight_softmax=False,
- num_heads=1,
- weight_dropout=0.,
- bias=False,
- renorm_padding=False,
- conv_bias=False,
- query_size=None,
+ self,
+ input_size,
+ kernel_size=1,
+ padding_l=None,
+ weight_softmax=False,
+ num_heads=1,
+ weight_dropout=0.0,
+ bias=False,
+ renorm_padding=False,
+ conv_bias=False,
+ query_size=None,
):
super(DynamicconvLayer, self).__init__()
@@ -58,7 +55,9 @@ def __init__(
self.padding_l = padding_l
self.num_heads = num_heads
self.weight_softmax = weight_softmax
- self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__)
+ self.weight_dropout_module = FairseqDropout(
+ weight_dropout, module_name=self.__class__.__name__
+ )
self.renorm_padding = renorm_padding
self.bias = bias
@@ -72,8 +71,8 @@ def __init__(
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight_linear.weight)
if self.conv_bias is not None:
- nn.init.constant_(self.conv_bias, 0.)
- nn.init.constant_(self.weight_linaer.bias, 0.)
+ nn.init.constant_(self.conv_bias, 0.0)
+ nn.init.constant_(self.weight_linaer.bias, 0.0)
def forward(self, x, incremental_state=None, query=None, unfold=None):
@@ -83,7 +82,9 @@ def forward(self, x, incremental_state=None, query=None, unfold=None):
# during inference time, incremental BMM is faster
if incremental_state is not None:
- unfold = x.size(0) > 512 if unfold is None else unfold # use unfold mode as default for long sequence to save memory
+ unfold = (
+ x.size(0) > 512 if unfold is None else unfold
+ ) # use unfold mode as default for long sequence to save memory
unfold = unfold or (incremental_state is not None)
assert query is None
@@ -110,7 +111,9 @@ def forward(self, x, incremental_state=None, query=None, unfold=None):
weight = weight.permute(1, 2, 3, 0).contiguous()
self.filters = weight
x = x.permute(1, 2, 0).contiguous()
- output = dynamicconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1)
+ output = dynamicconvFunction.apply(x, weight, self.padding_l).permute(
+ 2, 0, 1
+ )
if self.conv_bias is not None:
output = output + self.conv_bias.view(1, 1, -1)
return output
@@ -122,20 +125,22 @@ def reorder_incremental_state(self, incremental_state, new_order):
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
- return utils.get_incremental_state(self, incremental_state, 'input_buffer')
+ return utils.get_incremental_state(self, incremental_state, "input_buffer")
def _set_input_buffer(self, incremental_state, new_buffer):
- return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
+ return utils.set_incremental_state(
+ self, incremental_state, "input_buffer", new_buffer
+ )
def _forward_unfolded(self, x, incremental_state, query):
- '''The conventional implementation of convolutions.
- Unfolding the input by having a window shifting to the right.'''
+ """The conventional implementation of convolutions.
+ Unfolding the input by having a window shifting to the right."""
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
assert R * H == C == self.input_size
- weight = self.weight_linear(query).view(T*B*H, -1)
+ weight = self.weight_linear(query).view(T * B * H, -1)
# renorm_padding is only implemented in _forward_expanded
assert not self.renorm_padding or incremental_state is not None
@@ -146,23 +151,25 @@ def _forward_unfolded(self, x, incremental_state, query):
input_buffer = x.new()
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
if self.kernel_size > 1:
- self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:])
- x_unfold = x_unfold.view(T*B*H, R, -1)
+ self._set_input_buffer(
+ incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :]
+ )
+ x_unfold = x_unfold.view(T * B * H, R, -1)
else:
padding_l = self.padding_l
- if K > T and padding_l == K-1:
- weight = weight.narrow(1, K-T, T)
- K, padding_l = T, T-1
+ if K > T and padding_l == K - 1:
+ weight = weight.narrow(1, K - T, T)
+ K, padding_l = T, T - 1
# unfold the input: T x B x C --> T' x B x C x K
x_unfold = unfold1d(x, K, padding_l, 0)
- x_unfold = x_unfold.view(T*B*H, R, K)
+ x_unfold = x_unfold.view(T * B * H, R, K)
if self.weight_softmax and not self.renorm_padding:
weight = F.softmax(weight, dim=1)
weight = weight.narrow(1, 0, K)
if incremental_state is not None:
- weight = weight[:, -x_unfold.size(2):]
+ weight = weight[:, -x_unfold.size(2) :]
K = weight.size(1)
if self.weight_softmax and self.renorm_padding:
@@ -175,28 +182,30 @@ def _forward_unfolded(self, x, incremental_state, query):
return output
def _forward_expanded(self, x, incremental_stat, query):
- '''Turn the convolution filters into band matrices and do matrix multiplication.
+ """Turn the convolution filters into band matrices and do matrix multiplication.
This is faster when the sequence is short, but less memory efficient.
This is not used in the decoder during inference.
- '''
+ """
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
assert R * H == C == self.input_size
- weight = self.weight_linear(query).view(T*B*H, -1)
+ weight = self.weight_linear(query).view(T * B * H, -1)
if not self.renorm_padding:
if self.weight_softmax:
weight = F.softmax(weight, dim=1)
weight = self.weight_dropout_module(weight, inplace=False)
weight = weight.narrow(1, 0, K).contiguous()
- weight = weight.view(T, B*H, K).transpose(0, 1)
+ weight = weight.view(T, B * H, K).transpose(0, 1)
- x = x.view(T, B*H, R).transpose(0, 1)
+ x = x.view(T, B * H, R).transpose(0, 1)
if self.weight_softmax and self.renorm_padding:
# turn the convolution filters into band matrices
- weight_expanded = weight.new(B*H, T, T+K-1).fill_(float('-inf'))
- weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
+ weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf"))
+ weight_expanded.as_strided(
+ (B * H, T, K), (T * (T + K - 1), T + K, 1)
+ ).copy_(weight)
weight_expanded = weight_expanded.narrow(2, self.padding_l, T)
# normalize the weight over valid positions like self-attention
weight_expanded = F.softmax(weight_expanded, dim=2)
@@ -204,12 +213,14 @@ def _forward_expanded(self, x, incremental_stat, query):
else:
P = self.padding_l
# For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length
- if K > T and P == K-1:
- weight = weight.narrow(2, K-T, T)
- K, P = T, T-1
+ if K > T and P == K - 1:
+ weight = weight.narrow(2, K - T, T)
+ K, P = T, T - 1
# turn the convolution filters into band matrices
- weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False)
- weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
+ weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False)
+ weight_expanded.as_strided(
+ (B * H, T, K), (T * (T + K - 1), T + K, 1)
+ ).copy_(weight)
weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T
output = torch.bmm(weight_expanded, x)
output = output.transpose(0, 1).contiguous().view(T, B, C)
diff --git a/fairseq/modules/dynamicconv_layer/setup.py b/fairseq/modules/dynamicconv_layer/setup.py
index 4d789c3283..6a21f7e2ee 100644
--- a/fairseq/modules/dynamicconv_layer/setup.py
+++ b/fairseq/modules/dynamicconv_layer/setup.py
@@ -5,19 +5,19 @@
# LICENSE file in the root directory of this source tree.
from setuptools import setup
-from torch.utils.cpp_extension import CUDAExtension, BuildExtension
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
setup(
- name='dynamicconv_layer',
+ name="dynamicconv_layer",
ext_modules=[
CUDAExtension(
- name='dynamicconv_cuda',
+ name="dynamicconv_cuda",
sources=[
- 'dynamicconv_cuda.cpp',
- 'dynamicconv_cuda_kernel.cu',
+ "dynamicconv_cuda.cpp",
+ "dynamicconv_cuda_kernel.cu",
],
),
],
- cmdclass={
- 'build_ext': BuildExtension
- })
+ cmdclass={"build_ext": BuildExtension},
+)
diff --git a/fairseq/modules/fairseq_dropout.py b/fairseq/modules/fairseq_dropout.py
index cbfacf477f..f070a804e6 100644
--- a/fairseq/modules/fairseq_dropout.py
+++ b/fairseq/modules/fairseq_dropout.py
@@ -14,7 +14,6 @@
class FairseqDropout(nn.Module):
-
def __init__(self, p, module_name=None):
super().__init__()
self.p = p
@@ -37,16 +36,16 @@ def make_generation_fast_(
if retain_dropout:
if retain_dropout_modules is not None and self.module_name is None:
logger.warning(
- 'Cannot enable dropout during inference for module {} '
- 'because module_name was not set'.format(name)
+ "Cannot enable dropout during inference for module {} "
+ "because module_name was not set".format(name)
)
elif (
retain_dropout_modules is None # if None, apply to all modules
or self.module_name in retain_dropout_modules
):
logger.info(
- 'Enabling dropout during inference for module: {}'.format(name)
+ "Enabling dropout during inference for module: {}".format(name)
)
self.apply_during_inference = True
else:
- logger.info('Disabling dropout for module: {}'.format(name))
+ logger.info("Disabling dropout for module: {}".format(name))
diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py
index 01ddd2298b..47657bb0ab 100644
--- a/fairseq/modules/gumbel_vector_quantizer.py
+++ b/fairseq/modules/gumbel_vector_quantizer.py
@@ -83,6 +83,7 @@ def set_num_updates(self, num_updates):
self.curr_temp = max(
self.max_temp * self.temp_decay ** num_updates, self.min_temp
)
+
def get_codebook_indices(self):
if self.codebook_indices is None:
from itertools import product
@@ -106,8 +107,8 @@ def codebook(self):
indices = self.get_codebook_indices()
return (
self.vars.squeeze(0)
- .index_select(0, indices)
- .view(self.num_vars ** self.groups, -1)
+ .index_select(0, indices)
+ .view(self.num_vars ** self.groups, -1)
)
def sample_from_codebook(self, b, n):
@@ -115,7 +116,7 @@ def sample_from_codebook(self, b, n):
indices = indices.view(-1, self.groups)
cb_size = indices.size(0)
assert (
- n < cb_size
+ n < cb_size
), f"sample size {n} is greater than size of codebook {cb_size}"
sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,))
indices = indices[sample_idx]
diff --git a/fairseq/modules/kmeans_vector_quantizer.py b/fairseq/modules/kmeans_vector_quantizer.py
index be56e6081b..040db1e83e 100644
--- a/fairseq/modules/kmeans_vector_quantizer.py
+++ b/fairseq/modules/kmeans_vector_quantizer.py
@@ -5,7 +5,6 @@
import torch
import torch.nn as nn
-
from fairseq.modules import Fp32GroupNorm
@@ -13,17 +12,17 @@ class KmeansVectorQuantizer(nn.Module):
def __init__(
self, dim, num_vars, groups, combine_groups, vq_dim, time_first, gamma=0.25
):
- '''Vector quantization using straight pass-through estimator (i.e. kmeans)
-
- Args:
- dim: input dimension (channels)
- num_vars: number of quantized vectors per group
- groups: number of groups for vector quantization
- combine_groups: whether to use the vectors for all groups
- vq_dim: dimensionality of the resulting quantized vector
- time_first: if true, expect input in BxTxC format, otherwise in BxCxT
- gamma: commitment loss coefficient
- '''
+ """Vector quantization using straight pass-through estimator (i.e. kmeans)
+
+ Args:
+ dim: input dimension (channels)
+ num_vars: number of quantized vectors per group
+ groups: number of groups for vector quantization
+ combine_groups: whether to use the vectors for all groups
+ vq_dim: dimensionality of the resulting quantized vector
+ time_first: if true, expect input in BxTxC format, otherwise in BxCxT
+ gamma: commitment loss coefficient
+ """
super().__init__()
self.groups = groups
@@ -51,7 +50,7 @@ def __init__(
self.mse_mean = nn.MSELoss(reduction="mean")
def _pass_grad(self, x, y):
- """ Manually set gradient for backward pass.
+ """Manually set gradient for backward pass.
for y = f(x), ensure that during the backward pass,
dL/dy = dL/dx regardless of f(x).
Returns:
@@ -102,9 +101,9 @@ def forward(self, x, produce_targets=False):
x = self._pass_grad(ze, zq)
hard_x = (
- idx.new_zeros(bsz*tsz*self.groups, self.num_vars)
- .scatter_(-1, idx.view(-1, 1), 1.0)
- .view(bsz * tsz, self.groups, -1)
+ idx.new_zeros(bsz * tsz * self.groups, self.num_vars)
+ .scatter_(-1, idx.view(-1, 1), 1.0)
+ .view(bsz * tsz, self.groups, -1)
)
hard_probs = torch.mean(hard_x.float(), dim=0)
result["code_perplexity"] = torch.exp(
diff --git a/fairseq/modules/layer_norm.py b/fairseq/modules/layer_norm.py
index 7b1d241436..234609d9e2 100644
--- a/fairseq/modules/layer_norm.py
+++ b/fairseq/modules/layer_norm.py
@@ -22,6 +22,7 @@ def forward(self, x):
with torch.cuda.device(x.device):
return super().forward(x)
+
except ImportError:
has_fused_layernorm = False
diff --git a/fairseq/modules/lightconv_layer/cuda_function_gen.py b/fairseq/modules/lightconv_layer/cuda_function_gen.py
index afec9e19e7..a25433dd8e 100644
--- a/fairseq/modules/lightconv_layer/cuda_function_gen.py
+++ b/fairseq/modules/lightconv_layer/cuda_function_gen.py
@@ -91,7 +91,7 @@ def gen_forward():
}
"""
- with open("lightconv_cuda_forward.cu", 'w') as forward:
+ with open("lightconv_cuda_forward.cu", "w") as forward:
forward.write(head)
for seq in seqs:
forward.write(sequence_if.format(seq=seq))
@@ -261,7 +261,7 @@ def gen_backward():
thresh = [32, 32, 64, 128, 256, -1, -1, -1]
max_mem = [-1, -1, -1, -1, -1, 192, 96, 64]
- with open("lightconv_cuda_backward.cu", 'w') as backward:
+ with open("lightconv_cuda_backward.cu", "w") as backward:
backward.write(head)
for (k, t, mem) in zip(kernels, thresh, max_mem):
backward.write(case_k.format(k=k))
diff --git a/fairseq/modules/lightconv_layer/lightconv_layer.py b/fairseq/modules/lightconv_layer/lightconv_layer.py
index 9b4c9a951e..e7e597f474 100644
--- a/fairseq/modules/lightconv_layer/lightconv_layer.py
+++ b/fairseq/modules/lightconv_layer/lightconv_layer.py
@@ -3,19 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+import lightconv_cuda
import torch
-from torch import nn
-from torch.autograd import Function
import torch.nn.functional as F
-
-import lightconv_cuda
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
+from torch import nn
+from torch.autograd import Function
class lightconvFunction(Function):
-
@staticmethod
def forward(ctx, x, weights, padding_l):
ctx.padding_l = padding_l
@@ -27,9 +25,8 @@ def forward(ctx, x, weights, padding_l):
@staticmethod
def backward(ctx, grad_output):
outputs = lightconv_cuda.backward(
- grad_output.contiguous(),
- ctx.padding_l,
- *ctx.saved_tensors)
+ grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors
+ )
grad_input, grad_weights = outputs
return grad_input, grad_weights, None
@@ -37,14 +34,14 @@ def backward(ctx, grad_output):
@with_incremental_state
class LightconvLayer(nn.Module):
def __init__(
- self,
- input_size,
- kernel_size=1,
- padding_l=None,
- weight_softmax=False,
- num_heads=1,
- weight_dropout=0.,
- bias=False,
+ self,
+ input_size,
+ kernel_size=1,
+ padding_l=None,
+ weight_softmax=False,
+ num_heads=1,
+ weight_dropout=0.0,
+ bias=False,
):
super(LightconvLayer, self).__init__()
self.input_size = input_size
@@ -52,7 +49,9 @@ def __init__(
self.padding_l = padding_l
self.num_heads = num_heads
self.weight_softmax = weight_softmax
- self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__)
+ self.weight_dropout_module = FairseqDropout(
+ weight_dropout, module_name=self.__class__.__name__
+ )
self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size))
if bias:
@@ -62,16 +61,16 @@ def __init__(
self.reset_parameters()
def upgrade_state_dict_named(self, state_dict, name):
- prefix = name + '.' if name != '' else ''
+ prefix = name + "." if name != "" else ""
for k, v in state_dict.items():
- if k.endswith(prefix + 'weight'):
+ if k.endswith(prefix + "weight"):
if v.dim() == 3 and v.size(1) == 1:
state_dict[k] = v.squeeze(1)
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
- nn.init.constant_(self.bias, 0.)
+ nn.init.constant_(self.bias, 0.0)
def forward(self, x, incremental_state=None):
@@ -85,18 +84,25 @@ def forward(self, x, incremental_state=None):
input_buffer = x.new()
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
if self.kernel_size > 1:
- self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:])
- x_unfold = x_unfold.view(T*B*H, R, -1)
+ self._set_input_buffer(
+ incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :]
+ )
+ x_unfold = x_unfold.view(T * B * H, R, -1)
weight = self.weight
if self.weight_softmax:
weight = F.softmax(weight.float(), dim=1).type_as(weight)
- weight = weight[:, -x_unfold.size(2):]
+ weight = weight[:, -x_unfold.size(2) :]
K = weight.size(1)
- weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1)
+ weight = (
+ weight.view(1, H, K)
+ .expand(T * B, H, K)
+ .contiguous()
+ .view(T * B * H, K, 1)
+ )
weight = self.weight_dropout_module(weight)
output = torch.bmm(x_unfold, weight) # T*B*H x R x 1
@@ -120,10 +126,12 @@ def reorder_incremental_state(self, incremental_state, new_order):
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
- return utils.get_incremental_state(self, incremental_state, 'input_buffer')
+ return utils.get_incremental_state(self, incremental_state, "input_buffer")
def _set_input_buffer(self, incremental_state, new_buffer):
- return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
+ return utils.set_incremental_state(
+ self, incremental_state, "input_buffer", new_buffer
+ )
def half(self):
return self._apply(lambda t: t.half() if t.is_floating_point() else t)
diff --git a/fairseq/modules/lightconv_layer/setup.py b/fairseq/modules/lightconv_layer/setup.py
index 0eac1df03c..052635be79 100644
--- a/fairseq/modules/lightconv_layer/setup.py
+++ b/fairseq/modules/lightconv_layer/setup.py
@@ -5,16 +5,19 @@
# LICENSE file in the root directory of this source tree.
from setuptools import setup
-from torch.utils.cpp_extension import CUDAExtension, BuildExtension
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
setup(
- name='lightconv_layer',
+ name="lightconv_layer",
ext_modules=[
- CUDAExtension('lightconv_cuda', [
- 'lightconv_cuda.cpp',
- 'lightconv_cuda_kernel.cu',
- ]),
+ CUDAExtension(
+ "lightconv_cuda",
+ [
+ "lightconv_cuda.cpp",
+ "lightconv_cuda_kernel.cu",
+ ],
+ ),
],
- cmdclass={
- 'build_ext': BuildExtension
- })
+ cmdclass={"build_ext": BuildExtension},
+)
diff --git a/fairseq/modules/lightweight_convolution.py b/fairseq/modules/lightweight_convolution.py
index 3d4cddb134..ec11a95079 100644
--- a/fairseq/modules/lightweight_convolution.py
+++ b/fairseq/modules/lightweight_convolution.py
@@ -6,32 +6,49 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from fairseq import utils
-from fairseq.modules.unfold import unfold1d
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
+from fairseq.modules.unfold import unfold1d
-def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
- weight_dropout=0., weight_softmax=False, bias=False):
+def LightweightConv(
+ input_size,
+ kernel_size=1,
+ padding_l=None,
+ num_heads=1,
+ weight_dropout=0.0,
+ weight_softmax=False,
+ bias=False,
+):
if torch.cuda.is_available():
try:
from fairseq.modules.lightconv_layer import LightconvLayer
- return LightconvLayer(input_size, kernel_size=kernel_size,
- padding_l=padding_l, num_heads=num_heads,
- weight_dropout=weight_dropout,
- weight_softmax=weight_softmax, bias=bias)
+
+ return LightconvLayer(
+ input_size,
+ kernel_size=kernel_size,
+ padding_l=padding_l,
+ num_heads=num_heads,
+ weight_dropout=weight_dropout,
+ weight_softmax=weight_softmax,
+ bias=bias,
+ )
except ImportError as e:
print(e)
- return LightweightConv1dTBC(input_size, kernel_size=kernel_size,
- padding_l=padding_l, num_heads=num_heads,
- weight_dropout=weight_dropout,
- weight_softmax=weight_softmax, bias=bias)
+ return LightweightConv1dTBC(
+ input_size,
+ kernel_size=kernel_size,
+ padding_l=padding_l,
+ num_heads=num_heads,
+ weight_dropout=weight_dropout,
+ weight_softmax=weight_softmax,
+ bias=bias,
+ )
class LightweightConv1d(nn.Module):
- '''Lightweight Convolution assuming the input is BxCxT
+ """Lightweight Convolution assuming the input is BxCxT
This is just an example that explains LightConv clearer than the TBC version.
We don't use this module in the model.
@@ -51,10 +68,18 @@ class LightweightConv1d(nn.Module):
weight: the learnable weights of the module of shape
`(num_heads, 1, kernel_size)`
bias: the learnable bias of the module of shape `(input_size)`
- '''
-
- def __init__(self, input_size, kernel_size=1, padding=0, num_heads=1,
- weight_softmax=False, bias=False, weight_dropout=0.):
+ """
+
+ def __init__(
+ self,
+ input_size,
+ kernel_size=1,
+ padding=0,
+ num_heads=1,
+ weight_softmax=False,
+ bias=False,
+ weight_dropout=0.0,
+ ):
super().__init__()
self.input_size = input_size
self.kernel_size = kernel_size
@@ -67,19 +92,21 @@ def __init__(self, input_size, kernel_size=1, padding=0, num_heads=1,
self.bias = nn.Parameter(torch.Tensor(input_size))
else:
self.bias = None
- self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__)
+ self.weight_dropout_module = FairseqDropout(
+ weight_dropout, module_name=self.__class__.__name__
+ )
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
- nn.init.constant_(self.bias, 0.)
+ nn.init.constant_(self.bias, 0.0)
def forward(self, input):
- '''
+ """
input size: B x C x T
output size: B x C x T
- '''
+ """
B, C, T = input.size()
H = self.num_heads
@@ -103,7 +130,7 @@ def forward(self, input):
@with_incremental_state
class LightweightConv1dTBC(nn.Module):
- '''Lightweight Convolution assuming the input is TxBxC
+ """Lightweight Convolution assuming the input is TxBxC
Args:
input_size: # of channels of the input
kernel_size: convolution channels
@@ -121,15 +148,26 @@ class LightweightConv1dTBC(nn.Module):
weight: the learnable weights of the module of shape
`(num_heads, 1, kernel_size)`
bias: the learnable bias of the module of shape `(input_size)`
- '''
- def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1,
- weight_dropout=0., weight_softmax=False, bias=False):
+ """
+
+ def __init__(
+ self,
+ input_size,
+ kernel_size=1,
+ padding_l=None,
+ num_heads=1,
+ weight_dropout=0.0,
+ weight_softmax=False,
+ bias=False,
+ ):
super().__init__()
self.input_size = input_size
self.kernel_size = kernel_size
self.padding_l = padding_l
self.num_heads = num_heads
- self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__)
+ self.weight_dropout_module = FairseqDropout(
+ weight_dropout, module_name=self.__class__.__name__
+ )
self.weight_softmax = weight_softmax
self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size))
@@ -144,15 +182,15 @@ def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1,
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
- nn.init.constant_(self.bias, 0.)
+ nn.init.constant_(self.bias, 0.0)
def forward(self, x, incremental_state=None, unfold=False):
- '''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C
+ """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C
args:
x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size)
incremental_state: A dict to keep the state
unfold: unfold the input or not. If not, we use the matrix trick instead
- '''
+ """
unfold = unfold or (incremental_state is not None)
if unfold:
@@ -168,8 +206,8 @@ def prepare_for_onnx_export_(self):
self.onnx_trace = True
def _forward_unfolded(self, x, incremental_state):
- '''The conventional implementation of convolutions.
- Unfolding the input by having a window shifting to the right.'''
+ """The conventional implementation of convolutions.
+ Unfolding the input by having a window shifting to the right."""
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
@@ -182,21 +220,27 @@ def _forward_unfolded(self, x, incremental_state):
input_buffer = x.new()
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
if self.kernel_size > 1:
- self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:])
- x_unfold = x_unfold.view(T*B*H, R, -1)
+ self._set_input_buffer(
+ incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :]
+ )
+ x_unfold = x_unfold.view(T * B * H, R, -1)
else:
# unfold the input: T x B x C --> T' x B x C x K
x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0)
- x_unfold = x_unfold.view(T*B*H, R, K)
+ x_unfold = x_unfold.view(T * B * H, R, K)
if self.weight_softmax:
- weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)
+ weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(
+ weight
+ )
if incremental_state is not None:
- weight = weight[:, -x_unfold.size(2):]
+ weight = weight[:, -x_unfold.size(2) :]
K = weight.size(1)
- weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1)
+ weight = (
+ weight.view(1, H, K).expand(T * B, H, K).contiguous().view(T * B * H, K, 1)
+ )
weight = self.weight_dropout_module(weight)
output = torch.bmm(x_unfold, weight) # T*B*H x R x 1
@@ -204,10 +248,10 @@ def _forward_unfolded(self, x, incremental_state):
return output
def _forward_expanded(self, x, incremental_state):
- '''Turn the convolution filters into band matrices and do matrix multiplication.
+ """Turn the convolution filters into band matrices and do matrix multiplication.
This is faster when the sequence is short, but less memory efficient.
This is not used in the decoder during inference.
- '''
+ """
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
@@ -215,18 +259,22 @@ def _forward_expanded(self, x, incremental_state):
weight = self.weight.view(H, K)
if self.weight_softmax:
- weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)
- weight = weight.view(1, H, K).expand(T*B, H, K).contiguous()
- weight = weight.view(T, B*H, K).transpose(0, 1)
+ weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(
+ weight
+ )
+ weight = weight.view(1, H, K).expand(T * B, H, K).contiguous()
+ weight = weight.view(T, B * H, K).transpose(0, 1)
- x = x.view(T, B*H, R).transpose(0, 1)
+ x = x.view(T, B * H, R).transpose(0, 1)
P = self.padding_l
- if K > T and P == K-1:
- weight = weight.narrow(2, K-T, T)
- K, P = T, T-1
+ if K > T and P == K - 1:
+ weight = weight.narrow(2, K - T, T)
+ K, P = T, T - 1
# turn the convolution filters into band matrices
- weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False)
- weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
+ weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False)
+ weight_expanded.as_strided((B * H, T, K), (T * (T + K - 1), T + K, 1)).copy_(
+ weight
+ )
weight_expanded = weight_expanded.narrow(2, P, T)
weight_expanded = self.weight_dropout_module(weight_expanded)
@@ -241,16 +289,22 @@ def reorder_incremental_state(self, incremental_state, new_order):
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
- return utils.get_incremental_state(self, incremental_state, 'input_buffer')
+ return utils.get_incremental_state(self, incremental_state, "input_buffer")
def _set_input_buffer(self, incremental_state, new_buffer):
- return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
+ return utils.set_incremental_state(
+ self, incremental_state, "input_buffer", new_buffer
+ )
def extra_repr(self):
- s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, bias={}'.format(
- self.input_size, self.kernel_size, self.padding_l,
- self.num_heads, self.weight_softmax, self.bias is not None
+ s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, bias={}".format(
+ self.input_size,
+ self.kernel_size,
+ self.padding_l,
+ self.num_heads,
+ self.weight_softmax,
+ self.bias is not None,
)
- if self.weight_dropout_module.p > 0.:
- s += ', weight_dropout={}'.format(self.weight_dropout_module.p)
+ if self.weight_dropout_module.p > 0.0:
+ s += ", weight_dropout={}".format(self.weight_dropout_module.p)
return s
diff --git a/fairseq/modules/linearized_convolution.py b/fairseq/modules/linearized_convolution.py
index 3dd4b151c1..09a8f201c0 100644
--- a/fairseq/modules/linearized_convolution.py
+++ b/fairseq/modules/linearized_convolution.py
@@ -5,11 +5,11 @@
import torch
import torch.nn.functional as F
-
from fairseq import utils
-from .conv_tbc import ConvTBC
from fairseq.incremental_decoding_utils import with_incremental_state
+from .conv_tbc import ConvTBC
+
@with_incremental_state
class LinearizedConvolution(ConvTBC):
@@ -26,17 +26,17 @@ def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight)
- def state_dict(self, destination=None, prefix='', keep_vars=False):
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
state = ConvTBC.state_dict(self, destination, prefix, keep_vars=keep_vars)
# don't store redundant _linearized_weight in checkpoints
- if prefix + '_linearized_weight' in state:
- del state[prefix + '_linearized_weight']
+ if prefix + "_linearized_weight" in state:
+ del state[prefix + "_linearized_weight"]
return state
def upgrade_state_dict_named(self, state_dict, name):
- prefix = name + '.' if name != '' else ''
- if prefix + '_linearized_weight' in state_dict:
- del state_dict[prefix + '_linearized_weight']
+ prefix = name + "." if name != "" else ""
+ if prefix + "_linearized_weight" in state_dict:
+ del state_dict[prefix + "_linearized_weight"]
def forward(self, input, incremental_state=None):
"""
@@ -52,7 +52,7 @@ def forward(self, input, incremental_state=None):
output = super().forward(input)
if self.kernel_size[0] > 1 and self.padding[0] > 0:
# remove future timesteps added by padding
- output = output[:-self.padding[0], :, :]
+ output = output[: -self.padding[0], :, :]
return output
# reshape weight
@@ -83,17 +83,21 @@ def reorder_incremental_state(self, incremental_state, new_order):
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
- return utils.get_incremental_state(self, incremental_state, 'input_buffer')
+ return utils.get_incremental_state(self, incremental_state, "input_buffer")
def _set_input_buffer(self, incremental_state, new_buffer):
- return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
+ return utils.set_incremental_state(
+ self, incremental_state, "input_buffer", new_buffer
+ )
def _get_linearized_weight(self):
if self._linearized_weight is None:
kw = self.kernel_size[0]
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
assert weight.size() == (self.out_channels, kw, self.in_channels)
- self._linearized_weight = torch.nn.Parameter(weight.view(self.out_channels, -1))
+ self._linearized_weight = torch.nn.Parameter(
+ weight.view(self.out_channels, -1)
+ )
return self._linearized_weight
def _clear_linearized_weight(self, *args):
diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py
index 90b635af2b..99f95deb5f 100644
--- a/fairseq/modules/multihead_attention.py
+++ b/fairseq/modules/multihead_attention.py
@@ -8,13 +8,12 @@
import torch
import torch.nn.functional as F
-from torch import Tensor, nn
-from torch.nn import Parameter
-
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
+from torch import Tensor, nn
+from torch.nn import Parameter
@with_incremental_state
@@ -63,11 +62,19 @@ def __init__(
"Self-attention requires query, key and " "value to be of the same size"
)
- self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
- self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
- self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
+ self.k_proj = quant_noise(
+ nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ self.v_proj = quant_noise(
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ self.q_proj = quant_noise(
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
- self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
+ self.out_proj = quant_noise(
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
@@ -102,7 +109,7 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
- nn.init.constant_(self.out_proj.bias, 0.)
+ nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
@@ -333,11 +340,11 @@ def forward(
if not self.tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
- float("-inf")
+ float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
- attn_weights = attn_weights.masked_fill(key_padding_mask, float('-inf'))
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -411,7 +418,9 @@ def _append_prev_key_padding_mask(
@torch.jit.export
def reorder_incremental_state(
- self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
+ self,
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
+ new_order: Tensor,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
@@ -419,7 +428,9 @@ def reorder_incremental_state(
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
- if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(0):
+ if self.encoder_decoder_attention and input_buffer_k.size(
+ 0
+ ) == new_order.size(0):
break
input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
diff --git a/fairseq/modules/positional_embedding.py b/fairseq/modules/positional_embedding.py
index 511460fcb7..8e94e35edb 100644
--- a/fairseq/modules/positional_embedding.py
+++ b/fairseq/modules/positional_embedding.py
@@ -4,15 +4,16 @@
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
+
from .learned_positional_embedding import LearnedPositionalEmbedding
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
def PositionalEmbedding(
- num_embeddings: int,
- embedding_dim: int,
- padding_idx: int,
- learned: bool = False,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: int,
+ learned: bool = False,
):
if learned:
# if padding_idx is specified then offset the embedding ids by
@@ -27,6 +28,8 @@ def PositionalEmbedding(
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
- embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
+ embedding_dim,
+ padding_idx,
+ init_size=num_embeddings + padding_idx + 1,
)
return m
diff --git a/fairseq/modules/quant_noise.py b/fairseq/modules/quant_noise.py
index b38ea263d3..d777dfbb6c 100644
--- a/fairseq/modules/quant_noise.py
+++ b/fairseq/modules/quant_noise.py
@@ -39,13 +39,17 @@ def quant_noise(module, p, block_size):
# 2D matrix
if not is_conv:
- assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
+ assert (
+ module.weight.size(1) % block_size == 0
+ ), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
- assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
+ assert (
+ module.in_channels % block_size == 0
+ ), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
@@ -61,7 +65,9 @@ def _forward_pre_hook(mod, input):
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
- mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
+ mask = torch.zeros(
+ in_features // block_size * out_features, device=weight.device
+ )
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
@@ -73,16 +79,27 @@ def _forward_pre_hook(mod, input):
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
- mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device)
+ mask = torch.zeros(
+ int(in_channels // block_size * out_channels),
+ device=weight.device,
+ )
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
- mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
+ mask = torch.zeros(
+ weight.size(0), weight.size(1), device=weight.device
+ )
mask.bernoulli_(p)
- mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+ mask = (
+ mask.unsqueeze(2)
+ .unsqueeze(3)
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+ )
# scale weights and apply mask
- mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
+ mask = mask.to(
+ torch.bool
+ ) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
diff --git a/fairseq/modules/quantization/pq/em.py b/fairseq/modules/quantization/pq/em.py
index 420d8afda2..6f15c3e46b 100644
--- a/fairseq/modules/quantization/pq/em.py
+++ b/fairseq/modules/quantization/pq/em.py
@@ -3,9 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+import logging
import os
import random
-import logging
from collections import Counter
import torch
diff --git a/fairseq/modules/quantization/pq/modules/__init__.py b/fairseq/modules/quantization/pq/modules/__init__.py
index f52f6f37a6..b67c8e8ad6 100644
--- a/fairseq/modules/quantization/pq/modules/__init__.py
+++ b/fairseq/modules/quantization/pq/modules/__init__.py
@@ -4,5 +4,5 @@
# LICENSE file in the root directory of this source tree.
from .qconv import PQConv2d # NOQA
-from .qlinear import PQLinear # NOQA
from .qemb import PQEmbedding # NOQA
+from .qlinear import PQLinear # NOQA
diff --git a/fairseq/modules/quantization/pq/modules/qemb.py b/fairseq/modules/quantization/pq/modules/qemb.py
index 98d856d04e..3a74ad3c4c 100644
--- a/fairseq/modules/quantization/pq/modules/qemb.py
+++ b/fairseq/modules/quantization/pq/modules/qemb.py
@@ -27,9 +27,19 @@ class PQEmbedding(nn.Module):
the non-quantized nn.Embedding module for a standard training loop.
"""
- def __init__(self, centroids, assignments, num_embeddings, embedding_dim,
- padding_idx=None, max_norm=None, norm_type=2.,
- scale_grad_by_freq=False, sparse=False, _weight=None):
+ def __init__(
+ self,
+ centroids,
+ assignments,
+ num_embeddings,
+ embedding_dim,
+ padding_idx=None,
+ max_norm=None,
+ norm_type=2.0,
+ scale_grad_by_freq=False,
+ sparse=False,
+ _weight=None,
+ ):
super(PQEmbedding, self).__init__()
self.block_size = centroids.size(1)
self.n_centroids = centroids.size(0)
@@ -37,9 +47,13 @@ def __init__(self, centroids, assignments, num_embeddings, embedding_dim,
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
- assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
+ assert (
+ padding_idx < self.num_embeddings
+ ), "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
- assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
+ assert (
+ padding_idx >= -self.num_embeddings
+ ), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
@@ -67,21 +81,27 @@ def weight(self):
def forward(self, input):
return F.embedding(
- input, self.weight, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse)
+ input,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
def extra_repr(self):
- s = '{num_embeddings}, {embedding_dim}'
+ s = "{num_embeddings}, {embedding_dim}"
if self.padding_idx is not None:
- s += ', padding_idx={padding_idx}'
+ s += ", padding_idx={padding_idx}"
if self.max_norm is not None:
- s += ', max_norm={max_norm}'
+ s += ", max_norm={max_norm}"
if self.norm_type != 2:
- s += ', norm_type={norm_type}'
+ s += ", norm_type={norm_type}"
if self.scale_grad_by_freq is not False:
- s += ', scale_grad_by_freq={scale_grad_by_freq}'
+ s += ", scale_grad_by_freq={scale_grad_by_freq}"
if self.sparse is not False:
- s += ', sparse=True'
- s += ', n_centroids={n_centroids}, block_size={block_size}'
+ s += ", sparse=True"
+ s += ", n_centroids={n_centroids}, block_size={block_size}"
return s.format(**self.__dict__)
diff --git a/fairseq/modules/quantization/pq/utils.py b/fairseq/modules/quantization/pq/utils.py
index 57aaa1b7a3..03b15e4b1b 100644
--- a/fairseq/modules/quantization/pq/utils.py
+++ b/fairseq/modules/quantization/pq/utils.py
@@ -8,10 +8,10 @@
from operator import attrgetter, itemgetter
import numpy as np
-import torch.nn as nn
import torch.distributed as dist
+import torch.nn as nn
-from .modules import PQConv2d, PQLinear, PQEmbedding
+from .modules import PQConv2d, PQEmbedding, PQLinear
from .pq import PQ
@@ -63,7 +63,9 @@ def quantize_model_(
for layer in quantized_layers:
# book-keeping
- is_master_process = (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0)
+ is_master_process = (not dist.is_initialized()) or (
+ dist.is_initialized() and dist.get_rank() == 0
+ )
verbose = verbose and is_master_process
# get block size and centroids
@@ -71,11 +73,13 @@ def quantize_model_(
block_size = get_param(module, layer, block_sizes_config)
n_centroids = get_param(module, layer, n_centroids_config)
if verbose:
- logging.info(f"Quantizing layer {layer} with block size {block_size} and {n_centroids} centroids")
+ logging.info(
+ f"Quantizing layer {layer} with block size {block_size} and {n_centroids} centroids"
+ )
# quantize layer
weight = module.weight.data.clone()
- is_bias = 'bias' in [x[0] for x in module.named_parameters()]
+ is_bias = "bias" in [x[0] for x in module.named_parameters()]
bias = module.bias.data.clone() if is_bias else None
quantizer = PQ(
weight,
@@ -238,9 +242,7 @@ def get_param(module, layer_name, param_config):
if "*" in params:
feature_value = "*"
else:
- raise KeyError(
- f"name={layer_name} not in config for {module}"
- )
+ raise KeyError(f"name={layer_name} not in config for {module}")
else:
feature_value = feature_values[0]
diff --git a/fairseq/modules/quantization/scalar/modules/__init__.py b/fairseq/modules/quantization/scalar/modules/__init__.py
index ead4669611..8031d9cdb2 100644
--- a/fairseq/modules/quantization/scalar/modules/__init__.py
+++ b/fairseq/modules/quantization/scalar/modules/__init__.py
@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from .qact import ActivationQuantizer # NOQA
from .qconv import IntConv2d # NOQA
-from .qlinear import IntLinear # NOQA
from .qemb import IntEmbedding # NOQA
-from .qact import ActivationQuantizer # NOQA
+from .qlinear import IntLinear # NOQA
diff --git a/fairseq/modules/quantization/scalar/modules/qact.py b/fairseq/modules/quantization/scalar/modules/qact.py
index a9f79011c1..c5dd1d6336 100644
--- a/fairseq/modules/quantization/scalar/modules/qact.py
+++ b/fairseq/modules/quantization/scalar/modules/qact.py
@@ -32,8 +32,16 @@ class ActivationQuantizer:
- The activations are hard-clamped in [-clamp_threshold, clamp_threshold]
to prevent overflow during the backward pass
"""
- def __init__(self, module, p=1, update_step=1000, bits=8,
- method="histogram", clamp_threshold=5):
+
+ def __init__(
+ self,
+ module,
+ p=1,
+ update_step=1000,
+ bits=8,
+ method="histogram",
+ clamp_threshold=5,
+ ):
self.module = module
self.p = p
self.update_step = update_step
@@ -72,7 +80,7 @@ def quantize_hook(module, x, y):
noise = (y_q - y).masked_fill(mask.bool(), 0)
# using straight-through estimator (STE)
- clamp_low = - self.scale * self.zero_point
+ clamp_low = -self.scale * self.zero_point
clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point)
return torch.clamp(y, clamp_low.item(), clamp_high.item()) + noise.detach()
diff --git a/fairseq/modules/quantization/scalar/modules/qconv.py b/fairseq/modules/quantization/scalar/modules/qconv.py
index d718c9b90d..83788c6f71 100644
--- a/fairseq/modules/quantization/scalar/modules/qconv.py
+++ b/fairseq/modules/quantization/scalar/modules/qconv.py
@@ -118,9 +118,12 @@ def forward(self, input):
noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0)
# using straight-through estimator (STE)
- clamp_low = - self.scale * self.zero_point
+ clamp_low = -self.scale * self.zero_point
clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point)
- weight = torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + noise.detach()
+ weight = (
+ torch.clamp(self.weight, clamp_low.item(), clamp_high.item())
+ + noise.detach()
+ )
# return output
output = self._conv_forward(input, weight)
diff --git a/fairseq/modules/quantization/scalar/modules/qemb.py b/fairseq/modules/quantization/scalar/modules/qemb.py
index 835b2782a7..d6cf06e587 100644
--- a/fairseq/modules/quantization/scalar/modules/qemb.py
+++ b/fairseq/modules/quantization/scalar/modules/qemb.py
@@ -37,7 +37,7 @@ def __init__(
embedding_dim,
padding_idx=None,
max_norm=None,
- norm_type=2.,
+ norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
@@ -51,9 +51,13 @@ def __init__(
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
- assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
+ assert (
+ padding_idx < self.num_embeddings
+ ), "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
- assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
+ assert (
+ padding_idx >= -self.num_embeddings
+ ), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
@@ -63,8 +67,10 @@ def __init__(
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
else:
- assert list(_weight.shape) == [num_embeddings, embedding_dim], \
- 'Shape of weight does not match num_embeddings and embedding_dim'
+ assert list(_weight.shape) == [
+ num_embeddings,
+ embedding_dim,
+ ], "Shape of weight does not match num_embeddings and embedding_dim"
self.weight = nn.Parameter(_weight)
self.sparse = sparse
@@ -106,27 +112,36 @@ def forward(self, input):
noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0)
# using straight-through estimator (STE)
- clamp_low = - self.scale * self.zero_point
+ clamp_low = -self.scale * self.zero_point
clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point)
- weight = torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + noise.detach()
+ weight = (
+ torch.clamp(self.weight, clamp_low.item(), clamp_high.item())
+ + noise.detach()
+ )
# return output
output = F.embedding(
- input, weight, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse)
+ input,
+ weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
return output
def extra_repr(self):
- s = '{num_embeddings}, {embedding_dim}'
+ s = "{num_embeddings}, {embedding_dim}"
if self.padding_idx is not None:
- s += ', padding_idx={padding_idx}'
+ s += ", padding_idx={padding_idx}"
if self.max_norm is not None:
- s += ', max_norm={max_norm}'
+ s += ", max_norm={max_norm}"
if self.norm_type != 2:
- s += ', norm_type={norm_type}'
+ s += ", norm_type={norm_type}"
if self.scale_grad_by_freq is not False:
- s += ', scale_grad_by_freq={scale_grad_by_freq}'
+ s += ", scale_grad_by_freq={scale_grad_by_freq}"
if self.sparse is not False:
- s += ', sparse=True'
- s += 'quant_noise={p}, bits={bits}, method={method}'
+ s += ", sparse=True"
+ s += "quant_noise={p}, bits={bits}, method={method}"
return s.format(**self.__dict__)
diff --git a/fairseq/modules/quantization/scalar/modules/qlinear.py b/fairseq/modules/quantization/scalar/modules/qlinear.py
index 2d4b27dc6c..9db1559386 100644
--- a/fairseq/modules/quantization/scalar/modules/qlinear.py
+++ b/fairseq/modules/quantization/scalar/modules/qlinear.py
@@ -91,9 +91,12 @@ def forward(self, input):
noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0)
# using straight-through estimator (STE)
- clamp_low = - self.scale * self.zero_point
+ clamp_low = -self.scale * self.zero_point
clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point)
- weight = torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + noise.detach()
+ weight = (
+ torch.clamp(self.weight, clamp_low.item(), clamp_high.item())
+ + noise.detach()
+ )
# return output
output = F.linear(input, weight, self.bias)
diff --git a/fairseq/modules/quantization/scalar/ops.py b/fairseq/modules/quantization/scalar/ops.py
index 90bc737cc8..2a855159be 100644
--- a/fairseq/modules/quantization/scalar/ops.py
+++ b/fairseq/modules/quantization/scalar/ops.py
@@ -12,7 +12,9 @@ def emulate_int(w, bits, method, scale=None, zero_point=None):
def quantize(w, scale, zero_point):
- return (torch.clamp(torch.round(w / scale + zero_point), 0, 255) - zero_point) * scale
+ return (
+ torch.clamp(torch.round(w / scale + zero_point), 0, 255) - zero_point
+ ) * scale
def emulate_int8_histogram(w, scale=None, zero_point=None):
diff --git a/fairseq/modules/quantization/scalar/utils.py b/fairseq/modules/quantization/scalar/utils.py
index 4071f7b80a..32cf616568 100644
--- a/fairseq/modules/quantization/scalar/utils.py
+++ b/fairseq/modules/quantization/scalar/utils.py
@@ -6,11 +6,11 @@
import logging
from operator import attrgetter
-import torch.nn as nn
import torch.distributed as dist
+import torch.nn as nn
-from ..pq.utils import get_layers, attrsetter
-from .modules import IntConv2d, IntLinear, IntEmbedding, ActivationQuantizer
+from ..pq.utils import attrsetter, get_layers
+from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear
MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d}
@@ -34,15 +34,25 @@ def quantize_model_(model, p=0.2, bits=8, update_step=3000):
for layer in quantized_layers:
# book-keeping
- is_master_process = (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0)
+ is_master_process = (not dist.is_initialized()) or (
+ dist.is_initialized() and dist.get_rank() == 0
+ )
# recover module
module = attrgetter(layer)(model)
if is_master_process:
- logging.info(f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}")
+ logging.info(
+ f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}"
+ )
# quantization params
- q_params = {"p": p, "update_step": update_step, "bits": bits, "method": "histogram", "counter": 0}
+ q_params = {
+ "p": p,
+ "update_step": update_step,
+ "bits": bits,
+ "method": "histogram",
+ "counter": 0,
+ }
# instantiate the quantized counterpart
if isinstance(module, tuple(MAPPING.keys())):
diff --git a/fairseq/modules/sparse_multihead_attention.py b/fairseq/modules/sparse_multihead_attention.py
index 61430195c2..3cbd9d6785 100644
--- a/fairseq/modules/sparse_multihead_attention.py
+++ b/fairseq/modules/sparse_multihead_attention.py
@@ -4,12 +4,14 @@
# LICENSE file in the root directory of this source tree.
import math
+
import torch
+
from .multihead_attention import MultiheadAttention
class SparseMultiheadAttention(MultiheadAttention):
- """ Sparse Multi-Headed Attention.
+ """Sparse Multi-Headed Attention.
"Generating Long Sequences with Sparse Transformers". Implements
fixed factorized self attention, where l=stride and c=expressivity.
@@ -19,19 +21,40 @@ class SparseMultiheadAttention(MultiheadAttention):
as in the paper.
"""
- def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
- add_bias_kv=False, add_zero_attn=False, self_attention=False,
- encoder_decoder_attention=False, stride=32, expressivity=8, is_bidirectional=True):
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False,
+ stride=32,
+ expressivity=8,
+ is_bidirectional=True,
+ ):
super().__init__(
- embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv,
- add_zero_attn, self_attention, encoder_decoder_attention
+ embed_dim,
+ num_heads,
+ kdim,
+ vdim,
+ dropout,
+ bias,
+ add_bias_kv,
+ add_zero_attn,
+ self_attention,
+ encoder_decoder_attention,
)
self.is_bidirectional = is_bidirectional
self.stride = stride
self.expressivity = expressivity
- assert(self.stride > 0 and self.stride >= self.expressivity)
+ assert self.stride > 0 and self.stride >= self.expressivity
# Used for Ai(2) calculations - beginning of [l-c, l] range
def compute_checkpoint(self, word_index):
@@ -40,7 +63,8 @@ def compute_checkpoint(self, word_index):
else:
checkpoint_index = (
math.floor(word_index / self.stride) * self.stride
- + self.stride - self.expressivity
+ + self.stride
+ - self.expressivity
)
return checkpoint_index
@@ -48,12 +72,15 @@ def compute_checkpoint(self, word_index):
def compute_subset_summaries(self, absolute_max):
checkpoint_index = self.compute_checkpoint(0)
subset_two = set()
- while checkpoint_index <= absolute_max-1:
- summary = set(range(checkpoint_index, min(
- checkpoint_index+self.expressivity+1, absolute_max)
- ))
+ while checkpoint_index <= absolute_max - 1:
+ summary = set(
+ range(
+ checkpoint_index,
+ min(checkpoint_index + self.expressivity + 1, absolute_max),
+ )
+ )
subset_two = subset_two.union(summary)
- checkpoint_index = self.compute_checkpoint(checkpoint_index+self.stride)
+ checkpoint_index = self.compute_checkpoint(checkpoint_index + self.stride)
return subset_two
# Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf
@@ -65,12 +92,19 @@ def compute_fixed_attention_subset(self, word_index, tgt_len):
absolute_max = tgt_len
# Subset 1 - whole window
- rounded_index = math.floor((word_index + self.stride) / self.stride) * self.stride
+ rounded_index = (
+ math.floor((word_index + self.stride) / self.stride) * self.stride
+ )
if word_index % self.stride == 0 and word_index != 0:
- subset_one = set(range(word_index-self.stride, min(absolute_max, word_index+1)))
+ subset_one = set(
+ range(word_index - self.stride, min(absolute_max, word_index + 1))
+ )
else:
- subset_one = set(range(max(0, rounded_index - self.stride), min(
- absolute_max, rounded_index+1))
+ subset_one = set(
+ range(
+ max(0, rounded_index - self.stride),
+ min(absolute_max, rounded_index + 1),
+ )
)
# Subset 2 - summary per window
@@ -83,8 +117,8 @@ def compute_fixed_attention_subset(self, word_index, tgt_len):
# Compute sparse mask - if bidirectional, can pre-compute and store
def buffered_sparse_mask(self, tensor, tgt_len, src_len):
- assert(tgt_len > self.stride)
- sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float('-inf'))
+ assert tgt_len > self.stride
+ sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float("-inf"))
# If bidirectional, subset 2 is the same for every index
subset_summaries = set()
@@ -100,5 +134,7 @@ def buffered_sparse_mask(self, tensor, tgt_len, src_len):
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len)
- sparse_mask = sparse_mask.unsqueeze(0).expand(bsz * self.num_heads, tgt_len, src_len)
+ sparse_mask = sparse_mask.unsqueeze(0).expand(
+ bsz * self.num_heads, tgt_len, src_len
+ )
attn_weights += sparse_mask
diff --git a/fairseq/modules/sparse_transformer_sentence_encoder.py b/fairseq/modules/sparse_transformer_sentence_encoder.py
index 3d50d5a882..f41ec09327 100644
--- a/fairseq/modules/sparse_transformer_sentence_encoder.py
+++ b/fairseq/modules/sparse_transformer_sentence_encoder.py
@@ -5,7 +5,9 @@
import torch.nn as nn
from fairseq.modules import TransformerSentenceEncoder
-from fairseq.modules.sparse_transformer_sentence_encoder_layer import SparseTransformerSentenceEncoderLayer
+from fairseq.modules.sparse_transformer_sentence_encoder_layer import (
+ SparseTransformerSentenceEncoderLayer,
+)
class SparseTransformerSentenceEncoder(TransformerSentenceEncoder):
@@ -43,12 +45,27 @@ def __init__(
) -> None:
super().__init__(
- padding_idx, vocab_size, num_encoder_layers, embedding_dim,
- ffn_embedding_dim, num_attention_heads, dropout, attention_dropout,
- activation_dropout, max_seq_len, num_segments, use_position_embeddings,
- offset_positions_by_padding, encoder_normalize_before, apply_bert_init,
- activation_fn, learned_pos_embedding, embed_scale, freeze_embeddings,
- n_trans_layers_to_freeze, export
+ padding_idx,
+ vocab_size,
+ num_encoder_layers,
+ embedding_dim,
+ ffn_embedding_dim,
+ num_attention_heads,
+ dropout,
+ attention_dropout,
+ activation_dropout,
+ max_seq_len,
+ num_segments,
+ use_position_embeddings,
+ offset_positions_by_padding,
+ encoder_normalize_before,
+ apply_bert_init,
+ activation_fn,
+ learned_pos_embedding,
+ embed_scale,
+ freeze_embeddings,
+ n_trans_layers_to_freeze,
+ export,
)
self.layers = nn.ModuleList(
diff --git a/fairseq/modules/sparse_transformer_sentence_encoder_layer.py b/fairseq/modules/sparse_transformer_sentence_encoder_layer.py
index 21c2fe4d5a..d95da59c24 100644
--- a/fairseq/modules/sparse_transformer_sentence_encoder_layer.py
+++ b/fairseq/modules/sparse_transformer_sentence_encoder_layer.py
@@ -20,7 +20,7 @@ def __init__(
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
- activation_fn: str = 'relu',
+ activation_fn: str = "relu",
export: bool = False,
is_bidirectional: bool = True,
stride: int = 32,
@@ -28,8 +28,14 @@ def __init__(
) -> None:
super().__init__(
- embedding_dim, ffn_embedding_dim, num_attention_heads, dropout,
- attention_dropout, activation_dropout, activation_fn, export
+ embedding_dim,
+ ffn_embedding_dim,
+ num_attention_heads,
+ dropout,
+ attention_dropout,
+ activation_dropout,
+ activation_fn,
+ export,
)
self.self_attn = SparseMultiheadAttention(
diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py
index 9965f2f26c..48cd4c7314 100644
--- a/fairseq/modules/transformer_layer.py
+++ b/fairseq/modules/transformer_layer.py
@@ -9,10 +9,11 @@
import torch.nn as nn
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention
-from fairseq.modules.quant_noise import quant_noise
from fairseq.modules.fairseq_dropout import FairseqDropout
+from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
+
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
@@ -35,7 +36,9 @@ def __init__(self, args):
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
- self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ args.dropout, module_name=self.__class__.__name__
+ )
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
)
@@ -48,19 +51,29 @@ def __init__(self, args):
)
self.normalize_before = args.encoder_normalize_before
self.fc1 = self.build_fc1(
- self.embed_dim, args.encoder_ffn_embed_dim, self.quant_noise, self.quant_noise_block_size
+ self.embed_dim,
+ args.encoder_ffn_embed_dim,
+ self.quant_noise,
+ self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
- args.encoder_ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size
+ args.encoder_ffn_embed_dim,
+ self.embed_dim,
+ self.quant_noise,
+ self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
- return quant_noise(nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size)
+ return quant_noise(
+ nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
+ )
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
- return quant_noise(nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size)
+ return quant_noise(
+ nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
+ )
def build_self_attention(self, embed_dim, args):
return MultiheadAttention(
@@ -164,7 +177,9 @@ def __init__(
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
- self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ args.dropout, module_name=self.__class__.__name__
+ )
self.quant_noise = getattr(args, "quant_noise_pq", 0)
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
@@ -178,14 +193,17 @@ def __init__(
)
self.activation_fn = utils.get_activation_fn(
- activation=str(args.activation_fn) if getattr(args, "activation_fn", None) is not None else "relu"
+ activation=str(args.activation_fn)
+ if getattr(args, "activation_fn", None) is not None
+ else "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0)
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0)
self.activation_dropout_module = FairseqDropout(
- float(activation_dropout_p), module_name=self.__class__.__name__)
+ float(activation_dropout_p), module_name=self.__class__.__name__
+ )
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
@@ -202,10 +220,16 @@ def __init__(
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = self.build_fc1(
- self.embed_dim, args.decoder_ffn_embed_dim, self.quant_noise, self.quant_noise_block_size
+ self.embed_dim,
+ args.decoder_ffn_embed_dim,
+ self.quant_noise,
+ self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
- args.decoder_ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size
+ args.decoder_ffn_embed_dim,
+ self.embed_dim,
+ self.quant_noise,
+ self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
@@ -219,7 +243,9 @@ def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
- def build_self_attention(self, embed_dim, args, add_bias_kv=False, add_zero_attn=False):
+ def build_self_attention(
+ self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
+ ):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py
index 74cd1d0664..208488f562 100644
--- a/fairseq/modules/transformer_sentence_encoder.py
+++ b/fairseq/modules/transformer_sentence_encoder.py
@@ -102,7 +102,9 @@ def __init__(
super().__init__()
self.padding_idx = padding_idx
self.vocab_size = vocab_size
- self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ dropout, module_name=self.__class__.__name__
+ )
self.layerdrop = layerdrop
self.max_seq_len = max_seq_len
self.embedding_dim = embedding_dim
@@ -148,21 +150,23 @@ def __init__(
self.layers = LayerDropModuleList(p=self.layerdrop)
else:
self.layers = nn.ModuleList([])
- self.layers.extend([
- self.build_transformer_sentence_encoder_layer(
- embedding_dim=self.embedding_dim,
- ffn_embedding_dim=ffn_embedding_dim,
- num_attention_heads=num_attention_heads,
- dropout=self.dropout_module.p,
- attention_dropout=attention_dropout,
- activation_dropout=activation_dropout,
- activation_fn=activation_fn,
- export=export,
- q_noise=q_noise,
- qn_block_size=qn_block_size,
- )
- for _ in range(num_encoder_layers)
- ])
+ self.layers.extend(
+ [
+ self.build_transformer_sentence_encoder_layer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=ffn_embedding_dim,
+ num_attention_heads=num_attention_heads,
+ dropout=self.dropout_module.p,
+ attention_dropout=attention_dropout,
+ activation_dropout=activation_dropout,
+ activation_fn=activation_fn,
+ export=export,
+ q_noise=q_noise,
+ qn_block_size=qn_block_size,
+ )
+ for _ in range(num_encoder_layers)
+ ]
+ )
if encoder_normalize_before:
self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
diff --git a/fairseq/modules/transformer_sentence_encoder_layer.py b/fairseq/modules/transformer_sentence_encoder_layer.py
index 383938f68f..3589c60fe6 100644
--- a/fairseq/modules/transformer_sentence_encoder_layer.py
+++ b/fairseq/modules/transformer_sentence_encoder_layer.py
@@ -7,15 +7,10 @@
import torch
import torch.nn as nn
-
from fairseq import utils
-from fairseq.modules import (
- LayerNorm,
- MultiheadAttention,
-)
-from fairseq.modules.quant_noise import quant_noise
+from fairseq.modules import LayerNorm, MultiheadAttention
from fairseq.modules.fairseq_dropout import FairseqDropout
-
+from fairseq.modules.quant_noise import quant_noise
class TransformerSentenceEncoderLayer(nn.Module):
@@ -32,7 +27,7 @@ def __init__(
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
- activation_fn: str = 'relu',
+ activation_fn: str = "relu",
export: bool = False,
q_noise: float = 0.0,
qn_block_size: int = 8,
@@ -45,8 +40,12 @@ def __init__(
# Initialize parameters
self.embedding_dim = embedding_dim
- self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
- self.activation_dropout_module = FairseqDropout(activation_dropout, module_name=self.__class__.__name__)
+ self.dropout_module = FairseqDropout(
+ dropout, module_name=self.__class__.__name__
+ )
+ self.activation_dropout_module = FairseqDropout(
+ activation_dropout, module_name=self.__class__.__name__
+ )
# Initialize blocks
self.activation_fn = utils.get_activation_fn(activation_fn)
@@ -79,14 +78,10 @@ def __init__(
self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
- return quant_noise(
- nn.Linear(input_dim, output_dim), q_noise, qn_block_size
- )
+ return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
- return quant_noise(
- nn.Linear(input_dim, output_dim), q_noise, qn_block_size
- )
+ return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_self_attention(
self,
diff --git a/fairseq/modules/unfold.py b/fairseq/modules/unfold.py
index 3a142db698..138272f1ef 100644
--- a/fairseq/modules/unfold.py
+++ b/fairseq/modules/unfold.py
@@ -7,11 +7,13 @@
def unfold1d(x, kernel_size, padding_l, pad_value=0):
- '''unfold T x B x C to T x B x C x K'''
+ """unfold T x B x C to T x B x C x K"""
if kernel_size > 1:
T, B, C = x.size()
- x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value)
- x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C))
+ x = F.pad(
+ x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value
+ )
+ x = x.as_strided((T, B, C, kernel_size), (B * C, C, 1, B * C))
else:
x = x.unsqueeze(3)
return x
diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py
index 0d7d8d7d79..faa8031d46 100644
--- a/fairseq/nan_detector.py
+++ b/fairseq/nan_detector.py
@@ -4,14 +4,16 @@
# LICENSE file in the root directory of this source tree.
import logging
+
import torch
+
logger = logging.getLogger(__name__)
class NanDetector:
"""
- Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name
+ Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name
"""
def __init__(self, model, forward=True, backward=True):
@@ -83,7 +85,7 @@ def _apply(self, module, inp, x, backward):
f" input max: {inp.max().item()}, input min: {inp.min().item()}"
)
- has_printed_attr = 'has_printed_b' if backward else 'has_printed_f'
+ has_printed_attr = "has_printed_b" if backward else "has_printed_f"
logger.warning(err)
setattr(self, has_printed_attr, True)
elif isinstance(x, dict):
diff --git a/fairseq/optim/adadelta.py b/fairseq/optim/adadelta.py
index 9b311ae38a..f1a2154977 100644
--- a/fairseq/optim/adadelta.py
+++ b/fairseq/optim/adadelta.py
@@ -5,10 +5,10 @@
import torch.optim
-from . import register_optimizer, LegacyFairseqOptimizer
+from . import LegacyFairseqOptimizer, register_optimizer
-@register_optimizer('adadelta')
+@register_optimizer("adadelta")
class Adadelta(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
@@ -36,10 +36,10 @@ def optimizer_config(self):
different learning rate.
"""
return {
- 'lr': self.args.lr[0],
- 'rho': self.args.adadelta_rho,
- 'eps': self.args.adadelta_eps,
- 'weight_decay': self.args.weight_decay,
+ "lr": self.args.lr[0],
+ "rho": self.args.adadelta_rho,
+ "eps": self.args.adadelta_eps,
+ "weight_decay": self.args.weight_decay,
}
@property
diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py
index ab69e0e58d..91745ce10e 100644
--- a/fairseq/optim/adafactor.py
+++ b/fairseq/optim/adafactor.py
@@ -4,13 +4,14 @@
# LICENSE file in the root directory of this source tree.
import math
+
import torch
import torch.optim
-from . import register_optimizer, LegacyFairseqOptimizer
+from . import LegacyFairseqOptimizer, register_optimizer
-@register_optimizer('adafactor')
+@register_optimizer("adafactor")
class FairseqAdafactor(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
@@ -50,15 +51,15 @@ def optimizer_config(self):
Might require search for appropriate configuration.
"""
return {
- 'lr': self.args.lr[0],
- 'eps': eval(self.args.adafactor_eps),
- 'clip_threshold': self.args.clip_threshold,
- 'decay_rate': self.args.decay_rate,
- 'beta1': self.args.beta1,
- 'weight_decay': self.args.weight_decay,
- 'scale_parameter': self.args.scale_parameter, # defaults to False
- 'relative_step': self.args.relative_step, # defaults to False
- 'warmup_init': self.args.warmup_init,
+ "lr": self.args.lr[0],
+ "eps": eval(self.args.adafactor_eps),
+ "clip_threshold": self.args.clip_threshold,
+ "decay_rate": self.args.decay_rate,
+ "beta1": self.args.beta1,
+ "weight_decay": self.args.weight_decay,
+ "scale_parameter": self.args.scale_parameter, # defaults to False
+ "relative_step": self.args.relative_step, # defaults to False
+ "warmup_init": self.args.warmup_init,
}
@@ -96,17 +97,35 @@ class Adafactor(torch.optim.Optimizer):
whether warm-up initialization is being used (default: False)
"""
- def __init__(self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0,
- decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True,
- relative_step=True, warmup_init=False):
+ def __init__(
+ self,
+ params,
+ lr=None,
+ eps=(1e-30, 1e-3),
+ clip_threshold=1.0,
+ decay_rate=-0.8,
+ beta1=None,
+ weight_decay=0.0,
+ scale_parameter=True,
+ relative_step=True,
+ warmup_init=False,
+ ):
if lr is not None and relative_step:
- raise ValueError('Cannot combine manual lr and relative_step options')
+ raise ValueError("Cannot combine manual lr and relative_step options")
if warmup_init and not relative_step:
- raise ValueError('warmup_init requires relative_step=True')
-
- defaults = dict(lr=lr, eps=eps, clip_threshold=clip_threshold, decay_rate=decay_rate,
- beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
- relative_step=relative_step, warmup_init=warmup_init)
+ raise ValueError("warmup_init requires relative_step=True")
+
+ defaults = dict(
+ lr=lr,
+ eps=eps,
+ clip_threshold=clip_threshold,
+ decay_rate=decay_rate,
+ beta1=beta1,
+ weight_decay=weight_decay,
+ scale_parameter=scale_parameter,
+ relative_step=relative_step,
+ warmup_init=warmup_init,
+ )
super(Adafactor, self).__init__(params, defaults)
@property
@@ -118,18 +137,20 @@ def supports_flat_params(self):
return False
def _get_lr(self, param_group, param_state):
- rel_step_sz = param_group['lr']
- if param_group['relative_step']:
- min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
- rel_step_sz = min(min_step, 1.0/math.sqrt(param_state['step']))
+ rel_step_sz = param_group["lr"]
+ if param_group["relative_step"]:
+ min_step = (
+ 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
+ )
+ rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
param_scale = 1.0
- if param_group['scale_parameter']:
- param_scale = max(param_group['eps'][1], param_state['RMS'])
+ if param_group["scale_parameter"]:
+ param_scale = max(param_group["eps"][1], param_state["RMS"])
return param_scale * rel_step_sz
def _get_options(self, param_group, param_shape):
factored = len(param_shape) >= 2
- use_first_moment = param_group['beta1'] is not None
+ use_first_moment = param_group["beta1"] is not None
return factored, use_first_moment
def _rms(self, tensor):
@@ -137,8 +158,10 @@ def _rms(self, tensor):
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (
- exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)
- ).rsqrt_().unsqueeze(-1)
+ (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
+ .rsqrt_()
+ .unsqueeze(-1)
+ )
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
@@ -154,14 +177,14 @@ def step(self, closure=None):
loss = closure()
for group in self.param_groups:
- for p in group['params']:
+ for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
- raise RuntimeError('Adafactor does not support sparse gradients.')
+ raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
@@ -169,65 +192,73 @@ def step(self, closure=None):
factored, use_first_moment = self._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
- state['step'] = 0
+ state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(grad)
+ state["exp_avg"] = torch.zeros_like(grad)
if factored:
- state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
- state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
+ state["exp_avg_sq_col"] = torch.zeros(
+ grad_shape[:-2] + grad_shape[-1:]
+ ).to(grad)
else:
- state['exp_avg_sq'] = torch.zeros_like(grad)
+ state["exp_avg_sq"] = torch.zeros_like(grad)
- state['RMS'] = 0
+ state["RMS"] = 0
else:
if use_first_moment:
- state['exp_avg'] = state['exp_avg'].to(grad)
+ state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
- state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
- state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
- state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
- state['step'] += 1
- state['RMS'] = self._rms(p_data_fp32)
- group['lr'] = self._get_lr(group, state)
+ state["step"] += 1
+ state["RMS"] = self._rms(p_data_fp32)
+ group["lr"] = self._get_lr(group, state)
- beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
- update = (grad**2) + group['eps'][0]
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
+ update = (grad ** 2) + group["eps"][0]
if factored:
- exp_avg_sq_row = state['exp_avg_sq_row']
- exp_avg_sq_col = state['exp_avg_sq_col']
+ exp_avg_sq_row = state["exp_avg_sq_row"]
+ exp_avg_sq_col = state["exp_avg_sq_col"]
- exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
- exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
+ exp_avg_sq_row.mul_(beta2t).add_(
+ update.mean(dim=-1), alpha=1.0 - beta2t
+ )
+ exp_avg_sq_col.mul_(beta2t).add_(
+ update.mean(dim=-2), alpha=1.0 - beta2t
+ )
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
- exp_avg_sq = state['exp_avg_sq']
+ exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_(
- (self._rms(update) / group['clip_threshold']).clamp_(min=1.0)
+ (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)
)
- update.mul_(group['lr'])
+ update.mul_(group["lr"])
if use_first_moment:
- exp_avg = state['exp_avg']
- exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
+ exp_avg = state["exp_avg"]
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"])
update = exp_avg
- if group['weight_decay'] != 0:
- p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
+ if group["weight_decay"] != 0:
+ p_data_fp32.add_(
+ p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
+ )
p_data_fp32.add_(-update)
diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py
index 5056752776..a79b6c39da 100644
--- a/fairseq/optim/adagrad.py
+++ b/fairseq/optim/adagrad.py
@@ -5,10 +5,10 @@
import torch.optim
-from . import register_optimizer, LegacyFairseqOptimizer
+from . import LegacyFairseqOptimizer, register_optimizer
-@register_optimizer('adagrad')
+@register_optimizer("adagrad")
class Adagrad(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
@@ -31,8 +31,8 @@ def optimizer_config(self):
different learning rate.
"""
return {
- 'lr': self.args.lr[0],
- 'weight_decay': self.args.weight_decay,
+ "lr": self.args.lr[0],
+ "weight_decay": self.args.weight_decay,
}
@property
diff --git a/fairseq/optim/adamax.py b/fairseq/optim/adamax.py
index 195e7a90d8..577a688166 100644
--- a/fairseq/optim/adamax.py
+++ b/fairseq/optim/adamax.py
@@ -6,10 +6,10 @@
import torch
import torch.optim
-from . import register_optimizer, LegacyFairseqOptimizer
+from . import LegacyFairseqOptimizer, register_optimizer
-@register_optimizer('adamax')
+@register_optimizer("adamax")
class FairseqAdamax(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
@@ -38,11 +38,11 @@ def optimizer_config(self):
different learning rate.
"""
return {
- 'lr': self.args.lr[0],
- 'betas': eval(self.args.adamax_betas),
- 'eps': self.args.adamax_eps,
- 'weight_decay': self.args.weight_decay,
- 'bias_correction': not self.args.no_bias_correction,
+ "lr": self.args.lr[0],
+ "betas": eval(self.args.adamax_betas),
+ "eps": self.args.adamax_eps,
+ "weight_decay": self.args.weight_decay,
+ "bias_correction": not self.args.no_bias_correction,
}
@@ -67,8 +67,15 @@ class Adamax(torch.optim.Optimizer):
__ https://arxiv.org/abs/1412.6980
"""
- def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, bias_correction=True):
+ def __init__(
+ self,
+ params,
+ lr=2e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ bias_correction=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@@ -80,8 +87,13 @@ def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
- bias_correction=bias_correction)
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ bias_correction=bias_correction,
+ )
super(Adamax, self).__init__(params, defaults)
@property
@@ -104,12 +116,12 @@ def step(self, closure=None):
loss = closure()
for group in self.param_groups:
- for p in group['params']:
+ for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
- raise RuntimeError('Adamax does not support sparse gradients')
+ raise RuntimeError("Adamax does not support sparse gradients")
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
@@ -119,18 +131,18 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = 0
- state['exp_avg'] = torch.zeros_like(p_data_fp32)
- state['exp_inf'] = torch.zeros_like(p_data_fp32)
+ state["step"] = 0
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
+ state["exp_inf"] = torch.zeros_like(p_data_fp32)
else:
- state['exp_avg'] = state['exp_avg'].to(p_data_fp32)
- state['exp_inf'] = state['exp_inf'].to(p_data_fp32)
+ state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
+ state["exp_inf"] = state["exp_inf"].to(p_data_fp32)
- exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
- beta1, beta2 = group['betas']
- eps = group['eps']
+ exp_avg, exp_inf = state["exp_avg"], state["exp_inf"]
+ beta1, beta2 = group["betas"]
+ eps = group["eps"]
- state['step'] += 1
+ state["step"] += 1
# Update biased first moment estimate.
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
@@ -142,13 +154,15 @@ def step(self, closure=None):
out=exp_inf,
)
- step_size = group['lr']
- if group['bias_correction']:
- bias_correction = 1 - beta1 ** state['step']
+ step_size = group["lr"]
+ if group["bias_correction"]:
+ bias_correction = 1 - beta1 ** state["step"]
step_size /= bias_correction
- if group['weight_decay'] != 0:
- p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
+ if group["weight_decay"] != 0:
+ p_data_fp32.add_(
+ p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
+ )
p_data_fp32.addcdiv_(exp_avg, exp_inf.add(eps), value=-step_size)
diff --git a/fairseq/optim/dynamic_loss_scaler.py b/fairseq/optim/dynamic_loss_scaler.py
index 9d1f0b2c05..c5da604220 100644
--- a/fairseq/optim/dynamic_loss_scaler.py
+++ b/fairseq/optim/dynamic_loss_scaler.py
@@ -3,11 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-class DynamicLossScaler(object):
+class DynamicLossScaler(object):
def __init__(
- self, init_scale=2.**15, scale_factor=2., scale_window=2000,
- tolerance=0.05, threshold=None, min_loss_scale=1e-4
+ self,
+ init_scale=2.0 ** 15,
+ scale_factor=2.0,
+ scale_window=2000,
+ tolerance=0.05,
+ threshold=None,
+ min_loss_scale=1e-4,
):
self.loss_scale = init_scale
self.scale_factor = scale_factor
@@ -36,7 +41,7 @@ def _decrease_loss_scale(self):
def check_overflow(self, grad_norm):
# detect inf and nan
- if grad_norm == float('inf') or grad_norm != grad_norm:
+ if grad_norm == float("inf") or grad_norm != grad_norm:
# overflow has occured
prev_scale = self.loss_scale
iter_since_rescale = self._iter - self._last_rescale_iter
@@ -53,11 +58,13 @@ def check_overflow(self, grad_norm):
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
self.loss_scale = prev_scale
- raise FloatingPointError((
- 'Minimum loss scale reached ({}). Your loss is probably exploding. '
- 'Try lowering the learning rate, using gradient clipping or '
- 'increasing the batch size.'
- ).format(self.min_loss_scale))
+ raise FloatingPointError(
+ (
+ "Minimum loss scale reached ({}). Your loss is probably exploding. "
+ "Try lowering the learning rate, using gradient clipping or "
+ "increasing the batch size."
+ ).format(self.min_loss_scale)
+ )
self._iter += 1
- raise OverflowError('setting loss scale to: ' + str(self.loss_scale))
+ raise OverflowError("setting loss scale to: " + str(self.loss_scale))
diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py
index b602e51818..8a10399a8b 100644
--- a/fairseq/optim/fairseq_optimizer.py
+++ b/fairseq/optim/fairseq_optimizer.py
@@ -109,8 +109,8 @@ def step(self, closure=None, scale=1.0):
if self.supports_step_with_scale:
self.optimizer.step(closure, scale=scale)
else:
- if scale != 1.:
- self.multiply_grads(1. / scale)
+ if scale != 1.0:
+ self.multiply_grads(1.0 / scale)
self.optimizer.step(closure)
def zero_grad(self):
diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py
index edb4f536ea..b622fbde44 100644
--- a/fairseq/optim/fp16_optimizer.py
+++ b/fairseq/optim/fp16_optimizer.py
@@ -3,41 +3,35 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from itertools import chain
from collections import defaultdict
+from itertools import chain
import torch
-
from fairseq import optim, utils
from .dynamic_loss_scaler import DynamicLossScaler
class _FP16OptimizerMixin(object):
-
def __init__(self, *args, **kwargs):
# forward __init__ call to the next class in mro(method resolution order)
super().__init__(*args, **kwargs)
- self._multiply_factor = 1.
+ self._multiply_factor = 1.0
@property
def has_flat_params(self):
- return (
- torch.is_tensor(self.fp32_params) or
- (
- isinstance(self.fp32_params, dict) and
- all(torch.is_tensor(t) for t in self.fp32_params.values())
- )
+ return torch.is_tensor(self.fp32_params) or (
+ isinstance(self.fp32_params, dict)
+ and all(torch.is_tensor(t) for t in self.fp32_params.values())
)
@classmethod
def build_fp32_params(cls, args, params, flatten=True):
# create FP32 copy of parameters and grads
if flatten:
- is_pipeline_parallel = (
- getattr(args, 'pipeline_model_parallel', False)
- and getattr(args, 'distributed_no_spawn', False)
- )
+ is_pipeline_parallel = getattr(
+ args, "pipeline_model_parallel", False
+ ) and getattr(args, "distributed_no_spawn", False)
total_param_size = sum(p.data.numel() for p in params)
devices = [torch.cuda.current_device()]
if is_pipeline_parallel:
@@ -45,19 +39,25 @@ def build_fp32_params(cls, args, params, flatten=True):
fp32_params = {}
for device in devices:
if is_pipeline_parallel:
- device_param_size = sum(p.data.numel() for p in params if p.device.index == device)
+ device_param_size = sum(
+ p.data.numel() for p in params if p.device.index == device
+ )
device_params = [p for p in params if p.device.index == device]
else:
device_param_size = total_param_size
device_params = params
- fp32_params[device] = device_params[0].new(0).float().new(device_param_size)
+ fp32_params[device] = (
+ device_params[0].new(0).float().new(device_param_size)
+ )
offset = 0
for p in device_params:
numel = p.data.numel()
- fp32_params[device][offset:offset+numel].copy_(p.data.view(-1))
+ fp32_params[device][offset : offset + numel].copy_(p.data.view(-1))
offset += numel
fp32_params[device] = torch.nn.Parameter(fp32_params[device])
- fp32_params[device].grad = fp32_params[device].data.new(device_param_size)
+ fp32_params[device].grad = fp32_params[device].data.new(
+ device_param_size
+ )
return fp32_params
else:
fp32_params = []
@@ -71,7 +71,7 @@ def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
if self.scaler is not None:
- state_dict['loss_scale'] = self.scaler.loss_scale
+ state_dict["loss_scale"] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
@@ -82,8 +82,8 @@ def load_state_dict(self, state_dict, optimizer_overrides=None):
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
- if 'loss_scale' in state_dict and self.scaler is not None:
- self.scaler.loss_scale = state_dict['loss_scale']
+ if "loss_scale" in state_dict and self.scaler is not None:
+ self.scaler.loss_scale = state_dict["loss_scale"]
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
@@ -111,9 +111,15 @@ def _sync_fp16_grads_to_fp32(self):
device_params = device_params_dict[device]
offset = 0
for p in device_params:
- grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
+ grad_data = (
+ p.grad.data
+ if p.grad is not None
+ else p.data.new_zeros(p.data.shape)
+ )
numel = grad_data.numel()
- self.fp32_params[device].grad.data[offset:offset+numel].copy_(grad_data.view(-1))
+ self.fp32_params[device].grad.data[
+ offset : offset + numel
+ ].copy_(grad_data.view(-1))
offset += numel
else:
for p, p32 in zip(self.fp16_params, self.fp32_params):
@@ -138,7 +144,11 @@ def _sync_fp32_params_to_fp16(self):
offset = 0
for p in device_params:
numel = p.data.numel()
- p.data.copy_(self.fp32_params[device].data[offset:offset+numel].view_as(p.data))
+ p.data.copy_(
+ self.fp32_params[device]
+ .data[offset : offset + numel]
+ .view_as(p.data)
+ )
offset += numel
else:
for p, p32 in zip(self.fp16_params, self.fp32_params):
@@ -148,9 +158,9 @@ def _sync_fp32_params_to_fp16(self):
def _unscale_grads(self):
self._sync_fp16_grads_to_fp32()
- if self._multiply_factor != 1.:
+ if self._multiply_factor != 1.0:
self.fp32_optimizer.multiply_grads(self._multiply_factor)
- self._multiply_factor = 1.
+ self._multiply_factor = 1.0
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
@@ -160,7 +170,9 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
- grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(0, aggregate_norm_fn)
+ grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(
+ 0, aggregate_norm_fn
+ )
if self.scaler is not None:
if grad_norm > max_norm > 0.0:
@@ -177,8 +189,8 @@ def step(self, closure=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
- if getattr(self, 'supports_step_with_scale', False):
- self.fp32_optimizer.step(closure, scale=(1. / self._multiply_factor))
+ if getattr(self, "supports_step_with_scale", False):
+ self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor))
else:
self._unscale_grads()
self.fp32_optimizer.step(closure)
@@ -199,14 +211,14 @@ def zero_grad(self):
for fp32_params in self.fp32_params.values():
fp32_params.grad.zero_()
else:
- raise("self.fp32_params must be a tensor or dict")
+ raise ("self.fp32_params must be a tensor or dict")
else:
for p32 in self.fp32_params:
p32.grad.zero_()
self._needs_sync = False
if self.scaler is not None:
- self._multiply_factor = 1. / float(self.scaler.loss_scale)
+ self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
@@ -220,24 +232,26 @@ def __init__(self, args, params, fp32_optimizer, fp32_params):
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
- if getattr(args, 'fp16_scale_window', None) is None:
+ if getattr(args, "fp16_scale_window", None) is None:
if len(args.update_freq) > 1:
raise ValueError(
- '--fp16-scale-window must be given explicitly when using a '
- 'custom --update-freq schedule'
+ "--fp16-scale-window must be given explicitly when using a "
+ "custom --update-freq schedule"
)
- data_parallel_size = int(args.distributed_world_size / args.model_parallel_size)
- scale_window = int(2**14 / data_parallel_size / args.update_freq[0])
+ data_parallel_size = int(
+ args.distributed_world_size / args.model_parallel_size
+ )
+ scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0])
else:
scale_window = args.fp16_scale_window
- if not getattr(args, 'bf16', False):
+ if not getattr(args, "bf16", False):
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
- min_loss_scale=args.min_loss_scale
+ min_loss_scale=args.min_loss_scale,
)
else:
# disable loss scaling for bfloat16
@@ -250,8 +264,8 @@ def build_optimizer(cls, args, params):
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
- flatten = not getattr(args, 'fp16_no_flatten_grads', False)
- if getattr(args, 'bf16', False):
+ flatten = not getattr(args, "fp16_no_flatten_grads", False)
+ if getattr(args, "bf16", False):
flatten = False # mixed precision is faster on TPUs without flat grads
fp32_params = cls.build_fp32_params(args, params, flatten=flatten)
if flatten:
@@ -260,8 +274,8 @@ def build_optimizer(cls, args, params):
fp32_optimizer = optim.build_optimizer(args, fp32_params)
if flatten and not fp32_optimizer.supports_flat_params:
raise RuntimeError(
- 'chosen optimizer does not support flat params, '
- 'please set --fp16-no-flatten-grads'
+ "chosen optimizer does not support flat params, "
+ "please set --fp16-no-flatten-grads"
)
return cls(args, params, fp32_optimizer, fp32_params)
@@ -285,11 +299,10 @@ def set_lr(self, lr):
class _MemoryEfficientFP16OptimizerMixin(object):
-
def __init__(self, *args, **kwargs):
# forward __init__ call to the next class in MRO (method resolution order)
super().__init__(*args, **kwargs)
- self._multiply_factor = 1.
+ self._multiply_factor = 1.0
@property
def has_flat_params(self):
@@ -299,7 +312,7 @@ def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.wrapped_optimizer.state_dict()
if self.scaler is not None:
- state_dict['loss_scale'] = self.scaler.loss_scale
+ state_dict["loss_scale"] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
@@ -310,8 +323,8 @@ def load_state_dict(self, state_dict, optimizer_overrides=None):
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
- if 'loss_scale' in state_dict and self.scaler is not None:
- self.scaler.loss_scale = state_dict['loss_scale']
+ if "loss_scale" in state_dict and self.scaler is not None:
+ self.scaler.loss_scale = state_dict["loss_scale"]
self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides)
@@ -320,17 +333,17 @@ def load_state_dict(self, state_dict, optimizer_overrides=None):
# params are FP16 while the optimizer state is FP32 and we don't want
# to cast. A workaround is to manually copy back the original state
# after the optimizer has been loaded.
- if not getattr(self.optimizer, 'disable_mem_eff_fp16_loading_hack', False):
+ if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False):
groups = self.optimizer.param_groups
- saved_groups = state_dict['param_groups']
+ saved_groups = state_dict["param_groups"]
id_map = {
old_id: p
for old_id, p in zip(
- chain(*(g['params'] for g in saved_groups)),
- chain(*(g['params'] for g in groups))
+ chain(*(g["params"] for g in saved_groups)),
+ chain(*(g["params"] for g in groups)),
)
}
- for k, v in state_dict['state'].items():
+ for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
self.optimizer.state[param] = v
@@ -347,9 +360,9 @@ def backward(self, loss):
loss.backward()
def _unscale_grads(self):
- if self._multiply_factor != 1.:
+ if self._multiply_factor != 1.0:
self.wrapped_optimizer.multiply_grads(self._multiply_factor)
- self._multiply_factor = 1.
+ self._multiply_factor = 1.0
def multiply_grads(self, c):
"""Multiplies grads by a constant *c*."""
@@ -358,11 +371,13 @@ def multiply_grads(self, c):
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm and updates dynamic loss scaler."""
max_norm = float(max_norm)
- grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm(0, aggregate_norm_fn)
+ grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm(
+ 0, aggregate_norm_fn
+ )
if self.scaler is not None:
grad_norm_cpu = float(grad_norm)
- if grad_norm_cpu > max_norm > 0.:
+ if grad_norm_cpu > max_norm > 0.0:
self._multiply_factor *= max_norm / grad_norm_cpu
# detect overflow and adjust loss scale
@@ -375,9 +390,9 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
def step(self, closure=None):
"""Performs a single optimization step."""
- if getattr(self, 'supports_step_with_scale', False):
+ if getattr(self, "supports_step_with_scale", False):
# NOTE(msb) optimizer divides by scale factor
- self.wrapped_optimizer.step(closure, scale=(1. / self._multiply_factor))
+ self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor))
else:
self._unscale_grads()
self.wrapped_optimizer.step(closure)
@@ -389,12 +404,14 @@ def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.wrapped_optimizer.zero_grad()
if self.scaler is not None:
- self._multiply_factor = 1. / float(self.scaler.loss_scale)
+ self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
else:
- self._multiply_factor = 1.
+ self._multiply_factor = 1.0
-class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer):
+class MemoryEfficientFP16Optimizer(
+ _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer
+):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
@@ -413,30 +430,32 @@ class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.Fai
def __init__(self, args, params, optimizer):
if not optimizer.supports_memory_efficient_fp16:
raise ValueError(
- 'Unsupported optimizer: {}'.format(optimizer.__class__.__name__)
+ "Unsupported optimizer: {}".format(optimizer.__class__.__name__)
)
super().__init__(args)
self.wrapped_optimizer = optimizer
- if getattr(args, 'fp16_scale_window', None) is None:
+ if getattr(args, "fp16_scale_window", None) is None:
if len(args.update_freq) > 1:
raise ValueError(
- '--fp16-scale-window must be given explicitly when using a '
- 'custom --update-freq schedule'
+ "--fp16-scale-window must be given explicitly when using a "
+ "custom --update-freq schedule"
)
- data_parallel_size = int(args.distributed_world_size / args.model_parallel_size)
- scale_window = 2**14 / data_parallel_size / args.update_freq[0]
+ data_parallel_size = int(
+ args.distributed_world_size / args.model_parallel_size
+ )
+ scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
- if not getattr(args, 'bf16', False):
+ if not getattr(args, "bf16", False):
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
- min_loss_scale=args.min_loss_scale
+ min_loss_scale=args.min_loss_scale,
)
else:
# disable loss scaling for bfloat16
diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py
index 9024451aff..1780f9c0bb 100644
--- a/fairseq/optim/fused_adam.py
+++ b/fairseq/optim/fused_adam.py
@@ -21,6 +21,7 @@ def get_fused_adam_class():
# `--deprecated_fused_adam` option when building apex.
global fused_adam_cuda
import importlib
+
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
return FusedAdamV1
except ImportError:
@@ -28,6 +29,7 @@ def get_fused_adam_class():
# fallback to the newer interface
from apex.optimizers import FusedAdam as _FusedAdam # noqa
from apex.multi_tensor_apply import multi_tensor_applier
+
if multi_tensor_applier.available:
return FusedAdamV2
except ImportError:
@@ -67,23 +69,32 @@ class FusedAdamV1(torch.optim.Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self, params,
- lr=1e-3, bias_correction=True,
- betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt=False,
- weight_decay=0., max_grad_norm=0., amsgrad=False):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ eps_inside_sqrt=False,
+ weight_decay=0.0,
+ max_grad_norm=0.0,
+ amsgrad=False,
+ ):
global fused_adam_cuda
import importlib
+
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
if amsgrad:
- raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
+ raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
defaults = {
- 'lr': lr,
- 'bias_correction': bias_correction,
- 'betas': betas,
- 'eps': eps,
- 'weight_decay': weight_decay,
- 'max_grad_norm': max_grad_norm,
+ "lr": lr,
+ "bias_correction": bias_correction,
+ "betas": betas,
+ "eps": eps,
+ "weight_decay": weight_decay,
+ "max_grad_norm": max_grad_norm,
}
super().__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
@@ -100,7 +111,7 @@ def supports_flat_params(self):
def supports_step_with_scale(self):
return True
- def step(self, closure=None, grads=None, scale=1., grad_norms=None):
+ def step(self, closure=None, grads=None, scale=1.0, grad_norms=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
@@ -130,23 +141,25 @@ def step(self, closure=None, grads=None, scale=1., grad_norms=None):
grads_group = grads
if grad_norms is None:
- grad_norms = [None]*len(self.param_groups)
+ grad_norms = [None] * len(self.param_groups)
- for group, grads_this_group, grad_norm in zip(self.param_groups, grads_group, grad_norms):
+ for group, grads_this_group, grad_norm in zip(
+ self.param_groups, grads_group, grad_norms
+ ):
if grads_this_group is None:
- grads_this_group = [None]*len(group['params'])
+ grads_this_group = [None] * len(group["params"])
# compute combined scale factor for this group
combined_scale = scale
- if group.get('max_grad_norm', 0) > 0:
+ if group.get("max_grad_norm", 0) > 0:
# norm is in fact norm*scale
- clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
+ clip = ((grad_norm / scale) + 1e-6) / group["max_grad_norm"]
if clip > 1:
combined_scale = clip * scale
- bias_correction = 1 if group.get('bias_correction', 1) else 0
+ bias_correction = 1 if group.get("bias_correction", 1) else 0
- for p, grad in zip(group['params'], grads_this_group):
+ for p, grad in zip(group["params"], grads_this_group):
# note: p.grad should not ever be set for correct
# operation of mixed precision optimizer that sometimes
# sends None gradients
@@ -156,8 +169,8 @@ def step(self, closure=None, grads=None, scale=1., grad_norms=None):
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
- 'FusedAdam does not support sparse gradients, '
- 'please consider SparseAdam instead'
+ "FusedAdam does not support sparse gradients, "
+ "please consider SparseAdam instead"
)
p_data_fp32 = p.data.float()
@@ -166,37 +179,39 @@ def step(self, closure=None, grads=None, scale=1., grad_norms=None):
# State initialization
if len(state) == 0:
- state['step'] = 0
+ state["step"] = 0
# Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(p_data_fp32)
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
- state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
- state['exp_avg'] = state['exp_avg'].to(p_data_fp32)
- state['exp_avg_sq'] = state['exp_avg_sq'].to(p_data_fp32)
+ state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
- exp_avg = state['exp_avg']
- exp_avg_sq = state['exp_avg_sq']
- beta1, beta2 = group['betas']
+ exp_avg = state["exp_avg"]
+ exp_avg_sq = state["exp_avg_sq"]
+ beta1, beta2 = group["betas"]
- state['step'] += 1
+ state["step"] += 1
out_p = p.data
with torch.cuda.device(p.device):
- fused_adam_cuda.adam(p_data_fp32,
- out_p,
- exp_avg,
- exp_avg_sq,
- grad,
- group['lr'],
- beta1,
- beta2,
- group['eps'],
- combined_scale,
- state['step'],
- self.eps_mode,
- bias_correction,
- group['weight_decay'])
+ fused_adam_cuda.adam(
+ p_data_fp32,
+ out_p,
+ exp_avg,
+ exp_avg_sq,
+ grad,
+ group["lr"],
+ beta1,
+ beta2,
+ group["eps"],
+ combined_scale,
+ state["step"],
+ self.eps_mode,
+ bias_correction,
+ group["weight_decay"],
+ )
return loss
@@ -213,8 +228,10 @@ class FusedAdamV2(FusedAdam):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- if not hasattr(self, 'multi_tensor_adam'):
- raise Exception('Apex installation is outdated. Please install an updated version of apex.')
+ if not hasattr(self, "multi_tensor_adam"):
+ raise Exception(
+ "Apex installation is outdated. Please install an updated version of apex."
+ )
@property
def supports_memory_efficient_fp16(self):
@@ -224,89 +241,108 @@ def supports_memory_efficient_fp16(self):
def supports_flat_params(self):
return True
- def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
+ def step(
+ self,
+ closure=None,
+ grads=None,
+ output_params=None,
+ scale=None,
+ grad_norms=None,
+ ):
"""Performs a single optimization step."""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
- bias_correction = 1 if group['bias_correction'] else 0
- beta1, beta2 = group['betas']
+ bias_correction = 1 if group["bias_correction"] else 0
+ beta1, beta2 = group["betas"]
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
- if 'step' in group:
- group['step'] += 1
+ if "step" in group:
+ group["step"] += 1
else:
- group['step'] = 1
+ group["step"] = 1
# create lists for multi-tensor apply
g_16, p_16, orig_p_16, m_16, v_16 = [], [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
- for p in group['params']:
+ for p in group["params"]:
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError(
- 'FusedAdam does not support sparse gradients, '
- 'please consider SparseAdam instead'
+ "FusedAdam does not support sparse gradients, "
+ "please consider SparseAdam instead"
)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float)
+ state["exp_avg"] = torch.zeros_like(p.data, dtype=torch.float)
# Exponential moving average of squared gradient values
- state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float)
+ state["exp_avg_sq"] = torch.zeros_like(
+ p.data, dtype=torch.float
+ )
else:
- state['exp_avg'] = state['exp_avg'].to(device=p.data.device, dtype=torch.float)
- state['exp_avg_sq'] = state['exp_avg_sq'].to(device=p.data.device, dtype=torch.float)
+ state["exp_avg"] = state["exp_avg"].to(
+ device=p.data.device, dtype=torch.float
+ )
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(
+ device=p.data.device, dtype=torch.float
+ )
if p.dtype == torch.float16:
g_16.append(p.grad.data.float())
p_16.append(p.data.float())
orig_p_16.append(p.data)
- m_16.append(state['exp_avg'])
- v_16.append(state['exp_avg_sq'])
+ m_16.append(state["exp_avg"])
+ v_16.append(state["exp_avg_sq"])
elif p.dtype == torch.float32:
g_32.append(p.grad.data)
p_32.append(p.data)
- m_32.append(state['exp_avg'])
- v_32.append(state['exp_avg_sq'])
+ m_32.append(state["exp_avg"])
+ v_32.append(state["exp_avg_sq"])
else:
- raise RuntimeError('FusedAdam only support fp16 and fp32.')
+ raise RuntimeError("FusedAdam only support fp16 and fp32.")
with torch.cuda.device(p.device):
- if(len(g_16) > 0):
- multi_tensor_applier(self.multi_tensor_adam,
- self._dummy_overflow_buf,
- [g_16, p_16, m_16, v_16],
- group['lr'],
- beta1,
- beta2,
- group['eps'],
- group['step'],
- self.adam_w_mode,
- bias_correction,
- group['weight_decay'])
+ if len(g_16) > 0:
+ multi_tensor_applier(
+ self.multi_tensor_adam,
+ self._dummy_overflow_buf,
+ [g_16, p_16, m_16, v_16],
+ group["lr"],
+ beta1,
+ beta2,
+ group["eps"],
+ group["step"],
+ self.adam_w_mode,
+ bias_correction,
+ group["weight_decay"],
+ )
for orig_p, p in zip(orig_p_16, p_16):
orig_p.copy_(p.data)
- if(len(g_32) > 0):
- multi_tensor_applier(self.multi_tensor_adam,
- self._dummy_overflow_buf,
- [g_32, p_32, m_32, v_32],
- group['lr'],
- beta1,
- beta2,
- group['eps'],
- group['step'],
- self.adam_w_mode,
- bias_correction,
- group['weight_decay'])
+ if len(g_32) > 0:
+ multi_tensor_applier(
+ self.multi_tensor_adam,
+ self._dummy_overflow_buf,
+ [g_32, p_32, m_32, v_32],
+ group["lr"],
+ beta1,
+ beta2,
+ group["eps"],
+ group["step"],
+ self.adam_w_mode,
+ bias_correction,
+ group["weight_decay"],
+ )
return loss
+
+
except ImportError:
pass
diff --git a/fairseq/optim/fused_lamb.py b/fairseq/optim/fused_lamb.py
index d48ecbc8e0..f4f2bdb0c6 100644
--- a/fairseq/optim/fused_lamb.py
+++ b/fairseq/optim/fused_lamb.py
@@ -3,10 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from fairseq.optim import register_optimizer, LegacyFairseqOptimizer
+from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
-@register_optimizer('lamb')
+@register_optimizer("lamb")
class FairseqLAMB(LegacyFairseqOptimizer):
"""LAMB optimizer."""
@@ -14,9 +14,10 @@ def __init__(self, args, params):
super().__init__(args)
try:
from apex.optimizers import FusedLAMB
+
self._optimizer = FusedLAMB(params, **self.optimizer_config)
except ImportError:
- raise ImportError('Please install apex to use LAMB optimizer')
+ raise ImportError("Please install apex to use LAMB optimizer")
@staticmethod
def add_args(parser):
@@ -39,10 +40,10 @@ def optimizer_config(self):
different learning rate.
"""
return {
- 'lr': self.args.lr[0],
- 'betas': eval(self.args.lamb_betas),
- 'eps': self.args.lamb_eps,
- 'weight_decay': self.args.weight_decay,
+ "lr": self.args.lr[0],
+ "betas": eval(self.args.lamb_betas),
+ "eps": self.args.lamb_eps,
+ "weight_decay": self.args.weight_decay,
}
@property
diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py
index 9a30195fab..7ca7826ed2 100644
--- a/fairseq/optim/lr_scheduler/fixed_schedule.py
+++ b/fairseq/optim/lr_scheduler/fixed_schedule.py
@@ -3,10 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . import register_lr_scheduler, LegacyFairseqLRScheduler
+from . import LegacyFairseqLRScheduler, register_lr_scheduler
-@register_lr_scheduler('fixed')
+@register_lr_scheduler("fixed")
class FixedSchedule(LegacyFairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
@@ -14,11 +14,11 @@ def __init__(self, args, optimizer):
super().__init__(args, optimizer)
# set defaults
- args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
+ args.warmup_updates = getattr(args, "warmup_updates", 0) or 0
self.lr = args.lr[0]
if args.warmup_updates > 0:
- self.warmup_factor = 1. / args.warmup_updates
+ self.warmup_factor = 1.0 / args.warmup_updates
else:
self.warmup_factor = 1
@@ -35,11 +35,11 @@ def add_args(parser):
# fmt: on
def state_dict(self):
- return {'lr': self.lr}
+ return {"lr": self.lr}
def load_state_dict(self, state_dict):
- if 'lr' in state_dict:
- self.lr = state_dict['lr']
+ if "lr" in state_dict:
+ self.lr = state_dict["lr"]
def get_next_lr(self, epoch):
lrs = self.args.lr
@@ -48,7 +48,9 @@ def get_next_lr(self, epoch):
next_lr = lrs[min(epoch, len(lrs) - 1)]
else:
# annneal based on lr_shrink
- next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal)
+ next_lr = lrs[-1] * self.args.lr_shrink ** (
+ epoch + 1 - self.args.force_anneal
+ )
return next_lr
def step(self, epoch, val_loss=None):
diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py
index 73e8b170bc..ea8e647668 100644
--- a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py
+++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py
@@ -3,10 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . import register_lr_scheduler, LegacyFairseqLRScheduler
+from . import LegacyFairseqLRScheduler, register_lr_scheduler
-@register_lr_scheduler('polynomial_decay')
+@register_lr_scheduler("polynomial_decay")
class PolynomialDecaySchedule(LegacyFairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
@@ -14,11 +14,11 @@ def __init__(self, args, optimizer):
super().__init__(args, optimizer)
# set defaults
- args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
+ args.warmup_updates = getattr(args, "warmup_updates", 0) or 0
self.lr = args.lr[0]
if args.warmup_updates > 0:
- self.warmup_factor = 1. / args.warmup_updates
+ self.warmup_factor = 1.0 / args.warmup_updates
else:
self.warmup_factor = 1
self.end_learning_rate = args.end_learning_rate
@@ -29,13 +29,23 @@ def __init__(self, args, optimizer):
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
- parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
- help='force annealing at specified epoch')
- parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
- help='warmup the learning rate linearly for the first N updates')
- parser.add_argument('--end-learning-rate', default=0.0, type=float)
- parser.add_argument('--power', default=1.0, type=float)
- parser.add_argument('--total-num-update', default=1000000, type=int)
+ parser.add_argument(
+ "--force-anneal",
+ "--fa",
+ type=int,
+ metavar="N",
+ help="force annealing at specified epoch",
+ )
+ parser.add_argument(
+ "--warmup-updates",
+ default=0,
+ type=int,
+ metavar="N",
+ help="warmup the learning rate linearly for the first N updates",
+ )
+ parser.add_argument("--end-learning-rate", default=0.0, type=float)
+ parser.add_argument("--power", default=1.0, type=float)
+ parser.add_argument("--total-num-update", default=1000000, type=int)
def get_next_lr(self, epoch):
lrs = self.args.lr
@@ -64,7 +74,9 @@ def step_update(self, num_updates):
else:
warmup = self.args.warmup_updates
lr_range = self.lr - self.end_learning_rate
- pct_remaining = 1 - (num_updates - warmup) / (self.total_num_update - warmup)
+ pct_remaining = 1 - (num_updates - warmup) / (
+ self.total_num_update - warmup
+ )
lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate
self.optimizer.set_lr(lr)
return self.optimizer.get_lr()
diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
index 5199b09a3e..82bb36efe9 100644
--- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
+++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
@@ -5,10 +5,10 @@
import torch.optim.lr_scheduler
-from . import register_lr_scheduler, LegacyFairseqLRScheduler
+from . import LegacyFairseqLRScheduler, register_lr_scheduler
-@register_lr_scheduler('reduce_lr_on_plateau')
+@register_lr_scheduler("reduce_lr_on_plateau")
class ReduceLROnPlateau(LegacyFairseqLRScheduler):
"""
Decay the LR by a factor every time the validation loss plateaus.
@@ -30,13 +30,16 @@ def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
raise ValueError(
- 'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.'
- ' Consider --lr-scheduler=fixed instead.'
+ "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau."
+ " Consider --lr-scheduler=fixed instead."
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
- self.optimizer.optimizer, patience=args.lr_patience, factor=args.lr_shrink,
- mode='max' if args.maximize_best_checkpoint_metric else 'min',
- threshold=args.lr_threshold)
+ self.optimizer.optimizer,
+ patience=args.lr_patience,
+ factor=args.lr_shrink,
+ mode="max" if args.maximize_best_checkpoint_metric else "min",
+ threshold=args.lr_threshold,
+ )
warmup_end_lr = args.lr[0]
# if no warm up, sets initial lr to be args.lr[0]
if args.warmup_init_lr < 0:
@@ -76,15 +79,15 @@ def add_args(parser):
def state_dict(self):
"""Return the LR scheduler state dict."""
return {
- 'best': self.lr_scheduler.best,
- 'last_epoch': self.lr_scheduler.last_epoch,
+ "best": self.lr_scheduler.best,
+ "last_epoch": self.lr_scheduler.last_epoch,
}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
- self.lr_scheduler.best = state_dict['best']
- if 'last_epoch' in state_dict:
- self.lr_scheduler.last_epoch = state_dict['last_epoch']
+ self.lr_scheduler.best = state_dict["best"]
+ if "last_epoch" in state_dict:
+ self.lr_scheduler.last_epoch = state_dict["last_epoch"]
def step(self, epoch, val_loss=None):
"""
@@ -103,7 +106,7 @@ def step_update(self, num_updates):
# if there is warmup
if self.args.warmup_updates > 0:
if num_updates <= self.args.warmup_updates:
- self.lr = self.args.warmup_init_lr + num_updates*self.lr_step
+ self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
self.optimizer.set_lr(self.lr)
else:
if self.warmup_end is False:
diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py
index 95c5576f20..c573237f11 100644
--- a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py
+++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py
@@ -3,11 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . import register_lr_scheduler, LegacyFairseqLRScheduler
import math
+from . import LegacyFairseqLRScheduler, register_lr_scheduler
-@register_lr_scheduler('tri_stage')
+
+@register_lr_scheduler("tri_stage")
class TriStageLRSchedule(LegacyFairseqLRScheduler):
"""Tristage learning rate schedulr
@@ -50,8 +51,8 @@ def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
raise ValueError(
- 'Cannot use a fixed learning rate schedule with tri-stage lr.'
- ' Consider --lr-scheduler=fixed instead.'
+ "Cannot use a fixed learning rate schedule with tri-stage lr."
+ " Consider --lr-scheduler=fixed instead."
)
# calculate LR at each point
@@ -65,7 +66,8 @@ def __init__(self, args, optimizer):
self.decay_steps = args.decay_steps
self.warmup_rate = (
- (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0
+ (self.peak_lr - self.init_lr) / self.warmup_steps
+ if self.warmup_steps != 0
else 0
)
self.decay_factor = -math.log(args.final_lr_scale) / args.decay_steps
diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py
index 67e1df65e1..0f3193f2b8 100644
--- a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py
+++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py
@@ -5,10 +5,10 @@
import math
-from . import register_lr_scheduler, LegacyFairseqLRScheduler
+from . import LegacyFairseqLRScheduler, register_lr_scheduler
-@register_lr_scheduler('triangular')
+@register_lr_scheduler("triangular")
class TriangularSchedule(LegacyFairseqLRScheduler):
"""Assign LR based on a triangular cyclical schedule.
@@ -19,13 +19,13 @@ def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
raise ValueError(
- 'Cannot use a fixed learning rate schedule with triangular.'
- ' Consider --lr-scheduler=fixed instead.'
+ "Cannot use a fixed learning rate schedule with triangular."
+ " Consider --lr-scheduler=fixed instead."
)
lr = args.lr[0]
- assert args.max_lr > lr, 'max_lr must be more than lr'
+ assert args.max_lr > lr, "max_lr must be more than lr"
self.min_lr = lr
self.max_lr = args.max_lr
self.stepsize = args.lr_period_updates // 2
diff --git a/fairseq/optim/sgd.py b/fairseq/optim/sgd.py
index b558f41ab0..8e34fb99a1 100644
--- a/fairseq/optim/sgd.py
+++ b/fairseq/optim/sgd.py
@@ -5,10 +5,10 @@
import torch.optim
-from . import register_optimizer, LegacyFairseqOptimizer
+from . import LegacyFairseqOptimizer, register_optimizer
-@register_optimizer('sgd')
+@register_optimizer("sgd")
class SGD(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
@@ -33,9 +33,9 @@ def optimizer_config(self):
different learning rate.
"""
return {
- 'lr': self.args.lr[0],
- 'momentum': self.args.momentum,
- 'weight_decay': self.args.weight_decay,
+ "lr": self.args.lr[0],
+ "momentum": self.args.momentum,
+ "weight_decay": self.args.weight_decay,
}
@property
diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py
index 8c508f41f2..a035a1c1f9 100644
--- a/fairseq/optim/shard.py
+++ b/fairseq/optim/shard.py
@@ -6,6 +6,7 @@
try:
from fairscale.optim import OSS
+
_has_fairscale = True
except ImportError:
_has_fairscale = False
@@ -14,8 +15,7 @@
def shard_(args, optimizer, group):
if not _has_fairscale:
raise ImportError(
- '\n\nPlease install the fairscale package:'
- '\n\n pip install fairscale'
+ "\n\nPlease install the fairscale package:" "\n\n pip install fairscale"
)
class FairseqOSS(OSS):
@@ -26,9 +26,16 @@ def disable_mem_eff_fp16_loading_hack(self):
def __getattr__(self, name):
if name.startswith("supports") and hasattr(self.optim, name):
return getattr(self.optim, name)
- raise AttributeError("'FairseqOSS' object has no attribute {0!r}".format(name))
+ raise AttributeError(
+ "'FairseqOSS' object has no attribute {0!r}".format(name)
+ )
torch_optimizer = optimizer.optimizer
optim_cls = type(torch_optimizer)
-
- optimizer.optimizer = FairseqOSS(torch_optimizer.param_groups, optim_cls, group=group, **optimizer.optimizer_config)
+
+ optimizer.optimizer = FairseqOSS(
+ torch_optimizer.param_groups,
+ optim_cls,
+ group=group,
+ **optimizer.optimizer_config
+ )
diff --git a/fairseq/options.py b/fairseq/options.py
index 31ed28a80e..1a24fccaec 100644
--- a/fairseq/options.py
+++ b/fairseq/options.py
@@ -168,7 +168,9 @@ def parse_args_and_arch(
args = parser.parse_args(input_args)
extra = None
# Post-process args.
- if (hasattr(args, "batch_size_valid") and args.batch_size_valid is None) or not hasattr(args, "batch_size_valid"):
+ if (
+ hasattr(args, "batch_size_valid") and args.batch_size_valid is None
+ ) or not hasattr(args, "batch_size_valid"):
args.batch_size_valid = args.batch_size
if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None:
args.max_tokens_valid = args.max_tokens
diff --git a/fairseq/pdb.py b/fairseq/pdb.py
index f1ce3c46bc..1ba6ef0d33 100644
--- a/fairseq/pdb.py
+++ b/fairseq/pdb.py
@@ -9,7 +9,7 @@
import sys
-__all__ = ['set_trace']
+__all__ = ["set_trace"]
_stdin = [None]
diff --git a/fairseq/quantization_utils.py b/fairseq/quantization_utils.py
index a7f5ade9b3..69dd61d785 100644
--- a/fairseq/quantization_utils.py
+++ b/fairseq/quantization_utils.py
@@ -12,7 +12,7 @@
def quantize_model_scalar(model, args):
- quant_noise_scalar = getattr(args, 'quant_noise_scalar', 0)
+ quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
if quant_noise_scalar > 0:
# quantize_model edits the model in place
scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000)
@@ -20,12 +20,11 @@ def quantize_model_scalar(model, args):
class Quantizer(object):
-
def __init__(self, config_path, max_epoch, max_update):
try:
import yaml
except ImportError:
- raise ImportError('Please install yaml with: pip install yaml')
+ raise ImportError("Please install yaml with: pip install yaml")
# parse config
if config_path:
@@ -46,22 +45,23 @@ def __init__(self, config_path, max_epoch, max_update):
num_iterations = len(self.layers_to_quantize)
if max_epoch > 0:
assert max_epoch % num_iterations == 0, (
- 'for iterative PQ, --max-epoch (={}) must be evenly divisible by '
- 'len(layers_to_quantize) (={})'.format(max_epoch, num_iterations)
+ "for iterative PQ, --max-epoch (={}) must be evenly divisible by "
+ "len(layers_to_quantize) (={})".format(max_epoch, num_iterations)
)
self.epoch_schedule = max_epoch // num_iterations
else:
self.epoch_schedule = None
if max_update > 0:
assert max_update % num_iterations == 0, (
- 'for iterative PQ, --max-update (={}) must be evenly divisible by '
- 'len(layers_to_quantize) (={})'.format(max_update, num_iterations)
+ "for iterative PQ, --max-update (={}) must be evenly divisible by "
+ "len(layers_to_quantize) (={})".format(max_update, num_iterations)
)
self.update_schedule = max_update // num_iterations
else:
self.update_schedule = None
- assert (self.epoch_schedule is not None) ^ (self.update_schedule is not None), \
- 'for iterative PQ, cannot specify both --max-update and --max-epoch'
+ assert (self.epoch_schedule is not None) ^ (
+ self.update_schedule is not None
+ ), "for iterative PQ, cannot specify both --max-update and --max-epoch"
# 0 is a special value for quantization step, which will force
# the first call to begin_epoch() to call step()
@@ -80,7 +80,7 @@ def step(self):
return
logger.info(
- 'quantizing model (step={}; layers_to_quantize[step]={})'.format(
+ "quantizing model (step={}; layers_to_quantize[step]={})".format(
self.quantization_step, self.layers_to_quantize[self.quantization_step]
)
)
@@ -92,7 +92,7 @@ def step(self):
self.n_centroids_config,
step=self.quantization_step,
)
- logger.info('quantized layers: {}'.format(quantized_layers))
+ logger.info("quantized layers: {}".format(quantized_layers))
logger.info(self.size_tracker)
self.quantization_step += 1
@@ -125,18 +125,18 @@ def step_update(self, num_updates):
def state_dict(self):
return {
- 'n_centroids_config': self.n_centroids_config,
- 'block_sizes_config': self.block_sizes_config,
- 'layers_to_quantize': self.layers_to_quantize,
- 'epoch_schedule': self.epoch_schedule,
- 'update_schedule': self.update_schedule,
- 'quantization_step': self.quantization_step,
+ "n_centroids_config": self.n_centroids_config,
+ "block_sizes_config": self.block_sizes_config,
+ "layers_to_quantize": self.layers_to_quantize,
+ "epoch_schedule": self.epoch_schedule,
+ "update_schedule": self.update_schedule,
+ "quantization_step": self.quantization_step,
}
def load_state_dict(self, state_dict):
- self.n_centroids_config = state_dict['n_centroids_config']
- self.block_sizes_config = state_dict['block_sizes_config']
- self.layers_to_quantize = state_dict['layers_to_quantize']
- self.epoch_schedule = state_dict['epoch_schedule']
- self.update_schedule = state_dict['update_schedule']
- self.quantization_step = state_dict['quantization_step']
+ self.n_centroids_config = state_dict["n_centroids_config"]
+ self.block_sizes_config = state_dict["block_sizes_config"]
+ self.layers_to_quantize = state_dict["layers_to_quantize"]
+ self.epoch_schedule = state_dict["epoch_schedule"]
+ self.update_schedule = state_dict["update_schedule"]
+ self.quantization_step = state_dict["quantization_step"]
diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py
index 4468f2ad21..4be0cb5188 100644
--- a/fairseq/scoring/__init__.py
+++ b/fairseq/scoring/__init__.py
@@ -49,6 +49,7 @@ def build_scorer(args, tgt_dict):
args.scoring = "sacrebleu"
if args.scoring == "bleu":
from fairseq.scoring import bleu
+
return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
return _build_scorer(args)
diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py
index a45d44b003..7f8bd73bf5 100644
--- a/fairseq/scoring/bleu.py
+++ b/fairseq/scoring/bleu.py
@@ -8,7 +8,6 @@
import sys
import torch
-
from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer
@@ -33,11 +32,12 @@ class SacrebleuScorer(BaseScorer):
def __init__(self, args):
super(SacrebleuScorer, self).__init__(args)
import sacrebleu
+
self.sacrebleu = sacrebleu
self.tokenizer = EvaluationTokenizer(
tokenizer_type=self.args.sacrebleu_tokenizer,
lowercase=self.args.sacrebleu_lowercase,
- character_tokenization=self.args.sacrebleu_char_level
+ character_tokenization=self.args.sacrebleu_char_level,
)
@staticmethod
@@ -63,8 +63,9 @@ def result_string(self, order=4):
if order != 4:
raise NotImplementedError
# tokenization and lowercasing are performed by self.tokenizer instead.
- return self.sacrebleu.corpus_bleu(self.pred, [self.ref],
- tokenize='none').format()
+ return self.sacrebleu.corpus_bleu(
+ self.pred, [self.ref], tokenize="none"
+ ).format()
@register_scorer("bleu")
@@ -78,7 +79,9 @@ def __init__(self, pad, eos, unk):
try:
from fairseq import libbleu
except ImportError as e:
- sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n")
+ sys.stderr.write(
+ "ERROR: missing libbleu.so. run `pip install --editable .`\n"
+ )
raise e
self.C = ctypes.cdll.LoadLibrary(libbleu.__file__)
diff --git a/fairseq/scoring/chrf.py b/fairseq/scoring/chrf.py
index b932a43604..0d6cb77383 100644
--- a/fairseq/scoring/chrf.py
+++ b/fairseq/scoring/chrf.py
@@ -6,11 +6,12 @@
from fairseq.scoring import BaseScorer, register_scorer
-@register_scorer('chrf')
+@register_scorer("chrf")
class ChrFScorer(BaseScorer):
def __init__(self, args):
super(ChrFScorer, self).__init__(args)
import sacrebleu
+
self.sacrebleu = sacrebleu
def add_string(self, ref, pred):
diff --git a/fairseq/scoring/tokenizer.py b/fairseq/scoring/tokenizer.py
index c9d5218e1e..dbcc6e4d10 100644
--- a/fairseq/scoring/tokenizer.py
+++ b/fairseq/scoring/tokenizer.py
@@ -19,13 +19,18 @@ class EvaluationTokenizer(object):
category) from text.
character_tokenization (bool): tokenize the text to characters.
"""
+
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
- ALL_TOKENIZER_TYPES = ['none', '13a', 'intl', 'zh', 'ja-mecab']
+ ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"]
- def __init__(self, tokenizer_type: str = '13a', lowercase: bool = False,
- punctuation_removal: bool = False,
- character_tokenization: bool = False):
+ def __init__(
+ self,
+ tokenizer_type: str = "13a",
+ lowercase: bool = False,
+ punctuation_removal: bool = False,
+ character_tokenization: bool = False,
+ ):
from sacrebleu.tokenizers import TOKENIZERS
assert tokenizer_type in self.ALL_TOKENIZER_TYPES
@@ -38,8 +43,9 @@ def __init__(self, tokenizer_type: str = '13a', lowercase: bool = False,
def remove_punctuation(cls, sent: str):
"""Remove punctuation based on Unicode category."""
return cls.SPACE.join(
- t for t in sent.split(cls.SPACE)
- if not all(unicodedata.category(c)[0] == 'P' for c in t)
+ t
+ for t in sent.split(cls.SPACE)
+ if not all(unicodedata.category(c)[0] == "P" for c in t)
)
def tokenize(self, sent: str):
diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py
index 61c5fd950e..21efefd9b8 100644
--- a/fairseq/scoring/wer.py
+++ b/fairseq/scoring/wer.py
@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from fairseq.scoring import register_scorer, BaseScorer
+from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer
@@ -15,7 +15,7 @@ def __init__(self, args):
try:
import editdistance as ed
except ImportError:
- raise ImportError('Please install editdistance to use WER scorer')
+ raise ImportError("Please install editdistance to use WER scorer")
self.ed = ed
self.tokenizer = EvaluationTokenizer(
tokenizer_type=self.args.wer_tokenizer,
@@ -52,7 +52,4 @@ def result_string(self):
return f"WER: {self.score():.2f}"
def score(self):
- return (
- 100.0 * self.distance / self.ref_length if self.ref_length > 0
- else 0
- )
+ return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0
diff --git a/fairseq/search.py b/fairseq/search.py
index 2c21b66bbd..d5ea68b4ce 100644
--- a/fairseq/search.py
+++ b/fairseq/search.py
@@ -4,17 +4,16 @@
# LICENSE file in the root directory of this source tree.
import math
-from typing import Optional, List
+from typing import List, Optional
import torch
import torch.nn as nn
-from torch import Tensor
-
from fairseq.token_generation_constraints import (
ConstraintState,
- UnorderedConstraintState,
OrderedConstraintState,
+ UnorderedConstraintState,
)
+from torch import Tensor
class Search(nn.Module):
diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py
index 7ce797746f..ddfb67853f 100644
--- a/fairseq/sequence_generator.py
+++ b/fairseq/sequence_generator.py
@@ -34,7 +34,7 @@ def __init__(
eos=None,
symbols_to_strip_from_output=None,
lm_model=None,
- lm_weight=1.0
+ lm_weight=1.0,
):
"""Generates translations of a given source sentence.
@@ -69,7 +69,9 @@ def __init__(
self.eos = tgt_dict.eos() if eos is None else eos
self.symbols_to_strip_from_output = (
symbols_to_strip_from_output.union({self.eos})
- if symbols_to_strip_from_output is not None else {self.eos})
+ if symbols_to_strip_from_output is not None
+ else {self.eos}
+ )
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
# the max beam size is the dictionary size - 1, since we never select pad
@@ -92,7 +94,9 @@ def __init__(
# We only need to set src_lengths in LengthConstrainedBeamSearch.
# As a module attribute, setting it would break in multithread
# settings when the model is shared.
- self.should_set_src_lengths = hasattr(self.search, 'needs_src_lengths') and self.search.needs_src_lengths
+ self.should_set_src_lengths = (
+ hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
+ )
self.model.eval()
@@ -188,19 +192,21 @@ def _generate(
)
net_input = sample["net_input"]
- if 'src_tokens' in net_input:
- src_tokens = net_input['src_tokens']
+ if "src_tokens" in net_input:
+ src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad
- src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
- elif 'source' in net_input:
- src_tokens = net_input['source']
src_lengths = (
- net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1)
- if net_input['padding_mask'] is not None
+ (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
+ )
+ elif "source" in net_input:
+ src_tokens = net_input["source"]
+ src_lengths = (
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
+ if net_input["padding_mask"] is not None
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
)
else:
- raise Exception('expected src_tokens or source in net input')
+ raise Exception("expected src_tokens or source in net input")
# bsz: total number of sentences in beam
# Note that src_tokens may have more than 2 dimenions (i.e. audio features)
@@ -208,7 +214,9 @@ def _generate(
beam_size = self.beam_size
if constraints is not None and not self.search.supports_constraints:
- raise NotImplementedError("Target-side constraints were provided, but search method doesn't support them")
+ raise NotImplementedError(
+ "Target-side constraints were provided, but search method doesn't support them"
+ )
# Initialize constraints, when active
self.search.init_constraints(constraints, beam_size)
@@ -421,10 +429,14 @@ def _generate(
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
- batch_mask = torch.ones(bsz, dtype=torch.bool, device=cand_indices.device)
+ batch_mask = torch.ones(
+ bsz, dtype=torch.bool, device=cand_indices.device
+ )
batch_mask[finalized_sents] = False
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
- batch_idxs = torch.arange(bsz, device=cand_indices.device).masked_select(batch_mask)
+ batch_idxs = torch.arange(
+ bsz, device=cand_indices.device
+ ).masked_select(batch_mask)
# Choose the subset of the hypothesized constraints that will continue
self.search.prune_sentences(batch_idxs)
@@ -519,10 +531,14 @@ def _generate(
# sort by score descending
for sent in range(len(finalized)):
- scores = torch.tensor([float(elem["score"].item()) for elem in finalized[sent]])
+ scores = torch.tensor(
+ [float(elem["score"].item()) for elem in finalized[sent]]
+ )
_, sorted_scores_indices = torch.sort(scores, descending=True)
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
- finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], finalized[sent])
+ finalized[sent] = torch.jit.annotate(
+ List[Dict[str, Tensor]], finalized[sent]
+ )
return finalized
def _prefix_tokens(
@@ -787,10 +803,7 @@ def max_decoder_positions(self):
def forward_encoder(self, net_input: Dict[str, Tensor]):
if not self.has_encoder():
return None
- return [
- model.encoder.forward_torchscript(net_input)
- for model in self.models
- ]
+ return [model.encoder.forward_torchscript(net_input) for model in self.models]
@torch.jit.export
def forward_decoder(
@@ -915,9 +928,12 @@ def generate(self, models, sample, **kwargs):
src_tokens = sample["net_input"]["src_tokens"]
bsz = src_tokens.shape[0]
beam_size = self.beam_size
- src_tokens, src_lengths, prev_output_tokens, tgt_tokens = self._prepare_batch_for_alignment(
- sample, finalized
- )
+ (
+ src_tokens,
+ src_lengths,
+ prev_output_tokens,
+ tgt_tokens,
+ ) = self._prepare_batch_for_alignment(sample, finalized)
if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
else:
@@ -927,9 +943,9 @@ def generate(self, models, sample, **kwargs):
]
if src_tokens.device != "cpu":
- src_tokens = src_tokens.to('cpu')
- tgt_tokens = tgt_tokens.to('cpu')
- attn = [i.to('cpu') for i in attn]
+ src_tokens = src_tokens.to("cpu")
+ tgt_tokens = tgt_tokens.to("cpu")
+ attn = [i.to("cpu") for i in attn]
# Process the attn matrix to extract hard alignments.
for i in range(bsz * beam_size):
diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py
index c8ded1930c..411d4df444 100644
--- a/fairseq/sequence_scorer.py
+++ b/fairseq/sequence_scorer.py
@@ -3,9 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
import sys
+import torch
from fairseq import utils
@@ -13,7 +13,11 @@ class SequenceScorer(object):
"""Scores the target for a given source sentence."""
def __init__(
- self, tgt_dict, softmax_batch=None, compute_alignment=False, eos=None,
+ self,
+ tgt_dict,
+ softmax_batch=None,
+ compute_alignment=False,
+ eos=None,
symbols_to_strip_from_output=None,
):
self.pad = tgt_dict.pad()
@@ -23,12 +27,14 @@ def __init__(
self.compute_alignment = compute_alignment
self.symbols_to_strip_from_output = (
symbols_to_strip_from_output.union({self.eos})
- if symbols_to_strip_from_output is not None else {self.eos})
+ if symbols_to_strip_from_output is not None
+ else {self.eos}
+ )
@torch.no_grad()
def generate(self, models, sample, **kwargs):
"""Score a batch of translations."""
- net_input = sample['net_input']
+ net_input = sample["net_input"]
def batch_for_softmax(dec_out, target):
# assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
@@ -52,7 +58,7 @@ def gather_target_probs(probs, target):
)
return probs
- orig_target = sample['target']
+ orig_target = sample["target"]
# compute scores for each model in the ensemble
avg_probs = None
@@ -62,13 +68,15 @@ def gather_target_probs(probs, target):
decoder_out = model(**net_input)
attn = decoder_out[1] if len(decoder_out) > 1 else None
if type(attn) is dict:
- attn = attn.get('attn', None)
+ attn = attn.get("attn", None)
batched = batch_for_softmax(decoder_out, orig_target)
probs, idx = None, 0
for bd, tgt, is_single in batched:
- sample['target'] = tgt
- curr_prob = model.get_normalized_probs(bd, log_probs=len(models) == 1, sample=sample).data
+ sample["target"] = tgt
+ curr_prob = model.get_normalized_probs(
+ bd, log_probs=len(models) == 1, sample=sample
+ ).data
if is_single:
probs = gather_target_probs(curr_prob, orig_target)
else:
@@ -76,12 +84,14 @@ def gather_target_probs(probs, target):
probs = curr_prob.new(orig_target.numel())
step = curr_prob.size(0) * curr_prob.size(1)
end = step + idx
- tgt_probs = gather_target_probs(curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt)
+ tgt_probs = gather_target_probs(
+ curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt
+ )
probs[idx:end] = tgt_probs.view(-1)
idx = end
- sample['target'] = orig_target
+ sample["target"] = orig_target
- probs = probs.view(sample['target'].shape)
+ probs = probs.view(sample["target"].shape)
if avg_probs is None:
avg_probs = probs
@@ -104,21 +114,24 @@ def gather_target_probs(probs, target):
bsz = avg_probs.size(0)
hypos = []
- start_idxs = sample['start_indices'] if 'start_indices' in sample else [0] * bsz
+ start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz
for i in range(bsz):
# remove padding from ref
- ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \
- if sample['target'] is not None else None
+ ref = (
+ utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad)
+ if sample["target"] is not None
+ else None
+ )
tgt_len = ref.numel()
- avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
+ avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len]
score_i = avg_probs_i.sum() / tgt_len
if avg_attn is not None:
avg_attn_i = avg_attn[i]
if self.compute_alignment:
alignment = utils.extract_hard_alignment(
avg_attn_i,
- sample['net_input']['src_tokens'][i],
- sample['target'][i],
+ sample["net_input"]["src_tokens"][i],
+ sample["target"][i],
self.pad,
self.eos,
)
@@ -126,11 +139,15 @@ def gather_target_probs(probs, target):
alignment = None
else:
avg_attn_i = alignment = None
- hypos.append([{
- 'tokens': ref,
- 'score': score_i,
- 'attention': avg_attn_i,
- 'alignment': alignment,
- 'positional_scores': avg_probs_i,
- }])
+ hypos.append(
+ [
+ {
+ "tokens": ref,
+ "score": score_i,
+ "attention": avg_attn_i,
+ "alignment": alignment,
+ "positional_scores": avg_probs_i,
+ }
+ ]
+ )
return hypos
diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py
index 75bcfaa8db..ff2342afa9 100644
--- a/fairseq/tasks/audio_pretraining.py
+++ b/fairseq/tasks/audio_pretraining.py
@@ -8,7 +8,8 @@
import os
import sys
-from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset
+from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset
+
from . import LegacyFairseqTask, register_task
@@ -24,9 +25,7 @@ def __call__(self, label):
@register_task("audio_pretraining")
class AudioPretrainingTask(LegacyFairseqTask):
- """
-
- """
+ """"""
@staticmethod
def add_args(parser):
@@ -137,11 +136,11 @@ def max_positions(self):
return (sys.maxsize, sys.maxsize)
def filter_indices_by_size(
- self,
- indices,
- dataset,
- max_positions=None,
- ignore_invalid_inputs=False,
+ self,
+ indices,
+ dataset,
+ max_positions=None,
+ ignore_invalid_inputs=False,
):
# we do not need to filter by size in this task as dataloaders take care of this
return indices
diff --git a/fairseq/tasks/cross_lingual_lm.py b/fairseq/tasks/cross_lingual_lm.py
index a7ce1f1ad5..8f8fe7e2de 100644
--- a/fairseq/tasks/cross_lingual_lm.py
+++ b/fairseq/tasks/cross_lingual_lm.py
@@ -3,31 +3,24 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from collections import OrderedDict
import itertools
import logging
import os
+from collections import OrderedDict
import numpy as np
-
-from fairseq import tokenizer
-from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
-
-from fairseq.data import (
- Dictionary,
- ConcatDataset,
- data_utils,
- TokenBlockDataset,
-)
+from fairseq import tokenizer, utils
+from fairseq.data import ConcatDataset, Dictionary, TokenBlockDataset, data_utils
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
+from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
-from fairseq.tasks import register_task, LegacyFairseqTask
-from fairseq import utils
+from fairseq.tasks import LegacyFairseqTask, register_task
+
logger = logging.getLogger(__name__)
-@register_task('cross_lingual_lm')
+@register_task("cross_lingual_lm")
class CrossLingualLMTask(LegacyFairseqTask):
"""
Task for training cross-lingual language models.
@@ -41,17 +34,29 @@ class CrossLingualLMTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', help='colon separated path to data directories list, \
- will be iterated upon during epochs in round-robin manner')
- parser.add_argument('--tokens-per-sample', default=512, type=int,
- help='max number of total tokens over all segments'
- ' per sample')
- parser.add_argument('--monolingual-langs', default='en', type=str,
- help='comma separated list of languages for which we'
- ' want to train XLM on')
- parser.add_argument('--shuffle', action='store_true',
- help='shuffle each monolingual dataset while'
- ' training')
+ parser.add_argument(
+ "data",
+ help="colon separated path to data directories list, \
+ will be iterated upon during epochs in round-robin manner",
+ )
+ parser.add_argument(
+ "--tokens-per-sample",
+ default=512,
+ type=int,
+ help="max number of total tokens over all segments" " per sample",
+ )
+ parser.add_argument(
+ "--monolingual-langs",
+ default="en",
+ type=str,
+ help="comma separated list of languages for which we"
+ " want to train XLM on",
+ )
+ parser.add_argument(
+ "--shuffle",
+ action="store_true",
+ help="shuffle each monolingual dataset while" " training",
+ )
def __init__(self, args, dictionary):
super().__init__(args)
@@ -60,16 +65,13 @@ def __init__(self, args, dictionary):
self.distributed_world_size = args.distributed_world_size
self.langs2id = self._lang_to_id(args.monolingual_langs)
- def _lang_to_id(
- self,
- languages: str
- ):
+ def _lang_to_id(self, languages: str):
"""
Build a map from languages to ids. These ids are used as segment labels
for cross-lingual LM training.
"""
lang2id = {}
- langs = [l.strip() for l in languages.split(',')]
+ langs = [l.strip() for l in languages.split(",")]
for id, lang in enumerate(langs):
lang2id[lang] = id
return lang2id
@@ -79,10 +81,14 @@ def load_dictionary(cls, filename):
return MaskedLMDictionary.load(filename)
@classmethod
- def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
+ def build_dictionary(
+ cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
+ ):
d = MaskedLMDictionary()
for filename in filenames:
- Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
+ Dictionary.add_file_to_dictionary(
+ filename, d, tokenizer.tokenize_line, workers
+ )
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@@ -93,8 +99,8 @@ def target_dictionary(self):
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task."""
- dictionary = MaskedLMDictionary.load(os.path.join(args.data, 'dict.txt'))
- logger.info('dictionary: {} types'.format(len(dictionary)))
+ dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt"))
+ logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def _load_single_lang_dataset(self, split, epoch):
@@ -105,27 +111,36 @@ def _load_single_lang_dataset(self, split, epoch):
data_path = paths[(epoch - 1) % len(paths)]
for k in itertools.count():
- split_k = split + (str(k) if k > 0 else '')
+ split_k = split + (str(k) if k > 0 else "")
path = os.path.join(data_path, split_k)
- ds = data_utils.load_indexed_dataset(path, self.dictionary, self.args.dataset_impl)
+ ds = data_utils.load_indexed_dataset(
+ path, self.dictionary, self.args.dataset_impl
+ )
if ds is None:
if k > 0:
break
else:
- raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, data_path)
+ )
# Since we append each block with the classification_token,
# we need to effectively create blocks of length
# tokens_per_sample-1
loaded_datasets.append(
TokenBlockDataset(
- ds, ds.sizes, self.args.tokens_per_sample - 1,
- pad=self.dictionary.pad(), eos=self.dictionary.eos(),
+ ds,
+ ds.sizes,
+ self.args.tokens_per_sample - 1,
+ pad=self.dictionary.pad(),
+ eos=self.dictionary.eos(),
)
)
- logger.info('{} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
+ logger.info(
+ "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))
+ )
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
@@ -146,9 +161,11 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
for lang in self.langs2id.keys():
# Datasets are expected to be in "split.lang" format (Eg: train.en)
- language_split = '{}.{}'.format(split, lang)
+ language_split = "{}.{}".format(split, lang)
- block_dataset, sizes = self._load_single_lang_dataset(split=language_split, epoch=epoch)
+ block_dataset, sizes = self._load_single_lang_dataset(
+ split=language_split, epoch=epoch
+ )
dataset_map[lang] = MaskedLMDataset(
dataset=block_dataset,
@@ -158,13 +175,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
mask_idx=self.dictionary.mask(),
classif_token_idx=self.dictionary.eos(),
sep_token_idx=self.dictionary.eos(),
- shuffle=getattr(self.args, 'shuffle', False),
+ shuffle=getattr(self.args, "shuffle", False),
has_pairs=False,
segment_id=self.langs2id[lang],
seed=self.seed,
)
self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
- logger.info('{} {} {} examples'.format(
- utils.split_paths(self.args.data)[epoch - 1], split, len(self.datasets[split]))
+ logger.info(
+ "{} {} {} examples".format(
+ utils.split_paths(self.args.data)[epoch - 1],
+ split,
+ len(self.datasets[split]),
+ )
)
diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py
index ea6db45c75..3e88bf0ed0 100644
--- a/fairseq/tasks/denoising.py
+++ b/fairseq/tasks/denoising.py
@@ -6,24 +6,24 @@
import logging
import os
+from fairseq import utils
from fairseq.data import (
- data_utils,
- Dictionary,
AppendTokenDataset,
DenoisingDataset,
+ Dictionary,
PrependTokenDataset,
StripTokenDataset,
TokenBlockDataset,
+ data_utils,
)
from fairseq.data.encoders.utils import get_whole_word_mask
-from fairseq.tasks import register_task, LegacyFairseqTask
-from fairseq import utils
+from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
-@register_task('denoising')
+@register_task("denoising")
class DenoisingTask(LegacyFairseqTask):
"""
Denoising task for applying sequence to sequence denoising. (ie. BART)
@@ -32,58 +32,88 @@ class DenoisingTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', help='path to data directory')
- parser.add_argument('--tokens-per-sample', default=512, type=int,
- help='max number of total tokens over all segments'
- ' per sample for dataset')
+ parser.add_argument("data", help="path to data directory")
parser.add_argument(
- '--sample-break-mode', default="complete_doc", type=str,
- help='mode for breaking sentence',
+ "--tokens-per-sample",
+ default=512,
+ type=int,
+ help="max number of total tokens over all segments"
+ " per sample for dataset",
)
parser.add_argument(
- '--mask', default=0.0, type=float,
- help='fraction of words/subwords that will be masked',
+ "--sample-break-mode",
+ default="complete_doc",
+ type=str,
+ help="mode for breaking sentence",
)
parser.add_argument(
- '--mask-random', default=0.0, type=float,
- help='instead of using [MASK], use random token this often'
+ "--mask",
+ default=0.0,
+ type=float,
+ help="fraction of words/subwords that will be masked",
)
parser.add_argument(
- '--insert', default=0.0, type=float,
- help='insert this percentage of additional random tokens',
+ "--mask-random",
+ default=0.0,
+ type=float,
+ help="instead of using [MASK], use random token this often",
)
parser.add_argument(
- '--permute', default=0.0, type=float,
- help='take this proportion of subwords and permute them',
+ "--insert",
+ default=0.0,
+ type=float,
+ help="insert this percentage of additional random tokens",
)
parser.add_argument(
- '--rotate', default=0.5, type=float,
- help='rotate this proportion of inputs',
+ "--permute",
+ default=0.0,
+ type=float,
+ help="take this proportion of subwords and permute them",
)
parser.add_argument(
- '--poisson-lambda', default=3.0, type=float,
- help='randomly shuffle sentences for this proportion of inputs'
+ "--rotate",
+ default=0.5,
+ type=float,
+ help="rotate this proportion of inputs",
)
parser.add_argument(
- '--permute-sentences', default=0.0, type=float,
- help='shuffle this proportion of sentences in all inputs'
+ "--poisson-lambda",
+ default=3.0,
+ type=float,
+ help="randomly shuffle sentences for this proportion of inputs",
)
parser.add_argument(
- '--mask-length', default="subword", type=str,
- choices=['subword', 'word', 'span-poisson'],
- help='mask length to choose'
+ "--permute-sentences",
+ default=0.0,
+ type=float,
+ help="shuffle this proportion of sentences in all inputs",
)
parser.add_argument(
- '--replace-length', default=-1, type=int,
- help='when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)'
+ "--mask-length",
+ default="subword",
+ type=str,
+ choices=["subword", "word", "span-poisson"],
+ help="mask length to choose",
)
parser.add_argument(
- '--max-source-positions', default=1024, type=int, metavar='N',
- help='max number of tokens in the source sequence'
+ "--replace-length",
+ default=-1,
+ type=int,
+ help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
)
parser.add_argument(
- '--max-target-positions', default=1024, type=int, metavar='N',
- help='max number of tokens in the target sequence'
+ "--max-source-positions",
+ default=1024,
+ type=int,
+ metavar="N",
+ help="max number of tokens in the source sequence",
+ )
+ parser.add_argument(
+ "--max-target-positions",
+ default=1024,
+ type=int,
+ metavar="N",
+ help="max number of tokens in the target sequence",
)
def __init__(self, args, dictionary):
@@ -92,15 +122,14 @@ def __init__(self, args, dictionary):
self.seed = args.seed
# add mask token
- self.mask_idx = self.dictionary.add_symbol('')
+ self.mask_idx = self.dictionary.add_symbol("")
@classmethod
def setup_task(cls, args, **kwargs):
- """Setup the task.
- """
- dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
- logger.info('dictionary: {} types'.format(len(dictionary)))
- if not hasattr(args, 'shuffle_instance'):
+ """Setup the task."""
+ dictionary = Dictionary.load(os.path.join(args.data, "dict.txt"))
+ logger.info("dictionary: {} types".format(len(dictionary)))
+ if not hasattr(args, "shuffle_instance"):
args.shuffle_instance = False
return cls(args, dictionary)
@@ -122,32 +151,42 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
combine=combine,
)
if dataset is None:
- raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, split_path)
+ )
dataset = StripTokenDataset(dataset, self.dictionary.eos())
# create continuous blocks of tokens
dataset = TokenBlockDataset(
- dataset,
- dataset.sizes,
- self.args.tokens_per_sample - 2, # one less for and one for
- pad=self.dictionary.pad(),
- eos=self.dictionary.eos(),
- break_mode=self.args.sample_break_mode,
- document_sep_len=0
+ dataset,
+ dataset.sizes,
+ self.args.tokens_per_sample - 2, # one less for and one for
+ pad=self.dictionary.pad(),
+ eos=self.dictionary.eos(),
+ break_mode=self.args.sample_break_mode,
+ document_sep_len=0,
)
# prepend beginning-of-sentence token (, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
- mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
- if self.args.mask_length != 'subword' else None
+ mask_whole_words = (
+ get_whole_word_mask(self.args, self.source_dictionary)
+ if self.args.mask_length != "subword"
+ else None
+ )
self.datasets[split] = DenoisingDataset(
- dataset, dataset.sizes, self.dictionary, self.mask_idx,
- mask_whole_words, shuffle=self.args.shuffle_instance,
- seed=self.seed, args=self.args
+ dataset,
+ dataset.sizes,
+ self.dictionary,
+ self.mask_idx,
+ mask_whole_words,
+ shuffle=self.args.shuffle_instance,
+ seed=self.seed,
+ args=self.args,
)
logger.info(
"Split: {0}, Loaded {1} samples of denoising_dataset".format(
diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py
index a8bfaa532d..0a96aeb1ea 100644
--- a/fairseq/tasks/fairseq_task.py
+++ b/fairseq/tasks/fairseq_task.py
@@ -10,9 +10,10 @@
import torch
from fairseq import metrics, search, tokenizer, utils
-from fairseq.data import Dictionary, FairseqDataset, data_utils, iterators, encoders
+from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators
from fairseq.dataclass.utils import gen_parser_from_dataclass
+
logger = logging.getLogger(__name__)
@@ -358,7 +359,7 @@ def build_generator(
)
elif prefix_allowed_tokens_fn:
search_strategy = search.PrefixConstrainedBeamSearch(
- self.target_dictionary, prefix_allowed_tokens_fn
+ self.target_dictionary, prefix_allowed_tokens_fn
)
else:
search_strategy = search.BeamSearch(self.target_dictionary)
diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py
index 5477c28aa9..8792c6481c 100644
--- a/fairseq/tasks/language_modeling.py
+++ b/fairseq/tasks/language_modeling.py
@@ -27,7 +27,7 @@
)
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.data.shorten_dataset import maybe_shorten_dataset
-from fairseq.dataclass import FairseqDataclass, ChoiceEnum
+from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II
diff --git a/fairseq/tasks/legacy_masked_lm.py b/fairseq/tasks/legacy_masked_lm.py
index 4e0390cdca..9754976549 100644
--- a/fairseq/tasks/legacy_masked_lm.py
+++ b/fairseq/tasks/legacy_masked_lm.py
@@ -8,26 +8,18 @@
import os
import numpy as np
-
-from fairseq import tokenizer
-from fairseq.data import (
- ConcatDataset,
- indexed_dataset,
- data_utils,
-)
-
-from fairseq.data import Dictionary
+from fairseq import tokenizer, utils
+from fairseq.data import ConcatDataset, Dictionary, data_utils, indexed_dataset
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
-from fairseq.tasks import register_task, LegacyFairseqTask
-from fairseq import utils
+from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
-@register_task('legacy_masked_lm')
+@register_task("legacy_masked_lm")
class LegacyMaskedLMTask(LegacyFairseqTask):
"""
Task for training Masked LM (BERT) model.
@@ -38,13 +30,22 @@ class LegacyMaskedLMTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', help='colon separated path to data directories list, \
- will be iterated upon during epochs in round-robin manner')
- parser.add_argument('--tokens-per-sample', default=512, type=int,
- help='max number of total tokens over all segments'
- ' per sample for BERT dataset')
- parser.add_argument('--break-mode', default="doc", type=str, help='mode for breaking sentence')
- parser.add_argument('--shuffle-dataset', action='store_true', default=False)
+ parser.add_argument(
+ "data",
+ help="colon separated path to data directories list, \
+ will be iterated upon during epochs in round-robin manner",
+ )
+ parser.add_argument(
+ "--tokens-per-sample",
+ default=512,
+ type=int,
+ help="max number of total tokens over all segments"
+ " per sample for BERT dataset",
+ )
+ parser.add_argument(
+ "--break-mode", default="doc", type=str, help="mode for breaking sentence"
+ )
+ parser.add_argument("--shuffle-dataset", action="store_true", default=False)
def __init__(self, args, dictionary):
super().__init__(args)
@@ -56,10 +57,14 @@ def load_dictionary(cls, filename):
return BertDictionary.load(filename)
@classmethod
- def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
+ def build_dictionary(
+ cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
+ ):
d = BertDictionary()
for filename in filenames:
- Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
+ Dictionary.add_file_to_dictionary(
+ filename, d, tokenizer.tokenize_line, workers
+ )
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@@ -69,12 +74,11 @@ def target_dictionary(self):
@classmethod
def setup_task(cls, args, **kwargs):
- """Setup the task.
- """
+ """Setup the task."""
paths = utils.split_paths(args.data)
assert len(paths) > 0
- dictionary = BertDictionary.load(os.path.join(paths[0], 'dict.txt'))
- logger.info('dictionary: {} types'.format(len(dictionary)))
+ dictionary = BertDictionary.load(os.path.join(paths[0], "dict.txt"))
+ logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
@@ -92,7 +96,7 @@ def load_dataset(self, split, epoch=1, combine=False):
logger.info("data_path", data_path)
for k in itertools.count():
- split_k = split + (str(k) if k > 0 else '')
+ split_k = split + (str(k) if k > 0 else "")
path = os.path.join(data_path, split_k)
ds = indexed_dataset.make_dataset(
path,
@@ -105,7 +109,9 @@ def load_dataset(self, split, epoch=1, combine=False):
if k > 0:
break
else:
- raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, data_path)
+ )
with data_utils.numpy_seed(self.seed + k):
loaded_datasets.append(
@@ -119,7 +125,9 @@ def load_dataset(self, split, epoch=1, combine=False):
)
)
- logger.info('{} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
+ logger.info(
+ "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))
+ )
if not combine:
break
diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py
index 10b234a96b..56086f5e81 100644
--- a/fairseq/tasks/masked_lm.py
+++ b/fairseq/tasks/masked_lm.py
@@ -7,64 +7,99 @@
import os
import numpy as np
-
+from fairseq import utils
from fairseq.data import (
- data_utils,
Dictionary,
IdDataset,
MaskTokensDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
- RightPadDataset,
PrependTokenDataset,
+ RightPadDataset,
SortDataset,
TokenBlockDataset,
+ data_utils,
)
-from fairseq.tasks import register_task, LegacyFairseqTask
-from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.data.encoders.utils import get_whole_word_mask
-from fairseq import utils
+from fairseq.data.shorten_dataset import maybe_shorten_dataset
+from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
-@register_task('masked_lm')
+@register_task("masked_lm")
class MaskedLMTask(LegacyFairseqTask):
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', help='colon separated path to data directories list, \
- will be iterated upon during epochs in round-robin manner')
- parser.add_argument('--sample-break-mode', default='complete',
- choices=['none', 'complete', 'complete_doc', 'eos'],
- help='If omitted or "none", fills each sample with tokens-per-sample '
- 'tokens. If set to "complete", splits samples only at the end '
- 'of sentence, but may include multiple sentences per sample. '
- '"complete_doc" is similar but respects doc boundaries. '
- 'If set to "eos", includes only one sentence per sample.')
- parser.add_argument('--tokens-per-sample', default=512, type=int,
- help='max number of total tokens over all segments '
- 'per sample for BERT dataset')
- parser.add_argument('--mask-prob', default=0.15, type=float,
- help='probability of replacing a token with mask')
- parser.add_argument('--leave-unmasked-prob', default=0.1, type=float,
- help='probability that a masked token is unmasked')
- parser.add_argument('--random-token-prob', default=0.1, type=float,
- help='probability of replacing a token with a random token')
- parser.add_argument('--freq-weighted-replacement', default=False, action='store_true',
- help='sample random replacement words based on word frequencies')
- parser.add_argument('--mask-whole-words', default=False, action='store_true',
- help='mask whole words; you may also want to set --bpe')
- parser.add_argument('--shorten-method', default='none',
- choices=['none', 'truncate', 'random_crop'],
- help='if not none, shorten sequences that exceed --tokens-per-sample')
- parser.add_argument('--shorten-data-split-list', default='',
- help='comma-separated list of dataset splits to apply shortening to, '
- 'e.g., "train,valid" (default: all dataset splits)')
+ parser.add_argument(
+ "data",
+ help="colon separated path to data directories list, \
+ will be iterated upon during epochs in round-robin manner",
+ )
+ parser.add_argument(
+ "--sample-break-mode",
+ default="complete",
+ choices=["none", "complete", "complete_doc", "eos"],
+ help='If omitted or "none", fills each sample with tokens-per-sample '
+ 'tokens. If set to "complete", splits samples only at the end '
+ "of sentence, but may include multiple sentences per sample. "
+ '"complete_doc" is similar but respects doc boundaries. '
+ 'If set to "eos", includes only one sentence per sample.',
+ )
+ parser.add_argument(
+ "--tokens-per-sample",
+ default=512,
+ type=int,
+ help="max number of total tokens over all segments "
+ "per sample for BERT dataset",
+ )
+ parser.add_argument(
+ "--mask-prob",
+ default=0.15,
+ type=float,
+ help="probability of replacing a token with mask",
+ )
+ parser.add_argument(
+ "--leave-unmasked-prob",
+ default=0.1,
+ type=float,
+ help="probability that a masked token is unmasked",
+ )
+ parser.add_argument(
+ "--random-token-prob",
+ default=0.1,
+ type=float,
+ help="probability of replacing a token with a random token",
+ )
+ parser.add_argument(
+ "--freq-weighted-replacement",
+ default=False,
+ action="store_true",
+ help="sample random replacement words based on word frequencies",
+ )
+ parser.add_argument(
+ "--mask-whole-words",
+ default=False,
+ action="store_true",
+ help="mask whole words; you may also want to set --bpe",
+ )
+ parser.add_argument(
+ "--shorten-method",
+ default="none",
+ choices=["none", "truncate", "random_crop"],
+ help="if not none, shorten sequences that exceed --tokens-per-sample",
+ )
+ parser.add_argument(
+ "--shorten-data-split-list",
+ default="",
+ help="comma-separated list of dataset splits to apply shortening to, "
+ 'e.g., "train,valid" (default: all dataset splits)',
+ )
def __init__(self, args, dictionary):
super().__init__(args)
@@ -72,14 +107,14 @@ def __init__(self, args, dictionary):
self.seed = args.seed
# add mask token
- self.mask_idx = dictionary.add_symbol('')
+ self.mask_idx = dictionary.add_symbol("")
@classmethod
def setup_task(cls, args, **kwargs):
paths = utils.split_paths(args.data)
assert len(paths) > 0
- dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
- logger.info('dictionary: {} types'.format(len(dictionary)))
+ dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
+ logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
@@ -100,7 +135,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
combine=combine,
)
if dataset is None:
- raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, split_path)
+ )
dataset = maybe_shorten_dataset(
dataset,
@@ -120,14 +157,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
eos=self.source_dictionary.eos(),
break_mode=self.args.sample_break_mode,
)
- logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path))
+ logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
# prepend beginning-of-sentence token (, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
# create masked input and targets
- mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
- if self.args.mask_whole_words else None
+ mask_whole_words = (
+ get_whole_word_mask(self.args, self.source_dictionary)
+ if self.args.mask_whole_words
+ else None
+ )
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
dataset,
@@ -148,20 +188,20 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
self.datasets[split] = SortDataset(
NestedDictionaryDataset(
{
- 'id': IdDataset(),
- 'net_input': {
- 'src_tokens': RightPadDataset(
+ "id": IdDataset(),
+ "net_input": {
+ "src_tokens": RightPadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
),
- 'src_lengths': NumelDataset(src_dataset, reduce=False),
+ "src_lengths": NumelDataset(src_dataset, reduce=False),
},
- 'target': RightPadDataset(
+ "target": RightPadDataset(
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
),
- 'nsentences': NumSamplesDataset(),
- 'ntokens': NumelDataset(src_dataset, reduce=True),
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_dataset, reduce=True),
},
sizes=[src_dataset.sizes],
),
@@ -179,17 +219,17 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
self.args.tokens_per_sample - 1, # one less for
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
- break_mode='eos',
+ break_mode="eos",
),
pad_idx=self.source_dictionary.pad(),
)
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
src_dataset = NestedDictionaryDataset(
{
- 'id': IdDataset(),
- 'net_input': {
- 'src_tokens': src_dataset,
- 'src_lengths': NumelDataset(src_dataset, reduce=False),
+ "id": IdDataset(),
+ "net_input": {
+ "src_tokens": src_dataset,
+ "src_lengths": NumelDataset(src_dataset, reduce=False),
},
},
sizes=src_lengths,
diff --git a/fairseq/tasks/multilingual_denoising.py b/fairseq/tasks/multilingual_denoising.py
index 18ee717fff..d1c914917f 100644
--- a/fairseq/tasks/multilingual_denoising.py
+++ b/fairseq/tasks/multilingual_denoising.py
@@ -7,62 +7,74 @@
import os
import numpy as np
-
from fairseq.data import (
- data_utils,
- Dictionary,
AppendTokenDataset,
ConcatDataset,
DenoisingDataset,
+ Dictionary,
PrependTokenDataset,
ResamplingDataset,
SortDataset,
TokenBlockDataset,
+ data_utils,
)
-from .denoising import DenoisingTask
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.tasks import register_task
+from .denoising import DenoisingTask
+
logger = logging.getLogger(__name__)
-@register_task('multilingual_denoising')
+@register_task("multilingual_denoising")
class MultilingualDenoisingTask(DenoisingTask):
-
@staticmethod
def add_args(parser):
DenoisingTask.add_args(parser)
- parser.add_argument('--multilang-sampling-alpha', type=float, default=1.0,
- help='smoothing alpha for sample ratios across multiple datasets')
- parser.add_argument('--add-lang-token', default=False, action='store_true')
- parser.add_argument('--langs', type=str, help="language ids we are considering", default=None)
- parser.add_argument('--no-whole-word-mask-langs', type=str, default='', metavar='N',
- help='languages without spacing between words dont support whole word masking')
+ parser.add_argument(
+ "--multilang-sampling-alpha",
+ type=float,
+ default=1.0,
+ help="smoothing alpha for sample ratios across multiple datasets",
+ )
+ parser.add_argument("--add-lang-token", default=False, action="store_true")
+ parser.add_argument(
+ "--langs", type=str, help="language ids we are considering", default=None
+ )
+ parser.add_argument(
+ "--no-whole-word-mask-langs",
+ type=str,
+ default="",
+ metavar="N",
+ help="languages without spacing between words dont support whole word masking",
+ )
@classmethod
def setup_task(cls, args, **kwargs):
- """Setup the task.
- """
- paths = args.data.split(':')
+ """Setup the task."""
+ paths = args.data.split(":")
assert len(paths) > 0
- dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
+ dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
data_path = paths[0]
if args.langs is None:
- languages = sorted([
- name for name in os.listdir(data_path)
- if os.path.isdir(os.path.join(data_path, name))
- ])
+ languages = sorted(
+ [
+ name
+ for name in os.listdir(data_path)
+ if os.path.isdir(os.path.join(data_path, name))
+ ]
+ )
else:
- languages = args.langs.split(',')
+ languages = args.langs.split(",")
if args.add_lang_token:
for lang in languages:
- dictionary.add_symbol('[{}]'.format(lang))
+ dictionary.add_symbol("[{}]".format(lang))
logger.info("dictionary: {} types".format(len(dictionary)))
- if not hasattr(args, 'shuffle_instance'):
+ if not hasattr(args, "shuffle_instance"):
args.shuffle_instance = False
return cls(args, dictionary)
@@ -72,7 +84,7 @@ def __init__(self, args, dictionary):
self.seed = args.seed
# add mask token
- self.mask_idx = self.dictionary.add_symbol('')
+ self.mask_idx = self.dictionary.add_symbol("")
self.langs = args.langs
self.args = args
@@ -92,30 +104,32 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
Args:
split (str): name of the split (e.g., train, valid, test)
"""
- paths = self.args.data.split(':')
+ paths = self.args.data.split(":")
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
if self.langs is None:
- languages = sorted([
- name for name in os.listdir(data_path)
- if os.path.isdir(os.path.join(data_path, name))
- ])
+ languages = sorted(
+ [
+ name
+ for name in os.listdir(data_path)
+ if os.path.isdir(os.path.join(data_path, name))
+ ]
+ )
else:
- languages = self.langs.split(',')
+ languages = self.langs.split(",")
for name in languages:
p = os.path.join(data_path, name)
assert os.path.exists(p), "data not found: {}".format(p)
logger.info("Training on {0} languages: {1}".format(len(languages), languages))
- logger.info("Language to id mapping: ", {
- lang: id for id, lang in enumerate(languages)
- }
+ logger.info(
+ "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
)
mask_whole_words = get_whole_word_mask(self.args, self.dictionary)
- language_without_segmentations = self.args.no_whole_word_mask_langs.split(',')
+ language_without_segmentations = self.args.no_whole_word_mask_langs.split(",")
lang_datasets = []
for language in languages:
split_path = os.path.join(data_path, language, split)
@@ -127,10 +141,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
combine=combine,
)
if dataset is None:
- raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, split_path)
+ )
- end_token = self.source_dictionary.index('[{}]'.format(language)) \
- if self.args.add_lang_token else self.source_dictionary.eos()
+ end_token = (
+ self.source_dictionary.index("[{}]".format(language))
+ if self.args.add_lang_token
+ else self.source_dictionary.eos()
+ )
# create continuous blocks of tokens
dataset = TokenBlockDataset(
@@ -141,13 +160,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
eos=end_token,
break_mode=self.args.sample_break_mode,
)
- logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path))
+ logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
# prepend beginning-of-sentence token (, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
dataset = AppendTokenDataset(dataset, end_token)
- lang_mask_whole_words = mask_whole_words if language not in language_without_segmentations else None
+ lang_mask_whole_words = (
+ mask_whole_words
+ if language not in language_without_segmentations
+ else None
+ )
lang_dataset = DenoisingDataset(
dataset,
dataset.sizes,
@@ -157,7 +180,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
shuffle=self.args.shuffle_instance,
seed=self.seed,
args=self.args,
- eos=None if not self.args.add_lang_token else self.source_dictionary.index('[{}]'.format(language)),
+ eos=None
+ if not self.args.add_lang_token
+ else self.source_dictionary.index("[{}]".format(language)),
)
lang_datasets.append(lang_dataset)
@@ -166,7 +191,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
dtype=float,
)
logger.info(
- 'loaded total {} blocks for all languages'.format(
+ "loaded total {} blocks for all languages".format(
int(dataset_lengths.sum()),
)
)
@@ -174,17 +199,21 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
# For train subset, additionally up or down sample languages.
sample_probs = self._get_sample_prob(dataset_lengths)
logger.info(
- "Sample probability by language: {}".format({
- lang: "{0:.4f}".format(sample_probs[id])
- for id, lang in enumerate(languages)
- })
+ "Sample probability by language: {}".format(
+ {
+ lang: "{0:.4f}".format(sample_probs[id])
+ for id, lang in enumerate(languages)
+ }
+ )
)
size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
logger.info(
- "Up/Down Sampling ratio by language: {}".format({
- lang: "{0:.2f}".format(size_ratio[id])
- for id, lang in enumerate(languages)
- })
+ "Up/Down Sampling ratio by language: {}".format(
+ {
+ lang: "{0:.2f}".format(size_ratio[id])
+ for id, lang in enumerate(languages)
+ }
+ )
)
resampled_lang_datasets = [
@@ -204,13 +233,13 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
dataset = ConcatDataset(lang_datasets)
lang_splits = [split]
for lang_id, lang_dataset in enumerate(lang_datasets):
- split_name = split + '_' + languages[lang_id]
+ split_name = split + "_" + languages[lang_id]
lang_splits.append(split_name)
self.datasets[split_name] = lang_dataset
if split in self.args.valid_subset:
self.args.valid_subset = self.args.valid_subset.replace(
- split, ','.join(lang_splits)
+ split, ",".join(lang_splits)
)
with data_utils.numpy_seed(self.args.seed + epoch):
diff --git a/fairseq/tasks/multilingual_masked_lm.py b/fairseq/tasks/multilingual_masked_lm.py
index 110e580a73..9e6ce4b8a2 100644
--- a/fairseq/tasks/multilingual_masked_lm.py
+++ b/fairseq/tasks/multilingual_masked_lm.py
@@ -8,12 +8,10 @@
import numpy as np
import torch
-
+from fairseq import utils
from fairseq.data import (
- data_utils,
- Dictionary,
- encoders,
ConcatDataset,
+ Dictionary,
IdDataset,
MaskTokensDataset,
NestedDictionaryDataset,
@@ -25,45 +23,79 @@
ResamplingDataset,
SortDataset,
TokenBlockDataset,
+ data_utils,
+ encoders,
)
-from fairseq.tasks import register_task, LegacyFairseqTask
-from fairseq import utils
+from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
-@register_task('multilingual_masked_lm')
+@register_task("multilingual_masked_lm")
class MultiLingualMaskedLMTask(LegacyFairseqTask):
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', help='colon separated path to data directories list, \
- will be iterated upon during epochs in round-robin manner')
- parser.add_argument('--sample-break-mode', default='complete',
- choices=['none', 'complete', 'complete_doc', 'eos'],
- help='If omitted or "none", fills each sample with tokens-per-sample '
- 'tokens. If set to "complete", splits samples only at the end '
- 'of sentence, but may include multiple sentences per sample. '
- '"complete_doc" is similar but respects doc boundaries. '
- 'If set to "eos", includes only one sentence per sample.')
- parser.add_argument('--tokens-per-sample', default=512, type=int,
- help='max number of total tokens over all segments '
- 'per sample for BERT dataset')
- parser.add_argument('--mask-prob', default=0.15, type=float,
- help='probability of replacing a token with mask')
- parser.add_argument('--leave-unmasked-prob', default=0.1, type=float,
- help='probability that a masked token is unmasked')
- parser.add_argument('--random-token-prob', default=0.1, type=float,
- help='probability of replacing a token with a random token')
- parser.add_argument('--freq-weighted-replacement', action='store_true',
- help='sample random replacement words based on word frequencies')
- parser.add_argument('--mask-whole-words', default=False, action='store_true',
- help='mask whole words; you may also want to set --bpe')
- parser.add_argument('--multilang-sampling-alpha', type=float, default=1.0,
- help='smoothing alpha for sample rations across multiple datasets')
+ parser.add_argument(
+ "data",
+ help="colon separated path to data directories list, \
+ will be iterated upon during epochs in round-robin manner",
+ )
+ parser.add_argument(
+ "--sample-break-mode",
+ default="complete",
+ choices=["none", "complete", "complete_doc", "eos"],
+ help='If omitted or "none", fills each sample with tokens-per-sample '
+ 'tokens. If set to "complete", splits samples only at the end '
+ "of sentence, but may include multiple sentences per sample. "
+ '"complete_doc" is similar but respects doc boundaries. '
+ 'If set to "eos", includes only one sentence per sample.',
+ )
+ parser.add_argument(
+ "--tokens-per-sample",
+ default=512,
+ type=int,
+ help="max number of total tokens over all segments "
+ "per sample for BERT dataset",
+ )
+ parser.add_argument(
+ "--mask-prob",
+ default=0.15,
+ type=float,
+ help="probability of replacing a token with mask",
+ )
+ parser.add_argument(
+ "--leave-unmasked-prob",
+ default=0.1,
+ type=float,
+ help="probability that a masked token is unmasked",
+ )
+ parser.add_argument(
+ "--random-token-prob",
+ default=0.1,
+ type=float,
+ help="probability of replacing a token with a random token",
+ )
+ parser.add_argument(
+ "--freq-weighted-replacement",
+ action="store_true",
+ help="sample random replacement words based on word frequencies",
+ )
+ parser.add_argument(
+ "--mask-whole-words",
+ default=False,
+ action="store_true",
+ help="mask whole words; you may also want to set --bpe",
+ )
+ parser.add_argument(
+ "--multilang-sampling-alpha",
+ type=float,
+ default=1.0,
+ help="smoothing alpha for sample rations across multiple datasets",
+ )
def __init__(self, args, dictionary):
super().__init__(args)
@@ -71,14 +103,14 @@ def __init__(self, args, dictionary):
self.seed = args.seed
# add mask token
- self.mask_idx = dictionary.add_symbol('')
+ self.mask_idx = dictionary.add_symbol("")
@classmethod
def setup_task(cls, args, **kwargs):
paths = utils.split_paths(args.data)
assert len(paths) > 0
- dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
- logger.info('dictionary: {} types'.format(len(dictionary)))
+ dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
+ logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def _get_whole_word_mask(self):
@@ -92,16 +124,16 @@ def is_beginning_of_word(i):
# special elements are always considered beginnings
return True
tok = self.source_dictionary[i]
- if tok.startswith('madeupword'):
+ if tok.startswith("madeupword"):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
- mask_whole_words = torch.ByteTensor(list(
- map(is_beginning_of_word, range(len(self.source_dictionary)))
- ))
+ mask_whole_words = torch.ByteTensor(
+ list(map(is_beginning_of_word, range(len(self.source_dictionary))))
+ )
else:
mask_whole_words = None
return mask_whole_words
@@ -127,14 +159,14 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
data_path = paths[(epoch - 1) % len(paths)]
languages = sorted(
- name for name in os.listdir(data_path)
+ name
+ for name in os.listdir(data_path)
if os.path.isdir(os.path.join(data_path, name))
)
logger.info("Training on {0} languages: {1}".format(len(languages), languages))
- logger.info("Language to id mapping: ", {
- lang: id for id, lang in enumerate(languages)
- }
+ logger.info(
+ "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
)
mask_whole_words = self._get_whole_word_mask()
@@ -149,7 +181,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
combine=combine,
)
if dataset is None:
- raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, split_path)
+ )
# create continuous blocks of tokens
dataset = TokenBlockDataset(
@@ -160,7 +194,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
eos=self.source_dictionary.eos(),
break_mode=self.args.sample_break_mode,
)
- logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path))
+ logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
# prepend beginning-of-sentence token (, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
@@ -180,50 +214,53 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
lang_dataset = NestedDictionaryDataset(
{
- 'net_input': {
- 'src_tokens': PadDataset(
+ "net_input": {
+ "src_tokens": PadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
- 'src_lengths': NumelDataset(src_dataset, reduce=False),
+ "src_lengths": NumelDataset(src_dataset, reduce=False),
},
- 'target': PadDataset(
+ "target": PadDataset(
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
- 'nsentences': NumSamplesDataset(),
- 'ntokens': NumelDataset(src_dataset, reduce=True),
- 'lang_id': RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_dataset, reduce=True),
+ "lang_id": RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
},
sizes=[src_dataset.sizes],
)
lang_datasets.append(lang_dataset)
-
dataset_lengths = np.array(
[len(d) for d in lang_datasets],
dtype=float,
)
logger.info(
- 'loaded total {} blocks for all languages'.format(
+ "loaded total {} blocks for all languages".format(
dataset_lengths.sum(),
)
)
if split == self.args.train_subset:
# For train subset, additionally up or down sample languages.
sample_probs = self._get_sample_prob(dataset_lengths)
- logger.info("Sample probability by language: ", {
+ logger.info(
+ "Sample probability by language: ",
+ {
lang: "{0:.4f}".format(sample_probs[id])
for id, lang in enumerate(languages)
- }
+ },
)
size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
- logger.info("Up/Down Sampling ratio by language: ", {
+ logger.info(
+ "Up/Down Sampling ratio by language: ",
+ {
lang: "{0:.2f}".format(size_ratio[id])
for id, lang in enumerate(languages)
- }
+ },
)
resampled_lang_datasets = [
@@ -241,7 +278,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
dataset = ConcatDataset(lang_datasets)
lang_splits = [split]
for lang_id, lang_dataset in enumerate(lang_datasets):
- split_name = split + '_' + languages[lang_id]
+ split_name = split + "_" + languages[lang_id]
lang_splits.append(split_name)
self.datasets[split_name] = lang_dataset
@@ -250,7 +287,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
# in more generic ways.
if split in self.args.valid_subset:
self.args.valid_subset = self.args.valid_subset.replace(
- split, ','.join(lang_splits)
+ split, ",".join(lang_splits)
)
with data_utils.numpy_seed(self.args.seed + epoch):
@@ -272,7 +309,7 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
self.args.tokens_per_sample - 1, # one less for
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
- break_mode='eos',
+ break_mode="eos",
),
pad_idx=self.source_dictionary.pad(),
left_pad=False,
@@ -280,10 +317,10 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
src_dataset = NestedDictionaryDataset(
{
- 'id': IdDataset(),
- 'net_input': {
- 'src_tokens': src_dataset,
- 'src_lengths': NumelDataset(src_dataset, reduce=False),
+ "id": IdDataset(),
+ "net_input": {
+ "src_tokens": src_dataset,
+ "src_lengths": NumelDataset(src_dataset, reduce=False),
},
},
sizes=src_lengths,
diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py
index 161eb436ec..f6cb17f12a 100644
--- a/fairseq/tasks/multilingual_translation.py
+++ b/fairseq/tasks/multilingual_translation.py
@@ -3,14 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from collections import OrderedDict
+import contextlib
import logging
import os
-from fairseq import options
-import contextlib
-import torch
+from collections import OrderedDict
-from fairseq import metrics, utils
+import torch
+from fairseq import metrics, options, utils
from fairseq.data import (
Dictionary,
LanguagePairDataset,
@@ -20,24 +19,24 @@
from fairseq.models import FairseqMultiModel
from fairseq.tasks.translation import load_langpair_dataset
-from . import register_task, LegacyFairseqTask
+from . import LegacyFairseqTask, register_task
+
logger = logging.getLogger(__name__)
def _lang_token(lang: str):
- return '__{}__'.format(lang)
+ return "__{}__".format(lang)
def _lang_token_index(dic: Dictionary, lang: str):
"""Return language token index."""
idx = dic.index(_lang_token(lang))
- assert idx != dic.unk_index, \
- 'cannot find language token for lang {}'.format(lang)
+ assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang)
return idx
-@register_task('multilingual_translation')
+@register_task("multilingual_translation")
class MultilingualTranslationTask(LegacyFairseqTask):
"""A task for training multiple translation models simultaneously.
@@ -99,7 +98,7 @@ def __init__(self, args, dicts, training):
if training:
self.lang_pairs = args.lang_pairs
else:
- self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
+ self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
# eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus
@@ -123,10 +122,14 @@ def prepare(cls, args, **kargs):
args.left_pad_target = utils.eval_bool(args.left_pad_target)
if args.lang_pairs is None:
- raise ValueError('--lang-pairs is required. List all the language pairs in the training objective.')
+ raise ValueError(
+ "--lang-pairs is required. List all the language pairs in the training objective."
+ )
if isinstance(args.lang_pairs, str):
- args.lang_pairs = args.lang_pairs.split(',')
- sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
+ args.lang_pairs = args.lang_pairs.split(",")
+ sorted_langs = sorted(
+ list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")})
+ )
if args.source_lang is not None or args.target_lang is not None:
training = False
else:
@@ -137,7 +140,9 @@ def prepare(cls, args, **kargs):
for lang in sorted_langs:
paths = utils.split_paths(args.data)
assert len(paths) > 0
- dicts[lang] = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
+ dicts[lang] = cls.load_dictionary(
+ os.path.join(paths[0], "dict.{}.txt".format(lang))
+ )
if len(dicts) > 0:
assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
@@ -145,13 +150,13 @@ def prepare(cls, args, **kargs):
if args.encoder_langtok is not None or args.decoder_langtok:
for lang_to_add in sorted_langs:
dicts[lang].add_symbol(_lang_token(lang_to_add))
- logger.info('[{}] dictionary: {} types'.format(lang, len(dicts[lang])))
+ logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang])))
return dicts, training
def get_encoder_langtok(self, src_lang, tgt_lang):
if self.args.encoder_langtok is None:
return self.dicts[src_lang].eos()
- if self.args.encoder_langtok == 'src':
+ if self.args.encoder_langtok == "src":
return _lang_token_index(self.dicts[src_lang], src_lang)
else:
return _lang_token_index(self.dicts[src_lang], tgt_lang)
@@ -161,14 +166,24 @@ def get_decoder_langtok(self, tgt_lang):
return self.dicts[tgt_lang].eos()
return _lang_token_index(self.dicts[tgt_lang], tgt_lang)
- def alter_dataset_langtok(self, lang_pair_dataset,
- src_eos=None, src_lang=None, tgt_eos=None, tgt_lang=None):
+ def alter_dataset_langtok(
+ self,
+ lang_pair_dataset,
+ src_eos=None,
+ src_lang=None,
+ tgt_eos=None,
+ tgt_lang=None,
+ ):
if self.args.encoder_langtok is None and not self.args.decoder_langtok:
return lang_pair_dataset
new_src_eos = None
- if self.args.encoder_langtok is not None and src_eos is not None \
- and src_lang is not None and tgt_lang is not None:
+ if (
+ self.args.encoder_langtok is not None
+ and src_eos is not None
+ and src_lang is not None
+ and tgt_lang is not None
+ ):
new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang)
else:
src_eos = None
@@ -194,10 +209,16 @@ def load_dataset(self, split, epoch=1, **kwargs):
data_path = paths[(epoch - 1) % len(paths)]
def language_pair_dataset(lang_pair):
- src, tgt = lang_pair.split('-')
+ src, tgt = lang_pair.split("-")
langpair_dataset = load_langpair_dataset(
- data_path, split, src, self.dicts[src], tgt, self.dicts[tgt],
- combine=True, dataset_impl=self.args.dataset_impl,
+ data_path,
+ split,
+ src,
+ self.dicts[src],
+ tgt,
+ self.dicts[tgt],
+ combine=True,
+ dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
@@ -213,68 +234,100 @@ def language_pair_dataset(lang_pair):
)
self.datasets[split] = RoundRobinZipDatasets(
- OrderedDict([
- (lang_pair, language_pair_dataset(lang_pair))
- for lang_pair in self.lang_pairs
- ]),
- eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
+ OrderedDict(
+ [
+ (lang_pair, language_pair_dataset(lang_pair))
+ for lang_pair in self.lang_pairs
+ ]
+ ),
+ eval_key=None
+ if self.training
+ else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
- raise NotImplementedError("Constrained decoding with the multilingual_translation task is not supported")
+ raise NotImplementedError(
+ "Constrained decoding with the multilingual_translation task is not supported"
+ )
lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
return RoundRobinZipDatasets(
- OrderedDict([(
- lang_pair,
- self.alter_dataset_langtok(
- LanguagePairDataset(
- src_tokens, src_lengths,
- self.source_dictionary
- ),
- src_eos=self.source_dictionary.eos(),
- src_lang=self.args.source_lang,
- tgt_eos=self.target_dictionary.eos(),
- tgt_lang=self.args.target_lang,
- ),
- )]),
+ OrderedDict(
+ [
+ (
+ lang_pair,
+ self.alter_dataset_langtok(
+ LanguagePairDataset(
+ src_tokens, src_lengths, self.source_dictionary
+ ),
+ src_eos=self.source_dictionary.eos(),
+ src_lang=self.args.source_lang,
+ tgt_eos=self.target_dictionary.eos(),
+ tgt_lang=self.args.target_lang,
+ ),
+ )
+ ]
+ ),
eval_key=lang_pair,
)
def build_model(self, args):
def check_args():
messages = []
- if len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs)) != 0:
- messages.append('--lang-pairs should include all the language pairs {}.'.format(args.lang_pairs))
+ if (
+ len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs))
+ != 0
+ ):
+ messages.append(
+ "--lang-pairs should include all the language pairs {}.".format(
+ args.lang_pairs
+ )
+ )
if self.args.encoder_langtok != args.encoder_langtok:
- messages.append('--encoder-langtok should be {}.'.format(args.encoder_langtok))
+ messages.append(
+ "--encoder-langtok should be {}.".format(args.encoder_langtok)
+ )
if self.args.decoder_langtok != args.decoder_langtok:
- messages.append('--decoder-langtok should {} be set.'.format("" if args.decoder_langtok else "not"))
+ messages.append(
+ "--decoder-langtok should {} be set.".format(
+ "" if args.decoder_langtok else "not"
+ )
+ )
if len(messages) > 0:
- raise ValueError(' '.join(messages))
+ raise ValueError(" ".join(messages))
# Check if task args are consistant with model args
check_args()
from fairseq import models
+
model = models.build_model(args, self)
if not isinstance(model, FairseqMultiModel):
- raise ValueError('MultilingualTranslationTask requires a FairseqMultiModel architecture')
+ raise ValueError(
+ "MultilingualTranslationTask requires a FairseqMultiModel architecture"
+ )
return model
- def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad):
- loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
+ def _per_lang_pair_train_loss(
+ self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
+ ):
+ loss, sample_size, logging_output = criterion(
+ model.models[lang_pair], sample[lang_pair]
+ )
if ignore_grad:
loss *= 0
optimizer.backward(loss)
return loss, sample_size, logging_output
- def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
+ def train_step(
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
+ ):
model.train()
from collections import defaultdict
- agg_loss, agg_sample_size, agg_logging_output = 0., 0., defaultdict(float)
+
+ agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float)
curr_lang_pairs = [
lang_pair
for lang_pair in self.model_lang_pairs
@@ -282,18 +335,27 @@ def train_step(self, sample, model, criterion, optimizer, update_num, ignore_gra
]
for idx, lang_pair in enumerate(curr_lang_pairs):
+
def maybe_no_sync():
if (
self.args.distributed_world_size > 1
- and hasattr(model, 'no_sync')
+ and hasattr(model, "no_sync")
and idx < len(curr_lang_pairs) - 1
):
return model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager
+
with maybe_no_sync():
loss, sample_size, logging_output = self._per_lang_pair_train_loss(
- lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad)
+ lang_pair,
+ model,
+ update_num,
+ criterion,
+ sample,
+ optimizer,
+ ignore_grad,
+ )
agg_loss += loss.detach().item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
@@ -309,11 +371,18 @@ def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
from collections import defaultdict
- agg_loss, agg_sample_size, agg_logging_output = 0., 0., defaultdict(float)
+
+ agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float)
for lang_pair in self.eval_lang_pairs:
- if lang_pair not in sample or sample[lang_pair] is None or len(sample[lang_pair]) == 0:
+ if (
+ lang_pair not in sample
+ or sample[lang_pair] is None
+ or len(sample[lang_pair]) == 0
+ ):
continue
- loss, sample_size, logging_output = self._per_lang_pair_valid_loss(lang_pair, model, criterion, sample)
+ loss, sample_size, logging_output = self._per_lang_pair_valid_loss(
+ lang_pair, model, criterion, sample
+ )
agg_loss += loss.data.item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
@@ -322,10 +391,14 @@ def valid_step(self, sample, model, criterion):
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
- def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
+ def inference_step(
+ self, generator, models, sample, prefix_tokens=None, constraints=None
+ ):
with torch.no_grad():
if self.args.decoder_langtok:
- bos_token = _lang_token_index(self.target_dictionary, self.args.target_lang)
+ bos_token = _lang_token_index(
+ self.target_dictionary, self.args.target_lang
+ )
else:
bos_token = self.target_dictionary.eos()
return generator.generate(
@@ -340,7 +413,7 @@ def reduce_metrics(self, logging_outputs, criterion):
with metrics.aggregate():
# pass 'sample_size', 'nsentences', 'ntokens' stats to fairseq_task
super().reduce_metrics(logging_outputs, criterion)
- for k in ['sample_size', 'nsentences', 'ntokens']:
+ for k in ["sample_size", "nsentences", "ntokens"]:
metrics.log_scalar(k, sum(l[k] for l in logging_outputs))
@property
@@ -360,10 +433,17 @@ def target_dictionary(self):
def max_positions(self):
"""Return the max sentence length allowed by the task."""
if len(self.datasets.values()) == 0:
- return {'%s-%s' % (self.args.source_lang, self.args.target_lang):
- (self.args.max_source_positions, self.args.max_target_positions)}
- return OrderedDict([
- (key, (self.args.max_source_positions, self.args.max_target_positions))
- for split in self.datasets.keys()
- for key in self.datasets[split].datasets.keys()
- ])
+ return {
+ "%s-%s"
+ % (self.args.source_lang, self.args.target_lang): (
+ self.args.max_source_positions,
+ self.args.max_target_positions,
+ )
+ }
+ return OrderedDict(
+ [
+ (key, (self.args.max_source_positions, self.args.max_target_positions))
+ for split in self.datasets.keys()
+ for key in self.datasets[split].datasets.keys()
+ ]
+ )
diff --git a/fairseq/tasks/semisupervised_translation.py b/fairseq/tasks/semisupervised_translation.py
index c81d362886..b2f9bf9a73 100644
--- a/fairseq/tasks/semisupervised_translation.py
+++ b/fairseq/tasks/semisupervised_translation.py
@@ -3,27 +3,28 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from collections import OrderedDict
import logging
import os
+from collections import OrderedDict
+from fairseq import utils
from fairseq.data import (
BacktranslationDataset,
- data_utils,
- indexed_dataset,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
NoisingDataset,
RoundRobinZipDatasets,
+ data_utils,
+ indexed_dataset,
)
from fairseq.models import FairseqMultiModel
from fairseq.sequence_generator import SequenceGenerator
-from .multilingual_translation import MultilingualTranslationTask
from . import register_task
-from fairseq import utils
+from .multilingual_translation import MultilingualTranslationTask
+
logger = logging.getLogger(__name__)
@@ -46,18 +47,20 @@ def parse_lambda_config(x):
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
# iterations, then will linearly increase to 1 until iteration 2000
"""
- split = x.split(',')
+ split = x.split(",")
if len(split) == 1:
return float(x), None
else:
split = [s.split(os.pathsep) for s in split]
assert all(len(s) == 2 for s in split)
assert all(k.isdigit() for k, _ in split)
- assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1))
+ assert all(
+ int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1)
+ )
return float(split[0][1]), [(int(k), float(v)) for k, v in split]
-@register_task('semisupervised_translation')
+@register_task("semisupervised_translation")
class SemisupervisedTranslationTask(MultilingualTranslationTask):
"""A task for training multiple translation models simultaneously.
@@ -119,13 +122,19 @@ def add_args(parser):
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
- self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config(args.lambda_parallel_config)
- self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config(args.lambda_otf_bt_config)
- self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config(args.lambda_denoising_config)
- if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None):
+ self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config(
+ args.lambda_parallel_config
+ )
+ self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config(
+ args.lambda_otf_bt_config
+ )
+ self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config(
+ args.lambda_denoising_config
+ )
+ if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None:
denoising_lang_pairs = [
"%s-%s" % (tgt, tgt)
- for tgt in {lang_pair.split('-')[1] for lang_pair in args.lang_pairs}
+ for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs}
]
self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs
self.backtranslate_datasets = {}
@@ -144,39 +153,71 @@ def load_dataset(self, split, epoch=1, **kwargs):
def split_exists(split, src, tgt, lang):
if src is not None:
- filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
+ filename = os.path.join(
+ data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
+ )
else:
- filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt))
+ filename = os.path.join(
+ data_path, "{}.{}-None.{}".format(split, src, tgt)
+ )
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
def load_indexed_dataset(path, dictionary):
- return data_utils.load_indexed_dataset(path, dictionary, self.args.dataset_impl)
+ return data_utils.load_indexed_dataset(
+ path, dictionary, self.args.dataset_impl
+ )
# load parallel datasets
src_datasets, tgt_datasets = {}, {}
- if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")):
+ if (
+ self.lambda_parallel > 0.0
+ or self.lambda_parallel_steps is not None
+ or not split.startswith("train")
+ ):
for lang_pair in self.lang_pairs:
- src, tgt = lang_pair.split('-')
+ src, tgt = lang_pair.split("-")
if split_exists(split, src, tgt, src):
- prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
+ prefix = os.path.join(
+ data_path, "{}.{}-{}.".format(split, src, tgt)
+ )
elif split_exists(split, tgt, src, src):
- prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
+ prefix = os.path.join(
+ data_path, "{}.{}-{}.".format(split, tgt, src)
+ )
else:
continue
- src_datasets[lang_pair] = load_indexed_dataset(prefix + src, self.dicts[src])
- tgt_datasets[lang_pair] = load_indexed_dataset(prefix + tgt, self.dicts[tgt])
- logger.info('parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
+ src_datasets[lang_pair] = load_indexed_dataset(
+ prefix + src, self.dicts[src]
+ )
+ tgt_datasets[lang_pair] = load_indexed_dataset(
+ prefix + tgt, self.dicts[tgt]
+ )
+ logger.info(
+ "parallel-{} {} {} examples".format(
+ data_path, split, len(src_datasets[lang_pair])
+ )
+ )
if len(src_datasets) == 0:
- raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, data_path)
+ )
# back translation datasets
backtranslate_datasets = {}
- if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"):
+ if (
+ self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
+ ) and split.startswith("train"):
for lang_pair in self.lang_pairs:
- src, tgt = lang_pair.split('-')
+ src, tgt = lang_pair.split("-")
if not split_exists(split, tgt, None, tgt):
- raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path))
- filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
+ raise FileNotFoundError(
+ "Dataset not found: backtranslation {} ({})".format(
+ split, data_path
+ )
+ )
+ filename = os.path.join(
+ data_path, "{}.{}-None.{}".format(split, tgt, tgt)
+ )
dataset = load_indexed_dataset(filename, self.dicts[tgt])
lang_pair_dataset_tgt = LanguagePairDataset(
dataset,
@@ -203,7 +244,8 @@ def load_indexed_dataset(path, dictionary):
tgt_lang=src,
),
backtranslation_fn=self.backtranslators[lang_pair],
- src_dict=self.dicts[src], tgt_dict=self.dicts[tgt],
+ src_dict=self.dicts[src],
+ tgt_dict=self.dicts[tgt],
output_collater=self.alter_dataset_langtok(
lang_pair_dataset=lang_pair_dataset,
src_eos=self.dicts[src].eos(),
@@ -212,19 +254,30 @@ def load_indexed_dataset(path, dictionary):
tgt_lang=tgt,
).collater,
)
- logger.info('backtranslate-{}: {} {} {} examples'.format(
- tgt, data_path, split, len(backtranslate_datasets[lang_pair]),
- ))
- self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair]
+ logger.info(
+ "backtranslate-{}: {} {} {} examples".format(
+ tgt,
+ data_path,
+ split,
+ len(backtranslate_datasets[lang_pair]),
+ )
+ )
+ self.backtranslate_datasets[lang_pair] = backtranslate_datasets[
+ lang_pair
+ ]
# denoising autoencoder
noising_datasets = {}
- if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"):
+ if (
+ self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None
+ ) and split.startswith("train"):
for lang_pair in self.lang_pairs:
- _, tgt = lang_pair.split('-')
+ _, tgt = lang_pair.split("-")
if not split_exists(split, tgt, None, tgt):
continue
- filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
+ filename = os.path.join(
+ data_path, "{}.{}-None.{}".format(split, tgt, tgt)
+ )
tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
noising_dataset = NoisingDataset(
@@ -251,17 +304,26 @@ def load_indexed_dataset(path, dictionary):
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
)
- logger.info('denoising-{}: {} {} {} examples'.format(
- tgt, data_path, split, len(noising_datasets[lang_pair]),
- ))
+ logger.info(
+ "denoising-{}: {} {} {} examples".format(
+ tgt,
+ data_path,
+ split,
+ len(noising_datasets[lang_pair]),
+ )
+ )
def language_pair_dataset(lang_pair):
- src, tgt = lang_pair.split('-')
+ src, tgt = lang_pair.split("-")
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
return self.alter_dataset_langtok(
LanguagePairDataset(
- src_dataset, src_dataset.sizes, self.dicts[src],
- tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
+ src_dataset,
+ src_dataset.sizes,
+ self.dicts[src],
+ tgt_dataset,
+ tgt_dataset.sizes,
+ self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
),
@@ -272,31 +334,42 @@ def language_pair_dataset(lang_pair):
)
self.datasets[split] = RoundRobinZipDatasets(
- OrderedDict([
- (lang_pair, language_pair_dataset(lang_pair))
- for lang_pair in src_datasets.keys()
- ] + [
- (_get_bt_dataset_key(lang_pair), dataset)
- for lang_pair, dataset in backtranslate_datasets.items()
- ] + [
- (_get_denoising_dataset_key(lang_pair), dataset)
- for lang_pair, dataset in noising_datasets.items()
- ]),
- eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
+ OrderedDict(
+ [
+ (lang_pair, language_pair_dataset(lang_pair))
+ for lang_pair in src_datasets.keys()
+ ]
+ + [
+ (_get_bt_dataset_key(lang_pair), dataset)
+ for lang_pair, dataset in backtranslate_datasets.items()
+ ]
+ + [
+ (_get_denoising_dataset_key(lang_pair), dataset)
+ for lang_pair, dataset in noising_datasets.items()
+ ]
+ ),
+ eval_key=None
+ if self.training
+ else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)
def build_model(self, args):
from fairseq import models
+
model = models.build_model(args, self)
if not isinstance(model, FairseqMultiModel):
- raise ValueError('SemisupervisedTranslationTask requires a FairseqMultiModel architecture')
+ raise ValueError(
+ "SemisupervisedTranslationTask requires a FairseqMultiModel architecture"
+ )
# create SequenceGenerator for each model that has backtranslation dependency on it
self.sequence_generators = {}
- if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and self.training:
+ if (
+ self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
+ ) and self.training:
for lang_pair in self.lang_pairs:
- src, tgt = lang_pair.split('-')
- key = '{}-{}'.format(tgt, src)
+ src, tgt = lang_pair.split("-")
+ key = "{}-{}".format(tgt, src)
self.sequence_generators[key] = SequenceGenerator(
[model.models[key]],
tgt_dict=self.dicts[src],
@@ -307,7 +380,8 @@ def build_model(self, args):
decoder_lang_tok_idx = self.get_decoder_langtok(src)
def backtranslate_fn(
- sample, model=model.models[key],
+ sample,
+ model=model.models[key],
bos_token=decoder_lang_tok_idx,
sequence_generator=self.sequence_generators[key],
):
@@ -316,17 +390,20 @@ def backtranslate_fn(
sample,
bos_token=bos_token,
)
+
self.backtranslators[lang_pair] = backtranslate_fn
return model
- def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
+ def train_step(
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
+ ):
model.train()
if update_num > 0:
self.update_step(update_num)
- agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
+ agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {}
def forward_backward(model, samples, logging_output_key, weight):
nonlocal agg_loss, agg_sample_size, agg_logging_output
@@ -347,18 +424,33 @@ def forward_backward(model, samples, logging_output_key, weight):
if self.lambda_parallel > 0.0:
for lang_pair in self.lang_pairs:
- forward_backward(model.models[lang_pair], sample[lang_pair], lang_pair, self.lambda_parallel)
+ forward_backward(
+ model.models[lang_pair],
+ sample[lang_pair],
+ lang_pair,
+ self.lambda_parallel,
+ )
if self.lambda_otf_bt > 0.0:
for lang_pair in self.lang_pairs:
sample_key = _get_bt_dataset_key(lang_pair)
- forward_backward(model.models[lang_pair], sample[sample_key], sample_key, self.lambda_otf_bt)
+ forward_backward(
+ model.models[lang_pair],
+ sample[sample_key],
+ sample_key,
+ self.lambda_otf_bt,
+ )
if self.lambda_denoising > 0.0:
for lang_pair in self.lang_pairs:
- _, tgt = lang_pair.split('-')
+ _, tgt = lang_pair.split("-")
sample_key = _get_denoising_dataset_key(lang_pair)
- forward_backward(model.models['{0}-{0}'.format(tgt)], sample[sample_key], sample_key, self.lambda_denoising)
+ forward_backward(
+ model.models["{0}-{0}".format(tgt)],
+ sample[sample_key],
+ sample_key,
+ self.lambda_denoising,
+ )
return agg_loss, agg_sample_size, agg_logging_output
@@ -367,7 +459,11 @@ def lambda_step_func(config, n_iter):
"""
Update a lambda value according to its schedule configuration.
"""
- ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]]
+ ranges = [
+ i
+ for i in range(len(config) - 1)
+ if config[i][0] <= n_iter < config[i + 1][0]
+ ]
if len(ranges) == 0:
assert n_iter >= config[-1][0]
return config[-1][1]
@@ -378,8 +474,12 @@ def lambda_step_func(config, n_iter):
return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)
if self.lambda_parallel_steps is not None:
- self.lambda_parallel = lambda_step_func(self.lambda_parallel_steps, num_updates)
+ self.lambda_parallel = lambda_step_func(
+ self.lambda_parallel_steps, num_updates
+ )
if self.lambda_denoising_steps is not None:
- self.lambda_denoising = lambda_step_func(self.lambda_denoising_steps, num_updates)
+ self.lambda_denoising = lambda_step_func(
+ self.lambda_denoising_steps, num_updates
+ )
if self.lambda_otf_bt_steps is not None:
self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates)
diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py
index d9a82faddd..69dc996e6a 100644
--- a/fairseq/tasks/sentence_prediction.py
+++ b/fairseq/tasks/sentence_prediction.py
@@ -7,16 +7,14 @@
import os
import numpy as np
-
from fairseq import utils
from fairseq.data import (
ConcatSentencesDataset,
- data_utils,
Dictionary,
IdDataset,
NestedDictionaryDataset,
- NumSamplesDataset,
NumelDataset,
+ NumSamplesDataset,
OffsetTokensDataset,
PrependTokenDataset,
RawLabelDataset,
@@ -24,15 +22,16 @@
RollDataset,
SortDataset,
StripTokenDataset,
+ data_utils,
)
-from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.data.shorten_dataset import maybe_shorten_dataset
+from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
-@register_task('sentence_prediction')
+@register_task("sentence_prediction")
class SentencePredictionTask(LegacyFairseqTask):
"""
Sentence (or sentence pair) prediction (classification or regression) task.
@@ -44,30 +43,51 @@ class SentencePredictionTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', metavar='FILE',
- help='file prefix for data')
- parser.add_argument('--num-classes', type=int, default=-1,
- help='number of classes or regression targets')
- parser.add_argument('--init-token', type=int, default=None,
- help='add token at the beginning of each batch item')
- parser.add_argument('--separator-token', type=int, default=None,
- help='add separator token between inputs')
- parser.add_argument('--regression-target', action='store_true', default=False)
- parser.add_argument('--no-shuffle', action='store_true', default=False)
- parser.add_argument('--shorten-method', default='none',
- choices=['none', 'truncate', 'random_crop'],
- help='if not none, shorten sequences that exceed --tokens-per-sample')
- parser.add_argument('--shorten-data-split-list', default='',
- help='comma-separated list of dataset splits to apply shortening to, '
- 'e.g., "train,valid" (default: all dataset splits)')
- parser.add_argument('--add-prev-output-tokens', action='store_true', default=False,
- help='add prev_output_tokens to sample, used for encoder-decoder arch')
+ parser.add_argument("data", metavar="FILE", help="file prefix for data")
+ parser.add_argument(
+ "--num-classes",
+ type=int,
+ default=-1,
+ help="number of classes or regression targets",
+ )
+ parser.add_argument(
+ "--init-token",
+ type=int,
+ default=None,
+ help="add token at the beginning of each batch item",
+ )
+ parser.add_argument(
+ "--separator-token",
+ type=int,
+ default=None,
+ help="add separator token between inputs",
+ )
+ parser.add_argument("--regression-target", action="store_true", default=False)
+ parser.add_argument("--no-shuffle", action="store_true", default=False)
+ parser.add_argument(
+ "--shorten-method",
+ default="none",
+ choices=["none", "truncate", "random_crop"],
+ help="if not none, shorten sequences that exceed --tokens-per-sample",
+ )
+ parser.add_argument(
+ "--shorten-data-split-list",
+ default="",
+ help="comma-separated list of dataset splits to apply shortening to, "
+ 'e.g., "train,valid" (default: all dataset splits)',
+ )
+ parser.add_argument(
+ "--add-prev-output-tokens",
+ action="store_true",
+ default=False,
+ help="add prev_output_tokens to sample, used for encoder-decoder arch",
+ )
def __init__(self, args, data_dictionary, label_dictionary):
super().__init__(args)
self.dictionary = data_dictionary
self._label_dictionary = label_dictionary
- if not hasattr(args, 'max_positions'):
+ if not hasattr(args, "max_positions"):
self._max_positions = (
args.max_source_positions,
args.max_target_positions,
@@ -84,36 +104,37 @@ def load_dictionary(cls, args, filename, source=True):
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
- dictionary.add_symbol('')
+ dictionary.add_symbol("")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
- assert args.num_classes > 0, 'Must set --num-classes'
+ assert args.num_classes > 0, "Must set --num-classes"
# load data dictionary
data_dict = cls.load_dictionary(
args,
- os.path.join(args.data, 'input0', 'dict.txt'),
+ os.path.join(args.data, "input0", "dict.txt"),
source=True,
)
- logger.info('[input] dictionary: {} types'.format(len(data_dict)))
+ logger.info("[input] dictionary: {} types".format(len(data_dict)))
label_dict = None
if not args.regression_target:
# load label dictionary
label_dict = cls.load_dictionary(
args,
- os.path.join(args.data, 'label', 'dict.txt'),
+ os.path.join(args.data, "label", "dict.txt"),
source=False,
)
- logger.info('[label] dictionary: {} types'.format(len(label_dict)))
+ logger.info("[label] dictionary: {} types".format(len(label_dict)))
else:
label_dict = data_dict
return cls(args, data_dict, label_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
+
def get_path(type, split):
return os.path.join(self.args.data, type, split)
@@ -128,9 +149,11 @@ def make_dataset(type, dictionary):
)
return dataset
- input0 = make_dataset('input0', self.source_dictionary)
- assert input0 is not None, 'could not find dataset: {}'.format(get_path(type, split))
- input1 = make_dataset('input1', self.source_dictionary)
+ input0 = make_dataset("input0", self.source_dictionary)
+ assert input0 is not None, "could not find dataset: {}".format(
+ get_path(type, split)
+ )
+ input1 = make_dataset("input1", self.source_dictionary)
if self.args.init_token is not None:
input0 = PrependTokenDataset(input0, self.args.init_token)
@@ -156,16 +179,16 @@ def make_dataset(type, dictionary):
)
dataset = {
- 'id': IdDataset(),
- 'net_input': {
- 'src_tokens': RightPadDataset(
+ "id": IdDataset(),
+ "net_input": {
+ "src_tokens": RightPadDataset(
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
- 'src_lengths': NumelDataset(src_tokens, reduce=False),
+ "src_lengths": NumelDataset(src_tokens, reduce=False),
},
- 'nsentences': NumSamplesDataset(),
- 'ntokens': NumelDataset(src_tokens, reduce=True),
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_tokens, reduce=True),
}
if self.args.add_prev_output_tokens:
@@ -173,12 +196,12 @@ def make_dataset(type, dictionary):
RollDataset(src_tokens, 1),
pad_idx=self.dictionary.pad(),
)
- dataset['net_input'].update(
+ dataset["net_input"].update(
prev_output_tokens=prev_tokens_dataset,
)
if not self.args.regression_target:
- label_dataset = make_dataset('label', self.label_dictionary)
+ label_dataset = make_dataset("label", self.label_dictionary)
if label_dataset is not None:
dataset.update(
target=OffsetTokensDataset(
@@ -190,21 +213,24 @@ def make_dataset(type, dictionary):
)
)
else:
- label_path = "{0}.label".format(get_path('label', split))
+ label_path = "{0}.label".format(get_path("label", split))
if os.path.exists(label_path):
def parse_regression_target(i, line):
values = line.split()
- assert len(values) == self.args.num_classes, \
- f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
+ assert (
+ len(values) == self.args.num_classes
+ ), f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
return [float(x) for x in values]
with open(label_path) as h:
dataset.update(
- target=RawLabelDataset([
- parse_regression_target(i, line.strip())
- for i, line in enumerate(h.readlines())
- ])
+ target=RawLabelDataset(
+ [
+ parse_regression_target(i, line.strip())
+ for i, line in enumerate(h.readlines())
+ ]
+ )
)
nested_dataset = NestedDictionaryDataset(
@@ -228,10 +254,11 @@ def parse_regression_target(i, line):
def build_model(self, args):
from fairseq import models
+
model = models.build_model(args, self)
model.register_classification_head(
- getattr(args, 'classification_head_name', 'sentence_classification_head'),
+ getattr(args, "classification_head_name", "sentence_classification_head"),
num_classes=self.args.num_classes,
)
diff --git a/fairseq/tasks/sentence_ranking.py b/fairseq/tasks/sentence_ranking.py
index a1d332a3ca..bed44f34e5 100644
--- a/fairseq/tasks/sentence_ranking.py
+++ b/fairseq/tasks/sentence_ranking.py
@@ -7,30 +7,29 @@
import os
import numpy as np
-
from fairseq import utils
from fairseq.data import (
ConcatSentencesDataset,
- data_utils,
Dictionary,
IdDataset,
NestedDictionaryDataset,
- NumSamplesDataset,
NumelDataset,
+ NumSamplesDataset,
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
SortDataset,
- TruncateDataset
+ TruncateDataset,
+ data_utils,
)
-from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.data.shorten_dataset import maybe_shorten_dataset
+from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
-@register_task('sentence_ranking')
+@register_task("sentence_ranking")
class SentenceRankingTask(LegacyFairseqTask):
"""
Ranking task on multiple sentences.
@@ -42,23 +41,34 @@ class SentenceRankingTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', metavar='FILE',
- help='file prefix for data')
- parser.add_argument('--num-classes', type=int,
- help='number of sentences to be ranked')
- parser.add_argument('--init-token', type=int,
- help='add token at the beginning of each batch item')
- parser.add_argument('--separator-token', type=int,
- help='add separator token between inputs')
- parser.add_argument('--no-shuffle', action='store_true')
- parser.add_argument('--shorten-method', default='none',
- choices=['none', 'truncate', 'random_crop'],
- help='if not none, shorten sequences that exceed --tokens-per-sample')
- parser.add_argument('--shorten-data-split-list', default='',
- help='comma-separated list of dataset splits to apply shortening to, '
- 'e.g., "train,valid" (default: all dataset splits)')
- parser.add_argument('--max-option-length', type=int,
- help='max length for each option')
+ parser.add_argument("data", metavar="FILE", help="file prefix for data")
+ parser.add_argument(
+ "--num-classes", type=int, help="number of sentences to be ranked"
+ )
+ parser.add_argument(
+ "--init-token",
+ type=int,
+ help="add token at the beginning of each batch item",
+ )
+ parser.add_argument(
+ "--separator-token", type=int, help="add separator token between inputs"
+ )
+ parser.add_argument("--no-shuffle", action="store_true")
+ parser.add_argument(
+ "--shorten-method",
+ default="none",
+ choices=["none", "truncate", "random_crop"],
+ help="if not none, shorten sequences that exceed --tokens-per-sample",
+ )
+ parser.add_argument(
+ "--shorten-data-split-list",
+ default="",
+ help="comma-separated list of dataset splits to apply shortening to, "
+ 'e.g., "train,valid" (default: all dataset splits)',
+ )
+ parser.add_argument(
+ "--max-option-length", type=int, help="max length for each option"
+ )
def __init__(self, args, dictionary):
super().__init__(args)
@@ -72,21 +82,22 @@ def load_dictionary(cls, args, filename, source=True):
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
- dictionary.add_symbol('')
+ dictionary.add_symbol("")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
- assert args.criterion == 'sentence_ranking', \
- 'Must set --criterion=sentence_ranking'
+ assert (
+ args.criterion == "sentence_ranking"
+ ), "Must set --criterion=sentence_ranking"
# load data dictionary
data_dict = cls.load_dictionary(
args,
- os.path.join(args.data, 'input0', 'dict.txt'),
+ os.path.join(args.data, "input0", "dict.txt"),
source=True,
)
- logger.info('[input] dictionary: {} types'.format(len(data_dict)))
+ logger.info("[input] dictionary: {} types".format(len(data_dict)))
return SentenceRankingTask(args, data_dict)
def load_dataset(self, split, combine=False, **kwargs):
@@ -106,12 +117,9 @@ def make_dataset(type, dictionary):
)
return dataset
- input0 = make_dataset('input0', self.source_dictionary)
+ input0 = make_dataset("input0", self.source_dictionary)
input_options = [
- make_dataset(
- 'input{idx}'.format(idx=idx + 1),
- self.source_dictionary
- )
+ make_dataset("input{idx}".format(idx=idx + 1), self.source_dictionary)
for idx in range(self.args.num_classes)
]
@@ -123,7 +131,9 @@ def make_dataset(type, dictionary):
if self.args.init_token is not None:
input_option = PrependTokenDataset(input_option, self.args.init_token)
if self.args.max_option_length is not None:
- input_option = TruncateDataset(input_option, self.args.max_option_length)
+ input_option = TruncateDataset(
+ input_option, self.args.max_option_length
+ )
src_token = ConcatSentencesDataset(input_option, input0)
src_token = maybe_shorten_dataset(
src_token,
@@ -139,31 +149,31 @@ def make_dataset(type, dictionary):
shuffle = np.random.permutation(len(src_tokens[0]))
dataset = {
- 'id': IdDataset(),
- 'nsentences': NumSamplesDataset(),
- 'ntokens': NumelDataset(src_tokens[0], reduce=True),
+ "id": IdDataset(),
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_tokens[0], reduce=True),
}
for src_token_idx in range(len(src_tokens)):
dataset.update(
{
- 'net_input{idx}'.format(idx=src_token_idx+1): {
- 'src_tokens': RightPadDataset(
+ "net_input{idx}".format(idx=src_token_idx + 1): {
+ "src_tokens": RightPadDataset(
src_tokens[src_token_idx],
pad_idx=self.source_dictionary.pad(),
),
- 'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False),
+ "src_lengths": NumelDataset(
+ src_tokens[src_token_idx], reduce=False
+ ),
}
}
)
- label_path = '{}.label'.format(get_path('label', split))
+ label_path = "{}.label".format(get_path("label", split))
if os.path.exists(label_path):
with open(label_path) as h:
dataset.update(
- target=RawLabelDataset([
- int(x.strip()) for x in h.readlines()
- ])
+ target=RawLabelDataset([int(x.strip()) for x in h.readlines()])
)
nested_dataset = NestedDictionaryDataset(
@@ -187,10 +197,11 @@ def make_dataset(type, dictionary):
def build_model(self, args):
from fairseq import models
+
model = models.build_model(args, self)
model.register_classification_head(
- getattr(args, 'ranking_head_name', 'sentence_classification_head'),
+ getattr(args, "ranking_head_name", "sentence_classification_head"),
num_classes=1,
)
diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py
index b17ad22602..6d222f0de3 100644
--- a/fairseq/tasks/speech_to_text.py
+++ b/fairseq/tasks/speech_to_text.py
@@ -4,38 +4,51 @@
# LICENSE file in the root directory of this source tree.
import logging
-from argparse import Namespace
import os.path as op
+from argparse import Namespace
-from fairseq.data import encoders, Dictionary
+from fairseq.data import Dictionary, encoders
from fairseq.data.audio.speech_to_text_dataset import (
- SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig
+ S2TDataConfig,
+ SpeechToTextDataset,
+ SpeechToTextDatasetCreator,
)
from fairseq.tasks import FairseqTask, register_task
+
logging.basicConfig(
- format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S',
- level=logging.INFO,
- )
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO,
+)
logger = logging.getLogger(__name__)
-@register_task('speech_to_text')
+@register_task("speech_to_text")
class SpeechToTextTask(FairseqTask):
@staticmethod
def add_args(parser):
- parser.add_argument('data', help='manifest root path')
+ parser.add_argument("data", help="manifest root path")
+ parser.add_argument(
+ "--config-yaml",
+ type=str,
+ default="config.yaml",
+ help="Configuration YAML filename (under manifest root)",
+ )
+ parser.add_argument(
+ "--max-source-positions",
+ default=6000,
+ type=int,
+ metavar="N",
+ help="max number of tokens in the source sequence",
+ )
parser.add_argument(
- '--config-yaml', type=str, default='config.yaml',
- help='Configuration YAML filename (under manifest root)'
+ "--max-target-positions",
+ default=1024,
+ type=int,
+ metavar="N",
+ help="max number of tokens in the target sequence",
)
- parser.add_argument('--max-source-positions', default=6000, type=int,
- metavar='N',
- help='max number of tokens in the source sequence')
- parser.add_argument('--max-target-positions', default=1024, type=int,
- metavar='N',
- help='max number of tokens in the target sequence')
def __init__(self, args, tgt_dict):
super().__init__(args)
@@ -47,31 +60,41 @@ def setup_task(cls, args, **kwargs):
data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml))
dict_path = op.join(args.data, data_cfg.vocab_filename)
if not op.isfile(dict_path):
- raise FileNotFoundError(f'Dict not found: {dict_path}')
+ raise FileNotFoundError(f"Dict not found: {dict_path}")
tgt_dict = Dictionary.load(dict_path)
- logger.info(f'dictionary size ({data_cfg.vocab_filename}): '
- f'{len(tgt_dict):,}')
+ logger.info(
+ f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}"
+ )
- if getattr(args, 'train_subset', None) is not None:
- if not all(s.startswith('train') for s in args.train_subset.split(',')):
+ if getattr(args, "train_subset", None) is not None:
+ if not all(s.startswith("train") for s in args.train_subset.split(",")):
raise ValueError('Train splits should be named like "train*".')
return cls(args, tgt_dict)
def build_criterion(self, args):
from fairseq import criterions
+
if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1:
- raise ValueError('Please set "--ignore-prefix-size 1" since '
- 'target language ID token is prepended as BOS.')
+ raise ValueError(
+ 'Please set "--ignore-prefix-size 1" since '
+ "target language ID token is prepended as BOS."
+ )
return criterions.build_criterion(args, self)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
- is_train_split = split.startswith('train')
+ is_train_split = split.startswith("train")
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv(
- self.args.data, self.data_cfg, split, self.tgt_dict,
- pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split,
- epoch=epoch, seed=self.args.seed
+ self.args.data,
+ self.data_cfg,
+ split,
+ self.tgt_dict,
+ pre_tokenizer,
+ bpe_tokenizer,
+ is_train_split=is_train_split,
+ epoch=epoch,
+ seed=self.args.seed,
)
@property
@@ -91,30 +114,35 @@ def build_model(self, args):
return super(SpeechToTextTask, self).build_model(args)
def build_generator(
- self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None,
+ self,
+ models,
+ args,
+ seq_gen_cls=None,
+ extra_gen_cls_kwargs=None,
):
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
- raise ValueError('Please set "--prefix-size 1" since '
- 'target language ID token is prepended as BOS.')
+ raise ValueError(
+ 'Please set "--prefix-size 1" since '
+ "target language ID token is prepended as BOS."
+ )
lang_token_ids = {
- i for s, i in self.tgt_dict.indices.items()
+ i
+ for s, i in self.tgt_dict.indices.items()
if SpeechToTextDataset.is_lang_tag(s)
}
- extra_gen_cls_kwargs = {'symbols_to_strip_from_output': lang_token_ids}
+ extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids}
return super().build_generator(
- models, args, seq_gen_cls=None,
- extra_gen_cls_kwargs=extra_gen_cls_kwargs
+ models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def build_tokenizer(self, args):
- logger.info(f'pre-tokenizer: {self.data_cfg.pre_tokenizer}')
+ logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}")
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer))
def build_bpe(self, args):
- logger.info(f'tokenizer: {self.data_cfg.bpe_tokenizer}')
+ logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}")
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer))
@classmethod
def build_dataset_for_inference(cls, audio_paths, n_frames):
- return SpeechToTextDataset('interactive', False, {}, audio_paths,
- n_frames)
+ return SpeechToTextDataset("interactive", False, {}, audio_paths, n_frames)
diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py
index a04924605c..79007a6d9f 100644
--- a/fairseq/tasks/translation.py
+++ b/fairseq/tasks/translation.py
@@ -3,28 +3,27 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from argparse import Namespace
-import json
import itertools
+import json
import logging
import os
-from fairseq import options
-import numpy as np
+from argparse import Namespace
-from fairseq import metrics, utils
+import numpy as np
+from fairseq import metrics, options, utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
- data_utils,
- encoders,
- indexed_dataset,
LanguagePairDataset,
PrependTokenDataset,
StripTokenDataset,
TruncateDataset,
+ data_utils,
+ encoders,
+ indexed_dataset,
)
+from fairseq.tasks import LegacyFairseqTask, register_task
-from fairseq.tasks import register_task, LegacyFairseqTask
EVAL_BLEU_ORDER = 4
@@ -33,40 +32,53 @@
def load_langpair_dataset(
- data_path, split,
- src, src_dict,
- tgt, tgt_dict,
- combine, dataset_impl, upsample_primary,
- left_pad_source, left_pad_target, max_source_positions,
- max_target_positions, prepend_bos=False, load_alignments=False,
- truncate_source=False, append_source_id=False,
+ data_path,
+ split,
+ src,
+ src_dict,
+ tgt,
+ tgt_dict,
+ combine,
+ dataset_impl,
+ upsample_primary,
+ left_pad_source,
+ left_pad_target,
+ max_source_positions,
+ max_target_positions,
+ prepend_bos=False,
+ load_alignments=False,
+ truncate_source=False,
+ append_source_id=False,
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
):
-
def split_exists(split, src, tgt, lang, data_path):
- filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
+ filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
src_datasets = []
tgt_datasets = []
for k in itertools.count():
- split_k = split + (str(k) if k > 0 else '')
+ split_k = split + (str(k) if k > 0 else "")
# infer langcode
if split_exists(split_k, src, tgt, src, data_path):
- prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
- prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
else:
if k > 0:
break
else:
- raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, data_path)
+ )
- src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
+ src_dataset = data_utils.load_indexed_dataset(
+ prefix + src, src_dict, dataset_impl
+ )
if truncate_source:
src_dataset = AppendTokenDataset(
TruncateDataset(
@@ -77,13 +89,17 @@ def split_exists(split, src, tgt, lang, data_path):
)
src_datasets.append(src_dataset)
- tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)
+ tgt_dataset = data_utils.load_indexed_dataset(
+ prefix + tgt, tgt_dict, dataset_impl
+ )
if tgt_dataset is not None:
tgt_datasets.append(tgt_dataset)
- logger.info('{} {} {}-{} {} examples'.format(
- data_path, split_k, src, tgt, len(src_datasets[-1])
- ))
+ logger.info(
+ "{} {} {}-{} {} examples".format(
+ data_path, split_k, src, tgt, len(src_datasets[-1])
+ )
+ )
if not combine:
break
@@ -110,31 +126,42 @@ def split_exists(split, src, tgt, lang, data_path):
eos = None
if append_source_id:
- src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src)))
+ src_dataset = AppendTokenDataset(
+ src_dataset, src_dict.index("[{}]".format(src))
+ )
if tgt_dataset is not None:
- tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
- eos = tgt_dict.index('[{}]'.format(tgt))
+ tgt_dataset = AppendTokenDataset(
+ tgt_dataset, tgt_dict.index("[{}]".format(tgt))
+ )
+ eos = tgt_dict.index("[{}]".format(tgt))
align_dataset = None
if load_alignments:
- align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
+ align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
- align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)
+ align_dataset = data_utils.load_indexed_dataset(
+ align_path, None, dataset_impl
+ )
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
return LanguagePairDataset(
- src_dataset, src_dataset.sizes, src_dict,
- tgt_dataset, tgt_dataset_sizes, tgt_dict,
+ src_dataset,
+ src_dataset.sizes,
+ src_dict,
+ tgt_dataset,
+ tgt_dataset_sizes,
+ tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
- align_dataset=align_dataset, eos=eos,
+ align_dataset=align_dataset,
+ eos=eos,
num_buckets=num_buckets,
shuffle=shuffle,
pad_to_multiple=pad_to_multiple,
)
-@register_task('translation')
+@register_task("translation")
class TranslationTask(LegacyFairseqTask):
"""
Translate from one (source) language to another (target) language.
@@ -227,18 +254,26 @@ def setup_task(cls, args, **kwargs):
assert len(paths) > 0
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
- args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])
+ args.source_lang, args.target_lang = data_utils.infer_language_pair(
+ paths[0]
+ )
if args.source_lang is None or args.target_lang is None:
- raise Exception('Could not infer language pair, please provide it explicitly')
+ raise Exception(
+ "Could not infer language pair, please provide it explicitly"
+ )
# load dictionaries
- src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
- tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
+ src_dict = cls.load_dictionary(
+ os.path.join(paths[0], "dict.{}.txt".format(args.source_lang))
+ )
+ tgt_dict = cls.load_dictionary(
+ os.path.join(paths[0], "dict.{}.txt".format(args.target_lang))
+ )
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
- logger.info('[{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
- logger.info('[{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))
+ logger.info("[{}] dictionary: {} types".format(args.source_lang, len(src_dict)))
+ logger.info("[{}] dictionary: {} types".format(args.target_lang, len(tgt_dict)))
return cls(args, src_dict, tgt_dict)
@@ -259,8 +294,14 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
src, tgt = self.args.source_lang, self.args.target_lang
self.datasets[split] = load_langpair_dataset(
- data_path, split, src, self.src_dict, tgt, self.tgt_dict,
- combine=combine, dataset_impl=self.args.dataset_impl,
+ data_path,
+ split,
+ src,
+ self.src_dict,
+ tgt,
+ self.tgt_dict,
+ combine=combine,
+ dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
@@ -269,45 +310,52 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
load_alignments=self.args.load_alignments,
truncate_source=self.args.truncate_source,
num_buckets=self.args.num_batch_buckets,
- shuffle=(split != 'test'),
+ shuffle=(split != "test"),
pad_to_multiple=self.args.required_seq_len_multiple,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
- return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary,
- tgt_dict=self.target_dictionary,
- constraints=constraints)
+ return LanguagePairDataset(
+ src_tokens,
+ src_lengths,
+ self.source_dictionary,
+ tgt_dict=self.target_dictionary,
+ constraints=constraints,
+ )
def build_model(self, args):
model = super().build_model(args)
- if getattr(args, 'eval_bleu', False):
- assert getattr(args, 'eval_bleu_detok', None) is not None, (
- '--eval-bleu-detok is required if using --eval-bleu; '
- 'try --eval-bleu-detok=moses (or --eval-bleu-detok=space '
- 'to disable detokenization, e.g., when using sentencepiece)'
+ if getattr(args, "eval_bleu", False):
+ assert getattr(args, "eval_bleu_detok", None) is not None, (
+ "--eval-bleu-detok is required if using --eval-bleu; "
+ "try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
+ "to disable detokenization, e.g., when using sentencepiece)"
+ )
+ detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}")
+ self.tokenizer = encoders.build_tokenizer(
+ Namespace(
+ tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args
+ )
+ )
+
+ gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}")
+ self.sequence_generator = self.build_generator(
+ [model], Namespace(**gen_args)
)
- detok_args = json.loads(getattr(args, 'eval_bleu_detok_args', '{}') or '{}')
- self.tokenizer = encoders.build_tokenizer(Namespace(
- tokenizer=getattr(args, 'eval_bleu_detok', None),
- **detok_args
- ))
-
- gen_args = json.loads(getattr(args, 'eval_bleu_args', '{}') or '{}')
- self.sequence_generator = self.build_generator([model], Namespace(**gen_args))
return model
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if self.args.eval_bleu:
bleu = self._inference_with_bleu(self.sequence_generator, sample, model)
- logging_output['_bleu_sys_len'] = bleu.sys_len
- logging_output['_bleu_ref_len'] = bleu.ref_len
+ logging_output["_bleu_sys_len"] = bleu.sys_len
+ logging_output["_bleu_ref_len"] = bleu.ref_len
# we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync
assert len(bleu.counts) == EVAL_BLEU_ORDER
for i in range(EVAL_BLEU_ORDER):
- logging_output['_bleu_counts_' + str(i)] = bleu.counts[i]
- logging_output['_bleu_totals_' + str(i)] = bleu.totals[i]
+ logging_output["_bleu_counts_" + str(i)] = bleu.counts[i]
+ logging_output["_bleu_totals_" + str(i)] = bleu.totals[i]
return loss, sample_size, logging_output
def reduce_metrics(self, logging_outputs, criterion):
@@ -319,34 +367,35 @@ def sum_logs(key):
counts, totals = [], []
for i in range(EVAL_BLEU_ORDER):
- counts.append(sum_logs('_bleu_counts_' + str(i)))
- totals.append(sum_logs('_bleu_totals_' + str(i)))
+ counts.append(sum_logs("_bleu_counts_" + str(i)))
+ totals.append(sum_logs("_bleu_totals_" + str(i)))
if max(totals) > 0:
# log counts as numpy arrays -- log_scalar will sum them correctly
- metrics.log_scalar('_bleu_counts', np.array(counts))
- metrics.log_scalar('_bleu_totals', np.array(totals))
- metrics.log_scalar('_bleu_sys_len', sum_logs('_bleu_sys_len'))
- metrics.log_scalar('_bleu_ref_len', sum_logs('_bleu_ref_len'))
+ metrics.log_scalar("_bleu_counts", np.array(counts))
+ metrics.log_scalar("_bleu_totals", np.array(totals))
+ metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
+ metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))
def compute_bleu(meters):
import inspect
import sacrebleu
+
fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
- if 'smooth_method' in fn_sig:
- smooth = {'smooth_method': 'exp'}
+ if "smooth_method" in fn_sig:
+ smooth = {"smooth_method": "exp"}
else:
- smooth = {'smooth': 'exp'}
+ smooth = {"smooth": "exp"}
bleu = sacrebleu.compute_bleu(
- correct=meters['_bleu_counts'].sum,
- total=meters['_bleu_totals'].sum,
- sys_len=meters['_bleu_sys_len'].sum,
- ref_len=meters['_bleu_ref_len'].sum,
+ correct=meters["_bleu_counts"].sum,
+ total=meters["_bleu_totals"].sum,
+ sys_len=meters["_bleu_sys_len"].sum,
+ ref_len=meters["_bleu_ref_len"].sum,
**smooth
)
return round(bleu.score, 2)
- metrics.log_derived('bleu', compute_bleu)
+ metrics.log_derived("bleu", compute_bleu)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
@@ -374,9 +423,7 @@ def decode(toks, escape_unk=False):
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
- unk_string=(
- "UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"
- ),
+ unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
)
if self.tokenizer:
s = self.tokenizer.decode(s)
@@ -385,15 +432,17 @@ def decode(toks, escape_unk=False):
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None)
hyps, refs = [], []
for i in range(len(gen_out)):
- hyps.append(decode(gen_out[i][0]['tokens']))
- refs.append(decode(
- utils.strip_pad(sample['target'][i], self.tgt_dict.pad()),
- escape_unk=True, # don't count as matches to the hypo
- ))
+ hyps.append(decode(gen_out[i][0]["tokens"]))
+ refs.append(
+ decode(
+ utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
+ escape_unk=True, # don't count as matches to the hypo
+ )
+ )
if self.args.eval_bleu_print_samples:
- logger.info('example hypothesis: ' + hyps[0])
- logger.info('example reference: ' + refs[0])
+ logger.info("example hypothesis: " + hyps[0])
+ logger.info("example reference: " + refs[0])
if self.args.eval_tokenized_bleu:
- return sacrebleu.corpus_bleu(hyps, [refs], tokenize='none')
+ return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none")
else:
return sacrebleu.corpus_bleu(hyps, [refs])
diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py
index 4d574ffc82..8710b7fe7d 100644
--- a/fairseq/tasks/translation_from_pretrained_bart.py
+++ b/fairseq/tasks/translation_from_pretrained_bart.py
@@ -4,15 +4,14 @@
# LICENSE file in the root directory of this source tree.
import torch
-
-from fairseq.data import LanguagePairDataset
from fairseq import utils
+from fairseq.data import LanguagePairDataset
-from .translation import load_langpair_dataset, TranslationTask
from . import register_task
+from .translation import TranslationTask, load_langpair_dataset
-@register_task('translation_from_pretrained_bart')
+@register_task("translation_from_pretrained_bart")
class TranslationFromPretrainedBARTTask(TranslationTask):
"""
Translate from source language to target language with a model initialized with a multilingual pretrain.
@@ -52,11 +51,11 @@ def add_args(parser):
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args, src_dict, tgt_dict)
- self.langs = args.langs.split(',')
+ self.langs = args.langs.split(",")
for d in [src_dict, tgt_dict]:
for l in self.langs:
- d.add_symbol('[{}]'.format(l))
- d.add_symbol('')
+ d.add_symbol("[{}]".format(l))
+ d.add_symbol("")
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
@@ -72,50 +71,62 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
src, tgt = self.args.source_lang, self.args.target_lang
self.datasets[split] = load_langpair_dataset(
- data_path, split, src, self.src_dict, tgt, self.tgt_dict,
- combine=combine, dataset_impl=self.args.dataset_impl,
+ data_path,
+ split,
+ src,
+ self.src_dict,
+ tgt,
+ self.tgt_dict,
+ combine=combine,
+ dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
- max_source_positions=getattr(self.args, 'max_source_positions', 1024),
- max_target_positions=getattr(self.args, 'max_target_positions', 1024),
+ max_source_positions=getattr(self.args, "max_source_positions", 1024),
+ max_target_positions=getattr(self.args, "max_target_positions", 1024),
load_alignments=self.args.load_alignments,
- prepend_bos=getattr(self.args, 'prepend_bos', False),
- append_source_id=True
- )
+ prepend_bos=getattr(self.args, "prepend_bos", False),
+ append_source_id=True,
+ )
def build_generator(self, models, args, **unused):
- if getattr(args, 'score_reference', False):
+ if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer
+
return SequenceScorer(
self.target_dictionary,
- eos=self.tgt_dict.index('[{}]'.format(self.args.target_lang))
+ eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
)
else:
from fairseq.sequence_generator import SequenceGenerator
+
return SequenceGenerator(
models,
self.target_dictionary,
- beam_size=getattr(args, 'beam', 5),
- max_len_a=getattr(args, 'max_len_a', 0),
- max_len_b=getattr(args, 'max_len_b', 200),
- min_len=getattr(args, 'min_len', 1),
- normalize_scores=(not getattr(args, 'unnormalized', False)),
- len_penalty=getattr(args, 'lenpen', 1),
- unk_penalty=getattr(args, 'unkpen', 0),
- temperature=getattr(args, 'temperature', 1.),
- match_source_len=getattr(args, 'match_source_len', False),
- no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
- eos=self.tgt_dict.index('[{}]'.format(self.args.target_lang))
+ beam_size=getattr(args, "beam", 5),
+ max_len_a=getattr(args, "max_len_a", 0),
+ max_len_b=getattr(args, "max_len_b", 200),
+ min_len=getattr(args, "min_len", 1),
+ normalize_scores=(not getattr(args, "unnormalized", False)),
+ len_penalty=getattr(args, "lenpen", 1),
+ unk_penalty=getattr(args, "unkpen", 0),
+ temperature=getattr(args, "temperature", 1.0),
+ match_source_len=getattr(args, "match_source_len", False),
+ no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
+ eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
- src_lang_id = self.source_dictionary.index('[{}]'.format(self.args.source_lang))
+ src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang))
source_tokens = []
for s_t in src_tokens:
s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
source_tokens.append(s_t)
- dataset = LanguagePairDataset(source_tokens, src_lengths, self.source_dictionary,
- tgt_dict=self.target_dictionary,
- constraints=constraints)
+ dataset = LanguagePairDataset(
+ source_tokens,
+ src_lengths,
+ self.source_dictionary,
+ tgt_dict=self.target_dictionary,
+ constraints=constraints,
+ )
return dataset
diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py
index 3af9bb2532..4678774922 100644
--- a/fairseq/tasks/translation_lev.py
+++ b/fairseq/tasks/translation_lev.py
@@ -6,15 +6,14 @@
import os
import torch
-
+from fairseq import utils
from fairseq.data import LanguagePairDataset
-
-from fairseq.utils import new_arange
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
-from fairseq import utils
+from fairseq.utils import new_arange
+
-@register_task('translation_lev')
+@register_task("translation_lev")
class TranslationLevenshteinTask(TranslationTask):
"""
Translation (Sequence Generation) task for Levenshtein Transformer
@@ -46,8 +45,14 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
src, tgt = self.args.source_lang, self.args.target_lang
self.datasets[split] = load_langpair_dataset(
- data_path, split, src, self.src_dict, tgt, self.tgt_dict,
- combine=combine, dataset_impl=self.args.dataset_impl,
+ data_path,
+ split,
+ src,
+ self.src_dict,
+ tgt,
+ self.tgt_dict,
+ combine=combine,
+ dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
@@ -66,24 +71,32 @@ def _random_delete(target_tokens):
target_mask = target_tokens.eq(pad)
target_score = target_tokens.clone().float().uniform_()
target_score.masked_fill_(
- target_tokens.eq(bos) | target_tokens.eq(eos), 0.0)
+ target_tokens.eq(bos) | target_tokens.eq(eos), 0.0
+ )
target_score.masked_fill_(target_mask, 1)
target_score, target_rank = target_score.sort(1)
target_length = target_mask.size(1) - target_mask.float().sum(
- 1, keepdim=True)
+ 1, keepdim=True
+ )
# do not delete and (we assign 0 score for them)
- target_cutoff = 2 + ((target_length - 2) * target_score.new_zeros(
- target_score.size(0), 1).uniform_()).long()
+ target_cutoff = (
+ 2
+ + (
+ (target_length - 2)
+ * target_score.new_zeros(target_score.size(0), 1).uniform_()
+ ).long()
+ )
target_cutoff = target_score.sort(1)[1] >= target_cutoff
- prev_target_tokens = target_tokens.gather(
- 1, target_rank).masked_fill_(target_cutoff, pad).gather(
- 1,
- target_rank.masked_fill_(target_cutoff,
- max_len).sort(1)[1])
- prev_target_tokens = prev_target_tokens[:, :prev_target_tokens.
- ne(pad).sum(1).max()]
+ prev_target_tokens = (
+ target_tokens.gather(1, target_rank)
+ .masked_fill_(target_cutoff, pad)
+ .gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1])
+ )
+ prev_target_tokens = prev_target_tokens[
+ :, : prev_target_tokens.ne(pad).sum(1).max()
+ ]
return prev_target_tokens
@@ -93,9 +106,9 @@ def _random_mask(target_tokens):
eos = self.tgt_dict.eos()
unk = self.tgt_dict.unk()
- target_masks = target_tokens.ne(pad) & \
- target_tokens.ne(bos) & \
- target_tokens.ne(eos)
+ target_masks = (
+ target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos)
+ )
target_score = target_tokens.clone().float().uniform_()
target_score.masked_fill_(~target_masks, 2.0)
target_length = target_masks.sum(1).float()
@@ -105,7 +118,8 @@ def _random_mask(target_tokens):
_, target_rank = target_score.sort(1)
target_cutoff = new_arange(target_rank) < target_length[:, None].long()
prev_target_tokens = target_tokens.masked_fill(
- target_cutoff.scatter(1, target_rank, target_cutoff), unk)
+ target_cutoff.scatter(1, target_rank, target_cutoff), unk
+ )
return prev_target_tokens
def _full_mask(target_tokens):
@@ -114,17 +128,18 @@ def _full_mask(target_tokens):
eos = self.tgt_dict.eos()
unk = self.tgt_dict.unk()
- target_mask = target_tokens.eq(bos) | target_tokens.eq(
- eos) | target_tokens.eq(pad)
+ target_mask = (
+ target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad)
+ )
return target_tokens.masked_fill(~target_mask, unk)
- if self.args.noise == 'random_delete':
+ if self.args.noise == "random_delete":
return _random_delete(target_tokens)
- elif self.args.noise == 'random_mask':
+ elif self.args.noise == "random_mask":
return _random_mask(target_tokens)
- elif self.args.noise == 'full_mask':
+ elif self.args.noise == "full_mask":
return _full_mask(target_tokens)
- elif self.args.noise == 'no_noise':
+ elif self.args.noise == "no_noise":
return target_tokens
else:
raise NotImplementedError
@@ -132,34 +147,34 @@ def _full_mask(target_tokens):
def build_generator(self, models, args, **unused):
# add models input to match the API for SequenceGenerator
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
+
return IterativeRefinementGenerator(
self.target_dictionary,
- eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0),
- max_iter=getattr(args, 'iter_decode_max_iter', 10),
- beam_size=getattr(args, 'iter_decode_with_beam', 1),
- reranking=getattr(args, 'iter_decode_with_external_reranker', False),
- decoding_format=getattr(args, 'decoding_format', None),
- adaptive=not getattr(args, 'iter_decode_force_max_iter', False),
- retain_history=getattr(args, 'retain_iter_history', False))
+ eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0),
+ max_iter=getattr(args, "iter_decode_max_iter", 10),
+ beam_size=getattr(args, "iter_decode_with_beam", 1),
+ reranking=getattr(args, "iter_decode_with_external_reranker", False),
+ decoding_format=getattr(args, "decoding_format", None),
+ adaptive=not getattr(args, "iter_decode_force_max_iter", False),
+ retain_history=getattr(args, "retain_iter_history", False),
+ )
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
# Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/
- raise NotImplementedError("Constrained decoding with the translation_lev task is not supported")
+ raise NotImplementedError(
+ "Constrained decoding with the translation_lev task is not supported"
+ )
return LanguagePairDataset(
src_tokens, src_lengths, self.source_dictionary, append_bos=True
)
- def train_step(self,
- sample,
- model,
- criterion,
- optimizer,
- update_num,
- ignore_grad=False):
+ def train_step(
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
+ ):
model.train()
- sample['prev_target'] = self.inject_noise(sample['target'])
+ sample["prev_target"] = self.inject_noise(sample["target"])
loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0
@@ -169,6 +184,6 @@ def train_step(self,
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
- sample['prev_target'] = self.inject_noise(sample['target'])
+ sample["prev_target"] = self.inject_noise(sample["target"])
loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output
diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py
index 960b82e1e8..95a2d162c0 100644
--- a/fairseq/tasks/translation_multi_simple_epoch.py
+++ b/fairseq/tasks/translation_multi_simple_epoch.py
@@ -3,34 +3,40 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import logging
import datetime
+import logging
import time
import torch
from fairseq.data import (
- data_utils,
FairseqDataset,
- iterators,
LanguagePairDataset,
ListDataset,
+ data_utils,
+ iterators,
+)
+from fairseq.data.multilingual.multilingual_data_manager import (
+ MultilingualDatasetManager,
)
-
-from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.data.multilingual.sampling_method import SamplingMethod
-from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager
+from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.utils import FileContentsAction
+
###
def get_time_gap(s, e):
- return (datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)).__str__()
+ return (
+ datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
+ ).__str__()
+
+
###
logger = logging.getLogger(__name__)
-@register_task('translation_multi_simple_epoch')
+@register_task("translation_multi_simple_epoch")
class TranslationMultiSimpleEpochTask(LegacyFairseqTask):
"""
Translate from one (source) language to another (target) language.
@@ -79,7 +85,7 @@ def __init__(self, args, langs, dicts, training):
if training:
self.lang_pairs = args.lang_pairs
else:
- self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
+ self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
# eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus
@@ -92,7 +98,8 @@ def __init__(self, args, langs, dicts, training):
self.model_lang_pairs = self.lang_pairs
self.sampling_method = SamplingMethod.build_sampler(args, self)
self.data_manager = MultilingualDatasetManager.setup_data_manager(
- args, self.lang_pairs, langs, dicts, self.sampling_method)
+ args, self.lang_pairs, langs, dicts, self.sampling_method
+ )
@classmethod
def setup_task(cls, args, **kwargs):
@@ -130,59 +137,67 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
else:
# estimate the shard epoch from virtual data size and virtual epoch size
shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch)
- logger.info(f'loading data for {split} epoch={epoch}/{shard_epoch}')
+ logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}")
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
if split in self.datasets:
del self.datasets[split]
- logger.info('old dataset deleted manually')
+ logger.info("old dataset deleted manually")
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
self.datasets[split] = self.data_manager.load_sampled_multi_epoch_dataset(
split,
self.training,
- epoch=epoch, combine=combine, shard_epoch=shard_epoch, **kwargs
+ epoch=epoch,
+ combine=combine,
+ shard_epoch=shard_epoch,
+ **kwargs,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
- raise NotImplementedError("Constrained decoding with the multilingual_translation task is not supported")
+ raise NotImplementedError(
+ "Constrained decoding with the multilingual_translation task is not supported"
+ )
src_data = ListDataset(src_tokens, src_lengths)
dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary)
- src_langtok_spec, tgt_langtok_spec = self.args.langtoks['main']
+ src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"]
if self.args.lang_tok_replacing_bos_eos:
dataset = self.data_manager.alter_dataset_langtok(
- dataset,
- src_eos=self.source_dictionary.eos(),
- src_lang=self.args.source_lang,
- tgt_eos=self.target_dictionary.eos(),
- tgt_lang=self.args.target_lang,
- src_langtok_spec=src_langtok_spec,
- tgt_langtok_spec=tgt_langtok_spec,
- )
+ dataset,
+ src_eos=self.source_dictionary.eos(),
+ src_lang=self.args.source_lang,
+ tgt_eos=self.target_dictionary.eos(),
+ tgt_lang=self.args.target_lang,
+ src_langtok_spec=src_langtok_spec,
+ tgt_langtok_spec=tgt_langtok_spec,
+ )
else:
dataset.src = self.data_manager.src_dataset_tranform_func(
self.args.source_lang,
self.args.target_lang,
dataset=dataset.src,
spec=src_langtok_spec,
- )
+ )
return dataset
def build_generator(
- self, models, args,
- seq_gen_cls=None, extra_gen_cls_kwargs=None,
+ self,
+ models,
+ args,
+ seq_gen_cls=None,
+ extra_gen_cls_kwargs=None,
):
- if not getattr(args, 'keep_inference_langtok', False):
- _, tgt_langtok_spec = self.args.langtoks['main']
+ if not getattr(args, "keep_inference_langtok", False):
+ _, tgt_langtok_spec = self.args.langtoks["main"]
if tgt_langtok_spec:
- tgt_lang_tok = self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec)
+ tgt_lang_tok = self.data_manager.get_decoder_langtok(
+ self.args.target_lang, tgt_langtok_spec
+ )
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
- extra_gen_cls_kwargs['symbols_to_strip_from_output'] = {tgt_lang_tok}
+ extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok}
return super().build_generator(
- models, args,
- seq_gen_cls=None,
- extra_gen_cls_kwargs=extra_gen_cls_kwargs
+ models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def build_model(self, args):
@@ -192,30 +207,37 @@ def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
return loss, sample_size, logging_output
- def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
+ def inference_step(
+ self, generator, models, sample, prefix_tokens=None, constraints=None
+ ):
with torch.no_grad():
- _, tgt_langtok_spec = self.args.langtoks['main']
+ _, tgt_langtok_spec = self.args.langtoks["main"]
if not self.args.lang_tok_replacing_bos_eos:
if prefix_tokens is None and tgt_langtok_spec:
- tgt_lang_tok = self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec)
- src_tokens = sample['net_input']['src_tokens']
+ tgt_lang_tok = self.data_manager.get_decoder_langtok(
+ self.args.target_lang, tgt_langtok_spec
+ )
+ src_tokens = sample["net_input"]["src_tokens"]
bsz = src_tokens.size(0)
- prefix_tokens = torch.LongTensor(
- [[tgt_lang_tok]]
- ).expand(bsz, 1).to(src_tokens)
+ prefix_tokens = (
+ torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens)
+ )
return generator.generate(
- models,
- sample,
- prefix_tokens=prefix_tokens,
- constraints=constraints,
+ models,
+ sample,
+ prefix_tokens=prefix_tokens,
+ constraints=constraints,
)
else:
return generator.generate(
- models,
- sample,
- prefix_tokens=prefix_tokens,
- bos_token=self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec)
- if tgt_langtok_spec else self.target_dictionary.eos(),
+ models,
+ sample,
+ prefix_tokens=prefix_tokens,
+ bos_token=self.data_manager.get_decoder_langtok(
+ self.args.target_lang, tgt_langtok_spec
+ )
+ if tgt_langtok_spec
+ else self.target_dictionary.eos(),
)
def reduce_metrics(self, logging_outputs, criterion):
@@ -234,15 +256,18 @@ def target_dictionary(self):
return next(iter(self.dicts.values()))
def create_batch_sampler_func(
- self, max_positions, ignore_invalid_inputs,
- max_tokens, max_sentences,
+ self,
+ max_positions,
+ ignore_invalid_inputs,
+ max_tokens,
+ max_sentences,
required_batch_size_multiple=1,
seed=1,
):
- def construct_batch_sampler(
- dataset, epoch
- ):
- splits = [s for s, _ in self.datasets.items() if self.datasets[s] == dataset]
+ def construct_batch_sampler(dataset, epoch):
+ splits = [
+ s for s, _ in self.datasets.items() if self.datasets[s] == dataset
+ ]
split = splits[0] if len(splits) > 0 else None
# NEW implementation
if epoch is not None:
@@ -255,7 +280,9 @@ def construct_batch_sampler(
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
- logger.info(f'[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}')
+ logger.info(
+ f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}"
+ )
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
# filter examples that are too large
@@ -264,7 +291,9 @@ def construct_batch_sampler(
indices = self.filter_indices_by_size(
indices, dataset, max_positions, ignore_invalid_inputs
)
- logger.info(f'[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}')
+ logger.info(
+ f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}"
+ )
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
# create mini-batches with given size constraints
@@ -276,19 +305,34 @@ def construct_batch_sampler(
required_batch_size_multiple=required_batch_size_multiple,
)
- logger.info(f'[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}')
- logger.info(f'[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}')
+ logger.info(
+ f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}"
+ )
+ logger.info(
+ f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}"
+ )
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
return batch_sampler
+
return construct_batch_sampler
# we need to override get_batch_iterator because we want to reset the epoch iterator each time
def get_batch_iterator(
- self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
- ignore_invalid_inputs=False, required_batch_size_multiple=1,
- seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
- data_buffer_size=0, disable_iterator_cache=False,
+ self,
+ dataset,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=1,
+ data_buffer_size=0,
+ disable_iterator_cache=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
@@ -329,9 +373,7 @@ def get_batch_iterator(
assert isinstance(dataset, FairseqDataset)
if dataset in self.dataset_to_epoch_iter:
return self.dataset_to_epoch_iter[dataset]
- if (
- self.args.sampling_method == 'RoundRobin'
- ):
+ if self.args.sampling_method == "RoundRobin":
batch_iter = super().get_batch_iterator(
dataset,
max_tokens=max_tokens,
@@ -351,8 +393,10 @@ def get_batch_iterator(
return batch_iter
construct_batch_sampler = self.create_batch_sampler_func(
- max_positions, ignore_invalid_inputs,
- max_tokens, max_sentences,
+ max_positions,
+ ignore_invalid_inputs,
+ max_tokens,
+ max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
)
diff --git a/fairseq/token_generation_constraints.py b/fairseq/token_generation_constraints.py
index 7077199fd9..e708dc51bc 100644
--- a/fairseq/token_generation_constraints.py
+++ b/fairseq/token_generation_constraints.py
@@ -27,10 +27,11 @@
that many times in the output.
"""
+from collections import Counter
+from typing import List, Optional, Set, Tuple
+
import torch
-from collections import Counter
-from typing import Tuple, List, Optional, Set
class ConstraintState:
def __init__(self):
@@ -70,7 +71,11 @@ def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tenso
for sentence_constraints in batch_constraints:
if len(sentence_constraints):
# number of constraints, plus sum of constrain lens, plus a zero after each
- constraints_len = 1 + sum([c.size(0) for c in sentence_constraints]) + len(sentence_constraints)
+ constraints_len = (
+ 1
+ + sum([c.size(0) for c in sentence_constraints])
+ + len(sentence_constraints)
+ )
max_constraints_len = max(max_constraints_len, constraints_len)
batch_size = len(batch_constraints)
@@ -80,7 +85,7 @@ def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tenso
offset = 1
for j, constraint in enumerate(sentence_constraints):
this_len = constraint.size(0)
- constraints_tensor[i, offset:offset+this_len] = constraint
+ constraints_tensor[i, offset : offset + this_len] = constraint
offset += this_len + 1
return constraints_tensor.long()
@@ -107,6 +112,7 @@ class ConstraintNode:
"""
Represents a node in a trie managing unordered constraints.
"""
+
def __init__(self, token: int = None, parent=None):
# The token associate with this node (None for the root)
self.token = int(token) if token is not None else None
@@ -198,9 +204,8 @@ class UnorderedConstraintState(ConstraintState):
Records progress through the set of constraints for each item in the beam
using a trie.
"""
- def __init__(self,
- node: ConstraintNode,
- copy_from: "ConstraintState" = None):
+
+ def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None):
self.node = node
if copy_from is None:
@@ -383,9 +388,8 @@ class OrderedConstraintState(ConstraintState):
"""
Records progress through the set of linear nonbranching constraints with gaps.
"""
- def __init__(self,
- sequence: ConstraintSequence,
- state: int = -1):
+
+ def __init__(self, sequence: ConstraintSequence, state: int = -1):
self.sequence = sequence
self.state = state
@@ -407,7 +411,9 @@ def copy(self):
def num_completed(self):
if self.state == -1:
return 0
- count = len(list(filter(lambda x: x, self.sequence.endpoints[0:self.state+1])))
+ count = len(
+ list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1]))
+ )
return count
@property
diff --git a/fairseq/tokenizer.py b/fairseq/tokenizer.py
index 8c4d694aa0..42131f7b1d 100644
--- a/fairseq/tokenizer.py
+++ b/fairseq/tokenizer.py
@@ -5,6 +5,7 @@
import re
+
SPACE_NORMALIZER = re.compile(r"\s+")
diff --git a/fairseq/trainer.py b/fairseq/trainer.py
index 5d68783bfb..0069b79425 100644
--- a/fairseq/trainer.py
+++ b/fairseq/trainer.py
@@ -8,14 +8,13 @@
"""
import contextlib
-from itertools import chain
import logging
import sys
import time
+from itertools import chain
from typing import Any, Dict, List
import torch
-
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics
@@ -43,20 +42,21 @@ def __init__(self, args, task, model, criterion, quantizer=None):
# catalog shared parameters
shared_params = _catalog_shared_params(model)
- self.tpu = getattr(args, 'tpu', False)
+ self.tpu = getattr(args, "tpu", False)
self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
if self.cuda:
- self.device = torch.device('cuda')
+ self.device = torch.device("cuda")
elif self.tpu:
self.device = utils.get_tpu_device(args)
else:
- self.device = torch.device('cpu')
+ self.device = torch.device("cpu")
# copy model and criterion to current device/dtype
self._criterion = criterion
self._model = model
if self.tpu:
import torch_xla.core.xla_model as xm
+
self._model = xm.send_cpu_data_to_device(self._model, self.device)
if args.fp16:
self._criterion = self._criterion.half()
@@ -77,7 +77,7 @@ def __init__(self, args, task, model, criterion, quantizer=None):
ref = _get_module_by_path(self._model, shared_param[0])
for path in shared_param[1:]:
logger.info(
- 'detected shared parameter: {} <- {}'.format(shared_param[0], path)
+ "detected shared parameter: {} <- {}".format(shared_param[0], path)
)
_set_module_by_path(self._model, path, ref)
@@ -134,7 +134,7 @@ def data_parallel_world_size(self):
@property
def data_parallel_process_group(self):
if self.tpu:
- return ('tpu', None)
+ return ("tpu", None)
else:
return None
@@ -156,8 +156,9 @@ def criterion(self):
and not self.tpu
):
self._wrapped_criterion = models.DistributedFairseqModel(
- self.args, self._criterion,
- process_group=self.data_parallel_process_group
+ self.args,
+ self._criterion,
+ process_group=self.data_parallel_process_group,
)
else:
self._wrapped_criterion = self._criterion
@@ -172,8 +173,9 @@ def model(self):
and not self.tpu
):
self._wrapped_model = models.DistributedFairseqModel(
- self.args, self._model,
- process_group=self.data_parallel_process_group
+ self.args,
+ self._model,
+ process_group=self.data_parallel_process_group,
)
else:
self._wrapped_model = self._model
@@ -219,17 +221,20 @@ def _build_optimizer(self):
if self.args.use_bmuf:
self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)
- if self.args.zero_sharding == 'os':
- if (self.args.fp16
- and not self.args.memory_efficient_fp16
- and not self.args.memory_efficient_bf16
+ if self.args.zero_sharding == "os":
+ if (
+ self.args.fp16
+ and not self.args.memory_efficient_fp16
+ and not self.args.memory_efficient_bf16
) and not self.args.fp16_no_flatten_grads:
raise ValueError(
- "ZeRO is incomptabile with fp16 and flattened grads. "
- "Please use --fp16-no-flatten-grads"
+ "ZeRO is incomptabile with fp16 and flattened grads. "
+ "Please use --fp16-no-flatten-grads"
)
else:
- optim.shard_(self.args, self._optimizer, self.data_parallel_process_group)
+ optim.shard_(
+ self.args, self._optimizer, self.data_parallel_process_group
+ )
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
@@ -416,7 +421,7 @@ def begin_epoch(self, epoch):
if self.tpu:
import torch_xla.core.xla_model as xm
- xm.rendezvous('begin_epoch') # wait for all workers
+ xm.rendezvous("begin_epoch") # wait for all workers
xm.mark_step()
def begin_valid_epoch(self, epoch):
@@ -511,13 +516,14 @@ def maybe_no_sync():
# To handle gradient accumulation use case, we explicitly
# mark step here for every forward pass without a backward pass
import torch_xla.core.xla_model as xm
+
xm.mark_step()
if is_dummy_batch:
if torch.is_tensor(sample_size):
sample_size.zero_()
else:
- sample_size *= 0.
+ sample_size *= 0.0
if torch.is_tensor(sample_size):
sample_size = sample_size.float()
@@ -527,27 +533,42 @@ def maybe_no_sync():
# gather logging outputs from all replicas
if self._sync_stats():
train_time = self._local_cumulative_training_time()
- logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs(
- logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch,
+ logging_outputs, (
+ sample_size,
+ ooms,
+ total_train_time,
+ ) = self._aggregate_logging_outputs(
+ logging_outputs,
+ sample_size,
+ ooms,
+ train_time,
+ ignore=is_dummy_batch,
+ )
+ self._cumulative_training_time = (
+ total_train_time / self.data_parallel_world_size
)
- self._cumulative_training_time = total_train_time / self.data_parallel_world_size
- if hasattr(self.model, 'all_reduce'):
+ if hasattr(self.model, "all_reduce"):
self.model.all_reduce()
overflow = False
try:
if self.tpu and self.data_parallel_world_size > 1:
import torch_xla.core.xla_model as xm
+
gradients = xm._fetch_gradients(self.optimizer.optimizer)
- xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size)
+ xm.all_reduce(
+ "sum", gradients, scale=1.0 / self.data_parallel_world_size
+ )
with torch.autograd.profiler.record_function("multiply-grads"):
# multiply gradients by (# GPUs / sample_size) since DDP
# already normalizes by the number of GPUs. Thus we get
# (sum_of_gradients / sample_size).
if not self.args.use_bmuf:
- self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size)
+ self.optimizer.multiply_grads(
+ self.data_parallel_world_size / sample_size
+ )
elif sample_size > 0: # BMUF needs to check sample size
num = self.data_parallel_world_size if self._sync_stats() else 1
self.optimizer.multiply_grads(num / sample_size)
@@ -559,7 +580,7 @@ def maybe_no_sync():
# check that grad norms are consistent across workers
if (
not self.args.use_bmuf
- and self.args.distributed_wrapper != 'SlowMo'
+ and self.args.distributed_wrapper != "SlowMo"
and not self.tpu
):
self._check_grad_norms(grad_norm)
@@ -573,14 +594,18 @@ def maybe_no_sync():
# out where it fails
with NanDetector(self.get_model()):
self.task.train_step(
- sample, self.model, self.criterion, self.optimizer, self.get_num_updates(),
- ignore_grad=False
+ sample,
+ self.model,
+ self.criterion,
+ self.optimizer,
+ self.get_num_updates(),
+ ignore_grad=False,
)
raise
except OverflowError as e:
overflow = True
logger.info("NOTE: overflow detected, " + str(e))
- grad_norm = torch.tensor(0.).cuda()
+ grad_norm = torch.tensor(0.0).cuda()
self.zero_grad()
except RuntimeError as e:
if "out of memory" in str(e):
@@ -589,18 +614,23 @@ def maybe_no_sync():
raise e
# Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
- if hasattr(self.model, 'perform_additional_optimizer_actions'):
- if hasattr(self.optimizer, 'fp32_params'):
- self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params)
+ if hasattr(self.model, "perform_additional_optimizer_actions"):
+ if hasattr(self.optimizer, "fp32_params"):
+ self.model.perform_additional_optimizer_actions(
+ self.optimizer.optimizer, self.optimizer.fp32_params
+ )
else:
- self.model.perform_additional_optimizer_actions(self.optimizer.optimizer)
+ self.model.perform_additional_optimizer_actions(
+ self.optimizer.optimizer
+ )
- if not overflow or self.args.distributed_wrapper == 'SlowMo':
+ if not overflow or self.args.distributed_wrapper == "SlowMo":
self.set_num_updates(self.get_num_updates() + 1)
if self.tpu:
# mark step on TPUs
import torch_xla.core.xla_model as xm
+
xm.mark_step()
# only log stats every log_interval steps
@@ -609,17 +639,27 @@ def maybe_no_sync():
if self.get_num_updates() % self.args.log_interval == 0:
# log memory usage
mem_info = xm.get_memory_info(self.device)
- gb_free = mem_info['kb_free'] / 1024 / 1024
- gb_total = mem_info['kb_total'] / 1024 / 1024
+ gb_free = mem_info["kb_free"] / 1024 / 1024
+ gb_total = mem_info["kb_total"] / 1024 / 1024
metrics.log_scalar(
- 'gb_free', gb_free, priority=1500, round=1, weight=0,
+ "gb_free",
+ gb_free,
+ priority=1500,
+ round=1,
+ weight=0,
)
metrics.log_scalar(
- 'gb_total', gb_total, priority=1600, round=1, weight=0,
+ "gb_total",
+ gb_total,
+ priority=1600,
+ round=1,
+ weight=0,
)
logging_output = self._reduce_and_log_stats(
- logging_outputs, sample_size, grad_norm,
+ logging_outputs,
+ sample_size,
+ grad_norm,
)
# log whenever there's an XLA compilation, since these
@@ -629,7 +669,9 @@ def maybe_no_sync():
else:
# log stats
logging_output = self._reduce_and_log_stats(
- logging_outputs, sample_size, grad_norm,
+ logging_outputs,
+ sample_size,
+ grad_norm,
)
# clear CUDA cache to reduce memory fragmentation
@@ -639,7 +681,8 @@ def maybe_no_sync():
and (
(self.get_num_updates() + self.args.empty_cache_freq - 1)
% self.args.empty_cache_freq
- ) == 0
+ )
+ == 0
):
torch.cuda.empty_cache()
@@ -660,7 +703,8 @@ def valid_step(self, sample, raise_oom=False):
"""Do forward pass in evaluation mode."""
if self.tpu:
import torch_xla.core.xla_model as xm
- xm.rendezvous('valid_step') # wait for all workers
+
+ xm.rendezvous("valid_step") # wait for all workers
xm.mark_step()
with torch.no_grad():
@@ -700,12 +744,14 @@ def valid_step(self, sample, raise_oom=False):
if torch.is_tensor(sample_size):
sample_size.zero_()
else:
- sample_size *= 0.
+ sample_size *= 0.0
# gather logging outputs from all replicas
if self.data_parallel_world_size > 1:
- logging_outputs, (sample_size, ) = self._aggregate_logging_outputs(
- logging_outputs, sample_size, ignore=is_dummy_batch,
+ logging_outputs, (sample_size,) = self._aggregate_logging_outputs(
+ logging_outputs,
+ sample_size,
+ ignore=is_dummy_batch,
)
# log validation stats
@@ -744,10 +790,10 @@ def get_meter(self, name):
"""[deprecated] Get a specific meter by name."""
from fairseq import meters
- if 'get_meter' not in self._warn_once:
- self._warn_once.add('get_meter')
+ if "get_meter" not in self._warn_once:
+ self._warn_once.add("get_meter")
utils.deprecation_warning(
- 'Trainer.get_meter is deprecated. Please use fairseq.metrics instead.'
+ "Trainer.get_meter is deprecated. Please use fairseq.metrics instead."
)
train_meters = metrics.get_meters("train")
@@ -772,7 +818,7 @@ def get_meter(self, name):
elif name in {"valid_loss", "valid_nll_loss"}:
# support for legacy train.py, which assumed these meters
# are always initialized
- k = name[len("valid_"):]
+ k = name[len("valid_") :]
m = metrics.get_meter("valid", k)
return m or meters.AverageMeter()
elif name == "oom":
@@ -820,8 +866,10 @@ def _prepare_sample(self, sample):
if self.cuda:
if self.pipeline_model_parallel:
- if 'target' in sample:
- sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device)
+ if "target" in sample:
+ sample["target"] = utils.move_to_cuda(
+ sample["target"], device=self.last_device
+ )
else:
sample = utils.move_to_cuda(sample)
@@ -855,10 +903,9 @@ def _sync_stats(self):
if self.data_parallel_world_size == 1:
return False
elif self.args.use_bmuf:
- return (
- (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
- and (self.get_num_updates() + 1) > self.args.warmup_iterations
- )
+ return (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 and (
+ self.get_num_updates() + 1
+ ) > self.args.warmup_iterations
else:
return True
@@ -899,13 +946,15 @@ def _all_gather_list_sync(
raise NotImplementedError
if ignore:
logging_outputs = []
- results = list(zip(
- *distributed_utils.all_gather_list(
- [logging_outputs] + list(extra_stats_to_sum),
- max_size=getattr(self.args, 'all_gather_list_size', 16384),
- group=self.data_parallel_process_group,
+ results = list(
+ zip(
+ *distributed_utils.all_gather_list(
+ [logging_outputs] + list(extra_stats_to_sum),
+ max_size=getattr(self.args, "all_gather_list_size", 16384),
+ group=self.data_parallel_process_group,
+ )
)
- ))
+ )
logging_outputs, extra_stats_to_sum = results[0], results[1:]
logging_outputs = list(chain.from_iterable(logging_outputs))
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
@@ -925,7 +974,7 @@ def _fast_stat_sync_sum(
"""
data = {}
for i, stat in enumerate(extra_stats_to_sum):
- data['extra_stats_' + str(i)] = stat
+ data["extra_stats_" + str(i)] = stat
if len(logging_outputs) > 0:
log_keys = list(logging_outputs[0].keys())
for k in log_keys:
@@ -934,21 +983,19 @@ def _fast_stat_sync_sum(
else:
v = logging_outputs[0][k]
v = torch.zeros_like(v) if torch.is_tensor(v) else 0
- data['logging_outputs_' + k] = v
+ data["logging_outputs_" + k] = v
else:
log_keys = None
data = distributed_utils.all_reduce_dict(
- data,
- device=self.device,
- group=self.data_parallel_process_group
+ data, device=self.device, group=self.data_parallel_process_group
)
extra_stats_to_sum = [
- data['extra_stats_' + str(i)] for i in range(len(extra_stats_to_sum))
+ data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum))
]
if log_keys is not None:
- logging_outputs = [{k: data['logging_outputs_' + k] for k in log_keys}]
+ logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}]
else:
logging_outputs = []
return logging_outputs, extra_stats_to_sum
@@ -959,8 +1006,7 @@ def _check_grad_norms(self, grad_norm):
self._grad_norm_buf.zero_()
self._grad_norm_buf[self.data_parallel_rank] = grad_norm
distributed_utils.all_reduce(
- self._grad_norm_buf,
- group=self.data_parallel_process_group
+ self._grad_norm_buf, group=self.data_parallel_process_group
)
def is_consistent(tensor):
@@ -975,7 +1021,9 @@ def is_consistent(tensor):
"rank {:3d} = {:.8f}".format(r, n)
for r, n in enumerate(self._grad_norm_buf.tolist())
)
- error_detail = "grad_norm across the workers:\n{}\n".format(pretty_detail)
+ error_detail = "grad_norm across the workers:\n{}\n".format(
+ pretty_detail
+ )
raise RuntimeError(
"Fatal error: gradients are inconsistent between workers. "
"Try --ddp-backend=no_c10d. "
@@ -988,7 +1036,7 @@ def is_consistent(tensor):
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
if grad_norm is not None:
- metrics.log_speed("ups", 1., priority=100, round=2)
+ metrics.log_speed("ups", 1.0, priority=100, round=2)
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
if self.args.clip_norm > 0:
metrics.log_scalar(
@@ -1030,6 +1078,7 @@ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
def _check_xla_compilation(self):
import torch_xla.debug.metrics as met
+
compile_stats = met.metric_data("CompileTime")
if compile_stats is None:
return
@@ -1037,41 +1086,42 @@ def _check_xla_compilation(self):
if num_xla_compiles > self._num_xla_compiles:
logger.warning(
"XLA compilation detected on device #{}; too many of these can lead "
- "to slow training, but we expect a few in the beginning"
- .format(self.args.distributed_rank)
+ "to slow training, but we expect a few in the beginning".format(
+ self.args.distributed_rank
+ )
)
self._num_xla_compiles = num_xla_compiles
-def _catalog_shared_params(module, memo=None, prefix=''):
+def _catalog_shared_params(module, memo=None, prefix=""):
if memo is None:
first_call = True
memo = {}
else:
first_call = False
for name, param in module._parameters.items():
- param_prefix = prefix + ('.' if prefix else '') + name
+ param_prefix = prefix + ("." if prefix else "") + name
if param not in memo:
memo[param] = []
memo[param].append(param_prefix)
for name, m in module._modules.items():
if m is None:
continue
- submodule_prefix = prefix + ('.' if prefix else '') + name
+ submodule_prefix = prefix + ("." if prefix else "") + name
_catalog_shared_params(m, memo, submodule_prefix)
if first_call:
return [x for x in memo.values() if len(x) > 1]
def _get_module_by_path(module, path):
- path = path.split('.')
+ path = path.split(".")
for name in path:
module = getattr(module, name)
return module
def _set_module_by_path(module, path, value):
- path = path.split('.')
+ path = path.split(".")
for name in path[:-1]:
module = getattr(module, name)
setattr(module, path[-1], value)
diff --git a/fairseq/utils.py b/fairseq/utils.py
index 1a18bf5e6c..fdbf66cf3f 100644
--- a/fairseq/utils.py
+++ b/fairseq/utils.py
@@ -445,7 +445,7 @@ def import_user_module(args):
# temporary directory and symlink the user_dir under a new name, which is
# a deterministic hash of the original module_path.
with tempfile.TemporaryDirectory() as tmpdirname:
- unique_mod_name = 'fairseq_user_dir_{}'.format(hash(module_path) % 100000)
+ unique_mod_name = "fairseq_user_dir_{}".format(hash(module_path) % 100000)
os.symlink(module_path, os.path.join(tmpdirname, unique_mod_name))
sys.path.insert(0, tmpdirname)
diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py
index 64c83673e6..9a4ff8ee39 100644
--- a/fairseq_cli/eval_lm.py
+++ b/fairseq_cli/eval_lm.py
@@ -13,21 +13,19 @@
import os
import torch
-
-from fairseq import checkpoint_utils, options, tasks, utils
+from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
-from fairseq import distributed_utils
logging.basicConfig(
- format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S',
- level=os.environ.get('LOGLEVEL', 'INFO').upper(),
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
)
-logger = logging.getLogger('fairseq_cli.eval_lm')
+logger = logging.getLogger("fairseq_cli.eval_lm")
class WordStat(object):
@@ -40,10 +38,10 @@ def __init__(self, word, is_bpe):
self.missing_next_words = 0
def add(self, log_prob, next_word_prob):
- """ increments counters for the sum of log probs of current word and next
- word (given context ending at current word). Since the next word might be at the end of the example,
- or it might be not counted because it is not an ending subword unit,
- also keeps track of how many of those we have seen """
+ """increments counters for the sum of log probs of current word and next
+ word (given context ending at current word). Since the next word might be at the end of the example,
+ or it might be not counted because it is not an ending subword unit,
+ also keeps track of how many of those we have seen"""
if next_word_prob is not None:
self.next_word_prob += next_word_prob
else:
@@ -52,12 +50,18 @@ def add(self, log_prob, next_word_prob):
self.count += 1
def __str__(self):
- return '{}\t{}\t{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob, self.is_bpe,
- self.next_word_prob, self.count - self.missing_next_words)
+ return "{}\t{}\t{}\t{}\t{}\t{}".format(
+ self.word,
+ self.count,
+ self.log_prob,
+ self.is_bpe,
+ self.next_word_prob,
+ self.count - self.missing_next_words,
+ )
def main(parsed_args, **unused_kwargs):
- assert parsed_args.path is not None, '--path required for evaluation!'
+ assert parsed_args.path is not None, "--path required for evaluation!"
if torch.cuda.is_available() and not parsed_args.cpu:
torch.cuda.set_device(parsed_args.device_id)
@@ -71,7 +75,7 @@ def main(parsed_args, **unused_kwargs):
task = tasks.setup_task(parsed_args)
# Load ensemble
- logger.info('loading model(s) from {}'.format(parsed_args.path))
+ logger.info("loading model(s) from {}".format(parsed_args.path))
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(os.pathsep),
arg_overrides=eval(parsed_args.model_overrides),
@@ -83,8 +87,12 @@ def main(parsed_args, **unused_kwargs):
for arg in vars(parsed_args).keys():
if arg not in {
- 'self_target', 'future_target', 'past_target', 'tokens_per_sample',
- 'output_size_dictionary', 'add_bos_token',
+ "self_target",
+ "future_target",
+ "past_target",
+ "tokens_per_sample",
+ "output_size_dictionary",
+ "add_bos_token",
}:
setattr(args, arg, getattr(parsed_args, arg))
@@ -102,7 +110,7 @@ def main(parsed_args, **unused_kwargs):
context_window=args.context_window,
pad_idx=task.source_dictionary.pad(),
)
- logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
+ logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset)))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
@@ -114,15 +122,17 @@ def main(parsed_args, **unused_kwargs):
assert len(models) > 0
- logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))
+ logger.info(
+ "num. model params: {}".format(sum(p.numel() for p in models[0].parameters()))
+ )
itr = task.get_batch_iterator(
dataset=dataset,
max_tokens=args.max_tokens or 36000,
max_sentences=args.batch_size,
- max_positions=utils.resolve_max_positions(*[
- model.max_positions() for model in models
- ]),
+ max_positions=utils.resolve_max_positions(
+ *[model.max_positions() for model in models]
+ ),
ignore_invalid_inputs=True,
num_shards=args.num_shards,
shard_id=args.shard_id,
@@ -133,17 +143,17 @@ def main(parsed_args, **unused_kwargs):
itr,
log_format=args.log_format,
log_interval=args.log_interval,
- default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
+ default_log_format=("tqdm" if not args.no_progress_bar else "none"),
)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
- score_sum = 0.
+ score_sum = 0.0
count = 0
if args.remove_bpe is not None:
- if args.remove_bpe == 'sentencepiece':
+ if args.remove_bpe == "sentencepiece":
raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
@@ -162,25 +172,25 @@ def main(parsed_args, **unused_kwargs):
wps_meter = TimeMeter()
for sample in progress:
- if 'net_input' not in sample:
+ if "net_input" not in sample:
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
gen_timer.start()
hypos = scorer.generate(models, sample)
- gen_timer.stop(sample['ntokens'])
+ gen_timer.stop(sample["ntokens"])
for i, hypos_i in enumerate(hypos):
hypo = hypos_i[0]
- sample_id = sample['id'][i]
+ sample_id = sample["id"][i]
- tokens = hypo['tokens']
+ tokens = hypo["tokens"]
tgt_len = tokens.numel()
- pos_scores = hypo['positional_scores'].float()
+ pos_scores = hypo["positional_scores"].float()
- if getattr(args, 'add_bos_token', False):
- assert hypo['tokens'][0].item() == task.target_dictionary.bos()
+ if getattr(args, "add_bos_token", False):
+ assert hypo["tokens"][0].item() == task.target_dictionary.bos()
tokens = tokens[1:]
pos_scores = pos_scores[1:]
@@ -192,18 +202,18 @@ def main(parsed_args, **unused_kwargs):
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0
- inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
+ inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf"))
if inf_scores.any():
logger.info(
- 'skipping tokens with inf scores:',
- task.target_dictionary.string(tokens[inf_scores.nonzero()])
+ "skipping tokens with inf scores:",
+ task.target_dictionary.string(tokens[inf_scores.nonzero()]),
)
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks
if args.output_word_probs or args.output_word_stats:
- w = ''
+ w = ""
word_prob = []
is_bpe = False
for i in range(len(tokens)):
@@ -223,25 +233,36 @@ def main(parsed_args, **unused_kwargs):
break
ind += 1
- word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
+ word_stats.setdefault(w, WordStat(w, is_bpe)).add(
+ pos_scores[i].item(), next_prob
+ )
is_bpe = False
- w = ''
+ w = ""
if args.output_word_probs:
logger.info(
- str(int(sample_id)) + " "
- + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
+ str(int(sample_id))
+ + " "
+ + (
+ "\t".join(
+ "{} [{:2f}]".format(x[0], x[1]) for x in word_prob
+ )
+ )
)
- wps_meter.update(sample['ntokens'])
- progress.log({'wps': round(wps_meter.avg)})
+ wps_meter.update(sample["ntokens"])
+ progress.log({"wps": round(wps_meter.avg)})
avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2
- logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
- gen_timer.n, gen_timer.sum, 1. / gen_timer.avg
- ))
- logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
- avg_nll_loss, 2**avg_nll_loss
- ))
+ logger.info(
+ "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format(
+ gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg
+ )
+ )
+ logger.info(
+ "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format(
+ avg_nll_loss, 2 ** avg_nll_loss
+ )
+ )
if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
@@ -254,5 +275,5 @@ def cli_main():
distributed_utils.call_main(args, main)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py
index 0064b88a95..8ddf981cc3 100644
--- a/fairseq_cli/generate.py
+++ b/fairseq_cli/generate.py
@@ -8,39 +8,41 @@
"""
import ast
-from itertools import chain
import logging
import math
import os
import sys
+from itertools import chain
import numpy as np
-
import torch
-
from fairseq import checkpoint_utils, options, scoring, tasks, utils
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
def main(args):
- assert args.path is not None, '--path required for generation!'
- assert not args.sampling or args.nbest == args.beam, \
- '--sampling requires --nbest to be equal to --beam'
- assert args.replace_unk is None or args.dataset_impl == 'raw', \
- '--replace-unk requires a raw text dataset (--dataset-impl=raw)'
+ assert args.path is not None, "--path required for generation!"
+ assert (
+ not args.sampling or args.nbest == args.beam
+ ), "--sampling requires --nbest to be equal to --beam"
+ assert (
+ args.replace_unk is None or args.dataset_impl == "raw"
+ ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
if args.results_path is not None:
os.makedirs(args.results_path, exist_ok=True)
- output_path = os.path.join(args.results_path, 'generate-{}.txt'.format(args.gen_subset))
- with open(output_path, 'w', buffering=1, encoding='utf-8') as h:
+ output_path = os.path.join(
+ args.results_path, "generate-{}.txt".format(args.gen_subset)
+ )
+ with open(output_path, "w", buffering=1, encoding="utf-8") as h:
return _main(args, h)
else:
return _main(args, sys.stdout)
def get_symbols_to_strip_from_output(generator):
- if hasattr(generator, 'symbols_to_strip_from_output'):
+ if hasattr(generator, "symbols_to_strip_from_output"):
return generator.symbols_to_strip_from_output
else:
return {generator.eos}
@@ -48,12 +50,12 @@ def get_symbols_to_strip_from_output(generator):
def _main(args, output_file):
logging.basicConfig(
- format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S',
- level=os.environ.get('LOGLEVEL', 'INFO').upper(),
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=output_file,
)
- logger = logging.getLogger('fairseq_cli.generate')
+ logger = logging.getLogger("fairseq_cli.generate")
utils.import_user_module(args)
@@ -74,7 +76,7 @@ def _main(args, output_file):
# Set dictionaries
try:
- src_dict = getattr(task, 'source_dictionary', None)
+ src_dict = getattr(task, "source_dictionary", None)
except NotImplementedError:
src_dict = None
tgt_dict = task.target_dictionary
@@ -82,7 +84,7 @@ def _main(args, output_file):
overrides = ast.literal_eval(args.model_overrides)
# Load ensemble
- logger.info('loading model(s) from {}'.format(args.path))
+ logger.info("loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
utils.split_paths(args.path),
arg_overrides=overrides,
@@ -93,7 +95,7 @@ def _main(args, output_file):
)
if args.lm_path is not None:
- overrides['data'] = args.data
+ overrides["data"] = args.data
try:
lms, _ = checkpoint_utils.load_model_ensemble(
@@ -102,8 +104,10 @@ def _main(args, output_file):
task=None,
)
except:
- logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same "
- f"as target dict and is located in the data dir ({args.data})")
+ logger.warning(
+ f"Failed to load language model! Please make sure that the language model dict is the same "
+ f"as target dict and is located in the data dir ({args.data})"
+ )
raise
assert len(lms) == 1
@@ -130,8 +134,7 @@ def _main(args, output_file):
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=utils.resolve_max_positions(
- task.max_positions(),
- *[model.max_positions() for model in models]
+ task.max_positions(), *[model.max_positions() for model in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
@@ -144,17 +147,16 @@ def _main(args, output_file):
itr,
log_format=args.log_format,
log_interval=args.log_interval,
- default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
+ default_log_format=("tqdm" if not args.no_progress_bar else "none"),
)
# Initialize generator
gen_timer = StopwatchMeter()
- extra_gen_cls_kwargs = {
- 'lm_model': lms[0],
- 'lm_weight': args.lm_weight
- }
- generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs)
+ extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": args.lm_weight}
+ generator = task.build_generator(
+ models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs
+ )
# Handle tokenization and BPE
tokenizer = task.build_tokenizer(args)
@@ -174,39 +176,51 @@ def decode_fn(x):
wps_meter = TimeMeter()
for sample in progress:
sample = utils.move_to_cuda(sample) if use_cuda else sample
- if 'net_input' not in sample:
+ if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
- prefix_tokens = sample['target'][:, :args.prefix_size]
+ prefix_tokens = sample["target"][:, : args.prefix_size]
constraints = None
if "constraints" in sample:
constraints = sample["constraints"]
gen_timer.start()
- hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints)
- num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
+ hypos = task.inference_step(
+ generator,
+ models,
+ sample,
+ prefix_tokens=prefix_tokens,
+ constraints=constraints,
+ )
+ num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)
- for i, sample_id in enumerate(sample['id'].tolist()):
- has_target = sample['target'] is not None
+ for i, sample_id in enumerate(sample["id"].tolist()):
+ has_target = sample["target"] is not None
# Remove padding
- if 'src_tokens' in sample['net_input']:
- src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
+ if "src_tokens" in sample["net_input"]:
+ src_tokens = utils.strip_pad(
+ sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
+ )
else:
src_tokens = None
target_tokens = None
if has_target:
- target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()
+ target_tokens = (
+ utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
+ )
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
- target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
+ target_str = task.dataset(args.gen_subset).tgt.get_original_text(
+ sample_id
+ )
else:
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
@@ -217,7 +231,9 @@ def decode_fn(x):
target_tokens,
args.remove_bpe,
escape_unk=True,
- extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(
+ generator
+ ),
)
src_str = decode_fn(src_str)
@@ -226,16 +242,16 @@ def decode_fn(x):
if not args.quiet:
if src_dict is not None:
- print('S-{}\t{}'.format(sample_id, src_str), file=output_file)
+ print("S-{}\t{}".format(sample_id, src_str), file=output_file)
if has_target:
- print('T-{}\t{}'.format(sample_id, target_str), file=output_file)
+ print("T-{}\t{}".format(sample_id, target_str), file=output_file)
# Process top predictions
- for j, hypo in enumerate(hypos[i][:args.nbest]):
+ for j, hypo in enumerate(hypos[i][: args.nbest]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
- hypo_tokens=hypo['tokens'].int().cpu(),
+ hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str,
- alignment=hypo['alignment'],
+ alignment=hypo["alignment"],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
@@ -243,71 +259,116 @@ def decode_fn(x):
)
detok_hypo_str = decode_fn(hypo_str)
if not args.quiet:
- score = hypo['score'] / math.log(2) # convert to base 2
+ score = hypo["score"] / math.log(2) # convert to base 2
# original hypothesis (after tokenization and BPE)
- print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file)
+ print(
+ "H-{}\t{}\t{}".format(sample_id, score, hypo_str),
+ file=output_file,
+ )
# detokenized hypothesis
- print('D-{}\t{}\t{}'.format(sample_id, score, detok_hypo_str), file=output_file)
- print('P-{}\t{}'.format(
- sample_id,
- ' '.join(map(
- lambda x: '{:.4f}'.format(x),
- # convert from base e to base 2
- hypo['positional_scores'].div_(math.log(2)).tolist(),
- ))
- ), file=output_file)
+ print(
+ "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str),
+ file=output_file,
+ )
+ print(
+ "P-{}\t{}".format(
+ sample_id,
+ " ".join(
+ map(
+ lambda x: "{:.4f}".format(x),
+ # convert from base e to base 2
+ hypo["positional_scores"]
+ .div_(math.log(2))
+ .tolist(),
+ )
+ ),
+ ),
+ file=output_file,
+ )
if args.print_alignment:
- print('A-{}\t{}'.format(
- sample_id,
- ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment])
- ), file=output_file)
+ print(
+ "A-{}\t{}".format(
+ sample_id,
+ " ".join(
+ [
+ "{}-{}".format(src_idx, tgt_idx)
+ for src_idx, tgt_idx in alignment
+ ]
+ ),
+ ),
+ file=output_file,
+ )
if args.print_step:
- print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file)
+ print(
+ "I-{}\t{}".format(sample_id, hypo["steps"]),
+ file=output_file,
+ )
- if getattr(args, 'retain_iter_history', False):
- for step, h in enumerate(hypo['history']):
+ if getattr(args, "retain_iter_history", False):
+ for step, h in enumerate(hypo["history"]):
_, h_str, _ = utils.post_process_prediction(
- hypo_tokens=h['tokens'].int().cpu(),
+ hypo_tokens=h["tokens"].int().cpu(),
src_str=src_str,
alignment=None,
align_dict=None,
tgt_dict=tgt_dict,
remove_bpe=None,
)
- print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file)
+ print(
+ "E-{}_{}\t{}".format(sample_id, step, h_str),
+ file=output_file,
+ )
# Score only the top hypothesis
if has_target and j == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
- target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
- hypo_tokens = tgt_dict.encode_line(detok_hypo_str, add_if_not_exist=True)
- if hasattr(scorer, 'add_string'):
+ target_tokens = tgt_dict.encode_line(
+ target_str, add_if_not_exist=True
+ )
+ hypo_tokens = tgt_dict.encode_line(
+ detok_hypo_str, add_if_not_exist=True
+ )
+ if hasattr(scorer, "add_string"):
scorer.add_string(target_str, detok_hypo_str)
else:
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(num_generated_tokens)
- progress.log({'wps': round(wps_meter.avg)})
- num_sentences += sample["nsentences"] if "nsentences" in sample else sample['id'].numel()
-
- logger.info('NOTE: hypothesis and token scores are output in base 2')
- logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
- num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
+ progress.log({"wps": round(wps_meter.avg)})
+ num_sentences += (
+ sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
+ )
+
+ logger.info("NOTE: hypothesis and token scores are output in base 2")
+ logger.info(
+ "Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
+ num_sentences,
+ gen_timer.n,
+ gen_timer.sum,
+ num_sentences / gen_timer.sum,
+ 1.0 / gen_timer.avg,
+ )
+ )
if has_target:
if args.bpe and not args.sacrebleu:
if args.remove_bpe:
logger.warning(
- "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization")
+ "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
+ )
else:
logger.warning(
- "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization")
+ "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization"
+ )
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on
print(
- 'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()),
- file=output_file)
+ "Generate {} with beam={}: {}".format(
+ args.gen_subset, args.beam, scorer.result_string()
+ ),
+ file=output_file,
+ )
return scorer
@@ -318,5 +379,5 @@ def cli_main():
main(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py
index fc4b46e39d..de3893a385 100644
--- a/fairseq_cli/interactive.py
+++ b/fairseq_cli/interactive.py
@@ -7,34 +7,33 @@
Translate raw text with a trained model. Batches data on-the-fly.
"""
-from collections import namedtuple
import fileinput
import logging
import math
+import os
import sys
import time
-import os
+from collections import namedtuple
import numpy as np
-
import torch
-
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import encoders
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output
+
logging.basicConfig(
- format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S',
- level=os.environ.get('LOGLEVEL', 'INFO').upper(),
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
-logger = logging.getLogger('fairseq_cli.interactive')
+logger = logging.getLogger("fairseq_cli.interactive")
-Batch = namedtuple('Batch', 'ids src_tokens src_lengths constraints')
-Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
+Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
+Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
def buffered_read(input, buffer_size):
@@ -64,11 +63,14 @@ def encode_fn_target(x):
# Convert each List[str] to List[Tensor]
for i, constraint_list in enumerate(batch_constraints):
- batch_constraints[i] = [task.target_dictionary.encode_line(
- encode_fn_target(constraint),
- append_eos=False,
- add_if_not_exist=False,
- ) for constraint in constraint_list]
+ batch_constraints[i] = [
+ task.target_dictionary.encode_line(
+ encode_fn_target(constraint),
+ append_eos=False,
+ add_if_not_exist=False,
+ )
+ for constraint in constraint_list
+ ]
tokens = [
task.source_dictionary.encode_line(
@@ -84,16 +86,18 @@ def encode_fn_target(x):
lengths = [t.numel() for t in tokens]
itr = task.get_batch_iterator(
- dataset=task.build_dataset_for_inference(tokens, lengths, constraints=constraints_tensor),
+ dataset=task.build_dataset_for_inference(
+ tokens, lengths, constraints=constraints_tensor
+ ),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=max_positions,
- ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test
+ ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
).next_epoch_itr(shuffle=False)
for batch in itr:
- ids = batch['id']
- src_tokens = batch['net_input']['src_tokens']
- src_lengths = batch['net_input']['src_lengths']
+ ids = batch["id"]
+ src_tokens = batch["net_input"]["src_tokens"]
+ src_lengths = batch["net_input"]["src_lengths"]
constraints = batch.get("constraints", None)
yield Batch(
@@ -115,10 +119,12 @@ def main(args):
if args.max_tokens is None and args.batch_size is None:
args.batch_size = 1
- assert not args.sampling or args.nbest == args.beam, \
- '--sampling requires --nbest to be equal to --beam'
- assert not args.batch_size or args.batch_size <= args.buffer_size, \
- '--batch-size cannot be larger than --buffer-size'
+ assert (
+ not args.sampling or args.nbest == args.beam
+ ), "--sampling requires --nbest to be equal to --beam"
+ assert (
+ not args.batch_size or args.batch_size <= args.buffer_size
+ ), "--batch-size cannot be larger than --buffer-size"
logger.info(args)
@@ -133,7 +139,7 @@ def main(args):
task = tasks.setup_task(args)
# Load ensemble
- logger.info('loading model(s) from {}'.format(args.path))
+ logger.info("loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(os.pathsep),
arg_overrides=eval(args.model_overrides),
@@ -181,17 +187,18 @@ def decode_fn(x):
align_dict = utils.load_align_dict(args.replace_unk)
max_positions = utils.resolve_max_positions(
- task.max_positions(),
- *[model.max_positions() for model in models]
+ task.max_positions(), *[model.max_positions() for model in models]
)
if args.constraints:
- logger.warning("NOTE: Constrained decoding currently assumes a shared subword vocabulary.")
+ logger.warning(
+ "NOTE: Constrained decoding currently assumes a shared subword vocabulary."
+ )
if args.buffer_size > 1:
- logger.info('Sentence buffer size: %s', args.buffer_size)
- logger.info('NOTE: hypothesis and token scores are output in base 2')
- logger.info('Type the input sentence and press return:')
+ logger.info("Sentence buffer size: %s", args.buffer_size)
+ logger.info("NOTE: hypothesis and token scores are output in base 2")
+ logger.info("Type the input sentence and press return:")
start_id = 0
for inputs in buffered_read(args.input, args.buffer_size):
results = []
@@ -207,13 +214,15 @@ def decode_fn(x):
constraints = constraints.cuda()
sample = {
- 'net_input': {
- 'src_tokens': src_tokens,
- 'src_lengths': src_lengths,
+ "net_input": {
+ "src_tokens": src_tokens,
+ "src_lengths": src_lengths,
},
}
translate_start_time = time.time()
- translations = task.inference_step(generator, models, sample, constraints=constraints)
+ translations = task.inference_step(
+ generator, models, sample, constraints=constraints
+ )
translate_time = time.time() - translate_start_time
total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)]
@@ -222,56 +231,75 @@ def decode_fn(x):
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
constraints = list_constraints[i]
- results.append((start_id + id, src_tokens_i, hypos,
- { "constraints": constraints,
- "time": translate_time / len(translations) }
- ))
+ results.append(
+ (
+ start_id + id,
+ src_tokens_i,
+ hypos,
+ {
+ "constraints": constraints,
+ "time": translate_time / len(translations),
+ },
+ )
+ )
# sort output to match input order
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
- print('S-{}\t{}'.format(id_, src_str))
+ print("S-{}\t{}".format(id_, src_str))
print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
for constraint in info["constraints"]:
- print("C-{}\t{}".format(id_, tgt_dict.string(constraint, args.remove_bpe)))
+ print(
+ "C-{}\t{}".format(
+ id_, tgt_dict.string(constraint, args.remove_bpe)
+ )
+ )
# Process top predictions
- for hypo in hypos[:min(len(hypos), args.nbest)]:
+ for hypo in hypos[: min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
- hypo_tokens=hypo['tokens'].int().cpu(),
+ hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str,
- alignment=hypo['alignment'],
+ alignment=hypo["alignment"],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)
detok_hypo_str = decode_fn(hypo_str)
- score = hypo['score'] / math.log(2) # convert to base 2
+ score = hypo["score"] / math.log(2) # convert to base 2
# original hypothesis (after tokenization and BPE)
- print('H-{}\t{}\t{}'.format(id_, score, hypo_str))
+ print("H-{}\t{}\t{}".format(id_, score, hypo_str))
# detokenized hypothesis
- print('D-{}\t{}\t{}'.format(id_, score, detok_hypo_str))
- print('P-{}\t{}'.format(
- id_,
- ' '.join(map(
- lambda x: '{:.4f}'.format(x),
- # convert from base e to base 2
- hypo['positional_scores'].div_(math.log(2)).tolist(),
- ))
- ))
- if args.print_alignment:
- alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment])
- print('A-{}\t{}'.format(
+ print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
+ print(
+ "P-{}\t{}".format(
id_,
- alignment_str
- ))
+ " ".join(
+ map(
+ lambda x: "{:.4f}".format(x),
+ # convert from base e to base 2
+ hypo["positional_scores"].div_(math.log(2)).tolist(),
+ )
+ ),
+ )
+ )
+ if args.print_alignment:
+ alignment_str = " ".join(
+ ["{}-{}".format(src, tgt) for src, tgt in alignment]
+ )
+ print("A-{}\t{}".format(id_, alignment_str))
# update running id_ counter
start_id += len(inputs)
- logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format(time.time() - start_time, total_translate_time))
+ logger.info(
+ "Total time: {:.3f} seconds; translation time: {:.3f}".format(
+ time.time() - start_time, total_translate_time
+ )
+ )
+
def cli_main():
parser = options.get_interactive_generation_parser()
@@ -279,5 +307,5 @@ def cli_main():
distributed_utils.call_main(args, main)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/fairseq_cli/preprocess.py b/fairseq_cli/preprocess.py
index 3fe5131324..fa77da8dba 100644
--- a/fairseq_cli/preprocess.py
+++ b/fairseq_cli/preprocess.py
@@ -7,26 +7,26 @@
Data pre-processing: build vocabularies and binarize training data.
"""
-from collections import Counter
-from itertools import zip_longest
import logging
-from multiprocessing import Pool
import os
import shutil
import sys
+from collections import Counter
+from itertools import zip_longest
+from multiprocessing import Pool
from fairseq import options, tasks, utils
-from fairseq.data import indexed_dataset
from fairseq.binarizer import Binarizer
+from fairseq.data import indexed_dataset
logging.basicConfig(
- format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S',
- level=os.environ.get('LOGLEVEL', 'INFO').upper(),
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
-logger = logging.getLogger('fairseq_cli.preprocess')
+logger = logging.getLogger("fairseq_cli.preprocess")
def main(args):
@@ -34,9 +34,11 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True)
- logger.addHandler(logging.FileHandler(
- filename=os.path.join(args.destdir, 'preprocess.log'),
- ))
+ logger.addHandler(
+ logging.FileHandler(
+ filename=os.path.join(args.destdir, "preprocess.log"),
+ )
+ )
logger.info(args)
task = tasks.get_task(args.task)
@@ -74,31 +76,39 @@ def build_dictionary(filenames, src=False, tgt=False):
raise FileExistsError(dict_path(args.target_lang))
if args.joined_dictionary:
- assert not args.srcdict or not args.tgtdict, \
- "cannot use both --srcdict and --tgtdict with --joined-dictionary"
+ assert (
+ not args.srcdict or not args.tgtdict
+ ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
elif args.tgtdict:
src_dict = task.load_dictionary(args.tgtdict)
else:
- assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
+ assert (
+ args.trainpref
+ ), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary(
- {train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True
+ {train_path(lang) for lang in [args.source_lang, args.target_lang]},
+ src=True,
)
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
else:
- assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
+ assert (
+ args.trainpref
+ ), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)], src=True)
if target:
if args.tgtdict:
tgt_dict = task.load_dictionary(args.tgtdict)
else:
- assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
+ assert (
+ args.trainpref
+ ), "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
else:
tgt_dict = None
@@ -135,18 +145,20 @@ def merge_result(worker_result):
prefix,
lang,
offsets[worker_id],
- offsets[worker_id + 1]
+ offsets[worker_id + 1],
),
- callback=merge_result
+ callback=merge_result,
)
pool.close()
- ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"),
- impl=args.dataset_impl, vocab_size=len(vocab))
+ ds = indexed_dataset.make_builder(
+ dataset_dest_file(args, output_prefix, lang, "bin"),
+ impl=args.dataset_impl,
+ vocab_size=len(vocab),
+ )
merge_result(
Binarizer.binarize(
- input_file, vocab, lambda t: ds.add_item(t),
- offset=0, end=offsets[1]
+ input_file, vocab, lambda t: ds.add_item(t), offset=0, end=offsets[1]
)
)
if num_workers > 1:
@@ -175,7 +187,7 @@ def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers):
nseq = [0]
def merge_result(worker_result):
- nseq[0] += worker_result['nseq']
+ nseq[0] += worker_result["nseq"]
input_file = input_prefix
offsets = Binarizer.find_offsets(input_file, num_workers)
@@ -192,19 +204,23 @@ def merge_result(worker_result):
utils.parse_alignment,
prefix,
offsets[worker_id],
- offsets[worker_id + 1]
+ offsets[worker_id + 1],
),
- callback=merge_result
+ callback=merge_result,
)
pool.close()
- ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"),
- impl=args.dataset_impl)
+ ds = indexed_dataset.make_builder(
+ dataset_dest_file(args, output_prefix, None, "bin"), impl=args.dataset_impl
+ )
merge_result(
Binarizer.binarize_alignments(
- input_file, utils.parse_alignment, lambda t: ds.add_item(t),
- offset=0, end=offsets[1]
+ input_file,
+ utils.parse_alignment,
+ lambda t: ds.add_item(t),
+ offset=0,
+ end=offsets[1],
)
)
if num_workers > 1:
@@ -218,12 +234,7 @@ def merge_result(worker_result):
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
- logger.info(
- "[alignments] {}: parsed {} alignments".format(
- input_file,
- nseq[0]
- )
- )
+ logger.info("[alignments] {}: parsed {} alignments".format(input_file, nseq[0]))
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
if args.dataset_impl == "raw":
@@ -242,7 +253,9 @@ def make_all(lang, vocab):
if args.validpref:
for k, validpref in enumerate(args.validpref.split(",")):
outprefix = "valid{}".format(k) if k > 0 else "valid"
- make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers)
+ make_dataset(
+ vocab, validpref, outprefix, lang, num_workers=args.workers
+ )
if args.testpref:
for k, testpref in enumerate(args.testpref.split(",")):
outprefix = "test{}".format(k) if k > 0 else "test"
@@ -250,11 +263,23 @@ def make_all(lang, vocab):
def make_all_alignments():
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
- make_binary_alignment_dataset(args.trainpref + "." + args.align_suffix, "train.align", num_workers=args.workers)
+ make_binary_alignment_dataset(
+ args.trainpref + "." + args.align_suffix,
+ "train.align",
+ num_workers=args.workers,
+ )
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
- make_binary_alignment_dataset(args.validpref + "." + args.align_suffix, "valid.align", num_workers=args.workers)
+ make_binary_alignment_dataset(
+ args.validpref + "." + args.align_suffix,
+ "valid.align",
+ num_workers=args.workers,
+ )
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
- make_binary_alignment_dataset(args.testpref + "." + args.align_suffix, "test.align", num_workers=args.workers)
+ make_binary_alignment_dataset(
+ args.testpref + "." + args.align_suffix,
+ "test.align",
+ num_workers=args.workers,
+ )
make_all(args.source_lang, src_dict)
if target:
@@ -269,9 +294,9 @@ def make_all_alignments():
src_file_name = train_path(args.source_lang)
tgt_file_name = train_path(args.target_lang)
freq_map = {}
- with open(args.alignfile, "r", encoding='utf-8') as align_file:
- with open(src_file_name, "r", encoding='utf-8') as src_file:
- with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
+ with open(args.alignfile, "r", encoding="utf-8") as align_file:
+ with open(src_file_name, "r", encoding="utf-8") as src_file:
+ with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = src_dict.encode_line(s, add_if_not_exist=False)
ti = tgt_dict.encode_line(t, add_if_not_exist=False)
@@ -297,38 +322,47 @@ def make_all_alignments():
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(
- os.path.join(
- args.destdir,
- "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
- ),
- "w", encoding='utf-8'
+ os.path.join(
+ args.destdir,
+ "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
+ ),
+ "w",
+ encoding="utf-8",
) as f:
for k, v in align_dict.items():
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
- ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"),
- impl=args.dataset_impl, vocab_size=len(vocab))
+ ds = indexed_dataset.make_builder(
+ dataset_dest_file(args, output_prefix, lang, "bin"),
+ impl=args.dataset_impl,
+ vocab_size=len(vocab),
+ )
def consumer(tensor):
ds.add_item(tensor)
- res = Binarizer.binarize(filename, vocab, consumer, append_eos=append_eos,
- offset=offset, end=end)
+ res = Binarizer.binarize(
+ filename, vocab, consumer, append_eos=append_eos, offset=offset, end=end
+ )
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
return res
def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end):
- ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"),
- impl=args.dataset_impl, vocab_size=None)
+ ds = indexed_dataset.make_builder(
+ dataset_dest_file(args, output_prefix, None, "bin"),
+ impl=args.dataset_impl,
+ vocab_size=None,
+ )
def consumer(tensor):
ds.add_item(tensor)
- res = Binarizer.binarize_alignments(filename, parse_alignment, consumer, offset=offset,
- end=end)
+ res = Binarizer.binarize_alignments(
+ filename, parse_alignment, consumer, offset=offset, end=end
+ )
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
return res
diff --git a/fairseq_cli/score.py b/fairseq_cli/score.py
index 59631c2d65..b8354eb95a 100644
--- a/fairseq_cli/score.py
+++ b/fairseq_cli/score.py
@@ -11,12 +11,14 @@
import os
import sys
-from fairseq.scoring import bleu
from fairseq.data import dictionary
+from fairseq.scoring import bleu
def get_parser():
- parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
+ parser = argparse.ArgumentParser(
+ description="Command-line script for BLEU scoring."
+ )
# fmt: off
parser.add_argument('-s', '--sys', default='-', help='system output')
parser.add_argument('-r', '--ref', required=True, help='references')
@@ -37,10 +39,10 @@ def cli_main():
args = parser.parse_args()
print(args)
- assert args.sys == '-' or os.path.exists(args.sys), \
- "System output file {} does not exist".format(args.sys)
- assert os.path.exists(args.ref), \
- "Reference file {} does not exist".format(args.ref)
+ assert args.sys == "-" or os.path.exists(
+ args.sys
+ ), "System output file {} does not exist".format(args.sys)
+ assert os.path.exists(args.ref), "Reference file {} does not exist".format(args.ref)
dict = dictionary.Dictionary()
@@ -57,17 +59,23 @@ def readlines(fd):
def score(fdsys):
with open(args.ref) as fdref:
print(sacrebleu.corpus_bleu(fdsys, [fdref]))
+
elif args.sentence_bleu:
+
def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
- for i, (sys_tok, ref_tok) in enumerate(zip(readlines(fdsys), readlines(fdref))):
+ for i, (sys_tok, ref_tok) in enumerate(
+ zip(readlines(fdsys), readlines(fdref))
+ ):
scorer.reset(one_init=True)
sys_tok = dict.encode_line(sys_tok)
ref_tok = dict.encode_line(ref_tok)
scorer.add(ref_tok, sys_tok)
print(i, scorer.result_string(args.order))
+
else:
+
def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
@@ -77,12 +85,12 @@ def score(fdsys):
scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order))
- if args.sys == '-':
+ if args.sys == "-":
score(sys.stdin)
else:
- with open(args.sys, 'r') as f:
+ with open(args.sys, "r") as f:
score(f)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py
index 717a776c8f..df857550d1 100644
--- a/fairseq_cli/validate.py
+++ b/fairseq_cli/validate.py
@@ -5,31 +5,31 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from itertools import chain
import logging
import os
import sys
+from itertools import chain
import torch
-
from fairseq import checkpoint_utils, distributed_utils, options, utils
from fairseq.logging import metrics, progress_bar
logging.basicConfig(
- format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S',
- level=os.environ.get('LOGLEVEL', 'INFO').upper(),
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
-logger = logging.getLogger('fairseq_cli.validate')
+logger = logging.getLogger("fairseq_cli.validate")
def main(args, override_args=None):
utils.import_user_module(args)
- assert args.max_tokens is not None or args.batch_size is not None, \
- 'Must specify batch size either with --max-tokens or --batch-size'
+ assert (
+ args.max_tokens is not None or args.batch_size is not None
+ ), "Must specify batch size either with --max-tokens or --batch-size"
use_fp16 = args.fp16
use_cuda = torch.cuda.is_available() and not args.cpu
@@ -39,12 +39,12 @@ def main(args, override_args=None):
if override_args is not None:
overrides = vars(override_args)
- overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
+ overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
else:
overrides = None
# Load ensemble
- logger.info('loading model(s) from {}'.format(args.path))
+ logger.info("loading model(s) from {}".format(args.path))
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
[args.path],
arg_overrides=overrides,
@@ -66,12 +66,12 @@ def main(args, override_args=None):
criterion = task.build_criterion(model_args)
criterion.eval()
- for subset in args.valid_subset.split(','):
+ for subset in args.valid_subset.split(","):
try:
task.load_dataset(subset, combine=False, epoch=1)
dataset = task.dataset(subset)
except KeyError:
- raise Exception('Cannot find dataset: ' + subset)
+ raise Exception("Cannot find dataset: " + subset)
# Initialize data iterator
itr = task.get_batch_iterator(
@@ -95,7 +95,7 @@ def main(args, override_args=None):
log_format=args.log_format,
log_interval=args.log_interval,
prefix=f"valid on '{subset}' subset",
- default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
+ default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
)
log_outputs = []
@@ -108,7 +108,7 @@ def main(args, override_args=None):
if args.distributed_world_size > 1:
log_outputs = distributed_utils.all_gather_list(
log_outputs,
- max_size=getattr(args, 'all_gather_list_size', 16384),
+ max_size=getattr(args, "all_gather_list_size", 16384),
)
log_outputs = list(chain.from_iterable(log_outputs))
@@ -130,5 +130,5 @@ def cli_main():
distributed_utils.call_main(args, main, override_args=override_args)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()
diff --git a/hubconf.py b/hubconf.py
index c63fa8ae89..ce7d76cfe1 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -6,14 +6,20 @@
import functools
import importlib
+from fairseq.hub_utils import ( # noqa; noqa
+ BPEHubInterface as bpe,
+ TokenizerHubInterface as tokenizer,
+)
+from fairseq.models import MODEL_REGISTRY # noqa
+
dependencies = [
- 'dataclasses',
- 'hydra',
- 'numpy',
- 'regex',
- 'requests',
- 'torch',
+ "dataclasses",
+ "hydra",
+ "numpy",
+ "regex",
+ "requests",
+ "torch",
]
@@ -26,11 +32,11 @@
# Hack: the hydra package is provided under the "hydra-core" name in
# pypi. We don't want the user mistakenly calling `pip install hydra`
# since that will install an unrelated package.
- if dep == 'hydra':
- dep = 'hydra-core'
+ if dep == "hydra":
+ dep = "hydra-core"
missing_deps.append(dep)
if len(missing_deps) > 0:
- raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))
+ raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
# torch.hub doesn't build Cython components, so if they are not found then try
@@ -42,22 +48,18 @@
import cython # noqa
import os
from setuptools import sandbox
+
sandbox.run_setup(
- os.path.join(os.path.dirname(__file__), 'setup.py'),
- ['build_ext', '--inplace'],
+ os.path.join(os.path.dirname(__file__), "setup.py"),
+ ["build_ext", "--inplace"],
)
except ImportError:
print(
- 'Unable to build Cython components. Please make sure Cython is '
- 'installed if the torch.hub model you are loading depends on it.'
+ "Unable to build Cython components. Please make sure Cython is "
+ "installed if the torch.hub model you are loading depends on it."
)
-from fairseq.hub_utils import BPEHubInterface as bpe # noqa
-from fairseq.hub_utils import TokenizerHubInterface as tokenizer # noqa
-from fairseq.models import MODEL_REGISTRY # noqa
-
-
# automatically expose models defined in FairseqModel::hub_models
for _model_type, _cls in MODEL_REGISTRY.items():
for model_name in _cls.hub_models().keys():
diff --git a/scripts/average_checkpoints.py b/scripts/average_checkpoints.py
index 9d69671e7e..c512f802bc 100644
--- a/scripts/average_checkpoints.py
+++ b/scripts/average_checkpoints.py
@@ -6,10 +6,10 @@
import argparse
import collections
-import torch
import os
import re
+import torch
from fairseq.file_io import PathManager
@@ -30,26 +30,26 @@ def average_checkpoints(inputs):
num_models = len(inputs)
for fpath in inputs:
- with PathManager.open(fpath, 'rb') as f:
+ with PathManager.open(fpath, "rb") as f:
state = torch.load(
f,
map_location=(
- lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
+ lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
- model_params = state['model']
+ model_params = state["model"]
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
- 'For checkpoint {}, expected list of params: {}, '
- 'but found: {}'.format(f, params_keys, model_params_keys)
+ "For checkpoint {}, expected list of params: {}, "
+ "but found: {}".format(f, params_keys, model_params_keys)
)
for k in params_keys:
@@ -69,7 +69,7 @@ def average_checkpoints(inputs):
averaged_params[k].div_(num_models)
else:
averaged_params[k] //= num_models
- new_state['model'] = averaged_params
+ new_state["model"] = averaged_params
return new_state
@@ -77,9 +77,9 @@ def last_n_checkpoints(paths, n, update_based, upper_bound=None):
assert len(paths) == 1
path = paths[0]
if update_based:
- pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
+ pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt")
else:
- pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
+ pt_regexp = re.compile(r"checkpoint(\d+)\.pt")
files = PathManager.ls(path)
entries = []
@@ -90,14 +90,16 @@ def last_n_checkpoints(paths, n, update_based, upper_bound=None):
if upper_bound is None or sort_key <= upper_bound:
entries.append((sort_key, m.group(0)))
if len(entries) < n:
- raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
+ raise Exception(
+ "Found {} checkpoint files but need at least {}", len(entries), n
+ )
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
def main():
parser = argparse.ArgumentParser(
- description='Tool to average the params of input checkpoints to '
- 'produce a new checkpoint',
+ description="Tool to average the params of input checkpoints to "
+ "produce a new checkpoint",
)
# fmt: off
parser.add_argument('--inputs', required=True, nargs='+',
@@ -129,22 +131,28 @@ def main():
elif args.num_epoch_checkpoints is not None:
num = args.num_epoch_checkpoints
- assert args.checkpoint_upper_bound is None or (args.num_epoch_checkpoints is not None or args.num_update_checkpoints is not None), \
- '--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints'
- assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
- 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
+ assert args.checkpoint_upper_bound is None or (
+ args.num_epoch_checkpoints is not None
+ or args.num_update_checkpoints is not None
+ ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints"
+ assert (
+ args.num_epoch_checkpoints is None or args.num_update_checkpoints is None
+ ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints"
if num is not None:
args.inputs = last_n_checkpoints(
- args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
+ args.inputs,
+ num,
+ is_update_based,
+ upper_bound=args.checkpoint_upper_bound,
)
- print('averaging checkpoints: ', args.inputs)
+ print("averaging checkpoints: ", args.inputs)
new_state = average_checkpoints(args.inputs)
- with PathManager.open(args.output, 'wb') as f:
+ with PathManager.open(args.output, "wb") as f:
torch.save(new_state, f)
- print('Finished writing averaged checkpoint to {}'.format(args.output))
+ print("Finished writing averaged checkpoint to {}".format(args.output))
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/build_sym_alignment.py b/scripts/build_sym_alignment.py
index bb0cac09dd..0ca5c18f7b 100644
--- a/scripts/build_sym_alignment.py
+++ b/scripts/build_sym_alignment.py
@@ -27,7 +27,7 @@
def main():
- parser = argparse.ArgumentParser(description='symmetric alignment builer')
+ parser = argparse.ArgumentParser(description="symmetric alignment builer")
# fmt: off
parser.add_argument('--fast_align_dir',
help='path to fast_align build directory')
@@ -47,40 +47,40 @@ def main():
# fmt: on
args = parser.parse_args()
- fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align')
- symal_bin = os.path.join(args.mosesdecoder_dir, 'bin', 'symal')
+ fast_align_bin = os.path.join(args.fast_align_dir, "fast_align")
+ symal_bin = os.path.join(args.mosesdecoder_dir, "bin", "symal")
sym_fast_align_bin = os.path.join(
- args.mosesdecoder_dir, 'scripts', 'ems',
- 'support', 'symmetrize-fast-align.perl')
+ args.mosesdecoder_dir, "scripts", "ems", "support", "symmetrize-fast-align.perl"
+ )
# create joined file
- joined_file = os.path.join(args.output_dir, 'text.joined')
- with open(args.source_file, 'r', encoding='utf-8') as src, open(args.target_file, 'r', encoding='utf-8') as tgt:
- with open(joined_file, 'w', encoding='utf-8') as joined:
+ joined_file = os.path.join(args.output_dir, "text.joined")
+ with open(args.source_file, "r", encoding="utf-8") as src, open(
+ args.target_file, "r", encoding="utf-8"
+ ) as tgt:
+ with open(joined_file, "w", encoding="utf-8") as joined:
for s, t in zip_longest(src, tgt):
- print('{} ||| {}'.format(s.strip(), t.strip()), file=joined)
+ print("{} ||| {}".format(s.strip(), t.strip()), file=joined)
- bwd_align_file = os.path.join(args.output_dir, 'align.backward')
+ bwd_align_file = os.path.join(args.output_dir, "align.backward")
# run forward alignment
- fwd_align_file = os.path.join(args.output_dir, 'align.forward')
- fwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'.format(
- FASTALIGN=fast_align_bin,
- JOINED=joined_file,
- FWD=fwd_align_file)
+ fwd_align_file = os.path.join(args.output_dir, "align.forward")
+ fwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v > {FWD}".format(
+ FASTALIGN=fast_align_bin, JOINED=joined_file, FWD=fwd_align_file
+ )
assert os.system(fwd_fast_align_cmd) == 0
# run backward alignment
- bwd_align_file = os.path.join(args.output_dir, 'align.backward')
- bwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'.format(
- FASTALIGN=fast_align_bin,
- JOINED=joined_file,
- BWD=bwd_align_file)
+ bwd_align_file = os.path.join(args.output_dir, "align.backward")
+ bwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}".format(
+ FASTALIGN=fast_align_bin, JOINED=joined_file, BWD=bwd_align_file
+ )
assert os.system(bwd_fast_align_cmd) == 0
# run symmetrization
- sym_out_file = os.path.join(args.output_dir, 'aligned')
- sym_cmd = '{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'.format(
+ sym_out_file = os.path.join(args.output_dir, "aligned")
+ sym_cmd = "{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}".format(
SYMFASTALIGN=sym_fast_align_bin,
FWD=fwd_align_file,
BWD=bwd_align_file,
@@ -88,10 +88,10 @@ def main():
TGT=args.target_file,
OUT=sym_out_file,
HEURISTIC=args.sym_heuristic,
- SYMAL=symal_bin
+ SYMAL=symal_bin,
)
assert os.system(sym_cmd) == 0
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/compare_namespaces.py b/scripts/compare_namespaces.py
index db5121189a..bc24db624f 100644
--- a/scripts/compare_namespaces.py
+++ b/scripts/compare_namespaces.py
@@ -6,13 +6,13 @@
def main():
- ns1 = eval(input('Namespace 1: '))
- ns2 = eval(input('Namespace 2: '))
+ ns1 = eval(input("Namespace 1: "))
+ ns2 = eval(input("Namespace 2: "))
def keys(ns):
ks = set()
for k in dir(ns):
- if not k.startswith('_'):
+ if not k.startswith("_"):
ks.add(k)
return ks
@@ -22,23 +22,25 @@ def keys(ns):
def print_keys(ks, ns1, ns2=None):
for k in ks:
if ns2 is None:
- print('{}\t{}'.format(k, getattr(ns1, k, None)))
+ print("{}\t{}".format(k, getattr(ns1, k, None)))
else:
- print('{}\t{}\t{}'.format(k, getattr(ns1, k, None), getattr(ns2, k, None)))
+ print(
+ "{}\t{}\t{}".format(k, getattr(ns1, k, None), getattr(ns2, k, None))
+ )
- print('Keys unique to namespace 1:')
+ print("Keys unique to namespace 1:")
print_keys(k1 - k2, ns1)
print()
- print('Keys unique to namespace 2:')
+ print("Keys unique to namespace 2:")
print_keys(k2 - k1, ns2)
print()
- print('Overlapping keys with different values:')
- ks = [k for k in k1 & k2 if getattr(ns1, k, 'None') != getattr(ns2, k, 'None')]
+ print("Overlapping keys with different values:")
+ ks = [k for k in k1 & k2 if getattr(ns1, k, "None") != getattr(ns2, k, "None")]
print_keys(ks, ns1, ns2)
print()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/constraints/extract.py b/scripts/constraints/extract.py
index 8f9bc4ad14..f6155d0a05 100755
--- a/scripts/constraints/extract.py
+++ b/scripts/constraints/extract.py
@@ -10,12 +10,13 @@
import argparse
import random
import sys
+
from sacrebleu import extract_ngrams
def get_phrase(words, index, length):
- assert(index < len(words) - length + 1)
- phr = ' '.join(words[index:index+length])
+ assert index < len(words) - length + 1
+ phr = " ".join(words[index : index + length])
for i in range(index, index + length):
words.pop(index)
return phr
@@ -33,8 +34,8 @@ def add_constraint(constraint):
constraints.append(constraint)
source = line.rstrip()
- if '\t' in line:
- source, target = line.split('\t')
+ if "\t" in line:
+ source, target = line.split("\t")
if args.add_sos:
target = f" {target}"
if args.add_eos:
@@ -53,8 +54,12 @@ def add_constraint(constraint):
segment = words.pop(segmentno)
tokens = segment.split()
phrase_index = random.choice(range(len(tokens)))
- choice = " ".join(tokens[phrase_index:min(len(tokens), phrase_index + args.len)])
- for j in range(phrase_index, min(len(tokens), phrase_index + args.len)):
+ choice = " ".join(
+ tokens[phrase_index : min(len(tokens), phrase_index + args.len)]
+ )
+ for j in range(
+ phrase_index, min(len(tokens), phrase_index + args.len)
+ ):
tokens.pop(phrase_index)
if phrase_index > 0:
words.append(" ".join(tokens[0:phrase_index]))
@@ -73,11 +78,15 @@ def add_constraint(constraint):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--number', '-n', type=int, default=1, help="number of phrases")
- parser.add_argument('--len', '-l', type=int, default=1, help="phrase length")
- parser.add_argument('--add-sos', default=False, action='store_true', help='add token')
- parser.add_argument('--add-eos', default=False, action='store_true', help='add token')
- parser.add_argument('--seed', "-s", default=0, type=int)
+ parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases")
+ parser.add_argument("--len", "-l", type=int, default=1, help="phrase length")
+ parser.add_argument(
+ "--add-sos", default=False, action="store_true", help="add token"
+ )
+ parser.add_argument(
+ "--add-eos", default=False, action="store_true", help="add token"
+ )
+ parser.add_argument("--seed", "-s", default=0, type=int)
args = parser.parse_args()
main(args)
diff --git a/scripts/constraints/validate.py b/scripts/constraints/validate.py
index 6d1a4a0885..d531ad9f39 100755
--- a/scripts/constraints/validate.py
+++ b/scripts/constraints/validate.py
@@ -7,6 +7,7 @@
import sys
+
"""Reads in a fairseq output file, and verifies that the constraints
(C- lines) are present in the output (the first H- line). Assumes that
constraints are listed prior to the first hypothesis.
diff --git a/scripts/count_docs.py b/scripts/count_docs.py
index 8d185398a7..58d85af85e 100644
--- a/scripts/count_docs.py
+++ b/scripts/count_docs.py
@@ -17,15 +17,15 @@
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('input')
- parser.add_argument('--gzip', action='store_true')
+ parser.add_argument("input")
+ parser.add_argument("--gzip", action="store_true")
args = parser.parse_args()
def gopen():
if args.gzip:
- return gzip.open(args.input, 'r')
+ return gzip.open(args.input, "r")
else:
- return open(args.input, 'r', encoding='utf-8')
+ return open(args.input, "r", encoding="utf-8")
num_lines = []
num_toks = []
@@ -54,5 +54,5 @@ def gopen():
print("average num toks per doc: {}".format(np.mean(num_toks)))
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/read_binarized.py b/scripts/read_binarized.py
index f48409beb4..a414095d03 100644
--- a/scripts/read_binarized.py
+++ b/scripts/read_binarized.py
@@ -6,12 +6,13 @@
import argparse
-from fairseq.data import data_utils, Dictionary, indexed_dataset
+from fairseq.data import Dictionary, data_utils, indexed_dataset
def get_parser():
parser = argparse.ArgumentParser(
- description='writes text from binarized file to stdout')
+ description="writes text from binarized file to stdout"
+ )
# fmt: off
parser.add_argument('--dataset-impl', help='dataset implementation',
choices=indexed_dataset.get_available_dataset_impl())
@@ -31,17 +32,17 @@ def main():
args.input,
dictionary,
dataset_impl=args.dataset_impl,
- default='lazy',
+ default="lazy",
)
for tensor_line in dataset:
if dictionary is None:
- line = ' '.join([str(int(x)) for x in tensor_line])
+ line = " ".join([str(int(x)) for x in tensor_line])
else:
line = dictionary.string(tensor_line)
print(line)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/rm_pt.py b/scripts/rm_pt.py
index 21976cee4f..6cd063d21f 100644
--- a/scripts/rm_pt.py
+++ b/scripts/rm_pt.py
@@ -11,9 +11,9 @@
import sys
-pt_regexp = re.compile(r'checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt')
-pt_regexp_epoch_based = re.compile(r'checkpoint(\d+)\.pt')
-pt_regexp_update_based = re.compile(r'checkpoint_\d+_(\d+)\.pt')
+pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt")
+pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt")
+pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt")
def parse_checkpoints(files):
@@ -42,18 +42,31 @@ def every_n_checkpoints(files, n):
def main():
parser = argparse.ArgumentParser(
description=(
- 'Recursively delete checkpoint files from `root_dir`, '
- 'but preserve checkpoint_best.pt and checkpoint_last.pt'
+ "Recursively delete checkpoint files from `root_dir`, "
+ "but preserve checkpoint_best.pt and checkpoint_last.pt"
)
)
- parser.add_argument('root_dirs', nargs='*')
- parser.add_argument('--save-last', type=int, default=0, help='number of last checkpoints to save')
- parser.add_argument('--save-every', type=int, default=0, help='interval of checkpoints to save')
- parser.add_argument('--preserve-test', action='store_true',
- help='preserve checkpoints in dirs that start with test_ prefix (default: delete them)')
- parser.add_argument('--delete-best', action='store_true', help='delete checkpoint_best.pt')
- parser.add_argument('--delete-last', action='store_true', help='delete checkpoint_last.pt')
- parser.add_argument('--no-dereference', action='store_true', help='don\'t dereference symlinks')
+ parser.add_argument("root_dirs", nargs="*")
+ parser.add_argument(
+ "--save-last", type=int, default=0, help="number of last checkpoints to save"
+ )
+ parser.add_argument(
+ "--save-every", type=int, default=0, help="interval of checkpoints to save"
+ )
+ parser.add_argument(
+ "--preserve-test",
+ action="store_true",
+ help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)",
+ )
+ parser.add_argument(
+ "--delete-best", action="store_true", help="delete checkpoint_best.pt"
+ )
+ parser.add_argument(
+ "--delete-last", action="store_true", help="delete checkpoint_last.pt"
+ )
+ parser.add_argument(
+ "--no-dereference", action="store_true", help="don't dereference symlinks"
+ )
args = parser.parse_args()
files_to_desymlink = []
@@ -72,15 +85,11 @@ def main():
continue
full_path = os.path.join(root, file)
if (
- (
- not os.path.basename(root).startswith('test_')
- or args.preserve_test
- )
- and (
- (file == 'checkpoint_last.pt' and not args.delete_last)
- or (file == 'checkpoint_best.pt' and not args.delete_best)
- or file in to_save
- )
+ not os.path.basename(root).startswith("test_") or args.preserve_test
+ ) and (
+ (file == "checkpoint_last.pt" and not args.delete_last)
+ or (file == "checkpoint_best.pt" and not args.delete_best)
+ or file in to_save
):
if os.path.islink(full_path) and not args.no_dereference:
files_to_desymlink.append(full_path)
@@ -90,43 +99,43 @@ def main():
files_to_delete.append(full_path)
if len(files_to_desymlink) == 0 and len(files_to_delete) == 0:
- print('Nothing to do.')
+ print("Nothing to do.")
sys.exit(0)
files_to_desymlink = sorted(files_to_desymlink)
files_to_preserve = sorted(files_to_preserve)
files_to_delete = sorted(files_to_delete)
- print('Operations to perform (in order):')
+ print("Operations to perform (in order):")
if len(files_to_desymlink) > 0:
for file in files_to_desymlink:
- print(' - preserve (and dereference symlink): ' + file)
+ print(" - preserve (and dereference symlink): " + file)
if len(files_to_preserve) > 0:
for file in files_to_preserve:
- print(' - preserve: ' + file)
+ print(" - preserve: " + file)
if len(files_to_delete) > 0:
for file in files_to_delete:
- print(' - delete: ' + file)
+ print(" - delete: " + file)
while True:
- resp = input('Continue? (Y/N): ')
- if resp.strip().lower() == 'y':
+ resp = input("Continue? (Y/N): ")
+ if resp.strip().lower() == "y":
break
- elif resp.strip().lower() == 'n':
+ elif resp.strip().lower() == "n":
sys.exit(0)
- print('Executing...')
+ print("Executing...")
if len(files_to_desymlink) > 0:
for file in files_to_desymlink:
realpath = os.path.realpath(file)
- print('rm ' + file)
+ print("rm " + file)
os.remove(file)
- print('cp {} {}'.format(realpath, file))
+ print("cp {} {}".format(realpath, file))
shutil.copyfile(realpath, file)
if len(files_to_delete) > 0:
for file in files_to_delete:
- print('rm ' + file)
+ print("rm " + file)
os.remove(file)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/shard_docs.py b/scripts/shard_docs.py
index 87d7c22d4f..97232c3c84 100644
--- a/scripts/shard_docs.py
+++ b/scripts/shard_docs.py
@@ -14,21 +14,23 @@
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('input')
- parser.add_argument('--num-shards', type=int)
+ parser.add_argument("input")
+ parser.add_argument("--num-shards", type=int)
args = parser.parse_args()
assert args.num_shards is not None and args.num_shards > 1
- with open(args.input, 'r', encoding='utf-8') as h:
+ with open(args.input, "r", encoding="utf-8") as h:
with contextlib.ExitStack() as stack:
outputs = [
- stack.enter_context(open(args.input + ".shard" + str(i), "w", encoding="utf-8"))
+ stack.enter_context(
+ open(args.input + ".shard" + str(i), "w", encoding="utf-8")
+ )
for i in range(args.num_shards)
]
doc = []
- first_doc = [True]*args.num_shards
+ first_doc = [True] * args.num_shards
def output_doc(i):
if not first_doc[i]:
@@ -48,5 +50,5 @@ def output_doc(i):
output_doc(num_docs % args.num_shards)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/split_train_valid_docs.py b/scripts/split_train_valid_docs.py
index 9adf99634c..ff15978528 100644
--- a/scripts/split_train_valid_docs.py
+++ b/scripts/split_train_valid_docs.py
@@ -15,12 +15,13 @@
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('input')
- parser.add_argument('sample_output', help='train output file')
- parser.add_argument('remainder_output', help='valid output file')
- parser.add_argument('-k', type=int, help="remainder size")
- parser.add_argument('--lines', action='store_true',
- help='split lines instead of docs')
+ parser.add_argument("input")
+ parser.add_argument("sample_output", help="train output file")
+ parser.add_argument("remainder_output", help="valid output file")
+ parser.add_argument("-k", type=int, help="remainder size")
+ parser.add_argument(
+ "--lines", action="store_true", help="split lines instead of docs"
+ )
args = parser.parse_args()
assert args.k is not None
@@ -43,7 +44,7 @@ def update_sample(doc):
num_docs[0] += 1
doc.clear()
- with open(args.input, 'r', encoding='utf-8') as h:
+ with open(args.input, "r", encoding="utf-8") as h:
doc = []
for i, line in enumerate(h):
if line.strip() == "": # empty line indicates new document
@@ -62,7 +63,7 @@ def update_sample(doc):
assert len(sample) == args.k
- with open(args.sample_output, 'w', encoding='utf-8') as out:
+ with open(args.sample_output, "w", encoding="utf-8") as out:
first = True
for doc in sample:
if not first and not args.lines:
@@ -71,7 +72,7 @@ def update_sample(doc):
for line in doc:
out.write(line)
- with open(args.remainder_output, 'w', encoding='utf-8') as out:
+ with open(args.remainder_output, "w", encoding="utf-8") as out:
first = True
for doc in remainder:
if not first and not args.lines:
@@ -81,5 +82,5 @@ def update_sample(doc):
out.write(line)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/spm_decode.py b/scripts/spm_decode.py
index bd3961ab97..1c18b1d2a7 100644
--- a/scripts/spm_decode.py
+++ b/scripts/spm_decode.py
@@ -14,8 +14,9 @@
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--model", required=True,
- help="sentencepiece model to use for decoding")
+ parser.add_argument(
+ "--model", required=True, help="sentencepiece model to use for decoding"
+ )
parser.add_argument("--input", required=True, help="input file to decode")
parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
args = parser.parse_args()
@@ -24,11 +25,15 @@ def main():
sp.Load(args.model)
if args.input_format == "piece":
+
def decode(l):
return "".join(sp.DecodePieces(l))
+
elif args.input_format == "id":
+
def decode(l):
return "".join(sp.DecodeIds(l))
+
else:
raise NotImplementedError
@@ -43,5 +48,6 @@ def tok2int(tok):
elif args.input_format == "piece":
print(decode(line.rstrip().split()))
+
if __name__ == "__main__":
main()
diff --git a/scripts/spm_encode.py b/scripts/spm_encode.py
index e1cb54192a..83facfb3b1 100644
--- a/scripts/spm_encode.py
+++ b/scripts/spm_encode.py
@@ -16,53 +16,73 @@
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--model", required=True,
- help="sentencepiece model to use for encoding")
- parser.add_argument("--inputs", nargs="+", default=['-'],
- help="input files to filter/encode")
- parser.add_argument("--outputs", nargs="+", default=['-'],
- help="path to save encoded outputs")
+ parser.add_argument(
+ "--model", required=True, help="sentencepiece model to use for encoding"
+ )
+ parser.add_argument(
+ "--inputs", nargs="+", default=["-"], help="input files to filter/encode"
+ )
+ parser.add_argument(
+ "--outputs", nargs="+", default=["-"], help="path to save encoded outputs"
+ )
parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
- parser.add_argument("--min-len", type=int, metavar="N",
- help="filter sentence pairs with fewer than N tokens")
- parser.add_argument("--max-len", type=int, metavar="N",
- help="filter sentence pairs with more than N tokens")
+ parser.add_argument(
+ "--min-len",
+ type=int,
+ metavar="N",
+ help="filter sentence pairs with fewer than N tokens",
+ )
+ parser.add_argument(
+ "--max-len",
+ type=int,
+ metavar="N",
+ help="filter sentence pairs with more than N tokens",
+ )
args = parser.parse_args()
- assert len(args.inputs) == len(args.outputs), \
- "number of input and output paths should match"
+ assert len(args.inputs) == len(
+ args.outputs
+ ), "number of input and output paths should match"
sp = spm.SentencePieceProcessor()
sp.Load(args.model)
if args.output_format == "piece":
+
def encode(l):
return sp.EncodeAsPieces(l)
+
elif args.output_format == "id":
+
def encode(l):
return list(map(str, sp.EncodeAsIds(l)))
+
else:
raise NotImplementedError
if args.min_len is not None or args.max_len is not None:
+
def valid(line):
- return (
- (args.min_len is None or len(line) >= args.min_len)
- and (args.max_len is None or len(line) <= args.max_len)
+ return (args.min_len is None or len(line) >= args.min_len) and (
+ args.max_len is None or len(line) <= args.max_len
)
+
else:
+
def valid(lines):
return True
with contextlib.ExitStack() as stack:
inputs = [
- stack.enter_context(open(input, "r", encoding="utf-8")) \
- if input != "-" else sys.stdin
+ stack.enter_context(open(input, "r", encoding="utf-8"))
+ if input != "-"
+ else sys.stdin
for input in args.inputs
]
outputs = [
- stack.enter_context(open(output, "w", encoding="utf-8")) \
- if output != "-" else sys.stdout
+ stack.enter_context(open(output, "w", encoding="utf-8"))
+ if output != "-"
+ else sys.stdout
for output in args.outputs
]
diff --git a/setup.py b/setup.py
index 21e05d8da6..ad2ea2088b 100644
--- a/setup.py
+++ b/setup.py
@@ -5,22 +5,23 @@
# LICENSE file in the root directory of this source tree.
import os
-from setuptools import setup, find_packages, Extension
import sys
+from setuptools import Extension, find_packages, setup
+
if sys.version_info < (3, 6):
- sys.exit('Sorry, Python >= 3.6 is required for fairseq.')
+ sys.exit("Sorry, Python >= 3.6 is required for fairseq.")
-with open('README.md') as f:
+with open("README.md") as f:
readme = f.read()
-if sys.platform == 'darwin':
- extra_compile_args = ['-stdlib=libc++', '-O3']
+if sys.platform == "darwin":
+ extra_compile_args = ["-stdlib=libc++", "-O3"]
else:
- extra_compile_args = ['-std=c++11', '-O3']
+ extra_compile_args = ["-std=c++11", "-O3"]
class NumpyExtension(Extension):
@@ -33,6 +34,7 @@ def __init__(self, *args, **kwargs):
@property
def include_dirs(self):
import numpy
+
return self.__include_dirs + [numpy.get_include()]
@include_dirs.setter
@@ -42,23 +44,23 @@ def include_dirs(self, dirs):
extensions = [
Extension(
- 'fairseq.libbleu',
+ "fairseq.libbleu",
sources=[
- 'fairseq/clib/libbleu/libbleu.cpp',
- 'fairseq/clib/libbleu/module.cpp',
+ "fairseq/clib/libbleu/libbleu.cpp",
+ "fairseq/clib/libbleu/module.cpp",
],
extra_compile_args=extra_compile_args,
),
NumpyExtension(
- 'fairseq.data.data_utils_fast',
- sources=['fairseq/data/data_utils_fast.pyx'],
- language='c++',
+ "fairseq.data.data_utils_fast",
+ sources=["fairseq/data/data_utils_fast.pyx"],
+ language="c++",
extra_compile_args=extra_compile_args,
),
NumpyExtension(
- 'fairseq.data.token_block_utils_fast',
- sources=['fairseq/data/token_block_utils_fast.pyx'],
- language='c++',
+ "fairseq.data.token_block_utils_fast",
+ sources=["fairseq/data/token_block_utils_fast.pyx"],
+ language="c++",
extra_compile_args=extra_compile_args,
),
]
@@ -70,94 +72,104 @@ def include_dirs(self, dirs):
try:
# torch is not available when generating docs
from torch.utils import cpp_extension
- extensions.extend([
- cpp_extension.CppExtension(
- 'fairseq.libnat',
- sources=[
- 'fairseq/clib/libnat/edit_dist.cpp',
- ],
- )
- ])
- if 'CUDA_HOME' in os.environ:
- extensions.extend([
+ extensions.extend(
+ [
cpp_extension.CppExtension(
- 'fairseq.libnat_cuda',
+ "fairseq.libnat",
sources=[
- 'fairseq/clib/libnat_cuda/edit_dist.cu',
- 'fairseq/clib/libnat_cuda/binding.cpp'
+ "fairseq/clib/libnat/edit_dist.cpp",
],
- )])
- cmdclass['build_ext'] = cpp_extension.BuildExtension
+ )
+ ]
+ )
+
+ if "CUDA_HOME" in os.environ:
+ extensions.extend(
+ [
+ cpp_extension.CppExtension(
+ "fairseq.libnat_cuda",
+ sources=[
+ "fairseq/clib/libnat_cuda/edit_dist.cu",
+ "fairseq/clib/libnat_cuda/binding.cpp",
+ ],
+ )
+ ]
+ )
+ cmdclass["build_ext"] = cpp_extension.BuildExtension
except ImportError:
pass
-if 'READTHEDOCS' in os.environ:
+if "READTHEDOCS" in os.environ:
# don't build extensions when generating docs
extensions = []
- if 'build_ext' in cmdclass:
- del cmdclass['build_ext']
+ if "build_ext" in cmdclass:
+ del cmdclass["build_ext"]
# use CPU build of PyTorch
dependency_links = [
- 'https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl'
+ "https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl"
]
else:
dependency_links = []
-if 'clean' in sys.argv[1:]:
+if "clean" in sys.argv[1:]:
# Source: https://bit.ly/2NLVsgE
print("deleting Cython files...")
import subprocess
- subprocess.run(['rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd'], shell=True)
+
+ subprocess.run(
+ ["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"],
+ shell=True,
+ )
setup(
- name='fairseq',
- version='0.9.0',
- description='Facebook AI Research Sequence-to-Sequence Toolkit',
- url='https://github.com/pytorch/fairseq',
+ name="fairseq",
+ version="0.9.0",
+ description="Facebook AI Research Sequence-to-Sequence Toolkit",
+ url="https://github.com/pytorch/fairseq",
classifiers=[
- 'Intended Audience :: Science/Research',
- 'License :: OSI Approved :: MIT License',
- 'Programming Language :: Python :: 3.6',
- 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python :: 3.6",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
],
long_description=readme,
- long_description_content_type='text/markdown',
+ long_description_content_type="text/markdown",
setup_requires=[
- 'cython',
- 'numpy',
- 'setuptools>=18.0',
+ "cython",
+ "numpy",
+ "setuptools>=18.0",
],
install_requires=[
- 'cffi',
- 'cython',
- 'dataclasses',
- 'editdistance',
- 'hydra-core',
- 'numpy',
- 'regex',
- 'sacrebleu>=1.4.12',
- 'torch',
- 'tqdm',
+ "cffi",
+ "cython",
+ "dataclasses",
+ "editdistance",
+ "hydra-core",
+ "numpy",
+ "regex",
+ "sacrebleu>=1.4.12",
+ "torch",
+ "tqdm",
],
dependency_links=dependency_links,
- packages=find_packages(exclude=['scripts', 'tests']),
+ packages=find_packages(exclude=["scripts", "tests"]),
ext_modules=extensions,
- test_suite='tests',
+ test_suite="tests",
entry_points={
- 'console_scripts': [
- 'fairseq-eval-lm = fairseq_cli.eval_lm:cli_main',
- 'fairseq-generate = fairseq_cli.generate:cli_main',
- 'fairseq-interactive = fairseq_cli.interactive:cli_main',
- 'fairseq-preprocess = fairseq_cli.preprocess:cli_main',
- 'fairseq-score = fairseq_cli.score:cli_main',
- 'fairseq-train = fairseq_cli.train:cli_main',
- 'fairseq-validate = fairseq_cli.validate:cli_main',
+ "console_scripts": [
+ "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main",
+ "fairseq-generate = fairseq_cli.generate:cli_main",
+ "fairseq-interactive = fairseq_cli.interactive:cli_main",
+ "fairseq-preprocess = fairseq_cli.preprocess:cli_main",
+ "fairseq-score = fairseq_cli.score:cli_main",
+ "fairseq-train = fairseq_cli.train:cli_main",
+ "fairseq-validate = fairseq_cli.validate:cli_main",
],
},
cmdclass=cmdclass,
diff --git a/tests/speech_recognition/asr_test_base.py b/tests/speech_recognition/asr_test_base.py
index 4f3d3fceb7..0341031394 100644
--- a/tests/speech_recognition/asr_test_base.py
+++ b/tests/speech_recognition/asr_test_base.py
@@ -7,6 +7,7 @@
import numpy as np
import torch
+from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.dictionary import Dictionary
from fairseq.models import (
@@ -18,7 +19,6 @@
FairseqModel,
)
from fairseq.tasks.fairseq_task import LegacyFairseqTask
-from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
DEFAULT_TEST_VOCAB_SIZE = 100
@@ -172,9 +172,8 @@ def check_encoder_output(encoder_output, batch_size=None):
"encoder_padding_mask must be a torch.Tensor" + _current_postion_info()
)
return False, msg
- if (
- mask.dtype != torch.uint8
- and (not hasattr(torch, 'bool') or mask.dtype != torch.bool)
+ if mask.dtype != torch.uint8 and (
+ not hasattr(torch, "bool") or mask.dtype != torch.bool
):
msg = (
"encoder_padding_mask must have dtype of uint8"
@@ -516,14 +515,16 @@ def setUpArgs(self):
def setUp(self):
args = self.setUpArgs()
self.model = DummyEncoderModel(encoder=DummyEncoder())
- self.criterion = self.criterion_cls.build_criterion(args=args, task=DummyTask(args))
+ self.criterion = self.criterion_cls.build_criterion(
+ args=args, task=DummyTask(args)
+ )
def get_src_tokens(self, correct_prediction, aggregate):
"""
- correct_prediction: True if the net_output (src_tokens) should
- predict the correct target
- aggregate: True if the criterion expects net_output (src_tokens)
- aggregated across time axis
+ correct_prediction: True if the net_output (src_tokens) should
+ predict the correct target
+ aggregate: True if the criterion expects net_output (src_tokens)
+ aggregated across time axis
"""
predicted_idx = 0 if correct_prediction else 1
if aggregate:
diff --git a/tests/speech_recognition/test_cross_entropy.py b/tests/speech_recognition/test_cross_entropy.py
index 508d490e01..b05400ed95 100644
--- a/tests/speech_recognition/test_cross_entropy.py
+++ b/tests/speech_recognition/test_cross_entropy.py
@@ -4,7 +4,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion
+from examples.speech_recognition.criterions.cross_entropy_acc import (
+ CrossEntropyWithAccCriterion,
+)
+
from .asr_test_base import CrossEntropyCriterionTestBase
diff --git a/tests/speech_recognition/test_data_utils.py b/tests/speech_recognition/test_data_utils.py
index 5ca7c5c2a1..a72e0b6694 100644
--- a/tests/speech_recognition/test_data_utils.py
+++ b/tests/speech_recognition/test_data_utils.py
@@ -6,18 +6,57 @@
import unittest
import torch
-
from examples.speech_recognition.data import data_utils
class DataUtilsTest(unittest.TestCase):
-
def test_normalization(self):
- sample_len1 = torch.tensor([[-0.7661, -1.3889, -2.0972, -0.9134, -0.7071, -0.9765, -0.8700, -0.8283,
- 0.7512, 1.3211, 2.1532, 2.1174, 1.2800, 1.2633, 1.6147, 1.6322,
- 2.0723, 3.1522, 3.2852, 2.2309, 2.5569, 2.2183, 2.2862, 1.5886,
- 0.8773, 0.8725, 1.2662, 0.9899, 1.1069, 1.3926, 1.2795, 1.1199,
- 1.1477, 1.2687, 1.3843, 1.1903, 0.8355, 1.1367, 1.2639, 1.4707]])
+ sample_len1 = torch.tensor(
+ [
+ [
+ -0.7661,
+ -1.3889,
+ -2.0972,
+ -0.9134,
+ -0.7071,
+ -0.9765,
+ -0.8700,
+ -0.8283,
+ 0.7512,
+ 1.3211,
+ 2.1532,
+ 2.1174,
+ 1.2800,
+ 1.2633,
+ 1.6147,
+ 1.6322,
+ 2.0723,
+ 3.1522,
+ 3.2852,
+ 2.2309,
+ 2.5569,
+ 2.2183,
+ 2.2862,
+ 1.5886,
+ 0.8773,
+ 0.8725,
+ 1.2662,
+ 0.9899,
+ 1.1069,
+ 1.3926,
+ 1.2795,
+ 1.1199,
+ 1.1477,
+ 1.2687,
+ 1.3843,
+ 1.1903,
+ 0.8355,
+ 1.1367,
+ 1.2639,
+ 1.4707,
+ ]
+ ]
+ )
out = data_utils.apply_mv_norm(sample_len1)
assert not torch.isnan(out).any()
assert (out == sample_len1).all()
diff --git a/tests/test_average_checkpoints.py b/tests/test_average_checkpoints.py
index 8ed298c3c9..f348b56b86 100644
--- a/tests/test_average_checkpoints.py
+++ b/tests/test_average_checkpoints.py
@@ -5,16 +5,14 @@
import collections
import os
+import shutil
import tempfile
import unittest
-import shutil
import numpy as np
import torch
-from torch import nn
-
-
from scripts.average_checkpoints import average_checkpoints
+from torch import nn
class ModelWithSharedParameter(nn.Module):
@@ -37,33 +35,33 @@ class TestAverageCheckpoints(unittest.TestCase):
def test_average_checkpoints(self):
params_0 = collections.OrderedDict(
[
- ('a', torch.DoubleTensor([100.0])),
- ('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
- ('c', torch.IntTensor([7, 8, 9])),
+ ("a", torch.DoubleTensor([100.0])),
+ ("b", torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
+ ("c", torch.IntTensor([7, 8, 9])),
]
)
params_1 = collections.OrderedDict(
[
- ('a', torch.DoubleTensor([1.0])),
- ('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
- ('c', torch.IntTensor([2, 2, 2])),
+ ("a", torch.DoubleTensor([1.0])),
+ ("b", torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
+ ("c", torch.IntTensor([2, 2, 2])),
]
)
params_avg = collections.OrderedDict(
[
- ('a', torch.DoubleTensor([50.5])),
- ('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
+ ("a", torch.DoubleTensor([50.5])),
+ ("b", torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
# We expect truncation for integer division
- ('c', torch.IntTensor([4, 5, 5])),
+ ("c", torch.IntTensor([4, 5, 5])),
]
)
fd_0, path_0 = tempfile.mkstemp()
fd_1, path_1 = tempfile.mkstemp()
- torch.save(collections.OrderedDict([('model', params_0)]), path_0)
- torch.save(collections.OrderedDict([('model', params_1)]), path_1)
+ torch.save(collections.OrderedDict([("model", params_0)]), path_0)
+ torch.save(collections.OrderedDict([("model", params_1)]), path_1)
- output = average_checkpoints([path_0, path_1])['model']
+ output = average_checkpoints([path_0, path_1])["model"]
os.close(fd_0)
os.remove(path_0)
@@ -71,28 +69,27 @@ def test_average_checkpoints(self):
os.remove(path_1)
for (k_expected, v_expected), (k_out, v_out) in zip(
- params_avg.items(), output.items()):
+ params_avg.items(), output.items()
+ ):
self.assertEqual(
- k_expected, k_out, 'Key mismatch - expected {} but found {}. '
- '(Expected list of keys: {} vs actual list of keys: {})'.format(
+ k_expected,
+ k_out,
+ "Key mismatch - expected {} but found {}. "
+ "(Expected list of keys: {} vs actual list of keys: {})".format(
k_expected, k_out, params_avg.keys(), output.keys()
- )
+ ),
)
np.testing.assert_allclose(
v_expected.numpy(),
v_out.numpy(),
- err_msg='Tensor value mismatch for key {}'.format(k_expected)
+ err_msg="Tensor value mismatch for key {}".format(k_expected),
)
def test_average_checkpoints_with_shared_parameters(self):
-
def _construct_model_with_shared_parameters(path, value):
m = ModelWithSharedParameter()
nn.init.constant_(m.FC1.weight, value)
- torch.save(
- {'model': m.state_dict()},
- path
- )
+ torch.save({"model": m.state_dict()}, path)
return m
tmpdir = tempfile.mkdtemp()
@@ -112,32 +109,26 @@ def _construct_model_with_shared_parameters(path, value):
new_model = average_checkpoints(paths)
self.assertTrue(
torch.equal(
- new_model['model']['embedding.weight'],
- (m1.embedding.weight +
- m2.embedding.weight +
- m3.embedding.weight) / 3.0
+ new_model["model"]["embedding.weight"],
+ (m1.embedding.weight + m2.embedding.weight + m3.embedding.weight) / 3.0,
)
)
self.assertTrue(
torch.equal(
- new_model['model']['FC1.weight'],
- (m1.FC1.weight +
- m2.FC1.weight +
- m3.FC1.weight) / 3.0
+ new_model["model"]["FC1.weight"],
+ (m1.FC1.weight + m2.FC1.weight + m3.FC1.weight) / 3.0,
)
)
self.assertTrue(
torch.equal(
- new_model['model']['FC2.weight'],
- (m1.FC2.weight +
- m2.FC2.weight +
- m3.FC2.weight) / 3.0
+ new_model["model"]["FC2.weight"],
+ (m1.FC2.weight + m2.FC2.weight + m3.FC2.weight) / 3.0,
)
)
shutil.rmtree(tmpdir)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_backtranslation_dataset.py b/tests/test_backtranslation_dataset.py
index 23ae333761..dffc3b4938 100644
--- a/tests/test_backtranslation_dataset.py
+++ b/tests/test_backtranslation_dataset.py
@@ -5,8 +5,8 @@
import unittest
+import tests.utils as test_utils
import torch
-
from fairseq.data import (
BacktranslationDataset,
LanguagePairDataset,
@@ -14,15 +14,17 @@
)
from fairseq.sequence_generator import SequenceGenerator
-import tests.utils as test_utils
-
class TestBacktranslationDataset(unittest.TestCase):
-
def setUp(self):
- self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = (
- test_utils.sequence_generator_setup()
- )
+ (
+ self.tgt_dict,
+ self.w1,
+ self.w2,
+ self.src_tokens,
+ self.src_lengths,
+ self.model,
+ ) = test_utils.sequence_generator_setup()
dummy_src_samples = self.src_tokens
@@ -30,7 +32,9 @@ def setUp(self):
self.cuda = torch.cuda.is_available()
def _backtranslation_dataset_helper(
- self, remove_eos_from_input_src, remove_eos_from_output_src,
+ self,
+ remove_eos_from_input_src,
+ remove_eos_from_output_src,
):
tgt_dataset = LanguagePairDataset(
src=self.tgt_dataset,
@@ -94,17 +98,20 @@ def _backtranslation_dataset_helper(
def test_backtranslation_dataset_no_eos_in_output_src(self):
self._backtranslation_dataset_helper(
- remove_eos_from_input_src=False, remove_eos_from_output_src=True,
+ remove_eos_from_input_src=False,
+ remove_eos_from_output_src=True,
)
def test_backtranslation_dataset_with_eos_in_output_src(self):
self._backtranslation_dataset_helper(
- remove_eos_from_input_src=False, remove_eos_from_output_src=False,
+ remove_eos_from_input_src=False,
+ remove_eos_from_output_src=False,
)
def test_backtranslation_dataset_no_eos_in_input_src(self):
self._backtranslation_dataset_helper(
- remove_eos_from_input_src=True, remove_eos_from_output_src=False,
+ remove_eos_from_input_src=True,
+ remove_eos_from_output_src=False,
)
def assertTensorEqual(self, t1, t2):
diff --git a/tests/test_binaries.py b/tests/test_binaries.py
index aa5a6c69d1..4b87afea55 100644
--- a/tests/test_binaries.py
+++ b/tests/test_binaries.py
@@ -4,31 +4,27 @@
# LICENSE file in the root directory of this source tree.
import contextlib
-from io import StringIO
import logging
import os
import random
import tempfile
import unittest
+from io import StringIO
import torch
-
from fairseq import options
-from fairseq_cli import train
-from fairseq_cli import eval_lm
-from fairseq_cli import validate
+from fairseq_cli import eval_lm, train, validate
from tests.utils import (
create_dummy_data,
+ generate_main,
preprocess_lm_data,
- preprocess_translation_data,
preprocess_summarization_data,
+ preprocess_translation_data,
train_translation_model,
- generate_main,
)
class TestTranslation(unittest.TestCase):
-
def setUp(self):
logging.disable(logging.CRITICAL)
@@ -37,180 +33,271 @@ def tearDown(self):
def test_fconv(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_fconv') as data_dir:
+ with tempfile.TemporaryDirectory("test_fconv") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'fconv_iwslt_de_en')
+ train_translation_model(data_dir, "fconv_iwslt_de_en")
generate_main(data_dir)
def test_raw(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_fconv_raw') as data_dir:
+ with tempfile.TemporaryDirectory("test_fconv_raw") as data_dir:
create_dummy_data(data_dir)
- preprocess_translation_data(data_dir, ['--dataset-impl', 'raw'])
- train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--dataset-impl', 'raw'])
- generate_main(data_dir, ['--dataset-impl', 'raw'])
+ preprocess_translation_data(data_dir, ["--dataset-impl", "raw"])
+ train_translation_model(
+ data_dir, "fconv_iwslt_de_en", ["--dataset-impl", "raw"]
+ )
+ generate_main(data_dir, ["--dataset-impl", "raw"])
def test_update_freq(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
+ with tempfile.TemporaryDirectory("test_update_freq") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--update-freq', '3'])
+ train_translation_model(
+ data_dir, "fconv_iwslt_de_en", ["--update-freq", "3"]
+ )
generate_main(data_dir)
def test_max_positions(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_max_positions') as data_dir:
+ with tempfile.TemporaryDirectory("test_max_positions") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
with self.assertRaises(Exception) as context:
train_translation_model(
- data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'],
+ data_dir,
+ "fconv_iwslt_de_en",
+ ["--max-target-positions", "5"],
)
self.assertTrue(
- 'skip this example with --skip-invalid-size-inputs-valid-test' in str(context.exception)
+ "skip this example with --skip-invalid-size-inputs-valid-test"
+ in str(context.exception)
)
train_translation_model(
- data_dir, 'fconv_iwslt_de_en',
- ['--max-target-positions', '5', '--skip-invalid-size-inputs-valid-test'],
+ data_dir,
+ "fconv_iwslt_de_en",
+ [
+ "--max-target-positions",
+ "5",
+ "--skip-invalid-size-inputs-valid-test",
+ ],
)
with self.assertRaises(Exception) as context:
generate_main(data_dir)
- generate_main(data_dir, ['--skip-invalid-size-inputs-valid-test'])
+ generate_main(data_dir, ["--skip-invalid-size-inputs-valid-test"])
def test_generation(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_sampling') as data_dir:
+ with tempfile.TemporaryDirectory("test_sampling") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'fconv_iwslt_de_en')
- generate_main(data_dir, [
- '--sampling',
- '--temperature', '2',
- '--beam', '2',
- '--nbest', '2',
- ])
- generate_main(data_dir, [
- '--sampling',
- '--sampling-topk', '3',
- '--beam', '2',
- '--nbest', '2',
- ])
- generate_main(data_dir, [
- '--sampling',
- '--sampling-topp', '0.2',
- '--beam', '2',
- '--nbest', '2',
- ])
- generate_main(data_dir, [
- '--diversity-rate', '0.5',
- '--beam', '6',
- ])
+ train_translation_model(data_dir, "fconv_iwslt_de_en")
+ generate_main(
+ data_dir,
+ [
+ "--sampling",
+ "--temperature",
+ "2",
+ "--beam",
+ "2",
+ "--nbest",
+ "2",
+ ],
+ )
+ generate_main(
+ data_dir,
+ [
+ "--sampling",
+ "--sampling-topk",
+ "3",
+ "--beam",
+ "2",
+ "--nbest",
+ "2",
+ ],
+ )
+ generate_main(
+ data_dir,
+ [
+ "--sampling",
+ "--sampling-topp",
+ "0.2",
+ "--beam",
+ "2",
+ "--nbest",
+ "2",
+ ],
+ )
+ generate_main(
+ data_dir,
+ [
+ "--diversity-rate",
+ "0.5",
+ "--beam",
+ "6",
+ ],
+ )
with self.assertRaises(ValueError):
- generate_main(data_dir, [
- '--diverse-beam-groups', '4',
- '--match-source-len',
- ])
- generate_main(data_dir, ['--prefix-size', '2'])
- generate_main(data_dir, ['--retain-dropout'])
+ generate_main(
+ data_dir,
+ [
+ "--diverse-beam-groups",
+ "4",
+ "--match-source-len",
+ ],
+ )
+ generate_main(data_dir, ["--prefix-size", "2"])
+ generate_main(data_dir, ["--retain-dropout"])
def test_eval_bleu(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_eval_bleu') as data_dir:
+ with tempfile.TemporaryDirectory("test_eval_bleu") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'fconv_iwslt_de_en', [
- '--eval-bleu',
- '--eval-bleu-print-samples',
- '--eval-bleu-remove-bpe',
- '--eval-bleu-detok', 'space',
- '--eval-bleu-args', '{"beam": 4, "min_len": 10}',
- ])
+ train_translation_model(
+ data_dir,
+ "fconv_iwslt_de_en",
+ [
+ "--eval-bleu",
+ "--eval-bleu-print-samples",
+ "--eval-bleu-remove-bpe",
+ "--eval-bleu-detok",
+ "space",
+ "--eval-bleu-args",
+ '{"beam": 4, "min_len": 10}',
+ ],
+ )
def test_lstm(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_lstm') as data_dir:
+ with tempfile.TemporaryDirectory("test_lstm") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'lstm_wiseman_iwslt_de_en', [
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--decoder-out-embed-dim', '8',
- ])
+ train_translation_model(
+ data_dir,
+ "lstm_wiseman_iwslt_de_en",
+ [
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--decoder-out-embed-dim",
+ "8",
+ ],
+ )
generate_main(data_dir)
def test_lstm_bidirectional(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_lstm_bidirectional') as data_dir:
+ with tempfile.TemporaryDirectory("test_lstm_bidirectional") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'lstm', [
- '--encoder-layers', '2',
- '--encoder-bidirectional',
- '--encoder-hidden-size', '16',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--decoder-out-embed-dim', '8',
- '--decoder-layers', '2',
- ])
+ train_translation_model(
+ data_dir,
+ "lstm",
+ [
+ "--encoder-layers",
+ "2",
+ "--encoder-bidirectional",
+ "--encoder-hidden-size",
+ "16",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--decoder-out-embed-dim",
+ "8",
+ "--decoder-layers",
+ "2",
+ ],
+ )
generate_main(data_dir)
def test_transformer(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_transformer') as data_dir:
+ with tempfile.TemporaryDirectory("test_transformer") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'transformer_iwslt_de_en', [
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- ], run_validation=True)
+ train_translation_model(
+ data_dir,
+ "transformer_iwslt_de_en",
+ [
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ ],
+ run_validation=True,
+ )
generate_main(data_dir)
def test_multilingual_transformer(self):
# test with all combinations of encoder/decoder lang tokens
- encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']]
- decoder_langtok_flags = [[], ['--decoder-langtok']]
+ encoder_langtok_flags = [
+ [],
+ ["--encoder-langtok", "src"],
+ ["--encoder-langtok", "tgt"],
+ ]
+ decoder_langtok_flags = [[], ["--decoder-langtok"]]
with contextlib.redirect_stdout(StringIO()):
for i in range(len(encoder_langtok_flags)):
for j in range(len(decoder_langtok_flags)):
enc_ltok_flag = encoder_langtok_flags[i]
dec_ltok_flag = decoder_langtok_flags[j]
- with tempfile.TemporaryDirectory(f'test_multilingual_transformer_{i}_{j}') as data_dir:
+ with tempfile.TemporaryDirectory(
+ f"test_multilingual_transformer_{i}_{j}"
+ ) as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir,
- arch='multilingual_transformer',
- task='multilingual_translation',
+ arch="multilingual_transformer",
+ task="multilingual_translation",
extra_flags=[
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- ] + enc_ltok_flag + dec_ltok_flag,
- lang_flags=['--lang-pairs', 'in-out,out-in'],
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ ]
+ + enc_ltok_flag
+ + dec_ltok_flag,
+ lang_flags=["--lang-pairs", "in-out,out-in"],
run_validation=True,
extra_valid_flags=enc_ltok_flag + dec_ltok_flag,
)
generate_main(
data_dir,
extra_flags=[
- '--task', 'multilingual_translation',
- '--lang-pairs', 'in-out,out-in',
- '--source-lang', 'in',
- '--target-lang', 'out',
- ] + enc_ltok_flag + dec_ltok_flag,
+ "--task",
+ "multilingual_translation",
+ "--lang-pairs",
+ "in-out,out-in",
+ "--source-lang",
+ "in",
+ "--target-lang",
+ "out",
+ ]
+ + enc_ltok_flag
+ + dec_ltok_flag,
)
def test_multilingual_translation_latent_depth(self):
# test with latent depth in encoder, decoder, or both
- encoder_latent_layer = [[], ['--encoder-latent-layer']]
- decoder_latent_layer = [[], ['--decoder-latent-layer']]
+ encoder_latent_layer = [[], ["--encoder-latent-layer"]]
+ decoder_latent_layer = [[], ["--decoder-latent-layer"]]
with contextlib.redirect_stdout(StringIO()):
for i in range(len(encoder_latent_layer)):
for j in range(len(decoder_latent_layer)):
@@ -218,186 +305,298 @@ def test_multilingual_translation_latent_depth(self):
continue
enc_ll_flag = encoder_latent_layer[i]
dec_ll_flag = decoder_latent_layer[j]
- with tempfile.TemporaryDirectory(f'test_multilingual_translation_latent_depth_{i}_{j}') as data_dir:
+ with tempfile.TemporaryDirectory(
+ f"test_multilingual_translation_latent_depth_{i}_{j}"
+ ) as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(
- data_dir,
- extra_flags=['--joined-dictionary']
+ data_dir, extra_flags=["--joined-dictionary"]
)
train_translation_model(
data_dir,
- arch='latent_multilingual_transformer',
- task='multilingual_translation_latent_depth',
+ arch="latent_multilingual_transformer",
+ task="multilingual_translation_latent_depth",
extra_flags=[
- '--user-dir', 'examples/latent_depth/src',
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--share-encoders',
- '--share-decoders',
- '--sparsity-weight', '0.1',
- ] + enc_ll_flag + dec_ll_flag,
- lang_flags=['--lang-pairs', 'in-out,out-in'],
+ "--user-dir",
+ "examples/latent_depth/src",
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--share-encoders",
+ "--share-decoders",
+ "--sparsity-weight",
+ "0.1",
+ ]
+ + enc_ll_flag
+ + dec_ll_flag,
+ lang_flags=["--lang-pairs", "in-out,out-in"],
run_validation=True,
- extra_valid_flags=['--user-dir', 'examples/latent_depth/src'] + enc_ll_flag + dec_ll_flag,
+ extra_valid_flags=[
+ "--user-dir",
+ "examples/latent_depth/src",
+ ]
+ + enc_ll_flag
+ + dec_ll_flag,
)
generate_main(
data_dir,
extra_flags=[
- '--user-dir', 'examples/latent_depth/src',
- '--task', 'multilingual_translation_latent_depth',
- '--lang-pairs', 'in-out,out-in',
- '--source-lang', 'in',
- '--target-lang', 'out',
- ] + enc_ll_flag + dec_ll_flag,
+ "--user-dir",
+ "examples/latent_depth/src",
+ "--task",
+ "multilingual_translation_latent_depth",
+ "--lang-pairs",
+ "in-out,out-in",
+ "--source-lang",
+ "in",
+ "--target-lang",
+ "out",
+ ]
+ + enc_ll_flag
+ + dec_ll_flag,
)
def test_translation_multi_simple_epoch(self):
# test with all combinations of encoder/decoder lang tokens
- encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']]
- decoder_langtok_flags = [[], ['--decoder-langtok']]
+ encoder_langtok_flags = [
+ [],
+ ["--encoder-langtok", "src"],
+ ["--encoder-langtok", "tgt"],
+ ]
+ decoder_langtok_flags = [[], ["--decoder-langtok"]]
with contextlib.redirect_stdout(StringIO()):
for i in range(len(encoder_langtok_flags)):
for j in range(len(decoder_langtok_flags)):
enc_ltok_flag = encoder_langtok_flags[i]
dec_ltok_flag = decoder_langtok_flags[j]
- with tempfile.TemporaryDirectory(f'test_translation_multi_simple_epoch_{i}_{j}') as data_dir:
+ with tempfile.TemporaryDirectory(
+ f"test_translation_multi_simple_epoch_{i}_{j}"
+ ) as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(
- data_dir,
- extra_flags=['--joined-dictionary']
+ data_dir, extra_flags=["--joined-dictionary"]
)
train_translation_model(
data_dir,
- arch='transformer',
- task='translation_multi_simple_epoch',
+ arch="transformer",
+ task="translation_multi_simple_epoch",
extra_flags=[
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--sampling-method', 'temperature',
- '--sampling-temperature', '1.5',
- '--virtual-epoch-size', '1000',
- ] + enc_ltok_flag + dec_ltok_flag,
- lang_flags=['--lang-pairs', 'in-out,out-in'],
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--sampling-method",
+ "temperature",
+ "--sampling-temperature",
+ "1.5",
+ "--virtual-epoch-size",
+ "1000",
+ ]
+ + enc_ltok_flag
+ + dec_ltok_flag,
+ lang_flags=["--lang-pairs", "in-out,out-in"],
run_validation=True,
extra_valid_flags=enc_ltok_flag + dec_ltok_flag,
)
generate_main(
data_dir,
extra_flags=[
- '--task', 'translation_multi_simple_epoch',
- '--lang-pairs', 'in-out,out-in',
- '--source-lang', 'in',
- '--target-lang', 'out',
- ] + enc_ltok_flag + dec_ltok_flag,
+ "--task",
+ "translation_multi_simple_epoch",
+ "--lang-pairs",
+ "in-out,out-in",
+ "--source-lang",
+ "in",
+ "--target-lang",
+ "out",
+ ]
+ + enc_ltok_flag
+ + dec_ltok_flag,
)
def test_transformer_cross_self_attention(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_transformer_cross_self_attention') as data_dir:
+ with tempfile.TemporaryDirectory(
+ "test_transformer_cross_self_attention"
+ ) as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'transformer_iwslt_de_en', [
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--no-cross-attention',
- '--cross-self-attention',
- ], run_validation=True)
+ train_translation_model(
+ data_dir,
+ "transformer_iwslt_de_en",
+ [
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--no-cross-attention",
+ "--cross-self-attention",
+ ],
+ run_validation=True,
+ )
generate_main(data_dir, extra_flags=[])
def test_transformer_pointer_generator(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_transformer_pointer_generator') as data_dir:
+ with tempfile.TemporaryDirectory(
+ "test_transformer_pointer_generator"
+ ) as data_dir:
create_dummy_data(data_dir)
preprocess_summarization_data(data_dir)
train_translation_model(
data_dir,
- 'transformer_pointer_generator',
+ "transformer_pointer_generator",
extra_flags=[
- '--user-dir', 'examples/pointer_generator/src',
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--alignment-layer', '-1',
- '--alignment-heads', '1',
- '--source-position-markers', '0',
+ "--user-dir",
+ "examples/pointer_generator/src",
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--alignment-layer",
+ "-1",
+ "--alignment-heads",
+ "1",
+ "--source-position-markers",
+ "0",
],
run_validation=True,
- extra_valid_flags=['--user-dir', 'examples/pointer_generator/src'],
+ extra_valid_flags=["--user-dir", "examples/pointer_generator/src"],
)
generate_main(
data_dir,
- extra_flags=['--user-dir', 'examples/pointer_generator/src'],
+ extra_flags=["--user-dir", "examples/pointer_generator/src"],
)
def test_lightconv(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_lightconv') as data_dir:
+ with tempfile.TemporaryDirectory("test_lightconv") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'lightconv_iwslt_de_en', [
- '--encoder-conv-type', 'lightweight',
- '--decoder-conv-type', 'lightweight',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- ])
+ train_translation_model(
+ data_dir,
+ "lightconv_iwslt_de_en",
+ [
+ "--encoder-conv-type",
+ "lightweight",
+ "--decoder-conv-type",
+ "lightweight",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ ],
+ )
generate_main(data_dir)
def test_dynamicconv(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_dynamicconv') as data_dir:
+ with tempfile.TemporaryDirectory("test_dynamicconv") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'lightconv_iwslt_de_en', [
- '--encoder-conv-type', 'dynamic',
- '--decoder-conv-type', 'dynamic',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- ])
+ train_translation_model(
+ data_dir,
+ "lightconv_iwslt_de_en",
+ [
+ "--encoder-conv-type",
+ "dynamic",
+ "--decoder-conv-type",
+ "dynamic",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ ],
+ )
generate_main(data_dir)
def test_cmlm_transformer(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_cmlm_transformer') as data_dir:
+ with tempfile.TemporaryDirectory("test_cmlm_transformer") as data_dir:
create_dummy_data(data_dir)
- preprocess_translation_data(data_dir, ['--joined-dictionary'])
- train_translation_model(data_dir, 'cmlm_transformer', [
- '--apply-bert-init',
- '--criterion', 'nat_loss',
- '--noise', 'full_mask',
- '--pred-length-offset',
- '--length-loss-factor', '0.1'
- ], task='translation_lev')
- generate_main(data_dir, [
- '--task', 'translation_lev',
- '--iter-decode-max-iter', '9',
- '--iter-decode-eos-penalty', '0',
- '--print-step',
- ])
+ preprocess_translation_data(data_dir, ["--joined-dictionary"])
+ train_translation_model(
+ data_dir,
+ "cmlm_transformer",
+ [
+ "--apply-bert-init",
+ "--criterion",
+ "nat_loss",
+ "--noise",
+ "full_mask",
+ "--pred-length-offset",
+ "--length-loss-factor",
+ "0.1",
+ ],
+ task="translation_lev",
+ )
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "translation_lev",
+ "--iter-decode-max-iter",
+ "9",
+ "--iter-decode-eos-penalty",
+ "0",
+ "--print-step",
+ ],
+ )
def test_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir:
+ with tempfile.TemporaryDirectory(
+ "test_nonautoregressive_transformer"
+ ) as data_dir:
create_dummy_data(data_dir)
- preprocess_translation_data(data_dir, ['--joined-dictionary'])
- train_translation_model(data_dir, 'nonautoregressive_transformer', [
- '--apply-bert-init', '--src-embedding-copy', '--criterion',
- 'nat_loss', '--noise', 'full_mask', '--pred-length-offset',
- '--length-loss-factor', '0.1'
- ], task='translation_lev')
- generate_main(data_dir, [
- '--task', 'translation_lev',
- '--iter-decode-max-iter', '0',
- '--iter-decode-eos-penalty', '0',
- '--print-step',
- ])
+ preprocess_translation_data(data_dir, ["--joined-dictionary"])
+ train_translation_model(
+ data_dir,
+ "nonautoregressive_transformer",
+ [
+ "--apply-bert-init",
+ "--src-embedding-copy",
+ "--criterion",
+ "nat_loss",
+ "--noise",
+ "full_mask",
+ "--pred-length-offset",
+ "--length-loss-factor",
+ "0.1",
+ ],
+ task="translation_lev",
+ )
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "translation_lev",
+ "--iter-decode-max-iter",
+ "0",
+ "--iter-decode-eos-penalty",
+ "0",
+ "--print-step",
+ ],
+ )
# def test_nat_crf_transformer(self):
# with contextlib.redirect_stdout(StringIO()):
@@ -421,78 +620,139 @@ def test_nonautoregressive_transformer(self):
def test_iterative_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir:
+ with tempfile.TemporaryDirectory(
+ "test_iterative_nonautoregressive_transformer"
+ ) as data_dir:
create_dummy_data(data_dir)
- preprocess_translation_data(data_dir, ['--joined-dictionary'])
- train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [
- '--apply-bert-init', '--src-embedding-copy', '--criterion',
- 'nat_loss', '--noise', 'full_mask', '--stochastic-approx',
- '--dae-ratio', '0.5', '--train-step', '3'
- ], task='translation_lev')
- generate_main(data_dir, [
- '--task', 'translation_lev',
- '--iter-decode-max-iter', '9',
- '--iter-decode-eos-penalty', '0',
- '--print-step',
- ])
+ preprocess_translation_data(data_dir, ["--joined-dictionary"])
+ train_translation_model(
+ data_dir,
+ "iterative_nonautoregressive_transformer",
+ [
+ "--apply-bert-init",
+ "--src-embedding-copy",
+ "--criterion",
+ "nat_loss",
+ "--noise",
+ "full_mask",
+ "--stochastic-approx",
+ "--dae-ratio",
+ "0.5",
+ "--train-step",
+ "3",
+ ],
+ task="translation_lev",
+ )
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "translation_lev",
+ "--iter-decode-max-iter",
+ "9",
+ "--iter-decode-eos-penalty",
+ "0",
+ "--print-step",
+ ],
+ )
def test_insertion_transformer(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir:
+ with tempfile.TemporaryDirectory("test_insertion_transformer") as data_dir:
create_dummy_data(data_dir)
- preprocess_translation_data(data_dir, ['--joined-dictionary'])
- train_translation_model(data_dir, 'insertion_transformer', [
- '--apply-bert-init', '--criterion', 'nat_loss', '--noise',
- 'random_mask'
- ], task='translation_lev')
- generate_main(data_dir, [
- '--task', 'translation_lev',
- '--iter-decode-max-iter', '9',
- '--iter-decode-eos-penalty', '0',
- '--print-step',
- ])
+ preprocess_translation_data(data_dir, ["--joined-dictionary"])
+ train_translation_model(
+ data_dir,
+ "insertion_transformer",
+ [
+ "--apply-bert-init",
+ "--criterion",
+ "nat_loss",
+ "--noise",
+ "random_mask",
+ ],
+ task="translation_lev",
+ )
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "translation_lev",
+ "--iter-decode-max-iter",
+ "9",
+ "--iter-decode-eos-penalty",
+ "0",
+ "--print-step",
+ ],
+ )
def test_mixture_of_experts(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_moe') as data_dir:
+ with tempfile.TemporaryDirectory("test_moe") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
- train_translation_model(data_dir, 'transformer_iwslt_de_en', [
- '--task', 'translation_moe',
- '--user-dir', 'examples/translation_moe/src',
- '--method', 'hMoElp',
- '--mean-pool-gating-network',
- '--num-experts', '3',
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- ])
- generate_main(data_dir, [
- '--task', 'translation_moe',
- '--user-dir', 'examples/translation_moe/src',
- '--method', 'hMoElp',
- '--mean-pool-gating-network',
- '--num-experts', '3',
- '--gen-expert', '0'
- ])
+ train_translation_model(
+ data_dir,
+ "transformer_iwslt_de_en",
+ [
+ "--task",
+ "translation_moe",
+ "--user-dir",
+ "examples/translation_moe/src",
+ "--method",
+ "hMoElp",
+ "--mean-pool-gating-network",
+ "--num-experts",
+ "3",
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ ],
+ )
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "translation_moe",
+ "--user-dir",
+ "examples/translation_moe/src",
+ "--method",
+ "hMoElp",
+ "--mean-pool-gating-network",
+ "--num-experts",
+ "3",
+ "--gen-expert",
+ "0",
+ ],
+ )
def test_alignment(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_alignment') as data_dir:
+ with tempfile.TemporaryDirectory("test_alignment") as data_dir:
create_dummy_data(data_dir, alignment=True)
- preprocess_translation_data(data_dir, ['--align-suffix', 'align'])
+ preprocess_translation_data(data_dir, ["--align-suffix", "align"])
train_translation_model(
data_dir,
- 'transformer_align',
+ "transformer_align",
[
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--load-alignments',
- '--alignment-layer', '1',
- '--criterion', 'label_smoothed_cross_entropy_with_alignment',
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--load-alignments",
+ "--alignment-layer",
+ "1",
+ "--criterion",
+ "label_smoothed_cross_entropy_with_alignment",
],
run_validation=True,
)
@@ -500,21 +760,27 @@ def test_alignment(self):
def test_alignment_full_context(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_alignment') as data_dir:
+ with tempfile.TemporaryDirectory("test_alignment") as data_dir:
create_dummy_data(data_dir, alignment=True)
- preprocess_translation_data(data_dir, ['--align-suffix', 'align'])
+ preprocess_translation_data(data_dir, ["--align-suffix", "align"])
train_translation_model(
data_dir,
- 'transformer_align',
+ "transformer_align",
[
- '--encoder-layers', '2',
- '--decoder-layers', '2',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--load-alignments',
- '--alignment-layer', '1',
- '--criterion', 'label_smoothed_cross_entropy_with_alignment',
- '--full-context-alignment',
+ "--encoder-layers",
+ "2",
+ "--decoder-layers",
+ "2",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--load-alignments",
+ "--alignment-layer",
+ "1",
+ "--criterion",
+ "label_smoothed_cross_entropy_with_alignment",
+ "--full-context-alignment",
],
run_validation=True,
)
@@ -522,7 +788,6 @@ def test_alignment_full_context(self):
class TestStories(unittest.TestCase):
-
def setUp(self):
logging.disable(logging.CRITICAL)
@@ -531,37 +796,55 @@ def tearDown(self):
def test_fconv_self_att_wp(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_fconv_self_att_wp') as data_dir:
+ with tempfile.TemporaryDirectory("test_fconv_self_att_wp") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
config = [
- '--encoder-layers', '[(128, 3)] * 2',
- '--decoder-layers', '[(128, 3)] * 2',
- '--decoder-attention', 'True',
- '--encoder-attention', 'False',
- '--gated-attention', 'True',
- '--self-attention', 'True',
- '--project-input', 'True',
- '--encoder-embed-dim', '8',
- '--decoder-embed-dim', '8',
- '--decoder-out-embed-dim', '8',
- '--multihead-self-attention-nheads', '2'
+ "--encoder-layers",
+ "[(128, 3)] * 2",
+ "--decoder-layers",
+ "[(128, 3)] * 2",
+ "--decoder-attention",
+ "True",
+ "--encoder-attention",
+ "False",
+ "--gated-attention",
+ "True",
+ "--self-attention",
+ "True",
+ "--project-input",
+ "True",
+ "--encoder-embed-dim",
+ "8",
+ "--decoder-embed-dim",
+ "8",
+ "--decoder-out-embed-dim",
+ "8",
+ "--multihead-self-attention-nheads",
+ "2",
]
- train_translation_model(data_dir, 'fconv_self_att_wp', config)
+ train_translation_model(data_dir, "fconv_self_att_wp", config)
generate_main(data_dir)
# fusion model
- os.rename(os.path.join(data_dir, 'checkpoint_last.pt'), os.path.join(data_dir, 'pretrained.pt'))
- config.extend([
- '--pretrained', 'True',
- '--pretrained-checkpoint', os.path.join(data_dir, 'pretrained.pt'),
- '--save-dir', os.path.join(data_dir, 'fusion_model'),
- ])
- train_translation_model(data_dir, 'fconv_self_att_wp', config)
+ os.rename(
+ os.path.join(data_dir, "checkpoint_last.pt"),
+ os.path.join(data_dir, "pretrained.pt"),
+ )
+ config.extend(
+ [
+ "--pretrained",
+ "True",
+ "--pretrained-checkpoint",
+ os.path.join(data_dir, "pretrained.pt"),
+ "--save-dir",
+ os.path.join(data_dir, "fusion_model"),
+ ]
+ )
+ train_translation_model(data_dir, "fconv_self_att_wp", config)
class TestLanguageModeling(unittest.TestCase):
-
def setUp(self):
logging.disable(logging.CRITICAL)
@@ -570,84 +853,134 @@ def tearDown(self):
def test_fconv_lm(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
+ with tempfile.TemporaryDirectory("test_fconv_lm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
- train_language_model(data_dir, 'fconv_lm', [
- '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
- '--decoder-embed-dim', '280',
- '--optimizer', 'nag',
- '--lr', '0.1',
- ])
+ train_language_model(
+ data_dir,
+ "fconv_lm",
+ [
+ "--decoder-layers",
+ "[(850, 3)] * 2 + [(1024,4)]",
+ "--decoder-embed-dim",
+ "280",
+ "--optimizer",
+ "nag",
+ "--lr",
+ "0.1",
+ ],
+ )
eval_lm_main(data_dir)
- generate_main(data_dir, [
- '--task', 'language_modeling',
- '--sample-break-mode', 'eos',
- '--tokens-per-sample', '500',
- ])
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "language_modeling",
+ "--sample-break-mode",
+ "eos",
+ "--tokens-per-sample",
+ "500",
+ ],
+ )
def test_transformer_lm(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir:
+ with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
- data_dir, 'transformer_lm', ['--add-bos-token'], run_validation=True,
+ data_dir,
+ "transformer_lm",
+ ["--add-bos-token"],
+ run_validation=True,
)
eval_lm_main(data_dir)
- generate_main(data_dir, [
- '--task', 'language_modeling',
- '--sample-break-mode', 'eos',
- '--tokens-per-sample', '500',
- ])
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "language_modeling",
+ "--sample-break-mode",
+ "eos",
+ "--tokens-per-sample",
+ "500",
+ ],
+ )
def test_lightconv_lm(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_lightconv_lm') as data_dir:
+ with tempfile.TemporaryDirectory("test_lightconv_lm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
- data_dir, 'lightconv_lm', ['--add-bos-token'], run_validation=True,
+ data_dir,
+ "lightconv_lm",
+ ["--add-bos-token"],
+ run_validation=True,
)
eval_lm_main(data_dir)
- generate_main(data_dir, [
- '--task', 'language_modeling',
- '--sample-break-mode', 'eos',
- '--tokens-per-sample', '500',
- ])
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "language_modeling",
+ "--sample-break-mode",
+ "eos",
+ "--tokens-per-sample",
+ "500",
+ ],
+ )
def test_lstm_lm(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_lstm_lm') as data_dir:
+ with tempfile.TemporaryDirectory("test_lstm_lm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
- data_dir, 'lstm_lm', ['--add-bos-token'], run_validation=True,
+ data_dir,
+ "lstm_lm",
+ ["--add-bos-token"],
+ run_validation=True,
)
eval_lm_main(data_dir)
- generate_main(data_dir, [
- '--task', 'language_modeling',
- '--sample-break-mode', 'eos',
- '--tokens-per-sample', '500',
- ])
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "language_modeling",
+ "--sample-break-mode",
+ "eos",
+ "--tokens-per-sample",
+ "500",
+ ],
+ )
def test_lstm_lm_residuals(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_lstm_lm_residuals') as data_dir:
+ with tempfile.TemporaryDirectory("test_lstm_lm_residuals") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
- data_dir, 'lstm_lm', ['--add-bos-token', '--residuals'], run_validation=True,
+ data_dir,
+ "lstm_lm",
+ ["--add-bos-token", "--residuals"],
+ run_validation=True,
)
eval_lm_main(data_dir)
- generate_main(data_dir, [
- '--task', 'language_modeling',
- '--sample-break-mode', 'eos',
- '--tokens-per-sample', '500',
- ])
+ generate_main(
+ data_dir,
+ [
+ "--task",
+ "language_modeling",
+ "--sample-break-mode",
+ "eos",
+ "--tokens-per-sample",
+ "500",
+ ],
+ )
-class TestMaskedLanguageModel(unittest.TestCase):
+class TestMaskedLanguageModel(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
@@ -666,32 +999,52 @@ def test_roberta_masked_lm(self):
with tempfile.TemporaryDirectory("test_roberta_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
- train_masked_lm(data_dir, "roberta_base", extra_flags=["--encoder-layers", "2"])
+ train_masked_lm(
+ data_dir, "roberta_base", extra_flags=["--encoder-layers", "2"]
+ )
def test_roberta_sentence_prediction(self):
num_classes = 3
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_roberta_head") as data_dir:
create_dummy_roberta_head_data(data_dir, num_classes=num_classes)
- preprocess_lm_data(os.path.join(data_dir, 'input0'))
- preprocess_lm_data(os.path.join(data_dir, 'label'))
+ preprocess_lm_data(os.path.join(data_dir, "input0"))
+ preprocess_lm_data(os.path.join(data_dir, "label"))
train_roberta_head(data_dir, "roberta_base", num_classes=num_classes)
def test_roberta_regression_single(self):
num_classes = 1
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory("test_roberta_regression_single") as data_dir:
- create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True)
- preprocess_lm_data(os.path.join(data_dir, 'input0'))
- train_roberta_head(data_dir, "roberta_base", num_classes=num_classes, extra_flags=['--regression-target'])
+ with tempfile.TemporaryDirectory(
+ "test_roberta_regression_single"
+ ) as data_dir:
+ create_dummy_roberta_head_data(
+ data_dir, num_classes=num_classes, regression=True
+ )
+ preprocess_lm_data(os.path.join(data_dir, "input0"))
+ train_roberta_head(
+ data_dir,
+ "roberta_base",
+ num_classes=num_classes,
+ extra_flags=["--regression-target"],
+ )
def test_roberta_regression_multiple(self):
num_classes = 3
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory("test_roberta_regression_multiple") as data_dir:
- create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True)
- preprocess_lm_data(os.path.join(data_dir, 'input0'))
- train_roberta_head(data_dir, "roberta_base", num_classes=num_classes, extra_flags=['--regression-target'])
+ with tempfile.TemporaryDirectory(
+ "test_roberta_regression_multiple"
+ ) as data_dir:
+ create_dummy_roberta_head_data(
+ data_dir, num_classes=num_classes, regression=True
+ )
+ preprocess_lm_data(os.path.join(data_dir, "input0"))
+ train_roberta_head(
+ data_dir,
+ "roberta_base",
+ num_classes=num_classes,
+ extra_flags=["--regression-target"],
+ )
def test_linformer_roberta_masked_lm(self):
with contextlib.redirect_stdout(StringIO()):
@@ -702,8 +1055,10 @@ def test_linformer_roberta_masked_lm(self):
data_dir,
"linformer_roberta_base",
extra_flags=[
- "--user-dir", "examples/linformer/src",
- "--encoder-layers", "2",
+ "--user-dir",
+ "examples/linformer/src",
+ "--encoder-layers",
+ "2",
],
)
@@ -712,8 +1067,8 @@ def test_linformer_roberta_sentence_prediction(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_linformer_roberta_head") as data_dir:
create_dummy_roberta_head_data(data_dir, num_classes=num_classes)
- preprocess_lm_data(os.path.join(data_dir, 'input0'))
- preprocess_lm_data(os.path.join(data_dir, 'label'))
+ preprocess_lm_data(os.path.join(data_dir, "input0"))
+ preprocess_lm_data(os.path.join(data_dir, "label"))
train_roberta_head(
data_dir,
"linformer_roberta_base",
@@ -724,27 +1079,43 @@ def test_linformer_roberta_sentence_prediction(self):
def test_linformer_roberta_regression_single(self):
num_classes = 1
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory("test_linformer_roberta_regression_single") as data_dir:
- create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True)
- preprocess_lm_data(os.path.join(data_dir, 'input0'))
+ with tempfile.TemporaryDirectory(
+ "test_linformer_roberta_regression_single"
+ ) as data_dir:
+ create_dummy_roberta_head_data(
+ data_dir, num_classes=num_classes, regression=True
+ )
+ preprocess_lm_data(os.path.join(data_dir, "input0"))
train_roberta_head(
data_dir,
"linformer_roberta_base",
num_classes=num_classes,
- extra_flags=["--regression-target", "--user-dir", "examples/linformer/src"],
+ extra_flags=[
+ "--regression-target",
+ "--user-dir",
+ "examples/linformer/src",
+ ],
)
def test_linformer_roberta_regression_multiple(self):
num_classes = 3
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory("test_linformer_roberta_regression_multiple") as data_dir:
- create_dummy_roberta_head_data(data_dir, num_classes=num_classes, regression=True)
- preprocess_lm_data(os.path.join(data_dir, 'input0'))
+ with tempfile.TemporaryDirectory(
+ "test_linformer_roberta_regression_multiple"
+ ) as data_dir:
+ create_dummy_roberta_head_data(
+ data_dir, num_classes=num_classes, regression=True
+ )
+ preprocess_lm_data(os.path.join(data_dir, "input0"))
train_roberta_head(
data_dir,
"linformer_roberta_base",
num_classes=num_classes,
- extra_flags=["--regression-target", "--user-dir", "examples/linformer/src"],
+ extra_flags=[
+ "--regression-target",
+ "--user-dir",
+ "examples/linformer/src",
+ ],
)
def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only):
@@ -755,7 +1126,7 @@ def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_on
train_legacy_masked_language_model(
data_dir,
arch="masked_lm",
- extra_args=('--encoder-learned-pos',) if learned_pos_emb else ()
+ extra_args=("--encoder-learned-pos",) if learned_pos_emb else (),
)
with tempfile.TemporaryDirectory(
"test_mlm_translation"
@@ -793,10 +1164,13 @@ def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_on
"500",
"--max-target-positions",
"500",
- ] + (
+ ]
+ + (
["--encoder-learned-pos", "--decoder-learned-pos"]
- if learned_pos_emb else []
- ) + (['--init-encoder-only'] if encoder_only else []),
+ if learned_pos_emb
+ else []
+ )
+ + (["--init-encoder-only"] if encoder_only else []),
task="translation_from_pretrained_xlm",
)
@@ -814,8 +1188,8 @@ def test_r4f_roberta(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_r4f_roberta_head") as data_dir:
create_dummy_roberta_head_data(data_dir, num_classes=num_classes)
- preprocess_lm_data(os.path.join(data_dir, 'input0'))
- preprocess_lm_data(os.path.join(data_dir, 'label'))
+ preprocess_lm_data(os.path.join(data_dir, "input0"))
+ preprocess_lm_data(os.path.join(data_dir, "label"))
train_roberta_head(
data_dir,
"roberta_base",
@@ -824,8 +1198,8 @@ def test_r4f_roberta(self):
"--user-dir",
"examples/rxf/src",
"--criterion",
- 'sentence_prediction_r3f',
- '--spectral-norm-classification-head',
+ "sentence_prediction_r3f",
+ "--spectral-norm-classification-head",
],
)
@@ -890,13 +1264,13 @@ def train_legacy_masked_language_model(data_dir, arch, extra_args=()):
"raw",
"--num-workers",
"0",
- ] + list(extra_args),
+ ]
+ + list(extra_args),
)
train.main(train_args)
class TestOptimizers(unittest.TestCase):
-
def setUp(self):
logging.disable(logging.CRITICAL)
@@ -905,27 +1279,39 @@ def tearDown(self):
def test_optimizers(self):
with contextlib.redirect_stdout(StringIO()):
- with tempfile.TemporaryDirectory('test_optimizers') as data_dir:
+ with tempfile.TemporaryDirectory("test_optimizers") as data_dir:
# Use just a bit of data and tiny model to keep this test runtime reasonable
create_dummy_data(data_dir, num_examples=10, maxlen=5)
preprocess_translation_data(data_dir)
- optimizers = ['adafactor', 'adam', 'nag', 'adagrad', 'sgd', 'adadelta']
- last_checkpoint = os.path.join(data_dir, 'checkpoint_last.pt')
+ optimizers = ["adafactor", "adam", "nag", "adagrad", "sgd", "adadelta"]
+ last_checkpoint = os.path.join(data_dir, "checkpoint_last.pt")
for optimizer in optimizers:
if os.path.exists(last_checkpoint):
os.remove(last_checkpoint)
- train_translation_model(data_dir, 'lstm', [
- '--required-batch-size-multiple', '1',
- '--encoder-layers', '1',
- '--encoder-hidden-size', '32',
- '--decoder-layers', '1',
- '--optimizer', optimizer,
- ])
+ train_translation_model(
+ data_dir,
+ "lstm",
+ [
+ "--required-batch-size-multiple",
+ "1",
+ "--encoder-layers",
+ "1",
+ "--encoder-hidden-size",
+ "32",
+ "--decoder-layers",
+ "1",
+ "--optimizer",
+ optimizer,
+ ],
+ )
generate_main(data_dir)
-def create_dummy_roberta_head_data(data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False):
- input_dir = 'input0'
+def create_dummy_roberta_head_data(
+ data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False
+):
+ input_dir = "input0"
+
def _create_dummy_data(filename):
random_data = torch.rand(num_examples * maxlen)
input_data = 97 + torch.floor(26 * random_data).int()
@@ -933,29 +1319,29 @@ def _create_dummy_data(filename):
output_data = torch.rand((num_examples, num_classes))
else:
output_data = 1 + torch.floor(num_classes * torch.rand(num_examples)).int()
- with open(os.path.join(data_dir, input_dir, filename+'.out'), 'w') as f_in:
- label_filename = filename+'.label' if regression else filename+'.out'
- with open(os.path.join(data_dir, 'label', label_filename), 'w') as f_out:
+ with open(os.path.join(data_dir, input_dir, filename + ".out"), "w") as f_in:
+ label_filename = filename + ".label" if regression else filename + ".out"
+ with open(os.path.join(data_dir, "label", label_filename), "w") as f_out:
offset = 0
for i in range(num_examples):
# write example input
ex_len = random.randint(1, maxlen)
- ex_str = ' '.join(map(chr, input_data[offset:offset+ex_len]))
+ ex_str = " ".join(map(chr, input_data[offset : offset + ex_len]))
print(ex_str, file=f_in)
# write example label
if regression:
- class_str = ' '.join(map(str, output_data[i].numpy()))
+ class_str = " ".join(map(str, output_data[i].numpy()))
print(class_str, file=f_out)
else:
- class_str = 'class{}'.format(output_data[i])
+ class_str = "class{}".format(output_data[i])
print(class_str, file=f_out)
offset += ex_len
os.mkdir(os.path.join(data_dir, input_dir))
- os.mkdir(os.path.join(data_dir, 'label'))
- _create_dummy_data('train')
- _create_dummy_data('valid')
- _create_dummy_data('test')
+ os.mkdir(os.path.join(data_dir, "label"))
+ _create_dummy_data("train")
+ _create_dummy_data("valid")
+ _create_dummy_data("test")
def train_masked_lm(data_dir, arch, extra_flags=None):
@@ -963,20 +1349,32 @@ def train_masked_lm(data_dir, arch, extra_flags=None):
train_args = options.parse_args_and_arch(
train_parser,
[
- '--task', 'masked_lm',
+ "--task",
+ "masked_lm",
+ data_dir,
+ "--arch",
+ arch,
+ "--optimizer",
+ "adam",
+ "--lr",
+ "0.0001",
+ "--criterion",
+ "masked_lm",
+ "--batch-size",
+ "500",
+ "--save-dir",
data_dir,
- '--arch', arch,
- '--optimizer', 'adam',
- '--lr', '0.0001',
- '--criterion', 'masked_lm',
- '--batch-size', '500',
- '--save-dir', data_dir,
- '--max-epoch', '1',
- '--no-progress-bar',
- '--distributed-world-size', '1',
- '--ddp-backend', 'no_c10d',
- '--num-workers', '0',
- ] + (extra_flags or []),
+ "--max-epoch",
+ "1",
+ "--no-progress-bar",
+ "--distributed-world-size",
+ "1",
+ "--ddp-backend",
+ "no_c10d",
+ "--num-workers",
+ "0",
+ ]
+ + (extra_flags or []),
)
train.main(train_args)
@@ -986,24 +1384,40 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None):
train_args = options.parse_args_and_arch(
train_parser,
[
- '--task', 'sentence_prediction',
+ "--task",
+ "sentence_prediction",
data_dir,
- '--arch', arch,
- '--encoder-layers', '2',
- '--num-classes', str(num_classes),
- '--optimizer', 'adam',
- '--lr', '0.0001',
- '--criterion', 'sentence_prediction',
- '--max-tokens', '500',
- '--max-positions', '500',
- '--batch-size', '500',
- '--save-dir', data_dir,
- '--max-epoch', '1',
- '--no-progress-bar',
- '--distributed-world-size', '1',
- '--ddp-backend', 'no_c10d',
- '--num-workers', '0',
- ] + (extra_flags or []),
+ "--arch",
+ arch,
+ "--encoder-layers",
+ "2",
+ "--num-classes",
+ str(num_classes),
+ "--optimizer",
+ "adam",
+ "--lr",
+ "0.0001",
+ "--criterion",
+ "sentence_prediction",
+ "--max-tokens",
+ "500",
+ "--max-positions",
+ "500",
+ "--batch-size",
+ "500",
+ "--save-dir",
+ data_dir,
+ "--max-epoch",
+ "1",
+ "--no-progress-bar",
+ "--distributed-world-size",
+ "1",
+ "--ddp-backend",
+ "no_c10d",
+ "--num-workers",
+ "0",
+ ]
+ + (extra_flags or []),
)
train.main(train_args)
@@ -1013,22 +1427,36 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False)
train_args = options.parse_args_and_arch(
train_parser,
[
- '--task', 'language_modeling',
+ "--task",
+ "language_modeling",
+ data_dir,
+ "--arch",
+ arch,
+ "--optimizer",
+ "adam",
+ "--lr",
+ "0.0001",
+ "--criterion",
+ "adaptive_loss",
+ "--adaptive-softmax-cutoff",
+ "5,10,15",
+ "--max-tokens",
+ "500",
+ "--tokens-per-sample",
+ "500",
+ "--save-dir",
data_dir,
- '--arch', arch,
- '--optimizer', 'adam',
- '--lr', '0.0001',
- '--criterion', 'adaptive_loss',
- '--adaptive-softmax-cutoff', '5,10,15',
- '--max-tokens', '500',
- '--tokens-per-sample', '500',
- '--save-dir', data_dir,
- '--max-epoch', '1',
- '--no-progress-bar',
- '--distributed-world-size', '1',
- '--ddp-backend', 'no_c10d',
- '--num-workers', '0',
- ] + (extra_flags or []),
+ "--max-epoch",
+ "1",
+ "--no-progress-bar",
+ "--distributed-world-size",
+ "1",
+ "--ddp-backend",
+ "no_c10d",
+ "--num-workers",
+ "0",
+ ]
+ + (extra_flags or []),
)
train.main(train_args)
@@ -1038,14 +1466,19 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False)
validate_args = options.parse_args_and_arch(
validate_parser,
[
- '--task', 'language_modeling',
+ "--task",
+ "language_modeling",
data_dir,
- '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
- '--valid-subset', 'valid',
- '--max-tokens', '500',
- '--no-progress-bar',
- '--num-workers', '0',
- ]
+ "--path",
+ os.path.join(data_dir, "checkpoint_last.pt"),
+ "--valid-subset",
+ "valid",
+ "--max-tokens",
+ "500",
+ "--no-progress-bar",
+ "--num-workers",
+ "0",
+ ],
)
validate.main(validate_args)
@@ -1056,9 +1489,11 @@ def eval_lm_main(data_dir):
eval_lm_parser,
[
data_dir,
- '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
- '--no-progress-bar',
- '--num-workers', '0',
+ "--path",
+ os.path.join(data_dir, "checkpoint_last.pt"),
+ "--no-progress-bar",
+ "--num-workers",
+ "0",
],
)
eval_lm.main(eval_lm_args)
@@ -1124,10 +1559,11 @@ def train_masked_language_model(data_dir, arch, extra_args=()):
"raw",
"--num-workers",
"0",
- ] + list(extra_args),
+ ]
+ + list(extra_args),
)
train.main(train_args)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_bmuf.py b/tests/test_bmuf.py
index 30563bdb50..0165b2955b 100644
--- a/tests/test_bmuf.py
+++ b/tests/test_bmuf.py
@@ -4,13 +4,12 @@
# LICENSE file in the root directory of this source tree.
import argparse
-from multiprocessing import Manager
import random
import unittest
+from multiprocessing import Manager
import torch
import torch.nn as nn
-
from fairseq import distributed_utils, optim
@@ -169,5 +168,5 @@ def assertAlmostEqual(self, t1, t2):
self.assertLess((t1 - t2).abs().max(), 1e-4)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_character_token_embedder.py b/tests/test_character_token_embedder.py
index 81042c2a3f..24940ebd21 100644
--- a/tests/test_character_token_embedder.py
+++ b/tests/test_character_token_embedder.py
@@ -3,9 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
import unittest
+import torch
from fairseq.data import Dictionary
from fairseq.modules import CharacterTokenEmbedder
@@ -13,12 +13,14 @@
class TestCharacterTokenEmbedder(unittest.TestCase):
def test_character_token_embedder(self):
vocab = Dictionary()
- vocab.add_symbol('hello')
- vocab.add_symbol('there')
+ vocab.add_symbol("hello")
+ vocab.add_symbol("there")
- embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2)
+ embedder = CharacterTokenEmbedder(
+ vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2
+ )
- test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']]
+ test_sents = [["hello", "unk", "there"], ["there"], ["hello", "there"]]
max_len = max(len(s) for s in test_sents)
input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad())
for i in range(len(test_sents)):
@@ -42,5 +44,5 @@ def assertAlmostEqual(self, t1, t2):
self.assertLess((t1 - t2).abs().max(), 1e-6)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_concat_dataset.py b/tests/test_concat_dataset.py
index dbdb2ac518..d94aeffd48 100644
--- a/tests/test_concat_dataset.py
+++ b/tests/test_concat_dataset.py
@@ -40,25 +40,19 @@ def setUp(self):
)
def test_concat_dataset_basics(self):
- d = ConcatDataset(
- [self.dataset_1, self.dataset_2]
- )
- assert(len(d) == 2)
- assert(d[0]['source'][0] == 1)
- assert(d[1]['source'][0] == 2)
+ d = ConcatDataset([self.dataset_1, self.dataset_2])
+ assert len(d) == 2
+ assert d[0]["source"][0] == 1
+ assert d[1]["source"][0] == 2
- d = ConcatDataset(
- [self.dataset_1, self.dataset_2], sample_ratios=[1, 2]
- )
- assert(len(d) == 3)
- assert(d[0]['source'][0] == 1)
- assert(d[1]['source'][0] == 2)
- assert(d[2]['source'][0] == 2)
+ d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[1, 2])
+ assert len(d) == 3
+ assert d[0]["source"][0] == 1
+ assert d[1]["source"][0] == 2
+ assert d[2]["source"][0] == 2
- d = ConcatDataset(
- [self.dataset_1, self.dataset_2], sample_ratios=[2, 1]
- )
- assert(len(d) == 3)
- assert(d[0]['source'][0] == 1)
- assert(d[1]['source'][0] == 1)
- assert(d[2]['source'][0] == 2)
+ d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[2, 1])
+ assert len(d) == 3
+ assert d[0]["source"][0] == 1
+ assert d[1]["source"][0] == 1
+ assert d[2]["source"][0] == 2
diff --git a/tests/test_constraints.py b/tests/test_constraints.py
index 3f63c8ace5..1c37f7e1fb 100755
--- a/tests/test_constraints.py
+++ b/tests/test_constraints.py
@@ -4,9 +4,9 @@
# LICENSE file in the root directory of this source tree.
import sys
-import torch
import unittest
+import torch
from fairseq.token_generation_constraints import *
@@ -17,26 +17,27 @@ def tensorize(constraints: List[List[int]]) -> torch.Tensor:
class TestHelperRoutines(unittest.TestCase):
def setUp(self):
self.examples = [
+ ([[]], torch.tensor([[0]])),
+ ([[], []], torch.tensor([[0], [0]])),
+ ([[torch.tensor([1, 2])], []], torch.tensor([[1, 1, 2, 0], [0, 0, 0, 0]])),
(
- [[]],
- torch.tensor([[0]])
- ),
- (
- [[], []],
- torch.tensor([[0], [0]])
- ),
- (
- [[torch.tensor([1, 2])], []],
- torch.tensor([[1, 1, 2, 0], [0, 0, 0, 0]])
+ [
+ [
+ torch.tensor([3, 1, 2]),
+ torch.tensor([3]),
+ torch.tensor([4, 5, 6, 7]),
+ ],
+ [],
+ [torch.tensor([1, 8, 9, 10, 1, 4, 11, 12])],
+ ],
+ torch.tensor(
+ [
+ [3, 3, 1, 2, 0, 3, 0, 4, 5, 6, 7, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 1, 8, 9, 10, 1, 4, 11, 12, 0, 0, 0],
+ ]
+ ),
),
- (
- [[torch.tensor([3, 1, 2]), torch.tensor([3]), torch.tensor([4, 5, 6, 7])],
- [],
- [ torch.tensor([1, 8, 9, 10, 1, 4, 11, 12]) ]],
- torch.tensor([[3, 3, 1, 2, 0, 3, 0, 4, 5, 6, 7, 0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 8, 9, 10, 1, 4, 11, 12, 0, 0, 0]])
- )
]
def test_packing(self):
@@ -53,20 +54,24 @@ def setUp(self):
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
"([None].False#6 ([1].True#4 ([2].False#1 [3].True#1) [3].True#1 [4].True#1) ([4].False#2 ([5].True#2 ([6].False#1 [7].True#1))))",
- { 1: 4, 2: 1, 3: 2, 4: 3, 5: 2, 6: 1, 7: 1 }
+ {1: 4, 2: 1, 3: 2, 4: 3, 5: 2, 6: 1, 7: 1},
+ ),
+ ([], "[None].False#0", {}),
+ (tensorize([[0]]), "([None].False#1 [0].True#1)", {0: 1}),
+ (
+ tensorize([[100000, 1, 2, 3, 4, 5]]),
+ "([None].False#1 ([100000].False#1 ([1].False#1 ([2].False#1 ([3].False#1 ([4].False#1 [5].True#1))))))",
+ {100000: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1},
),
- ( [], "[None].False#0", {} ),
- ( tensorize([[0]]), "([None].False#1 [0].True#1)", { 0: 1 } ),
- ( tensorize([[100000, 1, 2, 3, 4, 5]]), "([None].False#1 ([100000].False#1 ([1].False#1 ([2].False#1 ([3].False#1 ([4].False#1 [5].True#1))))))", { 100000: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1 } ),
(
tensorize([[1, 2], [1, 2]]),
"([None].False#2 ([1].False#2 [2].True#2))",
- { 1: 2, 2: 2 },
+ {1: 2, 2: 2},
),
(
tensorize([[1, 2], [3, 4]]),
"([None].False#2 ([1].False#1 [2].True#1) ([3].False#1 [4].True#1))",
- { 1: 1, 2: 1, 3: 1, 4: 1},
+ {1: 1, 2: 1, 3: 1, 4: 1},
),
]
@@ -74,65 +79,65 @@ def setUp(self):
(
self.examples[0][0],
[],
- { "bank": 0, "num_completed": 0, "finished": False, "is_root": True },
+ {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
),
(
self.examples[0][0],
[1, 2],
- { "bank": 2, "num_completed": 0, "finished": False, "is_root": False },
+ {"bank": 2, "num_completed": 0, "finished": False, "is_root": False},
),
(
self.examples[0][0],
[1, 2, 94],
- { "bank": 1, "num_completed": 1, "finished": False, "is_root": True },
+ {"bank": 1, "num_completed": 1, "finished": False, "is_root": True},
),
(
self.examples[0][0],
[1, 3, 999, 1, 4],
- { "bank": 4, "num_completed": 2, "finished": False, "is_root": False },
+ {"bank": 4, "num_completed": 2, "finished": False, "is_root": False},
),
(
self.examples[0][0],
[1, 3, 999, 1, 4, 999],
- { "bank": 4, "num_completed": 2, "finished": False, "is_root": True },
+ {"bank": 4, "num_completed": 2, "finished": False, "is_root": True},
),
(
self.examples[0][0],
[4, 5, 6, 8],
- { "bank": 2, "num_completed": 1, "finished": False, "is_root": True },
+ {"bank": 2, "num_completed": 1, "finished": False, "is_root": True},
),
(
self.examples[0][0],
# Tricky, because in last three, goes down [1->4] branch, could miss [1] and [4->5]
# [[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]],
[1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5],
- { "bank": 14, "num_completed": 6, "finished": True, "is_root": False },
+ {"bank": 14, "num_completed": 6, "finished": True, "is_root": False},
),
(
self.examples[0][0],
[1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117],
- { "bank": 14, "num_completed": 6, "finished": True, "is_root": True },
+ {"bank": 14, "num_completed": 6, "finished": True, "is_root": True},
),
(
tensorize([[1], [2, 3]]),
# Should not be able to get credit for entering 1 a second time
[1, 1],
- { "bank": 1, "num_completed": 1, "finished": False, "is_root": True },
+ {"bank": 1, "num_completed": 1, "finished": False, "is_root": True},
),
(
self.examples[4][0],
[1, 2, 1, 2],
- { "bank": 4, "num_completed": 2, "finished": True, "is_root": False },
+ {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
),
(
self.examples[4][0],
[1, 2, 1, 2, 1],
- { "bank": 4, "num_completed": 2, "finished": True, "is_root": True },
+ {"bank": 4, "num_completed": 2, "finished": True, "is_root": True},
),
(
self.examples[5][0],
[1, 2, 3, 4, 5],
- { "bank": 4, "num_completed": 2, "finished": True, "is_root": True },
+ {"bank": 4, "num_completed": 2, "finished": True, "is_root": True},
),
]
@@ -143,8 +148,12 @@ def test_graphs(self):
for example in self.examples:
constraints, expected, gold_counts = example
c = ConstraintNode.create(constraints)
- assert ConstraintNode.print_graph(c) == expected, f"got {ConstraintNode.print_graph(c)}, expected {expected}"
- assert c.token_counts() == gold_counts, f"{c} got {c.token_counts()} wanted {gold_counts}"
+ assert (
+ ConstraintNode.print_graph(c) == expected
+ ), f"got {ConstraintNode.print_graph(c)}, expected {expected}"
+ assert (
+ c.token_counts() == gold_counts
+ ), f"{c} got {c.token_counts()} wanted {gold_counts}"
def test_next_tokens(self):
"""
@@ -159,7 +168,9 @@ def test_next_tokens(self):
state = UnorderedConstraintState(root)
for token in sequence:
all_tokens = root_tokens.union(state.node.children.keys())
- assert all_tokens == state.next_tokens(), f"ALL {all_tokens} NEXT {state.next_tokens()}"
+ assert (
+ all_tokens == state.next_tokens()
+ ), f"ALL {all_tokens} NEXT {state.next_tokens()}"
state = state.advance(token)
def test_sequences(self):
@@ -171,7 +182,9 @@ def test_sequences(self):
for attr in expected.keys():
result[attr] = getattr(state, attr)
- assert result == expected, f"TEST({tokens}) GOT: {result} WANTED: {expected}"
+ assert (
+ result == expected
+ ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"
class TestOrderedConstraintState(unittest.TestCase):
@@ -180,62 +193,62 @@ def setUp(self):
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[],
- { "bank": 0, "num_completed": 0, "finished": False, "is_root": True },
+ {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2],
- { "bank": 2, "num_completed": 0, "finished": False, "is_root": False },
+ {"bank": 2, "num_completed": 0, "finished": False, "is_root": False},
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 94],
- { "bank": 0, "num_completed": 0, "finished": False, "is_root": True },
+ {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 3, 999, 1, 4],
- { "bank": 0, "num_completed": 0, "finished": False, "is_root": True },
+ {"bank": 0, "num_completed": 0, "finished": False, "is_root": True},
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 3, 999, 999],
- { "bank": 3, "num_completed": 1, "finished": False, "is_root": False },
+ {"bank": 3, "num_completed": 1, "finished": False, "is_root": False},
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 3, 77, 1, 3, 1],
- { "bank": 6, "num_completed": 2, "finished": False, "is_root": False },
+ {"bank": 6, "num_completed": 2, "finished": False, "is_root": False},
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5],
- { "bank": 14, "num_completed": 6, "finished": True, "is_root": False },
+ {"bank": 14, "num_completed": 6, "finished": True, "is_root": False},
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 999, 1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117],
- { "bank": 14, "num_completed": 6, "finished": True, "is_root": False },
+ {"bank": 14, "num_completed": 6, "finished": True, "is_root": False},
),
(
tensorize([[1], [2, 3]]),
[1, 1],
- { "bank": 1, "num_completed": 1, "finished": False, "is_root": False },
+ {"bank": 1, "num_completed": 1, "finished": False, "is_root": False},
),
(
tensorize([[1, 2], [1, 2]]),
[1, 2, 1, 2],
- { "bank": 4, "num_completed": 2, "finished": True, "is_root": False },
+ {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
),
(
tensorize([[1, 2], [1, 2]]),
[1, 2, 1, 2, 1],
- { "bank": 4, "num_completed": 2, "finished": True, "is_root": False },
+ {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
),
(
tensorize([[1, 2], [3, 4]]),
[1, 2, 3, 4, 5],
- { "bank": 4, "num_completed": 2, "finished": True, "is_root": False },
+ {"bank": 4, "num_completed": 2, "finished": True, "is_root": False},
),
]
@@ -247,8 +260,10 @@ def test_sequences(self):
result = {}
for attr in expected.keys():
result[attr] = getattr(state, attr)
- assert result == expected, f"TEST({tokens}) GOT: {result} WANTED: {expected}"
+ assert (
+ result == expected
+ ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"
+
if __name__ == "__main__":
unittest.main()
-
diff --git a/tests/test_convtbc.py b/tests/test_convtbc.py
index fc2ac0b5dc..3a3c9b91e7 100644
--- a/tests/test_convtbc.py
+++ b/tests/test_convtbc.py
@@ -3,14 +3,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
import unittest
-from fairseq.modules import ConvTBC
+
+import torch
import torch.nn as nn
+from fairseq.modules import ConvTBC
class TestConvTBC(unittest.TestCase):
-
def test_convtbc(self):
# ksz, in_channels, out_channels
conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1)
@@ -27,7 +27,9 @@ def test_convtbc(self):
output_tbc = conv_tbc(input_tbc)
output1d = conv1d(input1d)
- self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data)
+ self.assertAlmostEqual(
+ output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data
+ )
grad_tbc = torch.randn(output_tbc.size())
grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous()
@@ -35,14 +37,18 @@ def test_convtbc(self):
output_tbc.backward(grad_tbc)
output1d.backward(grad1d)
- self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data)
+ self.assertAlmostEqual(
+ conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data
+ )
self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data)
- self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data)
+ self.assertAlmostEqual(
+ input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data
+ )
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py
index d9a1ec72c8..81ce102f4f 100644
--- a/tests/test_dictionary.py
+++ b/tests/test_dictionary.py
@@ -8,31 +8,39 @@
import unittest
import torch
-
from fairseq.data import Dictionary
class TestDictionary(unittest.TestCase):
-
def test_finalize(self):
txt = [
- 'A B C D',
- 'B C D',
- 'C D',
- 'D',
+ "A B C D",
+ "B C D",
+ "C D",
+ "D",
]
- ref_ids1 = list(map(torch.IntTensor, [
- [4, 5, 6, 7, 2],
- [5, 6, 7, 2],
- [6, 7, 2],
- [7, 2],
- ]))
- ref_ids2 = list(map(torch.IntTensor, [
- [7, 6, 5, 4, 2],
- [6, 5, 4, 2],
- [5, 4, 2],
- [4, 2],
- ]))
+ ref_ids1 = list(
+ map(
+ torch.IntTensor,
+ [
+ [4, 5, 6, 7, 2],
+ [5, 6, 7, 2],
+ [6, 7, 2],
+ [7, 2],
+ ],
+ )
+ )
+ ref_ids2 = list(
+ map(
+ torch.IntTensor,
+ [
+ [7, 6, 5, 4, 2],
+ [6, 5, 4, 2],
+ [5, 4, 2],
+ [4, 2],
+ ],
+ )
+ )
# build dictionary
d = Dictionary()
@@ -59,7 +67,7 @@ def assertMatch(ids, ref_ids):
assertMatch(finalized_ids, ref_ids2)
# write to disk and reload
- with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
+ with tempfile.NamedTemporaryFile(mode="w") as tmp_dict:
d.save(tmp_dict.name)
d = Dictionary.load(tmp_dict.name)
reload_ids = get_ids(d)
@@ -77,40 +85,32 @@ def test_overwrite(self):
)
d = Dictionary()
d.add_from_file(dict_file)
- self.assertEqual(d.index(''), 1)
- self.assertEqual(d.index('foo'), 3)
- self.assertEqual(d.index(''), 4)
- self.assertEqual(d.index(''), 5)
- self.assertEqual(d.index(''), 6)
- self.assertEqual(d.index(','), 7)
- self.assertEqual(d.index('▁de'), 8)
+ self.assertEqual(d.index(""), 1)
+ self.assertEqual(d.index("foo"), 3)
+ self.assertEqual(d.index(""), 4)
+ self.assertEqual(d.index(""), 5)
+ self.assertEqual(d.index(""), 6)
+ self.assertEqual(d.index(","), 7)
+ self.assertEqual(d.index("▁de"), 8)
def test_no_overwrite(self):
# for example, Camembert overwrites , and
dict_file = io.StringIO(
- " 999\n"
- " 999\n"
- " 999\n"
- ", 999\n"
- "▁de 999\n"
+ " 999\n" " 999\n" " 999\n" ", 999\n" "▁de 999\n"
)
d = Dictionary()
- with self.assertRaisesRegex(RuntimeError, 'Duplicate'):
+ with self.assertRaisesRegex(RuntimeError, "Duplicate"):
d.add_from_file(dict_file)
def test_space(self):
# for example, character models treat space as a symbol
- dict_file = io.StringIO(
- " 999\n"
- "a 999\n"
- "b 999\n"
- )
+ dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n")
d = Dictionary()
d.add_from_file(dict_file)
- self.assertEqual(d.index(' '), 4)
- self.assertEqual(d.index('a'), 5)
- self.assertEqual(d.index('b'), 6)
+ self.assertEqual(d.index(" "), 4)
+ self.assertEqual(d.index("a"), 5)
+ self.assertEqual(d.index("b"), 6)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_file_io.py b/tests/test_file_io.py
index ffcc6a3eef..aef5b80d18 100644
--- a/tests/test_file_io.py
+++ b/tests/test_file_io.py
@@ -1,14 +1,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import sys
-import tempfile
import os
import shutil
-
-from typing import Optional
-
+import sys
+import tempfile
import unittest
+from typing import Optional
from unittest.mock import MagicMock
@@ -34,14 +32,16 @@ def tearDownClass(cls) -> None:
def test_file_io(self):
from fairseq.file_io import PathManager
+
with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f:
s = f.read()
self.assertEqual(s, self._tmpfile_contents)
def test_file_io_oss(self):
# Mock fvcore to simulate oss environment.
- sys.modules['fvcore'] = MagicMock()
+ sys.modules["fvcore"] = MagicMock()
from fairseq.file_io import PathManager
+
with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f:
s = f.read()
self.assertEqual(s, self._tmpfile_contents)
diff --git a/tests/test_fp16_optimizer.py b/tests/test_fp16_optimizer.py
index bca341af1a..c4195273e3 100644
--- a/tests/test_fp16_optimizer.py
+++ b/tests/test_fp16_optimizer.py
@@ -8,13 +8,11 @@
import unittest
import torch
-
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
-@unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU')
+@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestGradientScaling(unittest.TestCase):
-
def setUp(self):
self.x = torch.tensor([2.0]).cuda().half()
weight = 3.0
@@ -30,16 +28,16 @@ def setUp(self):
self.params = list(self.model.parameters())
self.namespace_dls = argparse.Namespace(
- optimizer='adam',
+ optimizer="adam",
lr=[0.1],
- adam_betas='(0.9, 0.999)',
+ adam_betas="(0.9, 0.999)",
adam_eps=1e-8,
weight_decay=0.0,
fp16_init_scale=1,
fp16_scale_window=1,
fp16_scale_tolerance=1,
threshold_loss_scale=1,
- min_loss_scale=1e-4
+ min_loss_scale=1e-4,
)
def run_iter(self, model, params, optimizer):
@@ -47,15 +45,25 @@ def run_iter(self, model, params, optimizer):
y = model(self.x)
loss = self.loss_fn(y, self.target)
optimizer.backward(loss)
- self.assertEqual(loss, torch.tensor(1., device='cuda:0', dtype=torch.float16))
+ self.assertEqual(loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16))
grad_norm = optimizer.clip_grad_norm(0)
self.assertAlmostEqual(grad_norm.item(), 2.2361, 4)
optimizer.step()
- self.assertEqual(model.weight, torch.tensor([[3.0996]], device='cuda:0', dtype=torch.float16, requires_grad=True))
- self.assertEqual(model.bias, torch.tensor([5.1016], device='cuda:0', dtype=torch.float16, requires_grad=True))
- self.assertEqual(optimizer.scaler.loss_scale, 2.)
+ self.assertEqual(
+ model.weight,
+ torch.tensor(
+ [[3.0996]], device="cuda:0", dtype=torch.float16, requires_grad=True
+ ),
+ )
+ self.assertEqual(
+ model.bias,
+ torch.tensor(
+ [5.1016], device="cuda:0", dtype=torch.float16, requires_grad=True
+ ),
+ )
+ self.assertEqual(optimizer.scaler.loss_scale, 2.0)
def test_mixed_precision(self):
model = copy.deepcopy(self.model)
@@ -63,18 +71,28 @@ def test_mixed_precision(self):
optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params)
self.run_iter(model, params, optimizer)
- self.assertTrue(all(
- torch.all(fp32_params.eq(torch.tensor([3.1000, 5.1000], device='cuda:0', requires_grad=True)))
- for fp32_params in optimizer.fp32_params.values()
- ))
+ self.assertTrue(
+ all(
+ torch.all(
+ fp32_params.eq(
+ torch.tensor(
+ [3.1000, 5.1000], device="cuda:0", requires_grad=True
+ )
+ )
+ )
+ for fp32_params in optimizer.fp32_params.values()
+ )
+ )
def test_memory_efficient(self):
model = copy.deepcopy(self.model)
params = list(model.parameters())
- optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.namespace_dls, params)
+ optimizer = MemoryEfficientFP16Optimizer.build_optimizer(
+ self.namespace_dls, params
+ )
self.run_iter(model, params, optimizer)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_inference_dropout.py b/tests/test_inference_dropout.py
index 4857bc7a87..fd5edd43d6 100644
--- a/tests/test_inference_dropout.py
+++ b/tests/test_inference_dropout.py
@@ -6,12 +6,11 @@
import logging
import unittest
-from tests.test_sequence_generator import get_dummy_task_and_parser
from fairseq.models.transformer import TransformerModel
+from tests.test_sequence_generator import get_dummy_task_and_parser
class TestInferenceDropout(unittest.TestCase):
-
def setUp(self):
self.task, self.parser = get_dummy_task_and_parser()
TransformerModel.add_args(self.parser)
@@ -55,7 +54,10 @@ def test_applies_training_mode(self):
def test_retain_modules(self):
self.args.retain_dropout = True
- self.args.retain_dropout_modules = ['TransformerEncoder', 'TransformerEncoderLayer']
+ self.args.retain_dropout_modules = [
+ "TransformerEncoder",
+ "TransformerEncoderLayer",
+ ]
self.transformer_model = TransformerModel.build_model(self.args, self.task)
self.transformer_model.prepare_for_inference_(self.args)
assert self.transformer_model.encoder.dropout_module.apply_during_inference
diff --git a/tests/test_iterators.py b/tests/test_iterators.py
index 9e444d154b..3d2c4d6251 100644
--- a/tests/test_iterators.py
+++ b/tests/test_iterators.py
@@ -9,7 +9,6 @@
class TestIterators(unittest.TestCase):
-
def test_counting_iterator(self, ref=None, itr=None):
if ref is None:
assert itr is None
@@ -109,7 +108,7 @@ def test_counting_iterator_buffered_iterator_take(self):
self.assertFalse(itr.has_next())
self.assertRaises(StopIteration, next, buffered_itr)
- ref = list(range(4,10))
+ ref = list(range(4, 10))
buffered_itr = iterators.BufferedIterator(2, ref)
itr = iterators.CountingIterator(buffered_itr, start=4)
itr.take(5)
@@ -120,5 +119,5 @@ def test_counting_iterator_buffered_iterator_take(self):
self.assertRaises(StopIteration, next, buffered_itr)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_label_smoothing.py b/tests/test_label_smoothing.py
index 94e5ccf1f3..04c0f974ac 100644
--- a/tests/test_label_smoothing.py
+++ b/tests/test_label_smoothing.py
@@ -7,16 +7,15 @@
import copy
import unittest
+import tests.utils as test_utils
import torch
-
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
-from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
-
-import tests.utils as test_utils
+from fairseq.criterions.label_smoothed_cross_entropy import (
+ LabelSmoothedCrossEntropyCriterion,
+)
class TestLabelSmoothing(unittest.TestCase):
-
def setUp(self):
# build dictionary
self.d = test_utils.dummy_dictionary(3)
@@ -30,8 +29,14 @@ def setUp(self):
# build dataset
self.data = [
# the first batch item has padding
- {'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, eos])},
- {'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, w1, eos])},
+ {
+ "source": torch.LongTensor([w1, eos]),
+ "target": torch.LongTensor([w1, eos]),
+ },
+ {
+ "source": torch.LongTensor([w1, eos]),
+ "target": torch.LongTensor([w1, w1, eos]),
+ },
]
self.sample = next(test_utils.dummy_dataloader(self.data))
@@ -39,23 +44,35 @@ def setUp(self):
self.args = argparse.Namespace()
self.args.sentence_avg = False
self.args.report_accuracy = False
- self.args.probs = torch.FloatTensor([
- # pad eos unk w1 w2 w3
- [0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],
- [0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
- [0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
- ]).unsqueeze(0).expand(2, 3, 7) # add batch dimension
+ self.args.probs = (
+ torch.FloatTensor(
+ [
+ # pad eos unk w1 w2 w3
+ [0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],
+ [0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
+ [0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
+ ]
+ )
+ .unsqueeze(0)
+ .expand(2, 3, 7)
+ ) # add batch dimension
self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d)
self.model = self.task.build_model(self.args)
def test_nll_loss(self):
self.args.label_smoothing = 0.1
nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task)
- smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task)
- nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
- smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
- self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6)
- self.assertLess(abs(nll_loss - smooth_logging_output['nll_loss']), 1e-6)
+ smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(
+ self.args, self.task
+ )
+ nll_loss, nll_sample_size, nll_logging_output = nll_crit(
+ self.model, self.sample
+ )
+ smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(
+ self.model, self.sample
+ )
+ self.assertLess(abs(nll_loss - nll_logging_output["loss"]), 1e-6)
+ self.assertLess(abs(nll_loss - smooth_logging_output["nll_loss"]), 1e-6)
def test_padding(self):
self.args.label_smoothing = 0.1
@@ -86,9 +103,15 @@ def test_reduction(self):
def test_zero_eps(self):
self.args.label_smoothing = 0.0
nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task)
- smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task)
- nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
- smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
+ smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(
+ self.args, self.task
+ )
+ nll_loss, nll_sample_size, nll_logging_output = nll_crit(
+ self.model, self.sample
+ )
+ smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(
+ self.model, self.sample
+ )
self.assertAlmostEqual(nll_loss, smooth_loss)
def assertAlmostEqual(self, t1, t2):
@@ -96,5 +119,5 @@ def assertAlmostEqual(self, t1, t2):
self.assertLess((t1 - t2).abs().max(), 1e-6)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_lstm_jitable.py b/tests/test_lstm_jitable.py
index d97652fb77..38f79d1793 100644
--- a/tests/test_lstm_jitable.py
+++ b/tests/test_lstm_jitable.py
@@ -92,19 +92,21 @@ def test_assert_jit_vs_nonjit_(self):
idx = len(task.source_dictionary)
iter = 100
# Inject random input and check output
- seq_len_tensor = torch.randint(1, 10, (iter, ))
- num_samples_tensor = torch.randint(1, 10, (iter, ))
+ seq_len_tensor = torch.randint(1, 10, (iter,))
+ num_samples_tensor = torch.randint(1, 10, (iter,))
for i in range(iter):
seq_len = seq_len_tensor[i]
num_samples = num_samples_tensor[i]
- src_token = torch.randint(0, idx, (num_samples, seq_len)),
- src_lengths = torch.randint(1, seq_len+1, (num_samples,))
+ src_token = (torch.randint(0, idx, (num_samples, seq_len)),)
+ src_lengths = torch.randint(1, seq_len + 1, (num_samples,))
src_lengths, _ = torch.sort(src_lengths, descending=True)
# Force the first sample to have seq_len
src_lengths[0] = seq_len
- prev_output_token = torch.randint(0, idx, (num_samples, 1)),
+ prev_output_token = (torch.randint(0, idx, (num_samples, 1)),)
result = model(src_token[0], src_lengths, prev_output_token[0], None)
- scripted_result = scripted_model(src_token[0], src_lengths, prev_output_token[0], None)
+ scripted_result = scripted_model(
+ src_token[0], src_lengths, prev_output_token[0], None
+ )
self.assertTensorEqual(result[0], scripted_result[0])
self.assertTensorEqual(result[1], scripted_result[1])
diff --git a/tests/test_memory_efficient_fp16.py b/tests/test_memory_efficient_fp16.py
index bd2b8faeb4..e10636d96a 100644
--- a/tests/test_memory_efficient_fp16.py
+++ b/tests/test_memory_efficient_fp16.py
@@ -8,14 +8,12 @@
import unittest
import torch
-
from fairseq.optim.adam import FairseqAdam
from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
-@unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU')
+@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestMemoryEfficientFP16(unittest.TestCase):
-
def setUp(self):
logging.disable(logging.CRITICAL)
@@ -31,7 +29,7 @@ def test_load_state_dict(self):
optimizer = FairseqAdam(
argparse.Namespace(
lr=[0.00001],
- adam_betas='(0.9, 0.999)',
+ adam_betas="(0.9, 0.999)",
adam_eps=1e-8,
weight_decay=0.0,
),
@@ -64,5 +62,5 @@ def test_load_state_dict(self):
self.assertTrue(v_i.dtype == torch.float32)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 060291808e..2de6969cf4 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -10,69 +10,68 @@
class TestMetrics(unittest.TestCase):
-
def test_nesting(self):
with metrics.aggregate() as a:
- metrics.log_scalar('loss', 1)
+ metrics.log_scalar("loss", 1)
with metrics.aggregate() as b:
- metrics.log_scalar('loss', 2)
+ metrics.log_scalar("loss", 2)
- self.assertEqual(a.get_smoothed_values()['loss'], 1.5)
- self.assertEqual(b.get_smoothed_values()['loss'], 2)
+ self.assertEqual(a.get_smoothed_values()["loss"], 1.5)
+ self.assertEqual(b.get_smoothed_values()["loss"], 2)
def test_new_root(self):
with metrics.aggregate() as a:
- metrics.log_scalar('loss', 1)
+ metrics.log_scalar("loss", 1)
with metrics.aggregate(new_root=True) as b:
- metrics.log_scalar('loss', 2)
+ metrics.log_scalar("loss", 2)
- self.assertEqual(a.get_smoothed_values()['loss'], 1)
- self.assertEqual(b.get_smoothed_values()['loss'], 2)
+ self.assertEqual(a.get_smoothed_values()["loss"], 1)
+ self.assertEqual(b.get_smoothed_values()["loss"], 2)
def test_nested_new_root(self):
with metrics.aggregate() as layer1:
- metrics.log_scalar('loss', 1)
+ metrics.log_scalar("loss", 1)
with metrics.aggregate(new_root=True) as layer2:
- metrics.log_scalar('loss', 2)
+ metrics.log_scalar("loss", 2)
with metrics.aggregate() as layer3:
- metrics.log_scalar('loss', 3)
+ metrics.log_scalar("loss", 3)
with metrics.aggregate(new_root=True) as layer4:
- metrics.log_scalar('loss', 4)
- metrics.log_scalar('loss', 1.5)
+ metrics.log_scalar("loss", 4)
+ metrics.log_scalar("loss", 1.5)
- self.assertEqual(layer4.get_smoothed_values()['loss'], 4)
- self.assertEqual(layer3.get_smoothed_values()['loss'], 3)
- self.assertEqual(layer2.get_smoothed_values()['loss'], 2.5)
- self.assertEqual(layer1.get_smoothed_values()['loss'], 1.25)
+ self.assertEqual(layer4.get_smoothed_values()["loss"], 4)
+ self.assertEqual(layer3.get_smoothed_values()["loss"], 3)
+ self.assertEqual(layer2.get_smoothed_values()["loss"], 2.5)
+ self.assertEqual(layer1.get_smoothed_values()["loss"], 1.25)
def test_named(self):
name = str(uuid.uuid4())
metrics.reset_meters(name)
with metrics.aggregate(name):
- metrics.log_scalar('loss', 1)
+ metrics.log_scalar("loss", 1)
- metrics.log_scalar('loss', 3)
+ metrics.log_scalar("loss", 3)
with metrics.aggregate(name):
- metrics.log_scalar('loss', 2)
+ metrics.log_scalar("loss", 2)
- self.assertEqual(metrics.get_smoothed_values(name)['loss'], 1.5)
+ self.assertEqual(metrics.get_smoothed_values(name)["loss"], 1.5)
def test_nested_duplicate_names(self):
name = str(uuid.uuid4())
metrics.reset_meters(name)
with metrics.aggregate(name):
- metrics.log_scalar('loss', 1)
+ metrics.log_scalar("loss", 1)
with metrics.aggregate() as other:
with metrics.aggregate(name):
- metrics.log_scalar('loss', 2)
- metrics.log_scalar('loss', 6)
+ metrics.log_scalar("loss", 2)
+ metrics.log_scalar("loss", 6)
- self.assertEqual(metrics.get_smoothed_values(name)['loss'], 3)
- self.assertEqual(other.get_smoothed_values()['loss'], 2)
+ self.assertEqual(metrics.get_smoothed_values(name)["loss"], 3)
+ self.assertEqual(other.get_smoothed_values()["loss"], 2)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_multihead_attention.py b/tests/test_multihead_attention.py
index 324d8e3eb5..9aa9cb2f87 100644
--- a/tests/test_multihead_attention.py
+++ b/tests/test_multihead_attention.py
@@ -3,8 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
import unittest
+
+import torch
from fairseq.modules.multihead_attention import MultiheadAttention
@@ -47,8 +48,8 @@ def test_append_prev_key_padding_mask(self):
if key_padding_mask is not None:
self.assertTrue(
torch.all(torch.eq(key_padding_mask, c[2])),
- f'Unexpected resultant key padding mask: {key_padding_mask}'
- f' given current: {c[0]} and previous: {c[1]}',
+ f"Unexpected resultant key padding mask: {key_padding_mask}"
+ f" given current: {c[0]} and previous: {c[1]}",
)
self.assertEqual(key_padding_mask.size(0), bsz)
self.assertEqual(key_padding_mask.size(1), src_len)
@@ -56,5 +57,5 @@ def test_append_prev_key_padding_mask(self):
self.assertIsNone(c[2])
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_noising.py b/tests/test_noising.py
index da792a1826..b3d0d123c4 100644
--- a/tests/test_noising.py
+++ b/tests/test_noising.py
@@ -408,7 +408,10 @@ def test_word_blank_without_eos(self):
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def _get_noising_dataset_batch(
- self, src_tokens_no_pad, src_dict, append_eos_to_tgt=False,
+ self,
+ src_tokens_no_pad,
+ src_dict,
+ append_eos_to_tgt=False,
):
"""
Constructs a NoisingDataset and the corresponding
@@ -433,7 +436,8 @@ def _get_noising_dataset_batch(
src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict
)
language_pair_dataset = TransformEosDataset(
- language_pair_dataset, src_dict.eos(),
+ language_pair_dataset,
+ src_dict.eos(),
append_eos_to_tgt=append_eos_to_tgt,
)
diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py
index 80f2948250..517e23c39e 100644
--- a/tests/test_reproducibility.py
+++ b/tests/test_reproducibility.py
@@ -4,11 +4,11 @@
# LICENSE file in the root directory of this source tree.
import contextlib
-from io import StringIO
import json
import os
import tempfile
import unittest
+from io import StringIO
import torch
@@ -16,13 +16,12 @@
class TestReproducibility(unittest.TestCase):
-
def _test_reproducibility(
self,
name,
extra_flags=None,
delta=0.0001,
- resume_checkpoint='checkpoint1.pt',
+ resume_checkpoint="checkpoint1.pt",
max_epoch=3,
):
def get_last_log_stats_containing_string(log_records, search_string):
@@ -41,63 +40,99 @@ def get_last_log_stats_containing_string(log_records, search_string):
# train epochs 1 and 2 together
with self.assertLogs() as logs:
test_binaries.train_translation_model(
- data_dir, 'fconv_iwslt_de_en', [
- '--dropout', '0.0',
- '--log-format', 'json',
- '--log-interval', '1',
- '--max-epoch', str(max_epoch),
- ] + extra_flags,
+ data_dir,
+ "fconv_iwslt_de_en",
+ [
+ "--dropout",
+ "0.0",
+ "--log-format",
+ "json",
+ "--log-interval",
+ "1",
+ "--max-epoch",
+ str(max_epoch),
+ ]
+ + extra_flags,
)
- train_log = get_last_log_stats_containing_string(logs.records, 'train_loss')
- valid_log = get_last_log_stats_containing_string(logs.records, 'valid_loss')
+ train_log = get_last_log_stats_containing_string(logs.records, "train_loss")
+ valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss")
# train epoch 2, resuming from previous checkpoint 1
os.rename(
os.path.join(data_dir, resume_checkpoint),
- os.path.join(data_dir, 'checkpoint_last.pt'),
+ os.path.join(data_dir, "checkpoint_last.pt"),
)
with self.assertLogs() as logs:
test_binaries.train_translation_model(
- data_dir, 'fconv_iwslt_de_en', [
- '--dropout', '0.0',
- '--log-format', 'json',
- '--log-interval', '1',
- '--max-epoch', str(max_epoch),
- ] + extra_flags,
+ data_dir,
+ "fconv_iwslt_de_en",
+ [
+ "--dropout",
+ "0.0",
+ "--log-format",
+ "json",
+ "--log-interval",
+ "1",
+ "--max-epoch",
+ str(max_epoch),
+ ]
+ + extra_flags,
)
- train_res_log = get_last_log_stats_containing_string(logs.records, 'train_loss')
- valid_res_log = get_last_log_stats_containing_string(logs.records, 'valid_loss')
+ train_res_log = get_last_log_stats_containing_string(
+ logs.records, "train_loss"
+ )
+ valid_res_log = get_last_log_stats_containing_string(
+ logs.records, "valid_loss"
+ )
- for k in ['train_loss', 'train_ppl', 'train_num_updates', 'train_gnorm']:
- self.assertAlmostEqual(float(train_log[k]), float(train_res_log[k]), delta=delta)
- for k in ['valid_loss', 'valid_ppl', 'valid_num_updates', 'valid_best_loss']:
- self.assertAlmostEqual(float(valid_log[k]), float(valid_res_log[k]), delta=delta)
+ for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]:
+ self.assertAlmostEqual(
+ float(train_log[k]), float(train_res_log[k]), delta=delta
+ )
+ for k in [
+ "valid_loss",
+ "valid_ppl",
+ "valid_num_updates",
+ "valid_best_loss",
+ ]:
+ self.assertAlmostEqual(
+ float(valid_log[k]), float(valid_res_log[k]), delta=delta
+ )
def test_reproducibility(self):
- self._test_reproducibility('test_reproducibility')
+ self._test_reproducibility("test_reproducibility")
- @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU')
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
def test_reproducibility_fp16(self):
- self._test_reproducibility('test_reproducibility_fp16', [
- '--fp16',
- '--fp16-init-scale', '4096',
- ], delta=0.011)
+ self._test_reproducibility(
+ "test_reproducibility_fp16",
+ [
+ "--fp16",
+ "--fp16-init-scale",
+ "4096",
+ ],
+ delta=0.011,
+ )
- @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU')
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
def test_reproducibility_memory_efficient_fp16(self):
- self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [
- '--memory-efficient-fp16',
- '--fp16-init-scale', '4096',
- ])
+ self._test_reproducibility(
+ "test_reproducibility_memory_efficient_fp16",
+ [
+ "--memory-efficient-fp16",
+ "--fp16-init-scale",
+ "4096",
+ ],
+ )
def test_mid_epoch_reproducibility(self):
self._test_reproducibility(
- 'test_mid_epoch_reproducibility',
- ['--save-interval-updates', '3'],
- resume_checkpoint='checkpoint_1_3.pt',
+ "test_mid_epoch_reproducibility",
+ ["--save-interval-updates", "3"],
+ resume_checkpoint="checkpoint_1_3.pt",
max_epoch=1,
)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_resampling_dataset.py b/tests/test_resampling_dataset.py
index 0d142f5a8d..ccb53a253c 100644
--- a/tests/test_resampling_dataset.py
+++ b/tests/test_resampling_dataset.py
@@ -7,7 +7,6 @@
import unittest
import numpy as np
-
from fairseq.data import ListDataset, ResamplingDataset
diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py
index 517aa77d59..c890b655ff 100644
--- a/tests/test_sequence_generator.py
+++ b/tests/test_sequence_generator.py
@@ -11,9 +11,8 @@
import torch
from fairseq import search
from fairseq.data.dictionary import Dictionary
-
from fairseq.models.transformer import TransformerModel
-from fairseq.sequence_generator import SequenceGenerator, EnsembleModel
+from fairseq.sequence_generator import EnsembleModel, SequenceGenerator
from fairseq.tasks.fairseq_task import LegacyFairseqTask
@@ -109,7 +108,6 @@ def _test_save_and_load(self, scripted_module):
class TestJitSequeneceGenerator(TestJitSequenceGeneratorBase):
-
@unittest.skipIf(
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
)
@@ -130,7 +128,6 @@ def test_ensemble_sequence_generator(self):
class TestJitEnsemble(TestJitSequenceGeneratorBase):
-
@unittest.skipIf(
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
)
@@ -190,9 +187,14 @@ def assertTensorEqual(self, t1, t2):
class TestSequeneceGenerator(TestSequenceGeneratorBase):
def setUp(self):
- self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = (
- test_utils.sequence_generator_setup()
- )
+ (
+ self.tgt_dict,
+ self.w1,
+ self.w2,
+ src_tokens,
+ src_lengths,
+ self.model,
+ ) = test_utils.sequence_generator_setup()
self.sample = {
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths}
}
@@ -276,7 +278,9 @@ def test_with_lenpen_favoring_long_hypos(self):
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self):
- generator = SequenceGenerator([self.model], self.tgt_dict, beam_size=2, max_len_b=2)
+ generator = SequenceGenerator(
+ [self.model], self.tgt_dict, beam_size=2, max_len_b=2
+ )
hypos = generator.forward(self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
@@ -294,21 +298,27 @@ def test_maxlen(self):
def test_encoder_with_different_output_len(self):
args = self.model.encoder.args
- task = test_utils.TestTranslationTask.setup_task(args, self.tgt_dict, self.tgt_dict)
+ task = test_utils.TestTranslationTask.setup_task(
+ args, self.tgt_dict, self.tgt_dict
+ )
reshaping_model = test_utils.TestReshapingModel.build_model(args, task)
- generator = SequenceGenerator([reshaping_model], self.tgt_dict, beam_size=2, max_len_b=2)
+ generator = SequenceGenerator(
+ [reshaping_model], self.tgt_dict, beam_size=2, max_len_b=2
+ )
hypos = generator.forward(self.sample)
for sent in [0, 1]:
for beam in [0, 1]:
- assert hypos[sent][beam]['attention'] is not None
+ assert hypos[sent][beam]["attention"] is not None
def test_generation_with_additional_input(self):
args = self.model.encoder.args
- task = test_utils.TestTranslationTask.setup_task(args, self.tgt_dict, self.tgt_dict)
+ task = test_utils.TestTranslationTask.setup_task(
+ args, self.tgt_dict, self.tgt_dict
+ )
add_input_model = test_utils.TestAdditionalInputModel.build_model(args, task)
generator = SequenceGenerator([add_input_model], self.tgt_dict, beam_size=2)
sample = self.sample.copy()
- sample['net_input']['fancy_other_input'] = sample['net_input']['src_tokens']
+ sample["net_input"]["fancy_other_input"] = sample["net_input"]["src_tokens"]
hypos = generator.forward(self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
@@ -317,7 +327,6 @@ def test_generation_with_additional_input(self):
class TestDiverseBeamSearch(TestSequenceGeneratorBase):
-
def setUp(self):
# construct dummy dictionary
d = test_utils.dummy_dictionary(vocab_size=2)
@@ -329,45 +338,53 @@ def setUp(self):
self.w2 = 5
# construct source data
- self.src_tokens = torch.LongTensor([
- [self.w1, self.w2, self.eos],
- [self.w1, self.w2, self.eos],
- ])
+ self.src_tokens = torch.LongTensor(
+ [
+ [self.w1, self.w2, self.eos],
+ [self.w1, self.w2, self.eos],
+ ]
+ )
self.src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace()
- unk = 0.
+ unk = 0.0
args.beam_probs = [
# step 0:
- torch.FloatTensor([
- # eos w1 w2
- # sentence 1:
- [0.0, unk, 0.9, 0.1], # beam 1
- [0.0, unk, 0.9, 0.1], # beam 2
- # sentence 2:
- [0.0, unk, 0.7, 0.3],
- [0.0, unk, 0.7, 0.3],
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ # sentence 1:
+ [0.0, unk, 0.9, 0.1], # beam 1
+ [0.0, unk, 0.9, 0.1], # beam 2
+ # sentence 2:
+ [0.0, unk, 0.7, 0.3],
+ [0.0, unk, 0.7, 0.3],
+ ]
+ ),
# step 1:
- torch.FloatTensor([
- # eos w1 w2
- # sentence 1:
- [0.0, unk, 0.6, 0.4],
- [0.0, unk, 0.6, 0.4],
- # sentence 2:
- [0.25, unk, 0.35, 0.4],
- [0.25, unk, 0.35, 0.4],
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ # sentence 1:
+ [0.0, unk, 0.6, 0.4],
+ [0.0, unk, 0.6, 0.4],
+ # sentence 2:
+ [0.25, unk, 0.35, 0.4],
+ [0.25, unk, 0.35, 0.4],
+ ]
+ ),
# step 2:
- torch.FloatTensor([
- # eos w1 w2
- # sentence 1:
- [1.0, unk, 0.0, 0.0],
- [1.0, unk, 0.0, 0.0],
- # sentence 2:
- [0.9, unk, 0.1, 0.0],
- [0.9, unk, 0.1, 0.0],
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ # sentence 1:
+ [1.0, unk, 0.0, 0.0],
+ [1.0, unk, 0.0, 0.0],
+ # sentence 2:
+ [0.9, unk, 0.1, 0.0],
+ [0.9, unk, 0.1, 0.0],
+ ]
+ ),
]
task = test_utils.TestTranslationTask.setup_task(args, d, d)
@@ -375,11 +392,21 @@ def setUp(self):
self.tgt_dict = task.target_dictionary
def test_diverse_beam_search(self):
- search_strategy = search.DiverseBeamSearch(self.tgt_dict, num_groups=2, diversity_strength=0.)
+ search_strategy = search.DiverseBeamSearch(
+ self.tgt_dict, num_groups=2, diversity_strength=0.0
+ )
generator = SequenceGenerator(
- [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy,
+ [self.model],
+ self.tgt_dict,
+ beam_size=2,
+ search_strategy=search_strategy,
)
- sample = {'net_input': {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths}}
+ sample = {
+ "net_input": {
+ "src_tokens": self.src_tokens,
+ "src_lengths": self.src_lengths,
+ }
+ }
hypos = generator.forward(sample)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
@@ -439,7 +466,6 @@ def test_diverse_beam_search(self):
class TestTopPSamplingSearch(TestSequenceGeneratorBase):
-
def setUp(self):
# construct dummy dictionary
d = test_utils.dummy_dictionary(vocab_size=2)
@@ -451,14 +477,16 @@ def setUp(self):
self.w2 = 5
# construct source data
- self.src_tokens = torch.LongTensor([
- [self.w1, self.w2, self.eos],
- [self.w1, self.w2, self.eos],
- ])
+ self.src_tokens = torch.LongTensor(
+ [
+ [self.w1, self.w2, self.eos],
+ [self.w1, self.w2, self.eos],
+ ]
+ )
self.src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace()
- unk = 0.
+ unk = 0.0
# The minimal probability of top 2 tokens.
self.min_top2_prob = 0.75
# The minimal probability of the top 1 token.
@@ -470,29 +498,35 @@ def setUp(self):
args.beam_probs = [
# step 0:
- torch.FloatTensor([
- # eos w1 w2
- [0.0, unk, 1.0, 0.0],
- [0.0, unk, 1.0, 0.0],
- [0.0, unk, 1.0, 0.0],
- [0.0, unk, 1.0, 0.0],
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ [0.0, unk, 1.0, 0.0],
+ [0.0, unk, 1.0, 0.0],
+ [0.0, unk, 1.0, 0.0],
+ [0.0, unk, 1.0, 0.0],
+ ]
+ ),
# step 1:
- torch.FloatTensor([
- # eos w1 w2
- [eos_prob, unk, w1_prob, w2_prob],
- [eos_prob, unk, w1_prob, w2_prob],
- [eos_prob, unk, w1_prob, w2_prob],
- [eos_prob, unk, w1_prob, w2_prob],
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ [eos_prob, unk, w1_prob, w2_prob],
+ [eos_prob, unk, w1_prob, w2_prob],
+ [eos_prob, unk, w1_prob, w2_prob],
+ [eos_prob, unk, w1_prob, w2_prob],
+ ]
+ ),
# step 2:
- torch.FloatTensor([
- # eos w1 w2
- [1.0, unk, 0.0, 0.0],
- [1.0, unk, 0.0, 0.0],
- [1.0, unk, 0.0, 0.0],
- [1.0, unk, 0.0, 0.0],
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ [1.0, unk, 0.0, 0.0],
+ [1.0, unk, 0.0, 0.0],
+ [1.0, unk, 0.0, 0.0],
+ [1.0, unk, 0.0, 0.0],
+ ]
+ ),
]
task = test_utils.TestTranslationTask.setup_task(args, d, d)
@@ -502,14 +536,17 @@ def setUp(self):
def test_topp_sampling_search_low_prob(self):
# Given a prob low enough to top-P sampling, we expect only the top
# 1 token to be sampled, which always results in the same output.
- low_sampling_topp = self.min_top1_prob/2.0
- search_strategy = search.Sampling(self.tgt_dict, sampling_topp=low_sampling_topp)
+ low_sampling_topp = self.min_top1_prob / 2.0
+ search_strategy = search.Sampling(
+ self.tgt_dict, sampling_topp=low_sampling_topp
+ )
generator = SequenceGenerator(
- [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy)
+ [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy
+ )
sample = {
- 'net_input': {
- 'src_tokens': self.src_tokens,
- 'src_lengths': self.src_lengths
+ "net_input": {
+ "src_tokens": self.src_tokens,
+ "src_lengths": self.src_lengths,
}
}
hypos = generator.forward(sample)
@@ -530,55 +567,74 @@ def test_topp_sampling_search_low_prob(self):
def test_topp_sampling_search_high_prob(self):
# Given a prob high enough to top-P sampling, any of the top 2
# tokens could be sampled. This can cause different outputs.
- high_sampling_topp = (self.min_top1_prob+self.min_top2_prob)/2.0
- search_strategy = search.Sampling(self.tgt_dict, sampling_topp=high_sampling_topp)
+ high_sampling_topp = (self.min_top1_prob + self.min_top2_prob) / 2.0
+ search_strategy = search.Sampling(
+ self.tgt_dict, sampling_topp=high_sampling_topp
+ )
generator = SequenceGenerator(
- [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy)
+ [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy
+ )
sample = {
- 'net_input': {
- 'src_tokens': self.src_tokens,
- 'src_lengths': self.src_lengths
+ "net_input": {
+ "src_tokens": self.src_tokens,
+ "src_lengths": self.src_lengths,
}
}
hypos = generator.forward(sample)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
- self.assertTrue(self.hypoTokens(hypos[0][0], [w1, w1, eos]) or
- self.hypoTokens(hypos[0][0], [w1, w2, eos]))
- self.assertTrue(self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0]) or
- self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0]))
+ self.assertTrue(
+ self.hypoTokens(hypos[0][0], [w1, w1, eos])
+ or self.hypoTokens(hypos[0][0], [w1, w2, eos])
+ )
+ self.assertTrue(
+ self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0])
+ or self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0])
+ )
# sentence 1, beam 2
- self.assertTrue(self.hypoTokens(hypos[0][1], [w1, w1, eos]) or
- self.hypoTokens(hypos[0][1], [w1, w2, eos]))
- self.assertTrue(self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0]) or
- self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0]))
+ self.assertTrue(
+ self.hypoTokens(hypos[0][1], [w1, w1, eos])
+ or self.hypoTokens(hypos[0][1], [w1, w2, eos])
+ )
+ self.assertTrue(
+ self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0])
+ or self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0])
+ )
# sentence 2, beam 1
- self.assertTrue(self.hypoTokens(hypos[1][0], [w1, w1, eos]) or
- self.hypoTokens(hypos[1][0], [w1, w2, eos]))
- self.assertTrue(self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0]) or
- self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0]))
+ self.assertTrue(
+ self.hypoTokens(hypos[1][0], [w1, w1, eos])
+ or self.hypoTokens(hypos[1][0], [w1, w2, eos])
+ )
+ self.assertTrue(
+ self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0])
+ or self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0])
+ )
# sentence 2, beam 2
- self.assertTrue(self.hypoTokens(hypos[1][1], [w1, w1, eos]) or
- self.hypoTokens(hypos[1][1], [w1, w2, eos]))
- self.assertTrue(self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0]) or
- self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0]))
+ self.assertTrue(
+ self.hypoTokens(hypos[1][1], [w1, w1, eos])
+ or self.hypoTokens(hypos[1][1], [w1, w2, eos])
+ )
+ self.assertTrue(
+ self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0])
+ or self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0])
+ )
def hypoTokens(self, hypo, tokens):
- return self.tensorEqual(hypo['tokens'], torch.LongTensor(tokens))
+ return self.tensorEqual(hypo["tokens"], torch.LongTensor(tokens))
- def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
+ def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
pos_scores = torch.FloatTensor(pos_probs).log()
- if not self.almostEqual(hypo['positional_scores'], pos_scores):
+ if not self.almostEqual(hypo["positional_scores"], pos_scores):
return False
- if pos_scores.numel() != hypo['tokens'].numel():
+ if pos_scores.numel() != hypo["tokens"].numel():
return False
score = pos_scores.sum()
if normalized:
score /= pos_scores.numel() ** lenpen
- return abs(score - hypo['score']) < 1e-6
+ return abs(score - hypo["score"]) < 1e-6
def almostEqual(self, t1, t2):
return t1.size() == t2.size() and (t1 - t2).abs().max() < 1e-4
diff --git a/tests/test_sequence_scorer.py b/tests/test_sequence_scorer.py
index a7c2a53a90..42f9447b59 100644
--- a/tests/test_sequence_scorer.py
+++ b/tests/test_sequence_scorer.py
@@ -6,15 +6,12 @@
import argparse
import unittest
+import tests.utils as test_utils
import torch
-
from fairseq.sequence_scorer import SequenceScorer
-import tests.utils as test_utils
-
class TestSequenceScorer(unittest.TestCase):
-
def test_sequence_scorer(self):
# construct dummy dictionary
d = test_utils.dummy_dictionary(vocab_size=2)
@@ -28,52 +25,60 @@ def test_sequence_scorer(self):
# construct dataloader
data = [
{
- 'source': torch.LongTensor([w1, w2, eos]),
- 'target': torch.LongTensor([w1, w2, w1, eos]),
+ "source": torch.LongTensor([w1, w2, eos]),
+ "target": torch.LongTensor([w1, w2, w1, eos]),
},
{
- 'source': torch.LongTensor([w2, eos]),
- 'target': torch.LongTensor([w2, w1, eos]),
+ "source": torch.LongTensor([w2, eos]),
+ "target": torch.LongTensor([w2, w1, eos]),
},
{
- 'source': torch.LongTensor([w2, eos]),
- 'target': torch.LongTensor([w2, eos]),
+ "source": torch.LongTensor([w2, eos]),
+ "target": torch.LongTensor([w2, eos]),
},
]
data_itr = test_utils.dummy_dataloader(data)
# specify expected output probabilities
args = argparse.Namespace()
- unk = 0.
+ unk = 0.0
args.beam_probs = [
# step 0:
- torch.FloatTensor([
- # eos w1 w2
- [0.0, unk, 0.6, 0.4], # sentence 1
- [0.0, unk, 0.4, 0.6], # sentence 2
- [0.0, unk, 0.7, 0.3], # sentence 3
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ [0.0, unk, 0.6, 0.4], # sentence 1
+ [0.0, unk, 0.4, 0.6], # sentence 2
+ [0.0, unk, 0.7, 0.3], # sentence 3
+ ]
+ ),
# step 1:
- torch.FloatTensor([
- # eos w1 w2
- [0.0, unk, 0.2, 0.7], # sentence 1
- [0.0, unk, 0.8, 0.2], # sentence 2
- [0.7, unk, 0.1, 0.2], # sentence 3
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ [0.0, unk, 0.2, 0.7], # sentence 1
+ [0.0, unk, 0.8, 0.2], # sentence 2
+ [0.7, unk, 0.1, 0.2], # sentence 3
+ ]
+ ),
# step 2:
- torch.FloatTensor([
- # eos w1 w2
- [0.10, unk, 0.50, 0.4], # sentence 1
- [0.15, unk, 0.15, 0.7], # sentence 2
- [0.00, unk, 0.00, 0.0], # sentence 3
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ [0.10, unk, 0.50, 0.4], # sentence 1
+ [0.15, unk, 0.15, 0.7], # sentence 2
+ [0.00, unk, 0.00, 0.0], # sentence 3
+ ]
+ ),
# step 3:
- torch.FloatTensor([
- # eos w1 w2
- [0.9, unk, 0.05, 0.05], # sentence 1
- [0.0, unk, 0.00, 0.0], # sentence 2
- [0.0, unk, 0.00, 0.0], # sentence 3
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ [0.9, unk, 0.05, 0.05], # sentence 1
+ [0.0, unk, 0.00, 0.0], # sentence 2
+ [0.0, unk, 0.00, 0.0], # sentence 3
+ ]
+ ),
]
expected_scores = [
[0.6, 0.7, 0.5, 0.9], # sentence 1
@@ -86,21 +91,21 @@ def test_sequence_scorer(self):
scorer = SequenceScorer(task.target_dictionary)
for sample in data_itr:
hypos = task.inference_step(scorer, [model], sample)
- for id, hypos_id in zip(sample['id'].tolist(), hypos):
- self.assertHypoTokens(hypos_id[0], data[id]['target'])
+ for id, hypos_id in zip(sample["id"].tolist(), hypos):
+ self.assertHypoTokens(hypos_id[0], data[id]["target"])
self.assertHypoScore(hypos_id[0], expected_scores[id])
def assertHypoTokens(self, hypo, tokens):
- self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
+ self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens))
- def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
+ def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
pos_scores = torch.FloatTensor(pos_probs).log()
- self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
- self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
+ self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
+ self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
score = pos_scores.sum()
if normalized:
- score /= pos_scores.numel()**lenpen
- self.assertLess(abs(score - hypo['score']), 1e-6)
+ score /= pos_scores.numel() ** lenpen
+ self.assertLess(abs(score - hypo["score"]), 1e-6)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
@@ -111,5 +116,5 @@ def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.ne(t2).long().sum(), 0)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_sparse_multihead_attention.py b/tests/test_sparse_multihead_attention.py
index eaf9742cdf..3e32b25a7f 100644
--- a/tests/test_sparse_multihead_attention.py
+++ b/tests/test_sparse_multihead_attention.py
@@ -3,46 +3,112 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
import unittest
+
+import torch
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
class TestSparseMultiheadAttention(unittest.TestCase):
def test_sparse_multihead_attention(self):
attn_weights = torch.randn(1, 8, 8)
- bidirectional_sparse_mask = torch.tensor([
- [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
- [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
- [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
- [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
- [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
- [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
- [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
- [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0]
- ])
-
- bidirectional_attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=True)
- bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
- torch.all(torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask))
-
- sparse_mask = torch.tensor([
- [0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'),
- float('-inf'), float('-inf')],
- [0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
- [0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
- [0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
- [0, 0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf')],
- [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, float('-inf'), float('-inf')],
- [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, float('-inf')],
- [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
- ])
-
- attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=False)
+ bidirectional_sparse_mask = torch.tensor(
+ [
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
+ ]
+ )
+
+ bidirectional_attention = SparseMultiheadAttention(
+ 16, 1, stride=4, expressivity=1, is_bidirectional=True
+ )
+ bidirectional_attention_sparse_mask = (
+ bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
+ )
+ torch.all(
+ torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask)
+ )
+
+ sparse_mask = torch.tensor(
+ [
+ [
+ 0,
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ ],
+ [
+ 0,
+ 0,
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ ],
+ [
+ 0,
+ 0,
+ 0,
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ ],
+ [
+ 0,
+ 0,
+ 0,
+ 0,
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ ],
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), float("-inf")],
+ [
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ 0,
+ 0,
+ 0,
+ float("-inf"),
+ float("-inf"),
+ ],
+ [
+ float("-inf"),
+ float("-inf"),
+ float("-inf"),
+ 0,
+ 0,
+ 0,
+ 0,
+ float("-inf"),
+ ],
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
+ ]
+ )
+
+ attention = SparseMultiheadAttention(
+ 16, 1, stride=4, expressivity=1, is_bidirectional=False
+ )
attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)
torch.all(torch.eq(attention_sparse_mask, sparse_mask))
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_token_block_dataset.py b/tests/test_token_block_dataset.py
index 41abb194da..ea315b4e67 100644
--- a/tests/test_token_block_dataset.py
+++ b/tests/test_token_block_dataset.py
@@ -5,15 +5,12 @@
import unittest
+import tests.utils as test_utils
import torch
-
from fairseq.data import TokenBlockDataset
-import tests.utils as test_utils
-
class TestTokenBlockDataset(unittest.TestCase):
-
def _build_dataset(self, data, **kwargs):
sizes = [len(x) for x in data]
underlying_ds = test_utils.TestDataset(data)
@@ -25,7 +22,7 @@ def test_eos_break_mode(self):
torch.tensor([1], dtype=torch.long),
torch.tensor([8, 7, 6, 1], dtype=torch.long),
]
- ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
+ ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos")
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [1])
self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
@@ -35,7 +32,7 @@ def test_eos_break_mode(self):
torch.tensor([8, 7, 6, 1], dtype=torch.long),
torch.tensor([1], dtype=torch.long),
]
- ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
+ ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos")
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1])
self.assertEqual(ds[2].tolist(), [1])
@@ -46,7 +43,7 @@ def test_block_break_mode(self):
torch.tensor([8, 7, 6, 1], dtype=torch.long),
torch.tensor([9, 1], dtype=torch.long),
]
- ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none')
+ ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode="none")
self.assertEqual(ds[0].tolist(), [5, 4, 3])
self.assertEqual(ds[1].tolist(), [2, 1, 8])
self.assertEqual(ds[2].tolist(), [7, 6, 1])
@@ -58,7 +55,9 @@ def test_complete_break_mode(self):
torch.tensor([8, 7, 6, 1], dtype=torch.long),
torch.tensor([9, 1], dtype=torch.long),
]
- ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete')
+ ds = self._build_dataset(
+ data, block_size=6, pad=0, eos=1, break_mode="complete"
+ )
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
@@ -68,7 +67,9 @@ def test_complete_break_mode(self):
torch.tensor([1], dtype=torch.long),
torch.tensor([6, 1], dtype=torch.long),
]
- ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete')
+ ds = self._build_dataset(
+ data, block_size=3, pad=0, eos=1, break_mode="complete"
+ )
self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [5, 1, 1])
self.assertEqual(ds[2].tolist(), [6, 1])
diff --git a/tests/test_train.py b/tests/test_train.py
index 048acaca54..1b7e027c0c 100644
--- a/tests/test_train.py
+++ b/tests/test_train.py
@@ -10,17 +10,16 @@
from unittest.mock import MagicMock, patch
import torch
-
-from fairseq import data, checkpoint_utils
+from fairseq import checkpoint_utils, data
def mock_trainer(epoch, num_updates, iterations_in_epoch):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {
- 'train_iterator': {
- 'epoch': epoch,
- 'iterations_in_epoch': iterations_in_epoch,
- 'shuffle': False,
+ "train_iterator": {
+ "epoch": epoch,
+ "iterations_in_epoch": iterations_in_epoch,
+ "shuffle": False,
},
}
trainer.get_num_updates.return_value = num_updates
@@ -38,10 +37,17 @@ def mock_dict():
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
tokens_ds = data.TokenBlockDataset(
- tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
+ tokens,
+ sizes=[tokens.size(-1)],
+ block_size=1,
+ pad=0,
+ eos=1,
+ include_targets=False,
)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
- dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
+ dataset = data.LanguagePairDataset(
+ tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
+ )
epoch_itr = data.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
@@ -52,7 +58,7 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
def get_mock_args(finetune_from_model=None):
args_mock = MagicMock()
- args_mock.optimizer_overrides = '{}'
+ args_mock.optimizer_overrides = "{}"
args_mock.reset_dataloader = False
args_mock.reset_meters = False
args_mock.reset_optimizer = False
@@ -63,15 +69,14 @@ def get_mock_args(finetune_from_model=None):
class TestLoadCheckpoint(unittest.TestCase):
-
def setUp(self):
self.args_mock = get_mock_args()
self.patches = {
- 'os.makedirs': MagicMock(),
- 'os.path.join': MagicMock(),
- 'os.path.isfile': MagicMock(return_value=True),
- 'os.path.isabs': MagicMock(return_value=False),
- 'fairseq.file_io.PathManager.exists': MagicMock(return_value=False),
+ "os.makedirs": MagicMock(),
+ "os.path.join": MagicMock(),
+ "os.path.isfile": MagicMock(return_value=True),
+ "os.path.isabs": MagicMock(return_value=False),
+ "fairseq.file_io.PathManager.exists": MagicMock(return_value=False),
}
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
@@ -95,7 +100,7 @@ def test_load_partial_checkpoint(self):
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
- self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
+ self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 50)
self.assertEqual(epoch_itr.iterations_in_epoch, 51)
for _ in range(150 - 52):
@@ -120,27 +125,32 @@ def test_load_full_checkpoint(self):
self.assertEqual(epoch_itr.epoch, 3)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
- self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
+ self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
- self.patches['os.path.isfile'].return_value = False
+ self.patches["os.path.isfile"].return_value = False
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
- self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
+ self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
def test_finetune_from_model_args_conflict(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
- for arg in ['reset_optimizer', 'reset_lr_scheduler', 'reset_meters', 'reset_dataloader']:
+ for arg in [
+ "reset_optimizer",
+ "reset_lr_scheduler",
+ "reset_meters",
+ "reset_dataloader",
+ ]:
with self.subTest(arg=arg):
args_mock = get_mock_args("/temp/checkpoint_pretrained.pt")
setattr(args_mock, arg, True)
@@ -149,7 +159,8 @@ def test_finetune_from_model_args_conflict(self):
self.assertTrue(
"--finetune-from-model can not be set together with either --reset-optimizer"
- " or reset_lr_scheduler or reset_meters or reset_dataloader" in str(context.exception)
+ " or reset_lr_scheduler or reset_meters or reset_dataloader"
+ in str(context.exception)
)
def test_finetune_from_model(self):
@@ -165,11 +176,18 @@ def mock_finetune_exist(path):
return True
else:
return False
- self.patches['fairseq.file_io.PathManager.exists'].side_effect = mock_finetune_exist
+
+ self.patches[
+ "fairseq.file_io.PathManager.exists"
+ ].side_effect = mock_finetune_exist
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
- checkpoint_path, reset_optimizer, reset_lr_scheduler, \
- optimizer_overrides = trainer.load_checkpoint.call_args[0]
- reset_meters = trainer.load_checkpoint.call_args[1]['reset_meters']
+ (
+ checkpoint_path,
+ reset_optimizer,
+ reset_lr_scheduler,
+ optimizer_overrides,
+ ) = trainer.load_checkpoint.call_args[0]
+ reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
self.assertTrue(reset_optimizer)
self.assertTrue(reset_lr_scheduler)
self.assertTrue(reset_meters)
@@ -185,19 +203,26 @@ def test_finetune_from_model_resume(self):
# launch second time
# both restore_file=checkpoint_last.pt and finetune_from_model are set
def mock_finetune_exist(path):
- if path == from_model_path or path.endsWith('checkpoint_last.pt'):
+ if path == from_model_path or path.endsWith("checkpoint_last.pt"):
return True
else:
return False
- self.patches['fairseq.file_io.PathManager.exists'].side_effect = mock_finetune_exist
+
+ self.patches[
+ "fairseq.file_io.PathManager.exists"
+ ].side_effect = mock_finetune_exist
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
- checkpoint_path, reset_optimizer, reset_lr_scheduler, \
- optimizer_overrides = trainer.load_checkpoint.call_args[0]
- reset_meters = trainer.load_checkpoint.call_args[1]['reset_meters']
+ (
+ checkpoint_path,
+ reset_optimizer,
+ reset_lr_scheduler,
+ optimizer_overrides,
+ ) = trainer.load_checkpoint.call_args[0]
+ reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
self.assertFalse(reset_optimizer)
self.assertFalse(reset_lr_scheduler)
self.assertFalse(reset_meters)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 35fb115dda..79195903e0 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -6,24 +6,26 @@
import unittest
import torch
-
from fairseq import utils
class TestUtils(unittest.TestCase):
-
def test_convert_padding_direction(self):
pad = 1
- left_pad = torch.LongTensor([
- [2, 3, 4, 5, 6],
- [1, 7, 8, 9, 10],
- [1, 1, 1, 11, 12],
- ])
- right_pad = torch.LongTensor([
- [2, 3, 4, 5, 6],
- [7, 8, 9, 10, 1],
- [11, 12, 1, 1, 1],
- ])
+ left_pad = torch.LongTensor(
+ [
+ [2, 3, 4, 5, 6],
+ [1, 7, 8, 9, 10],
+ [1, 1, 1, 11, 12],
+ ]
+ )
+ right_pad = torch.LongTensor(
+ [
+ [2, 3, 4, 5, 6],
+ [7, 8, 9, 10, 1],
+ [11, 12, 1, 1, 1],
+ ]
+ )
self.assertAlmostEqual(
right_pad,
@@ -44,26 +46,34 @@ def test_convert_padding_direction(self):
def test_make_positions(self):
pad = 1
- left_pad_input = torch.LongTensor([
- [9, 9, 9, 9, 9],
- [1, 9, 9, 9, 9],
- [1, 1, 1, 9, 9],
- ])
- left_pad_output = torch.LongTensor([
- [2, 3, 4, 5, 6],
- [1, 2, 3, 4, 5],
- [1, 1, 1, 2, 3],
- ])
- right_pad_input = torch.LongTensor([
- [9, 9, 9, 9, 9],
- [9, 9, 9, 9, 1],
- [9, 9, 1, 1, 1],
- ])
- right_pad_output = torch.LongTensor([
- [2, 3, 4, 5, 6],
- [2, 3, 4, 5, 1],
- [2, 3, 1, 1, 1],
- ])
+ left_pad_input = torch.LongTensor(
+ [
+ [9, 9, 9, 9, 9],
+ [1, 9, 9, 9, 9],
+ [1, 1, 1, 9, 9],
+ ]
+ )
+ left_pad_output = torch.LongTensor(
+ [
+ [2, 3, 4, 5, 6],
+ [1, 2, 3, 4, 5],
+ [1, 1, 1, 2, 3],
+ ]
+ )
+ right_pad_input = torch.LongTensor(
+ [
+ [9, 9, 9, 9, 9],
+ [9, 9, 9, 9, 1],
+ [9, 9, 1, 1, 1],
+ ]
+ )
+ right_pad_output = torch.LongTensor(
+ [
+ [2, 3, 4, 5, 6],
+ [2, 3, 4, 5, 1],
+ [2, 3, 1, 1, 1],
+ ]
+ )
self.assertAlmostEqual(
left_pad_output,
@@ -82,9 +92,9 @@ def test_clip_grad_norm_(self):
params = [torch.nn.Parameter(torch.zeros(5)) for i in range(3)]
for p in params:
- p.grad = torch.full((5,), fill_value=2.)
+ p.grad = torch.full((5,), fill_value=2.0)
grad_norm = utils.clip_grad_norm_(params, 1.0)
- exp_grad_norm = torch.full((15,), fill_value=2.).norm()
+ exp_grad_norm = torch.full((15,), fill_value=2.0).norm()
self.assertTrue(torch.is_tensor(grad_norm))
self.assertEqual(grad_norm, exp_grad_norm)
@@ -100,5 +110,5 @@ def assertAlmostEqual(self, t1, t2):
self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/utils.py b/tests/utils.py
index 44a35fdccf..91feca6b2a 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -7,10 +7,10 @@
import os
import random
import sys
+from io import StringIO
+
import torch
import torch.nn.functional as F
-
-from io import StringIO
from fairseq import options, utils
from fairseq.data import Dictionary
from fairseq.data.language_pair_dataset import collate
@@ -20,18 +20,11 @@
FairseqIncrementalDecoder,
)
from fairseq.models.fairseq_encoder import EncoderOut
-from fairseq.tasks import LegacyFairseqTask
-from fairseq.tasks import FairseqTask
-from fairseq_cli import (
- generate,
- interactive,
- preprocess,
- train,
- validate,
-)
+from fairseq.tasks import FairseqTask, LegacyFairseqTask
+from fairseq_cli import generate, interactive, preprocess, train, validate
-def dummy_dictionary(vocab_size, prefix='token_'):
+def dummy_dictionary(vocab_size, prefix="token_"):
d = Dictionary()
for i in range(vocab_size):
token = prefix + str(i)
@@ -51,8 +44,8 @@ def dummy_dataloader(
# add any missing data to samples
for i, sample in enumerate(samples):
- if 'id' not in sample:
- sample['id'] = i
+ if "id" not in sample:
+ sample["id"] = i
# create dataloader
dataset = TestDataset(samples)
@@ -77,48 +70,86 @@ def sequence_generator_setup():
src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace()
- unk = 0.
+ unk = 0.0
args.beam_probs = [
# step 0:
- torch.FloatTensor([
- # eos w1 w2
- # sentence 1:
- [0.0, unk, 0.9, 0.1], # beam 1
- [0.0, unk, 0.9, 0.1], # beam 2
- # sentence 2:
- [0.0, unk, 0.7, 0.3],
- [0.0, unk, 0.7, 0.3],
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2
+ # sentence 1:
+ [0.0, unk, 0.9, 0.1], # beam 1
+ [0.0, unk, 0.9, 0.1], # beam 2
+ # sentence 2:
+ [0.0, unk, 0.7, 0.3],
+ [0.0, unk, 0.7, 0.3],
+ ]
+ ),
# step 1:
- torch.FloatTensor([
- # eos w1 w2 prefix
- # sentence 1:
- [1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 : 0.9*1.0)
- [0.0, unk, 0.9, 0.1], # w2: 0.1
- # sentence 2:
- [0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 : 0.7*0.25)
- [0.00, unk, 0.10, 0.9], # w2: 0.3
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2 prefix
+ # sentence 1:
+ [1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 : 0.9*1.0)
+ [0.0, unk, 0.9, 0.1], # w2: 0.1
+ # sentence 2:
+ [0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 : 0.7*0.25)
+ [0.00, unk, 0.10, 0.9], # w2: 0.3
+ ]
+ ),
# step 2:
- torch.FloatTensor([
- # eos w1 w2 prefix
- # sentence 1:
- [0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
- [0.6, unk, 0.2, 0.2], # w2 w2: 0.1*0.1 (emit: w2 w2 : 0.1*0.1*0.6)
- # sentence 2:
- [0.60, unk, 0.4, 0.00], # w1 w2: 0.7*0.4 (emit: w1 w2 : 0.7*0.4*0.6)
- [0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2 prefix
+ # sentence 1:
+ [0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
+ [
+ 0.6,
+ unk,
+ 0.2,
+ 0.2,
+ ], # w2 w2: 0.1*0.1 (emit: w2 w2 : 0.1*0.1*0.6)
+ # sentence 2:
+ [
+ 0.60,
+ unk,
+ 0.4,
+ 0.00,
+ ], # w1 w2: 0.7*0.4 (emit: w1 w2 : 0.7*0.4*0.6)
+ [0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
+ ]
+ ),
# step 3:
- torch.FloatTensor([
- # eos w1 w2 prefix
- # sentence 1:
- [1.0, unk, 0.0, 0.0], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 : 0.1*0.9*0.9*1.0)
- [1.0, unk, 0.0, 0.0], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 : 0.1*0.9*0.1*1.0)
- # sentence 2:
- [0.1, unk, 0.5, 0.4], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 : 0.3*0.9*0.99*0.1)
- [1.0, unk, 0.0, 0.0], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 : 0.7*0.4*0.4*1.0)
- ]),
+ torch.FloatTensor(
+ [
+ # eos w1 w2 prefix
+ # sentence 1:
+ [
+ 1.0,
+ unk,
+ 0.0,
+ 0.0,
+ ], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 : 0.1*0.9*0.9*1.0)
+ [
+ 1.0,
+ unk,
+ 0.0,
+ 0.0,
+ ], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 : 0.1*0.9*0.1*1.0)
+ # sentence 2:
+ [
+ 0.1,
+ unk,
+ 0.5,
+ 0.4,
+ ], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 : 0.3*0.9*0.99*0.1)
+ [
+ 1.0,
+ unk,
+ 0.0,
+ 0.0,
+ ], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 : 0.7*0.4*0.4*1.0)
+ ]
+ ),
]
task = TestTranslationTask.setup_task(args, d, d)
@@ -132,18 +163,18 @@ def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False):
def _create_dummy_data(filename):
data = torch.rand(num_examples * maxlen)
data = 97 + torch.floor(26 * data).int()
- with open(os.path.join(data_dir, filename), 'w') as h:
+ with open(os.path.join(data_dir, filename), "w") as h:
offset = 0
for _ in range(num_examples):
ex_len = random.randint(1, maxlen)
- ex_str = ' '.join(map(chr, data[offset:offset+ex_len]))
+ ex_str = " ".join(map(chr, data[offset : offset + ex_len]))
print(ex_str, file=h)
offset += ex_len
def _create_dummy_alignment_data(filename_src, filename_tgt, filename):
- with open(os.path.join(data_dir, filename_src), 'r') as src_f, \
- open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \
- open(os.path.join(data_dir, filename), 'w') as h:
+ with open(os.path.join(data_dir, filename_src), "r") as src_f, open(
+ os.path.join(data_dir, filename_tgt), "r"
+ ) as tgt_f, open(os.path.join(data_dir, filename), "w") as h:
for src, tgt in zip(src_f, tgt_f):
src_len = len(src.split())
tgt_len = len(tgt.split())
@@ -151,31 +182,42 @@ def _create_dummy_alignment_data(filename_src, filename_tgt, filename):
num_alignments = random.randint(avg_len // 2, 2 * avg_len)
src_indices = torch.floor(torch.rand(num_alignments) * src_len).int()
tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int()
- ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)])
+ ex_str = " ".join(
+ [
+ "{}-{}".format(src, tgt)
+ for src, tgt in zip(src_indices, tgt_indices)
+ ]
+ )
print(ex_str, file=h)
- _create_dummy_data('train.in')
- _create_dummy_data('train.out')
- _create_dummy_data('valid.in')
- _create_dummy_data('valid.out')
- _create_dummy_data('test.in')
- _create_dummy_data('test.out')
+ _create_dummy_data("train.in")
+ _create_dummy_data("train.out")
+ _create_dummy_data("valid.in")
+ _create_dummy_data("valid.out")
+ _create_dummy_data("test.in")
+ _create_dummy_data("test.out")
if alignment:
- _create_dummy_alignment_data('train.in', 'train.out', 'train.align')
- _create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align')
- _create_dummy_alignment_data('test.in', 'test.out', 'test.align')
+ _create_dummy_alignment_data("train.in", "train.out", "train.align")
+ _create_dummy_alignment_data("valid.in", "valid.out", "valid.align")
+ _create_dummy_alignment_data("test.in", "test.out", "test.align")
def preprocess_lm_data(data_dir):
preprocess_parser = options.get_preprocessing_parser()
- preprocess_args = preprocess_parser.parse_args([
- '--only-source',
- '--trainpref', os.path.join(data_dir, 'train.out'),
- '--validpref', os.path.join(data_dir, 'valid.out'),
- '--testpref', os.path.join(data_dir, 'test.out'),
- '--destdir', data_dir,
- ])
+ preprocess_args = preprocess_parser.parse_args(
+ [
+ "--only-source",
+ "--trainpref",
+ os.path.join(data_dir, "train.out"),
+ "--validpref",
+ os.path.join(data_dir, "valid.out"),
+ "--testpref",
+ os.path.join(data_dir, "test.out"),
+ "--destdir",
+ data_dir,
+ ]
+ )
preprocess.main(preprocess_args)
@@ -183,15 +225,24 @@ def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = options.get_preprocessing_parser()
preprocess_args = preprocess_parser.parse_args(
[
- '--source-lang', 'in',
- '--target-lang', 'out',
- '--trainpref', os.path.join(data_dir, 'train'),
- '--validpref', os.path.join(data_dir, 'valid'),
- '--testpref', os.path.join(data_dir, 'test'),
- '--thresholdtgt', '0',
- '--thresholdsrc', '0',
- '--destdir', data_dir,
- ] + (extra_flags or []),
+ "--source-lang",
+ "in",
+ "--target-lang",
+ "out",
+ "--trainpref",
+ os.path.join(data_dir, "train"),
+ "--validpref",
+ os.path.join(data_dir, "valid"),
+ "--testpref",
+ os.path.join(data_dir, "test"),
+ "--thresholdtgt",
+ "0",
+ "--thresholdsrc",
+ "0",
+ "--destdir",
+ data_dir,
+ ]
+ + (extra_flags or []),
)
preprocess.main(preprocess_args)
@@ -200,43 +251,72 @@ def preprocess_summarization_data(data_dir, extra_flags=None):
preprocess_parser = options.get_preprocessing_parser()
preprocess_args = preprocess_parser.parse_args(
[
- '--source-lang', 'in',
- '--target-lang', 'out',
- '--trainpref', os.path.join(data_dir, 'train'),
- '--validpref', os.path.join(data_dir, 'valid'),
- '--testpref', os.path.join(data_dir, 'test'),
- '--thresholdtgt', '0',
- '--thresholdsrc', '0',
- '--joined-dictionary',
- '--destdir', data_dir,
- ] + (extra_flags or []),
+ "--source-lang",
+ "in",
+ "--target-lang",
+ "out",
+ "--trainpref",
+ os.path.join(data_dir, "train"),
+ "--validpref",
+ os.path.join(data_dir, "valid"),
+ "--testpref",
+ os.path.join(data_dir, "test"),
+ "--thresholdtgt",
+ "0",
+ "--thresholdsrc",
+ "0",
+ "--joined-dictionary",
+ "--destdir",
+ data_dir,
+ ]
+ + (extra_flags or []),
)
preprocess.main(preprocess_args)
-def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False,
- lang_flags=None, extra_valid_flags=None):
+def train_translation_model(
+ data_dir,
+ arch,
+ extra_flags=None,
+ task="translation",
+ run_validation=False,
+ lang_flags=None,
+ extra_valid_flags=None,
+):
if lang_flags is None:
lang_flags = [
- '--source-lang', 'in',
- '--target-lang', 'out',
+ "--source-lang",
+ "in",
+ "--target-lang",
+ "out",
]
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
- '--task', task,
+ "--task",
+ task,
data_dir,
- '--save-dir', data_dir,
- '--arch', arch,
- '--optimizer', 'nag',
- '--lr', '0.05',
- '--max-tokens', '500',
- '--max-epoch', '1',
- '--no-progress-bar',
- '--distributed-world-size', '1',
- '--num-workers', '0',
- ] + lang_flags + (extra_flags or []),
+ "--save-dir",
+ data_dir,
+ "--arch",
+ arch,
+ "--optimizer",
+ "nag",
+ "--lr",
+ "0.05",
+ "--max-tokens",
+ "500",
+ "--max-epoch",
+ "1",
+ "--no-progress-bar",
+ "--distributed-world-size",
+ "1",
+ "--num-workers",
+ "0",
+ ]
+ + lang_flags
+ + (extra_flags or []),
)
train.main(train_args)
@@ -246,14 +326,21 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'
validate_args = options.parse_args_and_arch(
validate_parser,
[
- '--task', task,
+ "--task",
+ task,
data_dir,
- '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
- '--valid-subset', 'valid',
- '--max-tokens', '500',
- '--no-progress-bar',
- '--num-workers', '0',
- ] + lang_flags + (extra_valid_flags or [])
+ "--path",
+ os.path.join(data_dir, "checkpoint_last.pt"),
+ "--valid-subset",
+ "valid",
+ "--max-tokens",
+ "500",
+ "--no-progress-bar",
+ "--num-workers",
+ "0",
+ ]
+ + lang_flags
+ + (extra_valid_flags or []),
)
validate.main(validate_args)
@@ -261,21 +348,28 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'
def generate_main(data_dir, extra_flags=None):
if extra_flags is None:
extra_flags = [
- '--print-alignment',
+ "--print-alignment",
]
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
[
data_dir,
- '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
- '--beam', '3',
- '--batch-size', '64',
- '--max-len-b', '5',
- '--gen-subset', 'valid',
- '--no-progress-bar',
- '--num-workers', '0',
- ] + (extra_flags or []),
+ "--path",
+ os.path.join(data_dir, "checkpoint_last.pt"),
+ "--beam",
+ "3",
+ "--batch-size",
+ "64",
+ "--max-len-b",
+ "5",
+ "--gen-subset",
+ "valid",
+ "--no-progress-bar",
+ "--num-workers",
+ "0",
+ ]
+ + (extra_flags or []),
)
# evaluate model in batch mode
@@ -283,16 +377,15 @@ def generate_main(data_dir, extra_flags=None):
# evaluate model interactively
generate_args.buffer_size = 0
- generate_args.input = '-'
+ generate_args.input = "-"
generate_args.batch_size = None
orig_stdin = sys.stdin
- sys.stdin = StringIO('h e l l o\n')
+ sys.stdin = StringIO("h e l l o\n")
interactive.main(generate_args)
sys.stdin = orig_stdin
class TestDataset(torch.utils.data.Dataset):
-
def __init__(self, data):
super().__init__()
self.data = data
@@ -306,7 +399,6 @@ def __len__(self):
class TestTranslationTask(LegacyFairseqTask):
-
def __init__(self, args, src_dict, tgt_dict, model):
super().__init__(args)
self.src_dict = src_dict
@@ -369,8 +461,8 @@ def reorder_encoder_out(self, encoder_out, new_order):
class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
- assert hasattr(args, 'beam_probs') or hasattr(args, 'probs')
- args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
+ assert hasattr(args, "beam_probs") or hasattr(args, "probs")
+ args.max_decoder_positions = getattr(args, "max_decoder_positions", 100)
self.args = args
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
@@ -384,18 +476,19 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
# determine number of steps
if incremental_state is not None:
# cache step number
- step = utils.get_incremental_state(self, incremental_state, 'step')
+ step = utils.get_incremental_state(self, incremental_state, "step")
if step is None:
step = 0
- utils.set_incremental_state(self, incremental_state, 'step', step + 1)
+ utils.set_incremental_state(self, incremental_state, "step", step + 1)
steps = [step]
else:
steps = list(range(tgt_len))
# define output in terms of raw probs
- if hasattr(self.args, 'probs'):
- assert self.args.probs.dim() == 3, \
- 'expected probs to have size bsz*steps*vocab'
+ if hasattr(self.args, "probs"):
+ assert (
+ self.args.probs.dim() == 3
+ ), "expected probs to have size bsz*steps*vocab"
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
else:
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
@@ -403,7 +496,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
- probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
+ probs[:, i, self.dictionary.eos() :] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
@@ -475,8 +568,8 @@ def __init__(self, args, dictionary):
self.args = args
def forward(self, src_tokens, src_lengths=None, **kwargs):
- assert 'fancy_other_input' in kwargs
- assert kwargs['fancy_other_input'] is not None
+ assert "fancy_other_input" in kwargs
+ assert kwargs["fancy_other_input"] is not None
return EncoderOut(
encoder_out=src_tokens,
encoder_padding_mask=None,
@@ -508,8 +601,8 @@ def build_model(cls, args, task):
return cls(encoder, decoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
- encoder_out = self.encoder(
- src_tokens, src_lengths=src_lengths, **kwargs)
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
- prev_output_tokens, encoder_out=encoder_out, **kwargs)
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
+ )
return decoder_out
diff --git a/train.py b/train.py
index 3967ef48f3..321de3d9b5 100644
--- a/train.py
+++ b/train.py
@@ -10,5 +10,5 @@
from fairseq_cli.train import cli_main
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_main()