diff --git a/src/api/layernorm.jl b/src/api/layernorm.jl index c374a6e1..eb147d30 100644 --- a/src/api/layernorm.jl +++ b/src/api/layernorm.jl @@ -1,6 +1,6 @@ @doc doc""" - layernorm(x, scale, bias, σ = identity, dims=Colon(), - epsilon = eps(eltype(x)) ^ (5 / 7)) + layernorm(x::AbstractArray{xT, N}, scale, bias, σ = identity, dims=1:(N - 1), + epsilon = eps(eltype(x)) ^ (5 / 7)) where {xT, N} Layer Normalization. For details see [1]. @@ -18,17 +18,13 @@ and applies the activation function `σ` elementwise to `y`. - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - `σ`: Activation function (default: `identity`) - - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`). - If `nothing` is passed, the dims are inferred based on the dimensions of scale and - bias. For example, if `x` is `N` dimensional and `scale` and `bias` are `M` - dimensional, then the dims will be `1:(N - M)`. + - `dims`: Dimensions along which the mean and std of `x` is computed. If `nothing` is + passed, the dims are inferred based on the dimensions of scale and bias. For example, + if `x` is `N` dimensional and `scale` and `bias` are `M` dimensional, then the dims + will be `1:(N - M)`. - `epsilon`: Value added to the denominator for numerical stability (default: `eps(eltype(x)) ^ (5 / 7)`) -!!! danger "Default `dims` to be changed in v1" - - By default, `dims` will exclude the batch dimension. - ## Returns Normalized Array of same size as `x`. @@ -38,9 +34,9 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AbstractArray{xT}, scale::Optional{<:AbstractArray}, - bias::Optional{<:AbstractArray}, σ::F=identity, dims=Colon(), - epsilon::Real=default_epsilon(x)) where {F, xT} +function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1), + epsilon::Real=default_epsilon(x)) where {F, xT, N} return layernorm_impl( x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon) end diff --git a/src/impl/Impl.jl b/src/impl/Impl.jl index fd2a128e..7a040456 100644 --- a/src/impl/Impl.jl +++ b/src/impl/Impl.jl @@ -29,7 +29,7 @@ using NNlib: NNlib, ConvDims using ..LuxLib: Optional, Numeric, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, contiguous, - copy_drop_gradients, depwarn, eltype_mismatch, expand_batchdim, + copy_drop_gradients, eltype_mismatch, expand_batchdim, maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking, reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning, unsafe_known, @enzyme_alternative