From 18cdd35caaccf03e53cd8143829f1c7767621019 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 13 Sep 2020 19:03:16 +0200 Subject: [PATCH] Make `setadbackend` more consistent + use ZygoteRules (#1401) --- Project.toml | 6 ++-- src/core/Core.jl | 9 +++--- src/core/ad.jl | 52 ++++++++++++++++++++++++---------- src/core/compat/reversediff.jl | 2 +- src/core/compat/zygote.jl | 31 -------------------- src/core/deprecations.jl | 9 ++++++ 6 files changed, 56 insertions(+), 53 deletions(-) delete mode 100644 src/core/compat/zygote.jl create mode 100644 src/core/deprecations.jl diff --git a/Project.toml b/Project.toml index e157430cc..541bd001c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.14.1" +version = "0.14.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -28,6 +28,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "1" @@ -38,7 +39,7 @@ Bijectors = "0.8" Distributions = "0.23.3" DistributionsAD = "0.6" DocStringExtensions = "0.8" -DynamicPPL = "0.9" +DynamicPPL = "0.9.1" EllipticalSliceSampling = "0.2, 0.3" ForwardDiff = "0.10.3" Libtask = "0.4" @@ -51,4 +52,5 @@ SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10" StatsBase = "0.32, 0.33" StatsFuns = "0.8, 0.9" Tracker = "0.2.3" +ZygoteRules = "0.2" julia = "1.3" diff --git a/src/core/Core.jl b/src/core/Core.jl index f3177de17..9ce136bf7 100644 --- a/src/core/Core.jl +++ b/src/core/Core.jl @@ -15,13 +15,13 @@ using StatsFuns: logsumexp, softmax @reexport using DynamicPPL using Requires +import ZygoteRules + include("container.jl") include("ad.jl") +include("deprecations.jl") + function __init__() - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - export ZygoteAD - end @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("compat/reversediff.jl") export ReverseDiffAD, getrdcache, setrdcache, emptyrdcache @@ -51,6 +51,7 @@ export @model, setadsafe, ForwardDiffAD, TrackerAD, + ZygoteAD, value, gradient_logp, CHUNKSIZE, diff --git a/src/core/ad.jl b/src/core/ad.jl index 1e4f55068..dc68b867d 100644 --- a/src/core/ad.jl +++ b/src/core/ad.jl @@ -2,27 +2,23 @@ # Global variables/constants # ############################## const ADBACKEND = Ref(:forwarddiff) -function setadbackend(backend_sym::Symbol) - setadbackend(Val(backend_sym)) - AdvancedVI.setadbackend(Val(backend_sym)) - Bijectors.setadbackend(Val(backend_sym)) +setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) +function setadbackend(backend::Val) + _setadbackend(backend) + AdvancedVI.setadbackend(backend) + Bijectors.setadbackend(backend) end -function setadbackend(::Val{:forward_diff}) - Base.depwarn("`Turing.setadbackend(:forward_diff)` is deprecated. Please use `Turing.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend) - setadbackend(Val(:forwarddiff)) -end -function setadbackend(::Val{:forwarddiff}) + +function _setadbackend(::Val{:forwarddiff}) CHUNKSIZE[] == 0 && setchunksize(40) ADBACKEND[] = :forwarddiff end - -function setadbackend(::Val{:reverse_diff}) - Base.depwarn("`Turing.setadbackend(:reverse_diff)` is deprecated. Please use `Turing.setadbackend(:tracker)` to use `Tracker` or `Turing.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend) - setadbackend(Val(:tracker)) -end -function setadbackend(::Val{:tracker}) +function _setadbackend(::Val{:tracker}) ADBACKEND[] = :tracker end +function _setadbackend(::Val{:zygote}) + ADBACKEND[] = :zygote +end const ADSAFE = Ref(false) function setadsafe(switch::Bool) @@ -46,12 +42,14 @@ getchunksize(::Type{<:Sampler{Talg}}) where Talg = getchunksize(Talg) getchunksize(::Type{SampleFromPrior}) = CHUNKSIZE[] struct TrackerAD <: ADBackend end +struct ZygoteAD <: ADBackend end ADBackend() = ADBackend(ADBACKEND[]) ADBackend(T::Symbol) = ADBackend(Val(T)) ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} ADBackend(::Val{:tracker}) = TrackerAD +ADBackend(::Val{:zygote}) = ZygoteAD ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") """ @@ -149,6 +147,30 @@ function gradient_logp( return l, ∂l∂θ end +function gradient_logp( + backend::ZygoteAD, + θ::AbstractVector{<:Real}, + vi::VarInfo, + model::Model, + sampler::AbstractSampler = SampleFromPrior(), + context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() +) + T = typeof(getlogp(vi)) + + # Specify objective function. + function f(θ) + new_vi = VarInfo(vi, sampler, θ) + model(new_vi, sampler) + return getlogp(new_vi) + end + + # Compute forward and reverse passes. + l::T, ȳ = ZygoteRules.pullback(f, θ) + ∂l∂θ::typeof(θ) = ȳ(1)[1] + + return l, ∂l∂θ +end + function verifygrad(grad::AbstractVector{<:Real}) if any(isnan, grad) || any(isinf, grad) @warn("Numerical error in gradients. Rejecting current proposal...") diff --git a/src/core/compat/reversediff.jl b/src/core/compat/reversediff.jl index 400663628..5e7666b0e 100644 --- a/src/core/compat/reversediff.jl +++ b/src/core/compat/reversediff.jl @@ -10,7 +10,7 @@ function emptyrdcache end getrdcache() = RDCache[] ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()} -function setadbackend(::Val{:reversediff}) +function _setadbackend(::Val{:reversediff}) ADBACKEND[] = :reversediff end diff --git a/src/core/compat/zygote.jl b/src/core/compat/zygote.jl deleted file mode 100644 index 1f1be4647..000000000 --- a/src/core/compat/zygote.jl +++ /dev/null @@ -1,31 +0,0 @@ -struct ZygoteAD <: ADBackend end -ADBackend(::Val{:zygote}) = ZygoteAD -function setadbackend(::Val{:zygote}) - ADBACKEND[] = :zygote -end - -function gradient_logp( - backend::ZygoteAD, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - T = typeof(getlogp(vi)) - - # Specify objective function. - function f(θ) - new_vi = VarInfo(vi, sampler, θ) - model(new_vi, sampler) - return getlogp(new_vi) - end - - # Compute forward and reverse passes. - l::T, ȳ = Zygote.pullback(f, θ) - ∂l∂θ::typeof(θ) = ȳ(1)[1] - - return l, ∂l∂θ -end - -Zygote.@nograd DynamicPPL.updategid! diff --git a/src/core/deprecations.jl b/src/core/deprecations.jl new file mode 100644 index 000000000..444e05c40 --- /dev/null +++ b/src/core/deprecations.jl @@ -0,0 +1,9 @@ +function setadbackend(::Val{:forward_diff}) + Base.depwarn("`Turing.setadbackend(:forward_diff)` is deprecated. Please use `Turing.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend) + setadbackend(Val(:forwarddiff)) +end + +function setadbackend(::Val{:reverse_diff}) + Base.depwarn("`Turing.setadbackend(:reverse_diff)` is deprecated. Please use `Turing.setadbackend(:tracker)` to use `Tracker` or `Turing.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend) + setadbackend(Val(:tracker)) +end