Skip to content

Commit

Permalink
Capture rrelu_with_noise noise mutation in compile (#1144)
Browse files Browse the repository at this point in the history
1. Resolve: pytorch/pytorch#142102
2. Fixing mixed device types in input Tensors of `torch.lerp`

---------

Co-authored-by: Feng Yuan <[email protected]>
Co-authored-by: chunhuanMeng <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2024
1 parent 4beb7d3 commit 28cdc6b
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/ATen/native/xpu/RreluWithNoise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace native {

Tensor& rrelu_with_noise_out_xpu(
const Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
Expand All @@ -18,7 +18,7 @@ Tensor& rrelu_with_noise_out_xpu(

Tensor rrelu_with_noise_xpu(
const Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
Expand All @@ -30,7 +30,7 @@ Tensor rrelu_with_noise_xpu(

Tensor& rrelu_with_noise_xpu_(
Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
Expand Down
14 changes: 14 additions & 0 deletions src/ATen/native/xpu/sycl/LerpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,29 @@ struct LerpScalarFunctor {
opmath_t weight_val_;
};

void lerp_scalar_kernel(
at::TensorIteratorBase& iter,
const c10::Scalar& weight);

void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_xpu", [&] {
if (iter.is_cpu_scalar(3)) {
auto weight_val = iter.scalar_value<scalar_t>(3);
iter.remove_operand(3);
return lerp_scalar_kernel(iter, weight_val);
}
gpu_kernel(iter, LerpTensorComplexFunctor<scalar_t>());
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "lerp_xpu", [&] {
if (iter.is_cpu_scalar(3)) {
auto weight_val = iter.scalar_value<scalar_t>(3);
iter.remove_operand(3);
return lerp_scalar_kernel(iter, weight_val);
}
gpu_kernel(iter, LerpTensorFunctor<scalar_t>());
});
}
Expand Down
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ template <typename scalar_t>
inline void _rrelu_with_noise_xpu_train(
Tensor& output,
const Tensor& input_,
const Tensor& noise_,
Tensor& noise_,
const Scalar& lower_,
const Scalar& upper_,
std::optional<Generator> generator) {
Expand Down Expand Up @@ -153,7 +153,7 @@ inline void _rrelu_with_noise_xpu_train(

Tensor& rrelu_with_noise_kernel(
const Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
Expand Down
2 changes: 1 addition & 1 deletion src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace at::native::xpu {

TORCH_XPU_API Tensor& rrelu_with_noise_kernel(
const Tensor& self,
const Tensor& noise,
Tensor& noise,
const Scalar& lower,
const Scalar& upper,
bool training,
Expand Down
7 changes: 4 additions & 3 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8184,25 +8184,26 @@
variants: function
tags: pointwise

- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
- func: rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
tags: nondeterministic_seeded
dispatch:
XPU: rrelu_with_noise_out_xpu

- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
- func: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
python_module: nn
dispatch:
XPU: rrelu_with_noise_xpu
tags: nondeterministic_seeded
autogen: rrelu_with_noise_functional

- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
python_module: nn
dispatch:
CompositeExplicitAutograd: rrelu_with_noise_backward
autogen: rrelu_with_noise_backward.out

- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
- func: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
python_module: nn
tags: nondeterministic_seeded
dispatch:
Expand Down

0 comments on commit 28cdc6b

Please sign in to comment.