Skip to content

Commit

Permalink
Fix silently failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Sep 15, 2023
1 parent 1154071 commit 3838f78
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyro/distributions/one_one_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,6 @@ def maximum_weight_matching(logits):
from scipy.optimize import linear_sum_assignment

cost = -logits.cpu()
value = linear_sum_assignment(cost.numpy())[0]
value = linear_sum_assignment(cost.numpy())[1]
value = torch.tensor(value, dtype=torch.long, device=logits.device)
return value
2 changes: 1 addition & 1 deletion pyro/distributions/one_two_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def maximum_weight_matching(logits):

cost = -logits.cpu()
cost = torch.cat([cost, cost], dim=-1) # Duplicate destinations.
value = linear_sum_assignment(cost.numpy())[0]
value = linear_sum_assignment(cost.numpy())[1]
value = torch.tensor(value, dtype=torch.long, device=logits.device)
value %= logits.size(1)
return value
4 changes: 3 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ def assert_tensors_equal(a, b, prec=0.0, msg=""):
assert a.size() == b.size(), msg
if isinstance(prec, numbers.Number) and prec == 0:
assert (a == b).all(), msg
return
if a.numel() == 0 and b.numel() == 0:
return
b = b.type_as(a)
b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu()
if not a.dtype.is_floating_point:
return (a == b).all()
assert (a == b).all(), msg
return
# check that NaNs are in the same locations
nan_mask = a != a
assert torch.equal(nan_mask, b != b), msg
Expand Down
1 change: 1 addition & 0 deletions tests/distributions/test_one_one_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_mode(num_nodes, dtype):
expected = values[i]
actual = d.mode()
assert_equal(actual, expected)
assert (actual == expected).all()


@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
Expand Down

0 comments on commit 3838f78

Please sign in to comment.