diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index 3c1c8aed7..bff4eacf4 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -6,6 +6,7 @@ from torch.testing._internal.common_dtype import floating_and_complex_types_and from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel from torch.testing import make_tensor import unittest import itertools @@ -171,6 +172,116 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): if not use_transpose_a and not use_transpose_b: _test(17, k, n, use_transpose_a, use_transpose_b) +@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") +@parametrize("m", [32]) +@parametrize("k", [32]) +@parametrize("n", [32]) +def _int4_mm(self, device, m, k, n): + @staticmethod + def rand_int4(size, dtype=torch.int32, device="xpu"): + rand = torch.randint(-128, 128, [size // 2], device=device).to(torch.int8) + return rand.view(dtype=dtype) + + @staticmethod + def unpack_weight(qweight, scales, qzeros, q_config): + group_size = q_config["group_size"] + bits = q_config["bits"] + s32_bits = 32 + + assert bits == 4 + # Int32 can store 8 * 4bits data. This is the offset for each data. + wf = ( + torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32) + .unsqueeze(0) + .to("xpu") + ) + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(scales.shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(weight, (2**bits) - 1, out=weight) + + return weight, scales, zeros + + @staticmethod + def dequantize(qweight, scales, qzeros, group_size): + q_config = {"group_size": group_size, "bits": 4} + weight, gptq_scales, gptq_zeros = unpack_weight( + qweight, scales, qzeros, q_config + ) + gptq_zeros = (torch.ones_like(gptq_zeros) * 8).to("xpu") # TODO: hard code zp + if len(weight.shape) > 2: + weight = weight.reshape(-1, weight.shape[-1]) + infeatures = weight.shape[0] + g_idx = torch.tensor( + [i // q_config["group_size"] for i in range(infeatures)], + dtype=torch.int32, + ) + scale_zeros = gptq_zeros * gptq_scales + weight = gptq_scales[g_idx.long()] * weight - scale_zeros[g_idx.long()] + return weight + + def convert_weight_to_int4pack(b): + b_tmp, b_scales_and_zeros = _group_quantize_tensor( + b, n_bit=4, q_group_size=q_group + ) + if self.device_type == 'cpu': + b_int4pack = torch._convert_weight_to_int4pack_for_cpu( + b_tmp, inner_k_tiles + ) + else: + b_int4pack = torch._convert_weight_to_int4pack( + b_tmp, inner_k_tiles + ) + + return b_int4pack, b_scales_and_zeros + + def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): + self.assertTrue(b_int4pack.dtype is torch.int32) + # self.assertTrue(b_int4pack.dim() == 4) + return torch._weight_int4pack_mm( + a, b_int4pack, q_group, b_scales_and_zeros + ) + + dtype = torch.bfloat16 + q_group = 32 + inner_k_tiles = 2 + + torch.manual_seed(1) + a_bf16 = torch.rand((m, k), dtype=dtype, device=device) + b_int4 = rand_int4(k * n, torch.int32, "xpu").reshape(k // 8, n) + group_num = int(k / q_group) + + scales = torch.rand([group_num, n], device="xpu", dtype=dtype) + zero_points = rand_int4(group_num * n, torch.int32, "xpu").reshape( + group_num, n // 8 + ) + + b_bf16 = dequantize(b_int4, scales, zero_points, q_group).cpu() + + # b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) + + for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []): + a = a_bf16.to(dtype=dtype) + b = b_bf16.to(dtype=dtype) + # b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) + res = weight_int4pack_mm(a, b_int4, scales) + ref = torch.mm(a_bf16, b_bf16) + + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.05) + + + + + @dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU @setBlasBackendsToDefaultFinally def matmul_small_brute_force_1d_Nd(self, device, dtype): @@ -229,6 +340,7 @@ def ck_blas_library(self): TestLinalg.test_preferred_linalg_library=preferred_linalg_library TestLinalg.test_addbmm=addbmm TestLinalg.test__int_mm=_int_mm +TestLinalg.test__int4_mm=_int4_mm TestLinalg.test_matmul_small_brute_force_1d_Nd=matmul_small_brute_force_1d_Nd TestLinalg.test_matmul_small_brute_force_2d_Nd=matmul_small_brute_force_2d_Nd TestLinalg.test_matmul_small_brute_force_3d_Nd=matmul_small_brute_force_3d_Nd