diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 6ec389f8d..e72be275c 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -417,6 +417,25 @@ end end ##### +##### `mapfoldl(f, g, ::Tuple)` +##### + +# For tuples there should be no harm in handling `map` first. +# This will also catch `mapreduce`. + +function rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), f::F, op::G, init, x::Tuple; + ) where {F,G} + y, backmap = rrule(cfg, map, f, x) + z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y) + function mapfoldl_pullback_tuple(dz) + _, _, dop, dinit, dy = backred(dz) + _, df, dx = backmap(dy) + return (NoTangent(), df, dop, dinit, dx) + end + return z, mapfoldl_pullback_tuple +end + ##### ##### `foldl(f, ::Tuple)` ##### diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 699d91a15..80bc58bc6 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -304,6 +304,11 @@ const _INIT = Base._InitialValue() test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5))) test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5))) end + @testset "mapfoldl(f, g, ::Tuple)" begin + test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false) + test_rrule(mapfoldl_impl, abs2, *, 1+rand(), Tuple(rand(ComplexF64, 5)), check_inferred=false) + # TODO make the `map(f, ::Tuple)` rule infer better! + end end @testset "Accumulations" begin