diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index c385e87ea..988555f6b 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -257,7 +257,7 @@ end function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} @debug "rrule(map, f, arrays...)" f - z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...) + z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...), x, ys...) function map_pullback_2(dz) df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs) return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...) diff --git a/src/rulesets/Base/iterators.jl b/src/rulesets/Base/iterators.jl index f264b53d0..1ec727a36 100644 --- a/src/rulesets/Base/iterators.jl +++ b/src/rulesets/Base/iterators.jl @@ -1,5 +1,3 @@ -tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86 - ##### ##### Comprehension: Iterators.map ##### @@ -8,7 +6,7 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/Julia function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator} @debug "collect generator" - ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter) + ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x), gen.iter) proj_f = ProjectTo(gen.f) proj_iter = ProjectTo(gen.iter) function generator_pullback(dys_raw) @@ -28,8 +26,8 @@ ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.Pro Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) - Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) -Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) + Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape +Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally @@ -44,11 +42,10 @@ Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3 @btime Yota.grad($(rand(1000))) do xs sum(abs2, [sqrt(x) for x in xs]) end -# Yota min 1.134 ms, mean 1.207 ms (22017 allocations, 548.50 KiB) -# Diffractor min 936.708 μs, mean 1.020 ms (18028 allocations, 611.25 KiB) -# without unzip_map min 734.292 μs, mean 810.341 μs (13063 allocations, 517.97 KiB) +# Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB) +# Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB) -# Zygote min 6.117 μs, mean 11.287 μs (24 allocations, 40.31 KiB) +# Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB) @btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys @@ -57,11 +54,10 @@ end end sum(abs2, zs) end -# Yota + CR: min 2.643 ms, mean 2.781 ms (35011 allocations, 915.19 KiB) -# Diffractor + CR: min 1.184 ms, mean 1.285 ms (23026 allocations, 775.09 KiB) -# without unzip_map min 947.084 μs, mean 1.036 ms (18062 allocations, 697.86 KiB) +# Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB) +# Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB) -# Zygote: min 21.291 μs, mean 36.456 μs (26 allocations, 79.59 KiB) +# Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster =#