Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 8, 2023
1 parent fdc5671 commit 70ff69e
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 14 deletions.
4 changes: 2 additions & 2 deletions cmake/eigen.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ function(download_eigen)
if(EXISTS ${f})
set(eigen_URL "${f}")
file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL)
message(STATUS "Found local downloaded eigen: ${eigen_URL}")
set(eigen_URL2)
break()
endif()
Expand All @@ -34,13 +35,12 @@ function(download_eigen)

FetchContent_GetProperties(eigen)
if(NOT eigen_POPULATED)
message(STATUS "Downloading eigen ${eigen_URL}")
message(STATUS "Downloading eigen from ${eigen_URL}")
FetchContent_Populate(eigen)
endif()
message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}")
message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}")


add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL)
endfunction()

Expand Down
5 changes: 3 additions & 2 deletions cmake/kaldifst.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function(download_kaldifst)
set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2")

# If you don't have access to the Internet,
# please pre-download kaldi_native_io
# please pre-download kaldifst
set(possible_file_locations
$ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz
${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz
Expand All @@ -19,6 +19,7 @@ function(download_kaldifst)
if(EXISTS ${f})
set(kaldifst_URL "${f}")
file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL)
message(STATUS "Found local downloaded kaldifst: ${kaldifst_URL}")
set(kaldifst_URL2)
break()
endif()
Expand All @@ -34,7 +35,7 @@ function(download_kaldifst)

FetchContent_GetProperties(kaldifst)
if(NOT kaldifst_POPULATED)
message(STATUS "Downloading kaldifst ${kaldifst_URL}")
message(STATUS "Downloading kaldifst from ${kaldifst_URL}")
FetchContent_Populate(kaldifst)
endif()
message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}")
Expand Down
16 changes: 14 additions & 2 deletions sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@
#include "fst/fstlib.h"
#include "kaldi-decoder/csrc/decodable-ctc.h"
#include "kaldi-decoder/csrc/eigen.h"
#include "kaldi-decoder/csrc/faster-decoder.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

// this function is copied from kaldi
// This function is copied from kaldi.
//
// @param filename Path to a StdVectorFst or StdConstFst graph
// @return The caller should free the returned pointer using `delete` to
// avoid memory leak.
static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
// read decoding network FST
std::ifstream is(filename);
Expand All @@ -33,7 +38,7 @@ static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
}
fst::FstReadOptions ropts("<unspecified>", &hdr);

fst::Fst<fst::StdArc> *decode_fst = NULL;
fst::Fst<fst::StdArc> *decode_fst = nullptr;

if (hdr.FstType() == "vector") {
decode_fst = fst::VectorFst<fst::StdArc>::Read(is, ropts);
Expand All @@ -52,6 +57,13 @@ static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
}
}

/**
* @param decoder
* @param p Pointer to a 2-d array of shape (num_frames, vocab_size)
* @param num_frames Number of rows in the 2-d array.
* @param vocab_size Number of columns in the 2-d array.
* @return Return the decoded result.
*/
static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
const float *p, int32_t num_frames,
int32_t vocab_size) {
Expand Down
6 changes: 0 additions & 6 deletions sherpa-onnx/csrc/offline-ctc-fst-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <vector>

#include "fst/fst.h"
#include "kaldi-decoder/csrc/faster-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
Expand All @@ -23,11 +22,6 @@ class OfflineCtcFstDecoder : public OfflineCtcDecoder {
std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) override;

private:
// Decode a single utterance
OfflineCtcDecoderResult Decode(const float *p, int32_t num_frames,
int32_t vocab_size) const;

private:
OfflineCtcFstDecoderConfig config_;

Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
model_->FeatureNormalizationMethod();

if (!config_.ctc_fst_decoder_config.graph.empty()) {
// TODO(fangjun): Support android to read the graph from
// asset_manager
decoder_ = std::make_unique<OfflineCtcFstDecoder>(
config_.ctc_fst_decoder_config);
} else if (config_.decoding_method == "greedy_search") {
Expand Down
5 changes: 4 additions & 1 deletion sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@

namespace sherpa_onnx {

// for https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn
// for
// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
struct OfflineZipformerCtcModelConfig {
std::string model;

OfflineZipformerCtcModelConfig() = default;

explicit OfflineZipformerCtcModelConfig(const std::string &model)
: model(model) {}

void Register(ParseOptions *po);

bool Validate() const;

std::string ToString() const;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h
// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation

Expand All @@ -13,6 +13,7 @@ namespace sherpa_onnx {
void PybindOfflineZipformerCtcModelConfig(py::module *m) {
using PyClass = OfflineZipformerCtcModelConfig;
py::class_<PyClass>(*m, "OfflineZipformerCtcModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
OfflineParaformerModelConfig,
OfflineTdnnModelConfig,
OfflineWhisperModelConfig,
OfflineZipformerCtcModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
Expand Down

0 comments on commit 70ff69e

Please sign in to comment.