-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add aten::special_scaled_modified_bessel_* and its variants (#1038)
- [x] special_scaled_modified_bessel_k0 - [x] special_scaled_modified_bessel_k0.out - [x] special_scaled_modified_bessel_k1 - [x] special_scaled_modified_bessel_k1.out - [x] special_xlog1py - [x] special_xlog1py.out - [x] special_zeta - [x] special_zeta.out - [x] special_entr - [x] special_entr.out - [x] special_erfcx - [x] special_erfcx.out Co-authored-by: Yutao Xu <[email protected]>
- Loading branch information
1 parent
76a690a
commit f224138
Showing
19 changed files
with
300 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,4 +21,4 @@ void bessel_j0_kernel(TensorIteratorBase& iter) { | |
}); | ||
} | ||
|
||
} | ||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,4 +21,4 @@ void bessel_y0_kernel(TensorIteratorBase& iter) { | |
}); | ||
} | ||
|
||
} | ||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/native/Math.h> | ||
#include <ATen/native/TensorIterator.h> | ||
#include <ATen/native/xpu/sycl/Loops.h> | ||
#include <c10/core/Scalar.h> | ||
|
||
#include <ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
template <typename scalar_t> | ||
struct ScaledModifiedBesselK0Functor { | ||
scalar_t operator()(scalar_t a) const { | ||
return scaled_modified_bessel_k0_forward(a); | ||
} | ||
}; | ||
|
||
void scaled_modified_bessel_k0_kernel(TensorIteratorBase& iter) { | ||
AT_DISPATCH_FLOATING_TYPES( | ||
iter.common_dtype(), "scaled_modified_bessel_k0_xpu", [&]() { | ||
gpu_kernel(iter, ScaledModifiedBesselK0Functor<scalar_t>()); | ||
}); | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#pragma once | ||
|
||
#include <ATen/native/TensorIterator.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
TORCH_XPU_API void scaled_modified_bessel_k0_kernel(TensorIteratorBase& iter); | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/native/Math.h> | ||
#include <ATen/native/TensorIterator.h> | ||
#include <ATen/native/xpu/sycl/Loops.h> | ||
#include <c10/core/Scalar.h> | ||
|
||
#include <ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
template <typename scalar_t> | ||
struct ScaledModifiedBesselK1Functor { | ||
scalar_t operator()(scalar_t a) const { | ||
return scaled_modified_bessel_k1_forward(a); | ||
} | ||
}; | ||
|
||
void scaled_modified_bessel_k1_kernel(TensorIteratorBase& iter) { | ||
AT_DISPATCH_FLOATING_TYPES( | ||
iter.common_dtype(), "scaled_modified_bessel_k1_xpu", [&]() { | ||
gpu_kernel(iter, ScaledModifiedBesselK1Functor<scalar_t>()); | ||
}); | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#pragma once | ||
|
||
#include <ATen/native/TensorIterator.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
TORCH_XPU_API void scaled_modified_bessel_k1_kernel(TensorIteratorBase& iter); | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/native/BinaryOps.h> | ||
#include <ATen/native/Math.h> | ||
#include <ATen/native/TensorIterator.h> | ||
#include <ATen/native/xpu/sycl/Loops.h> | ||
#include <c10/core/Scalar.h> | ||
|
||
#include <ATen/native/xpu/sycl/ZetaKernel.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
template <typename scalar_t> | ||
struct ZetaFunctor { | ||
scalar_t operator()(scalar_t x, scalar_t q) const { | ||
return zeta<scalar_t, /*is_xpu=*/true>(x, q); | ||
} | ||
}; | ||
|
||
constexpr char zeta_name[] = "zeta"; | ||
void zeta_kernel(TensorIteratorBase& iter) { | ||
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_xpu", [&]() { | ||
gpu_kernel_with_scalars(iter, ZetaFunctor<scalar_t>()); | ||
}); | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#pragma once | ||
|
||
#include <ATen/native/TensorIterator.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
TORCH_XPU_API void zeta_kernel(TensorIteratorBase& iter); | ||
|
||
} // namespace at::native::xpu |
Oops, something went wrong.