From 5a79f329746fa478cadb3acf1a0ee04d17a66333 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 4 Aug 2022 09:48:28 +0800 Subject: [PATCH 1/8] draft of importance sampling algorithm --- k2/python/k2/mmi.py | 104 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 k2/python/k2/mmi.py diff --git a/k2/python/k2/mmi.py b/k2/python/k2/mmi.py new file mode 100644 index 000000000..74c77dc7e --- /dev/null +++ b/k2/python/k2/mmi.py @@ -0,0 +1,104 @@ +import torch +from torch.distributions.categorical import Categorical +from typing import Tuple + +def importance_sampling( + sampling_scores: torch.Tensor, + path_length: int, + num_paths: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + sampling_scores: + The output of predictor head, a tensor of shape (B, S, T, V) containing + the probabilities of emitting symbols at each (t, s) for each sequence. + path_length: + How many symbols we will sample for each path. + num_paths: + How many paths we will sample for each sequence. + + Returns: + Three tensors will be returned. + - sampled_indexs: + A tensor of shape (B, num_paths, path_length), containing the sampled symbol ids. + - sampled_scores: + A tensor of shape (B, num_paths, path_length), containing the sampling probabilities + of the corresponding symbols. + - sampled_t_indexs: + A tensor of shape (B, num_paths, path_length), containing the frame ids, which means + at what frame this symple be sampled. + """ + (B, S, T, V) = sampling_scores.shape + # we sample paths from frame 0 + t_index = torch.zeros( + (B, num_paths), + dtype=torch.int64, + device=sampling_scores.device + ) + # we sample paths from the first symbols (i.e. from null left_context) + s_index = torch.zeros( + (B, num_paths), + dtype=torch.int64, + device=sampling_scores.device + ) + + sampled_indexs = [] + sampled_scores = [] + sampled_t_indexs = [] + + for i in range(path_length): + # select context symbols for paths + # sub_scores : (B, num_paths, T, V) + sub_scores = torch.gather( + sampling_scores, dim=1, + index=s_index.reshape(B, num_paths, 1, 1).expand(B, num_paths, T, V)) + + # select frames for paths + # sub_scores : (B, num_paths, 1, V) + sub_scores = torch.gather( + sub_scores, dim=2, + index=t_index.reshape(B, num_paths, 1, 1).expand(B, num_paths, 1, V)) + + # sub_scores : (B, num_paths, V) + sub_scores = sub_scores.squeeze(2) + # sampler: https://pytorch.org/docs/stable/distributions.html#categorical + sampler = Categorical(probs=sub_scores) + + # sample one symbol for each path + # index : (B, num_paths) + index = sampler.sample() + sampled_indexs.append(index) + + # gather sampling probabilities for corresponding indexs + # score : (B, num_paths, 1) + score = torch.gather(sub_scores, dim=2, index=index.unsqueeze(2)) + sampled_scores.append(score.squeeze(2)) + + sampled_t_indexs.append(t_index) + + # update (t, s) for each path (for regular RNN-T) + # index == 0 means the sampled symbol is blank + t_mask = index == 0 + t_index = torch.where(t_mask, t_index + 1, t_index) + s_index = torch.where(t_mask, s_index + 1, s_index) + + # indexs : (B, num_paths, path_lengths) + indexs = torch.stack(sampled_indexs, dim=0).permute(1,2,0) + # scores : (B, num_paths, path_lengths) + scores = torch.stack(sampled_scores, dim=0).permute(1,2,0) + # t_indexs : (B, num_paths, path_lengths) + t_indexs = torch.stack(sampled_t_indexs, dim=0).permute(1,2,0) + + return indexs, scores, t_indexs + + +if __name__ == "__main__": + B, S, T, V = 2, 10, 20, 10 + path_length = 8 + num_path = 3 + logits = torch.randn((B, S, T, V)) + log_prob = torch.softmax(logits, -1) + indexs, scores, t_indexs = importance_sampling(log_prob, path_length, num_path) + print (indexs) + print (scores) + print (t_indexs) From 05aa9f38922653d5a6ac224d49430da3066de395 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 9 Aug 2022 19:51:15 +0800 Subject: [PATCH 2/8] Generate denominator lattice from sampled paths --- k2/csrc/fsa_algo.cu | 329 +++++++++++++++++- k2/csrc/fsa_algo.h | 42 ++- k2/csrc/fsa_algo_test.cu | 49 +++ k2/csrc/math.h | 3 +- k2/python/csrc/torch/fsa_algo.cu | 24 ++ k2/python/k2/__init__.py | 1 + k2/python/k2/fsa_algo.py | 54 +++ k2/python/k2/mmi.py | 104 ------ k2/python/tests/CMakeLists.txt | 1 + .../generate_denominator_lattice_test.py | 216 ++++++++++++ 10 files changed, 710 insertions(+), 113 deletions(-) delete mode 100644 k2/python/k2/mmi.py create mode 100644 k2/python/tests/generate_denominator_lattice_test.py diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index 4065713ff..075706dc1 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -1867,13 +1867,13 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, index_arc_idx2 = idx4; // corresponds to foo=0, so idx3 will be 0; // the idx4 enumerates the arcs leaving it.. } else { - // this is one of the extra `foo` indexes, it's conrespoding index + // this is one of the extra `foo` indexes, it's corresponding index // into `index` is `foo` index minus 1 index_arc_idx2 = idx2 - 1; } int32_t index_arc_idx01x = index_row_splits2_data[idx01]; - // index of the arc in source FSA, FSA that we're replaceing.. + // index of the arc in source FSA, FSA that we're replacing.. int32_t index_arc_idx012 = index_arc_idx01x + index_arc_idx2; Arc index_arc = index_arcs_data[index_arc_idx012]; @@ -1916,8 +1916,8 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, } else { // this arc would point to the initial state of the fsa in src, // the state id bias to current state(the src-state) is the count - // of all the ostates coresponding to the original state util now, - // the idx4 enumerates foo index + // of all the ostates corresponding to the original state until + // now, the idx4 enumerates foo index int32_t idx012_t = idx01x + 0, idx2_t = idx4, idx012x_t = tos_row_splits3_data[idx012_t], @@ -1932,7 +1932,7 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, arc_index_map_idx = index_arc_idx012; } else { // handle the arcs belongs to src // the arc point to the final state of the fsa in src would point to - // the dest state of the arc we're replaceing + // the dest state of the arc we're replacing if (src_arc.label == -1) { oarc.dest_state = orig_dest_state_idx0123 - idx0xxx; } else { @@ -1991,4 +1991,323 @@ FsaOrVec RemoveEpsilonSelfLoops(FsaOrVec &src, return ans; } +FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, + Ragged &frame_ids, + Ragged &left_symbols, + Ragged &sampling_probs, + int32_t vocab_size, + int32_t context_size, + Array1 *arc_map) { + NVTX_RANGE(K2_FUNC); + K2_CHECK(arc_map); + K2_CHECK_EQ(sampled_paths.NumAxes(), 3); + K2_CHECK_EQ(frame_ids.NumAxes(), 3); + K2_CHECK_EQ(left_symbols.NumAxes(), 4); + K2_CHECK_EQ(sampling_probs.NumAxes(), 3); + + K2_DCHECK_EQ(sampled_paths.NumElements(), frame_ids.NumElements()); + K2_DCHECK_EQ(sampled_paths.NumElements(), + left_symbols.NumElements() * context_size); + K2_DCHECK_EQ(sampled_paths.NumElements(), sampling_probs.NumElements()); + for (int32_t i = 0; i < 3; ++i) { + K2_DCHECK_EQ(sampled_paths.TotSize(i), frame_ids.TotSize(i)); + K2_DCHECK_EQ(sampled_paths.TotSize(i), left_symbols.TotSize(i)); + K2_DCHECK_EQ(sampled_paths.TotSize(i), sampling_probs.TotSize(i)); + } + + ContextPtr c = GetContext( + sampled_paths, frame_ids, left_symbols, sampling_probs); + + // The states indicating we are in on each position of each path, which has + // the same shape as `sampled_paths`, because each symbol in the paths is + // sampled from a specific frame with corresponding left contexts. + // Each state represents a tuple like (t, left_symbols1, left_symbols2...), + // the number of left_symbols equals to the `context_size`. A state is + // calculated from t * V ^ c + \sum_{i=1}^{c} s_i * V ^ (c - i), + // V is vocab_size, c is context_size, s_i is the ith left_symbols. + // For example, if context_size = 2, vocab_size = 10, so, one possible tuple + // would be (2, 4, 5), then the corresponding state is + // 2 * 10 ^ 2 + 4 * 10 + 5 = 245. + Ragged states(sampled_paths.shape); + int32_t num_states = states.NumElements(); + + const int32_t *frame_ids_data = frame_ids.values.Data(), + *left_symbols_row_splits3_data + = left_symbols.RowSplits(3).Data(), + *left_symbols_data = left_symbols.values.Data(); + int64_t *states_data = states.values.Data(); + + // This kernel calculates t * V ^ c for each state. + K2_EVAL( + c, num_states, lambda_init_states_with_t, (int32_t idx012) -> void { + states_data[idx012] + = frame_ids_data[idx012] * Pow(vocab_size, context_size); + }); + + // The following kernels calculate \sum_{i=1}^{c} s_i * V ^ (c - i) + for (int32_t i = 0; i < context_size; ++i) { + K2_EVAL( + c, num_states, lambda_generate_states, (int32_t idx012) -> void { + int32_t left_symbols_idx012x = left_symbols_row_splits3_data[idx012], + left_symbols_idx0123 = left_symbols_idx012x + i, + exp = context_size - i - 1; + states_data[idx012] + += left_symbols_data[left_symbols_idx0123] * Pow(vocab_size, exp); + }); + } + + // Sort those states for each sequence, so as to merge the same states. + // sorted_states has two axes: [seq][state] + auto sorted_states = Ragged( + RemoveAxis(states.shape, 1 /*axis*/), states.values.Clone()); + Array1 sorted_states_new2old(c, num_states); + SortSublists(&sorted_states, &sorted_states_new2old); + + // We need old2new map to find the original consecutive state. + Array1 sorted_states_old2new(c, num_states); + const int32_t *sorted_states_new2old_data = sorted_states_new2old.Data(); + int32_t *sorted_states_old2new_data = sorted_states_old2new.Data(); + K2_EVAL( + c, num_states, lambda_get_old2new, (int32_t i) -> void { + sorted_states_old2new_data[sorted_states_new2old_data[i]] = i; + }); + + // Search "tails concept" in k2/csrc/utils.h for the details of tail array. + // By applying ExclusiveSum on the tail_array, we can get a row_id mapping the + // sorted states to unique_states (i.e. the merged states). + Array1 tail_array(c, num_states); + const int32_t *sorted_states_row_ids1_data = sorted_states.RowIds(1).Data(); + const int64_t *sorted_states_data = sorted_states.values.Data(); + int32_t *tail_array_data = tail_array.Data(); + + K2_EVAL( + c, num_states, lambda_get_tail_array, (int32_t idx01) -> void { + if (idx01 == num_states - 1) tail_array_data[idx01] = 1; + int32_t idx0 = sorted_states_row_ids1_data[idx01], + next_idx0 = sorted_states_row_ids1_data[idx01 + 1]; + if (idx0 == next_idx0 && + sorted_states_data[idx01] == sorted_states_data[idx01 + 1]) + tail_array_data[idx01] = 0; + else + tail_array_data[idx01] = 1; + }); + + Array1 unique_states_row_ids(c, num_states); + ExclusiveSum(tail_array, &unique_states_row_ids); + + // unique_states_shape's shape [merged state][sorted state] + // unique_states_shape.row_splits.Dim() - 1 equals to the number of merged + // states. + RaggedShape unique_states_shape = RaggedShape2( + nullptr, &unique_states_row_ids, unique_states_row_ids.Dim()); + + // We are figuring out the ragged shape of the lattice. + // First, figure out the number of states (i.e. the merged states) for each + // sequence. + // Second, figure out the number of arcs for each merged state. + int32_t num_seqs = states.TotSize(0); + + // Plus 1 here because we will applying ExclusiveSum on this array. + Array1 num_states_for_seqs(c, states.TotSize(0) + 1); + + // "ss" is short for "sorted states" + // "us" is short for "unique states". + const int32_t *ss_row_splits1_data = sorted_states.RowSplits(1).Data(), + *us_row_ids1_data = unique_states_shape.RowIds(1).Data(); + int32_t *num_states_for_seqs_data = num_states_for_seqs.Data(); + + K2_EVAL( + c, num_seqs, lambda_get_num_states, (int32_t idx0) -> void { + int32_t ss_idx0x = ss_row_splits1_data[idx0], + ss_idx0x_next = ss_row_splits1_data[idx0 + 1], + us_idx0 = us_row_ids1_data[ss_idx0x], + us_idx0_next_minus_1 = us_row_ids1_data[ss_idx0x_next - 1], + num_unique_states = us_idx0_next_minus_1 - us_idx0 + 1; + // Plus 2 here, because we need a super dest_state for the last sampled + // symbol of each path, and a final state needed by k2. + num_states_for_seqs_data[idx0] = num_unique_states + 2; + }); + + ExclusiveSum(num_states_for_seqs, &num_states_for_seqs); + RaggedShape seqs_to_states_shape = RaggedShape2( + &num_states_for_seqs, nullptr, -1); + int32_t num_merged_states = seqs_to_states_shape.NumElements(); + + K2_CHECK_EQ(unique_states_shape.RowSplits(1).Dim() - 1 + num_seqs * 2, + num_merged_states); + + // Plus 1 here because we will applying ExclusiveSum on this array. + Array1 num_arcs_for_states( + c, seqs_to_states_shape.NumElements() + 1); + + // "sts" is short for "seqs to states" + // "us" is short for "unique states". + const int32_t *us_row_splits1_data = unique_states_shape.RowSplits(1).Data(), + *sts_row_ids1_data = seqs_to_states_shape.RowIds(1).Data(), + *sts_row_splits1_data + = seqs_to_states_shape.RowSplits(1).Data(); + int32_t *num_arcs_for_states_data = num_arcs_for_states.Data(); + + K2_EVAL( + c, num_merged_states, lambda_get_num_arcs, (int32_t idx01) -> void { + int32_t idx0 = sts_row_ids1_data[idx01], + idx0x_next = sts_row_splits1_data[idx0 + 1], + num_arcs = 0; + // The final state for each sequence. + if (idx01 == idx0x_next - 2) num_arcs = 1; + if (idx01 < idx0x_next - 2) { + // Minus idx0 * 2, because we add extra two states for each sequence. + int32_t us_idx0 = idx01 - idx0 * 2, + us_idx0x = us_row_splits1_data[us_idx0], + us_idx0x_next = us_row_splits1_data[us_idx0 + 1]; + num_arcs = us_idx0x_next - us_idx0x; + } + num_arcs_for_states_data[idx01] = num_arcs; + }); + + ExclusiveSum(num_arcs_for_states, &num_arcs_for_states); + RaggedShape states_to_arcs_shape = RaggedShape2( + &num_arcs_for_states, nullptr, -1); + + RaggedShape arcs_shape = ComposeRaggedShapes( + seqs_to_states_shape, states_to_arcs_shape); + int32_t num_arcs = arcs_shape.NumElements(); + + // Each state (before merging) has a leaving arc, we add a final arc + // to each sequence, so, the total number of arcs equals to + // num_states + num_seqs + K2_CHECK_EQ(num_arcs, num_seqs + num_states); + + // Populate arcs. + // "ss" is short for "sorted states" + const int32_t *sampled_paths_data = sampled_paths.values.Data(), + *arcs_shape_row_ids1_data = arcs_shape.RowIds(1).Data(), + *arcs_shape_row_splits1_data = arcs_shape.RowSplits(1).Data(), + *arcs_shape_row_ids2_data = arcs_shape.RowIds(2).Data(), + *states_row_ids2_data = states.RowIds(2).Data(), + *ss_row_ids1_data = sorted_states.RowIds(1).Data(); + const float *sampling_probs_data = sampling_probs.values.Data(); + Array1 arcs(c, num_arcs); + Arc *arcs_data = arcs.Data(); + + // The arc_map mapping from lattice arcs to original state indexes. + Array1 raw_arc_map(c, num_arcs); + int32_t *raw_arc_map_data = raw_arc_map.Data(); + + K2_EVAL( + c, num_arcs, lambda_set_arcs, (int32_t idx012) -> void { + Arc arc; + int32_t arc_map_value = -1; + int32_t idx01 = arcs_shape_row_ids2_data[idx012], + idx0 = arcs_shape_row_ids1_data[idx01], + idx0x = arcs_shape_row_splits1_data[idx0], + idx1 = idx01 - idx0x; + arc.src_state = idx1; + + // Final arc of the last sequence. + if (idx012 == num_arcs - 1) { + arc.dest_state = idx1 + 1; + arc.label = -1; + arc.score = 0.0; + } else { + int32_t idx01_next = arcs_shape_row_ids2_data[idx012 + 1], + idx0_next = arcs_shape_row_ids1_data[idx01_next]; + // Final arc for each sequence, except the last sequence. + if (idx0 != idx0_next) { + arc.dest_state = idx1 + 1; + arc.label = -1; + arc.score = 0.0; + } else { + // ss_idx01 is the global index of sorted states, minus idx0 here + // because we added an extra final arc for each sequence. + int32_t ss_idx01 = idx012 - idx0, + states_idx012 = sorted_states_new2old_data[ss_idx01]; + + arc_map_value = states_idx012; + arc.label = sampled_paths_data[states_idx012]; + float sampling_prob = sampling_probs_data[states_idx012]; + + int32_t us_idx0 = us_row_ids1_data[ss_idx01], + repeat_num = us_row_splits1_data[us_idx0 + 1] - + us_row_splits1_data[us_idx0]; + + arc.score = -logf(1 - powf(1 - sampling_prob, repeat_num)); + + // Final state of the last sequence, it will point to the added super + // dest_state. + if (states_idx012 == num_states - 1) { + int32_t idx0x_next = arcs_shape_row_splits1_data[idx0 + 1]; + arc.dest_state = idx0x_next - idx0x - 2; + } else { + // states_idx01 is path index + int32_t states_idx01 = states_row_ids2_data[states_idx012], + states_idx01_next = + states_row_ids2_data[states_idx012 + 1], + frame_id = frame_ids_data[states_idx012], + frame_id_next = frame_ids_data[states_idx012 + 1]; + // The first condition means this is the final state of each + // sequence. + // The second condition means we reach final frame at this state, + // the next state will be a start state of another path. + // So, this state points to the added super dest_state. + if (states_idx01 != states_idx01_next || + (states_idx01 == states_idx01_next && + frame_id_next < frame_id)) { + int32_t idx0x_next = + arcs_shape_row_splits1_data[idx0 + 1]; + arc.dest_state = idx0x_next - idx0x - 2; + } else { + // states_idx012 + 1 is the index of original consecutive state. + // "ss" is short for "sorted states" + // "us" is short for "unique states". + int32_t ss_idx01_next = + sorted_states_old2new_data[states_idx012 + 1], + us_idx0_next = us_row_ids1_data[ss_idx01_next]; + arc.dest_state = us_idx0_next + 2 * idx0 - idx0x; + } + } + } + } + arcs_data[idx012] = arc; + raw_arc_map_data[idx012] = arc_map_value; + }); + + FsaVec fsas = Ragged(arcs_shape, arcs); + // arcsort so as to remove duplicate arcs. + Array1 arc_sort_new2old(c, num_arcs); + SortSublists(&fsas, &arc_sort_new2old); + + // remove duplicate arcs, use renumbering + Renumbering renumber_arcs(c, num_arcs); + char *keep_arcs_data = renumber_arcs.Keep().Data(); + K2_EVAL( + c, num_arcs, lambda_set_keep_arcs, (int32_t idx012) -> void { + char keep = 1; + if (idx012 < num_arcs - 1) { + int32_t idx01 = arcs_shape_row_ids2_data[idx012], + idx01_next = arcs_shape_row_ids2_data[idx012 + 1]; + // duplicate arcs, which are arcs with the same symbol going from the + // same src_state to the same dest_state. The symbol will automatically + // be the same if the src_state and dest_state are the same if + // context_size > 0. + if (idx01 == idx01_next && + arcs_data[idx012].src_state == arcs_data[idx012 + 1].src_state && + arcs_data[idx012].dest_state == arcs_data[idx012 + 1].dest_state) { + K2_DCHECK_EQ(arcs_data[idx012].label, arcs_data[idx012 + 1].label); + keep = 0; + } + } + keep_arcs_data[idx012] = keep; + }); + + Array1 renumber_arc_map; + FsaVec final_fsas = Index( + fsas, 2, renumber_arcs.New2Old(), &renumber_arc_map); + + if (arc_map != nullptr) { + *arc_map = raw_arc_map[arc_sort_new2old][renumber_arc_map]; + } + return final_fsas; +} + } // namespace k2 diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 92dde5dfe..39a73216f 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -842,7 +842,7 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, weight 'x.weight', then the reverse of 'src' accepts the reverse of string 'x' with weight 'x.weight.reverse'. - Implementation notss: + Implementation notes: The Fsa in k2 only has one start state 0, and the only final state with the largest state number whose in-coming arcs have "-1" as the label. So, 1) the start state of 'dest' will correspond to the final state of 'src'. @@ -864,13 +864,51 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, @param [out] dest Output Fsa or FsaVec. At exit, it will be equivalent to the reverse Fsa of 'src'. Caution: the reverse will ignore the "-1" label. - + @param [out,optional] arc_map For each arc in `dest`, gives the index of the corresponding arc in `src` that it corresponds to. */ void Reverse(FsaVec &src, FsaVec *dest, Array1 *arc_map = nullptr); +/* + * Generate denominator lattice from sampled linear paths for RNN-T+MMI + * training. + * + * Implementation notes: + * 1) Generate "states" for each sampled symbol from their left_symbols and + * the frame_ids they are sampled from. + * 2) Sort those "states" for each sequence and then merge the same "states". + * 3) Map all of the sampled symbols to the merged "states". + * 4) Remove duplicate arcs. + * + * @param [in] sampled_paths The sampled symbols, it has a regular shape of + * [seq][num_path][path_length]. All its elements MUST satisfy + * `0 <= value < vocab_size. + * @param [in] frame_ids It contains the frame indexes of at which frame we + * sampled the symbols, which has same shape of sampled_paths. + * @param [in] left_symbols The left_symbols of the sampled symbols, it has a + * regular shape of [seq][num_path][path_length][context], the + * first three indexes are the same as sampled_paths. Each + * sublist along axis 3 has `context_size` elements. All its + * elements MUST satisfy `0 <= value < vocab_size`. + * @param [in] sampling_probs It contains the probabilities of sampling each + * symbol, which has the same shape as sampled_paths. + * @param [in] vocab_size The vocabulary size. + * @param [in] context_size The number of left symbols. + * @param [out] arc_map For each arc in the return Fsa, gives the orignal + * index (idx012) in sampled_paths that it corresponds to. + * + * @return Return the generated lattice. + */ +FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, + Ragged &frame_ids, + Ragged &left_symbols, + Ragged &sampling_probs, + int32_t vocab_size, + int32_t context_size, + Array1 *arc_map); + } // namespace k2 #endif // K2_CSRC_FSA_ALGO_H_ diff --git a/k2/csrc/fsa_algo_test.cu b/k2/csrc/fsa_algo_test.cu index 903514318..777a085d8 100644 --- a/k2/csrc/fsa_algo_test.cu +++ b/k2/csrc/fsa_algo_test.cu @@ -1383,4 +1383,53 @@ TEST(FsaAlgo, TestLevenshteinGraph) { } } +TEST(FsaAlgo, TestGenerateDenominatorLattice) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + Ragged sampled_paths(c, "[ [ [ 3 5 0 4 6 0 2 1 ] " + " [ 2 0 5 4 0 6 1 2 ] " + " [ 3 5 2 0 0 1 6 4 ] ] " + " [ [ 7 0 4 0 6 0 3 0 ] " + " [ 0 7 3 0 2 0 4 5 ] " + " [ 7 0 3 4 0 1 2 0 ] ] ]"); + Ragged frame_ids(c, "[ [ [ 0 0 0 1 1 1 2 2 ] " + " [ 0 0 1 1 1 2 2 2 ] " + " [ 0 0 0 0 1 2 2 2 ] ] " + " [ [ 0 0 1 1 2 2 3 3 ] " + " [ 0 1 1 1 2 2 3 1 ] " + " [ 0 0 1 1 1 2 2 2 ] ] ]"); + Ragged left_symbols(c, + "[ [ [ [ 0 0 ] [ 0 3 ] [ 3 5 ] [ 3 5 ] [ 5 4 ] [ 4 6 ] [ 4 6 ] [ 6 2 ] ] " + " [ [ 0 0 ] [ 0 2 ] [ 0 2 ] [ 2 5 ] [ 5 4 ] [ 5 4 ] [ 4 6 ] [ 6 1 ] ] " + " [ [ 0 0 ] [ 0 3 ] [ 3 5 ] [ 5 2 ] [ 5 2 ] [ 5 2 ] [ 2 1 ] [ 1 6 ] ] " + " ] " + " [ [ [ 0 0 ] [ 0 7 ] [ 0 7 ] [ 7 4 ] [ 7 4 ] [ 4 6 ] [ 4 6 ] [ 6 3 ] ] " + " [ [ 0 0 ] [ 0 0 ] [ 0 7 ] [ 7 3 ] [ 7 3 ] [ 3 2 ] [ 3 2 ] [ 0 0 ] ] " + " [ [ 0 0 ] [ 0 7 ] [ 0 7 ] [ 7 3 ] [ 3 4 ] [ 3 4 ] [ 4 1 ] [ 1 2 ] ] " + " ] ]"); + + Ragged sampling_probs(c, "[ [ [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 ] " + " [ 0.2 0.2 0.2 0.1 0.2 0.2 0.1 0.2 ] " + " [ 0.1 0.1 0.1 0.3 0.2 0.3 0.3 0.3 ] ] " + " [ [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 ] " + " [ 0.2 0.2 0.2 0.1 0.2 0.1 0.2 0.2 ] " + " [ 0.1 0.1 0.2 0.2 0.3 0.3 0.3 0.3 ] ] " + "]"); + Ragged path_scores(c, "[ [ [ 1 1 1 1 1 1 1 1 ] " + " [ 1 2 2 1 2 2 1 2 ] " + " [ 1 1 1 3 3 3 3 3 ] ] " + " [ [ 1 1 1 1 1 1 1 1 ] " + " [ 1 2 1 2 2 2 2 2 ] " + " [ 1 1 1 2 3 3 3 3 ] ] ]"); + + Array1 arc_map; + FsaVec lattice = GenerateDenominatorLattice( + sampled_paths, frame_ids, left_symbols, sampling_probs, + 10 /*vocab_size*/, 2 /*context_size*/, &arc_map); + K2_LOG(INFO) << arc_map; + K2_LOG(INFO) << lattice; + K2_LOG(INFO) << FsaToString(lattice.Index(0, 0)); + K2_LOG(INFO) << FsaToString(lattice.Index(0, 1)); + } +} + } // namespace k2 diff --git a/k2/csrc/math.h b/k2/csrc/math.h index 250cbd6ca..bfe7d64a2 100644 --- a/k2/csrc/math.h +++ b/k2/csrc/math.h @@ -29,8 +29,7 @@ namespace k2 { // Currently, only used in k2/csrc/rnnt_decode.cu // See https://github.com/k2-fsa/k2/pull/951#issuecomment-1096650842 -K2_CUDA_HOSTDEV __forceinline__ int64_t Pow(int64_t base, - int64_t exponent) { +K2_CUDA_HOSTDEV __forceinline__ int64_t Pow(int64_t base, int64_t exponent) { K2_CHECK_GE(exponent, 0); int64_t exp = 0; int64_t result = 1; diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index 304fde809..61561ec64 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -765,6 +765,29 @@ static void PybindReverse(py::module &m) { py::arg("src"), py::arg("need_arc_map") = true); } +static void PybindGenerateDenominatorLattice(py::module &m) { + m.def( + "generate_denominator_lattice", + [](RaggedAny &sampled_paths, RaggedAny &frame_ids, + RaggedAny &left_symbols, RaggedAny &sampling_probs, + int32_t vocab_size, int32_t context_size) + -> std::pair { + DeviceGuard guard(sampled_paths.any.Context()); + Array1 arc_map; + FsaVec lattice = GenerateDenominatorLattice( + sampled_paths.any.Specialize(), + frame_ids.any.Specialize(), + left_symbols.any.Specialize(), + sampling_probs.any.Specialize(), + vocab_size, context_size, &arc_map); + auto arc_map_tensor = ToTorch(arc_map); + return std::make_pair(lattice, arc_map_tensor); + }, + py::arg("sampled_paths"), py::arg("frame_ids"), py::arg("left_symbols"), + py::arg("sampling_probs"), py::arg("vocab_size"), + py::arg("context_size")); +} + } // namespace k2 void PybindFsaAlgo(py::module &m) { @@ -777,6 +800,7 @@ void PybindFsaAlgo(py::module &m) { k2::PybindDeterminize(m); k2::PybindExpandArcs(m); k2::PybindFixFinalLabels(m); + k2::PybindGenerateDenominatorLattice(m); k2::PybindIntersect(m); k2::PybindIntersectDense(m); k2::PybindIntersectDensePruned(m); diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 0b7c56d87..ac7e4d4ac 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -58,6 +58,7 @@ from .fsa_algo import ctc_topo from .fsa_algo import determinize from .fsa_algo import expand_ragged_attributes +from .fsa_algo import generate_denominator_lattice from .fsa_algo import intersect from .fsa_algo import intersect_device from .fsa_algo import invert diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index cba26c67b..0707cb7df 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -1381,3 +1381,57 @@ def union(fsas: Fsa) -> Fsa: out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map) return out_fsa + + +def generate_denominator_lattice( + sampled_paths: torch.Tensor, + frame_ids: torch.Tensor, + left_symbols: torch.Tensor, + sampling_probs: torch.Tensor, + path_scores: torch.Tensor, + vocab_size: int, + context_size: int, +) -> Fsa: + """Generate denominator lattice from sampled linear paths for RNN-T+MMI + training. + + Args: + sampled_paths: + The sampled symbols, it has a shape of (seq, num_path, path_length). + All its elements MUST satisfy `0 <= value < vocab_size. + frame_ids: + It contains the frame indexes of at which frame we sampled the symbols, + which has same shape of sampled_paths. + left_symbols: + The left_symbols of the sampled symbols, it has a shape of + (seq, num_path, path_length, context_size), the first three indexes are + the same as sampled_paths. All its elements MUST satisfy + `0 <= value < vocab_size`. + sampling_probs: + It contains the probabilities of sampling each symbol, which has a + same shape as sampled_paths. Normally comes from the output of + "predictor" head. + path_scores: + It contains the scores of each sampled symbol, which has a same shape as + sampled_paths. It might contain the output of hybrid head and the extra + language model output. Note: Autograd is supported for this tensor. + vocab_size: + The vocabulary size. + context_size: + The number of left symbols. + """ + ragged_arc, arc_map = _k2.generate_denominator_lattice( + sampled_paths=k2.RaggedTensor(sampled_paths), + frame_ids=k2.RaggedTensor(frame_ids), + left_symbols=k2.RaggedTensor(left_symbols), + sampling_probs=k2.RaggedTensor(sampling_probs), + vocab_size=vocab_size, + context_size=context_size, + ) + lattice = Fsa(ragged_arc) + a_value = getattr(lattice, "scores") + # Enable autograd for path_scores + b_value = index_select(path_scores.flatten(), arc_map) + value = a_value + b_value + setattr(lattice, "scores", value) + return lattice diff --git a/k2/python/k2/mmi.py b/k2/python/k2/mmi.py deleted file mode 100644 index 74c77dc7e..000000000 --- a/k2/python/k2/mmi.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -from torch.distributions.categorical import Categorical -from typing import Tuple - -def importance_sampling( - sampling_scores: torch.Tensor, - path_length: int, - num_paths: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - sampling_scores: - The output of predictor head, a tensor of shape (B, S, T, V) containing - the probabilities of emitting symbols at each (t, s) for each sequence. - path_length: - How many symbols we will sample for each path. - num_paths: - How many paths we will sample for each sequence. - - Returns: - Three tensors will be returned. - - sampled_indexs: - A tensor of shape (B, num_paths, path_length), containing the sampled symbol ids. - - sampled_scores: - A tensor of shape (B, num_paths, path_length), containing the sampling probabilities - of the corresponding symbols. - - sampled_t_indexs: - A tensor of shape (B, num_paths, path_length), containing the frame ids, which means - at what frame this symple be sampled. - """ - (B, S, T, V) = sampling_scores.shape - # we sample paths from frame 0 - t_index = torch.zeros( - (B, num_paths), - dtype=torch.int64, - device=sampling_scores.device - ) - # we sample paths from the first symbols (i.e. from null left_context) - s_index = torch.zeros( - (B, num_paths), - dtype=torch.int64, - device=sampling_scores.device - ) - - sampled_indexs = [] - sampled_scores = [] - sampled_t_indexs = [] - - for i in range(path_length): - # select context symbols for paths - # sub_scores : (B, num_paths, T, V) - sub_scores = torch.gather( - sampling_scores, dim=1, - index=s_index.reshape(B, num_paths, 1, 1).expand(B, num_paths, T, V)) - - # select frames for paths - # sub_scores : (B, num_paths, 1, V) - sub_scores = torch.gather( - sub_scores, dim=2, - index=t_index.reshape(B, num_paths, 1, 1).expand(B, num_paths, 1, V)) - - # sub_scores : (B, num_paths, V) - sub_scores = sub_scores.squeeze(2) - # sampler: https://pytorch.org/docs/stable/distributions.html#categorical - sampler = Categorical(probs=sub_scores) - - # sample one symbol for each path - # index : (B, num_paths) - index = sampler.sample() - sampled_indexs.append(index) - - # gather sampling probabilities for corresponding indexs - # score : (B, num_paths, 1) - score = torch.gather(sub_scores, dim=2, index=index.unsqueeze(2)) - sampled_scores.append(score.squeeze(2)) - - sampled_t_indexs.append(t_index) - - # update (t, s) for each path (for regular RNN-T) - # index == 0 means the sampled symbol is blank - t_mask = index == 0 - t_index = torch.where(t_mask, t_index + 1, t_index) - s_index = torch.where(t_mask, s_index + 1, s_index) - - # indexs : (B, num_paths, path_lengths) - indexs = torch.stack(sampled_indexs, dim=0).permute(1,2,0) - # scores : (B, num_paths, path_lengths) - scores = torch.stack(sampled_scores, dim=0).permute(1,2,0) - # t_indexs : (B, num_paths, path_lengths) - t_indexs = torch.stack(sampled_t_indexs, dim=0).permute(1,2,0) - - return indexs, scores, t_indexs - - -if __name__ == "__main__": - B, S, T, V = 2, 10, 20, 10 - path_length = 8 - num_path = 3 - logits = torch.randn((B, S, T, V)) - log_prob = torch.softmax(logits, -1) - indexs, scores, t_indexs = importance_sampling(log_prob, path_length, num_path) - print (indexs) - print (scores) - print (t_indexs) diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 455ebdd12..428988121 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -35,6 +35,7 @@ set(py_test_files fsa_from_unary_function_ragged_test.py fsa_from_unary_function_tensor_test.py fsa_test.py + generate_denominator_lattice_test.py get_arc_post_test.py get_backward_scores_test.py get_forward_scores_test.py diff --git a/k2/python/tests/generate_denominator_lattice_test.py b/k2/python/tests/generate_denominator_lattice_test.py new file mode 100644 index 000000000..31a656808 --- /dev/null +++ b/k2/python/tests/generate_denominator_lattice_test.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R generate_denominator_lattice_test_py + +import unittest + +import k2 +import torch + +from torch.distributions.categorical import Categorical +from typing import Tuple + + +def _roll_by_shifts( + src: torch.Tensor, shifts: torch.LongTensor +) -> torch.Tensor: + """Roll tensor with different shifts for each row. + + Note: + We assume the src is a 3 dimensions tensor and roll the last dimension. + + Example: + + >>> src = torch.arange(15).reshape((1,3,5)) + >>> src + tensor([[[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14]]]) + >>> shift = torch.tensor([[1, 2, 3]]) + >>> shift + tensor([[1, 2, 3]]) + >>> _roll_by_shifts(src, shift) + tensor([[[ 4, 0, 1, 2, 3], + [ 8, 9, 5, 6, 7], + [12, 13, 14, 10, 11]]]) + """ + assert src.dim() == 3 + (B, T, S) = src.shape + assert shifts.shape == (B, T) + + index = ( + torch.arange(S, device=src.device) + .view((1, S)) + .repeat((T, 1)) + .repeat((B, 1, 1)) + ) + index = (index - shifts.reshape(B, T, 1)) % S + return torch.gather(src, 2, index) + + +def simulate_importance_sampling( + batch_size: int, + vocab_size: int, + path_length: int, + num_paths: int, + context_size: int = 2, + blank_id: int = 0, + device: torch.device = torch.device("cpu"), +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + batch_size: + The number of sequence. + vocab_size: + Vocabulary size. + path_length: + How many symbols we will sample for each path. + num_paths: + How many paths we will sample for each sequence. + context_size: + The number of left symbols. + blank_id: + Stands for null context. + + Returns: + Three tensors will be returned. + - sampled_paths: + A tensor of shape (batch_size, num_paths, path_length), containing the + sampled symbol ids. + - sampling_probs: + A tensor of shape (batch_size, num_paths, path_length), containing the + sampling probabilities of the sampled symbols. + - left_symbols: + A tensor of shape (batch_size, num_paths, path_length, context_size), + containing the left symbols of the sampled symbols. + - frame_ids: + A tensor of shape (batch_size, num_paths, path_length), containing the + frame ids at which we sampled the symbols. + """ + # we sample paths from frame 0 + t_index = torch.zeros( + (batch_size, num_paths), dtype=torch.int64, device=device + ) + + left_symbols = torch.tensor( + [blank_id], dtype=torch.int64, device=device + ).expand(batch_size, num_paths, context_size) + + sampled_paths_list = [] + sampling_probs_list = [] + frame_ids_list = [] + left_symbols_list = [] + + for i in range(path_length): + probs = torch.randn(batch_size, num_paths, vocab_size) + probs = torch.softmax(probs, -1) + # sampler: https://pytorch.org/docs/stable/distributions.html#categorical + sampler = Categorical(probs=probs) + + # sample one symbol for each path + # index : (batch_size, num_paths) + index = sampler.sample() + sampled_paths_list.append(index) + + # gather sampling probabilities for corresponding indexs + # sampling_prob : (batch_size, num_paths, 1) + sampling_probs = torch.gather(probs, dim=2, index=index.unsqueeze(2)) + sampling_probs_list.append(sampling_probs.squeeze(2)) + + frame_ids_list.append(t_index) + + left_symbols_list.append(left_symbols) + + # update (t, s) for each path + # index == 0 means the sampled symbol is blank + t_mask = index == 0 + # t_index = torch.where(t_mask, t_index + 1, t_index) + t_index = t_index + 1 + current_symbols = torch.cat([left_symbols, index.unsqueeze(2)], dim=2) + left_symbols = _roll_by_shifts(current_symbols, t_mask.to(torch.int64)) + left_symbols = left_symbols[:, :, 1:] + + # sampled_paths : (batch_size, num_paths, path_lengths) + sampled_paths = torch.stack(sampled_paths_list, dim=2).int() + # sampling_probs : (batch_size, num_paths, path_lengths) + sampling_probs = torch.stack(sampling_probs_list, dim=2) + # frame_ids : (batch_size , num_paths, path_lengths) + frame_ids = torch.stack(frame_ids_list, dim=2).int() + # left_symbols : (batch_size, num_paths, path_lengths, context_size) + left_symbols = torch.stack(left_symbols_list, dim=2).int() + return sampled_paths, frame_ids, sampling_probs, left_symbols + + +class TestConnect(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.devices = [torch.device("cpu")] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device("cuda", 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device("cuda", 1)) + + def test(self): + context_size = 2 + batch_size, num_paths, path_length, vocab_size = 2, 3, 10, 10 + ( + sampled_paths_, + frame_ids_, + sampling_probs_, + left_symbols_, + ) = simulate_importance_sampling( + batch_size=batch_size, + vocab_size=vocab_size, + num_paths=num_paths, + path_length=path_length, + context_size=context_size, + ) + path_scores_ = torch.randn( + (batch_size, num_paths, path_length), dtype=torch.float + ) + for device in self.devices: + sampled_paths = sampled_paths_.to(device) + sampling_probs = sampling_probs_.to(device) + frame_ids = frame_ids_.to(device) + left_symbols = left_symbols_.to(device) + path_scores = path_scores_.detach().clone().to(device) + path_scores.requires_grad_(True) + fsa = k2.generate_denominator_lattice( + sampled_paths=sampled_paths, + frame_ids=frame_ids, + left_symbols=left_symbols, + sampling_probs=sampling_probs, + path_scores=path_scores, + vocab_size=vocab_size, + context_size=context_size, + ) + fsa = k2.connect(k2.top_sort(fsa)) + print(fsa) + scores = torch.sum( + fsa.get_tot_scores(log_semiring=True, use_double_scores=False) + ) + scores.backward() + print(path_scores.grad) + + +if __name__ == "__main__": + unittest.main() From 351186924e20c50dbdc6dcf06eb55950637d1ae9 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 9 Aug 2022 20:01:11 +0800 Subject: [PATCH 3/8] Fix typos --- k2/csrc/fsa_algo.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index 075706dc1..29eea9822 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -2107,7 +2107,7 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, // Second, figure out the number of arcs for each merged state. int32_t num_seqs = states.TotSize(0); - // Plus 1 here because we will applying ExclusiveSum on this array. + // Plus 1 here because we will apply ExclusiveSum on this array. Array1 num_states_for_seqs(c, states.TotSize(0) + 1); // "ss" is short for "sorted states" @@ -2136,7 +2136,7 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, K2_CHECK_EQ(unique_states_shape.RowSplits(1).Dim() - 1 + num_seqs * 2, num_merged_states); - // Plus 1 here because we will applying ExclusiveSum on this array. + // Plus 1 here because we will apply ExclusiveSum on this array. Array1 num_arcs_for_states( c, seqs_to_states_shape.NumElements() + 1); From 5dc671fbab23ef5eec460003ec6b696764bc8ec0 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 17 Aug 2022 10:45:05 +0800 Subject: [PATCH 4/8] Only allow states sampled on final frame to be final state; add unnormalized rnnt loss --- k2/csrc/fsa_algo.cu | 87 +++++---- k2/csrc/fsa_algo.h | 2 + k2/csrc/fsa_algo_test.cu | 9 +- k2/python/csrc/torch/fsa_algo.cu | 6 +- k2/python/k2/__init__.py | 1 + k2/python/k2/fsa_algo.py | 6 +- k2/python/k2/rnnt_loss.py | 169 ++++++++++++++++++ .../generate_denominator_lattice_test.py | 29 ++- 8 files changed, 260 insertions(+), 49 deletions(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index 29eea9822..651fcad61 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -1995,6 +1995,7 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, Ragged &frame_ids, Ragged &left_symbols, Ragged &sampling_probs, + Array1 &boundary, int32_t vocab_size, int32_t context_size, Array1 *arc_map) { @@ -2009,6 +2010,7 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, K2_DCHECK_EQ(sampled_paths.NumElements(), left_symbols.NumElements() * context_size); K2_DCHECK_EQ(sampled_paths.NumElements(), sampling_probs.NumElements()); + K2_DCHECK_EQ(sampled_paths.TotSize(0), boundary.Dim()); for (int32_t i = 0; i < 3; ++i) { K2_DCHECK_EQ(sampled_paths.TotSize(i), frame_ids.TotSize(i)); K2_DCHECK_EQ(sampled_paths.TotSize(i), left_symbols.TotSize(i)); @@ -2123,9 +2125,12 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, us_idx0 = us_row_ids1_data[ss_idx0x], us_idx0_next_minus_1 = us_row_ids1_data[ss_idx0x_next - 1], num_unique_states = us_idx0_next_minus_1 - us_idx0 + 1; - // Plus 2 here, because we need a super dest_state for the last sampled - // symbol of each path, and a final state needed by k2. - num_states_for_seqs_data[idx0] = num_unique_states + 2; + // Plus 3 here, because we need a super dest_state for the states sampled + // on the last frame (this dest_state will point to the final state), + // a fake super dest_state for the last states of linear paths that + // are not sampled on the last frames (this fake dest_state will be + // removed by connect operation), and a final state needed by k2. + num_states_for_seqs_data[idx0] = num_unique_states + 3; }); ExclusiveSum(num_states_for_seqs, &num_states_for_seqs); @@ -2133,7 +2138,7 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, &num_states_for_seqs, nullptr, -1); int32_t num_merged_states = seqs_to_states_shape.NumElements(); - K2_CHECK_EQ(unique_states_shape.RowSplits(1).Dim() - 1 + num_seqs * 2, + K2_CHECK_EQ(unique_states_shape.RowSplits(1).Dim() - 1 + num_seqs * 3, num_merged_states); // Plus 1 here because we will apply ExclusiveSum on this array. @@ -2153,15 +2158,17 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, int32_t idx0 = sts_row_ids1_data[idx01], idx0x_next = sts_row_splits1_data[idx0 + 1], num_arcs = 0; - // The final state for each sequence. + // The final arc for each sequence. if (idx01 == idx0x_next - 2) num_arcs = 1; - if (idx01 < idx0x_next - 2) { - // Minus idx0 * 2, because we add extra two states for each sequence. - int32_t us_idx0 = idx01 - idx0 * 2, + if (idx01 < idx0x_next - 3) { + // Minus idx0 * 3, because we add extra three states for each sequence. + int32_t us_idx0 = idx01 - idx0 * 3, us_idx0x = us_row_splits1_data[us_idx0], us_idx0x_next = us_row_splits1_data[us_idx0 + 1]; num_arcs = us_idx0x_next - us_idx0x; } + // idx01 == idx0x_next - 3 (i.e. the fake super dest_state) and + // idx01 == idx0x_next -1 (i.e. the final state) don't have arcs. num_arcs_for_states_data[idx01] = num_arcs; }); @@ -2185,6 +2192,7 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, *arcs_shape_row_splits1_data = arcs_shape.RowSplits(1).Data(), *arcs_shape_row_ids2_data = arcs_shape.RowIds(2).Data(), *states_row_ids2_data = states.RowIds(2).Data(), + *boundary_data = boundary.Data(), *ss_row_ids1_data = sorted_states.RowIds(1).Data(); const float *sampling_probs_data = sampling_probs.values.Data(); Array1 arcs(c, num_arcs); @@ -2233,37 +2241,50 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, arc.score = -logf(1 - powf(1 - sampling_prob, repeat_num)); - // Final state of the last sequence, it will point to the added super - // dest_state. + K2_DCHECK_LT(frame_ids_data[states_idx012], boundary_data[idx0]); + + int32_t idx0x_next = arcs_shape_row_splits1_data[idx0 + 1]; + + // Handle the final state of last sequence. if (states_idx012 == num_states - 1) { - int32_t idx0x_next = arcs_shape_row_splits1_data[idx0 + 1]; - arc.dest_state = idx0x_next - idx0x - 2; + // If current state is on final frame, it will point to the added + // super dest_state. + if (frame_ids_data[states_idx012] == boundary_data[idx0] - 1) { + arc.dest_state = idx0x_next - idx0x - 2; + } else { + // point to the fake added dest_state. + arc.dest_state = idx0x_next - idx0x - 3; + } } else { // states_idx01 is path index int32_t states_idx01 = states_row_ids2_data[states_idx012], states_idx01_next = - states_row_ids2_data[states_idx012 + 1], - frame_id = frame_ids_data[states_idx012], - frame_id_next = frame_ids_data[states_idx012 + 1]; - // The first condition means this is the final state of each - // sequence. - // The second condition means we reach final frame at this state, - // the next state will be a start state of another path. - // So, this state points to the added super dest_state. - if (states_idx01 != states_idx01_next || - (states_idx01 == states_idx01_next && - frame_id_next < frame_id)) { - int32_t idx0x_next = - arcs_shape_row_splits1_data[idx0 + 1]; - arc.dest_state = idx0x_next - idx0x - 2; + states_row_ids2_data[states_idx012 + 1]; + if (states_idx01 != states_idx01_next) { + // If current state is on final frame, it will point to the added + // super dest_state. + if (frame_ids_data[states_idx012] == boundary_data[idx0] - 1) { + arc.dest_state = idx0x_next - idx0x - 2; + } else { + // point to the fake added dest_state. + arc.dest_state = idx0x_next - idx0x - 3; + } } else { - // states_idx012 + 1 is the index of original consecutive state. - // "ss" is short for "sorted states" - // "us" is short for "unique states". - int32_t ss_idx01_next = - sorted_states_old2new_data[states_idx012 + 1], - us_idx0_next = us_row_ids1_data[ss_idx01_next]; - arc.dest_state = us_idx0_next + 2 * idx0 - idx0x; + // If current state is on final frame, it will point to the added + // super dest_state. + if (frame_ids_data[states_idx012] == boundary_data[idx0] - 1 && + frame_ids_data[states_idx012 + 1] != boundary_data[idx0] - 1) { + arc.dest_state = idx0x_next - idx0x - 2; + } else { + // states_idx012 + 1 is the index of original consecutive state. + // "ss" is short for "sorted states" + // "us" is short for "unique states". + int32_t ss_idx01_next = + sorted_states_old2new_data[states_idx012 + 1], + us_idx0_next = us_row_ids1_data[ss_idx01_next]; + // Plus 3 * idx0, because we add 3 state for each sequence + arc.dest_state = us_idx0_next + 3 * idx0 - idx0x; + } } } } diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 39a73216f..9097bfd53 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -894,6 +894,7 @@ void Reverse(FsaVec &src, FsaVec *dest, Array1 *arc_map = nullptr); * elements MUST satisfy `0 <= value < vocab_size`. * @param [in] sampling_probs It contains the probabilities of sampling each * symbol, which has the same shape as sampled_paths. + * @param [in] boundary It contains the number of frames for each sequence. * @param [in] vocab_size The vocabulary size. * @param [in] context_size The number of left symbols. * @param [out] arc_map For each arc in the return Fsa, gives the orignal @@ -905,6 +906,7 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, Ragged &frame_ids, Ragged &left_symbols, Ragged &sampling_probs, + Array1 &boundary, int32_t vocab_size, int32_t context_size, Array1 *arc_map); diff --git a/k2/csrc/fsa_algo_test.cu b/k2/csrc/fsa_algo_test.cu index 777a085d8..f30ca4d97 100644 --- a/k2/csrc/fsa_algo_test.cu +++ b/k2/csrc/fsa_algo_test.cu @@ -1414,16 +1414,11 @@ TEST(FsaAlgo, TestGenerateDenominatorLattice) { " [ 0.2 0.2 0.2 0.1 0.2 0.1 0.2 0.2 ] " " [ 0.1 0.1 0.2 0.2 0.3 0.3 0.3 0.3 ] ] " "]"); - Ragged path_scores(c, "[ [ [ 1 1 1 1 1 1 1 1 ] " - " [ 1 2 2 1 2 2 1 2 ] " - " [ 1 1 1 3 3 3 3 3 ] ] " - " [ [ 1 1 1 1 1 1 1 1 ] " - " [ 1 2 1 2 2 2 2 2 ] " - " [ 1 1 1 2 3 3 3 3 ] ] ]"); + Array1 boundary(c, "[ 3 4 ]"); Array1 arc_map; FsaVec lattice = GenerateDenominatorLattice( - sampled_paths, frame_ids, left_symbols, sampling_probs, + sampled_paths, frame_ids, left_symbols, sampling_probs, boundary, 10 /*vocab_size*/, 2 /*context_size*/, &arc_map); K2_LOG(INFO) << arc_map; K2_LOG(INFO) << lattice; diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index 61561ec64..0c4763eed 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -770,21 +770,23 @@ static void PybindGenerateDenominatorLattice(py::module &m) { "generate_denominator_lattice", [](RaggedAny &sampled_paths, RaggedAny &frame_ids, RaggedAny &left_symbols, RaggedAny &sampling_probs, - int32_t vocab_size, int32_t context_size) + torch::Tensor &boundary, int32_t vocab_size, int32_t context_size) -> std::pair { DeviceGuard guard(sampled_paths.any.Context()); Array1 arc_map; + Array1 boundary_array = FromTorch(boundary); FsaVec lattice = GenerateDenominatorLattice( sampled_paths.any.Specialize(), frame_ids.any.Specialize(), left_symbols.any.Specialize(), sampling_probs.any.Specialize(), + boundary_array, vocab_size, context_size, &arc_map); auto arc_map_tensor = ToTorch(arc_map); return std::make_pair(lattice, arc_map_tensor); }, py::arg("sampled_paths"), py::arg("frame_ids"), py::arg("left_symbols"), - py::arg("sampling_probs"), py::arg("vocab_size"), + py::arg("sampling_probs"), py::arg("boundary"), py::arg("vocab_size"), py::arg("context_size")); } diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index ac7e4d4ac..babe7d5ec 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -100,6 +100,7 @@ from .rnnt_loss import get_rnnt_logprobs_smoothed from .rnnt_loss import get_rnnt_prune_ranges from .rnnt_loss import rnnt_loss +from .rnnt_loss import rnnt_loss_for_numerator from .rnnt_loss import rnnt_loss_pruned from .rnnt_loss import rnnt_loss_simple from .rnnt_loss import rnnt_loss_smoothed diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 0707cb7df..5d2ed8390 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -1389,6 +1389,7 @@ def generate_denominator_lattice( left_symbols: torch.Tensor, sampling_probs: torch.Tensor, path_scores: torch.Tensor, + boundary: torch.Tensor, vocab_size: int, context_size: int, ) -> Fsa: @@ -1415,6 +1416,8 @@ def generate_denominator_lattice( It contains the scores of each sampled symbol, which has a same shape as sampled_paths. It might contain the output of hybrid head and the extra language model output. Note: Autograd is supported for this tensor. + boundary: + It contains the number of frames for each sequence. vocab_size: The vocabulary size. context_size: @@ -1425,6 +1428,7 @@ def generate_denominator_lattice( frame_ids=k2.RaggedTensor(frame_ids), left_symbols=k2.RaggedTensor(left_symbols), sampling_probs=k2.RaggedTensor(sampling_probs), + boundary=boundary, vocab_size=vocab_size, context_size=context_size, ) @@ -1432,6 +1436,6 @@ def generate_denominator_lattice( a_value = getattr(lattice, "scores") # Enable autograd for path_scores b_value = index_select(path_scores.flatten(), arc_map) - value = a_value + b_value + value = b_value - a_value setattr(lattice, "scores", value) return lattice diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 64e51a4c2..e54db2685 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -366,6 +366,175 @@ def get_rnnt_logprobs_joint( return (px, py) +def get_rnnt_logprobs_joint_for_numerator( + logits: Tensor, + symbols: Tensor, + termination_symbol: int, + boundary: Optional[Tensor] = None, + normalized: int = True, + modified: bool = False, +) -> Tuple[Tensor, Tensor]: + """Reduces RNN-T problem to a compact, standard form that can then be given + (with boundaries) to mutual_information_recursion(). + This function is called from rnnt_loss(). + + Args: + logits: + The output of joiner network, with shape (B, T, S + 1, C), + i.e. batch, time_seq_len, symbol_seq_len+1, num_classes + symbols: + A LongTensor of shape [B][S], containing the symbols at each position + of the sequence. + termination_symbol: + The identity of the termination symbol, must be in {0..C-1} + boundary: + a optional LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] + if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. + Returns: + (px, py) (the names are quite arbitrary):: + + px: logprobs, of shape [B][S][T+1] + py: logprobs, of shape [B][S+1][T] + + in the recursion:: + + p[b,0,0] = 0.0 + if !modified: + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + if modified: + p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], + p[b,s,t-1] + py[b,s,t-1]) + .. where p[b][s][t] is the "joint score" of the pair of subsequences of + length s and t respectively. px[b][s][t] represents the probability of + extending the subsequences of length (s,t) by one in the s direction, + given the particular symbol, and py[b][s][t] represents the probability + of extending the subsequences of length (s,t) by one in the t direction, + i.e. of emitting the termination/next-frame symbol. + + if !modified, px[:,:,T] equals -infinity, meaning on the + "one-past-the-last" frame we cannot emit any symbols. + This is simply a way of incorporating + the probability of the termination symbol on the last frame. + """ + assert logits.ndim == 4 + (B, T, S1, C) = logits.shape + S = S1 - 1 + assert symbols.shape == (B, S) + + px = torch.gather( + logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) + ).squeeze(-1) + px = px.permute((0, 2, 1)) + + if not modified: + px = torch.cat( + ( + px, + torch.full( + (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + ), + ), + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. + + py = ( + logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() + ) # [B][S+1][T] + + if normalized: + normalizers = torch.logsumexp(logits, dim=3) + normalizers = normalizers.permute((0, 2, 1)) + px[:, :, :T] -= normalizers[:, :S, :] + py -= normalizers + + px = px.contiguous() + py = py.contiguous() + + if not modified: + px = fix_for_boundary(px, boundary) + + return (px, py) + + +def rnnt_loss_for_numerator( + logits: Tensor, + symbols: Tensor, + external_lm: Tensor, + termination_symbol: int, + boundary: Optional[Tensor] = None, + modified: bool = False, + normalized: bool = True, + reduction: Optional[str] = "mean", +) -> Tensor: + """A normal RNN-T loss, which uses a 'joiner' network output as input, + i.e. a 4 dimensions tensor. + + Args: + logits: + The output of joiner network, with shape (B, T, S + 1, C), + i.e. batch, time_seq_len, symbol_seq_len+1, num_classes + symbols: + The symbol sequences, a LongTensor of shape [B][S], and elements + in {0..C-1}. + external_lm: + External language model network, with shape (B, S + 1, C). + termination_symbol: + the termination symbol, with 0 <= termination_symbol < C + boundary: + a optional LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + modified: if True, each time a real symbol is consumed a frame will + also be consumed, so at most 1 symbol can appear per frame. + normalized: + True to do log_softmax normalization, otherwise not. + reduction: + Specifies the reduction to apply to the output: `none`, `mean` or `sum`. + `none`: no reduction will be applied. + `mean`: apply `torch.mean` over the batches. + `sum`: the output will be summed. + Default: `mean` + + Returns: + If recursion is `none`, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch, otherwise a scalar + with the reduction applied. + """ + px, py = get_rnnt_logprobs_joint_for_numerator( + logits=logits, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + normalized=normalized, + modified=modified, + ) + + B, S, T1 = px.shape + px_external_lm = torch.gather( + external_lm[:, :S], dim=2, index=symbols.unsqueeze(-1) + ) # [B][S][1] + px += px_external_lm + + negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + if reduction == "none": + return -negated_loss + elif reduction == "mean": + return -torch.mean(negated_loss) + elif reduction == "sum": + return -torch.sum(negated_loss) + else: + assert ( + False + ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + + def rnnt_loss( logits: Tensor, symbols: Tensor, diff --git a/k2/python/tests/generate_denominator_lattice_test.py b/k2/python/tests/generate_denominator_lattice_test.py index 31a656808..29a06fbf5 100644 --- a/k2/python/tests/generate_denominator_lattice_test.py +++ b/k2/python/tests/generate_denominator_lattice_test.py @@ -67,7 +67,7 @@ def _roll_by_shifts( def simulate_importance_sampling( - batch_size: int, + boundary: torch.Tensor, vocab_size: int, path_length: int, num_paths: int, @@ -77,8 +77,9 @@ def simulate_importance_sampling( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: - batch_size: - The number of sequence. + boundary: + It is a tensor with shape (B,), containing the number of frames for + each sequence. vocab_size: Vocabulary size. path_length: @@ -106,10 +107,14 @@ def simulate_importance_sampling( frame ids at which we sampled the symbols. """ # we sample paths from frame 0 + batch_size = boundary.numel() + t_index = torch.zeros( (batch_size, num_paths), dtype=torch.int64, device=device ) + t_index_max = boundary.view(batch_size, 1).expand(batch_size, num_paths) + left_symbols = torch.tensor( [blank_id], dtype=torch.int64, device=device ).expand(batch_size, num_paths, context_size) @@ -142,11 +147,20 @@ def simulate_importance_sampling( # update (t, s) for each path # index == 0 means the sampled symbol is blank t_mask = index == 0 - # t_index = torch.where(t_mask, t_index + 1, t_index) + # t_index = torch.where(t_mask, t_index + 1, t_index) t_index = t_index + 1 + + final_mask = t_index >= t_index_max + reach_final = torch.any(final_mask) + if reach_final: + new_t_index = torch.randint(0, torch.min(t_index_max) - 1, (1,)).item() + t_index.masked_fill_(final_mask, new_t_index) + current_symbols = torch.cat([left_symbols, index.unsqueeze(2)], dim=2) left_symbols = _roll_by_shifts(current_symbols, t_mask.to(torch.int64)) left_symbols = left_symbols[:, :, 1:] + if reach_final: + left_symbols.masked_fill_(final_mask.unsqueeze(2), blank_id) # sampled_paths : (batch_size, num_paths, path_lengths) sampled_paths = torch.stack(sampled_paths_list, dim=2).int() @@ -172,13 +186,14 @@ def setUpClass(cls): def test(self): context_size = 2 batch_size, num_paths, path_length, vocab_size = 2, 3, 10, 10 + boundary_ = torch.tensor([6, 9], dtype=torch.int32) ( sampled_paths_, frame_ids_, sampling_probs_, left_symbols_, ) = simulate_importance_sampling( - batch_size=batch_size, + boundary=boundary_, vocab_size=vocab_size, num_paths=num_paths, path_length=path_length, @@ -188,6 +203,7 @@ def test(self): (batch_size, num_paths, path_length), dtype=torch.float ) for device in self.devices: + boundary = boundary_.to(device) sampled_paths = sampled_paths_.to(device) sampling_probs = sampling_probs_.to(device) frame_ids = frame_ids_.to(device) @@ -199,12 +215,13 @@ def test(self): frame_ids=frame_ids, left_symbols=left_symbols, sampling_probs=sampling_probs, + boundary=boundary, path_scores=path_scores, vocab_size=vocab_size, context_size=context_size, ) - fsa = k2.connect(k2.top_sort(fsa)) print(fsa) + fsa = k2.connect(k2.top_sort(fsa)) scores = torch.sum( fsa.get_tot_scores(log_semiring=True, use_double_scores=False) ) From 9d8e3e325bef9174f82deb6784ca820e84a8f4e8 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 18 Aug 2022 11:27:44 +0800 Subject: [PATCH 5/8] Minor fixes --- k2/python/k2/__init__.py | 1 - k2/python/k2/fsa_algo.py | 2 +- k2/python/k2/rnnt_loss.py | 256 +++++++++++--------------------------- 3 files changed, 74 insertions(+), 185 deletions(-) diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index babe7d5ec..ac7e4d4ac 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -100,7 +100,6 @@ from .rnnt_loss import get_rnnt_logprobs_smoothed from .rnnt_loss import get_rnnt_prune_ranges from .rnnt_loss import rnnt_loss -from .rnnt_loss import rnnt_loss_for_numerator from .rnnt_loss import rnnt_loss_pruned from .rnnt_loss import rnnt_loss_simple from .rnnt_loss import rnnt_loss_smoothed diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 5d2ed8390..448e744b9 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -1436,6 +1436,6 @@ def generate_denominator_lattice( a_value = getattr(lattice, "scores") # Enable autograd for path_scores b_value = index_select(path_scores.flatten(), arc_map) - value = b_value - a_value + value = b_value + a_value setattr(lattice, "scores", value) return lattice diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index e54db2685..d00eb9e16 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -199,6 +199,7 @@ def rnnt_loss_simple( symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + external_lm: Optional[Tensor] = None, modified: bool = False, reduction: Optional[str] = "mean", return_grad: bool = False, @@ -224,6 +225,8 @@ def rnnt_loss_simple( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + external_lm: + External language model network, with shape (B, S + 1, C). modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. reduction: @@ -255,6 +258,14 @@ def rnnt_loss_simple( boundary=boundary, modified=modified, ) + + if external_lm is not None: + B, S, T1 = px.shape + px_external_lm = torch.gather( + external_lm[:, :S], dim=2, index=symbols.unsqueeze(-1) + ) # [B][S][1] + px += px_external_lm + scores_and_grads = mutual_information_recursion( px=px, py=py, boundary=boundary, return_grad=return_grad ) @@ -266,9 +277,9 @@ def rnnt_loss_simple( elif reduction == "sum": loss = -torch.sum(negated_loss) else: - assert ( - False - ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + raise ValueError ( + f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + ) return (loss, scores_and_grads[1]) if return_grad else loss @@ -277,6 +288,7 @@ def get_rnnt_logprobs_joint( symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + normalized: bool = True, modified: bool = False, ) -> Tuple[Tensor, Tensor]: """Reduces RNN-T problem to a compact, standard form that can then be given @@ -298,6 +310,8 @@ def get_rnnt_logprobs_joint( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + normalized: + True to do log_softmax normalization, otherwise not. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. Returns: @@ -332,8 +346,10 @@ def get_rnnt_logprobs_joint( S = S1 - 1 assert symbols.shape == (B, S) - normalizers = torch.logsumexp(logits, dim=3) - normalizers = normalizers.permute((0, 2, 1)) + normalizers = None + if normalized: + normalizers = torch.logsumexp(logits, dim=3) + normalizers = normalizers.permute((0, 2, 1)) px = torch.gather( logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) @@ -351,106 +367,14 @@ def get_rnnt_logprobs_joint( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - px[:, :, :T] -= normalizers[:, :S, :] - - py = ( - logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() - ) # [B][S+1][T] - py -= normalizers - px = px.contiguous() - py = py.contiguous() - - if not modified: - px = fix_for_boundary(px, boundary) - - return (px, py) - - -def get_rnnt_logprobs_joint_for_numerator( - logits: Tensor, - symbols: Tensor, - termination_symbol: int, - boundary: Optional[Tensor] = None, - normalized: int = True, - modified: bool = False, -) -> Tuple[Tensor, Tensor]: - """Reduces RNN-T problem to a compact, standard form that can then be given - (with boundaries) to mutual_information_recursion(). - This function is called from rnnt_loss(). - - Args: - logits: - The output of joiner network, with shape (B, T, S + 1, C), - i.e. batch, time_seq_len, symbol_seq_len+1, num_classes - symbols: - A LongTensor of shape [B][S], containing the symbols at each position - of the sequence. - termination_symbol: - The identity of the termination symbol, must be in {0..C-1} - boundary: - a optional LongTensor of shape [B, 4] with elements interpreted as - [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as - [0, 0, S, T] - if boundary is not supplied. - Most likely you will want begin_symbol and begin_frame to be zero. - modified: if True, each time a real symbol is consumed a frame will - also be consumed, so at most 1 symbol can appear per frame. - Returns: - (px, py) (the names are quite arbitrary):: - - px: logprobs, of shape [B][S][T+1] - py: logprobs, of shape [B][S+1][T] - - in the recursion:: - - p[b,0,0] = 0.0 - if !modified: - p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], - p[b,s,t-1] + py[b,s,t-1]) - if modified: - p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], - p[b,s,t-1] + py[b,s,t-1]) - .. where p[b][s][t] is the "joint score" of the pair of subsequences of - length s and t respectively. px[b][s][t] represents the probability of - extending the subsequences of length (s,t) by one in the s direction, - given the particular symbol, and py[b][s][t] represents the probability - of extending the subsequences of length (s,t) by one in the t direction, - i.e. of emitting the termination/next-frame symbol. - - if !modified, px[:,:,T] equals -infinity, meaning on the - "one-past-the-last" frame we cannot emit any symbols. - This is simply a way of incorporating - the probability of the termination symbol on the last frame. - """ - assert logits.ndim == 4 - (B, T, S1, C) = logits.shape - S = S1 - 1 - assert symbols.shape == (B, S) - - px = torch.gather( - logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) - ).squeeze(-1) - px = px.permute((0, 2, 1)) - - if not modified: - px = torch.cat( - ( - px, - torch.full( - (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype - ), - ), - dim=2, - ) # now: [B][S][T+1], index [:,:,T] has -inf.. + if normalized: + px[:, :, :T] -= normalizers[:, :S, :] py = ( logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() ) # [B][S+1][T] if normalized: - normalizers = torch.logsumexp(logits, dim=3) - normalizers = normalizers.permute((0, 2, 1)) - px[:, :, :T] -= normalizers[:, :S, :] py -= normalizers px = px.contiguous() @@ -462,14 +386,14 @@ def get_rnnt_logprobs_joint_for_numerator( return (px, py) -def rnnt_loss_for_numerator( +def rnnt_loss( logits: Tensor, symbols: Tensor, - external_lm: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, - modified: bool = False, + external_lm: Optional[Tensor] = None, normalized: bool = True, + modified: bool = False, reduction: Optional[str] = "mean", ) -> Tensor: """A normal RNN-T loss, which uses a 'joiner' network output as input, @@ -482,8 +406,6 @@ def rnnt_loss_for_numerator( symbols: The symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}. - external_lm: - External language model network, with shape (B, S + 1, C). termination_symbol: the termination symbol, with 0 <= termination_symbol < C boundary: @@ -491,75 +413,10 @@ def rnnt_loss_for_numerator( [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. - modified: if True, each time a real symbol is consumed a frame will - also be consumed, so at most 1 symbol can appear per frame. + external_lm: + External language model network, with shape (B, S + 1, C). normalized: True to do log_softmax normalization, otherwise not. - reduction: - Specifies the reduction to apply to the output: `none`, `mean` or `sum`. - `none`: no reduction will be applied. - `mean`: apply `torch.mean` over the batches. - `sum`: the output will be summed. - Default: `mean` - - Returns: - If recursion is `none`, returns a tensor of shape (B,), containing the - total RNN-T loss values for each element of the batch, otherwise a scalar - with the reduction applied. - """ - px, py = get_rnnt_logprobs_joint_for_numerator( - logits=logits, - symbols=symbols, - termination_symbol=termination_symbol, - boundary=boundary, - normalized=normalized, - modified=modified, - ) - - B, S, T1 = px.shape - px_external_lm = torch.gather( - external_lm[:, :S], dim=2, index=symbols.unsqueeze(-1) - ) # [B][S][1] - px += px_external_lm - - negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) - if reduction == "none": - return -negated_loss - elif reduction == "mean": - return -torch.mean(negated_loss) - elif reduction == "sum": - return -torch.sum(negated_loss) - else: - assert ( - False - ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" - - -def rnnt_loss( - logits: Tensor, - symbols: Tensor, - termination_symbol: int, - boundary: Optional[Tensor] = None, - modified: bool = False, - reduction: Optional[str] = "mean", -) -> Tensor: - """A normal RNN-T loss, which uses a 'joiner' network output as input, - i.e. a 4 dimensions tensor. - - Args: - logits: - The output of joiner network, with shape (B, T, S + 1, C), - i.e. batch, time_seq_len, symbol_seq_len+1, num_classes - symbols: - The symbol sequences, a LongTensor of shape [B][S], and elements - in {0..C-1}. - termination_symbol: - the termination symbol, with 0 <= termination_symbol < C - boundary: - a optional LongTensor of shape [B, 4] with elements interpreted as - [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as - [0, 0, S, T] if boundary is not supplied. - Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. reduction: @@ -579,8 +436,17 @@ def rnnt_loss( symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, + normalized=normalized, modified=modified, ) + + if external_lm is not None: + B, S, T1 = px.shape + px_external_lm = torch.gather( + external_lm[:, :S], dim=2, index=symbols.unsqueeze(-1) + ) # [B][S][1] + px += px_external_lm + negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) if reduction == "none": return -negated_loss @@ -589,9 +455,9 @@ def rnnt_loss( elif reduction == "sum": return -torch.sum(negated_loss) else: - assert ( - False - ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + raise ValueError( + f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + ) def _adjust_pruning_lower_bound( @@ -837,6 +703,7 @@ def get_rnnt_logprobs_pruned( ranges: Tensor, termination_symbol: int, boundary: Tensor, + normalized: bool = True, modified: bool = False, ) -> Tuple[Tensor, Tensor]: """Construct px, py for mutual_information_recursion with pruned output. @@ -863,6 +730,8 @@ def get_rnnt_logprobs_pruned( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + normalized: + True to do log_softmax normalization, otherwise not. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. Returns: @@ -877,7 +746,9 @@ def get_rnnt_logprobs_pruned( assert ranges.shape == (B, T, s_range) (B, S) = symbols.shape - normalizers = torch.logsumexp(logits, dim=3) + normalizers = None + if normalized: + normalizers = torch.logsumexp(logits, dim=3) symbols_with_terminal = torch.cat( ( @@ -902,7 +773,9 @@ def get_rnnt_logprobs_pruned( px = torch.gather( logits, dim=3, index=pruned_symbols.reshape(B, T, s_range, 1) ).squeeze(-1) - px = px - normalizers + + if normalized: + px = px - normalizers # (B, T, S) with index larger than s_range in dim 2 fill with -inf px = torch.cat( @@ -935,7 +808,9 @@ def get_rnnt_logprobs_pruned( ) # now: [B][S][T+1], index [:,:,T] has -inf.. py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range) - py = py - normalizers + + if normalized: + py = py - normalizers # (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf py = torch.cat( @@ -971,7 +846,9 @@ def rnnt_loss_pruned( ranges: Tensor, termination_symbol: int, boundary: Tensor = None, + external_lm: Optional[Tensor] = None, modified: bool = False, + normalized: bool = True, reduction: Optional[str] = "mean", ) -> Tensor: """A RNN-T loss with pruning, which uses a pruned 'joiner' network output @@ -1000,8 +877,12 @@ def rnnt_loss_pruned( [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + external_lm: + External language model network, with shape (B, S + 1, C). modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. + normalized: + True to do log_softmax normalization, otherwise not. reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -1019,8 +900,17 @@ def rnnt_loss_pruned( ranges=ranges, termination_symbol=termination_symbol, boundary=boundary, + normalized=normalized, modified=modified, ) + + if external_lm is not None: + B, S, T1 = px.shape + px_external_lm = torch.gather( + external_lm[:, :S], dim=2, index=symbols.unsqueeze(-1) + ) # [B][S][1] + px += px_external_lm + negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) if reduction == "none": return -negated_loss @@ -1029,9 +919,9 @@ def rnnt_loss_pruned( elif reduction == "sum": return -torch.sum(negated_loss) else: - assert ( - False - ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + raise ValueError ( + f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + ) def get_rnnt_logprobs_smoothed( @@ -1339,7 +1229,7 @@ def rnnt_loss_smoothed( elif reduction == "sum": loss = -torch.sum(negated_loss) else: - assert ( - False - ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + raise ValueError ( + f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + ) return (loss, scores_and_grads[1]) if return_grad else loss From 2a14970e94dc892336d13ad29df8d81bfc1569dc Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 31 Aug 2022 14:32:34 +0800 Subject: [PATCH 6/8] fix importance sampling scores; return arc_map --- k2/csrc/fsa_algo.cu | 7 ++++++- k2/python/k2/fsa_algo.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index 651fcad61..e64371e3c 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -2239,7 +2239,12 @@ FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, repeat_num = us_row_splits1_data[us_idx0 + 1] - us_row_splits1_data[us_idx0]; - arc.score = -logf(1 - powf(1 - sampling_prob, repeat_num)); + float score = -logf(1 - powf(1 - sampling_prob, repeat_num)); + if (score - score != 0) { + arc.score = 0.0; + } else { + arc.score = score; + } K2_DCHECK_LT(frame_ids_data[states_idx012], boundary_data[idx0]); diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 448e744b9..a9fbe4862 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -24,6 +24,7 @@ import torch import _k2 import k2 +import logging from . import fsa_properties from .fsa import Fsa @@ -1392,7 +1393,8 @@ def generate_denominator_lattice( boundary: torch.Tensor, vocab_size: int, context_size: int, -) -> Fsa: + return_arc_map: bool = False, +) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]: """Generate denominator lattice from sampled linear paths for RNN-T+MMI training. @@ -1422,6 +1424,8 @@ def generate_denominator_lattice( The vocabulary size. context_size: The number of left symbols. + return_arc_map: + Whether to return arc_map. """ ragged_arc, arc_map = _k2.generate_denominator_lattice( sampled_paths=k2.RaggedTensor(sampled_paths), @@ -1436,6 +1440,10 @@ def generate_denominator_lattice( a_value = getattr(lattice, "scores") # Enable autograd for path_scores b_value = index_select(path_scores.flatten(), arc_map) + assert torch.all(a_value >= 0), a_value value = b_value + a_value setattr(lattice, "scores", value) - return lattice + if return_arc_map: + return lattice, arc_map + else: + return lattice From a9fea9df3f0aa913975924b0fa2c4da970329c9a Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 12 Jan 2023 16:23:44 +0800 Subject: [PATCH 7/8] return grad in rnnt_loss_pruned --- k2/python/k2/rnnt_loss.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index d00eb9e16..4e18cebaf 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -850,7 +850,8 @@ def rnnt_loss_pruned( modified: bool = False, normalized: bool = True, reduction: Optional[str] = "mean", -) -> Tensor: + return_grad: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: """A RNN-T loss with pruning, which uses a pruned 'joiner' network output as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), s_range means the symbols number kept for each frame. @@ -889,10 +890,20 @@ def rnnt_loss_pruned( `mean`: apply `torch.mean` over the batches. `sum`: the output will be summed. Default: `mean` + return_grad: + Whether to return grads of px and py, this grad standing for the + occupation probability is the output of the backward with a + `fake gradient`, the `fake gradient` is the same as the gradient you'd + get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the + loss here is the loss with reduction "none". + This is useful to implement the pruned version of rnnt loss. Returns: - If recursion is `none`, returns a tensor of shape (B,), containing the - total RNN-T loss values for each element of the batch, otherwise a scalar - with the reduction applied. + If return_grad is False, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch if reduction equals + to "none", otherwise a scalar with the reduction applied. + If return_grad is True, the grads of px and py, which is the output of + backward with a `fake gradient`(see above), will be returned too. And the + returned value will be a tuple like (loss, (px_grad, py_grad)). """ px, py = get_rnnt_logprobs_pruned( logits=logits, @@ -911,17 +922,21 @@ def rnnt_loss_pruned( ) # [B][S][1] px += px_external_lm - negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + scores_and_grads = mutual_information_recursion( + px=px, py=py, boundary=boundary, return_grad=return_grad + ) + negated_loss = scores_and_grads[0] if return_grad else scores_and_grads if reduction == "none": - return -negated_loss + loss = -negated_loss elif reduction == "mean": - return -torch.mean(negated_loss) + loss = -torch.mean(negated_loss) elif reduction == "sum": - return -torch.sum(negated_loss) + loss = -torch.sum(negated_loss) else: raise ValueError ( f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" ) + return (loss, scores_and_grads[1]) if return_grad else loss def get_rnnt_logprobs_smoothed( From 2be0926abfb6eb9c1b8f9968ddc9791040db0211 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 13 Jan 2023 10:28:24 +0800 Subject: [PATCH 8/8] Remove redundant code --- k2/python/k2/rnnt_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index c3cfa4dd1..de9cfcac3 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -1156,7 +1156,6 @@ def rnnt_loss_pruned( termination_symbol: int, boundary: Tensor = None, normalized: bool = True, - reduction: Optional[str] = "mean", return_grad: bool = False, rnnt_type: str = "regular", delay_penalty: float = 0.0,