Skip to content

Commit

Permalink
Fix misc issues (#7)
Browse files Browse the repository at this point in the history
* Fix disambiguation symbols bug

* Fix for disambiguation FST
  • Loading branch information
mmcauliffe authored Oct 10, 2023
1 parent 5840eea commit 8eff99d
Show file tree
Hide file tree
Showing 12 changed files with 534 additions and 237 deletions.
11 changes: 0 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,6 @@ if (MSVC)
# some warnings related with fst
add_compile_options(/wd4018 /wd4244 /wd4267 /wd4291 /wd4305)

set(CompilerFlags
CMAKE_CXX_FLAGS
CMAKE_CXX_FLAGS_DEBUG
CMAKE_CXX_FLAGS_RELEASE
CMAKE_C_FLAGS
CMAKE_C_FLAGS_DEBUG
CMAKE_C_FLAGS_RELEASE
)
foreach(CompilerFlag ${CompilerFlags})
string(REPLACE "/MD" "/MT" ${CompilerFlag} "${${CompilerFlag}}")
endforeach()
elseif(APPLE)
set(CMAKE_INSTALL_RPATH "@loader_path")
else()
Expand Down
10 changes: 9 additions & 1 deletion extensions/decoder/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,14 @@ void pybind_training_graph_compiler(py::module &m) {
TrainingGraphCompiler gc(trans_model, ctx_dep, mf, disambig_syms, opts);
return gc;
}))
.def(py::init([](const TransitionModel &trans_model, const ContextDependency &ctx_dep,
const std::string lex_rxfilename, const std::vector<int32> &disambig_syms,
const TrainingGraphCompilerOptions &opts){
VectorFst<StdArc> *lex_fst = fst::ReadFstKaldi(lex_rxfilename);
TrainingGraphCompiler *gc = new TrainingGraphCompiler(trans_model, ctx_dep, lex_fst, disambig_syms, opts);
return gc;
}),
py::return_value_policy::take_ownership)
.def("CompileGraph",
&PyClass::CompileGraph,
"CompileGraph compiles a single training graph its input is a "
Expand Down Expand Up @@ -1792,7 +1800,7 @@ void pybind_training_graph_compiler(py::module &m) {
},
"This function creates FSTs from the text and calls CompileGraphs.",
py::arg("transcripts"),
py::return_value_policy::reference);
py::return_value_policy::take_ownership);
}
}

Expand Down
8 changes: 4 additions & 4 deletions extensions/fstext/pybind_fstext.h
Original file line number Diff line number Diff line change
Expand Up @@ -1132,22 +1132,22 @@ void pybind_const_fst_impl(py::module& m, const std::string& class_name,
.def("Type", &PyClass::Type, "FST typename",
py::return_value_policy::reference)
.def("Copy", &PyClass::Copy,
"Get a copy of this VectorFst. See Fst<>::Copy() for further "
"Get a copy of this ConstFst. See Fst<>::Copy() for further "
"doc.",
py::arg("safe") = false, py::return_value_policy::take_ownership)
.def_static("Read",
// clang-format off
overload_cast_<std::istream&, const fst::FstReadOptions&>()(&PyClass::Read),
// clang-format on
"Reads a VectorFst from an input stream, returning nullptr "
"Reads a ConstFst from an input stream, returning nullptr "
"on error.",
py::arg("strm"), py::arg("opts"),
py::return_value_policy::take_ownership)
.def_static("Read", overload_cast_<const std::string&>()(&PyClass::Read),
"Read a VectorFst from a file, returning nullptr on error; "
"Read a ConstFst from a file, returning nullptr on error; "
"empty "
"filename reads from standard input.",
py::arg("filename"), py::return_value_policy::take_ownership)
py::arg("filename"), py::return_value_policy::reference)
.def("Write",
// clang-format off
(bool (PyClass::*)(std::ostream&, const fst::FstWriteOptions&)const)&PyClass::Write,
Expand Down
35 changes: 35 additions & 0 deletions extensions/ivector/ivector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,7 @@ void pybind_plda(py::module &m) {
py::array_t<double> utterance_two_ivector
){
py::gil_scoped_release gil_release;

Vector<double> ivector_one_dbl;
auto r_one = utterance_one_ivector.unchecked<1>();
ivector_one_dbl.Resize(r_one.shape(0));
Expand All @@ -869,6 +870,40 @@ void pybind_plda(py::module &m) {
1,
ivector_two_dbl));
return score;

},
py::arg("utterance_one_ivector"),
py::arg("utterance_two_ivector"))
.def("log_likelihood_distance_vectorized",
[](
PyClass &plda,
py::array_t<double> utterance_one_ivector,
py::array_t<double> utterance_two_ivector
){
py::gil_scoped_release gil_release;
py::buffer_info buf1 = utterance_one_ivector.request(), buf2 = utterance_two_ivector.request();

auto r_one = utterance_one_ivector.unchecked<2>();
auto r_two = utterance_two_ivector.unchecked<2>();
auto result = py::array_t<BaseFloat>(r_one.shape(0));
py::buffer_info buf3 = result.request();
double *ptr3 = static_cast<double *>(buf3.ptr);
for (py::size_t i = 0; i < r_one.shape(0); i++){
Vector<double> ivector_one_dbl;
ivector_one_dbl.Resize(r_one.shape(1));
Vector<double> ivector_two_dbl;
ivector_two_dbl.Resize(r_two.shape(1));
for (py::size_t j = 0; j < r_one.shape(1); j++){
ivector_one_dbl(j) = r_one(i, j);
ivector_two_dbl(j) = r_two(i, j);

}
ptr3[i] = 1.0 / Exp(plda.LogLikelihoodRatio(ivector_one_dbl,
1,
ivector_two_dbl));;

}
return result;
},
py::arg("utterance_one_ivector"),
py::arg("utterance_two_ivector"))
Expand Down
1 change: 0 additions & 1 deletion extensions/matrix/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "matrix/matrix-common.h"
#include "matrix/sparse-matrix.h"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include "util/pybind_util.h"

using namespace kaldi;
Expand Down
4 changes: 2 additions & 2 deletions kalpy/decoder/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __iter__(self) -> typing.Generator[typing.Tuple[str, VectorFst]]:
try:
while not reader.Done():
utt = reader.Key()
fst = reader.Value()
decode_fst = VectorFst(fst)
decode_fst = VectorFst(reader.Value())
reader.FreeCurrent()
yield utt, decode_fst
reader.Next()
finally:
Expand Down
92 changes: 88 additions & 4 deletions kalpy/decoder/training_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,14 @@ def compiler(self):
self._fst = pynini.Fst.read(str(self.lexicon_path))
else:
self._fst = self.lexicon_compiler.fst
disambiguation_symbols = []
if self.lexicon_compiler is not None and self.lexicon_compiler.disambiguation:
disambiguation_symbols = self.lexicon_compiler.disambiguation_symbols
self._compiler = _TrainingGraphCompiler(
self.transition_model,
self.tree,
VectorFst.from_pynini(self._fst),
self.disambiguation_symbols,
disambiguation_symbols,
self.options,
)
return self._compiler
Expand Down Expand Up @@ -227,7 +230,9 @@ def export_graphs(
writer.Close()
logger.info(f"Done {num_done} utterances, errors on {num_error}.")

def compile_fst(self, transcript: str) -> typing.Optional[VectorFst]:
def compile_fst(
self, transcript: str, interjection_words: typing.List[str] = None
) -> typing.Optional[VectorFst]:
"""
Compile a transcript to a training graph
Expand All @@ -251,13 +256,90 @@ def compile_fst(self, transcript: str) -> typing.Optional[VectorFst]:
state_threshold = 256 + 2 * lg_fst.num_states()
lg_fst = pynini.determinize(lg_fst, nstate=state_threshold, weight=weight_threshold)
lg_fst = VectorFst.from_pynini(lg_fst)
disambig_syms_in = (
[]
if not self.lexicon_compiler.disambiguation
else self.lexicon_compiler.disambiguation_symbols
)
lg_fst = fst_determinize_star(lg_fst, use_log=True)
fst_minimize_encoded(lg_fst)
fst_push_special(lg_fst)
clg_fst, disambig_out, ilabels = fst_compose_context(
lg_fst,
disambig_syms_in,
self.tree.ContextWidth(),
self.tree.CentralPosition(),
)
fst_arc_sort(clg_fst, sort_type="ilabel")
h, disambig = make_h_transducer(self.tree, self.transition_model, ilabels)
fst = fst_table_compose(h, clg_fst)
if fst.Start() == pywrapfst.NO_STATE_ID:
logger.debug(f"Falling back to pynini compose for '{transcript}")
h = kaldi_to_pynini(h)
clg_fst = kaldi_to_pynini(clg_fst)
fst = pynini_to_kaldi(pynini.compose(h, clg_fst))
fst_determinize_star(fst, use_log=True)
fst_rm_symbols(fst, disambig)
fst_rm_eps_local(fst)
fst_minimize_encoded(fst)
fst_add_self_loops(
fst, self.transition_model, disambig_syms_in, self.options.self_loop_scale
)
elif interjection_words:
g = pynini.Fst()
start_state = g.add_state()
g.set_start(start_state)
for w in transcript.split():
word_symbol = self.to_int(w)
word_initial_state = g.add_state()
for iw in interjection_words:
if not self.lexicon_compiler.word_table.member(iw):
continue
iw_symbol = self.to_int(iw)
g.add_arc(
word_initial_state - 1,
pywrapfst.Arc(
iw_symbol,
iw_symbol,
pywrapfst.Weight(g.weight_type(), 4.0),
word_initial_state,
),
)
word_final_state = g.add_state()
g.add_arc(
word_initial_state,
pywrapfst.Arc(
word_symbol,
word_symbol,
pywrapfst.Weight.one(g.weight_type()),
word_final_state,
),
)
g.add_arc(
word_initial_state - 1,
pywrapfst.Arc(
word_symbol,
word_symbol,
pywrapfst.Weight.one(g.weight_type()),
word_final_state,
),
)
g.set_final(word_final_state, pywrapfst.Weight.one(g.weight_type()))

lg = pynini.compose(self.lexicon_compiler.fst, g)
lg.optimize()
lg.arcsort("olabel")
lg_fst = VectorFst.from_pynini(lg)

disambig_syms_in = []
if self.lexicon_compiler is not None and self.lexicon_compiler.disambiguation:
disambig_syms_in = self.lexicon_compiler.disambiguation_symbols
lg_fst = fst_determinize_star(lg_fst, use_log=True)
fst_minimize_encoded(lg_fst)
fst_push_special(lg_fst)
clg_fst, disambig_out, ilabels = fst_compose_context(
lg_fst,
self.disambiguation_symbols,
disambig_syms_in,
self.tree.ContextWidth(),
self.tree.CentralPosition(),
)
Expand All @@ -273,7 +355,9 @@ def compile_fst(self, transcript: str) -> typing.Optional[VectorFst]:
fst_rm_symbols(fst, disambig)
fst_rm_eps_local(fst)
fst_minimize_encoded(fst)
fst_add_self_loops(fst, self.transition_model, [], self.options.self_loop_scale)
fst_add_self_loops(
fst, self.transition_model, disambig_syms_in, self.options.self_loop_scale
)
else:
transcript_symbols = [self.to_int(x) for x in transcript.split()]
fst = self.compiler.CompileGraphFromText(transcript_symbols)
Expand Down
18 changes: 10 additions & 8 deletions kalpy/feat/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,18 @@ def accumulate_stats(
num_done += 1
tot_like += tot_like_this_file
tot_t += tot_weight_this_file
logger.info(
f"Average like for this file is {tot_like_this_file/tot_weight_this_file} "
f"over {tot_weight_this_file} frames."
)
if num_done % 10 == 0:
if tot_weight_this_file != 0:
logger.info(
f"Average like for this file is {tot_like_this_file/tot_weight_this_file} "
f"over {tot_weight_this_file} frames."
)
if num_done % 10 == 0 and tot_t != 0:
logger.info(f"Average per frame so far is {tot_like/tot_t}")
logger.info(f"Done {num_done} files.")
logger.info(
f"Overall avg like per frame (Gaussian only) = {tot_like/tot_t} over {tot_t} frames."
)
if tot_t != 0:
logger.info(
f"Overall avg like per frame (Gaussian only) = {tot_like/tot_t} over {tot_t} frames."
)

def export_stats(
self, file_name: str, feature_archive: FeatureArchive, alignment_archive: AlignmentArchive
Expand Down
Loading

0 comments on commit 8eff99d

Please sign in to comment.