diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index e90426bdc..c2fc1103d 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -89,9 +89,24 @@ void OnlineTransducerGreedySearchDecoder::Decode( int32_t num_frames = static_cast(encoder_out_shape[1]); int32_t vocab_size = model_->VocabSize(); - Ort::Value decoder_input = model_->BuildDecoderInput(*result); - Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); - UseCachedDecoderOut(*result, &decoder_out); + Ort::Value decoder_out{nullptr}; + bool is_batch_decoder_out_cached = true; + for (const auto &r : *result) { + if (!r.decoder_out) { + is_batch_decoder_out_cached = false; + break; + } + } + if (is_batch_decoder_out_cached) { + auto &r = result->front(); + std::vector decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); + decoder_out_shape[0] = batch_size; + decoder_out = Ort::Value::CreateTensor(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size()); + UseCachedDecoderOut(*result, &decoder_out); + } else { + Ort::Value decoder_input = model_->BuildDecoderInput(*result); + decoder_out = model_->RunDecoder(std::move(decoder_input)); + } for (int32_t t = 0; t != num_frames; ++t) { Ort::Value cur_encoder_out =