diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index ffc8a67676..79299f052b 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -3374,76 +3374,96 @@ def visit_Compare(self, node): self.visit(node.left) left = self.popValue() self.visit(node.comparators[0]) - comparator = self.popValue() + right = self.popValue() op = node.ops[0] - if isinstance(op, ast.Gt): - if IntegerType.isinstance(left.type): - if F64Type.isinstance(comparator.type): - self.emitFatalError( - "invalid rhs for comparison (f64 type and not i64 type).", - node) + left_type = left.type + right_type = right.type + + if IntegerType.isinstance(left_type) and F64Type.isinstance(right_type): + left = arith.SIToFPOp(self.getFloatType(), left).result + elif F64Type.isinstance(left_type) and IntegerType.isinstance( + right_type): + right = arith.SIToFPOp(self.getFloatType(), right).result + elif IntegerType.isinstance(left_type) and IntegerType.isinstance( + right_type): + if IntegerType(left_type).width < IntegerType(right_type).width: + zeroext = IntegerType(left_type).width == 1 + left = cc.CastOp(right_type, + left, + sint=not zeroext, + zint=zeroext).result + elif IntegerType(left_type).width > IntegerType(right_type).width: + zeroext = IntegerType(right_type).width == 1 + right = cc.CastOp(left_type, + right, + sint=not zeroext, + zint=zeroext).result - self.pushValue( - arith.CmpIOp(self.getIntegerAttr(iTy, 4), left, - comparator).result) - elif F64Type.isinstance(left.type): - if IntegerType.isinstance(comparator.type): - comparator = arith.SIToFPOp(self.getFloatType(), - comparator).result + if isinstance(op, ast.Gt): + if F64Type.isinstance(left.type): self.pushValue( arith.CmpFOp(self.getIntegerAttr(iTy, 2), left, - comparator).result) + right).result) + else: + self.pushValue( + arith.CmpIOp(self.getIntegerAttr(iTy, 4), left, + right).result) return if isinstance(op, ast.GtE): - self.pushValue( - arith.CmpIOp(self.getIntegerAttr(iTy, 5), left, - comparator).result) + if F64Type.isinstance(left.type): + self.pushValue( + arith.CmpFOp(self.getIntegerAttr(iTy, 3), left, + right).result) + else: + self.pushValue( + arith.CmpIOp(self.getIntegerAttr(iTy, 5), left, + right).result) return if isinstance(op, ast.Lt): - self.pushValue( - arith.CmpIOp(self.getIntegerAttr(iTy, 2), left, - comparator).result) + if F64Type.isinstance(left.type): + self.pushValue( + arith.CmpFOp(self.getIntegerAttr(iTy, 4), left, + right).result) + else: + self.pushValue( + arith.CmpIOp(self.getIntegerAttr(iTy, 2), left, + right).result) return if isinstance(op, ast.LtE): - self.pushValue( - arith.CmpIOp(self.getIntegerAttr(iTy, 7), left, - comparator).result) + if F64Type.isinstance(left.type): + self.pushValue( + arith.CmpFOp(self.getIntegerAttr(iTy, 5), left, + right).result) + else: + self.pushValue( + arith.CmpIOp(self.getIntegerAttr(iTy, 7), left, + right).result) return if isinstance(op, ast.NotEq): - if F64Type.isinstance(left.type) and IntegerType.isinstance( - comparator.type): - left = arith.FPToSIOp(comparator.type, left).result - if IntegerType(left.type).width < IntegerType( - comparator.type).width: - zeroext = IntegerType(left.type).width == 1 - left = cc.CastOp(comparator.type, - left, - sint=not zeroext, - zint=zeroext).result - self.pushValue( - arith.CmpIOp(self.getIntegerAttr(iTy, 1), left, - comparator).result) + if F64Type.isinstance(left.type): + self.pushValue( + arith.CmpFOp(self.getIntegerAttr(iTy, 6), left, + right).result) + else: + self.pushValue( + arith.CmpIOp(self.getIntegerAttr(iTy, 1), left, + right).result) return if isinstance(op, ast.Eq): - if F64Type.isinstance(left.type) and IntegerType.isinstance( - comparator.type): - left = arith.FPToSIOp(comparator.type, left).result - if IntegerType(left.type).width < IntegerType( - comparator.type).width: - zeroext = IntegerType(left.type).width == 1 - left = cc.CastOp(comparator.type, - left, - sint=not zeroext, - zint=zeroext).result - self.pushValue( - arith.CmpIOp(self.getIntegerAttr(iTy, 0), left, - comparator).result) + if F64Type.isinstance(left.type): + self.pushValue( + arith.CmpFOp(self.getIntegerAttr(iTy, 1), left, + right).result) + else: + self.pushValue( + arith.CmpIOp(self.getIntegerAttr(iTy, 0), left, + right).result) return def visit_AugAssign(self, node): diff --git a/python/tests/kernel/test_ast_compare.py b/python/tests/kernel/test_ast_compare.py new file mode 100644 index 0000000000..e6f9f4b492 --- /dev/null +++ b/python/tests/kernel/test_ast_compare.py @@ -0,0 +1,98 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +import os +import pytest +import cudaq + + +def cmpfop(predicate, left, right): + operations = { + 2: lambda l, r: l > r, + 3: lambda l, r: l >= r, + 4: lambda l, r: l < r, + 5: lambda l, r: l <= r, + 1: lambda l, r: l == r, + 6: lambda l, r: l != r, + } + return operations[predicate](left, right) + + +def cmpiop(predicate, left, right): + operations = { + 4: lambda l, r: l > r, + 5: lambda l, r: l >= r, + 2: lambda l, r: l < r, + 7: lambda l, r: l <= r, + 0: lambda l, r: l == r, + 1: lambda l, r: l != r, + } + return operations[predicate](left, right) + + +@pytest.mark.parametrize( + "left, right, operation, expected", + [ + # Integer comparisons + (3, 5, "Lt", True), + (5, 3, "Gt", True), + (3, 3, "Eq", True), + (3, 5, "LtE", True), + (5, 5, "GtE", True), + (3, 5, "NotEq", True), + (5, 5, "NotEq", False), + + # Float comparisons + (3.2, 4.5, "Lt", True), + (4.5, 3.2, "Gt", True), + (3.2, 3.2, "Eq", True), + (3.2, 4.5, "LtE", True), + (4.5, 4.5, "GtE", True), + (3.2, 4.5, "NotEq", True), + (4.5, 4.5, "NotEq", False), + + # Mixed comparisons + (3, 4.5, "Lt", True), + (4.5, 3, "Gt", True), + (3, 3.0, "Eq", True), + (3, 4.5, "LtE", True), + (4.5, 4, "GtE", True), + (3, 4.5, "NotEq", True), + ], +) +def test_visit_compare(left, right, operation, expected): + result = None + + if operation in ["Gt", "GtE", "Lt", "LtE", "Eq", "NotEq"]: + if isinstance(left, float) or isinstance(right, float): + predicate = { + "Gt": 2, + "GtE": 3, + "Lt": 4, + "LtE": 5, + "Eq": 1, + "NotEq": 6, + }[operation] + result = cmpfop(predicate, left, right) + else: + predicate = { + "Gt": 4, + "GtE": 5, + "Lt": 2, + "LtE": 7, + "Eq": 0, + "NotEq": 1, + }[operation] + result = cmpiop(predicate, left, right) + + assert result == expected + + +if __name__ == "__main__": + loc = os.path.abspath(__file__) + pytest.main([loc, "-rP"]) diff --git a/python/tests/mlir/ast_comparators.py b/python/tests/mlir/ast_comparators.py new file mode 100644 index 0000000000..1479efa961 --- /dev/null +++ b/python/tests/mlir/ast_comparators.py @@ -0,0 +1,187 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +# RUN: PYTHONPATH=../../ pytest -rP %s | FileCheck %s + +import os +import pytest +import cudaq + + +def test_comparison_operators_for_integers(): + + @cudaq.kernel + def test_integer_less_than(): + a = 3 < 5 + + print(test_integer_less_than) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_integer_greater_than(): + a = 5 > 3 + + print(test_integer_greater_than) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_integer_equal_to(): + a = 3 == 3 + + print(test_integer_equal_to) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_integer_less_than_or_equal_to(): + a = 3 <= 5 + + print(test_integer_less_than_or_equal_to) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_integer_greater_than_or_equal_to(): + a = 5 >= 5 + + print(test_integer_greater_than_or_equal_to) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_integer_not_equal_to_true(): + a = 3 != 5 + + print(test_integer_not_equal_to_true) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_integer_not_equal_to_false(): + a = 5 != 5 + + print(test_integer_not_equal_to_false) + + # CHECK-LABEL: %false = arith.constant false + + +def test_comparison_operators_for_floats(): + + @cudaq.kernel + def test_float_less_than(): + a = 3.2 < 4.5 + + print(test_float_less_than) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_float_greater_than(): + a = 4.5 > 3.2 + + print(test_float_greater_than) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_float_equal_to(): + a = 3.2 == 3.2 + + print(test_float_equal_to) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_float_less_than_or_equal_to(): + a = 3.2 <= 4.5 + + print(test_float_less_than_or_equal_to) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_float_greater_than_or_equal_to(): + a = 4.5 >= 4.5 + + print(test_float_greater_than_or_equal_to) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_float_not_equal_to_true(): + a = 3.2 != 4.5 + + print(test_float_not_equal_to_true) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_float_not_equal_to_false(): + a = 4.5 != 4.5 + + print(test_float_not_equal_to_false) + + # CHECK-LABEL: %false = arith.constant false + + +def test_comparison_operators_for_mixed_types(): + + @cudaq.kernel + def test_mixed_less_than(): + a = 3 < 4.5 + + print(test_mixed_less_than) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_mixed_greater_than(): + a = 4.5 > 3 + + print(test_mixed_greater_than) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_mixed_equal_to(): + a = 3 == 3.0 + + print(test_mixed_equal_to) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_mixed_less_than_or_equal_to(): + a = 3 <= 4.5 + + print(test_mixed_less_than_or_equal_to) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_mixed_greater_than_or_equal_to(): + a = 4.5 >= 4 + + print(test_mixed_greater_than_or_equal_to) + + # CHECK-LABEL: %true = arith.constant true + + @cudaq.kernel + def test_mixed_not_equal_to_true(): + a = 3 != 4.5 + + print(test_mixed_not_equal_to_true) + + # CHECK-LABEL: %true = arith.constant true + + +if __name__ == "__main__": + loc = os.path.abspath(__file__) + pytest.main([loc, "-rP"])