From d1dfde2c9deaccdd281be8a7990c1e53ecb0fcb2 Mon Sep 17 00:00:00 2001 From: Teo Date: Sun, 23 Jun 2024 00:20:11 +0900 Subject: [PATCH 1/2] cast grad_scale in whiten to float --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e7c3f4ab12..60a348010f 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1032,7 +1032,7 @@ def backward(ctx, x_grad: Tensor): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1074,7 +1074,7 @@ def __init__( super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale From c14fd3872dc7f84e0e37dcb9d48eec9efdec0929 Mon Sep 17 00:00:00 2001 From: Teo Date: Sun, 23 Jun 2024 00:22:43 +0900 Subject: [PATCH 2/2] fix cast in zipformer_lora --- egs/librispeech/ASR/zipformer_lora/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 3149db9f3c..8d7aa80275 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -1137,7 +1137,7 @@ def backward(ctx, x_grad: Tensor): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1179,7 +1179,7 @@ def __init__( super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale