Skip to content

Commit

Permalink
Merge pull request #70 from Ericgig/fix_new_warnings
Browse files Browse the repository at this point in the history
Set `e_ops` as keyword parameter
  • Loading branch information
Ericgig authored Nov 8, 2024
2 parents b5b098a + dd7dfc7 commit 4067dfa
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tests/test_qutip/test_mcsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def H_1_coeff(t, omega):

# Test setup for gradient calculation
def setup_system(size=2):
a = qt.tensor(qt.destroy(size), qt.qeye(2)).to('jaxdia')
sm = qt.qeye(size).to('jaxdia') & qt.sigmax().to('jaxdia')
a = qt.tensor(qt.destroy(size), qt.qeye(2)).to('jaxdia')
sm = qt.qeye(size).to('jaxdia') & qt.sigmax().to('jaxdia')

# Define the Hamiltonian
H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm
Expand All @@ -33,29 +33,29 @@ def setup_system(size=2):

# Time list
tlist = jnp.linspace(0.0, 1.0, 101)

return H, state, tlist, c_ops, e_ops

# Function for which we want to compute the gradient
def f(omega, H, state, tlist, c_ops, e_ops):
result = mcsolve(
H, state, tlist, c_ops, e_ops, ntraj=10,
args={"omega": omega},
H, state, tlist, c_ops, e_ops=e_ops, ntraj=10,
args={"omega": omega},
options={"method": "diffrax"}
)

return result.expect[0][-1].real

# Pytest test case for gradient computation
@pytest.mark.parametrize("omega_val", [2.0])
def test_gradient_mcsolve(omega_val):
H, state, tlist, c_ops, e_ops = setup_system(size=10)

# Compute the gradient with respect to omega
grad_func = jax.grad(lambda omega: f(omega, H, state, tlist, c_ops, e_ops))
gradient = grad_func(omega_val)

# Check if the gradient is not None and has the correct shape
assert gradient is not None
assert gradient.shape == ()
assert jnp.isfinite(gradient)
assert gradient.shape == ()
assert jnp.isfinite(gradient)

0 comments on commit 4067dfa

Please sign in to comment.