From e69ec8ed89dbbb3698bdf1c0030231724f4313a5 Mon Sep 17 00:00:00 2001 From: taejinp Date: Wed, 13 Nov 2024 18:46:20 -0800 Subject: [PATCH 01/47] Adding the first pr files models and dataset Signed-off-by: taejinp --- .../sortformer_diar_HL_callhome_part1.yaml | 18 + .../sortformer_diar_HL_dihard.yaml | 17 + .../sortformer_diar_encoder_infer.py | 132 ++ .../sortformer_diar_encoder_train.py | 54 + .../asr/data/audio_to_diar_label.py | 490 ++++++- .../asr/data/audio_to_diar_label_lhotse.py | 76 + .../asr/models/sortformer_diar_models.py | 565 ++++++++ .../asr/parts/utils/asr_multispeaker_utils.py | 1231 +++++++++++++++++ .../common/parts/preprocessing/collections.py | 192 ++- 9 files changed, 2759 insertions(+), 16 deletions(-) create mode 100644 examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml create mode 100644 examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml create mode 100644 examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py create mode 100644 examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py create mode 100644 nemo/collections/asr/data/audio_to_diar_label_lhotse.py create mode 100644 nemo/collections/asr/models/sortformer_diar_models.py create mode 100644 nemo/collections/asr/parts/utils/asr_multispeaker_utils.py diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml new file mode 100644 index 000000000000..6b960e2d5950 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml @@ -0,0 +1,18 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477. +# Trial 17903 finished with value: 0.10261257411949805 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.0, 'min_duration_on': 0.39, 'min_duration_off': 0.39}. Best is trial 17903 with value: 0.10261257411949805. +# Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. +parameters: + window_length_in_sec: 0.0 # Not used + shift_length_in_sec: 0.0 # Not used + smoothing: False # Not used + overlap: 0.5 # Not used + onset: 0.53 # Onset threshold for detecting the beginning and end of a speech + offset: 0.49 # Offset threshold for detecting the end of a speech + pad_onset: 0.23 # Adding durations before each speech segment + pad_offset: 0.01 # Adding durations after each speech segment + min_duration_on: 0.42 # Threshold for small non-speech deletion + min_duration_off: 0.34 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml new file mode 100644 index 000000000000..bb9f362ad619 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml @@ -0,0 +1,17 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. +# Trial 180 finished with value: 0.12329626986650599 and parameters: {'onset': 0.56, 'offset': 0.81, 'pad_onset': 0.05, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.16}. Best is trial 180 with value: 0.12329626986650599. +parameters: + window_length_in_sec: 0.0 # Not used + shift_length_in_sec: 0.0 # Not used + smoothing: False # Not used + overlap: 0.5 # Not used + onset: 0.64 # Onset threshold for detecting the beginning and end of a speech + offset: 0.74 # Offset threshold for detecting the end of a speech + pad_onset: 0.06 # Adding durations before each speech segment + pad_offset: 0.0 # Adding durations after each speech segment + min_duration_on: 0.1 # Threshold for small non-speech deletion + min_duration_off: 0.15 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py new file mode 100644 index 000000000000..aafd2b2cb6ed --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +import seaborn as sns +import numpy as np + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +seed_everything(42) +import torch +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.manifold import TSNE +import pandas as pd +from nemo.collections.asr.data.audio_to_msdd_mock_label import generate_mock_embs + +def plot_enc_tsne(x, targets, memo): + # x = enc_states_list[-1].squeeze(0).cpu().detach().numpy() + tsne = TSNE(n_components=2, verbose=False, random_state=100) + zembs = tsne.fit_transform(x) + + # Step 1: Create a new column filled with 0.5 + new_column = torch.full((targets.size(0), 1), 0.5) + # Step 2: Concatenate the new column with the original tensor + updated_targets = torch.cat((new_column, targets), dim=1) + + df = pd.DataFrame() + df["y"] = updated_targets.argmax(dim=1).detach().cpu().numpy() + df["comp-1"] = zembs[:,0] + df["comp-2"] = zembs[:,1] + + # Plotting using seaborn + plt.figure(figsize=(10, 8)) + sns.scatterplot(x="comp-1", y="comp-2", hue=df.y.tolist(), + palette=sns.color_palette("hls", 10), + data=df).set(title="SortFormer HiddenState T-SNE projection") + + # Save the plot as a PNG file in the specified directory + plt.savefig(f'/home/taejinp/Downloads/tsne_plots/tsne_sortformer_plot_{memo}.png') + +def remove_speaker_models(ckpt_path): + ckpt_instance = torch.load(ckpt_path) + _state_dict = ckpt_instance['state_dict'] + + key_list = list(_state_dict.keys()) + for key in key_list: + if '_speaker_model.' in key or '_speaker_model_decoder.' in key: + # import ipdb; ipdb.set_trace() + del _state_dict[key] + + target_path = ckpt_path.replace('.ckpt', '.removed.ckpt') + torch.save(ckpt_instance, target_path) + return target_path + + +# @hydra_runner(config_path="../conf/neural_diarizer", config_name="msdd_5scl_15_05_50Povl_256x3x32x2.yaml") +def main(): + # logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + # trainer = pl.Trainer(**cfg.trainer) + # exp_manager(trainer, cfg.get("exp_manager", None)) + # ckpt_path = "/disk_c/taejinp_backup/msdd_model_train/NVB_SFmr_MixMockEmbsTest/version_18_f0:84/checkpoints/e613.ckpt" + ckpt_path = "/disk_c/taejinp_backup/msdd_model_train/SFmr_MixMockEmbsTest/version_21/checkpoints/ep2255.ckpt" + target_path = remove_speaker_models(ckpt_path) + sortformer_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=target_path) + unit_len = 25 + targets = torch.eye(4,4).repeat_interleave(unit_len,1).t() + targets[:,2:] = 0 + # targets[:,3:] = 0 + targets = targets[:2*unit_len, :] + new_column = torch.full((targets.size(0), 1), 0.5) + updated_targets = torch.cat((new_column, targets), dim=1) + mock_embs, audio_signal_length, targets = generate_mock_embs(targets=targets, seed=315, + mock_emb_noise_std=0.03, + mock_emb_degree_of_freedom=4, + min_noise_std=0.01,) + mock_embs = mock_embs.unsqueeze(0) + audio_signal = mock_embs + + audio_signal, audio_signal_length, targets + + audio_signal = audio_signal.cuda() + ms_seg_counts = torch.tensor([]).cuda() + ms_seg_timestamps = torch.tensor([]).cuda() + scale_mapping = torch.tensor([]).cuda() + sortformer_model.alpha = 0.0 + + _preds_mean, preds_, attn_score_stack, enc_states_list, preds_list = sortformer_model.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ms_seg_timestamps=ms_seg_timestamps, + ms_seg_counts=ms_seg_counts, + scale_mapping=scale_mapping, + temp_targets=targets, + ) + + audio_signal_np = audio_signal.squeeze(0).cpu().detach().numpy() + plot_enc_tsne(audio_signal_np, targets, memo=f'input', ) + for layer_c in range(len(enc_states_list)): + print(f"Plotting TSNE for layer {layer_c} ...") + x = enc_states_list[layer_c].squeeze(0).cpu().detach().numpy() + plot_enc_tsne(x, targets, memo=f'layer{layer_c}', ) + preds = preds_.squeeze(0).cpu().detach().numpy() + plot_enc_tsne(preds, targets, memo=f'preds', ) + _preds_mean = _preds_mean.squeeze(0).cpu().detach().numpy() + plot_enc_tsne(_preds_mean, targets, memo=f'preds_mean', ) + + # Optionally, you can also show the plot if desired + plt.show() + import ipdb; ipdb.set_trace() + + # msdd_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) + # trainer.fit(msdd_model) + + +if __name__ == '__main__': + main() diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py new file mode 100644 index 000000000000..fb350113d596 --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +""" +Example training session (single GPU training on telephonic datasets) + +python ./multiscale_diar_decoder.py --config-path='../conf/neural_diarizer' --config-name='msdd_5scl_15_05_50Povl_256x3x32x2.yaml' \ + trainer.devices=1 \ + model.base.diarizer.speaker_embeddings.model_path="titanet_large" \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ + model.train_ds.emb_dir="" \ + model.validation_ds.emb_dir="" \ + exp_manager.name='sample_train' \ + exp_manager.exp_dir='./msdd_exp' +""" + +seed_everything(42) + + +@hydra_runner(config_path="../conf/neural_diarizer", config_name="msdd_5scl_15_05_50Povl_256x3x32x2.yaml") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + sortformer_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) + # Initialize the weights of the model from another model, if provided via config + sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(sortformer_model) + + +if __name__ == '__main__': + + main() diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index a1cb6d0f1bdc..ffad8e4fd072 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,16 +15,17 @@ import os from collections import OrderedDict from statistics import mode -from typing import Dict, Optional - +from typing import Dict, List, Tuple, Optional import torch +import numpy as np from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat -from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data -from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import find_first_nonzero +from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data, get_subsegments +from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel, EndtoEndDiarizationSpeechLabel from nemo.core.classes import Dataset from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType - +from nemo.utils import logging def get_scale_mapping_list(uniq_timestamps): """ @@ -62,7 +63,7 @@ def get_scale_mapping_list(uniq_timestamps): return scale_mapping_argmat -def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_spks=None): +def extract_seg_info_from_rttm(rttm_lines, mapping_dict=None, target_spks=None): """ Get RTTM lines containing speaker labels, start time and end time. target_spks contains two targeted speaker indices for creating groundtruth label files. Only speakers in target_spks variable will be @@ -139,6 +140,128 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return fr_level_target +def get_subsegments_to_timestamps( + subsegments: List[Tuple[float, float]], + feat_per_sec: int = 100, + max_end_ts: float=None, + decimals=2 + ): + """ + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate and rounding. + All `ts` related tensors are dimensioned as (N, 2), where N is the number of subsegments. + + Args: + subsegments (List[Tuple[float, float]]): + A list of tuples where each tuple contains the start and end times of a subsegment. + feat_per_sec (int, optional): + The number of feature frames per second. Defaults to 100. + max_end_ts (float, optional): + The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. + decimals (int, optional): + The number of decimal places to round the timestamps. Defaults to 2. + + Returns: + ts (torch.tensor): + A tensor containing the scaled and rounded timestamps for each subsegment. + """ + seg_ts = (torch.tensor(subsegments) * feat_per_sec).float() + ts_round = torch.round(seg_ts, decimals=decimals) + ts = ts_round.long() + ts[:, 1] = ts[:, 0] + ts[:, 1] + if max_end_ts is not None: + ts = np.clip(ts, 0, int(max_end_ts*feat_per_sec)) + return ts + +def extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines, round_digits=3): + """ + Extracts RTTM lines containing speaker labels, start time, and end time for a given audio segment. + + Args: + uniq_id (str): Unique identifier for the audio file and corresponding RTTM file. + offset (float): The starting time offset for the segment of interest. + duration (float): The duration of the segment of interest. + rttm_lines (list): List of RTTM lines in string format. + round_digits (int, optional): Number of decimal places to round the start and end times. Defaults to 3. + + Returns: + rttm_mat (tuple): A tuple containing lists of start times, end times, and speaker labels. + sess_to_global_spkids (dict): A mapping from session-specific speaker indices to global speaker identifiers. + """ + rttm_stt, rttm_end = offset, offset + duration + stt_list, end_list, speaker_list, speaker_set = [], [], [], [] + sess_to_global_spkids = dict() + + for rttm_line in rttm_lines: + start, end, speaker = convert_rttm_line(rttm_line) + + # Skip invalid RTTM lines where the start time is greater than the end time. + if start > end: + continue + + # Check if the RTTM segment overlaps with the specified segment of interest. + if (end > rttm_stt and start < rttm_end) or (start < rttm_end and end > rttm_stt): + # Adjust the start and end times to fit within the segment of interest. + start, end = max(start, rttm_stt), min(end, rttm_end) + else: + continue + + # Round the start and end times to the specified number of decimal places. + end_list.append(round(end, round_digits)) + stt_list.append(round(start, round_digits)) + + # Assign a unique index to each speaker and maintain a mapping. + if speaker not in speaker_set: + speaker_set.append(speaker) + speaker_list.append(speaker_set.index(speaker)) + sess_to_global_spkids.update({speaker_set.index(speaker): speaker}) + + rttm_mat = (stt_list, end_list, speaker_list) + return rttm_mat, sess_to_global_spkids + +def get_frame_targets_from_rttm( + rttm_timestamps: list, + offset: float, + duration: float, + round_digits: int, + feat_per_sec: int, + max_spks: int, + ): + """ + Create a multi-dimensional vector sequence containing speaker timestamp information in RTTM. + The unit-length is the frame shift length of the acoustic feature. The feature-level annotations + `feat_level_target` will later be converted to base-segment level diarization label. + + Args: + rttm_timestamps (list): + List containing start and end time for each speaker segment label. + stt_list, end_list and speaker_list are contained. + feat_per_sec (int): + Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. + target_spks (tuple): + Speaker indices that are generated from combinations. If there are only one or two speakers, + only a single target_spks variable is generated. + + Returns: + feat_level_target (torch.tensor): + Tensor containing label for each feature level frame. + """ + stt_list, end_list, speaker_list = rttm_timestamps + sorted_speakers = sorted(list(set(speaker_list))) + total_fr_len = int(duration * feat_per_sec) + if len(sorted_speakers) > max_spks: + logging.warning(f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: {max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!") + feat_level_target = torch.zeros(total_fr_len, max_spks) + for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)): + if end < offset or stt > offset + duration: + continue + stt, end = max(offset, stt), min(offset + duration, end) + spk = spk_rttm_key + if spk < max_spks: + stt_fr, end_fr = int((stt - offset) * feat_per_sec), int((end - offset)* feat_per_sec) + feat_level_target[stt_fr:end_fr, spk] = 1 + return feat_level_target + + class _AudioMSDDTrainDataset(Dataset): """ Dataset class that loads a json file containing paths to audio files, @@ -338,7 +461,7 @@ def parse_rttm_for_ms_targets(self, sample): """ rttm_lines = open(sample.rttm_file).readlines() uniq_id = self.get_uniq_id_with_range(sample) - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, target_spks=sample.target_spks ) @@ -370,14 +493,14 @@ def get_uniq_id_with_range(self, sample, deci=3): def get_ms_seg_timestamps(self, sample): """ - Get start and end time of segments in each scale. + Get start and end time of each diarization frame. Args: sample: `DiarizationSpeechLabel` instance from preprocessing.collections Returns: ms_seg_timestamps (torch.tensor): - Tensor containing Multiscale segment timestamps. + Tensor containing timestamps for each frame. ms_seg_counts (torch.tensor): Number of segments for each scale. This information is used for reshaping embedding batch during forward propagation. @@ -529,7 +652,7 @@ def parse_rttm_multiscale(self, sample): rttm_lines = open(sample.rttm_file).readlines() uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] mapping_dict = self.emb_dict[max(self.emb_dict.keys())][uniq_id]['mapping'] - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict, sample.target_spks) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines, mapping_dict, sample.target_spks) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, sample.target_spks ) @@ -851,3 +974,348 @@ def __init__( def msdd_infer_collate_fn(self, batch): return _msdd_infer_collate_fn(self, batch) + +class _AudioToSpeechE2ESpkDiarDataset(Dataset): + """ + Dataset class that loads a json file containing paths to audio files, + RTTM files and number of speakers. This Dataset class is designed for + training or fine-tuning speaker embedding extractor and diarization decoder + at the same time. + + Example: + {"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm} + ... + {"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm} + + Args: + manifest_filepath (str): + Path to input manifest json files. + multiargs_dict (dict): + Dictionary containing the parameters for multiscale segmentation and clustering. + soft_label_thres (float): + Threshold that determines the label of each segment based on RTTM file information. + featurizer: + Featurizer instance for generating audio_signal from the raw waveform. + window_stride (float): + Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports.""" + output_types = { + "audio_signal": NeuralType(('B', 'T'), AudioSignal()), + "audio_length": NeuralType(('B'), LengthsType()), + "targets": NeuralType(('B', 'T', 'C'), ProbsType()), + "target_len": NeuralType(('B', 'C'), LengthsType()), + } + + return output_types + + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride: float, + min_subsegment_duration: float = 0.03, + global_rank: int = 0, + dtype=torch.float16, + round_digits: int = 2, + soft_targets: bool = False, + subsampling_factor: int = 8, + ): + super().__init__() + self.collection = EndtoEndDiarizationSpeechLabel( + manifests_files=manifest_filepath.split(','), + round_digits=round_digits, + ) + self.featurizer = featurizer + self.round_digits = round_digits + self.feat_per_sec = int(1 / window_stride) + self.diar_frame_length = round(subsampling_factor * window_stride, round_digits) + self.session_len_sec = session_len_sec + self.soft_label_thres = soft_label_thres + self.max_spks = num_spks + self.min_subsegment_duration = min_subsegment_duration + self.dtype = dtype + self.use_asr_style_frame_count = True + self.soft_targets = soft_targets + self.round_digits = 2 + self.floor_decimal = 10 ** self.round_digits + + def __len__(self): + return len(self.collection) + + def get_uniq_id_with_range(self, sample, deci=3): + """ + Generate unique training sample ID from unique file ID, offset and duration. The start-end time added + unique ID is required for identifying the sample since multiple short audio samples are generated from a single + audio file. The start time and end time of the audio stream uses millisecond units if `deci=3`. + + Args: + sample: + `DiarizationSpeechLabel` instance from collections. + + Returns: + uniq_id (str): + Unique sample ID which includes start and end time of the audio stream. + Example: abc1001_3122_6458 + """ + bare_uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] + offset = str(int(round(sample.offset, deci) * pow(10, deci))) + endtime = str(int(round(sample.offset + sample.duration, deci) * pow(10, deci))) + uniq_id = f"{bare_uniq_id}_{offset}_{endtime}" + return uniq_id + + def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, target_len): + """ + Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file. + This function converts (start, end, speaker_id) format into base-scale (the finest scale) segment level + diarization label in a matrix form. + + Example of seg_target: + [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] + """ + rttm_lines = open(rttm_file).readlines() + rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines) + + fr_level_target = get_frame_targets_from_rttm(rttm_timestamps=rttm_timestamps, + offset=offset, + duration=duration, + round_digits=self.round_digits, + feat_per_sec=self.feat_per_sec, + max_spks=self.max_spks) + + soft_target_seg = self.get_soft_targets_seg(feat_level_target=fr_level_target, + target_len=target_len) + if self.soft_targets: + step_target = soft_target_seg + else: + step_target = (soft_target_seg >= self.soft_label_thres).float() + return step_target + + def get_soft_targets_seg(self, feat_level_target, target_len): + """ + Generate the final targets for the actual diarization step. + Here, frame level means step level which is also referred to as segments. + We follow the original paper and refer to the step level as "frames". + + Args: + feat_level_target (torch.tensor): + Tensor variable containing hard-labels of speaker activity in each feature-level segment. + target_len (torch.tensor): + Numbers of ms segments + + Returns: + soft_target_seg (torch.tensor): + Tensor variable containing soft-labels of speaker activity in each step-level segment. + """ + num_seg = torch.max(target_len) + targets = torch.zeros(num_seg, self.max_spks) + stride = int(self.feat_per_sec * self.diar_frame_length) + for index in range(num_seg): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_seg - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + targets[index] = torch.mean(feat_level_target[seg_stt_feat:seg_end_feat+1, :], axis=0) + return targets + + def get_segment_timestamps( + self, + duration: float, + offset: float = 0, + sample_rate: int = 16000, + ): + """ + Get start and end time of segments in each scale. + + Args: + sample: + `DiarizationSpeechLabel` instance from preprocessing.collections + Returns: + segment_timestamps (torch.tensor): + Tensor containing Multiscale segment timestamps. + target_len (torch.tensor): + Number of segments for each scale. This information is used for reshaping embedding batch + during forward propagation. + """ + subsegments = get_subsegments(offset=offset, + window=round(self.diar_frame_length * 2, self.round_digits), + shift=self.diar_frame_length, + duration=duration, + min_subsegment_duration=self.min_subsegment_duration, + use_asr_style_frame_count=self.use_asr_style_frame_count, + sample_rate=sample_rate, + feat_per_sec=self.feat_per_sec, + ) + if self.use_asr_style_frame_count: + effective_dur = np.ceil((1+duration*sample_rate)/int(sample_rate/self.feat_per_sec)).astype(int)/self.feat_per_sec + else: + effective_dur = duration + ts_tensor = get_subsegments_to_timestamps(subsegments, self.feat_per_sec, decimals=2, max_end_ts=(offset+effective_dur)) + target_len = torch.tensor([ts_tensor.shape[0]]) + return target_len + + def __getitem__(self, index): + sample = self.collection[index] + if sample.offset is None: + sample.offset = 0 + offset = sample.offset + if self.session_len_sec < 0: + session_len_sec = sample.duration + else: + session_len_sec = min(sample.duration, self.session_len_sec) + + uniq_id = self.get_uniq_id_with_range(sample) + audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) + + # We should resolve the length mis-match from the round-off errors: `session_len_sec` and `audio_signal.shape[0]` + session_len_sec = np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal)/self.floor_decimal + audio_signal = audio_signal[:round(self.featurizer.sample_rate*session_len_sec)] + + audio_signal_length = torch.tensor(audio_signal.shape[0]).long() + audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') + target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) + targets = self.parse_rttm_for_targets_and_lens(uniq_id=uniq_id, + rttm_file=sample.rttm_file, + offset=offset, + duration=session_len_sec, + target_len=target_len) + return audio_signal, audio_signal_length, targets, target_len + +def _eesd_train_collate_fn(self, batch): + """ + Collate a batch of variables needed for training the end-to-end speaker diarization (EESD) model + from raw waveforms to diarization labels. The following variables are included in the training/validation batch: + + Args: + batch (tuple): + A tuple containing the variables for diarization training. + + Returns: + audio_signal (torch.Tensor): + A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` in the input manifest file. + feature_length (torch.Tensor): + A tensor containing the lengths of the raw waveform samples. + targets (torch.Tensor): + Groundtruth speaker labels for the given input embedding sequence. + target_lens (torch.Tensor): + A tensor containing the number of segments for each sample in the batch, necessary for reshaping inputs to the EESD model. + """ + packed_batch = list(zip(*batch)) + audio_signal, feature_length, targets, target_len = packed_batch + audio_signal_list, feature_length_list = [], [] + target_len_list, targets_list = [], [] + + max_raw_feat_len = max([x.shape[0] for x in audio_signal]) + max_target_len = max([x.shape[0] for x in targets]) + if max([len(feat.shape) for feat in audio_signal]) > 1: + max_ch = max([feat.shape[1] for feat in audio_signal]) + else: + max_ch = 1 + for feat, feat_len, tgt, segment_ct in batch: + seq_len = tgt.shape[0] + if len(feat.shape) > 1: + pad_feat = (0, 0, 0, max_raw_feat_len - feat.shape[0]) + else: + pad_feat = (0, max_raw_feat_len - feat.shape[0]) + if feat.shape[0] < feat_len: + feat_len_pad = feat_len - feat.shape[0] + feat = torch.nn.functional.pad(feat, (0, feat_len_pad)) + pad_tgt = (0, 0, 0, max_target_len - seq_len) + padded_feat = torch.nn.functional.pad(feat, pad_feat) + padded_tgt = torch.nn.functional.pad(tgt, pad_tgt) + if max_ch > 1 and padded_feat.shape[1] < max_ch: + feat_ch_pad = max_ch - padded_feat.shape[1] + padded_feat = torch.nn.functional.pad(padded_feat, (0, feat_ch_pad)) + audio_signal_list.append(padded_feat) + feature_length_list.append(feat_len.clone().detach()) + target_len_list.append(segment_ct.clone().detach()) + targets_list.append(padded_tgt) + audio_signal = torch.stack(audio_signal_list) + feature_length = torch.stack(feature_length_list) + target_lens = torch.stack(target_len_list) + targets = torch.stack(targets_list) + return audio_signal, feature_length, targets, target_lens + +class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): + """ + Dataset class for loading a JSON file containing paths to audio files, + RTTM (Rich Transcription Time Marked) files, and the number of speakers. + This class is designed for training or fine-tuning a speaker embedding + extractor and diarization decoder simultaneously. + + The JSON manifest file should have entries in the following format: + + Example: + { + "audio_filepath": "/path/to/audio_0.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm" + } + ... + { + "audio_filepath": "/path/to/audio_n.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm" + } + + Args: + manifest_filepath (str): + Path to the input manifest JSON file containing paths to audio and RTTM files. + soft_label_thres (float): + Threshold for assigning soft labels to segments based on RTTM file information. + session_len_sec (float): + Duration of each session (in seconds) for training or fine-tuning. + num_spks (int): + Number of speakers in the audio files. + featurizer: + Instance of a featurizer for generating features from the raw waveform. + window_stride (float): + Window stride (in seconds) for extracting acoustic features, used to calculate + the number of feature frames. + global_rank (int): + Global rank of the current process (used for distributed training). + soft_targets (bool): + Whether or not to use soft targets during training. + + Methods: + eesd_train_collate_fn(batch): + Collates a batch of data for end-to-end speaker diarization training. + """ + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride, + global_rank: int, + soft_targets: bool, + ): + super().__init__( + manifest_filepath=manifest_filepath, + soft_label_thres=soft_label_thres, + session_len_sec=session_len_sec, + num_spks=num_spks, + featurizer=featurizer, + window_stride=window_stride, + global_rank=global_rank, + soft_targets=soft_targets, + ) + + def eesd_train_collate_fn(self, batch): + return _eesd_train_collate_fn(self, batch) \ No newline at end of file diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py new file mode 100644 index 000000000000..e223e4ef2a56 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple + +import torch.utils.data +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_matrices + +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( + speaker_to_target, + get_hidden_length_from_sample_length, +) + +class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): + """ + This dataset is based on diarization datasets from audio_to_eesd_label.py. + Unlike native NeMo datasets, Lhotse dataset defines only the mapping from + a CutSet (meta-data) to a mini-batch with PyTorch tensors. + Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any). + Managing data, sampling, de-duplication across workers/nodes etc. is all handled + by Lhotse samplers instead. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'targets': NeuralType(('B', 'T', 'N'), LabelsType()), + 'target_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + } + + def __init__(self, cfg): + super().__init__() + self.load_audio = AudioSamples(fault_tolerant=True) + self.cfg = cfg + self.num_speakers = self.cfg.get('num_speakers', 4) + self.num_sample_per_mel_frame = int(self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000)) # 160 + self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) + self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero',False) + + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: + audio, audio_lens, cuts = self.load_audio(cuts) + speaker_activities = [] + for cut in cuts: + speaker_activity = speaker_to_target( + a_cut=cut, + num_speakers=self.num_speakers, + num_sample_per_mel_frame=self.num_sample_per_mel_frame, + num_mel_frame_per_asr_frame=self.num_mel_frame_per_target_frame, + spk_tar_all_zero=self.spk_tar_all_zero, + boundary_segments=True + ) + speaker_activities.append(speaker_activity) + targets = collate_matrices(speaker_activities).to(audio.dtype) + target_lens_list = [] + for audio_len in audio_lens: + target_fr_len = get_hidden_length_from_sample_length(audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame) + target_lens_list.append([target_fr_len]) + target_lens = torch.tensor(target_lens_list) + + return audio, audio_lens, targets, target_lens diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py new file mode 100644 index 000000000000..c389f0eb627f --- /dev/null +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -0,0 +1,565 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import itertools +import random +import torch +from collections import OrderedDict +from typing import Dict, List, Optional, Union +from hydra.utils import instantiate +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from tqdm import tqdm +from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.core.neural_types.elements import ProbsType +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy +from nemo.collections.asr.models.asr_model import ExportableEncDecModel +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_pil_targets, get_ats_targets +from nemo.utils import logging + +try: + from torch.cuda.amp import autocast +except ImportError: + from contextlib import contextmanager + + @contextmanager + def autocast(enabled=None): + yield + +# torch.backends.cudnn.enabled = False + +__all__ = ['SortformerEncLabelModel'] + +class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): + """ + Encoder class for Sortformer diarization model. + Model class creates training, validation methods for setting up data performing model forward pass. + + This model class expects config dict for: + * preprocessor + * Transformer Encoder + * FastConformer Encoder + * Sortformer Modules + """ + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + return result + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + """ + Initialize an Sortformer Diarizer model and a pretrained NEST encoder. + In this init function, training and validation datasets are prepared. + """ + random.seed(42) + self._trainer = trainer if trainer else None + self._cfg = cfg + + if self._trainer: + self.world_size = trainer.num_nodes * trainer.num_devices + else: + self.world_size = 1 + + if self._trainer is not None and self._cfg.get('augmentor', None) is not None: + self.augmentor = process_augmentations(self._cfg.augmentor) + else: + self.augmentor = None + super().__init__(cfg=self._cfg, trainer=trainer) + self.preprocessor = SortformerEncLabelModel.from_config_dict(self._cfg.preprocessor) + + if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: + self.spec_augmentation = SortformerEncLabelModel.from_config_dict(self._cfg.spec_augment) + else: + self.spec_augmentation = None + + self.encoder = SortformerEncLabelModel.from_config_dict(self._cfg.encoder) + self.sortformer_modules = SortformerEncLabelModel.from_config_dict(self._cfg.sortformer_modules) + self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder) + self._init_loss_weights() + + self.eps = 1e-3 + self.loss = instantiate(self._cfg.loss) + + self.streaming_mode = self._cfg.get("streaming_mode", False) + self.save_hyperparameters("cfg") + self._init_eval_metrics() + + speaker_inds = list(range(self._cfg.max_num_of_spks)) + self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations + + def _init_loss_weights(self): + pil_weight = self._cfg.get("pil_weight", 0.0) + ats_weight = self._cfg.get("ats_weight", 1.0) + if pil_weight + ats_weight == 0: + raise ValueError(f"weights for PIL {pil_weight} and ATS {ats_weight} cannot sum to 0") + self.pil_weight = pil_weight/(pil_weight + ats_weight) + self.ats_weight = ats_weight/(pil_weight + ats_weight) + logging.info(f"Normalized weights for PIL {self.pil_weight} and ATS {self.ats_weight}") + + def _init_eval_metrics(self): + """ + If there is no label, then the evaluation metrics will be based on Permutation Invariant Loss (PIL). + """ + self._accuracy_test = MultiBinaryAccuracy() + self._accuracy_train = MultiBinaryAccuracy() + self._accuracy_valid = MultiBinaryAccuracy() + + self._accuracy_test_ats = MultiBinaryAccuracy() + self._accuracy_train_ats = MultiBinaryAccuracy() + self._accuracy_valid_ats = MultiBinaryAccuracy() + + def _reset_train_metrics(self): + self._accuracy_train.reset() + self._accuracy_train_ats.reset() + + def _reset_valid_metrics(self): + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + def __setup_dataloader_from_config(self, config): + # Switch to lhotse dataloader if specified in the config + if config.get("use_lhotse"): + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), + ) + + featurizer = WaveformFeaturizer( + sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=self.augmentor + ) + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + logging.info(f"Loading dataset from {config.manifest_filepath}") + + if self._trainer is not None: + global_rank = self._trainer.global_rank + else: + global_rank = 0 + time_flag = time.time() + logging.info("AAB: Starting Dataloader Instance loading... Step A") + + dataset = AudioToSpeechE2ESpkDiarDataset( + manifest_filepath=config.manifest_filepath, + soft_label_thres=config.soft_label_thres, + session_len_sec=config.session_len_sec, + num_spks=config.num_spks, + featurizer=featurizer, + window_stride=self._cfg.preprocessor.window_stride, + global_rank=global_rank, + soft_targets=config.soft_targets if 'soft_targets' in config else False, + ) + logging.info(f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader step B: {time.time() - time_flag}") + + self.data_collection = dataset.collection + self.collate_ds = dataset + + dataloader_instance = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config.batch_size, + collate_fn=self.collate_ds.eesd_train_collate_fn, + drop_last=config.get('drop_last', False), + shuffle=False, + num_workers=config.get('num_workers', 1), + pin_memory=config.get('pin_memory', False), + ) + logging.info(f"AAC: Dataloader Instance loading is done ETA Step B done: {time.time() - time_flag}") + return dataloader_instance + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,) + + def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): + self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + self._test_dl = self.__setup_dataloader_from_config(config=test_data_config,) + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + audio_eltype = AudioSignal() + return { + "audio_signal": NeuralType(('B', 'T'), audio_eltype), + "audio_signal_length": NeuralType(('B',), LengthsType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return OrderedDict( + { + "preds": NeuralType(('B', 'T', 'C'), ProbsType()), + } + ) + + def frontend_encoder(self, processed_signal, processed_signal_length): + """ + Generate encoder outputs from frontend encoder. + + Args: + process_signal (torch.Tensor): tensor containing audio-feature (mel spectrogram, mfcc, etc.) + processed_signal_length (torch.Tensor): tensor containing lengths of audio signal in integers + + Returns: + emb_seq (torch.Tensor): tensor containing encoder outputs + emb_seq_length (torch.Tensor): tensor containing lengths of encoder outputs + """ + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + self.encoder = self.encoder.to(self.device) + emb_seq, emb_seq_length = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + emb_seq = emb_seq.transpose(1, 2) + if self._cfg.encoder.d_model != self._cfg.tf_d_model: + self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) + emb_seq = self.sortformer_modules.encoder_proj(emb_seq) + return emb_seq, emb_seq_length + + def forward_infer(self, emb_seq): + """ + The main forward pass for diarization for offline diarization inference. + + Args: + emb_seq (torch.Tensor): tensor containing FastConformer encoder states (embedding vectors). + Dimension: (batch_size, diar_frame_count, emb_dim) + + Returns: + preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. + Dimension: (batch_size, diar_frame_count, num_speakers) + encoder_states_list (list): List containing total speaker memory for each step for debugging purposes + Dimension: [(batch_size, diar_frame_count, inner dim), ... ] + """ + encoder_mask = self.sortformer_modules.length_to_mask(emb_seq) + trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask) + preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) + return preds + + def process_signal(self, audio_signal, audio_signal_length): + """ + Extract audio features from time-series signal for further processing in the model. + + This function performs the following steps: + 1. Moves the audio signal to the correct device. + 2. Normalizes the time-series audio signal. + 3. Extrac audio feature from from the time-series audio signal using the model's preprocessor. + + Args: + audio_signal (torch.Tensor): The input audio signal. + Shape: (batch_size, num_samples) + audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + Shape: (batch_size,) + + Returns: + tuple: A tuple containing: + - processed_signal (torch.Tensor): The preprocessed audio signal. + Shape: (batch_size, num_features, num_frames) + - processed_signal_length (torch.Tensor): The length of each processed signal. + Shape: (batch_size,) + """ + audio_signal = audio_signal.to(self.device) + audio_signal = (1/(audio_signal.max()+self.eps)) * audio_signal + processed_signal, processed_signal_length = self.preprocessor(input_signal=audio_signal, length=audio_signal_length) + return processed_signal, processed_signal_length + + def forward( + self, + audio_signal, + audio_signal_length, + ): + """ + Forward pass for training and inference. + + Args: + audio_signal (torch.Tensor): tensor containing audio waveform + Dimension: (batch_size, num_samples) + audio_signal_length (torch.Tensor): tensor containing lengths of audio waveforms + Dimension: (batch_size,) + + Returns: + preds (torch.Tensor): Sorted tensor containing predicted speaker labels + Dimension: (batch_size, diar_frame_count, num_speakers) + encoder_states_list (list): List containing total speaker memory for each step for debugging purposes + Dimension: [(batch_size, diar_frame_count, inner dim), ] + """ + processed_signal, processed_signal_length = self.process_signal(audio_signal=audio_signal, audio_signal_length=audio_signal_length) + processed_signal = processed_signal[:, :, :processed_signal_length.max()] + if self._cfg.get("streaming_mode", False): + raise NotImplementedError("Streaming mode is not implemented yet.") + else: + emb_seq, _ = self.frontend_encoder(processed_signal=processed_signal, processed_signal_length=processed_signal_length) + preds = self.forward_infer(emb_seq) + return preds + + def _get_aux_train_evaluations(self, preds, targets, target_lens): + """ + Compute auxiliary training evaluations including losses and metrics. + + This function calculates various losses and metrics for the training process, + including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + (dict): A dictionary containing the following training metrics. + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + loss = self.ats_weight * ats_loss + self.pil_weight * pil_loss + + self._accuracy_train(preds, targets_pil, target_lens) + train_f1_acc, train_precision, train_recall = self._accuracy_train.compute() + + self._accuracy_train_ats(preds, targets_ats, target_lens) + train_f1_acc_ats, _, _ = self._accuracy_train_ats.compute() + + train_metrics = { + 'loss': loss, + 'ats_loss': ats_loss, + 'pil_loss': pil_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'train_f1_acc': train_f1_acc, + 'train_precision': train_precision, + 'train_recall': train_recall, + 'train_f1_acc_ats': train_f1_acc_ats, + } + return train_metrics + + def training_step(self, batch: list) -> dict: + """ + Performs a single training step. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal in time-series format. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. + + Returns: + (dict): A dictionary containing the 'loss' key with the calculated loss value. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward(audio_signal=audio_signal, audio_signal_length=audio_signal_length) + train_metrics = self._get_aux_train_evaluations(preds, targets, target_lens) + self._reset_train_metrics() + self.log_dict(train_metrics, sync_dist=True, on_step=True, on_epoch=False, logger=True) + return {'loss': train_metrics['loss']} + + def _get_aux_validation_evaluations(self, preds, targets, target_lens): + """ + Compute auxiliary validation evaluations including losses and metrics. + This function calculates various losses and metrics for the validation process, + including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + dict: A dictionary containing the following validation metrics + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + + val_ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + val_pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + val_loss = self.ats_weight * val_ats_loss + self.pil_weight * val_pil_loss + + self._accuracy_valid(preds, targets_pil, target_lens) + val_f1_acc, val_precision, val_recall = self._accuracy_valid.compute() + + self._accuracy_valid_ats(preds, targets_ats, target_lens) + valid_f1_acc_ats, _, _ = self._accuracy_valid_ats.compute() + + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + val_metrics = { + 'val_loss': val_loss, + 'val_ats_loss': val_ats_loss, + 'val_pil_loss': val_pil_loss, + 'val_f1_acc': val_f1_acc, + 'val_precision': val_precision, + 'val_recall': val_recall, + 'val_f1_acc_ats': valid_f1_acc_ats, + } + return val_metrics + + def validation_step(self, batch: list, dataloader_idx: int = 0): + """ + Performs a single validation step. + + This method processes a batch of data during the validation phase. It forward passes + the audio signal through the model, computes various validation metrics, and stores + these metrics for later aggregation. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. + dataloader_idx (int, optional): The index of the dataloader in case of multiple + validation dataloaders. Defaults to 0. + + Returns: + dict: A dictionary containing various validation metrics for this batch. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + val_metrics = self._get_aux_validation_evaluations(preds, targets, target_lens) + if isinstance(self.trainer.val_dataloaders, list) and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(val_metrics) + else: + self.validation_step_outputs.append(val_metrics) + return val_metrics + + def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): + if not outputs: + logging.warning(f"`outputs` is None; empty outputs for dataloader={dataloader_idx}") + return None + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + val_ats_loss_mean = torch.stack([x['val_ats_loss'] for x in outputs]).mean() + val_pil_loss_mean = torch.stack([x['val_pil_loss'] for x in outputs]).mean() + val_f1_acc_mean = torch.stack([x['val_f1_acc'] for x in outputs]).mean() + val_precision_mean = torch.stack([x['val_precision'] for x in outputs]).mean() + val_recall_mean = torch.stack([x['val_recall'] for x in outputs]).mean() + val_f1_acc_ats_mean = torch.stack([x['val_f1_acc_ats'] for x in outputs]).mean() + + self._reset_valid_metrics() + + multi_val_metrics = { + 'val_loss': val_loss_mean, + 'val_ats_loss': val_ats_loss_mean, + 'val_pil_loss': val_pil_loss_mean, + 'val_f1_acc': val_f1_acc_mean, + 'val_precision': val_precision_mean, + 'val_recall': val_recall_mean, + 'val_f1_acc_ats': val_f1_acc_ats_mean, + } + return {'log': multi_val_metrics} + + def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target_lens): + """ + Compute auxiliary validation evaluations including losses and metrics. + + This function calculates various losses and metrics for the validation process, + including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + dict: A dictionary containing the following validation metrics + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + self._accuracy_test(preds, targets_pil, target_lens) + f1_acc, precision, recall = self._accuracy_test.compute() + self.batch_f1_accs_list.append(f1_acc) + self.batch_precision_list.append(precision) + self.batch_recall_list.append(recall) + logging.info(f"batch {batch_idx}: f1_acc={f1_acc}, precision={precision}, recall={recall}") + + self._accuracy_test_ats(preds, targets_ats, target_lens) + f1_acc_ats, precision_ats, recall_ats = self._accuracy_test_ats.compute() + self.batch_f1_accs_ats_list.append(f1_acc_ats) + logging.info(f"batch {batch_idx}: f1_acc_ats={f1_acc_ats}, precision_ats={precision_ats}, recall_ats={recall_ats}") + + self._accuracy_test.reset() + self._accuracy_test_ats.reset() + + def test_batch(self,): + """ + Perform batch testing on the model. + + This method iterates through the test data loader, making predictions for each batch, + and calculates various evaluation metrics. It handles both single and multi-sample batches. + """ + self.preds_total_list, self.batch_f1_accs_list, self.batch_precision_list, self.batch_recall_list, self.batch_f1_accs_ats_list = [], [], [], [], [] + + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(self._test_dl)): + audio_signal, audio_signal_length, targets, target_lens = batch + audio_signal = audio_signal.to(self.device) + audio_signal_length = audio_signal_length.to(self.device) + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + preds = preds.detach().to('cpu') + if preds.shape[0] == 1: # batch size = 1 + self.preds_total_list.append(preds) + else: + self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) + torch.cuda.empty_cache() + self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) + + logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") + logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") + logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") + logging.info(f"Batch ATS F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_ats_list))}") + + def diarize(self,): + raise NotImplementedError diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py new file mode 100644 index 000000000000..a1d34e1f7480 --- /dev/null +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -0,0 +1,1231 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import copy +import math +import random +import logging +import itertools +from copy import deepcopy +import concurrent.futures +from cytoolz import groupby +from collections import defaultdict +from typing import Dict, Optional, Tuple, List + +import numpy as np +import soundfile +from tqdm import tqdm +from scipy.stats import norm + +import torch.utils.data +from lhotse.cut.set import mix +from lhotse.cut import CutSet, MixedCut, MonoCut, MixTrack +from lhotse import SupervisionSet, SupervisionSegment, dill_enabled, AudioSource, Recording +from lhotse.utils import uuid4 + +def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres:float = 0.5) -> torch.Tensor: + """ + Finds the first nonzero value in the matrix, discretizing it to the specified maximum capacity. + + Args: + mat (Tensor): A torch tensor representing the matrix. + max_cap_val (int): The maximum capacity to which the matrix values will be discretized. + thres (float): The threshold value for discretizing the matrix values. + + Returns: + mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first nonzero value in each row. + """ + # Discretize the matrix to the specified maximum capacity + labels_discrete = mat.clone() + labels_discrete[labels_discrete < thres] = 0 + labels_discrete[labels_discrete >= thres] = 1 + + # non zero values mask + non_zero_mask = labels_discrete != 0 + # operations on the mask to find first nonzero values in the rows + mask_max_values, mask_max_indices = torch.max(non_zero_mask, dim=1) + # if the max-mask is zero, there is no nonzero value in the row + mask_max_indices[mask_max_values == 0] = max_cap_val + return mask_max_indices + +def find_best_permutation(match_score: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Finds the best permutation indices based on the match score. + + Args: + match_score (torch.Tensor): A tensor containing the match scores for each permutation. + Shape: (batch_size, num_permutations) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + """ + batch_best_perm = torch.argmax(match_score, axis=1) + rep_speaker_permutations = speaker_permutations.repeat(batch_best_perm.shape[0], 1).to(match_score.device) + perm_size = speaker_permutations.shape[0] + global_inds_vec = torch.arange(0, perm_size * batch_best_perm.shape[0], perm_size).to(batch_best_perm.device) + batch_best_perm + return rep_speaker_permutations[global_inds_vec.to(rep_speaker_permutations.device), :] + +def reconstruct_labels(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: + """ + Reconstructs the labels using the best permutation indices with matrix operations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + batch_perm_inds (torch.Tensor): A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Expanding batch_perm_inds to align with labels dimensions + batch_size, num_frames, num_speakers = labels.shape + batch_perm_inds_exp = batch_perm_inds.unsqueeze(1).expand(-1, num_frames, -1) + + # Reconstructing the labels using advanced indexing + reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) + return reconstructed_labels + +def get_ats_targets( + labels: torch.Tensor, + preds: torch.Tensor, + speaker_permutations: torch.Tensor, + thres: float = 0.5, + tolerance: float = 0 +) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal of all arrival-time ordered permutations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + preds (torch.Tensor): A tensor containing the predictions. + Shape: (batch_size, num_frames, num_speakers) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + thres (float): The threshold value for discretizing the matrix values. Default is 0.5. + tolerance (float): The tolerance for comparing the first speech frame indices. Default is 0. + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Find the first nonzero frame index for each speaker in each batch + nonzero_ind = find_first_nonzero(mat=labels, max_cap_val=labels.shape[1], thres=thres) # (batch_size, num_speakers) + + # Sort the first nonzero frame indices for arrival-time ordering + sorted_values = torch.sort(nonzero_ind)[0] # (batch_size, num_speakers) + perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_frames, num_permutations, num_speakers) + permed_nonzero_ind = find_first_nonzero(mat=permed_labels, max_cap_val=labels.shape[1]) # (batch_size, num_permutations, num_speakers) + + # Compare the first frame indices of sorted labels with those of the permuted labels using tolerance + perm_compare = torch.abs(sorted_values.unsqueeze(1) - permed_nonzero_ind) <= tolerance # (batch_size, num_permutations, num_speakers) + perm_mask = torch.all(perm_compare, dim=2).float() # (batch_size, num_permutations) + preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, perm_size, 1) # Exapnd the preds: (batch_size, num_frames, num_permutations, num_speakers) + + # Compute the match score for each permutation by comparing permuted labels with preds + match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) * perm_mask # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_frames, num_speakers) + return max_score_permed_labels # (batch_size, num_frames, num_speakers) + +def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal permutation based on the match score. + + Args: + labels (torch.Tensor): A tensor containing the ground truth labels. + Shape: (batch_size, num_speakers, num_classes) + preds (torch.Tensor): A tensor containing the predicted values. + Shape: (batch_size, num_speakers, num_classes) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor of permuted labels that best match the predictions. + Shape: (batch_size, num_speakers, num_classes) + """ + perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_classes, num_permutations, num_speakers) + # Repeat preds to match permutations for comparison + preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, speaker_permutations.shape[0], 1) # (batch_size, num_speakers, num_permutations, num_classes) + match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + # Reconstruct labels based on the best permutation for each batch + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) + return max_score_permed_labels # (batch_size, num_speakers, num_classes) + +def apply_spk_mapping(diar_preds: torch.Tensor, spk_mappings: torch.Tensor) -> torch.Tensor: + """ + Applies a speaker mapping to diar predictions. + + Args: + diar_preds (Tensor): The diar predictions tensor. + Dimension: (batch_size, num_frames, num_speakers) + spk_mappings (Tensor): The speaker mappings tensor. + Dimension: (batch_size, num_speakers) + + Returns: + permuted_diar_preds (Tensor): The permuted diar predictions tensor with the given speaker mappings. + """ + expanded_mappings = spk_mappings.unsqueeze(1).expand(-1, diar_preds.size(1), -1) + permuted_diar_preds = torch.gather(diar_preds, 2, expanded_mappings) + return permuted_diar_preds + +def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool = False, pattern= r'<\|spltoken\d+\|>') -> Tuple[CutSet, torch.Tensor]: + """ + Applies a shuffle mapping to speaker text labels in the cuts. + Example: + Original cut.text: + "<|spltoken0|> we do shuffle <|spltoken1|> and map speakers <|spltoken0|> yes <|spltoken2|> we keep dimensions" + Speaker Mapping: [3, 0, 1, 2] + Shuffled cut.text: + "<|spltoken3|> we do shuffle <|spltoken0|> and map speakers <|spltoken3|> yes <|spltoken1|> we keep dimensions" + + Args: + cuts (List[MonoCut, MixedCut]): A list of Cut instances. + num_speakers (int): The total number of speakers. + shuffle_spk_mapping (bool): Whether to shuffle the speaker mappings. + pattern (str): A regular expression pattern for speaker tokens. + + Returns: + cuts (list): The updated CutSet with shuffled speaker mappings. + spk_mappings (Tensor): + If shuffle_speaker_mapping is True, shuffled speaker mappings in batch. + If shuffle_speaker_mapping is False, speaker mappings in batch is not permuted and returns torch.arange() values. + """ + batch_size = len(cuts) + if shuffle_spk_mapping: + permuted_indices = torch.rand(batch_size, num_speakers).argsort(dim=1) + spk_mappings = torch.gather(torch.arange(num_speakers).repeat(batch_size, 1), 1, permuted_indices) + str_pattern = pattern.replace("\\", '') + left_str, right_str = str_pattern.split('d+')[0], str_pattern.split('d+')[1] + for idx, cut in enumerate(cuts): + word_list = [] + for word in deepcopy(cut.text).split(): + if len(re.findall(pattern, word)) > 0: + spk_token_int = int(word.replace(left_str,'').replace(right_str, '')) + new_spk = spk_mappings[idx][spk_token_int] + word_list.append(f'{left_str}{new_spk}{right_str}') + else: + word_list.append(word) + cuts[idx].supervisions[0].text = ' '.join(word_list) + else: + spk_mappings = torch.arange(num_speakers).unsqueeze(0).repeat(batch_size, 1) + return cuts, spk_mappings + +def find_segments_from_rttm( + recording_id: str, + rttms, + start_after: float, + end_before: float, + adjust_offset: bool=True, + tolerance: float=0.001): + """ + Finds segments from the given rttm file. + This function is designed to replace rttm + + Args: + recording_id (str): The recording ID in string format. + rttms (SupervisionSet): The SupervisionSet instance. + start_after (float): The start time after which segments are selected. + end_before (float): The end time before which segments are selected. + adjust_offset (bool): Whether to adjust the offset of the segments. + tolerance (float): The tolerance for time matching. 0.001 by default. + + Returns: + segments (List[SupervisionSegment]): A list of SupervisionSegment instances. + """ + segment_by_recording_id = rttms._segments_by_recording_id + if segment_by_recording_id is None: + from cytoolz import groupby + segment_by_recording_id = groupby(lambda seg: seg.recording_id, rttms) + + return [ + # We only modify the offset - the duration remains the same, as we're only shifting the segment + # relative to the Cut's start, and not truncating anything. + segment.with_offset(-start_after) if adjust_offset else segment + for segment in segment_by_recording_id.get(recording_id, []) + if segment.start < end_before + tolerance + and segment.end > start_after + tolerance + ] + +def speaker_to_target( + a_cut, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, + spk_tar_all_zero: bool = False, + boundary_segments: bool = False, + soft_label: bool = False, + ignore_num_spk_mismatch: bool = True, + soft_thres: float = 0.5, + ): + ''' + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) + This function is needed for speaker diarization with ASR model trainings. + + Args: + a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): encoder subsampling_factor, 8 by default + spk_tar_all_zero (Tensor): set to True gives all zero "mask" + boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training + soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. + + Returns: + mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) + ''' + # get cut-related segments from rttms + # basename = os.path.basename(a_cut.rttm_filepath).replace('.rttm', '') + if isinstance(a_cut, MixedCut): + cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + elif isinstance(a_cut, MonoCut): + cut_list = [a_cut] + offsets = [0] + else: + raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") + + segments_total = [] + for i, cut in enumerate(cut_list): + rttms = SupervisionSet.from_rttm(cut.rttm_filepath) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) + + for seg in segments_iterator: + if seg.start < 0: + seg.duration += seg.start + seg.start = 0 + if seg.end > cut.duration: + seg.duration -= seg.end - cut.duration + seg.start += offsets[i] + segments_total.append(seg) + + # apply arrival time sorting to the existing segments + segments_total.sort(key = lambda rttm_sup: rttm_sup.start) + + seen = set() + seen_add = seen.add + speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] + + speaker_to_idx_map = { + spk: idx + for idx, spk in enumerate(speaker_ats) + } + if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers + raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") + + # initialize mask matrices (num_speaker, encoder_hidden_len) + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) + if spk_tar_all_zero: + frame_mask = torch.zeros((num_samples, num_speakers)) + else: + frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) + soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) + + if soft_label: + mask = soft_mask + else: + mask = (soft_mask > soft_thres).float() + + return mask + +def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tensor, num_speakers: int =4, feat_per_sec: int=100, ignore_num_spk_mismatch: bool = False): + """ + Generate mask matrix from segments list. + This function is needed for speaker diarization with ASR model trainings. + + Args: + segments: A list of Lhotse Supervision segments iterator. + cut (MonoCut, MixedCut): Lhotse MonoCut or MixedCut instance. + speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. + + Returns: + mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). + Dimension: (num_speakers, num_frames) + """ + # get targets with 0.01s frame rate + num_samples = round(a_cut.duration * feat_per_sec) + mask = torch.zeros((num_samples, num_speakers)) + for rttm_sup in segments: + speaker_idx = speaker_to_idx_map[rttm_sup.speaker] + if speaker_idx >= num_speakers: + if ignore_num_spk_mismatch: + continue + else: + raise ValueError(f"Speaker Index {speaker_idx} exceeds the max index: {num_speakers-1}") + stt = max(rttm_sup.start, 0) + ent = min(rttm_sup.end, a_cut.duration) + stf = int(stt * feat_per_sec) + enf = int(ent * feat_per_sec) + mask[stf:enf, speaker_idx] = 1.0 + return mask + +def get_soft_mask(feat_level_target, num_samples, stride): + """ + Get soft mask from feat_level_target with stride. + This function is needed for speaker diarization with ASR model trainings. + + Args: + feat_level_target (Tensor): A numpy array of shape (num_frames, num_speakers). + Dimension: (num_frames, num_speakers) + num_sample (int): The total number of samples. + stride (int): The stride for the mask. + """ + + num_speakers = feat_level_target.shape[1] + mask = torch.zeros(num_samples, num_speakers) + + for index in range(num_samples): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_samples - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + mask[index] = torch.mean(feat_level_target[seg_stt_feat:seg_end_feat+1, :], axis=0) + return mask + +def get_hidden_length_from_sample_length( + num_samples: int, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8 +) -> int: + """ + Calculate the hidden length from the given number of samples. + This function is needed for speaker diarization with ASR model trainings. + + This function computes the number of frames required for a given number of audio samples, + considering the number of samples per mel frame and the number of mel frames per ASR frame. + + Parameters: + num_samples (int): The total number of audio samples. + num_sample_per_mel_frame (int, optional): The number of samples per mel frame. Default is 160. + num_mel_frame_per_asr_frame (int, optional): The number of mel frames per ASR frame. Default is 8. + + Returns: + hidden_length (int): The calculated hidden length in terms of the number of frames. + """ + mel_frame_count = math.ceil((num_samples + 1) / num_sample_per_mel_frame) + hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) + return int(hidden_length) + +class ConcatenationMeetingSimulator(): + """ + This simulator concatenates the segments from different/same sessions to create a + multi-speaker meeting. + """ + + def __init__( + self, + intra_session_concat_prob: float|List[float] = [0, 1.0, 0.5, 0.2], + data_type: str = "msasr", + min_duration: float = 30.0, + max_duration: float = 40.0, + max_num_speakers: int = 4, + speaker_count_distribution: List[float] = [0, 2, 3, 4], + skip_long_segments: bool = True, + valid_dataset_ids: List[str] = [], + ): + """ + :param intra_session_concat_prob: the probability of concatenating segments from the same + session. [Default: 1] + :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', + the transcripts are included in the simulation,and the boundary segments are + not included. [Default: 'msasr'] + :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] + """ + super().__init__() + if isinstance(intra_session_concat_prob, float): + self.intra_session_concat_prob = [intra_session_concat_prob] * (max_num_speakers) + elif len(intra_session_concat_prob) == max_num_speakers: + self.intra_session_concat_prob = intra_session_concat_prob + else: + raise ValueError(f"intra_session_concat_prob must be either a float or a list of floats, but got {intra_session_concat_prob}") + if data_type not in ["msasr", "diar"]: + raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") + self.data_type = data_type + self.min_duration = min_duration + self.max_duration = max_duration + self.max_num_speakers = max_num_speakers + self.speaker_count_distribution = speaker_count_distribution + assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" + + if skip_long_segments: + self.skip_duration = max_duration / 2 + else: + self.skip_duration = max_duration + + self.valid_dataset_ids = valid_dataset_ids + + def fit(self, cuts) -> CutSet: + """ + Read the manifest file and return a CutSet object. + Each line in the manifest file should be a JSON object representing a segment. + """ + + self.id2cut = {} + self.sess2cut_ids = defaultdict(list) + self.sess2spks = defaultdict(set) + self.data2sess_ids = defaultdict(list) + self.spk2cut_ids = defaultdict(list) + self.data2num_spk2cut_ids = {} + self.sess2num_spk2cut_ids = {} + self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} + for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): + if cut.duration > self.skip_duration: + continue + if not hasattr(cut, 'dataset_id') or cut.dataset_id is None: + continue + if self.valid_dataset_ids and cut.dataset_id not in self.valid_dataset_ids: + continue + if cut.dataset_id not in self.data2num_spk2cut_ids: + self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) + if cut.recording_id not in self.sess2num_spk2cut_ids: + self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) + + speakers = cut.global_speaker_ids + if self.data_type == "msasr": + speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) + if len(speakers) != len(speaker_tokens): + # Lhotse automatically fixes the max duration of the cut, + # resulting in the mismatch of the number of speakers + # and speaker tokens for the last segment + # TODO: need to fix the issue in Lhotse that automatically fixes the max duration + continue + for spk in speakers: + self.spk2cut_ids[spk].append(cut.id) + self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) + + self.id2cut[cut.id] = cut + self.sess2cut_ids[cut.recording_id].append(cut.id) + self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) + self.sess2num_spk2cut_ids[cut.recording_id][len(speakers)].append(cut.id) + self.num_spk2cut_ids[len(speakers)].append(cut.id) + if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: + self.data2sess_ids[cut.dataset_id].append(cut.recording_id) + + self.cut_ids = list(self.id2cut.keys()) + self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) + + self.data2global_speaker = { + dataset_id: True for dataset_id in self.data2sess_ids.keys() + } + + def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: + + db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data + + if is_intra_session_concat: + # intra-dataset and intra-session concatenation + tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) + + else: + # intra-dataset but inter-session concatenation + tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) + + cut = MixedCut(id='concat_' + '_'.join([track.cut.id for track in tracks]), tracks=tracks) + if self.data_type == "msasr": + cut = self.reorder_spk_mapping(cut) + + assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" + assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" + + return cut + + def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + """ + Get the tracks for the MixedCut object. + """ + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + + total_duration = 0.0 + total_spk_set = set() + tracks = [] + while True: + cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] + tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) + total_spk_set = total_spk_set.union(cut.global_speaker_ids) + total_duration += cut.duration + + # break condition + if total_duration >= self.min_duration: + if total_duration > self.max_duration: # exceed the maximum duration, starting over + total_duration = 0.0 + total_spk_set = set() + tracks = [] + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + break + else: + total_duration = 0.0 + total_spk_set = set() + tracks = [] + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + + return tracks, len(total_spk_set) + + def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + """ + Get the tracks for the MixedCut object. + """ + sample_cut = self.id2cut[random.choice(self.cut_ids)] + dataset_id = sample_cut.dataset_id + n_spk_list = [n_spk for n_spk, cut_ids in self.data2num_spk2cut_ids[dataset_id].items() if len(cut_ids) > 0] + sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) + + if min(sum_spk_list) > n_speakers: + raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") + + n_spk_left = n_speakers + total_duration = 0.0 + total_spk_set = set() + tracks = [] + num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] + while True: + #if n_spk_left == n_speakers: # for more speakers cases + # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk < n_spk_left]) + if n_spk_left >= 2: + n_spk = 2 + else: + # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk <= n_spk_left]) + n_spk = 1 + + while True: + cut = self.id2cut[random.choice(num_spk2cut_ids[n_spk])] + spks = set(cut.global_speaker_ids) + if not spks.intersection(total_spk_set): + break + + tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) + total_duration += cut.duration + n_spk_left -= n_spk + total_spk_set = total_spk_set.union(spks) + + # break condition + + if total_duration >= self.min_duration: + if total_duration > self.max_duration or len(total_spk_set) < n_speakers: # exceed the maximum duration, starting over + total_duration = 0.0 + n_spk_left = n_speakers + total_spk_set = set() + tracks = [] + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + break + else: + if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers + total_duration = 0.0 + n_spk_left = n_speakers + total_spk_set = set() + tracks = [] + + return tracks, len(total_spk_set) + + def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: + """ + Concatenate the texts of the input cuts. + + """ + global_spk_mapping = {} + str_pattern = pattern.replace("\\", '') + left_str, right_str = str_pattern.split('d+') + for i, track in enumerate(cut.tracks): + local_inverse_spk_mapping = {} + local_spk_mapping = {} + for speaker in track.cut.global_speaker_ids: + if speaker not in global_spk_mapping: + global_spk_mapping[speaker] = len(global_spk_mapping) + if speaker not in local_spk_mapping: + local_spk_mapping[speaker] = len(local_spk_mapping) + local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker + + if i != 0: + text = '' + for word in track.cut.text.split(): + if len(re.findall(pattern, word)) > 0: + local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) + spk = local_inverse_spk_mapping[local_spk_idx] + global_spk_idx = global_spk_mapping[spk] + text += f'{left_str}{global_spk_idx}{right_str}' + else: + text += ' ' + word + track.cut.supervisions[0].text = text + cut.supervisions[i].text = text + else: + cut.supervisions[0].text = track.cut.text + # TODO: need to check the last speaker of last track and the first speaker of the current track + # if they are the same, we need to remove the the speaker token from the current track for segment-level + # Do not need to remove the speaker token for word-level + + return cut + + def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: + """ + Balance the speaker distribution for the simulated meetings. + Args: + num_meetings: The total number of simulated meetings. + speaker_count_distribution: The speaker count distribution for the simulated meetings. + For each number of speakers, calculate the number of meetings needed to balance the distribution. + """ + + total_spk = sum(speaker_count_distribution) + num_speakers2num_meetings = {} + for i_spk in range(self.max_num_speakers): + num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + + return num_speakers2num_meetings + + + @dill_enabled(True) + def simulate(self, + cuts: CutSet, + num_meetings: int = 10000, + seed: int = 0, + num_jobs: int = 1, + ) -> CutSet: + random.seed(seed) + + self.fit(cuts) + + + num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) + logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") + num_speakers2num_meetings[1] = 0 # skip 1-speaker samples + logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') + + # Step 0: Calculate the number of intra-session and inter-session concatentation samples + n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] + valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples + n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} + for n_spk, n_mt in num_speakers2num_meetings.items(): + logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) + if n_mt <= 0: + logging.warning(f"No concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + continue + n_intra_mt = int(n_mt * self.intra_session_concat_prob[n_spk-1]) + n_inter_mt = n_mt - n_intra_mt + if n_spk in self.num_spk2sess_ids: + logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") + n_spk2n_intra_mt[n_spk] = n_intra_mt + else: + logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + n_spk2n_intra_mt[n_spk] = 0 + n_inter_mt = n_mt + if n_spk in valid_sim_n_spks: + logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") + n_spk2n_inter_mt[n_spk] = n_inter_mt + else: + logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + if n_spk2n_intra_mt[n_spk] != 0: + n_spk2n_intra_mt[n_spk] = n_mt + logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") + else: + logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") + # Step 1: intra-session + num_intra_meetings = 0 + intra_mixtures = [] + logging.info(f"Simulating intra-session concatentation samples.") + for n_spk, n_mt in n_spk2n_intra_mt.items(): + if n_mt <= 0: + continue + + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): + intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) + num_intra_meetings += n_mt + logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") + + # Steo 2: inter-session + logging.info(f"Simulating inter-session concatentation samples.") + + num_inter_meetings = 0 + inter_mixtures = [] + for n_spk, n_mt in n_spk2n_inter_mt.items(): + if n_mt <= 0: + continue + + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): + inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) + num_inter_meetings += n_mt + logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") + + if num_inter_meetings + num_intra_meetings == 0: + logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration//2} and max {self.max_duration//2}, or the speaker count distribution is not correctly set.") + + + # Multi-processing gets slower, TODO + # else: + # futures = [] + # for n_spk, n_mt in num_speakers2num_meetings.items(): + # tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_jobs) + # futures.extend([tp.submit(self._create_mixture, n_spk) for _ in range(n_mt)]) + # pbar = tqdm(total=num_meetings, desc=f"Simulating mixtures", unit="line", ncols=128) + # count = 0 + # for f in concurrent.futures.as_completed(futures): + # count += 1 + # pbar.update() + # mixtures.append(f.result()) + # tp.shutdown() + # pbar.close() + + return CutSet.from_cuts(intra_mixtures + inter_mixtures) + + +class MixMeetingSimulator(): + """ + This simulator Mix the segments from different/same sessions to create a + multi-speaker meeting. + """ + + def __init__( + self, + intra_session_mix_prob: float|List[float] = [0, 0, 0, 0], + data_type: str = "msasr", + min_duration: float = 80.0, + max_duration: float = 100.0, + max_num_speakers: int = 4, + speaker_count_distribution: List[float] = [0, 0, 0.1, 4], + valid_dataset_ids: List[str] = [], + ): + """ + :param intra_session_mix_prob: the probability of concatenating segments from the same + session. [Default: 1] + :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', + the transcripts are included in the simulation,and the boundary segments are + not included. [Default: 'msasr'] + :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] + """ + super().__init__() + if isinstance(intra_session_mix_prob, float): + self.intra_session_mix_prob = [intra_session_mix_prob] * (max_num_speakers) + elif len(intra_session_mix_prob) == max_num_speakers: + self.intra_session_mix_prob = intra_session_mix_prob + else: + raise ValueError(f"intra_session_mix_prob must be either a float or a list of floats, but got {intra_session_mix_prob}") + if data_type not in ["msasr", "diar"]: + raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") + self.data_type = data_type + self.min_duration = min_duration + self.max_duration = max_duration + self.max_num_speakers = max_num_speakers + self.speaker_count_distribution = speaker_count_distribution + self.valid_dataset_ids = valid_dataset_ids + assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" + + def fit(self, cuts) -> CutSet: + """ + Read the manifest file and return a CutSet object. + Each line in the manifest file should be a JSON object representing a segment. + """ + + self.id2cut = {} + self.sess2cut_ids = defaultdict(list) + self.sess2spks = defaultdict(set) + self.data2sess_ids = defaultdict(list) + self.spk2cut_ids = defaultdict(list) + self.data2num_spk2cut_ids = {} + self.sess2num_spk2cut_ids = {} + self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} + for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): + if not self.min_duration <= cut.duration <= self.max_duration: + continue + if not hasattr(cut, 'dataset_id') or cut.dataset_id is None: + continue + if self.valid_dataset_ids and cut.dataset_id not in self.valid_dataset_ids: + continue + if cut.dataset_id not in self.data2num_spk2cut_ids: + self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) + if cut.recording_id not in self.sess2num_spk2cut_ids: + self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) + + speakers = cut.global_speaker_ids + if self.data_type == "msasr": + speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) + if len(speakers) != len(speaker_tokens): + # Lhotse automatically fixes the max duration of the cut, + # resulting in the mismatch of the number of speakers + # and speaker tokens for the last segment + # TODO: need to fix the issue in Lhotse that automatically fixes the max duration + continue + for spk in speakers: + self.spk2cut_ids[spk].append(cut.id) + self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) + + self.id2cut[cut.id] = cut + self.sess2cut_ids[cut.recording_id].append(cut.id) + self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) + self.sess2num_spk2cut_ids[cut.recording_id][len(speakers)].append(cut.id) + self.num_spk2cut_ids[len(speakers)].append(cut.id) + if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: + self.data2sess_ids[cut.dataset_id].append(cut.recording_id) + + self.cut_ids = list(self.id2cut.keys()) + self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) + + self.data2global_speaker = { + dataset_id: True for dataset_id in self.data2sess_ids.keys() + } + + def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: + + db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data + + if is_intra_session_concat: + # intra-dataset and intra-session concatenation + tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) + + else: + # intra-dataset but inter-session concatenation + tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) + + cut = MixedCut(id='mix_' + '_'.join([track.cut.id for track in tracks]), tracks=tracks) + if self.data_type == "msasr": + cut = self.reorder_spk_mapping(cut) + + assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" + assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" + + return cut + + def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + """ + Get the tracks for the MixedCut object. + """ + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + + total_spk_set = set() + tracks = [] + while True: + cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] + tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) + total_spk_set = total_spk_set.union(cut.global_speaker_ids) + total_duration = max(total_duration, cut.duration) + + # break condition + if total_duration >= self.min_duration: + if total_duration > self.max_duration: # exceed the maximum duration, starting over + total_duration = 0.0 + total_spk_set = set() + tracks = [] + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + break + else: + total_duration = 0.0 + total_spk_set = set() + tracks = [] + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + + return tracks, len(total_spk_set) + + def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + """ + Get the tracks for the MixedCut object. + """ + sample_cut = self.id2cut[random.choice(self.cut_ids)] + dataset_id = sample_cut.dataset_id + n_spk_list = [n_spk for n_spk, cut_ids in self.data2num_spk2cut_ids[dataset_id].items() if len(cut_ids) > 0] + sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) + + if min(sum_spk_list) > n_speakers: + raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") + + n_spk_left = n_speakers + total_duration = 0.0 + total_spk_set = set() + tracks = [] + num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] + while True: + if n_spk_left >= 2: + n_spk = 2 + else: + # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk <= n_spk_left]) + n_spk = 1 + + while True: + cut = self.id2cut[random.choice(num_spk2cut_ids[n_spk])] + spks = set(cut.global_speaker_ids) + if not spks.intersection(total_spk_set): + break + + tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) + total_duration = max(total_duration, cut.duration) + n_spk_left -= n_spk + total_spk_set = total_spk_set.union(spks) + + # break condition + + if total_duration >= self.min_duration: + if total_duration > self.max_duration or len(tracks) > 2: # exceed the maximum duration, starting over + total_duration = 0.0 + n_spk_left = n_speakers + total_spk_set = set() + tracks = [] + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + break + else: + if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers + total_duration = 0.0 + n_spk_left = n_speakers + total_spk_set = set() + tracks = [] + + return tracks, len(total_spk_set) + + def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: + """ + Concatenate the texts of the input cuts. + + """ + global_spk_mapping = {} + str_pattern = pattern.replace("\\", '') + left_str, right_str = str_pattern.split('d+') + for i, track in enumerate(cut.tracks): + local_inverse_spk_mapping = {} + local_spk_mapping = {} + for speaker in track.cut.global_speaker_ids: + if speaker not in global_spk_mapping: + global_spk_mapping[speaker] = len(global_spk_mapping) + if speaker not in local_spk_mapping: + local_spk_mapping[speaker] = len(local_spk_mapping) + local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker + + if i != 0: + text = '' + for word in track.cut.text.split(): + if len(re.findall(pattern, word)) > 0: + local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) + spk = local_inverse_spk_mapping[local_spk_idx] + global_spk_idx = global_spk_mapping[spk] + text += f'{left_str}{global_spk_idx}{right_str}' + else: + text += ' ' + word + track.cut.supervisions[0].text = text + cut.supervisions[i].text = text + else: + cut.supervisions[0].text = track.cut.text + # TODO: need to check the last speaker of last track and the first speaker of the current track + # if they are the same, we need to remove the the speaker token from the current track for segment-level + # Do not need to remove the speaker token for word-level + + return cut + + def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: + """ + Balance the speaker distribution for the simulated meetings. + Args: + num_meetings: The total number of simulated meetings. + speaker_count_distribution: The speaker count distribution for the simulated meetings. + For each number of speakers, calculate the number of meetings needed to balance the distribution. + """ + + total_spk = sum(speaker_count_distribution) + num_speakers2num_meetings = {} + for i_spk in range(self.max_num_speakers): + num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + + return num_speakers2num_meetings + + + @dill_enabled(True) + def simulate(self, + cuts: CutSet, + num_meetings: int = 10000, + seed: int = 0, + num_jobs: int = 1, + ) -> CutSet: + random.seed(seed) + + self.fit(cuts) + + num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) + logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") + num_speakers2num_meetings[1] = 0 # skip 1-speaker samples + logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') + + # Step 0: Calculate the number of intra-session and inter-session concatentation samples + n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] + valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples + n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} + for n_spk, n_mt in num_speakers2num_meetings.items(): + logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) + if n_mt <= 0: + logging.warning(f"No intra-session concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + continue + n_intra_mt = int(n_mt * self.intra_session_mix_prob[n_spk-1]) + n_inter_mt = n_mt - n_intra_mt + if n_spk in self.num_spk2sess_ids: + logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") + n_spk2n_intra_mt[n_spk] = n_intra_mt + else: + logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + n_spk2n_intra_mt[n_spk] = 0 + n_inter_mt = n_mt + if n_spk in valid_sim_n_spks: + logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") + n_spk2n_inter_mt[n_spk] = n_inter_mt + else: + logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + if n_spk2n_intra_mt[n_spk] != 0: + n_spk2n_intra_mt[n_spk] = n_mt + logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") + else: + logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") + # Step 1: intra-session + num_intra_meetings = 0 + intra_mixtures = [] + logging.info(f"Simulating intra-session concatentation samples.") + for n_spk, n_mt in n_spk2n_intra_mt.items(): + if n_mt <= 0: + continue + + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): + intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) + num_intra_meetings += n_mt + logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") + + # Steo 2: inter-session + logging.info(f"Simulating inter-session concatentation samples.") + + num_inter_meetings = 0 + inter_mixtures = [] + for n_spk, n_mt in n_spk2n_inter_mt.items(): + if n_mt <= 0: + continue + + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): + inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) + num_inter_meetings += n_mt + logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") + + if num_inter_meetings + num_intra_meetings == 0: + logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration} and max {self.max_duration}, or the speaker count distribution is not correctly set.") + + return CutSet.from_cuts(intra_mixtures + inter_mixtures) + +class LibriSpeechMixSimulator(): + + def __init__( + self, + min_duration: float = 80.0, + max_duration: float = 100.0, + n_mix_speakers: List[int] = [1, 2, 3], + speaker_count_distribution: List[float] = [1, 1, 1], + ): + """ + :param min_duration: the minimum duration of the simulated meeting. [Default: 80.0] + :param max_duration: the maximum duration of the simulated meeting. [Default: 100.0] + """ + super().__init__() + self.min_duration = min_duration + self.max_duration = max_duration + self.n_mix_speakers = n_mix_speakers + self.speaker_count_distribution = speaker_count_distribution + assert len(speaker_count_distribution) == len(n_mix_speakers), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {len(n_mix_speakers)}" + + def fit(self, cuts) -> CutSet: + pass + + def simulate(self, + cuts: CutSet, + num_meetings: int = 10000, + seed: int = 0, + num_jobs: int = 1, + ) -> CutSet: + random.seed(seed) + + cut_set = [] + for n_speakers, n_mt in zip(self.n_mix_speakers, self.speaker_count_distribution): + if n_mt <= 0: + continue + for i in tqdm(range(n_mt), desc=f"Simulating {n_speakers}-speaker mixtures", ncols=128): + cut_set.append(self._create_mixture(n_speakers=n_speakers)) + return CutSet.from_cuts(cut_set) + +class LibriSpeechMixGenerator(): + def __init__(self): + pass + + def generate(self, cuts): + cut_set = [] + for cut in tqdm(cuts): + offsets = cut.delays + durations = cut.durations + wavs = cut.wavs + texts = cut.texts + speakers = cut.speakers + + tracks = [] + for i, (offset, duration, wav, text, speaker) in enumerate(zip(offsets, durations, wavs, texts, speakers)): + wav_dur = soundfile.info(wav).duration + wav_samples = soundfile.info(wav).frames + custom = { + 'speaker': speaker, + 'text': text, + } + cut_1spk = MonoCut( + id=wav.split('/')[-1].replace('.wav', ''), + start=0, + duration=duration, + channel=0, + supervisions=[], + recording=Recording( + id=wav.split('/')[-1].replace('.wav', ''), + sources=[ + AudioSource( + type='file', + channels=[0], + source=wav + ) + ], + sampling_rate=16000, + num_samples=wav_samples, + duration=wav_dur + ), + custom=custom + ) + + tracks.append(MixTrack(cut=cut_1spk, type=type(cut_1spk), offset=offset)) + sup = SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=offset+wav_dur, + text=cut.text, + ) + tracks[0].cut.supervisions.append(sup) + cut_multi_spk = MixedCut(id=cut.id, tracks=tracks) + + cut_set.append(cut_multi_spk) + + return CutSet.from_cuts(cut_set) \ No newline at end of file diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index b16ac50e4d56..144ae405de52 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -16,8 +16,7 @@ import json import os from itertools import combinations -from typing import Any, Callable, Dict, Iterable, List, Optional, Union - +from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np import pandas as pd @@ -311,7 +310,7 @@ def __init__( class ASRAudioText(AudioText): """`AudioText` collector from asr structured json files.""" - def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[Callable] = None, *args, **kwargs): + def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): """Parse lists of audio files, durations and transcripts texts. Args: @@ -334,9 +333,8 @@ def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[ [], [], ) - speakers, orig_srs, token_labels, langs = [], [], [], [] - for item in manifest.item_iter(manifests_files, parse_func=parse_func): + for item in manifest.item_iter(manifests_files): ids.append(item['id']) audio_files.append(item['audio_file']) durations.append(item['duration']) @@ -1244,6 +1242,190 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: ) return item +class EndtoEndDiarizationLabel(_Collection): + """List of diarization audio-label correspondence with preprocessing.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='DiarizationLabelEntity', + field_names='audio_file uniq_id duration rttm_file offset', + ) + + def __init__( + self, + audio_files: List[str], + uniq_ids: List[str], + durations: List[float], + rttm_files: List[str], + offsets: List[float], + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + ): + """ + Instantiates audio-label manifest with filters and preprocessing. + + This method initializes the EndtoEndDiarizationLabel object by processing the input data + and applying optional filters and sorting. + + Args: + audio_files (List[str]): List of audio file paths. + uniq_ids (List[str]): List of unique identifiers for each audio file. + durations (List[float]): List of float durations for each audio file. + rttm_files (List[str]): List of RTTM path strings (Groundtruth diarization annotation file). + offsets (List[float]): List of offsets or None for each audio file. + max_number (Optional[int]): Maximum number of samples to collect. Defaults to None. + do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False. + index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. Defaults to False. + + """ + if index_by_file_id: + self.mapping = {} + output_type = self.OUTPUT_TYPE + data, duration_filtered = [], 0.0 + + zipped_items = zip( + audio_files, uniq_ids, durations, rttm_files, offsets + ) + for ( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) in zipped_items: + + if duration is None: + duration = 0 + + data.append( + output_type( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) + ) + + if index_by_file_id: + if isinstance(audio_file, list): + if len(audio_file) == 0: + raise ValueError(f"Empty audio file list: {audio_file}") + audio_file_name = sorted(audio_file)[0] + else: + audio_file_name = audio_file + file_id, _ = os.path.splitext(os.path.basename(audio_file)) + self.mapping[file_id] = len(data) - 1 + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info( + "Filtered duration for loading collection is %f.", duration_filtered, + ) + logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") + + super().__init__(data) + + +class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel): + """`DiarizationLabel` diarization data sample collector from structured json files.""" + + def __init__( + self, + manifests_files: Union[str, List[str]], + round_digits=2, + *args, + **kwargs, + ): + """ + Parse lists of audio files, durations, RTTM (Diarization annotation) files. + Since diarization model infers only two speakers, speaker pairs are generated + from the total number of speakers in the session. + + Args: + manifest_filepath (str): + Path to input manifest json files. + round_digit (int): + Number of digits to be rounded. + *args: Args to pass to `SpeechLabel` constructor. + **kwargs: Kwargs to pass to `SpeechLabel` constructor. + """ + self.round_digits = round_digits + audio_files, uniq_ids, durations, rttm_files, offsets = ( + [], + [], + [], + [], + [], + ) + + for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item_rttm): + # Training mode + rttm_labels = [] + with open(item['rttm_file'], 'r') as f: + for index, rttm_line in enumerate(f.readlines()): + rttm = rttm_line.strip().split() + start = round(float(rttm[3]), round_digits) + end = round(float(rttm[4]), round_digits) + round(float(rttm[3]), round_digits) + speaker = rttm[7] + rttm_labels.append('{} {} {}'.format(start, end, speaker)) + audio_files.append(item['audio_file']) + uniq_ids.append(item['uniq_id']) + durations.append(item['duration']) + rttm_files.append(item['rttm_file']) + offsets.append(item['offset']) + + super().__init__( + audio_files, + uniq_ids, + durations, + rttm_files, + offsets, + *args, + **kwargs, + ) + + def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: + """Parse each rttm file and save it to in Dict format""" + item = json.loads(line) + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + else: + raise ValueError( + f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." + ) + if isinstance(item['audio_file'], list): + item['audio_file'] = [os.path.expanduser(audio_file_path) for audio_file_path in item['audio_file']] + else: + item['audio_file'] = os.path.expanduser(item['audio_file']) + + if not isinstance(item['audio_file'], list): + if 'uniq_id' not in item: + item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] + elif 'uniq_id' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper uniq_id key.") + + if 'duration' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.") + item = dict( + audio_file=item['audio_file'], + uniq_id=item['uniq_id'], + duration=item['duration'], + rttm_file=item['rttm_filepath'], + offset=item.get('offset', None), + ) + return item + class Audio(_Collection): """Prepare a list of all audio items, filtered by duration.""" From 29143251c67bde5ef38aafa321bbd11a840bc1f2 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 00:01:08 -0800 Subject: [PATCH 02/47] Tested all unit-test files Signed-off-by: taejinp --- .../neural_diarizer/e2e_diarize_speech.py | 386 ++++++++++++++++++ nemo/collections/asr/metrics/der.py | 41 +- .../asr/metrics/multi_binary_acc.py | 51 ++- nemo/collections/asr/models/__init__.py | 7 +- .../asr/models/sortformer_diar_models.py | 3 +- .../asr/modules/sortformer_modules.py | 111 +++++ .../asr/parts/utils/speaker_utils.py | 163 +++++++- nemo/collections/asr/parts/utils/vad_utils.py | 136 +++--- 8 files changed, 780 insertions(+), 118 deletions(-) create mode 100644 examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py create mode 100644 nemo/collections/asr/modules/sortformer_modules.py diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py new file mode 100644 index 000000000000..98f2ee10e523 --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -0,0 +1,386 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +python $BASEPATH/neural_diarizer/sortformer_diarization.py \ + model_path=/path/to/sortformer_model.nemo \ + batch_size=4 \ + dataset_manifest=/path/to/diarization_path_to_manifest.json + +""" +import pytorch_lightning as pl +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.core.config import hydra_runner +from nemo.collections.asr.metrics.der import score_labels +from hydra.core.config_store import ConfigStore + +import os +import yaml +from dataclasses import dataclass, is_dataclass +from typing import Optional, Union, List, Tuple, Dict + +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object +from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing + +from tqdm import tqdm +import torch +import logging +import optuna +import tempfile + +seed_everything(42) +torch.backends.cudnn.deterministic = True + +@dataclass +class PostProcessingParams: + window_length_in_sec: float = 0.15 + shift_length_in_sec: float = 0.01 + smoothing: bool = False + overlap: float = 0.5 + onset: float = 0.5 + offset: float = 0.5 + pad_onset: float = 0.0 + pad_offset: float = 0.0 + min_duration_on: float = 0.0 + min_duration_off: float = 0.0 + filter_speech_first: bool = True + +@dataclass +class DiarizationConfig: + # Required configs + model_path: Optional[str] = None # Path to a .nemo file + pretrained_name: Optional[str] = None # Name of a pretrained model + audio_dir: Optional[str] = None # Path to a directory which contains audio files + dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest + + postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations + no_der: bool = False + out_rttm_dir: Optional[str] = None + + # General configs + session_len_sec: float = -1 # End-to-end diarization session length in seconds + batch_size: int = 4 + num_workers: int = 0 + random_seed: Optional[int] = None # seed number going to be used in seed_everything() + bypass_postprocessing: bool = True # If True, postprocessing will be bypassed + + # Eval Settings: (0.25, False) should be default setting for sortformer eval. + collar: float = 0.25 # Collar in seconds for DER calculation + ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments + + # If `cuda` is a negative number, inference will be on CPU only. + cuda: Optional[int] = None + matmul_precision: str = "highest" # Literal["highest", "high", "medium"] + + # Optuna Config + launch_pp_optim: bool = False # If True, launch optimization process for postprocessing parameters + optuna_study_name: str = "optim_postprocessing" + optuna_temp_dir: str = "/tmp/optuna" + optuna_storage: str = f"sqlite:///{optuna_study_name}.db" + optuna_log_file: str = f"{optuna_study_name}.log" + optuna_n_trials: int = 100000 + +def load_postprocessing_from_yaml(postprocessing_yaml): + """ + Load postprocessing parameters from a YAML file. + + Args: + postprocessing_yaml (str): + Path to a YAML file for postprocessing configurations. + + Returns: + postprocessing_params (dataclass): + Postprocessing parameters loaded from the YAML file. + """ + # Add PostProcessingParams as a field + postprocessing_params = OmegaConf.structured(PostProcessingParams()) + if postprocessing_yaml is None: + logging.info(f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied.") + else: + # Load postprocessing params from the provided YAML file + with open(postprocessing_yaml, 'r') as file: + yaml_params = yaml.safe_load(file)['parameters'] + # Update the postprocessing_params with the loaded values + logging.info(f"Postprocessing YAML file '{postprocessing_yaml}' has been loaded.") + for key, value in yaml_params.items(): + if hasattr(postprocessing_params, key): + setattr(postprocessing_params, key, value) + return postprocessing_params + +def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: + """ + Suggests hyperparameters for postprocessing using Optuna. + + Args: + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + + Returns: + PostProcessingParams: The updated postprocessing configuration with suggested hyperparameters. + """ + postprocessing_cfg.onset = trial.suggest_float("onset", 0.4, 0.8, step=0.01) + postprocessing_cfg.offset = trial.suggest_float("offset", 0.4, 0.9, step=0.01) + postprocessing_cfg.pad_onset = trial.suggest_float("pad_onset", 0.1, 0.5, step=0.01) + postprocessing_cfg.pad_offset = trial.suggest_float("pad_offset", 0.0, 0.2, step=0.01) + postprocessing_cfg.min_duration_on = trial.suggest_float("min_duration_on", 0.0, 0.75, step=0.01) + postprocessing_cfg.min_duration_off = trial.suggest_float("min_duration_off", 0.0, 0.75, step=0.01) + return postprocessing_cfg + +def get_tensor_path(cfg: DiarizationConfig) -> str: + """ + Constructs the file path for saving or loading prediction tensors based on the configuration. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + + Returns: + str: The constructed file path for the prediction tensor. + """ + tensor_filename = os.path.basename(cfg.dataset_manifest).replace("manifest.", "").replace(".json", "") + model_base_path = os.path.dirname(cfg.model_path) + model_id = os.path.basename(cfg.model_path).replace(".ckpt", "").replace(".nemo", "") + bpath = f"{model_base_path}/pred_tensors" + if not os.path.exists(bpath): + os.makedirs(bpath) + tensor_path = f"{bpath}/__{model_id}__{tensor_filename}.pt" + return tensor_path + +def diarization_objective( + trial: optuna.Trial, + postprocessing_cfg: PostProcessingParams, + temp_out_dir: str, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + diar_model_preds_total_list: List[torch.Tensor], + collar: float = 0.25, + ignore_overlap: bool = False +) -> float: + """ + Objective function for Optuna hyperparameter optimization in speaker diarization. + + This function evaluates the diarization performance using a set of postprocessing parameters + suggested by Optuna. It converts prediction matrices to time-stamp segments, scores the + diarization results, and returns the Diarization Error Rate (DER) as the optimization metric. + + Args: + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + temp_out_dir (str): Temporary directory for storing intermediate outputs. + infer_audio_rttm_dict (Dict[str, Dict[str, str]]): Dictionary containing audio file paths, + offsets, durations, and RTTM file paths. + diar_model_preds_total_list (List[torch.Tensor]): List of prediction matrices containing + sigmoid values for each speaker. Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + collar (float, optional): Collar in seconds for DER calculation. Defaults to 0.25. + ignore_overlap (bool, optional): If True, DER will be calculated only for non-overlapping segments. + Defaults to False. + + Returns: + float: The Diarization Error Rate (DER) for the given set of postprocessing parameters. + """ + with tempfile.TemporaryDirectory(dir=temp_out_dir, prefix="Diar_PostProcessing_") as local_temp_out_dir: + if trial is not None: + postprocessing_cfg = optuna_suggest_params(postprocessing_cfg, trial) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments(audio_rttm_map_dict=infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=False) + metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=collar, + ignore_overlap=ignore_overlap + ) + der = abs(metric) + return der + +def run_optuna_hyperparam_search( + cfg: DiarizationConfig, # type: DiarizationConfig + postprocessing_cfg: PostProcessingParams, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + preds_list: List[torch.Tensor], + temp_out_dir: str, + ): + worker_function = lambda trial: diarization_objective( + trial=trial, + postprocessing_cfg=postprocessing_cfg, + temp_out_dir=temp_out_dir, + infer_audio_rttm_dict=infer_audio_rttm_dict, + diar_model_preds_total_list=preds_list, + collar=cfg.collar, + ) + study = optuna.create_study( + direction="minimize", + study_name=cfg.optuna_study_name, + storage=cfg.optuna_storage, + load_if_exists=True + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Setup the root logger. + if cfg.optuna_log_file is not None: + logger.addHandler(logging.FileHandler(cfg.optuna_log_file, mode="a")) + logger.addHandler(logging.StreamHandler()) + optuna.logging.enable_propagation() # Propagate logs to the root logger. + study.optimize(worker_function, n_trials=cfg.optuna_n_trials) + + +def convert_pred_mat_to_segments( + audio_rttm_map_dict: Dict[str, Dict[str, str]], + postprocessing_cfg, + batch_preds_list: List[torch.Tensor], + unit_10ms_frame_count:int = 8, + bypass_postprocessing: bool = False, + out_rttm_dir: str | None = None, + ): + """ + Convert prediction matrix to time-stamp segments. + + Args: + audio_rttm_map_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + batch_preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + unit_10ms_frame_count (int, optional): number of 10ms segments in a frame. Defaults to 8. + bypass_postprocessing (bool, optional): if True, postprocessing will be bypassed. Defaults to False. + + Returns: + all_hypothesis (list): list of pyannote objects for each audio file. + all_reference (list): list of pyannote objects for each audio file. + all_uems (list): list of pyannote objects for each audio file. + """ + batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], [] + cfg_vad_params = OmegaConf.structured(postprocessing_cfg) + for sample_idx, (uniq_id, audio_rttm_values) in tqdm(enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc="Running post-processing"): + spk_ts = [] + offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration'] + speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0) + speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] + for spk_id in range(speaker_assign_mat.shape[-1]): + ts_mat = ts_vad_post_processing(speaker_assign_mat[:, spk_id], + cfg_vad_params=cfg_vad_params, + unit_10ms_frame_count=unit_10ms_frame_count, + bypass_postprocessing=bypass_postprocessing) + ts_mat = ts_mat + offset + ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration)) + ts_seg_list = ts_mat.tolist() + speaker_timestamps[spk_id].extend(ts_seg_list) + spk_ts.append(ts_seg_list) + all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object(speaker_timestamps, + uniq_id, + audio_rttm_values, + all_hypothesis, + all_reference, + all_uems, + out_rttm_dir, + ) + batch_pred_ts_segs.append(spk_ts) + return all_hypothesis, all_reference, all_uems + +@hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) +def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: + for key in cfg: + cfg[key] = None if cfg[key] == 'None' else cfg[key] + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if cfg.random_seed: + pl.seed_everything(cfg.random_seed) + + if cfg.model_path is None and cfg.pretrained_name is None: + raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") + if cfg.audio_dir is None and cfg.dataset_manifest is None: + raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") + + # setup GPU + torch.set_float32_matmul_precision(cfg.matmul_precision) + if cfg.cuda is None: + if torch.cuda.is_available(): + device = [0] # use 0th CUDA device + accelerator = 'gpu' + map_location = torch.device('cuda:0') + else: + device = 1 + accelerator = 'cpu' + map_location = torch.device('cpu') + else: + device = [cfg.cuda] + accelerator = 'gpu' + map_location = torch.device(f'cuda:{cfg.cuda}') + + if cfg.model_path.endswith(".ckpt"): + diar_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=cfg.model_path, map_location=map_location, strict=False) + elif cfg.model_path.endswith(".nemo"): + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.model_path, map_location=map_location) + else: + raise ValueError("cfg.model_path must end with.ckpt or.nemo!") + + diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec + trainer = pl.Trainer(devices=device, accelerator=accelerator) + diar_model.set_trainer(trainer) + + diar_model = diar_model.eval() + diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest + infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest) + diar_model._cfg.test_ds.batch_size = cfg.batch_size + + # Model setup for inference + diar_model._cfg.test_ds.num_workers = cfg.num_workers + diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) + + postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml) + tensor_path = get_tensor_path(cfg) + + if os.path.exists(tensor_path): + logging.info(f"A saved prediction tensor has been found. Loading the saved prediction tensors from {tensor_path}...") + diar_model_preds_total_list = torch.load(tensor_path) + else: + logging.info(f"No saved prediction tensors found. Running inference on the dataset...") + diar_model.test_batch() + diar_model_preds_total_list = diar_model.preds_total_list + torch.save(diar_model.preds_total_list, tensor_path) + + if cfg.launch_pp_optim: + # Launch a hyperparameter optimization process if launch_pp_optim is True + run_optuna_hyperparam_search(cfg=cfg, + postprocessing_cfg=postprocessing_cfg, + infer_audio_rttm_dict=infer_audio_rttm_dict, + preds_list=diar_model_preds_total_list, + temp_out_dir=cfg.optuna_temp_dir) + + # Evaluation + if not cfg.no_der: + if cfg.out_rttm_dir is not None and not os.path.exists(cfg.out_rttm_dir): + os.mkdir(cfg.out_rttm_dir) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments(infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=cfg.bypass_postprocessing, + out_rttm_dir=cfg.out_rttm_dir + ) + logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") + metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap + ) + logging.info(f"PostProcessingParams: {postprocessing_cfg}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index fc5cded970d0..16f62bbe9e4c 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -130,7 +130,13 @@ def uem_timeline_from_file(uem_file, uniq_name=''): def score_labels( - AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True, verbose: bool = True + AUDIO_RTTM_MAP, + all_reference, + all_hypothesis, + all_uem: List[List[float]]=None, + collar:float=0.25, + ignore_overlap: bool=True, + verbose: bool = True ) -> Optional[Tuple[DiarizationErrorRate, Dict]]: """ Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis results are @@ -157,26 +163,41 @@ def score_labels( if len(all_reference) == len(all_hypothesis): metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap) - mapping_dict = {} - for (reference, hypothesis) in zip(all_reference, all_hypothesis): + mapping_dict, correct_spk_count = {}, 0 + for idx, (reference, hypothesis) in enumerate(zip(all_reference, all_hypothesis)): ref_key, ref_labels = reference _, hyp_labels = hypothesis - uem = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) - if uem is not None: - uem = uem_timeline_from_file(uem_file=uem, uniq_name=ref_key) - metric(ref_labels, hyp_labels, uem=uem, detailed=True) + if len(ref_labels.labels()) == len(hyp_labels.labels()): + correct_spk_count += 1 + if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): + logging.info(f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}") + uem_obj = None + if all_uem is not None: + metric(ref_labels, hyp_labels, uem=all_uem[idx], detailed=True) + elif AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) is not None: + uem_file = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) + uem_obj = uem_timeline_from_file(uem_file=uem_file, uniq_name=ref_key) + metric(ref_labels, hyp_labels, uem=uem_obj, detailed=True) + else: + metric(ref_labels, hyp_labels, detailed=True) mapping_dict[ref_key] = metric.optimal_mapping(ref_labels, hyp_labels) + spk_count_acc = correct_spk_count / len(all_reference) DER = abs(metric) + if metric['total'] == 0: + raise ValueError(f"Total evaluation time is 0. Abort.") CER = metric['confusion'] / metric['total'] FA = metric['false alarm'] / metric['total'] MISS = metric['missed detection'] / metric['total'] + itemized_errors = (DER, CER, FA, MISS) + if verbose: + # logging.info(f"\n{metric.report()}") + pass logging.info( - "Cumulative Results for collar {} sec and ignore_overlap {}: \n FA: {:.4f}\t MISS {:.4f}\t \ - Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format( - collar, ignore_overlap, FA, MISS, DER, CER + "Cumulative Results for collar {} sec and ignore_overlap {}: \n| FA: {:.4f} | MISS: {:.4f} | CER: {:.4f} | DER: {:.4f} | Spk. Count Acc. {:.4f}\n".format( + collar, ignore_overlap, FA, MISS, CER, DER, spk_count_acc ) ) diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 8cc21c53ad82..72781143208b 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -68,18 +68,19 @@ def on_validation_epoch_end(self): f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. """ - full_state_update = False def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) - self.total_correct_counts = 0 - self.total_sample_counts = 0 - self.true_positive_count = 0 - self.false_positive_count = 0 - self.false_negative_count = 0 - - def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor) -> torch.Tensor: + self.add_state("total_correct_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("total_sample_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("true_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_negative_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.eps = 1e-6 + + def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False) -> torch.Tensor: with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] targets_list = [targets[k, : signal_lengths[k], :] for k in range(targets.shape[0])] @@ -91,22 +92,30 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: tor self.positive = self.preds.round().bool() == 1 self.negative = self.preds.round().bool() == 0 - self.positive_count = torch.sum(self.preds.round().bool() == True) - self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) - self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) - self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) - - self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) - self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + if cumulative: + self.positive_count += torch.sum(self.preds.round().bool() == True) + self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + else: + self.positive_count = torch.sum(self.preds.round().bool() == True) + self.true_positive_count = torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count = torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count = torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts = torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts = torch.prod(torch.tensor(self.targets.shape)) def compute(self): """ Compute F1 score from the accumulated values. Return -1 if the F1 score is NaN. """ - self.precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count) - self.recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count) - self.f1_score = 2 * self.precision * self.recall / (self.precision + self.recall) - if torch.isnan(self.f1_score): + precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count + self.eps) + recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count + self.eps) + f1_score = (2 * precision * recall / (precision + recall + self.eps)).detach().clone() + + if torch.isnan(f1_score): logging.warn("self.f1_score contains NaN value. Returning -1 instead of NaN value.") - self.f1_score = -1 - return self.f1_score + f1_score = -1 + return f1_score.float(), precision.float(), recall.float() diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index e4a1342b9c36..31194d8849f0 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -19,6 +19,7 @@ EncDecClassificationModel, EncDecFrameClassificationModel, ) +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel @@ -35,9 +36,5 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel -from nemo.collections.asr.models.ssl_models import ( - EncDecDenoiseMaskedTokenPredModel, - EncDecMaskedTokenPredModel, - SpeechEncDecSelfSupervisedModel, -) +from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index c389f0eb627f..50cdf6214d5b 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -555,7 +555,8 @@ def test_batch(self,): self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) torch.cuda.empty_cache() self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) - + # except: + # import ipdb; ipdb.set_trace() logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py new file mode 100644 index 000000000000..823cf98590e7 --- /dev/null +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import EncodedRepresentation, LengthsType, NeuralType, SpectrogramType +from nemo.core.neural_types.elements import ProbsType + +__all__ = ['SortformerModules'] + + +class SortformerModules(NeuralModule, Exportable): + """ + Multi-scale Diarization Decoder (MSDD) for overlap-aware diarization and improved diarization accuracy from clustering diarizer. + Based on the paper: Taejin Park et. al, "Multi-scale Speaker Diarization with Dynamic Scale Weighting", Interspeech 2022. + Arxiv version: https://arxiv.org/pdf/2203.15974.pdf + + Args: + num_spks (int): + Max number of speakers that are processed by the model. In `MSDD_module`, `num_spks=2` for pairwise inference. + hidden_size (int): + Number of hidden units in sequence models and intermediate layers. + num_lstm_layers (int): + Number of the stacked LSTM layers. + dropout_rate (float): + Dropout rate for linear layers, CNN and LSTM. + tf_d_model (int): + Dimension of the embedding vectors. + scale_n (int): + Number of scales in multi-scale system. + clamp_max (float): + Maximum value for limiting the scale weight values. + conv_repeat (int): + Number of CNN layers after the first CNN layer. + weighting_scheme (str): + Name of the methods for estimating the scale weights. + context_vector_type (str): + If 'cos_sim', cosine similarity values are used for the input of the sequence models. + If 'elem_prod', element-wise product values are used for the input of the sequence models. + """ + def init_weights(self, m): + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def __init__( + self, + num_spks: int = 4, + hidden_size: int = 192, + dropout_rate: float = 0.5, + fc_d_model: int = 512, + tf_d_model: int = 192, + ): + super().__init__() + self.fc_d_model = fc_d_model + self.tf_d_model = tf_d_model + self.hidden_size = tf_d_model + self.unit_n_spks: int = num_spks + self.hidden_to_spks = nn.Linear(2 * self.hidden_size, self.unit_n_spks) + self.first_hidden_to_hidden = nn.Linear(self.hidden_size, self.hidden_size) + self.single_hidden_to_spks = nn.Linear(self.hidden_size, self.unit_n_spks) + self.dropout = nn.Dropout(dropout_rate) + self.encoder_proj = nn.Linear(self.fc_d_model, self.tf_d_model) + + def length_to_mask(self, context_embs): + """ + Convert length values to encoder mask input tensor. + + Args: + lengths (torch.Tensor): tensor containing lengths of sequences + max_len (int): maximum sequence length + + Returns: + mask (torch.Tensor): tensor of shape (batch_size, max_len) containing 0's + in the padded region and 1's elsewhere + """ + lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0]) + batch_size = context_embs.shape[0] + max_len=context_embs.shape[1] + # create a tensor with the shape (batch_size, 1) filled with ones + row_vector = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) + # create a tensor with the shape (batch_size, max_len) filled with lengths + length_matrix = lengths.unsqueeze(1).expand(-1, max_len).to(lengths.device) + # create a mask by comparing the row vector and length matrix + mask = row_vector < length_matrix + return mask.float().to(context_embs.device) + + def forward_speaker_sigmoids(self, hidden_out): + hidden_out = self.dropout(F.relu(hidden_out)) + hidden_out = self.first_hidden_to_hidden(hidden_out) + hidden_out = self.dropout(F.relu(hidden_out)) + spk_preds = self.single_hidden_to_spks(hidden_out) + preds = nn.Sigmoid()(spk_preds) + return preds diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 5d3a0bf4274e..80b3e1f918b8 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -21,10 +21,11 @@ from typing import Dict, List, Tuple, Union import numpy as np -import omegaconf +from omegaconf import OmegaConf +from omegaconf.listconfig import ListConfig import soundfile as sf import torch -from pyannote.core import Annotation, Segment +from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm from nemo.collections.asr.data.audio_to_label import repeat_signal @@ -108,7 +109,10 @@ def audio_rttm_map(manifest, attach_dur=False): if attach_dur: uniqname = get_uniq_id_with_dur(meta) else: - uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) + if "uniq_id" in dic.keys(): + uniqname = dic['uniq_id'] + else: + uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) if uniqname not in AUDIO_RTTM_MAP: AUDIO_RTTM_MAP[uniqname] = meta @@ -144,7 +148,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ """ check_float_config = [isinstance(var, float) for var in (window_lengths_in_sec, shift_lengths_in_sec)] check_list_config = [ - isinstance(var, (omegaconf.listconfig.ListConfig, list, tuple)) + isinstance(var, (ListConfig, list, tuple)) for var in (window_lengths_in_sec, shift_lengths_in_sec, multiscale_weights) ] if all(check_list_config) or all(check_float_config): @@ -928,28 +932,61 @@ def segments_manifest_to_subsegments_manifest( return subsegments_manifest_file -def get_subsegments(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: - """ - Return subsegments from a segment of audio file +def get_subsegments( + offset: float, + window: float, + shift: float, + duration: float, + min_subsegment_duration: float = 0.01, + decimals: int = 2, + use_asr_style_frame_count: bool = False, + sample_rate: int = 16000, + feat_per_sec: int = 100, + ) -> List[List[float]]: + """ + Return subsegments from a segment of audio file. + + Example: + (window, shift) = 1.5, 0.75 + Segment: [12.05, 14.45] + Subsegments: [[12.05, 13.55], [12.8, 14.3], [13.55, 14.45], [14.3, 14.45]] + Args: - offset (float): start time of audio segment - window (float): window length for segments to subsegments length - shift (float): hop length for subsegments shift - duration (float): duration of segment + offset (float): Start time of audio segment + window (float): Window length for segments to subsegments length + shift (float): Hop length for subsegments shift + duration (float): Duration of segment + min_subsegment_duration (float): Exclude subsegments smaller than this duration value + decimals (int): Number of decimal places to round to + use_asr_style_frame_count (bool): If True, use asr style frame count to generate subsegments. + For example, if duration is 10 secs and frame_shift is 0.08 secs, + it results in (10/0.08)+1 = 125 + 1 frames. + Returns: subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment """ - subsegments: List[List[float]] = [] + subsegments: List[List[float]] = [] start = offset slice_end = start + duration - base = math.ceil((duration - window) / shift) - slices = 1 if base < 0 else base + 1 - for slice_id in range(slices): - end = start + window - if end > slice_end: - end = slice_end - subsegments.append([start, end - start]) - start = offset + (slice_id + 1) * shift + if min_subsegment_duration <= duration < shift: + slices = 1 + elif use_asr_style_frame_count is True: + num_feat_frames = np.ceil((1+duration*sample_rate)/int(sample_rate/feat_per_sec)).astype(int) + slices = np.ceil(num_feat_frames/int(feat_per_sec*shift)).astype(int) + slice_end = start + shift * slices + else: + slices = np.ceil(1+ (duration-window)/shift).astype(int) + if slices == 1: + if min(duration, window) >= min_subsegment_duration: + subsegments.append([start, min(duration, window)]) + elif slices > 0: # What if slcies = 0 ? + start_col = torch.arange(offset, slice_end, shift)[:slices] + dur_col = window * torch.ones(slices) + dur_col = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) + dur_col = torch.round(dur_col, decimals=decimals) + valid_mask = dur_col >= min_subsegment_duration + valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) + subsegments = valid_subsegments.tolist() return subsegments @@ -1000,6 +1037,15 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: return [[float(range_tensor[k][0]), float(range_tensor[k][1])] for k in range(range_tensor.shape[0])] +def generate_diarization_output_lines(speaker_timestamps, model_spk_num): + speaker_lines_total = [] + for spk_idx in range(model_spk_num): + ts_invervals = speaker_timestamps[spk_idx] + merged_ts_intervals = merge_float_intervals(ts_invervals) + for ts_interval in merged_ts_intervals: + speaker_lines_total.extend([f"{ts_interval[0]:.3f} {ts_interval[1]:.3f} speaker_{int(spk_idx)}"]) + return speaker_lines_total + def get_speech_labels_for_update( frame_start: float, buffer_end: float, @@ -1580,6 +1626,83 @@ def make_rttm_with_overlap( return all_reference, all_hypothesis +def timestamps_to_pyannote_object(speaker_timestamps: List[Tuple[float, float]], + uniq_id: str, + audio_rttm_values: Dict[str, str], + all_hypothesis: List[Tuple[str, Timeline]], + all_reference: List[Tuple[str, Timeline]], + all_uems: List[Tuple[str, Timeline]], + out_rttm_dir: str | None + ): + """ + Convert speaker timestamps to pyannote.core.Timeline object. + + Args: + speaker_timestamps (List[Tuple[float, float]]): + Timestamps of each speaker: start time and end time of each speaker. + uniq_id (str): + Unique ID of each speaker. + audio_rttm_values (Dict[str, str]): + Dictionary of manifest values. + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object. + out_rttm_dir (str | None): + Directory to save RTTMs + + Returns: + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object with an added Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object with an added Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object with an added Timeline object. + """ + offset, dur = float(audio_rttm_values.get('offset', None)), float(audio_rttm_values.get('duration', None)) + hyp_labels = generate_diarization_output_lines(speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps)) + hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=uniq_id) + if out_rttm_dir is not None and os.path.exists(out_rttm_dir): + with open(f'{out_rttm_dir}/{uniq_id}.rttm','w') as f: + hypothesis.write_rttm(f) + all_hypothesis.append([uniq_id, hypothesis]) + rttm_file = audio_rttm_values.get('rttm_filepath', None) + if rttm_file is not None and os.path.exists(rttm_file): + uem_lines = [[offset, dur+offset]] + org_ref_labels = rttm_to_labels(rttm_file) + ref_labels = org_ref_labels + reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) + uem_obj = get_uem_object(uem_lines, uniq_id=uniq_id) + all_uems.append(uem_obj) + all_reference.append([uniq_id, reference]) + return all_hypothesis, all_reference, all_uems + +def get_uem_object(uem_lines: List[List[float]], uniq_id: str): + """ + Generate pyannote timeline segments for uem file. + + file format + UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME + + Args: + uem_lines (list): list of session ID and start, end times. + Example: + [[0.0, 30.41], [60.04, 165.83]] + uniq_id (str): Unique session ID. + + Returns: + timeline (pyannote.core.Timeline): pyannote timeline object. + """ + timeline = Timeline(uri=uniq_id) + for uem_stt_end in uem_lines: + start_time, end_time = uem_stt_end + timeline.add(Segment(float(start_time), float(end_time))) + return timeline + + + def embedding_normalize(embs, use_std=False, eps=1e-10): """ Mean and l2 length normalize the input speaker embeddings diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index aea04b8cafcf..192c42375dca 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -23,31 +23,23 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union +import IPython.display as ipd import librosa import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pyannote.core import Annotation, Segment from pyannote.metrics import detection from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm - +from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging -HAVE_IPYTHON = False -try: - import IPython.display as ipd - - HAVE_IPYTHON = True -except: - HAVE_IPYTHON = False - - """ This file contains all the utility functions required for voice activity detection. """ @@ -74,8 +66,7 @@ def prepare_manifest(config: dict) -> str: input_list = config['input'] else: raise ValueError( - "The input for manifest preparation would either be a string of the filepath to \ - manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} " + "The input for manifest preparation would either be a string of the filepath to manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} " ) args_func = { @@ -204,8 +195,7 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list: def get_vad_stream_status(data: list) -> list: """ - Generate a list of status for each snippet in manifest. - A snippet should be in single, start, next or end status. + Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status. Used for concatenating to full audio file. Args: data (list): list of filepath of audio snippet @@ -256,8 +246,7 @@ def generate_overlap_vad_seq( out_dir: str = None, ) -> str: """ - Generate predictions with overlapping input windows/segments. - Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. + Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. Two common smoothing filters are supported: majority vote (median) and average (mean). This function uses multiprocessing to speed up. Args: @@ -321,8 +310,7 @@ def generate_overlap_vad_seq_per_tensor( frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str ) -> torch.Tensor: """ - Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) - to generate prediction with overlapping input window/segments + Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ @@ -484,8 +472,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: @@ -498,8 +485,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te frame_length_in_sec (float): length of frame. Returns: - speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) \ - format. + speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -550,8 +536,7 @@ def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: tor """ Remove speech segments list in to_be_removed_segments from original_segments. For example, - remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],\ - [start3, end3], [start4, end4]]), + remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), -> torch.Tensor([[start1, end1],[start3, end3]]) """ @@ -577,25 +562,21 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc Filter out short non_speech and speech segments. Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], \ - [start2, end2]]) format. + speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. per_args: min_duration_on (float): threshold for small non_speech deletion min_duration_off (float): threshold for short speech segment deletion - filter_speech_first (float): Whether to perform short speech segment deletion first. \ - Use 1.0 to represent True. + filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments(torch.Tensor): A tensor of filtered speech segment in \ - torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments(torch.Tensor): A tensor of filtered speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. """ if speech_segments.shape == torch.Size([0]): return speech_segments - + min_duration_on = per_args.get('min_duration_on', 0.0) min_duration_off = per_args.get('min_duration_off', 0.0) filter_speech_first = per_args.get('filter_speech_first', 1.0) @@ -728,8 +709,7 @@ def generate_vad_segment_table( 17,18, speech Args: vad_pred_dir (str): directory of prediction files to be processed. - postprocessing_params (dict): dictionary of thresholds for prediction score. - See details in binarization and filtering. + postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering. frame_length_in_sec (float): frame length. out_dir (str): output dir of generated table/csv file. num_workers(float): number of process for multiprocessing @@ -840,12 +820,10 @@ def vad_tune_threshold_on_dev( num_workers: int = 20, ) -> Tuple[dict, dict]: """ - Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate - (DetER) in thresholds. + Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds. Args: params (dict): dictionary of parameters to be tuned on. - vad_pred_method (str): suffix of prediction file. Use to locate file. - Should be either in "frame", "mean" or "median". + vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median". groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them. focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" frame_length_in_sec (float): frame length. @@ -936,8 +914,7 @@ def check_if_param_valid(params: dict) -> bool: for j in params[i]: if not j >= 0: raise ValueError( - "Invalid inputs! All float parameters except pad_onset and pad_offset should be \ - larger than 0!" + "Invalid inputs! All float parameters except pad_onset and pad_offset should be larger than 0!" ) if not (all(i <= 1 for i in params['onset']) and all(i <= 1 for i in params['offset'])): @@ -995,7 +972,7 @@ def plot( unit_frame_len: float = 0.01, label_repeat: int = 1, xticks_step: int = 5, -) -> "ipd.Audio": +) -> ipd.Audio: """ Plot Audio and/or VAD output and/or groundtruth labels for visualization Args: @@ -1009,13 +986,9 @@ def plot( threshold (float): threshold for prediction score (from 0 to 1). per_args(dict): a dict that stores the thresholds for postprocessing. unit_frame_len (float): unit frame length in seconds for VAD predictions. - label_repeat (int): repeat the label for this number of times to match different \ - frame lengths in preds and labels. + label_repeat (int): repeat the label for this number of times to match different frame lengths in preds and labels. xticks_step (int): step size for xticks. """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load( @@ -1281,8 +1254,7 @@ def stitch_segmented_asr_output( fout.flush() logging.info( - f"Finish stitch segmented ASR output to {stitched_output_manifest}, \ - the speech segments info has been stored in directory {speech_segments_tensor_dir}" + f"Finish stitch segmented ASR output to {stitched_output_manifest}, the speech segments info has been stored in directory {speech_segments_tensor_dir}" ) return stitched_output_manifest @@ -1462,13 +1434,10 @@ def plot_sample_from_rttm( show: bool = True, offset: float = 0.0, unit_frame_len: float = 0.01, -) -> "ipd.Audio": +): """ Plot audio signal and frame-level labels from RTTM file """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load(path=audio_file, sr=16000, mono=True, offset=offset, duration=max_duration) @@ -1503,9 +1472,8 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms). - The threshold 0.2 is not important, since the actual ratio will always be close to an integer - unless using frame/label. lengths that are not multiples of each other - (e.g., 15ms frame length and 20ms label length), which is not valid. + The threshold 0.2 is not important, since the actual ratio will always be close to an integer unless using frame/label + lengths that are not multiples of each other (e.g., 15ms frame length and 20ms label length), which is not valid. The value 0.2 here is just for easier unit testing. Args: probs (List[float]): list of probabilities @@ -1543,13 +1511,11 @@ def align_labels_to_frames(probs, labels, threshold=0.2): ratio = frames_len / labels_len res = frames_len % labels_len if ceil(ratio) - ratio < threshold: - # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a - # multiple of 2, and discard the redundant labels + # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a multiple of 2, and discard the redundant labels labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist() labels = labels[:frames_len] else: - # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of - # 2 and add additional labels + # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of 2 and add additional labels labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist() if res > 0: labels += labels[-res:] @@ -1743,3 +1709,51 @@ def frame_vad_eval_detection_error( auroc = roc_auc_score(y_true=all_labels, y_score=all_probs) report = metric.report(display=False) return auroc, report + + +def ts_vad_post_processing( + ts_vad_binary_vec: torch.Tensor, + cfg_vad_params: OmegaConf, + unit_10ms_frame_count: int=8, + bypass_postprocessing: bool = False + ): + """ + Post-processing on diarization results using VAD style post-processing methods. + These post-processing methods are inspired by the following paper: + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). + + Args: + ts_vad_binary_vec (Tensor): + Sigmoid values of each frame and each speaker. + Dimension: (num_frames,) + cfg_vad_params (OmegaConf): + Configuration (omega config) of VAD parameters. + unit_10ms_frame_count (int, optional): + an integer indicating the number of 10ms frames in a unit. + For example, if unit_10ms_frame_count is 8, then each frame is 0.08 seconds. + bypass_postprocessing (bool, optional): + If True, diarization post-processing will be bypassed. + + Returns: + speech_segments (Tensor): + start and end of each speech segment. + Dimension: (num_segments, 2) + + Example: + tensor([[ 0.0000, 3.0400], + [ 6.0000, 6.0800], + ... + [587.3600, 591.0400], + [591.1200, 597.7600]]) + """ + ts_vad_binary_frames = torch.repeat_interleave(ts_vad_binary_vec, unit_10ms_frame_count) + if not bypass_postprocessing: + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + speech_segments = filtering(speech_segments, cfg_vad_params) + else: + cfg_vad_params.onset=0.5 + cfg_vad_params.offset=0.5 + cfg_vad_params.pad_onset=0.0 + cfg_vad_params.pad_offset=0.0 + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + return speech_segments \ No newline at end of file From 9a468ac82e68458ceb9c5882975d887362a00c07 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 00:37:49 -0800 Subject: [PATCH 03/47] Name changes on yaml files and train example Signed-off-by: taejinp --- ...rtformer_diarizer_hybrid_loss_4spk-v1.yaml | 218 ++++++++++++++++++ ...rtformer_diar_4spk-v1_callhome-part1.yaml} | 1 - ...> sortformer_diar_4spk-v1_dihard-dev.yaml} | 2 +- .../sortformer_diar_encoder_infer.py | 132 ----------- ...oder_train.py => sortformer_diar_train.py} | 0 5 files changed, 219 insertions(+), 134 deletions(-) create mode 100644 examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml rename examples/speaker_tasks/diarization/conf/post_processing/{sortformer_diar_HL_callhome_part1.yaml => sortformer_diar_4spk-v1_callhome-part1.yaml} (85%) rename examples/speaker_tasks/diarization/conf/post_processing/{sortformer_diar_HL_dihard.yaml => sortformer_diar_4spk-v1_dihard-dev.yaml} (84%) delete mode 100644 examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py rename examples/speaker_tasks/diarization/neural_diarizer/{sortformer_diar_encoder_train.py => sortformer_diar_train.py} (100%) diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml new file mode 100644 index 000000000000..e44bae976729 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -0,0 +1,218 @@ +# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Model name convention for Sortformer Diarizer: sortformer_diarizer____loss.yaml +# (Example) `sortformer_diarizer_FC18_TF18_hybrid_loss.yaml` has 18 layers for FastConformer and 18 layers of Transformer. +# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. +# Example: a manifest line for training +# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} +name: "SortFormerDiarizer" +sample_rate: 16000 +num_workers: 18 +batch_size: 8 + +model: + pil_weight: 0.5 + ats_weight: 0.5 + num_workers: ${num_workers} + fc_d_model: 512 + tf_d_model: 192 + max_num_of_spks: 4 # Number of speakers per model. This is currently fixed at 4. + session_len_sec: 90 + + train_ds: + manifest_filepath: ??? + sample_rate: ${sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: ${model.session_len_sec} + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: True + num_workers: ${num_workers} + validation_mode: False + # lhotse config + use_lhotse: False + use_bucketing: True + num_buckets: 10 + bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90] + pin_memory: True + min_duration: 80 + max_duration: 90 + batch_duration: 400 + quadratic_duration: 1200 + bucket_buffer_size: 20000 + shuffle_buffer_size: 10000 + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + validation_ds: + manifest_filepath: ??? + is_tarred: False + tarred_audio_filepaths: null + sample_rate: ${sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: ${model.session_len_sec} + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + test_ds: + manifest_filepath: null + is_tarred: False + tarred_audio_filepaths: null + sample_rate: 16000 + num_spks: ${model.max_num_of_spks} + session_len_sec: ${model.session_len_sec} + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + seq_eval_mode: True + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + normalize: "per_feature" + window_size: 0.025 + sample_rate: ${sample_rate} + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + + sortformer_modules: + _target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules + num_spks: ${model.max_num_of_spks} # Number of speakers per model. This is currently fixed at 4. + dropout_rate: 0.5 # Dropout rate + fc_d_model: ${model.fc_d_model} + tf_d_model: ${model.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 + n_layers: 18 + d_model: ${model.fc_d_model} + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + transformer_encoder: + _target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder + num_layers: 18 + hidden_size: ${model.tf_d_model} # Needs to be multiple of num_attention_heads + inner_size: 768 + num_attention_heads: 8 + attn_score_dropout: 0.5 + attn_layer_dropout: 0.5 + ffn_dropout: 0.5 + hidden_act: relu + pre_ln: False + pre_ln_final_layer_norm: True + + loss: + _target_: nemo.collections.asr.losses.bce_loss.BCELoss + weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5]) + reduction: mean + + lr: 0.0001 + optim: + name: adamw + lr: ${model.lr} + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + sched: + name: InverseSquareRootAnnealing + warmup_steps: 2500 + warmup_ratio: null + min_lr: 1e-06 + +trainer: + devices: 1 # number of gpus (devices) + accelerator: gpu + max_epochs: 800 + max_steps: -1 # computed at runtime if not set + num_nodes: 1 + strategy: ddp_find_unused_parameters_true # Could be "ddp" + accumulate_grad_batches: 1 + deterministic: True + enable_checkpointing: False + logger: False + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + +exp_manager: + use_datetime_version: False + exp_dir: null + name: ${name} + resume_if_exists: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_ignore_no_checkpoint: True + create_tensorboard_logger: True + create_checkpoint_callback: True + create_wandb_logger: False + checkpoint_callback_params: + monitor: "val_f1_acc" + mode: "max" + save_top_k: 9 + every_n_epochs: 1 + wandb_logger_kwargs: + resume: True + name: null + project: null \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml similarity index 85% rename from examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml rename to examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml index 6b960e2d5950..3733e1285b77 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -3,7 +3,6 @@ # Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). # These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. # These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477. -# Trial 17903 finished with value: 0.10261257411949805 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.0, 'min_duration_on': 0.39, 'min_duration_off': 0.39}. Best is trial 17903 with value: 0.10261257411949805. # Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. parameters: window_length_in_sec: 0.0 # Not used diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml similarity index 84% rename from examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml rename to examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml index bb9f362ad619..275bc86db4cd 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml @@ -3,7 +3,7 @@ # Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). # These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. # These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. -# Trial 180 finished with value: 0.12329626986650599 and parameters: {'onset': 0.56, 'offset': 0.81, 'pad_onset': 0.05, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.16}. Best is trial 180 with value: 0.12329626986650599. +# Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. parameters: window_length_in_sec: 0.0 # Not used shift_length_in_sec: 0.0 # Not used diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py deleted file mode 100644 index aafd2b2cb6ed..000000000000 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytorch_lightning as pl -from omegaconf import OmegaConf -from pytorch_lightning import seed_everything -import seaborn as sns -import numpy as np - -from nemo.collections.asr.models import SortformerEncLabelModel -from nemo.core.config import hydra_runner -from nemo.utils import logging -from nemo.utils.exp_manager import exp_manager -seed_everything(42) -import torch -import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.manifold import TSNE -import pandas as pd -from nemo.collections.asr.data.audio_to_msdd_mock_label import generate_mock_embs - -def plot_enc_tsne(x, targets, memo): - # x = enc_states_list[-1].squeeze(0).cpu().detach().numpy() - tsne = TSNE(n_components=2, verbose=False, random_state=100) - zembs = tsne.fit_transform(x) - - # Step 1: Create a new column filled with 0.5 - new_column = torch.full((targets.size(0), 1), 0.5) - # Step 2: Concatenate the new column with the original tensor - updated_targets = torch.cat((new_column, targets), dim=1) - - df = pd.DataFrame() - df["y"] = updated_targets.argmax(dim=1).detach().cpu().numpy() - df["comp-1"] = zembs[:,0] - df["comp-2"] = zembs[:,1] - - # Plotting using seaborn - plt.figure(figsize=(10, 8)) - sns.scatterplot(x="comp-1", y="comp-2", hue=df.y.tolist(), - palette=sns.color_palette("hls", 10), - data=df).set(title="SortFormer HiddenState T-SNE projection") - - # Save the plot as a PNG file in the specified directory - plt.savefig(f'/home/taejinp/Downloads/tsne_plots/tsne_sortformer_plot_{memo}.png') - -def remove_speaker_models(ckpt_path): - ckpt_instance = torch.load(ckpt_path) - _state_dict = ckpt_instance['state_dict'] - - key_list = list(_state_dict.keys()) - for key in key_list: - if '_speaker_model.' in key or '_speaker_model_decoder.' in key: - # import ipdb; ipdb.set_trace() - del _state_dict[key] - - target_path = ckpt_path.replace('.ckpt', '.removed.ckpt') - torch.save(ckpt_instance, target_path) - return target_path - - -# @hydra_runner(config_path="../conf/neural_diarizer", config_name="msdd_5scl_15_05_50Povl_256x3x32x2.yaml") -def main(): - # logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') - # trainer = pl.Trainer(**cfg.trainer) - # exp_manager(trainer, cfg.get("exp_manager", None)) - # ckpt_path = "/disk_c/taejinp_backup/msdd_model_train/NVB_SFmr_MixMockEmbsTest/version_18_f0:84/checkpoints/e613.ckpt" - ckpt_path = "/disk_c/taejinp_backup/msdd_model_train/SFmr_MixMockEmbsTest/version_21/checkpoints/ep2255.ckpt" - target_path = remove_speaker_models(ckpt_path) - sortformer_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=target_path) - unit_len = 25 - targets = torch.eye(4,4).repeat_interleave(unit_len,1).t() - targets[:,2:] = 0 - # targets[:,3:] = 0 - targets = targets[:2*unit_len, :] - new_column = torch.full((targets.size(0), 1), 0.5) - updated_targets = torch.cat((new_column, targets), dim=1) - mock_embs, audio_signal_length, targets = generate_mock_embs(targets=targets, seed=315, - mock_emb_noise_std=0.03, - mock_emb_degree_of_freedom=4, - min_noise_std=0.01,) - mock_embs = mock_embs.unsqueeze(0) - audio_signal = mock_embs - - audio_signal, audio_signal_length, targets - - audio_signal = audio_signal.cuda() - ms_seg_counts = torch.tensor([]).cuda() - ms_seg_timestamps = torch.tensor([]).cuda() - scale_mapping = torch.tensor([]).cuda() - sortformer_model.alpha = 0.0 - - _preds_mean, preds_, attn_score_stack, enc_states_list, preds_list = sortformer_model.forward( - audio_signal=audio_signal, - audio_signal_length=audio_signal_length, - ms_seg_timestamps=ms_seg_timestamps, - ms_seg_counts=ms_seg_counts, - scale_mapping=scale_mapping, - temp_targets=targets, - ) - - audio_signal_np = audio_signal.squeeze(0).cpu().detach().numpy() - plot_enc_tsne(audio_signal_np, targets, memo=f'input', ) - for layer_c in range(len(enc_states_list)): - print(f"Plotting TSNE for layer {layer_c} ...") - x = enc_states_list[layer_c].squeeze(0).cpu().detach().numpy() - plot_enc_tsne(x, targets, memo=f'layer{layer_c}', ) - preds = preds_.squeeze(0).cpu().detach().numpy() - plot_enc_tsne(preds, targets, memo=f'preds', ) - _preds_mean = _preds_mean.squeeze(0).cpu().detach().numpy() - plot_enc_tsne(_preds_mean, targets, memo=f'preds_mean', ) - - # Optionally, you can also show the plot if desired - plt.show() - import ipdb; ipdb.set_trace() - - # msdd_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) - # trainer.fit(msdd_model) - - -if __name__ == '__main__': - main() diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py similarity index 100% rename from examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py rename to examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py From 2f44fe1fd6526ba1b691ccdd57b3dadc22cef4b0 Mon Sep 17 00:00:00 2001 From: tango4j Date: Thu, 14 Nov 2024 09:08:01 +0000 Subject: [PATCH 04/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../neural_diarizer/e2e_diarize_speech.py | 238 ++++--- .../neural_diarizer/sortformer_diar_train.py | 2 +- .../asr/data/audio_to_diar_label.py | 182 ++--- .../asr/data/audio_to_diar_label_lhotse.py | 25 +- nemo/collections/asr/metrics/der.py | 32 +- .../asr/metrics/multi_binary_acc.py | 7 +- nemo/collections/asr/models/__init__.py | 2 +- .../asr/models/sortformer_diar_models.py | 156 +++-- .../asr/modules/sortformer_modules.py | 7 +- .../asr/parts/utils/asr_multispeaker_utils.py | 636 +++++++++++------- .../asr/parts/utils/speaker_utils.py | 128 ++-- nemo/collections/asr/parts/utils/vad_utils.py | 41 +- .../common/parts/preprocessing/collections.py | 16 +- 13 files changed, 850 insertions(+), 622 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 98f2ee10e523..40ed9fab7a64 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -19,32 +19,31 @@ dataset_manifest=/path/to/diarization_path_to_manifest.json """ +import logging +import os +import tempfile +from dataclasses import dataclass, is_dataclass +from typing import Dict, List, Optional, Tuple, Union + +import optuna import pytorch_lightning as pl +import torch +import yaml +from hydra.core.config_store import ConfigStore from omegaconf import OmegaConf from pytorch_lightning import seed_everything +from tqdm import tqdm -from nemo.collections.asr.models import SortformerEncLabelModel -from nemo.core.config import hydra_runner from nemo.collections.asr.metrics.der import score_labels -from hydra.core.config_store import ConfigStore - -import os -import yaml -from dataclasses import dataclass, is_dataclass -from typing import Optional, Union, List, Tuple, Dict - +from nemo.collections.asr.models import SortformerEncLabelModel from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing - -from tqdm import tqdm -import torch -import logging -import optuna -import tempfile +from nemo.core.config import hydra_runner seed_everything(42) torch.backends.cudnn.deterministic = True + @dataclass class PostProcessingParams: window_length_in_sec: float = 0.15 @@ -59,6 +58,7 @@ class PostProcessingParams: min_duration_off: float = 0.0 filter_speech_first: bool = True + @dataclass class DiarizationConfig: # Required configs @@ -66,50 +66,53 @@ class DiarizationConfig: pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest - + postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations no_der: bool = False out_rttm_dir: Optional[str] = None - + # General configs - session_len_sec: float = -1 # End-to-end diarization session length in seconds + session_len_sec: float = -1 # End-to-end diarization session length in seconds batch_size: int = 4 num_workers: int = 0 random_seed: Optional[int] = None # seed number going to be used in seed_everything() - bypass_postprocessing: bool = True # If True, postprocessing will be bypassed - + bypass_postprocessing: bool = True # If True, postprocessing will be bypassed + # Eval Settings: (0.25, False) should be default setting for sortformer eval. - collar: float = 0.25 # Collar in seconds for DER calculation - ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments + collar: float = 0.25 # Collar in seconds for DER calculation + ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments # If `cuda` is a negative number, inference will be on CPU only. cuda: Optional[int] = None matmul_precision: str = "highest" # Literal["highest", "high", "medium"] # Optuna Config - launch_pp_optim: bool = False # If True, launch optimization process for postprocessing parameters + launch_pp_optim: bool = False # If True, launch optimization process for postprocessing parameters optuna_study_name: str = "optim_postprocessing" optuna_temp_dir: str = "/tmp/optuna" optuna_storage: str = f"sqlite:///{optuna_study_name}.db" optuna_log_file: str = f"{optuna_study_name}.log" optuna_n_trials: int = 100000 + def load_postprocessing_from_yaml(postprocessing_yaml): - """ + """ Load postprocessing parameters from a YAML file. Args: - postprocessing_yaml (str): + postprocessing_yaml (str): Path to a YAML file for postprocessing configurations. Returns: - postprocessing_params (dataclass): + postprocessing_params (dataclass): Postprocessing parameters loaded from the YAML file. """ # Add PostProcessingParams as a field postprocessing_params = OmegaConf.structured(PostProcessingParams()) if postprocessing_yaml is None: - logging.info(f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied.") + logging.info( + f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied." + ) else: # Load postprocessing params from the provided YAML file with open(postprocessing_yaml, 'r') as file: @@ -121,6 +124,7 @@ def load_postprocessing_from_yaml(postprocessing_yaml): setattr(postprocessing_params, key, value) return postprocessing_params + def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: """ Suggests hyperparameters for postprocessing using Optuna. @@ -140,6 +144,7 @@ def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optun postprocessing_cfg.min_duration_off = trial.suggest_float("min_duration_off", 0.0, 0.75, step=0.01) return postprocessing_cfg + def get_tensor_path(cfg: DiarizationConfig) -> str: """ Constructs the file path for saving or loading prediction tensors based on the configuration. @@ -159,20 +164,21 @@ def get_tensor_path(cfg: DiarizationConfig) -> str: tensor_path = f"{bpath}/__{model_id}__{tensor_filename}.pt" return tensor_path + def diarization_objective( - trial: optuna.Trial, - postprocessing_cfg: PostProcessingParams, - temp_out_dir: str, - infer_audio_rttm_dict: Dict[str, Dict[str, str]], - diar_model_preds_total_list: List[torch.Tensor], - collar: float = 0.25, - ignore_overlap: bool = False + trial: optuna.Trial, + postprocessing_cfg: PostProcessingParams, + temp_out_dir: str, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + diar_model_preds_total_list: List[torch.Tensor], + collar: float = 0.25, + ignore_overlap: bool = False, ) -> float: """ Objective function for Optuna hyperparameter optimization in speaker diarization. This function evaluates the diarization performance using a set of postprocessing parameters - suggested by Optuna. It converts prediction matrices to time-stamp segments, scores the + suggested by Optuna. It converts prediction matrices to time-stamp segments, scores the diarization results, and returns the Diarization Error Rate (DER) as the optimization metric. Args: @@ -192,42 +198,43 @@ def diarization_objective( """ with tempfile.TemporaryDirectory(dir=temp_out_dir, prefix="Diar_PostProcessing_") as local_temp_out_dir: if trial is not None: - postprocessing_cfg = optuna_suggest_params(postprocessing_cfg, trial) - all_hyps, all_refs, all_uems = convert_pred_mat_to_segments(audio_rttm_map_dict=infer_audio_rttm_dict, - postprocessing_cfg=postprocessing_cfg, - batch_preds_list=diar_model_preds_total_list, - unit_10ms_frame_count=8, - bypass_postprocessing=False) - metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, - all_reference=all_refs, - all_hypothesis=all_hyps, - all_uem=all_uems, - collar=collar, - ignore_overlap=ignore_overlap - ) + postprocessing_cfg = optuna_suggest_params(postprocessing_cfg, trial) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + audio_rttm_map_dict=infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=False, + ) + metric, mapping_dict, itemized_errors = score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=collar, + ignore_overlap=ignore_overlap, + ) der = abs(metric) return der + def run_optuna_hyperparam_search( cfg: DiarizationConfig, # type: DiarizationConfig postprocessing_cfg: PostProcessingParams, - infer_audio_rttm_dict: Dict[str, Dict[str, str]], - preds_list: List[torch.Tensor], - temp_out_dir: str, - ): + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + preds_list: List[torch.Tensor], + temp_out_dir: str, +): worker_function = lambda trial: diarization_objective( trial=trial, postprocessing_cfg=postprocessing_cfg, temp_out_dir=temp_out_dir, - infer_audio_rttm_dict=infer_audio_rttm_dict, + infer_audio_rttm_dict=infer_audio_rttm_dict, diar_model_preds_total_list=preds_list, collar=cfg.collar, ) study = optuna.create_study( - direction="minimize", - study_name=cfg.optuna_study_name, - storage=cfg.optuna_storage, - load_if_exists=True + direction="minimize", study_name=cfg.optuna_study_name, storage=cfg.optuna_storage, load_if_exists=True ) logger = logging.getLogger() logger.setLevel(logging.INFO) # Setup the root logger. @@ -235,17 +242,17 @@ def run_optuna_hyperparam_search( logger.addHandler(logging.FileHandler(cfg.optuna_log_file, mode="a")) logger.addHandler(logging.StreamHandler()) optuna.logging.enable_propagation() # Propagate logs to the root logger. - study.optimize(worker_function, n_trials=cfg.optuna_n_trials) + study.optimize(worker_function, n_trials=cfg.optuna_n_trials) def convert_pred_mat_to_segments( - audio_rttm_map_dict: Dict[str, Dict[str, str]], - postprocessing_cfg, - batch_preds_list: List[torch.Tensor], - unit_10ms_frame_count:int = 8, + audio_rttm_map_dict: Dict[str, Dict[str, str]], + postprocessing_cfg, + batch_preds_list: List[torch.Tensor], + unit_10ms_frame_count: int = 8, bypass_postprocessing: bool = False, out_rttm_dir: str | None = None, - ): +): """ Convert prediction matrix to time-stamp segments. @@ -263,32 +270,38 @@ def convert_pred_mat_to_segments( """ batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], [] cfg_vad_params = OmegaConf.structured(postprocessing_cfg) - for sample_idx, (uniq_id, audio_rttm_values) in tqdm(enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc="Running post-processing"): + for sample_idx, (uniq_id, audio_rttm_values) in tqdm( + enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc="Running post-processing" + ): spk_ts = [] offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration'] speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0) speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] for spk_id in range(speaker_assign_mat.shape[-1]): - ts_mat = ts_vad_post_processing(speaker_assign_mat[:, spk_id], - cfg_vad_params=cfg_vad_params, - unit_10ms_frame_count=unit_10ms_frame_count, - bypass_postprocessing=bypass_postprocessing) + ts_mat = ts_vad_post_processing( + speaker_assign_mat[:, spk_id], + cfg_vad_params=cfg_vad_params, + unit_10ms_frame_count=unit_10ms_frame_count, + bypass_postprocessing=bypass_postprocessing, + ) ts_mat = ts_mat + offset ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration)) ts_seg_list = ts_mat.tolist() speaker_timestamps[spk_id].extend(ts_seg_list) spk_ts.append(ts_seg_list) - all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object(speaker_timestamps, - uniq_id, - audio_rttm_values, - all_hypothesis, - all_reference, - all_uems, - out_rttm_dir, - ) - batch_pred_ts_segs.append(spk_ts) + all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object( + speaker_timestamps, + uniq_id, + audio_rttm_values, + all_hypothesis, + all_reference, + all_uems, + out_rttm_dir, + ) + batch_pred_ts_segs.append(spk_ts) return all_hypothesis, all_reference, all_uems + @hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: for key in cfg: @@ -299,7 +312,7 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: if cfg.random_seed: pl.seed_everything(cfg.random_seed) - + if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") if cfg.audio_dir is None and cfg.dataset_manifest is None: @@ -322,65 +335,74 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: map_location = torch.device(f'cuda:{cfg.cuda}') if cfg.model_path.endswith(".ckpt"): - diar_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=cfg.model_path, map_location=map_location, strict=False) + diar_model = SortformerEncLabelModel.load_from_checkpoint( + checkpoint_path=cfg.model_path, map_location=map_location, strict=False + ) elif cfg.model_path.endswith(".nemo"): diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.model_path, map_location=map_location) else: raise ValueError("cfg.model_path must end with.ckpt or.nemo!") - + diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec trainer = pl.Trainer(devices=device, accelerator=accelerator) diar_model.set_trainer(trainer) - + diar_model = diar_model.eval() diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest) diar_model._cfg.test_ds.batch_size = cfg.batch_size - - # Model setup for inference + + # Model setup for inference diar_model._cfg.test_ds.num_workers = cfg.num_workers - diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) - + diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) + postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml) tensor_path = get_tensor_path(cfg) - + if os.path.exists(tensor_path): - logging.info(f"A saved prediction tensor has been found. Loading the saved prediction tensors from {tensor_path}...") + logging.info( + f"A saved prediction tensor has been found. Loading the saved prediction tensors from {tensor_path}..." + ) diar_model_preds_total_list = torch.load(tensor_path) else: logging.info(f"No saved prediction tensors found. Running inference on the dataset...") diar_model.test_batch() diar_model_preds_total_list = diar_model.preds_total_list torch.save(diar_model.preds_total_list, tensor_path) - + if cfg.launch_pp_optim: # Launch a hyperparameter optimization process if launch_pp_optim is True - run_optuna_hyperparam_search(cfg=cfg, - postprocessing_cfg=postprocessing_cfg, - infer_audio_rttm_dict=infer_audio_rttm_dict, - preds_list=diar_model_preds_total_list, - temp_out_dir=cfg.optuna_temp_dir) + run_optuna_hyperparam_search( + cfg=cfg, + postprocessing_cfg=postprocessing_cfg, + infer_audio_rttm_dict=infer_audio_rttm_dict, + preds_list=diar_model_preds_total_list, + temp_out_dir=cfg.optuna_temp_dir, + ) # Evaluation if not cfg.no_der: if cfg.out_rttm_dir is not None and not os.path.exists(cfg.out_rttm_dir): os.mkdir(cfg.out_rttm_dir) - all_hyps, all_refs, all_uems = convert_pred_mat_to_segments(infer_audio_rttm_dict, - postprocessing_cfg=postprocessing_cfg, - batch_preds_list=diar_model_preds_total_list, - unit_10ms_frame_count=8, - bypass_postprocessing=cfg.bypass_postprocessing, - out_rttm_dir=cfg.out_rttm_dir - ) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=cfg.bypass_postprocessing, + out_rttm_dir=cfg.out_rttm_dir, + ) logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") - metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, - all_reference=all_refs, - all_hypothesis=all_hyps, - all_uem=all_uems, - collar=cfg.collar, - ignore_overlap=cfg.ignore_overlap - ) + metric, mapping_dict, itemized_errors = score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap, + ) logging.info(f"PostProcessingParams: {postprocessing_cfg}") + if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index fb350113d596..3ba0dbc3ed19 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -50,5 +50,5 @@ def main(cfg): if __name__ == '__main__': - + main() diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index ffad8e4fd072..b00338743a43 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -15,18 +15,23 @@ import os from collections import OrderedDict from statistics import mode -from typing import Dict, List, Tuple, Optional -import torch +from typing import Dict, List, Optional, Tuple + import numpy as np +import torch -from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat from nemo.collections.asr.parts.utils.asr_multispeaker_utils import find_first_nonzero -from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data, get_subsegments -from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel, EndtoEndDiarizationSpeechLabel +from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat +from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, get_subsegments, prepare_split_data +from nemo.collections.common.parts.preprocessing.collections import ( + DiarizationSpeechLabel, + EndtoEndDiarizationSpeechLabel, +) from nemo.core.classes import Dataset from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType from nemo.utils import logging + def get_scale_mapping_list(uniq_timestamps): """ Call get_argmin_mat function to find the index of the non-base-scale segment that is closest to the @@ -125,7 +130,7 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return None else: sorted_speakers = sorted(list(set(speaker_list))) - total_fr_len = int(max(end_list) * (10 ** round_digits)) + total_fr_len = int(max(end_list) * (10**round_digits)) spk_num = max(len(sorted_speakers), min_spks) speaker_mapping_dict = {rttm_key: x_int for x_int, rttm_key in enumerate(sorted_speakers)} fr_level_target = torch.zeros(total_fr_len, spk_num) @@ -141,27 +146,24 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, def get_subsegments_to_timestamps( - subsegments: List[Tuple[float, float]], - feat_per_sec: int = 100, - max_end_ts: float=None, - decimals=2 - ): + subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 +): """ Convert subsegment timestamps to scale timestamps by multiplying with the feature rate and rounding. All `ts` related tensors are dimensioned as (N, 2), where N is the number of subsegments. Args: - subsegments (List[Tuple[float, float]]): + subsegments (List[Tuple[float, float]]): A list of tuples where each tuple contains the start and end times of a subsegment. - feat_per_sec (int, optional): + feat_per_sec (int, optional): The number of feature frames per second. Defaults to 100. - max_end_ts (float, optional): + max_end_ts (float, optional): The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. - decimals (int, optional): + decimals (int, optional): The number of decimal places to round the timestamps. Defaults to 2. Returns: - ts (torch.tensor): + ts (torch.tensor): A tensor containing the scaled and rounded timestamps for each subsegment. """ seg_ts = (torch.tensor(subsegments) * feat_per_sec).float() @@ -169,8 +171,9 @@ def get_subsegments_to_timestamps( ts = ts_round.long() ts[:, 1] = ts[:, 0] + ts[:, 1] if max_end_ts is not None: - ts = np.clip(ts, 0, int(max_end_ts*feat_per_sec)) - return ts + ts = np.clip(ts, 0, int(max_end_ts * feat_per_sec)) + return ts + def extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines, round_digits=3): """ @@ -190,42 +193,43 @@ def extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines, round_di rttm_stt, rttm_end = offset, offset + duration stt_list, end_list, speaker_list, speaker_set = [], [], [], [] sess_to_global_spkids = dict() - + for rttm_line in rttm_lines: start, end, speaker = convert_rttm_line(rttm_line) - + # Skip invalid RTTM lines where the start time is greater than the end time. if start > end: continue - + # Check if the RTTM segment overlaps with the specified segment of interest. if (end > rttm_stt and start < rttm_end) or (start < rttm_end and end > rttm_stt): # Adjust the start and end times to fit within the segment of interest. start, end = max(start, rttm_stt), min(end, rttm_end) else: continue - + # Round the start and end times to the specified number of decimal places. end_list.append(round(end, round_digits)) stt_list.append(round(start, round_digits)) - + # Assign a unique index to each speaker and maintain a mapping. if speaker not in speaker_set: speaker_set.append(speaker) speaker_list.append(speaker_set.index(speaker)) sess_to_global_spkids.update({speaker_set.index(speaker): speaker}) - + rttm_mat = (stt_list, end_list, speaker_list) return rttm_mat, sess_to_global_spkids + def get_frame_targets_from_rttm( - rttm_timestamps: list, - offset: float, - duration: float, - round_digits: int, - feat_per_sec: int, + rttm_timestamps: list, + offset: float, + duration: float, + round_digits: int, + feat_per_sec: int, max_spks: int, - ): +): """ Create a multi-dimensional vector sequence containing speaker timestamp information in RTTM. The unit-length is the frame shift length of the acoustic feature. The feature-level annotations @@ -249,15 +253,17 @@ def get_frame_targets_from_rttm( sorted_speakers = sorted(list(set(speaker_list))) total_fr_len = int(duration * feat_per_sec) if len(sorted_speakers) > max_spks: - logging.warning(f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: {max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!") - feat_level_target = torch.zeros(total_fr_len, max_spks) + logging.warning( + f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: {max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!" + ) + feat_level_target = torch.zeros(total_fr_len, max_spks) for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)): if end < offset or stt > offset + duration: continue stt, end = max(offset, stt), min(offset + duration, end) spk = spk_rttm_key if spk < max_spks: - stt_fr, end_fr = int((stt - offset) * feat_per_sec), int((end - offset)* feat_per_sec) + stt_fr, end_fr = int((stt - offset) * feat_per_sec), int((end - offset) * feat_per_sec) feat_level_target[stt_fr:end_fr, spk] = 1 return feat_level_target @@ -337,7 +343,7 @@ def __init__( self.multiscale_args_dict = multiscale_args_dict self.emb_dir = emb_dir self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer self.max_spks = 2 @@ -347,7 +353,10 @@ def __init__( self.global_rank = global_rank self.manifest_filepath = manifest_filepath self.multiscale_timestamp_dict = prepare_split_data( - self.manifest_filepath, self.emb_dir, self.multiscale_args_dict, self.global_rank, + self.manifest_filepath, + self.emb_dir, + self.multiscale_args_dict, + self.global_rank, ) def __len__(self): @@ -364,7 +373,7 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): Unique sample ID for training. base_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for the base-scale segments. - + Returns: per_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for each segment in each scale. @@ -415,7 +424,7 @@ def get_diar_target_labels(self, uniq_id, sample, fr_level_target): seg_target_list, base_clus_label = [], [] self.scale_n = len(self.multiscale_timestamp_dict[uniq_id]['scale_dict']) subseg_time_stamp_list = self.multiscale_timestamp_dict[uniq_id]["scale_dict"][self.scale_n - 1]["time_stamps"] - for (seg_stt, seg_end) in subseg_time_stamp_list: + for seg_stt, seg_end in subseg_time_stamp_list: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec_sess = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -619,7 +628,7 @@ def __init__( self.emb_seq = emb_seq self.clus_label_dict = clus_label_dict self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.frame_per_sec = int(1 / window_stride) self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer @@ -685,7 +694,7 @@ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): return None else: seg_target_list = [] - for (seg_stt, seg_end, label_int) in self.clus_label_dict[uniq_id]: + for seg_stt, seg_end, label_int in self.clus_label_dict[uniq_id]: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -975,6 +984,7 @@ def __init__( def msdd_infer_collate_fn(self, batch): return _msdd_infer_collate_fn(self, batch) + class _AudioToSpeechE2ESpkDiarDataset(Dataset): """ Dataset class that loads a json file containing paths to audio files, @@ -1047,8 +1057,8 @@ def __init__( self.use_asr_style_frame_count = True self.soft_targets = soft_targets self.round_digits = 2 - self.floor_decimal = 10 ** self.round_digits - + self.floor_decimal = 10**self.round_digits + def __len__(self): return len(self.collection) @@ -1085,15 +1095,16 @@ def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, rttm_lines = open(rttm_file).readlines() rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines) - fr_level_target = get_frame_targets_from_rttm(rttm_timestamps=rttm_timestamps, - offset=offset, - duration=duration, - round_digits=self.round_digits, - feat_per_sec=self.feat_per_sec, - max_spks=self.max_spks) + fr_level_target = get_frame_targets_from_rttm( + rttm_timestamps=rttm_timestamps, + offset=offset, + duration=duration, + round_digits=self.round_digits, + feat_per_sec=self.feat_per_sec, + max_spks=self.max_spks, + ) - soft_target_seg = self.get_soft_targets_seg(feat_level_target=fr_level_target, - target_len=target_len) + soft_target_seg = self.get_soft_targets_seg(feat_level_target=fr_level_target, target_len=target_len) if self.soft_targets: step_target = soft_target_seg else: @@ -1128,15 +1139,15 @@ def get_soft_targets_seg(self, feat_level_target, target_len): seg_end_feat = feat_level_target.shape[0] else: seg_end_feat = stride * index - 1 + int(stride / 2) - targets[index] = torch.mean(feat_level_target[seg_stt_feat:seg_end_feat+1, :], axis=0) + targets[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) return targets def get_segment_timestamps( self, - duration: float, - offset: float = 0, + duration: float, + offset: float = 0, sample_rate: int = 16000, - ): + ): """ Get start and end time of segments in each scale. @@ -1150,22 +1161,28 @@ def get_segment_timestamps( Number of segments for each scale. This information is used for reshaping embedding batch during forward propagation. """ - subsegments = get_subsegments(offset=offset, - window=round(self.diar_frame_length * 2, self.round_digits), - shift=self.diar_frame_length, - duration=duration, - min_subsegment_duration=self.min_subsegment_duration, - use_asr_style_frame_count=self.use_asr_style_frame_count, - sample_rate=sample_rate, - feat_per_sec=self.feat_per_sec, + subsegments = get_subsegments( + offset=offset, + window=round(self.diar_frame_length * 2, self.round_digits), + shift=self.diar_frame_length, + duration=duration, + min_subsegment_duration=self.min_subsegment_duration, + use_asr_style_frame_count=self.use_asr_style_frame_count, + sample_rate=sample_rate, + feat_per_sec=self.feat_per_sec, ) if self.use_asr_style_frame_count: - effective_dur = np.ceil((1+duration*sample_rate)/int(sample_rate/self.feat_per_sec)).astype(int)/self.feat_per_sec + effective_dur = ( + np.ceil((1 + duration * sample_rate) / int(sample_rate / self.feat_per_sec)).astype(int) + / self.feat_per_sec + ) else: - effective_dur = duration - ts_tensor = get_subsegments_to_timestamps(subsegments, self.feat_per_sec, decimals=2, max_end_ts=(offset+effective_dur)) + effective_dur = duration + ts_tensor = get_subsegments_to_timestamps( + subsegments, self.feat_per_sec, decimals=2, max_end_ts=(offset + effective_dur) + ) target_len = torch.tensor([ts_tensor.shape[0]]) - return target_len + return target_len def __getitem__(self, index): sample = self.collection[index] @@ -1179,24 +1196,25 @@ def __getitem__(self, index): uniq_id = self.get_uniq_id_with_range(sample) audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) - + # We should resolve the length mis-match from the round-off errors: `session_len_sec` and `audio_signal.shape[0]` - session_len_sec = np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal)/self.floor_decimal - audio_signal = audio_signal[:round(self.featurizer.sample_rate*session_len_sec)] - + session_len_sec = ( + np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal + ) + audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)] + audio_signal_length = torch.tensor(audio_signal.shape[0]).long() audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) - targets = self.parse_rttm_for_targets_and_lens(uniq_id=uniq_id, - rttm_file=sample.rttm_file, - offset=offset, - duration=session_len_sec, - target_len=target_len) + targets = self.parse_rttm_for_targets_and_lens( + uniq_id=uniq_id, rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len + ) return audio_signal, audio_signal_length, targets, target_len + def _eesd_train_collate_fn(self, batch): """ - Collate a batch of variables needed for training the end-to-end speaker diarization (EESD) model + Collate a batch of variables needed for training the end-to-end speaker diarization (EESD) model from raw waveforms to diarization labels. The following variables are included in the training/validation batch: Args: @@ -1249,24 +1267,25 @@ def _eesd_train_collate_fn(self, batch): targets = torch.stack(targets_list) return audio_signal, feature_length, targets, target_lens + class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): """ Dataset class for loading a JSON file containing paths to audio files, RTTM (Rich Transcription Time Marked) files, and the number of speakers. - This class is designed for training or fine-tuning a speaker embedding + This class is designed for training or fine-tuning a speaker embedding extractor and diarization decoder simultaneously. The JSON manifest file should have entries in the following format: - + Example: { - "audio_filepath": "/path/to/audio_0.wav", + "audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2, "rttm_filepath": "/path/to/diar_label_0.rttm" } ... { - "audio_filepath": "/path/to/audio_n.wav", + "audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2, "rttm_filepath": "/path/to/diar_label_n.rttm" } @@ -1283,7 +1302,7 @@ class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): featurizer: Instance of a featurizer for generating features from the raw waveform. window_stride (float): - Window stride (in seconds) for extracting acoustic features, used to calculate + Window stride (in seconds) for extracting acoustic features, used to calculate the number of feature frames. global_rank (int): Global rank of the current process (used for distributed training). @@ -1294,6 +1313,7 @@ class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): eesd_train_collate_fn(batch): Collates a batch of data for end-to-end speaker diarization training. """ + def __init__( self, *, @@ -1318,4 +1338,4 @@ def __init__( ) def eesd_train_collate_fn(self, batch): - return _eesd_train_collate_fn(self, batch) \ No newline at end of file + return _eesd_train_collate_fn(self, batch) diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index e223e4ef2a56..8d11c4c1167d 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -18,11 +18,12 @@ from lhotse.dataset import AudioSamples from lhotse.dataset.collation import collate_matrices -from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( - speaker_to_target, - get_hidden_length_from_sample_length, + get_hidden_length_from_sample_length, + speaker_to_target, ) +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType + class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): """ @@ -43,16 +44,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'target_length': NeuralType(tuple('B'), LengthsType()), 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), } - + def __init__(self, cfg): super().__init__() self.load_audio = AudioSamples(fault_tolerant=True) self.cfg = cfg self.num_speakers = self.cfg.get('num_speakers', 4) - self.num_sample_per_mel_frame = int(self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000)) # 160 + self.num_sample_per_mel_frame = int( + self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) + ) # 160 self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) - self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero',False) - + self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero', False) + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: audio, audio_lens, cuts = self.load_audio(cuts) speaker_activities = [] @@ -63,14 +66,16 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: num_sample_per_mel_frame=self.num_sample_per_mel_frame, num_mel_frame_per_asr_frame=self.num_mel_frame_per_target_frame, spk_tar_all_zero=self.spk_tar_all_zero, - boundary_segments=True + boundary_segments=True, ) speaker_activities.append(speaker_activity) targets = collate_matrices(speaker_activities).to(audio.dtype) target_lens_list = [] for audio_len in audio_lens: - target_fr_len = get_hidden_length_from_sample_length(audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame) + target_fr_len = get_hidden_length_from_sample_length( + audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame + ) target_lens_list.append([target_fr_len]) target_lens = torch.tensor(target_lens_list) - + return audio, audio_lens, targets, target_lens diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index 16f62bbe9e4c..000b839ceb46 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -36,12 +36,12 @@ def get_partial_ref_labels(pred_labels: List[str], ref_labels: List[str]) -> List[str]: """ - For evaluation of online diarization performance, generate partial reference labels + For evaluation of online diarization performance, generate partial reference labels from the last prediction time. Args: pred_labels (list[str]): list of partial prediction labels - ref_labels (list[str]): list of full reference labels + ref_labels (list[str]): list of full reference labels Returns: ref_labels_out (list[str]): list of partial reference labels @@ -84,8 +84,8 @@ def get_online_DER_stats( For evaluation of online diarization performance, add cumulative, average, and maximum DER/CER. Args: - DER (float): Diarization Error Rate from the start to the current point - CER (float): Confusion Error Rate from the start to the current point + DER (float): Diarization Error Rate from the start to the current point + CER (float): Confusion Error Rate from the start to the current point FA (float): False Alarm from the start to the current point MISS (float): Miss rate from the start to the current point diar_eval_count (int): Number of evaluation sessions @@ -130,13 +130,13 @@ def uem_timeline_from_file(uem_file, uniq_name=''): def score_labels( - AUDIO_RTTM_MAP, - all_reference, - all_hypothesis, - all_uem: List[List[float]]=None, - collar:float=0.25, - ignore_overlap: bool=True, - verbose: bool = True + AUDIO_RTTM_MAP, + all_reference, + all_hypothesis, + all_uem: List[List[float]] = None, + collar: float = 0.25, + ignore_overlap: bool = True, + verbose: bool = True, ) -> Optional[Tuple[DiarizationErrorRate, Dict]]: """ Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis results are @@ -170,7 +170,9 @@ def score_labels( if len(ref_labels.labels()) == len(hyp_labels.labels()): correct_spk_count += 1 if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): - logging.info(f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}") + logging.info( + f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" + ) uem_obj = None if all_uem is not None: metric(ref_labels, hyp_labels, uem=all_uem[idx], detailed=True) @@ -189,7 +191,7 @@ def score_labels( CER = metric['confusion'] / metric['total'] FA = metric['false alarm'] / metric['total'] MISS = metric['missed detection'] / metric['total'] - + itemized_errors = (DER, CER, FA, MISS) if verbose: @@ -386,7 +388,7 @@ def calculate_session_cpWER( # Calculate WER for each speaker in hypothesis with reference # There are (number of hyp speakers) x (number of ref speakers) combinations lsa_wer_list = [] - for (spk_hyp_trans, spk_ref_trans) in all_pairs: + for spk_hyp_trans, spk_ref_trans in all_pairs: spk_wer = word_error_rate(hypotheses=[spk_hyp_trans], references=[spk_ref_trans]) lsa_wer_list.append(spk_wer) @@ -440,7 +442,7 @@ def concat_perm_word_error_rate( f"{len(spk_hypotheses)} and {len(spk_references)} correspondingly" ) cpWER_values, hyps_spk, refs_spk = [], [], [] - for (spk_hypothesis, spk_reference) in zip(spk_hypotheses, spk_references): + for spk_hypothesis, spk_reference in zip(spk_hypotheses, spk_references): cpWER, min_hypothesis, concat_reference = calculate_session_cpWER(spk_hypothesis, spk_reference) cpWER_values.append(cpWER) hyps_spk.append(min_hypothesis) diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 72781143208b..13e57b43bb0b 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -68,6 +68,7 @@ def on_validation_epoch_end(self): f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. """ + full_state_update = False def __init__(self, dist_sync_on_step=False): @@ -80,7 +81,9 @@ def __init__(self, dist_sync_on_step=False): self.add_state("positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) self.eps = 1e-6 - def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False) -> torch.Tensor: + def update( + self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False + ) -> torch.Tensor: with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] targets_list = [targets[k, : signal_lengths[k], :] for k in range(targets.shape[0])] @@ -99,7 +102,7 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: tor self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) - else: + else: self.positive_count = torch.sum(self.preds.round().bool() == True) self.true_positive_count = torch.sum(torch.logical_and(self.true, self.positive)) self.false_positive_count = torch.sum(torch.logical_and(self.false, self.positive)) diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 31194d8849f0..2573a7ac84b4 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -19,7 +19,6 @@ EncDecClassificationModel, EncDecFrameClassificationModel, ) -from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel @@ -36,5 +35,6 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 50cdf6214d5b..665f439b0ad0 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -12,28 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import itertools import random -import torch +import time from collections import OrderedDict from typing import Dict, List, Optional, Union + +import torch from hydra.utils import instantiate from omegaconf import DictConfig from pytorch_lightning import Trainer from tqdm import tqdm -from nemo.core.classes import ModelPT -from nemo.core.classes.common import PretrainedModelInfo -from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType -from nemo.core.neural_types.elements import ProbsType -from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations -from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config -from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset + from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy from nemo.collections.asr.models.asr_model import ExportableEncDecModel from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer -from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_pil_targets, get_ats_targets +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.core.neural_types.elements import ProbsType from nemo.utils import logging try: @@ -45,10 +47,12 @@ def autocast(enabled=None): yield -# torch.backends.cudnn.enabled = False + +# torch.backends.cudnn.enabled = False __all__ = ['SortformerEncLabelModel'] + class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): """ Encoder class for Sortformer diarization model. @@ -80,7 +84,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): random.seed(42) self._trainer = trainer if trainer else None self._cfg = cfg - + if self._trainer: self.world_size = trainer.num_nodes * trainer.num_devices else: @@ -109,27 +113,27 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.streaming_mode = self._cfg.get("streaming_mode", False) self.save_hyperparameters("cfg") self._init_eval_metrics() - + speaker_inds = list(range(self._cfg.max_num_of_spks)) - self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations - + self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations + def _init_loss_weights(self): pil_weight = self._cfg.get("pil_weight", 0.0) ats_weight = self._cfg.get("ats_weight", 1.0) if pil_weight + ats_weight == 0: raise ValueError(f"weights for PIL {pil_weight} and ATS {ats_weight} cannot sum to 0") - self.pil_weight = pil_weight/(pil_weight + ats_weight) - self.ats_weight = ats_weight/(pil_weight + ats_weight) + self.pil_weight = pil_weight / (pil_weight + ats_weight) + self.ats_weight = ats_weight / (pil_weight + ats_weight) logging.info(f"Normalized weights for PIL {self.pil_weight} and ATS {self.ats_weight}") - + def _init_eval_metrics(self): - """ + """ If there is no label, then the evaluation metrics will be based on Permutation Invariant Loss (PIL). """ self._accuracy_test = MultiBinaryAccuracy() self._accuracy_train = MultiBinaryAccuracy() self._accuracy_valid = MultiBinaryAccuracy() - + self._accuracy_test_ats = MultiBinaryAccuracy() self._accuracy_train_ats = MultiBinaryAccuracy() self._accuracy_valid_ats = MultiBinaryAccuracy() @@ -137,11 +141,11 @@ def _init_eval_metrics(self): def _reset_train_metrics(self): self._accuracy_train.reset() self._accuracy_train_ats.reset() - + def _reset_valid_metrics(self): self._accuracy_valid.reset() self._accuracy_valid_ats.reset() - + def __setup_dataloader_from_config(self, config): # Switch to lhotse dataloader if specified in the config if config.get("use_lhotse"): @@ -168,7 +172,7 @@ def __setup_dataloader_from_config(self, config): global_rank = 0 time_flag = time.time() logging.info("AAB: Starting Dataloader Instance loading... Step A") - + dataset = AudioToSpeechE2ESpkDiarDataset( manifest_filepath=config.manifest_filepath, soft_label_thres=config.soft_label_thres, @@ -179,11 +183,13 @@ def __setup_dataloader_from_config(self, config): global_rank=global_rank, soft_targets=config.soft_targets if 'soft_targets' in config else False, ) - logging.info(f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader step B: {time.time() - time_flag}") + logging.info( + f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader step B: {time.time() - time_flag}" + ) self.data_collection = dataset.collection self.collate_ds = dataset - + dataloader_instance = torch.utils.data.DataLoader( dataset=dataset, batch_size=config.batch_size, @@ -195,15 +201,21 @@ def __setup_dataloader_from_config(self, config): ) logging.info(f"AAC: Dataloader Instance loading is done ETA Step B done: {time.time() - time_flag}") return dataloader_instance - + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): - self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,) + self._train_dl = self.__setup_dataloader_from_config( + config=train_data_config, + ) def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): - self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,) - + self._validation_dl = self.__setup_dataloader_from_config( + config=val_data_layer_config, + ) + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): - self._test_dl = self.__setup_dataloader_from_config(config=test_data_config,) + self._test_dl = self.__setup_dataloader_from_config( + config=test_data_config, + ) def test_dataloader(self): if self._test_dl is not None: @@ -227,11 +239,11 @@ def output_types(self) -> Dict[str, NeuralType]: "preds": NeuralType(('B', 'T', 'C'), ProbsType()), } ) - + def frontend_encoder(self, processed_signal, processed_signal_length): - """ + """ Generate encoder outputs from frontend encoder. - + Args: process_signal (torch.Tensor): tensor containing audio-feature (mel spectrogram, mfcc, etc.) processed_signal_length (torch.Tensor): tensor containing lengths of audio signal in integers @@ -248,7 +260,7 @@ def frontend_encoder(self, processed_signal, processed_signal_length): emb_seq = emb_seq.transpose(1, 2) if self._cfg.encoder.d_model != self._cfg.tf_d_model: self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) - emb_seq = self.sortformer_modules.encoder_proj(emb_seq) + emb_seq = self.sortformer_modules.encoder_proj(emb_seq) return emb_seq, emb_seq_length def forward_infer(self, emb_seq): @@ -258,7 +270,7 @@ def forward_infer(self, emb_seq): Args: emb_seq (torch.Tensor): tensor containing FastConformer encoder states (embedding vectors). Dimension: (batch_size, diar_frame_count, emb_dim) - + Returns: preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. Dimension: (batch_size, diar_frame_count, num_speakers) @@ -269,9 +281,9 @@ def forward_infer(self, emb_seq): trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask) preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) return preds - + def process_signal(self, audio_signal, audio_signal_length): - """ + """ Extract audio features from time-series signal for further processing in the model. This function performs the following steps: @@ -293,43 +305,49 @@ def process_signal(self, audio_signal, audio_signal_length): Shape: (batch_size,) """ audio_signal = audio_signal.to(self.device) - audio_signal = (1/(audio_signal.max()+self.eps)) * audio_signal - processed_signal, processed_signal_length = self.preprocessor(input_signal=audio_signal, length=audio_signal_length) + audio_signal = (1 / (audio_signal.max() + self.eps)) * audio_signal + processed_signal, processed_signal_length = self.preprocessor( + input_signal=audio_signal, length=audio_signal_length + ) return processed_signal, processed_signal_length - + def forward( - self, - audio_signal, - audio_signal_length, + self, + audio_signal, + audio_signal_length, ): """ Forward pass for training and inference. - + Args: audio_signal (torch.Tensor): tensor containing audio waveform Dimension: (batch_size, num_samples) audio_signal_length (torch.Tensor): tensor containing lengths of audio waveforms Dimension: (batch_size,) - + Returns: preds (torch.Tensor): Sorted tensor containing predicted speaker labels Dimension: (batch_size, diar_frame_count, num_speakers) encoder_states_list (list): List containing total speaker memory for each step for debugging purposes Dimension: [(batch_size, diar_frame_count, inner dim), ] """ - processed_signal, processed_signal_length = self.process_signal(audio_signal=audio_signal, audio_signal_length=audio_signal_length) - processed_signal = processed_signal[:, :, :processed_signal_length.max()] + processed_signal, processed_signal_length = self.process_signal( + audio_signal=audio_signal, audio_signal_length=audio_signal_length + ) + processed_signal = processed_signal[:, :, : processed_signal_length.max()] if self._cfg.get("streaming_mode", False): raise NotImplementedError("Streaming mode is not implemented yet.") else: - emb_seq, _ = self.frontend_encoder(processed_signal=processed_signal, processed_signal_length=processed_signal_length) + emb_seq, _ = self.frontend_encoder( + processed_signal=processed_signal, processed_signal_length=processed_signal_length + ) preds = self.forward_infer(emb_seq) return preds - + def _get_aux_train_evaluations(self, preds, targets, target_lens): - """ + """ Compute auxiliary training evaluations including losses and metrics. - + This function calculates various losses and metrics for the training process, including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) based evaluations. @@ -366,7 +384,7 @@ def _get_aux_train_evaluations(self, preds, targets, target_lens): 'train_precision': train_precision, 'train_recall': train_recall, 'train_f1_acc_ats': train_f1_acc_ats, - } + } return train_metrics def training_step(self, batch: list) -> dict: @@ -392,7 +410,7 @@ def training_step(self, batch: list) -> dict: return {'loss': train_metrics['loss']} def _get_aux_validation_evaluations(self, preds, targets, target_lens): - """ + """ Compute auxiliary validation evaluations including losses and metrics. This function calculates various losses and metrics for the validation process, including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) @@ -482,7 +500,7 @@ def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): val_f1_acc_ats_mean = torch.stack([x['val_f1_acc_ats'] for x in outputs]).mean() self._reset_valid_metrics() - + multi_val_metrics = { 'val_loss': val_loss_mean, 'val_ats_loss': val_ats_loss_mean, @@ -495,9 +513,9 @@ def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): return {'log': multi_val_metrics} def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target_lens): - """ + """ Compute auxiliary validation evaluations including losses and metrics. - + This function calculates various losses and metrics for the validation process, including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) based evaluations. @@ -525,19 +543,29 @@ def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target self._accuracy_test_ats(preds, targets_ats, target_lens) f1_acc_ats, precision_ats, recall_ats = self._accuracy_test_ats.compute() self.batch_f1_accs_ats_list.append(f1_acc_ats) - logging.info(f"batch {batch_idx}: f1_acc_ats={f1_acc_ats}, precision_ats={precision_ats}, recall_ats={recall_ats}") + logging.info( + f"batch {batch_idx}: f1_acc_ats={f1_acc_ats}, precision_ats={precision_ats}, recall_ats={recall_ats}" + ) self._accuracy_test.reset() self._accuracy_test_ats.reset() - def test_batch(self,): - """ + def test_batch( + self, + ): + """ Perform batch testing on the model. - + This method iterates through the test data loader, making predictions for each batch, and calculates various evaluation metrics. It handles both single and multi-sample batches. """ - self.preds_total_list, self.batch_f1_accs_list, self.batch_precision_list, self.batch_recall_list, self.batch_f1_accs_ats_list = [], [], [], [], [] + ( + self.preds_total_list, + self.batch_f1_accs_list, + self.batch_precision_list, + self.batch_recall_list, + self.batch_f1_accs_ats_list, + ) = ([], [], [], [], []) with torch.no_grad(): for batch_idx, batch in enumerate(tqdm(self._test_dl)): @@ -549,7 +577,7 @@ def test_batch(self,): audio_signal_length=audio_signal_length, ) preds = preds.detach().to('cpu') - if preds.shape[0] == 1: # batch size = 1 + if preds.shape[0] == 1: # batch size = 1 self.preds_total_list.append(preds) else: self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) @@ -562,5 +590,7 @@ def test_batch(self,): logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") logging.info(f"Batch ATS F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_ats_list))}") - def diarize(self,): + def diarize( + self, + ): raise NotImplementedError diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index 823cf98590e7..6ed29d3e6a70 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -55,6 +55,7 @@ class SortformerModules(NeuralModule, Exportable): If 'cos_sim', cosine similarity values are used for the input of the sequence models. If 'elem_prod', element-wise product values are used for the input of the sequence models. """ + def init_weights(self, m): if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) @@ -91,9 +92,9 @@ def length_to_mask(self, context_embs): mask (torch.Tensor): tensor of shape (batch_size, max_len) containing 0's in the padded region and 1's elsewhere """ - lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0]) + lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0]) batch_size = context_embs.shape[0] - max_len=context_embs.shape[1] + max_len = context_embs.shape[1] # create a tensor with the shape (batch_size, 1) filled with ones row_vector = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) # create a tensor with the shape (batch_size, max_len) filled with lengths @@ -101,7 +102,7 @@ def length_to_mask(self, context_embs): # create a mask by comparing the row vector and length matrix mask = row_vector < length_matrix return mask.float().to(context_embs.device) - + def forward_speaker_sigmoids(self, hidden_out): hidden_out = self.dropout(F.relu(hidden_out)) hidden_out = self.first_hidden_to_hidden(hidden_out) diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index a1d34e1f7480..a52271a5e83b 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -12,39 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import re +import concurrent.futures import copy +import itertools +import logging import math +import os import random -import logging -import itertools -from copy import deepcopy -import concurrent.futures -from cytoolz import groupby +import re from collections import defaultdict -from typing import Dict, Optional, Tuple, List +from copy import deepcopy +from typing import Dict, List, Optional, Tuple import numpy as np import soundfile -from tqdm import tqdm -from scipy.stats import norm - import torch.utils.data +from cytoolz import groupby +from lhotse import AudioSource, Recording, SupervisionSegment, SupervisionSet, dill_enabled +from lhotse.cut import CutSet, MixedCut, MixTrack, MonoCut from lhotse.cut.set import mix -from lhotse.cut import CutSet, MixedCut, MonoCut, MixTrack -from lhotse import SupervisionSet, SupervisionSegment, dill_enabled, AudioSource, Recording from lhotse.utils import uuid4 +from scipy.stats import norm +from tqdm import tqdm -def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres:float = 0.5) -> torch.Tensor: - """ + +def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> torch.Tensor: + """ Finds the first nonzero value in the matrix, discretizing it to the specified maximum capacity. - + Args: mat (Tensor): A torch tensor representing the matrix. max_cap_val (int): The maximum capacity to which the matrix values will be discretized. thres (float): The threshold value for discretizing the matrix values. - + Returns: mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first nonzero value in each row. """ @@ -61,6 +61,7 @@ def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres:float = 0.5) -> mask_max_indices[mask_max_values == 0] = max_cap_val return mask_max_indices + def find_best_permutation(match_score: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: """ Finds the best permutation indices based on the match score. @@ -78,9 +79,12 @@ def find_best_permutation(match_score: torch.Tensor, speaker_permutations: torch batch_best_perm = torch.argmax(match_score, axis=1) rep_speaker_permutations = speaker_permutations.repeat(batch_best_perm.shape[0], 1).to(match_score.device) perm_size = speaker_permutations.shape[0] - global_inds_vec = torch.arange(0, perm_size * batch_best_perm.shape[0], perm_size).to(batch_best_perm.device) + batch_best_perm + global_inds_vec = ( + torch.arange(0, perm_size * batch_best_perm.shape[0], perm_size).to(batch_best_perm.device) + batch_best_perm + ) return rep_speaker_permutations[global_inds_vec.to(rep_speaker_permutations.device), :] + def reconstruct_labels(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: """ Reconstructs the labels using the best permutation indices with matrix operations. @@ -103,12 +107,13 @@ def reconstruct_labels(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> t reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) return reconstructed_labels + def get_ats_targets( - labels: torch.Tensor, - preds: torch.Tensor, - speaker_permutations: torch.Tensor, - thres: float = 0.5, - tolerance: float = 0 + labels: torch.Tensor, + preds: torch.Tensor, + speaker_permutations: torch.Tensor, + thres: float = 0.5, + tolerance: float = 0, ) -> torch.Tensor: """ Sorts labels and predictions to get the optimal of all arrival-time ordered permutations. @@ -128,25 +133,36 @@ def get_ats_targets( Shape: (batch_size, num_frames, num_speakers) """ # Find the first nonzero frame index for each speaker in each batch - nonzero_ind = find_first_nonzero(mat=labels, max_cap_val=labels.shape[1], thres=thres) # (batch_size, num_speakers) - + nonzero_ind = find_first_nonzero( + mat=labels, max_cap_val=labels.shape[1], thres=thres + ) # (batch_size, num_speakers) + # Sort the first nonzero frame indices for arrival-time ordering sorted_values = torch.sort(nonzero_ind)[0] # (batch_size, num_speakers) perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_frames, num_permutations, num_speakers) - permed_nonzero_ind = find_first_nonzero(mat=permed_labels, max_cap_val=labels.shape[1]) # (batch_size, num_permutations, num_speakers) + permed_nonzero_ind = find_first_nonzero( + mat=permed_labels, max_cap_val=labels.shape[1] + ) # (batch_size, num_permutations, num_speakers) # Compare the first frame indices of sorted labels with those of the permuted labels using tolerance - perm_compare = torch.abs(sorted_values.unsqueeze(1) - permed_nonzero_ind) <= tolerance # (batch_size, num_permutations, num_speakers) + perm_compare = ( + torch.abs(sorted_values.unsqueeze(1) - permed_nonzero_ind) <= tolerance + ) # (batch_size, num_permutations, num_speakers) perm_mask = torch.all(perm_compare, dim=2).float() # (batch_size, num_permutations) - preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, perm_size, 1) # Exapnd the preds: (batch_size, num_frames, num_permutations, num_speakers) + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, perm_size, 1 + ) # Exapnd the preds: (batch_size, num_frames, num_permutations, num_speakers) # Compute the match score for each permutation by comparing permuted labels with preds - match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) * perm_mask # (batch_size, num_permutations) + match_score = ( + torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) * perm_mask + ) # (batch_size, num_permutations) batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_frames, num_speakers) return max_score_permed_labels # (batch_size, num_frames, num_speakers) + def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: """ Sorts labels and predictions to get the optimal permutation based on the match score. @@ -166,23 +182,26 @@ def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutati perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_classes, num_permutations, num_speakers) # Repeat preds to match permutations for comparison - preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, speaker_permutations.shape[0], 1) # (batch_size, num_speakers, num_permutations, num_classes) + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, speaker_permutations.shape[0], 1 + ) # (batch_size, num_speakers, num_permutations, num_classes) match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) # (batch_size, num_permutations) batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) # Reconstruct labels based on the best permutation for each batch max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) return max_score_permed_labels # (batch_size, num_speakers, num_classes) + def apply_spk_mapping(diar_preds: torch.Tensor, spk_mappings: torch.Tensor) -> torch.Tensor: - """ + """ Applies a speaker mapping to diar predictions. Args: - diar_preds (Tensor): The diar predictions tensor. + diar_preds (Tensor): The diar predictions tensor. Dimension: (batch_size, num_frames, num_speakers) spk_mappings (Tensor): The speaker mappings tensor. Dimension: (batch_size, num_speakers) - + Returns: permuted_diar_preds (Tensor): The permuted diar predictions tensor with the given speaker mappings. """ @@ -190,15 +209,18 @@ def apply_spk_mapping(diar_preds: torch.Tensor, spk_mappings: torch.Tensor) -> t permuted_diar_preds = torch.gather(diar_preds, 2, expanded_mappings) return permuted_diar_preds -def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool = False, pattern= r'<\|spltoken\d+\|>') -> Tuple[CutSet, torch.Tensor]: - """ + +def shuffle_spk_mapping( + cuts: list, num_speakers: int, shuffle_spk_mapping: bool = False, pattern=r'<\|spltoken\d+\|>' +) -> Tuple[CutSet, torch.Tensor]: + """ Applies a shuffle mapping to speaker text labels in the cuts. Example: Original cut.text: - "<|spltoken0|> we do shuffle <|spltoken1|> and map speakers <|spltoken0|> yes <|spltoken2|> we keep dimensions" + "<|spltoken0|> we do shuffle <|spltoken1|> and map speakers <|spltoken0|> yes <|spltoken2|> we keep dimensions" Speaker Mapping: [3, 0, 1, 2] Shuffled cut.text: - "<|spltoken3|> we do shuffle <|spltoken0|> and map speakers <|spltoken3|> yes <|spltoken1|> we keep dimensions" + "<|spltoken3|> we do shuffle <|spltoken0|> and map speakers <|spltoken3|> yes <|spltoken1|> we keep dimensions" Args: cuts (List[MonoCut, MixedCut]): A list of Cut instances. @@ -208,11 +230,11 @@ def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool Returns: cuts (list): The updated CutSet with shuffled speaker mappings. - spk_mappings (Tensor): + spk_mappings (Tensor): If shuffle_speaker_mapping is True, shuffled speaker mappings in batch. If shuffle_speaker_mapping is False, speaker mappings in batch is not permuted and returns torch.arange() values. - """ - batch_size = len(cuts) + """ + batch_size = len(cuts) if shuffle_spk_mapping: permuted_indices = torch.rand(batch_size, num_speakers).argsort(dim=1) spk_mappings = torch.gather(torch.arange(num_speakers).repeat(batch_size, 1), 1, permuted_indices) @@ -220,9 +242,9 @@ def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool left_str, right_str = str_pattern.split('d+')[0], str_pattern.split('d+')[1] for idx, cut in enumerate(cuts): word_list = [] - for word in deepcopy(cut.text).split(): + for word in deepcopy(cut.text).split(): if len(re.findall(pattern, word)) > 0: - spk_token_int = int(word.replace(left_str,'').replace(right_str, '')) + spk_token_int = int(word.replace(left_str, '').replace(right_str, '')) new_spk = spk_mappings[idx][spk_token_int] word_list.append(f'{left_str}{new_spk}{right_str}') else: @@ -230,16 +252,18 @@ def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool cuts[idx].supervisions[0].text = ' '.join(word_list) else: spk_mappings = torch.arange(num_speakers).unsqueeze(0).repeat(batch_size, 1) - return cuts, spk_mappings + return cuts, spk_mappings + def find_segments_from_rttm( - recording_id: str, - rttms, - start_after: float, - end_before: float, - adjust_offset: bool=True, - tolerance: float=0.001): - """ + recording_id: str, + rttms, + start_after: float, + end_before: float, + adjust_offset: bool = True, + tolerance: float = 0.001, +): + """ Finds segments from the given rttm file. This function is designed to replace rttm @@ -250,35 +274,36 @@ def find_segments_from_rttm( end_before (float): The end time before which segments are selected. adjust_offset (bool): Whether to adjust the offset of the segments. tolerance (float): The tolerance for time matching. 0.001 by default. - + Returns: segments (List[SupervisionSegment]): A list of SupervisionSegment instances. """ segment_by_recording_id = rttms._segments_by_recording_id if segment_by_recording_id is None: from cytoolz import groupby + segment_by_recording_id = groupby(lambda seg: seg.recording_id, rttms) return [ - # We only modify the offset - the duration remains the same, as we're only shifting the segment - # relative to the Cut's start, and not truncating anything. - segment.with_offset(-start_after) if adjust_offset else segment - for segment in segment_by_recording_id.get(recording_id, []) - if segment.start < end_before + tolerance - and segment.end > start_after + tolerance - ] + # We only modify the offset - the duration remains the same, as we're only shifting the segment + # relative to the Cut's start, and not truncating anything. + segment.with_offset(-start_after) if adjust_offset else segment + for segment in segment_by_recording_id.get(recording_id, []) + if segment.start < end_before + tolerance and segment.end > start_after + tolerance + ] + def speaker_to_target( a_cut, - num_speakers: int = 4, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, spk_tar_all_zero: bool = False, boundary_segments: bool = False, soft_label: bool = False, ignore_num_spk_mismatch: bool = True, soft_thres: float = 0.5, - ): +): ''' Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) This function is needed for speaker diarization with ASR model trainings. @@ -292,7 +317,7 @@ def speaker_to_target( boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - + Returns: mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) ''' @@ -306,14 +331,18 @@ def speaker_to_target( offsets = [0] else: raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - + segments_total = [] for i, cut in enumerate(cut_list): rttms = SupervisionSet.from_rttm(cut.rttm_filepath) - if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included - segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) - else: # segments with seg_start > total_start and seg_end < total_end are included - segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm( + recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0 + ) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find( + recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True + ) for seg in segments_iterator: if seg.start < 0: @@ -323,28 +352,31 @@ def speaker_to_target( seg.duration -= seg.end - cut.duration seg.start += offsets[i] segments_total.append(seg) - + # apply arrival time sorting to the existing segments - segments_total.sort(key = lambda rttm_sup: rttm_sup.start) + segments_total.sort(key=lambda rttm_sup: rttm_sup.start) seen = set() seen_add = seen.add speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] - - speaker_to_idx_map = { - spk: idx - for idx, spk in enumerate(speaker_ats) - } + + speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers - raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") - + raise ValueError( + f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}" + ) + # initialize mask matrices (num_speaker, encoder_hidden_len) - feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default - num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) - if spk_tar_all_zero: + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length( + a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) + if spk_tar_all_zero: frame_mask = torch.zeros((num_samples, num_speakers)) else: - frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) + frame_mask = get_mask_from_segments( + segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch + ) soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) if soft_label: @@ -354,11 +386,19 @@ def speaker_to_target( return mask -def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tensor, num_speakers: int =4, feat_per_sec: int=100, ignore_num_spk_mismatch: bool = False): - """ + +def get_mask_from_segments( + segments: list, + a_cut, + speaker_to_idx_map: torch.Tensor, + num_speakers: int = 4, + feat_per_sec: int = 100, + ignore_num_spk_mismatch: bool = False, +): + """ Generate mask matrix from segments list. This function is needed for speaker diarization with ASR model trainings. - + Args: segments: A list of Lhotse Supervision segments iterator. cut (MonoCut, MixedCut): Lhotse MonoCut or MixedCut instance. @@ -366,13 +406,13 @@ def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tens num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - + Returns: mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). Dimension: (num_speakers, num_frames) """ # get targets with 0.01s frame rate - num_samples = round(a_cut.duration * feat_per_sec) + num_samples = round(a_cut.duration * feat_per_sec) mask = torch.zeros((num_samples, num_speakers)) for rttm_sup in segments: speaker_idx = speaker_to_idx_map[rttm_sup.speaker] @@ -388,17 +428,18 @@ def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tens mask[stf:enf, speaker_idx] = 1.0 return mask + def get_soft_mask(feat_level_target, num_samples, stride): """ Get soft mask from feat_level_target with stride. This function is needed for speaker diarization with ASR model trainings. - + Args: feat_level_target (Tensor): A numpy array of shape (num_frames, num_speakers). Dimension: (num_frames, num_speakers) num_sample (int): The total number of samples. stride (int): The stride for the mask. - """ + """ num_speakers = feat_level_target.shape[1] mask = torch.zeros(num_samples, num_speakers) @@ -412,15 +453,14 @@ def get_soft_mask(feat_level_target, num_samples, stride): seg_end_feat = feat_level_target.shape[0] else: seg_end_feat = stride * index - 1 + int(stride / 2) - mask[index] = torch.mean(feat_level_target[seg_stt_feat:seg_end_feat+1, :], axis=0) + mask[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) return mask + def get_hidden_length_from_sample_length( - num_samples: int, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8 + num_samples: int, num_sample_per_mel_frame: int = 160, num_mel_frame_per_asr_frame: int = 8 ) -> int: - """ + """ Calculate the hidden length from the given number of samples. This function is needed for speaker diarization with ASR model trainings. @@ -439,15 +479,16 @@ def get_hidden_length_from_sample_length( hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) return int(hidden_length) -class ConcatenationMeetingSimulator(): + +class ConcatenationMeetingSimulator: """ This simulator concatenates the segments from different/same sessions to create a - multi-speaker meeting. + multi-speaker meeting. """ def __init__( self, - intra_session_concat_prob: float|List[float] = [0, 1.0, 0.5, 0.2], + intra_session_concat_prob: float | List[float] = [0, 1.0, 0.5, 0.2], data_type: str = "msasr", min_duration: float = 30.0, max_duration: float = 40.0, @@ -460,7 +501,7 @@ def __init__( :param intra_session_concat_prob: the probability of concatenating segments from the same session. [Default: 1] :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', - the transcripts are included in the simulation,and the boundary segments are + the transcripts are included in the simulation,and the boundary segments are not included. [Default: 'msasr'] :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] """ @@ -470,7 +511,9 @@ def __init__( elif len(intra_session_concat_prob) == max_num_speakers: self.intra_session_concat_prob = intra_session_concat_prob else: - raise ValueError(f"intra_session_concat_prob must be either a float or a list of floats, but got {intra_session_concat_prob}") + raise ValueError( + f"intra_session_concat_prob must be either a float or a list of floats, but got {intra_session_concat_prob}" + ) if data_type not in ["msasr", "diar"]: raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") self.data_type = data_type @@ -478,7 +521,9 @@ def __init__( self.max_duration = max_duration self.max_num_speakers = max_num_speakers self.speaker_count_distribution = speaker_count_distribution - assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" + assert ( + len(speaker_count_distribution) == max_num_speakers + ), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" if skip_long_segments: self.skip_duration = max_duration / 2 @@ -489,7 +534,7 @@ def __init__( def fit(self, cuts) -> CutSet: """ - Read the manifest file and return a CutSet object. + Read the manifest file and return a CutSet object. Each line in the manifest file should be a JSON object representing a segment. """ @@ -500,7 +545,7 @@ def fit(self, cuts) -> CutSet: self.spk2cut_ids = defaultdict(list) self.data2num_spk2cut_ids = {} self.sess2num_spk2cut_ids = {} - self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} + self.num_spk2cut_ids = {i + 1: [] for i in range(self.max_num_speakers)} for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): if cut.duration > self.skip_duration: continue @@ -512,20 +557,20 @@ def fit(self, cuts) -> CutSet: self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) if cut.recording_id not in self.sess2num_spk2cut_ids: self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) - + speakers = cut.global_speaker_ids if self.data_type == "msasr": speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) - if len(speakers) != len(speaker_tokens): - # Lhotse automatically fixes the max duration of the cut, - # resulting in the mismatch of the number of speakers + if len(speakers) != len(speaker_tokens): + # Lhotse automatically fixes the max duration of the cut, + # resulting in the mismatch of the number of speakers # and speaker tokens for the last segment # TODO: need to fix the issue in Lhotse that automatically fixes the max duration continue for spk in speakers: self.spk2cut_ids[spk].append(cut.id) self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) - + self.id2cut[cut.id] = cut self.sess2cut_ids[cut.recording_id].append(cut.id) self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) @@ -533,23 +578,21 @@ def fit(self, cuts) -> CutSet: self.num_spk2cut_ids[len(speakers)].append(cut.id) if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: self.data2sess_ids[cut.dataset_id].append(cut.recording_id) - + self.cut_ids = list(self.id2cut.keys()) self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) - - self.data2global_speaker = { - dataset_id: True for dataset_id in self.data2sess_ids.keys() - } - + + self.data2global_speaker = {dataset_id: True for dataset_id in self.data2sess_ids.keys()} + def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: - db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data - + db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data + if is_intra_session_concat: # intra-dataset and intra-session concatenation tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) - else: + else: # intra-dataset but inter-session concatenation tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) @@ -557,44 +600,54 @@ def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> Mix if self.data_type == "msasr": cut = self.reorder_spk_mapping(cut) - assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" - assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" + assert ( + self.min_duration <= cut.duration <= self.max_duration + ), f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" + assert ( + n_speakers == num_speakers + ), f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" return cut - - def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + + def get_intra_session_tracks(self, n_speakers: int = 4, db_norm: float = -25) -> List[MixTrack]: """ Get the tracks for the MixedCut object. """ session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - + total_duration = 0.0 total_spk_set = set() tracks = [] while True: cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) + tracks.append( + MixTrack( + cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), + type=type(cut), + offset=total_duration, + ) + ) total_spk_set = total_spk_set.union(cut.global_speaker_ids) total_duration += cut.duration # break condition if total_duration >= self.min_duration: - if total_duration > self.max_duration: # exceed the maximum duration, starting over + if total_duration > self.max_duration: # exceed the maximum duration, starting over total_duration = 0.0 total_spk_set = set() tracks = [] session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break break else: total_duration = 0.0 total_spk_set = set() tracks = [] session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - + return tracks, len(total_spk_set) - def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + def get_inter_session_tracks(self, n_speakers: int = 4, db_norm: float = -25) -> List[MixTrack]: """ Get the tracks for the MixedCut object. """ @@ -604,7 +657,9 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) if min(sum_spk_list) > n_speakers: - raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") + raise ValueError( + f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers." + ) n_spk_left = n_speakers total_duration = 0.0 @@ -612,7 +667,7 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis tracks = [] num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] while True: - #if n_spk_left == n_speakers: # for more speakers cases + # if n_spk_left == n_speakers: # for more speakers cases # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk < n_spk_left]) if n_spk_left >= 2: n_spk = 2 @@ -626,34 +681,44 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis if not spks.intersection(total_spk_set): break - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) + tracks.append( + MixTrack( + cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), + type=type(cut), + offset=total_duration, + ) + ) total_duration += cut.duration n_spk_left -= n_spk total_spk_set = total_spk_set.union(spks) # break condition - + if total_duration >= self.min_duration: - if total_duration > self.max_duration or len(total_spk_set) < n_speakers: # exceed the maximum duration, starting over + if ( + total_duration > self.max_duration or len(total_spk_set) < n_speakers + ): # exceed the maximum duration, starting over total_duration = 0.0 n_spk_left = n_speakers total_spk_set = set() tracks = [] - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break break else: - if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers + if ( + len(total_spk_set) == n_speakers + ): # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers total_duration = 0.0 n_spk_left = n_speakers total_spk_set = set() tracks = [] - + return tracks, len(total_spk_set) - + def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: """ Concatenate the texts of the input cuts. - + """ global_spk_mapping = {} str_pattern = pattern.replace("\\", '') @@ -667,12 +732,12 @@ def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> st if speaker not in local_spk_mapping: local_spk_mapping[speaker] = len(local_spk_mapping) local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker - + if i != 0: text = '' - for word in track.cut.text.split(): + for word in track.cut.text.split(): if len(re.findall(pattern, word)) > 0: - local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) + local_spk_idx = int(word.replace(left_str, '').replace(right_str, '')) spk = local_inverse_spk_mapping[local_spk_idx] global_spk_idx = global_spk_mapping[spk] text += f'{left_str}{global_spk_idx}{right_str}' @@ -682,12 +747,12 @@ def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> st cut.supervisions[i].text = text else: cut.supervisions[0].text = track.cut.text - # TODO: need to check the last speaker of last track and the first speaker of the current track + # TODO: need to check the last speaker of last track and the first speaker of the current track # if they are the same, we need to remove the the speaker token from the current track for segment-level # Do not need to remove the speaker token for word-level - + return cut - + def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: """ Balance the speaker distribution for the simulated meetings. @@ -700,13 +765,13 @@ def apply_speaker_distribution(self, num_meetings: int, speaker_count_distributi total_spk = sum(speaker_count_distribution) num_speakers2num_meetings = {} for i_spk in range(self.max_num_speakers): - num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + num_speakers2num_meetings[i_spk + 1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) return num_speakers2num_meetings - - + @dill_enabled(True) - def simulate(self, + def simulate( + self, cuts: CutSet, num_meetings: int = 10000, seed: int = 0, @@ -715,42 +780,59 @@ def simulate(self, random.seed(seed) self.fit(cuts) - num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) - logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") - num_speakers2num_meetings[1] = 0 # skip 1-speaker samples - logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') + logging.warn( + f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}." + ) + num_speakers2num_meetings[1] = 0 # skip 1-speaker samples + logging.warn( + f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.' + ) # Step 0: Calculate the number of intra-session and inter-session concatentation samples n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] - valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples - n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} + valid_sim_n_spks = set( + [i + j for i in n_spks for j in n_spks] + ) # valid number of speakers for inter-session samples + n_spk2n_intra_mt, n_spk2n_inter_mt = {i + 1: 0 for i in range(self.max_num_speakers)}, { + i + 1: 0 for i in range(self.max_num_speakers) + } for n_spk, n_mt in num_speakers2num_meetings.items(): - logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) + logging.warn(f"==" * 16 + f"{n_spk}-speaker" + "==" * 16) if n_mt <= 0: - logging.warning(f"No concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + logging.warning( + f"No concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers." + ) continue - n_intra_mt = int(n_mt * self.intra_session_concat_prob[n_spk-1]) + n_intra_mt = int(n_mt * self.intra_session_concat_prob[n_spk - 1]) n_inter_mt = n_mt - n_intra_mt if n_spk in self.num_spk2sess_ids: logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") n_spk2n_intra_mt[n_spk] = n_intra_mt else: - logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + logging.warning( + f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers." + ) n_spk2n_intra_mt[n_spk] = 0 n_inter_mt = n_mt if n_spk in valid_sim_n_spks: logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") n_spk2n_inter_mt[n_spk] = n_inter_mt else: - logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + logging.warning( + f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers." + ) if n_spk2n_intra_mt[n_spk] != 0: n_spk2n_intra_mt[n_spk] = n_mt - logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") + logging.warn( + f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead." + ) else: logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") + logging.warn( + f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""" + ) # Step 1: intra-session num_intra_meetings = 0 intra_mixtures = [] @@ -762,25 +844,30 @@ def simulate(self, for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) num_intra_meetings += n_mt - logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") - + logging.info( + f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}" + ) + # Steo 2: inter-session logging.info(f"Simulating inter-session concatentation samples.") - + num_inter_meetings = 0 inter_mixtures = [] for n_spk, n_mt in n_spk2n_inter_mt.items(): if n_mt <= 0: continue - + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) num_inter_meetings += n_mt - logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") + logging.info( + f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}" + ) if num_inter_meetings + num_intra_meetings == 0: - logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration//2} and max {self.max_duration//2}, or the speaker count distribution is not correctly set.") - + logging.warning( + f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration//2} and max {self.max_duration//2}, or the speaker count distribution is not correctly set." + ) # Multi-processing gets slower, TODO # else: @@ -788,7 +875,7 @@ def simulate(self, # for n_spk, n_mt in num_speakers2num_meetings.items(): # tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_jobs) # futures.extend([tp.submit(self._create_mixture, n_spk) for _ in range(n_mt)]) - # pbar = tqdm(total=num_meetings, desc=f"Simulating mixtures", unit="line", ncols=128) + # pbar = tqdm(total=num_meetings, desc=f"Simulating mixtures", unit="line", ncols=128) # count = 0 # for f in concurrent.futures.as_completed(futures): # count += 1 @@ -798,17 +885,17 @@ def simulate(self, # pbar.close() return CutSet.from_cuts(intra_mixtures + inter_mixtures) - -class MixMeetingSimulator(): + +class MixMeetingSimulator: """ This simulator Mix the segments from different/same sessions to create a - multi-speaker meeting. + multi-speaker meeting. """ def __init__( self, - intra_session_mix_prob: float|List[float] = [0, 0, 0, 0], + intra_session_mix_prob: float | List[float] = [0, 0, 0, 0], data_type: str = "msasr", min_duration: float = 80.0, max_duration: float = 100.0, @@ -820,7 +907,7 @@ def __init__( :param intra_session_mix_prob: the probability of concatenating segments from the same session. [Default: 1] :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', - the transcripts are included in the simulation,and the boundary segments are + the transcripts are included in the simulation,and the boundary segments are not included. [Default: 'msasr'] :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] """ @@ -830,7 +917,9 @@ def __init__( elif len(intra_session_mix_prob) == max_num_speakers: self.intra_session_mix_prob = intra_session_mix_prob else: - raise ValueError(f"intra_session_mix_prob must be either a float or a list of floats, but got {intra_session_mix_prob}") + raise ValueError( + f"intra_session_mix_prob must be either a float or a list of floats, but got {intra_session_mix_prob}" + ) if data_type not in ["msasr", "diar"]: raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") self.data_type = data_type @@ -839,11 +928,13 @@ def __init__( self.max_num_speakers = max_num_speakers self.speaker_count_distribution = speaker_count_distribution self.valid_dataset_ids = valid_dataset_ids - assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" + assert ( + len(speaker_count_distribution) == max_num_speakers + ), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" def fit(self, cuts) -> CutSet: """ - Read the manifest file and return a CutSet object. + Read the manifest file and return a CutSet object. Each line in the manifest file should be a JSON object representing a segment. """ @@ -854,7 +945,7 @@ def fit(self, cuts) -> CutSet: self.spk2cut_ids = defaultdict(list) self.data2num_spk2cut_ids = {} self.sess2num_spk2cut_ids = {} - self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} + self.num_spk2cut_ids = {i + 1: [] for i in range(self.max_num_speakers)} for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): if not self.min_duration <= cut.duration <= self.max_duration: continue @@ -866,20 +957,20 @@ def fit(self, cuts) -> CutSet: self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) if cut.recording_id not in self.sess2num_spk2cut_ids: self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) - + speakers = cut.global_speaker_ids if self.data_type == "msasr": speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) - if len(speakers) != len(speaker_tokens): - # Lhotse automatically fixes the max duration of the cut, - # resulting in the mismatch of the number of speakers + if len(speakers) != len(speaker_tokens): + # Lhotse automatically fixes the max duration of the cut, + # resulting in the mismatch of the number of speakers # and speaker tokens for the last segment # TODO: need to fix the issue in Lhotse that automatically fixes the max duration continue for spk in speakers: self.spk2cut_ids[spk].append(cut.id) self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) - + self.id2cut[cut.id] = cut self.sess2cut_ids[cut.recording_id].append(cut.id) self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) @@ -887,23 +978,21 @@ def fit(self, cuts) -> CutSet: self.num_spk2cut_ids[len(speakers)].append(cut.id) if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: self.data2sess_ids[cut.dataset_id].append(cut.recording_id) - + self.cut_ids = list(self.id2cut.keys()) self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) - - self.data2global_speaker = { - dataset_id: True for dataset_id in self.data2sess_ids.keys() - } - + + self.data2global_speaker = {dataset_id: True for dataset_id in self.data2sess_ids.keys()} + def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: - db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data - + db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data + if is_intra_session_concat: # intra-dataset and intra-session concatenation tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) - else: + else: # intra-dataset but inter-session concatenation tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) @@ -911,43 +1000,51 @@ def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> Mix if self.data_type == "msasr": cut = self.reorder_spk_mapping(cut) - assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" - assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" + assert ( + self.min_duration <= cut.duration <= self.max_duration + ), f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" + assert ( + n_speakers == num_speakers + ), f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" return cut - - def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + + def get_intra_session_tracks(self, n_speakers: int = 4, db_norm: float = -25) -> List[MixTrack]: """ Get the tracks for the MixedCut object. """ session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - + total_spk_set = set() tracks = [] while True: cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) + tracks.append( + MixTrack( + cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0 + ) + ) total_spk_set = total_spk_set.union(cut.global_speaker_ids) total_duration = max(total_duration, cut.duration) # break condition if total_duration >= self.min_duration: - if total_duration > self.max_duration: # exceed the maximum duration, starting over + if total_duration > self.max_duration: # exceed the maximum duration, starting over total_duration = 0.0 total_spk_set = set() tracks = [] session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break break else: total_duration = 0.0 total_spk_set = set() tracks = [] session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - + return tracks, len(total_spk_set) - def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + def get_inter_session_tracks(self, n_speakers: int = 4, db_norm: float = -25) -> List[MixTrack]: """ Get the tracks for the MixedCut object. """ @@ -957,7 +1054,9 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) if min(sum_spk_list) > n_speakers: - raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") + raise ValueError( + f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers." + ) n_spk_left = n_speakers total_duration = 0.0 @@ -977,34 +1076,40 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis if not spks.intersection(total_spk_set): break - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) + tracks.append( + MixTrack( + cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0 + ) + ) total_duration = max(total_duration, cut.duration) n_spk_left -= n_spk total_spk_set = total_spk_set.union(spks) # break condition - + if total_duration >= self.min_duration: - if total_duration > self.max_duration or len(tracks) > 2: # exceed the maximum duration, starting over + if total_duration > self.max_duration or len(tracks) > 2: # exceed the maximum duration, starting over total_duration = 0.0 n_spk_left = n_speakers total_spk_set = set() tracks = [] - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break break else: - if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers + if ( + len(total_spk_set) == n_speakers + ): # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers total_duration = 0.0 n_spk_left = n_speakers total_spk_set = set() tracks = [] - + return tracks, len(total_spk_set) - + def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: """ Concatenate the texts of the input cuts. - + """ global_spk_mapping = {} str_pattern = pattern.replace("\\", '') @@ -1018,12 +1123,12 @@ def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> st if speaker not in local_spk_mapping: local_spk_mapping[speaker] = len(local_spk_mapping) local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker - + if i != 0: text = '' - for word in track.cut.text.split(): + for word in track.cut.text.split(): if len(re.findall(pattern, word)) > 0: - local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) + local_spk_idx = int(word.replace(left_str, '').replace(right_str, '')) spk = local_inverse_spk_mapping[local_spk_idx] global_spk_idx = global_spk_mapping[spk] text += f'{left_str}{global_spk_idx}{right_str}' @@ -1033,12 +1138,12 @@ def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> st cut.supervisions[i].text = text else: cut.supervisions[0].text = track.cut.text - # TODO: need to check the last speaker of last track and the first speaker of the current track + # TODO: need to check the last speaker of last track and the first speaker of the current track # if they are the same, we need to remove the the speaker token from the current track for segment-level # Do not need to remove the speaker token for word-level - + return cut - + def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: """ Balance the speaker distribution for the simulated meetings. @@ -1051,13 +1156,13 @@ def apply_speaker_distribution(self, num_meetings: int, speaker_count_distributi total_spk = sum(speaker_count_distribution) num_speakers2num_meetings = {} for i_spk in range(self.max_num_speakers): - num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + num_speakers2num_meetings[i_spk + 1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) return num_speakers2num_meetings - - + @dill_enabled(True) - def simulate(self, + def simulate( + self, cuts: CutSet, num_meetings: int = 10000, seed: int = 0, @@ -1068,39 +1173,57 @@ def simulate(self, self.fit(cuts) num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) - logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") - num_speakers2num_meetings[1] = 0 # skip 1-speaker samples - logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') + logging.warn( + f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}." + ) + num_speakers2num_meetings[1] = 0 # skip 1-speaker samples + logging.warn( + f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.' + ) # Step 0: Calculate the number of intra-session and inter-session concatentation samples n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] - valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples - n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} + valid_sim_n_spks = set( + [i + j for i in n_spks for j in n_spks] + ) # valid number of speakers for inter-session samples + n_spk2n_intra_mt, n_spk2n_inter_mt = {i + 1: 0 for i in range(self.max_num_speakers)}, { + i + 1: 0 for i in range(self.max_num_speakers) + } for n_spk, n_mt in num_speakers2num_meetings.items(): - logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) + logging.warn(f"==" * 16 + f"{n_spk}-speaker" + "==" * 16) if n_mt <= 0: - logging.warning(f"No intra-session concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + logging.warning( + f"No intra-session concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers." + ) continue - n_intra_mt = int(n_mt * self.intra_session_mix_prob[n_spk-1]) + n_intra_mt = int(n_mt * self.intra_session_mix_prob[n_spk - 1]) n_inter_mt = n_mt - n_intra_mt if n_spk in self.num_spk2sess_ids: logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") n_spk2n_intra_mt[n_spk] = n_intra_mt else: - logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + logging.warning( + f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers." + ) n_spk2n_intra_mt[n_spk] = 0 n_inter_mt = n_mt if n_spk in valid_sim_n_spks: logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") n_spk2n_inter_mt[n_spk] = n_inter_mt else: - logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + logging.warning( + f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers." + ) if n_spk2n_intra_mt[n_spk] != 0: n_spk2n_intra_mt[n_spk] = n_mt - logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") + logging.warn( + f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead." + ) else: logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") + logging.warn( + f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""" + ) # Step 1: intra-session num_intra_meetings = 0 intra_mixtures = [] @@ -1112,28 +1235,35 @@ def simulate(self, for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) num_intra_meetings += n_mt - logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") - + logging.info( + f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}" + ) + # Steo 2: inter-session logging.info(f"Simulating inter-session concatentation samples.") - + num_inter_meetings = 0 inter_mixtures = [] for n_spk, n_mt in n_spk2n_inter_mt.items(): if n_mt <= 0: continue - + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) num_inter_meetings += n_mt - logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") + logging.info( + f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}" + ) if num_inter_meetings + num_intra_meetings == 0: - logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration} and max {self.max_duration}, or the speaker count distribution is not correctly set.") + logging.warning( + f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration} and max {self.max_duration}, or the speaker count distribution is not correctly set." + ) return CutSet.from_cuts(intra_mixtures + inter_mixtures) -class LibriSpeechMixSimulator(): + +class LibriSpeechMixSimulator: def __init__( self, @@ -1151,12 +1281,15 @@ def __init__( self.max_duration = max_duration self.n_mix_speakers = n_mix_speakers self.speaker_count_distribution = speaker_count_distribution - assert len(speaker_count_distribution) == len(n_mix_speakers), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {len(n_mix_speakers)}" + assert len(speaker_count_distribution) == len( + n_mix_speakers + ), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {len(n_mix_speakers)}" def fit(self, cuts) -> CutSet: pass - def simulate(self, + def simulate( + self, cuts: CutSet, num_meetings: int = 10000, seed: int = 0, @@ -1172,7 +1305,8 @@ def simulate(self, cut_set.append(self._create_mixture(n_speakers=n_speakers)) return CutSet.from_cuts(cut_set) -class LibriSpeechMixGenerator(): + +class LibriSpeechMixGenerator: def __init__(self): pass @@ -1201,18 +1335,12 @@ def generate(self, cuts): supervisions=[], recording=Recording( id=wav.split('/')[-1].replace('.wav', ''), - sources=[ - AudioSource( - type='file', - channels=[0], - source=wav - ) - ], - sampling_rate=16000, + sources=[AudioSource(type='file', channels=[0], source=wav)], + sampling_rate=16000, num_samples=wav_samples, - duration=wav_dur + duration=wav_dur, ), - custom=custom + custom=custom, ) tracks.append(MixTrack(cut=cut_1spk, type=type(cut_1spk), offset=offset)) @@ -1220,12 +1348,12 @@ def generate(self, cuts): id=cut.id, recording_id=cut.recording_id, start=0, - duration=offset+wav_dur, + duration=offset + wav_dur, text=cut.text, ) tracks[0].cut.supervisions.append(sup) cut_multi_spk = MixedCut(id=cut.id, tracks=tracks) - + cut_set.append(cut_multi_spk) - - return CutSet.from_cuts(cut_set) \ No newline at end of file + + return CutSet.from_cuts(cut_set) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 80b3e1f918b8..046f32c1d48f 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -21,10 +21,10 @@ from typing import Dict, List, Tuple, Union import numpy as np -from omegaconf import OmegaConf -from omegaconf.listconfig import ListConfig import soundfile as sf import torch +from omegaconf import OmegaConf +from omegaconf.listconfig import ListConfig from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm @@ -589,7 +589,7 @@ def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, Number of decimals to round the offset and duration values. """ audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] - for (stt, end) in overlap_range_list: + for stt, end in overlap_range_list: meta = { "audio_filepath": audio_path, "offset": round(stt, decimals), @@ -749,14 +749,14 @@ def fl2int(x: float, decimals: int = 3) -> int: """ Convert floating point number to integer. """ - return torch.round(torch.tensor([x * (10 ** decimals)]), decimals=0).int().item() + return torch.round(torch.tensor([x * (10**decimals)]), decimals=0).int().item() def int2fl(x: int, decimals: int = 3) -> float: """ Convert integer to floating point number. """ - return torch.round(torch.tensor([x / (10 ** decimals)]), decimals=decimals).item() + return torch.round(torch.tensor([x / (10**decimals)]), decimals=decimals).item() def merge_float_intervals(ranges: List[List[float]], decimals: int = 5, margin: int = 2) -> List[List[float]]: @@ -902,9 +902,10 @@ def segments_manifest_to_subsegments_manifest( pwd = os.getcwd() subsegments_manifest_file = os.path.join(pwd, 'subsegments.json') - with open(segments_manifest_file, 'r') as segments_manifest, open( - subsegments_manifest_file, 'w' - ) as subsegments_manifest: + with ( + open(segments_manifest_file, 'r') as segments_manifest, + open(subsegments_manifest_file, 'w') as subsegments_manifest, + ): segments = segments_manifest.readlines() for segment in segments: segment = segment.strip() @@ -933,22 +934,22 @@ def segments_manifest_to_subsegments_manifest( def get_subsegments( - offset: float, - window: float, - shift: float, - duration: float, + offset: float, + window: float, + shift: float, + duration: float, min_subsegment_duration: float = 0.01, decimals: int = 2, use_asr_style_frame_count: bool = False, sample_rate: int = 16000, feat_per_sec: int = 100, - ) -> List[List[float]]: +) -> List[List[float]]: """ Return subsegments from a segment of audio file. - + Example: (window, shift) = 1.5, 0.75 - Segment: [12.05, 14.45] + Segment: [12.05, 14.45] Subsegments: [[12.05, 13.55], [12.8, 14.3], [13.55, 14.45], [14.3, 14.45]] Args: @@ -959,30 +960,30 @@ def get_subsegments( min_subsegment_duration (float): Exclude subsegments smaller than this duration value decimals (int): Number of decimal places to round to use_asr_style_frame_count (bool): If True, use asr style frame count to generate subsegments. - For example, if duration is 10 secs and frame_shift is 0.08 secs, + For example, if duration is 10 secs and frame_shift is 0.08 secs, it results in (10/0.08)+1 = 125 + 1 frames. - + Returns: subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment """ - subsegments: List[List[float]] = [] + subsegments: List[List[float]] = [] start = offset slice_end = start + duration if min_subsegment_duration <= duration < shift: slices = 1 - elif use_asr_style_frame_count is True: - num_feat_frames = np.ceil((1+duration*sample_rate)/int(sample_rate/feat_per_sec)).astype(int) - slices = np.ceil(num_feat_frames/int(feat_per_sec*shift)).astype(int) + elif use_asr_style_frame_count is True: + num_feat_frames = np.ceil((1 + duration * sample_rate) / int(sample_rate / feat_per_sec)).astype(int) + slices = np.ceil(num_feat_frames / int(feat_per_sec * shift)).astype(int) slice_end = start + shift * slices else: - slices = np.ceil(1+ (duration-window)/shift).astype(int) + slices = np.ceil(1 + (duration - window) / shift).astype(int) if slices == 1: if min(duration, window) >= min_subsegment_duration: subsegments.append([start, min(duration, window)]) - elif slices > 0: # What if slcies = 0 ? + elif slices > 0: # What if slcies = 0 ? start_col = torch.arange(offset, slice_end, shift)[:slices] dur_col = window * torch.ones(slices) - dur_col = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) + dur_col = torch.min(slice_end * torch.ones_like(start_col) - start_col, window * torch.ones_like(start_col)) dur_col = torch.round(dur_col, decimals=decimals) valid_mask = dur_col >= min_subsegment_duration valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) @@ -990,7 +991,13 @@ def get_subsegments( return subsegments -def get_target_sig(sig, start_sec: float, end_sec: float, slice_length: int, sample_rate: int,) -> torch.Tensor: +def get_target_sig( + sig, + start_sec: float, + end_sec: float, + slice_length: int, + sample_rate: int, +) -> torch.Tensor: """ Extract time-series signal from the given audio buffer based on the start and end timestamps. @@ -1037,15 +1044,16 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: return [[float(range_tensor[k][0]), float(range_tensor[k][1])] for k in range(range_tensor.shape[0])] -def generate_diarization_output_lines(speaker_timestamps, model_spk_num): - speaker_lines_total = [] +def generate_diarization_output_lines(speaker_timestamps, model_spk_num): + speaker_lines_total = [] for spk_idx in range(model_spk_num): ts_invervals = speaker_timestamps[spk_idx] merged_ts_intervals = merge_float_intervals(ts_invervals) for ts_interval in merged_ts_intervals: speaker_lines_total.extend([f"{ts_interval[0]:.3f} {ts_interval[1]:.3f} speaker_{int(spk_idx)}"]) return speaker_lines_total - + + def get_speech_labels_for_update( frame_start: float, buffer_end: float, @@ -1113,9 +1121,12 @@ def get_speech_labels_for_update( return speech_label_for_new_segments, cumulative_speech_labels -def get_new_cursor_for_update(frame_start: float, segment_range_ts: List[List[float]],) -> Tuple[float, int]: +def get_new_cursor_for_update( + frame_start: float, + segment_range_ts: List[List[float]], +) -> Tuple[float, int]: """ - Function for updating a cursor online speaker diarization. + Function for updating a cursor online speaker diarization. Remove the old segments that overlap with the new frame (self.frame_start) cursor_for_old_segments is set to the onset of the t_range popped lastly. @@ -1273,7 +1284,10 @@ def get_online_subsegments_from_buffer( range_t = [max(0, range_offs[0]), range_offs[1]] subsegments = get_subsegments( - offset=range_t[0], window=window, shift=shift, duration=(range_t[1] - range_t[0]), + offset=range_t[0], + window=window, + shift=shift, + duration=(range_t[1] - range_t[0]), ) ind_offset, sigs, ranges, inds = get_online_segments_from_slices( sig=audio_buffer, @@ -1444,8 +1458,7 @@ def generate_speaker_timestamps( def get_uniq_id_list_from_manifest(manifest_file: str): - """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list. - """ + """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list.""" uniq_id_list = [] with open(manifest_file, 'r', encoding='utf-8') as manifest: for i, line in enumerate(manifest.readlines()): @@ -1626,21 +1639,22 @@ def make_rttm_with_overlap( return all_reference, all_hypothesis -def timestamps_to_pyannote_object(speaker_timestamps: List[Tuple[float, float]], - uniq_id: str, - audio_rttm_values: Dict[str, str], - all_hypothesis: List[Tuple[str, Timeline]], - all_reference: List[Tuple[str, Timeline]], - all_uems: List[Tuple[str, Timeline]], - out_rttm_dir: str | None - ): - """ +def timestamps_to_pyannote_object( + speaker_timestamps: List[Tuple[float, float]], + uniq_id: str, + audio_rttm_values: Dict[str, str], + all_hypothesis: List[Tuple[str, Timeline]], + all_reference: List[Tuple[str, Timeline]], + all_uems: List[Tuple[str, Timeline]], + out_rttm_dir: str | None, +): + """ Convert speaker timestamps to pyannote.core.Timeline object. - + Args: - speaker_timestamps (List[Tuple[float, float]]): + speaker_timestamps (List[Tuple[float, float]]): Timestamps of each speaker: start time and end time of each speaker. - uniq_id (str): + uniq_id (str): Unique ID of each speaker. audio_rttm_values (Dict[str, str]): Dictionary of manifest values. @@ -1652,7 +1666,7 @@ def timestamps_to_pyannote_object(speaker_timestamps: List[Tuple[float, float]], List of uems in pyannote.core.Timeline object. out_rttm_dir (str | None): Directory to save RTTMs - + Returns: all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): List of hypothesis in pyannote.core.Timeline object with an added Timeline object. @@ -1662,47 +1676,49 @@ def timestamps_to_pyannote_object(speaker_timestamps: List[Tuple[float, float]], List of uems in pyannote.core.Timeline object with an added Timeline object. """ offset, dur = float(audio_rttm_values.get('offset', None)), float(audio_rttm_values.get('duration', None)) - hyp_labels = generate_diarization_output_lines(speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps)) + hyp_labels = generate_diarization_output_lines( + speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps) + ) hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=uniq_id) if out_rttm_dir is not None and os.path.exists(out_rttm_dir): - with open(f'{out_rttm_dir}/{uniq_id}.rttm','w') as f: + with open(f'{out_rttm_dir}/{uniq_id}.rttm', 'w') as f: hypothesis.write_rttm(f) all_hypothesis.append([uniq_id, hypothesis]) rttm_file = audio_rttm_values.get('rttm_filepath', None) if rttm_file is not None and os.path.exists(rttm_file): - uem_lines = [[offset, dur+offset]] + uem_lines = [[offset, dur + offset]] org_ref_labels = rttm_to_labels(rttm_file) ref_labels = org_ref_labels reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) uem_obj = get_uem_object(uem_lines, uniq_id=uniq_id) all_uems.append(uem_obj) all_reference.append([uniq_id, reference]) - return all_hypothesis, all_reference, all_uems - + return all_hypothesis, all_reference, all_uems + + def get_uem_object(uem_lines: List[List[float]], uniq_id: str): """ Generate pyannote timeline segments for uem file. - + file format UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME - + Args: uem_lines (list): list of session ID and start, end times. Example: [[0.0, 30.41], [60.04, 165.83]] uniq_id (str): Unique session ID. - + Returns: timeline (pyannote.core.Timeline): pyannote timeline object. """ timeline = Timeline(uri=uniq_id) for uem_stt_end in uem_lines: - start_time, end_time = uem_stt_end + start_time, end_time = uem_stt_end timeline.add(Segment(float(start_time), float(end_time))) return timeline - def embedding_normalize(embs, use_std=False, eps=1e-10): """ Mean and l2 length normalize the input speaker embeddings diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 192c42375dca..f7374931bc45 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -35,8 +35,9 @@ from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm -from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object + from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel +from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging @@ -576,7 +577,7 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc """ if speech_segments.shape == torch.Size([0]): return speech_segments - + min_duration_on = per_args.get('min_duration_on', 0.0) min_duration_off = per_args.get('min_duration_off', 0.0) filter_speech_first = per_args.get('filter_speech_first', 1.0) @@ -1712,34 +1713,34 @@ def frame_vad_eval_detection_error( def ts_vad_post_processing( - ts_vad_binary_vec: torch.Tensor, - cfg_vad_params: OmegaConf, - unit_10ms_frame_count: int=8, - bypass_postprocessing: bool = False - ): + ts_vad_binary_vec: torch.Tensor, + cfg_vad_params: OmegaConf, + unit_10ms_frame_count: int = 8, + bypass_postprocessing: bool = False, +): """ Post-processing on diarization results using VAD style post-processing methods. These post-processing methods are inspired by the following paper: - Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). Args: - ts_vad_binary_vec (Tensor): + ts_vad_binary_vec (Tensor): Sigmoid values of each frame and each speaker. Dimension: (num_frames,) - cfg_vad_params (OmegaConf): + cfg_vad_params (OmegaConf): Configuration (omega config) of VAD parameters. - unit_10ms_frame_count (int, optional): + unit_10ms_frame_count (int, optional): an integer indicating the number of 10ms frames in a unit. For example, if unit_10ms_frame_count is 8, then each frame is 0.08 seconds. - bypass_postprocessing (bool, optional): + bypass_postprocessing (bool, optional): If True, diarization post-processing will be bypassed. Returns: - speech_segments (Tensor): + speech_segments (Tensor): start and end of each speech segment. Dimension: (num_segments, 2) - - Example: + + Example: tensor([[ 0.0000, 3.0400], [ 6.0000, 6.0800], ... @@ -1751,9 +1752,9 @@ def ts_vad_post_processing( speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) speech_segments = filtering(speech_segments, cfg_vad_params) else: - cfg_vad_params.onset=0.5 - cfg_vad_params.offset=0.5 - cfg_vad_params.pad_onset=0.0 - cfg_vad_params.pad_offset=0.0 + cfg_vad_params.onset = 0.5 + cfg_vad_params.offset = 0.5 + cfg_vad_params.pad_onset = 0.0 + cfg_vad_params.pad_offset = 0.0 speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) - return speech_segments \ No newline at end of file + return speech_segments diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 144ae405de52..632ec06bc647 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -1242,6 +1242,7 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: ) return item + class EndtoEndDiarizationLabel(_Collection): """List of diarization audio-label correspondence with preprocessing.""" @@ -1283,9 +1284,7 @@ def __init__( output_type = self.OUTPUT_TYPE data, duration_filtered = [], 0.0 - zipped_items = zip( - audio_files, uniq_ids, durations, rttm_files, offsets - ) + zipped_items = zip(audio_files, uniq_ids, durations, rttm_files, offsets) for ( audio_file, uniq_id, @@ -1328,7 +1327,8 @@ def __init__( data.sort(key=lambda entity: entity.duration) logging.info( - "Filtered duration for loading collection is %f.", duration_filtered, + "Filtered duration for loading collection is %f.", + duration_filtered, ) logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") @@ -1346,8 +1346,8 @@ def __init__( **kwargs, ): """ - Parse lists of audio files, durations, RTTM (Diarization annotation) files. - Since diarization model infers only two speakers, speaker pairs are generated + Parse lists of audio files, durations, RTTM (Diarization annotation) files. + Since diarization model infers only two speakers, speaker pairs are generated from the total number of speakers in the session. Args: @@ -1404,12 +1404,12 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: raise ValueError( f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." ) - if isinstance(item['audio_file'], list): + if isinstance(item['audio_file'], list): item['audio_file'] = [os.path.expanduser(audio_file_path) for audio_file_path in item['audio_file']] else: item['audio_file'] = os.path.expanduser(item['audio_file']) - if not isinstance(item['audio_file'], list): + if not isinstance(item['audio_file'], list): if 'uniq_id' not in item: item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] elif 'uniq_id' not in item: From 4ddc59bc0d8fc606d8452ddcfcc5e5a12a3ed9e0 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 16:56:08 -0800 Subject: [PATCH 05/47] Reflecting comments and removing unnecessary parts for this PR Signed-off-by: taejinp --- ...rtformer_diarizer_hybrid_loss_4spk-v1.yaml | 35 +- ...ortformer_diar_4spk-v1_callhome-part1.yaml | 4 - .../sortformer_diar_4spk-v1_dihard-dev.yaml | 4 - .../neural_diarizer/e2e_diarize_speech.py | 35 +- nemo/collections/asr/models/__init__.py | 10 +- .../asr/models/sortformer_diar_models.py | 16 +- .../asr/modules/sortformer_modules.py | 16 +- .../asr/parts/utils/asr_multispeaker_utils.py | 1024 ++--------------- .../asr/parts/utils/speaker_utils.py | 6 +- nemo/collections/asr/parts/utils/vad_utils.py | 1 - 10 files changed, 122 insertions(+), 1029 deletions(-) diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml index e44bae976729..04409a4cd60a 100644 --- a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -1,6 +1,6 @@ -# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. -# Model name convention for Sortformer Diarizer: sortformer_diarizer____loss.yaml -# (Example) `sortformer_diarizer_FC18_TF18_hybrid_loss.yaml` has 18 layers for FastConformer and 18 layers of Transformer. +sortformer_diarizer_hybrid_loss_4spk-v1.yaml# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Model name convention for Sortformer Diarizer: sortformer_diarizer___.yaml +# (Example) `sortformer_diarizer_hybrid_loss_4spk-v1.yaml`. # Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. # Example: a manifest line for training # {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} @@ -10,21 +10,21 @@ num_workers: 18 batch_size: 8 model: - pil_weight: 0.5 - ats_weight: 0.5 - num_workers: ${num_workers} - fc_d_model: 512 - tf_d_model: 192 - max_num_of_spks: 4 # Number of speakers per model. This is currently fixed at 4. - session_len_sec: 90 + pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model + ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model + num_workers: ${num_workers} # Number of workers for data loading + fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder + tf_d_model: 192 # Hidden dimension size of the Transformer Encoder + max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4 + session_len_sec: 90 # Maximum session length in seconds train_ds: manifest_filepath: ??? sample_rate: ${sample_rate} num_spks: ${model.max_num_of_spks} session_len_sec: ${model.session_len_sec} - soft_label_thres: 0.5 - soft_targets: False + soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity. + soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss labels: null batch_size: ${batch_size} shuffle: True @@ -52,7 +52,7 @@ model: sample_rate: ${sample_rate} num_spks: ${model.max_num_of_spks} session_len_sec: ${model.session_len_sec} - soft_label_thres: 0.5 + soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes. soft_targets: False labels: null batch_size: ${batch_size} @@ -121,10 +121,8 @@ model: subsampling_factor: 8 # must be power of 2 for striding and vggnet subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model causal_downsampling: false - # Feed forward module's params ff_expansion_factor: 4 - # Multi-headed Attention Module's params self_attention_model: rel_pos # rel_pos or abs_pos n_heads: 8 # may need to be lower for smaller d_models @@ -134,19 +132,16 @@ model: xscaling: true # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers pos_emb_max_len: 5000 - # Convolution module's params conv_kernel_size: 9 conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) conv_context_size: null - - ### regularization + # Regularization dropout: 0.1 # The dropout used in most of the Conformer Modules dropout_pre_encoder: 0.1 # The dropout used before the encoder dropout_emb: 0.0 # The dropout used for embeddings dropout_att: 0.1 # The dropout for multi-headed attention modules - - # set to non-zero to enable stochastic depth + # Set to non-zero to enable stochastic depth stochastic_depth_drop_prob: 0.0 stochastic_depth_mode: linear # linear or uniform stochastic_depth_start_layer: 1 diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml index 3733e1285b77..ebed4a649730 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -5,10 +5,6 @@ # These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477. # Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. parameters: - window_length_in_sec: 0.0 # Not used - shift_length_in_sec: 0.0 # Not used - smoothing: False # Not used - overlap: 0.5 # Not used onset: 0.53 # Onset threshold for detecting the beginning and end of a speech offset: 0.49 # Offset threshold for detecting the end of a speech pad_onset: 0.23 # Adding durations before each speech segment diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml index 275bc86db4cd..9beaff6e3c7c 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml @@ -5,10 +5,6 @@ # These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. # Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. parameters: - window_length_in_sec: 0.0 # Not used - shift_length_in_sec: 0.0 # Not used - smoothing: False # Not used - overlap: 0.5 # Not used onset: 0.64 # Onset threshold for detecting the beginning and end of a speech offset: 0.74 # Offset threshold for detecting the end of a speech pad_onset: 0.06 # Adding durations before each speech segment diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 98f2ee10e523..a2dcd15dbb71 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -47,17 +47,12 @@ @dataclass class PostProcessingParams: - window_length_in_sec: float = 0.15 - shift_length_in_sec: float = 0.01 - smoothing: bool = False - overlap: float = 0.5 - onset: float = 0.5 - offset: float = 0.5 - pad_onset: float = 0.0 - pad_offset: float = 0.0 - min_duration_on: float = 0.0 - min_duration_off: float = 0.0 - filter_speech_first: bool = True + onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech + offset: float = 0.5 # Offset threshold for detecting the end of a speech + pad_onset: float = 0.0 # Adding durations before each speech segment + pad_offset: float = 0.0 # Adding durations after each speech segment + min_duration_on: float = 0.0 # Threshold for small non-speech deletion + min_duration_off: float = 0.0 # Threshold for short speech segment deletion @dataclass class DiarizationConfig: @@ -124,7 +119,9 @@ def load_postprocessing_from_yaml(postprocessing_yaml): def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: """ Suggests hyperparameters for postprocessing using Optuna. - + See the following link for `trial` instance in Optuna framework. + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial + Args: postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. @@ -373,13 +370,13 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: out_rttm_dir=cfg.out_rttm_dir ) logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") - metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, - all_reference=all_refs, - all_hypothesis=all_hyps, - all_uem=all_uems, - collar=cfg.collar, - ignore_overlap=cfg.ignore_overlap - ) + score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap + ) logging.info(f"PostProcessingParams: {postprocessing_cfg}") if __name__ == '__main__': diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 31194d8849f0..e85500593656 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -19,8 +19,8 @@ EncDecClassificationModel, EncDecFrameClassificationModel, ) -from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel @@ -36,5 +36,9 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel -from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel -from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE +from nemo.collections.asr.models.ssl_models import ( + EncDecDenoiseMaskedTokenPredModel, + EncDecMaskedTokenPredModel, + SpeechEncDecSelfSupervisedModel, +) +from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE \ No newline at end of file diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 50cdf6214d5b..7b2b5cf17793 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -36,17 +36,6 @@ from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_pil_targets, get_ats_targets from nemo.utils import logging -try: - from torch.cuda.amp import autocast -except ImportError: - from contextlib import contextmanager - - @contextmanager - def autocast(enabled=None): - yield - -# torch.backends.cudnn.enabled = False - __all__ = ['SortformerEncLabelModel'] class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): @@ -549,14 +538,13 @@ def test_batch(self,): audio_signal_length=audio_signal_length, ) preds = preds.detach().to('cpu') - if preds.shape[0] == 1: # batch size = 1 + if preds.shape[0] == 1: # If batch size is absolute 1 self.preds_total_list.append(preds) else: self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) torch.cuda.empty_cache() self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) - # except: - # import ipdb; ipdb.set_trace() + logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index 823cf98590e7..1805327ab69b 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -20,7 +20,6 @@ from nemo.core.classes.exportable import Exportable from nemo.core.classes.module import NeuralModule -from nemo.core.neural_types import EncodedRepresentation, LengthsType, NeuralType, SpectrogramType from nemo.core.neural_types.elements import ProbsType __all__ = ['SortformerModules'] @@ -37,23 +36,12 @@ class SortformerModules(NeuralModule, Exportable): Max number of speakers that are processed by the model. In `MSDD_module`, `num_spks=2` for pairwise inference. hidden_size (int): Number of hidden units in sequence models and intermediate layers. - num_lstm_layers (int): - Number of the stacked LSTM layers. dropout_rate (float): Dropout rate for linear layers, CNN and LSTM. + fc_d_model (int): + Dimension of the embedding vectors. tf_d_model (int): Dimension of the embedding vectors. - scale_n (int): - Number of scales in multi-scale system. - clamp_max (float): - Maximum value for limiting the scale weight values. - conv_repeat (int): - Number of CNN layers after the first CNN layer. - weighting_scheme (str): - Name of the methods for estimating the scale weights. - context_vector_type (str): - If 'cos_sim', cosine similarity values are used for the input of the sequence models. - If 'elem_prod', element-wise product values are used for the input of the sequence models. """ def init_weights(self, m): if type(m) == nn.Linear: diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index a1d34e1f7480..fed55730e7f1 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -12,29 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import re -import copy import math -import random -import logging -import itertools -from copy import deepcopy -import concurrent.futures -from cytoolz import groupby -from collections import defaultdict -from typing import Dict, Optional, Tuple, List - -import numpy as np -import soundfile +import torch from tqdm import tqdm -from scipy.stats import norm - -import torch.utils.data -from lhotse.cut.set import mix -from lhotse.cut import CutSet, MixedCut, MonoCut, MixTrack -from lhotse import SupervisionSet, SupervisionSegment, dill_enabled, AudioSource, Recording -from lhotse.utils import uuid4 +from lhotse import SupervisionSet +from lhotse.cut import MixedCut, MonoCut def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres:float = 0.5) -> torch.Tensor: """ @@ -163,7 +145,6 @@ def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutati torch.Tensor: A tensor of permuted labels that best match the predictions. Shape: (batch_size, num_speakers, num_classes) """ - perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_classes, num_permutations, num_speakers) # Repeat preds to match permutations for comparison preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, speaker_permutations.shape[0], 1) # (batch_size, num_speakers, num_permutations, num_classes) @@ -173,65 +154,6 @@ def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutati max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) return max_score_permed_labels # (batch_size, num_speakers, num_classes) -def apply_spk_mapping(diar_preds: torch.Tensor, spk_mappings: torch.Tensor) -> torch.Tensor: - """ - Applies a speaker mapping to diar predictions. - - Args: - diar_preds (Tensor): The diar predictions tensor. - Dimension: (batch_size, num_frames, num_speakers) - spk_mappings (Tensor): The speaker mappings tensor. - Dimension: (batch_size, num_speakers) - - Returns: - permuted_diar_preds (Tensor): The permuted diar predictions tensor with the given speaker mappings. - """ - expanded_mappings = spk_mappings.unsqueeze(1).expand(-1, diar_preds.size(1), -1) - permuted_diar_preds = torch.gather(diar_preds, 2, expanded_mappings) - return permuted_diar_preds - -def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool = False, pattern= r'<\|spltoken\d+\|>') -> Tuple[CutSet, torch.Tensor]: - """ - Applies a shuffle mapping to speaker text labels in the cuts. - Example: - Original cut.text: - "<|spltoken0|> we do shuffle <|spltoken1|> and map speakers <|spltoken0|> yes <|spltoken2|> we keep dimensions" - Speaker Mapping: [3, 0, 1, 2] - Shuffled cut.text: - "<|spltoken3|> we do shuffle <|spltoken0|> and map speakers <|spltoken3|> yes <|spltoken1|> we keep dimensions" - - Args: - cuts (List[MonoCut, MixedCut]): A list of Cut instances. - num_speakers (int): The total number of speakers. - shuffle_spk_mapping (bool): Whether to shuffle the speaker mappings. - pattern (str): A regular expression pattern for speaker tokens. - - Returns: - cuts (list): The updated CutSet with shuffled speaker mappings. - spk_mappings (Tensor): - If shuffle_speaker_mapping is True, shuffled speaker mappings in batch. - If shuffle_speaker_mapping is False, speaker mappings in batch is not permuted and returns torch.arange() values. - """ - batch_size = len(cuts) - if shuffle_spk_mapping: - permuted_indices = torch.rand(batch_size, num_speakers).argsort(dim=1) - spk_mappings = torch.gather(torch.arange(num_speakers).repeat(batch_size, 1), 1, permuted_indices) - str_pattern = pattern.replace("\\", '') - left_str, right_str = str_pattern.split('d+')[0], str_pattern.split('d+')[1] - for idx, cut in enumerate(cuts): - word_list = [] - for word in deepcopy(cut.text).split(): - if len(re.findall(pattern, word)) > 0: - spk_token_int = int(word.replace(left_str,'').replace(right_str, '')) - new_spk = spk_mappings[idx][spk_token_int] - word_list.append(f'{left_str}{new_spk}{right_str}') - else: - word_list.append(word) - cuts[idx].supervisions[0].text = ' '.join(word_list) - else: - spk_mappings = torch.arange(num_speakers).unsqueeze(0).repeat(batch_size, 1) - return cuts, spk_mappings - def find_segments_from_rttm( recording_id: str, rttms, @@ -268,91 +190,6 @@ def find_segments_from_rttm( and segment.end > start_after + tolerance ] -def speaker_to_target( - a_cut, - num_speakers: int = 4, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8, - spk_tar_all_zero: bool = False, - boundary_segments: bool = False, - soft_label: bool = False, - ignore_num_spk_mismatch: bool = True, - soft_thres: float = 0.5, - ): - ''' - Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) - This function is needed for speaker diarization with ASR model trainings. - - Args: - a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. - num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default - num_sample_per_mel_frame (int): number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) - num_mel_frame_per_asr_frame (int): encoder subsampling_factor, 8 by default - spk_tar_all_zero (Tensor): set to True gives all zero "mask" - boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training - soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - - Returns: - mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) - ''' - # get cut-related segments from rttms - # basename = os.path.basename(a_cut.rttm_filepath).replace('.rttm', '') - if isinstance(a_cut, MixedCut): - cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] - offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] - elif isinstance(a_cut, MonoCut): - cut_list = [a_cut] - offsets = [0] - else: - raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - - segments_total = [] - for i, cut in enumerate(cut_list): - rttms = SupervisionSet.from_rttm(cut.rttm_filepath) - if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included - segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) - else: # segments with seg_start > total_start and seg_end < total_end are included - segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) - - for seg in segments_iterator: - if seg.start < 0: - seg.duration += seg.start - seg.start = 0 - if seg.end > cut.duration: - seg.duration -= seg.end - cut.duration - seg.start += offsets[i] - segments_total.append(seg) - - # apply arrival time sorting to the existing segments - segments_total.sort(key = lambda rttm_sup: rttm_sup.start) - - seen = set() - seen_add = seen.add - speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] - - speaker_to_idx_map = { - spk: idx - for idx, spk in enumerate(speaker_ats) - } - if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers - raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") - - # initialize mask matrices (num_speaker, encoder_hidden_len) - feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default - num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) - if spk_tar_all_zero: - frame_mask = torch.zeros((num_samples, num_speakers)) - else: - frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) - soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) - - if soft_label: - mask = soft_mask - else: - mask = (soft_mask > soft_thres).float() - - return mask def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tensor, num_speakers: int =4, feat_per_sec: int=100, ignore_num_spk_mismatch: bool = False): """ @@ -439,793 +276,88 @@ def get_hidden_length_from_sample_length( hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) return int(hidden_length) -class ConcatenationMeetingSimulator(): - """ - This simulator concatenates the segments from different/same sessions to create a - multi-speaker meeting. - """ - - def __init__( - self, - intra_session_concat_prob: float|List[float] = [0, 1.0, 0.5, 0.2], - data_type: str = "msasr", - min_duration: float = 30.0, - max_duration: float = 40.0, - max_num_speakers: int = 4, - speaker_count_distribution: List[float] = [0, 2, 3, 4], - skip_long_segments: bool = True, - valid_dataset_ids: List[str] = [], - ): - """ - :param intra_session_concat_prob: the probability of concatenating segments from the same - session. [Default: 1] - :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', - the transcripts are included in the simulation,and the boundary segments are - not included. [Default: 'msasr'] - :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] - """ - super().__init__() - if isinstance(intra_session_concat_prob, float): - self.intra_session_concat_prob = [intra_session_concat_prob] * (max_num_speakers) - elif len(intra_session_concat_prob) == max_num_speakers: - self.intra_session_concat_prob = intra_session_concat_prob - else: - raise ValueError(f"intra_session_concat_prob must be either a float or a list of floats, but got {intra_session_concat_prob}") - if data_type not in ["msasr", "diar"]: - raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") - self.data_type = data_type - self.min_duration = min_duration - self.max_duration = max_duration - self.max_num_speakers = max_num_speakers - self.speaker_count_distribution = speaker_count_distribution - assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" - - if skip_long_segments: - self.skip_duration = max_duration / 2 - else: - self.skip_duration = max_duration - - self.valid_dataset_ids = valid_dataset_ids - - def fit(self, cuts) -> CutSet: - """ - Read the manifest file and return a CutSet object. - Each line in the manifest file should be a JSON object representing a segment. - """ - - self.id2cut = {} - self.sess2cut_ids = defaultdict(list) - self.sess2spks = defaultdict(set) - self.data2sess_ids = defaultdict(list) - self.spk2cut_ids = defaultdict(list) - self.data2num_spk2cut_ids = {} - self.sess2num_spk2cut_ids = {} - self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} - for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): - if cut.duration > self.skip_duration: - continue - if not hasattr(cut, 'dataset_id') or cut.dataset_id is None: - continue - if self.valid_dataset_ids and cut.dataset_id not in self.valid_dataset_ids: - continue - if cut.dataset_id not in self.data2num_spk2cut_ids: - self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) - if cut.recording_id not in self.sess2num_spk2cut_ids: - self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) - - speakers = cut.global_speaker_ids - if self.data_type == "msasr": - speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) - if len(speakers) != len(speaker_tokens): - # Lhotse automatically fixes the max duration of the cut, - # resulting in the mismatch of the number of speakers - # and speaker tokens for the last segment - # TODO: need to fix the issue in Lhotse that automatically fixes the max duration - continue - for spk in speakers: - self.spk2cut_ids[spk].append(cut.id) - self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) - - self.id2cut[cut.id] = cut - self.sess2cut_ids[cut.recording_id].append(cut.id) - self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) - self.sess2num_spk2cut_ids[cut.recording_id][len(speakers)].append(cut.id) - self.num_spk2cut_ids[len(speakers)].append(cut.id) - if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: - self.data2sess_ids[cut.dataset_id].append(cut.recording_id) - - self.cut_ids = list(self.id2cut.keys()) - self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) - - self.data2global_speaker = { - dataset_id: True for dataset_id in self.data2sess_ids.keys() - } - - def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: - - db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data - - if is_intra_session_concat: - # intra-dataset and intra-session concatenation - tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) - - else: - # intra-dataset but inter-session concatenation - tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) - - cut = MixedCut(id='concat_' + '_'.join([track.cut.id for track in tracks]), tracks=tracks) - if self.data_type == "msasr": - cut = self.reorder_spk_mapping(cut) - - assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" - assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" - - return cut - - def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: - """ - Get the tracks for the MixedCut object. - """ - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - - total_duration = 0.0 - total_spk_set = set() - tracks = [] - while True: - cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) - total_spk_set = total_spk_set.union(cut.global_speaker_ids) - total_duration += cut.duration - - # break condition - if total_duration >= self.min_duration: - if total_duration > self.max_duration: # exceed the maximum duration, starting over - total_duration = 0.0 - total_spk_set = set() - tracks = [] - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break - break - else: - total_duration = 0.0 - total_spk_set = set() - tracks = [] - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - - return tracks, len(total_spk_set) - - def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: - """ - Get the tracks for the MixedCut object. - """ - sample_cut = self.id2cut[random.choice(self.cut_ids)] - dataset_id = sample_cut.dataset_id - n_spk_list = [n_spk for n_spk, cut_ids in self.data2num_spk2cut_ids[dataset_id].items() if len(cut_ids) > 0] - sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) - - if min(sum_spk_list) > n_speakers: - raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") - - n_spk_left = n_speakers - total_duration = 0.0 - total_spk_set = set() - tracks = [] - num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] - while True: - #if n_spk_left == n_speakers: # for more speakers cases - # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk < n_spk_left]) - if n_spk_left >= 2: - n_spk = 2 - else: - # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk <= n_spk_left]) - n_spk = 1 - - while True: - cut = self.id2cut[random.choice(num_spk2cut_ids[n_spk])] - spks = set(cut.global_speaker_ids) - if not spks.intersection(total_spk_set): - break - - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) - total_duration += cut.duration - n_spk_left -= n_spk - total_spk_set = total_spk_set.union(spks) - - # break condition - - if total_duration >= self.min_duration: - if total_duration > self.max_duration or len(total_spk_set) < n_speakers: # exceed the maximum duration, starting over - total_duration = 0.0 - n_spk_left = n_speakers - total_spk_set = set() - tracks = [] - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break - break - else: - if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers - total_duration = 0.0 - n_spk_left = n_speakers - total_spk_set = set() - tracks = [] - - return tracks, len(total_spk_set) - - def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: - """ - Concatenate the texts of the input cuts. - - """ - global_spk_mapping = {} - str_pattern = pattern.replace("\\", '') - left_str, right_str = str_pattern.split('d+') - for i, track in enumerate(cut.tracks): - local_inverse_spk_mapping = {} - local_spk_mapping = {} - for speaker in track.cut.global_speaker_ids: - if speaker not in global_spk_mapping: - global_spk_mapping[speaker] = len(global_spk_mapping) - if speaker not in local_spk_mapping: - local_spk_mapping[speaker] = len(local_spk_mapping) - local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker - - if i != 0: - text = '' - for word in track.cut.text.split(): - if len(re.findall(pattern, word)) > 0: - local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) - spk = local_inverse_spk_mapping[local_spk_idx] - global_spk_idx = global_spk_mapping[spk] - text += f'{left_str}{global_spk_idx}{right_str}' - else: - text += ' ' + word - track.cut.supervisions[0].text = text - cut.supervisions[i].text = text - else: - cut.supervisions[0].text = track.cut.text - # TODO: need to check the last speaker of last track and the first speaker of the current track - # if they are the same, we need to remove the the speaker token from the current track for segment-level - # Do not need to remove the speaker token for word-level - - return cut - - def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: - """ - Balance the speaker distribution for the simulated meetings. - Args: - num_meetings: The total number of simulated meetings. - speaker_count_distribution: The speaker count distribution for the simulated meetings. - For each number of speakers, calculate the number of meetings needed to balance the distribution. - """ - - total_spk = sum(speaker_count_distribution) - num_speakers2num_meetings = {} - for i_spk in range(self.max_num_speakers): - num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) - - return num_speakers2num_meetings - - - @dill_enabled(True) - def simulate(self, - cuts: CutSet, - num_meetings: int = 10000, - seed: int = 0, - num_jobs: int = 1, - ) -> CutSet: - random.seed(seed) - - self.fit(cuts) - - - num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) - logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") - num_speakers2num_meetings[1] = 0 # skip 1-speaker samples - logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') - - # Step 0: Calculate the number of intra-session and inter-session concatentation samples - n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] - valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples - n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} - for n_spk, n_mt in num_speakers2num_meetings.items(): - logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) - if n_mt <= 0: - logging.warning(f"No concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - continue - n_intra_mt = int(n_mt * self.intra_session_concat_prob[n_spk-1]) - n_inter_mt = n_mt - n_intra_mt - if n_spk in self.num_spk2sess_ids: - logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") - n_spk2n_intra_mt[n_spk] = n_intra_mt - else: - logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") - n_spk2n_intra_mt[n_spk] = 0 - n_inter_mt = n_mt - if n_spk in valid_sim_n_spks: - logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") - n_spk2n_inter_mt[n_spk] = n_inter_mt - else: - logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") - if n_spk2n_intra_mt[n_spk] != 0: - n_spk2n_intra_mt[n_spk] = n_mt - logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") - else: - logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") - # Step 1: intra-session - num_intra_meetings = 0 - intra_mixtures = [] - logging.info(f"Simulating intra-session concatentation samples.") - for n_spk, n_mt in n_spk2n_intra_mt.items(): - if n_mt <= 0: - continue - - for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): - intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) - num_intra_meetings += n_mt - logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") - - # Steo 2: inter-session - logging.info(f"Simulating inter-session concatentation samples.") - - num_inter_meetings = 0 - inter_mixtures = [] - for n_spk, n_mt in n_spk2n_inter_mt.items(): - if n_mt <= 0: - continue - - for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): - inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) - num_inter_meetings += n_mt - logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") - - if num_inter_meetings + num_intra_meetings == 0: - logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration//2} and max {self.max_duration//2}, or the speaker count distribution is not correctly set.") - - - # Multi-processing gets slower, TODO - # else: - # futures = [] - # for n_spk, n_mt in num_speakers2num_meetings.items(): - # tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_jobs) - # futures.extend([tp.submit(self._create_mixture, n_spk) for _ in range(n_mt)]) - # pbar = tqdm(total=num_meetings, desc=f"Simulating mixtures", unit="line", ncols=128) - # count = 0 - # for f in concurrent.futures.as_completed(futures): - # count += 1 - # pbar.update() - # mixtures.append(f.result()) - # tp.shutdown() - # pbar.close() - - return CutSet.from_cuts(intra_mixtures + inter_mixtures) - - -class MixMeetingSimulator(): - """ - This simulator Mix the segments from different/same sessions to create a - multi-speaker meeting. - """ - - def __init__( - self, - intra_session_mix_prob: float|List[float] = [0, 0, 0, 0], - data_type: str = "msasr", - min_duration: float = 80.0, - max_duration: float = 100.0, - max_num_speakers: int = 4, - speaker_count_distribution: List[float] = [0, 0, 0.1, 4], - valid_dataset_ids: List[str] = [], +def speaker_to_target( + a_cut, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, + spk_tar_all_zero: bool = False, + boundary_segments: bool = False, + soft_label: bool = False, + ignore_num_spk_mismatch: bool = True, + soft_thres: float = 0.5, ): - """ - :param intra_session_mix_prob: the probability of concatenating segments from the same - session. [Default: 1] - :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', - the transcripts are included in the simulation,and the boundary segments are - not included. [Default: 'msasr'] - :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] - """ - super().__init__() - if isinstance(intra_session_mix_prob, float): - self.intra_session_mix_prob = [intra_session_mix_prob] * (max_num_speakers) - elif len(intra_session_mix_prob) == max_num_speakers: - self.intra_session_mix_prob = intra_session_mix_prob - else: - raise ValueError(f"intra_session_mix_prob must be either a float or a list of floats, but got {intra_session_mix_prob}") - if data_type not in ["msasr", "diar"]: - raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") - self.data_type = data_type - self.min_duration = min_duration - self.max_duration = max_duration - self.max_num_speakers = max_num_speakers - self.speaker_count_distribution = speaker_count_distribution - self.valid_dataset_ids = valid_dataset_ids - assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" - - def fit(self, cuts) -> CutSet: - """ - Read the manifest file and return a CutSet object. - Each line in the manifest file should be a JSON object representing a segment. - """ - - self.id2cut = {} - self.sess2cut_ids = defaultdict(list) - self.sess2spks = defaultdict(set) - self.data2sess_ids = defaultdict(list) - self.spk2cut_ids = defaultdict(list) - self.data2num_spk2cut_ids = {} - self.sess2num_spk2cut_ids = {} - self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} - for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): - if not self.min_duration <= cut.duration <= self.max_duration: - continue - if not hasattr(cut, 'dataset_id') or cut.dataset_id is None: - continue - if self.valid_dataset_ids and cut.dataset_id not in self.valid_dataset_ids: - continue - if cut.dataset_id not in self.data2num_spk2cut_ids: - self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) - if cut.recording_id not in self.sess2num_spk2cut_ids: - self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) - - speakers = cut.global_speaker_ids - if self.data_type == "msasr": - speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) - if len(speakers) != len(speaker_tokens): - # Lhotse automatically fixes the max duration of the cut, - # resulting in the mismatch of the number of speakers - # and speaker tokens for the last segment - # TODO: need to fix the issue in Lhotse that automatically fixes the max duration - continue - for spk in speakers: - self.spk2cut_ids[spk].append(cut.id) - self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) - - self.id2cut[cut.id] = cut - self.sess2cut_ids[cut.recording_id].append(cut.id) - self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) - self.sess2num_spk2cut_ids[cut.recording_id][len(speakers)].append(cut.id) - self.num_spk2cut_ids[len(speakers)].append(cut.id) - if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: - self.data2sess_ids[cut.dataset_id].append(cut.recording_id) - - self.cut_ids = list(self.id2cut.keys()) - self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) - - self.data2global_speaker = { - dataset_id: True for dataset_id in self.data2sess_ids.keys() - } - - def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: - - db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data - - if is_intra_session_concat: - # intra-dataset and intra-session concatenation - tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) - - else: - # intra-dataset but inter-session concatenation - tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) - - cut = MixedCut(id='mix_' + '_'.join([track.cut.id for track in tracks]), tracks=tracks) - if self.data_type == "msasr": - cut = self.reorder_spk_mapping(cut) - - assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" - assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" - - return cut - - def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: - """ - Get the tracks for the MixedCut object. - """ - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - - total_spk_set = set() - tracks = [] - while True: - cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) - total_spk_set = total_spk_set.union(cut.global_speaker_ids) - total_duration = max(total_duration, cut.duration) - - # break condition - if total_duration >= self.min_duration: - if total_duration > self.max_duration: # exceed the maximum duration, starting over - total_duration = 0.0 - total_spk_set = set() - tracks = [] - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break - break - else: - total_duration = 0.0 - total_spk_set = set() - tracks = [] - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - - return tracks, len(total_spk_set) - - def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: - """ - Get the tracks for the MixedCut object. - """ - sample_cut = self.id2cut[random.choice(self.cut_ids)] - dataset_id = sample_cut.dataset_id - n_spk_list = [n_spk for n_spk, cut_ids in self.data2num_spk2cut_ids[dataset_id].items() if len(cut_ids) > 0] - sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) - - if min(sum_spk_list) > n_speakers: - raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") - - n_spk_left = n_speakers - total_duration = 0.0 - total_spk_set = set() - tracks = [] - num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] - while True: - if n_spk_left >= 2: - n_spk = 2 - else: - # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk <= n_spk_left]) - n_spk = 1 - - while True: - cut = self.id2cut[random.choice(num_spk2cut_ids[n_spk])] - spks = set(cut.global_speaker_ids) - if not spks.intersection(total_spk_set): - break - - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) - total_duration = max(total_duration, cut.duration) - n_spk_left -= n_spk - total_spk_set = total_spk_set.union(spks) + ''' + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) + This function is needed for speaker diarization with ASR model trainings. - # break condition - - if total_duration >= self.min_duration: - if total_duration > self.max_duration or len(tracks) > 2: # exceed the maximum duration, starting over - total_duration = 0.0 - n_spk_left = n_speakers - total_spk_set = set() - tracks = [] - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break - break - else: - if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers - total_duration = 0.0 - n_spk_left = n_speakers - total_spk_set = set() - tracks = [] - - return tracks, len(total_spk_set) + Args: + a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): encoder subsampling_factor, 8 by default + spk_tar_all_zero (Tensor): set to True gives all zero "mask" + boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training + soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: - """ - Concatenate the texts of the input cuts. - - """ - global_spk_mapping = {} - str_pattern = pattern.replace("\\", '') - left_str, right_str = str_pattern.split('d+') - for i, track in enumerate(cut.tracks): - local_inverse_spk_mapping = {} - local_spk_mapping = {} - for speaker in track.cut.global_speaker_ids: - if speaker not in global_spk_mapping: - global_spk_mapping[speaker] = len(global_spk_mapping) - if speaker not in local_spk_mapping: - local_spk_mapping[speaker] = len(local_spk_mapping) - local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker - - if i != 0: - text = '' - for word in track.cut.text.split(): - if len(re.findall(pattern, word)) > 0: - local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) - spk = local_inverse_spk_mapping[local_spk_idx] - global_spk_idx = global_spk_mapping[spk] - text += f'{left_str}{global_spk_idx}{right_str}' - else: - text += ' ' + word - track.cut.supervisions[0].text = text - cut.supervisions[i].text = text - else: - cut.supervisions[0].text = track.cut.text - # TODO: need to check the last speaker of last track and the first speaker of the current track - # if they are the same, we need to remove the the speaker token from the current track for segment-level - # Do not need to remove the speaker token for word-level - - return cut + Returns: + mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) + ''' + # get cut-related segments from rttms + # basename = os.path.basename(a_cut.rttm_filepath).replace('.rttm', '') + if isinstance(a_cut, MixedCut): + cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + elif isinstance(a_cut, MonoCut): + cut_list = [a_cut] + offsets = [0] + else: + raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: - """ - Balance the speaker distribution for the simulated meetings. - Args: - num_meetings: The total number of simulated meetings. - speaker_count_distribution: The speaker count distribution for the simulated meetings. - For each number of speakers, calculate the number of meetings needed to balance the distribution. - """ - - total_spk = sum(speaker_count_distribution) - num_speakers2num_meetings = {} - for i_spk in range(self.max_num_speakers): - num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + segments_total = [] + for i, cut in enumerate(cut_list): + rttms = SupervisionSet.from_rttm(cut.rttm_filepath) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) - return num_speakers2num_meetings - + for seg in segments_iterator: + if seg.start < 0: + seg.duration += seg.start + seg.start = 0 + if seg.end > cut.duration: + seg.duration -= seg.end - cut.duration + seg.start += offsets[i] + segments_total.append(seg) - @dill_enabled(True) - def simulate(self, - cuts: CutSet, - num_meetings: int = 10000, - seed: int = 0, - num_jobs: int = 1, - ) -> CutSet: - random.seed(seed) - - self.fit(cuts) - - num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) - logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") - num_speakers2num_meetings[1] = 0 # skip 1-speaker samples - logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') - - # Step 0: Calculate the number of intra-session and inter-session concatentation samples - n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] - valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples - n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} - for n_spk, n_mt in num_speakers2num_meetings.items(): - logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) - if n_mt <= 0: - logging.warning(f"No intra-session concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - continue - n_intra_mt = int(n_mt * self.intra_session_mix_prob[n_spk-1]) - n_inter_mt = n_mt - n_intra_mt - if n_spk in self.num_spk2sess_ids: - logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") - n_spk2n_intra_mt[n_spk] = n_intra_mt - else: - logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") - n_spk2n_intra_mt[n_spk] = 0 - n_inter_mt = n_mt - if n_spk in valid_sim_n_spks: - logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") - n_spk2n_inter_mt[n_spk] = n_inter_mt - else: - logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") - if n_spk2n_intra_mt[n_spk] != 0: - n_spk2n_intra_mt[n_spk] = n_mt - logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") - else: - logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") - # Step 1: intra-session - num_intra_meetings = 0 - intra_mixtures = [] - logging.info(f"Simulating intra-session concatentation samples.") - for n_spk, n_mt in n_spk2n_intra_mt.items(): - if n_mt <= 0: - continue + # apply arrival time sorting to the existing segments + segments_total.sort(key = lambda rttm_sup: rttm_sup.start) - for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): - intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) - num_intra_meetings += n_mt - logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") - - # Steo 2: inter-session - logging.info(f"Simulating inter-session concatentation samples.") + seen = set() + seen_add = seen.add + speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] + + speaker_to_idx_map = { + spk: idx + for idx, spk in enumerate(speaker_ats) + } + if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers + raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") - num_inter_meetings = 0 - inter_mixtures = [] - for n_spk, n_mt in n_spk2n_inter_mt.items(): - if n_mt <= 0: - continue - - for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): - inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) - num_inter_meetings += n_mt - logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") - - if num_inter_meetings + num_intra_meetings == 0: - logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration} and max {self.max_duration}, or the speaker count distribution is not correctly set.") - - return CutSet.from_cuts(intra_mixtures + inter_mixtures) - -class LibriSpeechMixSimulator(): - - def __init__( - self, - min_duration: float = 80.0, - max_duration: float = 100.0, - n_mix_speakers: List[int] = [1, 2, 3], - speaker_count_distribution: List[float] = [1, 1, 1], - ): - """ - :param min_duration: the minimum duration of the simulated meeting. [Default: 80.0] - :param max_duration: the maximum duration of the simulated meeting. [Default: 100.0] - """ - super().__init__() - self.min_duration = min_duration - self.max_duration = max_duration - self.n_mix_speakers = n_mix_speakers - self.speaker_count_distribution = speaker_count_distribution - assert len(speaker_count_distribution) == len(n_mix_speakers), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {len(n_mix_speakers)}" - - def fit(self, cuts) -> CutSet: - pass - - def simulate(self, - cuts: CutSet, - num_meetings: int = 10000, - seed: int = 0, - num_jobs: int = 1, - ) -> CutSet: - random.seed(seed) - - cut_set = [] - for n_speakers, n_mt in zip(self.n_mix_speakers, self.speaker_count_distribution): - if n_mt <= 0: - continue - for i in tqdm(range(n_mt), desc=f"Simulating {n_speakers}-speaker mixtures", ncols=128): - cut_set.append(self._create_mixture(n_speakers=n_speakers)) - return CutSet.from_cuts(cut_set) - -class LibriSpeechMixGenerator(): - def __init__(self): - pass - - def generate(self, cuts): - cut_set = [] - for cut in tqdm(cuts): - offsets = cut.delays - durations = cut.durations - wavs = cut.wavs - texts = cut.texts - speakers = cut.speakers + # initialize mask matrices (num_speaker, encoder_hidden_len) + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) + if spk_tar_all_zero: + frame_mask = torch.zeros((num_samples, num_speakers)) + else: + frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) + soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) - tracks = [] - for i, (offset, duration, wav, text, speaker) in enumerate(zip(offsets, durations, wavs, texts, speakers)): - wav_dur = soundfile.info(wav).duration - wav_samples = soundfile.info(wav).frames - custom = { - 'speaker': speaker, - 'text': text, - } - cut_1spk = MonoCut( - id=wav.split('/')[-1].replace('.wav', ''), - start=0, - duration=duration, - channel=0, - supervisions=[], - recording=Recording( - id=wav.split('/')[-1].replace('.wav', ''), - sources=[ - AudioSource( - type='file', - channels=[0], - source=wav - ) - ], - sampling_rate=16000, - num_samples=wav_samples, - duration=wav_dur - ), - custom=custom - ) + if soft_label: + mask = soft_mask + else: + mask = (soft_mask > soft_thres).float() - tracks.append(MixTrack(cut=cut_1spk, type=type(cut_1spk), offset=offset)) - sup = SupervisionSegment( - id=cut.id, - recording_id=cut.recording_id, - start=0, - duration=offset+wav_dur, - text=cut.text, - ) - tracks[0].cut.supervisions.append(sup) - cut_multi_spk = MixedCut(id=cut.id, tracks=tracks) - - cut_set.append(cut_multi_spk) - - return CutSet.from_cuts(cut_set) \ No newline at end of file + return mask \ No newline at end of file diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 80b3e1f918b8..492041162cff 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -21,7 +21,6 @@ from typing import Dict, List, Tuple, Union import numpy as np -from omegaconf import OmegaConf from omegaconf.listconfig import ListConfig import soundfile as sf import torch @@ -981,9 +980,8 @@ def get_subsegments( subsegments.append([start, min(duration, window)]) elif slices > 0: # What if slcies = 0 ? start_col = torch.arange(offset, slice_end, shift)[:slices] - dur_col = window * torch.ones(slices) - dur_col = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) - dur_col = torch.round(dur_col, decimals=decimals) + dur_col_raw = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) + dur_col = torch.round(dur_col_raw, decimals=decimals) valid_mask = dur_col >= min_subsegment_duration valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) subsegments = valid_subsegments.tolist() diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 192c42375dca..0ccfef9b9e8b 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -35,7 +35,6 @@ from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm -from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging From 40e9f95d879ba9bc1fed25d04e56806e6d24fdec Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 15 Nov 2024 01:15:55 +0000 Subject: [PATCH 06/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../neural_diarizer/e2e_diarize_speech.py | 17 +++--- nemo/collections/asr/models/__init__.py | 2 +- .../asr/models/sortformer_diar_models.py | 2 +- .../asr/parts/utils/asr_multispeaker_utils.py | 60 +++++++++++-------- .../asr/parts/utils/speaker_utils.py | 4 +- 5 files changed, 48 insertions(+), 37 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 5237b5c3c67b..72d7977840ce 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -125,7 +125,7 @@ def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optun Suggests hyperparameters for postprocessing using Optuna. See the following link for `trial` instance in Optuna framework. https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial - + Args: postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. @@ -390,13 +390,14 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: out_rttm_dir=cfg.out_rttm_dir, ) logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") - score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, - all_reference=all_refs, - all_hypothesis=all_hyps, - all_uem=all_uems, - collar=cfg.collar, - ignore_overlap=cfg.ignore_overlap - ) + score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap, + ) logging.info(f"PostProcessingParams: {postprocessing_cfg}") diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index f27828a6b11e..34dead15b33d 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -20,7 +20,6 @@ EncDecFrameClassificationModel, ) from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer -from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel @@ -36,6 +35,7 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ssl_models import ( EncDecDenoiseMaskedTokenPredModel, EncDecMaskedTokenPredModel, diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index fd9d01f33f2b..939b03e7a5ac 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -571,7 +571,7 @@ def test_batch( self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) torch.cuda.empty_cache() self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) - + logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 5e19b7abeb38..3f40f5cd3e39 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -13,10 +13,11 @@ # limitations under the License. import math + import torch -from tqdm import tqdm from lhotse import SupervisionSet from lhotse.cut import MixedCut, MonoCut +from tqdm import tqdm def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> torch.Tensor: @@ -173,6 +174,7 @@ def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutati max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) return max_score_permed_labels # (batch_size, num_speakers, num_classes) + def find_segments_from_rttm( recording_id: str, rttms, @@ -211,8 +213,6 @@ def find_segments_from_rttm( ] - - def get_mask_from_segments( segments: list, a_cut, @@ -305,17 +305,18 @@ def get_hidden_length_from_sample_length( hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) return int(hidden_length) + def speaker_to_target( a_cut, - num_speakers: int = 4, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, spk_tar_all_zero: bool = False, boundary_segments: bool = False, soft_label: bool = False, ignore_num_spk_mismatch: bool = True, soft_thres: float = 0.5, - ): +): ''' Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) This function is needed for speaker diarization with ASR model trainings. @@ -329,7 +330,7 @@ def speaker_to_target( boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - + Returns: mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) ''' @@ -343,14 +344,18 @@ def speaker_to_target( offsets = [0] else: raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - + segments_total = [] for i, cut in enumerate(cut_list): rttms = SupervisionSet.from_rttm(cut.rttm_filepath) - if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included - segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) - else: # segments with seg_start > total_start and seg_end < total_end are included - segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm( + recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0 + ) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find( + recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True + ) for seg in segments_iterator: if seg.start < 0: @@ -360,28 +365,31 @@ def speaker_to_target( seg.duration -= seg.end - cut.duration seg.start += offsets[i] segments_total.append(seg) - + # apply arrival time sorting to the existing segments - segments_total.sort(key = lambda rttm_sup: rttm_sup.start) + segments_total.sort(key=lambda rttm_sup: rttm_sup.start) seen = set() seen_add = seen.add speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] - - speaker_to_idx_map = { - spk: idx - for idx, spk in enumerate(speaker_ats) - } + + speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers - raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") - + raise ValueError( + f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}" + ) + # initialize mask matrices (num_speaker, encoder_hidden_len) - feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default - num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) - if spk_tar_all_zero: + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length( + a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) + if spk_tar_all_zero: frame_mask = torch.zeros((num_samples, num_speakers)) else: - frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) + frame_mask = get_mask_from_segments( + segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch + ) soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) if soft_label: diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 09b395a28f8d..87ad7eda59d9 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -982,7 +982,9 @@ def get_subsegments( subsegments.append([start, min(duration, window)]) elif slices > 0: # What if slcies = 0 ? start_col = torch.arange(offset, slice_end, shift)[:slices] - dur_col_raw = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) + dur_col_raw = torch.min( + slice_end * torch.ones_like(start_col) - start_col, window * torch.ones_like(start_col) + ) dur_col = torch.round(dur_col_raw, decimals=decimals) valid_mask = dur_col >= min_subsegment_duration valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) From f7f84bb386fc0d8c11012e6e246a5d058caa0a72 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 17:53:31 -0800 Subject: [PATCH 07/47] Adding docstrings to reflect the PR comments Signed-off-by: taejinp --- .../asr/data/audio_to_diar_label.py | 19 +++++++++++++------ .../asr/parts/utils/asr_multispeaker_utils.py | 3 --- .../asr/parts/utils/speaker_utils.py | 5 ----- .../common/parts/preprocessing/collections.py | 3 --- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index b00338743a43..f08788d2d231 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -149,18 +149,25 @@ def get_subsegments_to_timestamps( subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 ): """ - Convert subsegment timestamps to scale timestamps by multiplying with the feature rate and rounding. - All `ts` related tensors are dimensioned as (N, 2), where N is the number of subsegments. - + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) and rounding. + Segment is consisted of many subsegments and sugsegments are equivalent to `frames` in end-to-end speaker diarization models. + Args: subsegments (List[Tuple[float, float]]): - A list of tuples where each tuple contains the start and end times of a subsegment. + A list of tuples where each tuple contains the start and end times of a subsegment (frames in end-to-end models). + >>> subsegments = [[t0_start, t0_duration], [t1_start, t1_duration],..., [tN_start, tN_duration]] feat_per_sec (int, optional): The number of feature frames per second. Defaults to 100. max_end_ts (float, optional): The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. decimals (int, optional): The number of decimal places to round the timestamps. Defaults to 2. + + Example: + Segments starting from 0.0 and ending at 69.2 seconds. + If hop-length is 0.08 and the subsegment (frame) length is 0.16 seconds, + there are 864 = (69.2 - 0.16)/0.08 + 1 subsegments (frames in end-to-end models) in this segment. + >>> subsegments = [[[0.0, 0.16], [0.08, 0.16], ..., [69.04, 0.16], [69.12, 0.08]] Returns: ts (torch.tensor): @@ -175,7 +182,7 @@ def get_subsegments_to_timestamps( return ts -def extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines, round_digits=3): +def extract_frame_info_from_rttm(offset, duration, rttm_lines, round_digits=3): """ Extracts RTTM lines containing speaker labels, start time, and end time for a given audio segment. @@ -1093,7 +1100,7 @@ def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] """ rttm_lines = open(rttm_file).readlines() - rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines) + rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(offset, duration, rttm_lines) fr_level_target = get_frame_targets_from_rttm( rttm_timestamps=rttm_timestamps, diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 5e19b7abeb38..785a7c41d32c 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -210,9 +210,6 @@ def find_segments_from_rttm( if segment.start < end_before + tolerance and segment.end > start_after + tolerance ] - - - def get_mask_from_segments( segments: list, a_cut, diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 09b395a28f8d..cd831f054106 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -33,11 +33,6 @@ from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data from nemo.utils import logging -""" -This file contains all the utility functions required for speaker embeddings part in diarization scripts -""" - - def get_uniqname_from_filepath(filepath): """ Return base name from provided filepath diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 632ec06bc647..8acd9fc08743 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -1310,9 +1310,6 @@ def __init__( if isinstance(audio_file, list): if len(audio_file) == 0: raise ValueError(f"Empty audio file list: {audio_file}") - audio_file_name = sorted(audio_file)[0] - else: - audio_file_name = audio_file file_id, _ = os.path.splitext(os.path.basename(audio_file)) self.mapping[file_id] = len(data) - 1 From 4134e2533c80c76fab190a7ac62436a24a0016a2 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 17:56:03 -0800 Subject: [PATCH 08/47] removed the unused find_first_nonzero Signed-off-by: taejinp --- nemo/collections/asr/data/audio_to_diar_label.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index f08788d2d231..0a40d832eaf0 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -20,7 +20,6 @@ import numpy as np import torch -from nemo.collections.asr.parts.utils.asr_multispeaker_utils import find_first_nonzero from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, get_subsegments, prepare_split_data from nemo.collections.common.parts.preprocessing.collections import ( From 5dd4d4c6d2d4eee96f40a0b34b9139f59157d208 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 15 Nov 2024 01:56:16 +0000 Subject: [PATCH 09/47] Apply isort and black reformatting Signed-off-by: tango4j --- nemo/collections/asr/data/audio_to_diar_label.py | 8 ++++---- .../collections/asr/parts/utils/asr_multispeaker_utils.py | 1 + nemo/collections/asr/parts/utils/speaker_utils.py | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index f08788d2d231..535ae3301173 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -151,7 +151,7 @@ def get_subsegments_to_timestamps( """ Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` in end-to-end speaker diarization models. - + Args: subsegments (List[Tuple[float, float]]): A list of tuples where each tuple contains the start and end times of a subsegment (frames in end-to-end models). @@ -162,10 +162,10 @@ def get_subsegments_to_timestamps( The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. decimals (int, optional): The number of decimal places to round the timestamps. Defaults to 2. - - Example: + + Example: Segments starting from 0.0 and ending at 69.2 seconds. - If hop-length is 0.08 and the subsegment (frame) length is 0.16 seconds, + If hop-length is 0.08 and the subsegment (frame) length is 0.16 seconds, there are 864 = (69.2 - 0.16)/0.08 + 1 subsegments (frames in end-to-end models) in this segment. >>> subsegments = [[[0.0, 0.16], [0.08, 0.16], ..., [69.04, 0.16], [69.12, 0.08]] diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 6412d88f4c0f..e945439bf8fa 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -210,6 +210,7 @@ def find_segments_from_rttm( if segment.start < end_before + tolerance and segment.end > start_after + tolerance ] + def get_mask_from_segments( segments: list, a_cut, diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index cb1244eef660..fd6f71dc0502 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -33,6 +33,7 @@ from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data from nemo.utils import logging + def get_uniqname_from_filepath(filepath): """ Return base name from provided filepath From 037f61e85a9739ae8cadcf63e40b0122f218de6c Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 12:33:06 -0800 Subject: [PATCH 10/47] Fixed all pylint issues Signed-off-by: taejinp --- ...rtformer_diarizer_hybrid_loss_4spk-v1.yaml | 2 +- .../neural_diarizer/e2e_diarize_speech.py | 3 +- .../neural_diarizer/sortformer_diar_train.py | 19 +-- .../asr/data/audio_to_diar_label.py | 96 +++++++---- nemo/collections/asr/metrics/der.py | 40 +++-- .../asr/metrics/multi_binary_acc.py | 12 ++ .../asr/models/sortformer_diar_models.py | 3 +- .../asr/modules/sortformer_modules.py | 1 - .../asr/parts/utils/speaker_utils.py | 152 +++++++++++------- nemo/collections/asr/parts/utils/vad_utils.py | 110 ++++++++----- .../common/parts/preprocessing/collections.py | 49 +++--- 11 files changed, 303 insertions(+), 184 deletions(-) diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml index 04409a4cd60a..4a6d8f242d36 100644 --- a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -1,4 +1,4 @@ -sortformer_diarizer_hybrid_loss_4spk-v1.yaml# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. # Model name convention for Sortformer Diarizer: sortformer_diarizer___.yaml # (Example) `sortformer_diarizer_hybrid_loss_4spk-v1.yaml`. # Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 72d7977840ce..0f90e70eff80 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -23,13 +23,12 @@ import os import tempfile from dataclasses import dataclass, is_dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import optuna import pytorch_lightning as pl import torch import yaml -from hydra.core.config_store import ConfigStore from omegaconf import OmegaConf from pytorch_lightning import seed_everything from tqdm import tqdm diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 3ba0dbc3ed19..75980d342c65 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,33 +22,26 @@ from nemo.utils.exp_manager import exp_manager """ -Example training session (single GPU training on telephonic datasets) +Example training session (single node training) -python ./multiscale_diar_decoder.py --config-path='../conf/neural_diarizer' --config-name='msdd_5scl_15_05_50Povl_256x3x32x2.yaml' \ +python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' --config-name='' \ trainer.devices=1 \ - model.base.diarizer.speaker_embeddings.model_path="titanet_large" \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \ - model.train_ds.emb_dir="" \ - model.validation_ds.emb_dir="" \ exp_manager.name='sample_train' \ - exp_manager.exp_dir='./msdd_exp' + exp_manager.exp_dir='./sortformer_diar_train' """ seed_everything(42) - -@hydra_runner(config_path="../conf/neural_diarizer", config_name="msdd_5scl_15_05_50Povl_256x3x32x2.yaml") +@hydra_runner(config_path="../conf/neural_diarizer", config_name="sortformer_diarizer_hybrid_loss_4spk-v1.yaml") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) sortformer_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) - # Initialize the weights of the model from another model, if provided via config sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) trainer.fit(sortformer_model) - if __name__ == '__main__': - - main() + main() \ No newline at end of file diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index b6b398743198..f47b5ca11f43 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -81,7 +81,8 @@ def extract_seg_info_from_rttm(rttm_lines, mapping_dict=None, target_spks=None): mapping_dict (dict): Mapping between the estimated speakers and the speakers in the ground-truth annotation. `mapping_dict` variable is only provided when the inference mode is running in sequence-eval mode. - Sequence eval mode uses the mapping between the estimated speakers and the speakers in ground-truth annotation. + Sequence eval mode uses the mapping between the estimated speakers and the speakers + in ground-truth annotation. Returns: rttm_tup (tuple): Tuple containing lists of start time, end time and speaker labels. @@ -113,12 +114,14 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, Args: rttm_timestamps (list): List containing start and end time for each speaker segment label. - stt_list, end_list and speaker_list are contained. + `stt_list`, `end_list` and `speaker_list` are contained. frame_per_sec (int): - Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. + Number of feature frames per second. This quantity is determined by + `window_stride` variable in preprocessing module. target_spks (tuple): - Speaker indices that are generated from combinations. If there are only one or two speakers, - only a single target_spks variable is generated. + Speaker indices that are generated from combinations. + If there are only one or two speakers, + only a single `target_spks` variable is generated. Returns: fr_level_target (torch.tensor): @@ -148,12 +151,14 @@ def get_subsegments_to_timestamps( subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 ): """ - Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) and rounding. - Segment is consisted of many subsegments and sugsegments are equivalent to `frames` in end-to-end speaker diarization models. + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) + and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` + in end-to-end speaker diarization models. Args: subsegments (List[Tuple[float, float]]): - A list of tuples where each tuple contains the start and end times of a subsegment (frames in end-to-end models). + A list of tuples where each tuple contains the start and end times of a subsegment + (frames in end-to-end models). >>> subsegments = [[t0_start, t0_duration], [t1_start, t1_duration],..., [tN_start, tN_duration]] feat_per_sec (int, optional): The number of feature frames per second. Defaults to 100. @@ -246,7 +251,8 @@ def get_frame_targets_from_rttm( List containing start and end time for each speaker segment label. stt_list, end_list and speaker_list are contained. feat_per_sec (int): - Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. + Number of feature frames per second. + This quantity is determined by window_stride variable in preprocessing module. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, only a single target_spks variable is generated. @@ -260,7 +266,8 @@ def get_frame_targets_from_rttm( total_fr_len = int(duration * feat_per_sec) if len(sorted_speakers) > max_spks: logging.warning( - f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: {max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!" + f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: " + f"{max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!" ) feat_level_target = torch.zeros(total_fr_len, max_spks) for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)): @@ -408,15 +415,17 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): def get_diar_target_labels(self, uniq_id, sample, fr_level_target): """ - Convert frame-level diarization target variable into segment-level target variable. Since the granularity is reduced - from frame level (10ms) to segment level (100ms~500ms), we need a threshold value, `soft_label_thres`, which determines - the label of each segment based on the overlap between a segment range (start and end time) and the frame-level target variable. + Convert frame-level diarization target variable into segment-level target variable. + Since the granularity is reduced from frame level (10ms) to segment level (100ms~500ms), + we need a threshold value, `soft_label_thres`, which determines the label of each segment + based on the overlap between a segment range (start and end time) and the frame-level target variable. Args: uniq_id (str): Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file. sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. fr_level_target (torch.tensor): Tensor containing label for each feature-level frame. @@ -424,7 +433,8 @@ def get_diar_target_labels(self, uniq_id, sample, fr_level_target): seg_target (torch.tensor): Tensor containing binary speaker labels for base-scale segments. base_clus_label (torch.tensor): - Representative speaker label for each segment. This variable only has one speaker label for each base-scale segment. + Representative speaker label for each segment. This variable only has one speaker label + for each base-scale segment. -1 means that there is no corresponding speaker in the target_spks tuple. """ seg_target_list, base_clus_label = [], [] @@ -459,7 +469,8 @@ def parse_rttm_for_ms_targets(self, sample): Args: sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, only a single target_spks tuple is generated. @@ -474,7 +485,8 @@ def parse_rttm_for_ms_targets(self, sample): multiscale embeddings to form an input matrix for the MSDD model. """ - rttm_lines = open(sample.rttm_file).readlines() + with open(sample.rttm_file, 'r') as file: + rttm_lines = file.readlines() uniq_id = self.get_uniq_id_with_range(sample) rttm_timestamps = extract_seg_info_from_rttm(rttm_lines) fr_level_target = assign_frame_level_spk_vector( @@ -579,7 +591,8 @@ class _AudioMSDDInferDataset(Dataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, + scale mapping and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): @@ -678,9 +691,9 @@ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): """ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate - ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared with `soft_label_thres` - to determine whether a label vector should contain 0 or 1 for each speaker bin. Note that seg_target variable has - dimension of (number of base-scale segments x 2) dimension. + ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared + with `soft_label_thres` to determine whether a label vector should contain 0 or 1 for each speaker bin. + Note that seg_target variable has dimension of (number of base-scale segments x 2) dimension. Example of seg_target: [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] @@ -726,7 +739,8 @@ def __getitem__(self, index): if avg_embs.shape[2] > self.max_spks: raise ValueError( - f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to self.max_num_speakers {self.max_spks}" + f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to " + f"self.max_num_speakers {self.max_spks}" ) feats = [] @@ -820,7 +834,8 @@ def _msdd_train_collate_fn(self, batch): def _msdd_infer_collate_fn(self, batch): """ - Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings. + Collate batch of feats (speaker embeddings), feature lengths, target label sequences + and cluster-average embeddings. Args: batch (tuple): @@ -922,6 +937,7 @@ def __init__( ) def msdd_train_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for training.""" return _msdd_train_collate_fn(self, batch) @@ -943,11 +959,13 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, scale mapping + and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): - Threshold that determines speaker labels of segments depending on the overlap with groundtruth speaker timestamps. + Threshold that determines speaker labels of segments depending on the overlap + with groundtruth speaker timestamps. featurizer: Featurizer instance for generating features from raw waveform. use_single_scale_clus (bool): @@ -955,11 +973,12 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. window_stride (float): - Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + Window stride for acoustic feature. This value is used for calculating the numbers of + feature-level frames. pairwise_infer (bool): - If True, this Dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the MSDD to merge the individual results. + If True, this Dataset class operates in inference mode. In inference mode, a set of speakers + in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the MSDD to merge the individual results. """ def __init__( @@ -988,6 +1007,7 @@ def __init__( ) def msdd_infer_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for inference.""" return _msdd_infer_collate_fn(self, batch) @@ -1089,7 +1109,7 @@ def get_uniq_id_with_range(self, sample, deci=3): uniq_id = f"{bare_uniq_id}_{offset}_{endtime}" return uniq_id - def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, target_len): + def parse_rttm_for_targets_and_lens(self, rttm_file, offset, duration, target_len): """ Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file. This function converts (start, end, speaker_id) format into base-scale (the finest scale) segment level @@ -1098,7 +1118,9 @@ def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, Example of seg_target: [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] """ - rttm_lines = open(rttm_file).readlines() + with open(rttm_file, 'r') as f: + rttm_lines = f.readlines() + rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(offset, duration, rttm_lines) fr_level_target = get_frame_targets_from_rttm( @@ -1203,7 +1225,8 @@ def __getitem__(self, index): uniq_id = self.get_uniq_id_with_range(sample) audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) - # We should resolve the length mis-match from the round-off errors: `session_len_sec` and `audio_signal.shape[0]` + # We should resolve the length mis-match from the round-off errors between these two variables: + # `session_len_sec` and `audio_signal.shape[0]` session_len_sec = ( np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal ) @@ -1213,7 +1236,7 @@ def __getitem__(self, index): audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) targets = self.parse_rttm_for_targets_and_lens( - uniq_id=uniq_id, rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len + rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len ) return audio_signal, audio_signal_length, targets, target_len @@ -1229,13 +1252,15 @@ def _eesd_train_collate_fn(self, batch): Returns: audio_signal (torch.Tensor): - A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` in the input manifest file. + A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` + in the input manifest file. feature_length (torch.Tensor): A tensor containing the lengths of the raw waveform samples. targets (torch.Tensor): Groundtruth speaker labels for the given input embedding sequence. target_lens (torch.Tensor): - A tensor containing the number of segments for each sample in the batch, necessary for reshaping inputs to the EESD model. + A tensor containing the number of segments for each sample in the batch, necessary for + reshaping inputs to the EESD model. """ packed_batch = list(zip(*batch)) audio_signal, feature_length, targets, target_len = packed_batch @@ -1344,4 +1369,5 @@ def __init__( ) def eesd_train_collate_fn(self, batch): + """Collate a batch of data for end-to-end speaker diarization training.""" return _eesd_train_collate_fn(self, batch) diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index 000b839ceb46..22c9a76b7fc9 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -123,7 +123,7 @@ def uem_timeline_from_file(uem_file, uniq_name=''): lines = f.readlines() for line in lines: line = line.strip() - speaker_id, channel, start_time, end_time = line.split() + _, _, start_time, end_time = line.split() timeline.add(Segment(float(start_time), float(end_time))) return timeline @@ -145,14 +145,21 @@ def score_labels( Args: - AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath - all_reference (list[uniq_name,Annotation]): reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation - verbose (bool): Warns if RTTM file is not found. + AUDIO_RTTM_MAP (dict): + Dictionary containing information provided from manifestpath + all_reference (list[uniq_name,Annotation]): + Reference annotations for score calculation + all_hypothesis (list[uniq_name,Annotation]): + Hypothesis annotations for score calculation + verbose (bool): + Warns if RTTM file is not found. Returns: - metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. - mapping (dict): Mapping dict containing the mapping speaker label for each audio input + metric (pyannote.DiarizationErrorRate): + Pyannote Diarization Error Rate metric object. + This object contains detailed scores of each audiofile. + mapping (dict): + Mapping dict containing the mapping speaker label for each audio input < Caveat > Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of @@ -171,7 +178,8 @@ def score_labels( correct_spk_count += 1 if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): logging.info( - f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" + f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " + f"Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" ) uem_obj = None if all_uem is not None: @@ -187,7 +195,7 @@ def score_labels( spk_count_acc = correct_spk_count / len(all_reference) DER = abs(metric) if metric['total'] == 0: - raise ValueError(f"Total evaluation time is 0. Abort.") + raise ValueError("Total evaluation time is 0. Abort.") CER = metric['confusion'] / metric['total'] FA = metric['false alarm'] / metric['total'] MISS = metric['missed detection'] / metric['total'] @@ -195,18 +203,18 @@ def score_labels( itemized_errors = (DER, CER, FA, MISS) if verbose: - # logging.info(f"\n{metric.report()}") - pass + logging.info(f"\n{metric.report()}") logging.info( - "Cumulative Results for collar {} sec and ignore_overlap {}: \n| FA: {:.4f} | MISS: {:.4f} | CER: {:.4f} | DER: {:.4f} | Spk. Count Acc. {:.4f}\n".format( - collar, ignore_overlap, FA, MISS, CER, DER, spk_count_acc - ) + f"Cumulative Results for collar {collar} sec and ignore_overlap {ignore_overlap}: \n" + f"| FA: {FA:.4f} | MISS: {MISS:.4f} | CER: {CER:.4f} | DER: {DER:.4f} | " + f"Spk. Count Acc. {spk_count_acc:.4f}\n" ) return metric, mapping_dict, itemized_errors elif verbose: logging.warning( - "Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate" + "Check if each ground truth RTTMs were present in the provided manifest file. " + "Skipping calculation of Diariazation Error Rate" ) return None @@ -447,4 +455,4 @@ def concat_perm_word_error_rate( cpWER_values.append(cpWER) hyps_spk.append(min_hypothesis) refs_spk.append(concat_reference) - return cpWER_values, hyps_spk, refs_spk + return cpWER_values, hyps_spk, refs_spk \ No newline at end of file diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 13e57b43bb0b..8ad09c842636 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -84,6 +84,18 @@ def __init__(self, dist_sync_on_step=False): def update( self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False ) -> torch.Tensor: + """ + Update the metric with the given predictions, targets, and signal lengths to the metric instance. + + Args: + preds (torch.Tensor): Predicted values. + targets (torch.Tensor): Target values. + signal_lengths (torch.Tensor): Length of each sequence in the batch input. + cumulative (bool): Whether to accumulate the values over time. + + Returns: + f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. + """ with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] targets_list = [targets[k, : signal_lengths[k], :] for k in range(targets.shape[0])] diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 939b03e7a5ac..e3c14dd77c65 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -208,7 +208,8 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): def test_dataloader(self): if self._test_dl is not None: return self._test_dl - + return None + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: if hasattr(self.preprocessor, '_sample_rate'): diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index fdbeee5235ea..e0b5b15094b6 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -20,7 +20,6 @@ from nemo.core.classes.exportable import Exportable from nemo.core.classes.module import NeuralModule -from nemo.core.neural_types.elements import ProbsType __all__ = ['SortformerModules'] diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index fd6f71dc0502..1e7dda91c9e7 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -14,7 +14,6 @@ import gc import json -import math import os import shutil from copy import deepcopy @@ -23,14 +22,13 @@ import numpy as np import soundfile as sf import torch -from omegaconf import OmegaConf from omegaconf.listconfig import ListConfig from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm from nemo.collections.asr.data.audio_to_label import repeat_signal from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering -from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data +from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat, split_input_data from nemo.utils import logging @@ -78,10 +76,13 @@ def audio_rttm_map(manifest, attach_dur=False): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - Args: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists - - returns: - AUDIO_RTTM_MAP (dict) : A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files + + Args: + manifest (str): Path to the manifest file + attach_dur (bool, optional): If True, attach duration information to the unique name. Defaults to False. + + Returns: + AUDIO_RTTM_MAP (dict) : Dictionary with unique names as keys and corresponding metadata as values. """ AUDIO_RTTM_MAP = {} @@ -114,10 +115,9 @@ def audio_rttm_map(manifest, attach_dur=False): AUDIO_RTTM_MAP[uniqname] = meta else: raise KeyError( - "file {} is already part of AUDIO_RTTM_MAP, it might be duplicated, Note: file basename must be unique".format( - meta['audio_filepath'] + f"file {meta['audio_filepath']} is already part of AUDIO_RTTM_MAP, it might be duplicated, " + "Note: file basename must be unique" ) - ) return AUDIO_RTTM_MAP @@ -247,7 +247,8 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg def get_timestamps(multiscale_timestamps, multiscale_args_dict): """ The timestamps in `multiscale_timestamps` dictionary are indexed by scale index. - This function rearranges the extracted speaker embedding and timestamps by unique ID to make the further processing more convenient. + This function rearranges the extracted speaker embedding and timestamps by unique ID + to make the further processing more convenient. Args: multiscale_timestamps (dict): @@ -441,13 +442,20 @@ def perform_clustering( 'embeddings' : Tensor containing embeddings. Dimensions:(# of embs) x (emb. dimension) 'timestamps' : Tensor containing ime stamps list for each audio recording 'multiscale_segment_counts' : Tensor containing the number of segments for each scale - AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path - out_rttm_dir (str): Path to write predicted rttms - clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int), - oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int) - use_torch_script (bool): Boolean that determines whether to use torch.jit.script for speaker clustering - device (torch.device): Device we are running on ('cpu', 'cuda'). - verbose (bool): Enable TQDM progress bar. + AUDIO_RTTM_MAP (dict): + AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path + out_rttm_dir (str): + Path to write predicted rttms + clustering_params (dict): + Clustering parameters provided through config that contains max_num_speakers (int), + oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) + and enhance_count_threshold (int). + use_torch_script (bool): + Boolean that determines whether to use torch.jit.script for speaker clustering + device (torch.device): + Device we are running on ('cpu', 'cuda'). + verbose (bool): + Enable TQDM progress bar. Returns: all_reference (list[uniq_name,Annotation]): reference annotations for score calculation @@ -614,10 +622,9 @@ def read_rttm_lines(rttm_file_path): lines = f.readlines() else: raise FileNotFoundError( - "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( - rttm_file_path + "Requested to construct manifest from rttm with oracle VAD option " + f"or from NeMo VAD but received filename as {rttm_file_path}" ) - ) return lines @@ -886,7 +893,8 @@ def segments_manifest_to_subsegments_manifest( Generate subsegments manifest from segments manifest file Args: segments_manifest file (str): path to segments manifest file, typically from VAD output - subsegments_manifest_file (str): path to output subsegments manifest file (default (None) : writes to current working directory) + subsegments_manifest_file (str): path to output subsegments manifest file + (default (None) : writes to current working directory) window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift min_subsegments_duration (float): exclude subsegments smaller than this duration value @@ -960,7 +968,8 @@ def get_subsegments( it results in (10/0.08)+1 = 125 + 1 frames. Returns: - subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment + subsegments (List[tuple[float, float]]): subsegments generated for the segments as + list of tuple of start and duration of each subsegment """ subsegments: List[List[float]] = [] start = offset @@ -1041,7 +1050,25 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: return [[float(range_tensor[k][0]), float(range_tensor[k][1])] for k in range(range_tensor.shape[0])] -def generate_diarization_output_lines(speaker_timestamps, model_spk_num): +def generate_diarization_output_lines(speaker_timestamps: List[List[float]], model_spk_num: int) -> List[str]: + """ + Generate diarization output lines list from the speaker timestamps list by merging overlapping intervals. + + Args: + speaker_timestamps (list): + List containing the start and end time of the speech intervals for each speaker. + Example: + >>> speaker_timestamps = [[0.5, 3.12], [3.51, 7.26],... ] + model_spk_num (int): + Number of speakers in the model. + + Returns: + speaker_lines_total (list): + List containing the diarization output lines in the format: + "start_time end_time speaker_id" + Example: + >>> speaker_lines_total = ["0.5 3.12 speaker_0", "3.51 7.26 speaker_1",...] + """ speaker_lines_total = [] for spk_idx in range(model_spk_num): ts_invervals = speaker_timestamps[spk_idx] @@ -1334,20 +1361,22 @@ def get_scale_mapping_argmat(uniq_embs_and_timestamps: Dict[str, dict]) -> Dict[ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): """ - Generate timestamps that include overlap speech. Overlap-including timestamps are created based on the segments that are - created for clustering diarizer. Overlap speech is assigned to the existing speech segments in `cont_stamps`. + Generate timestamps that include overlap speech. Overlap-including timestamps are created based on + the segments that are created for clustering diarizer. Overlap speech is assigned to the existing + speech segments in `cont_stamps`. Args: cont_stamps (list): - Non-overlapping (single speaker per segment) diarization output in string format. - Each line contains the start and end time of segments and corresponding speaker labels. + Non-overlapping (single speaker per segment) diarization output in string format. Each line + contains the start and end time of segments and corresponding speaker labels. ovl_spk_idx (list): - List containing segment index of the estimated overlapped speech. The start and end of segments are based on the - single-speaker (i.e., non-overlap-aware) RTTM generation. + List containing segment index of the estimated overlapped speech. The start and end of + segments are based on the single-speaker (i.e., non-overlap-aware) RTTM generation. + Returns: total_ovl_cont_list (list): - Rendered diarization output in string format. Each line contains the start and end time of segments and - corresponding speaker labels. This format is identical to `cont_stamps`. + Rendered diarization output in string format. Each line contains the start and end time of + segments and corresponding speaker labels. This format is identical to `cont_stamps`. """ ovl_spk_cont_list = [[] for _ in range(len(ovl_spk_idx))] for spk_idx in range(len(ovl_spk_idx)): @@ -1364,18 +1393,21 @@ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, overlap_infer_spk_limit: int): """ - This function controls the magnitude of the sigmoid threshold based on the estimated number of speakers. As the number of - speakers becomes larger, diarization error rate is very sensitive on overlap speech detection. This function linearly increases - the threshold in proportion to the estimated number of speakers so more confident overlap speech results are reflected when - the number of estimated speakers are relatively high. + This function controls the magnitude of the sigmoid threshold based on the estimated number of + speakers. As the number of speakers becomes larger, diarization error rate is very sensitive + to overlap speech detection. This function linearly increases the threshold in proportion to + the estimated number of speakers so more confident overlap speech results are reflected when + the number of estimated speakers is relatively high. Args: estimated_num_of_spks (int): Estimated number of speakers from the clustering result. min_threshold (float): - Sigmoid threshold value from the config file. This threshold value is minimum threshold value when `estimated_num_of_spks=2` + Sigmoid threshold value from the config file. This threshold value is the minimum + threshold when `estimated_num_of_spks=2`. overlap_infer_spk_limit (int): - If the `estimated_num_of_spks` is less then `overlap_infer_spk_limit`, overlap speech estimation is skipped. + If the `estimated_num_of_spks` is less than `overlap_infer_spk_limit`, overlap speech + estimation is skipped. Returns: adaptive_threshold (float): @@ -1390,37 +1422,41 @@ def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, ove def generate_speaker_timestamps( clus_labels: List[Union[float, int]], msdd_preds: List[torch.Tensor], **params ) -> Tuple[List[str], List[str]]: - ''' - Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use clustering result for main speaker - labels and use timestamps from the predicted sigmoid values. In this function, the main speaker labels in `maj_labels` exist for - every subsegment steps while overlap speaker labels in `ovl_labels` only exist for segments where overlap-speech is occuring. + """ + Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use + clustering result for main speaker labels and use timestamps from the predicted sigmoid values. + In this function, the main speaker labels in `maj_labels` exist for every subsegment step, while + overlap speaker labels in `ovl_labels` only exist for segments where overlap speech occurs. Args: clus_labels (list): List containing integer-valued speaker clustering results. msdd_preds (list): - List containing tensors of the predicted sigmoid values. - Each tensor has shape of: (Session length, estimated number of speakers). + List containing tensors of the predicted sigmoid values. Each tensor has shape of: + (Session length, estimated number of speakers). params: Parameters for generating RTTM output and evaluation. Parameters include: - infer_overlap (bool): If False, overlap-speech will not be detected. - use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. If False, only MSDD output - is used for constructing output RTTM files. + infer_overlap (bool): If False, overlap speech will not be detected. + use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. + If False, only MSDD output is used for constructing output + RTTM files. overlap_infer_spk_limit (int): Above this limit, overlap-speech detection is bypassed. - use_adaptive_thres (bool): Boolean that determines whehther to use adaptive_threshold depending on the estimated - number of speakers. + use_adaptive_thres (bool): Boolean that determines whether to use adaptive thresholds + depending on the estimated number of speakers. max_overlap_spks (int): Maximum number of overlap speakers detected. Default is 2. threshold (float): Sigmoid threshold for MSDD output. Returns: maj_labels (list): - List containing string-formated single-speaker speech segment timestamps and corresponding speaker labels. + List containing string-formatted single-speaker speech segment timestamps and corresponding + speaker labels. Example: [..., '551.685 552.77 speaker_1', '552.99 554.43 speaker_0', '554.97 558.19 speaker_0', ...] ovl_labels (list): - List containing string-formated additional overlapping speech segment timestamps and corresponding speaker labels. - Note that `ovl_labels` includes only overlapping speech that is not included in `maj_labels`. + List containing string-formatted additional overlapping speech segment timestamps and + corresponding speaker labels. Note that `ovl_labels` includes only overlapping speech that + is not included in `maj_labels`. Example: [..., '152.495 152.745 speaker_1', '372.71 373.085 speaker_0', '554.97 555.885 speaker_1', ...] - ''' + """ msdd_preds.squeeze(0) estimated_num_of_spks = msdd_preds.shape[-1] overlap_speaker_list = [[] for _ in range(estimated_num_of_spks)] @@ -1474,7 +1510,8 @@ def get_id_tup_dict(uniq_id_list: List[str], test_data_collection, preds_list: L uniq_id_list (list): List containing the `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): - Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. + Class instance that is containing session information such as targeted speaker indices, + audio filepath and RTTM filepath. preds_list (list): List containing tensors of predicted sigmoid values. @@ -1503,11 +1540,14 @@ def prepare_split_data(manifest_filepath, _out_dir, multiscale_args_dict, global Returns: multiscale_args_dict (dict): - - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps for each data sample. + - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps + for each data sample. - Each data sample has two keys: `multiscale_weights` and `scale_dict`. - `multiscale_weights` key contains a list containing multiscale weights. - `scale_dict` is indexed by integer keys which are scale index. - - Each data sample is indexed by using the following naming convention: `__` + - Each data sample is indexed by using the following naming convention: + `__` + Example: `fe_03_00106_mixed_626310_642300` """ speaker_dir = os.path.join(_out_dir, 'speaker_outputs') diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index cffcfd1ae5a1..0fbda543ca11 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -36,7 +36,6 @@ from sklearn.model_selection import ParameterGrid from tqdm import tqdm from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel -from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging @@ -66,7 +65,8 @@ def prepare_manifest(config: dict) -> str: input_list = config['input'] else: raise ValueError( - "The input for manifest preparation would either be a string of the filepath to manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} " + "The input for manifest preparation would either be a string of the filepath to manifest " + "or a list of {'audio_filepath': i, 'offset': 0, 'duration': null}." ) args_func = { @@ -246,7 +246,8 @@ def generate_overlap_vad_seq( out_dir: str = None, ) -> str: """ - Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. + Generate predictions with overlapping input windows/segments. + Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. Two common smoothing filters are supported: majority vote (median) and average (mean). This function uses multiprocessing to speed up. Args: @@ -310,8 +311,8 @@ def generate_overlap_vad_seq_per_tensor( frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str ) -> torch.Tensor: """ - Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments - See description in generate_overlap_vad_seq. + Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate + prediction with overlapping input window/segments. See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ # This function will be refactor for vectorization but this is okay for now @@ -472,7 +473,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: @@ -485,7 +487,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te frame_length_in_sec (float): length of frame. Returns: - speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments(torch.Tensor): A tensor of speech segment in the form of: + `torch.Tensor([[start1, end1], [start2, end2]])`. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -535,10 +538,10 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: torch.Tensor) -> torch.Tensor: """ Remove speech segments list in to_be_removed_segments from original_segments. - For example, - remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), - -> - torch.Tensor([[start1, end1],[start3, end3]]) + (Example) Remove torch.Tensor([[start2, end2],[start4, end4]]) + from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), + -> + torch.Tensor([[start1, end1],[start3, end3]]) """ for y in to_be_removed_segments: original_segments = original_segments[original_segments.eq(y).all(dim=1).logical_not()] @@ -559,20 +562,30 @@ def get_gap_segments(segments: torch.Tensor) -> torch.Tensor: @torch.jit.script def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torch.Tensor: """ - Filter out short non_speech and speech segments. + Filter out short non-speech and speech segments. + + Reference: + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. + Implementation: + https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py - Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. - Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). per_args: - min_duration_on (float): threshold for small non_speech deletion - min_duration_off (float): threshold for short speech segment deletion - filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True. + min_duration_on (float): + Threshold for small non-speech deletion. + min_duration_off (float): + Threshold for short speech segment deletion. + filter_speech_first (float): + Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments(torch.Tensor): A tensor of filtered speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of filtered speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). """ if speech_segments.shape == torch.Size([0]): return speech_segments @@ -709,7 +722,8 @@ def generate_vad_segment_table( 17,18, speech Args: vad_pred_dir (str): directory of prediction files to be processed. - postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering. + postprocessing_params (dict): dictionary of thresholds for prediction score. + See details in binarization and filtering. frame_length_in_sec (float): frame length. out_dir (str): output dir of generated table/csv file. num_workers(float): number of process for multiprocessing @@ -820,16 +834,19 @@ def vad_tune_threshold_on_dev( num_workers: int = 20, ) -> Tuple[dict, dict]: """ - Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds. + Tune thresholds on dev set. Return best thresholds which gives the lowest + detection error rate (DetER) in thresholds. + Args: params (dict): dictionary of parameters to be tuned on. - vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median". - groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them. - focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" - frame_length_in_sec (float): frame length. - num_workers (int): number of workers. + vad_pred_method (str): suffix of prediction file. Use to locate file. + Should be either in "frame", "mean" or "median". + groundtruth_RTTM_dir (str): Directory of ground-truth rttm files or a file contains the paths of them. + focus_metric (str): Metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" + frame_length_in_sec (float): Frame length. + num_workers (int): Number of workers. Returns: - best_threshold (float): threshold that gives lowest DetER. + best_threshold (float): Threshold that gives lowest DetER. """ min_score = 100 all_perf = {} @@ -986,7 +1003,8 @@ def plot( threshold (float): threshold for prediction score (from 0 to 1). per_args(dict): a dict that stores the thresholds for postprocessing. unit_frame_len (float): unit frame length in seconds for VAD predictions. - label_repeat (int): repeat the label for this number of times to match different frame lengths in preds and labels. + label_repeat (int): repeat the label for this number of times to match different + frame lengths in preds and labels. xticks_step (int): step size for xticks. """ plt.figure(figsize=[20, 2]) @@ -1254,7 +1272,8 @@ def stitch_segmented_asr_output( fout.flush() logging.info( - f"Finish stitch segmented ASR output to {stitched_output_manifest}, the speech segments info has been stored in directory {speech_segments_tensor_dir}" + f"Finish stitch segmented ASR output to {stitched_output_manifest}, " + f"the speech segments info has been stored in directory {speech_segments_tensor_dir}" ) return stitched_output_manifest @@ -1471,16 +1490,22 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ - Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms). - The threshold 0.2 is not important, since the actual ratio will always be close to an integer unless using frame/label - lengths that are not multiples of each other (e.g., 15ms frame length and 20ms label length), which is not valid. - The value 0.2 here is just for easier unit testing. + Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length + (e.g., 20ms). The threshold 0.2 is not critical, as the actual ratio will always be close to an + integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame + length and 20ms label length), which is not valid. The value 0.2 is chosen for easier unit testing. + Args: - probs (List[float]): list of probabilities - labels (List[int]): list of labels - threshold (float): threshold for rounding ratio to integer + probs (List[float]): + List of probabilities. + labels (List[int]): + List of labels. + threshold (float): + Threshold for rounding the ratio to an integer. + Returns: - labels (List[int]): list of labels aligned to frames + labels (List[int]): + List of labels aligned to frames. """ frames_len = len(probs) labels_len = len(labels) @@ -1511,11 +1536,13 @@ def align_labels_to_frames(probs, labels, threshold=0.2): ratio = frames_len / labels_len res = frames_len % labels_len if ceil(ratio) - ratio < threshold: - # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a multiple of 2, and discard the redundant labels + # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels + # to make it a multiple of 2, and discard the redundant labels labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist() labels = labels[:frames_len] else: - # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of 2 and add additional labels + # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels + # to make it a multiple of 2 and add additional labels labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist() if res > 0: labels += labels[-res:] @@ -1720,7 +1747,8 @@ def ts_vad_post_processing( """ Post-processing on diarization results using VAD style post-processing methods. These post-processing methods are inspired by the following paper: - Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: + a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). Args: ts_vad_binary_vec (Tensor): diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 8acd9fc08743..5edf1724dc2f 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -350,7 +350,9 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): class SpeechLLMAudioTextEntity(object): + """Class for SpeechLLM dataloader instance.""" def __init__(self, sid, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) -> None: + """Initialize the AudioTextEntity for a SpeechLLM dataloader instance.""" self.id = sid self.audio_file = audio_file self.duration = duration @@ -642,7 +644,8 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: elif 'question' in item: # compatability with old manifests that uses 'question' as context key logging.warning( - f"Neither `{self.context_key}` is found nor `context_file` is set, but found `question` in item: {item}", + f"Neither `{self.context_key}` is found nor" + f"`context_file` is set, but found `question` in item: {item}", mode=logging_mode.ONCE, ) item['context'] = item.pop('question') @@ -739,7 +742,8 @@ def __init__( else: logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") logging.info( - f"Dataset successfully loaded with {len(data)} items and total duration provided from manifest is {total_duration / 3600: .2f} hours." + f"Dataset successfully loaded with {len(data)} items " + f"and total duration provided from manifest is {total_duration / 3600: .2f} hours." ) self.uniq_labels = sorted(set(map(lambda x: x.label, data))) @@ -880,13 +884,15 @@ def __init__( if len(data) == max_number: break - logging.info("# {} files loaded including # {} unique labels".format(len(data), len(self.uniq_labels))) + logging.info(f"# {len(data)} files loaded including # {len(self.uniq_labels)} unique labels") super().__init__(data) def relative_speaker_parser(self, seq_label): """Convert sequence of speaker labels to relative labels. Convert sequence of absolute speaker to sequence of relative speaker [E A C A E E C] -> [0 1 2 1 0 0 2] - In this seq of label , if label do not appear before, assign new relative labels len(pos); else reuse previous assigned relative labels. + In this seq of label , if label do not appear before, assign new relative labels len(pos); + else reuse previous assigned relative labels. + Args: seq_label (str): A string of a sequence of labels. @@ -923,10 +929,13 @@ def __init__( """Parse lists of feature files and sequences of labels. Args: - manifests_files: Either single string file or list of such - - manifests to yield items from. - max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + manifests_files: + Either single string file or list of such manifests to yield items from. + max_number: + Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: + If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, seq_labels = [], [] @@ -1088,24 +1097,26 @@ def __init__( **kwargs, ): """ - Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since diarization model infers only - two speakers, speaker pairs are generated from the total number of speakers in the session. + Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization + model infers only two speakers, speaker pairs are generated from the total number of speakers in + the session. Args: manifest_filepath (str): - Path to input manifest json files. + Path to input manifest JSON files. emb_dict (Dict): Dictionary containing cluster-average embeddings and speaker mapping information. clus_label_dict (Dict): Segment-level speaker labels from clustering results. round_digit (int): - Number of digits to be rounded. + Number of digits to round. seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. pairwise_infer (bool): - If True, this dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the diarization system to merge the individual results. + If True, this dataset class operates in inference mode. In inference mode, a set of + speakers in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to + merge the individual results. *args: Args to pass to `SpeechLabel` constructor. **kwargs: Kwargs to pass to `SpeechLabel` constructor. """ @@ -1244,7 +1255,7 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: class EndtoEndDiarizationLabel(_Collection): - """List of diarization audio-label correspondence with preprocessing.""" + """List of end-to-end diarization audio-label correspondence with preprocessing.""" OUTPUT_TYPE = collections.namedtuple( typename='DiarizationLabelEntity', @@ -1276,7 +1287,8 @@ def __init__( offsets (List[float]): List of offsets or None for each audio file. max_number (Optional[int]): Maximum number of samples to collect. Defaults to None. do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False. - index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. Defaults to False. + index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. + Defaults to False. """ if index_by_file_id: @@ -1694,7 +1706,8 @@ def __init__( manifests_files: Either single string file or list of such - manifests to yield items from. max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, labels, durations = [], [], [] From cb232682b9d142b086524d78d41cca10f6d55249 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 15 Nov 2024 20:34:41 +0000 Subject: [PATCH 11/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../neural_diarizer/sortformer_diar_train.py | 4 +- .../asr/data/audio_to_diar_label.py | 52 ++++++------- nemo/collections/asr/metrics/der.py | 18 ++--- .../asr/metrics/multi_binary_acc.py | 6 +- .../asr/models/sortformer_diar_models.py | 4 +- .../asr/parts/utils/speaker_utils.py | 76 +++++++++---------- nemo/collections/asr/parts/utils/vad_utils.py | 58 +++++++------- .../common/parts/preprocessing/collections.py | 27 +++---- 8 files changed, 124 insertions(+), 121 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 75980d342c65..5231f4822886 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -34,6 +34,7 @@ seed_everything(42) + @hydra_runner(config_path="../conf/neural_diarizer", config_name="sortformer_diarizer_hybrid_loss_4spk-v1.yaml") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') @@ -43,5 +44,6 @@ def main(cfg): sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) trainer.fit(sortformer_model) + if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index f47b5ca11f43..34aa0989e564 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -81,7 +81,7 @@ def extract_seg_info_from_rttm(rttm_lines, mapping_dict=None, target_spks=None): mapping_dict (dict): Mapping between the estimated speakers and the speakers in the ground-truth annotation. `mapping_dict` variable is only provided when the inference mode is running in sequence-eval mode. - Sequence eval mode uses the mapping between the estimated speakers and the speakers + Sequence eval mode uses the mapping between the estimated speakers and the speakers in ground-truth annotation. Returns: rttm_tup (tuple): @@ -116,10 +116,10 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, List containing start and end time for each speaker segment label. `stt_list`, `end_list` and `speaker_list` are contained. frame_per_sec (int): - Number of feature frames per second. This quantity is determined by + Number of feature frames per second. This quantity is determined by `window_stride` variable in preprocessing module. target_spks (tuple): - Speaker indices that are generated from combinations. + Speaker indices that are generated from combinations. If there are only one or two speakers, only a single `target_spks` variable is generated. @@ -151,13 +151,13 @@ def get_subsegments_to_timestamps( subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 ): """ - Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) - and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) + and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` in end-to-end speaker diarization models. Args: subsegments (List[Tuple[float, float]]): - A list of tuples where each tuple contains the start and end times of a subsegment + A list of tuples where each tuple contains the start and end times of a subsegment (frames in end-to-end models). >>> subsegments = [[t0_start, t0_duration], [t1_start, t1_duration],..., [tN_start, tN_duration]] feat_per_sec (int, optional): @@ -251,7 +251,7 @@ def get_frame_targets_from_rttm( List containing start and end time for each speaker segment label. stt_list, end_list and speaker_list are contained. feat_per_sec (int): - Number of feature frames per second. + Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, @@ -415,16 +415,16 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): def get_diar_target_labels(self, uniq_id, sample, fr_level_target): """ - Convert frame-level diarization target variable into segment-level target variable. - Since the granularity is reduced from frame level (10ms) to segment level (100ms~500ms), - we need a threshold value, `soft_label_thres`, which determines the label of each segment + Convert frame-level diarization target variable into segment-level target variable. + Since the granularity is reduced from frame level (10ms) to segment level (100ms~500ms), + we need a threshold value, `soft_label_thres`, which determines the label of each segment based on the overlap between a segment range (start and end time) and the frame-level target variable. Args: uniq_id (str): Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file. sample: - `DiarizationSpeechLabel` instance containing sample information such as + `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. fr_level_target (torch.tensor): Tensor containing label for each feature-level frame. @@ -433,7 +433,7 @@ def get_diar_target_labels(self, uniq_id, sample, fr_level_target): seg_target (torch.tensor): Tensor containing binary speaker labels for base-scale segments. base_clus_label (torch.tensor): - Representative speaker label for each segment. This variable only has one speaker label + Representative speaker label for each segment. This variable only has one speaker label for each base-scale segment. -1 means that there is no corresponding speaker in the target_spks tuple. """ @@ -469,7 +469,7 @@ def parse_rttm_for_ms_targets(self, sample): Args: sample: - `DiarizationSpeechLabel` instance containing sample information such as + `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, @@ -591,7 +591,7 @@ class _AudioMSDDInferDataset(Dataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, + Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. @@ -691,8 +691,8 @@ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): """ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate - ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared - with `soft_label_thres` to determine whether a label vector should contain 0 or 1 for each speaker bin. + ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared + with `soft_label_thres` to determine whether a label vector should contain 0 or 1 for each speaker bin. Note that seg_target variable has dimension of (number of base-scale segments x 2) dimension. Example of seg_target: @@ -739,7 +739,7 @@ def __getitem__(self, index): if avg_embs.shape[2] > self.max_spks: raise ValueError( - f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to " + f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to " f"self.max_num_speakers {self.max_spks}" ) @@ -834,7 +834,7 @@ def _msdd_train_collate_fn(self, batch): def _msdd_infer_collate_fn(self, batch): """ - Collate batch of feats (speaker embeddings), feature lengths, target label sequences + Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings. Args: @@ -959,12 +959,12 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping + Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): - Threshold that determines speaker labels of segments depending on the overlap + Threshold that determines speaker labels of segments depending on the overlap with groundtruth speaker timestamps. featurizer: Featurizer instance for generating features from raw waveform. @@ -973,11 +973,11 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. window_stride (float): - Window stride for acoustic feature. This value is used for calculating the numbers of + Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. pairwise_infer (bool): - If True, this Dataset class operates in inference mode. In inference mode, a set of speakers - in the input audio is split into multiple pairs of speakers and speaker tuples + If True, this Dataset class operates in inference mode. In inference mode, a set of speakers + in the input audio is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the MSDD to merge the individual results. """ @@ -1225,7 +1225,7 @@ def __getitem__(self, index): uniq_id = self.get_uniq_id_with_range(sample) audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) - # We should resolve the length mis-match from the round-off errors between these two variables: + # We should resolve the length mis-match from the round-off errors between these two variables: # `session_len_sec` and `audio_signal.shape[0]` session_len_sec = ( np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal @@ -1252,14 +1252,14 @@ def _eesd_train_collate_fn(self, batch): Returns: audio_signal (torch.Tensor): - A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` + A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` in the input manifest file. feature_length (torch.Tensor): A tensor containing the lengths of the raw waveform samples. targets (torch.Tensor): Groundtruth speaker labels for the given input embedding sequence. target_lens (torch.Tensor): - A tensor containing the number of segments for each sample in the batch, necessary for + A tensor containing the number of segments for each sample in the batch, necessary for reshaping inputs to the EESD model. """ packed_batch = list(zip(*batch)) diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index 22c9a76b7fc9..7496f700341f 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -145,20 +145,20 @@ def score_labels( Args: - AUDIO_RTTM_MAP (dict): + AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath - all_reference (list[uniq_name,Annotation]): + all_reference (list[uniq_name,Annotation]): Reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): + all_hypothesis (list[uniq_name,Annotation]): Hypothesis annotations for score calculation - verbose (bool): + verbose (bool): Warns if RTTM file is not found. Returns: - metric (pyannote.DiarizationErrorRate): - Pyannote Diarization Error Rate metric object. + metric (pyannote.DiarizationErrorRate): + Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. - mapping (dict): + mapping (dict): Mapping dict containing the mapping speaker label for each audio input < Caveat > @@ -178,7 +178,7 @@ def score_labels( correct_spk_count += 1 if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): logging.info( - f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " + f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " f"Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" ) uem_obj = None @@ -455,4 +455,4 @@ def concat_perm_word_error_rate( cpWER_values.append(cpWER) hyps_spk.append(min_hypothesis) refs_spk.append(concat_reference) - return cpWER_values, hyps_spk, refs_spk \ No newline at end of file + return cpWER_values, hyps_spk, refs_spk diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 8ad09c842636..3a99769ebd25 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -84,15 +84,15 @@ def __init__(self, dist_sync_on_step=False): def update( self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False ) -> torch.Tensor: - """ + """ Update the metric with the given predictions, targets, and signal lengths to the metric instance. - + Args: preds (torch.Tensor): Predicted values. targets (torch.Tensor): Target values. signal_lengths (torch.Tensor): Length of each sequence in the batch input. cumulative (bool): Whether to accumulate the values over time. - + Returns: f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. """ diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index e3c14dd77c65..2e15e095b77a 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -208,8 +208,8 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): def test_dataloader(self): if self._test_dl is not None: return self._test_dl - return None - + return None + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: if hasattr(self.preprocessor, '_sample_rate'): diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 1e7dda91c9e7..15ec8a24a3bd 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -76,11 +76,11 @@ def audio_rttm_map(manifest, attach_dur=False): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - + Args: manifest (str): Path to the manifest file attach_dur (bool, optional): If True, attach duration information to the unique name. Defaults to False. - + Returns: AUDIO_RTTM_MAP (dict) : Dictionary with unique names as keys and corresponding metadata as values. """ @@ -117,7 +117,7 @@ def audio_rttm_map(manifest, attach_dur=False): raise KeyError( f"file {meta['audio_filepath']} is already part of AUDIO_RTTM_MAP, it might be duplicated, " "Note: file basename must be unique" - ) + ) return AUDIO_RTTM_MAP @@ -247,7 +247,7 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg def get_timestamps(multiscale_timestamps, multiscale_args_dict): """ The timestamps in `multiscale_timestamps` dictionary are indexed by scale index. - This function rearranges the extracted speaker embedding and timestamps by unique ID + This function rearranges the extracted speaker embedding and timestamps by unique ID to make the further processing more convenient. Args: @@ -442,19 +442,19 @@ def perform_clustering( 'embeddings' : Tensor containing embeddings. Dimensions:(# of embs) x (emb. dimension) 'timestamps' : Tensor containing ime stamps list for each audio recording 'multiscale_segment_counts' : Tensor containing the number of segments for each scale - AUDIO_RTTM_MAP (dict): + AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path - out_rttm_dir (str): + out_rttm_dir (str): Path to write predicted rttms - clustering_params (dict): - Clustering parameters provided through config that contains max_num_speakers (int), - oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) + clustering_params (dict): + Clustering parameters provided through config that contains max_num_speakers (int), + oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int). - use_torch_script (bool): + use_torch_script (bool): Boolean that determines whether to use torch.jit.script for speaker clustering - device (torch.device): + device (torch.device): Device we are running on ('cpu', 'cuda'). - verbose (bool): + verbose (bool): Enable TQDM progress bar. Returns: @@ -624,7 +624,7 @@ def read_rttm_lines(rttm_file_path): raise FileNotFoundError( "Requested to construct manifest from rttm with oracle VAD option " f"or from NeMo VAD but received filename as {rttm_file_path}" - ) + ) return lines @@ -893,7 +893,7 @@ def segments_manifest_to_subsegments_manifest( Generate subsegments manifest from segments manifest file Args: segments_manifest file (str): path to segments manifest file, typically from VAD output - subsegments_manifest_file (str): path to output subsegments manifest file + subsegments_manifest_file (str): path to output subsegments manifest file (default (None) : writes to current working directory) window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift @@ -968,7 +968,7 @@ def get_subsegments( it results in (10/0.08)+1 = 125 + 1 frames. Returns: - subsegments (List[tuple[float, float]]): subsegments generated for the segments as + subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment """ subsegments: List[List[float]] = [] @@ -1051,9 +1051,9 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: def generate_diarization_output_lines(speaker_timestamps: List[List[float]], model_spk_num: int) -> List[str]: - """ + """ Generate diarization output lines list from the speaker timestamps list by merging overlapping intervals. - + Args: speaker_timestamps (list): List containing the start and end time of the speech intervals for each speaker. @@ -1061,7 +1061,7 @@ def generate_diarization_output_lines(speaker_timestamps: List[List[float]], mod >>> speaker_timestamps = [[0.5, 3.12], [3.51, 7.26],... ] model_spk_num (int): Number of speakers in the model. - + Returns: speaker_lines_total (list): List containing the diarization output lines in the format: @@ -1393,20 +1393,20 @@ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, overlap_infer_spk_limit: int): """ - This function controls the magnitude of the sigmoid threshold based on the estimated number of - speakers. As the number of speakers becomes larger, diarization error rate is very sensitive - to overlap speech detection. This function linearly increases the threshold in proportion to - the estimated number of speakers so more confident overlap speech results are reflected when + This function controls the magnitude of the sigmoid threshold based on the estimated number of + speakers. As the number of speakers becomes larger, diarization error rate is very sensitive + to overlap speech detection. This function linearly increases the threshold in proportion to + the estimated number of speakers so more confident overlap speech results are reflected when the number of estimated speakers is relatively high. Args: estimated_num_of_spks (int): Estimated number of speakers from the clustering result. min_threshold (float): - Sigmoid threshold value from the config file. This threshold value is the minimum + Sigmoid threshold value from the config file. This threshold value is the minimum threshold when `estimated_num_of_spks=2`. overlap_infer_spk_limit (int): - If the `estimated_num_of_spks` is less than `overlap_infer_spk_limit`, overlap speech + If the `estimated_num_of_spks` is less than `overlap_infer_spk_limit`, overlap speech estimation is skipped. Returns: @@ -1423,37 +1423,37 @@ def generate_speaker_timestamps( clus_labels: List[Union[float, int]], msdd_preds: List[torch.Tensor], **params ) -> Tuple[List[str], List[str]]: """ - Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use - clustering result for main speaker labels and use timestamps from the predicted sigmoid values. - In this function, the main speaker labels in `maj_labels` exist for every subsegment step, while + Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use + clustering result for main speaker labels and use timestamps from the predicted sigmoid values. + In this function, the main speaker labels in `maj_labels` exist for every subsegment step, while overlap speaker labels in `ovl_labels` only exist for segments where overlap speech occurs. Args: clus_labels (list): List containing integer-valued speaker clustering results. msdd_preds (list): - List containing tensors of the predicted sigmoid values. Each tensor has shape of: + List containing tensors of the predicted sigmoid values. Each tensor has shape of: (Session length, estimated number of speakers). params: Parameters for generating RTTM output and evaluation. Parameters include: infer_overlap (bool): If False, overlap speech will not be detected. - use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. - If False, only MSDD output is used for constructing output + use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. + If False, only MSDD output is used for constructing output RTTM files. overlap_infer_spk_limit (int): Above this limit, overlap-speech detection is bypassed. - use_adaptive_thres (bool): Boolean that determines whether to use adaptive thresholds + use_adaptive_thres (bool): Boolean that determines whether to use adaptive thresholds depending on the estimated number of speakers. max_overlap_spks (int): Maximum number of overlap speakers detected. Default is 2. threshold (float): Sigmoid threshold for MSDD output. Returns: maj_labels (list): - List containing string-formatted single-speaker speech segment timestamps and corresponding + List containing string-formatted single-speaker speech segment timestamps and corresponding speaker labels. Example: [..., '551.685 552.77 speaker_1', '552.99 554.43 speaker_0', '554.97 558.19 speaker_0', ...] ovl_labels (list): - List containing string-formatted additional overlapping speech segment timestamps and - corresponding speaker labels. Note that `ovl_labels` includes only overlapping speech that + List containing string-formatted additional overlapping speech segment timestamps and + corresponding speaker labels. Note that `ovl_labels` includes only overlapping speech that is not included in `maj_labels`. Example: [..., '152.495 152.745 speaker_1', '372.71 373.085 speaker_0', '554.97 555.885 speaker_1', ...] """ @@ -1510,7 +1510,7 @@ def get_id_tup_dict(uniq_id_list: List[str], test_data_collection, preds_list: L uniq_id_list (list): List containing the `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): - Class instance that is containing session information such as targeted speaker indices, + Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. preds_list (list): List containing tensors of predicted sigmoid values. @@ -1540,14 +1540,14 @@ def prepare_split_data(manifest_filepath, _out_dir, multiscale_args_dict, global Returns: multiscale_args_dict (dict): - - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps + - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps for each data sample. - Each data sample has two keys: `multiscale_weights` and `scale_dict`. - `multiscale_weights` key contains a list containing multiscale weights. - `scale_dict` is indexed by integer keys which are scale index. - - Each data sample is indexed by using the following naming convention: + - Each data sample is indexed by using the following naming convention: `__` - + Example: `fe_03_00106_mixed_626310_642300` """ speaker_dir = os.path.join(_out_dir, 'speaker_outputs') diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 0fbda543ca11..83a811ee4adb 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -246,7 +246,7 @@ def generate_overlap_vad_seq( out_dir: str = None, ) -> str: """ - Generate predictions with overlapping input windows/segments. + Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. Two common smoothing filters are supported: majority vote (median) and average (mean). This function uses multiprocessing to speed up. @@ -311,7 +311,7 @@ def generate_overlap_vad_seq_per_tensor( frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str ) -> torch.Tensor: """ - Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate + Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments. See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ @@ -473,7 +473,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py @@ -488,7 +488,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Returns: speech_segments(torch.Tensor): A tensor of speech segment in the form of: - `torch.Tensor([[start1, end1], [start2, end2]])`. + `torch.Tensor([[start1, end1], [start2, end2]])`. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -538,7 +538,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: torch.Tensor) -> torch.Tensor: """ Remove speech segments list in to_be_removed_segments from original_segments. - (Example) Remove torch.Tensor([[start2, end2],[start4, end4]]) + (Example) Remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), -> torch.Tensor([[start1, end1],[start3, end3]]) @@ -565,26 +565,26 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc Filter out short non-speech and speech segments. Reference: - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. - Implementation: + Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): - A tensor of speech segments in the format + speech_segments (torch.Tensor): + A tensor of speech segments in the format torch.Tensor([[start1, end1], [start2, end2]]). per_args: - min_duration_on (float): + min_duration_on (float): Threshold for small non-speech deletion. - min_duration_off (float): + min_duration_off (float): Threshold for short speech segment deletion. - filter_speech_first (float): + filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments (torch.Tensor): - A tensor of filtered speech segments in the format + speech_segments (torch.Tensor): + A tensor of filtered speech segments in the format torch.Tensor([[start1, end1], [start2, end2]]). """ if speech_segments.shape == torch.Size([0]): @@ -722,7 +722,7 @@ def generate_vad_segment_table( 17,18, speech Args: vad_pred_dir (str): directory of prediction files to be processed. - postprocessing_params (dict): dictionary of thresholds for prediction score. + postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering. frame_length_in_sec (float): frame length. out_dir (str): output dir of generated table/csv file. @@ -834,12 +834,12 @@ def vad_tune_threshold_on_dev( num_workers: int = 20, ) -> Tuple[dict, dict]: """ - Tune thresholds on dev set. Return best thresholds which gives the lowest + Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds. - + Args: params (dict): dictionary of parameters to be tuned on. - vad_pred_method (str): suffix of prediction file. Use to locate file. + vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median". groundtruth_RTTM_dir (str): Directory of ground-truth rttm files or a file contains the paths of them. focus_metric (str): Metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" @@ -1003,7 +1003,7 @@ def plot( threshold (float): threshold for prediction score (from 0 to 1). per_args(dict): a dict that stores the thresholds for postprocessing. unit_frame_len (float): unit frame length in seconds for VAD predictions. - label_repeat (int): repeat the label for this number of times to match different + label_repeat (int): repeat the label for this number of times to match different frame lengths in preds and labels. xticks_step (int): step size for xticks. """ @@ -1490,21 +1490,21 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ - Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length - (e.g., 20ms). The threshold 0.2 is not critical, as the actual ratio will always be close to an - integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame + Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length + (e.g., 20ms). The threshold 0.2 is not critical, as the actual ratio will always be close to an + integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame length and 20ms label length), which is not valid. The value 0.2 is chosen for easier unit testing. Args: - probs (List[float]): + probs (List[float]): List of probabilities. - labels (List[int]): + labels (List[int]): List of labels. - threshold (float): + threshold (float): Threshold for rounding the ratio to an integer. Returns: - labels (List[int]): + labels (List[int]): List of labels aligned to frames. """ frames_len = len(probs) @@ -1536,12 +1536,12 @@ def align_labels_to_frames(probs, labels, threshold=0.2): ratio = frames_len / labels_len res = frames_len % labels_len if ceil(ratio) - ratio < threshold: - # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels + # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels # to make it a multiple of 2, and discard the redundant labels labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist() labels = labels[:frames_len] else: - # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels + # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels # to make it a multiple of 2 and add additional labels labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist() if res > 0: @@ -1747,7 +1747,7 @@ def ts_vad_post_processing( """ Post-processing on diarization results using VAD style post-processing methods. These post-processing methods are inspired by the following paper: - Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). Args: diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index a4728c29ff06..b6db109afa58 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -477,6 +477,7 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): class SpeechLLMAudioTextEntity(object): """Class for SpeechLLM dataloader instance.""" + def __init__(self, sid, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) -> None: """Initialize the AudioTextEntity for a SpeechLLM dataloader instance.""" self.id = sid @@ -1016,9 +1017,9 @@ def __init__( def relative_speaker_parser(self, seq_label): """Convert sequence of speaker labels to relative labels. Convert sequence of absolute speaker to sequence of relative speaker [E A C A E E C] -> [0 1 2 1 0 0 2] - In this seq of label , if label do not appear before, assign new relative labels len(pos); + In this seq of label , if label do not appear before, assign new relative labels len(pos); else reuse previous assigned relative labels. - + Args: seq_label (str): A string of a sequence of labels. @@ -1055,12 +1056,12 @@ def __init__( """Parse lists of feature files and sequences of labels. Args: - manifests_files: + manifests_files: Either single string file or list of such manifests to yield items from. - max_number: + max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: - If True, saves a mapping from filename base (ID) to index in data; + index_by_file_id: + If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. """ @@ -1223,8 +1224,8 @@ def __init__( **kwargs, ): """ - Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization - model infers only two speakers, speaker pairs are generated from the total number of speakers in + Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization + model infers only two speakers, speaker pairs are generated from the total number of speakers in the session. Args: @@ -1239,9 +1240,9 @@ def __init__( seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. pairwise_infer (bool): - If True, this dataset class operates in inference mode. In inference mode, a set of - speakers in the input audio is split into multiple pairs of speakers and speaker tuples - (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to + If True, this dataset class operates in inference mode. In inference mode, a set of + speakers in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to merge the individual results. *args: Args to pass to `SpeechLabel` constructor. **kwargs: Kwargs to pass to `SpeechLabel` constructor. @@ -1413,7 +1414,7 @@ def __init__( offsets (List[float]): List of offsets or None for each audio file. max_number (Optional[int]): Maximum number of samples to collect. Defaults to None. do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False. - index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. + index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. Defaults to False. """ @@ -1832,7 +1833,7 @@ def __init__( manifests_files: Either single string file or list of such - manifests to yield items from. max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. """ From 4a266b93a38384236f41a58c642170dd2c4fac03 Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 14:49:28 -0800 Subject: [PATCH 12/47] Resolving pylint issues Signed-off-by: taejinp --- .../neural_diarizer/e2e_diarize_speech.py | 22 ++++++++- .../neural_diarizer/sortformer_diar_train.py | 1 + .../asr/data/audio_to_diar_label.py | 2 +- .../asr/data/audio_to_diar_label_lhotse.py | 2 + .../asr/models/sortformer_diar_models.py | 6 ++- .../asr/modules/sortformer_modules.py | 45 ++++++++++-------- .../asr/parts/utils/asr_multispeaker_utils.py | 46 ++++++++++++------- .../common/parts/preprocessing/collections.py | 10 ++-- 8 files changed, 88 insertions(+), 46 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 0f90e70eff80..cb09b7df3100 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +45,12 @@ @dataclass class PostProcessingParams: + """ + Postprocessing parameters for end-to-end speaker diarization models. + These parameters can significantly affect DER performance depending on the evaluation style and the dataset. + It is recommended to tune these parameters based on the evaluation style and the dataset + to achieve the desired DER performance. + """ onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech offset: float = 0.5 # Offset threshold for detecting the end of a speech pad_onset: float = 0.0 # Adding durations before each speech segment @@ -55,7 +61,7 @@ class PostProcessingParams: @dataclass class DiarizationConfig: - # Required configs + """Diarization configuration parameters for inference.""" model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files @@ -221,6 +227,17 @@ def run_optuna_hyperparam_search( preds_list: List[torch.Tensor], temp_out_dir: str, ): + """ + Run Optuna hyperparameter optimization for speaker diarization. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + infer_audio_rttm_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + temp_out_dir (str): temporary directory for storing intermediate outputs. + """ worker_function = lambda trial: diarization_objective( trial=trial, postprocessing_cfg=postprocessing_cfg, @@ -300,6 +317,7 @@ def convert_pred_mat_to_segments( @hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: + """Main function for end-to-end speaker diarization inference.""" for key in cfg: cfg[key] = None if cfg[key] == 'None' else cfg[key] diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 75980d342c65..bff2218e361f 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -36,6 +36,7 @@ @hydra_runner(config_path="../conf/neural_diarizer", config_name="sortformer_diarizer_hybrid_loss_4spk-v1.yaml") def main(cfg): + """Main function for training the sortformer diarizer model.""" logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index f47b5ca11f43..48454e310070 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index 8d11c4c1167d..14723a398fe8 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -37,6 +37,8 @@ class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Define the output types of the dataset. + """ return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index e3c14dd77c65..eadabe642779 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -172,7 +172,8 @@ def __setup_dataloader_from_config(self, config): soft_targets=config.soft_targets if 'soft_targets' in config else False, ) logging.info( - f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader step B: {time.time() - time_flag}" + f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader" + f"step B: {time.time() - time_flag}" ) self.data_collection = dataset.collection @@ -581,4 +582,5 @@ def test_batch( def diarize( self, ): + """One-clieck runner function for diarization.""" raise NotImplementedError diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index e0b5b15094b6..36b7438c9a92 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict - import torch import torch.nn as nn import torch.nn.functional as F @@ -26,24 +24,13 @@ class SortformerModules(NeuralModule, Exportable): """ - Multi-scale Diarization Decoder (MSDD) for overlap-aware diarization and improved diarization accuracy from clustering diarizer. - Based on the paper: Taejin Park et. al, "Multi-scale Speaker Diarization with Dynamic Scale Weighting", Interspeech 2022. - Arxiv version: https://arxiv.org/pdf/2203.15974.pdf - - Args: - num_spks (int): - Max number of speakers that are processed by the model. In `MSDD_module`, `num_spks=2` for pairwise inference. - hidden_size (int): - Number of hidden units in sequence models and intermediate layers. - dropout_rate (float): - Dropout rate for linear layers, CNN and LSTM. - fc_d_model (int): - Dimension of the embedding vectors. - tf_d_model (int): - Dimension of the embedding vectors. + A class including auxiliary functions for Sortformer models. + This class contains and will contain the following functions that performs streaming features, + and any neural layers that are not included in the NeMo neural modules (e.g. Transformer, Fast-Conformer). """ def init_weights(self, m): + """Init weights for linear layers.""" if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) @@ -56,6 +43,19 @@ def __init__( fc_d_model: int = 512, tf_d_model: int = 192, ): + """ + Args: + num_spks (int): + Max number of speakers that are processed by the model. + hidden_size (int): + Number of hidden units in sequence models and intermediate layers. + dropout_rate (float): + Dropout rate for linear layers, CNN and LSTM. + fc_d_model (int): + Dimension of the embedding vectors. + tf_d_model (int): + Dimension of the embedding vectors. + """ super().__init__() self.fc_d_model = fc_d_model self.tf_d_model = tf_d_model @@ -91,6 +91,15 @@ def length_to_mask(self, context_embs): return mask.float().to(context_embs.device) def forward_speaker_sigmoids(self, hidden_out): + """ + A set of layers for predicting speaker probabilities with a sigmoid activation function. + + Args: + hidden_out (torch.Tensor): tensor of shape (batch_size, seq_len, hidden_size) + + Returns: + preds (torch.Tensor): tensor of shape (batch_size, num_spks) containing speaker probabilities + """ hidden_out = self.dropout(F.relu(hidden_out)) hidden_out = self.first_hidden_to_hidden(hidden_out) hidden_out = self.dropout(F.relu(hidden_out)) diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index e945439bf8fa..46bcf2f1a8c6 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -28,7 +28,8 @@ def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> thres (float): The threshold value for discretizing the matrix values. Returns: - mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first nonzero value in each row. + mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first + nonzero value in each row. """ # Discretize the matrix to the specified maximum capacity labels_discrete = mat.clone() @@ -229,7 +230,8 @@ def get_mask_from_segments( speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. + Will be removed in the future. Returns: mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). @@ -315,25 +317,34 @@ def speaker_to_target( ignore_num_spk_mismatch: bool = True, soft_thres: float = 0.5, ): - ''' - Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) - This function is needed for speaker diarization with ASR model trainings. + """ + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape + (num_speaker, hidden_length). This function is needed for speaker diarization with ASR model trainings. Args: - a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. - num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default - num_sample_per_mel_frame (int): number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) - num_mel_frame_per_asr_frame (int): encoder subsampling_factor, 8 by default - spk_tar_all_zero (Tensor): set to True gives all zero "mask" - boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training - soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. + a_cut (MonoCut, MixedCut): + Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): + Max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): + Number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): + Encoder subsampling_factor, 8 by default + spk_tar_all_zero (Tensor): + Set to True gives all zero "mask" + boundary_segments (bool): + Set to True to include segments containing the boundary of the cut, + False by default for multi-speaker ASR training + soft_label (bool): + Set to True to use soft label that enables values in [0, 1] range, + False by default and leads to binary labels. + ignore_num_spk_mismatch (bool): + This is a temporary solution to handle speaker mismatch. Will be removed in the future. Returns: - mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) - ''' + mask (Tensor): Speaker mask with shape (num_speaker, hidden_lenght) + """ # get cut-related segments from rttms - # basename = os.path.basename(a_cut.rttm_filepath).replace('.rttm', '') if isinstance(a_cut, MixedCut): cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] @@ -374,7 +385,8 @@ def speaker_to_target( speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers raise ValueError( - f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}" + f"Number of speakers {len(speaker_to_idx_map)} is larger than " + f"the maximum number of speakers {num_speakers}" ) # initialize mask matrices (num_speaker, encoder_hidden_len) diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index a4728c29ff06..13f9efe48a06 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -310,10 +310,9 @@ def __init__( class InstructionTuningAudioText(_Collection): """`AudioText` collector from asr structured json files.""" - OUTPUT_TYPE = collections.namedtuple( - typename='InstructionTuningText', - field_names='id context context_type context_duration question question_type answer answer_type answer_duration speaker', - ) + OUTPUT_TYPE = collections.namedtuple(typename='InstructionTuningText', + field_names=('id context context_type context_duration question ' + 'question_type answer answer_type answer_duration speaker'),) def __init__( self, @@ -559,7 +558,6 @@ def __init__( ): """Instantiates audio-context-answer manifest with filters and preprocessing. - Args: ids: List of examples positions. audio_files: List of audio files. @@ -1471,7 +1469,7 @@ def __init__( class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel): - """`DiarizationLabel` diarization data sample collector from structured json files.""" + """End-to-end speaker diarization data sample collector from structured json files.""" def __init__( self, From 6e2225ef65552b6c6c557779bf3ee6b7b3f1f340 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 15 Nov 2024 22:50:57 +0000 Subject: [PATCH 13/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../neural_diarizer/e2e_diarize_speech.py | 4 ++- .../asr/data/audio_to_diar_label_lhotse.py | 3 +-- .../asr/models/sortformer_diar_models.py | 2 +- .../asr/modules/sortformer_modules.py | 8 +++--- .../asr/parts/utils/asr_multispeaker_utils.py | 26 +++++++++---------- .../common/parts/preprocessing/collections.py | 10 ++++--- 6 files changed, 29 insertions(+), 24 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index cb09b7df3100..65ba0226988a 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -48,9 +48,10 @@ class PostProcessingParams: """ Postprocessing parameters for end-to-end speaker diarization models. These parameters can significantly affect DER performance depending on the evaluation style and the dataset. - It is recommended to tune these parameters based on the evaluation style and the dataset + It is recommended to tune these parameters based on the evaluation style and the dataset to achieve the desired DER performance. """ + onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech offset: float = 0.5 # Offset threshold for detecting the end of a speech pad_onset: float = 0.0 # Adding durations before each speech segment @@ -62,6 +63,7 @@ class PostProcessingParams: @dataclass class DiarizationConfig: """Diarization configuration parameters for inference.""" + model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index 14723a398fe8..0839b63954f0 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -37,8 +37,7 @@ class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Define the output types of the dataset. - """ + """Define the output types of the dataset.""" return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index b54bfcc0d05c..5a3c8e354f1b 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -172,7 +172,7 @@ def __setup_dataloader_from_config(self, config): soft_targets=config.soft_targets if 'soft_targets' in config else False, ) logging.info( - f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader" + f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader" f"step B: {time.time() - time_flag}" ) diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index 36b7438c9a92..193dae29c304 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -43,10 +43,10 @@ def __init__( fc_d_model: int = 512, tf_d_model: int = 192, ): - """ + """ Args: num_spks (int): - Max number of speakers that are processed by the model. + Max number of speakers that are processed by the model. hidden_size (int): Number of hidden units in sequence models and intermediate layers. dropout_rate (float): @@ -54,7 +54,7 @@ def __init__( fc_d_model (int): Dimension of the embedding vectors. tf_d_model (int): - Dimension of the embedding vectors. + Dimension of the embedding vectors. """ super().__init__() self.fc_d_model = fc_d_model @@ -93,7 +93,7 @@ def length_to_mask(self, context_embs): def forward_speaker_sigmoids(self, hidden_out): """ A set of layers for predicting speaker probabilities with a sigmoid activation function. - + Args: hidden_out (torch.Tensor): tensor of shape (batch_size, seq_len, hidden_size) diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 46bcf2f1a8c6..eddfd3254adc 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -28,7 +28,7 @@ def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> thres (float): The threshold value for discretizing the matrix values. Returns: - mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first + mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first nonzero value in each row. """ # Discretize the matrix to the specified maximum capacity @@ -230,7 +230,7 @@ def get_mask_from_segments( speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. Returns: @@ -318,27 +318,27 @@ def speaker_to_target( soft_thres: float = 0.5, ): """ - Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length). This function is needed for speaker diarization with ASR model trainings. Args: - a_cut (MonoCut, MixedCut): + a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. - num_speakers (int): + num_speakers (int): Max number of speakers for all cuts ("mask" dim0), 4 by default - num_sample_per_mel_frame (int): + num_sample_per_mel_frame (int): Number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) - num_mel_frame_per_asr_frame (int): + num_mel_frame_per_asr_frame (int): Encoder subsampling_factor, 8 by default - spk_tar_all_zero (Tensor): + spk_tar_all_zero (Tensor): Set to True gives all zero "mask" - boundary_segments (bool): - Set to True to include segments containing the boundary of the cut, + boundary_segments (bool): + Set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training - soft_label (bool): - Set to True to use soft label that enables values in [0, 1] range, + soft_label (bool): + Set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. - ignore_num_spk_mismatch (bool): + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. Returns: diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index b3d96f17ce8f..5773ddf4b79b 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -310,9 +310,13 @@ def __init__( class InstructionTuningAudioText(_Collection): """`AudioText` collector from asr structured json files.""" - OUTPUT_TYPE = collections.namedtuple(typename='InstructionTuningText', - field_names=('id context context_type context_duration question ' - 'question_type answer answer_type answer_duration speaker'),) + OUTPUT_TYPE = collections.namedtuple( + typename='InstructionTuningText', + field_names=( + 'id context context_type context_duration question ' + 'question_type answer answer_type answer_duration speaker' + ), + ) def __init__( self, From ab93b176a3a21ded11dacbd9da10f720c170e063 Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 14:57:15 -0800 Subject: [PATCH 14/47] Removing unused varialbe in audio_to_diar_label.py Signed-off-by: taejinp --- nemo/collections/asr/data/audio_to_diar_label.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index 7217686d3168..568708dc8c7a 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -687,7 +687,7 @@ def parse_rttm_multiscale(self, sample): seg_target = self.get_diar_target_labels_from_fr_target(uniq_id, fr_level_target) return seg_target - def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): + def get_diar_target_labels_from_fr_target(self, uniq_id: str, fr_level_target: torch.Tensor) -> torch.Tensor: """ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate @@ -1222,7 +1222,6 @@ def __getitem__(self, index): else: session_len_sec = min(sample.duration, self.session_len_sec) - uniq_id = self.get_uniq_id_with_range(sample) audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) # We should resolve the length mis-match from the round-off errors between these two variables: From 7dea01b4b14f83934717cb5504942d82ea35b169 Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 18:13:58 -0800 Subject: [PATCH 15/47] Fixed docstrings in training script Signed-off-by: taejinp --- .../diarization/neural_diarizer/sortformer_diar_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 1b4376c4f2c4..78c7acbaa6c2 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -24,7 +24,7 @@ """ Example training session (single node training) -python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' --config-name='' \ +python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ trainer.devices=1 \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \ From 71d515faa608870f7da8be2297a4ae7b78d99f4d Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 18:18:21 -0800 Subject: [PATCH 16/47] Line-too-long issue from Pylint fixed Signed-off-by: taejinp --- .../diarization/neural_diarizer/sortformer_diar_train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 78c7acbaa6c2..8719b6463f70 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -24,7 +24,8 @@ """ Example training session (single node training) -python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ +python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' \ + --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ trainer.devices=1 \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \ From f2d5e36f2e25233cb5c4104eaa77619134990ce4 Mon Sep 17 00:00:00 2001 From: taejinp Date: Mon, 18 Nov 2024 18:28:28 -0800 Subject: [PATCH 17/47] Adding get_subsegments_scriptable to prevent jit.script error Signed-off-by: taejinp --- nemo/collections/asr/losses/__init__.py | 1 + nemo/collections/asr/losses/bce_loss.py | 80 ++++++++++++++++--- .../asr/parts/utils/speaker_utils.py | 45 +++++++++-- tests/collections/asr/test_diar_utils.py | 4 +- 4 files changed, 110 insertions(+), 20 deletions(-) diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py index 756a071178d7..31bd0bae5d40 100644 --- a/nemo/collections/asr/losses/__init__.py +++ b/nemo/collections/asr/losses/__init__.py @@ -19,3 +19,4 @@ from nemo.collections.asr.losses.ssl_losses.ctc import CTCLossForSSL from nemo.collections.asr.losses.ssl_losses.mlm import MLMLoss, MultiMLMLoss from nemo.collections.asr.losses.ssl_losses.rnnt import RNNTLossForSSL +from nemo.collections.asr.losses.bce_loss import BCELoss \ No newline at end of file diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index 30e31b8610ec..89f292ddae60 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -33,7 +33,7 @@ def input_types(self): return { "probs": NeuralType(('B', 'T', 'C'), ProbsType()), 'labels': NeuralType(('B', 'T', 'C'), LabelsType()), - "signal_lengths": NeuralType(tuple('B'), LengthsType()), + "target_lens": NeuralType(('B', 'C'), LengthsType()), } @property @@ -43,31 +43,91 @@ def output_types(self): """ return {"loss": NeuralType(elements_type=LossType())} - def __init__(self, reduction='sum', alpha=1.0, weight=torch.tensor([0.5, 0.5])): + def __init__( + self, + reduction: str = 'mean', + alpha: float = 1.0, + weight: torch.Tensor = torch.tensor([0.1, 0.9]), + sorted_preds: bool = False, + sorted_loss: bool = False, + class_normalization: bool = False, + ): + """ + A custom loss function that supports class normalization, + weighted binary cross-entropy, and optional sorting. + + Args: + reduction (str): Specifies the reduction to apply to the output, + options are 'mean', 'sum', or 'none'. Default is 'mean'. + alpha (float): Scaling factor for loss (unused in this implementation). Default is 1.0. + weight (torch.Tensor): Class weights for the binary cross-entropy loss. Default is [0.1, 0.9]. + sorted_preds (bool): If True, assumes predictions are sorted. Default is False. + sorted_loss (bool): If True, sorts the loss before reduction. Default is False. + class_normalization (bool): If True, uses 'none' reduction for per-class loss. Default is False. + """ super().__init__() - self.reduction = reduction + self.class_normalization = class_normalization + if class_normalization: + self.reduction = 'none' + else: + self.reduction = 'mean' self.loss_weight = weight - self.loss_f = torch.nn.BCELoss(weight=self.loss_weight, reduction=self.reduction) + self.loss_f = torch.nn.BCELoss(reduction=self.reduction) + self.sorted_preds = sorted_preds + self.sorted_loss = sorted_loss + self.eps = 1e-6 @typecheck() - def forward(self, probs, labels, signal_lengths): + def forward(self, probs, labels, target_lens): """ - Calculate binary cross entropy loss based on probs, labels and signal_lengths variables. + Calculate binary cross entropy loss based on probs, labels and target_lens variables. Args: probs (torch.tensor) Predicted probability value which ranges from 0 to 1. Sigmoid output is expected. labels (torch.tensor) Groundtruth label for the predicted samples. - signal_lengths (torch.tensor): + target_lens (torch.tensor): The actual length of the sequence without zero-padding. Returns: loss (NeuralType) Binary cross entropy loss value. """ - probs_list = [probs[k, : signal_lengths[k], :] for k in range(probs.shape[0])] - targets_list = [labels[k, : signal_lengths[k], :] for k in range(labels.shape[0])] + probs_list = [probs[k, : target_lens[k], :] for k in range(probs.shape[0])] + targets_list = [labels[k, : target_lens[k], :] for k in range(labels.shape[0])] probs = torch.cat(probs_list, dim=0) labels = torch.cat(targets_list, dim=0) - return self.loss_f(probs, labels) + if self.class_normalization in ['class', 'class_binary', 'binary']: + if self.class_normalization in ['class', 'class_binary']: + # Normalize loss by number of classes + norm_weight = 1/(labels.sum(dim=0) + self.eps) + norm_weight_norm = norm_weight / norm_weight.sum() + norm_weight_norm2 = torch.clamp(norm_weight_norm, min=0.05, max=1.0) + norm_weight_norm2 = norm_weight_norm2 / norm_weight_norm2.max() + norm_weight = norm_weight_norm2[None, :].expand_as(labels).detach().clone() + else: + norm_weight = torch.ones_like(labels).detach().clone() + + if self.class_normalization in ['binary', 'class_binary']: + binary_weight = torch.ones_like(labels).detach().clone() + one_weight = (labels.sum() / (labels.shape[0]*labels.shape[1])).to(labels.device) + binary_weight[labels == 0] = one_weight + binary_weight[labels == 1] = 1 - one_weight + else: + binary_weight = torch.ones_like(labels).detach().clone() + + elif self.class_normalization == 'none' or not self.class_normalization: + binary_weight = torch.ones_like(labels).detach().clone() + norm_weight = torch.ones_like(labels).detach().clone() + + if self.reduction == 'sum': + return self.loss_f(probs, labels) + elif self.reduction == 'mean': + return self.loss_f(probs, labels).mean() + elif self.reduction == 'none': + if self.class_normalization in ['class', 'class_binary', 'binary']: + return (binary_weight * norm_weight * self.loss_f(probs, labels)).sum() + else: + return self.loss_f(probs, labels) + diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 15ec8a24a3bd..3e10ff564bae 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -19,6 +19,7 @@ from copy import deepcopy from typing import Dict, List, Tuple, Union +import math import numpy as np import soundfile as sf import torch @@ -915,7 +916,7 @@ def segments_manifest_to_subsegments_manifest( segment = segment.strip() dic = json.loads(segment) audio, offset, duration, label = dic['audio_filepath'], dic['offset'], dic['duration'], dic['label'] - subsegments = get_subsegments(offset=offset, window=window, shift=shift, duration=duration) + subsegments = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=duration) if include_uniq_id and 'uniq_id' in dic: uniq_id = dic['uniq_id'] else: @@ -974,14 +975,14 @@ def get_subsegments( subsegments: List[List[float]] = [] start = offset slice_end = start + duration - if min_subsegment_duration <= duration < shift: + if min_subsegment_duration <= duration <= shift: slices = 1 - elif use_asr_style_frame_count is True: - num_feat_frames = np.ceil((1 + duration * sample_rate) / int(sample_rate / feat_per_sec)).astype(int) - slices = np.ceil(num_feat_frames / int(feat_per_sec * shift)).astype(int) + elif use_asr_style_frame_count is True: + num_feat_frames = np.ceil((1+duration*sample_rate)/int(sample_rate/feat_per_sec)).astype(int) + slices = np.ceil(num_feat_frames/int(feat_per_sec*shift)).astype(int) slice_end = start + shift * slices else: - slices = np.ceil(1 + (duration - window) / shift).astype(int) + slices = np.ceil(1+ (duration-window)/shift).astype(int) if slices == 1: if min(duration, window) >= min_subsegment_duration: subsegments.append([start, min(duration, window)]) @@ -996,6 +997,34 @@ def get_subsegments( subsegments = valid_subsegments.tolist() return subsegments +def get_subsegments_scriptable(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: + """ + Return subsegments from a segment of audio file. + This function is inefficient since the segmentation is based on for-loop, + but this implementation makes this function torch-jit-scriptable. + + Args: + offset (float): start time of audio segment + window (float): window length for segments to subsegments length + shift (float): hop length for subsegments shift + duration (float): duration of segment + Returns: + subsegments (List[tuple[float, float]]): subsegments generated for the segments + as list of tuple of start and duration of each subsegment + """ + subsegments: List[List[float]] = [] + start = offset + slice_end = start + duration + base = math.ceil((duration - window) / shift) + slices = 1 if base < 0 else base + 1 + for slice_id in range(slices): + end = start + window + if end > slice_end: + end = slice_end + subsegments.append([start, end - start]) + start = offset + (slice_id + 1) * shift + return subsegments + def get_target_sig( sig, @@ -1307,7 +1336,7 @@ def get_online_subsegments_from_buffer( range_offs = [float(range_spl[0].item() - buffer_start), float(range_spl[1].item() - buffer_start)] range_t = [max(0, range_offs[0]), range_offs[1]] - subsegments = get_subsegments( + subsegments = get_subsegments_scriptable( offset=range_t[0], window=window, shift=shift, @@ -1811,7 +1840,7 @@ def run_online_segmentation( segment_indexes: List[int], window: float, shift: float, - ): + )-> Tuple[List[torch.Tensor], List[List[float]], List[int]]: """ Remove the old segments that overlap with the new frame (self.frame_start) cursor_for_old_segments is pointing at the onset of the t_range popped most recently. diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/asr/test_diar_utils.py index f48292d27981..a72313923a66 100644 --- a/tests/collections/asr/test_diar_utils.py +++ b/tests/collections/asr/test_diar_utils.py @@ -48,7 +48,7 @@ get_online_subsegments_from_buffer, get_speech_labels_for_update, get_sub_range_list, - get_subsegments, + get_subsegments_scriptable, get_target_sig, int2fl, is_overlap, @@ -110,7 +110,7 @@ def generate_toy_data( random_orthogonal_embs = generate_orthogonal_embs(n_spks, perturb_sigma, emb_dim) for scale_idx, (window, shift) in enumerate(zip(ms_window, ms_shift)): for spk_idx, (offset, dur) in enumerate(spk_timestamps): - segments_stt_dur = get_subsegments(offset=offset, window=window, shift=shift, duration=dur) + segments_stt_dur = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=dur) segments = [[x[0], x[0] + x[1]] for x in segments_stt_dur] emb_cent = random_orthogonal_embs[spk_idx, :] emb = emb_cent.tile((len(segments), 1)) + 0.1 * torch.rand(len(segments), emb_dim) From 9cca3e892117ff4e8bbbac7896c73aa7d8190ba8 Mon Sep 17 00:00:00 2001 From: tango4j Date: Tue, 19 Nov 2024 02:29:32 +0000 Subject: [PATCH 18/47] Apply isort and black reformatting Signed-off-by: tango4j --- nemo/collections/asr/losses/__init__.py | 2 +- nemo/collections/asr/losses/bce_loss.py | 20 ++--- .../asr/parts/utils/speaker_utils.py | 19 +++-- tests/collections/asr/test_diar_utils.py | 84 +++++++++++++++---- 4 files changed, 86 insertions(+), 39 deletions(-) diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py index 31bd0bae5d40..f88bd49d1f7b 100644 --- a/nemo/collections/asr/losses/__init__.py +++ b/nemo/collections/asr/losses/__init__.py @@ -13,10 +13,10 @@ # limitations under the License. from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss +from nemo.collections.asr.losses.bce_loss import BCELoss from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.losses.lattice_losses import LatticeLoss from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss from nemo.collections.asr.losses.ssl_losses.ctc import CTCLossForSSL from nemo.collections.asr.losses.ssl_losses.mlm import MLMLoss, MultiMLMLoss from nemo.collections.asr.losses.ssl_losses.rnnt import RNNTLossForSSL -from nemo.collections.asr.losses.bce_loss import BCELoss \ No newline at end of file diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index 89f292ddae60..4299f1422891 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -28,8 +28,7 @@ class BCELoss(Loss, Typing): @property def input_types(self): - """Input types definitions for AnguarLoss. - """ + """Input types definitions for AnguarLoss.""" return { "probs": NeuralType(('B', 'T', 'C'), ProbsType()), 'labels': NeuralType(('B', 'T', 'C'), LabelsType()), @@ -57,7 +56,7 @@ def __init__( weighted binary cross-entropy, and optional sorting. Args: - reduction (str): Specifies the reduction to apply to the output, + reduction (str): Specifies the reduction to apply to the output, options are 'mean', 'sum', or 'none'. Default is 'mean'. alpha (float): Scaling factor for loss (unused in this implementation). Default is 1.0. weight (torch.Tensor): Class weights for the binary cross-entropy loss. Default is [0.1, 0.9]. @@ -101,26 +100,26 @@ def forward(self, probs, labels, target_lens): if self.class_normalization in ['class', 'class_binary', 'binary']: if self.class_normalization in ['class', 'class_binary']: # Normalize loss by number of classes - norm_weight = 1/(labels.sum(dim=0) + self.eps) + norm_weight = 1 / (labels.sum(dim=0) + self.eps) norm_weight_norm = norm_weight / norm_weight.sum() - norm_weight_norm2 = torch.clamp(norm_weight_norm, min=0.05, max=1.0) + norm_weight_norm2 = torch.clamp(norm_weight_norm, min=0.05, max=1.0) norm_weight_norm2 = norm_weight_norm2 / norm_weight_norm2.max() norm_weight = norm_weight_norm2[None, :].expand_as(labels).detach().clone() - else: + else: norm_weight = torch.ones_like(labels).detach().clone() if self.class_normalization in ['binary', 'class_binary']: binary_weight = torch.ones_like(labels).detach().clone() - one_weight = (labels.sum() / (labels.shape[0]*labels.shape[1])).to(labels.device) + one_weight = (labels.sum() / (labels.shape[0] * labels.shape[1])).to(labels.device) binary_weight[labels == 0] = one_weight binary_weight[labels == 1] = 1 - one_weight else: binary_weight = torch.ones_like(labels).detach().clone() - + elif self.class_normalization == 'none' or not self.class_normalization: - binary_weight = torch.ones_like(labels).detach().clone() + binary_weight = torch.ones_like(labels).detach().clone() norm_weight = torch.ones_like(labels).detach().clone() - + if self.reduction == 'sum': return self.loss_f(probs, labels) elif self.reduction == 'mean': @@ -130,4 +129,3 @@ def forward(self, probs, labels, target_lens): return (binary_weight * norm_weight * self.loss_f(probs, labels)).sum() else: return self.loss_f(probs, labels) - diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 3e10ff564bae..4df4d0f09e04 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -14,12 +14,12 @@ import gc import json +import math import os import shutil from copy import deepcopy from typing import Dict, List, Tuple, Union -import math import numpy as np import soundfile as sf import torch @@ -977,12 +977,12 @@ def get_subsegments( slice_end = start + duration if min_subsegment_duration <= duration <= shift: slices = 1 - elif use_asr_style_frame_count is True: - num_feat_frames = np.ceil((1+duration*sample_rate)/int(sample_rate/feat_per_sec)).astype(int) - slices = np.ceil(num_feat_frames/int(feat_per_sec*shift)).astype(int) + elif use_asr_style_frame_count is True: + num_feat_frames = np.ceil((1 + duration * sample_rate) / int(sample_rate / feat_per_sec)).astype(int) + slices = np.ceil(num_feat_frames / int(feat_per_sec * shift)).astype(int) slice_end = start + shift * slices else: - slices = np.ceil(1+ (duration-window)/shift).astype(int) + slices = np.ceil(1 + (duration - window) / shift).astype(int) if slices == 1: if min(duration, window) >= min_subsegment_duration: subsegments.append([start, min(duration, window)]) @@ -997,19 +997,20 @@ def get_subsegments( subsegments = valid_subsegments.tolist() return subsegments + def get_subsegments_scriptable(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: """ Return subsegments from a segment of audio file. - This function is inefficient since the segmentation is based on for-loop, + This function is inefficient since the segmentation is based on for-loop, but this implementation makes this function torch-jit-scriptable. - + Args: offset (float): start time of audio segment window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift duration (float): duration of segment Returns: - subsegments (List[tuple[float, float]]): subsegments generated for the segments + subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment """ subsegments: List[List[float]] = [] @@ -1840,7 +1841,7 @@ def run_online_segmentation( segment_indexes: List[int], window: float, shift: float, - )-> Tuple[List[torch.Tensor], List[List[float]], List[int]]: + ) -> Tuple[List[torch.Tensor], List[List[float]], List[int]]: """ Remove the old segments that overlap with the new frame (self.frame_start) cursor_for_old_segments is pointing at the onset of the t_range popped most recently. diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/asr/test_diar_utils.py index a72313923a66..cb364675fcf4 100644 --- a/tests/collections/asr/test_diar_utils.py +++ b/tests/collections/asr/test_diar_utils.py @@ -82,8 +82,7 @@ def matrix(mat, use_tensor=True, dtype=torch.long): def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim): - """Generate a set of artificial orthogonal embedding vectors from random numbers - """ + """Generate a set of artificial orthogonal embedding vectors from random numbers""" gaus = torch.randn(emb_dim, emb_dim) _svd = torch.linalg.svd(gaus) orth = _svd[0] @ _svd[2] @@ -130,8 +129,7 @@ def generate_toy_data( class TestDiarizationSequneceUtilFunctions: - """Tests diarization and speaker-task related utils. - """ + """Tests diarization and speaker-task related utils.""" @pytest.mark.unit @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]]) @@ -278,7 +276,10 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10) em_s, ts_s = split_input_data(em, ts, mc) merged_embs, merged_clus_labels, _ = run_reducer( - pre_embs=em_s[-1], target_spk_idx=target_speaker_index, merge_quantity=merge_quantity, pre_clus_labels=gt, + pre_embs=em_s[-1], + target_spk_idx=target_speaker_index, + merge_quantity=merge_quantity, + pre_clus_labels=gt, ) assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] @@ -287,7 +288,11 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)]) @pytest.mark.parametrize("mspb", [25]) def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0])) @pytest.mark.unit @@ -295,7 +300,11 @@ def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)]) @pytest.mark.parametrize("mspb", [0, 25]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0, 0])) @pytest.mark.unit @@ -303,7 +312,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)]) @pytest.mark.parametrize("mspb", [3, 10]) def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([18, 13, 56, 0])) @pytest.mark.unit @@ -311,7 +324,11 @@ def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)]) @pytest.mark.parametrize("mspb", [0, 2]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 1, 0])) @pytest.mark.unit @@ -319,7 +336,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)]) @pytest.mark.parametrize("mspb", [2]) def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 0, 0, 0])) @@ -414,13 +435,21 @@ def test_is_overlap_false(self, rangeA, rangeB): @pytest.mark.parametrize("x", [1.0, 2.3456]) @pytest.mark.parametrize("decimals", [1, 2, 3, 4]) def test_fl2int(self, x, decimals): - assert fl2int(x, decimals) == round(x * 10 ** decimals, 0) + assert fl2int(x, decimals) == round(x * 10**decimals, 0) @pytest.mark.unit @pytest.mark.parametrize("x", [1234]) - @pytest.mark.parametrize("decimals", [1, 2, 3, 4,]) + @pytest.mark.parametrize( + "decimals", + [ + 1, + 2, + 3, + 4, + ], + ) def test_int2fl(self, x, decimals): - assert abs(int2fl(x, decimals) - round(x / (10 ** decimals), decimals)) < (10 ** -(decimals + 1)) + assert abs(int2fl(x, decimals) - round(x / (10**decimals), decimals)) < (10 ** -(decimals + 1)) @pytest.mark.unit def test_merge_float_intervals_edge_margin_test(self): @@ -462,7 +491,11 @@ def test_get_speech_labels_for_update(self): vad_timestamps = torch.tensor([[0.9600, 4.8400]]) cursor_for_old_segments = 1.0 speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update( - frame_start, buffer_end, cumulative_speech_labels, vad_timestamps, cursor_for_old_segments, + frame_start, + buffer_end, + cumulative_speech_labels, + vad_timestamps, + cursor_for_old_segments, ) assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8 assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8 @@ -532,7 +565,10 @@ def test_tensor_to_list(self, source_range_list): @pytest.mark.unit @pytest.mark.parametrize( "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate", - [(0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000),], + [ + (0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), + (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000), + ], ) def test_get_online_segments_from_slices( self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate @@ -665,7 +701,13 @@ def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_si @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_cpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -697,7 +739,13 @@ def test_offline_speaker_clustering_very_short_cpu( @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_gpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -908,7 +956,7 @@ def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix): Test the linear sum assignment algorithm with a cost matrix Compare with the scipy implementation and make sure the final cost is the same. - NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. + NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. This test only checks if the cost is the same. """ row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix) From 008dcbd97768d8b691740c74f81ae81e1abd8ac0 Mon Sep 17 00:00:00 2001 From: taejinp Date: Tue, 19 Nov 2024 15:33:41 -0800 Subject: [PATCH 19/47] Addressed Code-QL issues Signed-off-by: taejinp --- nemo/collections/asr/losses/bce_loss.py | 11 +++++++---- nemo/collections/asr/parts/utils/speaker_utils.py | 10 ++++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index 89f292ddae60..f5ffb24d32a2 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -98,6 +98,8 @@ def forward(self, probs, labels, target_lens): targets_list = [labels[k, : target_lens[k], :] for k in range(labels.shape[0])] probs = torch.cat(probs_list, dim=0) labels = torch.cat(targets_list, dim=0) + norm_weight = torch.zeros_like(labels).detach().clone() + if self.class_normalization in ['class', 'class_binary', 'binary']: if self.class_normalization in ['class', 'class_binary']: # Normalize loss by number of classes @@ -122,12 +124,13 @@ def forward(self, probs, labels, target_lens): norm_weight = torch.ones_like(labels).detach().clone() if self.reduction == 'sum': - return self.loss_f(probs, labels) + loss = self.loss_f(probs, labels) elif self.reduction == 'mean': - return self.loss_f(probs, labels).mean() + loss = self.loss_f(probs, labels).mean() elif self.reduction == 'none': if self.class_normalization in ['class', 'class_binary', 'binary']: - return (binary_weight * norm_weight * self.loss_f(probs, labels)).sum() + loss = (binary_weight * norm_weight * self.loss_f(probs, labels)).sum() else: - return self.loss_f(probs, labels) + loss = self.loss_f(probs, labels) + return loss diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 3e10ff564bae..80453ecc6e09 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -999,9 +999,10 @@ def get_subsegments( def get_subsegments_scriptable(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: """ - Return subsegments from a segment of audio file. - This function is inefficient since the segmentation is based on for-loop, - but this implementation makes this function torch-jit-scriptable. + This function returns subsegments from a segment of an audio file. + Although this implementation is inefficient due to the use of a for-loop for segmentation, + it is designed to be torch-jit-scriptable. + Use `get_subsegments` for a more efficient implementation. Args: offset (float): start time of audio segment @@ -1010,7 +1011,8 @@ def get_subsegments_scriptable(offset: float, window: float, shift: float, durat duration (float): duration of segment Returns: subsegments (List[tuple[float, float]]): subsegments generated for the segments - as list of tuple of start and duration of each subsegment + as list of tuple of start and duration of + each subsegment """ subsegments: List[List[float]] = [] start = offset From 045f3a2b791848479218e46c3f805445945ffcf2 Mon Sep 17 00:00:00 2001 From: taejinp Date: Tue, 19 Nov 2024 15:44:41 -0800 Subject: [PATCH 20/47] Resolved conflicts on bce_loss.py Signed-off-by: taejinp --- nemo/collections/asr/losses/bce_loss.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index bca75b3f06b3..21190742e0b7 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -130,10 +130,6 @@ def forward(self, probs, labels, target_lens): if self.class_normalization in ['class', 'class_binary', 'binary']: loss = (binary_weight * norm_weight * self.loss_f(probs, labels)).sum() else: -<<<<<<< HEAD loss = self.loss_f(probs, labels) return loss -======= - return self.loss_f(probs, labels) ->>>>>>> 681fe3881c7029e104788ef621d020b8f94bd410 From 1dcf9ab40b02fea582790ffcdebf6585a8a257e4 Mon Sep 17 00:00:00 2001 From: tango4j Date: Tue, 19 Nov 2024 23:45:47 +0000 Subject: [PATCH 21/47] Apply isort and black reformatting Signed-off-by: tango4j --- nemo/collections/asr/losses/bce_loss.py | 3 +- .../asr/parts/utils/speaker_utils.py | 52 +++++++++---------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index 21190742e0b7..d2aa03319007 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -98,7 +98,7 @@ def forward(self, probs, labels, target_lens): probs = torch.cat(probs_list, dim=0) labels = torch.cat(targets_list, dim=0) norm_weight = torch.zeros_like(labels).detach().clone() - + if self.class_normalization in ['class', 'class_binary', 'binary']: if self.class_normalization in ['class', 'class_binary']: # Normalize loss by number of classes @@ -132,4 +132,3 @@ def forward(self, probs, labels, target_lens): else: loss = self.loss_f(probs, labels) return loss - diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 57eee5d87d63..14a336a97479 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -1000,32 +1000,32 @@ def get_subsegments( def get_subsegments_scriptable(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: """ -<<<<<<< HEAD - This function returns subsegments from a segment of an audio file. - Although this implementation is inefficient due to the use of a for-loop for segmentation, - it is designed to be torch-jit-scriptable. - Use `get_subsegments` for a more efficient implementation. - -======= - Return subsegments from a segment of audio file. - This function is inefficient since the segmentation is based on for-loop, - but this implementation makes this function torch-jit-scriptable. - ->>>>>>> 681fe3881c7029e104788ef621d020b8f94bd410 - Args: - offset (float): start time of audio segment - window (float): window length for segments to subsegments length - shift (float): hop length for subsegments shift - duration (float): duration of segment - Returns: -<<<<<<< HEAD - subsegments (List[tuple[float, float]]): subsegments generated for the segments - as list of tuple of start and duration of - each subsegment -======= - subsegments (List[tuple[float, float]]): subsegments generated for the segments - as list of tuple of start and duration of each subsegment ->>>>>>> 681fe3881c7029e104788ef621d020b8f94bd410 + <<<<<<< HEAD + This function returns subsegments from a segment of an audio file. + Although this implementation is inefficient due to the use of a for-loop for segmentation, + it is designed to be torch-jit-scriptable. + Use `get_subsegments` for a more efficient implementation. + + ======= + Return subsegments from a segment of audio file. + This function is inefficient since the segmentation is based on for-loop, + but this implementation makes this function torch-jit-scriptable. + + >>>>>>> 681fe3881c7029e104788ef621d020b8f94bd410 + Args: + offset (float): start time of audio segment + window (float): window length for segments to subsegments length + shift (float): hop length for subsegments shift + duration (float): duration of segment + Returns: + <<<<<<< HEAD + subsegments (List[tuple[float, float]]): subsegments generated for the segments + as list of tuple of start and duration of + each subsegment + ======= + subsegments (List[tuple[float, float]]): subsegments generated for the segments + as list of tuple of start and duration of each subsegment + >>>>>>> 681fe3881c7029e104788ef621d020b8f94bd410 """ subsegments: List[List[float]] = [] start = offset From be8ac22f071c105721fff4e4137ec60f2713655b Mon Sep 17 00:00:00 2001 From: taejinp Date: Tue, 19 Nov 2024 15:55:15 -0800 Subject: [PATCH 22/47] Adding all the diarization reltated unit-tests Signed-off-by: taejinp --- .../speaker_tasks/test_diar_datasets.py | 112 ++ .../speaker_tasks/test_diar_label_models.py | 200 ++++ .../test_diar_lhotse_datasets.py | 157 +++ .../speaker_tasks/test_diar_metrics.py | 197 ++++ .../test_diar_neural_inference.py | 74 ++ .../test_diar_sortformer_models.py | 175 +++ .../utils/test_data_simul_utils.py | 545 +++++++++ .../speaker_tasks/utils/test_diar_utils.py | 1046 +++++++++++++++++ .../utils/test_multispeaker_utils.py | 320 +++++ .../speaker_tasks/utils/test_vad_utils.py | 126 ++ 10 files changed, 2952 insertions(+) create mode 100644 tests/collections/speaker_tasks/test_diar_datasets.py create mode 100644 tests/collections/speaker_tasks/test_diar_label_models.py create mode 100644 tests/collections/speaker_tasks/test_diar_lhotse_datasets.py create mode 100644 tests/collections/speaker_tasks/test_diar_metrics.py create mode 100644 tests/collections/speaker_tasks/test_diar_neural_inference.py create mode 100644 tests/collections/speaker_tasks/test_diar_sortformer_models.py create mode 100644 tests/collections/speaker_tasks/utils/test_data_simul_utils.py create mode 100644 tests/collections/speaker_tasks/utils/test_diar_utils.py create mode 100644 tests/collections/speaker_tasks/utils/test_multispeaker_utils.py create mode 100644 tests/collections/speaker_tasks/utils/test_vad_utils.py diff --git a/tests/collections/speaker_tasks/test_diar_datasets.py b/tests/collections/speaker_tasks/test_diar_datasets.py new file mode 100644 index 000000000000..915a54382f99 --- /dev/null +++ b/tests/collections/speaker_tasks/test_diar_datasets.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import filecmp +import json +import os +import shutil +import tempfile +from unittest import mock + +import numpy as np +import pytest +import soundfile as sf +import torch.cuda +from omegaconf import DictConfig, OmegaConf +from torch.utils.data import DataLoader + +from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.utils.speaker_utils import read_rttm_lines, get_offset_and_duration, get_vad_out_from_rttm_line + +def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec): + """ + Check if the maximum RTTM duration exceeds the length of the provided audio file. + + Args: + rttm_file_path (str): Path to the RTTM file. + wav_len_in_sec (float): Length of the audio file in seconds. + + Returns: + bool: True if the maximum RTTM duration is less than or equal to the length of the audio file, False otherwise. + """ + rttm_lines = read_rttm_lines(rttm_file_path) + max_rttm_sec = 0 + for line in rttm_lines: + start, dur = get_vad_out_from_rttm_line(line) + max_rttm_sec = max(max_rttm_sec, start + dur) + return max_rttm_sec <= wav_len_in_sec + +class TestAudioToSpeechE2ESpkDiarDataset: + + @pytest.mark.unit + def test_e2e_speaker_diar_dataset(self, test_data_dir): + manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/diarizer/lsm_val.json')) + + batch_size = 4 + num_samples = 8 + + device = 'gpu' if torch.cuda.is_available() else 'cpu' + data_dict_list = [] + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: + with open(manifest_path, 'r', encoding='utf-8') as mfile: + for ix, line in enumerate(mfile): + if ix >= num_samples: + break + + line = line.replace("tests/data/", test_data_dir + "/").replace("\n", "") + f.write(f"{line}\n") + data_dict = json.loads(line) + data_dict_list.append(data_dict) + + f.seek(0) + featurizer = WaveformFeaturizer(sample_rate=16000, int_values=False, augmentor=None) + + dataset = AudioToSpeechE2ESpkDiarDataset( + manifest_filepath=f.name, + soft_label_thres=0.5, + session_len_sec=90, + num_spks=4, + featurizer=featurizer, + window_stride=0.01, + global_rank=0, + soft_targets=False, + ) + + dataloader_instance = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=batch_size, + collate_fn=dataset.eesd_train_collate_fn, + drop_last=False, + shuffle=False, + num_workers=1, + pin_memory=False, + ) + assert len(dataloader_instance) == (num_samples / batch_size) # Check if the number of batches is correct + batch_counts = len(dataloader_instance) + + deviation_thres_rate = 0.01 # 1% deviation allowed + for batch_index, batch in enumerate(dataloader_instance): + if batch_index != batch_counts - 1: + assert len(batch) == batch_size, "Batch size does not match the expected value" + audio_signals, audio_signal_len, targets, target_lens = batch + for sample_index in range(audio_signals.shape[0]): + dataloader_audio_in_sec = audio_signal_len[sample_index].item() + data_dur_in_sec = abs(data_dict_list[batch_size*batch_index + sample_index]['duration'] * featurizer.sample_rate - dataloader_audio_in_sec) + assert data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec, "Duration deviation exceeds 1%" + assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values" + assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values" + assert not torch.isnan(targets).any(), "targets tensor contains NaN values" + assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values" + \ No newline at end of file diff --git a/tests/collections/speaker_tasks/test_diar_label_models.py b/tests/collections/speaker_tasks/test_diar_label_models.py new file mode 100644 index 000000000000..cf073d9e85e2 --- /dev/null +++ b/tests/collections/speaker_tasks/test_diar_label_models.py @@ -0,0 +1,200 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from omegaconf import DictConfig + +from nemo.collections.asr.models import EncDecDiarLabelModel +from nemo.collections.asr.losses import BCELoss + +@pytest.fixture() +def msdd_model(): + + preprocessor = { + 'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', + 'params': {"features": 80, "window_size": 0.025, "window_stride": 0.01, "sample_rate": 16000,}, + } + + speaker_model_encoder = { + 'cls': 'nemo.collections.asr.modules.ConvASREncoder', + 'params': { + 'feat_in': 80, + 'activation': 'relu', + 'conv_mask': True, + 'jasper': [ + { + 'filters': 512, + 'repeat': 1, + 'kernel': [1], + 'stride': [1], + 'dilation': [1], + 'dropout': 0.0, + 'residual': False, + 'separable': False, + } + ], + }, + } + + speaker_model_decoder = { + 'cls': 'nemo.collections.asr.modules.SpeakerDecoder', + 'params': {'feat_in': 512, 'num_classes': 2, 'pool_mode': 'xvector', 'emb_sizes': [1024]}, + } + + speaker_model_cfg = DictConfig( + { + 'preprocessor': DictConfig(preprocessor), + 'encoder': DictConfig(speaker_model_encoder), + 'decoder': DictConfig(speaker_model_decoder), + } + ) + + msdd_module = { + 'cls': 'nemo.collections.asr.modules.MSDD_module', + 'params': { + "num_spks": 2, + "hidden_size": 256, + "num_lstm_layers": 3, + "dropout_rate": 0.5, + "cnn_output_ch": 32, + "conv_repeat": 2, + "emb_dim": 192, + "scale_n": 5, + "weighting_scheme": 'conv_scale_weight', + "context_vector_type": 'cos_sim', + }, + } + + loss = {'cls': 'nemo.collections.asr.losses.bce_loss.BCELoss', 'params': {"weight": None}} + + diarizer = { + 'out_dir': None, + 'oracle_vad': True, + "speaker_embeddings": { + "model_path": None, + "parameters": { + "window_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5], + "shift_length_in_sec": [0.75, 0.625, 0.5, 0.375, 0.25], + "multiscale_weights": [1, 1, 1, 1, 1], + "save_embeddings": True, + }, + }, + } + + modelConfig = DictConfig( + { + 'msdd_module': DictConfig(msdd_module), + 'preprocessor': DictConfig(preprocessor), + 'diarizer': DictConfig(diarizer), + 'loss': DictConfig(loss), + 'max_num_of_spks': 2, + 'num_workers': 5, + 'emb_batch_size': 0, + 'soft_label_thres': 0.5, + 'scale_n': 5, + 'speaker_model_cfg': speaker_model_cfg, + } + ) + model = EncDecDiarLabelModel(cfg=modelConfig) + return model + + +class TestEncDecDiarLabelModel: + @pytest.mark.unit + def test_constructor(self, msdd_model): + diar_model = msdd_model.train() + assert diar_model.cfg.scale_n == len( + diar_model.cfg.diarizer.speaker_embeddings.parameters.window_length_in_sec + ) + assert diar_model.cfg.scale_n == len(diar_model.cfg.diarizer.speaker_embeddings.parameters.shift_length_in_sec) + assert diar_model.cfg.scale_n == len(diar_model.cfg.diarizer.speaker_embeddings.parameters.multiscale_weights) + assert diar_model.cfg.msdd_module.num_spks == diar_model.cfg.max_num_of_spks + # TODO: make proper config and assert correct number of weights + # Check to/from config_dict: + confdict = diar_model.to_config_dict() + instance2 = EncDecDiarLabelModel.from_config_dict(confdict) + assert isinstance(instance2, EncDecDiarLabelModel) + + @pytest.mark.unit + def test_forward_infer(self, msdd_model): + diar_model = msdd_model.eval() + + # batch_size 4, scale_n 5, length 25, emb_dim 192 + input_signal = torch.randn(size=(4, 25, 5, 192)) + input_signal_length = 25 * torch.ones(4, dtype=torch.int) + emb_vectors = torch.randn(size=(4, 5, 192, 2)) + targets = torch.randint(2, size=(4, 25, 2), dtype=torch.int) + + with torch.no_grad(): + # batch size 1 + preds_list, scale_weights_list = [], [] + for i in range(input_signal.size(0)): + preds, scale_weights = diar_model.forward_infer( + input_signal[i : i + 1], input_signal_length[i : i + 1], emb_vectors[i : i + 1], targets[i : i + 1] + ) + preds_list.append(preds) + scale_weights_list.append(scale_weights) + preds_instance = torch.cat(preds_list, 0) + scale_weights_instance = torch.cat(scale_weights_list, 0) + + # batch size 4 + preds_batch, scale_weights_batch = diar_model.forward_infer( + input_signal, input_signal_length, emb_vectors, targets + ) + + assert preds_instance.shape == preds_batch.shape + assert scale_weights_instance.shape == scale_weights_batch.shape + + diff = torch.mean(torch.abs(preds_instance - preds_batch)) + assert diff <= 1e-6 + diff = torch.max(torch.abs(preds_instance - preds_batch)) + assert diff <= 1e-6 + diff = torch.mean(torch.abs(scale_weights_instance - scale_weights_batch)) + assert diff <= 1e-6 + diff = torch.max(torch.abs(scale_weights_instance - scale_weights_batch)) + assert diff <= 1e-6 + +class TestBCELoss: + @pytest.mark.unit + @pytest.mark.parametrize( + "probs, labels, target_lens, reduction, expected_output", [ + ( + torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32), + torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), + torch.tensor([[2]]), + "mean", + torch.tensor(0.693147, dtype=torch.float32) + ), + ( + torch.tensor([[[0.5, 0.5], [0.0, 1.0]]], dtype=torch.float32), + torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), + torch.tensor([[1]]), + "mean", + torch.tensor(0.693147, dtype=torch.float32) + ), + ( + torch.tensor([[[0, 1], [1, 0]]], dtype=torch.float32), + torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), + torch.tensor([[2]]), + "mean", + torch.tensor(100, dtype=torch.float32) + ) + ] + ) + def test_loss(self, probs, labels, target_lens, reduction, expected_output): + loss = BCELoss(reduction=reduction) + result = loss(probs=probs, labels=labels, target_lens=target_lens) + assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" + diff --git a/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py new file mode 100644 index 000000000000..0aa676a6318e --- /dev/null +++ b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +from unittest import mock +import pytest +import torch +import torch.cuda +from omegaconf import DictConfig +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset + +def get_train_ds_config(manifest_filepath, batch_size, num_workers) -> DictConfig: + return DictConfig({ + 'manifest_filepath': manifest_filepath, + 'sample_rate': 16000, + 'num_spks': 4, + 'session_len_sec': 90, + 'soft_label_thres': 0.5, + 'soft_targets': False, + 'labels': None, + 'batch_size': batch_size, + 'shuffle': True, + 'num_workers': num_workers, + 'validation_mode': False, + 'use_lhotse': True, + 'use_bucketing': True, + 'num_buckets': 10, + 'bucket_duration_bins': [10, 20, 30, 40, 50, 60, 70, 80, 90], + 'pin_memory': True, + 'min_duration': 80, + 'max_duration': 90, + 'batch_duration': 400, + 'quadratic_duration': 1200, + 'bucket_buffer_size': 20000, + 'shuffle_buffer_size': 10000, + 'window_stride': 0.01, + 'subsampling_factor': 8, + }) + +def get_validation_ds_config(manifest_filepath, batch_size, num_workers) -> DictConfig: + return DictConfig({ + 'manifest_filepath': manifest_filepath, + 'is_tarred': False, + 'tarred_audio_filepaths': None, + 'sample_rate': 16000, + 'num_spks': 4, + 'session_len_sec': 90, + 'soft_label_thres': 0.5, + 'soft_targets': False, + 'labels': None, + 'batch_size': batch_size, + 'shuffle': False, + 'seq_eval_mode': True, + 'num_workers': num_workers, + 'validation_mode': True, + 'use_lhotse': False, + 'use_bucketing': False, + 'drop_last': False, + 'pin_memory': True, + 'window_stride': 0.01, + 'subsampling_factor': 8, + }) + +def get_test_ds_config(manifest_filepath, batch_size, num_workers) -> DictConfig: + return DictConfig({ + 'manifest_filepath': manifest_filepath, + 'is_tarred': False, + 'tarred_audio_filepaths': None, + 'sample_rate': 16000, + 'num_spks': 4, + 'session_len_sec': 90, + 'soft_label_thres': 0.5, + 'soft_targets': False, + 'labels': None, + 'batch_size': batch_size, + 'shuffle': False, + 'seq_eval_mode': True, + 'num_workers': num_workers, + 'validation_mode': True, + 'use_lhotse': False, + 'use_bucketing': False, + 'drop_last': False, + 'pin_memory': True, + 'window_stride': 0.01, + 'subsampling_factor': 8, + }) + +class TestLhotseAudioToSpeechE2ESpkDiarDataset: + + + @pytest.mark.unit + @pytest.mark.parametrize( + "batch_size, num_workers, split", + [ + (4, 8, 'train'), # Example 1 + (4, 0, 'train'), # Example 2 + (2, 4, 'validation'), # Example 3 + (8, 2, 'test') # Example 4 + ] + ) + def test_e2e_speaker_diar_lhotse_dataset(self, test_data_dir, batch_size, num_workers, split): + manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/diarizer/lsm_val.json')) + num_samples = 8 + device = 'gpu' if torch.cuda.is_available() else 'cpu' + data_dict_list = [] + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: + with open(manifest_path, 'r', encoding='utf-8') as mfile: + for ix, line in enumerate(mfile): + if ix >= num_samples: + break + + line = line.replace("tests/data/", test_data_dir + "/").replace("\n", "") + f.write(f"{line}\n") + data_dict = json.loads(line) + data_dict_list.append(data_dict) + + f.seek(0) + if split == 'train': + config = get_train_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) + elif split == 'validation': + config = get_train_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) + elif split == 'test': + config = get_test_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) + + dataloader_instance = get_lhotse_dataloader_from_config( + config, + global_rank=0, + world_size=1, + dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), + ) + + deviation_thres_rate = 0.01 # 1% deviation allowed + for batch_index, batch in enumerate(dataloader_instance): + audio_signals, audio_signal_len, targets, target_lens = batch + for sample_index in range(audio_signals.shape[0]): + dataloader_audio_in_sec = audio_signal_len[sample_index].item() + data_dur_in_sec = abs(data_dict_list[batch_size*batch_index + sample_index]['duration'] * config.sample_rate - dataloader_audio_in_sec) + assert data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec, "Duration deviation exceeds 1%" + assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values" + assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values" + assert not torch.isnan(targets).any(), "targets tensor contains NaN values" + assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values" + \ No newline at end of file diff --git a/tests/collections/speaker_tasks/test_diar_metrics.py b/tests/collections/speaker_tasks/test_diar_metrics.py new file mode 100644 index 000000000000..3ae6f6f6a3fa --- /dev/null +++ b/tests/collections/speaker_tasks/test_diar_metrics.py @@ -0,0 +1,197 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import permutations + +import pytest +import torch + +from nemo.collections.asr.metrics.der import ( + calculate_session_cpWER, + calculate_session_cpWER_bruteforce, + get_online_DER_stats, + get_partial_ref_labels, +) + + +def word_count(spk_transcript): + return sum([len(w.split()) for w in spk_transcript]) + + +def calculate_wer_count(_ins, _del, _sub, ref_word_count): + return (_ins + _del + _sub) / ref_word_count + + +def permuted_input_test(hyp, ref, calculated): + """ + Randomly permute the input to see if evaluation result stays the same. + """ + for hyp_permed in permutations(hyp): + cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp_permed, spk_reference=ref) + diff = torch.abs(torch.tensor(calculated - cpWER)) + assert diff <= 1e-6 + + +class TestConcatMinPermWordErrorRate: + """ + Tests for cpWER calculation. + """ + + @pytest.mark.unit + def test_cpwer_oneword(self): + hyp = ["oneword"] + ref = ["oneword"] + _ins, _del, _sub = 0, 0, 0 + cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) + ref_word_count = word_count(ref) + calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) + diff = torch.abs(torch.tensor(calculated - cpWER)) + assert diff <= 1e-6 + permuted_input_test(hyp, ref, calculated) + cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) + diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) + assert diff <= 1e-6 + + # Test with a substitution + hyp = ["wrongword"] + _ins, _del, _sub = 0, 0, 1 + cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) + calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) + diff = torch.abs(torch.tensor(calculated - cpWER)) + assert diff <= 1e-6 + permuted_input_test(hyp, ref, calculated) + cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) + diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) + assert diff <= 1e-6 + + @pytest.mark.unit + def test_cpwer_perfect(self): + hyp = ["ff", "aa bb cc", "dd ee"] + ref = ["aa bb cc", "dd ee", "ff"] + cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) + calculated = 0 + diff = torch.abs(torch.tensor(calculated - cpWER)) + assert diff <= 1e-6 + permuted_input_test(hyp, ref, calculated) + + @pytest.mark.unit + def test_cpwer_spk_counfusion_and_asr_error(self): + hyp = ["aa bb c ff", "dd e ii jj kk", "hi"] + ref = ["aa bb cc ff", "dd ee gg jj kk", "hh ii"] + _ins, _del, _sub = 0, 1, 4 + cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) + ref_word_count = word_count(ref) + calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) + diff = torch.abs(torch.tensor(calculated - cpWER)) + assert diff <= 1e-6 + permuted_input_test(hyp, ref, calculated) + cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) + diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) + assert diff <= 1e-6 + + @pytest.mark.unit + def test_cpwer_undercount(self): + hyp = ["aa bb cc", "dd ee gg", "hh ii", "jj kk"] + ref = ["aa bb cc", "dd ee", "ff", "gg", "hh ii", "jj kk"] + _ins, _del, _sub = 0, 1, 0 + cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) + ref_word_count = word_count(ref) + calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) + diff = torch.abs(torch.tensor(calculated - cpWER)) + assert diff <= 1e-6 + cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) + diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) + assert diff <= 1e-6 + + @pytest.mark.unit + def test_cpwer_overcount(self): + hyp = ["aa bb cc", "dd ee gg hh", "ii jj kk"] + ref = ["aa bb cc", "dd ee ff gg hh ii jj kk"] + _ins, _del, _sub = 0, 1, 0 + cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) + ref_word_count = word_count(ref) + calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) + diff = torch.abs(torch.tensor(calculated - cpWER)) + assert diff <= 1e-6 + cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) + diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) + assert diff <= 1e-6 + + @pytest.mark.parametrize( + "pred_labels, ref_labels, expected_output", + [ + ([], [], []), + (["0.0 1.0 speaker1"], [], []), + (["0.0 1.0 speaker1"], ["0.0 1.5 speaker1"], ["0.0 1.0 speaker1"]), + (["0.1 0.4 speaker1", "0.5 1.0 speaker2"], ["0.0 1.5 speaker1"], ["0.0 1.0 speaker1"]), + ( + ["0.5 1.0 speaker2", "0.1 0.4 speaker1"], + ["0.0 1.5 speaker1"], + ["0.0 1.0 speaker1"], + ), # Order of prediction does not matter + ( + ["0.1 1.4 speaker1", "0.5 1.0 speaker2"], + ["0.0 1.5 speaker1"], + ["0.0 1.4 speaker1"], + ), # Overlapping prediction + ( + ["0.1 0.6 speaker1", "0.2 1.5 speaker2"], + ["0.5 1.0 speaker1", "1.01 2.0 speaker2"], + ["0.5 1.0 speaker1", "1.01 1.5 speaker2"], + ), + ( + ["0.0 2.0 speaker1"], + ["0.0 2.0 speaker1", "1.0 3.0 speaker2", "0.0 5.0 speaker3"], + ["0.0 2.0 speaker1", "1.0 2.0 speaker2", "0.0 2.0 speaker3"], + ), + ], + ) + def test_get_partial_ref_labels(self, pred_labels, ref_labels, expected_output): + assert get_partial_ref_labels(pred_labels, ref_labels) == expected_output + + @pytest.mark.parametrize( + "DER, CER, FA, MISS, diar_eval_count, der_stat_dict, deci, expected_der_dict, expected_der_stat_dict", + [ + ( + 0.3, + 0.1, + 0.05, + 0.15, + 1, + {"cum_DER": 0, "cum_CER": 0, "avg_DER": 0, "avg_CER": 0, "max_DER": 0, "max_CER": 0}, + 3, + {"DER": 30.0, "CER": 10.0, "FA": 5.0, "MISS": 15.0}, + {"cum_DER": 0.3, "cum_CER": 0.1, "avg_DER": 30.0, "avg_CER": 10.0, "max_DER": 30.0, "max_CER": 10.0}, + ), + ( + 0.1, + 0.2, + 0.03, + 0.07, + 2, + {"cum_DER": 0.3, "cum_CER": 0.3, "avg_DER": 15.0, "avg_CER": 15.0, "max_DER": 30.0, "max_CER": 10.0}, + 2, + {"DER": 10.0, "CER": 20.0, "FA": 3.0, "MISS": 7.0}, + {"cum_DER": 0.4, "cum_CER": 0.5, "avg_DER": 20.0, "avg_CER": 25.0, "max_DER": 30.0, "max_CER": 20.0}, + ), + ], + ) + def test_get_online_DER_stats( + self, DER, CER, FA, MISS, diar_eval_count, der_stat_dict, deci, expected_der_dict, expected_der_stat_dict + ): + actual_der_dict, actual_der_stat_dict = get_online_DER_stats( + DER, CER, FA, MISS, diar_eval_count, der_stat_dict, deci + ) + assert actual_der_dict == expected_der_dict + assert actual_der_stat_dict == expected_der_stat_dict diff --git a/tests/collections/speaker_tasks/test_diar_neural_inference.py b/tests/collections/speaker_tasks/test_diar_neural_inference.py new file mode 100644 index 000000000000..3218a631bda3 --- /dev/null +++ b/tests/collections/speaker_tasks/test_diar_neural_inference.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch + +from nemo.collections.asr.models.msdd_models import NeuralDiarizer + + +class TestNeuralDiarizerInference: + @pytest.mark.unit + @pytest.mark.parametrize( + "device", + [ + torch.device("cpu"), + pytest.param( + torch.device("cuda"), + marks=pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA required for test.',), + ), + ], + ) + @pytest.mark.parametrize("num_speakers", [None, 1]) + @pytest.mark.parametrize("max_num_speakers", [4]) + def test_msdd_diar_inference(self, tmpdir, test_data_dir, device, num_speakers, max_num_speakers): + """ + Test to ensure diarization inference works correctly. + - Ensures multiple audio files can be diarized sequentially + - Ensures both CPU/CUDA is supported + - Ensures that max speaker and num speaker are set correctly + - Ensures temporary directory is emptied at the end of diarization + - Sanity check to ensure outputs from diarization are reasonable + """ + audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav'] + audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames] + + diarizer = NeuralDiarizer.from_pretrained(model_name='diar_msdd_telephonic').to(device) + + out_dir = os.path.join(tmpdir, 'diarize_inference/') + + assert diarizer.msdd_model.device.type == device.type + assert diarizer._speaker_model.device.type == device.type + for audio_path in audio_paths: + annotation = diarizer( + audio_path, num_speakers=num_speakers, max_speakers=max_num_speakers, out_dir=out_dir + ) + + # assert max speakers has been set up correctly + assert diarizer.clustering_embedding.clus_diar_model._cluster_params.max_num_speakers == max_num_speakers + + if num_speakers: + assert diarizer._cfg.diarizer.clustering.parameters.oracle_num_speakers + + # assert all temporary files are cleaned up + assert len(os.listdir(out_dir)) == 0 + + # assert only 1 speaker & segment + assert len(annotation.labels()) == 1 + assert len(list(annotation.itersegments())) == 1 + + # class TestSortformerDiarizerInference: + # TODO: This test can only be implemented once SortformerDiarizer model is uploaded. diff --git a/tests/collections/speaker_tasks/test_diar_sortformer_models.py b/tests/collections/speaker_tasks/test_diar_sortformer_models.py new file mode 100644 index 000000000000..6e59206df894 --- /dev/null +++ b/tests/collections/speaker_tasks/test_diar_sortformer_models.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from omegaconf import DictConfig + +from nemo.collections.asr.models import SortformerEncLabelModel + + +@pytest.fixture() +def sortformer_model(): + + batch_size = 4 + model = { + 'pil_weight': 0.5, + 'ats_weight': 0.5, + 'num_workers': 18, + 'fc_d_model': 512, + 'tf_d_model': 192, + 'max_num_of_spks': 4, + 'session_len_sec': 90, + } + + + preprocessor = { + '_target_': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', + 'normalize': 'per_feature', + 'window_size': 0.025, + 'sample_rate': 16000, + 'window_stride': 0.01, + 'window': 'hann', + 'features': 80, + 'n_fft': 512, + 'frame_splicing': 1, + 'dither': 0.00001, + } + + sortformer_modules = { + '_target_': 'nemo.collections.asr.modules.sortformer_modules.SortformerModules', + 'num_spks': model['max_num_of_spks'], + 'dropout_rate': 0.5, + 'fc_d_model': model['fc_d_model'], + 'tf_d_model': model['tf_d_model'], + } + + encoder = { + '_target_': 'nemo.collections.asr.modules.ConformerEncoder', + 'feat_in': preprocessor['features'], + 'feat_out': -1, + 'n_layers': 18, + 'd_model': model['fc_d_model'], + 'subsampling': 'dw_striding', + 'subsampling_factor': 8, + 'subsampling_conv_channels': 256, + 'causal_downsampling': False, + 'ff_expansion_factor': 4, + 'self_attention_model': 'rel_pos', + 'n_heads': 8, + 'att_context_size': [-1, -1], + 'att_context_style': 'regular', + 'xscaling': True, + 'untie_biases': True, + 'pos_emb_max_len': 5000, + 'conv_kernel_size': 9, + 'conv_norm_type': 'batch_norm', + 'conv_context_size': None, + 'dropout': 0.1, + 'dropout_pre_encoder': 0.1, + 'dropout_emb': 0.0, + 'dropout_att': 0.1, + 'stochastic_depth_drop_prob': 0.0, + 'stochastic_depth_mode': 'linear', + 'stochastic_depth_start_layer': 1, + } + + transformer_encoder = { + '_target_': 'nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder', + 'num_layers': 18, + 'hidden_size': model['tf_d_model'], + 'inner_size': 768, + 'num_attention_heads': 8, + 'attn_score_dropout': 0.5, + 'attn_layer_dropout': 0.5, + 'ffn_dropout': 0.5, + 'hidden_act': 'relu', + 'pre_ln': False, + 'pre_ln_final_layer_norm': True, + } + + loss = { + '_target_': 'nemo.collections.asr.losses.bce_loss.BCELoss', + 'weight': None, + 'reduction': 'mean', + } + + + modelConfig = DictConfig( + {'pil_weight': 0.5, + 'ats_weight': 0.5, + 'num_workers': 1, + 'fc_d_model': 512, + 'tf_d_model': 192, + 'max_num_of_spks': 4, + 'session_len_sec': 90, + 'encoder': DictConfig(encoder), + 'transformer_encoder': DictConfig(transformer_encoder), + 'sortformer_modules': DictConfig(sortformer_modules), + 'preprocessor': DictConfig(preprocessor), + 'loss': DictConfig(loss), + 'optim': { + 'optimizer': 'Adam', + 'lr': 0.001, + 'betas': (0.9, 0.98), + } + } + ) + model = SortformerEncLabelModel(cfg=modelConfig) + return model + + +class TestSortformerEncLabelModel: + @pytest.mark.unit + def test_constructor(self, sortformer_model): + sortformer_diar_model = sortformer_model.train() + confdict = sortformer_diar_model.to_config_dict() + instance2 = SortformerEncLabelModel.from_config_dict(confdict) + assert isinstance(instance2, SortformerEncLabelModel) + + @pytest.mark.unit + @pytest.mark.parametrize( + "batch_size, frame_length, sample_len", + [ + (4, 0.08, 16), # Example 1 + (2, 0.02, 32), # Example 2 + (1, 0.1, 20), # Example 3 + ] +) + def test_forward_infer(self, sortformer_model, batch_size, frame_length, sample_len, num_spks=4): + sortformer_diar_model = sortformer_model.eval() + confdict = sortformer_diar_model.to_config_dict() + sampling_rate = confdict['preprocessor']['sample_rate'] + target_frame_count = int(sample_len // frame_length) + input_signal = torch.randn(size=(batch_size, sample_len * sampling_rate)) + input_signal_length = (sample_len * sampling_rate) * torch.ones(batch_size, dtype=torch.int) + targets = torch.randint(2, size=(batch_size, target_frame_count, num_spks), dtype=torch.int) + target_len = target_frame_count * torch.ones(batch_size, dtype=torch.int) + + with torch.no_grad(): + # batch size 1 + preds_list = [] + for i in range(input_signal.size(0)): + preds= sortformer_diar_model.forward(input_signal[i : i + 1], input_signal_length[i : i + 1]) + preds_list.append(preds) + preds_instance = torch.cat(preds_list, 0) + + # batch size 4 + preds_batch = sortformer_diar_model.forward(input_signal, input_signal_length) + assert preds_instance.shape == preds_batch.shape + + diff = torch.mean(torch.abs(preds_instance - preds_batch)) + assert diff <= 1e-6 + diff = torch.max(torch.abs(preds_instance - preds_batch)) + assert diff <= 1e-6 diff --git a/tests/collections/speaker_tasks/utils/test_data_simul_utils.py b/tests/collections/speaker_tasks/utils/test_data_simul_utils.py new file mode 100644 index 000000000000..295b79c76d18 --- /dev/null +++ b/tests/collections/speaker_tasks/utils/test_data_simul_utils.py @@ -0,0 +1,545 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import torch +from omegaconf import DictConfig + +from nemo.collections.asr.parts.utils.data_simulation_utils import ( + DataAnnotator, + SpeechSampler, + add_silence_to_alignments, + binary_search_alignments, + get_cleaned_base_path, + get_split_points_in_alignments, + normalize_audio, + read_noise_manifest, +) +from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line + + +@pytest.fixture() +def annotator(): + cfg = get_data_simulation_configs() + return DataAnnotator(cfg) + + +@pytest.fixture() +def sampler(): + cfg = get_data_simulation_configs() + sampler = SpeechSampler(cfg) + # Must get session-wise randomized silence/overlap mean + sampler.get_session_overlap_mean() + sampler.get_session_silence_mean() + return sampler + + +def get_data_simulation_configs(): + config_dict = { + 'data_simulator': { + 'manifest_filepath': '???', + 'sr': 16000, + 'random_seed': 42, + 'multiprocessing_chunksize': 10000, + 'session_config': {'num_speakers': 4, 'num_sessions': 60, 'session_length': 600}, + 'session_params': { + 'max_audio_read_sec': 20, + 'sentence_length_params': [0.4, 0.05], + 'dominance_var': 0.11, + 'min_dominance': 0.05, + 'turn_prob': 0.875, + 'min_turn_prob': 0.5, + 'mean_silence': 0.15, + 'mean_silence_var': 0.01, + 'per_silence_var': 900, + 'per_silence_min': 0.0, + 'per_silence_max': -1, + 'mean_overlap': 0.1, + 'mean_overlap_var': 0.01, + 'per_overlap_var': 900, + 'per_overlap_min': 0.0, + 'per_overlap_max': -1, + 'start_window': True, + 'window_type': 'hamming', + 'window_size': 0.05, + 'start_buffer': 0.1, + 'split_buffer': 0.1, + 'release_buffer': 0.1, + 'normalize': True, + 'normalization_type': 'equal', + 'normalization_var': 0.1, + 'min_volume': 0.75, + 'max_volume': 1.25, + 'end_buffer': 0.5, + }, + 'outputs': { + 'output_dir': '???', + 'output_filename': 'multispeaker_session', + 'overwrite_output': True, + 'output_precision': 3, + }, + 'background_noise': { + 'add_bg': False, + 'background_manifest': None, + 'num_noise_files': 10, + 'snr': 60, + 'snr_min': None, + }, + 'segment_augmentor': { + 'add_seg_aug': False, + 'augmentor': {'gain': {'prob': 0.5, 'min_gain_dbfs': -10.0, 'max_gain_dbfs': 10.0},}, + }, + 'session_augmentor': { + 'add_sess_aug': False, + 'augmentor': {'white_noise': {'prob': 1.0, 'min_level': -90, 'max_level': -46},}, + }, + 'speaker_enforcement': {'enforce_num_speakers': True, 'enforce_time': [0.25, 0.75]}, + 'segment_manifest': {'window': 0.5, 'shift': 0.25, 'step_count': 50, 'deci': 3}, + } + } + return DictConfig(config_dict) + + +def generate_words_and_alignments(sample_index): + if sample_index == 0: + words = ['', 'hello', 'world'] + alignments = [0.5, 1.0, 1.5] + elif sample_index == 1: + words = ["", "stephanos", "dedalos", ""] + alignments = [0.51, 1.31, 2.04, 2.215] + elif sample_index == 2: + words = ['', 'hello', 'world', '', 'welcome', 'to', 'nemo', ''] + alignments = [0.5, 1.0, 1.5, 1.7, 1.8, 2.2, 2.7, 2.8] + else: + raise ValueError(f"sample_index {sample_index} not supported") + speaker_id = 'speaker_0' + return words, alignments, speaker_id + + +class TestGetCtmLine: + @pytest.mark.unit + @pytest.mark.parametrize("conf", [0, 1]) + def test_wrong_type_conf_values(self, conf): + # Test with wrong integer confidence values + with pytest.raises(ValueError): + result = get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + ) + expected = f"test_source 1 0.12 0.46 word {conf} lex speaker1\n" + assert result == expected, f"Failed on valid conf value {conf}" + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [0.0, 0.5, 1.0, 0.01, 0.99]) + def test_valid_conf_values(self, conf): + # Test with valid confidence values + output_precision = 2 + result = get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + output_precision=output_precision, + ) + expected = "test_source 1 0.12 0.46 word" + f" {conf:.{output_precision}f} lex speaker1\n" + assert result == expected, f"Failed on valid conf value {conf}" + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [-0.1, 1.1, 2, -1, 100, -100]) + def test_invalid_conf_ranges(self, conf): + # Test with invalid confidence values + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + ) + + @pytest.mark.unit + @pytest.mark.parametrize( + "start_time, duration, output_precision", + [(0.123, 0.456, 2), (1.0, 2.0, 1), (0.0, 0.0, 2), (0.01, 0.99, 3), (1.23, 4.56, 2)], + ) + def test_valid_start_time_duration_with_precision(self, start_time, duration, output_precision): + # Test with valid beginning time, duration values and output precision + confidence = 0.5 + result = get_ctm_line( + source="test_source", + channel=1, + start_time=start_time, + duration=duration, + token="word", + conf=confidence, + type_of_token="lex", + speaker="speaker1", + output_precision=output_precision, + ) + expected_start_time = ( + f"{start_time:.{output_precision}f}" # Adjusted to match the output format with precision + ) + expected_duration = f"{duration:.{output_precision}f}" # Adjusted to match the output format with precision + expected_confidence = ( + f"{confidence:.{output_precision}f}" # Adjusted to match the output format with precision + ) + expected = f"test_source 1 {expected_start_time} {expected_duration} word {expected_confidence} lex speaker1\n" + assert ( + result == expected + ), f"Failed on valid start_time {start_time}, duration {duration} with precision {output_precision}" + + @pytest.mark.unit + def test_valid_input(self): + # Test with completely valid inputs + result = get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=0.789, + type_of_token="lex", + speaker="speaker1", + ) + expected = "test_source 1 0.12 0.46 word 0.79 lex speaker1\n" + assert result == expected, "Failed on valid input" + + @pytest.mark.unit + @pytest.mark.parametrize( + "start_time, duration", + [ + ("not a float", 1.0), + (1.0, "not a float"), + (1, 2.0), # Integers should be converted to float + (2.0, 3), # Same as above + ], + ) + def test_invalid_types_for_time_duration(self, start_time, duration): + # Test with invalid types for start_time and duration + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", + channel=1, + start_time=start_time, + duration=duration, + token="word", + conf=0.5, + type_of_token="lex", + speaker="speaker1", + ) + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [-0.1, 1.1, "not a float"]) + def test_invalid_conf_values(self, conf): + # Test with invalid values for conf + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + ) + + @pytest.mark.unit + def test_default_values(self): + # Test with missing optional parameters + result = get_ctm_line( + source="test_source", + channel=None, + start_time=0.123, + duration=0.456, + token="word", + conf=None, + type_of_token=None, + speaker=None, + ) + expected = "test_source 1 0.12 0.46 word NA unknown NA\n" + assert result == expected, "Failed on default values" + + +class TestDataSimulatorUtils: + # TODO: add tests for all util functions + @pytest.mark.parametrize("max_audio_read_sec", [2.5, 3.5, 4.5]) + @pytest.mark.parametrize("min_alignment_count", [2, 3, 4]) + def test_binary_search_alignments(self, max_audio_read_sec, min_alignment_count): + inds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + alignments = [0.5, 11.0, 11.5, 12.0, 13.0, 14.0, 14.5, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 30, 40.0] + offset_max = binary_search_alignments(inds, max_audio_read_sec, min_alignment_count, alignments) + assert max_audio_read_sec <= alignments[-1 * min_alignment_count] - alignments[inds[offset_max]] + + @pytest.mark.parametrize("sample_len", [100, 16000]) + @pytest.mark.parametrize("gain", [0.1, 0.5, 1.0, 2.0, 5.0]) + def test_normalize_audio(self, sample_len, gain): + array_raw = np.random.randn(sample_len) + array_input = torch.from_numpy(gain * array_raw / np.max(np.abs(array_raw))) + norm_array = normalize_audio(array_input) + assert torch.max(torch.abs(norm_array)) == 1.0 + assert torch.min(torch.abs(norm_array)) < 1.0 + + @pytest.mark.parametrize("output_dir", [os.path.join(os.getcwd(), "test_dir")]) + def test_get_cleaned_base_path(self, output_dir): + result_path = get_cleaned_base_path(output_dir, overwrite_output=True) + assert os.path.exists(result_path) and not os.path.isfile(result_path) + result_path = get_cleaned_base_path(output_dir, overwrite_output=False) + assert os.path.exists(result_path) and not os.path.isfile(result_path) + os.rmdir(result_path) + assert not os.path.exists(result_path) + + @pytest.mark.parametrize( + "words, alignments, answers", + [ + (['', 'hello', 'world'], [0.5, 1.0, 1.5], [[0, 16000.0]]), + ( + ['', 'hello', 'world', '', 'welcome', 'to', 'nemo', ''], + [0.27, 1.0, 1.7, 2.7, 2.8, 3.2, 3.7, 3.9], + [[0, (1.7 + 0.5) * 16000], [(2.7 - 0.5) * 16000, (3.9 - 0.27) * 16000]], + ), + ], + ) + @pytest.mark.parametrize("sr", [16000]) + @pytest.mark.parametrize("split_buffer", [0.5]) + @pytest.mark.parametrize("new_start", [0.0]) + def test_get_split_points_in_alignments(self, words, alignments, sr, new_start, split_buffer, answers): + sentence_audio_len = sr * (alignments[-1] - alignments[0]) + splits = get_split_points_in_alignments(words, alignments, split_buffer, sr, sentence_audio_len, new_start) + assert len(splits) == len(answers) + for k, interval in enumerate(splits): + assert abs(answers[k][0] - interval[0]) < 1e-4 + assert abs(answers[k][1] - interval[1]) < 1e-4 + + @pytest.mark.parametrize( + "alignments, words", [(['hello', 'world'], [1.0, 1.5]), (['', 'hello', 'world'], [0.0, 1.0, 1.5])] + ) + def test_add_silence_to_alignments(self, alignments, words): + """ + Test add_silence_to_alignments function. + """ + audio_manifest = { + 'audio_filepath': 'test.wav', + 'alignments': alignments, + 'words': words, + } + audio_manifest = add_silence_to_alignments(audio_manifest) + if words[0] == '': + assert audio_manifest['alignments'] == [0.0] + alignments + assert audio_manifest['words'] == [''] + words + else: + assert audio_manifest['alignments'] == alignments + assert audio_manifest['words'] == words + + +class TestDataAnnotator: + def test_init(self, annotator): + assert isinstance(annotator, DataAnnotator) + + def test_create_new_rttm_entry(self, annotator): + words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) + start, end = alignments[0], alignments[-1] + rttm_list = annotator.create_new_rttm_entry( + words=words, alignments=alignments, start=start, end=end, speaker_id=speaker_id + ) + assert rttm_list[0] == f"{start} {end} {speaker_id}" + + def test_create_new_json_entry(self, annotator): + words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) + start, end = alignments[0], alignments[-1] + test_wav_filename = '/path/to/test_wav_filename.wav' + test_rttm_filename = '/path/to/test_rttm_filename.rttm' + test_ctm_filename = '/path/to/test_ctm_filename.ctm' + text = " ".join(words) + + one_line_json_dict = annotator.create_new_json_entry( + text=text, + wav_filename=test_wav_filename, + start=start, + length=end - start, + speaker_id=speaker_id, + rttm_filepath=test_rttm_filename, + ctm_filepath=test_ctm_filename, + ) + start = round(float(start), annotator._params.data_simulator.outputs.output_precision) + length = round(float(end - start), annotator._params.data_simulator.outputs.output_precision) + meta = { + "audio_filepath": test_wav_filename, + "offset": start, + "duration": length, + "label": speaker_id, + "text": text, + "num_speakers": annotator._params.data_simulator.session_config.num_speakers, + "rttm_filepath": test_rttm_filename, + "ctm_filepath": test_ctm_filename, + "uem_filepath": None, + } + assert one_line_json_dict == meta + + def test_create_new_ctm_entry(self, annotator): + words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) + session_name = 'test_session' + ctm_list = annotator.create_new_ctm_entry( + words=words, alignments=alignments, session_name=session_name, speaker_id=speaker_id, start=alignments[0] + ) + assert ctm_list[0] == ( + alignments[1], + get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[1], + duration=float(alignments[1] - alignments[0]), + token=words[1], + conf=None, + type_of_token='lex', + speaker=speaker_id, + ), + ) + assert ctm_list[1] == ( + alignments[2], + get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[2], + duration=float(alignments[2] - alignments[1]), + token=words[2], + conf=None, + type_of_token='lex', + speaker=speaker_id, + ), + ) + + +class TestSpeechSampler: + def test_init(self, sampler): + assert isinstance(sampler, SpeechSampler) + + def test_init_overlap_params(self, sampler): + sampler._init_overlap_params() + assert sampler.per_silence_min_len is not None + assert sampler.per_silence_max_len is not None + assert type(sampler.per_silence_min_len) == int + assert type(sampler.per_silence_max_len) == int + + def test_init_silence_params(self, sampler): + sampler._init_overlap_params() + assert sampler.per_overlap_min_len is not None + assert sampler.per_overlap_max_len is not None + assert type(sampler.per_overlap_min_len) == int + assert type(sampler.per_overlap_max_len) == int + + @pytest.mark.parametrize("mean", [0.1, 0.2, 0.3]) + @pytest.mark.parametrize("var", [0.05, 0.07]) + def test_get_session_silence_mean_pass(self, sampler, mean, var): + sampler.mean_silence = mean + sampler.mean_silence_var = var + sampled_silence_mean = sampler.get_session_silence_mean() + assert 0 <= sampled_silence_mean <= 1 + + @pytest.mark.parametrize("mean", [0.5]) + @pytest.mark.parametrize("var", [0.5, 0.6]) + def test_get_session_silence_mean_fail(self, sampler, mean, var): + """ + This test should raise `ValueError` because `mean_silence_var` + should be less than `mean_silence * (1 - mean_silence)`. + """ + sampler.mean_silence = mean + sampler.mean_silence_var = var + with pytest.raises(ValueError) as execinfo: + sampler.get_session_silence_mean() + assert "ValueError" in str(execinfo) and "mean_silence_var" in str(execinfo) + + @pytest.mark.parametrize("mean", [0.1, 0.2, 0.3]) + @pytest.mark.parametrize("var", [0.05, 0.07]) + def test_get_session_overlap_mean_pass(self, sampler, mean, var): + sampler.mean_overlap = mean + sampler.mean_overlap_var = var + sampled_overlap_mean = sampler.get_session_overlap_mean() + assert 0 <= sampled_overlap_mean <= 1 + + @pytest.mark.parametrize("mean", [0.4, 0.5]) + @pytest.mark.parametrize("var", [0.3, 0.8]) + def test_get_session_overlap_mean_fail(self, sampler, mean, var): + """ + This test should raise `ValueError` because `mean_overlap_var` + should be less than `mean_overlap * (1 - mean_overlap)`. + """ + sampler.mean_overlap = mean + sampler.mean_overlap_var = var + sampler._params = DictConfig(sampler._params) + with pytest.raises(ValueError) as execinfo: + sampler.get_session_overlap_mean() + assert "ValueError" in str(execinfo) and "mean_overlap_var" in str(execinfo) + + @pytest.mark.parametrize("non_silence_len_samples", [16000, 32000]) + @pytest.mark.parametrize("running_overlap_len_samples", [8000, 12000]) + def test_sample_from_overlap_model(self, sampler, non_silence_len_samples, running_overlap_len_samples): + sampler.get_session_overlap_mean() + sampler.running_overlap_len_samples = running_overlap_len_samples + overlap_amount = sampler.sample_from_overlap_model(non_silence_len_samples=non_silence_len_samples) + assert type(overlap_amount) == int + assert 0 <= overlap_amount + + @pytest.mark.parametrize("running_len_samples", [8000, 16000]) + @pytest.mark.parametrize("running_overlap_len_samples", [8000, 12000]) + def test_sample_from_silence_model(self, sampler, running_len_samples, running_overlap_len_samples): + sampler.get_session_silence_mean() + self.running_overlap_len_samples = running_overlap_len_samples + silence_amount = sampler.sample_from_silence_model(running_len_samples=running_len_samples) + assert type(silence_amount) == int + assert 0 <= silence_amount + + @pytest.mark.with_downloads() + @pytest.mark.parametrize("num_noise_files", [1, 2, 4]) + def test_sample_noise_manifest(self, sampler, num_noise_files, test_data_dir): + sampler.num_noise_files = num_noise_files + manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/an4_val.json')) + noise_manifest = read_noise_manifest(add_bg=True, background_manifest=manifest_path) + sampled_noise_manifests = sampler.sample_noise_manifest(noise_manifest=noise_manifest) + assert len(sampled_noise_manifests) == num_noise_files + + @pytest.mark.parametrize("running_speech_len_samples", [32000, 64000]) + @pytest.mark.parametrize("running_overlap_len_samples", [16000, 32000]) + @pytest.mark.parametrize("running_len_samples", [64000, 96000]) + @pytest.mark.parametrize("non_silence_len_samples", [16000, 32000]) + def test_silence_vs_overlap_selector( + self, + sampler, + running_overlap_len_samples, + running_speech_len_samples, + running_len_samples, + non_silence_len_samples, + ): + sampler.running_overlap_len_samples = running_overlap_len_samples + sampler.running_speech_len_samples = running_speech_len_samples + add_overlap = sampler.silence_vs_overlap_selector( + running_len_samples=running_len_samples, non_silence_len_samples=non_silence_len_samples + ) + assert type(add_overlap) == bool diff --git a/tests/collections/speaker_tasks/utils/test_diar_utils.py b/tests/collections/speaker_tasks/utils/test_diar_utils.py new file mode 100644 index 000000000000..cd7e7f5b2a3b --- /dev/null +++ b/tests/collections/speaker_tasks/utils/test_diar_utils.py @@ -0,0 +1,1046 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import torch +from scipy.optimize import linear_sum_assignment as scipy_linear_sum_assignment +from typing import List, Tuple +import math + +from nemo.collections.asr.data.audio_to_label import repeat_signal +from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering +from nemo.collections.asr.parts.utils.offline_clustering import ( + SpeakerClustering, + get_scale_interpolated_embs, + getCosAffinityMatrix, + getKneighborsConnections, + split_input_data, +) +from nemo.collections.asr.parts.utils.online_clustering import ( + OnlineSpeakerClustering, + get_closest_embeddings, + get_merge_quantity, + get_minimal_indices, + merge_vectors, + run_reducer, + stitch_cluster_labels, +) +from nemo.collections.asr.parts.utils.optimization_utils import LinearSumAssignmentSolver +from nemo.collections.asr.parts.utils.optimization_utils import linear_sum_assignment as nemo_linear_sum_assignment +from nemo.collections.asr.parts.utils.speaker_utils import ( + OnlineSegmentor, + check_ranges, + fl2int, + get_new_cursor_for_update, + get_online_segments_from_slices, + get_online_subsegments_from_buffer, + get_speech_labels_for_update, + get_sub_range_list, + get_subsegments, + get_subsegments_scriptable, + get_target_sig, + int2fl, + is_overlap, + merge_float_intervals, + merge_int_intervals, + tensor_to_list, +) + + +def check_range_values(target, source): + bool_list = [] + for tgt, src in zip(target, source): + for x, y in zip(src, tgt): + bool_list.append(abs(x - y) < 1e-6) + return all(bool_list) + + +def check_labels(target, source): + bool_list = [] + for x, y in zip(target, source): + bool_list.append(abs(x - y) < 1e-6) + return all(bool_list) + + +def matrix(mat, use_tensor=True, dtype=torch.long): + if use_tensor: + mat = torch.Tensor(mat).to(dtype) + else: + mat = np.array(mat) + return mat + +def __get_subsegments(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: + """ + Return subsegments from a segment of audio file + Args: + offset (float): start time of audio segment + window (float): window length for segments to subsegments length + shift (float): hop length for subsegments shift + duration (float): duration of segment + Returns: + subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment + """ + subsegments: List[List[float]] = [] + start = offset + slice_end = start + duration + base = math.ceil((duration - window) / shift) + slices = 1 if base < 0 else base + 1 + for slice_id in range(slices): + end = start + window + if end > slice_end: + end = slice_end + subsegments.append([start, end - start]) + start = offset + (slice_id + 1) * shift + return subsegments + + +def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim): + """Generate a set of artificial orthogonal embedding vectors from random numbers + """ + gaus = torch.randn(emb_dim, emb_dim) + _svd = torch.linalg.svd(gaus) + orth = _svd[0] @ _svd[2] + orth_embs = orth[:total_spks] + # Assert orthogonality + assert torch.abs(getCosAffinityMatrix(orth_embs) - torch.diag(torch.ones(total_spks))).sum() < 1e-4 + return orth_embs + + +def generate_toy_data( + n_spks=2, + spk_dur=3, + emb_dim=192, + perturb_sigma=0.0, + ms_window=[1.5, 1.0, 0.5], + ms_shift=[0.75, 0.5, 0.25], + torch_seed=0, +): + torch.manual_seed(torch_seed) + spk_timestamps = [(spk_dur * k, spk_dur) for k in range(n_spks)] + emb_list, seg_list = [], [] + multiscale_segment_counts = [0 for _ in range(len(ms_window))] + ground_truth = [] + random_orthogonal_embs = generate_orthogonal_embs(n_spks, perturb_sigma, emb_dim) + for scale_idx, (window, shift) in enumerate(zip(ms_window, ms_shift)): + for spk_idx, (offset, dur) in enumerate(spk_timestamps): + segments_stt_dur = get_subsegments(offset=offset, window=window, shift=shift, duration=dur) + segments = [[x[0], x[0] + x[1]] for x in segments_stt_dur] + emb_cent = random_orthogonal_embs[spk_idx, :] + emb = emb_cent.tile((len(segments), 1)) + 0.1 * torch.rand(len(segments), emb_dim) + seg_list.extend(segments) + emb_list.append(emb) + if emb.shape[0] == 0: + import ipdb; ipdb.set_trace() + multiscale_segment_counts[scale_idx] += emb.shape[0] + + if scale_idx == len(multiscale_segment_counts) - 1: + ground_truth.extend([spk_idx] * emb.shape[0]) + + emb_tensor = torch.concat(emb_list) + multiscale_segment_counts = torch.tensor(multiscale_segment_counts) + segm_tensor = torch.tensor(seg_list) + multiscale_weights = torch.ones(len(ms_window)).unsqueeze(0) + ground_truth = torch.tensor(ground_truth) + return emb_tensor, segm_tensor, multiscale_segment_counts, multiscale_weights, spk_timestamps, ground_truth + + +class TestDiarizationSequneceUtilFunctions: + """Tests diarization and speaker-task related utils. + """ + + @pytest.mark.unit + @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]]) + @pytest.mark.parametrize("target", [[0, 0, 0, 1, 1, 2]]) + @pytest.mark.parametrize("offset", [1, 10]) + def test_minimal_index_ex2(self, Y, target, offset): + Y = torch.tensor(Y) + target = torch.tensor(target) + min_Y = get_minimal_indices(Y) + assert check_labels(target, min_Y) + min_Y = get_minimal_indices(Y + offset) + assert check_labels(target, min_Y) + + @pytest.mark.parametrize("Y", [[4, 0, 0, 5, 4, 5], [14, 12, 12, 19, 14, 19]]) + @pytest.mark.parametrize("target", [[1, 0, 0, 2, 1, 2]]) + @pytest.mark.parametrize("offset", [1, 10]) + def test_minimal_index_ex2(self, Y, target, offset): + Y = torch.tensor(Y) + target = torch.tensor(target) + min_Y = get_minimal_indices(Y) + assert check_labels(target, min_Y) + min_Y = get_minimal_indices(Y + offset) + assert check_labels(target, min_Y) + + @pytest.mark.unit + @pytest.mark.parametrize("N", [2, 4, 16, 64]) + def test_minimal_index_same(self, N): + Y = matrix([0] * N + [1] * N + [2] * N) + min_Y = get_minimal_indices(Y) + target = matrix([0] * N + [1] * N + [2] * N) + assert check_labels(target, min_Y) + + @pytest.mark.unit + @pytest.mark.parametrize("N", [2, 4, 16, 64]) + def test_stitch_cluster_labels_label_switch(self, N): + Y_old = matrix([0] * N) + Y_new = matrix([0] * N) + 1 + target = matrix([0] * N) + result = stitch_cluster_labels(Y_old, Y_new) + assert check_labels(target, result) + + @pytest.mark.unit + @pytest.mark.parametrize("N", [2, 4, 16, 64]) + def test_stitch_cluster_labels_label_many_to_one(self, N): + Y_old = matrix(np.arange(N).tolist()) + Y_new = matrix([0] * N) + target = matrix([0] * N) + result = stitch_cluster_labels(Y_old, Y_new) + assert check_labels(target, result) + + @pytest.mark.unit + @pytest.mark.parametrize("N", [2, 4, 16, 64]) + def test_stitch_cluster_labels_label_one_to_many(self, N): + Y_old = matrix(np.arange(N).tolist()) + Y_new = matrix([k for k in range(N)]) + target = matrix([k for k in range(N)]) + result = stitch_cluster_labels(Y_old, Y_new) + assert check_labels(target, result) + + @pytest.mark.unit + @pytest.mark.parametrize("N", [2, 4, 16, 64]) + def test_stitch_cluster_labels_one_label_replaced(self, N): + Y_old = matrix([0] * N + [1] * N + [2] * N) + Y_new = matrix([1] * N + [2] * N + [3] * N) + target = matrix([0] * N + [1] * N + [2] * N) + result = stitch_cluster_labels(Y_old, Y_new) + assert check_labels(target, result) + + @pytest.mark.unit + @pytest.mark.parametrize("N", [2, 4, 16, 64]) + def test_stitch_cluster_labels_confusion_error(self, N): + Y_old = matrix([0] * N + [1] * (N - 1) + [2] * (N + 1)) + Y_new = matrix([1] * N + [2] * N + [3] * N) + target = matrix([0] * N + [1] * N + [2] * N) + result = stitch_cluster_labels(Y_old, Y_new) + assert check_labels(target, result) + + @pytest.mark.unit + @pytest.mark.parametrize("N", [2, 256]) + def test_stitch_cluster_labels_speaker_more_speakers(self, N): + Y_old = matrix([0] * N + [1] * (N - 1) + [2] * (N + 1) + [0, 0, 0]) + Y_new = matrix([1] * N + [0] * N + [2] * N + [4, 5, 6]) + target = matrix([0] * N + [1] * N + [2] * N + [3, 4, 5]) + result = stitch_cluster_labels(Y_old, Y_new) + assert check_labels(target, result) + + @pytest.mark.unit + @pytest.mark.parametrize("N", [2, 256]) + def test_stitch_cluster_labels_speaker_longer_sequence(self, N): + Y_old = matrix([0] * N + [1] * N + [2] * N + [0, 0, 0] * N) + Y_new = matrix([1] * N + [2] * N + [0] * N + [1, 2, 3, 1, 2, 3] * N) + target = matrix([0] * N + [1] * N + [2] * N + [0, 1, 3, 0, 1, 3] * N) + result = stitch_cluster_labels(Y_old, Y_new) + assert check_labels(target, result) + + @pytest.mark.unit + @pytest.mark.parametrize("n_spks", [2, 3, 4, 5]) + @pytest.mark.parametrize("merge_quantity", [2, 3]) + def test_embedding_merger(self, n_spks, merge_quantity): + em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks, spk_dur=5, perturb_sigma=10) + em_s, ts_s = split_input_data(em, ts, mc) + target_speaker_index = 0 + pre_clus_labels = gt + ndx = torch.where(pre_clus_labels == target_speaker_index)[0] + pre_embs = em_s[-1] + affinity_mat = getCosAffinityMatrix(pre_embs) + cmat = affinity_mat[:, ndx][ndx, :] + # Check the dimension of the selected affinity values + assert cmat.shape[0] == cmat.shape[1] == torch.sum(pre_clus_labels == target_speaker_index).item() + index_2d, rest_inds = get_closest_embeddings(cmat, merge_quantity) + # Check the most closest affinity value + assert torch.max(cmat.sum(0)) == cmat.sum(0)[index_2d[0]] + spk_cluster_labels, emb_ndx = pre_clus_labels[ndx], pre_embs[ndx] + merged_embs, merged_clus_labels = merge_vectors(index_2d, emb_ndx, spk_cluster_labels) + # Check the number of merged embeddings and labels + assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] + + @pytest.mark.unit + @pytest.mark.parametrize("n_spks", [1, 8]) + @pytest.mark.parametrize("spk_dur", [0.2, 0.25, 0.5, 1, 10]) + def test_cosine_affinity_calculation(self, n_spks, spk_dur): + em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=spk_dur) + em_s, ts_s = split_input_data(em, ts, mc) + affinity_mat = getCosAffinityMatrix(em_s[-1]) + # affinity_mat should not contain any nan element + assert torch.any(torch.isnan(affinity_mat)) == False + + @pytest.mark.unit + @pytest.mark.parametrize("n_spks", [1, 8]) + @pytest.mark.parametrize("spk_dur", [0.2, 0.25, 0.5, 1, 10]) + def test_cosine_affinity_calculation_scale_interpol(self, n_spks, spk_dur): + em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=spk_dur) + em_s, ts_s = split_input_data(em, ts, mc) + embs, _ = get_scale_interpolated_embs(mw, em_s, ts_s) + affinity_mat = getCosAffinityMatrix(embs) + # affinity_mat should not contain any nan element + assert torch.any(torch.isnan(affinity_mat)) == False + + @pytest.mark.unit + @pytest.mark.parametrize("n_spks", [4, 5, 6]) + @pytest.mark.parametrize("target_speaker_index", [0, 1, 2]) + @pytest.mark.parametrize("merge_quantity", [2, 3]) + def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): + em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10) + em_s, ts_s = split_input_data(em, ts, mc) + merged_embs, merged_clus_labels, _ = run_reducer( + pre_embs=em_s[-1], target_spk_idx=target_speaker_index, merge_quantity=merge_quantity, pre_clus_labels=gt, + ) + assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] + + @pytest.mark.unit + @pytest.mark.parametrize("ntbr", [3]) + @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)]) + @pytest.mark.parametrize("mspb", [25]) + def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + assert all(class_target_vol == torch.tensor([3, 0])) + + @pytest.mark.unit + @pytest.mark.parametrize("ntbr", [3]) + @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)]) + @pytest.mark.parametrize("mspb", [0, 25]) + def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + assert all(class_target_vol == torch.tensor([3, 0, 0])) + + @pytest.mark.unit + @pytest.mark.parametrize("ntbr", [132 - 45]) + @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)]) + @pytest.mark.parametrize("mspb", [3, 10]) + def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + assert all(class_target_vol == torch.tensor([18, 13, 56, 0])) + + @pytest.mark.unit + @pytest.mark.parametrize("ntbr", [3]) + @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)]) + @pytest.mark.parametrize("mspb", [0, 2]) + def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + assert all(class_target_vol == torch.tensor([2, 1, 0])) + + @pytest.mark.unit + @pytest.mark.parametrize("ntbr", [2]) + @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)]) + @pytest.mark.parametrize("mspb", [2]) + def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb): + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + assert all(class_target_vol == torch.tensor([2, 0, 0, 0])) + + +class TestClassExport: + @pytest.mark.unit + def test_online_segmentor_class_export(self): + _OnlineSegmentor = torch.jit.script(OnlineSegmentor) + online_segmentor = _OnlineSegmentor(sample_rate=16000) + assert isinstance(online_segmentor, OnlineSegmentor) + + @pytest.mark.unit + def test_online_segmentor_instance_export(self): + online_segmentor = OnlineSegmentor(sample_rate=16000) + online_segmentor = torch.jit.script(online_segmentor) + isinstance(online_segmentor, torch.jit._script.RecursiveScriptClass) + + @pytest.mark.unit + def test_online_speaker_clustering_instance_export(self): + online_clus = OnlineSpeakerClustering( + max_num_speakers=8, + max_rp_threshold=0.15, + sparse_search_volume=30, + history_buffer_size=150, + current_buffer_size=150, + cuda=True, + ) + online_clus = torch.jit.script(online_clus) + isinstance(online_clus, torch.jit._script.RecursiveScriptClass) + + @pytest.mark.unit + def test_online_speaker_clustering_instance_export(self): + offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=True) + offline_speaker_clustering = torch.jit.script(offline_speaker_clustering) + isinstance(offline_speaker_clustering, torch.jit._script.RecursiveScriptClass) + +class TestGetSubsegments: + @pytest.mark.unit + @pytest.mark.parametrize( + "offset, window, shift, duration, min_subsegment_duration, decimals, use_asr_style_frame_count, sample_rate, feat_per_sec, expected", + [ + (12.05, 1.5, 0.75, 2.4, 0.01, 2, False, 16000, 100, [[12.05, 1.5], [12.8, 1.5], [13.55, 0.9]]), + (0, 1.0, 0.5, 0.4, 0.01, 2, False, 16000, 100, [[0, 0.4]]), + (0, 2.0, 1.0, 1.5, 0.5, 2, False, 16000, 100, [[0, 1.5]]), + (10, 1.5, 0.75, 4.5, 0.5, 2, False, 16000, 100, [[10, 1.5], [10.75, 1.5], [11.5, 1.5], [12.25, 1.5], [13.0, 1.5]]), + (0, 1.5, 0.5, 0.3, 0.01, 2, True, 16000, 100, [[0, 0.3]]), + ], + ) + def test_get_subsegments( + self, + offset, + window, + shift, + duration, + min_subsegment_duration, + decimals, + use_asr_style_frame_count, + sample_rate, + feat_per_sec, + expected, + ): + + for is_scriptable in [True, False]: + if is_scriptable: + result = get_subsegments_scriptable( + offset=offset, + window=window, + shift=shift, + duration=duration, + ) + else: + result = get_subsegments( + offset=offset, + window=window, + shift=shift, + duration=duration, + min_subsegment_duration=min_subsegment_duration, + decimals=decimals, + use_asr_style_frame_count=use_asr_style_frame_count, + sample_rate=sample_rate, + feat_per_sec=feat_per_sec, + ) + result_round = [] + for subsegment in result: + result_round.append([round(x, decimals) for x in subsegment]) + assert result_round == expected + + @pytest.mark.unit + def test_min_subsegment_duration_filtering(self): + result = get_subsegments( + offset=0, + window=1.5, + shift=0.5, + duration=3, + min_subsegment_duration=2.0, + decimals=2, + use_asr_style_frame_count=False, + ) + expected = [] # Only subsegments meeting the duration filter should remain + assert result == expected + + @pytest.mark.unit + def test_zero_duration(self): + result = get_subsegments( + offset=0, + window=1.0, + shift=0.5, + duration=0, + min_subsegment_duration=0.01, + decimals=2, + use_asr_style_frame_count=False, + ) + assert result == [] + + @pytest.mark.unit + def test_edge_case_short_slice(self): + result = get_subsegments( + offset=0, + window=0.5, + shift=0.25, # Shift larger than duration + duration=0.25, + min_subsegment_duration=0.01, + decimals=2, + use_asr_style_frame_count=False, + ) + assert result == [[0.0, 0.25]] + + +class TestDiarizationSegmentationUtils: + """ + Test segmentation util functions + """ + + @pytest.mark.unit + @pytest.mark.parametrize( + "intervals", + [ + [[1, 4], [2, 6], [8, 10], [15, 18]], + [[8, 10], [15, 18], [2, 6], [1, 3]], + [[8, 10], [15, 18], [2, 6], [1, 3], [3, 5]], + [[8, 10], [8, 8], [15, 18], [2, 6], [1, 6], [2, 4]], + ], + ) + @pytest.mark.parametrize("target", [[[1, 6], [8, 10], [15, 18]]]) + def test_merge_int_intervals_ex1(self, intervals, target): + merged = merge_int_intervals(intervals) + assert check_range_values(target, merged) + + @pytest.mark.unit + @pytest.mark.parametrize( + "intervals", + [ + [[6, 8], [0, 9], [2, 4], [4, 7]], + [[0, 9], [6, 8], [4, 7], [2, 4]], + [[0, 4], [0, 0], [4, 9], [2, 4]], + [[6, 8], [2, 8], [0, 3], [3, 4], [4, 5], [5, 9]], + ], + ) + @pytest.mark.parametrize("target", [[[0, 9]]]) + def test_merge_int_intervals_ex2(self, intervals, target): + merged = merge_int_intervals(intervals) + assert check_range_values(target, merged) + + @pytest.mark.unit + @pytest.mark.parametrize("intervals", [[[0, 1], [1, 9]], [[0, 0], [0, 9]], [[0, 9], [0, 9]]]) + @pytest.mark.parametrize("target", [[[0, 9]]]) + def test_merge_int_intervals_edge_test(self, intervals, target): + merged = merge_int_intervals(intervals) + assert check_range_values(target, merged) + + @pytest.mark.unit + @pytest.mark.parametrize("rangeA", [[1.0, 2.0]]) + @pytest.mark.parametrize("rangeB", [[0.5, 1.5], [0.9999, 1.0001]]) + def test_is_overlap_true(self, rangeA, rangeB): + assert is_overlap(rangeA, rangeB) + + @pytest.mark.unit + @pytest.mark.parametrize("rangeA", [[1.0, 2.0]]) + @pytest.mark.parametrize("rangeB", [[2.0, 2.5], [-1.0, 1.00]]) + def test_is_overlap_false(self, rangeA, rangeB): + assert not is_overlap(rangeA, rangeB) + + @pytest.mark.unit + @pytest.mark.parametrize("x", [1.0, 2.3456]) + @pytest.mark.parametrize("decimals", [1, 2, 3, 4]) + def test_fl2int(self, x, decimals): + assert fl2int(x, decimals) == round(x * 10 ** decimals, 0) + + @pytest.mark.unit + @pytest.mark.parametrize("x", [1234]) + @pytest.mark.parametrize("decimals", [1, 2, 3, 4,]) + def test_int2fl(self, x, decimals): + assert abs(int2fl(x, decimals) - round(x / (10 ** decimals), decimals)) < (10 ** -(decimals + 1)) + + @pytest.mark.unit + def test_merge_float_intervals_edge_margin_test(self): + intervals = [[0.0, 1.0], [1.0, 2.0]] + + target_0 = [[0.0, 2.0]] + merged_0 = merge_float_intervals(intervals, margin=0) + assert check_range_values(target_0, merged_0) + + target_1 = [[0.0, 1.0], [1.0, 2.0]] + merged_1 = merge_float_intervals(intervals, margin=1) + assert check_range_values(target_1, merged_1) + + target_2 = [[0.0, 1.0], [1.0, 2.0]] + merged_2 = merge_float_intervals(intervals, margin=2) + assert check_range_values(target_2, merged_2) + + @pytest.mark.unit + @pytest.mark.parametrize( + "intervals", + [ + [[0.25, 1.7], [1.5, 3.0], [2.8, 5.0], [5.5, 10.0]], + [[0.25, 5.0], [5.5, 10.0], [1.5, 3.5]], + [[5.5, 8.05], [8.0, 10.0], [0.25, 5.0]], + [[0.25, 3.0], [1.5, 3.0], [5.5, 10.0], [2.8, 5.0]], + [[0.25, 1.7], [1.5, 3.0], [2.8, 5.0], [5.5, 10.0]], + ], + ) + @pytest.mark.parametrize("target", [[[0.25, 5.0], [5.5, 10.0]]]) + def test_merge_float_overlaps(self, intervals, target): + merged = merge_float_intervals(intervals) + assert check_range_values(target, merged) + + @pytest.mark.unit + def test_get_speech_labels_for_update(self): + frame_start = 3.0 + buffer_end = 6.0 + cumulative_speech_labels = torch.tensor([[0.0000, 3.7600]]) + vad_timestamps = torch.tensor([[0.9600, 4.8400]]) + cursor_for_old_segments = 1.0 + speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update( + frame_start, buffer_end, cumulative_speech_labels, vad_timestamps, cursor_for_old_segments, + ) + assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8 + assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8 + + # Check if the ranges are containing faulty values + assert check_ranges(speech_labels_for_update) + assert check_ranges(cumulative_speech_labels) + + @pytest.mark.unit + def test_get_online_subsegments_from_buffer(self): + torch.manual_seed(0) + sample_rate = 16000 + speech_labels_for_update = torch.Tensor([[0.0000, 3.7600]]) + audio_buffer = torch.randn(5 * sample_rate) + segment_indexes = [] + window = 2.0 + shift = 1.0 + slice_length = int(window * sample_rate) + range_target = [[0.0, 2.0], [1.0, 3.0], [2.0, 3.76]] + sigs_list, sig_rangel_list, sig_indexes = get_online_subsegments_from_buffer( + buffer_start=0.0, + buffer_end=5.0, + sample_rate=sample_rate, + speech_labels_for_update=speech_labels_for_update, + audio_buffer=audio_buffer, + segment_indexes=segment_indexes, + window=window, + shift=shift, + ) + assert check_range_values(target=range_target, source=sig_rangel_list) + for k, rg in enumerate(sig_rangel_list): + signal = get_target_sig(audio_buffer, rg[0], rg[1], slice_length, sample_rate) + if len(signal) < int(window * sample_rate): + signal = repeat_signal(signal, len(signal), slice_length) + assert len(signal) == int(slice_length), "Length mismatch" + assert (np.abs(signal - sigs_list[k])).sum() < 1e-8, "Audio stream mismatch" + assert (torch.tensor(sig_indexes) - torch.arange(len(range_target))).sum() < 1e-8, "Segment index mismatch" + + @pytest.mark.unit + @pytest.mark.parametrize("frame_start", [3.0]) + @pytest.mark.parametrize("segment_range_ts", [[[0.0, 2.0]]]) + @pytest.mark.parametrize("gt_cursor_for_old_segments", [3.0]) + @pytest.mark.parametrize("gt_cursor_index", [1]) + def test_get_new_cursor_for_update_mulsegs_ex1( + self, frame_start, segment_range_ts, gt_cursor_for_old_segments, gt_cursor_index + ): + cursor_for_old_segments, cursor_index = get_new_cursor_for_update(frame_start, segment_range_ts) + assert cursor_for_old_segments == gt_cursor_for_old_segments + assert cursor_index == gt_cursor_index + + @pytest.mark.unit + @pytest.mark.parametrize("target_range", [[1.0, 4.0]]) + @pytest.mark.parametrize( + "source_range_list", [[[2.0, 3.0], [3.0, 4.0]], [[0.0, 2.0], [3.0, 5.0]], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]] + ) + def get_sub_range_list(self, target_range, source_range_list): + sub_range_list = get_sub_range_list(target_range, source_range_list) + assert sub_range_list == [[2.0, 3.0], [3.0, 4.0]] + + @pytest.mark.unit + @pytest.mark.parametrize("source_range_list", [[[0.0, 2.0]], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]) + def test_tensor_to_list(self, source_range_list): + a_range_tensor = torch.tensor(source_range_list) + converted_list = tensor_to_list(a_range_tensor) + assert source_range_list == converted_list + + @pytest.mark.unit + @pytest.mark.parametrize( + "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate", + [(0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000),], + ) + def test_get_online_segments_from_slices( + self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate + ): + sig = torch.randn(int(sample_rate * buffer_end)) + ind_offset, sigs_list, sig_rangel_list, sig_indexes = get_online_segments_from_slices( + sig, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate + ) + assert ind_offset == 2 + assert len(sigs_list) == 2 + assert len(sig_rangel_list) == 2 + assert len(sig_indexes) == 2 + + +class TestClusteringUtilFunctions: + @pytest.mark.parametrize("p_value", [1, 5, 9]) + @pytest.mark.parametrize("N", [9, 20]) + @pytest.mark.parametrize("mask_method", ['binary', 'sigmoid', 'drop']) + def test_get_k_neighbors_connections(self, p_value: int, N: int, mask_method: str, seed=0): + torch.manual_seed(seed) + random_mat = torch.rand(N, N) + affinity_mat = 0.5 * (random_mat + random_mat.T) + affinity_mat = affinity_mat / torch.max(affinity_mat) + binarized_affinity_mat = getKneighborsConnections(affinity_mat, p_value, mask_method) + if mask_method == 'binary': + assert all(binarized_affinity_mat.sum(dim=0) == float(p_value)) + elif mask_method == 'sigmoid': + assert all(binarized_affinity_mat.sum(dim=0) <= float(p_value)) + elif mask_method == 'drop': + assert all(binarized_affinity_mat.sum(dim=0) <= float(p_value)) + + @pytest.mark.unit + @pytest.mark.parametrize("Y_aggr", [torch.tensor([0, 1, 0, 1])]) + @pytest.mark.parametrize("chunk_cluster_count, embeddings_per_chunk", [(2, 50)]) + @pytest.mark.parametrize("window_range_list", [[[0, 1], [1, 2], [2, 3], [3, 4]]]) + @pytest.mark.parametrize( + "absolute_merge_mapping", + [[[torch.tensor([]), torch.tensor([0, 2])], [torch.tensor([]), torch.tensor([1, 3])]]], + ) + @pytest.mark.parametrize("org_len", [4]) + def test_unpack_labels( + self, Y_aggr, window_range_list, absolute_merge_mapping, chunk_cluster_count, embeddings_per_chunk, org_len + ): + expected_result = Y_aggr + longform_speaker_clustering = LongFormSpeakerClustering(cuda=False) + output = longform_speaker_clustering.unpack_labels(Y_aggr, window_range_list, absolute_merge_mapping, org_len) + assert torch.equal(output, expected_result) + + +class TestSpeakerClustering: + """ + Test speaker clustering module + """ + + @pytest.mark.unit + @pytest.mark.parametrize("cuda", [True, False]) + def test_offline_clus_script_save_load(self, cuda): + exported_filename = 'speaker_clustering_script.pt' + speaker_clustering_python = SpeakerClustering(maj_vote_spk_count=False, cuda=cuda) + speaker_clustering_scripted_source = torch.jit.script(speaker_clustering_python) + torch.jit.save(speaker_clustering_scripted_source, exported_filename) + assert os.path.exists(exported_filename) + os.remove(exported_filename) + assert not os.path.exists(exported_filename) + + @pytest.mark.unit + @pytest.mark.parametrize("cuda", [True, False]) + def test_online_clus_script_save_load(self, cuda): + exported_filename = 'speaker_clustering_script.pt' + speaker_clustering_python = OnlineSpeakerClustering( + max_num_speakers=8, + max_rp_threshold=0.15, + sparse_search_volume=30, + history_buffer_size=150, + current_buffer_size=150, + cuda=cuda, + ) + speaker_clustering_scripted_source = torch.jit.script(speaker_clustering_python) + torch.jit.save(speaker_clustering_scripted_source, exported_filename) + assert os.path.exists(exported_filename) + os.remove(exported_filename) + assert not os.path.exists(exported_filename) + + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + @pytest.mark.parametrize("n_spks", [1, 2, 3, 4, 5, 6, 7]) + @pytest.mark.parametrize("total_sec, SSV, perturb_sigma, seed", [(30, 10, 0.1, 0)]) + @pytest.mark.parametrize("jit_script", [False, True]) + def test_offline_speaker_clustering(self, n_spks, total_sec, SSV, perturb_sigma, seed, jit_script, cuda=True): + spk_dur = total_sec / n_spks + em, ts, mc, mw, spk_ts, gt = generate_toy_data( + n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=perturb_sigma, torch_seed=seed + ) + offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, cuda=cuda) + assert isinstance(offline_speaker_clustering, SpeakerClustering) + if jit_script: + offline_speaker_clustering = torch.jit.script(offline_speaker_clustering) + + Y_out = offline_speaker_clustering.forward_infer( + embeddings_in_scales=em, + timestamps_in_scales=ts, + multiscale_segment_counts=mc, + multiscale_weights=mw, + oracle_num_speakers=-1, + max_num_speakers=8, + enhanced_count_thres=40, + sparse_search_volume=SSV, + max_rp_threshold=0.15, + fixed_thres=-1.0, + ) + permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) + permuted_Y = permuted_Y.to(gt.device) + # mc[-1] is the number of base scale segments + assert len(set(permuted_Y.tolist())) == n_spks + assert Y_out.shape[0] == mc[-1] + assert all(permuted_Y == gt) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + @pytest.mark.parametrize("n_spks", [1, 2, 3, 4, 5, 6, 7]) + @pytest.mark.parametrize("total_sec, SSV, perturb_sigma, seed", [(30, 10, 0.1, 0)]) + @pytest.mark.parametrize("jit_script", [False, True]) + def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_sigma, seed, jit_script, cuda=False): + self.test_offline_speaker_clustering(n_spks, total_sec, SSV, perturb_sigma, seed, jit_script, cuda=cuda) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + @pytest.mark.parametrize("n_spks", [1]) + @pytest.mark.parametrize("spk_dur", [0.25, 0.5, 0.75, 1, 1.5, 2]) + @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)]) + @pytest.mark.parametrize("seed", [0]) + def test_offline_speaker_clustering_very_short_cpu( + self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + ): + em, ts, mc, mw, spk_ts, gt = generate_toy_data( + n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed + ) + offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=False) + assert isinstance(offline_speaker_clustering, SpeakerClustering) + Y_out = offline_speaker_clustering.forward_infer( + embeddings_in_scales=em, + timestamps_in_scales=ts, + multiscale_segment_counts=mc, + multiscale_weights=mw, + oracle_num_speakers=-1, + max_num_speakers=8, + enhanced_count_thres=enhanced_count_thres, + sparse_search_volume=SSV, + max_rp_threshold=0.15, + fixed_thres=-1.0, + ) + permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) + permuted_Y = permuted_Y.to(gt.device) + # mc[-1] is the number of base scale segments + assert len(set(permuted_Y.tolist())) == n_spks + assert Y_out.shape[0] == mc[-1] + assert all(permuted_Y == gt) + + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + @pytest.mark.parametrize("spk_dur", [0.25, 0.5, 0.75, 1, 2, 4]) + @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) + @pytest.mark.parametrize("seed", [0]) + def test_offline_speaker_clustering_very_short_gpu( + self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + ): + em, ts, mc, mw, spk_ts, gt = generate_toy_data( + n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed + ) + offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=True) + assert isinstance(offline_speaker_clustering, SpeakerClustering) + Y_out = offline_speaker_clustering.forward_infer( + embeddings_in_scales=em, + timestamps_in_scales=ts, + multiscale_segment_counts=mc, + multiscale_weights=mw, + oracle_num_speakers=-1, + max_num_speakers=8, + enhanced_count_thres=enhanced_count_thres, + sparse_search_volume=SSV, + max_rp_threshold=0.15, + fixed_thres=-1.0, + ) + permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) + permuted_Y = permuted_Y.to(gt.device) + # mc[-1] is the number of base scale segments + assert Y_out.shape[0] == mc[-1] + assert all(permuted_Y == gt) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(2, 5, 40, 6)]) + @pytest.mark.parametrize("spk_dur, chunk_cluster_count, embeddings_per_chunk", [(120, 4, 50), (240, 4, 100)]) + @pytest.mark.parametrize("seed", [0]) + @pytest.mark.parametrize("jit_script", [False, True]) + def test_longform_speaker_clustering_cpu( + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + chunk_cluster_count, + embeddings_per_chunk, + jit_script, + seed, + ): + em, ts, mc, mw, spk_ts, gt = generate_toy_data( + n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed + ) + longform_speaker_clustering = LongFormSpeakerClustering(cuda=False) + if jit_script: + longform_speaker_clustering = torch.jit.script(longform_speaker_clustering) + else: + assert isinstance(longform_speaker_clustering, LongFormSpeakerClustering) + Y_out = longform_speaker_clustering.forward_infer( + embeddings_in_scales=em, + timestamps_in_scales=ts, + multiscale_segment_counts=mc, + multiscale_weights=mw, + oracle_num_speakers=-1, + max_num_speakers=n_spks, + enhanced_count_thres=enhanced_count_thres, + sparse_search_volume=SSV, + max_rp_threshold=0.15, + fixed_thres=-1.0, + chunk_cluster_count=chunk_cluster_count, + embeddings_per_chunk=embeddings_per_chunk, + ) + permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) + permuted_Y = permuted_Y.to(gt.device) + + # mc[-1] is the number of base scale segments + assert Y_out.shape[0] == mc[-1] + assert all(permuted_Y == gt) + + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(2, 5, 40, 6)]) + @pytest.mark.parametrize("spk_dur, chunk_cluster_count, embeddings_per_chunk", [(120, 4, 50), (240, 4, 100)]) + @pytest.mark.parametrize("seed", [0]) + @pytest.mark.parametrize("jit_script", [False, True]) + def test_longform_speaker_clustering_gpu( + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + chunk_cluster_count, + embeddings_per_chunk, + jit_script, + seed, + ): + em, ts, mc, mw, spk_ts, gt = generate_toy_data( + n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed + ) + longform_speaker_clustering = LongFormSpeakerClustering(cuda=True) + + if jit_script: + longform_speaker_clustering = torch.jit.script(longform_speaker_clustering) + else: + assert isinstance(longform_speaker_clustering, LongFormSpeakerClustering) + + Y_out = longform_speaker_clustering.forward_infer( + embeddings_in_scales=em, + timestamps_in_scales=ts, + multiscale_segment_counts=mc, + multiscale_weights=mw, + oracle_num_speakers=-1, + max_num_speakers=n_spks, + enhanced_count_thres=enhanced_count_thres, + sparse_search_volume=SSV, + max_rp_threshold=0.15, + fixed_thres=-1.0, + chunk_cluster_count=chunk_cluster_count, + embeddings_per_chunk=embeddings_per_chunk, + ) + permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) + permuted_Y = permuted_Y.to(gt.device) + + # mc[-1] is the number of base scale segments + assert Y_out.shape[0] == mc[-1] + assert all(permuted_Y == gt) + + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + @pytest.mark.parametrize("n_spks", [1, 2, 3]) + @pytest.mark.parametrize("total_sec, buffer_size, sigma", [(30, 30, 0.1)]) + @pytest.mark.parametrize("seed", [0]) + @pytest.mark.parametrize("jit_script", [False, True]) + def test_online_speaker_clustering(self, n_spks, total_sec, buffer_size, sigma, seed, jit_script, cuda=True): + step_per_frame = 2 + spk_dur = total_sec / n_spks + em, ts, mc, _, _, gt = generate_toy_data(n_spks, spk_dur=spk_dur, perturb_sigma=sigma, torch_seed=seed) + em_s, ts_s = split_input_data(em, ts, mc) + + emb_gen = em_s[-1] + segment_indexes = ts_s[-1] + if cuda: + device = torch.cuda.current_device() + emb_gen, segment_indexes = emb_gen.to(device), segment_indexes.to(device) + + history_buffer_size = buffer_size + current_buffer_size = buffer_size + + online_clus = OnlineSpeakerClustering( + max_num_speakers=8, + max_rp_threshold=0.15, + sparse_search_volume=30, + history_buffer_size=history_buffer_size, + current_buffer_size=current_buffer_size, + cuda=cuda, + ) + if jit_script: + online_clus = torch.jit.script(online_clus) + + n_frames = int(emb_gen.shape[0] / step_per_frame) + evaluation_list = [] + + # Simulate online speaker clustering + for frame_index in range(n_frames): + curr_emb = emb_gen[0 : (frame_index + 1) * step_per_frame] + base_segment_indexes = torch.arange(curr_emb.shape[0]).to(curr_emb.device) + # Check history_buffer_size and history labels + assert ( + online_clus.history_embedding_buffer_emb.shape[0] <= history_buffer_size + ), "History buffer size error" + assert ( + online_clus.history_embedding_buffer_emb.shape[0] + == online_clus.history_embedding_buffer_label.shape[0] + ) + + # Call clustering function + merged_clus_labels = online_clus.forward_infer( + curr_emb=curr_emb, base_segment_indexes=base_segment_indexes, frame_index=frame_index, cuda=cuda + ) + + # Resolve permutations + assert len(merged_clus_labels) == (frame_index + 1) * step_per_frame + # Resolve permutation issue by using stitch_cluster_labels function + merged_clus_labels = merged_clus_labels.cpu() + merged_clus_labels = stitch_cluster_labels(Y_old=gt[: len(merged_clus_labels)], Y_new=merged_clus_labels) + evaluation_list.extend(list(merged_clus_labels == gt[: len(merged_clus_labels)])) + + assert online_clus.is_online + cumul_label_acc = sum(evaluation_list) / len(evaluation_list) + assert cumul_label_acc > 0.9 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + @pytest.mark.parametrize("n_spks, total_sec, buffer_size, sigma, seed", [(3, 30, 30, 0.1, 0)]) + @pytest.mark.parametrize("jit_script", [False, True]) + def test_online_speaker_clustering_cpu(self, n_spks, total_sec, buffer_size, sigma, seed, jit_script, cuda=False): + self.test_online_speaker_clustering(n_spks, total_sec, buffer_size, sigma, seed, jit_script, cuda) + + +class TestLinearSumAssignmentAlgorithm: + @pytest.mark.unit + def test_lsa_solver_export_test(self): + cost_matrix = torch.randint(0, 10, (3, 3)) + solver = LinearSumAssignmentSolver(cost_matrix) + solver = torch.jit.script(solver) + assert isinstance(solver, torch.jit._script.RecursiveScriptClass) + + @pytest.mark.unit + @pytest.mark.parametrize( + "cost_matrix", + [torch.tensor([[7, 6, 2, 9, 2], [6, 2, 1, 3, 9], [5, 6, 8, 9, 5], [6, 8, 5, 8, 6], [9, 5, 6, 4, 7]])], + ) + def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix): + """ + Test the linear sum assignment algorithm with a cost matrix + + Compare with the scipy implementation and make sure the final cost is the same. + NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. + This test only checks if the cost is the same. + """ + row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix) + row_ind_sc, col_ind_sc = scipy_linear_sum_assignment(cost_matrix.cpu().numpy()) + cost_nm = sum(cost_matrix[row_ind_nm, col_ind_nm]) + cost_sc = sum(cost_matrix[row_ind_sc, col_ind_sc]) + assert cost_nm == cost_sc + + @pytest.mark.unit + @pytest.mark.parametrize("seed", [0, 1]) + @pytest.mark.parametrize("mat_size", [1, 2, 4, 8]) + def test_linear_sum_assignment_algorithm_random_matrix(self, seed, mat_size): + torch.manual_seed(seed) + cost_matrix = torch.randint(0, 10, (mat_size, mat_size)) + self.test_linear_sum_assignment_algorithm_cost_matrix(cost_matrix) diff --git a/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py new file mode 100644 index 000000000000..1c45603643d3 --- /dev/null +++ b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py @@ -0,0 +1,320 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import torch +import itertools +from omegaconf import DictConfig + +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( + find_first_nonzero, + find_best_permutation, + reconstruct_labels, + get_ats_targets, + get_pil_targets, + get_hidden_length_from_sample_length, +) + +def reconstruct_labels_forloop(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: + """ + This is a for-loop implementation of reconstruct_labels built for testing purposes. + """ + # Expanding batch_perm_inds to align with labels dimensions + batch_size, num_frames, num_speakers = labels.shape + batch_perm_inds_exp = batch_perm_inds.unsqueeze(1).expand(-1, num_frames, -1) + + # Reconstructing the labels using advanced indexing + reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) + return reconstructed_labels + +class TestSortingUtils: + @pytest.mark.unit + @pytest.mark.parametrize( + "mat, max_cap_val, thres, expected", + [ + # Test 1: Basic case with clear first nonzero values + (torch.tensor([[0.1, 0.6, 0.0], [0.0, 0.0, 0.9]]), -1, 0.5, torch.tensor([1, 2])), + # Test 2: All elements are below threshold + (torch.tensor([[0.1, 0.2], [0.3, 0.4]]), -1, 0.5, torch.tensor([-1, -1])), + # Test 3: No nonzero elements, should return max_cap_val (-1) + (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), -1, 0.5, torch.tensor([-1, -1])), + # Test 4: Large matrix with mixed values, some rows with all values below threshold + (torch.tensor([[0.1, 0.7, 0.3], [0.0, 0.0, 0.9], [0.5, 0.6, 0.7]]), -1, 0.5, torch.tensor([1, 2, 0])), + # Test 5: Single row matrix + (torch.tensor([[0.0, 0.0, 0.6]]), -1, 0.5, torch.tensor([2])), + # Test 6: Single column matrix + (torch.tensor([[0.1], [0.6], [0.0]]), -1, 0.5, torch.tensor([-1, 0, -1])), + # Test 7: One element matrix + (torch.tensor([[0.501]]), -1, 0.5, torch.tensor([0], dtype=torch.long)), + # Test 8: All values are zero, should return max_cap_val + (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), -1, 0.5, torch.tensor([-1, -1])), + # Test 9: All values are above threshold + (torch.tensor([[0.6, 0.7], [0.8, 0.9]]), -1, 0.5, torch.tensor([0, 0])), + # Test 10: Custom max_cap_val different from default + (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), 99, 0.5, torch.tensor([99, 99])), + # Test 11: Matrix with 101 columns, first nonzero value is towards the end + (torch.cat([torch.zeros(1, 100), torch.ones(1, 1)], dim=1), -1, 0.5, torch.tensor([100])), + # Test 12: Matrix with 1000 columns, all below threshold except one near the middle + (torch.cat([torch.zeros(1, 499), torch.tensor([[0.6]]), torch.zeros(1, 500)], dim=1), -1, 0.5, torch.tensor([499])), + + ] + ) + def test_find_first_nonzero(self, mat, max_cap_val, thres, expected): + result = find_first_nonzero(mat, max_cap_val, thres) + assert torch.equal(result, expected), f"Expected {expected} but got {result}" + + + @pytest.mark.unit + @pytest.mark.parametrize( + "match_score, speaker_permutations, expected", + [ + # Test 1: Simple case with batch size 1, clear best match + ( + torch.tensor([[0.1, 0.9, 0.2]]), # match_score (batch_size=1, num_permutations=3) + torch.tensor([[0, 1], [1, 0], [0, 1]]), # speaker_permutations (num_permutations=3, num_speakers=2) + torch.tensor([[1, 0]]) # expected best permutation for the batch + ), + # Test 2: Batch size 2, different best matches for each batch + ( + torch.tensor([[0.5, 0.3, 0.7], [0.2, 0.6, 0.4]]), # match_score (batch_size=2, num_permutations=3) + torch.tensor([[0, 1], [1, 0], [0, 1]]), # speaker_permutations + torch.tensor([[0, 1], [1, 0]]) # expected best permutations + ), + # Test 3: Larger number of speakers and permutations + ( + torch.tensor([[0.1, 0.4, 0.9, 0.5], [0.6, 0.3, 0.7, 0.2]]), # match_score (batch_size=2, num_permutations=4) + torch.tensor([[0, 1, 2], [1, 0, 2], [2, 1, 0], [1, 2, 0]]), # speaker_permutations (num_permutations=4, num_speakers=3) + torch.tensor([[2, 1, 0], [2, 1, 0]]) # expected best permutations + ), + # Test 4: All match scores are the same, should pick the first permutation (argmax behavior) + ( + torch.tensor([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]), # equal match_score across permutations + torch.tensor([[0, 1], [1, 0], [0, 1]]), # speaker_permutations + torch.tensor([[0, 1], [0, 1]]) # first permutation is chosen as tie-breaker + ), + # Test 5: Single speaker case (num_speakers = 1) + ( + torch.tensor([[0.8, 0.2]]), # match_score (batch_size=1, num_permutations=2) + torch.tensor([[0], [0]]), # speaker_permutations (num_permutations=2, num_speakers=1) + torch.tensor([[0]]) # expected best permutation + ), + # Test 6: Batch size 3, varying permutations + ( + torch.tensor([[0.3, 0.6], [0.4, 0.1], [0.2, 0.7]]), # match_score (batch_size=3, num_permutations=2) + torch.tensor([[0, 1], [1, 0]]), # speaker_permutations + torch.tensor([[1, 0], [0, 1], [1, 0]]) # expected best permutations for each batch + ), + ] + ) + def test_find_best_permutation(self, match_score, speaker_permutations, expected): + result = find_best_permutation(match_score, speaker_permutations) + assert torch.equal(result, expected), f"Expected {expected} but got {result}" + + + @pytest.mark.parametrize("batch_size, num_frames, num_speakers", [ + (2, 4, 3), # Original test case + (3, 5, 2), # More frames and speakers + (1, 6, 4), # Single batch with more frames and speakers + (5, 3, 5), # More batch size with equal frames and speakers + ]) + def test_reconstruct_labels_with_forloop_ver(self, batch_size, num_frames, num_speakers): + # Generate random labels and batch_perm_inds tensor for testing + labels = torch.rand(batch_size, num_frames, num_speakers) + batch_perm_inds = torch.stack([torch.randperm(num_speakers) for _ in range(batch_size)]) + + # Call both functions + result_matrix = reconstruct_labels(labels, batch_perm_inds) + result_forloop = reconstruct_labels_forloop(labels, batch_perm_inds) + + # Assert that both methods return the same result + assert torch.allclose(result_matrix, result_forloop), "The results are not equal!" + + @pytest.mark.parametrize("labels, batch_perm_inds, expected_output", [ + # Example 1: Small batch size with a few frames and speakers + ( + torch.tensor([ + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], # First batch + [[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]] # Second batch + ]), + torch.tensor([[2, 0, 1], [1, 2, 0]]), + torch.tensor([ + [[0.3, 0.1, 0.2], [0.6, 0.4, 0.5], [0.9, 0.7, 0.8]], # First batch reconstructed + [[0.8, 0.7, 0.9], [0.5, 0.4, 0.6], [0.2, 0.1, 0.3]] # Second batch reconstructed + ]) + ), + + # Example 2: batch_size = 1 with more frames and speakers + ( + torch.tensor([ + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2], [1.3, 1.4, 1.5, 1.6]] + ]), + torch.tensor([[3, 0, 1, 2]]), + torch.tensor([ + [[0.4, 0.1, 0.2, 0.3], [0.8, 0.5, 0.6, 0.7], [1.2, 0.9, 1.0, 1.1], [1.6, 1.3, 1.4, 1.5]] + ]) + ), + + # Example 3: Larger batch size with fewer frames and speakers + ( + torch.tensor([ + [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], # First batch + [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], # Second batch + [[1.3, 1.4], [1.5, 1.6], [1.7, 1.8]], # Third batch + [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]] # Fourth batch + ]), + torch.tensor([[1, 0], [0, 1], [1, 0], [0, 1]]), + torch.tensor([ + [[0.2, 0.1], [0.4, 0.3], [0.6, 0.5]], # First batch reconstructed + [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], # Second batch unchanged + [[1.4, 1.3], [1.6, 1.5], [1.8, 1.7]], # Third batch reconstructed + [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]] # Fourth batch unchanged + ]) + ) + ]) + def test_reconstruct_labels(self, labels, batch_perm_inds, expected_output): + # Call the reconstruct_labels function + result = reconstruct_labels(labels, batch_perm_inds) + # Assert that the result matches the expected output + assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" + + + +class TestTargetGenerators: + + @pytest.mark.parametrize("labels, preds, num_speakers, expected_output", [ + # Test 1: Basic case with simple permutations + ( + torch.tensor([ + [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], # Batch 1 + [[0.0, 0.0, 0.9], [0.0, 0.9, 0.1], [0.9, 0.1, 0.0]] # Batch 2 + ]), + torch.tensor([ + [[0.8, 0.2, 0.0], [0.2, 0.7, 0.0], [0.0, 0.1, 0.9]], # Batch 1 + [[0.0, 0.0, 0.8], [0.0, 0.8, 0.2], [0.9, 0.1, 0.0]] # Batch 2 + ]), + 3, # Number of speakers + torch.tensor([ + [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], # Expected labels for Batch 1 + [[0.9, 0.0, 0.0], [0.1, 0.9, 0.0], [0.0, 0.1, 0.9]] # Expected labels for Batch 2 + ]) + ), + + # Test 2: Ambiguous case + ( + torch.tensor([[[0.9, 0.8, 0.7], [0.2, 0.8, 0.7], [0.2, 0.3, 0.9]]]), # Labels + torch.tensor([[[0.6, 0.7, 0.2], [0.9, 0.4, 0.0], [0.1, 0.7, 0.1]]]), # Preds + 3, # Number of speakers + torch.tensor([[[0.8, 0.7, 0.9], [0.8, 0.7, 0.2], [0.3, 0.9, 0.2]]]) # Expected output + ), + + # Test 3: Ambiguous case + ( + torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), # Labels + torch.tensor([[[0.6, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]]), # Preds + 4, # Number of speakers + torch.tensor([[[1, 1, 0, 0], [1, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0]]]) # Expected output + ) + + ]) + def test_get_ats_targets(self, labels, preds, num_speakers, expected_output): + # Generate all permutations for the given number of speakers + speaker_inds = list(range(num_speakers)) + speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) + + # Call the function under test + result = get_ats_targets(labels, preds, speaker_permutations) + # Assert that the result matches the expected output + assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" + + + @pytest.mark.unit + @pytest.mark.parametrize( + "labels, preds, num_speakers, expected_output", + [ + # Test 1: Basic case with simple permutations + ( + torch.tensor([[[1, 0], [0, 1]], [[1, 0], [0, 1]]]), # Labels (batch_size=2, num_speakers=2, num_classes=2) + torch.tensor([[[1, 0], [0, 1]], [[0, 1], [1, 0]]]), # Preds (batch_size=2, num_speakers=2, num_classes=2) + 2, # Number of speakers + torch.tensor([[[1, 0], [0, 1]], [[0, 1], [1, 0]]]) # expected max_score_permed_labels + ), + + # Test 2: Batch size 1 with more complex permutations + ( + torch.tensor([[[0.8, 0.2], [0.3, 0.7]]]), # Labels + torch.tensor([[[0.9, 0.1], [0.2, 0.8]]]), # Preds + 2, # Number of speakers + torch.tensor([[[0.8, 0.2], [0.3, 0.7]]]) # expected output (labels remain the same as preds are close) + ), + + # Test 3: Ambiguous case + ( + torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), # Labels + torch.tensor([[[0.61, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]]), # Preds + 4, # Number of speakers + torch.tensor([[[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]) # Expected output + ) + ] + ) + def test_get_pil_targets(self, labels, preds, num_speakers, expected_output): + # Generate all permutations for the given number of speakers + speaker_inds = list(range(num_speakers)) + speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) + + result = get_pil_targets(labels, preds, speaker_permutations) + assert torch.equal(result, expected_output), f"Expected {expected_output} but got {result}" + + +class TestGetHiddenLengthFromSampleLength: + @pytest.mark.parametrize( + "num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame, expected_hidden_length", + [ + (160, 160, 8, 1), + (1280, 160, 8, 2), + (0, 160, 8, 1), + (159, 160, 8, 1), + (129, 100, 5, 1), + (300, 150, 3, 1), + ] + ) + def test_various_cases(self, num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame, expected_hidden_length): + result = get_hidden_length_from_sample_length(num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) + assert result == expected_hidden_length + + def test_default_parameters(self): + assert get_hidden_length_from_sample_length(160) == 1 + assert get_hidden_length_from_sample_length(1280) == 2 + assert get_hidden_length_from_sample_length(0) == 1 + assert get_hidden_length_from_sample_length(159) == 1 + + def test_edge_cases(self): + assert get_hidden_length_from_sample_length(159, 160, 8) == 1 + assert get_hidden_length_from_sample_length(160, 160, 8) == 1 + assert get_hidden_length_from_sample_length(161, 160, 8) == 1 + assert get_hidden_length_from_sample_length(1279, 160, 8) == 1 + + def test_real_life_examples(self): + # The samples tried when this function was designed. + assert get_hidden_length_from_sample_length(160000) == 126 + assert get_hidden_length_from_sample_length(159999) == 125 + assert get_hidden_length_from_sample_length(158720) == 125 + assert get_hidden_length_from_sample_length(158719) == 124 + + assert get_hidden_length_from_sample_length(158880) == 125 + assert get_hidden_length_from_sample_length(158879) == 125 + assert get_hidden_length_from_sample_length(1600) == 2 + assert get_hidden_length_from_sample_length(1599) == 2 \ No newline at end of file diff --git a/tests/collections/speaker_tasks/utils/test_vad_utils.py b/tests/collections/speaker_tasks/utils/test_vad_utils.py new file mode 100644 index 000000000000..a7672e1aa43d --- /dev/null +++ b/tests/collections/speaker_tasks/utils/test_vad_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from pyannote.core import Annotation, Segment + +from nemo.collections.asr.parts.utils.vad_utils import ( + align_labels_to_frames, + convert_labels_to_speech_segments, + frame_vad_construct_pyannote_object_per_file, + get_frame_labels, + get_nonspeech_segments, + load_speech_overlap_segments_from_rttm, + load_speech_segments_from_rttm, + read_rttm_as_pyannote_object, +) + + +def get_simple_rttm_without_overlap(rttm_file="test1.rttm"): + line = "SPEAKER 1 0 2 speech \n" + speech_segments = [[0.0, 2.0]] + with open(rttm_file, "w") as f: + f.write(line) + return rttm_file, speech_segments + + +def get_simple_rttm_with_overlap(rttm_file="test2.rttm"): + speech_segments = [[0.0, 3.0]] + overlap_segments = [[1.0, 2.0]] + with open(rttm_file, "w") as f: + f.write("SPEAKER 1 0 2 speech \n") + f.write("SPEAKER 1 1 2 speech \n") + return rttm_file, speech_segments, overlap_segments + + +def get_simple_rttm_with_silence(rttm_file="test3.rttm"): + line = "SPEAKER 1 1 2 speech \n" + speech_segments = [[1.0, 2.0]] + silence_segments = [[0.0, 1.0]] + with open(rttm_file, "w") as f: + f.write(line) + return rttm_file, speech_segments, silence_segments + + +class TestVADUtils: + @pytest.mark.parametrize(["logits_len", "labels_len"], [(20, 10), (20, 11), (20, 9), (10, 21), (10, 19)]) + @pytest.mark.unit + def test_align_label_logits(self, logits_len, labels_len): + logits = np.arange(logits_len).tolist() + labels = np.arange(labels_len).tolist() + labels_new = align_labels_to_frames(probs=logits, labels=labels) + + assert len(labels_new) == len(logits) + + @pytest.mark.unit + def test_load_speech_segments_from_rttm(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test1.rttm") + speech_segments_new = load_speech_segments_from_rttm(rttm_file) + assert speech_segments_new == speech_segments + + @pytest.mark.unit + def test_load_speech_overlap_segments_from_rttm(self, test_data_dir): + rttm_file, speech_segments, overlap_segments = get_simple_rttm_with_overlap(test_data_dir + "/test2.rttm") + speech_segments_new, overlap_segments_new = load_speech_overlap_segments_from_rttm(rttm_file) + assert speech_segments_new == speech_segments + assert overlap_segments_new == overlap_segments + + @pytest.mark.unit + def test_get_nonspeech_segments(self, test_data_dir): + rttm_file, speech_segments, silence_segments = get_simple_rttm_with_silence(test_data_dir + "/test3.rttm") + speech_segments_new = load_speech_segments_from_rttm(rttm_file) + silence_segments_new = get_nonspeech_segments(speech_segments_new) + assert silence_segments_new == silence_segments + + @pytest.mark.unit + def test_get_frame_labels(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test4.rttm") + speech_segments_new = load_speech_segments_from_rttm(rttm_file) + frame_labels = get_frame_labels(speech_segments_new, 0.02, 0.0, 3.0, as_str=False) + assert frame_labels[0] == 1 + assert len(frame_labels) == 150 + + @pytest.mark.unit + def test_convert_labels_to_speech_segments(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test5.rttm") + speech_segments_new = load_speech_segments_from_rttm(rttm_file) + frame_labels = get_frame_labels(speech_segments_new, 0.02, 0.0, 3.0, as_str=False) + speech_segments_new = convert_labels_to_speech_segments(frame_labels, 0.02) + assert speech_segments_new == speech_segments + + @pytest.mark.unit + def test_read_rttm_as_pyannote_object(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test6.rttm") + pyannote_object = read_rttm_as_pyannote_object(rttm_file) + pyannote_object_gt = Annotation() + pyannote_object_gt[Segment(0.0, 2.0)] = 'speech' + assert pyannote_object == pyannote_object_gt + + @pytest.mark.unit + def test_frame_vad_construct_pyannote_object_per_file(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test7.rttm") + # test for rttm input + ref, hyp = frame_vad_construct_pyannote_object_per_file(rttm_file, rttm_file) + pyannote_object_gt = Annotation() + pyannote_object_gt[Segment(0.0, 2.0)] = 'speech' + assert ref == hyp == pyannote_object_gt + + # test for list input + speech_segments = load_speech_segments_from_rttm(rttm_file) + frame_labels = get_frame_labels(speech_segments, 0.02, 0.0, 3.0, as_str=False) + speech_segments_new = convert_labels_to_speech_segments(frame_labels, 0.02) + assert speech_segments_new == speech_segments + ref, hyp = frame_vad_construct_pyannote_object_per_file(frame_labels, frame_labels, 0.02) + assert ref == hyp == pyannote_object_gt From ca44a6635351fa360c5c880866baab28304ea1b0 Mon Sep 17 00:00:00 2001 From: taejinp Date: Tue, 19 Nov 2024 18:33:48 -0800 Subject: [PATCH 23/47] Moving speaker task related unit test files to speaker_tasks folder Signed-off-by: taejinp --- .../collections/asr/test_diar_label_models.py | 167 --- tests/collections/asr/test_diar_metrics.py | 197 ---- .../asr/test_diar_neural_inference.py | 71 -- tests/collections/asr/test_diar_utils.py | 974 ------------------ .../test_speaker_label_models.py | 0 5 files changed, 1409 deletions(-) delete mode 100644 tests/collections/asr/test_diar_label_models.py delete mode 100644 tests/collections/asr/test_diar_metrics.py delete mode 100644 tests/collections/asr/test_diar_neural_inference.py delete mode 100644 tests/collections/asr/test_diar_utils.py rename tests/collections/{asr => speaker_tasks}/test_speaker_label_models.py (100%) diff --git a/tests/collections/asr/test_diar_label_models.py b/tests/collections/asr/test_diar_label_models.py deleted file mode 100644 index 2ed6177d3cb2..000000000000 --- a/tests/collections/asr/test_diar_label_models.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import torch -from omegaconf import DictConfig - -from nemo.collections.asr.models import EncDecDiarLabelModel - - -@pytest.fixture() -def msdd_model(): - - preprocessor = { - 'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', - 'params': {"features": 80, "window_size": 0.025, "window_stride": 0.01, "sample_rate": 16000,}, - } - - speaker_model_encoder = { - 'cls': 'nemo.collections.asr.modules.ConvASREncoder', - 'params': { - 'feat_in': 80, - 'activation': 'relu', - 'conv_mask': True, - 'jasper': [ - { - 'filters': 512, - 'repeat': 1, - 'kernel': [1], - 'stride': [1], - 'dilation': [1], - 'dropout': 0.0, - 'residual': False, - 'separable': False, - } - ], - }, - } - - speaker_model_decoder = { - 'cls': 'nemo.collections.asr.modules.SpeakerDecoder', - 'params': {'feat_in': 512, 'num_classes': 2, 'pool_mode': 'xvector', 'emb_sizes': [1024]}, - } - - speaker_model_cfg = DictConfig( - { - 'preprocessor': DictConfig(preprocessor), - 'encoder': DictConfig(speaker_model_encoder), - 'decoder': DictConfig(speaker_model_decoder), - } - ) - - msdd_module = { - 'cls': 'nemo.collections.asr.modules.MSDD_module', - 'params': { - "num_spks": 2, - "hidden_size": 256, - "num_lstm_layers": 3, - "dropout_rate": 0.5, - "cnn_output_ch": 32, - "conv_repeat": 2, - "emb_dim": 192, - "scale_n": 5, - "weighting_scheme": 'conv_scale_weight', - "context_vector_type": 'cos_sim', - }, - } - - loss = {'cls': 'nemo.collections.asr.losses.bce_loss.BCELoss', 'params': {"weight": None}} - - diarizer = { - 'out_dir': None, - 'oracle_vad': True, - "speaker_embeddings": { - "model_path": None, - "parameters": { - "window_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5], - "shift_length_in_sec": [0.75, 0.625, 0.5, 0.375, 0.25], - "multiscale_weights": [1, 1, 1, 1, 1], - "save_embeddings": True, - }, - }, - } - - modelConfig = DictConfig( - { - 'msdd_module': DictConfig(msdd_module), - 'preprocessor': DictConfig(preprocessor), - 'diarizer': DictConfig(diarizer), - 'loss': DictConfig(loss), - 'max_num_of_spks': 2, - 'num_workers': 5, - 'emb_batch_size': 0, - 'soft_label_thres': 0.5, - 'scale_n': 5, - 'speaker_model_cfg': speaker_model_cfg, - } - ) - model = EncDecDiarLabelModel(cfg=modelConfig) - return model - - -class TestEncDecDiarLabelModel: - @pytest.mark.unit - def test_constructor(self, msdd_model): - diar_model = msdd_model.train() - assert diar_model.cfg.scale_n == len( - diar_model.cfg.diarizer.speaker_embeddings.parameters.window_length_in_sec - ) - assert diar_model.cfg.scale_n == len(diar_model.cfg.diarizer.speaker_embeddings.parameters.shift_length_in_sec) - assert diar_model.cfg.scale_n == len(diar_model.cfg.diarizer.speaker_embeddings.parameters.multiscale_weights) - assert diar_model.cfg.msdd_module.num_spks == diar_model.cfg.max_num_of_spks - # TODO: make proper config and assert correct number of weights - # Check to/from config_dict: - confdict = diar_model.to_config_dict() - instance2 = EncDecDiarLabelModel.from_config_dict(confdict) - assert isinstance(instance2, EncDecDiarLabelModel) - - @pytest.mark.unit - def test_forward_infer(self, msdd_model): - diar_model = msdd_model.eval() - - # batch_size 4, scale_n 5, length 25, emb_dim 192 - input_signal = torch.randn(size=(4, 25, 5, 192)) - input_signal_length = 25 * torch.ones(4, dtype=torch.int) - emb_vectors = torch.randn(size=(4, 5, 192, 2)) - targets = torch.randint(2, size=(4, 25, 2), dtype=torch.int) - - with torch.no_grad(): - # batch size 1 - preds_list, scale_weights_list = [], [] - for i in range(input_signal.size(0)): - preds, scale_weights = diar_model.forward_infer( - input_signal[i : i + 1], input_signal_length[i : i + 1], emb_vectors[i : i + 1], targets[i : i + 1] - ) - preds_list.append(preds) - scale_weights_list.append(scale_weights) - preds_instance = torch.cat(preds_list, 0) - scale_weights_instance = torch.cat(scale_weights_list, 0) - - # batch size 4 - preds_batch, scale_weights_batch = diar_model.forward_infer( - input_signal, input_signal_length, emb_vectors, targets - ) - - assert preds_instance.shape == preds_batch.shape - assert scale_weights_instance.shape == scale_weights_batch.shape - - diff = torch.mean(torch.abs(preds_instance - preds_batch)) - assert diff <= 1e-6 - diff = torch.max(torch.abs(preds_instance - preds_batch)) - assert diff <= 1e-6 - diff = torch.mean(torch.abs(scale_weights_instance - scale_weights_batch)) - assert diff <= 1e-6 - diff = torch.max(torch.abs(scale_weights_instance - scale_weights_batch)) - assert diff <= 1e-6 diff --git a/tests/collections/asr/test_diar_metrics.py b/tests/collections/asr/test_diar_metrics.py deleted file mode 100644 index 3ae6f6f6a3fa..000000000000 --- a/tests/collections/asr/test_diar_metrics.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from itertools import permutations - -import pytest -import torch - -from nemo.collections.asr.metrics.der import ( - calculate_session_cpWER, - calculate_session_cpWER_bruteforce, - get_online_DER_stats, - get_partial_ref_labels, -) - - -def word_count(spk_transcript): - return sum([len(w.split()) for w in spk_transcript]) - - -def calculate_wer_count(_ins, _del, _sub, ref_word_count): - return (_ins + _del + _sub) / ref_word_count - - -def permuted_input_test(hyp, ref, calculated): - """ - Randomly permute the input to see if evaluation result stays the same. - """ - for hyp_permed in permutations(hyp): - cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp_permed, spk_reference=ref) - diff = torch.abs(torch.tensor(calculated - cpWER)) - assert diff <= 1e-6 - - -class TestConcatMinPermWordErrorRate: - """ - Tests for cpWER calculation. - """ - - @pytest.mark.unit - def test_cpwer_oneword(self): - hyp = ["oneword"] - ref = ["oneword"] - _ins, _del, _sub = 0, 0, 0 - cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) - ref_word_count = word_count(ref) - calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) - diff = torch.abs(torch.tensor(calculated - cpWER)) - assert diff <= 1e-6 - permuted_input_test(hyp, ref, calculated) - cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) - diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) - assert diff <= 1e-6 - - # Test with a substitution - hyp = ["wrongword"] - _ins, _del, _sub = 0, 0, 1 - cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) - calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) - diff = torch.abs(torch.tensor(calculated - cpWER)) - assert diff <= 1e-6 - permuted_input_test(hyp, ref, calculated) - cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) - diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) - assert diff <= 1e-6 - - @pytest.mark.unit - def test_cpwer_perfect(self): - hyp = ["ff", "aa bb cc", "dd ee"] - ref = ["aa bb cc", "dd ee", "ff"] - cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) - calculated = 0 - diff = torch.abs(torch.tensor(calculated - cpWER)) - assert diff <= 1e-6 - permuted_input_test(hyp, ref, calculated) - - @pytest.mark.unit - def test_cpwer_spk_counfusion_and_asr_error(self): - hyp = ["aa bb c ff", "dd e ii jj kk", "hi"] - ref = ["aa bb cc ff", "dd ee gg jj kk", "hh ii"] - _ins, _del, _sub = 0, 1, 4 - cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) - ref_word_count = word_count(ref) - calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) - diff = torch.abs(torch.tensor(calculated - cpWER)) - assert diff <= 1e-6 - permuted_input_test(hyp, ref, calculated) - cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) - diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) - assert diff <= 1e-6 - - @pytest.mark.unit - def test_cpwer_undercount(self): - hyp = ["aa bb cc", "dd ee gg", "hh ii", "jj kk"] - ref = ["aa bb cc", "dd ee", "ff", "gg", "hh ii", "jj kk"] - _ins, _del, _sub = 0, 1, 0 - cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) - ref_word_count = word_count(ref) - calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) - diff = torch.abs(torch.tensor(calculated - cpWER)) - assert diff <= 1e-6 - cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) - diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) - assert diff <= 1e-6 - - @pytest.mark.unit - def test_cpwer_overcount(self): - hyp = ["aa bb cc", "dd ee gg hh", "ii jj kk"] - ref = ["aa bb cc", "dd ee ff gg hh ii jj kk"] - _ins, _del, _sub = 0, 1, 0 - cpWER, hyp_min, ref_str = calculate_session_cpWER(spk_hypothesis=hyp, spk_reference=ref) - ref_word_count = word_count(ref) - calculated = calculate_wer_count(_ins, _del, _sub, ref_word_count) - diff = torch.abs(torch.tensor(calculated - cpWER)) - assert diff <= 1e-6 - cpWER_perm, hyp_min_perm, ref_str = calculate_session_cpWER_bruteforce(spk_hypothesis=hyp, spk_reference=ref) - diff = torch.abs(torch.tensor(cpWER_perm - cpWER)) - assert diff <= 1e-6 - - @pytest.mark.parametrize( - "pred_labels, ref_labels, expected_output", - [ - ([], [], []), - (["0.0 1.0 speaker1"], [], []), - (["0.0 1.0 speaker1"], ["0.0 1.5 speaker1"], ["0.0 1.0 speaker1"]), - (["0.1 0.4 speaker1", "0.5 1.0 speaker2"], ["0.0 1.5 speaker1"], ["0.0 1.0 speaker1"]), - ( - ["0.5 1.0 speaker2", "0.1 0.4 speaker1"], - ["0.0 1.5 speaker1"], - ["0.0 1.0 speaker1"], - ), # Order of prediction does not matter - ( - ["0.1 1.4 speaker1", "0.5 1.0 speaker2"], - ["0.0 1.5 speaker1"], - ["0.0 1.4 speaker1"], - ), # Overlapping prediction - ( - ["0.1 0.6 speaker1", "0.2 1.5 speaker2"], - ["0.5 1.0 speaker1", "1.01 2.0 speaker2"], - ["0.5 1.0 speaker1", "1.01 1.5 speaker2"], - ), - ( - ["0.0 2.0 speaker1"], - ["0.0 2.0 speaker1", "1.0 3.0 speaker2", "0.0 5.0 speaker3"], - ["0.0 2.0 speaker1", "1.0 2.0 speaker2", "0.0 2.0 speaker3"], - ), - ], - ) - def test_get_partial_ref_labels(self, pred_labels, ref_labels, expected_output): - assert get_partial_ref_labels(pred_labels, ref_labels) == expected_output - - @pytest.mark.parametrize( - "DER, CER, FA, MISS, diar_eval_count, der_stat_dict, deci, expected_der_dict, expected_der_stat_dict", - [ - ( - 0.3, - 0.1, - 0.05, - 0.15, - 1, - {"cum_DER": 0, "cum_CER": 0, "avg_DER": 0, "avg_CER": 0, "max_DER": 0, "max_CER": 0}, - 3, - {"DER": 30.0, "CER": 10.0, "FA": 5.0, "MISS": 15.0}, - {"cum_DER": 0.3, "cum_CER": 0.1, "avg_DER": 30.0, "avg_CER": 10.0, "max_DER": 30.0, "max_CER": 10.0}, - ), - ( - 0.1, - 0.2, - 0.03, - 0.07, - 2, - {"cum_DER": 0.3, "cum_CER": 0.3, "avg_DER": 15.0, "avg_CER": 15.0, "max_DER": 30.0, "max_CER": 10.0}, - 2, - {"DER": 10.0, "CER": 20.0, "FA": 3.0, "MISS": 7.0}, - {"cum_DER": 0.4, "cum_CER": 0.5, "avg_DER": 20.0, "avg_CER": 25.0, "max_DER": 30.0, "max_CER": 20.0}, - ), - ], - ) - def test_get_online_DER_stats( - self, DER, CER, FA, MISS, diar_eval_count, der_stat_dict, deci, expected_der_dict, expected_der_stat_dict - ): - actual_der_dict, actual_der_stat_dict = get_online_DER_stats( - DER, CER, FA, MISS, diar_eval_count, der_stat_dict, deci - ) - assert actual_der_dict == expected_der_dict - assert actual_der_stat_dict == expected_der_stat_dict diff --git a/tests/collections/asr/test_diar_neural_inference.py b/tests/collections/asr/test_diar_neural_inference.py deleted file mode 100644 index 076eac129293..000000000000 --- a/tests/collections/asr/test_diar_neural_inference.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import pytest -import torch - -from nemo.collections.asr.models.msdd_models import NeuralDiarizer - - -class TestNeuralDiarizerInference: - @pytest.mark.unit - @pytest.mark.parametrize( - "device", - [ - torch.device("cpu"), - pytest.param( - torch.device("cuda"), - marks=pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA required for test.',), - ), - ], - ) - @pytest.mark.parametrize("num_speakers", [None, 1]) - @pytest.mark.parametrize("max_num_speakers", [4]) - def test_diar_inference(self, tmpdir, test_data_dir, device, num_speakers, max_num_speakers): - """ - Test to ensure diarization inference works correctly. - - Ensures multiple audio files can be diarized sequentially - - Ensures both CPU/CUDA is supported - - Ensures that max speaker and num speaker are set correctly - - Ensures temporary directory is emptied at the end of diarization - - Sanity check to ensure outputs from diarization are reasonable - """ - audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav'] - audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames] - - diarizer = NeuralDiarizer.from_pretrained(model_name='diar_msdd_telephonic').to(device) - - out_dir = os.path.join(tmpdir, 'diarize_inference/') - - assert diarizer.msdd_model.device.type == device.type - assert diarizer._speaker_model.device.type == device.type - for audio_path in audio_paths: - annotation = diarizer( - audio_path, num_speakers=num_speakers, max_speakers=max_num_speakers, out_dir=out_dir - ) - - # assert max speakers has been set up correctly - assert diarizer.clustering_embedding.clus_diar_model._cluster_params.max_num_speakers == max_num_speakers - - if num_speakers: - assert diarizer._cfg.diarizer.clustering.parameters.oracle_num_speakers - - # assert all temporary files are cleaned up - assert len(os.listdir(out_dir)) == 0 - - # assert only 1 speaker & segment - assert len(annotation.labels()) == 1 - assert len(list(annotation.itersegments())) == 1 diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/asr/test_diar_utils.py deleted file mode 100644 index cb364675fcf4..000000000000 --- a/tests/collections/asr/test_diar_utils.py +++ /dev/null @@ -1,974 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import numpy as np -import pytest -import torch -from scipy.optimize import linear_sum_assignment as scipy_linear_sum_assignment - -from nemo.collections.asr.data.audio_to_label import repeat_signal -from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering -from nemo.collections.asr.parts.utils.offline_clustering import ( - SpeakerClustering, - get_scale_interpolated_embs, - getCosAffinityMatrix, - getKneighborsConnections, - split_input_data, -) -from nemo.collections.asr.parts.utils.online_clustering import ( - OnlineSpeakerClustering, - get_closest_embeddings, - get_merge_quantity, - get_minimal_indices, - merge_vectors, - run_reducer, - stitch_cluster_labels, -) -from nemo.collections.asr.parts.utils.optimization_utils import LinearSumAssignmentSolver -from nemo.collections.asr.parts.utils.optimization_utils import linear_sum_assignment as nemo_linear_sum_assignment -from nemo.collections.asr.parts.utils.speaker_utils import ( - OnlineSegmentor, - check_ranges, - fl2int, - get_new_cursor_for_update, - get_online_segments_from_slices, - get_online_subsegments_from_buffer, - get_speech_labels_for_update, - get_sub_range_list, - get_subsegments_scriptable, - get_target_sig, - int2fl, - is_overlap, - merge_float_intervals, - merge_int_intervals, - tensor_to_list, -) - - -def check_range_values(target, source): - bool_list = [] - for tgt, src in zip(target, source): - for x, y in zip(src, tgt): - bool_list.append(abs(x - y) < 1e-6) - return all(bool_list) - - -def check_labels(target, source): - bool_list = [] - for x, y in zip(target, source): - bool_list.append(abs(x - y) < 1e-6) - return all(bool_list) - - -def matrix(mat, use_tensor=True, dtype=torch.long): - if use_tensor: - mat = torch.Tensor(mat).to(dtype) - else: - mat = np.array(mat) - return mat - - -def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim): - """Generate a set of artificial orthogonal embedding vectors from random numbers""" - gaus = torch.randn(emb_dim, emb_dim) - _svd = torch.linalg.svd(gaus) - orth = _svd[0] @ _svd[2] - orth_embs = orth[:total_spks] - # Assert orthogonality - assert torch.abs(getCosAffinityMatrix(orth_embs) - torch.diag(torch.ones(total_spks))).sum() < 1e-4 - return orth_embs - - -def generate_toy_data( - n_spks=2, - spk_dur=3, - emb_dim=192, - perturb_sigma=0.0, - ms_window=[1.5, 1.0, 0.5], - ms_shift=[0.75, 0.5, 0.25], - torch_seed=0, -): - torch.manual_seed(torch_seed) - spk_timestamps = [(spk_dur * k, spk_dur) for k in range(n_spks)] - emb_list, seg_list = [], [] - multiscale_segment_counts = [0 for _ in range(len(ms_window))] - ground_truth = [] - random_orthogonal_embs = generate_orthogonal_embs(n_spks, perturb_sigma, emb_dim) - for scale_idx, (window, shift) in enumerate(zip(ms_window, ms_shift)): - for spk_idx, (offset, dur) in enumerate(spk_timestamps): - segments_stt_dur = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=dur) - segments = [[x[0], x[0] + x[1]] for x in segments_stt_dur] - emb_cent = random_orthogonal_embs[spk_idx, :] - emb = emb_cent.tile((len(segments), 1)) + 0.1 * torch.rand(len(segments), emb_dim) - seg_list.extend(segments) - emb_list.append(emb) - multiscale_segment_counts[scale_idx] += emb.shape[0] - - if scale_idx == len(multiscale_segment_counts) - 1: - ground_truth.extend([spk_idx] * emb.shape[0]) - - emb_tensor = torch.concat(emb_list) - multiscale_segment_counts = torch.tensor(multiscale_segment_counts) - segm_tensor = torch.tensor(seg_list) - multiscale_weights = torch.ones(len(ms_window)).unsqueeze(0) - ground_truth = torch.tensor(ground_truth) - return emb_tensor, segm_tensor, multiscale_segment_counts, multiscale_weights, spk_timestamps, ground_truth - - -class TestDiarizationSequneceUtilFunctions: - """Tests diarization and speaker-task related utils.""" - - @pytest.mark.unit - @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]]) - @pytest.mark.parametrize("target", [[0, 0, 0, 1, 1, 2]]) - @pytest.mark.parametrize("offset", [1, 10]) - def test_minimal_index_ex2(self, Y, target, offset): - Y = torch.tensor(Y) - target = torch.tensor(target) - min_Y = get_minimal_indices(Y) - assert check_labels(target, min_Y) - min_Y = get_minimal_indices(Y + offset) - assert check_labels(target, min_Y) - - @pytest.mark.parametrize("Y", [[4, 0, 0, 5, 4, 5], [14, 12, 12, 19, 14, 19]]) - @pytest.mark.parametrize("target", [[1, 0, 0, 2, 1, 2]]) - @pytest.mark.parametrize("offset", [1, 10]) - def test_minimal_index_ex2(self, Y, target, offset): - Y = torch.tensor(Y) - target = torch.tensor(target) - min_Y = get_minimal_indices(Y) - assert check_labels(target, min_Y) - min_Y = get_minimal_indices(Y + offset) - assert check_labels(target, min_Y) - - @pytest.mark.unit - @pytest.mark.parametrize("N", [2, 4, 16, 64]) - def test_minimal_index_same(self, N): - Y = matrix([0] * N + [1] * N + [2] * N) - min_Y = get_minimal_indices(Y) - target = matrix([0] * N + [1] * N + [2] * N) - assert check_labels(target, min_Y) - - @pytest.mark.unit - @pytest.mark.parametrize("N", [2, 4, 16, 64]) - def test_stitch_cluster_labels_label_switch(self, N): - Y_old = matrix([0] * N) - Y_new = matrix([0] * N) + 1 - target = matrix([0] * N) - result = stitch_cluster_labels(Y_old, Y_new) - assert check_labels(target, result) - - @pytest.mark.unit - @pytest.mark.parametrize("N", [2, 4, 16, 64]) - def test_stitch_cluster_labels_label_many_to_one(self, N): - Y_old = matrix(np.arange(N).tolist()) - Y_new = matrix([0] * N) - target = matrix([0] * N) - result = stitch_cluster_labels(Y_old, Y_new) - assert check_labels(target, result) - - @pytest.mark.unit - @pytest.mark.parametrize("N", [2, 4, 16, 64]) - def test_stitch_cluster_labels_label_one_to_many(self, N): - Y_old = matrix(np.arange(N).tolist()) - Y_new = matrix([k for k in range(N)]) - target = matrix([k for k in range(N)]) - result = stitch_cluster_labels(Y_old, Y_new) - assert check_labels(target, result) - - @pytest.mark.unit - @pytest.mark.parametrize("N", [2, 4, 16, 64]) - def test_stitch_cluster_labels_one_label_replaced(self, N): - Y_old = matrix([0] * N + [1] * N + [2] * N) - Y_new = matrix([1] * N + [2] * N + [3] * N) - target = matrix([0] * N + [1] * N + [2] * N) - result = stitch_cluster_labels(Y_old, Y_new) - assert check_labels(target, result) - - @pytest.mark.unit - @pytest.mark.parametrize("N", [2, 4, 16, 64]) - def test_stitch_cluster_labels_confusion_error(self, N): - Y_old = matrix([0] * N + [1] * (N - 1) + [2] * (N + 1)) - Y_new = matrix([1] * N + [2] * N + [3] * N) - target = matrix([0] * N + [1] * N + [2] * N) - result = stitch_cluster_labels(Y_old, Y_new) - assert check_labels(target, result) - - @pytest.mark.unit - @pytest.mark.parametrize("N", [2, 256]) - def test_stitch_cluster_labels_speaker_more_speakers(self, N): - Y_old = matrix([0] * N + [1] * (N - 1) + [2] * (N + 1) + [0, 0, 0]) - Y_new = matrix([1] * N + [0] * N + [2] * N + [4, 5, 6]) - target = matrix([0] * N + [1] * N + [2] * N + [3, 4, 5]) - result = stitch_cluster_labels(Y_old, Y_new) - assert check_labels(target, result) - - @pytest.mark.unit - @pytest.mark.parametrize("N", [2, 256]) - def test_stitch_cluster_labels_speaker_longer_sequence(self, N): - Y_old = matrix([0] * N + [1] * N + [2] * N + [0, 0, 0] * N) - Y_new = matrix([1] * N + [2] * N + [0] * N + [1, 2, 3, 1, 2, 3] * N) - target = matrix([0] * N + [1] * N + [2] * N + [0, 1, 3, 0, 1, 3] * N) - result = stitch_cluster_labels(Y_old, Y_new) - assert check_labels(target, result) - - @pytest.mark.unit - @pytest.mark.parametrize("n_spks", [2, 3, 4, 5]) - @pytest.mark.parametrize("merge_quantity", [2, 3]) - def test_embedding_merger(self, n_spks, merge_quantity): - em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks, spk_dur=5, perturb_sigma=10) - em_s, ts_s = split_input_data(em, ts, mc) - target_speaker_index = 0 - pre_clus_labels = gt - ndx = torch.where(pre_clus_labels == target_speaker_index)[0] - pre_embs = em_s[-1] - affinity_mat = getCosAffinityMatrix(pre_embs) - cmat = affinity_mat[:, ndx][ndx, :] - # Check the dimension of the selected affinity values - assert cmat.shape[0] == cmat.shape[1] == torch.sum(pre_clus_labels == target_speaker_index).item() - index_2d, rest_inds = get_closest_embeddings(cmat, merge_quantity) - # Check the most closest affinity value - assert torch.max(cmat.sum(0)) == cmat.sum(0)[index_2d[0]] - spk_cluster_labels, emb_ndx = pre_clus_labels[ndx], pre_embs[ndx] - merged_embs, merged_clus_labels = merge_vectors(index_2d, emb_ndx, spk_cluster_labels) - # Check the number of merged embeddings and labels - assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] - - @pytest.mark.unit - @pytest.mark.parametrize("n_spks", [1, 8]) - @pytest.mark.parametrize("spk_dur", [0.2, 0.25, 0.5, 1, 10]) - def test_cosine_affinity_calculation(self, n_spks, spk_dur): - em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=spk_dur) - em_s, ts_s = split_input_data(em, ts, mc) - affinity_mat = getCosAffinityMatrix(em_s[-1]) - # affinity_mat should not contain any nan element - assert torch.any(torch.isnan(affinity_mat)) == False - - @pytest.mark.unit - @pytest.mark.parametrize("n_spks", [1, 8]) - @pytest.mark.parametrize("spk_dur", [0.2, 0.25, 0.5, 1, 10]) - def test_cosine_affinity_calculation_scale_interpol(self, n_spks, spk_dur): - em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=spk_dur) - em_s, ts_s = split_input_data(em, ts, mc) - embs, _ = get_scale_interpolated_embs(mw, em_s, ts_s) - affinity_mat = getCosAffinityMatrix(embs) - # affinity_mat should not contain any nan element - assert torch.any(torch.isnan(affinity_mat)) == False - - @pytest.mark.unit - @pytest.mark.parametrize("n_spks", [4, 5, 6]) - @pytest.mark.parametrize("target_speaker_index", [0, 1, 2]) - @pytest.mark.parametrize("merge_quantity", [2, 3]) - def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): - em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10) - em_s, ts_s = split_input_data(em, ts, mc) - merged_embs, merged_clus_labels, _ = run_reducer( - pre_embs=em_s[-1], - target_spk_idx=target_speaker_index, - merge_quantity=merge_quantity, - pre_clus_labels=gt, - ) - assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] - - @pytest.mark.unit - @pytest.mark.parametrize("ntbr", [3]) - @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)]) - @pytest.mark.parametrize("mspb", [25]) - def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) - assert all(class_target_vol == torch.tensor([3, 0])) - - @pytest.mark.unit - @pytest.mark.parametrize("ntbr", [3]) - @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)]) - @pytest.mark.parametrize("mspb", [0, 25]) - def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) - assert all(class_target_vol == torch.tensor([3, 0, 0])) - - @pytest.mark.unit - @pytest.mark.parametrize("ntbr", [132 - 45]) - @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)]) - @pytest.mark.parametrize("mspb", [3, 10]) - def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) - assert all(class_target_vol == torch.tensor([18, 13, 56, 0])) - - @pytest.mark.unit - @pytest.mark.parametrize("ntbr", [3]) - @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)]) - @pytest.mark.parametrize("mspb", [0, 2]) - def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) - assert all(class_target_vol == torch.tensor([2, 1, 0])) - - @pytest.mark.unit - @pytest.mark.parametrize("ntbr", [2]) - @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)]) - @pytest.mark.parametrize("mspb", [2]) - def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) - assert all(class_target_vol == torch.tensor([2, 0, 0, 0])) - - -class TestClassExport: - @pytest.mark.unit - def test_online_segmentor_class_export(self): - _OnlineSegmentor = torch.jit.script(OnlineSegmentor) - online_segmentor = _OnlineSegmentor(sample_rate=16000) - assert isinstance(online_segmentor, OnlineSegmentor) - - @pytest.mark.unit - def test_online_segmentor_instance_export(self): - online_segmentor = OnlineSegmentor(sample_rate=16000) - online_segmentor = torch.jit.script(online_segmentor) - isinstance(online_segmentor, torch.jit._script.RecursiveScriptClass) - - @pytest.mark.unit - def test_online_speaker_clustering_instance_export(self): - online_clus = OnlineSpeakerClustering( - max_num_speakers=8, - max_rp_threshold=0.15, - sparse_search_volume=30, - history_buffer_size=150, - current_buffer_size=150, - cuda=True, - ) - online_clus = torch.jit.script(online_clus) - isinstance(online_clus, torch.jit._script.RecursiveScriptClass) - - @pytest.mark.unit - def test_online_speaker_clustering_instance_export(self): - offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=True) - offline_speaker_clustering = torch.jit.script(offline_speaker_clustering) - isinstance(offline_speaker_clustering, torch.jit._script.RecursiveScriptClass) - - -class TestDiarizationSegmentationUtils: - """ - Test segmentation util functions - """ - - @pytest.mark.unit - @pytest.mark.parametrize( - "intervals", - [ - [[1, 4], [2, 6], [8, 10], [15, 18]], - [[8, 10], [15, 18], [2, 6], [1, 3]], - [[8, 10], [15, 18], [2, 6], [1, 3], [3, 5]], - [[8, 10], [8, 8], [15, 18], [2, 6], [1, 6], [2, 4]], - ], - ) - @pytest.mark.parametrize("target", [[[1, 6], [8, 10], [15, 18]]]) - def test_merge_int_intervals_ex1(self, intervals, target): - merged = merge_int_intervals(intervals) - assert check_range_values(target, merged) - - @pytest.mark.unit - @pytest.mark.parametrize( - "intervals", - [ - [[6, 8], [0, 9], [2, 4], [4, 7]], - [[0, 9], [6, 8], [4, 7], [2, 4]], - [[0, 4], [0, 0], [4, 9], [2, 4]], - [[6, 8], [2, 8], [0, 3], [3, 4], [4, 5], [5, 9]], - ], - ) - @pytest.mark.parametrize("target", [[[0, 9]]]) - def test_merge_int_intervals_ex2(self, intervals, target): - merged = merge_int_intervals(intervals) - assert check_range_values(target, merged) - - @pytest.mark.unit - @pytest.mark.parametrize("intervals", [[[0, 1], [1, 9]], [[0, 0], [0, 9]], [[0, 9], [0, 9]]]) - @pytest.mark.parametrize("target", [[[0, 9]]]) - def test_merge_int_intervals_edge_test(self, intervals, target): - merged = merge_int_intervals(intervals) - assert check_range_values(target, merged) - - @pytest.mark.unit - @pytest.mark.parametrize("rangeA", [[1.0, 2.0]]) - @pytest.mark.parametrize("rangeB", [[0.5, 1.5], [0.9999, 1.0001]]) - def test_is_overlap_true(self, rangeA, rangeB): - assert is_overlap(rangeA, rangeB) - - @pytest.mark.unit - @pytest.mark.parametrize("rangeA", [[1.0, 2.0]]) - @pytest.mark.parametrize("rangeB", [[2.0, 2.5], [-1.0, 1.00]]) - def test_is_overlap_false(self, rangeA, rangeB): - assert not is_overlap(rangeA, rangeB) - - @pytest.mark.unit - @pytest.mark.parametrize("x", [1.0, 2.3456]) - @pytest.mark.parametrize("decimals", [1, 2, 3, 4]) - def test_fl2int(self, x, decimals): - assert fl2int(x, decimals) == round(x * 10**decimals, 0) - - @pytest.mark.unit - @pytest.mark.parametrize("x", [1234]) - @pytest.mark.parametrize( - "decimals", - [ - 1, - 2, - 3, - 4, - ], - ) - def test_int2fl(self, x, decimals): - assert abs(int2fl(x, decimals) - round(x / (10**decimals), decimals)) < (10 ** -(decimals + 1)) - - @pytest.mark.unit - def test_merge_float_intervals_edge_margin_test(self): - intervals = [[0.0, 1.0], [1.0, 2.0]] - - target_0 = [[0.0, 2.0]] - merged_0 = merge_float_intervals(intervals, margin=0) - assert check_range_values(target_0, merged_0) - - target_1 = [[0.0, 1.0], [1.0, 2.0]] - merged_1 = merge_float_intervals(intervals, margin=1) - assert check_range_values(target_1, merged_1) - - target_2 = [[0.0, 1.0], [1.0, 2.0]] - merged_2 = merge_float_intervals(intervals, margin=2) - assert check_range_values(target_2, merged_2) - - @pytest.mark.unit - @pytest.mark.parametrize( - "intervals", - [ - [[0.25, 1.7], [1.5, 3.0], [2.8, 5.0], [5.5, 10.0]], - [[0.25, 5.0], [5.5, 10.0], [1.5, 3.5]], - [[5.5, 8.05], [8.0, 10.0], [0.25, 5.0]], - [[0.25, 3.0], [1.5, 3.0], [5.5, 10.0], [2.8, 5.0]], - [[0.25, 1.7], [1.5, 3.0], [2.8, 5.0], [5.5, 10.0]], - ], - ) - @pytest.mark.parametrize("target", [[[0.25, 5.0], [5.5, 10.0]]]) - def test_merge_float_overlaps(self, intervals, target): - merged = merge_float_intervals(intervals) - assert check_range_values(target, merged) - - @pytest.mark.unit - def test_get_speech_labels_for_update(self): - frame_start = 3.0 - buffer_end = 6.0 - cumulative_speech_labels = torch.tensor([[0.0000, 3.7600]]) - vad_timestamps = torch.tensor([[0.9600, 4.8400]]) - cursor_for_old_segments = 1.0 - speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update( - frame_start, - buffer_end, - cumulative_speech_labels, - vad_timestamps, - cursor_for_old_segments, - ) - assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8 - assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8 - - # Check if the ranges are containing faulty values - assert check_ranges(speech_labels_for_update) - assert check_ranges(cumulative_speech_labels) - - @pytest.mark.unit - def test_get_online_subsegments_from_buffer(self): - torch.manual_seed(0) - sample_rate = 16000 - speech_labels_for_update = torch.Tensor([[0.0000, 3.7600]]) - audio_buffer = torch.randn(5 * sample_rate) - segment_indexes = [] - window = 2.0 - shift = 1.0 - slice_length = int(window * sample_rate) - range_target = [[0.0, 2.0], [1.0, 3.0], [2.0, 3.76]] - sigs_list, sig_rangel_list, sig_indexes = get_online_subsegments_from_buffer( - buffer_start=0.0, - buffer_end=5.0, - sample_rate=sample_rate, - speech_labels_for_update=speech_labels_for_update, - audio_buffer=audio_buffer, - segment_indexes=segment_indexes, - window=window, - shift=shift, - ) - assert check_range_values(target=range_target, source=sig_rangel_list) - for k, rg in enumerate(sig_rangel_list): - signal = get_target_sig(audio_buffer, rg[0], rg[1], slice_length, sample_rate) - if len(signal) < int(window * sample_rate): - signal = repeat_signal(signal, len(signal), slice_length) - assert len(signal) == int(slice_length), "Length mismatch" - assert (np.abs(signal - sigs_list[k])).sum() < 1e-8, "Audio stream mismatch" - assert (torch.tensor(sig_indexes) - torch.arange(len(range_target))).sum() < 1e-8, "Segment index mismatch" - - @pytest.mark.unit - @pytest.mark.parametrize("frame_start", [3.0]) - @pytest.mark.parametrize("segment_range_ts", [[[0.0, 2.0]]]) - @pytest.mark.parametrize("gt_cursor_for_old_segments", [3.0]) - @pytest.mark.parametrize("gt_cursor_index", [1]) - def test_get_new_cursor_for_update_mulsegs_ex1( - self, frame_start, segment_range_ts, gt_cursor_for_old_segments, gt_cursor_index - ): - cursor_for_old_segments, cursor_index = get_new_cursor_for_update(frame_start, segment_range_ts) - assert cursor_for_old_segments == gt_cursor_for_old_segments - assert cursor_index == gt_cursor_index - - @pytest.mark.unit - @pytest.mark.parametrize("target_range", [[1.0, 4.0]]) - @pytest.mark.parametrize( - "source_range_list", [[[2.0, 3.0], [3.0, 4.0]], [[0.0, 2.0], [3.0, 5.0]], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]] - ) - def get_sub_range_list(self, target_range, source_range_list): - sub_range_list = get_sub_range_list(target_range, source_range_list) - assert sub_range_list == [[2.0, 3.0], [3.0, 4.0]] - - @pytest.mark.unit - @pytest.mark.parametrize("source_range_list", [[[0.0, 2.0]], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]) - def test_tensor_to_list(self, source_range_list): - a_range_tensor = torch.tensor(source_range_list) - converted_list = tensor_to_list(a_range_tensor) - assert source_range_list == converted_list - - @pytest.mark.unit - @pytest.mark.parametrize( - "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate", - [ - (0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), - (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000), - ], - ) - def test_get_online_segments_from_slices( - self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate - ): - sig = torch.randn(int(sample_rate * buffer_end)) - ind_offset, sigs_list, sig_rangel_list, sig_indexes = get_online_segments_from_slices( - sig, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate - ) - assert ind_offset == 2 - assert len(sigs_list) == 2 - assert len(sig_rangel_list) == 2 - assert len(sig_indexes) == 2 - - -class TestClusteringUtilFunctions: - @pytest.mark.parametrize("p_value", [1, 5, 9]) - @pytest.mark.parametrize("N", [9, 20]) - @pytest.mark.parametrize("mask_method", ['binary', 'sigmoid', 'drop']) - def test_get_k_neighbors_connections(self, p_value: int, N: int, mask_method: str, seed=0): - torch.manual_seed(seed) - random_mat = torch.rand(N, N) - affinity_mat = 0.5 * (random_mat + random_mat.T) - affinity_mat = affinity_mat / torch.max(affinity_mat) - binarized_affinity_mat = getKneighborsConnections(affinity_mat, p_value, mask_method) - if mask_method == 'binary': - assert all(binarized_affinity_mat.sum(dim=0) == float(p_value)) - elif mask_method == 'sigmoid': - assert all(binarized_affinity_mat.sum(dim=0) <= float(p_value)) - elif mask_method == 'drop': - assert all(binarized_affinity_mat.sum(dim=0) <= float(p_value)) - - @pytest.mark.unit - @pytest.mark.parametrize("Y_aggr", [torch.tensor([0, 1, 0, 1])]) - @pytest.mark.parametrize("chunk_cluster_count, embeddings_per_chunk", [(2, 50)]) - @pytest.mark.parametrize("window_range_list", [[[0, 1], [1, 2], [2, 3], [3, 4]]]) - @pytest.mark.parametrize( - "absolute_merge_mapping", - [[[torch.tensor([]), torch.tensor([0, 2])], [torch.tensor([]), torch.tensor([1, 3])]]], - ) - @pytest.mark.parametrize("org_len", [4]) - def test_unpack_labels( - self, Y_aggr, window_range_list, absolute_merge_mapping, chunk_cluster_count, embeddings_per_chunk, org_len - ): - expected_result = Y_aggr - longform_speaker_clustering = LongFormSpeakerClustering(cuda=False) - output = longform_speaker_clustering.unpack_labels(Y_aggr, window_range_list, absolute_merge_mapping, org_len) - assert torch.equal(output, expected_result) - - -class TestSpeakerClustering: - """ - Test speaker clustering module - """ - - @pytest.mark.unit - @pytest.mark.parametrize("cuda", [True, False]) - def test_offline_clus_script_save_load(self, cuda): - exported_filename = 'speaker_clustering_script.pt' - speaker_clustering_python = SpeakerClustering(maj_vote_spk_count=False, cuda=cuda) - speaker_clustering_scripted_source = torch.jit.script(speaker_clustering_python) - torch.jit.save(speaker_clustering_scripted_source, exported_filename) - assert os.path.exists(exported_filename) - os.remove(exported_filename) - assert not os.path.exists(exported_filename) - - @pytest.mark.unit - @pytest.mark.parametrize("cuda", [True, False]) - def test_online_clus_script_save_load(self, cuda): - exported_filename = 'speaker_clustering_script.pt' - speaker_clustering_python = OnlineSpeakerClustering( - max_num_speakers=8, - max_rp_threshold=0.15, - sparse_search_volume=30, - history_buffer_size=150, - current_buffer_size=150, - cuda=cuda, - ) - speaker_clustering_scripted_source = torch.jit.script(speaker_clustering_python) - torch.jit.save(speaker_clustering_scripted_source, exported_filename) - assert os.path.exists(exported_filename) - os.remove(exported_filename) - assert not os.path.exists(exported_filename) - - @pytest.mark.run_only_on('GPU') - @pytest.mark.unit - @pytest.mark.parametrize("n_spks", [1, 2, 3, 4, 5, 6, 7]) - @pytest.mark.parametrize("total_sec, SSV, perturb_sigma, seed", [(30, 10, 0.1, 0)]) - @pytest.mark.parametrize("jit_script", [False, True]) - def test_offline_speaker_clustering(self, n_spks, total_sec, SSV, perturb_sigma, seed, jit_script, cuda=True): - spk_dur = total_sec / n_spks - em, ts, mc, mw, spk_ts, gt = generate_toy_data( - n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=perturb_sigma, torch_seed=seed - ) - offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, cuda=cuda) - assert isinstance(offline_speaker_clustering, SpeakerClustering) - if jit_script: - offline_speaker_clustering = torch.jit.script(offline_speaker_clustering) - - Y_out = offline_speaker_clustering.forward_infer( - embeddings_in_scales=em, - timestamps_in_scales=ts, - multiscale_segment_counts=mc, - multiscale_weights=mw, - oracle_num_speakers=-1, - max_num_speakers=8, - enhanced_count_thres=40, - sparse_search_volume=SSV, - max_rp_threshold=0.15, - fixed_thres=-1.0, - ) - permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) - permuted_Y = permuted_Y.to(gt.device) - # mc[-1] is the number of base scale segments - assert len(set(permuted_Y.tolist())) == n_spks - assert Y_out.shape[0] == mc[-1] - assert all(permuted_Y == gt) - - @pytest.mark.run_only_on('CPU') - @pytest.mark.unit - @pytest.mark.parametrize("n_spks", [1, 2, 3, 4, 5, 6, 7]) - @pytest.mark.parametrize("total_sec, SSV, perturb_sigma, seed", [(30, 10, 0.1, 0)]) - @pytest.mark.parametrize("jit_script", [False, True]) - def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_sigma, seed, jit_script, cuda=False): - self.test_offline_speaker_clustering(n_spks, total_sec, SSV, perturb_sigma, seed, jit_script, cuda=cuda) - - @pytest.mark.run_only_on('CPU') - @pytest.mark.unit - @pytest.mark.parametrize("n_spks", [1]) - @pytest.mark.parametrize("spk_dur", [0.25, 0.5, 0.75, 1, 1.5, 2]) - @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)]) - @pytest.mark.parametrize("seed", [0]) - def test_offline_speaker_clustering_very_short_cpu( - self, - n_spks, - spk_dur, - SSV, - enhanced_count_thres, - min_samples_for_nmesc, - seed, - ): - em, ts, mc, mw, spk_ts, gt = generate_toy_data( - n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed - ) - offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=False) - assert isinstance(offline_speaker_clustering, SpeakerClustering) - Y_out = offline_speaker_clustering.forward_infer( - embeddings_in_scales=em, - timestamps_in_scales=ts, - multiscale_segment_counts=mc, - multiscale_weights=mw, - oracle_num_speakers=-1, - max_num_speakers=8, - enhanced_count_thres=enhanced_count_thres, - sparse_search_volume=SSV, - max_rp_threshold=0.15, - fixed_thres=-1.0, - ) - permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) - permuted_Y = permuted_Y.to(gt.device) - # mc[-1] is the number of base scale segments - assert len(set(permuted_Y.tolist())) == n_spks - assert Y_out.shape[0] == mc[-1] - assert all(permuted_Y == gt) - - @pytest.mark.run_only_on('GPU') - @pytest.mark.unit - @pytest.mark.parametrize("spk_dur", [0.25, 0.5, 0.75, 1, 2, 4]) - @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) - @pytest.mark.parametrize("seed", [0]) - def test_offline_speaker_clustering_very_short_gpu( - self, - n_spks, - spk_dur, - SSV, - enhanced_count_thres, - min_samples_for_nmesc, - seed, - ): - em, ts, mc, mw, spk_ts, gt = generate_toy_data( - n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed - ) - offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=True) - assert isinstance(offline_speaker_clustering, SpeakerClustering) - Y_out = offline_speaker_clustering.forward_infer( - embeddings_in_scales=em, - timestamps_in_scales=ts, - multiscale_segment_counts=mc, - multiscale_weights=mw, - oracle_num_speakers=-1, - max_num_speakers=8, - enhanced_count_thres=enhanced_count_thres, - sparse_search_volume=SSV, - max_rp_threshold=0.15, - fixed_thres=-1.0, - ) - permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) - permuted_Y = permuted_Y.to(gt.device) - # mc[-1] is the number of base scale segments - assert Y_out.shape[0] == mc[-1] - assert all(permuted_Y == gt) - - @pytest.mark.run_only_on('CPU') - @pytest.mark.unit - @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(2, 5, 40, 6)]) - @pytest.mark.parametrize("spk_dur, chunk_cluster_count, embeddings_per_chunk", [(120, 4, 50), (240, 4, 100)]) - @pytest.mark.parametrize("seed", [0]) - @pytest.mark.parametrize("jit_script", [False, True]) - def test_longform_speaker_clustering_cpu( - self, - n_spks, - spk_dur, - SSV, - enhanced_count_thres, - min_samples_for_nmesc, - chunk_cluster_count, - embeddings_per_chunk, - jit_script, - seed, - ): - em, ts, mc, mw, spk_ts, gt = generate_toy_data( - n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed - ) - longform_speaker_clustering = LongFormSpeakerClustering(cuda=False) - if jit_script: - longform_speaker_clustering = torch.jit.script(longform_speaker_clustering) - else: - assert isinstance(longform_speaker_clustering, LongFormSpeakerClustering) - Y_out = longform_speaker_clustering.forward_infer( - embeddings_in_scales=em, - timestamps_in_scales=ts, - multiscale_segment_counts=mc, - multiscale_weights=mw, - oracle_num_speakers=-1, - max_num_speakers=n_spks, - enhanced_count_thres=enhanced_count_thres, - sparse_search_volume=SSV, - max_rp_threshold=0.15, - fixed_thres=-1.0, - chunk_cluster_count=chunk_cluster_count, - embeddings_per_chunk=embeddings_per_chunk, - ) - permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) - permuted_Y = permuted_Y.to(gt.device) - - # mc[-1] is the number of base scale segments - assert Y_out.shape[0] == mc[-1] - assert all(permuted_Y == gt) - - @pytest.mark.run_only_on('GPU') - @pytest.mark.unit - @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(2, 5, 40, 6)]) - @pytest.mark.parametrize("spk_dur, chunk_cluster_count, embeddings_per_chunk", [(120, 4, 50), (240, 4, 100)]) - @pytest.mark.parametrize("seed", [0]) - @pytest.mark.parametrize("jit_script", [False, True]) - def test_longform_speaker_clustering_gpu( - self, - n_spks, - spk_dur, - SSV, - enhanced_count_thres, - min_samples_for_nmesc, - chunk_cluster_count, - embeddings_per_chunk, - jit_script, - seed, - ): - em, ts, mc, mw, spk_ts, gt = generate_toy_data( - n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed - ) - longform_speaker_clustering = LongFormSpeakerClustering(cuda=True) - - if jit_script: - longform_speaker_clustering = torch.jit.script(longform_speaker_clustering) - else: - assert isinstance(longform_speaker_clustering, LongFormSpeakerClustering) - - Y_out = longform_speaker_clustering.forward_infer( - embeddings_in_scales=em, - timestamps_in_scales=ts, - multiscale_segment_counts=mc, - multiscale_weights=mw, - oracle_num_speakers=-1, - max_num_speakers=n_spks, - enhanced_count_thres=enhanced_count_thres, - sparse_search_volume=SSV, - max_rp_threshold=0.15, - fixed_thres=-1.0, - chunk_cluster_count=chunk_cluster_count, - embeddings_per_chunk=embeddings_per_chunk, - ) - permuted_Y = stitch_cluster_labels(Y_old=gt, Y_new=Y_out) - permuted_Y = permuted_Y.to(gt.device) - - # mc[-1] is the number of base scale segments - assert Y_out.shape[0] == mc[-1] - assert all(permuted_Y == gt) - - @pytest.mark.run_only_on('GPU') - @pytest.mark.unit - @pytest.mark.parametrize("n_spks", [1, 2, 3]) - @pytest.mark.parametrize("total_sec, buffer_size, sigma", [(30, 30, 0.1)]) - @pytest.mark.parametrize("seed", [0]) - @pytest.mark.parametrize("jit_script", [False, True]) - def test_online_speaker_clustering(self, n_spks, total_sec, buffer_size, sigma, seed, jit_script, cuda=True): - step_per_frame = 2 - spk_dur = total_sec / n_spks - em, ts, mc, _, _, gt = generate_toy_data(n_spks, spk_dur=spk_dur, perturb_sigma=sigma, torch_seed=seed) - em_s, ts_s = split_input_data(em, ts, mc) - - emb_gen = em_s[-1] - segment_indexes = ts_s[-1] - if cuda: - device = torch.cuda.current_device() - emb_gen, segment_indexes = emb_gen.to(device), segment_indexes.to(device) - - history_buffer_size = buffer_size - current_buffer_size = buffer_size - - online_clus = OnlineSpeakerClustering( - max_num_speakers=8, - max_rp_threshold=0.15, - sparse_search_volume=30, - history_buffer_size=history_buffer_size, - current_buffer_size=current_buffer_size, - cuda=cuda, - ) - if jit_script: - online_clus = torch.jit.script(online_clus) - - n_frames = int(emb_gen.shape[0] / step_per_frame) - evaluation_list = [] - - # Simulate online speaker clustering - for frame_index in range(n_frames): - curr_emb = emb_gen[0 : (frame_index + 1) * step_per_frame] - base_segment_indexes = torch.arange(curr_emb.shape[0]).to(curr_emb.device) - # Check history_buffer_size and history labels - assert ( - online_clus.history_embedding_buffer_emb.shape[0] <= history_buffer_size - ), "History buffer size error" - assert ( - online_clus.history_embedding_buffer_emb.shape[0] - == online_clus.history_embedding_buffer_label.shape[0] - ) - - # Call clustering function - merged_clus_labels = online_clus.forward_infer( - curr_emb=curr_emb, base_segment_indexes=base_segment_indexes, frame_index=frame_index, cuda=cuda - ) - - # Resolve permutations - assert len(merged_clus_labels) == (frame_index + 1) * step_per_frame - # Resolve permutation issue by using stitch_cluster_labels function - merged_clus_labels = merged_clus_labels.cpu() - merged_clus_labels = stitch_cluster_labels(Y_old=gt[: len(merged_clus_labels)], Y_new=merged_clus_labels) - evaluation_list.extend(list(merged_clus_labels == gt[: len(merged_clus_labels)])) - - assert online_clus.is_online - cumul_label_acc = sum(evaluation_list) / len(evaluation_list) - assert cumul_label_acc > 0.9 - - @pytest.mark.run_only_on('CPU') - @pytest.mark.unit - @pytest.mark.parametrize("n_spks, total_sec, buffer_size, sigma, seed", [(3, 30, 30, 0.1, 0)]) - @pytest.mark.parametrize("jit_script", [False, True]) - def test_online_speaker_clustering_cpu(self, n_spks, total_sec, buffer_size, sigma, seed, jit_script, cuda=False): - self.test_online_speaker_clustering(n_spks, total_sec, buffer_size, sigma, seed, jit_script, cuda) - - -class TestLinearSumAssignmentAlgorithm: - @pytest.mark.unit - def test_lsa_solver_export_test(self): - cost_matrix = torch.randint(0, 10, (3, 3)) - solver = LinearSumAssignmentSolver(cost_matrix) - solver = torch.jit.script(solver) - assert isinstance(solver, torch.jit._script.RecursiveScriptClass) - - @pytest.mark.unit - @pytest.mark.parametrize( - "cost_matrix", - [torch.tensor([[7, 6, 2, 9, 2], [6, 2, 1, 3, 9], [5, 6, 8, 9, 5], [6, 8, 5, 8, 6], [9, 5, 6, 4, 7]])], - ) - def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix): - """ - Test the linear sum assignment algorithm with a cost matrix - - Compare with the scipy implementation and make sure the final cost is the same. - NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. - This test only checks if the cost is the same. - """ - row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix) - row_ind_sc, col_ind_sc = scipy_linear_sum_assignment(cost_matrix.cpu().numpy()) - cost_nm = sum(cost_matrix[row_ind_nm, col_ind_nm]) - cost_sc = sum(cost_matrix[row_ind_sc, col_ind_sc]) - assert cost_nm == cost_sc - - @pytest.mark.unit - @pytest.mark.parametrize("seed", [0, 1]) - @pytest.mark.parametrize("mat_size", [1, 2, 4, 8]) - def test_linear_sum_assignment_algorithm_random_matrix(self, seed, mat_size): - torch.manual_seed(seed) - cost_matrix = torch.randint(0, 10, (mat_size, mat_size)) - self.test_linear_sum_assignment_algorithm_cost_matrix(cost_matrix) diff --git a/tests/collections/asr/test_speaker_label_models.py b/tests/collections/speaker_tasks/test_speaker_label_models.py similarity index 100% rename from tests/collections/asr/test_speaker_label_models.py rename to tests/collections/speaker_tasks/test_speaker_label_models.py From 136083152d24debf115e945ff3c64a6fdf0a2db3 Mon Sep 17 00:00:00 2001 From: taejinp Date: Tue, 19 Nov 2024 18:38:25 -0800 Subject: [PATCH 24/47] Fixed uninit variable issue in bce_loss.py spotted by codeQL Signed-off-by: taejinp --- nemo/collections/asr/losses/bce_loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index d2aa03319007..83f6b57c0203 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -98,6 +98,7 @@ def forward(self, probs, labels, target_lens): probs = torch.cat(probs_list, dim=0) labels = torch.cat(targets_list, dim=0) norm_weight = torch.zeros_like(labels).detach().clone() + loss = torch.tensor(0.0).to(labels.device) if self.class_normalization in ['class', 'class_binary', 'binary']: if self.class_normalization in ['class', 'class_binary']: From 553197a8a45f397901be902a7c643f268e77aae8 Mon Sep 17 00:00:00 2001 From: tango4j Date: Wed, 20 Nov 2024 02:43:55 +0000 Subject: [PATCH 25/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../speaker_tasks/test_diar_datasets.py | 30 +- .../speaker_tasks/test_diar_label_models.py | 59 ++-- .../test_diar_lhotse_datasets.py | 189 ++++++------ .../test_diar_neural_inference.py | 9 +- .../test_diar_sortformer_models.py | 55 ++-- .../test_speaker_label_models.py | 12 +- .../utils/test_data_simul_utils.py | 12 +- .../speaker_tasks/utils/test_diar_utils.py | 109 +++++-- .../utils/test_multispeaker_utils.py | 282 ++++++++++-------- 9 files changed, 451 insertions(+), 306 deletions(-) diff --git a/tests/collections/speaker_tasks/test_diar_datasets.py b/tests/collections/speaker_tasks/test_diar_datasets.py index 915a54382f99..ea57653b9324 100644 --- a/tests/collections/speaker_tasks/test_diar_datasets.py +++ b/tests/collections/speaker_tasks/test_diar_datasets.py @@ -28,10 +28,15 @@ from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer -from nemo.collections.asr.parts.utils.speaker_utils import read_rttm_lines, get_offset_and_duration, get_vad_out_from_rttm_line +from nemo.collections.asr.parts.utils.speaker_utils import ( + get_offset_and_duration, + get_vad_out_from_rttm_line, + read_rttm_lines, +) -def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec): - """ + +def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec): + """ Check if the maximum RTTM duration exceeds the length of the provided audio file. Args: @@ -48,6 +53,7 @@ def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec): max_rttm_sec = max(max_rttm_sec, start + dur) return max_rttm_sec <= wav_len_in_sec + class TestAudioToSpeechE2ESpkDiarDataset: @pytest.mark.unit @@ -56,7 +62,7 @@ def test_e2e_speaker_diar_dataset(self, test_data_dir): batch_size = 4 num_samples = 8 - + device = 'gpu' if torch.cuda.is_available() else 'cpu' data_dict_list = [] with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: @@ -82,8 +88,8 @@ def test_e2e_speaker_diar_dataset(self, test_data_dir): window_stride=0.01, global_rank=0, soft_targets=False, - ) - + ) + dataloader_instance = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, @@ -95,7 +101,7 @@ def test_e2e_speaker_diar_dataset(self, test_data_dir): ) assert len(dataloader_instance) == (num_samples / batch_size) # Check if the number of batches is correct batch_counts = len(dataloader_instance) - + deviation_thres_rate = 0.01 # 1% deviation allowed for batch_index, batch in enumerate(dataloader_instance): if batch_index != batch_counts - 1: @@ -103,10 +109,14 @@ def test_e2e_speaker_diar_dataset(self, test_data_dir): audio_signals, audio_signal_len, targets, target_lens = batch for sample_index in range(audio_signals.shape[0]): dataloader_audio_in_sec = audio_signal_len[sample_index].item() - data_dur_in_sec = abs(data_dict_list[batch_size*batch_index + sample_index]['duration'] * featurizer.sample_rate - dataloader_audio_in_sec) - assert data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec, "Duration deviation exceeds 1%" + data_dur_in_sec = abs( + data_dict_list[batch_size * batch_index + sample_index]['duration'] * featurizer.sample_rate + - dataloader_audio_in_sec + ) + assert ( + data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec + ), "Duration deviation exceeds 1%" assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values" assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values" assert not torch.isnan(targets).any(), "targets tensor contains NaN values" assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values" - \ No newline at end of file diff --git a/tests/collections/speaker_tasks/test_diar_label_models.py b/tests/collections/speaker_tasks/test_diar_label_models.py index cf073d9e85e2..ab1e255010b6 100644 --- a/tests/collections/speaker_tasks/test_diar_label_models.py +++ b/tests/collections/speaker_tasks/test_diar_label_models.py @@ -16,15 +16,21 @@ import torch from omegaconf import DictConfig -from nemo.collections.asr.models import EncDecDiarLabelModel from nemo.collections.asr.losses import BCELoss +from nemo.collections.asr.models import EncDecDiarLabelModel + @pytest.fixture() def msdd_model(): preprocessor = { 'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', - 'params': {"features": 80, "window_size": 0.025, "window_stride": 0.01, "sample_rate": 16000,}, + 'params': { + "features": 80, + "window_size": 0.025, + "window_stride": 0.01, + "sample_rate": 16000, + }, } speaker_model_encoder = { @@ -166,35 +172,36 @@ def test_forward_infer(self, msdd_model): diff = torch.max(torch.abs(scale_weights_instance - scale_weights_batch)) assert diff <= 1e-6 + class TestBCELoss: @pytest.mark.unit @pytest.mark.parametrize( - "probs, labels, target_lens, reduction, expected_output", [ - ( - torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32), - torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), - torch.tensor([[2]]), - "mean", - torch.tensor(0.693147, dtype=torch.float32) - ), - ( - torch.tensor([[[0.5, 0.5], [0.0, 1.0]]], dtype=torch.float32), - torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), - torch.tensor([[1]]), - "mean", - torch.tensor(0.693147, dtype=torch.float32) - ), - ( - torch.tensor([[[0, 1], [1, 0]]], dtype=torch.float32), - torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), - torch.tensor([[2]]), - "mean", - torch.tensor(100, dtype=torch.float32) - ) - ] + "probs, labels, target_lens, reduction, expected_output", + [ + ( + torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32), + torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), + torch.tensor([[2]]), + "mean", + torch.tensor(0.693147, dtype=torch.float32), + ), + ( + torch.tensor([[[0.5, 0.5], [0.0, 1.0]]], dtype=torch.float32), + torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), + torch.tensor([[1]]), + "mean", + torch.tensor(0.693147, dtype=torch.float32), + ), + ( + torch.tensor([[[0, 1], [1, 0]]], dtype=torch.float32), + torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), + torch.tensor([[2]]), + "mean", + torch.tensor(100, dtype=torch.float32), + ), + ], ) def test_loss(self, probs, labels, target_lens, reduction, expected_output): loss = BCELoss(reduction=reduction) result = loss(probs=probs, labels=labels, target_lens=target_lens) assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" - diff --git a/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py index 0aa676a6318e..82c1e8877ca2 100644 --- a/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py +++ b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py @@ -16,101 +16,112 @@ import os import tempfile from unittest import mock + import pytest import torch import torch.cuda from omegaconf import DictConfig -from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config + from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config + def get_train_ds_config(manifest_filepath, batch_size, num_workers) -> DictConfig: - return DictConfig({ - 'manifest_filepath': manifest_filepath, - 'sample_rate': 16000, - 'num_spks': 4, - 'session_len_sec': 90, - 'soft_label_thres': 0.5, - 'soft_targets': False, - 'labels': None, - 'batch_size': batch_size, - 'shuffle': True, - 'num_workers': num_workers, - 'validation_mode': False, - 'use_lhotse': True, - 'use_bucketing': True, - 'num_buckets': 10, - 'bucket_duration_bins': [10, 20, 30, 40, 50, 60, 70, 80, 90], - 'pin_memory': True, - 'min_duration': 80, - 'max_duration': 90, - 'batch_duration': 400, - 'quadratic_duration': 1200, - 'bucket_buffer_size': 20000, - 'shuffle_buffer_size': 10000, - 'window_stride': 0.01, - 'subsampling_factor': 8, - }) + return DictConfig( + { + 'manifest_filepath': manifest_filepath, + 'sample_rate': 16000, + 'num_spks': 4, + 'session_len_sec': 90, + 'soft_label_thres': 0.5, + 'soft_targets': False, + 'labels': None, + 'batch_size': batch_size, + 'shuffle': True, + 'num_workers': num_workers, + 'validation_mode': False, + 'use_lhotse': True, + 'use_bucketing': True, + 'num_buckets': 10, + 'bucket_duration_bins': [10, 20, 30, 40, 50, 60, 70, 80, 90], + 'pin_memory': True, + 'min_duration': 80, + 'max_duration': 90, + 'batch_duration': 400, + 'quadratic_duration': 1200, + 'bucket_buffer_size': 20000, + 'shuffle_buffer_size': 10000, + 'window_stride': 0.01, + 'subsampling_factor': 8, + } + ) + def get_validation_ds_config(manifest_filepath, batch_size, num_workers) -> DictConfig: - return DictConfig({ - 'manifest_filepath': manifest_filepath, - 'is_tarred': False, - 'tarred_audio_filepaths': None, - 'sample_rate': 16000, - 'num_spks': 4, - 'session_len_sec': 90, - 'soft_label_thres': 0.5, - 'soft_targets': False, - 'labels': None, - 'batch_size': batch_size, - 'shuffle': False, - 'seq_eval_mode': True, - 'num_workers': num_workers, - 'validation_mode': True, - 'use_lhotse': False, - 'use_bucketing': False, - 'drop_last': False, - 'pin_memory': True, - 'window_stride': 0.01, - 'subsampling_factor': 8, - }) + return DictConfig( + { + 'manifest_filepath': manifest_filepath, + 'is_tarred': False, + 'tarred_audio_filepaths': None, + 'sample_rate': 16000, + 'num_spks': 4, + 'session_len_sec': 90, + 'soft_label_thres': 0.5, + 'soft_targets': False, + 'labels': None, + 'batch_size': batch_size, + 'shuffle': False, + 'seq_eval_mode': True, + 'num_workers': num_workers, + 'validation_mode': True, + 'use_lhotse': False, + 'use_bucketing': False, + 'drop_last': False, + 'pin_memory': True, + 'window_stride': 0.01, + 'subsampling_factor': 8, + } + ) + def get_test_ds_config(manifest_filepath, batch_size, num_workers) -> DictConfig: - return DictConfig({ - 'manifest_filepath': manifest_filepath, - 'is_tarred': False, - 'tarred_audio_filepaths': None, - 'sample_rate': 16000, - 'num_spks': 4, - 'session_len_sec': 90, - 'soft_label_thres': 0.5, - 'soft_targets': False, - 'labels': None, - 'batch_size': batch_size, - 'shuffle': False, - 'seq_eval_mode': True, - 'num_workers': num_workers, - 'validation_mode': True, - 'use_lhotse': False, - 'use_bucketing': False, - 'drop_last': False, - 'pin_memory': True, - 'window_stride': 0.01, - 'subsampling_factor': 8, - }) + return DictConfig( + { + 'manifest_filepath': manifest_filepath, + 'is_tarred': False, + 'tarred_audio_filepaths': None, + 'sample_rate': 16000, + 'num_spks': 4, + 'session_len_sec': 90, + 'soft_label_thres': 0.5, + 'soft_targets': False, + 'labels': None, + 'batch_size': batch_size, + 'shuffle': False, + 'seq_eval_mode': True, + 'num_workers': num_workers, + 'validation_mode': True, + 'use_lhotse': False, + 'use_bucketing': False, + 'drop_last': False, + 'pin_memory': True, + 'window_stride': 0.01, + 'subsampling_factor': 8, + } + ) + class TestLhotseAudioToSpeechE2ESpkDiarDataset: - @pytest.mark.unit @pytest.mark.parametrize( - "batch_size, num_workers, split", - [ - (4, 8, 'train'), # Example 1 - (4, 0, 'train'), # Example 2 - (2, 4, 'validation'), # Example 3 - (8, 2, 'test') # Example 4 - ] + "batch_size, num_workers, split", + [ + (4, 8, 'train'), # Example 1 + (4, 0, 'train'), # Example 2 + (2, 4, 'validation'), # Example 3 + (8, 2, 'test'), # Example 4 + ], ) def test_e2e_speaker_diar_lhotse_dataset(self, test_data_dir, batch_size, num_workers, split): manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/diarizer/lsm_val.json')) @@ -130,28 +141,32 @@ def test_e2e_speaker_diar_lhotse_dataset(self, test_data_dir, batch_size, num_wo f.seek(0) if split == 'train': - config = get_train_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) + config = get_train_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) elif split == 'validation': - config = get_train_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) + config = get_train_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) elif split == 'test': - config = get_test_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) - + config = get_test_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) + dataloader_instance = get_lhotse_dataloader_from_config( config, global_rank=0, world_size=1, dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), ) - + deviation_thres_rate = 0.01 # 1% deviation allowed for batch_index, batch in enumerate(dataloader_instance): audio_signals, audio_signal_len, targets, target_lens = batch for sample_index in range(audio_signals.shape[0]): dataloader_audio_in_sec = audio_signal_len[sample_index].item() - data_dur_in_sec = abs(data_dict_list[batch_size*batch_index + sample_index]['duration'] * config.sample_rate - dataloader_audio_in_sec) - assert data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec, "Duration deviation exceeds 1%" + data_dur_in_sec = abs( + data_dict_list[batch_size * batch_index + sample_index]['duration'] * config.sample_rate + - dataloader_audio_in_sec + ) + assert ( + data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec + ), "Duration deviation exceeds 1%" assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values" assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values" assert not torch.isnan(targets).any(), "targets tensor contains NaN values" assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values" - \ No newline at end of file diff --git a/tests/collections/speaker_tasks/test_diar_neural_inference.py b/tests/collections/speaker_tasks/test_diar_neural_inference.py index 3218a631bda3..64c1196cd9a6 100644 --- a/tests/collections/speaker_tasks/test_diar_neural_inference.py +++ b/tests/collections/speaker_tasks/test_diar_neural_inference.py @@ -28,7 +28,10 @@ class TestNeuralDiarizerInference: torch.device("cpu"), pytest.param( torch.device("cuda"), - marks=pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA required for test.',), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), + reason='CUDA required for test.', + ), ), ], ) @@ -69,6 +72,6 @@ def test_msdd_diar_inference(self, tmpdir, test_data_dir, device, num_speakers, # assert only 1 speaker & segment assert len(annotation.labels()) == 1 assert len(list(annotation.itersegments())) == 1 - + # class TestSortformerDiarizerInference: - # TODO: This test can only be implemented once SortformerDiarizer model is uploaded. + # TODO: This test can only be implemented once SortformerDiarizer model is uploaded. diff --git a/tests/collections/speaker_tasks/test_diar_sortformer_models.py b/tests/collections/speaker_tasks/test_diar_sortformer_models.py index 6e59206df894..db5c3962521b 100644 --- a/tests/collections/speaker_tasks/test_diar_sortformer_models.py +++ b/tests/collections/speaker_tasks/test_diar_sortformer_models.py @@ -32,8 +32,7 @@ def sortformer_model(): 'max_num_of_spks': 4, 'session_len_sec': 90, } - - + preprocessor = { '_target_': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'normalize': 'per_feature', @@ -104,27 +103,27 @@ def sortformer_model(): 'weight': None, 'reduction': 'mean', } - modelConfig = DictConfig( - {'pil_weight': 0.5, - 'ats_weight': 0.5, - 'num_workers': 1, - 'fc_d_model': 512, - 'tf_d_model': 192, - 'max_num_of_spks': 4, - 'session_len_sec': 90, - 'encoder': DictConfig(encoder), - 'transformer_encoder': DictConfig(transformer_encoder), - 'sortformer_modules': DictConfig(sortformer_modules), - 'preprocessor': DictConfig(preprocessor), - 'loss': DictConfig(loss), - 'optim': { - 'optimizer': 'Adam', - 'lr': 0.001, - 'betas': (0.9, 0.98), + { + 'pil_weight': 0.5, + 'ats_weight': 0.5, + 'num_workers': 1, + 'fc_d_model': 512, + 'tf_d_model': 192, + 'max_num_of_spks': 4, + 'session_len_sec': 90, + 'encoder': DictConfig(encoder), + 'transformer_encoder': DictConfig(transformer_encoder), + 'sortformer_modules': DictConfig(sortformer_modules), + 'preprocessor': DictConfig(preprocessor), + 'loss': DictConfig(loss), + 'optim': { + 'optimizer': 'Adam', + 'lr': 0.001, + 'betas': (0.9, 0.98), + }, } - } ) model = SortformerEncLabelModel(cfg=modelConfig) return model @@ -140,13 +139,13 @@ def test_constructor(self, sortformer_model): @pytest.mark.unit @pytest.mark.parametrize( - "batch_size, frame_length, sample_len", - [ - (4, 0.08, 16), # Example 1 - (2, 0.02, 32), # Example 2 - (1, 0.1, 20), # Example 3 - ] -) + "batch_size, frame_length, sample_len", + [ + (4, 0.08, 16), # Example 1 + (2, 0.02, 32), # Example 2 + (1, 0.1, 20), # Example 3 + ], + ) def test_forward_infer(self, sortformer_model, batch_size, frame_length, sample_len, num_spks=4): sortformer_diar_model = sortformer_model.eval() confdict = sortformer_diar_model.to_config_dict() @@ -161,7 +160,7 @@ def test_forward_infer(self, sortformer_model, batch_size, frame_length, sample_ # batch size 1 preds_list = [] for i in range(input_signal.size(0)): - preds= sortformer_diar_model.forward(input_signal[i : i + 1], input_signal_length[i : i + 1]) + preds = sortformer_diar_model.forward(input_signal[i : i + 1], input_signal_length[i : i + 1]) preds_list.append(preds) preds_instance = torch.cat(preds_list, 0) diff --git a/tests/collections/speaker_tasks/test_speaker_label_models.py b/tests/collections/speaker_tasks/test_speaker_label_models.py index 29b5c9eea643..81a051e32e66 100644 --- a/tests/collections/speaker_tasks/test_speaker_label_models.py +++ b/tests/collections/speaker_tasks/test_speaker_label_models.py @@ -96,7 +96,11 @@ def test_ecapa_enc_dec(self): } modelConfig = DictConfig( - {'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder),} + { + 'preprocessor': DictConfig(preprocessor), + 'encoder': DictConfig(encoder), + 'decoder': DictConfig(decoder), + } ) speaker_model = EncDecSpeakerLabelModel(cfg=modelConfig) speaker_model.train() @@ -142,7 +146,11 @@ def test_titanet_enc_dec(self): } modelConfig = DictConfig( - {'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder),} + { + 'preprocessor': DictConfig(preprocessor), + 'encoder': DictConfig(encoder), + 'decoder': DictConfig(decoder), + } ) speaker_model = EncDecSpeakerLabelModel(cfg=modelConfig) speaker_model.train() diff --git a/tests/collections/speaker_tasks/utils/test_data_simul_utils.py b/tests/collections/speaker_tasks/utils/test_data_simul_utils.py index 295b79c76d18..9a27820cdfa1 100644 --- a/tests/collections/speaker_tasks/utils/test_data_simul_utils.py +++ b/tests/collections/speaker_tasks/utils/test_data_simul_utils.py @@ -101,11 +101,15 @@ def get_data_simulation_configs(): }, 'segment_augmentor': { 'add_seg_aug': False, - 'augmentor': {'gain': {'prob': 0.5, 'min_gain_dbfs': -10.0, 'max_gain_dbfs': 10.0},}, + 'augmentor': { + 'gain': {'prob': 0.5, 'min_gain_dbfs': -10.0, 'max_gain_dbfs': 10.0}, + }, }, 'session_augmentor': { 'add_sess_aug': False, - 'augmentor': {'white_noise': {'prob': 1.0, 'min_level': -90, 'max_level': -46},}, + 'augmentor': { + 'white_noise': {'prob': 1.0, 'min_level': -90, 'max_level': -46}, + }, }, 'speaker_enforcement': {'enforce_num_speakers': True, 'enforce_time': [0.25, 0.75]}, 'segment_manifest': {'window': 0.5, 'shift': 0.25, 'step_count': 50, 'deci': 3}, @@ -467,7 +471,7 @@ def test_get_session_silence_mean_pass(self, sampler, mean, var): @pytest.mark.parametrize("var", [0.5, 0.6]) def test_get_session_silence_mean_fail(self, sampler, mean, var): """ - This test should raise `ValueError` because `mean_silence_var` + This test should raise `ValueError` because `mean_silence_var` should be less than `mean_silence * (1 - mean_silence)`. """ sampler.mean_silence = mean @@ -488,7 +492,7 @@ def test_get_session_overlap_mean_pass(self, sampler, mean, var): @pytest.mark.parametrize("var", [0.3, 0.8]) def test_get_session_overlap_mean_fail(self, sampler, mean, var): """ - This test should raise `ValueError` because `mean_overlap_var` + This test should raise `ValueError` because `mean_overlap_var` should be less than `mean_overlap * (1 - mean_overlap)`. """ sampler.mean_overlap = mean diff --git a/tests/collections/speaker_tasks/utils/test_diar_utils.py b/tests/collections/speaker_tasks/utils/test_diar_utils.py index cd7e7f5b2a3b..f70f8006e8f3 100644 --- a/tests/collections/speaker_tasks/utils/test_diar_utils.py +++ b/tests/collections/speaker_tasks/utils/test_diar_utils.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os +from typing import List, Tuple import numpy as np import pytest import torch from scipy.optimize import linear_sum_assignment as scipy_linear_sum_assignment -from typing import List, Tuple -import math from nemo.collections.asr.data.audio_to_label import repeat_signal from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering @@ -83,6 +83,7 @@ def matrix(mat, use_tensor=True, dtype=torch.long): mat = np.array(mat) return mat + def __get_subsegments(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: """ Return subsegments from a segment of audio file @@ -109,8 +110,7 @@ def __get_subsegments(offset: float, window: float, shift: float, duration: floa def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim): - """Generate a set of artificial orthogonal embedding vectors from random numbers - """ + """Generate a set of artificial orthogonal embedding vectors from random numbers""" gaus = torch.randn(emb_dim, emb_dim) _svd = torch.linalg.svd(gaus) orth = _svd[0] @ _svd[2] @@ -144,7 +144,9 @@ def generate_toy_data( seg_list.extend(segments) emb_list.append(emb) if emb.shape[0] == 0: - import ipdb; ipdb.set_trace() + import ipdb + + ipdb.set_trace() multiscale_segment_counts[scale_idx] += emb.shape[0] if scale_idx == len(multiscale_segment_counts) - 1: @@ -159,8 +161,7 @@ def generate_toy_data( class TestDiarizationSequneceUtilFunctions: - """Tests diarization and speaker-task related utils. - """ + """Tests diarization and speaker-task related utils.""" @pytest.mark.unit @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]]) @@ -307,7 +308,10 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10) em_s, ts_s = split_input_data(em, ts, mc) merged_embs, merged_clus_labels, _ = run_reducer( - pre_embs=em_s[-1], target_spk_idx=target_speaker_index, merge_quantity=merge_quantity, pre_clus_labels=gt, + pre_embs=em_s[-1], + target_spk_idx=target_speaker_index, + merge_quantity=merge_quantity, + pre_clus_labels=gt, ) assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] @@ -316,7 +320,11 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)]) @pytest.mark.parametrize("mspb", [25]) def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0])) @pytest.mark.unit @@ -324,7 +332,11 @@ def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)]) @pytest.mark.parametrize("mspb", [0, 25]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0, 0])) @pytest.mark.unit @@ -332,7 +344,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)]) @pytest.mark.parametrize("mspb", [3, 10]) def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([18, 13, 56, 0])) @pytest.mark.unit @@ -340,7 +356,11 @@ def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)]) @pytest.mark.parametrize("mspb", [0, 2]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 1, 0])) @pytest.mark.unit @@ -348,7 +368,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)]) @pytest.mark.parametrize("mspb", [2]) def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 0, 0, 0])) @@ -384,6 +408,7 @@ def test_online_speaker_clustering_instance_export(self): offline_speaker_clustering = torch.jit.script(offline_speaker_clustering) isinstance(offline_speaker_clustering, torch.jit._script.RecursiveScriptClass) + class TestGetSubsegments: @pytest.mark.unit @pytest.mark.parametrize( @@ -392,7 +417,18 @@ class TestGetSubsegments: (12.05, 1.5, 0.75, 2.4, 0.01, 2, False, 16000, 100, [[12.05, 1.5], [12.8, 1.5], [13.55, 0.9]]), (0, 1.0, 0.5, 0.4, 0.01, 2, False, 16000, 100, [[0, 0.4]]), (0, 2.0, 1.0, 1.5, 0.5, 2, False, 16000, 100, [[0, 1.5]]), - (10, 1.5, 0.75, 4.5, 0.5, 2, False, 16000, 100, [[10, 1.5], [10.75, 1.5], [11.5, 1.5], [12.25, 1.5], [13.0, 1.5]]), + ( + 10, + 1.5, + 0.75, + 4.5, + 0.5, + 2, + False, + 16000, + 100, + [[10, 1.5], [10.75, 1.5], [11.5, 1.5], [12.25, 1.5], [13.0, 1.5]], + ), (0, 1.5, 0.5, 0.3, 0.01, 2, True, 16000, 100, [[0, 0.3]]), ], ) @@ -409,7 +445,7 @@ def test_get_subsegments( feat_per_sec, expected, ): - + for is_scriptable in [True, False]: if is_scriptable: result = get_subsegments_scriptable( @@ -534,13 +570,21 @@ def test_is_overlap_false(self, rangeA, rangeB): @pytest.mark.parametrize("x", [1.0, 2.3456]) @pytest.mark.parametrize("decimals", [1, 2, 3, 4]) def test_fl2int(self, x, decimals): - assert fl2int(x, decimals) == round(x * 10 ** decimals, 0) + assert fl2int(x, decimals) == round(x * 10**decimals, 0) @pytest.mark.unit @pytest.mark.parametrize("x", [1234]) - @pytest.mark.parametrize("decimals", [1, 2, 3, 4,]) + @pytest.mark.parametrize( + "decimals", + [ + 1, + 2, + 3, + 4, + ], + ) def test_int2fl(self, x, decimals): - assert abs(int2fl(x, decimals) - round(x / (10 ** decimals), decimals)) < (10 ** -(decimals + 1)) + assert abs(int2fl(x, decimals) - round(x / (10**decimals), decimals)) < (10 ** -(decimals + 1)) @pytest.mark.unit def test_merge_float_intervals_edge_margin_test(self): @@ -582,7 +626,11 @@ def test_get_speech_labels_for_update(self): vad_timestamps = torch.tensor([[0.9600, 4.8400]]) cursor_for_old_segments = 1.0 speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update( - frame_start, buffer_end, cumulative_speech_labels, vad_timestamps, cursor_for_old_segments, + frame_start, + buffer_end, + cumulative_speech_labels, + vad_timestamps, + cursor_for_old_segments, ) assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8 assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8 @@ -652,7 +700,10 @@ def test_tensor_to_list(self, source_range_list): @pytest.mark.unit @pytest.mark.parametrize( "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate", - [(0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000),], + [ + (0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), + (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000), + ], ) def test_get_online_segments_from_slices( self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate @@ -785,7 +836,13 @@ def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_si @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_cpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -817,7 +874,13 @@ def test_offline_speaker_clustering_very_short_cpu( @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_gpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -1028,7 +1091,7 @@ def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix): Test the linear sum assignment algorithm with a cost matrix Compare with the scipy implementation and make sure the final cost is the same. - NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. + NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. This test only checks if the cost is the same. """ row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix) diff --git a/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py index 1c45603643d3..9000cb82b842 100644 --- a/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py +++ b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py @@ -12,23 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import os import numpy as np import pytest import torch -import itertools from omegaconf import DictConfig from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( - find_first_nonzero, find_best_permutation, - reconstruct_labels, + find_first_nonzero, get_ats_targets, - get_pil_targets, get_hidden_length_from_sample_length, + get_pil_targets, + reconstruct_labels, ) + def reconstruct_labels_forloop(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: """ This is a for-loop implementation of reconstruct_labels built for testing purposes. @@ -41,6 +42,7 @@ def reconstruct_labels_forloop(labels: torch.Tensor, batch_perm_inds: torch.Tens reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) return reconstructed_labels + class TestSortingUtils: @pytest.mark.unit @pytest.mark.parametrize( @@ -69,15 +71,18 @@ class TestSortingUtils: # Test 11: Matrix with 101 columns, first nonzero value is towards the end (torch.cat([torch.zeros(1, 100), torch.ones(1, 1)], dim=1), -1, 0.5, torch.tensor([100])), # Test 12: Matrix with 1000 columns, all below threshold except one near the middle - (torch.cat([torch.zeros(1, 499), torch.tensor([[0.6]]), torch.zeros(1, 500)], dim=1), -1, 0.5, torch.tensor([499])), - - ] + ( + torch.cat([torch.zeros(1, 499), torch.tensor([[0.6]]), torch.zeros(1, 500)], dim=1), + -1, + 0.5, + torch.tensor([499]), + ), + ], ) def test_find_first_nonzero(self, mat, max_cap_val, thres, expected): result = find_first_nonzero(mat, max_cap_val, thres) assert torch.equal(result, expected), f"Expected {expected} but got {result}" - - + @pytest.mark.unit @pytest.mark.parametrize( "match_score, speaker_permutations, expected", @@ -86,51 +91,57 @@ def test_find_first_nonzero(self, mat, max_cap_val, thres, expected): ( torch.tensor([[0.1, 0.9, 0.2]]), # match_score (batch_size=1, num_permutations=3) torch.tensor([[0, 1], [1, 0], [0, 1]]), # speaker_permutations (num_permutations=3, num_speakers=2) - torch.tensor([[1, 0]]) # expected best permutation for the batch + torch.tensor([[1, 0]]), # expected best permutation for the batch ), # Test 2: Batch size 2, different best matches for each batch ( torch.tensor([[0.5, 0.3, 0.7], [0.2, 0.6, 0.4]]), # match_score (batch_size=2, num_permutations=3) torch.tensor([[0, 1], [1, 0], [0, 1]]), # speaker_permutations - torch.tensor([[0, 1], [1, 0]]) # expected best permutations + torch.tensor([[0, 1], [1, 0]]), # expected best permutations ), # Test 3: Larger number of speakers and permutations ( - torch.tensor([[0.1, 0.4, 0.9, 0.5], [0.6, 0.3, 0.7, 0.2]]), # match_score (batch_size=2, num_permutations=4) - torch.tensor([[0, 1, 2], [1, 0, 2], [2, 1, 0], [1, 2, 0]]), # speaker_permutations (num_permutations=4, num_speakers=3) - torch.tensor([[2, 1, 0], [2, 1, 0]]) # expected best permutations + torch.tensor( + [[0.1, 0.4, 0.9, 0.5], [0.6, 0.3, 0.7, 0.2]] + ), # match_score (batch_size=2, num_permutations=4) + torch.tensor( + [[0, 1, 2], [1, 0, 2], [2, 1, 0], [1, 2, 0]] + ), # speaker_permutations (num_permutations=4, num_speakers=3) + torch.tensor([[2, 1, 0], [2, 1, 0]]), # expected best permutations ), # Test 4: All match scores are the same, should pick the first permutation (argmax behavior) ( torch.tensor([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]), # equal match_score across permutations torch.tensor([[0, 1], [1, 0], [0, 1]]), # speaker_permutations - torch.tensor([[0, 1], [0, 1]]) # first permutation is chosen as tie-breaker + torch.tensor([[0, 1], [0, 1]]), # first permutation is chosen as tie-breaker ), # Test 5: Single speaker case (num_speakers = 1) ( torch.tensor([[0.8, 0.2]]), # match_score (batch_size=1, num_permutations=2) torch.tensor([[0], [0]]), # speaker_permutations (num_permutations=2, num_speakers=1) - torch.tensor([[0]]) # expected best permutation + torch.tensor([[0]]), # expected best permutation ), # Test 6: Batch size 3, varying permutations ( torch.tensor([[0.3, 0.6], [0.4, 0.1], [0.2, 0.7]]), # match_score (batch_size=3, num_permutations=2) torch.tensor([[0, 1], [1, 0]]), # speaker_permutations - torch.tensor([[1, 0], [0, 1], [1, 0]]) # expected best permutations for each batch + torch.tensor([[1, 0], [0, 1], [1, 0]]), # expected best permutations for each batch ), - ] + ], ) def test_find_best_permutation(self, match_score, speaker_permutations, expected): result = find_best_permutation(match_score, speaker_permutations) assert torch.equal(result, expected), f"Expected {expected} but got {result}" - - @pytest.mark.parametrize("batch_size, num_frames, num_speakers", [ - (2, 4, 3), # Original test case - (3, 5, 2), # More frames and speakers - (1, 6, 4), # Single batch with more frames and speakers - (5, 3, 5), # More batch size with equal frames and speakers - ]) + @pytest.mark.parametrize( + "batch_size, num_frames, num_speakers", + [ + (2, 4, 3), # Original test case + (3, 5, 2), # More frames and speakers + (1, 6, 4), # Single batch with more frames and speakers + (5, 3, 5), # More batch size with equal frames and speakers + ], + ) def test_reconstruct_labels_with_forloop_ver(self, batch_size, num_frames, num_speakers): # Generate random labels and batch_perm_inds tensor for testing labels = torch.rand(batch_size, num_frames, num_speakers) @@ -142,49 +153,58 @@ def test_reconstruct_labels_with_forloop_ver(self, batch_size, num_frames, num_s # Assert that both methods return the same result assert torch.allclose(result_matrix, result_forloop), "The results are not equal!" - - @pytest.mark.parametrize("labels, batch_perm_inds, expected_output", [ - # Example 1: Small batch size with a few frames and speakers - ( - torch.tensor([ - [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], # First batch - [[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]] # Second batch - ]), - torch.tensor([[2, 0, 1], [1, 2, 0]]), - torch.tensor([ - [[0.3, 0.1, 0.2], [0.6, 0.4, 0.5], [0.9, 0.7, 0.8]], # First batch reconstructed - [[0.8, 0.7, 0.9], [0.5, 0.4, 0.6], [0.2, 0.1, 0.3]] # Second batch reconstructed - ]) - ), - - # Example 2: batch_size = 1 with more frames and speakers - ( - torch.tensor([ - [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2], [1.3, 1.4, 1.5, 1.6]] - ]), - torch.tensor([[3, 0, 1, 2]]), - torch.tensor([ - [[0.4, 0.1, 0.2, 0.3], [0.8, 0.5, 0.6, 0.7], [1.2, 0.9, 1.0, 1.1], [1.6, 1.3, 1.4, 1.5]] - ]) - ), - - # Example 3: Larger batch size with fewer frames and speakers - ( - torch.tensor([ - [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], # First batch - [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], # Second batch - [[1.3, 1.4], [1.5, 1.6], [1.7, 1.8]], # Third batch - [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]] # Fourth batch - ]), - torch.tensor([[1, 0], [0, 1], [1, 0], [0, 1]]), - torch.tensor([ - [[0.2, 0.1], [0.4, 0.3], [0.6, 0.5]], # First batch reconstructed - [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], # Second batch unchanged - [[1.4, 1.3], [1.6, 1.5], [1.8, 1.7]], # Third batch reconstructed - [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]] # Fourth batch unchanged - ]) - ) - ]) + + @pytest.mark.parametrize( + "labels, batch_perm_inds, expected_output", + [ + # Example 1: Small batch size with a few frames and speakers + ( + torch.tensor( + [ + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], # First batch + [[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]], # Second batch + ] + ), + torch.tensor([[2, 0, 1], [1, 2, 0]]), + torch.tensor( + [ + [[0.3, 0.1, 0.2], [0.6, 0.4, 0.5], [0.9, 0.7, 0.8]], # First batch reconstructed + [[0.8, 0.7, 0.9], [0.5, 0.4, 0.6], [0.2, 0.1, 0.3]], # Second batch reconstructed + ] + ), + ), + # Example 2: batch_size = 1 with more frames and speakers + ( + torch.tensor( + [[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2], [1.3, 1.4, 1.5, 1.6]]] + ), + torch.tensor([[3, 0, 1, 2]]), + torch.tensor( + [[[0.4, 0.1, 0.2, 0.3], [0.8, 0.5, 0.6, 0.7], [1.2, 0.9, 1.0, 1.1], [1.6, 1.3, 1.4, 1.5]]] + ), + ), + # Example 3: Larger batch size with fewer frames and speakers + ( + torch.tensor( + [ + [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], # First batch + [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], # Second batch + [[1.3, 1.4], [1.5, 1.6], [1.7, 1.8]], # Third batch + [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]], # Fourth batch + ] + ), + torch.tensor([[1, 0], [0, 1], [1, 0], [0, 1]]), + torch.tensor( + [ + [[0.2, 0.1], [0.4, 0.3], [0.6, 0.5]], # First batch reconstructed + [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], # Second batch unchanged + [[1.4, 1.3], [1.6, 1.5], [1.8, 1.7]], # Third batch reconstructed + [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]], # Fourth batch unchanged + ] + ), + ), + ], + ) def test_reconstruct_labels(self, labels, batch_perm_inds, expected_output): # Call the reconstruct_labels function result = reconstruct_labels(labels, batch_perm_inds) @@ -192,44 +212,51 @@ def test_reconstruct_labels(self, labels, batch_perm_inds, expected_output): assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" - class TestTargetGenerators: - @pytest.mark.parametrize("labels, preds, num_speakers, expected_output", [ - # Test 1: Basic case with simple permutations - ( - torch.tensor([ - [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], # Batch 1 - [[0.0, 0.0, 0.9], [0.0, 0.9, 0.1], [0.9, 0.1, 0.0]] # Batch 2 - ]), - torch.tensor([ - [[0.8, 0.2, 0.0], [0.2, 0.7, 0.0], [0.0, 0.1, 0.9]], # Batch 1 - [[0.0, 0.0, 0.8], [0.0, 0.8, 0.2], [0.9, 0.1, 0.0]] # Batch 2 - ]), - 3, # Number of speakers - torch.tensor([ - [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], # Expected labels for Batch 1 - [[0.9, 0.0, 0.0], [0.1, 0.9, 0.0], [0.0, 0.1, 0.9]] # Expected labels for Batch 2 - ]) - ), - - # Test 2: Ambiguous case - ( - torch.tensor([[[0.9, 0.8, 0.7], [0.2, 0.8, 0.7], [0.2, 0.3, 0.9]]]), # Labels - torch.tensor([[[0.6, 0.7, 0.2], [0.9, 0.4, 0.0], [0.1, 0.7, 0.1]]]), # Preds - 3, # Number of speakers - torch.tensor([[[0.8, 0.7, 0.9], [0.8, 0.7, 0.2], [0.3, 0.9, 0.2]]]) # Expected output - ), - - # Test 3: Ambiguous case - ( - torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), # Labels - torch.tensor([[[0.6, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]]), # Preds - 4, # Number of speakers - torch.tensor([[[1, 1, 0, 0], [1, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0]]]) # Expected output - ) - - ]) + @pytest.mark.parametrize( + "labels, preds, num_speakers, expected_output", + [ + # Test 1: Basic case with simple permutations + ( + torch.tensor( + [ + [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], # Batch 1 + [[0.0, 0.0, 0.9], [0.0, 0.9, 0.1], [0.9, 0.1, 0.0]], # Batch 2 + ] + ), + torch.tensor( + [ + [[0.8, 0.2, 0.0], [0.2, 0.7, 0.0], [0.0, 0.1, 0.9]], # Batch 1 + [[0.0, 0.0, 0.8], [0.0, 0.8, 0.2], [0.9, 0.1, 0.0]], # Batch 2 + ] + ), + 3, # Number of speakers + torch.tensor( + [ + [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], # Expected labels for Batch 1 + [[0.9, 0.0, 0.0], [0.1, 0.9, 0.0], [0.0, 0.1, 0.9]], # Expected labels for Batch 2 + ] + ), + ), + # Test 2: Ambiguous case + ( + torch.tensor([[[0.9, 0.8, 0.7], [0.2, 0.8, 0.7], [0.2, 0.3, 0.9]]]), # Labels + torch.tensor([[[0.6, 0.7, 0.2], [0.9, 0.4, 0.0], [0.1, 0.7, 0.1]]]), # Preds + 3, # Number of speakers + torch.tensor([[[0.8, 0.7, 0.9], [0.8, 0.7, 0.2], [0.3, 0.9, 0.2]]]), # Expected output + ), + # Test 3: Ambiguous case + ( + torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), # Labels + torch.tensor( + [[[0.6, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]] + ), # Preds + 4, # Number of speakers + torch.tensor([[[1, 1, 0, 0], [1, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0]]]), # Expected output + ), + ], + ) def test_get_ats_targets(self, labels, preds, num_speakers, expected_output): # Generate all permutations for the given number of speakers speaker_inds = list(range(num_speakers)) @@ -240,35 +267,40 @@ def test_get_ats_targets(self, labels, preds, num_speakers, expected_output): # Assert that the result matches the expected output assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" - @pytest.mark.unit @pytest.mark.parametrize( "labels, preds, num_speakers, expected_output", [ # Test 1: Basic case with simple permutations ( - torch.tensor([[[1, 0], [0, 1]], [[1, 0], [0, 1]]]), # Labels (batch_size=2, num_speakers=2, num_classes=2) - torch.tensor([[[1, 0], [0, 1]], [[0, 1], [1, 0]]]), # Preds (batch_size=2, num_speakers=2, num_classes=2) - 2, # Number of speakers - torch.tensor([[[1, 0], [0, 1]], [[0, 1], [1, 0]]]) # expected max_score_permed_labels + torch.tensor( + [[[1, 0], [0, 1]], [[1, 0], [0, 1]]] + ), # Labels (batch_size=2, num_speakers=2, num_classes=2) + torch.tensor( + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]] + ), # Preds (batch_size=2, num_speakers=2, num_classes=2) + 2, # Number of speakers + torch.tensor([[[1, 0], [0, 1]], [[0, 1], [1, 0]]]), # expected max_score_permed_labels ), - # Test 2: Batch size 1 with more complex permutations ( torch.tensor([[[0.8, 0.2], [0.3, 0.7]]]), # Labels torch.tensor([[[0.9, 0.1], [0.2, 0.8]]]), # Preds - 2, # Number of speakers - torch.tensor([[[0.8, 0.2], [0.3, 0.7]]]) # expected output (labels remain the same as preds are close) + 2, # Number of speakers + torch.tensor( + [[[0.8, 0.2], [0.3, 0.7]]] + ), # expected output (labels remain the same as preds are close) ), - # Test 3: Ambiguous case ( torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), # Labels - torch.tensor([[[0.61, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]]), # Preds - 4, # Number of speakers - torch.tensor([[[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]) # Expected output - ) - ] + torch.tensor( + [[[0.61, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]] + ), # Preds + 4, # Number of speakers + torch.tensor([[[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), # Expected output + ), + ], ) def test_get_pil_targets(self, labels, preds, num_speakers, expected_output): # Generate all permutations for the given number of speakers @@ -289,10 +321,14 @@ class TestGetHiddenLengthFromSampleLength: (159, 160, 8, 1), (129, 100, 5, 1), (300, 150, 3, 1), - ] + ], ) - def test_various_cases(self, num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame, expected_hidden_length): - result = get_hidden_length_from_sample_length(num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) + def test_various_cases( + self, num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame, expected_hidden_length + ): + result = get_hidden_length_from_sample_length( + num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) assert result == expected_hidden_length def test_default_parameters(self): @@ -313,8 +349,8 @@ def test_real_life_examples(self): assert get_hidden_length_from_sample_length(159999) == 125 assert get_hidden_length_from_sample_length(158720) == 125 assert get_hidden_length_from_sample_length(158719) == 124 - + assert get_hidden_length_from_sample_length(158880) == 125 assert get_hidden_length_from_sample_length(158879) == 125 assert get_hidden_length_from_sample_length(1600) == 2 - assert get_hidden_length_from_sample_length(1599) == 2 \ No newline at end of file + assert get_hidden_length_from_sample_length(1599) == 2 From c3c0b324ae65130642b259dcf0cbe65620ebc662 Mon Sep 17 00:00:00 2001 From: taejinp Date: Wed, 20 Nov 2024 16:03:15 -0800 Subject: [PATCH 26/47] Fixing code-QL issues Signed-off-by: taejinp --- .../collections/speaker_tasks/test_diar_datasets.py | 13 +++---------- .../speaker_tasks/test_diar_lhotse_datasets.py | 4 +++- .../speaker_tasks/test_diar_sortformer_models.py | 2 -- .../speaker_tasks/utils/test_diar_utils.py | 2 +- .../speaker_tasks/utils/test_multispeaker_utils.py | 4 ---- 5 files changed, 7 insertions(+), 18 deletions(-) diff --git a/tests/collections/speaker_tasks/test_diar_datasets.py b/tests/collections/speaker_tasks/test_diar_datasets.py index ea57653b9324..4c6426a97889 100644 --- a/tests/collections/speaker_tasks/test_diar_datasets.py +++ b/tests/collections/speaker_tasks/test_diar_datasets.py @@ -11,25 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy -import filecmp + import json import os -import shutil import tempfile -from unittest import mock - -import numpy as np import pytest -import soundfile as sf import torch.cuda -from omegaconf import DictConfig, OmegaConf -from torch.utils.data import DataLoader from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer from nemo.collections.asr.parts.utils.speaker_utils import ( - get_offset_and_duration, get_vad_out_from_rttm_line, read_rttm_lines, ) @@ -90,6 +81,8 @@ def test_e2e_speaker_diar_dataset(self, test_data_dir): soft_targets=False, ) + dataset = dataset.to(device) + dataloader_instance = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, diff --git a/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py index 82c1e8877ca2..8618a6129bf1 100644 --- a/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py +++ b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py @@ -140,6 +140,7 @@ def test_e2e_speaker_diar_lhotse_dataset(self, test_data_dir, batch_size, num_wo data_dict_list.append(data_dict) f.seek(0) + config = None if split == 'train': config = get_train_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) elif split == 'validation': @@ -153,7 +154,8 @@ def test_e2e_speaker_diar_lhotse_dataset(self, test_data_dir, batch_size, num_wo world_size=1, dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), ) - + dataloader_instance = dataloader_instance.to(device) + deviation_thres_rate = 0.01 # 1% deviation allowed for batch_index, batch in enumerate(dataloader_instance): audio_signals, audio_signal_len, targets, target_lens = batch diff --git a/tests/collections/speaker_tasks/test_diar_sortformer_models.py b/tests/collections/speaker_tasks/test_diar_sortformer_models.py index db5c3962521b..6966c56ade86 100644 --- a/tests/collections/speaker_tasks/test_diar_sortformer_models.py +++ b/tests/collections/speaker_tasks/test_diar_sortformer_models.py @@ -22,7 +22,6 @@ @pytest.fixture() def sortformer_model(): - batch_size = 4 model = { 'pil_weight': 0.5, 'ats_weight': 0.5, @@ -154,7 +153,6 @@ def test_forward_infer(self, sortformer_model, batch_size, frame_length, sample_ input_signal = torch.randn(size=(batch_size, sample_len * sampling_rate)) input_signal_length = (sample_len * sampling_rate) * torch.ones(batch_size, dtype=torch.int) targets = torch.randint(2, size=(batch_size, target_frame_count, num_spks), dtype=torch.int) - target_len = target_frame_count * torch.ones(batch_size, dtype=torch.int) with torch.no_grad(): # batch size 1 diff --git a/tests/collections/speaker_tasks/utils/test_diar_utils.py b/tests/collections/speaker_tasks/utils/test_diar_utils.py index f70f8006e8f3..a9d68447dd97 100644 --- a/tests/collections/speaker_tasks/utils/test_diar_utils.py +++ b/tests/collections/speaker_tasks/utils/test_diar_utils.py @@ -14,7 +14,7 @@ import math import os -from typing import List, Tuple +from typing import List import numpy as np import pytest diff --git a/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py index 9000cb82b842..2e01cf4b94da 100644 --- a/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py +++ b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py @@ -13,12 +13,8 @@ # limitations under the License. import itertools -import os - -import numpy as np import pytest import torch -from omegaconf import DictConfig from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( find_best_permutation, From 631555d28cbb211d88e047cfae42822e5f4c5184 Mon Sep 17 00:00:00 2001 From: tango4j Date: Thu, 21 Nov 2024 00:04:15 +0000 Subject: [PATCH 27/47] Apply isort and black reformatting Signed-off-by: tango4j --- tests/collections/speaker_tasks/test_diar_datasets.py | 8 +++----- .../speaker_tasks/test_diar_lhotse_datasets.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/collections/speaker_tasks/test_diar_datasets.py b/tests/collections/speaker_tasks/test_diar_datasets.py index 4c6426a97889..b030fb46dd7c 100644 --- a/tests/collections/speaker_tasks/test_diar_datasets.py +++ b/tests/collections/speaker_tasks/test_diar_datasets.py @@ -15,15 +15,13 @@ import json import os import tempfile + import pytest import torch.cuda from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer -from nemo.collections.asr.parts.utils.speaker_utils import ( - get_vad_out_from_rttm_line, - read_rttm_lines, -) +from nemo.collections.asr.parts.utils.speaker_utils import get_vad_out_from_rttm_line, read_rttm_lines def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec): @@ -82,7 +80,7 @@ def test_e2e_speaker_diar_dataset(self, test_data_dir): ) dataset = dataset.to(device) - + dataloader_instance = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, diff --git a/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py index 8618a6129bf1..1d48cc18b641 100644 --- a/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py +++ b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py @@ -155,7 +155,7 @@ def test_e2e_speaker_diar_lhotse_dataset(self, test_data_dir, batch_size, num_wo dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), ) dataloader_instance = dataloader_instance.to(device) - + deviation_thres_rate = 0.01 # 1% deviation allowed for batch_index, batch in enumerate(dataloader_instance): audio_signals, audio_signal_len, targets, target_lens = batch From 6a3bb6246800e09bded9238baf9eba17d2f03b44 Mon Sep 17 00:00:00 2001 From: taejinp Date: Wed, 20 Nov 2024 17:47:24 -0800 Subject: [PATCH 28/47] Reflecting PR comments from weiqingw Signed-off-by: taejinp --- ...ortformer_diar_4spk-v1_callhome-part1.yaml | 4 +- ... sortformer_diar_4spk-v1_dihard3-dev.yaml} | 4 +- .../neural_diarizer/e2e_diarize_speech.py | 27 ++++++++---- .../asr/data/audio_to_diar_label_lhotse.py | 4 +- nemo/collections/asr/losses/bce_loss.py | 6 +-- .../asr/metrics/multi_binary_acc.py | 8 ++-- .../asr/models/sortformer_diar_models.py | 10 +---- .../asr/modules/sortformer_modules.py | 2 +- .../asr/parts/utils/asr_multispeaker_utils.py | 17 +++++--- .../asr/parts/utils/speaker_utils.py | 42 +++++++------------ .../common/parts/preprocessing/collections.py | 6 +-- 11 files changed, 63 insertions(+), 67 deletions(-) rename examples/speaker_tasks/diarization/conf/post_processing/{sortformer_diar_4spk-v1_dihard-dev.yaml => sortformer_diar_4spk-v1_dihard3-dev.yaml} (73%) diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml index ebed4a649730..59bd533632d8 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -1,8 +1,8 @@ # Postprocessing parameters for timestamp outputs from speaker diarization models. # This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: # Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). -# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. -# These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477. +# These parameters were optimized with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. # Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. parameters: onset: 0.53 # Onset threshold for detecting the beginning and end of a speech diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml similarity index 73% rename from examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml rename to examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml index 9beaff6e3c7c..ebf994c10f2e 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml @@ -1,8 +1,8 @@ # Postprocessing parameters for timestamp outputs from speaker diarization models. # This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: # Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). -# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. -# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. +# These parameters were optimized with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on the development split of DIHARD3 dataset (See https://arxiv.org/pdf/2012.01477). # Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. parameters: onset: 0.64 # Onset threshold for detecting the beginning and end of a speech diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 65ba0226988a..28e2a94c7ffc 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -11,11 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -python $BASEPATH/neural_diarizer/sortformer_diarization.py \ - model_path=/path/to/sortformer_model.nemo \ - batch_size=4 \ + +""" +Usage: +End-to-end speaker diarization model can be specified by either "model_path" or "pretrained_name". +Data for diarization is fed through "dataset_manifest". +By default, post-processing is bypassed and only binarization is performed. +If you want to reproduce DER scores, you need to apply post-processing steps. +Use batch_size = 1 to have the longest inference window and the highest possible accuracy. + +python $BASEPATH/neural_diarizer/e2e_diarize_speech.py \ + model_path=/path/to/diar_sortformer_spk_v.nemo \ + batch_size=1 \ dataset_manifest=/path/to/diarization_path_to_manifest.json """ @@ -192,7 +200,7 @@ def diarization_objective( infer_audio_rttm_dict (Dict[str, Dict[str, str]]): Dictionary containing audio file paths, offsets, durations, and RTTM file paths. diar_model_preds_total_list (List[torch.Tensor]): List of prediction matrices containing - sigmoid values for each speaker. Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + sigmoid values for each speaker. Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] collar (float, optional): Collar in seconds for DER calculation. Defaults to 0.25. ignore_overlap (bool, optional): If True, DER will be calculated only for non-overlapping segments. Defaults to False. @@ -237,7 +245,7 @@ def run_optuna_hyperparam_search( postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. infer_audio_rttm_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. - Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] temp_out_dir (str): temporary directory for storing intermediate outputs. """ worker_function = lambda trial: diarization_objective( @@ -274,7 +282,7 @@ def convert_pred_mat_to_segments( Args: audio_rttm_map_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. batch_preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. - Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] unit_10ms_frame_count (int, optional): number of 10ms segments in a frame. Defaults to 8. bypass_postprocessing (bool, optional): if True, postprocessing will be bypassed. Defaults to False. @@ -285,8 +293,9 @@ def convert_pred_mat_to_segments( """ batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], [] cfg_vad_params = OmegaConf.structured(postprocessing_cfg) + pp_message = "Bypass PP, Running Binarization" if bypass_postprocessing else "Running post-processing" for sample_idx, (uniq_id, audio_rttm_values) in tqdm( - enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc="Running post-processing" + enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc=pp_message ): spk_ts = [] offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration'] @@ -385,7 +394,7 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: diar_model.test_batch() diar_model_preds_total_list = diar_model.preds_total_list torch.save(diar_model.preds_total_list, tensor_path) - + if cfg.launch_pp_optim: # Launch a hyperparameter optimization process if launch_pp_optim is True run_optuna_hyperparam_search( diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index 0839b63954f0..9e974f027b5e 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -27,7 +27,7 @@ class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): """ - This dataset is based on diarization datasets from audio_to_eesd_label.py. + This dataset is a Lhotse version of diarization dataset in audio_to_diar_label.py. Unlike native NeMo datasets, Lhotse dataset defines only the mapping from a CutSet (meta-data) to a mini-batch with PyTorch tensors. Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any). @@ -53,7 +53,7 @@ def __init__(self, cfg): self.num_speakers = self.cfg.get('num_speakers', 4) self.num_sample_per_mel_frame = int( self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) - ) # 160 + ) # 160 samples for every 1ms by default self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero', False) diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index 83f6b57c0203..b24e08bae76c 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -105,9 +105,9 @@ def forward(self, probs, labels, target_lens): # Normalize loss by number of classes norm_weight = 1 / (labels.sum(dim=0) + self.eps) norm_weight_norm = norm_weight / norm_weight.sum() - norm_weight_norm2 = torch.clamp(norm_weight_norm, min=0.05, max=1.0) - norm_weight_norm2 = norm_weight_norm2 / norm_weight_norm2.max() - norm_weight = norm_weight_norm2[None, :].expand_as(labels).detach().clone() + norm_weight_norm = torch.clamp(norm_weight_norm, min=0.05, max=1.0) + norm_weight_norm = norm_weight_norm / norm_weight_norm.max() + norm_weight = norm_weight_norm[None, :].expand_as(labels).detach().clone() else: norm_weight = torch.ones_like(labels).detach().clone() diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 3a99769ebd25..fda6891fe85f 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -92,9 +92,6 @@ def update( targets (torch.Tensor): Target values. signal_lengths (torch.Tensor): Length of each sequence in the batch input. cumulative (bool): Whether to accumulate the values over time. - - Returns: - f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. """ with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] @@ -125,6 +122,11 @@ def update( def compute(self): """ Compute F1 score from the accumulated values. Return -1 if the F1 score is NaN. + + Returns: + f1_score (torch.Tensor): F1 score calculated from the accumulated values. + precision (torch.Tensor): Precision calculated from the accumulated values. + recall (torch.Tensor): Recall calculated from the accumulated values. """ precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count + self.eps) recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count + self.eps) diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 5a3c8e354f1b..fd6c3b1ac127 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -235,7 +235,7 @@ def frontend_encoder(self, processed_signal, processed_signal_length): Generate encoder outputs from frontend encoder. Args: - process_signal (torch.Tensor): tensor containing audio-feature (mel spectrogram, mfcc, etc.) + processed_signal (torch.Tensor): tensor containing audio-feature (mel spectrogram, mfcc, etc.) processed_signal_length (torch.Tensor): tensor containing lengths of audio signal in integers Returns: @@ -264,8 +264,6 @@ def forward_infer(self, emb_seq): Returns: preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. Dimension: (batch_size, diar_frame_count, num_speakers) - encoder_states_list (list): List containing total speaker memory for each step for debugging purposes - Dimension: [(batch_size, diar_frame_count, inner dim), ... ] """ encoder_mask = self.sortformer_modules.length_to_mask(emb_seq) trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask) @@ -318,8 +316,6 @@ def forward( Returns: preds (torch.Tensor): Sorted tensor containing predicted speaker labels Dimension: (batch_size, diar_frame_count, num_speakers) - encoder_states_list (list): List containing total speaker memory for each step for debugging purposes - Dimension: [(batch_size, diar_frame_count, inner dim), ] """ processed_signal, processed_signal_length = self.process_signal( audio_signal=audio_signal, audio_signal_length=audio_signal_length @@ -387,7 +383,6 @@ def training_step(self, batch: list) -> dict: - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. - targets (torch.Tensor): The target labels for the batch. - target_lens (torch.Tensor): The length of each target sequence in the batch. - batch_idx (int): The index of the current batch. Returns: (dict): A dictionary containing the 'loss' key with the calculated loss value. @@ -517,9 +512,6 @@ def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target Shape: (batch_size, diar_frame_count, num_speakers) target_lens (torch.Tensor): Lengths of target sequences. Shape: (batch_size,) - - Returns: - dict: A dictionary containing the following validation metrics """ targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index 193dae29c304..d99bf3b93e38 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -98,7 +98,7 @@ def forward_speaker_sigmoids(self, hidden_out): hidden_out (torch.Tensor): tensor of shape (batch_size, seq_len, hidden_size) Returns: - preds (torch.Tensor): tensor of shape (batch_size, num_spks) containing speaker probabilities + preds (torch.Tensor): tensor of shape (batch_size, seq_len, num_spks) containing speaker probabilities """ hidden_out = self.dropout(F.relu(hidden_out)) hidden_out = self.first_hidden_to_hidden(hidden_out) diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index eddfd3254adc..8b96e627d97a 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -14,6 +14,7 @@ import math import torch +from typing import Optional, Union from lhotse import SupervisionSet from lhotse.cut import MixedCut, MonoCut @@ -176,7 +177,7 @@ def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutati def find_segments_from_rttm( recording_id: str, - rttms, + rttms: SupervisionSet, start_after: float, end_before: float, adjust_offset: bool = True, @@ -214,7 +215,7 @@ def find_segments_from_rttm( def get_mask_from_segments( segments: list, - a_cut, + a_cut: Optional[Union[MonoCut, MixedCut]], speaker_to_idx_map: torch.Tensor, num_speakers: int = 4, feat_per_sec: int = 100, @@ -255,7 +256,7 @@ def get_mask_from_segments( return mask -def get_soft_mask(feat_level_target, num_samples, stride): +def get_soft_mask(feat_level_target, num_frames, stride): """ Get soft mask from feat_level_target with stride. This function is needed for speaker diarization with ASR model trainings. @@ -265,17 +266,21 @@ def get_soft_mask(feat_level_target, num_samples, stride): Dimension: (num_frames, num_speakers) num_sample (int): The total number of samples. stride (int): The stride for the mask. + + Returns: + mask: The soft mask of shape (num_frames, num_speakers). + Dimension: (num_frames, num_speakers) """ num_speakers = feat_level_target.shape[1] - mask = torch.zeros(num_samples, num_speakers) + mask = torch.zeros(num_frames, num_speakers) - for index in range(num_samples): + for index in range(num_frames): if index == 0: seg_stt_feat = 0 else: seg_stt_feat = stride * index - 1 - int(stride / 2) - if index == num_samples - 1: + if index == num_frames - 1: seg_end_feat = feat_level_target.shape[0] else: seg_end_feat = stride * index - 1 + int(stride / 2) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 14a336a97479..99c8b5850a2c 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -970,7 +970,7 @@ def get_subsegments( Returns: subsegments (List[tuple[float, float]]): subsegments generated for the segments as - list of tuple of start and duration of each subsegment + list of tuple of start and duration of each subsegment """ subsegments: List[List[float]] = [] start = offset @@ -1000,32 +1000,20 @@ def get_subsegments( def get_subsegments_scriptable(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: """ - <<<<<<< HEAD - This function returns subsegments from a segment of an audio file. - Although this implementation is inefficient due to the use of a for-loop for segmentation, - it is designed to be torch-jit-scriptable. - Use `get_subsegments` for a more efficient implementation. - - ======= - Return subsegments from a segment of audio file. - This function is inefficient since the segmentation is based on for-loop, - but this implementation makes this function torch-jit-scriptable. - - >>>>>>> 681fe3881c7029e104788ef621d020b8f94bd410 - Args: - offset (float): start time of audio segment - window (float): window length for segments to subsegments length - shift (float): hop length for subsegments shift - duration (float): duration of segment - Returns: - <<<<<<< HEAD - subsegments (List[tuple[float, float]]): subsegments generated for the segments - as list of tuple of start and duration of - each subsegment - ======= - subsegments (List[tuple[float, float]]): subsegments generated for the segments - as list of tuple of start and duration of each subsegment - >>>>>>> 681fe3881c7029e104788ef621d020b8f94bd410 + This function returns subsegments from a segment of an audio file. + Although this implementation is inefficient due to the use of a for-loop for segmentation, + it is designed to be torch-jit-scriptable. + Use `get_subsegments` for a more efficient implementation. + + Args: + offset (float): start time of audio segment + window (float): window length for segments to subsegments length + shift (float): hop length for subsegments shift + duration (float): duration of segment + Returns: + subsegments (List[tuple[float, float]]): subsegments generated for the segments + as list of tuple of start and duration of + each subsegment """ subsegments: List[List[float]] = [] start = offset diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 5773ddf4b79b..3684b176cc22 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -1219,7 +1219,7 @@ def __init__( manifests_files: Union[str, List[str]], emb_dict: Dict, clus_label_dict: Dict, - round_digit=2, + round_digits: int=2, seq_eval_mode=False, pairwise_infer=False, *args, @@ -1249,7 +1249,7 @@ def __init__( *args: Args to pass to `SpeechLabel` constructor. **kwargs: Kwargs to pass to `SpeechLabel` constructor. """ - self.round_digit = round_digit + self.round_digits = round_digits self.emb_dict = emb_dict self.clus_label_dict = clus_label_dict self.seq_eval_mode = seq_eval_mode @@ -1479,7 +1479,7 @@ class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel): def __init__( self, manifests_files: Union[str, List[str]], - round_digits=2, + round_digits: int =2, *args, **kwargs, ): From b8a49ea8a2bb2aa5eea742c7b9b9d4266790a313 Mon Sep 17 00:00:00 2001 From: tango4j Date: Thu, 21 Nov 2024 01:48:36 +0000 Subject: [PATCH 29/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../diarization/neural_diarizer/e2e_diarize_speech.py | 2 +- nemo/collections/asr/data/audio_to_diar_label_lhotse.py | 2 +- nemo/collections/asr/metrics/multi_binary_acc.py | 2 +- .../collections/asr/parts/utils/asr_multispeaker_utils.py | 5 +++-- nemo/collections/asr/parts/utils/speaker_utils.py | 8 ++++---- .../collections/common/parts/preprocessing/collections.py | 4 ++-- 6 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 28e2a94c7ffc..4a4c7628c3fc 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -394,7 +394,7 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: diar_model.test_batch() diar_model_preds_total_list = diar_model.preds_total_list torch.save(diar_model.preds_total_list, tensor_path) - + if cfg.launch_pp_optim: # Launch a hyperparameter optimization process if launch_pp_optim is True run_optuna_hyperparam_search( diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index 9e974f027b5e..927e3887de78 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -53,7 +53,7 @@ def __init__(self, cfg): self.num_speakers = self.cfg.get('num_speakers', 4) self.num_sample_per_mel_frame = int( self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) - ) # 160 samples for every 1ms by default + ) # 160 samples for every 1ms by default self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero', False) diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index fda6891fe85f..7b2b9148a74e 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -122,7 +122,7 @@ def update( def compute(self): """ Compute F1 score from the accumulated values. Return -1 if the F1 score is NaN. - + Returns: f1_score (torch.Tensor): F1 score calculated from the accumulated values. precision (torch.Tensor): Precision calculated from the accumulated values. diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 8b96e627d97a..66cfcc75f49f 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -13,8 +13,9 @@ # limitations under the License. import math -import torch from typing import Optional, Union + +import torch from lhotse import SupervisionSet from lhotse.cut import MixedCut, MonoCut @@ -266,7 +267,7 @@ def get_soft_mask(feat_level_target, num_frames, stride): Dimension: (num_frames, num_speakers) num_sample (int): The total number of samples. stride (int): The stride for the mask. - + Returns: mask: The soft mask of shape (num_frames, num_speakers). Dimension: (num_frames, num_speakers) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 99c8b5850a2c..223916e60a76 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -1001,18 +1001,18 @@ def get_subsegments( def get_subsegments_scriptable(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: """ This function returns subsegments from a segment of an audio file. - Although this implementation is inefficient due to the use of a for-loop for segmentation, + Although this implementation is inefficient due to the use of a for-loop for segmentation, it is designed to be torch-jit-scriptable. Use `get_subsegments` for a more efficient implementation. - + Args: offset (float): start time of audio segment window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift duration (float): duration of segment Returns: - subsegments (List[tuple[float, float]]): subsegments generated for the segments - as list of tuple of start and duration of + subsegments (List[tuple[float, float]]): subsegments generated for the segments + as list of tuple of start and duration of each subsegment """ subsegments: List[List[float]] = [] diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 3684b176cc22..6cc31d328e23 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -1219,7 +1219,7 @@ def __init__( manifests_files: Union[str, List[str]], emb_dict: Dict, clus_label_dict: Dict, - round_digits: int=2, + round_digits: int = 2, seq_eval_mode=False, pairwise_infer=False, *args, @@ -1479,7 +1479,7 @@ class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel): def __init__( self, manifests_files: Union[str, List[str]], - round_digits: int =2, + round_digits: int = 2, *args, **kwargs, ): From 6198a20b4dcb0f252c101b05d8a9ea933e700fd2 Mon Sep 17 00:00:00 2001 From: taejinp Date: Wed, 20 Nov 2024 17:51:43 -0800 Subject: [PATCH 30/47] Line too long pylint issue resolved in e2e_diarize_speech.py Signed-off-by: taejinp --- .../diarization/neural_diarizer/e2e_diarize_speech.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 28e2a94c7ffc..522b7d53ee0c 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -200,7 +200,8 @@ def diarization_objective( infer_audio_rttm_dict (Dict[str, Dict[str, str]]): Dictionary containing audio file paths, offsets, durations, and RTTM file paths. diar_model_preds_total_list (List[torch.Tensor]): List of prediction matrices containing - sigmoid values for each speaker. Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] + sigmoid values for each speaker. + Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] collar (float, optional): Collar in seconds for DER calculation. Defaults to 0.25. ignore_overlap (bool, optional): If True, DER will be calculated only for non-overlapping segments. Defaults to False. From 07c424224176a0aee681a2eedfdda227b5cd3bad Mon Sep 17 00:00:00 2001 From: tango4j Date: Thu, 21 Nov 2024 01:52:53 +0000 Subject: [PATCH 31/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../diarization/neural_diarizer/e2e_diarize_speech.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 52681d3368fe..4f76d0e24b8a 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -200,7 +200,7 @@ def diarization_objective( infer_audio_rttm_dict (Dict[str, Dict[str, str]]): Dictionary containing audio file paths, offsets, durations, and RTTM file paths. diar_model_preds_total_list (List[torch.Tensor]): List of prediction matrices containing - sigmoid values for each speaker. + sigmoid values for each speaker. Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] collar (float, optional): Collar in seconds for DER calculation. Defaults to 0.25. ignore_overlap (bool, optional): If True, DER will be calculated only for non-overlapping segments. From 9feb013b047051cd3c068beff2b47ccae9da30ac Mon Sep 17 00:00:00 2001 From: taejinp Date: Wed, 20 Nov 2024 17:53:41 -0800 Subject: [PATCH 32/47] Resovled unused variable issue in model test Signed-off-by: taejinp --- tests/collections/speaker_tasks/test_diar_sortformer_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/collections/speaker_tasks/test_diar_sortformer_models.py b/tests/collections/speaker_tasks/test_diar_sortformer_models.py index 6966c56ade86..76df22668441 100644 --- a/tests/collections/speaker_tasks/test_diar_sortformer_models.py +++ b/tests/collections/speaker_tasks/test_diar_sortformer_models.py @@ -152,7 +152,6 @@ def test_forward_infer(self, sortformer_model, batch_size, frame_length, sample_ target_frame_count = int(sample_len // frame_length) input_signal = torch.randn(size=(batch_size, sample_len * sampling_rate)) input_signal_length = (sample_len * sampling_rate) * torch.ones(batch_size, dtype=torch.int) - targets = torch.randint(2, size=(batch_size, target_frame_count, num_spks), dtype=torch.int) with torch.no_grad(): # batch size 1 From fa111556978bf188c0e3a1364e3d6d8b456ee20a Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 21 Nov 2024 13:48:12 -0800 Subject: [PATCH 33/47] Reflecting the comment on Nov 21st 2024. Signed-off-by: taejinp --- ...rtformer_diarizer_hybrid_loss_4spk-v1.yaml | 10 +++---- ...ortformer_diar_4spk-v1_callhome-part1.yaml | 2 +- .../neural_diarizer/e2e_diarize_speech.py | 27 ++++++++++++------- .../neural_diarizer/sortformer_diar_train.py | 5 +++- .../asr/models/sortformer_diar_models.py | 21 ++++++--------- requirements/requirements_asr.txt | 1 + 6 files changed, 36 insertions(+), 30 deletions(-) diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml index 4a6d8f242d36..38346531b95a 100644 --- a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -12,17 +12,15 @@ batch_size: 8 model: pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model - num_workers: ${num_workers} # Number of workers for data loading fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder tf_d_model: 192 # Hidden dimension size of the Transformer Encoder max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4 - session_len_sec: 90 # Maximum session length in seconds train_ds: manifest_filepath: ??? sample_rate: ${sample_rate} num_spks: ${model.max_num_of_spks} - session_len_sec: ${model.session_len_sec} + session_len_sec: 90 # Maximum session length in seconds soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity. soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss labels: null @@ -36,7 +34,7 @@ model: num_buckets: 10 bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90] pin_memory: True - min_duration: 80 + min_duration: 10 max_duration: 90 batch_duration: 400 quadratic_duration: 1200 @@ -51,7 +49,7 @@ model: tarred_audio_filepaths: null sample_rate: ${sample_rate} num_spks: ${model.max_num_of_spks} - session_len_sec: ${model.session_len_sec} + session_len_sec: 90 # Maximum session length in seconds soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes. soft_targets: False labels: null @@ -73,7 +71,7 @@ model: tarred_audio_filepaths: null sample_rate: 16000 num_spks: ${model.max_num_of_spks} - session_len_sec: ${model.session_len_sec} + session_len_sec: 90 # Maximum session length in seconds soft_label_thres: 0.5 soft_targets: False labels: null diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml index 59bd533632d8..9b7a9701c4f2 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -2,7 +2,7 @@ # This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: # Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). # These parameters were optimized with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. -# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. +# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the part1 (callhome1) specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v2/run.sh # Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. parameters: onset: 0.53 # Onset threshold for detecting the beginning and end of a speech diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 4f76d0e24b8a..6eb767dd4ccb 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -14,17 +14,26 @@ """ -Usage: -End-to-end speaker diarization model can be specified by either "model_path" or "pretrained_name". -Data for diarization is fed through "dataset_manifest". -By default, post-processing is bypassed and only binarization is performed. -If you want to reproduce DER scores, you need to apply post-processing steps. +This script provides an inference and evaluation script for end-to-end speaker diarization models. +The performance of the diarization model is measured using the Diarization Error Rate (DER). +If you want to evaluate its performance, the manifest JSON file should contain the corresponding RTTM +(Rich Transcription Time Marked) file. +Please refer to the NeMo Library Documentation for more details on data preparation for diarization inference: +https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit +/asr/speaker_diarization/datasets.html#data-preparation-for-inference + +Usage for diarization inference: + +The end-to-end speaker diarization model can be specified by either "model_path" or "pretrained_name". +Data for diarization is fed through the "dataset_manifest". +By default, post-processing is bypassed, and only binarization is performed. +If you want to reproduce DER scores reported on NeMo model cards, you need to apply post-processing steps. Use batch_size = 1 to have the longest inference window and the highest possible accuracy. python $BASEPATH/neural_diarizer/e2e_diarize_speech.py \ - model_path=/path/to/diar_sortformer_spk_v.nemo \ + model_path=/path/to/diar_sortformer_4spk_v1.nemo \ batch_size=1 \ - dataset_manifest=/path/to/diarization_path_to_manifest.json + dataset_manifest=/path/to/diarization_manifest.json """ import logging @@ -34,7 +43,7 @@ from typing import Dict, List, Optional, Union import optuna -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import yaml from omegaconf import OmegaConf @@ -105,7 +114,7 @@ class DiarizationConfig: optuna_n_trials: int = 100000 -def load_postprocessing_from_yaml(postprocessing_yaml): +def load_postprocessing_from_yaml(postprocessing_yaml: PostProcessingParams=None) -> PostProcessingParams: """ Load postprocessing parameters from a YAML file. diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 8719b6463f70..ecb43c92f489 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from pytorch_lightning import seed_everything @@ -46,6 +46,9 @@ def main(cfg): sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) trainer.fit(sortformer_model) + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if sortformer_model.prepare_test(trainer): + trainer.test(sortformer_model) if __name__ == '__main__': main() diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index fd6c3b1ac127..3e9b0ac0de0c 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -90,9 +90,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): else: self.spec_augmentation = None - self.encoder = SortformerEncLabelModel.from_config_dict(self._cfg.encoder) - self.sortformer_modules = SortformerEncLabelModel.from_config_dict(self._cfg.sortformer_modules) - self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder) + self.encoder = SortformerEncLabelModel.from_config_dict(self._cfg.encoder).to(self.device) + self.sortformer_modules = SortformerEncLabelModel.from_config_dict(self._cfg.sortformer_modules).to(self.device) + self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder).to(self.device) + if self._cfg.encoder.d_model != self._cfg.tf_d_model: + self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) + else: + self.sortformer_modules.encoder_proj = None self._init_loss_weights() self.eps = 1e-3 @@ -158,8 +162,6 @@ def __setup_dataloader_from_config(self, config): global_rank = self._trainer.global_rank else: global_rank = 0 - time_flag = time.time() - logging.info("AAB: Starting Dataloader Instance loading... Step A") dataset = AudioToSpeechE2ESpkDiarDataset( manifest_filepath=config.manifest_filepath, @@ -171,10 +173,6 @@ def __setup_dataloader_from_config(self, config): global_rank=global_rank, soft_targets=config.soft_targets if 'soft_targets' in config else False, ) - logging.info( - f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader" - f"step B: {time.time() - time_flag}" - ) self.data_collection = dataset.collection self.collate_ds = dataset @@ -188,7 +186,6 @@ def __setup_dataloader_from_config(self, config): num_workers=config.get('num_workers', 1), pin_memory=config.get('pin_memory', False), ) - logging.info(f"AAC: Dataloader Instance loading is done ETA Step B done: {time.time() - time_flag}") return dataloader_instance def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): @@ -245,11 +242,9 @@ def frontend_encoder(self, processed_signal, processed_signal_length): # Spec augment is not applied during evaluation/testing if self.spec_augmentation is not None and self.training: processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) - self.encoder = self.encoder.to(self.device) emb_seq, emb_seq_length = self.encoder(audio_signal=processed_signal, length=processed_signal_length) emb_seq = emb_seq.transpose(1, 2) - if self._cfg.encoder.d_model != self._cfg.tf_d_model: - self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) + if self.sortformer_modules.encoder_proj is not None: emb_seq = self.sortformer_modules.encoder_proj(emb_seq) return emb_seq, emb_seq_length diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index d28b3f7980a7..783f7a483dc5 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -8,6 +8,7 @@ kaldiio lhotse>=1.26.0 librosa>=0.10.2 marshmallow +optuna packaging pyannote.core pyannote.metrics From b5878cc634122dcc7510489a7d1f93e94c3004dd Mon Sep 17 00:00:00 2001 From: tango4j Date: Thu, 21 Nov 2024 21:49:41 +0000 Subject: [PATCH 34/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../diarization/neural_diarizer/e2e_diarize_speech.py | 4 ++-- .../diarization/neural_diarizer/sortformer_diar_train.py | 1 + nemo/collections/asr/models/sortformer_diar_models.py | 8 ++++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 6eb767dd4ccb..27c06d09cc30 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -42,8 +42,8 @@ from dataclasses import dataclass, is_dataclass from typing import Dict, List, Optional, Union -import optuna import lightning.pytorch as pl +import optuna import torch import yaml from omegaconf import OmegaConf @@ -114,7 +114,7 @@ class DiarizationConfig: optuna_n_trials: int = 100000 -def load_postprocessing_from_yaml(postprocessing_yaml: PostProcessingParams=None) -> PostProcessingParams: +def load_postprocessing_from_yaml(postprocessing_yaml: PostProcessingParams = None) -> PostProcessingParams: """ Load postprocessing parameters from a YAML file. diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index ecb43c92f489..ab6e418b1072 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -50,5 +50,6 @@ def main(cfg): if sortformer_model.prepare_test(trainer): trainer.test(sortformer_model) + if __name__ == '__main__': main() diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 3e9b0ac0de0c..4d37fca32ddd 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -91,8 +91,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.spec_augmentation = None self.encoder = SortformerEncLabelModel.from_config_dict(self._cfg.encoder).to(self.device) - self.sortformer_modules = SortformerEncLabelModel.from_config_dict(self._cfg.sortformer_modules).to(self.device) - self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder).to(self.device) + self.sortformer_modules = SortformerEncLabelModel.from_config_dict(self._cfg.sortformer_modules).to( + self.device + ) + self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder).to( + self.device + ) if self._cfg.encoder.d_model != self._cfg.tf_d_model: self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) else: From 7898697e2a3bef881832ee54bc9952916e7c89d4 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 21 Nov 2024 13:55:19 -0800 Subject: [PATCH 35/47] Unused variable import time Signed-off-by: taejinp --- nemo/collections/asr/models/sortformer_diar_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 3e9b0ac0de0c..579a589298f5 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -14,7 +14,6 @@ import itertools import random -import time from collections import OrderedDict from typing import Dict, List, Optional, Union From e4006cf3e991b4bf1beff7151f76a457aaa80288 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 21 Nov 2024 17:47:14 -0800 Subject: [PATCH 36/47] Adding docstrings to score_labels() function in der.py Signed-off-by: taejinp --- nemo/collections/asr/metrics/der.py | 34 +++++++++++++++-------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index 7496f700341f..ee9e9b36424f 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -131,8 +131,8 @@ def uem_timeline_from_file(uem_file, uniq_name=''): def score_labels( AUDIO_RTTM_MAP, - all_reference, - all_hypothesis, + all_reference: list, + all_hypothesis: list, all_uem: List[List[float]] = None, collar: float = 0.25, ignore_overlap: bool = True, @@ -143,23 +143,25 @@ def score_labels( coming from Pyannote-formatted speaker diarization results and References are coming from Pyannote-formatted RTTM data. - Args: - AUDIO_RTTM_MAP (dict): - Dictionary containing information provided from manifestpath - all_reference (list[uniq_name,Annotation]): - Reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): - Hypothesis annotations for score calculation - verbose (bool): - Warns if RTTM file is not found. + AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath + all_reference (list[uniq_name,Annotation]): reference annotations for score calculation + all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation + all_uem (list[list[float]]): List of UEM segments for each audio file. If UEM file is not provided, + it will be read from manifestpath + collar (float): Length of collar (in seconds) for diarization error rate calculation + ignore_overlap (bool): If True, overlapping segments in reference and hypothesis will be ignored + verbose (bool): If True, warning messages will be printed Returns: - metric (pyannote.DiarizationErrorRate): - Pyannote Diarization Error Rate metric object. - This object contains detailed scores of each audiofile. - mapping (dict): - Mapping dict containing the mapping speaker label for each audio input + metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. + This object contains detailed scores of each audiofile. + mapping (dict): Mapping dict containing the mapping speaker label for each audio input + itemized_errors (tuple): Tuple containing (DER, CER, FA, MISS) for each audio file. + - DER: Diarization Error Rate, which is sum of all three errors, CER + FA + MISS. + - CER: Confusion Error Rate, which is sum of all errors + - FA: False Alarm Rate, which is the number of false alarm segments + - MISS: Missed Detection Rate, which is the number of missed detection segments < Caveat > Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of From ca480eb338f8986b718e402965e99b3dc3c0ad82 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 22 Nov 2024 01:49:02 +0000 Subject: [PATCH 37/47] Apply isort and black reformatting Signed-off-by: tango4j --- nemo/collections/asr/metrics/der.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index ee9e9b36424f..c8dec24eaaca 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -147,14 +147,14 @@ def score_labels( AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath all_reference (list[uniq_name,Annotation]): reference annotations for score calculation all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation - all_uem (list[list[float]]): List of UEM segments for each audio file. If UEM file is not provided, + all_uem (list[list[float]]): List of UEM segments for each audio file. If UEM file is not provided, it will be read from manifestpath collar (float): Length of collar (in seconds) for diarization error rate calculation ignore_overlap (bool): If True, overlapping segments in reference and hypothesis will be ignored verbose (bool): If True, warning messages will be printed Returns: - metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. + metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. mapping (dict): Mapping dict containing the mapping speaker label for each audio input itemized_errors (tuple): Tuple containing (DER, CER, FA, MISS) for each audio file. From af0483223c3096bdcfb9ca99d09f0b25139373b3 Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 22 Nov 2024 11:07:37 -0800 Subject: [PATCH 38/47] Reflecting comments on YAML files and model file variable changes. Signed-off-by: taejinp --- ...rtformer_diarizer_hybrid_loss_4spk-v1.yaml | 24 ++++++++++--------- .../asr/models/sortformer_diar_models.py | 23 ++++++++++-------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml index 38346531b95a..66cfc5fd1b61 100644 --- a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -1,24 +1,26 @@ # Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. -# Model name convention for Sortformer Diarizer: sortformer_diarizer___.yaml +# Model name convention for Sortformer Diarizer: sortformer_diarizer__-.yaml # (Example) `sortformer_diarizer_hybrid_loss_4spk-v1.yaml`. # Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. # Example: a manifest line for training # {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} name: "SortFormerDiarizer" -sample_rate: 16000 num_workers: 18 batch_size: 8 model: + sample_rate: 16000 pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model - fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder - tf_d_model: 192 # Hidden dimension size of the Transformer Encoder max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4 + model_defaults: + fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder + tf_d_model: 192 # Hidden dimension size of the Transformer Encoder + train_ds: manifest_filepath: ??? - sample_rate: ${sample_rate} + sample_rate: ${model.sample_rate} num_spks: ${model.max_num_of_spks} session_len_sec: 90 # Maximum session length in seconds soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity. @@ -47,7 +49,7 @@ model: manifest_filepath: ??? is_tarred: False tarred_audio_filepaths: null - sample_rate: ${sample_rate} + sample_rate: ${model.sample_rate} num_spks: ${model.max_num_of_spks} session_len_sec: 90 # Maximum session length in seconds soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes. @@ -92,7 +94,7 @@ model: _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor normalize: "per_feature" window_size: 0.025 - sample_rate: ${sample_rate} + sample_rate: ${model.sample_rate} window_stride: 0.01 window: "hann" features: 80 @@ -104,15 +106,15 @@ model: _target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules num_spks: ${model.max_num_of_spks} # Number of speakers per model. This is currently fixed at 4. dropout_rate: 0.5 # Dropout rate - fc_d_model: ${model.fc_d_model} - tf_d_model: ${model.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module + fc_d_model: ${model.model_defaults.fc_d_model} + tf_d_model: ${model.model_defaults.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module encoder: _target_: nemo.collections.asr.modules.ConformerEncoder feat_in: ${model.preprocessor.features} feat_out: -1 n_layers: 18 - d_model: ${model.fc_d_model} + d_model: ${model.model_defaults.fc_d_model} # Sub-sampling parameters subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding @@ -147,7 +149,7 @@ model: transformer_encoder: _target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder num_layers: 18 - hidden_size: ${model.tf_d_model} # Needs to be multiple of num_attention_heads + hidden_size: ${model.model_defaults.tf_d_model} # Needs to be multiple of num_attention_heads inner_size: 768 num_attention_heads: 8 attn_score_dropout: 0.5 diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index cb4398433ae9..96d61768b80d 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -96,7 +96,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder).to( self.device ) - if self._cfg.encoder.d_model != self._cfg.tf_d_model: + if self._cfg.encoder.d_model != self._cfg.model_defaults.tf_d_model: self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) else: self.sortformer_modules.encoder_proj = None @@ -328,12 +328,12 @@ def forward( preds = self.forward_infer(emb_seq) return preds - def _get_aux_train_evaluations(self, preds, targets, target_lens): + def _get_aux_train_evaluations(self, preds, targets, target_lens) -> dict: """ Compute auxiliary training evaluations including losses and metrics. This function calculates various losses and metrics for the training process, - including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) based evaluations. Args: @@ -392,11 +392,12 @@ def training_step(self, batch: list) -> dict: self.log_dict(train_metrics, sync_dist=True, on_step=True, on_epoch=False, logger=True) return {'loss': train_metrics['loss']} - def _get_aux_validation_evaluations(self, preds, targets, target_lens): + def _get_aux_validation_evaluations(self, preds, targets, target_lens) -> dict: """ Compute auxiliary validation evaluations including losses and metrics. - This function calculates various losses and metrics for the validation process, - including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + + This function calculates various losses and metrics for the training process, + including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) based evaluations. Args: @@ -408,7 +409,7 @@ def _get_aux_validation_evaluations(self, preds, targets, target_lens): Shape: (batch_size,) Returns: - dict: A dictionary containing the following validation metrics + val_metrics (dict): A dictionary containing the following validation metrics """ targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) @@ -498,9 +499,9 @@ def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target_lens): """ Compute auxiliary validation evaluations including losses and metrics. - - This function calculates various losses and metrics for the validation process, - including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + + This function calculates various losses and metrics for the training process, + including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) based evaluations. Args: @@ -573,4 +574,6 @@ def diarize( self, ): """One-clieck runner function for diarization.""" + # TODO: A direct one-click runner function that generates + # speaker labels from audio file path lists. raise NotImplementedError From edbe15918cf28d45f0c91892045dd02bf00aa113 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 22 Nov 2024 19:09:23 +0000 Subject: [PATCH 39/47] Apply isort and black reformatting Signed-off-by: tango4j --- nemo/collections/asr/models/sortformer_diar_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 96d61768b80d..f6b0eab4c895 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -395,7 +395,7 @@ def training_step(self, batch: list) -> dict: def _get_aux_validation_evaluations(self, preds, targets, target_lens) -> dict: """ Compute auxiliary validation evaluations including losses and metrics. - + This function calculates various losses and metrics for the training process, including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) based evaluations. @@ -499,7 +499,7 @@ def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target_lens): """ Compute auxiliary validation evaluations including losses and metrics. - + This function calculates various losses and metrics for the training process, including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) based evaluations. @@ -574,6 +574,6 @@ def diarize( self, ): """One-clieck runner function for diarization.""" - # TODO: A direct one-click runner function that generates + # TODO: A direct one-click runner function that generates # speaker labels from audio file path lists. raise NotImplementedError From 8365a052ef77ed39844540b28d3a0bdaae579e84 Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 22 Nov 2024 14:56:47 -0800 Subject: [PATCH 40/47] Added get_subsegments_scriptable for legacy get_subsegment functions Signed-off-by: taejinp --- .../asr/parts/utils/manifest_utils.py | 4 +- tests/collections/asr/test_diar_utils.py | 84 ++++--------------- 2 files changed, 20 insertions(+), 68 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index e9f91045c9a2..55c8e3dbb8c3 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -24,7 +24,7 @@ from nemo.collections.asr.parts.utils.speaker_utils import ( audio_rttm_map, - get_subsegments, + get_subsegments_scriptable, get_uniqname_from_filepath, rttm_to_labels, segments_manifest_to_subsegments_manifest, @@ -179,7 +179,7 @@ def get_subsegment_dict(subsegments_manifest_file: str, window: float, shift: fl segment = segment.strip() dic = json.loads(segment) audio, offset, duration, label = dic['audio_filepath'], dic['offset'], dic['duration'], dic['label'] - subsegments = get_subsegments(offset=offset, window=window, shift=shift, duration=duration) + subsegments = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=duration) if dic['uniq_id'] is not None: uniq_id = dic['uniq_id'] else: diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/asr/test_diar_utils.py index cb364675fcf4..a72313923a66 100644 --- a/tests/collections/asr/test_diar_utils.py +++ b/tests/collections/asr/test_diar_utils.py @@ -82,7 +82,8 @@ def matrix(mat, use_tensor=True, dtype=torch.long): def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim): - """Generate a set of artificial orthogonal embedding vectors from random numbers""" + """Generate a set of artificial orthogonal embedding vectors from random numbers + """ gaus = torch.randn(emb_dim, emb_dim) _svd = torch.linalg.svd(gaus) orth = _svd[0] @ _svd[2] @@ -129,7 +130,8 @@ def generate_toy_data( class TestDiarizationSequneceUtilFunctions: - """Tests diarization and speaker-task related utils.""" + """Tests diarization and speaker-task related utils. + """ @pytest.mark.unit @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]]) @@ -276,10 +278,7 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10) em_s, ts_s = split_input_data(em, ts, mc) merged_embs, merged_clus_labels, _ = run_reducer( - pre_embs=em_s[-1], - target_spk_idx=target_speaker_index, - merge_quantity=merge_quantity, - pre_clus_labels=gt, + pre_embs=em_s[-1], target_spk_idx=target_speaker_index, merge_quantity=merge_quantity, pre_clus_labels=gt, ) assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] @@ -288,11 +287,7 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)]) @pytest.mark.parametrize("mspb", [25]) def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) assert all(class_target_vol == torch.tensor([3, 0])) @pytest.mark.unit @@ -300,11 +295,7 @@ def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)]) @pytest.mark.parametrize("mspb", [0, 25]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) assert all(class_target_vol == torch.tensor([3, 0, 0])) @pytest.mark.unit @@ -312,11 +303,7 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)]) @pytest.mark.parametrize("mspb", [3, 10]) def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) assert all(class_target_vol == torch.tensor([18, 13, 56, 0])) @pytest.mark.unit @@ -324,11 +311,7 @@ def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)]) @pytest.mark.parametrize("mspb", [0, 2]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) assert all(class_target_vol == torch.tensor([2, 1, 0])) @pytest.mark.unit @@ -336,11 +319,7 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)]) @pytest.mark.parametrize("mspb", [2]) def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity( - num_to_be_removed=ntbr, - pre_clus_labels=pcl, - min_count_per_cluster=mspb, - ) + class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) assert all(class_target_vol == torch.tensor([2, 0, 0, 0])) @@ -435,21 +414,13 @@ def test_is_overlap_false(self, rangeA, rangeB): @pytest.mark.parametrize("x", [1.0, 2.3456]) @pytest.mark.parametrize("decimals", [1, 2, 3, 4]) def test_fl2int(self, x, decimals): - assert fl2int(x, decimals) == round(x * 10**decimals, 0) + assert fl2int(x, decimals) == round(x * 10 ** decimals, 0) @pytest.mark.unit @pytest.mark.parametrize("x", [1234]) - @pytest.mark.parametrize( - "decimals", - [ - 1, - 2, - 3, - 4, - ], - ) + @pytest.mark.parametrize("decimals", [1, 2, 3, 4,]) def test_int2fl(self, x, decimals): - assert abs(int2fl(x, decimals) - round(x / (10**decimals), decimals)) < (10 ** -(decimals + 1)) + assert abs(int2fl(x, decimals) - round(x / (10 ** decimals), decimals)) < (10 ** -(decimals + 1)) @pytest.mark.unit def test_merge_float_intervals_edge_margin_test(self): @@ -491,11 +462,7 @@ def test_get_speech_labels_for_update(self): vad_timestamps = torch.tensor([[0.9600, 4.8400]]) cursor_for_old_segments = 1.0 speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update( - frame_start, - buffer_end, - cumulative_speech_labels, - vad_timestamps, - cursor_for_old_segments, + frame_start, buffer_end, cumulative_speech_labels, vad_timestamps, cursor_for_old_segments, ) assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8 assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8 @@ -565,10 +532,7 @@ def test_tensor_to_list(self, source_range_list): @pytest.mark.unit @pytest.mark.parametrize( "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate", - [ - (0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), - (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000), - ], + [(0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000),], ) def test_get_online_segments_from_slices( self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate @@ -701,13 +665,7 @@ def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_si @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_cpu( - self, - n_spks, - spk_dur, - SSV, - enhanced_count_thres, - min_samples_for_nmesc, - seed, + self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -739,13 +697,7 @@ def test_offline_speaker_clustering_very_short_cpu( @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_gpu( - self, - n_spks, - spk_dur, - SSV, - enhanced_count_thres, - min_samples_for_nmesc, - seed, + self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -956,7 +908,7 @@ def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix): Test the linear sum assignment algorithm with a cost matrix Compare with the scipy implementation and make sure the final cost is the same. - NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. + NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. This test only checks if the cost is the same. """ row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix) From 86315db0bfa2edee49000e36cdc9bb7fedb1fa8d Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 22 Nov 2024 22:58:55 +0000 Subject: [PATCH 41/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../asr/parts/utils/manifest_utils.py | 20 +++-- tests/collections/asr/test_diar_utils.py | 84 +++++++++++++++---- 2 files changed, 78 insertions(+), 26 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 55c8e3dbb8c3..e05108c509ff 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -67,11 +67,11 @@ def get_ctm_line( ) -> str: """ Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. - - CTM Format: + + CTM Format: - - Reference: + + Reference: https://web.archive.org/web/20170119114252/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf Args: @@ -80,11 +80,11 @@ def get_ctm_line( start_time (float): is the begin time of the word, which we refer to as `start_time` in NeMo. duration (float): is duration of the word token (str): Token or word for the current entry - conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) - when no confidence is computed and in the reference data. + conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) + when no confidence is computed and in the reference data. type_of_token (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” speaker (str): is a string identifier for the speaker who uttered the token. This should be “null” for non-speech tokens and “unknown” when - the speaker has not been determined. + the speaker has not been determined. NA_token (str, optional): A token for . Defaults to ''. output_precision (int, optional): The precision of the output floating point number. Defaults to 3. @@ -368,7 +368,11 @@ def create_segment_manifest( segments_manifest_file = write_rttm2manifest(AUDIO_RTTM_MAP, segment_manifest_path, deci) subsegments_manifest_file = subsegment_manifest_path segments_manifest_to_subsegments_manifest( - segments_manifest_file, subsegments_manifest_file, window, shift, min_subsegment_duration, + segments_manifest_file, + subsegments_manifest_file, + window, + shift, + min_subsegment_duration, ) subsegments_dict = get_subsegment_dict(subsegments_manifest_file, window, shift, deci) write_truncated_subsegments(input_manifest_dict, subsegments_dict, output_manifest_path, step_count, deci) diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/asr/test_diar_utils.py index a72313923a66..cb364675fcf4 100644 --- a/tests/collections/asr/test_diar_utils.py +++ b/tests/collections/asr/test_diar_utils.py @@ -82,8 +82,7 @@ def matrix(mat, use_tensor=True, dtype=torch.long): def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim): - """Generate a set of artificial orthogonal embedding vectors from random numbers - """ + """Generate a set of artificial orthogonal embedding vectors from random numbers""" gaus = torch.randn(emb_dim, emb_dim) _svd = torch.linalg.svd(gaus) orth = _svd[0] @ _svd[2] @@ -130,8 +129,7 @@ def generate_toy_data( class TestDiarizationSequneceUtilFunctions: - """Tests diarization and speaker-task related utils. - """ + """Tests diarization and speaker-task related utils.""" @pytest.mark.unit @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]]) @@ -278,7 +276,10 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10) em_s, ts_s = split_input_data(em, ts, mc) merged_embs, merged_clus_labels, _ = run_reducer( - pre_embs=em_s[-1], target_spk_idx=target_speaker_index, merge_quantity=merge_quantity, pre_clus_labels=gt, + pre_embs=em_s[-1], + target_spk_idx=target_speaker_index, + merge_quantity=merge_quantity, + pre_clus_labels=gt, ) assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] @@ -287,7 +288,11 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)]) @pytest.mark.parametrize("mspb", [25]) def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0])) @pytest.mark.unit @@ -295,7 +300,11 @@ def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)]) @pytest.mark.parametrize("mspb", [0, 25]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0, 0])) @pytest.mark.unit @@ -303,7 +312,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)]) @pytest.mark.parametrize("mspb", [3, 10]) def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([18, 13, 56, 0])) @pytest.mark.unit @@ -311,7 +324,11 @@ def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)]) @pytest.mark.parametrize("mspb", [0, 2]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 1, 0])) @pytest.mark.unit @@ -319,7 +336,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)]) @pytest.mark.parametrize("mspb", [2]) def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 0, 0, 0])) @@ -414,13 +435,21 @@ def test_is_overlap_false(self, rangeA, rangeB): @pytest.mark.parametrize("x", [1.0, 2.3456]) @pytest.mark.parametrize("decimals", [1, 2, 3, 4]) def test_fl2int(self, x, decimals): - assert fl2int(x, decimals) == round(x * 10 ** decimals, 0) + assert fl2int(x, decimals) == round(x * 10**decimals, 0) @pytest.mark.unit @pytest.mark.parametrize("x", [1234]) - @pytest.mark.parametrize("decimals", [1, 2, 3, 4,]) + @pytest.mark.parametrize( + "decimals", + [ + 1, + 2, + 3, + 4, + ], + ) def test_int2fl(self, x, decimals): - assert abs(int2fl(x, decimals) - round(x / (10 ** decimals), decimals)) < (10 ** -(decimals + 1)) + assert abs(int2fl(x, decimals) - round(x / (10**decimals), decimals)) < (10 ** -(decimals + 1)) @pytest.mark.unit def test_merge_float_intervals_edge_margin_test(self): @@ -462,7 +491,11 @@ def test_get_speech_labels_for_update(self): vad_timestamps = torch.tensor([[0.9600, 4.8400]]) cursor_for_old_segments = 1.0 speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update( - frame_start, buffer_end, cumulative_speech_labels, vad_timestamps, cursor_for_old_segments, + frame_start, + buffer_end, + cumulative_speech_labels, + vad_timestamps, + cursor_for_old_segments, ) assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8 assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8 @@ -532,7 +565,10 @@ def test_tensor_to_list(self, source_range_list): @pytest.mark.unit @pytest.mark.parametrize( "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate", - [(0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000),], + [ + (0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), + (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000), + ], ) def test_get_online_segments_from_slices( self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate @@ -665,7 +701,13 @@ def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_si @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_cpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -697,7 +739,13 @@ def test_offline_speaker_clustering_very_short_cpu( @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_gpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -908,7 +956,7 @@ def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix): Test the linear sum assignment algorithm with a cost matrix Compare with the scipy implementation and make sure the final cost is the same. - NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. + NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. This test only checks if the cost is the same. """ row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix) From 07f791afa39bc6105175668ea3e13b678f534504 Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 22 Nov 2024 15:01:16 -0800 Subject: [PATCH 42/47] Resolved line too long pylint issues Signed-off-by: taejinp --- .../asr/parts/utils/manifest_utils.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 55c8e3dbb8c3..2a6914e6b357 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -66,13 +66,15 @@ def get_ctm_line( output_precision: int = 2, ) -> str: """ - Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. + Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in + `Rich Transcription Meeting Eval Plan: RT09` document. CTM Format: Reference: - https://web.archive.org/web/20170119114252/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf + https://web.archive.org/web/20170119114252/ + http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf Args: source (str): is name of the source file, session name or utterance ID @@ -80,11 +82,14 @@ def get_ctm_line( start_time (float): is the begin time of the word, which we refer to as `start_time` in NeMo. duration (float): is duration of the word token (str): Token or word for the current entry - conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) + conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). + A value of “NA” is used (in CTM format data) when no confidence is computed and in the reference data. - type_of_token (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” - speaker (str): is a string identifier for the speaker who uttered the token. This should be “null” for non-speech tokens and “unknown” when - the speaker has not been determined. + type_of_token (str): is the token type. The legal values of are + “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” + speaker (str): is a string identifier for the speaker who uttered the token. + This should be “null” for non-speech tokens and “unknown” when + the speaker has not been determined. NA_token (str, optional): A token for . Defaults to ''. output_precision (int, optional): The precision of the output floating point number. Defaults to 3. @@ -505,7 +510,9 @@ def write_manifest(output_path: Union[Path, str], target_manifest: List[dict], e Args: output_path (str or Path): Path to output manifest file target_manifest (list): List of manifest file entries - ensure_ascii (bool): default is True, meaning the output is guaranteed to have all incoming non-ASCII characters escaped. If ensure_ascii is false, these characters will be output as-is. + ensure_ascii (bool): default is True, meaning the output is guaranteed to have all incoming + non-ASCII characters escaped. If ensure_ascii is false, these characters + will be output as-is. """ with open(output_path, "w", encoding="utf-8") as outfile: for tgt in target_manifest: From 30f1159906a0b1c5f08e0afc4516470bc5477003 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 22 Nov 2024 23:03:55 +0000 Subject: [PATCH 43/47] Apply isort and black reformatting Signed-off-by: tango4j --- .../asr/parts/utils/manifest_utils.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 41f2ccfbe755..418f95832f48 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -66,13 +66,13 @@ def get_ctm_line( output_precision: int = 2, ) -> str: """ - Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in + Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. - - CTM Format: + + CTM Format: - - Reference: + + Reference: https://web.archive.org/web/20170119114252/ http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf @@ -82,14 +82,14 @@ def get_ctm_line( start_time (float): is the begin time of the word, which we refer to as `start_time` in NeMo. duration (float): is duration of the word token (str): Token or word for the current entry - conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). + conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) - when no confidence is computed and in the reference data. - type_of_token (str): is the token type. The legal values of are + when no confidence is computed and in the reference data. + type_of_token (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” - speaker (str): is a string identifier for the speaker who uttered the token. + speaker (str): is a string identifier for the speaker who uttered the token. This should be “null” for non-speech tokens and “unknown” when - the speaker has not been determined. + the speaker has not been determined. NA_token (str, optional): A token for . Defaults to ''. output_precision (int, optional): The precision of the output floating point number. Defaults to 3. @@ -514,8 +514,8 @@ def write_manifest(output_path: Union[Path, str], target_manifest: List[dict], e Args: output_path (str or Path): Path to output manifest file target_manifest (list): List of manifest file entries - ensure_ascii (bool): default is True, meaning the output is guaranteed to have all incoming - non-ASCII characters escaped. If ensure_ascii is false, these characters + ensure_ascii (bool): default is True, meaning the output is guaranteed to have all incoming + non-ASCII characters escaped. If ensure_ascii is false, these characters will be output as-is. """ with open(output_path, "w", encoding="utf-8") as outfile: From 7ec3b1f49a6d13edb8bb7d9905598c01c9b97130 Mon Sep 17 00:00:00 2001 From: taejinp Date: Mon, 25 Nov 2024 21:36:58 -0800 Subject: [PATCH 44/47] Added training and inference CI-tests Signed-off-by: taejinp --- .github/workflows/cicd-main.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index a4b2baa59550..cdea00520e11 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -816,6 +816,33 @@ jobs: +trainer.fast_dev_run=True \ exp_manager.exp_dir=/tmp/speaker_diarization_results + L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure-gpus-1 + SCRIPT: | + python examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py \ + trainer.devices="[0]" \ + batch_size=3 \ + model.train_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_train/eesd_train_tiny.json \ + model.validation_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \ + exp_manager.exp_dir=/tmp/speaker_diarization_results \ + +trainer.fast_dev_run=True + + L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py \ + model_path=/home/TestData/an4_diarizer/diar_sortformer_4spk-v1-tiny.nemo \ + dataset_manifest=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \ + batch_size=1 + L2_Speaker_dev_run_Speech_to_Label: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4517,6 +4544,8 @@ jobs: - L2_Speech_to_Text_EMA - L2_Speaker_dev_run_Speaker_Recognition - L2_Speaker_dev_run_Speaker_Diarization + - L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer + - L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference - L2_Speaker_dev_run_Speech_to_Label - L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference - L2_Speaker_dev_run_Clustering_Diarizer_Inference From 0eb260e81409667b629a1d43b6ad6a0751d33453 Mon Sep 17 00:00:00 2001 From: taejinp Date: Tue, 26 Nov 2024 09:09:53 -0800 Subject: [PATCH 45/47] Added the missing parse_func in preprocessing/collections.py Signed-off-by: taejinp --- .../diarization/neural_diarizer/e2e_diarize_speech.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 27c06d09cc30..1767a16cbe02 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -92,7 +92,7 @@ class DiarizationConfig: # General configs session_len_sec: float = -1 # End-to-end diarization session length in seconds - batch_size: int = 4 + batch_size: int = 1 num_workers: int = 0 random_seed: Optional[int] = None # seed number going to be used in seed_everything() bypass_postprocessing: bool = True # If True, postprocessing will be bypassed From 37d42409870810def90c0f08020d12e28fb62bb2 Mon Sep 17 00:00:00 2001 From: taejinp Date: Tue, 26 Nov 2024 09:15:30 -0800 Subject: [PATCH 46/47] Adding the missing parse_func in preprocessing/collections.py Signed-off-by: taejinp --- .../collections/common/parts/preprocessing/collections.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 6cc31d328e23..d54c807f2637 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -16,7 +16,8 @@ import json import os from itertools import combinations -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union + import numpy as np import pandas as pd @@ -439,7 +440,7 @@ def _get_len(self, field_type, data, duration_data): class ASRAudioText(AudioText): """`AudioText` collector from asr structured json files.""" - def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): + def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[Callable] = None, *args, **kwargs): """Parse lists of audio files, durations and transcripts texts. Args: @@ -462,8 +463,9 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): [], [], ) + speakers, orig_srs, token_labels, langs = [], [], [], [] - for item in manifest.item_iter(manifests_files): + for item in manifest.item_iter(manifests_files, parse_func=parse_func): ids.append(item['id']) audio_files.append(item['audio_file']) durations.append(item['duration']) From bde68879cf6caea49521c751c7789d0c3c96b9c1 Mon Sep 17 00:00:00 2001 From: taejinp Date: Tue, 26 Nov 2024 09:50:15 -0800 Subject: [PATCH 47/47] Fixed an indentation error Signed-off-by: taejinp --- .github/workflows/cicd-main.yml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index cdea00520e11..df0372316a06 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -823,13 +823,13 @@ jobs: with: RUNNER: self-hosted-azure-gpus-1 SCRIPT: | - python examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py \ - trainer.devices="[0]" \ - batch_size=3 \ - model.train_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_train/eesd_train_tiny.json \ - model.validation_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \ - exp_manager.exp_dir=/tmp/speaker_diarization_results \ - +trainer.fast_dev_run=True + python examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py \ + trainer.devices="[0]" \ + batch_size=3 \ + model.train_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_train/eesd_train_tiny.json \ + model.validation_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \ + exp_manager.exp_dir=/tmp/speaker_diarization_results \ + +trainer.fast_dev_run=True L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference: needs: [cicd-test-container-setup] @@ -838,10 +838,10 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - python examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py \ - model_path=/home/TestData/an4_diarizer/diar_sortformer_4spk-v1-tiny.nemo \ - dataset_manifest=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \ - batch_size=1 + python examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py \ + model_path=/home/TestData/an4_diarizer/diar_sortformer_4spk-v1-tiny.nemo \ + dataset_manifest=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \ + batch_size=1 L2_Speaker_dev_run_Speech_to_Label: needs: [cicd-test-container-setup]