From a54f9a516038fdda868aa6945ac44ccf5b3e4444 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 24 Jan 2025 00:36:37 +0100 Subject: [PATCH] more debugging --- .../exp2024_04_23_baselines/recog_ext/ctc.py | 146 +++++++++++++++++- 1 file changed, 144 insertions(+), 2 deletions(-) diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py b/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py index 81cb2ff4a..d37e7493e 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py @@ -3,7 +3,7 @@ """ from __future__ import annotations -from typing import TypeVar, Optional, Sequence, Tuple, Dict, Generator +from typing import TypeVar, Optional, Any, Sequence, Tuple, Dict, Generator import re import functools @@ -175,6 +175,7 @@ def model_recog( seq_label = _gather_backrefs_tree(seq_label, backrefs=backrefs, dim_map=backrefs_dim_map) _seq_label_print(f"{t=} gather backrefs", seq_label) + prev = (lm_log_probs, lm_state, lm_scores, seq_label) got_new_label_cpu = rf.copy_to_device(got_new_label, "cpu") if got_new_label_cpu.raw_tensor.sum().item() > 0: ( @@ -230,7 +231,14 @@ def model_recog( # _seq_label_print("masked scatter", seq_label) - # TODO debug more... compare scores to when fed directly + new_state = _state_update(prev, target=target, beam_dim=beam_dim, model=model, lm=lm, lm_scale=lm_scale) + _where_deep_check( + prev, + new_state, + (lm_log_probs, lm_state, lm_scores, seq_label), + mask=got_new_label, + mask_cpu=got_new_label_cpu, + ) # seq_log_prob, lm_log_probs: Batch, Beam # Add LM EOS score at the end. @@ -577,6 +585,140 @@ def _extend_dim_name(name: str) -> str: # for debugging: +def _state_update( + prev: Any, + *, + target: Tensor, + beam_dim: Dim, + model: Model, + lm: TransformerDecoder, + lm_scale: float, +) -> Tuple[Tensor, Any, Tensor, Any]: + from returnn.tensor import batch_dim + + lm_log_probs, lm_state, lm_scores, seq_label = prev + + lm_log_probs = rf.gather(lm_log_probs, axis=model.target_dim, indices=target) # Batch, Beam + assert lm_scores.dims_set == lm_log_probs.dims_set == target.dims_set == {batch_dim, beam_dim} + lm_scores = lm_scores + lm_log_probs # Batch, Beam + + lm_logits, lm_state = lm( + target, + spatial_dim=single_step_dim, + state=lm_state, + ) # Batch, Beam, Vocab / ... + lm_log_probs = rf.log_softmax(lm_logits, axis=model.target_dim) # Batch, Beam, Vocab + lm_log_probs *= lm_scale + + seq_label = _seq_label_append(seq_label, target) + + return lm_log_probs, lm_state, lm_scores, seq_label + + +def _where_deep_check(a: Any, b: Any, ref_result: Any, *, mask: Tensor, mask_cpu: Tensor): + import tree + + tree.assert_same_structure(a, b) + tree.assert_same_structure(a, ref_result) + + dim_map = {} + tree.map_structure(functools.partial(_where_prepare_dims, mask=mask_cpu, dim_map=dim_map), a, b) + res = tree.map_structure(functools.partial(_where, mask=mask, mask_cpu=mask_cpu, dim_map=dim_map), a, b) + + check_dim_map = {} + tree.map_structure(functools.partial(_where_res_check_equal_prepare_dims, dim_map=check_dim_map), res, ref_result) + equal = tree.map_structure_with_path( + functools.partial(_where_res_check_equal, dim_map=check_dim_map), res, ref_result + ) + if all(tree.flatten(equal)): + return res + print("** Error, some elements are not equal:", equal) + raise SystemExit(1) + + +def _where_prepare_dims(a: Any, b: Any, *, mask: Tensor, dim_map: Dict[Dim, Dim]): + if isinstance(a, Dim): + assert isinstance(b, Dim) + if a == b: + return a + if a in dim_map: + return dim_map[a] + assert b not in dim_map + a_size = a.get_size_tensor() + b_size = b.get_size_tensor() + res_size = rf.where(mask, a_size, b_size, allow_broadcast_all_sources=True) + res_dim = Dim(res_size, name=_extend_dim_name(b.name)) + dim_map[a] = res_dim + dim_map[b] = res_dim + return res_dim + if isinstance(a, Tensor): + assert isinstance(b, Tensor) + return a # ignored at this stage + raise TypeError(f"_where_prepare_dims: unexpected type ({type(a)}, {type(b)})") + + +def _where(a: Any, b: Any, *, mask: Tensor, mask_cpu: Tensor, dim_map: Dict[Dim, Dim]): + if isinstance(a, Dim): + assert isinstance(b, Dim) + if a == b: + return a + assert a in dim_map + assert b in dim_map + return dim_map[a] + if isinstance(a, Tensor): + assert isinstance(b, Tensor) + for d in a.dims: + if d in dim_map: + a = _expand_slice(a, old_dim=d, new_dim=dim_map[d]) + for d in b.dims: + if d in dim_map: + b = _expand_slice(b, old_dim=d, new_dim=dim_map[d]) + assert a.dims_set == b.dims_set + if a.device == "cpu": + mask = mask_cpu + return rf.where(mask, a, b, allow_broadcast_all_sources=True) + raise TypeError(f"_where: unexpected type ({type(a)}, {type(b)})") + + +def _where_res_check_equal_prepare_dims(a: Any, b: Any, *, dim_map: Dict[Dim, Dim]): + if isinstance(a, Dim): + assert isinstance(b, Dim) + if a != b: + dim_map[a] = b + + +def _where_res_check_equal(path: Tuple[Any, ...], a: Any, b: Any, *, dim_map: Dict[Dim, Dim]): + import torch + + if isinstance(a, Dim): + assert isinstance(b, Dim) + if a == b: + return True + assert dim_map[a] == b + return _where_res_check_equal(path + ("size",), a.get_size_tensor(), b.get_size_tensor(), dim_map=dim_map) + if isinstance(a, Tensor): + assert isinstance(b, Tensor) + for d in a.dims: + if d in dim_map: + d_ = dim_map[d] + _where_res_check_equal(path + (d,), d, d_, dim_map=dim_map) + a, _ = rf.replace_dim(a, in_dim=d, out_dim=d_) + assert a.dims_set == b.dims_set + a = a.copy_transpose(b.dims) + a = a.copy_masked(0) + b = b.copy_masked(0) + a = rf.copy_to_device(a, "cpu") + b = rf.copy_to_device(b, "cpu") + try: + torch.testing.assert_close(a.raw_tensor, b.raw_tensor) + except AssertionError as exc: + print(f"** Error in {path} ({a} vs {b}):\n{exc}") + return False + else: + return True + raise TypeError(f"_where_res_check_equal: unexpected type ({type(a)}, {type(b)})") + + def _seq_label_history_init_state(*, vocab_dim: Dim, batch_dims: Sequence[Dim]) -> rf.State: hist_dim = Dim(0, name="hist0") history = rf.zeros(list(batch_dims) + [hist_dim], dtype="int64", sparse_dim=vocab_dim)