Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add support for typecasting #27

Merged
merged 14 commits into from
Jun 13, 2024
65 changes: 65 additions & 0 deletions src/autoqasm/converters/typecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"""Converters for integer casting nodes."""

import ast

import gast
from malt.core import ag_ctx, converter
from malt.pyct import templates


class TypecastTransformer(converter.Base):
def visit_Call(self, node: ast.stmt) -> ast.stmt:
"""Converts type casting operations to their AutoQASM counterpart.

Args:
node (ast.stmt): AST node to transform.

Returns:
ast.stmt: Transformed node.
"""
typecasts_supported = ["int"]
node = self.generic_visit(node)

if (
hasattr(node, "func")
and hasattr(node.func, "id")
and node.func.id in typecasts_supported
):
template = f"ag__.{node.func.id}_typecast(argument_)"
new_node = templates.replace(
template,
argument_=node.args,
original=node,
)
new_node = new_node[0].value
else:
new_node = node
return new_node


def transform(node: ast.stmt, ctx: ag_ctx.ControlStatusCtx) -> ast.stmt:
"""Transform int cast nodes.

Args:
node (ast.stmt): AST node to transform.
ctx (ag_ctx.ControlStatusCtx): Transformer context.

Returns:
ast.stmt: Transformed node.
"""

return TypecastTransformer(ctx).visit(node)
1 change: 1 addition & 0 deletions src/autoqasm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
from .logical import or_ # noqa: F401
from .return_statements import return_output_from_main # noqa: F401
from .slices import GetItemOpts, get_item, set_item # noqa: F401
from .typecast import int_typecast # noqa: F401
35 changes: 35 additions & 0 deletions src/autoqasm/operators/typecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"""Operators for int cast statements."""

from typing import Any

from autoqasm import types as aq_types


def int_typecast(argument_: Any, *args, **kwargs) -> aq_types.IntVar | int:
abidart marked this conversation as resolved.
Show resolved Hide resolved
"""Operator declares the `oq` variable, or sets variable's value if it's
already declared.

Args:
argument_ (Any): object to be converted into an int.

Returns:
IntVar | int : IntVar object if argument is QASM type, else int.
"""
if aq_types.is_qasm_type(argument_):
return aq_types.IntVar(argument_)
rmshaffer marked this conversation as resolved.
Show resolved Hide resolved
else:
return int(argument_, *args, **kwargs)
9 changes: 8 additions & 1 deletion src/autoqasm/transpiler/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@
from malt.utils import ag_logging as logging

from autoqasm import operators, program, types
from autoqasm.converters import assignments, break_statements, comparisons, return_statements
from autoqasm.converters import (
assignments,
break_statements,
comparisons,
return_statements,
typecast,
)


class PyToOqpy(transpiler.PyToPy):
Expand Down Expand Up @@ -128,6 +134,7 @@ def transform_ast(
# canonicalization creates.
node = continue_statements.transform(node, ctx)
node = return_statements.transform(node, ctx)
node = typecast.transform(node, ctx)
node = assignments.transform(node, ctx)
node = lists.transform(node, ctx)
node = slices.transform(node, ctx)
Expand Down
36 changes: 36 additions & 0 deletions test/unit_tests/autoqasm/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,3 +1266,39 @@ def test(int[32] a, int[32] b) {
test(2, 3);
test(4, 5);"""
assert main.build().to_ir() == expected


class TestTypecasting:
def test_int_typecasting_on_measure(self):
@aq.main(num_qubits=2)
def main():
test = int(measure([0, 1]))
abidart marked this conversation as resolved.
Show resolved Hide resolved

expected_ir = """OPENQASM 3.0;
qubit[2] __qubits__;
bit[2] __bit_0__ = "00";
__bit_0__[0] = measure __qubits__[0];
__bit_0__[1] = measure __qubits__[1];
int[32] test = __bit_0__;"""
abidart marked this conversation as resolved.
Show resolved Hide resolved
assert main.build().to_ir() == expected_ir

def test_int_typecasting_on_string(self):
@aq.main(num_qubits=2)
def main():
test = int("101", 2)
abidart marked this conversation as resolved.
Show resolved Hide resolved

expected_ir = """OPENQASM 3.0;
qubit[2] __qubits__;"""
abidart marked this conversation as resolved.
Show resolved Hide resolved
assert main.build().to_ir() == expected_ir

def test_nested_int_typecasting(self):
@aq.main(num_qubits=2)
def main():
test = 2 * int(measure([0, 1]))
abidart marked this conversation as resolved.
Show resolved Hide resolved

expected_ir = """OPENQASM 3.0;
qubit[2] __qubits__;
bit[2] __bit_0__ = "00";
__bit_0__[0] = measure __qubits__[0];
__bit_0__[1] = measure __qubits__[1];"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look right. I would expect the output IR to contain something like:

int[32] __int_1__ = __bit_0__;
int[32] test = 2 * __int_1__;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For debugging, might be interesting to know what the IR output is if you modify the program to be like this:

        @aq.main(num_qubits=2)
        def main():
            test = int(2 * measure([0, 1]))

or like this:

        @aq.main(num_qubits=2)
        def main():
            test = 2 * int(measure([0, 1]))
            return test

Copy link
Contributor Author

@abidart abidart Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and using

@aq.main(num_qubits=2)
def main():
    test = 2 * IntVar(measure([0, 1]))

results in the same IR. I think the discrepancy came from the fact that the variable test was not being used after declared. Your second suggestion helped me realize that adding return test (or other lines that make use of test) results in an IR with the int[32] lines. I updated the unit test to reflect this.

On a second note, I did not add support for arithmetic operations on BitVara in this PR because I focused on int typecasting. To add the arithmetic operations, would you suggest I add methods like __mul__ and __rmul__ to the BitVar class (where self gets casted to an IntVar) or use a converter to intercept the arithmetic operations and cast the BitVars to IntVars?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good question. I think it's ok to leave out the arithmetic operations from this PR. I think it's ok (and maybe desirable) that the user must manually cast the BitVar to an int before doing arithmetic on it.

assert main.build().to_ir() == expected_ir
Loading