diff --git a/ext/LuxLibReverseDiffExt.jl b/ext/LuxLibReverseDiffExt.jl index 6f56b279..4e15e0ab 100644 --- a/ext/LuxLibReverseDiffExt.jl +++ b/ext/LuxLibReverseDiffExt.jl @@ -58,6 +58,10 @@ Utils.remove_tracking(x::TrackedArray) = ReverseDiff.value(x) Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) +Utils.within_gradient(::TrackedReal) = True() +Utils.within_gradient(::TrackedArray) = True() +Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True() + # Traits extensions Traits.is_tracked(::Type{<:TrackedReal}) = True() diff --git a/ext/LuxLibTrackerExt.jl b/ext/LuxLibTrackerExt.jl index e02c25f8..fa9ffd34 100644 --- a/ext/LuxLibTrackerExt.jl +++ b/ext/LuxLibTrackerExt.jl @@ -93,6 +93,10 @@ Utils.remove_tracking(x::TrackedArray) = Tracker.data(x) Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) +Utils.within_gradient(::TrackedReal) = True() +Utils.within_gradient(::TrackedArray) = True() +Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True() + # Traits extensions Traits.is_tracked(::Type{<:TrackedReal}) = True() diff --git a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index c2468e72..77e59d3e 100644 --- a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -29,9 +29,6 @@ end function CRC.rrule( ::typeof(Impl.batchnorm_cudnn), γ, β, x, rμ, rσ², m, ϵ, training::StaticBool) - # TODO: Transition this to an error in the future - unsafe_known(training) || - @warn "`training=Val(false)` but gradient was called." maxlog=1 y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, rμ, rσ², m, ϵ, training) 𝒫x, 𝒫γ, 𝒫β = CRC.ProjectTo(x), CRC.ProjectTo(γ), CRC.ProjectTo(β) ∇batchnorm_cudnn = @closure Δ -> begin diff --git a/src/api/API.jl b/src/api/API.jl index e353c9b2..d222d92e 100644 --- a/src/api/API.jl +++ b/src/api/API.jl @@ -8,10 +8,12 @@ using Static: Static, StaticBool, static using ..LuxLib: Optional using ..Impl: Impl, select_fastest_activation -using ..Utils: default_epsilon, expand_batchdim, remove_tracking +using ..Utils: default_epsilon, expand_batchdim, remove_tracking, static_training_mode const CRC = ChainRulesCore +const TrainingType = Union{Val{true}, Val{false}, StaticBool, Nothing} + # The names are aliased so we define constants for them for op in (:batched_matmul, :batchnorm, :bias_activation, :bias_activation!!, :dropout, :alpha_dropout, :groupnorm, :instancenorm, :layernorm, diff --git a/src/api/batchnorm.jl b/src/api/batchnorm.jl index 3f55c387..05964f0c 100644 --- a/src/api/batchnorm.jl +++ b/src/api/batchnorm.jl @@ -1,5 +1,5 @@ @doc doc""" - batchnorm(x, scale, bias, running_mean, running_var, training::Union{Val, StaticBool}, + batchnorm(x, scale, bias, running_mean, running_var, training, σ=identity, momentum = 0.1f0, epsilon = eps(eltype(x)) ^ (5 // 7)) Batch Normalization. For details see [1]. @@ -15,7 +15,9 @@ accordingly. - `bias`: Bias factor (``\beta``) (can be `nothing`) - `running_mean`: Running mean (can be `nothing`) - `running_var`: Running variance (can be `nothing`) - - `training`: Set to `Val(true)` if running in training mode + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context - `σ`: Activation function (default: `identity`) - `momentum`: Momentum for updating running mean and variance (default: `0.1f0`) - `epsilon`: Value added to the denominator for numerical stability @@ -34,11 +36,11 @@ mean and variance. """ function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, - rσ²::Optional{<:AbstractVector}, training::Union{Val, StaticBool}, - act::F=identity, momentum::Real=0.1f0, - epsilon::Real=default_epsilon(x)) where {F, T, N} + rσ²::Optional{<:AbstractVector}, training::TrainingType, act::F=identity, + momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F, T, N} σ = select_fastest_activation(act, x, γ, β, rμ, rσ²) y, rμ, rσ² = batchnorm_impl( - x, γ, β, rμ, rσ², static(training), σ, momentum, epsilon) + x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²), + σ, momentum, epsilon) return y, (; running_mean=remove_tracking(rμ), running_var=remove_tracking(rσ²)) end diff --git a/src/api/dropout.jl b/src/api/dropout.jl index b8e0d6ff..3d4e4c6d 100644 --- a/src/api/dropout.jl +++ b/src/api/dropout.jl @@ -1,7 +1,7 @@ """ - dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, invp, dims) - dropout(rng::AbstractRNG, x, mask, p, training::Union{Val, StaticBool}, - update_mask::Union{Val, StaticBool}, invp, dims) + dropout(rng::AbstractRNG, x, p, training, invp, dims) + dropout(rng::AbstractRNG, x, mask, p, training, update_mask::Union{Val, StaticBool}, + invp, dims) Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. @@ -11,10 +11,11 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see - `x`: Input Array - `mask`: Dropout Mask. If not used then it is constructed automatically - `p`: Probability of an element to be dropped out - - `Val(training)`: If `true` then dropout is applied on `x` with probability `p` along - `dims`. Else, `x` is returned - - `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` - provided is directly used + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context + - `update_mask`: If `Val(true)` or `True()` then the mask is generated and used. Else, the + `mask` provided is directly used - `invp`: Inverse multiplied to the mask. Calculated as `invp = 1 / (1 - p)`. ## Returns @@ -28,20 +29,20 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -function dropout(rng::AbstractRNG, x::AbstractArray, p::T, - training::Union{Val, StaticBool}, invp::T, dims) where {T} - return dropout_impl(rng, x, p, static(training), invp, dims) +function dropout(rng::AbstractRNG, x::AbstractArray, p::T, training::TrainingType, invp::T, + dims) where {T} + return dropout_impl(rng, x, p, static_training_mode(training, x), invp, dims) end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, training::Union{Val, StaticBool}, - update_mask::Union{Val, StaticBool}, invp::T, dims) where {T} - return dropout_impl(rng, x, mask, p, static(training), static(update_mask), invp, dims) + p::T, training::TrainingType, update_mask::TrainingType, invp::T, dims) where {T} + return dropout_impl(rng, x, mask, p, static_training_mode(training, x), + static(update_mask), invp, dims) end """ - alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}) - alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, α, A, B) + alpha_dropout(rng::AbstractRNG, x, p, training) + alpha_dropout(rng::AbstractRNG, x, p, training, α, A, B) Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the input. For details see [1]. Use the second call signature to avoid recomputing the constants @@ -52,8 +53,9 @@ for a fixed dropout probability. - `rng`: Random number generator - `x`: Input Array - `p`: Probability of an element to be dropped out - - `Val(training)`: If `true` then dropout is applied on `x` with probability `p`. Else, - `x` is returned + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context` - `α`: `-1.7580993408473766`. Computed at limit x tends to infinity, `selu(x) = -λβ = α` - `A`: Scaling factor for the mean - `B`: Scaling factor for the variance @@ -68,12 +70,11 @@ for a fixed dropout probability. [1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). """ -function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}) - return alpha_dropout_impl(rng, x, p, static(training)) +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training::TrainingType) + return alpha_dropout_impl(rng, x, p, static_training_mode(training, x)) end function alpha_dropout( - rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}, α, A, B) - return alpha_dropout_impl(rng, x, p, static(training), α, A, B) + rng::AbstractRNG, x::AbstractArray, p, training::TrainingType, α, A, B) + return alpha_dropout_impl(rng, x, p, static_training_mode(training, x), α, A, B) end diff --git a/src/api/instancenorm.jl b/src/api/instancenorm.jl index e06d7bc8..1ee4e7a2 100644 --- a/src/api/instancenorm.jl +++ b/src/api/instancenorm.jl @@ -1,5 +1,5 @@ @doc doc""" - instancenorm(x, scale, bias, training::Union{Val, StaticBool}, σ = identity, + instancenorm(x, scale, bias, training, σ = identity, epsilon = eps(eltype(x)) ^ (5 // 7)) Instance Normalization. For details see [1]. @@ -16,7 +16,9 @@ accordingly. - `σ`: Activation function (default: `identity`) - `epsilon`: Value added to the denominator for numerical stability (default: `eps(eltype(x)) ^ (5 / 7)`) - - `training`: Set to `Val(true)` if running in training mode + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context ## Returns @@ -29,13 +31,13 @@ mean and variance. missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). """ function instancenorm(x::AbstractArray, scale::Optional{<:AbstractVector}, - bias::Optional{<:AbstractVector}, training::Union{Val, StaticBool}=Val(false), + bias::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity, epsilon::Real=default_epsilon(x)) where {F} assert_valid_instancenorm_arguments(x) - σ′ = select_fastest_activation(σ, x, scale, bias) - y, xμ, xσ² = instancenorm_impl( - x, nothing, nothing, scale, bias, static(training), nothing, epsilon, σ′) + y, xμ, xσ² = instancenorm_impl(x, nothing, nothing, scale, bias, + static_training_mode(training, x, scale, bias), nothing, epsilon, + select_fastest_activation(σ, x, scale, bias)) return y, (; running_mean=xμ, running_var=xσ²) end diff --git a/src/utils.jl b/src/utils.jl index 90e9e563..c5d18bca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,7 +8,7 @@ using KernelAbstractions: KernelAbstractions using LinearAlgebra: LinearAlgebra, BLAS using MLDataDevices: get_device_type, CPUDevice using NNlib: NNlib -using Static: Static, False, True +using Static: Static, StaticBool, False, True, static using StaticArraysCore: SVector, SMatrix using ..LuxLib: Optional, ∂∅ @@ -231,4 +231,54 @@ end return end +within_gradient_vararg(args...) = unrolled_any(within_gradient, args) + +within_gradient(_) = False() +within_gradient(::ForwardDiff.Dual) = True() +within_gradient(::AbstractArray{<:ForwardDiff.Dual}) = True() + +CRC.rrule(::typeof(within_gradient), x) = True(), _ -> (∂∅, ∂∅) + +static_training_mode(::Nothing, args...) = within_gradient_vararg(args...) + +function static_training_mode( + training::Union{Bool, Val{true}, Val{false}, StaticBool}, args...) + return static_training_mode_check( + training, static(training), within_gradient_vararg(args...)) +end + +function CRC.rrule(::typeof(static_training_mode), ::Nothing, args...) + return True(), _ -> ntuple(Returns(∂∅), length(args) + 2) +end + +function CRC.rrule(::typeof(static_training_mode), + training::Union{Bool, Val{true}, Val{false}, StaticBool}, args...) + res = static_training_mode_check(training, static(training), True()) + return res, _ -> ntuple(Returns(∂∅), length(args) + 2) +end + +static_training_mode_check(_, ::True, ::True) = True() +static_training_mode_check(_, ::False, ::False) = False() + +function static_training_mode_check(training, ::True, ::False) + @warn "`training` is set to `$(training)` but is not being used within an autodiff \ + call (gradient, jacobian, etc...). This will be slow. If you are using a \ + `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. \ + Reliance on this behavior is discouraged, and is not guaranteed by Semantic \ + Versioning, and might be removed without a deprecation cycle. It is recommended \ + to fix this issue in your code. \n\n\ + If you are using Enzyme.jl, then you can ignore this warning." maxlog=1 + return True() +end + +function static_training_mode_check(training, ::False, ::True) + @warn "`training` is set to `$(training)` but is being used within an autodiff call \ + (gradient, jacobian, etc...). This might lead to incorrect results. If you are \ + using a `Lux.jl` model, set it to training mode using \ + `LuxCore.trainmode`." maxlog=1 + return False() +end + +CRC.@non_differentiable static_training_mode_check(::Any...) + end