From 6d16ceff8695fd55c685a7ed7be494e5851f4d90 Mon Sep 17 00:00:00 2001 From: billera Date: Wed, 1 May 2024 16:27:16 +0200 Subject: [PATCH] general type aijh --- src/grads.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/grads.jl b/src/grads.jl index addc1c6..cfe2716 100644 --- a/src/grads.jl +++ b/src/grads.jl @@ -100,7 +100,6 @@ function ChainRulesCore.rrule(::typeof(T_R3_inv), A, R, t; dims = 1) batch_size = size(A)[3:end] R2 = reshape(R, size(R,1), size(R,2), :) Δ2 = reshape(Δ, size(Δ,1), size(Δ,2), :) - #@show size(R2), size(Δ2) ΔA = batched_mul(R2, Δ2) reshape(ΔA, size(ΔA, 1), size(ΔA, 2), batch_size...) end @@ -160,7 +159,6 @@ function ChainRulesCore.rrule(::typeof(diff_sum_glob), T, q, k) Tq, Tk = reshape(Tq, bs...), reshape(Tk, bs...) pair_diffs, pair_diffs_pullback = rrule(pair_diff, Tq, Tk, dims = 4) sabs2, sabs2_pullback = rrule(sumabs2, pair_diffs, dims = [1,3]) - #@show size(sabs2) function diff_sum_glob_pullback(_Δ) # Our applications always use these for now, so no thunk since we easily save some compute by sharing ops @@ -250,11 +248,11 @@ function ChainRulesCore.rrule(::typeof(softmax1), x; dims = 1) end -function pre_softmax_aijh(qh,kh,T,qhp,khp, bij, gamma_h) - w_C = Float32(sqrt(2f0/(9f0*size(qhp,3)))) - dim_scale = Float32(1f0/sqrt(size(qh,1))) - w_L = Float32(1f0/sqrt(3f0)) - - w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(T,qhp,khp),dims=(1,3))) +function pre_softmax_aijh(qh::AbstractArray{T},kh::AbstractArray{T},Ti,qhp::AbstractArray{T},khp::AbstractArray{T}, bij::AbstractArray{T}, gamma_h::AbstractArray{T}) where T + w_C = T(sqrt(2f0/(9f0*size(qhp,3)))) + dim_scale = T(1f0/sqrt(size(qh,1))) + w_L = T(1f0/sqrt(3f0)) + + w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(Ti,qhp,khp),dims=(1,3))) end