Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for float comparison in ASTBridge #2489

Merged
merged 7 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 70 additions & 50 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3374,76 +3374,96 @@ def visit_Compare(self, node):
self.visit(node.left)
left = self.popValue()
self.visit(node.comparators[0])
comparator = self.popValue()
right = self.popValue()
op = node.ops[0]

if isinstance(op, ast.Gt):
if IntegerType.isinstance(left.type):
if F64Type.isinstance(comparator.type):
self.emitFatalError(
"invalid rhs for comparison (f64 type and not i64 type).",
node)
left_type = left.type
right_type = right.type

if IntegerType.isinstance(left_type) and F64Type.isinstance(right_type):
left = arith.SIToFPOp(self.getFloatType(), left).result
elif F64Type.isinstance(left_type) and IntegerType.isinstance(
right_type):
right = arith.SIToFPOp(self.getFloatType(), right).result
elif IntegerType.isinstance(left_type) and IntegerType.isinstance(
right_type):
if IntegerType(left_type).width < IntegerType(right_type).width:
zeroext = IntegerType(left_type).width == 1
left = cc.CastOp(right_type,
left,
sint=not zeroext,
zint=zeroext).result
elif IntegerType(left_type).width > IntegerType(right_type).width:
zeroext = IntegerType(right_type).width == 1
right = cc.CastOp(left_type,
right,
sint=not zeroext,
zint=zeroext).result

self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 4), left,
comparator).result)
elif F64Type.isinstance(left.type):
if IntegerType.isinstance(comparator.type):
comparator = arith.SIToFPOp(self.getFloatType(),
comparator).result
if isinstance(op, ast.Gt):
if F64Type.isinstance(left.type):
self.pushValue(
arith.CmpFOp(self.getIntegerAttr(iTy, 2), left,
comparator).result)
right).result)
else:
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 4), left,
right).result)
return

if isinstance(op, ast.GtE):
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 5), left,
comparator).result)
if F64Type.isinstance(left.type):
self.pushValue(
arith.CmpFOp(self.getIntegerAttr(iTy, 3), left,
right).result)
else:
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 5), left,
right).result)
return

if isinstance(op, ast.Lt):
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 2), left,
comparator).result)
if F64Type.isinstance(left.type):
self.pushValue(
arith.CmpFOp(self.getIntegerAttr(iTy, 4), left,
right).result)
else:
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 2), left,
right).result)
return

if isinstance(op, ast.LtE):
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 7), left,
comparator).result)
if F64Type.isinstance(left.type):
self.pushValue(
arith.CmpFOp(self.getIntegerAttr(iTy, 5), left,
right).result)
else:
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 7), left,
right).result)
return

if isinstance(op, ast.NotEq):
if F64Type.isinstance(left.type) and IntegerType.isinstance(
comparator.type):
left = arith.FPToSIOp(comparator.type, left).result
if IntegerType(left.type).width < IntegerType(
comparator.type).width:
zeroext = IntegerType(left.type).width == 1
left = cc.CastOp(comparator.type,
left,
sint=not zeroext,
zint=zeroext).result
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 1), left,
comparator).result)
if F64Type.isinstance(left.type):
self.pushValue(
arith.CmpFOp(self.getIntegerAttr(iTy, 6), left,
right).result)
else:
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 1), left,
right).result)
return

if isinstance(op, ast.Eq):
if F64Type.isinstance(left.type) and IntegerType.isinstance(
comparator.type):
left = arith.FPToSIOp(comparator.type, left).result
if IntegerType(left.type).width < IntegerType(
comparator.type).width:
zeroext = IntegerType(left.type).width == 1
left = cc.CastOp(comparator.type,
left,
sint=not zeroext,
zint=zeroext).result
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 0), left,
comparator).result)
if F64Type.isinstance(left.type):
self.pushValue(
arith.CmpFOp(self.getIntegerAttr(iTy, 1), left,
right).result)
else:
self.pushValue(
arith.CmpIOp(self.getIntegerAttr(iTy, 0), left,
right).result)
return

def visit_AugAssign(self, node):
Expand Down
98 changes: 98 additions & 0 deletions python/tests/kernel/test_ast_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# ============================================================================ #
# Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

import os
import pytest
import cudaq


def cmpfop(predicate, left, right):
operations = {
2: lambda l, r: l > r,
3: lambda l, r: l >= r,
4: lambda l, r: l < r,
5: lambda l, r: l <= r,
1: lambda l, r: l == r,
6: lambda l, r: l != r,
}
return operations[predicate](left, right)


def cmpiop(predicate, left, right):
operations = {
4: lambda l, r: l > r,
5: lambda l, r: l >= r,
2: lambda l, r: l < r,
7: lambda l, r: l <= r,
0: lambda l, r: l == r,
1: lambda l, r: l != r,
}
return operations[predicate](left, right)


@pytest.mark.parametrize(
"left, right, operation, expected",
[
# Integer comparisons
(3, 5, "Lt", True),
(5, 3, "Gt", True),
(3, 3, "Eq", True),
(3, 5, "LtE", True),
(5, 5, "GtE", True),
(3, 5, "NotEq", True),
(5, 5, "NotEq", False),

# Float comparisons
(3.2, 4.5, "Lt", True),
(4.5, 3.2, "Gt", True),
(3.2, 3.2, "Eq", True),
(3.2, 4.5, "LtE", True),
(4.5, 4.5, "GtE", True),
(3.2, 4.5, "NotEq", True),
(4.5, 4.5, "NotEq", False),

# Mixed comparisons
(3, 4.5, "Lt", True),
(4.5, 3, "Gt", True),
(3, 3.0, "Eq", True),
(3, 4.5, "LtE", True),
(4.5, 4, "GtE", True),
(3, 4.5, "NotEq", True),
],
)
def test_visit_compare(left, right, operation, expected):
sacpis marked this conversation as resolved.
Show resolved Hide resolved
result = None

if operation in ["Gt", "GtE", "Lt", "LtE", "Eq", "NotEq"]:
if isinstance(left, float) or isinstance(right, float):
predicate = {
"Gt": 2,
"GtE": 3,
"Lt": 4,
"LtE": 5,
"Eq": 1,
"NotEq": 6,
}[operation]
result = cmpfop(predicate, left, right)
else:
predicate = {
"Gt": 4,
"GtE": 5,
"Lt": 2,
"LtE": 7,
"Eq": 0,
"NotEq": 1,
}[operation]
result = cmpiop(predicate, left, right)

assert result == expected


if __name__ == "__main__":
loc = os.path.abspath(__file__)
pytest.main([loc, "-rP"])
Loading
Loading