Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

[WIP] 2-state HMM topo as an alternative to CTC topo #126

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
8 changes: 5 additions & 3 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from snowfall.common import save_training_info
from snowfall.common import setup_logger
from snowfall.models import AcousticModel
from snowfall.models.transformer import Noam, Transformer
from snowfall.models.conformer import Conformer
from snowfall.models.transformer import Noam, Transformer
from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change
from snowfall.training.hmm_topo import build_hmm_topo_2state
from snowfall.training.mmi_graph import MmiTrainingGraphCompiler
from snowfall.training.mmi_graph import create_bigram_phone_lm
from snowfall.training.mmi_graph import get_phone_symbols
Expand Down Expand Up @@ -472,6 +473,7 @@ def main():
phones=phone_symbol_table,
words=word_symbol_table,
device=device,
topo_builder_fn=build_hmm_topo_2state
)
phone_ids = get_phone_symbols(phone_symbol_table)
P = create_bigram_phone_lm(phone_ids)
Expand Down Expand Up @@ -550,15 +552,15 @@ def main():
num_features=40,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
num_classes=2 * (len(phone_ids) + 1), # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
else:
model = Conformer(
num_features=40,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
num_classes=2 * (len(phone_ids) + 1), # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)

Expand Down
48 changes: 48 additions & 0 deletions snowfall/training/hmm_topo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import k2
from typing import List


def build_hmm_topo_2state(tokens: List[int]) -> k2.Fsa:
"""
Build a 2-state HMM topology used in Kaldi's chain models.
The first HMM state is entered only once for each token instance,
and the second HMM state is self-looped and optional.

Args:
tokens:
A list of token int IDs, e.g., phones, characters, etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably an issue in the baseline, but we shuold be clear whether this list is supposed to contain zero, or perhaps should not contain zero.

The IDs for the first HMM state will be the same as token IDs;
The IDs for the second HMM state are: ``token_id + len(tokens)``
Returns:
An FST that converts a sequence of HMM state IDs to a sequence of token IDs.
"""
followup_tokens = range(len(tokens), len(tokens) * 2)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be

followup_tokens = range(len(tokens) + 1, len(tokens) * 2 + 1)

as token id starts from 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right. In the general case, to avoid surprises, I think that should be len(tokens) + min_token_id.

num_states = len(tokens) + 2 # + start state, + final state
arcs = []

# Start state -> token state
for i in range(0, len(tokens)):
arcs += [f'0 {i + 1} {tokens[i]} {tokens[i]} 0.0']

# Token state self loops
for i in range(0, len(tokens)):
arcs += [f'{i + 1} {i + 1} {followup_tokens[i]} 0 0.0']

# Cross-token transitions
for i in range(0, len(tokens)):
for j in range(0, len(tokens)):
if i != j:
arcs += [f'{i + 1} {j + 1} {tokens[i]} {tokens[i]} 0.0']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be tokens[j] and tokens[j], instead of tokens[i] and tokens[i]?


# Token state -> superfinal state
for i in range(0, len(tokens)):
arcs += [f'{i + 1} {num_states - 1} -1 -1 0.0']

# Final state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fix the problem, you can change

# Final state
arcs += [f'{num_states - 1}']

# Build the FST
arcs = '\n'.join(sorted(arcs))

to

# Build the FST
arcs = '\n'.join(sorted(arcs))

# Final state
arcs += f'\n{num_states - 1}'

k2 expects that the last line contains the final state. Nothing should follow
the final state.

The documentation https://github.com/k2-fsa/k2/blob/1eeeecfac558a6ae4133e2c0b4f0022bee24c786/k2/python/k2/fsa.py#L1078
says

        Caution:
          The first column has to be non-decreasing.

non-decreasing is in numeric, not in alphabetic order. sorted in python sorts in alphabetic.
That is the problem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above fix is not a complete solution.
If the list is too large, it may result in

1 ....
1 ...
11 ....
2 ....

due to sorted. 11 should come after 2 and it will cause another crash.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing

arcs = '\n'.join(sorted(arcs))

to

arcs = '\n'.join(sorted(arcs, key=lambda arc: int(arc.split()[0])))

should work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I don't think I would have came up with that so fast myself ;)

arcs += [f'{num_states - 1}']

# Build the FST
arcs = '\n'.join(sorted(arcs))
ans = k2.Fsa.from_str(arcs)
ans = k2.arc_sort(ans)
return ans
33 changes: 16 additions & 17 deletions snowfall/training/mmi_graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (c) 2020 Xiaomi Corp. (author: Fangjun Kuang)

import k2
import torch
from typing import Iterable
from typing import List
from typing import Tuple

import k2
import torch

from .ctc_graph import build_ctc_topo
from snowfall.common import get_phone_symbols
from .ctc_graph import build_ctc_topo


def create_bigram_phone_lm(phones: List[int]) -> k2.Fsa:
Expand Down Expand Up @@ -47,6 +46,7 @@ def __init__(self,
phones: k2.SymbolTable,
words: k2.SymbolTable,
device: torch.device,
topo_builder_fn=build_ctc_topo,
oov: str = '<UNK>'):
'''
Args:
Expand Down Expand Up @@ -78,10 +78,9 @@ def __init__(self,
phone_symbols = get_phone_symbols(phones)
phone_symbols_with_blank = [0] + phone_symbols

ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device)
assert ctc_topo.requires_grad is False

self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
H = topo_builder_fn(phone_symbols_with_blank).to(device)
assert H.requires_grad is False
self.H_inv = k2.arc_sort(H.invert_())

def compile(self, texts: Iterable[str],
P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]:
Expand All @@ -106,28 +105,28 @@ def compile(self, texts: Iterable[str],
assert P.device == self.device
P_with_self_loops = k2.add_epsilon_self_loops(P)

ctc_topo_P = k2.intersect(self.ctc_topo_inv,
P_with_self_loops,
treat_epsilons_specially=False).invert()

ctc_topo_P = k2.arc_sort(ctc_topo_P)
HP = k2.intersect(
self.H_inv,
P_with_self_loops,
treat_epsilons_specially=False
).invert()
HP = k2.arc_sort(HP)

num_graphs = self.build_num_graphs(texts)
num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
num_graphs)

num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

num = k2.compose(ctc_topo_P,
num = k2.compose(HP,
num_graphs_with_self_loops,
treat_epsilons_specially=False)
num = k2.arc_sort(num)

ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
HP_vec = k2.create_fsa_vec([HP.detach()])
indexes = torch.zeros(len(texts),
dtype=torch.int32,
device=self.device)
den = k2.index_fsa(ctc_topo_P_vec, indexes)
den = k2.index_fsa(HP_vec, indexes)

return num, den

Expand Down