diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp index a3b6f4ac2..ac2e76909 100644 --- a/src/ATen/native/xpu/LinearInt4.cpp +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -50,13 +50,19 @@ Tensor _weight_int4pack_mm_with_scales_and_zeros_xpu( std::optional common_device = std::nullopt; c10::impl::check_and_update_common_device( - common_device, input, "xpu::linear_int4", "input"); + common_device, + input, + "xpu::_weight_int4pack_mm_with_scales_and_zeros", + "input"); c10::impl::check_and_update_common_device( - common_device, weight, "xpu::linear_int4", "weight"); + common_device, + weight, + "xpu::_weight_int4pack_mm_with_scales_and_zeros", + "weight"); c10::impl::check_and_update_common_device( common_device, weight_scale_zero_point, - "xpu::linear_int4", + "xpu::_weight_int4pack_mm_with_scales_and_zeros", "weight_scale_zero_point"); Tensor output = at::empty({M, N}, input.options()); diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 6ac13442c..558250803 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -139,29 +139,58 @@ void linear_int4_kernel( int lda = k; int ldb = n; int ldc = n; - if (input.scalar_type() == at::kHalf) { - using scalar_t = at::Half; - // const auto scalar_t = input.scalar_type(); - const scalar_t* input_data = input.data_ptr(); - uint32_t* weight_data = weight.data_ptr(); // int4x8 - scalar_t* output_data = output.data_ptr(); - scalar_t* weight_scale_data = weight_scale_zero_point.data_ptr(); - LinearInt4KernelFunctor kfn( - input_data, - weight_data, - output_data, - weight_scale_data, - nullptr, - m, - n, - k, - k, - n, - n); + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::Half, input.scalar_type(), "linear_int4_kernel", [&]() { + using scalar_t = at::Half; + const scalar_t* input_data = input.data_ptr(); + uint32_t* weight_data = weight.data_ptr(); // int4x8 - sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); - } + scalar_t* output_data = output.data_ptr(); + scalar_t* weight_scale_data = + weight_scale_zero_point.data_ptr(); + LinearInt4KernelFunctor kfn( + input_data, + weight_data, + output_data, + weight_scale_data, + nullptr, + m, + n, + k, + k, + n, + n); + + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + }); + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + input.scalar_type(), + "linear_int4_kernel", + [&]() { + using scalar_t = at::BFloat16; + const scalar_t* input_data = input.data_ptr(); + uint32_t* weight_data = weight.data_ptr(); // int4x8 + + scalar_t* output_data = output.data_ptr(); + scalar_t* weight_scale_data = + weight_scale_zero_point.data_ptr(); + LinearInt4KernelFunctor kfn( + input_data, + weight_data, + output_data, + weight_scale_data, + nullptr, + m, + n, + k, + k, + n, + n); + + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + }); } } // namespace at::native::xpu \ No newline at end of file