Skip to content

Commit

Permalink
Merge branch 'master' into torfjelde/step-warmup
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 25, 2023
2 parents 25afc66 + dfb33b5 commit 1886fa8
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 71 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.5.0"
version = "5.0.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Expand All @@ -21,6 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
[compat]
BangBang = "0.3.19"
ConsoleProgressMonitor = "0.1"
FillArrays = "1"
LogDensityProblems = "2"
LoggingExtras = "0.4, 0.5, 1"
ProgressLogging = "0.1"
Expand Down
8 changes: 5 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,16 @@ Common keyword arguments for regular and parallel sampling are:
- `discard_initial` (default: `num_warmup`): number of initial samples that are discarded. Note that
if `discard_initial < num_warmup`, warm-up samples will also be included in the resulting samples.
- `thinning` (default: `1`): factor by which to thin samples.
- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref)
is passed `initial_state` as the `state` argument.

!!! info
The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref).

There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`.

Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.

Expand Down
1 change: 1 addition & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ProgressLogging: ProgressLogging
using StatsBase: StatsBase
using TerminalLoggers: TerminalLoggers
using Transducers: Transducers
using FillArrays: FillArrays

using Distributed: Distributed
using Logging: Logging
Expand Down
112 changes: 82 additions & 30 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ function mcmcsample(
discard_initial::Int=num_warmup,
thinning=1,
chain_type::Type=Any,
initial_state=nothing,
kwargs...,
)
# Check the number of requested samples.
Expand Down Expand Up @@ -145,9 +146,17 @@ function mcmcsample(

# Obtain the initial sample and state.
sample, state = if num_warmup > 0
step_warmup(rng, model, sampler; kwargs...)
if initial_state === nothing
step_warmup(rng, model, sampler; kwargs...)
else
step_warmup(rng, model, sampler, initial_state; kwargs...)
end
else
step(rng, model, sampler; kwargs...)
if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, initial_state; kwargs...)
end
end

# Update the progress bar.
Expand Down Expand Up @@ -253,6 +262,7 @@ function mcmcsample(
num_warmup=0,
discard_initial=num_warmup,
thinning=1,
initial_state=nothing,
kwargs...,
)
# Determine how many samples to drop from `num_warmup` and the
Expand All @@ -267,9 +277,17 @@ function mcmcsample(
@ifwithprogresslogger progress name = progressname begin
# Obtain the initial sample and state.
sample, state = if num_warmup > 0
step_warmup(rng, model, sampler; kwargs...)
if initial_state === nothing
step_warmup(rng, model, sampler; kwargs...)
else
step_warmup(rng, model, sampler, initial_state; kwargs...)
end
else
step(rng, model, sampler; kwargs...)
if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, initial_state; kwargs...)
end
end

# Discard initial samples.
Expand Down Expand Up @@ -349,7 +367,8 @@ function mcmcsample(
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)",
init_params=nothing,
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
# Check if actually multiple threads are used.
Expand All @@ -373,8 +392,9 @@ function mcmcsample(
# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -425,10 +445,15 @@ function mcmcsample(
_sampler,
N;
progress=false,
init_params=if init_params === nothing
initial_params=if initial_params === nothing
nothing
else
init_params[chainidx]
initial_params[chainidx]
end,
initial_state=if initial_state === nothing
nothing
else
initial_state[chainidx]
end,
kwargs...,
)
Expand Down Expand Up @@ -458,7 +483,8 @@ function mcmcsample(
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(Distributed.nworkers()) processes)",
init_params=nothing,
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
# Check if actually multiple processes are used.
Expand All @@ -471,8 +497,14 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)
Expand Down Expand Up @@ -509,7 +541,7 @@ function mcmcsample(

Distributed.@async begin
try
function sample_chain(seed, init_params=nothing)
function sample_chain(seed, initial_params, initial_state)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

Expand All @@ -520,7 +552,8 @@ function mcmcsample(
sampler,
N;
progress=false,
init_params=init_params,
initial_params=initial_params,
initial_state=initial_state,
kwargs...,
)

Expand All @@ -530,11 +563,9 @@ function mcmcsample(
# Return the new chain.
return chain
end
chains = if init_params === nothing
Distributed.pmap(sample_chain, pool, seeds)
else
Distributed.pmap(sample_chain, pool, seeds, init_params)
end
chains = Distributed.pmap(
sample_chain, pool, seeds, _initial_params, _initial_state
)
finally
# Stop updating the progress bar.
progress && put!(channel, false)
Expand All @@ -555,22 +586,29 @@ function mcmcsample(
N::Integer,
nchains::Integer;
progressname="Sampling",
init_params=nothing,
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
# Check if the number of chains is larger than the number of samples
if nchains > N
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Sample the chains.
function sample_chain(i, seed, init_params=nothing)
function sample_chain(i, seed, initial_params, initial_state)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

Expand All @@ -581,16 +619,13 @@ function mcmcsample(
sampler,
N;
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
init_params=init_params,
initial_params=initial_params,
initial_state=initial_state,
kwargs...,
)
end

chains = if init_params === nothing
map(sample_chain, 1:nchains, seeds)
else
map(sample_chain, 1:nchains, seeds, init_params)
end
chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state)

# Concatenate the chains together.
return chainsstack(tighten_eltype(chains))
Expand All @@ -604,7 +639,6 @@ tighten_eltype(x::Vector{Any}) = map(identity, x)
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
),
)

check_initial_params(::Nothing, n) = nothing
function check_initial_params(x::AbstractArray, n)
if length(x) != n
Expand All @@ -617,3 +651,21 @@ function check_initial_params(x::AbstractArray, n)

return nothing
end

@nospecialize check_initial_state(x, n) = throw(
ArgumentError(
"initial states must be specified as a vector of length equal to the number of chains or `nothing`",
),
)
check_initial_state(::Nothing, n) = nothing
function check_initial_state(x::AbstractArray, n)
if length(x) != n
throw(
ArgumentError(
"incorrect number of initial states (expected $n, received $(length(x))"
),
)
end

return nothing
end
Loading

0 comments on commit 1886fa8

Please sign in to comment.