Skip to content

Commit

Permalink
Use inspect module to get argument counts (#1151)
Browse files Browse the repository at this point in the history
Using `qml.cond` on PennyLane operators raises an exception because the
target does not have a `__code__` attribute:
```py
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def f(x: int):
    qml.cond(x < 5, qml.Hadamard)(0)
    return qml.probs()
```
```
AttributeError: type object 'Hadamard' has no attribute '__code__'
```

[sc-74292]
  • Loading branch information
dime10 authored Sep 24, 2024
1 parent 316bf26 commit f122439
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 12 deletions.
27 changes: 18 additions & 9 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
Array([[1], [0], [1], [1], [0], [1],[0]], dtype=int64))
```

* A new function `catalyst.passes.pipeline` allows the quantum circuit transformation pass pipeline for QNodes within a qjit-compiled workflow to be configured.
* A new function `catalyst.passes.pipeline` allows the quantum circuit transformation pass pipeline
for QNodes within a qjit-compiled workflow to be configured.
[(#1131)](https://github.com/PennyLaneAI/catalyst/pull/1131)

```python
Expand Down Expand Up @@ -92,10 +93,12 @@
return jnp.abs(circuit_pipeline(x) - circuit_other(x))
```

For a list of available passes, please see the [catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).
For a list of available passes, please see the
[catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

The pass pipeline order and options can be configured *globally* for a
qjit-compiled function, by using the `circuit_transform_pipeline` argument of the :func:`~.qjit` decorator.
The pass pipeline order and options can be configured *globally* for a qjit-compiled
function, by using the `circuit_transform_pipeline` argument of the :func:`~.qjit`
decorator.

```python
my_passes = {
Expand All @@ -114,11 +117,13 @@
Available MLIR passes are now documented and available within the
[catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

* Catalyst Autograph now supports updating a single index or a slice of JAX arrays using Python's array assignment operator syntax.
* Catalyst Autograph now supports updating a single index or a slice of JAX arrays using Python's
array assignment operator syntax.
[(#769)](https://github.com/PennyLaneAI/catalyst/pull/769)
[(#1143)](https://github.com/PennyLaneAI/catalyst/pull/1143)

Using operator assignment syntax in favor of `at...op` expressions is now possible for the following operations:
Using operator assignment syntax in favor of `at...op` expressions is now possible for the
following operations:
* `x[i] += y` in favor of `x.at[i].add(y)`
* `x[i] -= y` in favor of `x.at[i].add(-y)`
* `x[i] *= y` in favor of `x.at[i].multiply(y)`
Expand Down Expand Up @@ -154,8 +159,9 @@
in the program in tensors, including scalars, leading to unnecessary memory allocations for
programs compiled to CPU via MLIR to LLVM pipeline.

* Bufferization of `gradient.ForwardOp` and `gradient.ReverseOp` now requires 3 steps: `gradient-preprocessing`,
`gradient-bufferize`, and `gradient-postprocessing`. `gradient-bufferize` has a new rewrite for `gradient.ReturnOp`.
* Bufferization of `gradient.ForwardOp` and `gradient.ReverseOp` now requires 3 steps:
`gradient-preprocessing`, `gradient-bufferize`, and `gradient-postprocessing`.
`gradient-bufferize` has a new rewrite for `gradient.ReturnOp`.
[(#1139)](https://github.com/PennyLaneAI/catalyst/pull/1139)

* The decorator `self_inverses` now supports all Hermitian Gates.
Expand Down Expand Up @@ -196,9 +202,12 @@

<h3>Bug fixes</h3>

* Resolve a bug in the `vmap` function when passing shapeless values to the target.
* Resolve a bug in the `vmap` function when passing shapeless values to the target.
[(#1150)](https://github.com/PennyLaneAI/catalyst/pull/1150)

* Fix error message displayed when using `qml.cond` on callables with arguments.
[(#1151)](https://github.com/PennyLaneAI/catalyst/pull/1151)

<h3>Internal changes</h3>

* Update Enzyme to version `v0.0.149`.
Expand Down
7 changes: 4 additions & 3 deletions frontend/catalyst/api_extensions/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# pylint: disable=too-many-lines

import inspect
from typing import Any, Callable, List

import jax
Expand Down Expand Up @@ -236,7 +237,7 @@ def conditional_fn():
"""

def _decorator(true_fn: Callable):
if true_fn.__code__.co_argcount != 0:
if len(inspect.signature(true_fn).parameters):
raise TypeError("Conditional 'True' function is not allowed to have any arguments")
return CondCallable(pred, true_fn)

Expand Down Expand Up @@ -584,7 +585,7 @@ def else_if(self, pred):
"""

def decorator(branch_fn):
if branch_fn.__code__.co_argcount != 0:
if len(inspect.signature(branch_fn).parameters):
raise TypeError(
"Conditional 'else if' function is not allowed to have any arguments"
)
Expand All @@ -603,7 +604,7 @@ def otherwise(self, otherwise_fn):
Returns:
self
"""
if otherwise_fn.__code__.co_argcount != 0:
if len(inspect.signature(otherwise_fn).parameters):
raise TypeError("Conditional 'False' function is not allowed to have any arguments")
self.otherwise_fn = otherwise_fn
return self
Expand Down
13 changes: 13 additions & 0 deletions frontend/test/pytest/test_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,19 @@ def conditional_flip():
assert circuit(False) == 0
assert circuit(True) == 1

def test_argument_error_with_callables(self):
"""Test for the error when arguments are supplied and the target is not a function."""

@qml.qnode(qml.device("lightning.qubit", wires=1))
def f(x: int):

qml.cond(x < 5, qml.Hadamard)(0)

return qml.probs()

with pytest.raises(TypeError, match="not allowed to have any arguments"):
qjit(f)


class TestInterpretationConditional:
"""Test that the conditional operation's execution is semantically equivalent
Expand Down

0 comments on commit f122439

Please sign in to comment.