Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linear_int4_kernel for XPU #1130

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1e32bbc
Sync main into release/2.6 branch (#1117)
xytintel Nov 22, 2024
f312190
[Release-2.6] Fix bugs of `empty_xpu` and `soft_shrink` (#1139)
xytintel Dec 3, 2024
7ecb0b1
[Release-2.6] Capture rrelu_with_noise noise mutation in compile (#1145)
xytintel Dec 5, 2024
5410f51
contiguous layout for sycl int4 kernel
airMeng Nov 22, 2024
e9311a3
push without compile
sunjiweiswift Nov 26, 2024
e3eaffa
update linearkernel
sunjiweiswift Nov 28, 2024
2a664af
fix some comiple error(not all)
sunjiweiswift Nov 28, 2024
0156ba5
add sycl_ker_config_convention
sunjiweiswift Nov 28, 2024
a58afec
reg kernel for pytorch
sunjiweiswift Nov 29, 2024
f487b20
add yaml for int4mm
sunjiweiswift Nov 29, 2024
ce1c894
update yaml file
sunjiweiswift Dec 3, 2024
d61b198
Modified some review comments
sunjiweiswift Dec 3, 2024
d76a0ce
modify fun name
sunjiweiswift Dec 9, 2024
870a3b5
autogen: _weight_int4pack_mm_with_scales_and_zeros.out
sunjiweiswift Dec 10, 2024
a9627f6
param int->int64_t(python int is int64)
sunjiweiswift Dec 10, 2024
952ead9
use AT_DISPATCH_FLOATING_TYPES_AND
sunjiweiswift Dec 10, 2024
93804f9
Keep the same name as pytorch's _weight_int4pack_mm
sunjiweiswift Dec 11, 2024
9e50b68
modify UT for int4
sunjiweiswift Dec 11, 2024
81a72f1
sync UT with pytoch UT(linalg)
sunjiweiswift Dec 12, 2024
a70df0a
col-major
sunjiweiswift Dec 12, 2024
c08382c
UT pass for B ones
sunjiweiswift Dec 13, 2024
14bb4e0
update gemv
sunjiweiswift Dec 16, 2024
70a3e13
fix scale and zp address
sunjiweiswift Dec 17, 2024
a590ad6
fix K large than 1024 UT
sunjiweiswift Dec 18, 2024
d6a2f3a
bug fix for FP16(BF16 maybe incorrect)
sunjiweiswift Dec 18, 2024
27f18c2
save
sunjiweiswift Dec 20, 2024
7f94b9b
Merge branch 'main' into fp_zp
sunjiweiswift Dec 20, 2024
42c18e9
bugfix for Big Endian
sunjiweiswift Dec 20, 2024
d832050
Unify BF16 and FP16 Funtion
sunjiweiswift Dec 20, 2024
8385f7e
fix compile warning
sunjiweiswift Dec 20, 2024
f44ed70
modify by review
sunjiweiswift Dec 23, 2024
09696b1
Merge branch 'main' into fp_zp
sunjiweiswift Dec 24, 2024
ebe8c7c
Merge branch 'main' into fp_zp
sunjiweiswift Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
col-major
sunjiweiswift committed Dec 12, 2024
commit a70df0a7257151d40499e5f0c5840c4699f4e3e6
2 changes: 1 addition & 1 deletion src/ATen/native/xpu/LinearInt4.cpp
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ Tensor _weight_int4pack_mm_xpu(
int64_t qGroupSize,
const Tensor& qScaleAndZeros) {
auto M = A.size(0);
auto N = B.size(0);
auto N = B.size(1);
auto K = A.size(1);
TORCH_CHECK(
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
3 changes: 2 additions & 1 deletion test/xpu/test_linalg_xpu.py
Original file line number Diff line number Diff line change
@@ -272,8 +272,9 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype)
ref = torch.mm(a, b)
res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)

mean_err = ((res - ref).abs() / ref).mean()
print(ref)
print(res)
self.assertTrue(mean_err < 0.05)

@dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU