Skip to content

Commit

Permalink
[scripts,egs] Add TDNNF to pytorch. (#3892)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Feb 11, 2020
1 parent ee517cd commit be0842f
Show file tree
Hide file tree
Showing 10 changed files with 561 additions and 162 deletions.
10 changes: 6 additions & 4 deletions egs/aishell/s10/chain/egs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ def __call__(self, batch):


def _test_nnet_chain_example_dataset():
egs_dir = '/cache/fangjun/chain/aishell_kaldi_pybind/test'
egs_dir = 'exp/chain/merged_egs'
dataset = NnetChainExampleDataset(egs_dir=egs_dir)
egs_left_context = 23
egs_right_context = 23
egs_left_context = 29
egs_right_context = 29
frame_subsampling_factor = 3

collate_fn = NnetChainExampleDatasetCollateFunc(
Expand All @@ -200,7 +200,9 @@ def _test_nnet_chain_example_dataset():
collate_fn=collate_fn)
for b in dataloader:
key_list, feature_list, supervision_list = b
assert feature_list[0].shape == (128, 192, 120)
assert feature_list[0].shape == (128, 204, 129) \
or feature_list[0].shape == (128, 144, 129) \
or feature_list[0].shape == (128, 165, 129)
assert supervision_list[0].weight == 1
supervision_list[0].num_sequences == 128 # minibach size is 128

Expand Down
5 changes: 3 additions & 2 deletions egs/aishell/s10/chain/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def main():
output_dim=args.output_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
kernel_size_list=args.kernel_size_list,
stride_list=args.stride_list)
bottleneck_dim=args.bottleneck_dim,
time_stride_list=args.time_stride_list,
conv_stride_list=args.conv_stride_list)

load_checkpoint(args.checkpoint, model)

Expand Down
240 changes: 143 additions & 97 deletions egs/aishell/s10/chain/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import logging
Expand All @@ -10,109 +10,118 @@
import torch.nn.functional as F

from common import load_lda_mat
'''
input dim=$feat_dim name=input
# please note that it is important to have input layer with the name=input
# as the layer immediately preceding the fixed-affine-layer to enable
# the use of short notation for the descriptor
fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=$dir/configs/lda.mat
# the first splicing is moved before the lda layer, so no splicing here
relu-batchnorm-layer name=tdnn1 dim=625
relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625
relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625
relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625
relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625
relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625
## adding the layers for chain branch
relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=625 target-rms=0.5
output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5
# adding the layers for xent branch
# This block prints the configs for a separate output that will be
# trained with a cross-entropy objective in the 'chain' models... this
# has the effect of regularizing the hidden parts of the model. we use
# 0.5 / args.xent_regularize as the learning rate factor- the factor of
# 0.5 / args.xent_regularize is suitable as it means the xent
# final-layer learns at a rate independent of the regularization
# constant; and the 0.5 was tuned so as to make the relative progress
# similar in the xent and regular final layers.
relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=625 target-rms=0.5
output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5
'''
from tdnnf_layer import FactorizedTDNN
from tdnnf_layer import OrthonormalLinear
from tdnnf_layer import PrefinalLayer


def get_chain_model(feat_dim,
output_dim,
hidden_dim,
kernel_size_list,
stride_list,
bottleneck_dim,
time_stride_list,
conv_stride_list,
lda_mat_filename=None):
model = ChainModel(feat_dim=feat_dim,
output_dim=output_dim,
lda_mat_filename=lda_mat_filename,
hidden_dim=hidden_dim,
kernel_size_list=kernel_size_list,
stride_list=stride_list)
time_stride_list=time_stride_list,
conv_stride_list=conv_stride_list)
return model


'''
input dim=43 name=input
# please note that it is important to have input layer with the name=input
# as the layer immediately preceding the fixed-affine-layer to enable
# the use of short notation for the descriptor
fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=exp/chain_cleaned_1c/tdnn1c_sp/configs/lda.mat
# the first splicing is moved before the lda layer, so no splicing here
relu-batchnorm-dropout-layer name=tdnn1 l2-regularize=0.008 dropout-proportion=0.0 dropout-per-dim-continuous=true dim=1024
tdnnf-layer name=tdnnf2 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1
tdnnf-layer name=tdnnf3 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1
tdnnf-layer name=tdnnf4 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1
tdnnf-layer name=tdnnf5 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=0
tdnnf-layer name=tdnnf6 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf7 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf8 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf9 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf10 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf11 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf12 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf13 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
linear-component name=prefinal-l dim=256 l2-regularize=0.008 orthonormal-constraint=-1.0
prefinal-layer name=prefinal-chain input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256
output-layer name=output include-log-softmax=false dim=3456 l2-regularize=0.002
prefinal-layer name=prefinal-xent input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256
output-layer name=output-xent dim=3456 learning-rate-factor=5.0 l2-regularize=0.002
'''


# Create a network like the above one
class ChainModel(nn.Module):

def __init__(self,
feat_dim,
output_dim,
lda_mat_filename,
hidden_dim=625,
kernel_size_list=[1, 3, 3, 3, 3, 3],
stride_list=[1, 1, 3, 1, 1, 1],
lda_mat_filename=None,
hidden_dim=1024,
bottleneck_dim=128,
time_stride_list=[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1],
conv_stride_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1],
frame_subsampling_factor=3):
super().__init__()

# at present, we support only frame_subsampling_factor to be 3
assert frame_subsampling_factor == 3

assert len(kernel_size_list) == len(stride_list)
num_layers = len(kernel_size_list)
assert len(time_stride_list) == len(conv_stride_list)
num_layers = len(time_stride_list)

# tdnn1_affine requires [N, T, C]
self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3,
out_features=hidden_dim)

tdnns = []
# tdnn1_batchnorm requires [N, C, T]
self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim)

tdnnfs = []
for i in range(num_layers):
in_channels = hidden_dim
if i == 0:
in_channels = feat_dim * 3

kernel_size = kernel_size_list[i]
stride = stride_list[i]

# we do not need to perform padding in Conv1d because it
# has been included in left/right context while generating egs
layer = nn.Conv1d(in_channels=in_channels,
out_channels=hidden_dim,
kernel_size=kernel_size,
stride=stride)
tdnns.append(layer)

self.tdnns = nn.ModuleList(tdnns)
self.batch_norms = nn.ModuleList([
nn.BatchNorm1d(num_features=hidden_dim) for i in range(num_layers)
])

self.prefinal_chain_tdnn = nn.Conv1d(in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=1)
self.prefinal_chain_batch_norm = nn.BatchNorm1d(num_features=hidden_dim)
self.output_fc = nn.Linear(in_features=hidden_dim,
out_features=output_dim)

self.prefinal_xent_tdnn = nn.Conv1d(in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=1)
self.prefinal_xent_batch_norm = nn.BatchNorm1d(num_features=hidden_dim)
self.output_xent_fc = nn.Linear(in_features=hidden_dim,
out_features=output_dim)
time_stride = time_stride_list[i]
conv_stride = conv_stride_list[i]
layer = FactorizedTDNN(dim=hidden_dim,
bottleneck_dim=bottleneck_dim,
time_stride=time_stride,
conv_stride=conv_stride)
tdnnfs.append(layer)

# tdnnfs requires [N, C, T]
self.tdnnfs = nn.ModuleList(tdnnfs)

# prefinal_l affine requires [N, C, T]
self.prefinal_l = OrthonormalLinear(dim=hidden_dim,
bottleneck_dim=bottleneck_dim * 2,
time_stride=0)

# prefinal_chain requires [N, C, T]
self.prefinal_chain = PrefinalLayer(big_dim=hidden_dim,
small_dim=bottleneck_dim * 2)

# output_affine requires [N, T, C]
self.output_affine = nn.Linear(in_features=bottleneck_dim * 2,
out_features=output_dim)

# prefinal_xent requires [N, C, T]
self.prefinal_xent = PrefinalLayer(big_dim=hidden_dim,
small_dim=bottleneck_dim * 2)

self.output_xent_affine = nn.Linear(in_features=bottleneck_dim * 2,
out_features=output_dim)

if lda_mat_filename:
logging.info('Use LDA from {}'.format(lda_mat_filename))
Expand Down Expand Up @@ -146,32 +155,69 @@ def forward(self, x):

# at this point, x is [N, C, T]

# Conv1d requires input of shape [N, C, T]
for i in range(len(self.tdnns)):
x = self.tdnns[i](x)
x = F.relu(x)
x = self.batch_norms[i](x)
x = x.permute(0, 2, 1)

# at this point, x is [N, T, C]

x = self.tdnn1_affine(x)

# at this point, x is [N, T, C]

x = F.relu(x)

x = x.permute(0, 2, 1)

# at this point, x is [N, C, T]

x = self.tdnn1_batchnorm(x)

# tdnnf requires input of shape [N, C, T]
for i in range(len(self.tdnnfs)):
x = self.tdnnfs[i](x)

# at this point, x is [N, C, T]

# we have two branches from this point on
x = self.prefinal_l(x)

# at this point, x is [N, C, T]

# first, for the chain branch
x_chain = self.prefinal_chain_tdnn(x)
x_chain = F.relu(x_chain)
x_chain = self.prefinal_chain_batch_norm(x_chain)
x_chain = x_chain.permute(0, 2, 1)
# at this point, x_chain is [N, T, C]
nnet_output = self.output_fc(x_chain)
# for the output node
nnet_output = self.prefinal_chain(x)

# now for the xent branch
x_xent = self.prefinal_xent_tdnn(x)
x_xent = F.relu(x_xent)
x_xent = self.prefinal_xent_batch_norm(x_xent)
x_xent = x_xent.permute(0, 2, 1)
# at this point, nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1)
# at this point, nnet_output is [N, T, C]
nnet_output = self.output_affine(nnet_output)

# for the xent node
xent_output = self.prefinal_xent(x)

# at this point, xent_output is [N, C, T]
xent_output = xent_output.permute(0, 2, 1)
# at this point, xent_output is [N, T, C]
xent_output = self.output_xent_affine(xent_output)

# at this point x_xent is [N, T, C]
xent_output = self.output_xent_fc(x_xent)
xent_output = F.log_softmax(xent_output, dim=-1)

return nnet_output, xent_output

def constrain_orthonormal(self):
for i in range(len(self.tdnnfs)):
self.tdnnfs[i].constrain_orthonormal()

self.prefinal_l.constrain_orthonormal()
self.prefinal_chain.constrain_orthonormal()
self.prefinal_xent.constrain_orthonormal()


if __name__ == '__main__':
feat_dim = 43
output_dim = 4344
model = ChainModel(feat_dim=feat_dim, output_dim=output_dim)
N = 1
T = 150 + 27 + 27
C = feat_dim * 3
x = torch.arange(N * T * C).reshape(N, T, C).float()
nnet_output, xent_output = model(x)
print(x.shape, nnet_output.shape, xent_output.shape)
model.constrain_orthonormal()
33 changes: 20 additions & 13 deletions egs/aishell/s10/chain/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,19 @@ def _check_args(args):
assert args.feat_dim > 0
assert args.output_dim > 0
assert args.hidden_dim > 0
assert args.bottleneck_dim > 0

assert args.kernel_size_list is not None
assert len(args.kernel_size_list) > 0
assert args.time_stride_list is not None
assert len(args.time_stride_list) > 0

assert args.stride_list is not None
assert len(args.stride_list) > 0
assert args.conv_stride_list is not None
assert len(args.conv_stride_list) > 0

args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')]
args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')]

args.stride_list = [int(k) for k in args.stride_list.split(', ')]
args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')]

assert len(args.kernel_size_list) == len(args.stride_list)
assert len(args.time_stride_list) == len(args.conv_stride_list)

assert args.log_level in ['debug', 'info', 'warning']

Expand Down Expand Up @@ -195,15 +196,21 @@ def get_args():
required=True,
type=int)

parser.add_argument('--kernel-size-list',
dest='kernel_size_list',
help='kernel size list',
parser.add_argument('--bottleneck-dim',
dest='bottleneck_dim',
help='nn bottleneck dimension',
required=True,
type=int)

parser.add_argument('--time-stride-list',
dest='time_stride_list',
help='time stride list',
required=True,
type=str)

parser.add_argument('--stride-list',
dest='stride_list',
help='stride list',
parser.add_argument('--conv-stride-list',
dest='conv_stride_list',
help='conv stride list',
required=True,
type=str)

Expand Down
Loading

0 comments on commit be0842f

Please sign in to comment.