diff --git a/src/lib/array.jl b/src/lib/array.jl index e7fcf8a5e..08bee6bf2 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -63,6 +63,11 @@ end circshift(A, shifts), Δ -> (circshift(Δ, map(-, shifts)), nothing) end +@adjoint function reverse(x::AbstractArray, args...; kwargs...) + _reverse(t) = reverse(t, args...; kwargs...) + _reverse(x), Δ->(_reverse(Δ), map(_->nothing, args)...) +end + @adjoint permutedims(xs) = permutedims(xs), Δ -> (permutedims(Δ),) @adjoint permutedims(xs::AbstractVector) = permutedims(xs), Δ -> (vec(permutedims(Δ)),) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 2bf4e5610..e4ad93301 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -129,6 +129,11 @@ end @test gradtest(x -> meanpool(x, pdims), x) end +@test gradtest(x -> reverse(x), rand(17)) +@test gradtest(x -> reverse(x, 8), rand(17)) +@test gradtest(x -> reverse(x, 8, 13), rand(17)) +@test gradtest(x -> reverse(x, dims=2), rand(17, 42)) + @test gradtest(x -> permutedims(x), rand(2)) @test gradtest(x -> permutedims(x), rand(2,3)) @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))