From 1e32bbc3d9a68112299e02566cf4b174b89c24c9 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Fri, 22 Nov 2024 14:16:08 +0800 Subject: [PATCH 01/25] Sync main into release/2.6 branch (#1117) Reset to https://github.com/intel/torch-xpu-ops/commit/bfdbaf444042b21f72f40f1fc22ab30a44bd0ffb --------- Co-authored-by: mengfei25 Co-authored-by: LuFengqing Co-authored-by: Ratnam Parikh <114774508+ratnampa@users.noreply.github.com> Co-authored-by: Feng Yuan --- .github/scripts/apply_torch_pr.py | 2 + .github/scripts/env.sh | 3 +- .github/workflows/_linux_ut.yml | 2 +- .github/workflows/nightly_ondemand.yml | 2 +- .../workflows/nightly_ondemand_rolling.yml | 2 +- .github/workflows/nightly_ondemand_whl.yml | 2 +- cmake/BuildFlags.cmake | 2 +- src/ATen/native/xpu/sycl/MultiTensorApply.h | 10 ++- .../native/xpu/sycl/ScatterGatherKernels.cpp | 61 +++++++++----- src/BuildOnWindows.cmake | 83 +++++++++++++++++-- test/xpu/test_binary_ufuncs_xpu.py | 2 +- 11 files changed, 130 insertions(+), 41 deletions(-) diff --git a/.github/scripts/apply_torch_pr.py b/.github/scripts/apply_torch_pr.py index 8c550bd87..b4b441263 100644 --- a/.github/scripts/apply_torch_pr.py +++ b/.github/scripts/apply_torch_pr.py @@ -13,6 +13,8 @@ "https://github.com/pytorch/pytorch/pull/126516", # Modify the tolerance level in TIMM benchmark "https://github.com/pytorch/pytorch/pull/129735", + # [XPU] Update XPU C Shim Header + "https://github.com/pytorch/pytorch/pull/141086", ] ) parser.add_argument('--extra-pr-list', '-e', nargs='+',default=[]) diff --git a/.github/scripts/env.sh b/.github/scripts/env.sh index ab7d7812d..56d8e3930 100644 --- a/.github/scripts/env.sh +++ b/.github/scripts/env.sh @@ -1,3 +1,4 @@ #!/bin/bash -source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh +source /opt/intel/oneapi/compiler/latest/env/vars.sh +source /opt/intel/oneapi/umf/latest/env/vars.sh source /opt/intel/oneapi/pti/latest/env/vars.sh diff --git a/.github/workflows/_linux_ut.yml b/.github/workflows/_linux_ut.yml index 5ccd22b05..d2f717230 100644 --- a/.github/workflows/_linux_ut.yml +++ b/.github/workflows/_linux_ut.yml @@ -97,7 +97,7 @@ jobs: run: | source activate xpu_op_${ZE_AFFINITY_MASK} source .github/scripts/env.sh - pip install mkl-static mkl-include + pip install mkl-static==2025.0.1 mkl-include==2025.0.1 cd ../pytorch if [[ ${{ inputs.abi }} == '0' ]]; then export _GLIBCXX_USE_CXX11_ABI=0 diff --git a/.github/workflows/nightly_ondemand.yml b/.github/workflows/nightly_ondemand.yml index e33e55d56..2edc06102 100644 --- a/.github/workflows/nightly_ondemand.yml +++ b/.github/workflows/nightly_ondemand.yml @@ -123,7 +123,7 @@ jobs: conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci conda create -n e2e_ci python=${{ env.python }} cmake ninja -y source activate e2e_ci - pip install mkl-static mkl-include + pip install mkl-static==2025.0.1 mkl-include==2025.0.1 pip install pandas scipy tqdm - name: Prepare Stock Pytorch run: | diff --git a/.github/workflows/nightly_ondemand_rolling.yml b/.github/workflows/nightly_ondemand_rolling.yml index 309fd58fc..0a27b2b50 100644 --- a/.github/workflows/nightly_ondemand_rolling.yml +++ b/.github/workflows/nightly_ondemand_rolling.yml @@ -125,7 +125,7 @@ jobs: conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci conda create -n e2e_ci python=${{ env.python }} cmake ninja -y source activate e2e_ci - pip install mkl-static mkl-include + pip install mkl-static==2025.0.1 mkl-include==2025.0.1 pip install pandas scipy tqdm - name: Prepare Stock Pytorch run: | diff --git a/.github/workflows/nightly_ondemand_whl.yml b/.github/workflows/nightly_ondemand_whl.yml index 44b63cfd4..6b8d0b58f 100644 --- a/.github/workflows/nightly_ondemand_whl.yml +++ b/.github/workflows/nightly_ondemand_whl.yml @@ -98,7 +98,7 @@ jobs: conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci conda create -n e2e_ci python=${{ env.python }} cmake ninja -y source activate e2e_ci - pip install mkl-static mkl-include + pip install mkl-static==2025.0.1 mkl-include==2025.0.1 pip install pandas scipy tqdm - name: Prepare Stock Pytorch run: | diff --git a/cmake/BuildFlags.cmake b/cmake/BuildFlags.cmake index a8784097a..d7328a832 100644 --- a/cmake/BuildFlags.cmake +++ b/cmake/BuildFlags.cmake @@ -122,7 +122,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" set(SYCL_OFFLINE_COMPILER_CG_OPTIONS "-options '${SYCL_OFFLINE_COMPILER_CG_OPTIONS}'") if(WIN32) - set(AOT_TARGETS "ats-m150,lnl-m,mtl-u,mtl-h") + set(AOT_TARGETS "ats-m150,mtl-u,mtl-h,xe2-lpg,xe2-hpg") else() set(AOT_TARGETS "pvc,xe-lpg,ats-m150") endif() diff --git a/src/ATen/native/xpu/sycl/MultiTensorApply.h b/src/ATen/native/xpu/sycl/MultiTensorApply.h index 51ee195a9..e91ab56f5 100644 --- a/src/ATen/native/xpu/sycl/MultiTensorApply.h +++ b/src/ATen/native/xpu/sycl/MultiTensorApply.h @@ -68,7 +68,7 @@ static inline int64_t multi_tensor_apply_fused_kernel_get_chunk_size() { } template -struct MultiTensorApplyKernelFunctor { +struct MultiTensorApplyKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { void operator()(sycl::nd_item<1> item_id) const { // Expand the tuple elements manually and call the callable expandAndCall(item_id, std::index_sequence_for()); @@ -85,6 +85,12 @@ struct MultiTensorApplyKernelFunctor { callable(callable_), args(std::make_tuple(args_...)) {} + void sycl_ker_config_convention(sycl::handler& cgh) { + if constexpr (std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, U>) { + callable.sycl_ker_config_convention(cgh); + } + } + private: template void expandAndCall(sycl::nd_item<1> item_id, std::index_sequence) @@ -117,7 +123,6 @@ void launch_multi_tensor_apply_kernel( U callable, int num_wg, ArgTypes... args) { - auto& q = getCurrentSYCLQueue(); int64_t simd = syclMaxSubGroupSize(); int64_t max_wg_size = multi_tensor_apply_kernel_get_wg_size(simd); @@ -226,7 +231,6 @@ void multi_tensor_apply( std::vector>& tensor_lists, T callable, ArgTypes... args) { - TORCH_CHECK( tensor_lists.size() == depth, "Number of tensor lists has to match he depth"); diff --git a/src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp b/src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp index 597be8553..c807655b3 100644 --- a/src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp +++ b/src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp @@ -143,41 +143,64 @@ struct alignas(N) OpaqueType { char data[N]; }; -template +template struct ScatterGatherElementwiseKernelFunctor { void operator()(sycl::nd_item<1> item) const { - constexpr int nv = work_group_size * thread_work_size; + int nv = work_group_size_ * thread_work_size_; auto wg_id = item.get_group_linear_id(); auto local_id = item.get_local_linear_id(); int idx = nv * wg_id + local_id; -#pragma unroll - for (int i = 0; i < thread_work_size; ++i) { + for (int i = 0; i < thread_work_size_; ++i) { if (idx < N_) { f_(idx); - idx += work_group_size; + idx += work_group_size_; } } } - ScatterGatherElementwiseKernelFunctor(int N, func_t f) : N_(N), f_(f) {} + ScatterGatherElementwiseKernelFunctor( + int N, + func_t f, + int work_group_size, + int thread_work_size) + : N_(N), + f_(f), + work_group_size_(work_group_size), + thread_work_size_(thread_work_size) {} private: int N_; func_t f_; + int work_group_size_; + int thread_work_size_; }; -template +template static void launch_scatter_gather_kernel(int64_t N, const func_t& f) { TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); if (N == 0) { return; } - sycl::range<1> local_range{(size_t)nt}; - int num_workgroups = (N + nt * vt - 1) / (nt * vt); - sycl::range<1> global_range{(size_t)(num_workgroups * nt)}; - - auto caller = - ScatterGatherElementwiseKernelFunctor((int)N, f); + using KernelFn = ScatterGatherElementwiseKernelFunctor; + int64_t max_wg_size = syclMaxWorkGroupSize(); + int outputSize = N; + int work_group_size = outputSize > max_wg_size ? max_wg_size : outputSize; + const auto target_global_size = syclMaxWorkItemsPerTile(); + // Each work group size is work_group_size, one full device launch is + // target_global_size, so we can calculate max work group num as below + const int max_work_group_num = target_global_size / work_group_size; + int work_group_num = outputSize / work_group_size < max_work_group_num + ? outputSize / work_group_size + : max_work_group_num; + int draft_work_group_num = + (outputSize + work_group_size - 1) / work_group_size; + + int thread_work_size = draft_work_group_num / work_group_num + 1; + + sycl::range<1> local_range(work_group_size); + sycl::range<1> global_range(work_group_num * work_group_size); + + auto caller = KernelFn((int)N, f, work_group_size, thread_work_size); sycl_kernel_submit( global_range, local_range, at::xpu::getCurrentSYCLQueue(), caller); } @@ -268,11 +291,7 @@ struct ScatterGatherInternalKernel { numel, f); - // TODO: optimize it - constexpr int group_work_items = 256; - constexpr int work_size_per_item = 4; - launch_scatter_gather_kernel( - iter.numel(), loop); + launch_scatter_gather_kernel(iter.numel(), loop); } }; @@ -521,11 +540,7 @@ struct ScatterFillInternalKernel { decltype(offset_calc), func_t>(self_ptr, index_ptr, offset_calc, index_stride, f, src_val); - // TODO: optimize it - constexpr int group_work_items = 256; - constexpr int work_size_per_item = 4; - launch_scatter_gather_kernel( - iter.numel(), loop); + launch_scatter_gather_kernel(iter.numel(), loop); } }; diff --git a/src/BuildOnWindows.cmake b/src/BuildOnWindows.cmake index b671b305c..b4757acb1 100644 --- a/src/BuildOnWindows.cmake +++ b/src/BuildOnWindows.cmake @@ -3,11 +3,6 @@ set(TORCH_XPU_OPS_LIBRARIES) set(SYCL_LINK_LIBRARIES_KEYWORD PRIVATE) -# Walk around cyclic dependence -# libtorch_xpu.so links to libtorch_xpu_ops.a -# Load libtorch_xpu_ops_aten.so explicitly by torch/__init__.py:_load_dll_libraries (Break cycle) -# libtorch_xpu_ops_aten.so links to libtorch_xpu_ops_sycl_unary_binary_kernels.so and libtorch_xpu_ops_sycl_kernels.so -# libtorch_xpu_ops_sycl_unary_binary_kernels.so and libtorch_xpu_ops_sycl_kernels.so links to libtorch_xpu.so add_library( torch_xpu_ops STATIC @@ -21,7 +16,6 @@ add_library( ${ATen_XPU_NATIVE_CPP_SRCS} ${ATen_XPU_GEN_SRCS}) install(TARGETS torch_xpu_ops_aten DESTINATION "${TORCH_INSTALL_LIB_DIR}") -# target_compile_definitions(torch_xpu_ops_aten PRIVATE CAFFE2_BUILD_MAIN_LIB) target_compile_definitions(torch_xpu_ops_aten PRIVATE TORCH_XPU_BUILD_MAIN_LIB) target_link_libraries(torch_xpu_ops_aten PUBLIC torch_xpu) target_link_libraries(torch_xpu_ops_aten PUBLIC torch_cpu) @@ -48,8 +42,11 @@ else() set(ATen_XPU_SYCL_REDUCE_SRCS) set(ATen_XPU_SYCL_ACTIVATION_SRCS) set(ATen_XPU_SYCL_FOREACH_SRCS) + set(ATen_XPU_SYCL_TENSOR_SRCS) + set(ATen_XPU_SYCL_NORM_LOSS_SRCS) + set(ATen_XPU_SYCL_POLY_SRCS) + set(ATen_XPU_SYCL_DISTRIBUTION_SRCS) set(ATen_XPU_SYCL_OTHERS_SRCS) - foreach(sycl_src ${ATen_XPU_SYCL_SRCS}) string(REGEX MATCH "Binary" IS_BINARY ${sycl_src}) string(REGEX MATCH "Unary" IS_UNARY ${sycl_src}) @@ -63,6 +60,13 @@ else() string(REGEX MATCH "Activation" IS_ACTIVATION ${sycl_src}) string(REGEX MATCH "Foreach" IS_FOREACH ${sycl_src}) string(REGEX MATCH "Reduce" IS_REDUCE ${sycl_src}) + string(REGEX MATCH "Tensor" IS_TENSOR ${sycl_src}) + string(REGEX MATCH "Norm" IS_NORM ${sycl_src}) + string(REGEX MATCH "Loss" IS_LOSS ${sycl_src}) + string(REGEX MATCH "Polynomial" IS_POLY ${sycl_src}) + #Move resize kernel to Norm and Loss lib, to resolve symbol. + string(REGEX MATCH "Resize" IS_RESIZE ${sycl_src}) + string(REGEX MATCH "Distribution" IS_DISTRIBUTION ${sycl_src}) if(NOT IS_FOREACH STREQUAL "") list(APPEND ATen_XPU_SYCL_FOREACH_SRCS ${sycl_src}) @@ -74,11 +78,18 @@ else() list(APPEND ATen_XPU_SYCL_REDUCE_SRCS ${sycl_src}) elseif(NOT IS_ACTIVATION STREQUAL "") list(APPEND ATen_XPU_SYCL_ACTIVATION_SRCS ${sycl_src}) + elseif(NOT IS_TENSOR STREQUAL "") + list(APPEND ATen_XPU_SYCL_TENSOR_SRCS ${sycl_src}) + elseif(NOT IS_DISTRIBUTION STREQUAL "") + list(APPEND ATen_XPU_SYCL_DISTRIBUTION_SRCS ${sycl_src}) + elseif(NOT IS_NORM STREQUAL "" OR NOT IS_LOSS STREQUAL "" OR NOT IS_RESIZE STREQUAL "") + list(APPEND ATen_XPU_SYCL_NORM_LOSS_SRCS ${sycl_src}) + elseif(NOT IS_POLY STREQUAL "") + list(APPEND ATen_XPU_SYCL_POLY_SRCS ${sycl_src}) else() list(APPEND ATen_XPU_SYCL_OTHERS_SRCS ${sycl_src}) endif() endforeach() - # Binary kernel lib set(sycl_binary_lib torch_xpu_ops_sycl_binary_kernels) sycl_add_library( @@ -148,7 +159,63 @@ else() # Decouple with PyTorch cmake definition. install(TARGETS ${sycl_foreach_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + + # Tensor kernel lib + set(sycl_tensor_lib torch_xpu_ops_sycl_tensor_kernels) + sycl_add_library( + ${sycl_tensor_lib} + SHARED + SYCL_SOURCES ${ATen_XPU_SYCL_TENSOR_SRCS}) + target_compile_definitions(${sycl_tensor_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB) + target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_tensor_lib}) + target_link_libraries(${sycl_tensor_lib} PUBLIC torch_xpu) + list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_tensor_lib}) + # Decouple with PyTorch cmake definition. + install(TARGETS ${sycl_tensor_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + + # Norm and Loss kernel lib + set(sycl_norm_loss_lib torch_xpu_ops_sycl_norm_loss_kernels) + sycl_add_library( + ${sycl_norm_loss_lib} + SHARED + SYCL_SOURCES ${ATen_XPU_SYCL_NORM_LOSS_SRCS}) + target_compile_definitions(${sycl_norm_loss_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB) + target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_norm_loss_lib}) + target_link_libraries(${sycl_norm_loss_lib} PUBLIC torch_xpu) + list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_norm_loss_lib}) + + # Decouple with PyTorch cmake definition. + install(TARGETS ${sycl_norm_loss_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + + # Polynomial kernel lib + set(sycl_poly_lib torch_xpu_ops_sycl_poly_kernels) + sycl_add_library( + ${sycl_poly_lib} + SHARED + SYCL_SOURCES ${ATen_XPU_SYCL_POLY_SRCS}) + target_compile_definitions(${sycl_poly_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB) + target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_poly_lib}) + target_link_libraries(${sycl_poly_lib} PUBLIC torch_xpu) + list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_poly_lib}) + + # Decouple with PyTorch cmake definition. + install(TARGETS ${sycl_poly_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + + # Distribution kernel lib + set(sycl_dist_lib torch_xpu_ops_sycl_dist_kernels) + sycl_add_library( + ${sycl_dist_lib} + SHARED + SYCL_SOURCES ${ATen_XPU_SYCL_DISTRIBUTION_SRCS}) + target_compile_definitions(${sycl_dist_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB) + target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_dist_lib}) + target_link_libraries(${sycl_dist_lib} PUBLIC torch_xpu) + list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_dist_lib}) + + # Decouple with PyTorch cmake definition. + install(TARGETS ${sycl_dist_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + # Other kernel lib set(sycl_lib torch_xpu_ops_sycl_kernels) sycl_add_library( diff --git a/test/xpu/test_binary_ufuncs_xpu.py b/test/xpu/test_binary_ufuncs_xpu.py index 9cff7f9f1..91db3f12a 100644 --- a/test/xpu/test_binary_ufuncs_xpu.py +++ b/test/xpu/test_binary_ufuncs_xpu.py @@ -65,7 +65,7 @@ def to_np(value): else: self.assertRaisesRegex( RuntimeError, - "Found dtype \\w+ but expected \\w+", + r"result type \w+ can't be cast to the desired output type \w+", lambda: actual.pow_(exponent), ) From f312190a927b21caa34f18505cf350faaed6f455 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Tue, 3 Dec 2024 14:55:30 +0800 Subject: [PATCH 02/25] [Release-2.6] Fix bugs of `empty_xpu` and `soft_shrink` (#1139) #### Bugfix - [add lazy init for empty_xpu](https://github.com/intel/torch-xpu-ops/pull/1115) - [nan propagation for soft_shrink](https://github.com/intel/torch-xpu-ops/pull/1116/files#diff-b7cb5876d000db957286c8b0e72badb2b7502402c8955334f1cc21c34c98a5b9) --------- Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> Co-authored-by: ZhiweiYan-96 --- src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp | 9 +++++---- src/ATen/xpu/EmptyTensor.cpp | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp b/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp index b96ab461e..c4bfac98f 100644 --- a/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp +++ b/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp @@ -1,16 +1,17 @@ #include +#include #include - -#include - #include +#include namespace at::native::xpu { template struct SoftshrinkFunctor { scalar_t operator()(scalar_t a) const { - return a > lambd_ ? a - lambd_ : (a < -lambd_ ? a + lambd_ : scalar_t(0)); + return at::_isnan(a) + ? a + : (a > lambd_ ? a - lambd_ : (a < -lambd_ ? a + lambd_ : scalar_t(0))); } SoftshrinkFunctor(scalar_t lambd) : lambd_(lambd) {} diff --git a/src/ATen/xpu/EmptyTensor.cpp b/src/ATen/xpu/EmptyTensor.cpp index 2550b4dbb..3f5e998f8 100644 --- a/src/ATen/xpu/EmptyTensor.cpp +++ b/src/ATen/xpu/EmptyTensor.cpp @@ -12,6 +12,7 @@ TensorBase empty_xpu( ScalarType dtype, c10::optional device_opt, c10::optional memory_format_opt) { + at::globalContext().lazyInitDevice(c10::DeviceType::XPU); const auto device = device_or_default(device_opt); TORCH_INTERNAL_ASSERT(device.is_xpu()); const c10::DeviceGuard device_guard(device); From 7ecb0b1a56b65dec63837a30972a8ba6f8432477 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 5 Dec 2024 14:20:27 +0800 Subject: [PATCH 03/25] [Release-2.6] Capture rrelu_with_noise noise mutation in compile (#1145) Resolve: https://github.com/pytorch/pytorch/issues/142102 --- src/ATen/native/xpu/RreluWithNoise.cpp | 6 +++--- src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp | 4 ++-- src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h | 2 +- yaml/native/native_functions.yaml | 7 ++++--- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/ATen/native/xpu/RreluWithNoise.cpp b/src/ATen/native/xpu/RreluWithNoise.cpp index f66833983..fb4e2c333 100644 --- a/src/ATen/native/xpu/RreluWithNoise.cpp +++ b/src/ATen/native/xpu/RreluWithNoise.cpp @@ -6,7 +6,7 @@ namespace native { Tensor& rrelu_with_noise_out_xpu( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -18,7 +18,7 @@ Tensor& rrelu_with_noise_out_xpu( Tensor rrelu_with_noise_xpu( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -30,7 +30,7 @@ Tensor rrelu_with_noise_xpu( Tensor& rrelu_with_noise_xpu_( Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp index 533630175..7f6f33805 100644 --- a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp +++ b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp @@ -86,7 +86,7 @@ template inline void _rrelu_with_noise_xpu_train( Tensor& output, const Tensor& input_, - const Tensor& noise_, + Tensor& noise_, const Scalar& lower_, const Scalar& upper_, std::optional generator) { @@ -153,7 +153,7 @@ inline void _rrelu_with_noise_xpu_train( Tensor& rrelu_with_noise_kernel( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h index 8371c38ab..fa7e568ea 100644 --- a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h +++ b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h @@ -7,7 +7,7 @@ namespace at::native::xpu { TORCH_XPU_API Tensor& rrelu_with_noise_kernel( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 64e94073b..7e9d19e9d 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8106,17 +8106,18 @@ variants: function tags: pointwise -- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) +- func: rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn tags: nondeterministic_seeded dispatch: XPU: rrelu_with_noise_out_xpu -- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor +- func: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor python_module: nn dispatch: XPU: rrelu_with_noise_xpu tags: nondeterministic_seeded + autogen: rrelu_with_noise_functional - func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor python_module: nn @@ -8124,7 +8125,7 @@ CompositeExplicitAutograd: rrelu_with_noise_backward autogen: rrelu_with_noise_backward.out -- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) +- func: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) python_module: nn tags: nondeterministic_seeded dispatch: From 5410f510e8a365055ede770c8a7518941dd04ea9 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Fri, 22 Nov 2024 02:58:15 +0000 Subject: [PATCH 04/25] contiguous layout for sycl int4 kernel --- test/xpu/test_int4_linear.py | 205 +++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 test/xpu/test_int4_linear.py diff --git a/test/xpu/test_int4_linear.py b/test/xpu/test_int4_linear.py new file mode 100644 index 000000000..6d305e1cb --- /dev/null +++ b/test/xpu/test_int4_linear.py @@ -0,0 +1,205 @@ +# Owner(s): ["module: intel"] + +import torch +import pytest + +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + + +checking_atol = 1e-2 +checking_rtol = 1e-2 + + +class TestSYCLInt4Linear(TestCase): + + @staticmethod + def unpack_weight(qweight, scales, qzeros, q_config): + group_size = q_config["group_size"] + bits = q_config["bits"] + s32_bits = 32 + + assert bits == 4 + # Int32 can store 8 * 4bits data. This is the offset for each data. + wf = ( + torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32) + .unsqueeze(0) + .to("xpu") + ) + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(scales.shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(weight, (2**bits) - 1, out=weight) + + return weight, scales, zeros + + @staticmethod + def dequantize(qweight, scales, qzeros, group_size): + q_config = {"group_size": group_size, "bits": 4} + weight, gptq_scales, gptq_zeros = TestSYCLInt4Linear.unpack_weight( + qweight, scales, qzeros, q_config + ) + gptq_zeros = (torch.ones_like(gptq_zeros) * 8).to("xpu") # TODO: hard code zp + if len(weight.shape) > 2: + weight = weight.reshape(-1, weight.shape[-1]) + infeatures = weight.shape[0] + g_idx = torch.tensor( + [i // q_config["group_size"] for i in range(infeatures)], + dtype=torch.int32, + ) + scale_zeros = gptq_zeros * gptq_scales + weight = gptq_scales[g_idx.long()] * weight - scale_zeros[g_idx.long()] + return weight + + @staticmethod + def rand_int4(size, dtype=torch.int32, device="xpu"): + rand = torch.randint(-128, 128, [size // 2], device=device).to(torch.int8) + return rand.view(dtype=dtype) + + @parametrize("per_channel", [False], lambda k: "per_channel" * k) + @parametrize("dtype", [torch.float16]) + @parametrize("m,n,k", [(8, 4096, 4096), (1, 4096, 11008), (32, 4096, 4096)]) + def test_gemm_int4(self, m, n, k, per_channel, dtype): + input = torch.rand([m, k], device="xpu", dtype=dtype) + input_torch = input.cpu() + weight = self.rand_int4(k * n, torch.int32, "xpu").reshape(k // 8, n) + + group_size = min(128, k) + if per_channel: + group_size = k + group_num = int(k / group_size) + + scales = torch.rand([group_num, n], device="xpu", dtype=dtype) + zero_points = self.rand_int4(group_num * n, torch.int32, "xpu").reshape( + group_num, n // 8 + ) + + weight_fp = self.dequantize( + weight, scales, zero_points, group_size).cpu() + # check gemm + zero_points = torch.Tensor([8]).to(torch.int8).to("xpu") + weight_ba = weight.transpose(0, 1).contiguous() + + out_onednn =torch._weight_int4pack_mm_with_scales_and_zeros( + input, weight_ba, scales, zero_points, group_size + ) + out_torch = torch.matmul(input_torch, weight_fp) + self.assertEqual( + out_onednn.cpu().float(), + out_torch.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + # check gemm + residual + res0 = torch.rand([m, n], device="xpu", dtype=dtype) + out_onednn_res = torch._weight_int4pack_mm_with_scales_and_zeros( + input, weight_ba, scales, zero_points, group_size, res0) + out_torch_res = out_torch + res0.cpu().float() + self.assertEqual( + out_onednn_res.cpu().float(), + out_torch_res.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + bias + bias = torch.rand([1, n], device="xpu", dtype=dtype) + out_onednn_bias = torch._weight_int4pack_mm_with_scales_and_zeros( + input, weight_ba, bias, scales, zero_points, group_size) + out_torch_bias = out_torch + bias.cpu().float() + self.assertEqual( + out_onednn_bias.cpu().float(), + out_torch_bias.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + bias + gelu + out_onednn_gelu = torch._weight_int4pack_mm_with_scales_and_zeros( + input, + weight_ba, + scales, + zero_points, + bias, + group_size, + "tanh", + ) + gelu_out = torch.nn.GELU(approximate="tanh")(out_torch_bias) + self.assertEqual( + out_onednn_gelu.cpu().float(), + gelu_out.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + silu + mul + res0 = torch.rand([m, n], device="xpu", dtype=dtype) + out_onednn_silu = torch._weight_int4pack_mm_with_scales_and_zeros( + input, weight_ba, scales, zero_points, group_size, res0 + ) + silu_mul_out = torch.nn.SiLU()(out_torch) * res0.cpu().float() + self.assertEqual( + out_onednn_silu.cpu().float(), + silu_mul_out.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + bias + residual + residual + res0 = torch.rand([m, n], device="xpu", dtype=dtype) + res1 = torch.rand([m, n], device="xpu", dtype=dtype) + out_onednn_bias_2res = torch._weight_int4pack_mm_with_scales_and_zeros( + input, + weight_ba, + bias, + res0, + res1, + scales, + zero_points, + group_size, + ) + out_torch_bias_2res = out_torch_bias + res0.cpu().float() + res1.cpu().float() + self.assertEqual( + out_onednn_bias_2res.cpu().float(), + out_torch_bias_2res.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + bias + residual + res0 = torch.rand([m, n], device="xpu", dtype=dtype) + out_onednn_bias_add = torch._weight_int4pack_mm_with_scales_and_zeros( + input, + weight_ba, + bias, + scales, + zero_points, + group_size, + res0, + ) + out_torch_bias_add = out_torch_bias + res0.cpu().float() + self.assertEqual( + out_onednn_bias_add.cpu().float(), + out_torch_bias_add.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + +instantiate_parametrized_tests(TestSYCLInt4Linear, globals(), only_for="xpu", allow_xpu=True) + +if __name__ == "__main__": + TestCase._default_dtype_check_enabled = True + run_tests() From e9311a36c4fb2b0ebfb63a9863e9a009103677eb Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Tue, 26 Nov 2024 09:05:33 +0000 Subject: [PATCH 05/25] push without compile --- src/ATen/native/xpu/sycl/LinearInt4.cpp | 135 ++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 src/ATen/native/xpu/sycl/LinearInt4.cpp diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp new file mode 100644 index 000000000..796edb791 --- /dev/null +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -0,0 +1,135 @@ +#include + +namespace at::native::xpu { + +void linear_int4_kernel( + const Tensor& input, + const Tensor& weight, + const Tensor& weight_scale_zero_point, + const std::optional& weight_bias, + Tensor& output, + int block_size) { + int64_t M = input[0]; + int64_t K = input[1]; + int64_t N = output[1]; + scalar_t* input_data = input.data_ptr(); + int4x8* weight_data = weight.data_ptr(); + scalar_t* weight_scale_data = weight_scale.data_ptr(); +} + +template +struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + LinearInt4KernelFunctor( + scalar_t* A, + int4x8* B, + scalar_t* C, + scalar_t* B_scale, + scalar_t* B_zero_point, + int m, + int n, + int k, + int lda, + int ldb, + int ldc) + : A(A), + B(B), + C(C), + B_scale(B_scale), + B_zero_point(B_zero_point), + m(m), + n(n), + k(k), + lda(lda), + ldb(ldb), + ldc(ldc) {} + + void operator()(sycl::nd_item<1> item) const { + int constexpr Unroll = 2; + int constexpr SgSize = 16; + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + int constexpr blocksize = 16; + + int g_idx = item.get_group(0); + auto sg = item.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = B_scale + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { + sycl::half2 tmpAcc = {0.f, 0.f}; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scale_t scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = { + static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { + scalar_t tmpAcc = 0.f; + int constexpr Unroll = 2; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scale_t scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += scale_t(aptr[sg_id * TileK + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += scale_t(aptr[sg_id * TileK + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += sg.shuffle(tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; + } + } + } + + private: + scalar_t* A; + int4x8* B; + scalar_t* C; + scalar_t* B_scale; + scalar_t* B_zero_point; + int m; + int n; + int k; + int lda; + int ldb; + int ldc; +}; +} // namespace at::native::xpu \ No newline at end of file From e3eaffad6f73f6504237558d8c83eb7203b2e271 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Thu, 28 Nov 2024 01:46:19 +0000 Subject: [PATCH 06/25] update linearkernel --- src/ATen/native/xpu/sycl/LinearInt4.cpp | 27 ++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 796edb791..564dd85de 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -9,12 +9,33 @@ void linear_int4_kernel( const std::optional& weight_bias, Tensor& output, int block_size) { - int64_t M = input[0]; - int64_t K = input[1]; - int64_t N = output[1]; + auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); + int64_t m = input[0]; + int64_t n = input[1]; + int64_t k = output[1]; + int constexpr Unroll = 2; + int constexpr SgSize = 16; + sycl::range<1> group{SgSize}; + sycl::range<1> problem{static_cast(n) * SgSize}; + int lda = k; + int ldb = n; + int ldc = n; scalar_t* input_data = input.data_ptr(); int4x8* weight_data = weight.data_ptr(); + scalar_t* output_data = output.data_ptr(); scalar_t* weight_scale_data = weight_scale.data_ptr(); + auto kfn = LinearInt4KernelFunctor( + input_data, + weight_data, + output_data, + weight_scale_data, + nullptr, + m, + n, + k, + n, + n); + sycl_kernel_submit(::sycl::nd_range<1>(problem, group), sycl_queue, kfn); } template From 2a664af6a02465b2eb0ba6113cebef0648d96708 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Thu, 28 Nov 2024 09:04:21 +0000 Subject: [PATCH 07/25] fix some comiple error(not all) --- src/ATen/native/xpu/sycl/LinearInt4.cpp | 96 ++++++++++++++----------- 1 file changed, 53 insertions(+), 43 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 564dd85de..e973c7e7e 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -1,48 +1,13 @@ #include +#include namespace at::native::xpu { -void linear_int4_kernel( - const Tensor& input, - const Tensor& weight, - const Tensor& weight_scale_zero_point, - const std::optional& weight_bias, - Tensor& output, - int block_size) { - auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); - int64_t m = input[0]; - int64_t n = input[1]; - int64_t k = output[1]; - int constexpr Unroll = 2; - int constexpr SgSize = 16; - sycl::range<1> group{SgSize}; - sycl::range<1> problem{static_cast(n) * SgSize}; - int lda = k; - int ldb = n; - int ldc = n; - scalar_t* input_data = input.data_ptr(); - int4x8* weight_data = weight.data_ptr(); - scalar_t* output_data = output.data_ptr(); - scalar_t* weight_scale_data = weight_scale.data_ptr(); - auto kfn = LinearInt4KernelFunctor( - input_data, - weight_data, - output_data, - weight_scale_data, - nullptr, - m, - n, - k, - n, - n); - sycl_kernel_submit(::sycl::nd_range<1>(problem, group), sycl_queue, kfn); -} - -template +template struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { LinearInt4KernelFunctor( scalar_t* A, - int4x8* B, + uint32_t* B, scalar_t* C, scalar_t* B_scale, scalar_t* B_zero_point, @@ -87,7 +52,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { uint8_t tmps8[TileK / 2]; *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - scale_t scale = *(sptr + sg_id * TileK / blocksize); + scalar_t scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll for (int ikk = 0; ikk < TileK; ikk += 2) { sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; @@ -117,12 +82,12 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { uint8_t tmps8[TileK / 2]; *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - scale_t scale = *(sptr + sg_id * TileK / blocksize); + scalar_t scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll for (int ikk = 0; ikk < TileK; ikk += 2) { - tmpAcc += scale_t(aptr[sg_id * TileK + ikk]) * + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk]) * static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; - tmpAcc += scale_t(aptr[sg_id * TileK + ikk + 1]) * + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk + 1]) * static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; } sptr += GroupK / blocksize; @@ -142,7 +107,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { private: scalar_t* A; - int4x8* B; + uint32_t* B; scalar_t* C; scalar_t* B_scale; scalar_t* B_zero_point; @@ -153,4 +118,49 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int ldb; int ldc; }; + +void linear_int4_kernel( + const Tensor& input, + const Tensor& weight, + const Tensor& weight_scale_zero_point, + const std::optional& weight_bias, + Tensor& output, + int block_size) { + auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); + int64_t m = input.size(0); + int64_t n = input.size(1); + int64_t k = output.size(1); + int constexpr Unroll = 2; + int constexpr SgSize = 16; + sycl::range<1> group{SgSize}; + sycl::range<1> problem{static_cast(n) * SgSize}; + int lda = k; + int ldb = n; + int ldc = n; + if (input.scalar_type() == at::kHalf) { + using scalar_t = at::Half; + // const auto scalar_t = input.scalar_type(); + scalar_t* input_data = input.data_ptr(); + uint32_t* weight_data = weight.data_ptr(); // int4x8 + + scalar_t* output_data = output.data_ptr(); + scalar_t* weight_scale_data = weight_scale_zero_point.data_ptr(); + auto kfn = LinearInt4KernelFunctor( + input_data, + weight_data, + output_data, + weight_scale_data, + nullptr, + m, + n, + k, + k, + n, + n); + + sycl_kernel_submit(::sycl::nd_range<1>(problem, group), sycl_queue, kfn); + } +} + + } // namespace at::native::xpu \ No newline at end of file From 0156ba57e8c3ea080e88ab8dcb5d093594f98d61 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Thu, 28 Nov 2024 12:13:39 +0000 Subject: [PATCH 08/25] add sycl_ker_config_convention --- src/ATen/native/xpu/sycl/LinearInt4.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index e973c7e7e..648cc0e6b 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -28,6 +28,9 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { lda(lda), ldb(ldb), ldc(ldc) {} + void sycl_ker_config_convention(sycl::handler& cgh) { + // local_scan_ = sycl_local_acc_t(N_, cgh); + } void operator()(sycl::nd_item<1> item) const { int constexpr Unroll = 2; @@ -132,8 +135,8 @@ void linear_int4_kernel( int64_t k = output.size(1); int constexpr Unroll = 2; int constexpr SgSize = 16; - sycl::range<1> group{SgSize}; - sycl::range<1> problem{static_cast(n) * SgSize}; + sycl::range<1> local_range{SgSize}; + sycl::range<1> global_range{static_cast(n) * SgSize}; int lda = k; int ldb = n; int ldc = n; @@ -145,7 +148,7 @@ void linear_int4_kernel( scalar_t* output_data = output.data_ptr(); scalar_t* weight_scale_data = weight_scale_zero_point.data_ptr(); - auto kfn = LinearInt4KernelFunctor( + LinearInt4KernelFunctor kfn( input_data, weight_data, output_data, @@ -158,9 +161,8 @@ void linear_int4_kernel( n, n); - sycl_kernel_submit(::sycl::nd_range<1>(problem, group), sycl_queue, kfn); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); } } - } // namespace at::native::xpu \ No newline at end of file From a58afeca5f02f0feb717f0b18179cbce511e520a Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Fri, 29 Nov 2024 07:07:30 +0000 Subject: [PATCH 09/25] reg kernel for pytorch --- src/ATen/native/xpu/LinearInt4.cpp | 32 +++++++++++++++++++++++++ src/ATen/native/xpu/sycl/LinearInt4.cpp | 5 ++-- src/ATen/native/xpu/sycl/LinearInt4.h | 13 ++++++++++ 3 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 src/ATen/native/xpu/LinearInt4.cpp create mode 100644 src/ATen/native/xpu/sycl/LinearInt4.h diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp new file mode 100644 index 000000000..b8e114714 --- /dev/null +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -0,0 +1,32 @@ + +#include +#include +#include +#include + +#include +#include + +namespace at::native { +Tensor& linear_int4_xpu( + const Tensor& input, + const Tensor& weight, + int qGroupSize, + const Tensor& weight_scale_zero_point) { + std::optional common_device = std::nullopt; + c10::impl::check_and_update_common_device( + common_device, input, "xpu::linear_int4", "input"); + c10::impl::check_and_update_common_device( + common_device, weight, "xpu::linear_int4", "weight"); + c10::impl::check_and_update_common_device( + common_device, + weight_scale_zero_point, + "xpu::linear_int4", + "weight_scale_zero_point"); + Tensor output = at::empty({0}, input.options()); + + at::native::xpu::linear_int4_kernel( + input, weight, qGroupSize, weight_scale_zero_point, output); + return output; +} +} // namespace at::native diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 648cc0e6b..7105236c8 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -125,10 +125,9 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { void linear_int4_kernel( const Tensor& input, const Tensor& weight, + int qGroupSize, const Tensor& weight_scale_zero_point, - const std::optional& weight_bias, - Tensor& output, - int block_size) { + Tensor& output) { auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); int64_t m = input.size(0); int64_t n = input.size(1); diff --git a/src/ATen/native/xpu/sycl/LinearInt4.h b/src/ATen/native/xpu/sycl/LinearInt4.h new file mode 100644 index 000000000..bcd3c78cc --- /dev/null +++ b/src/ATen/native/xpu/sycl/LinearInt4.h @@ -0,0 +1,13 @@ +#pragma once +#include + +namespace at::native::xpu { + +TORCH_XPU_API void linear_int4_kernel( + const Tensor& input, + const Tensor& weight, + int qGroupSize, + const Tensor& weight_scale_zero_point, + Tensor& output); + +} // namespace at::native::xpu From f487b201c37253fc56c2fdbee3c0d630d7b2f704 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Fri, 29 Nov 2024 09:27:47 +0000 Subject: [PATCH 10/25] add yaml for int4mm --- yaml/native/native_functions.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 7e9d19e9d..3a6d5ba13 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8469,3 +8469,8 @@ dispatch: SparseXPU: copy_sparse_ autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out + +- func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor + dispatch: + XPU: linear_int4_xpu + From ce1c89468a21771cc91dfc9fd6549a9e4f5a0788 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Tue, 3 Dec 2024 06:36:54 +0000 Subject: [PATCH 11/25] update yaml file --- yaml/native/native_functions.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 3a6d5ba13..a0c6eff54 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8473,4 +8473,6 @@ - func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor dispatch: XPU: linear_int4_xpu + autogen: _weight_int4pack_mm_with_scales_and_zeros.out + tags: core From d61b198819ec96650a4d092fa4f029afdd8447c1 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Tue, 3 Dec 2024 07:50:05 +0000 Subject: [PATCH 12/25] Modified some review comments --- src/ATen/native/xpu/LinearInt4.cpp | 37 ++++++++++++++++++++++++- src/ATen/native/xpu/sycl/LinearInt4.cpp | 20 ++++++------- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index b8e114714..baa02e341 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -13,6 +13,41 @@ Tensor& linear_int4_xpu( const Tensor& weight, int qGroupSize, const Tensor& weight_scale_zero_point) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + TORCH_CHECK( + input.dtype() == kBFloat16 || input.dtype() == kHalf || + input.dtype() == kFloat, + __func__, + " : expect input to be either 32-bit or 16-bit float tensor."); + + TORCH_CHECK( + weight.dtype() == kByte, __func__, " : expect B to be uint8 tensor."); + TORCH_CHECK( + weight.is_contiguous(), __func__, " : expect B to be contiguous."); + TORCH_CHECK( + weight.size(1) == K / 2, + __func__, + " : expect B.size(1) to be K/2, got ", + weight.size(1)); + + TORCH_CHECK( + qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || + qGroupSize == 256, + __func__, + ": expect qGroupSize to be 32, 64, 128 or 256, got ", + qGroupSize); + + TORCH_CHECK( + weight_scale_zero_point.dim() == 3 && + weight_scale_zero_point.size(1) == N && + weight_scale_zero_point.size(2) == 2, + __func__, + ": expect weight_scale_zero_point to be 3d tensor with sizes [:, ", + N, + ", 2]"); + std::optional common_device = std::nullopt; c10::impl::check_and_update_common_device( common_device, input, "xpu::linear_int4", "input"); @@ -23,7 +58,7 @@ Tensor& linear_int4_xpu( weight_scale_zero_point, "xpu::linear_int4", "weight_scale_zero_point"); - Tensor output = at::empty({0}, input.options()); + Tensor output = at::empty({M, N}, input.options()); at::native::xpu::linear_int4_kernel( input, weight, qGroupSize, weight_scale_zero_point, output); diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 7105236c8..6ac13442c 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -6,11 +6,11 @@ namespace at::native::xpu { template struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { LinearInt4KernelFunctor( - scalar_t* A, - uint32_t* B, + const scalar_t* A, + const uint32_t* B, scalar_t* C, - scalar_t* B_scale, - scalar_t* B_zero_point, + const scalar_t* B_scale, + const scalar_t* B_zero_point, int m, int n, int k, @@ -49,10 +49,10 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto cptr = C + g_n; if constexpr (std::is_same_v) { sycl::half2 tmpAcc = {0.f, 0.f}; + uint8_t tmps8[TileK / 2]; for (int i = 0; i < k; i += GroupK * Unroll) { #pragma unroll for (int iu = 0; iu < Unroll; iu++) { - uint8_t tmps8[TileK / 2]; *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); scalar_t scale = *(sptr + sg_id * TileK / blocksize); @@ -109,11 +109,11 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } private: - scalar_t* A; - uint32_t* B; + const scalar_t* A; + const uint32_t* B; scalar_t* C; - scalar_t* B_scale; - scalar_t* B_zero_point; + const scalar_t* B_scale; + const scalar_t* B_zero_point; int m; int n; int k; @@ -142,7 +142,7 @@ void linear_int4_kernel( if (input.scalar_type() == at::kHalf) { using scalar_t = at::Half; // const auto scalar_t = input.scalar_type(); - scalar_t* input_data = input.data_ptr(); + const scalar_t* input_data = input.data_ptr(); uint32_t* weight_data = weight.data_ptr(); // int4x8 scalar_t* output_data = output.data_ptr(); From d76a0ce486fd5fe9190fb72582538dcd72133d08 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Mon, 9 Dec 2024 11:09:22 +0000 Subject: [PATCH 13/25] modify fun name --- src/ATen/native/xpu/LinearInt4.cpp | 2 +- yaml/native/native_functions.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index baa02e341..2a86c8c32 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -8,7 +8,7 @@ #include namespace at::native { -Tensor& linear_int4_xpu( +Tensor& _weight_int4pack_mm_with_scales_and_zeros_xpu( const Tensor& input, const Tensor& weight, int qGroupSize, diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index a0c6eff54..4746e774c 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8472,7 +8472,7 @@ - func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor dispatch: - XPU: linear_int4_xpu - autogen: _weight_int4pack_mm_with_scales_and_zeros.out + XPU: _weight_int4pack_mm_with_scales_and_zeros_xpu + autogen: _weight_int4pack_mm_with_scales_and_zeros_xpu.out tags: core From 870a3b5623c30328c7d210048da5ed9d9b41d846 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Tue, 10 Dec 2024 02:49:08 +0000 Subject: [PATCH 14/25] autogen: _weight_int4pack_mm_with_scales_and_zeros.out --- yaml/native/native_functions.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 4746e774c..75e63cd78 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8473,6 +8473,6 @@ - func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor dispatch: XPU: _weight_int4pack_mm_with_scales_and_zeros_xpu - autogen: _weight_int4pack_mm_with_scales_and_zeros_xpu.out + autogen: _weight_int4pack_mm_with_scales_and_zeros.out tags: core From a9627f66d3b8abcdcca7edafa97bc18a51ff8672 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Tue, 10 Dec 2024 05:50:31 +0000 Subject: [PATCH 15/25] param int->int64_t(python int is int64) --- src/ATen/native/xpu/LinearInt4.cpp | 4 ++-- yaml/native/native_functions.yaml | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index 2a86c8c32..a3b6f4ac2 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -8,10 +8,10 @@ #include namespace at::native { -Tensor& _weight_int4pack_mm_with_scales_and_zeros_xpu( +Tensor _weight_int4pack_mm_with_scales_and_zeros_xpu( const Tensor& input, const Tensor& weight, - int qGroupSize, + int64_t qGroupSize, const Tensor& weight_scale_zero_point) { auto M = input.size(0); auto N = weight.size(0); diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 75e63cd78..1d3664079 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8473,6 +8473,5 @@ - func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor dispatch: XPU: _weight_int4pack_mm_with_scales_and_zeros_xpu - autogen: _weight_int4pack_mm_with_scales_and_zeros.out - tags: core - + # autogen: _weight_int4pack_mm_with_scales_and_zeros.out + # tags: core From 952ead9a5bc4ce3179beeb29cb33562ef72046ad Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Tue, 10 Dec 2024 08:56:03 +0000 Subject: [PATCH 16/25] use AT_DISPATCH_FLOATING_TYPES_AND --- src/ATen/native/xpu/LinearInt4.cpp | 12 +++-- src/ATen/native/xpu/sycl/LinearInt4.cpp | 71 +++++++++++++++++-------- 2 files changed, 59 insertions(+), 24 deletions(-) diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index a3b6f4ac2..ac2e76909 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -50,13 +50,19 @@ Tensor _weight_int4pack_mm_with_scales_and_zeros_xpu( std::optional common_device = std::nullopt; c10::impl::check_and_update_common_device( - common_device, input, "xpu::linear_int4", "input"); + common_device, + input, + "xpu::_weight_int4pack_mm_with_scales_and_zeros", + "input"); c10::impl::check_and_update_common_device( - common_device, weight, "xpu::linear_int4", "weight"); + common_device, + weight, + "xpu::_weight_int4pack_mm_with_scales_and_zeros", + "weight"); c10::impl::check_and_update_common_device( common_device, weight_scale_zero_point, - "xpu::linear_int4", + "xpu::_weight_int4pack_mm_with_scales_and_zeros", "weight_scale_zero_point"); Tensor output = at::empty({M, N}, input.options()); diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 6ac13442c..558250803 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -139,29 +139,58 @@ void linear_int4_kernel( int lda = k; int ldb = n; int ldc = n; - if (input.scalar_type() == at::kHalf) { - using scalar_t = at::Half; - // const auto scalar_t = input.scalar_type(); - const scalar_t* input_data = input.data_ptr(); - uint32_t* weight_data = weight.data_ptr(); // int4x8 - scalar_t* output_data = output.data_ptr(); - scalar_t* weight_scale_data = weight_scale_zero_point.data_ptr(); - LinearInt4KernelFunctor kfn( - input_data, - weight_data, - output_data, - weight_scale_data, - nullptr, - m, - n, - k, - k, - n, - n); + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::Half, input.scalar_type(), "linear_int4_kernel", [&]() { + using scalar_t = at::Half; + const scalar_t* input_data = input.data_ptr(); + uint32_t* weight_data = weight.data_ptr(); // int4x8 - sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); - } + scalar_t* output_data = output.data_ptr(); + scalar_t* weight_scale_data = + weight_scale_zero_point.data_ptr(); + LinearInt4KernelFunctor kfn( + input_data, + weight_data, + output_data, + weight_scale_data, + nullptr, + m, + n, + k, + k, + n, + n); + + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + }); + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + input.scalar_type(), + "linear_int4_kernel", + [&]() { + using scalar_t = at::BFloat16; + const scalar_t* input_data = input.data_ptr(); + uint32_t* weight_data = weight.data_ptr(); // int4x8 + + scalar_t* output_data = output.data_ptr(); + scalar_t* weight_scale_data = + weight_scale_zero_point.data_ptr(); + LinearInt4KernelFunctor kfn( + input_data, + weight_data, + output_data, + weight_scale_data, + nullptr, + m, + n, + k, + k, + n, + n); + + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + }); } } // namespace at::native::xpu \ No newline at end of file From 93804f9ef2f65804ac389f4ae110ad062859e241 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Wed, 11 Dec 2024 02:39:34 +0000 Subject: [PATCH 17/25] Keep the same name as pytorch's _weight_int4pack_mm --- src/ATen/native/xpu/LinearInt4.cpp | 2 +- yaml/native/native_functions.yaml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index ac2e76909..22381fc02 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -8,7 +8,7 @@ #include namespace at::native { -Tensor _weight_int4pack_mm_with_scales_and_zeros_xpu( +Tensor _weight_int4pack_mm_xpu( const Tensor& input, const Tensor& weight, int64_t qGroupSize, diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 1d3664079..59c0b3b73 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8470,8 +8470,8 @@ SparseXPU: copy_sparse_ autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out -- func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor +- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor dispatch: - XPU: _weight_int4pack_mm_with_scales_and_zeros_xpu - # autogen: _weight_int4pack_mm_with_scales_and_zeros.out + XPU: _weight_int4pack_mm_xpu + # autogen: _weight_int4pack_mm.out # tags: core From 9e50b68c444409a3b57e4dc37bc6a9977a202776 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Wed, 11 Dec 2024 09:02:04 +0000 Subject: [PATCH 18/25] modify UT for int4 --- test/xpu/test_linalg_xpu.py | 112 ++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index 3c1c8aed7..bff4eacf4 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -6,6 +6,7 @@ from torch.testing._internal.common_dtype import floating_and_complex_types_and from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel from torch.testing import make_tensor import unittest import itertools @@ -171,6 +172,116 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): if not use_transpose_a and not use_transpose_b: _test(17, k, n, use_transpose_a, use_transpose_b) +@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") +@parametrize("m", [32]) +@parametrize("k", [32]) +@parametrize("n", [32]) +def _int4_mm(self, device, m, k, n): + @staticmethod + def rand_int4(size, dtype=torch.int32, device="xpu"): + rand = torch.randint(-128, 128, [size // 2], device=device).to(torch.int8) + return rand.view(dtype=dtype) + + @staticmethod + def unpack_weight(qweight, scales, qzeros, q_config): + group_size = q_config["group_size"] + bits = q_config["bits"] + s32_bits = 32 + + assert bits == 4 + # Int32 can store 8 * 4bits data. This is the offset for each data. + wf = ( + torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32) + .unsqueeze(0) + .to("xpu") + ) + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(scales.shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(weight, (2**bits) - 1, out=weight) + + return weight, scales, zeros + + @staticmethod + def dequantize(qweight, scales, qzeros, group_size): + q_config = {"group_size": group_size, "bits": 4} + weight, gptq_scales, gptq_zeros = unpack_weight( + qweight, scales, qzeros, q_config + ) + gptq_zeros = (torch.ones_like(gptq_zeros) * 8).to("xpu") # TODO: hard code zp + if len(weight.shape) > 2: + weight = weight.reshape(-1, weight.shape[-1]) + infeatures = weight.shape[0] + g_idx = torch.tensor( + [i // q_config["group_size"] for i in range(infeatures)], + dtype=torch.int32, + ) + scale_zeros = gptq_zeros * gptq_scales + weight = gptq_scales[g_idx.long()] * weight - scale_zeros[g_idx.long()] + return weight + + def convert_weight_to_int4pack(b): + b_tmp, b_scales_and_zeros = _group_quantize_tensor( + b, n_bit=4, q_group_size=q_group + ) + if self.device_type == 'cpu': + b_int4pack = torch._convert_weight_to_int4pack_for_cpu( + b_tmp, inner_k_tiles + ) + else: + b_int4pack = torch._convert_weight_to_int4pack( + b_tmp, inner_k_tiles + ) + + return b_int4pack, b_scales_and_zeros + + def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): + self.assertTrue(b_int4pack.dtype is torch.int32) + # self.assertTrue(b_int4pack.dim() == 4) + return torch._weight_int4pack_mm( + a, b_int4pack, q_group, b_scales_and_zeros + ) + + dtype = torch.bfloat16 + q_group = 32 + inner_k_tiles = 2 + + torch.manual_seed(1) + a_bf16 = torch.rand((m, k), dtype=dtype, device=device) + b_int4 = rand_int4(k * n, torch.int32, "xpu").reshape(k // 8, n) + group_num = int(k / q_group) + + scales = torch.rand([group_num, n], device="xpu", dtype=dtype) + zero_points = rand_int4(group_num * n, torch.int32, "xpu").reshape( + group_num, n // 8 + ) + + b_bf16 = dequantize(b_int4, scales, zero_points, q_group).cpu() + + # b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) + + for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []): + a = a_bf16.to(dtype=dtype) + b = b_bf16.to(dtype=dtype) + # b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) + res = weight_int4pack_mm(a, b_int4, scales) + ref = torch.mm(a_bf16, b_bf16) + + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.05) + + + + + @dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU @setBlasBackendsToDefaultFinally def matmul_small_brute_force_1d_Nd(self, device, dtype): @@ -229,6 +340,7 @@ def ck_blas_library(self): TestLinalg.test_preferred_linalg_library=preferred_linalg_library TestLinalg.test_addbmm=addbmm TestLinalg.test__int_mm=_int_mm +TestLinalg.test__int4_mm=_int4_mm TestLinalg.test_matmul_small_brute_force_1d_Nd=matmul_small_brute_force_1d_Nd TestLinalg.test_matmul_small_brute_force_2d_Nd=matmul_small_brute_force_2d_Nd TestLinalg.test_matmul_small_brute_force_3d_Nd=matmul_small_brute_force_3d_Nd From 81a72f1945c04ea9b1904a2017b5290c8967f04c Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Thu, 12 Dec 2024 07:13:28 +0000 Subject: [PATCH 19/25] sync UT with pytoch UT(linalg) --- src/ATen/native/xpu/LinearInt4.cpp | 68 ++++++-------- src/ATen/native/xpu/sycl/LinearInt4.cpp | 114 ++++++++++++------------ test/xpu/test_linalg_xpu.py | 40 ++++----- 3 files changed, 102 insertions(+), 120 deletions(-) diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index 22381fc02..f19f49668 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -9,28 +9,26 @@ namespace at::native { Tensor _weight_int4pack_mm_xpu( - const Tensor& input, - const Tensor& weight, + const Tensor& A, + const Tensor& B, int64_t qGroupSize, - const Tensor& weight_scale_zero_point) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); + const Tensor& qScaleAndZeros) { + auto M = A.size(0); + auto N = B.size(0); + auto K = A.size(1); TORCH_CHECK( - input.dtype() == kBFloat16 || input.dtype() == kHalf || - input.dtype() == kFloat, + A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, __func__, - " : expect input to be either 32-bit or 16-bit float tensor."); + " : expect A to be either 32-bit or 16-bit float tensor."); + TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous."); + TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor."); TORCH_CHECK( - weight.dtype() == kByte, __func__, " : expect B to be uint8 tensor."); - TORCH_CHECK( - weight.is_contiguous(), __func__, " : expect B to be contiguous."); - TORCH_CHECK( - weight.size(1) == K / 2, + B.dtype() == kInt || B.dtype() == kUInt32, __func__, - " : expect B.size(1) to be K/2, got ", - weight.size(1)); + " : expect B to be int32 or uint32 tensor."); + TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous."); + TORCH_CHECK(B.dim() == 4, __func__, " : expect B to 4d tensor."); TORCH_CHECK( qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || @@ -39,35 +37,27 @@ Tensor _weight_int4pack_mm_xpu( ": expect qGroupSize to be 32, 64, 128 or 256, got ", qGroupSize); - TORCH_CHECK( - weight_scale_zero_point.dim() == 3 && - weight_scale_zero_point.size(1) == N && - weight_scale_zero_point.size(2) == 2, - __func__, - ": expect weight_scale_zero_point to be 3d tensor with sizes [:, ", - N, - ", 2]"); + // TORCH_CHECK( + // qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(1) == N && + // qScaleAndZeros.size(2) == 2, + // __func__, + // ": expect qScaleAndZeros to be 3d tensor with sizes [:, ", + // N, + // ", 2]"); std::optional common_device = std::nullopt; c10::impl::check_and_update_common_device( - common_device, - input, - "xpu::_weight_int4pack_mm_with_scales_and_zeros", - "input"); + common_device, A, "xpu::_weight_int4pack_mm", "A"); c10::impl::check_and_update_common_device( - common_device, - weight, - "xpu::_weight_int4pack_mm_with_scales_and_zeros", - "weight"); + common_device, B, "xpu::_weight_int4pack_mm", "B"); c10::impl::check_and_update_common_device( common_device, - weight_scale_zero_point, - "xpu::_weight_int4pack_mm_with_scales_and_zeros", - "weight_scale_zero_point"); - Tensor output = at::empty({M, N}, input.options()); + qScaleAndZeros, + "xpu::_weight_int4pack_mm", + "qScaleAndZeros"); + Tensor C = at::empty({M, N}, A.options()); - at::native::xpu::linear_int4_kernel( - input, weight, qGroupSize, weight_scale_zero_point, output); - return output; + at::native::xpu::linear_int4_kernel(A, B, qGroupSize, qScaleAndZeros, C); + return C; } } // namespace at::native diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 558250803..8ab452ee8 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -7,7 +7,7 @@ template struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { LinearInt4KernelFunctor( const scalar_t* A, - const uint32_t* B, + const int32_t* B, scalar_t* C, const scalar_t* B_scale, const scalar_t* B_zero_point, @@ -71,7 +71,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } sycl::half2 sum = {0.f, 0.f}; for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmpAcc, i); + sum += group_broadcast(sg, tmpAcc, i); } if (sg_id == 0) { *cptr = sum[0] + sum[1]; @@ -100,7 +100,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } float sum = 0.f; for (int i = 0; i < SgSize; i += 1) { - sum += sg.shuffle(tmpAcc, i); + sum += group_broadcast(sg, tmpAcc, i); } if (sg_id == 0) { *cptr = sum; @@ -110,7 +110,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { private: const scalar_t* A; - const uint32_t* B; + const int32_t* B; scalar_t* C; const scalar_t* B_scale; const scalar_t* B_zero_point; @@ -123,15 +123,15 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { }; void linear_int4_kernel( - const Tensor& input, - const Tensor& weight, + const Tensor& A, + const Tensor& B, int qGroupSize, - const Tensor& weight_scale_zero_point, - Tensor& output) { + const Tensor& qScaleAndZeros, + Tensor& C) { auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); - int64_t m = input.size(0); - int64_t n = input.size(1); - int64_t k = output.size(1); + int64_t m = A.size(0); + int64_t n = A.size(1); + int64_t k = C.size(1); int constexpr Unroll = 2; int constexpr SgSize = 16; sycl::range<1> local_range{SgSize}; @@ -140,57 +140,55 @@ void linear_int4_kernel( int ldb = n; int ldc = n; - AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::Half, input.scalar_type(), "linear_int4_kernel", [&]() { - using scalar_t = at::Half; - const scalar_t* input_data = input.data_ptr(); - uint32_t* weight_data = weight.data_ptr(); // int4x8 + // AT_DISPATCH_FLOATING_TYPES_AND( + // at::ScalarType::Half, A.scalar_type(), "linear_int4_kernel", [&]() { + if (A.scalar_type() == at::ScalarType::Half) { + using scalar_t = at::Half; + const scalar_t* input_data = A.data_ptr(); + int32_t* weight_data = B.data_ptr(); // int4x8 - scalar_t* output_data = output.data_ptr(); - scalar_t* weight_scale_data = - weight_scale_zero_point.data_ptr(); - LinearInt4KernelFunctor kfn( - input_data, - weight_data, - output_data, - weight_scale_data, - nullptr, - m, - n, - k, - k, - n, - n); + scalar_t* output_data = C.data_ptr(); + scalar_t* weight_scale_data = qScaleAndZeros.data_ptr(); + LinearInt4KernelFunctor kfn( + input_data, + weight_data, + output_data, + weight_scale_data, + nullptr, + m, + n, + k, + k, + n, + n); - sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); - }); - AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::BFloat16, - input.scalar_type(), - "linear_int4_kernel", - [&]() { - using scalar_t = at::BFloat16; - const scalar_t* input_data = input.data_ptr(); - uint32_t* weight_data = weight.data_ptr(); // int4x8 + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + } + // AT_DISPATCH_FLOATING_TYPES_AND( + // at::ScalarType::BFloat16, A.scalar_type(), "linear_int4_kernel", [&]() + // { + else if (A.scalar_type() == at::ScalarType::BFloat16) { + using scalar_t = at::BFloat16; + const scalar_t* input_data = A.data_ptr(); + int32_t* weight_data = B.data_ptr(); // int4x8 - scalar_t* output_data = output.data_ptr(); - scalar_t* weight_scale_data = - weight_scale_zero_point.data_ptr(); - LinearInt4KernelFunctor kfn( - input_data, - weight_data, - output_data, - weight_scale_data, - nullptr, - m, - n, - k, - k, - n, - n); + scalar_t* output_data = C.data_ptr(); + scalar_t* weight_scale_data = qScaleAndZeros.data_ptr(); + LinearInt4KernelFunctor kfn( + input_data, + weight_data, + output_data, + weight_scale_data, + nullptr, + m, + n, + k, + k, + n, + n); - sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); - }); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + } } } // namespace at::native::xpu \ No newline at end of file diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index bff4eacf4..9b2cc8795 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -244,11 +244,18 @@ def convert_weight_to_int4pack(b): return b_int4pack, b_scales_and_zeros def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): - self.assertTrue(b_int4pack.dtype is torch.int32) - # self.assertTrue(b_int4pack.dim() == 4) - return torch._weight_int4pack_mm( - a, b_int4pack, q_group, b_scales_and_zeros - ) + if self.device_type == 'cpu': + self.assertTrue(b_int4pack.dtype is torch.uint8) + self.assertTrue(b_int4pack.dim() == 2) + return torch._weight_int4pack_mm_for_cpu( + a, b_int4pack, q_group, b_scales_and_zeros + ) + else: + self.assertTrue(b_int4pack.dtype is torch.int32) + self.assertTrue(b_int4pack.dim() == 4) + return torch._weight_int4pack_mm( + a, b_int4pack, q_group, b_scales_and_zeros + ) dtype = torch.bfloat16 q_group = 32 @@ -256,32 +263,19 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): torch.manual_seed(1) a_bf16 = torch.rand((m, k), dtype=dtype, device=device) - b_int4 = rand_int4(k * n, torch.int32, "xpu").reshape(k // 8, n) - group_num = int(k / q_group) - - scales = torch.rand([group_num, n], device="xpu", dtype=dtype) - zero_points = rand_int4(group_num * n, torch.int32, "xpu").reshape( - group_num, n // 8 - ) - - b_bf16 = dequantize(b_int4, scales, zero_points, q_group).cpu() - - # b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) + b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device) + b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []): a = a_bf16.to(dtype=dtype) b = b_bf16.to(dtype=dtype) - # b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) - res = weight_int4pack_mm(a, b_int4, scales) - ref = torch.mm(a_bf16, b_bf16) + b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) + ref = torch.mm(a, b) + res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05) - - - - @dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU @setBlasBackendsToDefaultFinally def matmul_small_brute_force_1d_Nd(self, device, dtype): From a70df0a7257151d40499e5f0c5840c4699f4e3e6 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Thu, 12 Dec 2024 07:23:00 +0000 Subject: [PATCH 20/25] col-major --- src/ATen/native/xpu/LinearInt4.cpp | 2 +- test/xpu/test_linalg_xpu.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index f19f49668..0f3bab961 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -14,7 +14,7 @@ Tensor _weight_int4pack_mm_xpu( int64_t qGroupSize, const Tensor& qScaleAndZeros) { auto M = A.size(0); - auto N = B.size(0); + auto N = B.size(1); auto K = A.size(1); TORCH_CHECK( A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index 9b2cc8795..82b4c13f1 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -272,8 +272,9 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) ref = torch.mm(a, b) res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) - mean_err = ((res - ref).abs() / ref).mean() + print(ref) + print(res) self.assertTrue(mean_err < 0.05) @dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU From c08382c648e0776836dffa05601b409892daa942 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Fri, 13 Dec 2024 11:57:59 +0000 Subject: [PATCH 21/25] UT pass for B ones --- test/xpu/test_linalg_xpu.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index 82b4c13f1..fbebec7e9 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -173,9 +173,9 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): _test(17, k, n, use_transpose_a, use_transpose_b) @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") -@parametrize("m", [32]) -@parametrize("k", [32]) -@parametrize("n", [32]) +@parametrize("m", [1]) +@parametrize("k", [1024]) +@parametrize("n", [1024]) def _int4_mm(self, device, m, k, n): @staticmethod def rand_int4(size, dtype=torch.int32, device="xpu"): @@ -257,16 +257,15 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): a, b_int4pack, q_group, b_scales_and_zeros ) - dtype = torch.bfloat16 + dtype = torch.float16 q_group = 32 inner_k_tiles = 2 torch.manual_seed(1) a_bf16 = torch.rand((m, k), dtype=dtype, device=device) - b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device) + b_bf16 = torch.ones((k, n), dtype=dtype, device=device) b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) - - for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []): + for dtype in [torch.float16] + ([torch.float16, torch.float32] if device == "cpu" else []): a = a_bf16.to(dtype=dtype) b = b_bf16.to(dtype=dtype) b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) From 14bb4e0bdb970af4dadf23a64f6b8a9226e1df98 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Mon, 16 Dec 2024 01:49:36 +0000 Subject: [PATCH 22/25] update gemv --- src/ATen/native/xpu/LinearInt4.cpp | 2 +- src/ATen/native/xpu/sycl/LinearInt4.cpp | 410 +++++++++++++++++------- src/ATen/native/xpu/sycl/LinearInt4.h | 1 + 3 files changed, 303 insertions(+), 110 deletions(-) diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index 0f3bab961..7337fa3c7 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -14,7 +14,7 @@ Tensor _weight_int4pack_mm_xpu( int64_t qGroupSize, const Tensor& qScaleAndZeros) { auto M = A.size(0); - auto N = B.size(1); + auto N = B.size(0) * 8; auto K = A.size(1); TORCH_CHECK( A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 8ab452ee8..67689d82d 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -2,12 +2,23 @@ #include namespace at::native::xpu { +static inline int padto_le(int src, int padding) { + return src / padding * padding; +} + +static inline int64_t padto_le(int64_t src, int64_t padding) { + return src / padding * padding; +} -template +static inline size_t padto_le(size_t src, int padding) { + return src / size_t(padding) * size_t(padding); +} + +template struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { LinearInt4KernelFunctor( const scalar_t* A, - const int32_t* B, + const uint8_t* B, scalar_t* C, const scalar_t* B_scale, const scalar_t* B_zero_point, @@ -16,7 +27,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int k, int lda, int ldb, - int ldc) + int ldc, + sycl::stream os) : A(A), B(B), C(C), @@ -27,90 +39,256 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { k(k), lda(lda), ldb(ldb), - ldc(ldc) {} + ldc(ldc), + os_(os) {} void sycl_ker_config_convention(sycl::handler& cgh) { // local_scan_ = sycl_local_acc_t(N_, cgh); + // sycl::stream os_(1024, 128, cgh); + // os_ = sycl::stream(1024, 128, cgh); } - void operator()(sycl::nd_item<1> item) const { + void operator()(sycl::nd_item<1> it) const { int constexpr Unroll = 2; int constexpr SgSize = 16; int constexpr TileK = 32; int constexpr GroupK = SgSize * TileK; int constexpr blocksize = 16; - int g_idx = item.get_group(0); - auto sg = item.get_sub_group(); - int sg_id = sg.get_local_id()[0]; - int g_n = g_idx; - auto sptr = B_scale + g_n * ldb; - auto bptr = B + g_n * k / 2; - auto aptr = A; - auto cptr = C + g_n; - if constexpr (std::is_same_v) { - sycl::half2 tmpAcc = {0.f, 0.f}; - uint8_t tmps8[TileK / 2]; - for (int i = 0; i < k; i += GroupK * Unroll) { + // int g_idx = item.get_group(0); + // auto sg = item.get_sub_group(); + // int sg_id = sg.get_local_id()[0]; + // int g_n = g_idx; + // auto sptr = B_scale + g_n * ldb; + // auto bptr = B + g_n * k / 2; + // auto aptr = A; + // auto cptr = C + g_n; + if (k % (SgSize * 32 * Unroll) == 0) { + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = B_scale + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { // Half + sycl::half2 tmpAcc = {0.f, 0.f}; + for (int i = 0; i < k; i += GroupK * Unroll) { #pragma unroll - for (int iu = 0; iu < Unroll; iu++) { - *(sycl::vec*)tmps8 = - *(sycl::vec*)(bptr + sg_id * TileK / 2); - scalar_t scale = *(sptr + sg_id * TileK / blocksize); + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scalar_t scale = 1.f; + // scalar_t scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll - for (int ikk = 0; ikk < TileK; ikk += 2) { - sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; - sycl::half2 tmpB = { - static_cast((tmps8[ikk / 2] & 0x0f) - 8), - static_cast((tmps8[ikk / 2] >> 4) - 8)}; - tmpAcc += tmpA * tmpB * scale; + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = {1.f, 1.f}; + // static_cast((tmps8[ikk / 2] & 0x0f) - 8), + // static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; } - sptr += GroupK / blocksize; - aptr += GroupK; - bptr += GroupK / 2; } - } - sycl::half2 sum = {0.f, 0.f}; - for (int i = 0; i < SgSize; i += 1) { - sum += group_broadcast(sg, tmpAcc, i); - } - if (sg_id == 0) { - *cptr = sum[0] + sum[1]; - } - } else { - scalar_t tmpAcc = 0.f; - int constexpr Unroll = 2; - for (int i = 0; i < k; i += GroupK * Unroll) { + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += select_from_group(sg, tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { // Bfloat16 + scalar_t tmpAcc = 0.f; + int constexpr Unroll = 2; + for (int i = 0; i < k; i += GroupK * Unroll) { #pragma unroll - for (int iu = 0; iu < Unroll; iu++) { - uint8_t tmps8[TileK / 2]; - *(sycl::vec*)tmps8 = - *(sycl::vec*)(bptr + sg_id * TileK / 2); - scalar_t scale = *(sptr + sg_id * TileK / blocksize); + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scalar_t scale = *(sptr + sg_id * TileK / blocksize); #pragma unroll - for (int ikk = 0; ikk < TileK; ikk += 2) { - tmpAcc += scalar_t(aptr[sg_id * TileK + ikk]) * - static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; - tmpAcc += scalar_t(aptr[sg_id * TileK + ikk + 1]) * - static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; } - sptr += GroupK / blocksize; - aptr += GroupK; - bptr += GroupK / 2; + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += select_from_group(sg, tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; } } - float sum = 0.f; - for (int i = 0; i < SgSize; i += 1) { - sum += group_broadcast(sg, tmpAcc, i); - } - if (sg_id == 0) { - *cptr = sum; + } else { // k % (SgSize * 32 * Unroll) != 0 + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + int k_body = padto_le(k, GroupK * Unroll); + int constexpr TileK2 = 8; + int constexpr GroupK2 = SgSize * TileK2; + int k_body2 = padto_le(k, GroupK2 * Unroll); + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = B_scale + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { // Half + sycl::half2 tmpAcc = {0.f, 0.f}; + int i = 0; + for (; i < k_body; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scalar_t scale = 1.f; + // scalar_t scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = {1.f, 1.f}; + // static_cast((tmps8[ikk / 2] & 0x0f) - 8), + // static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + if (i + GroupK2 * Unroll < k_body2) { + for (; i < k_body2; i += GroupK2 * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK2 / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK2 / 2); + scalar_t scale = 1.f; + // scalar_t scale = *(sptr + sg_id * TileK2 / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK2; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK2 + ikk]; + sycl::half2 tmpB = {1.f, 1.f}; + // static_cast((tmps8[ikk / 2] & 0x0f) - 8), + // static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK2 / blocksize; + aptr += GroupK2; + bptr += GroupK2 / 2; + } + } + } + if (i + SgSize * 2 < k) { + for (; i < k; i += SgSize * 2) { + uint8_t tmps8 = *(bptr + sg_id); + scalar_t scale = 1.f; + // scalar_t scale = *(sptr + sg_id * 2 / blocksize); + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * 2]; + sycl::half2 tmpB = {1.f, 1.f}; + // static_cast((tmps8 & 0x0f) - 8), + // static_cast((tmps8 >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + sptr += SgSize * 2 / blocksize; + aptr += SgSize * 2; + bptr += SgSize * 2 / 2; + } + } + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += select_from_group(sg, tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { // Bfloat16 + scalar_t tmpAcc = 0.f; + int constexpr Unroll = 2; + int i = 0; + for (; i < k_body; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scalar_t scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + if (i + GroupK2 * Unroll < k_body2) { + for (; i < k_body2; i += GroupK2 * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK2 / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK2 / 2); + scalar_t scale = *(sptr + sg_id * TileK2 / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK2; ikk += 2) { + tmpAcc += scalar_t(aptr[sg_id * TileK2 + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += scalar_t(aptr[sg_id * TileK2 + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK2 / blocksize; + aptr += GroupK2; + bptr += GroupK2 / 2; + } + } + } + if (i + SgSize * Unroll < k) { + for (; i < k; i += SgSize) { + uint8_t tmps8 = *(bptr + sg_id / 2); + scalar_t scale = *(sptr + sg_id / blocksize); + tmpAcc += scalar_t(aptr[sg_id]) * + static_cast((tmps8 & 0x0f) - 8) * scale; + tmpAcc += scalar_t(aptr[sg_id]) * + static_cast((tmps8 >> 4) - 8) * scale; + sptr += SgSize / blocksize; + aptr += SgSize; + bptr += SgSize / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += select_from_group(sg, tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; + } } } } private: const scalar_t* A; - const int32_t* B; + const uint8_t* B; scalar_t* C; const scalar_t* B_scale; const scalar_t* B_zero_point; @@ -120,6 +298,10 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int lda; int ldb; int ldc; + + private: + sycl::stream os_; + // sycl::handler cgh_; }; void linear_int4_kernel( @@ -130,64 +312,74 @@ void linear_int4_kernel( Tensor& C) { auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); int64_t m = A.size(0); - int64_t n = A.size(1); - int64_t k = C.size(1); + int64_t n = C.size(1); + int64_t k = A.size(1); int constexpr Unroll = 2; int constexpr SgSize = 16; sycl::range<1> local_range{SgSize}; sycl::range<1> global_range{static_cast(n) * SgSize}; - int lda = k; - int ldb = n; - int ldc = n; - // AT_DISPATCH_FLOATING_TYPES_AND( - // at::ScalarType::Half, A.scalar_type(), "linear_int4_kernel", [&]() { if (A.scalar_type() == at::ScalarType::Half) { using scalar_t = at::Half; - const scalar_t* input_data = A.data_ptr(); - int32_t* weight_data = B.data_ptr(); // int4x8 - - scalar_t* output_data = C.data_ptr(); - scalar_t* weight_scale_data = qScaleAndZeros.data_ptr(); - LinearInt4KernelFunctor kfn( - input_data, - weight_data, - output_data, - weight_scale_data, - nullptr, - m, - n, - k, - k, - n, - n); - - sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + using scalar_sycl_t = sycl::half; + const scalar_sycl_t* input_data = + reinterpret_cast(A.data_ptr()); + uint8_t* weight_data = + reinterpret_cast(B.data_ptr()); // int4x8 + + scalar_sycl_t* output_data = + reinterpret_cast(C.data_ptr()); + scalar_sycl_t* weight_scale_data = + reinterpret_cast(qScaleAndZeros.data_ptr()); + + auto cgf = [&](::sycl::handler& cgh) { + auto os = sycl::stream(1024, 768, cgh); + LinearInt4KernelFunctor kfn( + input_data, + weight_data, + output_data, + weight_scale_data, + nullptr, + m, + n, + k, + k, + n, + n, + os); + kfn.sycl_ker_config_convention(cgh); + cgh.parallel_for>( + ::sycl::nd_range<1>(global_range, local_range), kfn); + }; + sycl_queue.submit(cgf); + + // sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); } // AT_DISPATCH_FLOATING_TYPES_AND( - // at::ScalarType::BFloat16, A.scalar_type(), "linear_int4_kernel", [&]() + // at::ScalarType::BFloat16, A.scalar_type(), "linear_int4_kernel", + // [&]() // { else if (A.scalar_type() == at::ScalarType::BFloat16) { - using scalar_t = at::BFloat16; - const scalar_t* input_data = A.data_ptr(); - int32_t* weight_data = B.data_ptr(); // int4x8 - - scalar_t* output_data = C.data_ptr(); - scalar_t* weight_scale_data = qScaleAndZeros.data_ptr(); - LinearInt4KernelFunctor kfn( - input_data, - weight_data, - output_data, - weight_scale_data, - nullptr, - m, - n, - k, - k, - n, - n); - - sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + // using scalar_t = at::BFloat16; + // const scalar_t* input_data = A.data_ptr(); + // int32_t* weight_data = B.data_ptr(); // int4x8 + + // scalar_t* output_data = C.data_ptr(); + // scalar_t* weight_scale_data = qScaleAndZeros.data_ptr(); + // LinearInt4KernelFunctor kfn( + // input_data, + // weight_data, + // output_data, + // weight_scale_data, + // nullptr, + // m, + // n, + // k, + // k, + // n, + // n); + + // sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); } } diff --git a/src/ATen/native/xpu/sycl/LinearInt4.h b/src/ATen/native/xpu/sycl/LinearInt4.h index bcd3c78cc..c54f3df21 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.h +++ b/src/ATen/native/xpu/sycl/LinearInt4.h @@ -1,4 +1,5 @@ #pragma once +#include #include namespace at::native::xpu { From 70a3e13e93c2a838d87ab14e765c417be265ac97 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Tue, 17 Dec 2024 03:10:06 +0000 Subject: [PATCH 23/25] fix scale and zp address --- src/ATen/native/xpu/sycl/LinearInt4.cpp | 48 ++++++++++--------------- 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 67689d82d..ab552399a 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -14,14 +14,13 @@ static inline size_t padto_le(size_t src, int padding) { return src / size_t(padding) * size_t(padding); } -template +template struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { LinearInt4KernelFunctor( const scalar_t* A, const uint8_t* B, scalar_t* C, - const scalar_t* B_scale, - const scalar_t* B_zero_point, + const scalar_t* ScaleAndZeros, int m, int n, int k, @@ -32,8 +31,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { : A(A), B(B), C(C), - B_scale(B_scale), - B_zero_point(B_zero_point), + ScaleAndZeros(ScaleAndZeros), m(m), n(n), k(k), @@ -52,16 +50,8 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int constexpr SgSize = 16; int constexpr TileK = 32; int constexpr GroupK = SgSize * TileK; - int constexpr blocksize = 16; + int constexpr blocksize = block_size; - // int g_idx = item.get_group(0); - // auto sg = item.get_sub_group(); - // int sg_id = sg.get_local_id()[0]; - // int g_n = g_idx; - // auto sptr = B_scale + g_n * ldb; - // auto bptr = B + g_n * k / 2; - // auto aptr = A; - // auto cptr = C + g_n; if (k % (SgSize * 32 * Unroll) == 0) { int constexpr TileK = 32; int constexpr GroupK = SgSize * TileK; @@ -70,7 +60,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto sg = it.get_sub_group(); int sg_id = sg.get_local_id()[0]; int g_n = g_idx; - auto sptr = B_scale + g_n * ldb; + auto sptr = ScaleAndZeros + g_n * ldb; auto bptr = B + g_n * k / 2; auto aptr = A; auto cptr = C + g_n; @@ -82,17 +72,17 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { uint8_t tmps8[TileK / 2]; *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - scalar_t scale = 1.f; - // scalar_t scale = *(sptr + sg_id * TileK / blocksize); + scalar_t scale = *(sptr + sg_id * (TileK / blocksize) * 2); + scalar_t zero_point = *(sptr + sg_id * (TileK / blocksize) * 2 + 1); #pragma unroll for (int ikk = 0; ikk < TileK; ikk += 2) { sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; - sycl::half2 tmpB = {1.f, 1.f}; - // static_cast((tmps8[ikk / 2] & 0x0f) - 8), - // static_cast((tmps8[ikk / 2] >> 4) - 8)}; + sycl::half2 tmpB = { + static_cast((tmps8[ikk / 2] & 0x0f) - zero_point), + static_cast((tmps8[ikk / 2] >> 4) - zero_point)}; tmpAcc += tmpA * tmpB * scale; } - sptr += GroupK / blocksize; + sptr += (GroupK / blocksize) * 2; aptr += GroupK; bptr += GroupK / 2; } @@ -145,7 +135,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto sg = it.get_sub_group(); int sg_id = sg.get_local_id()[0]; int g_n = g_idx; - auto sptr = B_scale + g_n * ldb; + auto sptr = ScaleAndZeros + g_n * ldb; auto bptr = B + g_n * k / 2; auto aptr = A; auto cptr = C + g_n; @@ -290,8 +280,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { const scalar_t* A; const uint8_t* B; scalar_t* C; - const scalar_t* B_scale; - const scalar_t* B_zero_point; + const scalar_t* ScaleAndZeros; int m; int n; int k; @@ -329,26 +318,25 @@ void linear_int4_kernel( scalar_sycl_t* output_data = reinterpret_cast(C.data_ptr()); - scalar_sycl_t* weight_scale_data = + scalar_sycl_t* scale_zeros_data = reinterpret_cast(qScaleAndZeros.data_ptr()); auto cgf = [&](::sycl::handler& cgh) { auto os = sycl::stream(1024, 768, cgh); - LinearInt4KernelFunctor kfn( + LinearInt4KernelFunctor kfn( input_data, weight_data, output_data, - weight_scale_data, - nullptr, + scale_zeros_data, m, n, k, k, - n, + (k / qGroupSize) * 2, // scale and zero point combined n, os); kfn.sycl_ker_config_convention(cgh); - cgh.parallel_for>( + cgh.parallel_for>( ::sycl::nd_range<1>(global_range, local_range), kfn); }; sycl_queue.submit(cgf); From a590ad6e1e8b1d48f34e8e652f3ad3d241a14d60 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Wed, 18 Dec 2024 02:36:49 +0000 Subject: [PATCH 24/25] fix K large than 1024 UT --- test/xpu/test_linalg_xpu.py | 75 +++++++++---------------------------- 1 file changed, 17 insertions(+), 58 deletions(-) diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index fbebec7e9..012d30f98 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -174,68 +174,21 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @parametrize("m", [1]) -@parametrize("k", [1024]) -@parametrize("n", [1024]) +@parametrize("k", [1024, 2048]) +@parametrize("n", [48, 64]) def _int4_mm(self, device, m, k, n): - @staticmethod - def rand_int4(size, dtype=torch.int32, device="xpu"): - rand = torch.randint(-128, 128, [size // 2], device=device).to(torch.int8) - return rand.view(dtype=dtype) - - @staticmethod - def unpack_weight(qweight, scales, qzeros, q_config): - group_size = q_config["group_size"] - bits = q_config["bits"] - s32_bits = 32 - - assert bits == 4 - # Int32 can store 8 * 4bits data. This is the offset for each data. - wf = ( - torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32) - .unsqueeze(0) - .to("xpu") - ) - zeros = torch.bitwise_right_shift( - torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) - ).to(torch.int16 if bits == 8 else torch.int8) - torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) - - zeros = zeros + 1 - zeros = zeros.reshape(scales.shape) - - weight = torch.bitwise_right_shift( - torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1) - ).to(torch.int16 if bits == 8 else torch.int8) - torch.bitwise_and(weight, (2**bits) - 1, out=weight) - - return weight, scales, zeros - - @staticmethod - def dequantize(qweight, scales, qzeros, group_size): - q_config = {"group_size": group_size, "bits": 4} - weight, gptq_scales, gptq_zeros = unpack_weight( - qweight, scales, qzeros, q_config - ) - gptq_zeros = (torch.ones_like(gptq_zeros) * 8).to("xpu") # TODO: hard code zp - if len(weight.shape) > 2: - weight = weight.reshape(-1, weight.shape[-1]) - infeatures = weight.shape[0] - g_idx = torch.tensor( - [i // q_config["group_size"] for i in range(infeatures)], - dtype=torch.int32, - ) - scale_zeros = gptq_zeros * gptq_scales - weight = gptq_scales[g_idx.long()] * weight - scale_zeros[g_idx.long()] - return weight - def convert_weight_to_int4pack(b): b_tmp, b_scales_and_zeros = _group_quantize_tensor( b, n_bit=4, q_group_size=q_group ) + if self.device_type == 'cpu': b_int4pack = torch._convert_weight_to_int4pack_for_cpu( b_tmp, inner_k_tiles ) + elif self.device_type == 'xpu': + # b_int4pack = b_tmp.view(torch.int32) + b_int4pack = b_tmp else: b_int4pack = torch._convert_weight_to_int4pack( b_tmp, inner_k_tiles @@ -250,6 +203,12 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): return torch._weight_int4pack_mm_for_cpu( a, b_int4pack, q_group, b_scales_and_zeros ) + elif self.device_type == 'xpu': + self.assertTrue(b_int4pack.dtype is torch.int32 or b_int4pack.dtype is torch.uint8) + self.assertTrue(b_int4pack.dim() == 2) + return torch._weight_int4pack_mm( + a, b_int4pack, q_group, b_scales_and_zeros + ) else: self.assertTrue(b_int4pack.dtype is torch.int32) self.assertTrue(b_int4pack.dim() == 4) @@ -257,23 +216,23 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): a, b_int4pack, q_group, b_scales_and_zeros ) - dtype = torch.float16 q_group = 32 inner_k_tiles = 2 torch.manual_seed(1) - a_bf16 = torch.rand((m, k), dtype=dtype, device=device) - b_bf16 = torch.ones((k, n), dtype=dtype, device=device) + a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device) + b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device) + b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) + for dtype in [torch.float16] + ([torch.float16, torch.float32] if device == "cpu" else []): a = a_bf16.to(dtype=dtype) b = b_bf16.to(dtype=dtype) b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) ref = torch.mm(a, b) res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) + mean_err = ((res - ref).abs() / ref).mean() - print(ref) - print(res) self.assertTrue(mean_err < 0.05) @dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU From d6a2f3a7083b7e32efd13858d76d732ffc156446 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Wed, 18 Dec 2024 09:07:11 +0000 Subject: [PATCH 25/25] bug fix for FP16(BF16 maybe incorrect) --- test/xpu/test_linalg_xpu.py | 53 ++++++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index 012d30f98..88acd9134 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -6,7 +6,7 @@ from torch.testing._internal.common_dtype import floating_and_complex_types_and from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.common_mkldnn import bf32_on_and_off -from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel +from torch.testing._internal.common_quantization import _dynamically_quantize_per_channel from torch.testing import make_tensor import unittest import itertools @@ -174,9 +174,53 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @parametrize("m", [1]) -@parametrize("k", [1024, 2048]) -@parametrize("n", [48, 64]) +@parametrize("k", [256, 512, 1024]) +@parametrize("n", [32, 64]) def _int4_mm(self, device, m, k, n): + def _group_quantize_tensor(w, n_bit=4, q_group_size=16): + assert w.dim() == 2 + w = w.transpose(0, 1).contiguous() + assert q_group_size > 1 + assert w.shape[-1] % q_group_size == 0 + + to_quant = w.reshape(-1, q_group_size) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2 ** n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + assert torch.isnan(scales).sum() == 0 + + zeros = min_val + scales * (2 ** (n_bit - 1)) + assert torch.isnan(zeros).sum() == 0 + + out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) + assert torch.isnan(out).sum() == 0 + + out = out.to(dtype=torch.int32).reshape(w.shape) + if out.device.type != 'cpu' or out.device.type != 'xpu': + out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) + + # Scales and zeros for the same q-group should be contiguous, so we can + # load as a 32-bit word + scales = scales.view(w.shape[0], -1) + zeros = zeros.view(w.shape[0], -1) + scales_and_zeros = ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + ) + + if out.device.type != 'xpu': + scales_and_zeros = scales_and_zeros.transpose(0, 1).contiguous() + return out, scales_and_zeros + def convert_weight_to_int4pack(b): b_tmp, b_scales_and_zeros = _group_quantize_tensor( b, n_bit=4, q_group_size=q_group @@ -231,7 +275,8 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) ref = torch.mm(a, b) res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) - + print(ref) + print(res) mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05)