diff --git a/docs/src/api.md b/docs/src/api.md index ad2f430a..226943ca 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -82,9 +82,9 @@ Common keyword arguments for regular and parallel sampling are: The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref). There is no "official" way for providing initial parameter values yet. -However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. -To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): -- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`. +However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain. +To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): +- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`. Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. diff --git a/test/sample.jl b/test/sample.jl index a41b3228..ad44054a 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -28,7 +28,7 @@ # initial parameters chain = sample( - MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8) + MyModel(), MySampler(), 3; progress=false, initial_params=(b=3.2, a=-1.8) ) @test chain[1].a == -1.8 @test chain[1].b == 3.2 @@ -163,7 +163,7 @@ # initial parameters nchains = 100 - init_params = [(b=randn(), a=rand()) for _ in 1:nchains] + initial_params = [(b=randn(), a=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -171,15 +171,15 @@ 3, nchains; progress=false, - init_params=init_params, + initial_params=initial_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, init_params) + (chain, params) in zip(chains, initial_params) ) - init_params = (a=randn(), b=rand()) + initial_params = (a=randn(), b=rand()) chains = sample( MyModel(), MySampler(), @@ -187,14 +187,14 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains), + initial_params=Iterators.repeated(initial_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains ) - # Too many `init_params` + # Too many `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -202,10 +202,10 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains + 1), + initial_params=Iterators.repeated(initial_params, nchains + 1), ) - # Too few `init_params` + # Too few `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -213,7 +213,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains - 1), + initial_params=Iterators.repeated(initial_params, nchains - 1), ) end @@ -298,7 +298,7 @@ # initial parameters nchains = 100 - init_params = [(a=randn(), b=rand()) for _ in 1:nchains] + initial_params = [(a=randn(), b=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -306,15 +306,15 @@ 3, nchains; progress=false, - init_params=init_params, + initial_params=initial_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, init_params) + (chain, params) in zip(chains, initial_params) ) - init_params = (b=randn(), a=rand()) + initial_params = (b=randn(), a=rand()) chains = sample( MyModel(), MySampler(), @@ -322,14 +322,14 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains), + initial_params=Iterators.repeated(initial_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains ) - # Too many `init_params` + # Too many `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -337,10 +337,10 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains + 1), + initial_params=Iterators.repeated(initial_params, nchains + 1), ) - # Too few `init_params` + # Too few `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -348,7 +348,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains - 1), + initial_params=Iterators.repeated(initial_params, nchains - 1), ) # Remove workers @@ -407,7 +407,7 @@ # initial parameters nchains = 100 - init_params = [(a=rand(), b=randn()) for _ in 1:nchains] + initial_params = [(a=rand(), b=randn()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -415,15 +415,15 @@ 3, nchains; progress=false, - init_params=init_params, + initial_params=initial_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, init_params) + (chain, params) in zip(chains, initial_params) ) - init_params = (b=rand(), a=randn()) + initial_params = (b=rand(), a=randn()) chains = sample( MyModel(), MySampler(), @@ -431,14 +431,14 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains), + initial_params=Iterators.repeated(initial_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains ) - # Too many `init_params` + # Too many `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -446,10 +446,10 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains + 1), + initial_params=Iterators.repeated(initial_params, nchains + 1), ) - # Too few `init_params` + # Too few `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -457,7 +457,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains - 1), + initial_params=Iterators.repeated(initial_params, nchains - 1), ) end diff --git a/test/utils.jl b/test/utils.jl index f69fcdab..1e29a473 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -22,12 +22,12 @@ function AbstractMCMC.step( sampler::MySampler, state::Union{Nothing,Integer}=nothing; loggers=false, - init_params=nothing, + initial_params=nothing, kwargs..., ) # sample `a` is missing in the first step if not provided - a, b = if state === nothing && init_params !== nothing - init_params.a, init_params.b + a, b = if state === nothing && initial_params !== nothing + initial_params.a, initial_params.b else (state === nothing ? missing : rand(rng)), randn(rng) end