Skip to content

Commit

Permalink
Merge branch 'master' into dw/enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored May 29, 2024
2 parents 0385250 + 3c6149f commit 24cc3a9
Show file tree
Hide file tree
Showing 26 changed files with 1,204 additions and 155 deletions.
16 changes: 9 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.30.7"
version = "0.32.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
Expand All @@ -24,12 +25,12 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
Expand All @@ -45,32 +46,33 @@ TuringDynamicHMCExt = "DynamicHMC"
TuringOptimExt = "Optim"

[compat]
ADTypes = "0.2"
ADTypes = "0.2, 1"
AbstractMCMC = "5.2"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
AdvancedPS = "0.5.4"
AdvancedPS = "0.6.0"
AdvancedVI = "0.2"
BangBang = "0.3"
BangBang = "0.4"
Bijectors = "0.13.6"
DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.24.7"
DynamicPPL = "0.27.1"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "5, 6"
NamedArrays = "0.9, 0.10"
OrderedCollections = "1"
Optim = "1"
Reexport = "0.2, 1"
Requires = "0.5, 1.0"
SciMLBase = "1.37.1, 2"
Setfield = "0.8, 1"
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
Statistics = "1.6"
StatsAPI = "1.6"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

Turing's home page, with links to everything you'll need to use Turing, is available at:

https://turinglang.org/dev/docs/using-turing/get-started
https://turinglang.org/docs/


## What's changed recently?
Expand Down
2 changes: 1 addition & 1 deletion ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end
function DynamicNUTS(
spl::DynamicHMC.NUTS = DynamicHMC.NUTS(),
space::Tuple = ();
adtype::ADTypes.AbstractADType = ADTypes.AutoForwardDiff(; chunksize=0)
adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE
)
return DynamicNUTS{typeof(adtype),space,typeof(spl)}(spl, adtype)
end
Expand Down
26 changes: 12 additions & 14 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ module TuringOptimExt

if isdefined(Base, :get_extension)
import Turing
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import Optim
else
import ..Turing
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import ..Optim
end

Expand Down Expand Up @@ -80,7 +80,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff
# Hessian is computed with respect to the untransformed parameters.
linked = DynamicPPL.istrans(m.f.varinfo)
if linked
Setfield.@set! m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
m = Accessors.@set m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
end

# Calculate the Hessian, which is the information matrix because the negative of the log likelihood was optimized
Expand All @@ -89,7 +89,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff

# Link it back if we invlinked it.
if linked
Setfield.@set! m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
m = Accessors.@set m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
end

return NamedArrays.NamedArray(info, (varnames, varnames))
Expand Down Expand Up @@ -227,8 +227,8 @@ function _optimize(
)
# Convert the initial values, since it is assumed that users provide them
# in the constrained space.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)
init_vals = DynamicPPL.getparams(f)

# Optimize!
Expand All @@ -241,17 +241,15 @@ function _optimize(

# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
Setfield.@set! f.varinfo = DynamicPPL.invlink!!(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
f = Accessors.@set f.varinfo = DynamicPPL.invlink(f.varinfo, model)
vals = DynamicPPL.getparams(f)
Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)

# Make one transition to get the parameter names.
ts = [Turing.Inference.Transition(
Turing.Inference.getparams(model, f.varinfo),
DynamicPPL.getlogp(f.varinfo)
)]
varnames = map(Symbol, first(Turing.Inference._params_to_array(model, ts)))
vns_vals_iter = Turing.Inference.getparams(model, f.varinfo)
varnames = map(Symbol first, vns_vals_iter)
vals = map(last, vns_vals_iter)

# Store the parameters and their names in an array.
vmat = NamedArrays.NamedArray(vals, varnames)
Expand Down
25 changes: 22 additions & 3 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@ using DynamicPPL: DynamicPPL, LogDensityFunction
import DynamicPPL: getspace, NoDist, NamedDist
import LogDensityProblems
import NamedArrays
import Setfield
import Accessors
import StatsAPI
import StatsBase

using Accessors: Accessors

import Printf
import Random

using ADTypes: ADTypes

const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff()

const PROGRESS = Ref(true)

# TODO: remove `PROGRESS` and this function in favour of `AbstractMCMC.PROGRESS`
Expand Down Expand Up @@ -48,14 +54,17 @@ using .Variational
include("optimisation/Optimisation.jl")
using .Optimisation

include("experimental/Experimental.jl")
include("deprecated.jl") # to be removed in the next minor version release

###########
# Exports #
###########
# `using` statements for stuff to re-export
using DynamicPPL: pointwise_loglikelihoods, generated_quantities, logprior, logjoint
using DynamicPPL: pointwise_loglikelihoods, generated_quantities, logprior, logjoint, condition, decondition, fix, unfix, conditioned
using StatsBase: predict
using Bijectors: ordered
using OrderedCollections: OrderedDict

# Turing essentials - modelling macros and inference algorithms
export @model, # modelling
Expand Down Expand Up @@ -98,6 +107,7 @@ export @model, # modelling
AutoReverseDiff,
AutoZygote,
AutoTracker,
AutoTapir,

setprogress!, # debugging

Expand All @@ -107,10 +117,10 @@ export @model, # modelling
BernoulliLogit, # Part of Distributions >= 0.25.77
OrderedLogistic,
LogPoisson,
NamedDist,
filldist,
arraydist,

NamedDist, # Exports from DynamicPPL
predict,
pointwise_loglikelihoods,
elementwise_loglikelihoods,
Expand All @@ -119,6 +129,15 @@ export @model, # modelling
logjoint,
LogDensityFunction,

condition,
decondition,
fix,
unfix,
conditioned,
OrderedDict,

ordered, # Exports from Bijectors

constrained_space, # optimisation interface
MAP,
MLE,
Expand Down
4 changes: 2 additions & 2 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
export setadbackend, setchunksize, setadsafe

function setadbackend(::Union{Symbol, Val})
Base.depwarn("`ADBACKEND` and `setbackend` are deprecated. Please specify the chunk size directly in the sampler constructor, e.g., `HMC(0.1, 5; adtype=AutoForwardDiff(; chunksize=0))`.\n This function has no effects.", :setbackend; force=true)
Base.depwarn("`ADBACKEND` and `setbackend` are deprecated. Please specify the chunk size directly in the sampler constructor, e.g., `HMC(0.1, 5; adtype=AutoForwardDiff())`.\n This function has no effects.", :setbackend; force=true)
nothing
end

function setchunksize(::Int)
Base.depwarn("`CHUNKSIZE` and `setchunksize` are deprecated. Please specify the chunk size directly in the sampler constructor, e.g., `HMC(0.1, 5; adtype=AutoForwardDiff(; chunksize=0))`.\n This function has no effects.", :setchunksize; force=true)
Base.depwarn("`CHUNKSIZE` and `setchunksize` are deprecated. Please specify the chunk size directly in the sampler constructor, e.g., `HMC(0.1, 5; adtype=AutoForwardDiff())`.\n This function has no effects.", :setchunksize; force=true)
nothing
end

Expand Down
16 changes: 16 additions & 0 deletions src/experimental/Experimental.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module Experimental

using Random: Random
using AbstractMCMC: AbstractMCMC
using DynamicPPL: DynamicPPL, VarName
using Accessors: Accessors

using DocStringExtensions: TYPEDFIELDS
using Distributions

using ..Turing: Turing
using ..Turing.Inference: gibbs_rerun, InferenceAlgorithm

include("gibbs.jl")

end
Loading

0 comments on commit 24cc3a9

Please sign in to comment.