diff --git a/src/ATen/native/xpu/Distributions.cpp b/src/ATen/native/xpu/Distributions.cpp index bce51bbf8..264a368f8 100644 --- a/src/ATen/native/xpu/Distributions.cpp +++ b/src/ATen/native/xpu/Distributions.cpp @@ -28,5 +28,6 @@ REGISTER_XPU_DISPATCH( REGISTER_XPU_DISPATCH( multinomial_with_replacement_stub, &xpu::multinomial_kernel); +REGISTER_XPU_DISPATCH(log_normal_stub, &xpu::log_normal_kernel); } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 7226a6c4e..6819e9c51 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -213,7 +213,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_linalg_svd.U", "linspace.out", "_logcumsumexp", - "log_normal_", "logspace.out", "lu_unpack.out", "max_pool3d_with_indices", diff --git a/src/ATen/native/xpu/sycl/DistributionKernels.h b/src/ATen/native/xpu/sycl/DistributionKernels.h index 88d62933f..494033dbb 100644 --- a/src/ATen/native/xpu/sycl/DistributionKernels.h +++ b/src/ATen/native/xpu/sycl/DistributionKernels.h @@ -45,4 +45,10 @@ TORCH_XPU_API void exponential_kernel( double lambda, c10::optional gen); +TORCH_XPU_API void log_normal_kernel( + TensorIteratorBase& iter, + double mean, + double std, + std::optional gen); + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/DistributionLogNormalKernel.cpp b/src/ATen/native/xpu/sycl/DistributionLogNormalKernel.cpp new file mode 100644 index 000000000..4726209e0 --- /dev/null +++ b/src/ATen/native/xpu/sycl/DistributionLogNormalKernel.cpp @@ -0,0 +1,17 @@ +#include +#include +#include + +namespace at::native::xpu { + +void log_normal_kernel( + TensorIteratorBase& iter, + double mean, + double std, + std::optional gen) { + auto generator = get_generator_or_default( + gen, at::xpu::detail::getDefaultXPUGenerator()); + at::native::templates::xpu::log_normal_kernel(iter, mean, std, generator); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/DistributionTemplates.h b/src/ATen/native/xpu/sycl/DistributionTemplates.h index f5a5efdb5..f696e50d1 100644 --- a/src/ATen/native/xpu/sycl/DistributionTemplates.h +++ b/src/ATen/native/xpu/sycl/DistributionTemplates.h @@ -698,6 +698,44 @@ void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG gen) { }); } +// ====================== LogNormal ====================== + +template +struct LogNormalFunctor { + scalar_t operator()(accscalar_t rand) const { + return static_cast(transformation::log_normal( + transformation::normal(rand, mean_, std_))); + } + LogNormalFunctor(accscalar_t mean, accscalar_t std) + : mean_(mean), std_(std) {} + + private: + accscalar_t mean_; + accscalar_t std_; +}; + +template +void log_normal_kernel( + TensorIteratorBase& iter, + double mean, + double std, + RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "log_normal_xpu_", + [&] { + using accscalar_t = at::acc_type_device; + auto mean_ = static_cast(mean); + auto std_ = static_cast(std); + // define functor to multiply std and add mean + LogNormalFunctor log_normal_functor(mean_, std_); + normal_and_transform( + iter, gen, log_normal_functor); + }); +} + } // namespace xpu } // namespace templates } // namespace native diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 42f6ab715..80922171a 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -249,6 +249,7 @@ "square", "heaviside", "argsort", + "log_normal", ] _ops_without_cuda_support = [ diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 71fd4f342..2e9b5aabb 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -4100,6 +4100,14 @@ XPU: fmin_out tags: pointwise +- func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + variants: method + dispatch: + XPU: log_normal_ + autogen: log_normal, log_normal.out + - func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator tags: nondeterministic_seeded