Skip to content

Commit

Permalink
change: get dimensions for aq.ArrayVar automatically (#22)
Browse files Browse the repository at this point in the history
* change: get dimensions for aq.ArrayVar automatically

reformat

* change: add init_expression to ArrayVar explicitly

* Add test for multidimensional array support

remove print

* Add type hint for init_expression

* Add type hint test for init_expression
  • Loading branch information
Yash-10 authored Jun 3, 2024
1 parent ecce2bc commit 32b2524
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 9 deletions.
19 changes: 17 additions & 2 deletions src/autoqasm/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,31 @@ def __init__(


class ArrayVar(oqpy.ArrayVar):
def __init__(self, *args, annotations: str | Iterable[str] | None = None, **kwargs):
def __init__(
self,
init_expression: Iterable,
*args,
annotations: str | Iterable[str] | None = None,
**kwargs,
):
if (
program.get_program_conversion_context().subroutines_processing
or not program.get_program_conversion_context().at_function_root_scope
):
raise errors.InvalidArrayDeclaration(
"Arrays may only be declared at the root scope of an AutoQASM main function."
)

if not isinstance(init_expression, Iterable):
raise errors.InvalidArrayDeclaration("init_expression must be an iterable type.")

dimensions = [len(init_expression)]
super(ArrayVar, self).__init__(
*args, annotations=make_annotations_list(annotations), **kwargs
init_expression=init_expression,
*args,
annotations=make_annotations_list(annotations),
dimensions=dimensions,
**kwargs,
)
self.name = program.get_program_conversion_context().next_var_name(oqpy.ArrayVar)

Expand Down
52 changes: 45 additions & 7 deletions test/unit_tests/autoqasm/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ def test_declare_array():

@aq.main
def declare_array():
a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3])
a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar)
a[0] = 11
b = aq.ArrayVar([4, 5, 6], base_type=aq.IntVar, dimensions=[3])
b = aq.ArrayVar([4, 5, 6], base_type=aq.IntVar)
b[2] = 14
b = a

Expand All @@ -207,8 +207,8 @@ def test_invalid_array_assignment():

@aq.main
def invalid():
a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3])
b = aq.ArrayVar([4, 5], base_type=aq.IntVar, dimensions=[2])
a = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar)
b = aq.ArrayVar([4, 5], base_type=aq.IntVar)
a = b # noqa: F841

with pytest.raises(aq.errors.InvalidAssignmentStatement):
Expand All @@ -221,7 +221,7 @@ def test_declare_array_in_local_scope():
@aq.main
def declare_array():
if aq.BoolVar(True):
_ = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3])
_ = aq.ArrayVar([1, 2, 3], base_type=aq.IntVar)

with pytest.raises(aq.errors.InvalidArrayDeclaration):
declare_array.build()
Expand All @@ -236,7 +236,7 @@ def main() -> list[int]:

@aq.subroutine
def declare_array():
_ = aq.ArrayVar([1, 2, 3], dimensions=[3])
_ = aq.ArrayVar([1, 2, 3])

with pytest.raises(aq.errors.InvalidArrayDeclaration):
main.build()
Expand Down Expand Up @@ -383,7 +383,7 @@ def annotation_test(input: list[int]):

@aq.main
def main():
a = aq.ArrayVar([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dimensions=[10])
a = aq.ArrayVar([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
annotation_test(a)

with pytest.raises(aq.errors.ParameterTypeError):
Expand Down Expand Up @@ -724,3 +724,41 @@ def main():

with pytest.raises(aq.errors.ParameterTypeError):
main.build()


def test_array_does_not_accept_dimensions_argument():
@aq.main
def declare_array():
aq.ArrayVar([1, 2, 3], base_type=aq.IntVar, dimensions=[3])

with pytest.raises(TypeError):
declare_array.build()


def test_array_requires_init_expression():
@aq.main
def declare_array():
aq.ArrayVar()

with pytest.raises(TypeError):
declare_array.build()


def test_array_init_expression_type():
@aq.main
def declare_array():
aq.ArrayVar(1)

with pytest.raises(aq.errors.InvalidArrayDeclaration):
declare_array.build()


def test_array_supports_multidimensional_arrays():
@aq.main
def declare_array():
aq.ArrayVar([[1, 2], [3, 4]])

expected = """OPENQASM 3.0;
array[int[32], 2, 2] a = {{1, 2}, {3, 4}};"""

declare_array.build().to_ir() == expected

0 comments on commit 32b2524

Please sign in to comment.