Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support integer division #29

Merged
merged 9 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/autoqasm/converters/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""Converters for aritmetic operator nodes"""

import ast

import gast
from malt.core import ag_ctx, converter
from malt.pyct import templates

ARITHMETIC_OPERATORS = {
gast.FloorDiv: "ag__.fd_",
}


class ArithmeticTransformer(converter.Base):
"""Transformer for arithmetic nodes."""

def visit_BinOp(self, node: ast.stmt) -> ast.stmt:
"""Transforms a BinOp node.
Args :
node(ast.stmt) : AST node to transform
Returns :
ast.stmt : Transformed node
"""
node = self.generic_visit(node)
op_type = type(node.op)
if op_type not in ARITHMETIC_OPERATORS:
return node

template = f"{ARITHMETIC_OPERATORS[op_type]}(lhs_,rhs_)"

new_node = templates.replace(
template,
lhs_=node.left,
rhs_=node.right,
original=node,
)[0].value

return new_node


def transform(node: ast.stmt, ctx: ag_ctx.ControlStatusCtx) -> ast.stmt:
"""Transform arithmetic nodes.
Args:
node(ast.stmt) : AST node to transform
ctx (ag_ctx.ControlStatusCtx) : Transformer context.
Returns :
ast.stmt : Transformed node.
"""

return ArithmeticTransformer(ctx).visit(node)
1 change: 1 addition & 0 deletions src/autoqasm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from malt.impl.api import autograph_artifact # noqa: F401
from malt.operators.variables import Undefined, UndefinedReturnValue, ld, ldu # noqa: F401

from .arithmetic import fd_ # noqa: F401
from .assignments import assign_for_output, assign_stmt # noqa: F401
from .comparisons import gt_, gteq_, lt_, lteq_ # noqa: F401
from .conditional_expressions import if_exp # noqa: F401
Expand Down
74 changes: 74 additions & 0 deletions src/autoqasm/operators/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""Operators for arithmetic operators: // """

from __future__ import annotations

from autoqasm import program
from autoqasm import types as aq_types

from .utils import _register_and_convert_parameters


def fd_(
jcjaskula-aws marked this conversation as resolved.
Show resolved Hide resolved
num_: aq_types.IntVar | aq_types.FloatVar | int | float,
jcjaskula-aws marked this conversation as resolved.
Show resolved Hide resolved
den_: aq_types.IntVar | aq_types.FloatVar | int | float,
) -> int | aq_types.IntVar:
"""Functional form of "//".
Args:
num_ (aq_types.IntVar | aq_types.FloatVar | int | float) :
The numerator of the integer division
den_ (aq_types.IntVar | aq_types.FloatVar | int | float) :
The denominator of the integer division
Returns :
int | IntVar : integer division, IntVar if either numerator or denominator
are QASM types, else int
"""
if aq_types.is_qasm_type(num_) or aq_types.is_qasm_type(den_):
return _oqpy_fd(num_, den_)
else:
return _py_fd(num_, den_)


def _oqpy_fd(
num_: aq_types.IntVar | aq_types.FloatVar,
den_: aq_types.IntVar | aq_types.FloatVar,
abidart marked this conversation as resolved.
Show resolved Hide resolved
) -> aq_types.IntVar:
num_, den_ = _register_and_convert_parameters(num_, den_)
oqpy_program = program.get_program_conversion_context().get_oqpy_program()
num_is_float = isinstance(num_, aq_types.FloatVar)
den_is_float = isinstance(den_, aq_types.FloatVar)
abidart marked this conversation as resolved.
Show resolved Hide resolved

# if they are of different types, then one must cast to FloatVar
if num_is_float or den_is_float:
if num_is_float:
float_var = aq_types.FloatVar()
oqpy_program.declare(float_var)
oqpy_program.set(float_var, den_)

Check warning on line 58 in src/autoqasm/operators/arithmetic.py

View check run for this annotation

Codecov / codecov/patch

src/autoqasm/operators/arithmetic.py#L56-L58

Added lines #L56 - L58 were not covered by tests
if den_is_float:
float_var = aq_types.FloatVar()
oqpy_program.declare(float_var)
oqpy_program.set(float_var, num_)
abidart marked this conversation as resolved.
Show resolved Hide resolved

result = aq_types.IntVar()
oqpy_program.declare(result)
oqpy_program.set(result, num_ / den_)
return result


def _py_fd(
num_: int | float,
den_: int | float,
) -> int:
return num_ // den_
9 changes: 8 additions & 1 deletion src/autoqasm/transpiler/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@
from malt.utils import ag_logging as logging

from autoqasm import operators, program, types
from autoqasm.converters import assignments, break_statements, comparisons, return_statements
from autoqasm.converters import (
arithmetic,
assignments,
break_statements,
comparisons,
return_statements,
)


class PyToOqpy(transpiler.PyToPy):
Expand Down Expand Up @@ -135,6 +141,7 @@ def transform_ast(
node = control_flow.transform(node, ctx)
node = conditional_expressions.transform(node, ctx)
node = comparisons.transform(node, ctx)
node = arithmetic.transform(node, ctx)
node = logical_expressions.transform(node, ctx)
node = variables.transform(node, ctx)

Expand Down
13 changes: 13 additions & 0 deletions src/autoqasm/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from openpulse import ast

from autoqasm import errors, program
from autoqasm.errors import UnsupportedFeatureError


def is_qasm_type(val: Any) -> bool:
Expand Down Expand Up @@ -152,10 +153,22 @@
)
self.name = program.get_program_conversion_context().next_var_name(oqpy.FloatVar)

def __floordiv__(self, other):
raise UnsupportedFeatureError("Integer division is supported by OpenQASM.")

Check warning on line 157 in src/autoqasm/types/types.py

View check run for this annotation

Codecov / codecov/patch

src/autoqasm/types/types.py#L157

Added line #L157 was not covered by tests
rmshaffer marked this conversation as resolved.
Show resolved Hide resolved

def __rfloordiv__(self, other):
raise UnsupportedFeatureError("Integer division is supported by OpenQASM.")

Check warning on line 160 in src/autoqasm/types/types.py

View check run for this annotation

Codecov / codecov/patch

src/autoqasm/types/types.py#L160

Added line #L160 was not covered by tests


class IntVar(oqpy.IntVar):
def __init__(self, *args, annotations: str | Iterable[str] | None = None, **kwargs):
super(IntVar, self).__init__(
*args, annotations=make_annotations_list(annotations), **kwargs
)
self.name = program.get_program_conversion_context().next_var_name(oqpy.IntVar)

def __floordiv__(self, other):
raise UnsupportedFeatureError("Integer division is supported by OpenQASM.")

Check warning on line 171 in src/autoqasm/types/types.py

View check run for this annotation

Codecov / codecov/patch

src/autoqasm/types/types.py#L171

Added line #L171 was not covered by tests

def __rfloordiv__(self, other):
raise UnsupportedFeatureError("Integer division is supported by OpenQASM.")

Check warning on line 174 in src/autoqasm/types/types.py

View check run for this annotation

Codecov / codecov/patch

src/autoqasm/types/types.py#L174

Added line #L174 was not covered by tests
49 changes: 49 additions & 0 deletions test/unit_tests/autoqasm/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,52 @@ def test_list_ops():
assert np.array_equal(c, [[2, 3, 4], [2, 3, 4]])

assert test_list_ops.build().to_ir()


class TestFloorDiv:
jcjaskula-aws marked this conversation as resolved.
Show resolved Hide resolved
def test_integer_division_on_intvars(self):
@aq.main(num_qubits=2)
def main():
a = aq.IntVar(5)
b = aq.IntVar(2)
c = a // b # noqa: F841

expected_ir = """OPENQASM 3.0;
int[32] c;
qubit[2] __qubits__;
int[32] a = 5;
int[32] b = 2;
int[32] __int_2__;
__int_2__ = a / b;
c = __int_2__;"""
assert main.build().to_ir() == expected_ir

def test_integer_division_on_mixed_vars(self):
@aq.main(num_qubits=2)
def main():
a = aq.IntVar(5)
b = aq.FloatVar(2.3)
c = a // b # noqa: F841
abidart marked this conversation as resolved.
Show resolved Hide resolved

expected_ir = """OPENQASM 3.0;
int[32] c;
qubit[2] __qubits__;
int[32] a = 5;
float[64] b = 2.3;
float[64] __float_2__;
__float_2__ = a;
int[32] __int_3__;
__int_3__ = a / b;
c = __int_3__;"""
assert main.build().to_ir() == expected_ir
Copy link
Contributor

@rmshaffer rmshaffer Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For full validation of your code's correctness - you could also run this program on a simulator and validate that the results of the floor division are what you expect.

Running on a simulator is pretty simple, see here for an example:

def _test_parametric_on_local_sim(program: aq.Program, inputs: dict[str, float]) -> np.ndarray:
device = LocalSimulator(backend=McmSimulator())
task = device.run(program, shots=100, inputs=inputs)
assert isinstance(task, LocalQuantumTask)
assert isinstance(task.result().measurements, dict)
return task.result().measurements

You don't have inputs, so you could ignore that part. And you'd just grab task.result().measurements which is a dict containing the values of all of the variables in the program.

Copy link
Contributor Author

@abidart abidart Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much, this helped a ton! Should add tests involving the local simulator to test_operators.py as well or was this suggestion mostly for local development?

On the local development note, I realized that every time OpenQASM casts a FloatVar that is not a round number to an IntVar I get a warning coming from openqasm/_helpers/casting.py:96. For example,

@aq.main
def floor_div():
    a = aq.IntVar(2)
    b = 4.3
    c = b // a  # noqa: F841

results in the warning UserWarning: Integer overflow for value 2.15 and size 32.

This is not directly related to this PR since:

@aq.main
def int_to_float():
    a = aq.IntVar(2.15)

nets the same warning. The warning can show up when floats/FloatVars are involved in // because this PR casts the result of the normal division to an IntVar. Do you think I should do something in the context of this PR to silence the warning when the user does integer division or should I leave it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see, that's a warning from the simulator here:
https://github.com/amazon-braket/amazon-braket-default-simulator-python/blob/a4d7f98cb123ae6a1092972e728d2dbb93cb27b5/src/braket/default_simulator/openqasm/_helpers/casting.py#L83-L99

It looks like that warning should be modified to account for that case - not sure yet if that change should happen in the autoqasm repo, or in the amazon-braket-default-simulator-python repo. Feel free to open an issue here to track it though, with the details you just provided! It definitely doesn't need to be fixed within this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing, will do!


def test_integer_division_on_python_types(self):
@aq.main(num_qubits=2)
def main():
a = 5
b = 2.3
c = a // b # noqa: F841

expected_ir = """OPENQASM 3.0;
qubit[2] __qubits__;"""
assert main.build().to_ir() == expected_ir
Loading