diff --git a/.github/scripts/test-offline-transducer.sh b/.github/scripts/test-offline-transducer.sh index 9fc1107ab..8cb3c5e8d 100755 --- a/.github/scripts/test-offline-transducer.sh +++ b/.github/scripts/test-offline-transducer.sh @@ -123,3 +123,30 @@ time $EXE \ $repo/test_wavs/8k.wav rm -rf $repo + +log "------------------------------------------------------------" +log "Run Paraformer (Chinese) with timestamps" +log "------------------------------------------------------------" + +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14 +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +ls -lh *.onnx +popd + +time $EXE \ + --tokens=$repo/tokens.txt \ + --paraformer=$repo/model.int8.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/8k.wav + +rm -rf $repo diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 6a989d542..520c47542 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -353,11 +353,22 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( std::copy(text.begin(), text.end(), const_cast(r->text)); const_cast(r->text)[text.size()] = 0; + if (!result.timestamps.empty()) { + r->timestamps = new float[result.timestamps.size()]; + std::copy(result.timestamps.begin(), result.timestamps.end(), + r->timestamps); + r->count = result.timestamps.size(); + } else { + r->timestamps = nullptr; + r->count = 0; + } + return r; } void DestroyOfflineRecognizerResult( const SherpaOnnxOfflineRecognizerResult *r) { delete[] r->text; + delete[] r->timestamps; delete r; } diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index d669bce27..71aa56426 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -408,6 +408,14 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { const char *text; + + // Pointer to continuous memory which holds timestamps + // + // It is NULL if the model does not support timestamps + float *timestamps; + + // number of entries in timestamps + int32_t count; // TODO(fangjun): Add more fields } SherpaOnnxOfflineRecognizerResult; diff --git a/sherpa-onnx/csrc/offline-paraformer-decoder.h b/sherpa-onnx/csrc/offline-paraformer-decoder.h index 1b783e88d..2effdfe1b 100644 --- a/sherpa-onnx/csrc/offline-paraformer-decoder.h +++ b/sherpa-onnx/csrc/offline-paraformer-decoder.h @@ -14,6 +14,11 @@ namespace sherpa_onnx { struct OfflineParaformerDecoderResult { /// The decoded token IDs std::vector tokens; + + // it contains the start time of each token in seconds + // + // len(timestamps) == len(tokens) + std::vector timestamps; }; class OfflineParaformerDecoder { @@ -28,7 +33,8 @@ class OfflineParaformerDecoder { * @return Return a vector of size `N` containing the decoded results. */ virtual std::vector Decode( - Ort::Value log_probs, Ort::Value token_num) = 0; + Ort::Value log_probs, Ort::Value token_num, + Ort::Value us_cif_peak = Ort::Value(nullptr)) = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc index 619b33495..c1d89a3ab 100644 --- a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc @@ -5,13 +5,18 @@ #include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" #include +#include #include +#include "sherpa-onnx/csrc/macros.h" + namespace sherpa_onnx { std::vector -OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs, - Ort::Value /*token_num*/) { +OfflineParaformerGreedySearchDecoder::Decode( + Ort::Value log_probs, Ort::Value /*token_num*/, + Ort::Value us_cif_peak /*=Ort::Value(nullptr)*/ +) { std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); int32_t batch_size = shape[0]; int32_t num_tokens = shape[1]; @@ -25,12 +30,43 @@ OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs, for (int32_t k = 0; k != num_tokens; ++k) { auto max_idx = static_cast( std::distance(p, std::max_element(p, p + vocab_size))); - if (max_idx == eos_id_) break; + if (max_idx == eos_id_) { + break; + } results[i].tokens.push_back(max_idx); p += vocab_size; } + + if (us_cif_peak) { + int32_t dim = us_cif_peak.GetTensorTypeAndShapeInfo().GetShape()[1]; + + const auto *peak = us_cif_peak.GetTensorData() + i * dim; + std::vector timestamps; + timestamps.reserve(results[i].tokens.size()); + + // 10.0: frameshift is 10 milliseconds + // 6: LfrWindowSize + // 3: us_cif_peak is upsampled by a factor of 3 + // 1000: milliseconds to seconds + float scale = 10.0 * 6 / 3 / 1000; + + for (int32_t k = 0; k != dim; ++k) { + if (peak[k] > 1 - 1e-4) { + timestamps.push_back(k * scale); + } + } + timestamps.pop_back(); + + if (timestamps.size() == results[i].tokens.size()) { + results[i].timestamps = std::move(timestamps); + } else { + SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i, + static_cast(results[i].tokens.size()), + static_cast(timestamps.size())); + } + } } return results; diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h index 1f48e8c84..eba3fc04b 100644 --- a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h @@ -17,7 +17,8 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { : eos_id_(eos_id) {} std::vector Decode( - Ort::Value log_probs, Ort::Value /*token_num*/) override; + Ort::Value log_probs, Ort::Value token_num, + Ort::Value us_cif_peak = Ort::Value(nullptr)) override; private: int32_t eos_id_; diff --git a/sherpa-onnx/csrc/offline-paraformer-model.cc b/sherpa-onnx/csrc/offline-paraformer-model.cc index 614b2cc61..ce1851062 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model.cc +++ b/sherpa-onnx/csrc/offline-paraformer-model.cc @@ -6,6 +6,7 @@ #include #include +#include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" @@ -36,16 +37,13 @@ class OfflineParaformerModel::Impl { } #endif - std::pair Forward(Ort::Value features, - Ort::Value features_length) { + std::vector Forward(Ort::Value features, + Ort::Value features_length) { std::array inputs = {std::move(features), std::move(features_length)}; - auto out = - sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), - output_names_ptr_.data(), output_names_ptr_.size()); - - return {std::move(out[0]), std::move(out[1])}; + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); } int32_t VocabSize() const { return vocab_size_; } @@ -119,7 +117,7 @@ OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr, OfflineParaformerModel::~OfflineParaformerModel() = default; -std::pair OfflineParaformerModel::Forward( +std::vector OfflineParaformerModel::Forward( Ort::Value features, Ort::Value features_length) { return impl_->Forward(std::move(features), std::move(features_length)); } diff --git a/sherpa-onnx/csrc/offline-paraformer-model.h b/sherpa-onnx/csrc/offline-paraformer-model.h index 1fe7e84d5..d5c2329f6 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model.h +++ b/sherpa-onnx/csrc/offline-paraformer-model.h @@ -5,7 +5,6 @@ #define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ #include -#include #include #if __ANDROID_API__ >= 9 @@ -35,13 +34,17 @@ class OfflineParaformerModel { * valid frames in `features` before padding. * Its dtype is int32_t. * - * @return Return a pair containing: + * @return Return a vector containing: * - log_probs: A 3-D tensor of shape (N, T', vocab_size) * - token_num: A 1-D tensor of shape (N, T') containing number * of valid tokens in each utterance. Its dtype is int64_t. + * If it is a model supporting timestamps, then there are additional two + * outputs: + * - us_alphas + * - us_cif_peak */ - std::pair Forward(Ort::Value features, - Ort::Value features_length); + std::vector Forward(Ort::Value features, + Ort::Value features_length); /** Return the vocabulary size of the model */ diff --git a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h index 9d7186124..3c96f03bc 100644 --- a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h @@ -31,6 +31,7 @@ static OfflineRecognitionResult Convert( const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) { OfflineRecognitionResult r; r.tokens.reserve(src.tokens.size()); + r.timestamps = src.timestamps; std::string text; @@ -184,7 +185,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { // i.e., -23.025850929940457f Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0); - std::pair t{nullptr, nullptr}; + std::vector t; try { t = model_->Forward(std::move(x), std::move(x_length)); } catch (const Ort::Exception &ex) { @@ -193,7 +194,13 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { return; } - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); + std::vector results; + if (t.size() == 2) { + results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + } else { + results = + decoder_->Decode(std::move(t[0]), std::move(t[1]), std::move(t[3])); + } for (int32_t i = 0; i != n; ++i) { auto r = Convert(results[i], symbol_table_); diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 524499b2a..72c497cf8 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -349,6 +349,23 @@ class SherpaOnnxOfflineRecongitionResult { return String(cString: result.pointee.text) } + var count: Int32 { + return result.pointee.count + } + + var timestamps: [Float] { + if let p = result.pointee.timestamps { + var timestamps: [Float] = [] + for index in 0..!) { self.result = result } diff --git a/swift-api-examples/decode-file-non-streaming.swift b/swift-api-examples/decode-file-non-streaming.swift index a9485c5fd..6d0b4e8b5 100644 --- a/swift-api-examples/decode-file-non-streaming.swift +++ b/swift-api-examples/decode-file-non-streaming.swift @@ -13,21 +13,45 @@ extension AVAudioPCMBuffer { } func run() { - let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" - let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx" - let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt" - let whisperConfig = sherpaOnnxOfflineWhisperModelConfig( - encoder: encoder, - decoder: decoder - ) + var recognizer: SherpaOnnxOfflineRecognizer + var modelConfig: SherpaOnnxOfflineModelConfig + var modelType = "whisper" + // modelType = "paraformer" - let modelConfig = sherpaOnnxOfflineModelConfig( - tokens: tokens, - whisper: whisperConfig, - debug: 0, - modelType: "whisper" - ) + if modelType == "whisper" { + let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" + let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx" + let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt" + + let whisperConfig = sherpaOnnxOfflineWhisperModelConfig( + encoder: encoder, + decoder: decoder + ) + + modelConfig = sherpaOnnxOfflineModelConfig( + tokens: tokens, + whisper: whisperConfig, + debug: 0, + modelType: "whisper" + ) + } else if modelType == "paraformer" { + let model = "./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx" + let tokens = "./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt" + let paraformerConfig = sherpaOnnxOfflineParaformerModelConfig( + model: model + ) + + modelConfig = sherpaOnnxOfflineModelConfig( + tokens: tokens, + paraformer: paraformerConfig, + debug: 0, + modelType: "paraformer" + ) + } else { + print("Please specify a supported modelType \(modelType)") + return + } let featConfig = sherpaOnnxFeatureConfig( sampleRate: 16000, @@ -38,7 +62,7 @@ func run() { modelConfig: modelConfig ) - let recognizer = SherpaOnnxOfflineRecognizer(config: &config) + recognizer = SherpaOnnxOfflineRecognizer(config: &config) let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" let fileURL: NSURL = NSURL(fileURLWithPath: filePath) @@ -55,6 +79,10 @@ func run() { let array: [Float]! = audioFileBuffer?.array() let result = recognizer.decode(samples: array, sampleRate: Int(audioFormat.sampleRate)) print("\nresult is:\n\(result.text)") + if result.timestamps.count != 0 { + print("\ntimestamps is:\n\(result.timestamps)") + } + } @main