Skip to content

Commit

Permalink
[TensorRT] Enable refitting an embedded engine when provided as byte …
Browse files Browse the repository at this point in the history
…stream (microsoft#21357)

### Description

This allows refitting an engine using an ONNX file not available on
disk. This is important for encrypted ONNX files on disk.
  • Loading branch information
gedoensmax authored Jul 20, 2024
1 parent 34cd2e8 commit 5bec522
Show file tree
Hide file tree
Showing 13 changed files with 225 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct OrtTensorRTProviderOptionsV2 {
int device_id{0}; // cuda device id.
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
// can be updated using: UpdateTensorRTProviderOptionsWithValue
int trt_max_partition_iterations{1000}; // maximum iterations for TensorRT parser to get capability
int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs
size_t trt_max_workspace_size{1 << 30}; // maximum workspace size for TensorRT.
Expand Down Expand Up @@ -78,6 +79,12 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
// the ONNX model containing the weights (applicable only when
// the "trt_weight_stripped_engine_enable" option is enabled)
const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights
// (applicable only when the "trt_weight_stripped_engine_enable"
// option is enabled)
// can be updated using: UpdateTensorRTProviderOptionsWithValue
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
// can be updated using: UpdateTensorRTProviderOptionsWithValue

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
Expand Down
24 changes: 23 additions & 1 deletion onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
// Only make path checks if model not provided as byte buffer
bool make_secure_path_checks = !GetModelPath(graph_viewer).empty();

if (embed_mode) {
// Get engine from byte stream.
const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s();
Expand All @@ -284,6 +287,23 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from binary data");
}

if (weight_stripped_engine_refit_) {
const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
std::string placeholder;
auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
onnx_model_folder_path_,
placeholder,
make_secure_path_checks,
onnx_model_bytestream_,
onnx_model_bytestream_size_,
(*trt_engine_).get(),
false /* serialize refitted engine to disk */,
detailed_build_log_);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
}
} else {
// Get engine from cache file.
std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s();
Expand Down Expand Up @@ -343,7 +363,9 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
onnx_model_folder_path_,
weight_stripped_engine_cache,
true /* path check for security */,
make_secure_path_checks,
onnx_model_bytestream_,
onnx_model_bytestream_size_,
(*trt_engine_).get(),
true /* serialize refitted engine to disk */,
detailed_build_log_);
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,17 @@ class TensorRTCacheModelHandler {
std::string compute_capability,
bool weight_stripped_engine_refit,
std::string onnx_model_folder_path,
const void* onnx_model_bytestream,
size_t onnx_model_bytestream_size,
bool detailed_build_log)
: trt_engine_(trt_engine),
trt_runtime_(trt_runtime),
ep_context_model_path_(ep_context_model_path),
compute_capability_(compute_capability),
weight_stripped_engine_refit_(weight_stripped_engine_refit),
onnx_model_folder_path_(onnx_model_folder_path),
onnx_model_bytestream_(onnx_model_bytestream),
onnx_model_bytestream_size_(onnx_model_bytestream_size),
detailed_build_log_(detailed_build_log) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);
Expand All @@ -74,6 +78,8 @@ class TensorRTCacheModelHandler {
std::string compute_capability_;
bool weight_stripped_engine_refit_;
std::string onnx_model_folder_path_;
const void* onnx_model_bytestream_;
size_t onnx_model_bytestream_size_;
bool detailed_build_log_;
}; // TRTCacheModelHandler
} // namespace onnxruntime
81 changes: 61 additions & 20 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
engine_cache_enable_ = info.engine_cache_enable;
weight_stripped_engine_enable_ = info.weight_stripped_engine_enable;
onnx_model_folder_path_ = info.onnx_model_folder_path;
onnx_model_bytestream_ = info.onnx_bytestream;
onnx_model_bytestream_size_ = info.onnx_bytestream_size;
if ((onnx_model_bytestream_ != nullptr && onnx_model_bytestream_size_ == 0) ||
(onnx_model_bytestream_ == nullptr && onnx_model_bytestream_size_ != 0)) {
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"When providing either 'trt_onnx_bytestream_size' or "
"'trt_onnx_bytestream' both have to be provided"));
}
timing_cache_enable_ = info.timing_cache_enable;
force_timing_cache_match_ = info.force_timing_cache;
detailed_build_log_ = info.detailed_build_log;
Expand Down Expand Up @@ -1757,7 +1765,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
<< ", trt_ep_context_file_path: " << ep_context_file_path_
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_;
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand Down Expand Up @@ -2597,38 +2606,61 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
std::string& onnx_model_folder_path,
std::string& weight_stripped_engine_cath_path,
bool path_check,
const void* onnx_model_bytestream,
size_t onnx_model_bytestream_size,
nvinfer1::ICudaEngine* trt_engine,
bool serialize_refitted_engine,
bool detailed_build_log) {
#if NV_TENSORRT_MAJOR >= 10
bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0;
std::filesystem::path onnx_model_path{onnx_model_folder_path};
onnx_model_path.append(onnx_model_filename);
if (path_check && IsAbsolutePath(onnx_model_path.string())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"For security purpose, the ONNX model path should be set with "
"a relative path, but it is an absolute path: " +
onnx_model_path.string());
}
if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model path has '..'. For security purpose, it's not "
"allowed to point outside the directory.");
}
if (refit_from_file) {
if (!onnx_model_filename.empty()) {
onnx_model_path.append(onnx_model_filename);
}
if (onnx_model_path.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model was not provided as path. "
"Please use provide an ONNX bytestream to enable refitting the weightless engine.");
} else {
// check if file path to ONNX is legal
if (path_check && IsAbsolutePath(onnx_model_path.string())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"For security purpose, the ONNX model path should be set with "
"a relative path, but it is an absolute path: " +
onnx_model_path.string());
}
if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model path has '..'. For security purpose, it's not "
"allowed to point outside the directory.");
}

if (!std::filesystem::exists(onnx_model_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model " + onnx_model_path.string() +
" does not exist.");
if (!(std::filesystem::exists(onnx_model_path) && std::filesystem::is_regular_file(onnx_model_path))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model " + onnx_model_path.string() +
" does not exist.");
}
}
}

// weight-stripped engine refit logic
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log);
auto refitter = std::unique_ptr<nvinfer1::IRefitter>(nvinfer1::createInferRefitter(*trt_engine, trt_logger));
auto parser_refitter = std::unique_ptr<nvonnxparser::IParserRefitter>(
nvonnxparser::createParserRefitter(*refitter, trt_logger));
if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string());
if (refit_from_file) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from file on disk: " << onnx_model_path.string();
if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string());
}
} else {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from byte array";
if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestraem");
}
}
if (refitter->refitCudaEngine()) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine.";
Expand Down Expand Up @@ -3212,10 +3244,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
}

if (weight_stripped_engine_refit_) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build";
char* onnx = string_buf.data();
size_t onnx_size = string_buf.size();
auto status = RefitEngine(model_path_,
onnx_model_folder_path_,
engine_cache_path,
false /* path check for security */,
onnx,
onnx_size,
trt_engine.get(),
true /* serialize refitted engine to disk */,
detailed_build_log_);
Expand Down Expand Up @@ -3685,6 +3722,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
onnx_model_folder_path_,
engine_cache_path,
false /* path check for security */,
onnx_model_bytestream_,
onnx_model_bytestream_size_,
trt_engine,
true /* serialize refitted engine to disk */,
detailed_build_log_);
Expand Down Expand Up @@ -3910,6 +3949,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
compute_capability_,
weight_stripped_engine_enable_,
onnx_model_folder_path_,
onnx_model_bytestream_,
onnx_model_bytestream_size_,
detailed_build_log_);
auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer);
if (status != Status::OK()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,12 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;

/**
* Refit the weight-stripped engine
*/
static common::Status RefitEngine(std::string onnx_model_filename,
std::string& onnx_model_folder_path,
std::string& weight_stripped_engine_cath_path,
bool path_check,
const void* onnx_model_bytestream,
size_t onnx_model_bytestream_size,
nvinfer1::ICudaEngine* trt_engine,
bool serialize_refitted_engine,
bool detailed_build_log);
Expand All @@ -305,6 +304,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool weight_stripped_engine_enable_ = false;
bool weight_stripped_engine_refit_ = false;
std::string onnx_model_folder_path_;
const void* onnx_model_bytestream_;
size_t onnx_model_bytestream_size_;
bool build_heuristics_enable_ = false;
bool sparsity_enable_ = false;
int builder_optimization_level_ = 3;
Expand Down
Loading

0 comments on commit 5bec522

Please sign in to comment.