Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 18, 2024
1 parent b26c1ca commit a4bb413
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def model_recog(
dims=batch_dims + [beam_dim],
in_dim=packed_new_label_dim,
dim_map=packed_new_label_dim_map,
)
) # Batch, InBeam, Vocab / ...

seq_log_prob = seq_log_prob + label_log_prob_ta[t] # Batch, InBeam, VocabWB

Expand Down Expand Up @@ -343,6 +343,7 @@ def model_recog(
seq_targets_wb.append(target_wb)
seq_backrefs.append(backrefs)

lm_log_probs = rf.gather(lm_log_probs, indices=backrefs) # Batch, Beam, Vocab
lm_state = tree.map_structure(functools.partial(_gather_backrefs, backrefs=backrefs), lm_state)
prev_target = rf.gather(prev_target, indices=backrefs) # Batch, Beam -> Vocab
prev_target_wb = rf.gather(prev_target_wb, indices=backrefs) # Batch, Beam -> VocabWB
Expand Down

0 comments on commit a4bb413

Please sign in to comment.