Skip to content

Commit

Permalink
test should pass
Browse files Browse the repository at this point in the history
  • Loading branch information
rflamary committed Oct 30, 2024
1 parent ce6d2ea commit 108f9c8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
7 changes: 6 additions & 1 deletion ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,7 +1890,12 @@ def _to_numpy(self, a):
return a.cpu().detach().numpy()

def _from_numpy(self, a, type_as=None):
if isinstance(a, float) or isinstance(a, int):
if (
isinstance(a, float)
or isinstance(a, int)
or isinstance(a, np.float32)
or isinstance(a, np.float64)
):
a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
Expand Down
4 changes: 2 additions & 2 deletions test/gromov/test_lowrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def test__flat_product_operator():
X = np.reshape(1.0 * np.arange(2 * n), (n, d))
A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False)

A1_ = ot.gromov._flat_product_operator(A1)
A2_ = ot.gromov._flat_product_operator(A2)
A1_ = ot.gromov._lowrank._flat_product_operator(A1)
A2_ = ot.gromov._lowrank._flat_product_operator(A2)
cost = ot.dist(X, X)

# test value
Expand Down

0 comments on commit 108f9c8

Please sign in to comment.