Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GaussAdjoint with callbacks #1060

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

ChrisRackauckas
Copy link
Member

No description provided.

@Vaibhavdixit02
Copy link
Member

Ah, this was what #1034 was trying to address lol

@ChrisRackauckas
Copy link
Member Author

That would've been good to know 😅

@ChrisRackauckas ChrisRackauckas force-pushed the gaussadjoint_callbacks branch from 07d6bdf to 33a57d2 Compare June 7, 2024 01:51
src/concrete_solve.jl Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member Author

Bump?

@frankschae frankschae force-pushed the gaussadjoint_callbacks branch from f248d52 to ed0f98f Compare June 29, 2024 05:34
@frankschae
Copy link
Member

hmm ... it's probably close but it currently fails for the callback where the correction should be 0:

g(sol) = sum(sol)
function dg!(out, u, p, t, i)
    (out .= 1)
end
@testset "callbacks with no effect" begin
    condition(u, t, integrator) = t == 5
    affect!(integrator) = integrator.u[1] += 0.0
    cb = DiscreteCallback(condition, affect!, save_positions = (false, false))
    tstops = [5.0]
    test_discrete_callback(cb, tstops, g, dg!)
end

I think it's the line:

if sensealg isa GaussAdjoint
            @assert integrator.f.f isa ODEGaussAdjointSensitivityFunction
            @show integrator.f.f.integrating_cb.affect!.integrand_values.integrand dgrad
            integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad

in callback_tracking.jl, whch for the example above prints:

integrator.f.f.integrating_cb.affect!.integrand_values.integrand = [4.318834073956617, -23.50595577533703, 3.507883353046969, -26.70725798446173]
dgrad = [-0.0, -0.0, -0.0, -0.0]

@ChrisRackauckas
Copy link
Member Author

Is that with the VJPs as Enzyme or ReverseDiff? IIUC it always defaults to ReverseDiff right now?

@ChrisRackauckas
Copy link
Member Author

The incorrect values I think stem from missing make_zero!s, but it's currently waiting on an Enzyme tag from @wsmoses before #1067 finishes and then this can get retested.

@ChrisRackauckas
Copy link
Member Author

@jClugstor did your PR look at GaussAdjoint?

@jClugstor
Copy link
Contributor

No, but there is a '@test_broken' for 'GaussAdjoint' with a callback in that PR, so if callbacks get fixed that will need to change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants