diff --git a/examples/NeuralODE/Project.toml b/examples/NeuralODE/Project.toml index f6691ac876..2f342f8dbf 100644 --- a/examples/NeuralODE/Project.toml +++ b/examples/NeuralODE/Project.toml @@ -24,5 +24,5 @@ MLUtils = "0.2, 0.3, 0.4" OneHotArrays = "0.1, 0.2" Optimisers = "0.2, 0.3" OrdinaryDiffEq = "6" -SciMLSensitivity = "7" +SciMLSensitivity = "7.45" Zygote = "0.6" diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index ce3f29ad97..01b1421406 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -58,33 +58,38 @@ function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f return NeuralODE(model, solver, sensealg, tspan, kwargs) end +# OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities +# like `ReverseDiffAdjoint` can't handle non-Vector inputs. Hence, we need to convert the +# input and output of the ODE solver to a Vector. function (n::NeuralODE)(x, ps, st) function dudt(u, p, t) - u_, st = n.model(u, p, st) - return u_ + u_, st = n.model(reshape(u, size(x)), p, st) + return vec(u_) end - prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps) + prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps) return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st end -diffeqsol_to_array(x::ODESolution) = last(x.u) +@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :)) +@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :)) # ## Create and Initialize the Neural ODE Layer -function create_model(model_fn=NeuralODE) +function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Bool=false, + sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP())) ## Construct the Neural ODE Model model = Chain(FlattenLayer(), - Dense(784, 20, tanh), - model_fn(Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh)); - save_everystep=false, reltol=1.0f-3, abstol=1.0f-3, save_start=false), - diffeqsol_to_array, - Dense(20, 10)) + Dense(784 => 20, tanh), + model_fn(Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh)); + save_everystep=false, reltol=1.0f-3, abstol=1.0f-3, save_start=false, + sensealg), + Base.Fix1(diffeqsol_to_array, 20), + Dense(20 => 10)) rng = Random.default_rng() Random.seed!(rng, 0) ps, st = Lux.setup(rng, model) - dev = gpu_device() - ps = ComponentArray(ps) |> dev + ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev st = st |> dev return model, ps, st @@ -98,14 +103,13 @@ function loss(x, y, model, ps, st) return logitcrossentropy(y_pred, y), st end -function accuracy(model, ps, st, dataloader) +function accuracy(model, ps, st, dataloader; dev=gpu_device()) total_correct, total = 0, 0 st = Lux.testmode(st) cpu_dev = cpu_device() - gpu_dev = gpu_device() for (x, y) in dataloader target_class = onecold(y) - predicted_class = onecold(cpu_dev(first(model(gpu_dev(x), ps, st)))) + predicted_class = onecold(cpu_dev(first(model(dev(x), ps, st)))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end @@ -113,8 +117,9 @@ function accuracy(model, ps, st, dataloader) end # ## Training -function train(model_function) - model, ps, st = create_model(model_function) +function train(model_function; cpu::Bool=false, kwargs...) + dev = cpu ? cpu_device() : gpu_device() + model, ps, st = create_model(model_function; dev, kwargs...) ## Training train_dataloader, test_dataloader = loadmnist(128, 0.9) @@ -122,11 +127,9 @@ function train(model_function) opt = Adam(0.001f0) st_opt = Optimisers.setup(opt, ps) - dev = gpu_device() - ### Warmup the Model - img, lab = dev(train_dataloader.data[1][:, :, :, 1:1]), - dev(train_dataloader.data[2][:, 1:1]) + img = dev(train_dataloader.data[1][:, :, :, 1:1]) + lab = dev(train_dataloader.data[2][:, 1:1]) loss(img, lab, model, ps, st) (l, _), back = pullback(p -> loss(img, lab, model, p, st), ps) back((one(l), nothing)) @@ -146,13 +149,28 @@ function train(model_function) ttime = time() - stime println("[$epoch/$nepochs] \t Time $(round(ttime; digits=2))s \t Training Accuracy: " * - "$(round(accuracy(model, ps, st, train_dataloader) * 100; digits=2))% \t " * - "Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader) * 100; digits=2))%") + "$(round(accuracy(model, ps, st, train_dataloader; dev) * 100; digits=2))% \t " * + "Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader; dev) * 100; digits=2))%") end end train(NeuralODE) +# We can also change the sensealg and train the model! `GaussAdjoint` allows you to use +# any arbitrary parameter structure and not just a flat vector (`ComponentArray`). + +train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true) + +# But remember some AD backends like `ReverseDiff` is not GPU compatible. +# For a model this size, you will notice that training time is significantly lower for +# training on CPU than on GPU. + +train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true) + +# For completeness, let's also test out discrete sensitivities! + +train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true) + # ## Alternate Implementation using Stateful Layer # Starting `v0.5.5`, Lux provides a `Lux.Experimental.StatefulLuxLayer` which can be used