Skip to content

Commit

Permalink
add transform part of the dq matmul tool chain (microsoft#21374)
Browse files Browse the repository at this point in the history
### Description

This is a partial change from
[fajin/qdqmatmulnbitstoolchain](microsoft#21180).
The original PR is blocked by Web CI failures.

MatMulNBits is a heavily optimized matmul operation. Currently a MatMul
can be converted to MatMulNBits to speed up the model inference.
However, MatMulNBits is an ORT only op. To make the graph compatible
with ONNX ops and utilize MatMulNBits at the same time, we introduce
Q/DQ support for MatMulNBits.

To convert MatMul ops in a model to MatMulNBits:
1. use matmul_4bits_quantizer.py to convert MatMul to DQ + MatMul using
QDQ mode.
2. In ORT session, DQ + MatMul is fused to MatMulNBits

#### Note
MatMulNBits assume B weight is uint4. When no zp is provided, zp
defaults to 8, which is different from DQ. DQ defaults zp to 0 when no
zp provided. And DQ supports int4. Therefore some conversions are
introduced during DQ + MatMul --> MatMulNBits step.

#### Perf
Using QDQ format will increase the model initialization time and memory
consumption. With current implement, model init time increased from ~4s
to ~9s, and memory consumption increased from ~2.8GB to ~4.8GB.
The memory increase is due to 
1. in optimizer, after transpose the B weight, a in-memory tensor proto
is created using protobuf's arena.
2. in finalize step, when saving initializer and prepacking, ORT arena
is used to create buffers for initializers.

The memory allocated by arenas cannot be fully deallocated.
If disable ORT arena memory allocation, the memory consumptions of both
QDQ format and original format are ~2.2GB.
The time increase is mainly due to multiple memory copy, but can be
further optimized.

### Motivation and Context
Please see description for details.
  • Loading branch information
fajin-corp authored Jul 20, 2024
1 parent 5bec522 commit 11bf309
Show file tree
Hide file tree
Showing 16 changed files with 833 additions and 37 deletions.
7 changes: 5 additions & 2 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/common/inlined_containers.h"
#include "core/framework/session_options.h"
#include "core/optimizer/graph_transformer.h"
#include "core/platform/threadpool.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "core/optimizer/rule_based_graph_transformer.h"
Expand Down Expand Up @@ -49,7 +50,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr);

#endif // !defined(ORT_MINIMAL_BUILD)

Expand Down Expand Up @@ -78,7 +80,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr);

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,8 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
// Refer to MatMulNBits op schema for more details.
// If not provided, default is 4.
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";
26 changes: 21 additions & 5 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/platform/threadpool.h"

#if !defined(ORT_MINIMAL_BUILD)

Expand Down Expand Up @@ -187,7 +188,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
const InlinedHashSet<std::string>& rules_and_transformers_to_disable) {
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
concurrency::ThreadPool* intra_op_thread_pool) {
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
const bool disable_quant_qdq =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
Expand Down Expand Up @@ -287,6 +289,10 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
onnxruntime::kJsExecutionProvider};
const InlinedHashSet<std::string_view> cpu_dml_eps = {onnxruntime::kCpuExecutionProvider,
onnxruntime::kDmlExecutionProvider};
const int64_t qdq_matmulnbits_accuracy_level =
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
"4"));
#ifdef MLAS_TARGET_AMD64_IX86
const bool avx2_precision_mode =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow();
Expand All @@ -300,7 +306,10 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
if (!qdq_is_int8_allowed) {
transformers.emplace_back(std::make_unique<QDQS8ToU8Transformer>(avx2_precision_mode, cpu_ep));
}
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed));
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
SatApplyContextVariant{},
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
}

transformers.emplace_back(std::make_unique<GemmActivationFusion>(cpu_ep));
Expand Down Expand Up @@ -409,7 +418,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable) {
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
concurrency::ThreadPool* intra_op_thread_pool) {
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
const bool saving = std::holds_alternative<SatRuntimeOptimizationSaveContext>(apply_context);

Expand All @@ -423,12 +433,18 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const bool qdq_is_int8_allowed =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed,
QDQIsInt8Allowed() ? "1" : "0") == "1";

const int64_t qdq_matmulnbits_accuracy_level =
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
"4"));
// runtime optimizations only support CPU EP now
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};

if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed, apply_context));
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
apply_context,
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
}

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_ep, apply_context));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
// Licensed under the MIT License.

#include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h"

#include "core/optimizer/qdq_transformer/qdq_util.h"
#include "core/optimizer/initializer.h"
#include "core/graph/node_attr_utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/mlas/inc/mlas_q4.h"

namespace onnxruntime {
namespace QDQ {

Expand Down Expand Up @@ -273,6 +275,175 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select
}
}

DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool)
: accuracy_level_{accuracy_level},
domain_{kMSDomain},
op_type_{"MatMulNBits"},
value_moves_{[]() {
NTO::NodeLocation target{NTO::NodeType::kTarget, 0};
return std::vector<NodeAndMoveInfo>{
MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput),
MoveAll(target, ArgType::kOutput)};
}()},
intra_op_thread_pool_{intra_op_thread_pool} {
ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4");
}

NodeAttributes
DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) const {
NodeAttributes extra_attributes;

const auto* dq_node = runtime_state.selected_nodes.Input(0);
auto& attrs = dq_node->GetAttributes();
const auto* weight_shape = dq_node->InputDefs()[0]->Shape();

utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes);
utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes);
utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes);
// currently only 4bits is supported. In the future, derive bits from DQ's weight type.
utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast<int64_t>(4)), extra_attributes);
utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes);

return extra_attributes;
}

Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
const NodesToOptimize& selected_nodes,
Node& replacement_node) const {
const auto* dq_node = selected_nodes.Input(0);
const auto* weight_arg = dq_node->InputDefs()[0];
const auto* scale_arg = dq_node->InputDefs()[1];
const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr;
const auto& attrs = dq_node->GetAttributes();

const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr;
const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr;
const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr;
graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto);
graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto);
if (zp_arg) {
graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto);
}

auto K = weight_arg->Shape()->dim(0).dim_value();
auto N = weight_arg->Shape()->dim(1).dim_value();
auto block_size = attrs.at("block_size").i();
auto quant_num = (K + block_size - 1) / block_size;
auto blob_bytes = (block_size + 1) / 2;

// Unfortunately iterating the source data is complicated, the data maybe in
// external file, a raw buffer, or a repeated field depending on the data
// type. UnpackTensor() already contains some of these logic and is closest
// to what we need. But it does not handle external data.
Initializer weight_src(*weight_tensor_proto, graph.ModelPath());
Initializer scale_src(*scale_tensor_proto, graph.ModelPath());
std::optional<Initializer> zp_src;
Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(weight_arg->Name() + "_T"),
std::vector<int64_t>{N, quant_num, blob_bytes});
Initializer scale_dst(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(scale_src.data_type()),
graph.GenerateNodeArgName(scale_arg->Name() + "_T"),
std::vector<int64_t>{N * quant_num});
std::optional<Initializer> zp_dst;

if (zp_tensor_proto) {
zp_src.emplace(*zp_tensor_proto, graph.ModelPath());
zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(zp_arg->Name() + "_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)});
} else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)});
}

if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
MlasQDQTransposeBlockwiseQuantized<float, 4, true>(
weight_src.DataAsByteSpan().data(),
scale_src.data<float>(),
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<float>(),
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
} else {
MlasQDQTransposeBlockwiseQuantized<float, 4, false>(
weight_src.DataAsByteSpan().data(),
scale_src.data<float>(),
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<float>(),
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
}
} else {
if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
MlasQDQTransposeBlockwiseQuantized<MLFloat16, 4, true>(
weight_src.DataAsByteSpan().data(),
scale_src.data<MLFloat16>(),
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<MLFloat16>(),
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);

} else {
MlasQDQTransposeBlockwiseQuantized<MLFloat16, 4, false>(
weight_src.DataAsByteSpan().data(),
scale_src.data<MLFloat16>(),
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<MLFloat16>(),
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
}
}

ONNX_NAMESPACE::TensorProto weight_T_tp;
ONNX_NAMESPACE::TensorProto scale_T_tp;
std::optional<ONNX_NAMESPACE::TensorProto> zp_T_tp;

// TODO(fajin): external_data to memory location to avoid arena allocation
// https://github.com/microsoft/onnxruntime/pull/12465
weight_dst.ToProto(weight_T_tp);
scale_dst.ToProto(scale_T_tp);
if (zp_dst) {
zp_T_tp.emplace();
zp_dst->ToProto(zp_T_tp.value());
}

auto& input_defs = replacement_node.MutableInputDefs();
input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp));
replacement_node.MutableInputArgsCount().push_back(1);
input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp));
replacement_node.MutableInputArgsCount().push_back(1);

if (zp_T_tp) {
input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value()));
replacement_node.MutableInputArgsCount().push_back(1);
}

return Status::OK();
}

static std::vector<NodeAndMoveInfo> GetGemmMoveInfo(bool does_q_node_exist) {
NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0};
NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

#pragma once

#include <memory>
#include <string>
#include <vector>

#include "core/optimizer/selectors_actions/actions.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {

Expand Down Expand Up @@ -76,6 +81,30 @@ struct MatMulReplaceWithQLinear : public Action {
BinaryReplaceWithQLinear qlinear_matmul_replacer_;
};

// used together with DQMatMulNodeGroupSelector, which does the sanity check
struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew {
DQMatMulToMatMulNBitsAction(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool);

private:
std::string OpType(const RuntimeState&) const override { return op_type_; }

std::string Domain(const RuntimeState&) const override { return domain_; }

NodeAttributes ExtraAttributes(const RuntimeState&) const override;

std::vector<NodeAndMoveInfo> ValueMoves(const RuntimeState&) const override { return value_moves_; }

// transpose initializers, and add to the MatMulNBits inputs
Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const override;

const int64_t accuracy_level_;
const std::string domain_;
const std::string op_type_;
const std::vector<NodeAndMoveInfo> value_moves_;
concurrency::ThreadPool* intra_op_thread_pool_;
};

struct GemmReplaceWithQuant : public Action {
GemmReplaceWithQuant();

Expand Down
Loading

0 comments on commit 11bf309

Please sign in to comment.