diff --git a/e3nn_jax/_src/irreps_array.py b/e3nn_jax/_src/irreps_array.py index 66c2fc4..158e62a 100644 --- a/e3nn_jax/_src/irreps_array.py +++ b/e3nn_jax/_src/irreps_array.py @@ -1118,6 +1118,7 @@ def rechunk(self, irreps: IntoIrreps) -> "IrrepsArray": new_chunks = None if self._chunks is not None: + jnp = _infer_backend(self.array) leading_shape = self.shape[:-1] new_chunks = []