Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate fixing AD issues #151

Open
wants to merge 5 commits into
base: new_API
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ version = "1.0.0"
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
GeoStats = "dcc97b0b-8ce5-5539-9008-bb190f959ef6"
Expand All @@ -20,6 +22,7 @@ Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Muninn = "4b816528-16ba-4e32-9a2e-3c1bc2049d3a"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down Expand Up @@ -48,7 +51,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
BenchmarkTools = "1.3.2"
ChainRules = "1.50"
Downloads = "1"
Flux = "0.13, 0.14"
GR = "0.71, 0.72, 0.73"
Huginn = "0.6"
IJulia = "1.2"
Expand Down
8 changes: 7 additions & 1 deletion src/ODINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using IterTools: ncycle
using Zygote
using ChainRules: @ignore_derivatives
using Base: @kwdef
using Flux
using Flux, Lux, ComponentArrays
using Tullio
using Infiltrator, Cthulhu
using Plots, PlotThemes
Expand All @@ -31,6 +31,12 @@ using Downloads
using TimerOutputs
using GeoStats
using ImageFiltering
using EnzymeCore

# This is equivalent to `@ignore_derivatives`
EnzymeCore.EnzymeRules.inactive(::typeof(Huginn.define_callback_steps), args...) = nothing
EnzymeCore.EnzymeRules.inactive(::typeof(MB_timestep!), args...) = nothing
EnzymeCore.EnzymeRules.inactive(::typeof(apply_MB_mask!), args...) = nothing

# ##############################################
# ############ PARAMETERS ###############
Expand Down
22 changes: 0 additions & 22 deletions src/helpers/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,28 +217,6 @@ function generate_batches(batch_size, UA, gdirs, gdir_refs, tspan::Tuple; shuffl
return train_loader
end


"""
get_NN()

Generates a neural network.
"""
function get_NN(θ_trained)
UA = Chain(
Dense(1,3, x->softplus.(x)),
Dense(3,10, x->softplus.(x)),
Dense(10,3, x->softplus.(x)),
Dense(3,1, sigmoid_A)
)
UA = Flux.f64(UA)
# See if parameters need to be retrained or not
θ, UA_f = Flux.destructure(UA)
if !isempty(θ_trained)
θ = θ_trained
end
return UA_f, θ
end

function get_NN_inversion(θ_trained, target)
if target == "D"
U, θ = get_NN_inversion_D(θ_trained)
Expand Down
23 changes: 11 additions & 12 deletions src/models/machine_learning/ML_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@ get_NN()
Generates a neural network.
"""
function get_NN(θ_trained)
UA = Flux.Chain(
Dense(1,3, x->softplus.(x)),
Dense(3,10, x->softplus.(x)),
Dense(10,3, x->softplus.(x)),
Dense(3,1, sigmoid_A)
UA = Lux.Chain(
Lux.Dense(1,3, x->Lux.softplus.(x)),
Lux.Dense(3,10, x->Lux.softplus.(x)),
Lux.Dense(10,3, x->Lux.softplus.(x)),
Lux.Dense(3,1, sigmoid_A)
)
UA = Flux.f64(UA)
# See if parameters need to be retrained or not
θ, UA_f = Flux.destructure(UA)
θ, st = Lux.setup(Random.default_rng(), UA)
θ = ComponentArray{Float64}(θ)
if !isnothing(θ_trained)
θ = θ_trained
end
return UA, θ, UA_f
return UA, θ, st
end

"""
Expand All @@ -26,7 +25,7 @@ end
Predicts the value of A with a neural network based on the long-term air temperature.
"""
function predict_A̅(U, temp)
return U(temp) .* 1e-18
return U(temp)[1] .* 1e-18
end

function sigmoid_A(x)
Expand Down Expand Up @@ -83,12 +82,12 @@ function build_D_features(H::Matrix, temp, ∇S)
∇S_flat = ∇S[inn1(H) .!= 0.0] # flatten
H_flat = H[H .!= 0.0] # flatten
T_flat = repeat(temp,length(H_flat))
X = Flux.normalise(hcat(H_flat,T_flat,∇S_flat))' # build feature matrix
X = Lux.normalise(hcat(H_flat,T_flat,∇S_flat))' # build feature matrix

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux doesn't have a normalise function

return X
end

function build_D_features(H::Float64, temp::Float64, ∇S::Float64)
X = Flux.normalise(hcat([H],[temp],[∇S]))' # build feature matrix
X = Lux.normalise(hcat([H],[temp],[∇S]))' # build feature matrix
return X
end

Expand Down
21 changes: 11 additions & 10 deletions src/models/machine_learning/MLmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,36 @@ function Model(;
return model
end

mutable struct NN{F <: AbstractFloat} <: MLmodel
architecture::Flux.Chain
NN_f::Optimisers.Restructure
θ::Vector{F}
mutable struct NN{T1, T2, T3} <: MLmodel
architecture::T1
st::T2
θ::T3
end
(f::NN)(u) = f.architecture(u, f.θ, f.st)

"""
NN(params::Parameters;
architecture::Union{Flux.Chain, Nothing} = nothing,
architecture::Union{Lux.Chain, Nothing} = nothing,
θ::Union{Vector{AbstractFloat}, Nothing} = nothing)

Feed-forward neural network.

Keyword arguments
=================
- `architecture`: `Flux.Chain` neural network architecture
- `architecture`: `Lux.Chain` neural network architecture
- `θ`: Neural network parameters
"""
function NN(params::Sleipnir.Parameters;
architecture::Union{Flux.Chain, Nothing} = nothing,
θ::Union{Vector{F}, Nothing} = nothing) where {F <: AbstractFloat}
architecture::Union{Lux.Chain, Nothing} = nothing,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a Chain, won't AbstractExplicitLayer work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will, I'm just making small changes to isolate the real AD issues though.

θ::Union{ComponentArray{F}, Nothing} = nothing) where {F <: AbstractFloat}

if isnothing(architecture)
architecture, θ, NN_f = get_NN(θ)
architecture, θ, st = get_NN(θ)
end

# Build the simulation parameters based on input values
ft = params.simulation.float_type
neural_net = NN{ft}(architecture, NN_f, θ)
neural_net = NN(architecture, st, θ)

return neural_net
end
4 changes: 2 additions & 2 deletions src/parameters/Hyperparameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export Hyperparameters
current_epoch::I
current_minibatch::I
loss_history::Vector{F}
optimizer::Union{Optim.FirstOrderOptimizer, Flux.Optimise.AbstractOptimiser, Optimisers.AbstractRule}
optimizer::Union{Optim.FirstOrderOptimizer, Optimisers.AbstractRule}
loss_epoch::F
epochs::I
batch_size::I
Expand Down Expand Up @@ -33,7 +33,7 @@ function Hyperparameters(;
current_epoch::Int64 = 1,
current_minibatch::Int64 = 1,
loss_history::Vector{Float64} = zeros(Float64, 0),
optimizer::Union{Optim.FirstOrderOptimizer, Flux.Optimise.AbstractOptimiser, Optimisers.AbstractRule} = BFGS(initial_stepnorm=0.001),
optimizer::Union{Optim.FirstOrderOptimizer, Optimisers.AbstractRule} = BFGS(initial_stepnorm=0.001),
loss_epoch::Float64 = 0.0,
epochs::Int64 = 50,
batch_size::Int64 = 15
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function train_UDE!(simulation::FunctionalInversion)
train_batches = generate_batches(simulation)
θ = simulation.model.machine_learning.θ

optf = OptimizationFunction((θ, _, batch_ids, rgi_ids)->loss_iceflow(θ, batch_ids, simulation), Optimization.AutoReverseDiff())
optf = OptimizationFunction((θ, _, batch_ids, rgi_ids)->loss_iceflow(θ, batch_ids, simulation), Optimization.AutoZygote())
optprob = OptimizationProblem(optf, θ)

if simulation.parameters.UDE.target == "A"
Expand Down Expand Up @@ -127,15 +127,13 @@ function batch_iceflow_UDE(θ, simulation::FunctionalInversion, batch_id::I) whe
# Initialize glacier ice flow model
initialize_iceflow_model(model.iceflow[batch_id], batch_id, glacier, params)

params.solver.tstops = @ignore_derivatives Huginn.define_callback_steps(params.simulation.tspan, params.solver.step)
params.solver.tstops = Huginn.define_callback_steps(params.simulation.tspan, params.solver.step)
stop_condition(u,t,integrator) = Sleipnir.stop_condition_tstops(u,t,integrator, params.solver.tstops) #closure
function action!(integrator)
if params.simulation.use_MB
# Compute mass balance
@ignore_derivatives begin
MB_timestep!(model, glacier, params.solver.step, integrator.t; batch_id = batch_id)
apply_MB_mask!(integrator.u, glacier, model.iceflow[batch_id])
end
MB_timestep!(model, glacier, params.solver.step, integrator.t; batch_id = batch_id)
apply_MB_mask!(integrator.u, glacier, model.iceflow[batch_id])
end
# Apply parametrization
apply_UDE_parametrization!(θ, simulation, integrator, batch_id)
Expand Down Expand Up @@ -218,7 +216,7 @@ end

function apply_UDE_parametrization!(θ, simulation::FunctionalInversion, integrator, batch_id::I) where {I <: Integer}
# We load the ML model with the parameters
U = simulation.model.machine_learning.NN_f(θ)
U = NN(simulation.model.machine_learning.architecture, simulation.model.machine_learning.st, convert(typeof(simulation.model.machine_learning.θ),θ))
# We generate the ML parametrization based on the target
if simulation.parameters.UDE.target == "A"
A = predict_A̅(U, [mean(simulation.glaciers[batch_id].climate.longterm_temps)])[1]
Expand All @@ -244,7 +242,7 @@ callback_plots_A = function (θ, l, simulation) # callback function to observe t
p = sortperm(avg_temps)
avg_temps = avg_temps[p]
# We load the ML model with the parameters
U = simulation.model.machine_learning.NN_f(θ)
U = NN(simulation.model.machine_learning.architecture, simulation.model.machine_learning.st, convert(typeof(simulation.model.machine_learning.θ),θ))
pred_A = predict_A̅(U, collect(-23.0:1.0:0.0)')
pred_A = Float64[pred_A...] # flatten
true_A = A_fake(avg_temps, true)
Expand Down