Skip to content

Commit

Permalink
[Stablehlo] lowering aten.view to shape.num_elements + stablehlo.comp… (
Browse files Browse the repository at this point in the history
llvm#3125)

…ute_reshape_shape

as that `aten.view` support at most one `-1` in dim list. The original
calculation of `numel` is wrong when there is a `-1` in dim list.
  • Loading branch information
qingyunqu authored Apr 9, 2024
1 parent 42a16fa commit 8d5e257
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 48 deletions.
34 changes: 4 additions & 30 deletions lib/Conversion/TorchToStablehlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "PopulatePatterns.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
Expand Down Expand Up @@ -178,8 +179,7 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
}

auto loc = op.getLoc();
auto newRank = dimSizes.size();
if (newRank == 0 || rankType.getRank() == 0) {
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
Expand All @@ -193,35 +193,9 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
return dSize;
});

const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
if (options.dimSizeIndexBits == 32) {
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
// unlikely to exceed the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
// dimSize: cast i64 -> i32
dSize = rewriter.create<arith::TruncIOp>(loc, intType, dSize);
return dSize;
});
}

Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
for (auto d : dimSizes) {
numel = rewriter.create<arith::MulIOp>(loc, numel, d);
}
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numel);
Value numel = rewriter.create<shape::NumElementsOp>(
loc, rewriter.create<shape::ShapeOfOp>(loc, adaptor.getSelf()));

if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.getSelf());
return success();
}
Value stablehloShape =
rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(
Expand Down
30 changes: 12 additions & 18 deletions test/Conversion/TorchToStablehlo/view_like.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,13 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]224 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]]
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[T4:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64
// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index
// CHECK: %[[T4:.*]] = shape.shape_of %[[T0]] : tensor<?x?x?x?xf32> -> tensor<4xindex>
// CHECK: %[[T5:.*]] = shape.num_elements %[[T4]] : tensor<4xindex> -> index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32>
// CHECK: %[[T6:.*]] = stablehlo.compute_reshape_shape %[[T5]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[T7:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T6]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
// CHECK: return %[[T8]] : !torch.vtensor<[?,224],f32>
func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
%int-1 = torch.constant.int -1
%int224 = torch.constant.int 224
Expand All @@ -339,17 +337,13 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]]
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]]
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]]
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[T6:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64
// CHECK: %[[T7:.*]] = arith.muli %[[T6]], %[[T3]] : i64
// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T4]] : i64
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64
// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index
// CHECK: %[[T6:.*]] = shape.shape_of %[[T0]] : tensor<?x?x?x?x?xf32> -> tensor<5xindex>
// CHECK: %[[T7:.*]] = shape.num_elements %[[T6]] : tensor<5xindex> -> index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64>
// CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32>
// CHECK: %[[T8:.*]] = stablehlo.compute_reshape_shape %[[T7]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T8]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
// CHECK: return %[[T10]] : !torch.vtensor<[?,120,4,64],f32>
func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
%int-1 = torch.constant.int -1
%int120 = torch.constant.int 120
Expand Down

0 comments on commit 8d5e257

Please sign in to comment.