Skip to content

Commit

Permalink
[Quant][Onednn] add linear_dynamic_fp16 ops (pytorch#140376)
Browse files Browse the repository at this point in the history
**About this PR**
This PR adds the following ops for `linear_dynamic_fp16` in onednn namespace. These ops are intended for PT2E quantization eager mode.
- `onednn::linear_prepack_fp16`: packs fp32 weight to an fp16 MkldnnCPU tensor.
- `onednn::linear_dynamic_fp16`: takes an fp32 CPU tensor and an fp16 MkldnnCPU tensor and compute linear in fp32
- `onednn::linear_relu_dynamic_fp16`: similar as the former and apply relu on output.

**Test plan**
`python test/test_quantization.py -k test_linear_dynamic_fp16_onednn`

**Implementation**
These ops call oneDNN lib under the hood. It's worth noting that oneDNN does not support f32 * f16 -> f32 computation, so we have to convert fp16 weight to fp32 before computation. And weight is still in plain format after packing.

**Correctness and performance**
Correctness is guaranteed by UT.
Performance of the new ops may be better than the FBGEMM implementation when weight shape is small but worse when weight shape is large. It's because weight dtype conversion and computation are not fused.
For example, I ran benchmarks on an Intel(R) Xeon(R) Platinum 8490H machine with different cores and shapes. When using 1 core per instance, the new implementation generally is faster for weight shape < 1024 * 1024. When using more cores, the threshold will increase.

Pull Request resolved: pytorch#140376
Approved by: https://github.com/jerryzh168, https://github.com/jgong5
  • Loading branch information
Xia-Weiwen authored and pytorchmergebot committed Nov 14, 2024
1 parent 99c8d5a commit 62eea62
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 0 deletions.
119 changes: 119 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
#include <ATen/native/quantized/cpu/OnednnUtils.h>
#include <ATen/native/quantized/cpu/QuantUtils.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <torch/library.h>

Expand Down Expand Up @@ -616,6 +617,92 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_relu(
std::move(input), reduce_range);
}

static at::Tensor linear_dynamic_fp16_with_onednn_weight(
at::Tensor input,
at::Tensor onednn_weight, // fp16 tensor from MkldnnCPU
std::optional<at::Tensor> bias,
bool relu_fused) {
using ideep::tensor;
const int64_t dim = input.dim();
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float,
"onednn linear dynamic fp16: data type of input should be float.");
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Half,
"onednn linear dynamic fp16: data type of weight should be half.");

// If the input has more than two dimensions, we will reshape it to a 2-dimensional form
// for calculation and subsequently reshape the output back.
auto input_contig =
dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous();

auto src = at::native::itensor_from_tensor(input_contig);
auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight);
int64_t K = input.size(dim - 1), M = input.numel() / K, N = packed_weight.get_dim(1);

auto output_size = input.sizes().vec();
output_size[dim - 1] = N;

std::optional<ideep::tensor> onednn_bias{std::nullopt};
bool with_bias = bias.has_value();
at::Tensor bias_val_float;
if (with_bias) {
bias_val_float = bias.value().to(at::kFloat);
if (bias_val_float.dim() == 1) {
auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)});
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
} else {
onednn_bias = at::native::itensor_view_from_dense(bias_val_float);
}
}
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
at::Tensor output = at::empty(
dst_dims,
device(c10::kCPU)
.dtype(c10::kFloat)
);
if (output.numel() == 0) {
return output;
}
tensor dst = at::native::itensor_view_from_dense(output);
static tensor empty_tensor;
static tensor::desc empty_tensor_desc;

// Create matmul primitive
auto src_dtype = ideep::data_type::f32;
auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any);
// onednn does not support f32f16f32 matmul, so we get primitive with f32 weight desc
// weight is stored in f16 and reordered to f32 below by `reorder_if_differ_in`
auto weights_desc = tensor::desc(packed_weight.get_dims(), ideep::data_type::f32, ideep::format_tag::any);
auto dst_dtype = dst.get_data_type();
auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
auto bias_desc = with_bias ?
tensor::desc(onednn_bias.value().get_dims(), ideep::data_type::f32, ideep::format_tag::any) :
empty_tensor_desc;
// Get op attr for primitive
auto op_attr = relu_fused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
auto engine = ideep::engine::cpu_engine();
auto primitive_desc = with_bias ?
dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) :
dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, dst_desc, op_attr);
auto primitive = dnnl::matmul(primitive_desc);

// Convert weight from f16 to f32 with layout changes
auto expected_weight = packed_weight.reorder_if_differ_in(primitive_desc.weights_desc());

// Prepare args and execute primitive
tensor scratchpad(primitive_desc.scratchpad_desc());
ideep::exec_args args;
args.insert({DNNL_ARG_SRC, src});
args.insert({DNNL_ARG_WEIGHTS, expected_weight});
args.insert({DNNL_ARG_DST, dst});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
if (with_bias) {
args.insert({DNNL_ARG_BIAS, onednn_bias.value()});
}
primitive.execute(ideep::stream::default_stream(), args);
return dim == 2 ? output : output.reshape(output_size);
}
#endif // #if AT_MKLDNN_ENABLED()

namespace at::native {
Expand Down Expand Up @@ -786,6 +873,32 @@ at::Tensor wrapped_fbgemm_linear_fp16_weight_meta(const at::Tensor& input, const
#endif // USE_FBGEMM
}

class LinearDynamicFp16Onednn final {
public:
static Tensor run(
Tensor act, // int8 CPU tensor, not QTensor
Tensor onednn_weight, // int8 tensor from MkldnnCPU
std::optional<Tensor> bias) {
#if AT_MKLDNN_ENABLED()
return linear_dynamic_fp16_with_onednn_weight(
act, onednn_weight, bias, /*relu_fused*/false);
#endif
TORCH_CHECK(false, "Unimplemented (linear_dynamic_fp16_with_onednn_weight)");
}

static Tensor run_relu(
Tensor act, // int8 CPU tensor, not QTensor
Tensor onednn_weight, // int8 tensor from MkldnnCPU
std::optional<Tensor> bias) {
#if AT_MKLDNN_ENABLED()
return linear_dynamic_fp16_with_onednn_weight(
act, onednn_weight, bias, /*relu_fused*/true);
#endif
TORCH_CHECK(false, "Unimplemented (linear_dynamic_fp16_with_onednn_weight)");
}

};


TORCH_LIBRARY_IMPL(quantized, CPU, m) {
register_linear_params();
Expand Down Expand Up @@ -834,5 +947,11 @@ TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
wrapped_fbgemm_linear_fp16_weight_meta);
}

TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("onednn::linear_dynamic_fp16"),
TORCH_FN(LinearDynamicFp16Onednn::run));
m.impl(TORCH_SELECTIVE_NAME("onednn::linear_relu_dynamic_fp16"),
TORCH_FN(LinearDynamicFp16Onednn::run_relu));
}
} // namespace
} // namespace at::native
36 changes: 36 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,23 @@ inline at::Tensor pack_weight_to_onednn_tensor(
return packed_weight;
}

inline at::Tensor pack_weight_to_fp16_onednn_tensor(
at::Tensor& weight,
std::optional<torch::List<int64_t>>& input_shape) {
weight = at::_saturate_weight_to_fp16(weight);
std::vector<int64_t> w_dims = weight.sizes().vec();
auto weight_fp16 = weight.to(at::kHalf);
ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::f16}, weight_fp16.data_ptr());
auto expected_weight = wei.transpose(0, 1); // oneDNN requires transposed weight
// Onednn does not support f32f16f32 matmul, so we need to convert weight to f32 before compute
// Therefore, we just return weight in plain format
auto packed_weight = at::native::new_with_itensor_mkldnn(
std::move(expected_weight),
c10::kHalf,
weight.options().device_opt());
return packed_weight;
}

#endif // #if AT_MKLDNN_ENABLED()

namespace at::native {
Expand Down Expand Up @@ -672,6 +689,21 @@ class QLinearPackWeightInt8Onednn final {
}
};

class QLinearPackWeightFp16Onednn final {
public:
static at::Tensor run(
// NOLINTNEXTLINE(performance-unnecessary-value-param)
[[maybe_unused]] at::Tensor weight, // Not QTensor
// NOLINTNEXTLINE(performance-unnecessary-value-param)
[[maybe_unused]] std::optional<torch::List<int64_t>> input_shape) {
#if AT_MKLDNN_ENABLED()
return pack_weight_to_fp16_onednn_tensor(weight, input_shape);
#else
TORCH_CHECK(false, "Unimplemented as onednn is not available.");
#endif
}
};

TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
register_linear_params();
m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run));
Expand Down Expand Up @@ -716,5 +748,9 @@ TORCH_LIBRARY_IMPL(onednn, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"), TORCH_FN(QLinearPackWeightInt8Onednn::run));
}

TORCH_LIBRARY_IMPL(onednn, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("onednn::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16Onednn::run));
}

} // namespace
} // namespace at::native
4 changes: 4 additions & 0 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,11 @@ TORCH_LIBRARY(onednn, m) {

// Linear prepack
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_prepack(Tensor weight, int[]? x_shape) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::linear_prepack_fp16(Tensor weight, int[]? x_shape) -> Tensor"));

// Linear
m.def(TORCH_SELECTIVE_SCHEMA("onednn::linear_dynamic_fp16(Tensor x, Tensor w, Tensor? bias) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::linear_relu_dynamic_fp16(Tensor x, Tensor w, Tensor? bias) -> Tensor"));
// Linear with unary postop
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
Expand Down
33 changes: 33 additions & 0 deletions test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3750,6 +3750,39 @@ def test_dynamic_convtranspose3d(self):
return # TODO: fix MakeDeConvOutputShape overflowing for convT3d with qnnpack
self._test_qconv_op_impl(q_mod, dq_op, dim, dtype)

@skipIfNoONEDNN
def test_linear_dynamic_fp16_onednn(self):

options = itertools.product(
(2, 4), # batch_size
(4, 5, 12), # input_channels
(4, 7, 8), # output_channels
(True, False), # use_bias
(True, False), # use_relu
)
for batch_size, input_channels, output_channels, use_bias, use_relu in options:
qlinear_prepack = torch.ops.onednn.linear_prepack_fp16
if use_relu:
qlinear_dynamic = torch.ops.onednn.linear_relu_dynamic_fp16
else:
qlinear_dynamic = torch.ops.onednn.linear_dynamic_fp16

x = torch.randn(batch_size, input_channels)
w = torch.randn(output_channels, input_channels)
bias = torch.randn(output_channels) if use_bias else None

w_packed = qlinear_prepack(w, x.shape)
out = qlinear_dynamic(x, w_packed, bias)

# qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors
# output is FP32
w_fp16 = w.to(torch.float16).to(torch.float32)
ref = F.linear(x, w_fp16, bias)
if use_relu:
ref.relu_()

self.assertEqual(out, ref)


class TestQuantizedLinear(TestCase):
def _test_qlinear_impl(self, batch_size, input_channels, output_channels, use_bias,
Expand Down

0 comments on commit 62eea62

Please sign in to comment.