From 790b51308175268a3e5237e36aac40d509b3d344 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 29 Aug 2024 09:19:05 -0400 Subject: [PATCH] fix!: remove dropout branching based on size --- src/impl/dropout.jl | 12 ++--------- test/common_ops/dropout_tests.jl | 34 +------------------------------- 2 files changed, 3 insertions(+), 43 deletions(-) diff --git a/src/impl/dropout.jl b/src/impl/dropout.jl index 05276f86..c477087c 100644 --- a/src/impl/dropout.jl +++ b/src/impl/dropout.jl @@ -13,16 +13,8 @@ function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, ::True, ::False, invp::T, dims) where {T} - if dropout_shape(x, dims) != size(mask) - Utils.depwarn( - "`update_mask` is `Val(false)` but `mask` is not of the same size \ - as `LuxLib.dropout_shape(x, dims)`. This has been deprecated and \ - will be removed in the next release. Set `update_mask` to \ - `Val(true)` to avoid this.", :dropout) - mask, rngₙ = generate_dropout_mask(rng, x, p, invp, dims) - return dropout_dot_mul(x, mask), mask, rngₙ - end + ::T, ::True, ::False, invp::T, dims) where {T} + @assert dropout_shape(x, dims)==size(mask) "`mask` is not of the same size as `LuxLib.dropout_shape(x, dims)`." return dropout_dot_mul(x, mask), mask, rng end diff --git a/test/common_ops/dropout_tests.jl b/test/common_ops/dropout_tests.jl index e8b637df..f7f2368b 100644 --- a/test/common_ops/dropout_tests.jl +++ b/test/common_ops/dropout_tests.jl @@ -42,8 +42,6 @@ end @testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin - Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation - using Statistics rng = StableRNG(12345) @@ -100,8 +98,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime values - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -115,35 +112,6 @@ end rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType - # Try using mask if possible (not possible!!) - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any - - y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = (x, mask) -> sum(first(dropout( - StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime activity - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true - - __f = let rng = rng, mask = mask - x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - - @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode @test @inferred(dropout( rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any