diff --git a/src/layers/mul_layer.jl b/src/layers/mul_layer.jl index 983505cc..02804a40 100644 --- a/src/layers/mul_layer.jl +++ b/src/layers/mul_layer.jl @@ -31,5 +31,5 @@ function LuxCore.outputsize(m::MulLayer) end @inline function (m::MulLayer)(x::AbstractVecOrMat, ps::Any, st::NamedTuple) - return Lux.apply_activation(m.activation, Octavian.matmul(ps.weight, x)), st + m.activation.(Octavian.matmul(ps.weight, x)), st end diff --git a/src/layers/planar_layer.jl b/src/layers/planar_layer.jl index 3f43412d..6b087abe 100644 --- a/src/layers/planar_layer.jl +++ b/src/layers/planar_layer.jl @@ -71,34 +71,33 @@ function LuxCore.outputsize(m::PlanarLayer) end @inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)), - st + ps.u * m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st end @inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st + ps.u * m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st end @inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st + ps.u * m.activation.(LinearAlgebra.dot(ps.w, z)), st end @inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * Lux.apply_activation(m.activation, transpose(ps.w) * z), st + ps.u * m.activation.(transpose(ps.w) * z), st end @inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple) - Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)), st + m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st end @inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple) - Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st + m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st end @inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple) - Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st + m.activation.(LinearAlgebra.dot(ps.w, z)), st end @inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple) - Lux.apply_activation(m.activation, transpose(ps.w) * z), st + m.activation.(transpose(ps.w) * z), st end