Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 29, 2024
1 parent 7bd679f commit 686d2d9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions egs/audioset/AT/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
"""Please see the help information of Zipformer.forward
Args:
Expand All @@ -206,7 +206,7 @@ def forward(
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
logits
) # normalize the logits

print(logits.shape)
return logits


Expand Down Expand Up @@ -234,10 +234,10 @@ def export_audio_tagging_model_onnx(
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
x = torch.zeros(1, 200, 80, dtype=torch.float32)
x_lens = torch.tensor([200], dtype=torch.int64)

model = torch.jit.trace(model, (x, x_lens))
model = torch.jit.script(model)

torch.onnx.export(
model,
Expand All @@ -250,7 +250,7 @@ def export_audio_tagging_model_onnx(
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
# "logits": {0: "N", 1: "T"},
"logits": {0: "N"},
},
)

Expand Down
2 changes: 1 addition & 1 deletion egs/audioset/AT/zipformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def forward_audio_tagging(self, encoder_out, encoder_out_lens):
before padding.
Returns:
A 3-D tensor of shape (N, T, num_classes).
A 3-D tensor of shape (N, num_classes).
"""
logits = self.classifier(encoder_out) # (N, T, num_classes)
padding_mask = make_pad_mask(encoder_out_lens)
Expand Down

0 comments on commit 686d2d9

Please sign in to comment.