From 9d8789303e7f0eeb68908ff05e512b9f14ff74bd Mon Sep 17 00:00:00 2001 From: Eric Giguere Date: Wed, 21 Jun 2023 16:55:59 -0400 Subject: [PATCH 1/4] jaxarray support alternative dtypes --- src/qutip_jax/create.py | 18 +++++++++--------- src/qutip_jax/jaxarray.py | 12 ++++++------ src/qutip_jax/qobjevo.py | 15 +++++++++++++-- tests/test_jaxarray.py | 10 +++++++++- tests/test_ode.py | 21 +++++++++++++++++++-- 5 files changed, 56 insertions(+), 20 deletions(-) diff --git a/src/qutip_jax/create.py b/src/qutip_jax/create.py index bbddba7..019713d 100644 --- a/src/qutip_jax/create.py +++ b/src/qutip_jax/create.py @@ -16,7 +16,7 @@ ] -def zeros_jaxarray(rows, cols): +def zeros_jaxarray(rows, cols, *, dtype=jnp.complex128): """ Creates a matrix representation of zeros with the given dimensions. @@ -25,10 +25,10 @@ def zeros_jaxarray(rows, cols): rows, cols : int The number of rows and columns in the output matrix. """ - return JaxArray(jnp.zeros((rows, cols), dtype=jnp.complex128)) + return JaxArray(jnp.zeros((rows, cols), dtype=dtype)) -def identity_jaxarray(dimensions, scale=None): +def identity_jaxarray(dimensions, scale=None, *, dtype=jnp.complex128): """ Creates a square identity matrix of the given dimension. @@ -43,11 +43,11 @@ def identity_jaxarray(dimensions, scale=None): The element which should be placed on the diagonal. """ if scale is None: - return JaxArray(jnp.eye(dimensions, dtype=jnp.complex128)) - return JaxArray(jnp.eye(dimensions, dtype=jnp.complex128) * scale) + return JaxArray(jnp.eye(dimensions, dtype=dtype)) + return JaxArray(jnp.eye(dimensions, dtype=dtype) * scale) -def diag_jaxarray(diagonals, offsets=None, shape=None): +def diag_jaxarray(diagonals, offsets=None, shape=None, *, dtype=jnp.complex128): """ Constructs a matrix from diagonals and their offsets. @@ -108,7 +108,7 @@ def diag_jaxarray(diagonals, offsets=None, shape=None): if n_rows == n_cols: # jax diag only create square matrix - out = jnp.zeros((n_rows, n_cols), dtype=jnp.complex128) + out = jnp.zeros((n_rows, n_cols), dtype=dtype) for offset, diag in zip(offsets, diagonals): out += jnp.diag(jnp.array(diag), offset) out = JaxArray(out) @@ -118,7 +118,7 @@ def diag_jaxarray(diagonals, offsets=None, shape=None): return out -def one_element_jaxarray(shape, position, value=None): +def one_element_jaxarray(shape, position, value=None, *, dtype=jnp.complex128): """ Creates a matrix with only one nonzero element. @@ -141,7 +141,7 @@ def one_element_jaxarray(shape, position, value=None): + str(shape) ) value = value or 1 - out = jnp.zeros(shape, dtype=jnp.complex128) + out = jnp.zeros(shape, dtype=dtype) return JaxArray(out.at[position].set(value)) diff --git a/src/qutip_jax/jaxarray.py b/src/qutip_jax/jaxarray.py index 4451216..4acfc78 100644 --- a/src/qutip_jax/jaxarray.py +++ b/src/qutip_jax/jaxarray.py @@ -18,8 +18,8 @@ class JaxArray(Data): _jxa: jnp.ndarray shape: tuple - def __init__(self, data, shape=None, copy=None): - jxa = jnp.array(data, dtype=jnp.complex128) + def __init__(self, data, shape=None, copy=None, *, dtype=None): + jxa = jnp.array(data, dtype=dtype) if shape is None: shape = data.shape @@ -46,19 +46,19 @@ def __init__(self, data, shape=None, copy=None): Data.__init__(self, shape) def copy(self): - return self.__class__(self._jxa, copy=True) + return JaxArray._fast_constructor(self._jxa.copy(), shape=self.shape) def to_array(self): return np.array(self._jxa) def conj(self): - return self.__class__(self._jxa.conj()) + return JaxArray._fast_constructor(self._jxa.conj(), shape=self.shape) def transpose(self): - return self.__class__(self._jxa.T) + return JaxArray._fast_constructor(self._jxa.T, shape=self.shape[::-1]) def adjoint(self): - return self.__class__(self._jxa.T.conj()) + return JaxArray._fast_constructor(self._jxa.T.conj(), shape=self.shape[::-1]) def trace(self): return jnp.trace(self._jxa) diff --git a/src/qutip_jax/qobjevo.py b/src/qutip_jax/qobjevo.py index df3a406..3fa1f85 100644 --- a/src/qutip_jax/qobjevo.py +++ b/src/qutip_jax/qobjevo.py @@ -97,6 +97,7 @@ class JaxQobjEvo(eqx.Module): batched_data: jnp.ndarray coeffs: list dims: object = eqx.static_field() + dtype: jnp.dtype def __init__(self, qobjevo): as_list = qobjevo.to_list() @@ -106,26 +107,36 @@ def __init__(self, qobjevo): constant = JaxJitCoeff(eqx.filter_jit(lambda t, **_: 1.)) + dtype = None + for part in as_list: if isinstance(part, Qobj): qobjs.append(part) self.coeffs.append(constant) + if isinstance(part.data, JaxArray): + dtype = jnp.promote_types(dtype, part.data._jxa.dtype) elif ( isinstance(part, list) and isinstance(part[0], Qobj) ): qobjs.append(part[0]) self.coeffs.append(part[1]) + if isinstance(part[0], JaxArray): + dtype = jnp.promote_types(dtype, part[0].data._jxa.dtype) else: # TODO: raise NotImplementedError( "Function based QobjEvo are not supported" ) + if dtype is None: + dtype=jnp.complex128 + self.dtype = dtype + if qobjs: shape = qobjs[0].shape self.batched_data = jnp.zeros( - shape + (len(qobjs),), dtype=np.complex128 + shape + (len(qobjs),), dtype=dtype ) for i, qobj in enumerate(qobjs): self.batched_data = self.batched_data.at[:, :, i].set( @@ -135,7 +146,7 @@ def __init__(self, qobjevo): @eqx.filter_jit def _coeff(self, t, **args): list_coeffs = [coeff(t, **args) for coeff in self.coeffs] - return jnp.array(list_coeffs, dtype=np.complex128) + return jnp.array(list_coeffs, dtype=self.batched_data.dtype) def __call__(self, t, **kwargs): return Qobj(self.data(t, **kwargs), dims=self.dims) diff --git a/tests/test_jaxarray.py b/tests/test_jaxarray.py index fd654cf..ef128c5 100644 --- a/tests/test_jaxarray.py +++ b/tests/test_jaxarray.py @@ -23,7 +23,6 @@ def test_init(backend, shape, dtype): array = backend.array(array) jax_a = JaxArray(array) assert isinstance(jax_a, JaxArray) - assert jax_a._jxa.dtype == jax.numpy.complex128 if len(shape) == 1: shape = shape + (1,) assert jax_a.shape == shape @@ -91,3 +90,12 @@ def test_convert(): sx = qutip.qeye(5, dtype="JaxArray") assert isinstance(sx.data, JaxArray) + + +def test_alternative_dtype(): + ones = jnp.ones((3, 3)) + real_array = JaxArray(ones, dtype=jnp.float64) + cplx_array = JaxArray(ones*1j, dtype=jnp.complex64) + assert (real_array * 5.)._jxa.dtype == jnp.float64 + assert (cplx_array + cplx_array)._jxa.dtype == jnp.complex64 + assert (cplx_array @ real_array)._jxa.dtype == jnp.complex128 diff --git a/tests/test_ode.py b/tests/test_ode.py index a959b10..5f40114 100644 --- a/tests/test_ode.py +++ b/tests/test_ode.py @@ -1,9 +1,11 @@ from qutip import ( - coefficient, num, destroy, create, sesolve, MESolver, basis, settings, QobjEvo + coefficient, num, destroy, create, sesolve, MESolver, basis, settings, QobjEvo, Qobj ) import qutip_jax +from qutip_jax.qobjevo import JaxQobjEvo import pytest import jax +import jax.numpy as jnp import numpy as np settings.core["default_dtype"] = "jax" @@ -67,7 +69,7 @@ def test_ode_step(): assert (solver.step(1) - ref_solver.step(1)).norm() <= 1e-6 - +import jax def test_ode_grad(): H = num(10) c_ops = [QobjEvo([destroy(10), cte], args={"A": 1.0})] @@ -86,3 +88,18 @@ def f(solver, t, A): assert val == pytest.approx(9 * np.exp(- 0.2 * 0.5)) assert dt == pytest.approx(9 * np.exp(- 0.2 * 0.5) * -0.5) assert dA == pytest.approx(9 * np.exp(- 0.2 * 0.5) * -0.2) + + +def test_non_cplx128_JaxQobjEvo(): + op1 = Qobj(qutip_jax.zeros_jaxarray(3, 3, dtype=jnp.float64)) + op2 = Qobj( + qutip_jax.one_element_jaxarray((3, 3), (0, 0), dtype=jnp.float64) + ) + op3 = Qobj(qutip_jax.identity_jaxarray(3, dtype=jnp.float64)) + qevo = QobjEvo( + [op1, [op2, pulse], [op3, cte]], + args={"A":1.0, "u":0.1, "sigma":0.5} + ) + jqevo = JaxQobjEvo(qevo) + assert jqevo.dtype == jnp.float64 + assert jqevo.batched_data.dtype == jnp.float64 From 4e4ba28e7e08c96081a8bd69dd26329fbdf685a8 Mon Sep 17 00:00:00 2001 From: Eric Giguere Date: Wed, 21 Jun 2023 16:58:05 -0400 Subject: [PATCH 2/4] JaxQobjEvo support other dtype --- src/qutip_jax/qobjevo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/qutip_jax/qobjevo.py b/src/qutip_jax/qobjevo.py index 3fa1f85..d299241 100644 --- a/src/qutip_jax/qobjevo.py +++ b/src/qutip_jax/qobjevo.py @@ -97,7 +97,6 @@ class JaxQobjEvo(eqx.Module): batched_data: jnp.ndarray coeffs: list dims: object = eqx.static_field() - dtype: jnp.dtype def __init__(self, qobjevo): as_list = qobjevo.to_list() @@ -131,7 +130,6 @@ def __init__(self, qobjevo): if dtype is None: dtype=jnp.complex128 - self.dtype = dtype if qobjs: shape = qobjs[0].shape From 0cf90ea21293a573657b77c55ab24db487c21645 Mon Sep 17 00:00:00 2001 From: Eric Giguere Date: Thu, 22 Jun 2023 13:51:43 -0400 Subject: [PATCH 3/4] Support other dtypes --- src/qutip_jax/binops.py | 6 +++--- src/qutip_jax/create.py | 10 +++++----- src/qutip_jax/jaxarray.py | 6 ++++-- src/qutip_jax/linalg.py | 6 +++--- src/qutip_jax/ode.py | 12 ++++++++---- src/qutip_jax/permute.py | 2 +- src/qutip_jax/qobjevo.py | 24 +++++++++++++++++------- src/qutip_jax/reshape.py | 4 ++-- src/qutip_jax/unary.py | 6 +++--- tests/conftest.py | 2 +- tests/test_jaxarray.py | 5 ++++- tests/test_ode.py | 26 ++++++++++++++++++++++++-- 12 files changed, 75 insertions(+), 34 deletions(-) diff --git a/src/qutip_jax/binops.py b/src/qutip_jax/binops.py index f9d762d..cf9ea56 100644 --- a/src/qutip_jax/binops.py +++ b/src/qutip_jax/binops.py @@ -74,7 +74,7 @@ def mul_jaxarray(matrix, value): # We don't want to check values type in case jax pass a tracer etc. # But we want to ensure the output is a matrix, thus don't use the # fast constructor. - return JaxArray(matrix._jxa * value) + return JaxArray._fast_constructor(matrix._jxa * value, shape=matrix.shape) def matmul_jaxarray(left, right, scale=1, out=None): @@ -119,7 +119,7 @@ def kron_jaxarray(left, right): Compute the Kronecker product of two matrices. This is used to represent quantum tensor products of vector spaces. """ - return JaxArray(jnp.kron(left._jxa, right._jxa)) + return JaxArray._fast_constructor(jnp.kron(left._jxa, right._jxa)) def pow_jaxarray(matrix, n): @@ -138,7 +138,7 @@ def pow_jaxarray(matrix, n): """ if matrix.shape[0] != matrix.shape[1]: raise ValueError("matrix power only works with square matrices") - return JaxArray(jnp.linalg.matrix_power(matrix._jxa, n)) + return JaxArray._fast_constructor(jnp.linalg.matrix_power(matrix._jxa, n)) qutip.data.add.add_specialisations( diff --git a/src/qutip_jax/create.py b/src/qutip_jax/create.py index 019713d..0552c70 100644 --- a/src/qutip_jax/create.py +++ b/src/qutip_jax/create.py @@ -25,7 +25,7 @@ def zeros_jaxarray(rows, cols, *, dtype=jnp.complex128): rows, cols : int The number of rows and columns in the output matrix. """ - return JaxArray(jnp.zeros((rows, cols), dtype=dtype)) + return JaxArray._fast_constructor(jnp.zeros((rows, cols), dtype=dtype)) def identity_jaxarray(dimensions, scale=None, *, dtype=jnp.complex128): @@ -43,8 +43,8 @@ def identity_jaxarray(dimensions, scale=None, *, dtype=jnp.complex128): The element which should be placed on the diagonal. """ if scale is None: - return JaxArray(jnp.eye(dimensions, dtype=dtype)) - return JaxArray(jnp.eye(dimensions, dtype=dtype) * scale) + return JaxArray._fast_constructor(jnp.eye(dimensions, dtype=dtype)) + return JaxArray._fast_constructor(jnp.eye(dimensions, dtype=dtype) * scale) def diag_jaxarray(diagonals, offsets=None, shape=None, *, dtype=jnp.complex128): @@ -111,7 +111,7 @@ def diag_jaxarray(diagonals, offsets=None, shape=None, *, dtype=jnp.complex128): out = jnp.zeros((n_rows, n_cols), dtype=dtype) for offset, diag in zip(offsets, diagonals): out += jnp.diag(jnp.array(diag), offset) - out = JaxArray(out) + out = JaxArray._fast_constructor(out) else: out = jax_from_dense(qutip.core.data.dense.diags(diagonals, offsets, shape)) @@ -142,7 +142,7 @@ def one_element_jaxarray(shape, position, value=None, *, dtype=jnp.complex128): ) value = value or 1 out = jnp.zeros(shape, dtype=dtype) - return JaxArray(out.at[position].set(value)) + return JaxArray._fast_constructor(out.at[position].set(value)) qutip.data.zeros.add_specialisations( diff --git a/src/qutip_jax/jaxarray.py b/src/qutip_jax/jaxarray.py index 4acfc78..d6edeb8 100644 --- a/src/qutip_jax/jaxarray.py +++ b/src/qutip_jax/jaxarray.py @@ -18,7 +18,7 @@ class JaxArray(Data): _jxa: jnp.ndarray shape: tuple - def __init__(self, data, shape=None, copy=None, *, dtype=None): + def __init__(self, data, shape=None, copy=None, *, dtype=jnp.complex128): jxa = jnp.array(data, dtype=dtype) if shape is None: @@ -82,8 +82,10 @@ def __matmul__(self, other): return NotImplemented @classmethod - def _fast_constructor(cls, array, shape): + def _fast_constructor(cls, array, shape=None): out = cls.__new__(cls) + if shape is None: + shape = array.shape Data.__init__(out, shape) out._jxa = array return out diff --git a/src/qutip_jax/linalg.py b/src/qutip_jax/linalg.py index 8d909d4..a3339aa 100644 --- a/src/qutip_jax/linalg.py +++ b/src/qutip_jax/linalg.py @@ -65,7 +65,7 @@ def eigs_jaxarray(data, isherm=None, vecs=True, sort='low', eigvals=0): evals, evecs = _eigs_jaxarray(data._jxa, isherm, vecs, eigvals, low_first) - return (evals, JaxArray(evecs, copy=False)) if vecs else evals + return (evals, JaxArray._fast_constructor(evecs)) if vecs else evals qutip.data.eigs.add_specialisations( @@ -109,7 +109,7 @@ def svd_jaxarray(data, vecs=True, full_matrices=True, hermitian=False): ) if vecs: u, s, vh = out - return JaxArray(u, copy=False), s, JaxArray(vh, copy=False) + return JaxArray._fast_constructor(u), s, JaxArray._fast_constructor(vh) return out @@ -160,7 +160,7 @@ def solve_jaxarray(matrix: JaxArray, target: JaxArray, method=None, else: raise ValueError(f"Unknown solver {method}," " 'solve' and 'lstsq' are supported.") - return JaxArray(out, copy=False) + return JaxArray._fast_constructor(out) qutip.data.solve.add_specialisations( diff --git a/src/qutip_jax/ode.py b/src/qutip_jax/ode.py index 766753b..c01e478 100644 --- a/src/qutip_jax/ode.py +++ b/src/qutip_jax/ode.py @@ -14,12 +14,16 @@ @jax.jit def _cplx2float(arr): - return jnp.stack([arr.real, arr.imag]) + if jnp.iscomplexobj(arr): + return jnp.stack([arr.real, arr.imag]) + return arr @jax.jit def _float2cplx(arr): - return arr[0] + 1j * arr[1] + if arr.ndim == 3: + return arr[0] + 1j * arr[1] + return arr class DiffraxIntegrator(Integrator): @@ -49,7 +53,7 @@ def _prepare(self): def dstate(t, y, args): state = _float2cplx(y) H, kwargs = args - d_state = H.matmul_data(t, JaxArray(state), **kwargs) + d_state = H.matmul_data(t, JaxArray._fast_constructor(state), **kwargs) return _cplx2float(d_state._jxa) def set_state(self, t, state0): @@ -61,7 +65,7 @@ def set_state(self, t, state0): self._is_set = True def get_state(self, copy=False): - return self.t, JaxArray(_float2cplx(self.state)) + return self.t, JaxArray._fast_constructor(_float2cplx(self.state)) def integrate(self, t, copy=False, **kwargs): sol = diffrax.diffeqsolve( diff --git a/src/qutip_jax/permute.py b/src/qutip_jax/permute.py index 6c053d3..bf37826 100644 --- a/src/qutip_jax/permute.py +++ b/src/qutip_jax/permute.py @@ -12,7 +12,7 @@ def indices_jaxarray(matrix, row_perm=None, col_perm=None): data = data[np.argsort(row_perm), :] if col_perm is not None: data = data[:, np.argsort(col_perm)] - return JaxArray(data) + return JaxArray._fast_constructor(data) def dimensions_jaxarray(matrix, dimensions, order): diff --git a/src/qutip_jax/qobjevo.py b/src/qutip_jax/qobjevo.py index d299241..9968a53 100644 --- a/src/qutip_jax/qobjevo.py +++ b/src/qutip_jax/qobjevo.py @@ -5,7 +5,7 @@ import numpy as np from .jaxarray import JaxArray from qutip.core.coefficient import coefficient_builders -from qutip.core.cy.coefficient import Coefficient +from qutip.core.cy.coefficient import Coefficient, coefficient_function_parameters from qutip import Qobj @@ -18,16 +18,26 @@ class JaxJitCoeff(Coefficient): def __init__(self, func, args={}, **_): self.func = func + _f_pythonic, _f_parameters = coefficient_function_parameters(func) + if _f_parameters is not None: + args = {key:val for key, val in args.items() if key in _f_parameters} + else: + args = args.copy() + if not _f_pythonic: + raise TypeError("Jitted coefficient should use a pythonic signature.") Coefficient.__init__(self, args) @eqx.filter_jit def __call__(self, t, _args=None, **kwargs): if _args: kwargs.update(_args) - args = self.args.copy() - for key in kwargs: - if key in args: - args[key] = kwargs[key] + if kwargs: + args = self.args.copy() + for key in kwargs: + if key in args: + args[key] = kwargs[key] + else: + args = self.args return self.func(t, **args) def replace_arguments(self, _args=None, **kwargs): @@ -153,12 +163,12 @@ def __call__(self, t, **kwargs): def data(self, t, **kwargs): coeff = self._coeff(t, **kwargs) data = jnp.dot(self.batched_data, coeff) - return JaxArray(data) + return JaxArray._fast_constructor(data) @eqx.filter_jit def matmul_data(self, t, y, **kwargs): coeffs = self._coeff(t, **kwargs) - out = JaxArray(jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa)) + out = JaxArray._fast_constructor(jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa)) return out def arguments(self, args): diff --git a/src/qutip_jax/reshape.py b/src/qutip_jax/reshape.py index f77ef8f..75c0f9c 100644 --- a/src/qutip_jax/reshape.py +++ b/src/qutip_jax/reshape.py @@ -58,7 +58,7 @@ def column_unstack_jaxarray(matrix, rows): @jit def split_columns_jaxarray(matrix): return [ - JaxArray(matrix._jxa[:, k]) for k in range(matrix.shape[1]) + JaxArray._fast_constructor(matrix._jxa[:, k:k+1]) for k in range(matrix.shape[1]) ] @@ -117,7 +117,7 @@ def ptrace_jaxarray(matrix, dims, sel): + sel + [nd + q for q in sel] ) - return JaxArray( + return JaxArray._fast_constructor( _ptrace_core(matrix._jxa, dims2, transpose_idx, dtrace, dkeep) ) diff --git a/src/qutip_jax/unary.py b/src/qutip_jax/unary.py index d3b601b..d8afbfd 100644 --- a/src/qutip_jax/unary.py +++ b/src/qutip_jax/unary.py @@ -31,12 +31,12 @@ def neg_jaxarray(matrix): @jit def adjoint_jaxarray(matrix): """Hermitian adjoint (matrix conjugate transpose).""" - return JaxArray(matrix._jxa.T.conj()) + return JaxArray._fast_constructor(matrix._jxa.T.conj()) def transpose_jaxarray(matrix): """Transpose of a matrix.""" - return JaxArray(matrix._jxa.T) + return JaxArray._fast_constructor(matrix._jxa.T) def conj_jaxarray(matrix): @@ -79,7 +79,7 @@ def project_jaxarray(state): out = _project_bra(state._jxa) else: raise ValueError("state must be a ket or a bra.") - return JaxArray(out) + return JaxArray._fast_constructor(out) qutip.data.neg.add_specialisations( diff --git a/tests/conftest.py b/tests/conftest.py index 66ca23e..08d6117 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,6 @@ key = random.PRNGKey(1234) def _random_cplx(shape): - return qutip_jax.JaxArray( + return qutip_jax.JaxArray._fast_constructor( random.normal(key, shape) + 1j*random.normal(key, shape) ) diff --git a/tests/test_jaxarray.py b/tests/test_jaxarray.py index ef128c5..6fffad9 100644 --- a/tests/test_jaxarray.py +++ b/tests/test_jaxarray.py @@ -98,4 +98,7 @@ def test_alternative_dtype(): cplx_array = JaxArray(ones*1j, dtype=jnp.complex64) assert (real_array * 5.)._jxa.dtype == jnp.float64 assert (cplx_array + cplx_array)._jxa.dtype == jnp.complex64 - assert (cplx_array @ real_array)._jxa.dtype == jnp.complex128 + + cplx_array = JaxArray(ones*1j, dtype=jnp.complex64) + real_array = JaxArray(ones, dtype=jnp.float32) + assert (cplx_array @ real_array)._jxa.dtype == jnp.complex64 diff --git a/tests/test_ode.py b/tests/test_ode.py index 5f40114..7c6d448 100644 --- a/tests/test_ode.py +++ b/tests/test_ode.py @@ -3,6 +3,8 @@ ) import qutip_jax from qutip_jax.qobjevo import JaxQobjEvo +from qutip_jax.ode import DiffraxIntegrator + import pytest import jax import jax.numpy as jnp @@ -69,7 +71,7 @@ def test_ode_step(): assert (solver.step(1) - ref_solver.step(1)).norm() <= 1e-6 -import jax + def test_ode_grad(): H = num(10) c_ops = [QobjEvo([destroy(10), cte], args={"A": 1.0})] @@ -101,5 +103,25 @@ def test_non_cplx128_JaxQobjEvo(): args={"A":1.0, "u":0.1, "sigma":0.5} ) jqevo = JaxQobjEvo(qevo) - assert jqevo.dtype == jnp.float64 assert jqevo.batched_data.dtype == jnp.float64 + + +def test_non_real_Diffrax(): + op1 = Qobj(qutip_jax.zeros_jaxarray(3, 3, dtype=jnp.float64)) + op2 = Qobj( + qutip_jax.one_element_jaxarray((3, 3), (0, 0), dtype=jnp.float64) + ) + op3 = Qobj(qutip_jax.identity_jaxarray(3, dtype=jnp.float64)) + qevo = QobjEvo( + [op1, [op2, pulse], [op3, cte]], + args={"A":1.0, "u":0.1, "sigma":0.5} + ) + + ode = DiffraxIntegrator(qevo, {}) + ode.set_state( + 0, + qutip_jax.one_element_jaxarray((3, 1), (2, 0), dtype=jnp.float64) + ) + t, out = ode.integrate(0.1) + assert out._jxa.dtype == jnp.float64 + From e43bcb8669c085e9ba2c3510f73f900b49192e9e Mon Sep 17 00:00:00 2001 From: Eric Giguere Date: Tue, 4 Jul 2023 10:04:28 -0400 Subject: [PATCH 4/4] rename test --- tests/test_ode.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_ode.py b/tests/test_ode.py index 7c6d448..61be829 100644 --- a/tests/test_ode.py +++ b/tests/test_ode.py @@ -106,7 +106,7 @@ def test_non_cplx128_JaxQobjEvo(): assert jqevo.batched_data.dtype == jnp.float64 -def test_non_real_Diffrax(): +def test_non_cplx128_Diffrax(): op1 = Qobj(qutip_jax.zeros_jaxarray(3, 3, dtype=jnp.float64)) op2 = Qobj( qutip_jax.one_element_jaxarray((3, 3), (0, 0), dtype=jnp.float64) @@ -116,12 +116,11 @@ def test_non_real_Diffrax(): [op1, [op2, pulse], [op3, cte]], args={"A":1.0, "u":0.1, "sigma":0.5} ) - + ode = DiffraxIntegrator(qevo, {}) ode.set_state( - 0, + 0, qutip_jax.one_element_jaxarray((3, 1), (2, 0), dtype=jnp.float64) ) t, out = ode.integrate(0.1) assert out._jxa.dtype == jnp.float64 -