Skip to content

Commit

Permalink
mod out L2norm, T_R3_inv
Browse files Browse the repository at this point in the history
  • Loading branch information
billera committed May 4, 2024
1 parent 42a95fd commit 76d435f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/grads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function _T_R3_inv_no_rrule(mat, rot,trans)
rotated_mat = batched_mul(rotc,matc .- trans)
return reshape(rotated_mat,size_mat)
end

#=
function diff_sum_glob(T, q, k)
bs = size(q)
qresh = reshape(q, size(q,1), size(q,2)*size(q,3), size(q,4),size(q,5))
Expand All @@ -146,8 +146,8 @@ function _diff_sum_glob_no_rrule(T,q,k)
Tq, Tk = _T_R3_no_rrule(qresh,T[1],T[2]),_T_R3_no_rrule(kresh,T[1],T[2])
Tq, Tk = reshape(Tq, bs...), reshape(Tk, bs...)
diffs = _sumabs2_no_rrule(_pair_diff_no_rrule(Tq, Tk, dims = 4),dims=[1,3])
end

end=#
#=
# not implemented grad with respect to T here, as is not needed in any applications for now
# this rrule provides ≈ 2x memory improvement by computing query and key grads simultaneously
function ChainRulesCore.rrule(::typeof(diff_sum_glob), T, q, k)
Expand Down Expand Up @@ -211,7 +211,7 @@ function ChainRulesCore.rrule(::typeof(qhTkh), q, k)
end
return qhTkh, qhTkh_pullback
end

=#
"""
softmax1(x, dims = 1)
Expand Down
4 changes: 2 additions & 2 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,Abstr
ohp_r = reshape(sum(broadcast_att_ohp.*broadcast_tvhp,dims=5),3,N_head*N_point_values,N_frames_R,:)
end
#ohp_r were in the global frame, so we put those back in the recipient local
ohp = T_R3_inv(ohp_r, rot_TiR, translate_TiR)
normed_ohp = L2norm(ohp, eps = Typ(0.000001f0)) #Adding eps
ohp = _T_R3_inv_no_rrule(ohp_r, rot_TiR, translate_TiR)
normed_ohp = sqrt.(sumabs2(ohp, dims = 1) .+ Typ(0.000001f0)) #Adding eps
catty = vcat(
reshape(oh, N_head*c, N_frames_R,:),
reshape(ohp, 3*N_head*N_point_values, N_frames_R,:),
Expand Down

0 comments on commit 76d435f

Please sign in to comment.