Skip to content

Commit

Permalink
fix: fix observed variable adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 7, 2024
1 parent 9b96078 commit 6184c69
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ end
if is_observed(VA, sym)
f = observed(VA, sym)
p = parameter_values(VA)
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
u = state_values(VA)
t = current_time(VA)
y, back = Zygote.pullback(u, tunables) do u, tunables
f.(u, Ref(tunables), t)
_p = repack(tunables)
f.(u, Ref(_p), t)
end
gs = back(Δ)
(gs[1], nothing)
Expand Down

0 comments on commit 6184c69

Please sign in to comment.