From 18729f13f1e6fe369642eb1a0d1b6704923449da Mon Sep 17 00:00:00 2001 From: Seung Hyun Lee Date: Mon, 24 Jun 2024 16:28:09 +0900 Subject: [PATCH] Reformat by black non-streaming zipformer recipe for ksponspeech (#1665) --- egs/ksponspeech/ASR/zipformer/ctc_decode.py | 8 +++++--- egs/ksponspeech/ASR/zipformer/decode.py | 6 +++++- egs/ksponspeech/ASR/zipformer/streaming_decode.py | 6 +++++- egs/ksponspeech/ASR/zipformer/train.py | 4 +--- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/egs/ksponspeech/ASR/zipformer/ctc_decode.py b/egs/ksponspeech/ASR/zipformer/ctc_decode.py index 9f04f5d4dc..30bf1610b3 100755 --- a/egs/ksponspeech/ASR/zipformer/ctc_decode.py +++ b/egs/ksponspeech/ASR/zipformer/ctc_decode.py @@ -571,7 +571,9 @@ def save_results( # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: - cer = write_error_stats(f, f"{test_set_name}-{key}", results, compute_CER=True) + cer = write_error_stats( + f, f"{test_set_name}-{key}", results, compute_CER=True + ) test_set_cers[key] = cer logging.info("Wrote detailed error stats to {}".format(errs_filename)) @@ -807,7 +809,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - + ksponspeech = KsponSpeechAsrDataModule(args) eval_clean_cuts = ksponspeech.eval_clean_cuts() @@ -815,7 +817,7 @@ def main(): eval_clean_dl = ksponspeech.test_dataloaders(eval_clean_cuts) eval_other_dl = ksponspeech.test_dataloaders(eval_other_cuts) - + test_sets = ["eval_clean", "eval_other"] test_dl = [eval_clean_dl, eval_other_dl] diff --git a/egs/ksponspeech/ASR/zipformer/decode.py b/egs/ksponspeech/ASR/zipformer/decode.py index be42898b75..5c21abb790 100755 --- a/egs/ksponspeech/ASR/zipformer/decode.py +++ b/egs/ksponspeech/ASR/zipformer/decode.py @@ -727,7 +727,11 @@ def save_results( ) with open(errs_filename, "w") as f: cer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_cers[key] = cer diff --git a/egs/ksponspeech/ASR/zipformer/streaming_decode.py b/egs/ksponspeech/ASR/zipformer/streaming_decode.py index 9811bac7c3..73a681c6ad 100755 --- a/egs/ksponspeech/ASR/zipformer/streaming_decode.py +++ b/egs/ksponspeech/ASR/zipformer/streaming_decode.py @@ -659,7 +659,11 @@ def save_results( ) with open(errs_filename, "w") as f: cer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_cers[key] = cer diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index 5957fe1fb1..b612b6835a 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -961,9 +961,7 @@ def save_bad_model(suffix: str = ""): scaler.update() optimizer.zero_grad() except Exception as e: - logging.info( - f"Caught exception: {e}." - ) + logging.info(f"Caught exception: {e}.") save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise