Skip to content

Commit

Permalink
Update examples using the new losses
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 20, 2024
1 parent ce2ec21 commit 514414e
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 54 deletions.
2 changes: 1 addition & 1 deletion docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ All features listed on this page are **experimental** which means:
Pages = ["contrib.md"]
```

## Training
## [Training](@id Training-API)

Helper Functions making it easier to train `Lux.jl` models.

Expand Down
8 changes: 8 additions & 0 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ Pages = ["utilities.md"]

## Loss Functions

Loss Functions Objects take 2 forms of inputs:

1. $y\hat$ and $y$ where $y\hat$ is the predicted output and $y$ is the target output.
2. `model`, `ps`, `st`, `(x, y)` where `model` is the model, `ps` are the parameters,
`st` are the states and `(x, y)` are the input and target pair. Then it returns the
loss, updated states, and an empty named tuple. This makes them compatible with the
[Experimental Training API](@ref Training-API).

!!! warning

When using ChainRules.jl compatible AD (like Zygote), we only compute the gradients
Expand Down
9 changes: 2 additions & 7 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
#! format: on
end

logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end

function accuracy(model, ps, st, dataloader; dev=gpu_device())
total_correct, total = 0, 0
st = Lux.testmode(st)
Expand Down Expand Up @@ -94,6 +87,8 @@ end
lr_schedule = linear_interpolation(
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0])

loss = CrossEntropyLoss(; logits=Val(true))

for epoch in 1:epochs
stime = time()
lr = 0
Expand Down
6 changes: 4 additions & 2 deletions examples/DDIM/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,12 @@ function preprocess_image(image::Matrix{<:RGB}, image_size::Int)
return apply(CenterResizeCrop((image_size, image_size)), Image(image)) |> itemdata
end

const maeloss = MAELoss()

function loss_function(model, ps, st, data)
(noises, images, pred_noises, pred_images), st = Lux.apply(model, data, ps, st)
noise_loss = mean(abs, noises .- pred_noises)
image_loss = mean(abs, images .- pred_images)
noise_loss = maeloss(noises, pred_noises)
image_loss = maeloss(images, pred_images)
return noise_loss, st, (; image_loss, noise_loss)
end

Expand Down
7 changes: 4 additions & 3 deletions examples/GravitationalWaveForm/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ CUDA.allowscalar(false)

# We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position
# vector $r = r_1 - r_2$ and use Newtonian formulas to get $r_1$, $r_2$ (e.g. Theoretical
# Mechanics of Particles and Continua 4.3)
# Mechanics of Particles and Continua 4.3)

function one2two(path, m₁, m₂)
M = m₁ + m₂
Expand Down Expand Up @@ -290,11 +290,12 @@ end

# Next, we define the objective (loss) function to be minimized when training the neural
# differential equations.
const mseloss = MSELoss()

function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
loss = sum(abs2, waveform .- pred_waveform)
return loss, pred_waveform
return mseloss(waveform, pred_waveform), pred_waveform
end

# Warmup the loss function
Expand Down
9 changes: 2 additions & 7 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ function create_model()
end

# ## Define Utility Functions
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (data_idx, x, y))
y_pred, st = model((data_idx, x), ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx, gdev=gpu_device())
total_correct, total = 0, 0
Expand Down Expand Up @@ -101,7 +96,7 @@ function train()
x = x |> dev
y = y |> dev
(_, _, _, train_state) = Lux.Experimental.single_train_step!(
AutoZygote(), loss, (data_idx, x, y), train_state)
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime

Expand Down
9 changes: 2 additions & 7 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,7 @@ function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Boo
end

# ## Define Utility Functions
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader; dev=gpu_device())
total_correct, total = 0, 0
Expand Down Expand Up @@ -143,7 +138,7 @@ function train(model_function; cpu::Bool=false, kwargs...)
x = dev(x)
y = dev(y)
_, _, _, tstate = Lux.Experimental.single_train_step!(
AutoZygote(), loss, (x, y), tstate)
AutoZygote(), logitcrossentropy, (x, y), tstate)
end
ttime = time() - stime

Expand Down
9 changes: 3 additions & 6 deletions examples/PolynomialFitting/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@ opt = Adam(0.03f0)

# We will use the `Lux.Training` API so we need to ensure that our loss function takes 4
# inputs -- model, parameters, states and data. The function must return 3 values -- loss,
# updated_state, and any computed statistics.
function loss_function(model, ps, st, data)
y_pred, st = Lux.apply(model, data[1], ps, st)
mse_loss = mean(abs2, y_pred .- data[2])
return mse_loss, st, ()
end
# updated_state, and any computed statistics. This is already satisfied by the loss
# functions provided by Lux.
const loss_function = MSELoss()

# ## Training

Expand Down
7 changes: 1 addition & 6 deletions examples/SimpleChains/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)))
simple_chains_model = adaptor(lux_model)

# ## Helper Functions
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
Expand Down
23 changes: 8 additions & 15 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,20 +116,12 @@ end
# Now let's define the binarycrossentropy loss. Typically it is recommended to use
# `logitbinarycrossentropy` since it is more numerically stable, but for the sake of
# simplicity we will use `binarycrossentropy`.

function xlogy(x, y)
result = x * log(y)
return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
y_pred = y_pred .+ eps(eltype(y_pred))
return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
Expand All @@ -156,7 +148,7 @@ function main(model_type)
y = y |> dev

(_, loss, _, train_state) = Lux.Experimental.single_train_step!(
AutoZygote(), compute_loss, (x, y), train_state)
AutoZygote(), lossfn, (x, y), train_state)

@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
Expand All @@ -166,8 +158,9 @@ function main(model_type)
for (x, y) in val_loader
x = x |> dev
y = y |> dev
loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
acc = accuracy(ret.y_pred, y)
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
Expand Down

0 comments on commit 514414e

Please sign in to comment.