Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss functions module #704

Merged
merged 13 commits into from
Jun 20, 2024
Merged

Loss functions module #704

merged 13 commits into from
Jun 20, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Jun 16, 2024

This has been in popular demand from users.

We deviate from Flux's API of providing just functions and go with the Pytorch style of constructing and calling that object (like LossFunctions.jl). This lets us fix the common kwargs.

TODO

  • xlogx / xlogy implementations
  • Common Loss Functions
  • Add loss functions to the API documentation
  • Reuse code from LossFunctions.jl
  • Add tests from https://github.com/FluxML/Flux.jl/blob/master/test/losses.jl
  • Additional Tests
    • @test_adjoint
    • @jet
  • Add a functionality in the Experimental module to wrap the loss functions into an API compatible with the training API.
    • Add docs for this
    • Update existing examples to use this

Copy link

codecov bot commented Jun 16, 2024

Codecov Report

Attention: Patch coverage is 94.48276% with 8 lines in your changes missing coverage. Please review.

Project coverage is 87.70%. Comparing base (f1b8c12) to head (514414e).

Files Patch % Lines
src/utils.jl 85.41% 7 Missing ⚠️
src/helpers/losses.jl 98.61% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #704      +/-   ##
==========================================
+ Coverage   87.27%   87.70%   +0.42%     
==========================================
  Files          50       51       +1     
  Lines        2522     2667     +145     
==========================================
+ Hits         2201     2339     +138     
- Misses        321      328       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 514414e Previous: f1b8c12 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3631.875 ns 3683.125 ns 0.99
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7158.416666666666 ns 7288.666666666667 ns 0.98
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 21069 ns 20909 ns 1.01
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9812.2 ns 9847.3 ns 1.00
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9049.5 ns 9238.375 ns 0.98
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4484.625 ns 4527.125 ns 0.99
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1157.127659574468 ns 1168.5407407407408 ns 0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1116.4415584415585 ns 1176.1526717557251 ns 0.95
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1183.610294117647 ns 1186.4857142857143 ns 1.00
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1779.7796610169491 ns 1782.859375 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.87760778859527 ns 179.37413073713492 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17323 ns 17342 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16862 ns 17022 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 38976 ns 37380 ns 1.04
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29114 ns 29484.5 ns 0.99
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20047 ns 21770 ns 0.92
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17363 ns 17477.5 ns 0.99
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4336.714285714285 ns 4316.571428571428 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3858.625 ns 3864.625 ns 1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3947.375 ns 3923.5 ns 1.01
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4824.714285714285 ns 4809 ns 1.00
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1657.1 ns 1660.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 40731666.5 ns 39311146 ns 1.04
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 58221071 ns 57818439 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 80486416 ns 70725143 ns 1.14
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 92289031.5 ns 89020101 ns 1.04
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 78152776.5 ns 72846612 ns 1.07
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11765992 ns 12056878.5 ns 0.98
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 17979283 ns 17802524.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7047975 ns 7028063 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7004301 ns 7000092.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 11596864 ns 9924699 ns 1.17
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6405955 ns 6389608 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 735355046 ns 737562829 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2573465965 ns 2545549640 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 134117036 ns 146821325 ns 0.91
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 947130506 ns 868615027 ns 1.09
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3290476025 ns 3064060217 ns 1.07
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 213218863.5 ns 219512795 ns 0.97
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 676460181 ns 685678726 ns 0.99
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2904583601 ns 2574375943 ns 1.13
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 142754278 ns 127147427 ns 1.12
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 174721427.5 ns 171884482 ns 1.02
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 654617121.5 ns 650293250.5 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 34640207.5 ns 34511836 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 164818954 ns 164391167.5 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 643255585 ns 634653416 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30219827 ns 29977086.5 ns 1.01
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 226063665 ns 185946798 ns 1.22
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 848679144.5 ns 765662897.5 ns 1.11
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 38429907 ns 35241726.5 ns 1.09
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1245926200.5 ns 1245538918.5 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1878249003.5 ns 1864879281 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2518204632 ns 2293551179 ns 1.10
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2568126221 ns 2516850614 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1926214731 ns 1882887952.5 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 562169008 ns 561045265 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 323670224 ns 326179109 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 324197610 ns 323271956 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 444007505.5 ns 349888101 ns 1.27
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11901404 ns 11973548 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17929687 ns 17858872 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19146791.5 ns 19168560 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23901598 ns 23865197 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17924815.5 ns 17866720 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1158931 ns 1158234 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 5803386 ns 5814007 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2048048 ns 2054540.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2028681 ns 2037248 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2074240 ns 2078324 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 197480 ns 202510.5 ns 0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 291206 ns 293437.5 ns 0.99
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 265427 ns 266057.5 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 364433 ns 365572 ns 1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 407153 ns 407804 ns 1.00
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 273222 ns 275034 ns 0.99
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 405900 ns 411080 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83236 ns 83504 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 80932 ns 81180.5 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81433 ns 81631 ns 1.00
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86492 ns 86775.5 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104586 ns 104563 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 208181727 ns 203633792 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 328707483 ns 328082047.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 422656671 ns 399733123 ns 1.06
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 481664801 ns 429567326 ns 1.12
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 391857473 ns 375921768 ns 1.04
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 325505169.5 ns 328704380 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 101123238 ns 101203246 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43706316.5 ns 43990642 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43416719 ns 43821294.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 67806135 ns 53275150 ns 1.27
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28142932.5 ns 28607335 ns 0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18738802 ns 19166105 ns 0.98
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19518289 ns 19549447.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23182711 ns 23387251 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24062627 ns 24155491 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19576555 ns 19735654 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6499610.5 ns 6562123 ns 0.99
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6520696.5 ns 6547446.5 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6495478.5 ns 6511687 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6489931 ns 6536680 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal linked an issue Jun 16, 2024 that may be closed by this pull request
@avik-pal avik-pal force-pushed the ap/loss_functions branch 2 times, most recently from 68e1427 to c5489e1 Compare June 16, 2024 23:48
src/losses/Losses.jl Outdated Show resolved Hide resolved
src/losses/utils.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/loss_functions branch 2 times, most recently from 846501f to aa7497c Compare June 17, 2024 05:45
@avik-pal avik-pal force-pushed the ap/loss_functions branch from 762fd7a to 94f609e Compare June 19, 2024 04:11
@avik-pal avik-pal force-pushed the ap/loss_functions branch from 94f609e to 2009f7b Compare June 19, 2024 04:14
@avik-pal avik-pal force-pushed the ap/loss_functions branch 2 times, most recently from 077675f to 9f853a6 Compare June 19, 2024 07:01
@avik-pal avik-pal force-pushed the ap/loss_functions branch 2 times, most recently from b405093 to 3a43218 Compare June 19, 2024 16:29
@avik-pal avik-pal force-pushed the ap/loss_functions branch 2 times, most recently from 0ecf985 to 8119743 Compare June 19, 2024 22:05
@avik-pal avik-pal force-pushed the ap/loss_functions branch from 8119743 to c10aa78 Compare June 19, 2024 22:20
@avik-pal avik-pal force-pushed the ap/loss_functions branch from 478b931 to 9e94890 Compare June 20, 2024 02:14
@avik-pal avik-pal force-pushed the ap/loss_functions branch from 9e94890 to 638a176 Compare June 20, 2024 02:51
@avik-pal avik-pal force-pushed the ap/loss_functions branch from b73c193 to d1c8448 Compare June 20, 2024 03:28
@avik-pal avik-pal force-pushed the ap/loss_functions branch 2 times, most recently from d42594d to 97a290b Compare June 20, 2024 04:45
@avik-pal avik-pal force-pushed the ap/loss_functions branch from 97a290b to 514414e Compare June 20, 2024 04:57
@avik-pal avik-pal merged commit f7b9539 into main Jun 20, 2024
51 of 53 checks passed
@avik-pal avik-pal deleted the ap/loss_functions branch June 20, 2024 05:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Documentation Request: Have a section about Loss Functions
1 participant