Skip to content

Commit

Permalink
Add top k parameters and print of best path
Browse files Browse the repository at this point in the history
  • Loading branch information
mmueller00 committed Dec 18, 2024
1 parent 75ee8ba commit daa9610
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 32 deletions.
21 changes: 14 additions & 7 deletions users/mueller/experiments/ctc_baseline/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def py():
blank_prior = True
prior_gradient = False
LM_order = 2
top_k = 1
self_train_subset = 18000

if train_small:
Expand Down Expand Up @@ -149,13 +150,7 @@ def py():
} if self_training_rounds > 0 else None

for am, lm, prior in [
(0.5, 0.5, 0.5),
# (0.5, 0.3, 0.5),
# (0.5, 0.2, 0.5),
# (0.5, 0.1, 0.5),
# (0.5, 0.05, 0.5),
# (0.5, 0.0, 0.5),
# (0.3, 0.2, 0.5),
(1.0, 0.0, 0.2)
]:
if use_sum_criterion:
training_scales = {
Expand All @@ -170,6 +165,7 @@ def py():
sum_str = f"-full_sum" + \
(f"_p{str(training_scales['prior']).replace('.', '')}_l{str(training_scales['lm']).replace('.', '')}_a{str(training_scales['am']).replace('.', '')}" if training_scales else "") + \
(f"_LMorder{LM_order}" if LM_order > 2 else "") + \
(f"_topK{top_k}" if top_k > 0 else "") + \
("_wo_hor_pr" if not horizontal_prior else "") + \
("_wo_blank_pr" if not blank_prior else "") + \
("_wo_pr_grad" if not prior_gradient else "")
Expand Down Expand Up @@ -201,6 +197,7 @@ def py():
blank_prior=blank_prior,
prior_gradient=prior_gradient,
LM_order=LM_order,
top_k=top_k,
training_scales=training_scales if use_sum_criterion else None,
self_train_subset=self_train_subset,
)
Expand Down Expand Up @@ -242,6 +239,7 @@ def train_exp(
blank_prior: bool = True,
prior_gradient: bool = True,
LM_order: int = 2,
top_k: int = 0,
training_scales: Optional[Dict[str, float]] = None,
self_train_subset: Optional[int] = None,
) -> Optional[ModelWithCheckpoints]:
Expand Down Expand Up @@ -353,6 +351,8 @@ def train_exp(
config_self["prior_scale"] = training_scales["prior"]
if not prior_gradient:
config_self["prior_gradient"] = prior_gradient
if top_k > 0:
config_self["top_k"] = top_k

# When testing on a smaller subset we only want one gpu
if self_train_subset is not None:
Expand Down Expand Up @@ -966,12 +966,17 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
horizontal_prior = config.bool("horizontal_prior", True)
blank_prior = config.bool("blank_prior", True)
prior_gradient = config.bool("prior_gradient", True)
top_k = config.int("top_k", 0)
use_prior = prior_scale > 0.0

if data.feature_dim and data.feature_dim.dimension == 1:
data = rf.squeeze(data, axis=data.feature_dim)
assert not data.feature_dim # raw audio

if am_scale == 0.7:
print("Data", data)
print("Batch", data.batch)


with uopen(lm_path, "rb") as f:
lm = torch.load(f, map_location=data.device)
Expand Down Expand Up @@ -1001,6 +1006,7 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
log_lm_probs=lm,
log_prior=aux_log_prior,
input_lengths=enc_spatial_dim.dyn_size_ext.raw_tensor,
top_k=top_k,
LM_order=lm_order,
am_scale=am_scale,
lm_scale=lm_scale,
Expand Down Expand Up @@ -1035,6 +1041,7 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
log_lm_probs=lm,
log_prior=log_prior,
input_lengths=enc_spatial_dim.dyn_size_ext.raw_tensor,
top_k=top_k,
LM_order=lm_order,
am_scale=am_scale,
lm_scale=lm_scale,
Expand Down
89 changes: 64 additions & 25 deletions users/mueller/experiments/ctc_baseline/sum_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def sum_loss(
unk_idx: int = 1,
log_zero: float = float("-inf"),
device: torch.device = torch.device("cpu"),
print_best_path_for_idx: list[int] = [],
):
"""
Sum criterion training for CTC, given by
Expand Down Expand Up @@ -83,7 +84,10 @@ def sum_loss(
max_audio_time, batch_size, n_out = log_probs.shape
# scaled log am and lm probs
log_probs = am_scale * log_probs
log_lm_probs = lm_scale * log_lm_probs
if lm_scale == 0.0:
log_lm_probs = torch.zeros_like(log_lm_probs)
else:
log_lm_probs = lm_scale * log_lm_probs
if use_prior:
log_prior = prior_scale * log_prior

Expand Down Expand Up @@ -140,6 +144,12 @@ def sum_loss(
log_q = log_q_label
if top_k > 0:
topk_scores, topk_idx = torch.topk(log_q, top_k, dim=-1, sorted=False)
if print_best_path_for_idx:
with torch.no_grad():
best_path_print = {}
max_val, max_idx = torch.max(log_q, dim=-1)
for idx in print_best_path_for_idx:
best_path_print[idx] = {"str": f"{max_idx[idx] + 2}", "score": "{:.2f}".format(max_val[idx].tolist()), "AM": log_probs[0][idx].tolist()[max_idx[idx] + 2]}

log_lm_probs_wo_eos = log_lm_probs[out_idx_vocab][:, out_idx_vocab].fill_diagonal_(log_zero)
for t in range(1, max_audio_time):
Expand Down Expand Up @@ -201,6 +211,13 @@ def sum_loss(

if top_k > 0:
topk_scores, topk_idx = torch.topk(log_q, top_k, dim=-1, sorted=False)
if print_best_path_for_idx:
with torch.no_grad():
max_val, max_idx = torch.max(log_q, dim=-1)
for idx in print_best_path_for_idx:
best_path_print[idx]["str"] += f" {max_idx[idx] + 2}"
best_path_print[idx]["score"] += " {:.2f}".format(max_val[idx].tolist()) # / (t+1)
best_path_print[idx]["AM"] += log_probs[t][idx].tolist()[max_idx[idx] + 2]

torch.cuda.empty_cache()

Expand All @@ -209,6 +226,10 @@ def sum_loss(
log_q = topk_scores + log_lm_probs[out_idx_vocab, eos_symbol].unsqueeze(0).expand(batch_size, -1).gather(-1, topk_idx)
else:
log_q = log_q + log_lm_probs[out_idx_vocab, eos_symbol].unsqueeze(0)
if print_best_path_for_idx:
with torch.no_grad():
for idx in print_best_path_for_idx:
print(f"Best path for {idx}: {best_path_print[idx]['str']}\nScore: {best_path_print[idx]['score']}\nAM: {-best_path_print[idx]['AM']}")

# sum over the vocab dimension
sum_score = safe_logsumexp(log_q, dim=-1)
Expand Down Expand Up @@ -681,45 +702,63 @@ def test():
# prior = _calc_log_prior(am, length)
# am = am.permute(1, 0, 2)

am = ag(am, "AM", False)
prior = ag(prior, "prior", False)
# am = ag(am, "AM", False)
# prior = ag(prior, "prior", False)


# loss = sum_loss(
# log_probs=am,
# log_lm_probs=lm,
# log_prior=prior,
# input_lengths=length,
# LM_order=2,
# am_scale=1.0,
# lm_scale=1.9,
# prior_scale=0.2,
# horizontal_prior=True,
# blank_idx=184,
# eos_idx=0,
# )
loss = sum_loss_k(
loss = sum_loss(
log_probs=am,
log_lm_probs=lm,
log_prior=prior,
input_lengths=length,
top_k=1,
top_k = 1,
LM_order=2,
am_scale=1.0,
lm_scale=1.9,
prior_scale=0.2,
lm_scale=0.0,
prior_scale=0.0,
horizontal_prior=True,
blank_prior=True,
blank_idx=184,
eos_idx=0,
print_best_path_for_idx=[0]
)
print("OUT", loss[0].tolist())
l += (loss / frames).mean()

del loss, am, prior
torch.cuda.empty_cache()
print(time.time() - s)
l.backward(torch.ones_like(l, device=device))
# del loss, am, prior
# torch.cuda.empty_cache()
# print(time.time() - s)

# targets = torch.tensor([55, 148, 178, 108, 179, 126, 110, 103, 9, 154, 84, 162, 159, 83, 153, 33, 106, 9, 131, 46, 63, 15, 162, 94, 0, 111, 121, 29, 121, 21, 151, 18, 4, 159, 118, 86, 129, 18, 13, 170, 151, 81, 77, 53, 165, 57, 134, 63, 103, 110, 47, 35, 145, 18, 34, 66, 42, 96, 139, 16, 138, 156, 1, 63, 103, 95, 149, 111, 83, 34, 113, 158, 39, 166, 34, 123, 26, 148, 134, 148, 168, 177, 18, 23, 164, 69, 145, 93, 166, 174, 162, 36, 95, 116, 123, 74, 124, 70])
# targets = targets + 2
targets = torch.tensor(
[ 57, 150, 180, 110, 107, 128, 112, 105, 11, 156, 86, 164, 161, 85,
155, 35, 108, 11, 133, 48, 133, 17, 164, 96, 2, 113, 123, 31,
123, 23, 153, 20, 6, 161, 120, 88, 131, 20, 15, 99, 153, 58,
119, 1, 88, 59, 136, 65, 105, 99, 122, 37, 147, 20, 36, 68,
44, 98, 141, 18, 1, 158, 3, 65, 105, 97, 151, 113, 85, 36,
115, 160, 83, 168, 36, 125, 28, 150, 136, 90, 170, 179,
20, 25, 166, 71, 147, 95, 168, 176, 164, 38, 97, 118, 125, 76,
43, 72]
)
# greedy_probs, greedy_idx = torch.max(am[:, 0:1], dim=-1)
# print(greedy_idx.squeeze(-1))
targets = targets.unsqueeze(0)
target_lengths = torch.tensor([targets.size(1)])
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=am[:, 0:1],
targets=targets,
input_lengths=length[0:1],
target_lengths=target_lengths,
blank=184,
reduction="none"
)
print(ctc_loss)


# l.backward(torch.ones_like(l, device=device))
e1 = time.time()
print(f"Sum loss took {time.strftime('%H:%M:%S', time.gmtime(e1-s1))}: {l}") # 5:00 mins
# print(f"Sum loss took {time.strftime('%H:%M:%S', time.gmtime(e1-s1))}: {l}") # 5:00 mins

# s2 = time.time()

Expand Down

0 comments on commit daa9610

Please sign in to comment.