From 0987f5f13668def2cabe32ae3fba636a0e9260c9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Mar 2023 15:45:21 +0000 Subject: [PATCH 01/35] added step_warmup which is can be overloaded when convenient --- Project.toml | 2 +- src/interface.jl | 13 +++++++++++++ src/sample.jl | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 40494911..38064eee 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.4.0" +version = "4.4.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/interface.jl b/src/interface.jl index eaecb492..bda6dfa7 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -56,6 +56,19 @@ current `state` of the sampler. """ function step end +""" + step_warmup(rng, model, sampler[, state; kwargs...]) + +Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`. + +When sampling using [`sample`](@ref), this takes the place of [`step`](@ref) in the first `discard_initial` +number of iterations. This is useful if the sampler has a "warmup"-stage initial stage +that is different from the standard iteration. + +By default, this simply calls [`step`](@ref.) +""" +step_warmup(rng, model, sampler, state; kwargs...) = step(rng, model, sampler, state; kwargs...) + """ samples(sample, model, sampler[, N; kwargs...]) diff --git a/src/sample.jl b/src/sample.jl index dc951ca2..89f2868b 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -133,7 +133,7 @@ function mcmcsample( end # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = step_warmup(rng, model, sampler, state; kwargs...) end # Run callback. From 30c9f123cd3458ced0b0e6e6f1797230630d5ef2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Mar 2023 16:00:40 +0000 Subject: [PATCH 02/35] added step_warmup to docs --- docs/src/design.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/src/design.md b/docs/src/design.md index 0cc524a3..1512ff5e 100644 --- a/docs/src/design.md +++ b/docs/src/design.md @@ -63,6 +63,14 @@ the sampling step of the inference method. AbstractMCMC.step ``` +If one also has some special handling of the warmup-stage of sampling, then this can be specified by overloading + +```@docs +AbstractMCMC.step_warmup +``` + +Note that this is optional; by default it simply calls [`AbstractMCMC.step`](@ref) from above. + ## Collecting samples !!! note From 7faa73fe8fa8c23cfcb4dac95d663fe3645782e0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Mar 2023 16:22:18 +0000 Subject: [PATCH 03/35] Update src/interface.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/interface.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index bda6dfa7..66ad7a1d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -67,7 +67,9 @@ that is different from the standard iteration. By default, this simply calls [`step`](@ref.) """ -step_warmup(rng, model, sampler, state; kwargs...) = step(rng, model, sampler, state; kwargs...) +function step_warmup(rng, model, sampler, state; kwargs...) + return step(rng, model, sampler, state; kwargs...) +end """ samples(sample, model, sampler[, N; kwargs...]) From bd0bdc7dfefc02bf6d7509b07739f6478e590fef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 06:45:02 +0000 Subject: [PATCH 04/35] introduce new kwarg `num_warmup` to `sample` which uses `step_warmup` --- src/interface.jl | 7 ++++--- src/sample.jl | 21 +++++++++++++++++---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index bda6dfa7..adb8ffb6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -61,9 +61,10 @@ function step end Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`. -When sampling using [`sample`](@ref), this takes the place of [`step`](@ref) in the first `discard_initial` -number of iterations. This is useful if the sampler has a "warmup"-stage initial stage -that is different from the standard iteration. +When sampling using [`sample`](@ref), this takes the place of [`step`](@ref) in the first +`num_warmup` number of iterations, as specified by the `num_warmup` keyword to [`sample`](@ref). +This is useful if the sampler has a "warmup"-stage initial stage that is different from the +standard iteration. By default, this simply calls [`step`](@ref.) """ diff --git a/src/sample.jl b/src/sample.jl index 89f2868b..d9fae998 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -100,14 +100,15 @@ function mcmcsample( progress=PROGRESS[], progressname="Sampling", callback=nothing, - discard_initial=0, + num_warmup=0, + discard_initial=num_warmup, thinning=1, chain_type::Type=Any, kwargs..., ) # Check the number of requested samples. N > 0 || error("the number of samples must be ≥ 1") - Ntotal = thinning * (N - 1) + discard_initial + 1 + Ntotal = thinning * (N - 1) + discard_initial + num_warmup + 1 # Start the timer start = time() @@ -125,7 +126,7 @@ function mcmcsample( sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. - for i in 1:discard_initial + for i in 1:num_warmup # Update the progress bar. if progress && i >= next_update ProgressLogging.@logprogress i / Ntotal @@ -136,6 +137,18 @@ function mcmcsample( sample, state = step_warmup(rng, model, sampler, state; kwargs...) end + # Discard initial samples. + for i in 1:discard_initial + # Update the progress bar. + if progress && i >= next_update + ProgressLogging.@logprogress i / Ntotal + next_update = i + threshold + end + + # Obtain the next sample and state. + sample, state = step(rng, model, sampler, state; kwargs...) + end + # Run callback. callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) @@ -144,7 +157,7 @@ function mcmcsample( samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) # Update the progress bar. - itotal = 1 + discard_initial + itotal = 1 + num_warmup + discard_initial if progress && itotal >= next_update ProgressLogging.@logprogress itotal / Ntotal next_update = itotal + threshold From c620cca78f88d6e6b080cf674a5ad96ff09ebcd1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 06:47:37 +0000 Subject: [PATCH 05/35] updated docs --- docs/src/design.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/design.md b/docs/src/design.md index 1512ff5e..cc2651ac 100644 --- a/docs/src/design.md +++ b/docs/src/design.md @@ -69,6 +69,7 @@ If one also has some special handling of the warmup-stage of sampling, then this AbstractMCMC.step_warmup ``` +which will be used for the first `num_warmup`, as specified as a keyword argument to [`AbstractMCMC.sample`](@ref). Note that this is optional; by default it simply calls [`AbstractMCMC.step`](@ref) from above. ## Collecting samples From 572a286b3968d2c23dbc61a27d3d8cc49374ed64 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 07:42:25 +0000 Subject: [PATCH 06/35] allow combination of discard_initial and num_warmup --- src/sample.jl | 142 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 107 insertions(+), 35 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index d9fae998..4932c4aa 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -110,6 +110,12 @@ function mcmcsample( N > 0 || error("the number of samples must be ≥ 1") Ntotal = thinning * (N - 1) + discard_initial + num_warmup + 1 + # Determine how many samples to drop from `num_warmup` and the + # main sampling process before we start saving samples. + discard_from_warmup = min(num_warmup, discard_initial) + keep_from_warmup = num_warmup - discard_from_warmup + discard_from_sample = max(discard_initial - discard_from_warmup, 0) + # Start the timer start = time() local state @@ -125,46 +131,73 @@ function mcmcsample( # Obtain the initial sample and state. sample, state = step(rng, model, sampler; kwargs...) - # Discard initial samples. - for i in 1:num_warmup - # Update the progress bar. - if progress && i >= next_update - ProgressLogging.@logprogress i / Ntotal - next_update = i + threshold - end - + # Warmup sampling. + for _ = 1:discard_from_warmup # Obtain the next sample and state. sample, state = step_warmup(rng, model, sampler, state; kwargs...) end - # Discard initial samples. - for i in 1:discard_initial - # Update the progress bar. - if progress && i >= next_update - ProgressLogging.@logprogress i / Ntotal - next_update = i + threshold + i = 1 + if keep_from_warmup > 0 + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) + samples = save!!(samples, sample, i, model, sampler; kwargs...) + + # Step through remainder of warmup iterations and save. + i += 1 + for _ in (discard_from_warmup + 1):num_warmup + # Update the progress bar. + if progress && i >= next_update + ProgressLogging.@logprogress i / Ntotal + next_update = i + threshold + end + + # Obtain the next sample and state. + sample, state = step_warmup(rng, model, sampler, state; kwargs...) + + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + + # Save the sample. + samples = save!!(samples, sample, i, model, sampler; kwargs...) + i += 1 end + else + # Discard additional initial samples, if needed. + for _ in 1:discard_from_sample + # Update the progress bar. + if progress && i >= next_update + ProgressLogging.@logprogress i / Ntotal + next_update = i + threshold + end - # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) - end + # Obtain the next sample and state. + sample, state = step(rng, model, sampler, state; kwargs...) + end - # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) - # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) - samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) + samples = save!!(samples, sample, i, model, sampler; kwargs...) + + # Increment iteration number. + i += 1 + end # Update the progress bar. - itotal = 1 + num_warmup + discard_initial + itotal = i if progress && itotal >= next_update ProgressLogging.@logprogress itotal / Ntotal next_update = itotal + threshold end # Step through the sampler. - for i in 2:N + while i ≤ N # Discard thinned samples. for _ in 1:(thinning - 1) # Obtain the next sample and state. @@ -187,6 +220,9 @@ function mcmcsample( # Save the sample. samples = save!!(samples, sample, i, model, sampler, N; kwargs...) + # Increment iteration counter. + i += 1 + # Update the progress bar. if progress && (itotal += 1) >= next_update ProgressLogging.@logprogress itotal / Ntotal @@ -222,10 +258,16 @@ function mcmcsample( progress=PROGRESS[], progressname="Convergence sampling", callback=nothing, - discard_initial=0, + num_warmup=0, + discard_initial=num_warmup, thinning=1, kwargs..., ) + # Determine how many samples to drop from `num_warmup` and the + # main sampling process before we start saving samples. + discard_from_warmup = min(num_warmup, discard_initial) + keep_from_warmup = num_warmup - discard_from_warmup + discard_from_sample = max(discard_initial - discard_from_warmup, 0) # Start the timer start = time() @@ -235,21 +277,51 @@ function mcmcsample( # Obtain the initial sample and state. sample, state = step(rng, model, sampler; kwargs...) - # Discard initial samples. - for _ in 1:discard_initial + # Warmup sampling. + for _ = 1:discard_from_warmup # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = step_warmup(rng, model, sampler, state; kwargs...) end - # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) + i = 1 + if keep_from_warmup > 0 + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) - # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) - samples = save!!(samples, sample, 1, model, sampler; kwargs...) + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) + samples = save!!(samples, sample, i, model, sampler; kwargs...) + + # Step through remainder of warmup iterations and save. + i += 1 + for _ in (discard_from_warmup + 1):num_warmup + # Obtain the next sample and state. + sample, state = step_warmup(rng, model, sampler, state; kwargs...) - # Step through the sampler until stopping. - i = 2 + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + + # Save the sample. + samples = save!!(samples, sample, i, model, sampler; kwargs...) + i += 1 + end + else + # Discard additional initial samples, if needed. + for _ in 1:discard_from_sample + # Obtain the next sample and state. + sample, state = step(rng, model, sampler, state; kwargs...) + end + + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) + samples = save!!(samples, sample, i, model, sampler; kwargs...) + + # Increment iteration number. + i += 1 + end while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) # Discard thinned samples. From 6b842ee4cbc93db7652a1c490db4cb57b551e465 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 07:43:21 +0000 Subject: [PATCH 07/35] added docstring for mcmcsample --- src/sample.jl | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index 4932c4aa..dc3b774d 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -91,7 +91,29 @@ function StatsBase.sample( end # Default implementations of regular and parallel sampling. - +""" + mcmcsample(rng, model, sampler, N_or_is_done; kwargs...) + +Default implementation of `sample` for a `model` and `sampler`. + +# Arguments +- `rng::Random.AbstractRNG`: the random number generator to use. +- `model::AbstractModel`: the model to sample from. +- `sampler::AbstractSampler`: the sampler to use. +- `N::Integer`: the number of samples to draw. + +# Keyword arguments +- `progress`: whether to display a progress bar. Defaults to `true`. +- `progressname`: the name of the progress bar. Defaults to `"Sampling"`. +- `callback`: a function that is called after each [`AbstractMCMC.step`](@ref). + Defaults to `nothing`. +- `num_warmup`: number of warmup samples to draw. Defaults to `0`. +- `discard_initial`: number of initial samples to discard. Defaults to `num_warmup`. +- `thinning`: number of samples to discard between samples. Defaults to `1`. +- `chain_type`: the type to pass to [`AbstractMCMC.bundle_samples`](@ref) at the + end of sampling to wrap up the resulting samples nicely. Defaults to `Any`. +- `kwargs...`: Additional keyword arguments to pass on to [`AbstractMCMC.step`](@ref). +""" function mcmcsample( rng::Random.AbstractRNG, model::AbstractModel, From 04417730b3c32d2d1e3067d32a4a61175986c3ef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 07:48:35 +0000 Subject: [PATCH 08/35] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sample.jl | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index dc3b774d..9ad7100b 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -154,7 +154,7 @@ function mcmcsample( sample, state = step(rng, model, sampler; kwargs...) # Warmup sampling. - for _ = 1:discard_from_warmup + for _ in 1:discard_from_warmup # Obtain the next sample and state. sample, state = step_warmup(rng, model, sampler, state; kwargs...) end @@ -162,7 +162,8 @@ function mcmcsample( i = 1 if keep_from_warmup > 0 # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) @@ -181,7 +182,8 @@ function mcmcsample( sample, state = step_warmup(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler; kwargs...) @@ -201,7 +203,8 @@ function mcmcsample( end # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) @@ -300,7 +303,7 @@ function mcmcsample( sample, state = step(rng, model, sampler; kwargs...) # Warmup sampling. - for _ = 1:discard_from_warmup + for _ in 1:discard_from_warmup # Obtain the next sample and state. sample, state = step_warmup(rng, model, sampler, state; kwargs...) end @@ -308,7 +311,8 @@ function mcmcsample( i = 1 if keep_from_warmup > 0 # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) @@ -321,7 +325,8 @@ function mcmcsample( sample, state = step_warmup(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler; kwargs...) @@ -335,7 +340,8 @@ function mcmcsample( end # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) From ea369fff03cee333595377e2047dd7c9e0f9e877 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 09:03:31 +0000 Subject: [PATCH 09/35] Apply suggestions from code review Co-authored-by: David Widmann --- src/sample.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 9ad7100b..3257d70d 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -122,8 +122,8 @@ function mcmcsample( progress=PROGRESS[], progressname="Sampling", callback=nothing, - num_warmup=0, - discard_initial=num_warmup, + num_warmup::Int=0, + discard_initial::Int=num_warmup, thinning=1, chain_type::Type=Any, kwargs..., From 8e0ca5322725256ffa97f7f1f57d8cbaaab56e08 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 09:07:35 +0000 Subject: [PATCH 10/35] Update src/sample.jl Co-authored-by: David Widmann --- src/sample.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index 3257d70d..078f8641 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -130,7 +130,10 @@ function mcmcsample( ) # Check the number of requested samples. N > 0 || error("the number of samples must be ≥ 1") - Ntotal = thinning * (N - 1) + discard_initial + num_warmup + 1 + discard_initial >= 0 || throw(ArgumentError("number of discarded samples must be non-negative")) + num_warmup >= 0 || throw(ArgumentError("number of warm-up samples must be non-negative")) + Ntotal = thinning * (N - 1) + discard_initial + 1 + Ntotal >= num_warmup || throw(ArgumentError("number of warm-up samples exceeds the total number of samples")) # Determine how many samples to drop from `num_warmup` and the # main sampling process before we start saving samples. From 6877978a276989ca37aabe47a2a0d436ef107b39 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 09:11:25 +0000 Subject: [PATCH 11/35] removed docstring and deferred description of keyword arguments to the docs --- src/sample.jl | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index dc3b774d..ea12e816 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -40,6 +40,11 @@ isdone(rng, model, sampler, samples, state, iteration; kwargs...) ``` where `state` and `iteration` are the current state and iteration of the sampler, respectively. It should return `true` when sampling should end, and `false` otherwise. + +# Keyword arguments + +See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword +arguments. """ function StatsBase.sample( rng::Random.AbstractRNG, @@ -77,6 +82,11 @@ end Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel using the `parallel` algorithm, and combine them into a single chain. + +# Keyword arguments + +See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword +arguments. """ function StatsBase.sample( rng::Random.AbstractRNG, @@ -91,29 +101,6 @@ function StatsBase.sample( end # Default implementations of regular and parallel sampling. -""" - mcmcsample(rng, model, sampler, N_or_is_done; kwargs...) - -Default implementation of `sample` for a `model` and `sampler`. - -# Arguments -- `rng::Random.AbstractRNG`: the random number generator to use. -- `model::AbstractModel`: the model to sample from. -- `sampler::AbstractSampler`: the sampler to use. -- `N::Integer`: the number of samples to draw. - -# Keyword arguments -- `progress`: whether to display a progress bar. Defaults to `true`. -- `progressname`: the name of the progress bar. Defaults to `"Sampling"`. -- `callback`: a function that is called after each [`AbstractMCMC.step`](@ref). - Defaults to `nothing`. -- `num_warmup`: number of warmup samples to draw. Defaults to `0`. -- `discard_initial`: number of initial samples to discard. Defaults to `num_warmup`. -- `thinning`: number of samples to discard between samples. Defaults to `1`. -- `chain_type`: the type to pass to [`AbstractMCMC.bundle_samples`](@ref) at the - end of sampling to wrap up the resulting samples nicely. Defaults to `Any`. -- `kwargs...`: Additional keyword arguments to pass on to [`AbstractMCMC.step`](@ref). -""" function mcmcsample( rng::Random.AbstractRNG, model::AbstractModel, From ddc5254a28d75bfedf46ae229c21ffc0091173d6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 09:13:37 +0000 Subject: [PATCH 12/35] Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sample.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index b62bad3d..71948ca0 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -117,8 +117,10 @@ function mcmcsample( ) # Check the number of requested samples. N > 0 || error("the number of samples must be ≥ 1") - discard_initial >= 0 || throw(ArgumentError("number of discarded samples must be non-negative")) - num_warmup >= 0 || throw(ArgumentError("number of warm-up samples must be non-negative")) + discard_initial >= 0 || + throw(ArgumentError("number of discarded samples must be non-negative")) + num_warmup >= 0 || + throw(ArgumentError("number of warm-up samples must be non-negative")) Ntotal = thinning * (N - 1) + discard_initial + 1 Ntotal >= num_warmup || throw(ArgumentError("number of warm-up samples exceeds the total number of samples")) From ffbd32fce8c7103b6634baa233dd447dcdd64bbb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 09:13:45 +0000 Subject: [PATCH 13/35] Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sample.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index 71948ca0..35ddc27a 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -122,7 +122,9 @@ function mcmcsample( num_warmup >= 0 || throw(ArgumentError("number of warm-up samples must be non-negative")) Ntotal = thinning * (N - 1) + discard_initial + 1 - Ntotal >= num_warmup || throw(ArgumentError("number of warm-up samples exceeds the total number of samples")) + Ntotal >= num_warmup || throw( + ArgumentError("number of warm-up samples exceeds the total number of samples") + ) # Determine how many samples to drop from `num_warmup` and the # main sampling process before we start saving samples. From 87480ffcb6f6d1a2d79238605e4f593d78c90af1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 09:18:05 +0000 Subject: [PATCH 14/35] added num_warmup to common keyword arguments docs --- docs/src/api.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 52c2c2e1..21b7c2ff 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -73,7 +73,11 @@ Common keyword arguments for regular and parallel sampling are: - `callback` (default: `nothing`): if `callback !== nothing`, then `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration -- `discard_initial` (default: `0`): number of initial samples that are discarded +- `num_warmup` (default: `0`): number of "warm-up" steps to take before the first "regular" step, + i.e. number of times to call [`AbstractMCMC.step_warmup`](@ref) before the first call to + [`AbstractMCMC.step`](@ref). +- `discard_initial` (default: `num_warmup`): number of initial samples that are discarded. Note that + if `discard_initial < num_warmup`, warm-up samples will also be included in the resulting samples. - `thinning` (default: `1`): factor by which to thin samples. !!! info From 76f2f2322b08d52593cc261cfbd40fcb43cb7fe1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 09:25:04 +0000 Subject: [PATCH 15/35] also allow step_warmup for the initial step --- src/interface.jl | 5 ++--- src/sample.jl | 12 ++++++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 08b5241e..5d121f67 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -68,9 +68,8 @@ standard iteration. By default, this simply calls [`step`](@ref.) """ -function step_warmup(rng, model, sampler, state; kwargs...) - return step(rng, model, sampler, state; kwargs...) -end +step_warmup(rng, model, sampler; kwargs...) = step(rng, model, sampler; kwargs...) +step_warmup(rng, model, sampler, state; kwargs...) = step(rng, model, sampler, state; kwargs...) """ samples(sample, model, sampler[, N; kwargs...]) diff --git a/src/sample.jl b/src/sample.jl index b62bad3d..aecc3212 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -141,7 +141,11 @@ function mcmcsample( end # Obtain the initial sample and state. - sample, state = step(rng, model, sampler; kwargs...) + sample, state = if num_warmup > 0 + step_warmup(rng, model, sampler; kwargs...) + else + step(rng, model, sampler; kwargs...) + end # Warmup sampling. for _ in 1:discard_from_warmup @@ -290,7 +294,11 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. - sample, state = step(rng, model, sampler; kwargs...) + sample, state = if num_warmup > 0 + step_warmup(rng, model, sampler; kwargs...) + else + step(rng, model, sampler; kwargs...) + end # Warmup sampling. for _ in 1:discard_from_warmup From ef09c192239ff2aae9520e967f59d0ee712765dd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 10:09:17 +0000 Subject: [PATCH 16/35] simplify logic for discarding fffinitial samples --- src/sample.jl | 112 +++++++++++++++++--------------------------------- 1 file changed, 38 insertions(+), 74 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 5a35799e..115b0020 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -130,7 +130,6 @@ function mcmcsample( # main sampling process before we start saving samples. discard_from_warmup = min(num_warmup, discard_initial) keep_from_warmup = num_warmup - discard_from_warmup - discard_from_sample = max(discard_initial - discard_from_warmup, 0) # Start the timer start = time() @@ -152,63 +151,43 @@ function mcmcsample( end # Warmup sampling. - for _ in 1:discard_from_warmup + for j in 1:discard_initial # Obtain the next sample and state. - sample, state = step_warmup(rng, model, sampler, state; kwargs...) + sample, state = if j ≤ num_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end end + # Initialize iteration counter. i = 1 - if keep_from_warmup > 0 - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) - # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) - samples = save!!(samples, sample, i, model, sampler; kwargs...) + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) - # Step through remainder of warmup iterations and save. - i += 1 - for _ in (discard_from_warmup + 1):num_warmup - # Update the progress bar. - if progress && i >= next_update - ProgressLogging.@logprogress i / Ntotal - next_update = i + threshold - end + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) + samples = save!!(samples, sample, i, model, sampler; kwargs...) - # Obtain the next sample and state. - sample, state = step_warmup(rng, model, sampler, state; kwargs...) - - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) - - # Save the sample. - samples = save!!(samples, sample, i, model, sampler; kwargs...) - i += 1 + # Step through remainder of warmup iterations and save. + i += 1 + for _ in 1:keep_from_warmup + # Update the progress bar. + if progress && i >= next_update + ProgressLogging.@logprogress i / Ntotal + next_update = i + threshold end - else - # Discard additional initial samples, if needed. - for _ in 1:discard_from_sample - # Update the progress bar. - if progress && i >= next_update - ProgressLogging.@logprogress i / Ntotal - next_update = i + threshold - end - # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) - end + # Obtain the next sample and state. + sample, state = step_warmup(rng, model, sampler, state; kwargs...) # Run callback. callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) samples = save!!(samples, sample, i, model, sampler; kwargs...) - - # Increment iteration number. i += 1 end @@ -290,7 +269,6 @@ function mcmcsample( # main sampling process before we start saving samples. discard_from_warmup = min(num_warmup, discard_initial) keep_from_warmup = num_warmup - discard_from_warmup - discard_from_sample = max(discard_initial - discard_from_warmup, 0) # Start the timer start = time() @@ -305,51 +283,37 @@ function mcmcsample( end # Warmup sampling. - for _ in 1:discard_from_warmup + for j in 1:discard_initial # Obtain the next sample and state. - sample, state = step_warmup(rng, model, sampler, state; kwargs...) + sample, state = if j ≤ num_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end end + # Initialize iteration counter. i = 1 - if keep_from_warmup > 0 - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) - - # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) - samples = save!!(samples, sample, i, model, sampler; kwargs...) - # Step through remainder of warmup iterations and save. - i += 1 - for _ in (discard_from_warmup + 1):num_warmup - # Obtain the next sample and state. - sample, state = step_warmup(rng, model, sampler, state; kwargs...) + # Run callback. + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) + samples = save!!(samples, sample, i, model, sampler; kwargs...) - # Save the sample. - samples = save!!(samples, sample, i, model, sampler; kwargs...) - i += 1 - end - else - # Discard additional initial samples, if needed. - for _ in 1:discard_from_sample - # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) - end + # Step through remainder of warmup iterations and save. + i += 1 + for _ in 1:keep_from_warmup + # Obtain the next sample and state. + sample, state = step_warmup(rng, model, sampler, state; kwargs...) # Run callback. callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) samples = save!!(samples, sample, i, model, sampler; kwargs...) - - # Increment iteration number. i += 1 end From 49b8406115fcd7473bfb6198de3a76c07785f692 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 11:57:03 +0000 Subject: [PATCH 17/35] Apply suggestions from code review Co-authored-by: David Widmann --- src/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index 115b0020..22c5a911 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -150,7 +150,7 @@ function mcmcsample( step(rng, model, sampler; kwargs...) end - # Warmup sampling. + # Discard initial samples. for j in 1:discard_initial # Obtain the next sample and state. sample, state = if j ≤ num_warmup From f005746ec160070665598e45daf1d8b495482d93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 12:06:15 +0000 Subject: [PATCH 18/35] also report progress for the discarded samples --- src/sample.jl | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 115b0020..24a13763 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -150,8 +150,21 @@ function mcmcsample( step(rng, model, sampler; kwargs...) end + # Update the progress bar. + itotal = 1 + if progress && itotal >= next_update + ProgressLogging.@logprogress itotal / Ntotal + next_update = itotal + threshold + end + # Warmup sampling. for j in 1:discard_initial + # Update the progress bar. + if progress && (itotal += 1) >= next_update + ProgressLogging.@logprogress itotal / Ntotal + next_update = itotal + threshold + end + # Obtain the next sample and state. sample, state = if j ≤ num_warmup step_warmup(rng, model, sampler, state; kwargs...) @@ -173,12 +186,6 @@ function mcmcsample( # Step through remainder of warmup iterations and save. i += 1 for _ in 1:keep_from_warmup - # Update the progress bar. - if progress && i >= next_update - ProgressLogging.@logprogress i / Ntotal - next_update = i + threshold - end - # Obtain the next sample and state. sample, state = step_warmup(rng, model, sampler, state; kwargs...) @@ -189,13 +196,12 @@ function mcmcsample( # Save the sample. samples = save!!(samples, sample, i, model, sampler; kwargs...) i += 1 - end - # Update the progress bar. - itotal = i - if progress && itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold + # Update progress bar. + if progress && (itotal += 1) >= next_update + ProgressLogging.@logprogress itotal / Ntotal + next_update = itotal + threshold + end end # Step through the sampler. From ff00e6e1e41cb5f0679b63d8f4a798922bce8dbe Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 13:48:02 +0000 Subject: [PATCH 19/35] Apply suggestions from code review Co-authored-by: David Widmann --- docs/src/design.md | 2 +- src/interface.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/design.md b/docs/src/design.md index cc2651ac..f5becb45 100644 --- a/docs/src/design.md +++ b/docs/src/design.md @@ -69,7 +69,7 @@ If one also has some special handling of the warmup-stage of sampling, then this AbstractMCMC.step_warmup ``` -which will be used for the first `num_warmup`, as specified as a keyword argument to [`AbstractMCMC.sample`](@ref). +which will be used for the first `num_warmup` iterations, as specified as a keyword argument to [`AbstractMCMC.sample`](@ref). Note that this is optional; by default it simply calls [`AbstractMCMC.step`](@ref) from above. ## Collecting samples diff --git a/src/interface.jl b/src/interface.jl index 5d121f67..d790b0b0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -61,12 +61,12 @@ function step end Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`. -When sampling using [`sample`](@ref), this takes the place of [`step`](@ref) in the first +When sampling using [`sample`](@ref), this takes the place of [`AbstractMCMC.step`](@ref) in the first `num_warmup` number of iterations, as specified by the `num_warmup` keyword to [`sample`](@ref). -This is useful if the sampler has a "warmup"-stage initial stage that is different from the +This is useful if the sampler has an initial "warmup"-stage that is different from the standard iteration. -By default, this simply calls [`step`](@ref.) +By default, this simply calls [`AbstractMCMC.step`](@ref). """ step_warmup(rng, model, sampler; kwargs...) = step(rng, model, sampler; kwargs...) step_warmup(rng, model, sampler, state; kwargs...) = step(rng, model, sampler, state; kwargs...) From 7ce9f6b9a72a84493938e55be27d54aaafcb9060 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Mar 2023 13:54:48 +0000 Subject: [PATCH 20/35] move progress-report to end of for-loop for discard samples --- src/sample.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 9fc4643c..16b2f865 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -159,18 +159,18 @@ function mcmcsample( # Discard initial samples. for j in 1:discard_initial - # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold - end - # Obtain the next sample and state. sample, state = if j ≤ num_warmup step_warmup(rng, model, sampler, state; kwargs...) else step(rng, model, sampler, state; kwargs...) end + + # Update the progress bar. + if progress && (itotal += 1) >= next_update + ProgressLogging.@logprogress itotal / Ntotal + next_update = itotal + threshold + end end # Initialize iteration counter. From 3a217b2525bfe303d537406aac760c605af33e53 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Mar 2023 22:04:43 +0000 Subject: [PATCH 21/35] move step_warmup to the inner while loops too --- src/sample.jl | 58 +++++++++++++++++++-------------------------------- 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 16b2f865..e712551d 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -185,31 +185,17 @@ function mcmcsample( # Step through remainder of warmup iterations and save. i += 1 - for _ in 1:keep_from_warmup - # Obtain the next sample and state. - sample, state = step_warmup(rng, model, sampler, state; kwargs...) - - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) - - # Save the sample. - samples = save!!(samples, sample, i, model, sampler; kwargs...) - i += 1 - - # Update progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold - end - end # Step through the sampler. while i ≤ N # Discard thinned samples. for _ in 1:(thinning - 1) # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end # Update progress bar. if progress && (itotal += 1) >= next_update @@ -219,7 +205,12 @@ function mcmcsample( end # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end + # Run callback. callback === nothing || @@ -308,30 +299,23 @@ function mcmcsample( samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) samples = save!!(samples, sample, i, model, sampler; kwargs...) - # Step through remainder of warmup iterations and save. - i += 1 - for _ in 1:keep_from_warmup - # Obtain the next sample and state. - sample, state = step_warmup(rng, model, sampler, state; kwargs...) - - # Run callback. - callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) - - # Save the sample. - samples = save!!(samples, sample, i, model, sampler; kwargs...) - i += 1 - end - while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) # Discard thinned samples. for _ in 1:(thinning - 1) # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end end # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end # Run callback. callback === nothing || From de9bb2cb6c6b03460373d17517e938f42e257642 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Mar 2023 22:19:59 +0000 Subject: [PATCH 22/35] Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sample.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index e712551d..19ef2eca 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -211,7 +211,6 @@ function mcmcsample( step(rng, model, sampler, state; kwargs...) end - # Run callback. callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) From 85d938fbd2f0fea208640545a309b1a0e093ed9e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 19 Apr 2023 08:56:23 +0100 Subject: [PATCH 23/35] Apply suggestions from code review Co-authored-by: David Widmann --- src/sample.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 19ef2eca..7b47fb77 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -180,8 +180,8 @@ function mcmcsample( callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) - samples = save!!(samples, sample, i, model, sampler; kwargs...) + samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) + samples = save!!(samples, sample, i, model, sampler, N; kwargs...) # Step through remainder of warmup iterations and save. i += 1 @@ -278,10 +278,10 @@ function mcmcsample( step(rng, model, sampler; kwargs...) end - # Warmup sampling. + # Discard initial samples. for j in 1:discard_initial # Obtain the next sample and state. - sample, state = if j ≤ num_warmup + sample, state = if j ≤ discard_num_warmup step_warmup(rng, model, sampler, state; kwargs...) else step(rng, model, sampler, state; kwargs...) @@ -297,6 +297,8 @@ function mcmcsample( # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) samples = save!!(samples, sample, i, model, sampler; kwargs...) + + i += 1 while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) # Discard thinned samples. From 0a667a49014c37153371e20869952671d86a8ca9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 19 Apr 2023 08:57:45 +0100 Subject: [PATCH 24/35] reverted to for-loop --- src/sample.jl | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 7b47fb77..28bc6c6e 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -173,21 +173,15 @@ function mcmcsample( end end - # Initialize iteration counter. - i = 1 - # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) - samples = save!!(samples, sample, i, model, sampler, N; kwargs...) - - # Step through remainder of warmup iterations and save. - i += 1 + samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) # Step through the sampler. - while i ≤ N + for i = 2:N # Discard thinned samples. for _ in 1:(thinning - 1) # Obtain the next sample and state. @@ -288,18 +282,14 @@ function mcmcsample( end end - # Initialize iteration counter. - i = 1 - # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) - samples = save!!(samples, sample, i, model, sampler; kwargs...) + samples = save!!(samples, sample, 1, model, sampler; kwargs...) - i += 1 - + i = 2 while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) # Discard thinned samples. for _ in 1:(thinning - 1) From 91f5a10396599a619cfcae661252f41f4ea8ac7e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 19 Apr 2023 08:58:46 +0100 Subject: [PATCH 25/35] Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sample.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index 28bc6c6e..df2a8bea 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -288,7 +288,6 @@ function mcmcsample( # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) samples = save!!(samples, sample, 1, model, sampler; kwargs...) - i = 2 while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) # Discard thinned samples. From 7603171c535673b3c38a2ed9eca1277e5541ca05 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 19 Apr 2023 09:07:41 +0100 Subject: [PATCH 26/35] added accidentanly removed comment --- src/sample.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sample.jl b/src/sample.jl index df2a8bea..ca9409c9 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -288,6 +288,8 @@ function mcmcsample( # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) samples = save!!(samples, sample, 1, model, sampler; kwargs...) + + # Step through the sampler until stopping. i = 2 while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) # Discard thinned samples. From ef68d04350f8df6d0da8f616720e5675cac5cdf9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 19 Apr 2023 09:58:22 +0100 Subject: [PATCH 27/35] Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index ca9409c9..68aae8f6 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -181,7 +181,7 @@ function mcmcsample( samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) # Step through the sampler. - for i = 2:N + for i in 2:N # Discard thinned samples. for _ in 1:(thinning - 1) # Obtain the next sample and state. From 0ea293a03f2ba4222f034426cde2ab92df2df0ce Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 26 Oct 2023 08:27:35 +0100 Subject: [PATCH 28/35] fixed formatting --- src/interface.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index c42bb08b..b58ced99 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -86,7 +86,9 @@ standard iteration. By default, this simply calls [`AbstractMCMC.step`](@ref). """ step_warmup(rng, model, sampler; kwargs...) = step(rng, model, sampler; kwargs...) -step_warmup(rng, model, sampler, state; kwargs...) = step(rng, model, sampler, state; kwargs...) +function step_warmup(rng, model, sampler, state; kwargs...) + return step(rng, model, sampler, state; kwargs...) +end """ samples(sample, model, sampler[, N; kwargs...]) From 6e8f88e70303b1c3c5d163fe02c3bf33acebc3f5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 26 Oct 2023 08:39:27 +0100 Subject: [PATCH 29/35] fix typo --- src/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index b3af2102..8e99ae7e 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -293,7 +293,7 @@ function mcmcsample( # Discard initial samples. for j in 1:discard_initial # Obtain the next sample and state. - sample, state = if j ≤ discard_num_warmup + sample, state = if j ≤ discard_from_warmup step_warmup(rng, model, sampler, state; kwargs...) else step(rng, model, sampler, state; kwargs...) From 3b4f6dbe91ecf0eed30834287d5058a39560a890 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 4 Oct 2024 11:50:22 +0100 Subject: [PATCH 30/35] Apply suggestions from code review Co-authored-by: David Widmann --- src/sample.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index e59f0e32..7227029f 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -224,9 +224,6 @@ function mcmcsample( # Save the sample. samples = save!!(samples, sample, i, model, sampler, N; kwargs...) - # Increment iteration counter. - i += 1 - # Update the progress bar. if progress && (itotal += 1) >= next_update ProgressLogging.@logprogress itotal / Ntotal @@ -296,7 +293,7 @@ function mcmcsample( # Discard initial samples. for j in 1:discard_initial # Obtain the next sample and state. - sample, state = if j ≤ discard_from_warmup + sample, state = if j ≤ num_warmup step_warmup(rng, model, sampler, state; kwargs...) else step(rng, model, sampler, state; kwargs...) From f9142a6a205144b79adba27e8b7ca258157d5b4f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 4 Oct 2024 12:07:49 +0100 Subject: [PATCH 31/35] Added testing of warmup steps --- test/sample.jl | 39 +++++++++++++++++++++++++++++++++++++++ test/utils.jl | 18 ++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/test/sample.jl b/test/sample.jl index dcc87526..7599bd79 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -575,6 +575,45 @@ @test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N) end + @testset "Warm-up steps" begin + # Create a chain and discard initial samples. + Random.seed!(1234) + N = 100 + num_warmup = 50 + + # Everything should be discarded here. + chain = sample(MyModel(), MySampler(), N; num_warmup=num_warmup) + @test length(chain) == N + @test !ismissing(chain[1].a) + + # Repeat sampling without discarding initial samples. + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + ref_chain = sample( + MyModel(), MySampler(), N + num_warmup; progress=VERSION < v"1.6" + ) + @test all(chain[i].a == ref_chain[i + num_warmup].a for i in 1:N) + @test all(chain[i].b == ref_chain[i + num_warmup].b for i in 1:N) + + # Some other stuff. + Random.seed!(1234) + discard_initial = 10 + chain_warmup = sample( + MyModel(), + MySampler(), + N; + num_warmup=num_warmup, + discard_initial=discard_initial, + ) + @test length(chain_warmup) == N + @test all(chain_warmup[i].a == ref_chain[i + discard_initial].a for i in 1:N) + # Check that the first `num_warmup - discard_initial` samples are warmup samples. + @test all( + chain_warmup[i].is_warmup == (i <= num_warmup - discard_initial) for i in 1:N + ) + end + @testset "Thin chain by a factor of `thinning`" begin # Run a thinned chain with `N` samples thinned by factor of `thinning`. Random.seed!(100) diff --git a/test/utils.jl b/test/utils.jl index 1e29a473..b041b3a7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,8 +3,11 @@ struct MyModel <: AbstractMCMC.AbstractModel end struct MySample{A,B} a::A b::B + is_warmup::Bool end +MySample(a, b) = MySample(a, b, false) + struct MySampler <: AbstractMCMC.AbstractSampler end struct AnotherSampler <: AbstractMCMC.AbstractSampler end @@ -16,6 +19,21 @@ end MyChain(a, b) = MyChain(a, b, NamedTuple()) +function AbstractMCMC.step_warmup( + rng::AbstractRNG, + model::MyModel, + sampler::MySampler, + state::Union{Nothing,Integer}=nothing; + loggers=false, + initial_params=nothing, + kwargs..., +) + transition, state = AbstractMCMC.step( + rng, model, sampler, state; loggers, initial_params, kwargs... + ) + return MySample(transition.a, transition.b, true), state +end + function AbstractMCMC.step( rng::AbstractRNG, model::MyModel, From 295fdc1b6a2a1580e9e5b0dfb14df4b6ffce0418 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 4 Oct 2024 12:10:41 +0100 Subject: [PATCH 32/35] Added checks as @devmotion requested --- src/sample.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/sample.jl b/src/sample.jl index 7227029f..e339ce37 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -265,6 +265,13 @@ function mcmcsample( initial_state=nothing, kwargs..., ) + # Check the number of requested samples. + N > 0 || error("the number of samples must be ≥ 1") + discard_initial >= 0 || + throw(ArgumentError("number of discarded samples must be non-negative")) + num_warmup >= 0 || + throw(ArgumentError("number of warm-up samples must be non-negative")) + # Determine how many samples to drop from `num_warmup` and the # main sampling process before we start saving samples. discard_from_warmup = min(num_warmup, discard_initial) From e6acb1f24b9890660fba4672a7f28abfce0721c4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 4 Oct 2024 12:11:59 +0100 Subject: [PATCH 33/35] Removed unintended change in previous commit --- src/sample.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index e339ce37..6e21f180 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -266,7 +266,6 @@ function mcmcsample( kwargs..., ) # Check the number of requested samples. - N > 0 || error("the number of samples must be ≥ 1") discard_initial >= 0 || throw(ArgumentError("number of discarded samples must be non-negative")) num_warmup >= 0 || From 2e9fa5cc3239e9f20586c3609dc433a3d0b0d690 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 4 Oct 2024 12:12:25 +0100 Subject: [PATCH 34/35] Bumped patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 60215cec..c57082ec 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.3.0" +version = "5.3.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From 366fceb0adc9081e15c7ad27f3bb19b9523e1ce6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 4 Oct 2024 12:12:50 +0100 Subject: [PATCH 35/35] Bump minor version instead of patch version since this is a new feature --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c57082ec..f57b1ff0 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.3.1" +version = "5.4.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"