Skip to content

Commit

Permalink
[transforms] Fix simplification patterns for stablehlo.(and|or)
Browse files Browse the repository at this point in the history
Fixes an issue in `stablehlo-aggressive-simplification` where `%1` in
the below would get replaced by `%arg0`:

```
  %0 = stablehlo.constant dense<1> : tensor<2xi32>
  %1 = stablehlo.and %0, %arg0 : tensor<2xi32>
```

The pattern was checking whether `%0` is equal to `0b1` and was
only tested on bools. A similar bug existed for `stablehlo.and`. Fixed
by just making sure the constant is integer with all bits set to 1.
  • Loading branch information
christopherbate committed Nov 22, 2024
1 parent f21104d commit 59ce025
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ func.func @and_one(%arg0: tensor<2xi1>) -> tensor<2xi1> {
return %1 : tensor<2xi1>
}

// CHECK-LABEL: @and_i32_one
func.func @and_i32_one(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = stablehlo.constant dense<1> : tensor<2xi32>
%1 = stablehlo.and %0, %arg0 : tensor<2xi32>
// CHECK: %[[AND:.+]] = stablehlo.and
// CHECK: return %[[AND]]
return %1 : tensor<2xi32>
}

// CHECK-LABEL: @and_i32_neg_one
// CHECK-SAME: (%[[ARG0:.+]]: tensor<2xi32>)
func.func @and_i32_neg_one(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = stablehlo.constant dense<-1> : tensor<2xi32>
%1 = stablehlo.and %0, %arg0 : tensor<2xi32>
// CHECK-NOT: stablehlo.and
// CHECK: return %[[ARG0]]
return %1 : tensor<2xi32>
}

// -----

/////////
Expand Down Expand Up @@ -874,6 +893,25 @@ func.func @or_one(%arg0: tensor<2xi1>) -> tensor<2xi1> {
return %1 : tensor<2xi1>
}

// CHECK-LABEL: @or_i32_one
func.func @or_i32_one(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = stablehlo.constant dense<1> : tensor<2xi32>
%1 = stablehlo.or %0, %arg0 : tensor<2xi32>
// CHECK: %[[OR:.+]] = stablehlo.or
// CHECK: return %[[OR]]
return %1 : tensor<2xi32>
}

// CHECK-LABEL: @or_i32_neg_one
func.func @or_i32_neg_one(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = stablehlo.constant dense<-1> : tensor<2xi32>
%1 = stablehlo.or %0, %arg0 : tensor<2xi32>
// CHECK-NOT: stablehlo.or
// CHECK: [[NEG_ONE:%.+]] = stablehlo.constant dense<-1> : tensor<2xi32>
// CHECK: return [[NEG_ONE]]
return %1 : tensor<2xi32>
}

// -----

////////
Expand Down
14 changes: 12 additions & 2 deletions stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def IntOne : AttrConstraint<
CPred<"::mlir::matchPattern($_self, m_One())">,
"is integer one">;

def IntAllOnes : AttrConstraint<
CPred<[{
::mlir::matchPattern($_self,
::mlir::detail::constant_int_predicate_matcher{
[](const llvm::APInt &val) {
return val.isAllOnes();
}})
}]>,
"is integer with all bits set to 1">;

def IntZero : AttrConstraint<
CPred<"::mlir::matchPattern($_self, m_Zero())">,"is integer zero">;

Expand Down Expand Up @@ -139,7 +149,7 @@ def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)),
(replaceWithValue $zero)>;

// Pattern: and(X, 1) -> X
def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$one IntOne:$value)),
def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)),
(replaceWithValue $lhs)>;

////////
Expand Down Expand Up @@ -307,7 +317,7 @@ def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp IntOne:$value)),
def : CanonicalizeConstantToRhs<StableHLO_OrOp>;

// Pattern: or(X, 1) -> 1
def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntOne:$value)),
def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)),
(replaceWithValue $one)>;

// Pattern: or(X, 0) -> X
Expand Down

0 comments on commit 59ce025

Please sign in to comment.