Skip to content

Commit

Permalink
Judge before UseCachedDecoderOut (#431)
Browse files Browse the repository at this point in the history
Co-authored-by: hiedean <[email protected]>
  • Loading branch information
HieDean and hiedean authored Nov 17, 2023
1 parent eeda1e1 commit 1a6a41e
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,24 @@ void OnlineTransducerGreedySearchDecoder::Decode(
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
int32_t vocab_size = model_->VocabSize();

Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
UseCachedDecoderOut(*result, &decoder_out);
Ort::Value decoder_out{nullptr};
bool is_batch_decoder_out_cached = true;
for (const auto &r : *result) {
if (!r.decoder_out) {
is_batch_decoder_out_cached = false;
break;
}
}
if (is_batch_decoder_out_cached) {
auto &r = result->front();
std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
decoder_out_shape[0] = batch_size;
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size());
UseCachedDecoderOut(*result, &decoder_out);
} else {
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
decoder_out = model_->RunDecoder(std::move(decoder_input));
}

for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out =
Expand Down

0 comments on commit 1a6a41e

Please sign in to comment.