Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow partial for intersect_dense_pruned #1218

Merged
merged 25 commits into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/run-tests-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ jobs:
torch: ["1.13.1"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
build_type: ["Release", "Debug"]
exclude:
- os: macos-latest
python-version: "3.11"

steps:
# refer to https://github.com/actions/checkout
Expand Down
14 changes: 10 additions & 4 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest,
@param[in] b_fsas Input FSAs that correspond to neural network
outputs (see documentation in fsa.h).
@param[in] search_beam Beam for frame-synchronous beam pruning,
e.g. 20. Smaller is faster, larger is more exact
(less pruning). This is the default value; it may be
modified by {min,max}_active which dictate the minimum
or maximum allowed number of active states per frame.
e.g. 20. Smaller is faster, larger is more exact
(less pruning). This is the default value; it may be
modified by {min,max}_active which dictate the minimum
or maximum allowed number of active states per frame.
@param[in] output_beam Beam with which we prune the output (analogous
to lattice-beam in Kaldi), e.g. 8. We discard arcs in
the output that are not on a path that's within
Expand All @@ -178,6 +178,11 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest,
of states are active. The hash size used per FSA is 4
times (this rounded up to a power of 2), so this
affects memory consumption.
@param [in] allow_partial If true and there was no final state active,
we will treat all the states on the last frame
to be final state. If false, we only
care about the real final state in the decoding
graph on the last frame when generating lattice.
@param[out] out Output vector of composed, pruned FSAs, with same
Dim0() as b_fsas. Elements of it may be empty if the
composition was empty, either intrinsically or due to
Expand All @@ -196,6 +201,7 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest,
void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas,
float search_beam, float output_beam,
int32_t min_active_states, int32_t max_active_states,
bool allow_partial,
FsaVec *out, Array1<int32_t> *arc_map_a,
Array1<int32_t> *arc_map_b);

Expand Down
331 changes: 234 additions & 97 deletions k2/csrc/intersect_dense_pruned.cu

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions k2/csrc/intersect_dense_pruned.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class MultiGraphDenseIntersectPruned;

// DecodeStateInfo contains the history decoding states for each sequence, this
// is normally constructed from `frames_` in MultiGraphDenseIntersectPruned
// bu using `Stack` and `Unstack`.
// by using `Stack` and `Unstack`.
struct DecodeStateInfo {
// States that survived for the previously decoded frames. Indexed
// [frame_idx][state_idx], state_idx just enumerates the active states
Expand Down Expand Up @@ -170,7 +170,7 @@ class OnlineDenseIntersecter {
public:
OnlineDenseIntersecter(FsaVec &a_fsas, int32_t num_seqs, float search_beam,
float output_beam, int32_t min_states,
int32_t max_states);
int32_t max_states, bool allow_partial = true);

/* Does intersection/composition for current chunk of nnet_output(given
by a DenseFsaVec), sequences in every chunk may come from different
Expand All @@ -194,7 +194,7 @@ class OnlineDenseIntersecter {
will have been assigned to this location.
*/
void Decode(DenseFsaVec &b_fsas,
std::vector<std::shared_ptr<DecodeStateInfo>> *decode_states,
std::vector<DecodeStateInfo* > *decode_states,
FsaVec *ofsa, Array1<int32_t> *arc_map_a);

ContextPtr &Context() { return c_;}
Expand Down
22 changes: 15 additions & 7 deletions k2/csrc/intersect_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ TEST(Intersect, RandomSingle) {
K2_LOG(INFO) << "fsas_b = " << fsas_b;
FsaVec out_fsas2;
Array1<int32_t> arc_map_a2, arc_map_b2;
// IntersectDensePruned() treats epsilons as normal symbols, so we need to
// IntersectDense() treats epsilons as normal symbols, so we need to
// as well.

ArcSort(&fsa); // CAUTION if you later test the arc_maps: we arc-sort here,
Expand Down Expand Up @@ -339,7 +339,7 @@ TEST(Intersect, RandomFsaVec) {
K2_LOG(INFO) << "fsas_b = " << fsas_b;
FsaVec out_fsas2;
Array1<int32_t> arc_map_a2, arc_map_b2;
// IntersectDensePruned() treats epsilons as normal symbols, so we need to
// IntersectDense() treats epsilons as normal symbols, so we need to
// as well.

ArcSort(&fsavec); // CAUTION if you later test the arc_maps: we arc-sort
Expand Down Expand Up @@ -485,11 +485,12 @@ TEST(IntersectPruned, Simple) {

float beam = 100000;
int32_t max_active = 10000, min_active = 0;
bool allow_partial = false;

FsaVec out_fsas;
Array1<int32_t> arc_map_a, arc_map_b;
IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active,
&out_fsas, &arc_map_a, &arc_map_b);
allow_partial, &out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;

Expand Down Expand Up @@ -542,11 +543,12 @@ TEST(IntersectPruned, TwoDense) {

float beam = 100000;
int32_t max_active = 10000, min_active = 0;
bool allow_partial = false;

FsaVec out_fsas;
Array1<int32_t> arc_map_a, arc_map_b;
IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active,
&out_fsas, &arc_map_a, &arc_map_b);
allow_partial, &out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;

Expand Down Expand Up @@ -591,11 +593,12 @@ TEST(IntersectPruned, TwoFsas) {

float beam = 100000;
int32_t max_active = 10000, min_active = 0;
bool allow_partial = false;

FsaVec out_fsas;
Array1<int32_t> arc_map_a, arc_map_b;
IntersectDensePruned(fsa_vec, dfsavec, beam, beam, min_active, max_active,
&out_fsas, &arc_map_a, &arc_map_b);
allow_partial, &out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;

Expand Down Expand Up @@ -659,8 +662,10 @@ TEST(IntersectPruned, RandomSingle) {
FsaVec out_fsas;
float beam = 1000.0;
int32_t max_active = 10000, min_active = 0;
bool allow_partial = false;

IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active,
&out_fsas, &arc_map_a, &arc_map_b);
allow_partial, &out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_b = " << arc_map_b;

FsaVec fsas_b = ConvertDenseToFsaVec(dfsavec);
Expand Down Expand Up @@ -763,8 +768,11 @@ TEST(IntersectPruned, RandomFsaVec) {
FsaVec out_fsas;
float search_beam = 1000.0, output_beam = 1000.0;
int32_t min_active = 0, max_active = 10;
bool allow_partial = false;

IntersectDensePruned(fsavec, dfsavec, search_beam, output_beam, min_active,
max_active, &out_fsas, &arc_map_a, &arc_map_b);
max_active, allow_partial,
&out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas
<< ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;
Expand Down
31 changes: 20 additions & 11 deletions k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,23 @@ static void PybindIntersectDensePruned(py::module &m) {
"intersect_dense_pruned",
[](FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam,
float output_beam, int32_t min_active_states,
int32_t max_active_states)
int32_t max_active_states, bool allow_partial)
-> std::tuple<FsaVec, torch::Tensor, torch::Tensor> {
DeviceGuard guard(a_fsas.Context());
Array1<int32_t> arc_map_a;
Array1<int32_t> arc_map_b;
FsaVec out;

IntersectDensePruned(a_fsas, b_fsas, search_beam, output_beam,
min_active_states, max_active_states, &out,
min_active_states, max_active_states,
allow_partial, &out,
&arc_map_a, &arc_map_b);
return std::make_tuple(out, ToTorch(arc_map_a), ToTorch(arc_map_b));
},
py::arg("a_fsas"), py::arg("b_fsas"), py::arg("search_beam"),
py::arg("output_beam"), py::arg("min_active_states"),
py::arg("max_active_states"));
py::arg("max_active_states"),
py::arg("allow_partial") = false);
}

static void PybindIntersectDense(py::module &m) {
Expand Down Expand Up @@ -751,8 +753,9 @@ static void PybindLevenshteinGraph(py::module &m) {

static void PybindDecodeStateInfo(py::module &m) {
using PyClass = DecodeStateInfo;
py::class_<PyClass, std::shared_ptr<PyClass>> state_info(m,
"DecodeStateInfo");
py::class_<PyClass> state_info(
m, "DecodeStateInfo");
state_info.def(py::init<>());
}

static void PybindOnlineDenseIntersecter(py::module &m) {
Expand All @@ -763,26 +766,32 @@ static void PybindOnlineDenseIntersecter(py::module &m) {
py::init([](FsaVec &decoding_graph, int32_t num_streams,
float search_beam, float output_beam,
int32_t min_active_states,
int32_t max_active_states) -> std::unique_ptr<PyClass> {
int32_t max_active_states,
bool allow_partial) -> std::unique_ptr<PyClass> {
DeviceGuard guard(decoding_graph.Context());
return std::make_unique<PyClass>(decoding_graph, num_streams,
search_beam, output_beam,
min_active_states, max_active_states);
min_active_states, max_active_states,
allow_partial);
}),
py::arg("decoding_graph"), py::arg("num_streams"), py::arg("search_beam"),
py::arg("output_beam"), py::arg("min_active_states"),
py::arg("max_active_states"));
py::arg("max_active_states"), py::arg("allow_partial") = true);

intersecter.def(
"decode",
[](PyClass &self, DenseFsaVec &dense_fsa_vec,
std::vector<std::shared_ptr<DecodeStateInfo>> &decode_states)
std::vector<DecodeStateInfo> &decode_states)
-> std::tuple<FsaVec, torch::Tensor,
std::vector<std::shared_ptr<DecodeStateInfo>>> {
std::vector<DecodeStateInfo>> {
DeviceGuard guard(self.Context());
FsaVec ofsa;
Array1<int32_t> arc_map;
self.Decode(dense_fsa_vec, &decode_states, &ofsa, &arc_map);
std::vector<DecodeStateInfo*> decode_states_ptr(decode_states.size());
for (size_t i = 0; i < decode_states.size(); ++i) {
decode_states_ptr[i] = &decode_states[i];
}
self.Decode(dense_fsa_vec, &decode_states_ptr, &ofsa, &arc_map);
torch::Tensor arc_map_tensor = ToTorch(arc_map);
return std::make_tuple(ofsa, arc_map_tensor, decode_states);
},
Expand Down
33 changes: 24 additions & 9 deletions k2/python/k2/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def forward(ctx,
output_beam: float,
min_active_states: int,
max_active_states: int,
allow_partial: bool,
unused_scores_a: torch.Tensor,
unused_scores_b: torch.Tensor,
seqframe_idx_name: Optional[str] = None,
Expand All @@ -383,16 +384,21 @@ def forward(ctx,
output_beam:
Pruning beam for the output of intersection (vs. best path);
equivalent to kaldi's lattice-beam. E.g. 8.
max_active_states:
Maximum number of FSA states that are allowed to be active on any
given frame for any given intersection/composition task. This is
advisory, in that it will try not to exceed that but may not always
succeed. You can use a very large number if no constraint is needed.
min_active_states:
Minimum number of FSA states that are allowed to be active on any
given frame for any given intersection/composition task. This is
advisory, in that it will try not to have fewer than this number
active. Set it to zero if there is no constraint.
max_active_states:
Maximum number of FSA states that are allowed to be active on any
given frame for any given intersection/composition task. This is
advisory, in that it will try not to exceed that but may not always
succeed. You can use a very large number if no constraint is needed.
allow_partial If true and there was no final state active,
we will treat all the states on the
last frame to be final state. If false, we only
care about the real final state in the decoding
graph on the last frame when generating lattice.
unused_scores_a:
It equals to `a_fsas.scores` and its sole purpose is for back
propagation.
Expand All @@ -418,7 +424,8 @@ def forward(ctx,
search_beam=search_beam,
output_beam=output_beam,
min_active_states=min_active_states,
max_active_states=max_active_states)
max_active_states=max_active_states,
allow_partial=allow_partial)

out_fsa[0] = Fsa(ragged_arc)

Expand Down Expand Up @@ -466,7 +473,7 @@ def forward(ctx,

@staticmethod
def backward(ctx, out_fsa_grad: torch.Tensor) \
-> Tuple[None, None, None, None, None, None, None, torch.Tensor, torch.Tensor]: # noqa
-> Tuple[None, None, None, None, None, None, None, None, torch.Tensor, torch.Tensor, None, None]: # noqa
a_scores, b_scores = ctx.saved_tensors
arc_map_a = ctx.arc_map_a
arc_map_b = ctx.arc_map_b
Expand All @@ -493,6 +500,7 @@ def backward(ctx, out_fsa_grad: torch.Tensor) \
None, # output_beam
None, # min_active_states
None, # max_active_states
None, # allow_partial
grad_a, # unused_scores_a
grad_b, # unused_scores_b
None, # seqframe_idx_name
Expand Down Expand Up @@ -663,7 +671,8 @@ def intersect_dense_pruned(a_fsas: Fsa,
min_active_states: int,
max_active_states: int,
seqframe_idx_name: Optional[str] = None,
frame_idx_name: Optional[str] = None) -> Fsa:
frame_idx_name: Optional[str] = None,
allow_partial: bool = False) -> Fsa:
'''Intersect array of FSAs on CPU/GPU.

Caution:
Expand Down Expand Up @@ -694,6 +703,11 @@ def intersect_dense_pruned(a_fsas: Fsa,
frame for any given intersection/composition task. This is advisory,
in that it will try not to exceed that but may not always succeed.
You can use a very large number if no constraint is needed.
allow_partial If true and there was no final state active,
we will treat all the states on the
last frame to be final state. If false, we only
care about the real final state in the decoding
graph on the last frame when generating lattice.
seqframe_idx_name:
If set (e.g. to 'seqframe'), an attribute in the output will be created
that encodes the sequence-index and the frame-index within that
Expand Down Expand Up @@ -727,7 +741,8 @@ def intersect_dense_pruned(a_fsas: Fsa,
# in `out_fsa[0].scores`
_IntersectDensePrunedFunction.apply(a_fsas, b_fsas, out_fsa, search_beam,
output_beam, min_active_states,
max_active_states, a_fsas.scores,
max_active_states, allow_partial,
a_fsas.scores,
b_fsas.scores, seqframe_idx_name,
frame_idx_name)
return out_fsa[0]
Expand Down
2 changes: 1 addition & 1 deletion k2/python/k2/dense_fsa_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self,
segment_index, start_frame, duration = segment
assert 0 <= segment_index < N
assert 0 <= start_frame < T
assert duration > 0
assert duration >= 0
assert start_frame + duration <= T + allow_truncate
offset = segment_index * T
end_frame = min(start_frame + duration, T) # exclusive
Expand Down
7 changes: 7 additions & 0 deletions k2/python/k2/online_dense_intersecter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
output_beam: float,
min_active_states: int,
max_active_states: int,
allow_partial: bool = True,
) -> None:
"""Create a new online intersecter object.
Args:
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
decode_states[1] = new_decode_states[1]
...
"""
self.num_streams_ = num_streams
self.decoding_graph = decoding_graph
self.device = decoding_graph.device
self.intersecter = _k2.OnlineDenseIntersecter(
Expand All @@ -100,8 +102,13 @@ def __init__(
output_beam,
min_active_states,
max_active_states,
allow_partial=allow_partial,
)

@property
def num_streams(self) -> int:
return self.num_streams_

def decode(
self, dense_fsas: DenseFsaVec, decode_states: List[DecodeStateInfo]
) -> Tuple[Fsa, List[DecodeStateInfo]]:
Expand Down
2 changes: 1 addition & 1 deletion k2/python/tests/online_dense_intersecter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test(self):
num_chunks = 3
chunk_size = 5

decode_states = [None] * num_streams
decode_states = [k2.DecodeStateInfo()] * num_streams

for i in range(num_chunks):
logits = torch.randn(
Expand Down
Loading
Loading