Skip to content

Commit

Permalink
0.6.6
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcauliffe committed Sep 3, 2024
1 parent 30fd402 commit 72a7677
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 95 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
default_language_version:
python: python3.8
repos:
- repo: https://github.com/psf/black
rev: 22.10.0
Expand Down
136 changes: 87 additions & 49 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,61 +43,99 @@ find_package(CUDAToolkit)

find_package(pybind11 REQUIRED)
include_directories(extensions)
pybind11_add_module(_kalpy extensions/_kalpy.cpp
extensions/chain/chain.cpp
extensions/cudamatrix/cudamatrix.cpp
extensions/decoder/decoder.cpp
extensions/feat/feat.cpp
extensions/fstext/fstext.cpp
extensions/gmm/gmm.cpp
extensions/hmm/hmm.cpp
extensions/itf/itf.cpp
extensions/ivector/ivector.cpp
extensions/kws/kws.cpp
extensions/lat/lat.cpp
extensions/lm/lm.cpp
#extensions/rnnlm/rnnlm.cpp
extensions/online/online.cpp
extensions/online2/online2.cpp
extensions/matrix/matrix.cpp
extensions/nnet/nnet.cpp
extensions/nnet2/nnet2.cpp
extensions/nnet3/nnet3.cpp
extensions/transform/transform.cpp
extensions/tree/tree.cpp
extensions/util/util.cpp
)
target_link_libraries(_kalpy PUBLIC kaldi-base kaldi-chain
kaldi-matrix
kaldi-cudamatrix
kaldi-hmm
kaldi-online kaldi-online2 kaldi-rnnlm
kaldi-nnet3
kaldi-nnet2 kaldi-nnet
kaldi-kws
kaldi-decoder
kaldi-lat
kaldi-ivector kaldi-lm
kaldi-fstext kaldi-feat
kaldi-transform kaldi-gmm
kaldi-tree
kaldi-util
fst
fstscript
)
if (CUDAToolkit_FOUND)
pybind11_add_module(_kalpy extensions/_kalpy.cpp
extensions/chain/chain.cpp
extensions/cudamatrix/cudamatrix.cpp
extensions/decoder/decoder.cpp
extensions/feat/feat.cpp
extensions/fstext/fstext.cpp
extensions/gmm/gmm.cpp
extensions/hmm/hmm.cpp
extensions/itf/itf.cpp
extensions/ivector/ivector.cpp
extensions/kws/kws.cpp
extensions/lat/lat.cpp
extensions/lm/lm.cpp
#extensions/rnnlm/rnnlm.cpp
extensions/online/online.cpp
extensions/online2/online2.cpp
extensions/matrix/matrix.cpp
extensions/nnet/nnet.cpp
extensions/nnet2/nnet2.cpp
extensions/nnet3/nnet3.cpp
extensions/transform/transform.cpp
extensions/tree/tree.cpp
extensions/util/util.cpp
)
target_link_libraries(_kalpy PUBLIC kaldi-base kaldi-chain
kaldi-matrix
kaldi-cudamatrix
kaldi-hmm
kaldi-online kaldi-online2 kaldi-rnnlm
kaldi-nnet3
kaldi-nnet2 kaldi-nnet
kaldi-kws
kaldi-decoder
kaldi-lat
kaldi-nnet3
kaldi-nnet2 kaldi-nnet
kaldi-ivector kaldi-lm
kaldi-fstext kaldi-feat
kaldi-transform kaldi-gmm
kaldi-tree
kaldi-util
fst
fstscript
)

find_library(KALDI_CUDADECODER kaldi-cudadecoder)
find_library(KALDI_CUDADECODER kaldi-cudadecoder)

if(CUDAToolkit_FOUND AND KALDI_CUDADECODER)
if(KALDI_CUDADECODER)

target_link_libraries(_kalpy PUBLIC kaldi-cudadecoder kaldi-cudafeat
)
endif()
else()
pybind11_add_module(_kalpy extensions/_kalpy.cpp
extensions/decoder/decoder.cpp
extensions/feat/feat.cpp
extensions/fstext/fstext.cpp
extensions/gmm/gmm.cpp
extensions/hmm/hmm.cpp
extensions/itf/itf.cpp
extensions/ivector/ivector.cpp
extensions/kws/kws.cpp
extensions/lat/lat.cpp
extensions/lm/lm.cpp
extensions/online/online.cpp
extensions/online2/online2.cpp
extensions/matrix/matrix.cpp
extensions/transform/transform.cpp
extensions/tree/tree.cpp
extensions/util/util.cpp
)
target_link_libraries(_kalpy PUBLIC kaldi-base
kaldi-matrix
kaldi-hmm
kaldi-online kaldi-online2
kaldi-kws
kaldi-decoder
kaldi-lat
kaldi-ivector kaldi-lm
kaldi-fstext kaldi-feat
kaldi-transform kaldi-gmm
kaldi-tree
kaldi-util
fst
fstscript
)

target_link_libraries(_kalpy PUBLIC kaldi-cudadecoder kaldi-cudafeat
)
endif()
target_compile_definitions(_kalpy
PRIVATE VERSION_INFO="5.5.1068")

if(MSVC)
set_target_properties(_kalpy PROPERTIES
DEFINE_SYMBOL "KALDI_DLL_IMPORTS"
"KALDI_CUMATRIX_DLL_IMPORTS"
#"KALDI_CUMATRIX_DLL_IMPORTS"
"KALDI_UTIL_DLL_IMPORTS")
endif(MSVC)
2 changes: 1 addition & 1 deletion kalpy/decoder/training_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def export_graphs(
fsts.append(self.compile_fst(t))
elif interjection_words:
fsts = self.compiler.CompileGraphs(transcript_batch)
del transcript_batch
else:
fsts = self.compiler.CompileGraphsFromText(transcript_batch)
del transcript_batch
assert len(fsts) == len(keys)
batch_done = 0
batch_error = 0
Expand Down
123 changes: 80 additions & 43 deletions kalpy/fstext/lexicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,50 +837,78 @@ def _create_pronunciation_string(
word_begin_symbol = self.phone_table.find(self.word_begin_label)
word_end_symbol = self.phone_table.find(self.word_end_label)
text = " ".join(self.word_table.find(x) for x in word_symbols)
acceptor = pynini.accep(text, token_type=self.word_table)
phone_to_word = pynini.compose(self.align_fst, acceptor)
phone_fst = pynini.Fst()
current_state = phone_fst.add_state()
phone_fst.set_start(current_state)
for symbol in phone_symbols:
next_state = phone_fst.add_state()
phone_fst.add_arc(
current_state,
pywrapfst.Arc(
symbol, symbol, pywrapfst.Weight.one(phone_fst.weight_type()), next_state
),
)
current_state = next_state
if transcription:
if phone_symbols[-1] == self.phone_table.find(self.silence_phone):
state = current_state - 1
if len(word_symbols) > 1:
text, final_word = text.rsplit(maxsplit=1)
else:
state = current_state
phone_to_word_state = phone_to_word.num_states() - 1
for i in range(self.phone_table.num_symbols()):
if self.phone_table.find(i) == "<eps>":
continue
if self.phone_table.find(i).startswith("#"):
continue
phone_fst.add_arc(
final_word = text
phone_to_word = pynini.compose(
self.align_fst, pynini.accep(text, token_type=self.word_table)
)
if transcription:
final_word_phone_to_word = pynini.compose(
self.align_fst, pynini.accep(final_word, token_type=self.word_table)
)
infinity_weight = pywrapfst.Weight(final_word_phone_to_word.weight_type(), "infinity")
final_word_phone_to_word = pynini.determinize(final_word_phone_to_word)
final_states = []
for i in range(final_word_phone_to_word.num_states()):
if final_word_phone_to_word.final(i) != infinity_weight:
final_states.append(i)
else:
final_word_phone_to_word.set_final(
i, pywrapfst.Weight.one(final_word_phone_to_word.weight_type())
)
extra_state = final_word_phone_to_word.add_state()
final_word_phone_to_word.set_final(
extra_state, pywrapfst.Weight.one(final_word_phone_to_word.weight_type())
)
for state in final_states:
final_word_phone_to_word.add_arc(
state,
pywrapfst.Arc(
word_begin_symbol,
self.phone_table.find("<eps>"),
i,
pywrapfst.Weight.one(phone_fst.weight_type()),
state,
pywrapfst.Weight(final_word_phone_to_word.weight_type(), 10),
extra_state,
),
)
for i in range(self.phone_table.num_symbols()):
if self.phone_table.find(i) == "<eps>":
continue
if self.phone_table.find(i).startswith(self.silence_phone):
continue
if self.phone_table.find(i).startswith("#"):
continue

phone_to_word.add_arc(
phone_to_word_state,
final_word_phone_to_word.add_arc(
extra_state,
pywrapfst.Arc(
i,
self.phone_table.find("<eps>"),
pywrapfst.Weight.one(phone_fst.weight_type()),
phone_to_word_state,
pywrapfst.Weight(final_word_phone_to_word.weight_type(), 10),
extra_state,
),
)
if len(word_symbols) > 1:
phone_to_word = pynini.concat(phone_to_word, final_word_phone_to_word)
else:
phone_to_word = final_word_phone_to_word

phone_fst = pynini.Fst()
current_state = phone_fst.add_state()
phone_fst.set_start(current_state)
for symbol in phone_symbols:
next_state = phone_fst.add_state()
phone_fst.add_arc(
current_state,
pywrapfst.Arc(
symbol, symbol, pywrapfst.Weight.one(phone_fst.weight_type()), next_state
),
)
current_state = next_state
phone_fst.set_final(current_state, pywrapfst.Weight.one(phone_fst.weight_type()))

for s in range(current_state + 1):
phone_fst.add_arc(
s,
Expand All @@ -900,9 +928,12 @@ def _create_pronunciation_string(
s,
),
)

phone_fst.set_final(current_state, pywrapfst.Weight.one(phone_fst.weight_type()))
phone_fst.arcsort("olabel")
if transcription:
inf_weight = pywrapfst.Weight(phone_fst.weight_type(), "infinity")
for state in range(phone_fst.num_states()):
if phone_fst.final(state) != inf_weight:
phone_fst.set_final(state, pywrapfst.Weight(phone_fst.weight_type(), 100))

lattice = pynini.compose(phone_fst, phone_to_word)

Expand Down Expand Up @@ -956,6 +987,8 @@ def phones_to_pronunciations(
current_phone_index = 0
current_word_index = 0
for i, w in enumerate(actual_words):
if current_word_index >= len(word_splits):
break
pron = word_splits[current_word_index]
word_symbol = word_symbols[i]
if pron == self.silence_phone:
Expand All @@ -968,19 +1001,21 @@ def phones_to_pronunciations(
)
current_word_index += 1
current_phone_index += 1
if current_word_index >= len(word_splits):
break
pron = word_splits[current_word_index]

phones = pron.split()
word_intervals.append(
WordCtmInterval(
w,
word_symbol,
intervals[current_phone_index : current_phone_index + len(phones)],
if pron:
phones = pron.split()
word_intervals.append(
WordCtmInterval(
w,
word_symbol,
intervals[current_phone_index : current_phone_index + len(phones)],
)
)
)
current_phone_index += len(phones)
current_phone_index += len(phones)
current_word_index += 1
if current_word_index != len(word_splits):
if current_word_index < len(word_splits):
pron = word_splits[current_word_index]
if pron == self.silence_phone:
word_intervals.append(
Expand All @@ -990,6 +1025,8 @@ def phones_to_pronunciations(
intervals[current_phone_index : current_phone_index + 1],
)
)
if not word_intervals[-1].phones:
del word_intervals[-1]
return HierarchicalCtm(word_intervals, text=text)


Expand Down

0 comments on commit 72a7677

Please sign in to comment.