Skip to content

Commit

Permalink
Add aten::special_scaled_modified_bessel_* and its variants (#1038)
Browse files Browse the repository at this point in the history
- [x] special_scaled_modified_bessel_k0
- [x] special_scaled_modified_bessel_k0.out
- [x] special_scaled_modified_bessel_k1
- [x] special_scaled_modified_bessel_k1.out
- [x] special_xlog1py
- [x] special_xlog1py.out
- [x] special_zeta
- [x] special_zeta.out
- [x] special_entr
- [x] special_entr.out
- [x] special_erfcx
- [x] special_erfcx.out

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
yucai-intel and xytintel authored Nov 4, 2024
1 parent 76a690a commit f224138
Show file tree
Hide file tree
Showing 19 changed files with 300 additions and 21 deletions.
33 changes: 26 additions & 7 deletions src/ATen/native/xpu/Bessel.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/native/UnaryOps.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/xpu/sycl/BesselJ0Kernel.h>
#include <ATen/native/xpu/sycl/BesselJ1Kernel.h>
#include <ATen/native/xpu/sycl/BesselY0Kernel.h>
Expand All @@ -10,6 +10,8 @@
#include <ATen/native/xpu/sycl/ModifiedBesselI1Kernel.h>
#include <ATen/native/xpu/sycl/ModifiedBesselK0Kernel.h>
#include <ATen/native/xpu/sycl/ModifiedBesselK1Kernel.h>
#include <ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.h>
#include <ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.h>
#include <ATen/native/xpu/sycl/SphericalBesselJ0Kernel.h>

namespace at {
Expand All @@ -18,10 +20,27 @@ REGISTER_XPU_DISPATCH(special_bessel_j0_stub, &xpu::bessel_j0_kernel);
REGISTER_XPU_DISPATCH(special_bessel_j1_stub, &xpu::bessel_j1_kernel);
REGISTER_XPU_DISPATCH(special_bessel_y0_stub, &xpu::bessel_y0_kernel);
REGISTER_XPU_DISPATCH(special_bessel_y1_stub, &xpu::bessel_y1_kernel);
REGISTER_XPU_DISPATCH(special_modified_bessel_i0_stub, &xpu::modified_bessel_i0_kernel);
REGISTER_XPU_DISPATCH(special_modified_bessel_i1_stub, &xpu::modified_bessel_i1_kernel);
REGISTER_XPU_DISPATCH(special_modified_bessel_k0_stub, &xpu::modified_bessel_k0_kernel);
REGISTER_XPU_DISPATCH(special_modified_bessel_k1_stub, &xpu::modified_bessel_k1_kernel);
REGISTER_XPU_DISPATCH(special_spherical_bessel_j0_stub, &xpu::spherical_bessel_j0_kernel);
REGISTER_XPU_DISPATCH(
special_modified_bessel_i0_stub,
&xpu::modified_bessel_i0_kernel);
REGISTER_XPU_DISPATCH(
special_modified_bessel_i1_stub,
&xpu::modified_bessel_i1_kernel);
REGISTER_XPU_DISPATCH(
special_modified_bessel_k0_stub,
&xpu::modified_bessel_k0_kernel);
REGISTER_XPU_DISPATCH(
special_modified_bessel_k1_stub,
&xpu::modified_bessel_k1_kernel);
REGISTER_XPU_DISPATCH(
special_spherical_bessel_j0_stub,
&xpu::spherical_bessel_j0_kernel);
REGISTER_XPU_DISPATCH(
special_scaled_modified_bessel_k0_stub,
&xpu::scaled_modified_bessel_k0_kernel);
REGISTER_XPU_DISPATCH(
special_scaled_modified_bessel_k1_stub,
&xpu::scaled_modified_bessel_k1_kernel);

} // namespace native
} // namespace at
7 changes: 5 additions & 2 deletions src/ATen/native/xpu/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ATen/native/xpu/sycl/BinaryMiscOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryRemainderKernel.h>
#include <ATen/native/xpu/sycl/BinaryShiftOpsKernels.h>
#include <ATen/native/xpu/sycl/ChebyshevPolynomialKernels.h>
#include <ATen/native/xpu/sycl/CopysignKernel.h>
#include <ATen/native/xpu/sycl/GcdLcmKernels.h>
#include <ATen/native/xpu/sycl/IGammaKernel.h>
Expand All @@ -23,9 +24,9 @@
#include <ATen/native/xpu/sycl/LegendrePolynomialPKernel.h>
#include <ATen/native/xpu/sycl/LogAddExpKernels.h>
#include <ATen/native/xpu/sycl/MaxMinElementwiseKernels.h>
#include <ATen/native/xpu/sycl/StepKernels.h>
#include <ATen/native/xpu/sycl/ChebyshevPolynomialKernels.h>
#include <ATen/native/xpu/sycl/ShiftedChebyshevPolynomialKernels.h>
#include <ATen/native/xpu/sycl/StepKernels.h>
#include <ATen/native/xpu/sycl/ZetaKernel.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -65,6 +66,8 @@ REGISTER_XPU_DISPATCH(fmin_stub, &xpu::fmin_kernel);
REGISTER_XPU_DISPATCH(lshift_stub, &xpu::lshift_kernel);
REGISTER_XPU_DISPATCH(rshift_stub, &xpu::rshift_kernel);
REGISTER_XPU_DISPATCH(xlogy_stub, &xpu::xlogy_kernel);
REGISTER_XPU_DISPATCH(xlog1py_stub, &xpu::xlog1py_kernel);
REGISTER_XPU_DISPATCH(zeta_stub, &xpu::zeta_kernel);
REGISTER_XPU_DISPATCH(
hermite_polynomial_h_stub,
&xpu::hermite_polynomial_h_kernel);
Expand Down
3 changes: 3 additions & 0 deletions src/ATen/native/xpu/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ REGISTER_XPU_DISPATCH(acos_stub, &xpu::acos_kernel);
REGISTER_XPU_DISPATCH(acosh_stub, &xpu::acosh_kernel);
REGISTER_XPU_DISPATCH(erf_stub, &xpu::erf_kernel);
REGISTER_XPU_DISPATCH(erfc_stub, &xpu::erfc_kernel);

REGISTER_XPU_DISPATCH(erfinv_stub, &xpu::erfinv_kernel);
REGISTER_XPU_DISPATCH(exp2_stub, &xpu::exp2_kernel);
REGISTER_XPU_DISPATCH(expm1_stub, &xpu::expm1_kernel);
Expand All @@ -86,6 +87,8 @@ REGISTER_XPU_DISPATCH(special_i1_stub, &xpu::i1_kernel);
REGISTER_XPU_DISPATCH(special_i1e_stub, &xpu::i1e_kernel);
REGISTER_XPU_DISPATCH(special_ndtri_stub, &xpu::ndtri_kernel);
REGISTER_XPU_DISPATCH(special_log_ndtr_stub, &xpu::log_ndtr_kernel);
REGISTER_XPU_DISPATCH(special_erfcx_stub, &xpu::erfcx_kernel);
REGISTER_XPU_DISPATCH(special_entr_stub, &xpu::entr_kernel);

} // namespace native
} // namespace at
6 changes: 0 additions & 6 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_segment_reduce_backward",
"sinc.out",
"special_airy_ai.out",
"special_entr.out",
"special_erfcx.out",
"special_scaled_modified_bessel_k0.out",
"special_scaled_modified_bessel_k1.out",
"special_xlog1py.out",
"special_zeta.out",
"_thnn_fused_gru_cell",
"_to_sparse",
"_to_sparse_csr",
Expand Down
2 changes: 1 addition & 1 deletion src/ATen/native/xpu/sycl/BesselJ0Kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ void bessel_j0_kernel(TensorIteratorBase& iter) {
});
}

}
} // namespace at::native::xpu
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/sycl/BesselJ1Kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/Dispatch.h>
#include <ATen/native/xpu/sycl/MathExtensions.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/Loops.h>
#include <ATen/native/xpu/sycl/MathExtensions.h>
#include <c10/core/Scalar.h>

#include <ATen/native/xpu/sycl/BesselJ1Kernel.h>
Expand All @@ -24,4 +24,4 @@ void bessel_j1_kernel(TensorIteratorBase& iter) {
});
}

}
} // namespace at::native::xpu
2 changes: 1 addition & 1 deletion src/ATen/native/xpu/sycl/BesselY0Kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ void bessel_y0_kernel(TensorIteratorBase& iter) {
});
}

}
} // namespace at::native::xpu
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/sycl/BesselY1Kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/Dispatch.h>
#include <ATen/native/xpu/sycl/MathExtensions.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/Loops.h>
#include <ATen/native/xpu/sycl/MathExtensions.h>
#include <c10/core/Scalar.h>

#include <ATen/native/xpu/sycl/BesselY1Kernel.h>
Expand All @@ -21,4 +21,4 @@ void bessel_y1_kernel(TensorIteratorBase& iter) {
});
}

}
} // namespace at::native::xpu
22 changes: 22 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,26 @@ void xlogy_kernel(TensorIteratorBase& iter) {
[&]() { gpu_kernel_with_scalars(iter, XlogyFunctor<scalar_t>()); });
}

template <typename scalar_t>
struct Xlog1pyFunctor {
scalar_t operator()(scalar_t x, scalar_t y) const {
if (at::_isnan(y)) {
return NAN;
}
if (x == 0) {
return 0;
}
return x * std::log1p(y);
}
};

void xlog1py_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.common_dtype(),
"xlog1py_xpu",
[&]() { gpu_kernel_with_scalars(iter, Xlog1pyFunctor<scalar_t>()); });
}

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ TORCH_XPU_API void huber_kernel(TensorIterator& iter, double delta);

TORCH_XPU_API void xlogy_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void xlog1py_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
25 changes: 25 additions & 0 deletions src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <ATen/Dispatch.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/Loops.h>
#include <c10/core/Scalar.h>

#include <ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.h>

namespace at::native::xpu {

template <typename scalar_t>
struct ScaledModifiedBesselK0Functor {
scalar_t operator()(scalar_t a) const {
return scaled_modified_bessel_k0_forward(a);
}
};

void scaled_modified_bessel_k0_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(
iter.common_dtype(), "scaled_modified_bessel_k0_xpu", [&]() {
gpu_kernel(iter, ScaledModifiedBesselK0Functor<scalar_t>());
});
}

} // namespace at::native::xpu
9 changes: 9 additions & 0 deletions src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

TORCH_XPU_API void scaled_modified_bessel_k0_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
25 changes: 25 additions & 0 deletions src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <ATen/Dispatch.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/Loops.h>
#include <c10/core/Scalar.h>

#include <ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.h>

namespace at::native::xpu {

template <typename scalar_t>
struct ScaledModifiedBesselK1Functor {
scalar_t operator()(scalar_t a) const {
return scaled_modified_bessel_k1_forward(a);
}
};

void scaled_modified_bessel_k1_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(
iter.common_dtype(), "scaled_modified_bessel_k1_xpu", [&]() {
gpu_kernel(iter, ScaledModifiedBesselK1Functor<scalar_t>());
});
}

} // namespace at::native::xpu
9 changes: 9 additions & 0 deletions src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

TORCH_XPU_API void scaled_modified_bessel_k1_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
36 changes: 36 additions & 0 deletions src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,40 @@ void log_ndtr_kernel(TensorIteratorBase& iter) {
});
}

template <typename scalar_t>
struct EntrFunctor {
scalar_t operator()(scalar_t x) const {
if (at::_isnan(x)) {
return x;
} else if (x > 0) {
return -x * std::log(x);
} else if (x == 0) {
return 0;
}
return static_cast<scalar_t>(-std::numeric_limits<scalar_t>::infinity());
}
};

void entr_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
iter.common_dtype(),
"entr_xpu",
[&]() { gpu_kernel(iter, EntrFunctor<scalar_t>()); });
}

template <typename scalar_t>
struct ErfcxFunctor {
scalar_t operator()(scalar_t a) const {
return calc_erfcx(a);
}
};

void erfcx_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_xpu", [&]() {
gpu_kernel(iter, ErfcxFunctor<scalar_t>());
});
}

} // namespace at::native::xpu
4 changes: 4 additions & 0 deletions src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,8 @@ TORCH_XPU_API void ndtri_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void log_ndtr_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void entr_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void erfcx_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
26 changes: 26 additions & 0 deletions src/ATen/native/xpu/sycl/ZetaKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <ATen/Dispatch.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/Loops.h>
#include <c10/core/Scalar.h>

#include <ATen/native/xpu/sycl/ZetaKernel.h>

namespace at::native::xpu {

template <typename scalar_t>
struct ZetaFunctor {
scalar_t operator()(scalar_t x, scalar_t q) const {
return zeta<scalar_t, /*is_xpu=*/true>(x, q);
}
};

constexpr char zeta_name[] = "zeta";
void zeta_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_xpu", [&]() {
gpu_kernel_with_scalars(iter, ZetaFunctor<scalar_t>());
});
}

} // namespace at::native::xpu
9 changes: 9 additions & 0 deletions src/ATen/native/xpu/sycl/ZetaKernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

TORCH_XPU_API void zeta_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
Loading

0 comments on commit f224138

Please sign in to comment.