From 42a16fa9121f5e9725786d5e3ef8a4da0cdc0d3c Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Tue, 9 Apr 2024 11:06:53 +0800 Subject: [PATCH] [Torch] Support Aten_CastFloatOp. (#3115) By canonicalize Aten_CastFloatOp into AtenToDtypeOp --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 21 ++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 8 ++++++ projects/pt1/e2e_testing/xfail_sets.py | 2 ++ .../build_tools/abstract_interp_lib_gen.py | 6 +++++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 20 +++++++++++++++ 7 files changed, 83 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7252be53c960..aef700e28c05 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10972,6 +10972,31 @@ def Torch_AtenToDeviceOp : Torch_Op<"aten.to.device", [ }]; } +def Torch_Aten_CastFloatOp : Torch_Op<"aten._cast_Float", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_cast_Float : (Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_BoolType:$non_blocking + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_CastFloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_CastFloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index dfd63c71317e..5837e745caf9 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -939,6 +939,27 @@ void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// Aten_CastFloatOp +//===----------------------------------------------------------------------===// + +void Aten_CastFloatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + // `aten.cast_float` -> `aten.to.dtype` + patterns.add(+[](Aten_CastFloatOp op, PatternRewriter &rewriter) { + auto self = op.getSelf(); + auto loc = op.getLoc(); + Value constNone = rewriter.create(loc); + Value f32Type = rewriter.create( + loc, (int)torch_upstream::ScalarType::Float); + Value constFalse = rewriter.create(loc, false); + rewriter.replaceOpWithNewOp(op, op.getType(), self, f32Type, + op.getNonBlocking(), constFalse, + constNone); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenViewOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index dde2c40d7053..3e72ee8c6edf 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6786,6 +6786,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten._cast_Float\"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.type_as\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12714,6 +12718,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._cast_Float\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" return %int6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.type_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 5e2f2510216a..a896f195941e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -412,6 +412,7 @@ "AtenRoundIntModule_basic", "AtenSubFloatModule_basic", "AtenToDeviceModule_basic", + "Aten_CastFloatModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dStaticModule_basic", "BaddbmmBroadcast1DInputModule_basic", @@ -943,6 +944,7 @@ "AtenRoundIntModule_basic", "AtenInstanceNormModule_basic", "AtenToDeviceModule_basic", + "Aten_CastFloatModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", "BaddbmmDynamicModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 628cda1ccf85..36f258462cbd 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -391,6 +391,9 @@ def aten〇to〇device〡shape(self: List[int], device: device, dtype: int, non_ def aten〇to〇other〡shape(self: List[int], other: List[int], non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇_cast_Float〡shape(self: List[int], non_blocking: bool = False) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇type_as〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -4325,6 +4328,9 @@ def aten〇to〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype other_rank, other_dtype = other_rank_dtype return other_dtype +def aten〇_cast_Float〡dtype(self_rank_dtype: Tuple[int, int], non_blocking: bool = False) -> int: + return torch.float32 + @check_dtype_function(_check_two_tensor_op()) def aten〇type_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 8fc6ac4eea48..e1118eef4186 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -669,6 +669,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True) emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") + emit("aten::_cast_Float : (Tensor, bool) -> (Tensor)", has_canonicalizer=True) emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index f164f0beea4b..8152a2b3eea1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4240,6 +4240,26 @@ def AtenToDeviceModule_basic(module, tu: TestUtils): # ============================================================================== +class Aten_CastFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4], torch.int64, True), + ]) + + def forward(self, val): + return torch.ops.aten._cast_Float(val) + +@register_test_case(module_factory=lambda: Aten_CastFloatModule()) +def Aten_CastFloatModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 4)) + + +# ============================================================================== + class UpSampleNearest2dBackward(torch.nn.Module): def __init__(self):