Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow inbounds getters and setters #113

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,19 @@ apply:
parameter values, and can be accessed at specific indices in the timeseries.
- A mix of timeseries and non-timeseries parameters: The function can _only_ be used on
non-timeseries objects and will return the value of each parameter at in the object.

# Keyword Arguments

- `inbounds`: Whether to wrap the returned function in `@inbounds`.
"""
function getp(sys, p)
function getp(sys, p; inbounds = false)
symtype = symbolic_type(p)
elsymtype = symbolic_type(eltype(p))
_getp(sys, symtype, elsymtype, p)
getter = _getp(sys, symtype, elsymtype, p)
if inbounds
getter = InboundsWrapper(getter)
end
return getter
end

struct GetParameterIndex{I} <: AbstractParameterGetIndexer
Expand Down Expand Up @@ -659,15 +667,22 @@ Requires that the value provider implement [`parameter_values`](@ref) and the re
collection be a mutable reference to the parameter object. In case `parameter_values`
cannot return such a mutable reference, or additional actions need to be performed when
updating parameters, [`set_parameter!`](@ref) must be implemented.

# Keyword Arguments

- `inbounds`: Whether to wrap the function in `@inbounds`.
"""
function setp(sys, p; run_hook = true)
function setp(sys, p; run_hook = true, inbounds = false)
symtype = symbolic_type(p)
elsymtype = symbolic_type(eltype(p))
return if run_hook
return ParameterHookWrapper(_setp(sys, symtype, elsymtype, p), p)
else
_setp(sys, symtype, elsymtype, p)
setter = _setp(sys, symtype, elsymtype, p)
if run_hook
setter = ParameterHookWrapper(setter, p)
end
if inbounds
setter = InboundsWrapper(setter)
end
return setter
end

struct SetParameterIndex{I} <: AbstractSetIndexer
Expand Down Expand Up @@ -723,11 +738,19 @@ the types of values stored, and leverages [`remake_buffer`](@ref). Note that `sy
an index, a symbolic variable, or an array/tuple of the aforementioned.

Requires that the value provider implement `parameter_values` and `remake_buffer`.

# Keyword Arguments

- `inbounds`: Whether to wrap the returned function in `@inbounds`.
"""
function setp_oop(indp, sym)
function setp_oop(indp, sym; inbounds = false)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))
return _setp_oop(indp, symtype, elsymtype, sym)
setter = _setp_oop(indp, symtype, elsymtype, sym)
if inbounds
setter = InboundsWrapper(setter)
end
return setter
end

function _setp_oop(indp, ::NotSymbolic, ::NotSymbolic, sym)
Expand Down
36 changes: 30 additions & 6 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ relying on the above functions.
If the value provider is a parameter timeseries object, the same rules apply as
[`getp`](@ref). The difference here is that `sym` may also contain non-parameter symbols,
and the values are always returned corresponding to the state timeseries.

# Keyword Arguments

- `inbounds`: whether to wrap the returned function in an `@inbounds`.
"""
function getsym(sys, sym)
function getsym(sys, sym; inbounds = false)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))
_getsym(sys, symtype, elsymtype, sym)
getter = _getsym(sys, symtype, elsymtype, sym)
if inbounds
getter = InboundsWrapper(getter)
end
return getter
end

struct GetStateIndex{I} <: AbstractStateGetIndexer
Expand Down Expand Up @@ -322,11 +330,19 @@ collection be a mutable reference to the state vector in the value provider. Alt
if this is not possible or additional actions need to be performed when updating state,
[`set_state!`](@ref) can be defined. This function does not work on types for which
[`is_timeseries`](@ref) is [`Timeseries`](@ref).

# Keyword Arguments

- `inbounds`: Whether to wrap the returned function in an `@inbounds`.
"""
function setsym(sys, sym)
function setsym(sys, sym; inbounds = false)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))
_setsym(sys, symtype, elsymtype, sym)
setter = _setsym(sys, symtype, elsymtype, sym)
if inbounds
setter = InboundsWrapper(setter)
end
return setter
end

struct SetStateIndex{I} <: AbstractSetIndexer
Expand Down Expand Up @@ -390,11 +406,19 @@ array/tuple of the aforementioned. All entries `s` in `sym` must satisfy `is_var
or `is_parameter(indp, s)`.

Requires that the value provider implement `state_values`, `parameter_values` and `remake_buffer`.

# Keyword Arguments

- `inbounds`: Whether to wrap the returned function in `@inbounds`.
"""
function setsym_oop(indp, sym)
function setsym_oop(indp, sym; inbounds = false)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))
return _setsym_oop(indp, symtype, elsymtype, sym)
setter = _setsym_oop(indp, symtype, elsymtype, sym)
if inbounds
setter = InboundsWrapper(setter)
end
return setter
end

struct FullSetter{S, P, I, J}
Expand Down
22 changes: 22 additions & 0 deletions src/value_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,28 @@ function _root_indp(indp)
end
end

"""
struct InboundsWrapper

Utility struct to wrap a callable in `@inbounds`.
"""
struct InboundsWrapper{F}
fn::F
end

is_indexer_timeseries(::Type{InboundsWrapper{F}}) where {F} = is_indexer_timeseries(F)
indexer_timeseries_index(iw::InboundsWrapper) = indexer_timeseries_index(iw.fn)
as_timeseries_indexer(iw::InboundsWrapper) = InboundsWrapper(as_timeseries_indexer(iw.fn))
function as_not_timeseries_indexer(iw::InboundsWrapper)
InboundsWrapper(as_not_timeseries_indexer(iw.fn))
end

function (ig::InboundsWrapper)(args...)
return @inbounds begin
ig.fn(args...)
end
end

###########
# Errors
###########
Expand Down
Loading
Loading