Skip to content

Commit

Permalink
sync UT with pytoch UT(linalg)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Dec 12, 2024
1 parent 9e50b68 commit 81a72f1
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 120 deletions.
68 changes: 29 additions & 39 deletions src/ATen/native/xpu/LinearInt4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,26 @@

namespace at::native {
Tensor _weight_int4pack_mm_xpu(
const Tensor& input,
const Tensor& weight,
const Tensor& A,
const Tensor& B,
int64_t qGroupSize,
const Tensor& weight_scale_zero_point) {
auto M = input.size(0);
auto N = weight.size(0);
auto K = input.size(1);
const Tensor& qScaleAndZeros) {
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);
TORCH_CHECK(
input.dtype() == kBFloat16 || input.dtype() == kHalf ||
input.dtype() == kFloat,
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
__func__,
" : expect input to be either 32-bit or 16-bit float tensor.");
" : expect A to be either 32-bit or 16-bit float tensor.");
TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous.");
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");

TORCH_CHECK(
weight.dtype() == kByte, __func__, " : expect B to be uint8 tensor.");
TORCH_CHECK(
weight.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCH_CHECK(
weight.size(1) == K / 2,
B.dtype() == kInt || B.dtype() == kUInt32,
__func__,
" : expect B.size(1) to be K/2, got ",
weight.size(1));
" : expect B to be int32 or uint32 tensor.");
TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCH_CHECK(B.dim() == 4, __func__, " : expect B to 4d tensor.");

TORCH_CHECK(
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
Expand All @@ -39,35 +37,27 @@ Tensor _weight_int4pack_mm_xpu(
": expect qGroupSize to be 32, 64, 128 or 256, got ",
qGroupSize);

TORCH_CHECK(
weight_scale_zero_point.dim() == 3 &&
weight_scale_zero_point.size(1) == N &&
weight_scale_zero_point.size(2) == 2,
__func__,
": expect weight_scale_zero_point to be 3d tensor with sizes [:, ",
N,
", 2]");
// TORCH_CHECK(
// qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(1) == N &&
// qScaleAndZeros.size(2) == 2,
// __func__,
// ": expect qScaleAndZeros to be 3d tensor with sizes [:, ",
// N,
// ", 2]");

std::optional<Device> common_device = std::nullopt;
c10::impl::check_and_update_common_device(
common_device,
input,
"xpu::_weight_int4pack_mm_with_scales_and_zeros",
"input");
common_device, A, "xpu::_weight_int4pack_mm", "A");
c10::impl::check_and_update_common_device(
common_device,
weight,
"xpu::_weight_int4pack_mm_with_scales_and_zeros",
"weight");
common_device, B, "xpu::_weight_int4pack_mm", "B");
c10::impl::check_and_update_common_device(
common_device,
weight_scale_zero_point,
"xpu::_weight_int4pack_mm_with_scales_and_zeros",
"weight_scale_zero_point");
Tensor output = at::empty({M, N}, input.options());
qScaleAndZeros,
"xpu::_weight_int4pack_mm",
"qScaleAndZeros");
Tensor C = at::empty({M, N}, A.options());

at::native::xpu::linear_int4_kernel(
input, weight, qGroupSize, weight_scale_zero_point, output);
return output;
at::native::xpu::linear_int4_kernel(A, B, qGroupSize, qScaleAndZeros, C);
return C;
}
} // namespace at::native
114 changes: 56 additions & 58 deletions src/ATen/native/xpu/sycl/LinearInt4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ template <typename scalar_t = at::Half, int block_size = 16>
struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
LinearInt4KernelFunctor(
const scalar_t* A,
const uint32_t* B,
const int32_t* B,
scalar_t* C,
const scalar_t* B_scale,
const scalar_t* B_zero_point,
Expand Down Expand Up @@ -71,7 +71,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
}
sycl::half2 sum = {0.f, 0.f};
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
sum += group_broadcast(sg, tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
Expand Down Expand Up @@ -100,7 +100,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
}
float sum = 0.f;
for (int i = 0; i < SgSize; i += 1) {
sum += sg.shuffle(tmpAcc, i);
sum += group_broadcast(sg, tmpAcc, i);
}
if (sg_id == 0) {
*cptr = sum;
Expand All @@ -110,7 +110,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {

private:
const scalar_t* A;
const uint32_t* B;
const int32_t* B;
scalar_t* C;
const scalar_t* B_scale;
const scalar_t* B_zero_point;
Expand All @@ -123,15 +123,15 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
};

void linear_int4_kernel(
const Tensor& input,
const Tensor& weight,
const Tensor& A,
const Tensor& B,
int qGroupSize,
const Tensor& weight_scale_zero_point,
Tensor& output) {
const Tensor& qScaleAndZeros,
Tensor& C) {
auto& sycl_queue = at::xpu::getCurrentSYCLQueue();
int64_t m = input.size(0);
int64_t n = input.size(1);
int64_t k = output.size(1);
int64_t m = A.size(0);
int64_t n = A.size(1);
int64_t k = C.size(1);
int constexpr Unroll = 2;
int constexpr SgSize = 16;
sycl::range<1> local_range{SgSize};
Expand All @@ -140,57 +140,55 @@ void linear_int4_kernel(
int ldb = n;
int ldc = n;

AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Half, input.scalar_type(), "linear_int4_kernel", [&]() {
using scalar_t = at::Half;
const scalar_t* input_data = input.data_ptr<scalar_t>();
uint32_t* weight_data = weight.data_ptr<uint32_t>(); // int4x8
// AT_DISPATCH_FLOATING_TYPES_AND(
// at::ScalarType::Half, A.scalar_type(), "linear_int4_kernel", [&]() {
if (A.scalar_type() == at::ScalarType::Half) {
using scalar_t = at::Half;
const scalar_t* input_data = A.data_ptr<scalar_t>();
int32_t* weight_data = B.data_ptr<int32_t>(); // int4x8

scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* weight_scale_data =
weight_scale_zero_point.data_ptr<scalar_t>();
LinearInt4KernelFunctor<scalar_t, 16> kfn(
input_data,
weight_data,
output_data,
weight_scale_data,
nullptr,
m,
n,
k,
k,
n,
n);
scalar_t* output_data = C.data_ptr<scalar_t>();
scalar_t* weight_scale_data = qScaleAndZeros.data_ptr<scalar_t>();
LinearInt4KernelFunctor<scalar_t, 16> kfn(
input_data,
weight_data,
output_data,
weight_scale_data,
nullptr,
m,
n,
k,
k,
n,
n);

sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
});
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
input.scalar_type(),
"linear_int4_kernel",
[&]() {
using scalar_t = at::BFloat16;
const scalar_t* input_data = input.data_ptr<scalar_t>();
uint32_t* weight_data = weight.data_ptr<uint32_t>(); // int4x8
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
}
// AT_DISPATCH_FLOATING_TYPES_AND(
// at::ScalarType::BFloat16, A.scalar_type(), "linear_int4_kernel", [&]()
// {
else if (A.scalar_type() == at::ScalarType::BFloat16) {
using scalar_t = at::BFloat16;
const scalar_t* input_data = A.data_ptr<scalar_t>();
int32_t* weight_data = B.data_ptr<int32_t>(); // int4x8

scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* weight_scale_data =
weight_scale_zero_point.data_ptr<scalar_t>();
LinearInt4KernelFunctor<scalar_t, 16> kfn(
input_data,
weight_data,
output_data,
weight_scale_data,
nullptr,
m,
n,
k,
k,
n,
n);
scalar_t* output_data = C.data_ptr<scalar_t>();
scalar_t* weight_scale_data = qScaleAndZeros.data_ptr<scalar_t>();
LinearInt4KernelFunctor<scalar_t, 16> kfn(
input_data,
weight_data,
output_data,
weight_scale_data,
nullptr,
m,
n,
k,
k,
n,
n);

sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
});
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
}
}

} // namespace at::native::xpu
40 changes: 17 additions & 23 deletions test/xpu/test_linalg_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,44 +244,38 @@ def convert_weight_to_int4pack(b):
return b_int4pack, b_scales_and_zeros

def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
self.assertTrue(b_int4pack.dtype is torch.int32)
# self.assertTrue(b_int4pack.dim() == 4)
return torch._weight_int4pack_mm(
a, b_int4pack, q_group, b_scales_and_zeros
)
if self.device_type == 'cpu':
self.assertTrue(b_int4pack.dtype is torch.uint8)
self.assertTrue(b_int4pack.dim() == 2)
return torch._weight_int4pack_mm_for_cpu(
a, b_int4pack, q_group, b_scales_and_zeros
)
else:
self.assertTrue(b_int4pack.dtype is torch.int32)
self.assertTrue(b_int4pack.dim() == 4)
return torch._weight_int4pack_mm(
a, b_int4pack, q_group, b_scales_and_zeros
)

dtype = torch.bfloat16
q_group = 32
inner_k_tiles = 2

torch.manual_seed(1)
a_bf16 = torch.rand((m, k), dtype=dtype, device=device)
b_int4 = rand_int4(k * n, torch.int32, "xpu").reshape(k // 8, n)
group_num = int(k / q_group)

scales = torch.rand([group_num, n], device="xpu", dtype=dtype)
zero_points = rand_int4(group_num * n, torch.int32, "xpu").reshape(
group_num, n // 8
)

b_bf16 = dequantize(b_int4, scales, zero_points, q_group).cpu()

# b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16)
b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16)

for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []):
a = a_bf16.to(dtype=dtype)
b = b_bf16.to(dtype=dtype)
# b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype)
res = weight_int4pack_mm(a, b_int4, scales)
ref = torch.mm(a_bf16, b_bf16)
b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype)
ref = torch.mm(a, b)
res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)

mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)





@dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU
@setBlasBackendsToDefaultFinally
def matmul_small_brute_force_1d_Nd(self, device, dtype):
Expand Down

0 comments on commit 81a72f1

Please sign in to comment.