From 108f9c89e60309ee30e9493776748cbaaddc6793 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 30 Oct 2024 09:03:43 +0100 Subject: [PATCH] test should pass --- ot/backend.py | 7 ++++++- test/gromov/test_lowrank.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index f14d1a81c..a99639445 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1890,7 +1890,12 @@ def _to_numpy(self, a): return a.cpu().detach().numpy() def _from_numpy(self, a, type_as=None): - if isinstance(a, float) or isinstance(a, int): + if ( + isinstance(a, float) + or isinstance(a, int) + or isinstance(a, np.float32) + or isinstance(a, np.float64) + ): a = np.array(a) if type_as is None: return torch.from_numpy(a) diff --git a/test/gromov/test_lowrank.py b/test/gromov/test_lowrank.py index 312ce749d..27e0fcdb0 100644 --- a/test/gromov/test_lowrank.py +++ b/test/gromov/test_lowrank.py @@ -15,8 +15,8 @@ def test__flat_product_operator(): X = np.reshape(1.0 * np.arange(2 * n), (n, d)) A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False) - A1_ = ot.gromov._flat_product_operator(A1) - A2_ = ot.gromov._flat_product_operator(A2) + A1_ = ot.gromov._lowrank._flat_product_operator(A1) + A2_ = ot.gromov._lowrank._flat_product_operator(A2) cost = ot.dist(X, X) # test value