Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

fix: dropout enzyme test fixes #153

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/impl/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(fused_dense!)},
weight::EnzymeCore.Annotation{<:AbstractMatrix},
x::EnzymeCore.Annotation{<:AbstractMatrix},
b::EnzymeCore.Annotation{<:Optional{<:AbstractVector}})
# TODO: For the other cases
case_specific_cache, weight_cache, x_cache = cache

(case, tmp) = case_specific_cache
Expand Down
9 changes: 3 additions & 6 deletions src/impl/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ end

function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray,
::T, ::False, ::False, invp::T, dims) where {T}
return (x, mask, rng)
return x, mask, rng
end

function check_dropout_mask_shape_mismatch(x::AbstractArray, mask::AbstractArray, dims)
Expand Down Expand Up @@ -205,11 +205,8 @@ end
dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask

function CRC.rrule(::typeof(dropout_dot_mul), x::AbstractArray, mask::AbstractArray)
res = dropout_dot_mul(x, mask) # size(res) == size(x)
𝒫x = CRC.ProjectTo(x)
∇dropout_dot_mul = @closure Δ -> begin
∂x = 𝒫x(dropout_dot_mul(Δ, mask))
return ∂∅, ∂x, ∂∅
return ∂∅, (CRC.ProjectTo(x))(dropout_dot_mul(Δ, mask)), ∂∅
end
return res, ∇dropout_dot_mul
return dropout_dot_mul(x, mask), ∇dropout_dot_mul
end
39 changes: 20 additions & 19 deletions test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)),
dims in (Colon(), 1, (1, 2))
dims in (:, 1, (1, 2))

x = randn(rng, T, x_shape) |> aType

Expand Down Expand Up @@ -55,10 +55,10 @@ end

# Update mask
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
Expand All @@ -68,26 +68,25 @@ end
@test mask != mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon())))
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
__f = let rng = rng, mask = mask, p = T(0.5), invp = T(2)
x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(true), invp, :)))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)))

# Try using mask if possible (possible!!)
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
Expand All @@ -97,27 +96,29 @@ end
@test mask == mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon())))
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
__f = let rng = rng, mask = mask, p = T(0.5), invp = T(2)
x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(false), invp, :)))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),

soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : []
skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : []

test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends,
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)))
mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

# Testing Mode
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
Expand Down
3 changes: 2 additions & 1 deletion test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp
__f = (args...) -> sum(first(instancenorm(
args..., rm, rv, training, act, T(0.1), epsilon)))
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
skip_backends = (Sys.iswindows() && fp16) ? [AutoEnzyme()] : []
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail, skip_backends)
end
end

Expand Down
Loading