From b88b3410afe85111c303245fdb694d3a479f0131 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 15 Oct 2022 16:57:22 +0200 Subject: [PATCH 1/7] rrule for stack --- src/ChainRules.jl | 6 ++++++ src/rulesets/Base/array.jl | 13 +++++++++++++ test/rulesets/Base/array.jl | 9 +++++++++ 3 files changed, 28 insertions(+) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 9f63eeb11..aacd22844 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -22,6 +22,12 @@ import ChainRulesCore: rrule, frule # Experimental: using ChainRulesCore: derivatives_given_output +if isdefined(Base, :stack) + using Base: stack +else + using Compat: stack +end + # numbers that we know commute under multiplication const CommutativeMulNumber = Union{Real,Complex} diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 2461c5561..8cf58a042 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -610,3 +610,16 @@ function _extrema_dims(x, dims) end return y, extrema_pullback_dims end + +##### +##### `stack` +##### + +function rrule(::typeof(stack), xs; dims::Union{Integer, Colon} = :) + dims = dims === Colon() ? ndims(first(xs)) + 1 : dims + function stack_pullback(Δ) + dy = unthunk(Δ) + return (NoTangent(), [copy(selectdim(dy, dims, i)) for i in 1:size(dy, dims)]) + end + return stack(xs; dims), stack_pullback +end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 921c81534..f39888fa5 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -416,3 +416,12 @@ end B = hcat(A[:,:,1], A[:,:,1]) @test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1] end + +@testset "stack" begin + xs = [rand(3, 4), rand(3, 4)] + + test_rrule(stack, xs, check_inferred=false) + test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false) + test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false) + test_rrule(stack, xs, fkwargs=(dims=3,), check_inferred=false) +end From f017b2581c41e130135bd0f049bb75566811831c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 15 Oct 2022 17:20:34 +0200 Subject: [PATCH 2/7] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b32ccd742..29f0bd570 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.44.7" +version = "1.45.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 1c2fcea30752932d6fb5e239b962d33782710098 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 17 Oct 2022 06:39:00 +0200 Subject: [PATCH 3/7] extend rrule to muldim containers --- Project.toml | 2 +- src/rulesets/Base/array.jl | 49 ++++++++++++++++++++++++++++++++----- test/rulesets/Base/array.jl | 8 ++++++ 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 29f0bd570..c6f8c9543 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Adapt = "3.4.0" ChainRulesCore = "1.15.3" ChainRulesTestUtils = "1.5" -Compat = "3.42.0, 4" +Compat = "3.46, 4.2" FiniteDifferences = "0.12.20" GPUArraysCore = "0.1.0" IrrationalConstants = "0.1.1" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 8cf58a042..1519eb57e 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -615,11 +615,48 @@ end ##### `stack` ##### -function rrule(::typeof(stack), xs; dims::Union{Integer, Colon} = :) - dims = dims === Colon() ? ndims(first(xs)) + 1 : dims +# function rrule(::typeof(stack), xs; dims::Union{Integer, Colon} = :) +# dims = dims === Colon() ? ndims(first(xs)) + 1 : dims +# function stack_pullback(Δ) +# dy = unthunk(Δ) +# return (NoTangent(), [copy(selectdim(dy, dims, i)) for i in 1:size(dy, dims)]) +# end +# return stack(xs; dims), stack_pullback +# end + + +function frule((_, ẋ), ::typeof(stack), x; dims::Union{Integer, Colon} = :) + return stack(x; dims), stack(ẋ; dims) +end + +# Other iterable X also allowed, maybe this should be wider? +function rrule(::typeof(stack), X::AbstractArray; dims::Union{Integer, Colon} = :) + Y = stack(X; dims) + sdims = if dims isa Colon + N = ndims(Y) - ndims(X) + X isa AbstractVector ? ndims(Y) : ntuple(i -> i + N, ndims(X)) + else + dims + end + project = ProjectTo(X) function stack_pullback(Δ) - dy = unthunk(Δ) - return (NoTangent(), [copy(selectdim(dy, dims, i)) for i in 1:size(dy, dims)]) + dY = unthunk(Δ) + dY isa NoTangent && return (NoTangent(), NoTangent()) + dY isa ZeroTangent && return (NoTangent(), ZeroTangent()) + dX = collect(eachslice(unthunk(dY); dims = sdims)) + return (NoTangent(), project(dX)) end - return stack(xs; dims), stack_pullback -end + return Y, stack_pullback +end + +# # This wants #671, but ought to work with Zygote already? +# function rrule(config::RuleConfig, ::typeof(stack), f, args...; dims::Union{Integer, Colon} = :) +# y, unmap = rrule_via_ad(config, map, f, args...) +# z, unstack = rrule(stack, y) +# function stack_pullback_f(dz) +# _, dy = unstack(dz) +# _, df, dargs... = unmap(dy) +# return (NoTangent(), df, dargs...) +# end +# return z, stack_pullback_f +# end \ No newline at end of file diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index f39888fa5..8e65b1337 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -418,10 +418,18 @@ end end @testset "stack" begin + # vector container xs = [rand(3, 4), rand(3, 4)] test_rrule(stack, xs, check_inferred=false) test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false) test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false) test_rrule(stack, xs, fkwargs=(dims=3,), check_inferred=false) + + # multidimensional container + xs = [(1,2,3) (4,5,6); (7,8,9) (10,11,12)] + + test_rrule(stack, xs, check_inferred=false) + test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false) + test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false) end From bee487e551149642e39fba39b81962c022db056d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 21 Oct 2022 21:38:37 -0400 Subject: [PATCH 4/7] hope you don't mind me committing these --- src/rulesets/Base/array.jl | 14 +------------- test/rulesets/Base/array.jl | 21 +++++++++++++++++---- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 1519eb57e..c7eb23739 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -644,19 +644,7 @@ function rrule(::typeof(stack), X::AbstractArray; dims::Union{Integer, Colon} = dY isa NoTangent && return (NoTangent(), NoTangent()) dY isa ZeroTangent && return (NoTangent(), ZeroTangent()) dX = collect(eachslice(unthunk(dY); dims = sdims)) - return (NoTangent(), project(dX)) + return (NoTangent(), project(reshape(dX, project.axes))) end return Y, stack_pullback end - -# # This wants #671, but ought to work with Zygote already? -# function rrule(config::RuleConfig, ::typeof(stack), f, args...; dims::Union{Integer, Colon} = :) -# y, unmap = rrule_via_ad(config, map, f, args...) -# z, unstack = rrule(stack, y) -# function stack_pullback_f(dz) -# _, dy = unstack(dz) -# _, df, dargs... = unmap(dy) -# return (NoTangent(), df, dargs...) -# end -# return z, stack_pullback_f -# end \ No newline at end of file diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 8e65b1337..afdeb4cb9 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -420,6 +420,8 @@ end @testset "stack" begin # vector container xs = [rand(3, 4), rand(3, 4)] + test_frule(stack, xs) + test_frule(stack, xs; fkwargs=(dims=1,)) test_rrule(stack, xs, check_inferred=false) test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false) @@ -427,9 +429,20 @@ end test_rrule(stack, xs, fkwargs=(dims=3,), check_inferred=false) # multidimensional container - xs = [(1,2,3) (4,5,6); (7,8,9) (10,11,12)] + ms = [rand(2,3) for _ in 1:4, _ in 1:5]; - test_rrule(stack, xs, check_inferred=false) - test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false) - test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false) + if VERSION > v"1.9-" # this needs new eachslice, not yet in Compat + test_rrule(stack, ms, check_inferred=false) + end + test_rrule(stack, ms, fkwargs=(dims=1,), check_inferred=false) + test_rrule(stack, ms, fkwargs=(dims=3,), check_inferred=false) + + # non-array inner objects + ts = [Tuple(rand(3)) for _ in 1:4, _ in 1:2]; + + if VERSION > v"1.9-" + test_rrule(stack, ts, check_inferred=false) + end + test_rrule(stack, ts, fkwargs=(dims=1,), check_inferred=false) + test_rrule(stack, ts, fkwargs=(dims=2,), check_inferred=false) end From 9e6ac8a4767a3e5b43e8d205461c7db9d5f69dcc Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 23 Oct 2022 12:22:50 +0200 Subject: [PATCH 5/7] import stack in tests --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 71444f388..a9f25c55c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils using Adapt using Base.Broadcast: broadcastable using ChainRules +using ChainRules: stack using ChainRulesCore using ChainRulesTestUtils using ChainRulesTestUtils: rand_tangent, _fdm From 080c5f41156504a2a9d535be577d1b43e5bae2a5 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 23 Oct 2022 12:25:44 +0200 Subject: [PATCH 6/7] cleanup --- src/rulesets/Base/array.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index c7eb23739..5f795ef95 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -615,16 +615,6 @@ end ##### `stack` ##### -# function rrule(::typeof(stack), xs; dims::Union{Integer, Colon} = :) -# dims = dims === Colon() ? ndims(first(xs)) + 1 : dims -# function stack_pullback(Δ) -# dy = unthunk(Δ) -# return (NoTangent(), [copy(selectdim(dy, dims, i)) for i in 1:size(dy, dims)]) -# end -# return stack(xs; dims), stack_pullback -# end - - function frule((_, ẋ), ::typeof(stack), x; dims::Union{Integer, Colon} = :) return stack(x; dims), stack(ẋ; dims) end From 5d28c45ec4154d62faff41215548eea5465ba70c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 10 Nov 2022 22:35:47 -0500 Subject: [PATCH 7/7] Apply3 suggestions --- src/ChainRules.jl | 6 +----- src/rulesets/Base/array.jl | 5 ++--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index aacd22844..28e73c166 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -22,11 +22,7 @@ import ChainRulesCore: rrule, frule # Experimental: using ChainRulesCore: derivatives_given_output -if isdefined(Base, :stack) - using Base: stack -else - using Compat: stack -end +using Compat: stack # numbers that we know commute under multiplication const CommutativeMulNumber = Union{Real,Complex} diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 5f795ef95..4ae424151 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -631,9 +631,8 @@ function rrule(::typeof(stack), X::AbstractArray; dims::Union{Integer, Colon} = project = ProjectTo(X) function stack_pullback(Δ) dY = unthunk(Δ) - dY isa NoTangent && return (NoTangent(), NoTangent()) - dY isa ZeroTangent && return (NoTangent(), ZeroTangent()) - dX = collect(eachslice(unthunk(dY); dims = sdims)) + dY isa AbstractZero && return (NoTangent(), dY) + dX = collect(eachslice(dY; dims = sdims)) return (NoTangent(), project(reshape(dX, project.axes))) end return Y, stack_pullback