From a12ebfab2294f4ede5cd4c85e3e63d994851caba Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 7 Sep 2023 15:12:29 +0800 Subject: [PATCH] treat unk as blank (#299) --- .../csrc/online-recognizer-transducer-impl.h | 21 +++++++++++++------ ...online-transducer-greedy-search-decoder.cc | 4 +++- .../online-transducer-greedy-search-decoder.h | 6 ++++-- ...transducer-modified-beam-search-decoder.cc | 4 +++- ...-transducer-modified-beam-search-decoder.h | 6 ++++-- 5 files changed, 29 insertions(+), 12 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 27f58687b..e08993dc1 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_(OnlineTransducerModel::Create(config.model_config)), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { + if (sym_.contains("")) { + unk_id_ = sym_[""]; + } + if (config.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { lm_ = OnlineLM::Create(config.lm_config); @@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale); + config_.lm_config.scale, unk_id_); } else if (config.decoding_method == "greedy_search") { - decoder_ = - std::make_unique(model_.get()); + decoder_ = std::make_unique( + model_.get(), unk_id_); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); @@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_(OnlineTransducerModel::Create(mgr, config.model_config)), sym_(mgr, config.model_config.tokens), endpoint_(config_.endpoint_config) { + if (sym_.contains("")) { + unk_id_ = sym_[""]; + } + if (config.decoding_method == "modified_beam_search") { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale); + config_.lm_config.scale, unk_id_); } else if (config.decoding_method == "greedy_search") { - decoder_ = - std::make_unique(model_.get()); + decoder_ = std::make_unique( + model_.get(), unk_id_); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); @@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { std::unique_ptr decoder_; SymbolTable sym_; Endpoint endpoint_; + int32_t unk_id_ = -1; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 965285ce7..e90426bdc 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -108,7 +108,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( static_cast(p_logit), std::max_element(static_cast(p_logit), static_cast(p_logit) + vocab_size))); - if (y != 0) { + // blank id is hardcoded to 0 + // also, it treats unk as blank + if (y != 0 && y != unk_id_) { emitted = true; r.tokens.push_back(y); r.timestamps.push_back(t + r.frame_offset); diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h index f7fa7ddf4..363cefedd 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h @@ -14,8 +14,9 @@ namespace sherpa_onnx { class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { public: - explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) - : model_(model) {} + OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, + int32_t unk_id) + : model_(model), unk_id_(unk_id) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -26,6 +27,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { private: OnlineTransducerModel *model_; // Not owned + int32_t unk_id_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index fef673472..a98f19dad 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -155,7 +155,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( float context_score = 0; auto context_state = new_hyp.context_state; - if (new_token != 0) { + // blank is hardcoded to 0 + // also, it treats unk as blank + if (new_token != 0 && new_token != unk_id_) { new_hyp.ys.push_back(new_token); new_hyp.timestamps.push_back(t + frame_offset); new_hyp.num_trailing_blanks = 0; diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index d05c5167b..bc0cfb559 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -21,11 +21,12 @@ class OnlineTransducerModifiedBeamSearchDecoder OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, OnlineLM *lm, int32_t max_active_paths, - float lm_scale) + float lm_scale, int32_t unk_id) : model_(model), lm_(lm), max_active_paths_(max_active_paths), - lm_scale_(lm_scale) {} + lm_scale_(lm_scale), + unk_id_(unk_id) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -45,6 +46,7 @@ class OnlineTransducerModifiedBeamSearchDecoder int32_t max_active_paths_; float lm_scale_; // used only when lm_ is not nullptr + int32_t unk_id_; }; } // namespace sherpa_onnx