From 00b30b431f7dfe3b3aed880e33e10f38d5dc301c Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:51:19 +0800 Subject: [PATCH 1/8] Unit test: Evaluate and update skip list of test_meta. (#1103) relate issues: https://github.com/intel/torch-xpu-ops/issues/774, https://github.com/intel/torch-xpu-ops/issues/922 --------- Co-authored-by: Yutao Xu --- .../xpu/sycl/UpSampleBilinear2dKernels.cpp | 4 +- test/xpu/extended/skip_list_common.py | 4 - test/xpu/skip_list_common.py | 83 ++++--------------- test/xpu/xpu_test_utils.py | 2 + 4 files changed, 20 insertions(+), 73 deletions(-) diff --git a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp index bbf51625d..e5a717495 100644 --- a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp @@ -711,8 +711,8 @@ void upsample_bilinear2d_backward_out_kernel( : at::zeros(grad_input.sizes(), grad_input.options()); Tensor grad_output = grad_output_.contiguous(); - scalar_t* idata = grad_input_c.data_ptr(); - scalar_t* odata = grad_output.data_ptr(); + scalar_t* idata = grad_input_c.mutable_data_ptr(); + const scalar_t* odata = grad_output.const_data_ptr(); const accscalar_t rheight = area_pixel_compute_scale( input_height, output_height, align_corners, scales_h); diff --git a/test/xpu/extended/skip_list_common.py b/test/xpu/extended/skip_list_common.py index db53d6b4f..6b5fd653e 100644 --- a/test/xpu/extended/skip_list_common.py +++ b/test/xpu/extended/skip_list_common.py @@ -181,10 +181,6 @@ "test_operator_multinomial_xpu_float32", "test_view_replay_multinomial_xpu_float32" - # https://github.com/intel/torch-xpu-ops/issues/922 - "test_compare_cpu_isin_xpu_bfloat16", - "test_compare_cpu_unique_consecutive_xpu_bfloat16", - # returned index is dependent on input data and implementation detail, and no # specification is given to uniquely identify the correct index # (e.g. index with maximal / minimal value) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 32edc14d9..eab76482b 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -660,9 +660,6 @@ # Unexpected success, CUDA got XFAIL because CUDA does not have historgramadd supported" "test_errors_histogramdd_xpu", - # https://github.com/intel/torch-xpu-ops/issues/922 - "test_dtypes_isin_xpu", - # NotImplementedError: The operator 'aten::_assert_async.msg' is not currently implemented for the XPU device. "test_view_replay_multinomial_xpu_float32", @@ -2558,14 +2555,26 @@ "test_meta_xpu.py": ( # https://github.com/intel/torch-xpu-ops/issues/774 "_jiterator_", - # segment fault + + + # RuntimeError: Short is not supported in oneDNN! Need oneDNN's support, suggest to keep skip. "test_dispatch_meta_outplace_nn_functional_linear_xpu_int16", - "test_dispatch_meta_outplace_nn_functional_linear_xpu_int64", "test_dispatch_symbolic_meta_outplace_nn_functional_linear_xpu_int16", - "test_dispatch_symbolic_meta_outplace_nn_functional_linear_xpu_int64", "test_meta_outplace_nn_functional_linear_xpu_int16", + + # RuntimeError: Long is not supported in oneDNN! Need oneDNN's support, suggest to keep skip. + "test_dispatch_meta_outplace_nn_functional_linear_xpu_int64", + "test_dispatch_symbolic_meta_outplace_nn_functional_linear_xpu_int64", "test_meta_outplace_nn_functional_linear_xpu_int64", + # RuntimeError: Double and complex datatype matmul is not supported in oneDNN + + "test_dispatch_meta_inplace_addbmm_xpu_complex", + "test_dispatch_meta_outplace_addbmm_xpu_complex", + "test_dispatch_symbolic_meta_inplace_addbmm_xpu_complex", + "test_dispatch_symbolic_meta_outplace_addbmm_xpu_complex", + "test_meta_inplace_addbmm_xpu_complex", + "test_meta_outplace_addbmm_xpu_complex", "test_dispatch_meta_inplace_addbmm_xpu_float64", "test_dispatch_meta_inplace_addmm_decomposed_xpu_complex", "test_dispatch_meta_inplace_addmm_decomposed_xpu_float64", @@ -3284,72 +3293,12 @@ "test_meta_outplace_nn_functional_conv_transpose3d_xpu_bfloat16", "test_meta_outplace_nn_functional_conv_transpose3d_xpu_complex", "test_meta_outplace_nn_functional_conv_transpose3d_xpu_float", - # _foreach_norm: RuntimeError: output 1: meta disagrees with real impl: - "test_dispatch_meta_outplace__foreach_norm_xpu_bfloat16", - "test_dispatch_meta_outplace__foreach_norm_xpu_float", - "test_dispatch_symbolic_meta_outplace__foreach_norm_xpu_bfloat16", - "test_dispatch_symbolic_meta_outplace__foreach_norm_xpu_float", - "test_dispatch_symbolic_meta_outplace_all_strides__foreach_norm_xpu_float32", - "test_meta_outplace__foreach_norm_xpu_bfloat16", - "test_meta_outplace__foreach_norm_xpu_float", - # RuntimeError: value cannot be converted to type float without overflow - "test_dispatch_meta_inplace_addbmm_xpu_complex", - "test_dispatch_meta_outplace_addbmm_xpu_complex", - "test_dispatch_symbolic_meta_inplace_addbmm_xpu_complex", - "test_dispatch_symbolic_meta_outplace_addbmm_xpu_complex", - "test_meta_inplace_addbmm_xpu_complex", - "test_meta_outplace_addbmm_xpu_complex", - # RuntimeError: false INTERNAL ASSERT FAILED at "pytorch/aten/src/ATen/native/DispatchStub.cpp":220, please report a bug to PyTorch. DispatchStub: missing kernel for xpu - "test_dispatch_meta_outplace_nanmean_xpu", - "test_dispatch_symbolic_meta_outplace_all_strides_nanmean_xpu_float32", - "test_dispatch_symbolic_meta_outplace_nanmean_xpu", - "test_meta_outplace_nanmean_xpu", - # RuntimeError: "avg_pool2d_xpu" not implemented for 'Long' - # run dtype of cpu. It should run dtypeifcuda. add 'nn.functional.avg_pool1d' and 'nn.functional.local_response_norm' to '_xpu_computation_op_list' will skip these case - "test_dispatch_meta_outplace_nn_functional_avg_pool1d_xpu_int64", - "test_dispatch_symbolic_meta_outplace_nn_functional_avg_pool1d_xpu_int64", - "test_meta_outplace_nn_functional_avg_pool1d_xpu_int64", - "test_dispatch_meta_outplace_nn_functional_local_response_norm_xpu_int64", - "test_dispatch_symbolic_meta_outplace_nn_functional_local_response_norm_xpu_int64", - "test_meta_outplace_nn_functional_local_response_norm_xpu_int64", - # RuntimeError: output 0: meta disagrees with real impl: + # Not implemented, try these cases after implementing vdot "test_dispatch_meta_outplace_vdot_xpu_complex", "test_dispatch_symbolic_meta_outplace_vdot_xpu_complex", "test_meta_outplace_vdot_xpu_complex", # Unexpected success: - "test_dispatch_meta_inplace__foreach_lgamma_xpu_bfloat16", - "test_dispatch_meta_inplace__foreach_sigmoid_xpu_complex", - "test_dispatch_meta_outplace__foreach_lgamma_xpu_bfloat16", - "test_dispatch_meta_outplace__foreach_sigmoid_xpu_complex", - "test_dispatch_symbolic_meta_inplace__foreach_lgamma_xpu_bfloat16", - "test_dispatch_symbolic_meta_inplace__foreach_sigmoid_xpu_complex", - "test_dispatch_symbolic_meta_outplace__foreach_lgamma_xpu_bfloat16", - "test_dispatch_symbolic_meta_outplace__foreach_sigmoid_xpu_complex", "test_dispatch_symbolic_meta_outplace_all_strides_narrow_copy_xpu_float32", - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_channel_shuffle_xpu_float32", - "test_meta_inplace__foreach_lgamma_xpu_bfloat16", - "test_meta_inplace__foreach_sigmoid_xpu_complex", - "test_meta_outplace__foreach_lgamma_xpu_bfloat16", - "test_meta_outplace__foreach_sigmoid_xpu_complex", - # adaptive_max_pool2d: Expected out tensor to have dtype c10::BFloat16/c10::Half/float/double, but got long int instead - "test_dispatch_meta_outplace_nn_functional_adaptive_max_pool1d_xpu_bfloat16", - "test_dispatch_meta_outplace_nn_functional_adaptive_max_pool1d_xpu_float", - "test_dispatch_meta_outplace_nn_functional_adaptive_max_pool2d_xpu_bfloat16", - "test_dispatch_meta_outplace_nn_functional_adaptive_max_pool2d_xpu_float", - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_adaptive_max_pool1d_xpu_float32", - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_adaptive_max_pool2d_xpu_float32", - "test_dispatch_symbolic_meta_outplace_nn_functional_adaptive_max_pool1d_xpu_bfloat16", - "test_dispatch_symbolic_meta_outplace_nn_functional_adaptive_max_pool1d_xpu_float", - "test_dispatch_symbolic_meta_outplace_nn_functional_adaptive_max_pool2d_xpu_bfloat16", - "test_dispatch_symbolic_meta_outplace_nn_functional_adaptive_max_pool2d_xpu_float", - - # https://github.com/intel/torch-xpu-ops/issues/922 - "test_dispatch_meta_outplace_isin_xpu_bfloat16", - "test_dispatch_meta_outplace_unique_consecutive_xpu_bfloat16", - "test_dispatch_symbolic_meta_outplace_isin_xpu_bfloat16", - "test_dispatch_symbolic_meta_outplace_unique_consecutive_xpu_bfloat16", - "test_meta_outplace_isin_xpu_bfloat16", - "test_meta_outplace_unique_consecutive_xpu_bfloat16", ), "test_type_promotion_xpu.py": None, diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 5d855a72f..5cf1de64c 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -130,6 +130,7 @@ "nn.functional.hardsigmoid", "nn.functional.softplus", "nn.functional.softshrink", + "nn.functional.local_response_norm", "nextafter", "heaviside", "nonzero", @@ -200,6 +201,7 @@ "nn.functional.max_pool3d", "nn.functional.adaptive_avg_pool2d", "nn.functional.adaptive_avg_pool3d", + "nn.functional.avg_pool1d", "nn.functional.avg_pool2d", "nn.functional.avg_pool3d", "nn.functional.embedding", From ce98b918f7c408b86bfa8256a7bf41f4127b4d44 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 27 Nov 2024 13:28:51 +0800 Subject: [PATCH 2/8] Update skip list in `test_ops_xpu` (#1102) Enable more test cases in `test_ops_xpu.py`. --- test/xpu/skip_list_common.py | 45 ------------------------------------ 1 file changed, 45 deletions(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index eab76482b..106c2307c 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -55,15 +55,6 @@ # Issue https://github.com/intel/torch-xpu-ops/issues/327 "test_numpy_ref_linalg_tensorinv_xpu_float64", - # RuntimeError: false INTERNAL ASSERT FAILED at "/home/gta/daisyden/pytorch4/aten/src/ATen/native/DispatchStub.cpp":220, please report a bug to PyTorch. DispatchStub: missing kernel for xpu - "test_out_nanmean_xpu_float32", - "test_out_warning_nanmean_xpu", - - # NameError: name 'nanj' is not defined. Did you mean: 'nan'? - # https://github.com/intel/torch-xpu-ops/issues/768 - "test_python_ref_executor__refs_logaddexp_executor_aten_xpu_complex128", - "test_python_ref_executor__refs_logaddexp_executor_aten_xpu_complex64", - # RuntimeError: could not create a primitive descriptor for a deconvolution # https://github.com/intel/torch-xpu-ops/issues/253 "test_variant_consistency_eager_nn_functional_conv_transpose2d_xpu_complex64", @@ -77,7 +68,6 @@ "test_compare_cpu_linalg_lu_factor_xpu_float32", "test_compare_cpu_linalg_lu_xpu_float32", "test_compare_cpu_special_hermite_polynomial_h_xpu_float32", - "test_compare_cpu_special_zeta_xpu_float32", # XFAIL of CUDA and XPU, unexpected success in fallback "test_out_cholesky_inverse_xpu_float32", @@ -104,9 +94,6 @@ # Cuda skipped it "test_non_standard_bool_values_sort_xpu_bool", # The implementation aligns with CUDA, RuntimeError: "sort" not implemented for 'Bool'. - # Cuda skipped it - "test_non_standard_bool_values_msort_xpu_bool", # The implementation aligns with CUDA, RuntimeError: "msort" not implemented for 'Bool'. - # Cuda XFAIL (stock pytorch commit: e7cf7d0) "test_non_standard_bool_values_argsort_xpu_bool", @@ -635,48 +622,16 @@ "test_noncontiguous_samples_nn_functional_avg_pool1d_xpu_int64", "test_noncontiguous_samples_nn_functional_local_response_norm_xpu_int64", - #AssertionError: The supported dtypes for unique_consecutive on device type xpu are incorrect! - #The following dtypes worked in forward but are not listed by the OpInfo: {torch.bfloat16}. - # XPU supports bfloat16, CUDA doesn't support it. - "test_dtypes_unique_xpu", # RuntimeError: Expected both inputs to be Half, Float or Double tensors but got BFloat16 and BFloat16. # Polar's backward is calculated using complex(), which does not support bfloat16. CUDA fails with same error. #"test_dtypes_polar_xpu", # implemented aten::histogram to align MPS operators coverage, CUDA doesn't support # but test_dtypes infrastructure leverage CUDA supported datatypes "test_dtypes_histogram_xpu", - # The following dtypes worked in forward but are not listed by the OpInfo: {torch.float16}. - # Align with CPU implementation since, - # 1. most cases of nextafter require Half dtype. - # 2. Half dtype is a common dtype in workloads. - # So far CUDA doesn't support Half, so that XPU fails as we aligned claimed dtypes with CUDA in test infra. - # https://github.com/intel/torch-xpu-ops/issues/623 - "test_dtypes_nextafter_xpu", - # AssertionError: The supported dtypes for argsort on device type xpu are incorrect! - # The following dtypes worked in forward but are not listed by the OpInfo: {torch.bool}. - # CUDA does not have torch.bool support on argsort. - "test_dtypes_argsort_xpu", # Unexpected success, CUDA got XFAIL because CUDA does not have historgramadd supported" "test_errors_histogramdd_xpu", - # NotImplementedError: The operator 'aten::_assert_async.msg' is not currently implemented for the XPU device. - "test_view_replay_multinomial_xpu_float32", - - # AssertionError: The supported dtypes for nn.functional.max_unpool3d on device type xpu are incorrect! - # The following dtypes worked in forward but are not listed by the OpInfo: {torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}. - "test_dtypes_nn_functional_max_unpool3d_grad_xpu", - "test_dtypes_nn_functional_max_unpool3d_xpu", - - # Unknown error with indexSelectBackward - # AssertionError: The supported dtypes for _refs.nn.functional.pdist on device type xpu are incorrect! - # The following dtypes did not work in forward but are listed by the OpInfo: {torch.float64}. - # Unexpected failures raised the following errors: - # torch.float64 - Native API failed. Native API returns: -5 (PI_ERROR_OUT_OF_RESOURCES) -5 (PI_ERROR_OUT_OF_RESOURCES) - # FATAL: Unexpected page fault from GPU at 0x0, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 3 (PML4), access: 0 (Read), banned: 1, aborting. - # FATAL: Unexpected page fault from GPU at 0x0, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 3 (PML4), access: 0 (Read), banned: 1, aborting. - "test_dtypes__refs_nn_functional_pdist_xpu", - # 2025 bundle std::pow complex result is different on host and device "test_python_ref__refs_square_xpu_complex64", "test_python_ref_torch_fallback__refs_square_xpu_complex64", From 41b282fb12a565f2f7b630b84273bc6d5bb5fc55 Mon Sep 17 00:00:00 2001 From: gaopengff Date: Wed, 27 Nov 2024 15:24:12 +0800 Subject: [PATCH 3/8] Add quantized_maxpool_2d for xpu (#1049) Now we only support datatype of uint8(Byte). Referring the stock pytorch cpu implementation at [code](https://github.com/pytorch/pytorch/blob/v2.5.0/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp#L1452). Waiting https://github.com/intel/torch-xpu-ops/pull/921 to be merged. --- .../native/quantized/QuantizedMaxPool2d.cpp | 49 +++ .../quantized/sycl/QuantizedMaxPool2d.cpp | 325 ++++++++++++++++++ .../quantized/sycl/QuantizedMaxPool2d.h | 15 + .../core/test_quantized_op_xpu.py | 56 +++ test/xpu/skip_list_common.py | 8 + test/xpu/xpu_test_utils.py | 1 + yaml/native/native_functions.yaml | 5 + 7 files changed, 459 insertions(+) create mode 100644 src/ATen/native/quantized/QuantizedMaxPool2d.cpp create mode 100644 src/ATen/native/quantized/sycl/QuantizedMaxPool2d.cpp create mode 100644 src/ATen/native/quantized/sycl/QuantizedMaxPool2d.h create mode 100644 test/xpu/quantization/core/test_quantized_op_xpu.py diff --git a/src/ATen/native/quantized/QuantizedMaxPool2d.cpp b/src/ATen/native/quantized/QuantizedMaxPool2d.cpp new file mode 100644 index 000000000..0d559704d --- /dev/null +++ b/src/ATen/native/quantized/QuantizedMaxPool2d.cpp @@ -0,0 +1,49 @@ +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +Tensor quantized_max_pool2d_xpu( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { + return xpu::quantized_max_pool2d_kernel( + input, kernel_size, stride, padding, dilation, ceil_mode); +} + +// Keep the registry in the anonymous namespace. +namespace { +class QMaxPool_arr_args final { + public: + static Tensor run( + const Tensor& qx, + std::vector kernel_size, + std::vector stride, + std::vector padding, + std::vector dilation, + bool ceil_mode) { + // Now we only support Byte, qint is not supported. + TORCH_CHECK( + qx.scalar_type() == c10::ScalarType::Byte, + "QuantizedMaxPool2d only supports Byte for xpu now"); + return at::native::quantized_max_pool2d_xpu( + qx, kernel_size, stride, padding, dilation, ceil_mode); + } +}; +} // anonymous namespace + +TORCH_LIBRARY_IMPL(quantized, XPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("quantized::max_pool2d"), + TORCH_FN(QMaxPool_arr_args::run)); +} +} // namespace native +} // namespace at diff --git a/src/ATen/native/quantized/sycl/QuantizedMaxPool2d.cpp b/src/ATen/native/quantized/sycl/QuantizedMaxPool2d.cpp new file mode 100644 index 000000000..d6cd6324a --- /dev/null +++ b/src/ATen/native/quantized/sycl/QuantizedMaxPool2d.cpp @@ -0,0 +1,325 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +namespace at::native::xpu { + +namespace { +void check_maxpool2d_params( + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation) { + TORCH_CHECK( + kernel_size.size() == 1 || kernel_size.size() == 2, + "Expected 1d or 2d kernel size, got ", + kernel_size.size()); + TORCH_CHECK( + stride.empty() || stride.size() == 2, + "Expected no strides or 2d strides, got", + stride.size()); + TORCH_CHECK( + padding.size() == 1 || padding.size() == 2, + "Expected 1d or 2d padding, got ", + padding.size()); + TORCH_CHECK( + dilation.size() == 1 || dilation.size() == 2, + "Expected 1d or 2d dilation, got ", + dilation.size()); +} +} // anonymous namespace + +template +struct QuantizedMaxPool2dKernelFunctor { + void operator()(sycl::nd_item<2> item) const { + auto desc = cfg_.get_item_desc(item); + + do { + if (desc.glb_problem < cfg_.problem_) { + int idx = desc.glb_problem; + int64_t b{0}, row{0}, col{0}; + b = idx / stride_; + col = idx % oW_; + row = idx / oW_ % oH_; + + int64_t output_base_offset = (b * oW_ * oH_ + row * oW_ + col) * iC_; + + // Get the boundary. + int64_t h_start = row * sH_ - pH_; + int64_t w_start = col * sW_ - pW_; + int64_t h_end = std::min(h_start + (kH_ - 1) * dH_ + 1, iH_); + int64_t w_end = std::min(w_start + (kW_ - 1) * dW_ + 1, iW_); + while (h_start < 0) + h_start += dH_; + while (w_start < 0) + w_start += dW_; + + // Stock pytorch's cpu implementation use vectorized instructions + // through channels such as AVX-512. We use for-loop directly. + int64_t w, h, c; +#pragma unroll + for (c = 0; c < iC_; c++) { + scalar_t maxVal = at::numeric_limits::lower_bound(); +#pragma unroll + for (h = h_start; h < h_end; h += dH_) { +#pragma unroll + for (w = w_start; w < w_end; w += dW_) { + int64_t input_base_offset = (b * iW_ * iH_ + h * iW_ + w) * iC_; + scalar_t val = input_[input_base_offset + c]; + if ((static_cast(val) > maxVal) || at::_isnan(val)) { + maxVal = static_cast(val); + } + } + } + output_[output_base_offset + c] = static_cast(maxVal); + } + } + } while (cfg_.next(item, desc)); + } + + QuantizedMaxPool2dKernelFunctor( + scalar_t* output, + scalar_t* input, + int64_t iC, + int64_t iH, + int64_t iW, + int64_t oH, + int64_t oW, + int64_t kH, + int64_t kW, + int64_t sH, + int64_t sW, + int64_t pH, + int64_t pW, + int64_t dH, + int64_t dW, + int64_t stride, + BatchKernelConfig cfg) + : output_(output), + input_(input), + iC_(iC), + iH_(iH), + iW_(iW), + oH_(oH), + oW_(oW), + kH_(kH), + kW_(kW), + sH_(sH), + sW_(sW), + pH_(pH), + pW_(pW), + dH_(dH), + dW_(dW), + stride_(stride), + cfg_(cfg) {} + + private: + scalar_t* output_; + scalar_t* input_; + int64_t iC_; // input/output channels + int64_t iH_; + int64_t iW_; // input sizes + int64_t oH_; + int64_t oW_; // output sizes + int64_t kH_; + int64_t kW_; // kernel size + int64_t sH_; + int64_t sW_; // strides + int64_t pH_; + int64_t pW_; // padding + int64_t dH_; + int64_t dW_; // dilation + int64_t stride_; + BatchKernelConfig cfg_; +}; + +template +void launch_quantized_max_pool2d_kernel( + scalar_t* output, + scalar_t* input, + int64_t nBatch, + int64_t iC, + int64_t iH, + int64_t iW, + int64_t oH, + int64_t oW, + int64_t kH, + int64_t kW, + int64_t sH, + int64_t sW, + int64_t pH, + int64_t pW, + int64_t dH, + int64_t dW) { + using KernelClass = QuantizedMaxPool2dKernelFunctor; + + auto& queue = at::xpu::getCurrentSYCLQueue(); + int outputSize = nBatch * oH * oW; + int stride = oH * oW; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + 1, outputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); + auto kfn = KernelClass( + output, + input, + iC, + iH, + iW, + oH, + oW, + kH, + kW, + sH, + sW, + pH, + pW, + dH, + dW, + stride, + cfg); + sycl_kernel_submit(cfg.global_size(), cfg.group_size(), queue, kfn); +} + +Tensor quantized_max_pool2d_kernel( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { + check_maxpool2d_params(kernel_size, stride, padding, dilation); + if (stride.empty()) { + stride = kernel_size; + } + Tensor output; + int ndim = input.dim(); + int64_t kH = kernel_size[0]; + int64_t kW = kernel_size[1]; + int64_t sH = stride[0]; + int64_t sW = stride[1]; + int64_t pH = padding[0]; + int64_t pW = padding[1]; + int64_t dH = dilation[0]; + int64_t dW = dilation[1]; + + // Check input dimensions. + TORCH_CHECK(kH > 0 && kW > 0, "kernel_size should be greater than zero."); + TORCH_CHECK(sH > 0 && sW > 0, "strides should be greater than zero."); + TORCH_CHECK( + dH > 0 && dW > 0, + "dilation should be greater than zero. " + "Got (", + dH, + ", ", + dW, + ")"); + TORCH_CHECK( + ndim == 3 || ndim == 4, "Expecting the input tensor of rank 3 or 4."); + + int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1; + int64_t iC = input.size(-3); + int64_t iH = input.size(-2); + int64_t iW = input.size(-1); + int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode); + int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode); + int64_t oC = iC; + + TORCH_CHECK( + oH > 0 && oW > 0, + "Given input size: (", + iC, + "x", + iH, + "x", + iW, + "). Calculated output size: (", + oC, + "x", + oH, + "x", + oW, + "). Output size is too small."); + + std::vector oSizes; + if (ndim == 3) { + oSizes = {oC, oH, oW}; + } else { + oSizes = {nbatch, oC, oH, oW}; + } + + // Create an input + output = at::empty( + oSizes, + input.options() + .device(c10::kXPU) + .dtype(input.scalar_type()) + .memory_format(c10::MemoryFormat::ChannelsLast)); + + if (input.is_contiguous(c10::MemoryFormat::ChannelsLast)) { + AT_DISPATCH_INTEGRAL_TYPES( + input.scalar_type(), "quantized_max_pool2d_xpu", [&]() { + launch_quantized_max_pool2d_kernel( + output.data_ptr(), + input.data_ptr(), + nbatch, + iC, + iH, + iW, + oH, + oW, + kH, + kW, + sH, + sW, + pH, + pW, + dH, + dW); + }); + } else { + // If input is uint8 and contiguous memory format, + // Use the channels_last implementation and convert output back to + // contiguous. + auto input_nhwc = input.contiguous(c10::MemoryFormat::ChannelsLast); + AT_DISPATCH_INTEGRAL_TYPES( + input.scalar_type(), "quantized_max_pool2d_xpu", [&]() { + launch_quantized_max_pool2d_kernel( + output.data_ptr(), + input_nhwc.data_ptr(), + nbatch, + iC, + iH, + iW, + oH, + oW, + kH, + kW, + sH, + sW, + pH, + pW, + dH, + dW); + }); + output = output.contiguous(); + } + return output; +} + +} // namespace at::native::xpu + +#pragma GCC diagnostic pop +#pragma clang diagnostic pop diff --git a/src/ATen/native/quantized/sycl/QuantizedMaxPool2d.h b/src/ATen/native/quantized/sycl/QuantizedMaxPool2d.h new file mode 100644 index 000000000..d5f86e68d --- /dev/null +++ b/src/ATen/native/quantized/sycl/QuantizedMaxPool2d.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace at::native::xpu { + +TORCH_XPU_API Tensor quantized_max_pool2d_kernel( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode); + +} // namespace at::native::xpu diff --git a/test/xpu/quantization/core/test_quantized_op_xpu.py b/test/xpu/quantization/core/test_quantized_op_xpu.py new file mode 100644 index 000000000..6091f5bab --- /dev/null +++ b/test/xpu/quantization/core/test_quantized_op_xpu.py @@ -0,0 +1,56 @@ +# Owner(s): ["module: intel"] +import itertools +import torch +from torch.nn.modules.utils import _pair +from torch.testing._internal.common_utils import ( + run_tests, + TestCase, +) +from torch.testing._internal.common_device_type import instantiate_device_type_tests + +try: + from xpu_test_utils import XPUPatchForImport +except Exception as e: + import sys + import os + script_path = os.path.split(__file__)[0] + sys.path.insert(0, os.path.realpath(os.path.join(script_path, "../.."))) + from xpu_test_utils import XPUPatchForImport + +with XPUPatchForImport(False): + from test_quantized_op import TestQuantizedOps + +def _test_max_pool2d_pt2e(self): + kernel_list = [2, 3] + stride_list = [1, 2] + padding_list = [0, 2] + dilation_list = [1, 2] + ceil_mode_list = [False, True] + channels_last_input = [False, True] + options = itertools.product(kernel_list, stride_list, padding_list, dilation_list, ceil_mode_list, channels_last_input) + for kernel, stride, padding, dilation, ceil_mode, channels_last in options: + if padding >= (kernel // 2): + # Continue with invalid input + continue + device = torch.device('xpu:0') + input = torch.randint(0, 8, (1, 3, 8, 8), dtype=torch.uint8, device=device) + if channels_last: + input = input.contiguous(memory_format=torch.channels_last) + a_pool = torch.nn.functional.max_pool2d(input.to(torch.float32), kernel_size=kernel, + stride=stride, padding=padding, dilation=dilation, + ceil_mode=ceil_mode).to(torch.uint8) + a_hat = torch.ops.quantized.max_pool2d(input, kernel_size=_pair(kernel), + stride=_pair(stride), padding=_pair(padding), + dilation=_pair(dilation), ceil_mode=ceil_mode) + self.assertEqual(input.is_contiguous(), a_hat.is_contiguous(), + msg="ops.quantized.max_pool2d input output diff memory format") + self.assertEqual(a_pool, a_hat, + msg="ops.quantized.max_pool2d results are off") + +TestQuantizedOps.test_max_pool2d_pt2e = _test_max_pool2d_pt2e + +instantiate_device_type_tests(TestQuantizedOps, globals(), only_for="xpu", allow_xpu=True) + +if __name__ == "__main__": + TestCase._default_dtype_check_enabled = True + run_tests() diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 106c2307c..3748c1d31 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -1982,6 +1982,14 @@ "test_reduction_all_sum_layout2_xpu_float64", ), + "quantization/core/test_quantized_op_xpu.py": ( + # AssertionError: Torch not compiled with CUDA enabled + "test_qgelu_xpu", + "test_qrelu_xpu", + # AttributeError: 'TestQuantizedOpsXPU' object has no attribute 'test_qsoftmax' + "test_qsoftmax_qnnpack_xpu", + ), + "quantization/core/test_workflow_ops_xpu.py": ( # AssertionError: Not equal to tolerance rtol=1e-06, atol=1e-06 # Max absolute difference among violations: 1.731507e+10 diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 5cf1de64c..05a5b8e73 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -962,6 +962,7 @@ def __init__(self, *args): ] common_cuda.TEST_CUDA = True common_cuda.TEST_CUDNN = True + common_cuda.TEST_CUDNN_VERSION = 0 cuda.is_available = lambda: True cuda.is_bf16_supported = lambda: True diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 6b0a1221d..999dcaf28 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -3444,6 +3444,11 @@ autogen: _adaptive_avg_pool2d_backward.out tags: core +- func: quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + XPU: quantized_max_pool2d_xpu + autogen: quantized_max_pool2d.out + - func: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor dispatch: XPU: embedding_dense_backward_xpu From 518bea47eb7f1cf29a46a2072194816db23087b1 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 28 Nov 2024 13:29:47 +0800 Subject: [PATCH 4/8] Update the skip list (#1123) Enable the skip cases that can currently passed. --- test/xpu/skip_list_common.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 3748c1d31..7c3aa7f8e 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -918,11 +918,6 @@ # Unexpected success: "test_cpu_gpu_parity_nn_ConvTranspose1d_xpu_complex32", "test_cpu_gpu_parity_nn_ConvTranspose2d_xpu_complex32", - # CPU fallback could not cover these - # CUDA xfails - # Failed: Unexpected success - "test_memory_format_nn_AdaptiveAvgPool2d_xpu_float32", - "test_memory_format_nn_AdaptiveAvgPool2d_xpu_float64", # CPU fallback fails # RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. "test_save_load_nn_GRU_eval_mode_xpu_float32", @@ -1042,10 +1037,6 @@ # AssertionError: False is not true "test_ctc_loss_cudnn_xpu", # want "xpu" in function name "test_ctc_loss_cudnn_tensor", # want "xpu" in function name - # RuntimeError: "multilabel_margin_loss_forward_out_frame" not implemented for 'Half' - "test_MultiLabelMarginLoss_no_batch_dim_mean_cuda_half", - "test_MultiLabelMarginLoss_no_batch_dim_none_cuda_half", - "test_MultiLabelMarginLoss_no_batch_dim_sum_cuda_half", ), "test_indexing_xpu.py": ( @@ -1099,9 +1090,6 @@ "test_autograd_composite_implicit_and_dispatch_registration_xpu", "test_autograd_multiple_dispatch_registrations_xpu", # AttributeError: module 'torch.xpu' has no attribute - "test_graph_save_on_cpu_cuda", - "test_checkpointing_without_reentrant_memory_savings", - "test_flops_and_mem", "test_profiler_emit_nvtx_xpu", # Double and complex datatype matmul is not supported in oneDNN "test_mv_grad_stride_0_xpu", @@ -1891,6 +1879,8 @@ "test_scaled_mm_vs_emulated_float16_xpu", "test_scaled_mm_vs_emulated_float32_xpu", "test_scaled_mm_vs_emulated_row_wise_bfloat16_xpu", + # AssertionError: Torch not compiled with CUDA enabled + "test_zero_dim_tensorwise_which_dim_zero", ), "test_maskedtensor_xpu.py": ( @@ -2351,7 +2341,6 @@ "test_grad_scaler_pass_itself_xpu", "test_pickle_gradscaler_xpu", ### Error #15 in TestTorchDeviceTypeXPU , totally 2 , AssertionError: Tensor-likes are not close! - "test_gradient_all_xpu_float32", "test_index_put_non_accumulate_deterministic_xpu", ### Error #17 in TestTorchDeviceTypeXPU , totally 2 , AssertionError: False is not true "test_sync_warning_xpu", @@ -2364,7 +2353,6 @@ "test_nondeterministic_alert_MaxPool3d_xpu", "test_nondeterministic_alert_NLLLoss_xpu", "test_nondeterministic_alert_interpolate_bilinear_xpu", - "test_nondeterministic_alert_kthvalue_xpu_float64", "test_nondeterministic_alert_put_accumulate_xpu", ### Error #24 in TestTorchDeviceTypeXPU , totally 1 , AttributeError: 'TestTorchDeviceTypeXPU' object has no attribute 'check_device_nondeterministic_alert' "test_nondeterministic_alert_AvgPool3d_xpu", @@ -3302,6 +3290,7 @@ "test_set_default_dtype_works_with_foreach_Rprop_xpu_float64", "test_set_default_dtype_works_with_foreach_SGD_xpu_float64", ), + "test_sparse_xpu.py": ( "test_bmm_deterministic_xpu_float64", # - AssertionError: Torch not compiled with CUDA enabled "test_bmm_oob_xpu", # - NotImplementedError: Could not run 'aten::bmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was ... From 43957e2523e9530adb2019fa730b1b43b988bf5a Mon Sep 17 00:00:00 2001 From: Kanya-Mo <167922169+Kanya-Mo@users.noreply.github.com> Date: Wed, 27 Nov 2024 22:37:46 -0800 Subject: [PATCH 5/8] Add upsample_aa op series. (#1106) - [x] _upsample_bicubic2d_aa - [x] _upsample_bicubic2d_aa.out - [x] _upsample_bicubic2d_aa_backward - [x] _upsample_bicubic2d_aa_backward.grad_input - [x] _upsample_bilinear2d_aa - [x] _upsample_bilinear2d_aa.out - [x] _upsample_bilinear2d_aa_backward - [x] _upsample_bilinear2d_aa_backward.grad_input --------- Co-authored-by: Yutao Xu --- src/ATen/native/xpu/UpSample.h | 109 +++ src/ATen/native/xpu/UpSampleBicubic2d.cpp | 28 + src/ATen/native/xpu/UpSampleBilinear2d.cpp | 27 + src/ATen/native/xpu/XPUFallback.template | 2 - .../xpu/sycl/UpSampleBilinear2dKernels.cpp | 627 ++++++++++++++++++ .../xpu/sycl/UpSampleBilinear2dKernels.h | 34 + test/xpu/xpu_test_utils.py | 1 + yaml/native/native_functions.yaml | 40 ++ 8 files changed, 866 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/UpSample.h b/src/ATen/native/xpu/UpSample.h index 447eacff2..ef9696f41 100644 --- a/src/ATen/native/xpu/UpSample.h +++ b/src/ATen/native/xpu/UpSample.h @@ -316,4 +316,113 @@ static void upsample_increment_value_bounded( return {nbatch, channels, output_width}; } +namespace upsample_antialias { + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L20-L29 +struct BilinearFilterFunctor { + template + accscalar_t operator()(accscalar_t x) const { + if (x < 0) { + x = -x; + } + if (x < 1) { + return 1 - x; + } + return 0; + } + + static const int size = 2; +}; + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L46-L62 +struct BicubicFilterFunctor { + template + accscalar_t operator()(accscalar_t x) const { + // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + const accscalar_t a = -0.5; + if (x < 0) { + x = -x; + } + if (x < 1) { + return ((a + 2) * x - (a + 3)) * x * x + 1; + } + if (x < 2) { + return (((x - 5) * x + 8) * x - 4) * a; + } + return 0; + } + + static const int size = 4; +}; + +template +static inline void _compute_weights_span( + const int i, + const int input_size, + const accscalar_t scale, + const accscalar_t support, + int& xmin, + int& xsize, + accscalar_t& center) { + center = scale * (i + static_cast(0.5)); + xmin = + max(static_cast(center - support + static_cast(0.5)), + static_cast(0)); + xsize = + min(static_cast(center + support + static_cast(0.5)), + input_size) - + xmin; +} + +template +static inline void _compute_weights( + scalar_t* wt_ptr, + const accscalar_t scale, + int interp_size, + const interp_filter_t& interp_filter, + accscalar_t xmin_m_center, + int xsize) { + accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; + accscalar_t total_w = 0.0; + int j = 0; + for (j = 0; j < xsize; j++) { + accscalar_t w = interp_filter( + (j + xmin_m_center + static_cast(0.5)) * invscale); + wt_ptr[j] = static_cast(w); + total_w += w; + } + for (j = 0; j < xsize; j++) { + if (total_w != 0.0) { + wt_ptr[j] /= total_w; + } + } + for (; j < interp_size; j++) { + wt_ptr[j] = static_cast(0.0); + } +} + +template +static inline accscalar_t interpolate_aa_single_dim( + const scalar_t* src, + const scalar_t* weights, + int size) { + scalar_t t = static_cast(*src); + scalar_t wts = static_cast(weights[0]); + accscalar_t output = t * wts; + + int j = 1; + for (; j < size; j++) { + wts = static_cast(weights[j]); + t = static_cast(*(src + j)); + output += t * wts; + } + return output; +} + +} // namespace upsample_antialias + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/UpSampleBicubic2d.cpp b/src/ATen/native/xpu/UpSampleBicubic2d.cpp index b0baf0969..7e0e4de40 100644 --- a/src/ATen/native/xpu/UpSampleBicubic2d.cpp +++ b/src/ATen/native/xpu/UpSampleBicubic2d.cpp @@ -2,10 +2,13 @@ #include #include #include +#include #include #include #include +#include +#include namespace at { namespace native { TORCH_IMPL_FUNC(upsample_bicubic2d_out_xpu) @@ -37,5 +40,30 @@ TORCH_IMPL_FUNC(upsample_bicubic2d_backward_out_xpu) scales_h, scales_w); } + +TORCH_IMPL_FUNC(_upsample_bicubic2d_aa_out_xpu) +(const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& output) { + xpu::_upsample_bicubic2d_aa_out_kernel( + output, input, output_size, align_corners, scales_h, scales_w); +} + +TORCH_IMPL_FUNC(_upsample_bicubic2d_aa_backward_out_xpu) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& grad_input) { + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("upsample_bicubic2d_aa_backward_out_xpu"); + xpu::_upsample_bicubic2d_aa_backward_out_kernel( + grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w); +} } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/UpSampleBilinear2d.cpp b/src/ATen/native/xpu/UpSampleBilinear2d.cpp index 67fed551c..ee8c37ac0 100644 --- a/src/ATen/native/xpu/UpSampleBilinear2d.cpp +++ b/src/ATen/native/xpu/UpSampleBilinear2d.cpp @@ -6,6 +6,8 @@ #include #include +#include +#include namespace at { namespace native { @@ -38,5 +40,30 @@ TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_xpu) scales_w); } +TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_xpu) +(const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& output) { + xpu::_upsample_bilinear2d_aa_out_kernel( + output, input, output_size, align_corners, scales_h, scales_w); +} + +TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_xpu) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& grad_input) { + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("upsample_bilinear2d_aa_backward_out_xpu"); + xpu::_upsample_bilinear2d_aa_backward_out_kernel( + grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w); +} + } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 10e16e2dc..8492a98be 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -189,10 +189,8 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_thnn_fused_gru_cell", "_to_sparse_csr", "triangular_solve.X", - "_upsample_bilinear2d_aa.out", "_validate_compressed_sparse_indices", "vdot", - "_upsample_bicubic2d_aa.out", }; for (auto& op_name : fallback_list) { m.impl( diff --git a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp index e5a717495..cd52a2a4e 100644 --- a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -740,6 +741,632 @@ void upsample_bilinear2d_backward_out_kernel( }); } +template +struct UpsampleGen2dAaKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<3> item) const { + const int output_x = item.get_global_id(2); + const int output_y = item.get_global_id(1); + + const int interp_height = (int)ceilf(support_h_) * 2 + 1; + const int interp_width = (int)ceilf(support_w_) * 2 + 1; + + auto ptr = + (scalar_t*)shared_.template get_multi_ptr() + .get(); + scalar_t* wx = ptr + interp_width * item.get_local_id(2); + scalar_t* wy = ptr + interp_width * item.get_local_range(2) + + interp_height * item.get_local_id(1); + const int offset = interp_width * item.get_local_range(2) + + interp_height * item.get_local_range(1); + scalar_t* buffer2 = ptr + offset + + interp_height * + (item.get_local_id(2) + + item.get_local_id(1) * item.get_local_range(2)); + + int xmin, xsize, ymin, ysize; + accscalar_t xcenter, ycenter; + + if (output_x < output_width_ && output_y < output_height_) { + upsample_antialias::_compute_weights_span( + output_x, + input_width_, + width_scale_, + support_w_, + xmin, + xsize, + xcenter); + upsample_antialias::_compute_weights_span( + output_y, + input_height_, + height_scale_, + support_h_, + ymin, + ysize, + ycenter); + + if (item.get_local_id(1) == 0) { + // All threadIdx.y have the same wx weights + upsample_antialias::_compute_weights( + wx, + width_scale_, + interp_width, + interp_filter_, + xmin - xcenter, + xsize); + } + + if (item.get_local_id(2) == 0) { + // All threadIdx.x have the same wy weights + upsample_antialias::_compute_weights( + wy, + height_scale_, + interp_height, + interp_filter_, + ymin - ycenter, + ysize); + } + } + + item.barrier(sycl_local_fence); + + if (output_x < output_width_ && output_y < output_height_) { + const scalar_t* buffer1; + auto odata = odata_; + + // Parallelized across batch/channels + for (int i = item.get_group(0); i < batchsize_ * channels_; + i += item.get_global_range(0)) { + int n = i / channels_; + int c = i % channels_; + // interpolate on y-axis for ymin to ymin + ysize + for (int y = 0; y < ysize; y++) { + buffer1 = &(idata_[n][c][ymin + y][xmin]); + buffer2[y] = static_cast( + upsample_antialias:: + interpolate_aa_single_dim( + buffer1, wx, xsize)); + } + odata[n][c][output_y][output_x] = static_cast( + upsample_antialias:: + interpolate_aa_single_dim( + buffer2, wy, ysize)); + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(local_size_, cgh); + } + + UpsampleGen2dAaKernelFunctor( + const accscalar_t height_scale, + const accscalar_t width_scale, + const PackedTensorAccessor idata, + PackedTensorAccessor odata, + InterpFilter interp_filter, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t batchsize, + int64_t channels, + const accscalar_t support_h, + const accscalar_t support_w, + int64_t local_size) + : height_scale_(height_scale), + width_scale_(width_scale), + idata_(idata), + odata_(odata), + interp_filter_(interp_filter), + input_height_(input_height), + input_width_(input_width), + output_height_(output_height), + output_width_(output_width), + batchsize_(batchsize), + channels_(channels), + support_h_(support_h), + support_w_(support_w), + local_size_(local_size) {} + + private: + const accscalar_t height_scale_; + const accscalar_t width_scale_; + const PackedTensorAccessor idata_; + PackedTensorAccessor odata_; + InterpFilter interp_filter_; + int64_t input_height_; + int64_t input_width_; + int64_t output_height_; + int64_t output_width_; + int64_t batchsize_; + int64_t channels_; + const accscalar_t support_h_; + const accscalar_t support_w_; + int64_t local_size_; + sycl_local_acc_t shared_; +}; + +template +struct UpsampleGen2dAaBackwardKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<3> item) const { + const int output_x = item.get_global_id(2); + const int output_y = item.get_global_id(1); + + const int interp_height = (int)ceilf(support_h_) * 2 + 1; + const int interp_width = (int)ceilf(support_w_) * 2 + 1; + + auto ptr = + (scalar_t*)shared_.template get_multi_ptr() + .get(); + scalar_t* wx = ptr + interp_width * item.get_local_id(2); + scalar_t* wy = ptr + interp_width * item.get_local_range(2) + + interp_height * item.get_local_id(1); + + int xmin, xsize, ymin, ysize; + accscalar_t xcenter, ycenter; + if (output_x < output_width_ && output_y < output_height_) { + upsample_antialias::_compute_weights_span( + output_x, + input_width_, + width_scale_, + support_w_, + xmin, + xsize, + xcenter); + upsample_antialias::_compute_weights_span( + output_y, + input_height_, + height_scale_, + support_h_, + ymin, + ysize, + ycenter); + + if (item.get_local_id(1) == 0) { + // All threadIdx.y have the same wx weights + upsample_antialias::_compute_weights( + wx, + width_scale_, + interp_width, + interp_filter_, + xmin - xcenter, + xsize); + } + + if (item.get_local_id(2) == 0) { + // All threadIdx.x have the same wy weights + upsample_antialias::_compute_weights( + wy, + height_scale_, + interp_height, + interp_filter_, + ymin - ycenter, + ysize); + } + } + + item.barrier(sycl_local_fence); + + if (output_x < output_width_ && output_y < output_height_) { + // Parallelized across batch/channels + auto idata = idata_; + for (int i = item.get_group(0); i < batchsize_ * channels_; + i += item.get_global_range(0)) { + int n = i / channels_; + int c = i % channels_; + scalar_t out_value = odata_[n][c][output_y][output_x]; + for (int y = 0; y < ysize; y++) { + for (int x = 0; x < xsize; x++) { + upsample_increment_value_bounded( + idata, + n, + c, + input_height_, + input_width_, + ymin + y, + xmin + x, + wx[x] * wy[y] * out_value); + } + } + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(local_size_, cgh); + } + + UpsampleGen2dAaBackwardKernelFunctor( + const accscalar_t height_scale, + const accscalar_t width_scale, + PackedTensorAccessor idata, + const PackedTensorAccessor odata, + InterpFilter interp_filter, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t batchsize, + int64_t channels, + const accscalar_t support_h, + const accscalar_t support_w, + int64_t local_size) + : height_scale_(height_scale), + width_scale_(width_scale), + idata_(idata), + odata_(odata), + interp_filter_(interp_filter), + input_height_(input_height), + input_width_(input_width), + output_height_(output_height), + output_width_(output_width), + batchsize_(batchsize), + channels_(channels), + support_h_(support_h), + support_w_(support_w), + local_size_(local_size) {} + + private: + const accscalar_t height_scale_; + const accscalar_t width_scale_; + PackedTensorAccessor idata_; + const PackedTensorAccessor odata_; + InterpFilter interp_filter_; + int64_t input_height_; + int64_t input_width_; + int64_t output_height_; + int64_t output_width_; + int64_t batchsize_; + int64_t channels_; + const accscalar_t support_h_; + const accscalar_t support_w_; + int64_t local_size_; + sycl_local_acc_t shared_; +}; + +template +void launch_upsample_gen2d_aa_kernel( + const accscalar_t height_scale, + const accscalar_t width_scale, + const PackedTensorAccessor idata, + PackedTensorAccessor odata, + InterpFilter interp_filter, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + const accscalar_t support_h, + const accscalar_t support_w) { + auto queue = getCurrentSYCLQueue(); + + const int interp_height = (int)ceilf(support_h) * 2 + 1; + const int interp_width = (int)ceilf(support_w) * 2 + 1; + + auto sharedMemPerBlock = syclLocalMemSize(); + auto total_threads = syclMaxWorkItemsPerTile(); + int maxThreadsPerBlock = std::min( + syclMaxWorkGroupSize< + UpsampleGen2dAaKernelFunctor>(), + 256); // 256 performs better + int block_x = syclMaxSubGroupSize(); + + int numer = + sharedMemPerBlock * 1.0 / sizeof(scalar_t) - interp_width * block_x; + int denom = interp_height * (block_x + 1); + int block_y = lastPow2((unsigned int)(numer / denom)); + block_y = std::min(maxThreadsPerBlock / block_x, block_y); + + int grid_x = std::min( + total_threads, (output_width + block_x - 1) / block_x * block_x); + int grid_y = std::min( + total_threads / grid_x, + (output_height + block_y - 1) / block_y * block_y); + int grid_z = + std::min(total_threads / grid_x / grid_y, nbatch * channels); + + int64_t weights_per_block = interp_width * block_x + interp_height * block_y; + weights_per_block += interp_height * block_y * block_x; + int64_t shmem_size = weights_per_block * sizeof(scalar_t); + TORCH_CHECK( + shmem_size <= sharedMemPerBlock, + "Provided interpolation parameters can not be handled with current algorithm implementation. ", + "Please reduce the scale factor. Too much shared memory required: ", + shmem_size, + " vs ", + sharedMemPerBlock); + + UpsampleGen2dAaKernelFunctor kfn( + height_scale, + width_scale, + idata, + odata, + interp_filter, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + support_h, + support_w, + weights_per_block); + + sycl_kernel_submit( + sycl::range<3>(grid_z, grid_y, grid_x), + sycl::range<3>(1, block_y, block_x), + queue, + kfn); +} + +template +void launch_upsample_gen2d_aa_backward_kernel( + const accscalar_t height_scale, + const accscalar_t width_scale, + PackedTensorAccessor idata, + const PackedTensorAccessor odata, + InterpFilter interp_filter, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + const accscalar_t support_h, + const accscalar_t support_w) { + auto queue = getCurrentSYCLQueue(); + + auto sharedMemPerBlock = syclLocalMemSize(); + auto total_threads = syclMaxWorkItemsPerTile(); + int maxThreadsPerBlock = std::min( + syclMaxWorkGroupSize< + UpsampleGen2dAaKernelFunctor>(), + 256); // 256 performs better + int block_x = syclMaxSubGroupSize(); + int block_y = maxThreadsPerBlock / block_x; + + int grid_x = std::min( + total_threads, (output_width + block_x - 1) / block_x * block_x); + int grid_y = std::min( + total_threads / grid_x, + (output_height + block_y - 1) / block_y * block_y); + int grid_z = + std::min(total_threads / grid_x / grid_y, nbatch * channels); + + const int interp_height = (int)ceilf(support_h) * 2 + 1; + const int interp_width = (int)ceilf(support_w) * 2 + 1; + + int64_t weights_per_block = interp_width * block_x + interp_height * block_y; + int64_t shmem_size = weights_per_block * sizeof(scalar_t); + TORCH_CHECK( + shmem_size <= sharedMemPerBlock, + "Provided interpolation parameters can not be handled with current algorithm implementation. ", + "Please reduce the scale factor. Too much shared memory required: ", + shmem_size, + " vs ", + sharedMemPerBlock); + + UpsampleGen2dAaBackwardKernelFunctor kfn( + height_scale, + width_scale, + idata, + odata, + interp_filter, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + support_h, + support_w, + weights_per_block); + + sycl_kernel_submit( + sycl::range<3>(grid_z, grid_y, grid_x), + sycl::range<3>(1, block_y, block_x), + queue, + kfn); +} + +template +void upsample_gen2d_aa_out_kernel( + const Tensor& output, + const Tensor& input_, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2}; + checkAllSameGPU(__func__, {input_arg, output_arg}); + + // TODO: remove this when the kernel is updated to support the channels_last + // memory format. + auto output_c = output.is_contiguous() + ? output + : at::empty(output.sizes(), output.options()); + auto input = input_.contiguous(); + int output_height = output_size[0]; + int output_width = output_size[1]; + int input_height = input.size(2); + int input_width = input.size(3); + int nbatch = input.size(0); + int channels = input.size(1); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "upsample_bilinear2d_xpu", + [&] { + using accscalar_t = acc_type_device; + auto idata = input.packed_accessor64(); + auto odata = output_c.packed_accessor64(); + + const accscalar_t height_scale = area_pixel_compute_scale( + input_height, output_height, align_corners, scales_h); + const accscalar_t width_scale = area_pixel_compute_scale( + input_width, output_width, align_corners, scales_w); + + auto interp_filter = InterpFilter(); + const accscalar_t support_h = static_cast( + (height_scale >= 1.0) ? (interp_filter.size * 0.5) * height_scale + : interp_filter.size * 0.5); + const accscalar_t support_w = static_cast( + (width_scale >= 1.0) ? (interp_filter.size * 0.5) * width_scale + : interp_filter.size * 0.5); + launch_upsample_gen2d_aa_kernel( + height_scale, + width_scale, + idata, + odata, + interp_filter, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + support_h, + support_w); + }); + + if (!output.is_contiguous()) { + output.copy_(output_c); + } +} + +template +void upsample_gen2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + TensorArg grad_input_arg{grad_input, "grad_input", 1}, + grad_output_arg{grad_output_, "grad_output_", 2}; + checkAllSameGPU( + "upsample_gen2d_backward_out_cuda", {grad_output_arg, grad_input_arg}); + + int output_height = output_size[0]; + int output_width = output_size[1]; + int input_height = input_size[2]; + int input_width = input_size[3]; + int nbatch = input_size[0]; + int channels = input_size[1]; + + Tensor grad_output = grad_output_.contiguous(); + grad_input.zero_(); + + if (grad_output.sizes() == grad_input.sizes()) { + grad_input.copy_(grad_output_); + return; + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + grad_output.scalar_type(), + "upsample_bilinear2d_xpu", + [&] { + using accscalar_t = acc_type_device; + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); + + const accscalar_t height_scale = area_pixel_compute_scale( + input_height, output_height, align_corners, scales_h); + const accscalar_t width_scale = area_pixel_compute_scale( + input_width, output_width, align_corners, scales_w); + + auto interp_filter = InterpFilter(); + const accscalar_t support_h = static_cast( + (height_scale >= 1.0) ? (interp_filter.size * 0.5) * height_scale + : interp_filter.size * 0.5); + const accscalar_t support_w = static_cast( + (width_scale >= 1.0) ? (interp_filter.size * 0.5) * width_scale + : interp_filter.size * 0.5); + launch_upsample_gen2d_aa_backward_kernel( + height_scale, + width_scale, + idata, + odata, + interp_filter, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + support_h, + support_w); + }); +} + +void _upsample_bilinear2d_aa_out_kernel( + const Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + return upsample_gen2d_aa_out_kernel< + upsample_antialias::BilinearFilterFunctor>( + output, input, output_size, align_corners, scales_h, scales_w); +} + +void _upsample_bilinear2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + return upsample_gen2d_aa_backward_out_kernel< + upsample_antialias::BilinearFilterFunctor>( + grad_input, + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w); +} + +void _upsample_bicubic2d_aa_out_kernel( + const Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + return upsample_gen2d_aa_out_kernel( + output, input, output_size, align_corners, scales_h, scales_w); +} + +void _upsample_bicubic2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + return upsample_gen2d_aa_backward_out_kernel< + upsample_antialias::BicubicFilterFunctor>( + grad_input, + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w); +} + } // namespace at::native::xpu #pragma GCC diagnostic pop diff --git a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h index aa5ee2c09..d7ae0dcf1 100644 --- a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h +++ b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h @@ -21,4 +21,38 @@ TORCH_XPU_API void upsample_bilinear2d_backward_out_kernel( c10::optional scales_h, c10::optional scales_w); +TORCH_XPU_API void _upsample_bilinear2d_aa_out_kernel( + const Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + +TORCH_XPU_API void _upsample_bilinear2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + +TORCH_XPU_API void _upsample_bicubic2d_aa_out_kernel( + const Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + +TORCH_XPU_API void _upsample_bicubic2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + } // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 05a5b8e73..6c31415cc 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -209,6 +209,7 @@ "nn.functional.pad", "nn.functional.interpolate", "nn.functional.upsample_bilinear", + "_upsample_bilinear2d_aa", "nn.functional.upsample_nearest", "nn.functional.nll_loss", "nn.functional.smooth_l1_loss", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 999dcaf28..e3bec5484 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -4768,6 +4768,26 @@ python_module: nn structured_delegate: upsample_bicubic2d_backward.grad_input +- func: _upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + XPU: _upsample_bicubic2d_aa_out_xpu + +- func: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bicubic2d_aa.out + +- func: _upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + XPU: _upsample_bicubic2d_aa_backward_out_xpu + +- func: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bicubic2d_aa_backward.grad_input + - func: upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn structured: True @@ -4788,6 +4808,26 @@ python_module: nn structured_delegate: upsample_bilinear2d_backward.grad_input +- func: _upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + XPU: _upsample_bilinear2d_aa_out_xpu + +- func: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bilinear2d_aa.out + +- func: _upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + XPU: _upsample_bilinear2d_aa_backward_out_xpu + +- func: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bilinear2d_aa_backward.grad_input + - func: native_norm(Tensor self, Scalar p=2) -> Tensor dispatch: SparseXPU: norm_sparse From 98f47b621e3d8757ea6c03c4dacbad491dc84014 Mon Sep 17 00:00:00 2001 From: mengfei25 Date: Mon, 2 Dec 2024 15:57:45 +0800 Subject: [PATCH 6/8] Nightly wheel test parser env versions (#1132) --- .github/workflows/nightly_ondemand_whl.yml | 39 +++++++++++++--------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/.github/workflows/nightly_ondemand_whl.yml b/.github/workflows/nightly_ondemand_whl.yml index 6b8d0b58f..6bdd0a612 100644 --- a/.github/workflows/nightly_ondemand_whl.yml +++ b/.github/workflows/nightly_ondemand_whl.yml @@ -75,19 +75,20 @@ jobs: ut: ${{ github.event_name == 'schedule' && 'op_regression,op_regression_dev1,op_extended,op_ut,torch_xpu' || inputs.ut }} python: ${{ github.event_name == 'schedule' && '3.10' || inputs.python }} outputs: - TORCH_BRANCH_ID: ${{ steps.pinned.outputs.TORCH_BRANCH_ID }} - TORCH_COMMIT_ID: ${{ steps.pinned.outputs.TORCH_COMMIT_ID }} - DRIVER_VERSION: ${{ steps.pinned.outputs.DRIVER_VERSION }} - KERNEL_VERSION: ${{ steps.pinned.outputs.KERNEL_VERSION }} - BUNDLE_VERSION: ${{ steps.pinned.outputs.BUNDLE_VERSION }} - OS_PRETTY_NAME: ${{ steps.pinned.outputs.OS_PRETTY_NAME }} - GCC_VERSION: ${{ steps.pinned.outputs.GCC_VERSION }} + TORCH_BRANCH_ID: ${{ steps.installed.outputs.TORCH_BRANCH_ID }} + TORCH_COMMIT_ID: ${{ steps.installed.outputs.TORCH_COMMIT_ID }} + TORCH_XPU_OPS_COMMIT: ${{ steps.installed.outputs.TORCH_XPU_OPS_COMMIT }} TORCHBENCH_COMMIT_ID: ${{ steps.pinned.outputs.TORCHBENCH_COMMIT_ID }} TORCHVISION_COMMIT_ID: ${{ steps.pinned.outputs.TORCHVISION_COMMIT_ID }} TORCHAUDIO_COMMIT_ID: ${{ steps.pinned.outputs.TORCHAUDIO_COMMIT_ID }} TRANSFORMERS_VERSION: ${{ steps.pinned.outputs.TRANSFORMERS_VERSION }} TIMM_COMMIT_ID: ${{ steps.pinned.outputs.TIMM_COMMIT_ID }} TRITON_COMMIT_ID: ${{ steps.pinned.outputs.TRITON_COMMIT_ID }} + DRIVER_VERSION: ${{ steps.pinned.outputs.DRIVER_VERSION }} + KERNEL_VERSION: ${{ steps.pinned.outputs.KERNEL_VERSION }} + BUNDLE_VERSION: ${{ steps.pinned.outputs.BUNDLE_VERSION }} + OS_PRETTY_NAME: ${{ steps.pinned.outputs.OS_PRETTY_NAME }} + GCC_VERSION: ${{ steps.pinned.outputs.GCC_VERSION }} TIMEOUT_MODELS: ${{ steps.summary.outputs.TIMEOUT_MODELS }} steps: - name: Checkout torch-xpu-ops @@ -101,32 +102,37 @@ jobs: pip install mkl-static==2025.0.1 mkl-include==2025.0.1 pip install pandas scipy tqdm - name: Prepare Stock Pytorch + id: installed run: | pwd source activate e2e_ci + pip install torch torchvision torchaudio --pre --index-url https://download.pytorch.org/whl/nightly/xpu + echo "TORCH_BRANCH_ID=$(python -c 'import torch; print(torch.__version__)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" + echo "TORCH_COMMIT_ID=$(python -c 'import torch; print(torch.version.git_version)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" source .github/scripts/env.sh cd ../ && rm -rf pytorch git clone https://github.com/pytorch/pytorch pytorch - cd pytorch && git checkout $(echo ${{ env.pytorch }} |sed 's/^nightly_wheel$/nightly/') + cd pytorch && git checkout ${TORCH_COMMIT_ID} # apply PRs for stock pytorch pip install requests python ../torch-xpu-ops/.github/scripts/apply_torch_pr.py git status && git show -s pip install -r requirements.txt - cd ../ - pip install torch torchvision torchaudio --pre --index-url https://download.pytorch.org/whl/nightly/xpu + echo "TORCH_XPU_OPS_COMMIT=$(> "${GITHUB_ENV}" + rm -rf third_party/torch-xpu-ops + git clone https://github.com/intel/torch-xpu-ops.git third_party/torch-xpu-ops + cd third_party/torch-xpu-ops + git checkout ${TORCH_XPU_OPS_COMMIT} - name: Identify pinned versions id: pinned run: | source activate e2e_ci source .github/scripts/env.sh + echo "TORCHVISION_COMMIT_ID=$(python -c 'import torchvision; print(torchvision.version.git_version)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" + echo "TORCHAUDIO_COMMIT_ID=$(python -c 'import torchaudio; print(torchaudio.version.git_version)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" + echo "TRITON_COMMIT_ID=$(python -c 'import triton; print(triton.__version__)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" cd ../pytorch - echo "TRITON_COMMIT_ID=$(pip list |grep -w pytorch-triton-xpu |awk '{print $2}')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" - echo "TORCH_BRANCH_ID=nightly" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" - echo "TORCH_COMMIT_ID=$(pip list |grep -w torch |awk '{print $2}')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" echo "TORCHBENCH_COMMIT_ID=$(> "${GITHUB_ENV}" - echo "TORCHVISION_COMMIT_ID=$(pip list |grep -w torchvision |awk '{print $2}')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" - echo "TORCHAUDIO_COMMIT_ID=$(pip list |grep -w torchaudio |awk '{print $2}')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" echo "TRANSFORMERS_VERSION=$(<.ci/docker/ci_commit_pins/huggingface.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" echo "TIMM_COMMIT_ID=$(<.ci/docker/ci_commit_pins/timm.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" echo "MODEL_ONLY_NAME=${{ inputs.model }}" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" @@ -271,6 +277,7 @@ jobs: repo="${{ github.repository }}" TORCH_BRANCH_ID="${{ needs.Linux-Nightly-Ondemand-E2E-WHL-Tests.outputs.TORCH_BRANCH_ID }}" TORCH_COMMIT_ID="${{ needs.Linux-Nightly-Ondemand-E2E-WHL-Tests.outputs.TORCH_COMMIT_ID }}" + TORCH_XPU_OPS_COMMIT="${{ needs.Linux-Nightly-Ondemand-E2E-WHL-Tests.outputs.TORCH_XPU_OPS_COMMIT }}" DRIVER_VERSION="${{ needs.Linux-Nightly-Ondemand-E2E-WHL-Tests.outputs.DRIVER_VERSION }}" KERNEL_VERSION="${{ needs.Linux-Nightly-Ondemand-E2E-WHL-Tests.outputs.KERNEL_VERSION }}" BUNDLE_VERSION="${{ needs.Linux-Nightly-Ondemand-E2E-WHL-Tests.outputs.BUNDLE_VERSION }}" @@ -307,7 +314,7 @@ jobs: fi # Test report echo -e "**${test_status}** $test_type WHL Test on $(date +'%F'), See: $build_url\n" > ${{ github.workspace }}/report.txt - printf "Torch-xpu-ops | PyTorch | Triton\n--- | --- | ---\n${GITHUB_WORKFLOW_SHA:0:7} on ${GITHUB_REF_NAME} | " >> ${{ github.workspace }}/report.txt + printf "Torch-xpu-ops | PyTorch | Triton\n--- | --- | ---\n${TORCH_XPU_OPS_COMMIT:0:7} on pinned | " >> ${{ github.workspace }}/report.txt printf "[${TORCH_COMMIT_ID:0:7}](https://github.com/pytorch/pytorch/commit/${TORCH_COMMIT_ID:0:7}) on $TORCH_BRANCH_ID | " >> ${{ github.workspace }}/report.txt echo -e "[${TRITON_COMMIT_ID:0:7}](https://github.com/intel/intel-xpu-backend-for-triton/commit/${TRITON_COMMIT_ID:0:7}) \n" >> ${{ github.workspace }}/report.txt printf "Transformers | Timm | Torchbench | Torchvision | Torchaudio\n--- | --- | --- | --- | ---\n" >> ${{ github.workspace }}/report.txt From 41a06fcfadb679f645c4386842b9a689f7a1f9ab Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Tue, 3 Dec 2024 13:47:43 +0800 Subject: [PATCH 7/8] Reuse the max op implemented by the reduction kernel to optimize the global max pooling (#1127) Fix: https://github.com/intel/torch-xpu-ops/issues/938 --- src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp | 143 +++++++++++------- src/ATen/native/xpu/sycl/DilatedMaxPool2d.h | 2 +- 2 files changed, 92 insertions(+), 53 deletions(-) diff --git a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp index e21c0160c..cba138a5f 100644 --- a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp +++ b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp @@ -5,9 +5,10 @@ #pragma GCC diagnostic ignored "-Wreturn-type" #include +#include #include #include -#include +#include #include #include @@ -541,58 +542,96 @@ void max_pool2d_with_indices_kernel( const int64_t outputHeight = output.size(-2); const int64_t outputWidth = output.size(-1); - - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, kBFloat16, input.scalar_type(), "max_pool2d_xpu", [&] { - switch (smf) { - case MemoryFormat::ChannelsLast: { - launch_max_pool2d_kernel( - output.mutable_data_ptr(), - indices.mutable_data_ptr(), - input.const_data_ptr(), - nbatch, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH, - kW, - dH, - dW, - padH, - padW, - dilationH, - dilationW); - break; - } - case MemoryFormat::Contiguous: { - launch_max_pool2d_kernel( - output.mutable_data_ptr(), - indices.mutable_data_ptr(), - input.const_data_ptr(), - nbatch, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH, - kW, - dH, - dW, - padH, - padW, - dilationH, - dilationW); - break; + if (outputHeight == 1 && outputWidth == 1 && inputHeight <= kH && + inputWidth <= kW && padH == 0 && padW == 0) { + bool is_3d = input_.ndimension() == 3; + Tensor indices_, output_; + if (is_3d) { + indices_ = indices.contiguous(); + output_ = output.contiguous(); + } else { + indices_ = indices.contiguous(smf); + output_ = output.contiguous(smf); + } + if (!is_3d) { + input.resize_({nbatch, nInputPlane, 1, inputHeight * inputWidth}, smf); + output_.resize_( + {nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf); + indices_.resize_( + {nbatch, nInputPlane, 1, outputHeight * outputWidth}, smf); + at::max_outf(input, 3, true, output_, indices_); + } else { + at::max_outf(input, 2, true, output_, indices_); + } + + if (!is_3d) { + input.resize_({nbatch, nInputPlane, inputHeight, inputWidth}, smf); + output_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf); + indices_.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf); + } + + if ((is_3d && !indices.is_contiguous()) || + (!is_3d && !indices.is_contiguous(smf))) { + indices.copy_(indices_); + } + + if ((is_3d && !output.is_contiguous()) || + (!is_3d && !output.is_contiguous(smf))) { + output.copy_(output_); + } + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "max_pool2d_xpu", [&] { + switch (smf) { + case MemoryFormat::ChannelsLast: { + launch_max_pool2d_kernel( + output.mutable_data_ptr(), + indices.mutable_data_ptr(), + input.const_data_ptr(), + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + break; + } + case MemoryFormat::Contiguous: { + launch_max_pool2d_kernel( + output.mutable_data_ptr(), + indices.mutable_data_ptr(), + input.const_data_ptr(), + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } - default: - TORCH_CHECK( - false, - "Unsupported memory format. Supports only ChannelsLast, Contiguous"); - } - }); + }); + } } void max_pool2d_with_indices_backward_kernel( diff --git a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.h b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.h index d530560e6..b07041fcb 100644 --- a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.h +++ b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace at::native::xpu { From be810b52372ff4059c03c0c31e2f16fba29d64d3 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Wed, 4 Dec 2024 13:40:54 +0800 Subject: [PATCH 8/8] Move sdp_choice to pytorch & remove unimplemented sdpa_mem fallback (#1138) As title. To work with pytorch/pytorch#140389 --- src/ATen/native/transformers/Attention.cpp | 38 ++++------------------ src/ATen/native/xpu/XPUFallback.template | 1 - yaml/native/native_functions.yaml | 6 ---- 3 files changed, 6 insertions(+), 39 deletions(-) diff --git a/src/ATen/native/transformers/Attention.cpp b/src/ATen/native/transformers/Attention.cpp index bb8b4602b..3090dfbee 100644 --- a/src/ATen/native/transformers/Attention.cpp +++ b/src/ATen/native/transformers/Attention.cpp @@ -93,36 +93,6 @@ static bool check_for_seq_len_1_nested_tensor( return true; } -int64_t _fused_sdp_choice_xpu( - const Tensor& query, - const Tensor& key, - const Tensor& value, - const std::optional& attn_mask_, - double dropout_p, - bool is_causal, - std::optional scale, - bool enable_gqa) { - // We have implemented efficient_attention backend with xetla, flash_attention - // backend is not supported now, which will be implemented in the future. So - // we provide two backends here. - sdp::sdp_params kernel_params{ - query, key, value, attn_mask_, dropout_p, is_causal, enable_gqa}; - // Because TORCHCHECK checks if condition is true we negate debug so that - // The statements will be printed when debug is true - bool print_debug = false; - sdp::SDPBackend backend = - sdp::can_use_mem_efficient_attention(kernel_params, print_debug) - ? sdp::SDPBackend::efficient_attention - : sdp::SDPBackend::math; - if (backend == sdp::SDPBackend::error) { - TORCH_CHECK( - false, - "No viable backend for scaled_dot_product_attention was found. ", - "This is likely due to turning off both the math kernel and the fused kernels."); - } - return static_cast(backend); -} - std::tuple native_multi_head_attention_xpu( const Tensor& query, const Tensor& key, @@ -204,8 +174,12 @@ std::tuple native_multi_head_attention_xpu( value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2); sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false, false}; - auto backend = static_cast( - _fused_sdp_choice_xpu(q, k, v, mask, 0.0, false, {}, false)); + + sdp::SDPBackend backend = sdp::SDPBackend::math; + if (_fused_sdp_choice_stub.is_device_supported(q.device().type())) { + backend = static_cast(_fused_sdp_choice_stub( + q.device().type(), q, k, v, mask, 0.0, false, std::nullopt, false)); + } // strides from packed projection for nested tensors when seq_len is 1 will // be and will trigger a contiguous call in the kernel, so we prevent this diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 8492a98be..1df3cd072 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -184,7 +184,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_linalg_svd.U", "lu_unpack.out", "ormqr", - "_scaled_dot_product_efficient_attention", "_scaled_mm", "_thnn_fused_gru_cell", "_to_sparse_csr", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index e3bec5484..40b710c12 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -5969,12 +5969,6 @@ XPU: native_multi_head_attention_xpu autogen: _native_multi_head_attention.out -# This aten function is kept so that we can test the choice function from Python -- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int - dispatch: - XPU: _fused_sdp_choice_xpu - tags: nondeterministic_seeded - - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor structured_delegate: argmin.out device_check: NoCheck # TensorIterator