1- tup2 (x) = Tuple {Any,Any} (x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86
2-
31# ####
42# #### Comprehension: Iterators.map
53# ####
@@ -8,7 +6,7 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/Julia
86
97function rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (collect), gen:: G ) where {G<: Base.Generator }
108 @debug " collect generator"
11- ys, backs = unzip_map (x -> rrule_via_ad (cfg, gen. f, x)|> tup2 , gen. iter)
9+ ys, backs = unzip_map (x -> rrule_via_ad (cfg, gen. f, x), gen. iter)
1210 proj_f = ProjectTo (gen. f)
1311 proj_iter = ProjectTo (gen. iter)
1412 function generator_pullback (dys_raw)
@@ -28,8 +26,8 @@ ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.Pro
2826 Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
2927Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
3028
31- Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5])
32- Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5])
29+ Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape
30+ Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators
3331
3432 Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3)
3533Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally
@@ -44,11 +42,10 @@ Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3
4442@btime Yota.grad($(rand(1000))) do xs
4543 sum(abs2, [sqrt(x) for x in xs])
4644end
47- # Yota min 1.134 ms, mean 1.207 ms (22017 allocations, 548.50 KiB)
48- # Diffractor min 936.708 μs, mean 1.020 ms (18028 allocations, 611.25 KiB)
49- # without unzip_map min 734.292 μs, mean 810.341 μs (13063 allocations, 517.97 KiB)
45+ # Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB)
46+ # Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB)
5047
51- # Zygote min 6.117 μs, mean 11.287 μs (24 allocations, 40.31 KiB)
48+ # Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB)
5249
5350
5451@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys
5754 end
5855 sum(abs2, zs)
5956end
60- # Yota + CR: min 2.643 ms, mean 2.781 ms (35011 allocations, 915.19 KiB)
61- # Diffractor + CR: min 1.184 ms, mean 1.285 ms (23026 allocations, 775.09 KiB)
62- # without unzip_map min 947.084 μs, mean 1.036 ms (18062 allocations, 697.86 KiB)
57+ # Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB)
58+ # Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB)
6359
64- # Zygote: min 21.291 μs, mean 36.456 μs (26 allocations, 79.59 KiB)
60+ # Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster
6561
6662
6763=#
0 commit comments