Skip to content

Commit

Permalink
broadcast MultipleSetters instead of map
Browse files Browse the repository at this point in the history
fixes SciML#82
  • Loading branch information
hexaeder committed Jun 14, 2024
1 parent e7dd822 commit cefe510
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ struct MultipleSetters{S} <: AbstractSetIndexer
end

function (ms::MultipleSetters)(prob, val)
map((s!, v) -> s!(prob, v), ms.setters, val)
broadcast((s!, v) -> s!(prob, v), ms.setters, val)
end

for (t1, t2) in [
Expand Down
15 changes: 11 additions & 4 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ for sys in [
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
([:a, :b], p[1:2], 42, true),
]
get = getp(sys, sym)
set! = setp(sys, sym)
Expand All @@ -77,13 +78,13 @@ for sys in [
end
@test fi.counter[] == 1

@test get(fi) == newval
@test all(get(fi) .== newval)
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2

fi.ps[sym] = newval
@test get(fi) == newval
@test all(get(fi) .== newval)
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
Expand All @@ -98,7 +99,7 @@ for sys in [
else
set!(p, newval)
end
@test get(p) == newval
@test all(get(p) .== newval)
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
Expand Down Expand Up @@ -150,6 +151,12 @@ end

Base.getindex(mpo::MyParameterObject, i) = mpo.p[i]

# check throws if setp dimensions do not match
sys = SymbolCache([:x, :y, :z], [:a, :b, :c, :d], [:t])
fi = FakeIntegrator(sys, [1.0, 2.0, 3.0], 0.0, Ref(0))
@test_throws DimensionMismatch setp(fi, 1:2)(fi, [-1.0, -2.0, -3.0])
@test_throws DimensionMismatch setp(fi, 1:3)(fi, [-1.0, -2.0])

struct FakeSolution
sys::SymbolCache
u::Vector{Vector{Float64}}
Expand Down

0 comments on commit cefe510

Please sign in to comment.