diff --git a/egs/aishell/asr/simple_v1/mmi_att_transformer_decode.py b/egs/aishell/asr/simple_v1/mmi_att_transformer_decode.py index d2c590df..2e4a5676 100644 --- a/egs/aishell/asr/simple_v1/mmi_att_transformer_decode.py +++ b/egs/aishell/asr/simple_v1/mmi_att_transformer_decode.py @@ -26,8 +26,8 @@ from snowfall.common import setup_logger from snowfall.decoding.graph import compile_LG from snowfall.models import AcousticModel -from snowfall.models.transformer import Transformer from snowfall.models.conformer import Conformer +from snowfall.models.transformer import Transformer from snowfall.training.ctc_graph import build_ctc_topo from snowfall.training.mmi_graph import create_bigram_phone_lm from snowfall.training.mmi_graph import get_phone_symbols @@ -268,7 +268,8 @@ def main(): P.set_scores_stochastic_(model.P_scores) print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='P_scores.txt') - if not os.path.exists(lang_dir / 'LG.pt'): + HLG_path = exp_dir / 'HLG.pt' + if not HLG_path.exists(): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) @@ -282,10 +283,10 @@ def main(): ctc_topo=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) - torch.save(LG.as_dict(), lang_dir / 'LG.pt') + torch.save(LG.as_dict(), HLG_path) else: - logging.debug("Loading pre-compiled LG") - d = torch.load(lang_dir / 'LG.pt') + logging.debug("Loading pre-compiled HLG") + d = torch.load(HLG_path) LG = k2.Fsa.from_dict(d) # load dataset diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py index ba771fdd..f40befab 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py @@ -25,9 +25,9 @@ from snowfall.common import setup_logger from snowfall.decoding.graph import compile_HLG from snowfall.models import AcousticModel -from snowfall.models.transformer import Transformer from snowfall.models.conformer import Conformer -from snowfall.training.ctc_graph import build_ctc_topo +from snowfall.models.transformer import Transformer +from snowfall.training.hmm_topo import build_hmm_topo_2state from snowfall.training.mmi_graph import create_bigram_phone_lm from snowfall.training.mmi_graph import get_phone_symbols @@ -206,7 +206,7 @@ def main(): avg = args.avg att_rate = args.att_rate - exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan') + exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-hmm') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table @@ -218,7 +218,8 @@ def main(): P = create_bigram_phone_lm(phone_ids) phone_ids_with_blank = [0] + phone_ids - ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) + # H = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) + H = build_hmm_topo_2state(phone_ids_with_blank) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N @@ -235,7 +236,7 @@ def main(): num_features=40, nhead=args.nhead, d_model=args.attention_dim, - num_classes=len(phone_ids) + 1, # +1 for the blank symbol + num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) else: @@ -243,7 +244,7 @@ def main(): num_features=40, nhead=args.nhead, d_model=args.attention_dim, - num_classes=len(phone_ids) + 1, # +1 for the blank symbol + num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) @@ -267,7 +268,8 @@ def main(): P.set_scores_stochastic_(model.P_scores) print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='P_scores.txt') - if not os.path.exists(lang_dir / 'HLG.pt'): + HLG_path = exp_dir / 'HLG.pt' + if not HLG_path.exists(): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) @@ -277,14 +279,14 @@ def main(): first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) HLG = compile_HLG(L=L, - G=G, - H=ctc_topo, - labels_disambig_id_start=first_phone_disambig_id, - aux_labels_disambig_id_start=first_word_disambig_id) - torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') + G=G, + H=H, + labels_disambig_id_start=first_phone_disambig_id, + aux_labels_disambig_id_start=first_word_disambig_id) + torch.save(HLG.as_dict(), HLG_path) else: logging.debug("Loading pre-compiled HLG") - d = torch.load(lang_dir / 'HLG.pt') + d = torch.load(HLG_path) HLG = k2.Fsa.from_dict(d) # load dataset diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py index 836f404f..1ab4ac85 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py @@ -29,9 +29,10 @@ from snowfall.common import save_training_info from snowfall.common import setup_logger from snowfall.models import AcousticModel -from snowfall.models.transformer import Noam, Transformer from snowfall.models.conformer import Conformer +from snowfall.models.transformer import Noam, Transformer from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change +from snowfall.training.hmm_topo import build_hmm_topo_2state from snowfall.training.mmi_graph import MmiTrainingGraphCompiler from snowfall.training.mmi_graph import create_bigram_phone_lm from snowfall.training.mmi_graph import get_phone_symbols @@ -64,7 +65,7 @@ def get_tot_objf_and_num_frames(tot_scores: torch.Tensor, frames_per_seq[bad_indexes], " vs. max length ", torch.max(frames_per_seq), ", avg ", (torch.sum(frames_per_seq) / frames_per_seq.numel())) - # print("finite_indexes = ", finite_indexes, ", tot_scores = ", tot_scores) + #print("finite_indexes = ", finite_indexes, ", tot_scores = ", tot_scores) ok_frames = frames_per_seq[finite_indexes].sum() all_frames = frames_per_seq.sum() return (tot_scores[finite_indexes].sum(), ok_frames, all_frames) @@ -134,6 +135,7 @@ def get_objf(batch: Dict, num = k2.intersect_dense(num, dense_fsa_vec, 10.0) den = k2.intersect_dense(den, dense_fsa_vec, 10.0) + #den = k2.intersect_dense_pruned(den, dense_fsa_vec, search_beam=10.0, output_beam=10.0, min_active_states=100, max_active_states=1000) num_tot_scores = num.get_tot_scores( log_semiring=True, @@ -446,7 +448,7 @@ def main(): fix_random_seed(42) - exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan') + exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-hmm') setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') @@ -467,11 +469,13 @@ def main(): device_id = 0 device = torch.device('cuda', device_id) + logging.info('Initializing the MMI graph compiler') graph_compiler = MmiTrainingGraphCompiler( L_inv=L_inv, phones=phone_symbol_table, words=word_symbol_table, device=device, + topo_builder_fn=build_hmm_topo_2state ) phone_ids = get_phone_symbols(phone_symbol_table) P = create_bigram_phone_lm(phone_ids) @@ -550,7 +554,7 @@ def main(): num_features=40, nhead=args.nhead, d_model=args.attention_dim, - num_classes=len(phone_ids) + 1, # +1 for the blank symbol + num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) else: @@ -558,7 +562,7 @@ def main(): num_features=40, nhead=args.nhead, d_model=args.attention_dim, - num_classes=len(phone_ids) + 1, # +1 for the blank symbol + num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) diff --git a/snowfall/training/hmm_topo.py b/snowfall/training/hmm_topo.py new file mode 100644 index 00000000..9cc882e9 --- /dev/null +++ b/snowfall/training/hmm_topo.py @@ -0,0 +1,52 @@ +import k2 +from typing import List + + +def build_hmm_topo_2state(tokens: List[int]) -> k2.Fsa: + """ + Build a 2-state HMM topology used in Kaldi's chain models. + The first HMM state is entered only once for each token instance, + and the second HMM state is self-looped and optional. + + Args: + tokens: + A list of token int IDs, e.g., phones, characters, etc. + The IDs for the first HMM state will be the same as token IDs; + The IDs for the second HMM state are: ``token_id + len(tokens)`` + Returns: + An FST that converts a sequence of HMM state IDs to a sequence of token IDs. + """ + min_token_id = min(tokens) + followup_tokens = list(range( + len(tokens) + min_token_id, + 2 * len(tokens) + min_token_id + )) + num_states = len(tokens) + 2 # + start state, + final state + arcs = [] + + # Start state -> token state + for i in range(0, len(tokens)): + arcs += [f'0 {i + 1} {tokens[i]} {tokens[i]} 0.0'] + + # Token state self loops + for i in range(0, len(tokens)): + arcs += [f'{i + 1} {i + 1} {followup_tokens[i]} 0 0.0'] + + # Cross-token transitions + for i in range(0, len(tokens)): + for j in range(0, len(tokens)): + if i != j: + arcs += [f'{i + 1} {j + 1} {tokens[i]} {tokens[i]} 0.0'] + + # Token state -> superfinal state + for i in range(0, len(tokens)): + arcs += [f'{i + 1} {num_states - 1} -1 -1 0.0'] + + # Final state + arcs += [f'{num_states - 1}'] + + # Build the FST + arcs = '\n'.join(sorted(arcs, key=lambda arc: int(arc.split()[0]))) + ans = k2.Fsa.from_str(arcs) + ans = k2.arc_sort(ans) + return ans diff --git a/snowfall/training/mmi_graph.py b/snowfall/training/mmi_graph.py index ecdbd1d8..e0ad29e6 100644 --- a/snowfall/training/mmi_graph.py +++ b/snowfall/training/mmi_graph.py @@ -1,14 +1,13 @@ # Copyright (c) 2020 Xiaomi Corp. (author: Fangjun Kuang) +import k2 +import torch from typing import Iterable from typing import List from typing import Tuple -import k2 -import torch - -from .ctc_graph import build_ctc_topo from snowfall.common import get_phone_symbols +from .ctc_graph import build_ctc_topo def create_bigram_phone_lm(phones: List[int]) -> k2.Fsa: @@ -47,6 +46,7 @@ def __init__(self, phones: k2.SymbolTable, words: k2.SymbolTable, device: torch.device, + topo_builder_fn=build_ctc_topo, oov: str = ''): ''' Args: @@ -78,10 +78,9 @@ def __init__(self, phone_symbols = get_phone_symbols(phones) phone_symbols_with_blank = [0] + phone_symbols - ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) - assert ctc_topo.requires_grad is False - - self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) + H = topo_builder_fn(phone_symbols_with_blank).to(device) + assert H.requires_grad is False + self.H_inv = k2.arc_sort(H.invert_()) def compile(self, texts: Iterable[str], P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]: @@ -106,28 +105,28 @@ def compile(self, texts: Iterable[str], assert P.device == self.device P_with_self_loops = k2.add_epsilon_self_loops(P) - ctc_topo_P = k2.intersect(self.ctc_topo_inv, - P_with_self_loops, - treat_epsilons_specially=False).invert() - - ctc_topo_P = k2.arc_sort(ctc_topo_P) + HP = k2.intersect( + self.H_inv, + P_with_self_loops, + treat_epsilons_specially=False + ).invert() + HP = k2.arc_sort(HP) num_graphs = self.build_num_graphs(texts) num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( num_graphs) - num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) - num = k2.compose(ctc_topo_P, + num = k2.compose(HP, num_graphs_with_self_loops, treat_epsilons_specially=False) num = k2.arc_sort(num) - ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) + HP_vec = k2.create_fsa_vec([HP.detach()]) indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) - den = k2.index_fsa(ctc_topo_P_vec, indexes) + den = k2.index_fsa(HP_vec, indexes) return num, den @@ -163,3 +162,4 @@ def build_num_graphs(self, texts: List[str]) -> k2.Fsa: treat_epsilons_specially=False).invert_() num_graphs = k2.arc_sort(num_graphs) return num_graphs +