From a634bf6f1ae38908b7a21551207e768ae174f2df Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 17 Aug 2022 23:26:04 +0200 Subject: [PATCH 1/8] Use LogDensityProblems instead of `gradient_logp` --- Project.toml | 5 +- docs/src/using-turing/autodiff.md | 2 +- src/Turing.jl | 9 ++ src/contrib/inference/dynamichmc.jl | 41 +++------ src/contrib/inference/sghmc.jl | 22 +++-- src/essential/Essential.jl | 15 +--- src/essential/ad.jl | 126 ++++++++-------------------- src/essential/compat/reversediff.jl | 80 ------------------ src/inference/Inference.jl | 4 +- src/inference/hmc.jl | 27 +++--- src/modes/ModeEstimation.jl | 11 +-- test/Project.toml | 2 - test/essential/ad.jl | 28 ++++--- test/runtests.jl | 3 +- test/skipped/unit_test_helper.jl | 6 +- test/test_utils/ad_utils.jl | 8 +- 16 files changed, 121 insertions(+), 268 deletions(-) delete mode 100644 src/essential/compat/reversediff.jl diff --git a/Project.toml b/Project.toml index f390512ca..95e806677 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -20,6 +19,7 @@ EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -43,7 +43,6 @@ AdvancedVI = "0.1" BangBang = "0.3" Bijectors = "0.8, 0.9, 0.10" DataStructures = "0.18" -DiffResults = "1" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" @@ -51,6 +50,7 @@ DynamicPPL = "0.20" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.6.7, 0.7" +LogDensityProblems = "0.11" MCMCChains = "5" NamedArrays = "0.9" Reexport = "0.2, 1" @@ -60,5 +60,4 @@ SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2" StatsBase = "0.32, 0.33" StatsFuns = "0.8, 0.9, 1" Tracker = "0.2.3" -ZygoteRules = "0.2" julia = "1.6" diff --git a/docs/src/using-turing/autodiff.md b/docs/src/using-turing/autodiff.md index dd9e2976d..44eff6133 100644 --- a/docs/src/using-turing/autodiff.md +++ b/docs/src/using-turing/autodiff.md @@ -10,7 +10,7 @@ title: Automatic Differentiation Turing supports four packages of automatic differentiation (AD) in the back end during sampling. The default AD backend is [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) for forward-mode AD. Three reverse-mode AD backends are also supported, namely [Tracker](https://github.com/FluxML/Tracker.jl), [Zygote](https://github.com/FluxML/Zygote.jl) and [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl). `Zygote` and `ReverseDiff` are supported optionally if explicitly loaded by the user with `using Zygote` or `using ReverseDiff` next to `using Turing`. -To switch between the different AD backends, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forwarddiff` (`ForwardDiff`), `:tracker` (`Tracker`), `:zygote` (`Zygote`) or `:reversediff` (`ReverseDiff.jl`). When using `ReverseDiff`, to compile the tape only once and cache it for later use, the user needs to load [Memoization.jl](https://github.com/marius311/Memoization.jl) first with `using Memoization` then call `Turing.setrdcache(true)`. However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine. To empty the cache, you can call `Turing.emptyrdcache()`. +To switch between the different AD backends, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forwarddiff` (`ForwardDiff`), `:tracker` (`Tracker`), `:zygote` (`Zygote`) or `:reversediff` (`ReverseDiff.jl`). When using `ReverseDiff`, to compile the tape only once and cache it for later use, the user has to call `Turing.setrdcache(true)`. However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine. ## Compositional Sampling with Differing AD Modes diff --git a/src/Turing.jl b/src/Turing.jl index 86eeb9be3..fd9f1b861 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -9,6 +9,7 @@ using Tracker: Tracker import AdvancedVI import DynamicPPL: getspace, NoDist, NamedDist +import LogDensityProblems import Random const PROGRESS = Ref(true) @@ -37,12 +38,20 @@ function (f::LogDensityFunction)(θ::AbstractVector) return getlogp(last(DynamicPPL.evaluate!!(f.model, VarInfo(f.varinfo, f.sampler, θ), f.sampler, f.context))) end +# LogDensityProblems interface +LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) = f(θ) +LogDensityProblems.dimension(f::LogDensityFunction) = length(f.varinfo[f.sampler]) +function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) + return LogDensityProblems.LogDensityOrder{0}() +end + # Standard tag: Improves stacktraces # Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ struct TuringTag end # Allow Turing tag in gradient etc. calls of the log density function ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::LogDensityFunction, ::AbstractArray{V}) where {V} = true +ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:LogDensityFunction}, ::AbstractArray{V}) where {V} = true # Random probability measures. include("stdlib/distributions.jl") diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index 6a9207217..552fb6e98 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -19,25 +19,6 @@ DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}() DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space -# Only define traits for `DynamicNUTS` sampler to avoid type piracy and surprises -# TODO: Implement generally with `LogDensityProblems` -const DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext} - -function DynamicHMC.dimension(ℓ::DynamicHMCLogDensity) - return length(ℓ.varinfo[ℓ.sampler]) -end - -function DynamicHMC.capabilities(::Type{<:DynamicHMCLogDensity}) - return DynamicHMC.LogDensityOrder{1}() -end - -function DynamicHMC.logdensity_and_gradient( - ℓ::DynamicHMCLogDensity, - x::AbstractVector, -) - return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler, ℓ.context) -end - """ DynamicNUTSState @@ -46,9 +27,10 @@ State of the [`DynamicNUTS`](@ref) sampler. # Fields $(TYPEDFIELDS) """ -struct DynamicNUTSState{V<:AbstractVarInfo,C,M,S} +struct DynamicNUTSState{L,V<:AbstractVarInfo,C,M,S} + logdensity::L vi::V - "Cache of sample, log density, and gradient of log density." + "Cache of sample, log density, and gradient of log density evaluation." cache::C metric::M stepsize::S @@ -56,15 +38,15 @@ end # Implement interface of `Gibbs` sampler function gibbs_state( - model::Model, + ::Model, spl::Sampler{<:DynamicNUTS}, state::DynamicNUTSState, varinfo::AbstractVarInfo, ) # Update the previous evaluation. - ℓ = Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext()) + ℓ = state.logdensity Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl]) - return DynamicNUTSState(varinfo, Q, state.metric, state.stepsize) + return DynamicNUTSState(ℓ, varinfo, Q, state.metric, state.stepsize) end DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform() @@ -82,10 +64,13 @@ function DynamicPPL.initialstep( model(rng, vi, spl) end + # Define log-density function. + ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())) + # Perform initial step. results = DynamicHMC.mcmc_keep_warmup( rng, - Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()), + ℓ, 0; initialization = (q = vi[spl],), reporter = DynamicHMC.NoProgressReport(), @@ -99,7 +84,7 @@ function DynamicPPL.initialstep( # Create first sample and state. sample = Transition(vi) - state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ) + state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ) return sample, state end @@ -113,7 +98,7 @@ function AbstractMCMC.step( ) # Compute next sample. vi = state.vi - ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + ℓ = state.logdensity steps = DynamicHMC.mcmc_steps( rng, DynamicHMC.NUTS(), @@ -129,7 +114,7 @@ function AbstractMCMC.step( # Create next sample and state. sample = Transition(vi) - newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize) + newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize) return sample, newstate end diff --git a/src/contrib/inference/sghmc.jl b/src/contrib/inference/sghmc.jl index d026b3656..0cb42c8a4 100644 --- a/src/contrib/inference/sghmc.jl +++ b/src/contrib/inference/sghmc.jl @@ -41,7 +41,8 @@ function SGHMC{AD}( return SGHMC{AD,space,typeof(_learning_rate)}(_learning_rate, _momentum_decay) end -struct SGHMCState{V<:AbstractVarInfo, T<:AbstractVector{<:Real}} +struct SGHMCState{L,V<:AbstractVarInfo, T<:AbstractVector{<:Real}} + logdensity::L vi::V velocity::T end @@ -61,7 +62,8 @@ function DynamicPPL.initialstep( # Compute initial sample and state. sample = Transition(vi) - state = SGHMCState(vi, zero(vi[spl])) + ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())) + state = SGHMCState(ℓ, vi, zero(vi[spl])) return sample, state end @@ -74,9 +76,10 @@ function AbstractMCMC.step( kwargs... ) # Compute gradient of log density. + ℓ = state.logdensity vi = state.vi θ = vi[spl] - _, grad = gradient_logp(θ, vi, model, spl) + grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) # Update latent variables and velocity according to # equation (15) of Chen et al. (2014) @@ -92,7 +95,7 @@ function AbstractMCMC.step( # Compute next sample and state. sample = Transition(vi) - newstate = SGHMCState(vi, newv) + newstate = SGHMCState(ℓ, vi, newv) return sample, newstate end @@ -191,7 +194,8 @@ metadata(t::SGLDTransition) = (lp = t.lp, SGLD_stepsize = t.stepsize) DynamicPPL.getlogp(t::SGLDTransition) = t.lp -struct SGLDState{V<:AbstractVarInfo} +struct SGLDState{L,V<:AbstractVarInfo} + logdensity::L vi::V step::Int end @@ -211,7 +215,8 @@ function DynamicPPL.initialstep( # Create first sample and state. sample = SGLDTransition(vi, zero(spl.alg.stepsize(0))) - state = SGLDState(vi, 1) + ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())) + state = SGLDState(ℓ, vi, 1) return sample, state end @@ -224,9 +229,10 @@ function AbstractMCMC.step( kwargs... ) # Perform gradient step. + ℓ = state.logdensity vi = state.vi θ = vi[spl] - _, grad = gradient_logp(θ, vi, model, spl) + grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) step = state.step stepsize = spl.alg.stepsize(step) θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ)) @@ -237,7 +243,7 @@ function AbstractMCMC.step( # Compute next sample and state. sample = SGLDTransition(vi, stepsize) - newstate = SGLDState(vi, state.step + 1) + newstate = SGLDState(ℓ, vi, state.step + 1) return sample, newstate end diff --git a/src/essential/Essential.jl b/src/essential/Essential.jl index d18a6e7ab..df0a9b5ac 100644 --- a/src/essential/Essential.jl +++ b/src/essential/Essential.jl @@ -13,22 +13,13 @@ import Bijectors: link, invlink using AdvancedVI using StatsFuns: logsumexp, softmax @reexport using DynamicPPL -using Requires import AdvancedPS -import DiffResults -import ZygoteRules +import LogDensityProblems include("container.jl") include("ad.jl") -function __init__() - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - export ReverseDiffAD, getrdcache, setrdcache, emptyrdcache - end -end - export @model, @varname, generate_observe, @@ -53,11 +44,13 @@ export @model, ForwardDiffAD, TrackerAD, ZygoteAD, + ReverseDiffAD, value, - gradient_logp, CHUNKSIZE, ADBACKEND, setchunksize, + setrdcache, + getrdcache, verifygrad, @logprob_str, @prob_str diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 0e4005395..b56ce0140 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -18,6 +18,9 @@ end function _setadbackend(::Val{:zygote}) ADBACKEND[] = :zygote end +function _setadbackend(::Val{:reversediff}) + ADBACKEND[] = :reversediff +end const ADSAFE = Ref(false) function setadsafe(switch::Bool) @@ -39,9 +42,7 @@ struct ForwardDiffAD{chunk,standardtag} <: ADBackend end # Use standard tag if not specified otherwise ForwardDiffAD{N}() where {N} = ForwardDiffAD{N,true}() -getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk -getchunksize(::Type{<:Sampler{Talg}}) where Talg = getchunksize(Talg) -getchunksize(::Type{SampleFromPrior}) = CHUNKSIZE[] +getchunksize(::ForwardDiffAD{chunk}) where chunk = chunk standardtag(::ForwardDiffAD{<:Any,true}) = true standardtag(::ForwardDiffAD) = false @@ -49,12 +50,24 @@ standardtag(::ForwardDiffAD) = false struct TrackerAD <: ADBackend end struct ZygoteAD <: ADBackend end +struct ReverseDiffAD{cache} <: ADBackend end + +const RDCache = Ref(false) + +setrdcache(b::Bool) = setrdcache(Val(b)) +setrdcache(::Val{false}) = RDCache[] = false +setrdcache(::Val{true}) = RDCache[] = true + +getrdcache() = RDCache[] + ADBackend() = ADBackend(ADBACKEND[]) ADBackend(T::Symbol) = ADBackend(Val(T)) ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} ADBackend(::Val{:tracker}) = TrackerAD ADBackend(::Val{:zygote}) = ZygoteAD +ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()} + ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") """ @@ -63,54 +76,15 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t Find the autodifferentiation backend of the algorithm `alg`. """ getADbackend(spl::Sampler) = getADbackend(spl.alg) -getADbackend(spl::SampleFromPrior) = ADBackend()() +getADbackend(::SampleFromPrior) = ADBackend()() -""" - gradient_logp( - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler, - ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() - ) - -Computes the value of the log joint of `θ` and its gradient for the model -specified by `(vi, sampler, model)` using whichever automatic differentation -tool is currently active. -""" -function gradient_logp( - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler, - ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - return gradient_logp(getADbackend(sampler), θ, vi, model, sampler, ctx) +function LogDensityProblems.ADgradient(ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(getADbackend(ℓ.sampler), ℓ) end -""" -gradient_logp( - backend::ADBackend, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - -Compute the value of the log joint of `θ` and its gradient for the model -specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{N}()` uses `ForwardDiff.jl` with chunk size `N`, `TrackerAD()` uses `Tracker.jl` and `ZygoteAD()` uses `Zygote.jl`. -""" -function gradient_logp( - ad::ForwardDiffAD, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) +function LogDensityProblems.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensityFunction) + θ = ℓ.varinfo[ℓ.sampler] + f = Base.Fix1(LogDensityProblems.logdensity, ℓ) # Define configuration for ForwardDiff. tag = if standardtag(ad) @@ -118,58 +92,30 @@ function gradient_logp( else ForwardDiff.Tag(f, eltype(θ)) end - chunk_size = getchunksize(typeof(ad)) + chunk_size = getchunksize(ad) config = if chunk_size == 0 ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(θ), tag) else ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size), tag) end - # Obtain both value and gradient of the log density function. - out = DiffResults.GradientResult(θ) - ForwardDiff.gradient!(out, f, θ, config) - logp = DiffResults.value(out) - ∂logp∂θ = DiffResults.gradient(out) + return LogDensityProblems.ADgradient(Val(:ForwardDiff), ℓ; gradientconfig=config) +end - return logp, ∂logp∂θ +function LogDensityProblems.ADgradient(::TrackerAD, ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(Val(:Tracker), ℓ) end -function gradient_logp( - ::TrackerAD, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) - - # Compute forward pass and pullback. - l_tracked, ȳ = Tracker.forward(f, θ) - - # Remove tracking info. - l::typeof(getlogp(vi)) = Tracker.data(l_tracked) - ∂l∂θ::typeof(θ) = Tracker.data(only(ȳ(1))) - - return l, ∂l∂θ + +function LogDensityProblems.ADgradient(::ZygoteAD, ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(Val(:Zygote), ℓ) end -function gradient_logp( - backend::ZygoteAD, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) - - # Compute forward pass and pullback. - l::typeof(getlogp(vi)), ȳ = ZygoteRules.pullback(f, θ) - ∂l∂θ::typeof(θ) = only(ȳ(1)) - - return l, ∂l∂θ +for cache in (:true, :false) + @eval begin + function LogDensityProblems.ADgradient(::ReverseDiffAD{$cache}, ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(Val(:ReverseDiff), ℓ; compile=Val($cache)) + end + end end function verifygrad(grad::AbstractVector{<:Real}) diff --git a/src/essential/compat/reversediff.jl b/src/essential/compat/reversediff.jl deleted file mode 100644 index cc077c5e0..000000000 --- a/src/essential/compat/reversediff.jl +++ /dev/null @@ -1,80 +0,0 @@ -using .ReverseDiff: compile, GradientTape - -struct ReverseDiffAD{cache} <: ADBackend end -const RDCache = Ref(false) -setrdcache(b::Bool) = setrdcache(Val(b)) -setrdcache(::Val{false}) = RDCache[] = false -setrdcache(::Val) = throw("Memoization.jl is not loaded. Please load it before setting the cache to true.") -function emptyrdcache end - -getrdcache() = RDCache[] -ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()} -function _setadbackend(::Val{:reversediff}) - ADBACKEND[] = :reversediff -end - -function gradient_logp( - backend::ReverseDiffAD{false}, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) - - # Obtain both value and gradient of the log density function. - tp, result = taperesult(f, θ) - ReverseDiff.gradient!(result, tp, θ) - logp = DiffResults.value(result) - ∂logp∂θ = DiffResults.gradient(result) - - return logp, ∂logp∂θ -end - -tape(f, x) = GradientTape(f, x) -taperesult(f, x) = (tape(f, x), DiffResults.GradientResult(x)) - -@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin - setrdcache(::Val{true}) = RDCache[] = true - function emptyrdcache() - Memoization.empty_cache!(memoized_taperesult) - return - end - - function gradient_logp( - backend::ReverseDiffAD{true}, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() - ) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) - - # Obtain both value and gradient of the log density function. - ctp, result = memoized_taperesult(f, θ) - ReverseDiff.gradient!(result, ctp, θ) - logp = DiffResults.value(result) - ∂logp∂θ = DiffResults.gradient(result) - - return logp, ∂logp∂θ - end - - # This makes sure we generate a single tape per Turing model and sampler - struct RDTapeKey{F, Tx} - f::F - x::Tx - end - function Memoization._get!(f, d::Dict, keys::Tuple{Tuple{RDTapeKey}, Any}) - key = keys[1][1] - return Memoization._get!(f, d, (key.f, typeof(key.x), size(key.x), Threads.threadid())) - end - memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x)) - Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey) - return compiledtape(k.f, k.x), DiffResults.GradientResult(k.x) - end - compiledtape(f, x) = compile(GradientTape(f, x)) -end diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 8dfe8306a..2024cff2c 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -28,8 +28,9 @@ import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import AdvancedPS import BangBang -import ..Essential: getchunksize, getADbackend +import ..Essential: getADbackend import EllipticalSliceSampling +import LogDensityProblems import Random import MCMCChains import StatsBase: predict @@ -76,7 +77,6 @@ abstract type Hamiltonian{AD} <: InferenceAlgorithm end abstract type StaticHamiltonian{AD} <: Hamiltonian{AD} end abstract type AdaptiveHamiltonian{AD} <: Hamiltonian{AD} end -getchunksize(::Type{<:Hamiltonian{AD}}) where AD = getchunksize(AD) getADbackend(::Hamiltonian{AD}) where AD = AD() # Algorithm for sampling from the prior diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 34274a32e..c9583754d 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -159,8 +159,11 @@ function DynamicPPL.initialstep( # Create a Hamiltonian. metricT = getmetricT(spl.alg) metric = metricT(length(theta)) - ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) - logπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + ℓ = LogDensityProblems.ADgradient( + Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + ) + logπ = Base.Fix1(LogDensityProblems.logdensity, ℓ) + ∂logπ∂θ(x) = LogDensityProblems.logdensity_and_gradient(ℓ, x) hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ) # Compute phase point z. @@ -262,8 +265,11 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) - ℓπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) - ∂ℓπ∂θ = gen_∂logπ∂θ(vi, spl, model) + ℓ = LogDensityProblems.ADgradient( + Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + ) + ℓπ = Base.Fix1(LogDensityProblems.logdensity, ℓ) + ∂ℓπ∂θ = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓ) return AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) end @@ -422,19 +428,6 @@ end getstepsize(sampler::Sampler{<:Hamiltonian}, state) = sampler.alg.ϵ getstepsize(sampler::Sampler{<:AdaptiveHamiltonian}, state) = AHMC.getϵ(state.adaptor) -""" - gen_∂logπ∂θ(vi, spl::Sampler, model) - -Generate a function that takes a vector of reals `θ` and compute the logpdf and -gradient at `θ` for the model specified by `(vi, spl, model)`. -""" -function gen_∂logπ∂θ(vi, spl::Sampler, model) - function ∂logπ∂θ(x) - return gradient_logp(x, vi, model, spl) - end - return ∂logπ∂θ -end - gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim) function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state) return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc)) diff --git a/src/modes/ModeEstimation.jl b/src/modes/ModeEstimation.jl index 7efb3be7e..8fb1faed5 100644 --- a/src/modes/ModeEstimation.jl +++ b/src/modes/ModeEstimation.jl @@ -107,14 +107,9 @@ end function (f::OptimLogDensity)(F, G, z) if G !== nothing # Calculate negative log joint and its gradient. - sampler = f.sampler - neglogp, ∇neglogp = Turing.gradient_logp( - z, - DynamicPPL.VarInfo(f.varinfo, sampler, z), - f.model, - sampler, - f.context, - ) + # TODO: Make OptimLogDensity already an LogDensityProblems.ADgradient? Allow to specify AD? + ℓ = LogDensityProblems.ADgradient(f) + neglogp, ∇neglogp = LogDensityProblems.logdensity_and_gradient(ℓ, z) # Save the gradient to the pre-allocated array. copyto!(G, ∇neglogp) diff --git a/test/Project.toml b/test/Project.toml index 450ad8b04..13eb9d3b3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,7 +13,6 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" @@ -45,7 +44,6 @@ DynamicPPL = "0.20" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12" MCMCChains = "5" -Memoization = "0.1.4" NamedArrays = "0.9.4" Optim = "0.22, 1.0" Optimization = "3.5" diff --git a/test/essential/ad.jl b/test/essential/ad.jl index c36006c94..dcae69cb8 100644 --- a/test/essential/ad.jl +++ b/test/essential/ad.jl @@ -27,11 +27,19 @@ _x = [_m, _s] grad_FWAD = sort(g(_x)) + ℓ = LogDensityFunction(vi, ad_test_f, SampleFromPrior(), DynamicPPL.DefaultContext()) x = map(x->Float64(x), vi[SampleFromPrior()]) - ∇E1 = gradient_logp(TrackerAD(), x, vi, ad_test_f)[2] + + trackerℓ = LogDensityProblems.ADgradient(TrackerAD(), ℓ) + @test trackerℓ isa LogDensityProblems.TrackerGradientLogDensity + @test trackerℓ.ℓ === ℓ + ∇E1 = LogDensityProblems.logdensity_and_gradient(trackerℓ, x)[2] @test sort(∇E1) ≈ grad_FWAD atol=1e-9 - ∇E2 = gradient_logp(ZygoteAD(), x, vi, ad_test_f)[2] + zygoteℓ = LogDensityProblems.ADgradient(ZygoteAD(), ℓ) + @test zygoteℓ isa LogDensityProblems.ZygoteGradientLogDensity + @test zygoteℓ.ℓ === ℓ + ∇E2 = LogDensityProblems.logdensity_and_gradient(zygoteℓ, x)[2] @test sort(∇E2) ≈ grad_FWAD atol=1e-9 end @turing_testset "general AD tests" begin @@ -71,19 +79,13 @@ Turing.setadbackend(:tracker) sample(dir(), HMC(0.01, 1), 1000); Turing.setadbackend(:zygote) - sample(dir(), HMC(0.01, 1), 1000); + sample(dir(), HMC(0.01, 1), 1000) Turing.setadbackend(:reversediff) Turing.setrdcache(false) - sample(dir(), HMC(0.01, 1), 1000); + sample(dir(), HMC(0.01, 1), 1000) Turing.setrdcache(true) - sample(dir(), HMC(0.01, 1), 1000); - caches = Memoization.find_caches(Turing.Essential.memoized_taperesult) - @test length(caches) == 1 - @test !isempty(first(values(caches))) - Turing.emptyrdcache() - caches = Memoization.find_caches(Turing.Essential.memoized_taperesult) - @test length(caches) == 1 - @test isempty(first(values(caches))) + sample(dir(), HMC(0.01, 1), 1000) + Turing.setrdcache(false) end # FIXME: For some reasons PDMatDistribution AD tests fail with ReverseDiff @testset "PDMatDistribution AD" begin @@ -163,7 +165,7 @@ @test mean(Array(chn[:sigma])) ≈ std(data) atol=0.5 end - Turing.emptyrdcache() + Turing.setrdcache(false) end @testset "chunksize" begin diff --git a/test/runtests.jl b/test/runtests.jl index 5e0c1f0e2..f5e3a9289 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,6 @@ using DistributionsAD using FiniteDifferences using ForwardDiff using MCMCChains -using Memoization using NamedArrays using Optim using Optimization @@ -38,7 +37,7 @@ using MCMCChains: Chains using StatsFuns: binomlogpdf, logistic, logsumexp using TimerOutputs: TimerOutputs, @timeit using Turing: BinomialLogit, ForwardDiffAD, Sampler, SampleFromPrior, NUTS, TrackerAD, - Variational, ZygoteAD, getspace, gradient_logp + Variational, ZygoteAD, getspace using Turing.Essential: TuringDenseMvNormal, TuringDiagMvNormal using Turing.Variational: TruncatedADAGrad, DecayedADAGrad, AdvancedVI diff --git a/test/skipped/unit_test_helper.jl b/test/skipped/unit_test_helper.jl index 0a5c52789..9aa43e54e 100644 --- a/test/skipped/unit_test_helper.jl +++ b/test/skipped/unit_test_helper.jl @@ -8,9 +8,13 @@ function test_grad(turing_model, grad_f; trans=Dict()) end d = length(vi.vals) @testset "Gradient using random inputs" begin + ℓ = LogDensityProblems.ADgradient( + TrackerAD(), + LogDensityFunction(vi, model_f, SampleFromPrior(), DynamicPPL.DefaultContext()), + ) for _ = 1:10000 theta = rand(d) - @test Turing.gradient_logp(TrackerAD(), theta, vi, model_f) == grad_f(theta)[2] + @test LogDensityProblems.logdensity_and_gradient(ℓ, theta) == grad_f(theta)[2] end end end diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 55fe375c4..74c79c0a2 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -90,10 +90,14 @@ function test_model_ad(model, f, syms::Vector{Symbol}) # Call ForwardDiff's AD directly. grad_FWAD = sort(ForwardDiff.gradient(f, x)) - # Compare with `gradient_logp`. + # Compare with `logdensity_and_gradient`. z = vi[SampleFromPrior()] for chunksize in (0, 1, 10), standardtag in (true, false, 0, 3) - l, ∇E = gradient_logp(ForwardDiffAD{chunksize, standardtag}(), z, vi, model) + ℓ = LogDensityProblems.ADgradient( + ForwardDiffAD{chunksize, standardtag}(), + LogDensityFunction(vi, model, SampleFromPrior(), DynamicPPL.DefaultContext()), + ) + l, ∇E = LogDensityProblems.logdensity_and_gradient(ℓ, z) # Compare result @test l ≈ logp From f97631bab0baf7ea28c5411bb62933314af4595e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 20 Aug 2022 11:19:47 +0200 Subject: [PATCH 2/8] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 95e806677..25dfc0dba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.21.10" +version = "0.21.11" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -50,7 +50,7 @@ DynamicPPL = "0.20" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.6.7, 0.7" -LogDensityProblems = "0.11" +LogDensityProblems = "0.12" MCMCChains = "5" NamedArrays = "0.9" Reexport = "0.2, 1" From b163bd42530fb3e48a4854e834ff90dfc65e5e12 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 20 Aug 2022 11:20:36 +0200 Subject: [PATCH 3/8] Update Project.toml --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 25dfc0dba..53be72a2a 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "4" From 24b6b87b24f57400d2c240432217c325176d7576 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 20 Aug 2022 20:14:41 +0200 Subject: [PATCH 4/8] Import LogDensityProblems in submodule --- src/modes/ModeEstimation.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/modes/ModeEstimation.jl b/src/modes/ModeEstimation.jl index 8fb1faed5..184da87f1 100644 --- a/src/modes/ModeEstimation.jl +++ b/src/modes/ModeEstimation.jl @@ -10,6 +10,8 @@ using DynamicPPL: Model, AbstractContext, VarInfo, VarName, _getindex, getsym, getfield, setorder!, get_and_set_val!, istrans +import LogDensityProblems + export constrained_space, MAP, MLE, From a26f35f1f5d98a9f5790c2699fcb8697ba7aeb60 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 23 Aug 2022 16:10:29 +0200 Subject: [PATCH 5/8] Fix Gibbs sampling with DynamicHMC --- src/contrib/inference/dynamichmc.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index 552fb6e98..fd7c5a6e3 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -38,13 +38,13 @@ end # Implement interface of `Gibbs` sampler function gibbs_state( - ::Model, + model::Model, spl::Sampler{<:DynamicNUTS}, state::DynamicNUTSState, varinfo::AbstractVarInfo, ) - # Update the previous evaluation. - ℓ = state.logdensity + # Update the log density function and its cached evaluation. + ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext())) Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl]) return DynamicNUTSState(ℓ, varinfo, Q, state.metric, state.stepsize) end From 1925a15213a666cf3695dd5b1904e12c32fec2f2 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 23 Aug 2022 22:21:38 +0200 Subject: [PATCH 6/8] Update ad.jl --- test/essential/ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/essential/ad.jl b/test/essential/ad.jl index dcae69cb8..1620b50e4 100644 --- a/test/essential/ad.jl +++ b/test/essential/ad.jl @@ -27,7 +27,7 @@ _x = [_m, _s] grad_FWAD = sort(g(_x)) - ℓ = LogDensityFunction(vi, ad_test_f, SampleFromPrior(), DynamicPPL.DefaultContext()) + ℓ = Turing.LogDensityFunction(vi, ad_test_f, SampleFromPrior(), DynamicPPL.DefaultContext()) x = map(x->Float64(x), vi[SampleFromPrior()]) trackerℓ = LogDensityProblems.ADgradient(TrackerAD(), ℓ) From b9ac695040d43a53c01c586ecd627bf1f491520f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 24 Aug 2022 15:03:22 +0200 Subject: [PATCH 7/8] Add LogDensityProblems test dependency --- test/Project.toml | 2 ++ test/runtests.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 13eb9d3b3..dfaf97a50 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673" Optim = "429524aa-4258-5aef-a3af-852621145aeb" @@ -43,6 +44,7 @@ DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.20" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12" +LogDensityProblems = "0.12" MCMCChains = "5" NamedArrays = "0.9.4" Optim = "0.22, 1.0" diff --git a/test/runtests.jl b/test/runtests.jl index f5e3a9289..46e8f2647 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,8 @@ using Turing: BinomialLogit, ForwardDiffAD, Sampler, SampleFromPrior, NUTS, Trac using Turing.Essential: TuringDenseMvNormal, TuringDiagMvNormal using Turing.Variational: TruncatedADAGrad, DecayedADAGrad, AdvancedVI +import LogDensityProblems + setprogress!(false) include(pkgdir(Turing)*"/test/test_utils/AllUtils.jl") From 005b6e0b0213afc165d46de6d60f303bd894664d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 24 Aug 2022 17:03:05 +0200 Subject: [PATCH 8/8] Qualify `LogDensityFunction` --- test/skipped/unit_test_helper.jl | 2 +- test/test_utils/ad_utils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/skipped/unit_test_helper.jl b/test/skipped/unit_test_helper.jl index 9aa43e54e..b66002ff6 100644 --- a/test/skipped/unit_test_helper.jl +++ b/test/skipped/unit_test_helper.jl @@ -10,7 +10,7 @@ function test_grad(turing_model, grad_f; trans=Dict()) @testset "Gradient using random inputs" begin ℓ = LogDensityProblems.ADgradient( TrackerAD(), - LogDensityFunction(vi, model_f, SampleFromPrior(), DynamicPPL.DefaultContext()), + Turing.LogDensityFunction(vi, model_f, SampleFromPrior(), DynamicPPL.DefaultContext()), ) for _ = 1:10000 theta = rand(d) diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 74c79c0a2..7b323c749 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -95,7 +95,7 @@ function test_model_ad(model, f, syms::Vector{Symbol}) for chunksize in (0, 1, 10), standardtag in (true, false, 0, 3) ℓ = LogDensityProblems.ADgradient( ForwardDiffAD{chunksize, standardtag}(), - LogDensityFunction(vi, model, SampleFromPrior(), DynamicPPL.DefaultContext()), + Turing.LogDensityFunction(vi, model, SampleFromPrior(), DynamicPPL.DefaultContext()), ) l, ∇E = LogDensityProblems.logdensity_and_gradient(ℓ, z)