diff --git a/examples/mms/data_prep/align_and_segment.py b/examples/mms/data_prep/align_and_segment.py index cd5045eabc..de45d757bd 100644 --- a/examples/mms/data_prep/align_and_segment.py +++ b/examples/mms/data_prep/align_and_segment.py @@ -87,13 +87,14 @@ def get_alignments( blank = dictionary[""] targets = torch.tensor(token_indices, dtype=torch.int32).to(DEVICE) - input_lengths = torch.tensor(emissions.shape[0]) - target_lengths = torch.tensor(targets.shape[0]) - + + input_lengths = torch.tensor(emissions.shape[0]).unsqueeze(-1) + target_lengths = torch.tensor(targets.shape[0]).unsqueeze(-1) path, _ = F.forced_align( - emissions, targets, input_lengths, target_lengths, blank=blank + emissions.unsqueeze(0), targets.unsqueeze(0), input_lengths, target_lengths, blank=blank ) - path = path.to("cpu").tolist() + path = path.squeeze().to("cpu").tolist() + segments = merge_repeats(path, {v: k for k, v in dictionary.items()}) return segments, stride