Skip to content

Commit

Permalink
Merge #515
Browse files Browse the repository at this point in the history
515: Fix Flux.flip by providing an adjoint for Base.reverse r=dhairyagandhi96 a=tanhevg

The main motivation behind this PR is to address various issues concerning `Flux.flip()` (used mainly for bRNNs), e.g.  FluxML/Flux.jl#962, FluxML/Flux.jl#990 and FluxML/model-zoo#179

Co-authored-by: Evgeny Tankhilevich <[email protected]>
  • Loading branch information
bors[bot] and Evgeny Tankhilevich authored Feb 25, 2020
2 parents 94441dd + 8f7da4d commit 701135b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ end
circshift(A, shifts), Δ -> (circshift(Δ, map(-, shifts)), nothing)
end

@adjoint function reverse(x::AbstractArray, args...; kwargs...)
_reverse(t) = reverse(t, args...; kwargs...)
_reverse(x), Δ->(_reverse(Δ), map(_->nothing, args)...)
end

@adjoint permutedims(xs) = permutedims(xs), Δ -> (permutedims(Δ),)

@adjoint permutedims(xs::AbstractVector) = permutedims(xs), Δ -> (vec(permutedims(Δ)),)
Expand Down
5 changes: 5 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ end
@test gradtest(x -> meanpool(x, pdims), x)
end

@test gradtest(x -> reverse(x), rand(17))
@test gradtest(x -> reverse(x, 8), rand(17))
@test gradtest(x -> reverse(x, 8, 13), rand(17))
@test gradtest(x -> reverse(x, dims=2), rand(17, 42))

@test gradtest(x -> permutedims(x), rand(2))
@test gradtest(x -> permutedims(x), rand(2,3))
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
Expand Down

0 comments on commit 701135b

Please sign in to comment.