Skip to content

Commit

Permalink
Add hotwords for ctc prefix beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Oct 17, 2024
1 parent 0137902 commit af07894
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 21 deletions.
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/hypothesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ struct Hypothesis {
const ContextState *context_state = nullptr)
: ys(ys), log_prob(log_prob), context_state(context_state) {}

explicit Hypothesis(const ContextState *context_state)
: context_state(context_state) {}

double TotalLogProb(bool use_ctc = false) const {
return LogProb(use_ctc) + lm_log_prob;
}
Expand Down
4 changes: 3 additions & 1 deletion sherpa-onnx/csrc/offline-ctc-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-stream.h"

namespace sherpa_onnx {

Expand Down Expand Up @@ -42,7 +43,8 @@ class OfflineCtcDecoder {
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) = 0;
Ort::Value log_probs, Ort::Value log_probs_length,
OfflineStream **ss = nullptr, int32_t n = 0) = 0;
};

} // namespace sherpa_onnx
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ OfflineCtcFstDecoder::OfflineCtcFstDecoder(
: config_(config), fst_(ReadGraph(config_.graph)) {}

std::vector<OfflineCtcDecoderResult> OfflineCtcFstDecoder::Decode(
Ort::Value log_probs, Ort::Value log_probs_length) {
Ort::Value log_probs, Ort::Value log_probs_length,
OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();

assert(static_cast<int32_t>(shape.size()) == 3);
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/offline-ctc-fst-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ class OfflineCtcFstDecoder : public OfflineCtcDecoder {
public:
explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config);

std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) override;
std::vector<OfflineCtcDecoderResult> Decode(Ort::Value log_probs,
Ort::Value log_probs_length,
OfflineStream **ss = nullptr,
int32_t n = 0) override;

private:
OfflineCtcFstDecoderConfig config_;
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
namespace sherpa_onnx {

std::vector<OfflineCtcDecoderResult> OfflineCtcGreedySearchDecoder::Decode(
Ort::Value log_probs, Ort::Value log_probs_length) {
Ort::Value log_probs, Ort::Value log_probs_length,
OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = static_cast<int32_t>(shape[0]);
int32_t num_frames = static_cast<int32_t>(shape[1]);
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder {
explicit OfflineCtcGreedySearchDecoder(int32_t blank_id)
: blank_id_(blank_id) {}

std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) override;
std::vector<OfflineCtcDecoderResult> Decode(Ort::Value log_probs,
Ort::Value log_probs_length,
OfflineStream **ss = nullptr,
int32_t n = 0) override;

private:
int32_t blank_id_;
Expand Down
43 changes: 36 additions & 7 deletions sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/macros.h"

Expand All @@ -16,14 +17,16 @@ namespace sherpa_onnx {
static std::vector<Hypothesis> StepWorker(const float *p_log_probs,
std::vector<Hypothesis> &hyps,
int32_t blank_id, int32_t vocab_size,
int32_t max_active_paths) {
int32_t max_active_paths,
const ContextGraph *context_graph) {
auto topk = TopkIndex(p_log_probs, vocab_size, max_active_paths);
Hypotheses next_hyps;
for (auto &hyp : hyps) {
Hypothesis new_hyp = hyp;
for (auto k : topk) {
Hypothesis new_hyp = hyp;
int32_t new_token = k;
float log_prob = p_log_probs[k];
bool update_prefix = false;
if (new_token == blank_id) {
// Case 0: *a + ε => *a
// *aε + ε => *a
Expand All @@ -36,7 +39,6 @@ static std::vector<Hypothesis> StepWorker(const float *p_log_probs,
// Prefix does not change, update log_prob of non_blank
new_hyp.log_prob_nb = hyp.log_prob_nb + log_prob;
new_hyp.log_prob_b = -std::numeric_limits<float>::infinity();

next_hyps.Add(std::move(new_hyp));

// Case 2: *aε + a => *aa
Expand All @@ -45,14 +47,26 @@ static std::vector<Hypothesis> StepWorker(const float *p_log_probs,
new_hyp.ys.push_back(new_token);
new_hyp.log_prob_nb = hyp.log_prob_b + log_prob;
new_hyp.log_prob_b = -std::numeric_limits<float>::infinity();
next_hyps.Add(std::move(new_hyp));
update_prefix = true;
} else {
// Case 3: *a + b => *ab, *aε + b => *ab
// Prefix changes, update log_prob of non_blank
// Caution: DO NOT use append, as clone is shallow copy
new_hyp.ys.push_back(new_token);
new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob;
new_hyp.log_prob_b = -std::numeric_limits<float>::infinity();
update_prefix = true;
}

if (update_prefix) {
float lm_log_prob = hyp.lm_log_prob;
if (context_graph != nullptr && hyp.context_state != nullptr) {
auto context_res =
context_graph->ForwardOneStep(hyp.context_state, new_token);
lm_log_prob = lm_log_prob + std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
new_hyp.lm_log_prob = lm_log_prob;
next_hyps.Add(std::move(new_hyp));
}
}
Expand All @@ -61,7 +75,8 @@ static std::vector<Hypothesis> StepWorker(const float *p_log_probs,
}

std::vector<OfflineCtcDecoderResult> OfflineCtcPrefixBeamSearchDecoder::Decode(
Ort::Value log_probs, Ort::Value log_probs_length) {
Ort::Value log_probs, Ort::Value log_probs_length,
OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = static_cast<int32_t>(shape[0]);
int32_t num_frames = static_cast<int32_t>(shape[1]);
Expand All @@ -75,8 +90,17 @@ std::vector<OfflineCtcDecoderResult> OfflineCtcPrefixBeamSearchDecoder::Decode(
std::vector<std::vector<Hypothesis>> cur;
cur.reserve(batch_size);

std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);

for (int32_t i = 0; i < batch_size; ++i) {
cur.emplace_back(std::vector<Hypothesis>({Hypothesis()}));
const ContextState *context_state = nullptr;
if (ss != nullptr) {
context_graphs[i] = ss[i]->GetContextGraph();
if (context_graphs[i] != nullptr)
context_state = context_graphs[i]->Root();
}
Hypothesis hyp(context_state);
cur.emplace_back(std::vector<Hypothesis>({hyp}));
}

for (int32_t t = 0; t < num_frames; ++t) {
Expand All @@ -85,7 +109,12 @@ std::vector<OfflineCtcDecoderResult> OfflineCtcPrefixBeamSearchDecoder::Decode(
const float *p_log_probs = log_probs.GetTensorData<float>() +
b * num_frames * vocab_size + t * vocab_size;
cur[b] = StepWorker(p_log_probs, cur[b], blank_id_, vocab_size,
max_active_paths_);
max_active_paths_, context_graphs[b].get());
// for (auto &x : cur[b]) {
// SHERPA_ONNX_LOGE("step : %d, key : %s, ac : %f, lm : %f", t,
// x.Key().c_str(), x.LogProb(true), x.lm_log_prob);
// }
// SHERPA_ONNX_LOGE("\n");
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ class OfflineCtcPrefixBeamSearchDecoder : public OfflineCtcDecoder {
OfflineCtcPrefixBeamSearchDecoder(int32_t max_active_paths, int32_t blank_id)
: max_active_paths_(max_active_paths), blank_id_(blank_id) {}

std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) override;
std::vector<OfflineCtcDecoderResult> Decode(Ort::Value log_probs,
Ort::Value log_probs_length,
OfflineStream **ss = nullptr,
int32_t n = 0) override;

private:
int32_t max_active_paths_;
Expand Down
109 changes: 106 additions & 3 deletions sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <ios>
#include <memory>
#include <regex> // NOLINT
#include <sstream>
#include <string>
#include <utility>
Expand All @@ -25,6 +26,8 @@
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"

namespace sherpa_onnx {

Expand Down Expand Up @@ -151,6 +154,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
} else {
if (!config_.model_config.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model_config.bpe_vocab);
}

if (!config_.hotwords_file.empty()) {
InitHotwords();
}

decoder_ = std::make_unique<OfflineCtcPrefixBeamSearchDecoder>(
config_.max_active_paths, blank_id);
}
Expand All @@ -163,8 +175,46 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}

std::unique_ptr<OfflineStream> CreateStream(
const std::string &hotwords) const override {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
std::vector<float> current_scores;
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &current, &current_scores)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}

int32_t num_default_hws = hotwords_.size();
int32_t num_hws = current.size();

current.insert(current.end(), hotwords_.begin(), hotwords_.end());

if (!current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), boost_scores_.begin(),
boost_scores_.end());
} else if (!current_scores.empty() && boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_default_hws,
config_.hotwords_score);
} else if (current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_hws,
config_.hotwords_score);
current_scores.insert(current_scores.end(), boost_scores_.begin(),
boost_scores_.end());
} else {
// Do nothing.
}

auto context_graph = std::make_shared<ContextGraph>(
current, config_.hotwords_score, current_scores);
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
}

std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
return std::make_unique<OfflineStream>(config_.feat_config,
hotwords_graph_);
}

void DecodeStreams(OfflineStream **ss, int32_t n) const override {
Expand Down Expand Up @@ -218,7 +268,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
-23.025850929940457f);
auto t = model_->Forward(std::move(x), std::move(x_length));

auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]), ss, n);

int32_t frame_shift_ms = 10;
for (int32_t i = 0; i != n; ++i) {
Expand Down Expand Up @@ -255,7 +305,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
x_length_shape.data(), x_length_shape.size());

auto t = model_->Forward(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));

OfflineStream *ss[1] = {s};
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]), ss, 1);
int32_t frame_shift_ms = 10;

auto r = Convert(results[0], symbol_table_, frame_shift_ms,
Expand All @@ -264,9 +316,60 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
s->SetResult(r);
}

void InitHotwords() {
// each line in hotwords_file contains space-separated words

std::ifstream is(config_.hotwords_file);
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ = std::make_shared<ContextGraph>(
hotwords_, config_.hotwords_score, boost_scores_);
}

#if __ANDROID_API__ >= 9
void InitHotwords(AAssetManager *mgr) {
// each line in hotwords_file contains space-separated words

auto buf = ReadFile(mgr, config_.hotwords_file);

std::istringstream is(std::string(buf.begin(), buf.end()));

if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ = std::make_shared<ContextGraph>(
hotwords_, config_.hotwords_score, boost_scores_);
}
#endif

private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;

std::vector<std::vector<int32_t>> hotwords_;
std::vector<float> boost_scores_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;

std::unique_ptr<OfflineCtcModel> model_;
std::unique_ptr<OfflineCtcDecoder> decoder_;
};
Expand Down
5 changes: 3 additions & 2 deletions sherpa-onnx/csrc/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ bool OfflineRecognizerConfig::Validate() const {
}
}

if (!hotwords_file.empty() && decoding_method != "modified_beam_search") {
if (!hotwords_file.empty() && (decoding_method != "modified_beam_search" &&
decoding_method != "prefix_beam_search")) {
SHERPA_ONNX_LOGE(
"Please use --decoding-method=modified_beam_search if you"
" provide --hotwords-file. Given --decoding-method='%s'",
Expand Down Expand Up @@ -157,7 +158,7 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
}

void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) {
impl_->SetConfig(config);
impl_->SetConfig(config);
}

OfflineRecognizerConfig OfflineRecognizer::GetConfig() const {
Expand Down

0 comments on commit af07894

Please sign in to comment.