diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index cdee9a209..fec122ed2 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -58,21 +58,51 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( recognizer_config.model_config.transducer.encoder = SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); + recognizer_config.model_config.transducer.encoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_begin, nullptr); + recognizer_config.model_config.transducer.encoder_buf_end = + SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_end, nullptr); + recognizer_config.model_config.transducer.decoder = SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); + recognizer_config.model_config.transducer.decoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_begin, nullptr); + recognizer_config.model_config.transducer.decoder_buf_end = + SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_end, nullptr); + recognizer_config.model_config.transducer.joiner = SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); + recognizer_config.model_config.transducer.joiner_buf_begin = + SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_begin, nullptr); + recognizer_config.model_config.transducer.joiner_buf_end = + SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_end, nullptr); recognizer_config.model_config.paraformer.encoder = SHERPA_ONNX_OR(config->model_config.paraformer.encoder, ""); + recognizer_config.model_config.paraformer.encoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.paraformer.encoder_buf_begin, nullptr); + recognizer_config.model_config.paraformer.encoder_buf_end = + SHERPA_ONNX_OR(config->model_config.paraformer.encoder_buf_end, nullptr); + recognizer_config.model_config.paraformer.decoder = SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); + recognizer_config.model_config.paraformer.decoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.paraformer.decoder_buf_begin, nullptr); + recognizer_config.model_config.paraformer.decoder_buf_end = + SHERPA_ONNX_OR(config->model_config.paraformer.decoder_buf_end, nullptr); recognizer_config.model_config.zipformer2_ctc.model = SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); + recognizer_config.model_config.zipformer2_ctc.model_buf_begin = + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model_buf_begin, nullptr); + recognizer_config.model_config.zipformer2_ctc.model_buf_end = + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model_buf_end, nullptr); recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); + recognizer_config.model_config.tokens_buf_str = + SHERPA_ONNX_OR(config->model_config.tokens_buf_str, ""); + recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); recognizer_config.model_config.provider_config.provider = @@ -118,6 +148,7 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( SHERPA_ONNX_OR(config->rule3_min_utterance_length, 20); recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); + recognizer_config.hotwords_buf_str = SHERPA_ONNX_OR(config->hotwords_buf_str, ""); recognizer_config.hotwords_score = SHERPA_ONNX_OR(config->hotwords_score, 1.5); @@ -366,24 +397,66 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( recognizer_config.model_config.transducer.encoder_filename = SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); + recognizer_config.model_config.transducer.encoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_begin, nullptr); + + recognizer_config.model_config.transducer.encoder_buf_end = + SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_end, nullptr); + recognizer_config.model_config.transducer.decoder_filename = SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); + recognizer_config.model_config.transducer.decoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_begin, nullptr); + + recognizer_config.model_config.transducer.decoder_buf_end = + SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_end, nullptr); + recognizer_config.model_config.transducer.joiner_filename = SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); + recognizer_config.model_config.transducer.joiner_buf_begin = + SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_begin, nullptr); + + recognizer_config.model_config.transducer.joiner_buf_end = + SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_end, nullptr); + recognizer_config.model_config.paraformer.model = SHERPA_ONNX_OR(config->model_config.paraformer.model, ""); + recognizer_config.model_config.paraformer.model_buf_begin = + SHERPA_ONNX_OR(config->model_config.paraformer.model_buf_begin, nullptr); + + recognizer_config.model_config.paraformer.model_buf_end = + SHERPA_ONNX_OR(config->model_config.paraformer.model_buf_end, nullptr); + recognizer_config.model_config.nemo_ctc.model = SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, ""); + recognizer_config.model_config.nemo_ctc.model_buf_begin = + SHERPA_ONNX_OR(config->model_config.nemo_ctc.model_buf_begin, nullptr); + + recognizer_config.model_config.nemo_ctc.model_buf_end = + SHERPA_ONNX_OR(config->model_config.nemo_ctc.model_buf_end, nullptr); + recognizer_config.model_config.whisper.encoder = SHERPA_ONNX_OR(config->model_config.whisper.encoder, ""); + recognizer_config.model_config.whisper.encoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.whisper.encoder_buf_begin, nullptr); + + recognizer_config.model_config.whisper.encoder_buf_end = + SHERPA_ONNX_OR(config->model_config.whisper.encoder_buf_end, nullptr); + recognizer_config.model_config.whisper.decoder = SHERPA_ONNX_OR(config->model_config.whisper.decoder, ""); + recognizer_config.model_config.whisper.decoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.whisper.decoder_buf_begin, nullptr); + + recognizer_config.model_config.whisper.decoder_buf_end = + SHERPA_ONNX_OR(config->model_config.whisper.decoder_buf_end, nullptr); + recognizer_config.model_config.whisper.language = SHERPA_ONNX_OR(config->model_config.whisper.language, ""); @@ -398,9 +471,16 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( recognizer_config.model_config.tdnn.model = SHERPA_ONNX_OR(config->model_config.tdnn.model, ""); + recognizer_config.model_config.tdnn.model_buf_begin = + SHERPA_ONNX_OR(config->model_config.tdnn.model_buf_begin, nullptr); + recognizer_config.model_config.tdnn.model_buf_end = + SHERPA_ONNX_OR(config->model_config.tdnn.model_buf_end, nullptr); recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); + recognizer_config.model_config.tokens_buf_str = + SHERPA_ONNX_OR(config->model_config.tokens_buf_str, ""); + recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); recognizer_config.model_config.debug = @@ -451,6 +531,7 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( SHERPA_ONNX_OR(config->max_active_paths, 4); recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); + recognizer_config.hotwords_buf_str = SHERPA_ONNX_OR(config->hotwords_buf_str, ""); recognizer_config.hotwords_score = SHERPA_ONNX_OR(config->hotwords_score, 1.5); @@ -639,24 +720,53 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); spotter_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); - + spotter_config.model_config.transducer.encoder = SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); + spotter_config.model_config.transducer.encoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_begin, nullptr); + spotter_config.model_config.transducer.encoder_buf_end = + SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_end, nullptr); + spotter_config.model_config.transducer.decoder = SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); + spotter_config.model_config.transducer.decoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_begin, nullptr); + spotter_config.model_config.transducer.decoder_buf_end = + SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_end, nullptr); + spotter_config.model_config.transducer.joiner = SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); + spotter_config.model_config.transducer.joiner_buf_begin = + SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_begin, nullptr); + spotter_config.model_config.transducer.joiner_buf_end = + SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_end, nullptr); spotter_config.model_config.paraformer.encoder = SHERPA_ONNX_OR(config->model_config.paraformer.encoder, ""); + spotter_config.model_config.paraformer.encoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.paraformer.encoder_buf_begin, nullptr); + spotter_config.model_config.paraformer.encoder_buf_end = + SHERPA_ONNX_OR(config->model_config.paraformer.encoder_buf_end, nullptr); + spotter_config.model_config.paraformer.decoder = SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); + spotter_config.model_config.paraformer.decoder_buf_begin = + SHERPA_ONNX_OR(config->model_config.paraformer.decoder_buf_begin, nullptr); + spotter_config.model_config.paraformer.decoder_buf_end = + SHERPA_ONNX_OR(config->model_config.paraformer.decoder_buf_end, nullptr); spotter_config.model_config.zipformer2_ctc.model = SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); + spotter_config.model_config.zipformer2_ctc.model_buf_begin = + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model_buf_begin, nullptr); + spotter_config.model_config.zipformer2_ctc.model_buf_end = + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model_buf_end, nullptr); spotter_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); + spotter_config.model_config.tokens_buf_str = + SHERPA_ONNX_OR(config->model_config.tokens_buf_str, ""); spotter_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); spotter_config.model_config.provider_config.provider = @@ -681,6 +791,7 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( SHERPA_ONNX_OR(config->keywords_threshold, 0.25); spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, ""); + spotter_config.keywords_buf_str = SHERPA_ONNX_OR(config->keywords_buf_str, ""); if (config->model_config.debug) { SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index d4844aed1..b3d23bdc0 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -54,8 +54,11 @@ extern "C" { /// from there. SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { const char *encoder; + const char *encoder_buf_begin, *encoder_buf_end; // if non-null, loading the encoder from the buffer in prioriy const char *decoder; + const char *decoder_buf_begin, *decoder_buf_end; // if non-null, loading the decoder from the buffer in prioriy const char *joiner; + const char *joiner_buf_begin, *joiner_buf_end; // if non-null, loading the joiner from the buffer in prioriy } SherpaOnnxOnlineTransducerModelConfig; // please visit @@ -63,7 +66,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { // to download pre-trained streaming paraformer models SHERPA_ONNX_API typedef struct SherpaOnnxOnlineParaformerModelConfig { const char *encoder; + const char *encoder_buf_begin, *encoder_buf_end; // if non-null, loading the encoder from the buffer in prioriy const char *decoder; + const char *decoder_buf_begin, *decoder_buf_end; // if non-null, loading the decoder from the buffer in prioriy } SherpaOnnxOnlineParaformerModelConfig; // Please visit @@ -71,6 +76,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineParaformerModelConfig { // to download pre-trained streaming zipformer2 ctc models SHERPA_ONNX_API typedef struct SherpaOnnxOnlineZipformer2CtcModelConfig { const char *model; + const char *model_buf_begin, *model_buf_end; // if non-null, loading the model from the buffer in prioriy } SherpaOnnxOnlineZipformer2CtcModelConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { @@ -78,6 +84,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { SherpaOnnxOnlineParaformerModelConfig paraformer; SherpaOnnxOnlineZipformer2CtcModelConfig zipformer2_ctc; const char *tokens; + const char *tokens_buf_str; // if non-null, loading the tokens from the buffered string directly in prioriy int32_t num_threads; const char *provider; int32_t debug; // true to print debug information of the model @@ -139,6 +146,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { /// Path to the hotwords. const char *hotwords_file; + const char *hotwords_buf_str; // if non-empty, loading the hotwords from the buffered string directly in prioriy /// Bonus score for each token in hotwords. float hotwords_score; @@ -358,21 +366,28 @@ SHERPA_ONNX_API void SherpaOnnxPrint(const SherpaOnnxDisplay *display, /// from there. SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTransducerModelConfig { const char *encoder; + const char *encoder_buf_begin, *encoder_buf_end; // if non-null, loading the encoder from the buffer in prioriy const char *decoder; + const char *decoder_buf_begin, *decoder_buf_end; // if non-null, loading the decoder from the buffer in prioriy const char *joiner; + const char *joiner_buf_begin, *joiner_buf_end; // if non-null, loading the joiner from the buffer in prioriy } SherpaOnnxOfflineTransducerModelConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOfflineParaformerModelConfig { const char *model; + const char *model_buf_begin, *model_buf_end; // if non-null, loading the model from the buffer in prioriy } SherpaOnnxOfflineParaformerModelConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOfflineNemoEncDecCtcModelConfig { const char *model; + const char *model_buf_begin, *model_buf_end; // if non-null, loading the model from the buffer in prioriy } SherpaOnnxOfflineNemoEncDecCtcModelConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { const char *encoder; + const char *encoder_buf_begin, *encoder_buf_end; // if non-null, loading the encoder from the buffer in prioriy const char *decoder; + const char *decoder_buf_begin, *decoder_buf_end; // if non-null, loading the decoder from the buffer in prioriy const char *language; const char *task; int32_t tail_paddings; @@ -380,6 +395,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig { const char *model; + const char *model_buf_begin, *model_buf_end; // if non-null, loading the model from the buffer in prioriy } SherpaOnnxOfflineTdnnModelConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig { @@ -401,6 +417,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { SherpaOnnxOfflineTdnnModelConfig tdnn; const char *tokens; + const char *tokens_buf_str; // if non-null, loading the tokens from the buffered string directly in prioriy int32_t num_threads; int32_t debug; const char *provider; @@ -425,6 +442,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { /// Path to the hotwords. const char *hotwords_file; + const char *hotwords_buf_str; // if non-null, loading the tokens from the buffered string directly in prioriy /// Bonus score for each token in hotwords. float hotwords_score; @@ -627,6 +645,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { float keywords_score; float keywords_threshold; const char *keywords_file; + const char *keywords_buf_str; } SherpaOnnxKeywordSpotterConfig; SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter diff --git a/sherpa-onnx/csrc/keyword-spotter-impl.cc b/sherpa-onnx/csrc/keyword-spotter-impl.cc index 1c9d59485..ebf038dfd 100644 --- a/sherpa-onnx/csrc/keyword-spotter-impl.cc +++ b/sherpa-onnx/csrc/keyword-spotter-impl.cc @@ -10,7 +10,7 @@ namespace sherpa_onnx { std::unique_ptr KeywordSpotterImpl::Create( const KeywordSpotterConfig &config) { - if (!config.model_config.transducer.encoder.empty()) { + if (!config.model_config.transducer.encoder.empty() || config.model_config.transducer.encoder_buf_begin) { return std::make_unique(config); } diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index 2300839f3..8d0c5547a 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -67,7 +67,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config) : config_(config), model_(OnlineTransducerModel::Create(config.model_config)), - sym_(config.model_config.tokens) { + sym_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true) { if (sym_.Contains("")) { unk_id_ = sym_[""]; } diff --git a/sherpa-onnx/csrc/keyword-spotter.h b/sherpa-onnx/csrc/keyword-spotter.h index 3d7935cc3..6f03d2464 100644 --- a/sherpa-onnx/csrc/keyword-spotter.h +++ b/sherpa-onnx/csrc/keyword-spotter.h @@ -68,6 +68,7 @@ struct KeywordSpotterConfig { float keywords_threshold = 0.25; std::string keywords_file; + std::string keywords_buf_str; KeywordSpotterConfig() = default; @@ -82,7 +83,8 @@ struct KeywordSpotterConfig { num_trailing_blanks(num_trailing_blanks), keywords_score(keywords_score), keywords_threshold(keywords_threshold), - keywords_file(keywords_file) {} + keywords_file(keywords_file), + keywords_buf_str("") {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index 8eb725e4e..e7774519d 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -29,6 +29,7 @@ struct OfflineModelConfig { std::string telespeech_ctc; std::string tokens; + std::string tokens_buf_str; int32_t num_threads = 2; bool debug = false; std::string provider = "cpu"; @@ -70,7 +71,7 @@ struct OfflineModelConfig { wenet_ctc(wenet_ctc), sense_voice(sense_voice), telespeech_ctc(telespeech_ctc), - tokens(tokens), + tokens(tokens), tokens_buf_str(""), num_threads(num_threads), debug(debug), provider(provider), diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h index 9ef7b54a9..542f51085 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h @@ -12,11 +12,15 @@ namespace sherpa_onnx { struct OfflineNemoEncDecCtcModelConfig { std::string model; + const char *model_buf_begin, *model_buf_end; OfflineNemoEncDecCtcModelConfig() = default; explicit OfflineNemoEncDecCtcModelConfig(const std::string &model) : model(model) {} + explicit OfflineNemoEncDecCtcModelConfig(const char *model_buf_begin, *model_buf_end) + : model_buf_begin(model_buf_begin), model_buf_end(model_buf_end) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc index 2d790954b..72547fd47 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc @@ -19,8 +19,12 @@ class OfflineNemoEncDecCtcModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - auto buf = ReadFile(config_.nemo_ctc.model); - Init(buf.data(), buf.size()); + if(config.nemo_ctc.model_buf_begin && config.nemo_ctc.model_buf_end) { + Init(config.nemo_ctc.model_buf_begin, config.nemo_ctc.model_buf_end - config.nemo_ctc.model_buf_begin); + } else { + auto buf = ReadFile(config_.nemo_ctc.model); + Init(buf.data(), buf.size()); + } } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/offline-paraformer-model-config.h b/sherpa-onnx/csrc/offline-paraformer-model-config.h index f0420dcb6..9d13e50d5 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model-config.h +++ b/sherpa-onnx/csrc/offline-paraformer-model-config.h @@ -12,11 +12,15 @@ namespace sherpa_onnx { struct OfflineParaformerModelConfig { std::string model; + const char *model_buf_begin, *model_buf_end; OfflineParaformerModelConfig() = default; explicit OfflineParaformerModelConfig(const std::string &model) : model(model) {} + explicit OfflineParaformerModelConfig(const char *model_buf_begin, *model_buf_end) + : model_buf_begin(model_buf_begin), model_buf_end(model_buf_end) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-paraformer-model.cc b/sherpa-onnx/csrc/offline-paraformer-model.cc index ce1851062..df91df561 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model.cc +++ b/sherpa-onnx/csrc/offline-paraformer-model.cc @@ -22,8 +22,12 @@ class OfflineParaformerModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - auto buf = ReadFile(config_.paraformer.model); - Init(buf.data(), buf.size()); + if(config.paraformer.model_buf_begin && config.paraformer.model_buf_end) { + Init(config.paraformer.model_buf_begin, config.paraformer.model_buf_end - config.paraformer.model_buf_begin); + } else { + auto buf = ReadFile(config_.paraformer.model); + Init(buf.data(), buf.size()); + } } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 05c1b7981..02418c4c8 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -75,7 +75,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(config), config_(config), - symbol_table_(config_.model_config.tokens), + symbol_table_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), model_(OfflineCtcModel::Create(config_.model_config)) { Init(); } @@ -103,7 +104,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { config_.feat_config.is_mfcc = true; } - if (!config_.model_config.wenet_ctc.model.empty()) { + if (!config_.model_config.wenet_ctc.model.empty() || config_.model_config.wenet_ctc.model_buf_begin) { // WeNet CTC models assume input samples are in the range // [-32768, 32767], so we set normalize_samples to false config_.feat_config.normalize_samples = false; diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 5062968cc..eda21d5c1 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -32,22 +32,22 @@ namespace sherpa_onnx { std::unique_ptr OfflineRecognizerImpl::Create( const OfflineRecognizerConfig &config) { - if (!config.model_config.sense_voice.model.empty()) { + if (!config.model_config.sense_voice.model.empty() || config.model_config.sense_voice.model_buf_begin) { return std::make_unique(config); } - if (!config.model_config.paraformer.model.empty()) { + if (!config.model_config.paraformer.model.empty() || config.model_config.paraformer.model_buf_begin) { return std::make_unique(config); } - if (!config.model_config.nemo_ctc.model.empty() || - !config.model_config.zipformer_ctc.model.empty() || - !config.model_config.tdnn.model.empty() || - !config.model_config.wenet_ctc.model.empty()) { + if (!config.model_config.nemo_ctc.model.empty() || config.model_config.nemo_ctc.model_buf_begin || + !config.model_config.zipformer_ctc.model.empty() || config.model_config.zipformer_ctc.model_buf_begin || + !config.model_config.tdnn.model.empty() || config.model_config.tdnn.model_buf_begin || + !config.model_config.wenet_ctc.model.empty() || config.model_config.wenet_ctc.model_buf_begin) { return std::make_unique(config); } - if (!config.model_config.whisper.encoder.empty()) { + if (!config.model_config.whisper.encoder.empty() || config.model_config.whisper.encoder_buf_begin) { return std::make_unique(config); } @@ -82,31 +82,59 @@ std::unique_ptr OfflineRecognizerImpl::Create( sess_opts.SetInterOpNumThreads(1); std::string model_filename; - if (!config.model_config.transducer.encoder_filename.empty()) { + const char *model_buf_begin = nullptr, char *model_buf_end = nullptr; + if(config.model_config.transducer.encoder_buf_begin && config.model_config.transducer.encoder_buf_end) { + model_buf_begin = config.model_config.transducer.encoder_buf_begin; + model_buf_end = config.model_config.transducer.encoder_buf_end; + } else if (!config.model_config.transducer.encoder_filename.empty()) { model_filename = config.model_config.transducer.encoder_filename; + } else if(config.model_config.paraformer.model_buf_begin && config.model_config.paraformer.model_buf_end) { + model_buf_begin = config.model_config.paraformer.model_buf_begin; + model_buf_end = config.model_config.paraformer.model_buf_end; } else if (!config.model_config.paraformer.model.empty()) { model_filename = config.model_config.paraformer.model; + } else if(config.model_config.nemo_ctc.model_buf_begin && config.model_config.nemo_ctc.model_buf_end) { + model_buf_begin = config.model_config.nemo_ctc.model_buf_begin; + model_buf_end = config.model_config.nemo_ctc.model_buf_end; } else if (!config.model_config.nemo_ctc.model.empty()) { model_filename = config.model_config.nemo_ctc.model; } else if (!config.model_config.telespeech_ctc.empty()) { + // loading from the buffered memory for telepseech is not supported yet model_filename = config.model_config.telespeech_ctc; + } else if(config.model_config.tdnn.model_buf_begin && config.model_config.tdnn.model_buf_end) { + model_buf_begin = config.model_config.tdnn.model_buf_begin; + model_buf_end = config.model_config.tdnn.model_buf_end; } else if (!config.model_config.tdnn.model.empty()) { model_filename = config.model_config.tdnn.model; + } else if(config.model_config.zipformer_ctc.model_buf_begin && config.model_config.zipformer_ctc.model_buf_end) { + model_buf_begin = config.model_config.zipformer_ctc.model_buf_begin; + model_buf_end = config.model_config.zipformer_ctc.model_buf_end; } else if (!config.model_config.zipformer_ctc.model.empty()) { model_filename = config.model_config.zipformer_ctc.model; + } else if(config.model_config.wenet_ctc.model_buf_begin && config.model_config.wenet_ctc.model_buf_end) { + model_buf_begin = config.model_config.wenet_ctc.model_buf_begin; + model_buf_end = config.model_config.wenet_ctc.model_buf_end; } else if (!config.model_config.wenet_ctc.model.empty()) { model_filename = config.model_config.wenet_ctc.model; + } else if(config.model_config.whisper.encoder_buf_begin && config.model_config.whisper.encoder_buf_end) { + model_buf_begin = config.model_config.whisper.encoder_buf_begin; + model_buf_end = config.model_config.whisper.encoder_buf_end; } else if (!config.model_config.whisper.encoder.empty()) { model_filename = config.model_config.whisper.encoder; } else { SHERPA_ONNX_LOGE("Please provide a model"); exit(-1); } - - auto buf = ReadFile(model_filename); - - auto encoder_sess = + + unique_ptr encoder_sess; + if(model_buf_begin && model_buf_end) { + encoder_sess = + std::make_unique(env, model_buf_begin, model_buf_end - model_buf_begin, sess_opts); + } else { + buf = ReadFile(model_filename); + encoder_sess = std::make_unique(env, buf.data(), buf.size(), sess_opts); + } Ort::ModelMetadata meta_data = encoder_sess->GetModelMetadata(); diff --git a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h index 525c92cc2..54144d3ce 100644 --- a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h @@ -91,7 +91,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(config), config_(config), - symbol_table_(config_.model_config.tokens), + symbol_table_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), model_(std::make_unique(config.model_config)) { if (config.decoding_method == "greedy_search") { int32_t eos_id = symbol_table_[""]; diff --git a/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h index 6cebf23c7..3bb4f495c 100644 --- a/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h @@ -68,7 +68,8 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(config), config_(config), - symbol_table_(config_.model_config.tokens), + symbol_table_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), model_(std::make_unique(config.model_config)) { const auto &meta_data = model_->GetModelMetadata(); if (config.decoding_method == "greedy_search") { diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 05759ac5b..e8241e1b9 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -76,7 +76,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(config), config_(config), - symbol_table_(config_.model_config.tokens), + symbol_table_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), model_(std::make_unique(config_.model_config)) { if (symbol_table_.Contains("")) { unk_id_ = symbol_table_[""]; @@ -95,7 +96,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { config_.model_config.bpe_vocab); } - if (!config_.hotwords_file.empty()) { + if (!config_.hotwords_buf_str.empty()) { + InitHotwordsFromBufStr(); + } else if (!config_.hotwords_file.empty()) { InitHotwords(); } @@ -250,7 +253,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { return config_; } - void InitHotwords() { // each line in hotwords_file contains space-separated words @@ -271,6 +273,22 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { hotwords_, config_.hotwords_score, boost_scores_); } + void InitHotwordsFromBufStr() { + // init the hotwords from its raw string-format directly, e.g R"(hotword1 :1.5 + // hotword2 :1.5)" + // each line in hotwords_file contains space-separated words + + std::istringstream iss(config_.hotwords_buf_str); + if (!EncodeHotwords(iss, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + #if __ANDROID_API__ >= 9 void InitHotwords(AAssetManager *mgr) { // each line in hotwords_file contains space-separated words diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h index 2f5b9e2a2..986b46658 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h @@ -43,7 +43,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(config), config_(config), - symbol_table_(config_.model_config.tokens), + symbol_table_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), model_(std::make_unique( config_.model_config)) { if (config_.decoding_method == "greedy_search") { diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index 023700e77..6ba0251b4 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -55,7 +55,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(config), config_(config), - symbol_table_(config_.model_config.tokens), + symbol_table_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), model_(std::make_unique(config.model_config)) { Init(); } diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 8f0b47a08..2048bdb88 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -36,6 +36,7 @@ struct OfflineRecognizerConfig { int32_t max_active_paths = 4; std::string hotwords_file; + std::string hotwords_buf_str; float hotwords_score = 1.5; float blank_penalty = 0.0; @@ -64,7 +65,7 @@ struct OfflineRecognizerConfig { ctc_fst_decoder_config(ctc_fst_decoder_config), decoding_method(decoding_method), max_active_paths(max_active_paths), - hotwords_file(hotwords_file), + hotwords_file(hotwords_file), hotwords_buf_str("") hotwords_score(hotwords_score), blank_penalty(blank_penalty), rule_fsts(rule_fsts), diff --git a/sherpa-onnx/csrc/offline-sense-voice-model-config.h b/sherpa-onnx/csrc/offline-sense-voice-model-config.h index 2f724e446..1038abdb5 100644 --- a/sherpa-onnx/csrc/offline-sense-voice-model-config.h +++ b/sherpa-onnx/csrc/offline-sense-voice-model-config.h @@ -12,6 +12,7 @@ namespace sherpa_onnx { struct OfflineSenseVoiceModelConfig { std::string model; + const char *model_buf_begin, *model_buf_end; // "" or "auto" to let the model recognize the language // valid values: @@ -28,6 +29,13 @@ struct OfflineSenseVoiceModelConfig { bool use_itn) : model(model), language(language), use_itn(use_itn) {} + explicit OfflineSenseVoiceModelConfig(const char *model_buf_begin, + const char *model_buf_end, + const std::string &language, + bool use_itn) + : model_buf_begin(model_buf_begin), model_buf_end(model_buf_end), + language(language), use_itn(use_itn) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-sense-voice-model.cc b/sherpa-onnx/csrc/offline-sense-voice-model.cc index 1d2a14ef5..7004269dc 100644 --- a/sherpa-onnx/csrc/offline-sense-voice-model.cc +++ b/sherpa-onnx/csrc/offline-sense-voice-model.cc @@ -21,8 +21,12 @@ class OfflineSenseVoiceModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - auto buf = ReadFile(config_.sense_voice.model); - Init(buf.data(), buf.size()); + if(config.sense_voice.model.buf_begin && config.sense_voice.model.buf_end) { + Init(config.sense_voice.model.buf_begin, config.sense_voice.model.buf_end - config.sense_voice.model.buf_begin); + } else { + auto buf = ReadFile(config_.sense_voice.model); + Init(buf.data(), buf.size()); + } } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc index ea91d1c55..506871f37 100644 --- a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc @@ -21,8 +21,12 @@ class OfflineTdnnCtcModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - auto buf = ReadFile(config_.tdnn.model); - Init(buf.data(), buf.size()); + if(config.tdnn.model_buf_begin && config.tdnn.model_buf_end) { + Init(config.tdnn.model_buf_begin, config.tdnn.model_buf_end - config.tdnn.model_buf_begin); + } else { + auto buf = ReadFile(config_.tdnn.model); + Init(buf.data(), buf.size()); + } } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/offline-tdnn-model-config.h b/sherpa-onnx/csrc/offline-tdnn-model-config.h index bddea551c..2748aedfa 100644 --- a/sherpa-onnx/csrc/offline-tdnn-model-config.h +++ b/sherpa-onnx/csrc/offline-tdnn-model-config.h @@ -13,10 +13,14 @@ namespace sherpa_onnx { // for https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn struct OfflineTdnnModelConfig { std::string model; + const char *model_buf_being, *model_buf_end; OfflineTdnnModelConfig() = default; explicit OfflineTdnnModelConfig(const std::string &model) : model(model) {} + explicit OfflineTdnnModelConfig(const char *model_buf_being, const char *model_buf_end) : + model_buf_being(model_buf_being), model_buf_end(model_buf_end) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.h b/sherpa-onnx/csrc/offline-transducer-model-config.h index 1b51f104e..8aa31451d 100644 --- a/sherpa-onnx/csrc/offline-transducer-model-config.h +++ b/sherpa-onnx/csrc/offline-transducer-model-config.h @@ -12,8 +12,12 @@ namespace sherpa_onnx { struct OfflineTransducerModelConfig { std::string encoder_filename; + const char *encoder_buf_begin, *encoder_buf_end; std::string decoder_filename; + const char *decoder_buf_begin, *decoder_buf_end; std::string joiner_filename; + const char *joiner_buf_begin, *joiner_buf_end; + OfflineTransducerModelConfig() = default; OfflineTransducerModelConfig(const std::string &encoder_filename, @@ -23,6 +27,13 @@ struct OfflineTransducerModelConfig { decoder_filename(decoder_filename), joiner_filename(joiner_filename) {} + OfflineTransducerModelConfig(const char *encoder_buf_begin, const char *encoder_buf_end, + const char *decoder_buf_begin, const char *decoder_buf_end, + const char *joiner_buf_begin, const char *joiner_buf_end) + : encoder_buf_begin(encoder_buf_begin), encoder_buf_end(encoder_buf_end), + decoder_buf_begin(decoder_buf_begin), decoder_buf_end(decoder_buf_end), + joiner_buf_begin(joiner_buf_begin), joiner_buf_end(joiner_buf_end) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-transducer-model.cc b/sherpa-onnx/csrc/offline-transducer-model.cc index 6a297347d..2415ad50b 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-model.cc @@ -22,19 +22,31 @@ class OfflineTransducerModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - { - auto buf = ReadFile(config.transducer.encoder_filename); - InitEncoder(buf.data(), buf.size()); + if(config.transducer.encoder_buf_begin && config.transducer.encoder_buf_end) { + InitEncoder(config.transducer.encoder_buf_begin, + config.transducer.encoder_buf_end - + config.transducer.encoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); } - - { - auto buf = ReadFile(config.transducer.decoder_filename); - InitDecoder(buf.data(), buf.size()); + + if(config.transducer.decoder_buf_begin && config.transducer.decoder_buf_end) { + InitDecoder(config.transducer.decoder_buf_begin, + config.transducer.decoder_buf_end - + config.transducer.decoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); } - { - auto buf = ReadFile(config.transducer.joiner_filename); - InitJoiner(buf.data(), buf.size()); + if(config.transducer.joiner_buf_begin && config.transducer.joiner_buf_end) { + InitJoiner(config.transducer.joiner_buf_begin, + config.transducer.joiner_buf_end - + config.transducer.joiner_buf_begin); + } else { + auto buf = ReadFile(config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc index f18e57da9..3c93105b4 100644 --- a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc @@ -24,19 +24,31 @@ class OfflineTransducerNeMoModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - { - auto buf = ReadFile(config.transducer.encoder_filename); - InitEncoder(buf.data(), buf.size()); + if(config.transducer.encoder_buf_begin && config.transducer.encoder_buf_end) { + InitEncoder(config.transducer.encoder_buf_begin, + config.transducer.encoder_buf_end - + config.transducer.encoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); } - - { - auto buf = ReadFile(config.transducer.decoder_filename); - InitDecoder(buf.data(), buf.size()); + + if(config.transducer.decoder_buf_begin && config.transducer.decoder_buf_end) { + InitDecoder(config.transducer.decoder_buf_begin, + config.transducer.decoder_buf_end - + config.transducer.decoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); } - { - auto buf = ReadFile(config.transducer.joiner_filename); - InitJoiner(buf.data(), buf.size()); + if(config.transducer.joiner_buf_begin && config.transducer.joiner_buf_end) { + InitJoiner(config.transducer.joiner_buf_begin, + config.transducer.joiner_buf_end - + config.transducer.joiner_buf_begin); + } else { + auto buf = ReadFile(config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/offline-wenet-ctc-model-config.h b/sherpa-onnx/csrc/offline-wenet-ctc-model-config.h index 4a9b30b80..84499f257 100644 --- a/sherpa-onnx/csrc/offline-wenet-ctc-model-config.h +++ b/sherpa-onnx/csrc/offline-wenet-ctc-model-config.h @@ -12,11 +12,15 @@ namespace sherpa_onnx { struct OfflineWenetCtcModelConfig { std::string model; + const char *model_buf_begin, *model_buf_end; OfflineWenetCtcModelConfig() = default; explicit OfflineWenetCtcModelConfig(const std::string &model) : model(model) {} + explicit OfflineWenetCtcModelConfig(const char *model_buf_begin, char *model_buf_end) + : model_buf_begin(model_buf_begin), model_buf_end(model_buf_end) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-wenet-ctc-model.cc b/sherpa-onnx/csrc/offline-wenet-ctc-model.cc index 93fdffab8..35dbd1649 100644 --- a/sherpa-onnx/csrc/offline-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-wenet-ctc-model.cc @@ -19,8 +19,12 @@ class OfflineWenetCtcModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - auto buf = ReadFile(config_.wenet_ctc.model); - Init(buf.data(), buf.size()); + if(config.wenet_ctc.model_buf_begin && config.wenet_ctc.model_buf_end) { + Init(config.wenet_ctc.model_buf_begin, config.wenet_ctc.model_buf_end - config.wenet_ctc.model_buf_begin); + } else { + auto buf = ReadFile(config_.wenet_ctc.model); + Init(buf.data(), buf.size()); + } } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/offline-whisper-model-config.h b/sherpa-onnx/csrc/offline-whisper-model-config.h index a612a46d5..bf19fd07d 100644 --- a/sherpa-onnx/csrc/offline-whisper-model-config.h +++ b/sherpa-onnx/csrc/offline-whisper-model-config.h @@ -12,7 +12,9 @@ namespace sherpa_onnx { struct OfflineWhisperModelConfig { std::string encoder; + const char *encoder_buf_begin, *encoder_buf_end; std::string decoder; + const char *decoder_buf_begin, *decoder_buf_end; // Available languages can be found at // https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 @@ -49,6 +51,16 @@ struct OfflineWhisperModelConfig { task(task), tail_paddings(tail_paddings) {} + OfflineWhisperModelConfig(const char *encoder_buf_begin, const char *encoder_buf_end, + const char *decoder_buf_begin, const char *decoder_buf_end, + const std::string &language, + const std::string &task, int32_t tail_paddings) + : encoder_buf_begin(encoder_buf_begin), encoder_buf_end(encoder_buf_end), + decoder_buf_begin(decoder_buf_begin), decoder_buf_end(decoder_buf_end), + language(language), + task(task), + tail_paddings(tail_paddings) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-whisper-model.cc b/sherpa-onnx/csrc/offline-whisper-model.cc index 7812e1d09..ab1d3126d 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.cc +++ b/sherpa-onnx/csrc/offline-whisper-model.cc @@ -25,12 +25,16 @@ class OfflineWhisperModel::Impl { debug_(config.debug), sess_opts_(GetSessionOptions(config)), allocator_{} { - { + if(config.whisper.encoder_buf_begin && config.whisper.encoder_buf_end) { + Init(config.whisper.encoder_buf_begin, config.whisper.encoder_buf_end - config.whisper.encoder_buf_begin); + } else { auto buf = ReadFile(config.whisper.encoder); InitEncoder(buf.data(), buf.size()); } - { + if(config.whisper.decoder_buf_begin && config.whisper.decoder_buf_end) { + Init(config.whisper.decoder_buf_begin, config.whisper.decoder_buf_end - config.whisper.decoder_buf_begin); + } else { auto buf = ReadFile(config.whisper.decoder); InitDecoder(buf.data(), buf.size()); } diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h index 4f60e832e..f0f3977a0 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h @@ -17,7 +17,7 @@ struct OfflineZipformerAudioTaggingModelConfig { explicit OfflineZipformerAudioTaggingModelConfig(const std::string &model) : model(model) {} - + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h index 702575e72..c606caf5c 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h @@ -14,11 +14,14 @@ namespace sherpa_onnx { // https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py struct OfflineZipformerCtcModelConfig { std::string model; + const char *model_buf_begin, *model_buf_end; OfflineZipformerCtcModelConfig() = default; explicit OfflineZipformerCtcModelConfig(const std::string &model) : model(model) {} + explicit OfflineZipformerCtcModelConfig(const char *model_buf_begin, const char *model_buf_end) + : model_buf_begin(model_buf_begin), model_buf_end(model_buf_end) {} void Register(ParseOptions *po); diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc index 8db9439e4..05bed2b7b 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc @@ -21,8 +21,12 @@ class OfflineZipformerCtcModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - auto buf = ReadFile(config_.zipformer_ctc.model); - Init(buf.data(), buf.size()); + if(config.zipformer_ctc.model_buf_begin && config.zipformer_ctc.model_buf_end) { + Init(config.zipformer_ctc.model_buf_begin, config.zipformer_ctc.model_buf_end - config.zipformer_ctc.model_buf_begin); + } else { + auto buf = ReadFile(config_.zipformer_ctc.model); + Init(buf.data(), buf.size()); + } } #if __ANDROID_API__ >= 9 @@ -118,4 +122,4 @@ int32_t OfflineZipformerCtcModel::SubsamplingFactor() const { return impl_->SubsamplingFactor(); } -} // namespace sherpa_onnx +} // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.cc b/sherpa-onnx/csrc/online-conformer-transducer-model.cc index 7c252f5a4..2c5888cd4 100644 --- a/sherpa-onnx/csrc/online-conformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.cc @@ -34,19 +34,31 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { - { - auto buf = ReadFile(config.transducer.encoder); - InitEncoder(buf.data(), buf.size()); + if(config.transducer.encoder_buf_begin && config.transducer.encoder_buf_end) { + InitEncoder(config.transducer.encoder_buf_begin, + config.transducer.encoder_buf_end - + config.transducer.encoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); } - - { - auto buf = ReadFile(config.transducer.decoder); - InitDecoder(buf.data(), buf.size()); + + if(config.transducer.decoder_buf_begin && config.transducer.decoder_buf_end) { + InitDecoder(config.transducer.decoder_buf_begin, + config.transducer.decoder_buf_end - + config.transducer.decoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); } - { - auto buf = ReadFile(config.transducer.joiner); - InitJoiner(buf.data(), buf.size()); + if(config.transducer.joiner_buf_begin && config.transducer.joiner_buf_end) { + InitJoiner(config.transducer.joiner_buf_begin, + config.transducer.joiner_buf_end - + config.transducer.joiner_buf_begin); + } else { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/online-ctc-model.cc b/sherpa-onnx/csrc/online-ctc-model.cc index a3a071a72..7f653c854 100644 --- a/sherpa-onnx/csrc/online-ctc-model.cc +++ b/sherpa-onnx/csrc/online-ctc-model.cc @@ -19,11 +19,11 @@ namespace sherpa_onnx { std::unique_ptr OnlineCtcModel::Create( const OnlineModelConfig &config) { - if (!config.wenet_ctc.model.empty()) { + if (!config.wenet_ctc.model.empty() || config.wenet_ctc.model_buf_begin) { return std::make_unique(config); - } else if (!config.zipformer2_ctc.model.empty()) { + } else if (!config.zipformer2_ctc.model.empty() || config.zipformer2_ctc.model_buf_begin) { return std::make_unique(config); - } else if (!config.nemo_ctc.model.empty()) { + } else if (!config.nemo_ctc.model.empty() || config.nemo_ctc.model_buf_begin) { return std::make_unique(config); } else { SHERPA_ONNX_LOGE("Please specify a CTC model"); diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 094cc933c..b93017f26 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -32,19 +32,31 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { - { - auto buf = ReadFile(config.transducer.encoder); - InitEncoder(buf.data(), buf.size()); + if(config.transducer.encoder_buf_begin && config.transducer.encoder_buf_end) { + InitEncoder(config.transducer.encoder_buf_begin, + config.transducer.encoder_buf_end - + config.transducer.encoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); } - - { - auto buf = ReadFile(config.transducer.decoder); - InitDecoder(buf.data(), buf.size()); + + if(config.transducer.decoder_buf_begin && config.transducer.decoder_buf_end) { + InitDecoder(config.transducer.decoder_buf_begin, + config.transducer.decoder_buf_end - + config.transducer.decoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); } - { - auto buf = ReadFile(config.transducer.joiner); - InitJoiner(buf.data(), buf.size()); + if(config.transducer.joiner_buf_begin && config.transducer.joiner_buf_end) { + InitJoiner(config.transducer.joiner_buf_begin, + config.transducer.joiner_buf_end - + config.transducer.joiner_buf_begin); + } else { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 0b64e06de..19ba6144c 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -22,7 +22,8 @@ struct OnlineModelConfig { OnlineZipformer2CtcModelConfig zipformer2_ctc; OnlineNeMoCtcModelConfig nemo_ctc; ProviderConfig provider_config; - std::string tokens; + std::string tokens; // the tokens file path + std::string tokens_buf_str; // the tokens' content in a buffered string int32_t num_threads = 1; int32_t warm_up = 0; bool debug = false; @@ -63,7 +64,7 @@ struct OnlineModelConfig { zipformer2_ctc(zipformer2_ctc), nemo_ctc(nemo_ctc), provider_config(provider_config), - tokens(tokens), + tokens(tokens), tokens_buf_str(""), num_threads(num_threads), warm_up(warm_up), debug(debug), diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model-config.h b/sherpa-onnx/csrc/online-nemo-ctc-model-config.h index 4fb1de0ad..fb57f8029 100644 --- a/sherpa-onnx/csrc/online-nemo-ctc-model-config.h +++ b/sherpa-onnx/csrc/online-nemo-ctc-model-config.h @@ -12,11 +12,15 @@ namespace sherpa_onnx { struct OnlineNeMoCtcModelConfig { std::string model; + const char *model_buf_begin, *model_buf_end; OnlineNeMoCtcModelConfig() = default; explicit OnlineNeMoCtcModelConfig(const std::string &model) : model(model) {} + explicit OnlineNeMoCtcModelConfig(const char *model_buf_begin, const char *model_buf_end) + : model_buf_begin(model_buf_begin), model_buf_end(model_buf_end) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model.cc b/sherpa-onnx/csrc/online-nemo-ctc-model.cc index d93ff73b1..1faa6f47d 100644 --- a/sherpa-onnx/csrc/online-nemo-ctc-model.cc +++ b/sherpa-onnx/csrc/online-nemo-ctc-model.cc @@ -30,7 +30,10 @@ class OnlineNeMoCtcModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - { + if(config.nemo_ctc.model_buf_begin && config.nemo_ctc.model_buf_end) { + Init(config.nemo_ctc.model_buf_begin, config.nemo_ctc.model_buf_end - + config.nemo_ctc.model_buf_begin); + } else { auto buf = ReadFile(config.nemo_ctc.model); Init(buf.data(), buf.size()); } diff --git a/sherpa-onnx/csrc/online-paraformer-model-config.h b/sherpa-onnx/csrc/online-paraformer-model-config.h index 29f33e45c..c15d08efd 100644 --- a/sherpa-onnx/csrc/online-paraformer-model-config.h +++ b/sherpa-onnx/csrc/online-paraformer-model-config.h @@ -12,7 +12,9 @@ namespace sherpa_onnx { struct OnlineParaformerModelConfig { std::string encoder; + const char *encoder_buf_begin, *encoder_buf_end; std::string decoder; + const char *decoder_buf_begin, *decoder_buf_end; OnlineParaformerModelConfig() = default; @@ -20,6 +22,11 @@ struct OnlineParaformerModelConfig { const std::string &decoder) : encoder(encoder), decoder(decoder) {} + OnlineParaformerModelConfig(const char *encoder_buf_begin, const char *encoder_buf_end, + const char *decoder_buf_begin, const char *decoder_buf_end) +: encoder_buf_begin(encoder_buf_begin), encoder_buf_end(encoder_buf_end), + decoder_buf_begin(decoder_buf_begin), decoder_buf_end(decoder_buf_end) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-paraformer-model.cc b/sherpa-onnx/csrc/online-paraformer-model.cc index 9397ff75b..272bf6713 100644 --- a/sherpa-onnx/csrc/online-paraformer-model.cc +++ b/sherpa-onnx/csrc/online-paraformer-model.cc @@ -27,12 +27,20 @@ class OnlineParaformerModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - { + if(config.paraformer.encoder_buf_begin && config.paraformer.encoder_buf_end) { + InitEncoder(config.paraformer.encoder_buf_begin, + config.paraformer.encoder_buf_end - + config.paraformer.encoder_buf_begin); + } else { auto buf = ReadFile(config.paraformer.encoder); InitEncoder(buf.data(), buf.size()); } - { + if(config.paraformer.decoder_buf_begin && config.paraformer.decoder_buf_end) { + InitDecoder(config.paraformer.decoder_buf_begin, + config.paraformer.decoder_buf_end - + config.paraformer.decoder_buf_begin); + } else { auto buf = ReadFile(config.paraformer.decoder); InitDecoder(buf.data(), buf.size()); } diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 7dd9d8b18..d7b6c15d3 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -71,9 +71,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { : OnlineRecognizerImpl(config), config_(config), model_(OnlineCtcModel::Create(config.model_config)), - sym_(config.model_config.tokens), + sym_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), endpoint_(config_.endpoint_config) { - if (!config.model_config.wenet_ctc.model.empty()) { + if (!config.model_config.wenet_ctc.model.empty() || config.model_config.wenet_ctc.model_buf_begin) { // WeNet CTC models assume input samples are in the range // [-32768, 32767], so we set normalize_samples to false config_.feat_config.normalize_samples = false; diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 399dab49e..f1ff099cf 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -27,12 +27,20 @@ namespace sherpa_onnx { std::unique_ptr OnlineRecognizerImpl::Create( const OnlineRecognizerConfig &config) { - if (!config.model_config.transducer.encoder.empty()) { + if (!config.model_config.transducer.encoder.empty() || config.model_config.transducer.encoder_buf_begin) { Ort::Env env(ORT_LOGGING_LEVEL_ERROR); - auto decoder_model = ReadFile(config.model_config.transducer.decoder); - auto sess = std::make_unique( - env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + unique_ptr sess; + if(config.model_config.transducer.decoder_buf_begin && config.model_config.transducer.decoder_buf_end) { + sess = std::make_unique( + env, config.model_config.transducer.decoder_buf_begin, + config.model_config.transducer.decoder_buf_end - config.model_config.transducer.decoder_buf_begin, + Ort::SessionOptions{}); + } else { + auto decoder_model = ReadFile(config.model_config.transducer.decoder); + sess = std::make_unique( + env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + } size_t node_count = sess->GetOutputCount(); @@ -43,13 +51,13 @@ std::unique_ptr OnlineRecognizerImpl::Create( } } - if (!config.model_config.paraformer.encoder.empty()) { + if (!config.model_config.paraformer.encoder.empty() || config.model_config.paraformer.encoder_buf_begin) { return std::make_unique(config); } - if (!config.model_config.wenet_ctc.model.empty() || - !config.model_config.zipformer2_ctc.model.empty() || - !config.model_config.nemo_ctc.model.empty()) { + if (!config.model_config.wenet_ctc.model.empty() || config.model_config.wenet_ctc.model_buf_begin || + !config.model_config.zipformer2_ctc.model.empty() || config.model_config.zipformer2_ctc.model_buf_begin || + !config.model_config.nemo_ctc.model.empty() || config.model_config.nemo_ctc.model_buf_begin) { return std::make_unique(config); } diff --git a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h index 26fdb08c3..2d6c719ec 100644 --- a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h @@ -99,7 +99,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { : OnlineRecognizerImpl(config), config_(config), model_(config.model_config), - sym_(config.model_config.tokens), + sym_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), endpoint_(config_.endpoint_config) { if (config.decoding_method != "greedy_search") { SHERPA_ONNX_LOGE( diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index ab1e165f3..a715d5dbf 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -83,7 +83,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { : OnlineRecognizerImpl(config), config_(config), model_(OnlineTransducerModel::Create(config.model_config)), - sym_(config.model_config.tokens), + sym_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), endpoint_(config_.endpoint_config) { if (sym_.Contains("")) { unk_id_ = sym_[""]; @@ -97,7 +98,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { config_.model_config.bpe_vocab); } - if (!config_.hotwords_file.empty()) { + if (!config_.hotwords_buf_str.empty()) { + InitHotwordsFromBufStr(); + } else if (!config_.hotwords_file.empty()) { InitHotwords(); } @@ -412,6 +415,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { hotwords_, config_.hotwords_score, boost_scores_); } + void InitHotwordsFromBufStr() { + // init the hotwords from its raw string-format directly, e.g R"(hotword1 :1.5 + // hotword2 :1.5)" + // each line in hotwords_file contains space-separated words + + std::istringstream iss(config_.hotwords_buf_str); + if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + #if __ANDROID_API__ >= 9 void InitHotwords(AAssetManager *mgr) { // each line in hotwords_file contains space-separated words diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 700054dc2..cc247d3ea 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -44,7 +44,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { const OnlineRecognizerConfig &config) : OnlineRecognizerImpl(config), config_(config), - symbol_table_(config.model_config.tokens), + symbol_table_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens : + config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true), endpoint_(config_.endpoint_config), model_( std::make_unique(config.model_config)) { diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 7fde367fb..66602e356 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -94,6 +94,7 @@ struct OnlineRecognizerConfig { /// used only for modified_beam_search std::string hotwords_file; + std::string hotwords_buf_str; float hotwords_score = 1.5; float blank_penalty = 0.0; @@ -126,6 +127,7 @@ struct OnlineRecognizerConfig { decoding_method(decoding_method), max_active_paths(max_active_paths), hotwords_file(hotwords_file), + hotwords_buf_str(""), hotwords_score(hotwords_score), blank_penalty(blank_penalty), temperature_scale(temperature_scale), diff --git a/sherpa-onnx/csrc/online-transducer-model-config.h b/sherpa-onnx/csrc/online-transducer-model-config.h index 5d79e25bf..a781bb92b 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.h +++ b/sherpa-onnx/csrc/online-transducer-model-config.h @@ -12,8 +12,11 @@ namespace sherpa_onnx { struct OnlineTransducerModelConfig { std::string encoder; + const char *encoder_buf_begin, *encoder_buf_end; std::string decoder; + const char *decoder_buf_begin, *decoder_buf_end; std::string joiner; + const char *joiner_buf_begin, *joiner_buf_end; OnlineTransducerModelConfig() = default; OnlineTransducerModelConfig(const std::string &encoder, @@ -21,6 +24,13 @@ struct OnlineTransducerModelConfig { const std::string &joiner) : encoder(encoder), decoder(decoder), joiner(joiner) {} + OnlineTransducerModelConfig(const char *encoder_buf_begin, const char *encoder_buf_end, + const char *decoder_buf_begin, const char *decoder_buf_end, + const char *joiner_buf_begin, const char *joiner_buf_end) +: encoder_buf_begin(encoder_buf_begin), encoder_buf_end(encoder_buf_end), + decoder_buf_begin(decoder_buf_begin), decoder_buf_end(decoder_buf_end), + joiner_buf_begin(joiner_buf_begin), joiner_buf_end(joiner_buf_end) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 16577dd49..9982ab690 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -98,9 +98,14 @@ std::unique_ptr OnlineTransducerModel::Create( ModelType model_type = ModelType::kUnknown; { - auto buffer = ReadFile(config.transducer.encoder); + if(config.transducer.encoder_buf_begin && config.transducer.encoder_buf_end) { + model_type = GetModelType(config.transducer.encoder_buf_begin, + config.transducer.encoder_buf_end - config.transducer.encoder_buf_begin, config.debug); + } else { + auto buffer = ReadFile(config.transducer.encoder); + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } - model_type = GetModelType(buffer.data(), buffer.size(), config.debug); } switch (model_type) { diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index 4e12da44c..ae3520847 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -38,19 +38,31 @@ class OnlineTransducerNeMoModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - { - auto buf = ReadFile(config.transducer.encoder); - InitEncoder(buf.data(), buf.size()); + if(config.transducer.encoder_buf_begin && config.transducer.encoder_buf_end) { + InitEncoder(config.transducer.encoder_buf_begin, + config.transducer.encoder_buf_end - + config.transducer.encoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); } - - { - auto buf = ReadFile(config.transducer.decoder); - InitDecoder(buf.data(), buf.size()); + + if(config.transducer.decoder_buf_begin && config.transducer.decoder_buf_end) { + InitDecoder(config.transducer.decoder_buf_begin, + config.transducer.decoder_buf_end - + config.transducer.decoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); } - { - auto buf = ReadFile(config.transducer.joiner); - InitJoiner(buf.data(), buf.size()); + if(config.transducer.joiner_buf_begin && config.transducer.joiner_buf_end) { + InitJoiner(config.transducer.joiner_buf_begin, + config.transducer.joiner_buf_end - + config.transducer.joiner_buf_begin); + } else { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model-config.h b/sherpa-onnx/csrc/online-wenet-ctc-model-config.h index b90295ea6..a34ccfb4c 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model-config.h +++ b/sherpa-onnx/csrc/online-wenet-ctc-model-config.h @@ -12,6 +12,7 @@ namespace sherpa_onnx { struct OnlineWenetCtcModelConfig { std::string model; + const char *model_buf_begin, *model_buf_end; // --chunk_size from wenet int32_t chunk_size = 16; @@ -27,6 +28,12 @@ struct OnlineWenetCtcModelConfig { chunk_size(chunk_size), num_left_chunks(num_left_chunks) {} + OnlineWenetCtcModelConfig(const char *model_buf_begin, const char *model_buf_end, int32_t chunk_size, + int32_t num_left_chunks) + : model_buf_begin(model_buf_begin), model_buf_end(model_buf_end), + chunk_size(chunk_size), + num_left_chunks(num_left_chunks) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.cc b/sherpa-onnx/csrc/online-wenet-ctc-model.cc index 1b1605183..c7e9aeb1b 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.cc @@ -27,7 +27,9 @@ class OnlineWenetCtcModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - { + if(config.wenet_ctc.model_buf_begin && config.wenet_ctc.model_buf_end) { + Init(config.wenet_ctc.model_buf_begin, config.wenet_ctc.model_buf_end - config.wenet_ctc.model_buf_begin) + } else { auto buf = ReadFile(config.wenet_ctc.model); Init(buf.data(), buf.size()); } diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 324b2b088..0fbf7ac32 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -34,19 +34,31 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { - { - auto buf = ReadFile(config.transducer.encoder); - InitEncoder(buf.data(), buf.size()); + if(config.transducer.encoder_buf_begin && config.transducer.encoder_buf_end) { + InitEncoder(config.transducer.encoder_buf_begin, + config.transducer.encoder_buf_end - + config.transducer.encoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); } - - { - auto buf = ReadFile(config.transducer.decoder); - InitDecoder(buf.data(), buf.size()); + + if(config.transducer.decoder_buf_begin && config.transducer.decoder_buf_end) { + InitDecoder(config.transducer.decoder_buf_begin, + config.transducer.decoder_buf_end - + config.transducer.decoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); } - { - auto buf = ReadFile(config.transducer.joiner); - InitJoiner(buf.data(), buf.size()); + if(config.transducer.joiner_buf_begin && config.transducer.joiner_buf_end) { + InitJoiner(config.transducer.joiner_buf_begin, + config.transducer.joiner_buf_end - + config.transducer.joiner_buf_begin); + } else { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h b/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h index 18115c8fe..eb4b6c905 100644 --- a/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h @@ -12,12 +12,16 @@ namespace sherpa_onnx { struct OnlineZipformer2CtcModelConfig { std::string model; + const char *model_buf_begin, *model_buf_end; OnlineZipformer2CtcModelConfig() = default; explicit OnlineZipformer2CtcModelConfig(const std::string &model) : model(model) {} + explicit OnlineZipformer2CtcModelConfig(const char *model_buf_begin, const char *model_buf_end) + : model_buf_begin(model_buf_begin), model_buf_end(model_buf_end) {} + void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc index 04699a56b..03b7d8cd4 100644 --- a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc @@ -31,7 +31,9 @@ class OnlineZipformer2CtcModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { - { + if(config.zipformer2_ctc.model_buf_begin && config.zipformer2_ctc.model_buf_end) { + Init(config.zipformer2_ctc.model_buf_begin, config.zipformer2_ctc.model_buf_end - config.zipformer2_ctc.model_buf_begin) + } else { auto buf = ReadFile(config.zipformer2_ctc.model); Init(buf.data(), buf.size()); } diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc index 0782f06fc..85596b8dd 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -38,19 +38,31 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( joiner_sess_opts_(GetSessionOptions(config, "joiner")), config_(config), allocator_{} { - { - auto buf = ReadFile(config.transducer.encoder); - InitEncoder(buf.data(), buf.size()); + if(config.transducer.encoder_buf_begin && config.transducer.encoder_buf_end) { + InitEncoder(config.transducer.encoder_buf_begin, + config.transducer.encoder_buf_end - + config.transducer.encoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); } - - { - auto buf = ReadFile(config.transducer.decoder); - InitDecoder(buf.data(), buf.size()); + + if(config.transducer.decoder_buf_begin && config.transducer.decoder_buf_end) { + InitDecoder(config.transducer.decoder_buf_begin, + config.transducer.decoder_buf_end - + config.transducer.decoder_buf_begin); + } else { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); } - { - auto buf = ReadFile(config.transducer.joiner); - InitJoiner(buf.data(), buf.size()); + if(config.transducer.joiner_buf_begin && config.transducer.joiner_buf_end) { + InitJoiner(config.transducer.joiner_buf_begin, + config.transducer.joiner_buf_end - + config.transducer.joiner_buf_begin); + } else { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 8862972b7..5777dff16 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -20,9 +20,15 @@ namespace sherpa_onnx { -SymbolTable::SymbolTable(const std::string &filename) { - std::ifstream is(filename); - Init(is); +SymbolTable::SymbolTable(const std::string &filename, bool is_file) { + if(is_file) { + std::ifstream is(filename); + Init(is); + } + else { + std::istringstream iss(filename); + Init(iss); + } } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index 00d7a69e2..2231a6f0f 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -25,7 +25,7 @@ class SymbolTable { /// sym ID /// /// Fields are separated by space(s). - explicit SymbolTable(const std::string &filename); + explicit SymbolTable(const std::string &filename, bool is_file = true); #if __ANDROID_API__ >= 9 SymbolTable(AAssetManager *mgr, const std::string &filename);