Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Use MMI not CTC model for alignment #203

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def main():

output_beam_size = args.output_beam_size

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg-mmiali')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

logging.info(f'output_beam_size: {output_beam_size}')
Expand Down
23 changes: 15 additions & 8 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,16 @@ def get_objf(batch: Dict,
# scale less than one so it will be encouraged
# to mimic ali_model's output
ali_model_scale = 500.0 / (global_batch_idx_train // accum_grad + 500)
nnet_output_orig = nnet_output
nnet_output = nnet_output.clone() # or log-softmax backprop will fail.
nnet_output[:, :,:min_len] += ali_model_scale * ali_model_output[:, :,:min_len]

x = nnet_output.abs().sum().item()
if x - x != 0:
print("Warning: reverting nnet output since it seems to be nan.")
nnet_output = nnet_output_orig
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GNroy perhaps this is related to the error you had? I found that sometimes I'd get NaN's in the forward pass of the alignment model. I commented out ali_model.eval() as well as making this change, because I suspected that it had to do with test-mode batchnorm, but I might have been wrong, I need to test this. It might relate to float16 usage (or a combination of the two).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Actually, I resolved my issue.
NaNs were produced by the encoder part (not a loss or softmax problem as I thought before).
It was fixed with some hyperparameters re-tuning. In particular, setting eps=1e-3 for the optimizer helped.



# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C]

Expand Down Expand Up @@ -417,11 +424,11 @@ def get_parser():
'--use-ali-model',
type=str2bool,
default=True,
help='If true, we assume that you have run ./ctc_train.py '
'and you have some checkpoints inside the directory '
'exp-lstm-adam-ctc-musan/ .'
'It will use exp-lstm-adam-ctc-musan/epoch-{ali-model-epoch}.pt '
'as the pre-trained alignment model'
help='If true, we assume that you have run ./mmi_bigram_train.py '
'and you have some checkpoints inside the directory '
'exp-lstm-adam-mmi-bigram-musan-dist-s4/. It will use '
'exp-lstm-adam-mmi-bigram-musan-dist-s4/epoch-{ali-model-epoch}.pt ' # noqa
'as the pre-trained alignment model'
)
parser.add_argument(
'--ali-model-epoch',
Expand Down Expand Up @@ -465,7 +472,7 @@ def run(rank, world_size, args):
fix_random_seed(42)
setup_dist(rank, world_size, args.master_port)

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg-mmiali')
setup_logger(f'{exp_dir}/log/log-train-{rank}')
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
Expand Down Expand Up @@ -548,13 +555,13 @@ def run(rank, world_size, args):
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4)

ali_model_fname = Path(f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt')
ali_model_fname = Path(f'exp-lstm-adam-mmi-bigram-musan-dist-s4/epoch-{args.ali_model_epoch}.pt')
assert ali_model_fname.is_file(), \
f'ali model filename {ali_model_fname} does not exist!'
ali_model.load_state_dict(torch.load(ali_model_fname, map_location='cpu')['state_dict'])
ali_model.to(device)

ali_model.eval()
# ali_model.eval()
ali_model.requires_grad_(False)
logging.info(f'Use ali_model: {ali_model_fname}')
else:
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/asr/simple_v1/mmi_bigram_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_parser():

def main():
args = get_parser().parse_args()
exp_dir = Path('exp-lstm-adam-mmi-bigram-musan-dist')
exp_dir = Path('exp-lstm-adam-mmi-bigram-musan-dist-s4')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

# load L, G, symbol_table
Expand All @@ -183,7 +183,7 @@ def main():
device = torch.device('cuda')
model = TdnnLstm1b(num_features=80,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=3)
subsampling_factor=4)
model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False)

checkpoint = os.path.join(exp_dir, f'epoch-{args.epoch}.pt')
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/asr/simple_v1/mmi_bigram_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def main():
num_epochs = 10
use_adam = True

exp_dir = f'exp-lstm-adam-mmi-bigram-musan-dist'
exp_dir = f'exp-lstm-adam-mmi-bigram-musan-dist-s4'
setup_logger('{}/log/log-train'.format(exp_dir), use_console=args.local_rank == 0)
tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.local_rank == 0 else None

Expand Down Expand Up @@ -356,7 +356,7 @@ def main():
logging.info("About to create model")
model = TdnnLstm1b(num_features=80,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=3)
subsampling_factor=4)
model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

model.to(device)
Expand Down
29 changes: 20 additions & 9 deletions snowfall/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any, Dict, Iterable, List, Optional, TextIO, Tuple, Union

import k2
import k2.ragged as k2r
import kaldialign
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -314,18 +315,28 @@ def get_texts(best_paths: k2.Fsa, indices: Optional[torch.Tensor] = None) -> Lis
decoded.
'''
# remove any 0's or -1's (there should be no 0's left but may be -1's.)
aux_labels = k2.ragged.remove_values_leq(best_paths.aux_labels, 0)
aux_shape = k2.ragged.compose_ragged_shapes(best_paths.arcs.shape(),
aux_labels.shape())
# remove the states and arcs axes.
aux_shape = k2.ragged.remove_axis(aux_shape, 1)
aux_shape = k2.ragged.remove_axis(aux_shape, 1)
aux_labels = k2.RaggedInt(aux_shape, aux_labels.values())

if isinstance(best_paths.aux_labels, k2.RaggedInt):
aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0)
aux_shape = k2r.compose_ragged_shapes(best_paths.arcs.shape(),
aux_labels.shape())

# remove the states and arcs axes.
aux_shape = k2r.remove_axis(aux_shape, 1)
aux_shape = k2r.remove_axis(aux_shape, 1)
aux_labels = k2.RaggedInt(aux_shape, aux_labels.values())
else:
# remove axis corresponding to states.
aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1)
aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels)
# remove 0's and -1's.
aux_labels = k2r.remove_values_leq(aux_labels, 0)

assert (aux_labels.num_axes() == 2)
aux_labels, _ = k2.ragged.index(aux_labels,
aux_labels, _ = k2r.index(aux_labels,
invert_permutation(indices).to(dtype=torch.int32,
device=best_paths.device))
return k2.ragged.to_list(aux_labels)
return k2r.to_list(aux_labels)


def invert_permutation(indices: torch.Tensor) -> torch.Tensor:
Expand Down