From 439e4709d6e897c4b9e0541c2336fe2fcf5d111f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 20 Jan 2022 23:23:20 +0100 Subject: [PATCH 01/54] Fix signature of `isdone` in docstring (#91) * Fix signature of `isdone` in docstring * Update sample.jl --- src/sample.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index df76caf0..9d48df39 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -44,9 +44,10 @@ convergence criterion `isdone` returns `true`, and return the samples. The function `isdone` has the signature ```julia -isdone(rng, model, sampler, samples, iteration; kwargs...) +isdone(rng, model, sampler, samples, state, iteration; kwargs...) ``` -and should return `true` when sampling should end, and `false` otherwise. +where `state` and `iteration` are the current state and iteration of the sampler, respectively. +It should return `true` when sampling should end, and `false` otherwise. """ function StatsBase.sample( rng::Random.AbstractRNG, From fe972e8ee2b091070866e113c4aca6f849653027 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 21 Jan 2022 00:04:01 +0100 Subject: [PATCH 02/54] Fix test error (#93) * Fix test error * Fix another tolerance --- test/stepper.jl | 2 +- test/transducer.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/stepper.jl b/test/stepper.jl index f3a4b599..bc75d637 100644 --- a/test/stepper.jl +++ b/test/stepper.jl @@ -21,7 +21,7 @@ @test length(as) == length(bs) == 998 - @test mean(as) ≈ 0.5 atol=1e-2 + @test mean(as) ≈ 0.5 atol=2e-2 @test var(as) ≈ 1 / 12 atol=5e-3 @test mean(bs) ≈ 0.0 atol=5e-2 @test var(bs) ≈ 1 atol=5e-2 diff --git a/test/transducer.jl b/test/transducer.jl index 2b363e27..910f9d70 100644 --- a/test/transducer.jl +++ b/test/transducer.jl @@ -45,7 +45,7 @@ @test length(as) == length(bs) == 998 - @test mean(as) ≈ 0.5 atol=1e-2 + @test mean(as) ≈ 0.5 atol=2e-2 @test var(as) ≈ 1 / 12 atol=5e-3 @test mean(bs) ≈ 0.0 atol=5e-2 @test var(bs) ≈ 1 atol=5e-2 From 7aa0e8c194a1a42f330544e877f78798a01ffe60 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 19 Feb 2022 05:09:18 +0100 Subject: [PATCH 03/54] Remove use of `threadid` --- src/sample.jl | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 9d48df39..aefb67b4 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -298,16 +298,15 @@ function mcmcsample( end # Copy the random number generator, model, and sample for each thread - # NOTE: As of May 17, 2020, this relies on Julia's thread scheduling functionality - # that distributes a for loop into equal-sized blocks and allocates them - # to each thread. If this changes, we may need to rethink things here. + nchunks = min(nchains, Threads.nthreads()) + chunksize = cld(nchains, nchunks) interval = 1:min(nchains, Threads.nthreads()) rngs = [deepcopy(rng) for _ in interval] models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] - # Create a seed for each chain using the provided random number generator. - seeds = rand(rng, UInt, nchains) + # Create a seed for each chunk using the provided random number generator. + seeds = rand(rng, UInt, nchunks) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -340,20 +339,26 @@ function mcmcsample( Distributed.@async begin try - Threads.@threads for i in 1:nchains - # Obtain the ID of the current thread. - id = Threads.threadid() - - # Seed the thread-specific random number generator with the pre-made seed. - subrng = rngs[id] - Random.seed!(subrng, seeds[i]) - - # Sample a chain and save it to the vector. - chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N; - progress = false, kwargs...) - - # Update the progress bar. - progress && put!(channel, true) + for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers) + Threads.@spawn begin + chainidxs = if i == nchunks + ((i - 1) * chunksize + 1):nchains + else + ((i - 1) * chunksize + 1):(i * chunksize) + end + + # Seed the chunk-specific random number generator with the pre-made seed. + Random.seed!(_rng, seed) + + for chainidx in chainidxs + # Sample a chain and save it to the vector. + chains[i] = StatsBase.sample(_rng, _model, _sampler, N; + progress = false, kwargs...) + + # Update the progress bar. + progress && put!(channel, true) + end + end end finally # Stop updating the progress bar. From f300c711ae5cd5f8415530155260232dcab10042 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 19 Feb 2022 05:11:21 +0100 Subject: [PATCH 04/54] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7ccac490..0299dad9 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.2.1" +version = "3.2.2" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From 459d8f97111e19fe7939c87b31eafc224c2141e4 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 19 Feb 2022 05:21:15 +0100 Subject: [PATCH 05/54] Fix issue --- src/sample.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index aefb67b4..fb075106 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -341,19 +341,19 @@ function mcmcsample( try for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers) Threads.@spawn begin + # Seed the chunk-specific random number generator with the pre-made seed. + Random.seed!(_rng, seed) + chainidxs = if i == nchunks ((i - 1) * chunksize + 1):nchains else ((i - 1) * chunksize + 1):(i * chunksize) end - # Seed the chunk-specific random number generator with the pre-made seed. - Random.seed!(_rng, seed) - for chainidx in chainidxs # Sample a chain and save it to the vector. - chains[i] = StatsBase.sample(_rng, _model, _sampler, N; - progress = false, kwargs...) + chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N; + progress = false, kwargs...) # Update the progress bar. progress && put!(channel, true) From 9d89f77834e5f31e18b10b79d9a74b01ec820565 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 19 Feb 2022 05:49:09 +0100 Subject: [PATCH 06/54] Drop support for Julia < 1.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0299dad9..7884c6f7 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ ProgressLogging = "0.1" StatsBase = "0.32, 0.33" TerminalLoggers = "0.1" Transducers = "0.4.30" -julia = "1" +julia = "1.3" [extras] Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1" From b53c961d79ca0a28e6b435d69bb0ebf75e1108d9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 19 Feb 2022 05:51:16 +0100 Subject: [PATCH 07/54] Update CI --- .github/workflows/CI.yml | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 37cc617e..069ef82b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: version: - - '1.0' + - '1.3' - '1' - nightly os: @@ -31,7 +31,7 @@ jobs: arch: x86 - os: macOS-latest arch: x86 - - version: '1.0' + - version: '1.3' num_threads: 2 include: - version: '1' @@ -45,16 +45,7 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- + - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest env: From bb7ced2bd6ffcef34d0440c5688d3dd5c6dbe97d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 19 Feb 2022 09:03:55 +0100 Subject: [PATCH 08/54] Update sample.jl --- src/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index fb075106..e783760f 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -339,7 +339,7 @@ function mcmcsample( Distributed.@async begin try - for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers) + Distributed.@sync for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers) Threads.@spawn begin # Seed the chunk-specific random number generator with the pre-made seed. Random.seed!(_rng, seed) From 6a7370621772453232f2024c2bf271e1fdafceb9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 18:34:32 +0100 Subject: [PATCH 09/54] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7884c6f7..794deb24 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.2.2" +version = "3.3.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From 05f91cf88a641c73ec9555fb780be5c2453156e1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 22:33:52 +0100 Subject: [PATCH 10/54] Fix reproducibility of ensemble sampling --- src/sample.jl | 53 +++++++++++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index e783760f..e1ff1be7 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -305,8 +305,8 @@ function mcmcsample( models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] - # Create a seed for each chunk using the provided random number generator. - seeds = rand(rng, UInt, nchunks) + # Create a seed for each chain using the provided random number generator. + seeds = rand(rng, UInt, nchains) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -339,25 +339,22 @@ function mcmcsample( Distributed.@async begin try - Distributed.@sync for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers) - Threads.@spawn begin + Distributed.@sync for (i, _rng, _model, _sampler) in zip(1:nchunks, rngs, models, samplers) + chainidxs = if i == nchunks + ((i - 1) * chunksize + 1):nchains + else + ((i - 1) * chunksize + 1):(i * chunksize) + end + Threads.@spawn for chainidx in chainidxs # Seed the chunk-specific random number generator with the pre-made seed. - Random.seed!(_rng, seed) - - chainidxs = if i == nchunks - ((i - 1) * chunksize + 1):nchains - else - ((i - 1) * chunksize + 1):(i * chunksize) - end - - for chainidx in chainidxs - # Sample a chain and save it to the vector. - chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N; - progress = false, kwargs...) - - # Update the progress bar. - progress && put!(channel, true) - end + Random.seed!(_rng, seeds[chainidx]) + + # Sample a chain and save it to the vector. + chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N; + progress = false, kwargs...) + + # Update the progress bar. + progress && put!(channel, true) end end finally @@ -469,12 +466,18 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Create a seed for each chain using the provided random number generator. + seeds = rand(rng, UInt, nchains) + # Sample the chains. - chains = map( - i -> StatsBase.sample(rng, model, sampler, N; progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"), - kwargs...), - 1:nchains - ) + chains = map(enumerate(seeds)) do (i, seed) + Random.seed!(rng, seed) + return StatsBase.sample( + rng, model, sampler, N; + progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"), + kwargs..., + ) + end # Concatenate the chains together. return chainsstack(tighten_eltype(chains)) From dca3aa09a4171c7460966747628c694bf8e08c64 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 22:34:35 +0100 Subject: [PATCH 11/54] Add tests --- test/sample.jl | 120 ++++++++++++++++++++++++++++++------------------- 1 file changed, 74 insertions(+), 46 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index 6e876d48..d8a04712 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -103,61 +103,59 @@ end end - if VERSION ≥ v"1.3" - @testset "Multithreaded sampling" begin - if Threads.nthreads() == 1 - warnregex = r"^Only a single thread available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), - 10, 10) - end + @testset "Multithreaded sampling" begin + if Threads.nthreads() == 1 + warnregex = r"^Only a single thread available" + @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), + 10, 10) + end - # No dedicated chains type - N = 10_000 - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000) - @test chains isa Vector{<:Vector{<:MySample}} - @test length(chains) == 1000 - @test all(length(x) == N for x in chains) + # No dedicated chains type + N = 10_000 + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000) + @test chains isa Vector{<:Vector{<:MySample}} + @test length(chains) == 1000 + @test all(length(x) == N for x in chains) - Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + Random.seed!(1234) + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; + chain_type = MyChain) - # test output type and size - @test chains isa Vector{<:MyChain} - @test length(chains) == 1000 - @test all(x -> length(x.as) == length(x.bs) == N, chains) + # test output type and size + @test chains isa Vector{<:MyChain} + @test length(chains) == 1000 + @test all(x -> length(x.as) == length(x.bs) == N, chains) - # test some statistical properties - @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) - @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + # test some statistical properties + @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) + @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) - # test reproducibility - Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + # test reproducibility + Random.seed!(1234) + chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; + chain_type = MyChain) - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - # Unexpected order of arguments. - str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCThreads(), 5, 10; - chain_type = MyChain) + # Unexpected order of arguments. + str = "Number of chains (10) is greater than number of samples per chain (5)" + @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), + MCMCThreads(), 5, 10; + chain_type = MyChain) - # Suppress output. - logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; - progress = false, chain_type = MyChain) - end - @test all(l.level > Logging.LogLevel(-1) for l in logs) + # Suppress output. + logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do + sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; + progress = false, chain_type = MyChain) + end + @test all(l.level > Logging.LogLevel(-1) for l in logs) - # Smoke test for nchains < nthreads - if Threads.nthreads() == 2 - sample(MyModel(), MySampler(), MCMCThreads(), N, 1) - end + # Smoke test for nchains < nthreads + if Threads.nthreads() == 2 + sample(MyModel(), MySampler(), MCMCThreads(), N, 1) end end @@ -271,6 +269,36 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) end + @testset "Ensemble sampling: Reproducibility" begin + N = 1_000 + nchains = 10 + + # Serial sampling + Random.seed!(1234) + chains_serial = sample( + MyModel(), MySampler(), MCMCSerial(), N, nchains; + progress=false, chain_type=MyChain + ) + + # Multi-threaded sampling + Random.seed!(1234) + chains_threads = sample( + MyModel(), MySampler(), MCMCThreads(), N, nchains; + progress=false, chain_type=MyChain + ) + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) + + # Multi-core sampling + Random.seed!(1234) + chains_distributed = sample( + MyModel(), MySampler(), MCMCDistributed(), N, nchains; + progress=false, chain_type=MyChain + ) + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) + end + @testset "Chain constructors" begin chain1 = sample(MyModel(), MySampler(), 100; sleepy = true) chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain) From 25d8c391322aa037e0398210c8e56df1da1d8e3d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 22:34:42 +0100 Subject: [PATCH 12/54] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 794deb24..7d3a2ca9 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.3.0" +version = "3.3.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From d24e6b04bc3596e933eafd2b5ad0942b55081426 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 22:56:26 +0100 Subject: [PATCH 13/54] Fix CI --- .github/workflows/CI.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 069ef82b..d5b1273a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -46,6 +46,8 @@ jobs: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - uses: julia-actions/cache@v1 + with: + cache-packages: "false" # caching Conda.jl causes precompilation error - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest env: From 20ffc79d51e3ca2bc0bb75efa2366389c7fc08f1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 23:34:00 +0100 Subject: [PATCH 14/54] Increase tolerances --- test/sample.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index d8a04712..407330d2 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -130,7 +130,7 @@ @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) # test reproducibility Random.seed!(1234) @@ -199,7 +199,7 @@ @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) # Test reproducibility. Random.seed!(1234) @@ -245,7 +245,7 @@ @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) # Test reproducibility. Random.seed!(1234) From 22299e27d56530b367593f2e595cac8e659f9362 Mon Sep 17 00:00:00 2001 From: Cameron Pfiffer Date: Tue, 22 Feb 2022 11:05:42 -0800 Subject: [PATCH 15/54] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7d3a2ca9..bf356aa1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.3.1" +version = "3.3.2" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From 93284f02fc4b2e6544e30134f6a506d4b9b1b5a7 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 23 Feb 2022 00:41:34 +0100 Subject: [PATCH 16/54] Fix version number (#98) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bf356aa1..7d3a2ca9 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.3.2" +version = "3.3.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From acfadfab3013c3ae051dafa856ccd5137617265b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 2 Mar 2022 22:23:11 +0100 Subject: [PATCH 17/54] Fix test errors and simplify docs (#99) * Fix Windows test errors (hopefully) * Simplify building docs * Update .gitignore * Fix another test error * Fix another test error --- .github/workflows/Docs.yml | 9 +- .gitignore | 2 +- docs/Manifest.toml | 374 ------------------------------------- docs/Project.toml | 2 - docs/make.jl | 8 +- test/sample.jl | 12 +- test/utils.jl | 2 +- 7 files changed, 14 insertions(+), 395 deletions(-) delete mode 100644 docs/Manifest.toml diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml index e47f9389..afa6ee8b 100644 --- a/.github/workflows/Docs.yml +++ b/.github/workflows/Docs.yml @@ -12,13 +12,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v1 with: version: '1' - - name: Install dependencies - run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - - name: Build and deploy + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key - run: julia --project=docs/ docs/make.jl + JULIA_DEBUG: Documenter # Print `@debug` statements (https://github.com/JuliaDocs/Documenter.jl/issues/955) diff --git a/.gitignore b/.gitignore index dfa313a1..83d89f72 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ *.jl.*.cov *.jl.mem deps/deps.jl -/Manifest.toml \ No newline at end of file +Manifest.toml \ No newline at end of file diff --git a/docs/Manifest.toml b/docs/Manifest.toml deleted file mode 100644 index 014aac99..00000000 --- a/docs/Manifest.toml +++ /dev/null @@ -1,374 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[AbstractMCMC]] -deps = ["BangBang", "ConsoleProgressMonitor", "Distributed", "Logging", "LoggingExtras", "ProgressLogging", "Random", "StatsBase", "TerminalLoggers", "Transducers"] -path = ".." -uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" -version = "3.2.1" - -[[AbstractTrees]] -git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.3.4" - -[[Adapt]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.3.1" - -[[ArgCheck]] -git-tree-sha1 = "dedbbb2ddb876f899585c4ec4433265e3017215a" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.1.0" - -[[ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" - -[[Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[BangBang]] -deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] -git-tree-sha1 = "e239020994123f08905052b9603b4ca14f8c5807" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.31" - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.31.0" - -[[CompositionsBase]] -git-tree-sha1 = "f3955eb38944e5dd0fabf8ca1e267d94941d34a5" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.0" - -[[ConsoleProgressMonitor]] -deps = ["Logging", "ProgressMeter"] -git-tree-sha1 = "3ab7b2136722890b9af903859afcf457fa3059e8" -uuid = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" -version = "0.1.2" - -[[ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f74e9d5388b8620b4cee35d4c5a618dd4dc547f4" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.3.0" - -[[DataAPI]] -git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.7.0" - -[[DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.9" - -[[DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[DefineSingletons]] -git-tree-sha1 = "77b4ca280084423b728662fe040e5ff8819347c5" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.1" - -[[DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.5" - -[[Documenter]] -deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48" -uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.3" - -[[Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" - -[[Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[IOCapture]] -deps = ["Logging", "Random"] -git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" -uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.2" - -[[InitialValues]] -git-tree-sha1 = "26c8832afd63ac558b98a823265856670d898b6c" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.2.10" - -[[InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.1" - -[[LeftChildRightSiblingTrees]] -deps = ["AbstractTrees"] -git-tree-sha1 = "71be1eb5ad19cb4f61fa8c73395c0338fd092ae0" -uuid = "1d6d02ad-be62-4b6b-8a6d-2f90e265016e" -version = "0.1.2" - -[[LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" - -[[LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" - -[[LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "dfeda1c1130990428720de0024d4516b1902ce98" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "0.4.7" - -[[MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.6" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" - -[[MicroCollections]] -deps = ["BangBang", "Setfield"] -git-tree-sha1 = "e991b6a9d38091c4a0d7cd051fcb57c05f98ac03" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.0" - -[[Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.0.0" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" - -[[NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" - -[[OrderedCollections]] -git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.1" - -[[Parsers]] -deps = ["Dates"] -git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.1.0" - -[[Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "afadeba63d90ff223a6a48d2009434ecee2ec9e8" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.7.1" - -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.1.3" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "Requires"] -git-tree-sha1 = "d5640fc570fb1b6c54512f0bd3853866bd298b3e" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "0.7.0" - -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.0.0" - -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "edef25a158db82f4940720ebada14a60ef6c4232" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.13" - -[[Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[StatsAPI]] -git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.0.0" - -[[StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.8" - -[[TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" - -[[TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] -git-tree-sha1 = "8ed4a3ea724dac32670b062be3ef1c1de6773ae8" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.4.4" - -[[Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" - -[[TerminalLoggers]] -deps = ["LeftChildRightSiblingTrees", "Logging", "Markdown", "Printf", "ProgressLogging", "UUIDs"] -git-tree-sha1 = "d620a061cb2a56930b52bdf5cf908a5c4fa8e76a" -uuid = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" -version = "0.1.4" - -[[Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[Transducers]] -deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "34f27ac221cb53317ab6df196f9ed145077231ff" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.65" - -[[UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" - -[[ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.1" - -[[nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" - -[[p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/docs/Project.toml b/docs/Project.toml index 69dcc9d0..555443ab 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,9 +1,7 @@ [deps] -AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -AbstractMCMC = "3" Documenter = "0.27" julia = "1" diff --git a/docs/make.jl b/docs/make.jl index e0fa16e9..67978d1e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,11 +1,5 @@ -using Documenter - -# Print `@debug` statements (https://github.com/JuliaDocs/Documenter.jl/issues/955) -if haskey(ENV, "GITHUB_ACTIONS") - ENV["JULIA_DEBUG"] = "Documenter" -end - using AbstractMCMC +using Documenter using Random DocMeta.setdocmeta!( diff --git a/test/sample.jl b/test/sample.jl index 407330d2..19242629 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -167,7 +167,9 @@ end # Add worker processes. - addprocs() + # Memory requirements on Windows are ~4x larger than on Linux, hence number of processes is reduced + # See, e.g., https://github.com/JuliaLang/julia/issues/40766 and https://github.com/JuliaLang/Pkg.jl/pull/2366 + addprocs(Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int) # Load all required packages (`interface.jl` needs Random). @everywhere begin @@ -310,7 +312,7 @@ @testset "Sample stats" begin chain = sample(MyModel(), MySampler(), 1000; chain_type = MyChain) - @test chain.stats.stop > chain.stats.start + @test chain.stats.stop >= chain.stats.start @test chain.stats.duration == chain.stats.stop - chain.stats.start end @@ -341,19 +343,19 @@ chain = sample(MyModel(), MySampler()) bmean = mean(x.b for x in chain) @test ismissing(chain[1].a) - @test abs(bmean) <= 0.001 && length(chain) < 10_000 + @test abs(bmean) <= 0.001 || length(chain) == 10_000 # Discard initial samples. chain = sample(MyModel(), MySampler(); discard_initial = 50) bmean = mean(x.b for x in chain) @test !ismissing(chain[1].a) - @test abs(bmean) <= 0.001 && length(chain) < 10_000 + @test abs(bmean) <= 0.001 || length(chain) == 10_000 # Thin chain by a factor of `thinning`. chain = sample(MyModel(), MySampler(); thinning = 3) bmean = mean(x.b for x in chain) @test ismissing(chain[1].a) - @test abs(bmean) <= 0.001 && length(chain) < 10_000 + @test abs(bmean) <= 0.001 || length(chain) == 10_000 end @testset "Sample vector of `NamedTuple`s" begin diff --git a/test/utils.jl b/test/utils.jl index f6ac9d27..cd3543b7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -63,7 +63,7 @@ function isdone( ) # Calculate the mean of x.b. bmean = mean(x.b for x in samples) - return abs(bmean) <= 0.001 || iteration >= 10_000 || state >= 10_000 + return abs(bmean) <= 0.001 || iteration > 10_000 end # Set a default convergence function. From 4994a79735b3f4518bd5c62dd062c1f2dd16ebae Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 5 Mar 2022 14:23:48 +0100 Subject: [PATCH 18/54] Use Blue style (#100) --- .JuliaFormatter.toml | 1 + .github/workflows/Format.yml | 31 ++++++ README.md | 1 + docs/make.jl | 17 +-- src/AbstractMCMC.jl | 23 ++-- src/deprecations.jl | 2 +- src/interface.jl | 51 ++------- src/logging.jl | 27 +++-- src/sample.jl | 116 ++++++++++---------- src/samplingstats.jl | 2 +- src/stepper.jl | 11 +- src/transducer.jl | 8 +- test/deprecations.jl | 2 +- test/runtests.jl | 2 +- test/sample.jl | 206 ++++++++++++++++++++++------------- test/stepper.jl | 10 +- test/transducer.jl | 25 ++--- test/utils.jl | 19 ++-- 18 files changed, 297 insertions(+), 257 deletions(-) create mode 100644 .JuliaFormatter.toml create mode 100644 .github/workflows/Format.yml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..1e72b507 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style="blue" diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml new file mode 100644 index 00000000..ec14da16 --- /dev/null +++ b/.github/workflows/Format.yml @@ -0,0 +1,31 @@ +name: Format + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - name: Format code + run: | + using Pkg + Pkg.add(; name="JuliaFormatter", uuid="98e50ef6-434e-11e9-1051-2b60c6c9e899") + using JuliaFormatter + format("."; verbose=true) + shell: julia --color=yes {0} + - uses: reviewdog/action-suggester@v1 + if: github.event_name == 'pull_request' + with: + tool_name: JuliaFormatter + fail_on_error: true diff --git a/README.md b/README.md index a2d40c34..ee186269 100644 --- a/README.md +++ b/README.md @@ -8,3 +8,4 @@ Abstract types and interfaces for Markov chain Monte Carlo methods. [![IntegrationTest](https://github.com/TuringLang/AbstractMCMC.jl/workflows/IntegrationTest/badge.svg?branch=master)](https://github.com/TuringLang/AbstractMCMC.jl/actions?query=workflow%3AIntegrationTest+branch%3Amaster) [![Codecov](https://codecov.io/gh/TuringLang/AbstractMCMC.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AbstractMCMC.jl) [![Coveralls](https://coveralls.io/repos/github/TuringLang/AbstractMCMC.jl/badge.svg?branch=master)](https://coveralls.io/github/TuringLang/AbstractMCMC.jl?branch=master) +[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) diff --git a/docs/make.jl b/docs/make.jl index 67978d1e..66d7619c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,26 +2,15 @@ using AbstractMCMC using Documenter using Random -DocMeta.setdocmeta!( - AbstractMCMC, - :DocTestSetup, - :(using AbstractMCMC); - recursive=true, -) +DocMeta.setdocmeta!(AbstractMCMC, :DocTestSetup, :(using AbstractMCMC); recursive=true) makedocs(; sitename="AbstractMCMC", format=Documenter.HTML(), modules=[AbstractMCMC], - pages=[ - "Home" => "index.md", - "api.md", - "design.md", - ], + pages=["Home" => "index.md", "api.md", "design.md"], strict=true, checkdocs=:exports, ) -deploydocs(; - repo="github.com/TuringLang/AbstractMCMC.jl.git", push_preview=true -) +deploydocs(; repo="github.com/TuringLang/AbstractMCMC.jl.git", push_preview=true) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index ef23cb51..686924a8 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -1,16 +1,16 @@ module AbstractMCMC -import BangBang -import ConsoleProgressMonitor -import LoggingExtras -import ProgressLogging -import StatsBase -import TerminalLoggers -import Transducers - -import Distributed -import Logging -import Random +using BangBang: BangBang +using ConsoleProgressMonitor: ConsoleProgressMonitor +using LoggingExtras: LoggingExtras +using ProgressLogging: ProgressLogging +using StatsBase: StatsBase +using TerminalLoggers: TerminalLoggers +using Transducers: Transducers + +using Distributed: Distributed +using Logging: Logging +using Random: Random # Reexport sample using StatsBase: sample @@ -71,7 +71,6 @@ processes. """ struct MCMCDistributed <: AbstractMCMCEnsemble end - """ MCMCSerial diff --git a/src/deprecations.jl b/src/deprecations.jl index 1cc93d12..128f16d1 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,2 +1,2 @@ # Deprecate the old name AbstractMCMCParallel in favor of AbstractMCMCEnsemble -Base.@deprecate_binding AbstractMCMCParallel AbstractMCMCEnsemble false \ No newline at end of file +Base.@deprecate_binding AbstractMCMCParallel AbstractMCMCEnsemble false diff --git a/src/interface.jl b/src/interface.jl index 7b3daefb..eaecb492 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -30,24 +30,14 @@ be specified with the `chain_type` argument. By default, this method returns `samples`. """ function bundle_samples( - samples, - ::AbstractModel, - ::AbstractSampler, - ::Any, - ::Type; - kwargs... + samples, ::AbstractModel, ::AbstractSampler, ::Any, ::Type; kwargs... ) return samples end function bundle_samples( - samples::Vector, - ::AbstractModel, - ::AbstractSampler, - ::Any, - ::Type{Vector{T}}; - kwargs... -) where T + samples::Vector, ::AbstractModel, ::AbstractSampler, ::Any, ::Type{Vector{T}}; kwargs... +) where {T} return map(samples) do sample convert(T, sample) end @@ -74,24 +64,13 @@ sample is `sample`. The method can be called with and without a predefined number `N` of samples. """ -function samples( - sample, - ::AbstractModel, - ::AbstractSampler, - N::Integer; - kwargs... -) +function samples(sample, ::AbstractModel, ::AbstractSampler, N::Integer; kwargs...) ts = Vector{typeof(sample)}(undef, 0) sizehint!(ts, N) return ts end -function samples( - sample, - ::AbstractModel, - ::AbstractSampler; - kwargs... -) +function samples(sample, ::AbstractModel, ::AbstractSampler; kwargs...) return Vector{typeof(sample)}(undef, 0) end @@ -113,7 +92,7 @@ function save!!( ::AbstractModel, ::AbstractSampler, N::Integer; - kwargs... + kwargs..., ) s = BangBang.push!!(samples, sample) s !== samples && sizehint!(s, N) @@ -121,27 +100,15 @@ function save!!( end function save!!( - samples, - sample, - iteration::Integer, - ::AbstractModel, - ::AbstractSampler; - kwargs... + samples, sample, iteration::Integer, ::AbstractModel, ::AbstractSampler; kwargs... ) return BangBang.push!!(samples, sample) end # Deprecations Base.@deprecate transitions( - transition, - model::AbstractModel, - sampler::AbstractSampler, - N::Integer; - kwargs... + transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs... ) samples(transition, model, sampler, N; kwargs...) false Base.@deprecate transitions( - transition, - model::AbstractModel, - sampler::AbstractSampler; - kwargs... + transition, model::AbstractModel, sampler::AbstractSampler; kwargs... ) samples(transition, model, sampler; kwargs...) false diff --git a/src/logging.jl b/src/logging.jl index a550c532..04c41187 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -2,19 +2,21 @@ # and add a custom progress logger if the current logger does not seem to be able to handle # progress logs macro ifwithprogresslogger(progress, exprs...) - return quote - if $progress - if $hasprogresslevel($Logging.current_logger()) - $ProgressLogging.@withprogress $(exprs...) - else - $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + return esc( + quote + if $progress + if $hasprogresslevel($Logging.current_logger()) $ProgressLogging.@withprogress $(exprs...) + else + $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + $ProgressLogging.@withprogress $(exprs...) + end end + else + $(exprs[end]) end - else - $(exprs[end]) - end - end |> esc + end, + ) end # improved checks? @@ -31,13 +33,14 @@ function with_progresslogger(f, _module, logger) log._module !== _module || log.level != ProgressLogging.ProgressLevel end - Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2)) + return Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2)) end function progresslogger() # detect if code is running under IJulia since TerminalLogger does not work with IJulia # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia - if (Sys.iswindows() && VERSION < v"1.5.3") || (isdefined(Main, :IJulia) && Main.IJulia.inited) + if (Sys.iswindows() && VERSION < v"1.5.3") || + (isdefined(Main, :IJulia) && Main.IJulia.inited) return ConsoleProgressMonitor.ProgressLogger() else return TerminalLoggers.TerminalLogger() diff --git a/src/sample.jl b/src/sample.jl index e1ff1be7..b6fad3fe 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -12,12 +12,7 @@ function setprogress!(progress::Bool) return progress end -function StatsBase.sample( - model::AbstractModel, - sampler::AbstractSampler, - arg; - kwargs... -) +function StatsBase.sample(model::AbstractModel, sampler::AbstractSampler, arg; kwargs...) return StatsBase.sample(Random.GLOBAL_RNG, model, sampler, arg; kwargs...) end @@ -31,7 +26,7 @@ function StatsBase.sample( model::AbstractModel, sampler::AbstractSampler, N::Integer; - kwargs... + kwargs..., ) return mcmcsample(rng, model, sampler, N; kwargs...) end @@ -54,7 +49,7 @@ function StatsBase.sample( model::AbstractModel, sampler::AbstractSampler, isdone; - kwargs... + kwargs..., ) return mcmcsample(rng, model, sampler, isdone; kwargs...) end @@ -65,10 +60,11 @@ function StatsBase.sample( parallel::AbstractMCMCEnsemble, N::Integer, nchains::Integer; - kwargs... + kwargs..., ) - return StatsBase.sample(Random.GLOBAL_RNG, model, sampler, parallel, N, nchains; - kwargs...) + return StatsBase.sample( + Random.GLOBAL_RNG, model, sampler, parallel, N, nchains; kwargs... + ) end """ @@ -84,7 +80,7 @@ function StatsBase.sample( parallel::AbstractMCMCEnsemble, N::Integer, nchains::Integer; - kwargs... + kwargs..., ) return mcmcsample(rng, model, sampler, parallel, N, nchains; kwargs...) end @@ -96,13 +92,13 @@ function mcmcsample( model::AbstractModel, sampler::AbstractSampler, N::Integer; - progress = PROGRESS[], - progressname = "Sampling", - callback = nothing, - discard_initial = 0, - thinning = 1, + progress=PROGRESS[], + progressname="Sampling", + callback=nothing, + discard_initial=0, + thinning=1, chain_type::Type=Any, - kwargs... + kwargs..., ) # Check the number of requested samples. N > 0 || error("the number of samples must be ≥ 1") @@ -112,7 +108,7 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) if progress @@ -127,7 +123,7 @@ function mcmcsample( for i in 1:(discard_initial - 1) # Update the progress bar. if progress && i >= next_update - ProgressLogging.@logprogress i/Ntotal + ProgressLogging.@logprogress i / Ntotal next_update = i + threshold end @@ -167,7 +163,8 @@ function mcmcsample( sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler, N; kwargs...) @@ -186,15 +183,15 @@ function mcmcsample( stats = SamplingStats(start, stop, duration) return bundle_samples( - samples, - model, + samples, + model, sampler, state, chain_type; stats=stats, discard_initial=discard_initial, thinning=thinning, - kwargs... + kwargs..., ) end @@ -204,19 +201,19 @@ function mcmcsample( sampler::AbstractSampler, isdone; chain_type::Type=Any, - progress = PROGRESS[], - progressname = "Convergence sampling", - callback = nothing, - discard_initial = 0, - thinning = 1, - kwargs... + progress=PROGRESS[], + progressname="Convergence sampling", + callback=nothing, + discard_initial=0, + thinning=1, + kwargs..., ) # Start the timer start = time() local state - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. sample, state = step(rng, model, sampler; kwargs...) @@ -247,7 +244,8 @@ function mcmcsample( sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler; kwargs...) @@ -264,15 +262,15 @@ function mcmcsample( # Wrap the samples up. return bundle_samples( - samples, + samples, model, - sampler, - state, - chain_type; + sampler, + state, + chain_type; stats=stats, discard_initial=discard_initial, thinning=thinning, - kwargs... + kwargs..., ) end @@ -283,9 +281,9 @@ function mcmcsample( ::MCMCThreads, N::Integer, nchains::Integer; - progress = PROGRESS[], - progressname = "Sampling ($(min(nchains, Threads.nthreads())) threads)", - kwargs... + progress=PROGRESS[], + progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", + kwargs..., ) # Check if actually multiple threads are used. if Threads.nthreads() == 1 @@ -311,7 +309,7 @@ function mcmcsample( # Set up a chains vector. chains = Vector{Any}(undef, nchains) - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Create a channel for progress logging. if progress channel = Channel{Bool}(length(interval)) @@ -330,7 +328,7 @@ function mcmcsample( while take!(channel) progresschains += 1 if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains/nchains + ProgressLogging.@logprogress progresschains / nchains nextprogresschains = progresschains + threshold end end @@ -339,7 +337,8 @@ function mcmcsample( Distributed.@async begin try - Distributed.@sync for (i, _rng, _model, _sampler) in zip(1:nchunks, rngs, models, samplers) + Distributed.@sync for (i, _rng, _model, _sampler) in + zip(1:nchunks, rngs, models, samplers) chainidxs = if i == nchunks ((i - 1) * chunksize + 1):nchains else @@ -350,8 +349,9 @@ function mcmcsample( Random.seed!(_rng, seeds[chainidx]) # Sample a chain and save it to the vector. - chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N; - progress = false, kwargs...) + chains[chainidx] = StatsBase.sample( + _rng, _model, _sampler, N; progress=false, kwargs... + ) # Update the progress bar. progress && put!(channel, true) @@ -376,9 +376,9 @@ function mcmcsample( ::MCMCDistributed, N::Integer, nchains::Integer; - progress = PROGRESS[], - progressname = "Sampling ($(Distributed.nworkers()) processes)", - kwargs... + progress=PROGRESS[], + progressname="Sampling ($(Distributed.nworkers()) processes)", + kwargs..., ) # Check if actually multiple processes are used. if Distributed.nworkers() == 1 @@ -397,7 +397,7 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Create a channel for progress logging. if progress channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers())) @@ -416,7 +416,7 @@ function mcmcsample( while take!(channel) progresschains += 1 if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains/nchains + ProgressLogging.@logprogress progresschains / nchains nextprogresschains = progresschains + threshold end end @@ -430,8 +430,9 @@ function mcmcsample( Random.seed!(rng, seed) # Sample a chain. - chain = StatsBase.sample(rng, model, sampler, N; - progress = false, kwargs...) + chain = StatsBase.sample( + rng, model, sampler, N; progress=false, kwargs... + ) # Update the progress bar. progress && put!(channel, true) @@ -458,8 +459,8 @@ function mcmcsample( ::MCMCSerial, N::Integer, nchains::Integer; - progressname = "Sampling", - kwargs... + progressname="Sampling", + kwargs..., ) # Check if the number of chains is larger than the number of samples if nchains > N @@ -473,8 +474,11 @@ function mcmcsample( chains = map(enumerate(seeds)) do (i, seed) Random.seed!(rng, seed) return StatsBase.sample( - rng, model, sampler, N; - progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"), + rng, + model, + sampler, + N; + progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), kwargs..., ) end diff --git a/src/samplingstats.jl b/src/samplingstats.jl index dea2b653..c5820dff 100644 --- a/src/samplingstats.jl +++ b/src/samplingstats.jl @@ -13,4 +13,4 @@ struct SamplingStats start::Float64 stop::Float64 duration::Float64 -end \ No newline at end of file +end diff --git a/src/stepper.jl b/src/stepper.jl index 34391851..18867c58 100644 --- a/src/stepper.jl +++ b/src/stepper.jl @@ -13,11 +13,7 @@ end Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite() Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown() -function steps( - model::AbstractModel, - sampler::AbstractSampler; - kwargs... -) +function steps(model::AbstractModel, sampler::AbstractSampler; kwargs...) return steps(Random.GLOBAL_RNG, model, sampler; kwargs...) end @@ -46,10 +42,7 @@ true ``` """ function steps( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler; - kwargs... + rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler; kwargs... ) return Stepper(rng, model, sampler, kwargs) end diff --git a/src/transducer.jl b/src/transducer.jl index 7aca51e0..51f9b358 100644 --- a/src/transducer.jl +++ b/src/transducer.jl @@ -1,4 +1,5 @@ -struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: Transducers.Transducer +struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: + Transducers.Transducer rng::A model::M sampler::S @@ -34,10 +35,7 @@ true ``` """ function Sample( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler; - kwargs... + rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler; kwargs... ) return Sample(rng, model, sampler, kwargs) end diff --git a/test/deprecations.jl b/test/deprecations.jl index f866668c..dd53cb42 100644 --- a/test/deprecations.jl +++ b/test/deprecations.jl @@ -1,4 +1,4 @@ @testset "deprecations.jl" begin @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler()) @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler(), 3) -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index c3f108e1..e8f09589 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,7 @@ using TerminalLoggers: TerminalLogger using Transducers using Distributed -import Logging +using Logging: Logging using Random using Statistics using Test diff --git a/test/sample.jl b/test/sample.jl index 19242629..debb2238 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -5,12 +5,13 @@ Random.seed!(1234) N = 1_000 - chain = sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + chain = sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) @test length(LOGGERS) == 1 logger = first(LOGGERS) @test logger isa TeeLogger - @test logger.loggers[1].logger isa (Sys.iswindows() && VERSION < v"1.5.3" ? ProgressLogger : TerminalLogger) + @test logger.loggers[1].logger isa + (Sys.iswindows() && VERSION < v"1.5.3" ? ProgressLogger : TerminalLogger) @test logger.loggers[2].logger === CURRENT_LOGGER @test Logging.current_logger() === CURRENT_LOGGER @@ -20,10 +21,10 @@ # test some statistical properties tail_chain = @view chain[2:end] - @test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2 - @test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3 - @test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2 - @test var(x.b for x in tail_chain) ≈ 1 atol=6e-2 + @test mean(x.a for x in tail_chain) ≈ 0.5 atol = 6e-2 + @test var(x.a for x in tail_chain) ≈ 1 / 12 atol = 5e-3 + @test mean(x.b for x in tail_chain) ≈ 0.0 atol = 5e-2 + @test var(x.b for x in tail_chain) ≈ 1 atol = 6e-2 end @testset "Juno" begin @@ -34,7 +35,7 @@ logger = JunoProgressLogger() Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) end @test length(LOGGERS) == 1 @@ -52,7 +53,7 @@ Random.seed!(1234) N = 10 - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) @test length(LOGGERS) == 1 logger = first(LOGGERS) @@ -74,7 +75,7 @@ logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1)) Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) end @test length(LOGGERS) == 1 @@ -84,21 +85,25 @@ @testset "Suppress output" begin logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), 100; progress = false, sleepy = true) + sample(MyModel(), MySampler(), 100; progress=false, sleepy=true) end @test all(l.level > Logging.LogLevel(-1) for l in logs) # disable progress logging globally - @test !(@test_logs (:info, "progress logging is disabled globally") AbstractMCMC.setprogress!(false)) + @test !(@test_logs (:info, "progress logging is disabled globally") AbstractMCMC.setprogress!( + false + )) @test !AbstractMCMC.PROGRESS[] logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), 100; sleepy = true) + sample(MyModel(), MySampler(), 100; sleepy=true) end @test all(l.level > Logging.LogLevel(-1) for l in logs) # enable progress logging globally - @test (@test_logs (:info, "progress logging is enabled globally") AbstractMCMC.setprogress!(true)) + @test (@test_logs (:info, "progress logging is enabled globally") AbstractMCMC.setprogress!( + true + )) @test AbstractMCMC.PROGRESS[] end end @@ -106,8 +111,9 @@ @testset "Multithreaded sampling" begin if Threads.nthreads() == 1 warnregex = r"^Only a single thread available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), - 10, 10) + @test_logs (:warn, warnregex) sample( + MyModel(), MySampler(), MCMCThreads(), 10, 10 + ) end # No dedicated chains type @@ -118,8 +124,7 @@ @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain) # test output type and size @test chains isa Vector{<:MyChain} @@ -134,36 +139,43 @@ # test reproducibility Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain) @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCThreads(), 5, 10; - chain_type = MyChain) + @test_logs (:warn, str) match_mode = :any sample( + MyModel(), MySampler(), MCMCThreads(), 5, 10; chain_type=MyChain + ) # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; - progress = false, chain_type = MyChain) + sample( + MyModel(), + MySampler(), + MCMCThreads(), + 10_000, + 1000; + progress=false, + chain_type=MyChain, + ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) - + # Smoke test for nchains < nthreads if Threads.nthreads() == 2 - sample(MyModel(), MySampler(), MCMCThreads(), N, 1) + sample(MyModel(), MySampler(), MCMCThreads(), N, 1) end end @testset "Multicore sampling" begin if nworkers() == 1 warnregex = r"^Only a single process available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCDistributed(), - 10, 10; chain_type = MyChain) + @test_logs (:warn, warnregex) sample( + MyModel(), MySampler(), MCMCDistributed(), 10, 10; chain_type=MyChain + ) end # Add worker processes. @@ -188,8 +200,9 @@ @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; - chain_type = MyChain) + chains = sample( + MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain + ) # Test output type and size. @test chains isa Vector{<:MyChain} @@ -205,22 +218,30 @@ # Test reproducibility. Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; - chain_type = MyChain) + chains2 = sample( + MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain + ) @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCDistributed(), 5, 10; - chain_type = MyChain) + @test_logs (:warn, str) match_mode = :any sample( + MyModel(), MySampler(), MCMCDistributed(), 5, 10; chain_type=MyChain + ) # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 100; - progress = false, chain_type = MyChain) + sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 10_000, + 100; + progress=false, + chain_type=MyChain, + ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) end @@ -234,8 +255,7 @@ @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; - chain_type = MyChain) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) # Test output type and size. @test chains isa Vector{<:MyChain} @@ -251,22 +271,28 @@ # Test reproducibility. Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; - chain_type = MyChain) + chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCSerial(), 5, 10; - chain_type = MyChain) + @test_logs (:warn, str) match_mode = :any sample( + MyModel(), MySampler(), MCMCSerial(), 5, 10; chain_type=MyChain + ) # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCSerial(), 10_000, 100; - progress = false, chain_type = MyChain) + sample( + MyModel(), + MySampler(), + MCMCSerial(), + 10_000, + 100; + progress=false, + chain_type=MyChain, + ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) end @@ -278,46 +304,73 @@ # Serial sampling Random.seed!(1234) chains_serial = sample( - MyModel(), MySampler(), MCMCSerial(), N, nchains; - progress=false, chain_type=MyChain + MyModel(), + MySampler(), + MCMCSerial(), + N, + nchains; + progress=false, + chain_type=MyChain, ) # Multi-threaded sampling Random.seed!(1234) chains_threads = sample( - MyModel(), MySampler(), MCMCThreads(), N, nchains; - progress=false, chain_type=MyChain + MyModel(), + MySampler(), + MCMCThreads(), + N, + nchains; + progress=false, + chain_type=MyChain, + ) + @test all( + c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), + i in 1:N + ) + @test all( + c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), + i in 1:N ) - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) # Multi-core sampling Random.seed!(1234) chains_distributed = sample( - MyModel(), MySampler(), MCMCDistributed(), N, nchains; - progress=false, chain_type=MyChain + MyModel(), + MySampler(), + MCMCDistributed(), + N, + nchains; + progress=false, + chain_type=MyChain, + ) + @test all( + c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), + i in 1:N + ) + @test all( + c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), + i in 1:N ) - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) end @testset "Chain constructors" begin - chain1 = sample(MyModel(), MySampler(), 100; sleepy = true) - chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain) + chain1 = sample(MyModel(), MySampler(), 100; sleepy=true) + chain2 = sample(MyModel(), MySampler(), 100; sleepy=true, chain_type=MyChain) @test chain1 isa Vector{<:MySample} @test chain2 isa MyChain end @testset "Sample stats" begin - chain = sample(MyModel(), MySampler(), 1000; chain_type = MyChain) - + chain = sample(MyModel(), MySampler(), 1000; chain_type=MyChain) + @test chain.stats.stop >= chain.stats.start @test chain.stats.duration == chain.stats.stop - chain.stats.start end @testset "Discard initial samples" begin - chain = sample(MyModel(), MySampler(), 100; sleepy = true, discard_initial = 50) + chain = sample(MyModel(), MySampler(), 100; sleepy=true, discard_initial=50) @test length(chain) == 100 @test !ismissing(chain[1].a) end @@ -327,17 +380,16 @@ Random.seed!(1234) N = 100 thinning = 3 - chain = sample(MyModel(), MySampler(), N; sleepy = true, thinning = thinning) + chain = sample(MyModel(), MySampler(), N; sleepy=true, thinning=thinning) @test length(chain) == N @test ismissing(chain[1].a) # Repeat sampling without thinning. Random.seed!(1234) - ref_chain = sample(MyModel(), MySampler(), N * thinning; sleepy = true) + ref_chain = sample(MyModel(), MySampler(), N * thinning; sleepy=true) @test all(chain[i].a === ref_chain[(i - 1) * thinning + 1].a for i in 1:N) end - @testset "Sample without predetermined N" begin Random.seed!(1234) chain = sample(MyModel(), MySampler()) @@ -346,20 +398,20 @@ @test abs(bmean) <= 0.001 || length(chain) == 10_000 # Discard initial samples. - chain = sample(MyModel(), MySampler(); discard_initial = 50) + chain = sample(MyModel(), MySampler(); discard_initial=50) bmean = mean(x.b for x in chain) @test !ismissing(chain[1].a) @test abs(bmean) <= 0.001 || length(chain) == 10_000 # Thin chain by a factor of `thinning`. - chain = sample(MyModel(), MySampler(); thinning = 3) + chain = sample(MyModel(), MySampler(); thinning=3) bmean = mean(x.b for x in chain) @test ismissing(chain[1].a) @test abs(bmean) <= 0.001 || length(chain) == 10_000 end @testset "Sample vector of `NamedTuple`s" begin - chain = sample(MyModel(), MySampler(), 1_000; chain_type = Vector{NamedTuple}) + chain = sample(MyModel(), MySampler(), 1_000; chain_type=Vector{NamedTuple}) # Check output type @test chain isa Vector{<:NamedTuple} @test length(chain) == 1_000 @@ -367,15 +419,17 @@ # Check some statistical properties @test ismissing(chain[1].a) - @test mean(x.a for x in view(chain, 2:1_000)) ≈ 0.5 atol=6e-2 - @test var(x.a for x in view(chain, 2:1_000)) ≈ 1 / 12 atol=1e-2 - @test mean(x.b for x in chain) ≈ 0 atol=0.1 - @test var(x.b for x in chain) ≈ 1 atol=0.15 + @test mean(x.a for x in view(chain, 2:1_000)) ≈ 0.5 atol = 6e-2 + @test var(x.a for x in view(chain, 2:1_000)) ≈ 1 / 12 atol = 1e-2 + @test mean(x.b for x in chain) ≈ 0 atol = 0.1 + @test var(x.b for x in chain) ≈ 1 atol = 0.15 end - + @testset "Testing callbacks" begin - function count_iterations(rng, model, sampler, sample, state, i; iter_array, kwargs...) - push!(iter_array, i) + function count_iterations( + rng, model, sampler, sample, state, i; iter_array, kwargs... + ) + return push!(iter_array, i) end N = 100 it_array = Float64[] @@ -384,7 +438,9 @@ # sampling without predetermined N it_array = Float64[] - chain = sample(MyModel(), MySampler(); callback=count_iterations, iter_array=it_array) + chain = sample( + MyModel(), MySampler(); callback=count_iterations, iter_array=it_array + ) @test it_array == collect(1:size(chain, 1)) end end diff --git a/test/stepper.jl b/test/stepper.jl index bc75d637..1b570557 100644 --- a/test/stepper.jl +++ b/test/stepper.jl @@ -5,7 +5,7 @@ bs = [] iter = AbstractMCMC.steps(MyModel(), MySampler()) - iter = AbstractMCMC.steps(MyModel(), MySampler(); a = 1.0) # `a` shouldn't do anything + iter = AbstractMCMC.steps(MyModel(), MySampler(); a=1.0) # `a` shouldn't do anything for (count, t) in enumerate(iter) if count >= 1000 @@ -21,10 +21,10 @@ @test length(as) == length(bs) == 998 - @test mean(as) ≈ 0.5 atol=2e-2 - @test var(as) ≈ 1 / 12 atol=5e-3 - @test mean(bs) ≈ 0.0 atol=5e-2 - @test var(bs) ≈ 1 atol=5e-2 + @test mean(as) ≈ 0.5 atol = 2e-2 + @test var(as) ≈ 1 / 12 atol = 5e-3 + @test mean(bs) ≈ 0.0 atol = 5e-2 + @test var(bs) ≈ 1 atol = 5e-2 @test Base.IteratorSize(iter) == Base.IsInfinite() @test Base.IteratorEltype(iter) == Base.EltypeUnknown() diff --git a/test/transducer.jl b/test/transducer.jl index 910f9d70..f9e1a049 100644 --- a/test/transducer.jl +++ b/test/transducer.jl @@ -5,9 +5,8 @@ N = 1_000 local chain Logging.with_logger(TerminalLogger()) do - xf = AbstractMCMC.Sample(MyModel(), MySampler(); - sleepy = true, logger = true) - chain = withprogress(1:N; interval=1e-3) |> xf |> collect + xf = AbstractMCMC.Sample(MyModel(), MySampler(); sleepy=true, logger=true) + chain = collect(xf(withprogress(1:N; interval=1e-3))) end # test output type and size @@ -16,15 +15,15 @@ # test some statistical properties tail_chain = @view chain[2:end] - @test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2 - @test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3 - @test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2 - @test var(x.b for x in tail_chain) ≈ 1 atol=6e-2 + @test mean(x.a for x in tail_chain) ≈ 0.5 atol = 6e-2 + @test var(x.a for x in tail_chain) ≈ 1 / 12 atol = 5e-3 + @test mean(x.b for x in tail_chain) ≈ 0.0 atol = 5e-2 + @test var(x.b for x in tail_chain) ≈ 1 atol = 6e-2 end @testset "drop" begin xf = AbstractMCMC.Sample(MyModel(), MySampler()) - chain = 1:10 |> xf |> Drop(1) |> collect + chain = collect(Drop(1)(xf(1:10))) @test chain isa Vector{MySample{Float64,Float64}} @test length(chain) == 9 end @@ -37,7 +36,7 @@ OfType(MySample{Float64,Float64}), Map(x -> (x.a, x.b)), ) - as, bs = foldl(xf, 1:999; init = (Float64[], Float64[])) do (as, bs), (a, b) + as, bs = foldl(xf, 1:999; init=(Float64[], Float64[])) do (as, bs), (a, b) push!(as, a) push!(bs, b) as, bs @@ -45,9 +44,9 @@ @test length(as) == length(bs) == 998 - @test mean(as) ≈ 0.5 atol=2e-2 - @test var(as) ≈ 1 / 12 atol=5e-3 - @test mean(bs) ≈ 0.0 atol=5e-2 - @test var(bs) ≈ 1 atol=5e-2 + @test mean(as) ≈ 0.5 atol = 2e-2 + @test var(as) ≈ 1 / 12 atol = 5e-3 + @test mean(bs) ≈ 0.0 atol = 5e-2 + @test var(bs) ≈ 1 atol = 5e-2 end end diff --git a/test/utils.jl b/test/utils.jl index cd3543b7..32474639 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -20,10 +20,10 @@ function AbstractMCMC.step( rng::AbstractRNG, model::MyModel, sampler::MySampler, - state::Union{Nothing,Integer} = nothing; - sleepy = false, - loggers = false, - kwargs... + state::Union{Nothing,Integer}=nothing; + sleepy=false, + loggers=false, + kwargs..., ) # sample `a` is missing in the first step a = state === nothing ? missing : rand(rng) @@ -43,8 +43,8 @@ function AbstractMCMC.bundle_samples( sampler::MySampler, ::Any, ::Type{MyChain}; - stats = nothing, - kwargs... + stats=nothing, + kwargs..., ) as = [t.a for t in samples] bs = [t.b for t in samples] @@ -59,7 +59,7 @@ function isdone( samples, state, iteration::Int; - kwargs... + kwargs..., ) # Calculate the mean of x.b. bmean = mean(x.b for x in samples) @@ -72,11 +72,10 @@ function AbstractMCMC.sample(model, sampler::MySampler; kwargs...) end function AbstractMCMC.chainscat( - chain::Union{MyChain,Vector{<:MyChain}}, - chains::Union{MyChain,Vector{<:MyChain}}... + chain::Union{MyChain,Vector{<:MyChain}}, chains::Union{MyChain,Vector{<:MyChain}}... ) return vcat(chain, chains...) end # Conversion to NamedTuple -Base.convert(::Type{NamedTuple}, x::MySample) = (a = x.a, b = x.b) +Base.convert(::Type{NamedTuple}, x::MySample) = (a=x.a, b=x.b) From 3de7393b8b8e76330f53505b27d2b928ef178681 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 7 Mar 2022 08:39:09 +0100 Subject: [PATCH 19/54] Support `init_params` in ensemble methods (#94) * Support `init_params` in ensemble methods * Fix typo * Fix typo * Add documentation * Support `Iterators.Repeated` * Breaking release * Fix and simplify docs setup * Remove deprecations * Reduce tasks on Windows * Generalize to arbitrary collections * Use Blue style --- Project.toml | 2 +- docs/src/api.md | 5 +++ src/AbstractMCMC.jl | 1 - src/deprecations.jl | 2 - src/sample.jl | 76 ++++++++++++++++++++++++++++--- test/deprecations.jl | 4 -- test/runtests.jl | 1 - test/sample.jl | 103 +++++++++++++++++++++++++++++++++++++++++++ test/utils.jl | 10 +++-- 9 files changed, 187 insertions(+), 17 deletions(-) delete mode 100644 src/deprecations.jl delete mode 100644 test/deprecations.jl diff --git a/Project.toml b/Project.toml index 7d3a2ca9..9690fd8d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.3.1" +version = "4.0.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/docs/src/api.md b/docs/src/api.md index 8dcf55f4..c7451cc5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -53,6 +53,11 @@ are: - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. +There is no "official" way for providing initial parameter values yet. +However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. +To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): +- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`. + Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. ```@docs diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 686924a8..3e8e2ff2 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -84,6 +84,5 @@ include("interface.jl") include("sample.jl") include("stepper.jl") include("transducer.jl") -include("deprecations.jl") end # module AbstractMCMC diff --git a/src/deprecations.jl b/src/deprecations.jl deleted file mode 100644 index 128f16d1..00000000 --- a/src/deprecations.jl +++ /dev/null @@ -1,2 +0,0 @@ -# Deprecate the old name AbstractMCMCParallel in favor of AbstractMCMCEnsemble -Base.@deprecate_binding AbstractMCMCParallel AbstractMCMCEnsemble false diff --git a/src/sample.jl b/src/sample.jl index b6fad3fe..01548470 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -283,6 +283,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", + init_params=nothing, kwargs..., ) # Check if actually multiple threads are used. @@ -298,7 +299,7 @@ function mcmcsample( # Copy the random number generator, model, and sample for each thread nchunks = min(nchains, Threads.nthreads()) chunksize = cld(nchains, nchunks) - interval = 1:min(nchains, Threads.nthreads()) + interval = 1:nchunks rngs = [deepcopy(rng) for _ in interval] models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] @@ -306,6 +307,9 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) + # Ensure that initial parameters are `nothing` or indexable + _init_params = _first_or_nothing(init_params, nchains) + # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -350,7 +354,17 @@ function mcmcsample( # Sample a chain and save it to the vector. chains[chainidx] = StatsBase.sample( - _rng, _model, _sampler, N; progress=false, kwargs... + _rng, + _model, + _sampler, + N; + progress=false, + init_params=if _init_params === nothing + nothing + else + _init_params[chainidx] + end, + kwargs..., ) # Update the progress bar. @@ -378,6 +392,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) processes)", + init_params=nothing, kwargs..., ) # Check if actually multiple processes are used. @@ -425,13 +440,19 @@ function mcmcsample( Distributed.@async begin try - chains = Distributed.pmap(pool, seeds) do seed + function sample_chain(seed, init_params=nothing) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) # Sample a chain. chain = StatsBase.sample( - rng, model, sampler, N; progress=false, kwargs... + rng, + model, + sampler, + N; + progress=false, + init_params=init_params, + kwargs..., ) # Update the progress bar. @@ -440,6 +461,11 @@ function mcmcsample( # Return the new chain. return chain end + chains = if init_params === nothing + Distributed.pmap(sample_chain, pool, seeds) + else + Distributed.pmap(sample_chain, pool, seeds, init_params) + end finally # Stop updating the progress bar. progress && put!(channel, false) @@ -460,6 +486,7 @@ function mcmcsample( N::Integer, nchains::Integer; progressname="Sampling", + init_params=nothing, kwargs..., ) # Check if the number of chains is larger than the number of samples @@ -471,21 +498,60 @@ function mcmcsample( seeds = rand(rng, UInt, nchains) # Sample the chains. - chains = map(enumerate(seeds)) do (i, seed) + function sample_chain(i, seed, init_params=nothing) + # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) + + # Sample a chain. return StatsBase.sample( rng, model, sampler, N; progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), + init_params=init_params, kwargs..., ) end + chains = if init_params === nothing + map(sample_chain, 1:nchains, seeds) + else + map(sample_chain, 1:nchains, seeds, init_params) + end + # Concatenate the chains together. return chainsstack(tighten_eltype(chains)) end tighten_eltype(x) = x tighten_eltype(x::Vector{Any}) = map(identity, x) + +""" + _first_or_nothing(x, n::Int) + +Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`. + +If `x !== nothing`, then `x` has to contain at least `n` elements. +""" +function _first_or_nothing(x, n::Int) + y = _first(x, n) + length(y) == n || throw( + ArgumentError("not enough initial parameters (expected $n, received $(length(y))"), + ) + return y +end +_first_or_nothing(::Nothing, ::Int) = nothing + +# `first(x, n::Int)` requires Julia 1.6 +function _first(x, n::Int) + @static if VERSION >= v"1.6.0-DEV.431" + first(x, n) + else + if x isa AbstractVector + @inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))] + else + collect(Iterators.take(x, n)) + end + end +end diff --git a/test/deprecations.jl b/test/deprecations.jl deleted file mode 100644 index dd53cb42..00000000 --- a/test/deprecations.jl +++ /dev/null @@ -1,4 +0,0 @@ -@testset "deprecations.jl" begin - @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler()) - @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler(), 3) -end diff --git a/test/runtests.jl b/test/runtests.jl index e8f09589..3baef78c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,5 +22,4 @@ include("utils.jl") include("sample.jl") include("stepper.jl") include("transducer.jl") - include("deprecations.jl") end diff --git a/test/sample.jl b/test/sample.jl index debb2238..f5a69c12 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -25,6 +25,13 @@ @test var(x.a for x in tail_chain) ≈ 1 / 12 atol = 5e-3 @test mean(x.b for x in tail_chain) ≈ 0.0 atol = 5e-2 @test var(x.b for x in tail_chain) ≈ 1 atol = 6e-2 + + # initial parameters + chain = sample( + MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8) + ) + @test chain[1].a == -1.8 + @test chain[1].b == 3.2 end @testset "Juno" begin @@ -168,6 +175,38 @@ if Threads.nthreads() == 2 sample(MyModel(), MySampler(), MCMCThreads(), N, 1) end + + # initial parameters + init_params = [(b=randn(), a=rand()) for _ in 1:100] + chains = sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + 100; + progress=false, + init_params=init_params, + ) + @test length(chains) == 100 + @test all( + chain[1].a == params.a && chain[1].b == params.b for + (chain, params) in zip(chains, init_params) + ) + + init_params = (a=randn(), b=rand()) + chains = sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + 100; + progress=false, + init_params=Iterators.repeated(init_params), + ) + @test length(chains) == 100 + @test all( + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + ) end @testset "Multicore sampling" begin @@ -244,6 +283,38 @@ ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) + + # initial parameters + init_params = [(a=randn(), b=rand()) for _ in 1:100] + chains = sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + 100; + progress=false, + init_params=init_params, + ) + @test length(chains) == 100 + @test all( + chain[1].a == params.a && chain[1].b == params.b for + (chain, params) in zip(chains, init_params) + ) + + init_params = (b=randn(), a=rand()) + chains = sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + 100; + progress=false, + init_params=Iterators.repeated(init_params), + ) + @test length(chains) == 100 + @test all( + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + ) end @testset "Serial sampling" begin @@ -295,6 +366,38 @@ ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) + + # initial parameters + init_params = [(a=rand(), b=randn()) for _ in 1:100] + chains = sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + 100; + progress=false, + init_params=init_params, + ) + @test length(chains) == 100 + @test all( + chain[1].a == params.a && chain[1].b == params.b for + (chain, params) in zip(chains, init_params) + ) + + init_params = (b=rand(), a=randn()) + chains = sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + 100; + progress=false, + init_params=Iterators.repeated(init_params), + ) + @test length(chains) == 100 + @test all( + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + ) end @testset "Ensemble sampling: Reproducibility" begin diff --git a/test/utils.jl b/test/utils.jl index 32474639..67ba1481 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -23,11 +23,15 @@ function AbstractMCMC.step( state::Union{Nothing,Integer}=nothing; sleepy=false, loggers=false, + init_params=nothing, kwargs..., ) - # sample `a` is missing in the first step - a = state === nothing ? missing : rand(rng) - b = randn(rng) + # sample `a` is missing in the first step if not provided + a, b = if state === nothing && init_params !== nothing + init_params.a, init_params.b + else + (state === nothing ? missing : rand(rng)), randn(rng) + end loggers && push!(LOGGERS, Logging.current_logger()) sleepy && sleep(0.001) From 650d9e1f23ddb9254bc98cabd7c25cf54c669092 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 28 May 2022 11:12:24 +0200 Subject: [PATCH 20/54] CompatHelper: bump compat for "LoggingExtras" to "0.5" (#101) * CompatHelper: bump compat for "LoggingExtras" to "0.5" * Update Project.toml Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 9690fd8d..6dce9f2c 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.0.0" +version = "4.1.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" @@ -20,7 +20,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19" ConsoleProgressMonitor = "0.1" -LoggingExtras = "0.4" +LoggingExtras = "0.4, 0.5" ProgressLogging = "0.1" StatsBase = "0.32, 0.33" TerminalLoggers = "0.1" From 8d7f22f5a047a16b6870ebb15c0090331db8dcaa Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 1 Jun 2022 13:13:30 +0200 Subject: [PATCH 21/54] Fix `discard_initial`, and add support for `discard_initial` and `thinning` to iterator and transducer (#102) * Fix `discard_initial`, and add support for `discard_initial` and `thinning` to iterator and transducer * Fix test errors on Julia < 1.6 * Only enable progress logging on Julia < 1.6 * Use different seed * Update api.md * Update api.md * Update sample.jl * Use `==` instead of `===` --- Project.toml | 2 +- docs/src/api.md | 6 ++- src/sample.jl | 4 +- src/stepper.jl | 32 +++++++++++++- src/transducer.jl | 52 +++++++++++++++++++---- test/sample.jl | 101 +++++++++++++++++++++++++++++++++------------ test/stepper.jl | 42 +++++++++++++++++++ test/transducer.jl | 46 +++++++++++++++++++++ 8 files changed, 243 insertions(+), 42 deletions(-) diff --git a/Project.toml b/Project.toml index 6dce9f2c..69a7fb83 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.1.0" +version = "4.1.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/docs/src/api.md b/docs/src/api.md index c7451cc5..9ce28805 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -43,8 +43,7 @@ AbstractMCMC.MCMCSerial ## Common keyword arguments -Common keyword arguments for regular and parallel sampling (not supported by the iterator and transducer) -are: +Common keyword arguments for regular and parallel sampling are: - `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then @@ -53,6 +52,9 @@ are: - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. +!!! info + The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref). + There is no "official" way for providing initial parameter values yet. However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): diff --git a/src/sample.jl b/src/sample.jl index 01548470..3b578020 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -120,7 +120,7 @@ function mcmcsample( sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. - for i in 1:(discard_initial - 1) + for i in 1:discard_initial # Update the progress bar. if progress && i >= next_update ProgressLogging.@logprogress i / Ntotal @@ -218,7 +218,7 @@ function mcmcsample( sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. - for _ in 2:discard_initial + for _ in 1:discard_initial # Obtain the next sample and state. sample, state = step(rng, model, sampler, state; kwargs...) end diff --git a/src/stepper.jl b/src/stepper.jl index 18867c58..e7c97eed 100644 --- a/src/stepper.jl +++ b/src/stepper.jl @@ -5,9 +5,37 @@ struct Stepper{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} kwargs::K end -Base.iterate(stp::Stepper) = step(stp.rng, stp.model, stp.sampler; stp.kwargs...) +# Initial sample. +function Base.iterate(stp::Stepper) + # Unpack iterator. + rng = stp.rng + model = stp.model + sampler = stp.sampler + kwargs = stp.kwargs + discard_initial = get(kwargs, :discard_initial, 0)::Int + + # Start sampling algorithm and discard initial samples if desired. + sample, state = step(rng, model, sampler; kwargs...) + for _ in 1:discard_initial + sample, state = step(rng, model, sampler, state; kwargs...) + end + return sample, state +end + +# Subsequent samples. function Base.iterate(stp::Stepper, state) - return step(stp.rng, stp.model, stp.sampler, state; stp.kwargs...) + # Unpack iterator. + rng = stp.rng + model = stp.model + sampler = stp.sampler + kwargs = stp.kwargs + thinning = get(kwargs, :thinning, 1)::Int + + # Return next sample, possibly after thinning the chain if desired. + for _ in 1:(thinning - 1) + _, state = step(rng, model, sampler, state; kwargs...) + end + return step(rng, model, sampler, state; kwargs...) end Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite() diff --git a/src/transducer.jl b/src/transducer.jl index 51f9b358..42df6dba 100644 --- a/src/transducer.jl +++ b/src/transducer.jl @@ -40,24 +40,58 @@ function Sample( return Sample(rng, model, sampler, kwargs) end +# Initial sample. function Transducers.start(rf::Transducers.R_{<:Sample}, result) - sampler = Transducers.xform(rf) + # Unpack transducer. + td = Transducers.xform(rf) + rng = td.rng + model = td.model + sampler = td.sampler + kwargs = td.kwargs + discard_initial = get(kwargs, :discard_initial, 0)::Int + + # Start sampling algorithm and discard initial samples if desired. + sample, state = step(rng, model, sampler; kwargs...) + for _ in 1:discard_initial + sample, state = step(rng, model, sampler, state; kwargs...) + end + return Transducers.wrap( - rf, - step(sampler.rng, sampler.model, sampler.sampler; sampler.kwargs...), - Transducers.start(Transducers.inner(rf), result), + rf, (sample, state), Transducers.start(Transducers.inner(rf), result) ) end +# Subsequent samples. function Transducers.next(rf::Transducers.R_{<:Sample}, result, input) - t = Transducers.xform(rf) - Transducers.wrapping(rf, result) do (sample, state), iresult - iresult2 = Transducers.next(Transducers.inner(rf), iresult, sample) - return step(t.rng, t.model, t.sampler, state; t.kwargs...), iresult2 + # Unpack transducer. + td = Transducers.xform(rf) + rng = td.rng + model = td.model + sampler = td.sampler + kwargs = td.kwargs + thinning = get(kwargs, :thinning, 1)::Int + + let rng = rng, + model = model, + sampler = sampler, + kwargs = kwargs, + thinning = thinning, + inner_rf = Transducers.inner(rf) + + Transducers.wrapping(rf, result) do (sample, state), iresult + iresult2 = Transducers.next(inner_rf, iresult, sample) + + # Perform thinning if desired. + for _ in 1:(thinning - 1) + _, state = step(rng, model, sampler, state; kwargs...) + end + + return step(rng, model, sampler, state; kwargs...), iresult2 + end end end function Transducers.complete(rf::Transducers.R_{Sample}, result) - _private_state, inner_result = Transducers.unwrap(rf, result) + _, inner_result = Transducers.unwrap(rf, result) return Transducers.complete(Transducers.inner(rf), inner_result) end diff --git a/test/sample.jl b/test/sample.jl index f5a69c12..cf080321 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -137,6 +137,7 @@ @test chains isa Vector{<:MyChain} @test length(chains) == 1000 @test all(x -> length(x.as) == length(x.bs) == N, chains) + @test all(ismissing(x.as[1]) for x in chains) # test some statistical properties @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @@ -147,9 +148,9 @@ # test reproducibility Random.seed!(1234) chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(ismissing(x.as[1]) for x in chains2) + @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) + @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" @@ -245,7 +246,7 @@ # Test output type and size. @test chains isa Vector{<:MyChain} - @test all(c.as[1] === missing for c in chains) + @test all(ismissing(c.as[1]) for c in chains) @test length(chains) == 1000 @test all(x -> length(x.as) == length(x.bs) == N, chains) @@ -260,9 +261,9 @@ chains2 = sample( MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain ) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(ismissing(c.as[1]) for c in chains2) + @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) + @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" @@ -330,7 +331,7 @@ # Test output type and size. @test chains isa Vector{<:MyChain} - @test all(c.as[1] === missing for c in chains) + @test all(ismissing(c.as[1]) for c in chains) @test length(chains) == 1000 @test all(x -> length(x.as) == length(x.bs) == N, chains) @@ -343,9 +344,9 @@ # Test reproducibility. Random.seed!(1234) chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(ismissing(c.as[1]) for c in chains2) + @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) + @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" @@ -415,6 +416,7 @@ progress=false, chain_type=MyChain, ) + @test all(ismissing(c.as[1]) for c in chains_serial) # Multi-threaded sampling Random.seed!(1234) @@ -427,12 +429,13 @@ progress=false, chain_type=MyChain, ) + @test all(ismissing(c.as[1]) for c in chains_threads) @test all( - c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), - i in 1:N + c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), + i in 2:N ) @test all( - c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), + c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N ) @@ -447,12 +450,13 @@ progress=false, chain_type=MyChain, ) + @test all(ismissing(c.as[1]) for c in chains_distributed) @test all( - c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), - i in 1:N + c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), + i in 2:N ) @test all( - c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), + c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N ) end @@ -473,24 +477,41 @@ end @testset "Discard initial samples" begin - chain = sample(MyModel(), MySampler(), 100; sleepy=true, discard_initial=50) - @test length(chain) == 100 + # Create a chain and discard initial samples. + Random.seed!(1234) + N = 100 + discard_initial = 50 + chain = sample(MyModel(), MySampler(), N; discard_initial=discard_initial) + @test length(chain) == N @test !ismissing(chain[1].a) + + # Repeat sampling without discarding initial samples. + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + ref_chain = sample( + MyModel(), MySampler(), N + discard_initial; progress=VERSION < v"1.6" + ) + @test all(chain[i].a == ref_chain[i + discard_initial].a for i in 1:N) + @test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N) end @testset "Thin chain by a factor of `thinning`" begin # Run a thinned chain with `N` samples thinned by factor of `thinning`. - Random.seed!(1234) + Random.seed!(100) N = 100 thinning = 3 - chain = sample(MyModel(), MySampler(), N; sleepy=true, thinning=thinning) + chain = sample(MyModel(), MySampler(), N; thinning=thinning) @test length(chain) == N @test ismissing(chain[1].a) # Repeat sampling without thinning. - Random.seed!(1234) - ref_chain = sample(MyModel(), MySampler(), N * thinning; sleepy=true) - @test all(chain[i].a === ref_chain[(i - 1) * thinning + 1].a for i in 1:N) + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(100) + ref_chain = sample(MyModel(), MySampler(), N * thinning; progress=VERSION < v"1.6") + @test all(chain[i].a == ref_chain[(i - 1) * thinning + 1].a for i in 2:N) + @test all(chain[i].b == ref_chain[(i - 1) * thinning + 1].b for i in 1:N) end @testset "Sample without predetermined N" begin @@ -501,16 +522,44 @@ @test abs(bmean) <= 0.001 || length(chain) == 10_000 # Discard initial samples. - chain = sample(MyModel(), MySampler(); discard_initial=50) + Random.seed!(1234) + discard_initial = 50 + chain = sample(MyModel(), MySampler(); discard_initial=discard_initial) bmean = mean(x.b for x in chain) @test !ismissing(chain[1].a) @test abs(bmean) <= 0.001 || length(chain) == 10_000 + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + N = length(chain) + ref_chain = sample( + MyModel(), + MySampler(), + N; + discard_initial=discard_initial, + progress=VERSION < v"1.6", + ) + @test all(chain[i].a == ref_chain[i].a for i in 1:N) + @test all(chain[i].b == ref_chain[i].b for i in 1:N) + # Thin chain by a factor of `thinning`. - chain = sample(MyModel(), MySampler(); thinning=3) + Random.seed!(1234) + thinning = 3 + chain = sample(MyModel(), MySampler(); thinning=thinning) bmean = mean(x.b for x in chain) @test ismissing(chain[1].a) @test abs(bmean) <= 0.001 || length(chain) == 10_000 + + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + N = length(chain) + ref_chain = sample( + MyModel(), MySampler(), N; thinning=thinning, progress=VERSION < v"1.6" + ) + @test all(chain[i].a == ref_chain[i].a for i in 2:N) + @test all(chain[i].b == ref_chain[i].b for i in 1:N) end @testset "Sample vector of `NamedTuple`s" begin diff --git a/test/stepper.jl b/test/stepper.jl index 1b570557..bc0ea8b2 100644 --- a/test/stepper.jl +++ b/test/stepper.jl @@ -29,4 +29,46 @@ @test Base.IteratorSize(iter) == Base.IsInfinite() @test Base.IteratorEltype(iter) == Base.EltypeUnknown() end + + @testset "Discard initial samples" begin + # Create a chain of `N` samples after discarding some initial samples. + Random.seed!(1234) + N = 50 + discard_initial = 10 + iter = AbstractMCMC.steps(MyModel(), MySampler(); discard_initial=discard_initial) + as = [] + bs = [] + for t in Iterators.take(iter, N) + push!(as, t.a) + push!(bs, t.b) + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample( + MyModel(), MySampler(), N; discard_initial=discard_initial, progress=false + ) + @test all(as[i] === chain[i].a for i in 1:N) + @test all(bs[i] === chain[i].b for i in 1:N) + end + + @testset "Thin chain by a factor of `thinning`" begin + # Create a thinned chain with a thinning factor of `thinning`. + Random.seed!(1234) + N = 50 + thinning = 3 + iter = AbstractMCMC.steps(MyModel(), MySampler(); thinning=thinning) + as = [] + bs = [] + for t in Iterators.take(iter, N) + push!(as, t.a) + push!(bs, t.b) + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample(MyModel(), MySampler(), N; thinning=thinning, progress=false) + @test all(as[i] === chain[i].a for i in 1:N) + @test all(bs[i] === chain[i].b for i in 1:N) + end end diff --git a/test/transducer.jl b/test/transducer.jl index f9e1a049..c534ac90 100644 --- a/test/transducer.jl +++ b/test/transducer.jl @@ -49,4 +49,50 @@ @test mean(bs) ≈ 0.0 atol = 5e-2 @test var(bs) ≈ 1 atol = 5e-2 end + + @testset "Discard initial samples" begin + # Create a chain of `N` samples after discarding some initial samples. + Random.seed!(1234) + N = 50 + discard_initial = 10 + xf = opcompose( + AbstractMCMC.Sample(MyModel(), MySampler(); discard_initial=discard_initial), + Map(x -> (x.a, x.b)), + ) + as, bs = foldl(xf, 1:N; init=([], [])) do (as, bs), (a, b) + push!(as, a) + push!(bs, b) + as, bs + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample( + MyModel(), MySampler(), N; discard_initial=discard_initial, progress=false + ) + @test all(as[i] === chain[i].a for i in 1:N) + @test all(bs[i] === chain[i].b for i in 1:N) + end + + @testset "Thin chain by a factor of `thinning`" begin + # Create a thinned chain with a thinning factor of `thinning`. + Random.seed!(1234) + N = 50 + thinning = 3 + xf = opcompose( + AbstractMCMC.Sample(MyModel(), MySampler(); thinning=thinning), + Map(x -> (x.a, x.b)), + ) + as, bs = foldl(xf, 1:N; init=([], [])) do (as, bs), (a, b) + push!(as, a) + push!(bs, b) + as, bs + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample(MyModel(), MySampler(), N; thinning=thinning, progress=false) + @test all(as[i] === chain[i].a for i in 1:N) + @test all(bs[i] === chain[i].b for i in 1:N) + end end From df8cb0e370d76d456591c342490c94ad1a925a80 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 1 Jun 2022 13:25:47 +0200 Subject: [PATCH 22/54] Update transducer.jl --- test/transducer.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/transducer.jl b/test/transducer.jl index c534ac90..b161151c 100644 --- a/test/transducer.jl +++ b/test/transducer.jl @@ -70,8 +70,8 @@ chain = sample( MyModel(), MySampler(), N; discard_initial=discard_initial, progress=false ) - @test all(as[i] === chain[i].a for i in 1:N) - @test all(bs[i] === chain[i].b for i in 1:N) + @test all(as[i] == chain[i].a for i in 1:N) + @test all(bs[i] == chain[i].b for i in 1:N) end @testset "Thin chain by a factor of `thinning`" begin @@ -92,7 +92,8 @@ # Repeat sampling with `sample`. Random.seed!(1234) chain = sample(MyModel(), MySampler(), N; thinning=thinning, progress=false) - @test all(as[i] === chain[i].a for i in 1:N) - @test all(bs[i] === chain[i].b for i in 1:N) + @test as[1] === chain[1].a === missing + @test all(as[i] == chain[i].a for i in 2:N) + @test all(bs[i] == chain[i].b for i in 1:N) end end From aa82c2430bea33c7ac55787d971ccdfc7604e44e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 1 Jun 2022 13:26:47 +0200 Subject: [PATCH 23/54] Update stepper.jl --- test/stepper.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/stepper.jl b/test/stepper.jl index bc0ea8b2..80143344 100644 --- a/test/stepper.jl +++ b/test/stepper.jl @@ -48,8 +48,8 @@ chain = sample( MyModel(), MySampler(), N; discard_initial=discard_initial, progress=false ) - @test all(as[i] === chain[i].a for i in 1:N) - @test all(bs[i] === chain[i].b for i in 1:N) + @test all(as[i] == chain[i].a for i in 1:N) + @test all(bs[i] == chain[i].b for i in 1:N) end @testset "Thin chain by a factor of `thinning`" begin @@ -68,7 +68,8 @@ # Repeat sampling with `sample`. Random.seed!(1234) chain = sample(MyModel(), MySampler(), N; thinning=thinning, progress=false) - @test all(as[i] === chain[i].a for i in 1:N) - @test all(bs[i] === chain[i].b for i in 1:N) + @test as[1] === chain[1].a === missing + @test all(as[i] == chain[i].a for i in 2:N) + @test all(bs[i] == chain[i].b for i in 1:N) end end From 346c6a27b6d9c3df99229a35e570c9bbc62043a5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 1 Jun 2022 16:40:36 +0200 Subject: [PATCH 24/54] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 69a7fb83..fa2220b7 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.1.1" +version = "4.1.2" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From a9ba4d3b1c0314393532dbd792befa55750d9a0f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 8 Jun 2022 10:42:17 +0200 Subject: [PATCH 25/54] Replace `GLOBAL_RNG` with `default_rng()` (#104) * Replace `GLOBAL_RNG` with `default_rng()` * Update utils.jl * Update sample.jl --- Project.toml | 2 +- src/sample.jl | 4 ++-- src/stepper.jl | 2 +- src/transducer.jl | 2 +- test/sample.jl | 16 ++++++++-------- test/utils.jl | 4 +--- 6 files changed, 14 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index fa2220b7..028a8bb6 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.1.2" +version = "4.1.3" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/sample.jl b/src/sample.jl index 3b578020..c6b0112f 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -13,7 +13,7 @@ function setprogress!(progress::Bool) end function StatsBase.sample(model::AbstractModel, sampler::AbstractSampler, arg; kwargs...) - return StatsBase.sample(Random.GLOBAL_RNG, model, sampler, arg; kwargs...) + return StatsBase.sample(Random.default_rng(), model, sampler, arg; kwargs...) end """ @@ -63,7 +63,7 @@ function StatsBase.sample( kwargs..., ) return StatsBase.sample( - Random.GLOBAL_RNG, model, sampler, parallel, N, nchains; kwargs... + Random.default_rng(), model, sampler, parallel, N, nchains; kwargs... ) end diff --git a/src/stepper.jl b/src/stepper.jl index e7c97eed..68059926 100644 --- a/src/stepper.jl +++ b/src/stepper.jl @@ -42,7 +42,7 @@ Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite() Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown() function steps(model::AbstractModel, sampler::AbstractSampler; kwargs...) - return steps(Random.GLOBAL_RNG, model, sampler; kwargs...) + return steps(Random.default_rng(), model, sampler; kwargs...) end """ diff --git a/src/transducer.jl b/src/transducer.jl index 42df6dba..46d36d91 100644 --- a/src/transducer.jl +++ b/src/transducer.jl @@ -7,7 +7,7 @@ struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: end function Sample(model::AbstractModel, sampler::AbstractSampler; kwargs...) - return Sample(Random.GLOBAL_RNG, model, sampler; kwargs...) + return Sample(Random.default_rng(), model, sampler; kwargs...) end """ diff --git a/test/sample.jl b/test/sample.jl index cf080321..97bf5a5e 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -5,7 +5,7 @@ Random.seed!(1234) N = 1_000 - chain = sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) + chain = sample(MyModel(), MySampler(), N; loggers=true) @test length(LOGGERS) == 1 logger = first(LOGGERS) @@ -42,7 +42,7 @@ logger = JunoProgressLogger() Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) + sample(MyModel(), MySampler(), N; loggers=true) end @test length(LOGGERS) == 1 @@ -60,7 +60,7 @@ Random.seed!(1234) N = 10 - sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) + sample(MyModel(), MySampler(), N; loggers=true) @test length(LOGGERS) == 1 logger = first(LOGGERS) @@ -82,7 +82,7 @@ logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1)) Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) + sample(MyModel(), MySampler(), N; loggers=true) end @test length(LOGGERS) == 1 @@ -92,7 +92,7 @@ @testset "Suppress output" begin logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), 100; progress=false, sleepy=true) + sample(MyModel(), MySampler(), 100; progress=false) end @test all(l.level > Logging.LogLevel(-1) for l in logs) @@ -103,7 +103,7 @@ @test !AbstractMCMC.PROGRESS[] logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), 100; sleepy=true) + sample(MyModel(), MySampler(), 100) end @test all(l.level > Logging.LogLevel(-1) for l in logs) @@ -462,8 +462,8 @@ end @testset "Chain constructors" begin - chain1 = sample(MyModel(), MySampler(), 100; sleepy=true) - chain2 = sample(MyModel(), MySampler(), 100; sleepy=true, chain_type=MyChain) + chain1 = sample(MyModel(), MySampler(), 100) + chain2 = sample(MyModel(), MySampler(), 100; chain_type=MyChain) @test chain1 isa Vector{<:MySample} @test chain2 isa MyChain diff --git a/test/utils.jl b/test/utils.jl index 67ba1481..e2eedcb4 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -21,7 +21,6 @@ function AbstractMCMC.step( model::MyModel, sampler::MySampler, state::Union{Nothing,Integer}=nothing; - sleepy=false, loggers=false, init_params=nothing, kwargs..., @@ -34,7 +33,6 @@ function AbstractMCMC.step( end loggers && push!(LOGGERS, Logging.current_logger()) - sleepy && sleep(0.001) _state = state === nothing ? 1 : state + 1 @@ -72,7 +70,7 @@ end # Set a default convergence function. function AbstractMCMC.sample(model, sampler::MySampler; kwargs...) - return sample(Random.GLOBAL_RNG, model, sampler, isdone; kwargs...) + return sample(Random.default_rng(), model, sampler, isdone; kwargs...) end function AbstractMCMC.chainscat( From 6ef1dcbf954261ab506f52df46fb2454bbde85cf Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 26 Aug 2022 10:51:12 +0200 Subject: [PATCH 26/54] Update CompatHelper.yml (#106) --- .github/workflows/CompatHelper.yml | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index f5da2d24..23e85888 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -3,16 +3,35 @@ on: schedule: - cron: 0 0 * * * workflow_dispatch: +permissions: + contents: write + pull-requests: write jobs: CompatHelper: runs-on: ubuntu-latest steps: + - name: Check if Julia is already available in the PATH + id: julia_in_path + run: which julia + continue-on-error: true + - name: Install Julia, but only if it is not already available in the PATH + uses: julia-actions/setup-julia@v1 + with: + version: '1' + arch: ${{ runner.arch }} + if: steps.julia_in_path.outcome != 'success' + - name: "Add the General registry via Git" + run: | + import Pkg + ENV["JULIA_PKG_SERVER"] = "" + Pkg.Registry.add("General") + shell: julia --color=yes {0} - name: "Install CompatHelper" run: | import Pkg name = "CompatHelper" uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" - version = "2" + version = "3" Pkg.add(; name, uuid, version) shell: julia --color=yes {0} - name: "Run CompatHelper" From 18ea3f156e93177ac3bc0c6504232a334ad7ee83 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Dec 2022 14:42:10 +0000 Subject: [PATCH 27/54] Add AbstractModel implementation of LogDensityProblems interface (#110) * added AbstractModel implementing the interface of LogDensityProblems * version bump and add LogDensityProblems as a dep * Update Project.toml Co-authored-by: David Widmann * removed forwarding of methods for LogDensityModel * Apply suggestions from code review Co-authored-by: David Widmann * bump julia compat to 1.6 Co-authored-by: David Widmann --- .github/workflows/CI.yml | 4 ++-- Project.toml | 4 ++-- src/AbstractMCMC.jl | 1 + src/logdensityproblems.jl | 15 +++++++++++++++ 4 files changed, 20 insertions(+), 4 deletions(-) create mode 100644 src/logdensityproblems.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d5b1273a..06ff8ad9 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: version: - - '1.3' + - '1.6' - '1' - nightly os: @@ -31,7 +31,7 @@ jobs: arch: x86 - os: macOS-latest arch: x86 - - version: '1.3' + - version: '1.6' num_threads: 2 include: - version: '1' diff --git a/Project.toml b/Project.toml index 028a8bb6..4053aa83 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.1.3" +version = "4.2" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" @@ -25,7 +25,7 @@ ProgressLogging = "0.1" StatsBase = "0.32, 0.33" TerminalLoggers = "0.1" Transducers = "0.4.30" -julia = "1.3" +julia = "1.6" [extras] Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1" diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 3e8e2ff2..44f56a9c 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -84,5 +84,6 @@ include("interface.jl") include("sample.jl") include("stepper.jl") include("transducer.jl") +include("logdensityproblems.jl") end # module AbstractMCMC diff --git a/src/logdensityproblems.jl b/src/logdensityproblems.jl new file mode 100644 index 00000000..98615cde --- /dev/null +++ b/src/logdensityproblems.jl @@ -0,0 +1,15 @@ +""" + LogDensityModel <: AbstractMCMC.AbstractModel + +Wrapper around something that implements the LogDensityProblem.jl interface. + +Note that this does _not_ implement the LogDensityProblems.jl interface itself, +but it simply useful for indicating to the `sample` and other `AbstractMCMC` methods +that the wrapped object implements the LogDensityProblems.jl interface. + +# Fields +- `logdensity`: The object that implements the LogDensityProblems.jl interface. +""" +struct LogDensityModel{L} <: AbstractModel + logdensity::L +end From 50cdf04980ebed131d19768453a2ff3226dc068f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 29 Dec 2022 17:47:20 +0100 Subject: [PATCH 28/54] Check if log density supports LogDensityProblems (#111) * Check if log density supports LogDensityProblems * Load LogDensityProblems * Update Project.toml * Update src/logdensityproblems.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Add AdvancedHMC to downstream tests Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .github/workflows/IntegrationTest.yml | 1 + Project.toml | 4 +++- src/AbstractMCMC.jl | 1 + src/logdensityproblems.jl | 12 ++++++++++++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index cd3b3658..2e9d6bcf 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -14,6 +14,7 @@ jobs: fail-fast: false matrix: package: + - {user: TuringLang, repo: AdvancedHMC.jl} - {user: TuringLang, repo: AdvancedMH.jl} - {user: TuringLang, repo: EllipticalSliceSampling.jl} - {user: TuringLang, repo: MCMCChains.jl} diff --git a/Project.toml b/Project.toml index 4053aa83..99ef46a1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,12 +3,13 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.2" +version = "4.2.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" @@ -20,6 +21,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19" ConsoleProgressMonitor = "0.1" +LogDensityProblems = "2" LoggingExtras = "0.4, 0.5" ProgressLogging = "0.1" StatsBase = "0.32, 0.33" diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 44f56a9c..64f20f97 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -2,6 +2,7 @@ module AbstractMCMC using BangBang: BangBang using ConsoleProgressMonitor: ConsoleProgressMonitor +using LogDensityProblems: LogDensityProblems using LoggingExtras: LoggingExtras using ProgressLogging: ProgressLogging using StatsBase: StatsBase diff --git a/src/logdensityproblems.jl b/src/logdensityproblems.jl index 98615cde..54db36bb 100644 --- a/src/logdensityproblems.jl +++ b/src/logdensityproblems.jl @@ -12,4 +12,16 @@ that the wrapped object implements the LogDensityProblems.jl interface. """ struct LogDensityModel{L} <: AbstractModel logdensity::L + function LogDensityModel{L}(logdensity::L) where {L} + if LogDensityProblems.capabilities(logdensity) === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + return new{L}(logdensity) + end end + +LogDensityModel(logdensity::L) where {L} = LogDensityModel{L}(logdensity) From 2d31f092d2f7230c5a32a552a2ae3e9a8f958277 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 9 Jan 2023 23:15:08 +0100 Subject: [PATCH 29/54] CompatHelper: bump compat for LoggingExtras to 1, (keep existing compat) (#114) * CompatHelper: bump compat for LoggingExtras to 1, (keep existing compat) * Update Project.toml Co-authored-by: CompatHelper Julia Co-authored-by: David Widmann --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 99ef46a1..2af00074 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.2.1" +version = "4.3.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" @@ -22,7 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" BangBang = "0.3.19" ConsoleProgressMonitor = "0.1" LogDensityProblems = "2" -LoggingExtras = "0.4, 0.5" +LoggingExtras = "0.4, 0.5, 1" ProgressLogging = "0.1" StatsBase = "0.32, 0.33" TerminalLoggers = "0.1" From 33487da76d9874adb7bee1b0509d0a3172580c9a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 10 Jan 2023 13:51:27 +0100 Subject: [PATCH 30/54] Support log density functions as models (#113) * Update sample.jl * Update sample.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update api.md * Update stepper.jl * Update transducer.jl * Update api.md * Update src/stepper.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/transducer.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * Update src/sample.jl Co-authored-by: Tor Erlend Fjelde * Reorganize fallbacks * Add tests * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * Define utilities on all workers * Update test/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde --- Project.toml | 2 +- docs/src/api.md | 26 ++++++++++- src/logdensityproblems.jl | 92 ++++++++++++++++++++++++++++++++++++++ src/sample.jl | 53 ++++++++++++---------- src/stepper.jl | 11 +++-- src/transducer.jl | 11 +++-- test/logdensityproblems.jl | 90 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 + test/sample.jl | 11 ++++- test/utils.jl | 32 +++++++++++++ 10 files changed, 296 insertions(+), 34 deletions(-) create mode 100644 test/logdensityproblems.jl diff --git a/Project.toml b/Project.toml index 2af00074..a6bf3e65 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.3.0" +version = "4.4.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/docs/src/api.md b/docs/src/api.md index 9ce28805..52c2c2e1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -2,23 +2,39 @@ AbstractMCMC defines an interface for sampling Markov chains. +## Model + +```@docs +AbstractMCMC.AbstractModel +AbstractMCMC.LogDensityModel +``` + +## Sampler + +```@docs +AbstractMCMC.AbstractSampler +``` + ## Sampling a single chain ```@docs -AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Integer) AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Any) +AbstractMCMC.sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler, ::Any) + ``` ### Iterator ```@docs AbstractMCMC.steps(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler) +AbstractMCMC.steps(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler) ``` ### Transducer ```@docs AbstractMCMC.Sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler) +AbstractMCMC.Sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler) ``` ## Sampling multiple chains in parallel @@ -32,6 +48,14 @@ AbstractMCMC.sample( ::Integer, ::Integer, ) +AbstractMCMC.sample( + ::AbstractRNG, + ::Any, + ::AbstractMCMC.AbstractSampler, + ::AbstractMCMC.AbstractMCMCEnsemble, + ::Integer, + ::Integer, +) ``` Two algorithms are provided for parallel sampling with multiple threads and multiple processes, and one allows for the user to sample multiple chains in serial (no parallelization): diff --git a/src/logdensityproblems.jl b/src/logdensityproblems.jl index 54db36bb..f15f656a 100644 --- a/src/logdensityproblems.jl +++ b/src/logdensityproblems.jl @@ -25,3 +25,95 @@ struct LogDensityModel{L} <: AbstractModel end LogDensityModel(logdensity::L) where {L} = LogDensityModel{L}(logdensity) + +# Fallbacks: Wrap log density function in a model +""" + sample( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler, + N_or_isdone; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function StatsBase.sample( + rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler, N_or_isdone; kwargs... +) + return StatsBase.sample(rng, _model(logdensity), sampler, N_or_isdone; kwargs...) +end + +""" + sample( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler, + parallel::AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function StatsBase.sample( + rng::Random.AbstractRNG, + logdensity, + sampler::AbstractSampler, + parallel::AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + kwargs..., +) + return StatsBase.sample( + rng, _model(logdensity), sampler, parallel, N, nchains; kwargs... + ) +end + +""" + steps( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `steps` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function steps(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...) + return steps(rng, _model(logdensity), sampler; kwargs...) +end + +""" + Sample( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `Sample` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function Sample(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...) + return Sample(rng, _model(logdensity), sampler; kwargs...) +end + +function _model(logdensity) + if LogDensityProblems.capabilities(logdensity) === nothing + throw( + ArgumentError( + "the log density function does not support the LogDensityProblems.jl interface. Please implement the interface or provide a model of type `AbstractMCMC.AbstractModel`", + ), + ) + end + return LogDensityModel(logdensity) +end diff --git a/src/sample.jl b/src/sample.jl index c6b0112f..dc951ca2 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -12,32 +12,29 @@ function setprogress!(progress::Bool) return progress end -function StatsBase.sample(model::AbstractModel, sampler::AbstractSampler, arg; kwargs...) - return StatsBase.sample(Random.default_rng(), model, sampler, arg; kwargs...) -end - -""" - sample([rng, ]model, sampler, N; kwargs...) - -Return `N` samples from the `model` with the Markov chain Monte Carlo `sampler`. -""" function StatsBase.sample( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler, - N::Integer; - kwargs..., + model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs... ) - return mcmcsample(rng, model, sampler, N; kwargs...) + return StatsBase.sample( + Random.default_rng(), model_or_logdensity, sampler, N_or_isdone; kwargs... + ) end """ - sample([rng, ]model, sampler, isdone; kwargs...) + sample( + rng::Random.AbatractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler, + N_or_isdone; + kwargs..., + ) + +Sample from the `model` with the Markov chain Monte Carlo `sampler` and return the samples. -Sample from the `model` with the Markov chain Monte Carlo `sampler` until a -convergence criterion `isdone` returns `true`, and return the samples. +If `N_or_isdone` is an `Integer`, exactly `N_or_isdone` samples are returned. -The function `isdone` has the signature +Otherwise, sampling is performed until a convergence criterion `N_or_isdone` returns `true`. +The convergence criterion has to be a function with the signature ```julia isdone(rng, model, sampler, samples, state, iteration; kwargs...) ``` @@ -48,14 +45,14 @@ function StatsBase.sample( rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler, - isdone; + N_or_isdone; kwargs..., ) - return mcmcsample(rng, model, sampler, isdone; kwargs...) + return mcmcsample(rng, model, sampler, N_or_isdone; kwargs...) end function StatsBase.sample( - model::AbstractModel, + model_or_logdensity, sampler::AbstractSampler, parallel::AbstractMCMCEnsemble, N::Integer, @@ -63,12 +60,20 @@ function StatsBase.sample( kwargs..., ) return StatsBase.sample( - Random.default_rng(), model, sampler, parallel, N, nchains; kwargs... + Random.default_rng(), model_or_logdensity, sampler, parallel, N, nchains; kwargs... ) end """ - sample([rng, ]model, sampler, parallel, N, nchains; kwargs...) + sample( + rng::Random.AbstractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler, + parallel::AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + kwargs..., + ) Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel using the `parallel` algorithm, and combine them into a single chain. diff --git a/src/stepper.jl b/src/stepper.jl index 68059926..a71826cb 100644 --- a/src/stepper.jl +++ b/src/stepper.jl @@ -41,12 +41,17 @@ end Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite() Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown() -function steps(model::AbstractModel, sampler::AbstractSampler; kwargs...) - return steps(Random.default_rng(), model, sampler; kwargs...) +function steps(model_or_logdensity, sampler::AbstractSampler; kwargs...) + return steps(Random.default_rng(), model_or_logdensity, sampler; kwargs...) end """ - steps([rng, ]model, sampler; kwargs...) + steps( + rng::Random.AbstractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler; + kwargs..., + ) Create an iterator that returns samples from the `model` with the Markov chain Monte Carlo `sampler`. diff --git a/src/transducer.jl b/src/transducer.jl index 46d36d91..63bff3fd 100644 --- a/src/transducer.jl +++ b/src/transducer.jl @@ -6,12 +6,17 @@ struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: kwargs::K end -function Sample(model::AbstractModel, sampler::AbstractSampler; kwargs...) - return Sample(Random.default_rng(), model, sampler; kwargs...) +function Sample(model_or_logdensity, sampler::AbstractSampler; kwargs...) + return Sample(Random.default_rng(), model_or_logdensity, sampler; kwargs...) end """ - Sample([rng, ]model, sampler; kwargs...) + Sample( + rng::Random.AbstractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler; + kwargs..., + ) Create a transducer that returns samples from the `model` with the Markov chain Monte Carlo `sampler`. diff --git a/test/logdensityproblems.jl b/test/logdensityproblems.jl new file mode 100644 index 00000000..181d2645 --- /dev/null +++ b/test/logdensityproblems.jl @@ -0,0 +1,90 @@ +@testset "logdensityproblems.jl" begin + # Add worker processes. + # Memory requirements on Windows are ~4x larger than on Linux, hence number of processes is reduced + # See, e.g., https://github.com/JuliaLang/julia/issues/40766 and https://github.com/JuliaLang/Pkg.jl/pull/2366 + pids = addprocs(Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int) + + # Load all required packages (`utils.jl` needs LogDensityProblems, Logging, and Random). + @everywhere begin + using AbstractMCMC + using AbstractMCMC: sample + using LogDensityProblems + + using Logging + using Random + include("utils.jl") + end + + @testset "LogDensityModel" begin + ℓ = MyLogDensity(10) + model = @inferred AbstractMCMC.LogDensityModel(ℓ) + @test model isa AbstractMCMC.LogDensityModel{MyLogDensity} + @test model.logdensity === ℓ + + @test_throws ArgumentError AbstractMCMC.LogDensityModel(mylogdensity) + end + + @testset "fallback for log densities" begin + # Sample with log density + dim = 10 + ℓ = MyLogDensity(dim) + Random.seed!(1234) + N = 1_000 + samples = sample(ℓ, MySampler(), N) + + # Samples are of the correct dimension and log density values are correct + @test length(samples) == N + @test all(length(x.a) == dim for x in samples) + @test all(x.b ≈ LogDensityProblems.logdensity(ℓ, x.a) for x in samples) + + # Same chain as if LogDensityModel is used explicitly + Random.seed!(1234) + samples2 = sample(AbstractMCMC.LogDensityModel(ℓ), MySampler(), N) + @test length(samples2) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples2)) + + # Same chain if sampling is performed with convergence criterion + Random.seed!(1234) + isdone(rng, model, sampler, state, samples, iteration; kwargs...) = iteration > N + samples3 = sample(ℓ, MySampler(), isdone) + @test length(samples3) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples3)) + + # Same chain if sampling is performed with iterator + Random.seed!(1234) + samples4 = collect(Iterators.take(AbstractMCMC.steps(ℓ, MySampler()), N)) + @test length(samples4) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples4)) + + # Same chain if sampling is performed with transducer + Random.seed!(1234) + xf = AbstractMCMC.Sample(ℓ, MySampler()) + samples5 = collect(xf(1:N)) + @test length(samples5) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples5)) + + # Parallel sampling + for alg in (MCMCSerial(), MCMCDistributed(), MCMCThreads()) + chains = sample(ℓ, MySampler(), alg, N, 2) + @test length(chains) == 2 + samples = vcat(chains[1], chains[2]) + @test length(samples) == 2 * N + @test all(length(x.a) == dim for x in samples) + @test all(x.b ≈ LogDensityProblems.logdensity(ℓ, x.a) for x in samples) + end + + # Log density has to satisfy the LogDensityProblems interface + @test_throws ArgumentError sample(mylogdensity, MySampler(), N) + @test_throws ArgumentError sample(mylogdensity, MySampler(), isdone) + @test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCSerial(), N, 2) + @test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCThreads(), N, 2) + @test_throws ArgumentError sample( + mylogdensity, MySampler(), MCMCDistributed(), N, 2 + ) + @test_throws ArgumentError AbstractMCMC.steps(mylogdensity, MySampler()) + @test_throws ArgumentError AbstractMCMC.Sample(mylogdensity, MySampler()) + end + + # Remove workers + rmprocs(pids...) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3baef78c..0b002b21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using AbstractMCMC using Atom.Progress: JunoProgressLogger using ConsoleProgressMonitor: ProgressLogger using IJulia +using LogDensityProblems using LoggingExtras: TeeLogger, EarlyFilteredLogger using TerminalLoggers: TerminalLogger using Transducers @@ -22,4 +23,5 @@ include("utils.jl") include("sample.jl") include("stepper.jl") include("transducer.jl") + include("logdensityproblems.jl") end diff --git a/test/sample.jl b/test/sample.jl index 97bf5a5e..7ced7f0c 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -221,13 +221,17 @@ # Add worker processes. # Memory requirements on Windows are ~4x larger than on Linux, hence number of processes is reduced # See, e.g., https://github.com/JuliaLang/julia/issues/40766 and https://github.com/JuliaLang/Pkg.jl/pull/2366 - addprocs(Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int) + pids = addprocs( + Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int + ) - # Load all required packages (`interface.jl` needs Random). + # Load all required packages (`utils.jl` needs LogDensityProblems, Logging, and Random). @everywhere begin using AbstractMCMC using AbstractMCMC: sample + using LogDensityProblems + using Logging using Random include("utils.jl") end @@ -316,6 +320,9 @@ @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) + + # Remove workers + rmprocs(pids...) end @testset "Serial sampling" begin diff --git a/test/utils.jl b/test/utils.jl index e2eedcb4..f69fcdab 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -81,3 +81,35 @@ end # Conversion to NamedTuple Base.convert(::Type{NamedTuple}, x::MySample) = (a=x.a, b=x.b) + +# Gaussian log density (without additive constants) +# Without LogDensityProblems.jl interface +mylogdensity(x) = -sum(abs2, x) / 2 + +# With LogDensityProblems.jl interface +struct MyLogDensity + dim::Int +end +LogDensityProblems.logdensity(::MyLogDensity, x) = mylogdensity(x) +LogDensityProblems.dimension(m::MyLogDensity) = m.dim +function LogDensityProblems.capabilities(::Type{MyLogDensity}) + return LogDensityProblems.LogDensityOrder{0}() +end + +# Define "sampling" +function AbstractMCMC.step( + rng::AbstractRNG, + model::AbstractMCMC.LogDensityModel{MyLogDensity}, + ::MySampler, + state::Union{Nothing,Integer}=nothing; + kwargs..., +) + # Sample from multivariate normal distribution + ℓ = model.logdensity + dim = LogDensityProblems.dimension(ℓ) + θ = randn(rng, dim) + logdensity_θ = LogDensityProblems.logdensity(ℓ, θ) + + _state = state === nothing ? 1 : state + 1 + return MySample(θ, logdensity_θ), _state +end From 3ca7b942e3e4782f21230ff4bf540cac590a08e7 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 10 Jan 2023 13:53:49 +0100 Subject: [PATCH 31/54] Do not test Atom/Juno --- Project.toml | 3 +-- test/runtests.jl | 1 - test/sample.jl | 16 ---------------- 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index a6bf3e65..40494911 100644 --- a/Project.toml +++ b/Project.toml @@ -30,10 +30,9 @@ Transducers = "0.4.30" julia = "1.6" [extras] -Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Atom", "IJulia", "Statistics", "Test"] +test = ["IJulia", "Statistics", "Test"] diff --git a/test/runtests.jl b/test/runtests.jl index 0b002b21..75aac0f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,4 @@ using AbstractMCMC -using Atom.Progress: JunoProgressLogger using ConsoleProgressMonitor: ProgressLogger using IJulia using LogDensityProblems diff --git a/test/sample.jl b/test/sample.jl index 7ced7f0c..fcd3ab13 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -34,22 +34,6 @@ @test chain[1].b == 3.2 end - @testset "Juno" begin - empty!(LOGGERS) - - Random.seed!(1234) - N = 10 - - logger = JunoProgressLogger() - Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; loggers=true) - end - - @test length(LOGGERS) == 1 - @test first(LOGGERS) === logger - @test Logging.current_logger() === CURRENT_LOGGER - end - @testset "IJulia" begin # emulate running IJulia kernel @eval IJulia begin From 35ef3c00154d3395877027a380dcd82d36e14a62 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 2 May 2023 14:11:46 +0100 Subject: [PATCH 32/54] CompatHelper: bump compat for StatsBase to 0.34, (keep existing compat) (#121) Co-authored-by: CompatHelper Julia --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 40494911..ea99aa40 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ ConsoleProgressMonitor = "0.1" LogDensityProblems = "2" LoggingExtras = "0.4, 0.5, 1" ProgressLogging = "0.1" -StatsBase = "0.32, 0.33" +StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" Transducers = "0.4.30" julia = "1.6" From e149f9f18453e605c31efc0990c0d107055ff576 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 2 May 2023 14:12:31 +0100 Subject: [PATCH 33/54] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ea99aa40..b92f6578 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.4.0" +version = "4.4.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From d7c549fe41a80c1f164423c7ac458425535f624b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 20 Jun 2023 20:36:51 +0200 Subject: [PATCH 34/54] Fix method ambiguity issues (#123) * Fix method ambiguity issues * Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update tolerances --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 2 +- src/interface.jl | 25 +++++++++++++++++++++---- test/sample.jl | 2 +- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index b92f6578..7960600c 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.4.1" +version = "4.4.2" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/interface.jl b/src/interface.jl index eaecb492..928a933d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -30,13 +30,30 @@ be specified with the `chain_type` argument. By default, this method returns `samples`. """ function bundle_samples( - samples, ::AbstractModel, ::AbstractSampler, ::Any, ::Type; kwargs... + samples, model::AbstractModel, sampler::AbstractSampler, state, ::Type{T}; kwargs... +) where {T} + # dispatch to internal method for default implementations to fix + # method ambiguity issues (see #120) + return _bundle_samples(samples, model, sampler, state, T; kwargs...) +end + +function _bundle_samples( + samples, + @nospecialize(::AbstractModel), + @nospecialize(::AbstractSampler), + @nospecialize(::Any), + ::Type; + kwargs..., ) return samples end - -function bundle_samples( - samples::Vector, ::AbstractModel, ::AbstractSampler, ::Any, ::Type{Vector{T}}; kwargs... +function _bundle_samples( + samples::Vector, + @nospecialize(::AbstractModel), + @nospecialize(::AbstractSampler), + @nospecialize(::Any), + ::Type{Vector{T}}; + kwargs..., ) where {T} return map(samples) do sample convert(T, sample) diff --git a/test/sample.jl b/test/sample.jl index fcd3ab13..261cc1ef 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -564,7 +564,7 @@ @test ismissing(chain[1].a) @test mean(x.a for x in view(chain, 2:1_000)) ≈ 0.5 atol = 6e-2 @test var(x.a for x in view(chain, 2:1_000)) ≈ 1 / 12 atol = 1e-2 - @test mean(x.b for x in chain) ≈ 0 atol = 0.1 + @test mean(x.b for x in chain) ≈ 0 atol = 0.11 @test var(x.b for x in chain) ≈ 1 atol = 0.15 end From 6f5ac5a58c0e4c0b6850d02f8746984396b3ef26 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 00:36:43 +0100 Subject: [PATCH 35/54] use _init_parmas for MCMCThreads and MCMCDistributed too --- src/sample.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index dc951ca2..b949087f 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -410,6 +410,9 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Ensure that initial parameters are `nothing` or indexable + _init_params = _first_or_nothing(init_params, nchains) + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -466,10 +469,10 @@ function mcmcsample( # Return the new chain. return chain end - chains = if init_params === nothing + chains = if _init_params === nothing Distributed.pmap(sample_chain, pool, seeds) else - Distributed.pmap(sample_chain, pool, seeds, init_params) + Distributed.pmap(sample_chain, pool, seeds, _init_params) end finally # Stop updating the progress bar. @@ -499,6 +502,9 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Ensure that initial parameters are `nothing` or indexable + _init_params = _first_or_nothing(init_params, nchains) + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -519,10 +525,10 @@ function mcmcsample( ) end - chains = if init_params === nothing + chains = if _init_params === nothing map(sample_chain, 1:nchains, seeds) else - map(sample_chain, 1:nchains, seeds, init_params) + map(sample_chain, 1:nchains, seeds, _init_params) end # Concatenate the chains together. From 880852148366870e0ccef4ef9c21d6d0a9d9408b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 00:37:37 +0100 Subject: [PATCH 36/54] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7960600c..96ef5cba 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.4.2" +version = "4.4.3" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From e2a04ba2561e51bddca2bf0100aeb12190d587c4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 00:55:35 +0100 Subject: [PATCH 37/54] added some tests for too many and too few init params --- test/sample.jl | 105 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 87 insertions(+), 18 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index 261cc1ef..73f9fa96 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -162,17 +162,18 @@ end # initial parameters - init_params = [(b=randn(), a=rand()) for _ in 1:100] + nchains = 100 + init_params = [(b=randn(), a=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), MCMCThreads(), 3, - 100; + nchains; progress=false, init_params=init_params, ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for (chain, params) in zip(chains, init_params) @@ -184,14 +185,36 @@ MySampler(), MCMCThreads(), 3, - 100; + nchains; progress=false, - init_params=Iterators.repeated(init_params), + init_params=Iterators.repeated(init_params, nchains), ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) + + # Too many `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + nchains; + progress=false, + init_params=Iterators.repeated(init_params, nchains + 1), + ) + + # Too few `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + nchains; + progress=false, + init_params=Iterators.repeated(init_params, nchains - 1), + ) end @testset "Multicore sampling" begin @@ -274,17 +297,18 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) # initial parameters - init_params = [(a=randn(), b=rand()) for _ in 1:100] + nchains = 100 + init_params = [(a=randn(), b=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), MCMCDistributed(), 3, - 100; + nchains; progress=false, init_params=init_params, ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for (chain, params) in zip(chains, init_params) @@ -296,15 +320,37 @@ MySampler(), MCMCDistributed(), 3, - 100; + nchains; progress=false, - init_params=Iterators.repeated(init_params), + init_params=Iterators.repeated(init_params, nchains), ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) + # Too many `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCThrMCMCDistributedeads(), + 3, + nchains; + progress=false, + init_params=Iterators.repeated(init_params, nchains + 1), + ) + + # Too few `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + nchains; + progress=false, + init_params=Iterators.repeated(init_params, nchains - 1), + ) + # Remove workers rmprocs(pids...) end @@ -360,17 +406,18 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) # initial parameters - init_params = [(a=rand(), b=randn()) for _ in 1:100] + nchains = 100 + init_params = [(a=rand(), b=randn()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), MCMCSerial(), 3, - 100; + nchains; progress=false, init_params=init_params, ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for (chain, params) in zip(chains, init_params) @@ -382,14 +429,36 @@ MySampler(), MCMCSerial(), 3, - 100; + nchains; progress=false, - init_params=Iterators.repeated(init_params), + init_params=Iterators.repeated(init_params, nchains), ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) + + # Too many `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + nchains; + progress=false, + init_params=Iterators.repeated(init_params, nchains + 1), + ) + + # Too few `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + nchains; + progress=false, + init_params=Iterators.repeated(init_params, nchains - 1), + ) end @testset "Ensemble sampling: Reproducibility" begin From 2e6e23d6370fd5c7be6f09607403be44bba82667 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 01:04:45 +0100 Subject: [PATCH 38/54] fixed typo in tests --- test/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sample.jl b/test/sample.jl index 73f9fa96..a41b3228 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -333,7 +333,7 @@ @test_throws ArgumentError sample( MyModel(), MySampler(), - MCMCThrMCMCDistributedeads(), + MCMCDistributed(), 3, nchains; progress=false, From e897b8ac852e07fc6607d586cbf761919ddc9437 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 01:26:26 +0100 Subject: [PATCH 39/54] remove _first_or_nothing and just check if init_params is of the right length --- src/sample.jl | 52 +++++++++++++++------------------------------------ 1 file changed, 15 insertions(+), 37 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index b949087f..733a4e85 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -312,8 +312,8 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) - # Ensure that initial parameters are `nothing` or indexable - _init_params = _first_or_nothing(init_params, nchains) + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -364,10 +364,10 @@ function mcmcsample( _sampler, N; progress=false, - init_params=if _init_params === nothing + init_params=if init_params === nothing nothing else - _init_params[chainidx] + init_params[chainidx] end, kwargs..., ) @@ -410,8 +410,8 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters are `nothing` or indexable - _init_params = _first_or_nothing(init_params, nchains) + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -469,10 +469,10 @@ function mcmcsample( # Return the new chain. return chain end - chains = if _init_params === nothing + chains = if init_params === nothing Distributed.pmap(sample_chain, pool, seeds) else - Distributed.pmap(sample_chain, pool, seeds, _init_params) + Distributed.pmap(sample_chain, pool, seeds, init_params) end finally # Stop updating the progress bar. @@ -502,8 +502,8 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters are `nothing` or indexable - _init_params = _first_or_nothing(init_params, nchains) + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -525,10 +525,10 @@ function mcmcsample( ) end - chains = if _init_params === nothing + chains = if init_params === nothing map(sample_chain, 1:nchains, seeds) else - map(sample_chain, 1:nchains, seeds, _init_params) + map(sample_chain, 1:nchains, seeds, init_params) end # Concatenate the chains together. @@ -538,31 +538,9 @@ end tighten_eltype(x) = x tighten_eltype(x::Vector{Any}) = map(identity, x) -""" - _first_or_nothing(x, n::Int) - -Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`. - -If `x !== nothing`, then `x` has to contain at least `n` elements. -""" -function _first_or_nothing(x, n::Int) - y = _first(x, n) - length(y) == n || throw( +check_initial_params(x::Nothing, n::Int) = nothing +function check_initial_params(x, n::Int) + length(x) == n || throw( ArgumentError("not enough initial parameters (expected $n, received $(length(y))"), ) - return y -end -_first_or_nothing(::Nothing, ::Int) = nothing - -# `first(x, n::Int)` requires Julia 1.6 -function _first(x, n::Int) - @static if VERSION >= v"1.6.0-DEV.431" - first(x, n) - else - if x isa AbstractVector - @inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))] - else - collect(Iterators.take(x, n)) - end - end end From 33f9bb79ce853ebd5374391b9172725268e0698e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 01:27:49 +0100 Subject: [PATCH 40/54] fix typo in error message --- src/sample.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 733a4e85..e3d0dbbc 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -540,7 +540,9 @@ tighten_eltype(x::Vector{Any}) = map(identity, x) check_initial_params(x::Nothing, n::Int) = nothing function check_initial_params(x, n::Int) - length(x) == n || throw( - ArgumentError("not enough initial parameters (expected $n, received $(length(y))"), - ) + if length(x) != n + throw( + ArgumentError("not enough initial parameters (expected $n, received $(length(x))"), + ) + end end From 85deddb65224cbcb2a143117477b6d2286547342 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 01:28:42 +0100 Subject: [PATCH 41/54] bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 96ef5cba..4e00a5f1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.4.3" +version = "4.5.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From c985383a7d5488f1eb0698d8809eece6610abd09 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 01:46:20 +0100 Subject: [PATCH 42/54] use `collect` for init_params when using MCMCThreads --- src/sample.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index e3d0dbbc..8dd01b16 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -314,6 +314,8 @@ function mcmcsample( # Ensure that initial parameters are `nothing` or of the correct length check_initial_params(init_params, nchains) + # We will use `getindex` later so we need to `collect`. + _init_params = collect(init_params) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -364,10 +366,10 @@ function mcmcsample( _sampler, N; progress=false, - init_params=if init_params === nothing + init_params=if _init_params === nothing nothing else - init_params[chainidx] + _init_params[chainidx] end, kwargs..., ) From fc3dd211e4bc04767ba02f8f4c3507d0dd013ea9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 01:47:33 +0100 Subject: [PATCH 43/54] check if init_params is nothing before collect --- src/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index 8dd01b16..fc746558 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -315,7 +315,7 @@ function mcmcsample( # Ensure that initial parameters are `nothing` or of the correct length check_initial_params(init_params, nchains) # We will use `getindex` later so we need to `collect`. - _init_params = collect(init_params) + _init_params = init_params !== nothing ? collect(init_params) : nothing # Set up a chains vector. chains = Vector{Any}(undef, nchains) From d2e2a9b5e5e08d45b624412aa0fa1fdd2b322c58 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 02:26:47 +0100 Subject: [PATCH 44/54] Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sample.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index fc746558..1176df67 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -544,7 +544,9 @@ check_initial_params(x::Nothing, n::Int) = nothing function check_initial_params(x, n::Int) if length(x) != n throw( - ArgumentError("not enough initial parameters (expected $n, received $(length(x))"), + ArgumentError( + "not enough initial parameters (expected $n, received $(length(x))" + ), ) end end From 93a4f1beb15643ada39293a9559bb0e0ba410091 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 14 Sep 2023 02:27:16 +0100 Subject: [PATCH 45/54] fuxed error message --- src/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index fc746558..a772ec96 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -544,7 +544,7 @@ check_initial_params(x::Nothing, n::Int) = nothing function check_initial_params(x, n::Int) if length(x) != n throw( - ArgumentError("not enough initial parameters (expected $n, received $(length(x))"), + ArgumentError("incorrect number of initial parameters (expected $n, received $(length(x))"), ) end end From 1c589e8d4e054dd3b30b4a6cdb7bc25577ae8dce Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 13:47:19 +0200 Subject: [PATCH 46/54] CompatHelper: bump compat for Documenter to 1 for package docs, (keep existing compat) (#127) * CompatHelper: bump compat for Documenter to 1 for package docs, (keep existing compat) * Remove outdated keyword argument --------- Co-authored-by: CompatHelper Julia Co-authored-by: David Widmann --- docs/Project.toml | 2 +- docs/make.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 555443ab..f74dfb58 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,5 +3,5 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -Documenter = "0.27" +Documenter = "1" julia = "1" diff --git a/docs/make.jl b/docs/make.jl index 66d7619c..9395d2a0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -9,7 +9,6 @@ makedocs(; format=Documenter.HTML(), modules=[AbstractMCMC], pages=["Home" => "index.md", "api.md", "design.md"], - strict=true, checkdocs=:exports, ) From caeade2abe60b6803201cd341f7d62595465f6b2 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 27 Sep 2023 13:55:41 +0200 Subject: [PATCH 47/54] Update callback signature in docs (#130) --- docs/src/api.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 52c2c2e1..629e4c37 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -71,8 +71,8 @@ Common keyword arguments for regular and parallel sampling are: - `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then - `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, - where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration + `callback(rng, model, sampler, sample, state, iteration)` is called after every sampling step, + where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. From 3e8798700e6e89186a311aa0c8af50b89117519a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 17:57:20 +0100 Subject: [PATCH 48/54] require init_params to be a vector of length equal to the nubmer of chains or nothing --- src/sample.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index a0b021af..6c9c32ae 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -314,8 +314,6 @@ function mcmcsample( # Ensure that initial parameters are `nothing` or of the correct length check_initial_params(init_params, nchains) - # We will use `getindex` later so we need to `collect`. - _init_params = init_params !== nothing ? collect(init_params) : nothing # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -366,10 +364,10 @@ function mcmcsample( _sampler, N; progress=false, - init_params=if _init_params === nothing + init_params=if init_params === nothing nothing else - _init_params[chainidx] + init_params[chainidx] end, kwargs..., ) @@ -540,8 +538,14 @@ end tighten_eltype(x) = x tighten_eltype(x::Vector{Any}) = map(identity, x) -check_initial_params(x::Nothing, n::Int) = nothing -function check_initial_params(x, n::Int) +@nospecialize check_initial_params(x, n) = throw( + ArgumentError( + "initial parameters must be specified as a vector of length equal to the number of chains or `nothing`", + ), +) + +check_initial_params(::Nothing, n) = nothing +function check_initial_params(x::AbstractArray, n) if length(x) != n throw( ArgumentError( @@ -549,4 +553,6 @@ function check_initial_params(x, n::Int) ), ) end + + return nothing end From 3c1ed50d4a6bc3fa3b802c9ab5707c723978ff4f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:00:08 +0100 Subject: [PATCH 49/54] updated docs --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 52c2c2e1..d89b078a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -82,7 +82,7 @@ Common keyword arguments for regular and parallel sampling are: There is no "official" way for providing initial parameter values yet. However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): -- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`. +- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`. Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. From ae6562fa2921c7985613c4a09c472becc6ed543d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:00:29 +0100 Subject: [PATCH 50/54] bump minor version since this will be potentially breaking --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4e00a5f1..02068817 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.5.0" +version = "4.6.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From e1bb661b1c089341b1139286fc44eb0c280bfded Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:31:58 +0100 Subject: [PATCH 51/54] replaced usages of Iterators.repeated with FillArrays.Fill --- test/sample.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index a41b3228..22f4b26d 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -187,7 +187,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains), + init_params=FillArrays.Fill(init_params, nchains), ) @test length(chains) == nchains @test all( @@ -202,7 +202,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains + 1), + init_params=FillArrays.Fill(init_params, nchains + 1), ) # Too few `init_params` @@ -213,7 +213,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains - 1), + init_params=FillArrays.Fill(init_params, nchains - 1), ) end @@ -322,7 +322,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains), + init_params=FillArrays.Fill(init_params, nchains), ) @test length(chains) == nchains @test all( @@ -337,7 +337,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains + 1), + init_params=FillArrays.Fill(init_params, nchains + 1), ) # Too few `init_params` @@ -348,7 +348,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains - 1), + init_params=FillArrays.Fill(init_params, nchains - 1), ) # Remove workers @@ -431,7 +431,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains), + init_params=FillArrays.Fill(init_params, nchains), ) @test length(chains) == nchains @test all( @@ -446,7 +446,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains + 1), + init_params=FillArrays.Fill(init_params, nchains + 1), ) # Too few `init_params` @@ -457,7 +457,7 @@ 3, nchains; progress=false, - init_params=Iterators.repeated(init_params, nchains - 1), + init_params=FillArrays.Fill(init_params, nchains - 1), ) end From eaf9d501f60e106e340568e8387e74a19935d65e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:33:33 +0100 Subject: [PATCH 52/54] added FillArrays as a test dep --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 02068817..40f90c96 100644 --- a/Project.toml +++ b/Project.toml @@ -30,9 +30,10 @@ Transducers = "0.4.30" julia = "1.6" [extras] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["IJulia", "Statistics", "Test"] +test = ["FillArrays", "IJulia", "Statistics", "Test"] From 4dbcb3fd01b1111d3cfcd234ae79bf750c1c541e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Oct 2023 11:35:02 +0100 Subject: [PATCH 53/54] fixed runtests --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 75aac0f1..909ae8b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using IJulia using LogDensityProblems using LoggingExtras: TeeLogger, EarlyFilteredLogger using TerminalLoggers: TerminalLogger +using FillArrays: FillArrays using Transducers using Distributed From d5218159232bc3b035ad9c789e874ac68b5643d5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Oct 2023 15:20:40 +0100 Subject: [PATCH 54/54] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 40f90c96..90117048 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.6.0" +version = "4.5.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"