Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss functions module #704

Merged
merged 13 commits into from
Jun 20, 2024
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.5.55"
version = "0.5.56"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -16,6 +16,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Expand All @@ -27,6 +28,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
Expand Down Expand Up @@ -89,6 +91,7 @@ Functors = "0.4.10"
GPUArraysCore = "0.1.6"
LinearAlgebra = "1.10"
Logging = "1.10"
LossFunctions = "0.11.1"
LuxCore = "0.1.14"
LuxDeviceUtils = "0.1.22"
LuxLib = "0.3.23"
Expand All @@ -99,6 +102,7 @@ MacroTools = "0.5.13"
Markdown = "1.10"
NCCL = "0.1.1"
OhMyThreads = "0.5.1"
OneHotArrays = "0.2.5"
Optimisers = "0.3"
Pkg = "1.10"
PrecompileTools = "1.2"
Expand Down Expand Up @@ -130,6 +134,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Expand All @@ -142,4 +147,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxTestUtils", "MLUtils", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxTestUtils", "MLUtils", "OneHotArrays", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
4 changes: 1 addition & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ pages = [
"api/Lux/distributed_utils.md",
],
"Accelerator Support" => [
"api/Accelerator_Support/LuxCUDA.md",
"api/Accelerator_Support/LuxDeviceUtils.md"
],
"Building Blocks" => [
Expand Down Expand Up @@ -82,8 +81,7 @@ makedocs(; sitename="Lux.jl Documentation",
authors="Avik Pal et al.",
clean=true,
doctest=false, # We test it in the CI, no need to run it here
modules=[Lux, LuxCore, LuxLib, WeightInitializers,
Boltz, LuxTestUtils, LuxDeviceUtils, LuxCUDA],
modules=[Lux, LuxCore, LuxLib, WeightInitializers, Boltz, LuxTestUtils, LuxDeviceUtils],
linkcheck=true,
repo="https://github.com/LuxDL/Lux.jl/blob/{commit}{path}#{line}",
format=DocumenterVitepress.MarkdownVitepress(;
Expand Down
2 changes: 0 additions & 2 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ export default defineConfig({
},
{
text: 'Accelerator Support', items: [
{ text: 'LuxCUDA', link: '/api/Accelerator_Support/LuxCUDA' },
{ text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' }
]
},
Expand Down Expand Up @@ -196,7 +195,6 @@ export default defineConfig({
},
{
text: 'Accelerator Support', collapsed: false, items: [
{ text: 'LuxCUDA', link: '/api/Accelerator_Support/LuxCUDA' },
{ text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' }]
},
{
Expand Down
16 changes: 0 additions & 16 deletions docs/src/api/Accelerator_Support/LuxCUDA.md

This file was deleted.

2 changes: 1 addition & 1 deletion docs/src/api/Lux/autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Lux. Additionally, we provide some convenience functions for working with AD.
| [`ReverseDiff.jl`](https://github.com/JuliaDiff/ReverseDiff.jl) | ✔️ | ❌ | ❌ | Tier II |
| [`Tracker.jl`](https://github.com/FluxML/Tracker.jl) | ✔️ | ✔️ | ❌ | Tier II |
| [`Enzyme.jl`](https://github.com/EnzymeAD/Enzyme.jl) | ✔️ | ❓[^q] | ❓[^q] | Tier II |
| [`Tapir.jl`](https://github.com/withbayes/Tapir.jl) | ❓[^q] | ❓[^q] | ❌ | Tier IV |
| [`Tapir.jl`](https://github.com/withbayes/Tapir.jl) | ❓[^q] | | ❌ | Tier IV |
| [`Diffractor.jl`](https://github.com/JuliaDiff/Diffractor.jl) | ❓[^q] | ❓[^q] | ❓[^q] | Tier IV |

[^q]: This feature is supported downstream, but we don't extensively test it to ensure
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ All features listed on this page are **experimental** which means:
Pages = ["contrib.md"]
```

## Training
## [Training](@id Training-API)

Helper Functions making it easier to train `Lux.jl` models.

Expand Down
51 changes: 43 additions & 8 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,38 @@
Pages = ["utilities.md"]
```

## Device Management / Data Transfer
## Loss Functions

```@docs
Lux.cpu
Lux.gpu
```
Loss Functions Objects take 2 forms of inputs:

1. $y\hat$ and $y$ where $y\hat$ is the predicted output and $y$ is the target output.
2. `model`, `ps`, `st`, `(x, y)` where `model` is the model, `ps` are the parameters,
`st` are the states and `(x, y)` are the input and target pair. Then it returns the
loss, updated states, and an empty named tuple. This makes them compatible with the
[Experimental Training API](@ref Training-API).

!!! warning

For detailed API documentation on Data Transfer check out the
[LuxDeviceUtils.jl](@ref LuxDeviceUtils-API)
When using ChainRules.jl compatible AD (like Zygote), we only compute the gradients
wrt the inputs and drop any gradients wrt the targets.

```@docs
GenericLossFunction
BinaryCrossEntropyLoss
BinaryFocalLoss
CrossEntropyLoss
DiceCoeffLoss
FocalLoss
HingeLoss
HuberLoss
KLDivergenceLoss
MAELoss
MSELoss
MSLELoss
PoissonLoss
SiameseContrastiveLoss
SquaredHingeLoss
```

## Weight Initialization

Expand All @@ -31,6 +52,8 @@ Lux.gpu
Lux.foldl_init
Lux.istraining
Lux.multigate
Lux.xlogy
Lux.xlogx
```

## Updating Floating Point Precision
Expand All @@ -56,8 +79,20 @@ StatefulLuxLayer
@compact
```

## Truncated Stacktraces
## Truncated Stacktraces (Deprecated)

```@docs
Lux.disable_stacktrace_truncation!
```

## Device Management / Data Transfer (Deprecated)

```@docs
Lux.cpu
Lux.gpu
```

!!! warning

For detailed API documentation on Data Transfer check out the
[LuxDeviceUtils.jl](@ref LuxDeviceUtils-API)
5 changes: 3 additions & 2 deletions examples/Basics/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.13, 0.14, 0.15"
ComponentArrays = "0.15"
ForwardDiff = "0.10"
Literate = "2"
Lux = "0.5"
Lux = "0.5.56"
LuxCUDA = "0.2, 0.3"
Optimisers = "0.2, 0.3"
Zygote = "0.6"
48 changes: 20 additions & 28 deletions examples/Basics/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,34 +294,26 @@ println("x shape: ", size(x_samples), "; y shape: ", size(y_samples))
# [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). We will use Stochastic Gradient
# Descent (SGD) with a learning rate of `0.01`.

using Optimisers

opt = Optimisers.Descent(0.01f0)

# Initialize the initial state of the optimiser
opt_state = Optimisers.setup(opt, ps)
using Optimisers, Printf

# Define the loss function
function sse(model, ps, st, X, y)
y_pred, st_new = model(X, ps, st)
return sum(abs2, y_pred .- y), st_new
end
sse(weight, bias, X, y) = sum(abs2, weight * X .+ bias .- y)
loss_function(ps, X, y) = sse(model, ps, st, X, y)

println("Loss Value with ground true parameters: ", sse(W, b, x_samples, y_samples))

for i in 1:100
## In actual code, don't use globals. But here I will simply for the sake of
## demonstration
global ps, st, opt_state
## Compute the gradient using the pullback API to update the states
(loss, st), pb_f = Zygote.pullback(loss_function, ps, x_samples, y_samples)
## We pass nothing as the seed for `st`, since we don't want to propagate any gradient
## for st
gs = pb_f((one(loss), nothing))[1]
## Update model parameters
## `Optimisers.update` can be used if mutation is not desired
opt_state, ps = Optimisers.update!(opt_state, ps, gs)
(i % 10 == 1 || i == 100) && println(lazy"Loss Value after $i iterations: $loss")
lossfn = MSELoss()

println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples))

# We will train the model using our training API.
function train_model!(model, ps, st, opt, nepochs::Int)
tstate = Lux.Experimental.TrainState(model, ps, st, opt)
for i in 1:nepochs
grads, loss, _, tstate = Lux.Experimental.single_train_step!(
AutoZygote(), lossfn, (x_samples, y_samples), tstate)
if i % 1000 == 1 || i == nepochs
@printf "Loss Value after %6d iterations: %.8f\n" i loss
end
end
return tstate.model, tstate.parameters, tstate.states
end

model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)

println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples))
9 changes: 2 additions & 7 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
#! format: on
end

logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end

function accuracy(model, ps, st, dataloader; dev=gpu_device())
total_correct, total = 0, 0
st = Lux.testmode(st)
Expand Down Expand Up @@ -94,6 +87,8 @@ end
lr_schedule = linear_interpolation(
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0])

loss = CrossEntropyLoss(; logits=Val(true))

for epoch in 1:epochs
stime = time()
lr = 0
Expand Down
6 changes: 4 additions & 2 deletions examples/DDIM/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,12 @@ function preprocess_image(image::Matrix{<:RGB}, image_size::Int)
return apply(CenterResizeCrop((image_size, image_size)), Image(image)) |> itemdata
end

const maeloss = MAELoss()

function loss_function(model, ps, st, data)
(noises, images, pred_noises, pred_images), st = Lux.apply(model, data, ps, st)
noise_loss = mean(abs, noises .- pred_noises)
image_loss = mean(abs, images .- pred_images)
noise_loss = maeloss(noises, pred_noises)
image_loss = maeloss(images, pred_images)
return noise_loss, st, (; image_loss, noise_loss)
end

Expand Down
7 changes: 4 additions & 3 deletions examples/GravitationalWaveForm/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ CUDA.allowscalar(false)

# We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position
# vector $r = r_1 - r_2$ and use Newtonian formulas to get $r_1$, $r_2$ (e.g. Theoretical
# Mechanics of Particles and Continua 4.3)
# Mechanics of Particles and Continua 4.3)

function one2two(path, m₁, m₂)
M = m₁ + m₂
Expand Down Expand Up @@ -290,11 +290,12 @@ end

# Next, we define the objective (loss) function to be minimized when training the neural
# differential equations.
const mseloss = MSELoss()

function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
loss = sum(abs2, waveform .- pred_waveform)
return loss, pred_waveform
return mseloss(waveform, pred_waveform), pred_waveform
end

# Warmup the loss function
Expand Down
9 changes: 2 additions & 7 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ function create_model()
end

# ## Define Utility Functions
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (data_idx, x, y))
y_pred, st = model((data_idx, x), ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx, gdev=gpu_device())
total_correct, total = 0, 0
Expand Down Expand Up @@ -101,7 +96,7 @@ function train()
x = x |> dev
y = y |> dev
(_, _, _, train_state) = Lux.Experimental.single_train_step!(
AutoZygote(), loss, (data_idx, x, y), train_state)
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime

Expand Down
9 changes: 2 additions & 7 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,7 @@ function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Boo
end

# ## Define Utility Functions
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader; dev=gpu_device())
total_correct, total = 0, 0
Expand Down Expand Up @@ -143,7 +138,7 @@ function train(model_function; cpu::Bool=false, kwargs...)
x = dev(x)
y = dev(y)
_, _, _, tstate = Lux.Experimental.single_train_step!(
AutoZygote(), loss, (x, y), tstate)
AutoZygote(), logitcrossentropy, (x, y), tstate)
end
ttime = time() - stime

Expand Down
9 changes: 3 additions & 6 deletions examples/PolynomialFitting/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@ opt = Adam(0.03f0)

# We will use the `Lux.Training` API so we need to ensure that our loss function takes 4
# inputs -- model, parameters, states and data. The function must return 3 values -- loss,
# updated_state, and any computed statistics.
function loss_function(model, ps, st, data)
y_pred, st = Lux.apply(model, data[1], ps, st)
mse_loss = mean(abs2, y_pred .- data[2])
return mse_loss, st, ()
end
# updated_state, and any computed statistics. This is already satisfied by the loss
# functions provided by Lux.
const loss_function = MSELoss()

# ## Training

Expand Down
Loading
Loading