From 6779ef23d3536ef5f53fc3a9fa4ee79f58fa6161 Mon Sep 17 00:00:00 2001 From: Eric Giguere Date: Thu, 11 Jan 2024 12:54:57 -0500 Subject: [PATCH 1/3] Update for jax 0.4.23 and dims PR --- src/qutip_jax/jaxarray.py | 3 +-- src/qutip_jax/qutip_trees.py | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/qutip_jax/jaxarray.py b/src/qutip_jax/jaxarray.py index 20e4dab..5845bdc 100644 --- a/src/qutip_jax/jaxarray.py +++ b/src/qutip_jax/jaxarray.py @@ -1,6 +1,5 @@ import jax.numpy as jnp -from jax import tree_util -from jax.config import config +from jax import tree_util, config import numbers import numpy as np diff --git a/src/qutip_jax/qutip_trees.py b/src/qutip_jax/qutip_trees.py index 91cb04f..ae06a11 100644 --- a/src/qutip_jax/qutip_trees.py +++ b/src/qutip_jax/qutip_trees.py @@ -9,9 +9,7 @@ def qobj_tree_flatten(qobj): children = (qobj.to("jax").data,) aux_data = { - "dims": qobj.dims, - "type": qobj.type, - "superrep": qobj.superrep, + "_dims": qobj._dims, # Attribute that depend on the data are not safe to be set. "_isherm": None, "_isunitary": None, From b91e04d9d56575bef9b788bca20de72a97a3c8e2 Mon Sep 17 00:00:00 2001 From: Eric Giguere Date: Thu, 11 Jan 2024 13:34:56 -0500 Subject: [PATCH 2/3] Filter diffrax warnings --- src/qutip_jax/ode.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/qutip_jax/ode.py b/src/qutip_jax/ode.py index 10a7448..0b9ccdf 100644 --- a/src/qutip_jax/ode.py +++ b/src/qutip_jax/ode.py @@ -1,4 +1,5 @@ import diffrax +import warnings from qutip.solver.integrator import Integrator import jax import jax.numpy as jnp @@ -64,16 +65,22 @@ def get_state(self, copy=False): return self.t, JaxArray(_float2cplx(self.state)) def integrate(self, t, copy=False, **kwargs): - sol = diffrax.diffeqsolve( - self.ODEsystem, - t0=self.t, - t1=t, - y0=self.state, - saveat=diffrax.SaveAt(t1=True, solver_state=True), - solver_state=self.solver_state, - args=(self.system, kwargs), - **self._options, - ) + with warnings.catch_warnings(): + # Diffrax added partial support for complex number, but raise a + # warning when it find a complex anywhere in the tree. + warnings.filterwarnings("ignore", + message="Complex dtype support is work in progress," + ) + sol = diffrax.diffeqsolve( + self.ODEsystem, + t0=self.t, + t1=t, + y0=self.state, + saveat=diffrax.SaveAt(t1=True, solver_state=True), + solver_state=self.solver_state, + args=(self.system, kwargs), + **self._options, + ) self.t = t self.state = sol.ys[0, :] self.solver_state = sol.solver_state From efbf4a77beca9877a5295766f690c46ef619400a Mon Sep 17 00:00:00 2001 From: Eric Giguere Date: Thu, 11 Jan 2024 13:58:48 -0500 Subject: [PATCH 3/3] filter in tests --- .github/workflows/tests.yml | 2 +- .github/workflows/weekly.yml | 2 +- src/qutip_jax/ode.py | 27 ++++++++++----------------- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bf5426a..dc5c5ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -57,7 +57,7 @@ jobs: # to see if something hung. timeout-minutes: 60 run: | - pytest --durations=0 --durations-min=1.0 --verbosity=1 --cov=qutip_jax --cov-report= --color=yes -W ignore::UserWarning:qutip + pytest --durations=0 --durations-min=1.0 --verbosity=1 --cov=qutip_jax --cov-report= --color=yes -W ignore::UserWarning:qutip -W "ignore:Complex dtype:UserWarning" # Above flags are: # --durations=0 --durations-min=1.0 # at the end, show a list of all the tests that took longer than a diff --git a/.github/workflows/weekly.yml b/.github/workflows/weekly.yml index 3cd7ade..260ec2c 100644 --- a/.github/workflows/weekly.yml +++ b/.github/workflows/weekly.yml @@ -64,7 +64,7 @@ jobs: # to see if something hung. timeout-minutes: 60 run: | - pytest --durations=0 --durations-min=1.0 --verbosity=1 --color=yes -W ignore::UserWarning:qutip + pytest --durations=0 --durations-min=1.0 --verbosity=1 --color=yes -W ignore::UserWarning:qutip -W "ignore:Complex dtype:UserWarning" # Above flags are: # --durations=0 --durations-min=1.0 # at the end, show a list of all the tests that took longer than a diff --git a/src/qutip_jax/ode.py b/src/qutip_jax/ode.py index 0b9ccdf..10a7448 100644 --- a/src/qutip_jax/ode.py +++ b/src/qutip_jax/ode.py @@ -1,5 +1,4 @@ import diffrax -import warnings from qutip.solver.integrator import Integrator import jax import jax.numpy as jnp @@ -65,22 +64,16 @@ def get_state(self, copy=False): return self.t, JaxArray(_float2cplx(self.state)) def integrate(self, t, copy=False, **kwargs): - with warnings.catch_warnings(): - # Diffrax added partial support for complex number, but raise a - # warning when it find a complex anywhere in the tree. - warnings.filterwarnings("ignore", - message="Complex dtype support is work in progress," - ) - sol = diffrax.diffeqsolve( - self.ODEsystem, - t0=self.t, - t1=t, - y0=self.state, - saveat=diffrax.SaveAt(t1=True, solver_state=True), - solver_state=self.solver_state, - args=(self.system, kwargs), - **self._options, - ) + sol = diffrax.diffeqsolve( + self.ODEsystem, + t0=self.t, + t1=t, + y0=self.state, + saveat=diffrax.SaveAt(t1=True, solver_state=True), + solver_state=self.solver_state, + args=(self.system, kwargs), + **self._options, + ) self.t = t self.state = sol.ys[0, :] self.solver_state = sol.solver_state