Skip to content

Commit

Permalink
fix: only provide key to EP register if option was set, fix #67
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Aug 20, 2023
1 parent 64de7ec commit 2cf47df
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 127 deletions.
210 changes: 84 additions & 126 deletions src/execution_providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,22 @@ impl Default for CUDAExecutionProviderCuDNNConvAlgoSearch {
}
}

#[derive(Debug, Clone)]
#[derive(Default, Debug, Clone)]
pub struct CUDAExecutionProviderOptions {
pub device_id: u32,
pub device_id: Option<u32>,
/// The size limit of the device memory arena in bytes. This size limit is only for the execution provider’s arena.
/// The total device memory usage may be higher.
pub gpu_mem_limit: size_t,
pub gpu_mem_limit: Option<size_t>,
/// The strategy for extending the device memory arena. See [`ArenaExtendStrategy`].
pub arena_extend_strategy: ArenaExtendStrategy,
pub arena_extend_strategy: Option<ArenaExtendStrategy>,
/// ORT leverages cuDNN for convolution operations and the first step in this process is to determine an
/// “optimal” convolution algorithm to use while performing the convolution operation for the given input
/// configuration (input shape, filter shape, etc.) in each `Conv` node. This option controlls the type of search
/// done for cuDNN convolution algorithms. See [`CUDAExecutionProviderCuDNNConvAlgoSearch`] for more info.
pub cudnn_conv_algo_search: CUDAExecutionProviderCuDNNConvAlgoSearch,
pub cudnn_conv_algo_search: Option<CUDAExecutionProviderCuDNNConvAlgoSearch>,
/// Whether to do copies in the default stream or use separate streams. The recommended setting is true. If false,
/// there are race conditions and possibly better performance.
pub do_copy_in_default_stream: bool,
pub do_copy_in_default_stream: Option<bool>,
/// ORT leverages cuDNN for convolution operations and the first step in this process is to determine an
/// “optimal” convolution algorithm to use while performing the convolution operation for the given input
/// configuration (input shape, filter shape, etc.) in each `Conv` node. This sub-step involves querying cuDNN for a
Expand All @@ -113,14 +113,14 @@ pub struct CUDAExecutionProviderOptions {
///
/// When `cudnn_conv_use_max_workspace` is false, ORT will clamp the workspace size to 32 MB, which may lead to
/// cuDNN selecting a suboptimal convolution algorithm. The recommended (and default) value is `true`.
pub cudnn_conv_use_max_workspace: bool,
pub cudnn_conv_use_max_workspace: Option<bool>,
/// ORT leverages cuDNN for convolution operations. While cuDNN only takes 4-D or 5-D tensors as input for
/// convolution operations, dimension padding is needed if the input is a 3-D tensor. Given an input tensor of shape
/// `[N, C, D]`, it can be padded to `[N, C, D, 1]` or `[N, C, 1, D]`. While both of these padding methods produce
/// the same output, the performance may differ because different convolution algorithms are selected,
/// especially on some devices such as A100. By default, the input is padded to `[N, C, D, 1]`. Set this option to
/// true to instead use `[N, C, 1, D]`.
pub cudnn_conv1d_pad_to_nc1d: bool,
pub cudnn_conv1d_pad_to_nc1d: Option<bool>,
/// ORT supports the usage of CUDA Graphs to remove CPU overhead associated with launching CUDA kernels
/// sequentially. To enable the usage of CUDA Graphs, set `enable_cuda_graph` to true.
/// Currently, there are some constraints with regards to using the CUDA Graphs feature:
Expand All @@ -145,93 +145,42 @@ pub struct CUDAExecutionProviderOptions {
/// > allocations, capturing the CUDA graph for the model, and then performing a graph replay to ensure that the
/// > graph runs. Due to this, the latency associated with the first `run()` is bound to be high. Subsequent
/// > `run()`s only perform graph replays of the graph captured and cached in the first `run()`.
pub enable_cuda_graph: bool,
pub enable_cuda_graph: Option<bool>,
/// Whether to use strict mode in the `SkipLayerNormalization` implementation. The default and recommanded setting
/// is `false`. If enabled, accuracy may improve slightly, but performance may decrease.
pub enable_skip_layer_norm_strict_mode: bool
pub enable_skip_layer_norm_strict_mode: Option<bool>
}

impl Default for CUDAExecutionProviderOptions {
fn default() -> Self {
Self {
device_id: 0,
gpu_mem_limit: size_t::MAX,
arena_extend_strategy: ArenaExtendStrategy::NextPowerOfTwo,
cudnn_conv_algo_search: CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive,
do_copy_in_default_stream: true,
cudnn_conv_use_max_workspace: true,
cudnn_conv1d_pad_to_nc1d: false,
enable_cuda_graph: false,
enable_skip_layer_norm_strict_mode: false
}
}
}

#[derive(Debug, Clone)]
#[derive(Default, Debug, Clone)]
pub struct TensorRTExecutionProviderOptions {
pub device_id: u32,
pub max_workspace_size: u32,
pub max_partition_iterations: u32,
pub min_subgraph_size: u32,
pub fp16_enable: bool,
pub int8_enable: bool,
pub int8_calibration_table_name: String,
pub int8_use_native_calibration_table: bool,
pub dla_enable: bool,
pub dla_core: u32,
pub engine_cache_enable: bool,
pub engine_cache_path: String,
pub dump_subgraphs: bool,
pub force_sequential_engine_build: bool,
pub enable_context_memory_sharing: bool,
pub layer_norm_fp32_fallback: bool,
pub timing_cache_enable: bool,
pub force_timing_cache: bool,
pub detailed_build_log: bool,
pub enable_build_heuristics: bool,
pub enable_sparsity: bool,
pub builder_optimization_level: u8,
pub auxiliary_streams: i8,
pub tactic_sources: String,
pub extra_plugin_lib_paths: String,
pub profile_min_shapes: String,
pub profile_max_shapes: String,
pub profile_opt_shapes: String
}

impl Default for TensorRTExecutionProviderOptions {
fn default() -> Self {
Self {
device_id: 0,
max_workspace_size: 1073741824,
max_partition_iterations: 1000,
min_subgraph_size: 1,
fp16_enable: false,
int8_enable: false,
int8_calibration_table_name: String::default(),
int8_use_native_calibration_table: false,
dla_enable: false,
dla_core: 0,
engine_cache_enable: false,
engine_cache_path: String::default(),
dump_subgraphs: false,
force_sequential_engine_build: false,
enable_context_memory_sharing: false,
layer_norm_fp32_fallback: false,
timing_cache_enable: false,
force_timing_cache: false,
detailed_build_log: false,
enable_build_heuristics: false,
enable_sparsity: false,
builder_optimization_level: 3,
auxiliary_streams: -1,
tactic_sources: String::default(),
extra_plugin_lib_paths: String::default(),
profile_min_shapes: String::default(),
profile_max_shapes: String::default(),
profile_opt_shapes: String::default()
}
}
pub device_id: Option<u32>,
pub max_workspace_size: Option<u32>,
pub max_partition_iterations: Option<u32>,
pub min_subgraph_size: Option<u32>,
pub fp16_enable: Option<bool>,
pub int8_enable: Option<bool>,
pub int8_calibration_table_name: Option<String>,
pub int8_use_native_calibration_table: Option<bool>,
pub dla_enable: Option<bool>,
pub dla_core: Option<u32>,
pub engine_cache_enable: Option<bool>,
pub engine_cache_path: Option<String>,
pub dump_subgraphs: Option<bool>,
pub force_sequential_engine_build: Option<bool>,
pub enable_context_memory_sharing: Option<bool>,
pub layer_norm_fp32_fallback: Option<bool>,
pub timing_cache_enable: Option<bool>,
pub force_timing_cache: Option<bool>,
pub detailed_build_log: Option<bool>,
pub enable_build_heuristics: Option<bool>,
pub enable_sparsity: Option<bool>,
pub builder_optimization_level: Option<u8>,
pub auxiliary_streams: Option<i8>,
pub tactic_sources: Option<String>,
pub extra_plugin_lib_paths: Option<String>,
pub profile_min_shapes: Option<String>,
pub profile_max_shapes: Option<String>,
pub profile_opt_shapes: Option<String>
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -350,8 +299,9 @@ pub struct NNAPIExecutionProviderOptions {
/// will be ignored for Android API level 28 and lower.
pub cpu_only: bool
}

#[derive(Debug, Clone)]
enum QNNExecutionHTPPerformanceMode {
pub enum QNNExecutionHTPPerformanceMode {
/// Default mode.
Default,
Burst,
Expand Down Expand Up @@ -379,31 +329,32 @@ impl QNNExecutionHTPPerformanceMode {
}
}
}

#[derive(Debug, Clone)]
pub struct QNNExecutionProviderOptions {
/// The file path to QNN backend library.On Linux/Android: libQnnCpu.so for CPU backend, libQnnHtp.so for GPU
/// backend.
backend_path: String,
pub backend_path: Option<String>,
/// true to enable QNN graph creation from cached QNN context file. If it's enabled: QNN EP will
/// load from cached QNN context binary if it exist. It will generate a context binary file if it's not exist
qnn_context_cache_enable: bool,
pub qnn_context_cache_enable: Option<bool>,
/// explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided.
qnn_context_cache_path: Option<String>,
pub qnn_context_cache_path: Option<String>,
/// QNN profiling level, options: "off", "basic", "detailed". Default to off.
profiling_level: Option<String>,
pub profiling_level: Option<String>,
/// Allows client to set up RPC control latency in microseconds.
rpc_control_latency: Option<u32>,
pub rpc_control_latency: Option<u32>,
/// QNN performance mode, options: "burst", "balanced", "default", "high_performance",
/// "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to
/// "default".
htp_performance_mode: Option<QNNExecutionHTPPerformanceMode>
pub htp_performance_mode: Option<QNNExecutionHTPPerformanceMode>
}

impl Default for QNNExecutionProviderOptions {
fn default() -> Self {
Self {
backend_path: String::from("libQnnHtp.so"),
qnn_context_cache_enable: false,
backend_path: Some(String::from("libQnnHtp.so")),
qnn_context_cache_enable: Some(false),
qnn_context_cache_path: Some(String::from("model_file.onnx.bin")),
profiling_level: Some(String::from("off")),
rpc_control_latency: Some(10),
Expand Down Expand Up @@ -462,8 +413,10 @@ macro_rules! map_keys {
let mut keys = Vec::<CString>::new();
let mut values = Vec::<CString>::new();
$(
keys.push(CString::new(stringify!($fn_name)).unwrap());
values.push(CString::new(($ex).to_string().as_str()).unwrap());
if let Some(v) = $ex {
keys.push(CString::new(stringify!($fn_name)).unwrap());
values.push(CString::new(v.to_string().as_str()).unwrap());
}
)*
assert_eq!(keys.len(), values.len()); // sanity check
let key_ptrs: Vec<*const c_char> = keys.iter().map(|k| k.as_ptr()).collect();
Expand All @@ -474,7 +427,12 @@ macro_rules! map_keys {
}

#[inline(always)]
fn bool_as_int(x: bool) -> i32 {
fn bool_as_int(x: Option<bool>) -> Option<i32> {
x.map(|x| if x { 1 } else { 0 })
}

#[inline(always)]
fn bool_as_int_req(x: bool) -> i32 {
if x { 1 } else { 0 }
}

Expand Down Expand Up @@ -532,15 +490,15 @@ impl ExecutionProvider {
status_to_result(ortsys![unsafe CreateCUDAProviderOptions(&mut cuda_options)]).map_err(OrtError::ExecutionProvider)?;
let (key_ptrs, value_ptrs, len, keys, values) = map_keys! {
device_id = options.device_id,
arena_extend_strategy = match options.arena_extend_strategy {
arena_extend_strategy = options.arena_extend_strategy.as_ref().map(|v| match v {
ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo",
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested"
},
cudnn_conv_algo_search = match options.cudnn_conv_algo_search {
}),
cudnn_conv_algo_search = options.cudnn_conv_algo_search.as_ref().map(|v| match v {
CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive => "EXHAUSTIVE",
CUDAExecutionProviderCuDNNConvAlgoSearch::Heuristic => "HEURISTIC",
CUDAExecutionProviderCuDNNConvAlgoSearch::Default => "DEFAULT"
},
}),
do_copy_in_default_stream = bool_as_int(options.do_copy_in_default_stream),
cudnn_conv_use_max_workspace = bool_as_int(options.cudnn_conv_use_max_workspace),
cudnn_conv1d_pad_to_nc1d = bool_as_int(options.cudnn_conv1d_pad_to_nc1d),
Expand Down Expand Up @@ -571,11 +529,11 @@ impl ExecutionProvider {
trt_min_subgraph_size = options.min_subgraph_size,
trt_fp16_enable = bool_as_int(options.fp16_enable),
trt_int8_enable = bool_as_int(options.int8_enable),
trt_int8_calibration_table_name = options.int8_calibration_table_name,
trt_int8_calibration_table_name = options.int8_calibration_table_name.clone(),
trt_dla_enable = bool_as_int(options.dla_enable),
trt_dla_core = options.dla_core,
trt_engine_cache_enable = bool_as_int(options.engine_cache_enable),
trt_engine_cache_path = options.engine_cache_path,
trt_engine_cache_path = options.engine_cache_path.clone(),
trt_dump_subgraphs = bool_as_int(options.dump_subgraphs),
trt_force_sequential_engine_build = bool_as_int(options.force_sequential_engine_build),
trt_context_memory_sharing_enable = bool_as_int(options.enable_context_memory_sharing),
Expand All @@ -587,11 +545,11 @@ impl ExecutionProvider {
trt_sparsity_enable = bool_as_int(options.enable_sparsity),
trt_builder_optimization_level = options.builder_optimization_level,
trt_auxiliary_streams = options.auxiliary_streams,
trt_tactic_sources = options.tactic_sources,
trt_extra_plugin_lib_paths = options.extra_plugin_lib_paths,
trt_profile_min_shapes = options.profile_min_shapes,
trt_profile_max_shapes = options.profile_max_shapes,
trt_profile_opt_shapes = options.profile_opt_shapes
trt_tactic_sources = options.tactic_sources.clone(),
trt_extra_plugin_lib_paths = options.extra_plugin_lib_paths.clone(),
trt_profile_min_shapes = options.profile_min_shapes.clone(),
trt_profile_max_shapes = options.profile_max_shapes.clone(),
trt_profile_opt_shapes = options.profile_opt_shapes.clone()
};
if let Err(e) = status_to_result(ortsys![unsafe UpdateTensorRTProviderOptions(trt_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), len as _)])
.map_err(OrtError::ExecutionProvider)
Expand Down Expand Up @@ -643,18 +601,18 @@ impl ExecutionProvider {
&Self::ROCm(options) => {
let rocm_options = sys::OrtROCMProviderOptions {
device_id: options.device_id,
miopen_conv_exhaustive_search: bool_as_int(options.miopen_conv_exhaustive_search),
miopen_conv_exhaustive_search: bool_as_int_req(options.miopen_conv_exhaustive_search),
gpu_mem_limit: options.gpu_mem_limit as _,
arena_extend_strategy: match options.arena_extend_strategy {
ArenaExtendStrategy::NextPowerOfTwo => 0,
ArenaExtendStrategy::SameAsRequested => 1
},
do_copy_in_default_stream: bool_as_int(options.do_copy_in_default_stream),
has_user_compute_stream: bool_as_int(options.user_compute_stream.is_some()),
do_copy_in_default_stream: bool_as_int_req(options.do_copy_in_default_stream),
has_user_compute_stream: bool_as_int_req(options.user_compute_stream.is_some()),
user_compute_stream: options.user_compute_stream.unwrap_or_else(ptr::null_mut),
default_memory_arena_cfg: options.default_memory_arena_cfg.unwrap_or_else(ptr::null_mut),
tunable_op_enable: bool_as_int(options.tunable_op_enable),
tunable_op_tuning_enable: bool_as_int(options.tunable_op_tuning_enable)
tunable_op_enable: bool_as_int_req(options.tunable_op_enable),
tunable_op_tuning_enable: bool_as_int_req(options.tunable_op_tuning_enable)
};
status_to_result(ortsys![unsafe SessionOptionsAppendExecutionProvider_ROCM(session_options, &rocm_options as *const _)])
.map_err(OrtError::ExecutionProvider)?;
Expand Down Expand Up @@ -697,22 +655,22 @@ impl ExecutionProvider {
.map(|x| x.as_bytes().as_ptr() as *const c_char)
.unwrap_or_else(ptr::null),
context: options.context,
enable_opencl_throttling: bool_as_int(options.enable_opencl_throttling) as _,
enable_dynamic_shapes: bool_as_int(options.enable_dynamic_shapes) as _,
enable_vpu_fast_compile: bool_as_int(options.enable_vpu_fast_compile) as _
enable_opencl_throttling: bool_as_int_req(options.enable_opencl_throttling) as _,
enable_dynamic_shapes: bool_as_int_req(options.enable_dynamic_shapes) as _,
enable_vpu_fast_compile: bool_as_int_req(options.enable_vpu_fast_compile) as _
};
status_to_result(ortsys![unsafe SessionOptionsAppendExecutionProvider_OpenVINO(session_options, &openvino_options as *const _)])
.map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "qnn"))]
&Self::QNN(options) => {
let (key_ptrs, value_ptrs, len, keys, values) = map_keys! {
backend_path = options.backend_path,
profiling_level = options.profiling_level.clone().unwrap_or("off".to_string()),
backend_path = options.backend_path.clone(),
profiling_level = options.profiling_level.clone(),
qnn_context_cache_enable = bool_as_int(options.qnn_context_cache_enable),
qnn_context_cache_path = options.qnn_context_cache_path.clone().unwrap_or("model_file.onnx.bin".to_string()),
htp_performance_mode = options.htp_performance_mode.clone().unwrap_or(QNNExecutionHTPPerformanceMode::Default).as_str(),
rpc_control_latency = options.rpc_control_latency.unwrap_or(10)
qnn_context_cache_path = options.qnn_context_cache_path.clone(),
htp_performance_mode = options.htp_performance_mode.as_ref().map(|v| v.as_str()),
rpc_control_latency = options.rpc_control_latency
};
status_to_result(ortsys![unsafe SessionOptionsAppendExecutionProvider(
session_options,
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ pub use self::error::{OrtApiError, OrtError, OrtResult};
pub use self::execution_providers::{
ACLExecutionProviderOptions, ArenaExtendStrategy, CPUExecutionProviderOptions, CUDAExecutionProviderCuDNNConvAlgoSearch, CUDAExecutionProviderOptions,
CoreMLExecutionProviderOptions, DirectMLExecutionProviderOptions, ExecutionProvider, NNAPIExecutionProviderOptions, OneDNNExecutionProviderOptions,
OpenVINOExecutionProviderOptions, ROCmExecutionProviderOptions, TensorRTExecutionProviderOptions
OpenVINOExecutionProviderOptions, QNNExecutionHTPPerformanceMode, QNNExecutionProviderOptions, ROCmExecutionProviderOptions,
TensorRTExecutionProviderOptions
};
pub use self::io_binding::IoBinding;
pub use self::memory::{AllocationDevice, MemoryInfo};
Expand Down

0 comments on commit 2cf47df

Please sign in to comment.