Skip to content

Commit

Permalink
Fix issue when targeting a patch version for serialization (#2598)
Browse files Browse the repository at this point in the history
We want to allow targeting patch version because we allow
`-target=current` so it makes sense that current can resolve to a patch
version.

That said we write our compat passes against patch v0 since all opset
changes require a minor version bump and therefore have patch v0.

This PR implements "version flooring" inside the vhlo-to-version pass.
  • Loading branch information
GleasonK authored Oct 23, 2024
1 parent 2952108 commit 7ce6c1c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
15 changes: 15 additions & 0 deletions stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.4.1' %s | FileCheck %s

// AllToAll was in the initial StableHLO opset, but changed in v1.5.0 to have
// tuple arguments. Ensure that serializing for 1.4.1 is valid and targets the
// v1.4.0 opset.
//
// This will catch issues in op `isLegal` checks:
// op.minVersion() <= target <= op.maxVersion()

// CHECK-LABEL: vhlo.func_v1 @all_to_all
func.func public @all_to_all(%arg0: tensor<8x8x1xui16>) -> tensor<1x8x8xui16> {
// CHECK: vhlo.all_to_all_v1
%0 = "stablehlo.all_to_all"(%arg0) <{concat_dimension = 2 : i64, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, split_count = 8 : i64, split_dimension = 0 : i64}> : (tensor<8x8x1xui16>) -> tensor<1x8x8xui16>
return %0 : tensor<1x8x8xui16>
}
7 changes: 7 additions & 0 deletions stablehlo/transforms/VhloToVersion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ FailureOr<Version> validateTargetVersion(llvm::StringRef versionRef,
<< " is greater than current version "
<< Version::getCurrentVersion();

// Opset changes warrant a minor version bump, so this conversion assumes
// patch v0 since it is written against the opset at version `X.Y.0`.
if (targetVersion.getPatch() != 0) {
targetVersion =
vhlo::Version(targetVersion.getMajor(), targetVersion.getMinor(), 0);
}

return targetVersion;
}

Expand Down

0 comments on commit 7ce6c1c

Please sign in to comment.