Skip to content

Commit

Permalink
Replace internal AD backend types with ADTypes (#2047)
Browse files Browse the repository at this point in the history
* Replace internal AD backend types with ADTypes

* Remove upstreamed functionality

* Update ADBackend code

* Formatting

* Update Project.toml

* Update src/essential/ad.jl

* Fix tests

* Switch ADType version to 0.1.5 for CI testing

* A few fixes

* Update Project.toml

* Make ad type a field

* Improve docs and tests

* small fix

* Fix test errors

---------

Co-authored-by: Xianda Sun <[email protected]>
  • Loading branch information
devmotion and sunxd3 authored Nov 16, 2023
1 parent d4a7975 commit 6649f10
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 169 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.29.4"
version = "0.30.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
Expand Down Expand Up @@ -44,6 +45,7 @@ TuringDynamicHMCExt = "DynamicHMC"
TuringOptimExt = "Optim"

[compat]
ADTypes = "0.2"
AbstractMCMC = "4, 5"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
Expand All @@ -61,7 +63,7 @@ EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "5, 6"
NamedArrays = "0.9, 0.10"
Optim = "1"
Expand Down
18 changes: 11 additions & 7 deletions ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ if isdefined(Base, :get_extension)
import DynamicHMC
using Turing
using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using Turing.Inference: LogDensityProblemsAD, TYPEDFIELDS
using Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
else
import ..DynamicHMC
using ..Turing
using ..Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using ..Turing.Inference: LogDensityProblemsAD, TYPEDFIELDS
using ..Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
end

"""
Expand All @@ -26,14 +26,18 @@ To use it, make sure you have DynamicHMC package (version >= 2) loaded:
using DynamicHMC
```
"""
struct DynamicNUTS{AD,space,T<:DynamicHMC.NUTS} <: Turing.Inference.Hamiltonian{AD}
struct DynamicNUTS{AD,space,T<:DynamicHMC.NUTS} <: Turing.Inference.Hamiltonian
sampler::T
adtype::AD
end

DynamicNUTS(args...) = DynamicNUTS{Turing.ADBackend()}(args...)
DynamicNUTS{AD}(spl::DynamicHMC.NUTS, space::Tuple) where AD = DynamicNUTS{AD, space, typeof(spl)}(spl)
DynamicNUTS{AD}(spl::DynamicHMC.NUTS) where AD = DynamicNUTS{AD}(spl, ())
DynamicNUTS{AD}() where AD = DynamicNUTS{AD}(DynamicHMC.NUTS())
function DynamicNUTS(
spl::DynamicHMC.NUTS = DynamicHMC.NUTS(),
space::Tuple = ();
adtype::ADTypes.AbstractADType = Turing.ADBackend()
)
return DynamicNUTS{typeof(adtype),space,typeof(spl)}(spl, adtype)
end
Turing.externalsampler(spl::DynamicHMC.NUTS) = DynamicNUTS(spl)

DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space
Expand Down
9 changes: 5 additions & 4 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Bijectors: PDMatDistribution
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using ADTypes: ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote

import AdvancedPS
import LogDensityProblems
Expand Down Expand Up @@ -40,10 +41,10 @@ export @model,
ADBackend,
setadbackend,
setadsafe,
ForwardDiffAD,
TrackerAD,
ZygoteAD,
ReverseDiffAD,
AutoForwardDiff,
AutoTracker,
AutoZygote,
AutoReverseDiff,
value,
CHUNKSIZE,
ADBACKEND,
Expand Down
47 changes: 12 additions & 35 deletions src/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,10 @@ function setchunksize(chunk_size::Int)
AdvancedVI.setchunksize(chunk_size)
end

abstract type ADBackend end
struct ForwardDiffAD{chunk,standardtag} <: ADBackend end
getchunksize(::AutoForwardDiff{chunk}) where {chunk} = chunk

# Use standard tag if not specified otherwise
ForwardDiffAD{N}() where {N} = ForwardDiffAD{N,true}()

getchunksize(::ForwardDiffAD{chunk}) where chunk = chunk

standardtag(::ForwardDiffAD{<:Any,true}) = true
standardtag(::ForwardDiffAD) = false

struct TrackerAD <: ADBackend end
struct ZygoteAD <: ADBackend end

struct ReverseDiffAD{cache} <: ADBackend end
standardtag(::AutoForwardDiff{<:Any,Nothing}) = true
standardtag(::AutoForwardDiff) = false

const RDCache = Ref(false)

Expand All @@ -63,10 +52,10 @@ getrdcache() = RDCache[]
ADBackend() = ADBackend(ADBACKEND[])
ADBackend(T::Symbol) = ADBackend(Val(T))

ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
ADBackend(::Val{:tracker}) = TrackerAD
ADBackend(::Val{:zygote}) = ZygoteAD
ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()}
ADBackend(::Val{:forwarddiff}) = AutoForwardDiff(; chunksize=CHUNKSIZE[])
ADBackend(::Val{:tracker}) = AutoTracker()
ADBackend(::Val{:zygote}) = AutoZygote()
ADBackend(::Val{:reversediff}) = AutoReverseDiff(; compile=getrdcache())

ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")

Expand All @@ -76,18 +65,18 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t
Find the autodifferentiation backend of the algorithm `alg`.
"""
getADbackend(spl::Sampler) = getADbackend(spl.alg)
getADbackend(::SampleFromPrior) = ADBackend()()
getADbackend(::SampleFromPrior) = ADBackend()
getADbackend(ctx::DynamicPPL.SamplingContext) = getADbackend(ctx.sampler)
getADbackend(ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.NodeTrait(ctx), ctx)

getADbackend(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = ADBackend()()
getADbackend(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = ADBackend()
getADbackend(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.childcontext(ctx))

function LogDensityProblemsAD.ADgradient(ℓ::Turing.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(getADbackend(ℓ.context), ℓ)
end

function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensityFunction)
function LogDensityProblemsAD.ADgradient(ad::AutoForwardDiff, ℓ::Turing.LogDensityFunction)
θ = DynamicPPL.getparams(ℓ)
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)

Expand All @@ -107,20 +96,8 @@ function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensi
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x = θ)
end

function LogDensityProblemsAD.ADgradient(::TrackerAD, ℓ::Turing.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(Val(:Tracker), ℓ)
end

function LogDensityProblemsAD.ADgradient(::ZygoteAD, ℓ::Turing.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(Val(:Zygote), ℓ)
end

for cache in (:true, :false)
@eval begin
function LogDensityProblemsAD.ADgradient(::ReverseDiffAD{$cache}, ℓ::Turing.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile=Val($cache), x=DynamicPPL.getparams(ℓ))
end
end
function LogDensityProblemsAD.ADgradient(ad::AutoReverseDiff, ℓ::Turing.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile=Val(ad.compile), x=DynamicPPL.getparams(ℓ))
end

function verifygrad(grad::AbstractVector{<:Real})
Expand Down
9 changes: 5 additions & 4 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using DataStructures: OrderedSet
using Setfield: Setfield

import ADTypes
import AbstractMCMC
import AdvancedHMC; const AHMC = AdvancedHMC
import AdvancedMH; const AMH = AdvancedMH
Expand Down Expand Up @@ -74,10 +75,10 @@ export InferenceAlgorithm,
abstract type AbstractAdapter end
abstract type InferenceAlgorithm end
abstract type ParticleInference <: InferenceAlgorithm end
abstract type Hamiltonian{AD} <: InferenceAlgorithm end
abstract type StaticHamiltonian{AD} <: Hamiltonian{AD} end
abstract type AdaptiveHamiltonian{AD} <: Hamiltonian{AD} end
getADbackend(::Hamiltonian{AD}) where AD = AD()
abstract type Hamiltonian <: InferenceAlgorithm end
abstract type StaticHamiltonian <: Hamiltonian end
abstract type AdaptiveHamiltonian <: Hamiltonian end
getADbackend(alg::Hamiltonian) = alg.adtype

"""
ExternalSampler{S<:AbstractSampler}
Expand Down
Loading

0 comments on commit 6649f10

Please sign in to comment.