From 28cdc6bdba47f6c3a176834cb9e051cdc6392a47 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 11 Dec 2024 13:10:48 +0800 Subject: [PATCH] Capture rrelu_with_noise noise mutation in compile (#1144) 1. Resolve: https://github.com/pytorch/pytorch/issues/142102 2. Fixing mixed device types in input Tensors of `torch.lerp` --------- Co-authored-by: Feng Yuan Co-authored-by: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> --- src/ATen/native/xpu/RreluWithNoise.cpp | 6 +++--- src/ATen/native/xpu/sycl/LerpKernels.cpp | 14 ++++++++++++++ src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp | 4 ++-- src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h | 2 +- yaml/native/native_functions.yaml | 7 ++++--- 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/ATen/native/xpu/RreluWithNoise.cpp b/src/ATen/native/xpu/RreluWithNoise.cpp index f66833983..fb4e2c333 100644 --- a/src/ATen/native/xpu/RreluWithNoise.cpp +++ b/src/ATen/native/xpu/RreluWithNoise.cpp @@ -6,7 +6,7 @@ namespace native { Tensor& rrelu_with_noise_out_xpu( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -18,7 +18,7 @@ Tensor& rrelu_with_noise_out_xpu( Tensor rrelu_with_noise_xpu( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -30,7 +30,7 @@ Tensor rrelu_with_noise_xpu( Tensor& rrelu_with_noise_xpu_( Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/src/ATen/native/xpu/sycl/LerpKernels.cpp b/src/ATen/native/xpu/sycl/LerpKernels.cpp index 1648f193b..9d7551290 100644 --- a/src/ATen/native/xpu/sycl/LerpKernels.cpp +++ b/src/ATen/native/xpu/sycl/LerpKernels.cpp @@ -57,15 +57,29 @@ struct LerpScalarFunctor { opmath_t weight_val_; }; +void lerp_scalar_kernel( + at::TensorIteratorBase& iter, + const c10::Scalar& weight); + void lerp_tensor_kernel(at::TensorIteratorBase& iter) { auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_xpu", [&] { + if (iter.is_cpu_scalar(3)) { + auto weight_val = iter.scalar_value(3); + iter.remove_operand(3); + return lerp_scalar_kernel(iter, weight_val); + } gpu_kernel(iter, LerpTensorComplexFunctor()); }); } else { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "lerp_xpu", [&] { + if (iter.is_cpu_scalar(3)) { + auto weight_val = iter.scalar_value(3); + iter.remove_operand(3); + return lerp_scalar_kernel(iter, weight_val); + } gpu_kernel(iter, LerpTensorFunctor()); }); } diff --git a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp index 533630175..7f6f33805 100644 --- a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp +++ b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp @@ -86,7 +86,7 @@ template inline void _rrelu_with_noise_xpu_train( Tensor& output, const Tensor& input_, - const Tensor& noise_, + Tensor& noise_, const Scalar& lower_, const Scalar& upper_, std::optional generator) { @@ -153,7 +153,7 @@ inline void _rrelu_with_noise_xpu_train( Tensor& rrelu_with_noise_kernel( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h index 8371c38ab..fa7e568ea 100644 --- a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h +++ b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h @@ -7,7 +7,7 @@ namespace at::native::xpu { TORCH_XPU_API Tensor& rrelu_with_noise_kernel( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 40b710c12..f76f49fb8 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8184,17 +8184,18 @@ variants: function tags: pointwise -- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) +- func: rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn tags: nondeterministic_seeded dispatch: XPU: rrelu_with_noise_out_xpu -- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor +- func: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor python_module: nn dispatch: XPU: rrelu_with_noise_xpu tags: nondeterministic_seeded + autogen: rrelu_with_noise_functional - func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor python_module: nn @@ -8202,7 +8203,7 @@ CompositeExplicitAutograd: rrelu_with_noise_backward autogen: rrelu_with_noise_backward.out -- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) +- func: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) python_module: nn tags: nondeterministic_seeded dispatch: