From 202ea3467490c5023925ae4df4b5c6faf4c8a9a0 Mon Sep 17 00:00:00 2001 From: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:04:14 -0400 Subject: [PATCH] Fix multi-dimensional ArrayVar declaration --- src/autoqasm/types/types.py | 3 ++- test/unit_tests/autoqasm/test_types.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/autoqasm/types/types.py b/src/autoqasm/types/types.py index 1963049..11bcf2c 100644 --- a/src/autoqasm/types/types.py +++ b/src/autoqasm/types/types.py @@ -18,6 +18,7 @@ from collections.abc import Iterable from typing import Any, List, Union, get_args +import numpy as np import oqpy import oqpy.base from braket.circuits import FreeParameterExpression @@ -114,7 +115,7 @@ def __init__( if not isinstance(init_expression, Iterable): raise errors.InvalidArrayDeclaration("init_expression must be an iterable type.") - dimensions = [len(init_expression)] + dimensions = np.shape(init_expression) super(ArrayVar, self).__init__( init_expression=init_expression, *args, diff --git a/test/unit_tests/autoqasm/test_types.py b/test/unit_tests/autoqasm/test_types.py index ec4ff93..28c56df 100644 --- a/test/unit_tests/autoqasm/test_types.py +++ b/test/unit_tests/autoqasm/test_types.py @@ -756,9 +756,9 @@ def declare_array(): def test_array_supports_multidimensional_arrays(): @aq.main def declare_array(): - aq.ArrayVar([[1, 2], [3, 4]]) + a = aq.ArrayVar([[1, 2, 3], [4, 5, 6]]) # noqa: F841 expected = """OPENQASM 3.0; -array[int[32], 2, 2] a = {{1, 2}, {3, 4}};""" +array[int[32], 2, 3] a = {{1, 2, 3}, {4, 5, 6}};""" - declare_array.build().to_ir() == expected + assert declare_array.build().to_ir() == expected