Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make_zero! can fail on some immutable functions #1661

Open
ChrisRackauckas opened this issue Jul 22, 2024 · 6 comments
Open

make_zero! can fail on some immutable functions #1661

ChrisRackauckas opened this issue Jul 22, 2024 · 6 comments

Comments

@ChrisRackauckas
Copy link
Contributor

Example from SciML/SciMLSensitivity.jl#1067:

using OrdinaryDiffEq, Zygote, SciMLSensitivity

N0 = [0.0] # initial population
p = [100.0, 50.0] # steady-state pop., M
tspan = (0.0, 10.0) # integration time
f(D, u, p, t) = (D[1] = p[1] - u[1]) # system
prob = ODEProblem(f, N0, tspan, p)

# at time tinject1 we inject M1 cells
tinject = 8.0
condition(u, t, integrator) = t == tinject
affect(integrator) = integrator.u[1] += integrator.p[2]
cb = DiscreteCallback(condition, affect)

function loss(p)
    _prob = remake(prob, p = p)
    _sol = solve(_prob, Tsit5(); callback = cb,
        abstol = 1e-14, reltol = 1e-14, tstops = [tinject],
        sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()))
    _sol.u[end][1]
end

gZy = Zygote.gradient(loss, p)[1]

Throws:

ERROR: setfield!: immutable struct of type #136#140 cannot be changed
Stacktrace:
  [1] make_zero!
    @ ~/.julia/packages/Enzyme/SiyIj/src/compiler.jl:1601 [inlined]
  [2] make_zero!
    @ ~/.julia/packages/Enzyme/SiyIj/src/compiler.jl:1576 [inlined]
  [3] _vecjacobian!(dλ::SubArray{…}, y::Vector{…}, λ::SubArray{…}, p::Vector{…}, t::Float64, S::SciMLSensitivity.CallbackSensitivityFunction{…}, isautojacvec::EnzymeVJP, dgrad::SubArray{…}, dy::SubArray{…}, W::Nothing)
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:710
  [4] #vecjacobian!#18
    @ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:232 [inlined]
  [5] vecjacobian!
    @ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:229 [inlined]
  [6] (::SciMLSensitivity.var"#affect!#272"{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/callback_tracking.jl:339
  [7] #111
    @ ~/.julia/packages/DiffEqCallbacks/9fKPq/src/preset_time.jl:58 [inlined]
  [8] apply_discrete_callback!
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/callbacks.jl:613 [inlined]
  [9] apply_discrete_callback! (repeats 2 times)
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/callbacks.jl:635 [inlined]
 [10] apply_discrete_callback!
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/callbacks.jl:628 [inlined]
 [11] handle_callbacks!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/integrators/integrator_utils.jl:349
 [12] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/integrators/integrator_utils.jl:254
 [13] loopfooter!
    @ ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/integrators/integrator_utils.jl:207 [inlined]
 [14] solve!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/solve.jl:558
 [15] #__solve#560
    @ ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/solve.jl:7 [inlined]
 [16] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/solve.jl:1 [inlined]
 [17] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:612
 [18] solve_call
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:569 [inlined]
 [19] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1080 [inlined]
 [20] solve_up
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1066 [inlined]
 [21] #solve#51
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
 [22] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::BacksolveAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::CallbackSet{…}, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:448
 [23] _adjoint_sensitivities
    @ ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:405 [inlined]
 [24] #adjoint_sensitivities#63
    @ ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:401 [inlined]
 [25] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#310"{…})(Δ::ODESolution{…})
    @ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/concrete_solve.jl:619
 [26] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [27] (::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [28] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [29] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [30] #solve#51
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
 [31] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [32] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [33] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [34] solve
    @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
 [35] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [36] loss
    @ ~/Desktop/test.jl:84 [inlined]
 [37] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [38] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [39] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [40] top-level scope
    @ ~/Desktop/test.jl:91
Some type information was truncated. Use `show(err)` to see complete types.

But I haven't been able to isolate it any more.

@wsmoses
Copy link
Member

wsmoses commented Jul 22, 2024

Unfortunately I don't think this is an error we can resolve here (depending on the type).

You can't update an immutable type so doing an in place update doesn't make sense

@ChrisRackauckas
Copy link
Contributor Author

I would assume the behavior of make_zero! on a function with no enclosed data would just be a no op.

@wsmoses
Copy link
Member

wsmoses commented Jul 22, 2024 via email

@ChrisRackauckas
Copy link
Contributor Author

If the data cannot be mutated then it can safely be skipped though. The use case here is that make_zero(f) is used to make the shadow function and then Duplicated(f,make_zero!(df)) is used to reset the caches each time for safety in the the shadow data. But if the caches cannot be mutated, then they don't need to be zero'd.

@wsmoses
Copy link
Member

wsmoses commented Jul 22, 2024

Makezero! Internally uses a routine that indeed only updates the mutable parts. Presently, however we define the semantics of make zero! To zero all differentiable data and err otherwise (as it does here). If you were to pass the df as a duplicated in here you would get the wrong answer if it wasn’t zero’s fully

@ChrisRackauckas
Copy link
Contributor Author

Okay then we're missing a utility for generically handling functions correctly, since duplicated functions have this behavior where I want to set things to zero before reusing, unless the values are not writable (because of course that means the last pass hasn't changed them)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants