Skip to content

Commit

Permalink
general type aijh
Browse files Browse the repository at this point in the history
  • Loading branch information
billera committed May 1, 2024
1 parent 34c58a7 commit 6d16cef
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/grads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ function ChainRulesCore.rrule(::typeof(T_R3_inv), A, R, t; dims = 1)
batch_size = size(A)[3:end]
R2 = reshape(R, size(R,1), size(R,2), :)
Δ2 = reshape(Δ, size(Δ,1), size(Δ,2), :)
#@show size(R2), size(Δ2)
ΔA = batched_mul(R2, Δ2)
reshape(ΔA, size(ΔA, 1), size(ΔA, 2), batch_size...)
end
Expand Down Expand Up @@ -160,7 +159,6 @@ function ChainRulesCore.rrule(::typeof(diff_sum_glob), T, q, k)
Tq, Tk = reshape(Tq, bs...), reshape(Tk, bs...)
pair_diffs, pair_diffs_pullback = rrule(pair_diff, Tq, Tk, dims = 4)
sabs2, sabs2_pullback = rrule(sumabs2, pair_diffs, dims = [1,3])
#@show size(sabs2)

function diff_sum_glob_pullback(_Δ)
# Our applications always use these for now, so no thunk since we easily save some compute by sharing ops
Expand Down Expand Up @@ -250,11 +248,11 @@ function ChainRulesCore.rrule(::typeof(softmax1), x; dims = 1)
end


function pre_softmax_aijh(qh,kh,T,qhp,khp, bij, gamma_h)
w_C = Float32(sqrt(2f0/(9f0*size(qhp,3))))
dim_scale = Float32(1f0/sqrt(size(qh,1)))
w_L = Float32(1f0/sqrt(3f0))

w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(T,qhp,khp),dims=(1,3)))
function pre_softmax_aijh(qh::AbstractArray{T},kh::AbstractArray{T},Ti,qhp::AbstractArray{T},khp::AbstractArray{T}, bij::AbstractArray{T}, gamma_h::AbstractArray{T}) where T
w_C = T(sqrt(2f0/(9f0*size(qhp,3))))
dim_scale = T(1f0/sqrt(size(qh,1)))
w_L = T(1f0/sqrt(3f0))
w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(Ti,qhp,khp),dims=(1,3)))
end

0 comments on commit 6d16cef

Please sign in to comment.