From 0e1beb7b15a4e0ab6b87fa30b9960bbd4bf53259 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 25 Mar 2024 20:18:48 +0800 Subject: [PATCH] Complement part of activation ops (#49) e.g. relu, threshold, threshold_backward, gelu, gelu_backward, tanh_backward --------- Signed-off-by: Feng Yuan Co-authored-by: Feng Yuan --- .gitignore | 10 ++ src/aten/Activation.cpp | 124 ++++++++++++++++++ src/aten/BinaryOps.cpp | 19 +++ src/aten/sycl/ActivationGeluKernel.cpp | 123 +++++++++++++++++ src/aten/sycl/ActivationGeluKernel.h | 17 +++ src/aten/sycl/ActivationOpsKernels.cpp | 28 ++++ src/aten/sycl/ActivationOpsKernels.h | 13 ++ src/aten/sycl/ActivationThresholdKernel.cpp | 41 ++++++ src/aten/sycl/ActivationThresholdKernel.h | 16 +++ .../sycl/BinaryMiscBackwardOpsKernels.cpp | 49 +++++++ src/aten/sycl/BinaryMiscBackwardOpsKernels.h | 13 ++ src/comm/XPUMathCompat.h | 24 ++++ test/xpu/test_ops.py | 3 + yaml/xpu_functions.yaml | 15 +++ 14 files changed, 495 insertions(+) create mode 100644 .gitignore create mode 100644 src/aten/Activation.cpp create mode 100644 src/aten/sycl/ActivationGeluKernel.cpp create mode 100644 src/aten/sycl/ActivationGeluKernel.h create mode 100644 src/aten/sycl/ActivationOpsKernels.cpp create mode 100644 src/aten/sycl/ActivationOpsKernels.h create mode 100644 src/aten/sycl/ActivationThresholdKernel.cpp create mode 100644 src/aten/sycl/ActivationThresholdKernel.h create mode 100644 src/aten/sycl/BinaryMiscBackwardOpsKernels.cpp create mode 100644 src/aten/sycl/BinaryMiscBackwardOpsKernels.h create mode 100644 src/comm/XPUMathCompat.h diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..a8bdf2194 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +*/*.pyc +*/*.so* +*/**/__pycache__ +*/**/*.dylib* +*/**/*.pyc +*/**/*.pyd +*/**/*.so* +*/**/**/*.pyc +*/**/**/**/*.pyc +*/**/**/**/**/*.pyc diff --git a/src/aten/Activation.cpp b/src/aten/Activation.cpp new file mode 100644 index 000000000..da0848dba --- /dev/null +++ b/src/aten/Activation.cpp @@ -0,0 +1,124 @@ +#include +#include +#include +#include + +#include +#include +#include + +namespace at { + +Tensor XPUNativeFunctions::relu(const Tensor& self) { + Tensor out; + auto iter = TensorIterator::unary_op(out, self); + native::xpu::relu_kernel(iter); + return iter.output(); +} + +Tensor& XPUNativeFunctions::relu_(Tensor& self) { + auto iter = TensorIterator::unary_op(self, self); + native::xpu::relu_kernel(iter); + return self; +} + +Tensor& XPUNativeFunctions::relu_out(const Tensor& self, Tensor& out) { + auto iter = TensorIterator::unary_op(out, self); + native::xpu::relu_kernel(iter); + return out; +} + +Tensor XPUNativeFunctions::threshold( + const Tensor& self, + const Scalar& threshold, + const Scalar& value) { + Tensor out; + auto iter = TensorIterator::binary_op(out, self, self); + native::xpu::threshold_kernel(iter, threshold, value); + return iter.output(); +} + +Tensor& XPUNativeFunctions::threshold_( + Tensor& self, + const Scalar& threshold, + const Scalar& value) { + auto iter = TensorIterator::binary_op(self, self, self); + native::xpu::threshold_kernel(iter, threshold, value); + return self; +} + +Tensor& XPUNativeFunctions::threshold_out( + const Tensor& self, + const Scalar& threshold, + const Scalar& value, + Tensor& out) { + auto iter = TensorIterator::binary_op(out, self, self); + native::xpu::threshold_kernel(iter, threshold, value); + return out; +} + +Tensor XPUNativeFunctions::threshold_backward( + const Tensor& grad_output, + const Tensor& self, + const Scalar& threshold) { + Tensor grad_input; + auto iter = TensorIterator::binary_op(grad_input, self, grad_output); + native::xpu::threshold_kernel(iter, threshold, 0); + return iter.output(); +} + +Tensor& XPUNativeFunctions::threshold_backward_out( + const Tensor& grad_output, + const Tensor& self, + const Scalar& threshold, + Tensor& grad_input) { + auto iter = TensorIterator::binary_op(grad_input, self, grad_output); + native::xpu::threshold_kernel(iter, threshold, 0); + return grad_input; +} + +Tensor XPUNativeFunctions::gelu( + const Tensor& self, + c10::string_view approximate) { + Tensor out; + auto iter = TensorIterator::unary_op(out, self); + native::xpu::gelu_kernel(iter, approximate); + return iter.output(); +} + +Tensor& XPUNativeFunctions::gelu_(Tensor& self, c10::string_view approximate) { + auto iter = TensorIterator::unary_op(self, self); + native::xpu::gelu_kernel(iter, approximate); + return self; +} + +Tensor& XPUNativeFunctions::gelu_out( + const Tensor& self, + c10::string_view approximate, + Tensor& out) { + auto iter = TensorIterator::unary_op(out, self); + native::xpu::gelu_kernel(iter, approximate); + return out; +} + +Tensor XPUNativeFunctions::gelu_backward( + const Tensor& grad_output, + const Tensor& self, + c10::string_view approximate) { + Tensor grad_input; + auto iter = TensorIterator::binary_op(grad_input, grad_output, self); + native::xpu::gelu_backward_kernel(iter, approximate); + return iter.output(); +} + +Tensor& XPUNativeFunctions::gelu_backward_out( + const Tensor& grad_output, + const Tensor& self, + c10::string_view approximate, + Tensor& grad_input) { + auto iter = TensorIterator::binary_op(grad_input, grad_output, self); + native::xpu::gelu_backward_kernel(iter, approximate); + return grad_input; +} + +} // namespace at diff --git a/src/aten/BinaryOps.cpp b/src/aten/BinaryOps.cpp index 503a1eddf..d38525ee7 100644 --- a/src/aten/BinaryOps.cpp +++ b/src/aten/BinaryOps.cpp @@ -5,6 +5,7 @@ #include #include +#include #include namespace at { @@ -331,4 +332,22 @@ Tensor& XPUNativeFunctions::fmod_out( return XPUNativeFunctions::fmod_out(self, wrapper, out); } +Tensor XPUNativeFunctions::tanh_backward( + const Tensor& grad_output, + const Tensor& output) { + Tensor out; + auto iter = TensorIterator::binary_op(out, grad_output, output); + native::xpu::tanh_backward_kernel(iter); + return iter.output(); +} + +Tensor& XPUNativeFunctions::tanh_backward_out( + const Tensor& grad_output, + const Tensor& output, + Tensor& grad_input) { + auto iter = TensorIterator::binary_op(grad_input, grad_output, output); + native::xpu::tanh_backward_kernel(iter); + return grad_input; +} + } // namespace at diff --git a/src/aten/sycl/ActivationGeluKernel.cpp b/src/aten/sycl/ActivationGeluKernel.cpp new file mode 100644 index 000000000..847ce6146 --- /dev/null +++ b/src/aten/sycl/ActivationGeluKernel.cpp @@ -0,0 +1,123 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace xpu { + +template +struct GeluTanhFunctor { + scalar_t operator()(scalar_t x) const { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_cube = static_cast(x) * static_cast(x) * + static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + return opmath_t(0.5) * static_cast(x) * + (opmath_t(1) + c10::xpu::compat::tanh(inner)); + } +}; + +template +struct GeluTanhBackwardFunctor { + scalar_t operator()(scalar_t dy, scalar_t x) const { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_sq = static_cast(x) * static_cast(x); + auto x_cube = x_sq * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + auto tanh_inner = c10::xpu::compat::tanh(inner); + + auto left = opmath_t(0.5) * static_cast(x); + auto right = opmath_t(1) + tanh_inner; + + auto left_derivative = opmath_t(0.5) * right; + + auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; + auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); + auto right_derivative = left * tanh_derivative * inner_derivative; + + return static_cast(dy) * (left_derivative + right_derivative); + } +}; + +template +struct GeluErfFunctor { + scalar_t operator()(scalar_t x) const { + using opmath_t = at::opmath_type; + constexpr opmath_t kAlpha = M_SQRT1_2; + return static_cast(x) * opmath_t(0.5) * + (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + } +}; + +template +struct GeluErfBackwardFunctor { + scalar_t operator()(scalar_t dy, scalar_t x) const { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); + constexpr opmath_t kAlpha = M_SQRT1_2; + const opmath_t cdf = opmath_t(0.5) * + (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + const opmath_t pdf = c10::xpu::compat::exp( + opmath_t(-0.5) * static_cast(x) * + static_cast(x)) * + kBeta; + return static_cast(dy) * (cdf + static_cast(x) * pdf); + } +}; + +void gelu_kernel(TensorIteratorBase& iter, c10::string_view approximate) { + auto approximate_ = at::native::get_gelutype_enum(approximate); + if (approximate_ == at::native::GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + iter.dtype(), + "gelu_tanh_xpu", + [&]() { gpu_kernel(iter, GeluTanhFunctor()); }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + iter.dtype(), + "gelu_erf_xpu", + [&]() { gpu_kernel(iter, GeluErfFunctor()); }); + } +} + +void gelu_backward_kernel( + TensorIteratorBase& iter, + c10::string_view approximate) { + auto approximate_ = at::native::get_gelutype_enum(approximate); + if (approximate_ == at::native::GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + iter.dtype(), + "gelu_tanh_backward_xpu", + [&]() { + gpu_kernel_with_scalars(iter, GeluTanhBackwardFunctor()); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + iter.dtype(), + "gelu_erf_backward_xpu", + [&]() { + gpu_kernel_with_scalars(iter, GeluErfBackwardFunctor()); + }); + } +} + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/aten/sycl/ActivationGeluKernel.h b/src/aten/sycl/ActivationGeluKernel.h new file mode 100644 index 000000000..6c373a7cd --- /dev/null +++ b/src/aten/sycl/ActivationGeluKernel.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace xpu { + +void gelu_kernel(TensorIteratorBase& iter, c10::string_view approximate); + +void gelu_backward_kernel( + TensorIteratorBase& iter, + c10::string_view approximate); + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/aten/sycl/ActivationOpsKernels.cpp b/src/aten/sycl/ActivationOpsKernels.cpp new file mode 100644 index 000000000..31d6c7f0b --- /dev/null +++ b/src/aten/sycl/ActivationOpsKernels.cpp @@ -0,0 +1,28 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace xpu { + +template +struct ReluFunctor { + scalar_t operator()(scalar_t x) const { + return x <= scalar_t{0} ? scalar_t{0} : x; + } +}; + +void relu_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "relu_xpu", [&]() { + gpu_kernel(iter, ReluFunctor()); + }); +} + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/aten/sycl/ActivationOpsKernels.h b/src/aten/sycl/ActivationOpsKernels.h new file mode 100644 index 000000000..a96d9dc3d --- /dev/null +++ b/src/aten/sycl/ActivationOpsKernels.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace xpu { + +void relu_kernel(TensorIteratorBase& iter); + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/aten/sycl/ActivationThresholdKernel.cpp b/src/aten/sycl/ActivationThresholdKernel.cpp new file mode 100644 index 000000000..5747cb256 --- /dev/null +++ b/src/aten/sycl/ActivationThresholdKernel.cpp @@ -0,0 +1,41 @@ +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace xpu { + +template +struct ThresholdFunctor { + scalar_t operator()(scalar_t x, scalar_t other) const { + return x <= threshold_ ? value_ : other; + } + + ThresholdFunctor(scalar_t threshold, scalar_t value) + : threshold_(threshold), value_(value) {} + + private: + scalar_t threshold_; + scalar_t value_; +}; + +void threshold_kernel( + TensorIteratorBase& iter, + const Scalar& threshold, + const Scalar& value) { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, iter.dtype(), "threshold_xpu", [&]() { + scalar_t threshold_ = threshold.to(); + scalar_t value_ = value.to(); + gpu_kernel_with_scalars( + iter, ThresholdFunctor(threshold_, value_)); + }); +} + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/aten/sycl/ActivationThresholdKernel.h b/src/aten/sycl/ActivationThresholdKernel.h new file mode 100644 index 000000000..c22a6b18c --- /dev/null +++ b/src/aten/sycl/ActivationThresholdKernel.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace xpu { + +void threshold_kernel( + TensorIteratorBase& iter, + const Scalar& threshold, + const Scalar& value); + +} +} // namespace native +} // namespace at diff --git a/src/aten/sycl/BinaryMiscBackwardOpsKernels.cpp b/src/aten/sycl/BinaryMiscBackwardOpsKernels.cpp new file mode 100644 index 000000000..9299c090e --- /dev/null +++ b/src/aten/sycl/BinaryMiscBackwardOpsKernels.cpp @@ -0,0 +1,49 @@ +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace xpu { + +template +struct TanhBackwardComplexFunctor { + scalar_t operator()(scalar_t a, scalar_t b) const { + using comp_t = at::opmath_type; + const auto one = comp_t{1.}; + const auto comp_b = static_cast(b); + const auto comp_a = static_cast(a); + return static_cast(comp_a * std::conj(one - comp_b * comp_b)); + } +}; + +template +struct TanhBackwardFunctor { + scalar_t operator()(scalar_t a, scalar_t b) const { + return a * (scalar_t{1.} - b * b); + } +}; + +void tanh_backward_kernel(TensorIteratorBase& iter) { + auto dtype = iter.dtype(); + if (isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, dtype, "tanh_backward_complex_xpu", [&]() { + gpu_kernel(iter, TanhBackwardComplexFunctor()); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + dtype, + "tanh_backward_xpu", + [&]() { gpu_kernel(iter, TanhBackwardFunctor()); }); + } +} + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/aten/sycl/BinaryMiscBackwardOpsKernels.h b/src/aten/sycl/BinaryMiscBackwardOpsKernels.h new file mode 100644 index 000000000..0aa30db93 --- /dev/null +++ b/src/aten/sycl/BinaryMiscBackwardOpsKernels.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace xpu { + +void tanh_backward_kernel(TensorIteratorBase& iter); + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/comm/XPUMathCompat.h b/src/comm/XPUMathCompat.h new file mode 100644 index 000000000..127417010 --- /dev/null +++ b/src/comm/XPUMathCompat.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +#define __MATH_FUNCTIONS_DECL__ static inline + +namespace c10::xpu::compat { + +__MATH_FUNCTIONS_DECL__ float exp(float x) { + return ::expf(x); +} +__MATH_FUNCTIONS_DECL__ double exp(double x) { + return ::exp(x); +} + +__MATH_FUNCTIONS_DECL__ float tanh(float x) { + return ::tanhf(x); +} +__MATH_FUNCTIONS_DECL__ double tanh(double x) { + return ::tanh(x); +} + +} // namespace c10::xpu::compat diff --git a/test/xpu/test_ops.py b/test/xpu/test_ops.py index ae6bb73eb..f40dc4fe0 100644 --- a/test/xpu/test_ops.py +++ b/test/xpu/test_ops.py @@ -67,6 +67,9 @@ "reciprocal", "pow", "unfold", + "nn.functional.threshold", + "nn.functional.relu", + "nn.functional.gelu", ] _xpu_tensor_factory_op_list = [ "normal", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index ce31f0011..06d3441c3 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -45,6 +45,8 @@ supported: - fmod.Scalar - fmod_.Scalar - fmod.Scalar_out + - tanh_backward + - tanh_backward.grad_input - eq.Scalar - eq.Scalar_out - eq_.Scalar @@ -81,6 +83,19 @@ supported: - ge.Tensor - ge.Tensor_out - ge_.Tensor + - relu + - relu_ + - relu.out + - threshold + - threshold_ + - threshold.out + - threshold_backward + - threshold_backward.grad_input + - gelu + - gelu_ + - gelu.out + - gelu_backward + - gelu_backward.grad_input - abs - abs_ - abs.out