-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
39 changed files
with
8,624 additions
and
34 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .dictionary import Dictionary | ||
from .fairseq_dataset import FairseqDataset | ||
|
||
__all__ = [ | ||
"Dictionary", | ||
"FairseqDataset", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import importlib | ||
import os | ||
|
||
from fairseq import registry | ||
|
||
|
||
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( | ||
"--tokenizer", | ||
default=None, | ||
) | ||
|
||
|
||
build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry( | ||
"--bpe", | ||
default=None, | ||
) | ||
|
||
|
||
# automatically import any Python files in the encoders/ directory | ||
for file in sorted(os.listdir(os.path.dirname(__file__))): | ||
if file.endswith(".py") and not file.startswith("_"): | ||
module = file[: file.find(".py")] | ||
importlib.import_module("fairseq.data.encoders." + module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
import numpy as np | ||
import torch.utils.data | ||
from fairseq.data import data_utils | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class EpochListening: | ||
"""Mixin for receiving updates whenever the epoch increments.""" | ||
|
||
@property | ||
def can_reuse_epoch_itr_across_epochs(self): | ||
""" | ||
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for | ||
this dataset across epochs. | ||
This needs to return ``False`` if the sample sizes can change across | ||
epochs, in which case we may need to regenerate batches at each epoch. | ||
If your dataset relies in ``set_epoch`` then you should consider setting | ||
this to ``False``. | ||
""" | ||
return True | ||
|
||
def set_epoch(self, epoch): | ||
"""Will receive the updated epoch number at the beginning of the epoch.""" | ||
pass | ||
|
||
|
||
class FairseqDataset(torch.utils.data.Dataset, EpochListening): | ||
"""A dataset that provides helpers for batching.""" | ||
|
||
def __getitem__(self, index): | ||
raise NotImplementedError | ||
|
||
def __len__(self): | ||
raise NotImplementedError | ||
|
||
def collater(self, samples): | ||
"""Merge a list of samples to form a mini-batch. | ||
Args: | ||
samples (List[dict]): samples to collate | ||
Returns: | ||
dict: a mini-batch suitable for forwarding with a Model | ||
""" | ||
raise NotImplementedError | ||
|
||
def num_tokens(self, index): | ||
"""Return the number of tokens in a sample. This value is used to | ||
enforce ``--max-tokens`` during batching.""" | ||
raise NotImplementedError | ||
|
||
def num_tokens_vec(self, indices): | ||
"""Return the number of tokens for a set of positions defined by indices. | ||
This value is used to enforce ``--max-tokens`` during batching.""" | ||
raise NotImplementedError | ||
|
||
def size(self, index): | ||
"""Return an example's size as a float or tuple. This value is used when | ||
filtering a dataset with ``--max-positions``.""" | ||
raise NotImplementedError | ||
|
||
def ordered_indices(self): | ||
"""Return an ordered list of indices. Batches will be constructed based | ||
on this order.""" | ||
return np.arange(len(self), dtype=np.int64) | ||
|
||
@property | ||
def supports_prefetch(self): | ||
"""Whether this dataset supports prefetching.""" | ||
return False | ||
|
||
def attr(self, attr: str, index: int): | ||
return getattr(self, attr, None) | ||
|
||
def prefetch(self, indices): | ||
"""Prefetch the data required for this epoch.""" | ||
raise NotImplementedError | ||
|
||
def get_batch_shapes(self): | ||
""" | ||
Return a list of valid batch shapes, for example:: | ||
[(8, 512), (16, 256), (32, 128)] | ||
The first dimension of each tuple is the batch size and can be ``None`` | ||
to automatically infer the max batch size based on ``--max-tokens``. | ||
The second dimension of each tuple is the max supported length as given | ||
by :func:`fairseq.data.FairseqDataset.num_tokens`. | ||
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` | ||
to restrict batch shapes. This is useful on TPUs to avoid too many | ||
dynamic shapes (and recompilations). | ||
""" | ||
return None | ||
|
||
def batch_by_size( | ||
self, | ||
indices, | ||
max_tokens=None, | ||
max_sentences=None, | ||
required_batch_size_multiple=1, | ||
): | ||
""" | ||
Given an ordered set of indices, return batches according to | ||
*max_tokens*, *max_sentences* and *required_batch_size_multiple*. | ||
""" | ||
from fairseq.data import data_utils | ||
|
||
fixed_shapes = self.get_batch_shapes() | ||
if fixed_shapes is not None: | ||
|
||
def adjust_bsz(bsz, num_tokens): | ||
if bsz is None: | ||
assert max_tokens is not None, "Must specify --max-tokens" | ||
bsz = max_tokens // num_tokens | ||
if max_sentences is not None: | ||
bsz = min(bsz, max_sentences) | ||
elif ( | ||
bsz >= required_batch_size_multiple | ||
and bsz % required_batch_size_multiple != 0 | ||
): | ||
bsz -= bsz % required_batch_size_multiple | ||
return bsz | ||
|
||
fixed_shapes = np.array( | ||
[ | ||
[adjust_bsz(bsz, num_tokens), num_tokens] | ||
for (bsz, num_tokens) in fixed_shapes | ||
] | ||
) | ||
|
||
try: | ||
num_tokens_vec = self.num_tokens_vec(indices).astype("int64") | ||
except NotImplementedError: | ||
num_tokens_vec = None | ||
|
||
return data_utils.batch_by_size( | ||
indices, | ||
num_tokens_fn=self.num_tokens, | ||
num_tokens_vec=num_tokens_vec, | ||
max_tokens=max_tokens, | ||
max_sentences=max_sentences, | ||
required_batch_size_multiple=required_batch_size_multiple, | ||
fixed_shapes=fixed_shapes, | ||
) | ||
|
||
def filter_indices_by_size(self, indices, max_sizes): | ||
""" | ||
Filter a list of sample indices. Remove those that are longer than | ||
specified in *max_sizes*. | ||
WARNING: don't update, override method in child classes | ||
Args: | ||
indices (np.array): original array of sample indices | ||
max_sizes (int or list[int] or tuple[int]): max sample size, | ||
can be defined separately for src and tgt (then list or tuple) | ||
Returns: | ||
np.array: filtered sample array | ||
list: list of removed indices | ||
""" | ||
if isinstance(max_sizes, float) or isinstance(max_sizes, int): | ||
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): | ||
ignored = indices[self.sizes[indices] > max_sizes].tolist() | ||
indices = indices[self.sizes[indices] <= max_sizes] | ||
elif ( | ||
hasattr(self, "sizes") | ||
and isinstance(self.sizes, list) | ||
and len(self.sizes) == 1 | ||
): | ||
ignored = indices[self.sizes[0][indices] > max_sizes].tolist() | ||
indices = indices[self.sizes[0][indices] <= max_sizes] | ||
else: | ||
indices, ignored = data_utils._filter_by_size_dynamic( | ||
indices, self.size, max_sizes | ||
) | ||
else: | ||
indices, ignored = data_utils._filter_by_size_dynamic( | ||
indices, self.size, max_sizes | ||
) | ||
return indices, ignored | ||
|
||
@property | ||
def supports_fetch_outside_dataloader(self): | ||
"""Whether this dataset supports fetching outside the workers of the dataloader.""" | ||
return True | ||
|
||
|
||
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): | ||
""" | ||
For datasets that need to be read sequentially, usually because the data is | ||
being streamed or otherwise can't be manipulated on a single machine. | ||
""" | ||
|
||
def __iter__(self): | ||
raise NotImplementedError |
Oops, something went wrong.