Skip to content

Commit

Permalink
Fix exporting streaming zipformer models. (#1755)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Sep 11, 2024
1 parent 329e34a commit 6f1abd8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
import torch
import torch.nn as nn
from decoder import Decoder
from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
Expand Down Expand Up @@ -756,6 +755,7 @@ def main():
logging.info(f"Exported joiner to {joiner_filename}")

if(params.fp16) :
from onnxconverter_common import float16
logging.info("Generate fp16 models")

encoder = onnx.load(encoder_filename)
Expand Down
41 changes: 33 additions & 8 deletions egs/librispeech/ASR/zipformer/zipformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,18 @@ def _to_tuple(x):
dim=encoder_dim[i],
downsample=downsampling_factor[i],
dropout=dropout,
causal=causal,
)

encoders.append(encoder)

self.encoders = nn.ModuleList(encoders)

self.downsample_output = SimpleDownsample(
max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
max(encoder_dim),
downsample=output_downsampling_factor,
dropout=dropout,
causal=causal,
)

def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
Expand Down Expand Up @@ -1217,11 +1221,16 @@ class DownsampledZipformer2Encoder(nn.Module):
"""

def __init__(
self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
self,
encoder: nn.Module,
dim: int,
downsample: int,
dropout: FloatLike,
causal: bool,
):
super(DownsampledZipformer2Encoder, self).__init__()
self.downsample_factor = downsample
self.downsample = SimpleDownsample(dim, downsample, dropout)
self.downsample = SimpleDownsample(dim, downsample, dropout, causal)
self.num_layers = encoder.num_layers
self.encoder = encoder
self.upsample = SimpleUpsample(dim, downsample)
Expand Down Expand Up @@ -1310,9 +1319,12 @@ class SimpleDownsample(torch.nn.Module):
Does downsampling with attention, by weighted sum, and a projection..
"""

def __init__(self, channels: int, downsample: int, dropout: FloatLike):
def __init__(
self, channels: int, downsample: int, dropout: FloatLike, causal: bool
):
super(SimpleDownsample, self).__init__()

self.causal = causal
self.bias = nn.Parameter(torch.zeros(downsample))

self.name = None # will be set from training code
Expand All @@ -1333,9 +1345,18 @@ def forward(self, src: Tensor) -> Tensor:
# Pad to an exact multiple of self.downsample
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds

if self.causal and torch.jit.is_tracing():
assert (
pad == 0
), f"pad should be zero for exporting streaming models. Given {pad}"

# If we are exporting a streaming model, then we skip the if statement
if not self.causal or not torch.jit.is_tracing():
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)

assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds)

src = src.reshape(d_seq_len, ds, batch_size, in_channels)

Expand Down Expand Up @@ -1609,7 +1630,11 @@ def forward(
k = x[..., query_dim : 2 * query_dim]
# p is the position-encoding query
p = x[..., 2 * query_dim :]
assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim)
assert p.shape[-1] == num_heads * pos_head_dim, (
p.shape[-1],
num_heads,
pos_head_dim,
)

q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
Expand Down

0 comments on commit 6f1abd8

Please sign in to comment.