-
Notifications
You must be signed in to change notification settings - Fork 42
[WIP] 2-state HMM topo as an alternative to CTC topo #126
base: master
Are you sure you want to change the base?
Changes from 2 commits
8586e31
2221209
d2696ca
d1efefa
1f35b96
ee992eb
497e075
5eebbac
268898d
bdc812e
eac0a65
c85e53a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import k2 | ||
from typing import List | ||
|
||
|
||
def build_hmm_topo_2state(tokens: List[int]) -> k2.Fsa: | ||
""" | ||
Build a 2-state HMM topology used in Kaldi's chain models. | ||
The first HMM state is entered only once for each token instance, | ||
and the second HMM state is self-looped and optional. | ||
|
||
Args: | ||
tokens: | ||
A list of token int IDs, e.g., phones, characters, etc. | ||
The IDs for the first HMM state will be the same as token IDs; | ||
The IDs for the second HMM state are: ``token_id + len(tokens)`` | ||
Returns: | ||
An FST that converts a sequence of HMM state IDs to a sequence of token IDs. | ||
""" | ||
followup_tokens = range(len(tokens), len(tokens) * 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should it be followup_tokens = range(len(tokens) + 1, len(tokens) * 2 + 1) as token id starts from 1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you're right. In the general case, to avoid surprises, I think that should be |
||
num_states = len(tokens) + 2 # + start state, + final state | ||
arcs = [] | ||
|
||
# Start state -> token state | ||
for i in range(0, len(tokens)): | ||
arcs += [f'0 {i + 1} {tokens[i]} {tokens[i]} 0.0'] | ||
|
||
# Token state self loops | ||
for i in range(0, len(tokens)): | ||
arcs += [f'{i + 1} {i + 1} {followup_tokens[i]} 0 0.0'] | ||
|
||
# Cross-token transitions | ||
for i in range(0, len(tokens)): | ||
for j in range(0, len(tokens)): | ||
if i != j: | ||
arcs += [f'{i + 1} {j + 1} {tokens[i]} {tokens[i]} 0.0'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be tokens[j] and tokens[j], instead of tokens[i] and tokens[i]? |
||
|
||
# Token state -> superfinal state | ||
for i in range(0, len(tokens)): | ||
arcs += [f'{i + 1} {num_states - 1} -1 -1 0.0'] | ||
|
||
# Final state | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To fix the problem, you can change # Final state
arcs += [f'{num_states - 1}']
# Build the FST
arcs = '\n'.join(sorted(arcs)) to # Build the FST
arcs = '\n'.join(sorted(arcs))
# Final state
arcs += f'\n{num_states - 1}' k2 expects that the last line contains the final state. Nothing should follow The documentation https://github.com/k2-fsa/k2/blob/1eeeecfac558a6ae4133e2c0b4f0022bee24c786/k2/python/k2/fsa.py#L1078 Caution:
The first column has to be non-decreasing.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The above fix is not a complete solution.
due to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing arcs = '\n'.join(sorted(arcs)) to arcs = '\n'.join(sorted(arcs, key=lambda arc: int(arc.split()[0]))) should work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I don't think I would have came up with that so fast myself ;) |
||
arcs += [f'{num_states - 1}'] | ||
|
||
# Build the FST | ||
arcs = '\n'.join(sorted(arcs)) | ||
ans = k2.Fsa.from_str(arcs) | ||
ans = k2.arc_sort(ans) | ||
return ans |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably an issue in the baseline, but we shuold be clear whether this list is supposed to contain zero, or perhaps should not contain zero.