Skip to content

Commit

Permalink
Added an alternative way to initialize/load some models (for non-And…
Browse files Browse the repository at this point in the history
…roid codes), tokens, hotwrods and keywords, that initializing/loading from memory buffers.

    The original usage of model initialization unchanged.
    the supported models are:
    online models:
      OnlineTransducerModel
      OnlineParaformerModel
      OnlineWenetCtcModel
      OnlineZipformer2CtcModel
      OnlineNeMoCtcModel

    offline models:
      OfflineTransducerModel
      OfflineParaformerModel
      OfflineNemoEncDecCtcModel
      OfflineWhisperModel
      OfflineTdnnModel
      OfflineZipformerCtcModel
      OfflineWenetCtcModel
      OfflineSenseVoiceModel

    keywordspotter models:
      KeywordSpotterTransducer
  • Loading branch information
xiao committed Sep 7, 2024
1 parent 888f74b commit 742fdb8
Show file tree
Hide file tree
Showing 57 changed files with 582 additions and 143 deletions.
113 changes: 112 additions & 1 deletion sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,51 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(

recognizer_config.model_config.transducer.encoder =
SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
recognizer_config.model_config.transducer.encoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_begin, nullptr);
recognizer_config.model_config.transducer.encoder_buf_end =
SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_end, nullptr);

recognizer_config.model_config.transducer.decoder =
SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
recognizer_config.model_config.transducer.decoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_begin, nullptr);
recognizer_config.model_config.transducer.decoder_buf_end =
SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_end, nullptr);

recognizer_config.model_config.transducer.joiner =
SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");
recognizer_config.model_config.transducer.joiner_buf_begin =
SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_begin, nullptr);
recognizer_config.model_config.transducer.joiner_buf_end =
SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_end, nullptr);

recognizer_config.model_config.paraformer.encoder =
SHERPA_ONNX_OR(config->model_config.paraformer.encoder, "");
recognizer_config.model_config.paraformer.encoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.paraformer.encoder_buf_begin, nullptr);
recognizer_config.model_config.paraformer.encoder_buf_end =
SHERPA_ONNX_OR(config->model_config.paraformer.encoder_buf_end, nullptr);

recognizer_config.model_config.paraformer.decoder =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder, "");
recognizer_config.model_config.paraformer.decoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder_buf_begin, nullptr);
recognizer_config.model_config.paraformer.decoder_buf_end =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder_buf_end, nullptr);

recognizer_config.model_config.zipformer2_ctc.model =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, "");
recognizer_config.model_config.zipformer2_ctc.model_buf_begin =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model_buf_begin, nullptr);
recognizer_config.model_config.zipformer2_ctc.model_buf_end =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model_buf_end, nullptr);

recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.tokens_buf_str =
SHERPA_ONNX_OR(config->model_config.tokens_buf_str, "");

recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.provider_config.provider =
Expand Down Expand Up @@ -118,6 +148,7 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
SHERPA_ONNX_OR(config->rule3_min_utterance_length, 20);

recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_buf_str = SHERPA_ONNX_OR(config->hotwords_buf_str, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);

Expand Down Expand Up @@ -366,24 +397,66 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
recognizer_config.model_config.transducer.encoder_filename =
SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");

recognizer_config.model_config.transducer.encoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_begin, nullptr);

recognizer_config.model_config.transducer.encoder_buf_end =
SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_end, nullptr);

recognizer_config.model_config.transducer.decoder_filename =
SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");

recognizer_config.model_config.transducer.decoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_begin, nullptr);

recognizer_config.model_config.transducer.decoder_buf_end =
SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_end, nullptr);

recognizer_config.model_config.transducer.joiner_filename =
SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");

recognizer_config.model_config.transducer.joiner_buf_begin =
SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_begin, nullptr);

recognizer_config.model_config.transducer.joiner_buf_end =
SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_end, nullptr);

recognizer_config.model_config.paraformer.model =
SHERPA_ONNX_OR(config->model_config.paraformer.model, "");

recognizer_config.model_config.paraformer.model_buf_begin =
SHERPA_ONNX_OR(config->model_config.paraformer.model_buf_begin, nullptr);

recognizer_config.model_config.paraformer.model_buf_end =
SHERPA_ONNX_OR(config->model_config.paraformer.model_buf_end, nullptr);

recognizer_config.model_config.nemo_ctc.model =
SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, "");

recognizer_config.model_config.nemo_ctc.model_buf_begin =
SHERPA_ONNX_OR(config->model_config.nemo_ctc.model_buf_begin, nullptr);

recognizer_config.model_config.nemo_ctc.model_buf_end =
SHERPA_ONNX_OR(config->model_config.nemo_ctc.model_buf_end, nullptr);

recognizer_config.model_config.whisper.encoder =
SHERPA_ONNX_OR(config->model_config.whisper.encoder, "");

recognizer_config.model_config.whisper.encoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.whisper.encoder_buf_begin, nullptr);

recognizer_config.model_config.whisper.encoder_buf_end =
SHERPA_ONNX_OR(config->model_config.whisper.encoder_buf_end, nullptr);

recognizer_config.model_config.whisper.decoder =
SHERPA_ONNX_OR(config->model_config.whisper.decoder, "");

recognizer_config.model_config.whisper.decoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.whisper.decoder_buf_begin, nullptr);

recognizer_config.model_config.whisper.decoder_buf_end =
SHERPA_ONNX_OR(config->model_config.whisper.decoder_buf_end, nullptr);

recognizer_config.model_config.whisper.language =
SHERPA_ONNX_OR(config->model_config.whisper.language, "");

Expand All @@ -398,9 +471,16 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(

recognizer_config.model_config.tdnn.model =
SHERPA_ONNX_OR(config->model_config.tdnn.model, "");
recognizer_config.model_config.tdnn.model_buf_begin =
SHERPA_ONNX_OR(config->model_config.tdnn.model_buf_begin, nullptr);
recognizer_config.model_config.tdnn.model_buf_end =
SHERPA_ONNX_OR(config->model_config.tdnn.model_buf_end, nullptr);

recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.tokens_buf_str =
SHERPA_ONNX_OR(config->model_config.tokens_buf_str, "");

recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.debug =
Expand Down Expand Up @@ -451,6 +531,7 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
SHERPA_ONNX_OR(config->max_active_paths, 4);

recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_buf_str = SHERPA_ONNX_OR(config->hotwords_buf_str, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);

Expand Down Expand Up @@ -639,24 +720,53 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
spotter_config.feat_config.feature_dim =
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);

spotter_config.model_config.transducer.encoder =
SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
spotter_config.model_config.transducer.encoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_begin, nullptr);
spotter_config.model_config.transducer.encoder_buf_end =
SHERPA_ONNX_OR(config->model_config.transducer.encoder_buf_end, nullptr);

spotter_config.model_config.transducer.decoder =
SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
spotter_config.model_config.transducer.decoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_begin, nullptr);
spotter_config.model_config.transducer.decoder_buf_end =
SHERPA_ONNX_OR(config->model_config.transducer.decoder_buf_end, nullptr);

spotter_config.model_config.transducer.joiner =
SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");
spotter_config.model_config.transducer.joiner_buf_begin =
SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_begin, nullptr);
spotter_config.model_config.transducer.joiner_buf_end =
SHERPA_ONNX_OR(config->model_config.transducer.joiner_buf_end, nullptr);

spotter_config.model_config.paraformer.encoder =
SHERPA_ONNX_OR(config->model_config.paraformer.encoder, "");
spotter_config.model_config.paraformer.encoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.paraformer.encoder_buf_begin, nullptr);
spotter_config.model_config.paraformer.encoder_buf_end =
SHERPA_ONNX_OR(config->model_config.paraformer.encoder_buf_end, nullptr);

spotter_config.model_config.paraformer.decoder =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder, "");
spotter_config.model_config.paraformer.decoder_buf_begin =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder_buf_begin, nullptr);
spotter_config.model_config.paraformer.decoder_buf_end =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder_buf_end, nullptr);

spotter_config.model_config.zipformer2_ctc.model =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, "");
spotter_config.model_config.zipformer2_ctc.model_buf_begin =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model_buf_begin, nullptr);
spotter_config.model_config.zipformer2_ctc.model_buf_end =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model_buf_end, nullptr);

spotter_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
spotter_config.model_config.tokens_buf_str =
SHERPA_ONNX_OR(config->model_config.tokens_buf_str, "");
spotter_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
spotter_config.model_config.provider_config.provider =
Expand All @@ -681,6 +791,7 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
SHERPA_ONNX_OR(config->keywords_threshold, 0.25);

spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, "");
spotter_config.keywords_buf_str = SHERPA_ONNX_OR(config->keywords_buf_str, "");

if (config->model_config.debug) {
SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str());
Expand Down
19 changes: 19 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,37 @@ extern "C" {
/// from there.
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig {
const char *encoder;
const char *encoder_buf_begin, *encoder_buf_end; // if non-null, loading the encoder from the buffer in prioriy
const char *decoder;
const char *decoder_buf_begin, *decoder_buf_end; // if non-null, loading the decoder from the buffer in prioriy
const char *joiner;
const char *joiner_buf_begin, *joiner_buf_end; // if non-null, loading the joiner from the buffer in prioriy
} SherpaOnnxOnlineTransducerModelConfig;

// please visit
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html
// to download pre-trained streaming paraformer models
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineParaformerModelConfig {
const char *encoder;
const char *encoder_buf_begin, *encoder_buf_end; // if non-null, loading the encoder from the buffer in prioriy
const char *decoder;
const char *decoder_buf_begin, *decoder_buf_end; // if non-null, loading the decoder from the buffer in prioriy
} SherpaOnnxOnlineParaformerModelConfig;

// Please visit
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#
// to download pre-trained streaming zipformer2 ctc models
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineZipformer2CtcModelConfig {
const char *model;
const char *model_buf_begin, *model_buf_end; // if non-null, loading the model from the buffer in prioriy
} SherpaOnnxOnlineZipformer2CtcModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
SherpaOnnxOnlineTransducerModelConfig transducer;
SherpaOnnxOnlineParaformerModelConfig paraformer;
SherpaOnnxOnlineZipformer2CtcModelConfig zipformer2_ctc;
const char *tokens;
const char *tokens_buf_str; // if non-null, loading the tokens from the buffered string directly in prioriy
int32_t num_threads;
const char *provider;
int32_t debug; // true to print debug information of the model
Expand Down Expand Up @@ -139,6 +146,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {

/// Path to the hotwords.
const char *hotwords_file;
const char *hotwords_buf_str; // if non-empty, loading the hotwords from the buffered string directly in prioriy

/// Bonus score for each token in hotwords.
float hotwords_score;
Expand Down Expand Up @@ -358,28 +366,36 @@ SHERPA_ONNX_API void SherpaOnnxPrint(const SherpaOnnxDisplay *display,
/// from there.
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTransducerModelConfig {
const char *encoder;
const char *encoder_buf_begin, *encoder_buf_end; // if non-null, loading the encoder from the buffer in prioriy
const char *decoder;
const char *decoder_buf_begin, *decoder_buf_end; // if non-null, loading the decoder from the buffer in prioriy
const char *joiner;
const char *joiner_buf_begin, *joiner_buf_end; // if non-null, loading the joiner from the buffer in prioriy
} SherpaOnnxOfflineTransducerModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineParaformerModelConfig {
const char *model;
const char *model_buf_begin, *model_buf_end; // if non-null, loading the model from the buffer in prioriy
} SherpaOnnxOfflineParaformerModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineNemoEncDecCtcModelConfig {
const char *model;
const char *model_buf_begin, *model_buf_end; // if non-null, loading the model from the buffer in prioriy
} SherpaOnnxOfflineNemoEncDecCtcModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig {
const char *encoder;
const char *encoder_buf_begin, *encoder_buf_end; // if non-null, loading the encoder from the buffer in prioriy
const char *decoder;
const char *decoder_buf_begin, *decoder_buf_end; // if non-null, loading the decoder from the buffer in prioriy
const char *language;
const char *task;
int32_t tail_paddings;
} SherpaOnnxOfflineWhisperModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig {
const char *model;
const char *model_buf_begin, *model_buf_end; // if non-null, loading the model from the buffer in prioriy
} SherpaOnnxOfflineTdnnModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig {
Expand All @@ -401,6 +417,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
SherpaOnnxOfflineTdnnModelConfig tdnn;

const char *tokens;
const char *tokens_buf_str; // if non-null, loading the tokens from the buffered string directly in prioriy
int32_t num_threads;
int32_t debug;
const char *provider;
Expand All @@ -425,6 +442,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {

/// Path to the hotwords.
const char *hotwords_file;
const char *hotwords_buf_str; // if non-null, loading the tokens from the buffered string directly in prioriy

/// Bonus score for each token in hotwords.
float hotwords_score;
Expand Down Expand Up @@ -627,6 +645,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
float keywords_score;
float keywords_threshold;
const char *keywords_file;
const char *keywords_buf_str;
} SherpaOnnxKeywordSpotterConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/keyword-spotter-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace sherpa_onnx {

std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
const KeywordSpotterConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
if (!config.model_config.transducer.encoder.empty() || config.model_config.transducer.encoder_buf_begin) {
return std::make_unique<KeywordSpotterTransducerImpl>(config);
}

Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens) {
sym_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens :
config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? false, true) {
if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
Expand Down
4 changes: 3 additions & 1 deletion sherpa-onnx/csrc/keyword-spotter.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ struct KeywordSpotterConfig {
float keywords_threshold = 0.25;

std::string keywords_file;
std::string keywords_buf_str;

KeywordSpotterConfig() = default;

Expand All @@ -82,7 +83,8 @@ struct KeywordSpotterConfig {
num_trailing_blanks(num_trailing_blanks),
keywords_score(keywords_score),
keywords_threshold(keywords_threshold),
keywords_file(keywords_file) {}
keywords_file(keywords_file),
keywords_buf_str("") {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct OfflineModelConfig {
std::string telespeech_ctc;

std::string tokens;
std::string tokens_buf_str;
int32_t num_threads = 2;
bool debug = false;
std::string provider = "cpu";
Expand Down Expand Up @@ -70,7 +71,7 @@ struct OfflineModelConfig {
wenet_ctc(wenet_ctc),
sense_voice(sense_voice),
telespeech_ctc(telespeech_ctc),
tokens(tokens),
tokens(tokens), tokens_buf_str(""),
num_threads(num_threads),
debug(debug),
provider(provider),
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ namespace sherpa_onnx {

struct OfflineNemoEncDecCtcModelConfig {
std::string model;
const char *model_buf_begin, *model_buf_end;

OfflineNemoEncDecCtcModelConfig() = default;
explicit OfflineNemoEncDecCtcModelConfig(const std::string &model)
: model(model) {}

explicit OfflineNemoEncDecCtcModelConfig(const char *model_buf_begin, *model_buf_end)
: model_buf_begin(model_buf_begin), model_buf_end(model_buf_end) {}

void Register(ParseOptions *po);
bool Validate() const;

Expand Down
Loading

0 comments on commit 742fdb8

Please sign in to comment.