From af07894675eb3380a333b79e3522403dafdacedf Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 17 Oct 2024 18:13:07 +0800 Subject: [PATCH] Add hotwords for ctc prefix beam search --- sherpa-onnx/csrc/hypothesis.h | 3 + sherpa-onnx/csrc/offline-ctc-decoder.h | 4 +- sherpa-onnx/csrc/offline-ctc-fst-decoder.cc | 3 +- sherpa-onnx/csrc/offline-ctc-fst-decoder.h | 6 +- .../csrc/offline-ctc-greedy-search-decoder.cc | 3 +- .../csrc/offline-ctc-greedy-search-decoder.h | 6 +- .../offline-ctc-prefix-beam-search-decoder.cc | 43 +++++-- .../offline-ctc-prefix-beam-search-decoder.h | 6 +- .../csrc/offline-recognizer-ctc-impl.h | 109 +++++++++++++++++- sherpa-onnx/csrc/offline-recognizer.cc | 5 +- 10 files changed, 167 insertions(+), 21 deletions(-) diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index c9ccf80a8..5cc47c984 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -80,6 +80,9 @@ struct Hypothesis { const ContextState *context_state = nullptr) : ys(ys), log_prob(log_prob), context_state(context_state) {} + explicit Hypothesis(const ContextState *context_state) + : context_state(context_state) {} + double TotalLogProb(bool use_ctc = false) const { return LogProb(use_ctc) + lm_log_prob; } diff --git a/sherpa-onnx/csrc/offline-ctc-decoder.h b/sherpa-onnx/csrc/offline-ctc-decoder.h index c9d1b36ff..b5914e939 100644 --- a/sherpa-onnx/csrc/offline-ctc-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-decoder.h @@ -8,6 +8,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-stream.h" namespace sherpa_onnx { @@ -42,7 +43,8 @@ class OfflineCtcDecoder { * @return Return a vector of size `N` containing the decoded results. */ virtual std::vector Decode( - Ort::Value log_probs, Ort::Value log_probs_length) = 0; + Ort::Value log_probs, Ort::Value log_probs_length, + OfflineStream **ss = nullptr, int32_t n = 0) = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc index 6c9df3fd3..e62d76a26 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc @@ -124,7 +124,8 @@ OfflineCtcFstDecoder::OfflineCtcFstDecoder( : config_(config), fst_(ReadGraph(config_.graph)) {} std::vector OfflineCtcFstDecoder::Decode( - Ort::Value log_probs, Ort::Value log_probs_length) { + Ort::Value log_probs, Ort::Value log_probs_length, + OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) { std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); assert(static_cast(shape.size()) == 3); diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.h b/sherpa-onnx/csrc/offline-ctc-fst-decoder.h index 2b33c14e8..0291f66cf 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.h @@ -19,8 +19,10 @@ class OfflineCtcFstDecoder : public OfflineCtcDecoder { public: explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config); - std::vector Decode( - Ort::Value log_probs, Ort::Value log_probs_length) override; + std::vector Decode(Ort::Value log_probs, + Ort::Value log_probs_length, + OfflineStream **ss = nullptr, + int32_t n = 0) override; private: OfflineCtcFstDecoderConfig config_; diff --git a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc index 8196e28b3..2aca90dd4 100644 --- a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc @@ -13,7 +13,8 @@ namespace sherpa_onnx { std::vector OfflineCtcGreedySearchDecoder::Decode( - Ort::Value log_probs, Ort::Value log_probs_length) { + Ort::Value log_probs, Ort::Value log_probs_length, + OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) { std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); int32_t batch_size = static_cast(shape[0]); int32_t num_frames = static_cast(shape[1]); diff --git a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h index ccc2f728a..ce2c19904 100644 --- a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h @@ -16,8 +16,10 @@ class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder { explicit OfflineCtcGreedySearchDecoder(int32_t blank_id) : blank_id_(blank_id) {} - std::vector Decode( - Ort::Value log_probs, Ort::Value log_probs_length) override; + std::vector Decode(Ort::Value log_probs, + Ort::Value log_probs_length, + OfflineStream **ss = nullptr, + int32_t n = 0) override; private: int32_t blank_id_; diff --git a/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc index 10937e81b..273f65210 100644 --- a/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc @@ -8,6 +8,7 @@ #include #include +#include "sherpa-onnx/csrc/context-graph.h" #include "sherpa-onnx/csrc/hypothesis.h" #include "sherpa-onnx/csrc/macros.h" @@ -16,14 +17,16 @@ namespace sherpa_onnx { static std::vector StepWorker(const float *p_log_probs, std::vector &hyps, int32_t blank_id, int32_t vocab_size, - int32_t max_active_paths) { + int32_t max_active_paths, + const ContextGraph *context_graph) { auto topk = TopkIndex(p_log_probs, vocab_size, max_active_paths); Hypotheses next_hyps; for (auto &hyp : hyps) { - Hypothesis new_hyp = hyp; for (auto k : topk) { + Hypothesis new_hyp = hyp; int32_t new_token = k; float log_prob = p_log_probs[k]; + bool update_prefix = false; if (new_token == blank_id) { // Case 0: *a + ε => *a // *aε + ε => *a @@ -36,7 +39,6 @@ static std::vector StepWorker(const float *p_log_probs, // Prefix does not change, update log_prob of non_blank new_hyp.log_prob_nb = hyp.log_prob_nb + log_prob; new_hyp.log_prob_b = -std::numeric_limits::infinity(); - next_hyps.Add(std::move(new_hyp)); // Case 2: *aε + a => *aa @@ -45,7 +47,7 @@ static std::vector StepWorker(const float *p_log_probs, new_hyp.ys.push_back(new_token); new_hyp.log_prob_nb = hyp.log_prob_b + log_prob; new_hyp.log_prob_b = -std::numeric_limits::infinity(); - next_hyps.Add(std::move(new_hyp)); + update_prefix = true; } else { // Case 3: *a + b => *ab, *aε + b => *ab // Prefix changes, update log_prob of non_blank @@ -53,6 +55,18 @@ static std::vector StepWorker(const float *p_log_probs, new_hyp.ys.push_back(new_token); new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob; new_hyp.log_prob_b = -std::numeric_limits::infinity(); + update_prefix = true; + } + + if (update_prefix) { + float lm_log_prob = hyp.lm_log_prob; + if (context_graph != nullptr && hyp.context_state != nullptr) { + auto context_res = + context_graph->ForwardOneStep(hyp.context_state, new_token); + lm_log_prob = lm_log_prob + std::get<0>(context_res); + new_hyp.context_state = std::get<1>(context_res); + } + new_hyp.lm_log_prob = lm_log_prob; next_hyps.Add(std::move(new_hyp)); } } @@ -61,7 +75,8 @@ static std::vector StepWorker(const float *p_log_probs, } std::vector OfflineCtcPrefixBeamSearchDecoder::Decode( - Ort::Value log_probs, Ort::Value log_probs_length) { + Ort::Value log_probs, Ort::Value log_probs_length, + OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) { std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); int32_t batch_size = static_cast(shape[0]); int32_t num_frames = static_cast(shape[1]); @@ -75,8 +90,17 @@ std::vector OfflineCtcPrefixBeamSearchDecoder::Decode( std::vector> cur; cur.reserve(batch_size); + std::vector context_graphs(batch_size, nullptr); + for (int32_t i = 0; i < batch_size; ++i) { - cur.emplace_back(std::vector({Hypothesis()})); + const ContextState *context_state = nullptr; + if (ss != nullptr) { + context_graphs[i] = ss[i]->GetContextGraph(); + if (context_graphs[i] != nullptr) + context_state = context_graphs[i]->Root(); + } + Hypothesis hyp(context_state); + cur.emplace_back(std::vector({hyp})); } for (int32_t t = 0; t < num_frames; ++t) { @@ -85,7 +109,12 @@ std::vector OfflineCtcPrefixBeamSearchDecoder::Decode( const float *p_log_probs = log_probs.GetTensorData() + b * num_frames * vocab_size + t * vocab_size; cur[b] = StepWorker(p_log_probs, cur[b], blank_id_, vocab_size, - max_active_paths_); + max_active_paths_, context_graphs[b].get()); + // for (auto &x : cur[b]) { + // SHERPA_ONNX_LOGE("step : %d, key : %s, ac : %f, lm : %f", t, + // x.Key().c_str(), x.LogProb(true), x.lm_log_prob); + // } + // SHERPA_ONNX_LOGE("\n"); } } } diff --git a/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h index 97449f572..1504bb870 100644 --- a/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h @@ -16,8 +16,10 @@ class OfflineCtcPrefixBeamSearchDecoder : public OfflineCtcDecoder { OfflineCtcPrefixBeamSearchDecoder(int32_t max_active_paths, int32_t blank_id) : max_active_paths_(max_active_paths), blank_id_(blank_id) {} - std::vector Decode( - Ort::Value log_probs, Ort::Value log_probs_length) override; + std::vector Decode(Ort::Value log_probs, + Ort::Value log_probs_length, + OfflineStream **ss = nullptr, + int32_t n = 0) override; private: int32_t max_active_paths_; diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index ebf7b25b1..4afedeb93 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -7,6 +7,7 @@ #include #include +#include // NOLINT #include #include #include @@ -25,6 +26,8 @@ #include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include "sherpa-onnx/csrc/pad-sequence.h" #include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" namespace sherpa_onnx { @@ -151,6 +154,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { if (config_.decoding_method == "greedy_search") { decoder_ = std::make_unique(blank_id); } else { + if (!config_.model_config.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model_config.bpe_vocab); + } + + if (!config_.hotwords_file.empty()) { + InitHotwords(); + } + decoder_ = std::make_unique( config_.max_active_paths, blank_id); } @@ -163,8 +175,46 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { } } + std::unique_ptr CreateStream( + const std::string &hotwords) const override { + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); + std::istringstream is(hws); + std::vector> current; + std::vector current_scores; + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), ¤t, ¤t_scores)) { + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", + hotwords.c_str()); + } + + int32_t num_default_hws = hotwords_.size(); + int32_t num_hws = current.size(); + + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); + + if (!current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else if (!current_scores.empty() && boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_default_hws, + config_.hotwords_score); + } else if (current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_hws, + config_.hotwords_score); + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else { + // Do nothing. + } + + auto context_graph = std::make_shared( + current, config_.hotwords_score, current_scores); + return std::make_unique(config_.feat_config, context_graph); + } + std::unique_ptr CreateStream() const override { - return std::make_unique(config_.feat_config); + return std::make_unique(config_.feat_config, + hotwords_graph_); } void DecodeStreams(OfflineStream **ss, int32_t n) const override { @@ -218,7 +268,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { -23.025850929940457f); auto t = model_->Forward(std::move(x), std::move(x_length)); - auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]), ss, n); int32_t frame_shift_ms = 10; for (int32_t i = 0; i != n; ++i) { @@ -255,7 +305,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { x_length_shape.data(), x_length_shape.size()); auto t = model_->Forward(std::move(x), std::move(x_length)); - auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + + OfflineStream *ss[1] = {s}; + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]), ss, 1); int32_t frame_shift_ms = 10; auto r = Convert(results[0], symbol_table_, frame_shift_ms, @@ -264,9 +316,60 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { s->SetResult(r); } + void InitHotwords() { + // each line in hotwords_file contains space-separated words + + std::ifstream is(config_.hotwords_file); + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + +#if __ANDROID_API__ >= 9 + void InitHotwords(AAssetManager *mgr) { + // each line in hotwords_file contains space-separated words + + auto buf = ReadFile(mgr, config_.hotwords_file); + + std::istringstream is(std::string(buf.begin(), buf.end())); + + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } +#endif + private: OfflineRecognizerConfig config_; SymbolTable symbol_table_; + + std::vector> hotwords_; + std::vector boost_scores_; + ContextGraphPtr hotwords_graph_; + std::unique_ptr bpe_encoder_; + std::unique_ptr model_; std::unique_ptr decoder_; }; diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index f73e35ad6..88e4a271d 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -68,7 +68,8 @@ bool OfflineRecognizerConfig::Validate() const { } } - if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { + if (!hotwords_file.empty() && (decoding_method != "modified_beam_search" && + decoding_method != "prefix_beam_search")) { SHERPA_ONNX_LOGE( "Please use --decoding-method=modified_beam_search if you" " provide --hotwords-file. Given --decoding-method='%s'", @@ -157,7 +158,7 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const { } void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) { - impl_->SetConfig(config); + impl_->SetConfig(config); } OfflineRecognizerConfig OfflineRecognizer::GetConfig() const {