Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common interface for rule definition #644

Open
gdalle opened this issue Nov 27, 2024 · 5 comments
Open

Common interface for rule definition #644

gdalle opened this issue Nov 27, 2024 · 5 comments
Labels
backend Related to one or more autodiff backends core Related to the core utilities of the package

Comments

@gdalle
Copy link
Member

gdalle commented Nov 27, 2024

It would be nice to have a universal translator between:

At the moment, DI.DifferentiateWith is a partial answer, but

  • it only works for one argument, number or array
  • it only defines rules for ForwardDiff and ChainRules (but that's not a limitation just a question of time)
  • it requires users to replace every f with DifferentiateWith(f, other_backend)
@gdalle gdalle added backend Related to one or more autodiff backends core Related to the core utilities of the package labels Nov 27, 2024
@gdalle
Copy link
Member Author

gdalle commented Nov 27, 2024

@antoine-levitt feel free to add comments here

@willtebbutt I'd love to hear your thoughts

@willtebbutt
Copy link
Member

willtebbutt commented Nov 27, 2024

I think we ought to be able to do a reasonable job of translating between Mooncake tangents / fdata / rdata and Enzyme's Duplicated / Active system, because they're both quite precisely specified. There are some differences, for example, if the primal value is a Vector{Int}, then the Mooncake tangent type is Vector{Mooncake.NoTangent} while the Enzyme type would (I believe) be another Vector{Int}, or might require that you make it a Const.

Re ChainRules, we'd have to make some choices, because the conversion is fundamentally ambiguous (ChainRules permits you to represent the tangent of anything with anything). That being said, I have to attempt to do this for ChainRules integration in Mooncake -- see here -- it's currently rather unsatisfying and incomplete.

This is all to say that I doubt a truly universal translator is possible, but we ought to be able to identify a set of types for which the conversion is possible, and provide reasonably informative error messages if someone attempts to do something we don't know how to handle.

@yebai
Copy link

yebai commented Nov 28, 2024

ReverseDiff and ForwardDiff are likely the most used autograd backends for Turing.jl, so we would like to keep supporting them even after Mooncake / Enzyme becomes more stable. Meanwhile, we likely will depreciate Zygote / Tracker in favour of Mooncake.

Unfortunately, supporting ReverseDiff and ForwardDiff means we must maintain and add many extra rules (see, e.g. DistributionsAD.jl, Bijectors.jl, and DynamicPPL.jl). I'd like to see an option to use Mooncake / Enzyme to define rules for ReverseDiff / ForwardDiff straightforwardly in the near future.

EDIT: I just noticed it is already possible to define ForwardDiff rules via Enzyme / Mooncake using DI.DifferentiateWith. Can similar functionality be added for ReverseDiff?

EDIT 2: it is often okay to manually replace f with DifferentiateWith(f, backend), so we don't necessarily need a new (universal) rule system. But please hook DifferentiateWith(f, backend) in more autograd libraries similar to ForwardDiff / Zygote.

@antoine-levitt
Copy link

antoine-levitt commented Nov 28, 2024

Yeah, from a user point of view it doesn't really matter what rule system we use, but we'd like to use just one. That unfortunately looks quite tricky.

In my application I want to differentiate wrt things that are hidden in structs, eg struct A{T} a::T end; f(x::A) = x.a^2. The appropriate ForwardDiff overload is f(x::A{<:Dual)), but really there's no way for a rule system to know this automatically, so I guess some amount of manual adaptation is needed, see eg ForwardDiffChainRules which requires you to say something like @ForwardDiff_frule f(x::A{Dual}). Is there something else that would prevent a UniversalChainRules package (the naming is bad I know) where I would say @universal_rrule f(x)... and that would do @rrule and also tell Enzyme/whatever to import it?

@wsmoses
Copy link

wsmoses commented Nov 28, 2024

You can just do Enzyme.@import_rrule / @import_frule and it'll import whatever chainrule is defined for that.

So your "universal_rrule" macro for the moment would be something like

macro universal_rrule
  @rrule. args...
  Enzyme.@import_rrule args...
end

That said I really don't think this should go here, probably in Chainrules or something

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend Related to one or more autodiff backends core Related to the core utilities of the package
Projects
None yet
Development

No branches or pull requests

5 participants