diff --git a/src/impl/dropout.jl b/src/impl/dropout.jl index 473b6a35..320eafbc 100644 --- a/src/impl/dropout.jl +++ b/src/impl/dropout.jl @@ -13,16 +13,8 @@ function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, ::True, ::False, invp::T, dims) where {T} - if dropout_shape(x, dims) != size(mask) - depwarn( - "`update_mask` is `Val(false)` but `mask` is not of the same size \ - as `LuxLib.dropout_shape(x, dims)`. This has been deprecated and \ - will be removed in the next release. Set `update_mask` to \ - `Val(true)` to avoid this.", :dropout) - mask, rngₙ = generate_dropout_mask(rng, x, p, invp, dims) - return dropout_dot_mul(x, mask), mask, rngₙ - end + ::T, ::True, ::False, invp::T, dims) where {T} + check_dropout_mask_shape_mismatch(x, mask, dims) return dropout_dot_mul(x, mask), mask, rng end @@ -31,6 +23,13 @@ function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, return (x, mask, rng) end +function check_dropout_mask_shape_mismatch(x::AbstractArray, mask::AbstractArray, dims) + @assert dropout_shape(x, dims)==size(mask) "`mask` is not of the same size as `LuxLib.dropout_shape(x, dims)`." + return nothing +end + +CRC.@non_differentiable check_dropout_mask_shape_mismatch(::Any...) + ## alpha_dropout function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True) where {T} α = T(-1.7580993408473766) diff --git a/test/common_ops/dropout_tests.jl b/test/common_ops/dropout_tests.jl index e8b637df..f7f2368b 100644 --- a/test/common_ops/dropout_tests.jl +++ b/test/common_ops/dropout_tests.jl @@ -42,8 +42,6 @@ end @testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin - Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation - using Statistics rng = StableRNG(12345) @@ -100,8 +98,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime values - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -115,35 +112,6 @@ end rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType - # Try using mask if possible (not possible!!) - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any - - y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = (x, mask) -> sum(first(dropout( - StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime activity - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true - - __f = let rng = rng, mask = mask - x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - - @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode @test @inferred(dropout( rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any