Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix!: change the default layernorm dims
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 30, 2024
1 parent e2c942c commit e47e8ba
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
22 changes: 9 additions & 13 deletions src/api/layernorm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@doc doc"""
layernorm(x, scale, bias, σ = identity, dims=Colon(),
epsilon = eps(eltype(x)) ^ (5 / 7))
layernorm(x::AbstractArray{xT, N}, scale, bias, σ = identity, dims=1:(N - 1),
epsilon = eps(eltype(x)) ^ (5 / 7)) where {xT, N}
Layer Normalization. For details see [1].
Expand All @@ -18,17 +18,13 @@ and applies the activation function `σ` elementwise to `y`.
- `scale`: Scale factor (``\gamma``) (can be `nothing`)
- `bias`: Bias factor (``\beta``) (can be `nothing`)
- `σ`: Activation function (default: `identity`)
- `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`).
If `nothing` is passed, the dims are inferred based on the dimensions of scale and
bias. For example, if `x` is `N` dimensional and `scale` and `bias` are `M`
dimensional, then the dims will be `1:(N - M)`.
- `dims`: Dimensions along which the mean and std of `x` is computed. If `nothing` is
passed, the dims are inferred based on the dimensions of scale and bias. For example,
if `x` is `N` dimensional and `scale` and `bias` are `M` dimensional, then the dims
will be `1:(N - M)`.
- `epsilon`: Value added to the denominator for numerical stability
(default: `eps(eltype(x)) ^ (5 / 7)`)
!!! danger "Default `dims` to be changed in v1"
By default, `dims` will exclude the batch dimension.
## Returns
Normalized Array of same size as `x`.
Expand All @@ -38,9 +34,9 @@ Normalized Array of same size as `x`.
[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv
preprint arXiv:1607.06450 (2016).
"""
function layernorm(x::AbstractArray{xT}, scale::Optional{<:AbstractArray},
bias::Optional{<:AbstractArray}, σ::F=identity, dims=Colon(),
epsilon::Real=default_epsilon(x)) where {F, xT}
function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray},
bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1),
epsilon::Real=default_epsilon(x)) where {F, xT, N}
return layernorm_impl(
x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon)
end
2 changes: 1 addition & 1 deletion src/impl/Impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using NNlib: NNlib, ConvDims
using ..LuxLib: Optional, Numeric, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode,
GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp
using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, contiguous,
copy_drop_gradients, depwarn, eltype_mismatch, expand_batchdim,
copy_drop_gradients, eltype_mismatch, expand_batchdim,
maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking,
reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning,
unsafe_known, @enzyme_alternative
Expand Down

0 comments on commit e47e8ba

Please sign in to comment.