Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add java api for hotwords #319

Merged
merged 4 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ data class OnlineRecognizerConfig(
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
)

class SherpaOnnx(
Expand Down
13 changes: 8 additions & 5 deletions java-api-examples/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

ENTRY_POINT = ./

LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx
Expand Down Expand Up @@ -65,18 +64,22 @@ clean:
mkdir -p ./lib

runfile:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav

java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile
runhotwords:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav

runmic:

java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic

runsrv:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer ../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer $(shell pwd)/../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg

runclient:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient ../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32

runclienthotwords:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./hotwords.wav 32

buildlib: $(LIB_FILES:.java=.class)

Expand Down
2 changes: 2 additions & 0 deletions java-api-examples/modelconfig.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ num_threads=4
enable_endpoint_detection=true
decoding_method=modified_beam_search
max_active_paths=4
hotwords_file=
hotwords_score=1.5
lm_model=
lm_scale=0.5
model_type=zipformer
Expand Down
8 changes: 8 additions & 0 deletions java-api-examples/runtest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ if [ ! -d $repo ];then
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
ln -s $repo/test_wavs/0.wav hotwords.wav

fi

log $(pwd)
Expand Down Expand Up @@ -64,3 +66,9 @@ cd ../java-api-examples
make all

make runfile

echo "礼 拜 二" > hotwords.txt

sed -i 's/hotwords_file=/hotwords_file=hotwords.txt/g' modeltest.cfg

make runhotwords
6 changes: 5 additions & 1 deletion java-api-examples/src/DecodeFile.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public void initModelWithPara() {
float rule3MinUtteranceLength = 20F;
String decodingMethod = "greedy_search";
int maxActivePaths = 4;
String hotwordsFile = "";
float hotwordsScore = 1.5F;
String lm_model = "";
float lm_scale = 0.5F;
String modelType = "zipformer";
Expand All @@ -69,6 +71,8 @@ public void initModelWithPara() {
lm_model,
lm_scale,
maxActivePaths,
hotwordsFile,
hotwordsScore,
modelType);
streamObj = rcgOjb.createStream();
} catch (Exception e) {
Expand Down Expand Up @@ -158,7 +162,7 @@ public static void main(String[] args) {
try {
String appDir = System.getProperty("user.dir");
System.out.println("appdir=" + appDir);
String fileName = appDir + "/test.wav";
String fileName = appDir + "/" + args[0];
String cfgPath = appDir + "/modeltest.cfg";
String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so";
OnlineRecognizer.setSoPath(soPath);
Expand Down
18 changes: 8 additions & 10 deletions java-api-examples/src/websocketsrv/AsrWebsocketServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ public void onMessage(WebSocket conn, ByteBuffer blob) {
}
}



public boolean streamQueueFind(WebSocket conn) {
return streamQueue.contains(conn);
}
Expand All @@ -151,16 +149,16 @@ public void initModelWithCfg(Map<String, String> cfgMap, String cfgPath) {

rcgOjb = new OnlineRecognizer(cfgPath);
// size of stream thread pool
int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num"));
int streamThreadNum = Integer.valueOf(cfgMap.getOrDefault("stream_thread_num", "16"));
// size of decoder thread pool
int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num"));
int decoderThreadNum = Integer.valueOf(cfgMap.getOrDefault("decoder_thread_num", "16"));

// time(ms) idle for decoder thread when no job
int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle"));
int decoderTimeIdle = Integer.valueOf(cfgMap.getOrDefault("decoder_time_idle", "200"));
// size of streams for parallel decoding
int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num"));
int parallelDecoderNum = Integer.valueOf(cfgMap.getOrDefault("parallel_decoder_num", "16"));
// time(ms) out for connection data
int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out"));
int deocderTimeOut = Integer.valueOf(cfgMap.getOrDefault("deocder_time_out", "30000"));

// create stream threads
for (int i = 0; i < streamThreadNum; i++) {
Expand Down Expand Up @@ -218,13 +216,13 @@ public static void main(String[] args) throws InterruptedException, IOException

String soPath = args[0];
String cfgPath = args[1];

OnlineRecognizer.setSoPath(soPath);
logger.info("readProperties");
Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
int port = Integer.valueOf(cfgMap.get("port"));
int port = Integer.valueOf(cfgMap.getOrDefault("port", "8890"));

int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
int connectionThreadNum = Integer.valueOf(cfgMap.getOrDefault("connection_thread_num", "16"));
AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
logger.info("initModelWithCfg");
s.initModelWithCfg(cfgMap, cfgPath);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,48 +44,60 @@ public class OnlineRecognizer {
public OnlineRecognizer(String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);

OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(proMap.get("encoder").trim(), proMap.get("decoder").trim());
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.get("encoder").trim(),
proMap.get("decoder").trim(),
proMap.get("joiner").trim());
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()),
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.get("model_type").trim(),
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));

OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
proMap.get("decoding_method").trim(),
Integer.parseInt(proMap.get("max_active_paths").trim()));
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);

Expand All @@ -98,51 +110,61 @@ public OnlineRecognizer(String modelCfgPath) {
public OnlineRecognizer(Object assetManager, String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.get("encoder").trim(), proMap.get("decoder").trim());
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.get("encoder").trim(),
proMap.get("decoder").trim(),
proMap.get("joiner").trim());
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());

OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()),
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.get("model_type").trim(),
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));

OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));

OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
proMap.get("decoding_method").trim(),
Integer.parseInt(proMap.get("max_active_paths").trim()));
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(assetManager, rcgCfg);

Expand All @@ -168,6 +190,8 @@ public OnlineRecognizer(
String lm_model,
float lm_scale,
int maxActivePaths,
String hotwordsFile,
float hotwordsScore,
String modelType) {
this.sampleRate = sampleRate;
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
Expand All @@ -189,7 +213,9 @@ public OnlineRecognizer(
onlineLmConfig,
enableEndpointDetection,
decodingMethod,
maxActivePaths);
maxActivePaths,
hotwordsFile,
hotwordsScore);
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
}
Expand All @@ -211,7 +237,6 @@ private Map<String, String> readProperties(String modelCfgPath) {
String key = (String) en.nextElement();
String Property = props.getProperty(key);
proMap.put(key, Property);
// System.out.println(key+"="+Property);
}

} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ public class OnlineRecognizerConfig {
private final boolean enableEndpoint;
private final String decodingMethod;
private final int maxActivePaths;
private final String hotwordsFile;
private final float hotwordsScore;

public OnlineRecognizerConfig(
FeatureConfig featConfig,
Expand All @@ -20,14 +22,18 @@ public OnlineRecognizerConfig(
OnlineLMConfig lmConfig,
boolean enableEndpoint,
String decodingMethod,
int maxActivePaths) {
int maxActivePaths,
String hotwordsFile,
float hotwordsScore) {
this.featConfig = featConfig;
this.modelConfig = modelConfig;
this.endpointConfig = endpointConfig;
this.lmConfig = lmConfig;
this.enableEndpoint = enableEndpoint;
this.decodingMethod = decodingMethod;
this.maxActivePaths = maxActivePaths;
this.hotwordsFile = hotwordsFile;
this.hotwordsScore = hotwordsScore;
}

public OnlineLMConfig getLmConfig() {
Expand Down
18 changes: 18 additions & 0 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);

fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.hotwords_file = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);

//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
Expand Down Expand Up @@ -293,6 +302,15 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);

fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.hotwords_file = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);

//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
Expand Down
Loading