Skip to content

Commit

Permalink
update export.py and pretrained_ctc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozengwei committed May 26, 2024
1 parent 84dfb57 commit acdc333
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 14 deletions.
31 changes: 29 additions & 2 deletions egs/librispeech/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,29 @@
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method whole-lattice-rescoring
(6) attention-decoder-rescoring-no-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--decoding-method attention-decoder-rescoring-no-ngram
(7) attention-decoder-rescoring-with-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method attention-decoder-rescoring-with-ngram
"""


Expand Down Expand Up @@ -101,10 +124,10 @@
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
rescore_with_attention_decoder_no_ngram,
rescore_with_attention_decoder_with_ngram,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
Expand Down Expand Up @@ -214,6 +237,10 @@ def get_parser():
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder.
""",
)

Expand Down
3 changes: 1 addition & 2 deletions egs/librispeech/ASR/zipformer/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def main():

token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.sos_id = params.eos_id = token_table["<sos/eos>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)
Expand Down Expand Up @@ -466,8 +467,6 @@ def main():
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
Expand Down
40 changes: 32 additions & 8 deletions egs/librispeech/ASR/zipformer/pretrained_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(5) attention-decoder-rescoring-no-ngram
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method attention-decoder-rescoring-no-ngram \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
"""

import argparse
Expand All @@ -100,6 +109,7 @@
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder_no_ngram,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
Expand Down Expand Up @@ -172,6 +182,8 @@ def get_parser():
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + whole-lattice n-gram LM rescoring.
(4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
""",
)

Expand Down Expand Up @@ -276,6 +288,7 @@ def main():
token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
params.blank_id = token_table["<blk>"]
params.sos_id = params.eos_id = token_table["<sos/eos>"]
assert params.blank_id == 0

logging.info(f"{params}")
Expand Down Expand Up @@ -333,16 +346,13 @@ def main():
dtype=torch.int32,
)

if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
max_token_id = params.vocab_size - 1

H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)

lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=H,
Expand All @@ -354,9 +364,23 @@ def main():
subsampling_factor=params.subsampling_factor,
)

best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
else:
logging.info("Use attention decoder rescoring without ngram")
best_path_dict = rescore_with_attention_decoder_no_ngram(
lattice=lattice,
num_paths=params.num_paths,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
nbest_scale=params.nbest_scale,
)
best_path = next(iter(best_path_dict.values()))

token_ids = get_texts(best_path)
hyps = [[token_table[i] for i in ids] for ids in token_ids]
elif params.method in [
Expand Down Expand Up @@ -430,7 +454,7 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}")

s = "\n"
if params.method == "ctc-decoding":
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("▁", " ").strip()
Expand Down
3 changes: 1 addition & 2 deletions egs/librispeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,8 +1199,7 @@ def run(rank, world_size, args):

# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.eos_id = sp.piece_to_id("<sos/eos>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()

if not params.use_transducer:
Expand Down

0 comments on commit acdc333

Please sign in to comment.