Skip to content

Commit

Permalink
Replace Clone() with View()
Browse files Browse the repository at this point in the history
  • Loading branch information
hiedean committed Nov 19, 2023
1 parent ac00eda commit 83d0bc0
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
// now cur_encoder_out is of shape (num_hyps, joiner_dim)

Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
std::move(cur_encoder_out), View(&decoder_out));

float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-conformer-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ Ort::Value OnlineConformerTransducerModel::RunDecoder(
Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
View(&decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-lstm-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
View(&decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
Expand Down
8 changes: 4 additions & 4 deletions sherpa-onnx/csrc/online-rnn-lm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ class OnlineRnnLM::Impl {
return {std::move(out[0]), std::move(next_states)};
}

std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() const {
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(init_states_.size());
for (const auto &s : init_states_) {
ans.emplace_back(Clone(allocator_, &s));
for (auto &s : init_states_) {
ans.emplace_back(View(&s));
}
return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)};
return {std::move(View(&init_scores_.value)), std::move(ans)};
}

private:
Expand Down
8 changes: 5 additions & 3 deletions sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
}
if (is_batch_decoder_out_cached) {
auto &r = result->front();
std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
std::vector<int64_t> decoder_out_shape =
r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
decoder_out_shape[0] = batch_size;
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size());
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(),
decoder_out_shape.data(), decoder_out_shape.size());
UseCachedDecoderOut(*result, &decoder_out);
} else {
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Expand All @@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
std::move(cur_encoder_out), View(&decoder_out));

const float *p_logit = logit.GetTensorData<float>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
cur_encoder_out =
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
std::move(cur_encoder_out), View(&decoder_out));

float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
Expand Down
12 changes: 6 additions & 6 deletions sherpa-onnx/csrc/online-wenet-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ class OnlineWenetCtcModel::Impl {
std::array<Ort::Value, 6> inputs = {std::move(x),
View(&offset),
View(&required_cache_size_tensor_),
std::move(attn_cache),
std::move(conv_cache),
std::move(attn_mask)};
View(&attn_cache),
View(&conv_cache),
View(&attn_mask)};

auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
Expand Down Expand Up @@ -105,11 +105,11 @@ class OnlineWenetCtcModel::Impl {
// - attn_cache
// - conv_cache
// - offset
std::vector<Ort::Value> GetInitStates() const {
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(3);
ans.push_back(Clone(Allocator(), &attn_cache_));
ans.push_back(Clone(Allocator(), &conv_cache_));
ans.push_back(View(&attn_cache_));
ans.push_back(View(&conv_cache_));

int64_t offset_shape = 1;

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-zipformer-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ Ort::Value OnlineZipformerTransducerModel::RunDecoder(
Ort::Value OnlineZipformerTransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
View(&decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-zipformer2-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ Ort::Value OnlineZipformer2TransducerModel::RunDecoder(
Ort::Value OnlineZipformer2TransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
View(&decoder_out)};
auto logit =
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
joiner_input.size(), joiner_output_names_ptr_.data(),
Expand Down

0 comments on commit 83d0bc0

Please sign in to comment.