Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: looped dropout implementation on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 5, 2024
1 parent 1afc1c7 commit 26a745f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
7 changes: 2 additions & 5 deletions src/impl/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,8 @@ end
dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask

function CRC.rrule(::typeof(dropout_dot_mul), x::AbstractArray, mask::AbstractArray)
res = dropout_dot_mul(x, mask) # size(res) == size(x)
𝒫x = CRC.ProjectTo(x)
∇dropout_dot_mul = @closure Δ -> begin
∂x = 𝒫x(dropout_dot_mul(Δ, mask))
return ∂∅, ∂x, ∂∅
return ∂∅, (CRC.ProjectTo(x))(dropout_dot_mul(Δ, mask)), ∂∅
end
return res, ∇dropout_dot_mul
return dropout_dot_mul(x, mask), ∇dropout_dot_mul
end
32 changes: 15 additions & 17 deletions test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)),
dims in (Colon(), 1, (1, 2))
dims in (:, 1, (1, 2))

x = randn(rng, T, x_shape) |> aType

Expand Down Expand Up @@ -55,10 +55,10 @@ end

# Update mask
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
Expand All @@ -68,26 +68,25 @@ end
@test mask != mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon())))
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
__f = let rng = rng, mask = mask, p = T(0.5), invp = T(2)
x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(true), invp, :)))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)))

# Try using mask if possible (possible!!)
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
Expand All @@ -97,27 +96,26 @@ end
@test mask == mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon())))
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
__f = let rng = rng, mask = mask, p = T(0.5), invp = T(2)
x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(false), invp, :)))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)))
mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

# Testing Mode
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
Expand Down

0 comments on commit 26a745f

Please sign in to comment.