From 22a5a5cdf87798a5768cfa920560caab4e1f5480 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Wed, 10 Apr 2024 11:58:22 +0200 Subject: [PATCH] Numerical stability tweaks --- src/layers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index 9bb4fbc..f0dbab4 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -128,7 +128,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si N_frames_L = size(siL,2) N_frames_R = size(siR,2) - gamma_h = min.(softplus(l.gamma_h),1f2) + gamma_h = softplus(clamp.(l.gamma_h,Typ(-100), Typ(100))) #Clamping w_C = Typ(sqrt(2/(9*N_query_points))) dim_scale = Typ(1/sqrt(c)) @@ -189,7 +189,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si #ohp_r were in the global frame, so we put those back in the recipient local ohp = T_R3_inv(ohp_r, rot_TiR, translate_TiR) - normed_ohp = sqrt.(sum(ohp.^2,dims = 1)) + normed_ohp = sqrt.(sum(ohp.^2,dims = 1) .+ Typ(0.000001f0)) #Adding eps catty = vcat( reshape(oh, N_head*c, N_frames_R,:),