diff --git a/src/grads.jl b/src/grads.jl index 8814055..a4e0958 100644 --- a/src/grads.jl +++ b/src/grads.jl @@ -129,7 +129,7 @@ function _T_R3_inv_no_rrule(mat, rot,trans) rotated_mat = batched_mul(rotc,matc .- trans) return reshape(rotated_mat,size_mat) end - +#= function diff_sum_glob(T, q, k) bs = size(q) qresh = reshape(q, size(q,1), size(q,2)*size(q,3), size(q,4),size(q,5)) @@ -146,8 +146,8 @@ function _diff_sum_glob_no_rrule(T,q,k) Tq, Tk = _T_R3_no_rrule(qresh,T[1],T[2]),_T_R3_no_rrule(kresh,T[1],T[2]) Tq, Tk = reshape(Tq, bs...), reshape(Tk, bs...) diffs = _sumabs2_no_rrule(_pair_diff_no_rrule(Tq, Tk, dims = 4),dims=[1,3]) -end - +end=# +#= # not implemented grad with respect to T here, as is not needed in any applications for now # this rrule provides ≈ 2x memory improvement by computing query and key grads simultaneously function ChainRulesCore.rrule(::typeof(diff_sum_glob), T, q, k) @@ -211,7 +211,7 @@ function ChainRulesCore.rrule(::typeof(qhTkh), q, k) end return qhTkh, qhTkh_pullback end - +=# """ softmax1(x, dims = 1) diff --git a/src/layers.jl b/src/layers.jl index fa19374..b84cacf 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -297,8 +297,8 @@ function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,Abstr ohp_r = reshape(sum(broadcast_att_ohp.*broadcast_tvhp,dims=5),3,N_head*N_point_values,N_frames_R,:) end #ohp_r were in the global frame, so we put those back in the recipient local - ohp = T_R3_inv(ohp_r, rot_TiR, translate_TiR) - normed_ohp = L2norm(ohp, eps = Typ(0.000001f0)) #Adding eps + ohp = _T_R3_inv_no_rrule(ohp_r, rot_TiR, translate_TiR) + normed_ohp = sqrt.(sumabs2(ohp, dims = 1) .+ Typ(0.000001f0)) #Adding eps catty = vcat( reshape(oh, N_head*c, N_frames_R,:), reshape(ohp, 3*N_head*N_point_values, N_frames_R,:),