Skip to content

Commit

Permalink
add test for mcsolve and qobj
Browse files Browse the repository at this point in the history
  • Loading branch information
rochisha0 committed Aug 14, 2024
1 parent 946888b commit ae21d43
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tests/test_qutip/test_mcsolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import jax
import jax.numpy as jnp
import qutip as qt
import qutip_jax as qjax
from qutip import mcsolve
from functools import partial

# Use JAX backend for QuTiP
qjax.use_jax_backend()

# Define time-dependent functions
@partial(jax.jit, static_argnames=("omega",))
def H_1_coeff(t, omega):
return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t)

# Test setup for gradient calculation
def setup_system(size=2):
a = qt.destroy(size).to("jax")
sm = qt.sigmax().to("jax")

# Define the Hamiltonian
H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm
H_1_op = sm * a.dag() + sm.dag() * a

H = [H_0, [H_1_op, qt.coefficient(H_1_coeff, args={"omega": 1.0})]]

state = qt.basis(size, size-1).to("jax")

# Define collapse operators and observables
c_ops = [jnp.sqrt(0.1) * a]
e_ops = [a.dag() * a, sm.dag() * sm]

# Time list
tlist = jnp.linspace(0.0, 10.0, 101)

return H, state, tlist, c_ops, e_ops

# Function for which we want to compute the gradient
def f(omega, H, state, tlist, c_ops, e_ops):
H[1][1] = qt.coefficient(H_1_coeff, args={"omega": omega})

result = mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={"method": "diffrax"})

return result.expect[0][-1].real

# Pytest test case for gradient computation
@pytest.mark.parametrize("omega_val", [1.0, 2.0, 3.0])
def test_gradient_mcsolve(omega_val):
H, state, tlist, c_ops, e_ops = setup_system(size=2)

# Compute the gradient with respect to omega
grad_func = jax.grad(lambda omega: f(omega, H, state, tlist, c_ops, e_ops))
gradient = grad_func(omega_val)

# Check if the gradient is not None and has the correct shape
assert gradient is not None
assert gradient.shape == ()
assert jnp.isfinite(gradient)
95 changes: 95 additions & 0 deletions tests/test_qutip/test_qobj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
import jax.numpy as jnp
from jax import jit, grad
from qutip import Qobj, basis, rand_dm, sigmax, identity, tensor, expect
import qutip.settings
import qutip_jax

# Set JAX backend for QuTiP
qutip.settings.core["auto_real_casting"] = False
qutip_jax.use_jax_backend()
tol = 1e-6 # Tolerance for assertion

# Initialize quantum objects for testing
with qutip.CoreOptions(default_dtype="jax"):
rho = rand_dm(2)
ket = basis(2, 0)
bra = ket.dag()
op1 = sigmax()
identity_op = identity(2)
composite_op = tensor(op1, identity_op)

def expectation_value(op: Qobj, state: Qobj) -> float:
"""
Compute the expectation value of an operator with respect to a quantum state.
Args:
op (Qobj): The operator (as a Qobj).
state (Qobj): The quantum state (as a Qobj).
Returns:
float: The expectation value.
"""
return expect(op, state)

# Test case for Qobj functions with jax.jit
@pytest.mark.parametrize("func_name, func", [
("copy", lambda x: x.copy()),
("conj", lambda x: x.conj()),
("contract", lambda x: x.contract()),
("cosm", lambda x: x.cosm()),
("dag", lambda x: x.dag()),
("eigenenergies", lambda x: x.eigenenergies()),
("expm", lambda x: x.expm()),
("inv", lambda x: x.inv()),
("logm", lambda x: x.logm()),
("matrix_element", lambda x: x.matrix_element(ket, ket)),
("norm", lambda x: x.norm()),
("overlap", lambda x: x.overlap(op1)),
("ptrace", lambda x: x.ptrace([0])),
("purity", lambda x: x.purity()),
("sinm", lambda x: x.sinm()),
("sqrtm", lambda x: x.sqrtm()),
("tr", lambda x: x.tr()),
("trans", lambda x: x.trans()),
("transform", lambda x: x.transform(identity_op)),
("unit", lambda x: x.unit())
])
def test_qobj_jit(func_name, func):
# Create a jitted function using the given Qobj function
def jit_func(op):
return func(op)

# Apply jit to the function
func_jit = jit(jit_func)
result_jit = func_jit(op1)

# Check if jit result is not None
assert result_jit is not None
print(f"JIT result of {func_name} with respect to Qobj data:", result_jit)

# Test case for Qobj functions with jax.grad
@pytest.mark.parametrize("func_name, func", [
("conj", lambda x: x.conj()),
("contract", lambda x: x.contract()),
("cosm", lambda x: x.cosm()),
("dag", lambda x: x.dag()),
("eigenenergies", lambda x: x.eigenenergies()),
("expm", lambda x: x.expm()),
("inv", lambda x: x.inv()),
("overlap", lambda x: x.overlap(op1)),
("purity", lambda x: x.purity()),
("sinm", lambda x: x.sinm()),
("tr", lambda x: x.tr()),
])
def test_qobj_grad(func_name, func):
# Create a differentiable function using the given Qobj function
def grad_func(op1):
return jnp.real(func(op1))

# Apply grad to the function
grad_func = grad(grad_func)
grad_result = grad_func(op1)

# Check if the gradient is not None
assert grad_result is not None

0 comments on commit ae21d43

Please sign in to comment.