-
-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Testing use of rrule #1103
Comments
So the difficulty here is that the vjp used in the adjoint rule of the differential equation is almost always (at this point) a different AD engine from the AD engine the user uses. Right now it is all Zygote that users tend to gravitate towards, and almost all ODE definitions are mutating. So the user uses Zygote, it hits that chain rule, and it goes to SciMLSensitivity.jl which defines the full adjoint and has AD deps, which then normally would slap Enzyme in there. If they haven't added SciMLSensitivity.jl, then they get an error https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L1593-L1597 . So because of that, the real test suite is SciMLSensitivity.jl: this package just shuttles over to there, and it's here because that way we can give an informative error message. Our current tests on "alternative AD frontends", i.e. non-Zygote, is https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/alternative_ad_frontend.jl . That's not exactly comprehensive, but the vast majority of work in the package is not on the AD front end but in the adjoint definitions, and thus that has tended to work. It's a little bit of a mess that it's not with DiffEqBase, but that's because most of the work is on adjoints so the front end parts in DiffEqBase don't tend to change much. This situation will hopefully improve with JuliaLang/julia#55516, but I digress. From what I see in your set of the front end tests, one of the big things you're missing is flexing the solution interface a bit more. |
Thanks for your help with this @ChrisRackauckas . I've added some more tests to the Mooncake PR now, which I think overs all of the bits you've mentioned above. I've also replied to your comment on the PR with a couple of additional questions. |
Question❓
This package defines an extension for ChainRulesCore, in which an rrule for
solve_up
is defined:DiffEqBase.jl/ext/DiffEqBaseChainRulesCoreExt.jl
Line 22 in 9de748d
I am attempting to make use of this rule in Mooncake in compintell/Mooncake.jl#320 to ensure that I can use Mooncake to differentiate functions which solve differential equations inside themselves. I am attempting to test that I've implemented my wrapper around this rule correctly -- see https://github.com/compintell/Mooncake.jl/pull/320/files#r1835360913 . However, I have no idea whether I've got a good collection of tests -- you'll see that I really just test that the function
build_and_solve
(which solves Lotka-Volterra equations) can be differentiated using Mooncake for a variety ofsensealg
s.My question: how / where is the
rrule
forsolve_up
tested inside this package, and is there an existing set of tests I can either copy over, or take inspiration from, in order to check that my wrapper for therrule
works in all cases?The text was updated successfully, but these errors were encountered: