Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
avoid ambiguities
Browse files Browse the repository at this point in the history
mcabbott committed Aug 30, 2022
1 parent c0a1aeb commit 6333879
Showing 2 changed files with 4 additions and 12 deletions.
15 changes: 4 additions & 11 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -438,11 +438,6 @@ function rrule(
return z, mapfoldl_pullback_tuple
end

function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f, op, init, x::Tuple{})
foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent())
return init, foldl_pullback_empty
end

#####
##### `foldl(f, ::Tuple)`
#####
@@ -491,6 +486,10 @@ function rrule(
init,
x::Tuple;
) where {G}
# Trivial case handled here to avoid ambiguities (and necc. because of Base.tail below)
foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent())
isempty(x) && return init, foldl_pullback_empty

# Treat `init` by simply appending it to the `x`:
y, back = rrule(config, Base.mapfoldl_impl, identity, op, Base._InitialValue(), (init, x...))
project_x = ProjectTo(x)
@@ -502,12 +501,6 @@ function rrule(
return y, foldl_pullback_tuple_init
end

# Base.tail doesn't work on (), trivial case:
function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), ::typeof(identity), op, init, x::Tuple{})
foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent())
return init, foldl_pullback_empty
end

#####
##### `foldl(f, ::Array)`
#####
1 change: 0 additions & 1 deletion test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -306,7 +306,6 @@ const _INIT = Base._InitialValue()

# Trivial case
test_rrule(mapfoldl_impl, identity, /, 2pi, ())
test_rrule(mapfoldl_impl, sqrt, /, 2pi, ())
end
@testset "mapfoldl(f, g, ::Tuple)" begin
test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false)

0 comments on commit 6333879

Please sign in to comment.