From 6aad0523a7e122c4a08d93aaf2c7eabce6524d34 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 25 Sep 2024 16:47:05 -0400 Subject: [PATCH] fix: rollback custom gelu implementation --- Project.toml | 2 +- src/impl/activation.jl | 38 -------------------------------------- 2 files changed, 1 insertion(+), 39 deletions(-) diff --git a/Project.toml b/Project.toml index 536aae51..d1e4779f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.0" +version = "1.3.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/impl/activation.jl b/src/impl/activation.jl index 8f39cf65..dfd1d0c9 100644 --- a/src/impl/activation.jl +++ b/src/impl/activation.jl @@ -153,7 +153,6 @@ CRC.@non_differentiable select_fastest_activation(::Any...) module SLEEFActivations using ChainRulesCore: ChainRulesCore -using EnzymeCore: EnzymeCore, EnzymeRules using NNlib: NNlib using SLEEFPirates: SLEEFPirates @@ -164,32 +163,16 @@ const CRC = ChainRulesCore sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x) softplus(x::Number) = SLEEFPirates.softplus(x) logsigmoid(x::Number) = -softplus(-x) -gelu(x::Number) = SLEEFPirates.gelu(x) swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x)) lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x)) tanh(x::Number) = SLEEFPirates.tanh(x) tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x) -const gelu_λ = √(2 / π) -const gelu_2λ = √(8 / π) - -function ∇gelu(x::Number) - α = oftype(x, 0.044715) - α2 = oftype(x, 0.08943) - λλ = oftype(x, gelu_2λ) - x2 = Base.FastMath.mul_fast(x, x) - t = muladd(x2, α, one(x)) - Ω = sigmoid_fast(λλ * x * t) - dσ = conj(Ω * (1 - Ω)) - return muladd(dσ * λλ * muladd(x2, α2, t), x, Ω) -end - for (f, dfdx) in [ #! format: off (:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), (:softplus, :(sigmoid_fast(x))), (:logsigmoid, :(sigmoid_fast(-x))), - (:gelu, :(∇gelu(x))), (:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))), (:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))), (:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), @@ -210,26 +193,6 @@ for (f, dfdx) in [ end end -# Enzyme works for all of these except `gelu`. -# See https://github.com/EnzymeAD/Enzyme.jl/issues/1671 -function EnzymeRules.augmented_primal( - cfg::EnzymeRules.RevConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)}, - ::Type{<:EnzymeCore.Active}, x::EnzymeCore.Active{<:Number}) - primal = EnzymeRules.needs_primal(cfg) ? func.val(x.val) : nothing - return EnzymeRules.AugmentedReturn(primal, nothing, nothing) -end - -function EnzymeRules.reverse( - ::EnzymeRules.RevConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)}, - dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number}) - return (dret.val * ∇gelu(x.val),) -end - -function EnzymeRules.forward(::EnzymeRules.FwdConfig, ::EnzymeCore.Const{typeof(gelu)}, - ::Type{<:EnzymeCore.Duplicated}, x::EnzymeCore.Duplicated{<:Number}) - return EnzymeCore.Duplicated(gelu(x.val), x.dval * ∇gelu(x.val)) -end - fast_act(f::F, ::Type{T}) where {F, T} = f fast_act(f::F, ::Type{Float32}) where {F} = fast_act(f) @@ -238,7 +201,6 @@ for (fbase, ffast) in [ (NNlib.sigmoid_fast, sigmoid_fast), (NNlib.softplus, softplus), (NNlib.logsigmoid, logsigmoid), - (NNlib.gelu, gelu), (NNlib.swish, swish), (NNlib.lisht, lisht), (Base.tanh, tanh),