-
Notifications
You must be signed in to change notification settings - Fork 299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent large values in conv module in wav2vec2_module.py in SSL recipe #1593
base: master
Are you sure you want to change the base?
Conversation
I'm moving the conversation here since the previous PR was closed. I ran @yfyeung 's training command using the merged k2ssl codes, with the batch-size and world-size adjusted to fit my environment. However, training crashed at epoch 31. I then implemented @danpovey 's recommended changes, and reran training from the checkpoint at epoch 26. This time, training crashed at epoch 30. The inf_check for this epoch 30 are as below. Infinity occurred at the forward pass of I have the diagnostics too for the start of epoch 30, but the file is too big to attach on Github. I can send it through email too if you'd like to take a look at it. Also, I'm wondering if I should rerun training from scratch using the recommended changes of penalizing large abs_value, or change the |
Hi, batch size is crucial for SSL. When batch size decreases, gradient noise becomes very large, which has a bad impact on half-precision and convergence. |
I see. This is my training command: python zipformer/pretrain.py \
--world-size 4 \
--num-epochs 100 \
--start-epoch 30 \
--use-fp16 1 \
--exp-dir zipformer/exp3/pretrain \
--manifest-dir data/raw \
--full-libri 1 \
--max-duration 300 \
--accum-grad 4 \
--do-normalize 0 \
--mask-prob 0.8 \
--dropout-input 0.1 \
--dropout-features 0.1 \
--feature-grad-mult 0.1 \
--untie-final-proj 1 \
--num-encoder-layers "2,2,3,4,3,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--encoder-dim "192,256,448,768,448,192" \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--base-lr 0.045 A little explanation about how I decided my batch size:
Thank you so much! I will have a look! |
The current gradient accumulation mechanism simulates multi-GPU setup. You can simulate my setup using 4 GPUs with IMO, keeping 8 GPUs and the same |
OK, this error is different from the error you got before. It's the grads that are infinite, not the activations:
Note that these are the grads after aggregating over the batch. Too-large grads on the first input convolution layer when training with fp16 are a problem I have noticed before. Probably the most certain fix would be to insert a ScaleGrad module to scale the grad down at that point during the backprop. E.g. insert it after conv_layers.0, into the nn.Sequential's list or something. See how I have used that in the zipformer recipe, to solve a similar issue. BTW, my approach with the zipformer recipe has been to fix instability or crashes one by one like this, as they appear, in the hope that after we address all the failure modes the recipe should be quite robust. |
Thank you for your guidance! By comparing if is_layer_norm:
return nn.Sequential(
make_conv(),
ScaleGrad(0.5),
nn.Dropout(p=dropout),
nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=True),
TransposeLast(),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
ScaleGrad(0.5),
nn.Dropout(p=dropout),
Fp32GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(
make_conv(),
ScaleGrad(0.5),
nn.Dropout(p=dropout),
nn.GELU(),
) Hopefully the anonymity period ends soon. Meanwhile, I will continue the debugging on my end. I will report again if I find something that works for my setup. |
No description provided.