diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 3680c3dd6ea..eb7634c3f1a 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -616,6 +616,9 @@ same information.
Bug fixes 🐛
+* Fixes incorrect differentiation of `TrotterProduct` when using `diff_method="parameter-shift"`.
+ [(#6432)](https://github.com/PennyLaneAI/pennylane/pull/6432)
+
* `qml.ControlledQubitUnitary` has consistent behaviour with program capture enabled.
[(#6719)](https://github.com/PennyLaneAI/pennylane/pull/6719)
diff --git a/tests/ops/functions/conftest.py b/tests/ops/functions/conftest.py
index 413d7cb02aa..b6264c1d8ba 100644
--- a/tests/ops/functions/conftest.py
+++ b/tests/ops/functions/conftest.py
@@ -64,7 +64,10 @@ def _trotterize_qfunc_dummy(time, theta, phi, wires, flip=False):
(qml.s_prod(1.1, qml.RX(1.1, 0)), {}),
(qml.prod(qml.PauliX(0), qml.PauliY(1), qml.PauliZ(0)), {}),
(qml.ctrl(qml.RX(1.1, 0), 1), {}),
- (qml.exp(qml.PauliX(0), 1.1), {}),
+ (qml.exp(qml.PauliX(0), 1.1), {"skip_differentiation": True}),
+ # FIXME: Generator of Exp is incorrect when coefficient is imaginary
+ # (qml.exp(qml.PauliX(0), 2.9j), {}),
+ (qml.evolve(qml.PauliX(0), -0.5), {}),
(qml.pow(qml.IsingXX(1.1, [0, 1]), 2.5), {}),
(qml.ops.Evolution(qml.PauliX(0), 5.2), {}),
(qml.QutritBasisState([1, 2, 0], wires=[0, 1, 2]), {"skip_differentiation": True}),
diff --git a/tests/ops/op_math/test_exp.py b/tests/ops/op_math/test_exp.py
index 5abead84595..c995d6b1dbd 100644
--- a/tests/ops/op_math/test_exp.py
+++ b/tests/ops/op_math/test_exp.py
@@ -870,7 +870,6 @@ def circuit(phi):
grad = qml.grad(circuit)(phi)
assert qml.math.allclose(grad, -qml.numpy.sin(phi))
- @pytest.mark.xfail # related to #6333
@pytest.mark.autograd
def test_autograd_param_shift_qnode(self):
"""Test execution and gradient with pennylane numpy array."""
diff --git a/tests/templates/test_subroutines/test_trotter.py b/tests/templates/test_subroutines/test_trotter.py
index 4787c974a50..a0593853f41 100644
--- a/tests/templates/test_subroutines/test_trotter.py
+++ b/tests/templates/test_subroutines/test_trotter.py
@@ -452,13 +452,57 @@ def test_copy(self, hamiltonian, time, n, order):
assert op.hyperparameters == new_op.hyperparameters
assert op is not new_op
- @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/pennylane/issues/6333", strict=False)
@pytest.mark.parametrize("hamiltonian", test_hamiltonians)
def test_standard_validity(self, hamiltonian):
"""Test standard validity criteria using assert_valid."""
time, n, order = (4.2, 10, 4)
op = qml.TrotterProduct(hamiltonian, time, n=n, order=order)
- qml.ops.functions.assert_valid(op)
+ qml.ops.functions.assert_valid(op, skip_differentiation=True)
+
+ @pytest.mark.parametrize("hamiltonian", test_hamiltonians)
+ def test_differentiation(self, hamiltonian):
+ """Tests the differentiation of the TrotterProduct with parameter-shift"""
+
+ time, n, order = (4.2, 10, 4)
+
+ dev = qml.device("default.qubit")
+ coeffs, ops = hamiltonian.terms()
+
+ @qml.qnode(dev, diff_method="backprop")
+ def circ_bp(coeffs, time):
+ with qml.queuing.QueuingManager.stop_recording():
+ hamiltonian = qml.dot(coeffs, ops)
+
+ qml.TrotterProduct(hamiltonian, time, n, order)
+ return qml.probs()
+
+ @qml.qnode(dev, diff_method="parameter-shift")
+ def circ_ps(coeffs, time):
+ with qml.queuing.QueuingManager.stop_recording():
+ hamiltonian = qml.dot(coeffs, ops)
+
+ qml.TrotterProduct(hamiltonian, time, n, order)
+ return qml.probs()
+
+ coeffs = qml.numpy.array(coeffs)
+ time = qml.numpy.array(time)
+
+ expected_bp = qml.jacobian(circ_bp)(coeffs, time)
+ assert expected_bp[0].shape == (2**hamiltonian.num_wires, len(coeffs))
+ assert expected_bp[1].shape == (2**hamiltonian.num_wires,)
+
+ ps = qml.jacobian(circ_ps)(coeffs, time)
+ assert ps[0].shape == (2**hamiltonian.num_wires, len(coeffs))
+ assert ps[1].shape == (2**hamiltonian.num_wires,)
+
+ error_msg = (
+ "Parameter-shift does not produce the same Jacobian as with backpropagation. "
+ "This might be a bug, or it might be expected due to the mathematical nature "
+ "of backpropagation, in which case, this test can be skipped for this operator."
+ )
+
+ for actual, expected in zip(ps, expected_bp):
+ assert qml.math.allclose(actual, expected, atol=1e-6), error_msg
# TODO: Remove test when we deprecate ApproxTimeEvolution
@pytest.mark.parametrize("n", (1, 2, 5, 10))