From abecef5e5cee12f284b0deeddcb8484a25950e68 Mon Sep 17 00:00:00 2001 From: Raul Torres Date: Thu, 9 Jan 2025 18:53:30 -0500 Subject: [PATCH] Remove function for assertion --- .../pytest/test_peephole_optimizations.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 8974518ab9..67d5abea83 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -21,23 +21,8 @@ from catalyst import pipeline, qjit from catalyst.passes import cancel_inverses, merge_rotations -default_device = qml.device("default.qubit", wires=1) - # pylint: disable=missing-function-docstring - -def _assert_against_reference(circuit, theta, backend, optimization): - - customized_device = qml.device(backend, wires=1) - - reference_workflow = qml.QNode(circuit, default_device) - qjitted_workflow = qjit(qml.QNode(circuit, customized_device)) - optimized_workflow = qjit(optimization(qml.QNode(circuit, customized_device))) - - assert np.allclose(reference_workflow(theta), qjitted_workflow(theta)) - assert np.allclose(reference_workflow(theta), optimized_workflow(theta)) - - # # cancel_inverses # @@ -53,7 +38,19 @@ def circuit(x): qml.Hadamard(wires=0) return qml.probs() - _assert_against_reference(circuit, theta, backend, cancel_inverses) + reference_workflow = qml.QNode(circuit, qml.device("default.qubit", wires=1)) + + customized_device = qml.device(backend, wires=1) + qjitted_workflow = qjit(qml.QNode(circuit, customized_device)) + optimized_workflow = qjit(cancel_inverses(qml.QNode(circuit, customized_device))) + + assert np.allclose(reference_workflow(theta), qjitted_workflow(theta)) + assert np.allclose(reference_workflow(theta), optimized_workflow(theta)) + + +# +# merge_rotations +# @pytest.mark.parametrize("theta", [42.42]) @@ -72,7 +69,14 @@ def circuit(x): qml.Hadamard(wires=0) return qml.probs() - _assert_against_reference(circuit, theta, backend, merge_rotations) + reference_workflow = qml.QNode(circuit, qml.device("default.qubit", wires=1)) + + customized_device = qml.device(backend, wires=1) + qjitted_workflow = qjit(qml.QNode(circuit, customized_device)) + optimized_workflow = qjit(cancel_inverses(qml.QNode(circuit, customized_device))) + + assert np.allclose(reference_workflow(theta), qjitted_workflow(theta)) + assert np.allclose(reference_workflow(theta), optimized_workflow(theta)) @pytest.mark.parametrize("theta", [42.42])