From 0222b1e7237b3e96c96aca23e926662e9d8a20dd Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 3 Nov 2022 12:01:15 +0800 Subject: [PATCH 01/22] allow_partial for intersect_dense_pruned --- k2/csrc/fsa_algo.h | 13 ++++-- k2/csrc/intersect_dense_pruned.cu | 71 ++++++++++++++++++++++++++++--- k2/csrc/intersect_test.cu | 22 +++++++--- k2/python/csrc/torch/fsa_algo.cu | 8 ++-- k2/python/k2/autograd.py | 26 ++++++++--- 5 files changed, 114 insertions(+), 26 deletions(-) diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 92dde5dfe..4cbec2490 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -161,10 +161,10 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest, @param[in] b_fsas Input FSAs that correspond to neural network outputs (see documentation in fsa.h). @param[in] search_beam Beam for frame-synchronous beam pruning, - e.g. 20. Smaller is faster, larger is more exact - (less pruning). This is the default value; it may be - modified by {min,max}_active which dictate the minimum - or maximum allowed number of active states per frame. + e.g. 20. Smaller is faster, larger is more exact + (less pruning). This is the default value; it may be + modified by {min,max}_active which dictate the minimum + or maximum allowed number of active states per frame. @param[in] output_beam Beam with which we prune the output (analogous to lattice-beam in Kaldi), e.g. 8. We discard arcs in the output that are not on a path that's within @@ -178,6 +178,10 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest, of states are active. The hash size used per FSA is 4 times (this rounded up to a power of 2), so this affects memory consumption. + @param [in] allow_partial If true, we will treat all the states on the + last frame to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. @param[out] out Output vector of composed, pruned FSAs, with same Dim0() as b_fsas. Elements of it may be empty if the composition was empty, either intrinsically or due to @@ -196,6 +200,7 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest, void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam, float output_beam, int32_t min_active_states, int32_t max_active_states, + bool allow_partial, FsaVec *out, Array1 *arc_map_a, Array1 *arc_map_b); diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 6ef8f4f1b..82b501041 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -133,16 +133,23 @@ class MultiGraphDenseIntersectPruned { intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. This determines the hash size. + @param [in] allow_partial If true, we will treat all the states on the + last frame to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. + */ MultiGraphDenseIntersectPruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam, float output_beam, - int32_t min_active, int32_t max_active) + int32_t min_active, int32_t max_active, + bool allow_partial) : a_fsas_(a_fsas), b_fsas_(b_fsas), search_beam_(search_beam), output_beam_(output_beam), min_active_(min_active), max_active_(max_active), + allow_partial_(allow_partial), dynamic_beams_(a_fsas.Context(), b_fsas.shape.Dim0(), search_beam), forward_semaphore_(1) { NVTX_RANGE(K2_FUNC); @@ -498,12 +505,27 @@ class MultiGraphDenseIntersectPruned { int32_t dest_state_idx012 = oarc_idx01x_next + arc_info.u.dest_info_state_idx1; arc.dest_state = dest_state_idx012 - oarc_idx0xx; - arc.label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label; + int32_t arc_label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label; + arc.label = arc_label; + int32_t final_t = b_fsas_row_splits1[oarc_idx0+1] - b_fsas_row_splits1[oarc_idx0]; + if (t == final_t - 1 && arc_label != -1) { + if (allow_partial_) { + arc.label = -1; + } else { + // Unreachable code. + K2_LOG(FATAL) << + "arc.labe != -1 on final_arc when allow_partial==false."; + } + } int32_t fsa_id = oarc_idx0, b_fsas_idx0x = b_fsas_row_splits1[fsa_id], b_fsas_idx01 = b_fsas_idx0x + t, - b_fsas_idx2 = (arc.label + 1), + // Use arc_label instead of arc.label to keep track of + // the origial arc index in b_fsas when allow_partial == true. + // Then arc_map_b storages the "correct" arc index instead of + // the non-exist manually added arc pointing to super-final state. + b_fsas_idx2 = (arc_label + 1), b_fsas_arc_idx012 = b_fsas_idx01 * b_fsas_num_cols + b_fsas_idx2; arc.score = arc_info.arc_loglike; @@ -664,6 +686,9 @@ class MultiGraphDenseIntersectPruned { const int32_t *ai_row_ids2 = ai_shape.RowIds(2).Data(); // from state_idx01 to arc_idx01x const int32_t *ai_row_splits2 = ai_shape.RowSplits(2).Data(); + + const int32_t *a_fsas_row_splits1 = a_fsas_.shape.RowSplits(1).Data(); + const int32_t *a_fsas_row_ids1 = a_fsas_.shape.RowIds(1).Data(); // from state_idx01 (into a_fsas_) to arc_idx01x (into a_fsas_) const int32_t *a_fsas_row_splits2 = a_fsas_.shape.RowSplits(2).Data(); @@ -679,6 +704,29 @@ class MultiGraphDenseIntersectPruned { Ragged ai(ai_shape); ArcInfo *ai_data = ai.values.Data(); // uninitialized + // A valid final arc means its label == -1. + auto has_valid_final_arc = Array1(c_, NumFsas(), false); + bool *has_valid_final_arc_data = has_valid_final_arc.Data(); + + if (allow_partial_) { + K2_EVAL( + c_, ai.values.Dim(), set_has_non_inf_arc, (int32_t ai_arc_idx012)->void { + int32_t ai_state_idx01 = ai_row_ids2[ai_arc_idx012], + ai_fsa_idx0 = ai_row_ids1[ai_state_idx01], + ai_arc_idx01x = ai_row_splits2[ai_state_idx01], + ai_arc_idx2 = ai_arc_idx012 - ai_arc_idx01x; + StateInfo sinfo = state_values[ai_state_idx01]; + int32_t a_fsas_arc_idx01x = + a_fsas_row_splits2[sinfo.a_fsas_state_idx01], + a_fsas_arc_idx012 = a_fsas_arc_idx01x + ai_arc_idx2; + Arc arc = arcs[a_fsas_arc_idx012]; + auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0]; + if (final_t - 1 == t && -1 == arc.label) { + has_valid_final_arc_data[ai_fsa_idx0] = true; + } + }); + } + K2_EVAL( c_, ai.values.Dim(), ai_lambda, (int32_t ai_arc_idx012)->void { int32_t ai_state_idx01 = ai_row_ids2[ai_arc_idx012], @@ -698,6 +746,17 @@ class MultiGraphDenseIntersectPruned { K2_DCHECK_LT(static_cast(scores_idx2), static_cast(scores_num_cols)); float acoustic_score = scores_acc(scores_idx01, scores_idx2); + auto dest_state = arc.dest_state; + auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0]; + if (final_t - 1 == t && !has_valid_final_arc_data[ai_fsa_idx0] && + allow_partial_) { + int32_t a_fsas_idx0 = a_fsas_row_ids1[sinfo.a_fsas_state_idx01]; + // state_idx1 is 0-based. + // So "-1" is used when calculating a_fsas_final_state_idx1. + int32_t a_fsas_final_state_idx1 = a_fsas_row_splits1[a_fsas_idx0 + 1] - 1 - a_fsas_row_splits1[a_fsas_idx0]; + dest_state = a_fsas_final_state_idx1; + acoustic_score = 0.0; + } ArcInfo ai; ai.a_fsas_arc_idx012 = a_fsas_arc_idx012; ai.arc_loglike = acoustic_score + arc.score; @@ -709,7 +768,7 @@ class MultiGraphDenseIntersectPruned { // convert to an idx01; this relies on the fact that // sinfo.abs_state_id == arc.src_state + a_fsas_fsa_idx0x. ai.u.dest_a_fsas_state_idx01 = - sinfo.a_fsas_state_idx01 + arc.dest_state - arc.src_state; + sinfo.a_fsas_state_idx01 + dest_state - arc.src_state; ai_data[ai_arc_idx012] = ai; }); return ai; @@ -1459,6 +1518,7 @@ class MultiGraphDenseIntersectPruned { float output_beam_; int32_t min_active_; int32_t max_active_; + bool allow_partial_; Array1 dynamic_beams_; // dynamic beams (initially just search_beam_ // but change due to max_active/min_active // constraints). @@ -1521,13 +1581,14 @@ class MultiGraphDenseIntersectPruned { void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam, float output_beam, int32_t min_active_states, int32_t max_active_states, + bool allow_partial, FsaVec *out, Array1 *arc_map_a, Array1 *arc_map_b) { NVTX_RANGE("IntersectDensePruned"); FsaVec a_vec = FsaToFsaVec(a_fsas); MultiGraphDenseIntersectPruned intersector(a_vec, b_fsas, search_beam, output_beam, min_active_states, - max_active_states); + max_active_states, allow_partial); intersector.Intersect(); intersector.FormatOutput(out, arc_map_a, arc_map_b); diff --git a/k2/csrc/intersect_test.cu b/k2/csrc/intersect_test.cu index 25ac69a03..4e7044cc5 100644 --- a/k2/csrc/intersect_test.cu +++ b/k2/csrc/intersect_test.cu @@ -243,7 +243,7 @@ TEST(Intersect, RandomSingle) { K2_LOG(INFO) << "fsas_b = " << fsas_b; FsaVec out_fsas2; Array1 arc_map_a2, arc_map_b2; - // IntersectDensePruned() treats epsilons as normal symbols, so we need to + // IntersectDense() treats epsilons as normal symbols, so we need to // as well. ArcSort(&fsa); // CAUTION if you later test the arc_maps: we arc-sort here, @@ -339,7 +339,7 @@ TEST(Intersect, RandomFsaVec) { K2_LOG(INFO) << "fsas_b = " << fsas_b; FsaVec out_fsas2; Array1 arc_map_a2, arc_map_b2; - // IntersectDensePruned() treats epsilons as normal symbols, so we need to + // IntersectDense() treats epsilons as normal symbols, so we need to // as well. ArcSort(&fsavec); // CAUTION if you later test the arc_maps: we arc-sort @@ -401,11 +401,12 @@ TEST(IntersectPruned, Simple) { float beam = 100000; int32_t max_active = 10000, min_active = 0; + bool allow_partial = false; FsaVec out_fsas; Array1 arc_map_a, arc_map_b; IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active, - &out_fsas, &arc_map_a, &arc_map_b); + allow_partial, &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; @@ -458,11 +459,12 @@ TEST(IntersectPruned, TwoDense) { float beam = 100000; int32_t max_active = 10000, min_active = 0; + bool allow_partial = false; FsaVec out_fsas; Array1 arc_map_a, arc_map_b; IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active, - &out_fsas, &arc_map_a, &arc_map_b); + allow_partial, &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; @@ -507,11 +509,12 @@ TEST(IntersectPruned, TwoFsas) { float beam = 100000; int32_t max_active = 10000, min_active = 0; + bool allow_partial = false; FsaVec out_fsas; Array1 arc_map_a, arc_map_b; IntersectDensePruned(fsa_vec, dfsavec, beam, beam, min_active, max_active, - &out_fsas, &arc_map_a, &arc_map_b); + allow_partial, &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; @@ -575,8 +578,10 @@ TEST(IntersectPruned, RandomSingle) { FsaVec out_fsas; float beam = 1000.0; int32_t max_active = 10000, min_active = 0; + bool allow_partial = false; + IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active, - &out_fsas, &arc_map_a, &arc_map_b); + allow_partial, &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_b = " << arc_map_b; FsaVec fsas_b = ConvertDenseToFsaVec(dfsavec); @@ -679,8 +684,11 @@ TEST(IntersectPruned, RandomFsaVec) { FsaVec out_fsas; float search_beam = 1000.0, output_beam = 1000.0; int32_t min_active = 0, max_active = 10; + bool allow_partial = false; + IntersectDensePruned(fsavec, dfsavec, search_beam, output_beam, min_active, - max_active, &out_fsas, &arc_map_a, &arc_map_b); + max_active, allow_partial, + &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index 304fde809..83f5acab8 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -200,7 +200,7 @@ static void PybindIntersectDensePruned(py::module &m) { "intersect_dense_pruned", [](FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam, float output_beam, int32_t min_active_states, - int32_t max_active_states) + int32_t max_active_states, bool allow_partial) -> std::tuple { DeviceGuard guard(a_fsas.Context()); Array1 arc_map_a; @@ -208,13 +208,15 @@ static void PybindIntersectDensePruned(py::module &m) { FsaVec out; IntersectDensePruned(a_fsas, b_fsas, search_beam, output_beam, - min_active_states, max_active_states, &out, + min_active_states, max_active_states, + allow_partial, &out, &arc_map_a, &arc_map_b); return std::make_tuple(out, ToTorch(arc_map_a), ToTorch(arc_map_b)); }, py::arg("a_fsas"), py::arg("b_fsas"), py::arg("search_beam"), py::arg("output_beam"), py::arg("min_active_states"), - py::arg("max_active_states")); + py::arg("max_active_states"), + py::arg("allow_partial") = false); } static void PybindIntersectDense(py::module &m) { diff --git a/k2/python/k2/autograd.py b/k2/python/k2/autograd.py index 19a282dfe..5c5e1bb02 100644 --- a/k2/python/k2/autograd.py +++ b/k2/python/k2/autograd.py @@ -358,6 +358,7 @@ def forward(ctx, output_beam: float, min_active_states: int, max_active_states: int, + allow_partial: bool, unused_scores_a: torch.Tensor, unused_scores_b: torch.Tensor, seqframe_idx_name: Optional[str] = None, @@ -383,16 +384,20 @@ def forward(ctx, output_beam: Pruning beam for the output of intersection (vs. best path); equivalent to kaldi's lattice-beam. E.g. 8. - max_active_states: - Maximum number of FSA states that are allowed to be active on any - given frame for any given intersection/composition task. This is - advisory, in that it will try not to exceed that but may not always - succeed. You can use a very large number if no constraint is needed. min_active_states: Minimum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to have fewer than this number active. Set it to zero if there is no constraint. + max_active_states: + Maximum number of FSA states that are allowed to be active on any + given frame for any given intersection/composition task. This is + advisory, in that it will try not to exceed that but may not always + succeed. You can use a very large number if no constraint is needed. + allow_partial If true, we will treat all the states on the + last frame to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. unused_scores_a: It equals to `a_fsas.scores` and its sole purpose is for back propagation. @@ -418,7 +423,8 @@ def forward(ctx, search_beam=search_beam, output_beam=output_beam, min_active_states=min_active_states, - max_active_states=max_active_states) + max_active_states=max_active_states, + allow_partial=allow_partial) out_fsa[0] = Fsa(ragged_arc) @@ -650,6 +656,7 @@ def intersect_dense_pruned(a_fsas: Fsa, b_fsas: DenseFsaVec, search_beam: float, output_beam: float, + allow_partial: bool, min_active_states: int, max_active_states: int, seqframe_idx_name: Optional[str] = None, @@ -684,6 +691,10 @@ def intersect_dense_pruned(a_fsas: Fsa, frame for any given intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. You can use a very large number if no constraint is needed. + allow_partial If true, we will treat all the states on the + last frame to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. seqframe_idx_name: If set (e.g. to 'seqframe'), an attribute in the output will be created that encodes the sequence-index and the frame-index within that @@ -717,7 +728,8 @@ def intersect_dense_pruned(a_fsas: Fsa, # in `out_fsa[0].scores` _IntersectDensePrunedFunction.apply(a_fsas, b_fsas, out_fsa, search_beam, output_beam, min_active_states, - max_active_states, a_fsas.scores, + max_active_states, allow_partial, + a_fsas.scores, b_fsas.scores, seqframe_idx_name, frame_idx_name) return out_fsa[0] From a8af042b82bfece9f8e365acd86e55ec54bf3158 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 3 Nov 2022 13:40:36 +0800 Subject: [PATCH 02/22] fix comment --- k2/csrc/fsa_algo.h | 5 +++-- k2/csrc/intersect_dense_pruned.cu | 5 +++-- k2/python/k2/autograd.py | 6 ++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 4cbec2490..481beb772 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -178,8 +178,9 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest, of states are active. The hash size used per FSA is 4 times (this rounded up to a power of 2), so this affects memory consumption. - @param [in] allow_partial If true, we will treat all the states on the - last frame to be final state. If false, we only + @param [in] allow_partial If true and there was no final state active, + we will treat all the states on the last frame + to be final state. If false, we only care about the real final state in the decoding graph on the last frame when generating lattice. @param[out] out Output vector of composed, pruned FSAs, with same diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 82b501041..0e660758d 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -133,8 +133,9 @@ class MultiGraphDenseIntersectPruned { intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. This determines the hash size. - @param [in] allow_partial If true, we will treat all the states on the - last frame to be final state. If false, we only + @param [in] allow_partial If true and there was no final state active, + we will treat all the states on the last frame + to be final state. If false, we only care about the real final state in the decoding graph on the last frame when generating lattice. diff --git a/k2/python/k2/autograd.py b/k2/python/k2/autograd.py index 5c5e1bb02..d626811fe 100644 --- a/k2/python/k2/autograd.py +++ b/k2/python/k2/autograd.py @@ -394,7 +394,8 @@ def forward(ctx, given frame for any given intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. You can use a very large number if no constraint is needed. - allow_partial If true, we will treat all the states on the + allow_partial If true and there was no final state active, + we will treat all the states on the last frame to be final state. If false, we only care about the real final state in the decoding graph on the last frame when generating lattice. @@ -691,7 +692,8 @@ def intersect_dense_pruned(a_fsas: Fsa, frame for any given intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. You can use a very large number if no constraint is needed. - allow_partial If true, we will treat all the states on the + allow_partial If true and there was no final state active, + we will treat all the states on the last frame to be final state. If false, we only care about the real final state in the decoding graph on the last frame when generating lattice. From f2693038fedeaef1116c9c853f04d2037a6a60b3 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 3 Nov 2022 21:58:12 +0800 Subject: [PATCH 03/22] set default value and move to the last --- k2/python/k2/autograd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/k2/python/k2/autograd.py b/k2/python/k2/autograd.py index d626811fe..74a94f012 100644 --- a/k2/python/k2/autograd.py +++ b/k2/python/k2/autograd.py @@ -657,11 +657,11 @@ def intersect_dense_pruned(a_fsas: Fsa, b_fsas: DenseFsaVec, search_beam: float, output_beam: float, - allow_partial: bool, min_active_states: int, max_active_states: int, seqframe_idx_name: Optional[str] = None, - frame_idx_name: Optional[str] = None) -> Fsa: + frame_idx_name: Optional[str] = None, + allow_partial: bool = False) -> Fsa: '''Intersect array of FSAs on CPU/GPU. Caution: From 9e29641353d589e9a4cc8540e45cc809b7473523 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 7 Jul 2023 23:00:21 +0800 Subject: [PATCH 04/22] Fix crash and backward --- k2/csrc/intersect_dense_pruned.cu | 6 ++++-- k2/python/k2/autograd.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 6d2fbb660..5df072448 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -397,6 +397,7 @@ class MultiGraphDenseIntersectPruned { NVTX_RANGE("FormatOutput"); bool online_decoding = online_decoding_; + bool allow_partial = allow_partial_; if (online_decoding) { K2_CHECK(arc_map_a); K2_CHECK_EQ(arc_map_b, nullptr); @@ -548,7 +549,7 @@ class MultiGraphDenseIntersectPruned { arc.label = arc_label; int32_t final_t = b_fsas_row_splits1[oarc_idx0+1] - b_fsas_row_splits1[oarc_idx0]; if (t == final_t - 1 && arc_label != -1) { - if (allow_partial_) { + if (allow_partial) { arc.label = -1; } else { // Unreachable code. @@ -757,6 +758,7 @@ class MultiGraphDenseIntersectPruned { // A valid final arc means its label == -1. auto has_valid_final_arc = Array1(c_, NumFsas(), false); bool *has_valid_final_arc_data = has_valid_final_arc.Data(); + bool allow_partial = allow_partial_; if (allow_partial_) { K2_EVAL( @@ -799,7 +801,7 @@ class MultiGraphDenseIntersectPruned { auto dest_state = arc.dest_state; auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0]; if (final_t - 1 == t && !has_valid_final_arc_data[ai_fsa_idx0] && - allow_partial_) { + allow_partial) { int32_t a_fsas_idx0 = a_fsas_row_ids1[sinfo.a_fsas_state_idx01]; // state_idx1 is 0-based. // So "-1" is used when calculating a_fsas_final_state_idx1. diff --git a/k2/python/k2/autograd.py b/k2/python/k2/autograd.py index 132da1e69..5d62b472b 100644 --- a/k2/python/k2/autograd.py +++ b/k2/python/k2/autograd.py @@ -473,7 +473,7 @@ def forward(ctx, @staticmethod def backward(ctx, out_fsa_grad: torch.Tensor) \ - -> Tuple[None, None, None, None, None, None, None, torch.Tensor, torch.Tensor]: # noqa + -> Tuple[None, None, None, None, None, None, None, None, torch.Tensor, torch.Tensor, None, None]: # noqa a_scores, b_scores = ctx.saved_tensors arc_map_a = ctx.arc_map_a arc_map_b = ctx.arc_map_b @@ -500,6 +500,7 @@ def backward(ctx, out_fsa_grad: torch.Tensor) \ None, # output_beam None, # min_active_states None, # max_active_states + None, # allow_partial grad_a, # unused_scores_a grad_b, # unused_scores_b None, # seqframe_idx_name From 5280b7ee95432090e2131131fd30f27398bbb1fc Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 11 Jul 2023 12:53:23 +0800 Subject: [PATCH 05/22] Fix allow partial for online decoding --- k2/csrc/intersect_dense_pruned.cu | 49 ++++++++++++++++--------------- k2/torch/bin/online_decode.cu | 2 +- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 5df072448..339269277 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -250,7 +250,7 @@ class MultiGraphDenseIntersectPruned { @param [in] frames The frames generated for previously decoded chunks. @param [in] beams Current search beams for each of the sequences, it has `beams.Dim() == num_seqs_`. - @return A pointer to current `frames_`, which would be usefull to + @return A pointer to current `frames_`, which would be useful to generate `DecodeStateInfo` for each sequences. */ const std::vector>* OnlineIntersect( @@ -258,7 +258,7 @@ class MultiGraphDenseIntersectPruned { std::vector> &frames, Array1 &beams) { /* - T is the largest number of (frames+1) of neural net output currently + T is the largest number (frames+1) of neural net output currently received, or the largest number of frames of log-likelihoods we count the final frame with (0, -inf, -inf..) that is used for the final-arc. The largest number of states in the fsas represented by b_fsas equals @@ -275,17 +275,15 @@ class MultiGraphDenseIntersectPruned { b_fsas_ = b_fsas; frames_.swap(frames); dynamic_beams_ = beams.To(c_); - T_ = frames_.size(); + T_ = frames_.size() - 1; // -1 here because we already put the initial frame info to frames_ - int32_t T = T_ + b_fsas_->shape.MaxSize(1) - 1; + int32_t T = T_ + b_fsas_->shape.MaxSize(1); // we'll initially populate frames_[0.. T+1], but discard the one at T+1, // which has no arcs or states, the ones we use are from 0 to T. frames_.reserve(T + 2); - if (T_ == 0) frames_.push_back(InitialFrameInfo()); - for (int32_t t = 0; t <= b_fsas_->shape.MaxSize(1); t++) { if (state_map_.NumKeyBits() == 32) { frames_.push_back(PropagateForward<32>(t, frames_.back().get())); @@ -296,7 +294,8 @@ class MultiGraphDenseIntersectPruned { frames_.push_back(PropagateForward<40>(t, frames_.back().get())); } if (t == b_fsas_->shape.MaxSize(1)) { - PruneTimeRange(T_ - 1, T_ + t); + int32_t start = std::max(0, T_ - 5); + PruneTimeRange(start, T_ + t); } } // The FrameInfo for time T+1 will have no states. We did that @@ -304,9 +303,9 @@ class MultiGraphDenseIntersectPruned { // is set up (it has no arcs but we need the shape). frames_.pop_back(); - int32_t history_t = T_ - 1; + int32_t history_t = T_; - T_ = T - 1; + T_ = T; // partial_final_frame_ is the last frame to generate partial result, // but it should not be the start frame of next chunk decoding. partial_final_frame_ = std::move(frames_.back()); @@ -320,7 +319,7 @@ class MultiGraphDenseIntersectPruned { c_, num_seqs_, lambda_set_final_and_final_t, (int32_t i)->void { int32_t b_chunk_size = b_fsas_row_splits1[i + 1] - b_fsas_row_splits1[i]; - final_t_data[i] = history_t + b_chunk_size - 1; + final_t_data[i] = history_t + b_chunk_size; }); return &frames_; } @@ -393,7 +392,7 @@ class MultiGraphDenseIntersectPruned { } void FormatOutput(FsaVec *ofsa, Array1 *arc_map_a, - Array1 *arc_map_b, bool is_final) { + Array1 *arc_map_b) { NVTX_RANGE("FormatOutput"); bool online_decoding = online_decoding_; @@ -402,11 +401,10 @@ class MultiGraphDenseIntersectPruned { K2_CHECK(arc_map_a); K2_CHECK_EQ(arc_map_b, nullptr); } else { - K2_CHECK(is_final); K2_CHECK(arc_map_a && arc_map_b); } - int32_t T = is_final ? T_ : T_ + 1; + int32_t T = T_; ContextPtr c_cpu = GetCpuContext(); Array1 arcs_data_ptrs(c_cpu, T + 1); Array1 arcs_row_splits1_ptrs(c_cpu, T + 1); @@ -414,12 +412,12 @@ class MultiGraphDenseIntersectPruned { arcs_data_ptrs.Data()[t] = frames_[t]->arcs.values.Data(); arcs_row_splits1_ptrs.Data()[t] = frames_[t]->arcs.RowSplits(1).Data(); } - arcs_data_ptrs.Data()[T] = is_final - ? frames_[T]->arcs.values.Data() - : partial_final_frame_->arcs.values.Data(); + arcs_data_ptrs.Data()[T] = online_decoding + ? partial_final_frame_->arcs.values.Data(); + : frames_[T]->arcs.values.Data() arcs_row_splits1_ptrs.Data()[T] = - is_final ? frames_[T]->arcs.RowSplits(1).Data() - : partial_final_frame_->arcs.RowSplits(1).Data(); + online_decoding ? partial_final_frame_->arcs.RowSplits(1).Data(); + : frames_[T]->arcs.RowSplits(1).Data() // transfer to GPU if we're using a GPU arcs_data_ptrs = arcs_data_ptrs.To(c_); @@ -447,10 +445,11 @@ class MultiGraphDenseIntersectPruned { int32_t *num_extra_states_data = num_extra_states.Data(); K2_EVAL(c_, num_fsas, lambda_set_num_extra_states, (int32_t i) -> void { int32_t final_t; - if (online_decoding) - final_t = is_final ? final_t_data[i] : final_t_data[i] + 1; - else + if (online_decoding) { + final_t = final_t_data[i]; + } else { final_t = b_fsas_row_splits1[i+1] - b_fsas_row_splits1[i]; + } int32_t *arcs_row_splits1_data = arcs_row_splits1_ptrs_data[final_t]; int32_t num_states_final_t = arcs_row_splits1_data[i + 1] - @@ -485,8 +484,8 @@ class MultiGraphDenseIntersectPruned { for (int32_t t = 0; t < T; t++) arcs_shapes[t] = &(frames_[t]->arcs.shape); - arcs_shapes[T] = is_final ? &(frames_[T]->arcs.shape) - : &(partial_final_frame_->arcs.shape); + arcs_shapes[T] = online_decoding ? &(partial_final_frame_->arcs.shape) + : &(frames_[T]->arcs.shape); arcs_shapes[T + 1] = &final_arcs_shape; @@ -547,7 +546,9 @@ class MultiGraphDenseIntersectPruned { arc.dest_state = dest_state_idx012 - oarc_idx0xx; int32_t arc_label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label; arc.label = arc_label; - int32_t final_t = b_fsas_row_splits1[oarc_idx0+1] - b_fsas_row_splits1[oarc_idx0]; + + int32_t final_t = online_decoding ? final_t_data[oarc_idx0] + :b_fsas_row_splits1[oarc_idx0+1] - b_fsas_row_splits1[oarc_idx0]; if (t == final_t - 1 && arc_label != -1) { if (allow_partial) { arc.label = -1; diff --git a/k2/torch/bin/online_decode.cu b/k2/torch/bin/online_decode.cu index 362b1002a..2c52c7e79 100644 --- a/k2/torch/bin/online_decode.cu +++ b/k2/torch/bin/online_decode.cu @@ -227,7 +227,7 @@ int main(int argc, char *argv[]) { std::vector positions(num_waves, 0); int32_t T = nnet_output.size(1); - int32_t chunk_size = 10; // 20 frames per chunk + int32_t chunk_size = 20; // 20 frames per chunk // simulate asynchronous decoding while (true) { From bdd5150211406c6079a22bed97170a0a3f27933a Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 11 Jul 2023 12:53:55 +0800 Subject: [PATCH 06/22] Add python demo --- k2/torch/bin/hlg_decode.py | 211 ++++++++++++++++++++++++++ k2/torch/bin/online_decode.py | 270 ++++++++++++++++++++++++++++++++++ 2 files changed, 481 insertions(+) create mode 100644 k2/torch/bin/hlg_decode.py create mode 100644 k2/torch/bin/online_decode.py diff --git a/k2/torch/bin/hlg_decode.py b/k2/torch/bin/hlg_decode.py new file mode 100644 index 000000000..aef48e6e8 --- /dev/null +++ b/k2/torch/bin/hlg_decode.py @@ -0,0 +1,211 @@ +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from k2 import ( + get_lattice, + one_best_decoding, + get_aux_labels, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", type=str, required=True, help="Path to the jit script model. " + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + args.sample_rate = 16000 + args.subsampling_factor = 4 + args.feature_dim = 80 + args.num_classes = 500 + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = torch.jit.load(args.nn_model) + model = model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = args.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=args.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + feature_len = [] + for f in features: + feature_len.append(f.shape[0]) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + nnet_output, _, _ = model(features) + + log_prob = torch.nn.functional.log_softmax(nnet_output, dim=-1) + log_prob_len = torch.tensor(feature_len) // args.subsampling_factor + log_prob_len = log_prob_len.to(device) + + if args.method == "ctc-decoding": + logging.info("Use CTC decoding") + max_token_id = args.num_classes - 1 + + H = k2.ctc_topo(max_token=max_token_id, device=device,) + + lattice = get_lattice( + log_prob=log_prob, + log_prob_len=log_prob_len, + decoding_graph=H, + subsampling_factor=args.subsampling_factor, + ) + + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + token_ids = get_aux_labels(best_path) + token_sym_table = k2.SymbolTable.from_file(args.tokens) + + hyps = ["".join([token_sym_table[i] for i in ids]) for ids in token_ids] + + else: + assert args.method == "1best", args.method + logging.info(f"Loading HLG from {args.HLG}") + HLG = k2.Fsa.from_dict(torch.load(args.HLG, map_location="cpu")) + HLG = HLG.to(device) + + lattice = get_lattice( + log_prob=log_prob, + log_prob_len=log_prob_len, + decoding_graph=HLG, + subsampling_factor=args.subsampling_factor, + ) + + if args.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + + hyps = get_aux_labels(best_path) + word_sym_table = k2.SymbolTable.from_file(args.words_file) + hyps = [" ".join([word_sym_table[i] for i in ids]) for ids in hyps] + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = hyp.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + torch.save(lattice.as_dict(), "offline.pt") + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/k2/torch/bin/online_decode.py b/k2/torch/bin/online_decode.py new file mode 100644 index 000000000..76bb6899f --- /dev/null +++ b/k2/torch/bin/online_decode.py @@ -0,0 +1,270 @@ +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from k2 import ( + DecodeStateInfo, + OnlineDenseIntersecter, + one_best_decoding, + get_aux_labels, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", type=str, required=True, help="Path to the jit script model. " + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + """, + ) + + parser.add_argument( + "--num-streams", + type=int, + default=2, + help="""The number of streams that can be run in parallel.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + args.sample_rate = 16000 + args.subsampling_factor = 4 + args.feature_dim = 80 + args.num_classes = 500 + args.chunk_size = 10 + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = torch.jit.load(args.nn_model) + model = model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = args.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=args.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + feature_len = [] + for f in features: + feature_len.append(f.shape[0]) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + nnet_output, _, _ = model(features) + num_frames = [x // args.subsampling_factor for x in feature_len] + T = nnet_output.shape[1] + + if args.method == "ctc-decoding": + logging.info("Use CTC decoding") + max_token_id = args.num_classes - 1 + decoding_graph = k2.ctc_topo( + max_token=max_token_id, + device=device, + ) + token_sym_table = k2.SymbolTable.from_file(args.tokens) + else: + assert args.method == "1best", args.method + logging.info(f"Loading HLG from {args.HLG}") + decoding_graph = k2.Fsa.from_dict(torch.load(args.HLG, map_location="cpu")) + decoding_graph = decoding_graph.to(device) + word_sym_table = k2.SymbolTable.from_file(args.words_file) + decoding_graph = k2.Fsa.from_fsas([decoding_graph]) + + intersector = k2.OnlineDenseIntersecter( + decoding_graph=decoding_graph, + num_streams=args.num_streams, + search_beam=20, + output_beam=8, + min_active_states=30, + max_active_states=10000, + ) + + state_infos = [None] * len(waves) + positions = [0] * len(waves) + results = [""] * len(waves) + + while True: + current_state_infos = [] + current_nnet_outputs = [] + current_wave_ids = [] + current_num_frames = [] + for i in range(len(waves)): + if positions[i] == num_frames[i]: + continue + current_state_infos.append(state_infos[i]) + current_wave_ids.append(i) + start = positions[i] + if (num_frames[i] - positions[i]) < args.chunk_size: + current_num_frames.append(num_frames[i] - positions[i]) + end = num_frames[i] + positions[i] = num_frames[i] + else: + current_num_frames.append(args.chunk_size) + end = positions[i] + args.chunk_size + positions[i] += args.chunk_size + + current_nnet_outputs.append(nnet_output[i, start:end, :]) + if len(current_wave_ids) == args.num_streams: + break + if len(current_wave_ids) == 0: + break + while len(current_num_frames) < args.num_streams: + current_num_frames.append(1) + current_nnet_outputs.append( + torch.zeros( + (args.chunk_size, nnet_output.shape[2]), + device=nnet_output.device, + ) + ) + current_state_infos.append(None) + + current_nnet_outputs = pad_sequence(current_nnet_outputs, batch_first=True) + supervision_segments = torch.tensor( + # seq_index, start_time, duration + [[i, 0, current_num_frames[i]] for i in range(args.num_streams)], + dtype=torch.int32, + ) + logging.info(f"supervision_segments : {supervision_segments}") + dense_fsa_vec = k2.DenseFsaVec(current_nnet_outputs, supervision_segments) + lattice, current_state_infos = intersector.decode( + dense_fsa_vec, current_state_infos + ) + + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + symbol_ids = get_aux_labels(best_path) + + if args.method == "ctc-decoding": + hyps = ["".join([token_sym_table[i] for i in ids]) for ids in symbol_ids] + else: + assert args.method == "1best", args.method + hyps = [" ".join([word_sym_table[i] for i in ids]) for ids in symbol_ids] + + s = "\n" + for i in range(len(current_wave_ids)): + state_infos[current_wave_ids[i]] = current_state_infos[i] + results[current_wave_ids[i]] = hyps[i].replace("▁", " ").strip() + s += f"{args.sound_files[current_wave_ids[i]]}:\n" + s += f"{results[current_wave_ids[i]]}\n\n" + logging.info(s) + + torch.save(lattice.as_dict(), "online.pt") + + s = "\n" + for filename, hyp in zip(args.sound_files, results): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() From b3a84c751612bc2a2861633de40cfac14f3a8741 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 11 Jul 2023 13:05:56 +0800 Subject: [PATCH 07/22] Minor fixes --- k2/csrc/intersect_dense_pruned.cu | 37 +++++++++++-------------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 339269277..6f91de875 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -413,11 +413,11 @@ class MultiGraphDenseIntersectPruned { arcs_row_splits1_ptrs.Data()[t] = frames_[t]->arcs.RowSplits(1).Data(); } arcs_data_ptrs.Data()[T] = online_decoding - ? partial_final_frame_->arcs.values.Data(); - : frames_[T]->arcs.values.Data() + ? partial_final_frame_->arcs.values.Data() + : frames_[T]->arcs.values.Data(); arcs_row_splits1_ptrs.Data()[T] = - online_decoding ? partial_final_frame_->arcs.RowSplits(1).Data(); - : frames_[T]->arcs.RowSplits(1).Data() + online_decoding ? partial_final_frame_->arcs.RowSplits(1).Data() + : frames_[T]->arcs.RowSplits(1).Data(); // transfer to GPU if we're using a GPU arcs_data_ptrs = arcs_data_ptrs.To(c_); @@ -444,12 +444,8 @@ class MultiGraphDenseIntersectPruned { Array1 num_extra_states(c_, num_fsas + 1); int32_t *num_extra_states_data = num_extra_states.Data(); K2_EVAL(c_, num_fsas, lambda_set_num_extra_states, (int32_t i) -> void { - int32_t final_t; - if (online_decoding) { - final_t = final_t_data[i]; - } else { - final_t = b_fsas_row_splits1[i+1] - b_fsas_row_splits1[i]; - } + int32_t final_t = online_decoding ? final_t_data[i] + : b_fsas_row_splits1[i+1] - b_fsas_row_splits1[i]; int32_t *arcs_row_splits1_data = arcs_row_splits1_ptrs_data[final_t]; int32_t num_states_final_t = arcs_row_splits1_data[i + 1] - @@ -559,16 +555,6 @@ class MultiGraphDenseIntersectPruned { } } - int32_t fsa_id = oarc_idx0, - b_fsas_idx0x = b_fsas_row_splits1[fsa_id], - b_fsas_idx01 = b_fsas_idx0x + t, - // Use arc_label instead of arc.label to keep track of - // the origial arc index in b_fsas when allow_partial == true. - // Then arc_map_b storages the "correct" arc index instead of - // the non-exist manually added arc pointing to super-final state. - b_fsas_idx2 = (arc_label + 1), - b_fsas_arc_idx012 = b_fsas_idx01 * b_fsas_num_cols + b_fsas_idx2; - arc.score = arc_info.arc_loglike; arcs_out_data[oarc_idx0123] = arc; @@ -578,7 +564,11 @@ class MultiGraphDenseIntersectPruned { int32_t fsa_id = oarc_idx0, b_fsas_idx0x = b_fsas_row_splits1[fsa_id], b_fsas_idx01 = b_fsas_idx0x + t, - b_fsas_idx2 = (arc.label + 1), + // Use arc_label instead of arc.label to keep track of + // the origial arc index in b_fsas when allow_partial == true. + // Then arc_map_b storages the "correct" arc index instead of + // the non-exist manually added arc pointing to super-final state. + b_fsas_idx2 = (arc_label + 1), b_fsas_arc_idx012 = b_fsas_idx01 * b_fsas_num_cols + b_fsas_idx2; arc_map_b_data[oarc_idx0123] = b_fsas_arc_idx012; } @@ -705,7 +695,6 @@ class MultiGraphDenseIntersectPruned { NVTX_RANGE(K2_FUNC); Ragged &states = cur_frame->states; const StateInfo *state_values = states.values.Data(); - float minus_inf = -std::numeric_limits::infinity(); // in a_fsas_ (the decoding graphs), maps from state_idx01 to arc_idx01x. const int32_t *fsa_arc_splits = a_fsas_.shape.RowSplits(2).Data(); @@ -1663,7 +1652,7 @@ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, auto b_fsas_p = std::make_shared(b_fsas); intersector.Intersect(b_fsas_p); - intersector.FormatOutput(out, arc_map_a, arc_map_b, true); + intersector.FormatOutput(out, arc_map_a, arc_map_b); } OnlineDenseIntersecter::OnlineDenseIntersecter(FsaVec &a_fsas, @@ -1756,7 +1745,7 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, } const auto new_frames = impl_->OnlineIntersect(b_fsas_p, frames, beams); - impl_->FormatOutput(ofsa, arc_map_a, nullptr/*arc_map_b*/, false); + impl_->FormatOutput(ofsa, arc_map_a, nullptr/*arc_map_b*/); int32_t frames_num = new_frames->size(); std::vector *> frame_states_ptr_vec(frames_num); From 664202152b83d00fe5333c196771bc860bd3194a Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 16 Jul 2023 07:49:55 +0800 Subject: [PATCH 08/22] Add is_final --- k2/csrc/intersect_dense_pruned.cu | 82 ++++++++++++++++++------------- k2/csrc/intersect_dense_pruned.h | 9 ++-- k2/python/csrc/torch/fsa_algo.cu | 20 +++++--- k2/torch/bin/online_decode.cu | 17 +++---- k2/torch/bin/online_decode.py | 15 +++--- 5 files changed, 82 insertions(+), 61 deletions(-) diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 6f91de875..91f6b88a3 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -95,6 +95,7 @@ class MultiGraphDenseIntersectPruned { forward_semaphore_(1), final_t_(a_fsas.Context(), num_seqs, 0) { NVTX_RANGE(K2_FUNC); + c_ = GetContext(a_fsas.shape); K2_CHECK_GT(search_beam, 0); K2_CHECK_GT(output_beam, 0); @@ -147,7 +148,7 @@ class MultiGraphDenseIntersectPruned { log-likes of each phone. A series of sequences of (in general) different length. */ - void Intersect(std::shared_ptr &b_fsas) { + void Intersect(DenseFsaVec *b_fsas) { /* T is the largest number of (frames+1) of neural net output, or the largest number of frames of log-likelihoods we count the final frame with (0, @@ -254,9 +255,10 @@ class MultiGraphDenseIntersectPruned { generate `DecodeStateInfo` for each sequences. */ const std::vector>* OnlineIntersect( - std::shared_ptr &b_fsas, + DenseFsaVec *b_fsas, std::vector> &frames, - Array1 &beams) { + Array1 &beams, + Array1 &is_final) { /* T is the largest number (frames+1) of neural net output currently received, or the largest number of frames of log-likelihoods we count the @@ -275,6 +277,7 @@ class MultiGraphDenseIntersectPruned { b_fsas_ = b_fsas; frames_.swap(frames); dynamic_beams_ = beams.To(c_); + is_final_ = is_final.To(c_); T_ = frames_.size() - 1; // -1 here because we already put the initial frame info to frames_ @@ -294,7 +297,7 @@ class MultiGraphDenseIntersectPruned { frames_.push_back(PropagateForward<40>(t, frames_.back().get())); } if (t == b_fsas_->shape.MaxSize(1)) { - int32_t start = std::max(0, T_ - 5); + int32_t start = std::max(0, T_ - 3); PruneTimeRange(start, T_ + t); } } @@ -631,6 +634,7 @@ class MultiGraphDenseIntersectPruned { Array1 cutoffs(c_, num_fsas); float *cutoffs_data = cutoffs.Data(); + bool online_decoding = online_decoding_; K2_EVAL( c_, num_fsas, lambda_set_beam_and_cutoffs, (int32_t i)->void { float best_loglike = max_per_fsa_data[i], @@ -641,7 +645,7 @@ class MultiGraphDenseIntersectPruned { float current_min_active = min_active; // Do less pruning on the few final frames, to ensure we don't prune // away final states. - if (t + 5 >= final_t) { + if (!online_decoding && t + 5 >= final_t) { current_min_active = max(min_active, max_active / 2); } if (active_states <= max_active) { @@ -662,7 +666,7 @@ class MultiGraphDenseIntersectPruned { } else { // We modify dynamic_beam when max_active violated only if it's not // last few frames, in order to avoid final states pruning. - if (t + 5 < final_t) { + if (online_decoding || t + 5 < final_t) { // We violated the max_active constraint -> decrease beam if (dynamic_beam > default_beam) dynamic_beam = default_beam; @@ -769,6 +773,9 @@ class MultiGraphDenseIntersectPruned { }); } + bool online_decoding = online_decoding_; + bool *is_final_data = is_final_.Data(); + K2_EVAL( c_, ai.values.Dim(), ai_lambda, (int32_t ai_arc_idx012)->void { int32_t ai_state_idx01 = ai_row_ids2[ai_arc_idx012], @@ -790,8 +797,11 @@ class MultiGraphDenseIntersectPruned { float acoustic_score = scores_acc(scores_idx01, scores_idx2); auto dest_state = arc.dest_state; auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0]; - if (final_t - 1 == t && !has_valid_final_arc_data[ai_fsa_idx0] && - allow_partial) { + bool is_final_chunk = is_final_data[ai_fsa_idx0]; + + if ((online_decoding && final_t - 1 == t && !is_final_chunk) || + (final_t - 1 == t && !has_valid_final_arc_data[ai_fsa_idx0] && + allow_partial)) { int32_t a_fsas_idx0 = a_fsas_row_ids1[sinfo.a_fsas_state_idx01]; // state_idx1 is 0-based. // So "-1" is used when calculating a_fsas_final_state_idx1. @@ -1558,7 +1568,7 @@ class MultiGraphDenseIntersectPruned { int32_t a_fsas_stride_; // 1 if we use a different FSA per sequence // (a_fsas_.Dim0() > 1), 0 if the decoding graph is // shared (a_fsas_.Dim0() == 1). - std::shared_ptr b_fsas_; // nnet_output to be decoded. + DenseFsaVec *b_fsas_; // nnet_output to be decoded. int32_t num_seqs_; // the number of sequences to decode at a time, // i.e. batch size for decoding. int32_t T_; // equals to b_fsas_->shape.MaxSize(1), for @@ -1575,14 +1585,15 @@ class MultiGraphDenseIntersectPruned { bool online_decoding_; // true for online decoding. Array1 final_t_; // record the final frame id of each DenseFsa. - + Array1 is_final_; // For online decoding, it has a dimension of + // b_fsas_->Dim0() indicating whether this is + // the final chunk of current sequence. std::unique_ptr partial_final_frame_; // store the final frame for // partial results int32_t state_map_fsa_stride_; // state_map_fsa_stride_ is a_fsas_.TotSize(1) // if a_fsas_.Dim0() == 1, else 0. - Hash state_map_; // state_map_ maps from: // key == (state_map_fsa_stride_*n) + a_fsas_state_idx01, // where n is the fsa_idx, i.e. the index into b_fsas_ @@ -1649,17 +1660,14 @@ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, max_active_states, allow_partial, online_decoding); - - auto b_fsas_p = std::make_shared(b_fsas); - intersector.Intersect(b_fsas_p); + intersector.Intersect(&b_fsas); intersector.FormatOutput(out, arc_map_a, arc_map_b); } OnlineDenseIntersecter::OnlineDenseIntersecter(FsaVec &a_fsas, int32_t num_seqs, float search_beam, float output_beam, - int32_t min_active_states, int32_t max_active_states) { + int32_t min_active_states, int32_t max_active_states, bool allow_partial) { bool online_decoding = true; - bool allow_partial = true; K2_CHECK_EQ(a_fsas.NumAxes(), 3); c_ = a_fsas.Context(); search_beam_ = search_beam; @@ -1684,10 +1692,9 @@ OnlineDenseIntersecter::~OnlineDenseIntersecter(){ } void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, - std::vector> *decode_states, + std::vector *decode_states, FsaVec *ofsa, Array1 *arc_map_a) { - auto b_fsas_p = std::make_shared(b_fsas); - int32_t num_seqs = b_fsas_p->shape.Dim0(); + int32_t num_seqs = b_fsas.shape.Dim0(); K2_CHECK_EQ(num_seqs, static_cast(decode_states->size())); @@ -1696,27 +1703,31 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, Array1 beams(GetCpuContext(), num_seqs); float *beams_data = beams.Data(); + Array1 is_final(GetCpuContext(), num_seqs); + bool *is_final_data = is_final.Data(); for (int32_t i = 0; i < num_seqs; ++i) { - // initialization - if (!decode_states->at(i)) { - DecodeStateInfo info; + DecodeStateInfo *decode_state_ptr = decode_states->at(i); + K2_CHECK(decode_state_ptr); + // initialization; NumAxes == 1 means this is an uninitialized Ragged + if (decode_state_ptr->states.NumAxes() == 1) { StateInfo sinfo; // start state of decoding graph sinfo.a_fsas_state_idx01 = 0; sinfo.forward_loglike = FloatToOrderedInt(0.0); - info.states = Ragged( + decode_state_ptr->states = Ragged( RegularRaggedShape(c_, 1, 1), Array1(c_, std::vector{sinfo})); - info.arcs = Ragged(RaggedShape(c_, "[ [ [ x ] ] ]"), + decode_state_ptr->arcs = Ragged(RaggedShape(c_, "[ [ [ x ] ] ]"), Array1(c_, std::vector{ArcInfo()})); - info.beam = search_beam_; - decode_states->at(i) = std::make_shared(info); + decode_state_ptr->beam = search_beam_; + decode_state_ptr->is_final = false; } - seq_states_ptr_vec[i] = &(decode_states->at(i)->states); - seq_arcs_ptr_vec[i] = &(decode_states->at(i)->arcs); - beams_data[i] = decode_states->at(i)->beam; + seq_states_ptr_vec[i] = &(decode_state_ptr->states); + seq_arcs_ptr_vec[i] = &(decode_state_ptr->arcs); + beams_data[i] = decode_state_ptr->beam; + is_final_data[i] = decode_state_ptr->is_final; } auto stack_states = Stack(0, num_seqs, seq_states_ptr_vec.data()); @@ -1744,7 +1755,9 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, frames[i] = std::make_unique(info); } - const auto new_frames = impl_->OnlineIntersect(b_fsas_p, frames, beams); + const auto new_frames = impl_->OnlineIntersect( + &b_fsas, frames, beams, is_final); + impl_->FormatOutput(ofsa, arc_map_a, nullptr/*arc_map_b*/); int32_t frames_num = new_frames->size(); @@ -1766,11 +1779,10 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, beams = impl_->GetBeams().To(GetCpuContext()); beams_data = beams.Data(); for (int32_t i = 0; i < num_seqs; ++i) { - DecodeStateInfo info; - info.states = seq_states_vec[i]; - info.arcs = seq_arcs_vec[i]; - info.beam = beams_data[i]; - decode_states->at(i) = std::make_shared(info); + DecodeStateInfo* decode_state_ptr = decode_states->at(i); + decode_state_ptr->states = seq_states_vec[i]; + decode_state_ptr->arcs = seq_arcs_vec[i]; + decode_state_ptr->beam = beams_data[i]; } } diff --git a/k2/csrc/intersect_dense_pruned.h b/k2/csrc/intersect_dense_pruned.h index 7f19405f1..f05b86308 100644 --- a/k2/csrc/intersect_dense_pruned.h +++ b/k2/csrc/intersect_dense_pruned.h @@ -117,7 +117,7 @@ class MultiGraphDenseIntersectPruned; // DecodeStateInfo contains the history decoding states for each sequence, this // is normally constructed from `frames_` in MultiGraphDenseIntersectPruned -// bu using `Stack` and `Unstack`. +// by using `Stack` and `Unstack`. struct DecodeStateInfo { // States that survived for the previously decoded frames. Indexed // [frame_idx][state_idx], state_idx just enumerates the active states @@ -135,6 +135,9 @@ struct DecodeStateInfo { // current search beam for this sequence float beam; + + // True if the chunk to be decoded is the final chunk + bool is_final; }; @@ -170,7 +173,7 @@ class OnlineDenseIntersecter { public: OnlineDenseIntersecter(FsaVec &a_fsas, int32_t num_seqs, float search_beam, float output_beam, int32_t min_states, - int32_t max_states); + int32_t max_states, bool allow_partial=true); /* Does intersection/composition for current chunk of nnet_output(given by a DenseFsaVec), sequences in every chunk may come from different @@ -194,7 +197,7 @@ class OnlineDenseIntersecter { will have been assigned to this location. */ void Decode(DenseFsaVec &b_fsas, - std::vector> *decode_states, + std::vector *decode_states, FsaVec *ofsa, Array1 *arc_map_a); ContextPtr &Context() { return c_;} diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index bf21a0d1a..aa681e6c5 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -753,8 +753,10 @@ static void PybindLevenshteinGraph(py::module &m) { static void PybindDecodeStateInfo(py::module &m) { using PyClass = DecodeStateInfo; - py::class_> state_info(m, - "DecodeStateInfo"); + py::class_> state_info( + m, "DecodeStateInfo"); + state_info.def(py::init<>()); + state_info.def_readwrite("is_final", &PyClass::is_final); } static void PybindOnlineDenseIntersecter(py::module &m) { @@ -765,15 +767,17 @@ static void PybindOnlineDenseIntersecter(py::module &m) { py::init([](FsaVec &decoding_graph, int32_t num_streams, float search_beam, float output_beam, int32_t min_active_states, - int32_t max_active_states) -> std::unique_ptr { + int32_t max_active_states, + bool allow_partial) -> std::unique_ptr { DeviceGuard guard(decoding_graph.Context()); return std::make_unique(decoding_graph, num_streams, search_beam, output_beam, - min_active_states, max_active_states); + min_active_states, max_active_states, + allow_partial); }), py::arg("decoding_graph"), py::arg("num_streams"), py::arg("search_beam"), py::arg("output_beam"), py::arg("min_active_states"), - py::arg("max_active_states")); + py::arg("max_active_states"), py::arg("allow_partial") = true); intersecter.def( "decode", @@ -784,7 +788,11 @@ static void PybindOnlineDenseIntersecter(py::module &m) { DeviceGuard guard(self.Context()); FsaVec ofsa; Array1 arc_map; - self.Decode(dense_fsa_vec, &decode_states, &ofsa, &arc_map); + std::vector decode_states_ptr(decode_states.size()); + for (size_t i = 0; i < decode_states.size(); ++i) { + decode_states_ptr[i] = decode_states[i].get(); + } + self.Decode(dense_fsa_vec, &decode_states_ptr, &ofsa, &arc_map); torch::Tensor arc_map_tensor = ToTorch(arc_map); return std::make_tuple(ofsa, arc_map_tensor, decode_states); }, diff --git a/k2/torch/bin/online_decode.cu b/k2/torch/bin/online_decode.cu index 2c52c7e79..4ad242127 100644 --- a/k2/torch/bin/online_decode.cu +++ b/k2/torch/bin/online_decode.cu @@ -219,7 +219,7 @@ int main(int argc, char *argv[]) { FLAGS_min_activate_states, FLAGS_max_activate_states); // store decode states for each waves - std::vector> states_info(num_waves); + std::vector states_info(num_waves); // decocding results for each waves std::vector texts(num_waves, ""); @@ -231,8 +231,8 @@ int main(int argc, char *argv[]) { // simulate asynchronous decoding while (true) { - std::vector> current_states_info( - FLAGS_num_streams); + k2::DecodeStateInfo dummy_state_info; + std::vector current_states_info; std::vector num_frame; std::vector current_nnet_output; // which waves we are decoding now @@ -242,12 +242,13 @@ int main(int argc, char *argv[]) { // this wave is done if (num_frames[i] == 0) continue; - current_states_info[current_wave_ids.size()] = states_info[i]; + current_states_info.push_back(&states_info[i]); current_wave_ids.push_back(i); - if (num_frames[i] < chunk_size * subsampling_factor) { + if (num_frames[i] <= chunk_size * subsampling_factor) { num_frame.push_back(num_frames[i]); num_frames[i] = 0; + states_info[i].is_final = true; } else { num_frame.push_back(chunk_size * subsampling_factor); num_frames[i] -= chunk_size * subsampling_factor; @@ -280,6 +281,7 @@ int main(int argc, char *argv[]) { .device(nnet_output.device()); current_nnet_output.push_back( torch::zeros({chunk_size, nnet_output.size(2)}, opts)); + current_states_info.push_back(&dummy_state_info); } auto sub_nnet_output = torch::stack(current_nnet_output); @@ -303,11 +305,6 @@ int main(int argc, char *argv[]) { decoder.Decode(dense_fsa_vec, ¤t_states_info, &fsa, &graph_arc_map); - // update decoding states - for (size_t i = 0; i < current_wave_ids.size(); ++i) { - states_info[current_wave_ids[i]] = current_states_info[i]; - } - k2::FsaClass lattice(fsa); lattice.CopyAttrs(decoding_graph, k2::Array1ToTorch(graph_arc_map)); diff --git a/k2/torch/bin/online_decode.py b/k2/torch/bin/online_decode.py index 76bb6899f..f56e56fc5 100644 --- a/k2/torch/bin/online_decode.py +++ b/k2/torch/bin/online_decode.py @@ -185,7 +185,7 @@ def main(): max_active_states=10000, ) - state_infos = [None] * len(waves) + state_infos = [DecodeStateInfo()] * len(waves) positions = [0] * len(waves) results = [""] * len(waves) @@ -197,19 +197,21 @@ def main(): for i in range(len(waves)): if positions[i] == num_frames[i]: continue - current_state_infos.append(state_infos[i]) - current_wave_ids.append(i) start = positions[i] - if (num_frames[i] - positions[i]) < args.chunk_size: + if (num_frames[i] - positions[i]) <= args.chunk_size: current_num_frames.append(num_frames[i] - positions[i]) end = num_frames[i] positions[i] = num_frames[i] + state_infos[i].is_final = True else: current_num_frames.append(args.chunk_size) end = positions[i] + args.chunk_size positions[i] += args.chunk_size + current_state_infos.append(state_infos[i]) + current_wave_ids.append(i) current_nnet_outputs.append(nnet_output[i, start:end, :]) + if len(current_wave_ids) == args.num_streams: break if len(current_wave_ids) == 0: @@ -222,7 +224,7 @@ def main(): device=nnet_output.device, ) ) - current_state_infos.append(None) + current_state_infos.append(DecodeStateInfo()) current_nnet_outputs = pad_sequence(current_nnet_outputs, batch_first=True) supervision_segments = torch.tensor( @@ -244,6 +246,7 @@ def main(): else: assert args.method == "1best", args.method hyps = [" ".join([word_sym_table[i] for i in ids]) for ids in symbol_ids] + logging.info(f"hyps : {hyps}") s = "\n" for i in range(len(current_wave_ids)): @@ -253,8 +256,6 @@ def main(): s += f"{results[current_wave_ids[i]]}\n\n" logging.info(s) - torch.save(lattice.as_dict(), "online.pt") - s = "\n" for filename, hyp in zip(args.sound_files, results): s += f"{filename}:\n{hyp}\n\n" From a90a122b1093fa15b187e5a21d4c0776e23938fc Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 16 Jul 2023 17:08:43 +0800 Subject: [PATCH 09/22] Minor fixes --- k2/csrc/intersect_dense_pruned.cu | 19 +++++++++++++------ k2/torch/bin/online_decode.cu | 2 +- k2/torch/bin/online_decode.py | 1 - 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 91f6b88a3..59932e4fd 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -677,7 +677,7 @@ class MultiGraphDenseIntersectPruned { } // no pruning on last frame; we want all final-arcs. // -1 because t starts from 0. - if (t == final_t - 1) dynamic_beam = 1.0e+10; + if (!online_decoding && t == final_t - 1) dynamic_beam = 1.0e+10; dynamic_beams_data[i] = dynamic_beam; cutoffs_data[i] = best_loglike - dynamic_beam; @@ -774,7 +774,10 @@ class MultiGraphDenseIntersectPruned { } bool online_decoding = online_decoding_; - bool *is_final_data = is_final_.Data(); + bool *is_final_data; + if (online_decoding) { + is_final_data = is_final_.Data(); + } K2_EVAL( c_, ai.values.Dim(), ai_lambda, (int32_t ai_arc_idx012)->void { @@ -797,11 +800,15 @@ class MultiGraphDenseIntersectPruned { float acoustic_score = scores_acc(scores_idx01, scores_idx2); auto dest_state = arc.dest_state; auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0]; - bool is_final_chunk = is_final_data[ai_fsa_idx0]; - if ((online_decoding && final_t - 1 == t && !is_final_chunk) || - (final_t - 1 == t && !has_valid_final_arc_data[ai_fsa_idx0] && - allow_partial)) { + bool is_final_chunk = false; + if (online_decoding) { + is_final_chunk = is_final_data[ai_fsa_idx0]; + } + + if (final_t - 1 == t && + ((online_decoding && !is_final_chunk) || + (allow_partial && !has_valid_final_arc_data[ai_fsa_idx0]))) { int32_t a_fsas_idx0 = a_fsas_row_ids1[sinfo.a_fsas_state_idx01]; // state_idx1 is 0-based. // So "-1" is used when calculating a_fsas_final_state_idx1. diff --git a/k2/torch/bin/online_decode.cu b/k2/torch/bin/online_decode.cu index 4ad242127..a345ecce5 100644 --- a/k2/torch/bin/online_decode.cu +++ b/k2/torch/bin/online_decode.cu @@ -227,7 +227,7 @@ int main(int argc, char *argv[]) { std::vector positions(num_waves, 0); int32_t T = nnet_output.size(1); - int32_t chunk_size = 20; // 20 frames per chunk + int32_t chunk_size = 10; // 10 frames per chunk // simulate asynchronous decoding while (true) { diff --git a/k2/torch/bin/online_decode.py b/k2/torch/bin/online_decode.py index f56e56fc5..6d0df70a8 100644 --- a/k2/torch/bin/online_decode.py +++ b/k2/torch/bin/online_decode.py @@ -232,7 +232,6 @@ def main(): [[i, 0, current_num_frames[i]] for i in range(args.num_streams)], dtype=torch.int32, ) - logging.info(f"supervision_segments : {supervision_segments}") dense_fsa_vec = k2.DenseFsaVec(current_nnet_outputs, supervision_segments) lattice, current_state_infos = intersector.decode( dense_fsa_vec, current_state_infos From e93ce564a21e8aad9b5ab45b8a62a92350651658 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 19 Jul 2023 12:08:28 +0800 Subject: [PATCH 10/22] Minor fixes --- k2/python/csrc/torch/fsa_algo.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index aa681e6c5..24e557145 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -753,7 +753,7 @@ static void PybindLevenshteinGraph(py::module &m) { static void PybindDecodeStateInfo(py::module &m) { using PyClass = DecodeStateInfo; - py::class_> state_info( + py::class_ state_info( m, "DecodeStateInfo"); state_info.def(py::init<>()); state_info.def_readwrite("is_final", &PyClass::is_final); @@ -782,15 +782,15 @@ static void PybindOnlineDenseIntersecter(py::module &m) { intersecter.def( "decode", [](PyClass &self, DenseFsaVec &dense_fsa_vec, - std::vector> &decode_states) + std::vector &decode_states) -> std::tuple>> { + std::vector> { DeviceGuard guard(self.Context()); FsaVec ofsa; Array1 arc_map; std::vector decode_states_ptr(decode_states.size()); for (size_t i = 0; i < decode_states.size(); ++i) { - decode_states_ptr[i] = decode_states[i].get(); + decode_states_ptr[i] = &decode_states[i]; } self.Decode(dense_fsa_vec, &decode_states_ptr, &ofsa, &arc_map); torch::Tensor arc_map_tensor = ToTorch(arc_map); From 69a98408a6302d180bf41e4d353c86af3d70200d Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 19 Jul 2023 12:13:25 +0800 Subject: [PATCH 11/22] Fix style --- k2/torch/bin/hlg_decode.py | 22 +++++++++++++++++----- k2/torch/bin/online_decode.py | 34 +++++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/k2/torch/bin/hlg_decode.py b/k2/torch/bin/hlg_decode.py index aef48e6e8..be5096e3a 100644 --- a/k2/torch/bin/hlg_decode.py +++ b/k2/torch/bin/hlg_decode.py @@ -22,7 +22,10 @@ def get_parser(): ) parser.add_argument( - "--nn-model", type=str, required=True, help="Path to the jit script model. " + "--nn-model", + type=str, + required=True, + help="Path to the jit script model.", ) parser.add_argument( @@ -144,7 +147,9 @@ def main(): for f in features: feature_len.append(f.shape[0]) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) # Note: We don't use key padding mask for attention during decoding nnet_output, _, _ = model(features) @@ -157,7 +162,10 @@ def main(): logging.info("Use CTC decoding") max_token_id = args.num_classes - 1 - H = k2.ctc_topo(max_token=max_token_id, device=device,) + H = k2.ctc_topo( + max_token=max_token_id, + device=device, + ) lattice = get_lattice( log_prob=log_prob, @@ -187,7 +195,9 @@ def main(): if args.method == "1best": logging.info("Use HLG decoding") - best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + best_path = one_best_decoding( + lattice=lattice, use_double_scores=True + ) hyps = get_aux_labels(best_path) word_sym_table = k2.SymbolTable.from_file(args.words_file) @@ -205,7 +215,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/k2/torch/bin/online_decode.py b/k2/torch/bin/online_decode.py index 6d0df70a8..1e5456ec4 100644 --- a/k2/torch/bin/online_decode.py +++ b/k2/torch/bin/online_decode.py @@ -23,7 +23,10 @@ def get_parser(): ) parser.add_argument( - "--nn-model", type=str, required=True, help="Path to the jit script model. " + "--nn-model", + type=str, + required=True, + help="Path to the jit script model.", ) parser.add_argument( @@ -153,12 +156,13 @@ def main(): for f in features: feature_len.append(f.shape[0]) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) # Note: We don't use key padding mask for attention during decoding nnet_output, _, _ = model(features) num_frames = [x // args.subsampling_factor for x in feature_len] - T = nnet_output.shape[1] if args.method == "ctc-decoding": logging.info("Use CTC decoding") @@ -171,7 +175,9 @@ def main(): else: assert args.method == "1best", args.method logging.info(f"Loading HLG from {args.HLG}") - decoding_graph = k2.Fsa.from_dict(torch.load(args.HLG, map_location="cpu")) + decoding_graph = k2.Fsa.from_dict( + torch.load(args.HLG, map_location="cpu") + ) decoding_graph = decoding_graph.to(device) word_sym_table = k2.SymbolTable.from_file(args.words_file) decoding_graph = k2.Fsa.from_fsas([decoding_graph]) @@ -226,13 +232,17 @@ def main(): ) current_state_infos.append(DecodeStateInfo()) - current_nnet_outputs = pad_sequence(current_nnet_outputs, batch_first=True) + current_nnet_outputs = pad_sequence( + current_nnet_outputs, batch_first=True + ) supervision_segments = torch.tensor( # seq_index, start_time, duration [[i, 0, current_num_frames[i]] for i in range(args.num_streams)], dtype=torch.int32, ) - dense_fsa_vec = k2.DenseFsaVec(current_nnet_outputs, supervision_segments) + dense_fsa_vec = k2.DenseFsaVec( + current_nnet_outputs, supervision_segments + ) lattice, current_state_infos = intersector.decode( dense_fsa_vec, current_state_infos ) @@ -241,10 +251,14 @@ def main(): symbol_ids = get_aux_labels(best_path) if args.method == "ctc-decoding": - hyps = ["".join([token_sym_table[i] for i in ids]) for ids in symbol_ids] + hyps = [ + "".join([token_sym_table[i] for i in ids]) for ids in symbol_ids + ] else: assert args.method == "1best", args.method - hyps = [" ".join([word_sym_table[i] for i in ids]) for ids in symbol_ids] + hyps = [ + " ".join([word_sym_table[i] for i in ids]) for ids in symbol_ids + ] logging.info(f"hyps : {hyps}") s = "\n" @@ -264,7 +278,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() From 2bf8a0278b9eb17ac3d59a84af7eefc0a2835ba9 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 19 Jul 2023 14:38:26 +0800 Subject: [PATCH 12/22] Fix online intersecter test --- k2/python/tests/online_dense_intersecter_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k2/python/tests/online_dense_intersecter_test.py b/k2/python/tests/online_dense_intersecter_test.py index 2da700bb1..b657264c4 100644 --- a/k2/python/tests/online_dense_intersecter_test.py +++ b/k2/python/tests/online_dense_intersecter_test.py @@ -59,7 +59,7 @@ def test(self): num_chunks = 3 chunk_size = 5 - decode_states = [None] * num_streams + decode_states = [k2.DecodeStateInfo()] * num_streams for i in range(num_chunks): logits = torch.randn( From 62cb12695eca745691714e3b2b010008fbd4281f Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 19 Jul 2023 17:26:15 +0800 Subject: [PATCH 13/22] Fix test --- k2/csrc/intersect_dense_pruned.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 59932e4fd..aaf08b5c8 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -95,7 +95,7 @@ class MultiGraphDenseIntersectPruned { forward_semaphore_(1), final_t_(a_fsas.Context(), num_seqs, 0) { NVTX_RANGE(K2_FUNC); - + T_ = 0; c_ = GetContext(a_fsas.shape); K2_CHECK_GT(search_beam, 0); K2_CHECK_GT(output_beam, 0); @@ -774,7 +774,7 @@ class MultiGraphDenseIntersectPruned { } bool online_decoding = online_decoding_; - bool *is_final_data; + bool *is_final_data = nullptr; if (online_decoding) { is_final_data = is_final_.Data(); } From f8ec30d805ffdac4deee73a1a215dbe108adc956 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 19 Jul 2023 17:33:02 +0800 Subject: [PATCH 14/22] Fix cpp style --- k2/csrc/intersect_dense_pruned.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k2/csrc/intersect_dense_pruned.h b/k2/csrc/intersect_dense_pruned.h index f05b86308..bdfb9b318 100644 --- a/k2/csrc/intersect_dense_pruned.h +++ b/k2/csrc/intersect_dense_pruned.h @@ -173,7 +173,7 @@ class OnlineDenseIntersecter { public: OnlineDenseIntersecter(FsaVec &a_fsas, int32_t num_seqs, float search_beam, float output_beam, int32_t min_states, - int32_t max_states, bool allow_partial=true); + int32_t max_states, bool allow_partial = true); /* Does intersection/composition for current chunk of nnet_output(given by a DenseFsaVec), sequences in every chunk may come from different From d49f76748790df8d89e9d43fcd734b0f4b5eff90 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 3 Aug 2023 15:21:27 +0800 Subject: [PATCH 15/22] Support decoding with a wav scp --- k2/python/k2/dense_fsa_vec.py | 2 +- k2/python/k2/online_dense_intersecter.py | 5 + k2/torch/bin/hlg_decode.py | 213 +++++++++----- k2/torch/bin/online_decode.py | 351 ++++++++++++++--------- 4 files changed, 372 insertions(+), 199 deletions(-) diff --git a/k2/python/k2/dense_fsa_vec.py b/k2/python/k2/dense_fsa_vec.py index 37c7e040c..580fc21ce 100644 --- a/k2/python/k2/dense_fsa_vec.py +++ b/k2/python/k2/dense_fsa_vec.py @@ -102,7 +102,7 @@ def __init__(self, segment_index, start_frame, duration = segment assert 0 <= segment_index < N assert 0 <= start_frame < T - assert duration > 0 + assert duration >= 0 assert start_frame + duration <= T + allow_truncate offset = segment_index * T end_frame = min(start_frame + duration, T) # exclusive diff --git a/k2/python/k2/online_dense_intersecter.py b/k2/python/k2/online_dense_intersecter.py index 4faebd5b0..b44713b23 100644 --- a/k2/python/k2/online_dense_intersecter.py +++ b/k2/python/k2/online_dense_intersecter.py @@ -91,6 +91,7 @@ def __init__( decode_states[1] = new_decode_states[1] ... """ + self.num_streams_ = num_streams self.decoding_graph = decoding_graph self.device = decoding_graph.device self.intersecter = _k2.OnlineDenseIntersecter( @@ -102,6 +103,10 @@ def __init__( max_active_states, ) + @property + def num_streams(self) -> int: + return self.num_streams_ + def decode( self, dense_fsas: DenseFsaVec, decode_states: List[DecodeStateInfo] ) -> Tuple[Fsa, List[DecodeStateInfo]]: diff --git a/k2/torch/bin/hlg_decode.py b/k2/torch/bin/hlg_decode.py index be5096e3a..714538eb0 100644 --- a/k2/torch/bin/hlg_decode.py +++ b/k2/torch/bin/hlg_decode.py @@ -1,7 +1,8 @@ import argparse import logging import math -from typing import List +import os +from typing import Any, Dict, List, Optional, Tuple import k2 import kaldifeat @@ -68,10 +69,29 @@ def get_parser(): """, ) + parser.add_argument( + "--wav-scp", + type=str, + help="""The audio lists to transcribe in wav.scp format""", + ) + + parser.add_argument( + "--output-file", + type=str, + help="The file to write out results to, only used when giving --wav-scp", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=5, + help="The number of wavs in a batch.", + ) + parser.add_argument( "sound_files", type=str, - nargs="+", + nargs="*", help="The input sound file(s) to transcribe. " "Supported formats are those supported by torchaudio.load(). " "For example, wav and flac are supported. " @@ -104,6 +124,61 @@ def read_sound_files( return ans +def decode_one_batch( + params: object, + batch: List[Tuple[str, str]], + model: torch.nn.Module, + feature_extractor: kaldifeat.Fbank, + decoding_graph: k2.Fsa, + token_sym_table: Optional[k2.SymbolTable] = None, + word_sym_table: Optional[k2.SymbolTable] = None, +) -> Dict[str, str]: + device = params.device + filenames = [x[1] for x in batch] + waves = read_sound_files( + filenames=filenames, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + features = feature_extractor(waves) + + feature_len = [] + for f in features: + feature_len.append(f.shape[0]) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + # Note: We don't use key padding mask for attention during decoding + nnet_output, _, _ = model(features) + + log_prob = torch.nn.functional.log_softmax(nnet_output, dim=-1) + log_prob_len = torch.tensor(feature_len) // params.subsampling_factor + log_prob_len = log_prob_len.to(device) + + lattice = get_lattice( + log_prob=log_prob, + log_prob_len=log_prob_len, + decoding_graph=decoding_graph, + subsampling_factor=params.subsampling_factor, + ) + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + + hyps = get_aux_labels(best_path) + + if params.method == "ctc-decoding": + hyps = ["".join([token_sym_table[i] for i in ids]) for ids in hyps] + else: + assert params.method == "1best", params.method + hyps = [" ".join([word_sym_table[i] for i in ids]) for ids in hyps] + + results = {} + for i, hyp in enumerate(hyps): + results[batch[i][0]] = hyp.replace("▁", " ").strip() + return results + + def main(): parser = get_parser() args = parser.parse_args() @@ -113,11 +188,36 @@ def main(): args.feature_dim = 80 args.num_classes = 500 + wave_list: List[Tuple[str, str]] = [] + if args.wav_scp is not None: + assert os.path.isfile( + args.wav_scp + ), f"wav_scp not exists : {args.wav_scp}" + assert ( + args.output_file is not None + ), "You should provide output_file when using wav_scp" + with open(args.wav_scp, "r") as f: + for line in f: + toks = line.strip().split() + assert len(toks) == 2, toks + if not os.path.isfile(toks[1]): + logging.warning(f"File {toks[1]} not exists, skipping.") + continue + wave_list.append(toks) + else: + assert len(args.sound_files) > 0, "No wav_scp or waves provided." + for i, f in enumerate(args.sound_files): + if not os.path.isfile(f): + logging.warning(f"File {f} not exists, skipping.") + continue + wave_list.append((i, f)) + device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + args.device = device - logging.info(f"device: {device}") + logging.info(f"params : {args}") logging.info("Creating model") model = torch.jit.load(args.nn_model) @@ -134,82 +234,59 @@ def main(): fbank = kaldifeat.Fbank(opts) - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, expected_sample_rate=args.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - feature_len = [] - for f in features: - feature_len.append(f.shape[0]) - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - - # Note: We don't use key padding mask for attention during decoding - nnet_output, _, _ = model(features) - - log_prob = torch.nn.functional.log_softmax(nnet_output, dim=-1) - log_prob_len = torch.tensor(feature_len) // args.subsampling_factor - log_prob_len = log_prob_len.to(device) - + token_sym_table = None + word_sym_table = None if args.method == "ctc-decoding": logging.info("Use CTC decoding") max_token_id = args.num_classes - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - device=device, - ) - - lattice = get_lattice( - log_prob=log_prob, - log_prob_len=log_prob_len, - decoding_graph=H, - subsampling_factor=args.subsampling_factor, - ) - - best_path = one_best_decoding(lattice=lattice, use_double_scores=True) - token_ids = get_aux_labels(best_path) + decoding_graph = k2.ctc_topo(max_token=max_token_id, device=device,) token_sym_table = k2.SymbolTable.from_file(args.tokens) - - hyps = ["".join([token_sym_table[i] for i in ids]) for ids in token_ids] - else: assert args.method == "1best", args.method logging.info(f"Loading HLG from {args.HLG}") - HLG = k2.Fsa.from_dict(torch.load(args.HLG, map_location="cpu")) - HLG = HLG.to(device) - - lattice = get_lattice( - log_prob=log_prob, - log_prob_len=log_prob_len, - decoding_graph=HLG, - subsampling_factor=args.subsampling_factor, + decoding_graph = k2.Fsa.from_dict( + torch.load(args.HLG, map_location="cpu") ) - - if args.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=True - ) - - hyps = get_aux_labels(best_path) + decoding_graph = decoding_graph.to(device) word_sym_table = k2.SymbolTable.from_file(args.words_file) - hyps = [" ".join([word_sym_table[i] for i in ids]) for ids in hyps] + decoding_graph = k2.Fsa.from_fsas([decoding_graph]) + + results = {} + start = 0 + while start + args.batch_size <= len(wave_list): + + if start % 100 == 0: + logging.info(f"Decoding progress: {start}/{len(wave_list)}.") + + res = decode_one_batch( + params=args, + batch=wave_list[start : start + args.batch_size], + model=model, + feature_extractor=fbank, + decoding_graph=decoding_graph, + token_sym_table=token_sym_table, + word_sym_table=word_sym_table, + ) + start += args.batch_size - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = hyp.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - logging.info(s) + results.update(res) - torch.save(lattice.as_dict(), "offline.pt") + logging.info(f"results : {results}") + + if args.wav_scp is not None: + output_dir = os.path.dirname(args.output_file) + if output_dir != "": + os.makedirs(output_dir, exist_ok=True) + with open(args.output_file, "w", encoding="utf-8") as f: + for x in wave_list: + f.write(x[0] + "\t" + results[x[0]] + "\n") + logging.info(f"Decoding results are written to {args.output_file}") + else: + s = "\n" + logging.info(f"results : {results}") + for x in wave_list: + s += f"{x[1]}:\n{results[x[0]]}\n\n" + logging.info(s) logging.info("Decoding Done") diff --git a/k2/torch/bin/online_decode.py b/k2/torch/bin/online_decode.py index 1e5456ec4..33236980d 100644 --- a/k2/torch/bin/online_decode.py +++ b/k2/torch/bin/online_decode.py @@ -1,20 +1,21 @@ import argparse import logging import math -from typing import List +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple import k2 import kaldifeat import torch import torchaudio -from torch.nn.utils.rnn import pad_sequence - from k2 import ( DecodeStateInfo, OnlineDenseIntersecter, one_best_decoding, get_aux_labels, ) +from torch.nn.utils.rnn import pad_sequence def get_parser(): @@ -76,10 +77,29 @@ def get_parser(): help="""The number of streams that can be run in parallel.""", ) + parser.add_argument( + "--wav-scp", + type=str, + help="""The audio lists to transcribe in wav.scp format""", + ) + + parser.add_argument( + "--output-file", + type=str, + help="The file to write out results to, only used when giving --wav-scp", + ) + + parser.add_argument( + "--print-partial", + dest="print_partial", + action="store_true", + help="Whether print partial results.", + ) + parser.add_argument( "sound_files", type=str, - nargs="+", + nargs="*", help="The input sound file(s) to transcribe. " "Supported formats are those supported by torchaudio.load(). " "For example, wav and flac are supported. " @@ -89,27 +109,151 @@ def get_parser(): return parser -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans +@dataclass +class DecodeStream: + # The identifier of wavs. + utt_id: str + # The total number of frames for current nnet_output. + num_frames: int + # The output of encoder. + nnet_output: torch.Tensor + # Current position, index in to feature. + position: int + # Decode state for intersect_dense_pruned. + state_info: DecodeStateInfo + # Current decoding result. + result: str + + +def decode_one_chunk( + params: object, + intersector: k2.OnlineDenseIntersecter, + streams: List[DecodeStream], + token_sym_table: Optional[k2.SymbolTable] = None, + word_sym_table: Optional[k2.SymbolTable] = None, +) -> List[int]: + assert params.num_streams == intersector.num_streams, ( + params.num_streams, + intersector.num_streams, + ) + current_state_infos = [] + current_nnet_outputs = [] + current_num_frames = [] + finised_streams = [] + for i, stream in enumerate(streams): + start = stream.position + if (stream.num_frames - stream.position) <= params.chunk_size: + current_num_frames.append(stream.num_frames - stream.position) + end = stream.num_frames + stream.position = stream.num_frames + stream.state_info.is_final = True + finised_streams.append(i) + else: + current_num_frames.append(params.chunk_size) + end = stream.position + params.chunk_size + stream.position += params.chunk_size + current_state_infos.append(stream.state_info) + current_nnet_outputs.append(stream.nnet_output[start:end, :]) + + while len(current_num_frames) < params.num_streams: + current_num_frames.append(0) + current_nnet_outputs.append( + torch.zeros( + (params.chunk_size, params.num_classes), device=params.device, + ) + ) + current_state_infos.append(DecodeStateInfo()) + + current_nnet_outputs = pad_sequence(current_nnet_outputs, batch_first=True) + supervision_segments = torch.tensor( + # seq_index, start_time, duration + [[i, 0, current_num_frames[i]] for i in range(params.num_streams)], + dtype=torch.int32, + ) + dense_fsa_vec = k2.DenseFsaVec(current_nnet_outputs, supervision_segments) + lattice, current_state_infos = intersector.decode( + dense_fsa_vec, current_state_infos + ) + + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + symbol_ids = get_aux_labels(best_path) + + if params.method == "ctc-decoding": + assert token_sym_table is not None + hyps = [ + "".join([token_sym_table[i] for i in ids]) for ids in symbol_ids + ] + else: + assert word_sym_table is not None + assert params.method == "1best", params.method + hyps = [ + " ".join([word_sym_table[i] for i in ids]) for ids in symbol_ids + ] + for i, stream in enumerate(streams): + stream.state_info = current_state_infos[i] + stream.result = hyps[i].replace("▁", " ").strip() + return finised_streams + + +def decode_dataset( + params: object, + waves: List[Tuple[str, str]], + model: torch.nn.Module, + feature_extractor: kaldifeat.Fbank, + intersector: k2.OnlineDenseIntersecter, + token_sym_table: Optional[k2.SymbolTable] = None, + word_sym_table: Optional[k2.SymbolTable] = None, +) -> Dict[str, str]: + results = {} + decode_streams = [] + wave_index = 0 + while True: + if wave_index < len(waves) and len(decode_streams) < params.num_streams: + data, sample_rate = torchaudio.load(waves[wave_index][1]) + assert ( + sample_rate == params.sample_rate + ), f"expected sample rate: {params.sample_rate}. Given: {sample_rate}" + data = data[0].to(params.device) + feature = feature_extractor(data) + nnet_output, _, _ = model(feature.unsqueeze(0)) + decode_streams.append( + DecodeStream( + utt_id=waves[wave_index][0], + num_frames=nnet_output.shape[1], + nnet_output=nnet_output[0], + position=0, + state_info=DecodeStateInfo(), + result="", + ) + ) + wave_index += 1 + if wave_index % 100 == 0: + logging.info(f"Decoding progress: {wave_index}/{len(waves)}.") + continue + + if len(decode_streams) == 0: + break + + finised_streams = decode_one_chunk( + params=params, + intersector=intersector, + streams=decode_streams, + token_sym_table=token_sym_table, + word_sym_table=word_sym_table, + ) + + if params.print_partial: + s = "\n" + for stream in decode_streams: + s += f"{stream.utt_id}:\t{stream.result}\n\n" + logging.info(s) + + if finised_streams: + finised_streams = sorted(finised_streams, reverse=True) + for j in finised_streams: + results[decode_streams[j].utt_id] = decode_streams[j].result + del decode_streams[j] + return results def main(): @@ -122,11 +266,38 @@ def main(): args.num_classes = 500 args.chunk_size = 10 + wave_list: List[Tuple[str, str]] = [] + if args.wav_scp is not None: + assert os.path.isfile( + args.wav_scp + ), f"wav_scp not exists : {args.wav_scp}" + assert ( + args.output_file is not None + ), "You should provide output_file when using wav_scp" + with open(args.wav_scp, "r") as f: + for line in f: + toks = line.strip().split() + assert len(toks) == 2, toks + if not os.path.isfile(toks[1]): + logging.warning(f"File {toks[1]} not exists, skipping.") + continue + wave_list.append(toks) + else: + assert len(args.sound_files) > 0, "No wav_scp or waves provided." + for i, f in enumerate(args.sound_files): + if not os.path.isfile(f): + logging.warning(f"File {f} not exists, skipping.") + continue + wave_list.append((i, f)) + + # logging.info(f"wave_list : {wave_list}") + device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + args.device = device - logging.info(f"device: {device}") + logging.info(f"params : {args}") logging.info("Creating model") model = torch.jit.load(args.nn_model) @@ -143,34 +314,12 @@ def main(): fbank = kaldifeat.Fbank(opts) - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, expected_sample_rate=args.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - feature_len = [] - for f in features: - feature_len.append(f.shape[0]) - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - - # Note: We don't use key padding mask for attention during decoding - nnet_output, _, _ = model(features) - num_frames = [x // args.subsampling_factor for x in feature_len] - + token_sym_table = None + word_sym_table = None if args.method == "ctc-decoding": logging.info("Use CTC decoding") max_token_id = args.num_classes - 1 - decoding_graph = k2.ctc_topo( - max_token=max_token_id, - device=device, - ) + decoding_graph = k2.ctc_topo(max_token=max_token_id, device=device,) token_sym_table = k2.SymbolTable.from_file(args.tokens) else: assert args.method == "1best", args.method @@ -191,89 +340,31 @@ def main(): max_active_states=10000, ) - state_infos = [DecodeStateInfo()] * len(waves) - positions = [0] * len(waves) - results = [""] * len(waves) - - while True: - current_state_infos = [] - current_nnet_outputs = [] - current_wave_ids = [] - current_num_frames = [] - for i in range(len(waves)): - if positions[i] == num_frames[i]: - continue - start = positions[i] - if (num_frames[i] - positions[i]) <= args.chunk_size: - current_num_frames.append(num_frames[i] - positions[i]) - end = num_frames[i] - positions[i] = num_frames[i] - state_infos[i].is_final = True - else: - current_num_frames.append(args.chunk_size) - end = positions[i] + args.chunk_size - positions[i] += args.chunk_size - - current_state_infos.append(state_infos[i]) - current_wave_ids.append(i) - current_nnet_outputs.append(nnet_output[i, start:end, :]) - - if len(current_wave_ids) == args.num_streams: - break - if len(current_wave_ids) == 0: - break - while len(current_num_frames) < args.num_streams: - current_num_frames.append(1) - current_nnet_outputs.append( - torch.zeros( - (args.chunk_size, nnet_output.shape[2]), - device=nnet_output.device, - ) - ) - current_state_infos.append(DecodeStateInfo()) - - current_nnet_outputs = pad_sequence( - current_nnet_outputs, batch_first=True - ) - supervision_segments = torch.tensor( - # seq_index, start_time, duration - [[i, 0, current_num_frames[i]] for i in range(args.num_streams)], - dtype=torch.int32, - ) - dense_fsa_vec = k2.DenseFsaVec( - current_nnet_outputs, supervision_segments - ) - lattice, current_state_infos = intersector.decode( - dense_fsa_vec, current_state_infos - ) - - best_path = one_best_decoding(lattice=lattice, use_double_scores=True) - symbol_ids = get_aux_labels(best_path) - - if args.method == "ctc-decoding": - hyps = [ - "".join([token_sym_table[i] for i in ids]) for ids in symbol_ids - ] - else: - assert args.method == "1best", args.method - hyps = [ - " ".join([word_sym_table[i] for i in ids]) for ids in symbol_ids - ] - logging.info(f"hyps : {hyps}") + results = decode_dataset( + params=args, + waves=wave_list, + model=model, + feature_extractor=fbank, + intersector=intersector, + token_sym_table=token_sym_table, + word_sym_table=word_sym_table, + ) + if args.wav_scp is not None: + output_dir = os.path.dirname(args.output_file) + if output_dir != "": + os.makedirs(output_dir, exist_ok=True) + with open(args.output_file, "w", encoding="utf-8") as f: + for x in wave_list: + f.write(x[0] + "\t" + results[x[0]] + "\n") + logging.info(f"Decoding results are written to {args.output_file}") + else: s = "\n" - for i in range(len(current_wave_ids)): - state_infos[current_wave_ids[i]] = current_state_infos[i] - results[current_wave_ids[i]] = hyps[i].replace("▁", " ").strip() - s += f"{args.sound_files[current_wave_ids[i]]}:\n" - s += f"{results[current_wave_ids[i]]}\n\n" + logging.info(f"results : {results}") + for x in wave_list: + s += f"{x[1]}:\n{results[x[0]]}\n\n" logging.info(s) - s = "\n" - for filename, hyp in zip(args.sound_files, results): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - logging.info("Decoding Done") From ee497406b08b795aefffc2fc7c379fded847f270 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 7 Aug 2023 07:44:13 +0800 Subject: [PATCH 16/22] quick fix for online decoding --- k2/csrc/intersect_dense_pruned.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index aaf08b5c8..5f598ce42 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -298,7 +298,8 @@ class MultiGraphDenseIntersectPruned { } if (t == b_fsas_->shape.MaxSize(1)) { int32_t start = std::max(0, T_ - 3); - PruneTimeRange(start, T_ + t); + PruneTimeRange(start, T_ + t - 1); + PruneTimeRange(T_ + t - 1, T_ + t); } } // The FrameInfo for time T+1 will have no states. We did that From 3430ffecfb72c55f62e95792d20ab5e25a28b7e7 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 7 Aug 2023 17:04:10 +0800 Subject: [PATCH 17/22] Add GetFinalFrame; Fix the online decoding issue --- k2/csrc/intersect_dense_pruned.cu | 178 +++++++++++++++-------- k2/csrc/intersect_dense_pruned.h | 3 - k2/python/csrc/torch/fsa_algo.cu | 1 - k2/python/k2/online_dense_intersecter.py | 2 + k2/torch/bin/online_decode.cu | 1 - k2/torch/bin/online_decode.py | 3 +- 6 files changed, 121 insertions(+), 67 deletions(-) diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 5f598ce42..e612921ac 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -257,17 +257,7 @@ class MultiGraphDenseIntersectPruned { const std::vector>* OnlineIntersect( DenseFsaVec *b_fsas, std::vector> &frames, - Array1 &beams, - Array1 &is_final) { - /* - T is the largest number (frames+1) of neural net output currently - received, or the largest number of frames of log-likelihoods we count the - final frame with (0, -inf, -inf..) that is used for the final-arc. - The largest number of states in the fsas represented by b_fsas equals - T+1 (e.g. 1 frame would require 2 states, because that 1 frame is the arc - from state 0 to state 1). So the #states is 2 greater than the actual - number of frames in the neural-net output. - */ + Array1 &beams) { K2_CHECK(online_decoding_); K2_CHECK(c_->IsCompatible(*b_fsas->Context())); K2_CHECK_EQ(a_fsas_.shape.Dim0(), 1); @@ -277,17 +267,21 @@ class MultiGraphDenseIntersectPruned { b_fsas_ = b_fsas; frames_.swap(frames); dynamic_beams_ = beams.To(c_); - is_final_ = is_final.To(c_); - T_ = frames_.size() - 1; - // -1 here because we already put the initial frame info to frames_ - int32_t T = T_ + b_fsas_->shape.MaxSize(1); + // T_ is the actual number of frames we have already processed in previous + // chunks, -1 here because frames_ includes the initial frame. + T_ = frames_.size() - 1; + // -1 here because we add extra frame to b_fsas_ (to handle -1 arc) + // see dense_fsa_vec.py for more details of converting nnet_outputs to fsas. + int32_t chunk_size = b_fsas_->shape.MaxSize(1) - 1; + int32_t T = T_ + chunk_size; - // we'll initially populate frames_[0.. T+1], but discard the one at T+1, - // which has no arcs or states, the ones we use are from 0 to T. - frames_.reserve(T + 2); + // plus initial frame, we actually have T + 1 frames. + frames_.reserve(T + 1); - for (int32_t t = 0; t <= b_fsas_->shape.MaxSize(1); t++) { + // we only do PropagateForward for real frames(i.e. not including the extra + // frame we added to b_fsas_. + for (int32_t t = 0; t < chunk_size; t++) { if (state_map_.NumKeyBits() == 32) { frames_.push_back(PropagateForward<32>(t, frames_.back().get())); } else if (state_map_.NumKeyBits() == 36) { @@ -296,25 +290,12 @@ class MultiGraphDenseIntersectPruned { K2_CHECK_EQ(state_map_.NumKeyBits(), 40); frames_.push_back(PropagateForward<40>(t, frames_.back().get())); } - if (t == b_fsas_->shape.MaxSize(1)) { - int32_t start = std::max(0, T_ - 3); - PruneTimeRange(start, T_ + t - 1); - PruneTimeRange(T_ + t - 1, T_ + t); + if (t == chunk_size - 1) { + int32_t start = std::max(0, T_ - 2); + PruneTimeRange(start, T_ + t + 1); } } - // The FrameInfo for time T+1 will have no states. We did that - // last PropagateForward so that the 'arcs' member of frames_[T] - // is set up (it has no arcs but we need the shape). - frames_.pop_back(); - int32_t history_t = T_; - - T_ = T; - // partial_final_frame_ is the last frame to generate partial result, - // but it should not be the start frame of next chunk decoding. - partial_final_frame_ = std::move(frames_.back()); - frames_.pop_back(); - const int32_t *b_fsas_row_splits1 = b_fsas_->shape.RowSplits(1).Data(); int32_t *final_t_data = final_t_.Data(); @@ -325,9 +306,104 @@ class MultiGraphDenseIntersectPruned { b_fsas_row_splits1[i + 1] - b_fsas_row_splits1[i]; final_t_data[i] = history_t + b_chunk_size; }); + + // T_ will be used in FormatOutput, plus 1 here because we need an extra + // frame for final arcs (i.e. the partial_final_frame return by + // GetFinalFrame()) to construct the lattice. + T_ = T + 1; return &frames_; } + /* Propagate the last frame in b_fsas_(i.e. the extra frame containing only 0 + and -infs). See dense_fsa_vec.py to get more details of b_fsas_. + + The purpose of this function is to get the final states to construct + partial results for online decoding. It suppose to be invoked in + FormatOutput when online_decoding_ is True. + + This function returns the final FrameInfo needed by the FormatOutput. The + final_frame->states contains the final state for each sequence (if it has), + the final_frame->arcs actually contains no arc at all, but we need its + shape. + + This function also adds the arcs to frames_.back(), normally the arcs of + frames_.back() will be populated in next ForwardPass, we populate it here + so that we can get valid fsas in FormatOutput. It will not affect the + ForwardPass because the ForwardPass only need the states in frames_.back(). + Actually we will re-expand the arcs in frames_.back() in the next + ForwardPass. + */ + std::unique_ptr GetFinalFrame() { + K2_CHECK(online_decoding_); + + // chunk_size is the index of the added extra frame. + int32_t chunk_size = b_fsas_->shape.MaxSize(1) - 1; + FrameInfo *cur_frame = frames_.back().get(); + + // These are all of the expanded arcs, actually we only need the arcs + // pointing to the final states. + auto arcs = GetArcs(chunk_size, cur_frame); + + int32_t num_fsas = NumFsas(); + + // Number of final states for each sequence, should be 0 or 1. + Array1 num_final_states(c_, num_fsas + 1, 0); + // Keep the arcs pointing to final states. + Renumbering renumber_arcs(c_, arcs.NumElements()); + char *keep_this_arc_data = renumber_arcs.Keep().Data(); + const int32_t *arcs_row_ids1_data = arcs.RowIds(1).Data(), + *arcs_row_ids2_data = arcs.RowIds(2).Data(), + *fsa_row_split1_data = a_fsas_.RowSplits(1).Data(); + int32_t *num_final_states_data = num_final_states.Data(); + ArcInfo *arcs_data = arcs.values.Data(); + + K2_EVAL( + c_, arcs.NumElements(), lambda_renumber_arc, (int32_t idx012) -> void { + int32_t idx01 = arcs_row_ids2_data[idx012], + idx0 = arcs_row_ids1_data[idx01]; + ArcInfo ai = arcs_data[idx012]; + // Arcs pointing to final states have non infinity scores + if (ai.arc_loglike - ai.arc_loglike == 0) { + num_final_states_data[idx0] = 1; + keep_this_arc_data[idx012] = 1; + } else { + keep_this_arc_data[idx012] = 0; + } + }); + + int32_t num_arcs = renumber_arcs.NumNewElems(); + const int32_t *new2old_data = renumber_arcs.New2Old().Data(); + Array1 new_arcs(c_, num_arcs); + ArcInfo *new_arcs_data = new_arcs.Data(); + + K2_EVAL(c_, num_arcs, lambda_set_new_arcs, (int32_t new_idx012) -> void { + int32_t old_idx012 = new2old_data[new_idx012]; + ArcInfo old_ai = arcs_data[old_idx012]; + // Only 1 state (the final state) in next frame, so idx1 is always 0. + old_ai.u.dest_info_state_idx1 = 0; + new_arcs_data[new_idx012] = old_ai; + }); + + auto old2new_rowsplits = renumber_arcs.Old2New(true); + auto old2new_shape = RaggedShape2(&old2new_rowsplits, nullptr, num_arcs); + auto total_shape = ComposeRaggedShapes(arcs.shape, old2new_shape); + auto new_arcs_shape = RemoveAxis(total_shape, 2); + cur_frame->arcs = Ragged(new_arcs_shape, new_arcs); + + std::unique_ptr ans = std::make_unique(); + ExclusiveSum(num_final_states, &num_final_states); + auto final_state_shape = RaggedShape2( + &num_final_states, nullptr, -1); + // No arcs for final frame, but we need its shape in FormatOutput. + auto state_to_arc_shape = RegularRaggedShape( + c_, final_state_shape.NumElements(), 0); + auto final_arc_shape = ComposeRaggedShapes( + final_state_shape, state_to_arc_shape); + ans->arcs = Ragged(final_arc_shape, Array1(c_, 0)); + return ans; + } + + void BackwardPass() { int32_t num_fsas = b_fsas_->shape.Dim0(), num_work_items = max_active_ * num_fsas * T_; @@ -401,7 +477,9 @@ class MultiGraphDenseIntersectPruned { bool online_decoding = online_decoding_; bool allow_partial = allow_partial_; + std::unique_ptr partial_final_frame; if (online_decoding) { + partial_final_frame = std::move(GetFinalFrame()); K2_CHECK(arc_map_a); K2_CHECK_EQ(arc_map_b, nullptr); } else { @@ -417,10 +495,10 @@ class MultiGraphDenseIntersectPruned { arcs_row_splits1_ptrs.Data()[t] = frames_[t]->arcs.RowSplits(1).Data(); } arcs_data_ptrs.Data()[T] = online_decoding - ? partial_final_frame_->arcs.values.Data() + ? partial_final_frame->arcs.values.Data() : frames_[T]->arcs.values.Data(); arcs_row_splits1_ptrs.Data()[T] = - online_decoding ? partial_final_frame_->arcs.RowSplits(1).Data() + online_decoding ? partial_final_frame->arcs.RowSplits(1).Data() : frames_[T]->arcs.RowSplits(1).Data(); // transfer to GPU if we're using a GPU @@ -484,7 +562,7 @@ class MultiGraphDenseIntersectPruned { for (int32_t t = 0; t < T; t++) arcs_shapes[t] = &(frames_[t]->arcs.shape); - arcs_shapes[T] = online_decoding ? &(partial_final_frame_->arcs.shape) + arcs_shapes[T] = online_decoding ? &(partial_final_frame->arcs.shape) : &(frames_[T]->arcs.shape); arcs_shapes[T + 1] = &final_arcs_shape; @@ -774,12 +852,6 @@ class MultiGraphDenseIntersectPruned { }); } - bool online_decoding = online_decoding_; - bool *is_final_data = nullptr; - if (online_decoding) { - is_final_data = is_final_.Data(); - } - K2_EVAL( c_, ai.values.Dim(), ai_lambda, (int32_t ai_arc_idx012)->void { int32_t ai_state_idx01 = ai_row_ids2[ai_arc_idx012], @@ -802,14 +874,8 @@ class MultiGraphDenseIntersectPruned { auto dest_state = arc.dest_state; auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0]; - bool is_final_chunk = false; - if (online_decoding) { - is_final_chunk = is_final_data[ai_fsa_idx0]; - } - if (final_t - 1 == t && - ((online_decoding && !is_final_chunk) || - (allow_partial && !has_valid_final_arc_data[ai_fsa_idx0]))) { + (allow_partial && !has_valid_final_arc_data[ai_fsa_idx0])) { int32_t a_fsas_idx0 = a_fsas_row_ids1[sinfo.a_fsas_state_idx01]; // state_idx1 is 0-based. // So "-1" is used when calculating a_fsas_final_state_idx1. @@ -1021,7 +1087,6 @@ class MultiGraphDenseIntersectPruned { int32_t dest_a_fsas_state_idx01 = info.u.dest_a_fsas_state_idx01; - uint64_t state_map_idx = dest_a_fsas_state_idx01 + fsa_id * state_map_fsa_stride; uint64_t state_idx01; @@ -1593,9 +1658,6 @@ class MultiGraphDenseIntersectPruned { bool online_decoding_; // true for online decoding. Array1 final_t_; // record the final frame id of each DenseFsa. - Array1 is_final_; // For online decoding, it has a dimension of - // b_fsas_->Dim0() indicating whether this is - // the final chunk of current sequence. std::unique_ptr partial_final_frame_; // store the final frame for // partial results @@ -1711,8 +1773,6 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, Array1 beams(GetCpuContext(), num_seqs); float *beams_data = beams.Data(); - Array1 is_final(GetCpuContext(), num_seqs); - bool *is_final_data = is_final.Data(); for (int32_t i = 0; i < num_seqs; ++i) { DecodeStateInfo *decode_state_ptr = decode_states->at(i); K2_CHECK(decode_state_ptr); @@ -1730,12 +1790,10 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, Array1(c_, std::vector{ArcInfo()})); decode_state_ptr->beam = search_beam_; - decode_state_ptr->is_final = false; } seq_states_ptr_vec[i] = &(decode_state_ptr->states); seq_arcs_ptr_vec[i] = &(decode_state_ptr->arcs); beams_data[i] = decode_state_ptr->beam; - is_final_data[i] = decode_state_ptr->is_final; } auto stack_states = Stack(0, num_seqs, seq_states_ptr_vec.data()); @@ -1764,7 +1822,7 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, } const auto new_frames = impl_->OnlineIntersect( - &b_fsas, frames, beams, is_final); + &b_fsas, frames, beams); impl_->FormatOutput(ofsa, arc_map_a, nullptr/*arc_map_b*/); diff --git a/k2/csrc/intersect_dense_pruned.h b/k2/csrc/intersect_dense_pruned.h index bdfb9b318..2d0860724 100644 --- a/k2/csrc/intersect_dense_pruned.h +++ b/k2/csrc/intersect_dense_pruned.h @@ -135,9 +135,6 @@ struct DecodeStateInfo { // current search beam for this sequence float beam; - - // True if the chunk to be decoded is the final chunk - bool is_final; }; diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index 24e557145..f875b3212 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -756,7 +756,6 @@ static void PybindDecodeStateInfo(py::module &m) { py::class_ state_info( m, "DecodeStateInfo"); state_info.def(py::init<>()); - state_info.def_readwrite("is_final", &PyClass::is_final); } static void PybindOnlineDenseIntersecter(py::module &m) { diff --git a/k2/python/k2/online_dense_intersecter.py b/k2/python/k2/online_dense_intersecter.py index b44713b23..efb751047 100644 --- a/k2/python/k2/online_dense_intersecter.py +++ b/k2/python/k2/online_dense_intersecter.py @@ -35,6 +35,7 @@ def __init__( output_beam: float, min_active_states: int, max_active_states: int, + allow_partial: bool = True, ) -> None: """Create a new online intersecter object. Args: @@ -101,6 +102,7 @@ def __init__( output_beam, min_active_states, max_active_states, + allow_partial=allow_partial, ) @property diff --git a/k2/torch/bin/online_decode.cu b/k2/torch/bin/online_decode.cu index a345ecce5..f1c75fbd9 100644 --- a/k2/torch/bin/online_decode.cu +++ b/k2/torch/bin/online_decode.cu @@ -248,7 +248,6 @@ int main(int argc, char *argv[]) { if (num_frames[i] <= chunk_size * subsampling_factor) { num_frame.push_back(num_frames[i]); num_frames[i] = 0; - states_info[i].is_final = true; } else { num_frame.push_back(chunk_size * subsampling_factor); num_frames[i] -= chunk_size * subsampling_factor; diff --git a/k2/torch/bin/online_decode.py b/k2/torch/bin/online_decode.py index 33236980d..09e4b18ec 100644 --- a/k2/torch/bin/online_decode.py +++ b/k2/torch/bin/online_decode.py @@ -146,7 +146,6 @@ def decode_one_chunk( current_num_frames.append(stream.num_frames - stream.position) end = stream.num_frames stream.position = stream.num_frames - stream.state_info.is_final = True finised_streams.append(i) else: current_num_frames.append(params.chunk_size) @@ -264,7 +263,7 @@ def main(): args.subsampling_factor = 4 args.feature_dim = 80 args.num_classes = 500 - args.chunk_size = 10 + args.chunk_size = 16 wave_list: List[Tuple[str, str]] = [] if args.wav_scp is not None: From 7db996f094fa6f4c0852ea0a96a18f14252a00ce Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 7 Aug 2023 18:44:07 +0800 Subject: [PATCH 18/22] Fix style --- k2/torch/bin/hlg_decode.py | 11 ++++++++--- k2/torch/bin/online_decode.py | 15 +++++++++++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/k2/torch/bin/hlg_decode.py b/k2/torch/bin/hlg_decode.py index 714538eb0..d95a6014c 100644 --- a/k2/torch/bin/hlg_decode.py +++ b/k2/torch/bin/hlg_decode.py @@ -78,7 +78,9 @@ def get_parser(): parser.add_argument( "--output-file", type=str, - help="The file to write out results to, only used when giving --wav-scp", + help=""" + The file to write out results to, only used when giving --wav-scp + """, ) parser.add_argument( @@ -239,7 +241,10 @@ def main(): if args.method == "ctc-decoding": logging.info("Use CTC decoding") max_token_id = args.num_classes - 1 - decoding_graph = k2.ctc_topo(max_token=max_token_id, device=device,) + decoding_graph = k2.ctc_topo( + max_token=max_token_id, + device=device, + ) token_sym_table = k2.SymbolTable.from_file(args.tokens) else: assert args.method == "1best", args.method @@ -260,7 +265,7 @@ def main(): res = decode_one_batch( params=args, - batch=wave_list[start : start + args.batch_size], + batch=wave_list[start: start + args.batch_size], model=model, feature_extractor=fbank, decoding_graph=decoding_graph, diff --git a/k2/torch/bin/online_decode.py b/k2/torch/bin/online_decode.py index 09e4b18ec..1d0bba3f6 100644 --- a/k2/torch/bin/online_decode.py +++ b/k2/torch/bin/online_decode.py @@ -86,7 +86,9 @@ def get_parser(): parser.add_argument( "--output-file", type=str, - help="The file to write out results to, only used when giving --wav-scp", + help=""" + The file to write out results to, only used when giving --wav-scp + """, ) parser.add_argument( @@ -158,7 +160,8 @@ def decode_one_chunk( current_num_frames.append(0) current_nnet_outputs.append( torch.zeros( - (params.chunk_size, params.num_classes), device=params.device, + (params.chunk_size, params.num_classes), + device=params.device, ) ) current_state_infos.append(DecodeStateInfo()) @@ -211,7 +214,8 @@ def decode_dataset( data, sample_rate = torchaudio.load(waves[wave_index][1]) assert ( sample_rate == params.sample_rate - ), f"expected sample rate: {params.sample_rate}. Given: {sample_rate}" + ), f"expected sample rate: {params.sample_rate}. " + f"Given: {sample_rate}" data = data[0].to(params.device) feature = feature_extractor(data) nnet_output, _, _ = model(feature.unsqueeze(0)) @@ -318,7 +322,10 @@ def main(): if args.method == "ctc-decoding": logging.info("Use CTC decoding") max_token_id = args.num_classes - 1 - decoding_graph = k2.ctc_topo(max_token=max_token_id, device=device,) + decoding_graph = k2.ctc_topo( + max_token=max_token_id, + device=device, + ) token_sym_table = k2.SymbolTable.from_file(args.tokens) else: assert args.method == "1best", args.method From d428106bb64a4816a6f6b979998c95585b1938bb Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 7 Aug 2023 18:56:09 +0800 Subject: [PATCH 19/22] Fix ci --- scripts/github_actions/generate_build_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/github_actions/generate_build_matrix.py b/scripts/github_actions/generate_build_matrix.py index 130d17519..7a4ae8d97 100755 --- a/scripts/github_actions/generate_build_matrix.py +++ b/scripts/github_actions/generate_build_matrix.py @@ -183,7 +183,7 @@ def generate_build_matrix(enable_cuda, for_windows, for_macos, test_only_latest_ "torch": torch, "python-version": p, "cuda": c, - "image": f"pytorch/manylinux-builder:cuda{c}", + "image": "pytorch/manylinux-builder:cuda" + c, } ) else: From 2cb1412b35fdcefd308a123acc454a4c2fdf1ad2 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 8 Sep 2023 16:15:19 +0800 Subject: [PATCH 20/22] Fix ci --- .github/workflows/run-tests-cpu.yml | 3 +++ scripts/github_actions/install_torch.sh | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/run-tests-cpu.yml b/.github/workflows/run-tests-cpu.yml index e70cd9966..e55f1fbbd 100644 --- a/.github/workflows/run-tests-cpu.yml +++ b/.github/workflows/run-tests-cpu.yml @@ -54,6 +54,9 @@ jobs: torch: ["1.13.1"] python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] build_type: ["Release", "Debug"] + exclude: + - os: macos-latest + python-version: "3.11" steps: # refer to https://github.com/actions/checkout diff --git a/scripts/github_actions/install_torch.sh b/scripts/github_actions/install_torch.sh index 7ba74857a..84eef395f 100755 --- a/scripts/github_actions/install_torch.sh +++ b/scripts/github_actions/install_torch.sh @@ -14,8 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -torch=$TORCH_VERSION -cuda=$CUDA_VERSION +if [ $TORCH_VERSION != "" ] && [ $CUDA_VERSION != ""]; then + torch=$TORCH_VERSION + cuda=$CUDA_VERSION +fi + case ${torch} in 1.5.*) case ${cuda} in From 8ce371fe0c6c848a52a938ee7c7d17c549c276cc Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 22 Sep 2023 19:05:14 +0800 Subject: [PATCH 21/22] fix ci --- .github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml | 1 + k2/torch/csrc/CMakeLists.txt | 2 +- scripts/github_actions/install_cuda.sh | 3 +++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml b/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml index 439de7d53..29e02c864 100644 --- a/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml +++ b/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml @@ -104,6 +104,7 @@ jobs: - name: Install GCC 7 run: | + sudo echo "deb [arch=amd64] http://archive.ubuntu.com/ubuntu focal main universe" >> /etc/apt/sources.list sudo apt-get install -y gcc-7 g++-7 echo "CC=/usr/bin/gcc-7" >> $GITHUB_ENV echo "CXX=/usr/bin/g++-7" >> $GITHUB_ENV diff --git a/k2/torch/csrc/CMakeLists.txt b/k2/torch/csrc/CMakeLists.txt index e4831bca6..39775a627 100644 --- a/k2/torch/csrc/CMakeLists.txt +++ b/k2/torch/csrc/CMakeLists.txt @@ -75,7 +75,7 @@ target_link_libraries(k2_torch_api PUBLIC k2_torch) if(K2_ENABLE_TESTS) add_executable(torch_api_test torch_api_test.cc) - target_link_libraries(torch_api_test PRIVATE k2_torch_api gtest gtest_main) + target_link_libraries(torch_api_test k2_torch_api gtest gtest_main) # NOTE: We set the working directory here so that # it works also on windows. The reason is that diff --git a/scripts/github_actions/install_cuda.sh b/scripts/github_actions/install_cuda.sh index f7a669a45..f94e7d869 100755 --- a/scripts/github_actions/install_cuda.sh +++ b/scripts/github_actions/install_cuda.sh @@ -49,6 +49,9 @@ case "$cuda" in 11.7) url=https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda_11.7.1_515.65.01_linux.run ;; + 11.8) + url=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run + ;; *) echo "Unknown cuda version: $cuda" exit 1 From f9b8cad32b4e868e7cbc3534200380fb3272bb47 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 22 Sep 2023 22:51:02 +0800 Subject: [PATCH 22/22] fix ci --- .github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml | 4 +++- scripts/github_actions/install_cudnn.sh | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml b/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml index 29e02c864..b79f4ac70 100644 --- a/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml +++ b/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml @@ -104,7 +104,9 @@ jobs: - name: Install GCC 7 run: | - sudo echo "deb [arch=amd64] http://archive.ubuntu.com/ubuntu focal main universe" >> /etc/apt/sources.list + sudo apt update + sudo apt install software-properties-common + sudo add-apt-repository "deb [arch=amd64] http://archive.ubuntu.com/ubuntu focal main universe" sudo apt-get install -y gcc-7 g++-7 echo "CC=/usr/bin/gcc-7" >> $GITHUB_ENV echo "CXX=/usr/bin/g++-7" >> $GITHUB_ENV diff --git a/scripts/github_actions/install_cudnn.sh b/scripts/github_actions/install_cudnn.sh index d57018ce0..7bfe681e4 100755 --- a/scripts/github_actions/install_cudnn.sh +++ b/scripts/github_actions/install_cudnn.sh @@ -42,6 +42,9 @@ case $cuda in 11.7) filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz ;; + 11.8) + filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz + ;; *) echo "Unsupported cuda version: $cuda" exit 1