-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loss functions module #704
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #704 +/- ##
==========================================
+ Coverage 87.27% 87.70% +0.42%
==========================================
Files 50 51 +1
Lines 2522 2667 +145
==========================================
+ Hits 2201 2339 +138
- Misses 321 328 +7 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 514414e | Previous: f1b8c12 | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3631.875 ns |
3683.125 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
7158.416666666666 ns |
7288.666666666667 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
21069 ns |
20909 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9812.2 ns |
9847.3 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
9049.5 ns |
9238.375 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4484.625 ns |
4527.125 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) |
1157.127659574468 ns |
1168.5407407407408 ns |
0.99 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
1116.4415584415585 ns |
1176.1526717557251 ns |
0.95 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
1183.610294117647 ns |
1186.4857142857143 ns |
1.00 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1779.7796610169491 ns |
1782.859375 ns |
1.00 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
179.87760778859527 ns |
179.37413073713492 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17323 ns |
17342 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
16862 ns |
17022 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
38976 ns |
37380 ns |
1.04 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
29114 ns |
29484.5 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
20047 ns |
21770 ns |
0.92 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
17363 ns |
17477.5 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) |
4336.714285714285 ns |
4316.571428571428 ns |
1.00 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
3858.625 ns |
3864.625 ns |
1.00 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
3947.375 ns |
3923.5 ns |
1.01 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4824.714285714285 ns |
4809 ns |
1.00 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1657.1 ns |
1660.1 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
40731666.5 ns |
39311146 ns |
1.04 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
58221071 ns |
57818439 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
80486416 ns |
70725143 ns |
1.14 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
92289031.5 ns |
89020101 ns |
1.04 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
78152776.5 ns |
72846612 ns |
1.07 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
11765992 ns |
12056878.5 ns |
0.98 |
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) |
17979283 ns |
17802524.5 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
7047975 ns |
7028063 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
7004301 ns |
7000092.5 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
11596864 ns |
9924699 ns |
1.17 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6405955 ns |
6389608 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
735355046 ns |
737562829 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
2573465965 ns |
2545549640 ns |
1.01 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) |
134117036 ns |
146821325 ns |
0.91 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
947130506 ns |
868615027 ns |
1.09 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
3290476025 ns |
3064060217 ns |
1.07 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) |
213218863.5 ns |
219512795 ns |
0.97 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
676460181 ns |
685678726 ns |
0.99 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
2904583601 ns |
2574375943 ns |
1.13 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) |
142754278 ns |
127147427 ns |
1.12 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
174721427.5 ns |
171884482 ns |
1.02 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
654617121.5 ns |
650293250.5 ns |
1.01 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) |
34640207.5 ns |
34511836 ns |
1.00 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
164818954 ns |
164391167.5 ns |
1.00 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
643255585 ns |
634653416 ns |
1.01 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) |
30219827 ns |
29977086.5 ns |
1.01 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
226063665 ns |
185946798 ns |
1.22 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
848679144.5 ns |
765662897.5 ns |
1.11 |
vgg16/cpu/forward/Flux/(32, 32, 3, 2) |
38429907 ns |
35241726.5 ns |
1.09 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1245926200.5 ns |
1245538918.5 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1878249003.5 ns |
1864879281 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2518204632 ns |
2293551179 ns |
1.10 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2568126221 ns |
2516850614 ns |
1.02 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
1926214731 ns |
1882887952.5 ns |
1.02 |
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) |
562169008 ns |
561045265 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
323670224 ns |
326179109 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
324197610 ns |
323271956 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
444007505.5 ns |
349888101 ns |
1.27 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
11901404 ns |
11973548 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
17929687 ns |
17858872 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
19146791.5 ns |
19168560 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
23901598 ns |
23865197 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
17924815.5 ns |
17866720 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1158931 ns |
1158234 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) |
5803386 ns |
5814007 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
2048048 ns |
2054540.5 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
2028681 ns |
2037248 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
2074240 ns |
2078324 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
197480 ns |
202510.5 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
291206 ns |
293437.5 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
265427 ns |
266057.5 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
364433 ns |
365572 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
407153 ns |
407804 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
273222 ns |
275034 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
405900 ns |
411080 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) |
83236 ns |
83504 ns |
1.00 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
80932 ns |
81180.5 ns |
1.00 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
81433 ns |
81631 ns |
1.00 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
86492 ns |
86775.5 ns |
1.00 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
104586 ns |
104563 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
208181727 ns |
203633792 ns |
1.02 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
328707483 ns |
328082047.5 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
422656671 ns |
399733123 ns |
1.06 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
481664801 ns |
429567326 ns |
1.12 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
391857473 ns |
375921768 ns |
1.04 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
325505169.5 ns |
328704380 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) |
101123238 ns |
101203246 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
43706316.5 ns |
43990642 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
43416719 ns |
43821294.5 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
67806135 ns |
53275150 ns |
1.27 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
28142932.5 ns |
28607335 ns |
0.98 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
18738802 ns |
19166105 ns |
0.98 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
19518289 ns |
19549447.5 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
23182711 ns |
23387251 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
24062627 ns |
24155491 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19576555 ns |
19735654 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) |
6499610.5 ns |
6562123 ns |
0.99 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6520696.5 ns |
6547446.5 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6495478.5 ns |
6511687 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6489931 ns |
6536680 ns |
0.99 |
This comment was automatically generated by workflow using github-action-benchmark.
68e1427
to
c5489e1
Compare
846501f
to
aa7497c
Compare
762fd7a
to
94f609e
Compare
94f609e
to
2009f7b
Compare
077675f
to
9f853a6
Compare
b405093
to
3a43218
Compare
0ecf985
to
8119743
Compare
8119743
to
c10aa78
Compare
478b931
to
9e94890
Compare
9e94890
to
638a176
Compare
b73c193
to
d1c8448
Compare
d42594d
to
97a290b
Compare
97a290b
to
514414e
Compare
This has been in popular demand from users.
We deviate from Flux's API of providing just functions and go with the Pytorch style of constructing and calling that object (like
LossFunctions.jl
). This lets us fix the common kwargs.TODO
xlogx
/xlogy
implementations@test_adjoint
@jet
Experimental
module to wrap the loss functions into an API compatible with the training API.