Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timestamps for offline paraformer #310

Merged
merged 4 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/scripts/test-offline-transducer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,30 @@ time $EXE \
$repo/test_wavs/8k.wav

rm -rf $repo

log "------------------------------------------------------------"
log "Run Paraformer (Chinese) with timestamps"
log "------------------------------------------------------------"

repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"

GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd

time $EXE \
--tokens=$repo/tokens.txt \
--paraformer=$repo/model.int8.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/8k.wav

rm -rf $repo
11 changes: 11 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,22 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
const_cast<char *>(r->text)[text.size()] = 0;

if (!result.timestamps.empty()) {
r->timestamps = new float[result.timestamps.size()];
std::copy(result.timestamps.begin(), result.timestamps.end(),
r->timestamps);
r->count = result.timestamps.size();
} else {
r->timestamps = nullptr;
r->count = 0;
}

return r;
}

void DestroyOfflineRecognizerResult(
const SherpaOnnxOfflineRecognizerResult *r) {
delete[] r->text;
delete[] r->timestamps;
delete r;
}
8 changes: 8 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,14 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
const char *text;

// Pointer to continuous memory which holds timestamps
//
// It is NULL if the model does not support timestamps
float *timestamps;

// number of entries in timestamps
int32_t count;
// TODO(fangjun): Add more fields
} SherpaOnnxOfflineRecognizerResult;

Expand Down
8 changes: 7 additions & 1 deletion sherpa-onnx/csrc/offline-paraformer-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ namespace sherpa_onnx {
struct OfflineParaformerDecoderResult {
/// The decoded token IDs
std::vector<int64_t> tokens;

// it contains the start time of each token in seconds
//
// len(timestamps) == len(tokens)
std::vector<float> timestamps;
};

class OfflineParaformerDecoder {
Expand All @@ -28,7 +33,8 @@ class OfflineParaformerDecoder {
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineParaformerDecoderResult> Decode(
Ort::Value log_probs, Ort::Value token_num) = 0;
Ort::Value log_probs, Ort::Value token_num,
Ort::Value us_cif_peak = Ort::Value(nullptr)) = 0;
};

} // namespace sherpa_onnx
Expand Down
42 changes: 39 additions & 3 deletions sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"

#include <algorithm>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

std::vector<OfflineParaformerDecoderResult>
OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
Ort::Value /*token_num*/) {
OfflineParaformerGreedySearchDecoder::Decode(
Ort::Value log_probs, Ort::Value /*token_num*/,
Ort::Value us_cif_peak /*=Ort::Value(nullptr)*/
) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = shape[0];
int32_t num_tokens = shape[1];
Expand All @@ -25,12 +30,43 @@ OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
for (int32_t k = 0; k != num_tokens; ++k) {
auto max_idx = static_cast<int64_t>(
std::distance(p, std::max_element(p, p + vocab_size)));
if (max_idx == eos_id_) break;
if (max_idx == eos_id_) {
break;
}

results[i].tokens.push_back(max_idx);

p += vocab_size;
}

if (us_cif_peak) {
int32_t dim = us_cif_peak.GetTensorTypeAndShapeInfo().GetShape()[1];

const auto *peak = us_cif_peak.GetTensorData<float>() + i * dim;
std::vector<float> timestamps;
timestamps.reserve(results[i].tokens.size());

// 10.0: frameshift is 10 milliseconds
// 6: LfrWindowSize
// 3: us_cif_peak is upsampled by a factor of 3
// 1000: milliseconds to seconds
float scale = 10.0 * 6 / 3 / 1000;

for (int32_t k = 0; k != dim; ++k) {
if (peak[k] > 1 - 1e-4) {
timestamps.push_back(k * scale);
}
}
timestamps.pop_back();

if (timestamps.size() == results[i].tokens.size()) {
results[i].timestamps = std::move(timestamps);
} else {
SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
static_cast<int32_t>(results[i].tokens.size()),
static_cast<int32_t>(timestamps.size()));
}
}
}

return results;
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
: eos_id_(eos_id) {}

std::vector<OfflineParaformerDecoderResult> Decode(
Ort::Value log_probs, Ort::Value /*token_num*/) override;
Ort::Value log_probs, Ort::Value token_num,
Ort::Value us_cif_peak = Ort::Value(nullptr)) override;

private:
int32_t eos_id_;
Expand Down
14 changes: 6 additions & 8 deletions sherpa-onnx/csrc/offline-paraformer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <algorithm>
#include <string>
#include <utility>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
Expand Down Expand Up @@ -36,16 +37,13 @@ class OfflineParaformerModel::Impl {
}
#endif

std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> inputs = {std::move(features),
std::move(features_length)};

auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());

return {std::move(out[0]), std::move(out[1])};
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
}

int32_t VocabSize() const { return vocab_size_; }
Expand Down Expand Up @@ -119,7 +117,7 @@ OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr,

OfflineParaformerModel::~OfflineParaformerModel() = default;

std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(
std::vector<Ort::Value> OfflineParaformerModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}
Expand Down
11 changes: 7 additions & 4 deletions sherpa-onnx/csrc/offline-paraformer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_

#include <memory>
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
Expand Down Expand Up @@ -35,13 +34,17 @@ class OfflineParaformerModel {
* valid frames in `features` before padding.
* Its dtype is int32_t.
*
* @return Return a pair containing:
* @return Return a vector containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size)
* - token_num: A 1-D tensor of shape (N, T') containing number
* of valid tokens in each utterance. Its dtype is int64_t.
* If it is a model supporting timestamps, then there are additional two
* outputs:
* - us_alphas
* - us_cif_peak
*/
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length);
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length);

/** Return the vocabulary size of the model
*/
Expand Down
11 changes: 9 additions & 2 deletions sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ static OfflineRecognitionResult Convert(
const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps = src.timestamps;

std::string text;

Expand Down Expand Up @@ -184,7 +185,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
// i.e., -23.025850929940457f
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);

std::pair<Ort::Value, Ort::Value> t{nullptr, nullptr};
std::vector<Ort::Value> t;
try {
t = model_->Forward(std::move(x), std::move(x_length));
} catch (const Ort::Exception &ex) {
Expand All @@ -193,7 +194,13 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
return;
}

auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
std::vector<OfflineParaformerDecoderResult> results;
if (t.size() == 2) {
results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
} else {
results =
decoder_->Decode(std::move(t[0]), std::move(t[1]), std::move(t[3]));
}

for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_);
Expand Down
17 changes: 17 additions & 0 deletions swift-api-examples/SherpaOnnx.swift
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,23 @@ class SherpaOnnxOfflineRecongitionResult {
return String(cString: result.pointee.text)
}

var count: Int32 {
return result.pointee.count
}

var timestamps: [Float] {
if let p = result.pointee.timestamps {
var timestamps: [Float] = []
for index in 0..<count {
timestamps.append(p[Int(index)])
}
return timestamps
} else {
let timestamps: [Float] = []
return timestamps
}
}

init(result: UnsafePointer<SherpaOnnxOfflineRecognizerResult>!) {
self.result = result
}
Expand Down
56 changes: 42 additions & 14 deletions swift-api-examples/decode-file-non-streaming.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,45 @@ extension AVAudioPCMBuffer {
}

func run() {
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"

let whisperConfig = sherpaOnnxOfflineWhisperModelConfig(
encoder: encoder,
decoder: decoder
)
var recognizer: SherpaOnnxOfflineRecognizer
var modelConfig: SherpaOnnxOfflineModelConfig
var modelType = "whisper"
// modelType = "paraformer"

let modelConfig = sherpaOnnxOfflineModelConfig(
tokens: tokens,
whisper: whisperConfig,
debug: 0,
modelType: "whisper"
)
if modelType == "whisper" {
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"

let whisperConfig = sherpaOnnxOfflineWhisperModelConfig(
encoder: encoder,
decoder: decoder
)

modelConfig = sherpaOnnxOfflineModelConfig(
tokens: tokens,
whisper: whisperConfig,
debug: 0,
modelType: "whisper"
)
} else if modelType == "paraformer" {
let model = "./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx"
let tokens = "./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt"
let paraformerConfig = sherpaOnnxOfflineParaformerModelConfig(
model: model
)

modelConfig = sherpaOnnxOfflineModelConfig(
tokens: tokens,
paraformer: paraformerConfig,
debug: 0,
modelType: "paraformer"
)
} else {
print("Please specify a supported modelType \(modelType)")
return
}

let featConfig = sherpaOnnxFeatureConfig(
sampleRate: 16000,
Expand All @@ -38,7 +62,7 @@ func run() {
modelConfig: modelConfig
)

let recognizer = SherpaOnnxOfflineRecognizer(config: &config)
recognizer = SherpaOnnxOfflineRecognizer(config: &config)

let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
Expand All @@ -55,6 +79,10 @@ func run() {
let array: [Float]! = audioFileBuffer?.array()
let result = recognizer.decode(samples: array, sampleRate: Int(audioFormat.sampleRate))
print("\nresult is:\n\(result.text)")
if result.timestamps.count != 0 {
print("\ntimestamps is:\n\(result.timestamps)")
}

}

@main
Expand Down
Loading