Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add support for typecasting #27

Merged
merged 14 commits into from
Jun 13, 2024
64 changes: 64 additions & 0 deletions src/autoqasm/converters/typecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"""Converters for integer casting nodes."""

import ast

from malt.core import ag_ctx, converter
from malt.pyct import templates


class TypecastTransformer(converter.Base):
def visit_Call(self, node: ast.stmt) -> ast.stmt:
"""Converts type casting operations to their AutoQASM counterpart.

Args:
node (ast.stmt): AST node to transform.

Returns:
ast.stmt: Transformed node.
"""
typecasts_supported = ["int"]
node = self.generic_visit(node)

if (
hasattr(node, "func")
and hasattr(node.func, "id")
and node.func.id in typecasts_supported
):
template = f"ag__.{node.func.id}_typecast(argument_)"
new_node = templates.replace(
template,
argument_=node.args,
original=node,
)
new_node = new_node[0].value
else:
new_node = node
return new_node


def transform(node: ast.stmt, ctx: ag_ctx.ControlStatusCtx) -> ast.stmt:
"""Transform int cast nodes.

Args:
node (ast.stmt): AST node to transform.
ctx (ag_ctx.ControlStatusCtx): Transformer context.

Returns:
ast.stmt: Transformed node.
"""

return TypecastTransformer(ctx).visit(node)
1 change: 1 addition & 0 deletions src/autoqasm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
from .logical import or_ # noqa: F401
from .return_statements import return_output_from_main # noqa: F401
from .slices import GetItemOpts, get_item, set_item # noqa: F401
from .typecast import int_typecast # noqa: F401
36 changes: 36 additions & 0 deletions src/autoqasm/operators/typecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"""Operators for int cast statements."""
from __future__ import annotations

from typing import Any

from autoqasm import types as aq_types


def int_typecast(argument_: Any, *args, **kwargs) -> aq_types.IntVar | int:
abidart marked this conversation as resolved.
Show resolved Hide resolved
"""Operator declares the `oq` variable, or sets variable's value if it's
already declared.

Args:
argument_ (Any): object to be converted into an int.

Returns:
IntVar | int : IntVar object if argument is QASM type, else int.
"""
if aq_types.is_qasm_type(argument_):
return aq_types.IntVar(argument_)
rmshaffer marked this conversation as resolved.
Show resolved Hide resolved
else:
return int(argument_, *args, **kwargs)
9 changes: 8 additions & 1 deletion src/autoqasm/transpiler/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@
from malt.utils import ag_logging as logging

from autoqasm import operators, program, types
from autoqasm.converters import assignments, break_statements, comparisons, return_statements
from autoqasm.converters import (
assignments,
break_statements,
comparisons,
return_statements,
typecast,
)


class PyToOqpy(transpiler.PyToPy):
Expand Down Expand Up @@ -128,6 +134,7 @@ def transform_ast(
# canonicalization creates.
node = continue_statements.transform(node, ctx)
node = return_statements.transform(node, ctx)
node = typecast.transform(node, ctx)
node = assignments.transform(node, ctx)
node = lists.transform(node, ctx)
node = slices.transform(node, ctx)
Expand Down
40 changes: 40 additions & 0 deletions test/unit_tests/autoqasm/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,3 +1266,43 @@ def test(int[32] a, int[32] b) {
test(2, 3);
test(4, 5);"""
assert main.build().to_ir() == expected


class TestTypecasting:
def test_int_typecasting_on_measure(self):
@aq.main(num_qubits=2)
def main():
test = int(measure([0, 1])) # noqa: F841

expected_ir = """OPENQASM 3.0;
qubit[2] __qubits__;
bit[2] __bit_0__ = "00";
__bit_0__[0] = measure __qubits__[0];
__bit_0__[1] = measure __qubits__[1];
int[32] test = __bit_0__;"""
abidart marked this conversation as resolved.
Show resolved Hide resolved
assert main.build().to_ir() == expected_ir

def test_int_typecasting_on_string(self):
@aq.main(num_qubits=2)
def main():
test = int("101", 2) # noqa: F841

expected_ir = """OPENQASM 3.0;
qubit[2] __qubits__;"""
abidart marked this conversation as resolved.
Show resolved Hide resolved
assert main.build().to_ir() == expected_ir

def test_nested_int_typecasting(self):
@aq.main(num_qubits=2)
def main():
test = 2 * int(measure([0, 1])) # noqa: F841
return test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be worth having two versions of this test, one without the return and one with the return. In both cases the test variable should be declared and assigned - it shouldn't matter whether it is used afterward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added! Could you confirm that in the case where there is no return statement, there should not be an instruction involving the 2 *. Currently, I am getting:

with return

OPENQASM 3.0;
output int[32] test;
qubit[2] __qubits__;
bit[2] __bit_0__ = "00";
__bit_0__[0] = measure __qubits__[0];
__bit_0__[1] = measure __qubits__[1];
int[32] __int_1__;
__int_1__ = __bit_0__;
test = 2 * __int_1__;

without return

OPENQASM 3.0;
qubit[2] __qubits__;
bit[2] __bit_0__ = "00";
__bit_0__[0] = measure __qubits__[0];
__bit_0__[1] = measure __qubits__[1];
int[32] __int_1__;
__int_1__ = __bit_0__;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting - I still would expect the test variable to be declared and assigned 2 * __int_1__ there. But functionally it doesn't matter, of course. I expect this would be an issue in the assignment converter and outside the scope of your PR. It shouldn't block moving forward with your changes.

Copy link
Contributor Author

@abidart abidart Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I can look into this in a new issue/PR? I noticed that

@aq.main(num_qubits=2)
def main():
    test = 4 * IntVar(2)

results in

OPENQASM 3.0;
qubit[2] __qubits__;

but

@aq.main(num_qubits=2)
def main():
    test = IntVar(2)

leads to

OPENQASM 3.0;
qubit[2] __qubits__;
int[32] test = 2;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if you'd like to look into this separately, that would be great! I'd recommend opening a new issue with those details.


expected_ir = """OPENQASM 3.0;
int[32] __int_1__ = __bit_0__;
rmshaffer marked this conversation as resolved.
Show resolved Hide resolved
output int[32] test;
qubit[2] __qubits__;
bit[2] __bit_0__ = "00";
__bit_0__[0] = measure __qubits__[0];
__bit_0__[1] = measure __qubits__[1];
test = 2 * __int_1__;"""
assert main.build().to_ir() == expected_ir