From 63d7b14fb5c6d0a0d956644f6597a71e0e9af475 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 14 Jan 2022 23:38:26 -0500 Subject: [PATCH 01/10] minimal change foldl -> mapfoldl_impl --- src/rulesets/Base/mapreduce.jl | 23 ++++++++------ test/rulesets/Base/mapreduce.jl | 54 ++++++++++++++++++++------------- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index a83f72cc7..c45a2a4b1 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -427,11 +427,12 @@ end # to carry intermediate results along creates arrays of tuples which could be avoided; using a # loop can be a few times faster. Note also that it does not return a gradient for `init`. +# Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier? + function rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(foldl), op::G, x::Union{AbstractArray, Tuple}; - init=_InitialValue() + config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple}; ) where {G} - list, start = if init === _InitialValue() + list, start = if init === _INIT _drop1(x), first(x) else # Case with init keyword is simpler to understand first! @@ -455,11 +456,12 @@ function rrule( end dop = sum(first, trio) dx = map(last, _reverse1(trio)) - if init === _InitialValue() + if init === _INIT # `hobbits` is one short dx = _vcat1(trio[end][2], dx) end - return (NoTangent(), dop, project(_reshape1(dx, axe))) + d_init = @not_implemented "gradient for foldl does not at present include init, sorry" + return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1(dx, axe))) end return y, unfoldl end @@ -484,7 +486,8 @@ _reverse1(x::Tuple) = reverse(x) _drop1(x::Tuple) = Base.tail(x) _zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N) -struct _InitialValue end # Old versions don't have `Base._InitialValue` +# struct _InitialValue end # Old versions don't have `Base._InitialValue` +const _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple() _vcat1(x, ys::AbstractVector) = vcat(x, ys) _vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys) @@ -505,7 +508,7 @@ _no_tuple_tangent(dx) = dx function rrule( config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple}; - init=_InitialValue(), dims=nothing + init=_INIT, dims=nothing ) where {G} isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw( "accumulate(op, x; dims) is not currently supported by ChainRules, sorry" @@ -513,7 +516,7 @@ function rrule( # gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4)) # ERROR: Mutating arrays is not supported ) - list, start = if init === _InitialValue() + list, start = if init === _INIT _drop1(x), first(x) else x, init @@ -522,7 +525,7 @@ function rrule( c, back = rrule_via_ad(config, op, a, b) end y = map(first, hobbits) - if init === _InitialValue() + if init === _INIT # `hobbits` is one short, and first one doesn't invoke `op` y = _vcat1(first(x), y) end @@ -543,7 +546,7 @@ function rrule( end dop = sum(first, trio) dx = map(last, _reverse1(trio)) - if init == _InitialValue() + if init == _INIT # `hobbits` is one short, and the first one is weird dx = _vcat1(trio[end][2] + dy_plain[1], dx) end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 89f41c933..f458e38e1 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -213,60 +213,72 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end end # prod @testset "foldl(f, ::Array)" begin + # `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is + # now attached there, as this is the simplest way to handle `init` keyword. + @eval using Base: mapfoldl_impl + @eval _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple() + # Simple - y1, b1 = rrule(CFG, foldl, *, [1, 2, 3]; init=1) + y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3]) @test y1 == 6 - b1(7) == (NoTangent(), NoTangent(), [42, 21, 14]) + @test b1(7)[1:3] == (NoTangent(), NoTangent(), NoTangent()) + @test b1(7)[4] isa ChainRulesCore.NotImplemented + @test b1(7)[5] == [42, 21, 14] - y2, b2 = rrule(CFG, foldl, *, [1 2; 0 4]) # without init, needs vcat + y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, [1 2; 0 4]) # without init, needs vcat @test y2 == 0 - b2(8) == (NoTangent(), NoTangent(), [0 0; 64 0]) # matrix, needs reshape + @test b2(8)[5] == [0 0; 64 0] # matrix, needs reshape # Test execution order c5 = Counter() - y5, b5 = rrule(CFG, foldl, c5, [5, 7, 11]) + y5, b5 = rrule(CFG, mapfoldl_impl, identity, c5, _INIT, [5, 7, 11]) @test c5 == Counter(2) @test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), [5, 7, 11]) - @test b5(1) == (NoTangent(), NoTangent(), [12*32, 12*42, 22]) + @test b5(1)[5] == [12*32, 12*42, 22] @test c5 == Counter(42) c6 = Counter() - y6, b6 = rrule(CFG, foldl, c6, [5, 7, 11], init=3) + y6, b6 = rrule(CFG, mapfoldl_impl, identity, c6, 3, [5, 7, 11]) @test c6 == Counter(3) @test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), [5, 7, 11], init=3) - @test b6(1) == (NoTangent(), NoTangent(), [63*33*13, 43*13, 23]) + @test b6(1)[5] == [63*33*13, 43*13, 23] @test c6 == Counter(63) # Test gradient of function - y7, b7 = rrule(CFG, foldl, Multiplier(3), [5, 7, 11]) + y7, b7 = rrule(CFG, mapfoldl_impl, identity, Multiplier(3), _INIT, [5, 7, 11]) @test y7 == foldl((x,y)->x*y*3, [5, 7, 11]) - @test b7(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2310,), [693, 495, 315]) + b7_1 = b7(1) + @test b7_1[3] == Tangent{Multiplier{Int}}(x = 2310,) + @test b7_1[5] == [693, 495, 315] - y8, b8 = rrule(CFG, foldl, Multiplier(13), [5, 7, 11], init=3) + y8, b8 = rrule(CFG, mapfoldl_impl, identity, Multiplier(13), 3, [5, 7, 11]) @test y8 == 2_537_535 == foldl((x,y)->x*y*13, [5, 7, 11], init=3) - @test b8(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 585585,), [507507, 362505, 230685]) + b8_1 = b8(1) + @test b8_1[3] == Tangent{Multiplier{Int}}(x = 585585,) + @test b8_1[5] == [507507, 362505, 230685] # To find these numbers: # ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13) # ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string # Finite differencing - test_rrule(foldl, /, 1 .+ rand(3,4)) - test_rrule(foldl, *, rand(ComplexF64,3,4); fkwargs=(; init=rand(ComplexF64))) - test_rrule(foldl, +, rand(ComplexF64,7); fkwargs=(; init=rand(ComplexF64))) - test_rrule(foldl, max, rand(3); fkwargs=(; init=999)) + test_rrule(mapfoldl_impl, identity, /, _INIT, 1 .+ rand(3,4)) + test_rrule(mapfoldl_impl, identity, *, rand(ComplexF64), rand(ComplexF64,3,4)) + test_rrule(mapfoldl_impl, identity, +, rand(ComplexF64), rand(ComplexF64,7)) + test_rrule(mapfoldl_impl, identity, max, 999, rand(3)) end @testset "foldl(f, ::Tuple)" begin y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1) + y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, (1,2,3)) @test y1 == 6 - b1(7) == (NoTangent(), NoTangent(), Tangent{NTuple{3,Int}}(42, 21, 14)) + @test b1(7)[5] == Tangent{NTuple{3,Int}}(42, 21, 14) - y2, b2 = rrule(CFG, foldl, *, (1, 2, 0, 4)) + y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, (1, 2, 0, 4)) @test y2 == 0 - b2(8) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(0, 0, 64, 0)) + @test b2(8)[5] == Tangent{NTuple{4,Int}}(0, 0, 64, 0) # Finite differencing - test_rrule(foldl, /, Tuple(1 .+ rand(5))) - test_rrule(foldl, *, Tuple(rand(ComplexF64, 5))) + test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5))) + test_rrule(mapfoldl_impl, identity, *, _INIT, Tuple(rand(ComplexF64, 5))) end end From 78b823007647e8375d1d08e692b8170169d30234 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 15 Jan 2022 09:23:57 -0500 Subject: [PATCH 02/10] change accumulate -> _accumulate! --- src/rulesets/Base/mapreduce.jl | 36 ++++++++++++++----------- test/rulesets/Base/mapreduce.jl | 48 +++++++++++++++++++++++++-------- 2 files changed, 57 insertions(+), 27 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index c45a2a4b1..553f0bcfb 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -506,34 +506,35 @@ _no_tuple_tangent(dx) = dx # Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`. +# Also like `foldl`, the version with a keyword `init` can't easily be given a gradient. +# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)` + function rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple}; - init=_INIT, dims=nothing + config::RuleConfig{>:HasReverseMode}, ::typeof(Base._accumulate!), op::G, y, x::AbstractVector, dims::Nothing, init, ) where {G} - isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw( - "accumulate(op, x; dims) is not currently supported by ChainRules, sorry" - # It's not supported by AD either, so no point calling back, and no regression: - # gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4)) - # ERROR: Mutating arrays is not supported - ) - list, start = if init === _INIT + + list, start = if init === nothing _drop1(x), first(x) else - x, init + x, something(init) end hobbits = accumulate(list; init = (start, nothing)) do (a, _), b c, back = rrule_via_ad(config, op, a, b) end - y = map(first, hobbits) - if init === _INIT + # y = map(first, hobbits) + if init === nothing # `hobbits` is one short, and first one doesn't invoke `op` - y = _vcat1(first(x), y) + # y = _vcat1(first(x), y) + y[1] = first(x) + map!(first, @view(y[2:end]), hobbits) + else + map!(first, y, hobbits) end axe = axes(x) project = ProjectTo(x) function decumulate(dy) dy_plain = _no_tuple_tangent(unthunk(dy)) - rev_list = if init === _InitialValue() + rev_list = if init === nothing # Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...)) # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{" _zip2(_reverse1(hobbits), _reverse1(dy_plain)) @@ -546,11 +547,14 @@ function rrule( end dop = sum(first, trio) dx = map(last, _reverse1(trio)) - if init == _INIT + if init == nothing # `hobbits` is one short, and the first one is weird dx = _vcat1(trio[end][2] + dy_plain[1], dx) end - return (NoTangent(), dop, project(_reshape1(dx, axe))) + dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only" + d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry" + d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not) + return (NoTangent(), dop, dy, project(_reshape1(dx, axe)), NoTangent(), d_init) end return _reshape1(y, axe), decumulate end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index f458e38e1..d0480213e 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -216,7 +216,7 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end # `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is # now attached there, as this is the simplest way to handle `init` keyword. @eval using Base: mapfoldl_impl - @eval _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple() + _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple() # Simple y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3]) @@ -336,10 +336,17 @@ end end # cumprod @testset "accumulate(f, ::Array)" begin + # `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`. + # The rule is now attached there, as this is the simplest way to handle `init` keyword. + @eval using Base: _accumulate! + # Simple - y1, b1 = rrule(CFG, accumulate, *, [1, 2, 3, 4]; init=1) + y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1)) @test y1 == [1, 2, 6, 24] - @test b1([1, 1, 1, 1]) == (NoTangent(), NoTangent(), [33, 16, 10, 6]) + @test b1([1, 1, 1, 1])[3] isa ChainRulesCore.NotImplemented + @test b1([1, 1, 1, 1])[4] == [33, 16, 10, 6] + @test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}} + @test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4]) @test y2 ≈ accumulate(/, [1 2; 3 4]) @@ -347,25 +354,27 @@ end # Test execution order c3 = Counter() - y3, b3 = rrule(CFG, accumulate, c3, [5, 7, 11]; init=3) + y3, b3 = rrule(CFG, _accumulate!, c3, [0, 0, 0], [5, 7, 11], nothing, Some(3)) @test c3 == Counter(3) @test y3 == [8, 30, 123] == accumulate(Counter(), [5, 7, 11]; init=3) - @test b3([1, 1, 1]) == (NoTangent(), NoTangent(), [29169, 602, 23]) # the 23 is clear! + @test b3([1, 1, 1])[4] == [29169, 602, 23] # the 23 is clear! c4 = Counter() - y4, b4 = rrule(CFG, accumulate, c4, [5, 7, 11]) + y4, b4 = rrule(CFG, _accumulate!, c4, [0, 0, 0], [5, 7, 11], nothing, nothing) @test c4 == Counter(2) @test y4 == [5, (5+7)*1, ((5+7)*1 + 11)*2] == accumulate(Counter(), [5, 7, 11]) - @test b4([1, 1, 1]) == (NoTangent(), NoTangent(), [417, 42*(1 + 12), 22]) + @test b4([1, 1, 1])[4] == [417, 42*(1 + 12), 22] # Test gradient of function - y7, b7 = rrule(CFG, accumulate, Multiplier(3), [5, 7, 11]) + y7, b7 = rrule(CFG, _accumulate!, Multiplier(3), [0, 0, 0], [5, 7, 11], nothing, nothing) @test y7 == accumulate((x,y)->x*y*3, [5, 7, 11]) - @test b7([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2345,), [715, 510, 315]) + @test b7([1, 1, 1])[2] == Tangent{Multiplier{Int}}(; x = 2345,) + @test b7([1, 1, 1])[4] == [715, 510, 315] - y8, b8 = rrule(CFG, accumulate, Multiplier(13), [5, 7, 11], init=3) + y8, b8 = rrule(CFG, _accumulate!, Multiplier(13), [0, 0, 0], [5, 7, 11], nothing, Some(3)) @test y8 == [195, 17745, 2537535] == accumulate((x,y)->x*y*13, [5, 7, 11], init=3) - @test b8([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 588330,), [511095, 365040, 230685]) + @test b8([1, 1, 1])[2] == Tangent{Multiplier{Int}}(; x = 588330,) + @test b8([1, 1, 1])[4] == [511095, 365040, 230685] # To find these numbers: # ForwardDiff.derivative(z -> sum(accumulate((x,y)->x*y*z, [5,7,11], init=3)), 13) # ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string @@ -384,5 +393,22 @@ end # Finite differencing test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) + + test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, nothing) + test_rrule(_accumulate!, /, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(1 + rand())) + # if VERSION >= v"1.5" + # test_rrule(accumulate, /, 1 .+ rand(3, 4)) + # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) + # end end + # VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin + # # Simple + # y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) + # @test y1 == (1, 2, 6, 24) + # @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) + + # # Finite differencing + # test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) + # test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) + # end end From 6ca472661e86a046d5c883e2e2858c8230c0d3d3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 18 Jul 2022 21:03:05 -0400 Subject: [PATCH 03/10] separate rule for foldl(::Tuple) --- src/rulesets/Base/mapreduce.jl | 71 +++++++++++++++++++++++++++++---- test/rulesets/Base/mapreduce.jl | 32 +++++++++++++-- 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 553f0bcfb..6ec389f8d 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -417,17 +417,73 @@ end end ##### -##### `foldl` +##### +##### `foldl(f, ::Tuple)` ##### # `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when -# this `f` is stateful, in which case the gradient must be calculated in the reverse order. +# this `f` is stateful, in which case the gradient must be calculated in the reverse order. + +# The rule is attached to `Base.mapfoldl_impl` because this gets the `init` keyword as an argument, +# which is handled below. For tuples, `reduce` also comes here. + +function rrule( + config::RuleConfig{>:HasReverseMode}, + ::typeof(Base.mapfoldl_impl), + ::typeof(identity), + op::G, + init::Base._InitialValue, + x::Tuple; + ) where {G} + hobbits = accumulate(Base.tail(x); init=(first(x), nothing)) do (a, _), b + # Here `a` is what we would normally cary forward, and `_` ignores + # the previous iteration's pullback function (needed later), + # while `b` is the fresh input from `list` as usual. + c, back = rrule_via_ad(config, op, a, b) + # We don't really need to store every `c`, last one is `foldl` output. + # (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.) + end + y = first(last(hobbits)) + project = ProjectTo(x) + function foldl_pullback_tuple(dy) + trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) + ds, da, db = back(dc) + # Don't need to store every `da`, need one for the next iteration + the last. + end + dop = sum(first, trio) + dx = (trio[end][2], reverse(map(last, trio))...) + return (NoTangent(), NoTangent(), ProjectTo(op)(dop), NoTangent(), project(dx)) + end + return y, foldl_pullback_tuple +end + +function rrule( + config::RuleConfig{>:HasReverseMode}, + ::typeof(Base.mapfoldl_impl), + ::typeof(identity), + op::G, + init, + x::Tuple; + ) where {G} + # Treat `init` by simply appending it to the `x`: + y, back = rrule(config, Base.mapfoldl_impl, identity, op, Base._InitialValue(), (init, x...)) + project_x = ProjectTo(x) + project_in = ProjectTo(init) + function foldl_pullback_tuple_init(dy) + _, _, dop, _, dxplus = back(dy) + return (NoTangent(), NoTangent(), dop, project_in(first(dxplus)), project_x(Base.tail(dxplus))) + end + return y, foldl_pullback_tuple_init +end -# The implementation aims to be efficient for both tuples and arrays, although using accumulate -# to carry intermediate results along creates arrays of tuples which could be avoided; using a -# loop can be a few times faster. Note also that it does not return a gradient for `init`. +##### +##### `foldl(f, ::Array)` +##### -# Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier? +# The implementation was originally for both tuples and arrays, although using accumulate +# to carry intermediate results along creates arrays of tuples which could be avoided. +# Using a loop can be a few times faster, this should be replaced. +# Note also that it does not return a gradient for `init`. function rrule( config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple}; @@ -486,8 +542,7 @@ _reverse1(x::Tuple) = reverse(x) _drop1(x::Tuple) = Base.tail(x) _zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N) -# struct _InitialValue end # Old versions don't have `Base._InitialValue` -const _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple() +const _INIT = Base._InitialValue() _vcat1(x, ys::AbstractVector) = vcat(x, ys) _vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index d0480213e..5080e94a1 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -2,6 +2,11 @@ Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights) struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end +const CFG = ChainRulesTestUtils.ADviaRuleConfig() + +using Base: mapfoldl_impl, _accumulate! # for foldl & accumulate rules +const _INIT = Base._InitialValue() + @testset "Reductions" begin @testset "sum(::Tuple)" begin test_frule(sum, Tuple(rand(5))) @@ -215,8 +220,6 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end @testset "foldl(f, ::Array)" begin # `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is # now attached there, as this is the simplest way to handle `init` keyword. - @eval using Base: mapfoldl_impl - _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple() # Simple y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3]) @@ -267,7 +270,6 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end test_rrule(mapfoldl_impl, identity, max, 999, rand(3)) end @testset "foldl(f, ::Tuple)" begin - y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1) y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, (1,2,3)) @test y1 == 6 @test b1(7)[5] == Tangent{NTuple{3,Int}}(42, 21, 14) @@ -275,10 +277,32 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, (1, 2, 0, 4)) @test y2 == 0 @test b2(8)[5] == Tangent{NTuple{4,Int}}(0, 0, 64, 0) + + # Test execution order + c5 = Counter() + y5, b5 = rrule(CFG, mapfoldl_impl, identity, c5, _INIT, (5, 7, 11)) + @test c5 == Counter(2) + @test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), (5, 7, 11)) + @test collect(b5(1)[5]) == [12*32, 12*42, 22] + @test c5 == Counter(42) + + c6 = Counter() + y6, b6 = rrule(CFG, mapfoldl_impl, identity, c6, 3, (5, 7, 11)) + @test c6 == Counter(3) + @test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), (5, 7, 11), init=3) + @test collect(b6(1)[5]) == [63*33*13, 43*13, 23] + @test c6 == Counter(63) + + # Test gradient of function + y7, b7 = rrule(CFG, mapfoldl_impl, identity, Multiplier(3), _INIT, (5, 7, 11)) + @test y7 == foldl((x,y)->x*y*3, (5, 7, 11)) + b7_1 = b7(1) + @test b7_1[3] == Tangent{Multiplier{Int}}(x = 2310,) + @test collect(b7_1[5]) == [693, 495, 315] # Finite differencing test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5))) - test_rrule(mapfoldl_impl, identity, *, _INIT, Tuple(rand(ComplexF64, 5))) + test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5))) end end From e17b0f1e60d97bed8c42689b45d9916d826b55a7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 18 Jul 2022 21:03:26 -0400 Subject: [PATCH 04/10] fix accumulate tests --- test/rulesets/Base/mapreduce.jl | 45 ++++++++------------------------- 1 file changed, 10 insertions(+), 35 deletions(-) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 5080e94a1..699d91a15 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -359,10 +359,9 @@ end end end # cumprod - @testset "accumulate(f, ::Array)" begin + @testset "accumulate(f, ::Vector)" begin # `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`. # The rule is now attached there, as this is the simplest way to handle `init` keyword. - @eval using Base: _accumulate! # Simple y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1)) @@ -372,9 +371,9 @@ end @test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}} @test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented - y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4]) - @test y2 ≈ accumulate(/, [1 2; 3 4]) - @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 + # y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing) + # @test y2 ≈ accumulate(/, [1 2; 3 4.0]) + # @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 # Test execution order c3 = Counter() @@ -404,35 +403,11 @@ end # ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string # Finite differencing - test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) - test_rrule(accumulate, /, 1 .+ rand(3, 4)) - test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) + # test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) + test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(rand())) + # test_rrule(accumulate, /, 1 .+ rand(3, 4)) + test_rrule(_accumulate!, /, randn(4) ⊢ NoTangent(), 1 .+ rand(4), nothing, nothing) + # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) + test_rrule(_accumulate!, ^, randn(6) ⊢ NoTangent(), 1 .+ rand(6), nothing, Some(rand())) end - @testset "accumulate(f, ::Tuple)" begin - # Simple - y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) - @test y1 == (1, 2, 6, 24) - @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) - - # Finite differencing - test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) - test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) - - test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, nothing) - test_rrule(_accumulate!, /, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(1 + rand())) - # if VERSION >= v"1.5" - # test_rrule(accumulate, /, 1 .+ rand(3, 4)) - # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) - # end - end - # VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin - # # Simple - # y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) - # @test y1 == (1, 2, 6, 24) - # @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) - - # # Finite differencing - # test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) - # test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) - # end end From c6985748cbcad055c91020d0c2b3a5c5851e03ac Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 18 Jul 2022 21:03:43 -0400 Subject: [PATCH 05/10] simple rule for mapfoldl --- src/rulesets/Base/mapreduce.jl | 19 +++++++++++++++++++ test/rulesets/Base/mapreduce.jl | 5 +++++ 2 files changed, 24 insertions(+) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 6ec389f8d..e72be275c 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -417,6 +417,25 @@ end end ##### +##### `mapfoldl(f, g, ::Tuple)` +##### + +# For tuples there should be no harm in handling `map` first. +# This will also catch `mapreduce`. + +function rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), f::F, op::G, init, x::Tuple; + ) where {F,G} + y, backmap = rrule(cfg, map, f, x) + z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y) + function mapfoldl_pullback_tuple(dz) + _, _, dop, dinit, dy = backred(dz) + _, df, dx = backmap(dy) + return (NoTangent(), df, dop, dinit, dx) + end + return z, mapfoldl_pullback_tuple +end + ##### ##### `foldl(f, ::Tuple)` ##### diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 699d91a15..80bc58bc6 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -304,6 +304,11 @@ const _INIT = Base._InitialValue() test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5))) test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5))) end + @testset "mapfoldl(f, g, ::Tuple)" begin + test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false) + test_rrule(mapfoldl_impl, abs2, *, 1+rand(), Tuple(rand(ComplexF64, 5)), check_inferred=false) + # TODO make the `map(f, ::Tuple)` rule infer better! + end end @testset "Accumulations" begin From 1961acaf5581f4ecedae0fb206103144ce702a18 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 18 Jul 2022 23:13:22 -0400 Subject: [PATCH 06/10] tidy up --- src/rulesets/Base/mapreduce.jl | 96 +++++++++++----------------------- 1 file changed, 31 insertions(+), 65 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index e72be275c..fa8c1c576 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -465,7 +465,7 @@ function rrule( y = first(last(hobbits)) project = ProjectTo(x) function foldl_pullback_tuple(dy) - trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) + trio = accumulate(reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) ds, da, db = back(dc) # Don't need to store every `da`, need one for the next iteration + the last. end @@ -501,78 +501,43 @@ end # The implementation was originally for both tuples and arrays, although using accumulate # to carry intermediate results along creates arrays of tuples which could be avoided. -# Using a loop can be a few times faster, this should be replaced. -# Note also that it does not return a gradient for `init`. +# Using a loop can be a few times faster, this should be replaced: +# https://github.com/FluxML/Zygote.jl/issues/644#issuecomment-628762305 + +# Note also that it does not return a gradient for `init`, now marked `@not_implemented`. function rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple}; + config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple}; ) where {G} - list, start = if init === _INIT - _drop1(x), first(x) + start, list = if init === Base._InitialValue() + Iterators.peel(x) else # Case with init keyword is simpler to understand first! - _reshape1(x, :), init # (vec is for Julia 1.0, accumulate is fussy) + init, x end - hobbits = accumulate(list; init=(start, nothing)) do (a,_), b - # Here `a` is what we would normally cary forward, and `_` ignores - # the previous iteration's pullback function (needed later), - # while `b` is the fresh input from `list` as usual. - c, back = rrule_via_ad(config, op, a, b) # LHS is just documentation here! - # We don't really need to store every `c`, last one is `foldl` output. - # (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.) + hobbits = accumulate(list; init=(start, nothing)) do (a, _), b + c, back = rrule_via_ad(config, op, a, b) end y = first(last(hobbits)) axe = axes(x) project = ProjectTo(x) function unfoldl(dy) - trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) + trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) ds, da, db = back(dc) - # Don't need to store every `da`, need one for the next iteration + maybe last end dop = sum(first, trio) - dx = map(last, _reverse1(trio)) - if init === _INIT - # `hobbits` is one short + dx = map(last, Iterators.reverse(trio)) + if init === Base._InitialValue() # `hobbits` is one short dx = _vcat1(trio[end][2], dx) end d_init = @not_implemented "gradient for foldl does not at present include init, sorry" - return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1(dx, axe))) + return (NoTangent(), NoTangent(), dop, d_init, project(reshape(dx, axe))) end return y, unfoldl end - -##### -##### Iterator-or-Tuple functions -##### - -# This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays, -# and also provides some alternatives for versions of Julia where iterators weren't supported. -# Inspired by `Base._reverse`, used in defn of `foldr`. - -# To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps -# be replaced by _peel1 like Iterators.peel - -_reverse1(x) = Iterators.reverse(x) -_drop1(x) = Iterators.drop(x, 1) -_zip2(x, y) = zip(x, y) # for `accumulate`, below - -_reverse1(x::Tuple) = reverse(x) -_drop1(x::Tuple) = Base.tail(x) -_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N) - -const _INIT = Base._InitialValue() - _vcat1(x, ys::AbstractVector) = vcat(x, ys) _vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys) -_vcat1(x, ys::Tuple) = (x, ys...) - -_reshape1(x::AbstractArray, axe) = reshape(x, axe) -_reshape1(x::Tuple, axe) = x - -_no_tuple_tangent(dx::Tangent) = ChainRulesCore.backing(dx) -_no_tuple_tangent(dx) = dx - ##### ##### `accumulate` @@ -584,13 +549,18 @@ _no_tuple_tangent(dx) = dx # Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)` function rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(Base._accumulate!), op::G, y, x::AbstractVector, dims::Nothing, init, + config::RuleConfig{>:HasReverseMode}, + ::typeof(Base._accumulate!), + op::G, y::AbstractVector, + x::AbstractVector, + dims::Nothing, + init, ) where {G} - list, start = if init === nothing - _drop1(x), first(x) + start, list = if init === nothing + Iterators.peel(x) else - x, something(init) + something(init), x end hobbits = accumulate(list; init = (start, nothing)) do (a, _), b c, back = rrule_via_ad(config, op, a, b) @@ -607,20 +577,16 @@ function rrule( axe = axes(x) project = ProjectTo(x) function decumulate(dy) - dy_plain = _no_tuple_tangent(unthunk(dy)) - rev_list = if init === nothing - # Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...)) - # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{" - _zip2(_reverse1(hobbits), _reverse1(dy_plain)) - else - _zip2(_reverse1(hobbits), _reverse1(dy_plain)) - end + dy_plain = unthunk(dy) + rev_list = zip(Iterators.reverse(hobbits), Iterators.reverse(dy_plain)) + # Here we rely on `zip` to stop early when init === nothing. Begin explicit with Iterators.reverse(Iterators.drop(..., 1)) + # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{" trio = accumulate(rev_list; init=(0, ZeroTangent(), 0)) do (_, dc, _), ((_, back), dz) ds, da, db = back(dc + dz) # Don't need to store every 'da', but need for next iteration, and the last one. end dop = sum(first, trio) - dx = map(last, _reverse1(trio)) + dx = map(last, Iterators.reverse(trio)) if init == nothing # `hobbits` is one short, and the first one is weird dx = _vcat1(trio[end][2] + dy_plain[1], dx) @@ -628,7 +594,7 @@ function rrule( dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only" d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry" d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not) - return (NoTangent(), dop, dy, project(_reshape1(dx, axe)), NoTangent(), d_init) + return (NoTangent(), dop, dy, project(reshape(dx, axe)), NoTangent(), d_init) end - return _reshape1(y, axe), decumulate + return reshape(y, axe), decumulate end From 1b8e121072352d649b7f6c4aaa80bbefb75f721d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 29 Aug 2022 17:45:26 -0400 Subject: [PATCH 07/10] fix https://github.com/JuliaDiff/ChainRules.jl/issues/672 --- src/rulesets/Base/mapreduce.jl | 15 ++++++++++++++- test/rulesets/Base/mapreduce.jl | 4 ++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index fa8c1c576..83ed130c4 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -420,11 +420,13 @@ end ##### `mapfoldl(f, g, ::Tuple)` ##### +using Base: mapfoldl_impl + # For tuples there should be no harm in handling `map` first. # This will also catch `mapreduce`. function rrule( - cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), f::F, op::G, init, x::Tuple; + cfg::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f::F, op::G, init, x::Tuple; ) where {F,G} y, backmap = rrule(cfg, map, f, x) z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y) @@ -436,6 +438,11 @@ function rrule( return z, mapfoldl_pullback_tuple end +function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f, op, init, x::Tuple{}) + foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) + return init, foldl_pullback_empty +end + ##### ##### `foldl(f, ::Tuple)` ##### @@ -495,6 +502,12 @@ function rrule( return y, foldl_pullback_tuple_init end +# Base.tail doesn't work on (), trivial case: +function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), ::typeof(identity), op, init, x::Tuple{}) + foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) + return init, foldl_pullback_empty +end + ##### ##### `foldl(f, ::Array)` ##### diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 80bc58bc6..bd3031ed5 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -303,6 +303,10 @@ const _INIT = Base._InitialValue() # Finite differencing test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5))) test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5))) + + # Trivial case + test_rrule(mapfoldl_impl, identity, /, 2pi, ()) + test_rrule(mapfoldl_impl, sqrt, /, 2pi, ()) end @testset "mapfoldl(f, g, ::Tuple)" begin test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false) From 7b28e1084788fc41f4e27d429abcc38d1c06734c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 29 Aug 2022 20:33:44 -0400 Subject: [PATCH 08/10] avoid ambiguities --- src/rulesets/Base/mapreduce.jl | 15 ++++----------- test/rulesets/Base/mapreduce.jl | 1 - 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 83ed130c4..80bca8cdf 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -438,11 +438,6 @@ function rrule( return z, mapfoldl_pullback_tuple end -function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f, op, init, x::Tuple{}) - foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) - return init, foldl_pullback_empty -end - ##### ##### `foldl(f, ::Tuple)` ##### @@ -491,6 +486,10 @@ function rrule( init, x::Tuple; ) where {G} + # Trivial case handled here to avoid ambiguities (and necc. because of Base.tail below) + foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) + isempty(x) && return init, foldl_pullback_empty + # Treat `init` by simply appending it to the `x`: y, back = rrule(config, Base.mapfoldl_impl, identity, op, Base._InitialValue(), (init, x...)) project_x = ProjectTo(x) @@ -502,12 +501,6 @@ function rrule( return y, foldl_pullback_tuple_init end -# Base.tail doesn't work on (), trivial case: -function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), ::typeof(identity), op, init, x::Tuple{}) - foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) - return init, foldl_pullback_empty -end - ##### ##### `foldl(f, ::Array)` ##### diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index bd3031ed5..494639f80 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -306,7 +306,6 @@ const _INIT = Base._InitialValue() # Trivial case test_rrule(mapfoldl_impl, identity, /, 2pi, ()) - test_rrule(mapfoldl_impl, sqrt, /, 2pi, ()) end @testset "mapfoldl(f, g, ::Tuple)" begin test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false) From dd4f952e8be8746c191591b13bb48093ed3e82a4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 4 Sep 2022 09:33:33 -0400 Subject: [PATCH 09/10] avoid int64 --- test/rulesets/Base/mapreduce.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 494639f80..dc350bc2f 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -376,7 +376,7 @@ end @test y1 == [1, 2, 6, 24] @test b1([1, 1, 1, 1])[3] isa ChainRulesCore.NotImplemented @test b1([1, 1, 1, 1])[4] == [33, 16, 10, 6] - @test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}} + @test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int}} @test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented # y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing) From f18f61db25f35f14143fdb15bc8330140c974bf0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 11 Jul 2023 23:00:37 -0400 Subject: [PATCH 10/10] rm 2 commented lines Co-authored-by: Frames White --- src/rulesets/Base/mapreduce.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 80bca8cdf..3e013df12 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -571,10 +571,8 @@ function rrule( hobbits = accumulate(list; init = (start, nothing)) do (a, _), b c, back = rrule_via_ad(config, op, a, b) end - # y = map(first, hobbits) if init === nothing # `hobbits` is one short, and first one doesn't invoke `op` - # y = _vcat1(first(x), y) y[1] = first(x) map!(first, @view(y[2:end]), hobbits) else