Skip to content

Commit

Permalink
fix error in accum_grad (#1693)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzasdf authored Jul 17, 2024
1 parent 2e13298 commit 1115141
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion egs/librispeech/SSL/hubert/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/valid_", params.batch_idx_train
)

if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/SSL/hubert/finetune_ce.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/valid_", params.batch_idx_train
)

if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/SSL/hubert/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/valid_", params.batch_idx_train
)

if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/SSL/hubert/pretrain_ce.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/valid_", params.batch_idx_train
)

if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/SSL/zipformer/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/valid_", params.batch_idx_train
)

if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/SSL/zipformer/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/valid_", params.batch_idx_train
)

if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
Expand Down

0 comments on commit 1115141

Please sign in to comment.