Skip to content

Commit

Permalink
use AT_DISPATCH_FLOATING_TYPES_AND
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Dec 10, 2024
1 parent 209ad7f commit e7c0a36
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 24 deletions.
12 changes: 9 additions & 3 deletions src/ATen/native/xpu/LinearInt4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,19 @@ Tensor _weight_int4pack_mm_with_scales_and_zeros_xpu(

std::optional<Device> common_device = std::nullopt;
c10::impl::check_and_update_common_device(
common_device, input, "xpu::linear_int4", "input");
common_device,
input,
"xpu::_weight_int4pack_mm_with_scales_and_zeros",
"input");
c10::impl::check_and_update_common_device(
common_device, weight, "xpu::linear_int4", "weight");
common_device,
weight,
"xpu::_weight_int4pack_mm_with_scales_and_zeros",
"weight");
c10::impl::check_and_update_common_device(
common_device,
weight_scale_zero_point,
"xpu::linear_int4",
"xpu::_weight_int4pack_mm_with_scales_and_zeros",
"weight_scale_zero_point");
Tensor output = at::empty({M, N}, input.options());

Expand Down
71 changes: 50 additions & 21 deletions src/ATen/native/xpu/sycl/LinearInt4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,29 +139,58 @@ void linear_int4_kernel(
int lda = k;
int ldb = n;
int ldc = n;
if (input.scalar_type() == at::kHalf) {
using scalar_t = at::Half;
// const auto scalar_t = input.scalar_type();
const scalar_t* input_data = input.data_ptr<scalar_t>();
uint32_t* weight_data = weight.data_ptr<uint32_t>(); // int4x8

scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* weight_scale_data = weight_scale_zero_point.data_ptr<scalar_t>();
LinearInt4KernelFunctor<scalar_t, 16> kfn(
input_data,
weight_data,
output_data,
weight_scale_data,
nullptr,
m,
n,
k,
k,
n,
n);
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Half, input.scalar_type(), "linear_int4_kernel", [&]() {
using scalar_t = at::Half;
const scalar_t* input_data = input.data_ptr<scalar_t>();
uint32_t* weight_data = weight.data_ptr<uint32_t>(); // int4x8

sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
}
scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* weight_scale_data =
weight_scale_zero_point.data_ptr<scalar_t>();
LinearInt4KernelFunctor<scalar_t, 16> kfn(
input_data,
weight_data,
output_data,
weight_scale_data,
nullptr,
m,
n,
k,
k,
n,
n);

sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
});
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
input.scalar_type(),
"linear_int4_kernel",
[&]() {
using scalar_t = at::BFloat16;
const scalar_t* input_data = input.data_ptr<scalar_t>();
uint32_t* weight_data = weight.data_ptr<uint32_t>(); // int4x8

scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* weight_scale_data =
weight_scale_zero_point.data_ptr<scalar_t>();
LinearInt4KernelFunctor<scalar_t, 16> kfn(
input_data,
weight_data,
output_data,
weight_scale_data,
nullptr,
m,
n,
k,
k,
n,
n);

sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
});
}

} // namespace at::native::xpu

0 comments on commit e7c0a36

Please sign in to comment.