Skip to content

Commit

Permalink
fix: Fix generic array functions (#630)
Browse files Browse the repository at this point in the history
See #629 for context.

This PR changes the lowering of classical arrays to use the same
`Option` trick that we use for linear arrays. This fixes #629 in the
short-term. Longer-term we should do #628 to address the problem
properly
  • Loading branch information
mark-koch authored Nov 14, 2024
1 parent 7519b90 commit f4e5655
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 9 deletions.
6 changes: 4 additions & 2 deletions guppylang/std/_internal/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,10 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
assert isinstance(len_arg, ConstArg)
if not self._is_numeric_or_bool_type(ty_arg.ty):
raise GuppyError(err)
base_ty = ty_arg.ty
array_len = len_arg.const
_base_ty = ty_arg.ty
_array_len = len_arg.const
# See https://github.com/CQCL/guppylang/issues/631
raise GuppyError(UnsupportedError(value, "Array results"))
else:
raise GuppyError(err)
node = ResultExpr(value, base_ty, array_len, tag.value)
Expand Down
15 changes: 11 additions & 4 deletions guppylang/std/_internal/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class NewArrayCompiler(ArrayCompiler):

def build_classical_array(self, elems: list[Wire]) -> Wire:
"""Lowers a call to `array.__new__` for classical arrays."""
return self.builder.add_op(array_new(self.elem_ty, len(elems)), *elems)
# See https://github.com/CQCL/guppylang/issues/629
return self.build_linear_array(elems)

def build_linear_array(self, elems: list[Wire]) -> Wire:
"""Lowers a call to `array.__new__` for linear arrays."""
Expand All @@ -121,9 +122,12 @@ class ArrayGetitemCompiler(ArrayCompiler):

def build_classical_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
"""Lowers a call to `array.__getitem__` for classical arrays."""
# See https://github.com/CQCL/guppylang/issues/629
elem_opt_ty = ht.Option(self.elem_ty)
idx = self.builder.add_op(convert_itousize(), idx)
result = self.builder.add_op(array_get(self.elem_ty, self.length), array, idx)
elem = build_unwrap(self.builder, result, "Array index out of bounds")
result = self.builder.add_op(array_get(elem_opt_ty, self.length), array, idx)
elem_opt = build_unwrap(self.builder, result, "Array index out of bounds")
elem = build_unwrap(self.builder, elem_opt, "array.__getitem__: Internal error")
return CallReturnWires(regular_returns=[elem], inout_returns=[array])

def build_linear_getitem(self, array: Wire, idx: Wire) -> CallReturnWires:
Expand Down Expand Up @@ -163,9 +167,12 @@ def build_classical_setitem(
self, array: Wire, idx: Wire, elem: Wire
) -> CallReturnWires:
"""Lowers a call to `array.__setitem__` for classical arrays."""
# See https://github.com/CQCL/guppylang/issues/629
elem_opt_ty = ht.Option(self.elem_ty)
idx = self.builder.add_op(convert_itousize(), idx)
elem_opt = self.builder.add_op(ops.Tag(1, elem_opt_ty), elem)
result = self.builder.add_op(
array_set(self.elem_ty, self.length), array, idx, elem
array_set(elem_opt_ty, self.length), array, idx, elem_opt
)
# Unwrap the result, but we don't have to hold onto the returned old value
_, array = build_unwrap_right(self.builder, result, "Array index out of bounds")
Expand Down
5 changes: 2 additions & 3 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,8 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:

# Linear elements are turned into an optional to enable unsafe indexing.
# See `ArrayGetitemCompiler` for details.
elem_ty = (
ht.Option(ty_arg.ty.to_hugr()) if ty_arg.ty.linear else ty_arg.ty.to_hugr()
)
# Same also for classical arrays, see https://github.com/CQCL/guppylang/issues/629
elem_ty = ht.Option(ty_arg.ty.to_hugr())

array = hugr.std.PRELUDE.get_type("array")
return array.instantiate([len_arg.to_hugr(), ht.TypeTypeArg(elem_ty)])
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,26 @@ def main(a: A @owned, i: int, j: int, k: int) -> A:

validate(module.compile())


def test_generic_function(validate):
module = GuppyModule("test")
module.load(qubit)
T = guppy.type_var("T", linear=True, module=module)
n = guppy.nat_var("n", module=module)

@guppy(module)
def foo(xs: array[T, n] @owned) -> array[T, n]:
return xs

@guppy(module)
def main() -> tuple[array[int, 3], array[qubit, 2]]:
xs = array(1, 2, 3)
ys = array(qubit(), qubit())
return foo(xs), foo(ys)

validate(module.compile())


def test_exec_array(validate, run_int_fn):
module = GuppyModule("test")

Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_result.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from guppylang.std.builtins import result, nat, array
from tests.util import compile_guppy

Expand All @@ -21,6 +23,7 @@ def main(w: nat, x: int, y: float, z: bool) -> None:
validate(main)


@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/631")
def test_array(validate):
@compile_guppy
def main(
Expand Down

0 comments on commit f4e5655

Please sign in to comment.