Skip to content

Commit

Permalink
update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 28, 2024
1 parent 711859c commit ebc0f3b
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions egs/librispeech/ASR/whisper/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--manifest-dir data/fbank_whisper \
--full-libri True \
--manifest-dir data/fbank_whisper_80D \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
# fine-tuning with ddp
torchrun --nproc_per_node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_medium \
--manifest-dir data/fbank_whisper \
--full-libri True \
--manifest-dir data/fbank_whisper_80D \
--base-lr 1e-5 \
--model-name medium
"""
Expand All @@ -53,7 +55,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
import whisper
from asr_datamodule import AishellAsrDataModule
from asr_datamodule import LibriSpeechAsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
Expand Down Expand Up @@ -147,7 +149,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
choices=["large-v2", "large-v3", "medium", "medium.en", "small", "small.en", "tiny", "tiny.en"],
help="""The model name to use.
""",
)
Expand Down Expand Up @@ -450,8 +452,7 @@ def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
batch_idx_train = params.batch_idx_train

texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [text.replace(" ", "") for text in texts]
texts = [t[0] + t[1:].lower() for t in texts]

text_tokens_list = [
list(tokenizer.sot_sequence_including_notimestamps)
Expand Down Expand Up @@ -744,7 +745,7 @@ def run(rank, world_size, args):
tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language="zh",
language="en",
task="transcribe",
)

Expand Down Expand Up @@ -800,7 +801,19 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)

aishell = AishellAsrDataModule(args)
librispeech = LibriSpeechAsrDataModule(args)

if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()

def remove_short_and_long_utt(c: Cut):
if c.duration < 1.0 or c.duration > 20.0:
return False
return True

train_cuts = train_cuts.filter(remove_short_and_long_utt)

if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
Expand All @@ -809,8 +822,16 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None

train_dl = aishell.train_dataloaders(aishell.train_cuts())
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)

valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()

# do this to prevent Whisper throwing the length mismatch error
valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
valid_dl = librispeech.valid_dataloaders(valid_cuts)

scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
Expand Down Expand Up @@ -911,7 +932,7 @@ def display_and_save_batch(

def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)

Expand Down

0 comments on commit ebc0f3b

Please sign in to comment.