From 33487da76d9874adb7bee1b0509d0a3172580c9a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 10 Jan 2023 13:51:27 +0100 Subject: [PATCH] Support log density functions as models (#113) * Update sample.jl * Update sample.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update api.md * Update stepper.jl * Update transducer.jl * Update api.md * Update src/stepper.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/transducer.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * Update src/sample.jl Co-authored-by: Tor Erlend Fjelde * Reorganize fallbacks * Add tests * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * Define utilities on all workers * Update test/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde --- Project.toml | 2 +- docs/src/api.md | 26 ++++++++++- src/logdensityproblems.jl | 92 ++++++++++++++++++++++++++++++++++++++ src/sample.jl | 53 ++++++++++++---------- src/stepper.jl | 11 +++-- src/transducer.jl | 11 +++-- test/logdensityproblems.jl | 90 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 + test/sample.jl | 11 ++++- test/utils.jl | 32 +++++++++++++ 10 files changed, 296 insertions(+), 34 deletions(-) create mode 100644 test/logdensityproblems.jl diff --git a/Project.toml b/Project.toml index 2af00074..a6bf3e65 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.3.0" +version = "4.4.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/docs/src/api.md b/docs/src/api.md index 9ce28805..52c2c2e1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -2,23 +2,39 @@ AbstractMCMC defines an interface for sampling Markov chains. +## Model + +```@docs +AbstractMCMC.AbstractModel +AbstractMCMC.LogDensityModel +``` + +## Sampler + +```@docs +AbstractMCMC.AbstractSampler +``` + ## Sampling a single chain ```@docs -AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Integer) AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Any) +AbstractMCMC.sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler, ::Any) + ``` ### Iterator ```@docs AbstractMCMC.steps(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler) +AbstractMCMC.steps(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler) ``` ### Transducer ```@docs AbstractMCMC.Sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler) +AbstractMCMC.Sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler) ``` ## Sampling multiple chains in parallel @@ -32,6 +48,14 @@ AbstractMCMC.sample( ::Integer, ::Integer, ) +AbstractMCMC.sample( + ::AbstractRNG, + ::Any, + ::AbstractMCMC.AbstractSampler, + ::AbstractMCMC.AbstractMCMCEnsemble, + ::Integer, + ::Integer, +) ``` Two algorithms are provided for parallel sampling with multiple threads and multiple processes, and one allows for the user to sample multiple chains in serial (no parallelization): diff --git a/src/logdensityproblems.jl b/src/logdensityproblems.jl index 54db36bb..f15f656a 100644 --- a/src/logdensityproblems.jl +++ b/src/logdensityproblems.jl @@ -25,3 +25,95 @@ struct LogDensityModel{L} <: AbstractModel end LogDensityModel(logdensity::L) where {L} = LogDensityModel{L}(logdensity) + +# Fallbacks: Wrap log density function in a model +""" + sample( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler, + N_or_isdone; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function StatsBase.sample( + rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler, N_or_isdone; kwargs... +) + return StatsBase.sample(rng, _model(logdensity), sampler, N_or_isdone; kwargs...) +end + +""" + sample( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler, + parallel::AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function StatsBase.sample( + rng::Random.AbstractRNG, + logdensity, + sampler::AbstractSampler, + parallel::AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + kwargs..., +) + return StatsBase.sample( + rng, _model(logdensity), sampler, parallel, N, nchains; kwargs... + ) +end + +""" + steps( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `steps` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function steps(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...) + return steps(rng, _model(logdensity), sampler; kwargs...) +end + +""" + Sample( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `Sample` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function Sample(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...) + return Sample(rng, _model(logdensity), sampler; kwargs...) +end + +function _model(logdensity) + if LogDensityProblems.capabilities(logdensity) === nothing + throw( + ArgumentError( + "the log density function does not support the LogDensityProblems.jl interface. Please implement the interface or provide a model of type `AbstractMCMC.AbstractModel`", + ), + ) + end + return LogDensityModel(logdensity) +end diff --git a/src/sample.jl b/src/sample.jl index c6b0112f..dc951ca2 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -12,32 +12,29 @@ function setprogress!(progress::Bool) return progress end -function StatsBase.sample(model::AbstractModel, sampler::AbstractSampler, arg; kwargs...) - return StatsBase.sample(Random.default_rng(), model, sampler, arg; kwargs...) -end - -""" - sample([rng, ]model, sampler, N; kwargs...) - -Return `N` samples from the `model` with the Markov chain Monte Carlo `sampler`. -""" function StatsBase.sample( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler, - N::Integer; - kwargs..., + model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs... ) - return mcmcsample(rng, model, sampler, N; kwargs...) + return StatsBase.sample( + Random.default_rng(), model_or_logdensity, sampler, N_or_isdone; kwargs... + ) end """ - sample([rng, ]model, sampler, isdone; kwargs...) + sample( + rng::Random.AbatractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler, + N_or_isdone; + kwargs..., + ) + +Sample from the `model` with the Markov chain Monte Carlo `sampler` and return the samples. -Sample from the `model` with the Markov chain Monte Carlo `sampler` until a -convergence criterion `isdone` returns `true`, and return the samples. +If `N_or_isdone` is an `Integer`, exactly `N_or_isdone` samples are returned. -The function `isdone` has the signature +Otherwise, sampling is performed until a convergence criterion `N_or_isdone` returns `true`. +The convergence criterion has to be a function with the signature ```julia isdone(rng, model, sampler, samples, state, iteration; kwargs...) ``` @@ -48,14 +45,14 @@ function StatsBase.sample( rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler, - isdone; + N_or_isdone; kwargs..., ) - return mcmcsample(rng, model, sampler, isdone; kwargs...) + return mcmcsample(rng, model, sampler, N_or_isdone; kwargs...) end function StatsBase.sample( - model::AbstractModel, + model_or_logdensity, sampler::AbstractSampler, parallel::AbstractMCMCEnsemble, N::Integer, @@ -63,12 +60,20 @@ function StatsBase.sample( kwargs..., ) return StatsBase.sample( - Random.default_rng(), model, sampler, parallel, N, nchains; kwargs... + Random.default_rng(), model_or_logdensity, sampler, parallel, N, nchains; kwargs... ) end """ - sample([rng, ]model, sampler, parallel, N, nchains; kwargs...) + sample( + rng::Random.AbstractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler, + parallel::AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + kwargs..., + ) Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel using the `parallel` algorithm, and combine them into a single chain. diff --git a/src/stepper.jl b/src/stepper.jl index 68059926..a71826cb 100644 --- a/src/stepper.jl +++ b/src/stepper.jl @@ -41,12 +41,17 @@ end Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite() Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown() -function steps(model::AbstractModel, sampler::AbstractSampler; kwargs...) - return steps(Random.default_rng(), model, sampler; kwargs...) +function steps(model_or_logdensity, sampler::AbstractSampler; kwargs...) + return steps(Random.default_rng(), model_or_logdensity, sampler; kwargs...) end """ - steps([rng, ]model, sampler; kwargs...) + steps( + rng::Random.AbstractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler; + kwargs..., + ) Create an iterator that returns samples from the `model` with the Markov chain Monte Carlo `sampler`. diff --git a/src/transducer.jl b/src/transducer.jl index 46d36d91..63bff3fd 100644 --- a/src/transducer.jl +++ b/src/transducer.jl @@ -6,12 +6,17 @@ struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: kwargs::K end -function Sample(model::AbstractModel, sampler::AbstractSampler; kwargs...) - return Sample(Random.default_rng(), model, sampler; kwargs...) +function Sample(model_or_logdensity, sampler::AbstractSampler; kwargs...) + return Sample(Random.default_rng(), model_or_logdensity, sampler; kwargs...) end """ - Sample([rng, ]model, sampler; kwargs...) + Sample( + rng::Random.AbstractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler; + kwargs..., + ) Create a transducer that returns samples from the `model` with the Markov chain Monte Carlo `sampler`. diff --git a/test/logdensityproblems.jl b/test/logdensityproblems.jl new file mode 100644 index 00000000..181d2645 --- /dev/null +++ b/test/logdensityproblems.jl @@ -0,0 +1,90 @@ +@testset "logdensityproblems.jl" begin + # Add worker processes. + # Memory requirements on Windows are ~4x larger than on Linux, hence number of processes is reduced + # See, e.g., https://github.com/JuliaLang/julia/issues/40766 and https://github.com/JuliaLang/Pkg.jl/pull/2366 + pids = addprocs(Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int) + + # Load all required packages (`utils.jl` needs LogDensityProblems, Logging, and Random). + @everywhere begin + using AbstractMCMC + using AbstractMCMC: sample + using LogDensityProblems + + using Logging + using Random + include("utils.jl") + end + + @testset "LogDensityModel" begin + ℓ = MyLogDensity(10) + model = @inferred AbstractMCMC.LogDensityModel(ℓ) + @test model isa AbstractMCMC.LogDensityModel{MyLogDensity} + @test model.logdensity === ℓ + + @test_throws ArgumentError AbstractMCMC.LogDensityModel(mylogdensity) + end + + @testset "fallback for log densities" begin + # Sample with log density + dim = 10 + ℓ = MyLogDensity(dim) + Random.seed!(1234) + N = 1_000 + samples = sample(ℓ, MySampler(), N) + + # Samples are of the correct dimension and log density values are correct + @test length(samples) == N + @test all(length(x.a) == dim for x in samples) + @test all(x.b ≈ LogDensityProblems.logdensity(ℓ, x.a) for x in samples) + + # Same chain as if LogDensityModel is used explicitly + Random.seed!(1234) + samples2 = sample(AbstractMCMC.LogDensityModel(ℓ), MySampler(), N) + @test length(samples2) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples2)) + + # Same chain if sampling is performed with convergence criterion + Random.seed!(1234) + isdone(rng, model, sampler, state, samples, iteration; kwargs...) = iteration > N + samples3 = sample(ℓ, MySampler(), isdone) + @test length(samples3) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples3)) + + # Same chain if sampling is performed with iterator + Random.seed!(1234) + samples4 = collect(Iterators.take(AbstractMCMC.steps(ℓ, MySampler()), N)) + @test length(samples4) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples4)) + + # Same chain if sampling is performed with transducer + Random.seed!(1234) + xf = AbstractMCMC.Sample(ℓ, MySampler()) + samples5 = collect(xf(1:N)) + @test length(samples5) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples5)) + + # Parallel sampling + for alg in (MCMCSerial(), MCMCDistributed(), MCMCThreads()) + chains = sample(ℓ, MySampler(), alg, N, 2) + @test length(chains) == 2 + samples = vcat(chains[1], chains[2]) + @test length(samples) == 2 * N + @test all(length(x.a) == dim for x in samples) + @test all(x.b ≈ LogDensityProblems.logdensity(ℓ, x.a) for x in samples) + end + + # Log density has to satisfy the LogDensityProblems interface + @test_throws ArgumentError sample(mylogdensity, MySampler(), N) + @test_throws ArgumentError sample(mylogdensity, MySampler(), isdone) + @test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCSerial(), N, 2) + @test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCThreads(), N, 2) + @test_throws ArgumentError sample( + mylogdensity, MySampler(), MCMCDistributed(), N, 2 + ) + @test_throws ArgumentError AbstractMCMC.steps(mylogdensity, MySampler()) + @test_throws ArgumentError AbstractMCMC.Sample(mylogdensity, MySampler()) + end + + # Remove workers + rmprocs(pids...) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3baef78c..0b002b21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using AbstractMCMC using Atom.Progress: JunoProgressLogger using ConsoleProgressMonitor: ProgressLogger using IJulia +using LogDensityProblems using LoggingExtras: TeeLogger, EarlyFilteredLogger using TerminalLoggers: TerminalLogger using Transducers @@ -22,4 +23,5 @@ include("utils.jl") include("sample.jl") include("stepper.jl") include("transducer.jl") + include("logdensityproblems.jl") end diff --git a/test/sample.jl b/test/sample.jl index 97bf5a5e..7ced7f0c 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -221,13 +221,17 @@ # Add worker processes. # Memory requirements on Windows are ~4x larger than on Linux, hence number of processes is reduced # See, e.g., https://github.com/JuliaLang/julia/issues/40766 and https://github.com/JuliaLang/Pkg.jl/pull/2366 - addprocs(Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int) + pids = addprocs( + Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int + ) - # Load all required packages (`interface.jl` needs Random). + # Load all required packages (`utils.jl` needs LogDensityProblems, Logging, and Random). @everywhere begin using AbstractMCMC using AbstractMCMC: sample + using LogDensityProblems + using Logging using Random include("utils.jl") end @@ -316,6 +320,9 @@ @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) + + # Remove workers + rmprocs(pids...) end @testset "Serial sampling" begin diff --git a/test/utils.jl b/test/utils.jl index e2eedcb4..f69fcdab 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -81,3 +81,35 @@ end # Conversion to NamedTuple Base.convert(::Type{NamedTuple}, x::MySample) = (a=x.a, b=x.b) + +# Gaussian log density (without additive constants) +# Without LogDensityProblems.jl interface +mylogdensity(x) = -sum(abs2, x) / 2 + +# With LogDensityProblems.jl interface +struct MyLogDensity + dim::Int +end +LogDensityProblems.logdensity(::MyLogDensity, x) = mylogdensity(x) +LogDensityProblems.dimension(m::MyLogDensity) = m.dim +function LogDensityProblems.capabilities(::Type{MyLogDensity}) + return LogDensityProblems.LogDensityOrder{0}() +end + +# Define "sampling" +function AbstractMCMC.step( + rng::AbstractRNG, + model::AbstractMCMC.LogDensityModel{MyLogDensity}, + ::MySampler, + state::Union{Nothing,Integer}=nothing; + kwargs..., +) + # Sample from multivariate normal distribution + ℓ = model.logdensity + dim = LogDensityProblems.dimension(ℓ) + θ = randn(rng, dim) + logdensity_θ = LogDensityProblems.logdensity(ℓ, θ) + + _state = state === nothing ? 1 : state + 1 + return MySample(θ, logdensity_θ), _state +end