From 59ce02515b23332bea01fcd15318943cc0e472e1 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Fri, 22 Nov 2024 22:43:47 +0000 Subject: [PATCH] [transforms] Fix simplification patterns for `stablehlo.(and|or)` Fixes an issue in `stablehlo-aggressive-simplification` where `%1` in the below would get replaced by `%arg0`: ``` %0 = stablehlo.constant dense<1> : tensor<2xi32> %1 = stablehlo.and %0, %arg0 : tensor<2xi32> ``` The pattern was checking whether `%0` is equal to `0b1` and was only tested on bools. A similar bug existed for `stablehlo.and`. Fixed by just making sure the constant is integer with all bits set to 1. --- .../stablehlo_aggressive_simplification.mlir | 38 +++++++++++++++++++ ...ablehloAggressiveSimplificationPatterns.td | 14 ++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir index 73da68c87b..24e2398fe7 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir @@ -63,6 +63,25 @@ func.func @and_one(%arg0: tensor<2xi1>) -> tensor<2xi1> { return %1 : tensor<2xi1> } +// CHECK-LABEL: @and_i32_one +func.func @and_i32_one(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = stablehlo.constant dense<1> : tensor<2xi32> + %1 = stablehlo.and %0, %arg0 : tensor<2xi32> + // CHECK: %[[AND:.+]] = stablehlo.and + // CHECK: return %[[AND]] + return %1 : tensor<2xi32> +} + +// CHECK-LABEL: @and_i32_neg_one +// CHECK-SAME: (%[[ARG0:.+]]: tensor<2xi32>) +func.func @and_i32_neg_one(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = stablehlo.constant dense<-1> : tensor<2xi32> + %1 = stablehlo.and %0, %arg0 : tensor<2xi32> + // CHECK-NOT: stablehlo.and + // CHECK: return %[[ARG0]] + return %1 : tensor<2xi32> +} + // ----- ///////// @@ -874,6 +893,25 @@ func.func @or_one(%arg0: tensor<2xi1>) -> tensor<2xi1> { return %1 : tensor<2xi1> } +// CHECK-LABEL: @or_i32_one +func.func @or_i32_one(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = stablehlo.constant dense<1> : tensor<2xi32> + %1 = stablehlo.or %0, %arg0 : tensor<2xi32> + // CHECK: %[[OR:.+]] = stablehlo.or + // CHECK: return %[[OR]] + return %1 : tensor<2xi32> +} + +// CHECK-LABEL: @or_i32_neg_one +func.func @or_i32_neg_one(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = stablehlo.constant dense<-1> : tensor<2xi32> + %1 = stablehlo.or %0, %arg0 : tensor<2xi32> + // CHECK-NOT: stablehlo.or + // CHECK: [[NEG_ONE:%.+]] = stablehlo.constant dense<-1> : tensor<2xi32> + // CHECK: return [[NEG_ONE]] + return %1 : tensor<2xi32> +} + // ----- //////// diff --git a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td index faea08078e..9cbcc07ca6 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td +++ b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td @@ -72,6 +72,16 @@ def IntOne : AttrConstraint< CPred<"::mlir::matchPattern($_self, m_One())">, "is integer one">; +def IntAllOnes : AttrConstraint< + CPred<[{ + ::mlir::matchPattern($_self, + ::mlir::detail::constant_int_predicate_matcher{ + [](const llvm::APInt &val) { + return val.isAllOnes(); + }}) + }]>, + "is integer with all bits set to 1">; + def IntZero : AttrConstraint< CPred<"::mlir::matchPattern($_self, m_Zero())">,"is integer zero">; @@ -139,7 +149,7 @@ def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), (replaceWithValue $zero)>; // Pattern: and(X, 1) -> X -def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$one IntOne:$value)), +def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)), (replaceWithValue $lhs)>; //////// @@ -307,7 +317,7 @@ def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp IntOne:$value)), def : CanonicalizeConstantToRhs; // Pattern: or(X, 1) -> 1 -def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntOne:$value)), +def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)), (replaceWithValue $one)>; // Pattern: or(X, 0) -> X