Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implement LM rescore for online transducer #1231

Merged
merged 6 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sherpa-onnx/csrc/hypothesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ struct Hypothesis {
// LM log prob if any.
double lm_log_prob = 0;

// the nn lm score for next token given the current ys
// the nn lm score for next token given the current ys, when using shallow fusion
CopyableOrtValue nn_lm_scores;

// cur scored tokens by RNN LM, when rescoring
int32_t cur_scored_pos = 0;

// the nn lm states
std::vector<CopyableOrtValue> nn_lm_states;

Expand Down
5 changes: 4 additions & 1 deletion sherpa-onnx/csrc/online-lm-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ void OnlineLMConfig::Register(ParseOptions *po) {
"Number of threads to run the neural network of LM model");
po->Register("lm-provider", &lm_provider,
"Specify a provider to LM model use: cpu, cuda, coreml");
po->Register("lm-shallow-fusion", &shallow_fusion,
"Boolean whether to use shallow fusion or rescore.");
}

bool OnlineLMConfig::Validate() const {
Expand All @@ -34,7 +36,8 @@ std::string OnlineLMConfig::ToString() const {

os << "OnlineLMConfig(";
os << "model=\"" << model << "\", ";
os << "scale=" << scale << ")";
os << "scale=" << scale << ", ";
os << "shallow_fusion=\"" << (shallow_fusion ? "True" : "False") << "\")";
SilverSulfide marked this conversation as resolved.
Show resolved Hide resolved

return os.str();
}
Expand Down
7 changes: 5 additions & 2 deletions sherpa-onnx/csrc/online-lm-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ struct OnlineLMConfig {
float scale = 0.5;
int32_t lm_num_threads = 1;
std::string lm_provider = "cpu";
// enable shallow fusion
bool shallow_fusion = true;

OnlineLMConfig() = default;

OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
const std::string &lm_provider)
const std::string &lm_provider, bool shallow_fusion)
: model(model),
scale(scale),
lm_num_threads(lm_num_threads),
lm_provider(lm_provider) {}
lm_provider(lm_provider),
shallow_fusion(shallow_fusion) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
24 changes: 19 additions & 5 deletions sherpa-onnx/csrc/online-lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,41 @@ class OnlineLM {

static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);

virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;
// init states for classic rescore
virtual std::vector<Ort::Value> GetInitStates() = 0;

/** ScoreToken a batch of sentences.
// init states for shallow fusion
virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() = 0;

/** ScoreToken a batch of sentences (shallow fusion).
*
* @param x A 2-D tensor of shape (N, 1) with data type int64.
* @param states It contains the states for the LM model
* @return Return a pair containingo
* @return Return a pair containing
* - log_prob of NN LM
* - updated states
*
*/
virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) = 0;

/** This function updates lm_lob_prob and nn_lm_scores of hyp
/** This function updates hyp.lm_log_prob of hyps (classic rescore).
*
* @param scale LM score
* @param context_size Context size of the transducer decoder model
* @param hyps It is changed in-place.
*
*/
virtual void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) = 0;

/** This function updates lm_log_prob and nn_lm_scores of hyp (shallow fusion).
*
* @param scale LM score
* @param hyps It is changed in-place.
*
*/
virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0;
virtual void ComputeLMScoreSF(float scale, Hypothesis *hyp) = 0;
};

} // namespace sherpa_onnx
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, config_.blank_penalty,
config_.temperature_scale);

} else if (config.decoding_method == "greedy_search") {
Expand Down Expand Up @@ -156,7 +156,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, config_.blank_penalty,
config_.temperature_scale);

} else if (config.decoding_method == "greedy_search") {
Expand Down
90 changes: 82 additions & 8 deletions sherpa-onnx/csrc/online-rnn-lm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ class OnlineRnnLM::Impl {
Init(config);
}

void ComputeLMScore(float scale, Hypothesis *hyp) {
// shallow fusion scoring function
void ComputeLMScoreSF(float scale, Hypothesis *hyp) {
if (hyp->nn_lm_states.empty()) {
auto init_states = GetInitStates();
auto init_states = GetInitStatesSF();
hyp->nn_lm_scores.value = std::move(init_states.first);
hyp->nn_lm_states = Convert(std::move(init_states.second));
}
Expand All @@ -49,6 +50,52 @@ class OnlineRnnLM::Impl {
hyp->nn_lm_states = Convert(std::move(lm_out.second));
}

// classic rescore function
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
Ort::AllocatorWithDefaultOptions allocator;

for (auto &hyp : *hyps) {
for (auto &h_m : hyp) {
auto &h = h_m.second;
auto &ys = h.ys;
const int32_t token_num_in_chunk =
ys.size() - context_size - h.cur_scored_pos - 1;

if (token_num_in_chunk < 1) {
continue;
}

if (h.nn_lm_states.empty()) {
h.nn_lm_states = Convert(GetInitStates());
}

if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};

Ort::Value x = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size());
int64_t *p_x = x.GetTensorMutableData<int64_t>();
std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
p_x);

// streaming forward by NN LM
auto out = ScoreToken(std::move(x),
Convert(std::move(h.nn_lm_states)));

// update NN LM score in hyp
const float *p_nll = out.first.GetTensorData<float>();
h.lm_log_prob = -scale * (*p_nll);

// update NN LM states in hyp
h.nn_lm_states = Convert(std::move(out.second));

h.cur_scored_pos += token_num_in_chunk;
}
}
}
}

std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) {
std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]),
Expand All @@ -66,7 +113,8 @@ class OnlineRnnLM::Impl {
return {std::move(out[0]), std::move(next_states)};
}

std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() {
// get init states for shallow fusion
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() {
std::vector<Ort::Value> ans;
ans.reserve(init_states_.size());
for (auto &s : init_states_) {
Expand All @@ -75,6 +123,18 @@ class OnlineRnnLM::Impl {
return {View(&init_scores_.value), std::move(ans)};
}

// get init states for classic rescore
std::vector<Ort::Value> GetInitStates() const {
std::vector<Ort::Value> ans;
ans.reserve(init_states_.size());

for (const auto &s : init_states_) {
ans.emplace_back(Clone(allocator_, &s));
}

return ans;
}

private:
void Init(const OnlineLMConfig &config) {
auto buf = ReadFile(config_.model);
Expand Down Expand Up @@ -116,7 +176,7 @@ class OnlineRnnLM::Impl {
states.push_back(std::move(c));
auto pair = ScoreToken(std::move(x), std::move(states));

init_scores_.value = std::move(pair.first);
init_scores_.value = std::move(pair.first); // only used during shallow fusion
init_states_ = std::move(pair.second);
}

Expand Down Expand Up @@ -147,17 +207,31 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)

OnlineRnnLM::~OnlineRnnLM() = default;

std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStates() {
// classic rescore state init
std::vector<Ort::Value> OnlineRnnLM::GetInitStates() {
return impl_->GetInitStates();
}

// shallow fusion state init
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStatesSF() {
return impl_->GetInitStatesSF();
}

std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) {
return impl_->ScoreToken(std::move(x), std::move(states));
}

void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) {
return impl_->ComputeLMScore(scale, hyp);
// classic rescore scores
void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
return impl_->ComputeLMScore(scale, context_size, hyps);
}

} // namespace sherpa_onnx
// shallow fusion scores
void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
return impl_->ComputeLMScoreSF(scale, hyp);
}


} // namespace sherpa_onnx
24 changes: 19 additions & 5 deletions sherpa-onnx/csrc/online-rnn-lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,41 @@ class OnlineRnnLM : public OnlineLM {

explicit OnlineRnnLM(const OnlineLMConfig &config);

std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
// init scores for classic rescore
std::vector<Ort::Value> GetInitStates() override;

/** ScoreToken a batch of sentences.
// init scores for shallow fusion
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() override;

/** ScoreToken a batch of sentences (shallow fusion).
*
* @param x A 2-D tensor of shape (N, L) with data type int64.
* @param states It contains the states for the LM model
* @return Return a pair containingo
* @return Return a pair containing
* - log_prob of NN LM
* - updated states
*
*/
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) override;

/** This function updates lm_lob_prob and nn_lm_scores of hyp
/** This function updates hyp.lm_lob_prob of hyps (classic rescore).
*
* @param scale LM score
* @param context_size Context size of the transducer decoder model
* @param hyps It is changed in-place.
*
*/
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) override;

/** This function updates lm_lob_prob and nn_lm_scores of hyp (shallow fusion).
*
* @param scale LM score
* @param hyps It is changed in-place.
*
*/
void ComputeLMScore(float scale, Hypothesis *hyp) override;
void ComputeLMScoreSF(float scale, Hypothesis *hyp) override;

private:
class Impl;
Expand Down
26 changes: 21 additions & 5 deletions sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(

// add log_prob of each hypothesis to p_logprob before taking top_k
for (int32_t i = 0; i != num_hyps; ++i) {
float log_prob = prev[i].log_prob + prev[i].lm_log_prob;

float log_prob = prev[i].log_prob;
if (lm_ && shallow_fusion_) {
log_prob += prev[i].lm_log_prob;
}

for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
*p_logprob += log_prob;
}
Expand Down Expand Up @@ -192,22 +197,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
if (lm_) {
lm_->ComputeLMScore(lm_scale_, &new_hyp);
if (lm_ && shallow_fusion_) {
lm_->ComputeLMScoreSF(lm_scale_, &new_hyp);
}
} else {
++new_hyp.num_trailing_blanks;
}
new_hyp.log_prob = p_logprob[k] + context_score -
if (lm_ && shallow_fusion_) {
new_hyp.log_prob = p_logprob[k] + context_score -
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
} else {
new_hyp.log_prob = p_logprob[k] + context_score; // for rescoring or no LM, previous token score is ignored
}

// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
float y_prob = logit_with_temperature[start * vocab_size + k];
new_hyp.ys_probs.push_back(y_prob);

if (lm_) { // export only when LM is used
if (lm_ && shallow_fusion_) { // export only when LM shallow fusion is used
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;

if (lm_scale_ != 0.0) {
lm_prob /= lm_scale_; // remove lm-scale
}
Expand All @@ -227,6 +238,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
} // for (int32_t b = 0; b != batch_size; ++b)
} // for (int32_t t = 0; t != num_frames; ++t)

// classic lm rescore
if (lm_ && !shallow_fusion_) {
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
}

for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ class OnlineTransducerModifiedBeamSearchDecoder
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
OnlineLM *lm,
int32_t max_active_paths,
float lm_scale, int32_t unk_id,
float lm_scale,
bool shallow_fusion,
int32_t unk_id,
float blank_penalty,
float temperature_scale)
: model_(model),
lm_(lm),
max_active_paths_(max_active_paths),
lm_scale_(lm_scale),
shallow_fusion_(shallow_fusion),
unk_id_(unk_id),
blank_penalty_(blank_penalty),
temperature_scale_(temperature_scale) {}
Expand All @@ -50,6 +53,7 @@ class OnlineTransducerModifiedBeamSearchDecoder

int32_t max_active_paths_;
float lm_scale_; // used only when lm_ is not nullptr
bool shallow_fusion_; // used only when lm_ is not nullptr
int32_t unk_id_;
float blank_penalty_;
float temperature_scale_;
Expand Down
Loading
Loading