Skip to content

Commit

Permalink
Add hotwords support to Swift/Go/Python/C#/Kotlin APIs (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Aug 31, 2023
1 parent 0f8e46d commit 401de81
Show file tree
Hide file tree
Showing 20 changed files with 196 additions and 113 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-ncnn)

set(SHERPA_NCNN_VERSION "2.0.7")
set(SHERPA_NCNN_VERSION "2.1.0")

# Disable warning about
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ data class RecognizerConfig(
var rule1MinTrailingSilence: Float = 2.4f,
var rule2MinTrailingSilence: Float = 1.0f,
var rule3MinUtteranceLength: Float = 30.0f,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
)

class SherpaNcnn(
Expand Down
14 changes: 8 additions & 6 deletions c-api-examples/decode-file-c-api.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ int32_t main(int32_t argc, char *argv[]) {
return -1;
}
SherpaNcnnRecognizerConfig config;
memset(&config, 0, sizeof(config));

config.model_config.tokens = argv[1];
config.model_config.encoder_param = argv[2];
config.model_config.encoder_bin = argv[3];
Expand All @@ -57,6 +59,7 @@ int32_t main(int32_t argc, char *argv[]) {
if (argc >= 10 && atoi(argv[9]) > 0) {
num_threads = atoi(argv[9]);
}

config.model_config.num_threads = num_threads;
config.model_config.use_vulkan_compute = 0;

Expand All @@ -65,6 +68,7 @@ int32_t main(int32_t argc, char *argv[]) {
if (argc >= 11) {
config.decoder_config.decoding_method = argv[10];
}

config.decoder_config.num_active_paths = 4;
config.enable_endpoint = 0;
config.rule1_min_trailing_silence = 2.4;
Expand All @@ -73,16 +77,14 @@ int32_t main(int32_t argc, char *argv[]) {

config.feat_config.sampling_rate = 16000;
config.feat_config.feature_dim = 80;
if(argc >= 12) {
if (argc >= 12) {
config.hotwords_file = argv[11];
} else {
config.hotwords_file = "";
}
if(argc == 13) {

if (argc == 13) {
config.hotwords_score = atof(argv[12]);
} else {
config.hotwords_score = 1.5;
}

SherpaNcnnRecognizer *recognizer = CreateRecognizer(&config);

const char *wav_filename = argv[8];
Expand Down
5 changes: 5 additions & 0 deletions scripts/dotnet/sherpa-ncnn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ public struct OnlineRecognizerConfig
public float Rule1MinTrailingSilence;
public float Rule2MinTrailingSilence;
public float Rule3MinUtteranceLength;

[MarshalAs(UnmanagedType.LPStr)]
public string HotwordsFile;

public float HotwordsScore;
}

// please see
Expand Down
8 changes: 8 additions & 0 deletions scripts/go/sherpa_ncnn.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ type RecognizerConfig struct {
Rule1MinTrailingSilence float32
Rule2MinTrailingSilence float32
Rule3MinUtteranceLength float32

HotwordsFile string
HotwordsScore float32
}

// It contains the recognition result for a online stream.
Expand Down Expand Up @@ -148,6 +151,11 @@ func NewRecognizer(config *RecognizerConfig) *Recognizer {
c.rule2_min_trailing_silence = C.float(config.Rule2MinTrailingSilence)
c.rule3_min_utterance_length = C.float(config.Rule3MinUtteranceLength)

c.hotwords_file = C.CString(config.HotwordsFile)
defer C.free(unsafe.Pointer(c.hotwords_file))

c.hotwords_score = C.float(config.HotwordsScore)

recognizer := &Recognizer{}
recognizer.impl = C.CreateRecognizer(&c)

Expand Down
19 changes: 14 additions & 5 deletions sherpa-ncnn/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ struct SherpaNcnnDisplay {
std::unique_ptr<sherpa_ncnn::Display> impl;
};

#define SHERPA_NCNN_OR(x, y) (x ? x : y)

SherpaNcnnRecognizer *CreateRecognizer(
const SherpaNcnnRecognizerConfig *in_config) {
// model_config
Expand All @@ -56,7 +58,7 @@ SherpaNcnnRecognizer *CreateRecognizer(
config.model_config.use_vulkan_compute =
in_config->model_config.use_vulkan_compute;

int32_t num_threads = in_config->model_config.num_threads;
int32_t num_threads = SHERPA_NCNN_OR(in_config->model_config.num_threads, 1);

config.model_config.encoder_opt.num_threads = num_threads;
config.model_config.decoder_opt.num_threads = num_threads;
Expand All @@ -66,8 +68,9 @@ SherpaNcnnRecognizer *CreateRecognizer(
config.decoder_config.method = in_config->decoder_config.decoding_method;
config.decoder_config.num_active_paths =
in_config->decoder_config.num_active_paths;
config.hotwords_file = in_config->hotwords_file;
config.hotwords_score = in_config->hotwords_score;

config.hotwords_file = SHERPA_NCNN_OR(in_config->hotwords_file, "");
config.hotwords_score = SHERPA_NCNN_OR(in_config->hotwords_score, 1.5);

config.enable_endpoint = in_config->enable_endpoint;

Expand All @@ -80,11 +83,17 @@ SherpaNcnnRecognizer *CreateRecognizer(
config.endpoint_config.rule3.min_utterance_length =
in_config->rule3_min_utterance_length;

config.feat_config.sampling_rate = in_config->feat_config.sampling_rate;
config.feat_config.feature_dim = in_config->feat_config.feature_dim;
config.feat_config.sampling_rate =
SHERPA_NCNN_OR(in_config->feat_config.sampling_rate, 16000);

config.feat_config.feature_dim =
SHERPA_NCNN_OR(in_config->feat_config.feature_dim, 80);

auto recognizer = std::make_unique<sherpa_ncnn::Recognizer>(config);

if (!recognizer->GetModel()) {
NCNN_LOGE("Failed to create the recognizer! Please check your config: %s",
config.ToString().c_str());
return nullptr;
}

Expand Down
8 changes: 4 additions & 4 deletions sherpa-ncnn/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ SHERPA_NCNN_API typedef struct SherpaNcnnRecognizerConfig {
/// this value.
/// Used only when enable_endpoint is not 0.
float rule3_min_utterance_length;
/// hotwords file, each line is a hotword which is segmented into char by space
/// if language is something like CJK, segment manually,
/// if language is something like English, segment by bpe model.

/// hotwords file, each line is a hotword which is segmented into char by
/// space if language is something like CJK, segment manually, if language is
/// something like English, segment by bpe model.
const char *hotwords_file;

/// scale of hotwords, used only when hotwords_file is not empty
Expand Down
8 changes: 7 additions & 1 deletion sherpa-ncnn/csrc/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ struct DecoderResult {
// used only for modified_beam_search
Hypotheses hyps;
};

class Stream;

class Decoder {
public:
virtual ~Decoder() = default;
Expand Down Expand Up @@ -88,7 +90,11 @@ class Decoder {
* and there are no paddings.
*/
virtual void Decode(ncnn::Mat encoder_out, DecoderResult *result) = 0;
virtual void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result){};

virtual void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result) {
NCNN_LOGE("Please override it!");
exit(-1);
}
};

} // namespace sherpa_ncnn
Expand Down
3 changes: 2 additions & 1 deletion sherpa-ncnn/csrc/hypothesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <unordered_map>
#include <utility>
#include <vector>

#include "sherpa-ncnn/csrc/context-graph.h"

namespace sherpa_ncnn {
Expand All @@ -43,7 +44,7 @@ struct Hypothesis {

Hypothesis() = default;
Hypothesis(const std::vector<int32_t> &ys, double log_prob,
const ContextState *context_state = nullptr)
const ContextState *context_state = nullptr)
: ys(ys), log_prob(log_prob), context_state(context_state) {}

// If two Hypotheses have the same `Key`, then they contain
Expand Down
9 changes: 3 additions & 6 deletions sherpa-ncnn/csrc/modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
result->num_trailing_blanks = hyp.num_trailing_blanks;
}


void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
DecoderResult *result) {
int32_t context_size = model_->ContextSize();
Expand All @@ -205,7 +204,6 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
std::vector<Hypothesis> prev = cur.GetTopK(num_active_paths_, true);
cur.Clear();


ncnn::Mat decoder_input = BuildDecoderInput(prev);
ncnn::Mat decoder_out;
if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size &&
Expand All @@ -218,14 +216,13 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,

// decoder_out.w == decoder_dim
// decoder_out.h == num_active_paths
ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));

ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);
// joiner_out.w == vocab_size
// joiner_out.h == num_active_paths
LogSoftmax(&joiner_out);


float *p_joiner_out = joiner_out;

for (int32_t i = 0; i != joiner_out.h; ++i) {
Expand Down Expand Up @@ -255,8 +252,8 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
new_hyp.num_trailing_blanks = 0;
new_hyp.timestamps.push_back(t + frame_offset);
if (s != nullptr && s->GetContextGraph() != nullptr) {
auto context_res = s->GetContextGraph()->ForwardOneStep(
context_state, new_token);
auto context_res =
s->GetContextGraph()->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
}
Expand Down
Loading

0 comments on commit 401de81

Please sign in to comment.