From a67a1bc3c02b5264b0d72f3cb084caf5931741f3 Mon Sep 17 00:00:00 2001 From: Arturo Date: Fri, 16 Aug 2024 10:12:02 -0400 Subject: [PATCH] [MINOR] build error in github --- pymlg/jax/base.py | 1 + tests/jax/test_standard_jax.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pymlg/jax/base.py b/pymlg/jax/base.py index 4f78452..6fbca07 100644 --- a/pymlg/jax/base.py +++ b/pymlg/jax/base.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +from jax import numpy as jnp from jax import random, core import numpy as onp diff --git a/tests/jax/test_standard_jax.py b/tests/jax/test_standard_jax.py index c8474a2..2eeca95 100644 --- a/tests/jax/test_standard_jax.py +++ b/tests/jax/test_standard_jax.py @@ -1,6 +1,7 @@ -from jax.config import config -config.update("jax_enable_x64", True) +#from jax.config import config +from jax import config +config.config.update("jax_enable_x64", True) from pymlg.jax import SO3, SE3, SE23, SL3, SO2, SE2 import pytest