diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp index cf25d27d8c..5ea377887c 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp @@ -41,8 +41,10 @@ bool isValidElementType(Value val) { /// detect whether the shapes are exactly the same or not. Hence, return false. /// Also, check the ranks of two tensors, they must be in range of (0, 4]. bool haveSameStaticShape(Value value1, Value value2) { - auto valueType1 = value1.getType().cast(); - auto valueType2 = value2.getType().cast(); + ShapedType valueType1 = value1.getType().cast(); + ShapedType valueType2 = value2.getType().cast(); + if (!valueType1.hasRank() || !valueType2.hasRank()) + return false; // Different rank, return false. if (valueType1.getRank() != valueType2.getRank()) return false; @@ -360,8 +362,9 @@ template <> bool isSuitableForZDNN(ONNXSoftmaxOp op) { if (!isValidElementType(op.input())) return false; - return ((op.axis() == 1 || op.axis() == -1) && - (op.input().getType().cast().getRank() == 2)); + ShapedType inputType = op.getType().cast(); + return (op.axis() == 1 || op.axis() == -1) && inputType.hasRank() && + (inputType.getRank() == 2); } /// Check legality for ONNXRelu. @@ -369,7 +372,8 @@ template <> bool isSuitableForZDNN(ONNXReluOp op) { if (!isValidElementType(op.X())) return false; - return (op.X().getType().cast().getRank() <= 4); + ShapedType xType = op.X().getType().cast(); + return xType.hasRank() && (xType.getRank() <= 4); } /// Check legality for ONNXTanh. @@ -377,7 +381,8 @@ template <> bool isSuitableForZDNN(ONNXTanhOp op) { if (!isValidElementType(op.input())) return false; - return (op.input().getType().cast().getRank() <= 4); + ShapedType inputType = op.getType().cast(); + return inputType.hasRank() && (inputType.getRank() <= 4); } /// Check legality for ONNXSigmoid. @@ -385,7 +390,8 @@ template <> bool isSuitableForZDNN(ONNXSigmoidOp op) { if (!isValidElementType(op.X())) return false; - return (op.X().getType().cast().getRank() <= 4); + ShapedType xType = op.X().getType().cast(); + return xType.hasRank() && (xType.getRank() <= 4); } /// Check legality for ONNXLog. @@ -393,7 +399,8 @@ template <> bool isSuitableForZDNN(ONNXLogOp op) { if (!isValidElementType(op.input())) return false; - return (op.input().getType().cast().getRank() <= 4); + ShapedType inputType = op.input().getType().cast(); + return inputType.hasRank() && (inputType.getRank() <= 4); } /// Check legality for ONNXExp. @@ -401,7 +408,8 @@ template <> bool isSuitableForZDNN(ONNXExpOp op) { if (!isValidElementType(op.input())) return false; - return (op.input().getType().cast().getRank() <= 4); + ShapedType inputType = op.input().getType().cast(); + return inputType.hasRank() && (inputType.getRank() <= 4); } /// Check legality for ONNXMatMul. diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td index 9afcb9c72d..4e17d4c339 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td @@ -33,7 +33,16 @@ def IsNoneType : Constraint())">>; def IsNotNoneType : Constraint())">>; class HasRankOf : Constraint< - CPred<"$0.getType().isa() && $0.getType().cast().getRank() == " # rank> + CPred<"$0.getType().isa() && " + "$0.getType().cast().hasRank() && " + "$0.getType().cast().getRank() == " # rank> +>; + +def IsBiasNoneOr1D : Constraint< + CPred<"$_self.getType().isa() || " + " ($_self.getType().isa() && " + " $_self.getType().cast().hasRank() && " + " $_self.getType().cast().getRank() == 1)"> >; class VariadicSizeIs : Constraint< @@ -536,14 +545,15 @@ def normalizeONNXGemmTransBPattern : Pat< (addBenefit 1) >; -def replaceONNXGemmBias1DPattern : Pat< + +def replaceONNXGemmBiasNoneOr1DPattern : Pat< (ONNXGemmOp $a, $b, $c, $_, $_, $_, $_), (ZHighUnstickOp (ZHighMatMulOp (ZHighStickOp $a, (_2DLayoutAttr)), (ZHighStickOp $b, (_2DLayoutAttr)), (ZHighStickOp $c, (_1DLayoutAttr)))), - [(HasRankOf<1> $c)], + [(IsBiasNoneOr1D:$c)], (addBenefit 0) >; diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir index 91ef05ccc0..60b143baed 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir @@ -1,5 +1,23 @@ // RUN: onnx-mlir-opt --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s +func @test_gemm_bias_none(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> { + %bias = "onnx.NoValue"() {value} : () -> none + %0 ="onnx.Gemm"(%arg0, %arg1, %bias) {alpha = 1.0 : f32, beta = 1.0 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<10x5xf32>, tensor<5x10xf32>, none) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func @test_gemm_bias_none +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x5xf32>, [[PARAM_1_:%.+]]: tensor<5x10xf32>) -> tensor<10x10xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf32, #zhigh.encoding<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf32, #zhigh.encoding<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<10x5xf32, #zhigh.encoding<{dataLayout = "2D"}>>, tensor<5x10xf32, #zhigh.encoding<{dataLayout = "2D"}>>, none) -> tensor<*xf32> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf32>) -> tensor<10x10xf32> +// CHECK: return [[VAR_4_]] : tensor<10x10xf32> +// CHECK: } +} + +// ----- + func @test_gemm_bias_1d(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>, %arg2: tensor<10xf32>) -> tensor<*xf32> { %0 ="onnx.Gemm"(%arg0, %arg1, %arg2) {alpha = 1.0 : f32, beta = 1.0 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<10x5xf32>, tensor<5x10xf32>, tensor<10xf32>) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> ()