From 686d2d9787ab92f62abbefe1bb1119d48db2cafe Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 29 Mar 2024 19:08:21 +0800 Subject: [PATCH] minor updates --- egs/audioset/AT/zipformer/export-onnx.py | 12 ++++++------ egs/audioset/AT/zipformer/model.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py index 25bafc8771..5fc98f8b69 100755 --- a/egs/audioset/AT/zipformer/export-onnx.py +++ b/egs/audioset/AT/zipformer/export-onnx.py @@ -180,7 +180,7 @@ def forward( self, x: torch.Tensor, x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """Please see the help information of Zipformer.forward Args: @@ -206,7 +206,7 @@ def forward( logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( logits ) # normalize the logits - + print(logits.shape) return logits @@ -234,10 +234,10 @@ def export_audio_tagging_model_onnx( opset_version: The opset version to use. """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) + x = torch.zeros(1, 200, 80, dtype=torch.float32) + x_lens = torch.tensor([200], dtype=torch.int64) - model = torch.jit.trace(model, (x, x_lens)) + model = torch.jit.script(model) torch.onnx.export( model, @@ -250,7 +250,7 @@ def export_audio_tagging_model_onnx( dynamic_axes={ "x": {0: "N", 1: "T"}, "x_lens": {0: "N"}, - # "logits": {0: "N", 1: "T"}, + "logits": {0: "N"}, }, ) diff --git a/egs/audioset/AT/zipformer/model.py b/egs/audioset/AT/zipformer/model.py index 7661ab4b67..f189eac622 100644 --- a/egs/audioset/AT/zipformer/model.py +++ b/egs/audioset/AT/zipformer/model.py @@ -144,7 +144,7 @@ def forward_audio_tagging(self, encoder_out, encoder_out_lens): before padding. Returns: - A 3-D tensor of shape (N, T, num_classes). + A 3-D tensor of shape (N, num_classes). """ logits = self.classifier(encoder_out) # (N, T, num_classes) padding_mask = make_pad_mask(encoder_out_lens)