Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 19, 2024
1 parent 299c0c4 commit a0608cb
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def py():
recog_def=model_recog,
config={
"beam_size": beam_size,
"recog_version": 5,
"recog_version": 6,
"batch_size": 5_000 * ctc_model.definition.batch_size_factor,
},
search_rqmt={"time": 24},
Expand Down Expand Up @@ -257,7 +257,7 @@ def model_recog(
config = get_global_config()
beam_size = config.int("beam_size", 12)
version = config.int("recog_version", 1)
assert version == 5
assert version == 6

batch_dims = data.remaining_dims((data_spatial_dim, data.feature_dim))
logits, enc, enc_spatial_dim = model(data, in_spatial_dim=data_spatial_dim)
Expand All @@ -277,13 +277,20 @@ def model_recog(
)
label_log_prob_ta = TensorArray.unstack(label_log_prob, axis=enc_spatial_dim) # t -> Batch, VocabWB

lm_log_probs = rf.constant(0.0, dims=batch_dims_ + [model.target_dim]) # Batch, InBeam, Vocab
lm_state = model.lm.default_initial_state(batch_dims=batch_dims_) # Batch, InBeam, ...
target = rf.constant(model.bos_idx, dims=batch_dims_, sparse_dim=model.target_dim) # Batch, InBeam -> Vocab
target_wb = rf.constant(
model.blank_idx, dims=batch_dims_, sparse_dim=model.wb_target_dim
) # Batch, InBeam -> VocabWB

lm_state = model.lm.default_initial_state(batch_dims=batch_dims_) # Batch, InBeam, ...
lm_logits, lm_state = model.lm(
target,
spatial_dim=single_step_dim,
state=lm_state,
) # Batch, InBeam, Vocab / ...
lm_log_probs = rf.log_softmax(lm_logits, axis=model.target_dim) # Flat_Batch_Beam, Vocab
lm_log_probs *= model.lm_scale

max_seq_len = int(enc_spatial_dim.get_dim_value())
seq_targets_wb = []
seq_backrefs = []
Expand Down

0 comments on commit a0608cb

Please sign in to comment.