Skip to content

Commit

Permalink
Merge pull request #808 from AayushSabharwal/as/fix-tests
Browse files Browse the repository at this point in the history
fix: use `split = false` system for remake autodiff tests
  • Loading branch information
ChrisRackauckas authored Oct 7, 2024
2 parents b064f2b + 6184c69 commit 70ae7fb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ end
if is_observed(VA, sym)
f = observed(VA, sym)
p = parameter_values(VA)
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
u = state_values(VA)
t = current_time(VA)
y, back = Zygote.pullback(u, tunables) do u, tunables
f.(u, Ref(tunables), t)
_p = repack(tunables)
f.(u, Ref(_p), t)
end
gs = back(Δ)
(gs[1], nothing)
Expand Down
6 changes: 3 additions & 3 deletions test/downstream/remake_autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ function lotka_volterra(; name = name)
end

@named lotka_volterra_sys = lotka_volterra()
lotka_volterra_sys = structural_simplify(lotka_volterra_sys)
lotka_volterra_sys = structural_simplify(lotka_volterra_sys, split = false)
prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), [])
sol = solve(prob, Tsit5(), reltol = 1e-6, abstol = 1e-6)
u0 = [1.0 1.0]
p = [1.5 1.0 1.0 1.0]
u0 = [1.0, 1.0]
p = [1.5, 1.0, 1.0, 1.0]

function sum_of_solution(u0, p)
_prob = remake(prob, u0 = u0, p = p)
Expand Down

0 comments on commit 70ae7fb

Please sign in to comment.