Skip to content

Commit

Permalink
Refactor 'merge rotations' and 'cancel inverses' tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc committed Jan 9, 2025
1 parent 1d61d40 commit dc32d99
Showing 1 changed file with 18 additions and 60 deletions.
78 changes: 18 additions & 60 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,23 @@
from catalyst import pipeline, qjit
from catalyst.passes import cancel_inverses, merge_rotations

default_device = qml.device("default.qubit", wires=1)

# pylint: disable=missing-function-docstring


def _assert_against_reference(circuit, theta, backend, optimization):

customized_device = qml.device(backend, wires=1)

reference_workflow = qml.QNode(circuit, default_device)
qjitted_workflow = qjit(qml.QNode(circuit, customized_device))
optimized_workflow = qjit(optimization(qml.QNode(circuit, customized_device)))

assert np.allclose(reference_workflow(theta), qjitted_workflow(theta))
assert np.allclose(reference_workflow(theta), optimized_workflow(theta))


#
# cancel_inverses
#
Expand All @@ -33,74 +47,19 @@
@pytest.mark.parametrize("theta", [42.42])
def test_cancel_inverses_functionality(theta, backend):

@qjit
def workflow():
@qml.qnode(qml.device(backend, wires=1))
def f(x):
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.probs()

@cancel_inverses
@qml.qnode(qml.device(backend, wires=1))
def g(x):
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.probs()

return f(theta), g(theta)

@qml.qnode(qml.device("default.qubit", wires=1))
def reference(x):
def circuit(x):
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.probs()

assert np.allclose(workflow()[0], workflow()[1])
assert np.allclose(workflow()[1], reference(theta))
_assert_against_reference(circuit, theta, backend, cancel_inverses)


@pytest.mark.parametrize("theta", [42.42])
def test_merge_rotation_functionality(theta, backend):

@qjit
def workflow():
@qml.qnode(qml.device(backend, wires=1))
def f(x):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.RZ(x, wires=0)
qml.adjoint(qml.RZ)(x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.PhaseShift(x, wires=0)
qml.PhaseShift(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.probs()

@merge_rotations
@qml.qnode(qml.device(backend, wires=1))
def g(x):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.RZ(x, wires=0)
qml.adjoint(qml.RZ)(x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.PhaseShift(x, wires=0)
qml.PhaseShift(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.probs()

return f(theta), g(theta)

@qml.qnode(qml.device("default.qubit", wires=1))
def reference(x):
def circuit(x):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.RZ(x, wires=0)
Expand All @@ -113,8 +72,7 @@ def reference(x):
qml.Hadamard(wires=0)
return qml.probs()

assert np.allclose(workflow()[0], workflow()[1])
assert np.allclose(workflow()[1], reference(theta))
_assert_against_reference(circuit, theta, backend, merge_rotations)


@pytest.mark.parametrize("theta", [42.42])
Expand Down

0 comments on commit dc32d99

Please sign in to comment.