Skip to content

Commit

Permalink
Merge pull request #119 from TuringLang/torfjelde/initial-state
Browse files Browse the repository at this point in the history
Allow specification of initial state for `sample`
  • Loading branch information
torfjelde authored Oct 24, 2023
2 parents d521815 + 3ed5314 commit 8d45ff4
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 68 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ version = "4.5.0"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Expand All @@ -21,6 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
[compat]
BangBang = "0.3.19"
ConsoleProgressMonitor = "0.1"
FillArrays = "1"
LogDensityProblems = "2"
LoggingExtras = "0.4, 0.5, 1"
ProgressLogging = "0.1"
Expand Down
8 changes: 5 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ Common keyword arguments for regular and parallel sampling are:
where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler
- `discard_initial` (default: `0`): number of initial samples that are discarded
- `thinning` (default: `1`): factor by which to thin samples.
- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref)
is passed `initial_state` as the `state` argument.

!!! info
The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref).

There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`.

Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.

Expand Down
1 change: 1 addition & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ProgressLogging: ProgressLogging
using StatsBase: StatsBase
using TerminalLoggers: TerminalLoggers
using Transducers: Transducers
using FillArrays: FillArrays

using Distributed: Distributed
using Logging: Logging
Expand Down
100 changes: 72 additions & 28 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ function mcmcsample(
discard_initial=0,
thinning=1,
chain_type::Type=Any,
initial_state=nothing,
kwargs...,
)
# Check the number of requested samples.
Expand All @@ -122,7 +123,11 @@ function mcmcsample(
end

# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)
sample, state = if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, initial_state; kwargs...)
end

# Discard initial samples.
for i in 1:discard_initial
Expand Down Expand Up @@ -211,6 +216,7 @@ function mcmcsample(
callback=nothing,
discard_initial=0,
thinning=1,
initial_state=nothing,
kwargs...,
)

Expand All @@ -220,7 +226,11 @@ function mcmcsample(

@ifwithprogresslogger progress name = progressname begin
# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)
sample, state = if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Discard initial samples.
for _ in 1:discard_initial
Expand Down Expand Up @@ -288,7 +298,8 @@ function mcmcsample(
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)",
init_params=nothing,
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
# Check if actually multiple threads are used.
Expand All @@ -312,8 +323,9 @@ function mcmcsample(
# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -364,10 +376,15 @@ function mcmcsample(
_sampler,
N;
progress=false,
init_params=if init_params === nothing
initial_params=if initial_params === nothing
nothing
else
initial_params[chainidx]
end,
initial_state=if initial_state === nothing
nothing
else
init_params[chainidx]
initial_state[chainidx]
end,
kwargs...,
)
Expand Down Expand Up @@ -397,7 +414,8 @@ function mcmcsample(
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(Distributed.nworkers()) processes)",
init_params=nothing,
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
# Check if actually multiple processes are used.
Expand All @@ -410,8 +428,14 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)
Expand Down Expand Up @@ -448,7 +472,7 @@ function mcmcsample(

Distributed.@async begin
try
function sample_chain(seed, init_params=nothing)
function sample_chain(seed, initial_params, initial_state)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

Expand All @@ -459,7 +483,8 @@ function mcmcsample(
sampler,
N;
progress=false,
init_params=init_params,
initial_params=initial_params,
initial_state=initial_state,
kwargs...,
)

Expand All @@ -469,11 +494,9 @@ function mcmcsample(
# Return the new chain.
return chain
end
chains = if init_params === nothing
Distributed.pmap(sample_chain, pool, seeds)
else
Distributed.pmap(sample_chain, pool, seeds, init_params)
end
chains = Distributed.pmap(
sample_chain, pool, seeds, _initial_params, _initial_state
)
finally
# Stop updating the progress bar.
progress && put!(channel, false)
Expand All @@ -494,22 +517,29 @@ function mcmcsample(
N::Integer,
nchains::Integer;
progressname="Sampling",
init_params=nothing,
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
# Check if the number of chains is larger than the number of samples
if nchains > N
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Sample the chains.
function sample_chain(i, seed, init_params=nothing)
function sample_chain(i, seed, initial_params, initial_state)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

Expand All @@ -520,16 +550,13 @@ function mcmcsample(
sampler,
N;
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
init_params=init_params,
initial_params=initial_params,
initial_state=initial_state,
kwargs...,
)
end

chains = if init_params === nothing
map(sample_chain, 1:nchains, seeds)
else
map(sample_chain, 1:nchains, seeds, init_params)
end
chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state)

# Concatenate the chains together.
return chainsstack(tighten_eltype(chains))
Expand All @@ -543,7 +570,6 @@ tighten_eltype(x::Vector{Any}) = map(identity, x)
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
),
)

check_initial_params(::Nothing, n) = nothing
function check_initial_params(x::AbstractArray, n)
if length(x) != n
Expand All @@ -556,3 +582,21 @@ function check_initial_params(x::AbstractArray, n)

return nothing
end

@nospecialize check_initial_state(x, n) = throw(
ArgumentError(
"initial states must be specified as a vector of length equal to the number of chains or `nothing`",
),
)
check_initial_state(::Nothing, n) = nothing
function check_initial_state(x::AbstractArray, n)
if length(x) != n
throw(
ArgumentError(
"incorrect number of initial states (expected $n, received $(length(x))"
),
)
end

return nothing
end
Loading

2 comments on commit 8d45ff4

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/94028

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.5.0 -m "<description of version>" 8d45ff49780e1aee2f02ad568eb81908f85980b1
git push origin v4.5.0

Please sign in to comment.