Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add support for typecasting #27

Merged
merged 14 commits into from
Jun 13, 2024
69 changes: 69 additions & 0 deletions src/autoqasm/converters/typecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"""Converters for integer casting nodes."""

import ast

import gast
from malt.core import ag_ctx, converter
from malt.pyct import templates


class TypecastTransformer(converter.Base):
def visit_Call(self, node: ast.stmt) -> ast.stmt:
"""Converts type casting operations to their AutoQASM counterpart.

Args:
node (ast.stmt): AST node to transform.

Returns:
ast.stmt: Transformed node.
"""
typecasts_supported = ["int", "float"]

template = """
ag__.typecast(type_, argument_)
"""
rmshaffer marked this conversation as resolved.
Show resolved Hide resolved
node = self.generic_visit(node)
if (
len(node.args) > 1
and hasattr(node.args[1], "func")
and hasattr(node.args[1].func, "id")
and node.args[1].func.id in typecasts_supported
):
new_node = templates.replace(
template,
type_=node.args[1].func.id,
argument_=node.args[1].args,
original=node,
)
new_node = new_node[0].value
else:
new_node = node
return new_node


def transform(node: ast.stmt, ctx: ag_ctx.ControlStatusCtx) -> ast.stmt:
"""Transform int cast nodes.

Args:
node (ast.stmt): AST node to transform.
ctx (ag_ctx.ControlStatusCtx): Transformer context.

Returns:
ast.stmt: Transformed node.
"""

return TypecastTransformer(ctx).visit(node)
1 change: 1 addition & 0 deletions src/autoqasm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
from .logical import or_ # noqa: F401
from .return_statements import return_output_from_main # noqa: F401
from .slices import GetItemOpts, get_item, set_item # noqa: F401
from .typecast import typecast # noqa: F401
47 changes: 47 additions & 0 deletions src/autoqasm/operators/typecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"""Operators for int cast statements."""

from typing import Any

from autoqasm import types as aq_types


def typecast(type_: type, argument_: Any) -> aq_types.IntVar | int:
"""Operator declares the `oq` variable, or sets variable's value if it's
already declared.

Args:
type_ (type): the type for the conversion
argument_ (Any): object to be converted.

Returns:
IntVar | FloatVar | int | float: IntVar/FloatVar object if argument is QASM type, else int/float.
"""
type_to_aq_type_map = {int: aq_types.IntVar, float: aq_types.FloatVar}
if aq_types.is_qasm_type(argument_):
if (
argument_.size is not None
and argument_.size > 1
and isinstance(argument_, aq_types.BitVar)
):
typecasted_arg = type_to_aq_type_map[type_](argument_[0])
for index in range(1, argument_.size):
typecasted_arg += type_to_aq_type_map[type_](argument_[index]) * 2**index
return typecasted_arg
rmshaffer marked this conversation as resolved.
Show resolved Hide resolved
else:
return type_to_aq_type_map[type_](argument_)
else:
return type_(*argument_)
9 changes: 8 additions & 1 deletion src/autoqasm/transpiler/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@
from malt.utils import ag_logging as logging

from autoqasm import operators, program, types
from autoqasm.converters import assignments, break_statements, comparisons, return_statements
from autoqasm.converters import (
assignments,
break_statements,
comparisons,
return_statements,
typecast,
)


class PyToOqpy(transpiler.PyToPy):
Expand Down Expand Up @@ -131,6 +137,7 @@ def transform_ast(
node = assignments.transform(node, ctx)
node = lists.transform(node, ctx)
node = slices.transform(node, ctx)
node = typecast.transform(node, ctx)
rmshaffer marked this conversation as resolved.
Show resolved Hide resolved
node = call_trees.transform(node, ctx)
node = control_flow.transform(node, ctx)
node = conditional_expressions.transform(node, ctx)
Expand Down