From 11cfd33b10782b20175f0a0139409a327a4dff39 Mon Sep 17 00:00:00 2001 From: Manickavela <50542248+manickavela29@users.noreply.github.com> Date: Mon, 15 Jul 2024 12:22:33 +0530 Subject: [PATCH] encoder only trt ep for transducer (#1130) --- .../csrc/online-zipformer2-transducer-model.cc | 14 +++++++++----- .../csrc/online-zipformer2-transducer-model.h | 5 ++++- sherpa-onnx/csrc/session.cc | 16 +++++++++++++++- sherpa-onnx/csrc/session.h | 3 +++ 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc index e5c448210..b6cb3e173 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -33,7 +33,9 @@ namespace sherpa_onnx { OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), - sess_opts_(GetSessionOptions(config)), + encoder_sess_opts_(GetSessionOptions(config)), + decoder_sess_opts_(GetSessionOptions(config, "decoder")), + joiner_sess_opts_(GetSessionOptions(config, "joiner")), config_(config), allocator_{} { { @@ -57,7 +59,9 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( AAssetManager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), - sess_opts_(GetSessionOptions(config)), + encoder_sess_opts_(GetSessionOptions(config)), + decoder_sess_opts_(GetSessionOptions(config)), + joiner_sess_opts_(GetSessionOptions(config)), allocator_{} { { auto buf = ReadFile(mgr, config.transducer.encoder); @@ -79,7 +83,7 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, size_t model_data_length) { encoder_sess_ = std::make_unique(env_, model_data, - model_data_length, sess_opts_); + model_data_length, encoder_sess_opts_); GetInputNames(encoder_sess_.get(), &encoder_input_names_, &encoder_input_names_ptr_); @@ -132,7 +136,7 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, size_t model_data_length) { decoder_sess_ = std::make_unique(env_, model_data, - model_data_length, sess_opts_); + model_data_length, decoder_sess_opts_); GetInputNames(decoder_sess_.get(), &decoder_input_names_, &decoder_input_names_ptr_); @@ -157,7 +161,7 @@ void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, void OnlineZipformer2TransducerModel::InitJoiner(void *model_data, size_t model_data_length) { joiner_sess_ = std::make_unique(env_, model_data, - model_data_length, sess_opts_); + model_data_length, joiner_sess_opts_); GetInputNames(joiner_sess_.get(), &joiner_input_names_, &joiner_input_names_ptr_); diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h index 07c9e9252..aa0f46f81 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h @@ -65,7 +65,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { private: Ort::Env env_; - Ort::SessionOptions sess_opts_; + Ort::SessionOptions encoder_sess_opts_; + Ort::SessionOptions decoder_sess_opts_; + Ort::SessionOptions joiner_sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; std::unique_ptr encoder_sess_; diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index b6fdaaa84..093465063 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -94,7 +94,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::to_string(trt_config.trt_timing_cache_enable); auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs); - std::vector trt_options = { {"device_id", device_id.c_str()}, {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, @@ -223,6 +222,21 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { config.provider_config.provider, &config.provider_config); } +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, + const std::string &model_type) { + /* + Transducer models : Only encoder will run with tensorrt, + decoder and joiner will run with cuda + */ + if(config.provider_config.provider == "trt" && + (model_type == "decoder" || model_type == "joiner")) { + return GetSessionOptionsImpl(config.num_threads, + "cuda", &config.provider_config); + } + return GetSessionOptionsImpl(config.num_threads, + config.provider_config.provider, &config.provider_config); +} + Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index a4121436a..691a2ff3c 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -24,6 +24,9 @@ namespace sherpa_onnx { Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, + const std::string &model_type); + Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);