Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add CTC prefix beam search / hotwords / shallow fussion #1439

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ set(sources
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
offline-ctc-model.cc
offline-ctc-prefix-beam-search-decoder.cc
offline-lm-config.cc
offline-lm.cc
offline-model-config.cc
Expand Down
49 changes: 31 additions & 18 deletions sherpa-onnx/csrc/hypothesis.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2023 Xiaomi Corporation
* Copyright (c) 2023-2024 Xiaomi Corporation
* Copyright (c) 2023 Pingfeng Luo
*/

Expand All @@ -10,53 +10,66 @@

namespace sherpa_onnx {

void Hypotheses::Add(Hypothesis hyp) {
void Hypotheses::Add(Hypothesis hyp, bool use_ctc /*= false */) {
auto key = hyp.Key();
auto it = hyps_dict_.find(key);
if (it == hyps_dict_.end()) {
hyps_dict_[key] = std::move(hyp);
} else {
it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
if (use_ctc) {
it->second.log_prob_b =
LogAdd<double>()(it->second.log_prob_b, hyp.log_prob_b);
it->second.log_prob_nb =
LogAdd<double>()(it->second.log_prob_nb, hyp.log_prob_nb);
} else {
it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
}
}
}

Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
Hypothesis Hypotheses::GetMostProbable(bool length_norm,
bool use_ctc /*= false */) const {
if (length_norm == false) {
return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
[](const auto &left, auto &right) -> bool {
return left.second.TotalLogProb() <
right.second.TotalLogProb();
})
return std::max_element(
hyps_dict_.begin(), hyps_dict_.end(),
[use_ctc](const auto &left, const auto &right) -> bool {
return left.second.TotalLogProb(use_ctc) <
right.second.TotalLogProb(use_ctc);
})
->second;
} else {
// for length_norm is true
return std::max_element(
hyps_dict_.begin(), hyps_dict_.end(),
[](const auto &left, const auto &right) -> bool {
return left.second.TotalLogProb() / left.second.ys.size() <
right.second.TotalLogProb() / right.second.ys.size();
[use_ctc](const auto &left, const auto &right) -> bool {
return left.second.TotalLogProb(use_ctc) /
left.second.ys.size() <
right.second.TotalLogProb(use_ctc) /
right.second.ys.size();
})
->second;
}
}

std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const {
std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm,
bool use_ctc /*= false*/) const {
k = std::max(k, 1);
k = std::min(k, Size());

std::vector<Hypothesis> all_hyps = Vec();

if (length_norm == false) {
std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
[](const auto &a, const auto &b) {
return a.TotalLogProb() > b.TotalLogProb();
[use_ctc](const auto &a, const auto &b) {
return a.TotalLogProb(use_ctc) >
b.TotalLogProb(use_ctc);
});
} else {
// for length_norm is true
std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
[](const auto &a, const auto &b) {
return a.TotalLogProb() / a.ys.size() >
b.TotalLogProb() / b.ys.size();
[use_ctc](const auto &a, const auto &b) {
return a.TotalLogProb(use_ctc) / a.ys.size() >
b.TotalLogProb(use_ctc) / b.ys.size();
});
}

Expand Down
33 changes: 28 additions & 5 deletions sherpa-onnx/csrc/hypothesis.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2023 Xiaomi Corporation
* Copyright (c) 2023-2024 Xiaomi Corporation
* Copyright (c) 2023 Pingfeng Luo
*
*/
Expand Down Expand Up @@ -48,6 +48,12 @@ struct Hypothesis {
// It contains only acoustic scores
double log_prob = 0;

// The total score of ys which ends with blank token in log space
double log_prob_b = 0;

// The total score of ys which ends with non blank token in log space
double log_prob_nb = -std::numeric_limits<float>::infinity();

// LM log prob if any.
double lm_log_prob = 0;

Expand All @@ -74,7 +80,21 @@ struct Hypothesis {
const ContextState *context_state = nullptr)
: ys(ys), log_prob(log_prob), context_state(context_state) {}

double TotalLogProb() const { return log_prob + lm_log_prob; }
explicit Hypothesis(const ContextState *context_state)
: context_state(context_state) {}

double TotalLogProb(bool use_ctc = false) const {
return LogProb(use_ctc) + lm_log_prob;
}

// The acoustic log probability
double LogProb(bool use_ctc = false) const {
if (use_ctc) {
return LogAdd<double>()(log_prob_b, log_prob_nb);
} else {
return log_prob;
}
}

// If two Hypotheses have the same `Key`, then they contain
// the same token sequence.
Expand Down Expand Up @@ -112,20 +132,23 @@ class Hypotheses {

// Add hyp to this object. If it already exists, its log_prob
// is updated with the given hyp using log-sum-exp.
void Add(Hypothesis hyp);
void Add(Hypothesis hyp, bool use_ctc = false);

// Get the hyp that has the largest log_prob.
// If length_norm is true, hyp's log_prob is divided by
// len(hyp.ys) before comparison.
Hypothesis GetMostProbable(bool length_norm) const;
Hypothesis GetMostProbable(bool length_norm, bool use_ctc = false) const;

// Get the k hyps that have the largest log_prob.
// If length_norm is true, hyp's log_prob is divided by
// len(hyp.ys) before comparison.
std::vector<Hypothesis> GetTopK(int32_t k, bool length_norm) const;
std::vector<Hypothesis> GetTopK(int32_t k, bool length_norm,
bool use_ctc = false) const;

int32_t Size() const { return hyps_dict_.size(); }

std::vector<Hypothesis> ToList() const { return Vec(); }

std::string ToString() const {
std::ostringstream os;
for (const auto &p : hyps_dict_) {
Expand Down
4 changes: 3 additions & 1 deletion sherpa-onnx/csrc/offline-ctc-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-stream.h"

namespace sherpa_onnx {

Expand Down Expand Up @@ -42,7 +43,8 @@ class OfflineCtcDecoder {
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) = 0;
Ort::Value log_probs, Ort::Value log_probs_length,
OfflineStream **ss = nullptr, int32_t n = 0) = 0;
};

} // namespace sherpa_onnx
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ OfflineCtcFstDecoder::OfflineCtcFstDecoder(
: config_(config), fst_(ReadGraph(config_.graph)) {}

std::vector<OfflineCtcDecoderResult> OfflineCtcFstDecoder::Decode(
Ort::Value log_probs, Ort::Value log_probs_length) {
Ort::Value log_probs, Ort::Value log_probs_length,
OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();

assert(static_cast<int32_t>(shape.size()) == 3);
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/offline-ctc-fst-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ class OfflineCtcFstDecoder : public OfflineCtcDecoder {
public:
explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config);

std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) override;
std::vector<OfflineCtcDecoderResult> Decode(Ort::Value log_probs,
Ort::Value log_probs_length,
OfflineStream **ss = nullptr,
int32_t n = 0) override;

private:
OfflineCtcFstDecoderConfig config_;
Expand Down
5 changes: 3 additions & 2 deletions sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation

Expand All @@ -13,7 +13,8 @@
namespace sherpa_onnx {

std::vector<OfflineCtcDecoderResult> OfflineCtcGreedySearchDecoder::Decode(
Ort::Value log_probs, Ort::Value log_probs_length) {
Ort::Value log_probs, Ort::Value log_probs_length,
OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = static_cast<int32_t>(shape[0]);
int32_t num_frames = static_cast<int32_t>(shape[1]);
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder {
explicit OfflineCtcGreedySearchDecoder(int32_t blank_id)
: blank_id_(blank_id) {}

std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) override;
std::vector<OfflineCtcDecoderResult> Decode(Ort::Value log_probs,
Ort::Value log_probs_length,
OfflineStream **ss = nullptr,
int32_t n = 0) override;

private:
int32_t blank_id_;
Expand Down
132 changes: 132 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h"

#include <algorithm>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

static std::vector<Hypothesis> StepWorker(const float *p_log_probs,
std::vector<Hypothesis> &hyps,
int32_t blank_id, int32_t vocab_size,
int32_t max_active_paths,
const ContextGraph *context_graph) {
auto topk = TopkIndex(p_log_probs, vocab_size, max_active_paths);
Hypotheses next_hyps;
for (auto &hyp : hyps) {
for (auto k : topk) {
Hypothesis new_hyp = hyp;
int32_t new_token = k;
float log_prob = p_log_probs[k];
bool update_prefix = false;
if (new_token == blank_id) {
// Case 0: *a + ε => *a
// *aε + ε => *a
// Prefix does not change, update log_prob of blank
new_hyp.log_prob_nb = -std::numeric_limits<float>::infinity();
new_hyp.log_prob_b = hyp.LogProb(true) + log_prob;
next_hyps.Add(std::move(new_hyp));
} else if (hyp.ys.size() > 0 && hyp.ys.back() == new_token) {
// Case 1: *a + a => *a
// Prefix does not change, update log_prob of non_blank
new_hyp.log_prob_nb = hyp.log_prob_nb + log_prob;
new_hyp.log_prob_b = -std::numeric_limits<float>::infinity();
next_hyps.Add(std::move(new_hyp));

// Case 2: *aε + a => *aa
// Prefix changes, update log_prob of blank
new_hyp = hyp;
new_hyp.ys.push_back(new_token);
new_hyp.log_prob_nb = hyp.log_prob_b + log_prob;
new_hyp.log_prob_b = -std::numeric_limits<float>::infinity();
update_prefix = true;
} else {
// Case 3: *a + b => *ab, *aε + b => *ab
// Prefix changes, update log_prob of non_blank
// Caution: DO NOT use append, as clone is shallow copy
new_hyp.ys.push_back(new_token);
new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob;
new_hyp.log_prob_b = -std::numeric_limits<float>::infinity();
update_prefix = true;
}

if (update_prefix) {
float lm_log_prob = hyp.lm_log_prob;
if (context_graph != nullptr && hyp.context_state != nullptr) {
auto context_res =
context_graph->ForwardOneStep(hyp.context_state, new_token);
lm_log_prob = lm_log_prob + std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
new_hyp.lm_log_prob = lm_log_prob;
next_hyps.Add(std::move(new_hyp));
}
}
}
return next_hyps.GetTopK(max_active_paths, false, true);
}

std::vector<OfflineCtcDecoderResult> OfflineCtcPrefixBeamSearchDecoder::Decode(
Ort::Value log_probs, Ort::Value log_probs_length,
OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = static_cast<int32_t>(shape[0]);
int32_t num_frames = static_cast<int32_t>(shape[1]);
int32_t vocab_size = static_cast<int32_t>(shape[2]);

const int64_t *p_log_probs_length = log_probs_length.GetTensorData<int64_t>();

std::vector<OfflineCtcDecoderResult> ans;
ans.reserve(batch_size);

std::vector<std::vector<Hypothesis>> cur;
cur.reserve(batch_size);

std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);

for (int32_t i = 0; i < batch_size; ++i) {
const ContextState *context_state = nullptr;
if (ss != nullptr) {
context_graphs[i] = ss[i]->GetContextGraph();
if (context_graphs[i] != nullptr)
context_state = context_graphs[i]->Root();
}
Hypothesis hyp(context_state);
cur.emplace_back(std::vector<Hypothesis>({hyp}));
}

for (int32_t t = 0; t < num_frames; ++t) {
for (int32_t b = 0; b < batch_size; ++b) {
if (t < p_log_probs_length[b]) {
const float *p_log_probs = log_probs.GetTensorData<float>() +
b * num_frames * vocab_size + t * vocab_size;
cur[b] = StepWorker(p_log_probs, cur[b], blank_id_, vocab_size,
max_active_paths_, context_graphs[b].get());
// for (auto &x : cur[b]) {
// SHERPA_ONNX_LOGE("step : %d, key : %s, ac : %f, lm : %f", t,
// x.Key().c_str(), x.LogProb(true), x.lm_log_prob);
// }
// SHERPA_ONNX_LOGE("\n");
}
}
}

for (int32_t b = 0; b != batch_size; ++b) {
Hypotheses hyps(cur[b]);
Hypothesis best_hyp = hyps.GetMostProbable(false, true);
OfflineCtcDecoderResult r;
r.tokens = best_hyp.ys;
ans.push_back(std::move(r));
}
return ans;
}

} // namespace sherpa_onnx
Loading
Loading