diff --git a/CMakeLists.txt b/CMakeLists.txt index 0524b96..9da52e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,17 +33,6 @@ if (MSVC) # some warnings related with fst add_compile_options(/wd4018 /wd4244 /wd4267 /wd4291 /wd4305) - set(CompilerFlags - CMAKE_CXX_FLAGS - CMAKE_CXX_FLAGS_DEBUG - CMAKE_CXX_FLAGS_RELEASE - CMAKE_C_FLAGS - CMAKE_C_FLAGS_DEBUG - CMAKE_C_FLAGS_RELEASE - ) - foreach(CompilerFlag ${CompilerFlags}) - string(REPLACE "/MD" "/MT" ${CompilerFlag} "${${CompilerFlag}}") - endforeach() elseif(APPLE) set(CMAKE_INSTALL_RPATH "@loader_path") else() diff --git a/extensions/decoder/decoder.cpp b/extensions/decoder/decoder.cpp index faee2e8..5865664 100644 --- a/extensions/decoder/decoder.cpp +++ b/extensions/decoder/decoder.cpp @@ -1716,6 +1716,14 @@ void pybind_training_graph_compiler(py::module &m) { TrainingGraphCompiler gc(trans_model, ctx_dep, mf, disambig_syms, opts); return gc; })) + .def(py::init([](const TransitionModel &trans_model, const ContextDependency &ctx_dep, + const std::string lex_rxfilename, const std::vector &disambig_syms, + const TrainingGraphCompilerOptions &opts){ + VectorFst *lex_fst = fst::ReadFstKaldi(lex_rxfilename); + TrainingGraphCompiler *gc = new TrainingGraphCompiler(trans_model, ctx_dep, lex_fst, disambig_syms, opts); + return gc; + }), + py::return_value_policy::take_ownership) .def("CompileGraph", &PyClass::CompileGraph, "CompileGraph compiles a single training graph its input is a " @@ -1792,7 +1800,7 @@ void pybind_training_graph_compiler(py::module &m) { }, "This function creates FSTs from the text and calls CompileGraphs.", py::arg("transcripts"), - py::return_value_policy::reference); + py::return_value_policy::take_ownership); } } diff --git a/extensions/fstext/pybind_fstext.h b/extensions/fstext/pybind_fstext.h index 0f72130..3104a54 100644 --- a/extensions/fstext/pybind_fstext.h +++ b/extensions/fstext/pybind_fstext.h @@ -1132,22 +1132,22 @@ void pybind_const_fst_impl(py::module& m, const std::string& class_name, .def("Type", &PyClass::Type, "FST typename", py::return_value_policy::reference) .def("Copy", &PyClass::Copy, - "Get a copy of this VectorFst. See Fst<>::Copy() for further " + "Get a copy of this ConstFst. See Fst<>::Copy() for further " "doc.", py::arg("safe") = false, py::return_value_policy::take_ownership) .def_static("Read", // clang-format off overload_cast_()(&PyClass::Read), // clang-format on - "Reads a VectorFst from an input stream, returning nullptr " + "Reads a ConstFst from an input stream, returning nullptr " "on error.", py::arg("strm"), py::arg("opts"), py::return_value_policy::take_ownership) .def_static("Read", overload_cast_()(&PyClass::Read), - "Read a VectorFst from a file, returning nullptr on error; " + "Read a ConstFst from a file, returning nullptr on error; " "empty " "filename reads from standard input.", - py::arg("filename"), py::return_value_policy::take_ownership) + py::arg("filename"), py::return_value_policy::reference) .def("Write", // clang-format off (bool (PyClass::*)(std::ostream&, const fst::FstWriteOptions&)const)&PyClass::Write, diff --git a/extensions/ivector/ivector.cpp b/extensions/ivector/ivector.cpp index 74a6754..6fc97c6 100644 --- a/extensions/ivector/ivector.cpp +++ b/extensions/ivector/ivector.cpp @@ -853,6 +853,7 @@ void pybind_plda(py::module &m) { py::array_t utterance_two_ivector ){ py::gil_scoped_release gil_release; + Vector ivector_one_dbl; auto r_one = utterance_one_ivector.unchecked<1>(); ivector_one_dbl.Resize(r_one.shape(0)); @@ -869,6 +870,40 @@ void pybind_plda(py::module &m) { 1, ivector_two_dbl)); return score; + + }, + py::arg("utterance_one_ivector"), + py::arg("utterance_two_ivector")) + .def("log_likelihood_distance_vectorized", + []( + PyClass &plda, + py::array_t utterance_one_ivector, + py::array_t utterance_two_ivector + ){ + py::gil_scoped_release gil_release; + py::buffer_info buf1 = utterance_one_ivector.request(), buf2 = utterance_two_ivector.request(); + + auto r_one = utterance_one_ivector.unchecked<2>(); + auto r_two = utterance_two_ivector.unchecked<2>(); + auto result = py::array_t(r_one.shape(0)); + py::buffer_info buf3 = result.request(); + double *ptr3 = static_cast(buf3.ptr); + for (py::size_t i = 0; i < r_one.shape(0); i++){ + Vector ivector_one_dbl; + ivector_one_dbl.Resize(r_one.shape(1)); + Vector ivector_two_dbl; + ivector_two_dbl.Resize(r_two.shape(1)); + for (py::size_t j = 0; j < r_one.shape(1); j++){ + ivector_one_dbl(j) = r_one(i, j); + ivector_two_dbl(j) = r_two(i, j); + + } + ptr3[i] = 1.0 / Exp(plda.LogLikelihoodRatio(ivector_one_dbl, + 1, + ivector_two_dbl));; + + } + return result; }, py::arg("utterance_one_ivector"), py::arg("utterance_two_ivector")) diff --git a/extensions/matrix/matrix.cpp b/extensions/matrix/matrix.cpp index 208b7b7..6908b31 100644 --- a/extensions/matrix/matrix.cpp +++ b/extensions/matrix/matrix.cpp @@ -7,7 +7,6 @@ #include "matrix/matrix-common.h" #include "matrix/sparse-matrix.h" #include -#include #include "util/pybind_util.h" using namespace kaldi; diff --git a/kalpy/decoder/data.py b/kalpy/decoder/data.py index fc3dcac..34850e7 100644 --- a/kalpy/decoder/data.py +++ b/kalpy/decoder/data.py @@ -36,8 +36,8 @@ def __iter__(self) -> typing.Generator[typing.Tuple[str, VectorFst]]: try: while not reader.Done(): utt = reader.Key() - fst = reader.Value() - decode_fst = VectorFst(fst) + decode_fst = VectorFst(reader.Value()) + reader.FreeCurrent() yield utt, decode_fst reader.Next() finally: diff --git a/kalpy/decoder/training_graphs.py b/kalpy/decoder/training_graphs.py index 30e3dfc..50f1ee1 100644 --- a/kalpy/decoder/training_graphs.py +++ b/kalpy/decoder/training_graphs.py @@ -138,11 +138,14 @@ def compiler(self): self._fst = pynini.Fst.read(str(self.lexicon_path)) else: self._fst = self.lexicon_compiler.fst + disambiguation_symbols = [] + if self.lexicon_compiler is not None and self.lexicon_compiler.disambiguation: + disambiguation_symbols = self.lexicon_compiler.disambiguation_symbols self._compiler = _TrainingGraphCompiler( self.transition_model, self.tree, VectorFst.from_pynini(self._fst), - self.disambiguation_symbols, + disambiguation_symbols, self.options, ) return self._compiler @@ -227,7 +230,9 @@ def export_graphs( writer.Close() logger.info(f"Done {num_done} utterances, errors on {num_error}.") - def compile_fst(self, transcript: str) -> typing.Optional[VectorFst]: + def compile_fst( + self, transcript: str, interjection_words: typing.List[str] = None + ) -> typing.Optional[VectorFst]: """ Compile a transcript to a training graph @@ -251,13 +256,90 @@ def compile_fst(self, transcript: str) -> typing.Optional[VectorFst]: state_threshold = 256 + 2 * lg_fst.num_states() lg_fst = pynini.determinize(lg_fst, nstate=state_threshold, weight=weight_threshold) lg_fst = VectorFst.from_pynini(lg_fst) + disambig_syms_in = ( + [] + if not self.lexicon_compiler.disambiguation + else self.lexicon_compiler.disambiguation_symbols + ) + lg_fst = fst_determinize_star(lg_fst, use_log=True) + fst_minimize_encoded(lg_fst) + fst_push_special(lg_fst) + clg_fst, disambig_out, ilabels = fst_compose_context( + lg_fst, + disambig_syms_in, + self.tree.ContextWidth(), + self.tree.CentralPosition(), + ) + fst_arc_sort(clg_fst, sort_type="ilabel") + h, disambig = make_h_transducer(self.tree, self.transition_model, ilabels) + fst = fst_table_compose(h, clg_fst) + if fst.Start() == pywrapfst.NO_STATE_ID: + logger.debug(f"Falling back to pynini compose for '{transcript}") + h = kaldi_to_pynini(h) + clg_fst = kaldi_to_pynini(clg_fst) + fst = pynini_to_kaldi(pynini.compose(h, clg_fst)) + fst_determinize_star(fst, use_log=True) + fst_rm_symbols(fst, disambig) + fst_rm_eps_local(fst) + fst_minimize_encoded(fst) + fst_add_self_loops( + fst, self.transition_model, disambig_syms_in, self.options.self_loop_scale + ) + elif interjection_words: + g = pynini.Fst() + start_state = g.add_state() + g.set_start(start_state) + for w in transcript.split(): + word_symbol = self.to_int(w) + word_initial_state = g.add_state() + for iw in interjection_words: + if not self.lexicon_compiler.word_table.member(iw): + continue + iw_symbol = self.to_int(iw) + g.add_arc( + word_initial_state - 1, + pywrapfst.Arc( + iw_symbol, + iw_symbol, + pywrapfst.Weight(g.weight_type(), 4.0), + word_initial_state, + ), + ) + word_final_state = g.add_state() + g.add_arc( + word_initial_state, + pywrapfst.Arc( + word_symbol, + word_symbol, + pywrapfst.Weight.one(g.weight_type()), + word_final_state, + ), + ) + g.add_arc( + word_initial_state - 1, + pywrapfst.Arc( + word_symbol, + word_symbol, + pywrapfst.Weight.one(g.weight_type()), + word_final_state, + ), + ) + g.set_final(word_final_state, pywrapfst.Weight.one(g.weight_type())) + + lg = pynini.compose(self.lexicon_compiler.fst, g) + lg.optimize() + lg.arcsort("olabel") + lg_fst = VectorFst.from_pynini(lg) + disambig_syms_in = [] + if self.lexicon_compiler is not None and self.lexicon_compiler.disambiguation: + disambig_syms_in = self.lexicon_compiler.disambiguation_symbols lg_fst = fst_determinize_star(lg_fst, use_log=True) fst_minimize_encoded(lg_fst) fst_push_special(lg_fst) clg_fst, disambig_out, ilabels = fst_compose_context( lg_fst, - self.disambiguation_symbols, + disambig_syms_in, self.tree.ContextWidth(), self.tree.CentralPosition(), ) @@ -273,7 +355,9 @@ def compile_fst(self, transcript: str) -> typing.Optional[VectorFst]: fst_rm_symbols(fst, disambig) fst_rm_eps_local(fst) fst_minimize_encoded(fst) - fst_add_self_loops(fst, self.transition_model, [], self.options.self_loop_scale) + fst_add_self_loops( + fst, self.transition_model, disambig_syms_in, self.options.self_loop_scale + ) else: transcript_symbols = [self.to_int(x) for x in transcript.split()] fst = self.compiler.CompileGraphFromText(transcript_symbols) diff --git a/kalpy/feat/lda.py b/kalpy/feat/lda.py index 9b0a704..2ca9c86 100644 --- a/kalpy/feat/lda.py +++ b/kalpy/feat/lda.py @@ -117,16 +117,18 @@ def accumulate_stats( num_done += 1 tot_like += tot_like_this_file tot_t += tot_weight_this_file - logger.info( - f"Average like for this file is {tot_like_this_file/tot_weight_this_file} " - f"over {tot_weight_this_file} frames." - ) - if num_done % 10 == 0: + if tot_weight_this_file != 0: + logger.info( + f"Average like for this file is {tot_like_this_file/tot_weight_this_file} " + f"over {tot_weight_this_file} frames." + ) + if num_done % 10 == 0 and tot_t != 0: logger.info(f"Average per frame so far is {tot_like/tot_t}") logger.info(f"Done {num_done} files.") - logger.info( - f"Overall avg like per frame (Gaussian only) = {tot_like/tot_t} over {tot_t} frames." - ) + if tot_t != 0: + logger.info( + f"Overall avg like per frame (Gaussian only) = {tot_like/tot_t} over {tot_t} frames." + ) def export_stats( self, file_name: str, feature_archive: FeatureArchive, alignment_archive: AlignmentArchive diff --git a/kalpy/fstext/lexicon.py b/kalpy/fstext/lexicon.py index 0c0e873..3bb1ccc 100644 --- a/kalpy/fstext/lexicon.py +++ b/kalpy/fstext/lexicon.py @@ -5,6 +5,7 @@ import math import pathlib import re +import threading import typing import dataclassy @@ -185,12 +186,17 @@ def __init__( else: self.phone_table.add_symbol(p) self.pronunciations: typing.List[Pronunciation] = [] + self._cached_pronunciations: typing.Set[typing.Tuple[str, str]] = set() self._fst = None - self._kaldi_fst = None self._align_fst = None self._align_lexicon = None self.word_begin_label = word_begin_label self.word_end_label = word_end_label + self.lock = threading.Lock() + self.start_state = None + self.loop_state = None + self.silence_state = None + self.non_silence_state = None def clear(self): self.pronunciations = [] @@ -338,11 +344,9 @@ def align_lexicon(self): self._align_lexicon = WordAlignLatticeLexiconInfo(lex) return self._align_lexicon - @property - def fst(self) -> pynini.Fst: - """Compiled lexicon FST""" - if self._fst is not None: - return self._fst + def create_fsts(self, phonological_rule_fst: pynini.Fst = None): + if self._fst is not None and self._align_fst is not None: + return initial_silence_cost = 0 initial_non_silence_cost = 0 @@ -356,63 +360,224 @@ def fst(self) -> pynini.Fst: final_silence_cost = -math.log(self.final_silence_correction) final_non_silence_cost = -math.log(self.final_non_silence_correction) - base_silence_following_cost = 0 - base_non_silence_following_cost = 0 - if self.silence_probability: - base_silence_following_cost = -math.log(self.silence_probability) - base_non_silence_following_cost = -math.log(1 - self.silence_probability) - self.phone_table.find(self.silence_disambiguation_symbol) - self.word_table.find("") + phone_eps_symbol = self.phone_table.find("") self.word_table.find(self.silence_word) - fst = pynini.Fst() - start_state = fst.add_state() - fst.set_start(start_state) - non_silence_state = fst.add_state() # Also loop state - silence_state = fst.add_state() + self._fst = pynini.Fst() + self._align_fst = pynini.Fst() + self.start_state = self._fst.add_state() + self._align_fst.add_state() + self._fst.set_start(self.start_state) + self.non_silence_state = self._fst.add_state() # Also loop state + self._align_fst.add_state() + self.silence_state = self._fst.add_state() + self._align_fst.add_state() + + self._align_fst.set_start(self.start_state) # initial no silence - fst.add_arc( - start_state, + self._fst.add_arc( + self.start_state, pywrapfst.Arc( - self.phone_table.find(self.silence_disambiguation_symbol), + phone_eps_symbol, + self.word_table.find(self.silence_word), + pywrapfst.Weight(self._fst.weight_type(), initial_non_silence_cost), + self.non_silence_state, + ), + ) + self._align_fst.add_arc( + self.start_state, + pywrapfst.Arc( + phone_eps_symbol, self.word_table.find(self.silence_word), - pywrapfst.Weight(fst.weight_type(), initial_non_silence_cost), - non_silence_state, + pywrapfst.Weight(self._align_fst.weight_type(), initial_non_silence_cost), + self.non_silence_state, ), ) # initial silence - fst.add_arc( - start_state, + self._fst.add_arc( + self.start_state, + pywrapfst.Arc( + self.phone_table.find(self.silence_phone), + self.word_table.find(self.silence_word), + pywrapfst.Weight(self._fst.weight_type(), initial_silence_cost), + self.silence_state, + ), + ) + self._align_fst.add_arc( + self.start_state, pywrapfst.Arc( self.phone_table.find(self.silence_phone), self.word_table.find(self.silence_word), - pywrapfst.Weight(fst.weight_type(), initial_silence_cost), - silence_state, + pywrapfst.Weight(self._align_fst.weight_type(), initial_silence_cost), + self.silence_state, ), ) + for pron in self.pronunciations: - word_symbol = self.word_table.find(pron.orthography) - phones = pron.pronunciation.split() - silence_before_cost = ( - -math.log(pron.silence_before_correction) - if pron.silence_before_correction - else 0.0 + self.add_pronunciation(pron, phonological_rule_fst) + if final_silence_cost > 0: + self._fst.set_final( + self.silence_state, pywrapfst.Weight(self._fst.weight_type(), final_silence_cost) ) - non_silence_before_cost = ( - -math.log(pron.non_silence_before_correction) - if pron.non_silence_before_correction - else 0.0 + self._align_fst.set_final( + self.silence_state, + pywrapfst.Weight(self._align_fst.weight_type(), final_silence_cost), ) - silence_following_cost = ( - -math.log(pron.silence_after_probability) - if pron.silence_after_probability - else base_silence_following_cost + else: + self._fst.set_final(self.silence_state, pywrapfst.Weight.one(self._fst.weight_type())) + self._align_fst.set_final( + self.silence_state, pywrapfst.Weight.one(self._align_fst.weight_type()) ) - non_silence_following_cost = ( - -math.log(1 - pron.silence_after_probability) - if pron.silence_after_probability - else base_non_silence_following_cost + if final_non_silence_cost > 0: + self._fst.set_final( + self.non_silence_state, + pywrapfst.Weight(self._fst.weight_type(), final_non_silence_cost), + ) + self._align_fst.set_final( + self.non_silence_state, + pywrapfst.Weight(self._align_fst.weight_type(), final_non_silence_cost), + ) + else: + self._fst.set_final( + self.non_silence_state, pywrapfst.Weight.one(self._fst.weight_type()) + ) + self._align_fst.set_final( + self.non_silence_state, pywrapfst.Weight.one(self._align_fst.weight_type()) + ) + + if ( + self._fst.num_states() <= self.silence_state + 1 + or self._fst.start() == pywrapfst.NO_STATE_ID + ): + num_words = self.word_table.num_symbols() + num_phones = self.phone_table.num_symbols() + num_pronunciations = len(self.pronunciations) + raise LexiconError( + f"There was an error compiling the lexicon " + f"({num_words} words, {num_pronunciations} pronunciations, " + f"{num_phones} phones)." ) + self._align_fst.arcsort("olabel") + self._fst.arcsort("olabel") + + @property + def base_silence_following_cost(self): + base_silence_following_cost = 0 + if self.silence_probability: + base_silence_following_cost = -math.log(self.silence_probability) + return base_silence_following_cost + + @property + def base_non_silence_following_cost(self): + base_non_silence_following_cost = 0 + if self.silence_probability: + base_non_silence_following_cost = -math.log(1 - self.silence_probability) + return base_non_silence_following_cost + + @property + def fst(self) -> pynini.Fst: + """Compiled lexicon FST""" + if self._fst is None: + self.create_fsts() + return self._fst + + def _create_word_fst( + self, pronunciation: Pronunciation, phonological_rule_fst: pynini.Fst = None + ): + + pron = pronunciation.pronunciation + if self.position_dependent_phones: + phones = pronunciation.pronunciation.split() + if len(phones) == 1: + phones[0] += "_S" + else: + phones[0] += "_B" + phones[-1] += "_E" + for i in range(1, len(phones) - 1): + phones[i] += "_I" + pron = " ".join(phones) + if not self.word_table.member(pronunciation.orthography): + self.word_table.add_symbol(pronunciation.orthography) + word_symbol = self.word_table.find(pronunciation.orthography) + if self.disambiguation and pronunciation.disambiguation is not None: + pron += f" #{pronunciation.disambiguation}" + probability = pronunciation.probability + weight = pywrapfst.Weight.one("tropical") + if probability is not None: + if probability < 0.01: + probability = 0.01 # Dithering to ensure low probability entries + weight = pywrapfst.Weight("tropical", abs(math.log(probability))) + fst = pynini.accep(pron, weight=weight, token_type=self.phone_table) + if phonological_rule_fst: + fst = pynini.compose(phonological_rule_fst, fst) + fst = pynini.arcmap(fst, map_type="output_epsilon") + arcs = [] + for arc in fst.mutable_arcs(fst.start()): + arc = arc.copy() + arc.olabel = word_symbol + arcs.append(arc) + fst.delete_arcs(fst.start()) + for arc in arcs: + fst.add_arc(fst.start(), arc) + + silence_before_cost = ( + -math.log(pronunciation.silence_before_correction) + if pronunciation.silence_before_correction + else 0.0 + ) + non_silence_before_cost = ( + -math.log(pronunciation.non_silence_before_correction) + if pronunciation.non_silence_before_correction + else 0.0 + ) + silence_following_cost = ( + -math.log(pronunciation.silence_after_probability) + if pronunciation.silence_after_probability + else self.base_silence_following_cost + ) + non_silence_following_cost = ( + -math.log(1 - pronunciation.silence_after_probability) + if pronunciation.silence_after_probability + else self.base_non_silence_following_cost + ) + initial_silence_fst = pynini.union( + pynini.accep( + self.silence_phone, + weight=pywrapfst.Weight("tropical", silence_before_cost), + token_type=self.phone_table, + ), + pynini.accep( + "", + weight=pywrapfst.Weight("tropical", non_silence_before_cost), + token_type=self.phone_table, + ), + ) + + initial_silence_fst = pynini.arcmap(initial_silence_fst, map_type="output_epsilon") + final_silence_fst = pynini.union( + pynini.accep( + self.silence_phone, + weight=pywrapfst.Weight("tropical", silence_following_cost), + token_type=self.phone_table, + ), + pynini.accep( + "", + weight=pywrapfst.Weight("tropical", non_silence_following_cost), + token_type=self.phone_table, + ), + ) + final_silence_fst = pynini.arcmap(final_silence_fst, map_type="output_epsilon") + fst = initial_silence_fst + fst + final_silence_fst + fst.optimize() + return fst + + def add_pronunciation( + self, pronunciation: Pronunciation, phonological_rule_fst: pynini.Fst = None + ): + if (pronunciation.orthography, pronunciation.pronunciation) in self._cached_pronunciations: + return + with self.lock: + phones = pronunciation.pronunciation.split() if self.position_dependent_phones: if len(phones) == 1: phones[0] += "_S" @@ -421,94 +586,192 @@ def fst(self) -> pynini.Fst: phones[-1] += "_E" for i in range(1, len(phones) - 1): phones[i] += "_I" - probability = pron.probability + new_phones = ", ".join(sorted({x for x in phones if not self.phone_table.member(x)})) + if new_phones: + raise Exception( + f"The pronunciation '{pronunciation}' had the following phones not in the symbol table: {new_phones}" + ) + pron = " ".join(phones) + fst = pynini.accep(pron, token_type=self.phone_table) + if phonological_rule_fst: + fst = pynini.compose(phonological_rule_fst, fst) + fst.rmepsilon() + self._cached_pronunciations.add( + (pronunciation.orthography, pronunciation.pronunciation) + ) + if not self.word_table.member(pronunciation.orthography): + self.word_table.add_symbol(pronunciation.orthography) + word_symbol = self.word_table.find(pronunciation.orthography) + word_eps_symbol = self.word_table.find("") + phone_eps_symbol = self.phone_table.find("") + silence_before_cost = ( + -math.log(pronunciation.silence_before_correction) + if pronunciation.silence_before_correction + else 0.0 + ) + non_silence_before_cost = ( + -math.log(pronunciation.non_silence_before_correction) + if pronunciation.non_silence_before_correction + else 0.0 + ) + silence_following_cost = ( + -math.log(pronunciation.silence_after_probability) + if pronunciation.silence_after_probability + else self.base_silence_following_cost + ) + non_silence_following_cost = ( + -math.log(1 - pronunciation.silence_after_probability) + if pronunciation.silence_after_probability + else self.base_non_silence_following_cost + ) + probability = pronunciation.probability if probability is None: probability = 1 elif probability < 0.01: probability = 0.01 # Dithering to ensure low probability entries pron_cost = abs(math.log(probability)) - if self.disambiguation and pron.disambiguation is not None: - phones += [f"#{pron.disambiguation}"] - - new_state = fst.add_state() - phone_symbol = self.phone_table.find(phones[0]) - # No silence before the pronunciation - fst.add_arc( - non_silence_state, - pywrapfst.Arc( - phone_symbol, - word_symbol, - pywrapfst.Weight(fst.weight_type(), pron_cost + non_silence_before_cost), - new_state, - ), - ) - # Silence before the pronunciation - fst.add_arc( - silence_state, - pywrapfst.Arc( - phone_symbol, - word_symbol, - pywrapfst.Weight(fst.weight_type(), pron_cost + silence_before_cost), - new_state, - ), - ) - current_state = new_state - for i in range(1, len(phones)): - next_state = fst.add_state() - phone_symbol = self.phone_table.find(phones[i]) - fst.add_arc( - current_state, + start_index = self._fst.num_states() - 1 + align_start_index = self._align_fst.num_states() + num_new_states = fst.num_states() - 1 + self._fst.add_states(num_new_states) + self._align_fst.add_states(num_new_states + 2) + + # FST arcs + for state in fst.states(): + for arc in fst.arcs(state): + if state == fst.start(): + # No silence before the pronunciation + self._fst.add_arc( + self.non_silence_state, + pywrapfst.Arc( + arc.ilabel, + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + non_silence_before_cost + ), + arc.nextstate + start_index, + ), + ) + # Silence before the pronunciation + self._fst.add_arc( + self.silence_state, + pywrapfst.Arc( + arc.ilabel, + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + silence_before_cost + ), + arc.nextstate + start_index, + ), + ) + + # No silence before the pronunciation + self._align_fst.add_arc( + self.non_silence_state, + pywrapfst.Arc( + self.phone_table.find(self.word_begin_label), + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + non_silence_before_cost + ), + arc.nextstate + align_start_index - 1, + ), + ) + # Silence before the pronunciation + self._align_fst.add_arc( + self.silence_state, + pywrapfst.Arc( + self.phone_table.find(self.word_begin_label), + word_symbol, + pywrapfst.Weight( + self._fst.weight_type(), pron_cost + silence_before_cost + ), + arc.nextstate + align_start_index - 1, + ), + ) + else: + self._fst.add_arc( + state + start_index, + pywrapfst.Arc( + arc.ilabel, + word_eps_symbol, + arc.weight, + arc.nextstate + start_index, + ), + ) + self._align_fst.add_arc( + state + align_start_index, + pywrapfst.Arc( + arc.ilabel, + word_eps_symbol, + arc.weight, + arc.nextstate + align_start_index, + ), + ) + + if self.disambiguation and pronunciation.disambiguation is not None: + self._fst.add_state() + self._fst.add_arc( + num_new_states + start_index, pywrapfst.Arc( - phone_symbol, - self.word_table.find(""), - pywrapfst.Weight.one(fst.weight_type()), - next_state, + self.phone_table.find(f"#{pronunciation.disambiguation}"), + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), non_silence_following_cost), + num_new_states + start_index + 1, ), ) - current_state = next_state + start_index += 1 + # No silence following the pronunciation - fst.add_arc( - current_state, + self._fst.add_arc( + num_new_states + start_index, pywrapfst.Arc( self.phone_table.find(self.silence_disambiguation_symbol), - self.word_table.find(""), - pywrapfst.Weight(fst.weight_type(), non_silence_following_cost), - non_silence_state, + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), non_silence_following_cost), + self.non_silence_state, ), ) # Silence following the pronunciation - fst.add_arc( - current_state, + self._fst.add_arc( + num_new_states + start_index, pywrapfst.Arc( self.phone_table.find(self.silence_phone), - self.word_table.find(""), - pywrapfst.Weight(fst.weight_type(), silence_following_cost), - silence_state, + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), + self.silence_state, ), ) - if final_silence_cost > 0: - fst.set_final(silence_state, pywrapfst.Weight(fst.weight_type(), final_silence_cost)) - else: - fst.set_final(silence_state, pywrapfst.Weight.one(fst.weight_type())) - if final_non_silence_cost > 0: - fst.set_final( - non_silence_state, pywrapfst.Weight(fst.weight_type(), final_non_silence_cost) + self._align_fst.add_arc( + num_new_states + align_start_index, + pywrapfst.Arc( + self.phone_table.find(self.word_end_label), + word_eps_symbol, + pywrapfst.Weight.one(self._align_fst.weight_type()), + num_new_states + align_start_index + 1, + ), ) - else: - fst.set_final(non_silence_state, pywrapfst.Weight.one(fst.weight_type())) - fst.arcsort("olabel") - if fst.num_states() == 0 or fst.start() == pywrapfst.NO_STATE_ID: - num_words = self.word_table.num_symbols() - num_phones = self.phone_table.num_symbols() - num_pronunciations = len(self.pronunciations) - raise LexiconError( - f"There was an error compiling the lexicon " - f"({num_words} words, {num_pronunciations} pronunciations, " - f"{num_phones} phones)." + # No silence following the pronunciation + self._align_fst.add_arc( + num_new_states + align_start_index + 1, + pywrapfst.Arc( + phone_eps_symbol, + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), non_silence_following_cost), + self.non_silence_state, + ), + ) + # Silence following the pronunciation + self._align_fst.add_arc( + num_new_states + align_start_index + 1, + pywrapfst.Arc( + self.phone_table.find(self.silence_phone), + word_eps_symbol, + pywrapfst.Weight(self._fst.weight_type(), silence_following_cost), + self.silence_state, + ), ) - self._fst = fst - return self._fst @property def kaldi_fst(self) -> VectorFst: @@ -527,7 +790,6 @@ def load_l_from_file( Path to read HCLG.fst """ self._fst = pynini.Fst.read(str(l_fst_path)) - self._kaldi_fst = VectorFst.Read(str(l_fst_path)) def load_l_align_from_file( self, @@ -546,94 +808,8 @@ def load_l_align_from_file( @property def align_fst(self) -> pynini.Fst: """Compiled FST for aligning lattices when `position_dependent_phones` is False""" - if self._align_fst is not None: - return self._align_fst - fst = pynini.Fst() - start_state = fst.add_state() - loop_state = fst.add_state() - sil_state = fst.add_state() - next_state = fst.add_state() - fst.set_start(start_state) - word_eps_symbol = self.word_table.find("") - phone_eps_symbol = self.phone_table.find("") - sil_cost = -math.log(0.5) - non_sil_cost = sil_cost - fst.add_arc( - start_state, - pywrapfst.Arc( - phone_eps_symbol, - word_eps_symbol, - pywrapfst.Weight(fst.weight_type(), non_sil_cost), - loop_state, - ), - ) - fst.add_arc( - start_state, - pywrapfst.Arc( - phone_eps_symbol, - word_eps_symbol, - pywrapfst.Weight(fst.weight_type(), sil_cost), - sil_state, - ), - ) - fst.add_arc( - sil_state, - pywrapfst.Arc( - self.phone_table.find(self.silence_phone), - self.word_table.find(self.silence_word), - pywrapfst.Weight.one(fst.weight_type()), - loop_state, - ), - ) - - for pron in self.pronunciations: - phones = pron.pronunciation.split() - if self.position_dependent_phones: - if phones[0] != self.silence_phone: - if len(phones) == 1: - phones[0] += "_S" - else: - phones[0] += "_B" - phones[-1] += "_E" - for i in range(1, len(phones) - 1): - phones[i] += "_I" - phones = [self.word_begin_label] + phones + [self.word_end_label] - current_state = loop_state - for i in range(len(phones) - 1): - p_s = self.phone_table.find(phones[i]) - if i == 0: - w_s = self.word_table.find(pron.orthography) - else: - w_s = word_eps_symbol - fst.add_arc( - current_state, - pywrapfst.Arc(p_s, w_s, pywrapfst.Weight.one(fst.weight_type()), next_state), - ) - current_state = next_state - next_state = fst.add_state() - i = len(phones) - 1 - if i >= 0: - p_s = self.phone_table.find(phones[i]) - else: - p_s = phone_eps_symbol - if i <= 0: - w_s = self.word_table.find(pron.orthography) - else: - w_s = word_eps_symbol - fst.add_arc( - current_state, - pywrapfst.Arc( - p_s, w_s, pywrapfst.Weight(fst.weight_type(), non_sil_cost), loop_state - ), - ) - fst.add_arc( - current_state, - pywrapfst.Arc(p_s, w_s, pywrapfst.Weight(fst.weight_type(), sil_cost), sil_state), - ) - fst.delete_states([next_state]) - fst.set_final(loop_state, pywrapfst.Weight.one(fst.weight_type())) - fst.arcsort("olabel") - self._align_fst = fst + if self._align_fst is None: + self.create_fsts() return self._align_fst def _create_pronunciation_string( diff --git a/kalpy/gmm/align.py b/kalpy/gmm/align.py index e37ec69..f44818c 100644 --- a/kalpy/gmm/align.py +++ b/kalpy/gmm/align.py @@ -83,9 +83,7 @@ def align_utterance( ) if not successful: return None - return Alignment( - utterance_id, alignment, words, likelihood / len(alignment), per_frame_log_likelihoods - ) + return Alignment(utterance_id, alignment, words, likelihood, per_frame_log_likelihoods) def align_utterances( self, training_graph_archive: FstArchive, feature_archive: FeatureArchive diff --git a/kalpy/gmm/data.py b/kalpy/gmm/data.py index 30ad833..9b9d9bc 100644 --- a/kalpy/gmm/data.py +++ b/kalpy/gmm/data.py @@ -162,6 +162,8 @@ def export_textgrid( output_format: str = TextgridFormats.LONG_TEXTGRID, ): # Create initial textgrid + if file_duration is not None: + file_duration = round(file_duration, 6) tg = tgio.Textgrid() tg.minTimestamp = 0 tg.maxTimestamp = file_duration diff --git a/tests/test_decoder.py b/tests/test_decoder.py index ca52ca4..1d373a9 100644 --- a/tests/test_decoder.py +++ b/tests/test_decoder.py @@ -22,7 +22,9 @@ def test_training_graphs( lc = LexiconCompiler(position_dependent_phones=False) lc.load_pronunciations(dictionary_path) lc.fst.write(str(mono_temp_dir.joinpath("lexicon.fst"))) - gc = TrainingGraphCompiler(mono_model_path, mono_tree_path, lc, lc.word_table) + gc = TrainingGraphCompiler( + mono_model_path, mono_tree_path, str(mono_temp_dir.joinpath("lexicon.fst")), lc.word_table + ) graph = kaldi_to_pynini(gc.compile_fst(acoustic_corpus_text)) assert graph.num_states() > 0 assert graph.start() != pywrapfst.NO_STATE_ID @@ -49,7 +51,9 @@ def test_training_graphs_sat( lc.fst.write(str(sat_temp_dir.joinpath("L_debug.fst"))) lc.word_table.write_text(str(sat_temp_dir.joinpath("words.txt"))) lc.phone_table.write_text(str(sat_temp_dir.joinpath("phones.txt"))) - gc = TrainingGraphCompiler(sat_model_path, sat_tree_path, lc, lc.word_table) + gc = TrainingGraphCompiler( + sat_model_path, sat_tree_path, str(sat_temp_dir.joinpath("L_debug.fst")), lc.word_table + ) graph = kaldi_to_pynini(gc.compile_fst(acoustic_corpus_text)) assert graph.num_states() > 0 assert graph.start() != pywrapfst.NO_STATE_ID