Skip to content

Commit

Permalink
Enable to stop TTS generation (#1041)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jun 22, 2024
1 parent 96ab843 commit 9dd0e03
Show file tree
Hide file tree
Showing 32 changed files with 248 additions and 69 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ project(sherpa-onnx)
# ./nodejs-addon-examples
# ./dart-api-examples/
# ./sherpa-onnx/flutter/CHANGELOG.md
set(SHERPA_ONNX_VERSION "1.10.0")
set(SHERPA_ONNX_VERSION "1.10.1")

# Disable warning about
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class MainActivity : AppCompatActivity() {
private lateinit var speed: EditText
private lateinit var generate: Button
private lateinit var play: Button
private lateinit var stop: Button
private var stopped: Boolean = false
private var mediaPlayer: MediaPlayer? = null

// see
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
Expand All @@ -49,9 +52,11 @@ class MainActivity : AppCompatActivity() {

generate = findViewById(R.id.generate)
play = findViewById(R.id.play)
stop = findViewById(R.id.stop)

generate.setOnClickListener { onClickGenerate() }
play.setOnClickListener { onClickPlay() }
stop.setOnClickListener { onClickStop() }

sid.setText("0")
speed.setText("1.0")
Expand All @@ -70,7 +75,7 @@ class MainActivity : AppCompatActivity() {
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_FLOAT
)
Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}")
Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength")

val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
Expand All @@ -90,8 +95,14 @@ class MainActivity : AppCompatActivity() {
}

// this function is called from C++
private fun callback(samples: FloatArray) {
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
private fun callback(samples: FloatArray): Int {
if (!stopped) {
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
return 1
} else {
track.stop()
return 0
}
}

private fun onClickGenerate() {
Expand Down Expand Up @@ -127,6 +138,8 @@ class MainActivity : AppCompatActivity() {
track.play()

play.isEnabled = false
generate.isEnabled = false
stopped = false
Thread {
val audio = tts.generateWithCallback(
text = textStr,
Expand All @@ -140,6 +153,7 @@ class MainActivity : AppCompatActivity() {
if (ok) {
runOnUiThread {
play.isEnabled = true
generate.isEnabled = true
track.stop()
}
}
Expand All @@ -148,11 +162,22 @@ class MainActivity : AppCompatActivity() {

private fun onClickPlay() {
val filename = application.filesDir.absolutePath + "/generated.wav"
val mediaPlayer = MediaPlayer.create(
mediaPlayer?.stop()
mediaPlayer = MediaPlayer.create(
applicationContext,
Uri.fromFile(File(filename))
)
mediaPlayer.start()
mediaPlayer?.start()
}

private fun onClickStop() {
stopped = true
play.isEnabled = true
generate.isEnabled = true
track.pause()
track.flush()
mediaPlayer?.stop()
mediaPlayer = null
}

private fun initTts() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class OfflineTts(
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Unit
callback: (samples: FloatArray) -> Int
): GeneratedAudio {
val objArray = generateWithCallbackImpl(
ptr,
Expand Down Expand Up @@ -146,7 +146,7 @@ class OfflineTts(
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Unit
callback: (samples: FloatArray) -> Int
): Array<Any>

companion object {
Expand Down
12 changes: 12 additions & 0 deletions android/SherpaOnnxTts/app/src/main/res/layout/activity_main.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,16 @@
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toBottomOf="@id/generate" />

<Button
android:id="@+id/stop"
android:textAllCaps="false"
android:layout_width="match_parent"
android:layout_height="50dp"
android:layout_marginTop="4dp"
android:text="@string/stop"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toBottomOf="@id/play" />

</androidx.constraintlayout.widget.ConstraintLayout>
1 change: 1 addition & 0 deletions android/SherpaOnnxTts/app/src/main/res/values/strings.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
<string name="text_hint">Please input your text here</string>
<string name="generate">Generate</string>
<string name="play">Play</string>
<string name="stop">Stop</string>
</resources>
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class TtsService : TextToSpeechService() {
return
}

val ttsCallback = { floatSamples: FloatArray ->
val ttsCallback: (FloatArray) -> Int = fun(floatSamples): Int {
// convert FloatArray to ByteArray
val samples = floatArrayToByteArray(floatSamples)
val maxBufferSize: Int = callback.maxBufferSize
Expand All @@ -137,6 +137,9 @@ class TtsService : TextToSpeechService() {
offset += bytesToWrite
}

// 1 means to continue
// 0 means to stop
return 1
}

Log.i(TAG, "text: $text")
Expand All @@ -160,4 +163,4 @@ class TtsService : TextToSpeechService() {
}
return byteArray
}
}
}
2 changes: 1 addition & 1 deletion dart-api-examples/non-streaming-asr/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ environment:

# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.0
sherpa_onnx: ^1.10.1
path: ^1.9.0
args: ^2.5.0

Expand Down
2 changes: 1 addition & 1 deletion dart-api-examples/streaming-asr/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ environment:

# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.0
sherpa_onnx: ^1.10.1
path: ^1.9.0
args: ^2.5.0

Expand Down
4 changes: 4 additions & 0 deletions dart-api-examples/tts/bin/piper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ void main(List<String> arguments) async {
callback: (Float32List samples) {
print('${samples.length} samples received');
// You can play samples in a separate thread/isolate

// 1 means to continue
// 0 means to stop
return 1;
});
tts.free();

Expand Down
2 changes: 1 addition & 1 deletion dart-api-examples/tts/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ environment:

# Add regular dependencies here.
dependencies:
sherpa_onnx: ^1.10.0
sherpa_onnx: ^1.10.1
path: ^1.9.0
args: ^2.5.0

Expand Down
2 changes: 1 addition & 1 deletion dart-api-examples/vad/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ environment:
sdk: ^3.4.0

dependencies:
sherpa_onnx: ^1.10.0
sherpa_onnx: ^1.10.1
path: ^1.9.0
args: ^2.5.0

Expand Down
4 changes: 4 additions & 0 deletions dotnet-examples/offline-tts-play/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ private static void Run(Options options)
Marshal.Copy(samples, data, 0, n);

dataItems.Add(data);

// 1 means to keep generating
// 0 means to stop generating
return 1;
};

bool playFinished = false;
Expand Down
42 changes: 41 additions & 1 deletion kotlin-api-examples/test_tts.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,46 @@ fun testTts() {
println("Saved to test-en.wav")
}

fun callback(samples: FloatArray): Unit {
/*
1. Unzip test_tts.jar
2.
javap ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class
3. It prints:
Compiled from "test_tts.kt"
final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> {
public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE;
com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1();
public final java.lang.Integer invoke(float[]);
public java.lang.Object invoke(java.lang.Object);
static {};
}
4.
javap -s ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class
5. It prints
Compiled from "test_tts.kt"
final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> {
public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE;
descriptor: Lcom/k2fsa/sherpa/onnx/Test_ttsKt$testTts$audio$1;
com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1();
descriptor: ()V
public final java.lang.Integer invoke(float[]);
descriptor: ([F)Ljava/lang/Integer;
public java.lang.Object invoke(java.lang.Object);
descriptor: (Ljava/lang/Object;)Ljava/lang/Object;
static {};
descriptor: ()V
}
*/
fun callback(samples: FloatArray): Int {
println("callback got called with ${samples.size} samples");

// 1 means to continue
// 0 means to stop
return 1
}
Binary file modified mfc-examples/NonStreamingTextToSpeech/NonStreamingTextToSpeech.rc
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ static bool g_started = false;
static bool g_stopped = false;
static bool g_killed = false;

static void AudioGeneratedCallback(const float *s, int32_t n) {
static int32_t AudioGeneratedCallback(const float *s, int32_t n) {
if (n > 0) {
Samples samples;
samples.data = std::vector<float>{s, s + n};
Expand All @@ -66,6 +66,10 @@ static void AudioGeneratedCallback(const float *s, int32_t n) {
g_buffer.samples.push(std::move(samples));
g_started = true;
}
if (g_killed) {
return 0;
}
return 1;
}

static int PlayCallback(const void * /*in*/, void *out,
Expand Down Expand Up @@ -324,6 +328,7 @@ BEGIN_MESSAGE_MAP(CNonStreamingTextToSpeechDlg, CDialogEx)
ON_WM_PAINT()
ON_WM_QUERYDRAGICON()
ON_BN_CLICKED(IDOK, &CNonStreamingTextToSpeechDlg::OnBnClickedOk)
ON_BN_CLICKED(IDC_STOP, &CNonStreamingTextToSpeechDlg::OnBnClickedStop)
END_MESSAGE_MAP()


Expand Down Expand Up @@ -492,11 +497,18 @@ void CNonStreamingTextToSpeechDlg::Init() {
if (tts_) {
SherpaOnnxDestroyOfflineTts(tts_);
}
if (generate_thread_ && generate_thread_->joinable()) {
generate_thread_->join();
}

if (play_thread_ && play_thread_->joinable()) {
play_thread_->join();
}
}


static std::string ToString(const CString &s) {
CT2CA pszConvertedAnsiString( s);
CT2CA pszConvertedAnsiString(s);
return std::string(pszConvertedAnsiString);
}

Expand All @@ -510,7 +522,7 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() {
}

speed_.GetWindowText(s);
float speed = static_cast<float>(_ttof(s));
float speed = static_cast<float>(_ttof(s));
if (speed < 0) {
AfxMessageBox(Utf8ToUtf16("Please input a valid speed").c_str(), MB_OK);
return;
Expand Down Expand Up @@ -541,28 +553,40 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() {
// for simplicity
play_thread_ = std::make_unique<std::thread>(StartPlayback, SherpaOnnxOfflineTtsSampleRate(tts_));

generate_btn_.EnableWindow(FALSE);

const SherpaOnnxGeneratedAudio *audio =
SherpaOnnxOfflineTtsGenerateWithCallback(tts_, ss.c_str(), speaker_id, speed, &AudioGeneratedCallback);

generate_btn_.EnableWindow(TRUE);
if (generate_thread_ && generate_thread_->joinable()) {
generate_thread_->join();
}

output_filename_.GetWindowText(s);
std::string filename = ToString(s);

int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate,
filename.c_str());
generate_thread_ = std::make_unique<std::thread>([ss, this,filename, speaker_id, speed]() {
std::string text = ss;

SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);
// generate_btn_.EnableWindow(FALSE);

if (ok) {
// AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK);
AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully");
} else {
// AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK);
AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename);
}
const SherpaOnnxGeneratedAudio *audio =
SherpaOnnxOfflineTtsGenerateWithCallback(tts_, text.c_str(), speaker_id, speed, &AudioGeneratedCallback);
// generate_btn_.EnableWindow(TRUE);
g_stopped = true;

int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate,
filename.c_str());

SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);

if (ok) {
// AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK);

// AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully");
} else {
// AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK);

// AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename);
}
});

//CDialogEx::OnOK();
}

void CNonStreamingTextToSpeechDlg::OnBnClickedStop() { g_killed = true; }
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,8 @@ class CNonStreamingTextToSpeechDlg : public CDialogEx
private:
Microphone mic_;
std::unique_ptr<std::thread> play_thread_;
std::unique_ptr<std::thread> generate_thread_;

public:
afx_msg void OnBnClickedStop();
};
Loading

0 comments on commit 9dd0e03

Please sign in to comment.