Skip to content

Commit

Permalink
Don't require inlining for shape refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK committed Nov 19, 2024
1 parent 2d42d12 commit e83ab74
Show file tree
Hide file tree
Showing 5 changed files with 729 additions and 57 deletions.
305 changes: 299 additions & 6 deletions stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
func.func @error_illformed(%arg0: tensor<3xf32>, %arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.abs %arg0 : (tensor<3xf32>) -> tensor<?xf32>
%1 = stablehlo.abs %arg1 : (tensor<4xf32>) -> tensor<?xf32>
// expected-error@+1{{requires the same shape for all operands and results}}
// expected-error@+1{{'stablehlo.add' op requires the same shape for all operands and results}}
%2 = stablehlo.add %0, %1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func.return %2 : tensor<?xf32>
}

// -----

// expected-error@+1{{must have exactly one block}}
// expected-error@+1{{'func.func' op must have exactly one block}}
func.func @error_too_many_blocks(%arg0: tensor<f32>) -> tensor<f32> {
cf.br ^bb1(%arg0 : tensor<f32>)
^bb1(%arg1 : tensor<f32>):
Expand Down Expand Up @@ -49,6 +49,7 @@ module @has_main {

// -----

// CHECK-LABEL: func @error_unsupported_operation
func.func @error_unsupported_operation(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> index {
// CHECK: stablehlo.add{{.*}} -> tensor<?xf32>
%0 = stablehlo.add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<?xf32>
Expand Down Expand Up @@ -596,10 +597,288 @@ func.func @refine_bitcast_convert_different_bitwidths(%arg0 : tensor<4xf32>) ->
// -----

// CHECK-LABEL: func @refine_bitcast_convert_same_bitwidth
func.func @refine_bitcast_convert_same_bitwidth(%arg0 : tensor<4xf32>) -> tensor<?xi32> {
// CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<4xi32>
%0 = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
func.func @refine_bitcast_convert_same_bitwidth() -> tensor<?x?x0xf32> {
%0 = stablehlo.constant dense<[3, 5, 0]> : tensor<3xi32>
%21 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<3xi32>) -> tensor<?x?x0xui32>
// CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<3x5x0xf32>
%48 = stablehlo.bitcast_convert %21 : (tensor<?x?x0xui32>) -> tensor<?x?x0xf32>
return %48 : tensor<?x?x0xf32>
}

// -----

// CHECK-LABEL: module @refine_call
module @refine_call {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%1 = stablehlo.constant dense<4> : tensor<i32>
// CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
%2 = call @refine_call_callee(%1, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// CHECK: refine_call_callee(%arg0: tensor<4xf32>) -> tensor<4xf32>
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: stablehlo.constant dense<4>
%0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
return %1 : tensor<?xf32>
}
}

// -----

// CHECK-LABEL: module @refine_call_dimension_arguments
module @refine_call_dimension_arguments {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT:%.*]] = call @callee
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<3> : tensor<i32>
%1 = call @callee(%0, %0, %arg0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// %arg0 and %arg1 are dimension arguments
// CHECK: @callee([[ARG0:%.*]]: tensor<i32>) -> tensor<i32>
func.func private @callee(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
// CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
// CHECK: return [[RESULT1]]
%0 = stablehlo.add %arg0, %arg1: tensor<i32>
%1 = stablehlo.add %0, %arg2: tensor<i32>
return %1 : tensor<i32>
}
}

// -----

// CHECK-LABEL: module @refine_call_prefix_token_and_dimension_arguments
module @refine_call_prefix_token_and_dimension_arguments {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT:%.*]] = call @callee
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<3> : tensor<i32>
%token = stablehlo.create_token : !stablehlo.token
%1 = call @callee(%token, %0, %0, %arg0) : (!stablehlo.token, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// %arg0 and %arg1 are dimension arguments
// CHECK: @callee([[ARG_TOKEN:%.*]]: !stablehlo.token, [[ARG0:%.*]]: tensor<i32>
func.func private @callee(%arg_token: !stablehlo.token, %arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
// CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
// CHECK: return [[RESULT1]]
%0 = stablehlo.add %arg0, %arg1: tensor<i32>
%1 = stablehlo.add %0, %arg2: tensor<i32>
return %1 : tensor<i32>
}
}

// -----

// CHECK-LABEL: module @refine_call_dimension_arguments_followed_by_token
module @refine_call_dimension_arguments_followed_by_token {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT:%.*]] = call @callee
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<3> : tensor<i32>
%token = stablehlo.create_token : !stablehlo.token
%1 = call @callee(%0, %0, %token, %arg0) : (tensor<i32>, tensor<i32>, !stablehlo.token, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// %arg0 and %arg1 are dimension arguments
// CHECK: @callee([[ARG_TOKEN:%.*]]: !stablehlo.token, [[ARG0:%.*]]: tensor<i32>
func.func private @callee(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg_token: !stablehlo.token, %arg2: tensor<i32>) -> tensor<i32> {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
// CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
// CHECK: return [[RESULT1]]
%0 = stablehlo.add %arg0, %arg1: tensor<i32>
%1 = stablehlo.add %0, %arg2: tensor<i32>
return %1 : tensor<i32>
}
}

// -----

// CHECK-LABEL: module @refine_multiple_call_with_same_context
module @refine_multiple_call_with_same_context {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
// CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
%1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
%2 = call @refine_call_callee(%arg0_new, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %arg1 : tensor<?xf32>
}
}

// -----

// CHECK-LABEL: module @refine_multiple_call_constant_function
module @refine_multiple_call_constant_function {
func.func @main(%arg0: tensor<5xf32>) -> tensor<i32> {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<16>
// CHECK: return [[RESULT0]]
%0 = stablehlo.constant dense<4> : tensor<i32>
%1 = call @refine_call_callee(%0, %arg0) : (tensor<i32>, tensor<5xf32>) -> tensor<i32>
%2 = call @refine_call_callee(%0, %arg0) : (tensor<i32>, tensor<5xf32>) -> tensor<i32>
%3 = stablehlo.add %1, %2: tensor<i32>
return %3 : tensor<i32>
}
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> tensor<i32> {
// CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<8>
// CHECK: return [[RESULT1]]
%0 = stablehlo.add %arg0, %arg0: tensor<i32>
return %0 : tensor<i32>
}
}

// -----

module @refine_call_multiple_with_different_number_dimension_arguments {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
// Ensure that the first argument is not a constant at the second call site
%arg0_different_f32 = stablehlo.bitcast_convert %arg0_new : (tensor<i32>) -> tensor<f32>
%arg0_different_i32 = stablehlo.bitcast_convert %arg0_different_f32 : (tensor<f32>) -> tensor<i32>
// expected-error@+1{{incorrect number of operands for callee}}
%2 = call @refine_call_callee(%arg0_different_i32, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// expected-error@+1{{'func.func' op refined with invompatible refinement keys}}
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %arg1 : tensor<?xf32>
}
}

// -----

module @refine_call_multiple_different_dimension_arguments {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
%arg0_different = stablehlo.add %arg0_new, %arg0_new : tensor<i32>
// expected-error@+1{{incorrect number of operands for callee}}
%2 = call @refine_call_callee(%arg0_different, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// expected-error@+1{{'func.func' op refined with invompatible refinement keys}}
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %arg1 : tensor<?xf32>
}
}

// -----

module @refine_call_multiple_different_non_dimension_arguments {
func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
%arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
%2 = stablehlo.constant dense<[1., 2.]> : tensor<2xf32>
%3 = stablehlo.concatenate %1, %2, dim = 0 : (tensor<?xf32>, tensor<2xf32>) -> tensor<?xf32>
// expected-error@+1{{incorrect number of operands for callee}}
%4 = call @refine_call_callee(%arg0_new, %3) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
return %4 : tensor<?xf32>
}
// expected-error@+1{{'func.func' op refined with invompatible refinement keys}}
func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %arg1 : tensor<?xf32>
}
}

// -----

module @refine_call_recursive {
func.func @main() -> tensor<i32> {
%0 = stablehlo.constant dense<3> : tensor<i32>
%1 = call @refine_call_callee(%0) : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// expected-error@+1{{Function refine_call_callee is being refined recursively}}
func.func @refine_call_callee(%arg0: tensor<i32>) -> tensor<i32> {
// expected-error@+1{{incorrect number of operands}}
%0 = call @refine_call_callee(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
}

// -----

module @refine_call_main_argument_unranked {
// CHECK-LABEL: func.func public @main(%arg0: tensor<*xi32>) -> tensor<*xi32>
func.func public @main(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%2 = call @callee(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
return %2 : tensor<*xi32>
}
func.func private @callee(%arg0: tensor<*xi32>) -> tensor<*xi32> {
return %arg0 : tensor<*xi32>
}
}

// -----

module @refine_call_main_argument_dynamic_shape {
// CHECK: func.func public @main(%arg0: tensor<?xi32>) -> tensor<?xi32>
func.func public @main(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%2 = call @callee(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
func.func private @callee(%arg0: tensor<?xi32>) -> tensor<?xi32> {
return %arg0 : tensor<?xi32>
}
}

// -----

module @refine_call_callee_argument_dynamic_shape {
// CHECK: func.func public @main(%arg0: tensor<1xi64>) -> tensor<?xi32>
func.func public @main(%arg0: tensor<1xi64>) -> tensor<?xi32> {
%1 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<1xi64>) -> tensor<?xi32>
%2 = call @callee(%1) : (tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
func.func private @callee(%arg0: tensor<?xi32>) -> tensor<?xi32> {
return %arg0 : tensor<?xi32>
}
}

// -----

// CHECK-LABEL: module @refine_call_dimension_argument_non_scalar
// The non-scalar constant is not folded into the callee
module @refine_call_dimension_argument_non_scalar {
func.func public @main() -> tensor<4xi32> {
// CHECK: dense<[1, 2, 3, 4]> : tensor<4xi32>
%0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
%1 = call @callee(%0) : (tensor<4xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}
func.func private @callee(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: return %arg0 : tensor<4xi32>
return %arg0 : tensor<4xi32>
}
}

// -----

// CHECK-LABEL: module @refine_call_dimension_argument_not_integer
module @refine_call_dimension_argument_not_integer {
func.func public @main() -> tensor<f32> {
%0 = stablehlo.constant dense<3.> : tensor<f32>
// CHECK: call @callee({{.*}}) : (tensor<f32>) -> tensor<f32>
%2 = call @callee(%0) : (tensor<f32>) -> tensor<f32>
return %2 : tensor<f32>
}
func.func private @callee(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}
}

// -----
Expand Down Expand Up @@ -656,6 +935,17 @@ func.func @refine_custom_call_operand_wrapper(%arg0: tensor<10x5xf32>) -> tensor

// -----

// CHECK-LABEL: @refine_custom_call_operand_wrapper_unranked
func.func @refine_custom_call_operand_wrapper_unranked(%arg0: tensor<4xi32>) -> tensor<*xi32> {
// CHECK-NOT: stablehlo.shape_refinement_operand_wrapper
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %0) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<*xi32>
// CHECK: return %arg0 : tensor<4xi32>
func.return %1 : tensor<*xi32>
}

// -----

// CHECK-LABEL: @refine_dot_general
func.func @refine_dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<?x?x?xf32> {
// CHECK: stablehlo.dot_general{{.*}} -> tensor<2x4x5xf32>
Expand Down Expand Up @@ -756,6 +1046,8 @@ func.func @refine_dynamic_reshape(%arg0: tensor<4xf32>) -> tensor<?x?xf32> {
func.return %1 : tensor<?x?xf32>
}



// -----

// CHECK-LABEL: @refine_infer_type_op_interface_supported_dialect_chlo
Expand Down Expand Up @@ -908,6 +1200,7 @@ func.func @refine_while(%arg0: tensor<4xf32>) -> tensor<?xf32> {
// -----

// TODO: Implement support for these ops.
// * dynamic_conv (#867).
// * dynamic_fft (#1366).
// * dynamic_reduce_window (#1258).
// * dynamic_rng_bit_generator (#1344).
Expand Down
20 changes: 20 additions & 0 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,25 @@ def StablehloRefineShapesPass : Pass<"stablehlo-refine-shapes", "ModuleOp"> {

%1 = stablehlo.add %arg0, %arg0 : tensor<16xf32>
```

Modules valid for shape refinement must have the following properties:

* All the dynamic shapes depend only on the input shapes (no shape
dependency on the input array contents). We refer to the operations that
depend transitively only on the input shapes (e.g., as given by
`stablehlo.get_dimension_size`) or global constants like the resolved
values of symbolic integers (i.e. tensor<Axf32> : A = 5), as `dimension`
operations. All dimension values can be resolved to constants through
inter-procedural constant folding.
* Intermediate functions may take a number of token arguments (of type
!stablehlo.token) at the start of the argument list, followed by some
global constant arguments which are constant integer scalars, such as the
resolved values of symbolic integers (i.e. tensor<Axf32> : A = 5).
* Some intermediate functions may return computations on global constants,
i.e. `floordiv` on symint values. These functions are indicated by only
returning constant values after refinement. These functions are inlined.
* All calls to a single function resolve to the same argument shapes, and no
recursive / co-recursive function calls are made.
}];
}

Expand All @@ -375,4 +394,5 @@ def VhloToVersionPass : Pass<"vhlo-to-version"> {
Option<"targetVersionOption", "target", "std::string", "",
"The target version. Must be a version of the form #.#.# .">,
];
let dependentDialects = ["mlir::vhlo::VhloDialect"];
}
Loading

0 comments on commit e83ab74

Please sign in to comment.