From 8635fb4334de5854ee3f5286fff30ed35707578f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 5 May 2022 20:58:46 +0800 Subject: [PATCH] Fix decoding for gigaspeech in the libri + giga setup. (#345) --- .../decode-giga.py | 131 ++++++++++++++---- 1 file changed, 104 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index e6a9a0aee7..a715a2a5ca 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -69,7 +69,8 @@ from asr_datamodule import AsrDataModule from beam_search import ( beam_search, - fast_beam_search, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, greedy_search, greedy_search_batch, modified_beam_search, @@ -100,27 +101,28 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + parser.add_argument( - "--avg", + "--iter", type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, ) parser.add_argument( - "--avg-last-n", + "--avg", type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -146,6 +148,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest_oracle """, ) @@ -165,7 +168,8 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", + Used only when --decoding-method is + fast_beam_search or fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -173,7 +177,7 @@ def get_parser(): type=int, default=4, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search or fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -181,7 +185,7 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search or fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -199,6 +203,23 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for computed nbest oracle WER + when the decoding method is fast_beam_search_nbest_oracle. + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding_method is fast_beam_search_nbest_oracle. + """, + ) return parser @@ -243,7 +264,8 @@ def decode_one_batch( for the format of the `batch`. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is + fast_beam_search or fast_beam_search_nbest_oracle. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -264,7 +286,7 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search( + hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -275,6 +297,21 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -328,6 +365,16 @@ def decode_one_batch( f"max_states_{params.max_states}" ): hyps } + elif params.decoding_method == "fast_beam_search_nbest_oracle": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}" + ): hyps + } else: return {f"beam_size_{params.beam_size}": hyps} @@ -463,17 +510,30 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / "giga" / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif params.decoding_method == "fast_beam_search_nbest_oracle": params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-num-paths-{params.num_paths}" + params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -490,8 +550,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.unk_id() params.vocab_size = sp.get_piece_size() logging.info(params) @@ -499,8 +560,20 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) @@ -519,13 +592,17 @@ def main(): model.to(device) model.eval() model.device = device + model.unk_id = params.unk_id # In beam_search.py, we are using model.decoder() and model.joiner(), # so we have to switch to the branch for the GigaSpeech dataset. model.decoder = model.decoder_giga model.joiner = model.joiner_giga - if params.decoding_method == "fast_beam_search": + if params.decoding_method in ( + "fast_beam_search", + "fast_beam_search_nbest_oracle", + ): decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None