Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Sep 26, 2023
1 parent 8e16c7d commit 559b982
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 3 additions & 1 deletion k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,9 @@ def get_rnnt_logprobs_smoothed(
unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]

px = px_am + px_lm # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T]
px = (
px_am + px_lm
) # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T]
px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] or [B][S][T]

px_amonly = (
Expand Down
7 changes: 3 additions & 4 deletions k2/python/tests/rnnt_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ def test_rnnt_loss_random(self):
)
assert (
px.shape == (B, S, T)
if rnnt_type != "regular" else (B, S, T + 1)
if rnnt_type != "regular"
else (B, S, T + 1)
)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
Expand Down Expand Up @@ -484,7 +485,7 @@ def test_rnnt_loss_smoothed(self):
assert torch.allclose(m, expected.to(device))

def test_rnnt_loss_pruned(self):
print (f"\ntest_rnnt_loss_pruned.")
print(f"\ntest_rnnt_loss_pruned.")
B = 4
T = 300
S = 50
Expand Down Expand Up @@ -758,7 +759,6 @@ def test_prune_ranges(self):

print(f"Pruned with old ranges {r} : {loss}")


# Test low s_range values with large S and small T,
# at this circumstance, the s_range would not be enough
# to cover the whole sequence length (in regular rnnt mode)
Expand Down Expand Up @@ -858,7 +858,6 @@ def test_rnnt_loss_pruned_small_s_range(self):

# Check that training with an empty reference does not cause a crash.
def _test_rnnt_loss_empty_reference(self):

B = 1
S = 0
T = 4
Expand Down

0 comments on commit 559b982

Please sign in to comment.