Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: create initialization systems for all problem types #3253

Merged
merged 18 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
219aee3
refactor: add guesses to `SDESystem`, `NonlinearSystem`, `JumpSystem`
AayushSabharwal Dec 1, 2024
af8cd67
feat: support arbitrary systems in `generate_initializesystem`
AayushSabharwal Dec 1, 2024
3928194
refactor: use `initialization_data` in SciMLFunction constructors
AayushSabharwal Dec 2, 2024
4044317
fix: don't build initializeprob for initializeprob
AayushSabharwal Dec 2, 2024
180b978
feat: build initialization system for all system types in `process_Sc…
AayushSabharwal Dec 2, 2024
c9c613f
fix: retain system data on `structural_simplify` of `SDESystem`
AayushSabharwal Dec 2, 2024
4d5daa3
fix: pass `t` to `process_SciMLProblem` in `SDEProblem`
AayushSabharwal Dec 3, 2024
8d09409
feat: support arbitrary systems in `remake_initialization_data`
AayushSabharwal Dec 3, 2024
def207b
fix: fix type promotion bug in `remake_buffer`
AayushSabharwal Dec 3, 2024
d971b18
test: test initialization on `SDEProblem`, `DDEProblem`, `SDDEProblem`
AayushSabharwal Dec 3, 2024
39bb59c
fix: handle integer `u0` in `DDEProblem`
AayushSabharwal Dec 6, 2024
2f2e625
feat: enable creating `InitializationProblem` for non-`AbstractODESys…
AayushSabharwal Dec 6, 2024
671b93f
fix: filter kwargs in `SDEProblem`
AayushSabharwal Dec 6, 2024
9987da0
test: test initialization on `NonlinearProblem` and `NonlinearLeastSq…
AayushSabharwal Dec 6, 2024
51eeeeb
fix: store and propagate `initialization_eqs` provided to Problem
AayushSabharwal Dec 6, 2024
96f8d5d
build: bump compats
AayushSabharwal Dec 14, 2024
0a881e7
fix: better handle reconstructing initializeprob with new types
AayushSabharwal Dec 16, 2024
2e07200
test: fix incorrect initial values in tests
AayushSabharwal Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ ConstructionBase = "1"
DataInterpolations = "6.4"
DataStructures = "0.17, 0.18"
DeepDiffs = "1"
DelayDiffEq = "5.50"
DiffEqBase = "6.157"
DiffEqCallbacks = "2.16, 3, 4"
DiffEqNoiseProcess = "5"
Expand Down Expand Up @@ -117,7 +118,7 @@ Libdl = "1"
LinearAlgebra = "1"
MLStyle = "0.4.17"
NaNMath = "0.3, 1"
NonlinearSolve = "3.14, 4"
NonlinearSolve = "4.3"
OffsetArrays = "1"
OrderedCollections = "1"
OrdinaryDiffEq = "6.82.0"
Expand All @@ -129,15 +130,17 @@ RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.66"
SciMLBase = "2.68.1"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0, 1, 2"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.35"
StochasticDiffEq = "6.72.1"
StochasticDelayDiffEq = "1.8.1"
SymbolicIndexingInterface = "0.3.36"
SymbolicUtils = "3.7"
Symbolics = "6.19"
URIs = "1"
Expand Down
61 changes: 28 additions & 33 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
sparsity = false,
analytic = nothing,
split_idxs = nothing,
initializeprob = nothing,
update_initializeprob! = nothing,
initializeprobmap = nothing,
initializeprobpmap = nothing,
initialization_data = nothing,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
Expand Down Expand Up @@ -463,10 +460,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
observed = observedfun,
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
analytic = analytic,
initializeprob = initializeprob,
update_initializeprob! = update_initializeprob!,
initializeprobmap = initializeprobmap,
initializeprobpmap = initializeprobpmap)
initialization_data)
end

"""
Expand Down Expand Up @@ -496,10 +490,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
sparse = false, simplify = false,
eval_module = @__MODULE__,
checkbounds = false,
initializeprob = nothing,
initializeprobmap = nothing,
initializeprobpmap = nothing,
update_initializeprob! = nothing,
initialization_data = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
Expand Down Expand Up @@ -547,15 +538,12 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
nothing
end

DAEFunction{iip}(f,
DAEFunction{iip}(f;
sys = sys,
jac = _jac === nothing ? nothing : _jac,
jac_prototype = jac_prototype,
observed = observedfun,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap,
initializeprobpmap = initializeprobpmap,
update_initializeprob! = update_initializeprob!)
initialization_data)
end

function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
Expand All @@ -567,6 +555,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
eval_expression = false,
eval_module = @__MODULE__,
checkbounds = false,
initialization_data = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `DDEFunction`")
Expand All @@ -579,7 +568,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
f(u, h, p, t) = f_oop(u, h, p, t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)

DDEFunction{iip}(f, sys = sys)
DDEFunction{iip}(f; sys = sys, initialization_data)
end

function DiffEqBase.SDDEFunction(sys::AbstractODESystem, args...; kwargs...)
Expand All @@ -591,6 +580,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
eval_expression = false,
eval_module = @__MODULE__,
checkbounds = false,
initialization_data = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `SDDEFunction`")
Expand All @@ -609,7 +599,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
g(u, h, p, t) = g_oop(u, h, p, t)
g(du, u, h, p, t) = g_iip(du, u, h, p, t)

SDDEFunction{iip}(f, g, sys = sys)
SDDEFunction{iip}(f, g; sys = sys, initialization_data)
end

"""
Expand Down Expand Up @@ -933,7 +923,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
h(p, t) = h_oop(p, t)
h(p::MTKParameters, t) = h_oop(p..., t)
u0 = h(p, tspan[1])
u0 = float.(h(p, tspan[1]))
if u0 !== nothing
u0 = u0_constructor(u0)
end
Expand Down Expand Up @@ -1257,23 +1247,23 @@ Generates a NonlinearProblem or NonlinearLeastSquaresProblem from an ODESystem
which represents the initialization, i.e. the calculation of the consistent
initial conditions for the given DAE.
"""
function InitializationProblem(sys::AbstractODESystem, args...; kwargs...)
function InitializationProblem(sys::AbstractSystem, args...; kwargs...)
InitializationProblem{true}(sys, args...; kwargs...)
end

function InitializationProblem(sys::AbstractODESystem, t,
function InitializationProblem(sys::AbstractSystem, t,
u0map::StaticArray,
args...;
kwargs...)
InitializationProblem{false, SciMLBase.FullSpecialize}(
sys, t, u0map, args...; kwargs...)
end

function InitializationProblem{true}(sys::AbstractODESystem, args...; kwargs...)
function InitializationProblem{true}(sys::AbstractSystem, args...; kwargs...)
InitializationProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function InitializationProblem{false}(sys::AbstractODESystem, args...; kwargs...)
function InitializationProblem{false}(sys::AbstractSystem, args...; kwargs...)
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

Expand All @@ -1292,8 +1282,8 @@ function Base.showerror(io::IO, e::IncompleteInitializationError)
println(io, e.uninit)
end

function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
t::Number, u0map = [],
function InitializationProblem{iip, specialize}(sys::AbstractSystem,
t, u0map = [],
parammap = DiffEqBase.NullParameters();
guesses = [],
check_length = true,
Expand All @@ -1320,6 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
pmap = parammap, guesses, extra_metadata = (; use_scc)); fully_determined)
end

meta = get_metadata(isys)
if meta isa InitializationSystemMetadata
@set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob(sys, isys)
end

ts = get_tearing_state(isys)
unassigned_vars = StructuralTransformations.singular_check(ts)
if warn_initialize_determined && !isempty(unassigned_vars)
Expand Down Expand Up @@ -1357,13 +1352,13 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. $(scc_message)To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true"
end

parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
[get_iv(sys) => t] :
merge(todict(parammap), Dict(get_iv(sys) => t))
parammap = Dict(k => v for (k, v) in parammap if v !== missing)
if isempty(u0map)
u0map = Dict()
parammap = recursive_unwrap(anydict(parammap))
if t !== nothing
parammap[get_iv(sys)] = t
end
filter!(kvp -> kvp[2] !== missing, parammap)

u0map = to_varmap(u0map, unknowns(sys))
if isempty(guesses)
guesses = Dict()
end
Expand Down Expand Up @@ -1405,5 +1400,5 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
else
NonlinearLeastSquaresProblem
end
TProb(isys, u0map, parammap; kwargs...)
TProb(isys, u0map, parammap; kwargs..., build_initializeprob = false)
end
27 changes: 7 additions & 20 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,29 +256,16 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
:ODESystem, force = true)
end
defaults = Dict{Any, Any}(todict(defaults))
guesses = Dict{Any, Any}(todict(guesses))
var_to_name = Dict()
process_variables!(var_to_name, defaults, dvs′)
process_variables!(var_to_name, defaults, ps′)
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
process_variables!(var_to_name, defaults, guesses, dvs′)
process_variables!(var_to_name, defaults, guesses, ps′)
process_variables!(
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
process_variables!(
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
defaults = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(defaults) if v !== nothing)

sysdvsguesses = [ModelingToolkit.getguess(st) for st in dvs′]
hasaguess = findall(!isnothing, sysdvsguesses)
var_guesses = dvs′[hasaguess] .=> sysdvsguesses[hasaguess]
sysdvsguesses = isempty(var_guesses) ? Dict() : todict(var_guesses)
syspsguesses = [ModelingToolkit.getguess(st) for st in ps′]
hasaguess = findall(!isnothing, syspsguesses)
ps_guesses = ps′[hasaguess] .=> syspsguesses[hasaguess]
syspsguesses = isempty(ps_guesses) ? Dict() : todict(ps_guesses)
syspdepguesses = [ModelingToolkit.getguess(eq.lhs) for eq in parameter_dependencies]
hasaguess = findall(!isnothing, syspdepguesses)
pdep_guesses = [eq.lhs for eq in parameter_dependencies][hasaguess] .=>
syspdepguesses[hasaguess]
syspdepguesses = isempty(pdep_guesses) ? Dict() : todict(pdep_guesses)

guesses = merge(sysdvsguesses, syspsguesses, syspdepguesses, todict(guesses))
guesses = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(guesses) if v !== nothing)

Expand Down
62 changes: 44 additions & 18 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ struct SDESystem <: AbstractODESystem
"""
defaults::Dict
"""
The guesses to use as the initial conditions for the
initialization system.
"""
guesses::Dict
"""
The system for performing the initialization.
"""
initializesystem::Union{Nothing, NonlinearSystem}
"""
Extra equations to be enforced during the initialization sequence.
"""
initialization_eqs::Vector{Equation}
"""
Type of the system.
"""
connector_type::Any
Expand Down Expand Up @@ -144,9 +157,8 @@ struct SDESystem <: AbstractODESystem
isscheduled::Bool

function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
tgrad,
jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
guesses, initializesystem, initialization_eqs, connector_type,
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false,
is_dde = false,
Expand All @@ -171,9 +183,9 @@ struct SDESystem <: AbstractODESystem
check_units(u, deqs, neqs)
end
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac,
Wfact, Wfact_t, name, description, systems,
defaults, connector_type, cevents, devents,
ctrl_jac, Wfact, Wfact_t, name, description, systems,
defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents,
devents,
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise,
is_dde, isscheduled)
end
Expand All @@ -187,6 +199,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
default_u0 = Dict(),
default_p = Dict(),
defaults = _merge(Dict(default_u0), Dict(default_p)),
guesses = Dict(),
initializesystem = nothing,
initialization_eqs = Equation[],
name = nothing,
description = "",
connector_type = nothing,
Expand All @@ -207,6 +222,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
dvs′ = value.(dvs)
ps′ = value.(ps)
ctrl′ = value.(controls)
parameter_dependencies, ps′ = process_parameter_dependencies(
parameter_dependencies, ps′)

sysnames = nameof.(systems)
if length(unique(sysnames)) != length(sysnames)
Expand All @@ -217,13 +234,21 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:SDESystem, force = true)
end
defaults = todict(defaults)
defaults = Dict(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

defaults = Dict{Any, Any}(todict(defaults))
guesses = Dict{Any, Any}(todict(guesses))
var_to_name = Dict()
process_variables!(var_to_name, defaults, dvs′)
process_variables!(var_to_name, defaults, ps′)
process_variables!(var_to_name, defaults, guesses, dvs′)
process_variables!(var_to_name, defaults, guesses, ps′)
process_variables!(
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
process_variables!(
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
defaults = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(defaults) if v !== nothing)
guesses = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(guesses) if v !== nothing)

isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))

tgrad = RefValue(EMPTY_TGRAD)
Expand All @@ -233,14 +258,13 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
Wfact_t = RefValue(EMPTY_JAC)
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
parameter_dependencies, ps′ = process_parameter_dependencies(
parameter_dependencies, ps′)
if is_dde === nothing
is_dde = _check_if_dde(deqs, iv′, systems)
end
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
initializesystem, initialization_eqs, connector_type,
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks)
end
Expand Down Expand Up @@ -520,7 +544,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
version = nothing, tgrad = false, sparse = false,
jac = false, Wfact = false, eval_expression = false,
eval_module = @__MODULE__,
checkbounds = false,
checkbounds = false, initialization_data = nothing,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
Expand Down Expand Up @@ -591,13 +615,13 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

SDEFunction{iip, specialize}(f, g,
SDEFunction{iip, specialize}(f, g;
sys = sys,
jac = _jac === nothing ? nothing : _jac,
tgrad = _tgrad === nothing ? nothing : _tgrad,
Wfact = _Wfact === nothing ? nothing : _Wfact,
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
mass_matrix = _M,
mass_matrix = _M, initialization_data,
observed = observedfun)
end

Expand Down Expand Up @@ -714,7 +738,7 @@ function DiffEqBase.SDEProblem{iip, specialize}(
end
f, u0, p = process_SciMLProblem(
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
kwargs...)
t = tspan === nothing ? nothing : tspan[1], kwargs...)
cbs = process_events(sys; callback, kwargs...)
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))

Expand All @@ -736,6 +760,8 @@ function DiffEqBase.SDEProblem{iip, specialize}(
noise = nothing
end

kwargs = filter_kwargs(kwargs)

SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
noise_rate_prototype = noise_rate_prototype, kwargs...)
end
Expand Down
Loading
Loading