From e45bcae86abcbb4d4d4d32ef06d23cb592327479 Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Sat, 19 Oct 2024 01:07:49 -0400 Subject: [PATCH 01/11] Added fix for incorrect pytree output with qml.counts() + tests --- frontend/catalyst/jax_tracer.py | 44 +++++++++--- frontend/test/pytest/test_jax_integration.py | 72 ++++++++++++++++++++ 2 files changed, 106 insertions(+), 10 deletions(-) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 1bfb118626..57efecd0f0 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -844,6 +844,8 @@ def trace_quantum_measurements( """ shots = get_device_shots(device) out_classical_tracers = [] + # NOTE: Number of qml.counts() we hit + num_counts = 0 for i, o in enumerate(outputs): if isinstance(o, MeasurementProcess): @@ -914,16 +916,8 @@ def trace_quantum_measurements( results = (jnp.asarray(results[0], jnp.int64), results[1]) out_classical_tracers.extend(results) counts_tree = tree_structure(("keys", "counts")) - meas_return_trees_children = out_tree.children() - if len(meas_return_trees_children): - meas_return_trees_children[i] = counts_tree - out_tree = out_tree.make_from_node_data_and_children( - PyTreeRegistry(), - out_tree.node_data(), - meas_return_trees_children, - ) - else: - out_tree = counts_tree + out_tree = replace_child_tree(out_tree, i + 1 + num_counts, counts_tree) + num_counts += 1 elif isinstance(o, StateMP) and not isinstance(o, DensityMatrixMP): assert using_compbasis shape = (2**nqubits,) @@ -940,6 +934,36 @@ def trace_quantum_measurements( return out_classical_tracers, out_tree +def replace_child_tree(tree: PyTreeDef, index: int, subtree: PyTreeDef) -> PyTreeDef: + """ + Replace the index-th leaf node in a PyTreeDef with a given subtree. + + Args: + tree (PyTreeDef): The original PyTree. + index (int): The index of the leaf node to replace. + subtree (PyTreeDef): The new subtree to replace the original leaf node with. + + Returns: + PyTreeDef: The modified PyTree with the replaced leaf node. + """ + + def replace_node(node, idx): + if not node.children(): + # Leaf node => update leaf node counter + idx[0] += 1 + if idx[0] == index: + return subtree + return node + + + return node.make_from_node_data_and_children( + PyTreeRegistry(), + node.node_data(), + [replace_node(child, idx) for child in node.children()] + ) + + return replace_node(tree, [0]) + @debug_logger def is_transform_valid_for_batch_transforms(tape, flat_results): diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py index 99b78f29b6..9faf31a1e5 100644 --- a/frontend/test/pytest/test_jax_integration.py +++ b/frontend/test/pytest/test_jax_integration.py @@ -489,6 +489,78 @@ def ansatz(i, x): params = jnp.array([0.54, 0.3154, 0.654, 0.123, 0.1, 0.2]) jax.grad(circuit, argnums=0)(params, 3) + def test_pytree_qml_counts_simple(self): + dev = qml.device("lightning.qubit", wires=1, shots=20) + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return {"1": qml.counts()} + + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + assert "PyTreeDef({'1': (*, *)})" == str(result_tree) + + def test_pytree_qml_counts_nested(self): + dev = qml.device("lightning.qubit", wires=1, shots=20) + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))} + + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + assert "PyTreeDef(({'1': (*, *)}, {'2': *}))" == str(result_tree) + + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))} + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + + assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))" == str(result_tree) + + def test_pytree_qml_counts_2_nested(self): + dev = qml.device("lightning.qubit", wires=1, shots=20) + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [{"3": qml.expval(qml.Z(0))}, {'4': qml.counts()}] + + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}]))" == str(result_tree) + + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [{"3": qml.counts()}, {'4': qml.expval(qml.Z(0))}] + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + + assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': (*, *)}, {'4': *}]))" == str(result_tree) + + + def test_pytree_qml_counts_longer(self): + dev = qml.device("lightning.qubit", wires=1, shots=20) + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return [[{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [{"3": qml.expval(qml.Z(0))}, {'4': qml.counts()}], {"5": qml.expval(qml.Z(0))}, {'6': qml.counts()}] + + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + assert "PyTreeDef([[{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}], {'5': *}, {'6': (*, *)}])" == str(result_tree) + + if __name__ == "__main__": pytest.main(["-x", __file__]) From 6556507608fab8f38b37f7209003310b7bb39004 Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Sat, 19 Oct 2024 01:15:04 -0400 Subject: [PATCH 02/11] running black linter --- frontend/catalyst/jax_tracer.py | 12 +++--- frontend/test/pytest/test_jax_integration.py | 42 +++++++++++++++----- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 57efecd0f0..34ad7051d5 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -934,19 +934,20 @@ def trace_quantum_measurements( return out_classical_tracers, out_tree + def replace_child_tree(tree: PyTreeDef, index: int, subtree: PyTreeDef) -> PyTreeDef: """ Replace the index-th leaf node in a PyTreeDef with a given subtree. - + Args: tree (PyTreeDef): The original PyTree. index (int): The index of the leaf node to replace. subtree (PyTreeDef): The new subtree to replace the original leaf node with. - + Returns: - PyTreeDef: The modified PyTree with the replaced leaf node. + PyTreeDef: The modified PyTree with the replaced leaf node. """ - + def replace_node(node, idx): if not node.children(): # Leaf node => update leaf node counter @@ -955,11 +956,10 @@ def replace_node(node, idx): return subtree return node - return node.make_from_node_data_and_children( PyTreeRegistry(), node.node_data(), - [replace_node(child, idx) for child in node.children()] + [replace_node(child, idx) for child in node.children()], ) return replace_node(tree, [0]) diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py index 9faf31a1e5..bec94f5b06 100644 --- a/frontend/test/pytest/test_jax_integration.py +++ b/frontend/test/pytest/test_jax_integration.py @@ -491,11 +491,12 @@ def ansatz(i, x): def test_pytree_qml_counts_simple(self): dev = qml.device("lightning.qubit", wires=1, shots=20) + @qjit @qml.qnode(dev) def circuit(x): qml.RX(x, wires=0) - return {"1": qml.counts()} + return {"1": qml.counts()} result = circuit(0.5) _, result_tree = jax.tree.flatten(result) @@ -503,6 +504,7 @@ def circuit(x): def test_pytree_qml_counts_nested(self): dev = qml.device("lightning.qubit", wires=1, shots=20) + @qjit @qml.qnode(dev) def circuit(x): @@ -513,12 +515,12 @@ def circuit(x): _, result_tree = jax.tree.flatten(result) assert "PyTreeDef(({'1': (*, *)}, {'2': *}))" == str(result_tree) - @qjit @qml.qnode(dev) def circuit(x): qml.RX(x, wires=0) return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))} + result = circuit(0.5) _, result_tree = jax.tree.flatten(result) @@ -526,40 +528,58 @@ def circuit(x): def test_pytree_qml_counts_2_nested(self): dev = qml.device("lightning.qubit", wires=1, shots=20) + @qjit @qml.qnode(dev) def circuit(x): qml.RX(x, wires=0) - return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [{"3": qml.expval(qml.Z(0))}, {'4': qml.counts()}] + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [ + {"3": qml.expval(qml.Z(0))}, + {"4": qml.counts()}, + ] result = circuit(0.5) _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}]))" == str(result_tree) - + assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}]))" == str( + result_tree + ) @qjit @qml.qnode(dev) def circuit(x): qml.RX(x, wires=0) - return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [{"3": qml.counts()}, {'4': qml.expval(qml.Z(0))}] + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [ + {"3": qml.counts()}, + {"4": qml.expval(qml.Z(0))}, + ] + result = circuit(0.5) _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': (*, *)}, {'4': *}]))" == str(result_tree) - + assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': (*, *)}, {'4': *}]))" == str( + result_tree + ) def test_pytree_qml_counts_longer(self): dev = qml.device("lightning.qubit", wires=1, shots=20) + @qjit @qml.qnode(dev) def circuit(x): qml.RX(x, wires=0) - return [[{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [{"3": qml.expval(qml.Z(0))}, {'4': qml.counts()}], {"5": qml.expval(qml.Z(0))}, {'6': qml.counts()}] + return [ + [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], + [{"3": qml.expval(qml.Z(0))}, {"4": qml.counts()}], + {"5": qml.expval(qml.Z(0))}, + {"6": qml.counts()}, + ] result = circuit(0.5) _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef([[{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}], {'5': *}, {'6': (*, *)}])" == str(result_tree) - + assert ( + "PyTreeDef([[{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}], {'5': *}, {'6': (*, *)}])" + == str(result_tree) + ) if __name__ == "__main__": From da3d038353e269f329260e4b6fd2397999d37911 Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Sat, 19 Oct 2024 01:34:15 -0400 Subject: [PATCH 03/11] codefactor fixes --- frontend/test/pytest/test_jax_integration.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py index bec94f5b06..97b7253a16 100644 --- a/frontend/test/pytest/test_jax_integration.py +++ b/frontend/test/pytest/test_jax_integration.py @@ -490,6 +490,7 @@ def ansatz(i, x): jax.grad(circuit, argnums=0)(params, 3) def test_pytree_qml_counts_simple(self): + """Test if a single qml.counts() can be used and output correctly.""" dev = qml.device("lightning.qubit", wires=1, shots=20) @qjit @@ -503,6 +504,7 @@ def circuit(x): assert "PyTreeDef({'1': (*, *)})" == str(result_tree) def test_pytree_qml_counts_nested(self): + """Test if nested qml.counts() can be used and output correctly.""" dev = qml.device("lightning.qubit", wires=1, shots=20) @qjit @@ -517,16 +519,17 @@ def circuit(x): @qjit @qml.qnode(dev) - def circuit(x): + def circuit2(x): qml.RX(x, wires=0) return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))} - result = circuit(0.5) + result = circuit2(0.5) _, result_tree = jax.tree.flatten(result) assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))" == str(result_tree) def test_pytree_qml_counts_2_nested(self): + """Test if multiple nested qml.counts() can be used and output correctly.""" dev = qml.device("lightning.qubit", wires=1, shots=20) @qjit @@ -546,14 +549,14 @@ def circuit(x): @qjit @qml.qnode(dev) - def circuit(x): + def circuit2(x): qml.RX(x, wires=0) return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [ {"3": qml.counts()}, {"4": qml.expval(qml.Z(0))}, ] - result = circuit(0.5) + result = circuit2(0.5) _, result_tree = jax.tree.flatten(result) assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': (*, *)}, {'4': *}]))" == str( @@ -561,6 +564,7 @@ def circuit(x): ) def test_pytree_qml_counts_longer(self): + """Test if 3 differently nested qml.counts() can be used and output correctly.""" dev = qml.device("lightning.qubit", wires=1, shots=20) @qjit @@ -577,7 +581,8 @@ def circuit(x): result = circuit(0.5) _, result_tree = jax.tree.flatten(result) assert ( - "PyTreeDef([[{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}], {'5': *}, {'6': (*, *)}])" + "PyTreeDef([[{'1': *}, {'2': (*, *)}], " + + "[{'3': *}, {'4': (*, *)}], {'5': *}, {'6': (*, *)}])" == str(result_tree) ) From 1f0d19b1e0ecfdea9db6aee9e16303e4e2505649 Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Tue, 22 Oct 2024 22:08:47 -0400 Subject: [PATCH 04/11] minor code fixes, correct location of tests + add mcm test, add changelog update --- doc/releases/changelog-dev.md | 3 + frontend/catalyst/jax_tracer.py | 4 +- frontend/test/pytest/test_jax_integration.py | 98 ---------------- frontend/test/pytest/test_pytree_args.py | 115 +++++++++++++++++++ 4 files changed, 120 insertions(+), 100 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3077e426cd..b473d8ce75 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -294,6 +294,8 @@ * Registers the func dialect as a requirement for running the scatter lowering pass. [(#1216)](https://github.com/PennyLaneAI/catalyst/pull/1216) +* Resolves a bug where calling `qml.counts()` within complex and/or nested return expressions did not return the correct `PyTreeDef`. [(#1219)](https://github.com/PennyLaneAI/catalyst/pull/1219) +

Internal changes

* Remove deprecated pennylane code across the frontend. @@ -350,6 +352,7 @@ This release contains contributions from (in alphabetical order): Amintor Dusko, +Arjun Bhamra, Joey Carter, Spencer Comin, Lillian M.A. Frederiksen, diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 34ad7051d5..3a360a8284 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -916,8 +916,8 @@ def trace_quantum_measurements( results = (jnp.asarray(results[0], jnp.int64), results[1]) out_classical_tracers.extend(results) counts_tree = tree_structure(("keys", "counts")) - out_tree = replace_child_tree(out_tree, i + 1 + num_counts, counts_tree) num_counts += 1 + out_tree = replace_child_tree(out_tree, i + num_counts, counts_tree) elif isinstance(o, StateMP) and not isinstance(o, DensityMatrixMP): assert using_compbasis shape = (2**nqubits,) @@ -937,7 +937,7 @@ def trace_quantum_measurements( def replace_child_tree(tree: PyTreeDef, index: int, subtree: PyTreeDef) -> PyTreeDef: """ - Replace the index-th leaf node in a PyTreeDef with a given subtree. + Replace the index-th leaf node in a left-to-right depth-first tree traversal of a PyTreeDef with a given subtree. Args: tree (PyTreeDef): The original PyTree. diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py index 97b7253a16..123ec3f0a0 100644 --- a/frontend/test/pytest/test_jax_integration.py +++ b/frontend/test/pytest/test_jax_integration.py @@ -489,103 +489,5 @@ def ansatz(i, x): params = jnp.array([0.54, 0.3154, 0.654, 0.123, 0.1, 0.2]) jax.grad(circuit, argnums=0)(params, 3) - def test_pytree_qml_counts_simple(self): - """Test if a single qml.counts() can be used and output correctly.""" - dev = qml.device("lightning.qubit", wires=1, shots=20) - - @qjit - @qml.qnode(dev) - def circuit(x): - qml.RX(x, wires=0) - return {"1": qml.counts()} - - result = circuit(0.5) - _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef({'1': (*, *)})" == str(result_tree) - - def test_pytree_qml_counts_nested(self): - """Test if nested qml.counts() can be used and output correctly.""" - dev = qml.device("lightning.qubit", wires=1, shots=20) - - @qjit - @qml.qnode(dev) - def circuit(x): - qml.RX(x, wires=0) - return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))} - - result = circuit(0.5) - _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef(({'1': (*, *)}, {'2': *}))" == str(result_tree) - - @qjit - @qml.qnode(dev) - def circuit2(x): - qml.RX(x, wires=0) - return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))} - - result = circuit2(0.5) - _, result_tree = jax.tree.flatten(result) - - assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))" == str(result_tree) - - def test_pytree_qml_counts_2_nested(self): - """Test if multiple nested qml.counts() can be used and output correctly.""" - dev = qml.device("lightning.qubit", wires=1, shots=20) - - @qjit - @qml.qnode(dev) - def circuit(x): - qml.RX(x, wires=0) - return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [ - {"3": qml.expval(qml.Z(0))}, - {"4": qml.counts()}, - ] - - result = circuit(0.5) - _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}]))" == str( - result_tree - ) - - @qjit - @qml.qnode(dev) - def circuit2(x): - qml.RX(x, wires=0) - return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [ - {"3": qml.counts()}, - {"4": qml.expval(qml.Z(0))}, - ] - - result = circuit2(0.5) - _, result_tree = jax.tree.flatten(result) - - assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': (*, *)}, {'4': *}]))" == str( - result_tree - ) - - def test_pytree_qml_counts_longer(self): - """Test if 3 differently nested qml.counts() can be used and output correctly.""" - dev = qml.device("lightning.qubit", wires=1, shots=20) - - @qjit - @qml.qnode(dev) - def circuit(x): - qml.RX(x, wires=0) - return [ - [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], - [{"3": qml.expval(qml.Z(0))}, {"4": qml.counts()}], - {"5": qml.expval(qml.Z(0))}, - {"6": qml.counts()}, - ] - - result = circuit(0.5) - _, result_tree = jax.tree.flatten(result) - assert ( - "PyTreeDef([[{'1': *}, {'2': (*, *)}], " - + "[{'3': *}, {'4': (*, *)}], {'5': *}, {'6': (*, *)}])" - == str(result_tree) - ) - - if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_pytree_args.py b/frontend/test/pytest/test_pytree_args.py index b8f1ab1e61..b5ebfb5338 100644 --- a/frontend/test/pytest/test_pytree_args.py +++ b/frontend/test/pytest/test_pytree_args.py @@ -593,5 +593,120 @@ def classical(x): assert result.a == 4 +class TestPyTreesQmlCounts: + """Test QJIT workflows when using qml.counts in a return expression.""" + + def test_pytree_qml_counts_simple(self): + """Test if a single qml.counts() can be used and output correctly.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return {"1": qml.counts()} + + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + assert "PyTreeDef({'1': (*, *)})" == str(result_tree) + + def test_pytree_qml_counts_nested(self): + """Test if nested qml.counts() can be used and output correctly.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))} + + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + assert "PyTreeDef(({'1': (*, *)}, {'2': *}))" == str(result_tree) + + @qjit + @qml.qnode(dev) + def circuit2(x): + qml.RX(x, wires=0) + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))} + + result = circuit2(0.5) + _, result_tree = jax.tree.flatten(result) + + assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))" == str(result_tree) + + def test_pytree_qml_counts_2_nested(self): + """Test if multiple nested qml.counts() can be used and output correctly.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [ + {"3": qml.expval(qml.Z(0))}, + {"4": qml.counts()}, + ] + + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}]))" == str( + result_tree + ) + + @qjit + @qml.qnode(dev) + def circuit2(x): + qml.RX(x, wires=0) + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [ + {"3": qml.counts()}, + {"4": qml.expval(qml.Z(0))}, + ] + + result = circuit2(0.5) + _, result_tree = jax.tree.flatten(result) + + assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': (*, *)}, {'4': *}]))" == str( + result_tree + ) + + def test_pytree_qml_counts_longer(self): + """Test if 3 differently nested qml.counts() can be used and output correctly.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return [ + [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], + [{"3": qml.expval(qml.Z(0))}, {"4": qml.counts()}], + {"5": qml.expval(qml.Z(0))}, + {"6": qml.counts()}, + ] + + result = circuit(0.5) + _, result_tree = jax.tree.flatten(result) + assert ( + "PyTreeDef([[{'1': *}, {'2': (*, *)}], " + + "[{'3': *}, {'4': (*, *)}], {'5': *}, {'6': (*, *)}])" + == str(result_tree) + ) + + def test_pytree_qml_counts_mcm(self): + """Test qml.counts() with mid circuit measurement.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qml.qjit + @qml.qnode(dev, mcm_method="one-shot", postselect_mode=None) + def circuit(x): + qml.RX(x, wires=0) + measure(0, postselect=1) + return {"hi": qml.counts()}, {"bye": qml.expval(qml.Z(0))}, {"hi": qml.counts()} + + result = circuit(0.9) + _, result_tree = jax.tree.flatten(result) + assert ("PyTreeDef(({'hi': (*, *)}, {'bye': *}, {'hi': (*, *)}))" == str(result_tree)) + if __name__ == "__main__": pytest.main(["-x", __file__]) From 9dba319203e37b496d3c34b162fc979657b9578c Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Tue, 22 Oct 2024 22:11:43 -0400 Subject: [PATCH 05/11] fix for linter --- frontend/test/pytest/test_pytree_args.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_pytree_args.py b/frontend/test/pytest/test_pytree_args.py index b5ebfb5338..9f8b21126f 100644 --- a/frontend/test/pytest/test_pytree_args.py +++ b/frontend/test/pytest/test_pytree_args.py @@ -706,7 +706,10 @@ def circuit(x): result = circuit(0.9) _, result_tree = jax.tree.flatten(result) - assert ("PyTreeDef(({'hi': (*, *)}, {'bye': *}, {'hi': (*, *)}))" == str(result_tree)) + assert ( + "PyTreeDef(({'hi': (*, *)}, {'bye': *}, {'hi': (*, *)}))" + == str(result_tree) + ) if __name__ == "__main__": pytest.main(["-x", __file__]) From ec5a8ecc63147b0f9c17c8a2c46bd0d2c44d5353 Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Wed, 23 Oct 2024 09:37:34 -0400 Subject: [PATCH 06/11] linter line length fix --- frontend/test/pytest/test_jax_integration.py | 1 + frontend/test/pytest/test_pytree_args.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py index 123ec3f0a0..99b78f29b6 100644 --- a/frontend/test/pytest/test_jax_integration.py +++ b/frontend/test/pytest/test_jax_integration.py @@ -489,5 +489,6 @@ def ansatz(i, x): params = jnp.array([0.54, 0.3154, 0.654, 0.123, 0.1, 0.2]) jax.grad(circuit, argnums=0)(params, 3) + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_pytree_args.py b/frontend/test/pytest/test_pytree_args.py index 9f8b21126f..ad2d5b32b1 100644 --- a/frontend/test/pytest/test_pytree_args.py +++ b/frontend/test/pytest/test_pytree_args.py @@ -706,10 +706,8 @@ def circuit(x): result = circuit(0.9) _, result_tree = jax.tree.flatten(result) - assert ( - "PyTreeDef(({'hi': (*, *)}, {'bye': *}, {'hi': (*, *)}))" - == str(result_tree) - ) + assert "PyTreeDef(({'hi': (*, *)}, {'bye': *}, {'hi': (*, *)}))" == str(result_tree) + if __name__ == "__main__": pytest.main(["-x", __file__]) From add7e2bb79c1d4e78878394b90fa013c1b044d3a Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Wed, 23 Oct 2024 11:54:52 -0400 Subject: [PATCH 07/11] updated tests to use actual PyTreeDefs, not str rep --- frontend/test/pytest/test_pytree_args.py | 103 ++++++++++++++++------- 1 file changed, 73 insertions(+), 30 deletions(-) diff --git a/frontend/test/pytest/test_pytree_args.py b/frontend/test/pytest/test_pytree_args.py index ad2d5b32b1..2639ba682b 100644 --- a/frontend/test/pytest/test_pytree_args.py +++ b/frontend/test/pytest/test_pytree_args.py @@ -606,9 +606,13 @@ def circuit(x): qml.RX(x, wires=0) return {"1": qml.counts()} - result = circuit(0.5) - _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef({'1': (*, *)})" == str(result_tree) + observed = circuit(0.5) + expected = {"1": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))} + + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape def test_pytree_qml_counts_nested(self): """Test if nested qml.counts() can be used and output correctly.""" @@ -620,9 +624,16 @@ def circuit(x): qml.RX(x, wires=0) return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))} - result = circuit(0.5) - _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef(({'1': (*, *)}, {'2': *}))" == str(result_tree) + observed = circuit(0.5) + expected = ( + {"1": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))}, + {"2": jnp.array(-1, dtype=jnp.float64)}, + ) + + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape @qjit @qml.qnode(dev) @@ -630,10 +641,16 @@ def circuit2(x): qml.RX(x, wires=0) return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))} - result = circuit2(0.5) - _, result_tree = jax.tree.flatten(result) - - assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))" == str(result_tree) + observed = circuit2(0.5) + expected = ( + [{"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))}], + {"3": jnp.array(-1, dtype=jnp.float64)}, + ) + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape def test_pytree_qml_counts_2_nested(self): """Test if multiple nested qml.counts() can be used and output correctly.""" @@ -648,11 +665,18 @@ def circuit(x): {"4": qml.counts()}, ] - result = circuit(0.5) - _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': *}, {'4': (*, *)}]))" == str( - result_tree + observed = circuit(0.5) + expected = ( + [{"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))}], + [{"3": jnp.array(-1, dtype=jnp.float64)}, + {"4": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))}], ) + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape @qjit @qml.qnode(dev) @@ -663,12 +687,18 @@ def circuit2(x): {"4": qml.expval(qml.Z(0))}, ] - result = circuit2(0.5) - _, result_tree = jax.tree.flatten(result) - - assert "PyTreeDef(([{'1': *}, {'2': (*, *)}], [{'3': (*, *)}, {'4': *}]))" == str( - result_tree + observed = circuit2(0.5) + expected = ( + [{"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))}], + [{"3": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))}, + {"4": jnp.array(-1, dtype=jnp.float64)}], ) + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape def test_pytree_qml_counts_longer(self): """Test if 3 differently nested qml.counts() can be used and output correctly.""" @@ -685,13 +715,21 @@ def circuit(x): {"6": qml.counts()}, ] - result = circuit(0.5) - _, result_tree = jax.tree.flatten(result) - assert ( - "PyTreeDef([[{'1': *}, {'2': (*, *)}], " - + "[{'3': *}, {'4': (*, *)}], {'5': *}, {'6': (*, *)}])" - == str(result_tree) - ) + observed = circuit(0.5) + expected = [ + [{"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))}], + [{"3": jnp.array(-1, dtype=jnp.float64)}, + {"4": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))}], + {"5": jnp.array(-1, dtype=jnp.float64)}, + {"6": (jnp.array((0, 1), dtype=jnp.int64), + jnp.array((0, 3), dtype=jnp.int64))}, + ] + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape def test_pytree_qml_counts_mcm(self): """Test qml.counts() with mid circuit measurement.""" @@ -704,10 +742,15 @@ def circuit(x): measure(0, postselect=1) return {"hi": qml.counts()}, {"bye": qml.expval(qml.Z(0))}, {"hi": qml.counts()} - result = circuit(0.9) - _, result_tree = jax.tree.flatten(result) - assert "PyTreeDef(({'hi': (*, *)}, {'bye': *}, {'hi': (*, *)}))" == str(result_tree) - + observed = circuit(0.5) + expected = ( + {"hi": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + {"bye": jnp.array(-1, dtype=jnp.float64)}, + {"hi": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ) + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape if __name__ == "__main__": pytest.main(["-x", __file__]) From 29697139597b92076c3f082fe5bdc264e5dd9eb6 Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Wed, 23 Oct 2024 11:55:52 -0400 Subject: [PATCH 08/11] linter fixes --- frontend/test/pytest/test_pytree_args.py | 63 +++++++++++++----------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/frontend/test/pytest/test_pytree_args.py b/frontend/test/pytest/test_pytree_args.py index 2639ba682b..fb1ad37962 100644 --- a/frontend/test/pytest/test_pytree_args.py +++ b/frontend/test/pytest/test_pytree_args.py @@ -607,9 +607,8 @@ def circuit(x): return {"1": qml.counts()} observed = circuit(0.5) - expected = {"1": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))} - + expected = {"1": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))} + _, expected_shape = tree_flatten(expected) _, observed_shape = tree_flatten(observed) assert expected_shape == observed_shape @@ -626,11 +625,10 @@ def circuit(x): observed = circuit(0.5) expected = ( - {"1": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))}, + {"1": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, {"2": jnp.array(-1, dtype=jnp.float64)}, ) - + _, expected_shape = tree_flatten(expected) _, observed_shape = tree_flatten(observed) assert expected_shape == observed_shape @@ -643,9 +641,10 @@ def circuit2(x): observed = circuit2(0.5) expected = ( - [{"1": jnp.array(-1, dtype=jnp.float64)}, - {"2": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))}], + [ + {"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], {"3": jnp.array(-1, dtype=jnp.float64)}, ) _, expected_shape = tree_flatten(expected) @@ -667,12 +666,14 @@ def circuit(x): observed = circuit(0.5) expected = ( - [{"1": jnp.array(-1, dtype=jnp.float64)}, - {"2": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))}], - [{"3": jnp.array(-1, dtype=jnp.float64)}, - {"4": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))}], + [ + {"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], + [ + {"3": jnp.array(-1, dtype=jnp.float64)}, + {"4": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], ) _, expected_shape = tree_flatten(expected) _, observed_shape = tree_flatten(observed) @@ -689,12 +690,14 @@ def circuit2(x): observed = circuit2(0.5) expected = ( - [{"1": jnp.array(-1, dtype=jnp.float64)}, - {"2": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))}], - [{"3": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))}, - {"4": jnp.array(-1, dtype=jnp.float64)}], + [ + {"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], + [ + {"3": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + {"4": jnp.array(-1, dtype=jnp.float64)}, + ], ) _, expected_shape = tree_flatten(expected) _, observed_shape = tree_flatten(observed) @@ -717,15 +720,16 @@ def circuit(x): observed = circuit(0.5) expected = [ - [{"1": jnp.array(-1, dtype=jnp.float64)}, - {"2": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))}], - [{"3": jnp.array(-1, dtype=jnp.float64)}, - {"4": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))}], + [ + {"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], + [ + {"3": jnp.array(-1, dtype=jnp.float64)}, + {"4": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], {"5": jnp.array(-1, dtype=jnp.float64)}, - {"6": (jnp.array((0, 1), dtype=jnp.int64), - jnp.array((0, 3), dtype=jnp.int64))}, + {"6": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, ] _, expected_shape = tree_flatten(expected) _, observed_shape = tree_flatten(observed) @@ -752,5 +756,6 @@ def circuit(x): _, observed_shape = tree_flatten(observed) assert expected_shape == observed_shape + if __name__ == "__main__": pytest.main(["-x", __file__]) From 1cf0aa786456de8b0c778a564c46c6a25756d344 Mon Sep 17 00:00:00 2001 From: Arjun Bhamra <33864884+abhamra@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:00:32 -0400 Subject: [PATCH 09/11] Update doc/releases/changelog-dev.md Co-authored-by: paul0403 <79805239+paul0403@users.noreply.github.com> --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 5d15241264..413b14cc67 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -326,7 +326,7 @@ - Registers the func dialect as a requirement for running the scatter lowering pass. - Emits error if `%input`, `%update` and `%result` are not of length 1 instead of segfaulting. -* Resolves a bug where calling `qml.counts()` within complex and/or nested return expressions did not return the correct `PyTreeDef`. [(#1219)](https://github.com/PennyLaneAI/catalyst/pull/1219) +* Resolves a bug where calling `qml.counts()` within complicated and/or nested return expressions did not return the correct `PyTreeDef`. [(#1219)](https://github.com/PennyLaneAI/catalyst/pull/1219)

Internal changes

From af405f59755dda39ba22a801eb2bf7c59bdfbdd4 Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Wed, 23 Oct 2024 20:01:17 -0400 Subject: [PATCH 10/11] changelog and comment fixes --- doc/releases/changelog-dev.md | 4 +++- frontend/catalyst/jax_tracer.py | 3 ++- frontend/test/pytest/test_pytree_args.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8364248433..3c258b724b 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -325,11 +325,13 @@ - Registers the func dialect as a requirement for running the scatter lowering pass. - Emits error if `%input`, `%update` and `%result` are not of length 1 instead of segfaulting. + * Fixes a performance issue with vmap with its root cause in the lowering of the scatter operation. [(#1214)](https://github.com/PennyLaneAI/catalyst/pull/1214) + * Resolves a bug where calling `qml.counts()` within complicated and/or nested return expressions did not return the correct `PyTreeDef`. -* [(#1219)](https://github.com/PennyLaneAI/catalyst/pull/1219) + [(#1219)](https://github.com/PennyLaneAI/catalyst/pull/1219)

Internal changes

diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 3a360a8284..3cf3696f5d 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -937,7 +937,8 @@ def trace_quantum_measurements( def replace_child_tree(tree: PyTreeDef, index: int, subtree: PyTreeDef) -> PyTreeDef: """ - Replace the index-th leaf node in a left-to-right depth-first tree traversal of a PyTreeDef with a given subtree. + Replace the index-th leaf node in a left-to-right depth-first tree traversal of a PyTreeDef + with a given subtree. Args: tree (PyTreeDef): The original PyTree. diff --git a/frontend/test/pytest/test_pytree_args.py b/frontend/test/pytest/test_pytree_args.py index fb1ad37962..ff55b45a81 100644 --- a/frontend/test/pytest/test_pytree_args.py +++ b/frontend/test/pytest/test_pytree_args.py @@ -736,7 +736,7 @@ def circuit(x): assert expected_shape == observed_shape def test_pytree_qml_counts_mcm(self): - """Test qml.counts() with mid circuit measurement.""" + """Test qml.counts() with mid-circuit measurement.""" dev = qml.device("lightning.qubit", wires=1, shots=20) @qml.qjit From 5bb216f27b643312c619de255806708e8a49073c Mon Sep 17 00:00:00 2001 From: Arjun Bhamra Date: Mon, 28 Oct 2024 14:15:11 -0400 Subject: [PATCH 11/11] add explanation for num_counts --- frontend/catalyst/jax_tracer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 3cf3696f5d..b6f42fff91 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -844,7 +844,8 @@ def trace_quantum_measurements( """ shots = get_device_shots(device) out_classical_tracers = [] - # NOTE: Number of qml.counts() we hit + # NOTE: Number of qml.counts() we hit, used to update our iteration variable to account + # for additional leaf PyTreeDef nodes. num_counts = 0 for i, o in enumerate(outputs):