-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: bidirectional RNN + debugging RNNs (#708)
* feat: add BidirectionalRNN * fix: avoid lazy reverse * fix: handle reverse for operator overloading AD * fix: allow debug modes for recurrent layers * fix: soa/aos handling for multigate --------- Co-authored-by: Avik Pal <[email protected]>
- Loading branch information
1 parent
803e660
commit d9aa5a6
Showing
11 changed files
with
200 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ LSTMCell | |
RNNCell | ||
Recurrence | ||
StatefulRecurrentCell | ||
BidirectionalRNN | ||
``` | ||
|
||
## Linear Layers | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
d9aa5a6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128)
3643
ns3668.125
ns0.99
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128)
7176
ns7265.333333333333
ns0.99
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128)
20829
ns21330
ns0.98
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128)
9808.2
ns9776.4
ns1.00
Dense(2 => 2)/cpu/reverse/Flux/(2, 128)
8944
ns9087
ns0.98
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128)
4470.875
ns4558.625
ns0.98
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128)
1152.5869565217392
ns1163.4855072463768
ns0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128)
1112.0065789473683
ns1119.0544871794873
ns0.99
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128)
1169.9583333333333
ns1180.2686567164178
ns0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128)
1773.142857142857
ns1766.6949152542372
ns1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128)
179.55182072829132
ns178.72408963585434
ns1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128)
17172
ns17182
ns1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128)
16852
ns17062
ns0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128)
36989
ns37560
ns0.98
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128)
28217.5
ns28904
ns0.98
Dense(20 => 20)/cpu/reverse/Flux/(20, 128)
19907
ns21470
ns0.93
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128)
17743
ns17412
ns1.02
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128)
4329.571428571428
ns4353.857142857143
ns0.99
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128)
3848.375
ns3851
ns1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128)
3965
ns3946.125
ns1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128)
4887.857142857143
ns4869.142857142857
ns1.00
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128)
1664.1
ns1652.1
ns1.01
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128)
38592071
ns48481192
ns0.80
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128)
57598840.5
ns57589911
ns1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128)
70928337
ns112051236
ns0.63
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128)
88908288
ns107350078
ns0.83
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128)
72637494
ns107753804
ns0.67
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128)
11652507
ns11633974
ns1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128)
6961815
ns8368822
ns0.83
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128)
7118503
ns6994384
ns1.02
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128)
7056081
ns6961674
ns1.01
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128)
10185966
ns18289304
ns0.56
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128)
6384646.5
ns6377896.5
ns1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16)
693015046
ns708146792
ns0.98
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64)
2564503502
ns2538464789
ns1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2)
141044477
ns130550753
ns1.08
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16)
755933134
ns940215170
ns0.80
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64)
2887970897
ns3222185770
ns0.90
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2)
209654834
ns200541349
ns1.05
vgg16/cpu/reverse/Flux/(32, 32, 3, 16)
652305217
ns725094261.5
ns0.90
vgg16/cpu/reverse/Flux/(32, 32, 3, 64)
2592391657
ns2700272370
ns0.96
vgg16/cpu/reverse/Flux/(32, 32, 3, 2)
123578859
ns133034897.5
ns0.93
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16)
175231208.5
ns174289604.5
ns1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64)
652645593.5
ns656945864.5
ns0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2)
45509251
ns45333461
ns1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16)
164808995.5
ns164456981
ns1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64)
643429956
ns639772875
ns1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2)
30232427
ns30105099
ns1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 16)
187781212
ns230482052
ns0.81
vgg16/cpu/forward/Flux/(32, 32, 3, 64)
711012026.5
ns896237645
ns0.79
vgg16/cpu/forward/Flux/(32, 32, 3, 2)
36478557.5
ns39999991
ns0.91
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128)
1238395972.5
ns1229029098.5
ns1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128)
1864792034.5
ns1857882986.5
ns1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128)
2335459386
ns2500640456
ns0.93
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128)
2510833048
ns2705526315
ns0.93
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128)
1834047291
ns1857087780.5
ns0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128)
327114030.5
ns354871769
ns0.92
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128)
326897983
ns321331574
ns1.02
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128)
322709878
ns319443241
ns1.01
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128)
426753786
ns365388358
ns1.17
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128)
11697996.5
ns11707792
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128)
17846230
ns17872220
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128)
19079864
ns19026670
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128)
23775806
ns23710928
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128)
17783235
ns17883894
ns0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128)
1158115.5
ns1151164
ns1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128)
2115024
ns2521988.5
ns0.84
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128)
2128008
ns2058845.5
ns1.03
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128)
2080563.5
ns2039299
ns1.02
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128)
2067390
ns2075005.5
ns1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128)
198766.5
ns196078
ns1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128)
293619
ns291156
ns1.01
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128)
265135.5
ns264076
ns1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128)
365413.5
ns364333.5
ns1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128)
408393
ns406242
ns1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128)
275595
ns273453
ns1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128)
408263
ns406201
ns1.01
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128)
83396
ns83697
ns1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128)
81152
ns81232
ns1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128)
82895
ns81733
ns1.01
Dense(200 => 200)/cpu/forward/Flux/(200, 128)
86902
ns86647.5
ns1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128)
104425
ns104416
ns1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128)
197729486
ns192747057.5
ns1.03
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128)
326930373
ns327202184
ns1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128)
398584711
ns449880668
ns0.89
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128)
443775723.5
ns474869215
ns0.93
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128)
347290902
ns412132513.5
ns0.84
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128)
322621922
ns322865631.5
ns1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128)
44383622
ns51483926
ns0.86
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128)
44477990
ns43968482
ns1.01
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128)
44018296
ns43749705
ns1.01
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128)
53487940
ns70756328
ns0.76
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128)
27983719
ns28106570
ns1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128)
18934728.5
ns18818603
ns1.01
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128)
19606172
ns19497998
ns1.01
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128)
23436883.5
ns23422262
ns1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128)
24235917
ns24150485
ns1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128)
19708092.5
ns19687573
ns1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128)
6514418
ns6494414
ns1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128)
6534616
ns6523056
ns1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128)
6507885
ns6473309
ns1.01
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128)
6515966
ns6503950
ns1.00
This comment was automatically generated by workflow using github-action-benchmark.