diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ee5422abac..6f00bc14d9 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,75 @@ ## Results +### zipformer (zipformer + CTC/AED) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +Results of the CTC head: + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-decoding | 2.29 | 5.14 | --epoch 50 --avg 29 | +| attention-decoder-rescoring-no-ngram | 2.1 | 4.57 | --epoch 50 --avg 29 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --full-libri 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --ctc-loss-scale 0.1 \ + --attention-decoder-loss-scale 0.9 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 1200 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-decoding attention-decoder-rescoring-no-ngram; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 29 \ + --exp-dir zipformer/exp-large \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --attention-decoder-loss-scale 0.9 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 100 \ + --causal 0 \ + --num-paths 100 \ + --decoding-method $m +done +``` + + ### zipformer (zipformer + pruned stateless transducer + CTC) See for more details. diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py new file mode 100644 index 0000000000..71be2d1ebb --- /dev/null +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -0,0 +1,573 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import List, Optional + +import k2 +import torch +import torch.nn as nn + +from label_smoothing import LabelSmoothingLoss +from icefall.utils import add_eos, add_sos, make_pad_mask +from scaling import penalize_abs_values_gt + + +class AttentionDecoderModel(nn.Module): + """ + Args: + vocab_size (int): Number of classes. + decoder_dim: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + num_heads (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + sos_id: int = 1, + eos_id: int = 1, + dropout: float = 0.1, + ignore_id: int = -1, + label_smoothing: float = 0.1, + ): + super().__init__() + self.eos_id = eos_id + self.sos_id = sos_id + self.ignore_id = ignore_id + + # For the segment of the warmup period, we let the Embedding + # layer learn something. Then we start to warm up the other encoders. + self.decoder = TransformerDecoder( + vocab_size=vocab_size, + d_model=decoder_dim, + num_decoder_layers=num_decoder_layers, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + + # Used to calculate attention-decoder loss + self.loss_fun = LabelSmoothingLoss( + ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum" + ) + + def _pre_ys_in_out(self, ys: k2.RaggedTensor, ys_lens: torch.Tensor): + """Prepare ys_in_pad and ys_out_pad.""" + ys_in = add_sos(ys, sos_id=self.sos_id) + # [B, S+1], start with SOS + ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id) + ys_in_lens = ys_lens + 1 + + ys_out = add_eos(ys, eos_id=self.eos_id) + # [B, S+1], end with EOS + ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id) + + return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64) + + def calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys: k2.RaggedTensor, + ys_lens: torch.Tensor, + ) -> torch.Tensor: + """Calculate attention-decoder loss. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: The attention-decoder loss. + """ + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + loss = self.loss_fun(x=decoder_out, target=ys_out_pad) + return loss + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + token_ids: List[List[int]], + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from attention-decoder. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: A tensor of shape (batch, num_tokens). + """ + ys = k2.RaggedTensor(token_ids).to(device=encoder_out.device) + row_splits = ys.shape.row_splits(1) + ys_lens = row_splits[1:] - row_splits[:-1] + + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) + + batch_size, _, num_classes = decoder_out.size() + nll = nn.functional.cross_entropy( + decoder_out.view(-1, num_classes), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction="none", + ) + nll = nll.view(batch_size, -1) + return nll + + +class TransformerDecoder(nn.Module): + """Transfomer decoder module. + + Args: + vocab_size: output dim + d_model: decoder dimension + num_decoder_layers: number of decoder layers + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + vocab_size: int, + d_model: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + super().__init__() + self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) + + # Absolute positional encoding + self.pos = PositionalEncoding(d_model, dropout_rate=0.1) + + self.num_layers = num_decoder_layers + self.layers = nn.ModuleList( + [ + DecoderLayer( + d_model=d_model, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, + ) + for _ in range(num_decoder_layers) + ] + ) + + self.output_layer = nn.Linear(d_model, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + memory: Optional[torch.Tensor] = None, + memory_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch, tgt_len). + x_lens: A tensor of shape (batch,) containing the number of tokens in `x` + before padding. + memory: + Memory sequence of shape (batch, src_len, memory_dim). + memory_lens: + A tensor of shape (batch,) containing the number of frames in + `memory` before padding. + + Returns: + Decoded token logits before softmax (batch, tgt_len, vocab_size) + """ + x = self.embed(x) # (batch, tgt_len, embed_dim) + x = self.pos(x) # (batch, tgt_len, embed_dim) + + x = x.permute(1, 0, 2) # (tgt_len, batch, embed_dim) + + # construct attn_mask for self-attn modules + padding_mask = make_pad_mask(x_lens) # (batch, tgt_len) + causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) + attn_mask = torch.logical_or( + padding_mask.unsqueeze(1), # (batch, 1, seq_len) + torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len) + ) # (batch, seq_len, seq_len) + + if memory is not None: + memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim) + # construct memory_attn_mask for cross-attn modules + memory_padding_mask = make_pad_mask(memory_lens) # (batch, src_len) + memory_attn_mask = memory_padding_mask.unsqueeze(1) # (batch, 1, src_len) + else: + memory_attn_mask = None + + for i, mod in enumerate(self.layers): + x = mod( + x, + attn_mask=attn_mask, + memory=memory, + memory_attn_mask=memory_attn_mask, + ) + + x = x.permute(1, 0, 2) # (batch, tgt_len, vocab_size) + x = self.output_layer(x) + + return x + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + d_model: equal to decoder_dim, total dimension of the decoder + attention_dim: total dimension of multi head attention + num_heads: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + d_model: int = 512, + attention_dim: int = 512, + num_heads: int = 8, + feedforward_dim: int = 2048, + memory_dim: int = 512, + dropout: float = 0.1, + ): + """Construct an DecoderLayer object.""" + super(DecoderLayer, self).__init__() + + self.norm_self_attn = nn.LayerNorm(d_model) + self.self_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, dropout=0.0 + ) + + self.norm_src_attn = nn.LayerNorm(d_model) + self.src_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0 + ) + + self.norm_ff = nn.LayerNorm(d_model) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, feedforward_dim), + Swish(), + nn.Dropout(dropout), + nn.Linear(feedforward_dim, d_model), + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: Input sequence of shape (seq_len, batch, embed_dim). + attn_mask: A binary mask for self-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + memory: Memory sequence of shape (seq_len, batch, memory_dim). + memory_attn_mask: A binary mask for cross-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + """ + # self-attn module + qkv = self.norm_self_attn(x) + self_attn_out = self.self_attn( + query=qkv, key=qkv, value=qkv, attn_mask=attn_mask + ) + x = x + self.dropout(self_attn_out) + + # cross-attn module + q = self.norm_src_attn(x) + src_attn_out = self.src_attn( + query=q, key=memory, value=memory, attn_mask=memory_attn_mask + ) + x = x + self.dropout(src_attn_out) + + # feed-forward module + x = x + self.dropout(self.feed_forward(self.norm_ff(x))) + + return x + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, but must be a multiple of num_heads. + num_heads: number of parallel attention heads. + memory_dim: dimension of memory embedding, optional. + dropout: a Dropout layer on attn_output_weights. + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + memory_dim: Optional[int] = None, + dropout: float = 0.0, + ): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.head_dim = attention_dim // num_heads + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, num_heads, attention_dim + ) + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True) + self.linear_k = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + self.linear_v = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + + self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute dot product attention. + + Args: + query: Query tensor of shape (tgt_len, batch, embed_dim). + key: Key tensor of shape (src_len, batch, embed_dim or memory_dim). + value: Value tensor of shape (src_len, batch, embed_dim or memory_dim). + key_padding_mask: A binary mask indicating which elements are padding. + Its shape is (batch, src_len). + attn_mask: A binary mask indicating which elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + + Returns: + Output tensor of shape (tgt_len, batch, embed_dim). + """ + num_heads = self.num_heads + head_dim = self.head_dim + + tgt_len, batch, _ = query.shape + src_len = key.shape[0] + + q = self.linear_q(query) # (tgt_len, batch, num_heads * head_dim) + k = self.linear_k(key) # (src_len, batch, num_heads * head_dim) + v = self.linear_v(value) # (src_len, batch, num_heads * head_dim) + + q = q.reshape(tgt_len, batch, num_heads, head_dim) + q = q.permute(1, 2, 0, 3) # (batch, head, tgt_len, head_dim) + k = k.reshape(src_len, batch, num_heads, head_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, head_dim, src_len) + v = v.reshape(src_len, batch, num_heads, head_dim) + v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1) + + # Note: could remove the scaling operation when using ScaledAdam + # (batch, head, tgt_len, src_len) + attn_weights = torch.matmul(q, k) / math.sqrt(head_dim) + + # From zipformer.py: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), + ) + + if attn_mask is not None: + assert ( + attn_mask.shape == (batch, 1, src_len) + or attn_mask.shape == (batch, tgt_len, src_len) + ), attn_mask.shape + attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf")) + + attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + # (batch * head, tgt_len, head_dim) + attn_output = torch.bmm(attn_weights, v) + assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape + + attn_output = attn_output.transpose(0, 1).contiguous() + attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) + + # (batch, tgt_len, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class PositionalEncoding(nn.Module): + """Positional encoding. + Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def subsequent_mask(size, device="cpu", dtype=torch.bool): + """Create mask for subsequent steps (size, size). + + :param int size: size of mask + :param str device: "cpu" or "cuda" or torch.Tensor.device + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=dtype) + return torch.tril(ret, out=ret) + + +def _test_attention_decoder_model(): + m = AttentionDecoderModel( + vocab_size=500, + decoder_dim=512, + num_decoder_layers=6, + attention_dim=512, + num_heads=8, + feedforward_dim=2048, + memory_dim=384, + dropout=0.1, + sos_id=1, + eos_id=1, + ignore_id=-1, + ) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + m.eval() + encoder_out = torch.randn(2, 50, 384) + encoder_out_lens = torch.full((2,), 50) + token_ids = [[1, 2, 3, 4], [2, 3, 10]] + + nll = m.nll(encoder_out, encoder_out_lens, token_ids) + print(nll) + + +if __name__ == "__main__": + _test_attention_decoder_model() diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 1f0f9bfac3..85ceb61b8a 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -73,6 +73,29 @@ --nbest-scale 1.0 \ --lm-dir data/lm \ --decoding-method whole-lattice-rescoring + +(6) attention-decoder-rescoring-no-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --decoding-method attention-decoder-rescoring-no-ngram + +(7) attention-decoder-rescoring-with-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method attention-decoder-rescoring-with-ngram """ @@ -101,6 +124,8 @@ nbest_decoding, nbest_oracle, one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_attention_decoder_with_ngram, rescore_with_n_best_list, rescore_with_whole_lattice, ) @@ -212,6 +237,10 @@ def get_parser(): - (6) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. + - (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + - (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + rescored lattice, rescore them with the attention decoder. """, ) @@ -406,6 +435,26 @@ def decode_one_batch( key = "ctc-decoding" return {key: hyps} + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + if params.decoding_method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -446,6 +495,7 @@ def decode_one_batch( assert params.decoding_method in [ "nbest-rescoring", "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] @@ -466,6 +516,21 @@ def decode_one_batch( G_with_epsilon_loops=G, lm_scale_list=lm_scale_list, ) + elif params.decoding_method == "attention-decoder-rescoring-with-ngram": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + best_path_dict = rescore_with_attention_decoder_with_ngram( + lattice=rescored_lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) else: assert False, f"Unsupported decoding method: {params.decoding_method}" @@ -564,12 +629,21 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): + if params.decoding_method in ( + "attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" + ): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. @@ -577,8 +651,8 @@ def save_results( with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" @@ -616,6 +690,8 @@ def main(): "nbest-rescoring", "whole-lattice-rescoring", "nbest-oracle", + "attention-decoder-rescoring-no-ngram", + "attention-decoder-rescoring-with-ngram", ) params.res_dir = params.exp_dir / params.decoding_method @@ -654,8 +730,10 @@ def main(): params.vocab_size = num_classes # and are defined in local/train_bpe_model.py params.blank_id = 0 + params.eos_id = 1 + params.sos_id = 1 - if params.decoding_method == "ctc-decoding": + if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: HLG = None H = k2.ctc_topo( max_token=max_token_id, @@ -679,6 +757,7 @@ def main(): if params.decoding_method in ( "nbest-rescoring", "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ): if not (params.lm_dir / "G_4_gram.pt").is_file(): logging.info("Loading G_4_gram.fst.txt") @@ -710,7 +789,9 @@ def main(): d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) G = k2.Fsa.from_dict(d) - if params.decoding_method == "whole-lattice-rescoring": + if params.decoding_method in [ + "whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" + ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index 2b8d1aaf36..1f3373cd83 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -404,6 +404,7 @@ def main(): token_table = k2.SymbolTable.from_file(params.tokens) params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] params.vocab_size = num_tokens(token_table) + 1 logging.info(params) @@ -466,8 +467,6 @@ def main(): device=device, ) ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: assert params.avg > 0, params.avg start = params.epoch - params.avg diff --git a/egs/librispeech/ASR/zipformer/label_smoothing.py b/egs/librispeech/ASR/zipformer/label_smoothing.py new file mode 100644 index 0000000000..52d2eda3bb --- /dev/null +++ b/egs/librispeech/ASR/zipformer/label_smoothing.py @@ -0,0 +1,109 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) + + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + ) + + # Set the value of ignored indexes to 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 86da3ab29a..bd1ed26d8d 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -34,11 +34,13 @@ def __init__( encoder: EncoderInterface, decoder: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, encoder_dim: int = 384, decoder_dim: int = 512, vocab_size: int = 500, use_transducer: bool = True, use_ctc: bool = False, + use_attention_decoder: bool = False, ): """A joint CTC & Transducer ASR model. @@ -70,6 +72,8 @@ def __init__( Whether use transducer head. Default: True. use_ctc: Whether use CTC head. Default: False. + use_attention_decoder: + Whether use attention-decoder head. Default: False. """ super().__init__() @@ -111,6 +115,12 @@ def __init__( nn.LogSoftmax(dim=-1), ) + self.use_attention_decoder = use_attention_decoder + if use_attention_decoder: + self.attention_decoder = attention_decoder + else: + assert attention_decoder is None + def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -286,7 +296,7 @@ def forward( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -308,7 +318,7 @@ def forward( part Returns: Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss) + in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -322,6 +332,8 @@ def forward( assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + device = x.device + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -333,7 +345,7 @@ def forward( simple_loss, pruned_loss = self.forward_transducer( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - y=y.to(x.device), + y=y.to(device), y_lens=y_lens, prune_range=prune_range, am_scale=am_scale, @@ -355,4 +367,14 @@ def forward( else: ctc_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss + if self.use_attention_decoder: + attention_decoder_loss = self.attention_decoder.calc_att_loss( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ys=y.to(device), + ys_lens=y_lens.to(device), + ) + else: + attention_decoder_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index 408d135769..4341ef61f7 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -81,6 +81,15 @@ --sample-rate 16000 \ /path/to/foo.wav \ /path/to/bar.wav + +(5) attention-decoder-rescoring-no-ngram +./zipformer/pretrained_ctc.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method attention-decoder-rescoring-no-ngram \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav """ import argparse @@ -100,6 +109,7 @@ from icefall.decode import ( get_lattice, one_best_decoding, + rescore_with_attention_decoder_no_ngram, rescore_with_n_best_list, rescore_with_whole_lattice, ) @@ -172,6 +182,8 @@ def get_parser(): decoding lattice and then use 1best to decode the rescored lattice. We call it HLG decoding + whole-lattice n-gram LM rescoring. + (4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. """, ) @@ -276,6 +288,7 @@ def main(): token_table = k2.SymbolTable.from_file(params.tokens) params.vocab_size = num_tokens(token_table) + 1 # +1 for blank params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] assert params.blank_id == 0 logging.info(f"{params}") @@ -333,16 +346,13 @@ def main(): dtype=torch.int32, ) - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: max_token_id = params.vocab_size - 1 - H = k2.ctc_topo( max_token=max_token_id, modified=False, device=device, ) - lattice = get_lattice( nnet_output=ctc_output, decoding_graph=H, @@ -354,9 +364,23 @@ def main(): subsampling_factor=params.subsampling_factor, ) - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + else: + logging.info("Use attention decoder rescoring without ngram") + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + token_ids = get_texts(best_path) hyps = [[token_table[i] for i in ids] for ids in token_ids] elif params.method in [ @@ -430,7 +454,7 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - if params.method == "ctc-decoding": + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: for filename, hyp in zip(params.sound_files, hyps): words = "".join(hyp) words = words.replace("▁", " ").strip() diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 04caf2fd80..d87041a52c 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -48,6 +48,8 @@ - transducer loss (default), with `--use-transducer True --use-ctc False` - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` + - ctc loss & attention decoder loss, no transducer loss, + with `--use-transducer False --use-ctc True --use-attention-decoder True` """ @@ -66,6 +68,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -221,6 +224,41 @@ def add_model_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--attention-decoder-dim", + type=int, + default=512, + help="""Dimension used in the attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-dim", + type=int, + default=512, + help="""Attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-dim", + type=int, + default=2048, + help="""Feedforward dimension used in attention decoder""", + ) + parser.add_argument( "--causal", type=str2bool, @@ -259,6 +297,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use CTC head.", ) + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -404,6 +449,13 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + parser.add_argument( "--seed", type=int, @@ -532,6 +584,9 @@ def get_params() -> AttributeDict: # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, "warm_step": 2000, "env_info": get_env_info(), } @@ -604,6 +659,23 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + def get_model(params: AttributeDict) -> nn.Module: assert params.use_transducer or params.use_ctc, ( f"At least one of them should be True, " @@ -621,16 +693,23 @@ def get_model(params: AttributeDict) -> nn.Module: decoder = None joiner = None + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, + attention_decoder=attention_decoder, encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, use_transducer=params.use_transducer, use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, ) return model @@ -793,7 +872,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -823,6 +902,9 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + assert loss.requires_grad == is_training info = MetricsTracker() @@ -837,6 +919,8 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() return loss, info @@ -1116,10 +1200,16 @@ def run(rank, world_size, args): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() if not params.use_transducer: - params.ctc_loss_scale = 1.0 + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, params.attention_decoder_loss_scale + ) logging.info(params) diff --git a/icefall/decode.py b/icefall/decode.py index 23f9fb9b3a..3abd5648a1 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1083,6 +1083,238 @@ def rescore_with_attention_decoder( return ans +def rescore_with_attention_decoder_with_ngram( + lattice: k2.Fsa, + num_paths: int, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + nbest_scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + attention_decoder: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + encoder_out: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(N, T, C)`. + encoder_out_lens: + Length of encoder outputs, with shape of `(N,)`. + nbest_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ + max_loop_count = 10 + loop_count = 0 + while loop_count <= max_loop_count: + try: + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # nbest.fsa.scores are all 0s at this point + nbest = nbest.intersect(lattice) + break + except RuntimeError as e: + logging.info(f"Caught exception:\n{e}\n") + logging.info(f"num_paths before decreasing: {num_paths}") + num_paths = int(num_paths / 2) + if loop_count >= max_loop_count or num_paths <= 0: + logging.info("Return None as the resulting lattice is too large.") + return None + logging.info( + "This OOM is not an error. You can ignore it. " + "If your model does not converge well, or --max-duration " + "is too large, or the input sound file is difficult to " + "decode, you will meet this exception." + ) + logging.info(f"num_paths after decreasing: {num_paths}") + loop_count += 1 + + # Now nbest.fsa has its scores set. + # Also, nbest.fsa inherits the attributes from `lattice`. + assert hasattr(nbest.fsa, "lm_scores") + + am_scores = nbest.compute_am_scores() + ngram_lm_scores = nbest.compute_lm_scores() + + # The `tokens` attribute is set inside `compile_hlg.py` + assert hasattr(nbest.fsa, "tokens") + assert isinstance(nbest.fsa.tokens, torch.Tensor) + + path_to_utt_map = nbest.shape.row_ids(1).to(torch.long) + # the shape of memory is (T, N, C), so we use axis=1 here + expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map) + + # remove axis corresponding to states. + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + tokens = tokens.remove_values_leq(0) + token_ids = tokens.tolist() + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if ngram_lm_scale is None: + ngram_lm_scale_list = [0.01, 0.05, 0.08] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + else: + ngram_lm_scale_list = [ngram_lm_scale] + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + for n_scale in ngram_lm_scale_list: + for a_scale in attention_scale_list: + tot_scores = ( + am_scores.values + + n_scale * ngram_lm_scores.values + + a_scale * attention_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" + ans[key] = best_path + return ans + + +def rescore_with_attention_decoder_no_ngram( + lattice: k2.Fsa, + num_paths: int, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + nbest_scale: float = 1.0, + attention_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + attention_decoder: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + encoder_out: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(N, T, C)`. + encoder_out_lens: + Length of encoder outputs, with shape of `(N,)`. + nbest_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ + # path is a ragged tensor with dtype torch.int32. + # It has three axes [utt][path][arc_pos] + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + # Note that labels, aux_labels and scores contains 0s and -1s. + # The last entry in each sublist is -1. + # The axes are [path][token_id] + labels = k2.ragged.index(lattice.labels.contiguous(), path).remove_axis(0) + aux_labels = k2.ragged.index(lattice.aux_labels.contiguous(), path).remove_axis(0) + scores = k2.ragged.index(lattice.scores.contiguous(), path).remove_axis(0) + + # Remove -1 from labels as we will use it to construct a linear FSA + labels = labels.remove_values_eq(-1) + fsa = k2.linear_fsa(labels) + fsa.aux_labels = aux_labels.values + + # utt_to_path_shape has axes [utt][path] + utt_to_path_shape = path.shape.get_layer(0) + scores = k2.RaggedTensor(utt_to_path_shape, scores.sum()) + + path_to_utt_map = utt_to_path_shape.row_ids(1).to(torch.long) + # the shape of memory is (N, T, C), so we use axis=0 here + expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map) + + token_ids = aux_labels.remove_values_leq(0).tolist() + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + + for a_scale in attention_scale_list: + tot_scores = scores.values + a_scale * attention_scores + ragged_tot_scores = k2.RaggedTensor(utt_to_path_shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(fsa, max_indexes) + + key = f"attention_scale_{a_scale}" + ans[key] = best_path + return ans + + def rescore_with_rnn_lm( lattice: k2.Fsa, num_paths: int,