Skip to content

Commit

Permalink
reapply formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ArnoStrouwen committed Apr 11, 2024
1 parent 063783b commit 98eb805
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 130 deletions.
2 changes: 2 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
style = "sciml"
format_markdown = true
format_docstrings = true
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ Pkg.add("SymbolicIndexingInterface")

The symbolic indexing interface has 2 levels:

1. The user level. At the user level, a modeler or engineer simply uses terms from a
domain-specific language (DSL) inside of SciML functionality and will receive the requested
values. For example, if a DSL defines a symbol `x`, then `sol[x]` returns the solution
value(s) for `x`.
2. The DSL system structure level. This is the structure which defines the symbolic indexing
for a given problem/solution. DSLs can tag a constructed problem/solution with this
object in order to endow the SciML tools with the ability to index symbolically according
to the definitions the DSL writer wants.

1. The user level. At the user level, a modeler or engineer simply uses terms from a
domain-specific language (DSL) inside of SciML functionality and will receive the requested
values. For example, if a DSL defines a symbol `x`, then `sol[x]` returns the solution
value(s) for `x`.
2. The DSL system structure level. This is the structure which defines the symbolic indexing
for a given problem/solution. DSLs can tag a constructed problem/solution with this
object in order to endow the SciML tools with the ability to index symbolically according
to the definitions the DSL writer wants.

## Example

Expand Down
195 changes: 99 additions & 96 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ This tutorial will show how to define the entire Symbolic Indexing Interface on

```julia
struct ExampleSystem
state_index::Dict{Symbol,Int}
parameter_index::Dict{Symbol,Int}
independent_variable::Union{Symbol,Nothing}
defaults::Dict{Symbol, Float64}
# mapping from observed variable to Expr to calculate its value
observed::Dict{Symbol,Expr}
state_index::Dict{Symbol, Int}
parameter_index::Dict{Symbol, Int}
independent_variable::Union{Symbol, Nothing}
# mapping from observed variable to Expr to calculate its value
observed::Dict{Symbol, Expr}
end
```

Expand All @@ -25,58 +24,58 @@ These are the simple functions which describe how to turn symbols into indices.

```julia
function SymbolicIndexingInterface.is_variable(sys::ExampleSystem, sym)
haskey(sys.state_index, sym)
haskey(sys.state_index, sym)
end

function SymbolicIndexingInterface.variable_index(sys::ExampleSystem, sym)
get(sys.state_index, sym, nothing)
get(sys.state_index, sym, nothing)
end

function SymbolicIndexingInterface.variable_symbols(sys::ExampleSystem)
collect(keys(sys.state_index))
collect(keys(sys.state_index))
end

function SymbolicIndexingInterface.is_parameter(sys::ExampleSystem, sym)
haskey(sys.parameter_index, sym)
haskey(sys.parameter_index, sym)
end

function SymbolicIndexingInterface.parameter_index(sys::ExampleSystem, sym)
get(sys.parameter_index, sym, nothing)
get(sys.parameter_index, sym, nothing)
end

function SymbolicIndexingInterface.parameter_symbols(sys::ExampleSystem)
collect(keys(sys.parameter_index))
collect(keys(sys.parameter_index))
end

function SymbolicIndexingInterface.is_independent_variable(sys::ExampleSystem, sym)
# note we have to check separately for `nothing`, otherwise
# `is_independent_variable(p, nothing)` would return `true`.
sys.independent_variable !== nothing && sym === sys.independent_variable
# note we have to check separately for `nothing`, otherwise
# `is_independent_variable(p, nothing)` would return `true`.
sys.independent_variable !== nothing && sym === sys.independent_variable
end

function SymbolicIndexingInterface.independent_variable_symbols(sys::ExampleSystem)
sys.independent_variable === nothing ? [] : [sys.independent_variable]
sys.independent_variable === nothing ? [] : [sys.independent_variable]
end

function SymbolicIndexingInterface.is_time_dependent(sys::ExampleSystem)
sys.independent_variable !== nothing
sys.independent_variable !== nothing
end

SymbolicIndexingInterface.constant_structure(::ExampleSystem) = true

function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSystem)
return vcat(
collect(keys(sys.state_index)),
collect(keys(sys.observed)),
)
return vcat(
collect(keys(sys.state_index)),
collect(keys(sys.observed))
)
end

function SymbolicIndexingInterface.all_symbols(sys::ExampleSystem)
return vcat(
all_solvable_symbols(sys),
collect(keys(sys.parameter_index)),
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
)
return vcat(
all_solvable_symbols(sys),
collect(keys(sys.parameter_index)),
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
)
end

function SymbolicIndexingInterface.default_values(sys::ExampleSystem)
Expand All @@ -95,36 +94,38 @@ RuntimeGeneratedFunctions.init(@__MODULE__)

# this type accepts `Expr` for observed expressions involving state/parameter/observed
# variables
SymbolicIndexingInterface.is_observed(sys::ExampleSystem, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym)
function SymbolicIndexingInterface.is_observed(sys::ExampleSystem, sym)
sym isa Expr || sym isa Symbol && haskey(sys.observed, sym)
end

function SymbolicIndexingInterface.observed(sys::ExampleSystem, sym::Expr)
# generate a function with the appropriate signature
if is_time_dependent(sys)
fn_expr = :(
function gen(u, p, t)
# assign a variable for each state symbol it's value in u
$([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...)
# assign a variable for each parameter symbol it's value in p
$([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...)
# assign a variable for the independent variable
$(sys.independent_variable) = t
# return the value of the expression
return $sym
end
)
else
fn_expr = :(
function gen(u, p)
# assign a variable for each state symbol it's value in u
$([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...)
# assign a variable for each parameter symbol it's value in p
$([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...)
# return the value of the expression
return $sym
end
)
end
return @RuntimeGeneratedFunction(fn_expr)
# generate a function with the appropriate signature
if is_time_dependent(sys)
fn_expr = :(
function gen(u, p, t)
# assign a variable for each state symbol it's value in u
$([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...)
# assign a variable for each parameter symbol it's value in p
$([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...)
# assign a variable for the independent variable
$(sys.independent_variable) = t
# return the value of the expression
return $sym
end
)
else
fn_expr = :(
function gen(u, p)
# assign a variable for each state symbol it's value in u
$([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...)
# assign a variable for each parameter symbol it's value in p
$([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...)
# return the value of the expression
return $sym
end
)
end
return @RuntimeGeneratedFunction(fn_expr)
end
```

Expand All @@ -136,16 +137,17 @@ defined to always return `false`, and `observed` does not need to be implemented
Note that the method definitions are all assuming `constant_structure(p) == true`.

In case `constant_structure(p) == false`, the following methods would change:
- `constant_structure(::ExampleSystem) = false`
- `variable_index(sys::ExampleSystem, sym)` would become
`variable_index(sys::ExampleSystem, sym i)` where `i` is the time index at which
the index of `sym` is required.
- `variable_symbols(sys::ExampleSystem)` would become
`variable_symbols(sys::ExampleSystem, i)` where `i` is the time index at which
the variable symbols are required.
- `observed(sys::ExampleSystem, sym)` would become
`observed(sys::ExampleSystem, sym, i)` where `i` is either the time index at which
the index of `sym` is required or a `Vector` of state symbols at the current time index.

- `constant_structure(::ExampleSystem) = false`
- `variable_index(sys::ExampleSystem, sym)` would become
`variable_index(sys::ExampleSystem, sym i)` where `i` is the time index at which
the index of `sym` is required.
- `variable_symbols(sys::ExampleSystem)` would become
`variable_symbols(sys::ExampleSystem, i)` where `i` is the time index at which
the variable symbols are required.
- `observed(sys::ExampleSystem, sym)` would become
`observed(sys::ExampleSystem, sym, i)` where `i` is either the time index at which
the index of `sym` is required or a `Vector` of state symbols at the current time index.

## Optional methods

Expand All @@ -163,7 +165,7 @@ them is not necessary.

```julia
function SymbolicIndexingInterface.parameter_values(sys::ExampleSystem)
sys.p
sys.p
end
```

Expand All @@ -179,10 +181,10 @@ Consider the following `ExampleIntegrator`

```julia
mutable struct ExampleIntegrator
u::Vector{Float64}
p::Vector{Float64}
t::Float64
sys::ExampleSystem
u::Vector{Float64}
p::Vector{Float64}
t::Float64
sys::ExampleSystem
end

# define a fallback for the interface methods
Expand All @@ -193,6 +195,7 @@ SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t
```

Then the following example would work:

```julia
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict())
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)
Expand All @@ -215,10 +218,10 @@ the [`Timeseries`](@ref) trait. The type would then return a timeseries from

```julia
struct ExampleSolution
u::Vector{Vector{Float64}}
t::Vector{Float64}
p::Vector{Float64}
sys::ExampleSystem
u::Vector{Vector{Float64}}
t::Vector{Float64}
p::Vector{Float64}
sys::ExampleSystem
end

# define a fallback for the interface methods
Expand All @@ -233,6 +236,7 @@ SymbolicIndexingInterface.current_time(sol::ExampleSolution) = sol.t
```

Then the following example would work:

```julia
# using the same system that the ExampleIntegrator used
sol = ExampleSolution([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]], [4.0, 5.0], [6.0, 7.0], sys)
Expand Down Expand Up @@ -262,8 +266,8 @@ follows:

```julia
function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val, idx)
integrator.u[idx] = val
integrator.u_modified = true
integrator.u[idx] = val
integrator.u_modified = true
end
```

Expand All @@ -279,24 +283,25 @@ performed for a bulk parameter update.
# The `ParameterIndexingProxy`

[`ParameterIndexingProxy`](@ref) is a wrapper around another type which implements the
interface and allows using [`getp`](@ref) and [`setp`](@ref) to get and set parameter
interface and allows using [`getp`](@ref) and [`setp`](@ref) to get and set parameter
values. This allows for a cleaner interface for parameter indexing. Consider the
following example for `ExampleIntegrator`:

```julia
function Base.getproperty(obj::ExampleIntegrator, sym::Symbol)
if sym === :ps
return ParameterIndexingProxy(obj)
else
return getfield(obj, sym)
end
if sym === :ps
return ParameterIndexingProxy(obj)
else
return getfield(obj, sym)
end
end
```

This enables the following API:

```julia
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0,
Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)

integrator.ps[:a] # 4.0
getp(integrator, :a)(integrator) # functionally the same as above
Expand All @@ -310,25 +315,25 @@ setp(integrator, :b)(integrator, 3.0) # functionally the same as above
The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It
has three variants:

- [`NotSymbolic`](@ref) for quantities that are not symbolic. This is the default for all
types.
- [`ScalarSymbolic`](@ref) for quantities that are symbolic, and represent a single
logical value.
- [`ArraySymbolic`](@ref) for quantities that are symbolic, and represent an array of
values. Types implementing this trait must return an array of `ScalarSymbolic` variables
of the appropriate size and dimensions when `collect`ed.
- [`NotSymbolic`](@ref) for quantities that are not symbolic. This is the default for all
types.
- [`ScalarSymbolic`](@ref) for quantities that are symbolic, and represent a single
logical value.
- [`ArraySymbolic`](@ref) for quantities that are symbolic, and represent an array of
values. Types implementing this trait must return an array of `ScalarSymbolic` variables
of the appropriate size and dimensions when `collect`ed.

The trait is implemented through the [`symbolic_type`](@ref) function. Consider the following
example types:

```julia
struct MySym
name::Symbol
name::Symbol
end

struct MySymArr{N}
name::Symbol
size::NTuple{N,Int}
name::Symbol
size::NTuple{N, Int}
end
```

Expand All @@ -343,10 +348,8 @@ SymbolicIndexingInterface.symbolic_type(::Type{<:MySymArr}) = ArraySymbolic()
SymbolicIndexingInterface.hasname(::MySymArr) = true
SymbolicIndexingInterface.getname(sym::MySymArr) = sym.name
function Base.collect(sym::MySymArr)
[
MySym(Symbol(sym.name, :_, join(idxs, "_")))
for idxs in Iterators.product(Base.OneTo.(sym.size)...)
]
[MySym(Symbol(sym.name, :_, join(idxs, "_")))
for idxs in Iterators.product(Base.OneTo.(sym.size)...)]
end
```

Expand Down
Loading

0 comments on commit 98eb805

Please sign in to comment.