Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LogDensityProblems instead of gradient_logp #1877

Merged
merged 8 commits into from
Aug 26, 2022

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Aug 17, 2022

The PR requires tpapp/LogDensityProblems.jl#86

The PR removes gradient_logp and instead implements the LogDensityProblems.jl interface for Turing.LogDensityFunction. It seemed to work on simple examples (haven't run the full test suite yet) and allows us to re-use existing implementations and functionalities in LogDensityProblems, such as compilation of tapes for ReverseDiff and support of Enzyme (see comment below). That simplifies the code significantly and e.g. let's us remove the Requires block for ReverseDiff (taken care of in LogDensityProblems) and Memoization.

Probably in the future we might want to move LogDensityFunction and its implementation of the LogDensityProblems interface to DynamicPPL. And we might want to move the AD backend types (TrackerAD, ForwardDiffAD etc.) and their (to be improved?) API to a lightweight package (TuringADCore?) such that AdvancedVI, AdvancedHMC, Bijectors etc. could also use it instead of duplicating the implementation.


Enzyme support can be added by something like

struct EnzymeAD <: ADBackend end

ADBackend(::Val{:enzyme}) = EnzymeAD
function _setadbackend(::Val{:enzyme})
    ADBACKEND[] = :enzyme
end

function LogDensityProblems.ADgradient(::EnzymeAD, ℓ::Turing.LogDensityFunction)
    return LogDensityProblems.ADgradient(Val(:Enzyme), ℓ)
end

Unfortunately, it seems it does not even work on simple examples such as

@model function model()
    m ~ Normal(0, 1)
    s ~ InverseGamma()
    x ~ Normal(m, s)
end
sample(model() | (; x=0.5), NUTS{Turing.EnzymeAD}(), 10)

Intentionally the error messages and stacktraces are not included here since it is not part of this PR.

@torfjelde
Copy link
Member

Though I really appreciate the compatibility with Enzyme, I'm a bit uncertain about adoping LogDensityProblems 😕 It's somewhat rigid in what it expects/wants, and we're taking on a fair bit of features that we don't need (e.g. TransformVariables.jl).

Is this really the best way we can get support for Enzyme?

I do agree with you about generalizing/separating out the AD-stuff though. I was hoping that we at some point could just be using AbstractDifferentiation.jl, but AFAIK this won't fix the issue of having to use Requires.jl.

@devmotion
Copy link
Member Author

I think I disagree quite a bit 😄

Maybe the title of the PR and the Enzyme were both a bit misleading. Of course, we could implement support for Enzyme by basically the example above and copying the code from LogDensityProblems. But Enzyme support is not the motivation here (apart from that I think it might be better to improve our code before adding support for another AD) as currently it does not work anyway on the simple examples I tried.

The change is purely internal and does not require any major rewrite of Turing. It's just replacing the current gradient_logp with logdensity_and_gradient(ADgradient(...)). We have to create the gradient functions anyway, currently that's done in gen_... but with this PR it only requires ADgradient(...) which is arguably a bit nicer and in particular creates something we can dispatch on. Also LogExpFunctions already exists and let's us retrieve the model, sampler, varinfo, and context, so it mainly bundles these but does not limit us internally.

Generally, the LogDensityProblems interface is very simple and general: it's just logdensity, logdensity_and_gradient, capabilities, and dimension. The dimension we already computed in the grad_logp code (but had no function for it), logdensity and logdensity_and_gradient we need anyway, and the capabilities are a simple trait used to distinguish between log density functions for which logdensity_and_gradient can be computed and for which not. LogDensityProblems adds support for TransformVariables in form of a log density function type that deals with these transformations but we don't use it and don't have to use it. The main part of the package is really the API and the AD wrappers. I think TransformVariables is also a lightweight dependency as we depend on all of its dependencies anyway: https://github.com/tpapp/TransformVariables.jl/blob/master/Project.toml Apart from TransformVariables, the same is true for LogDensityProblems as well it seems: https://github.com/tpapp/LogDensityProblems.jl/blob/master/Project.toml

For quite some time I thought as well that we might want to switch to AbstractDifferentiation. However, I became more and more convinced that - apart from stability and implementation issues - it is the wrong level to target from our side. We don't just want to differentiate a function but rather we have to build a function from the log density function of the model for which we can evaluate the primal and the gradient efficiently. In particular, when we want to compile tapes or cache stuff it's mandatory to wrap the user/model-provided log density function when initializing the sampling procedure. But that's exactly what ADgradient is designed for - it takes the log density function and wraps it in a type that allows us to perform all these optimizations (eg by storing the compiled tape in the case of ReverseDiff, if desired). BTW we're also not limited by the default implementations in LogDensityProblems for ADgradient but - as already done in this PR - can specialize it for our LogDensityFunction and the AD backend types that we support (e.g. by building custom gradient configs with our custom tags in the ForwardDiff case). It's rather that I think in the long run LogDensityProblems might want to get rid of the Requires block and move to AbstractDifferentiation, once it is a bit more stable and mature and does not use Requires itself anymore.

So the short summary of my longer reply is probably: I think the design of LogDensityProblems, and in particular of ADgradient, matches quite well what we need in Turing, feels like the right level of abstraction (compared with AbstractDifferentiation), and allows us to simplify our code and re-use existing functionality in the ecosystem.

@yebai
Copy link
Member

yebai commented Aug 19, 2022

I like it because it simplifies current code and adds features. I tend to agree with @torfjelde and am concerned that we don't influence future maintenance of LogDensityProblems.

@tpapp Would it make sense chance that LogDensityProblems become part of JuliaMath or JuliaStats (see, e.g. DensityInterface) so the community can keep looking after it?

Cc @oschulz @phipsgabler who are the developers of DensityInterface

@tpapp
Copy link
Contributor

tpapp commented Aug 19, 2022

I am happy to transfer LogDensityProblems.jl.

@phipsgabler
Copy link
Member

I like the idea and trust David's judgement that it LogDensityProblems fits our needs. The only concern I have is that currently it has an empty intersection with DensityInterface, which I preferred as the requirement for AbstractPPL, but was considered not mature enough to serve as base for LogDensityInterface when it was begun (cf. tpapp/LogDensityProblems.jl#78). We would have two different kinds of interface, especially logdensity and logdensityof...

So my question is, how to best reconcile this?

@devmotion
Copy link
Member Author

My initial guess would be: LogDensityProblems could implement the DensityInterface interface by deprecating logdensity instead of DensityInterface.logdensityof and defining the trait IsDensity for the log density functions it owns. However, I don't think LogDensityProblems can or should become obsolete by DensityInterface and I think for us it's not sufficient to work with DensityInterface alone: we need something like logdensity_and_gradient as well, and we want something that allows us to bundle "regular" log density functions with their gradient for different AD backends (like ADgradient). Both are not provided by DensityInterface.

@oschulz
Copy link

oschulz commented Aug 20, 2022

I agree with @devmotion - DensityInterface.jl was essentially designed for this use case (with input from some Turing devs like @phipsgabler ).

(Also, Distributions and MeasureBase/MeasureInterface already support DensityInterface.)

@oschulz
Copy link

oschulz commented Aug 20, 2022

CC @cscherrer

@oschulz
Copy link

oschulz commented Aug 20, 2022

We also have AbstractDifferentiation.jl now - maybe combining that with DensityInterface.jl will be already be sufficient?

@devmotion
Copy link
Member Author

We also have AbstractDifferentiation.jl now - maybe combining that with DensityInterface.jl will be already be sufficient?

In Turing? Or LogDensityProblems? I think that it is not sufficient for us in Turing as we need the wrappers of logdensity function + gradient that are provided by LogDensityProblems. Quoted from above:

For quite some time I thought as well that we might want to switch to AbstractDifferentiation. However, I became more and more convinced that - apart from stability and implementation issues - it is the wrong level to target from our side. We don't just want to differentiate a function but rather we have to build a function from the log density function of the model for which we can evaluate the primal and the gradient efficiently. In particular, when we want to compile tapes or cache stuff it's mandatory to wrap the user/model-provided log density function when initializing the sampling procedure. But that's exactly what ADgradient is designed for - it takes the log density function and wraps it in a type that allows us to perform all these optimizations (eg by storing the compiled tape in the case of ReverseDiff, if desired). BTW we're also not limited by the default implementations in LogDensityProblems for ADgradient but - as already done in this PR - can specialize it for our LogDensityFunction and the AD backend types that we support (e.g. by building custom gradient configs with our custom tags in the ForwardDiff case). It's rather that I think in the long run LogDensityProblems might want to get rid of the Requires block and move to AbstractDifferentiation, once it is a bit more stable and mature and does not use Requires itself anymore.

@oschulz
Copy link

oschulz commented Aug 20, 2022

In Turing? Or LogDensityProblems? I think that it is not sufficient for us in Turing as we need the wrappers of logdensity function + gradient that are provided by LogDensityProblems.

Shouldn't something like this work?

params -> AD.value_and_gradient(ad_backend, logdensityof(posterior), params)

For example (with an MvNormal acting as a dummy posterior):

julia> using DensityInterface, AbstractDifferentiation, Zygote, Distributions
julia> d = MvNormal([1.2 0.5; 0.5 2.1])
julia> AD.value_and_gradient(AD.ZygoteBackend(), logdensityof(d), rand(d))
(-2.4361998489144887, ([0.5451883211220292, 0.032870160010889764],))

@devmotion
Copy link
Member Author

Surely you can use AbstractDifferentiation to get primal and gradient of a function for some backend (even though personally I don't think it's mature and efficient enough for general adoption yet), but it's not sufficient for our use cases. The best example is tape compilation with ReverseDiff: we want to perform and store optimizations when we initialize the sampling procedure. Hence just calling value_and_gradient with the logdensity function of the model every time we want the primal and gradient is not sufficient. And when we do that we exactly end up with something like https://github.com/tpapp/LogDensityProblems.jl/blob/d3ea2615ba3fd3e269c7948f02678d3ab8906e6d/src/AD_ReverseDiff.jl#L7-L10.

Hence to me it seems AbstractDifferentiation is too low-level for our purposes. But I assume for at least some backends LogDensityProblems could probably use the AbstractDifferentiation API at some point. Currently, I don't see a clear benefit though since it would not remove the Requires dependency and the implementation in LogDensityProblems should be as efficient as possible whereas I'm not fully convinced that's the case with AbstractDifferentiation yet.

@oschulz
Copy link

oschulz commented Aug 20, 2022

@tpapp could LogDensityProblems support DensityInterface? DensityInterface.DensityKind would enable LogDensityProblems if a given object supports DensityInterface and take advantage of it.

@oschulz
Copy link

oschulz commented Aug 20, 2022

The best example is tape compilation with ReverseDiff: we want to perform and store optimizations when we initialize the sampling procedure

Hm, that's a useful thing in general. Maybe we should lobby for AbstractDifferentiation to add a caching mechanism like this?

@devmotion
Copy link
Member Author

Maybe we should lobby for AbstractDifferentiation to add a caching mechanism like this?

There's already an issue, so I think people are aware of it: JuliaDiff/AbstractDifferentiation.jl#41 It's not clear though how it could be done.

@oschulz
Copy link

oschulz commented Aug 20, 2022

There's already an issue,

Ah, thanks!

@devmotion
Copy link
Member Author

Regarding this PR here: The Tracker test errors are caused by Julia 1.8, it broke some stuff in Tracker (I've seen also test failures in e.g. AbstractDifferentiation and DiffRules). FluxML/Tracker.jl#125

@tpapp
Copy link
Contributor

tpapp commented Aug 21, 2022

@oschulz could LogDensityProblems support DensityInterface?

Possibly, but I do not fully understand what the question is here and what that would involve. Can you please provide an example, eg what a user or implementor would need?

@devmotion I assume for at least some backends LogDensityProblems could probably use the AbstractDifferentiation API at some point.

Yes, that's the intention, once the latter stabilizes.

Currently LogDensityProblems comprises two things:

  1. an abstract API for log densities and their gradients, which would properly belong in their own tiny interface package,
  2. AD glue.

The reason I haven't split this into two packages is that I am really hoping that AbstractDifferentiation will take care of the AD glue in the long run, at which point it will be removed from the package.

In any case, I am happy to extend LogDensityProblems, transfer it and/or add people as maintainers, and PRs are welcome as always.

@oschulz
Copy link

oschulz commented Aug 21, 2022

Possibly, but I do not fully understand what the question is here and what that would involve. Can you please provide an example, eg what a user or implementor would need?

We built DensityInterface (with input from devs from Distributions, Turing, and other packages) so that people can define density-like objects using a super-lightweight dependency. Code that uses density-like objects (algorithm code) will typically need more deps, of course, like AD and so on.

So I think it would be nice if LogDensityProblems (which is closer to the algorithmic side since it also handles AD and specific transformations) would support DensityInterface densities. I have to admit I'm not quite sure how, but maybe we could figure something out together? In principle default implementation of LogDensityProblems.capabilities would seem the ideal pleace to use DensityInterface.DensityKind to query and object, but capabilities works on types and DensityKind works on instances. And DensityInterface doesn't provide dimensionality information (it is also intended for discrete cases and so on), would that be mandatory or could it be optional for LogDensityProblems?

@phipsgabler
Copy link
Member

phipsgabler commented Aug 23, 2022

I think our "problem" isn't technically difficult -- I have created a PoC port of LogDensityProblems to use DensityInterface here (diff view). Rather, it is of "social" nature: it requires messing with the interface of an existing package, deviating from the original considerations of the package author and surprising the existing user base, both of which make me hesitate.

If anything, we all need to have a discussion about how to reconcise the interfaces -- not everything is immediately clear (especially DensityKind, and expectations about usage of functions vs. dedicated callable objects).

@devmotion
Copy link
Member Author

I'd have some comments but I guess it would be better to discuss these things in an issue over at LogDensityProblems? (IIRC there was already an issue which is closed now.) There's nothing user-facing in this PR and no new API, so I think in principal it does not require any changes to or discussions of interfaces.

@tpapp
Copy link
Contributor

tpapp commented Aug 23, 2022

I agree with @devmotion, but want to reiterate that I am happy to add anything to LogDensityProblems that helps with this PR.

@torfjelde
Copy link
Member

I think I disagree quite a bit

Appreciate the write-up @devmotion!

And I'm also quite happy with the "interface" that is provided from LogDensityProblems.jl. One of my main concerns is that we'd be depending on TransformVariables.jl but wouldn't be using it, but, I as you pointed out, it doesn't seem like a particularly heavy dependency. If we're happy with this, I'm happy with taking it on as a dep:)

@oschulz
Copy link

oschulz commented Aug 24, 2022

we'd be depending on TransformVariables.jl

Can this be replace by depending on ChangesOfVariables.jl (much lighter), since TransformVariables.jl supports it now?

@devmotion
Copy link
Member Author

It does not seem completely trivial, I faced some problems: tpapp/LogDensityProblems.jl#88

@coveralls
Copy link

Pull Request Test Coverage Report for Build 2920168242

  • 52 of 54 (96.3%) changed or added relevant lines in 6 files are covered.
  • 3 unchanged lines in 3 files lost coverage.
  • Overall coverage increased (+1.2%) to 82.219%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/Turing.jl 3 5 60.0%
Files with Coverage Reduction New Missed Lines %
src/essential/ad.jl 1 84.29%
src/inference/gibbs.jl 1 97.33%
src/Turing.jl 1 81.25%
Totals Coverage Status
Change from base Build 2836425564: 1.2%
Covered Lines: 1156
Relevant Lines: 1406

💛 - Coveralls

@codecov
Copy link

codecov bot commented Aug 24, 2022

Codecov Report

Merging #1877 (005b6e0) into master (c0c8bc2) will increase coverage by 1.18%.
The diff coverage is 96.29%.

@@            Coverage Diff             @@
##           master    #1877      +/-   ##
==========================================
+ Coverage   81.03%   82.21%   +1.18%     
==========================================
  Files          24       21       -3     
  Lines        1466     1406      -60     
==========================================
- Hits         1188     1156      -32     
+ Misses        278      250      -28     
Impacted Files Coverage Δ
src/inference/Inference.jl 84.55% <ø> (+0.68%) ⬆️
src/Turing.jl 81.25% <60.00%> (-18.75%) ⬇️
src/contrib/inference/dynamichmc.jl 100.00% <100.00%> (ø)
src/contrib/inference/sghmc.jl 98.50% <100.00%> (+0.09%) ⬆️
src/essential/ad.jl 84.28% <100.00%> (+2.09%) ⬆️
src/inference/hmc.jl 78.06% <100.00%> (ø)
src/modes/ModeEstimation.jl 82.78% <100.00%> (ø)
src/inference/gibbs.jl 97.33% <0.00%> (ø)
... and 4 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@tpapp
Copy link
Contributor

tpapp commented Aug 25, 2022

In response to the discussion above, I have excised TransformVariables.jl as a dependency from LogDensityProblems.jl. See https://github.com/tpapp/TransformedLogDensities.jl/ (where it ended up) and tpapp/LogDensityProblems.jl#89 (which will be merged once the first package is registered and stuff is cleaned up).

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this PR brings some excellent improvement. Many thanks, @phipsgabler, @devmotion and @tpapp!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants