diff --git a/src/grads.jl b/src/grads.jl index b27ace2..966bfea 100644 --- a/src/grads.jl +++ b/src/grads.jl @@ -83,7 +83,7 @@ function ChainRulesCore.rrule(::typeof(T_R3_inv), A, R, t; dims = 1) batch_size = size(A)[3:end] R2 = reshape(R, size(R,1), size(R,2), :) Δ2 = reshape(Δ, size(Δ,1), size(Δ,2), :) - @show size(R2), size(Δ2) + #@show size(R2), size(Δ2) ΔA = batched_mul(R2, Δ2) reshape(ΔA, size(ΔA, 1), size(ΔA, 2), batch_size...) end @@ -143,7 +143,7 @@ function ChainRulesCore.rrule(::typeof(diff_sum_glob), T, q, k) Tq, Tk = reshape(Tq, bs...), reshape(Tk, bs...) pair_diffs, pair_diffs_pullback = rrule(pair_diff, Tq, Tk, dims = 4) sabs2, sabs2_pullback = rrule(sumabs2, pair_diffs, dims = [1,3]) - @show size(sabs2) + #@show size(sabs2) function diff_sum_glob_pullback(_Δ) # Our applications always use these for now, so no thunk since we easily save some compute by sharing ops @@ -239,4 +239,5 @@ function pre_softmax_aijh(qh,kh,T,qhp,khp, bij, gamma_h) w_L = Float32(1f0/sqrt(3f0)) w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(T,qhp,khp),dims=(1,3))) -end \ No newline at end of file +end +