Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sortformer Diarizer 4spk v1 model PR Part 2: Unit-tests for Sortformer Diarizer. #11336

Draft
wants to merge 100 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
e69ec8e
Adding the first pr files models and dataset
tango4j Nov 14, 2024
2914325
Tested all unit-test files
tango4j Nov 14, 2024
9a468ac
Name changes on yaml files and train example
tango4j Nov 14, 2024
a910d30
Merge branch 'main' into sortformer/pr_01
tango4j Nov 14, 2024
2f44fe1
Apply isort and black reformatting
tango4j Nov 14, 2024
4ddc59b
Reflecting comments and removing unnecessary parts for this PR
tango4j Nov 15, 2024
43d95f0
Resolved conflicts
tango4j Nov 15, 2024
40e9f95
Apply isort and black reformatting
tango4j Nov 15, 2024
f7f84bb
Adding docstrings to reflect the PR comments
tango4j Nov 15, 2024
95acd79
Resolved the new conflict
tango4j Nov 15, 2024
919f4da
Merge branch 'main' into sortformer/pr_01
tango4j Nov 15, 2024
4134e25
removed the unused find_first_nonzero
tango4j Nov 15, 2024
d3432e5
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
5dd4d4c
Apply isort and black reformatting
tango4j Nov 15, 2024
ca5eea3
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
9d493c0
Merge branch 'main' into sortformer/pr_01
tango4j Nov 15, 2024
037f61e
Fixed all pylint issues
tango4j Nov 15, 2024
a8bc048
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
cb23268
Apply isort and black reformatting
tango4j Nov 15, 2024
4a266b9
Resolving pylint issues
tango4j Nov 15, 2024
5e4e9c8
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
c31c60c
Merge branch 'main' into sortformer/pr_01
tango4j Nov 15, 2024
6e2225e
Apply isort and black reformatting
tango4j Nov 15, 2024
ab93b17
Removing unused varialbe in audio_to_diar_label.py
tango4j Nov 15, 2024
4f3ee66
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
3f24b82
Merge branch 'main' into sortformer/pr_01
tango4j Nov 16, 2024
f49e107
Merge branch 'main' into sortformer/pr_01
tango4j Nov 16, 2024
7dea01b
Fixed docstrings in training script
tango4j Nov 16, 2024
2a99d53
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 16, 2024
71d515f
Line-too-long issue from Pylint fixed
tango4j Nov 16, 2024
9b7b93e
Merge branch 'main' into sortformer/pr_01
tango4j Nov 18, 2024
f2d5e36
Adding get_subsegments_scriptable to prevent jit.script error
tango4j Nov 19, 2024
9cca3e8
Apply isort and black reformatting
tango4j Nov 19, 2024
731caa8
Merge branch 'main' into sortformer/pr_01
tango4j Nov 19, 2024
681fe38
Merge branch 'main' into sortformer/pr_01
tango4j Nov 19, 2024
008dcbd
Addressed Code-QL issues
tango4j Nov 19, 2024
d89ed91
Addressed Code-QL issues and resolved conflicts
tango4j Nov 19, 2024
045f3a2
Resolved conflicts on bce_loss.py
tango4j Nov 19, 2024
1dcf9ab
Apply isort and black reformatting
tango4j Nov 19, 2024
be8ac22
Adding all the diarization reltated unit-tests
tango4j Nov 19, 2024
ca44a66
Moving speaker task related unit test files to speaker_tasks folder
tango4j Nov 20, 2024
1360831
Fixed uninit variable issue in bce_loss.py spotted by codeQL
tango4j Nov 20, 2024
553197a
Apply isort and black reformatting
tango4j Nov 20, 2024
7893e75
Merge branch 'main' into sortformer/pr_01
tango4j Nov 20, 2024
f7fced9
Merge branch 'main' into sortformer/pr_01
tango4j Nov 20, 2024
87af813
Merge branch 'main' into sortformer/pr_02
tango4j Nov 20, 2024
734dfd8
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 20, 2024
c3c0b32
Fixing code-QL issues
tango4j Nov 21, 2024
631555d
Apply isort and black reformatting
tango4j Nov 21, 2024
99ee5cc
Merge branch 'main' into sortformer/pr_02
tango4j Nov 21, 2024
9371ed0
Merge branch 'main' into sortformer/pr_01
tango4j Nov 21, 2024
6a3bb62
Reflecting PR comments from weiqingw
tango4j Nov 21, 2024
4e0327c
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 21, 2024
b8a49ea
Apply isort and black reformatting
tango4j Nov 21, 2024
6198a20
Line too long pylint issue resolved in e2e_diarize_speech.py
tango4j Nov 21, 2024
e4b0154
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 21, 2024
07c4242
Apply isort and black reformatting
tango4j Nov 21, 2024
9feb013
Resovled unused variable issue in model test
tango4j Nov 21, 2024
7496a0d
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j Nov 21, 2024
db90424
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 21, 2024
0eeaf06
Merge branch 'main' into sortformer/pr_01
tango4j Nov 21, 2024
fa11155
Reflecting the comment on Nov 21st 2024.
tango4j Nov 21, 2024
b5878cc
Apply isort and black reformatting
tango4j Nov 21, 2024
bfe36e7
Merge branch 'main' into sortformer/pr_01
tango4j Nov 21, 2024
7898697
Unused variable import time
tango4j Nov 21, 2024
e167dba
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 21, 2024
8712278
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 21, 2024
1bb89d5
Merge branch 'main' into sortformer/pr_01
tango4j Nov 21, 2024
e4006cf
Adding docstrings to score_labels() function in der.py
tango4j Nov 22, 2024
a92e4e6
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 22, 2024
ca480eb
Apply isort and black reformatting
tango4j Nov 22, 2024
5ea9d7d
Merge branch 'main' into sortformer/pr_01
tango4j Nov 22, 2024
1b091c8
Merge branch 'main' into sortformer/pr_01
tango4j Nov 22, 2024
af04832
Reflecting comments on YAML files and model file variable changes.
tango4j Nov 22, 2024
a4367a3
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 22, 2024
edbe159
Apply isort and black reformatting
tango4j Nov 22, 2024
b47579b
Merge branch 'main' into sortformer/pr_01
tango4j Nov 22, 2024
8365a05
Added get_subsegments_scriptable for legacy get_subsegment functions
tango4j Nov 22, 2024
f2250a0
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 22, 2024
5275fb5
Merge branch 'main' into sortformer/pr_01
tango4j Nov 22, 2024
86315db
Apply isort and black reformatting
tango4j Nov 22, 2024
07f791a
Resolved line too long pylint issues
tango4j Nov 22, 2024
2b23136
Resolved line too long pylint issues and merged main
tango4j Nov 22, 2024
30f1159
Apply isort and black reformatting
tango4j Nov 22, 2024
f9a9884
Merge branch 'main' into sortformer/pr_01
tango4j Nov 23, 2024
0e50abf
Merge branch 'main' into sortformer/pr_01
tango4j Nov 24, 2024
f232a40
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 25, 2024
6fd3076
Merge branch 'main' into sortformer/pr_01
tango4j Nov 26, 2024
7ec3b1f
Added training and inference CI-tests
tango4j Nov 26, 2024
0eb260e
Added the missing parse_func in preprocessing/collections.py
tango4j Nov 26, 2024
895b4ed
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 26, 2024
0d6ebc7
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 26, 2024
37d4240
Adding the missing parse_func in preprocessing/collections.py
tango4j Nov 26, 2024
01085ab
Merge branch 'main' into sortformer/pr_01
tango4j Nov 26, 2024
03c425b
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 26, 2024
bde6887
Fixed an indentation error
tango4j Nov 26, 2024
3f378f6
Merge branch 'main' into sortformer/pr_02
tango4j Nov 26, 2024
024a391
Merge branch 'main' into sortformer/pr_02
tango4j Nov 26, 2024
81b751e
Merge branch 'main' into sortformer/pr_02
tango4j Nov 27, 2024
f7029d7
Merge branch 'main' into sortformer/pr_02
tango4j Nov 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,33 @@ jobs:
+trainer.fast_dev_run=True \
exp_manager.exp_dir=/tmp/speaker_diarization_results

L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
python examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py \
trainer.devices="[0]" \
batch_size=3 \
model.train_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_train/eesd_train_tiny.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \
exp_manager.exp_dir=/tmp/speaker_diarization_results \
+trainer.fast_dev_run=True

L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py \
model_path=/home/TestData/an4_diarizer/diar_sortformer_4spk-v1-tiny.nemo \
dataset_manifest=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \
batch_size=1

L2_Speaker_dev_run_Speech_to_Label:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4517,6 +4544,8 @@ jobs:
- L2_Speech_to_Text_EMA
- L2_Speaker_dev_run_Speaker_Recognition
- L2_Speaker_dev_run_Speaker_Diarization
- L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer
- L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference
- L2_Speaker_dev_run_Speech_to_Label
- L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference
- L2_Speaker_dev_run_Clustering_Diarizer_Inference
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture.
# Model name convention for Sortformer Diarizer: sortformer_diarizer_<loss_type>_<speaker count limit>-<version>.yaml
# (Example) `sortformer_diarizer_hybrid_loss_4spk-v1.yaml`.
# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers.
# Example: a manifest line for training
# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"}
name: "SortFormerDiarizer"
num_workers: 18
batch_size: 8

model:
sample_rate: 16000
pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model
ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model
max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4

model_defaults:
fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder
tf_d_model: 192 # Hidden dimension size of the Transformer Encoder

train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
num_spks: ${model.max_num_of_spks}
session_len_sec: 90 # Maximum session length in seconds
soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity.
soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss
labels: null
batch_size: ${batch_size}
shuffle: True
num_workers: ${num_workers}
validation_mode: False
# lhotse config
use_lhotse: False
use_bucketing: True
num_buckets: 10
bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90]
pin_memory: True
min_duration: 10
max_duration: 90
batch_duration: 400
quadratic_duration: 1200
bucket_buffer_size: 20000
shuffle_buffer_size: 10000
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

validation_ds:
manifest_filepath: ???
is_tarred: False
tarred_audio_filepaths: null
sample_rate: ${model.sample_rate}
num_spks: ${model.max_num_of_spks}
session_len_sec: 90 # Maximum session length in seconds
soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes.
soft_targets: False
labels: null
batch_size: ${batch_size}
shuffle: False
num_workers: ${num_workers}
validation_mode: True
# lhotse config
use_lhotse: False
use_bucketing: False
drop_last: False
pin_memory: True
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

test_ds:
manifest_filepath: null
is_tarred: False
tarred_audio_filepaths: null
sample_rate: 16000
num_spks: ${model.max_num_of_spks}
session_len_sec: 90 # Maximum session length in seconds
soft_label_thres: 0.5
soft_targets: False
labels: null
batch_size: ${batch_size}
shuffle: False
seq_eval_mode: True
num_workers: ${num_workers}
validation_mode: True
# lhotse config
use_lhotse: False
use_bucketing: False
drop_last: False
pin_memory: True
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
normalize: "per_feature"
window_size: 0.025
sample_rate: ${model.sample_rate}
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
frame_splicing: 1
dither: 0.00001

sortformer_modules:
_target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules
num_spks: ${model.max_num_of_spks} # Number of speakers per model. This is currently fixed at 4.
dropout_rate: 0.5 # Dropout rate
fc_d_model: ${model.model_defaults.fc_d_model}
tf_d_model: ${model.model_defaults.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1
n_layers: 18
d_model: ${model.model_defaults.fc_d_model}

# Sub-sampling parameters
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 8 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model
causal_downsampling: false
# Feed forward module's params
ff_expansion_factor: 4
# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
# Convolution module's params
conv_kernel_size: 9
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
conv_context_size: null
# Regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules
# Set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 1

transformer_encoder:
_target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder
num_layers: 18
hidden_size: ${model.model_defaults.tf_d_model} # Needs to be multiple of num_attention_heads
inner_size: 768
num_attention_heads: 8
attn_score_dropout: 0.5
attn_layer_dropout: 0.5
ffn_dropout: 0.5
hidden_act: relu
pre_ln: False
pre_ln_final_layer_norm: True

loss:
_target_: nemo.collections.asr.losses.bce_loss.BCELoss
weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5])
reduction: mean

lr: 0.0001
optim:
name: adamw
lr: ${model.lr}
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

sched:
name: InverseSquareRootAnnealing
warmup_steps: 2500
warmup_ratio: null
min_lr: 1e-06

trainer:
devices: 1 # number of gpus (devices)
accelerator: gpu
max_epochs: 800
max_steps: -1 # computed at runtime if not set
num_nodes: 1
strategy: ddp_find_unused_parameters_true # Could be "ddp"
accumulate_grad_batches: 1
deterministic: True
enable_checkpointing: False
logger: False
log_every_n_steps: 1 # Interval of logging.
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations

exp_manager:
use_datetime_version: False
exp_dir: null
name: ${name}
resume_if_exists: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
resume_ignore_no_checkpoint: True
create_tensorboard_logger: True
create_checkpoint_callback: True
create_wandb_logger: False
checkpoint_callback_params:
monitor: "val_f1_acc"
mode: "max"
save_top_k: 9
every_n_epochs: 1
wandb_logger_kwargs:
resume: True
name: null
project: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Postprocessing parameters for timestamp outputs from speaker diarization models.
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper:
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020).
# These parameters were optimized with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656.
# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the part1 (callhome1) specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v2/run.sh
# Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055.
parameters:
onset: 0.53 # Onset threshold for detecting the beginning and end of a speech
offset: 0.49 # Offset threshold for detecting the end of a speech
pad_onset: 0.23 # Adding durations before each speech segment
pad_offset: 0.01 # Adding durations after each speech segment
min_duration_on: 0.42 # Threshold for small non-speech deletion
min_duration_off: 0.34 # Threshold for short speech segment deletion
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Postprocessing parameters for timestamp outputs from speaker diarization models.
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper:
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020).
# These parameters were optimized with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656.
# These parameters were optimized on the development split of DIHARD3 dataset (See https://arxiv.org/pdf/2012.01477).
# Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649.
parameters:
onset: 0.64 # Onset threshold for detecting the beginning and end of a speech
offset: 0.74 # Offset threshold for detecting the end of a speech
pad_onset: 0.06 # Adding durations before each speech segment
pad_offset: 0.0 # Adding durations after each speech segment
min_duration_on: 0.1 # Threshold for small non-speech deletion
min_duration_off: 0.15 # Threshold for short speech segment deletion
Loading
Loading