Skip to content

Commit

Permalink
update the decoding script
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 28, 2024
1 parent cfbc829 commit 5d41dec
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions egs/librispeech/ASR/whisper/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,31 +348,26 @@ def save_results(
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f,
f"{test_set_name}-{key}",
results_char,
results,
enable_log=enable_log,
compute_CER=True,
)
test_set_wers[key] = wer

if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))

test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)

s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
Expand All @@ -391,16 +386,21 @@ def main():
params.update(vars(args))
params.res_dir = params.exp_dir / params.method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.method == "beam_search":
params.suffix += f"-beam-search-beam-size-{params.beam_size}"

params.suffix += f"-whisper-{params.model_name}"
setup_logger(
f"{params.res_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
f"{params.res_dir}/log-{params.method}/log-decode-{params.suffix}"
)

options = whisper.DecodingOptions(
task="transcribe",
language="en",
without_timestamps=True,
#beam_size=params.beam_size,
beam_size=params.beam_size if params.method == "beam_search" else None,
)

params.decoding_options = options
params.cleaner = BasicTextNormalizer()
params.normalizer = Normalizer()
Expand Down

0 comments on commit 5d41dec

Please sign in to comment.