Skip to content

Commit

Permalink
use sum(abs2)
Browse files Browse the repository at this point in the history
  • Loading branch information
billera committed Apr 18, 2024
1 parent 064958c commit 08704c3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si
Tvhp = T_R3(vhp, rot_TiL, translate_TiL)

diffs_glob = unsqueeze(Tqhp, dims = 5) .- unsqueeze(Tkhp, dims = 4)
sum_norms_glob = reshape(sum(diffs_glob.^2, dims = [1,3]),N_head,N_frames_R,N_frames_L,:) #Sum over points for each head
sum_norms_glob = reshape(sum(abs2, diffs_glob, dims = [1,3]),N_head,N_frames_R,N_frames_L,:) #Sum over points for each head


att_arg = reshape(dim_scale .* qhTkh .- w_C/2 .* gamma_h .* sum_norms_glob,(N_head,N_frames_R,N_frames_L, :))
Expand Down Expand Up @@ -241,7 +241,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si

#ohp_r were in the global frame, so we put those back in the recipient local
ohp = T_R3_inv(ohp_r, rot_TiR, translate_TiR)
normed_ohp = sqrt.(sum(ohp.^2,dims = 1) .+ Typ(0.000001f0)) #Adding eps
normed_ohp = sqrt.(sum(abs2, ohp,dims = 1) .+ Typ(0.000001f0)) #Adding eps

catty = vcat(
reshape(oh, N_head*c, N_frames_R,:),
Expand Down

0 comments on commit 08704c3

Please sign in to comment.