From 507b456b5ade170f520bd381403cd4b925d440a9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 8 Jan 2025 18:01:24 +0530 Subject: [PATCH] feat: allow inbounds getters and setters --- src/parameter_indexing.jl | 41 +- src/state_indexing.jl | 36 +- src/value_provider_interface.jl | 22 + test/parameter_indexing_test.jl | 1008 +++++++++++++++++-------------- test/runtests.jl | 45 ++ test/state_indexing_test.jl | 646 +++++++++++--------- 6 files changed, 1060 insertions(+), 738 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index de21a38..6cdf83c 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -34,11 +34,19 @@ apply: parameter values, and can be accessed at specific indices in the timeseries. - A mix of timeseries and non-timeseries parameters: The function can _only_ be used on non-timeseries objects and will return the value of each parameter at in the object. + +# Keyword Arguments + +- `inbounds`: Whether to wrap the returned function in `@inbounds`. """ -function getp(sys, p) +function getp(sys, p; inbounds = false) symtype = symbolic_type(p) elsymtype = symbolic_type(eltype(p)) - _getp(sys, symtype, elsymtype, p) + getter = _getp(sys, symtype, elsymtype, p) + if inbounds + getter = InboundsWrapper(getter) + end + return getter end struct GetParameterIndex{I} <: AbstractParameterGetIndexer @@ -659,15 +667,22 @@ Requires that the value provider implement [`parameter_values`](@ref) and the re collection be a mutable reference to the parameter object. In case `parameter_values` cannot return such a mutable reference, or additional actions need to be performed when updating parameters, [`set_parameter!`](@ref) must be implemented. + +# Keyword Arguments + +- `inbounds`: Whether to wrap the function in `@inbounds`. """ -function setp(sys, p; run_hook = true) +function setp(sys, p; run_hook = true, inbounds = false) symtype = symbolic_type(p) elsymtype = symbolic_type(eltype(p)) - return if run_hook - return ParameterHookWrapper(_setp(sys, symtype, elsymtype, p), p) - else - _setp(sys, symtype, elsymtype, p) + setter = _setp(sys, symtype, elsymtype, p) + if run_hook + setter = ParameterHookWrapper(setter, p) + end + if inbounds + setter = InboundsWrapper(setter) end + return setter end struct SetParameterIndex{I} <: AbstractSetIndexer @@ -723,11 +738,19 @@ the types of values stored, and leverages [`remake_buffer`](@ref). Note that `sy an index, a symbolic variable, or an array/tuple of the aforementioned. Requires that the value provider implement `parameter_values` and `remake_buffer`. + +# Keyword Arguments + +- `inbounds`: Whether to wrap the returned function in `@inbounds`. """ -function setp_oop(indp, sym) +function setp_oop(indp, sym; inbounds = false) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - return _setp_oop(indp, symtype, elsymtype, sym) + setter = _setp_oop(indp, symtype, elsymtype, sym) + if inbounds + setter = InboundsWrapper(setter) + end + return setter end function _setp_oop(indp, ::NotSymbolic, ::NotSymbolic, sym) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 6e1d3d0..fead750 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -24,11 +24,19 @@ relying on the above functions. If the value provider is a parameter timeseries object, the same rules apply as [`getp`](@ref). The difference here is that `sym` may also contain non-parameter symbols, and the values are always returned corresponding to the state timeseries. + +# Keyword Arguments + +- `inbounds`: whether to wrap the returned function in an `@inbounds`. """ -function getsym(sys, sym) +function getsym(sys, sym; inbounds = false) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - _getsym(sys, symtype, elsymtype, sym) + getter = _getsym(sys, symtype, elsymtype, sym) + if inbounds + getter = InboundsWrapper(getter) + end + return getter end struct GetStateIndex{I} <: AbstractStateGetIndexer @@ -322,11 +330,19 @@ collection be a mutable reference to the state vector in the value provider. Alt if this is not possible or additional actions need to be performed when updating state, [`set_state!`](@ref) can be defined. This function does not work on types for which [`is_timeseries`](@ref) is [`Timeseries`](@ref). + +# Keyword Arguments + +- `inbounds`: Whether to wrap the returned function in an `@inbounds`. """ -function setsym(sys, sym) +function setsym(sys, sym; inbounds = false) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - _setsym(sys, symtype, elsymtype, sym) + setter = _setsym(sys, symtype, elsymtype, sym) + if inbounds + setter = InboundsWrapper(setter) + end + return setter end struct SetStateIndex{I} <: AbstractSetIndexer @@ -390,11 +406,19 @@ array/tuple of the aforementioned. All entries `s` in `sym` must satisfy `is_var or `is_parameter(indp, s)`. Requires that the value provider implement `state_values`, `parameter_values` and `remake_buffer`. + +# Keyword Arguments + +- `inbounds`: Whether to wrap the returned function in `@inbounds`. """ -function setsym_oop(indp, sym) +function setsym_oop(indp, sym; inbounds = false) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - return _setsym_oop(indp, symtype, elsymtype, sym) + setter = _setsym_oop(indp, symtype, elsymtype, sym) + if inbounds + setter = InboundsWrapper(setter) + end + return setter end struct FullSetter{S, P, I, J} diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index cd52ae1..b8ceff4 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -260,6 +260,28 @@ function _root_indp(indp) end end +""" + struct InboundsWrapper + +Utility struct to wrap a callable in `@inbounds`. +""" +struct InboundsWrapper{F} + fn::F +end + +is_indexer_timeseries(::Type{InboundsWrapper{F}}) where {F} = is_indexer_timeseries(F) +indexer_timeseries_index(iw::InboundsWrapper) = indexer_timeseries_index(iw.fn) +as_timeseries_indexer(iw::InboundsWrapper) = InboundsWrapper(as_timeseries_indexer(iw.fn)) +function as_not_timeseries_indexer(iw::InboundsWrapper) + InboundsWrapper(as_not_timeseries_indexer(iw.fn)) +end + +function (ig::InboundsWrapper)(args...) + return @inbounds begin + ig.fn(args...) + end +end + ########### # Errors ########### diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 26eaaf5..4a4c59c 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -5,6 +5,9 @@ using SymbolicIndexingInterface: IndexerOnlyTimeseries, IndexerNotTimeseries, In ParameterTimeseriesValueIndexMismatchError, MixedParameterTimeseriesIndexError using Test +import ..CheckboundsCountedArray +import ..maybe_CheckboundsCountedArray as maybe_CCA +import ..test_no_boundschecks arr = [1.0, 2.0, 3.0] @test parameter_values(arr) == arr @@ -28,181 +31,229 @@ SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator, p) fi.counter[] += 1 end +function test_no_boundschecks(fi::FakeIntegrator) + test_no_boundschecks(fi.p) +end -for sys in [ - SymbolCache([:x, :y, :z], [:a, :b, :c, :d], [:t]), - SymbolCache([:x, :y, :z], - [:a, :b, :c, :d], - [:t], - timeseries_parameters = Dict( - :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) -] - has_ts = sys.timeseries_parameters !== nothing - for pType in [Vector, Tuple] - p = [1.0, 2.0, 3.0, 4.0] - fi = FakeIntegrator(sys, pType(copy(p)), 9.0, Ref(0)) - new_p = [4.0, 5.0, 6.0, 7.0] - for (sym, oldval, newval, check_inference) in [ - (:a, p[1], new_p[1], true), - (1, p[1], new_p[1], true), - ([:a, :b], p[1:2], new_p[1:2], !has_ts), - (1:2, p[1:2], new_p[1:2], true), - ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), - ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), - ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), - ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), - ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)] - get = getp(sys, sym) - set! = setp(sys, sym) - if check_inference - @inferred get(fi) - end - @test get(fi) == fi.ps[sym] - @test get(fi) == oldval - - if pType === Tuple - @test_throws MethodError set!(fi, newval) - continue +@testset "FakeIntegrator: inbounds = $inbounds" for inbounds in [false, true] + for sys in [ + SymbolCache([:x, :y, :z], [:a, :b, :c, :d], [:t]), + SymbolCache([:x, :y, :z], + [:a, :b, :c, :d], + [:t], + timeseries_parameters = Dict( + :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) + ] + has_ts = sys.timeseries_parameters !== nothing + for pType in [Vector, Tuple] + p = [1.0, 2.0, 3.0, 4.0] + _p = pType(copy(p)) + if pType == Vector + _p = maybe_CCA(_p, inbounds) end + fi = FakeIntegrator(sys, _p, 9.0, Ref(0)) + new_p = [4.0, 5.0, 6.0, 7.0] + for (sym, oldval, newval, check_inference) in [ + (:a, p[1], new_p[1], true), + (1, p[1], new_p[1], true), + ([:a, :b], p[1:2], new_p[1:2], !has_ts), + (1:2, p[1:2], new_p[1:2], true), + ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), + ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([:a, (:b, :c)], [p[1], (p[2], p[3])], + [new_p[1], (new_p[2], new_p[3])], false), + ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((:a, (:b, :c)), (p[1], (p[2], p[3])), + (new_p[1], (new_p[2], new_p[3])), true), + ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([1, (:b, :c)], [p[1], (p[2], p[3])], + [new_p[1], (new_p[2], new_p[3])], false), + ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), + (new_p[1], (new_p[2], new_p[3])), true)] + get = getp(sys, sym; inbounds) + set! = setp(sys, sym; inbounds) + if check_inference + @inferred get(fi) + end + @test get(fi) == fi.ps[sym] + @test get(fi) == oldval - @test fi.counter[] == 0 - if check_inference - @inferred set!(fi, newval) - else - set!(fi, newval) - end - @test fi.counter[] == 1 + if pType === Tuple + @test_throws MethodError set!(fi, newval) + continue + end - @test get(fi) == newval - set!(fi, oldval) - @test get(fi) == oldval - @test fi.counter[] == 2 + @test fi.counter[] == 0 + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test fi.counter[] == 1 - fi.ps[sym] = newval - @test get(fi) == newval - @test fi.counter[] == 3 - fi.ps[sym] = oldval - @test get(fi) == oldval - @test fi.counter[] == 4 + @test get(fi) == newval + set!(fi, oldval) + @test get(fi) == oldval + @test fi.counter[] == 2 - if check_inference - @inferred get(p) - end - @test get(p) == oldval - if check_inference - @inferred set!(p, newval) - else - set!(p, newval) - end - @test get(p) == newval - set!(p, oldval) - @test get(p) == oldval - @test fi.counter[] == 4 - fi.counter[] = 0 - end + fi.ps[sym] = newval + @test get(fi) == newval + @test fi.counter[] == 3 + fi.ps[sym] = oldval + @test get(fi) == oldval + @test fi.counter[] == 4 - for (sym, val, check_inference) in [ - ([:a, :b, :c, :d], p, true), - ([:c, :a], p[[3, 1]], !has_ts), - ((:b, :a), Tuple(p[[2, 1]]), true), - ((1, :c), Tuple(p[[1, 3]]), true), - (:(a + b + t), p[1] + p[2] + fi.t, true), - ([:(a + b + t), :c], [p[1] + p[2] + fi.t, p[3]], true), - ((:(a + b + t), :c), (p[1] + p[2] + fi.t, p[3]), true) - ] - get = getp(sys, sym) - if check_inference - @inferred get(fi) - end - @test get(fi) == val - if sym isa Union{Array, Tuple} - buffer = zeros(length(sym)) if check_inference - @inferred get(buffer, fi) + @inferred get(p) + end + @test get(p) == oldval + if check_inference + @inferred set!(p, newval) else - get(buffer, fi) + set!(p, newval) + end + @test get(p) == newval + set!(p, oldval) + @test get(p) == oldval + @test fi.counter[] == 4 + fi.counter[] = 0 + + if inbounds + test_no_boundschecks(fi) end - @test buffer == collect(val) end - end - for (sym, val, check_inference) in [ - (:(a + b), p[1] + p[2], true), - ([:(a + b), :(a * b)], [p[1] + p[2], p[1] * p[2]], true), - ((:(a + b), :(a * b)), (p[1] + p[2], p[1] * p[2]), true), - ([:(a + c), :(a + b)], [p[1] + p[3], p[1] + p[2]], true) - ] - get = getp(sys, sym) - if check_inference - @inferred get(parameter_values(fi)) + for (sym, val, check_inference) in [ + ([:a, :b, :c, :d], p, true), + ([:c, :a], p[[3, 1]], !has_ts), + ((:b, :a), Tuple(p[[2, 1]]), true), + ((1, :c), Tuple(p[[1, 3]]), true), + (:(a + b + t), p[1] + p[2] + fi.t, true), + ([:(a + b + t), :c], [p[1] + p[2] + fi.t, p[3]], true), + ((:(a + b + t), :c), (p[1] + p[2] + fi.t, p[3]), true) + ] + get = getp(sys, sym; inbounds) + if check_inference + @inferred get(fi) + end + @test get(fi) == val + if sym isa Union{Array, Tuple} + buffer = zeros(length(sym)) + if check_inference + @inferred get(buffer, fi) + else + get(buffer, fi) + end + @test buffer == collect(val) + end + + if inbounds + test_no_boundschecks(fi) + end end - @test get(parameter_values(fi)) == val - if sym isa Union{Array, Tuple} - buffer = zeros(length(sym)) + + for (sym, val, check_inference) in [ + (:(a + b), p[1] + p[2], true), + ([:(a + b), :(a * b)], [p[1] + p[2], p[1] * p[2]], true), + ((:(a + b), :(a * b)), (p[1] + p[2], p[1] * p[2]), true), + ([:(a + c), :(a + b)], [p[1] + p[3], p[1] + p[2]], true) + ] + get = getp(sys, sym; inbounds) if check_inference - @inferred get(buffer, parameter_values(fi)) - else - get(buffer, parameter_values(fi)) + @inferred get(parameter_values(fi)) + end + @test get(parameter_values(fi)) == val + if sym isa Union{Array, Tuple} + buffer = zeros(length(sym)) + if check_inference + @inferred get(buffer, parameter_values(fi)) + else + get(buffer, parameter_values(fi)) + end + @test buffer == collect(val) + end + + if inbounds + test_no_boundschecks(fi) end - @test buffer == collect(val) end - end - for sym in [ - :(a + t), - [:(a + t), :(a * b)], - (:(a + t), :(a * b)) - ] - get = getp(sys, sym) - @test_throws MethodError get(parameter_values(fi)) - if sym isa Union{Array, Tuple} - @test_throws MethodError get(zeros(length(sym)), parameter_values(fi)) + for sym in [ + :(a + t), + [:(a + t), :(a * b)], + (:(a + t), :(a * b)) + ] + get = getp(sys, sym; inbounds) + @test_throws MethodError get(parameter_values(fi)) + if sym isa Union{Array, Tuple} + @test_throws MethodError get(zeros(length(sym)), parameter_values(fi)) + end + if inbounds + test_no_boundschecks(fi) + end + end + + getter = getp(sys, []; inbounds) + @test getter(fi) == [] + getter = getp(sys, (); inbounds) + @test getter(fi) == () + + if inbounds + test_no_boundschecks(fi) + end + + for (sym, val) in [ + (:a, 1.0f1), + (1, 1.0f1), + ([:a, :b], [1.0f1, 2.0f1]), + ((:b, :c), (2.0f1, 3.0f1)) + ] + setter = setp_oop(fi, sym; inbounds) + newp = setter(fi, val) + getter = getp(sys, sym; inbounds) + @test getter(newp) == val + + if inbounds + test_no_boundschecks(fi) + end end end + end - getter = getp(sys, []) + let + sc = SymbolCache(nothing, nothing, :t) + fi = FakeIntegrator(sc, nothing, 0.0, Ref(0)) + getter = getp(sc, []) @test getter(fi) == [] - getter = getp(sys, ()) + getter = getp(sc, ()) @test getter(fi) == () - - for (sym, val) in [ - (:a, 1.0f1), - (1, 1.0f1), - ([:a, :b], [1.0f1, 2.0f1]), - ((:b, :c), (2.0f1, 3.0f1)) - ] - setter = setp_oop(fi, sym) - newp = setter(fi, val) - getter = getp(sys, sym) - @test getter(newp) == val - end end end -let - sc = SymbolCache(nothing, nothing, :t) - fi = FakeIntegrator(sc, nothing, 0.0, Ref(0)) - getter = getp(sc, []) - @test getter(fi) == [] - getter = getp(sc, ()) - @test getter(fi) == () +struct MyDiffEqArray{ + T <: AbstractVector{Float64}, U <: AbstractVector{<:AbstractVector{Float64}}} + t::T + u::U end -struct MyDiffEqArray - t::Vector{Float64} - u::Vector{Vector{Float64}} -end SymbolicIndexingInterface.current_time(mda::MyDiffEqArray) = mda.t SymbolicIndexingInterface.state_values(mda::MyDiffEqArray) = mda.u -SymbolicIndexingInterface.is_timeseries(::Type{MyDiffEqArray}) = Timeseries() +SymbolicIndexingInterface.is_timeseries(::Type{<:MyDiffEqArray}) = Timeseries() + +function test_no_boundschecks(mda::MyDiffEqArray) + test_no_boundschecks(mda.t) + test_no_boundschecks(mda.u) + if mda.u isa CheckboundsCountedArray + for buf in mda.u.array + test_no_boundschecks(buf) + end + end +end -struct MyParameterObject - p::Vector{Float64} - disc_idxs::Vector{Vector{Int}} +struct MyParameterObject{P, D} + p::P + disc_idxs::D end SymbolicIndexingInterface.parameter_values(mpo::MyParameterObject) = mpo.p @@ -214,14 +265,24 @@ function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( return mpo end +function test_no_boundschecks(mpo::MyParameterObject) + test_no_boundschecks(mpo.p) + test_no_boundschecks(mpo.disc_idxs) + if mpo.disc_idxs isa CheckboundsCountedArray + for buf in mpo.disc_idxs.array + test_no_boundschecks(buf) + end + end +end + Base.getindex(mpo::MyParameterObject, i) = mpo.p[i] -struct FakeSolution +struct FakeSolution{U, T, P <: MyParameterObject, PT <: ParameterTimeseriesCollection} sys::SymbolCache - u::Vector{Vector{Float64}} - t::Vector{Float64} - p::MyParameterObject - p_ts::ParameterTimeseriesCollection{Vector{MyDiffEqArray}, MyParameterObject} + u::U + t::T + p::P + p_ts::PT end function Base.getproperty(fs::FakeSolution, s::Symbol) @@ -233,327 +294,378 @@ SymbolicIndexingInterface.symbolic_container(fs::FakeSolution) = fs.sys SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[i] SymbolicIndexingInterface.get_parameter_timeseries_collection(fs::FakeSolution) = fs.p_ts -SymbolicIndexingInterface.is_timeseries(::Type{FakeSolution}) = Timeseries() -SymbolicIndexingInterface.is_parameter_timeseries(::Type{FakeSolution}) = Timeseries() +SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries() +SymbolicIndexingInterface.is_parameter_timeseries(::Type{<:FakeSolution}) = Timeseries() + +function test_no_boundschecks(fs::FakeSolution) + test_no_boundschecks(fs.u) + if fs.u isa CheckboundsCountedArray + for buf in fs.u.array + test_no_boundschecks(buf) + end + end + test_no_boundschecks(fs.t) + test_no_boundschecks(fs.p) + test_no_boundschecks(fs.p_ts) +end + sys = SymbolCache([:x, :y, :z], [:a, :b, :c, :d], :t; timeseries_parameters = Dict( :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) -b_timeseries = MyDiffEqArray(collect(0:0.1:0.9), [[2.5i] for i in 1:10]) -c_timeseries = MyDiffEqArray(collect(0:0.25:0.9), [[3.5i] for i in 1:4]) -p = MyParameterObject( - [20.0, b_timeseries.u[end][1], c_timeseries.u[end][1], 30.0], [[2], [3]]) -fs = FakeSolution( - sys, - [i * ones(3) for i in 1:5], - [0.2i for i in 1:5], - p, - ParameterTimeseriesCollection([b_timeseries, c_timeseries], deepcopy(p)) -) -aval = fs.p[1] -bval = getindex.(b_timeseries.u) -cval = getindex.(c_timeseries.u) -dval = fs.p[4] -bidx = timeseries_parameter_index(sys, :b) -cidx = timeseries_parameter_index(sys, :c) - -# IndexerNotTimeseries -for (sym, val, buffer, check_inference) in [ - (:a, aval, nothing, true), - (1, aval, nothing, true), - ([:a, :d], [aval, dval], zeros(2), true), - ((:a, :d), (aval, dval), zeros(2), true), - ([1, 4], [aval, dval], zeros(2), true), - ((1, 4), (aval, dval), zeros(2), true), - ([:a, 4], [aval, dval], zeros(2), true), - ((:a, 4), (aval, dval), zeros(2), true), - (:(a + d), aval + dval, nothing, true), - ([:(a + d), :(a * d)], [aval + dval, aval * dval], zeros(2), true), - ((:(a + d), :(a * d)), (aval + dval, aval * dval), zeros(2), true) -] - getter = getp(fs, sym) - @test is_indexer_timeseries(getter) isa IndexerNotTimeseries - test_inplace = buffer !== nothing - is_observed = sym isa Expr || - sym isa Union{AbstractArray, Tuple} && any(x -> x isa Expr, sym) - if check_inference - @inferred getter(fs) + +@testset "FakeSolution: inbounds = $inbounds" for inbounds in [false, true] + b_timeseries = MyDiffEqArray(maybe_CCA(collect(0:0.1:0.9), inbounds), + maybe_CCA([maybe_CCA([2.5i], inbounds) for i in 1:10], inbounds)) + c_timeseries = MyDiffEqArray(maybe_CCA(collect(0:0.25:0.9), inbounds), + maybe_CCA([maybe_CCA([3.5i], inbounds) for i in 1:4], inbounds)) + p = MyParameterObject( + maybe_CCA([20.0, b_timeseries.u[end][1], c_timeseries.u[end][1], 30.0], inbounds), maybe_CCA( + [maybe_CCA([2], inbounds), maybe_CCA([3], inbounds)], inbounds)) + fs = FakeSolution( + sys, + maybe_CCA([maybe_CCA(i * ones(3), inbounds) for i in 1:5], inbounds), + maybe_CCA([0.2i for i in 1:5], inbounds), + p, + ParameterTimeseriesCollection( + maybe_CCA([b_timeseries, c_timeseries], inbounds), deepcopy(p)) + ) + aval = @inbounds fs.p[1] + bval = @inbounds getindex.(b_timeseries.u) + cval = @inbounds getindex.(c_timeseries.u) + dval = @inbounds fs.p[4] + bidx = timeseries_parameter_index(sys, :b) + cidx = timeseries_parameter_index(sys, :c) + # IndexerNotTimeseries + for (sym, val, buffer, check_inference) in [ + (:a, aval, nothing, true), + (1, aval, nothing, true), + ([:a, :d], [aval, dval], zeros(2), true), + ((:a, :d), (aval, dval), zeros(2), true), + ([1, 4], [aval, dval], zeros(2), true), + ((1, 4), (aval, dval), zeros(2), true), + ([:a, 4], [aval, dval], zeros(2), true), + ((:a, 4), (aval, dval), zeros(2), true), + (:(a + d), aval + dval, nothing, true), + ([:(a + d), :(a * d)], [aval + dval, aval * dval], zeros(2), true), + ((:(a + d), :(a * d)), (aval + dval, aval * dval), zeros(2), true) + ] + getter = getp(fs, sym; inbounds) + @test is_indexer_timeseries(getter) isa IndexerNotTimeseries + test_inplace = buffer !== nothing + is_observed = sym isa Expr || + sym isa Union{AbstractArray, Tuple} && any(x -> x isa Expr, sym) + if check_inference + @inferred getter(fs) + if !is_observed + @inferred getter(parameter_values(fs)) + end + if test_inplace + @inferred getter(deepcopy(buffer), fs) + if !is_observed + @inferred getter(deepcopy(buffer), parameter_values(fs)) + end + end + end + @test getter(fs) == val if !is_observed - @inferred getter(parameter_values(fs)) + @test getter(parameter_values(fs)) == val end if test_inplace - @inferred getter(deepcopy(buffer), fs) - if !is_observed - @inferred getter(deepcopy(buffer), parameter_values(fs)) + target = collect(val) + valps = is_observed ? (fs,) : (fs, parameter_values(fs)) + for valp in valps + tmp = deepcopy(buffer) + getter(tmp, valp) + @test tmp == target end end - end - @test getter(fs) == val - if !is_observed - @test getter(parameter_values(fs)) == val - end - if test_inplace - target = collect(val) - valps = is_observed ? (fs,) : (fs, parameter_values(fs)) - for valp in valps - tmp = deepcopy(buffer) - getter(tmp, valp) - @test tmp == target + for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] + @test getter(fs, subidx) == val + if test_inplace + tmp = deepcopy(buffer) + getter(tmp, fs, subidx) + @test tmp == collect(val) + end end - end - for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] - @test getter(fs, subidx) == val - if test_inplace - tmp = deepcopy(buffer) - getter(tmp, fs, subidx) - @test tmp == collect(val) + if inbounds + test_no_boundschecks(fs) end end -end -# IndexerBoth -for (sym, timeseries_index, val, buffer, check_inference) in [ - (:b, 1, bval, zeros(length(bval)), true), - ([:a, :b], 1, vcat.(aval, bval), map(_ -> zeros(2), bval), false), - ((:a, :b), 1, tuple.(aval, bval), map(_ -> zeros(2), bval), true), - ([1, :b], 1, vcat.(aval, bval), map(_ -> zeros(2), bval), false), - ((1, :b), 1, tuple.(aval, bval), map(_ -> zeros(2), bval), true), - ([:b, :b], 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), - ((:b, :b), 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), - (:(a + b), 1, bval .+ aval, zeros(length(bval)), true), - ([:(a + b), :a], 1, vcat.(bval .+ aval, aval), map(_ -> zeros(2), bval), true), - ((:(a + b), :a), 1, tuple.(bval .+ aval, aval), map(_ -> zeros(2), bval), true), - ([:(a + b), :b], 1, vcat.(bval .+ aval, bval), map(_ -> zeros(2), bval), true), - ((:(a + b), :b), 1, tuple.(bval .+ aval, bval), map(_ -> zeros(2), bval), true) -] - getter = getp(sys, sym) - @test is_indexer_timeseries(getter) isa IndexerBoth - @test indexer_timeseries_index(getter) == timeseries_index - isobs = sym isa Union{AbstractArray, Tuple} ? any(Base.Fix1(is_observed, sys), sym) : - is_observed(sys, sym) - - if check_inference - @inferred getter(fs) - @inferred getter(deepcopy(buffer), fs) + # IndexerBoth + for (sym, timeseries_index, val, buffer, check_inference) in [ + (:b, 1, bval, zeros(length(bval)), true), + ([:a, :b], 1, vcat.(aval, bval), map(_ -> zeros(2), bval), false), + ((:a, :b), 1, tuple.(aval, bval), map(_ -> zeros(2), bval), true), + ([1, :b], 1, vcat.(aval, bval), map(_ -> zeros(2), bval), false), + ((1, :b), 1, tuple.(aval, bval), map(_ -> zeros(2), bval), true), + ([:b, :b], 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((:b, :b), 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + (:(a + b), 1, bval .+ aval, zeros(length(bval)), true), + ([:(a + b), :a], 1, vcat.(bval .+ aval, aval), map(_ -> zeros(2), bval), true), + ((:(a + b), :a), 1, tuple.(bval .+ aval, aval), map(_ -> zeros(2), bval), true), + ([:(a + b), :b], 1, vcat.(bval .+ aval, bval), map(_ -> zeros(2), bval), true), + ((:(a + b), :b), 1, tuple.(bval .+ aval, bval), map(_ -> zeros(2), bval), true) + ] + getter = getp(sys, sym; inbounds) + @test is_indexer_timeseries(getter) isa IndexerBoth + @test indexer_timeseries_index(getter) == timeseries_index + isobs = sym isa Union{AbstractArray, Tuple} ? + any(Base.Fix1(is_observed, sys), sym) : + is_observed(sys, sym) + + if check_inference + @inferred getter(fs) + @inferred getter(deepcopy(buffer), fs) + if !isobs + @inferred getter(parameter_values(fs)) + if !(eltype(val) <: Number) + @inferred getter(deepcopy(buffer[1]), parameter_values(fs)) + end + end + end + + @test getter(fs) == val + if eltype(val) <: Number + target = val + else + target = collect.(val) + end + tmp = deepcopy(buffer) + getter(tmp, fs) + @test tmp == target + if !isobs - @inferred getter(parameter_values(fs)) + @test getter(parameter_values(fs)) == val[end] if !(eltype(val) <: Number) - @inferred getter(deepcopy(buffer[1]), parameter_values(fs)) + target = collect(val[end]) + tmp = deepcopy(buffer)[end] + getter(tmp, parameter_values(fs)) + @test tmp == target end end - end - - @test getter(fs) == val - if eltype(val) <: Number - target = val - else - target = collect.(val) - end - tmp = deepcopy(buffer) - getter(tmp, fs) - @test tmp == target - - if !isobs - @test getter(parameter_values(fs)) == val[end] - if !(eltype(val) <: Number) - target = collect(val[end]) - tmp = deepcopy(buffer)[end] - getter(tmp, parameter_values(fs)) + if inbounds + test_no_boundschecks(fs) + end + for subidx in [ + 1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] + if check_inference + @inferred getter(fs, subidx) + if !isa(val[subidx], Number) + @inferred getter(deepcopy(buffer[subidx]), fs, subidx) + end + end + @test getter(fs, subidx) == val[subidx] + tmp = deepcopy(buffer[subidx]) + if val[subidx] isa Number + continue + end + target = val[subidx] + if eltype(target) <: Number + target = collect(target) + else + target = collect.(target) + end + getter(tmp, fs, subidx) @test tmp == target + + if inbounds + test_no_boundschecks(fs) + end end end - for subidx in [ - 1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] + + # IndexerOnlyTimeseries + for (sym, timeseries_index, val, buffer, check_inference) in [ + (bidx, 1, bval, zeros(length(bval)), true), + ([bidx, :b], 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, :b), 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ([bidx, bidx], 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, bidx), 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true) + ] + getter = getp(sys, sym; inbounds) + @test is_indexer_timeseries(getter) isa IndexerOnlyTimeseries + @test indexer_timeseries_index(getter) == timeseries_index + + isscalar = eltype(val) <: Number + if check_inference - @inferred getter(fs, subidx) - if !isa(val[subidx], Number) - @inferred getter(deepcopy(buffer[subidx]), fs, subidx) - end - end - @test getter(fs, subidx) == val[subidx] - tmp = deepcopy(buffer[subidx]) - if val[subidx] isa Number - continue + @inferred getter(fs) + @inferred getter(deepcopy(buffer), fs) end - target = val[subidx] - if eltype(target) <: Number - target = collect(target) + + @test getter(fs) == val + target = if isscalar + val else - target = collect.(target) + collect.(val) end - getter(tmp, fs, subidx) + tmp = deepcopy(buffer) + getter(tmp, fs) @test tmp == target - end -end -# IndexerOnlyTimeseries -for (sym, timeseries_index, val, buffer, check_inference) in [ - (bidx, 1, bval, zeros(length(bval)), true), - ([bidx, :b], 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), - ((bidx, :b), 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), - ([bidx, bidx], 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), - ((bidx, bidx), 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true) -] - getter = getp(sys, sym) - @test is_indexer_timeseries(getter) isa IndexerOnlyTimeseries - @test indexer_timeseries_index(getter) == timeseries_index - - isscalar = eltype(val) <: Number - - if check_inference - @inferred getter(fs) - @inferred getter(deepcopy(buffer), fs) - end + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs)) + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter( + [], parameter_values(fs)) - @test getter(fs) == val - target = if isscalar - val - else - collect.(val) - end - tmp = deepcopy(buffer) - getter(tmp, fs) - @test tmp == target - - @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs)) - @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter( - [], parameter_values(fs)) - - for subidx in [ - 1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] - if check_inference - @inferred getter(fs, subidx) - if !isa(val[subidx], Number) - @inferred getter(deepcopy(buffer[subidx]), fs, subidx) + if inbounds + test_no_boundschecks(fs) + end + for subidx in [ + 1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] + if check_inference + @inferred getter(fs, subidx) + if !isa(val[subidx], Number) + @inferred getter(deepcopy(buffer[subidx]), fs, subidx) + end + end + @test getter(fs, subidx) == val[subidx] + if val[subidx] isa Number + continue + end + tmp = deepcopy(buffer[subidx]) + target = val[subidx] + if eltype(target) <: Number + target = collect(target) + else + target = collect.(target) + end + getter(tmp, fs, subidx) + @test tmp == target + if inbounds + test_no_boundschecks(fs) end end - @test getter(fs, subidx) == val[subidx] - if val[subidx] isa Number - continue + end + + # IndexerMixedTimeseries + for sym in [ + [:a, :b, :c], + (:a, :b, :c), + :(b + c), + [:(a + b), :c], + (:(a + b), :c) + ] + getter = getp(sys, sym; inbounds) + @test_throws MixedParameterTimeseriesIndexError getter(fs) + @test_throws MixedParameterTimeseriesIndexError getter([], fs) + for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] + @test_throws MixedParameterTimeseriesIndexError getter(fs, subidx) + @test_throws MixedParameterTimeseriesIndexError getter([], fs, subidx) end - tmp = deepcopy(buffer[subidx]) - target = val[subidx] - if eltype(target) <: Number - target = collect(target) - else - target = collect.(target) + if inbounds + test_no_boundschecks(fs) end - getter(tmp, fs, subidx) - @test tmp == target end -end -# IndexerMixedTimeseries -for sym in [ - [:a, :b, :c], - (:a, :b, :c), - :(b + c), - [:(a + b), :c], - (:(a + b), :c) -] - getter = getp(sys, sym) - @test_throws MixedParameterTimeseriesIndexError getter(fs) - @test_throws MixedParameterTimeseriesIndexError getter([], fs) - for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] - @test_throws MixedParameterTimeseriesIndexError getter(fs, subidx) - @test_throws MixedParameterTimeseriesIndexError getter([], fs, subidx) + for sym in [[:a, bidx], (:a, bidx), [1, bidx], (1, bidx)] + @test_throws ArgumentError getp(sys, sym; inbounds) end -end -for sym in [[:a, bidx], (:a, bidx), [1, bidx], (1, bidx)] - @test_throws ArgumentError getp(sys, sym) -end + for (sym, val) in [([:b, :c], [bval[end], cval[end]]) + ((:b, :c), (bval[end], cval[end]))] + getter = getp(sys, sym; inbounds) + @test is_indexer_timeseries(getter) == IndexerMixedTimeseries() + @test_throws MixedParameterTimeseriesIndexError getter(fs) + @test getter(parameter_values(fs)) == val + if inbounds + test_no_boundschecks(fs) + end + end -for (sym, val) in [([:b, :c], [bval[end], cval[end]]) - ((:b, :c), (bval[end], cval[end]))] - getter = getp(sys, sym) - @test is_indexer_timeseries(getter) == IndexerMixedTimeseries() - @test_throws MixedParameterTimeseriesIndexError getter(fs) - @test getter(parameter_values(fs)) == val -end + xval = @inbounds getindex.(fs.u, 1) + + for (sym, val_is_timeseries, val, check_inference) in [ + (:a, false, aval, true), + ([:a, :d], false, [aval, dval], true), + ((:a, :d), false, (aval, dval), true), + (:b, true, bval, true), + ([:a, :b], true, vcat.(aval, bval), false), + ((:a, :b), true, tuple.(aval, bval), true), + ([:a, :x], true, vcat.(aval, xval), false), + ((:a, :x), true, tuple.(aval, xval), true), + (:(2b), true, 2 .* bval, true), + ([:a, :(2b)], true, vcat.(aval, 2 .* bval), true), + ((:a, :(2b)), true, tuple.(aval, 2 .* bval), true) + ] + getter = getsym(sys, sym; inbounds) + if check_inference + @inferred getter(fs) + end + @test getter(fs) == val -xval = getindex.(fs.u, 1) - -for (sym, val_is_timeseries, val, check_inference) in [ - (:a, false, aval, true), - ([:a, :d], false, [aval, dval], true), - ((:a, :d), false, (aval, dval), true), - (:b, true, bval, true), - ([:a, :b], true, vcat.(aval, bval), false), - ((:a, :b), true, tuple.(aval, bval), true), - ([:a, :x], true, vcat.(aval, xval), false), - ((:a, :x), true, tuple.(aval, xval), true), - (:(2b), true, 2 .* bval, true), - ([:a, :(2b)], true, vcat.(aval, 2 .* bval), true), - ((:a, :(2b)), true, tuple.(aval, 2 .* bval), true) -] - getter = getsym(sys, sym) - if check_inference - @inferred getter(fs) + reference = val_is_timeseries ? val : xval + for subidx in [ + 1, CartesianIndex(2), :, rand(Bool, length(reference)), + rand(eachindex(reference), 3), 1:2] + if check_inference + @inferred getter(fs, subidx) + end + target = if val_is_timeseries + val[subidx] + else + val + end + @test getter(fs, subidx) == target + end + if inbounds + test_no_boundschecks(fs) + end end - @test getter(fs) == val - reference = val_is_timeseries ? val : xval - for subidx in [ - 1, CartesianIndex(2), :, rand(Bool, length(reference)), - rand(eachindex(reference), 3), 1:2] + temp_state = @inbounds ProblemState(; u = fs.u[1], + p = with_updated_parameter_timeseries_values( + sys, parameter_values(fs), 1 => fs.p_ts[1, 1], 2 => fs.p_ts[2, 1]), + t = fs.t[1]) + _xval = @inbounds temp_state.u[1] + _bval = @inbounds bval[1] + _cval = @inbounds cval[1] + for (sym, val, check_inference) in [ + ([:x, :b], [_xval, _bval], false), + ((:x, :c), (_xval, _cval), true), + (:(x + b), _xval + _bval, true), + ([:(2b), :(3x)], [2_bval, 3_xval], true), + ((:(2b), :(3x)), (2_bval, 3_xval), true) + ] + getter = getsym(sys, sym; inbounds) + @test_throws MixedParameterTimeseriesIndexError getter(fs) + for subidx in [1, CartesianIndex(2), :, rand(Bool, 4), rand(1:4, 3), 1:2] + @test_throws MixedParameterTimeseriesIndexError getter(fs, subidx) + end if check_inference - @inferred getter(fs, subidx) + @inferred getter(temp_state) end - target = if val_is_timeseries - val[subidx] - else - val + @test getter(temp_state) == val + if inbounds + test_no_boundschecks(temp_state) end - @test getter(fs, subidx) == target end -end -temp_state = ProblemState(; u = fs.u[1], - p = with_updated_parameter_timeseries_values( - sys, parameter_values(fs), 1 => fs.p_ts[1, 1], 2 => fs.p_ts[2, 1]), - t = fs.t[1]) -_xval = temp_state.u[1] -_bval = bval[1] -_cval = cval[1] -for (sym, val, check_inference) in [ - ([:x, :b], [_xval, _bval], false), - ((:x, :c), (_xval, _cval), true), - (:(x + b), _xval + _bval, true), - ([:(2b), :(3x)], [2_bval, 3_xval], true), - ((:(2b), :(3x)), (2_bval, 3_xval), true) -] - getter = getsym(sys, sym) - @test_throws MixedParameterTimeseriesIndexError getter(fs) - for subidx in [1, CartesianIndex(2), :, rand(Bool, 4), rand(1:4, 3), 1:2] - @test_throws MixedParameterTimeseriesIndexError getter(fs, subidx) - end - if check_inference - @inferred getter(temp_state) + for sym in [ + :err, + [:err, :b], + (:err, :b) + ] + @test_throws ErrorException getp(sys, sym) end - @test getter(temp_state) == val -end - -for sym in [ - :err, - [:err, :b], - (:err, :b) -] - @test_throws ErrorException getp(sys, sym) -end -let fs = fs, sys = sys - getter = getp(sys, []) - @test getter(fs) == [] - getter = getp(sys, ()) - @test getter(fs) == () + let fs = fs, sys = sys + getter = getp(sys, []) + @test getter(fs) == [] + getter = getp(sys, ()) + @test getter(fs) == () + if inbounds + test_no_boundschecks(fs) + end + end end -struct FakeNoTimeSolution +struct FakeNoTimeSolution{U <: AbstractVector{Float64}, P <: AbstractVector{Float64}} sys::SymbolCache - u::Vector{Float64} - p::Vector{Float64} + u::U + p::P end SymbolicIndexingInterface.state_values(fs::FakeNoTimeSolution) = fs.u @@ -561,28 +673,38 @@ SymbolicIndexingInterface.symbolic_container(fs::FakeNoTimeSolution) = fs.sys SymbolicIndexingInterface.parameter_values(fs::FakeNoTimeSolution) = fs.p SymbolicIndexingInterface.parameter_values(fs::FakeNoTimeSolution, i) = fs.p[i] -sys = SymbolCache([:x, :y, :z], [:a, :b, :c]) -u = [1.0, 2.0, 3.0] -p = [10.0, 20.0, 30.0] -fs = FakeNoTimeSolution(sys, u, p) - -for (sym, val, check_inference) in [ - (:a, p[1], true), - ([:a, :b], p[1:2], true), - ((:c, :b), (p[3], p[2]), true), - (:(a + b), p[1] + p[2], true), - ([:(a + b), :c], [p[1] + p[2], p[3]], true), - ((:(a + b), :c), (p[1] + p[2], p[3]), true) -] - getter = getp(sys, sym) - if check_inference - @inferred getter(fs) - end - @test getter(fs) == val +function test_no_boundschecks(fs::FakeNoTimeSolution) + test_no_boundschecks(fs.u) + test_no_boundschecks(fs.p) +end - if sym isa Union{Array, Tuple} - buffer = zeros(length(sym)) - @inferred getter(buffer, fs) - @test buffer == collect(val) +@testset "FakeNoTimeSolution: inbounds = $inbounds" for inbounds in [false, true] + sys = SymbolCache([:x, :y, :z], [:a, :b, :c]) + u = [1.0, 2.0, 3.0] + p = [10.0, 20.0, 30.0] + fs = FakeNoTimeSolution(sys, maybe_CCA(u, inbounds), maybe_CCA(p, inbounds)) + + for (sym, val, check_inference) in [ + (:a, p[1], true), + ([:a, :b], p[1:2], true), + ((:c, :b), (p[3], p[2]), true), + (:(a + b), p[1] + p[2], true), + ([:(a + b), :c], [p[1] + p[2], p[3]], true), + ((:(a + b), :c), (p[1] + p[2], p[3]), true) + ] + getter = getp(sys, sym; inbounds) + if check_inference + @inferred getter(fs) + end + @test getter(fs) == val + + if sym isa Union{Array, Tuple} + buffer = zeros(length(sym)) + @inferred getter(buffer, fs) + @test buffer == collect(val) + end + if inbounds + test_no_boundschecks(fs) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 13550e2..530ddab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,51 @@ function activate_downstream_env() Pkg.instantiate() end +mutable struct CheckboundsCountedArray{T, N, A <: AbstractArray{T, N}} <: + AbstractArray{T, N} + array::A + count::Int +end + +CheckboundsCountedArray(arr) = CheckboundsCountedArray(arr, 0) +Base.@propagate_inbounds Base.getindex(arr::CheckboundsCountedArray, args...) = getindex( + arr.array, args...) +Base.@propagate_inbounds Base.setindex!(arr::CheckboundsCountedArray, args...) = setindex!( + arr.array, args...) +Base.size(arr::CheckboundsCountedArray) = size(arr.array) +Base.length(arr::CheckboundsCountedArray) = length(arr.array) +function Base.checkbounds(arr::CheckboundsCountedArray, args...) + arr.count += 1 + checkbounds(arr.array, args...) +end +function Base.checkbounds(::Type{Bool}, arr::CheckboundsCountedArray, args...) + arr.count += 1 + checkbounds(arr.array, args...) +end + +function Base.copy(arr::CheckboundsCountedArray) + return CheckboundsCountedArray(copy(arr.array), arr.count) +end + +function test_no_boundschecks(arr::CheckboundsCountedArray) + @test arr.count == 0 +end +function test_no_boundschecks(p::ProblemState) + test_no_boundschecks(p.u) + test_no_boundschecks(p.p) +end +function test_no_boundschecks(ptc::ParameterTimeseriesCollection) + test_no_boundschecks(ptc.collection) + arr = ptc.collection isa CheckboundsCountedArray ? ptc.collection.array : ptc.collection + for buf in ptc.collection + test_no_boundschecks(buf) + end + test_no_boundschecks(ptc.paramcache) +end +test_no_boundschecks(_) = nothing + +maybe_CheckboundsCountedArray(arr, inbounds) = inbounds ? CheckboundsCountedArray(arr) : arr + if GROUP == "All" || GROUP == "Core" @safetestset "Quality Assurance" begin @time include("qa.jl") diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 4537276..0385cc7 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -1,5 +1,8 @@ using SymbolicIndexingInterface using SymbolicIndexingInterface: NotVariableOrParameter +import ..CheckboundsCountedArray +import ..maybe_CheckboundsCountedArray as maybe_CCA +import ..test_no_boundschecks struct FakeIntegrator{S, U, P, T} sys::S @@ -21,146 +24,176 @@ sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) u = [1.0, 2.0, 3.0] p = [11.0, 12.0, 13.0] t = 0.5 -fi = FakeIntegrator(sys, copy(u), copy(p), t) -# checking inference for non-concretely typed arrays will always fail -for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) - (:y, u[2], 4.0, true) - (:z, u[3], 4.0, true) - (1, u[1], 4.0, true) - ([:x, :y], u[1:2], 4ones(2), true) - ([1, 2], u[1:2], 4ones(2), true) - ((:z, :y), (u[3], u[2]), (4.0, 5.0), true) - ((3, 2), (u[3], u[2]), (4.0, 5.0), true) - ([:x, [:y, :z]], [u[1], u[2:3]], - [4.0, [5.0, 6.0]], false) - ([:x, 2:3], [u[1], u[2:3]], - [4.0, [5.0, 6.0]], false) - ([:x, (:y, :z)], [u[1], (u[2], u[3])], - [4.0, (5.0, 6.0)], false) - ([:x, Tuple(2:3)], [u[1], (u[2], u[3])], - [4.0, (5.0, 6.0)], false) - ([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)], - [4.0, [5.0], (6.0,)], false) - ([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)], - [4.0, [5.0], (6.0,)], false) - ((:x, [:y, :z]), (u[1], u[2:3]), - (4.0, [5.0, 6.0]), true) - ((:x, (:y, :z)), (u[1], (u[2], u[3])), - (4.0, (5.0, 6.0)), true) - ((1, (:y, :z)), (u[1], (u[2], u[3])), - (4.0, (5.0, 6.0)), true) - ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), - (4.0, [5.0], (6.0,)), true)] - get = getsym(sys, sym) - set! = setsym(sys, sym) - if check_inference - @inferred get(fi) - end - @test get(fi) == val - if check_inference - @inferred set!(fi, newval) - else - set!(fi, newval) - end - @test get(fi) == newval - new_states = copy(state_values(fi)) +function test_no_boundschecks(fi::FakeIntegrator) + test_no_boundschecks(fi.u) + test_no_boundschecks(fi.p) +end - set!(fi, val) - @test get(fi) == val +@testset "FakeIntegrator: inbounds = $inbounds" for inbounds in [false, true] + fi = FakeIntegrator(sys, maybe_CCA(copy(u), inbounds), maybe_CCA(copy(p), inbounds), t) + # checking inference for non-concretely typed arrays will always fail + for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) + (:y, u[2], 4.0, true) + (:z, u[3], 4.0, true) + (1, u[1], 4.0, true) + ([:x, :y], u[1:2], 4ones(2), true) + ([1, 2], u[1:2], 4ones(2), true) + ((:z, :y), (u[3], u[2]), (4.0, 5.0), true) + ((3, 2), (u[3], u[2]), (4.0, 5.0), true) + ([:x, [:y, :z]], [u[1], u[2:3]], + [4.0, [5.0, 6.0]], false) + ([:x, 2:3], [u[1], u[2:3]], + [4.0, [5.0, 6.0]], false) + ([:x, (:y, :z)], [u[1], (u[2], u[3])], + [4.0, (5.0, 6.0)], false) + ([:x, Tuple(2:3)], [u[1], (u[2], u[3])], + [4.0, (5.0, 6.0)], false) + ([:x, [:y], (:z,)], + [u[1], [u[2]], (u[3],)], + [4.0, [5.0], (6.0,)], false) + ([:x, [:y], (3,)], + [u[1], [u[2]], (u[3],)], + [4.0, [5.0], (6.0,)], false) + ((:x, [:y, :z]), (u[1], u[2:3]), + (4.0, [5.0, 6.0]), true) + ((:x, (:y, :z)), (u[1], (u[2], u[3])), + (4.0, (5.0, 6.0)), true) + ((1, (:y, :z)), (u[1], (u[2], u[3])), + (4.0, (5.0, 6.0)), true) + ((:x, [:y], (:z,)), + (u[1], [u[2]], (u[3],)), + (4.0, [5.0], (6.0,)), true)] + get = getsym(sys, sym; inbounds) + set! = setsym(sys, sym; inbounds) + if check_inference + @inferred get(fi) + end + @test get(fi) == val + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test get(fi) == newval - if check_inference - @inferred get(u) - end - @test get(u) == val - if check_inference - @inferred set!(u, newval) - else - set!(u, newval) - end - @test get(u) == newval - set!(u, val) - @test get(u) == val + new_states = copy(state_values(fi)) - if sym isa Union{Vector, Tuple} && any(x -> x isa Union{AbstractArray, Tuple}, sym) - continue - end + set!(fi, val) + @test get(fi) == val - setter = setsym_oop(sys, sym) - svals, pvals = setter(fi, newval) - @test svals ≈ new_states - @test pvals == parameter_values(fi) -end + if check_inference + @inferred get(u) + end + @test get(u) == val + if check_inference + @inferred set!(u, newval) + else + set!(u, newval) + end + @test get(u) == newval + set!(u, val) + @test get(u) == val -for (sym, val, check_inference) in [ - (:(x + y), u[1] + u[2], true), - ([:(x + y), :z], [u[1] + u[2], u[3]], false), - ((:(x + y), :(z + y)), (u[1] + u[2], u[2] + u[3]), false) -] - get = getsym(sys, sym) - if check_inference - @inferred get(fi) - end - @test get(fi) == val -end + if sym isa Union{Vector, Tuple} && any(x -> x isa Union{AbstractArray, Tuple}, sym) + continue + end -let fi = fi, sys = sys - getter = getsym(sys, []) - @test getter(fi) == [] - getter = getsym(sys, ()) - @test getter(fi) == () - sc = SymbolCache(nothing, [:a, :b], :t) - fi = FakeIntegrator(sys, nothing, [1.0, 2.0], 3.0) - getter = getsym(sc, []) - @test getter(fi) == [] - getter = getsym(sc, ()) - @test getter(fi) == () -end + setter = setsym_oop(sys, sym; inbounds) + svals, pvals = setter(fi, newval) + @test svals ≈ new_states + @test pvals == parameter_values(fi) -for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) - (:b, p[2], 5.0, true) - (:c, p[3], 6.0, true) - ([:a, :b], p[1:2], [4.0, 5.0], true) - ((:c, :b), (p[3], p[2]), (6.0, 5.0), true) - ([:x, :a], [u[1], p[1]], [4.0, 5.0], false) - ((:y, :b), (u[2], p[2]), (5.0, 6.0), true)] - get = getsym(fi, sym) - set! = setsym(fi, sym) - if check_inference - @inferred get(fi) + if inbounds + test_no_boundschecks(fi) + end end - @test get(fi) == oldval - if check_inference - @inferred set!(fi, newval) - else - set!(fi, newval) + + for (sym, val, check_inference) in [ + (:(x + y), u[1] + u[2], true), + ([:(x + y), :z], [u[1] + u[2], u[3]], false), + ((:(x + y), :(z + y)), (u[1] + u[2], u[2] + u[3]), false) + ] + get = getsym(sys, sym; inbounds) + if check_inference + @inferred get(fi) + end + @test get(fi) == val + if inbounds + test_no_boundschecks(fi) + end end - @test get(fi) == newval - newu = copy(state_values(fi)) - newp = copy(parameter_values(fi)) + let fi = fi, sys = sys + getter = getsym(sys, []; inbounds) + @test getter(fi) == [] + getter = getsym(sys, (); inbounds) + @test getter(fi) == () + sc = SymbolCache(nothing, [:a, :b], :t) + fi = FakeIntegrator(sys, nothing, [1.0, 2.0], 3.0) + getter = getsym(sc, []; inbounds) + @test getter(fi) == [] + getter = getsym(sc, (); inbounds) + @test getter(fi) == () + if inbounds + test_no_boundschecks(fi) + end + end - set!(fi, oldval) - @test get(fi) == oldval + for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) + (:b, p[2], 5.0, true) + (:c, p[3], 6.0, true) + ([:a, :b], p[1:2], [4.0, 5.0], true) + ((:c, :b), (p[3], p[2]), + (6.0, 5.0), true) + ([:x, :a], [u[1], p[1]], + [4.0, 5.0], false) + ((:y, :b), (u[2], p[2]), + (5.0, 6.0), true)] + get = getsym(fi, sym; inbounds) + set! = setsym(fi, sym; inbounds) + if check_inference + @inferred get(fi) + end + @test get(fi) == oldval + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test get(fi) == newval - oop_setter = setsym_oop(sys, sym) - uvals, pvals = oop_setter(fi, newval) - @test uvals ≈ newu - @test pvals ≈ newp -end + newu = copy(state_values(fi)) + newp = copy(parameter_values(fi)) + + set!(fi, oldval) + @test get(fi) == oldval -for (sym, val, check_inference) in [ - (:t, t, true), - ([:x, :a, :t], [u[1], p[1], t], false), - ((:x, :a, :t), (u[1], p[1], t), false) -] - get = getsym(fi, sym) - if check_inference - @inferred get(fi) + oop_setter = setsym_oop(sys, sym; inbounds) + uvals, pvals = oop_setter(fi, newval) + @test uvals ≈ newu + @test pvals ≈ newp + if inbounds + test_no_boundschecks(fi) + end end - @test get(fi) == val - @test_throws NotVariableOrParameter setsym_oop(fi, sym) + for (sym, val, check_inference) in [ + (:t, t, true), + ([:x, :a, :t], [u[1], p[1], t], false), + ((:x, :a, :t), (u[1], p[1], t), false) + ] + get = getsym(fi, sym; inbounds) + if check_inference + @inferred get(fi) + end + @test get(fi) == val + + @test_throws NotVariableOrParameter setsym_oop(fi, sym) + if inbounds + test_no_boundschecks(fi) + end + end end struct FakeSolution{S, U, P, T} @@ -180,140 +213,176 @@ SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u SymbolicIndexingInterface.parameter_values(fp::FakeSolution) = fp.p SymbolicIndexingInterface.current_time(fp::FakeSolution) = fp.t -sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) -u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]] -t = [1.5, 2.0, 2.3, 4.0] -sol = FakeSolution(sys, u, p, t) - -xvals = getindex.(sol.u, 1) -yvals = getindex.(sol.u, 2) -zvals = getindex.(sol.u, 3) - -for (sym, ans, check_inference) in [(:x, xvals, true) - (:y, yvals, true) - (:z, zvals, true) - (1, xvals, true) - ([:x, :y], vcat.(xvals, yvals), true) - (1:2, vcat.(xvals, yvals), true) - ([:x, 2], vcat.(xvals, yvals), true) - ((:z, :y), tuple.(zvals, yvals), true) - ((3, 2), tuple.(zvals, yvals), true) - ([:x, [:y, :z]], - vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), - false) - ([:x, (:y, :z)], - vcat.(xvals, tuple.(yvals, zvals)), false) - ([1, (:y, :z)], - vcat.(xvals, tuple.(yvals, zvals)), false) - ([:x, [:y, :z], (:x, :z)], - vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], - tuple.(xvals, zvals)), - false) - ([:x, [:y, 3], (1, :z)], - vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], - tuple.(xvals, zvals)), - false) - ((:x, [:y, :z]), - tuple.(xvals, vcat.(yvals, zvals)), true) - ((:x, (:y, :z)), - tuple.(xvals, tuple.(yvals, zvals)), true) - ((:x, [:y, :z], (:z, :y)), - tuple.(xvals, vcat.(yvals, zvals), - tuple.(zvals, yvals)), - true) - ([:x, :a], vcat.(xvals, p[1]), false) - ((:y, :b), tuple.(yvals, p[2]), true) - (:t, t, true) - ([:x, :a, :t], vcat.(xvals, p[1], t), false) - ((:x, :a, :t), tuple.(xvals, p[1], t), true)] - get = getsym(sys, sym) - if check_inference - @inferred get(sol) +function test_no_boundschecks(fs::FakeSolution) + test_no_boundschecks(fs.u) + if fs.u isa CheckboundsCountedArray + arr = fs.u.array + else + arr = fs.u end - @test get(sol) == ans - for i in [rand(eachindex(u)), CartesianIndex(1), :, - rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] - if check_inference - @inferred get(sol, i) - end - @test get(sol, i) == ans[i] + for buf in arr + test_no_boundschecks(buf) end + test_no_boundschecks(fs.p) + test_no_boundschecks(fs.t) end -for (sym, val, check_inference) in [ - (:(x + y), xvals .+ yvals, true), - ([:(x + y), :z], vcat.(xvals .+ yvals, zvals), false), - ((:(x + y), :(z + y)), tuple.(xvals .+ yvals, yvals .+ zvals), false) -] - get = getsym(sys, sym) - if check_inference - @inferred get(sol) +@testset "FakeSolution: inbounds = $inbounds" for inbounds in [false, true] + sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) + u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]] + p = [11.0, 12.0, 13.0] + t = [1.5, 2.0, 2.3, 4.0] + xvals = getindex.(u, 1) + yvals = getindex.(u, 2) + zvals = getindex.(u, 3) + sol = FakeSolution(sys, maybe_CCA(maybe_CCA.(u, inbounds), inbounds), + maybe_CCA(p, inbounds), maybe_CCA(t, inbounds)) + + for (sym, ans, check_inference) in [(:x, xvals, true) + (:y, yvals, true) + (:z, zvals, true) + (1, xvals, true) + ([:x, :y], vcat.(xvals, yvals), true) + (1:2, vcat.(xvals, yvals), true) + ([:x, 2], vcat.(xvals, yvals), true) + ((:z, :y), tuple.(zvals, yvals), true) + ((3, 2), tuple.(zvals, yvals), true) + ([:x, [:y, :z]], + vcat.( + xvals, [[x] for x in vcat.(yvals, zvals)]), + false) + ([:x, (:y, :z)], + vcat.(xvals, tuple.(yvals, zvals)), false) + ([1, (:y, :z)], + vcat.(xvals, tuple.(yvals, zvals)), false) + ([:x, [:y, :z], (:x, :z)], + vcat.( + xvals, [[x] for x in vcat.(yvals, zvals)], + tuple.(xvals, zvals)), + false) + ([:x, [:y, 3], (1, :z)], + vcat.( + xvals, [[x] for x in vcat.(yvals, zvals)], + tuple.(xvals, zvals)), + false) + ((:x, [:y, :z]), + tuple.(xvals, vcat.(yvals, zvals)), true) + ((:x, (:y, :z)), + tuple.(xvals, tuple.(yvals, zvals)), true) + ((:x, [:y, :z], (:z, :y)), + tuple.(xvals, vcat.(yvals, zvals), + tuple.(zvals, yvals)), + true) + ([:x, :a], vcat.(xvals, p[1]), false) + ((:y, :b), tuple.(yvals, p[2]), true) + (:t, t, true) + ([:x, :a, :t], vcat.(xvals, p[1], t), false) + ((:x, :a, :t), tuple.(xvals, p[1], t), true)] + get = getsym(sys, sym) + if check_inference + @inferred get(sol) + end + @test get(sol) == ans + for i in [rand(eachindex(u)), CartesianIndex(1), :, + rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] + if check_inference + @inferred get(sol, i) + end + @test get(sol, i) == ans[i] + end + if inbounds + test_no_boundschecks(sol) + end end - @test get(sol) == val - for i in [rand(eachindex(u)), CartesianIndex(1), :, - rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] + + for (sym, val, check_inference) in [ + (:(x + y), xvals .+ yvals, true), + ([:(x + y), :z], vcat.(xvals .+ yvals, zvals), false), + ((:(x + y), :(z + y)), tuple.(xvals .+ yvals, yvals .+ zvals), false) + ] + get = getsym(sys, sym) if check_inference - @inferred get(sol, i) + @inferred get(sol) + end + @test get(sol) == val + for i in [rand(eachindex(u)), CartesianIndex(1), :, + rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] + if check_inference + @inferred get(sol, i) + end + @test get(sol, i) == val[i] + end + if inbounds + test_no_boundschecks(sol) end - @test get(sol, i) == val[i] end -end -for (sym, val) in [(:a, p[1]) - (:b, p[2]) - (:c, p[3]) - ([:a, :b], p[1:2]) - ((:c, :b), (p[3], p[2]))] - get = getsym(sys, sym) - @inferred get(sol) - @test get(sol) == val -end + for (sym, val) in [(:a, p[1]) + (:b, p[2]) + (:c, p[3]) + ([:a, :b], p[1:2]) + ((:c, :b), (p[3], p[2]))] + get = getsym(sys, sym) + @inferred get(sol) + @test get(sol) == val + if inbounds + test_no_boundschecks(sol) + end + end -let sol = sol, sys = sys - getter = getsym(sys, []) - @test getter(sol) == [[] for _ in 1:length(sol.t)] - getter = getsym(sys, ()) - @test getter(sol) == [() for _ in 1:length(sol.t)] - sc = SymbolCache(nothing, [:a, :b], :t) - sol = FakeSolution(sys, [], [1.0, 2.0], []) - getter = getsym(sc, []) - @test getter(sol) == [] - getter = getsym(sc, ()) - @test getter(sol) == [] -end + let sol = sol, sys = sys + getter = getsym(sys, []) + @test getter(sol) == [[] for _ in 1:length(sol.t)] + getter = getsym(sys, ()) + @test getter(sol) == [() for _ in 1:length(sol.t)] + sc = SymbolCache(nothing, [:a, :b], :t) + sol = FakeSolution(sys, maybe_CCA([], inbounds), + maybe_CCA([1.0, 2.0], inbounds), maybe_CCA([], inbounds)) + getter = getsym(sc, []) + @test getter(sol) == [] + getter = getsym(sc, ()) + @test getter(sol) == [] + if inbounds + test_no_boundschecks(sol) + end + end -sys = SymbolCache([:x, :y, :z], [:a, :b, :c]) -u = [1.0, 2.0, 3.0] -p = [10.0, 20.0, 30.0] -fs = FakeSolution(sys, u, p, nothing) -@test is_timeseries(fs) == NotTimeseries() - -for (sym, val, check_inference) in [ - (:x, u[1], true), - (1, u[1], true), - ([:x, :y], u[1:2], true), - ((:x, :y), Tuple(u[1:2]), true), - (1:2, u[1:2], true), - ([:x, 2], u[1:2], true), - ((:x, 2), Tuple(u[1:2]), true), - ([1, 2], u[1:2], true), - ((1, 2), Tuple(u[1:2]), true), - (:a, p[1], true), - ([:a, :b], p[1:2], true), - ((:a, :b), Tuple(p[1:2]), true), - ([:x, :a], [u[1], p[1]], false), - ((:x, :a), (u[1], p[1]), true), - ([1, :a], [u[1], p[1]], false), - ((1, :a), (u[1], p[1]), true), - (:(x + y + a + b), u[1] + u[2] + p[1] + p[2], true), - ([:(x + a), :(y + b)], [u[1] + p[1], u[2] + p[2]], true), - ((:(x + a), :(y + b)), (u[1] + p[1], u[2] + p[2]), true) -] - getter = getsym(sys, sym) - if check_inference - @inferred getter(fs) + sys = SymbolCache([:x, :y, :z], [:a, :b, :c]) + u = [1.0, 2.0, 3.0] + p = [10.0, 20.0, 30.0] + fs = FakeSolution(sys, maybe_CCA(u, inbounds), maybe_CCA(p, inbounds), nothing) + @test is_timeseries(fs) == NotTimeseries() + + for (sym, val, check_inference) in [ + (:x, u[1], true), + (1, u[1], true), + ([:x, :y], u[1:2], true), + ((:x, :y), Tuple(u[1:2]), true), + (1:2, u[1:2], true), + ([:x, 2], u[1:2], true), + ((:x, 2), Tuple(u[1:2]), true), + ([1, 2], u[1:2], true), + ((1, 2), Tuple(u[1:2]), true), + (:a, p[1], true), + ([:a, :b], p[1:2], true), + ((:a, :b), Tuple(p[1:2]), true), + ([:x, :a], [u[1], p[1]], false), + ((:x, :a), (u[1], p[1]), true), + ([1, :a], [u[1], p[1]], false), + ((1, :a), (u[1], p[1]), true), + (:(x + y + a + b), u[1] + u[2] + p[1] + p[2], true), + ([:(x + a), :(y + b)], [u[1] + p[1], u[2] + p[2]], true), + ((:(x + a), :(y + b)), (u[1] + p[1], u[2] + p[2]), true) + ] + getter = getsym(sys, sym; inbounds) + if check_inference + @inferred getter(fs) + end + @test getter(fs) == val + if inbounds + test_no_boundschecks(sol) + end end - @test getter(fs) == val end struct NonMarkovianWrapper{S <: SymbolCache} @@ -340,44 +409,61 @@ u = [u0 .* i for i in 1:11] p = [10.0, 20.0, 30.0] ts = 0.0:0.1:1.0 -fi = FakeIntegrator(sys, u0, p, ts[1]) -fs = FakeSolution(sys, u, p, ts) -getter = getsym(sys, :(x + y)) -@test getter(fi) ≈ 2.8 -@test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11] -@test getter(fs, 1) ≈ 2.8 - -pstate = ProblemState(; u = u0, p = p, t = ts[1], h = t -> t .* ones(length(u0))) -@test getter(pstate) ≈ 2.8 - struct TupleObservedWrapper{S} sys::S end SymbolicIndexingInterface.symbolic_container(t::TupleObservedWrapper) = t.sys SymbolicIndexingInterface.supports_tuple_observed(::TupleObservedWrapper) = true -@testset "Tuple observed" begin - sc = SymbolCache([:x, :y, :z], [:a, :b, :c]) - sys = TupleObservedWrapper(sc) - ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]) - getter = getsym(sys, (:(x + y), :(y + z))) - @test all(getter(ps) .≈ (3.0, 5.0)) - @test getter(ps) isa Tuple - @test_nowarn @inferred getter(ps) - getter = getsym(sys, (:(a + b), :(b + c))) - @test all(getter(ps) .≈ (0.3, 0.5)) - @test getter(ps) isa Tuple - @test_nowarn @inferred getter(ps) - - sc = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) - sys = TupleObservedWrapper(sc) - ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.1) - getter = getsym(sys, (:(x + y), :(y + t))) - @test all(getter(ps) .≈ (3.0, 2.1)) - @test getter(ps) isa Tuple - @test_nowarn @inferred getter(ps) - getter = getsym(sys, (:(a + b), :(b + c))) - @test all(getter(ps) .≈ (0.3, 0.5)) - @test getter(ps) isa Tuple - @test_nowarn @inferred getter(ps) +@testset "NonMarkovian: inbounds = $inbounds" for inbounds in [false, true] + fi = FakeIntegrator(sys, maybe_CCA(u0, inbounds), maybe_CCA(p, inbounds), ts[1]) + fs = FakeSolution(sys, maybe_CCA(maybe_CCA.(u, inbounds), inbounds), + maybe_CCA(p, inbounds), maybe_CCA(ts, inbounds)) + getter = getsym(sys, :(x + y); inbounds) + @test getter(fi) ≈ 2.8 + @test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11] + @test getter(fs, 1) ≈ 2.8 + if inbounds + test_no_boundschecks(fi) + test_no_boundschecks(fs) + end + + pstate = ProblemState(; u = maybe_CCA(u0, inbounds), p = maybe_CCA(p, inbounds), + t = ts[1], h = t -> t .* ones(length(u0))) + @test getter(pstate) ≈ 2.8 + + if inbounds + test_no_boundschecks(pstate) + end + + @testset "Tuple observed" begin + sc = SymbolCache([:x, :y, :z], [:a, :b, :c]) + sys = TupleObservedWrapper(sc) + ps = ProblemState(; u = maybe_CCA([1.0, 2.0, 3.0], inbounds), + p = maybe_CCA([0.1, 0.2, 0.3], inbounds)) + getter = getsym(sys, (:(x + y), :(y + z)); inbounds) + @test all(getter(ps) .≈ (3.0, 5.0)) + @test getter(ps) isa Tuple + @test_nowarn @inferred getter(ps) + getter = getsym(sys, (:(a + b), :(b + c)); inbounds) + @test all(getter(ps) .≈ (0.3, 0.5)) + @test getter(ps) isa Tuple + @test_nowarn @inferred getter(ps) + if inbounds + test_no_boundschecks(ps) + end + + sc = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) + sys = TupleObservedWrapper(sc) + ps = ProblemState(; u = maybe_CCA([1.0, 2.0, 3.0], inbounds), + p = maybe_CCA([0.1, 0.2, 0.3], inbounds), t = 0.1) + getter = getsym(sys, (:(x + y), :(y + t)); inbounds) + @test all(getter(ps) .≈ (3.0, 2.1)) + @test getter(ps) isa Tuple + @test_nowarn @inferred getter(ps) + getter = getsym(sys, (:(a + b), :(b + c)); inbounds) + @test all(getter(ps) .≈ (0.3, 0.5)) + @test getter(ps) isa Tuple + @test_nowarn @inferred getter(ps) + end end