diff --git a/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp b/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp index 73148047..1de065e7 100644 --- a/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp +++ b/lib/Conversion/StablehloToEmitC/StablehloToEmitC.cpp @@ -481,6 +481,43 @@ class RngOpConversion : public OpConversionPattern { } }; +/// Convert `stablehlo.rsqrt` into an `emitc.call_opaque` operation. +class RsqrtOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + RsqrtOpConversion(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + +private: + LogicalResult + matchAndRewrite(stablehlo::RsqrtOp rsqrtOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ArrayAttr args; + ArrayAttr templateArgs; + + // Create sqrt op. + StringRef sqrtFuncName = "emitc::sqrt"; + StringAttr sqrtCallee = rewriter.getStringAttr(sqrtFuncName); + + auto sqrtEmitCOp = rewriter.create( + rsqrtOp.getLoc(), rsqrtOp.getType(), sqrtCallee, args, templateArgs, + adaptor.getOperands()); + + // Create reciprocal op. + StringRef reciprocalFuncName = "emitc::stablehlo::rsqrt"; + StringAttr reciprocalCallee = rewriter.getStringAttr(reciprocalFuncName); + + auto reciprocalOp = rewriter.create( + sqrtEmitCOp.getLoc(), rsqrtOp.getType(), reciprocalCallee, args, + templateArgs, sqrtEmitCOp.getResults()); + + rewriter.replaceOp(rsqrtOp, reciprocalOp.getResults()); + + return success(); + } +}; + } // namespace void populateStablehloToEmitcPatterns(MLIRContext *ctx, @@ -518,6 +555,8 @@ void populateStablehloToEmitcPatterns(MLIRContext *ctx, "emitc::stablehlo::sin"); patterns.add>( ctx, "emitc::stablehlo::sqrt"); + patterns.add>( + ctx, "emitc::stablehlo::rsqrt"); patterns.add>( ctx, "emitc::stablehlo::tanh"); @@ -616,6 +655,7 @@ struct ConvertStablehloToEmitCPass stablehlo::RoundOp, stablehlo::SineOp, stablehlo::SqrtOp, + stablehlo::RsqrtOp, stablehlo::TanhOp>(); // StableHLO binary elementwise ops. diff --git a/reference-implementation/include/emitc/core_ops.h b/reference-implementation/include/emitc/core_ops.h index beb279ca..2d862187 100644 --- a/reference-implementation/include/emitc/core_ops.h +++ b/reference-implementation/include/emitc/core_ops.h @@ -44,6 +44,16 @@ inline Src ceil(Src x) { return unary(x, f); } +// RsqrtOp +template +inline Src rsqrt(Src x) { + using ET_Src = typename get_element_type::type; + + auto f = [](ET_Src element) { return (static_cast(1.0) / element); }; + + return unary(x, f); +} + // ConvertOp template inline Dest convert(Src x) { diff --git a/reference-implementation/include/emitc/stablehlo.h b/reference-implementation/include/emitc/stablehlo.h index a349734b..7999f2ff 100644 --- a/reference-implementation/include/emitc/stablehlo.h +++ b/reference-implementation/include/emitc/stablehlo.h @@ -170,6 +170,12 @@ inline Src sqrt(Src x) { return emitc::sqrt(x); } +// RsqrtOp +template +inline Src rsqrt(Src x) { + return emitc::rsqrt(x); +} + // TanhOp template inline Src tanh(Src x) { diff --git a/reference-implementation/include/emitc/tosa.h b/reference-implementation/include/emitc/tosa.h index 787d62a3..9cc2d63c 100644 --- a/reference-implementation/include/emitc/tosa.h +++ b/reference-implementation/include/emitc/tosa.h @@ -94,11 +94,7 @@ inline Src negate(Src x) { // ReciprocalOp template inline Src reciprocal(Src x) { - using ET_Src = typename get_element_type::type; - - auto f = [](ET_Src element) { return (static_cast(1.0) / element); }; - - return unary(x, f); + return emitc::rsqrt(x); } // RescaleOp diff --git a/reference-implementation/unittests/stablehlo.cpp b/reference-implementation/unittests/stablehlo.cpp index 71333fac..aae61bab 100644 --- a/reference-implementation/unittests/stablehlo.cpp +++ b/reference-implementation/unittests/stablehlo.cpp @@ -627,6 +627,55 @@ TEST(stablehlo, sqrt) { } } +TEST(stablehlo, rsqrt) { + { + Tensor0D x{1.0f}; + Tensor0D expected_result{1.0f}; + Tensor0D result = stablehlo::rsqrt(x); + + EXPECT_THAT(result, Pointwise(FloatEq(), expected_result)); + } + { + Tensor1D x{6.312247e+64, -9.053782e-32}; + Tensor1D expected_result{1.5842219102009158e-65, + -1.1045108000170537e+31}; + Tensor1D result = stablehlo::rsqrt(x); + + EXPECT_THAT(result, Pointwise(DoubleEq(), expected_result)); + } + { + Tensor2D x{1.393225e+27f, -1.151362e-12f, -5.340778e+5f, + 1.346074e+6f, 1.373985f, 9.198730e+7f}; + Tensor2D expected_result{7.177592e-28f, -8.685366e+11f, + -1.872386e-6f, 7.429012e-7f, + 7.278100e-1f, 1.087107e-8f}; + Tensor2D result = stablehlo::rsqrt(x); + + EXPECT_THAT(result, Pointwise(FloatEq(), expected_result)); + } + { + Tensor3D x{-1.857135e-3, 3.523054e-5, 1.704234e+59, + -7.043905e-21}; + Tensor3D expected_result{ + -5.384638165776855e+2, 2.838446416092402e+4, 5.867738819903839e-60, + -1.4196670738745057e+20}; + Tensor3D result = stablehlo::rsqrt(x); + + EXPECT_THAT(result, Pointwise(DoubleEq(), expected_result)); + } + { + Tensor4D x{-2.524463e+22f, -5.496311e-5f, -1.025806e-2f, + 2.648090e-10f, 7.170789f, 2.227768e-26f, + 2.188774e+17f, 5.150893f}; + Tensor4D expected_result{ + -3.961238e-23f, -1.819402e+4f, -9.748432e+1f, 3.776307e+9f, + 1.394547e-1f, 4.488798e+25f, 4.568768e-18f, 1.941411e-1f}; + Tensor4D result = stablehlo::rsqrt(x); + + EXPECT_THAT(result, Pointwise(FloatEq(), expected_result)); + } +} + TEST(stablehlo, tanh) { EXPECT_NEAR(0.0f, stablehlo::tanh(0.0f), EPSILON);