Skip to content

Commit

Permalink
broadcast MultipleSetters instead of map
Browse files Browse the repository at this point in the history
fixes SciML#82
  • Loading branch information
hexaeder committed Jun 3, 2024
1 parent c9c7b6c commit 426cdf9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ struct MultipleSetters{S} <: AbstractSetIndexer
end

function (ms::MultipleSetters)(prob, val)
map((s!, v) -> s!(prob, v), ms.setters, val)
broadcast((s!, v) -> s!(prob, v), ms.setters, val)
end

for (t1, t2) in [
Expand Down
14 changes: 10 additions & 4 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ for pType in [Vector, Tuple]
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
([:a, :b], p[1:2], 42, true),
]
get = getp(sys, sym)
set! = setp(sys, sym)
Expand All @@ -58,13 +59,13 @@ for pType in [Vector, Tuple]
end
@test fi.counter[] == 1

@test get(fi) == newval
@test all(get(fi) .== newval)
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2

fi.ps[sym] = newval
@test get(fi) == newval
@test all(get(fi) .== newval)
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
Expand All @@ -79,7 +80,7 @@ for pType in [Vector, Tuple]
else
set!(p, newval)
end
@test get(p) == newval
@test all(get(p) .== newval)
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
Expand All @@ -99,6 +100,11 @@ for pType in [Vector, Tuple]
end
end

# check throws if setp dimensions do not match
fi = FakeIntegrator(sys, [1.0, 2.0, 3.0], Ref(0))
@test_throws DimensionMismatch setp(fi, 1:2)(fi, [-1.0, -2.0, -3.0])
@test_throws DimensionMismatch setp(fi, 1:3)(fi, [-1.0, -2.0])

struct FakeSolution
sys::SymbolCache
u::Vector{Vector{Float64}}
Expand Down

0 comments on commit 426cdf9

Please sign in to comment.