From 8d5e2578b0964f773cb5c1af85809de187a74e7c Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 9 Apr 2024 14:54:57 +0800 Subject: [PATCH] =?UTF-8?q?[Stablehlo]=20lowering=20aten.view=20to=20shape?= =?UTF-8?q?.num=5Felements=20+=20stablehlo.comp=E2=80=A6=20(#3125)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ute_reshape_shape as that `aten.view` support at most one `-1` in dim list. The original calculation of `numel` is wrong when there is a `-1` in dim list. --- lib/Conversion/TorchToStablehlo/ViewLike.cpp | 34 +++---------------- .../TorchToStablehlo/view_like.mlir | 30 +++++++--------- 2 files changed, 16 insertions(+), 48 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index 507821dee638..f6c879907004 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -13,6 +13,7 @@ #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" @@ -178,8 +179,7 @@ class ConvertAtenViewOp : public ConvertAtenOp { } auto loc = op.getLoc(); - auto newRank = dimSizes.size(); - if (newRank == 0 || rankType.getRank() == 0) { + if (dimSizes.size() == 0 || rankType.getRank() == 0) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( @@ -193,35 +193,9 @@ class ConvertAtenViewOp : public ConvertAtenOp { return dSize; }); - const auto &options = ConvertAtenOp::getOptions(); - Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); - if (options.dimSizeIndexBits == 32) { - // The i64 calculation is much slower than i32 on some devices, such as - // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are - // unlikely to exceed the range of i32(4GiB) - std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { - // dimSize: cast i64 -> i32 - dSize = rewriter.create(loc, intType, dSize); - return dSize; - }); - } - - Value numel = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); - for (auto d : dimSizes) { - numel = rewriter.create(loc, numel, d); - } - numel = rewriter.create(loc, rewriter.getIndexType(), - numel); + Value numel = rewriter.create( + loc, rewriter.create(loc, adaptor.getSelf())); - if (dimSizes.size() == 0) { - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - adaptor.getSelf()); - return success(); - } Value stablehloShape = rewriter.create(loc, dimSizes); Value computedShape = rewriter.create( diff --git a/test/Conversion/TorchToStablehlo/view_like.mlir b/test/Conversion/TorchToStablehlo/view_like.mlir index 206084873c81..30f33a4fbcea 100644 --- a/test/Conversion/TorchToStablehlo/view_like.mlir +++ b/test/Conversion/TorchToStablehlo/view_like.mlir @@ -308,15 +308,13 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]224 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1 // CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]] -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T4:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64 -// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64 -// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index +// CHECK: %[[T4:.*]] = shape.shape_of %[[T0]] : tensor -> tensor<4xindex> +// CHECK: %[[T5:.*]] = shape.num_elements %[[T4]] : tensor<4xindex> -> index // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64> -// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,224],f32> -// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32> +// CHECK: %[[T6:.*]] = stablehlo.compute_reshape_shape %[[T5]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64> +// CHECK: %[[T7:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T6]] : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,224],f32> +// CHECK: return %[[T8]] : !torch.vtensor<[?,224],f32> func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { %int-1 = torch.constant.int -1 %int224 = torch.constant.int 224 @@ -339,17 +337,13 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]] // CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]] // CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]] -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T6:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64 -// CHECK: %[[T7:.*]] = arith.muli %[[T6]], %[[T3]] : i64 -// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T4]] : i64 -// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 -// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index +// CHECK: %[[T6:.*]] = shape.shape_of %[[T0]] : tensor -> tensor<5xindex> +// CHECK: %[[T7:.*]] = shape.num_elements %[[T6]] : tensor<5xindex> -> index // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> -// CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64> -// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor, tensor<4xi64>) -> tensor -// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor -> !torch.vtensor<[?,120,4,64],f32> -// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32> +// CHECK: %[[T8:.*]] = stablehlo.compute_reshape_shape %[[T7]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T8]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,120,4,64],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,120,4,64],f32> func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { %int-1 = torch.constant.int -1 %int120 = torch.constant.int 120