Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implements RNNT+MMI #1030

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 324 additions & 5 deletions k2/csrc/fsa_algo.cu

Large diffs are not rendered by default.

42 changes: 40 additions & 2 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin,
weight 'x.weight', then the reverse of 'src' accepts the reverse of string 'x'
with weight 'x.weight.reverse'.

Implementation notss:
Implementation notes:
The Fsa in k2 only has one start state 0, and the only final state with
the largest state number whose in-coming arcs have "-1" as the label.
So, 1) the start state of 'dest' will correspond to the final state of 'src'.
Expand All @@ -864,13 +864,51 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin,
@param [out] dest Output Fsa or FsaVec. At exit, it will be equivalent
to the reverse Fsa of 'src'. Caution: the reverse will
ignore the "-1" label.

@param [out,optional] arc_map For each arc in `dest`, gives the index of
the corresponding arc in `src` that it corresponds to.

*/
void Reverse(FsaVec &src, FsaVec *dest, Array1<int32_t> *arc_map = nullptr);

/*
* Generate denominator lattice from sampled linear paths for RNN-T+MMI
* training.
*
* Implementation notes:
* 1) Generate "states" for each sampled symbol from their left_symbols and
* the frame_ids they are sampled from.
* 2) Sort those "states" for each sequence and then merge the same "states".
* 3) Map all of the sampled symbols to the merged "states".
* 4) Remove duplicate arcs.
*
* @param [in] sampled_paths The sampled symbols, it has a regular shape of
* [seq][num_path][path_length]. All its elements MUST satisfy
* `0 <= value < vocab_size.
* @param [in] frame_ids It contains the frame indexes of at which frame we
* sampled the symbols, which has same shape of sampled_paths.
* @param [in] left_symbols The left_symbols of the sampled symbols, it has a
* regular shape of [seq][num_path][path_length][context], the
* first three indexes are the same as sampled_paths. Each
* sublist along axis 3 has `context_size` elements. All its
* elements MUST satisfy `0 <= value < vocab_size`.
* @param [in] sampling_probs It contains the probabilities of sampling each
* symbol, which has the same shape as sampled_paths.
* @param [in] vocab_size The vocabulary size.
* @param [in] context_size The number of left symbols.
* @param [out] arc_map For each arc in the return Fsa, gives the orignal
* index (idx012) in sampled_paths that it corresponds to.
*
* @return Return the generated lattice.
*/
FsaVec GenerateDenominatorLattice(Ragged<int32_t> &sampled_paths,
Ragged<int32_t> &frame_ids,
Ragged<int32_t> &left_symbols,
Ragged<float> &sampling_probs,
int32_t vocab_size,
int32_t context_size,
Array1<int32_t> *arc_map);

} // namespace k2

#endif // K2_CSRC_FSA_ALGO_H_
49 changes: 49 additions & 0 deletions k2/csrc/fsa_algo_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1383,4 +1383,53 @@ TEST(FsaAlgo, TestLevenshteinGraph) {
}
}

TEST(FsaAlgo, TestGenerateDenominatorLattice) {
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
Ragged<int32_t> sampled_paths(c, "[ [ [ 3 5 0 4 6 0 2 1 ] "
" [ 2 0 5 4 0 6 1 2 ] "
" [ 3 5 2 0 0 1 6 4 ] ] "
" [ [ 7 0 4 0 6 0 3 0 ] "
" [ 0 7 3 0 2 0 4 5 ] "
" [ 7 0 3 4 0 1 2 0 ] ] ]");
Ragged<int32_t> frame_ids(c, "[ [ [ 0 0 0 1 1 1 2 2 ] "
" [ 0 0 1 1 1 2 2 2 ] "
" [ 0 0 0 0 1 2 2 2 ] ] "
" [ [ 0 0 1 1 2 2 3 3 ] "
" [ 0 1 1 1 2 2 3 1 ] "
" [ 0 0 1 1 1 2 2 2 ] ] ]");
Ragged<int32_t> left_symbols(c,
"[ [ [ [ 0 0 ] [ 0 3 ] [ 3 5 ] [ 3 5 ] [ 5 4 ] [ 4 6 ] [ 4 6 ] [ 6 2 ] ] "
" [ [ 0 0 ] [ 0 2 ] [ 0 2 ] [ 2 5 ] [ 5 4 ] [ 5 4 ] [ 4 6 ] [ 6 1 ] ] "
" [ [ 0 0 ] [ 0 3 ] [ 3 5 ] [ 5 2 ] [ 5 2 ] [ 5 2 ] [ 2 1 ] [ 1 6 ] ] "
" ] "
" [ [ [ 0 0 ] [ 0 7 ] [ 0 7 ] [ 7 4 ] [ 7 4 ] [ 4 6 ] [ 4 6 ] [ 6 3 ] ] "
" [ [ 0 0 ] [ 0 0 ] [ 0 7 ] [ 7 3 ] [ 7 3 ] [ 3 2 ] [ 3 2 ] [ 0 0 ] ] "
" [ [ 0 0 ] [ 0 7 ] [ 0 7 ] [ 7 3 ] [ 3 4 ] [ 3 4 ] [ 4 1 ] [ 1 2 ] ] "
" ] ]");

Ragged<float> sampling_probs(c, "[ [ [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 ] "
" [ 0.2 0.2 0.2 0.1 0.2 0.2 0.1 0.2 ] "
" [ 0.1 0.1 0.1 0.3 0.2 0.3 0.3 0.3 ] ] "
" [ [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 ] "
" [ 0.2 0.2 0.2 0.1 0.2 0.1 0.2 0.2 ] "
" [ 0.1 0.1 0.2 0.2 0.3 0.3 0.3 0.3 ] ] "
"]");
Ragged<float> path_scores(c, "[ [ [ 1 1 1 1 1 1 1 1 ] "
" [ 1 2 2 1 2 2 1 2 ] "
" [ 1 1 1 3 3 3 3 3 ] ] "
" [ [ 1 1 1 1 1 1 1 1 ] "
" [ 1 2 1 2 2 2 2 2 ] "
" [ 1 1 1 2 3 3 3 3 ] ] ]");

Array1<int32_t> arc_map;
FsaVec lattice = GenerateDenominatorLattice(
sampled_paths, frame_ids, left_symbols, sampling_probs,
10 /*vocab_size*/, 2 /*context_size*/, &arc_map);
K2_LOG(INFO) << arc_map;
K2_LOG(INFO) << lattice;
K2_LOG(INFO) << FsaToString(lattice.Index(0, 0));
K2_LOG(INFO) << FsaToString(lattice.Index(0, 1));
}
}

} // namespace k2
3 changes: 1 addition & 2 deletions k2/csrc/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ namespace k2 {

// Currently, only used in k2/csrc/rnnt_decode.cu
// See https://github.com/k2-fsa/k2/pull/951#issuecomment-1096650842
K2_CUDA_HOSTDEV __forceinline__ int64_t Pow(int64_t base,
int64_t exponent) {
K2_CUDA_HOSTDEV __forceinline__ int64_t Pow(int64_t base, int64_t exponent) {
K2_CHECK_GE(exponent, 0);
int64_t exp = 0;
int64_t result = 1;
Expand Down
24 changes: 24 additions & 0 deletions k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,29 @@ static void PybindReverse(py::module &m) {
py::arg("src"), py::arg("need_arc_map") = true);
}

static void PybindGenerateDenominatorLattice(py::module &m) {
m.def(
"generate_denominator_lattice",
[](RaggedAny &sampled_paths, RaggedAny &frame_ids,
RaggedAny &left_symbols, RaggedAny &sampling_probs,
int32_t vocab_size, int32_t context_size)
-> std::pair<FsaVec, torch::Tensor> {
DeviceGuard guard(sampled_paths.any.Context());
Array1<int32_t> arc_map;
FsaVec lattice = GenerateDenominatorLattice(
sampled_paths.any.Specialize<int32_t>(),
frame_ids.any.Specialize<int32_t>(),
left_symbols.any.Specialize<int32_t>(),
sampling_probs.any.Specialize<float>(),
vocab_size, context_size, &arc_map);
auto arc_map_tensor = ToTorch(arc_map);
return std::make_pair(lattice, arc_map_tensor);
},
py::arg("sampled_paths"), py::arg("frame_ids"), py::arg("left_symbols"),
py::arg("sampling_probs"), py::arg("vocab_size"),
py::arg("context_size"));
}

} // namespace k2

void PybindFsaAlgo(py::module &m) {
Expand All @@ -777,6 +800,7 @@ void PybindFsaAlgo(py::module &m) {
k2::PybindDeterminize(m);
k2::PybindExpandArcs(m);
k2::PybindFixFinalLabels(m);
k2::PybindGenerateDenominatorLattice(m);
k2::PybindIntersect(m);
k2::PybindIntersectDense(m);
k2::PybindIntersectDensePruned(m);
Expand Down
1 change: 1 addition & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .fsa_algo import ctc_topo
from .fsa_algo import determinize
from .fsa_algo import expand_ragged_attributes
from .fsa_algo import generate_denominator_lattice
from .fsa_algo import intersect
from .fsa_algo import intersect_device
from .fsa_algo import invert
Expand Down
54 changes: 54 additions & 0 deletions k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,3 +1381,57 @@ def union(fsas: Fsa) -> Fsa:
out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc,
arc_map)
return out_fsa


def generate_denominator_lattice(
sampled_paths: torch.Tensor,
frame_ids: torch.Tensor,
left_symbols: torch.Tensor,
sampling_probs: torch.Tensor,
path_scores: torch.Tensor,
vocab_size: int,
context_size: int,
) -> Fsa:
"""Generate denominator lattice from sampled linear paths for RNN-T+MMI
training.

Args:
sampled_paths:
The sampled symbols, it has a shape of (seq, num_path, path_length).
All its elements MUST satisfy `0 <= value < vocab_size.
frame_ids:
It contains the frame indexes of at which frame we sampled the symbols,
which has same shape of sampled_paths.
left_symbols:
The left_symbols of the sampled symbols, it has a shape of
(seq, num_path, path_length, context_size), the first three indexes are
the same as sampled_paths. All its elements MUST satisfy
`0 <= value < vocab_size`.
sampling_probs:
It contains the probabilities of sampling each symbol, which has a
same shape as sampled_paths. Normally comes from the output of
"predictor" head.
path_scores:
It contains the scores of each sampled symbol, which has a same shape as
sampled_paths. It might contain the output of hybrid head and the extra
language model output. Note: Autograd is supported for this tensor.
vocab_size:
The vocabulary size.
context_size:
The number of left symbols.
"""
ragged_arc, arc_map = _k2.generate_denominator_lattice(
sampled_paths=k2.RaggedTensor(sampled_paths),
frame_ids=k2.RaggedTensor(frame_ids),
left_symbols=k2.RaggedTensor(left_symbols),
sampling_probs=k2.RaggedTensor(sampling_probs),
vocab_size=vocab_size,
context_size=context_size,
)
lattice = Fsa(ragged_arc)
a_value = getattr(lattice, "scores")
# Enable autograd for path_scores
b_value = index_select(path_scores.flatten(), arc_map)
value = a_value + b_value
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

path_scores here will contain hybrid_output and detached lm_output. I include the path_scores here and enable antograd to path_scores.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, OK. Right, we treat those as differentiable, but the negated sampling_prob is treated as just a constant.

setattr(lattice, "scores", value)
return lattice
1 change: 1 addition & 0 deletions k2/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ set(py_test_files
fsa_from_unary_function_ragged_test.py
fsa_from_unary_function_tensor_test.py
fsa_test.py
generate_denominator_lattice_test.py
get_arc_post_test.py
get_backward_scores_test.py
get_forward_scores_test.py
Expand Down
Loading