diff --git a/Project.toml b/Project.toml index 90117048..a0ba0825 100644 --- a/Project.toml +++ b/Project.toml @@ -3,12 +3,13 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.5.0" +version = "5.0.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" @@ -21,6 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19" ConsoleProgressMonitor = "0.1" +FillArrays = "1" LogDensityProblems = "2" LoggingExtras = "0.4, 0.5, 1" ProgressLogging = "0.1" diff --git a/docs/src/api.md b/docs/src/api.md index 1af4fdf8..f6dd1cf8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -79,14 +79,16 @@ Common keyword arguments for regular and parallel sampling are: - `discard_initial` (default: `num_warmup`): number of initial samples that are discarded. Note that if `discard_initial < num_warmup`, warm-up samples will also be included in the resulting samples. - `thinning` (default: `1`): factor by which to thin samples. +- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref) + is passed `initial_state` as the `state` argument. !!! info The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref). There is no "official" way for providing initial parameter values yet. -However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. -To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): -- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`. +However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain. +To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): +- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`. Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 64f20f97..dc464d42 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -8,6 +8,7 @@ using ProgressLogging: ProgressLogging using StatsBase: StatsBase using TerminalLoggers: TerminalLoggers using Transducers: Transducers +using FillArrays: FillArrays using Distributed: Distributed using Logging: Logging diff --git a/src/sample.jl b/src/sample.jl index 35f551ad..b3af2102 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -113,6 +113,7 @@ function mcmcsample( discard_initial::Int=num_warmup, thinning=1, chain_type::Type=Any, + initial_state=nothing, kwargs..., ) # Check the number of requested samples. @@ -145,9 +146,17 @@ function mcmcsample( # Obtain the initial sample and state. sample, state = if num_warmup > 0 - step_warmup(rng, model, sampler; kwargs...) + if initial_state === nothing + step_warmup(rng, model, sampler; kwargs...) + else + step_warmup(rng, model, sampler, initial_state; kwargs...) + end else - step(rng, model, sampler; kwargs...) + if initial_state === nothing + step(rng, model, sampler; kwargs...) + else + step(rng, model, sampler, initial_state; kwargs...) + end end # Update the progress bar. @@ -253,6 +262,7 @@ function mcmcsample( num_warmup=0, discard_initial=num_warmup, thinning=1, + initial_state=nothing, kwargs..., ) # Determine how many samples to drop from `num_warmup` and the @@ -267,9 +277,17 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. sample, state = if num_warmup > 0 - step_warmup(rng, model, sampler; kwargs...) + if initial_state === nothing + step_warmup(rng, model, sampler; kwargs...) + else + step_warmup(rng, model, sampler, initial_state; kwargs...) + end else - step(rng, model, sampler; kwargs...) + if initial_state === nothing + step(rng, model, sampler; kwargs...) + else + step(rng, model, sampler, initial_state; kwargs...) + end end # Discard initial samples. @@ -349,7 +367,8 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", - init_params=nothing, + initial_params=nothing, + initial_state=nothing, kwargs..., ) # Check if actually multiple threads are used. @@ -373,8 +392,9 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) - # Ensure that initial parameters are `nothing` or of the correct length - check_initial_params(init_params, nchains) + # Ensure that initial parameters and states are `nothing` or of the correct length + check_initial_params(initial_params, nchains) + check_initial_state(initial_state, nchains) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -425,10 +445,15 @@ function mcmcsample( _sampler, N; progress=false, - init_params=if init_params === nothing + initial_params=if initial_params === nothing nothing else - init_params[chainidx] + initial_params[chainidx] + end, + initial_state=if initial_state === nothing + nothing + else + initial_state[chainidx] end, kwargs..., ) @@ -458,7 +483,8 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) processes)", - init_params=nothing, + initial_params=nothing, + initial_state=nothing, kwargs..., ) # Check if actually multiple processes are used. @@ -471,8 +497,14 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters are `nothing` or of the correct length - check_initial_params(init_params, nchains) + # Ensure that initial parameters and states are `nothing` or of the correct length + check_initial_params(initial_params, nchains) + check_initial_state(initial_state, nchains) + + _initial_params = + initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params + _initial_state = + initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -509,7 +541,7 @@ function mcmcsample( Distributed.@async begin try - function sample_chain(seed, init_params=nothing) + function sample_chain(seed, initial_params, initial_state) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -520,7 +552,8 @@ function mcmcsample( sampler, N; progress=false, - init_params=init_params, + initial_params=initial_params, + initial_state=initial_state, kwargs..., ) @@ -530,11 +563,9 @@ function mcmcsample( # Return the new chain. return chain end - chains = if init_params === nothing - Distributed.pmap(sample_chain, pool, seeds) - else - Distributed.pmap(sample_chain, pool, seeds, init_params) - end + chains = Distributed.pmap( + sample_chain, pool, seeds, _initial_params, _initial_state + ) finally # Stop updating the progress bar. progress && put!(channel, false) @@ -555,7 +586,8 @@ function mcmcsample( N::Integer, nchains::Integer; progressname="Sampling", - init_params=nothing, + initial_params=nothing, + initial_state=nothing, kwargs..., ) # Check if the number of chains is larger than the number of samples @@ -563,14 +595,20 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters are `nothing` or of the correct length - check_initial_params(init_params, nchains) + # Ensure that initial parameters and states are `nothing` or of the correct length + check_initial_params(initial_params, nchains) + check_initial_state(initial_state, nchains) + + _initial_params = + initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params + _initial_state = + initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) # Sample the chains. - function sample_chain(i, seed, init_params=nothing) + function sample_chain(i, seed, initial_params, initial_state) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -581,16 +619,13 @@ function mcmcsample( sampler, N; progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), - init_params=init_params, + initial_params=initial_params, + initial_state=initial_state, kwargs..., ) end - chains = if init_params === nothing - map(sample_chain, 1:nchains, seeds) - else - map(sample_chain, 1:nchains, seeds, init_params) - end + chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state) # Concatenate the chains together. return chainsstack(tighten_eltype(chains)) @@ -604,7 +639,6 @@ tighten_eltype(x::Vector{Any}) = map(identity, x) "initial parameters must be specified as a vector of length equal to the number of chains or `nothing`", ), ) - check_initial_params(::Nothing, n) = nothing function check_initial_params(x::AbstractArray, n) if length(x) != n @@ -617,3 +651,21 @@ function check_initial_params(x::AbstractArray, n) return nothing end + +@nospecialize check_initial_state(x, n) = throw( + ArgumentError( + "initial states must be specified as a vector of length equal to the number of chains or `nothing`", + ), +) +check_initial_state(::Nothing, n) = nothing +function check_initial_state(x::AbstractArray, n) + if length(x) != n + throw( + ArgumentError( + "incorrect number of initial states (expected $n, received $(length(x))" + ), + ) + end + + return nothing +end diff --git a/test/sample.jl b/test/sample.jl index 22f4b26d..dcc87526 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -28,7 +28,7 @@ # initial parameters chain = sample( - MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8) + MyModel(), MySampler(), 3; progress=false, initial_params=(b=3.2, a=-1.8) ) @test chain[1].a == -1.8 @test chain[1].b == 3.2 @@ -163,7 +163,7 @@ # initial parameters nchains = 100 - init_params = [(b=randn(), a=rand()) for _ in 1:nchains] + initial_params = [(b=randn(), a=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -171,15 +171,15 @@ 3, nchains; progress=false, - init_params=init_params, + initial_params=initial_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, init_params) + (chain, params) in zip(chains, initial_params) ) - init_params = (a=randn(), b=rand()) + initial_params = (a=randn(), b=rand()) chains = sample( MyModel(), MySampler(), @@ -187,14 +187,15 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains), + initial_params=FillArrays.Fill(initial_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for + chain in chains ) - # Too many `init_params` + # Too many `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -202,10 +203,10 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains + 1), + initial_params=FillArrays.Fill(initial_params, nchains + 1), ) - # Too few `init_params` + # Too few `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -213,7 +214,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains - 1), + initial_params=FillArrays.Fill(initial_params, nchains - 1), ) end @@ -298,7 +299,7 @@ # initial parameters nchains = 100 - init_params = [(a=randn(), b=rand()) for _ in 1:nchains] + initial_params = [(a=randn(), b=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -306,15 +307,15 @@ 3, nchains; progress=false, - init_params=init_params, + initial_params=initial_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, init_params) + (chain, params) in zip(chains, initial_params) ) - init_params = (b=randn(), a=rand()) + initial_params = (b=randn(), a=rand()) chains = sample( MyModel(), MySampler(), @@ -322,14 +323,15 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains), + initial_params=FillArrays.Fill(initial_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for + chain in chains ) - # Too many `init_params` + # Too many `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -337,10 +339,10 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains + 1), + initial_params=FillArrays.Fill(initial_params, nchains + 1), ) - # Too few `init_params` + # Too few `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -348,7 +350,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains - 1), + initial_params=FillArrays.Fill(initial_params, nchains - 1), ) # Remove workers @@ -358,13 +360,21 @@ @testset "Serial sampling" begin # No dedicated chains type N = 10_000 - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; progress=false) @test chains isa Vector{<:Vector{<:MySample}} @test length(chains) == 1000 @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) + chains = sample( + MyModel(), + MySampler(), + MCMCSerial(), + N, + 1000; + chain_type=MyChain, + progress=false, + ) # Test output type and size. @test chains isa Vector{<:MyChain} @@ -380,7 +390,15 @@ # Test reproducibility. Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) + chains2 = sample( + MyModel(), + MySampler(), + MCMCSerial(), + N, + 1000; + chain_type=MyChain, + progress=false, + ) @test all(ismissing(c.as[1]) for c in chains2) @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @@ -407,7 +425,7 @@ # initial parameters nchains = 100 - init_params = [(a=rand(), b=randn()) for _ in 1:nchains] + initial_params = [(a=rand(), b=randn()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -415,15 +433,15 @@ 3, nchains; progress=false, - init_params=init_params, + initial_params=initial_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, init_params) + (chain, params) in zip(chains, initial_params) ) - init_params = (b=rand(), a=randn()) + initial_params = (b=rand(), a=randn()) chains = sample( MyModel(), MySampler(), @@ -431,14 +449,15 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains), + initial_params=FillArrays.Fill(initial_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for + chain in chains ) - # Too many `init_params` + # Too many `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -446,10 +465,10 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains + 1), + initial_params=FillArrays.Fill(initial_params, nchains + 1), ) - # Too few `init_params` + # Too few `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -457,7 +476,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains - 1), + initial_params=FillArrays.Fill(initial_params, nchains - 1), ) end @@ -655,4 +674,63 @@ ) @test it_array == collect(1:size(chain, 1)) end + + @testset "Providing initial state" begin + function record_state( + rng, model, sampler, sample, state, i; states_channel, kwargs... + ) + return put!(states_channel, state) + end + + initial_state = 10 + + @testset "sample" begin + n = 10 + states_channel = Channel{Int}(n) + chain = sample( + MyModel(), + MySampler(), + n; + initial_state=initial_state, + callback=record_state, + states_channel=states_channel, + ) + + # Extract the states. + states = [take!(states_channel) for _ in 1:n] + @test length(states) == n + for i in 1:n + @test states[i] == initial_state + i + end + end + + @testset "sample with $mode" for mode in + [MCMCSerial(), MCMCThreads(), MCMCDistributed()] + nchains = 4 + initial_state = 10 + states_channel = if mode === MCMCDistributed() + # Need to use `RemoteChannel` for this. + RemoteChannel(() -> Channel{Int}(nchains)) + else + Channel{Int}(nchains) + end + chain = sample( + MyModel(), + MySampler(), + mode, + 1, + nchains; + initial_state=FillArrays.Fill(initial_state, nchains), + callback=record_state, + states_channel=states_channel, + ) + + # Extract the states. + states = [take!(states_channel) for _ in 1:nchains] + @test length(states) == nchains + for i in 1:nchains + @test states[i] == initial_state + 1 + end + end + end end diff --git a/test/utils.jl b/test/utils.jl index f69fcdab..1e29a473 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -22,12 +22,12 @@ function AbstractMCMC.step( sampler::MySampler, state::Union{Nothing,Integer}=nothing; loggers=false, - init_params=nothing, + initial_params=nothing, kwargs..., ) # sample `a` is missing in the first step if not provided - a, b = if state === nothing && init_params !== nothing - init_params.a, init_params.b + a, b = if state === nothing && initial_params !== nothing + initial_params.a, initial_params.b else (state === nothing ? missing : rand(rng)), randn(rng) end