From c109457a26b30d812325d87fd3f5bea70001e252 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 16 Dec 2024 17:11:43 +0100 Subject: [PATCH] diff LMs --- .../exp2024_04_23_baselines/ctc_recog_ext.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py index c69e0c087..0fe851913 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py @@ -32,7 +32,10 @@ # ... # trafo-n24-d512-noAbsPos-rmsNorm-ffGated-rope-noBias-drop0-b100_5k # _lm_name = "trafo-n96-d512-gelu-drop0-b32_1k" -_lm_name = "trafo-n24-d512-noAbsPos-rmsNorm-ffGated-rope-noBias-drop0-b100_5k" +_lms = { + "n24-d512": "trafo-n24-d512-noAbsPos-rmsNorm-ffGated-rope-noBias-drop0-b100_5k", + "n96-d512": "trafo-n96-d512-gelu-drop0-b32_1k", +} def py(): @@ -47,22 +50,28 @@ def py(): vocab = "spm10k" task = get_librispeech_task_raw_v2(vocab=vocab) ctc_model = _get_ctc_model(_ctc_model_name) - lm = _get_lm_model(_lm_name) prior = get_ctc_prior_probs(ctc_model, task.train_dataset.copy_train_as_static()) tk.register_output(f"{prefix}/ctc-prior", prior) - for prior_scale, lm_scale in [(1.0, 1.0)]: - model = get_ctc_with_lm( - ctc_model=ctc_model, prior=prior, prior_scale=prior_scale, language_model=lm, lm_scale=lm_scale - ) - res = recog_model( - task=task, - model=model, - recog_def=model_recog, - config={"beam_size": 12, "recog_version": 2, "batch_size": 5_000 * ctc_model.definition.batch_size_factor}, - search_rqmt={"time": 24}, - ) - tk.register_output(f"{prefix}/recog-priorScale{prior_scale}-lmScale{lm_scale}", res.output) + for lm_out_name, lm_name in _lms.items(): + lm = _get_lm_model(lm_name) + + for prior_scale, lm_scale in [(1.0, 1.0)]: + model = get_ctc_with_lm( + ctc_model=ctc_model, prior=prior, prior_scale=prior_scale, language_model=lm, lm_scale=lm_scale + ) + res = recog_model( + task=task, + model=model, + recog_def=model_recog, + config={ + "beam_size": 12, + "recog_version": 2, + "batch_size": 5_000 * ctc_model.definition.batch_size_factor, + }, + search_rqmt={"time": 24}, + ) + tk.register_output(f"{prefix}/recog-lm{lm_out_name}-lmScale{lm_scale}-priorScale{prior_scale}", res.output) _sis_prefix: Optional[str] = None