Skip to content

Commit

Permalink
Add jni interface and kotlin API examples for TTS. (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 23, 2023
1 parent b582f6c commit 0fdb204
Show file tree
Hide file tree
Showing 15 changed files with 454 additions and 37 deletions.
3 changes: 3 additions & 0 deletions kotlin-api-examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
hs_err*
main.jar
vits-zh-aishell3
22 changes: 22 additions & 0 deletions kotlin-api-examples/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,28 @@ package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager

fun main() {
testTts()
testAsr()
}

fun testTts() {
var config = OfflineTtsConfig(
model=OfflineTtsModelConfig(
vits=OfflineTtsVitsModelConfig(
model="./vits-zh-aishell3/vits-aishell3.onnx",
lexicon="./vits-zh-aishell3/lexicon.txt",
tokens="./vits-zh-aishell3/tokens.txt",
),
numThreads=1,
debug=true,
)
)
val tts = OfflineTts(config=config)
val audio = tts.generate(text="林美丽最美丽!", sid=99, speed=1.2f)
audio.save(filename="99.wav")
}

fun testAsr() {
var featConfig = FeatureConfig(
sampleRate = 16000,
featureDim = 80,
Expand Down
112 changes: 112 additions & 0 deletions kotlin-api-examples/Tts.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) 2023 Xiaomi Corporation
package com.k2fsa.sherpa.onnx

import android.content.res.AssetManager

data class OfflineTtsVitsModelConfig(
var model: String,
var lexicon: String,
var tokens: String,
var noiseScale: Float = 0.667f,
var noiseScaleW: Float = 0.8f,
var lengthScale: Float = 1.0f,
)

data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)

data class OfflineTtsConfig(
var model: OfflineTtsModelConfig,
)

class GeneratedAudio(
val samples : FloatArray,
val sampleRate: Int,
) {
fun save(filename: String) = saveImpl(filename=filename, samples=samples, sampleRate=sampleRate)

private external fun saveImpl(
filename: String,
samples: FloatArray,
sampleRate: Int
): Boolean
}

class OfflineTts(
assetManager: AssetManager? = null,
var config: OfflineTtsConfig,
) {
private var ptr: Long

init {
if (assetManager != null) {
ptr = new(assetManager, config)
} else {
ptr = newFromFile(config)
}
}

fun generate(
text: String,
sid: Int = 0,
speed: Float = 1.0f
): GeneratedAudio {
var objArray = generateImpl(ptr, text=text, sid=sid, speed=speed)
return GeneratedAudio(samples=objArray[0] as FloatArray,
sampleRate=objArray[1] as Int)
}

fun allocate(assetManager: AssetManager? = null) {
if (ptr == 0L) {
if (assetManager != null) {
ptr = new(assetManager, config)
} else {
ptr = newFromFile(config)
}
}
}

fun free() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}

protected fun finalize() {
delete(ptr)
}

private external fun new(
assetManager: AssetManager,
config: OfflineTtsConfig,
): Long

private external fun newFromFile(
config: OfflineTtsConfig,
): Long

private external fun delete(ptr: Long)

// The returned array has two entries:
// - the first entry is an 1-D float array containing audio samples.
// Each sample is normalized to the range [-1, 1]
// - the second entry is the sample rate
external fun generateImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f
): Array<Any>

companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}

}
27 changes: 15 additions & 12 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@

set -e


cd ..
mkdir -p build
cd build

cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..

make -j4
ls -lh lib
if [ ! -f ../build/lib/libsherpa-onnx-jni.dylib ]; then
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..

make -j4
ls -lh lib
fi

export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH

Expand All @@ -31,7 +34,7 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
fi

kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt
kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt

ls -lh main.jar

Expand Down
56 changes: 45 additions & 11 deletions sherpa-onnx/csrc/lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
#include <sstream>
#include <utility>

#if __ANDROID_API__ >= 9
#include <strstream>

#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {
Expand All @@ -22,11 +30,9 @@ static void ToLowerCase(std::string *in_out) {

// Note: We don't use SymbolTable here since tokens may contain a blank
// in the first column
static std::unordered_map<std::string, int32_t> ReadTokens(
const std::string &tokens) {
static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
std::unordered_map<std::string, int32_t> token2id;

std::ifstream is(tokens);
std::string line;

std::string sym;
Expand Down Expand Up @@ -80,11 +86,43 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
bool debug /*= false*/)
: debug_(debug) {
InitLanguage(language);
InitTokens(tokens);
InitLexicon(lexicon);

{
std::ifstream is(tokens);
InitTokens(is);
}

{
std::ifstream is(lexicon);
InitLexicon(is);
}

InitPunctuations(punctuations);
}

#if __ANDROID_API__ >= 9
Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &punctuations,
const std::string &language, bool debug /*= false*/)
: debug_(debug) {
InitLanguage(language);

{
auto buf = ReadFile(mgr, tokens);
std::istrstream is(buf.data(), buf.size());
InitTokens(is);
}

{
auto buf = ReadFile(mgr, lexicon);
std::istrstream is(buf.data(), buf.size());
InitLexicon(is);
}

InitPunctuations(punctuations);
}
#endif

std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &text) const {
switch (language_) {
Expand Down Expand Up @@ -192,9 +230,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
return ans;
}

void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
}
void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }

void Lexicon::InitLanguage(const std::string &_lang) {
std::string lang(_lang);
Expand All @@ -209,9 +245,7 @@ void Lexicon::InitLanguage(const std::string &_lang) {
}
}

void Lexicon::InitLexicon(const std::string &lexicon) {
std::ifstream is(lexicon);

void Lexicon::InitLexicon(std::istream &is) {
std::string word;
std::vector<std::string> token_list;
std::string line;
Expand Down
16 changes: 14 additions & 2 deletions sherpa-onnx/csrc/lexicon.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
#define SHERPA_ONNX_CSRC_LEXICON_H_

#include <cstdint>
#include <iostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

namespace sherpa_onnx {

// TODO(fangjun): Refactor it to an abstract class
Expand All @@ -20,6 +26,12 @@ class Lexicon {
const std::string &punctuations, const std::string &language,
bool debug = false);

#if __ANDROID_API__ >= 9
Lexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &punctuations,
const std::string &language, bool debug = false);
#endif

std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;

private:
Expand All @@ -30,8 +42,8 @@ class Lexicon {
const std::string &text) const;

void InitLanguage(const std::string &lang);
void InitTokens(const std::string &tokens);
void InitLexicon(const std::string &lexicon);
void InitTokens(std::istream &is);
void InitLexicon(std::istream &is);
void InitPunctuations(const std::string &punctuations);

private:
Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/offline-tts-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,12 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
return std::make_unique<OfflineTtsVitsImpl>(config);
}

#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
AAssetManager *mgr, const OfflineTtsConfig &config) {
// TODO(fangjun): Support other types
return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
}
#endif

} // namespace sherpa_onnx
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/offline-tts-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
#include <memory>
#include <string>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/offline-tts.h"

namespace sherpa_onnx {
Expand All @@ -18,6 +23,11 @@ class OfflineTtsImpl {

static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);

#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineTtsImpl> Create(AAssetManager *mgr,
const OfflineTtsConfig &config);
#endif

virtual GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0) const = 0;
};
Expand Down
13 changes: 13 additions & 0 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
Expand All @@ -24,6 +29,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
model_->Punctuations(), model_->Language(),
config.model.debug) {}

#if __ANDROID_API__ >= 9
OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations(), model_->Language(),
config.model.debug) {}
#endif

GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0) const override {
int32_t num_speakers = model_->NumSpeakers();
Expand Down
Loading

0 comments on commit 0fdb204

Please sign in to comment.