diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index b644e2314..f210f5eb4 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -2056,13 +2056,13 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, index_arc_idx2 = idx4; // corresponds to foo=0, so idx3 will be 0; // the idx4 enumerates the arcs leaving it.. } else { - // this is one of the extra `foo` indexes, it's conrespoding index + // this is one of the extra `foo` indexes, it's corresponding index // into `index` is `foo` index minus 1 index_arc_idx2 = idx2 - 1; } int32_t index_arc_idx01x = index_row_splits2_data[idx01]; - // index of the arc in source FSA, FSA that we're replaceing.. + // index of the arc in source FSA, FSA that we're replacing.. int32_t index_arc_idx012 = index_arc_idx01x + index_arc_idx2; Arc index_arc = index_arcs_data[index_arc_idx012]; @@ -2105,8 +2105,8 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, } else { // this arc would point to the initial state of the fsa in src, // the state id bias to current state(the src-state) is the count - // of all the ostates coresponding to the original state util now, - // the idx4 enumerates foo index + // of all the ostates corresponding to the original state until + // now, the idx4 enumerates foo index int32_t idx012_t = idx01x + 0, idx2_t = idx4, idx012x_t = tos_row_splits3_data[idx012_t], @@ -2121,7 +2121,7 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, arc_index_map_idx = index_arc_idx012; } else { // handle the arcs belongs to src // the arc point to the final state of the fsa in src would point to - // the dest state of the arc we're replaceing + // the dest state of the arc we're replacing if (src_arc.label == -1) { oarc.dest_state = orig_dest_state_idx0123 - idx0xxx; } else { @@ -2180,4 +2180,349 @@ FsaOrVec RemoveEpsilonSelfLoops(FsaOrVec &src, return ans; } +FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, + Ragged &frame_ids, + Ragged &left_symbols, + Ragged &sampling_probs, + Array1 &boundary, + int32_t vocab_size, + int32_t context_size, + Array1 *arc_map) { + NVTX_RANGE(K2_FUNC); + K2_CHECK(arc_map); + K2_CHECK_EQ(sampled_paths.NumAxes(), 3); + K2_CHECK_EQ(frame_ids.NumAxes(), 3); + K2_CHECK_EQ(left_symbols.NumAxes(), 4); + K2_CHECK_EQ(sampling_probs.NumAxes(), 3); + + K2_DCHECK_EQ(sampled_paths.NumElements(), frame_ids.NumElements()); + K2_DCHECK_EQ(sampled_paths.NumElements(), + left_symbols.NumElements() * context_size); + K2_DCHECK_EQ(sampled_paths.NumElements(), sampling_probs.NumElements()); + K2_DCHECK_EQ(sampled_paths.TotSize(0), boundary.Dim()); + for (int32_t i = 0; i < 3; ++i) { + K2_DCHECK_EQ(sampled_paths.TotSize(i), frame_ids.TotSize(i)); + K2_DCHECK_EQ(sampled_paths.TotSize(i), left_symbols.TotSize(i)); + K2_DCHECK_EQ(sampled_paths.TotSize(i), sampling_probs.TotSize(i)); + } + + ContextPtr c = GetContext( + sampled_paths, frame_ids, left_symbols, sampling_probs); + + // The states indicating we are in on each position of each path, which has + // the same shape as `sampled_paths`, because each symbol in the paths is + // sampled from a specific frame with corresponding left contexts. + // Each state represents a tuple like (t, left_symbols1, left_symbols2...), + // the number of left_symbols equals to the `context_size`. A state is + // calculated from t * V ^ c + \sum_{i=1}^{c} s_i * V ^ (c - i), + // V is vocab_size, c is context_size, s_i is the ith left_symbols. + // For example, if context_size = 2, vocab_size = 10, so, one possible tuple + // would be (2, 4, 5), then the corresponding state is + // 2 * 10 ^ 2 + 4 * 10 + 5 = 245. + Ragged states(sampled_paths.shape); + int32_t num_states = states.NumElements(); + + const int32_t *frame_ids_data = frame_ids.values.Data(), + *left_symbols_row_splits3_data + = left_symbols.RowSplits(3).Data(), + *left_symbols_data = left_symbols.values.Data(); + int64_t *states_data = states.values.Data(); + + // This kernel calculates t * V ^ c for each state. + K2_EVAL( + c, num_states, lambda_init_states_with_t, (int32_t idx012) -> void { + states_data[idx012] + = frame_ids_data[idx012] * Pow(vocab_size, context_size); + }); + + // The following kernels calculate \sum_{i=1}^{c} s_i * V ^ (c - i) + for (int32_t i = 0; i < context_size; ++i) { + K2_EVAL( + c, num_states, lambda_generate_states, (int32_t idx012) -> void { + int32_t left_symbols_idx012x = left_symbols_row_splits3_data[idx012], + left_symbols_idx0123 = left_symbols_idx012x + i, + exp = context_size - i - 1; + states_data[idx012] + += left_symbols_data[left_symbols_idx0123] * Pow(vocab_size, exp); + }); + } + + // Sort those states for each sequence, so as to merge the same states. + // sorted_states has two axes: [seq][state] + auto sorted_states = Ragged( + RemoveAxis(states.shape, 1 /*axis*/), states.values.Clone()); + Array1 sorted_states_new2old(c, num_states); + SortSublists(&sorted_states, &sorted_states_new2old); + + // We need old2new map to find the original consecutive state. + Array1 sorted_states_old2new(c, num_states); + const int32_t *sorted_states_new2old_data = sorted_states_new2old.Data(); + int32_t *sorted_states_old2new_data = sorted_states_old2new.Data(); + K2_EVAL( + c, num_states, lambda_get_old2new, (int32_t i) -> void { + sorted_states_old2new_data[sorted_states_new2old_data[i]] = i; + }); + + // Search "tails concept" in k2/csrc/utils.h for the details of tail array. + // By applying ExclusiveSum on the tail_array, we can get a row_id mapping the + // sorted states to unique_states (i.e. the merged states). + Array1 tail_array(c, num_states); + const int32_t *sorted_states_row_ids1_data = sorted_states.RowIds(1).Data(); + const int64_t *sorted_states_data = sorted_states.values.Data(); + int32_t *tail_array_data = tail_array.Data(); + + K2_EVAL( + c, num_states, lambda_get_tail_array, (int32_t idx01) -> void { + if (idx01 == num_states - 1) tail_array_data[idx01] = 1; + int32_t idx0 = sorted_states_row_ids1_data[idx01], + next_idx0 = sorted_states_row_ids1_data[idx01 + 1]; + if (idx0 == next_idx0 && + sorted_states_data[idx01] == sorted_states_data[idx01 + 1]) + tail_array_data[idx01] = 0; + else + tail_array_data[idx01] = 1; + }); + + Array1 unique_states_row_ids(c, num_states); + ExclusiveSum(tail_array, &unique_states_row_ids); + + // unique_states_shape's shape [merged state][sorted state] + // unique_states_shape.row_splits.Dim() - 1 equals to the number of merged + // states. + RaggedShape unique_states_shape = RaggedShape2( + nullptr, &unique_states_row_ids, unique_states_row_ids.Dim()); + + // We are figuring out the ragged shape of the lattice. + // First, figure out the number of states (i.e. the merged states) for each + // sequence. + // Second, figure out the number of arcs for each merged state. + int32_t num_seqs = states.TotSize(0); + + // Plus 1 here because we will apply ExclusiveSum on this array. + Array1 num_states_for_seqs(c, states.TotSize(0) + 1); + + // "ss" is short for "sorted states" + // "us" is short for "unique states". + const int32_t *ss_row_splits1_data = sorted_states.RowSplits(1).Data(), + *us_row_ids1_data = unique_states_shape.RowIds(1).Data(); + int32_t *num_states_for_seqs_data = num_states_for_seqs.Data(); + + K2_EVAL( + c, num_seqs, lambda_get_num_states, (int32_t idx0) -> void { + int32_t ss_idx0x = ss_row_splits1_data[idx0], + ss_idx0x_next = ss_row_splits1_data[idx0 + 1], + us_idx0 = us_row_ids1_data[ss_idx0x], + us_idx0_next_minus_1 = us_row_ids1_data[ss_idx0x_next - 1], + num_unique_states = us_idx0_next_minus_1 - us_idx0 + 1; + // Plus 3 here, because we need a super dest_state for the states sampled + // on the last frame (this dest_state will point to the final state), + // a fake super dest_state for the last states of linear paths that + // are not sampled on the last frames (this fake dest_state will be + // removed by connect operation), and a final state needed by k2. + num_states_for_seqs_data[idx0] = num_unique_states + 3; + }); + + ExclusiveSum(num_states_for_seqs, &num_states_for_seqs); + RaggedShape seqs_to_states_shape = RaggedShape2( + &num_states_for_seqs, nullptr, -1); + int32_t num_merged_states = seqs_to_states_shape.NumElements(); + + K2_CHECK_EQ(unique_states_shape.RowSplits(1).Dim() - 1 + num_seqs * 3, + num_merged_states); + + // Plus 1 here because we will apply ExclusiveSum on this array. + Array1 num_arcs_for_states( + c, seqs_to_states_shape.NumElements() + 1); + + // "sts" is short for "seqs to states" + // "us" is short for "unique states". + const int32_t *us_row_splits1_data = unique_states_shape.RowSplits(1).Data(), + *sts_row_ids1_data = seqs_to_states_shape.RowIds(1).Data(), + *sts_row_splits1_data + = seqs_to_states_shape.RowSplits(1).Data(); + int32_t *num_arcs_for_states_data = num_arcs_for_states.Data(); + + K2_EVAL( + c, num_merged_states, lambda_get_num_arcs, (int32_t idx01) -> void { + int32_t idx0 = sts_row_ids1_data[idx01], + idx0x_next = sts_row_splits1_data[idx0 + 1], + num_arcs = 0; + // The final arc for each sequence. + if (idx01 == idx0x_next - 2) num_arcs = 1; + if (idx01 < idx0x_next - 3) { + // Minus idx0 * 3, because we add extra three states for each sequence. + int32_t us_idx0 = idx01 - idx0 * 3, + us_idx0x = us_row_splits1_data[us_idx0], + us_idx0x_next = us_row_splits1_data[us_idx0 + 1]; + num_arcs = us_idx0x_next - us_idx0x; + } + // idx01 == idx0x_next - 3 (i.e. the fake super dest_state) and + // idx01 == idx0x_next -1 (i.e. the final state) don't have arcs. + num_arcs_for_states_data[idx01] = num_arcs; + }); + + ExclusiveSum(num_arcs_for_states, &num_arcs_for_states); + RaggedShape states_to_arcs_shape = RaggedShape2( + &num_arcs_for_states, nullptr, -1); + + RaggedShape arcs_shape = ComposeRaggedShapes( + seqs_to_states_shape, states_to_arcs_shape); + int32_t num_arcs = arcs_shape.NumElements(); + + // Each state (before merging) has a leaving arc, we add a final arc + // to each sequence, so, the total number of arcs equals to + // num_states + num_seqs + K2_CHECK_EQ(num_arcs, num_seqs + num_states); + + // Populate arcs. + // "ss" is short for "sorted states" + const int32_t *sampled_paths_data = sampled_paths.values.Data(), + *arcs_shape_row_ids1_data = arcs_shape.RowIds(1).Data(), + *arcs_shape_row_splits1_data = arcs_shape.RowSplits(1).Data(), + *arcs_shape_row_ids2_data = arcs_shape.RowIds(2).Data(), + *states_row_ids2_data = states.RowIds(2).Data(), + *boundary_data = boundary.Data(), + *ss_row_ids1_data = sorted_states.RowIds(1).Data(); + const float *sampling_probs_data = sampling_probs.values.Data(); + Array1 arcs(c, num_arcs); + Arc *arcs_data = arcs.Data(); + + // The arc_map mapping from lattice arcs to original state indexes. + Array1 raw_arc_map(c, num_arcs); + int32_t *raw_arc_map_data = raw_arc_map.Data(); + + K2_EVAL( + c, num_arcs, lambda_set_arcs, (int32_t idx012) -> void { + Arc arc; + int32_t arc_map_value = -1; + int32_t idx01 = arcs_shape_row_ids2_data[idx012], + idx0 = arcs_shape_row_ids1_data[idx01], + idx0x = arcs_shape_row_splits1_data[idx0], + idx1 = idx01 - idx0x; + arc.src_state = idx1; + + // Final arc of the last sequence. + if (idx012 == num_arcs - 1) { + arc.dest_state = idx1 + 1; + arc.label = -1; + arc.score = 0.0; + } else { + int32_t idx01_next = arcs_shape_row_ids2_data[idx012 + 1], + idx0_next = arcs_shape_row_ids1_data[idx01_next]; + // Final arc for each sequence, except the last sequence. + if (idx0 != idx0_next) { + arc.dest_state = idx1 + 1; + arc.label = -1; + arc.score = 0.0; + } else { + // ss_idx01 is the global index of sorted states, minus idx0 here + // because we added an extra final arc for each sequence. + int32_t ss_idx01 = idx012 - idx0, + states_idx012 = sorted_states_new2old_data[ss_idx01]; + + arc_map_value = states_idx012; + arc.label = sampled_paths_data[states_idx012]; + float sampling_prob = sampling_probs_data[states_idx012]; + + int32_t us_idx0 = us_row_ids1_data[ss_idx01], + repeat_num = us_row_splits1_data[us_idx0 + 1] - + us_row_splits1_data[us_idx0]; + + float score = -logf(1 - powf(1 - sampling_prob, repeat_num)); + if (score - score != 0) { + arc.score = 0.0; + } else { + arc.score = score; + } + + K2_DCHECK_LT(frame_ids_data[states_idx012], boundary_data[idx0]); + + int32_t idx0x_next = arcs_shape_row_splits1_data[idx0 + 1]; + + // Handle the final state of last sequence. + if (states_idx012 == num_states - 1) { + // If current state is on final frame, it will point to the added + // super dest_state. + if (frame_ids_data[states_idx012] == boundary_data[idx0] - 1) { + arc.dest_state = idx0x_next - idx0x - 2; + } else { + // point to the fake added dest_state. + arc.dest_state = idx0x_next - idx0x - 3; + } + } else { + // states_idx01 is path index + int32_t states_idx01 = states_row_ids2_data[states_idx012], + states_idx01_next = + states_row_ids2_data[states_idx012 + 1]; + if (states_idx01 != states_idx01_next) { + // If current state is on final frame, it will point to the added + // super dest_state. + if (frame_ids_data[states_idx012] == boundary_data[idx0] - 1) { + arc.dest_state = idx0x_next - idx0x - 2; + } else { + // point to the fake added dest_state. + arc.dest_state = idx0x_next - idx0x - 3; + } + } else { + // If current state is on final frame, it will point to the added + // super dest_state. + if (frame_ids_data[states_idx012] == boundary_data[idx0] - 1 && + frame_ids_data[states_idx012 + 1] != boundary_data[idx0] - 1) { + arc.dest_state = idx0x_next - idx0x - 2; + } else { + // states_idx012 + 1 is the index of original consecutive state. + // "ss" is short for "sorted states" + // "us" is short for "unique states". + int32_t ss_idx01_next = + sorted_states_old2new_data[states_idx012 + 1], + us_idx0_next = us_row_ids1_data[ss_idx01_next]; + // Plus 3 * idx0, because we add 3 state for each sequence + arc.dest_state = us_idx0_next + 3 * idx0 - idx0x; + } + } + } + } + } + arcs_data[idx012] = arc; + raw_arc_map_data[idx012] = arc_map_value; + }); + + FsaVec fsas = Ragged(arcs_shape, arcs); + // arcsort so as to remove duplicate arcs. + Array1 arc_sort_new2old(c, num_arcs); + SortSublists(&fsas, &arc_sort_new2old); + + // remove duplicate arcs, use renumbering + Renumbering renumber_arcs(c, num_arcs); + char *keep_arcs_data = renumber_arcs.Keep().Data(); + K2_EVAL( + c, num_arcs, lambda_set_keep_arcs, (int32_t idx012) -> void { + char keep = 1; + if (idx012 < num_arcs - 1) { + int32_t idx01 = arcs_shape_row_ids2_data[idx012], + idx01_next = arcs_shape_row_ids2_data[idx012 + 1]; + // duplicate arcs, which are arcs with the same symbol going from the + // same src_state to the same dest_state. The symbol will automatically + // be the same if the src_state and dest_state are the same if + // context_size > 0. + if (idx01 == idx01_next && + arcs_data[idx012].src_state == arcs_data[idx012 + 1].src_state && + arcs_data[idx012].dest_state == arcs_data[idx012 + 1].dest_state) { + K2_DCHECK_EQ(arcs_data[idx012].label, arcs_data[idx012 + 1].label); + keep = 0; + } + } + keep_arcs_data[idx012] = keep; + }); + + Array1 renumber_arc_map; + FsaVec final_fsas = Index( + fsas, 2, renumber_arcs.New2Old(), &renumber_arc_map); + + if (arc_map != nullptr) { + *arc_map = raw_arc_map[arc_sort_new2old][renumber_arc_map]; + } + return final_fsas; +} + } // namespace k2 diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index cecf6940a..233456b36 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -842,7 +842,7 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, weight 'x.weight', then the reverse of 'src' accepts the reverse of string 'x' with weight 'x.weight.reverse'. - Implementation notss: + Implementation notes: The Fsa in k2 only has one start state 0, and the only final state with the largest state number whose in-coming arcs have "-1" as the label. So, 1) the start state of 'dest' will correspond to the final state of 'src'. @@ -864,13 +864,53 @@ FsaOrVec ReplaceFsa(FsaVec &src, FsaOrVec &index, int32_t symbol_range_begin, @param [out] dest Output Fsa or FsaVec. At exit, it will be equivalent to the reverse Fsa of 'src'. Caution: the reverse will ignore the "-1" label. - + @param [out,optional] arc_map For each arc in `dest`, gives the index of the corresponding arc in `src` that it corresponds to. */ void Reverse(FsaVec &src, FsaVec *dest, Array1 *arc_map = nullptr); +/* + * Generate denominator lattice from sampled linear paths for RNN-T+MMI + * training. + * + * Implementation notes: + * 1) Generate "states" for each sampled symbol from their left_symbols and + * the frame_ids they are sampled from. + * 2) Sort those "states" for each sequence and then merge the same "states". + * 3) Map all of the sampled symbols to the merged "states". + * 4) Remove duplicate arcs. + * + * @param [in] sampled_paths The sampled symbols, it has a regular shape of + * [seq][num_path][path_length]. All its elements MUST satisfy + * `0 <= value < vocab_size. + * @param [in] frame_ids It contains the frame indexes of at which frame we + * sampled the symbols, which has same shape of sampled_paths. + * @param [in] left_symbols The left_symbols of the sampled symbols, it has a + * regular shape of [seq][num_path][path_length][context], the + * first three indexes are the same as sampled_paths. Each + * sublist along axis 3 has `context_size` elements. All its + * elements MUST satisfy `0 <= value < vocab_size`. + * @param [in] sampling_probs It contains the probabilities of sampling each + * symbol, which has the same shape as sampled_paths. + * @param [in] boundary It contains the number of frames for each sequence. + * @param [in] vocab_size The vocabulary size. + * @param [in] context_size The number of left symbols. + * @param [out] arc_map For each arc in the return Fsa, gives the orignal + * index (idx012) in sampled_paths that it corresponds to. + * + * @return Return the generated lattice. + */ +FsaVec GenerateDenominatorLattice(Ragged &sampled_paths, + Ragged &frame_ids, + Ragged &left_symbols, + Ragged &sampling_probs, + Array1 &boundary, + int32_t vocab_size, + int32_t context_size, + Array1 *arc_map); + } // namespace k2 #endif // K2_CSRC_FSA_ALGO_H_ diff --git a/k2/csrc/fsa_algo_test.cu b/k2/csrc/fsa_algo_test.cu index 903514318..f30ca4d97 100644 --- a/k2/csrc/fsa_algo_test.cu +++ b/k2/csrc/fsa_algo_test.cu @@ -1383,4 +1383,48 @@ TEST(FsaAlgo, TestLevenshteinGraph) { } } +TEST(FsaAlgo, TestGenerateDenominatorLattice) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + Ragged sampled_paths(c, "[ [ [ 3 5 0 4 6 0 2 1 ] " + " [ 2 0 5 4 0 6 1 2 ] " + " [ 3 5 2 0 0 1 6 4 ] ] " + " [ [ 7 0 4 0 6 0 3 0 ] " + " [ 0 7 3 0 2 0 4 5 ] " + " [ 7 0 3 4 0 1 2 0 ] ] ]"); + Ragged frame_ids(c, "[ [ [ 0 0 0 1 1 1 2 2 ] " + " [ 0 0 1 1 1 2 2 2 ] " + " [ 0 0 0 0 1 2 2 2 ] ] " + " [ [ 0 0 1 1 2 2 3 3 ] " + " [ 0 1 1 1 2 2 3 1 ] " + " [ 0 0 1 1 1 2 2 2 ] ] ]"); + Ragged left_symbols(c, + "[ [ [ [ 0 0 ] [ 0 3 ] [ 3 5 ] [ 3 5 ] [ 5 4 ] [ 4 6 ] [ 4 6 ] [ 6 2 ] ] " + " [ [ 0 0 ] [ 0 2 ] [ 0 2 ] [ 2 5 ] [ 5 4 ] [ 5 4 ] [ 4 6 ] [ 6 1 ] ] " + " [ [ 0 0 ] [ 0 3 ] [ 3 5 ] [ 5 2 ] [ 5 2 ] [ 5 2 ] [ 2 1 ] [ 1 6 ] ] " + " ] " + " [ [ [ 0 0 ] [ 0 7 ] [ 0 7 ] [ 7 4 ] [ 7 4 ] [ 4 6 ] [ 4 6 ] [ 6 3 ] ] " + " [ [ 0 0 ] [ 0 0 ] [ 0 7 ] [ 7 3 ] [ 7 3 ] [ 3 2 ] [ 3 2 ] [ 0 0 ] ] " + " [ [ 0 0 ] [ 0 7 ] [ 0 7 ] [ 7 3 ] [ 3 4 ] [ 3 4 ] [ 4 1 ] [ 1 2 ] ] " + " ] ]"); + + Ragged sampling_probs(c, "[ [ [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 ] " + " [ 0.2 0.2 0.2 0.1 0.2 0.2 0.1 0.2 ] " + " [ 0.1 0.1 0.1 0.3 0.2 0.3 0.3 0.3 ] ] " + " [ [ 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 ] " + " [ 0.2 0.2 0.2 0.1 0.2 0.1 0.2 0.2 ] " + " [ 0.1 0.1 0.2 0.2 0.3 0.3 0.3 0.3 ] ] " + "]"); + Array1 boundary(c, "[ 3 4 ]"); + + Array1 arc_map; + FsaVec lattice = GenerateDenominatorLattice( + sampled_paths, frame_ids, left_symbols, sampling_probs, boundary, + 10 /*vocab_size*/, 2 /*context_size*/, &arc_map); + K2_LOG(INFO) << arc_map; + K2_LOG(INFO) << lattice; + K2_LOG(INFO) << FsaToString(lattice.Index(0, 0)); + K2_LOG(INFO) << FsaToString(lattice.Index(0, 1)); + } +} + } // namespace k2 diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index e8f4c01a9..78ff069e6 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -807,6 +807,31 @@ static void PybindReverse(py::module &m) { py::arg("src"), py::arg("need_arc_map") = true); } +static void PybindGenerateDenominatorLattice(py::module &m) { + m.def( + "generate_denominator_lattice", + [](RaggedAny &sampled_paths, RaggedAny &frame_ids, + RaggedAny &left_symbols, RaggedAny &sampling_probs, + torch::Tensor &boundary, int32_t vocab_size, int32_t context_size) + -> std::pair { + DeviceGuard guard(sampled_paths.any.Context()); + Array1 arc_map; + Array1 boundary_array = FromTorch(boundary); + FsaVec lattice = GenerateDenominatorLattice( + sampled_paths.any.Specialize(), + frame_ids.any.Specialize(), + left_symbols.any.Specialize(), + sampling_probs.any.Specialize(), + boundary_array, + vocab_size, context_size, &arc_map); + auto arc_map_tensor = ToTorch(arc_map); + return std::make_pair(lattice, arc_map_tensor); + }, + py::arg("sampled_paths"), py::arg("frame_ids"), py::arg("left_symbols"), + py::arg("sampling_probs"), py::arg("boundary"), py::arg("vocab_size"), + py::arg("context_size")); +} + } // namespace k2 void PybindFsaAlgo(py::module &m) { @@ -820,6 +845,7 @@ void PybindFsaAlgo(py::module &m) { k2::PybindDeterminize(m); k2::PybindExpandArcs(m); k2::PybindFixFinalLabels(m); + k2::PybindGenerateDenominatorLattice(m); k2::PybindIntersect(m); k2::PybindIntersectDense(m); k2::PybindIntersectDensePruned(m); diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index f4e04be10..f58e5ae3d 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -58,6 +58,7 @@ from .fsa_algo import ctc_topo from .fsa_algo import determinize from .fsa_algo import expand_ragged_attributes +from .fsa_algo import generate_denominator_lattice from .fsa_algo import intersect from .fsa_algo import intersect_device from .fsa_algo import invert diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 35d9ca082..690874428 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -24,6 +24,7 @@ import torch import _k2 import k2 +import logging from . import fsa_properties from .fsa import Fsa @@ -1473,3 +1474,68 @@ def union(fsas: Fsa) -> Fsa: out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map) return out_fsa + + +def generate_denominator_lattice( + sampled_paths: torch.Tensor, + frame_ids: torch.Tensor, + left_symbols: torch.Tensor, + sampling_probs: torch.Tensor, + path_scores: torch.Tensor, + boundary: torch.Tensor, + vocab_size: int, + context_size: int, + return_arc_map: bool = False, +) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]: + """Generate denominator lattice from sampled linear paths for RNN-T+MMI + training. + + Args: + sampled_paths: + The sampled symbols, it has a shape of (seq, num_path, path_length). + All its elements MUST satisfy `0 <= value < vocab_size. + frame_ids: + It contains the frame indexes of at which frame we sampled the symbols, + which has same shape of sampled_paths. + left_symbols: + The left_symbols of the sampled symbols, it has a shape of + (seq, num_path, path_length, context_size), the first three indexes are + the same as sampled_paths. All its elements MUST satisfy + `0 <= value < vocab_size`. + sampling_probs: + It contains the probabilities of sampling each symbol, which has a + same shape as sampled_paths. Normally comes from the output of + "predictor" head. + path_scores: + It contains the scores of each sampled symbol, which has a same shape as + sampled_paths. It might contain the output of hybrid head and the extra + language model output. Note: Autograd is supported for this tensor. + boundary: + It contains the number of frames for each sequence. + vocab_size: + The vocabulary size. + context_size: + The number of left symbols. + return_arc_map: + Whether to return arc_map. + """ + ragged_arc, arc_map = _k2.generate_denominator_lattice( + sampled_paths=k2.RaggedTensor(sampled_paths), + frame_ids=k2.RaggedTensor(frame_ids), + left_symbols=k2.RaggedTensor(left_symbols), + sampling_probs=k2.RaggedTensor(sampling_probs), + boundary=boundary, + vocab_size=vocab_size, + context_size=context_size, + ) + lattice = Fsa(ragged_arc) + a_value = getattr(lattice, "scores") + # Enable autograd for path_scores + b_value = index_select(path_scores.flatten(), arc_map) + assert torch.all(a_value >= 0), a_value + value = b_value + a_value + setattr(lattice, "scores", value) + if return_arc_map: + return lattice, arc_map + else: + return lattice diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index b9a130d68..de9cfcac3 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -328,6 +328,7 @@ def get_rnnt_logprobs_joint( termination_symbol: int, rnnt_type: str = "regular", boundary: Optional[Tensor] = None, + normalized: bool = True, ) -> Tuple[Tensor, Tensor]: """Reduces RNN-T problem to a compact, standard form that can then be given (with boundaries) to mutual_information_recursion(). @@ -348,6 +349,8 @@ def get_rnnt_logprobs_joint( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + normalized: + True to do log_softmax normalization, otherwise not. rnnt_type: Specifies the type of rnnt paths: `regular`, `modified` or `constrained`. `regular`: The regular rnnt that taking you to the next frame only if @@ -398,8 +401,10 @@ def get_rnnt_logprobs_joint( assert T >= S, (T, S) assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type - normalizers = torch.logsumexp(logits, dim=3) - normalizers = normalizers.permute((0, 2, 1)) + normalizers = None + if normalized: + normalizers = torch.logsumexp(logits, dim=3) + normalizers = normalizers.permute((0, 2, 1)) px = torch.gather( logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) @@ -417,12 +422,18 @@ def get_rnnt_logprobs_joint( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - px[:, :, :T] -= normalizers[:, :S, :] + if normalized: + px[:, :, :T] -= normalizers[:, :S, :] py = ( logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() ) # [B][S+1][T] - py -= normalizers + + if normalized: + py -= normalizers + + px = px.contiguous() + py = py.contiguous() if rnnt_type == "regular": px = fix_for_boundary(px, boundary) @@ -437,6 +448,7 @@ def rnnt_loss( symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + normalized: bool = True, rnnt_type: str = "regular", delay_penalty: float = 0.0, reduction: Optional[str] = "mean", @@ -458,6 +470,8 @@ def rnnt_loss( [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + normalized: + True to do log_softmax normalization, otherwise not. rnnt_type: Specifies the type of rnnt paths: `regular`, `modified` or `constrained`. `regular`: The regular rnnt that taking you to the next frame only if @@ -492,6 +506,7 @@ def rnnt_loss( symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, + normalized=normalized, rnnt_type=rnnt_type, ) @@ -524,7 +539,6 @@ def rnnt_loss( f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" ) - def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor: """Compute a monotonically increasing lower bound of the tensor `x` on the last dimension. The basic idea is: we traverse the tensor in reverse order, @@ -561,7 +575,6 @@ def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor: x = torch.flip(x, dims=(-1,)) return x - def _adjust_pruning_lower_bound( s_begin: torch.Tensor, s_range: int ) -> torch.Tensor: @@ -961,6 +974,7 @@ def get_rnnt_logprobs_pruned( ranges: Tensor, termination_symbol: int, boundary: Tensor, + normalized: bool = True, rnnt_type: str = "regular", ) -> Tuple[Tensor, Tensor]: """Construct px, py for mutual_information_recursion with pruned output. @@ -987,6 +1001,8 @@ def get_rnnt_logprobs_pruned( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + normalized: + True to do log_softmax normalization, otherwise not. rnnt_type: Specifies the type of rnnt paths: `regular`, `modified` or `constrained`. `regular`: The regular rnnt that taking you to the next frame only if @@ -1040,7 +1056,9 @@ def get_rnnt_logprobs_pruned( assert T >= S, (T, S) assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type - normalizers = torch.logsumexp(logits, dim=3) + normalizers = None + if normalized: + normalizers = torch.logsumexp(logits, dim=3) symbols_with_terminal = torch.cat( ( @@ -1065,7 +1083,9 @@ def get_rnnt_logprobs_pruned( px = torch.gather( logits, dim=3, index=pruned_symbols.reshape(B, T, s_range, 1) ).squeeze(-1) - px = px - normalizers + + if normalized: + px = px - normalizers # (B, T, S) with index larger than s_range in dim 2 fill with -inf px = torch.cat( @@ -1098,7 +1118,9 @@ def get_rnnt_logprobs_pruned( ) # now: [B][S][T+1], index [:,:,T] has -inf.. py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range) - py = py - normalizers + + if normalized: + py = py - normalizers # (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf py = torch.cat( @@ -1133,10 +1155,12 @@ def rnnt_loss_pruned( ranges: Tensor, termination_symbol: int, boundary: Tensor = None, + normalized: bool = True, + return_grad: bool = False, rnnt_type: str = "regular", delay_penalty: float = 0.0, reduction: Optional[str] = "mean", -) -> Tensor: +) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: """A RNN-T loss with pruning, which uses the output of a pruned 'joiner' network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), s_range means the number of symbols kept for each frame. @@ -1163,6 +1187,8 @@ def rnnt_loss_pruned( [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + normalized: + True to do log_softmax normalization, otherwise not. rnnt_type: Specifies the type of rnnt paths: `regular`, `modified` or `constrained`. `regular`: The regular rnnt that taking you to the next frame only if @@ -1186,10 +1212,20 @@ def rnnt_loss_pruned( `mean`: apply `torch.mean` over the batches. `sum`: the output will be summed. Default: `mean` + return_grad: + Whether to return grads of px and py, this grad standing for the + occupation probability is the output of the backward with a + `fake gradient`, the `fake gradient` is the same as the gradient you'd + get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the + loss here is the loss with reduction "none". + This is useful to implement the pruned version of rnnt loss. Returns: - If reduction is `none`, returns a tensor of shape (B,), containing the - total RNN-T loss values for each sequence of the batch, otherwise a scalar - with the reduction applied. + If return_grad is False, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch if reduction equals + to "none", otherwise a scalar with the reduction applied. + If return_grad is True, the grads of px and py, which is the output of + backward with a `fake gradient`(see above), will be returned too. And the + returned value will be a tuple like (loss, (px_grad, py_grad)). """ px, py = get_rnnt_logprobs_pruned( logits=logits, @@ -1197,6 +1233,7 @@ def rnnt_loss_pruned( ranges=ranges, termination_symbol=termination_symbol, boundary=boundary, + normalized=normalized, rnnt_type=rnnt_type, ) @@ -1217,17 +1254,22 @@ def rnnt_loss_pruned( penalty = penalty * delay_penalty px += penalty.to(px.dtype) - negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + scores_and_grads = mutual_information_recursion( + px=px, py=py, boundary=boundary, return_grad=return_grad + ) + negated_loss = scores_and_grads[0] if return_grad else scores_and_grads + if reduction == "none": - return -negated_loss + loss = -negated_loss elif reduction == "mean": - return -torch.mean(negated_loss) + loss = -torch.mean(negated_loss) elif reduction == "sum": - return -torch.sum(negated_loss) + loss = -torch.sum(negated_loss) else: - raise ValueError( + raise ValueError ( f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" ) + return (loss, scores_and_grads[1]) if return_grad else loss def get_rnnt_logprobs_smoothed( diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 429af80b5..326accedf 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -34,6 +34,7 @@ set(py_test_files fsa_from_unary_function_ragged_test.py fsa_from_unary_function_tensor_test.py fsa_test.py + generate_denominator_lattice_test.py get_arc_post_test.py get_backward_scores_test.py get_forward_scores_test.py diff --git a/k2/python/tests/generate_denominator_lattice_test.py b/k2/python/tests/generate_denominator_lattice_test.py new file mode 100644 index 000000000..29a06fbf5 --- /dev/null +++ b/k2/python/tests/generate_denominator_lattice_test.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R generate_denominator_lattice_test_py + +import unittest + +import k2 +import torch + +from torch.distributions.categorical import Categorical +from typing import Tuple + + +def _roll_by_shifts( + src: torch.Tensor, shifts: torch.LongTensor +) -> torch.Tensor: + """Roll tensor with different shifts for each row. + + Note: + We assume the src is a 3 dimensions tensor and roll the last dimension. + + Example: + + >>> src = torch.arange(15).reshape((1,3,5)) + >>> src + tensor([[[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14]]]) + >>> shift = torch.tensor([[1, 2, 3]]) + >>> shift + tensor([[1, 2, 3]]) + >>> _roll_by_shifts(src, shift) + tensor([[[ 4, 0, 1, 2, 3], + [ 8, 9, 5, 6, 7], + [12, 13, 14, 10, 11]]]) + """ + assert src.dim() == 3 + (B, T, S) = src.shape + assert shifts.shape == (B, T) + + index = ( + torch.arange(S, device=src.device) + .view((1, S)) + .repeat((T, 1)) + .repeat((B, 1, 1)) + ) + index = (index - shifts.reshape(B, T, 1)) % S + return torch.gather(src, 2, index) + + +def simulate_importance_sampling( + boundary: torch.Tensor, + vocab_size: int, + path_length: int, + num_paths: int, + context_size: int = 2, + blank_id: int = 0, + device: torch.device = torch.device("cpu"), +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + boundary: + It is a tensor with shape (B,), containing the number of frames for + each sequence. + vocab_size: + Vocabulary size. + path_length: + How many symbols we will sample for each path. + num_paths: + How many paths we will sample for each sequence. + context_size: + The number of left symbols. + blank_id: + Stands for null context. + + Returns: + Three tensors will be returned. + - sampled_paths: + A tensor of shape (batch_size, num_paths, path_length), containing the + sampled symbol ids. + - sampling_probs: + A tensor of shape (batch_size, num_paths, path_length), containing the + sampling probabilities of the sampled symbols. + - left_symbols: + A tensor of shape (batch_size, num_paths, path_length, context_size), + containing the left symbols of the sampled symbols. + - frame_ids: + A tensor of shape (batch_size, num_paths, path_length), containing the + frame ids at which we sampled the symbols. + """ + # we sample paths from frame 0 + batch_size = boundary.numel() + + t_index = torch.zeros( + (batch_size, num_paths), dtype=torch.int64, device=device + ) + + t_index_max = boundary.view(batch_size, 1).expand(batch_size, num_paths) + + left_symbols = torch.tensor( + [blank_id], dtype=torch.int64, device=device + ).expand(batch_size, num_paths, context_size) + + sampled_paths_list = [] + sampling_probs_list = [] + frame_ids_list = [] + left_symbols_list = [] + + for i in range(path_length): + probs = torch.randn(batch_size, num_paths, vocab_size) + probs = torch.softmax(probs, -1) + # sampler: https://pytorch.org/docs/stable/distributions.html#categorical + sampler = Categorical(probs=probs) + + # sample one symbol for each path + # index : (batch_size, num_paths) + index = sampler.sample() + sampled_paths_list.append(index) + + # gather sampling probabilities for corresponding indexs + # sampling_prob : (batch_size, num_paths, 1) + sampling_probs = torch.gather(probs, dim=2, index=index.unsqueeze(2)) + sampling_probs_list.append(sampling_probs.squeeze(2)) + + frame_ids_list.append(t_index) + + left_symbols_list.append(left_symbols) + + # update (t, s) for each path + # index == 0 means the sampled symbol is blank + t_mask = index == 0 + # t_index = torch.where(t_mask, t_index + 1, t_index) + t_index = t_index + 1 + + final_mask = t_index >= t_index_max + reach_final = torch.any(final_mask) + if reach_final: + new_t_index = torch.randint(0, torch.min(t_index_max) - 1, (1,)).item() + t_index.masked_fill_(final_mask, new_t_index) + + current_symbols = torch.cat([left_symbols, index.unsqueeze(2)], dim=2) + left_symbols = _roll_by_shifts(current_symbols, t_mask.to(torch.int64)) + left_symbols = left_symbols[:, :, 1:] + if reach_final: + left_symbols.masked_fill_(final_mask.unsqueeze(2), blank_id) + + # sampled_paths : (batch_size, num_paths, path_lengths) + sampled_paths = torch.stack(sampled_paths_list, dim=2).int() + # sampling_probs : (batch_size, num_paths, path_lengths) + sampling_probs = torch.stack(sampling_probs_list, dim=2) + # frame_ids : (batch_size , num_paths, path_lengths) + frame_ids = torch.stack(frame_ids_list, dim=2).int() + # left_symbols : (batch_size, num_paths, path_lengths, context_size) + left_symbols = torch.stack(left_symbols_list, dim=2).int() + return sampled_paths, frame_ids, sampling_probs, left_symbols + + +class TestConnect(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.devices = [torch.device("cpu")] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device("cuda", 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device("cuda", 1)) + + def test(self): + context_size = 2 + batch_size, num_paths, path_length, vocab_size = 2, 3, 10, 10 + boundary_ = torch.tensor([6, 9], dtype=torch.int32) + ( + sampled_paths_, + frame_ids_, + sampling_probs_, + left_symbols_, + ) = simulate_importance_sampling( + boundary=boundary_, + vocab_size=vocab_size, + num_paths=num_paths, + path_length=path_length, + context_size=context_size, + ) + path_scores_ = torch.randn( + (batch_size, num_paths, path_length), dtype=torch.float + ) + for device in self.devices: + boundary = boundary_.to(device) + sampled_paths = sampled_paths_.to(device) + sampling_probs = sampling_probs_.to(device) + frame_ids = frame_ids_.to(device) + left_symbols = left_symbols_.to(device) + path_scores = path_scores_.detach().clone().to(device) + path_scores.requires_grad_(True) + fsa = k2.generate_denominator_lattice( + sampled_paths=sampled_paths, + frame_ids=frame_ids, + left_symbols=left_symbols, + sampling_probs=sampling_probs, + boundary=boundary, + path_scores=path_scores, + vocab_size=vocab_size, + context_size=context_size, + ) + print(fsa) + fsa = k2.connect(k2.top_sort(fsa)) + scores = torch.sum( + fsa.get_tot_scores(log_semiring=True, use_double_scores=False) + ) + scores.backward() + print(path_scores.grad) + + +if __name__ == "__main__": + unittest.main()