From bc1b92165520c3f1e123a73f1e748addcadf0b60 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 13 Sep 2023 11:42:22 +0200 Subject: [PATCH] fix --- e3nn_jax/_src/irreps_array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/e3nn_jax/_src/irreps_array.py b/e3nn_jax/_src/irreps_array.py index 66c2fc4..158e62a 100644 --- a/e3nn_jax/_src/irreps_array.py +++ b/e3nn_jax/_src/irreps_array.py @@ -1118,6 +1118,7 @@ def rechunk(self, irreps: IntoIrreps) -> "IrrepsArray": new_chunks = None if self._chunks is not None: + jnp = _infer_backend(self.array) leading_shape = self.shape[:-1] new_chunks = []