From 2cf47df9e7ff3c36177b111e43560acf54180548 Mon Sep 17 00:00:00 2001 From: Carson M Date: Sun, 20 Aug 2023 13:08:13 -0500 Subject: [PATCH] fix: only provide key to EP register if option was set, fix #67 --- src/execution_providers.rs | 210 +++++++++++++++---------------------- src/lib.rs | 3 +- 2 files changed, 86 insertions(+), 127 deletions(-) diff --git a/src/execution_providers.rs b/src/execution_providers.rs index 93ca207..68b643e 100644 --- a/src/execution_providers.rs +++ b/src/execution_providers.rs @@ -89,22 +89,22 @@ impl Default for CUDAExecutionProviderCuDNNConvAlgoSearch { } } -#[derive(Debug, Clone)] +#[derive(Default, Debug, Clone)] pub struct CUDAExecutionProviderOptions { - pub device_id: u32, + pub device_id: Option, /// The size limit of the device memory arena in bytes. This size limit is only for the execution provider’s arena. /// The total device memory usage may be higher. - pub gpu_mem_limit: size_t, + pub gpu_mem_limit: Option, /// The strategy for extending the device memory arena. See [`ArenaExtendStrategy`]. - pub arena_extend_strategy: ArenaExtendStrategy, + pub arena_extend_strategy: Option, /// ORT leverages cuDNN for convolution operations and the first step in this process is to determine an /// “optimal” convolution algorithm to use while performing the convolution operation for the given input /// configuration (input shape, filter shape, etc.) in each `Conv` node. This option controlls the type of search /// done for cuDNN convolution algorithms. See [`CUDAExecutionProviderCuDNNConvAlgoSearch`] for more info. - pub cudnn_conv_algo_search: CUDAExecutionProviderCuDNNConvAlgoSearch, + pub cudnn_conv_algo_search: Option, /// Whether to do copies in the default stream or use separate streams. The recommended setting is true. If false, /// there are race conditions and possibly better performance. - pub do_copy_in_default_stream: bool, + pub do_copy_in_default_stream: Option, /// ORT leverages cuDNN for convolution operations and the first step in this process is to determine an /// “optimal” convolution algorithm to use while performing the convolution operation for the given input /// configuration (input shape, filter shape, etc.) in each `Conv` node. This sub-step involves querying cuDNN for a @@ -113,14 +113,14 @@ pub struct CUDAExecutionProviderOptions { /// /// When `cudnn_conv_use_max_workspace` is false, ORT will clamp the workspace size to 32 MB, which may lead to /// cuDNN selecting a suboptimal convolution algorithm. The recommended (and default) value is `true`. - pub cudnn_conv_use_max_workspace: bool, + pub cudnn_conv_use_max_workspace: Option, /// ORT leverages cuDNN for convolution operations. While cuDNN only takes 4-D or 5-D tensors as input for /// convolution operations, dimension padding is needed if the input is a 3-D tensor. Given an input tensor of shape /// `[N, C, D]`, it can be padded to `[N, C, D, 1]` or `[N, C, 1, D]`. While both of these padding methods produce /// the same output, the performance may differ because different convolution algorithms are selected, /// especially on some devices such as A100. By default, the input is padded to `[N, C, D, 1]`. Set this option to /// true to instead use `[N, C, 1, D]`. - pub cudnn_conv1d_pad_to_nc1d: bool, + pub cudnn_conv1d_pad_to_nc1d: Option, /// ORT supports the usage of CUDA Graphs to remove CPU overhead associated with launching CUDA kernels /// sequentially. To enable the usage of CUDA Graphs, set `enable_cuda_graph` to true. /// Currently, there are some constraints with regards to using the CUDA Graphs feature: @@ -145,93 +145,42 @@ pub struct CUDAExecutionProviderOptions { /// > allocations, capturing the CUDA graph for the model, and then performing a graph replay to ensure that the /// > graph runs. Due to this, the latency associated with the first `run()` is bound to be high. Subsequent /// > `run()`s only perform graph replays of the graph captured and cached in the first `run()`. - pub enable_cuda_graph: bool, + pub enable_cuda_graph: Option, /// Whether to use strict mode in the `SkipLayerNormalization` implementation. The default and recommanded setting /// is `false`. If enabled, accuracy may improve slightly, but performance may decrease. - pub enable_skip_layer_norm_strict_mode: bool + pub enable_skip_layer_norm_strict_mode: Option } -impl Default for CUDAExecutionProviderOptions { - fn default() -> Self { - Self { - device_id: 0, - gpu_mem_limit: size_t::MAX, - arena_extend_strategy: ArenaExtendStrategy::NextPowerOfTwo, - cudnn_conv_algo_search: CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive, - do_copy_in_default_stream: true, - cudnn_conv_use_max_workspace: true, - cudnn_conv1d_pad_to_nc1d: false, - enable_cuda_graph: false, - enable_skip_layer_norm_strict_mode: false - } - } -} - -#[derive(Debug, Clone)] +#[derive(Default, Debug, Clone)] pub struct TensorRTExecutionProviderOptions { - pub device_id: u32, - pub max_workspace_size: u32, - pub max_partition_iterations: u32, - pub min_subgraph_size: u32, - pub fp16_enable: bool, - pub int8_enable: bool, - pub int8_calibration_table_name: String, - pub int8_use_native_calibration_table: bool, - pub dla_enable: bool, - pub dla_core: u32, - pub engine_cache_enable: bool, - pub engine_cache_path: String, - pub dump_subgraphs: bool, - pub force_sequential_engine_build: bool, - pub enable_context_memory_sharing: bool, - pub layer_norm_fp32_fallback: bool, - pub timing_cache_enable: bool, - pub force_timing_cache: bool, - pub detailed_build_log: bool, - pub enable_build_heuristics: bool, - pub enable_sparsity: bool, - pub builder_optimization_level: u8, - pub auxiliary_streams: i8, - pub tactic_sources: String, - pub extra_plugin_lib_paths: String, - pub profile_min_shapes: String, - pub profile_max_shapes: String, - pub profile_opt_shapes: String -} - -impl Default for TensorRTExecutionProviderOptions { - fn default() -> Self { - Self { - device_id: 0, - max_workspace_size: 1073741824, - max_partition_iterations: 1000, - min_subgraph_size: 1, - fp16_enable: false, - int8_enable: false, - int8_calibration_table_name: String::default(), - int8_use_native_calibration_table: false, - dla_enable: false, - dla_core: 0, - engine_cache_enable: false, - engine_cache_path: String::default(), - dump_subgraphs: false, - force_sequential_engine_build: false, - enable_context_memory_sharing: false, - layer_norm_fp32_fallback: false, - timing_cache_enable: false, - force_timing_cache: false, - detailed_build_log: false, - enable_build_heuristics: false, - enable_sparsity: false, - builder_optimization_level: 3, - auxiliary_streams: -1, - tactic_sources: String::default(), - extra_plugin_lib_paths: String::default(), - profile_min_shapes: String::default(), - profile_max_shapes: String::default(), - profile_opt_shapes: String::default() - } - } + pub device_id: Option, + pub max_workspace_size: Option, + pub max_partition_iterations: Option, + pub min_subgraph_size: Option, + pub fp16_enable: Option, + pub int8_enable: Option, + pub int8_calibration_table_name: Option, + pub int8_use_native_calibration_table: Option, + pub dla_enable: Option, + pub dla_core: Option, + pub engine_cache_enable: Option, + pub engine_cache_path: Option, + pub dump_subgraphs: Option, + pub force_sequential_engine_build: Option, + pub enable_context_memory_sharing: Option, + pub layer_norm_fp32_fallback: Option, + pub timing_cache_enable: Option, + pub force_timing_cache: Option, + pub detailed_build_log: Option, + pub enable_build_heuristics: Option, + pub enable_sparsity: Option, + pub builder_optimization_level: Option, + pub auxiliary_streams: Option, + pub tactic_sources: Option, + pub extra_plugin_lib_paths: Option, + pub profile_min_shapes: Option, + pub profile_max_shapes: Option, + pub profile_opt_shapes: Option } #[derive(Debug, Clone)] @@ -350,8 +299,9 @@ pub struct NNAPIExecutionProviderOptions { /// will be ignored for Android API level 28 and lower. pub cpu_only: bool } + #[derive(Debug, Clone)] -enum QNNExecutionHTPPerformanceMode { +pub enum QNNExecutionHTPPerformanceMode { /// Default mode. Default, Burst, @@ -379,31 +329,32 @@ impl QNNExecutionHTPPerformanceMode { } } } + #[derive(Debug, Clone)] pub struct QNNExecutionProviderOptions { /// The file path to QNN backend library.On Linux/Android: libQnnCpu.so for CPU backend, libQnnHtp.so for GPU /// backend. - backend_path: String, + pub backend_path: Option, /// true to enable QNN graph creation from cached QNN context file. If it's enabled: QNN EP will /// load from cached QNN context binary if it exist. It will generate a context binary file if it's not exist - qnn_context_cache_enable: bool, + pub qnn_context_cache_enable: Option, /// explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided. - qnn_context_cache_path: Option, + pub qnn_context_cache_path: Option, /// QNN profiling level, options: "off", "basic", "detailed". Default to off. - profiling_level: Option, + pub profiling_level: Option, /// Allows client to set up RPC control latency in microseconds. - rpc_control_latency: Option, + pub rpc_control_latency: Option, /// QNN performance mode, options: "burst", "balanced", "default", "high_performance", /// "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to /// "default". - htp_performance_mode: Option + pub htp_performance_mode: Option } impl Default for QNNExecutionProviderOptions { fn default() -> Self { Self { - backend_path: String::from("libQnnHtp.so"), - qnn_context_cache_enable: false, + backend_path: Some(String::from("libQnnHtp.so")), + qnn_context_cache_enable: Some(false), qnn_context_cache_path: Some(String::from("model_file.onnx.bin")), profiling_level: Some(String::from("off")), rpc_control_latency: Some(10), @@ -462,8 +413,10 @@ macro_rules! map_keys { let mut keys = Vec::::new(); let mut values = Vec::::new(); $( - keys.push(CString::new(stringify!($fn_name)).unwrap()); - values.push(CString::new(($ex).to_string().as_str()).unwrap()); + if let Some(v) = $ex { + keys.push(CString::new(stringify!($fn_name)).unwrap()); + values.push(CString::new(v.to_string().as_str()).unwrap()); + } )* assert_eq!(keys.len(), values.len()); // sanity check let key_ptrs: Vec<*const c_char> = keys.iter().map(|k| k.as_ptr()).collect(); @@ -474,7 +427,12 @@ macro_rules! map_keys { } #[inline(always)] -fn bool_as_int(x: bool) -> i32 { +fn bool_as_int(x: Option) -> Option { + x.map(|x| if x { 1 } else { 0 }) +} + +#[inline(always)] +fn bool_as_int_req(x: bool) -> i32 { if x { 1 } else { 0 } } @@ -532,15 +490,15 @@ impl ExecutionProvider { status_to_result(ortsys![unsafe CreateCUDAProviderOptions(&mut cuda_options)]).map_err(OrtError::ExecutionProvider)?; let (key_ptrs, value_ptrs, len, keys, values) = map_keys! { device_id = options.device_id, - arena_extend_strategy = match options.arena_extend_strategy { + arena_extend_strategy = options.arena_extend_strategy.as_ref().map(|v| match v { ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo", ArenaExtendStrategy::SameAsRequested => "kSameAsRequested" - }, - cudnn_conv_algo_search = match options.cudnn_conv_algo_search { + }), + cudnn_conv_algo_search = options.cudnn_conv_algo_search.as_ref().map(|v| match v { CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive => "EXHAUSTIVE", CUDAExecutionProviderCuDNNConvAlgoSearch::Heuristic => "HEURISTIC", CUDAExecutionProviderCuDNNConvAlgoSearch::Default => "DEFAULT" - }, + }), do_copy_in_default_stream = bool_as_int(options.do_copy_in_default_stream), cudnn_conv_use_max_workspace = bool_as_int(options.cudnn_conv_use_max_workspace), cudnn_conv1d_pad_to_nc1d = bool_as_int(options.cudnn_conv1d_pad_to_nc1d), @@ -571,11 +529,11 @@ impl ExecutionProvider { trt_min_subgraph_size = options.min_subgraph_size, trt_fp16_enable = bool_as_int(options.fp16_enable), trt_int8_enable = bool_as_int(options.int8_enable), - trt_int8_calibration_table_name = options.int8_calibration_table_name, + trt_int8_calibration_table_name = options.int8_calibration_table_name.clone(), trt_dla_enable = bool_as_int(options.dla_enable), trt_dla_core = options.dla_core, trt_engine_cache_enable = bool_as_int(options.engine_cache_enable), - trt_engine_cache_path = options.engine_cache_path, + trt_engine_cache_path = options.engine_cache_path.clone(), trt_dump_subgraphs = bool_as_int(options.dump_subgraphs), trt_force_sequential_engine_build = bool_as_int(options.force_sequential_engine_build), trt_context_memory_sharing_enable = bool_as_int(options.enable_context_memory_sharing), @@ -587,11 +545,11 @@ impl ExecutionProvider { trt_sparsity_enable = bool_as_int(options.enable_sparsity), trt_builder_optimization_level = options.builder_optimization_level, trt_auxiliary_streams = options.auxiliary_streams, - trt_tactic_sources = options.tactic_sources, - trt_extra_plugin_lib_paths = options.extra_plugin_lib_paths, - trt_profile_min_shapes = options.profile_min_shapes, - trt_profile_max_shapes = options.profile_max_shapes, - trt_profile_opt_shapes = options.profile_opt_shapes + trt_tactic_sources = options.tactic_sources.clone(), + trt_extra_plugin_lib_paths = options.extra_plugin_lib_paths.clone(), + trt_profile_min_shapes = options.profile_min_shapes.clone(), + trt_profile_max_shapes = options.profile_max_shapes.clone(), + trt_profile_opt_shapes = options.profile_opt_shapes.clone() }; if let Err(e) = status_to_result(ortsys![unsafe UpdateTensorRTProviderOptions(trt_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), len as _)]) .map_err(OrtError::ExecutionProvider) @@ -643,18 +601,18 @@ impl ExecutionProvider { &Self::ROCm(options) => { let rocm_options = sys::OrtROCMProviderOptions { device_id: options.device_id, - miopen_conv_exhaustive_search: bool_as_int(options.miopen_conv_exhaustive_search), + miopen_conv_exhaustive_search: bool_as_int_req(options.miopen_conv_exhaustive_search), gpu_mem_limit: options.gpu_mem_limit as _, arena_extend_strategy: match options.arena_extend_strategy { ArenaExtendStrategy::NextPowerOfTwo => 0, ArenaExtendStrategy::SameAsRequested => 1 }, - do_copy_in_default_stream: bool_as_int(options.do_copy_in_default_stream), - has_user_compute_stream: bool_as_int(options.user_compute_stream.is_some()), + do_copy_in_default_stream: bool_as_int_req(options.do_copy_in_default_stream), + has_user_compute_stream: bool_as_int_req(options.user_compute_stream.is_some()), user_compute_stream: options.user_compute_stream.unwrap_or_else(ptr::null_mut), default_memory_arena_cfg: options.default_memory_arena_cfg.unwrap_or_else(ptr::null_mut), - tunable_op_enable: bool_as_int(options.tunable_op_enable), - tunable_op_tuning_enable: bool_as_int(options.tunable_op_tuning_enable) + tunable_op_enable: bool_as_int_req(options.tunable_op_enable), + tunable_op_tuning_enable: bool_as_int_req(options.tunable_op_tuning_enable) }; status_to_result(ortsys![unsafe SessionOptionsAppendExecutionProvider_ROCM(session_options, &rocm_options as *const _)]) .map_err(OrtError::ExecutionProvider)?; @@ -697,9 +655,9 @@ impl ExecutionProvider { .map(|x| x.as_bytes().as_ptr() as *const c_char) .unwrap_or_else(ptr::null), context: options.context, - enable_opencl_throttling: bool_as_int(options.enable_opencl_throttling) as _, - enable_dynamic_shapes: bool_as_int(options.enable_dynamic_shapes) as _, - enable_vpu_fast_compile: bool_as_int(options.enable_vpu_fast_compile) as _ + enable_opencl_throttling: bool_as_int_req(options.enable_opencl_throttling) as _, + enable_dynamic_shapes: bool_as_int_req(options.enable_dynamic_shapes) as _, + enable_vpu_fast_compile: bool_as_int_req(options.enable_vpu_fast_compile) as _ }; status_to_result(ortsys![unsafe SessionOptionsAppendExecutionProvider_OpenVINO(session_options, &openvino_options as *const _)]) .map_err(OrtError::ExecutionProvider)?; @@ -707,12 +665,12 @@ impl ExecutionProvider { #[cfg(any(feature = "load-dynamic", feature = "qnn"))] &Self::QNN(options) => { let (key_ptrs, value_ptrs, len, keys, values) = map_keys! { - backend_path = options.backend_path, - profiling_level = options.profiling_level.clone().unwrap_or("off".to_string()), + backend_path = options.backend_path.clone(), + profiling_level = options.profiling_level.clone(), qnn_context_cache_enable = bool_as_int(options.qnn_context_cache_enable), - qnn_context_cache_path = options.qnn_context_cache_path.clone().unwrap_or("model_file.onnx.bin".to_string()), - htp_performance_mode = options.htp_performance_mode.clone().unwrap_or(QNNExecutionHTPPerformanceMode::Default).as_str(), - rpc_control_latency = options.rpc_control_latency.unwrap_or(10) + qnn_context_cache_path = options.qnn_context_cache_path.clone(), + htp_performance_mode = options.htp_performance_mode.as_ref().map(|v| v.as_str()), + rpc_control_latency = options.rpc_control_latency }; status_to_result(ortsys![unsafe SessionOptionsAppendExecutionProvider( session_options, diff --git a/src/lib.rs b/src/lib.rs index 4425966..003cf8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,8 @@ pub use self::error::{OrtApiError, OrtError, OrtResult}; pub use self::execution_providers::{ ACLExecutionProviderOptions, ArenaExtendStrategy, CPUExecutionProviderOptions, CUDAExecutionProviderCuDNNConvAlgoSearch, CUDAExecutionProviderOptions, CoreMLExecutionProviderOptions, DirectMLExecutionProviderOptions, ExecutionProvider, NNAPIExecutionProviderOptions, OneDNNExecutionProviderOptions, - OpenVINOExecutionProviderOptions, ROCmExecutionProviderOptions, TensorRTExecutionProviderOptions + OpenVINOExecutionProviderOptions, QNNExecutionHTPPerformanceMode, QNNExecutionProviderOptions, ROCmExecutionProviderOptions, + TensorRTExecutionProviderOptions }; pub use self::io_binding::IoBinding; pub use self::memory::{AllocationDevice, MemoryInfo};