Skip to content

Commit

Permalink
Merge pull request #431 from LuxDL/ap/neuralode_reversediff
Browse files Browse the repository at this point in the history
Ensure ReverseDiff and Gauss Adjoint is also tested
  • Loading branch information
avik-pal authored Oct 23, 2023
2 parents 793ee5b + 4714956 commit ffbb4ba
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 24 deletions.
2 changes: 1 addition & 1 deletion examples/NeuralODE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ MLUtils = "0.2, 0.3, 0.4"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2, 0.3"
OrdinaryDiffEq = "6"
SciMLSensitivity = "7"
SciMLSensitivity = "7.45"
Zygote = "0.6"
64 changes: 41 additions & 23 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,33 +58,38 @@ function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f
return NeuralODE(model, solver, sensealg, tspan, kwargs)
end

# OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities
# like `ReverseDiffAdjoint` can't handle non-Vector inputs. Hence, we need to convert the
# input and output of the ODE solver to a Vector.
function (n::NeuralODE)(x, ps, st)
function dudt(u, p, t)
u_, st = n.model(u, p, st)
return u_
u_, st = n.model(reshape(u, size(x)), p, st)
return vec(u_)
end
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st
end

diffeqsol_to_array(x::ODESolution) = last(x.u)
@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))

# ## Create and Initialize the Neural ODE Layer
function create_model(model_fn=NeuralODE)
function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Bool=false,
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()))
## Construct the Neural ODE Model
model = Chain(FlattenLayer(),
Dense(784, 20, tanh),
model_fn(Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh));
save_everystep=false, reltol=1.0f-3, abstol=1.0f-3, save_start=false),
diffeqsol_to_array,
Dense(20, 10))
Dense(784 => 20, tanh),
model_fn(Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
save_everystep=false, reltol=1.0f-3, abstol=1.0f-3, save_start=false,
sensealg),
Base.Fix1(diffeqsol_to_array, 20),
Dense(20 => 10))

rng = Random.default_rng()
Random.seed!(rng, 0)

ps, st = Lux.setup(rng, model)
dev = gpu_device()
ps = ComponentArray(ps) |> dev
ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev
st = st |> dev

return model, ps, st
Expand All @@ -98,35 +103,33 @@ function loss(x, y, model, ps, st)
return logitcrossentropy(y_pred, y), st
end

function accuracy(model, ps, st, dataloader)
function accuracy(model, ps, st, dataloader; dev=gpu_device())
total_correct, total = 0, 0
st = Lux.testmode(st)
cpu_dev = cpu_device()
gpu_dev = gpu_device()
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(cpu_dev(first(model(gpu_dev(x), ps, st))))
predicted_class = onecold(cpu_dev(first(model(dev(x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end

# ## Training
function train(model_function)
model, ps, st = create_model(model_function)
function train(model_function; cpu::Bool=false, kwargs...)
dev = cpu ? cpu_device() : gpu_device()
model, ps, st = create_model(model_function; dev, kwargs...)

## Training
train_dataloader, test_dataloader = loadmnist(128, 0.9)

opt = Adam(0.001f0)
st_opt = Optimisers.setup(opt, ps)

dev = gpu_device()

### Warmup the Model
img, lab = dev(train_dataloader.data[1][:, :, :, 1:1]),
dev(train_dataloader.data[2][:, 1:1])
img = dev(train_dataloader.data[1][:, :, :, 1:1])
lab = dev(train_dataloader.data[2][:, 1:1])
loss(img, lab, model, ps, st)
(l, _), back = pullback(p -> loss(img, lab, model, p, st), ps)
back((one(l), nothing))
Expand All @@ -146,13 +149,28 @@ function train(model_function)
ttime = time() - stime

println("[$epoch/$nepochs] \t Time $(round(ttime; digits=2))s \t Training Accuracy: " *
"$(round(accuracy(model, ps, st, train_dataloader) * 100; digits=2))% \t " *
"Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader) * 100; digits=2))%")
"$(round(accuracy(model, ps, st, train_dataloader; dev) * 100; digits=2))% \t " *
"Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader; dev) * 100; digits=2))%")
end
end

train(NeuralODE)

# We can also change the sensealg and train the model! `GaussAdjoint` allows you to use
# any arbitrary parameter structure and not just a flat vector (`ComponentArray`).

train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true)

# But remember some AD backends like `ReverseDiff` is not GPU compatible.
# For a model this size, you will notice that training time is significantly lower for
# training on CPU than on GPU.

train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true)

# For completeness, let's also test out discrete sensitivities!

train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)

# ## Alternate Implementation using Stateful Layer

# Starting `v0.5.5`, Lux provides a `Lux.Experimental.StatefulLuxLayer` which can be used
Expand Down

0 comments on commit ffbb4ba

Please sign in to comment.