diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py index 76c7cc08..86b05c74 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py @@ -71,6 +71,7 @@ from snowfall.models.contextnet import ContextNet from snowfall.training.ctc_graph import build_ctc_topo from snowfall.training.mmi_graph import create_bigram_phone_lm +from snowfall.training.mmi_graph import create_unigram_phone_lm from snowfall.training.mmi_graph import get_phone_symbols def nbest_decoding(lats: k2.Fsa, num_paths: int): @@ -401,6 +402,15 @@ def get_parser(): type=str2bool, default=True, help='When enabled, it uses vgg style network for subsampling') + + parser.add_argument( + '--use-unigram-lm', + type=str2bool, + default=False, + help='True to use unigram LM for P. False to use bigram LM for P. ' + 'This is used only for checkpoint-loading.' + ) + return parser @@ -423,7 +433,10 @@ def main(): output_beam_size = args.output_beam_size - exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') + if args.use_unigram_lm: + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer-unigram') + else: + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') logging.info(f'output_beam_size: {output_beam_size}') @@ -434,7 +447,12 @@ def main(): phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) - P = create_bigram_phone_lm(phone_ids) + if args.use_unigram_lm: + logging.info('Use unigram LM for P') + P = create_unigram_phone_lm(phone_ids) + else: + logging.info('Use bigram LM for P') + P = create_bigram_phone_lm(phone_ids) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py index f9526481..3601843d 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py @@ -44,6 +44,7 @@ from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change from snowfall.training.mmi_graph import MmiTrainingGraphCompiler from snowfall.training.mmi_graph import create_bigram_phone_lm +from snowfall.training.mmi_graph import create_unigram_phone_lm def get_objf(batch: Dict, @@ -461,6 +462,14 @@ def get_parser(): 'so that they can be simply loaded with torch.jit.load(). ' '-1 disables this option.' ) + + parser.add_argument( + '--use-unigram-lm', + type=str2bool, + default=False, + help='True to use unigram LM for P. False to use bigram LM for P.' + ) + return parser @@ -487,7 +496,10 @@ def run(rank, world_size, args): fix_random_seed(42) setup_dist(rank, world_size, args.master_port) - exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') + if args.use_unigram_lm: + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer-unigram') + else: + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') setup_logger(f'{exp_dir}/log/log-train-{rank}') if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') @@ -507,7 +519,14 @@ def run(rank, world_size, args): device=device, ) phone_ids = lexicon.phone_symbols() - P = create_bigram_phone_lm(phone_ids) + + if args.use_unigram_lm: + logging.info('Use unigram LM for P') + P = create_unigram_phone_lm(phone_ids) + else: + logging.info('Use bigram LM for P') + P = create_bigram_phone_lm(phone_ids) + P.scores = torch.zeros_like(P.scores) P = P.to(device) diff --git a/snowfall/training/mmi_graph.py b/snowfall/training/mmi_graph.py index 758830f7..c970bbb3 100644 --- a/snowfall/training/mmi_graph.py +++ b/snowfall/training/mmi_graph.py @@ -38,7 +38,33 @@ def create_bigram_phone_lm(phones: List[int]) -> k2.Fsa: rules += f'{i} {j} {phones[j-1]} 0.0\n' rules += f'{i} {final_state} -1 0.0\n' rules += f'{final_state}' - return k2.Fsa.from_str(rules) + ans = k2.Fsa.from_str(rules) + return k2.arc_sort(ans) + +def create_unigram_phone_lm(phones: List[int]) -> k2.Fsa: + '''Create a unigram phone LM. + The resulting FSA (P) has two states: a start state and a + final state. For each phone, there is a corresponding self-loop + at the start state. + + Caution: + blank is not a phone. + + Args: + A list of phone IDs. + + Returns: + An FSA representing the unigram phone LM. + ''' + assert 0 not in phones + + rules = '0 1 -1 0.0\n' + for i in phones: + rules += f'0 0 {i} 0.0\n' + rules += '1\n' + + ans = k2.Fsa.from_str(rules) + return k2.arc_sort(ans) class MmiTrainingGraphCompiler(object):