Skip to content

Commit

Permalink
Add aten::heaviside and its variants (#1024)
Browse files Browse the repository at this point in the history
- [x] heaviside.out  
- [x] heaviside
- [x] heaviside_

> (kanya_pt)
kanyamo@adl104051:/localdisk/kanya/pytorch/third_party/torch-xpu-ops$
PYTORCH_DEBUG_XPU_FALLBACK=1 PYTORCH_TEST_WITH_SLOW=1 pytest -vs
./test/xpu/test_ops_xpu.py -k heaviside
> ================================================= test session starts
==================================================
> platform linux -- Python 3.10.15, pytest-8.3.3, pluggy-1.5.0 --
/localdisk/miniforge3/envs/kanya_pt/bin/python3.10
> cachedir: .pytest_cache
> hypothesis profile 'default' ->
database=DirectoryBasedExampleDatabase(PosixPath('/localdisk/kanya/pytorch/third_party/torch-xpu-ops/.hypothesis/examples'))
> rootdir: /localdisk/kanya/pytorch
> configfile: pytest.ini
> plugins: hypothesis-6.114.1
> collected 14054 items / 14004 deselected / 50 selected
> 
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_dtypes__refs_heaviside_xpu
PASSED
> test/xpu/test_ops_xpu.py::TestCommonXPU::test_dtypes_heaviside_xpu
PASSED
> test/xpu/test_ops_xpu.py::TestCommonXPU::test_errors_heaviside_xpu
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_multiple_devices_heaviside_xpu_float32
SKIPPED (fewer than 2 d...)
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_multiple_devices_heaviside_xpu_int64
SKIPPED (fewer than 2 dev...)
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_non_standard_bool_values_heaviside_xpu_bool
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_noncontiguous_samples_heaviside_xpu_float32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_noncontiguous_samples_heaviside_xpu_int64
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_out__refs_heaviside_xpu_float32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_out_heaviside_xpu_float32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_out_warning__refs_heaviside_xpu
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_out_warning_heaviside_xpu
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref__refs_heaviside_xpu_bfloat16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref__refs_heaviside_xpu_bool
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref__refs_heaviside_xpu_float16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref__refs_heaviside_xpu_float32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref__refs_heaviside_xpu_int16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref__refs_heaviside_xpu_int32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref__refs_heaviside_xpu_int64
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref__refs_heaviside_xpu_int8
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref__refs_heaviside_xpu_uint8
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_errors__refs_heaviside_xpu
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_executor__refs_heaviside_executor_aten_xpu_bfloat16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_executor__refs_heaviside_executor_aten_xpu_bool
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_executor__refs_heaviside_executor_aten_xpu_float16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_executor__refs_heaviside_executor_aten_xpu_float32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_executor__refs_heaviside_executor_aten_xpu_int16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_executor__refs_heaviside_executor_aten_xpu_int32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_executor__refs_heaviside_executor_aten_xpu_int64
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_executor__refs_heaviside_executor_aten_xpu_int8
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_executor__refs_heaviside_executor_aten_xpu_uint8
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_meta__refs_heaviside_xpu_bfloat16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_meta__refs_heaviside_xpu_bool
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_meta__refs_heaviside_xpu_float16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_meta__refs_heaviside_xpu_float32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_meta__refs_heaviside_xpu_int16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_meta__refs_heaviside_xpu_int32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_meta__refs_heaviside_xpu_int64
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_meta__refs_heaviside_xpu_int8
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_meta__refs_heaviside_xpu_uint8
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_torch_fallback__refs_heaviside_xpu_bfloat16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_torch_fallback__refs_heaviside_xpu_bool
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_torch_fallback__refs_heaviside_xpu_float16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_torch_fallback__refs_heaviside_xpu_float32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_torch_fallback__refs_heaviside_xpu_int16
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_torch_fallback__refs_heaviside_xpu_int32
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_torch_fallback__refs_heaviside_xpu_int64
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_torch_fallback__refs_heaviside_xpu_int8
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_python_ref_torch_fallback__refs_heaviside_xpu_uint8
PASSED
>
test/xpu/test_ops_xpu.py::TestCommonXPU::test_variant_consistency_eager_heaviside_xpu_float32
PASSED
> 
> =================================== 48 passed, 2 skipped, 14004
deselected in 16.83s ===================================
  • Loading branch information
Kanya-Mo authored Oct 29, 2024
1 parent 2d43f11 commit 9b63af2
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/ATen/native/xpu/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ REGISTER_XPU_DISPATCH(maximum_stub, &xpu::maximum_kernel);
REGISTER_XPU_DISPATCH(minimum_stub, &xpu::minimum_kernel);
REGISTER_XPU_DISPATCH(sigmoid_backward_stub, &xpu::sigmoid_backward_kernel);
REGISTER_XPU_DISPATCH(nextafter_stub, &xpu::nextafter_kernel);
REGISTER_XPU_DISPATCH(heaviside_stub, &xpu::heaviside_kernel);
REGISTER_XPU_DISPATCH(hypot_stub, &xpu::hypot_kernel);
REGISTER_XPU_DISPATCH(atan2_stub, &xpu::atan2_kernel);
REGISTER_XPU_DISPATCH(copysign_stub, &xpu::copysign_kernel);
Expand Down
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"frexp.Tensor_out",
"_fused_moving_avg_obs_fq_helper",
"geqrf",
"heaviside.out",
"i0.out",
"igammac.out",
"igamma.out",
Expand Down
14 changes: 14 additions & 0 deletions src/ATen/native/xpu/sycl/StepKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,25 @@ struct NextafterFunctor {
}
};

template <typename scalar_t>
struct HeavisideFunctor {
scalar_t operator()(scalar_t a, scalar_t b) const {
return a == 0 ? b : static_cast<scalar_t>(a > 0);
}
};

void nextafter_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.common_dtype(), "nextafter_xpu", [&]() {
gpu_kernel_with_scalars(iter, NextafterFunctor<scalar_t>());
});
}

void heaviside_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND3(
kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_xpu", [&]() {
gpu_kernel_with_scalars(iter, HeavisideFunctor<scalar_t>());
});
}

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/StepKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ namespace at::native::xpu {

TORCH_XPU_API void nextafter_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void heaviside_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"nn.functional.softplus",
"nn.functional.softshrink",
"nextafter",
"heaviside",
"nonzero",
"normal",
"pow",
Expand Down
19 changes: 19 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4965,6 +4965,25 @@
variants: method
tags: pointwise

- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
dispatch:
XPU: heaviside_out
tags: pointwise

- func: heaviside(Tensor self, Tensor values) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
structured_delegate: heaviside.out
tags: pointwise

- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
structured_delegate: heaviside.out

- func: logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
structured: True
Expand Down

0 comments on commit 9b63af2

Please sign in to comment.