From 62eea62493b7c850739c8fc08ab4bce4f4a9e9dc Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 12 Nov 2024 22:15:47 -0800 Subject: [PATCH] [Quant][Onednn] add linear_dynamic_fp16 ops (#140376) **About this PR** This PR adds the following ops for `linear_dynamic_fp16` in onednn namespace. These ops are intended for PT2E quantization eager mode. - `onednn::linear_prepack_fp16`: packs fp32 weight to an fp16 MkldnnCPU tensor. - `onednn::linear_dynamic_fp16`: takes an fp32 CPU tensor and an fp16 MkldnnCPU tensor and compute linear in fp32 - `onednn::linear_relu_dynamic_fp16`: similar as the former and apply relu on output. **Test plan** `python test/test_quantization.py -k test_linear_dynamic_fp16_onednn` **Implementation** These ops call oneDNN lib under the hood. It's worth noting that oneDNN does not support f32 * f16 -> f32 computation, so we have to convert fp16 weight to fp32 before computation. And weight is still in plain format after packing. **Correctness and performance** Correctness is guaranteed by UT. Performance of the new ops may be better than the FBGEMM implementation when weight shape is small but worse when weight shape is large. It's because weight dtype conversion and computation are not fused. For example, I ran benchmarks on an Intel(R) Xeon(R) Platinum 8490H machine with different cores and shapes. When using 1 core per instance, the new implementation generally is faster for weight shape < 1024 * 1024. When using more cores, the threshold will increase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140376 Approved by: https://github.com/jerryzh168, https://github.com/jgong5 --- .../native/quantized/cpu/qlinear_dynamic.cpp | 119 ++++++++++++++++++ .../native/quantized/cpu/qlinear_prepack.cpp | 36 ++++++ aten/src/ATen/native/quantized/library.cpp | 4 + test/quantization/core/test_quantized_op.py | 33 +++++ 4 files changed, 192 insertions(+) diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 95854d14b2c25..091e309cd95d8 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -616,6 +617,92 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_relu( std::move(input), reduce_range); } +static at::Tensor linear_dynamic_fp16_with_onednn_weight( + at::Tensor input, + at::Tensor onednn_weight, // fp16 tensor from MkldnnCPU + std::optional bias, + bool relu_fused) { + using ideep::tensor; + const int64_t dim = input.dim(); + TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float, + "onednn linear dynamic fp16: data type of input should be float."); + TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Half, + "onednn linear dynamic fp16: data type of weight should be half."); + + // If the input has more than two dimensions, we will reshape it to a 2-dimensional form + // for calculation and subsequently reshape the output back. + auto input_contig = + dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous(); + + auto src = at::native::itensor_from_tensor(input_contig); + auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight); + int64_t K = input.size(dim - 1), M = input.numel() / K, N = packed_weight.get_dim(1); + + auto output_size = input.sizes().vec(); + output_size[dim - 1] = N; + + std::optional onednn_bias{std::nullopt}; + bool with_bias = bias.has_value(); + at::Tensor bias_val_float; + if (with_bias) { + bias_val_float = bias.value().to(at::kFloat); + if (bias_val_float.dim() == 1) { + auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)}); + onednn_bias = at::native::itensor_view_from_dense(b_reshape); + } else { + onednn_bias = at::native::itensor_view_from_dense(bias_val_float); + } + } + std::vector src_dims = {M, K}; + std::vector dst_dims = {M, N}; + at::Tensor output = at::empty( + dst_dims, + device(c10::kCPU) + .dtype(c10::kFloat) + ); + if (output.numel() == 0) { + return output; + } + tensor dst = at::native::itensor_view_from_dense(output); + static tensor empty_tensor; + static tensor::desc empty_tensor_desc; + + // Create matmul primitive + auto src_dtype = ideep::data_type::f32; + auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any); + // onednn does not support f32f16f32 matmul, so we get primitive with f32 weight desc + // weight is stored in f16 and reordered to f32 below by `reorder_if_differ_in` + auto weights_desc = tensor::desc(packed_weight.get_dims(), ideep::data_type::f32, ideep::format_tag::any); + auto dst_dtype = dst.get_data_type(); + auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any); + auto bias_desc = with_bias ? + tensor::desc(onednn_bias.value().get_dims(), ideep::data_type::f32, ideep::format_tag::any) : + empty_tensor_desc; + // Get op attr for primitive + auto op_attr = relu_fused ? ideep::attr_t::fuse_relu() : ideep::attr_t(); + op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto engine = ideep::engine::cpu_engine(); + auto primitive_desc = with_bias ? + dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) : + dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, dst_desc, op_attr); + auto primitive = dnnl::matmul(primitive_desc); + + // Convert weight from f16 to f32 with layout changes + auto expected_weight = packed_weight.reorder_if_differ_in(primitive_desc.weights_desc()); + + // Prepare args and execute primitive + tensor scratchpad(primitive_desc.scratchpad_desc()); + ideep::exec_args args; + args.insert({DNNL_ARG_SRC, src}); + args.insert({DNNL_ARG_WEIGHTS, expected_weight}); + args.insert({DNNL_ARG_DST, dst}); + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); + if (with_bias) { + args.insert({DNNL_ARG_BIAS, onednn_bias.value()}); + } + primitive.execute(ideep::stream::default_stream(), args); + return dim == 2 ? output : output.reshape(output_size); +} #endif // #if AT_MKLDNN_ENABLED() namespace at::native { @@ -786,6 +873,32 @@ at::Tensor wrapped_fbgemm_linear_fp16_weight_meta(const at::Tensor& input, const #endif // USE_FBGEMM } +class LinearDynamicFp16Onednn final { + public: + static Tensor run( + Tensor act, // int8 CPU tensor, not QTensor + Tensor onednn_weight, // int8 tensor from MkldnnCPU + std::optional bias) { +#if AT_MKLDNN_ENABLED() + return linear_dynamic_fp16_with_onednn_weight( + act, onednn_weight, bias, /*relu_fused*/false); +#endif + TORCH_CHECK(false, "Unimplemented (linear_dynamic_fp16_with_onednn_weight)"); + } + + static Tensor run_relu( + Tensor act, // int8 CPU tensor, not QTensor + Tensor onednn_weight, // int8 tensor from MkldnnCPU + std::optional bias) { +#if AT_MKLDNN_ENABLED() + return linear_dynamic_fp16_with_onednn_weight( + act, onednn_weight, bias, /*relu_fused*/true); +#endif + TORCH_CHECK(false, "Unimplemented (linear_dynamic_fp16_with_onednn_weight)"); + } + +}; + TORCH_LIBRARY_IMPL(quantized, CPU, m) { register_linear_params(); @@ -834,5 +947,11 @@ TORCH_LIBRARY_IMPL(_quantized, Meta, m) { wrapped_fbgemm_linear_fp16_weight_meta); } +TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { + m.impl(TORCH_SELECTIVE_NAME("onednn::linear_dynamic_fp16"), + TORCH_FN(LinearDynamicFp16Onednn::run)); + m.impl(TORCH_SELECTIVE_NAME("onednn::linear_relu_dynamic_fp16"), + TORCH_FN(LinearDynamicFp16Onednn::run_relu)); +} } // namespace } // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index f4c55b2a3cfe4..d9e3d484d02d2 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -309,6 +309,23 @@ inline at::Tensor pack_weight_to_onednn_tensor( return packed_weight; } +inline at::Tensor pack_weight_to_fp16_onednn_tensor( + at::Tensor& weight, + std::optional>& input_shape) { + weight = at::_saturate_weight_to_fp16(weight); + std::vector w_dims = weight.sizes().vec(); + auto weight_fp16 = weight.to(at::kHalf); + ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::f16}, weight_fp16.data_ptr()); + auto expected_weight = wei.transpose(0, 1); // oneDNN requires transposed weight + // Onednn does not support f32f16f32 matmul, so we need to convert weight to f32 before compute + // Therefore, we just return weight in plain format + auto packed_weight = at::native::new_with_itensor_mkldnn( + std::move(expected_weight), + c10::kHalf, + weight.options().device_opt()); + return packed_weight; +} + #endif // #if AT_MKLDNN_ENABLED() namespace at::native { @@ -672,6 +689,21 @@ class QLinearPackWeightInt8Onednn final { } }; +class QLinearPackWeightFp16Onednn final { + public: + static at::Tensor run( + // NOLINTNEXTLINE(performance-unnecessary-value-param) + [[maybe_unused]] at::Tensor weight, // Not QTensor + // NOLINTNEXTLINE(performance-unnecessary-value-param) + [[maybe_unused]] std::optional> input_shape) { +#if AT_MKLDNN_ENABLED() + return pack_weight_to_fp16_onednn_tensor(weight, input_shape); +#else + TORCH_CHECK(false, "Unimplemented as onednn is not available."); +#endif + } +}; + TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run)); @@ -716,5 +748,9 @@ TORCH_LIBRARY_IMPL(onednn, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"), TORCH_FN(QLinearPackWeightInt8Onednn::run)); } +TORCH_LIBRARY_IMPL(onednn, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("onednn::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16Onednn::run)); +} + } // namespace } // namespace at::native diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 05341366a9dfa..72dcda2b74de4 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -272,7 +272,11 @@ TORCH_LIBRARY(onednn, m) { // Linear prepack m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_prepack(Tensor weight, int[]? x_shape) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::linear_prepack_fp16(Tensor weight, int[]? x_shape) -> Tensor")); + // Linear + m.def(TORCH_SELECTIVE_SCHEMA("onednn::linear_dynamic_fp16(Tensor x, Tensor w, Tensor? bias) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::linear_relu_dynamic_fp16(Tensor x, Tensor w, Tensor? bias) -> Tensor")); // Linear with unary postop m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor")); diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index f7c7330a8c991..0e419989d3560 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -3750,6 +3750,39 @@ def test_dynamic_convtranspose3d(self): return # TODO: fix MakeDeConvOutputShape overflowing for convT3d with qnnpack self._test_qconv_op_impl(q_mod, dq_op, dim, dtype) + @skipIfNoONEDNN + def test_linear_dynamic_fp16_onednn(self): + + options = itertools.product( + (2, 4), # batch_size + (4, 5, 12), # input_channels + (4, 7, 8), # output_channels + (True, False), # use_bias + (True, False), # use_relu + ) + for batch_size, input_channels, output_channels, use_bias, use_relu in options: + qlinear_prepack = torch.ops.onednn.linear_prepack_fp16 + if use_relu: + qlinear_dynamic = torch.ops.onednn.linear_relu_dynamic_fp16 + else: + qlinear_dynamic = torch.ops.onednn.linear_dynamic_fp16 + + x = torch.randn(batch_size, input_channels) + w = torch.randn(output_channels, input_channels) + bias = torch.randn(output_channels) if use_bias else None + + w_packed = qlinear_prepack(w, x.shape) + out = qlinear_dynamic(x, w_packed, bias) + + # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors + # output is FP32 + w_fp16 = w.to(torch.float16).to(torch.float32) + ref = F.linear(x, w_fp16, bias) + if use_relu: + ref.relu_() + + self.assertEqual(out, ref) + class TestQuantizedLinear(TestCase): def _test_qlinear_impl(self, batch_size, input_channels, output_channels, use_bias,