Skip to content

Commit

Permalink
f32 to Float32
Browse files Browse the repository at this point in the history
  • Loading branch information
billera committed May 1, 2024
1 parent f180e91 commit 34c58a7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/grads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ end


function pre_softmax_aijh(qh,kh,T,qhp,khp, bij, gamma_h)
w_C = f32(sqrt(2/(9*size(qhp,3))))
dim_scale = f32(1/sqrt(size(qh,1)))
w_L = f32(1f0/sqrt(3f0))
w_C = Float32(sqrt(2f0/(9f0*size(qhp,3))))
dim_scale = Float32(1f0/sqrt(size(qh,1)))
w_L = Float32(1f0/sqrt(3f0))

w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(T,qhp,khp),dims=(1,3)))
end
Expand Down

0 comments on commit 34c58a7

Please sign in to comment.