diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 87589b5da4..b23cf7e066 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -72,7 +72,7 @@ def test_arary_api_strict(self): input = xp.asarray(self.random_input) test = get_activation_fn_dp(self.activation)(input) - np.testing.assert_allclose(self.ref, np.array(test), atol=1e-10) + np.testing.assert_allclose(self.ref, to_numpy_array(test), atol=1e-10) @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") def test_jax_consistent_with_ref(self):