diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 988555f6b..0d1529e5c 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -249,13 +249,17 @@ end ##### function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F} - # y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # Yota likes this one - # return Broadcast.materialize(y), back - y, back = rrule_via_ad(cfg, broadcast, f, x) # but testing like this one + # Here map agrees with broadcast, and we have a meta-rule with 4 different paths, should be fast: + y, back = rrule_via_ad(cfg, broadcast, f, x) return y, back end function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} + if all(==(size(x)), map(size, ys)) + # Here too map agrees with broadcast, maybe the test could be more elegant? + y, back = rrule_via_ad(cfg, broadcast, f, x, ys...) + return y, back + end @debug "rrule(map, f, arrays...)" f z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...), x, ys...) function map_pullback_2(dz) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 760f059e9..f9b30e2a4 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -121,26 +121,6 @@ end # Path 4: The most generic, save all the pullbacks. Can be 1000x slower. # While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration. -#= - -julia> Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), [1,2,3.0]) -┌ Debug: split broadcasting generic -│ f = #69 (generic function with 1 method) -│ N = 1 -└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:126 -(14.0, (ZeroTangent(), [2.0, 4.0, 6.0])) - -julia> ENV["JULIA_DEBUG"] = nothing - -julia> @btime Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), $(rand(1000))); - min 1.321 ms, mean 1.434 ms (23010 allocations, 594.66 KiB) # with unzip_map, as before - min 1.279 ms, mean 1.393 ms (23029 allocations, 595.73 KiB) # with unzip_map_reversed - -julia> @btime Yota.grad(xs -> sum(abs2, abs.(xs)), $(randn(1000))); # Debug: split broadcasting derivative - min 2.144 μs, mean 6.620 μs (6 allocations, 23.88 KiB) - -=# - function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} @debug("split broadcasting generic", f, N) ys3, backs = unzip_broadcast(args...) do a... diff --git a/src/rulesets/Base/iterators.jl b/src/rulesets/Base/iterators.jl index 1ec727a36..63df6a3d6 100644 --- a/src/rulesets/Base/iterators.jl +++ b/src/rulesets/Base/iterators.jl @@ -17,57 +17,10 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) whe ys, generator_pullback end -# Needed for Yota, but shouldn't these be automatic? -ChainRulesCore.rrule(::Type{<:Base.Generator}, f, iter) = Base.Generator(f, iter), dy -> (NoTangent(), dy.f, dy.iter) -ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.ProductIterator(iters), dy -> (NoTangent(), dy.iterators) - -#= - - Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) -Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) - - Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape -Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators - - Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) -Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally - - Yota.grad(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3) -Diffractor.gradient(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3) - - Yota.grad(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3) -Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3) - - -@btime Yota.grad($(rand(1000))) do xs - sum(abs2, [sqrt(x) for x in xs]) -end -# Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB) -# Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB) - -# Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB) - - -@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys - zs = map(xs, ys) do x, y - atan(x/y) - end - sum(abs2, zs) -end -# Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB) -# Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB) - -# Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster - - -=# - - ##### ##### `zip` ##### - function rrule(::typeof(zip), xs::AbstractArray...) function zip_pullback(dy) @debug "zip array pullback" summary(dy) @@ -94,8 +47,6 @@ function _unmap_pad(x::AbstractArray, dx::AbstractArray) @debug "_unmap_pad is extending gradient" length(x) == length(dx) i1 = firstindex(x) ∇getindex(x, vec(dx), i1:i1+length(dx)-1) - # dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx))) - # ProjectTo(x)(reshape(dx2, axes(x))) end end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a7c166b0f..830a794dd 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -251,6 +251,8 @@ @testset "map(f, ::Array, ::Array)" begin test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent} + test_rrule(map, atan, [1 2; 3.0 4.0], [4 5; 6 7.0], check_inferred=false) # same shape => just broadcast + test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false) test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false) diff --git a/test/rulesets/Base/iterators.jl b/test/rulesets/Base/iterators.jl index d9060d985..8e22c8ba8 100644 --- a/test/rulesets/Base/iterators.jl +++ b/test/rulesets/Base/iterators.jl @@ -10,7 +10,7 @@ y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0])) @test y2 == map(Counter(), 11:13) - @test bk2(ones(3))[2].iter == [93, 83, 73] + @test bk2(ones(3))[2].iter == [33, 23, 13] end end @@ -23,4 +23,4 @@ end test_rrule(collect∘zip, rand(3), rand(5)) test_rrule(collect∘zip, rand(3,2), rand(5)) end -end \ No newline at end of file +end