Skip to content

Commit

Permalink
support freezing modules
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 28, 2024
1 parent 360f208 commit cfbc829
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions egs/librispeech/ASR/whisper/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,6 @@
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]


def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
Expand Down Expand Up @@ -226,6 +217,13 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--freeze-modules",
type=str,
default=None,
help="Which modules to freeze during finetune"
)

parser = deepspeed.add_config_arguments(parser)

return parser
Expand Down Expand Up @@ -583,6 +581,9 @@ def train_one_epoch(
be set to 0.
"""
model.train()
for name, module in model.named_modules():
if name.startswith(params.freeze_modules):
module.eval()

tot_loss = MetricsTracker()

Expand Down Expand Up @@ -630,7 +631,6 @@ def train_one_epoch(
model.step()
else:
scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)

scaler.step(optimizer)
Expand Down Expand Up @@ -739,8 +739,19 @@ def run(rank, world_size, args):
replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads

if params.freeze_modules is not None:
for name, p in model.named_parameters():
if name.startswith(params.freeze_modules):
p.requires_grad = False
logging.info(f"Do not update {name}")
for name, module in model.named_modules():
if name.startswith(params.freeze_modules):
module.eval()

num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")

tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual,
Expand Down

0 comments on commit cfbc829

Please sign in to comment.