Skip to content

Commit

Permalink
renamed references for init_params to initial_params
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 1, 2023
1 parent 6dbba56 commit e234ccc
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 37 deletions.
6 changes: 3 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ Common keyword arguments for regular and parallel sampling are:
The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref).

There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`.

Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.

Expand Down
62 changes: 31 additions & 31 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# initial parameters
chain = sample(
MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8)
MyModel(), MySampler(), 3; progress=false, initial_params=(b=3.2, a=-1.8)
)
@test chain[1].a == -1.8
@test chain[1].b == 3.2
Expand Down Expand Up @@ -163,57 +163,57 @@

# initial parameters
nchains = 100
init_params = [(b=randn(), a=rand()) for _ in 1:nchains]
initial_params = [(b=randn(), a=rand()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=init_params,
initial_params=initial_params,
)
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
(chain, params) in zip(chains, initial_params)
)

init_params = (a=randn(), b=rand())
initial_params = (a=randn(), b=rand())
chains = sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains),
initial_params=Iterators.repeated(initial_params, nchains),
)
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains
)

# Too many `init_params`
# Too many `initial_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains + 1),
initial_params=Iterators.repeated(initial_params, nchains + 1),
)

# Too few `init_params`
# Too few `initial_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains - 1),
initial_params=Iterators.repeated(initial_params, nchains - 1),
)
end

Expand Down Expand Up @@ -298,57 +298,57 @@

# initial parameters
nchains = 100
init_params = [(a=randn(), b=rand()) for _ in 1:nchains]
initial_params = [(a=randn(), b=rand()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=init_params,
initial_params=initial_params,
)
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
(chain, params) in zip(chains, initial_params)
)

init_params = (b=randn(), a=rand())
initial_params = (b=randn(), a=rand())
chains = sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains),
initial_params=Iterators.repeated(initial_params, nchains),
)
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains
)

# Too many `init_params`
# Too many `initial_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains + 1),
initial_params=Iterators.repeated(initial_params, nchains + 1),
)

# Too few `init_params`
# Too few `initial_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains - 1),
initial_params=Iterators.repeated(initial_params, nchains - 1),
)

# Remove workers
Expand Down Expand Up @@ -407,57 +407,57 @@

# initial parameters
nchains = 100
init_params = [(a=rand(), b=randn()) for _ in 1:nchains]
initial_params = [(a=rand(), b=randn()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=init_params,
initial_params=initial_params,
)
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
(chain, params) in zip(chains, initial_params)
)

init_params = (b=rand(), a=randn())
initial_params = (b=rand(), a=randn())
chains = sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains),
initial_params=Iterators.repeated(initial_params, nchains),
)
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains
)

# Too many `init_params`
# Too many `initial_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains + 1),
initial_params=Iterators.repeated(initial_params, nchains + 1),
)

# Too few `init_params`
# Too few `initial_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains - 1),
initial_params=Iterators.repeated(initial_params, nchains - 1),
)
end

Expand Down
6 changes: 3 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ function AbstractMCMC.step(
sampler::MySampler,
state::Union{Nothing,Integer}=nothing;
loggers=false,
init_params=nothing,
initial_params=nothing,
kwargs...,
)
# sample `a` is missing in the first step if not provided
a, b = if state === nothing && init_params !== nothing
init_params.a, init_params.b
a, b = if state === nothing && initial_params !== nothing
initial_params.a, initial_params.b
else
(state === nothing ? missing : rand(rng)), randn(rng)
end
Expand Down

0 comments on commit e234ccc

Please sign in to comment.