Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: modify the dropout testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 5, 2024
1 parent 1afc1c7 commit 8b54f89
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 27 deletions.
1 change: 0 additions & 1 deletion src/impl/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(fused_dense!)},
weight::EnzymeCore.Annotation{<:AbstractMatrix},
x::EnzymeCore.Annotation{<:AbstractMatrix},
b::EnzymeCore.Annotation{<:Optional{<:AbstractVector}})
# TODO: For the other cases
case_specific_cache, weight_cache, x_cache = cache

(case, tmp) = case_specific_cache
Expand Down
9 changes: 3 additions & 6 deletions src/impl/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ end

function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray,
::T, ::False, ::False, invp::T, dims) where {T}
return (x, mask, rng)
return x, mask, rng
end

function check_dropout_mask_shape_mismatch(x::AbstractArray, mask::AbstractArray, dims)
Expand Down Expand Up @@ -205,11 +205,8 @@ end
dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask

function CRC.rrule(::typeof(dropout_dot_mul), x::AbstractArray, mask::AbstractArray)
res = dropout_dot_mul(x, mask) # size(res) == size(x)
𝒫x = CRC.ProjectTo(x)
∇dropout_dot_mul = @closure Δ -> begin
∂x = 𝒫x(dropout_dot_mul(Δ, mask))
return ∂∅, ∂x, ∂∅
return ∂∅, (CRC.ProjectTo(x))(dropout_dot_mul(Δ, mask)), ∂∅
end
return res, ∇dropout_dot_mul
return dropout_dot_mul(x, mask), ∇dropout_dot_mul
end
39 changes: 20 additions & 19 deletions test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)),
dims in (Colon(), 1, (1, 2))
dims in (:, 1, (1, 2))

x = randn(rng, T, x_shape) |> aType

Expand Down Expand Up @@ -55,10 +55,10 @@ end

# Update mask
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
Expand All @@ -68,26 +68,25 @@ end
@test mask != mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon())))
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
__f = let rng = rng, mask = mask, p = T(0.5), invp = T(2)
x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(true), invp, :)))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)))

# Try using mask if possible (possible!!)
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
Expand All @@ -97,27 +96,29 @@ end
@test mask == mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon())))
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
__f = let rng = rng, mask = mask, p = T(0.5), invp = T(2)
x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(false), invp, :)))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),

soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : []
skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : []

test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends,
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)))
mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

# Testing Mode
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
Expand Down
3 changes: 2 additions & 1 deletion test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp
__f = (args...) -> sum(first(instancenorm(
args..., rm, rv, training, act, T(0.1), epsilon)))
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
skip_backends = (Sys.iswindows() && fp16) ? [AutoEnzyme()] : []
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail, skip_backends)
end
end

Expand Down

0 comments on commit 8b54f89

Please sign in to comment.