-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #704 from LuxDL/ap/loss_functions
Loss functions module
- Loading branch information
Showing
22 changed files
with
1,260 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "Lux" | ||
uuid = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "0.5.55" | ||
version = "0.5.56" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
|
@@ -16,6 +16,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" | |
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" | ||
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" | ||
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" | ||
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" | ||
|
@@ -27,6 +28,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" | |
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69" | ||
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" | ||
|
||
[weakdeps] | ||
|
@@ -89,6 +91,7 @@ Functors = "0.4.10" | |
GPUArraysCore = "0.1.6" | ||
LinearAlgebra = "1.10" | ||
Logging = "1.10" | ||
LossFunctions = "0.11.1" | ||
LuxCore = "0.1.14" | ||
LuxDeviceUtils = "0.1.22" | ||
LuxLib = "0.3.23" | ||
|
@@ -99,6 +102,7 @@ MacroTools = "0.5.13" | |
Markdown = "1.10" | ||
NCCL = "0.1.1" | ||
OhMyThreads = "0.5.1" | ||
OneHotArrays = "0.2.5" | ||
Optimisers = "0.3" | ||
Pkg = "1.10" | ||
PrecompileTools = "1.2" | ||
|
@@ -130,6 +134,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | |
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" | ||
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | ||
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" | ||
|
@@ -142,4 +147,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" | |
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[targets] | ||
test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxTestUtils", "MLUtils", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] | ||
test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxTestUtils", "MLUtils", "OneHotArrays", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
f7b9539
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128)
3864.625
ns3683.125
ns1.05
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128)
7197.4
ns7288.666666666667
ns0.99
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128)
20859
ns20909
ns1.00
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128)
9690.5
ns9847.3
ns0.98
Dense(2 => 2)/cpu/reverse/Flux/(2, 128)
8960.6
ns9238.375
ns0.97
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128)
4472.125
ns4527.125
ns0.99
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128)
1152.2189781021898
ns1168.5407407407408
ns0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128)
1166.719696969697
ns1176.1526717557251
ns0.99
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128)
1178.348148148148
ns1186.4857142857143
ns0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128)
1791.0849056603774
ns1782.859375
ns1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128)
179.68464730290455
ns179.37413073713492
ns1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128)
17282
ns17342
ns1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128)
16802
ns17022
ns0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128)
38952
ns37380
ns1.04
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128)
29225
ns29484.5
ns0.99
Dense(20 => 20)/cpu/reverse/Flux/(20, 128)
19767
ns21770
ns0.91
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128)
17187.5
ns17477.5
ns0.98
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128)
4339.571428571428
ns4316.571428571428
ns1.01
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128)
3863.375
ns3864.625
ns1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128)
3932.375
ns3923.5
ns1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128)
4904.857142857143
ns4809
ns1.02
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128)
1662.1
ns1660.1
ns1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128)
46506004.5
ns39311146
ns1.18
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128)
57802889
ns57818439
ns1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128)
110407401
ns70725143
ns1.56
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128)
106701454
ns89020101
ns1.20
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128)
107083940.5
ns72846612
ns1.47
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128)
11944033.5
ns12056878.5
ns0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128)
17826363
ns17802524.5
ns1.00
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128)
7025065.5
ns7028063
ns1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128)
6982537
ns7000092.5
ns1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128)
18426786.5
ns9924699
ns1.86
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128)
6380584
ns6389608
ns1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16)
730954388
ns737562829
ns0.99
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64)
2547466756
ns2545549640
ns1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2)
143281222.5
ns146821325
ns0.98
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16)
907944434
ns868615027
ns1.05
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64)
3408151568
ns3064060217
ns1.11
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2)
232691070
ns219512795
ns1.06
vgg16/cpu/reverse/Flux/(32, 32, 3, 16)
701405563.5
ns685678726
ns1.02
vgg16/cpu/reverse/Flux/(32, 32, 3, 64)
2751192398
ns2574375943
ns1.07
vgg16/cpu/reverse/Flux/(32, 32, 3, 2)
150091812
ns127147427
ns1.18
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16)
174202338.5
ns171884482
ns1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64)
651482390
ns650293250.5
ns1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2)
45292469
ns34511836
ns1.31
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16)
164217014
ns164391167.5
ns1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64)
640363054
ns634653416
ns1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2)
30140396.5
ns29977086.5
ns1.01
vgg16/cpu/forward/Flux/(32, 32, 3, 16)
210450508
ns185946798
ns1.13
vgg16/cpu/forward/Flux/(32, 32, 3, 64)
812675116
ns765662897.5
ns1.06
vgg16/cpu/forward/Flux/(32, 32, 3, 2)
37414452
ns35241726.5
ns1.06
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128)
1319067230.5
ns1245538918.5
ns1.06
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128)
1855157553
ns1864879281
ns0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128)
2389725378
ns2293551179
ns1.04
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128)
2576211037
ns2516850614
ns1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128)
1979708272
ns1882887952.5
ns1.05
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128)
555898503
ns561045265
ns0.99
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128)
316191735
ns326179109
ns0.97
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128)
314504058
ns323271956
ns0.97
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128)
465777211
ns349888101
ns1.33
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128)
11830345.5
ns11973548
ns0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128)
18031514
ns17858872
ns1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128)
19168195
ns19168560
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128)
23938562
ns23865197
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128)
18050494
ns17866720
ns1.01
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128)
1154605
ns1158234
ns1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128)
5872019
ns5814007
ns1.01
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128)
2056057
ns2054540.5
ns1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128)
2037893
ns2037248
ns1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128)
2074361
ns2078324
ns1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128)
200504
ns202510.5
ns0.99
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128)
292576
ns293437.5
ns1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128)
265335
ns266057.5
ns1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128)
365001
ns365572
ns1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128)
408391.5
ns407804
ns1.00
Dense(200 => 200)/cpu/reverse/Flux/(200, 128)
383161
ns275034
ns1.39
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128)
407973
ns411080
ns0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128)
83536
ns83504
ns1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128)
81562
ns81180.5
ns1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128)
81833
ns81631
ns1.00
Dense(200 => 200)/cpu/forward/Flux/(200, 128)
86140
ns86775.5
ns0.99
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128)
104836
ns104563
ns1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128)
207388132
ns203633792
ns1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128)
324041362
ns328082047.5
ns0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128)
435241619.5
ns399733123
ns1.09
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128)
489451804
ns429567326
ns1.14
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128)
406672717.5
ns375921768
ns1.08
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128)
320495841
ns328704380
ns0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128)
101394771.5
ns101203246
ns1.00
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128)
43915846
ns43990642
ns1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128)
43768021
ns43821294.5
ns1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128)
70277761.5
ns53275150
ns1.32
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128)
29141934
ns28607335
ns1.02
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128)
18724535.5
ns19166105
ns0.98
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128)
19456819
ns19549447.5
ns1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128)
22919966.5
ns23387251
ns0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128)
23955208
ns24155491
ns0.99
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128)
19535589
ns19735654
ns0.99
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128)
6537723
ns6562123
ns1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128)
6514811
ns6547446.5
ns1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128)
6477278
ns6511687
ns0.99
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128)
6514184
ns6536680
ns1.00
This comment was automatically generated by workflow using github-action-benchmark.