Skip to content

Commit

Permalink
Add two-pass speech recognition android demo
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Sep 10, 2023
1 parent d2990b9 commit 7d66bf4
Show file tree
Hide file tree
Showing 6 changed files with 419 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import android.widget.Button
import android.widget.TextView
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.*
import kotlin.concurrent.thread

private const val TAG = "sherpa-onnx"
Expand All @@ -25,7 +24,8 @@ class MainActivity : AppCompatActivity() {
// If there is no GPU and useGPU is true, we won't use GPU
private val useGPU: Boolean = true

private lateinit var model: SherpaOnnx
private lateinit var onlineRecognizer: SherpaOnnx
private lateinit var offlineRecognizer: SherpaOnnxOffline
private var audioRecord: AudioRecord? = null
private lateinit var recordButton: Button
private lateinit var textView: TextView
Expand All @@ -35,6 +35,8 @@ class MainActivity : AppCompatActivity() {
private val sampleRateInHz = 16000
private val channelConfig = AudioFormat.CHANNEL_IN_MONO

private var samplesBuffer = arrayListOf<FloatArray>()

// Note: We don't use AudioFormat.ENCODING_PCM_FLOAT
// since the AudioRecord.read(float[]) needs API level >= 23
// but we are targeting API level >= 21
Expand Down Expand Up @@ -69,9 +71,13 @@ class MainActivity : AppCompatActivity() {

ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)

Log.i(TAG, "Start to initialize model")
initModel()
Log.i(TAG, "Finished initializing model")
Log.i(TAG, "Start to initialize first-pass recognizer")
initOnlineRecognizer()
Log.i(TAG, "Finished initializing first-pass recognizer")

Log.i(TAG, "Start to initialize second-pass recognizer")
initOfflineRecognizer()
Log.i(TAG, "Finished initializing second-pass recognizer")

recordButton = findViewById(R.id.record_button)
recordButton.setOnClickListener { onclick() }
Expand All @@ -91,7 +97,8 @@ class MainActivity : AppCompatActivity() {
audioRecord!!.startRecording()
recordButton.setText(R.string.stop)
isRecording = true
model.reset(true)
onlineRecognizer.reset(true)
samplesBuffer.clear()
textView.text = ""
lastText = ""
idx = 0
Expand Down Expand Up @@ -121,29 +128,38 @@ class MainActivity : AppCompatActivity() {
val ret = audioRecord?.read(buffer, 0, buffer.size)
if (ret != null && ret > 0) {
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
model.acceptWaveform(samples, sampleRate=sampleRateInHz)
while (model.isReady()) {
model.decode()
samplesBuffer.add(samples)

onlineRecognizer.acceptWaveform(samples, sampleRate = sampleRateInHz)
while (onlineRecognizer.isReady()) {
onlineRecognizer.decode()
}
runOnUiThread {
val isEndpoint = model.isEndpoint()
val text = model.text

if(text.isNotBlank()) {
if (lastText.isBlank()) {
textView.text = "${idx}: ${text}"
} else {
textView.text = "${lastText}\n${idx}: ${text}"
}
val isEndpoint = onlineRecognizer.isEndpoint()
var textToDisplay = lastText

var text = onlineRecognizer.text
if (text.isNotBlank()) {
if (lastText.isBlank()) {
// textView.text = "${idx}: ${text}"
textToDisplay = "${idx}: ${text}"
} else {
textToDisplay = "${lastText}\n${idx}: ${text}"
}
}

if (isEndpoint) {
onlineRecognizer.reset()

if (isEndpoint) {
model.reset()
if (text.isNotBlank()) {
lastText = "${lastText}\n${idx}: ${text}"
idx += 1
}
if (text.isNotBlank()) {
text = runSecondPass()
lastText = "${lastText}\n${idx}: ${text}"
idx += 1
}
samplesBuffer.clear()
}

runOnUiThread {
textView.text = textToDisplay.lowercase()
}
}
}
Expand Down Expand Up @@ -173,23 +189,59 @@ class MainActivity : AppCompatActivity() {
return true
}

private fun initModel() {
private fun initOnlineRecognizer() {
// Please change getModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val type = 0
println("Select model type ${type}")
val firstType = 1
println("Select model type ${firstType} for the first pass")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
lmConfig = getOnlineLMConfig(type = type),
modelConfig = getModelConfig(type = firstType)!!,
endpointConfig = getEndpointConfig(),
enableEndpoint = true,
)

model = SherpaOnnx(
onlineRecognizer = SherpaOnnx(
assetManager = application.assets,
config = config,
)
}

private fun initOfflineRecognizer() {
// Please change getOfflineModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val secondType = 1
println("Select model type ${secondType} for the second pass")

val config = OfflineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getOfflineModelConfig(type = secondType)!!,
)

offlineRecognizer = SherpaOnnxOffline(
assetManager = application.assets,
config = config,
)
}

private fun runSecondPass(): String {
var totalSamples = 0
for (a in samplesBuffer) {
totalSamples += a.size
}
var i = 0

val samples = FloatArray(totalSamples)

// todo(fangjun): Make it more efficient
for (a in samplesBuffer) {
for (s in a) {
samples[i] = s
i += 1
}
}
return offlineRecognizer.decode(samples, sampleRateInHz)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ data class EndpointRule(

data class EndpointConfig(
var rule1: EndpointRule = EndpointRule(false, 2.4f, 0.0f),
var rule2: EndpointRule = EndpointRule(true, 1.4f, 0.0f),
var rule2: EndpointRule = EndpointRule(true, 1.2f, 0.0f),
var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f)
)

Expand Down Expand Up @@ -48,13 +48,47 @@ data class FeatureConfig(
data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineModelConfig,
var lmConfig: OnlineLMConfig,
var lmConfig: OnlineLMConfig = OnlineLMConfig(),
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
)

data class OfflineTransducerModelConfig(
var encoder: String = "",
var decoder: String = "",
var joiner: String = "",
)

data class OfflineParaformerModelConfig(
var model: String = "",
)

data class OfflineWhisperModelConfig(
var encoder: String = "",
var decoder: String = "",
)

data class OfflineModelConfig(
var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(),
var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
var modelType: String = "",
var tokens: String,
)

data class OfflineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OfflineModelConfig,
// var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
)

class SherpaOnnx(
assetManager: AssetManager? = null,
var config: OnlineRecognizerConfig,
Expand Down Expand Up @@ -111,6 +145,46 @@ class SherpaOnnx(
}
}

class SherpaOnnxOffline(
assetManager: AssetManager? = null,
var config: OfflineRecognizerConfig,
) {
private val ptr: Long

init {
if (assetManager != null) {
ptr = new(assetManager, config)
} else {
ptr = newFromFile(config)
}
}

protected fun finalize() {
delete(ptr)
}

fun decode(samples: FloatArray, sampleRate: Int) = decode(ptr, samples, sampleRate)

private external fun delete(ptr: Long)

private external fun new(
assetManager: AssetManager,
config: OfflineRecognizerConfig,
): Long

private external fun newFromFile(
config: OfflineRecognizerConfig,
): Long

private external fun decode(ptr: Long, samples: FloatArray, sampleRate: Int): String

companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}

fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
}
Expand All @@ -129,6 +203,10 @@ by following the code)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-zh-14m-2023-02-23
encoder/joiner int8, decoder float32
1 - csukuangfj/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 (English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-en-20m-2023-02-17-english
encoder/joiner int8, decoder fp32
*/
fun getModelConfig(type: Int): OnlineModelConfig? {
when (type) {
Expand All @@ -144,8 +222,21 @@ fun getModelConfig(type: Int): OnlineModelConfig? {
modelType = "zipformer",
)
}

1 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17"
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
}
return null;
return null
}

/*
Expand All @@ -171,7 +262,7 @@ fun getOnlineLMConfig(type: Int): OnlineLMConfig {
)
}
}
return OnlineLMConfig();
return OnlineLMConfig()
}

fun getEndpointConfig(): EndpointConfig {
Expand All @@ -181,3 +272,71 @@ fun getEndpointConfig(): EndpointConfig {
rule3 = EndpointRule(false, 0.0f, 20.0f)
)
}

/*
Please see
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models.
We only add a few here. Please change the following code
to add your own. (It should be straightforward to add a new model
by following the code)
@param type
0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese
int8
1 - sherpa-onnx-whisper-tiny.en
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en
encoder int8, decoder int8
2 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese
encoder/joiner int8, decoder fp32
*/
fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-paraformer-zh-2023-03-28"
return OfflineModelConfig(
paraformer = OfflineParaformerModelConfig(
model = "$modelDir/model.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "paraformer",
)
}

1 -> {
val modelDir = "sherpa-onnx-whisper-tiny.en"
return OfflineModelConfig(
whisper = OfflineWhisperModelConfig(
encoder = "$modelDir/tiny.en-encoder.int8.onnx",
decoder = "$modelDir/tiny.en-decoder.int8.onnx",
),
tokens = "$modelDir/tiny.en-tokens.txt",
modelType = "whisper",
)
}

2 -> {
val modelDir = "icefall-asr-zipformer-wenetspeech-20230615"
return OfflineModelConfig(
transducer = OfflineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx",
decoder = "$modelDir/decoder-epoch-12-avg-4.onnx",
joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}

}
return null
}
Loading

0 comments on commit 7d66bf4

Please sign in to comment.