Skip to content

Commit

Permalink
Merge pull request #9 from MurrellGroup/numerical-patch-1
Browse files Browse the repository at this point in the history
Numerical stability tweaks
  • Loading branch information
murrellb authored Apr 10, 2024
2 parents 1b6320a + 22a5a5c commit 95eede6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si
N_frames_L = size(siL,2)
N_frames_R = size(siR,2)

gamma_h = min.(softplus(l.gamma_h),1f2)
gamma_h = softplus(clamp.(l.gamma_h,Typ(-100), Typ(100))) #Clamping

w_C = Typ(sqrt(2/(9*N_query_points)))
dim_scale = Typ(1/sqrt(c))
Expand Down Expand Up @@ -189,7 +189,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si

#ohp_r were in the global frame, so we put those back in the recipient local
ohp = T_R3_inv(ohp_r, rot_TiR, translate_TiR)
normed_ohp = sqrt.(sum(ohp.^2,dims = 1))
normed_ohp = sqrt.(sum(ohp.^2,dims = 1) .+ Typ(0.000001f0)) #Adding eps

catty = vcat(
reshape(oh, N_head*c, N_frames_R,:),
Expand Down

0 comments on commit 95eede6

Please sign in to comment.