Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support text normalization via rule FST #407

Merged
merged 2 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@ message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}")
message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}")
message(STATUS "SHERPA_ONNX_ENABLE_GPU ${SHERPA_ONNX_ENABLE_GPU}")

set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
endif()
set(CMAKE_CXX_EXTENSIONS OFF)
message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}")

include(CheckIncludeFileCXX)
check_include_file_cxx(alsa/asoundlib.h SHERPA_ONNX_HAS_ALSA)
Expand Down
16 changes: 8 additions & 8 deletions cmake/kaldifst.cmake
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
function(download_kaldifst)
include(FetchContent)

set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.6.tar.gz")
set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.6.tar.gz")
set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2")
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.8.tar.gz")
set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.8.tar.gz")
set(kaldifst_HASH "SHA256=94613923568ef9a240ba1059b8b9dfe3082daad794934635d99e66248a6687b5")

# If you don't have access to the Internet,
# please pre-download kaldifst
set(possible_file_locations
$ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz
${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz
${PROJECT_BINARY_DIR}/kaldifst-1.7.6.tar.gz
/tmp/kaldifst-1.7.6.tar.gz
/star-fj/fangjun/download/github/kaldifst-1.7.6.tar.gz
$ENV{HOME}/Downloads/kaldifst-1.7.8.tar.gz
${PROJECT_SOURCE_DIR}/kaldifst-1.7.8.tar.gz
${PROJECT_BINARY_DIR}/kaldifst-1.7.8.tar.gz
/tmp/kaldifst-1.7.8.tar.gz
/star-fj/fangjun/download/github/kaldifst-1.7.8.tar.gz
)

foreach(f IN LISTS possible_file_locations)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime-linux-aarch64.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ if(NOT BUILD_SHARED_LIBS)
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
endif()

set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.16.1/onnxruntime-linux-aarch64-1.16.1.tgz")
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.16.1/onnxruntime-linux-aarch64-1.16.1.tgz")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-aarch64-1.16.1.tgz")
set(onnxruntime_HASH "SHA256=f10851b62eb44f9e811134737e7c6edd15733d2c1549cb6ce403808e9c047385")

Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime-linux-x86_64-gpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ if(NOT SHERPA_ONNX_ENABLE_GPU)
message(FATAL_ERROR "This file is for NVIDIA GPU only. Given SHERPA_ONNX_ENABLE_GPU: ${SHERPA_ONNX_ENABLE_GPU}")
endif()

set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.16.1/onnxruntime-linux-x64-gpu-1.16.1.tgz")
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.16.1/onnxruntime-linux-x64-gpu-1.16.1.tgz")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-x64-gpu-1.16.1.tgz")
set(onnxruntime_HASH "SHA256=474d5d74b588d54aa3e167f38acc9b1b8d20c292d0db92299bdc33a81eb4492d")

Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime-linux-x86_64.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ if(NOT BUILD_SHARED_LIBS)
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
endif()

set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.16.1/onnxruntime-linux-x64-1.16.1.tgz")
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.16.1/onnxruntime-linux-x64-1.16.1.tgz")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-x64-1.16.1.tgz")
set(onnxruntime_HASH "SHA256=53a0f03f71587ed602e99e82773132fc634b74c2d227316fbfd4bf67181e72ed")

Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime-osx-arm64.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ if(NOT BUILD_SHARED_LIBS)
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
endif()

set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.16.1/onnxruntime-osx-arm64-1.16.1.tgz")
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.16.1/onnxruntime-osx-arm64-1.16.1.tgz")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-osx-arm64-1.16.1.tgz")
set(onnxruntime_HASH "SHA256=56ca6b8de3a220ea606c2067ba65d11dfa6e4f722e01ac7dc75f7152b81445e0")

Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime-osx-universal.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ if(NOT BUILD_SHARED_LIBS)
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
endif()

set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.16.1/onnxruntime-osx-universal2-1.16.1.tgz")
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.16.1/onnxruntime-osx-universal2-1.16.1.tgz")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-osx-universal2-1.16.1.tgz")
set(onnxruntime_HASH "SHA256=e8568a4a3f602c25ea7c3bbd2f085340dff5bb68fa7c859fd763d944105e3d76")

Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime-osx-x86_64.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ if(NOT BUILD_SHARED_LIBS)
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
endif()

set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.16.1/onnxruntime-osx-x86_64-1.16.1.tgz")
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.16.1/onnxruntime-osx-x86_64-1.16.1.tgz")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-osx-x86_64-1.16.1.tgz")
set(onnxruntime_HASH "SHA256=0b8ae24401a8f75e1c4f75257d4eaeb1b6d44055e027df4aa4a84e67e0f9b9e3")

Expand Down
46 changes: 41 additions & 5 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,50 @@
#include "android/asset_manager_jni.h"
#endif

#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {

class OfflineTtsVitsImpl : public OfflineTtsImpl {
public:
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
: config_(config),
model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations(), model_->Language(), config.model.debug,
model_->IsPiper()) {}
model_->IsPiper()) {
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
tn_list_.reserve(files.size());
for (const auto &f : files) {
if (config.model.debug) {
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
}
tn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
}
}
}

#if __ANDROID_API__ >= 9
OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
: config_(config),
model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations(), model_->Language(), config.model.debug,
model_->IsPiper()) {}
model_->IsPiper()) {
if (!config.rule_fsts.empty()) {
SHERPA_ONNX_LOGE("TODO(fangjun): Implement rule FST for Android");
}
}
#endif

GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
GeneratedAudio Generate(const std::string &_text, int64_t sid = 0,
float speed = 1.0) const override {
int32_t num_speakers = model_->NumSpeakers();
if (num_speakers == 0 && sid != 0) {
Expand All @@ -55,6 +75,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
sid = 0;
}

std::string text = _text;
if (config_.model.debug) {
SHERPA_ONNX_LOGE("Raw text: %s", text.c_str());
}

if (!tn_list_.empty()) {
for (const auto &tn : tn_list_) {
text = tn->Normalize(text);
if (config_.model.debug) {
SHERPA_ONNX_LOGE("After normalizing: %s", text.c_str());
}
}
}

std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
if (x.empty()) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
Expand Down Expand Up @@ -98,7 +132,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}

private:
OfflineTtsConfig config_;
std::unique_ptr<OfflineTtsVitsModel> model_;
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> tn_list_;
Lexicon lexicon_;
};

Expand Down
31 changes: 28 additions & 3 deletions sherpa-onnx/csrc/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,44 @@

#include <string>

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {

void OfflineTtsConfig::Register(ParseOptions *po) { model.Register(po); }
void OfflineTtsConfig::Register(ParseOptions *po) {
model.Register(po);

bool OfflineTtsConfig::Validate() const { return model.Validate(); }
po->Register("tts-rule-fsts", &rule_fsts,
"It not empty, it contains a list of rule FST filenames."
"Multiple filenames are separated by a comma and they are "
"applied from left to right. An example value: "
"rule1.fst,rule2,fst,rule3.fst");
}

bool OfflineTtsConfig::Validate() const {
if (!rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fsts, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule fst %s does not exist. ", f.c_str());
return false;
}
}
}

return model.Validate();
}

std::string OfflineTtsConfig::ToString() const {
std::ostringstream os;

os << "OfflineTtsConfig(";
os << "model=" << model.ToString() << ")";
os << "model=" << model.ToString() << ", ";
os << "rule_fsts=\"" << rule_fsts << "\")";

return os.str();
}
Expand Down
11 changes: 9 additions & 2 deletions sherpa-onnx/csrc/offline-tts.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,17 @@ namespace sherpa_onnx {

struct OfflineTtsConfig {
OfflineTtsModelConfig model;
// If not empty, it contains a list of rule FST filenames.
// Filenames are separated by a comma.
// Example value: rule1.fst,rule2,fst,rule3.fst
//
// If there are multiple rules, they are applied from left to right.
std::string rule_fsts;

OfflineTtsConfig() = default;
explicit OfflineTtsConfig(const OfflineTtsModelConfig &model)
: model(model) {}
OfflineTtsConfig(const OfflineTtsModelConfig &model,
const std::string &rule_fsts)
: model(model), rule_fsts(rule_fsts) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
6 changes: 5 additions & 1 deletion sherpa-onnx/python/csrc/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-tts.h"

#include <string>

#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/python/csrc/offline-tts-model-config.h"

Expand All @@ -28,8 +30,10 @@ static void PybindOfflineTtsConfig(py::module *m) {
using PyClass = OfflineTtsConfig;
py::class_<PyClass>(*m, "OfflineTtsConfig")
.def(py::init<>())
.def(py::init<const OfflineTtsModelConfig &>(), py::arg("model"))
.def(py::init<const OfflineTtsModelConfig &, const std::string &>(),
py::arg("model"), py::arg("rule_fsts") = "")
.def_readwrite("model", &PyClass::model)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def("__str__", &PyClass::ToString);
}

Expand Down
Loading