Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow other dtype than complex128. #22

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/qutip_jax/binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def mul_jaxarray(matrix, value):
# We don't want to check values type in case jax pass a tracer etc.
# But we want to ensure the output is a matrix, thus don't use the
# fast constructor.
return JaxArray(matrix._jxa * value)
AGaliciaMartinez marked this conversation as resolved.
Show resolved Hide resolved
return JaxArray._fast_constructor(matrix._jxa * value, shape=matrix.shape)


def matmul_jaxarray(left, right, scale=1, out=None):
Expand Down Expand Up @@ -119,7 +119,7 @@ def kron_jaxarray(left, right):
Compute the Kronecker product of two matrices. This is used to represent
quantum tensor products of vector spaces.
"""
return JaxArray(jnp.kron(left._jxa, right._jxa))
return JaxArray._fast_constructor(jnp.kron(left._jxa, right._jxa))


def pow_jaxarray(matrix, n):
Expand All @@ -138,7 +138,7 @@ def pow_jaxarray(matrix, n):
"""
if matrix.shape[0] != matrix.shape[1]:
raise ValueError("matrix power only works with square matrices")
return JaxArray(jnp.linalg.matrix_power(matrix._jxa, n))
return JaxArray._fast_constructor(jnp.linalg.matrix_power(matrix._jxa, n))


qutip.data.add.add_specialisations(
Expand Down
22 changes: 11 additions & 11 deletions src/qutip_jax/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
]


def zeros_jaxarray(rows, cols):
def zeros_jaxarray(rows, cols, *, dtype=jnp.complex128):
"""
Creates a matrix representation of zeros with the given dimensions.

Expand All @@ -25,10 +25,10 @@ def zeros_jaxarray(rows, cols):
rows, cols : int
The number of rows and columns in the output matrix.
"""
return JaxArray(jnp.zeros((rows, cols), dtype=jnp.complex128))
return JaxArray._fast_constructor(jnp.zeros((rows, cols), dtype=dtype))


def identity_jaxarray(dimensions, scale=None):
def identity_jaxarray(dimensions, scale=None, *, dtype=jnp.complex128):
"""
Creates a square identity matrix of the given dimension.

Expand All @@ -43,11 +43,11 @@ def identity_jaxarray(dimensions, scale=None):
The element which should be placed on the diagonal.
"""
if scale is None:
return JaxArray(jnp.eye(dimensions, dtype=jnp.complex128))
return JaxArray(jnp.eye(dimensions, dtype=jnp.complex128) * scale)
return JaxArray._fast_constructor(jnp.eye(dimensions, dtype=dtype))
return JaxArray._fast_constructor(jnp.eye(dimensions, dtype=dtype) * scale)


def diag_jaxarray(diagonals, offsets=None, shape=None):
def diag_jaxarray(diagonals, offsets=None, shape=None, *, dtype=jnp.complex128):
"""
Constructs a matrix from diagonals and their offsets.

Expand Down Expand Up @@ -108,10 +108,10 @@ def diag_jaxarray(diagonals, offsets=None, shape=None):

if n_rows == n_cols:
# jax diag only create square matrix
out = jnp.zeros((n_rows, n_cols), dtype=jnp.complex128)
out = jnp.zeros((n_rows, n_cols), dtype=dtype)
for offset, diag in zip(offsets, diagonals):
out += jnp.diag(jnp.array(diag), offset)
out = JaxArray(out)
out = JaxArray._fast_constructor(out)
else:
out = jax_from_dense(
qutip.core.data.dense.diags(diagonals, offsets, shape)
Expand All @@ -120,7 +120,7 @@ def diag_jaxarray(diagonals, offsets=None, shape=None):
return out


def one_element_jaxarray(shape, position, value=None):
def one_element_jaxarray(shape, position, value=None, *, dtype=jnp.complex128):
"""
Creates a matrix with only one nonzero element.

Expand All @@ -141,8 +141,8 @@ def one_element_jaxarray(shape, position, value=None):
)
if value is None:
value = 1.0
out = jnp.zeros(shape, dtype=jnp.complex128)
return JaxArray(out.at[position].set(value))
out = jnp.zeros(shape, dtype=dtype)
return JaxArray._fast_constructor(out.at[position].set(value))


qutip.data.zeros.add_specialisations(
Expand Down
16 changes: 9 additions & 7 deletions src/qutip_jax/jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class JaxArray(Data):
_jxa: jnp.ndarray
shape: tuple

def __init__(self, data, shape=None, copy=None):
jxa = jnp.array(data, dtype=jnp.complex128)
def __init__(self, data, shape=None, copy=None, *, dtype=jnp.complex128):
jxa = jnp.array(data, dtype=dtype)

if shape is None:
shape = data.shape
Expand All @@ -45,19 +45,19 @@ def __init__(self, data, shape=None, copy=None):
Data.__init__(self, shape)

def copy(self):
return self.__class__(self._jxa, copy=True)
return JaxArray._fast_constructor(self._jxa.copy(), shape=self.shape)

def to_array(self):
return np.array(self._jxa)

def conj(self):
return self.__class__(self._jxa.conj())
return JaxArray._fast_constructor(self._jxa.conj(), shape=self.shape)

def transpose(self):
return self.__class__(self._jxa.T)
return JaxArray._fast_constructor(self._jxa.T, shape=self.shape[::-1])

def adjoint(self):
return self.__class__(self._jxa.T.conj())
return JaxArray._fast_constructor(self._jxa.T.conj(), shape=self.shape[::-1])

def trace(self):
return jnp.trace(self._jxa)
Expand All @@ -81,8 +81,10 @@ def __matmul__(self, other):
return NotImplemented

@classmethod
def _fast_constructor(cls, array, shape):
def _fast_constructor(cls, array, shape=None):
out = cls.__new__(cls)
if shape is None:
shape = array.shape
Data.__init__(out, shape)
out._jxa = array
return out
Expand Down
6 changes: 3 additions & 3 deletions src/qutip_jax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def eigs_jaxarray(data, isherm=None, vecs=True, sort='low', eigvals=0):

evals, evecs = _eigs_jaxarray(data._jxa, isherm, vecs, eigvals, low_first)

return (evals, JaxArray(evecs, copy=False)) if vecs else evals
return (evals, JaxArray._fast_constructor(evecs)) if vecs else evals


qutip.data.eigs.add_specialisations(
Expand Down Expand Up @@ -109,7 +109,7 @@ def svd_jaxarray(data, vecs=True, full_matrices=True, hermitian=False):
)
if vecs:
u, s, vh = out
return JaxArray(u, copy=False), s, JaxArray(vh, copy=False)
return JaxArray._fast_constructor(u), s, JaxArray._fast_constructor(vh)
return out


Expand Down Expand Up @@ -160,7 +160,7 @@ def solve_jaxarray(matrix: JaxArray, target: JaxArray, method=None,
else:
raise ValueError(f"Unknown solver {method},"
" 'solve' and 'lstsq' are supported.")
return JaxArray(out, copy=False)
return JaxArray._fast_constructor(out)


qutip.data.solve.add_specialisations(
Expand Down
12 changes: 8 additions & 4 deletions src/qutip_jax/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

@jax.jit
def _cplx2float(arr):
return jnp.stack([arr.real, arr.imag])
if jnp.iscomplexobj(arr):
return jnp.stack([arr.real, arr.imag])
return arr


@jax.jit
def _float2cplx(arr):
return arr[0] + 1j * arr[1]
if arr.ndim == 3:
return arr[0] + 1j * arr[1]
return arr
Comment on lines 15 to +26
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test

Copy link
Member Author

@Ericgig Ericgig Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is covered in test_non_cplx128_Diffrax.



class DiffraxIntegrator(Integrator):
Expand Down Expand Up @@ -49,7 +53,7 @@ def _prepare(self):
def dstate(t, y, args):
state = _float2cplx(y)
H, kwargs = args
d_state = H.matmul_data(t, JaxArray(state), **kwargs)
d_state = H.matmul_data(t, JaxArray._fast_constructor(state), **kwargs)
return _cplx2float(d_state._jxa)

def set_state(self, t, state0):
Expand All @@ -61,7 +65,7 @@ def set_state(self, t, state0):
self._is_set = True

def get_state(self, copy=False):
return self.t, JaxArray(_float2cplx(self.state))
return self.t, JaxArray._fast_constructor(_float2cplx(self.state))

def integrate(self, t, copy=False, **kwargs):
sol = diffrax.diffeqsolve(
Expand Down
2 changes: 1 addition & 1 deletion src/qutip_jax/permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def indices_jaxarray(matrix, row_perm=None, col_perm=None):
data = data[np.argsort(row_perm), :]
if col_perm is not None:
data = data[:, np.argsort(col_perm)]
return JaxArray(data)
return JaxArray._fast_constructor(data)


def dimensions_jaxarray(matrix, dimensions, order):
Expand Down
37 changes: 28 additions & 9 deletions src/qutip_jax/qobjevo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from .jaxarray import JaxArray
from qutip.core.coefficient import coefficient_builders
from qutip.core.cy.coefficient import Coefficient
from qutip.core.cy.coefficient import Coefficient, coefficient_function_parameters
from qutip import Qobj


Expand All @@ -18,16 +18,26 @@ class JaxJitCoeff(Coefficient):

def __init__(self, func, args={}, **_):
self.func = func
_f_pythonic, _f_parameters = coefficient_function_parameters(func)
if _f_parameters is not None:
args = {key:val for key, val in args.items() if key in _f_parameters}
else:
args = args.copy()
if not _f_pythonic:
raise TypeError("Jitted coefficient should use a pythonic signature.")
Coefficient.__init__(self, args)

@eqx.filter_jit
def __call__(self, t, _args=None, **kwargs):
if _args:
kwargs.update(_args)
args = self.args.copy()
for key in kwargs:
if key in args:
args[key] = kwargs[key]
if kwargs:
args = self.args.copy()
for key in kwargs:
if key in args:
args[key] = kwargs[key]
else:
args = self.args
return self.func(t, **args)

def replace_arguments(self, _args=None, **kwargs):
Expand Down Expand Up @@ -113,25 +123,34 @@ def __init__(self, qobjevo):

constant = JaxJitCoeff(eqx.filter_jit(lambda t, **_: 1.0))

dtype = None

for part in as_list:
if isinstance(part, Qobj):
qobjs.append(part)
self.coeffs.append(constant)
if isinstance(part.data, JaxArray):
dtype = jnp.promote_types(dtype, part.data._jxa.dtype)
elif (
isinstance(part, list) and isinstance(part[0], Qobj)
):
qobjs.append(part[0])
self.coeffs.append(part[1])
if isinstance(part[0], JaxArray):
dtype = jnp.promote_types(dtype, part[0].data._jxa.dtype)
else:
# TODO:
raise NotImplementedError(
"Function based QobjEvo are not supported"
)

if dtype is None:
dtype=jnp.complex128

if qobjs:
shape = qobjs[0].shape
self.batched_data = jnp.zeros(
shape + (len(qobjs),), dtype=np.complex128
shape + (len(qobjs),), dtype=dtype
)
for i, qobj in enumerate(qobjs):
self.batched_data = self.batched_data.at[:, :, i].set(
Expand All @@ -141,7 +160,7 @@ def __init__(self, qobjevo):
@eqx.filter_jit
def _coeff(self, t, **args):
list_coeffs = [coeff(t, **args) for coeff in self.coeffs]
return jnp.array(list_coeffs, dtype=np.complex128)
return jnp.array(list_coeffs, dtype=self.batched_data.dtype)

def __call__(self, t, **kwargs):
return Qobj(self.data(t, **kwargs), dims=self.dims)
Expand All @@ -150,12 +169,12 @@ def __call__(self, t, **kwargs):
def data(self, t, **kwargs):
coeff = self._coeff(t, **kwargs)
data = jnp.dot(self.batched_data, coeff)
return JaxArray(data)
return JaxArray._fast_constructor(data)

@eqx.filter_jit
def matmul_data(self, t, y, **kwargs):
coeffs = self._coeff(t, **kwargs)
out = JaxArray(jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa))
out = JaxArray._fast_constructor(jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa))
return out

def arguments(self, args):
Expand Down
4 changes: 2 additions & 2 deletions src/qutip_jax/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def column_unstack_jaxarray(matrix, rows):
@jit
def split_columns_jaxarray(matrix):
return [
JaxArray(matrix._jxa[:, k]) for k in range(matrix.shape[1])
JaxArray._fast_constructor(matrix._jxa[:, k:k+1]) for k in range(matrix.shape[1])
]


Expand Down Expand Up @@ -119,7 +119,7 @@ def ptrace_jaxarray(matrix, dims, sel):
+ sel + [nd + q for q in sel]
)

return JaxArray(
return JaxArray._fast_constructor(
_ptrace_core(matrix._jxa, dims2, transpose_idx, dtrace, dkeep)
)

Expand Down
6 changes: 3 additions & 3 deletions src/qutip_jax/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def neg_jaxarray(matrix):
@jit
def adjoint_jaxarray(matrix):
"""Hermitian adjoint (matrix conjugate transpose)."""
return JaxArray(matrix._jxa.T.conj())
return JaxArray._fast_constructor(matrix._jxa.T.conj())


def transpose_jaxarray(matrix):
"""Transpose of a matrix."""
return JaxArray(matrix._jxa.T)
return JaxArray._fast_constructor(matrix._jxa.T)


def conj_jaxarray(matrix):
Expand Down Expand Up @@ -79,7 +79,7 @@ def project_jaxarray(state):
out = _project_bra(state._jxa)
else:
raise ValueError("state must be a ket or a bra.")
return JaxArray(out)
return JaxArray._fast_constructor(out)


qutip.data.neg.add_specialisations(
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
key = random.PRNGKey(1234)

def _random_cplx(shape):
return qutip_jax.JaxArray(
return qutip_jax.JaxArray._fast_constructor(
random.normal(key, shape) + 1j*random.normal(key, shape)
)
13 changes: 12 additions & 1 deletion tests/test_jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def test_init(backend, shape, dtype):
array = backend.array(array)
jax_a = JaxArray(array)
assert isinstance(jax_a, JaxArray)
assert jax_a._jxa.dtype == jax.numpy.complex128
if len(shape) == 1:
shape = shape + (1,)
assert jax_a.shape == shape
Expand Down Expand Up @@ -93,6 +92,18 @@ def test_convert():
assert isinstance(sx.data, JaxArray)


def test_alternative_dtype():
ones = jnp.ones((3, 3))
real_array = JaxArray(ones, dtype=jnp.float64)
cplx_array = JaxArray(ones*1j, dtype=jnp.complex64)
assert (real_array * 5.)._jxa.dtype == jnp.float64
assert (cplx_array + cplx_array)._jxa.dtype == jnp.complex64

cplx_array = JaxArray(ones*1j, dtype=jnp.complex64)
real_array = JaxArray(ones, dtype=jnp.float32)
assert (cplx_array @ real_array)._jxa.dtype == jnp.complex64


def test_extract():
ones = jnp.ones((3, 3))
qobj = qutip.Qobj(ones)
Expand Down
Loading