Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreasPlt committed Dec 18, 2024
1 parent 2845bdd commit 11cb4a8
Showing 1 changed file with 5 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_fairseq_root(commit="e4a2e4e93efbcbaaae52a17ae6600beb2083fb33", fairseq_
return fairseq_root


def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, **kwargs):
def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, checkpoint=None, **kwargs):
"""
Runs a FairseqHydraTrainingJob to pretrain a wav2vec 2.0 model.
Expand All @@ -73,6 +73,8 @@ def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, **
python_exe_hash_overwrite (Optional[str]): The hash overwrite for the fairseq_python_exe to use.
It should only be used to achieve compatibility with the previous setup structure and should be ignored
in all other cases.
checkpoint (Optional[str]): The path to the checkpoint to start from. If None, the training will start
from scratch.
**kwargs: Additional arguments to pass to the job. These will be used to overwrite the model configuration.
"""
# job requirements
Expand All @@ -93,6 +95,8 @@ def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, **
# generate config
fairseq_args = get_fairseq_args(num_gpus=num_gpus)
fairseq_args["task"]["alignment"] = alignment
if checkpoint is not None:
fairseq_args["checkpoint"]["continue_once"] = checkpoint
for k, v in kwargs.items():
fairseq_args["model"][k] = v
fairseq_config = FairseqHydraConfig(fairseq_args)
Expand All @@ -103,49 +107,6 @@ def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, **
tk.register_output(f"{prefix_name}/{exp_name}/pretraining/scores.png", job.out_plot_se)
return job

def run_fairseq_pretraining_from_checkpoint(exp_name, commit, checkpoint, python_exe_hash_overwrite=None, **kwargs):
"""
Runs a FairseqHydraTrainingJob to pretrain a wav2vec 2.0 model.
Args:
exp_name (str): The name of the experiment, used for output and alias folder.
commit (str): The commit ID of the fairseq_phoneme repository to use.
checkpoint (str): The path to the checkpoint to start from.
python_exe_hash_overwrite (Optional[str]): The hash overwrite for the fairseq_python_exe to use.
It should only be used to achieve compatibility with the previous setup structure and should be ignored
in all other cases.
**kwargs: Additional arguments to pass to the job.
"""
# job requirements
prefix_name = "experiments/librispeech/librispeech_960_pretraining/wav2vec2/"
alignment = get_alignment_hdf()
num_gpus = 8
fairseq_python_exe = tk.Path("/usr/bin/python3", hash_overwrite=python_exe_hash_overwrite)
fairseq_root = get_fairseq_root(fairseq_exe=fairseq_python_exe, commit=commit)
fairseq_training_args = dict(
save_interval=25,
max_epoch=600,
max_update=420000,
fairseq_root=fairseq_root,
fairseq_python_exe=fairseq_python_exe,
rqmt={"time": 336, "mem": 16, "cpu": 2, "gpu": num_gpus},
)

# generate config
fairseq_args = get_fairseq_args(num_gpus=num_gpus)
fairseq_args["task"]["alignment"] = alignment
fairseq_args["checkpoint"]["continue_once"] = checkpoint
for k, v in kwargs.items():
fairseq_args["model"][k] = v

fairseq_config = FairseqHydraConfig(fairseq_args)

# run pretraining
job = FairseqHydraTrainingJob(fairseq_config, **fairseq_training_args)
job.add_alias(os.path.join(prefix_name, exp_name, "pretraining"))
tk.register_output(f"{prefix_name}/{exp_name}/pretraining/scores.png", job.out_plot_se)
return job


def py():
# negatives other
Expand Down

0 comments on commit 11cb4a8

Please sign in to comment.