Skip to content

Commit

Permalink
Modified some review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Dec 9, 2024
1 parent 16cf764 commit 5a08d2e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
37 changes: 36 additions & 1 deletion src/ATen/native/xpu/LinearInt4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,41 @@ Tensor& linear_int4_xpu(
const Tensor& weight,
int qGroupSize,
const Tensor& weight_scale_zero_point) {
auto M = input.size(0);
auto N = weight.size(0);
auto K = input.size(1);
TORCH_CHECK(
input.dtype() == kBFloat16 || input.dtype() == kHalf ||
input.dtype() == kFloat,
__func__,
" : expect input to be either 32-bit or 16-bit float tensor.");

TORCH_CHECK(
weight.dtype() == kByte, __func__, " : expect B to be uint8 tensor.");
TORCH_CHECK(
weight.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCH_CHECK(
weight.size(1) == K / 2,
__func__,
" : expect B.size(1) to be K/2, got ",
weight.size(1));

TORCH_CHECK(
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
qGroupSize == 256,
__func__,
": expect qGroupSize to be 32, 64, 128 or 256, got ",
qGroupSize);

TORCH_CHECK(
weight_scale_zero_point.dim() == 3 &&
weight_scale_zero_point.size(1) == N &&
weight_scale_zero_point.size(2) == 2,
__func__,
": expect weight_scale_zero_point to be 3d tensor with sizes [:, ",
N,
", 2]");

std::optional<Device> common_device = std::nullopt;
c10::impl::check_and_update_common_device(
common_device, input, "xpu::linear_int4", "input");
Expand All @@ -23,7 +58,7 @@ Tensor& linear_int4_xpu(
weight_scale_zero_point,
"xpu::linear_int4",
"weight_scale_zero_point");
Tensor output = at::empty({0}, input.options());
Tensor output = at::empty({M, N}, input.options());

at::native::xpu::linear_int4_kernel(
input, weight, qGroupSize, weight_scale_zero_point, output);
Expand Down
20 changes: 10 additions & 10 deletions src/ATen/native/xpu/sycl/LinearInt4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ namespace at::native::xpu {
template <typename scalar_t = at::Half, int block_size = 16>
struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
LinearInt4KernelFunctor(
scalar_t* A,
uint32_t* B,
const scalar_t* A,
const uint32_t* B,
scalar_t* C,
scalar_t* B_scale,
scalar_t* B_zero_point,
const scalar_t* B_scale,
const scalar_t* B_zero_point,
int m,
int n,
int k,
Expand Down Expand Up @@ -49,10 +49,10 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
auto cptr = C + g_n;
if constexpr (std::is_same_v<scalar_t, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};
uint8_t tmps8[TileK / 2];
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
scalar_t scale = *(sptr + sg_id * TileK / blocksize);
Expand Down Expand Up @@ -109,11 +109,11 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
}

private:
scalar_t* A;
uint32_t* B;
const scalar_t* A;
const uint32_t* B;
scalar_t* C;
scalar_t* B_scale;
scalar_t* B_zero_point;
const scalar_t* B_scale;
const scalar_t* B_zero_point;
int m;
int n;
int k;
Expand Down Expand Up @@ -142,7 +142,7 @@ void linear_int4_kernel(
if (input.scalar_type() == at::kHalf) {
using scalar_t = at::Half;
// const auto scalar_t = input.scalar_type();
scalar_t* input_data = input.data_ptr<scalar_t>();
const scalar_t* input_data = input.data_ptr<scalar_t>();
uint32_t* weight_data = weight.data_ptr<uint32_t>(); // int4x8

scalar_t* output_data = output.data_ptr<scalar_t>();
Expand Down

0 comments on commit 5a08d2e

Please sign in to comment.