Skip to content

Commit

Permalink
more debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 23, 2025
1 parent d8695fa commit a54f9a5
Showing 1 changed file with 144 additions and 2 deletions.
146 changes: 144 additions & 2 deletions users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from __future__ import annotations
from typing import TypeVar, Optional, Sequence, Tuple, Dict, Generator
from typing import TypeVar, Optional, Any, Sequence, Tuple, Dict, Generator
import re
import functools

Expand Down Expand Up @@ -175,6 +175,7 @@ def model_recog(
seq_label = _gather_backrefs_tree(seq_label, backrefs=backrefs, dim_map=backrefs_dim_map)
_seq_label_print(f"{t=} gather backrefs", seq_label)

prev = (lm_log_probs, lm_state, lm_scores, seq_label)
got_new_label_cpu = rf.copy_to_device(got_new_label, "cpu")
if got_new_label_cpu.raw_tensor.sum().item() > 0:
(
Expand Down Expand Up @@ -230,7 +231,14 @@ def model_recog(

# _seq_label_print("masked scatter", seq_label)

# TODO debug more... compare scores to when fed directly
new_state = _state_update(prev, target=target, beam_dim=beam_dim, model=model, lm=lm, lm_scale=lm_scale)
_where_deep_check(
prev,
new_state,
(lm_log_probs, lm_state, lm_scores, seq_label),
mask=got_new_label,
mask_cpu=got_new_label_cpu,
)

# seq_log_prob, lm_log_probs: Batch, Beam
# Add LM EOS score at the end.
Expand Down Expand Up @@ -577,6 +585,140 @@ def _extend_dim_name(name: str) -> str:
# for debugging:


def _state_update(
prev: Any,
*,
target: Tensor,
beam_dim: Dim,
model: Model,
lm: TransformerDecoder,
lm_scale: float,
) -> Tuple[Tensor, Any, Tensor, Any]:
from returnn.tensor import batch_dim

lm_log_probs, lm_state, lm_scores, seq_label = prev

lm_log_probs = rf.gather(lm_log_probs, axis=model.target_dim, indices=target) # Batch, Beam
assert lm_scores.dims_set == lm_log_probs.dims_set == target.dims_set == {batch_dim, beam_dim}
lm_scores = lm_scores + lm_log_probs # Batch, Beam

lm_logits, lm_state = lm(
target,
spatial_dim=single_step_dim,
state=lm_state,
) # Batch, Beam, Vocab / ...
lm_log_probs = rf.log_softmax(lm_logits, axis=model.target_dim) # Batch, Beam, Vocab
lm_log_probs *= lm_scale

seq_label = _seq_label_append(seq_label, target)

return lm_log_probs, lm_state, lm_scores, seq_label


def _where_deep_check(a: Any, b: Any, ref_result: Any, *, mask: Tensor, mask_cpu: Tensor):
import tree

tree.assert_same_structure(a, b)
tree.assert_same_structure(a, ref_result)

dim_map = {}
tree.map_structure(functools.partial(_where_prepare_dims, mask=mask_cpu, dim_map=dim_map), a, b)
res = tree.map_structure(functools.partial(_where, mask=mask, mask_cpu=mask_cpu, dim_map=dim_map), a, b)

check_dim_map = {}
tree.map_structure(functools.partial(_where_res_check_equal_prepare_dims, dim_map=check_dim_map), res, ref_result)
equal = tree.map_structure_with_path(
functools.partial(_where_res_check_equal, dim_map=check_dim_map), res, ref_result
)
if all(tree.flatten(equal)):
return res
print("** Error, some elements are not equal:", equal)
raise SystemExit(1)


def _where_prepare_dims(a: Any, b: Any, *, mask: Tensor, dim_map: Dict[Dim, Dim]):
if isinstance(a, Dim):
assert isinstance(b, Dim)
if a == b:
return a
if a in dim_map:
return dim_map[a]
assert b not in dim_map
a_size = a.get_size_tensor()
b_size = b.get_size_tensor()
res_size = rf.where(mask, a_size, b_size, allow_broadcast_all_sources=True)
res_dim = Dim(res_size, name=_extend_dim_name(b.name))
dim_map[a] = res_dim
dim_map[b] = res_dim
return res_dim
if isinstance(a, Tensor):
assert isinstance(b, Tensor)
return a # ignored at this stage
raise TypeError(f"_where_prepare_dims: unexpected type ({type(a)}, {type(b)})")


def _where(a: Any, b: Any, *, mask: Tensor, mask_cpu: Tensor, dim_map: Dict[Dim, Dim]):
if isinstance(a, Dim):
assert isinstance(b, Dim)
if a == b:
return a
assert a in dim_map
assert b in dim_map
return dim_map[a]
if isinstance(a, Tensor):
assert isinstance(b, Tensor)
for d in a.dims:
if d in dim_map:
a = _expand_slice(a, old_dim=d, new_dim=dim_map[d])
for d in b.dims:
if d in dim_map:
b = _expand_slice(b, old_dim=d, new_dim=dim_map[d])
assert a.dims_set == b.dims_set
if a.device == "cpu":
mask = mask_cpu
return rf.where(mask, a, b, allow_broadcast_all_sources=True)
raise TypeError(f"_where: unexpected type ({type(a)}, {type(b)})")


def _where_res_check_equal_prepare_dims(a: Any, b: Any, *, dim_map: Dict[Dim, Dim]):
if isinstance(a, Dim):
assert isinstance(b, Dim)
if a != b:
dim_map[a] = b


def _where_res_check_equal(path: Tuple[Any, ...], a: Any, b: Any, *, dim_map: Dict[Dim, Dim]):
import torch

if isinstance(a, Dim):
assert isinstance(b, Dim)
if a == b:
return True
assert dim_map[a] == b
return _where_res_check_equal(path + ("size",), a.get_size_tensor(), b.get_size_tensor(), dim_map=dim_map)
if isinstance(a, Tensor):
assert isinstance(b, Tensor)
for d in a.dims:
if d in dim_map:
d_ = dim_map[d]
_where_res_check_equal(path + (d,), d, d_, dim_map=dim_map)
a, _ = rf.replace_dim(a, in_dim=d, out_dim=d_)
assert a.dims_set == b.dims_set
a = a.copy_transpose(b.dims)
a = a.copy_masked(0)
b = b.copy_masked(0)
a = rf.copy_to_device(a, "cpu")
b = rf.copy_to_device(b, "cpu")
try:
torch.testing.assert_close(a.raw_tensor, b.raw_tensor)
except AssertionError as exc:
print(f"** Error in {path} ({a} vs {b}):\n{exc}")
return False
else:
return True
raise TypeError(f"_where_res_check_equal: unexpected type ({type(a)}, {type(b)})")


def _seq_label_history_init_state(*, vocab_dim: Dim, batch_dims: Sequence[Dim]) -> rf.State:
hist_dim = Dim(0, name="hist0")
history = rf.zeros(list(batch_dims) + [hist_dim], dtype="int64", sparse_dim=vocab_dim)
Expand Down

0 comments on commit a54f9a5

Please sign in to comment.