Skip to content

Commit

Permalink
Remove function for assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc committed Jan 9, 2025
1 parent dc32d99 commit abecef5
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,8 @@
from catalyst import pipeline, qjit
from catalyst.passes import cancel_inverses, merge_rotations

default_device = qml.device("default.qubit", wires=1)

# pylint: disable=missing-function-docstring


def _assert_against_reference(circuit, theta, backend, optimization):

customized_device = qml.device(backend, wires=1)

reference_workflow = qml.QNode(circuit, default_device)
qjitted_workflow = qjit(qml.QNode(circuit, customized_device))
optimized_workflow = qjit(optimization(qml.QNode(circuit, customized_device)))

assert np.allclose(reference_workflow(theta), qjitted_workflow(theta))
assert np.allclose(reference_workflow(theta), optimized_workflow(theta))


#
# cancel_inverses
#
Expand All @@ -53,7 +38,19 @@ def circuit(x):
qml.Hadamard(wires=0)
return qml.probs()

_assert_against_reference(circuit, theta, backend, cancel_inverses)
reference_workflow = qml.QNode(circuit, qml.device("default.qubit", wires=1))

customized_device = qml.device(backend, wires=1)
qjitted_workflow = qjit(qml.QNode(circuit, customized_device))
optimized_workflow = qjit(cancel_inverses(qml.QNode(circuit, customized_device)))

assert np.allclose(reference_workflow(theta), qjitted_workflow(theta))
assert np.allclose(reference_workflow(theta), optimized_workflow(theta))


#
# merge_rotations
#


@pytest.mark.parametrize("theta", [42.42])
Expand All @@ -72,7 +69,14 @@ def circuit(x):
qml.Hadamard(wires=0)
return qml.probs()

_assert_against_reference(circuit, theta, backend, merge_rotations)
reference_workflow = qml.QNode(circuit, qml.device("default.qubit", wires=1))

customized_device = qml.device(backend, wires=1)
qjitted_workflow = qjit(qml.QNode(circuit, customized_device))
optimized_workflow = qjit(cancel_inverses(qml.QNode(circuit, customized_device)))

assert np.allclose(reference_workflow(theta), qjitted_workflow(theta))
assert np.allclose(reference_workflow(theta), optimized_workflow(theta))


@pytest.mark.parametrize("theta", [42.42])
Expand Down

0 comments on commit abecef5

Please sign in to comment.