diff --git a/egs/librispeech/asr/simple_v1/.gitignore b/egs/librispeech/asr/simple_v1/.gitignore index 2211df63..21f82734 100644 --- a/egs/librispeech/asr/simple_v1/.gitignore +++ b/egs/librispeech/asr/simple_v1/.gitignore @@ -1 +1,3 @@ *.txt +data +exp diff --git a/egs/librispeech/asr/simple_v1/local/add_silence_to_transcript.py b/egs/librispeech/asr/simple_v1/local/add_silence_to_transcript.py new file mode 100755 index 00000000..ca441f39 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/local/add_silence_to_transcript.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +''' +Add silence with a given probability after each word in the transcript. + +If the input transcript contains: + + hello world + foo bar koo + zoo + +Then the output transcript **may** look like the following: + + !SIL hello !SIL world !SIL + foo bar !SIL koo !SIL + !SIL zoo !SIL + +(Assume !SIL represents silence.) +''' + +from pathlib import Path + +import argparse +import random + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--transcript', + type=str, + help='The input transcript file.' + 'We assume that the transcript file consists of ' + 'lines. Each line consists of space separated words.') + parser.add_argument('--sil-word', + type=str, + default='!SIL', + help='The word that represents silence.') + parser.add_argument('--sil-prob', + type=float, + default=0.5, + help='The probability for adding a ' + 'silence after each world.') + parser.add_argument('--seed', + type=int, + default=None, + help='The seed for random number generators.') + + return parser.parse_args() + + +def need_silence(sil_prob: float) -> bool: + ''' + Args: + sil_prob: + The probability to add a silence. + Returns: + Return True if a silence is needed. + Return False otherwise. + ''' + return random.uniform(0, 1) <= sil_prob + + +def process_line(line: str, sil_word: str, sil_prob: float) -> None: + '''Process a single line from the transcript. + + Args: + line: + A str containing space separated words. + sil_word: + The symbol indicating silence. + sil_prob: + The probability for adding a silence after each word. + Returns: + Return None. + ''' + words = line.strip().split() + for i, word in enumerate(words): + if i == 0: + # beginning of the line + if need_silence(sil_prob): + print(sil_word, end=' ') + + print(word, end=' ') + + if need_silence(sil_prob): + print(sil_word, end=' ') + + # end of the line, print a new line + if i == len(words) - 1: + print() + + +def main(): + args = get_args() + random.seed(args.seed) + + assert Path(args.transcript).is_file() + assert len(args.sil_word) > 0 + assert 0 < args.sil_prob < 1 + + with open(args.transcript) as f: + for line in f: + process_line(line=line, + sil_word=args.sil_word, + sil_prob=args.sil_prob) + + +if __name__ == '__main__': + main() diff --git a/egs/librispeech/asr/simple_v1/local/convert_transcript_to_corpus.py b/egs/librispeech/asr/simple_v1/local/convert_transcript_to_corpus.py new file mode 100755 index 00000000..62c60074 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/local/convert_transcript_to_corpus.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +''' +Convert a transcript file to a corpus for LM training with +the help of a lexicon. If the lexicon contains phones, the resulting +LM will be a phone LM; If the lexicon contains word pieces, +the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, only the first one is used. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +''' + +from pathlib import Path +from typing import Dict + +import argparse + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--transcript', + type=str, + help='The input transcript file.' + 'We assume that the transcript file consists of ' + 'lines. Each line consists of space separated words.') + parser.add_argument('--lexicon', type=str, help='The input lexicon file.') + parser.add_argument('--oov', + type=str, + default='', + help='The OOV word.') + + return parser.parse_args() + + +def read_lexicon(filename: str) -> Dict[str, str]: + ''' + Args: + filename: + Filename to the lexicon. Each line in the lexicon + has the following format: + + word p1 p2 p3 + + where the first field is a word and the remaining fields + are the pronunciations of the word. Fields are separated + by spaces. + Returns: + Return a dict whose keys are words and values are the pronunciations. + ''' + ans = dict() + with open(filename) as f: + for line in f: + line = line.strip() + + if len(line) == 0: + # skip empty lines + continue + + fields = line.split() + assert len(fields) >= 2 + + word = fields[0] + pron = ' '.join(fields[1:]) + + if word not in ans: + # In case a word has multiple pronunciations, + # we only use the first one + ans[word] = pron + return ans + + +def process_line(lexicon: Dict[str, str], line: str, oov_pron: str) -> None: + ''' + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations. + line: + A line of transcript consisting of space separated words. + oov_pron: + The pronunciation of the oov word if a word in line is not present + in the lexicon. + Returns: + Return None. + ''' + words = line.strip().split() + for i, w in enumerate(words): + pron = lexicon.get(w, oov_pron) + print(pron, end=' ') + if i == len(words) - 1: + # end of the line, prints a new line + print() + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + lexicon = read_lexicon(args.lexicon) + assert args.oov in lexicon + + oov_pron = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_pron=oov_pron) + + +if __name__ == '__main__': + main() diff --git a/egs/librispeech/asr/simple_v1/local/make_kn_lm.py b/egs/librispeech/asr/simple_v1/local/make_kn_lm.py new file mode 100755 index 00000000..58b721d2 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/local/make_kn_lm.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python3 + +# Copyright 2016 Johns Hopkins University (Author: Daniel Povey) +# 2018 Ruizhe Huang +# Apache 2.0. + +# This is an implementation of computing Kneser-Ney smoothed language model +# in the same way as srilm. This is a back-off, unmodified version of +# Kneser-Ney smoothing, which produces the same results as the following +# command (as an example) of srilm: +# +# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \ +# -text corpus.txt -lm lm.arpa +# +# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py +# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html + +import sys +import os +import re +import io +import math +import argparse +from collections import Counter, defaultdict + + +parser = argparse.ArgumentParser(description=""" + Generate kneser-ney language model as arpa format. By default, + it will read the corpus from standard input, and output to standard output. + """) +parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram") +parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") +parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models") +parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level") +args = parser.parse_args() + +default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. + # Need to be very careful about the use of strip() and split() + # in this case, because there is a latin-1 whitespace character + # (nbsp) which is part of the unicode encoding range. + # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 +strip_chars = " \t\r\n" +whitespace = re.compile("[ \t]+") + + +class CountsForHistory: + # This class (which is more like a struct) stores the counts seen in a + # particular history-state. It is used inside class NgramCounts. + # It really does the job of a dict from int to float, but it also + # keeps track of the total count. + def __init__(self): + # The 'lambda: defaultdict(float)' is an anonymous function taking no + # arguments that returns a new defaultdict(float). + self.word_to_count = defaultdict(int) + self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts + self.word_to_f = dict() # discounted probability + self.word_to_bow = dict() # back-off weight + self.total_count = 0 + + def words(self): + return self.word_to_count.keys() + + def __str__(self): + # e.g. returns ' total=12: 3->4, 4->6, -1->2' + return ' total={0}: {1}'.format( + str(self.total_count), + ', '.join(['{0} -> {1}'.format(word, count) + for word, count in self.word_to_count.items()])) + + def add_count(self, predicted_word, context_word, count): + assert count >= 0 + + self.total_count += count + self.word_to_count[predicted_word] += count + if context_word is not None: + self.word_to_context[predicted_word].add(context_word) + + +class NgramCounts: + # A note on data-structure. Firstly, all words are represented as + # integers. We store n-gram counts as an array, indexed by (history-length + # == n-gram order minus one) (note: python calls arrays "lists") of dicts + # from histories to counts, where histories are arrays of integers and + # "counts" are dicts from integer to float. For instance, when + # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd + # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an + # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. + def __init__(self, ngram_order, bos_symbol='', eos_symbol=''): + assert ngram_order >= 2 + + self.ngram_order = ngram_order + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + self.counts = [] + for n in range(ngram_order): + self.counts.append(defaultdict(lambda: CountsForHistory())) + + self.d = [] # list of discounting factor for each order of ngram + + # adds a raw count (called while processing input data). + # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history' + # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be + # 1. + def add_count(self, history, predicted_word, context_word, count): + self.counts[len(history)][history].add_count(predicted_word, context_word, count) + + # 'line' is a string containing a sequence of integer word-ids. + # This function adds the un-smoothed counts from this line of text. + def add_raw_counts_from_line(self, line): + if line == '': + words = [self.bos_symbol, self.eos_symbol] + else: + words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] + + for i in range(len(words)): + for n in range(1, self.ngram_order+1): + if i + n > len(words): + break + ngram = words[i: i + n] + predicted_word = ngram[-1] + history = tuple(ngram[: -1]) + if i == 0 or n == self.ngram_order: + context_word = None + else: + context_word = words[i-1] + + self.add_count(history, predicted_word, context_word, 1) + + def add_raw_counts_from_standard_input(self): + lines_processed = 0 + infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input + for line in infile: + line = line.strip(strip_chars) + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def add_raw_counts_from_file(self, filename): + lines_processed = 0 + with open(filename, encoding=default_encoding) as fp: + for line in fp: + line = line.strip(strip_chars) + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def cal_discounting_constants(self): + # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), + # where n1_N is the number of unique N-grams with count = 1 (counts-of-counts). + # This constant is used similarly to absolute discounting. + # Return value: d is a list of floats, where d[N+1] = D_N + + self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 + # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, + # but perhaps this is not the case for some other scenarios. + for n in range(1, self.ngram_order): + this_order_counts = self.counts[n] + n1 = 0 + n2 = 0 + for hist, counts_for_hist in this_order_counts.items(): + stat = Counter(counts_for_hist.word_to_count.values()) + n1 += stat[1] + n2 += stat[2] + assert n1 + 2 * n2 > 0 + self.d.append(n1 * 1.0 / (n1 + 2 * n2)) + + def cal_f(self): + # f(a_z) is a probability distribution of word sequence a_z. + # Typically f(a_z) is discounted to be less than the ML estimate so we have + # some leftover probability for the z words unseen in the context (a_). + # + # f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams + # f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w, c in counts_for_hist.word_to_count.items(): + counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + + n_star_star = 0 + for w in counts_for_hist.word_to_count.keys(): + n_star_star += len(counts_for_hist.word_to_context[w]) + + if n_star_star != 0: + for w in counts_for_hist.word_to_count.keys(): + n_star_z = len(counts_for_hist.word_to_context[w]) + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star + else: # patterns begin with , they do not have "modified count", so use raw count instead + for w in counts_for_hist.word_to_count.keys(): + n_star_z = counts_for_hist.word_to_count[w] + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + def cal_bow(self): + # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. + # Thus, two sorts of ngrams do not have a bow: + # 1) highest order ngram + # 2) ngrams ending in + # + # bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z)) + # Note that Z1 is the set of all words with c(a_z) > 0 + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + counts_for_hist.word_to_bow[w] = None + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + if w == self.eos_symbol: + counts_for_hist.word_to_bow[w] = None + else: + a_ = hist + (w,) + + assert len(a_) < self.ngram_order + assert a_ in self.counts[len(a_)].keys() + + a_counts_for_hist = self.counts[len(a_)][a_] + + sum_z1_f_a_z = 0 + for u in a_counts_for_hist.word_to_count.keys(): + sum_z1_f_a_z += a_counts_for_hist.word_to_f[u] + + sum_z1_f_z = 0 + _ = a_[1:] + _counts_for_hist = self.counts[len(_)][_] + for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1 + sum_z1_f_z += _counts_for_hist.word_to_f[u] + + counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z) + + def print_raw_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])) + res.sort(reverse=True) + for r in res: + print(r) + + def print_modified_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + modified_count = len(counts_for_hist.word_to_context[w]) + raw_count = counts_for_hist.word_to_count[w] + + if modified_count == 0: + res.append("{0}\t{1}".format(ngram, raw_count)) + else: + res.append("{0}\t{1}".format(ngram, modified_count)) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + res.append("{0}\t{1}".format(ngram, math.log(f, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f_and_bow(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + bow = counts_for_hist.word_to_bow[w] + if bow is None: + res.append("{1}\t{0}".format(ngram, math.log(f, 10))) + else: + res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')): + # print as ARPA format. + + print('\\data\\', file=fout) + for hist_len in range(self.ngram_order): + # print the number of n-grams. + print('ngram {0}={1}'.format( + hist_len + 1, + sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])), + file=fout + ) + + print('', file=fout) + + for hist_len in range(self.ngram_order): + print('\\{0}-grams:'.format(hist_len + 1), file=fout) + + this_order_counts = self.counts[hist_len] + for hist, counts_for_hist in this_order_counts.items(): + for word in counts_for_hist.word_to_count.keys(): + ngram = hist + (word,) + prob = counts_for_hist.word_to_f[word] + bow = counts_for_hist.word_to_bow[word] + + if prob == 0: # f() is always 0 + prob = 1e-99 + + line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram)) + if bow is not None: + line += '\t{0}'.format('%.7f' % math.log10(bow)) + print(line, file=fout) + print('', file=fout) + print('\\end\\', file=fout) + + +if __name__ == "__main__": + + ngram_counts = NgramCounts(args.ngram_order) + + if args.text is None: + ngram_counts.add_raw_counts_from_standard_input() + else: + assert os.path.isfile(args.text) + ngram_counts.add_raw_counts_from_file(args.text) + + ngram_counts.cal_discounting_constants() + ngram_counts.cal_f() + ngram_counts.cal_bow() + + if args.lm is None: + ngram_counts.print_as_arpa() + else: + with open(args.lm, 'w', encoding=default_encoding) as f: + ngram_counts.print_as_arpa(fout=f) diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py index 76c7cc08..66ce811b 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py @@ -70,7 +70,6 @@ from snowfall.models.conformer import Conformer from snowfall.models.contextnet import ContextNet from snowfall.training.ctc_graph import build_ctc_topo -from snowfall.training.mmi_graph import create_bigram_phone_lm from snowfall.training.mmi_graph import get_phone_symbols def nbest_decoding(lats: k2.Fsa, num_paths: int): @@ -434,7 +433,6 @@ def main(): phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) - P = create_bigram_phone_lm(phone_ids) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) @@ -475,8 +473,6 @@ def main(): else: raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") - model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False) - if avg == 1: checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') load_checkpoint(checkpoint, model) diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py index f9526481..debd2ff1 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py @@ -28,6 +28,7 @@ from lhotse.utils import fix_random_seed, nullcontext from snowfall.common import describe, str2bool +from snowfall.common import find_first_disambig_symbol from snowfall.common import load_checkpoint, save_checkpoint from snowfall.common import save_training_info from snowfall.common import setup_logger @@ -43,13 +44,11 @@ from snowfall.objectives import LFMMILoss, encode_supervisions from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change from snowfall.training.mmi_graph import MmiTrainingGraphCompiler -from snowfall.training.mmi_graph import create_bigram_phone_lm def get_objf(batch: Dict, model: AcousticModel, ali_model: Optional[AcousticModel], - P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool, @@ -74,7 +73,6 @@ def get_objf(batch: Dict, loss_fn = LFMMILoss( graph_compiler=graph_compiler, - P=P, den_scale=den_scale, use_pruned_intersect=use_pruned_intersect ) @@ -92,7 +90,10 @@ def get_objf(batch: Dict, nnet_output, encoder_memory, memory_mask = model(feature, supervisions) if att_rate != 0.0: - att_loss = model.module.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) + if hasattr(model, 'module'): + att_loss = model.module.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) + else: + att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) if (ali_model is not None and global_batch_idx_train is not None and global_batch_idx_train // accum_grad < 4000): @@ -153,7 +154,6 @@ def maybe_log_gradients(tag: str): def get_validation_objf(dataloader: torch.utils.data.DataLoader, model: AcousticModel, ali_model: Optional[AcousticModel], - P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool, @@ -172,7 +172,6 @@ def get_validation_objf(dataloader: torch.utils.data.DataLoader, batch=batch, model=model, ali_model=ali_model, - P=P, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, @@ -192,7 +191,6 @@ def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, ali_model: Optional[AcousticModel], - P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool, @@ -213,7 +211,6 @@ def train_one_epoch(dataloader: torch.utils.data.DataLoader, dataloader: Training dataloader valid_dataloader: Validation dataloader model: Acoustic model to be trained - P: An FSA representing the bigram phone LM device: Training device, torch.device("cpu") or torch.device("cuda", device_id) graph_compiler: MMI training graph compiler optimizer: Training optimizer @@ -250,15 +247,10 @@ def train_one_epoch(dataloader: torch.utils.data.DataLoader, timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() - if forward_count == 1 or accum_grad == 1: - P.set_scores_stochastic_(model.module.P_scores) - assert P.requires_grad is True - curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, ali_model=ali_model, - P=P, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, @@ -307,7 +299,6 @@ def train_one_epoch(dataloader: torch.utils.data.DataLoader, dataloader=valid_dataloader, model=model, ali_model=ali_model, - P=P, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, @@ -333,7 +324,10 @@ def train_one_epoch(dataloader: torch.utils.data.DataLoader, tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) - model.module.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) + if hasattr(model, 'module'): + model.module.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) + else: + model.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train @@ -485,7 +479,8 @@ def run(rank, world_size, args): use_pruned_intersect = args.use_pruned_intersect fix_random_seed(42) - setup_dist(rank, world_size, args.master_port) + if world_size > 1: + setup_dist(rank, world_size, args.master_port) exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') setup_logger(f'{exp_dir}/log/log-train-{rank}') @@ -502,14 +497,39 @@ def run(rank, world_size, args): device_id = rank device = torch.device('cuda', device_id) + if not Path(lang_dir / 'P.pt').is_file(): + logging.debug(f'Loading P from {lang_dir}/P.fst.txt') + with open(lang_dir / 'P.fst.txt') as f: + # P is not an acceptor because there is + # a back-off state, whose incoming arcs + # have label #0 and aux_label eps. + P = k2.Fsa.from_openfst(f.read(), acceptor=False) + + phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') + first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) + + # P.aux_labels is not needed in later computations, so + # remove it here. + del P.aux_labels + # CAUTION(fangjun): The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + P.labels[P.labels >= first_phone_disambig_id] = 0 + + P = k2.remove_epsilon(P) + P = k2.arc_sort(P) + torch.save(P.as_dict(), lang_dir / 'P.pt') + else: + logging.debug('Loading pre-compiled P') + d = torch.load(lang_dir / 'P.pt') + P = k2.Fsa.from_dict(d) + graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, + P=P, device=device, ) phone_ids = lexicon.phone_symbols() - P = create_bigram_phone_lm(phone_ids) - P.scores = torch.zeros_like(P.scores) - P = P.to(device) librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() @@ -557,8 +577,6 @@ def run(rank, world_size, args): else: raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") - model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) - if args.torchscript: logging.info('Applying TorchScript to model...') model = torch.jit.script(model) @@ -566,7 +584,8 @@ def run(rank, world_size, args): model.to(device) describe(model) - model = DDP(model, device_ids=[rank]) + if world_size > 1: + model = DDP(model, device_ids=[rank]) # Now for the alignment model, if any if args.use_ali_model: @@ -624,7 +643,6 @@ def run(rank, world_size, args): valid_dataloader=valid_dl, model=model, ali_model=ali_model, - P=P, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, @@ -696,8 +714,9 @@ def run(rank, world_size, args): local_rank=rank) logging.warning('Done') - torch.distributed.barrier() - cleanup_dist() + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() def main(): @@ -706,7 +725,10 @@ def main(): args = parser.parse_args() world_size = args.world_size assert world_size >= 1 - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) torch.set_num_threads(1) diff --git a/egs/librispeech/asr/simple_v1/run.sh b/egs/librispeech/asr/simple_v1/run.sh index 1aa18626..4fcd1c4d 100755 --- a/egs/librispeech/asr/simple_v1/run.sh +++ b/egs/librispeech/asr/simple_v1/run.sh @@ -9,6 +9,29 @@ set -eou pipefail stage=0 +libri_dirs=( +/root/fangjun/data/librispeech/LibriSpeech +/export/corpora5/LibriSpeech +/home/storage04/zhuangweiji/data/open-source-data/librispeech/LibriSpeech +/export/common/data/corpora/ASR/openslr/SLR12/LibriSpeech +) + +libri_dir= +for d in ${libri_dirs[@]}; do + if [ -d $d ]; then + libri_dir=$d + break + fi +done + +if [ ! -d $libri_dir/train-clean-100 ]; then + echo "Please set LibriSpeech dataset path before running this script" + exit 1 +fi + +echo "LibriSpeech dataset dir: $libri_dir" + + if [ $stage -le 1 ]; then local/download_lm.sh "openslr.org/resources/11" data/local/lm fi @@ -70,6 +93,56 @@ if [ $stage -le 4 ]; then fi if [ $stage -le 5 ]; then + mkdir -p data/local/tmp + if [ ! -f data/local/tmp/transcript.txt ]; then + echo "Generating data/local/tmp/transcript.txt" + files=$( + find "$libri_dir/train-clean-100" -name "*.trans.txt" + find "$libri_dir/train-clean-360" -name "*.trans.txt" + find "$libri_dir/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > data/local/tmp/transcript.txt + fi +fi + +if [ $stage -le 6 ]; then + # this stage takes about 3 minutes + mkdir -p data/lm + if [ ! -f data/lm/P.arpa ]; then + echo "Generating data/lm/P.arpa" + ./local/add_silence_to_transcript.py \ + --transcript data/local/tmp/transcript.txt \ + --sil-word "!SIL" \ + --sil-prob 0.5 \ + --seed 20210629 \ + > data/lm/transcript_with_sil.txt + + ./local/convert_transcript_to_corpus.py \ + --transcript data/lm/transcript_with_sil.txt \ + --lexicon data/local/dict_nosp/lexicon.txt \ + --oov "" \ + > data/lm/corpus.txt + + ./local/make_kn_lm.py \ + -ngram-order 2 \ + -text data/lm/corpus.txt \ + -lm data/lm/P.arpa + fi +fi + +if [ $stage -le 7 ]; then + if [ ! -f data/lang_nosp/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="data/lang_nosp/phones.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + data/lm/P.arpa > data/lang_nosp/P.fst.txt + fi +fi + +if [ $stage -le 8 ]; then python3 ./prepare.py fi @@ -79,7 +152,7 @@ fi # # exit 0 -if [ $stage -le 6 ]; then +if [ $stage -le 9 ]; then # python3 ./train.py # ctc training # python3 ./mmi_bigram_train.py # ctc training + bigram phone LM # python3 ./mmi_mbr_train.py @@ -99,7 +172,7 @@ if [ $stage -le 6 ]; then # python3 -m torch.distributed.launch --nproc_per_node=$ngpus ./mmi_bigram_train.py --world_size $ngpus fi -if [ $stage -le 7 ]; then +if [ $stage -le 10 ]; then # python3 ./decode.py # ctc decoding # python3 ./mmi_bigram_decode.py --epoch 9 # python3 ./mmi_mbr_decode.py diff --git a/snowfall/common.py b/snowfall/common.py index 40b727b5..55880790 100755 --- a/snowfall/common.py +++ b/snowfall/common.py @@ -88,9 +88,11 @@ def load_checkpoint( src_key = '{}.{}'.format('module', key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict) + model.load_state_dict(dst_state_dict, strict=False) else: - model.load_state_dict(checkpoint['state_dict']) + model.load_state_dict(checkpoint['state_dict'], strict=False) + # Note we used strict=False above so that the current code + # can load models trained with P_scores. model.num_features = checkpoint['num_features'] model.num_classes = checkpoint['num_classes'] @@ -151,9 +153,9 @@ def average_checkpoint(filenames: List[Pathlike], model: AcousticModel) -> Dict[ src_key = '{}.{}'.format('module', key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict) + model.load_state_dict(dst_state_dict, strict=False) else: - model.load_state_dict(checkpoint['state_dict']) + model.load_state_dict(checkpoint['state_dict'], strict=False) model.num_features = checkpoint['num_features'] model.num_classes = checkpoint['num_classes'] diff --git a/snowfall/decoding/lm_rescore.py b/snowfall/decoding/lm_rescore.py index 1edef233..529cb1d0 100644 --- a/snowfall/decoding/lm_rescore.py +++ b/snowfall/decoding/lm_rescore.py @@ -311,10 +311,9 @@ def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, # scores = (scores - lm_scores)/lm_scale + lm_scores # = scores/lm_scale + lm_scores*(1 - 1/lm_scale) # - saved_scores = inv_lats.scores.clone() + saved_am_scores = inv_lats.scores - inv_lats.lm_scores for lm_scale in lm_scale_list: - am_scores = saved_scores - inv_lats.lm_scores - am_scores /= lm_scale + am_scores = saved_am_scores / lm_scale inv_lats.scores = am_scores + inv_lats.lm_scores best_paths = k2.shortest_path(inv_lats, use_double_scores=True) diff --git a/snowfall/objectives/mmi.py b/snowfall/objectives/mmi.py index 88cd55af..abefe643 100644 --- a/snowfall/objectives/mmi.py +++ b/snowfall/objectives/mmi.py @@ -14,7 +14,6 @@ def _compute_mmi_loss_exact_optimized( texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, - P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' @@ -36,13 +35,10 @@ def _compute_mmi_loss_exact_optimized( A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`. graph_compiler: Used to build num_graphs and den_graphs - P: - Represents a bigram Fsa. den_scale: The scale applied to the denominator tot_scores. ''' num_graphs, den_graphs = graph_compiler.compile(texts, - P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) @@ -111,7 +107,6 @@ def _compute_mmi_loss_exact_non_optimized( texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, - P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' @@ -124,7 +119,6 @@ def _compute_mmi_loss_exact_non_optimized( It uses less memory at the cost of speed. It is slower. ''' num_graphs, den_graphs = graph_compiler.compile(texts, - P, replicate_den=True) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) @@ -149,7 +143,6 @@ def _compute_mmi_loss_pruned( texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, - P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' @@ -163,7 +156,6 @@ def _compute_mmi_loss_pruned( to pruning. ''' num_graphs, den_graphs = graph_compiler.compile(texts, - P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) @@ -200,13 +192,11 @@ class LFMMILoss(nn.Module): def __init__( self, graph_compiler: MmiTrainingGraphCompiler, - P: k2.Fsa, use_pruned_intersect: bool = False, den_scale: float = 1.0, ): super().__init__() self.graph_compiler = graph_compiler - self.P = P self.den_scale = den_scale self.use_pruned_intersect = use_pruned_intersect @@ -223,5 +213,4 @@ def forward(self, nnet_output: torch.Tensor, texts: List[str], texts=texts, supervision_segments=supervision_segments, graph_compiler=self.graph_compiler, - P=self.P, den_scale=self.den_scale) diff --git a/snowfall/training/mmi_graph.py b/snowfall/training/mmi_graph.py index 758830f7..e2f22b06 100644 --- a/snowfall/training/mmi_graph.py +++ b/snowfall/training/mmi_graph.py @@ -46,6 +46,7 @@ class MmiTrainingGraphCompiler(object): def __init__( self, lexicon: Lexicon, + P: k2.Fsa, device: torch.device, oov: str = '' ): @@ -53,15 +54,19 @@ def __init__( Args: L_inv: Its labels are words, while its aux_labels are phones. - phones: - The phone symbol table. - words: - The word symbol table. - oov: - Out of vocabulary word. + P: + A phone bigram LM if the pronunciations in the lexicon are in phones; + a word piece bigram if the pronunciations in the lexicon are word pieces. + phones: + The phone symbol table. + words: + The word symbol table. + oov: + Out of vocabulary word. ''' self.lexicon = lexicon L_inv = self.lexicon.L_inv.to(device) + P = P.to(device) if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) @@ -81,11 +86,20 @@ def __init__( ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) assert ctc_topo.requires_grad is False - self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) + ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) + + P_with_self_loops = k2.add_epsilon_self_loops(P) + + ctc_topo_P = k2.intersect(ctc_topo_inv, + P_with_self_loops, + treat_epsilons_specially=False).invert() + + self.ctc_topo_P = k2.arc_sort(ctc_topo_P) + + def compile(self, texts: Iterable[str], - P: k2.Fsa, replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. @@ -94,8 +108,6 @@ def compile(self, texts: A list of transcripts. Within a transcript, words are separated by spaces. - P: - The bigram phone LM created by :func:`create_bigram_phone_lm`. replicate_den: If True, the returned den_graph is replicated to match the number of FSAs in the returned num_graph; if False, the returned den_graph @@ -110,33 +122,18 @@ def compile(self, shape of the `num_graph` if replicate_den is True; otherwise, it is an FsaVec containing only a single FSA. ''' - self_device = str(self.device) - if self_device == 'cuda': - # the compilers graph device does not specify GPU ID, just check that both tensors are on GPU - assert str(P.device).startswith( - 'cuda'), f'Assertion failed: GraphCompiler uses on "cuda", but P is on "{P.device}"' - else: - assert str(P.device) == str(self.device), f'Assertion failed: "{P.device} == {self.device}"' - P_with_self_loops = k2.add_epsilon_self_loops(P) - - ctc_topo_P = k2.intersect(self.ctc_topo_inv, - P_with_self_loops, - treat_epsilons_specially=False).invert() - - ctc_topo_P = k2.arc_sort(ctc_topo_P) - num_graphs = self.build_num_graphs(texts) num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( num_graphs) num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) - num = k2.compose(ctc_topo_P, + num = k2.compose(self.ctc_topo_P, num_graphs_with_self_loops, treat_epsilons_specially=False) num = k2.arc_sort(num) - ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) + ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P]) if replicate_den: indexes = torch.zeros(len(texts), dtype=torch.int32,