Skip to content

Commit

Permalink
add EKP adaptive timesteppers
Browse files Browse the repository at this point in the history
  • Loading branch information
costachris committed Oct 31, 2023
1 parent 1838a32 commit b018e2e
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 32 deletions.
6 changes: 6 additions & 0 deletions src/Diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ Elements:
- `mse_full_var` :: Variance estimate of MSE(`g_full`, `y_full`), empirical (EKI/EKS) or quadrature (UKI).
- `mse_full_nn_mean` :: MSE(`g_full`, `y_full`) of particle closest to the mean in parameter space. The mean in parameter space is the solution to the particle-based inversion.
- `failures` :: Number of particle failures per iteration. If the calibration is run with the "high_loss" failure handler, this diagnostic will not capture the failures due to parameter mapping.
- `timestep` :: EKP timestep in current iteration.
- `nn_mean_index` :: Particle index of the nearest neighbor to the ensemble mean in parameter space. This index is used to construct `..._nn_mean` metrics.
"""
function io_dictionary_metrics()
Expand All @@ -315,6 +316,7 @@ function io_dictionary_metrics()
"mse_full_var" => (; dims = ("iteration",), group = "metrics", type = Float64),
"mse_full_nn_mean" => (; dims = ("iteration",), group = "metrics", type = Float64),
"failures" => (; dims = ("iteration",), group = "metrics", type = Int16),
"timestep" => (; dims = ("iteration",), group = "metrics", type = Float64),
"nn_mean_index" => (; dims = ("iteration",), group = "metrics", type = Int16),
)
return io_dict
Expand All @@ -340,6 +342,9 @@ function io_dictionary_metrics(ekp::EnsembleKalmanProcess, mse_full::Vector{FT})
# Get loss at nearest_to_mean point
loss_nn_mean = loss[nn_mean]

# get timestep in latest iteration
timestep = deepcopy(ekp.Δt[end])

# Filter NaNs
loss_filt = filter(!isnan, loss)
mse_filt = filter(!isnan, mse_full)
Expand All @@ -357,6 +362,7 @@ function io_dictionary_metrics(ekp::EnsembleKalmanProcess, mse_full::Vector{FT})
"mse_full_var" => Base.setindex(orig_dict["mse_full_var"], mse_full_var, :field),
"mse_full_nn_mean" => Base.setindex(orig_dict["mse_full_nn_mean"], mse_full_nn_mean, :field),
"failures" => Base.setindex(orig_dict["failures"], failures, :field),
"timestep" => Base.setindex(orig_dict["timestep"], timestep, :field),
"nn_mean_index" => Base.setindex(orig_dict["nn_mean_index"], nn_mean, :field),
)
return io_dict
Expand Down
30 changes: 27 additions & 3 deletions src/KalmanProcessUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Utils for the construction and handling of Kalman Process structs.
module KalmanProcessUtils

export generate_ekp,
generate_tekp, get_sparse_indices, get_regularized_indices, get_Δt, PiecewiseConstantDecay, PiecewiseConstantGrowth
generate_tekp, get_sparse_indices, get_regularized_indices, get_Δt, modify_field, update_scheduler!, PiecewiseConstantDecay, PiecewiseConstantGrowth

using LinearAlgebra
using Statistics
Expand Down Expand Up @@ -86,6 +86,7 @@ get_Δt(lrs::PiecewiseConstantGrowth, iteration::IT) where {IT <: Int} = lrs.Δt
localizer::LocalizationMethod = NoLocalization(),
outdir_path::String = pwd(),
to_file::Bool = true,
verbose::Bool = true,
) where {T}
Generates, and possible writes to file, an EnsembleKalmanProcess
Expand All @@ -99,6 +100,7 @@ Inputs:
- localizer :: Covariance localization method.
- outdir_path :: Output path.
- to_file :: Whether to write the serialized prior to a JLD2 file.
- verbose :: Whether to use verbose EKP object
Output:
- The generated EnsembleKalmanProcess.
Expand All @@ -109,8 +111,10 @@ function generate_ekp(
u::Union{Matrix{T}, T} = nothing;
failure_handler::String = "ignore_failures",
localizer::LocalizationMethod = NoLocalization(),
scheduler = DefaultScheduler(),
outdir_path::String = pwd(),
to_file::Bool = true,
verbose::Bool = true,
) where {T}

@assert isa(process, Unscented) || !isnothing(u) "Incorrect EKP constructor."
Expand All @@ -121,10 +125,12 @@ function generate_ekp(
fh = IgnoreFailures()
end

kwargs = Dict(:failure_handler_method => fh, :localization_method => localizer, :verbose => true)
kwargs = Dict(:failure_handler_method => fh, :localization_method => localizer, :verbose => verbose, :scheduler => scheduler)
ekp =
isnothing(u) ? EnsembleKalmanProcess(ref_stats.y, ref_stats.Γ, process; kwargs...) :
EnsembleKalmanProcess(u, ref_stats.y, ref_stats.Γ, process; kwargs...)


if to_file
jldsave(ekobj_path(outdir_path, 1); ekp)
end
Expand All @@ -142,6 +148,7 @@ end
localizer::LocalizationMethod = NoLocalization(),
outdir_path::String = pwd(),
to_file::Bool = true,
verbose::Bool = true,
) where {T, R}
Generates, and possible writes to file, a Tikhonov EnsembleKalmanProcess
Expand All @@ -164,6 +171,7 @@ Inputs:
- localizer :: Covariance localization method.
- outdir_path :: Output path.
- to_file :: Whether to write the serialized prior to a JLD2 file.
- verbose :: Whether to use verbose EKP object
Output:
- The generated augmented EnsembleKalmanProcess.
Expand All @@ -176,8 +184,10 @@ function generate_tekp(
l2_reg::Union{Dict{String, Vector{R}}, R} = nothing,
failure_handler::String = "ignore_failures",
localizer::LocalizationMethod = NoLocalization(),
scheduler = DefaultScheduler(),
outdir_path::String = pwd(),
to_file::Bool = true,
verbose::Bool = true,
) where {T, R}

@assert isa(process, Unscented) || !isnothing(u) "Incorrect TEKP constructor."
Expand Down Expand Up @@ -219,7 +229,7 @@ function generate_tekp(
Γ_aug_list = [ref_stats.Γ, Array(Γ_θ)]
Γ_aug = cat(Γ_aug_list..., dims = (1, 2))

kwargs = Dict(:failure_handler_method => fh, :localization_method => localizer, :verbose => true)
kwargs = Dict(:failure_handler_method => fh, :localization_method => localizer, :verbose => verbose, :scheduler => scheduler)
ekp =
isnothing(u) ? EnsembleKalmanProcess(y_aug, Γ_aug, process; kwargs...) :
EnsembleKalmanProcess(u, y_aug, Γ_aug, process; kwargs...)
Expand All @@ -244,5 +254,19 @@ end
"Returns the indices of parameters to be regularized, given the l2 regularization configuration dictionary."
get_regularized_indices(l2_config::Dict) = flat_dict_keys_where(l2_config, above_eps)

"Return new EKP object with `field_name` overridden by `new_value`"
function modify_field(ekp::EnsembleKalmanProcess, field_name::Symbol, new_value)
fields = fieldnames(EnsembleKalmanProcess)
values = [field_name == f ? new_value : getfield(ekp, f) for f in fields]
return EnsembleKalmanProcess(values...)
end

"Update inv_sqrt_noise of EKP object when using DMC"
function update_scheduler!(ekp, iteration)
if typeof(ekp.scheduler) <: DataMisfitController && iteration > 1
inv_sqrt_Γ = inv(sqrt(posdef_correct(ekp.obs_noise_cov)))
push!(ekp.scheduler.inv_sqrt_noise, inv_sqrt_Γ)
end
end

end # module
40 changes: 31 additions & 9 deletions src/Pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ function init_calibration(config::Dict{Any, Any}; mode::String = "hpc", job_id::
Δt_scheduler = get_entry(proc_config, "Δt", 1.0)
Δt = get_Δt(Δt_scheduler, 1)

scheduler = get_entry(proc_config, "scheduler", nothing)

if !isnothing(scheduler)
Δt = nothing
end

augmented = get_entry(proc_config, "augmented", false)
failure_handler = get_entry(proc_config, "failure_handler", "high_loss")
localizer = get_entry(proc_config, "localizer", NoLocalization())
Expand Down Expand Up @@ -99,7 +105,6 @@ function init_calibration(config::Dict{Any, Any}; mode::String = "hpc", job_id::
ref_stats,
outdir_root,
algo_name,
Δt,
n_param,
N_ens,
N_iter,
Expand All @@ -114,7 +119,7 @@ function init_calibration(config::Dict{Any, Any}; mode::String = "hpc", job_id::
prior_μ = nothing
end
priors = construct_priors(params, outdir_path = outdir_path, unconstrained_σ = unc_σ, prior_mean = prior_μ)
ekp_kwargs = Dict(:outdir_path => outdir_path, :failure_handler => failure_handler, :localizer => localizer)
ekp_kwargs = Dict(:outdir_path => outdir_path, :failure_handler => failure_handler, :localizer => localizer, :verbose => true, :scheduler => scheduler)
# parameters are sampled in unconstrained space
if algo_name in ["Inversion", "Sampler", "SparseInversion"]
if algo_name == "Inversion"
Expand Down Expand Up @@ -197,7 +202,6 @@ function create_output_dir(
ref_stats::ReferenceStatistics,
outdir_root::String,
algo_name::String,
Δt::FT,
n_param::IT,
N_ens::IT,
N_iter::IT,
Expand All @@ -211,7 +215,7 @@ function create_output_dir(
suffix = randstring(3) # ensure output folder is unique
outdir_path = joinpath(
outdir_root,
"results_$(algo_name)_dt_$(Δt)_p$(n_param)_e$(N_ens)_i$(N_iter)_$(d)_$(typeof(y_ref_type))_$(now)_$(suffix)",
"results_$(algo_name)_p$(n_param)_e$(N_ens)_i$(N_iter)_$(d)_$(typeof(y_ref_type))_$(now)_$(suffix)",
)
@info "Name of outdir path for this EKP is: $outdir_path"
mkpath(outdir_path)
Expand Down Expand Up @@ -365,6 +369,12 @@ function ek_update(
Δt_scheduler = get_entry(proc_config, "Δt", 1.0)
Δt = get_Δt(Δt_scheduler, iteration)

scheduler = ekobj.scheduler

if !isnothing(scheduler)
Δt = nothing
end

deterministic_forward_map = get_entry(proc_config, "noisy_obs", false)
augmented = get_entry(proc_config, "augmented", false)
param_map = get_entry(config["prior"], "param_map", HelperFuncs.do_nothing_param_map()) # do-nothing param map by default
Expand Down Expand Up @@ -413,10 +423,18 @@ function ek_update(
update_minibatch_inverse_problem(ref_model_batch, ekobj, priors, batch_size, outdir_path, config)
rm(joinpath(outdir_path, "ref_model_batch.jld2"))
write_ref_model_batch(ref_model_batch, outdir_path = outdir_path)
# if ekp-native scheduler used, keep full Δt history
# this is needed because `update_minibatch_inverse_problem` creates a new ekp object and erases Δt history
if !isnothing(scheduler)
ekp = modify_field(ekp, :Δt, deepcopy(ekobj.Δt))
end

else
ekp = ekobj
end

update_scheduler!(ekp, iteration)

# Write to file new EKP and ModelEvaluators
jldsave(ekobj_path(outdir_path, iteration + 1); ekp)
write_model_evaluators(ekp, priors, param_map, ref_models, ref_stats, outdir_path, iteration, batch_indices)
Expand All @@ -426,8 +444,13 @@ function ek_update(
reg_config = config["regularization"]
update_validation(val_config, reg_config, ekobj, priors, param_map, versions, outdir_path, iteration)
end
end

else
# If final iteration, update saved ekp object for current iteration
ekp = ekobj
update_scheduler!(ekp, iteration)
jldsave(ekobj_path(outdir_path, iteration); ekp)
end

# Clean up
for version in versions
Expand Down Expand Up @@ -482,7 +505,6 @@ function restart_calibration(

reg_config = config["regularization"]
kwargs_ref_stats = get_ref_stats_kwargs(ref_config, reg_config)

val_config = get(config, "validation", nothing)

# Prepare updated EKP and ReferenceModelBatch if minibatching.
Expand Down Expand Up @@ -766,14 +788,14 @@ function update_minibatch_inverse_problem(

augmented = get_entry(proc_config, "augmented", false)
failure_handler = get_entry(proc_config, "failure_handler", "high_loss")

scheduler = deepcopy(ekp_old.scheduler)
localizer = get_entry(proc_config, "localizer", NoLocalization())
l2_reg = get_entry(reg_config, "l2_reg", nothing)
kwargs_ref_stats = get_ref_stats_kwargs(ref_config, reg_config)
ref_stats = ReferenceStatistics(ref_models; kwargs_ref_stats...)
process = ekp_old.process

ekp_kwargs = Dict(:outdir_path => outdir_path, :failure_handler => failure_handler, :localizer => localizer)

ekp_kwargs = Dict(:outdir_path => outdir_path, :failure_handler => failure_handler, :localizer => localizer, :verbose => true, :scheduler => scheduler)
if isa(process, Unscented)
# Reconstruct UKI using regularization toward the prior
algo = Unscented(
Expand Down
2 changes: 1 addition & 1 deletion test/Pipeline/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ end

function get_process_config()
config = Dict()
config["N_iter"] = 2
config["N_iter"] = 5
config["N_ens"] = 5
config["algorithm"] = "Inversion" # "Sampler", "Unscented"
return config
Expand Down
Loading

0 comments on commit b018e2e

Please sign in to comment.