From b5d89cddc9e4a0af831d2aafc1ba7dbf0f1b10d0 Mon Sep 17 00:00:00 2001 From: Vineel Pratap Date: Thu, 7 Sep 2023 11:25:28 -0700 Subject: [PATCH] Update align_and_segment.py (#5317) Fix MMS alignment code --- examples/mms/data_prep/align_and_segment.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/mms/data_prep/align_and_segment.py b/examples/mms/data_prep/align_and_segment.py index cd5045eabc..de45d757bd 100644 --- a/examples/mms/data_prep/align_and_segment.py +++ b/examples/mms/data_prep/align_and_segment.py @@ -87,13 +87,14 @@ def get_alignments( blank = dictionary[""] targets = torch.tensor(token_indices, dtype=torch.int32).to(DEVICE) - input_lengths = torch.tensor(emissions.shape[0]) - target_lengths = torch.tensor(targets.shape[0]) - + + input_lengths = torch.tensor(emissions.shape[0]).unsqueeze(-1) + target_lengths = torch.tensor(targets.shape[0]).unsqueeze(-1) path, _ = F.forced_align( - emissions, targets, input_lengths, target_lengths, blank=blank + emissions.unsqueeze(0), targets.unsqueeze(0), input_lengths, target_lengths, blank=blank ) - path = path.to("cpu").tolist() + path = path.squeeze().to("cpu").tolist() + segments = merge_repeats(path, {v: k for k, v in dictionary.items()}) return segments, stride