diff --git a/users/vieting/experiments/librispeech/librispeech_960_pretraining/wav2vec2/config_02_fairseq_phoneme.py b/users/vieting/experiments/librispeech/librispeech_960_pretraining/wav2vec2/config_02_fairseq_phoneme.py index c71bbb5c6..ce79cb86d 100644 --- a/users/vieting/experiments/librispeech/librispeech_960_pretraining/wav2vec2/config_02_fairseq_phoneme.py +++ b/users/vieting/experiments/librispeech/librispeech_960_pretraining/wav2vec2/config_02_fairseq_phoneme.py @@ -63,7 +63,7 @@ def get_fairseq_root(commit="e4a2e4e93efbcbaaae52a17ae6600beb2083fb33", fairseq_ return fairseq_root -def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, **kwargs): +def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, checkpoint=None, **kwargs): """ Runs a FairseqHydraTrainingJob to pretrain a wav2vec 2.0 model. @@ -73,6 +73,8 @@ def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, ** python_exe_hash_overwrite (Optional[str]): The hash overwrite for the fairseq_python_exe to use. It should only be used to achieve compatibility with the previous setup structure and should be ignored in all other cases. + checkpoint (Optional[str]): The path to the checkpoint to start from. If None, the training will start + from scratch. **kwargs: Additional arguments to pass to the job. These will be used to overwrite the model configuration. """ # job requirements @@ -93,6 +95,8 @@ def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, ** # generate config fairseq_args = get_fairseq_args(num_gpus=num_gpus) fairseq_args["task"]["alignment"] = alignment + if checkpoint is not None: + fairseq_args["checkpoint"]["continue_once"] = checkpoint for k, v in kwargs.items(): fairseq_args["model"][k] = v fairseq_config = FairseqHydraConfig(fairseq_args) @@ -103,49 +107,6 @@ def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, ** tk.register_output(f"{prefix_name}/{exp_name}/pretraining/scores.png", job.out_plot_se) return job -def run_fairseq_pretraining_from_checkpoint(exp_name, commit, checkpoint, python_exe_hash_overwrite=None, **kwargs): - """ - Runs a FairseqHydraTrainingJob to pretrain a wav2vec 2.0 model. - - Args: - exp_name (str): The name of the experiment, used for output and alias folder. - commit (str): The commit ID of the fairseq_phoneme repository to use. - checkpoint (str): The path to the checkpoint to start from. - python_exe_hash_overwrite (Optional[str]): The hash overwrite for the fairseq_python_exe to use. - It should only be used to achieve compatibility with the previous setup structure and should be ignored - in all other cases. - **kwargs: Additional arguments to pass to the job. - """ - # job requirements - prefix_name = "experiments/librispeech/librispeech_960_pretraining/wav2vec2/" - alignment = get_alignment_hdf() - num_gpus = 8 - fairseq_python_exe = tk.Path("/usr/bin/python3", hash_overwrite=python_exe_hash_overwrite) - fairseq_root = get_fairseq_root(fairseq_exe=fairseq_python_exe, commit=commit) - fairseq_training_args = dict( - save_interval=25, - max_epoch=600, - max_update=420000, - fairseq_root=fairseq_root, - fairseq_python_exe=fairseq_python_exe, - rqmt={"time": 336, "mem": 16, "cpu": 2, "gpu": num_gpus}, - ) - - # generate config - fairseq_args = get_fairseq_args(num_gpus=num_gpus) - fairseq_args["task"]["alignment"] = alignment - fairseq_args["checkpoint"]["continue_once"] = checkpoint - for k, v in kwargs.items(): - fairseq_args["model"][k] = v - - fairseq_config = FairseqHydraConfig(fairseq_args) - - # run pretraining - job = FairseqHydraTrainingJob(fairseq_config, **fairseq_training_args) - job.add_alias(os.path.join(prefix_name, exp_name, "pretraining")) - tk.register_output(f"{prefix_name}/{exp_name}/pretraining/scores.png", job.out_plot_se) - return job - def py(): # negatives other