diff --git a/.github/scripts/apply_torch_pr.py b/.github/scripts/apply_torch_pr.py index 8c550bd87..b4b441263 100644 --- a/.github/scripts/apply_torch_pr.py +++ b/.github/scripts/apply_torch_pr.py @@ -13,6 +13,8 @@ "https://github.com/pytorch/pytorch/pull/126516", # Modify the tolerance level in TIMM benchmark "https://github.com/pytorch/pytorch/pull/129735", + # [XPU] Update XPU C Shim Header + "https://github.com/pytorch/pytorch/pull/141086", ] ) parser.add_argument('--extra-pr-list', '-e', nargs='+',default=[]) diff --git a/.github/scripts/env.sh b/.github/scripts/env.sh index ab7d7812d..56d8e3930 100644 --- a/.github/scripts/env.sh +++ b/.github/scripts/env.sh @@ -1,3 +1,4 @@ #!/bin/bash -source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh +source /opt/intel/oneapi/compiler/latest/env/vars.sh +source /opt/intel/oneapi/umf/latest/env/vars.sh source /opt/intel/oneapi/pti/latest/env/vars.sh diff --git a/.github/workflows/_linux_ut.yml b/.github/workflows/_linux_ut.yml index 5ccd22b05..d2f717230 100644 --- a/.github/workflows/_linux_ut.yml +++ b/.github/workflows/_linux_ut.yml @@ -97,7 +97,7 @@ jobs: run: | source activate xpu_op_${ZE_AFFINITY_MASK} source .github/scripts/env.sh - pip install mkl-static mkl-include + pip install mkl-static==2025.0.1 mkl-include==2025.0.1 cd ../pytorch if [[ ${{ inputs.abi }} == '0' ]]; then export _GLIBCXX_USE_CXX11_ABI=0 diff --git a/.github/workflows/nightly_ondemand.yml b/.github/workflows/nightly_ondemand.yml index e33e55d56..2edc06102 100644 --- a/.github/workflows/nightly_ondemand.yml +++ b/.github/workflows/nightly_ondemand.yml @@ -123,7 +123,7 @@ jobs: conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci conda create -n e2e_ci python=${{ env.python }} cmake ninja -y source activate e2e_ci - pip install mkl-static mkl-include + pip install mkl-static==2025.0.1 mkl-include==2025.0.1 pip install pandas scipy tqdm - name: Prepare Stock Pytorch run: | diff --git a/.github/workflows/nightly_ondemand_rolling.yml b/.github/workflows/nightly_ondemand_rolling.yml index 309fd58fc..0a27b2b50 100644 --- a/.github/workflows/nightly_ondemand_rolling.yml +++ b/.github/workflows/nightly_ondemand_rolling.yml @@ -125,7 +125,7 @@ jobs: conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci conda create -n e2e_ci python=${{ env.python }} cmake ninja -y source activate e2e_ci - pip install mkl-static mkl-include + pip install mkl-static==2025.0.1 mkl-include==2025.0.1 pip install pandas scipy tqdm - name: Prepare Stock Pytorch run: | diff --git a/.github/workflows/nightly_ondemand_whl.yml b/.github/workflows/nightly_ondemand_whl.yml index 44b63cfd4..6b8d0b58f 100644 --- a/.github/workflows/nightly_ondemand_whl.yml +++ b/.github/workflows/nightly_ondemand_whl.yml @@ -98,7 +98,7 @@ jobs: conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci conda create -n e2e_ci python=${{ env.python }} cmake ninja -y source activate e2e_ci - pip install mkl-static mkl-include + pip install mkl-static==2025.0.1 mkl-include==2025.0.1 pip install pandas scipy tqdm - name: Prepare Stock Pytorch run: | diff --git a/cmake/BuildFlags.cmake b/cmake/BuildFlags.cmake index a8784097a..d7328a832 100644 --- a/cmake/BuildFlags.cmake +++ b/cmake/BuildFlags.cmake @@ -122,7 +122,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" set(SYCL_OFFLINE_COMPILER_CG_OPTIONS "-options '${SYCL_OFFLINE_COMPILER_CG_OPTIONS}'") if(WIN32) - set(AOT_TARGETS "ats-m150,lnl-m,mtl-u,mtl-h") + set(AOT_TARGETS "ats-m150,mtl-u,mtl-h,xe2-lpg,xe2-hpg") else() set(AOT_TARGETS "pvc,xe-lpg,ats-m150") endif() diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp new file mode 100644 index 000000000..7337fa3c7 --- /dev/null +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -0,0 +1,63 @@ + +#include +#include +#include +#include + +#include +#include + +namespace at::native { +Tensor _weight_int4pack_mm_xpu( + const Tensor& A, + const Tensor& B, + int64_t qGroupSize, + const Tensor& qScaleAndZeros) { + auto M = A.size(0); + auto N = B.size(0) * 8; + auto K = A.size(1); + TORCH_CHECK( + A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, + __func__, + " : expect A to be either 32-bit or 16-bit float tensor."); + TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous."); + TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor."); + + TORCH_CHECK( + B.dtype() == kInt || B.dtype() == kUInt32, + __func__, + " : expect B to be int32 or uint32 tensor."); + TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous."); + TORCH_CHECK(B.dim() == 4, __func__, " : expect B to 4d tensor."); + + TORCH_CHECK( + qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || + qGroupSize == 256, + __func__, + ": expect qGroupSize to be 32, 64, 128 or 256, got ", + qGroupSize); + + // TORCH_CHECK( + // qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(1) == N && + // qScaleAndZeros.size(2) == 2, + // __func__, + // ": expect qScaleAndZeros to be 3d tensor with sizes [:, ", + // N, + // ", 2]"); + + std::optional common_device = std::nullopt; + c10::impl::check_and_update_common_device( + common_device, A, "xpu::_weight_int4pack_mm", "A"); + c10::impl::check_and_update_common_device( + common_device, B, "xpu::_weight_int4pack_mm", "B"); + c10::impl::check_and_update_common_device( + common_device, + qScaleAndZeros, + "xpu::_weight_int4pack_mm", + "qScaleAndZeros"); + Tensor C = at::empty({M, N}, A.options()); + + at::native::xpu::linear_int4_kernel(A, B, qGroupSize, qScaleAndZeros, C); + return C; +} +} // namespace at::native diff --git a/src/ATen/native/xpu/RreluWithNoise.cpp b/src/ATen/native/xpu/RreluWithNoise.cpp index f66833983..fb4e2c333 100644 --- a/src/ATen/native/xpu/RreluWithNoise.cpp +++ b/src/ATen/native/xpu/RreluWithNoise.cpp @@ -6,7 +6,7 @@ namespace native { Tensor& rrelu_with_noise_out_xpu( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -18,7 +18,7 @@ Tensor& rrelu_with_noise_out_xpu( Tensor rrelu_with_noise_xpu( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, @@ -30,7 +30,7 @@ Tensor rrelu_with_noise_xpu( Tensor& rrelu_with_noise_xpu_( Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp b/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp index b96ab461e..c4bfac98f 100644 --- a/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp +++ b/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp @@ -1,16 +1,17 @@ #include +#include #include - -#include - #include +#include namespace at::native::xpu { template struct SoftshrinkFunctor { scalar_t operator()(scalar_t a) const { - return a > lambd_ ? a - lambd_ : (a < -lambd_ ? a + lambd_ : scalar_t(0)); + return at::_isnan(a) + ? a + : (a > lambd_ ? a - lambd_ : (a < -lambd_ ? a + lambd_ : scalar_t(0))); } SoftshrinkFunctor(scalar_t lambd) : lambd_(lambd) {} diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp new file mode 100644 index 000000000..ab552399a --- /dev/null +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -0,0 +1,374 @@ +#include +#include + +namespace at::native::xpu { +static inline int padto_le(int src, int padding) { + return src / padding * padding; +} + +static inline int64_t padto_le(int64_t src, int64_t padding) { + return src / padding * padding; +} + +static inline size_t padto_le(size_t src, int padding) { + return src / size_t(padding) * size_t(padding); +} + +template +struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + LinearInt4KernelFunctor( + const scalar_t* A, + const uint8_t* B, + scalar_t* C, + const scalar_t* ScaleAndZeros, + int m, + int n, + int k, + int lda, + int ldb, + int ldc, + sycl::stream os) + : A(A), + B(B), + C(C), + ScaleAndZeros(ScaleAndZeros), + m(m), + n(n), + k(k), + lda(lda), + ldb(ldb), + ldc(ldc), + os_(os) {} + void sycl_ker_config_convention(sycl::handler& cgh) { + // local_scan_ = sycl_local_acc_t(N_, cgh); + // sycl::stream os_(1024, 128, cgh); + // os_ = sycl::stream(1024, 128, cgh); + } + + void operator()(sycl::nd_item<1> it) const { + int constexpr Unroll = 2; + int constexpr SgSize = 16; + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + int constexpr blocksize = block_size; + + if (k % (SgSize * 32 * Unroll) == 0) { + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = ScaleAndZeros + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { // Half + sycl::half2 tmpAcc = {0.f, 0.f}; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scalar_t scale = *(sptr + sg_id * (TileK / blocksize) * 2); + scalar_t zero_point = *(sptr + sg_id * (TileK / blocksize) * 2 + 1); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = { + static_cast((tmps8[ikk / 2] & 0x0f) - zero_point), + static_cast((tmps8[ikk / 2] >> 4) - zero_point)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += (GroupK / blocksize) * 2; + aptr += GroupK; + bptr += GroupK / 2; + } + } + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += select_from_group(sg, tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { // Bfloat16 + scalar_t tmpAcc = 0.f; + int constexpr Unroll = 2; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scalar_t scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += select_from_group(sg, tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; + } + } + } else { // k % (SgSize * 32 * Unroll) != 0 + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + int k_body = padto_le(k, GroupK * Unroll); + int constexpr TileK2 = 8; + int constexpr GroupK2 = SgSize * TileK2; + int k_body2 = padto_le(k, GroupK2 * Unroll); + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = ScaleAndZeros + g_n * ldb; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + if constexpr (std::is_same_v) { // Half + sycl::half2 tmpAcc = {0.f, 0.f}; + int i = 0; + for (; i < k_body; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scalar_t scale = 1.f; + // scalar_t scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk]; + sycl::half2 tmpB = {1.f, 1.f}; + // static_cast((tmps8[ikk / 2] & 0x0f) - 8), + // static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + if (i + GroupK2 * Unroll < k_body2) { + for (; i < k_body2; i += GroupK2 * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK2 / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK2 / 2); + scalar_t scale = 1.f; + // scalar_t scale = *(sptr + sg_id * TileK2 / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK2; ikk += 2) { + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK2 + ikk]; + sycl::half2 tmpB = {1.f, 1.f}; + // static_cast((tmps8[ikk / 2] & 0x0f) - 8), + // static_cast((tmps8[ikk / 2] >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + } + sptr += GroupK2 / blocksize; + aptr += GroupK2; + bptr += GroupK2 / 2; + } + } + } + if (i + SgSize * 2 < k) { + for (; i < k; i += SgSize * 2) { + uint8_t tmps8 = *(bptr + sg_id); + scalar_t scale = 1.f; + // scalar_t scale = *(sptr + sg_id * 2 / blocksize); + sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * 2]; + sycl::half2 tmpB = {1.f, 1.f}; + // static_cast((tmps8 & 0x0f) - 8), + // static_cast((tmps8 >> 4) - 8)}; + tmpAcc += tmpA * tmpB * scale; + sptr += SgSize * 2 / blocksize; + aptr += SgSize * 2; + bptr += SgSize * 2 / 2; + } + } + sycl::half2 sum = {0.f, 0.f}; + for (int i = 0; i < SgSize; i += 1) { + sum += select_from_group(sg, tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum[0] + sum[1]; + } + } else { // Bfloat16 + scalar_t tmpAcc = 0.f; + int constexpr Unroll = 2; + int i = 0; + for (; i < k_body; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + scalar_t scale = *(sptr + sg_id * TileK / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += scalar_t(aptr[sg_id * TileK + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK / blocksize; + aptr += GroupK; + bptr += GroupK / 2; + } + } + if (i + GroupK2 * Unroll < k_body2) { + for (; i < k_body2; i += GroupK2 * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK2 / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK2 / 2); + scalar_t scale = *(sptr + sg_id * TileK2 / blocksize); +#pragma unroll + for (int ikk = 0; ikk < TileK2; ikk += 2) { + tmpAcc += scalar_t(aptr[sg_id * TileK2 + ikk]) * + static_cast((tmps8[ikk / 2] & 0x0f) - 8) * scale; + tmpAcc += scalar_t(aptr[sg_id * TileK2 + ikk + 1]) * + static_cast((tmps8[ikk / 2] >> 4) - 8) * scale; + } + sptr += GroupK2 / blocksize; + aptr += GroupK2; + bptr += GroupK2 / 2; + } + } + } + if (i + SgSize * Unroll < k) { + for (; i < k; i += SgSize) { + uint8_t tmps8 = *(bptr + sg_id / 2); + scalar_t scale = *(sptr + sg_id / blocksize); + tmpAcc += scalar_t(aptr[sg_id]) * + static_cast((tmps8 & 0x0f) - 8) * scale; + tmpAcc += scalar_t(aptr[sg_id]) * + static_cast((tmps8 >> 4) - 8) * scale; + sptr += SgSize / blocksize; + aptr += SgSize; + bptr += SgSize / 2; + } + } + float sum = 0.f; + for (int i = 0; i < SgSize; i += 1) { + sum += select_from_group(sg, tmpAcc, i); + } + if (sg_id == 0) { + *cptr = sum; + } + } + } + } + + private: + const scalar_t* A; + const uint8_t* B; + scalar_t* C; + const scalar_t* ScaleAndZeros; + int m; + int n; + int k; + int lda; + int ldb; + int ldc; + + private: + sycl::stream os_; + // sycl::handler cgh_; +}; + +void linear_int4_kernel( + const Tensor& A, + const Tensor& B, + int qGroupSize, + const Tensor& qScaleAndZeros, + Tensor& C) { + auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); + int64_t m = A.size(0); + int64_t n = C.size(1); + int64_t k = A.size(1); + int constexpr Unroll = 2; + int constexpr SgSize = 16; + sycl::range<1> local_range{SgSize}; + sycl::range<1> global_range{static_cast(n) * SgSize}; + + if (A.scalar_type() == at::ScalarType::Half) { + using scalar_t = at::Half; + using scalar_sycl_t = sycl::half; + const scalar_sycl_t* input_data = + reinterpret_cast(A.data_ptr()); + uint8_t* weight_data = + reinterpret_cast(B.data_ptr()); // int4x8 + + scalar_sycl_t* output_data = + reinterpret_cast(C.data_ptr()); + scalar_sycl_t* scale_zeros_data = + reinterpret_cast(qScaleAndZeros.data_ptr()); + + auto cgf = [&](::sycl::handler& cgh) { + auto os = sycl::stream(1024, 768, cgh); + LinearInt4KernelFunctor kfn( + input_data, + weight_data, + output_data, + scale_zeros_data, + m, + n, + k, + k, + (k / qGroupSize) * 2, // scale and zero point combined + n, + os); + kfn.sycl_ker_config_convention(cgh); + cgh.parallel_for>( + ::sycl::nd_range<1>(global_range, local_range), kfn); + }; + sycl_queue.submit(cgf); + + // sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + } + // AT_DISPATCH_FLOATING_TYPES_AND( + // at::ScalarType::BFloat16, A.scalar_type(), "linear_int4_kernel", + // [&]() + // { + else if (A.scalar_type() == at::ScalarType::BFloat16) { + // using scalar_t = at::BFloat16; + // const scalar_t* input_data = A.data_ptr(); + // int32_t* weight_data = B.data_ptr(); // int4x8 + + // scalar_t* output_data = C.data_ptr(); + // scalar_t* weight_scale_data = qScaleAndZeros.data_ptr(); + // LinearInt4KernelFunctor kfn( + // input_data, + // weight_data, + // output_data, + // weight_scale_data, + // nullptr, + // m, + // n, + // k, + // k, + // n, + // n); + + // sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + } +} + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/LinearInt4.h b/src/ATen/native/xpu/sycl/LinearInt4.h new file mode 100644 index 000000000..c54f3df21 --- /dev/null +++ b/src/ATen/native/xpu/sycl/LinearInt4.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include + +namespace at::native::xpu { + +TORCH_XPU_API void linear_int4_kernel( + const Tensor& input, + const Tensor& weight, + int qGroupSize, + const Tensor& weight_scale_zero_point, + Tensor& output); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/MultiTensorApply.h b/src/ATen/native/xpu/sycl/MultiTensorApply.h index 51ee195a9..e91ab56f5 100644 --- a/src/ATen/native/xpu/sycl/MultiTensorApply.h +++ b/src/ATen/native/xpu/sycl/MultiTensorApply.h @@ -68,7 +68,7 @@ static inline int64_t multi_tensor_apply_fused_kernel_get_chunk_size() { } template -struct MultiTensorApplyKernelFunctor { +struct MultiTensorApplyKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { void operator()(sycl::nd_item<1> item_id) const { // Expand the tuple elements manually and call the callable expandAndCall(item_id, std::index_sequence_for()); @@ -85,6 +85,12 @@ struct MultiTensorApplyKernelFunctor { callable(callable_), args(std::make_tuple(args_...)) {} + void sycl_ker_config_convention(sycl::handler& cgh) { + if constexpr (std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, U>) { + callable.sycl_ker_config_convention(cgh); + } + } + private: template void expandAndCall(sycl::nd_item<1> item_id, std::index_sequence) @@ -117,7 +123,6 @@ void launch_multi_tensor_apply_kernel( U callable, int num_wg, ArgTypes... args) { - auto& q = getCurrentSYCLQueue(); int64_t simd = syclMaxSubGroupSize(); int64_t max_wg_size = multi_tensor_apply_kernel_get_wg_size(simd); @@ -226,7 +231,6 @@ void multi_tensor_apply( std::vector>& tensor_lists, T callable, ArgTypes... args) { - TORCH_CHECK( tensor_lists.size() == depth, "Number of tensor lists has to match he depth"); diff --git a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp index 533630175..7f6f33805 100644 --- a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp +++ b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp @@ -86,7 +86,7 @@ template inline void _rrelu_with_noise_xpu_train( Tensor& output, const Tensor& input_, - const Tensor& noise_, + Tensor& noise_, const Scalar& lower_, const Scalar& upper_, std::optional generator) { @@ -153,7 +153,7 @@ inline void _rrelu_with_noise_xpu_train( Tensor& rrelu_with_noise_kernel( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h index 8371c38ab..fa7e568ea 100644 --- a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h +++ b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.h @@ -7,7 +7,7 @@ namespace at::native::xpu { TORCH_XPU_API Tensor& rrelu_with_noise_kernel( const Tensor& self, - const Tensor& noise, + Tensor& noise, const Scalar& lower, const Scalar& upper, bool training, diff --git a/src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp b/src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp index 597be8553..c807655b3 100644 --- a/src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp +++ b/src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp @@ -143,41 +143,64 @@ struct alignas(N) OpaqueType { char data[N]; }; -template +template struct ScatterGatherElementwiseKernelFunctor { void operator()(sycl::nd_item<1> item) const { - constexpr int nv = work_group_size * thread_work_size; + int nv = work_group_size_ * thread_work_size_; auto wg_id = item.get_group_linear_id(); auto local_id = item.get_local_linear_id(); int idx = nv * wg_id + local_id; -#pragma unroll - for (int i = 0; i < thread_work_size; ++i) { + for (int i = 0; i < thread_work_size_; ++i) { if (idx < N_) { f_(idx); - idx += work_group_size; + idx += work_group_size_; } } } - ScatterGatherElementwiseKernelFunctor(int N, func_t f) : N_(N), f_(f) {} + ScatterGatherElementwiseKernelFunctor( + int N, + func_t f, + int work_group_size, + int thread_work_size) + : N_(N), + f_(f), + work_group_size_(work_group_size), + thread_work_size_(thread_work_size) {} private: int N_; func_t f_; + int work_group_size_; + int thread_work_size_; }; -template +template static void launch_scatter_gather_kernel(int64_t N, const func_t& f) { TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); if (N == 0) { return; } - sycl::range<1> local_range{(size_t)nt}; - int num_workgroups = (N + nt * vt - 1) / (nt * vt); - sycl::range<1> global_range{(size_t)(num_workgroups * nt)}; - - auto caller = - ScatterGatherElementwiseKernelFunctor((int)N, f); + using KernelFn = ScatterGatherElementwiseKernelFunctor; + int64_t max_wg_size = syclMaxWorkGroupSize(); + int outputSize = N; + int work_group_size = outputSize > max_wg_size ? max_wg_size : outputSize; + const auto target_global_size = syclMaxWorkItemsPerTile(); + // Each work group size is work_group_size, one full device launch is + // target_global_size, so we can calculate max work group num as below + const int max_work_group_num = target_global_size / work_group_size; + int work_group_num = outputSize / work_group_size < max_work_group_num + ? outputSize / work_group_size + : max_work_group_num; + int draft_work_group_num = + (outputSize + work_group_size - 1) / work_group_size; + + int thread_work_size = draft_work_group_num / work_group_num + 1; + + sycl::range<1> local_range(work_group_size); + sycl::range<1> global_range(work_group_num * work_group_size); + + auto caller = KernelFn((int)N, f, work_group_size, thread_work_size); sycl_kernel_submit( global_range, local_range, at::xpu::getCurrentSYCLQueue(), caller); } @@ -268,11 +291,7 @@ struct ScatterGatherInternalKernel { numel, f); - // TODO: optimize it - constexpr int group_work_items = 256; - constexpr int work_size_per_item = 4; - launch_scatter_gather_kernel( - iter.numel(), loop); + launch_scatter_gather_kernel(iter.numel(), loop); } }; @@ -521,11 +540,7 @@ struct ScatterFillInternalKernel { decltype(offset_calc), func_t>(self_ptr, index_ptr, offset_calc, index_stride, f, src_val); - // TODO: optimize it - constexpr int group_work_items = 256; - constexpr int work_size_per_item = 4; - launch_scatter_gather_kernel( - iter.numel(), loop); + launch_scatter_gather_kernel(iter.numel(), loop); } }; diff --git a/src/ATen/xpu/EmptyTensor.cpp b/src/ATen/xpu/EmptyTensor.cpp index 2550b4dbb..3f5e998f8 100644 --- a/src/ATen/xpu/EmptyTensor.cpp +++ b/src/ATen/xpu/EmptyTensor.cpp @@ -12,6 +12,7 @@ TensorBase empty_xpu( ScalarType dtype, c10::optional device_opt, c10::optional memory_format_opt) { + at::globalContext().lazyInitDevice(c10::DeviceType::XPU); const auto device = device_or_default(device_opt); TORCH_INTERNAL_ASSERT(device.is_xpu()); const c10::DeviceGuard device_guard(device); diff --git a/src/BuildOnWindows.cmake b/src/BuildOnWindows.cmake index b671b305c..b4757acb1 100644 --- a/src/BuildOnWindows.cmake +++ b/src/BuildOnWindows.cmake @@ -3,11 +3,6 @@ set(TORCH_XPU_OPS_LIBRARIES) set(SYCL_LINK_LIBRARIES_KEYWORD PRIVATE) -# Walk around cyclic dependence -# libtorch_xpu.so links to libtorch_xpu_ops.a -# Load libtorch_xpu_ops_aten.so explicitly by torch/__init__.py:_load_dll_libraries (Break cycle) -# libtorch_xpu_ops_aten.so links to libtorch_xpu_ops_sycl_unary_binary_kernels.so and libtorch_xpu_ops_sycl_kernels.so -# libtorch_xpu_ops_sycl_unary_binary_kernels.so and libtorch_xpu_ops_sycl_kernels.so links to libtorch_xpu.so add_library( torch_xpu_ops STATIC @@ -21,7 +16,6 @@ add_library( ${ATen_XPU_NATIVE_CPP_SRCS} ${ATen_XPU_GEN_SRCS}) install(TARGETS torch_xpu_ops_aten DESTINATION "${TORCH_INSTALL_LIB_DIR}") -# target_compile_definitions(torch_xpu_ops_aten PRIVATE CAFFE2_BUILD_MAIN_LIB) target_compile_definitions(torch_xpu_ops_aten PRIVATE TORCH_XPU_BUILD_MAIN_LIB) target_link_libraries(torch_xpu_ops_aten PUBLIC torch_xpu) target_link_libraries(torch_xpu_ops_aten PUBLIC torch_cpu) @@ -48,8 +42,11 @@ else() set(ATen_XPU_SYCL_REDUCE_SRCS) set(ATen_XPU_SYCL_ACTIVATION_SRCS) set(ATen_XPU_SYCL_FOREACH_SRCS) + set(ATen_XPU_SYCL_TENSOR_SRCS) + set(ATen_XPU_SYCL_NORM_LOSS_SRCS) + set(ATen_XPU_SYCL_POLY_SRCS) + set(ATen_XPU_SYCL_DISTRIBUTION_SRCS) set(ATen_XPU_SYCL_OTHERS_SRCS) - foreach(sycl_src ${ATen_XPU_SYCL_SRCS}) string(REGEX MATCH "Binary" IS_BINARY ${sycl_src}) string(REGEX MATCH "Unary" IS_UNARY ${sycl_src}) @@ -63,6 +60,13 @@ else() string(REGEX MATCH "Activation" IS_ACTIVATION ${sycl_src}) string(REGEX MATCH "Foreach" IS_FOREACH ${sycl_src}) string(REGEX MATCH "Reduce" IS_REDUCE ${sycl_src}) + string(REGEX MATCH "Tensor" IS_TENSOR ${sycl_src}) + string(REGEX MATCH "Norm" IS_NORM ${sycl_src}) + string(REGEX MATCH "Loss" IS_LOSS ${sycl_src}) + string(REGEX MATCH "Polynomial" IS_POLY ${sycl_src}) + #Move resize kernel to Norm and Loss lib, to resolve symbol. + string(REGEX MATCH "Resize" IS_RESIZE ${sycl_src}) + string(REGEX MATCH "Distribution" IS_DISTRIBUTION ${sycl_src}) if(NOT IS_FOREACH STREQUAL "") list(APPEND ATen_XPU_SYCL_FOREACH_SRCS ${sycl_src}) @@ -74,11 +78,18 @@ else() list(APPEND ATen_XPU_SYCL_REDUCE_SRCS ${sycl_src}) elseif(NOT IS_ACTIVATION STREQUAL "") list(APPEND ATen_XPU_SYCL_ACTIVATION_SRCS ${sycl_src}) + elseif(NOT IS_TENSOR STREQUAL "") + list(APPEND ATen_XPU_SYCL_TENSOR_SRCS ${sycl_src}) + elseif(NOT IS_DISTRIBUTION STREQUAL "") + list(APPEND ATen_XPU_SYCL_DISTRIBUTION_SRCS ${sycl_src}) + elseif(NOT IS_NORM STREQUAL "" OR NOT IS_LOSS STREQUAL "" OR NOT IS_RESIZE STREQUAL "") + list(APPEND ATen_XPU_SYCL_NORM_LOSS_SRCS ${sycl_src}) + elseif(NOT IS_POLY STREQUAL "") + list(APPEND ATen_XPU_SYCL_POLY_SRCS ${sycl_src}) else() list(APPEND ATen_XPU_SYCL_OTHERS_SRCS ${sycl_src}) endif() endforeach() - # Binary kernel lib set(sycl_binary_lib torch_xpu_ops_sycl_binary_kernels) sycl_add_library( @@ -148,7 +159,63 @@ else() # Decouple with PyTorch cmake definition. install(TARGETS ${sycl_foreach_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + + # Tensor kernel lib + set(sycl_tensor_lib torch_xpu_ops_sycl_tensor_kernels) + sycl_add_library( + ${sycl_tensor_lib} + SHARED + SYCL_SOURCES ${ATen_XPU_SYCL_TENSOR_SRCS}) + target_compile_definitions(${sycl_tensor_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB) + target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_tensor_lib}) + target_link_libraries(${sycl_tensor_lib} PUBLIC torch_xpu) + list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_tensor_lib}) + # Decouple with PyTorch cmake definition. + install(TARGETS ${sycl_tensor_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + + # Norm and Loss kernel lib + set(sycl_norm_loss_lib torch_xpu_ops_sycl_norm_loss_kernels) + sycl_add_library( + ${sycl_norm_loss_lib} + SHARED + SYCL_SOURCES ${ATen_XPU_SYCL_NORM_LOSS_SRCS}) + target_compile_definitions(${sycl_norm_loss_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB) + target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_norm_loss_lib}) + target_link_libraries(${sycl_norm_loss_lib} PUBLIC torch_xpu) + list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_norm_loss_lib}) + + # Decouple with PyTorch cmake definition. + install(TARGETS ${sycl_norm_loss_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + + # Polynomial kernel lib + set(sycl_poly_lib torch_xpu_ops_sycl_poly_kernels) + sycl_add_library( + ${sycl_poly_lib} + SHARED + SYCL_SOURCES ${ATen_XPU_SYCL_POLY_SRCS}) + target_compile_definitions(${sycl_poly_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB) + target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_poly_lib}) + target_link_libraries(${sycl_poly_lib} PUBLIC torch_xpu) + list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_poly_lib}) + + # Decouple with PyTorch cmake definition. + install(TARGETS ${sycl_poly_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + + # Distribution kernel lib + set(sycl_dist_lib torch_xpu_ops_sycl_dist_kernels) + sycl_add_library( + ${sycl_dist_lib} + SHARED + SYCL_SOURCES ${ATen_XPU_SYCL_DISTRIBUTION_SRCS}) + target_compile_definitions(${sycl_dist_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB) + target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_dist_lib}) + target_link_libraries(${sycl_dist_lib} PUBLIC torch_xpu) + list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_dist_lib}) + + # Decouple with PyTorch cmake definition. + install(TARGETS ${sycl_dist_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + # Other kernel lib set(sycl_lib torch_xpu_ops_sycl_kernels) sycl_add_library( diff --git a/test/xpu/test_binary_ufuncs_xpu.py b/test/xpu/test_binary_ufuncs_xpu.py index 9cff7f9f1..91db3f12a 100644 --- a/test/xpu/test_binary_ufuncs_xpu.py +++ b/test/xpu/test_binary_ufuncs_xpu.py @@ -65,7 +65,7 @@ def to_np(value): else: self.assertRaisesRegex( RuntimeError, - "Found dtype \\w+ but expected \\w+", + r"result type \w+ can't be cast to the desired output type \w+", lambda: actual.pow_(exponent), ) diff --git a/test/xpu/test_int4_linear.py b/test/xpu/test_int4_linear.py new file mode 100644 index 000000000..6d305e1cb --- /dev/null +++ b/test/xpu/test_int4_linear.py @@ -0,0 +1,205 @@ +# Owner(s): ["module: intel"] + +import torch +import pytest + +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + + +checking_atol = 1e-2 +checking_rtol = 1e-2 + + +class TestSYCLInt4Linear(TestCase): + + @staticmethod + def unpack_weight(qweight, scales, qzeros, q_config): + group_size = q_config["group_size"] + bits = q_config["bits"] + s32_bits = 32 + + assert bits == 4 + # Int32 can store 8 * 4bits data. This is the offset for each data. + wf = ( + torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32) + .unsqueeze(0) + .to("xpu") + ) + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(scales.shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(weight, (2**bits) - 1, out=weight) + + return weight, scales, zeros + + @staticmethod + def dequantize(qweight, scales, qzeros, group_size): + q_config = {"group_size": group_size, "bits": 4} + weight, gptq_scales, gptq_zeros = TestSYCLInt4Linear.unpack_weight( + qweight, scales, qzeros, q_config + ) + gptq_zeros = (torch.ones_like(gptq_zeros) * 8).to("xpu") # TODO: hard code zp + if len(weight.shape) > 2: + weight = weight.reshape(-1, weight.shape[-1]) + infeatures = weight.shape[0] + g_idx = torch.tensor( + [i // q_config["group_size"] for i in range(infeatures)], + dtype=torch.int32, + ) + scale_zeros = gptq_zeros * gptq_scales + weight = gptq_scales[g_idx.long()] * weight - scale_zeros[g_idx.long()] + return weight + + @staticmethod + def rand_int4(size, dtype=torch.int32, device="xpu"): + rand = torch.randint(-128, 128, [size // 2], device=device).to(torch.int8) + return rand.view(dtype=dtype) + + @parametrize("per_channel", [False], lambda k: "per_channel" * k) + @parametrize("dtype", [torch.float16]) + @parametrize("m,n,k", [(8, 4096, 4096), (1, 4096, 11008), (32, 4096, 4096)]) + def test_gemm_int4(self, m, n, k, per_channel, dtype): + input = torch.rand([m, k], device="xpu", dtype=dtype) + input_torch = input.cpu() + weight = self.rand_int4(k * n, torch.int32, "xpu").reshape(k // 8, n) + + group_size = min(128, k) + if per_channel: + group_size = k + group_num = int(k / group_size) + + scales = torch.rand([group_num, n], device="xpu", dtype=dtype) + zero_points = self.rand_int4(group_num * n, torch.int32, "xpu").reshape( + group_num, n // 8 + ) + + weight_fp = self.dequantize( + weight, scales, zero_points, group_size).cpu() + # check gemm + zero_points = torch.Tensor([8]).to(torch.int8).to("xpu") + weight_ba = weight.transpose(0, 1).contiguous() + + out_onednn =torch._weight_int4pack_mm_with_scales_and_zeros( + input, weight_ba, scales, zero_points, group_size + ) + out_torch = torch.matmul(input_torch, weight_fp) + self.assertEqual( + out_onednn.cpu().float(), + out_torch.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + # check gemm + residual + res0 = torch.rand([m, n], device="xpu", dtype=dtype) + out_onednn_res = torch._weight_int4pack_mm_with_scales_and_zeros( + input, weight_ba, scales, zero_points, group_size, res0) + out_torch_res = out_torch + res0.cpu().float() + self.assertEqual( + out_onednn_res.cpu().float(), + out_torch_res.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + bias + bias = torch.rand([1, n], device="xpu", dtype=dtype) + out_onednn_bias = torch._weight_int4pack_mm_with_scales_and_zeros( + input, weight_ba, bias, scales, zero_points, group_size) + out_torch_bias = out_torch + bias.cpu().float() + self.assertEqual( + out_onednn_bias.cpu().float(), + out_torch_bias.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + bias + gelu + out_onednn_gelu = torch._weight_int4pack_mm_with_scales_and_zeros( + input, + weight_ba, + scales, + zero_points, + bias, + group_size, + "tanh", + ) + gelu_out = torch.nn.GELU(approximate="tanh")(out_torch_bias) + self.assertEqual( + out_onednn_gelu.cpu().float(), + gelu_out.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + silu + mul + res0 = torch.rand([m, n], device="xpu", dtype=dtype) + out_onednn_silu = torch._weight_int4pack_mm_with_scales_and_zeros( + input, weight_ba, scales, zero_points, group_size, res0 + ) + silu_mul_out = torch.nn.SiLU()(out_torch) * res0.cpu().float() + self.assertEqual( + out_onednn_silu.cpu().float(), + silu_mul_out.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + bias + residual + residual + res0 = torch.rand([m, n], device="xpu", dtype=dtype) + res1 = torch.rand([m, n], device="xpu", dtype=dtype) + out_onednn_bias_2res = torch._weight_int4pack_mm_with_scales_and_zeros( + input, + weight_ba, + bias, + res0, + res1, + scales, + zero_points, + group_size, + ) + out_torch_bias_2res = out_torch_bias + res0.cpu().float() + res1.cpu().float() + self.assertEqual( + out_onednn_bias_2res.cpu().float(), + out_torch_bias_2res.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + # check gemm + bias + residual + res0 = torch.rand([m, n], device="xpu", dtype=dtype) + out_onednn_bias_add = torch._weight_int4pack_mm_with_scales_and_zeros( + input, + weight_ba, + bias, + scales, + zero_points, + group_size, + res0, + ) + out_torch_bias_add = out_torch_bias + res0.cpu().float() + self.assertEqual( + out_onednn_bias_add.cpu().float(), + out_torch_bias_add.float(), + atol=checking_atol, + rtol=checking_rtol, + ) + + +instantiate_parametrized_tests(TestSYCLInt4Linear, globals(), only_for="xpu", allow_xpu=True) + +if __name__ == "__main__": + TestCase._default_dtype_check_enabled = True + run_tests() diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index 3c1c8aed7..88acd9134 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -6,6 +6,7 @@ from torch.testing._internal.common_dtype import floating_and_complex_types_and from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_quantization import _dynamically_quantize_per_channel from torch.testing import make_tensor import unittest import itertools @@ -171,6 +172,114 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): if not use_transpose_a and not use_transpose_b: _test(17, k, n, use_transpose_a, use_transpose_b) +@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") +@parametrize("m", [1]) +@parametrize("k", [256, 512, 1024]) +@parametrize("n", [32, 64]) +def _int4_mm(self, device, m, k, n): + def _group_quantize_tensor(w, n_bit=4, q_group_size=16): + assert w.dim() == 2 + w = w.transpose(0, 1).contiguous() + assert q_group_size > 1 + assert w.shape[-1] % q_group_size == 0 + + to_quant = w.reshape(-1, q_group_size) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2 ** n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + assert torch.isnan(scales).sum() == 0 + + zeros = min_val + scales * (2 ** (n_bit - 1)) + assert torch.isnan(zeros).sum() == 0 + + out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) + assert torch.isnan(out).sum() == 0 + + out = out.to(dtype=torch.int32).reshape(w.shape) + if out.device.type != 'cpu' or out.device.type != 'xpu': + out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) + + # Scales and zeros for the same q-group should be contiguous, so we can + # load as a 32-bit word + scales = scales.view(w.shape[0], -1) + zeros = zeros.view(w.shape[0], -1) + scales_and_zeros = ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + ) + + if out.device.type != 'xpu': + scales_and_zeros = scales_and_zeros.transpose(0, 1).contiguous() + return out, scales_and_zeros + + def convert_weight_to_int4pack(b): + b_tmp, b_scales_and_zeros = _group_quantize_tensor( + b, n_bit=4, q_group_size=q_group + ) + + if self.device_type == 'cpu': + b_int4pack = torch._convert_weight_to_int4pack_for_cpu( + b_tmp, inner_k_tiles + ) + elif self.device_type == 'xpu': + # b_int4pack = b_tmp.view(torch.int32) + b_int4pack = b_tmp + else: + b_int4pack = torch._convert_weight_to_int4pack( + b_tmp, inner_k_tiles + ) + + return b_int4pack, b_scales_and_zeros + + def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): + if self.device_type == 'cpu': + self.assertTrue(b_int4pack.dtype is torch.uint8) + self.assertTrue(b_int4pack.dim() == 2) + return torch._weight_int4pack_mm_for_cpu( + a, b_int4pack, q_group, b_scales_and_zeros + ) + elif self.device_type == 'xpu': + self.assertTrue(b_int4pack.dtype is torch.int32 or b_int4pack.dtype is torch.uint8) + self.assertTrue(b_int4pack.dim() == 2) + return torch._weight_int4pack_mm( + a, b_int4pack, q_group, b_scales_and_zeros + ) + else: + self.assertTrue(b_int4pack.dtype is torch.int32) + self.assertTrue(b_int4pack.dim() == 4) + return torch._weight_int4pack_mm( + a, b_int4pack, q_group, b_scales_and_zeros + ) + + q_group = 32 + inner_k_tiles = 2 + + torch.manual_seed(1) + a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device) + b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device) + + b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) + + for dtype in [torch.float16] + ([torch.float16, torch.float32] if device == "cpu" else []): + a = a_bf16.to(dtype=dtype) + b = b_bf16.to(dtype=dtype) + b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) + ref = torch.mm(a, b) + res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) + print(ref) + print(res) + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.05) + @dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU @setBlasBackendsToDefaultFinally def matmul_small_brute_force_1d_Nd(self, device, dtype): @@ -229,6 +338,7 @@ def ck_blas_library(self): TestLinalg.test_preferred_linalg_library=preferred_linalg_library TestLinalg.test_addbmm=addbmm TestLinalg.test__int_mm=_int_mm +TestLinalg.test__int4_mm=_int4_mm TestLinalg.test_matmul_small_brute_force_1d_Nd=matmul_small_brute_force_1d_Nd TestLinalg.test_matmul_small_brute_force_2d_Nd=matmul_small_brute_force_2d_Nd TestLinalg.test_matmul_small_brute_force_3d_Nd=matmul_small_brute_force_3d_Nd diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 64e94073b..59c0b3b73 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8106,17 +8106,18 @@ variants: function tags: pointwise -- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) +- func: rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn tags: nondeterministic_seeded dispatch: XPU: rrelu_with_noise_out_xpu -- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor +- func: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor python_module: nn dispatch: XPU: rrelu_with_noise_xpu tags: nondeterministic_seeded + autogen: rrelu_with_noise_functional - func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor python_module: nn @@ -8124,7 +8125,7 @@ CompositeExplicitAutograd: rrelu_with_noise_backward autogen: rrelu_with_noise_backward.out -- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) +- func: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) python_module: nn tags: nondeterministic_seeded dispatch: @@ -8468,3 +8469,9 @@ dispatch: SparseXPU: copy_sparse_ autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out + +- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor + dispatch: + XPU: _weight_int4pack_mm_xpu + # autogen: _weight_int4pack_mm.out + # tags: core