Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decoding method of ctc-greedy-search in zipformer recipe #1690

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions egs/librispeech/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@
"""
Usage:

(1) ctc-decoding
(1) ctc-greedy-search
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--max-duration 600 \
--decoding-method ctc-greedy-search

(2) ctc-decoding
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -30,7 +39,7 @@
--max-duration 600 \
--decoding-method ctc-decoding

(2) 1best
(3) 1best
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -40,7 +49,7 @@
--hlg-scale 0.6 \
--decoding-method 1best

(3) nbest
(4) nbest
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -50,7 +59,7 @@
--hlg-scale 0.6 \
--decoding-method nbest

(4) nbest-rescoring
(5) nbest-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -62,7 +71,7 @@
--lm-dir data/lm \
--decoding-method nbest-rescoring

(5) whole-lattice-rescoring
(6) whole-lattice-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -74,7 +83,7 @@
--lm-dir data/lm \
--decoding-method whole-lattice-rescoring

(6) attention-decoder-rescoring-no-ngram
(7) attention-decoder-rescoring-no-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -84,7 +93,7 @@
--max-duration 100 \
--decoding-method attention-decoder-rescoring-no-ngram

(7) attention-decoder-rescoring-with-ngram
(8) attention-decoder-rescoring-with-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
Expand Down Expand Up @@ -120,6 +129,7 @@
load_checkpoint,
)
from icefall.decode import (
ctc_greedy_search,
get_lattice,
nbest_decoding,
nbest_oracle,
Expand Down Expand Up @@ -220,26 +230,29 @@ def get_parser():
default="ctc-decoding",
help="""Decoding method.
Supported values are:
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (2) 1best. Extract the best path from the decoding lattice as the
- (2) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (3) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (3) nbest. Extract n paths from the decoding lattice; the path
- (4) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
- (5) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
- (6) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py
- (6) nbest-oracle. Its WER is the lower bound of any n-best
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
- (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder.
""",
)
Expand Down Expand Up @@ -381,6 +394,15 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
ctc_output = model.ctc_output(encoder_out) # (N, T, C)

if params.decoding_method == "ctc-greedy-search":
hyps = ctc_greedy_search(ctc_output, encoder_out_lens)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(hyps)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-greedy-search"
return {key: hyps}

supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
Expand Down Expand Up @@ -684,6 +706,7 @@ def main():
params.update(vars(args))

assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding",
"1best",
"nbest",
Expand Down Expand Up @@ -733,7 +756,9 @@ def main():
params.eos_id = 1
params.sos_id = 1

if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
if params.decoding_method in [
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram"
]:
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
Expand Down
31 changes: 31 additions & 0 deletions icefall/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,3 +1473,34 @@ def rescore_with_rnn_lm(
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_rnn_lm_scale_{r_scale}" # noqa
ans[key] = best_path
return ans


def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp


def ctc_greedy_search(
ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor
) -> List[List[int]]:
"""CTC greedy search.

Args:
ctc_output: (batch, seq_len, vocab_size)
encoder_out_lens: (batch,)
Returns:
List[List[int]]: greedy search result
"""
batch = ctc_output.shape[0]
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
return hyps
Loading