Skip to content

Commit

Permalink
updates for domain test work
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTemaki committed Dec 19, 2024
1 parent a4bb413 commit c054221
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 94 deletions.
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from sisyphus import tk

from i6_core.lexicon.modification import WriteLexiconJob

from i6_experiments.common.helpers.g2p import G2PBasedOovAugmenter
from i6_experiments.common.datasets.librispeech.lexicon import get_bliss_lexicon
from i6_experiments.common.datasets.librispeech.lexicon import get_bliss_lexicon, _get_special_lemma_lexicon
from i6_experiments.users.rossenbach.corpus.generate import CreateBlissFromTextLinesJob


from ..tts.tts_phon import get_lexicon


def create_data_lexicon(prefix: str, lm_text_bliss: tk.Path):
def create_data_lexicon(prefix: str, lexicon_bliss: tk.Path):
"""
:param prefix:
:param lm_text_bliss:
:param lexicon_bliss:
:return:
"""
ls960_tts_lexicon = get_lexicon(with_blank=False, corpus_key="train-other-960")
Expand All @@ -23,7 +25,28 @@ def create_data_lexicon(prefix: str, lm_text_bliss: tk.Path):
apply_args={"concurrent": 5}
)
extended_bliss_lexicon = g2p_augmenter.get_g2p_augmented_bliss_lexicon(
bliss_corpus=lm_text_bliss,
bliss_corpus=lexicon_bliss,
corpus_name="lm_tts_data",
alias_path=prefix,
casing="upper",
)
return extended_bliss_lexicon

def create_data_lexicon_v2(prefix: str, lexicon_bliss: tk.Path):
ls960_tts_lexicon = get_lexicon(with_blank=False, corpus_key="train-other-960")

static_lexicon = _get_special_lemma_lexicon(
add_unknown_phoneme_and_mapping=False,
add_silence=False,
)
static_lexicon_job = WriteLexiconJob(static_lexicon, sort_phonemes=True, sort_lemmata=False)
g2p_augmenter = G2PBasedOovAugmenter(
original_bliss_lexicon=static_lexicon_job.out_bliss_lexicon,
train_lexicon=ls960_tts_lexicon,
apply_args={"concurrent": 5}
)
extended_bliss_lexicon = g2p_augmenter.get_g2p_augmented_bliss_lexicon(
bliss_corpus=lexicon_bliss,
corpus_name="lm_tts_data",
alias_path=prefix,
casing="upper",
Expand All @@ -33,6 +56,7 @@ def create_data_lexicon(prefix: str, lm_text_bliss: tk.Path):

def create_data_lexicon_rasr_style(prefix: str, lm_text_bliss: tk.Path, with_unknown: bool):
"""
(with librispeech)
:param prefix:
:param lm_text_bliss:
Expand All @@ -42,7 +66,7 @@ def create_data_lexicon_rasr_style(prefix: str, lm_text_bliss: tk.Path, with_unk
ls960_tts_lexicon = get_lexicon(with_blank=False, corpus_key="train-other-960")

ls960_rasr_lexicon = get_bliss_lexicon(
use_stress_marker=False, add_unknown_phoneme_and_mapping=with_unknown, add_silence=True
use_stress_marker=False, add_unknown_phoneme_and_mapping=with_unknown, add_silence=True, output_prefix=prefix
)

g2p_augmenter = G2PBasedOovAugmenter(
Expand All @@ -58,6 +82,36 @@ def create_data_lexicon_rasr_style(prefix: str, lm_text_bliss: tk.Path, with_unk
)
return extended_bliss_lexicon

def create_data_lexicon_rasr_style_v2(prefix: str, lm_text_bliss: tk.Path, with_unknown: bool):
"""
(pure, without librispeech, only static lexicon)
:param prefix:
:param lm_text_bliss:
:return:
"""

ls960_tts_lexicon = get_lexicon(with_blank=False, corpus_key="train-other-960")

static_lexicon = _get_special_lemma_lexicon(
add_unknown_phoneme_and_mapping=with_unknown,
add_silence=True,
)
static_lexicon_job = WriteLexiconJob(static_lexicon, sort_phonemes=True, sort_lemmata=False)

g2p_augmenter = G2PBasedOovAugmenter(
original_bliss_lexicon=static_lexicon_job.out_bliss_lexicon,
train_lexicon=ls960_tts_lexicon,
apply_args={"concurrent": 5}
)
extended_bliss_lexicon = g2p_augmenter.get_g2p_augmented_bliss_lexicon(
bliss_corpus=lm_text_bliss,
corpus_name="lm_tts_data",
alias_path=prefix,
casing="upper",
)
return extended_bliss_lexicon


def bliss_from_text(prefix, name, lm_text):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import cast, Optional

from i6_core.tools.parameter_tuning import GetOptimalParametersAsVariableJob
from i6_core.report.report import GenerateReportStringJob

from i6_experiments.common.setups.returnn.datastreams.vocabulary import LabelDatastream

Expand All @@ -15,7 +16,18 @@
from i6_experiments.users.rossenbach.experiments.librispeech.ctc_rnnt_standalone_2024.lm import get_4gram_binary_lm
from i6_experiments.users.rossenbach.experiments.librispeech.ctc_rnnt_standalone_2024.pipeline import training, prepare_asr_model, search, ASRModel
from i6_experiments.users.rossenbach.experiments.librispeech.ctc_rnnt_standalone_2024.storage import get_ctc_model, get_synthetic_data, get_rnnt_model, get_aed_model
from i6_experiments.users.rossenbach.experiments.librispeech.ctc_rnnt_standalone_2024.report import tune_and_evalue_report

def report_template(report_values):
from i6_core.util import instanciate_delayed
report_values = instanciate_delayed(report_values)

string = f"Results for {report_values['corpus_name']}:\n"
string += f"Best LM: {report_values['best_lm']}\n"
string += f"Best Prior: {report_values['best_prior']}\n\n"

string += f"Final WER: {report_values['best_wer']}\n"

return string


def bpe_ls960_0924_relposencoder(lex, lm):
Expand All @@ -35,6 +47,7 @@ def ctc_tune_and_evaluate_helper(training_name, dev_dataset_tuples, test_dataset
lm_scales, prior_scales):
tune_parameters = []
report_values = {}
tune_values = []
for lm_weight in lm_scales:
for prior_scale in prior_scales:
decoder_config = copy.deepcopy(base_decoder_config)
Expand All @@ -51,6 +64,19 @@ def ctc_tune_and_evaluate_helper(training_name, dev_dataset_tuples, test_dataset
**default_returnn
)
tune_parameters.append((lm_weight, prior_scale))
tune_values.append(list(wers.values())[0])
pick_optimal_params_job = GetOptimalParametersAsVariableJob(parameters=tune_parameters, values=tune_values, mode="minimize")
dev_name = list(dev_dataset_tuples.keys())[0]
report_values = {
"corpus_name": dev_name,
"best_lm": pick_optimal_params_job.out_optimal_parameters[0],
"best_prior": pick_optimal_params_job.out_optimal_parameters[1],
"best_wer": pick_optimal_params_job.out_optimal_value
}
report = GenerateReportStringJob(report_values=report_values, report_template=report_template,
compress=False).out_report
tk.register_output(training_name + "/%s_report.txt" % dev_name, report)


def rnnt_evaluate_helper(
training_name: str,
Expand Down Expand Up @@ -117,10 +143,13 @@ def aed_evaluate_helper(
_, med_wmt22_n2_noise03_oggzip = get_synthetic_data("wmt22_medline_v1_sequiturg2p_glowtts460_noise03")
# (dataset, bliss)
ddt_medline_wmt22_noise07 = {"medline_wmt22_n2": (build_test_dataset_from_zip(med_wmt22_n2_oggzip, bpe_ctc_asr_model.settings), med_wmt22_n2_bliss)}
dev_dataset_tuples_noise03 = {"medline_wmt22_n2_noise03": (build_test_dataset_from_zip(med_wmt22_n2_noise03_oggzip, bpe_ctc_asr_model.settings), med_wmt22_n2_bliss)}
ddt_medline_wmt22_noise03 = {"medline_wmt22_n2_noise03": (build_test_dataset_from_zip(med_wmt22_n2_noise03_oggzip, bpe_ctc_asr_model.settings), med_wmt22_n2_bliss)}


bpe_lexicon = build_custom_bpe_lexicon(lex, bpe_ctc_asr_model.label_datastream.codes, bpe_ctc_asr_model.label_datastream.vocab)
bpe_lexicon = {
name: build_custom_bpe_lexicon(lex, bpe_ctc_asr_model.label_datastream.codes, bpe_ctc_asr_model.label_datastream.vocab)
for name, lex in lex.items()
}
default_decoder_config_bpe = DecoderConfig(
lexicon=bpe_ctc_asr_model.lexicon,
returnn_vocab=bpe_ctc_asr_model.returnn_vocab,
Expand Down Expand Up @@ -152,45 +181,82 @@ def aed_evaluate_helper(
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)

# decoding with changed LM and updated lexicon
ufal_lm_config = DecoderConfig(
lexicon=bpe_lexicon,
returnn_vocab=bpe_ctc_asr_model.returnn_vocab,
beam_size=1024,
beam_size_token=16, # makes it much faster
arpa_lm=lm["ufal_v1_mixlex_v2"],
beam_threshold=14,
)

ctc_tune_and_evaluate_helper(
prefix_name + "/medline_wmt22_ende_n2_ufal_lm_mixlex",
ddt_medline_wmt22_noise07, {}, bpe_ctc_asr_model, ufal_lm_config,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)
for lex_lm_key in ["ufal_v1_mixlex_v2", "ufal_v1_3more_only"]:
# decoding with changed LM and updated lexicon
ufal_lm_config = DecoderConfig(
lexicon=bpe_lexicon[lex_lm_key],
returnn_vocab=bpe_ctc_asr_model.returnn_vocab,
beam_size=1024,
beam_size_token=16, # makes it much faster
arpa_lm=lm[lex_lm_key],
beam_threshold=14,
)

ctc_tune_and_evaluate_helper(
prefix_name + "/medline_wmt22_ende_n2_noise03_ufal_lm_mixlex",
dev_dataset_tuples_noise03, {}, bpe_ctc_asr_model, ufal_lm_config,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)
ctc_tune_and_evaluate_helper(
prefix_name + f"/medline_wmt22_ende_n2_noise07_{lex_lm_key}",
ddt_medline_wmt22_noise07, {}, bpe_ctc_asr_model, ufal_lm_config,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)

ctc_tune_and_evaluate_helper(
prefix_name + f"/medline_wmt22_ende_n2_noise03_{lex_lm_key}",
ddt_medline_wmt22_noise03, {}, bpe_ctc_asr_model, ufal_lm_config,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)

ufal_lm_config_nols = DecoderConfig(
lexicon=bpe_lexicon[lex_lm_key + "_nols"],
returnn_vocab=bpe_ctc_asr_model.returnn_vocab,
beam_size=1024,
beam_size_token=16, # makes it much faster
arpa_lm=lm[lex_lm_key],
beam_threshold=14,
)

ctc_tune_and_evaluate_helper(
prefix_name + f"/medline_wmt22_ende_n2_noise07_{lex_lm_key}_nols",
ddt_medline_wmt22_noise07, {}, bpe_ctc_asr_model, ufal_lm_config_nols,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)


# dev other reference
dev_other_bliss, dev_other_oggzip = get_synthetic_data("dev-other_sequiturg2p_glowtts460_noise07")
_, dev_other_noise06_oggzip = get_synthetic_data("dev-other_sequiturg2p_glowtts460_noise06")
_, dev_other_noise055_oggzip = get_synthetic_data("dev-other_sequiturg2p_glowtts460_noise055")
_, dev_other_noise05_oggzip = get_synthetic_data("dev-other_sequiturg2p_glowtts460_noise05")
_, dev_other_noise03_oggzip = get_synthetic_data("dev-other_sequiturg2p_glowtts460_noise03")
# (dataset, bliss)
ddt_dev_other_noise07 = {"dev_other": (build_test_dataset_from_zip(dev_other_oggzip, bpe_ctc_asr_model.settings), dev_other_bliss)}
dev_dataset_tuples_noise03 = {"dev_other": (build_test_dataset_from_zip(dev_other_noise03_oggzip, bpe_ctc_asr_model.settings), dev_other_bliss)}
ddt_dev_other_noise06 = {"dev_other": (build_test_dataset_from_zip(dev_other_noise06_oggzip, bpe_ctc_asr_model.settings), dev_other_bliss)}
ddt_dev_other_noise055 = {"dev_other": (build_test_dataset_from_zip(dev_other_noise055_oggzip, bpe_ctc_asr_model.settings), dev_other_bliss)}
ddt_dev_other_noise05 = {"dev_other": (build_test_dataset_from_zip(dev_other_noise05_oggzip, bpe_ctc_asr_model.settings), dev_other_bliss)}
ddt_dev_other_noise03 = {"dev_other": (build_test_dataset_from_zip(dev_other_noise03_oggzip, bpe_ctc_asr_model.settings), dev_other_bliss)}

ctc_tune_and_evaluate_helper(
prefix_name + "/dev_other",
ddt_dev_other_noise07, {}, bpe_ctc_asr_model, default_decoder_config_bpe,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)

ctc_tune_and_evaluate_helper(
prefix_name + "/dev_other_noise06",
ddt_dev_other_noise06, {}, bpe_ctc_asr_model, default_decoder_config_bpe,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)
ctc_tune_and_evaluate_helper(
prefix_name + "/dev_other_noise055",
ddt_dev_other_noise055, {}, bpe_ctc_asr_model, default_decoder_config_bpe,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)
ctc_tune_and_evaluate_helper(
prefix_name + "/dev_other_noise05",
ddt_dev_other_noise05, {}, bpe_ctc_asr_model, default_decoder_config_bpe,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)
ctc_tune_and_evaluate_helper(
prefix_name + "/dev_other_noise03",
dev_dataset_tuples_noise03, {}, bpe_ctc_asr_model, default_decoder_config_bpe,
ddt_dev_other_noise03, {}, bpe_ctc_asr_model, default_decoder_config_bpe,
lm_scales=[1.6, 1.8, 2.0, 2.2], prior_scales=[0.1, 0.2, 0.3, 0.4]
)

Expand Down
Loading

0 comments on commit c054221

Please sign in to comment.