diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index b8e114714..baa02e341 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -13,6 +13,41 @@ Tensor& linear_int4_xpu( const Tensor& weight, int qGroupSize, const Tensor& weight_scale_zero_point) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + TORCH_CHECK( + input.dtype() == kBFloat16 || input.dtype() == kHalf || + input.dtype() == kFloat, + __func__, + " : expect input to be either 32-bit or 16-bit float tensor."); + + TORCH_CHECK( + weight.dtype() == kByte, __func__, " : expect B to be uint8 tensor."); + TORCH_CHECK( + weight.is_contiguous(), __func__, " : expect B to be contiguous."); + TORCH_CHECK( + weight.size(1) == K / 2, + __func__, + " : expect B.size(1) to be K/2, got ", + weight.size(1)); + + TORCH_CHECK( + qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || + qGroupSize == 256, + __func__, + ": expect qGroupSize to be 32, 64, 128 or 256, got ", + qGroupSize); + + TORCH_CHECK( + weight_scale_zero_point.dim() == 3 && + weight_scale_zero_point.size(1) == N && + weight_scale_zero_point.size(2) == 2, + __func__, + ": expect weight_scale_zero_point to be 3d tensor with sizes [:, ", + N, + ", 2]"); + std::optional common_device = std::nullopt; c10::impl::check_and_update_common_device( common_device, input, "xpu::linear_int4", "input"); @@ -23,7 +58,7 @@ Tensor& linear_int4_xpu( weight_scale_zero_point, "xpu::linear_int4", "weight_scale_zero_point"); - Tensor output = at::empty({0}, input.options()); + Tensor output = at::empty({M, N}, input.options()); at::native::xpu::linear_int4_kernel( input, weight, qGroupSize, weight_scale_zero_point, output); diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 7105236c8..6ac13442c 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -6,11 +6,11 @@ namespace at::native::xpu { template struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { LinearInt4KernelFunctor( - scalar_t* A, - uint32_t* B, + const scalar_t* A, + const uint32_t* B, scalar_t* C, - scalar_t* B_scale, - scalar_t* B_zero_point, + const scalar_t* B_scale, + const scalar_t* B_zero_point, int m, int n, int k, @@ -49,10 +49,10 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto cptr = C + g_n; if constexpr (std::is_same_v) { sycl::half2 tmpAcc = {0.f, 0.f}; + uint8_t tmps8[TileK / 2]; for (int i = 0; i < k; i += GroupK * Unroll) { #pragma unroll for (int iu = 0; iu < Unroll; iu++) { - uint8_t tmps8[TileK / 2]; *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); scalar_t scale = *(sptr + sg_id * TileK / blocksize); @@ -109,11 +109,11 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } private: - scalar_t* A; - uint32_t* B; + const scalar_t* A; + const uint32_t* B; scalar_t* C; - scalar_t* B_scale; - scalar_t* B_zero_point; + const scalar_t* B_scale; + const scalar_t* B_zero_point; int m; int n; int k; @@ -142,7 +142,7 @@ void linear_int4_kernel( if (input.scalar_type() == at::kHalf) { using scalar_t = at::Half; // const auto scalar_t = input.scalar_type(); - scalar_t* input_data = input.data_ptr(); + const scalar_t* input_data = input.data_ptr(); uint32_t* weight_data = weight.data_ptr(); // int4x8 scalar_t* output_data = output.data_ptr();