diff --git a/tests/test_qutip/test_mcsolve.py b/tests/test_qutip/test_mcsolve.py index ac7c095..a3dc6d7 100644 --- a/tests/test_qutip/test_mcsolve.py +++ b/tests/test_qutip/test_mcsolve.py @@ -16,8 +16,8 @@ def H_1_coeff(t, omega): # Test setup for gradient calculation def setup_system(size=2): - a = qt.tensor(qt.destroy(size), qt.qeye(2)).to('jaxdia') - sm = qt.qeye(size).to('jaxdia') & qt.sigmax().to('jaxdia') + a = qt.tensor(qt.destroy(size), qt.qeye(2)).to('jaxdia') + sm = qt.qeye(size).to('jaxdia') & qt.sigmax().to('jaxdia') # Define the Hamiltonian H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm @@ -33,29 +33,29 @@ def setup_system(size=2): # Time list tlist = jnp.linspace(0.0, 1.0, 101) - + return H, state, tlist, c_ops, e_ops # Function for which we want to compute the gradient def f(omega, H, state, tlist, c_ops, e_ops): result = mcsolve( - H, state, tlist, c_ops, e_ops, ntraj=10, - args={"omega": omega}, + H, state, tlist, c_ops, e_ops=e_ops, ntraj=10, + args={"omega": omega}, options={"method": "diffrax"} ) - + return result.expect[0][-1].real # Pytest test case for gradient computation @pytest.mark.parametrize("omega_val", [2.0]) def test_gradient_mcsolve(omega_val): H, state, tlist, c_ops, e_ops = setup_system(size=10) - + # Compute the gradient with respect to omega grad_func = jax.grad(lambda omega: f(omega, H, state, tlist, c_ops, e_ops)) gradient = grad_func(omega_val) - + # Check if the gradient is not None and has the correct shape assert gradient is not None - assert gradient.shape == () - assert jnp.isfinite(gradient) + assert gradient.shape == () + assert jnp.isfinite(gradient)