Skip to content

Commit

Permalink
Add aten::log_normal_ (#766)
Browse files Browse the repository at this point in the history
- log_normal_

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
hjhee and xytintel authored Oct 17, 2024
1 parent 94d0ee6 commit c3f7c54
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/ATen/native/xpu/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ REGISTER_XPU_DISPATCH(
REGISTER_XPU_DISPATCH(
multinomial_with_replacement_stub,
&xpu::multinomial_kernel);
REGISTER_XPU_DISPATCH(log_normal_stub, &xpu::log_normal_kernel);
} // namespace native
} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_linalg_svd.U",
"linspace.out",
"_logcumsumexp",
"log_normal_",
"logspace.out",
"lu_unpack.out",
"max_pool3d_with_indices",
Expand Down
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/sycl/DistributionKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,10 @@ TORCH_XPU_API void exponential_kernel(
double lambda,
c10::optional<Generator> gen);

TORCH_XPU_API void log_normal_kernel(
TensorIteratorBase& iter,
double mean,
double std,
std::optional<Generator> gen);

} // namespace at::native::xpu
17 changes: 17 additions & 0 deletions src/ATen/native/xpu/sycl/DistributionLogNormalKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/DistributionTemplates.h>
#include <ATen/xpu/XPUGeneratorImpl.h>

namespace at::native::xpu {

void log_normal_kernel(
TensorIteratorBase& iter,
double mean,
double std,
std::optional<Generator> gen) {
auto generator = get_generator_or_default<at::XPUGeneratorImpl>(
gen, at::xpu::detail::getDefaultXPUGenerator());
at::native::templates::xpu::log_normal_kernel(iter, mean, std, generator);
}

} // namespace at::native::xpu
38 changes: 38 additions & 0 deletions src/ATen/native/xpu/sycl/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,44 @@ void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG gen) {
});
}

// ====================== LogNormal ======================

template <typename scalar_t, typename accscalar_t>
struct LogNormalFunctor {
scalar_t operator()(accscalar_t rand) const {
return static_cast<scalar_t>(transformation::log_normal<accscalar_t>(
transformation::normal<accscalar_t>(rand, mean_, std_)));
}
LogNormalFunctor(accscalar_t mean, accscalar_t std)
: mean_(mean), std_(std) {}

private:
accscalar_t mean_;
accscalar_t std_;
};

template <typename RNG>
void log_normal_kernel(
TensorIteratorBase& iter,
double mean,
double std,
RNG gen) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.dtype(),
"log_normal_xpu_",
[&] {
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
auto mean_ = static_cast<accscalar_t>(mean);
auto std_ = static_cast<accscalar_t>(std);
// define functor to multiply std and add mean
LogNormalFunctor<scalar_t, accscalar_t> log_normal_functor(mean_, std_);
normal_and_transform<scalar_t, accscalar_t, rand4_engine_calls>(
iter, gen, log_normal_functor);
});
}

} // namespace xpu
} // namespace templates
} // namespace native
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@
"square",
"heaviside",
"argsort",
"log_normal",
]

_ops_without_cuda_support = [
Expand Down
8 changes: 8 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4100,6 +4100,14 @@
XPU: fmin_out
tags: pointwise

- func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: nondeterministic_seeded
variants: method
dispatch:
XPU: log_normal_
autogen: log_normal, log_normal.out

- func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: nondeterministic_seeded
Expand Down

0 comments on commit c3f7c54

Please sign in to comment.