Skip to content

Commit

Permalink
remove @show
Browse files Browse the repository at this point in the history
  • Loading branch information
billera committed Apr 29, 2024
1 parent 313d8ca commit 80c9933
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/grads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function ChainRulesCore.rrule(::typeof(T_R3_inv), A, R, t; dims = 1)
batch_size = size(A)[3:end]
R2 = reshape(R, size(R,1), size(R,2), :)
Δ2 = reshape(Δ, size(Δ,1), size(Δ,2), :)
@show size(R2), size(Δ2)
#@show size(R2), size(Δ2)
ΔA = batched_mul(R2, Δ2)
reshape(ΔA, size(ΔA, 1), size(ΔA, 2), batch_size...)
end
Expand Down Expand Up @@ -143,7 +143,7 @@ function ChainRulesCore.rrule(::typeof(diff_sum_glob), T, q, k)
Tq, Tk = reshape(Tq, bs...), reshape(Tk, bs...)
pair_diffs, pair_diffs_pullback = rrule(pair_diff, Tq, Tk, dims = 4)
sabs2, sabs2_pullback = rrule(sumabs2, pair_diffs, dims = [1,3])
@show size(sabs2)
#@show size(sabs2)

function diff_sum_glob_pullback(_Δ)
# Our applications always use these for now, so no thunk since we easily save some compute by sharing ops
Expand Down Expand Up @@ -239,4 +239,5 @@ function pre_softmax_aijh(qh,kh,T,qhp,khp, bij, gamma_h)
w_L = Float32(1f0/sqrt(3f0))

w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(T,qhp,khp),dims=(1,3)))
end
end

0 comments on commit 80c9933

Please sign in to comment.