diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py new file mode 100644 index 0000000..e3f316b --- /dev/null +++ b/fairseq/checkpoint_utils.py @@ -0,0 +1,936 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import collections +import contextlib +import inspect +import logging +import os +import re +import time +import traceback +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +from fairseq.data import data_utils +from fairseq.dataclass.configs import CheckpointConfig +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + overwrite_args_by_name, +) +from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP +from fairseq.file_io import PathManager +from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig, OmegaConf, open_dict + +logger = logging.getLogger(__name__) + + +def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): + from fairseq import meters + + # only one worker should attempt to create the required dir + if trainer.data_parallel_rank == 0: + os.makedirs(cfg.save_dir, exist_ok=True) + + prev_best = getattr(save_checkpoint, "best", val_loss) + if val_loss is not None: + best_function = max if cfg.maximize_best_checkpoint_metric else min + save_checkpoint.best = best_function(val_loss, prev_best) + + if cfg.no_save: + return None + + trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state + + if not trainer.should_save_checkpoint_on_current_rank: + if trainer.always_call_state_dict_during_save_checkpoint: + trainer.state_dict() + return None + + write_timer = meters.StopwatchMeter() + write_timer.start() + + epoch = epoch_itr.epoch + end_of_epoch = epoch_itr.end_of_epoch() + updates = trainer.get_num_updates() + + logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") + + def is_better(a, b): + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b + + suffix = trainer.checkpoint_suffix + checkpoint_conds = collections.OrderedDict() + checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( + end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 + ) + checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( + not end_of_epoch + and cfg.save_interval_updates > 0 + and updates % cfg.save_interval_updates == 0 + ) + checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( + not hasattr(save_checkpoint, "best") + or is_better(val_loss, save_checkpoint.best) + ) + if val_loss is not None and cfg.keep_best_checkpoints > 0: + worst_best = getattr(save_checkpoint, "best", None) + chkpts = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if len(chkpts) > 0: + p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] + worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) + # add random digits to resolve ties + with data_utils.numpy_seed(epoch, updates, val_loss): + rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) + + checkpoint_conds[ + "checkpoint.best_{}_{:.3f}{}{}.pt".format( + cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix + ) + ] = worst_best is None or is_better(val_loss, worst_best) + checkpoint_conds[ + "checkpoint_last{}.pt".format(suffix) + ] = not cfg.no_last_checkpoints + + extra_state = { + "train_iterator": epoch_itr.state_dict(), + "val_loss": val_loss, + } + + # Going forward, different tasks could expose an API like this to dump all + # the checkpoint worthy attributes in a dictionary which then will be + # merged with the parent dictionary to create the "extra_state". This + # allows for an extensible yet simple design to checkpoint task level + # attributes + if hasattr(trainer.task, "get_checkpoint_dict"): + extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()} + logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint") + + if hasattr(save_checkpoint, "best"): + extra_state.update({"best": save_checkpoint.best}) + + checkpoints = [ + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + ] + saved_cp = None + if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank: + saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state) + for cp in checkpoints[1:]: + if cfg.write_checkpoints_asynchronously: + # TODO[ioPath]: Need to implement a delayed asynchronous + # file copying/moving feature. + logger.warning( + f"ioPath is not copying {checkpoints[0]} to {cp} " + "since async write mode is on." + ) + else: + assert PathManager.copy( + checkpoints[0], cp, overwrite=True + ), f"Failed to copy {checkpoints[0]} to {cp}" + + write_timer.stop() + logger.info( + "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( + checkpoints[0], epoch, updates, val_loss, write_timer.sum + ) + ) + + if ( + not end_of_epoch + and cfg.keep_interval_updates > 0 + and trainer.should_save_checkpoint_on_current_rank + ): + # remove old checkpoints; checkpoints are sorted in descending order + if cfg.keep_interval_updates_pattern == -1: + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + ) + else: + checkpoints = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), + keep_match=True, + ) + checkpoints = [ + x[0] + for x in checkpoints + if x[1] % cfg.keep_interval_updates_pattern != 0 + ] + + for old_chk in checkpoints[cfg.keep_interval_updates :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank: + # only keep the best N checkpoints according to validation metric + checkpoints = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if not cfg.maximize_best_checkpoint_metric: + checkpoints = checkpoints[::-1] + for old_chk in checkpoints[cfg.keep_best_checkpoints :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + return saved_cp + + +def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): + """ + Load a checkpoint and restore the training iterator. + + *passthrough_args* will be passed through to + ``trainer.get_train_iterator``. + """ + + reset_optimizer = cfg.reset_optimizer + reset_lr_scheduler = cfg.reset_lr_scheduler + optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) + reset_meters = cfg.reset_meters + reset_dataloader = cfg.reset_dataloader + + if cfg.finetune_from_model is not None and ( + reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader + ): + raise ValueError( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" + ) + + suffix = trainer.checkpoint_suffix + if ( + cfg.restore_file == "checkpoint_last.pt" + ): # default value of restore_file is 'checkpoint_last.pt' + checkpoint_path = os.path.join( + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + ) + first_launch = not PathManager.exists(checkpoint_path) + if first_launch and getattr(cfg, "continue_once", None) is not None: + checkpoint_path = cfg.continue_once + elif cfg.finetune_from_model is not None and first_launch: + # if there is no last checkpoint to restore, start the finetune from pretrained model + # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. + if PathManager.exists(cfg.finetune_from_model): + checkpoint_path = cfg.finetune_from_model + reset_optimizer = True + reset_lr_scheduler = True + reset_meters = True + reset_dataloader = True + logger.info( + f"loading pretrained model from {checkpoint_path}: " + "optimizer, lr scheduler, meters, dataloader will be reset" + ) + else: + raise ValueError( + f"--finetune-from-model {cfg.finetune_from_model} does not exist" + ) + elif suffix is not None: + checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") + else: + checkpoint_path = cfg.restore_file + + if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: + raise ValueError( + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(cfg) + ) + + extra_state = trainer.load_checkpoint( + checkpoint_path, + reset_optimizer, + reset_lr_scheduler, + optimizer_overrides, + reset_meters=reset_meters, + ) + + if ( + extra_state is not None + and "best" in extra_state + and not reset_optimizer + and not reset_meters + ): + save_checkpoint.best = extra_state["best"] + + if extra_state is not None and not reset_dataloader: + # restore iterator from checkpoint + itr_state = extra_state["train_iterator"] + epoch_itr = trainer.get_train_iterator( + epoch=itr_state["epoch"], load_dataset=True, **passthrough_args + ) + epoch_itr.load_state_dict(itr_state) + + # Preload the checkpoint for the task + task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {}) + if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"): + trainer.task.set_checkpoint_dict(task_cp_dict) + else: + epoch_itr = trainer.get_train_iterator( + epoch=1, load_dataset=True, **passthrough_args + ) + + trainer.lr_step(epoch_itr.epoch) + + return extra_state, epoch_itr + + +def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): + """Loads a checkpoint to CPU (with upgrading for backward compatibility). + + If doing single-GPU training or if the checkpoint is only being loaded by at + most one process on each node (current default behavior is for only rank 0 + to read the checkpoint from disk), load_on_all_ranks should be False to + avoid errors from torch.distributed not having been initialized or + torch.distributed.barrier() hanging. + + If all processes on each node may be loading the checkpoint + simultaneously, load_on_all_ranks should be set to True to avoid I/O + conflicts. + + There's currently no support for > 1 but < all processes loading the + checkpoint on each node. + """ + local_path = PathManager.get_local_path(path) + # The locally cached file returned by get_local_path() may be stale for + # remote files that are periodically updated/overwritten (ex: + # checkpoint_last.pt) - so we remove the local copy, sync across processes + # (if needed), and then download a fresh copy. + if local_path != path and PathManager.path_requires_pathmanager(path): + try: + os.remove(local_path) + except FileNotFoundError: + # With potentially multiple processes removing the same file, the + # file being missing is benign (missing_ok isn't available until + # Python 3.8). + pass + if load_on_all_ranks: + torch.distributed.barrier() + local_path = PathManager.get_local_path(path) + + with open(local_path, "rb") as f: + state = torch.load(f, map_location=torch.device("cpu")) + + if "args" in state and state["args"] is not None and arg_overrides is not None: + args = state["args"] + for arg_name, arg_val in arg_overrides.items(): + setattr(args, arg_name, arg_val) + + if "cfg" in state and state["cfg"] is not None: + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import __version__ as oc_version + from omegaconf import _utils + + if oc_version < "2.2": + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + state["cfg"] = OmegaConf.create(state["cfg"]) + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(state["cfg"], True) + else: + state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True}) + + if arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) + + state = _upgrade_state_dict(state) + return state + + +def load_model_ensemble( + filenames, + arg_overrides: Optional[Dict[str, Any]] = None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, +): + """Loads an ensemble of models. + + Args: + filenames (List[str]): checkpoint files to load + arg_overrides (Dict[str,Any], optional): override model args that + were used during model training + task (fairseq.tasks.FairseqTask, optional): task to use for loading + """ + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" + ensemble, args, _task = load_model_ensemble_and_task( + filenames, + arg_overrides, + task, + strict, + suffix, + num_shards, + state, + ) + return ensemble, args + + +def get_maybe_sharded_checkpoint_filename( + filename: str, suffix: str, shard_idx: int, num_shards: int +) -> str: + orig_filename = filename + filename = filename.replace(".pt", suffix + ".pt") + fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt" + model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + if PathManager.exists(fsdp_filename): + return fsdp_filename + elif num_shards > 1: + return model_parallel_filename + else: + return filename + + +def load_model_ensemble_and_task( + filenames, + arg_overrides: Optional[Dict[str, Any]] = None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, +): + assert state is None or len(filenames) == 1 + + from fairseq import tasks + + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" + ensemble = [] + cfg = None + for filename in filenames: + orig_filename = filename + model_shard_state = {"shard_weights": [], "shard_metadata": []} + assert num_shards > 0 + st = time.time() + for shard_idx in range(num_shards): + filename = get_maybe_sharded_checkpoint_filename( + orig_filename, suffix, shard_idx, num_shards + ) + + if not PathManager.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + if state is None: + state = load_checkpoint_to_cpu(filename, arg_overrides) + if "args" in state and state["args"] is not None: + cfg = convert_namespace_to_omegaconf(state["args"]) + elif "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + else: + raise RuntimeError( + f"Neither args nor cfg exist in state keys = {state.keys()}" + ) + + if task is None: + task = tasks.setup_task(cfg.task, from_checkpoint=True) + + if "task_state" in state: + task.load_state_dict(state["task_state"]) + + argspec = inspect.getfullargspec(task.build_model) + + if "fsdp_metadata" in state and num_shards > 1: + model_shard_state["shard_weights"].append(state["model"]) + model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) + # check FSDP import before the code goes too far + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if shard_idx == num_shards - 1: + consolidated_model_state = FSDP.consolidate_shard_weights( + shard_weights=model_shard_state["shard_weights"], + shard_metadata=model_shard_state["shard_metadata"], + ) + if "from_checkpoint" in argspec.args: + model = task.build_model(cfg.model, from_checkpoint=True) + else: + model = task.build_model(cfg.model) + if ( + "optimizer_history" in state + and len(state["optimizer_history"]) > 0 + and "num_updates" in state["optimizer_history"][-1] + ): + model.set_num_updates( + state["optimizer_history"][-1]["num_updates"] + ) + model.load_state_dict( + consolidated_model_state, strict=strict, model_cfg=cfg.model + ) + else: + # model parallel checkpoint or unsharded checkpoint + # support old external tasks + + if "from_checkpoint" in argspec.args: + model = task.build_model(cfg.model, from_checkpoint=True) + else: + model = task.build_model(cfg.model) + if ( + "optimizer_history" in state + and len(state["optimizer_history"]) > 0 + and "num_updates" in state["optimizer_history"][-1] + ): + model.set_num_updates(state["optimizer_history"][-1]["num_updates"]) + model.load_state_dict( + state["model"], strict=strict, model_cfg=cfg.model + ) + + # reset state so it gets loaded for the next model in ensemble + state = None + if shard_idx % 10 == 0 and shard_idx > 0: + elapsed = time.time() - st + logger.info( + f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" + ) + + # build model for ensemble + ensemble.append(model) + return ensemble, cfg, task + + +def load_model_ensemble_and_task_from_hf_hub( + model_id, + cache_dir: Optional[str] = None, + arg_overrides: Optional[Dict[str, Any]] = None, + **kwargs: Any, +): + try: + from huggingface_hub import snapshot_download + except ImportError: + raise ImportError( + "You need to install huggingface_hub to use `load_from_hf_hub`. " + "See https://pypi.org/project/huggingface-hub/ for installation." + ) + + library_name = "fairseq" + cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix() + cache_dir = snapshot_download( + model_id, cache_dir=cache_dir, library_name=library_name, **kwargs + ) + + _arg_overrides = arg_overrides or {} + _arg_overrides["data"] = cache_dir + return load_model_ensemble_and_task( + [p.as_posix() for p in Path(cache_dir).glob("*.pt")], + arg_overrides=_arg_overrides, + ) + + +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False): + """Retrieves all checkpoints found in `path` directory. + + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = PathManager.ls(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = float(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + if keep_match: + return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)] + else: + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + + +def torch_persistent_save(obj, filename, async_write: bool = False): + if async_write: + with PathManager.opena(filename, "wb") as f: + _torch_persistent_save(obj, f) + else: + if PathManager.supports_rename(filename): + # do atomic save + with PathManager.open(filename + ".tmp", "wb") as f: + _torch_persistent_save(obj, f) + PathManager.rename(filename + ".tmp", filename) + else: + # fallback to non-atomic save + with PathManager.open(filename, "wb") as f: + _torch_persistent_save(obj, f) + + +def _torch_persistent_save(obj, f): + if isinstance(f, str): + with PathManager.open(f, "wb") as h: + torch_persistent_save(obj, h) + return + for i in range(3): + try: + return torch.save(obj, f) + except Exception: + if i == 2: + logger.error(traceback.format_exc()) + raise + else: + time.sleep(2.5) + + +def _upgrade_state_dict(state): + """Helper for upgrading old model checkpoints.""" + + # add optimizer_history + if "optimizer_history" not in state: + state["optimizer_history"] = [ + {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]} + ] + state["last_optimizer_state"] = state["optimizer"] + del state["optimizer"] + del state["best_loss"] + # move extra_state into sub-dictionary + if "epoch" in state and "extra_state" not in state: + state["extra_state"] = { + "epoch": state["epoch"], + "batch_offset": state["batch_offset"], + "val_loss": state["val_loss"], + } + del state["epoch"] + del state["batch_offset"] + del state["val_loss"] + # reduce optimizer history's memory usage (only keep the last state) + if "optimizer" in state["optimizer_history"][-1]: + state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"] + for optim_hist in state["optimizer_history"]: + del optim_hist["optimizer"] + # record the optimizer class name + if "optimizer_name" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" + # move best_loss into lr_scheduler_state + if "lr_scheduler_state" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["lr_scheduler_state"] = { + "best": state["optimizer_history"][-1]["best_loss"] + } + del state["optimizer_history"][-1]["best_loss"] + # keep track of number of updates + if "num_updates" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["num_updates"] = 0 + # use stateful training data iterator + if "train_iterator" not in state["extra_state"]: + state["extra_state"]["train_iterator"] = { + "epoch": state["extra_state"].get("epoch", 0), + "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), + } + + # backward compatibility, cfg updates + if "args" in state and state["args"] is not None: + # old model checkpoints may not have separate source/target positions + if hasattr(state["args"], "max_positions") and not hasattr( + state["args"], "max_source_positions" + ): + state["args"].max_source_positions = state["args"].max_positions + state["args"].max_target_positions = state["args"].max_positions + # default to translation task + if not hasattr(state["args"], "task"): + state["args"].task = "translation" + # --raw-text and --lazy-load are deprecated + if getattr(state["args"], "raw_text", False): + state["args"].dataset_impl = "raw" + elif getattr(state["args"], "lazy_load", False): + state["args"].dataset_impl = "lazy" + # epochs start at 1 + if state["extra_state"]["train_iterator"] is not None: + state["extra_state"]["train_iterator"]["epoch"] = max( + state["extra_state"]["train_iterator"].get("epoch", 1), 1 + ) + # --remove-bpe ==> --postprocess + if hasattr(state["args"], "remove_bpe"): + state["args"].post_process = state["args"].remove_bpe + # --min-lr ==> --stop-min-lr + if hasattr(state["args"], "min_lr"): + state["args"].stop_min_lr = state["args"].min_lr + del state["args"].min_lr + # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion + if hasattr(state["args"], "criterion") and state["args"].criterion in [ + "binary_cross_entropy", + "kd_binary_cross_entropy", + ]: + state["args"].criterion = "wav2vec" + # remove log_keys if it's None (criteria will supply a default value of []) + if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: + delattr(state["args"], "log_keys") + # speech_pretraining => audio pretraining + if ( + hasattr(state["args"], "task") + and state["args"].task == "speech_pretraining" + ): + state["args"].task = "audio_pretraining" + # audio_cpc => wav2vec + if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": + state["args"].arch = "wav2vec" + # convert legacy float learning rate to List[float] + if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): + state["args"].lr = [state["args"].lr] + # convert task data arg to a string instead of List[string] + if ( + hasattr(state["args"], "data") + and isinstance(state["args"].data, list) + and len(state["args"].data) > 0 + ): + state["args"].data = state["args"].data[0] + + state["cfg"] = convert_namespace_to_omegaconf(state["args"]) + + if "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + with open_dict(cfg): + # any upgrades for Hydra-based configs + if ( + "task" in cfg + and "eval_wer_config" in cfg.task + and isinstance(cfg.task.eval_wer_config.print_alignment, bool) + ): + cfg.task.eval_wer_config.print_alignment = "hard" + if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): + cfg.generation.print_alignment = ( + "hard" if cfg.generation.print_alignment else None + ) + if ( + "model" in cfg + and "w2v_args" in cfg.model + and cfg.model.w2v_args is not None + and ( + hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args + ) + and hasattr(cfg.model.w2v_args.task, "eval_wer_config") + and cfg.model.w2v_args.task.eval_wer_config is not None + and isinstance( + cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool + ) + ): + cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" + + return state + + +def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): + """Prune the given state_dict if desired for LayerDrop + (https://arxiv.org/abs/1909.11556). + + Training with LayerDrop allows models to be robust to pruning at inference + time. This function prunes state_dict to allow smaller models to be loaded + from a larger model and re-maps the existing state_dict for this to occur. + + It's called by functions that load models from checkpoints and does not + need to be called directly. + """ + arch = None + if model_cfg is not None: + arch = ( + model_cfg._name + if isinstance(model_cfg, DictConfig) + else getattr(model_cfg, "arch", None) + ) + + if not model_cfg or arch is None or arch == "ptt_transformer": + # args should not be none, but don't crash if it is. + return state_dict + + encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) + decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) + + if not encoder_layers_to_keep and not decoder_layers_to_keep: + return state_dict + + # apply pruning + logger.info( + "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop" + ) + + def create_pruning_pass(layers_to_keep, layer_name): + keep_layers = sorted( + int(layer_string) for layer_string in layers_to_keep.split(",") + ) + mapping_dict = {} + for i in range(len(keep_layers)): + mapping_dict[str(keep_layers[i])] = str(i) + + regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) + return {"substitution_regex": regex, "mapping_dict": mapping_dict} + + pruning_passes = [] + if encoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder")) + if decoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder")) + + new_state_dict = {} + for layer_name in state_dict.keys(): + match = re.search(r"\.layers\.(\d+)\.", layer_name) + # if layer has no number in it, it is a supporting layer, such as an + # embedding + if not match: + new_state_dict[layer_name] = state_dict[layer_name] + continue + + # otherwise, layer should be pruned. + original_layer_number = match.group(1) + # figure out which mapping dict to replace from + for pruning_pass in pruning_passes: + if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[ + "substitution_regex" + ].search(layer_name): + new_layer_number = pruning_pass["mapping_dict"][original_layer_number] + substitution_match = pruning_pass["substitution_regex"].search( + layer_name + ) + new_state_key = ( + layer_name[: substitution_match.start(1)] + + new_layer_number + + layer_name[substitution_match.end(1) :] + ) + new_state_dict[new_state_key] = state_dict[layer_name] + + # Since layers are now pruned, *_layers_to_keep are no longer needed. + # This is more of "It would make it work fix" rather than a proper fix. + if isinstance(model_cfg, DictConfig): + context = open_dict(model_cfg) + else: + context = contextlib.ExitStack() + with context: + if hasattr(model_cfg, "encoder_layers_to_keep"): + model_cfg.encoder_layers_to_keep = None + if hasattr(model_cfg, "decoder_layers_to_keep"): + model_cfg.decoder_layers_to_keep = None + + return new_state_dict + + +def load_pretrained_component_from_model( + component: Union[FairseqEncoder, FairseqDecoder], + checkpoint: str, + strict: bool = True, +): + """ + Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the + provided `component` object. If state_dict fails to load, there may be a + mismatch in the architecture of the corresponding `component` found in the + `checkpoint` file. + """ + if not PathManager.exists(checkpoint): + raise IOError("Model file not found: {}".format(checkpoint)) + state = load_checkpoint_to_cpu(checkpoint) + if isinstance(component, FairseqEncoder): + component_type = "encoder" + elif isinstance(component, FairseqDecoder): + component_type = "decoder" + else: + raise ValueError( + "component to load must be either a FairseqEncoder or " + "FairseqDecoder. Loading other component types are not supported." + ) + component_state_dict = OrderedDict() + for key in state["model"].keys(): + if key.startswith(component_type): + # encoder.input_layers.0.0.weight --> input_layers.0.0.weight + component_subkey = key[len(component_type) + 1 :] + component_state_dict[component_subkey] = state["model"][key] + component.load_state_dict(component_state_dict, strict=strict) + return component + + +def verify_checkpoint_directory(save_dir: str) -> None: + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + temp_file_path = os.path.join(save_dir, "dummy") + try: + with open(temp_file_path, "w"): + pass + except OSError as e: + logger.warning( + "Unable to access checkpoint save directory: {}".format(save_dir) + ) + raise e + else: + os.remove(temp_file_path) + + +def save_ema_as_checkpoint(src_path, dst_path): + state = load_ema_from_checkpoint(src_path) + torch_persistent_save(state, dst_path) + + +def load_ema_from_checkpoint(fpath): + """Loads exponential moving averaged (EMA) checkpoint from input and + returns a model with ema weights. + + Args: + fpath: A string path of checkpoint to load from. + + Returns: + A dict of string keys mapping to various values. The 'model' key + from the returned dict should correspond to an OrderedDict mapping + string parameter names to torch Tensors. + """ + params_dict = collections.OrderedDict() + new_state = None + + with PathManager.open(fpath, "rb") as f: + new_state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + ) + + # EMA model is stored in a separate "extra state" + model_params = new_state["extra_state"]["ema"] + + for key in list(model_params.keys()): + p = model_params[key] + if isinstance(p, torch.HalfTensor): + p = p.float() + if key not in params_dict: + params_dict[key] = p.clone() + # NOTE: clone() is needed in case of p is a shared parameter + else: + raise ValueError("Key {} is repeated in EMA model params.".format(key)) + + if len(params_dict) == 0: + raise ValueError( + f"Input checkpoint path '{fpath}' does not contain " + "ema model weights, is this model trained with EMA?" + ) + + new_state["model"] = params_dict + return new_state diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index e69de29..a3abc8b 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -0,0 +1,7 @@ +from .dictionary import Dictionary +from .fairseq_dataset import FairseqDataset + +__all__ = [ + "Dictionary", + "FairseqDataset", +] diff --git a/fairseq/data/encoders/__init__.py b/fairseq/data/encoders/__init__.py new file mode 100644 index 0000000..7cbe00a --- /dev/null +++ b/fairseq/data/encoders/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os + +from fairseq import registry + + +build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( + "--tokenizer", + default=None, +) + + +build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry( + "--bpe", + default=None, +) + + +# automatically import any Python files in the encoders/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.data.encoders." + module) diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py new file mode 100644 index 0000000..2bde7fc --- /dev/null +++ b/fairseq/data/fairseq_dataset.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import numpy as np +import torch.utils.data +from fairseq.data import data_utils + +logger = logging.getLogger(__name__) + + +class EpochListening: + """Mixin for receiving updates whenever the epoch increments.""" + + @property + def can_reuse_epoch_itr_across_epochs(self): + """ + Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for + this dataset across epochs. + + This needs to return ``False`` if the sample sizes can change across + epochs, in which case we may need to regenerate batches at each epoch. + If your dataset relies in ``set_epoch`` then you should consider setting + this to ``False``. + """ + return True + + def set_epoch(self, epoch): + """Will receive the updated epoch number at the beginning of the epoch.""" + pass + + +class FairseqDataset(torch.utils.data.Dataset, EpochListening): + """A dataset that provides helpers for batching.""" + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + raise NotImplementedError + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + raise NotImplementedError + + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + raise NotImplementedError + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + raise NotImplementedError + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + return np.arange(len(self), dtype=np.int64) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return False + + def attr(self, attr: str, index: int): + return getattr(self, attr, None) + + def prefetch(self, indices): + """Prefetch the data required for this epoch.""" + raise NotImplementedError + + def get_batch_shapes(self): + """ + Return a list of valid batch shapes, for example:: + + [(8, 512), (16, 256), (32, 128)] + + The first dimension of each tuple is the batch size and can be ``None`` + to automatically infer the max batch size based on ``--max-tokens``. + The second dimension of each tuple is the max supported length as given + by :func:`fairseq.data.FairseqDataset.num_tokens`. + + This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` + to restrict batch shapes. This is useful on TPUs to avoid too many + dynamic shapes (and recompilations). + """ + return None + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + """ + Given an ordered set of indices, return batches according to + *max_tokens*, *max_sentences* and *required_batch_size_multiple*. + """ + from fairseq.data import data_utils + + fixed_shapes = self.get_batch_shapes() + if fixed_shapes is not None: + + def adjust_bsz(bsz, num_tokens): + if bsz is None: + assert max_tokens is not None, "Must specify --max-tokens" + bsz = max_tokens // num_tokens + if max_sentences is not None: + bsz = min(bsz, max_sentences) + elif ( + bsz >= required_batch_size_multiple + and bsz % required_batch_size_multiple != 0 + ): + bsz -= bsz % required_batch_size_multiple + return bsz + + fixed_shapes = np.array( + [ + [adjust_bsz(bsz, num_tokens), num_tokens] + for (bsz, num_tokens) in fixed_shapes + ] + ) + + try: + num_tokens_vec = self.num_tokens_vec(indices).astype("int64") + except NotImplementedError: + num_tokens_vec = None + + return data_utils.batch_by_size( + indices, + num_tokens_fn=self.num_tokens, + num_tokens_vec=num_tokens_vec, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + fixed_shapes=fixed_shapes, + ) + + def filter_indices_by_size(self, indices, max_sizes): + """ + Filter a list of sample indices. Remove those that are longer than + specified in *max_sizes*. + + WARNING: don't update, override method in child classes + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if isinstance(max_sizes, float) or isinstance(max_sizes, int): + if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): + ignored = indices[self.sizes[indices] > max_sizes].tolist() + indices = indices[self.sizes[indices] <= max_sizes] + elif ( + hasattr(self, "sizes") + and isinstance(self.sizes, list) + and len(self.sizes) == 1 + ): + ignored = indices[self.sizes[0][indices] > max_sizes].tolist() + indices = indices[self.sizes[0][indices] <= max_sizes] + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored + + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return True + + +class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): + """ + For datasets that need to be read sequentially, usually because the data is + being streamed or otherwise can't be manipulated on a single machine. + """ + + def __iter__(self): + raise NotImplementedError diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py new file mode 100644 index 0000000..6a5a42a --- /dev/null +++ b/fairseq/data/iterators.py @@ -0,0 +1,879 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import math +import operator +import os +import queue +import time +from threading import Thread +from typing import Iterator, List + +import numpy as np +import torch +from fairseq.data import data_utils + + +logger = logging.getLogger(__name__) + +# Object used by _background_consumer to signal the source is exhausted +# to the main thread. +_sentinel = object() + + +class CountingIterator(object): + """Wrapper around an iterable that maintains the iteration count. + + Args: + iterable (iterable): iterable to wrap + start (int): starting iteration count. Note that this doesn't + actually advance the iterator. + total (int): override the iterator length returned by ``__len``. + This can be used to truncate *iterator*. + + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__(self, iterable, start=None, total=None): + self._itr = iter(iterable) + self.n = start or getattr(iterable, "n", 0) + self.total = total if total is not None else self.n + len(iterable) + + def __len__(self): + return self.total + + def __iter__(self): + return self + + def __next__(self): + if not self.has_next(): + raise StopIteration + try: + x = next(self._itr) + except StopIteration: + raise IndexError( + f"Iterator expected to have length {self.total}, " + f"but exhausted at position {self.n}." + ) + self.n += 1 + return x + + def has_next(self): + """Whether the iterator has been exhausted.""" + return self.n < self.total + + def skip(self, n): + """Fast-forward the iterator by skipping n elements.""" + for _ in range(n): + next(self) + return self + + def take(self, n): + """Truncate the iterator to n elements at most.""" + self.total = min(self.total, n) + # Propagate this change to the underlying iterator + if hasattr(self._itr, "take"): + self._itr.take(max(n - self.n, 0)) + return self + + +class EpochBatchIterating(object): + def __len__(self) -> int: + raise NotImplementedError + + @property + def next_epoch_idx(self): + raise NotImplementedError + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + """Return a new iterator over the dataset. + + Args: + shuffle (bool, optional): shuffle batches before returning the + iterator (default: True). + fix_batches_to_gpus (bool, optional): ensure that batches are always + allocated to the same shards across epochs. Requires + that :attr:`dataset` supports prefetching (default: False). + set_dataset_epoch (bool, optional): update the wrapped Dataset with + the new epoch number (default: True). + """ + raise NotImplementedError + + def end_of_epoch(self) -> bool: + """Returns whether the most recent epoch iterator has been exhausted""" + raise NotImplementedError + + @property + def iterations_in_epoch(self) -> int: + """The number of consumed batches in the current epoch.""" + raise NotImplementedError + + def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" + raise NotImplementedError + + def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" + raise NotImplementedError + + @property + def first_batch(self): + return "DUMMY" + + +class StreamingEpochBatchIterator(EpochBatchIterating): + """A steaming-style iterator over a :class:`torch.utils.data.IterableDataset`. + + Args: + dataset (~torch.utils.data.Dataset): dataset from which to load the data + max_sentences: batch size + collate_fn (callable): merges a list of samples to form a mini-batch + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + buffer_size (int, optional): the number of batches to keep ready in the + queue. Helps speeding up dataloading. When buffer_size is zero, the + default torch.utils.data.DataLoader preloading is used. + timeout (int, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative (default: ``0``). + """ + + def __init__( + self, + dataset, + max_sentences=1, + collate_fn=None, + epoch=1, + num_workers=0, + buffer_size=0, + timeout=0, + persistent_workers=True, + ): + assert isinstance(dataset, torch.utils.data.IterableDataset) + self.dataset = dataset + self.max_sentences = max_sentences + self.collate_fn = collate_fn + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs + self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 + # This upper limit here is to prevent people from abusing this feature + # in a shared computing environment. + self.buffer_size = min(buffer_size, 20) + self.timeout = timeout + + self._current_epoch_iterator = None + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + if self._current_epoch_iterator is not None and self.end_of_epoch(): + return self.epoch + 1 + else: + return self.epoch + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + self.epoch = self.next_epoch_idx + if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(self.epoch) + self._current_epoch_iterator = self._get_iterator_for_epoch(self.epoch, shuffle) + return self._current_epoch_iterator + + def end_of_epoch(self) -> bool: + return not self._current_epoch_iterator.has_next() + + @property + def iterations_in_epoch(self) -> int: + if self._current_epoch_iterator is not None: + return self._current_epoch_iterator.n + return 0 + + def state_dict(self): + return { + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + + def _get_iterator_for_epoch(self, epoch, shuffle, offset=0): + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + # Create data loader + worker_init_fn = getattr(self.dataset, "worker_init_fn", None) + itr = torch.utils.data.DataLoader( + self.dataset, + batch_size=self.max_sentences, + collate_fn=self.collate_fn, + num_workers=self.num_workers, + timeout=self.timeout, + worker_init_fn=worker_init_fn, + pin_memory=True, + persistent_workers=self.persistent_workers, + ) + + # Wrap with a BufferedIterator if needed + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # Wrap with CountingIterator + itr = CountingIterator(itr, start=offset) + + return itr + + +class FrozenBatchSampler: + def __init__( + self, + ordered_batches, + epoch, + fix_batches_to_gpus, + shuffle, + initial_offset, + ): + self.ordered_batches = ordered_batches + self.fix_batches_to_gpus = fix_batches_to_gpus + self.shuffle = shuffle + self.make_batches_for_epoch(epoch, initial_offset) + + def make_batches_for_epoch(self, epoch, offset=0): + self.batches = self.ordered_batches( + epoch, self.fix_batches_to_gpus, self.shuffle + ) + if offset > 0: + self.batches = self.batches[offset:] + + def __iter__(self) -> Iterator[List[int]]: + return iter(self.batches) + + def __len__(self) -> int: + return len(self.batches) + + +class EpochBatchIterator(EpochBatchIterating): + """A multi-epoch iterator over a :class:`torch.utils.data.Dataset`. + + Compared to :class:`torch.utils.data.DataLoader`, this iterator: + + - can be reused across multiple epochs with the :func:`next_epoch_itr` + method (optionally shuffled between epochs) + - can be serialized/deserialized with the :func:`state_dict` and + :func:`load_state_dict` methods + - supports sharding with the *num_shards* and *shard_id* arguments + + Args: + dataset (~torch.utils.data.Dataset): dataset from which to load the data + collate_fn (callable): merges a list of samples to form a mini-batch + batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of + indices, or a callable to create such an iterator (~torch.utils.data.Sampler). + A callable batch_sampler will be called for each epoch to enable per epoch dynamic + batch iterators defined by this callable batch_sampler. + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + buffer_size (int, optional): the number of batches to keep ready in the + queue. Helps speeding up dataloading. When buffer_size is zero, the + default torch.utils.data.DataLoader preloading is used. + timeout (int, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative (default: ``0``). + disable_shuffling (bool, optional): force disable shuffling + (default: ``False``). + skip_remainder_batch (bool, optional): if set, discard the last batch in an epoch + for the sake of training stability, as the last batch is usually smaller than + local_batch_size * distributed_word_size (default: ``False``). + grouped_shuffling (bool, optional): enable shuffling batches in groups + of num_shards. Ensures that each GPU receives similar length sequences when + batches are sorted by length. + """ + + def __init__( + self, + dataset, + collate_fn, + batch_sampler, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + buffer_size=0, + timeout=0, + disable_shuffling=False, + skip_remainder_batch=False, + grouped_shuffling=False, + reuse_dataloader=False, + persistent_workers=True, + ): + assert isinstance(dataset, torch.utils.data.Dataset) + self.dataset = dataset + self.collate_fn = collate_fn + self.batch_sampler = batch_sampler + self._frozen_batches = ( + tuple(batch_sampler) if not callable(batch_sampler) else None + ) + self.seed = seed + self.num_shards = num_shards + self.shard_id = shard_id + self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 + # This upper limit here is to prevent people from abusing this feature + # in a shared computing environment. + self.buffer_size = min(buffer_size, 20) + self.timeout = timeout + self.disable_shuffling = disable_shuffling + self.skip_remainder_batch = skip_remainder_batch + self.grouped_shuffling = grouped_shuffling + + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs + self.shuffle = not disable_shuffling + self._cur_epoch_itr = None + self._next_epoch_itr = None + self._supports_prefetch = getattr(dataset, "supports_prefetch", False) + + self.dataloader = None + self.reuse_dataloader = reuse_dataloader + + @property + def frozen_batches(self): + if self._frozen_batches is None: + self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch)) + return self._frozen_batches + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + if getattr(self.dataset, "supports_fetch_outside_dataloader", True): + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) + else: + return "DUMMY" + + def __len__(self): + return int(math.ceil(len(self.frozen_batches) / float(self.num_shards))) + + @property + def n(self): + return self.iterations_in_epoch + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + if self._next_epoch_itr is not None: + return self.epoch + elif self._cur_epoch_itr is not None and self.end_of_epoch(): + return self.epoch + 1 + else: + return self.epoch + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + """Return a new iterator over the dataset. + + Args: + shuffle (bool, optional): shuffle batches before returning the + iterator (default: True). + fix_batches_to_gpus (bool, optional): ensure that batches are always + allocated to the same shards across epochs. Requires + that :attr:`dataset` supports prefetching (default: False). + set_dataset_epoch (bool, optional): update the wrapped Dataset with + the new epoch number (default: True). + """ + if self.disable_shuffling: + shuffle = False + prev_epoch = self.epoch + self.epoch = self.next_epoch_idx + if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(self.epoch) + if self._next_epoch_itr is not None: + self._cur_epoch_itr = self._next_epoch_itr + self._next_epoch_itr = None + else: + if callable(self.batch_sampler) and prev_epoch != self.epoch: + # reset _frozen_batches to refresh the next epoch + self._frozen_batches = None + self._cur_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle, + fix_batches_to_gpus=fix_batches_to_gpus, + ) + self.shuffle = shuffle + return self._cur_epoch_itr + + def end_of_epoch(self) -> bool: + """Returns whether the most recent epoch iterator has been exhausted""" + return not self._cur_epoch_itr.has_next() + + @property + def iterations_in_epoch(self): + """The number of consumed batches in the current epoch.""" + if self._cur_epoch_itr is not None: + return self._cur_epoch_itr.n + elif self._next_epoch_itr is not None: + return self._next_epoch_itr.n + return 0 + + def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" + if self.end_of_epoch(): + epoch = self.epoch + 1 + iter_in_epoch = 0 + else: + epoch = self.epoch + iter_in_epoch = self.iterations_in_epoch + return { + "version": 2, + "epoch": epoch, + "iterations_in_epoch": iter_in_epoch, + "shuffle": self.shuffle, + } + + def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" + self.epoch = state_dict["epoch"] + itr_pos = state_dict.get("iterations_in_epoch", 0) + version = state_dict.get("version", 1) + if itr_pos > 0: + # fast-forward epoch iterator + self._next_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle=state_dict.get("shuffle", True), + offset=itr_pos, + ) + if self._next_epoch_itr is None: + if version == 1: + # legacy behavior: we finished the epoch, increment epoch counter + self.epoch += 1 + else: + raise RuntimeError( + "Cannot resume training due to dataloader mismatch, please " + "report this to the fairseq developers. You can relaunch " + "training with `--reset-dataloader` and it should work." + ) + else: + self._next_epoch_itr = None + + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): + if self.reuse_dataloader and self.dataloader is not None: + self.epoch_batch_sampler.make_batches_for_epoch(epoch, offset) + itr = self.dataloader + else: + self.epoch_batch_sampler = FrozenBatchSampler( + self.ordered_batches, + epoch, + fix_batches_to_gpus, + shuffle, + initial_offset=offset, + ) + + if offset > 0 and len(self.epoch_batch_sampler) == 0: + return None + + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + # Create data loader + itr = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=self.epoch_batch_sampler, + num_workers=self.num_workers, + timeout=self.timeout, + pin_memory=True, + persistent_workers=self.persistent_workers, + ) + + if self.reuse_dataloader: + self.dataloader = itr + + # Wrap with a BufferedIterator if needed + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # Wrap with CountingIterator + itr = CountingIterator(itr, start=offset) + + if self.skip_remainder_batch: + # TODO: Below is a lazy implementation which discard the final batch regardless + # of whether it is a full batch or not. + + total_num_itrs = len(itr) - 1 + itr.take(total_num_itrs) + logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}") + + return itr + + def ordered_batches(self, epoch, fix_batches_to_gpus, shuffle): + def shuffle_batches(batches, seed): + with data_utils.numpy_seed(seed): + + if self.grouped_shuffling: + grouped_batches = [ + batches[(i * self.num_shards) : ((i + 1) * self.num_shards)] + for i in range((len(batches) // self.num_shards)) + ] + np.random.shuffle(grouped_batches) + batches = list(itertools.chain(*grouped_batches)) + else: + np.random.shuffle(batches) + + return batches + + if self._supports_prefetch: + batches = self.frozen_batches + + if shuffle and not fix_batches_to_gpus: + batches = shuffle_batches(list(batches), self.seed + epoch) + + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + self.dataset.prefetch([i for s in batches for i in s]) + + if shuffle and fix_batches_to_gpus: + batches = shuffle_batches(batches, self.seed + epoch + self.shard_id) + else: + if shuffle: + batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch) + else: + batches = self.frozen_batches + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + return batches + + +class GroupedIterator(CountingIterator): + """Wrapper around an iterable that returns groups (chunks) of items. + + Args: + iterable (iterable): iterable to wrap + chunk_size (int): size of each chunk + skip_remainder_batch (bool, optional): if set, discard the last grouped batch in + each training epoch, as the last grouped batch is usually smaller than + local_batch_size * distributed_word_size * chunk_size (default: ``False``). + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__(self, iterable, chunk_size, skip_remainder_batch=False): + if skip_remainder_batch: + total_num_itrs = int(math.floor(len(iterable) / float(chunk_size))) + logger.info( + f"skip final residual batch, grouped total_num_itrs = {total_num_itrs}" + ) + else: + total_num_itrs = int(math.ceil(len(iterable) / float(chunk_size))) + logger.info(f"grouped total_num_itrs = {total_num_itrs}") + + itr = _chunk_iterator(iterable, chunk_size, skip_remainder_batch) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))), + total=total_num_itrs, + ) + self.chunk_size = chunk_size + + if skip_remainder_batch: + self.take(total_num_itrs) + # TODO: [Hack] Here the grouped iterator modifies the base iterator size so that + # training can move into the next epoch once the grouped iterator is exhausted. + # Double-check this implementation in case unexpected behavior occurs. + iterable.take(total_num_itrs * chunk_size) + + +def _chunk_iterator(itr, chunk_size, skip_remainder_batch=False): + chunk = [] + for x in itr: + chunk.append(x) + if len(chunk) == chunk_size: + yield chunk + chunk = [] + if not skip_remainder_batch and len(chunk) > 0: + yield chunk + + +class ShardedIterator(CountingIterator): + """A sharded wrapper around an iterable, padded to length. + + Args: + iterable (iterable): iterable to wrap + num_shards (int): number of shards to split the iterable into + shard_id (int): which shard to iterator over + fill_value (Any, optional): padding value when the iterable doesn't + evenly divide *num_shards* (default: None). + + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__( + self, iterable, num_shards, shard_id, fill_value=None, skip_remainder_batch=None + ): + """ + Args: + skip_remainder_batch: ignored""" + if shard_id < 0 or shard_id >= num_shards: + raise ValueError("shard_id must be between 0 and num_shards") + sharded_len = int(math.ceil(len(iterable) / float(num_shards))) + itr = map( + operator.itemgetter(1), + itertools.zip_longest( + range(sharded_len), + itertools.islice(iterable, shard_id, len(iterable), num_shards), + fillvalue=fill_value, + ), + ) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))), + total=sharded_len, + ) + + +class BackgroundConsumer(Thread): + def __init__(self, queue, source, max_len, cuda_device): + Thread.__init__(self) + + self._queue = queue + self._source = source + self._max_len = max_len + self.count = 0 + self.cuda_device = cuda_device + + def run(self): + # set_device to avoid creation of GPU0 context when using pin_memory + if self.cuda_device is not None: + torch.cuda.set_device(self.cuda_device) + + try: + for item in self._source: + self._queue.put(item) + + # Stop if we reached the maximum length + self.count += 1 + if self._max_len is not None and self.count >= self._max_len: + break + + # Signal the consumer we are done. + self._queue.put(_sentinel) + except Exception as e: + self._queue.put(e) + + +class BufferedIterator(object): + def __init__(self, size, iterable): + self._queue = queue.Queue(size) + self._iterable = iterable + self._consumer = None + + self.start_time = time.time() + self.warning_time = None + + self.total = len(iterable) + + def _create_consumer(self): + self._consumer = BackgroundConsumer( + self._queue, + self._iterable, + self.total, + torch.cuda.current_device() if torch.cuda.is_available() else None, + ) + self._consumer.daemon = True + self._consumer.start() + + def __iter__(self): + return self + + def __len__(self): + return self.total + + def take(self, n): + self.total = min(self.total, n) + # Propagate this change to the underlying iterator + if hasattr(self._iterable, "take"): + self._iterable.take(n) + return self + + def __next__(self): + # Create consumer if not created yet + if self._consumer is None: + self._create_consumer() + + # Notify the user if there is a data loading bottleneck + if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): + if time.time() - self.start_time > 5 * 60: + if ( + self.warning_time is None + or time.time() - self.warning_time > 15 * 60 + ): + logger.debug( + "Data loading buffer is empty or nearly empty. This may " + "indicate a data loading bottleneck, and increasing the " + "number of workers (--num-workers) may help." + ) + self.warning_time = time.time() + + # Get next example + item = self._queue.get(True) + if isinstance(item, Exception): + raise item + if item is _sentinel: + raise StopIteration() + return item + + +class GroupedEpochBatchIterator(EpochBatchIterator): + """Grouped version of EpochBatchIterator + It takes several samplers from different datasets. + Each epoch shuffle the dataset wise sampler individually with different + random seed. The those sub samplers are combined with into + one big samplers with deterministic permutation to mix batches from + different datasets. It will act like EpochBatchIterator but make sure + 1) data from one data set each time + 2) for different workers, they use the same order to fetch the data + so they will use data from the same dataset everytime + mult_rate is used for update_freq > 1 case where we want to make sure update_freq + mini-batches come from same source + """ + + def __init__( + self, + dataset, + collate_fn, + batch_samplers, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=0, + mult_rate=1, + buffer_size=0, + skip_remainder_batch=False, + ): + super().__init__( + dataset, + collate_fn, + batch_samplers, + seed, + num_shards, + shard_id, + num_workers, + epoch, + buffer_size, + skip_remainder_batch=skip_remainder_batch, + ) + # level 0: sub-samplers 1: batch_idx 2: batches + self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers]) + self.step_size = mult_rate * num_shards + + self.lengths = [ + (len(x) // self.step_size) * self.step_size for x in self.frozen_batches + ] + + def __len__(self): + return sum(self.lengths) + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + if self.dataset.supports_fetch_outside_dataloader: + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0][0]]) + else: + return "DUMMY" + + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): + def shuffle_batches(batches, seed): + with data_utils.numpy_seed(seed): + np.random.shuffle(batches) + return batches + + def return_full_batches(batch_sets, seed, shuffle): + if shuffle: + batch_sets = [shuffle_batches(list(x), seed) for x in batch_sets] + + batch_sets = [ + batch_sets[i][: self.lengths[i]] for i in range(len(batch_sets)) + ] + batches = list(itertools.chain.from_iterable(batch_sets)) + + if shuffle: + with data_utils.numpy_seed(seed): + idx = np.random.permutation(len(batches) // self.step_size) + if len(idx) * self.step_size != len(batches): + raise ValueError( + "ERROR: %d %d %d %d" + % (len(idx), self.step_size, len(batches), self.shard_id), + ":".join(["%d" % x for x in self.lengths]), + ) + mini_shards = [ + batches[i * self.step_size : (i + 1) * self.step_size] + for i in idx + ] + batches = list(itertools.chain.from_iterable(mini_shards)) + + return batches + + if self._supports_prefetch: + raise NotImplementedError("To be implemented") + else: + batches = return_full_batches( + self.frozen_batches, self.seed + epoch, shuffle + ) + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + + if offset > 0 and offset >= len(batches): + return None + + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + itr = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=batches[offset:], + num_workers=self.num_workers, + persistent_workers=self.persistent_workers, + ) + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + return CountingIterator(itr, start=offset) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index c8668ec..ba4d7e5 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -108,7 +108,7 @@ class CommonConfig(FairseqDataclass): "help": "log progress every N batches (when progress bar is disabled)" }, ) - log_format: Optional[LOG_FORMAT_CHOICES] = field( + log_format: Optional[LOG_FORMAT_CHOICES] = field( # type: ignore default=None, metadata={"help": "log format to use"} ) log_file: Optional[str] = field( @@ -298,10 +298,10 @@ class DistributedTrainingConfig(FairseqDataclass): "help": "do not spawn multiple processes even if multiple GPUs are visible" }, ) - ddp_backend: DDP_BACKEND_CHOICES = field( + ddp_backend: DDP_BACKEND_CHOICES = field( # type: ignore default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} ) - ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( + ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( # type: ignore default="none", metadata={"help": "communication hook"} ) bucket_cap_mb: int = field( @@ -428,11 +428,11 @@ class DistributedTrainingConfig(FairseqDataclass): "equal the length of the --pipeline-decoder-balance argument" }, ) - pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( + pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( # type: ignore default="never", metadata={"help": "checkpointing mode for pipeline model parallelism"}, ) - zero_sharding: ZERO_SHARDING_CHOICES = field( + zero_sharding: ZERO_SHARDING_CHOICES = field( # type: ignore default="none", metadata={"help": "ZeRO sharding"} ) fp16: bool = II("common.fp16") diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py new file mode 100644 index 0000000..f6467d5 --- /dev/null +++ b/fairseq/dataclass/utils.py @@ -0,0 +1,510 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import inspect +import logging +import os +import re +from argparse import ArgumentError, ArgumentParser, Namespace +from dataclasses import _MISSING_TYPE, MISSING, is_dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Type + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import FairseqConfig +from hydra.core.global_hydra import GlobalHydra +from hydra.experimental import compose, initialize +from omegaconf import DictConfig, OmegaConf, open_dict, _utils + +logger = logging.getLogger(__name__) + + +def eval_str_list(x, x_type=float): + if x is None: + return None + if isinstance(x, str): + if len(x) == 0: + return [] + x = ast.literal_eval(x) + try: + return list(map(x_type, x)) + except TypeError: + return [x_type(x)] + + +def interpret_dc_type(field_type): + if isinstance(field_type, str): + raise RuntimeError("field should be a type") + + if field_type == Any: + return str + + typestring = str(field_type) + if re.match( + r"(typing.|^)Union\[(.*), NoneType\]$", typestring + ) or typestring.startswith("typing.Optional"): + return field_type.__args__[0] + return field_type + + +def gen_parser_from_dataclass( + parser: ArgumentParser, + dataclass_instance: FairseqDataclass, + delete_default: bool = False, + with_prefix: Optional[str] = None, +) -> None: + """ + convert a dataclass instance to tailing parser arguments. + + If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are + building a flat namespace from a structured dataclass (see transformer_config.py for example). + """ + + def argparse_name(name: str): + if name == "data" and (with_prefix is None or with_prefix == ""): + # normally data is positional args, so we don't add the -- nor the prefix + return name + if name == "_name": + # private member, skip + return None + full_name = "--" + name.replace("_", "-") + if with_prefix is not None and with_prefix != "": + # if a prefix is specified, construct the prefixed arg name + full_name = with_prefix + "-" + full_name[2:] # strip -- when composing + return full_name + + def get_kwargs_from_dc( + dataclass_instance: FairseqDataclass, k: str + ) -> Dict[str, Any]: + """k: dataclass attributes""" + + kwargs = {} + + field_type = dataclass_instance._get_type(k) + inter_type = interpret_dc_type(field_type) + + field_default = dataclass_instance._get_default(k) + + if isinstance(inter_type, type) and issubclass(inter_type, Enum): + field_choices = [t.value for t in list(inter_type)] + else: + field_choices = None + + field_help = dataclass_instance._get_help(k) + field_const = dataclass_instance._get_argparse_const(k) + + if isinstance(field_default, str) and field_default.startswith("${"): + kwargs["default"] = field_default + else: + if field_default is MISSING: + kwargs["required"] = True + if field_choices is not None: + kwargs["choices"] = field_choices + if ( + isinstance(inter_type, type) + and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) + ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): + if "int" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, int) + elif "float" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, float) + elif "str" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, str) + else: + raise NotImplementedError( + "parsing of type " + str(inter_type) + " is not implemented" + ) + if field_default is not MISSING: + kwargs["default"] = ( + ",".join(map(str, field_default)) + if field_default is not None + else None + ) + elif ( + isinstance(inter_type, type) and issubclass(inter_type, Enum) + ) or "Enum" in str(inter_type): + kwargs["type"] = str + if field_default is not MISSING: + if isinstance(field_default, Enum): + kwargs["default"] = field_default.value + else: + kwargs["default"] = field_default + elif inter_type is bool: + kwargs["action"] = ( + "store_false" if field_default is True else "store_true" + ) + kwargs["default"] = field_default + else: + kwargs["type"] = inter_type + if field_default is not MISSING: + kwargs["default"] = field_default + + # build the help with the hierarchical prefix + if with_prefix is not None and with_prefix != "" and field_help is not None: + field_help = with_prefix[2:] + ": " + field_help + + kwargs["help"] = field_help + if field_const is not None: + kwargs["const"] = field_const + kwargs["nargs"] = "?" + + return kwargs + + for k in dataclass_instance._get_all_attributes(): + field_name = argparse_name(dataclass_instance._get_name(k)) + field_type = dataclass_instance._get_type(k) + if field_name is None: + continue + elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass): + # for fields that are of type FairseqDataclass, we can recursively + # add their fields to the namespace (so we add the args from model, task, etc. to the root namespace) + prefix = None + if with_prefix is not None: + # if a prefix is specified, then we don't want to copy the subfields directly to the root namespace + # but we prefix them with the name of the current field. + prefix = field_name + gen_parser_from_dataclass(parser, field_type(), delete_default, prefix) + continue + + kwargs = get_kwargs_from_dc(dataclass_instance, k) + + field_args = [field_name] + alias = dataclass_instance._get_argparse_alias(k) + if alias is not None: + field_args.append(alias) + + if "default" in kwargs: + if isinstance(kwargs["default"], str) and kwargs["default"].startswith( + "${" + ): + if kwargs["help"] is None: + # this is a field with a name that will be added elsewhere + continue + else: + del kwargs["default"] + if delete_default and "default" in kwargs: + del kwargs["default"] + try: + parser.add_argument(*field_args, **kwargs) + except ArgumentError: + pass + + +def _set_legacy_defaults(args, cls): + """Helper to set default arguments based on *add_args*.""" + if not hasattr(cls, "add_args"): + return + + import argparse + + parser = argparse.ArgumentParser( + argument_default=argparse.SUPPRESS, allow_abbrev=False + ) + cls.add_args(parser) + # copied from argparse.py: + defaults = argparse.Namespace() + for action in parser._actions: + if action.dest is not argparse.SUPPRESS: + if not hasattr(defaults, action.dest): + if action.default is not argparse.SUPPRESS: + setattr(defaults, action.dest, action.default) + for key, default_value in vars(defaults).items(): + if not hasattr(args, key): + setattr(args, key, default_value) + + +def _override_attr( + sub_node: str, data_class: Type[FairseqDataclass], args: Namespace +) -> List[str]: + overrides = [] + + if not inspect.isclass(data_class) or not issubclass(data_class, FairseqDataclass): + return overrides + + def get_default(f): + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + for k, v in data_class.__dataclass_fields__.items(): + if k.startswith("_"): + # private member, skip + continue + + val = get_default(v) if not hasattr(args, k) else getattr(args, k) + + field_type = interpret_dc_type(v.type) + if ( + isinstance(val, str) + and not val.startswith("${") # not interpolation + and field_type != str + and ( + not inspect.isclass(field_type) or not issubclass(field_type, Enum) + ) # not choices enum + ): + # upgrade old models that stored complex parameters as string + val = ast.literal_eval(val) + + if isinstance(val, tuple): + val = list(val) + + v_type = getattr(v.type, "__origin__", None) + if ( + (v_type is List or v_type is list or v_type is Optional) + # skip interpolation + and not (isinstance(val, str) and val.startswith("${")) + ): + # if type is int but val is float, then we will crash later - try to convert here + if hasattr(v.type, "__args__"): + t_args = v.type.__args__ + if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int): + val = list(map(t_args[0], val)) + elif val is not None and ( + field_type is int or field_type is bool or field_type is float + ): + try: + val = field_type(val) + except: + pass # ignore errors here, they are often from interpolation args + + if val is None: + overrides.append("{}.{}=null".format(sub_node, k)) + elif val == "": + overrides.append("{}.{}=''".format(sub_node, k)) + elif isinstance(val, str): + val = val.replace("'", r"\'") + overrides.append("{}.{}='{}'".format(sub_node, k, val)) + elif isinstance(val, FairseqDataclass): + overrides += _override_attr(f"{sub_node}.{k}", type(val), args) + elif isinstance(val, Namespace): + sub_overrides, _ = override_module_args(val) + for so in sub_overrides: + overrides.append(f"{sub_node}.{k}.{so}") + else: + overrides.append("{}.{}={}".format(sub_node, k, val)) + + return overrides + + +def migrate_registry( + name, value, registry, args, overrides, deletes, use_name_as_val=False +): + if value in registry: + overrides.append("{}={}".format(name, value)) + overrides.append("{}._name={}".format(name, value)) + overrides.extend(_override_attr(name, registry[value], args)) + elif use_name_as_val and value is not None: + overrides.append("{}={}".format(name, value)) + else: + deletes.append(name) + + +def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: + """use the field in args to overrides those in cfg""" + overrides = [] + deletes = [] + + for k in FairseqConfig.__dataclass_fields__.keys(): + overrides.extend( + _override_attr(k, FairseqConfig.__dataclass_fields__[k].type, args) + ) + + if args is not None: + if hasattr(args, "task"): + from fairseq.tasks import TASK_DATACLASS_REGISTRY + + migrate_registry( + "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes + ) + else: + deletes.append("task") + + # these options will be set to "None" if they have not yet been migrated + # so we can populate them with the entire flat args + CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} + + from fairseq.registry import REGISTRIES + + for k, v in REGISTRIES.items(): + if hasattr(args, k): + migrate_registry( + k, + getattr(args, k), + v["dataclass_registry"], + args, + overrides, + deletes, + use_name_as_val=k not in CORE_REGISTRIES, + ) + else: + deletes.append(k) + + no_dc = True + if hasattr(args, "arch"): + from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY + + if args.arch in ARCH_MODEL_REGISTRY: + m_cls = ARCH_MODEL_REGISTRY[args.arch] + dc = getattr(m_cls, "__dataclass", None) + if dc is not None: + m_name = ARCH_MODEL_NAME_REGISTRY[args.arch] + overrides.append("model={}".format(m_name)) + overrides.append("model._name={}".format(args.arch)) + # override model params with those exist in args + overrides.extend(_override_attr("model", dc, args)) + no_dc = False + if no_dc: + deletes.append("model") + + return overrides, deletes + + +class omegaconf_no_object_check: + def __init__(self): + # Changed in https://github.com/omry/omegaconf/pull/911 - both are kept for back compat. + if hasattr(_utils, "is_primitive_type"): + self.old_is_primitive = _utils.is_primitive_type + else: + self.old_is_primitive = _utils.is_primitive_type_annotation + + def __enter__(self): + if hasattr(_utils, "is_primitive_type"): + _utils.is_primitive_type = lambda _: True + else: + _utils.is_primitive_type_annotation = lambda _: True + + def __exit__(self, type, value, traceback): + if hasattr(_utils, "is_primitive_type"): + _utils.is_primitive_type = self.old_is_primitive + else: + _utils.is_primitive_type_annotation = self.old_is_primitive + + +def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: + """Convert a flat argparse.Namespace to a structured DictConfig.""" + + # Here we are using field values provided in args to override counterparts inside config object + overrides, deletes = override_module_args(args) + + # configs will be in fairseq/config after installation + config_path = os.path.join("..", "config") + + GlobalHydra.instance().clear() + + with initialize(config_path=config_path): + try: + composed_cfg = compose("config", overrides=overrides, strict=False) + except: + logger.error("Error when composing. Overrides: " + str(overrides)) + raise + + for k in deletes: + composed_cfg[k] = None + + cfg = OmegaConf.create( + OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) + ) + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + with omegaconf_no_object_check(): + if cfg.task is None and getattr(args, "task", None): + cfg.task = Namespace(**vars(args)) + from fairseq.tasks import TASK_REGISTRY + + _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) + cfg.task._name = args.task + if cfg.model is None and getattr(args, "arch", None): + cfg.model = Namespace(**vars(args)) + from fairseq.models import ARCH_MODEL_REGISTRY + + _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) + cfg.model._name = args.arch + if cfg.optimizer is None and getattr(args, "optimizer", None): + cfg.optimizer = Namespace(**vars(args)) + from fairseq.optim import OPTIMIZER_REGISTRY + + _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) + cfg.optimizer._name = args.optimizer + if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): + cfg.lr_scheduler = Namespace(**vars(args)) + from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY + + _set_legacy_defaults( + cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler] + ) + cfg.lr_scheduler._name = args.lr_scheduler + if cfg.criterion is None and getattr(args, "criterion", None): + cfg.criterion = Namespace(**vars(args)) + from fairseq.criterions import CRITERION_REGISTRY + + _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) + cfg.criterion._name = args.criterion + + OmegaConf.set_struct(cfg, True) + return cfg + + +def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): + # this will be deprecated when we get rid of argparse and model_overrides logic + + from fairseq.registry import REGISTRIES + + with open_dict(cfg): + for k in cfg.keys(): + # "k in cfg" will return false if its a "mandatory value (e.g. ???)" + if k in cfg and isinstance(cfg[k], DictConfig): + if k in overrides and isinstance(overrides[k], dict): + for ok, ov in overrides[k].items(): + if isinstance(ov, dict) and cfg[k][ok] is not None: + overwrite_args_by_name(cfg[k][ok], ov) + else: + cfg[k][ok] = ov + else: + overwrite_args_by_name(cfg[k], overrides) + elif k in cfg and isinstance(cfg[k], Namespace): + for override_key, val in overrides.items(): + setattr(cfg[k], override_key, val) + elif k in overrides: + if ( + k in REGISTRIES + and overrides[k] in REGISTRIES[k]["dataclass_registry"] + ): + cfg[k] = DictConfig( + REGISTRIES[k]["dataclass_registry"][overrides[k]] + ) + overwrite_args_by_name(cfg[k], overrides) + cfg[k]._name = overrides[k] + else: + cfg[k] = overrides[k] + + +def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False): + if remove_missing: + + def remove_missing_rec(src_keys, target_cfg): + if is_dataclass(target_cfg): + target_keys = set(target_cfg.__dataclass_fields__.keys()) + else: + target_keys = set(target_cfg.keys()) + + for k in list(src_keys.keys()): + if k not in target_keys: + del src_keys[k] + elif OmegaConf.is_config(src_keys[k]): + tgt = getattr(target_cfg, k) + if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")): + remove_missing_rec(src_keys[k], tgt) + + with open_dict(cfg): + remove_missing_rec(cfg, dc) + + merged_cfg = OmegaConf.merge(dc, cfg) + merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] + OmegaConf.set_struct(merged_cfg, True) + return merged_cfg diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py new file mode 100644 index 0000000..d278ccb --- /dev/null +++ b/fairseq/distributed/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .fully_sharded_data_parallel import ( + fsdp_enable_wrap, + fsdp_wrap, + FullyShardedDataParallel, +) + +__all__ = [ + "fsdp_enable_wrap", + "fsdp_wrap", + "FullyShardedDataParallel", +] \ No newline at end of file diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py new file mode 100644 index 0000000..7656d2e --- /dev/null +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Optional + +import torch +from fairseq.dataclass.configs import DistributedTrainingConfig +from fairseq.distributed import utils as dist_utils + + +try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP # type: ignore + + has_FSDP = True +except ImportError: + FSDP = torch.nn.Module + has_FSDP = False + + +class FullyShardedDataParallel(FSDP): + """ + A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some + fairseq-specific checkpoint saving/loading logic. + + Args: + use_sharded_state (bool): if True, then ``state_dict`` will return + ``FSDP.local_state_dict`` and ``load_state_dict`` will call + ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will + return the full model weights on data parallel rank 0 (empty on + other ranks) and ``load_state_dict`` will broadcast model weights + from rank 0 to other ranks. + """ + + def __init__(self, *args, use_sharded_state: bool = False, **kwargs): + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + super().__init__(*args, **kwargs) + self.use_sharded_state = use_sharded_state + + @property + def unwrapped_module(self) -> torch.nn.Module: + if self.flatten_parameters: + return self.module.module + else: + return self.module + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if self.use_sharded_state: + return super().local_state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + if self.rank == 0: + return super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + # We must call state_dict() due to use of communication + # primitives. But we don't use the result. + super().state_dict() + return destination or {} + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + if self.use_sharded_state: + return super().load_local_state_dict(state_dict, strict=strict) + else: + state_dict = dist_utils.broadcast_object( + state_dict, src_rank=0, group=self.process_group + ) + return super().load_state_dict(state_dict, strict=strict) + + +class DummyProcessGroup: + def __init__(self, rank: int, size: int): + self._rank = rank + self._size = size + + def rank(self) -> int: + return self._rank + + def size(self) -> int: + return self._size + + +@contextlib.contextmanager +def fsdp_enable_wrap(cfg: DistributedTrainingConfig): + try: + from fairscale.nn import enable_wrap # type: ignore + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if cfg.memory_efficient_fp16: + assert cfg.fp16 # memory_efficient_fp16 should imply fp16 + group = dist_utils.get_data_parallel_group() + if group is None and cfg.distributed_world_size == 1: + group = DummyProcessGroup(rank=0, size=1) + fsdp_config = { + "process_group": group, + "reshard_after_forward": not cfg.no_reshard_after_forward, + "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, + "fp32_reduce_scatter": cfg.fp32_reduce_scatter, + "flatten_parameters": not cfg.not_fsdp_flatten_parameters, + "cpu_offload": cfg.cpu_offload, + "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, + "bucket_cap_mb": cfg.bucket_cap_mb, + "state_dict_device": torch.device("cpu"), # reduce GPU mem usage + } + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + use_sharded_state=cfg.use_sharded_state, + **fsdp_config, + ): + yield + + +def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): + """ + Helper to wrap layers/modules in FSDP. This falls back to a no-op if + fairscale is not available. + + Args: + module (nn.Module): module to (maybe) wrap + min_num_params (int, Optional): minimum number of layer params to wrap + """ + try: + from fairscale.nn import wrap # type: ignore + + if min_num_params is not None: + num_params = sum(p.numel() for p in module.parameters()) + if num_params >= min_num_params: + return wrap(module, **kwargs) + else: + return module + else: + return wrap(module, **kwargs) + except ImportError: + return module diff --git a/fairseq/file_io.py b/fairseq/file_io.py new file mode 100644 index 0000000..8eca70a --- /dev/null +++ b/fairseq/file_io.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import shutil +from typing import List, Optional + + +logger = logging.getLogger(__file__) + + +try: + from iopath.common.file_io import g_pathmgr as IOPathManager + + try: + # [FB only - for now] AWS PathHandler for PathManager + from .fb_pathhandlers import S3PathHandler + + IOPathManager.register_handler(S3PathHandler()) + except KeyError: + logging.warning("S3PathHandler already registered.") + except ImportError: + logging.debug( + "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module." + ) + +except ImportError: + IOPathManager = None + + +class PathManager: + """ + Wrapper for insulating OSS I/O (using Python builtin operations) from + iopath's PathManager abstraction (for transparently handling various + internal backends). + """ + + @staticmethod + def open( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + if IOPathManager: + return IOPathManager.open( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + return open( + path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool: + if IOPathManager: + return IOPathManager.copy( + src_path=src_path, dst_path=dst_path, overwrite=overwrite + ) + return shutil.copyfile(src_path, dst_path) + + @staticmethod + def get_local_path(path: str, **kwargs) -> str: + if IOPathManager: + return IOPathManager.get_local_path(path, **kwargs) + return path + + @staticmethod + def exists(path: str) -> bool: + if IOPathManager: + return IOPathManager.exists(path) + return os.path.exists(path) + + @staticmethod + def isfile(path: str) -> bool: + if IOPathManager: + return IOPathManager.isfile(path) + return os.path.isfile(path) + + @staticmethod + def ls(path: str) -> List[str]: + if IOPathManager: + return IOPathManager.ls(path) + return os.listdir(path) + + @staticmethod + def mkdirs(path: str) -> None: + if IOPathManager: + return IOPathManager.mkdirs(path) + os.makedirs(path, exist_ok=True) + + @staticmethod + def rm(path: str) -> None: + if IOPathManager: + return IOPathManager.rm(path) + os.remove(path) + + @staticmethod + def chmod(path: str, mode: int) -> None: + if not PathManager.path_requires_pathmanager(path): + os.chmod(path, mode) + + @staticmethod + def register_handler(handler) -> None: + if IOPathManager: + return IOPathManager.register_handler(handler=handler) + + @staticmethod + def copy_from_local( + local_path: str, dst_path: str, overwrite: bool = False, **kwargs + ) -> None: + if IOPathManager: + return IOPathManager.copy_from_local( + local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs + ) + return shutil.copyfile(local_path, dst_path) + + @staticmethod + def path_requires_pathmanager(path: str) -> bool: + """Do we require PathManager to access given path?""" + if IOPathManager: + for p in IOPathManager._path_handlers.keys(): + if path.startswith(p): + return True + return False + + @staticmethod + def supports_rename(path: str) -> bool: + # PathManager doesn't yet support renames + return not PathManager.path_requires_pathmanager(path) + + @staticmethod + def rename(src: str, dst: str): + os.rename(src, dst) + + """ + ioPath async PathManager methods: + """ + + @staticmethod + def opena( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + """ + Return file descriptor with asynchronous write operations. + """ + global IOPathManager + if not IOPathManager: + logging.info("ioPath is initializing PathManager.") + try: + from iopath.common.file_io import PathManager + + IOPathManager = PathManager() + except Exception: + logging.exception("Failed to initialize ioPath PathManager object.") + return IOPathManager.opena( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def async_close() -> bool: + """ + Wait for files to be written and clean up asynchronous PathManager. + NOTE: `PathManager.async_close()` must be called at the end of any + script that uses `PathManager.opena(...)`. + """ + global IOPathManager + if IOPathManager: + return IOPathManager.async_close() + return False diff --git a/fairseq/logging/__init__.py b/fairseq/logging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py new file mode 100644 index 0000000..495bd08 --- /dev/null +++ b/fairseq/logging/meters.py @@ -0,0 +1,351 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import bisect +import time +from collections import OrderedDict +from typing import Dict, Optional + +try: + import torch + + def type_as(a, b): + if torch.is_tensor(a) and torch.is_tensor(b): + return a.to(b) + else: + return a + +except ImportError: + torch = None + + def type_as(a, b): + return a + + +try: + import numpy as np +except ImportError: + np = None + + +class Meter(object): + """Base class for Meters.""" + + def __init__(self): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass + + def reset(self): + raise NotImplementedError + + @property + def smoothed_value(self) -> float: + """Smoothed value used for logging.""" + raise NotImplementedError + + +def safe_round(number, ndigits): + if hasattr(number, "__round__"): + return round(number, ndigits) + elif torch is not None and torch.is_tensor(number) and number.numel() == 1: + return safe_round(number.item(), ndigits) + elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): + return safe_round(number.item(), ndigits) + else: + return number + + +class AverageMeter(Meter): + """Computes and stores the average and current value""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.reset() + + def reset(self): + self.val = None # most recent update + self.sum = 0 # sum from all updates + self.count = 0 # total n from all updates + + def update(self, val, n=1): + if val is not None: + self.val = val + if n > 0: + self.sum = type_as(self.sum, val) + (val * n) + self.count = type_as(self.count, n) + n + + def state_dict(self): + return { + "val": self.val, + "sum": self.sum, + "count": self.count, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.val = state_dict["val"] + self.sum = state_dict["sum"] + self.count = state_dict["count"] + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.count if self.count > 0 else self.val + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class SumMeter(Meter): + """Computes and stores the sum""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.reset() + + def reset(self): + self.sum = 0 # sum from all updates + + def update(self, val): + if val is not None: + self.sum = type_as(self.sum, val) + val + + def state_dict(self): + return { + "sum": self.sum, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict["sum"] + self.round = state_dict.get("round", None) + + @property + def smoothed_value(self) -> float: + val = self.sum + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class ConcatTensorMeter(Meter): + """Concatenates tensors""" + + def __init__(self, dim=0): + super().__init__() + self.reset() + self.dim = dim + + def reset(self): + self.tensor = None + + def update(self, val): + if self.tensor is None: + self.tensor = val + else: + self.tensor = torch.cat([self.tensor, val], dim=self.dim) + + def state_dict(self): + return { + "tensor": self.tensor, + } + + def load_state_dict(self, state_dict): + self.tensor = state_dict["tensor"] + + @property + def smoothed_value(self) -> float: + return [] # return a dummy value + + +class TimeMeter(Meter): + """Computes the average occurrence of some event per second""" + + def __init__( + self, + init: int = 0, + n: int = 0, + round: Optional[int] = None, + ): + self.round = round + self.reset(init, n) + + def reset(self, init=0, n=0): + self.init = init + self.start = time.perf_counter() + self.n = n + self.i = 0 + + def update(self, val=1): + self.n = type_as(self.n, val) + val + self.i += 1 + + def state_dict(self): + return { + "init": self.elapsed_time, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + if "start" in state_dict: + # backwards compatibility for old state_dicts + self.reset(init=state_dict["init"]) + else: + self.reset(init=state_dict["init"], n=state_dict["n"]) + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.n / self.elapsed_time + + @property + def elapsed_time(self): + return self.init + (time.perf_counter() - self.start) + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class StopwatchMeter(Meter): + """Computes the sum/avg duration of some event in seconds""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.sum = 0 + self.n = 0 + self.start_time = None + + def start(self): + self.start_time = time.perf_counter() + + def stop(self, n=1, prehook=None): + if self.start_time is not None: + if prehook is not None: + prehook() + delta = time.perf_counter() - self.start_time + self.sum = self.sum + delta + self.n = type_as(self.n, n) + n + + def reset(self): + self.sum = 0 # cumulative time during which stopwatch was active + self.n = 0 # total n across all start/stop + self.start() + + def state_dict(self): + return { + "sum": self.sum, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict["sum"] + self.n = state_dict["n"] + self.start_time = None + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.n if self.n > 0 else self.sum + + @property + def elapsed_time(self): + if self.start_time is None: + return 0.0 + return time.perf_counter() - self.start_time + + @property + def smoothed_value(self) -> float: + val = self.avg if self.sum > 0 else self.elapsed_time + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class MetersDict(OrderedDict): + """A sorted dictionary of :class:`Meters`. + + Meters are sorted according to a priority that is given when the + meter is first added to the dictionary. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.priorities = [] + + def __setitem__(self, key, value): + assert key not in self, "MetersDict doesn't support reassignment" + priority, value = value + bisect.insort(self.priorities, (priority, len(self.priorities), key)) + super().__setitem__(key, value) + for _, _, key in self.priorities: # reorder dict to match priorities + self.move_to_end(key) + + def add_meter(self, key, meter, priority): + self.__setitem__(key, (priority, meter)) + + def state_dict(self): + return [ + (pri, key, self[key].__class__.__name__, self[key].state_dict()) + for pri, _, key in self.priorities + # can't serialize DerivedMeter instances + if not isinstance(self[key], MetersDict._DerivedMeter) + ] + + def load_state_dict(self, state_dict): + self.clear() + self.priorities.clear() + for pri, key, meter_cls, meter_state in state_dict: + meter = globals()[meter_cls]() + meter.load_state_dict(meter_state) + self.add_meter(key, meter, pri) + + def get_smoothed_value(self, key: str) -> float: + """Get a single smoothed value.""" + meter = self[key] + if isinstance(meter, MetersDict._DerivedMeter): + return meter.fn(self) + else: + return meter.smoothed_value + + def get_smoothed_values(self) -> Dict[str, float]: + """Get all smoothed values.""" + return OrderedDict( + [ + (key, self.get_smoothed_value(key)) + for key in self.keys() + if not key.startswith("_") + ] + ) + + def reset(self): + """Reset Meter instances.""" + for meter in self.values(): + if isinstance(meter, MetersDict._DerivedMeter): + continue + meter.reset() + + class _DerivedMeter(Meter): + """A Meter whose values are derived from other Meters.""" + + def __init__(self, fn): + self.fn = fn + + def reset(self): + pass diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py new file mode 100644 index 0000000..a7b0b40 --- /dev/null +++ b/fairseq/logging/metrics.py @@ -0,0 +1,336 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +A standalone module for aggregating metrics. + +Metrics can be logged from anywhere using the `log_*` functions defined +in this module. The logged values will be aggregated dynamically based +on the aggregation context in which the logging occurs. See the +:func:`aggregate` context manager for more details. +""" + +import contextlib +import uuid +from collections import defaultdict +from typing import Callable, List, Optional + +from .meters import * + + +# Aggregation contexts are considered "active" when inside the scope +# created by the :func:`aggregate` context manager. +_aggregators = OrderedDict() +_active_aggregators = OrderedDict() +_active_aggregators_cnt = defaultdict(lambda: 0) + + +def reset() -> None: + """Reset all metrics aggregators.""" + _aggregators.clear() + _active_aggregators.clear() + _active_aggregators_cnt.clear() + + # The "default" aggregator observes all logged values. + _aggregators["default"] = MetersDict() + _active_aggregators["default"] = _aggregators["default"] + _active_aggregators_cnt["default"] = 1 + + +reset() + + +@contextlib.contextmanager +def aggregate(name: Optional[str] = None, new_root: bool = False): + """Context manager to aggregate metrics under a given name. + + Aggregations can be nested. If *new_root* is ``False``, then logged + metrics will be recorded along the entire stack of nested + aggregators, including a global "default" aggregator. If *new_root* + is ``True``, then this aggregator will be the root of a new + aggregation stack, thus bypassing any parent aggregators. + + Note that aggregation contexts are uniquely identified by their + *name* (e.g., train, valid). Creating a context with an existing + name will reuse the corresponding :class:`MetersDict` instance. + If no name is given, then a temporary aggregator will be created. + + Usage:: + + with metrics.aggregate("train"): + for step, batch in enumerate(epoch): + with metrics.aggregate("train_inner") as agg: + metrics.log_scalar("loss", get_loss(batch)) + if step % log_interval == 0: + print(agg.get_smoothed_value("loss")) + agg.reset() + print(metrics.get_smoothed_values("train")["loss"]) + + Args: + name (str): name of the aggregation. Defaults to a + random/temporary name if not given explicitly. + new_root (bool): make this aggregation the root of a new + aggregation stack. + """ + if name is None: + # generate a temporary name + name = str(uuid.uuid4()) + assert name not in _aggregators + agg = MetersDict() + else: + assert name != "default" + agg = _aggregators.setdefault(name, MetersDict()) + + if new_root: + backup_aggregators = _active_aggregators.copy() + _active_aggregators.clear() + backup_aggregators_cnt = _active_aggregators_cnt.copy() + _active_aggregators_cnt.clear() + + _active_aggregators[name] = agg + _active_aggregators_cnt[name] += 1 + + yield agg + + _active_aggregators_cnt[name] -= 1 + if _active_aggregators_cnt[name] == 0 and name in _active_aggregators: + del _active_aggregators[name] + + if new_root: + _active_aggregators.clear() + _active_aggregators.update(backup_aggregators) + _active_aggregators_cnt.clear() + _active_aggregators_cnt.update(backup_aggregators_cnt) + + +def get_active_aggregators() -> List[MetersDict]: + return list(_active_aggregators.values()) + + +def log_scalar( + key: str, + value: float, + weight: float = 1, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value. + + Args: + key (str): name of the field to log + value (float): value to log + weight (float): weight that this value contributes to the average. + A weight of 0 will always log the latest value. + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, AverageMeter(round=round), priority) + agg[key].update(value, weight) + + +def log_scalar_sum( + key: str, + value: float, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value that is summed for reporting. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, SumMeter(round=round), priority) + agg[key].update(value) + + +def log_concat_tensor( + key: str, + value: torch.Tensor, + priority: int = 10, + dim: int = 0, +): + """Log a scalar value that is summed for reporting. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, ConcatTensorMeter(dim=dim), priority) + agg[key].update(value) + + +def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): + """Log a scalar value derived from other meters. + + Args: + key (str): name of the field to log + fn (Callable[[MetersDict], float]): function that takes a single + argument *meters* and returns the derived value + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, MetersDict._DerivedMeter(fn), priority) + + +def log_speed( + key: str, + value: float, + priority: int = 30, + round: Optional[int] = None, +): + """Log the rate of some quantity per second. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, TimeMeter(round=round), priority) + agg[key].reset() # reset meter on the first call + else: + agg[key].update(value) + + +def log_start_time(key: str, priority: int = 40, round: Optional[int] = None): + """Log the duration of some event in seconds. + + The duration will be computed once :func:`log_stop_time` is called. + + Args: + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, StopwatchMeter(round=round), priority) + agg[key].start() + + +def log_stop_time(key: str, weight: float = 0.0, prehook=None): + """Log the duration of some event in seconds. + + The duration will be computed since :func:`log_start_time` was called. + Set weight > 0 to report the average time instead of the sum. + + Args: + key (str): name of the field to log + weight (float): weight that this time contributes to the average + prehook (function, no arguments): will be called before the timer + is stopped. For example, use prehook=torch.cuda.synchronize to + make sure all gpu operations are done before timer is stopped. + """ + for agg in get_active_aggregators(): + if key in agg: + agg[key].stop(weight, prehook) + + +def log_custom( + new_meter_fn: Callable[[], Meter], + key: str, + *args, + priority: int = 50, + **kwargs, +): + """Log using a custom Meter. + + Any extra *args* or *kwargs* will be passed through to the Meter's + *update* method. + + Args: + new_meter_fn (Callable[[], Meter]): function that returns a new + Meter instance + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, new_meter_fn(), priority) + agg[key].update(*args, **kwargs) + + +def reset_meter(name: str, key: str) -> None: + """Reset Meter instance aggregated under a given *name* and *key*.""" + meter = get_meter(name, key) + if meter is not None: + meter.reset() + + +def reset_meters(name: str) -> None: + """Reset Meter instances aggregated under a given *name*.""" + meters = get_meters(name) + if meters is not None: + meters.reset() + + +def get_meter(name: str, key: str) -> Meter: + """Get a single Meter instance aggregated under *name* and *key*. + + Returns: + Meter or None if no metrics have been logged under *name* and *key*. + """ + if name not in _aggregators: + return None + return _aggregators[name].get(key, None) + + +def get_meters(name: str) -> MetersDict: + """Get Meter instances aggregated under a given *name*. + + Returns: + MetersDict or None if no metrics have been logged under *name*. + """ + return _aggregators.get(name, None) + + +def get_smoothed_value(name: str, key: str) -> float: + """Get a single smoothed value. + + Raises: + KeyError: if no metrics have been logged under *name* and *key*. + """ + return _aggregators[name].get_smoothed_value(key) + + +def get_smoothed_values(name: str) -> Dict[str, float]: + """Get smoothed values aggregated under a given *name*. + + Raises: + KeyError: if no metrics have been logged under *name*. + """ + return _aggregators[name].get_smoothed_values() + + +def state_dict(): + return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()]) + + +def load_state_dict(state_dict): + for name, agg_state in state_dict.items(): + _aggregators[name] = MetersDict() + _aggregators[name].load_state_dict(agg_state) + + +def xla_metrics_report(): + try: + import torch_xla.debug.metrics as met # type: ignore + + print(met.metrics_report()) + except ImportError: + return diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index e69de29..7689230 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -0,0 +1,230 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import argparse +import importlib +import os + +from contextlib import ExitStack + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore +from omegaconf import open_dict, OmegaConf + +from .fairseq_decoder import FairseqDecoder +from .fairseq_encoder import FairseqEncoder +from .fairseq_model import ( + BaseFairseqModel, + FairseqEncoderDecoderModel, + FairseqEncoderModel, + FairseqLanguageModel, + FairseqModel, + FairseqMultiModel, +) + + +MODEL_REGISTRY = {} +MODEL_DATACLASS_REGISTRY = {} +ARCH_MODEL_REGISTRY = {} +ARCH_MODEL_NAME_REGISTRY = {} +ARCH_MODEL_INV_REGISTRY = {} +ARCH_CONFIG_REGISTRY = {} + + +__all__ = [ + "BaseFairseqModel", + "FairseqDecoder", + "FairseqEncoder", + "FairseqEncoderDecoderModel", + "FairseqEncoderModel", + "FairseqLanguageModel", + "FairseqModel", + "FairseqMultiModel", +] + + +def build_model(cfg: FairseqDataclass, task, from_checkpoint=False): + + model = None + model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None) + + if not model_type and len(cfg) == 1: + # this is hit if config object is nested in directory that is named after model type + + model_type = next(iter(cfg)) + if model_type in MODEL_DATACLASS_REGISTRY: + cfg = cfg[model_type] + else: + raise Exception( + "Could not infer model type from directory. Please add _name field to indicate model type. " + "Available models: " + + str(MODEL_DATACLASS_REGISTRY.keys()) + + " Requested model type: " + + model_type + ) + + if model_type in ARCH_MODEL_REGISTRY: + # case 1: legacy models + model = ARCH_MODEL_REGISTRY[model_type] + elif model_type in MODEL_DATACLASS_REGISTRY: + # case 2: config-driven models + model = MODEL_REGISTRY[model_type] + + if model_type in MODEL_DATACLASS_REGISTRY: + # set defaults from dataclass. note that arch name and model name can be the same + dc = MODEL_DATACLASS_REGISTRY[model_type] + + if isinstance(cfg, argparse.Namespace): + cfg = dc.from_namespace(cfg) + else: + cfg = merge_with_parent(dc(), cfg, from_checkpoint) + else: + if model_type in ARCH_CONFIG_REGISTRY: + with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack(): + # this calls the different "arch" functions (like base_architecture()) that you indicate + # if you specify --arch on the command line. this is only applicable to the old argparse based models + # hydra models should expose different architectures via different config files + # it will modify the cfg object and default parameters according to the arch + ARCH_CONFIG_REGISTRY[model_type](cfg) + + assert model is not None, ( + f"Could not infer model type from {cfg}. " + "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys()) + + f" Requested model type: {model_type}" + ) + + return model.build_model(cfg, task) + + +def register_model(name, dataclass=None): + """ + New model types can be added to fairseq with the :func:`register_model` + function decorator. + + For example:: + + @register_model('lstm') + class LSTM(FairseqEncoderDecoderModel): + (...) + + .. note:: All models must implement the :class:`BaseFairseqModel` interface. + Typically you will extend :class:`FairseqEncoderDecoderModel` for + sequence-to-sequence tasks or :class:`FairseqLanguageModel` for + language modeling tasks. + + Args: + name (str): the name of the model + """ + + def register_model_cls(cls): + if name in MODEL_REGISTRY: + return MODEL_REGISTRY[name] + + if not issubclass(cls, BaseFairseqModel): + raise ValueError( + "Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__) + ) + MODEL_REGISTRY[name] = cls + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if dataclass is not None: + MODEL_DATACLASS_REGISTRY[name] = dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group="model", node=node, provider="fairseq") + + @register_model_architecture(name, name) + def noop(_): + pass + + return cls + + return register_model_cls + + +def register_model_architecture(model_name, arch_name): + """ + New model architectures can be added to fairseq with the + :func:`register_model_architecture` function decorator. After registration, + model architectures can be selected with the ``--arch`` command-line + argument. + + For example:: + + @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') + def lstm_luong_wmt_en_de(cfg): + args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000) + (...) + + The decorated function should take a single argument *cfg*, which is a + :class:`omegaconf.DictConfig`. The decorated function should modify these + arguments in-place to match the desired architecture. + + Args: + model_name (str): the name of the Model (Model must already be + registered) + arch_name (str): the name of the model architecture (``--arch``) + """ + + def register_model_arch_fn(fn): + if model_name not in MODEL_REGISTRY: + raise ValueError( + "Cannot register model architecture for unknown model type ({})".format( + model_name + ) + ) + if arch_name in ARCH_MODEL_REGISTRY: + raise ValueError( + "Cannot register duplicate model architecture ({})".format(arch_name) + ) + if not callable(fn): + raise ValueError( + "Model architecture must be callable ({})".format(arch_name) + ) + ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] + ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name + ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) + ARCH_CONFIG_REGISTRY[arch_name] = fn + return fn + + return register_model_arch_fn + + +def import_models(models_dir, namespace): + for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module(namespace + "." + model_name) + + # extra `model_parser` for sphinx + if model_name in MODEL_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_archs = parser.add_argument_group("Named architectures") + group_archs.add_argument( + "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] + ) + group_args = parser.add_argument_group( + "Additional command-line arguments" + ) + MODEL_REGISTRY[model_name].add_args(group_args) + globals()[model_name + "_parser"] = parser + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +import_models(models_dir, "fairseq.models") diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py new file mode 100644 index 0000000..13b73d6 --- /dev/null +++ b/fairseq/models/fairseq_decoder.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple + +import torch.nn as nn +from fairseq import utils +from torch import Tensor + + +class FairseqDecoder(nn.Module): + """Base class for decoders.""" + + def __init__(self, dictionary): + super().__init__() + self.dictionary = dictionary + self.onnx_trace = False + self.adaptive_softmax = None + + def forward(self, prev_output_tokens, encoder_out=None, **kwargs): + """ + Args: + prev_output_tokens (LongTensor): shifted output tokens of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (dict, optional): output from the encoder, used for + encoder-side attention + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + x, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + x = self.output_layer(x) + return x, extra + + def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs): + """ + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def output_layer(self, features, **kwargs): + """ + Project features to the default output size, e.g., vocabulary size. + + Args: + features (Tensor): features returned by *extract_features*. + """ + raise NotImplementedError + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: + if sample is not None: + assert "target" in sample + target = sample["target"] + else: + target = None + out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) + return out.exp_() if not log_probs else out + + logits = net_output[0] + if log_probs: + return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + else: + return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + + def max_positions(self): + """Maximum input length supported by the decoder.""" + return 1e6 # an arbitrary large number + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code.""" + return state_dict + + def prepare_for_onnx_export_(self): + self.onnx_trace = True diff --git a/fairseq/models/fairseq_encoder.py b/fairseq/models/fairseq_encoder.py new file mode 100644 index 0000000..08cbde1 --- /dev/null +++ b/fairseq/models/fairseq_encoder.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, NamedTuple, Optional + +import torch +import torch.nn as nn +from torch import Tensor + + +EncoderOut = NamedTuple( + "EncoderOut", + [ + ("encoder_out", Tensor), # T x B x C + ("encoder_padding_mask", Optional[Tensor]), # B x T + ("encoder_embedding", Optional[Tensor]), # B x T x C + ("encoder_states", Optional[List[Tensor]]), # List[T x B x C] + ("src_tokens", Optional[Tensor]), # B x T + ("src_lengths", Optional[Tensor]), # B x 1 + ], +) + + +class FairseqEncoder(nn.Module): + """Base class for encoders.""" + + def __init__(self, dictionary): + super().__init__() + self.dictionary = dictionary + + def forward(self, src_tokens, src_lengths=None, **kwargs): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of shape + `(batch)` + """ + raise NotImplementedError + + def forward_torchscript(self, net_input: Dict[str, Tensor]): + """A TorchScript-compatible version of forward. + + Encoders which use additional arguments may want to override + this method for TorchScript compatibility. + """ + if torch.jit.is_scripting(): + return self.forward( + src_tokens=net_input["src_tokens"], + src_lengths=net_input["src_lengths"], + ) + else: + return self.forward_non_torchscript(net_input) + + @torch.jit.unused + def forward_non_torchscript(self, net_input: Dict[str, Tensor]): + encoder_input = { + k: v for k, v in net_input.items() if k != "prev_output_tokens" + } + return self.forward(**encoder_input) + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to `new_order`. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + `encoder_out` rearranged according to `new_order` + """ + raise NotImplementedError + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return 1e6 # an arbitrary large number + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code.""" + return state_dict + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + + def _apply(m): + if hasattr(m, "set_num_updates") and m != self: + m.set_num_updates(num_updates) + + self.apply(_apply) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py new file mode 100644 index 0000000..6e4ae6d --- /dev/null +++ b/fairseq/models/fairseq_model.py @@ -0,0 +1,578 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base classes for various fairseq models. +""" + +import logging +from argparse import Namespace +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.data import Dictionary +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + gen_parser_from_dataclass, +) +from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig +from torch import Tensor + +logger = logging.getLogger(__name__) + + +def check_type(module, expected_type): + if hasattr(module, "unwrapped_module"): + assert isinstance( + module.unwrapped_module, expected_type + ), f"{type(module.unwrapped_module)} != {expected_type}" + else: + assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" + + +class BaseFairseqModel(nn.Module): + """Base class for fairseq models.""" + + def __init__(self): + super().__init__() + self._is_generation_fast = False + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + # do not set defaults so that settings defaults from various architectures still works + gen_parser_from_dataclass(parser, dc(), delete_default=True) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + raise NotImplementedError("Model must implement the build_model method") + + def get_targets(self, sample, net_output): + """Get targets from either the sample or the net's output.""" + return sample["target"] + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" + if hasattr(self, "decoder"): + return self.decoder.get_normalized_probs(net_output, log_probs, sample) + elif torch.is_tensor(net_output): + # syntactic sugar for simple models which don't have a decoder + # (e.g., the classification tutorial) + logits = net_output.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def extract_features(self, *args, **kwargs): + """Similar to *forward* but only return features.""" + return self(*args, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return None + + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + + if model_cfg is None and args is not None: + logger.warn( + "using 'args' is deprecated, please update your code to use dataclass config" + ) + model_cfg = convert_namespace_to_omegaconf(args).model + + self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + + new_state_dict = prune_state_dict(state_dict, model_cfg) + return super().load_state_dict(new_state_dict, strict) + + def upgrade_state_dict(self, state_dict): + """Upgrade old state dicts to work with newer code.""" + self.upgrade_state_dict_named(state_dict, "") + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code. + + Args: + state_dict (dict): state dictionary to upgrade, in place + name (str): the state dict key corresponding to the current module + """ + assert state_dict is not None + + def do_upgrade(m, prefix): + if len(prefix) > 0: + prefix += "." + + for n, c in m.named_children(): + name = prefix + n + if hasattr(c, "upgrade_state_dict_named"): + c.upgrade_state_dict_named(state_dict, name) + elif hasattr(c, "upgrade_state_dict"): + c.upgrade_state_dict(state_dict) + do_upgrade(c, name) + + do_upgrade(self, name) + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + for m in self.modules(): + if hasattr(m, "set_num_updates") and m != self: + m.set_num_updates(num_updates) + + def set_epoch(self, epoch): + for m in self.modules(): + if hasattr(m, "set_epoch") and m != self: + m.set_epoch(epoch) + + def prepare_for_inference_(self, cfg: DictConfig): + """Prepare model for inference.""" + kwargs = {} + kwargs["beamable_mm_beam_size"] = ( + None + if getattr(cfg.generation, "no_beamable_mm", False) + else getattr(cfg.generation, "beam", 5) + ) + kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) + if getattr(cfg.generation, "retain_dropout", False): + kwargs["retain_dropout"] = cfg.generation.retain_dropout + kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules + self.make_generation_fast_(**kwargs) + + def make_generation_fast_(self, **kwargs): + """ + Legacy entry point to optimize model for faster generation. + Prefer prepare_for_inference_. + """ + if self._is_generation_fast: + return # only apply once + self._is_generation_fast = True + + # remove weight norm from all modules in the network + def apply_remove_weight_norm(module): + try: + nn.utils.remove_weight_norm(module) + except (AttributeError, ValueError): # this module didn't have weight norm + return + + self.apply(apply_remove_weight_norm) + + def apply_make_generation_fast_(module, prefix): + if len(prefix) > 0: + prefix += "." + + base_func = BaseFairseqModel.make_generation_fast_ + for n, m in module.named_modules(): + if ( + m != self + and hasattr(m, "make_generation_fast_") + # don't call this implementation again, e.g., if + # children modules also inherit from BaseFairseqModel + and m.make_generation_fast_.__func__ is not base_func + ): + name = prefix + n + m.make_generation_fast_(name=name, **kwargs) + + apply_make_generation_fast_(self, "") + + def train(mode=True): + if mode: + raise RuntimeError("cannot train after make_generation_fast") + + # this model should no longer be used for training + self.eval() + self.train = train + + def prepare_for_onnx_export_(self, **kwargs): + """Make model exportable via ONNX trace.""" + seen = set() + + def apply_prepare_for_onnx_export_(module): + if ( + module != self + and hasattr(module, "prepare_for_onnx_export_") + and module not in seen + ): + seen.add(module) + module.prepare_for_onnx_export_(**kwargs) + + self.apply(apply_prepare_for_onnx_export_) + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + **kwargs, + ): + """ + Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model + file. Downloads and caches the pre-trained model file if needed. + + The base implementation returns a + :class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to + generate translations or sample from language models. The underlying + :class:`~fairseq.models.FairseqModel` can be accessed via the + *generator.models* attribute. + + Other models may override this to implement custom hub interfaces. + + Args: + model_name_or_path (str): either the name of a pre-trained model to + load or a path/URL to a pre-trained model state dict + checkpoint_file (str, optional): colon-separated list of checkpoint + files in the model archive to ensemble (default: 'model.pt') + data_name_or_path (str, optional): point args.data to the archive + at the given path/URL. Can start with '.' or './' to reuse the + model archive path. + """ + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + **kwargs, + ) + logger.info(x["args"]) + return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"]) + + @classmethod + def hub_models(cls): + return {} + + +class FairseqEncoderDecoderModel(BaseFairseqModel): + """Base class for encoder-decoder models. + + Args: + encoder (FairseqEncoder): the encoder + decoder (FairseqDecoder): the decoder + """ + + def __init__(self, encoder, decoder): + super().__init__() + + self.encoder = encoder + self.decoder = decoder + + check_type(self.encoder, FairseqEncoder) + check_type(self.decoder, FairseqDecoder) + + def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + """ + Run the forward pass for an encoder-decoder model. + + First feed a batch of source tokens through the encoder. Then, feed the + encoder output and previous decoder outputs (i.e., teacher forcing) to + the decoder to produce the next outputs:: + + encoder_out = self.encoder(src_tokens, src_lengths) + return self.decoder(prev_output_tokens, encoder_out) + + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + decoder_out = self.decoder( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + return decoder_out + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + features = self.decoder.extract_features( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + return features + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder.max_positions(), self.decoder.max_positions()) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + +class FairseqModel(FairseqEncoderDecoderModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + utils.deprecation_warning( + "FairseqModel is deprecated, please use FairseqEncoderDecoderModel " + "or BaseFairseqModel instead", + stacklevel=4, + ) + + +class FairseqMultiModel(BaseFairseqModel): + """Base class for combining multiple encoder-decoder models.""" + + def __init__(self, encoders, decoders): + super().__init__() + assert encoders.keys() == decoders.keys() + self.keys = list(encoders.keys()) + for key in self.keys: + check_type(encoders[key], FairseqEncoder) + check_type(decoders[key], FairseqDecoder) + + self.models = nn.ModuleDict( + { + key: FairseqEncoderDecoderModel(encoders[key], decoders[key]) + for key in self.keys + } + ) + + @staticmethod + def build_shared_embeddings( + dicts: Dict[str, Dictionary], + langs: List[str], + embed_dim: int, + build_embedding: callable, + pretrained_embed_path: Optional[str] = None, + ): + """ + Helper function to build shared embeddings for a set of languages after + checking that all dicts corresponding to those languages are equivalent. + + Args: + dicts: Dict of lang_id to its corresponding Dictionary + langs: languages that we want to share embeddings for + embed_dim: embedding dimension + build_embedding: callable function to actually build the embedding + pretrained_embed_path: Optional path to load pretrained embeddings + """ + shared_dict = dicts[langs[0]] + if any(dicts[lang] != shared_dict for lang in langs): + raise ValueError( + "--share-*-embeddings requires a joined dictionary: " + "--share-encoder-embeddings requires a joined source " + "dictionary, --share-decoder-embeddings requires a joined " + "target dictionary, and --share-all-embeddings requires a " + "joint source + target dictionary." + ) + return build_embedding(shared_dict, embed_dim, pretrained_embed_path) + + def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + raise NotImplementedError + + def max_positions(self): + """Maximum length supported by the model.""" + return { + key: ( + self.models[key].encoder.max_positions(), + self.models[key].decoder.max_positions(), + ) + for key in self.keys + } + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return min(model.decoder.max_positions() for model in self.models.values()) + + @property + def encoder(self): + return self.models[self.keys[0]].encoder + + @property + def decoder(self): + return self.models[self.keys[0]].decoder + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg=None, + args: Optional[Namespace] = None, + ): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + + if model_cfg is None and args is not None: + logger.warn( + "using 'args' is deprecated, please update your code to use dataclass config" + ) + model_cfg = convert_namespace_to_omegaconf(args).model + + self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + + new_state_dict = prune_state_dict(state_dict, model_cfg) + return super().load_state_dict(new_state_dict, strict) + + +class FairseqLanguageModel(BaseFairseqModel): + """Base class for decoder-only models. + + Args: + decoder (FairseqDecoder): the decoder + """ + + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + check_type(self.decoder, FairseqDecoder) + + def forward(self, src_tokens, **kwargs): + """ + Run the forward pass for a decoder-only model. + + Feeds a batch of tokens through the decoder to predict the next tokens. + + Args: + src_tokens (LongTensor): tokens on which to condition the decoder, + of shape `(batch, tgt_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + + Returns: + tuple: + - the decoder's output of shape `(batch, seq_len, vocab)` + - a dictionary with any model-specific outputs + """ + return self.decoder(src_tokens, **kwargs) + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + def extract_features(self, src_tokens, **kwargs): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, seq_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + return self.decoder.extract_features(src_tokens, **kwargs) + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return self.decoder.max_positions() + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + @property + def supported_targets(self): + return {"future"} + + +class FairseqEncoderModel(BaseFairseqModel): + """Base class for encoder-only models. + + Args: + encoder (FairseqEncoder): the encoder + """ + + def __init__(self, encoder): + super().__init__() + self.encoder = encoder + check_type(self.encoder, FairseqEncoder) + + def forward(self, src_tokens, src_lengths, **kwargs): + """ + Run the forward pass for a encoder-only model. + + Feeds a batch of tokens through the encoder to generate features. + + Args: + src_tokens (LongTensor): input tokens of shape `(batch, src_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + + Returns: + the encoder's output, typically of shape `(batch, src_len, features)` + """ + return self.encoder(src_tokens, src_lengths, **kwargs) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + encoder_out = net_output["encoder_out"] + if torch.is_tensor(encoder_out): + logits = encoder_out.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def max_positions(self): + """Maximum length supported by the model.""" + return self.encoder.max_positions() diff --git a/fairseq/models/hubert/hubert.py b/fairseq/models/hubert/hubert.py index 05c8b2e..8c4b8d0 100644 --- a/fairseq/models/hubert/hubert.py +++ b/fairseq/models/hubert/hubert.py @@ -14,22 +14,21 @@ from fairseq import utils from fairseq.data.data_utils import compute_mask_indices - -# from fairseq.data.dictionary import Dictionary -# from fairseq.dataclass import ChoiceEnum, FairseqDataclass -# from fairseq.models import BaseFairseqModel, register_model -# from fairseq.models.wav2vec.wav2vec2 import ( -# EXTRACTOR_MODE_CHOICES, -# MASKING_DISTRIBUTION_CHOICES, -# LAYER_TYPE_CHOICES, -# ConvFeatureExtractionModel, -# TransformerEncoder, -# ) -# from fairseq.modules import GradMultiply, LayerNorm -# from fairseq.tasks.hubert_pretraining import ( -# HubertPretrainingConfig, -# HubertPretrainingTask, -# ) +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2 import ( + EXTRACTOR_MODE_CHOICES, + MASKING_DISTRIBUTION_CHOICES, + LAYER_TYPE_CHOICES, + ConvFeatureExtractionModel, + TransformerEncoder, +) +from fairseq.modules import GradMultiply, LayerNorm +from fairseq.tasks.hubert_pretraining import ( + HubertPretrainingConfig, + HubertPretrainingTask, +) logger = logging.getLogger(__name__) @@ -38,7 +37,7 @@ class HubertConfig(FairseqDataclass): label_rate: float = II("task.label_rate") - extractor_mode: EXTRACTOR_MODE_CHOICES = field( + extractor_mode: EXTRACTOR_MODE_CHOICES = field( # type: ignore default="default", metadata={ "help": "mode for feature extractor. default has a single group " @@ -58,10 +57,10 @@ class HubertConfig(FairseqDataclass): encoder_attention_heads: int = field( default=12, metadata={"help": "num encoder attention heads"} ) - activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( # type: ignore default="gelu", metadata={"help": "activation function to use"} ) - layer_type: LAYER_TYPE_CHOICES = field( + layer_type: LAYER_TYPE_CHOICES = field( # type: ignore default="transformer", metadata={"help": "layer type in encoder"} ) @@ -134,7 +133,7 @@ class HubertConfig(FairseqDataclass): default=0.65, metadata={"help": "probability of replacing a token with mask"}, ) - mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore default="static", metadata={"help": "how to choose mask length"} ) mask_other: float = field( @@ -162,7 +161,7 @@ class HubertConfig(FairseqDataclass): default=0.0, metadata={"help": "probability of replacing a feature with 0"}, ) - mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore default="static", metadata={"help": "how to choose mask length for channel masking"}, ) diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index e3c0653..4679bd7 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -1,8 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .fp32_group_norm import Fp32GroupNorm +from .grad_multiply import GradMultiply +from .gumbel_vector_quantizer import GumbelVectorQuantizer from .layer_norm import Fp32LayerNorm, LayerNorm -from .multihead_attention import MultiheadAttention +from .multihead_attention import MultiheadAttention +from .same_pad import SamePad, SamePad2d +from .positional_encoding import ( + RelPositionalEncoding, +) +from .espnet_multihead_attention import ( + ESPNETMultiHeadedAttention, + RelPositionMultiHeadedAttention, + RotaryPositionMultiHeadedAttention, +) +from .transpose_last import TransposeLast __all__ = [ + "Fp32GroupNorm", "Fp32LayerNorm", + "GradMultiply", + "GumbelVectorQuantizer", "LayerNorm", "MultiheadAttention", + "RelPositionalEncoding", + "SamePad", + "SamePad2d", + "TransposeLast", + "ESPNETMultiHeadedAttention", + "RelPositionMultiHeadedAttention", + "RotaryPositionMultiHeadedAttention", ] diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py new file mode 100644 index 0000000..aa0b592 --- /dev/null +++ b/fairseq/modules/checkpoint_activations.py @@ -0,0 +1,242 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Any, Dict, List, Tuple, Union + +import torch +import torch.utils.checkpoint as checkpoint +from fairseq import utils + + +def checkpoint_wrapper(m, offload_to_cpu=False): + """ + A friendlier wrapper for performing activation checkpointing. + + Compared to the PyTorch version, this version: + - wraps an nn.Module, so that all subsequent calls will use checkpointing + - handles keyword arguments in the forward + - handles non-Tensor outputs from the forward + + Usage:: + + checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) + a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) + """ + # should I check whether original_forward has already been set? + assert not hasattr( + m, "precheckpoint_forward" + ), "checkpoint function has already been applied?" + m.precheckpoint_forward = m.forward + m.forward = functools.partial( + _checkpointed_forward, + m.precheckpoint_forward, # original_forward + offload_to_cpu, + ) + return m + + +def unwrap_checkpoint(m: torch.nn.Module): + """ + unwrap a module and its children from checkpoint_wrapper + """ + for module in m.modules(): + if hasattr(module, "precheckpoint_forward"): + module.forward = module.precheckpoint_forward + del module.precheckpoint_forward + if hasattr(module, "old_deepcopy_method"): + module.__deepcopy__ = module.old_deepcopy_method + del module.old_deepcopy_method + return m + + +def _checkpointed_forward(original_forward, offload_to_cpu, *args, **kwargs): + # Autograd Functions in PyTorch work best with positional args, since + # the backward must return gradients (or None) for every input argument. + # We can flatten keyword arguments to make this easier. + kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + parent_ctx_dict = {"offload": offload_to_cpu} + output = CheckpointFunction.apply( + original_forward, parent_ctx_dict, kwarg_keys, *flat_args + ) + if isinstance(output, torch.Tensor): + return output + else: + packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] + if packed_non_tensor_outputs: + output = unpack_non_tensors(output, packed_non_tensor_outputs) + return output + + +def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: + """ + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == [1, 2] + assert kwargs == {"a": 3, "b": 4} + """ + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + return kwarg_keys, flat_args + + +def unpack_kwargs( + kwarg_keys: List[str], flat_args: List[Any] +) -> Tuple[List[Any], Dict[str, Any]]: + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} + return args, kwargs + + +def split_non_tensors( + mixed: Union[torch.Tensor, Tuple[Any]] +) -> Tuple[Tuple[torch.Tensor], Dict[str, List[Any]]]: + """ + Usage:: + + x = torch.Tensor([1]) + y = torch.Tensor([2]) + tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y, None, 3) + """ + if isinstance(mixed, torch.Tensor): + return (mixed,), None + tensors = [] + packed_non_tensors = {"is_tensor": [], "objects": []} + for o in mixed: + if isinstance(o, torch.Tensor): + packed_non_tensors["is_tensor"].append(True) + tensors.append(o) + else: + packed_non_tensors["is_tensor"].append(False) + packed_non_tensors["objects"].append(o) + return tuple(tensors), packed_non_tensors + + +def unpack_non_tensors( + tensors: Tuple[torch.Tensor], + packed_non_tensors: Dict[str, List[Any]], +) -> Tuple[Any]: + if packed_non_tensors is None: + return tensors + assert isinstance(packed_non_tensors, dict) + mixed = [] + is_tensor_list = packed_non_tensors["is_tensor"] + objects = packed_non_tensors["objects"] + assert len(tensors) + len(objects) == len(is_tensor_list) + obj_i = tnsr_i = 0 + for is_tensor in is_tensor_list: + if is_tensor: + mixed.append(tensors[tnsr_i]) + tnsr_i += 1 + else: + mixed.append(objects[obj_i]) + obj_i += 1 + return tuple(mixed) + + +class CheckpointFunction(torch.autograd.Function): + """Similar to the torch version, but support non-Tensor outputs. + + The caller is expected to provide a dict (*parent_ctx_dict*) that will hold + the non-Tensor outputs. These should be combined with the Tensor *outputs* + by calling ``unpack_non_tensors``. + """ + + @staticmethod + def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): + if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation + checkpoint.check_backward_validity(args) + + ctx.run_function = run_function + ctx.kwarg_keys = kwarg_keys + ctx.fwd_rng_state = utils.get_rng_state() + + tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) + if parent_ctx_dict["offload"]: + ctx.fwd_device = tuple(x.device for x in tensor_inputs) + ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) + tensor_inputs = tuple( + x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs + ) + + else: + ctx.fwd_device, ctx.grad_requirements = None, None + + ctx.save_for_backward(*tensor_inputs) + ctx.packed_non_tensor_inputs = packed_non_tensor_inputs + + with torch.no_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) + outputs = run_function(*unpacked_args, **unpacked_kwargs) + + if isinstance(outputs, torch.Tensor): + return outputs + else: + # Autograd Functions don't like non-Tensor outputs. We can split the + # non-Tensor and Tensor outputs, returning the former by reference + # through *parent_ctx_dict* and returning the latter directly. + outputs, packed_non_tensor_outputs = split_non_tensors(outputs) + parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad(), please use .backward() if possible" + ) + + tensor_inputs: Tuple = ctx.saved_tensors + tensor_inputs = checkpoint.detach_variable(tensor_inputs) + if ctx.fwd_device is not None: + tensor_inputs = [ + t.to(ctx.fwd_device[i], non_blocking=True) + for i, t in enumerate(tensor_inputs) + ] + for i, need_grad in enumerate(ctx.grad_requirements): + tensor_inputs[i].requires_grad = need_grad + inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) + + # Store the current states. + bwd_rng_state = utils.get_rng_state() + + # Set the states to what it used to be before the forward pass. + utils.set_rng_state(ctx.fwd_rng_state) + + with torch.enable_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) + outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) + tensor_outputs, _ = split_non_tensors(outputs) + # Set the states back to what it was at the start of this function. + utils.set_rng_state(bwd_rng_state) + + # Run backward() with only Tensors that require grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(tensor_outputs)): + if tensor_outputs[i].requires_grad: + outputs_with_grad.append(tensor_outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "None of the outputs have requires_grad=True, " + "this checkpoint() is not necessary" + ) + + torch.autograd.backward(outputs_with_grad, args_with_grad) + + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs + ) + return (None, None, None) + grads diff --git a/fairseq/modules/conformer_layer.py b/fairseq/modules/conformer_layer.py new file mode 100644 index 0000000..964af24 --- /dev/null +++ b/fairseq/modules/conformer_layer.py @@ -0,0 +1,301 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch + +from fairseq.modules import ( + ESPNETMultiHeadedAttention, + LayerNorm, + MultiheadAttention, + RelPositionMultiHeadedAttention, + RotaryPositionMultiHeadedAttention, +) +from fairseq.utils import get_activation_fn + + +class ConvolutionModule(torch.nn.Module): + """Convolution block used in the conformer block""" + + def __init__( + self, + embed_dim, + channels, + depthwise_kernel_size, + dropout, + activation_fn="swish", + bias=False, + export=False, + ): + """ + Args: + embed_dim: Embedding dimension + channels: Number of channels in depthwise conv layers + depthwise_kernel_size: Depthwise conv layer kernel size + dropout: dropout value + activation_fn: Activation function to use after depthwise convolution kernel + bias: If bias should be added to conv layers + export: If layernorm should be exported to jit + """ + super(ConvolutionModule, self).__init__() + assert ( + depthwise_kernel_size - 1 + ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" + self.layer_norm = LayerNorm(embed_dim, export=export) + self.pointwise_conv1 = torch.nn.Conv1d( + embed_dim, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.glu = torch.nn.GLU(dim=1) + self.depthwise_conv = torch.nn.Conv1d( + channels, + channels, + depthwise_kernel_size, + stride=1, + padding=(depthwise_kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.batch_norm = torch.nn.BatchNorm1d(channels) + self.activation = get_activation_fn(activation_fn)(channels) + self.pointwise_conv2 = torch.nn.Conv1d( + channels, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, x): + """ + Args: + x: Input of shape B X T X C + Returns: + Tensor of shape B X T X C + """ + x = self.layer_norm(x) + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = self.glu(x) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.batch_norm(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) + x = self.dropout(x) + return x.transpose(1, 2) + + +class FeedForwardModule(torch.nn.Module): + """Positionwise feed forward layer used in conformer""" + + def __init__( + self, + input_feat, + hidden_units, + dropout1, + dropout2, + activation_fn="swish", + bias=True, + ): + """ + Args: + input_feat: Input feature dimension + hidden_units: Hidden unit dimension + dropout1: dropout value for layer1 + dropout2: dropout value for layer2 + activation_fn: Name of activation function + bias: If linear layers should have bias + """ + + super(FeedForwardModule, self).__init__() + self.layer_norm = LayerNorm(input_feat) + self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias) + self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias) + self.dropout1 = torch.nn.Dropout(dropout1) + self.dropout2 = torch.nn.Dropout(dropout2) + self.activation = get_activation_fn(activation_fn)(hidden_units) + + def forward(self, x): + """ + Args: + x: Input Tensor of shape T X B X C + Returns: + Tensor of shape T X B X C + """ + x = self.layer_norm(x) + x = self.w_1(x) + x = self.activation(x) + x = self.dropout1(x) + x = self.w_2(x) + return self.dropout2(x) + + +class ConformerEncoderLayer(torch.nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100. We currently don't support relative positional encoding in MHA""" + + def __init__( + self, + embed_dim, + ffn_embed_dim, + attention_heads, + dropout, + use_fp16, + depthwise_conv_kernel_size=31, + activation_fn="swish", + attn_type=None, + pos_enc_type="abs", + ): + """ + Args: + embed_dim: Input embedding dimension + ffn_embed_dim: FFN layer dimension + attention_heads: Number of attention heads in MHA + dropout: dropout value + depthwise_conv_kernel_size: Size of kernel in depthwise conv layer in convolution module + activation_fn: Activation function name to use in convulation block and feed forward block + attn_type: MHA implementation from ESPNET vs fairseq + pos_enc_type: Positional encoding type - abs, rope, rel_pos + """ + self.pos_enc_type = pos_enc_type + super(ConformerEncoderLayer, self).__init__() + + self.ffn1 = FeedForwardModule( + embed_dim, + ffn_embed_dim, + dropout, + dropout, + ) + + self.self_attn_layer_norm = LayerNorm(embed_dim, export=False) + self.self_attn_dropout = torch.nn.Dropout(dropout) + if attn_type == "espnet": + if self.pos_enc_type == "rel_pos": + self.self_attn = RelPositionMultiHeadedAttention( + embed_dim, + attention_heads, + dropout=dropout, + ) + elif self.pos_enc_type == "rope": + self.self_attn = RotaryPositionMultiHeadedAttention( + embed_dim, attention_heads, dropout=dropout, precision=use_fp16 + ) + elif self.pos_enc_type == "abs": + self.self_attn = ESPNETMultiHeadedAttention( + embed_dim, + attention_heads, + dropout=dropout, + ) + else: + raise Exception(f"Unsupported attention type {self.pos_enc_type}") + else: + # Default to fairseq MHA + self.self_attn = MultiheadAttention( + embed_dim, + attention_heads, + dropout=dropout, + ) + + self.conv_module = ConvolutionModule( + embed_dim=embed_dim, + channels=embed_dim, + depthwise_kernel_size=depthwise_conv_kernel_size, + dropout=dropout, + activation_fn=activation_fn, + ) + + self.ffn2 = FeedForwardModule( + embed_dim, + ffn_embed_dim, + dropout, + dropout, + activation_fn=activation_fn, + ) + self.final_layer_norm = LayerNorm(embed_dim, export=False) + + def forward( + self, + x, + encoder_padding_mask: Optional[torch.Tensor], + position_emb: Optional[torch.Tensor] = None, + ): + """ + Args: + x: Tensor of shape T X B X C + encoder_padding_mask: Optional mask tensor + positions: + Returns: + Tensor of shape T X B X C + """ + residual = x + x = self.ffn1(x) + x = x * 0.5 + residual + residual = x + x = self.self_attn_layer_norm(x) + if self.pos_enc_type == "rel_pos": + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + pos_emb=position_emb, + need_weights=False, + ) + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + need_weights=False, + ) + x = self.self_attn_dropout(x) + x = x + residual + + residual = x + # TBC to BTC + x = x.transpose(0, 1) + x = self.conv_module(x) + # BTC to TBC + x = x.transpose(0, 1) + x = residual + x + + residual = x + x = self.ffn2(x) + + layer_result = x + + x = x * 0.5 + residual + + x = self.final_layer_norm(x) + return x, (attn, layer_result) + + +class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer): + """Encoder layer for Wav2vec2 encoder""" + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + position_emb=None, + ): + return super().forward(x, self_attn_padding_mask, position_emb) diff --git a/fairseq/modules/espnet_multihead_attention.py b/fairseq/modules/espnet_multihead_attention.py new file mode 100644 index 0000000..82bc0d7 --- /dev/null +++ b/fairseq/modules/espnet_multihead_attention.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multi-Head Attention layer definition.""" + +import math + +import torch +from torch import nn + +from fairseq.modules.rotary_positional_embedding import ( + RotaryPositionalEmbedding, + apply_rotary_pos_emb, +) + + +class ESPNETMultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + Args: + n_head: The number of heads. + n_feat: The number of features. + dropout: Dropout rate. + """ + + def __init__(self, n_feat, n_head, dropout): + """Construct an MultiHeadedAttention object.""" + super(ESPNETMultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward_qkv(self, query, key, value, **kwargs): + """Transform query, key and value. + Args: + query: Query tensor B X T1 X C + key: Key tensor B X T2 X C + value: Value tensor B X T2 X C + Returns: + torch.Tensor: Transformed query tensor B X n_head X T1 X d_k + torch.Tensor: Transformed key tensor B X n_head X T2 X d_k + torch.Tensor: Transformed value tensor B X n_head X T2 X d_k + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + Args: + value: Transformed value B X n_head X T2 X d_k. + scores: Attention score B X n_head X T1 X T2 + mask: Mask T2 X B + Returns: + torch.Tensor: Transformed value B X T1 X d_model + weighted by the attention score B X T1 X T2 + """ + n_batch = value.size(0) + if mask is not None: + scores = scores.masked_fill( + mask.unsqueeze(1).unsqueeze(2).to(bool), + float("-inf"), # (batch, head, time1, time2) + ) + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, key_padding_mask=None, **kwargs): + """Compute scaled dot product attention. + Args: + query (torch.Tensor): Query tensor T X B X C + key (torch.Tensor): Key tensor T X B X C + value (torch.Tensor): Value tensor T X B X C + mask (torch.Tensor): Mask tensor T X B + Returns: + torch.Tensor: Output tensor T X B X D. + """ + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + scores = self.forward_attention(v, scores, key_padding_mask) + scores = scores.transpose(0, 1) + return scores, None + + +class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head: The number of heads. + n_feat: The number of features. + dropout: Dropout rate. + zero_triu: Whether to zero the upper triangular part of attention matrix. + """ + + def __init__(self, n_feat, n_head, dropout, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_feat, n_head, dropout) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + Args: + x: Input tensor B X n_head X T X 2T-1 + Returns: + torch.Tensor: Output tensor. + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3)), device=x.device) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs): + """Compute scaled dot product attention. + Args: + query: Query tensor T X B X C + key: Key tensor T X B X C + value: Value tensor T X B X C + pos_emb: Positional embedding tensor B X 2T-1 X C + key_padding_mask: Mask tensor T X B + Returns: + torch.Tensor: Output tensor T X B X C. + """ + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + pos_emb = pos_emb.transpose(0, 1) + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + scores = self.forward_attention(v, scores, key_padding_mask) + scores = scores.transpose(0, 1) + return scores, None + + +class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): + def __init__( + self, + n_feat, + n_head, + dropout, + precision, + rotary_emd_base=10000, + ): + """Construct an RotaryPositionMultiHeadedAttention object.""" + super().__init__(n_feat, n_head, dropout) + precision = torch.float + self.rotary_ndims = self.d_k # also try self.d_k//2 + if precision == "fp16": + precision = torch.half + + self.rotary_emb = RotaryPositionalEmbedding( + self.rotary_ndims, base=rotary_emd_base, precision=precision + ) + + def forward(self, query, key, value, key_padding_mask=None, **kwargs): + """Compute rotary position attention. + Args: + query: Query tensor T X B X C + key: Key tensor T X B X C + value: Value tensor T X B X C + key_padding_mask: Mask tensor T X B + Returns: + torch.Tensor: Output tensor T X B X D. + Notes: + Assumes self attn + """ + + T, B, C = value.size() + query = query.view(T, B, self.h, self.d_k) + key = key.view(T, B, self.h, self.d_k) + value = value.view(T, B, self.h, self.d_k) + cos, sin = self.rotary_emb(value, seq_len=T) + query, key = apply_rotary_pos_emb( + query, key, cos, sin, offset=0 + ) # offset is based on layer_past + + query = query.view(T, B, self.h * self.d_k) + key = key.view(T, B, self.h * self.d_k) + value = value.view(T, B, self.h * self.d_k) + + # TBD to BTD + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + scores = self.forward_attention(v, scores, key_padding_mask) + scores = scores.transpose(0, 1) + return scores, None diff --git a/fairseq/modules/fp32_group_norm.py b/fairseq/modules/fp32_group_norm.py new file mode 100644 index 0000000..d03aac0 --- /dev/null +++ b/fairseq/modules/fp32_group_norm.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Layer norm done in fp32 (for fp16 training) +""" + +import torch.nn as nn +import torch.nn.functional as F + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) diff --git a/fairseq/modules/grad_multiply.py b/fairseq/modules/grad_multiply.py new file mode 100644 index 0000000..08d15f5 --- /dev/null +++ b/fairseq/modules/grad_multiply.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py new file mode 100644 index 0000000..867b019 --- /dev/null +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -0,0 +1,212 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GumbelVectorQuantizer(nn.Module): + def __init__( + self, + dim, + num_vars, + temp, + groups, + combine_groups, + vq_dim, + time_first, + activation=nn.GELU(), + weight_proj_depth=1, + weight_proj_factor=1, + hard=True, + std=0, + ): + """Vector quantization using gumbel softmax + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor) + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1 + weight_proj_depth: number of layers (with activation in between) to project input before computing logits + weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of + projections by this factor + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.time_first = time_first + self.hard = hard + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) + if std == 0: + nn.init.uniform_(self.vars) + else: + nn.init.normal_(self.vars, mean=0, std=std) + + if weight_proj_depth > 1: + + def block(input_dim, output_dim): + return nn.Sequential(nn.Linear(input_dim, output_dim), activation) + + inner_dim = self.input_dim * weight_proj_factor + self.weight_proj = nn.Sequential( + *[ + block(self.input_dim if i == 0 else inner_dim, inner_dim) + for i in range(weight_proj_depth - 1) + ], + nn.Linear(inner_dim, groups * num_vars), + ) + else: + self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) + nn.init.normal_(self.weight_proj.weight, mean=0, std=1) + nn.init.zeros_(self.weight_proj.bias) + + if isinstance(temp, str): + import ast + + temp = ast.literal_eval(temp) + assert len(temp) == 3, f"{temp}, {len(temp)}" + + self.max_temp, self.min_temp, self.temp_decay = temp + self.curr_temp = self.max_temp + self.codebook_indices = None + + def set_num_updates(self, num_updates): + self.curr_temp = max( + self.max_temp * self.temp_decay**num_updates, self.min_temp + ) + + def get_codebook_indices(self): + if self.codebook_indices is None: + from itertools import product + + p = [range(self.num_vars)] * self.groups + inds = list(product(*p)) + self.codebook_indices = torch.tensor( + inds, dtype=torch.long, device=self.vars.device + ).flatten() + + if not self.combine_groups: + self.codebook_indices = self.codebook_indices.view( + self.num_vars**self.groups, -1 + ) + for b in range(1, self.groups): + self.codebook_indices[:, b] += self.num_vars * b + self.codebook_indices = self.codebook_indices.flatten() + return self.codebook_indices + + def codebook(self): + indices = self.get_codebook_indices() + return ( + self.vars.squeeze(0) + .index_select(0, indices) + .view(self.num_vars**self.groups, -1) + ) + + def sample_from_codebook(self, b, n): + indices = self.get_codebook_indices() + indices = indices.view(-1, self.groups) + cb_size = indices.size(0) + assert ( + n < cb_size + ), f"sample size {n} is greater than size of codebook {cb_size}" + sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) + indices = indices[sample_idx] + + z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) + return z + + def to_codebook_index(self, indices): + res = indices.new_full(indices.shape[:-1], 0) + for i in range(self.groups): + exponent = self.groups - i - 1 + res += indices[..., i] * (self.num_vars**exponent) + return res + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward(self, x, produce_targets=False): + + result = {"num_vars": self.num_vars * self.groups} + + if not self.time_first: + x = x.transpose(1, 2) + + bsz, tsz, fsz = x.shape + x = x.reshape(-1, fsz) + x = self.weight_proj(x) + x = x.view(bsz * tsz * self.groups, -1) + + with torch.no_grad(): + _, k = x.max(-1) + hard_x = ( + x.new_zeros(*x.shape) + .scatter_(-1, k.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + avg_probs = torch.softmax( + x.view(bsz * tsz, self.groups, -1).float(), dim=-1 + ).mean(dim=0) + result["prob_perplexity"] = torch.exp( + -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) + ).sum() + + result["temp"] = self.curr_temp + + if self.training: + x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=self.hard).type_as( + x + ) + else: + x = hard_x + + x = x.view(bsz * tsz, -1) + + vars = self.vars + if self.combine_groups: + vars = vars.repeat(1, self.groups, 1) + + if produce_targets: + result["targets"] = ( + x.view(bsz * tsz * self.groups, -1) + .argmax(dim=-1) + .view(bsz, tsz, self.groups) + .detach() + ) + + x = x.unsqueeze(-1) * vars + x = x.view(bsz * tsz, self.groups, self.num_vars, -1) + x = x.sum(-2) + x = x.view(bsz, tsz, -1) + + if not self.time_first: + x = x.transpose(1, 2) # BTC -> BCT + + result["x"] = x + + return result diff --git a/fairseq/modules/positional_encoding.py b/fairseq/modules/positional_encoding.py new file mode 100644 index 0000000..67f6353 --- /dev/null +++ b/fairseq/modules/positional_encoding.py @@ -0,0 +1,129 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import math +import torch + + +class PositionalEncoding(nn.Module): + """Positional encoding. + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + reverse: Whether to reverse the input position. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + Args: + x (torch.Tensor): Input tensor B X T X C + Returns: + torch.Tensor: Encoded tensor B X T X C + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class RelPositionalEncoding(nn.Module): + """Relative positional encoding module (new implementation). + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + """ + + def __init__(self, max_len, d_model): + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1)) + self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1)) + return self.cos_cached, self.sin_cached + +# rotary pos emb helpers: +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat( + (-x2, x1), dim=x1.ndim - 1 + ) # dim=-1 triggers a bug in earlier torch versions + + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos, sin = ( + cos[offset : q.shape[0] + offset, ...], + sin[offset : q.shape[0] + offset, ...], + ) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) diff --git a/fairseq/modules/same_pad.py b/fairseq/modules/same_pad.py new file mode 100644 index 0000000..a3ce413 --- /dev/null +++ b/fairseq/modules/same_pad.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from torch import nn + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class SamePad2d(nn.Module): + def __init__(self, kernel_size): + super().__init__() + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + assert len(x.size()) == 4 + if self.remove > 0: + x = x[:, :, : -self.remove, : -self.remove] + return x diff --git a/fairseq/modules/transpose_last.py b/fairseq/modules/transpose_last.py new file mode 100644 index 0000000..d7cca9a --- /dev/null +++ b/fairseq/modules/transpose_last.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +transpose last 2 dimensions of the input +""" + +import torch.nn as nn + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None, tranpose_dim=-2): + super().__init__() + self.deconstruct_idx = deconstruct_idx + self.tranpose_dim = tranpose_dim + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(self.tranpose_dim, -1) diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fairseq/optim/amp_optimizer.py b/fairseq/optim/amp_optimizer.py new file mode 100644 index 0000000..cfe57d0 --- /dev/null +++ b/fairseq/optim/amp_optimizer.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from fairseq import optim +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) + + +class AMPOptimizer(optim.FairseqOptimizer): + """ + Wrap an *optimizer* to support AMP (automatic mixed precision) training. + """ + + def __init__(self, cfg: DictConfig, params, fp32_optimizer, **kwargs): + super().__init__(cfg.optimizer) + self.fp32_optimizer = fp32_optimizer + amp_kwargs = {"init_scale": cfg.common.fp16_init_scale} + if getattr(cfg.common, "amp_scale_window", None) is not None: + amp_kwargs["growth_interval"] = cfg.common.amp_init_scale + self._grad_scaler = torch.cuda.amp.GradScaler(**amp_kwargs) + self.min_loss_scale = cfg.common.min_loss_scale + + @classmethod + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): + """ + Args: + cfg (omegaconf.DictConfig): fairseq args + params (iterable): iterable of parameters to optimize + """ + fp32_optimizer = optim.build_optimizer(cfg.optimizer, params) + return cls(cfg, params, fp32_optimizer, **kwargs) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + self._grad_scaler.scale(loss).backward() + + def step(self): + self.scaler.step(self.fp32_optimizer) + self.scaler.update() + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + self.scaler.unscale_(self.optimizer) + grad_norm = self.fp32_optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) + if not torch.isfinite(grad_norm).all(): + new_loss_scale = self.next_loss_scale + if new_loss_scale <= self.min_loss_scale: + raise FloatingPointError( + ( + "AMP: Minimum loss scale reached ({}). Your loss is probably exploding. " + "Try restarting training or use fp32. {}" + ).format(self.min_loss_scale, new_loss_scale) + ) + else: + logger.info( + "AMP: overflow detected, setting scale to " f"to {new_loss_scale}" + ) + return grad_norm + + @property + def scaler(self): + return self._grad_scaler + + @property + def next_loss_scale(self): + return self.scaler.get_scale() * self.scaler.get_backoff_factor() + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + + @property + def lr_scheduler(self): + return getattr(self.fp32_optimizer, "lr_scheduler", None) + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + def all_reduce_grads(self, module): + self.fp32_optimizer.all_reduce_grads(module) + + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params diff --git a/fairseq/registry.py b/fairseq/registry.py new file mode 100644 index 0000000..904ffcd --- /dev/null +++ b/fairseq/registry.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace + +from typing import Union +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore +from omegaconf import DictConfig + +REGISTRIES = {} + + +def setup_registry(registry_name: str, base_class=None, default=None, required=False): + assert registry_name.startswith("--") + registry_name = registry_name[2:].replace("-", "_") + + REGISTRY = {} + REGISTRY_CLASS_NAMES = set() + DATACLASS_REGISTRY = {} + + # maintain a registry of all registries + if registry_name in REGISTRIES: + return # registry already exists + REGISTRIES[registry_name] = { + "registry": REGISTRY, + "default": default, + "dataclass_registry": DATACLASS_REGISTRY, + } + + def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs): + if isinstance(cfg, DictConfig): + choice = cfg._name + + if choice and choice in DATACLASS_REGISTRY: + from_checkpoint = extra_kwargs.get("from_checkpoint", False) + dc = DATACLASS_REGISTRY[choice] + cfg = merge_with_parent(dc(), cfg, remove_missing=from_checkpoint) + elif isinstance(cfg, str): + choice = cfg + if choice in DATACLASS_REGISTRY: + cfg = DATACLASS_REGISTRY[choice]() + else: + choice = getattr(cfg, registry_name, None) + if choice in DATACLASS_REGISTRY: + cfg = DATACLASS_REGISTRY[choice].from_namespace(cfg) + + if choice is None: + if required: + raise ValueError("{} is required!".format(registry_name)) + return None + + cls = REGISTRY[choice] + if hasattr(cls, "build_" + registry_name): + builder = getattr(cls, "build_" + registry_name) + else: + builder = cls + + if "from_checkpoint" in extra_kwargs: + del extra_kwargs["from_checkpoint"] + + return builder(cfg, *extra_args, **extra_kwargs) + + def register_x(name, dataclass=None): + def register_x_cls(cls): + if name in REGISTRY: + raise ValueError( + "Cannot register duplicate {} ({})".format(registry_name, name) + ) + if cls.__name__ in REGISTRY_CLASS_NAMES: + raise ValueError( + "Cannot register {} with duplicate class name ({})".format( + registry_name, cls.__name__ + ) + ) + if base_class is not None and not issubclass(cls, base_class): + raise ValueError( + "{} must extend {}".format(cls.__name__, base_class.__name__) + ) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if cls.__dataclass is not None: + DATACLASS_REGISTRY[name] = cls.__dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group=registry_name, node=node, provider="fairseq") + + REGISTRY[name] = cls + + return cls + + return register_x_cls + + return build_x, register_x, REGISTRY, DATACLASS_REGISTRY diff --git a/fairseq/search.py b/fairseq/search.py new file mode 100644 index 0000000..c7378bb --- /dev/null +++ b/fairseq/search.py @@ -0,0 +1,892 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from typing import List, Optional + +import torch +import torch.nn as nn +from fairseq.token_generation_constraints import ( + ConstraintState, + OrderedConstraintState, + UnorderedConstraintState, +) +from torch import Tensor + + +class Search(nn.Module): + def __init__(self, tgt_dict): + super().__init__() + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.src_lengths = torch.tensor(-1) + self.supports_constraints = False + self.stop_on_max_len = False + + def step( + self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None + ): + """Take a single search step. + + Args: + step: the current search step, starting at 0 + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + scores: (bsz x input_beam_size x step) + the historical model scores of each hypothesis up to this point + prev_output_tokens: (bsz x step) + the previously generated oputput tokens + original_batch_idxs: (bsz) + the tensor with the batch indices, in the range [0, bsz) + this is useful in case there has been applied a re-ordering + and we need to know the orignal indices + + Return: A tuple of (scores, indices, beams) where: + scores: (bsz x output_beam_size) + the scores of the chosen elements; output_beam_size can be + larger than input_beam_size, e.g., we may return + 2*input_beam_size to account for EOS + indices: (bsz x output_beam_size) + the indices of the chosen elements + beams: (bsz x output_beam_size) + the hypothesis ids of the chosen elements, in the range [0, input_beam_size) + """ + raise NotImplementedError + + @torch.jit.export + def set_src_lengths(self, src_lengths): + self.src_lengths = src_lengths + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + """Initialize constraint states for constrained decoding (if supported). + + Args: + batch_constraints: (torch.Tensor, optional) + the list of constraints, in packed form + beam_size: (int) + the beam size + Returns: + *encoder_out* rearranged according to *new_order* + """ + pass + + def prune_sentences(self, batch_idxs: Tensor): + """ + Removes constraint states for completed sentences (if supported). + This is called from sequence_generator._generate() when sentences are + deleted from the batch. + + Args: + batch_idxs: Indices of *sentences* whose constraint state should be *kept*. + """ + pass + + def update_constraints(self, active_hypos: Tensor): + """ + Updates the constraint states by selecting the beam items that are retained. + This is called at each time step of sequence_generator._generate() when + the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. + + Args: + active_hypos: (batch size, beam size) + list of integers denoting, for each sentence, which beam candidate items + should be kept. + """ + pass + + +class BeamSearch(Search): + def __init__(self, tgt_dict): + super().__init__(tgt_dict) + self.constraint_states = None + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + candidate_multiple: int = 2, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best `candidate_muliple`(default 2) x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + candidate_multiple * beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + # Project back into relative indices and beams + beams_buf = torch.div(indices_buf, vocab_size, rounding_mode="trunc") + indices_buf = indices_buf.fmod(vocab_size) + + # At this point, beams_buf and indices_buf are single-dim and contain relative indices + return scores_buf, indices_buf, beams_buf + + +class PrefixConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, prefix_allowed_tokens_fn): + super().__init__(tgt_dict) + self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self.stop_on_max_len = True + + @torch.jit.export + def apply_mask(self, x, prev_output_tokens, original_batch_idxs): + beam_size = x.shape[0] // original_batch_idxs.shape[0] + original_batch_idxs = ( + original_batch_idxs.unsqueeze(-1).repeat((1, beam_size)).flatten().tolist() + ) + + mask = torch.full_like(x, -math.inf) + for sent_i, (sent, batch_i) in enumerate( + zip(prev_output_tokens, original_batch_idxs) + ): + mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 + + return mask + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Tensor, + prev_output_tokens: Tensor, + original_batch_idxs: Tensor, + ): + bsz, beam_size, vocab_size = lprobs.size() + + lprobs += self.apply_mask( + lprobs.view(bsz * beam_size, 1, vocab_size), + prev_output_tokens, + original_batch_idxs, + ).view(bsz, beam_size, vocab_size) + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + return scores_buf, indices_buf, beams_buf + + +class LexicallyConstrainedBeamSearch(Search): + """Implements lexically constrained beam search as described in + + Fast Lexically Constrained Decoding with Dynamic Beam + Allocation for Neural Machine Translation. Post & Vilar, + NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ + + and + + Improved Lexically Constrained Decoding for Translation and + Monolingual Rewriting. Hu et al, NAACL + 2019. https://www.aclweb.org/anthology/N19-1090/ + + This is accomplished by maintaining, for each beam hypothesis, a + ConstraintState object (see constraints.py) that tracks which + constraints have been generated and using this information to + shape the beam for each input sentence. + """ + + def __init__(self, tgt_dict, representation): + super().__init__(tgt_dict) + self.representation = representation + self.vocab_size = len(tgt_dict) + self.num_cands = 0 + self.supports_constraints = True + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + self.constraint_states = [] + for constraint_tensor in batch_constraints: + if self.representation == "ordered": + constraint_state = OrderedConstraintState.create(constraint_tensor) + elif self.representation == "unordered": + constraint_state = UnorderedConstraintState.create(constraint_tensor) + + self.constraint_states.append([constraint_state for i in range(beam_size)]) + + @torch.jit.export + def prune_sentences(self, batch_idxs: Tensor): + self.constraint_states = [ + self.constraint_states[i] for i in batch_idxs.tolist() + ] + + @torch.jit.export + def update_constraints(self, active_hypos: Tensor): + if self.constraint_states: + batch_size = active_hypos.size(0) + for sentid in range(batch_size): + self.constraint_states[sentid] = [ + self.constraint_states[sentid][i] for i in active_hypos[sentid] + ] + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + """ + A constrained step builds a large candidates list from the following: + - the top 2 * {beam_size} items over the whole beam + - for each item in the beam + - the top {each_k} (default 1) + - all next constraints + We then compute the constrained state of each beam item, and assign + stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so + on. We then sort by (stripe, score), and truncate the list at + 2 * beam size. + + Args: + step: the decoder step + lprobs: (batch size, beam size, target vocab) + the target-vocab distributions for each item in the beam. + Retrun: A tuple of (scores, indices, beams, constraints) where: + scores: (batch, output beam size) + the scores of the chosen elements + indices: (batch, output beam size) + the target vocab indices of the chosen elements + beams: (batch, output beam size) + the 0-indexed hypothesis ids of the chosen elements + constraints: (batch, output beam size) + the new constraint states + """ + each_k = 1 + device = lprobs.device + + batch_size, beam_size, vocab_size = lprobs.size() + + self.num_cands = min( + # Just take the k-best. We'll get another k from the 1-best from each + # row, plus more from the constraints + beam_size * 2, + lprobs.view(batch_size, -1).size(1) - 1, # -1 so we never select pad + ) + + # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items + constraint_states = self.constraint_states + if constraint_states and step > 0: + not_finished_indices = [] + for sentno, sent_constraints in enumerate(constraint_states): + for beamno, state in enumerate(sent_constraints): + index = sentno * beam_size + beamno + if not state.finished: + not_finished_indices.append(index) + not_finished_indices = torch.tensor(not_finished_indices) + if not_finished_indices.numel() > 0: + lprobs.view(batch_size * beam_size, -1)[ + not_finished_indices, self.eos + ] = -math.inf + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam entry for each batch item + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(batch_size, -1), + self.num_cands, + ) + scores_buf, indices_buf = top_prediction + # Project back into relative indices and beams + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + + # Short circuit if there are no constraints in this batch + if not constraint_states: + return scores_buf, indices_buf, beams_buf + + # STEP 1: get top-1 from each hypothesis across all sentences in the batch + if step > 0: + top_scores, top_indices = torch.topk( + lprobs.view(batch_size * beam_size, -1), + k=each_k, + dim=1, + ) + top_scores = top_scores.view(batch_size, -1) + top_indices = top_indices.view(batch_size, -1) + scores_buf = torch.cat((scores_buf, top_scores), dim=1) + indices_buf = torch.cat((indices_buf, top_indices), dim=1) + new_beams = torch.arange(0, beam_size, device=device).repeat(batch_size, 1) + beams_buf = torch.cat((beams_buf, new_beams), dim=1) + + # Now, process sentences in the batch one by one. + new_scores_buf = torch.zeros((batch_size, 2 * beam_size), device=device) + new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + for sentno, states in enumerate(constraint_states): + scores, indices, beams, new_states = self.step_sentence( + step, + sentno, + lprobs[sentno], + constraint_states[sentno], + beams_buf[sentno].clone(), + indices_buf[sentno].clone(), + scores_buf[sentno].clone(), + ) + new_scores_buf[sentno] = scores + new_indices_buf[sentno] = indices + new_beams_buf[sentno] = beams + self.constraint_states[sentno] = new_states + + return new_scores_buf, new_indices_buf, new_beams_buf + + @torch.jit.export + def step_sentence( + self, + step: int, + sentno: int, + lprobs: Tensor, + constraint_states: List[List[ConstraintState]], + beams_buf: Tensor, + indices_buf: Tensor, + scores_buf: Tensor, + ): + """Does per-sentence processing. Adds all constraints for each + hypothesis to the list of candidates; then removes duplicates, + sorts, and dynamically stripes across the banks. All tensor inputs + are collapsed to those pertaining to a single input sentence. + """ + device = lprobs.device + + # STEP 2: Add all constraints for each beam item + for beamno, state in enumerate(constraint_states): + next_tokens = torch.tensor(list(state.next_tokens()), device=device).long() + if next_tokens.numel() != 0: + indices_buf = torch.cat((indices_buf, next_tokens)) + next_beams = ( + torch.tensor(beamno, device=device) + .repeat(next_tokens.size(0)) + .long() + ) + beams_buf = torch.cat((beams_buf, next_beams)) + next_values = lprobs[beamno].take(next_tokens.view(-1)) + scores_buf = torch.cat((scores_buf, next_values)) + + # At the 0th time step, there is just one beam item + if step == 0: + break + + # STEP 3: Compute the "bank" for each candidate. This is the + # number of constraints it's generated. We need this so that + # we can do round-robin allocation of the beam across these + # banks. If C is the number of constraints, we select the best + # item in bank C, then the best in bank C-1, etc, followed by + # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so + # on, until the maximum beam size. We accomplish this by + # creating a sort key and striping across the banks. + + # Compute the new states for all candidates + cands_size = indices_buf.size(0) + constraint_states = [ + constraint_states[beams_buf[i]].advance(indices_buf[i]) + for i in range(cands_size) + ] + + banks = torch.tensor([state.bank for state in constraint_states], device=device) + + # STEP 4: Sort + num_constraint_tokens = len(state.tokens) + + # Sort by keys (bank, score) (i.e., sort banks together, and scores + # within banks). AFAIK pytorch doesn't support either stable sort or + # multi-key sorting, so we have to hack this. + MAX_SCORE = -100 + sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf + sort_values, sort_indices = sort_key.sort(dim=0, descending=True) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + banks = banks[sort_indices] + + # Sort the constraints to follow suit + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 5: Remove duplicates. The topk calls (overall and + # per-row) plus the per-row generation of constraints will + # produce duplicates. Here we remove them. + + def roll(t): + """Rolls a 1d tensor left by 1. + + [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] + """ + return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) + + # We map candidates (beam, token_id) to a single dimension. + # This is then shifted by 1. We can then easily identify + # duplicates and create a mask that identifies unique + # extensions. + uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf + uniques_mask = roll(uniques_mask) != uniques_mask + + # Use the mask to pare down the data structures + scores_buf = torch.masked_select(scores_buf, uniques_mask) + indices_buf = torch.masked_select(indices_buf, uniques_mask) + beams_buf = torch.masked_select(beams_buf, uniques_mask) + banks = torch.masked_select(banks, uniques_mask) + i = 1 + for mask in uniques_mask[1:]: + if not mask: + constraint_states.pop(i) + i += mask + + # STEP 6: Assign IDs round-robin across banks, sort, and + # truncate. Now that the candidates are sorted by (bank, + # score) and uniqed, we dynamically allocate the {beam_size} + # beam by striping across the candidates. These stripes will + # be used as sort keys to do round-robin selection. This is + # accomplished in a single pass with offsets. Sorting by + # highest-banks (furthest-along hypotheses) first ensures + # progress through the constraints. + # + # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 + # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 + # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 + # = 0 5 10 1 6 11 13 2 7 12 3 8 + # + # Sorting by this then gives the following banks: + # + # 3 2 1 0 3 2 1 0 3 2 1 2 + # + # We'll take the top {beam_size} of these. + stripe_offsets = [offset * (len(banks) + 1) for offset in range(len(banks) + 1)] + stripes = torch.zeros_like(banks) + cur_bank_count = -1 + cur_bank = banks[0] + for i, bank in enumerate(banks): + if bank != cur_bank: + cur_bank_count = 0 + cur_bank = bank + else: + cur_bank_count += 1 + stripes[i] = num_constraint_tokens - bank + stripe_offsets[cur_bank_count] + + # STEP 7: Sort by the stripes values + sort_values, sort_indices = stripes.sort(dim=0) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 8: Truncate to the candidates size! + scores_buf = scores_buf[: self.num_cands] + indices_buf = indices_buf[: self.num_cands] + beams_buf = beams_buf[: self.num_cands] + + return scores_buf, indices_buf, beams_buf, constraint_states + + +class LengthConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): + super().__init__(tgt_dict) + self.min_len_a = min_len_a + self.min_len_b = min_len_b + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.beam = BeamSearch(tgt_dict) + self.needs_src_lengths = True + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + min_lens = self.min_len_a * self.src_lengths + self.min_len_b + max_lens = self.max_len_a * self.src_lengths + self.max_len_b + lprobs[step < min_lens, :, self.eos] = -math.inf + lprobs[step >= max_lens, :, self.eos] = 0 + return self.beam.step(step, lprobs, scores) + + +class DiverseBeamSearch(Search): + """Diverse Beam Search. + + See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence + Models" for details. + + We implement cumulative diversity penalty here as default, optionally provide Hamming diversity described + in the original paper, and a way to interpolate between the two through diversity_discount. + + Take the example below for illustration of cumulative diversity implemented. + A) I like dogs. + B) I like ____. + C) There are ___. + And we are at step=2, trying to fill in the blank: + + Hamming diversity: + Penalty for B from A is 1 for "dogs" and 0 for any other words like "cats". + Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats". + + Cumulative diversity (default): + Penalty for B from A is 3 for "dogs" and 0 for any other words like "cats". + Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats". + B and C differ because B matches with A for "I" and "like" at respective steps incurring 2 cumulative penalty. + + Using divesrity_discount to interpolate between the two: + if diverstiy_discount = 0.5, then + Penalty for B from A is 1.75 (1 + 0.5 + 0.25) for "dogs" and 0 for any other words like "cats". + Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats". + "I" and "like" matched for B and A at step 0 and 1 respectively. Since "I" is two steps away and "like" is one step away, they are discounted by (0.5)^2 and 0.5 respectively. + When diversity_discount = 0, we recover Hammning diversity and when diversity_discount = 1, we recover cumulative diversity. + + NB: During beam search for each diversity group, `candidate_mutiple` is set to 1 rather than BeamSearch default(2). + This is to ensure we have final `beam_size` candidates so that no diversity groups would be dropped during final token selection in sequence generation. + For full backwards compatibility, use diversity_discount=0 and candidate_multiple=2. + + """ + + def __init__( + self, + tgt_dict, + num_groups, + diversity_strength, + diversity_discount=1.0, + candidate_multiple=1, + ): + super().__init__(tgt_dict) + self.num_groups = num_groups + self.diversity_strength = -diversity_strength + self.beam = BeamSearch(tgt_dict) + self.diversity_discount = diversity_discount + self.candidate_multiple = candidate_multiple + + # Float tensor to keep track of overlap between groups. + # Each token shared at the same step between two groups is counted as one. + # Then token counts are discounted by `diversity_discount` for every next timestep. + # Once initialized, dimension is batch_size * num_groups * num_groups. + self.group_overlap = torch.empty(0) + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + if beam_size % self.num_groups != 0: + raise ValueError( + "DiverseBeamSearch requires --beam to be divisible by the number of groups" + ) + + # initialize diversity penalty + diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs) + + scores_G, beams_G = [], [] + + # pre-allocating tensor for indices for all groups + indices_G_stacked = torch.empty( + bsz, + int(beam_size / self.num_groups) * self.candidate_multiple, + self.num_groups, + dtype=torch.long, + device=lprobs.device, + ) + + for g in range(self.num_groups): + lprobs_g = lprobs[:, g :: self.num_groups, :] + scores_g = scores[:, g :: self.num_groups, :] if step > 0 else None + + diversity_buf.zero_() + # apply diversity penalty + if g > 0: + indices_ = indices_G_stacked[:, :, :g] + if step > 0: + penalty_val = 1 + self.group_overlap[original_batch_idxs, g, :g] + penalty_val = penalty_val.unsqueeze(1) + else: + penalty_val = torch.ones(bsz, 1, 1) + diversity_buf.scatter_add_( + 1, + indices_.reshape(bsz, -1), + penalty_val.expand(indices_.size()) + .reshape(bsz, -1) + .to(diversity_buf), + ) + + lprobs_g = torch.add( + lprobs_g, + other=diversity_buf.unsqueeze(1), + alpha=self.diversity_strength, + ) + else: + lprobs_g = lprobs_g.contiguous() + + scores_buf, indices_buf, beams_buf = self.beam.step( + step, lprobs_g, scores_g, candidate_multiple=self.candidate_multiple + ) + beams_buf.mul_(self.num_groups).add_(g) + + scores_G.append(scores_buf.clone()) + beams_G.append(beams_buf.clone()) + + indices_G_stacked[:, :, g] = indices_buf + + # interleave results from different groups + scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1) + indices_buf = indices_G_stacked.view(bsz, -1) + beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1) + # find num of overlapped tokens for each group pair + # then discount it for next timestamp + overlap = self.diversity_discount * torch.sum( + indices_G_stacked.unsqueeze(2).eq(indices_G_stacked.unsqueeze(3)), dim=1 + ) + if step == 0: + self.group_overlap = overlap + else: + self.group_overlap[original_batch_idxs] = ( + self.group_overlap[original_batch_idxs] * self.diversity_discount + + overlap + ) + + return scores_buf, indices_buf, beams_buf + + +class Sampling(Search): + sampling_topk: int + sampling_topp: float + + def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0): + super().__init__(tgt_dict) + self.sampling_topk = sampling_topk + self.sampling_topp = sampling_topp + + def _sample_topp(self, lprobs): + """Sample among the smallest set of elements whose cumulative probability mass exceeds p. + + See `"The Curious Case of Neural Text Degeneration" + (Holtzman et al., 2019) `_. + + Args: + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + + Return: A tuple of (trimed_probs, truncated_indices) where: + trimed_probs: (bsz x input_beam_size x ?) + the model's probabilities over the elements selected to sample from. The + width of the third dimension is determined by top-P. + truncated_indices: (bsz x input_beam_size x ?) + the indices of the chosen elements. + """ + probs = lprobs.exp_() + + # sort the last dimension (vocab dimension) in descending order + sorted_probs, sorted_indices = probs.sort(descending=True) + + # compute a mask to indicate the words to be included in the top-P set. + cumsum_probs = sorted_probs.cumsum(dim=2) + mask = cumsum_probs.lt(self.sampling_topp) + + # note that mask was computed by 'lt'. One more word needs to be included + # so that the cumulative probability mass can exceed p. + cumsum_mask = mask.cumsum(dim=2) + last_included = cumsum_mask[:, :, -1:] + last_included.clamp_(0, mask.size()[2] - 1) + mask = mask.scatter_(2, last_included, 1) + + # truncate unnecessary dims. + max_dim = last_included.max() + truncated_mask = mask[:, :, : max_dim + 1] + truncated_probs = sorted_probs[:, :, : max_dim + 1] + truncated_indices = sorted_indices[:, :, : max_dim + 1] + + # trim the words that are not in top-P by setting their probabilities + # to 0, so that they would not be sampled later. + trim_mask = ~truncated_mask + trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) + return trimed_probs, truncated_indices + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + + if self.sampling_topp > 0: + # only sample from the smallest set of words whose cumulative probability mass exceeds p + probs, top_indices = self._sample_topp(lprobs) + elif self.sampling_topk > 0: + # only sample from top-k candidates + lprobs, top_indices = lprobs.topk(self.sampling_topk) + probs = lprobs.exp_() + else: + probs = lprobs.exp_() + + # dummy data to be consistent with true branch for type check + top_indices = torch.empty(0).to(probs) + # sample + if step == 0: + indices_buf = torch.multinomial( + probs.view(bsz, -1), + beam_size, + replacement=True, + ).view(bsz, beam_size) + else: + indices_buf = torch.multinomial( + probs.view(bsz * beam_size, -1), + 1, + replacement=True, + ).view(bsz, beam_size) + + if step == 0: + # expand to beam size + probs = probs.expand(bsz, beam_size, -1) + + # gather scores + scores_buf = torch.gather(probs, dim=2, index=indices_buf.unsqueeze(-1)) + scores_buf = scores_buf.log_().view(bsz, -1) + + # remap indices if using top-k or top-P sampling + if self.sampling_topk > 0 or self.sampling_topp > 0: + indices_buf = torch.gather( + top_indices.expand(bsz, beam_size, -1), + dim=2, + index=indices_buf.unsqueeze(-1), + ).squeeze(2) + + if step == 0: + beams_buf = indices_buf.new_zeros(bsz, beam_size) + else: + beams_buf = torch.arange(0, beam_size).to(indices_buf).repeat(bsz, 1) + # make scores cumulative + scores_buf.add_( + torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf) + ) + + return scores_buf, indices_buf, beams_buf + + +class DiverseSiblingsSearch(Search): + """ + Beam search with diverse siblings. + + See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details. + https://arxiv.org/abs/1611.08562 + + 1/ Calculate hypotheses for each beam + 2/ Intra-sibling ordering + 3/ Rewrite scores + 4/ Choose top K hypotheses + + if diversity_rate == 0 is equivalent to BeamSearch + """ + + def __init__(self, tgt_dict, diversity_rate): + super().__init__(tgt_dict) + self.diversity_rate = diversity_rate + self.beam = BeamSearch(tgt_dict) + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + k = min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ) + s_list: List[Tensor] + i_list: List[Tensor] + s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)] + i_list = [torch.LongTensor().to(device=lprobs.device) for i in range(beam_size)] + sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate + + if step == 0: + return self.beam.step(step, lprobs, scores) + lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) + + # 1/ Calculate hypotheses for each beam + for i in range(beam_size): + torch.topk(lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i])) + i_list[i].fmod_(vocab_size) + + # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores + s_list[i].sub_(sibling_score) + + # 4/ Choose top K hypotheses + indices = torch.stack(i_list, dim=1).view(bsz, -1) + + final_scores = torch.empty(0).to(lprobs) + final_indices = torch.LongTensor().to(device=lprobs.device) + final_beams = torch.LongTensor().to(device=lprobs.device) + (final_scores, final_indices) = torch.topk( + torch.stack(s_list, dim=1).view(bsz, -1), + k, + ) + + final_beams = final_indices // k + + for i in range(bsz): + final_indices[i] = indices[i][final_indices[i]] + + return final_scores, final_indices, final_beams diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py new file mode 100644 index 0000000..a15abe3 --- /dev/null +++ b/fairseq/tasks/__init__.py @@ -0,0 +1,138 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import argparse +import importlib +import os + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore + +from .fairseq_task import FairseqTask # noqa + + +# register dataclass +TASK_DATACLASS_REGISTRY = {} +TASK_REGISTRY = {} +TASK_CLASS_NAMES = set() + + +def setup_task(cfg: FairseqDataclass, **kwargs): + task = None + task_name = getattr(cfg, "task", None) + + if isinstance(task_name, str): + # legacy tasks + task = TASK_REGISTRY[task_name] + if task_name in TASK_DATACLASS_REGISTRY: + dc = TASK_DATACLASS_REGISTRY[task_name] + cfg = dc.from_namespace(cfg) + else: + task_name = getattr(cfg, "_name", None) + + if task_name and task_name in TASK_DATACLASS_REGISTRY: + remove_missing = "from_checkpoint" in kwargs and kwargs["from_checkpoint"] + dc = TASK_DATACLASS_REGISTRY[task_name] + cfg = merge_with_parent(dc(), cfg, remove_missing=remove_missing) + task = TASK_REGISTRY[task_name] + + assert ( + task is not None + ), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}" + + return task.setup_task(cfg, **kwargs) + + +def register_task(name, dataclass=None): + """ + New tasks can be added to fairseq with the + :func:`~fairseq.tasks.register_task` function decorator. + + For example:: + + @register_task('classification') + class ClassificationTask(FairseqTask): + (...) + + .. note:: + + All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` + interface. + + Args: + name (str): the name of the task + """ + + def register_task_cls(cls): + if name in TASK_REGISTRY: + return TASK_REGISTRY[name] + + if not issubclass(cls, FairseqTask): + raise ValueError( + "Task ({}: {}) must extend FairseqTask".format(name, cls.__name__) + ) + if cls.__name__ in TASK_CLASS_NAMES: + raise ValueError( + "Cannot register task with duplicate class name ({})".format( + cls.__name__ + ) + ) + TASK_REGISTRY[name] = cls + TASK_CLASS_NAMES.add(cls.__name__) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if dataclass is not None: + TASK_DATACLASS_REGISTRY[name] = dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group="task", node=node, provider="fairseq") + + return cls + + return register_task_cls + + +def get_task(name): + return TASK_REGISTRY[name] + + +def import_tasks(tasks_dir, namespace): + for file in os.listdir(tasks_dir): + path = os.path.join(tasks_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + task_name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module(namespace + "." + task_name) + + # expose `task_parser` for sphinx + if task_name in TASK_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_task = parser.add_argument_group("Task name") + # fmt: off + group_task.add_argument('--task', metavar=task_name, + help='Enable this task with: ``--task=' + task_name + '``') + # fmt: on + group_args = parser.add_argument_group( + "Additional command-line arguments" + ) + TASK_REGISTRY[task_name].add_args(group_args) + globals()[task_name + "_parser"] = parser + + +# automatically import any Python files in the tasks/ directory +tasks_dir = os.path.dirname(__file__) +import_tasks(tasks_dir, "fairseq.tasks") diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py new file mode 100644 index 0000000..e39d1d6 --- /dev/null +++ b/fairseq/tasks/fairseq_task.py @@ -0,0 +1,708 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import warnings +from argparse import Namespace +from typing import Any, Callable, Dict, List + +import torch +from fairseq import search, tokenizer, utils +from fairseq.logging import metrics +from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim.amp_optimizer import AMPOptimizer +from omegaconf import DictConfig + + +logger = logging.getLogger(__name__) + + +class StatefulContainer(object): + def __init__(self): + self._state = dict() + self._factories = dict() + + def add_factory(self, name, factory: Callable[[], Any]): + self._factories[name] = factory + + def merge_state_dict(self, state_dict: Dict[str, Any]): + self._state.update(state_dict) + + @property + def state_dict(self) -> Dict[str, Any]: + return self._state + + def __getattr__(self, name): + if name not in self._state and name in self._factories: + self._state[name] = self._factories[name]() + + if name in self._state: + return self._state[name] + + raise AttributeError(f"Task state has no factory for attribute {name}") + + +class FairseqTask(object): + """ + Tasks store dictionaries and provide helpers for loading/iterating over + Datasets, initializing the Model/Criterion and calculating the loss. + + Tasks have limited statefulness. In particular, state that needs to be + saved to/loaded from checkpoints needs to be stored in the `self.state` + :class:`StatefulContainer` object. For example:: + + self.state.add_factory("dictionary", self.load_dictionary) + print(self.state.dictionary) # calls self.load_dictionary() + + This is necessary so that when loading checkpoints, we can properly + recreate the task state after initializing the task instance. + """ + + @classmethod + def add_args(cls, parser): + """Add task-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @staticmethod + def logging_outputs_can_be_summed(criterion) -> bool: + """ + Whether the logging outputs returned by `train_step` and `valid_step` can + be summed across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return criterion.logging_outputs_can_be_summed() + + def __init__(self, cfg: FairseqDataclass, **kwargs): + self.cfg = cfg + self.datasets = dict() + self.dataset_to_epoch_iter = dict() + self.state = StatefulContainer() + + @classmethod + def load_dictionary(cls, filename): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + return Dictionary.load(filename) + + @classmethod + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): + """Build the dictionary + + Args: + filenames (list): list of filenames + workers (int): number of concurrent workers + threshold (int): defines the minimum word count + nwords (int): defines the total number of words in the final dictionary, + including special symbols + padding_factor (int): can be used to pad the dictionary size to be a + multiple of 8, which is important on some hardware (e.g., Nvidia + Tensor Cores). + """ + d = Dictionary() + for filename in filenames: + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) + d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) + return d + + @classmethod + def setup_task(cls, cfg: DictConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (omegaconf.DictConfig): parsed command-line arguments + """ + return cls(cfg, **kwargs) + + def has_sharded_data(self, split): + return os.pathsep in getattr(self.cfg, "data", "") + + def load_dataset( + self, + split: str, + combine: bool = False, + task_cfg: FairseqDataclass = None, + **kwargs, + ): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + combine (bool): combines a split segmented into pieces into one dataset + task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used + to load datasets + """ + raise NotImplementedError + + def dataset(self, split): + """ + Return a loaded dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + + Returns: + a :class:`~fairseq.data.FairseqDataset` corresponding to *split* + """ + from fairseq.data import FairseqDataset + + if split not in self.datasets: + raise KeyError("Dataset not loaded: " + split) + if not isinstance(self.datasets[split], FairseqDataset): + raise TypeError("Datasets are expected to be of type FairseqDataset") + return self.datasets[split] + + def filter_indices_by_size( + self, indices, dataset, max_positions=None, ignore_invalid_inputs=False + ): + """ + Filter examples that are too large + + Args: + indices (np.array): original array of sample indices + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + Returns: + np.array: array of filtered sample indices + """ + indices, ignored = dataset.filter_indices_by_size(indices, max_positions) + if len(ignored) > 0: + if not ignore_invalid_inputs: + raise Exception( + ( + "Size of sample #{} is invalid (={}) since max_positions={}, " + "skip this example with --skip-invalid-size-inputs-valid-test" + ).format(ignored[0], dataset.size(ignored[0]), max_positions) + ) + logger.warning( + ( + "{:,} samples have invalid sizes and will be skipped, " + "max_positions={}, first few sample ids={}" + ).format(len(ignored), max_positions, ignored[:10]) + ) + return indices + + def can_reuse_epoch_itr(self, dataset): + # We can reuse the epoch iterator across epochs as long as the dataset + # hasn't disabled it. We default to ``False`` here, although in practice + # this will be ``True`` for most datasets that inherit from + # ``FairseqDataset`` due to the base implementation there. + return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False) + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ): + """ + Get an iterator that yields batches of data from the given dataset. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + data_buffer_size (int, optional): number of batches to + preload (default: 0). + disable_iterator_cache (bool, optional): don't cache the + EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) + (default: False). + skip_remainder_batch (bool, optional): if set, discard the last + batch in each training epoch, as the last batch is often smaller than + local_batch_size * distributed_word_size (default: ``True``). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + + Returns: + ~fairseq.iterators.EpochBatchIterator: a batched iterator over the + given dataset split + """ + can_reuse_epoch_itr = ( + not disable_iterator_cache + and not update_epoch_batch_itr + and self.can_reuse_epoch_itr(dataset) + ) + logger.info(f"can_reuse_epoch_itr = {can_reuse_epoch_itr}") + if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: + logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) + return self.dataset_to_epoch_iter[dataset] + + assert isinstance(dataset, FairseqDataset) + + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + def make_batches(dataset, epoch): + logger.info(f"creating new batches for epoch {epoch}") + + # get indices ordered by example size + with data_utils.numpy_seed(seed + epoch): + indices = dataset.ordered_indices() + + # filter examples that are too large + if max_positions is not None: + indices = self.filter_indices_by_size( + indices, dataset, max_positions, ignore_invalid_inputs + ) + + # create mini-batches with given size constraints + batches = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + return batches + + reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True) + persistent_workers = getattr(self.cfg, "persistent_workers", True) + rebuild_batches = getattr(self.cfg, "rebuild_batches", False) + logger.info(f"reuse_dataloader = {reuse_dataloader}") + logger.info(f"rebuild_batches = {rebuild_batches}") + + if rebuild_batches: + logger.info("batches will be rebuilt for each epoch") + batch_sampler = make_batches + else: + batch_sampler = make_batches(dataset, epoch) + + # return a reusable, sharded iterator + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + buffer_size=data_buffer_size, + skip_remainder_batch=skip_remainder_batch, + grouped_shuffling=grouped_shuffling, + reuse_dataloader=reuse_dataloader, + persistent_workers=persistent_workers, + ) + + if can_reuse_epoch_itr: + self.dataset_to_epoch_iter[dataset] = epoch_iter + + return epoch_iter + + def build_model(self, cfg: FairseqDataclass, from_checkpoint=False): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + cfg (FairseqDataclass): configuration object + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + + model = models.build_model(cfg, self, from_checkpoint) + model = quantization_utils.quantize_model_scalar(model, cfg) + return model + + def build_criterion(self, cfg: DictConfig, from_checkpoint=False): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + cfg (omegaconf.DictConfig): configration object + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(cfg, self, from_checkpoint=from_checkpoint) + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + prefix_allowed_tokens_fn=None, + ): + """ + Build a :class:`~fairseq.SequenceGenerator` instance for this + task. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + args (fairseq.dataclass.configs.GenerationConfig): + configuration object (dataclass) for generation + extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass + through to SequenceGenerator + prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): + If provided, this function constrains the beam search to + allowed tokens only at each step. The provided function + should take 2 arguments: the batch ID (`batch_id: int`) + and a unidimensional tensor of token ids (`inputs_ids: + torch.Tensor`). It has to return a `List[int]` with the + allowed tokens for the next generation step conditioned + on the previously generated tokens (`inputs_ids`) and + the batch ID (`batch_id`). This argument is useful for + constrained generation conditioned on the prefix, as + described in "Autoregressive Entity Retrieval" + (https://arxiv.org/abs/2010.00904) and + https://github.com/facebookresearch/GENRE. + """ + if getattr(args, "score_reference", False): + from fairseq.sequence_scorer import SequenceScorer + + return SequenceScorer( + self.target_dictionary, + compute_alignment=getattr(args, "print_alignment", False), + ) + + from fairseq.sequence_generator import ( + SequenceGenerator, + SequenceGeneratorWithAlignment, + ) + + # Choose search strategy. Defaults to Beam Search. + sampling = getattr(args, "sampling", False) + sampling_topk = getattr(args, "sampling_topk", -1) + sampling_topp = getattr(args, "sampling_topp", -1.0) + diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) + diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) + match_source_len = getattr(args, "match_source_len", False) + diversity_rate = getattr(args, "diversity_rate", -1) + constrained = getattr(args, "constraints", False) + if prefix_allowed_tokens_fn is None: + prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) + if ( + sum( + int(cond) + for cond in [ + sampling, + diverse_beam_groups > 0, + match_source_len, + diversity_rate > 0, + ] + ) + > 1 + ): + raise ValueError("Provided Search parameters are mutually exclusive.") + assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" + assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" + + if sampling: + search_strategy = search.Sampling( + self.target_dictionary, sampling_topk, sampling_topp + ) + elif diverse_beam_groups > 0: + search_strategy = search.DiverseBeamSearch( + self.target_dictionary, diverse_beam_groups, diverse_beam_strength + ) + elif match_source_len: + # this is useful for tagging applications where the output + # length should match the input length, so we hardcode the + # length constraints for simplicity + search_strategy = search.LengthConstrainedBeamSearch( + self.target_dictionary, + min_len_a=1, + min_len_b=0, + max_len_a=1, + max_len_b=0, + ) + elif diversity_rate > -1: + search_strategy = search.DiverseSiblingsSearch( + self.target_dictionary, diversity_rate + ) + elif constrained: + search_strategy = search.LexicallyConstrainedBeamSearch( + self.target_dictionary, args.constraints + ) + elif prefix_allowed_tokens_fn: + search_strategy = search.PrefixConstrainedBeamSearch( + self.target_dictionary, prefix_allowed_tokens_fn + ) + else: + search_strategy = search.BeamSearch(self.target_dictionary) + + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + if seq_gen_cls is None: + if getattr(args, "print_alignment", False): + seq_gen_cls = SequenceGeneratorWithAlignment + extra_gen_cls_kwargs["print_alignment"] = args.print_alignment + else: + seq_gen_cls = SequenceGenerator + + return seq_gen_cls( + models, + self.target_dictionary, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search_strategy, + **extra_gen_cls_kwargs, + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + """ + Do forward and backward, and return the loss as computed by *criterion* + for the given *model* and *sample*. + + Args: + sample (dict): the mini-batch. The format is defined by the + :class:`~fairseq.data.FairseqDataset`. + model (~fairseq.models.BaseFairseqModel): the model + criterion (~fairseq.criterions.FairseqCriterion): the criterion + optimizer (~fairseq.optim.FairseqOptimizer): the optimizer + update_num (int): the current update + ignore_grad (bool): multiply loss by 0 if this is set to True + + Returns: + tuple: + - the loss + - the sample size, which is used as the denominator for the + gradient + - logging outputs to display while training + """ + model.train() + model.set_num_updates(update_num) + with torch.autograd.profiler.record_function("forward"): + with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): + loss, sample_size, logging_output = criterion(model, sample) + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output + + def optimizer_step(self, optimizer, model, update_num): + optimizer.step() + + def build_dataset_for_inference( + self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs + ) -> torch.utils.data.Dataset: + raise NotImplementedError + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, constraints=constraints + ) + + def begin_epoch(self, epoch, model): + """Hook function called before the start of each epoch.""" + pass + + def begin_valid_epoch(self, epoch, model): + """Hook function called before the start of each validation epoch.""" + pass + + def aggregate_logging_outputs(self, logging_outputs, criterion): + """[deprecated] Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "The aggregate_logging_outputs API is deprecated. " + "Please use the reduce_metrics API instead." + ) + with metrics.aggregate() as agg: + self.reduce_metrics(logging_outputs, criterion) + return agg.get_smoothed_values() + + def reduce_metrics(self, logging_outputs, criterion): + """Aggregate logging outputs from data parallel training.""" + # backward compatibility for tasks that override aggregate_logging_outputs + base_func = FairseqTask.aggregate_logging_outputs + self_func = getattr(self, "aggregate_logging_outputs").__func__ + if self_func is not base_func: + utils.deprecation_warning( + "Tasks should implement the reduce_metrics API. " + "Falling back to deprecated aggregate_logging_outputs API." + ) + agg_logging_outputs = self.aggregate_logging_outputs( + logging_outputs, criterion + ) + for k, v in agg_logging_outputs.items(): + metrics.log_scalar(k, v) + return + + if not any("ntokens" in log for log in logging_outputs): + warnings.warn( + "ntokens not found in Criterion logging outputs, cannot log wpb or wps" + ) + else: + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + metrics.log_scalar("wpb", ntokens, priority=180, round=1) + metrics.log_speed("wps", ntokens, priority=90, round=1) + + if not any("nsentences" in log for log in logging_outputs): + warnings.warn( + "nsentences not found in Criterion logging outputs, cannot log bsz" + ) + else: + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + metrics.log_scalar("bsz", nsentences, priority=190, round=1) + + criterion.__class__.reduce_metrics(logging_outputs) + + def state_dict(self): + if self.state is not None: + return self.state.state_dict + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + if self.state is not None: + self.state.merge_state_dict(state_dict) + + def max_positions(self): + """Return the max input length allowed by the task.""" + return None + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + return None + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + return None + + def build_tokenizer(self, args): + """Build the pre-tokenizer for this task.""" + return encoders.build_tokenizer(args) + + def build_bpe(self, args): + """Build the tokenizer for this task.""" + return encoders.build_bpe(args) + + def get_interactive_tokens_and_lengths(self, lines, encode_fn): + tokens = [ + self.source_dictionary.encode_line( + encode_fn(src_str), add_if_not_exist=False + ).long() + for src_str in lines + ] + lengths = [t.numel() for t in tokens] + return tokens, lengths + + +class LegacyFairseqTask(FairseqTask): + def __init__(self, args: Namespace): + super().__init__(None) + self.args = args + self.datasets = {} + self.dataset_to_epoch_iter = {} + + @classmethod + def setup_task(cls, args: Namespace, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + return cls(args, **kwargs) + + def has_sharded_data(self, split): + return os.pathsep in getattr(self.args, "data", "") + + def build_model(self, args: Namespace, from_checkpoint=False): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + + model = models.build_model(args, self, from_checkpoint) + model = quantization_utils.quantize_model_scalar(model, args) + return model + + def build_criterion(self, args: Namespace): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(args, self) diff --git a/fairseq/tasks/hubert_pretraining.py b/fairseq/tasks/hubert_pretraining.py new file mode 100644 index 0000000..d7e58d3 --- /dev/null +++ b/fairseq/tasks/hubert_pretraining.py @@ -0,0 +1,193 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import sys +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from dataclasses import dataclass, field +from fairseq.data import Dictionary +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary: Dictionary) -> None: + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + return self.dictionary.encode_line( + label, + append_eos=False, + add_if_not_exist=False, + ) + + +@dataclass +class HubertPretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + fine_tuning: bool = field( + default=False, metadata={"help": "set to true if fine-tuning Hubert"} + ) + labels: List[str] = field( + default_factory=lambda: ["ltr"], + metadata={ + "help": ( + "extension of the label files to load, frame-level labels for" + " pre-training, and sequence-level label for fine-tuning" + ) + }, + ) + label_dir: Optional[str] = field( + default=None, + metadata={ + "help": "if set, looks for labels in this directory instead", + }, + ) + label_rate: float = field( + default=-1.0, + metadata={"help": "label frame rate. -1.0 for sequence label"}, + ) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_keep_size: Optional[int] = field( + default=None, + metadata={"help": "exclude sample longer than this"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to crop to for batching"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to crop to for batching"}, + ) + single_target: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset" + }, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + + +@register_task("hubert_pretraining", dataclass=HubertPretrainingConfig) +class HubertPretrainingTask(FairseqTask): + + cfg: HubertPretrainingConfig + + def __init__( + self, + cfg: HubertPretrainingConfig, + ) -> None: + super().__init__(cfg) + + logger.info(f"current directory is {os.getcwd()}") + logger.info(f"HubertPretrainingTask Config {cfg}") + + self.cfg = cfg + self.fine_tuning = cfg.fine_tuning + + if cfg.fine_tuning: + self.state.add_factory("target_dictionary", self.load_dictionaries) + else: + self.state.add_factory("dictionaries", self.load_dictionaries) + + self.blank_symbol = "" + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return None + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self.state.target_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return self.state.dictionaries + + @classmethod + def setup_task( + cls, cfg: HubertPretrainingConfig, **kwargs + ) -> "HubertPretrainingTask": + return cls(cfg) + + def load_dictionaries(self): + label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir + dictionaries = [ + Dictionary.load(f"{label_dir}/dict.{label}.txt") + for label in self.cfg.labels + ] + return dictionaries[0] if self.cfg.fine_tuning else dictionaries + + def get_label_dir(self) -> str: + if self.cfg.label_dir is None: + return self.cfg.data + return self.cfg.label_dir + + def load_dataset(self, split: str, **kwargs) -> None: + from fairseq.data import HubertDataset + + manifest = f"{self.cfg.data}/{split}.tsv" + dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries + pad_list = [dict.pad() for dict in dicts] + eos_list = [dict.eos() for dict in dicts] + procs = [LabelEncoder(dict) for dict in dicts] + paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] + + # hubert v1: pad_audio=True, random_crop=False; + self.datasets[split] = HubertDataset( + manifest, + sample_rate=self.cfg.sample_rate, + label_paths=paths, + label_rates=self.cfg.label_rate, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=self.cfg.max_keep_size, + min_keep_sample_size=self.cfg.min_sample_size, + max_sample_size=self.cfg.max_sample_size, + pad_audio=self.cfg.pad_audio, + normalize=self.cfg.normalize, + store_labels=False, + random_crop=self.cfg.random_crop, + single_target=self.cfg.single_target, + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: + return indices diff --git a/fairseq/token_generation_constraints.py b/fairseq/token_generation_constraints.py new file mode 100644 index 0000000..e708dc5 --- /dev/null +++ b/fairseq/token_generation_constraints.py @@ -0,0 +1,506 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Implements tracking of constraints for a beam item. + +A list of constraints is given as a list of one or more token +sequences, each of length at least one token. For example, for an input sentence + +> Die maschinelle Übersetzung ist schwer zu kontrollieren. + +We could have the constraints: +* to influence +* hard + +There are two implementations: +* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. +* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. + +The difference is that in the first, the constraints are assumed to be +in order; the algorithm will permit zero or more tokens between them. +In the second, the constraints are not ordered, so many orderings will +be explored. + +The same sequence can be present any number of times, and will appear +that many times in the output. +""" + +from collections import Counter +from typing import List, Optional, Set, Tuple + +import torch + + +class ConstraintState: + def __init__(self): + pass + + +def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: + """Takes a list of list of constraints in tensor form (a list of + tensor constraints for each sentence) and transforms it into a + packed Tensor. For example, here is a batch of size 3 with 3, 0, + and 1 constraints: + + [ [ [3 1 2], [3], [4 5 6 7], ] + [], + [ [1 8 9 10 1 4 11 12], ] + ] + + Its corresponding packed structure is: + + [ [ 3 3 1 2 0 3 0 4 5 6 7 0], + [ 0 0 0 0 0 0 0 0 0 0 0 0], + [ 1 1 8 9 10 1 4 11 12 0 0 0] ] + + The packed tensor has shape (batch size, maxlen), where + maxlen is defined below. Each row contains concatenated + constraint tokens for that sentence, with 0 appended after + each constraint. The first item in each row is the number + of constraints for that sentence. So maxlen is the maximum + of + + (number of constraints) + (sum length of constraints) + 1. + + across all sentences in the batch. + """ + # The maximum word length of concatenated constraints for any sentence + max_constraints_len = 1 + for sentence_constraints in batch_constraints: + if len(sentence_constraints): + # number of constraints, plus sum of constrain lens, plus a zero after each + constraints_len = ( + 1 + + sum([c.size(0) for c in sentence_constraints]) + + len(sentence_constraints) + ) + max_constraints_len = max(max_constraints_len, constraints_len) + + batch_size = len(batch_constraints) + constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() + for i, sentence_constraints in enumerate(batch_constraints): + constraints_tensor[i, 0] = len(sentence_constraints) + offset = 1 + for j, constraint in enumerate(sentence_constraints): + this_len = constraint.size(0) + constraints_tensor[i, offset : offset + this_len] = constraint + offset += this_len + 1 + + return constraints_tensor.long() + + +def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Transforms *one row* of a packed constraint tensor (e.g., for one + sentence in the batch) into a list of constraint tensors. + """ + constraint_list = [] + num_constraints = constraint_tensor[0] + constraints = constraint_tensor.tolist() + offset = 1 + for i in range(num_constraints): + where = constraints.index(0, offset) + constraint_list.append(constraint_tensor[offset:where]) + offset = where + 1 + + return constraint_list + + +class ConstraintNode: + """ + Represents a node in a trie managing unordered constraints. + """ + + def __init__(self, token: int = None, parent=None): + # The token associate with this node (None for the root) + self.token = int(token) if token is not None else None + # The parent (None at the root) + self.parent = parent + # Whether this node is a completed constraint + self.terminal = 0 + # List of child nodes + self.children = {} + + # The cumulative number of constraints from this point in the + # trie forward + self.num_constraints = 0 + + @property + def id(self): + return self.token + + def __str__(self): + term = self.terminal != 0 + return f"[{self.token}].{term}#{self.num_constraints}" + + def __getitem__(self, key: int): + return self.children.get(key, None) + + def next_tokens(self) -> Set[int]: + """The set of child labels.""" + return set(self.children.keys()) + + @staticmethod + def create(constraints: List[List[int]]): + root = ConstraintNode() + for sequence in constraints: + root.add_sequence(sequence) + + return root + + @staticmethod + def print_graph(node: "ConstraintNode"): + if len(node.children) == 0: + return str(node) + else: + s = f"({node}" + for child in node.children.values(): + s += " " + ConstraintNode.print_graph(child) + s += ")" + return s + + def token_counts(self) -> Counter: + """Returns a counter of the number of times each token is used + in a constraint. + """ + token_counts = Counter() + kids = list(self.children.values()) + while len(kids) > 0: + kid = kids.pop() + token_counts[kid.id] += kid.num_constraints + kids += list(kid.children.values()) + + return token_counts + + def tokens(self) -> Set[int]: + """Returns the set of tokens in constraints.""" + return set(self.token_counts().keys()) + + def add_sequence(self, sequence: List[int]): + """Adds a constraint, represented as a list of integers, to + the trie.""" + assert len(sequence) > 0 + + token = int(sequence[0]) + if token not in self.children: + self.children[token] = ConstraintNode(token, parent=self) + + node = self.children[token] + if len(sequence) == 1: + node.terminal += 1 + node.num_constraints += 1 + parent = node.parent + while parent is not None: + parent.num_constraints += 1 + parent = parent.parent + else: + node.add_sequence(sequence[1:]) + + +class UnorderedConstraintState(ConstraintState): + """ + Records progress through the set of constraints for each item in the beam + using a trie. + """ + + def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None): + self.node = node + + if copy_from is None: + # The root node + self.root = node + # The set of states in the graph that have been completed + self.completed = Counter() + # The... + self.generated = Counter() + # The list of tokens we need to generate + self.needed_tokens = self.root.tokens() + else: + self.completed = Counter(copy_from.completed) + self.generated = Counter(copy_from.generated) + self.root = copy_from.root + + # Mark the node as generated + if self.node != self.root: + self.generated[node] += 1 + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + constraint_trie_root = ConstraintNode.create(constraint_list) + return UnorderedConstraintState(constraint_trie_root) + + def __str__(self): + gen_str = ",".join([str(node) for node in self.generated]) + return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}" + + def __copy__(self): + copied_state = UnorderedConstraintState(self.node, copy_from=self) + return copied_state + + def copy(self): + return self.__copy__() + + @property + def name(self): + if self.node.id is None: + return "ROOT" + else: + return str(self.node.id) + + @property + def is_root(self): + return self.node == self.root + + @property + def bank(self): + return sum(self.generated.values()) + + @property + def num_completed(self): + """The number of constraints (not constraint tokens) that are completed. + In addition to the already-completed states, we need to account for the + current state, which might get marked as completed when another token + is generated. + """ + in_final = self.node.terminal and self.completed[self.node] < self.node.terminal + return sum(self.completed.values()) + in_final + + @property + def finished(self): + return self.root.num_constraints - self.num_completed == 0 + + @property + def token_counts(self): + return self.root.token_counts() + + @property + def tokens(self): + return self.root.tokens() + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + if self.node != self.root: + return self.root.next_tokens().union(self.node.next_tokens()) + else: + return self.root.next_tokens() + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + + next_state = None + child = self.node[token] + if child is not None and self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + + def rewind(): + """If we're mid-trie and an "illegal" token is chosen next, we need + to reset our state to the root state. However, along the way, we need + to check whether a prefix of the current trie state represents a state + we could mark as completed. + """ + node = self.node + while node != self.root: + if node.terminal and self.completed[node] < node.terminal: + next_state.completed[node] += 1 + return + + next_state.generated[node] -= 1 + node = node.parent + + # Fall off the graph, check the root + if next_state is None and token in self.root.next_tokens(): + child = self.root[token] + # We can only traverse this edge if it's not saturated + if self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + else: + next_state = UnorderedConstraintState(self.root, copy_from=self) + + # Rewind + rewind() + + elif next_state is None: + next_state = UnorderedConstraintState(self.root, copy_from=self) + # Rewind + rewind() + + return next_state + + +class ConstraintSequence: + def __init__(self, sequences: List[List[int]]): + """Represents a set of possibly multitoken constraints by + concatenating them and internally recording the end points. + """ + self.sequences = [] + self.endpoints = [] + self.num_tokens = 0 + self.tokens = set() + for sequence in sequences: + for token in sequence: + self.tokens.add(token) + self.num_tokens += len(sequence) + self.endpoints += [False for x in range(len(sequence) - 1)] + [True] + self.sequences += sequence + + def __getitem__(self, key: int): + return self.sequences[key] + + def __len__(self): + return len(self.sequences) + + def __str__(self): + return str(self.sequences) + + +class OrderedConstraintState(ConstraintState): + """ + Records progress through the set of linear nonbranching constraints with gaps. + """ + + def __init__(self, sequence: ConstraintSequence, state: int = -1): + self.sequence = sequence + self.state = state + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + return OrderedConstraintState(ConstraintSequence(constraint_list), -1) + + def __str__(self): + return f"{self.state}/{self.bank}x{self.num_completed}" + + def __copy__(self): + return OrderedConstraintState(self.sequence, self.state) + + def copy(self): + return self.__copy__() + + @property + def num_completed(self): + if self.state == -1: + return 0 + count = len( + list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1])) + ) + return count + + @property + def is_root(self): + return self.state == -1 + + @property + def name(self): + if self.state == -1: + return "ROOT" + else: + return str(self.sequence[self.state]) + + @property + def bank(self) -> int: + return self.state + 1 + + @property + def finished(self): + return self.state + 1 == len(self.sequence) + + @property + def token_counts(self): + return self.sequence.token_counts() + + @property + def tokens(self): + return self.sequence.tokens + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + tokens = set() + if self.state > 0: + tokens.add(self.sequence[0]) + if not self.finished: + tokens.add(self.sequence[self.state + 1]) + return tokens + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") + + if self.finished: + # Accept anything + next_state = self.copy() + + elif self.sequence[self.state + 1] == token: + # Advance to the next token + next_state = OrderedConstraintState(self.sequence, self.state + 1) + + elif self.sequence.endpoints[self.state]: + # Accept anything between constraints (*) + next_state = self.copy() + + elif token == self.sequence[0]: + # Start over having generated the first token + next_state = OrderedConstraintState(self.sequence, 0) + else: + # Start over from the root + next_state = OrderedConstraintState(self.sequence, -1) + + return next_state diff --git a/fairseq/tokenizer.py b/fairseq/tokenizer.py new file mode 100644 index 0000000..42131f7 --- /dev/null +++ b/fairseq/tokenizer.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re + + +SPACE_NORMALIZER = re.compile(r"\s+") + + +def tokenize_line(line): + line = SPACE_NORMALIZER.sub(" ", line) + line = line.strip() + return line.split() diff --git a/fairseq/utils.py b/fairseq/utils.py index 4d4b350..72bd35f 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -23,14 +23,14 @@ from fairseq.modules.multihead_attention import MultiheadAttention try: - from amp_C import multi_tensor_l2norm + from amp_C import multi_tensor_l2norm # type: ignore multi_tensor_l2norm_available = True except ImportError: multi_tensor_l2norm_available = False try: - import torch_xla.core.xla_model as xm + import torch_xla.core.xla_model as xm # type: ignore except ImportError: xm = None @@ -128,7 +128,7 @@ def _move_to_cpu(tensor): def move_to_tpu(sample): - import torch_xla.core.xla_model as xm + import torch_xla.core.xla_model as xm # type: ignore device = xm.xla_device() @@ -714,8 +714,8 @@ def get_tpu_device(): def tpu_data_loader(itr): - import torch_xla.core.xla_model as xm - import torch_xla.distributed.parallel_loader as pl + import torch_xla.core.xla_model as xm # type: ignore + import torch_xla.distributed.parallel_loader as pl # type: ignore from fairseq.data import iterators @@ -746,7 +746,7 @@ def index_put(tensor, indices, value): def xla_device_to_cpu(dat): - import torch_xla.core.xla_model as xm + import torch_xla.core.xla_model as xm # type: ignore return xm._maybe_convert_to_cpu(dat) @@ -890,7 +890,7 @@ def train_step(self, sample ....): * Need to launch train.py locally (cannot submit jobs) """ try: - import jurigged + import jurigged # type: ignore except ImportError as e: logger.warning("Please install jurigged: pip install jurigged[develoop]") raise e