Skip to content

Commit

Permalink
modify UT for int4
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Dec 11, 2024
1 parent f912813 commit 2424d54
Showing 1 changed file with 112 additions and 0 deletions.
112 changes: 112 additions & 0 deletions test/xpu/test_linalg_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.testing._internal.common_dtype import floating_and_complex_types_and
from torch.testing._internal.common_cuda import tf32_on_and_off
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
from torch.testing import make_tensor
import unittest
import itertools
Expand Down Expand Up @@ -171,6 +172,116 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True):
if not use_transpose_a and not use_transpose_b:
_test(17, k, n, use_transpose_a, use_transpose_b)

@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@parametrize("m", [32])
@parametrize("k", [32])
@parametrize("n", [32])
def _int4_mm(self, device, m, k, n):
@staticmethod
def rand_int4(size, dtype=torch.int32, device="xpu"):
rand = torch.randint(-128, 128, [size // 2], device=device).to(torch.int8)
return rand.view(dtype=dtype)

@staticmethod
def unpack_weight(qweight, scales, qzeros, q_config):
group_size = q_config["group_size"]
bits = q_config["bits"]
s32_bits = 32

assert bits == 4
# Int32 can store 8 * 4bits data. This is the offset for each data.
wf = (
torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32)
.unsqueeze(0)
.to("xpu")
)
zeros = torch.bitwise_right_shift(
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)
).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)

zeros = zeros + 1
zeros = zeros.reshape(scales.shape)

weight = torch.bitwise_right_shift(
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)
).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(weight, (2**bits) - 1, out=weight)

return weight, scales, zeros

@staticmethod
def dequantize(qweight, scales, qzeros, group_size):
q_config = {"group_size": group_size, "bits": 4}
weight, gptq_scales, gptq_zeros = unpack_weight(
qweight, scales, qzeros, q_config
)
gptq_zeros = (torch.ones_like(gptq_zeros) * 8).to("xpu") # TODO: hard code zp
if len(weight.shape) > 2:
weight = weight.reshape(-1, weight.shape[-1])
infeatures = weight.shape[0]
g_idx = torch.tensor(
[i // q_config["group_size"] for i in range(infeatures)],
dtype=torch.int32,
)
scale_zeros = gptq_zeros * gptq_scales
weight = gptq_scales[g_idx.long()] * weight - scale_zeros[g_idx.long()]
return weight

def convert_weight_to_int4pack(b):
b_tmp, b_scales_and_zeros = _group_quantize_tensor(
b, n_bit=4, q_group_size=q_group
)
if self.device_type == 'cpu':
b_int4pack = torch._convert_weight_to_int4pack_for_cpu(
b_tmp, inner_k_tiles
)
else:
b_int4pack = torch._convert_weight_to_int4pack(
b_tmp, inner_k_tiles
)

return b_int4pack, b_scales_and_zeros

def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
self.assertTrue(b_int4pack.dtype is torch.int32)
# self.assertTrue(b_int4pack.dim() == 4)
return torch._weight_int4pack_mm(
a, b_int4pack, q_group, b_scales_and_zeros
)

dtype = torch.bfloat16
q_group = 32
inner_k_tiles = 2

torch.manual_seed(1)
a_bf16 = torch.rand((m, k), dtype=dtype, device=device)
b_int4 = rand_int4(k * n, torch.int32, "xpu").reshape(k // 8, n)
group_num = int(k / q_group)

scales = torch.rand([group_num, n], device="xpu", dtype=dtype)
zero_points = rand_int4(group_num * n, torch.int32, "xpu").reshape(
group_num, n // 8
)

b_bf16 = dequantize(b_int4, scales, zero_points, q_group).cpu()

# b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16)

for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []):
a = a_bf16.to(dtype=dtype)
b = b_bf16.to(dtype=dtype)
# b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype)
res = weight_int4pack_mm(a, b_int4, scales)
ref = torch.mm(a_bf16, b_bf16)

mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)





@dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU
@setBlasBackendsToDefaultFinally
def matmul_small_brute_force_1d_Nd(self, device, dtype):
Expand Down Expand Up @@ -229,6 +340,7 @@ def ck_blas_library(self):
TestLinalg.test_preferred_linalg_library=preferred_linalg_library
TestLinalg.test_addbmm=addbmm
TestLinalg.test__int_mm=_int_mm
TestLinalg.test__int4_mm=_int4_mm
TestLinalg.test_matmul_small_brute_force_1d_Nd=matmul_small_brute_force_1d_Nd
TestLinalg.test_matmul_small_brute_force_2d_Nd=matmul_small_brute_force_2d_Nd
TestLinalg.test_matmul_small_brute_force_3d_Nd=matmul_small_brute_force_3d_Nd
Expand Down

0 comments on commit 2424d54

Please sign in to comment.