Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support playing as it is generating for Android #477

Merged
merged 3 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.k2fsa.sherpa.onnx

import android.content.res.AssetManager
import android.media.MediaPlayer
import android.media.*
import android.net.Uri
import android.os.Bundle
import android.util.Log
Expand All @@ -23,6 +23,10 @@ class MainActivity : AppCompatActivity() {
private lateinit var generate: Button
private lateinit var play: Button

// see
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
private lateinit var track: AudioTrack

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
Expand All @@ -31,6 +35,10 @@ class MainActivity : AppCompatActivity() {
initTts()
Log.i(TAG, "Finish initializing TTS")

Log.i(TAG, "Start to initialize AudioTrack")
initAudioTrack()
Log.i(TAG, "Finish initializing AudioTrack")

text = findViewById(R.id.text)
sid = findViewById(R.id.sid)
speed = findViewById(R.id.speed)
Expand All @@ -51,6 +59,33 @@ class MainActivity : AppCompatActivity() {
play.isEnabled = false
}

private fun initAudioTrack() {
val sampleRate = tts.sampleRate()
val bufLength = (sampleRate * 0.1).toInt()
Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}")

val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()

val format = AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.setSampleRate(sampleRate)
.build()

track = AudioTrack(
attr, format, bufLength, AudioTrack.MODE_STREAM,
AudioManager.AUDIO_SESSION_ID_GENERATE
)
track.play()
}

// this function is called from C++
private fun callback(samples: FloatArray) {
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
}

private fun onClickGenerate() {
val sidInt = sid.text.toString().toIntOrNull()
if (sidInt == null || sidInt < 0) {
Expand Down Expand Up @@ -79,16 +114,28 @@ class MainActivity : AppCompatActivity() {
return
}

play.isEnabled = false
val audio = tts.generate(text = textStr, sid = sidInt, speed = speedFloat)
track.pause()
track.flush()
track.play()

val filename = application.filesDir.absolutePath + "/generated.wav"
val ok = audio.samples.size > 0 && audio.save(filename)
if (ok) {
play.isEnabled = true
// Play automatically after generation
onClickPlay()
}
play.isEnabled = false
Thread {
val audio = tts.generateWithCallback(
text = textStr,
sid = sidInt,
speed = speedFloat,
callback = this::callback
)

val filename = application.filesDir.absolutePath + "/generated.wav"
val ok = audio.samples.size > 0 && audio.save(filename)
if (ok) {
runOnUiThread {
play.isEnabled = true
track.stop()
}
}
}.start()
}

private fun onClickPlay() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class OfflineTts(
}
}

fun sampleRate() = getSampleRate(ptr)

fun generate(
text: String,
sid: Int = 0,
Expand All @@ -66,6 +68,19 @@ class OfflineTts(
)
}

fun generateWithCallback(
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Unit
): GeneratedAudio {
var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed, callback=callback)
return GeneratedAudio(
samples = objArray[0] as FloatArray,
sampleRate = objArray[1] as Int
)
}

fun allocate(assetManager: AssetManager? = null) {
if (ptr == 0L) {
if (assetManager != null) {
Expand Down Expand Up @@ -97,6 +112,7 @@ class OfflineTts(
): Long

private external fun delete(ptr: Long)
private external fun getSampleRate(ptr: Long): Int

// The returned array has two entries:
// - the first entry is an 1-D float array containing audio samples.
Expand All @@ -109,6 +125,14 @@ class OfflineTts(
speed: Float = 1.0f
): Array<Any>

external fun generateWithCallbackImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Unit
): Array<Any>

companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
Expand Down
6 changes: 5 additions & 1 deletion kotlin-api-examples/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package com.k2fsa.sherpa.onnx

import android.content.res.AssetManager

fun callback(samples: FloatArray): Unit {
println("callback got called with ${samples.size} samples");
}

fun main() {
testTts()
testAsr()
Expand All @@ -22,7 +26,7 @@ fun testTts() {
)
)
val tts = OfflineTts(config=config)
val audio = tts.generate(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”")
val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback)
audio.save(filename="test-en.wav")
}

Expand Down
102 changes: 51 additions & 51 deletions scripts/apk/generate-tts-apk-script.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,57 +172,57 @@ def get_vits_models() -> List[TtsModel]:
lang="zh",
rule_fsts="vits-zh-aishell3/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-doom",
model_name="doom.onnx",
lang="zh",
rule_fsts="vits-zh-hf-doom/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-echo",
model_name="echo.onnx",
lang="zh",
rule_fsts="vits-zh-hf-echo/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-zenyatta",
model_name="zenyatta.onnx",
lang="zh",
rule_fsts="vits-zh-hf-zenyatta/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-abyssinvoker",
model_name="abyssinvoker.onnx",
lang="zh",
rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-keqing",
model_name="keqing.onnx",
lang="zh",
rule_fsts="vits-zh-hf-keqing/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-eula",
model_name="eula.onnx",
lang="zh",
rule_fsts="vits-zh-hf-eula/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-bronya",
model_name="bronya.onnx",
lang="zh",
rule_fsts="vits-zh-hf-bronya/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-theresa",
model_name="theresa.onnx",
lang="zh",
rule_fsts="vits-zh-hf-theresa/rule.fst",
),
# TtsModel(
# model_dir="vits-zh-hf-doom",
# model_name="doom.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-doom/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-echo",
# model_name="echo.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-echo/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-zenyatta",
# model_name="zenyatta.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-zenyatta/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-abyssinvoker",
# model_name="abyssinvoker.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-keqing",
# model_name="keqing.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-keqing/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-eula",
# model_name="eula.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-eula/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-bronya",
# model_name="bronya.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-bronya/rule.fst",
# ),
# TtsModel(
# model_dir="vits-zh-hf-theresa",
# model_name="theresa.onnx",
# lang="zh",
# rule_fsts="vits-zh-hf-theresa/rule.fst",
# ),
# English (US)
TtsModel(model_dir="vits-vctk", model_name="vits-vctk.onnx", lang="en"),
TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"),
# TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"),
# fmt: on
]

Expand All @@ -238,8 +238,8 @@ def main():
template = environment.from_string(s)
d = dict()

# all_model_list = get_vits_models()
all_model_list = get_piper_models()
all_model_list = get_vits_models()
all_model_list += get_piper_models()
all_model_list += get_coqui_models()

num_models = len(all_model_list)
Expand Down
57 changes: 53 additions & 4 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
#include "jni.h" // NOLINT

#include <fstream>
#include <functional>
#include <strstream>
#include <utility>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include <fstream>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
Expand Down Expand Up @@ -502,11 +504,14 @@ class SherpaOnnxOfflineTts {
explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config)
: tts_(config) {}

GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0) const {
return tts_.Generate(text, sid, speed);
GeneratedAudio Generate(
const std::string &text, int64_t sid = 0, float speed = 1.0,
std::function<void(const float *, int32_t)> callback = nullptr) const {
return tts_.Generate(text, sid, speed, callback);
}

int32_t SampleRate() const { return tts_.SampleRate(); }

private:
OfflineTts tts_;
};
Expand Down Expand Up @@ -628,6 +633,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete(
delete reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr);
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)
->SampleRate();
}

// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
static jobject NewInteger(JNIEnv *env, int32_t value) {
Expand Down Expand Up @@ -663,6 +675,43 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/,
return obj_arr;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid,
jfloat speed, jobject callback) {
const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text);

std::function<void(const float *, int32_t)> callback_wrapper =
[env, callback](const float *samples, int32_t n) {
jclass cls = env->GetObjectClass(callback);
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");

jfloatArray samples_arr = env->NewFloatArray(n);
env->SetFloatArrayRegion(samples_arr, 0, n, samples);
env->CallVoidMethod(callback, mid, samples_arr);
};

auto audio =
reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)->Generate(
p_text, sid, speed, callback_wrapper);

jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
audio.samples.data());

jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);

env->SetObjectArrayElement(obj_arr, 0, samples_arr);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));

env->ReleaseStringUTFChars(text, p_text);

return obj_arr;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
Expand Down
Loading