Skip to content

Commit

Permalink
diff LMs
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 16, 2024
1 parent 2f2d75c commit c109457
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
# ...
# trafo-n24-d512-noAbsPos-rmsNorm-ffGated-rope-noBias-drop0-b100_5k
# _lm_name = "trafo-n96-d512-gelu-drop0-b32_1k"
_lm_name = "trafo-n24-d512-noAbsPos-rmsNorm-ffGated-rope-noBias-drop0-b100_5k"
_lms = {
"n24-d512": "trafo-n24-d512-noAbsPos-rmsNorm-ffGated-rope-noBias-drop0-b100_5k",
"n96-d512": "trafo-n96-d512-gelu-drop0-b32_1k",
}


def py():
Expand All @@ -47,22 +50,28 @@ def py():
vocab = "spm10k"
task = get_librispeech_task_raw_v2(vocab=vocab)
ctc_model = _get_ctc_model(_ctc_model_name)
lm = _get_lm_model(_lm_name)
prior = get_ctc_prior_probs(ctc_model, task.train_dataset.copy_train_as_static())
tk.register_output(f"{prefix}/ctc-prior", prior)

for prior_scale, lm_scale in [(1.0, 1.0)]:
model = get_ctc_with_lm(
ctc_model=ctc_model, prior=prior, prior_scale=prior_scale, language_model=lm, lm_scale=lm_scale
)
res = recog_model(
task=task,
model=model,
recog_def=model_recog,
config={"beam_size": 12, "recog_version": 2, "batch_size": 5_000 * ctc_model.definition.batch_size_factor},
search_rqmt={"time": 24},
)
tk.register_output(f"{prefix}/recog-priorScale{prior_scale}-lmScale{lm_scale}", res.output)
for lm_out_name, lm_name in _lms.items():
lm = _get_lm_model(lm_name)

for prior_scale, lm_scale in [(1.0, 1.0)]:
model = get_ctc_with_lm(
ctc_model=ctc_model, prior=prior, prior_scale=prior_scale, language_model=lm, lm_scale=lm_scale
)
res = recog_model(
task=task,
model=model,
recog_def=model_recog,
config={
"beam_size": 12,
"recog_version": 2,
"batch_size": 5_000 * ctc_model.definition.batch_size_factor,
},
search_rqmt={"time": 24},
)
tk.register_output(f"{prefix}/recog-lm{lm_out_name}-lmScale{lm_scale}-priorScale{prior_scale}", res.output)


_sis_prefix: Optional[str] = None
Expand Down

0 comments on commit c109457

Please sign in to comment.