Skip to content

Commit

Permalink
Support c-api (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool authored Sep 18, 2023
1 parent bd18e07 commit a5d1c90
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 7 deletions.
4 changes: 2 additions & 2 deletions c-api-examples/Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

CUR_DIR :=$(shell pwd)

CFLAGS := -I ../
CFLAGS := -I ../ -I ../build/_deps/cargs-src/include/
LDFLAGS := -L ../build/lib
LDFLAGS += -L ../build/_deps/onnxruntime-src/lib
LDFLAGS += -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core
LDFLAGS += -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core -lcargs
LDFLAGS += -Wl,-rpath,${CUR_DIR}/../build/lib
LDFLAGS += -Wl,-rpath,${CUR_DIR}/../build/_deps/onnxruntime-src/lib

Expand Down
22 changes: 21 additions & 1 deletion c-api-examples/decode-file-c-api.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,21 @@ static struct cag_option options[] = {
.access_name = "decoding-method",
.value_name = "decoding-method",
.description =
"Decoding method: greedy_search (default), modified_beam_search"}};
"Decoding method: greedy_search (default), modified_beam_search"},
{.identifier = 'f',
.access_letters = NULL,
.access_name = "hotwords-file",
.value_name = "hotwords-file",
.description = "The file containing hotwords, one words/phrases per line, "
"and for each phrase the bpe/cjkchar are separated by a "
"space. For example: ▁HE LL O ▁WORLD, 你 好 世 界"},
{.identifier = 's',
.access_letters = NULL,
.access_name = "hotwords-score",
.value_name = "hotwords-score",
.description = "The bonus score for each token in hotwords. Used only "
"when decoding_method is modified_beam_search"},
};

const char *kUsage =
"\n"
Expand Down Expand Up @@ -130,6 +144,12 @@ int32_t main(int32_t argc, char *argv[]) {
case 'm':
config.decoding_method = value;
break;
case 'f':
config.hotwords_file = value;
break;
case 's':
config.hotwords_score = atof(value);
break;
case 'h': {
fprintf(stderr, "%s\n", kUsage);
exit(0);
Expand Down
22 changes: 18 additions & 4 deletions c-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,22 @@ if [ ! -f ./decode-file-c-api ]; then
fi

./decode-file-c-api \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
--encoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav

# Run with hotwords

echo "礼 拜 二" > hotwords.txt

./decode-file-c-api \
--tokens=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
--encoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--hotwords-file=hotwords.txt \
--hotwords-score=1.5 \
--decoding-method=modified_beam_search \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav
8 changes: 8 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.endpoint_config.rule3.min_utterance_length =
SHERPA_ONNX_OR(config->rule3_min_utterance_length, 20);

recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);

if (config->model_config.debug) {
fprintf(stderr, "%s\n", recognizer_config.ToString().c_str());
}
Expand Down Expand Up @@ -297,6 +301,10 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
recognizer_config.max_active_paths =
SHERPA_ONNX_OR(config->max_active_paths, 4);

recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);

if (config->model_config.debug) {
fprintf(stderr, "%s\n", recognizer_config.ToString().c_str());
}
Expand Down
12 changes: 12 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
/// this value.
/// Used only when enable_endpoint is not 0.
float rule3_min_utterance_length;

/// Path to the hotwords.
const char *hotwords_file;

/// Bonus score for each token in hotwords.
float hotwords_score;
} SherpaOnnxOnlineRecognizerConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {
Expand Down Expand Up @@ -335,6 +341,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {

const char *decoding_method;
int32_t max_active_paths;

/// Path to the hotwords.
const char *hotwords_file;

/// Bonus score for each token in hotwords.
float hotwords_score;
} SherpaOnnxOfflineRecognizerConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizer
Expand Down

0 comments on commit a5d1c90

Please sign in to comment.