diff --git a/.github/workflows/c-api-test-loading-tokens-hotwords-from-memory.yaml b/.github/workflows/c-api-from-buffer.yaml similarity index 92% rename from .github/workflows/c-api-test-loading-tokens-hotwords-from-memory.yaml rename to .github/workflows/c-api-from-buffer.yaml index 6ce4372ba..49a5cf385 100644 --- a/.github/workflows/c-api-test-loading-tokens-hotwords-from-memory.yaml +++ b/.github/workflows/c-api-from-buffer.yaml @@ -1,4 +1,4 @@ -name: c-api-test-loading-tokens-hotwords-from-memory +name: c-api-from-memory on: push: @@ -7,7 +7,7 @@ on: tags: - 'v[0-9]+.[0-9]+.[0-9]+*' paths: - - '.github/workflows/c-api.yaml' + - '.github/workflows/c-api-from-buffer.yaml' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -18,7 +18,7 @@ on: branches: - master paths: - - '.github/workflows/c-api.yaml' + - '.github/workflows/c-api-from-buffer.yaml' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -29,11 +29,11 @@ on: workflow_dispatch: concurrency: - group: c-api-${{ github.ref }} + group: c-api-from-buffer-${{ github.ref }} cancel-in-progress: true jobs: - c_api: + c_api_from_buffer: name: ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: @@ -106,8 +106,9 @@ jobs: curl -SL -O https://huggingface.co/desh2608/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-small/blob/main/data/lang_bpe_500/bpe.model cp bpe.model sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/ rm bpe.model - + printf "▁A ▁T ▁P :1.5\n▁A ▁B ▁C :3.0" > hotwords.txt + mv hotwords.txt ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 echo "---" @@ -115,7 +116,7 @@ jobs: export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH - + ./streaming-zipformer-buffered-tokens-hotwords-c-api - + rm -rf sherpa-onnx-streaming-zipformer-* diff --git a/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c index 0da5f3317..9a02ec664 100644 --- a/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c +++ b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c @@ -5,8 +5,8 @@ // // This file demonstrates how to use streaming Zipformer with sherpa-onnx's C -// and with tokens and hotwords loaded from buffered strings instead of from external -// files API. +// and with tokens and hotwords loaded from buffered strings instead of from +// external files API. // clang-format off // // wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 @@ -22,7 +22,7 @@ #include "sherpa-onnx/c-api/c-api.h" static size_t ReadFile(const char *filename, const char **buffer_out) { - FILE *file = fopen(filename, "rb"); + FILE *file = fopen(filename, "r"); if (file == NULL) { fprintf(stderr, "Failed to open %s\n", filename); return -1; @@ -39,7 +39,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { size_t read_bytes = fread(*buffer_out, 1, size, file); if (read_bytes != size) { printf("Errors occured in reading the file %s\n", filename); - free(*buffer_out); + free((void *)*buffer_out); *buffer_out = NULL; fclose(file); return -1; @@ -80,14 +80,14 @@ int32_t main() { size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); if (token_buf_size < 1) { fprintf(stderr, "Please check your tokens.txt!\n"); - free(tokens_buf); + free((void *)tokens_buf); return -1; } const char *hotwords_buf; size_t hotwords_buf_size = ReadFile(hotwords_filename, &hotwords_buf); if (hotwords_buf_size < 1) { fprintf(stderr, "Please check your hotwords.txt!\n"); - free(hotwords_buf); + free((void *)hotwords_buf); return -1; } @@ -119,9 +119,9 @@ int32_t main() { SherpaOnnxOnlineRecognizer *recognizer = SherpaOnnxCreateOnlineRecognizer(&recognizer_config); - free(tokens_buf); + free((void *)tokens_buf); tokens_buf = NULL; - free(hotwords_buf); + free((void *)hotwords_buf); hotwords_buf = NULL; if (recognizer == NULL) { @@ -199,4 +199,4 @@ int32_t main() { fprintf(stderr, "\n"); return 0; -} \ No newline at end of file +} diff --git a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart index 6e822daa1..abc5e1f09 100644 --- a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart +++ b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart @@ -234,6 +234,11 @@ final class SherpaOnnxOnlineModelConfig extends Struct { external Pointer modelingUnit; external Pointer bpeVocab; + + external Pointer tokensBuf; + + @Int32() + external int tokensBufSize; } final class SherpaOnnxOnlineCtcFstDecoderConfig extends Struct { @@ -275,6 +280,11 @@ final class SherpaOnnxOnlineRecognizerConfig extends Struct { @Float() external double blankPenalty; + + external Pointer hotwordsBuf; + + @Int32() + external int hotwordsBufSize; } final class SherpaOnnxSileroVadModelConfig extends Struct { diff --git a/scripts/dotnet/OnlineModelConfig.cs b/scripts/dotnet/OnlineModelConfig.cs index 2c7d502e8..7adbaab96 100644 --- a/scripts/dotnet/OnlineModelConfig.cs +++ b/scripts/dotnet/OnlineModelConfig.cs @@ -22,6 +22,8 @@ public OnlineModelConfig() ModelType = ""; ModelingUnit = "cjkchar"; BpeVocab = ""; + TokensBuf = ""; + TokensBufSize = 0; } public OnlineTransducerModelConfig Transducer; @@ -48,6 +50,11 @@ public OnlineModelConfig() [MarshalAs(UnmanagedType.LPStr)] public string BpeVocab; + + [MarshalAs(UnmanagedType.LPStr)] + public string TokensBuf; + + public int TokensBufSize; } -} \ No newline at end of file +} diff --git a/scripts/dotnet/OnlineRecognizerConfig.cs b/scripts/dotnet/OnlineRecognizerConfig.cs index d9a8f610b..bd55a1091 100644 --- a/scripts/dotnet/OnlineRecognizerConfig.cs +++ b/scripts/dotnet/OnlineRecognizerConfig.cs @@ -26,6 +26,8 @@ public OnlineRecognizerConfig() RuleFsts = ""; RuleFars = ""; BlankPenalty = 0.0F; + HotwordsBuf = ""; + HotwordsBufSize = 0; } public FeatureConfig FeatConfig; public OnlineModelConfig ModelConfig; @@ -72,5 +74,10 @@ public OnlineRecognizerConfig() public string RuleFars; public float BlankPenalty; + + [MarshalAs(UnmanagedType.LPStr)] + public string HotwordsBuf; + + public int HotwordsBufSize; } } diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index f2637c40f..aeee609ca 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -89,6 +89,8 @@ type OnlineModelConfig struct { ModelType string // Optional. You can specify it for faster model initialization ModelingUnit string // Optional. cjkchar, bpe, cjkchar+bpe BpeVocab string // Optional. + TokensBuf string // Optional. + TokensBufSize int // Optional. } // Configuration for the feature extractor @@ -133,6 +135,8 @@ type OnlineRecognizerConfig struct { CtcFstDecoderConfig OnlineCtcFstDecoderConfig RuleFsts string RuleFars string + HotwordsBuf string + HotwordsBufSize int } // It contains the recognition result for a online stream. @@ -184,6 +188,11 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { c.model_config.tokens = C.CString(config.ModelConfig.Tokens) defer C.free(unsafe.Pointer(c.model_config.tokens)) + c.model_config.tokens_buf = C.CString(config.ModelConfig.TokensBuf) + defer C.free(unsafe.Pointer(c.model_config.tokens_buf)) + + c.model_config.tokens_buf_size = C.int(config.ModelConfig.TokensBufSize) + c.model_config.num_threads = C.int(config.ModelConfig.NumThreads) c.model_config.provider = C.CString(config.ModelConfig.Provider) @@ -212,6 +221,11 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { c.hotwords_file = C.CString(config.HotwordsFile) defer C.free(unsafe.Pointer(c.hotwords_file)) + c.hotwords_buf = C.CString(config.HotwordsBuf) + defer C.free(unsafe.Pointer(c.hotwords_buf)) + + c.hotwords_buf_size = C.int(config.HotwordsBufSize) + c.hotwords_score = C.float(config.HotwordsScore) c.blank_penalty = C.float(config.BlankPenalty) diff --git a/scripts/node-addon-api/src/streaming-asr.cc b/scripts/node-addon-api/src/streaming-asr.cc index 8976d2e7b..6057cade3 100644 --- a/scripts/node-addon-api/src/streaming-asr.cc +++ b/scripts/node-addon-api/src/streaming-asr.cc @@ -120,6 +120,8 @@ SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj) { SHERPA_ONNX_ASSIGN_ATTR_STR(model_type, modelType); SHERPA_ONNX_ASSIGN_ATTR_STR(modeling_unit, modelingUnit); SHERPA_ONNX_ASSIGN_ATTR_STR(bpe_vocab, bpeVocab); + SHERPA_ONNX_ASSIGN_ATTR_STR(tokens_buf, tokensBuf); + SHERPA_ONNX_ASSIGN_ATTR_INT32(tokens_buf_size, tokensBufSize); return c; } @@ -192,6 +194,8 @@ static Napi::External CreateOnlineRecognizerWrapper( SHERPA_ONNX_ASSIGN_ATTR_STR(rule_fsts, ruleFsts); SHERPA_ONNX_ASSIGN_ATTR_STR(rule_fars, ruleFars); SHERPA_ONNX_ASSIGN_ATTR_FLOAT(blank_penalty, blankPenalty); + SHERPA_ONNX_ASSIGN_ATTR_STR(hotwords_buf, hotwordsBuf); + SHERPA_ONNX_ASSIGN_ATTR_INT32(hotwords_buf_size, hotwordsBufSize); c.ctc_fst_decoder_config = GetCtcFstDecoderConfig(o); @@ -241,6 +245,10 @@ static Napi::External CreateOnlineRecognizerWrapper( delete[] c.model_config.bpe_vocab; } + if (c.model_config.tokens_buf) { + delete[] c.model_config.tokens_buf; + } + if (c.decoding_method) { delete[] c.decoding_method; } @@ -257,6 +265,10 @@ static Napi::External CreateOnlineRecognizerWrapper( delete[] c.rule_fars; } + if (c.hotwords_buf) { + delete[] c.hotwords_buf; + } + if (c.ctc_fst_decoder_config.graph) { delete[] c.ctc_fst_decoder_config.graph; } diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 11dba9816..67746e587 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -91,7 +91,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { /// if non-null, loading the tokens from the buffered string directly in /// prioriy const char *tokens_buf; - /// byte size excluding the tailing '\0' + /// byte size excluding the trailing '\0' int32_t tokens_buf_size; } SherpaOnnxOnlineModelConfig; diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index 6662d518b..5d42a484e 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/csrc/offline-stream.h" +#include + #include #include #include @@ -245,7 +247,7 @@ class OfflineStream::Impl { for (int32_t i = 0; i != n; ++i) { float x = p[i]; x = (x > amin) ? x : amin; - x = std::log10f(x) * multiplier; + x = log10f(x) * multiplier; max_x = (x > max_x) ? x : max_x; p[i] = x; diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 503472e04..50af6b987 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -372,7 +372,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { // segment is incremented only when the last // result is not empty, contains non-blanks and longer than context_size) const auto &r = s->GetResult(); - if (!r.tokens.empty() && r.tokens.back() != 0 && r.tokens.size() > context_size) { + if (!r.tokens.empty() && r.tokens.back() != 0 && + r.tokens.size() > context_size) { s->GetCurrentSegment() += 1; } } @@ -392,7 +393,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { // if last result is not empty, then // preserve last tokens as the context for next result if (static_cast(last_result.tokens.size()) > context_size) { - std::vector context(last_result.tokens.end() - context_size, last_result.tokens.end()); + std::vector context(last_result.tokens.end() - context_size, + last_result.tokens.end()); Hypotheses context_hyp({{context, 0}}); r.hyps = std::move(context_hyp); diff --git a/sherpa-onnx/pascal-api/sherpa_onnx.pas b/sherpa-onnx/pascal-api/sherpa_onnx.pas index dc0684ebc..987b31f14 100644 --- a/sherpa-onnx/pascal-api/sherpa_onnx.pas +++ b/sherpa-onnx/pascal-api/sherpa_onnx.pas @@ -145,6 +145,8 @@ TSherpaOnnxOnlineModelConfig = record ModelType: AnsiString; ModelingUnit: AnsiString; BpeVocab: AnsiString; + TokensBuf: AnsiString; + TokensBufSize: Integer; function ToString: AnsiString; class operator Initialize({$IFDEF FPC}var{$ELSE}out{$ENDIF} Dest: TSherpaOnnxOnlineModelConfig); end; @@ -178,6 +180,8 @@ TSherpaOnnxOnlineRecognizerConfig = record RuleFsts: AnsiString; RuleFars: AnsiString; BlankPenalty: Single; + HotwordsBuf: AnsiString; + HotwordsBufSize: Integer; function ToString: AnsiString; class operator Initialize({$IFDEF FPC}var{$ELSE}out{$ENDIF} Dest: TSherpaOnnxOnlineRecognizerConfig); end; @@ -490,6 +494,8 @@ SherpaOnnxOnlineModelConfig= record ModelType: PAnsiChar; ModelingUnit: PAnsiChar; BpeVocab: PAnsiChar; + TokensBuf: PAnsiChar; + TokensBufSize: cint32; end; SherpaOnnxFeatureConfig = record SampleRate: cint32; @@ -514,6 +520,8 @@ SherpaOnnxOnlineRecognizerConfig = record RuleFsts: PAnsiChar; RuleFars: PAnsiChar; BlankPenalty: cfloat; + HotwordsBuf: PAnsiChar; + HotwordsBufSize: cint32; end; PSherpaOnnxOnlineRecognizerConfig = ^SherpaOnnxOnlineRecognizerConfig; diff --git a/sherpa-onnx/python/csrc/online-punctuation.cc b/sherpa-onnx/python/csrc/online-punctuation.cc index 13aa66b64..decd16f9a 100644 --- a/sherpa-onnx/python/csrc/online-punctuation.cc +++ b/sherpa-onnx/python/csrc/online-punctuation.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/python/csrc/online-punctuation.h" +#include + #include "sherpa-onnx/csrc/online-punctuation.h" namespace sherpa_onnx { @@ -12,9 +14,11 @@ static void PybindOnlinePunctuationModelConfig(py::module *m) { using PyClass = OnlinePunctuationModelConfig; py::class_(*m, "OnlinePunctuationModelConfig") .def(py::init<>()) - .def(py::init(), - py::arg("cnn_bilstm"), py::arg("bpe_vocab"), py::arg("num_threads") = 1, - py::arg("debug") = false, py::arg("provider") = "cpu") + .def(py::init(), + py::arg("cnn_bilstm"), py::arg("bpe_vocab"), + py::arg("num_threads") = 1, py::arg("debug") = false, + py::arg("provider") = "cpu") .def_readwrite("cnn_bilstm", &PyClass::cnn_bilstm) .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) .def_readwrite("num_threads", &PyClass::num_threads) @@ -30,7 +34,8 @@ static void PybindOnlinePunctuationConfig(py::module *m) { py::class_(*m, "OnlinePunctuationConfig") .def(py::init<>()) - .def(py::init(), py::arg("model_config")) + .def(py::init(), + py::arg("model_config")) .def_readwrite("model_config", &PyClass::model) .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); @@ -43,8 +48,8 @@ void PybindOnlinePunctuation(py::module *m) { py::class_(*m, "OnlinePunctuation") .def(py::init(), py::arg("config"), py::call_guard()) - .def("add_punctuation_with_case", &PyClass::AddPunctuationWithCase, py::arg("text"), - py::call_guard()); + .def("add_punctuation_with_case", &PyClass::AddPunctuationWithCase, + py::arg("text"), py::call_guard()); } } // namespace sherpa_onnx diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index df46acfab..1bdc82b15 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -90,7 +90,9 @@ func sherpaOnnxOnlineModelConfig( debug: Int = 0, modelType: String = "", modelingUnit: String = "cjkchar", - bpeVocab: String = "" + bpeVocab: String = "", + tokensBuf: String = "", + tokensBufSize: Int = 0 ) -> SherpaOnnxOnlineModelConfig { return SherpaOnnxOnlineModelConfig( transducer: transducer, @@ -102,7 +104,9 @@ func sherpaOnnxOnlineModelConfig( debug: Int32(debug), model_type: toCPointer(modelType), modeling_unit: toCPointer(modelingUnit), - bpe_vocab: toCPointer(bpeVocab) + bpe_vocab: toCPointer(bpeVocab), + tokens_buf: toCPointer(tokensBuf), + tokens_buf_size: Int32(tokensBufSize) ) } @@ -138,7 +142,9 @@ func sherpaOnnxOnlineRecognizerConfig( ctcFstDecoderConfig: SherpaOnnxOnlineCtcFstDecoderConfig = sherpaOnnxOnlineCtcFstDecoderConfig(), ruleFsts: String = "", ruleFars: String = "", - blankPenalty: Float = 0.0 + blankPenalty: Float = 0.0, + hotwordsBuf: String = "", + hotwordsBufSize: Int = 0 ) -> SherpaOnnxOnlineRecognizerConfig { return SherpaOnnxOnlineRecognizerConfig( feat_config: featConfig, @@ -154,7 +160,9 @@ func sherpaOnnxOnlineRecognizerConfig( ctc_fst_decoder_config: ctcFstDecoderConfig, rule_fsts: toCPointer(ruleFsts), rule_fars: toCPointer(ruleFars), - blank_penalty: blankPenalty + blank_penalty: blankPenalty, + hotwords_buf: toCPointer(hotwordsBuf), + hotwords_buf_size: Int32(hotwordsBufSize) ) } diff --git a/wasm/asr/sherpa-onnx-asr.js b/wasm/asr/sherpa-onnx-asr.js index f0b8bb778..9b966090c 100644 --- a/wasm/asr/sherpa-onnx-asr.js +++ b/wasm/asr/sherpa-onnx-asr.js @@ -155,6 +155,14 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { }; } + if (!('tokensBuf' in config)) { + config.tokensBuf = ''; + } + + if (!('tokensBufSize' in config)) { + config.tokensBufSize = 0; + } + const transducer = initSherpaOnnxOnlineTransducerModelConfig(config.transducer, Module); @@ -164,7 +172,7 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { const ctc = initSherpaOnnxOnlineZipformer2CtcModelConfig( config.zipformer2Ctc, Module); - const len = transducer.len + paraformer.len + ctc.len + 7 * 4; + const len = transducer.len + paraformer.len + ctc.len + 9 * 4; const ptr = Module._malloc(len); let offset = 0; @@ -182,9 +190,10 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { const modelTypeLen = Module.lengthBytesUTF8(config.modelType || '') + 1; const modelingUnitLen = Module.lengthBytesUTF8(config.modelingUnit || '') + 1; const bpeVocabLen = Module.lengthBytesUTF8(config.bpeVocab || '') + 1; + const tokensBufLen = Module.lengthBytesUTF8(config.tokensBuf || '') + 1; - const bufferLen = - tokensLen + providerLen + modelTypeLen + modelingUnitLen + bpeVocabLen; + const bufferLen = tokensLen + providerLen + modelTypeLen + modelingUnitLen + + bpeVocabLen + tokensBufLen; const buffer = Module._malloc(bufferLen); offset = 0; @@ -204,6 +213,9 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { Module.stringToUTF8(config.bpeVocab || '', buffer + offset, bpeVocabLen); offset += bpeVocabLen; + Module.stringToUTF8(config.tokensBuf || '', buffer + offset, tokensBufLen); + offset += tokensBufLen; + offset = transducer.len + paraformer.len + ctc.len; Module.setValue(ptr + offset, buffer, 'i8*'); // tokens offset += 4; @@ -232,6 +244,16 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { 'i8*'); // bpeVocab offset += 4; + Module.setValue( + ptr + offset, + buffer + tokensLen + providerLen + modelTypeLen + modelingUnitLen + + bpeVocabLen, + 'i8*'); // tokens_buf + offset += 4; + + Module.setValue(ptr + offset, config.tokensBufSize || 0, 'i32'); + offset += 4; + return { buffer: buffer, ptr: ptr, len: len, transducer: transducer, paraformer: paraformer, ctc: ctc @@ -275,12 +297,20 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { }; } + if (!('hotwordsBuf' in config)) { + config.hotwordsBuf = ''; + } + + if (!('hotwordsBufSize' in config)) { + config.hotwordsBufSize = 0; + } + const feat = initSherpaOnnxFeatureConfig(config.featConfig, Module); const model = initSherpaOnnxOnlineModelConfig(config.modelConfig, Module); const ctcFstDecoder = initSherpaOnnxOnlineCtcFstDecoderConfig( config.ctcFstDecoderConfig, Module) - const len = feat.len + model.len + 8 * 4 + ctcFstDecoder.len + 3 * 4; + const len = feat.len + model.len + 8 * 4 + ctcFstDecoder.len + 5 * 4; const ptr = Module._malloc(len); let offset = 0; @@ -295,8 +325,9 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { const hotwordsFileLen = Module.lengthBytesUTF8(config.hotwordsFile || '') + 1; const ruleFstsFileLen = Module.lengthBytesUTF8(config.ruleFsts || '') + 1; const ruleFarsFileLen = Module.lengthBytesUTF8(config.ruleFars || '') + 1; - const bufferLen = - decodingMethodLen + hotwordsFileLen + ruleFstsFileLen + ruleFarsFileLen; + const hotwordsBufLen = Module.lengthBytesUTF8(config.hotwordsBuf || '') + 1; + const bufferLen = decodingMethodLen + hotwordsFileLen + ruleFstsFileLen + + ruleFarsFileLen + hotwordsBufLen; const buffer = Module._malloc(bufferLen); offset = 0; @@ -314,6 +345,10 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { Module.stringToUTF8(config.ruleFars || '', buffer + offset, ruleFarsFileLen); offset += ruleFarsFileLen; + Module.stringToUTF8( + config.hotwordsBuf || '', buffer + offset, hotwordsBufLen); + offset += hotwordsBufLen; + offset = feat.len + model.len; Module.setValue(ptr + offset, buffer, 'i8*'); // decoding method offset += 4; @@ -354,6 +389,16 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { Module.setValue(ptr + offset, config.blankPenalty || 0, 'float'); offset += 4; + Module.setValue( + ptr + offset, + buffer + decodingMethodLen + hotwordsFileLen + ruleFstsFileLen + + ruleFarsFileLen, + 'i8*'); + offset += 4; + + Module.setValue(ptr + offset, config.hotwordsBufSize || 0, 'i32'); + offset += 4; + return { buffer: buffer, ptr: ptr, len: len, feat: feat, model: model, ctcFstDecoder: ctcFstDecoder diff --git a/wasm/asr/sherpa-onnx-wasm-main-asr.cc b/wasm/asr/sherpa-onnx-wasm-main-asr.cc index 4267c5c68..ffd90c201 100644 --- a/wasm/asr/sherpa-onnx-wasm-main-asr.cc +++ b/wasm/asr/sherpa-onnx-wasm-main-asr.cc @@ -19,14 +19,14 @@ static_assert(sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) == 1 * 4, ""); static_assert(sizeof(SherpaOnnxOnlineModelConfig) == sizeof(SherpaOnnxOnlineTransducerModelConfig) + sizeof(SherpaOnnxOnlineParaformerModelConfig) + - sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 7 * 4, + sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 9 * 4, ""); static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); static_assert(sizeof(SherpaOnnxOnlineCtcFstDecoderConfig) == 2 * 4, ""); static_assert(sizeof(SherpaOnnxOnlineRecognizerConfig) == sizeof(SherpaOnnxFeatureConfig) + sizeof(SherpaOnnxOnlineModelConfig) + 8 * 4 + - sizeof(SherpaOnnxOnlineCtcFstDecoderConfig) + 3 * 4, + sizeof(SherpaOnnxOnlineCtcFstDecoderConfig) + 5 * 4, ""); void MyPrint(SherpaOnnxOnlineRecognizerConfig *config) { @@ -54,6 +54,9 @@ void MyPrint(SherpaOnnxOnlineRecognizerConfig *config) { fprintf(stdout, "model type: %s\n", model_config->model_type); fprintf(stdout, "modeling unit: %s\n", model_config->modeling_unit); fprintf(stdout, "bpe vocab: %s\n", model_config->bpe_vocab); + fprintf(stdout, "tokens_buf: %s\n", + model_config->tokens_buf ? model_config->tokens_buf : ""); + fprintf(stdout, "tokens_buf_size: %d\n", model_config->tokens_buf_size); fprintf(stdout, "----------feat config----------\n"); fprintf(stdout, "sample rate: %d\n", feat->sample_rate); diff --git a/wasm/kws/sherpa-onnx-kws.js b/wasm/kws/sherpa-onnx-kws.js index c9fd7cb6f..dc1712bc9 100644 --- a/wasm/kws/sherpa-onnx-kws.js +++ b/wasm/kws/sherpa-onnx-kws.js @@ -62,12 +62,20 @@ function initSherpaOnnxOnlineTransducerModelConfig(config, Module) { // The user should free the returned pointers function initModelConfig(config, Module) { + if (!('tokensBuf' in config)) { + config.tokensBuf = ''; + } + + if (!('tokensBufSize' in config)) { + config.tokensBufSize = 0; + } + const transducer = initSherpaOnnxOnlineTransducerModelConfig(config.transducer, Module); const paraformer_len = 2 * 4 const ctc_len = 1 * 4 - const len = transducer.len + paraformer_len + ctc_len + 7 * 4; + const len = transducer.len + paraformer_len + ctc_len + 9 * 4; const ptr = Module._malloc(len); Module.HEAPU8.fill(0, ptr, ptr + len); @@ -79,8 +87,9 @@ function initModelConfig(config, Module) { const modelTypeLen = Module.lengthBytesUTF8(config.modelType || '') + 1; const modelingUnitLen = Module.lengthBytesUTF8(config.modelingUnit || '') + 1; const bpeVocabLen = Module.lengthBytesUTF8(config.bpeVocab || '') + 1; - const bufferLen = - tokensLen + providerLen + modelTypeLen + modelingUnitLen + bpeVocabLen; + const tokensBufLen = Module.lengthBytesUTF8(config.tokensBuf || '') + 1; + const bufferLen = tokensLen + providerLen + modelTypeLen + modelingUnitLen + + bpeVocabLen + tokensBufLen; const buffer = Module._malloc(bufferLen); offset = 0; @@ -100,6 +109,9 @@ function initModelConfig(config, Module) { Module.stringToUTF8(config.bpeVocab || '', buffer + offset, bpeVocabLen); offset += bpeVocabLen; + Module.stringToUTF8(config.tokensBuf || '', buffer + offset, tokensBufLen); + offset += tokensBufLen; + offset = transducer.len + paraformer_len + ctc_len; Module.setValue(ptr + offset, buffer, 'i8*'); // tokens offset += 4; @@ -128,6 +140,16 @@ function initModelConfig(config, Module) { 'i8*'); // bpeVocab offset += 4; + Module.setValue( + ptr + offset, + buffer + tokensLen + providerLen + modelTypeLen + modelingUnitLen + + bpeVocabLen, + 'i8*'); // tokens_buf + offset += 4; + + Module.setValue(ptr + offset, config.tokensBufSize || 0, 'i32'); + offset += 4; + return { buffer: buffer, ptr: ptr, len: len, transducer: transducer } diff --git a/wasm/kws/sherpa-onnx-wasm-main-kws.cc b/wasm/kws/sherpa-onnx-wasm-main-kws.cc index cbb3ab37d..cb3627955 100644 --- a/wasm/kws/sherpa-onnx-wasm-main-kws.cc +++ b/wasm/kws/sherpa-onnx-wasm-main-kws.cc @@ -19,7 +19,7 @@ static_assert(sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) == 1 * 4, ""); static_assert(sizeof(SherpaOnnxOnlineModelConfig) == sizeof(SherpaOnnxOnlineTransducerModelConfig) + sizeof(SherpaOnnxOnlineParaformerModelConfig) + - sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 7 * 4, + sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 9 * 4, ""); static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); static_assert(sizeof(SherpaOnnxKeywordSpotterConfig) ==