Skip to content

Commit

Permalink
Support asset manager for offline models
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Sep 10, 2023
1 parent 159c71b commit d2990b9
Show file tree
Hide file tree
Showing 19 changed files with 320 additions and 28 deletions.
38 changes: 38 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,42 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
return nullptr;
}

#if __ANDROID_API__ >= 9

std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
AAssetManager *mgr, const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;

std::string filename;
if (!config.nemo_ctc.model.empty()) {
filename = config.nemo_ctc.model;
} else if (!config.tdnn.model.empty()) {
filename = config.tdnn.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
}

{
auto buffer = ReadFile(mgr, filename);

model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}

switch (model_type) {
case ModelType::kEncDecCTCModelBPE:
return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config);
break;
case ModelType::kTdnn:
return std::make_unique<OfflineTdnnCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
}

return nullptr;
}
#endif

} // namespace sherpa_onnx
11 changes: 11 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
#include <string>
#include <utility>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"

Expand All @@ -16,9 +21,15 @@ namespace sherpa_onnx {
class OfflineCtcModel {
public:
virtual ~OfflineCtcModel() = default;

static std::unique_ptr<OfflineCtcModel> Create(
const OfflineModelConfig &config);

#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineCtcModel> Create(
AAssetManager *mgr, const OfflineModelConfig &config);
#endif

/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
Expand Down
7 changes: 7 additions & 0 deletions sherpa-onnx/csrc/offline-lm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
return std::make_unique<OfflineRnnLM>(config);
}

#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineLM> OfflineLM::Create(AAssetManager *mgr,
const OfflineLMConfig &config) {
return std::make_unique<OfflineRnnLM>(mgr, config);
}
#endif

void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
// compute the max token seq so that we know how much space to allocate
Expand Down
10 changes: 10 additions & 0 deletions sherpa-onnx/csrc/offline-lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
#include <memory>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
Expand All @@ -20,6 +25,11 @@ class OfflineLM {

static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);

#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineLM> Create(AAssetManager *mgr,
const OfflineLMConfig &config);
#endif

/** Rescore a batch of sentences.
*
* @param x A 2-D tensor of shape (N, L) with data type int64.
Expand Down
26 changes: 21 additions & 5 deletions sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,21 @@ class OfflineNemoEncDecCtcModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
Init();
auto buf = ReadFile(config_.nemo_ctc.model);
Init(buf.data(), buf.size());
}

#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.nemo_ctc.model);
Init(buf.data(), buf.size());
}
#endif

std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<int64_t> shape =
Expand Down Expand Up @@ -57,10 +69,8 @@ class OfflineNemoEncDecCtcModel::Impl {
std::string FeatureNormalizationMethod() const { return normalize_type_; }

private:
void Init() {
auto buf = ReadFile(config_.nemo_ctc.model);

sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);

GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
Expand Down Expand Up @@ -104,6 +114,12 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

#if __ANDROID_API__ >= 9
OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
AAssetManager *mgr, const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif

OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;

std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
Expand Down
11 changes: 11 additions & 0 deletions sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
Expand All @@ -23,6 +28,12 @@ namespace sherpa_onnx {
class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
public:
explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config);

#if __ANDROID_API__ >= 9
OfflineNemoEncDecCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config);
#endif

~OfflineNemoEncDecCtcModel() override;

/** Run the forward method of the model.
Expand Down
23 changes: 21 additions & 2 deletions sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-model.h"
Expand Down Expand Up @@ -46,10 +51,24 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(OfflineCtcModel::Create(config_.model_config)) {
Init();
}

#if __ANDROID_API__ >= 9
OfflineRecognizerCtcImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(OfflineCtcModel::Create(mgr, config_.model_config)) {
Init();
}
#endif

void Init() {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();

if (config.decoding_method == "greedy_search") {
if (config_.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>")) {
SHERPA_ONNX_LOGE(
Expand All @@ -69,7 +88,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
config_.decoding_method.c_str());
exit(-1);
}
}
Expand Down
16 changes: 8 additions & 8 deletions sherpa-onnx/csrc/offline-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
if (!config.model_config.model_type.empty()) {
const auto &model_type = config.model_config.model_type;
if (model_type == "transducer") {
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "tdnn") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
Expand Down Expand Up @@ -216,23 +216,23 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(

if (model_type == "conformer" || model_type == "zipformer" ||
model_type == "zipformer2") {
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
}

if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
}

if (model_type == "EncDecCTCModelBPE") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}

if (model_type == "tdnn") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}

if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
}

SHERPA_ONNX_LOGE(
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
#include <memory>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h"
Expand Down
31 changes: 31 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
Expand Down Expand Up @@ -73,6 +78,32 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}

#if __ANDROID_API__ >= 9
explicit OfflineRecognizerTransducerImpl(
AAssetManager *mgr, const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(mgr,
config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
} else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(mgr, config.lm_config);
}

decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}
}
#endif

std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const override {
// We create context_graph at this level, because we might have default
Expand Down
25 changes: 23 additions & 2 deletions sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
Expand Down Expand Up @@ -253,16 +258,32 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
Init();
}

#if __ANDROID_API__ >= 9
OfflineRecognizerWhisperImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(
std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) {
Init();
}

#endif

void Init() {
// tokens.txt from whisper is base64 encoded, so we need to decode it
symbol_table_.ApplyBase64Decode();

if (config.decoding_method == "greedy_search") {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
config_.model_config.whisper, model_.get());
} else {
SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for whisper. Given %s",
config.decoding_method.c_str());
config_.decoding_method.c_str());
exit(-1);
}
}
Expand Down
Loading

0 comments on commit d2990b9

Please sign in to comment.