From be0842f611d56b3caa20ccb3a21e2a60a4e3c748 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 11 Feb 2020 16:16:24 +0800 Subject: [PATCH] [scripts,egs] Add TDNNF to pytorch. (#3892) --- egs/aishell/s10/chain/egs_dataset.py | 10 +- egs/aishell/s10/chain/inference.py | 5 +- egs/aishell/s10/chain/model.py | 240 ++++++++++++-------- egs/aishell/s10/chain/options.py | 33 +-- egs/aishell/s10/chain/tdnnf_layer.py | 319 +++++++++++++++++++++++++++ egs/aishell/s10/chain/train.py | 11 +- egs/aishell/s10/conf/mfcc_hires.conf | 10 + egs/aishell/s10/local/run_chain.sh | 45 ++-- egs/aishell/s10/local/run_tdnn_1b.sh | 2 +- egs/aishell/s10/run.sh | 48 ++-- 10 files changed, 561 insertions(+), 162 deletions(-) create mode 100644 egs/aishell/s10/chain/tdnnf_layer.py create mode 100644 egs/aishell/s10/conf/mfcc_hires.conf diff --git a/egs/aishell/s10/chain/egs_dataset.py b/egs/aishell/s10/chain/egs_dataset.py index f6210f19e6f..6bd36cf7cb2 100755 --- a/egs/aishell/s10/chain/egs_dataset.py +++ b/egs/aishell/s10/chain/egs_dataset.py @@ -170,10 +170,10 @@ def __call__(self, batch): def _test_nnet_chain_example_dataset(): - egs_dir = '/cache/fangjun/chain/aishell_kaldi_pybind/test' + egs_dir = 'exp/chain/merged_egs' dataset = NnetChainExampleDataset(egs_dir=egs_dir) - egs_left_context = 23 - egs_right_context = 23 + egs_left_context = 29 + egs_right_context = 29 frame_subsampling_factor = 3 collate_fn = NnetChainExampleDatasetCollateFunc( @@ -200,7 +200,9 @@ def _test_nnet_chain_example_dataset(): collate_fn=collate_fn) for b in dataloader: key_list, feature_list, supervision_list = b - assert feature_list[0].shape == (128, 192, 120) + assert feature_list[0].shape == (128, 204, 129) \ + or feature_list[0].shape == (128, 144, 129) \ + or feature_list[0].shape == (128, 165, 129) assert supervision_list[0].weight == 1 supervision_list[0].num_sequences == 128 # minibach size is 128 diff --git a/egs/aishell/s10/chain/inference.py b/egs/aishell/s10/chain/inference.py index 7f5e28416d7..c8ef809ae61 100644 --- a/egs/aishell/s10/chain/inference.py +++ b/egs/aishell/s10/chain/inference.py @@ -34,8 +34,9 @@ def main(): output_dim=args.output_dim, lda_mat_filename=args.lda_mat_filename, hidden_dim=args.hidden_dim, - kernel_size_list=args.kernel_size_list, - stride_list=args.stride_list) + bottleneck_dim=args.bottleneck_dim, + time_stride_list=args.time_stride_list, + conv_stride_list=args.conv_stride_list) load_checkpoint(args.checkpoint, model) diff --git a/egs/aishell/s10/chain/model.py b/egs/aishell/s10/chain/model.py index bf65a4fe158..39d7acb765a 100644 --- a/egs/aishell/s10/chain/model.py +++ b/egs/aishell/s10/chain/model.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) # Apache 2.0 import logging @@ -10,109 +10,118 @@ import torch.nn.functional as F from common import load_lda_mat -''' - input dim=$feat_dim name=input - - # please note that it is important to have input layer with the name=input - # as the layer immediately preceding the fixed-affine-layer to enable - # the use of short notation for the descriptor - fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=$dir/configs/lda.mat - - # the first splicing is moved before the lda layer, so no splicing here - relu-batchnorm-layer name=tdnn1 dim=625 - relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625 - relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625 - relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625 - relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625 - relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625 - - ## adding the layers for chain branch - relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=625 target-rms=0.5 - output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 - - # adding the layers for xent branch - # This block prints the configs for a separate output that will be - # trained with a cross-entropy objective in the 'chain' models... this - # has the effect of regularizing the hidden parts of the model. we use - # 0.5 / args.xent_regularize as the learning rate factor- the factor of - # 0.5 / args.xent_regularize is suitable as it means the xent - # final-layer learns at a rate independent of the regularization - # constant; and the 0.5 was tuned so as to make the relative progress - # similar in the xent and regular final layers. - relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=625 target-rms=0.5 - output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 -''' +from tdnnf_layer import FactorizedTDNN +from tdnnf_layer import OrthonormalLinear +from tdnnf_layer import PrefinalLayer def get_chain_model(feat_dim, output_dim, hidden_dim, - kernel_size_list, - stride_list, + bottleneck_dim, + time_stride_list, + conv_stride_list, lda_mat_filename=None): model = ChainModel(feat_dim=feat_dim, output_dim=output_dim, lda_mat_filename=lda_mat_filename, hidden_dim=hidden_dim, - kernel_size_list=kernel_size_list, - stride_list=stride_list) + time_stride_list=time_stride_list, + conv_stride_list=conv_stride_list) return model +''' +input dim=43 name=input + +# please note that it is important to have input layer with the name=input +# as the layer immediately preceding the fixed-affine-layer to enable +# the use of short notation for the descriptor +fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=exp/chain_cleaned_1c/tdnn1c_sp/configs/lda.mat + +# the first splicing is moved before the lda layer, so no splicing here +relu-batchnorm-dropout-layer name=tdnn1 l2-regularize=0.008 dropout-proportion=0.0 dropout-per-dim-continuous=true dim=1024 +tdnnf-layer name=tdnnf2 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1 +tdnnf-layer name=tdnnf3 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1 +tdnnf-layer name=tdnnf4 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1 +tdnnf-layer name=tdnnf5 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=0 +tdnnf-layer name=tdnnf6 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 +tdnnf-layer name=tdnnf7 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 +tdnnf-layer name=tdnnf8 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 +tdnnf-layer name=tdnnf9 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 +tdnnf-layer name=tdnnf10 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 +tdnnf-layer name=tdnnf11 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 +tdnnf-layer name=tdnnf12 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 +tdnnf-layer name=tdnnf13 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3 +linear-component name=prefinal-l dim=256 l2-regularize=0.008 orthonormal-constraint=-1.0 + +prefinal-layer name=prefinal-chain input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256 +output-layer name=output include-log-softmax=false dim=3456 l2-regularize=0.002 + +prefinal-layer name=prefinal-xent input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256 +output-layer name=output-xent dim=3456 learning-rate-factor=5.0 l2-regularize=0.002 +''' + + # Create a network like the above one class ChainModel(nn.Module): def __init__(self, feat_dim, output_dim, - lda_mat_filename, - hidden_dim=625, - kernel_size_list=[1, 3, 3, 3, 3, 3], - stride_list=[1, 1, 3, 1, 1, 1], + lda_mat_filename=None, + hidden_dim=1024, + bottleneck_dim=128, + time_stride_list=[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1], + conv_stride_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1], frame_subsampling_factor=3): super().__init__() # at present, we support only frame_subsampling_factor to be 3 assert frame_subsampling_factor == 3 - assert len(kernel_size_list) == len(stride_list) - num_layers = len(kernel_size_list) + assert len(time_stride_list) == len(conv_stride_list) + num_layers = len(time_stride_list) + + # tdnn1_affine requires [N, T, C] + self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3, + out_features=hidden_dim) - tdnns = [] + # tdnn1_batchnorm requires [N, C, T] + self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim) + + tdnnfs = [] for i in range(num_layers): - in_channels = hidden_dim - if i == 0: - in_channels = feat_dim * 3 - - kernel_size = kernel_size_list[i] - stride = stride_list[i] - - # we do not need to perform padding in Conv1d because it - # has been included in left/right context while generating egs - layer = nn.Conv1d(in_channels=in_channels, - out_channels=hidden_dim, - kernel_size=kernel_size, - stride=stride) - tdnns.append(layer) - - self.tdnns = nn.ModuleList(tdnns) - self.batch_norms = nn.ModuleList([ - nn.BatchNorm1d(num_features=hidden_dim) for i in range(num_layers) - ]) - - self.prefinal_chain_tdnn = nn.Conv1d(in_channels=hidden_dim, - out_channels=hidden_dim, - kernel_size=1) - self.prefinal_chain_batch_norm = nn.BatchNorm1d(num_features=hidden_dim) - self.output_fc = nn.Linear(in_features=hidden_dim, - out_features=output_dim) - - self.prefinal_xent_tdnn = nn.Conv1d(in_channels=hidden_dim, - out_channels=hidden_dim, - kernel_size=1) - self.prefinal_xent_batch_norm = nn.BatchNorm1d(num_features=hidden_dim) - self.output_xent_fc = nn.Linear(in_features=hidden_dim, - out_features=output_dim) + time_stride = time_stride_list[i] + conv_stride = conv_stride_list[i] + layer = FactorizedTDNN(dim=hidden_dim, + bottleneck_dim=bottleneck_dim, + time_stride=time_stride, + conv_stride=conv_stride) + tdnnfs.append(layer) + + # tdnnfs requires [N, C, T] + self.tdnnfs = nn.ModuleList(tdnnfs) + + # prefinal_l affine requires [N, C, T] + self.prefinal_l = OrthonormalLinear(dim=hidden_dim, + bottleneck_dim=bottleneck_dim * 2, + time_stride=0) + + # prefinal_chain requires [N, C, T] + self.prefinal_chain = PrefinalLayer(big_dim=hidden_dim, + small_dim=bottleneck_dim * 2) + + # output_affine requires [N, T, C] + self.output_affine = nn.Linear(in_features=bottleneck_dim * 2, + out_features=output_dim) + + # prefinal_xent requires [N, C, T] + self.prefinal_xent = PrefinalLayer(big_dim=hidden_dim, + small_dim=bottleneck_dim * 2) + + self.output_xent_affine = nn.Linear(in_features=bottleneck_dim * 2, + out_features=output_dim) if lda_mat_filename: logging.info('Use LDA from {}'.format(lda_mat_filename)) @@ -146,32 +155,69 @@ def forward(self, x): # at this point, x is [N, C, T] - # Conv1d requires input of shape [N, C, T] - for i in range(len(self.tdnns)): - x = self.tdnns[i](x) - x = F.relu(x) - x = self.batch_norms[i](x) + x = x.permute(0, 2, 1) + + # at this point, x is [N, T, C] + + x = self.tdnn1_affine(x) + + # at this point, x is [N, T, C] + + x = F.relu(x) + + x = x.permute(0, 2, 1) + + # at this point, x is [N, C, T] + + x = self.tdnn1_batchnorm(x) + + # tdnnf requires input of shape [N, C, T] + for i in range(len(self.tdnnfs)): + x = self.tdnnfs[i](x) # at this point, x is [N, C, T] - # we have two branches from this point on + x = self.prefinal_l(x) + + # at this point, x is [N, C, T] - # first, for the chain branch - x_chain = self.prefinal_chain_tdnn(x) - x_chain = F.relu(x_chain) - x_chain = self.prefinal_chain_batch_norm(x_chain) - x_chain = x_chain.permute(0, 2, 1) - # at this point, x_chain is [N, T, C] - nnet_output = self.output_fc(x_chain) + # for the output node + nnet_output = self.prefinal_chain(x) - # now for the xent branch - x_xent = self.prefinal_xent_tdnn(x) - x_xent = F.relu(x_xent) - x_xent = self.prefinal_xent_batch_norm(x_xent) - x_xent = x_xent.permute(0, 2, 1) + # at this point, nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, 1) + # at this point, nnet_output is [N, T, C] + nnet_output = self.output_affine(nnet_output) + + # for the xent node + xent_output = self.prefinal_xent(x) + + # at this point, xent_output is [N, C, T] + xent_output = xent_output.permute(0, 2, 1) + # at this point, xent_output is [N, T, C] + xent_output = self.output_xent_affine(xent_output) - # at this point x_xent is [N, T, C] - xent_output = self.output_xent_fc(x_xent) xent_output = F.log_softmax(xent_output, dim=-1) return nnet_output, xent_output + + def constrain_orthonormal(self): + for i in range(len(self.tdnnfs)): + self.tdnnfs[i].constrain_orthonormal() + + self.prefinal_l.constrain_orthonormal() + self.prefinal_chain.constrain_orthonormal() + self.prefinal_xent.constrain_orthonormal() + + +if __name__ == '__main__': + feat_dim = 43 + output_dim = 4344 + model = ChainModel(feat_dim=feat_dim, output_dim=output_dim) + N = 1 + T = 150 + 27 + 27 + C = feat_dim * 3 + x = torch.arange(N * T * C).reshape(N, T, C).float() + nnet_output, xent_output = model(x) + print(x.shape, nnet_output.shape, xent_output.shape) + model.constrain_orthonormal() diff --git a/egs/aishell/s10/chain/options.py b/egs/aishell/s10/chain/options.py index a2f1231b460..5a6e04f9ba7 100644 --- a/egs/aishell/s10/chain/options.py +++ b/egs/aishell/s10/chain/options.py @@ -129,18 +129,19 @@ def _check_args(args): assert args.feat_dim > 0 assert args.output_dim > 0 assert args.hidden_dim > 0 + assert args.bottleneck_dim > 0 - assert args.kernel_size_list is not None - assert len(args.kernel_size_list) > 0 + assert args.time_stride_list is not None + assert len(args.time_stride_list) > 0 - assert args.stride_list is not None - assert len(args.stride_list) > 0 + assert args.conv_stride_list is not None + assert len(args.conv_stride_list) > 0 - args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')] + args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')] - args.stride_list = [int(k) for k in args.stride_list.split(', ')] + args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')] - assert len(args.kernel_size_list) == len(args.stride_list) + assert len(args.time_stride_list) == len(args.conv_stride_list) assert args.log_level in ['debug', 'info', 'warning'] @@ -195,15 +196,21 @@ def get_args(): required=True, type=int) - parser.add_argument('--kernel-size-list', - dest='kernel_size_list', - help='kernel size list', + parser.add_argument('--bottleneck-dim', + dest='bottleneck_dim', + help='nn bottleneck dimension', + required=True, + type=int) + + parser.add_argument('--time-stride-list', + dest='time_stride_list', + help='time stride list', required=True, type=str) - parser.add_argument('--stride-list', - dest='stride_list', - help='stride list', + parser.add_argument('--conv-stride-list', + dest='conv_stride_list', + help='conv stride list', required=True, type=str) diff --git a/egs/aishell/s10/chain/tdnnf_layer.py b/egs/aishell/s10/chain/tdnnf_layer.py new file mode 100644 index 00000000000..cf3c5a11862 --- /dev/null +++ b/egs/aishell/s10/chain/tdnnf_layer.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _constrain_orthonormal_internal(M): + ''' + Refer to + void ConstrainOrthonormalInternal(BaseFloat scale, CuMatrixBase *M) + from + https://github.com/kaldi-asr/kaldi/blob/master/src/nnet3/nnet-utils.cc#L982 + + Note that we always use the **floating** case. + ''' + assert M.ndim == 2 + + num_rows = M.size(0) + num_cols = M.size(1) + + assert num_rows <= num_cols + + # P = M * M^T + P = torch.mm(M, M.t()) + P_PT = torch.mm(P, P.t()) + + trace_P = torch.trace(P) + trace_P_P = torch.trace(P_PT) + + scale = torch.sqrt(trace_P_P / trace_P) + + ratio = trace_P_P * num_rows / (trace_P * trace_P) + assert ratio > 0.99 + + update_speed = 0.125 + + if ratio > 1.02: + update_speed *= 0.5 + if ratio > 1.1: + update_speed *= 0.5 + + identity = torch.eye(num_rows, dtype=P.dtype, device=P.device) + P = P - scale * scale * identity + + alpha = update_speed / (scale * scale) + M = M - 4 * alpha * torch.mm(P, M) + return M + + +class OrthonormalLinear(nn.Module): + + def __init__(self, dim, bottleneck_dim, time_stride): + super().__init__() + assert time_stride in [0, 1] + # WARNING(fangjun): kaldi uses [-1, 0] for the first linear layer + # and [0, 1] for the second affine layer; + # we use [-1, 0, 1] for the first linear layer if time_stride == 1 + + if time_stride == 0: + kernel_size = 1 + else: + kernel_size = 3 + + self.kernel_size = kernel_size + + # conv requires [N, C, T] + self.conv = nn.Conv1d(in_channels=dim, + out_channels=bottleneck_dim, + kernel_size=kernel_size, + bias=False) + + def forward(self, x): + # input x is of shape: [batch_size, feat_dim, seq_len] = [N, C, T] + assert x.ndim == 3 + x = self.conv(x) + return x + + def constrain_orthonormal(self): + state_dict = self.conv.state_dict() + w = state_dict['weight'] + # w is of shape [out_channels, in_channels, kernel_size] + out_channels = w.size(0) + in_channels = w.size(1) + kernel_size = w.size(2) + + w = w.reshape(out_channels, -1) + + num_rows = w.size(0) + num_cols = w.size(1) + + need_transpose = False + if num_rows > num_cols: + w = w.t() + need_transpose = True + + w = _constrain_orthonormal_internal(w) + + if need_transpose: + w = w.t() + + w = w.reshape(out_channels, in_channels, kernel_size) + + state_dict['weight'] = w + self.conv.load_state_dict(state_dict) + + +class PrefinalLayer(nn.Module): + + def __init__(self, big_dim, small_dim): + super().__init__() + self.affine = nn.Linear(in_features=small_dim, out_features=big_dim) + self.batchnorm1 = nn.BatchNorm1d(num_features=big_dim) + self.linear = OrthonormalLinear(dim=big_dim, + bottleneck_dim=small_dim, + time_stride=0) + self.batchnorm2 = nn.BatchNorm1d(num_features=small_dim) + + def forward(self, x): + # x is [N, C, T] + x = x.permute(0, 2, 1) + + # at this point, x is [N, T, C] + + x = self.affine(x) + x = F.relu(x) + + # at this point, x is [N, T, C] + + x = x.permute(0, 2, 1) + + # at this point, x is [N, C, T] + + x = self.batchnorm1(x) + + x = self.linear(x) + + x = self.batchnorm2(x) + + return x + + def constrain_orthonormal(self): + self.linear.constrain_orthonormal() + + +class FactorizedTDNN(nn.Module): + ''' + This class implements the following topology in kaldi: + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + + References: + - http://danielpovey.com/files/2018_interspeech_tdnnf.pdf + - ConstrainOrthonormalInternal() from + https://github.com/kaldi-asr/kaldi/blob/master/src/nnet3/nnet-utils.cc#L982 + ''' + + def __init__(self, + dim, + bottleneck_dim, + time_stride, + conv_stride, + bypass_scale=0.66): + super().__init__() + + assert conv_stride in [1, 3] + assert abs(bypass_scale) <= 1 + + self.bypass_scale = bypass_scale + + self.conv_stride = conv_stride + + # linear requires [N, C, T] + self.linear = OrthonormalLinear(dim=dim, + bottleneck_dim=bottleneck_dim, + time_stride=time_stride) + + # affine requires [N, C, T] + # WARNING(fangjun): we do not use nn.Linear here + # since we want to use `stride` + self.affine = nn.Conv1d(in_channels=bottleneck_dim, + out_channels=dim, + kernel_size=1, + stride=conv_stride) + + # batchnorm requires [N, C, T] + self.batchnorm = nn.BatchNorm1d(num_features=dim) + + def forward(self, x): + # input x is of shape: [batch_size, feat_dim, seq_len] = [N, C, T] + assert x.ndim == 3 + + # save it for skip connection + input_x = x + + x = self.linear(x) + + # at this point, x is [N, C, T] + + x = self.affine(x) + + # at this point, x is [N, C, T] + + x = F.relu(x) + + # at this point, x is [N, C, T] + + x = self.batchnorm(x) + + # at this point, x is [N, C, T] + + # TODO(fangjun): implement GeneralDropoutComponent in PyTorch + + if self.linear.kernel_size == 3: + x = self.bypass_scale * input_x[:, :, 1:-1:self.conv_stride] + x + else: + x = self.bypass_scale * input_x[:, :, ::self.conv_stride] + x + return x + + def constrain_orthonormal(self): + self.linear.constrain_orthonormal() + + +def _test_constrain_orthonormal(): + + def compute_loss(M): + P = torch.mm(M, M.t()) + P_PT = torch.mm(P, P.t()) + + trace_P = torch.trace(P) + trace_P_P = torch.trace(P_PT) + + scale = torch.sqrt(trace_P_P / trace_P) + + identity = torch.eye(P.size(0), dtype=P.dtype, device=P.device) + Q = P / (scale * scale) - identity + loss = torch.norm(Q, p='fro') # Frobenius norm + + return loss + + w = torch.randn(6, 8) * 10 + + loss = [] + loss.append(compute_loss(w)) + + for i in range(15): + w = _constrain_orthonormal_internal(w) + loss.append(compute_loss(w)) + + for i in range(1, len(loss)): + assert loss[i - 1] > loss[i] + + # TODO(fangjun): draw the loss using matplotlib + # print(loss) + + model = FactorizedTDNN(dim=1024, + bottleneck_dim=128, + time_stride=1, + conv_stride=3) + loss = [] + model.constrain_orthonormal() + loss.append( + compute_loss(model.linear.conv.state_dict()['weight'].reshape(128, -1))) + for i in range(5): + model.constrain_orthonormal() + loss.append( + compute_loss(model.linear.conv.state_dict()['weight'].reshape( + 128, -1))) + + for i in range(1, len(loss)): + assert loss[i - 1] > loss[i] + + +def _test_factorized_tdnn(): + import math + N = 1 + T = 10 + C = 4 + + # case 0: time_stride == 1, conv_stride == 1 + model = FactorizedTDNN(dim=C, + bottleneck_dim=2, + time_stride=1, + conv_stride=1) + x = torch.arange(N * T * C).reshape(N, C, T).float() + y = model(x) + assert y.size(2) == T - 2 + + # case 1: time_stride == 0, conv_stride == 1 + model = FactorizedTDNN(dim=C, + bottleneck_dim=2, + time_stride=0, + conv_stride=1) + y = model(x) + assert y.size(2) == T + + # case 2: time_stride == 1, conv_stride == 3 + model = FactorizedTDNN(dim=C, + bottleneck_dim=2, + time_stride=1, + conv_stride=3) + y = model(x) + assert y.size(2) == math.ceil((T - 2) / 3) + + # case 3: time_stride == 0, conv_stride == 3 + model = FactorizedTDNN(dim=C, + bottleneck_dim=2, + time_stride=0, + conv_stride=3) + y = model(x) + assert y.size(2) == math.ceil(T / 3) + + +if __name__ == '__main__': + torch.manual_seed(20200130) + _test_factorized_tdnn() + _test_constrain_orthonormal() diff --git a/egs/aishell/s10/chain/train.py b/egs/aishell/s10/chain/train.py index 829202fbe94..1f5c6824c97 100644 --- a/egs/aishell/s10/chain/train.py +++ b/egs/aishell/s10/chain/train.py @@ -11,6 +11,7 @@ # disable warnings when loading tensorboard warnings.simplefilter(action='ignore', category=FutureWarning) +import numpy as np import torch import torch.optim as optim from torch.nn.utils import clip_grad_value_ @@ -84,6 +85,11 @@ def train_one_epoch(dataloader, model, device, optimizer, criterion, total_weight += objf_l2_term_weight[2].item() num_frames = nnet_output.shape[0] total_frames += num_frames + + if np.random.choice(4) == 0: + with torch.no_grad(): + model.constrain_orthonormal() + if batch_idx % 100 == 0: logging.info( 'Process {}/{}({:.6f}%) global average objf: {:.6f} over {} ' @@ -135,8 +141,9 @@ def main(): output_dim=args.output_dim, lda_mat_filename=args.lda_mat_filename, hidden_dim=args.hidden_dim, - kernel_size_list=args.kernel_size_list, - stride_list=args.stride_list) + bottleneck_dim=args.bottleneck_dim, + time_stride_list=args.time_stride_list, + conv_stride_list=args.conv_stride_list) start_epoch = 0 num_epochs = args.num_epochs diff --git a/egs/aishell/s10/conf/mfcc_hires.conf b/egs/aishell/s10/conf/mfcc_hires.conf new file mode 100644 index 00000000000..137d00add94 --- /dev/null +++ b/egs/aishell/s10/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training. +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--sample-frequency=16000 # AISHELL-2 is sampled at 16kHz +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=20 # low cutoff frequency for mel bins +--high-freq=-400 # high cutoff frequency, relative to Nyquist of 8000 (=7600) diff --git a/egs/aishell/s10/local/run_chain.sh b/egs/aishell/s10/local/run_chain.sh index 8ce22d3364b..06b5d47e89f 100755 --- a/egs/aishell/s10/local/run_chain.sh +++ b/egs/aishell/s10/local/run_chain.sh @@ -9,7 +9,7 @@ stage=0 # GPU device id to use (count from 0). # you can also set `CUDA_VISIBLE_DEVICES` and set `device_id=0` -device_id=0 +device_id=6 nj=10 @@ -19,8 +19,8 @@ lat_dir=exp/tri5a_lats # input lat dir treedir=exp/chain/tri5_tree # output tree dir # You should know how to calculate your model's left/right context **manually** -model_left_context=12 -model_right_context=12 +model_left_context=28 +model_right_context=28 egs_left_context=$[$model_left_context + 1] egs_right_context=$[$model_right_context + 1] frames_per_eg=150,110,90 @@ -30,9 +30,10 @@ minibatch_size=128 num_epochs=6 lr=1e-3 -hidden_dim=625 -kernel_size_list="1, 3, 3, 3, 3, 3" # comma separated list -stride_list="1, 1, 3, 1, 1, 1" # comma separated list +hidden_dim=1024 +bottleneck_dim=128 +time_stride_list="1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list +conv_stride_list="1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list log_level=info # valid values: debug, info, warning @@ -47,11 +48,15 @@ save_nn_output_as_compressed=false if [[ $stage -le 0 ]]; then for datadir in train dev test; do - dst_dir=data/fbank_pitch/$datadir + dst_dir=data/mfcc_hires/$datadir if [[ ! -f $dst_dir/feats.scp ]]; then + echo "making mfcc features for LF-MMI training" utils/copy_data_dir.sh data/$datadir $dst_dir - echo "making fbank-pitch features for LF-MMI training" - steps/make_fbank_pitch.sh --cmd $train_cmd --nj $nj $dst_dir || exit 1 + steps/make_mfcc.sh \ + --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" \ + --nj $nj \ + $dst_dir || exit 1 steps/compute_cmvn_stats.sh $dst_dir || exit 1 utils/fix_data_dir.sh $dst_dir else @@ -80,12 +85,12 @@ if [[ $stage -le 2 ]]; then # step compared with other recipes. steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ --context-opts "--context-width=2 --central-position=1" \ - --cmd "$train_cmd" 5000 data/train $lang $ali_dir $treedir + --cmd "$train_cmd" 5000 data/mfcc/train $lang $ali_dir $treedir fi if [[ $stage -le 3 ]]; then echo "creating phone language-model" - $train_cmd exp/chain/log/make_phone_lm.log \ + "$train_cmd" exp/chain/log/make_phone_lm.log \ chain-est-phone-lm \ "ark:gunzip -c $treedir/ali.*.gz | ali-to-phones $treedir/final.mdl ark:- ark:- |" \ exp/chain/phone_lm.fst || exit 1 @@ -95,7 +100,7 @@ if [[ $stage -le 4 ]]; then echo "creating denominator FST" copy-transition-model $treedir/final.mdl exp/chain/0.trans_mdl cp $treedir/tree exp/chain - $train_cmd exp/chain/log/make_den_fst.log \ + "$train_cmd" exp/chain/log/make_den_fst.log \ chain-make-den-fst exp/chain/tree exp/chain/0.trans_mdl exp/chain/phone_lm.fst \ exp/chain/den.fst exp/chain/normalization.fst || exit 1 fi @@ -119,7 +124,7 @@ if [[ $stage -le 5 ]]; then --right-tolerance 5 \ --srand 0 \ --stage -10 \ - data/fbank_pitch/train \ + data/mfcc_hires/train \ exp/chain $lat_dir exp/chain/egs fi @@ -157,16 +162,17 @@ if [[ $stage -le 8 ]]; then # sort the options alphabetically python3 ./chain/train.py \ + --bottleneck-dim $bottleneck_dim \ --checkpoint=${train_checkpoint:-} \ + --conv-stride-list "$conv_stride_list" \ --device-id $device_id \ --dir exp/chain/train \ --feat-dim $feat_dim \ --hidden-dim $hidden_dim \ --is-training true \ - --kernel-size-list "$kernel_size_list" \ --log-level $log_level \ --output-dim $output_dim \ - --stride-list "$stride_list" \ + --time-stride-list "$time_stride_list" \ --train.cegs-dir exp/chain/merged_egs \ --train.den-fst exp/chain/den.fst \ --train.egs-left-context $egs_left_context \ @@ -186,20 +192,21 @@ if [[ $stage -le 9 ]]; then best_epoch=$(cat exp/chain/train/best-epoch-info | grep 'best epoch' | awk '{print $NF}') inference_checkpoint=exp/chain/train/epoch-${best_epoch}.pt python3 ./chain/inference.py \ + --bottleneck-dim $bottleneck_dim \ --checkpoint $inference_checkpoint \ + --conv-stride-list "$conv_stride_list" \ --device-id $device_id \ --dir exp/chain/inference/$x \ --feat-dim $feat_dim \ - --feats-scp data/fbank_pitch/$x/feats.scp \ + --feats-scp data/mfcc_hires/$x/feats.scp \ --hidden-dim $hidden_dim \ --is-training false \ - --kernel-size-list "$kernel_size_list" \ --log-level $log_level \ --model-left-context $model_left_context \ --model-right-context $model_right_context \ --output-dim $output_dim \ --save-as-compressed $save_nn_output_as_compressed \ - --stride-list "$stride_list" || exit 1 + --time-stride-list "$time_stride_list" || exit 1 fi done fi @@ -228,7 +235,7 @@ if [[ $stage -le 11 ]]; then for x in test dev; do ./local/score.sh --cmd "$decode_cmd" \ - data/fbank_pitch/$x \ + data/mfcc_hires/$x \ exp/chain/graph \ exp/chain/decode_res/$x || exit 1 done diff --git a/egs/aishell/s10/local/run_tdnn_1b.sh b/egs/aishell/s10/local/run_tdnn_1b.sh index 34aa7fc3fee..6d2b04359ff 100755 --- a/egs/aishell/s10/local/run_tdnn_1b.sh +++ b/egs/aishell/s10/local/run_tdnn_1b.sh @@ -85,7 +85,7 @@ if [[ $stage -le 2 ]]; then # step compared with other recipes. steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ --context-opts "--context-width=2 --central-position=1" \ - --cmd $train_cmd 5000 data/train $lang $ali_dir $treedir + --cmd $train_cmd 5000 data/mfcc/train $lang $ali_dir $treedir fi if [[ $stage -le 3 ]]; then diff --git a/egs/aishell/s10/run.sh b/egs/aishell/s10/run.sh index 50c87d7e94a..5e42fc954cc 100755 --- a/egs/aishell/s10/run.sh +++ b/egs/aishell/s10/run.sh @@ -11,7 +11,7 @@ # You also need a GPU to run this example. # # PyTorch with version `1.3.0dev20191006` has been tested and is -# guaranteed to work. +# known to work. # # Note that we have used Tensorboard to visualize the training loss. # You do **NOT** need to install TensorFlow to use Tensorboard. @@ -22,7 +22,7 @@ data=/data/fangjunkuang/data/aishell data_url=www.openslr.org/resources/33 -nj=10 +nj=30 stage=0 @@ -54,78 +54,78 @@ if [[ $stage -le 4 ]]; then cp data/lang/phones/* data/lang_test/phones/ fi -mfccdir=mfcc if [[ $stage -le 5 ]]; then for x in train dev test; do - steps/make_mfcc_pitch.sh --cmd "$train_cmd" --nj $nj \ - data/$x exp/make_mfcc/$x $mfccdir || exit 1 - steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir || exit 1 - utils/fix_data_dir.sh data/$x || exit 1 + dst_dir=data/mfcc/$x + utils/copy_data_dir.sh data/$x $dst_dir + steps/make_mfcc_pitch.sh --cmd "$train_cmd" --nj $nj $dst_dir || exit 1 + steps/compute_cmvn_stats.sh $dst_dir || exit 1 + utils/fix_data_dir.sh $dst_dir || exit 1 done fi if [[ $stage -le 6 ]]; then steps/train_mono.sh --cmd "$train_cmd" --nj $nj \ - data/train data/lang exp/mono || exit 1 + data/mfcc/train data/lang exp/mono || exit 1 fi if [[ $stage -le 7 ]]; then steps/align_si.sh --cmd "$train_cmd" --nj $nj \ - data/train data/lang exp/mono exp/mono_ali || exit 1 + data/mfcc/train data/lang exp/mono exp/mono_ali || exit 1 fi if [[ $stage -le 8 ]]; then steps/train_deltas.sh --cmd "$train_cmd" \ - 2500 20000 data/train data/lang exp/mono_ali exp/tri1 || exit 1 + 2500 20000 data/mfcc/train data/lang exp/mono_ali exp/tri1 || exit 1 fi if [[ $stage -le 9 ]]; then steps/align_si.sh --cmd "$train_cmd" --nj $nj \ - data/train data/lang exp/tri1 exp/tri1_ali || exit 1 + data/mfcc/train data/lang exp/tri1 exp/tri1_ali || exit 1 fi if [[ $stage -le 10 ]]; then steps/train_deltas.sh --cmd "$train_cmd" \ - 2500 20000 data/train data/lang exp/tri1_ali exp/tri2 || exit 1 + 2500 20000 data/mfcc/train data/lang exp/tri1_ali exp/tri2 || exit 1 fi if [[ $stage -le 11 ]]; then steps/align_si.sh --cmd "$train_cmd" --nj $nj \ - data/train data/lang exp/tri2 exp/tri2_ali || exit 1 + data/mfcc/train data/lang exp/tri2 exp/tri2_ali || exit 1 fi if [[ $stage -le 12 ]]; then steps/train_lda_mllt.sh --cmd "$train_cmd" \ - 2500 20000 data/train data/lang exp/tri2_ali exp/tri3a || exit 1 + 2500 20000 data/mfcc/train data/lang exp/tri2_ali exp/tri3a || exit 1 fi if [[ $stage -le 13 ]]; then steps/align_fmllr.sh --cmd "$train_cmd" --nj $nj \ - data/train data/lang exp/tri3a exp/tri3a_ali || exit 1 + data/mfcc/train data/lang exp/tri3a exp/tri3a_ali || exit 1 fi if [[ $stage -le 14 ]]; then - steps/train_sat.sh --cmd $train_cmd \ - 2500 20000 data/train data/lang exp/tri3a_ali exp/tri4a || exit 1 + steps/train_sat.sh --cmd "$train_cmd" \ + 2500 20000 data/mfcc/train data/lang exp/tri3a_ali exp/tri4a || exit 1 fi if [[ $stage -le 15 ]]; then - steps/align_fmllr.sh --cmd $train_cmd --nj $nj \ - data/train data/lang exp/tri4a exp/tri4a_ali + steps/align_fmllr.sh --cmd "$train_cmd" --nj $nj \ + data/mfcc/train data/lang exp/tri4a exp/tri4a_ali fi if [[ $stage -le 16 ]]; then - steps/train_sat.sh --cmd $train_cmd \ - 3500 100000 data/train data/lang exp/tri4a_ali exp/tri5a || exit 1 + steps/train_sat.sh --cmd "$train_cmd" \ + 3500 100000 data/mfcc/train data/lang exp/tri4a_ali exp/tri5a || exit 1 fi if [[ $stage -le 17 ]]; then - steps/align_fmllr.sh --cmd $train_cmd --nj $nj \ - data/train data/lang exp/tri5a exp/tri5a_ali || exit 1 + steps/align_fmllr.sh --cmd "$train_cmd" --nj $nj \ + data/mfcc/train data/lang exp/tri5a exp/tri5a_ali || exit 1 fi if [[ $stage -le 18 ]]; then - steps/align_fmllr_lats.sh --nj $nj --cmd $train_cmd data/train \ + steps/align_fmllr_lats.sh --nj $nj --cmd "$train_cmd" data/mfcc/train \ data/lang exp/tri5a exp/tri5a_lats rm exp/tri5a_lats/fsts.*.gz # save space fi