diff --git a/src/ATen/native/xpu/BinaryOps.cpp b/src/ATen/native/xpu/BinaryOps.cpp index a7b38bbc9..087bce516 100644 --- a/src/ATen/native/xpu/BinaryOps.cpp +++ b/src/ATen/native/xpu/BinaryOps.cpp @@ -47,6 +47,7 @@ REGISTER_XPU_DISPATCH(maximum_stub, &xpu::maximum_kernel); REGISTER_XPU_DISPATCH(minimum_stub, &xpu::minimum_kernel); REGISTER_XPU_DISPATCH(sigmoid_backward_stub, &xpu::sigmoid_backward_kernel); REGISTER_XPU_DISPATCH(nextafter_stub, &xpu::nextafter_kernel); +REGISTER_XPU_DISPATCH(heaviside_stub, &xpu::heaviside_kernel); REGISTER_XPU_DISPATCH(hypot_stub, &xpu::hypot_kernel); REGISTER_XPU_DISPATCH(atan2_stub, &xpu::atan2_kernel); REGISTER_XPU_DISPATCH(copysign_stub, &xpu::copysign_kernel); diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 052cb23ee..17f0e00bb 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -171,7 +171,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "frexp.Tensor_out", "_fused_moving_avg_obs_fq_helper", "geqrf", - "heaviside.out", "i0.out", "igammac.out", "igamma.out", diff --git a/src/ATen/native/xpu/sycl/StepKernels.cpp b/src/ATen/native/xpu/sycl/StepKernels.cpp index 34f703591..f2f16fd92 100644 --- a/src/ATen/native/xpu/sycl/StepKernels.cpp +++ b/src/ATen/native/xpu/sycl/StepKernels.cpp @@ -14,6 +14,13 @@ struct NextafterFunctor { } }; +template +struct HeavisideFunctor { + scalar_t operator()(scalar_t a, scalar_t b) const { + return a == 0 ? b : static_cast(a > 0); + } +}; + void nextafter_kernel(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, iter.common_dtype(), "nextafter_xpu", [&]() { @@ -21,4 +28,11 @@ void nextafter_kernel(TensorIteratorBase& iter) { }); } +void heaviside_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_xpu", [&]() { + gpu_kernel_with_scalars(iter, HeavisideFunctor()); + }); +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/StepKernels.h b/src/ATen/native/xpu/sycl/StepKernels.h index c026a6b65..1c94e2c9c 100644 --- a/src/ATen/native/xpu/sycl/StepKernels.h +++ b/src/ATen/native/xpu/sycl/StepKernels.h @@ -6,4 +6,6 @@ namespace at::native::xpu { TORCH_XPU_API void nextafter_kernel(TensorIteratorBase& iter); +TORCH_XPU_API void heaviside_kernel(TensorIteratorBase& iter); + } // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 878262f2a..13653ed20 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -125,6 +125,7 @@ "nn.functional.softplus", "nn.functional.softshrink", "nextafter", + "heaviside", "nonzero", "normal", "pow", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index dc6dd9992..dd927e5e3 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -4965,6 +4965,25 @@ variants: method tags: pointwise +- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + XPU: heaviside_out + tags: pointwise + +- func: heaviside(Tensor self, Tensor values) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: heaviside.out + tags: pointwise + +- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: heaviside.out + - func: logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn structured: True