Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linear_int4_kernel for XPU #1130

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open

linear_int4_kernel for XPU #1130

wants to merge 25 commits into from

Conversation

sunjiweiswift
Copy link

@sunjiweiswift sunjiweiswift commented Nov 29, 2024

Pure SYCL path for. int4 gemm

Benchmark results on PVC-1100. The remaining gaps are lack of usage of 2D load.

M K N SrcT WeiT DstT Bandwidth usage (BW usage)
1 4096 4096 float16 float16 float16 53.7%
1 4096 11008 float16 float16 float16 57.4%
1 4096 16384 float16 float16 float16 59.7%
1 12288 4096 float16 float16 float16 77.3%

Besides PVC, the kernel can achieve
92.7% bandwidth usage on MTL
84.7% bandwidth usage on A750

Reset to
bfdbaf4

---------

Co-authored-by: mengfei25 <[email protected]>
Co-authored-by: LuFengqing <[email protected]>
Co-authored-by: Ratnam Parikh <[email protected]>
Co-authored-by: Feng Yuan <[email protected]>
@sunjiweiswift sunjiweiswift changed the title Fp zp linear_int4_kernel for XPU Nov 29, 2024
Copy link

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the biggest question should be why we need post op fusion here? does pytorch have it with cuda?

src/ATen/native/xpu/sycl/LinearInt4.cpp Show resolved Hide resolved
src/ATen/native/xpu/sycl/LinearInt4.cpp Outdated Show resolved Hide resolved
src/ATen/native/xpu/sycl/LinearInt4.cpp Outdated Show resolved Hide resolved
src/ATen/native/xpu/sycl/LinearInt4.cpp Show resolved Hide resolved
src/ATen/native/xpu/sycl/LinearInt4.cpp Outdated Show resolved Hide resolved
auto aptr = A;
auto cptr = C + g_n;
if constexpr (std::is_same_v<scalar_t, sycl::half>) {
sycl::half2 tmpAcc = {0.f, 0.f};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it safe to use half as acc type?

usually, the acc type for both float16 and bfloat16 are float32

ref: https://github.com/pytorch/pytorch/blob/795f28ac552eb61d02ea02fd64637ba814133bd8/aten/src/ATen/native/cuda/int4mm.cu#L727

*cptr = sum[0] + sum[1];
}
} else {
scalar_t tmpAcc = 0.f;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to be VERY careful about the acc type. Slight difference between CUDA may lead to accuracy errors that are very very difficult to debug in a finla e2e model, especially in LLM

Comment on lines 60 to 65
for (int ikk = 0; ikk < TileK; ikk += 2) {
sycl::half2 tmpA = *(sycl::half2*)&aptr[sg_id * TileK + ikk];
sycl::half2 tmpB = {
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
tmpAcc += tmpA * tmpB * scale;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to do vectorized load and shift with sycl? i don't know.
if not, i guess this is best perf that we can get so far. this line should be the major bottlenecks.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depend on the IGC auto vectorization

zero_points = torch.Tensor([8]).to(torch.int8).to("xpu")
weight_ba = weight.transpose(0, 1).contiguous()

out_onednn =torch._weight_int4pack_mm_with_scales_and_zeros(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a more general question is that where are we placing _weight_int4pack_mm_with_scales_and_zeros, pytorch does not have this right now.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be added in pytorch/pytorch#137566

)

# check gemm + bias + gelu
out_onednn_gelu = torch._weight_int4pack_mm_with_scales_and_zeros(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where was the signature with "tanh" defined?
does pytorch has a packed int4 gemm with post op?

@mingfeima
Copy link

@liangan1 CC

@mingfeima
Copy link

@sunjiweiswift for the perf benchmarking, please include other configs expect M=1. This would serve as a reference of final decision making. I expect that big M would have worse perf, but that's fine, we still need to know the numbers.

@sunjiweiswift sunjiweiswift force-pushed the fp_zp branch 2 times, most recently from faa79b7 to 5a08d2e Compare December 9, 2024 05:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants