Skip to content

Commit

Permalink
Merge pull request #704 from LuxDL/ap/loss_functions
Browse files Browse the repository at this point in the history
Loss functions module
  • Loading branch information
avik-pal authored Jun 20, 2024
2 parents f1b8c12 + 514414e commit f7b9539
Show file tree
Hide file tree
Showing 22 changed files with 1,260 additions and 122 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.5.55"
version = "0.5.56"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -16,6 +16,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Expand All @@ -27,6 +28,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
Expand Down Expand Up @@ -89,6 +91,7 @@ Functors = "0.4.10"
GPUArraysCore = "0.1.6"
LinearAlgebra = "1.10"
Logging = "1.10"
LossFunctions = "0.11.1"
LuxCore = "0.1.14"
LuxDeviceUtils = "0.1.22"
LuxLib = "0.3.23"
Expand All @@ -99,6 +102,7 @@ MacroTools = "0.5.13"
Markdown = "1.10"
NCCL = "0.1.1"
OhMyThreads = "0.5.1"
OneHotArrays = "0.2.5"
Optimisers = "0.3"
Pkg = "1.10"
PrecompileTools = "1.2"
Expand Down Expand Up @@ -130,6 +134,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Expand All @@ -142,4 +147,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxTestUtils", "MLUtils", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxTestUtils", "MLUtils", "OneHotArrays", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
4 changes: 1 addition & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ pages = [
"api/Lux/distributed_utils.md",
],
"Accelerator Support" => [
"api/Accelerator_Support/LuxCUDA.md",
"api/Accelerator_Support/LuxDeviceUtils.md"
],
"Building Blocks" => [
Expand Down Expand Up @@ -82,8 +81,7 @@ makedocs(; sitename="Lux.jl Documentation",
authors="Avik Pal et al.",
clean=true,
doctest=false, # We test it in the CI, no need to run it here
modules=[Lux, LuxCore, LuxLib, WeightInitializers,
Boltz, LuxTestUtils, LuxDeviceUtils, LuxCUDA],
modules=[Lux, LuxCore, LuxLib, WeightInitializers, Boltz, LuxTestUtils, LuxDeviceUtils],
linkcheck=true,
repo="https://github.com/LuxDL/Lux.jl/blob/{commit}{path}#{line}",
format=DocumenterVitepress.MarkdownVitepress(;
Expand Down
2 changes: 0 additions & 2 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ export default defineConfig({
},
{
text: 'Accelerator Support', items: [
{ text: 'LuxCUDA', link: '/api/Accelerator_Support/LuxCUDA' },
{ text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' }
]
},
Expand Down Expand Up @@ -196,7 +195,6 @@ export default defineConfig({
},
{
text: 'Accelerator Support', collapsed: false, items: [
{ text: 'LuxCUDA', link: '/api/Accelerator_Support/LuxCUDA' },
{ text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' }]
},
{
Expand Down
16 changes: 0 additions & 16 deletions docs/src/api/Accelerator_Support/LuxCUDA.md

This file was deleted.

2 changes: 1 addition & 1 deletion docs/src/api/Lux/autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Lux. Additionally, we provide some convenience functions for working with AD.
| [`ReverseDiff.jl`](https://github.com/JuliaDiff/ReverseDiff.jl) | ✔️ ||| Tier II |
| [`Tracker.jl`](https://github.com/FluxML/Tracker.jl) | ✔️ | ✔️ || Tier II |
| [`Enzyme.jl`](https://github.com/EnzymeAD/Enzyme.jl) | ✔️ |[^q] |[^q] | Tier II |
| [`Tapir.jl`](https://github.com/withbayes/Tapir.jl) |[^q] | [^q] || Tier IV |
| [`Tapir.jl`](https://github.com/withbayes/Tapir.jl) |[^q] | || Tier IV |
| [`Diffractor.jl`](https://github.com/JuliaDiff/Diffractor.jl) |[^q] |[^q] |[^q] | Tier IV |

[^q]: This feature is supported downstream, but we don't extensively test it to ensure
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ All features listed on this page are **experimental** which means:
Pages = ["contrib.md"]
```

## Training
## [Training](@id Training-API)

Helper Functions making it easier to train `Lux.jl` models.

Expand Down
51 changes: 43 additions & 8 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,38 @@
Pages = ["utilities.md"]
```

## Device Management / Data Transfer
## Loss Functions

```@docs
Lux.cpu
Lux.gpu
```
Loss Functions Objects take 2 forms of inputs:

1. $y\hat$ and $y$ where $y\hat$ is the predicted output and $y$ is the target output.
2. `model`, `ps`, `st`, `(x, y)` where `model` is the model, `ps` are the parameters,
`st` are the states and `(x, y)` are the input and target pair. Then it returns the
loss, updated states, and an empty named tuple. This makes them compatible with the
[Experimental Training API](@ref Training-API).

!!! warning

For detailed API documentation on Data Transfer check out the
[LuxDeviceUtils.jl](@ref LuxDeviceUtils-API)
When using ChainRules.jl compatible AD (like Zygote), we only compute the gradients
wrt the inputs and drop any gradients wrt the targets.

```@docs
GenericLossFunction
BinaryCrossEntropyLoss
BinaryFocalLoss
CrossEntropyLoss
DiceCoeffLoss
FocalLoss
HingeLoss
HuberLoss
KLDivergenceLoss
MAELoss
MSELoss
MSLELoss
PoissonLoss
SiameseContrastiveLoss
SquaredHingeLoss
```

## Weight Initialization

Expand All @@ -31,6 +52,8 @@ Lux.gpu
Lux.foldl_init
Lux.istraining
Lux.multigate
Lux.xlogy
Lux.xlogx
```

## Updating Floating Point Precision
Expand All @@ -56,8 +79,20 @@ StatefulLuxLayer
@compact
```

## Truncated Stacktraces
## Truncated Stacktraces (Deprecated)

```@docs
Lux.disable_stacktrace_truncation!
```

## Device Management / Data Transfer (Deprecated)

```@docs
Lux.cpu
Lux.gpu
```

!!! warning

For detailed API documentation on Data Transfer check out the
[LuxDeviceUtils.jl](@ref LuxDeviceUtils-API)
5 changes: 3 additions & 2 deletions examples/Basics/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.13, 0.14, 0.15"
ComponentArrays = "0.15"
ForwardDiff = "0.10"
Literate = "2"
Lux = "0.5"
Lux = "0.5.56"
LuxCUDA = "0.2, 0.3"
Optimisers = "0.2, 0.3"
Zygote = "0.6"
48 changes: 20 additions & 28 deletions examples/Basics/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,34 +294,26 @@ println("x shape: ", size(x_samples), "; y shape: ", size(y_samples))
# [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). We will use Stochastic Gradient
# Descent (SGD) with a learning rate of `0.01`.

using Optimisers

opt = Optimisers.Descent(0.01f0)

# Initialize the initial state of the optimiser
opt_state = Optimisers.setup(opt, ps)
using Optimisers, Printf

# Define the loss function
function sse(model, ps, st, X, y)
y_pred, st_new = model(X, ps, st)
return sum(abs2, y_pred .- y), st_new
end
sse(weight, bias, X, y) = sum(abs2, weight * X .+ bias .- y)
loss_function(ps, X, y) = sse(model, ps, st, X, y)

println("Loss Value with ground true parameters: ", sse(W, b, x_samples, y_samples))

for i in 1:100
## In actual code, don't use globals. But here I will simply for the sake of
## demonstration
global ps, st, opt_state
## Compute the gradient using the pullback API to update the states
(loss, st), pb_f = Zygote.pullback(loss_function, ps, x_samples, y_samples)
## We pass nothing as the seed for `st`, since we don't want to propagate any gradient
## for st
gs = pb_f((one(loss), nothing))[1]
## Update model parameters
## `Optimisers.update` can be used if mutation is not desired
opt_state, ps = Optimisers.update!(opt_state, ps, gs)
(i % 10 == 1 || i == 100) && println(lazy"Loss Value after $i iterations: $loss")
lossfn = MSELoss()

println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples))

# We will train the model using our training API.
function train_model!(model, ps, st, opt, nepochs::Int)
tstate = Lux.Experimental.TrainState(model, ps, st, opt)
for i in 1:nepochs
grads, loss, _, tstate = Lux.Experimental.single_train_step!(
AutoZygote(), lossfn, (x_samples, y_samples), tstate)
if i % 1000 == 1 || i == nepochs
@printf "Loss Value after %6d iterations: %.8f\n" i loss
end
end
return tstate.model, tstate.parameters, tstate.states
end

model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)

println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples))
9 changes: 2 additions & 7 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
#! format: on
end

logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end

function accuracy(model, ps, st, dataloader; dev=gpu_device())
total_correct, total = 0, 0
st = Lux.testmode(st)
Expand Down Expand Up @@ -94,6 +87,8 @@ end
lr_schedule = linear_interpolation(
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0])

loss = CrossEntropyLoss(; logits=Val(true))

for epoch in 1:epochs
stime = time()
lr = 0
Expand Down
6 changes: 4 additions & 2 deletions examples/DDIM/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,12 @@ function preprocess_image(image::Matrix{<:RGB}, image_size::Int)
return apply(CenterResizeCrop((image_size, image_size)), Image(image)) |> itemdata
end

const maeloss = MAELoss()

function loss_function(model, ps, st, data)
(noises, images, pred_noises, pred_images), st = Lux.apply(model, data, ps, st)
noise_loss = mean(abs, noises .- pred_noises)
image_loss = mean(abs, images .- pred_images)
noise_loss = maeloss(noises, pred_noises)
image_loss = maeloss(images, pred_images)
return noise_loss, st, (; image_loss, noise_loss)
end

Expand Down
7 changes: 4 additions & 3 deletions examples/GravitationalWaveForm/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ CUDA.allowscalar(false)

# We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position
# vector $r = r_1 - r_2$ and use Newtonian formulas to get $r_1$, $r_2$ (e.g. Theoretical
# Mechanics of Particles and Continua 4.3)
# Mechanics of Particles and Continua 4.3)

function one2two(path, m₁, m₂)
M = m₁ + m₂
Expand Down Expand Up @@ -290,11 +290,12 @@ end

# Next, we define the objective (loss) function to be minimized when training the neural
# differential equations.
const mseloss = MSELoss()

function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
loss = sum(abs2, waveform .- pred_waveform)
return loss, pred_waveform
return mseloss(waveform, pred_waveform), pred_waveform
end

# Warmup the loss function
Expand Down
9 changes: 2 additions & 7 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ function create_model()
end

# ## Define Utility Functions
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (data_idx, x, y))
y_pred, st = model((data_idx, x), ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx, gdev=gpu_device())
total_correct, total = 0, 0
Expand Down Expand Up @@ -101,7 +96,7 @@ function train()
x = x |> dev
y = y |> dev
(_, _, _, train_state) = Lux.Experimental.single_train_step!(
AutoZygote(), loss, (data_idx, x, y), train_state)
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime

Expand Down
9 changes: 2 additions & 7 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,7 @@ function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Boo
end

# ## Define Utility Functions
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader; dev=gpu_device())
total_correct, total = 0, 0
Expand Down Expand Up @@ -143,7 +138,7 @@ function train(model_function; cpu::Bool=false, kwargs...)
x = dev(x)
y = dev(y)
_, _, _, tstate = Lux.Experimental.single_train_step!(
AutoZygote(), loss, (x, y), tstate)
AutoZygote(), logitcrossentropy, (x, y), tstate)
end
ttime = time() - stime

Expand Down
9 changes: 3 additions & 6 deletions examples/PolynomialFitting/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@ opt = Adam(0.03f0)

# We will use the `Lux.Training` API so we need to ensure that our loss function takes 4
# inputs -- model, parameters, states and data. The function must return 3 values -- loss,
# updated_state, and any computed statistics.
function loss_function(model, ps, st, data)
y_pred, st = Lux.apply(model, data[1], ps, st)
mse_loss = mean(abs2, y_pred .- data[2])
return mse_loss, st, ()
end
# updated_state, and any computed statistics. This is already satisfied by the loss
# functions provided by Lux.
const loss_function = MSELoss()

# ## Training

Expand Down
Loading

1 comment on commit f7b9539

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: f7b9539 Previous: f1b8c12 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3864.625 ns 3683.125 ns 1.05
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7197.4 ns 7288.666666666667 ns 0.99
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20859 ns 20909 ns 1.00
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9690.5 ns 9847.3 ns 0.98
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8960.6 ns 9238.375 ns 0.97
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4472.125 ns 4527.125 ns 0.99
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1152.2189781021898 ns 1168.5407407407408 ns 0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1166.719696969697 ns 1176.1526717557251 ns 0.99
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1178.348148148148 ns 1186.4857142857143 ns 0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1791.0849056603774 ns 1782.859375 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.68464730290455 ns 179.37413073713492 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17282 ns 17342 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16802 ns 17022 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 38952 ns 37380 ns 1.04
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29225 ns 29484.5 ns 0.99
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19767 ns 21770 ns 0.91
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17187.5 ns 17477.5 ns 0.98
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4339.571428571428 ns 4316.571428571428 ns 1.01
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3863.375 ns 3864.625 ns 1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3932.375 ns 3923.5 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4904.857142857143 ns 4809 ns 1.02
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1662.1 ns 1660.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 46506004.5 ns 39311146 ns 1.18
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57802889 ns 57818439 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 110407401 ns 70725143 ns 1.56
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 106701454 ns 89020101 ns 1.20
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 107083940.5 ns 72846612 ns 1.47
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11944033.5 ns 12056878.5 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 17826363 ns 17802524.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7025065.5 ns 7028063 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6982537 ns 7000092.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 18426786.5 ns 9924699 ns 1.86
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6380584 ns 6389608 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 730954388 ns 737562829 ns 0.99
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2547466756 ns 2545549640 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 143281222.5 ns 146821325 ns 0.98
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 907944434 ns 868615027 ns 1.05
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3408151568 ns 3064060217 ns 1.11
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 232691070 ns 219512795 ns 1.06
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 701405563.5 ns 685678726 ns 1.02
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2751192398 ns 2574375943 ns 1.07
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 150091812 ns 127147427 ns 1.18
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 174202338.5 ns 171884482 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 651482390 ns 650293250.5 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 45292469 ns 34511836 ns 1.31
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 164217014 ns 164391167.5 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 640363054 ns 634653416 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30140396.5 ns 29977086.5 ns 1.01
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 210450508 ns 185946798 ns 1.13
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 812675116 ns 765662897.5 ns 1.06
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 37414452 ns 35241726.5 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1319067230.5 ns 1245538918.5 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1855157553 ns 1864879281 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2389725378 ns 2293551179 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2576211037 ns 2516850614 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1979708272 ns 1882887952.5 ns 1.05
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 555898503 ns 561045265 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 316191735 ns 326179109 ns 0.97
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 314504058 ns 323271956 ns 0.97
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 465777211 ns 349888101 ns 1.33
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11830345.5 ns 11973548 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18031514 ns 17858872 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19168195 ns 19168560 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23938562 ns 23865197 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18050494 ns 17866720 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1154605 ns 1158234 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 5872019 ns 5814007 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2056057 ns 2054540.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2037893 ns 2037248 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2074361 ns 2078324 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 200504 ns 202510.5 ns 0.99
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 292576 ns 293437.5 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 265335 ns 266057.5 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 365001 ns 365572 ns 1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 408391.5 ns 407804 ns 1.00
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 383161 ns 275034 ns 1.39
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 407973 ns 411080 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83536 ns 83504 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81562 ns 81180.5 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81833 ns 81631 ns 1.00
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86140 ns 86775.5 ns 0.99
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104836 ns 104563 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 207388132 ns 203633792 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 324041362 ns 328082047.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 435241619.5 ns 399733123 ns 1.09
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 489451804 ns 429567326 ns 1.14
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 406672717.5 ns 375921768 ns 1.08
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 320495841 ns 328704380 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 101394771.5 ns 101203246 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43915846 ns 43990642 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43768021 ns 43821294.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 70277761.5 ns 53275150 ns 1.32
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 29141934 ns 28607335 ns 1.02
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18724535.5 ns 19166105 ns 0.98
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19456819 ns 19549447.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 22919966.5 ns 23387251 ns 0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 23955208 ns 24155491 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19535589 ns 19735654 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6537723 ns 6562123 ns 1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6514811 ns 6547446.5 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6477278 ns 6511687 ns 0.99
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6514184 ns 6536680 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.