Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
compare nbest_rescore result between unique_word_seqs and unique_toke…
Browse files Browse the repository at this point in the history
…n_seqs
  • Loading branch information
glynpu committed Jul 8, 2021
1 parent 6d1e935 commit 5c979cc
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 76 deletions.
155 changes: 84 additions & 71 deletions egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,87 +39,103 @@
from snowfall.training.mmi_graph import get_phone_symbols


def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: int):
'''
N-best rescore with transformer-decoder model.
The basic idea is to first extra n-best paths from the given lattice.
Then extract word_seqs and token_seqs for each path.
Compute the negative log-likehood for each token_seq as 'language model score', called decoder_scores.
Compute am score for each token_seq.
Total scores is a weight sum of am_score and decoder_scores.
The one with the max total score is used as the decoding output.
'''

def extract_nbest_list(lats: k2.Fsa, num_paths: int):
# lats has token IDs as labels
# and word IDs as aux_labels.
# First, extract `num_paths` paths for each sequence.
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

# token_seqs/word_seqs is a k2.RaggedInt sharing the same shape as `paths`
# but it contains word IDs. Note that it also contains 0s and -1s.
# Both token_seqs and word_seqs are k2.RaggedInt sharing the same shape as `paths`
# Note that they also contain 0s and -1s.
# The last entry in each sublist is -1.
token_seqs = k2.index(lats.labels.contiguous(), paths)
word_seqs = k2.index(lats.aux_labels.contiguous(), paths)

# Note: the above operation supports also the case when
# lats.aux_labels is a ragged tensor. In that case,
# `remove_axis=True` is used inside the pybind11 binding code,
# so the resulting `word_seqs` still has 3 axes, like `paths`.
# The 3 axes are [seq][path][word]

# Remove epsilons and -1 from word_seqs
token_seqs = k2.ragged.remove_values_leq(token_seqs, 0)
word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)
return token_seqs, word_seqs

# Remove repeated sequences to avoid redundant computation later.
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.num_elements()
def compute_am_flm_scrores_1(lats, word_seqs, token_seqs):
'''
Compute am scores with word_seqs
wer is worse than compute_am_flm_scores_2
'''
# lats has token IDs as labels and word IDs as aux_labels.
unique_word_seqs, _, new2old = k2.ragged.unique_sequences(
word_seqs, need_num_repeats=False, need_new2old_indexes=True)
# Note: unique_word_seqs still has the same axes as word_seqs

seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

# path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path
# belongs.
path_to_seq_map = seq_to_path_shape.row_ids(1)

# Remove the seq axis.
# Now unique_word_seqs has only two axes [path][word]
unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

# word_fsas is an FsaVec with axes [path][state][arc]
word_fsas = k2.linear_fsa(unique_word_seqs)

word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops, path_to_seq_map)
import pdb; pdb.set_trace()

# lats has token IDs as labels and word IDs as aux_labels.
# inv_lats has word IDs as labels and token IDs as aux_labels
# Do k2.invert to make it compatible to function compute_am_scores
# inv_lats = k2.invert(lats)
inv_lats = k2.arc_sort(k2.invert(lats)) # no-op if inv_lats is already arc-sorted

# lats = k2.arc_sort(lats)
fgram_lm_lats = _intersect_device(inv_lats, word_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True)
fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats))

# log_semiring=False is a little better than log_semiring=True.
fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=False)
fgram_lm_scores = fgram_tot_scores - am_scores

# Now token_seqs has only two axes [path][word]
token_seqs = k2.ragged.remove_axis(token_seqs, 0)
token_ids, _ = k2.ragged.index(token_seqs, new2old, axis=0)
token_ids = k2.ragged.to_list(token_ids)
return am_scores, fgram_lm_scores, token_ids, new2old

def compute_am_flm_scrores_2(lats, word_seqs, token_seqs):
'''
Compute am scores with token_seqs
wer is better than compute_am_flm_scores_1
'''
unique_token_seqs, _, new2old = k2.ragged.unique_sequences(
token_seqs, need_num_repeats=False, need_new2old_indexes=True)

seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0)
path_to_seq_map = seq_to_path_shape.row_ids(1)

unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0)
# token_fsas is an FsaVec with axes [path][state][arc]
token_fsas = k2.linear_fsa(unique_token_seqs)
token_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(token_fsas)
# lats has token IDs as labels and word IDs as aux_labels.
# inv_lats has word IDs as labels and token IDs as aux_labels
am_scores = compute_am_scores(k2.arc_sort(k2.invert(lats)), token_fsas_with_epsilon_loops, path_to_seq_map)

fgram_lm_lats = _intersect_device(k2.arc_sort(lats), token_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True)
fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats))

# log_semiring=False is a little better than log_semiring=True.
fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=False)
fgram_lm_scores = fgram_tot_scores - am_scores

# now compute lm scores from transformer decoder
# Now token_seqs has only two axes [path][word]
token_seqs = k2.ragged.remove_axis(token_seqs, 0)
token_ids, _ = k2.ragged.index(token_seqs, new2old, axis=0)
token_ids = k2.ragged.to_list(token_ids)
return am_scores, fgram_lm_scores, token_ids, new2old

def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: int):
'''
N-best rescore with transformer-decoder model.
The basic idea is to first extra n-best paths from the given lattice.
Then extract word_seqs and token_seqs for each path.
Compute the negative log-likehood for each token_seq as 'language model score', called decoder_scores.
Compute am score for each token_seq.
Total scores is a weight sum of am_score and decoder_scores.
The one with the max total score is used as the decoding output.
'''
# token_seqs, word_seqs, unique_token_seqs, unique_word_seqs = extract_nbest_list(lats, num_paths)
token_seqs, word_seqs = extract_nbest_list(lats, num_paths)

# am_scores, fgram_lm_scores, token_ids, new2old = compute_am_flm_scrores_1(lats, word_seqs, token_seqs)
am_scores, fgram_lm_scores, token_ids, new2old = compute_am_flm_scrores_2(lats, word_seqs, token_seqs)
# now compute lm scores from transformer decoder
num_seqs = len(token_ids)
time_steps = encoder_memory.shape[0]
feature_dim = encoder_memory.shape[2]
Expand All @@ -130,11 +146,22 @@ def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths:
nll = model.decoder_nll(encoder_memory, memory_mask, token_ids=token_ids)
assert nll.shape[0] == num_seqs
decoder_scores = - nll.sum(dim=1)
tot_scores = am_scores + fgram_lm_scores + decoder_scores
best_seq_idx = new2old[torch.argmax(tot_scores)]
best_word_seq = [k2.ragged.to_list(word_seqs)[0][best_seq_idx]]

return best_word_seq
flm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0, 4.0, 6.0, 8.0, 10.0]

decoder_scale_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0, 2.0, 4.0, 6.0, 8.0, 10.0]
decoder_scale_list += [0.01, 0.03, 0.05, 0.08, 0.09]

ans = dict()
for flm_scale in flm_scale_list:
for decoder_scale in decoder_scale_list:
key = f'lm_scale_{flm_scale}_decoder_scale_{decoder_scale}'
tot_scores = am_scores + flm_scale * fgram_lm_scores + decoder_scale * decoder_scores
best_seq_idx = new2old[torch.argmax(tot_scores)]
best_word_seq = [k2.ragged.to_list(word_seqs)[0][best_seq_idx]]
ans[key] = best_word_seq

return ans

def decode_one_batch(batch: Dict[str, Any],
model: AcousticModel,
Expand Down Expand Up @@ -218,28 +245,11 @@ def decode_one_batch(batch: Dict[str, Any],

lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, output_beam_size, 30, 10000)

# TODO(Guo Liyong): figure out a way to combine lm_scale_list with transformer decoder n-best rescore
# lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
# lm_scale_list += [0.45, 0.55, 0.65]
# lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
# lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]

lm_scale_list = [0.6] # lowest wer = 2.92 without transformer n-best rescore

if use_whole_lattice:
best_paths_dict = rescore_with_whole_lattice(lattices, G,
lm_scale_list,
need_rescored_lats=True)
ans = dict()
for lm_scale_str, (best_paths, ngram_rescored_lattices) in best_paths_dict.items():
assert best_paths.shape[0] == 1, 'Figuring out a way to do batch decoding'
if nbest_rescore_with_decoder:
best_word_seq = nbest_decoding(model, encoder_memory, memory_mask, ngram_rescored_lattices, num_paths)
hyps = best_word_seq

else:
hyps = get_texts(best_paths, indices)
ans[lm_scale_str] = hyps
# fgram means four-gram
fgram_rescored_lattices = rescore_with_whole_lattice(lattices, G,
lm_scale_list=None,
need_rescored_lats=True)
ans = nbest_decoding(model, encoder_memory, memory_mask, fgram_rescored_lattices, num_paths)
return ans


Expand All @@ -252,6 +262,8 @@ def decode(dataloader: torch.utils.data.DataLoader,
G: k2.Fsa,
use_whole_lattice: bool,
output_beam_size: float):
del HLG.lm_scores
HLG.lm_scores = HLG.scores.clone()
tot_num_cuts = len(dataloader.dataset.cuts)
num_cuts = 0
results = defaultdict(list)
Expand Down Expand Up @@ -542,7 +554,8 @@ def main():
HLG.lm_scores = HLG.scores.clone()

librispeech = LibriSpeechAsrDataModule(args)
test_sets = ['test-clean', 'test-other']
# test_sets = ['test-clean', 'test-other']
test_sets = ['test-clean']
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
logging.info(f'* DECODING: {test_set}')

Expand Down
10 changes: 5 additions & 5 deletions snowfall/decoding/lm_rescore.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,15 @@ def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa,
b_to_a_map,
sorted_match_a=True)

rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device))
rescoring_lats = k2.top_sort(k2.connect(rescoring_lats))

# inv_lats has phone IDs as labels
# and word IDs as aux_labels.
inv_lats = k2.invert(rescoring_lats)

if need_rescored_lats:
return inv_lats

ans = dict()
#
# The following implements
Expand All @@ -319,8 +322,5 @@ def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa,

best_paths = k2.shortest_path(inv_lats, use_double_scores=True)
key = f'lm_scale_{lm_scale}'
if need_rescored_lats:
ans[key] = (best_paths, inv_lats)
else:
ans[key] = best_paths
ans[key] = best_paths
return ans

0 comments on commit 5c979cc

Please sign in to comment.