diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index 41f63773fb..278dc53efe 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -40,8 +40,6 @@ else: TypeEmbedNetJAX = object if INSTALLED_ARRAY_API_STRICT: - import array_api_strict - from ..array_api_strict.utils.type_embed import TypeEmbedNet as TypeEmbedNetStrict else: TypeEmbedNetStrict = None @@ -133,8 +131,7 @@ def eval_jax(self, jax_obj: Any) -> Any: def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: out = array_api_strict_obj() return [ - np.asarray(x) if hasattr(x, "__array_namespace__") else x - for x in (out,) + np.asarray(x) if hasattr(x, "__array_namespace__") else x for x in (out,) ] def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: