Skip to content

Commit

Permalink
Merge pull request #3324 from AayushSabharwal/as/scc-fix
Browse files Browse the repository at this point in the history
feat: support caching of different types of subexpressions in `SCCNonlinearProblem`
  • Loading branch information
ChrisRackauckas authored Jan 16, 2025
2 parents 083a639 + 6ce9e85 commit f7f0221
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 38 deletions.
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ import SCCNonlinearSolve
using Reexport
using RecursiveArrayTools
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, blockpush!,
undef_blocks, blocks
import CommonSolve
import EnumX

Expand Down
2 changes: 1 addition & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@ function Base.show(
nrows > 0 && hint && print(io, " see hierarchy($name)")
for i in 1:nrows
sub = subs[i]
name = String(nameof(sub))
local name = String(nameof(sub))
print(io, "\n ", name)
desc = description(sub)
if !isempty(desc)
Expand Down
118 changes: 91 additions & 27 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -573,29 +573,37 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
end

const TypeT = Union{DataType, UnionAll}

struct CacheWriter{F}
fn::F
end

function (cw::CacheWriter)(p, sols)
cw.fn(p.caches[1], sols, p...)
cw.fn(p.caches, sols, p...)
end

function CacheWriter(sys::AbstractSystem, exprs, solsyms, obseqs::Vector{Equation};
function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation};
eval_expression = false, eval_module = @__MODULE__)
ps = parameters(sys)
rps = reorder_parameters(sys, ps)
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
cmap, cs = get_cmap(sys)
cmap_assigns = [eq.lhs eq.rhs for eq in cmap]

outsyms = [Symbol(:out, i) for i in eachindex(buffer_types)]
body = map(eachindex(buffer_types), buffer_types) do i, T
Symbol(:tmp, i) SetArray(true, :(out[$i]), get(exprs, T, []))
end
fn = Func(
[:out, DestructuredArgs(DestructuredArgs.(solsyms)),
DestructuredArgs.(rps)...],
[],
SetArray(true, :out, exprs)
Let(body, :())
) |> wrap_assignments(false, obs_assigns)[2] |>
wrap_parameter_dependencies(sys, false)[2] |>
wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |>
wrap_array_vars(sys, []; dvs = nothing, inputs = [])[2] |>
wrap_assignments(false, cmap_assigns)[2] |> toexpr
return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module))
end
Expand Down Expand Up @@ -677,8 +685,17 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,

explicitfuns = []
nlfuns = []
prevobsidxs = Int[]
cachesize = 0
prevobsidxs = BlockArray(undef_blocks, Vector{Int}, Int[])
# Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
# dict to maintain a consistent order of buffers across SCCs
cachetypes = TypeT[]
cachesizes = Int[]
# explicitfun! related information for each SCC
# We need to compute buffer sizes before doing any codegen
scc_cachevars = Dict{TypeT, Vector{Any}}[]
scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
scc_eqs = Vector{Equation}[]
scc_obs = Vector{Equation}[]
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
# subset unknowns and equations
_dvs = dvs[vscc]
Expand All @@ -690,11 +707,10 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
_obs = obs[obsidxs]

# get all subexpressions in the RHS which we can precompute in the cache
# precomputed subexpressions should not contain `banned_vars`
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
for var in banned_vars
iscall(var) || continue
operation(var) === getindex || continue
push!(banned_vars, arguments(var)[1])
filter!(banned_vars) do var
symbolic_type(var) != ArraySymbolic() || all(x -> var[i] in banned_vars, eachindex(var))
end
state = Dict()
for i in eachindex(_obs)
Expand All @@ -706,37 +722,85 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
_eqs[i].rhs, banned_vars, state)
end

# cached variables and their corresponding expressions
cachevars = Any[obs[i].lhs for i in prevobsidxs]
cacheexprs = Any[obs[i].lhs for i in prevobsidxs]
# map from symtype to cached variables and their expressions
cachevars = Dict{Union{DataType, UnionAll}, Vector{Any}}()
cacheexprs = Dict{Union{DataType, UnionAll}, Vector{Any}}()
# observed of previous SCCs are in the cache
# NOTE: When we get proper CSE, we can substitute these
# and then use `subexpressions_not_involving_vars!`
for i in prevobsidxs
T = symtype(obs[i].lhs)
buf = get!(() -> Any[], cachevars, T)
push!(buf, obs[i].lhs)

buf = get!(() -> Any[], cacheexprs, T)
push!(buf, obs[i].lhs)
end

for (k, v) in state
push!(cachevars, unwrap(v))
push!(cacheexprs, unwrap(k))
k = unwrap(k)
v = unwrap(v)
T = symtype(k)
buf = get!(() -> Any[], cachevars, T)
push!(buf, v)
buf = get!(() -> Any[], cacheexprs, T)
push!(buf, k)
end
cachesize = max(cachesize, length(cachevars))

# update the sizes of cache buffers
for (T, buf) in cachevars
idx = findfirst(isequal(T), cachetypes)
if idx === nothing
push!(cachetypes, T)
push!(cachesizes, 0)
idx = lastindex(cachetypes)
end
cachesizes[idx] = max(cachesizes[idx], length(buf))
end

push!(scc_cachevars, cachevars)
push!(scc_cacheexprs, cacheexprs)
push!(scc_eqs, _eqs)
push!(scc_obs, _obs)
blockpush!(prevobsidxs, obsidxs)
end

for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
_dvs = dvs[vscc]
_eqs = scc_eqs[i]
_prevobsidxs = reduce(vcat, blocks(prevobsidxs)[1:(i - 1)]; init = Int[])
_obs = scc_obs[i]
cachevars = scc_cachevars[i]
cacheexprs = scc_cacheexprs[i]

if isempty(cachevars)
push!(explicitfuns, Returns(nothing))
else
solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1)))
push!(explicitfuns,
CacheWriter(sys, cacheexprs, solsyms, obs[prevobsidxs];
CacheWriter(sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs];
eval_expression, eval_module))
end

cachebufsyms = Tuple(map(cachetypes) do T
get(cachevars, T, [])
end)
f = SCCNonlinearFunction{iip}(
sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...)
sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, kwargs...)
push!(nlfuns, f)
append!(cachevars, _dvs)
append!(cacheexprs, _dvs)
for i in obsidxs
push!(cachevars, obs[i].lhs)
push!(cacheexprs, obs[i].rhs)
end
append!(prevobsidxs, obsidxs)
end

if cachesize != 0
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), cachesize))
if !isempty(cachetypes)
templates = map(cachetypes, cachesizes) do T, n
# Real refers to `eltype(u0)`
if T == Real
T = eltype(u0)
elseif T <: Array && eltype(T) == Real
T = Array{eltype(u0), ndims(T)}
end
BufferTemplate(T, n)
end
p = rebuild_with_caches(p, templates...)
end

subprobs = []
Expand Down
27 changes: 18 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1108,23 +1108,33 @@ returns the modified `expr`.
"""
function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
expr = unwrap(expr)
symbolic_type(expr) == NotSymbolic() && return expr
if symbolic_type(expr) == NotSymbolic()
if is_array_of_symbolics(expr)
return map(expr) do el
subexpressions_not_involving_vars!(el, vars, state)
end
end
return expr
end
any(isequal(expr), vars) && return expr
iscall(expr) || return expr
is_variable_floatingpoint(expr) || return expr
symtype(expr) <: Union{Real, AbstractArray{<:Real}} || return expr
Symbolics.shape(expr) == Symbolics.Unknown() && return expr
haskey(state, expr) && return state[expr]
vs = ModelingToolkit.vars(expr)
intersect!(vs, vars)
if isempty(vs)
op = operation(expr)
args = arguments(expr)
# if this is a `getindex` and the getindex-ed value is a `Sym`
# or it is not a called parameter
# OR
# none of `vars` are involved in `expr`
if op === getindex && (issym(args[1]) || !iscalledparameter(args[1])) ||
(vs = ModelingToolkit.vars(expr); intersect!(vs, vars); isempty(vs))
sym = gensym(:subexpr)
stype = symtype(expr)
var = similar_variable(expr, sym)
state[expr] = var
return var
end
op = operation(expr)
args = arguments(expr)

if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic()
indep_args = []
dep_args = []
Expand All @@ -1143,7 +1153,6 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
return op(indep_term, dep_term)
end
newargs = map(args) do arg
symbolic_type(arg) != NotSymbolic() || is_array_of_symbolics(arg) || return arg
subexpressions_not_involving_vars!(arg, vars, state)
end
return maketerm(typeof(expr), op, newargs, metadata(expr))
Expand Down
92 changes: 92 additions & 0 deletions test/scc_nonlinear_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,95 @@ end
@test SciMLBase.successful_retcode(sccsol)
@test val[] == 1
end

import ModelingToolkitStandardLibrary.Blocks as B
import ModelingToolkitStandardLibrary.Mechanical.Translational as T
import ModelingToolkitStandardLibrary.Hydraulic.IsothermalCompressible as IC

@testset "Caching of subexpressions of different types" begin
liquid_pressure(rho, rho_0, bulk) = (rho / rho_0 - 1) * bulk
gas_pressure(rho, rho_0, p_gas, rho_gas) = rho * ((0 - p_gas) / (rho_0 - rho_gas))
full_pressure(rho, rho_0, bulk, p_gas, rho_gas) = ifelse(
rho >= rho_0, liquid_pressure(rho, rho_0, bulk),
gas_pressure(rho, rho_0, p_gas, rho_gas))

@component function Volume(;
#parameters
area,
direction = +1,
x_int,
name)
pars = @parameters begin
area = area
x_int = x_int
rho_0 = 1000
bulk = 1e9
p_gas = -1000
rho_gas = 1
end

vars = @variables begin
x(t) = x_int
dx(t), [guess = 0]
p(t), [guess = 0]
f(t), [guess = 0]
rho(t), [guess = 0]
m(t), [guess = 0]
dm(t), [guess = 0]
end

systems = @named begin
port = IC.HydraulicPort()
flange = T.MechanicalPort()
end

eqs = [
# connectors
port.p ~ p
port.dm ~ dm
flange.v * direction ~ dx
flange.f * direction ~ -f

# differentials
D(x) ~ dx
D(m) ~ dm

# physics
p ~ full_pressure(rho, rho_0, bulk, p_gas, rho_gas)
f ~ p * area
m ~ rho * x * area]

return ODESystem(eqs, t, vars, pars; name, systems)
end

systems = @named begin
fluid = IC.HydraulicFluid(; bulk_modulus = 1e9)

src1 = IC.Pressure(;)
src2 = IC.Pressure(;)

vol1 = Volume(; area = 0.01, direction = +1, x_int = 0.1)
vol2 = Volume(; area = 0.01, direction = +1, x_int = 0.1)

mass = T.Mass(; m = 10)

sin1 = B.Sine(; frequency = 0.5, amplitude = +0.5e5, offset = 10e5)
sin2 = B.Sine(; frequency = 0.5, amplitude = -0.5e5, offset = 10e5)
end

eqs = [connect(fluid, src1.port)
connect(fluid, src2.port)
connect(src1.port, vol1.port)
connect(src2.port, vol2.port)
connect(vol1.flange, mass.flange, vol2.flange)
connect(src1.p, sin1.output)
connect(src2.p, sin2.output)]

initialization_eqs = [mass.s ~ 0.0
mass.v ~ 0.0]

@mtkbuild sys = ODESystem(eqs, t, [], []; systems, initialization_eqs)
prob = ODEProblem(sys, [], (0, 5))
sol = solve(prob)
@test SciMLBase.successful_retcode(sol)
end

0 comments on commit f7f0221

Please sign in to comment.