Skip to content

Commit

Permalink
Reformat by black non-streaming zipformer recipe for ksponspeech (k2-…
Browse files Browse the repository at this point in the history
  • Loading branch information
whsqkaak authored and Your Name committed Aug 9, 2024
1 parent 5dae5fb commit 18729f1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
8 changes: 5 additions & 3 deletions egs/ksponspeech/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,9 @@ def save_results(
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
cer = write_error_stats(f, f"{test_set_name}-{key}", results, compute_CER=True)
cer = write_error_stats(
f, f"{test_set_name}-{key}", results, compute_CER=True
)
test_set_cers[key] = cer

logging.info("Wrote detailed error stats to {}".format(errs_filename))
Expand Down Expand Up @@ -807,15 +809,15 @@ def main():

# we need cut ids to display recognition results.
args.return_cuts = True

ksponspeech = KsponSpeechAsrDataModule(args)

eval_clean_cuts = ksponspeech.eval_clean_cuts()
eval_other_cuts = ksponspeech.eval_other_cuts()

eval_clean_dl = ksponspeech.test_dataloaders(eval_clean_cuts)
eval_other_dl = ksponspeech.test_dataloaders(eval_other_cuts)

test_sets = ["eval_clean", "eval_other"]
test_dl = [eval_clean_dl, eval_other_dl]

Expand Down
6 changes: 5 additions & 1 deletion egs/ksponspeech/ASR/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,11 @@ def save_results(
)
with open(errs_filename, "w") as f:
cer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True,
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
)
test_set_cers[key] = cer

Expand Down
6 changes: 5 additions & 1 deletion egs/ksponspeech/ASR/zipformer/streaming_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,11 @@ def save_results(
)
with open(errs_filename, "w") as f:
cer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True,
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
)
test_set_cers[key] = cer

Expand Down
4 changes: 1 addition & 3 deletions egs/ksponspeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,7 @@ def save_bad_model(suffix: str = ""):
scaler.update()
optimizer.zero_grad()
except Exception as e:
logging.info(
f"Caught exception: {e}."
)
logging.info(f"Caught exception: {e}.")
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
Expand Down

0 comments on commit 18729f1

Please sign in to comment.