Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StableHLO Testdata Refresh #2441

Merged
merged 9 commits into from
Aug 26, 2024
Merged

StableHLO Testdata Refresh #2441

merged 9 commits into from
Aug 26, 2024

Conversation

GleasonK
Copy link
Member

StableHLO Testdata Refresh

Note: Careful attempting to review.. this is all generated by a JAX script to dump their parameterized JAX primitive tests to StableHLO test cases, so the majority of this PR isn't interesting. I'll highlight some interesting bits below. The testgen script will remain an internal tool until it's cleaned up / validated a bit then we can upload it to build_tools.

I noticed that we didn't have any CHLO tests to capture precision changes as we go about making CHLO decompositions more precise.

High level overview

Better names

Better is relative.. but these names are now shorter and more readable, and shouldn't break any users with "max file length too long" anymore.

$ ls stablehlo/testdata/
abs_bfloat16_20_20.mlir
abs_complex128_20_20.mlir
abs_complex64_20_20.mlir
abs_float16_20_20.mlir
abs_float32_20_20.mlir
abs_float64_20_20.mlir
abs_int16_20_20.mlir
abs_int32_20_20.mlir
abs_int64_20_20.mlir
abs_int8_20_20.mlir
acos_bfloat16_20_20.mlir
...

CustomCalls more closely model Check dialect ops

Pretty much all the same, just no longer do the custom_call's return anything. This will allow us to write a trivial pass to convert to/from check ops, by more closely modeling check ops.

// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt > %t.0
// RUN: stablehlo-opt %s > %t.1
// RUN: diff %t.0 %t.1

module @jit_main attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main() -> (tensor<20x20xbf16> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @inputs() : () -> tensor<20x20xbf16>
    %1 = call @expected() : () -> tensor<20x20xbf16>
    %2 = stablehlo.abs %0 : tensor<20x20xbf16>
    stablehlo.custom_call @check.expect_close(%2, %1) {has_side_effect = true} : (tensor<20x20xbf16>, tensor<20x20xbf16>) -> ()
    return %2 : tensor<20x20xbf16>
  }
  func.func private @inputs() -> (tensor<20x20xbf16> {mhlo.layout_mode = "default"}) {
    %cst = stablehlo.constant dense<"__elided__"> : tensor<20x20xbf16>
    return %cst : tensor<20x20xbf16>
  }
  func.func private @expected() -> (tensor<20x20xbf16> {mhlo.layout_mode = "default"}) {
    %cst = stablehlo.constant dense<"__elided__"> : tensor<20x20xbf16>
    return %cst : tensor<20x20xbf16>
  }
}

New CHLO tests

Now we have tests that leverage CHLO decompositions and compare them to reference semantics from JAX/XLA.

$ cat stablehlo/testdata/acosh_bfloat16_20_20_chlo.mlir

// RUN: stablehlo-opt --chlo-pre-serialization-pipeline -inline %s | stablehlo-translate --interpret
...
module @jit_main attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main() -> (tensor<20x20xbf16> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @inputs() : () -> tensor<20x20xbf16>
    %1 = call @expected() : () -> tensor<20x20xbf16>
    %2 = chlo.acosh %0 : tensor<20x20xbf16> -> tensor<20x20xbf16>
    stablehlo.custom_call @check.expect_close(%2, %1) {has_side_effect = true} : (tensor<20x20xbf16>, tensor<20x20xbf16>) -> ()
    return %2 : tensor<20x20xbf16>
  }
  // ...
}

New bugs discovered ( 😄 )

Filtering a few sort-related bugs on #2440. Note: All disabled tests contain RUN-DISABLED(reason)

$ grep -rnwl stablehlo/testdata -e 'RUN-DISABLED(#2440)'
stablehlo/testdata/top_k_float32_5_chlo.mlir
stablehlo/testdata/top_k_float64_5_3_chlo.mlir
stablehlo/testdata/sort_int8_5_7.mlir
stablehlo/testdata/top_k_uint8_5_3_chlo.mlir
stablehlo/testdata/sort_uint64_5_7.mlir
stablehlo/testdata/sort_float32_5_7.mlir
stablehlo/testdata/top_k_bool_5_3_chlo.mlir
stablehlo/testdata/top_k_int8_5_3_chlo.mlir
stablehlo/testdata/sort_bfloat16_5_7.mlir
stablehlo/testdata/sort_uint8_5_7.mlir
stablehlo/testdata/sort_int64_5_7.mlir
stablehlo/testdata/sort_complex64_5_7.mlir
stablehlo/testdata/top_k_uint16_5_3_chlo.mlir
stablehlo/testdata/sort_bool_5_7.mlir
stablehlo/testdata/top_k_int16_5_3_chlo.mlir
stablehlo/testdata/top_k_uint64_5_3_chlo.mlir
stablehlo/testdata/sort_int32_100_int32_100.mlir
stablehlo/testdata/top_k_int32_6_chlo.mlir
stablehlo/testdata/top_k_int32_5_3_chlo.mlir
stablehlo/testdata/top_k_float32_5_3_chlo.mlir
stablehlo/testdata/top_k_float16_5_3_chlo.mlir
stablehlo/testdata/sort_float64_5_7.mlir
stablehlo/testdata/sort_uint16_5_7.mlir
stablehlo/testdata/sort_int32_100_int32_100_float32_100.mlir
stablehlo/testdata/sort_int16_5_7.mlir
stablehlo/testdata/sort_float32_5.mlir
stablehlo/testdata/top_k_uint32_5_3_chlo.mlir
stablehlo/testdata/top_k_bfloat16_5_3_chlo.mlir
stablehlo/testdata/sort_complex128_5_7.mlir
stablehlo/testdata/top_k_int64_5_3_chlo.mlir
stablehlo/testdata/sort_float16_5_7.mlir
stablehlo/testdata/sort_int32_5_7.mlir
stablehlo/testdata/sort_uint32_5_7.mlir

Accuracy filters

$ grep -rnwl stablehlo/testdata -e 'RUN-DISABLED(inaccurate)'
stablehlo/testdata/dot_general_uint32_4_3_bfloat16_3_6.mlir
stablehlo/testdata/pow_complex64_20_30_complex64_20_30.mlir
stablehlo/testdata/dot_general_uint16_4_3_bfloat16_3_6.mlir
stablehlo/testdata/dot_general_int8_4_3_float16_3_6.mlir
stablehlo/testdata/dot_general_int32_4_3_float16_3_6.mlir
stablehlo/testdata/sin_complex64_20_20.mlir
stablehlo/testdata/tan_complex64_20_20_chlo.mlir
stablehlo/testdata/cumprod_complex64_8_9.mlir
stablehlo/testdata/dot_general_uint64_4_3_float16_3_6.mlir
stablehlo/testdata/dot_general_uint32_4_3_float16_3_6.mlir

Future work

OSS the test generation script

Hopefully can get to this later in the week. Just need to figure out how to go about testing it. At a minimum I'll put it in an gist that may go stale at any time.

Make accuracy constraints tighter

We don't have many accuracy guarantees today, and some of these have fairly loose tolerances. Would be good to make this more strict over time.

@GleasonK GleasonK force-pushed the testdata-refresh branch 3 times, most recently from 2d972db to 171dc74 Compare July 16, 2024 23:52
@sdasgup3 sdasgup3 self-requested a review August 19, 2024 17:37
sdasgup3 added a commit that referenced this pull request Aug 22, 2024
#2441 proposes a new format of
testdata where the custom_call value should not be return from the
testdata function. This PR updated the generator to impose the new
format.
@ghpvnist ghpvnist merged commit 8602e09 into openxla:main Aug 26, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants