diff --git a/src/layers.jl b/src/layers.jl index 8f63fab..588c8b2 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -199,7 +199,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si Tvhp = T_R3(vhp, rot_TiL, translate_TiL) diffs_glob = unsqueeze(Tqhp, dims = 5) .- unsqueeze(Tkhp, dims = 4) - sum_norms_glob = reshape(sum(diffs_glob.^2, dims = [1,3]),N_head,N_frames_R,N_frames_L,:) #Sum over points for each head + sum_norms_glob = reshape(sum(abs2, diffs_glob, dims = [1,3]),N_head,N_frames_R,N_frames_L,:) #Sum over points for each head att_arg = reshape(dim_scale .* qhTkh .- w_C/2 .* gamma_h .* sum_norms_glob,(N_head,N_frames_R,N_frames_L, :)) @@ -241,7 +241,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si #ohp_r were in the global frame, so we put those back in the recipient local ohp = T_R3_inv(ohp_r, rot_TiR, translate_TiR) - normed_ohp = sqrt.(sum(ohp.^2,dims = 1) .+ Typ(0.000001f0)) #Adding eps + normed_ohp = sqrt.(sum(abs2, ohp,dims = 1) .+ Typ(0.000001f0)) #Adding eps catty = vcat( reshape(oh, N_head*c, N_frames_R,:),