diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..1e72b507 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style="blue" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 37cc617e..06ff8ad9 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: version: - - '1.0' + - '1.6' - '1' - nightly os: @@ -31,7 +31,7 @@ jobs: arch: x86 - os: macOS-latest arch: x86 - - version: '1.0' + - version: '1.6' num_threads: 2 include: - version: '1' @@ -45,16 +45,9 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 - env: - cache-name: cache-artifacts + - uses: julia-actions/cache@v1 with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- + cache-packages: "false" # caching Conda.jl causes precompilation error - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest env: diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index f5da2d24..23e85888 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -3,16 +3,35 @@ on: schedule: - cron: 0 0 * * * workflow_dispatch: +permissions: + contents: write + pull-requests: write jobs: CompatHelper: runs-on: ubuntu-latest steps: + - name: Check if Julia is already available in the PATH + id: julia_in_path + run: which julia + continue-on-error: true + - name: Install Julia, but only if it is not already available in the PATH + uses: julia-actions/setup-julia@v1 + with: + version: '1' + arch: ${{ runner.arch }} + if: steps.julia_in_path.outcome != 'success' + - name: "Add the General registry via Git" + run: | + import Pkg + ENV["JULIA_PKG_SERVER"] = "" + Pkg.Registry.add("General") + shell: julia --color=yes {0} - name: "Install CompatHelper" run: | import Pkg name = "CompatHelper" uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" - version = "2" + version = "3" Pkg.add(; name, uuid, version) shell: julia --color=yes {0} - name: "Run CompatHelper" diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml index e47f9389..afa6ee8b 100644 --- a/.github/workflows/Docs.yml +++ b/.github/workflows/Docs.yml @@ -12,13 +12,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@latest + - uses: julia-actions/setup-julia@v1 with: version: '1' - - name: Install dependencies - run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - - name: Build and deploy + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key - run: julia --project=docs/ docs/make.jl + JULIA_DEBUG: Documenter # Print `@debug` statements (https://github.com/JuliaDocs/Documenter.jl/issues/955) diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml new file mode 100644 index 00000000..ec14da16 --- /dev/null +++ b/.github/workflows/Format.yml @@ -0,0 +1,31 @@ +name: Format + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - name: Format code + run: | + using Pkg + Pkg.add(; name="JuliaFormatter", uuid="98e50ef6-434e-11e9-1051-2b60c6c9e899") + using JuliaFormatter + format("."; verbose=true) + shell: julia --color=yes {0} + - uses: reviewdog/action-suggester@v1 + if: github.event_name == 'pull_request' + with: + tool_name: JuliaFormatter + fail_on_error: true diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index cd3b3658..2e9d6bcf 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -14,6 +14,7 @@ jobs: fail-fast: false matrix: package: + - {user: TuringLang, repo: AdvancedHMC.jl} - {user: TuringLang, repo: AdvancedMH.jl} - {user: TuringLang, repo: EllipticalSliceSampling.jl} - {user: TuringLang, repo: MCMCChains.jl} diff --git a/.gitignore b/.gitignore index dfa313a1..83d89f72 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ *.jl.*.cov *.jl.mem deps/deps.jl -/Manifest.toml \ No newline at end of file +Manifest.toml \ No newline at end of file diff --git a/Project.toml b/Project.toml index 7ccac490..90117048 100644 --- a/Project.toml +++ b/Project.toml @@ -3,12 +3,13 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.2.1" +version = "4.5.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" @@ -20,18 +21,19 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19" ConsoleProgressMonitor = "0.1" -LoggingExtras = "0.4" +LogDensityProblems = "2" +LoggingExtras = "0.4, 0.5, 1" ProgressLogging = "0.1" -StatsBase = "0.32, 0.33" +StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" Transducers = "0.4.30" -julia = "1" +julia = "1.6" [extras] -Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Atom", "IJulia", "Statistics", "Test"] +test = ["FillArrays", "IJulia", "Statistics", "Test"] diff --git a/README.md b/README.md index a2d40c34..ee186269 100644 --- a/README.md +++ b/README.md @@ -8,3 +8,4 @@ Abstract types and interfaces for Markov chain Monte Carlo methods. [![IntegrationTest](https://github.com/TuringLang/AbstractMCMC.jl/workflows/IntegrationTest/badge.svg?branch=master)](https://github.com/TuringLang/AbstractMCMC.jl/actions?query=workflow%3AIntegrationTest+branch%3Amaster) [![Codecov](https://codecov.io/gh/TuringLang/AbstractMCMC.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AbstractMCMC.jl) [![Coveralls](https://coveralls.io/repos/github/TuringLang/AbstractMCMC.jl/badge.svg?branch=master)](https://coveralls.io/github/TuringLang/AbstractMCMC.jl?branch=master) +[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) diff --git a/docs/Manifest.toml b/docs/Manifest.toml deleted file mode 100644 index 014aac99..00000000 --- a/docs/Manifest.toml +++ /dev/null @@ -1,374 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[AbstractMCMC]] -deps = ["BangBang", "ConsoleProgressMonitor", "Distributed", "Logging", "LoggingExtras", "ProgressLogging", "Random", "StatsBase", "TerminalLoggers", "Transducers"] -path = ".." -uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" -version = "3.2.1" - -[[AbstractTrees]] -git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.3.4" - -[[Adapt]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.3.1" - -[[ArgCheck]] -git-tree-sha1 = "dedbbb2ddb876f899585c4ec4433265e3017215a" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.1.0" - -[[ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" - -[[Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[BangBang]] -deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] -git-tree-sha1 = "e239020994123f08905052b9603b4ca14f8c5807" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.31" - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.31.0" - -[[CompositionsBase]] -git-tree-sha1 = "f3955eb38944e5dd0fabf8ca1e267d94941d34a5" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.0" - -[[ConsoleProgressMonitor]] -deps = ["Logging", "ProgressMeter"] -git-tree-sha1 = "3ab7b2136722890b9af903859afcf457fa3059e8" -uuid = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" -version = "0.1.2" - -[[ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f74e9d5388b8620b4cee35d4c5a618dd4dc547f4" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.3.0" - -[[DataAPI]] -git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.7.0" - -[[DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.9" - -[[DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[DefineSingletons]] -git-tree-sha1 = "77b4ca280084423b728662fe040e5ff8819347c5" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.1" - -[[DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.5" - -[[Documenter]] -deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48" -uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.3" - -[[Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" - -[[Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[IOCapture]] -deps = ["Logging", "Random"] -git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" -uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.2" - -[[InitialValues]] -git-tree-sha1 = "26c8832afd63ac558b98a823265856670d898b6c" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.2.10" - -[[InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.1" - -[[LeftChildRightSiblingTrees]] -deps = ["AbstractTrees"] -git-tree-sha1 = "71be1eb5ad19cb4f61fa8c73395c0338fd092ae0" -uuid = "1d6d02ad-be62-4b6b-8a6d-2f90e265016e" -version = "0.1.2" - -[[LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" - -[[LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" - -[[LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "dfeda1c1130990428720de0024d4516b1902ce98" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "0.4.7" - -[[MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.6" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" - -[[MicroCollections]] -deps = ["BangBang", "Setfield"] -git-tree-sha1 = "e991b6a9d38091c4a0d7cd051fcb57c05f98ac03" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.0" - -[[Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.0.0" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" - -[[NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" - -[[OrderedCollections]] -git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.1" - -[[Parsers]] -deps = ["Dates"] -git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.1.0" - -[[Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "afadeba63d90ff223a6a48d2009434ecee2ec9e8" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.7.1" - -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.1.3" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "Requires"] -git-tree-sha1 = "d5640fc570fb1b6c54512f0bd3853866bd298b3e" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "0.7.0" - -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.0.0" - -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "edef25a158db82f4940720ebada14a60ef6c4232" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.13" - -[[Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[StatsAPI]] -git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.0.0" - -[[StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.8" - -[[TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" - -[[TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] -git-tree-sha1 = "8ed4a3ea724dac32670b062be3ef1c1de6773ae8" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.4.4" - -[[Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" - -[[TerminalLoggers]] -deps = ["LeftChildRightSiblingTrees", "Logging", "Markdown", "Printf", "ProgressLogging", "UUIDs"] -git-tree-sha1 = "d620a061cb2a56930b52bdf5cf908a5c4fa8e76a" -uuid = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" -version = "0.1.4" - -[[Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[Transducers]] -deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "34f27ac221cb53317ab6df196f9ed145077231ff" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.65" - -[[UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" - -[[ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.1" - -[[nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" - -[[p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/docs/Project.toml b/docs/Project.toml index 69dcc9d0..f74dfb58 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,9 +1,7 @@ [deps] -AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -AbstractMCMC = "3" -Documenter = "0.27" +Documenter = "1" julia = "1" diff --git a/docs/make.jl b/docs/make.jl index e0fa16e9..9395d2a0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,33 +1,15 @@ -using Documenter - -# Print `@debug` statements (https://github.com/JuliaDocs/Documenter.jl/issues/955) -if haskey(ENV, "GITHUB_ACTIONS") - ENV["JULIA_DEBUG"] = "Documenter" -end - using AbstractMCMC +using Documenter using Random -DocMeta.setdocmeta!( - AbstractMCMC, - :DocTestSetup, - :(using AbstractMCMC); - recursive=true, -) +DocMeta.setdocmeta!(AbstractMCMC, :DocTestSetup, :(using AbstractMCMC); recursive=true) makedocs(; sitename="AbstractMCMC", format=Documenter.HTML(), modules=[AbstractMCMC], - pages=[ - "Home" => "index.md", - "api.md", - "design.md", - ], - strict=true, + pages=["Home" => "index.md", "api.md", "design.md"], checkdocs=:exports, ) -deploydocs(; - repo="github.com/TuringLang/AbstractMCMC.jl.git", push_preview=true -) +deploydocs(; repo="github.com/TuringLang/AbstractMCMC.jl.git", push_preview=true) diff --git a/docs/src/api.md b/docs/src/api.md index 6be52d6d..648a87b8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -2,23 +2,39 @@ AbstractMCMC defines an interface for sampling Markov chains. +## Model + +```@docs +AbstractMCMC.AbstractModel +AbstractMCMC.LogDensityModel +``` + +## Sampler + +```@docs +AbstractMCMC.AbstractSampler +``` + ## Sampling a single chain ```@docs -AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Integer) AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Any) +AbstractMCMC.sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler, ::Any) + ``` ### Iterator ```@docs AbstractMCMC.steps(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler) +AbstractMCMC.steps(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler) ``` ### Transducer ```@docs AbstractMCMC.Sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler) +AbstractMCMC.Sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler) ``` ## Sampling multiple chains in parallel @@ -32,6 +48,14 @@ AbstractMCMC.sample( ::Integer, ::Integer, ) +AbstractMCMC.sample( + ::AbstractRNG, + ::Any, + ::AbstractMCMC.AbstractSampler, + ::AbstractMCMC.AbstractMCMCEnsemble, + ::Integer, + ::Integer, +) ``` Two algorithms are provided for parallel sampling with multiple threads and multiple processes, and one allows for the user to sample multiple chains in serial (no parallelization): @@ -43,16 +67,23 @@ AbstractMCMC.MCMCSerial ## Common keyword arguments -Common keyword arguments for regular and parallel sampling (not supported by the iterator and transducer) -are: +Common keyword arguments for regular and parallel sampling are: - `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then - `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, - where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration + `callback(rng, model, sampler, sample, state, iteration)` is called after every sampling step, + where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. +!!! info + The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref). + +There is no "official" way for providing initial parameter values yet. +However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. +To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): +- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`. + Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. ```@docs diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index b1dc6b7d..07960440 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -1,16 +1,17 @@ module AbstractMCMC -import BangBang -import ConsoleProgressMonitor -import LoggingExtras -import ProgressLogging -import StatsBase -import TerminalLoggers -import Transducers - -import Distributed -import Logging -import Random +using BangBang: BangBang +using ConsoleProgressMonitor: ConsoleProgressMonitor +using LogDensityProblems: LogDensityProblems +using LoggingExtras: LoggingExtras +using ProgressLogging: ProgressLogging +using StatsBase: StatsBase +using TerminalLoggers: TerminalLoggers +using Transducers: Transducers + +using Distributed: Distributed +using Logging: Logging +using Random: Random # Reexport sample using StatsBase: sample @@ -71,7 +72,6 @@ processes. """ struct MCMCDistributed <: AbstractMCMCEnsemble end - """ MCMCSerial @@ -115,6 +115,6 @@ include("interface.jl") include("sample.jl") include("stepper.jl") include("transducer.jl") -include("deprecations.jl") +include("logdensityproblems.jl") end # module AbstractMCMC diff --git a/src/deprecations.jl b/src/deprecations.jl deleted file mode 100644 index 1cc93d12..00000000 --- a/src/deprecations.jl +++ /dev/null @@ -1,2 +0,0 @@ -# Deprecate the old name AbstractMCMCParallel in favor of AbstractMCMCEnsemble -Base.@deprecate_binding AbstractMCMCParallel AbstractMCMCEnsemble false \ No newline at end of file diff --git a/src/interface.jl b/src/interface.jl index 7b3daefb..928a933d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -30,24 +30,31 @@ be specified with the `chain_type` argument. By default, this method returns `samples`. """ function bundle_samples( + samples, model::AbstractModel, sampler::AbstractSampler, state, ::Type{T}; kwargs... +) where {T} + # dispatch to internal method for default implementations to fix + # method ambiguity issues (see #120) + return _bundle_samples(samples, model, sampler, state, T; kwargs...) +end + +function _bundle_samples( samples, - ::AbstractModel, - ::AbstractSampler, - ::Any, + @nospecialize(::AbstractModel), + @nospecialize(::AbstractSampler), + @nospecialize(::Any), ::Type; - kwargs... + kwargs..., ) return samples end - -function bundle_samples( +function _bundle_samples( samples::Vector, - ::AbstractModel, - ::AbstractSampler, - ::Any, + @nospecialize(::AbstractModel), + @nospecialize(::AbstractSampler), + @nospecialize(::Any), ::Type{Vector{T}}; - kwargs... -) where T + kwargs..., +) where {T} return map(samples) do sample convert(T, sample) end @@ -74,24 +81,13 @@ sample is `sample`. The method can be called with and without a predefined number `N` of samples. """ -function samples( - sample, - ::AbstractModel, - ::AbstractSampler, - N::Integer; - kwargs... -) +function samples(sample, ::AbstractModel, ::AbstractSampler, N::Integer; kwargs...) ts = Vector{typeof(sample)}(undef, 0) sizehint!(ts, N) return ts end -function samples( - sample, - ::AbstractModel, - ::AbstractSampler; - kwargs... -) +function samples(sample, ::AbstractModel, ::AbstractSampler; kwargs...) return Vector{typeof(sample)}(undef, 0) end @@ -113,7 +109,7 @@ function save!!( ::AbstractModel, ::AbstractSampler, N::Integer; - kwargs... + kwargs..., ) s = BangBang.push!!(samples, sample) s !== samples && sizehint!(s, N) @@ -121,27 +117,15 @@ function save!!( end function save!!( - samples, - sample, - iteration::Integer, - ::AbstractModel, - ::AbstractSampler; - kwargs... + samples, sample, iteration::Integer, ::AbstractModel, ::AbstractSampler; kwargs... ) return BangBang.push!!(samples, sample) end # Deprecations Base.@deprecate transitions( - transition, - model::AbstractModel, - sampler::AbstractSampler, - N::Integer; - kwargs... + transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs... ) samples(transition, model, sampler, N; kwargs...) false Base.@deprecate transitions( - transition, - model::AbstractModel, - sampler::AbstractSampler; - kwargs... + transition, model::AbstractModel, sampler::AbstractSampler; kwargs... ) samples(transition, model, sampler; kwargs...) false diff --git a/src/logdensityproblems.jl b/src/logdensityproblems.jl new file mode 100644 index 00000000..f15f656a --- /dev/null +++ b/src/logdensityproblems.jl @@ -0,0 +1,119 @@ +""" + LogDensityModel <: AbstractMCMC.AbstractModel + +Wrapper around something that implements the LogDensityProblem.jl interface. + +Note that this does _not_ implement the LogDensityProblems.jl interface itself, +but it simply useful for indicating to the `sample` and other `AbstractMCMC` methods +that the wrapped object implements the LogDensityProblems.jl interface. + +# Fields +- `logdensity`: The object that implements the LogDensityProblems.jl interface. +""" +struct LogDensityModel{L} <: AbstractModel + logdensity::L + function LogDensityModel{L}(logdensity::L) where {L} + if LogDensityProblems.capabilities(logdensity) === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + return new{L}(logdensity) + end +end + +LogDensityModel(logdensity::L) where {L} = LogDensityModel{L}(logdensity) + +# Fallbacks: Wrap log density function in a model +""" + sample( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler, + N_or_isdone; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function StatsBase.sample( + rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler, N_or_isdone; kwargs... +) + return StatsBase.sample(rng, _model(logdensity), sampler, N_or_isdone; kwargs...) +end + +""" + sample( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler, + parallel::AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function StatsBase.sample( + rng::Random.AbstractRNG, + logdensity, + sampler::AbstractSampler, + parallel::AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + kwargs..., +) + return StatsBase.sample( + rng, _model(logdensity), sampler, parallel, N, nchains; kwargs... + ) +end + +""" + steps( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `steps` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function steps(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...) + return steps(rng, _model(logdensity), sampler; kwargs...) +end + +""" + Sample( + rng::Random.AbstractRNG=Random.default_rng(), + logdensity, + sampler::AbstractSampler; + kwargs..., + ) + +Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `Sample` with the resulting model instead of `logdensity`. + +The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface. +""" +function Sample(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...) + return Sample(rng, _model(logdensity), sampler; kwargs...) +end + +function _model(logdensity) + if LogDensityProblems.capabilities(logdensity) === nothing + throw( + ArgumentError( + "the log density function does not support the LogDensityProblems.jl interface. Please implement the interface or provide a model of type `AbstractMCMC.AbstractModel`", + ), + ) + end + return LogDensityModel(logdensity) +end diff --git a/src/logging.jl b/src/logging.jl index a550c532..04c41187 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -2,19 +2,21 @@ # and add a custom progress logger if the current logger does not seem to be able to handle # progress logs macro ifwithprogresslogger(progress, exprs...) - return quote - if $progress - if $hasprogresslevel($Logging.current_logger()) - $ProgressLogging.@withprogress $(exprs...) - else - $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + return esc( + quote + if $progress + if $hasprogresslevel($Logging.current_logger()) $ProgressLogging.@withprogress $(exprs...) + else + $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + $ProgressLogging.@withprogress $(exprs...) + end end + else + $(exprs[end]) end - else - $(exprs[end]) - end - end |> esc + end, + ) end # improved checks? @@ -31,13 +33,14 @@ function with_progresslogger(f, _module, logger) log._module !== _module || log.level != ProgressLogging.ProgressLevel end - Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2)) + return Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2)) end function progresslogger() # detect if code is running under IJulia since TerminalLogger does not work with IJulia # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia - if (Sys.iswindows() && VERSION < v"1.5.3") || (isdefined(Main, :IJulia) && Main.IJulia.inited) + if (Sys.iswindows() && VERSION < v"1.5.3") || + (isdefined(Main, :IJulia) && Main.IJulia.inited) return ConsoleProgressMonitor.ProgressLogger() else return TerminalLoggers.TerminalLogger() diff --git a/src/sample.jl b/src/sample.jl index df76caf0..6c9c32ae 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -13,65 +13,67 @@ function setprogress!(progress::Bool) end function StatsBase.sample( - model::AbstractModel, - sampler::AbstractSampler, - arg; - kwargs... + model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs... ) - return StatsBase.sample(Random.GLOBAL_RNG, model, sampler, arg; kwargs...) + return StatsBase.sample( + Random.default_rng(), model_or_logdensity, sampler, N_or_isdone; kwargs... + ) end """ - sample([rng, ]model, sampler, N; kwargs...) - -Return `N` samples from the `model` with the Markov chain Monte Carlo `sampler`. -""" -function StatsBase.sample( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler, - N::Integer; - kwargs... -) - return mcmcsample(rng, model, sampler, N; kwargs...) -end + sample( + rng::Random.AbatractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler, + N_or_isdone; + kwargs..., + ) -""" - sample([rng, ]model, sampler, isdone; kwargs...) +Sample from the `model` with the Markov chain Monte Carlo `sampler` and return the samples. -Sample from the `model` with the Markov chain Monte Carlo `sampler` until a -convergence criterion `isdone` returns `true`, and return the samples. +If `N_or_isdone` is an `Integer`, exactly `N_or_isdone` samples are returned. -The function `isdone` has the signature +Otherwise, sampling is performed until a convergence criterion `N_or_isdone` returns `true`. +The convergence criterion has to be a function with the signature ```julia -isdone(rng, model, sampler, samples, iteration; kwargs...) +isdone(rng, model, sampler, samples, state, iteration; kwargs...) ``` -and should return `true` when sampling should end, and `false` otherwise. +where `state` and `iteration` are the current state and iteration of the sampler, respectively. +It should return `true` when sampling should end, and `false` otherwise. """ function StatsBase.sample( rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler, - isdone; - kwargs... + N_or_isdone; + kwargs..., ) - return mcmcsample(rng, model, sampler, isdone; kwargs...) + return mcmcsample(rng, model, sampler, N_or_isdone; kwargs...) end function StatsBase.sample( - model::AbstractModel, + model_or_logdensity, sampler::AbstractSampler, parallel::AbstractMCMCEnsemble, N::Integer, nchains::Integer; - kwargs... + kwargs..., ) - return StatsBase.sample(Random.GLOBAL_RNG, model, sampler, parallel, N, nchains; - kwargs...) + return StatsBase.sample( + Random.default_rng(), model_or_logdensity, sampler, parallel, N, nchains; kwargs... + ) end """ - sample([rng, ]model, sampler, parallel, N, nchains; kwargs...) + sample( + rng::Random.AbstractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler, + parallel::AbstractMCMCEnsemble, + N::Integer, + nchains::Integer; + kwargs..., + ) Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel using the `parallel` algorithm, and combine them into a single chain. @@ -83,7 +85,7 @@ function StatsBase.sample( parallel::AbstractMCMCEnsemble, N::Integer, nchains::Integer; - kwargs... + kwargs..., ) return mcmcsample(rng, model, sampler, parallel, N, nchains; kwargs...) end @@ -95,13 +97,13 @@ function mcmcsample( model::AbstractModel, sampler::AbstractSampler, N::Integer; - progress = PROGRESS[], - progressname = "Sampling", - callback = nothing, - discard_initial = 0, - thinning = 1, + progress=PROGRESS[], + progressname="Sampling", + callback=nothing, + discard_initial=0, + thinning=1, chain_type::Type=Any, - kwargs... + kwargs..., ) # Check the number of requested samples. N > 0 || error("the number of samples must be ≥ 1") @@ -111,7 +113,7 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) if progress @@ -123,10 +125,10 @@ function mcmcsample( sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. - for i in 1:(discard_initial - 1) + for i in 1:discard_initial # Update the progress bar. if progress && i >= next_update - ProgressLogging.@logprogress i/Ntotal + ProgressLogging.@logprogress i / Ntotal next_update = i + threshold end @@ -166,7 +168,8 @@ function mcmcsample( sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler, N; kwargs...) @@ -185,15 +188,15 @@ function mcmcsample( stats = SamplingStats(start, stop, duration) return bundle_samples( - samples, - model, + samples, + model, sampler, state, chain_type; stats=stats, discard_initial=discard_initial, thinning=thinning, - kwargs... + kwargs..., ) end @@ -203,24 +206,24 @@ function mcmcsample( sampler::AbstractSampler, isdone; chain_type::Type=Any, - progress = PROGRESS[], - progressname = "Convergence sampling", - callback = nothing, - discard_initial = 0, - thinning = 1, - kwargs... + progress=PROGRESS[], + progressname="Convergence sampling", + callback=nothing, + discard_initial=0, + thinning=1, + kwargs..., ) # Start the timer start = time() local state - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. - for _ in 2:discard_initial + for _ in 1:discard_initial # Obtain the next sample and state. sample, state = step(rng, model, sampler, state; kwargs...) end @@ -246,7 +249,8 @@ function mcmcsample( sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler; kwargs...) @@ -263,15 +267,15 @@ function mcmcsample( # Wrap the samples up. return bundle_samples( - samples, + samples, model, - sampler, - state, - chain_type; + sampler, + state, + chain_type; stats=stats, discard_initial=discard_initial, thinning=thinning, - kwargs... + kwargs..., ) end @@ -282,9 +286,10 @@ function mcmcsample( ::MCMCThreads, N::Integer, nchains::Integer; - progress = PROGRESS[], - progressname = "Sampling ($(min(nchains, Threads.nthreads())) threads)", - kwargs... + progress=PROGRESS[], + progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", + init_params=nothing, + kwargs..., ) # Check if actually multiple threads are used. if Threads.nthreads() == 1 @@ -297,10 +302,9 @@ function mcmcsample( end # Copy the random number generator, model, and sample for each thread - # NOTE: As of May 17, 2020, this relies on Julia's thread scheduling functionality - # that distributes a for loop into equal-sized blocks and allocates them - # to each thread. If this changes, we may need to rethink things here. - interval = 1:min(nchains, Threads.nthreads()) + nchunks = min(nchains, Threads.nthreads()) + chunksize = cld(nchains, nchunks) + interval = 1:nchunks rngs = [deepcopy(rng) for _ in interval] models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] @@ -308,10 +312,13 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) + # Set up a chains vector. chains = Vector{Any}(undef, nchains) - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Create a channel for progress logging. if progress channel = Channel{Bool}(length(interval)) @@ -330,7 +337,7 @@ function mcmcsample( while take!(channel) progresschains += 1 if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains/nchains + ProgressLogging.@logprogress progresschains / nchains nextprogresschains = progresschains + threshold end end @@ -339,20 +346,35 @@ function mcmcsample( Distributed.@async begin try - Threads.@threads for i in 1:nchains - # Obtain the ID of the current thread. - id = Threads.threadid() - - # Seed the thread-specific random number generator with the pre-made seed. - subrng = rngs[id] - Random.seed!(subrng, seeds[i]) - - # Sample a chain and save it to the vector. - chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N; - progress = false, kwargs...) - - # Update the progress bar. - progress && put!(channel, true) + Distributed.@sync for (i, _rng, _model, _sampler) in + zip(1:nchunks, rngs, models, samplers) + chainidxs = if i == nchunks + ((i - 1) * chunksize + 1):nchains + else + ((i - 1) * chunksize + 1):(i * chunksize) + end + Threads.@spawn for chainidx in chainidxs + # Seed the chunk-specific random number generator with the pre-made seed. + Random.seed!(_rng, seeds[chainidx]) + + # Sample a chain and save it to the vector. + chains[chainidx] = StatsBase.sample( + _rng, + _model, + _sampler, + N; + progress=false, + init_params=if init_params === nothing + nothing + else + init_params[chainidx] + end, + kwargs..., + ) + + # Update the progress bar. + progress && put!(channel, true) + end end finally # Stop updating the progress bar. @@ -373,9 +395,10 @@ function mcmcsample( ::MCMCDistributed, N::Integer, nchains::Integer; - progress = PROGRESS[], - progressname = "Sampling ($(Distributed.nworkers()) processes)", - kwargs... + progress=PROGRESS[], + progressname="Sampling ($(Distributed.nworkers()) processes)", + init_params=nothing, + kwargs..., ) # Check if actually multiple processes are used. if Distributed.nworkers() == 1 @@ -387,6 +410,9 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -394,7 +420,7 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Create a channel for progress logging. if progress channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers())) @@ -413,7 +439,7 @@ function mcmcsample( while take!(channel) progresschains += 1 if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains/nchains + ProgressLogging.@logprogress progresschains / nchains nextprogresschains = progresschains + threshold end end @@ -422,13 +448,20 @@ function mcmcsample( Distributed.@async begin try - chains = Distributed.pmap(pool, seeds) do seed + function sample_chain(seed, init_params=nothing) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) # Sample a chain. - chain = StatsBase.sample(rng, model, sampler, N; - progress = false, kwargs...) + chain = StatsBase.sample( + rng, + model, + sampler, + N; + progress=false, + init_params=init_params, + kwargs..., + ) # Update the progress bar. progress && put!(channel, true) @@ -436,6 +469,11 @@ function mcmcsample( # Return the new chain. return chain end + chains = if init_params === nothing + Distributed.pmap(sample_chain, pool, seeds) + else + Distributed.pmap(sample_chain, pool, seeds, init_params) + end finally # Stop updating the progress bar. progress && put!(channel, false) @@ -455,20 +493,43 @@ function mcmcsample( ::MCMCSerial, N::Integer, nchains::Integer; - progressname = "Sampling", - kwargs... + progressname="Sampling", + init_params=nothing, + kwargs..., ) # Check if the number of chains is larger than the number of samples if nchains > N @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) + + # Create a seed for each chain using the provided random number generator. + seeds = rand(rng, UInt, nchains) + # Sample the chains. - chains = map( - i -> StatsBase.sample(rng, model, sampler, N; progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"), - kwargs...), - 1:nchains - ) + function sample_chain(i, seed, init_params=nothing) + # Seed a new random number generator with the pre-made seed. + Random.seed!(rng, seed) + + # Sample a chain. + return StatsBase.sample( + rng, + model, + sampler, + N; + progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), + init_params=init_params, + kwargs..., + ) + end + + chains = if init_params === nothing + map(sample_chain, 1:nchains, seeds) + else + map(sample_chain, 1:nchains, seeds, init_params) + end # Concatenate the chains together. return chainsstack(tighten_eltype(chains)) @@ -476,3 +537,22 @@ end tighten_eltype(x) = x tighten_eltype(x::Vector{Any}) = map(identity, x) + +@nospecialize check_initial_params(x, n) = throw( + ArgumentError( + "initial parameters must be specified as a vector of length equal to the number of chains or `nothing`", + ), +) + +check_initial_params(::Nothing, n) = nothing +function check_initial_params(x::AbstractArray, n) + if length(x) != n + throw( + ArgumentError( + "incorrect number of initial parameters (expected $n, received $(length(x))" + ), + ) + end + + return nothing +end diff --git a/src/samplingstats.jl b/src/samplingstats.jl index dea2b653..c5820dff 100644 --- a/src/samplingstats.jl +++ b/src/samplingstats.jl @@ -13,4 +13,4 @@ struct SamplingStats start::Float64 stop::Float64 duration::Float64 -end \ No newline at end of file +end diff --git a/src/stepper.jl b/src/stepper.jl index 34391851..a71826cb 100644 --- a/src/stepper.jl +++ b/src/stepper.jl @@ -5,24 +5,53 @@ struct Stepper{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} kwargs::K end -Base.iterate(stp::Stepper) = step(stp.rng, stp.model, stp.sampler; stp.kwargs...) +# Initial sample. +function Base.iterate(stp::Stepper) + # Unpack iterator. + rng = stp.rng + model = stp.model + sampler = stp.sampler + kwargs = stp.kwargs + discard_initial = get(kwargs, :discard_initial, 0)::Int + + # Start sampling algorithm and discard initial samples if desired. + sample, state = step(rng, model, sampler; kwargs...) + for _ in 1:discard_initial + sample, state = step(rng, model, sampler, state; kwargs...) + end + return sample, state +end + +# Subsequent samples. function Base.iterate(stp::Stepper, state) - return step(stp.rng, stp.model, stp.sampler, state; stp.kwargs...) + # Unpack iterator. + rng = stp.rng + model = stp.model + sampler = stp.sampler + kwargs = stp.kwargs + thinning = get(kwargs, :thinning, 1)::Int + + # Return next sample, possibly after thinning the chain if desired. + for _ in 1:(thinning - 1) + _, state = step(rng, model, sampler, state; kwargs...) + end + return step(rng, model, sampler, state; kwargs...) end Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite() Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown() -function steps( - model::AbstractModel, - sampler::AbstractSampler; - kwargs... -) - return steps(Random.GLOBAL_RNG, model, sampler; kwargs...) +function steps(model_or_logdensity, sampler::AbstractSampler; kwargs...) + return steps(Random.default_rng(), model_or_logdensity, sampler; kwargs...) end """ - steps([rng, ]model, sampler; kwargs...) + steps( + rng::Random.AbstractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler; + kwargs..., + ) Create an iterator that returns samples from the `model` with the Markov chain Monte Carlo `sampler`. @@ -46,10 +75,7 @@ true ``` """ function steps( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler; - kwargs... + rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler; kwargs... ) return Stepper(rng, model, sampler, kwargs) end diff --git a/src/transducer.jl b/src/transducer.jl index 7aca51e0..63bff3fd 100644 --- a/src/transducer.jl +++ b/src/transducer.jl @@ -1,16 +1,22 @@ -struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: Transducers.Transducer +struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: + Transducers.Transducer rng::A model::M sampler::S kwargs::K end -function Sample(model::AbstractModel, sampler::AbstractSampler; kwargs...) - return Sample(Random.GLOBAL_RNG, model, sampler; kwargs...) +function Sample(model_or_logdensity, sampler::AbstractSampler; kwargs...) + return Sample(Random.default_rng(), model_or_logdensity, sampler; kwargs...) end """ - Sample([rng, ]model, sampler; kwargs...) + Sample( + rng::Random.AbstractRNG=Random.default_rng(), + model::AbstractModel, + sampler::AbstractSampler; + kwargs..., + ) Create a transducer that returns samples from the `model` with the Markov chain Monte Carlo `sampler`. @@ -34,32 +40,63 @@ true ``` """ function Sample( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler; - kwargs... + rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler; kwargs... ) return Sample(rng, model, sampler, kwargs) end +# Initial sample. function Transducers.start(rf::Transducers.R_{<:Sample}, result) - sampler = Transducers.xform(rf) + # Unpack transducer. + td = Transducers.xform(rf) + rng = td.rng + model = td.model + sampler = td.sampler + kwargs = td.kwargs + discard_initial = get(kwargs, :discard_initial, 0)::Int + + # Start sampling algorithm and discard initial samples if desired. + sample, state = step(rng, model, sampler; kwargs...) + for _ in 1:discard_initial + sample, state = step(rng, model, sampler, state; kwargs...) + end + return Transducers.wrap( - rf, - step(sampler.rng, sampler.model, sampler.sampler; sampler.kwargs...), - Transducers.start(Transducers.inner(rf), result), + rf, (sample, state), Transducers.start(Transducers.inner(rf), result) ) end +# Subsequent samples. function Transducers.next(rf::Transducers.R_{<:Sample}, result, input) - t = Transducers.xform(rf) - Transducers.wrapping(rf, result) do (sample, state), iresult - iresult2 = Transducers.next(Transducers.inner(rf), iresult, sample) - return step(t.rng, t.model, t.sampler, state; t.kwargs...), iresult2 + # Unpack transducer. + td = Transducers.xform(rf) + rng = td.rng + model = td.model + sampler = td.sampler + kwargs = td.kwargs + thinning = get(kwargs, :thinning, 1)::Int + + let rng = rng, + model = model, + sampler = sampler, + kwargs = kwargs, + thinning = thinning, + inner_rf = Transducers.inner(rf) + + Transducers.wrapping(rf, result) do (sample, state), iresult + iresult2 = Transducers.next(inner_rf, iresult, sample) + + # Perform thinning if desired. + for _ in 1:(thinning - 1) + _, state = step(rng, model, sampler, state; kwargs...) + end + + return step(rng, model, sampler, state; kwargs...), iresult2 + end end end function Transducers.complete(rf::Transducers.R_{Sample}, result) - _private_state, inner_result = Transducers.unwrap(rf, result) + _, inner_result = Transducers.unwrap(rf, result) return Transducers.complete(Transducers.inner(rf), inner_result) end diff --git a/test/deprecations.jl b/test/deprecations.jl deleted file mode 100644 index f866668c..00000000 --- a/test/deprecations.jl +++ /dev/null @@ -1,4 +0,0 @@ -@testset "deprecations.jl" begin - @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler()) - @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler(), 3) -end \ No newline at end of file diff --git a/test/logdensityproblems.jl b/test/logdensityproblems.jl new file mode 100644 index 00000000..181d2645 --- /dev/null +++ b/test/logdensityproblems.jl @@ -0,0 +1,90 @@ +@testset "logdensityproblems.jl" begin + # Add worker processes. + # Memory requirements on Windows are ~4x larger than on Linux, hence number of processes is reduced + # See, e.g., https://github.com/JuliaLang/julia/issues/40766 and https://github.com/JuliaLang/Pkg.jl/pull/2366 + pids = addprocs(Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int) + + # Load all required packages (`utils.jl` needs LogDensityProblems, Logging, and Random). + @everywhere begin + using AbstractMCMC + using AbstractMCMC: sample + using LogDensityProblems + + using Logging + using Random + include("utils.jl") + end + + @testset "LogDensityModel" begin + ℓ = MyLogDensity(10) + model = @inferred AbstractMCMC.LogDensityModel(ℓ) + @test model isa AbstractMCMC.LogDensityModel{MyLogDensity} + @test model.logdensity === ℓ + + @test_throws ArgumentError AbstractMCMC.LogDensityModel(mylogdensity) + end + + @testset "fallback for log densities" begin + # Sample with log density + dim = 10 + ℓ = MyLogDensity(dim) + Random.seed!(1234) + N = 1_000 + samples = sample(ℓ, MySampler(), N) + + # Samples are of the correct dimension and log density values are correct + @test length(samples) == N + @test all(length(x.a) == dim for x in samples) + @test all(x.b ≈ LogDensityProblems.logdensity(ℓ, x.a) for x in samples) + + # Same chain as if LogDensityModel is used explicitly + Random.seed!(1234) + samples2 = sample(AbstractMCMC.LogDensityModel(ℓ), MySampler(), N) + @test length(samples2) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples2)) + + # Same chain if sampling is performed with convergence criterion + Random.seed!(1234) + isdone(rng, model, sampler, state, samples, iteration; kwargs...) = iteration > N + samples3 = sample(ℓ, MySampler(), isdone) + @test length(samples3) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples3)) + + # Same chain if sampling is performed with iterator + Random.seed!(1234) + samples4 = collect(Iterators.take(AbstractMCMC.steps(ℓ, MySampler()), N)) + @test length(samples4) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples4)) + + # Same chain if sampling is performed with transducer + Random.seed!(1234) + xf = AbstractMCMC.Sample(ℓ, MySampler()) + samples5 = collect(xf(1:N)) + @test length(samples5) == N + @test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples5)) + + # Parallel sampling + for alg in (MCMCSerial(), MCMCDistributed(), MCMCThreads()) + chains = sample(ℓ, MySampler(), alg, N, 2) + @test length(chains) == 2 + samples = vcat(chains[1], chains[2]) + @test length(samples) == 2 * N + @test all(length(x.a) == dim for x in samples) + @test all(x.b ≈ LogDensityProblems.logdensity(ℓ, x.a) for x in samples) + end + + # Log density has to satisfy the LogDensityProblems interface + @test_throws ArgumentError sample(mylogdensity, MySampler(), N) + @test_throws ArgumentError sample(mylogdensity, MySampler(), isdone) + @test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCSerial(), N, 2) + @test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCThreads(), N, 2) + @test_throws ArgumentError sample( + mylogdensity, MySampler(), MCMCDistributed(), N, 2 + ) + @test_throws ArgumentError AbstractMCMC.steps(mylogdensity, MySampler()) + @test_throws ArgumentError AbstractMCMC.Sample(mylogdensity, MySampler()) + end + + # Remove workers + rmprocs(pids...) +end diff --git a/test/runtests.jl b/test/runtests.jl index c3f108e1..909ae8b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,14 @@ using AbstractMCMC -using Atom.Progress: JunoProgressLogger using ConsoleProgressMonitor: ProgressLogger using IJulia +using LogDensityProblems using LoggingExtras: TeeLogger, EarlyFilteredLogger using TerminalLoggers: TerminalLogger +using FillArrays: FillArrays using Transducers using Distributed -import Logging +using Logging: Logging using Random using Statistics using Test @@ -22,5 +23,5 @@ include("utils.jl") include("sample.jl") include("stepper.jl") include("transducer.jl") - include("deprecations.jl") + include("logdensityproblems.jl") end diff --git a/test/sample.jl b/test/sample.jl index 6e876d48..22f4b26d 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -5,12 +5,13 @@ Random.seed!(1234) N = 1_000 - chain = sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + chain = sample(MyModel(), MySampler(), N; loggers=true) @test length(LOGGERS) == 1 logger = first(LOGGERS) @test logger isa TeeLogger - @test logger.loggers[1].logger isa (Sys.iswindows() && VERSION < v"1.5.3" ? ProgressLogger : TerminalLogger) + @test logger.loggers[1].logger isa + (Sys.iswindows() && VERSION < v"1.5.3" ? ProgressLogger : TerminalLogger) @test logger.loggers[2].logger === CURRENT_LOGGER @test Logging.current_logger() === CURRENT_LOGGER @@ -20,26 +21,17 @@ # test some statistical properties tail_chain = @view chain[2:end] - @test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2 - @test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3 - @test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2 - @test var(x.b for x in tail_chain) ≈ 1 atol=6e-2 - end - - @testset "Juno" begin - empty!(LOGGERS) - - Random.seed!(1234) - N = 10 - - logger = JunoProgressLogger() - Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) - end - - @test length(LOGGERS) == 1 - @test first(LOGGERS) === logger - @test Logging.current_logger() === CURRENT_LOGGER + @test mean(x.a for x in tail_chain) ≈ 0.5 atol = 6e-2 + @test var(x.a for x in tail_chain) ≈ 1 / 12 atol = 5e-3 + @test mean(x.b for x in tail_chain) ≈ 0.0 atol = 5e-2 + @test var(x.b for x in tail_chain) ≈ 1 atol = 6e-2 + + # initial parameters + chain = sample( + MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8) + ) + @test chain[1].a == -1.8 + @test chain[1].b == 3.2 end @testset "IJulia" begin @@ -52,7 +44,7 @@ Random.seed!(1234) N = 10 - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + sample(MyModel(), MySampler(), N; loggers=true) @test length(LOGGERS) == 1 logger = first(LOGGERS) @@ -74,7 +66,7 @@ logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1)) Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + sample(MyModel(), MySampler(), N; loggers=true) end @test length(LOGGERS) == 1 @@ -84,98 +76,169 @@ @testset "Suppress output" begin logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), 100; progress = false, sleepy = true) + sample(MyModel(), MySampler(), 100; progress=false) end @test all(l.level > Logging.LogLevel(-1) for l in logs) # disable progress logging globally - @test !(@test_logs (:info, "progress logging is disabled globally") AbstractMCMC.setprogress!(false)) + @test !(@test_logs (:info, "progress logging is disabled globally") AbstractMCMC.setprogress!( + false + )) @test !AbstractMCMC.PROGRESS[] logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), 100; sleepy = true) + sample(MyModel(), MySampler(), 100) end @test all(l.level > Logging.LogLevel(-1) for l in logs) # enable progress logging globally - @test (@test_logs (:info, "progress logging is enabled globally") AbstractMCMC.setprogress!(true)) + @test (@test_logs (:info, "progress logging is enabled globally") AbstractMCMC.setprogress!( + true + )) @test AbstractMCMC.PROGRESS[] end end - if VERSION ≥ v"1.3" - @testset "Multithreaded sampling" begin - if Threads.nthreads() == 1 - warnregex = r"^Only a single thread available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), - 10, 10) - end + @testset "Multithreaded sampling" begin + if Threads.nthreads() == 1 + warnregex = r"^Only a single thread available" + @test_logs (:warn, warnregex) sample( + MyModel(), MySampler(), MCMCThreads(), 10, 10 + ) + end - # No dedicated chains type - N = 10_000 - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000) - @test chains isa Vector{<:Vector{<:MySample}} - @test length(chains) == 1000 - @test all(length(x) == N for x in chains) + # No dedicated chains type + N = 10_000 + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000) + @test chains isa Vector{<:Vector{<:MySample}} + @test length(chains) == 1000 + @test all(length(x) == N for x in chains) - Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + Random.seed!(1234) + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain) - # test output type and size - @test chains isa Vector{<:MyChain} - @test length(chains) == 1000 - @test all(x -> length(x.as) == length(x.bs) == N, chains) + # test output type and size + @test chains isa Vector{<:MyChain} + @test length(chains) == 1000 + @test all(x -> length(x.as) == length(x.bs) == N, chains) + @test all(ismissing(x.as[1]) for x in chains) - # test some statistical properties - @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) - @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + # test some statistical properties + @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) + @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) - # test reproducibility - Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + # test reproducibility + Random.seed!(1234) + chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain) + @test all(ismissing(x.as[1]) for x in chains2) + @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) + @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + # Unexpected order of arguments. + str = "Number of chains (10) is greater than number of samples per chain (5)" + @test_logs (:warn, str) match_mode = :any sample( + MyModel(), MySampler(), MCMCThreads(), 5, 10; chain_type=MyChain + ) - # Unexpected order of arguments. - str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCThreads(), 5, 10; - chain_type = MyChain) + # Suppress output. + logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do + sample( + MyModel(), + MySampler(), + MCMCThreads(), + 10_000, + 1000; + progress=false, + chain_type=MyChain, + ) + end + @test all(l.level > Logging.LogLevel(-1) for l in logs) - # Suppress output. - logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; - progress = false, chain_type = MyChain) - end - @test all(l.level > Logging.LogLevel(-1) for l in logs) - - # Smoke test for nchains < nthreads - if Threads.nthreads() == 2 - sample(MyModel(), MySampler(), MCMCThreads(), N, 1) - end + # Smoke test for nchains < nthreads + if Threads.nthreads() == 2 + sample(MyModel(), MySampler(), MCMCThreads(), N, 1) end + + # initial parameters + nchains = 100 + init_params = [(b=randn(), a=rand()) for _ in 1:nchains] + chains = sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + nchains; + progress=false, + init_params=init_params, + ) + @test length(chains) == nchains + @test all( + chain[1].a == params.a && chain[1].b == params.b for + (chain, params) in zip(chains, init_params) + ) + + init_params = (a=randn(), b=rand()) + chains = sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains), + ) + @test length(chains) == nchains + @test all( + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + ) + + # Too many `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains + 1), + ) + + # Too few `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains - 1), + ) end @testset "Multicore sampling" begin if nworkers() == 1 warnregex = r"^Only a single process available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCDistributed(), - 10, 10; chain_type = MyChain) + @test_logs (:warn, warnregex) sample( + MyModel(), MySampler(), MCMCDistributed(), 10, 10; chain_type=MyChain + ) end # Add worker processes. - addprocs() + # Memory requirements on Windows are ~4x larger than on Linux, hence number of processes is reduced + # See, e.g., https://github.com/JuliaLang/julia/issues/40766 and https://github.com/JuliaLang/Pkg.jl/pull/2366 + pids = addprocs( + Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int + ) - # Load all required packages (`interface.jl` needs Random). + # Load all required packages (`utils.jl` needs LogDensityProblems, Logging, and Random). @everywhere begin using AbstractMCMC using AbstractMCMC: sample + using LogDensityProblems + using Logging using Random include("utils.jl") end @@ -188,12 +251,13 @@ @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; - chain_type = MyChain) + chains = sample( + MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain + ) # Test output type and size. @test chains isa Vector{<:MyChain} - @test all(c.as[1] === missing for c in chains) + @test all(ismissing(c.as[1]) for c in chains) @test length(chains) == 1000 @test all(x -> length(x.as) == length(x.bs) == N, chains) @@ -201,28 +265,94 @@ @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) # Test reproducibility. Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; - chain_type = MyChain) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + chains2 = sample( + MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain + ) + @test all(ismissing(c.as[1]) for c in chains2) + @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) + @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCDistributed(), 5, 10; - chain_type = MyChain) + @test_logs (:warn, str) match_mode = :any sample( + MyModel(), MySampler(), MCMCDistributed(), 5, 10; chain_type=MyChain + ) # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 100; - progress = false, chain_type = MyChain) + sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 10_000, + 100; + progress=false, + chain_type=MyChain, + ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) + + # initial parameters + nchains = 100 + init_params = [(a=randn(), b=rand()) for _ in 1:nchains] + chains = sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + nchains; + progress=false, + init_params=init_params, + ) + @test length(chains) == nchains + @test all( + chain[1].a == params.a && chain[1].b == params.b for + (chain, params) in zip(chains, init_params) + ) + + init_params = (b=randn(), a=rand()) + chains = sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains), + ) + @test length(chains) == nchains + @test all( + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + ) + + # Too many `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains + 1), + ) + + # Too few `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains - 1), + ) + + # Remove workers + rmprocs(pids...) end @testset "Serial sampling" begin @@ -234,12 +364,11 @@ @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; - chain_type = MyChain) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) # Test output type and size. @test chains isa Vector{<:MyChain} - @test all(c.as[1] === missing for c in chains) + @test all(ismissing(c.as[1]) for c in chains) @test length(chains) == 1000 @test all(x -> length(x.as) == length(x.bs) == N, chains) @@ -247,89 +376,254 @@ @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) # Test reproducibility. Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; - chain_type = MyChain) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) + @test all(ismissing(c.as[1]) for c in chains2) + @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) + @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCSerial(), 5, 10; - chain_type = MyChain) + @test_logs (:warn, str) match_mode = :any sample( + MyModel(), MySampler(), MCMCSerial(), 5, 10; chain_type=MyChain + ) # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCSerial(), 10_000, 100; - progress = false, chain_type = MyChain) + sample( + MyModel(), + MySampler(), + MCMCSerial(), + 10_000, + 100; + progress=false, + chain_type=MyChain, + ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) + + # initial parameters + nchains = 100 + init_params = [(a=rand(), b=randn()) for _ in 1:nchains] + chains = sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + nchains; + progress=false, + init_params=init_params, + ) + @test length(chains) == nchains + @test all( + chain[1].a == params.a && chain[1].b == params.b for + (chain, params) in zip(chains, init_params) + ) + + init_params = (b=rand(), a=randn()) + chains = sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains), + ) + @test length(chains) == nchains + @test all( + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + ) + + # Too many `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains + 1), + ) + + # Too few `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains - 1), + ) + end + + @testset "Ensemble sampling: Reproducibility" begin + N = 1_000 + nchains = 10 + + # Serial sampling + Random.seed!(1234) + chains_serial = sample( + MyModel(), + MySampler(), + MCMCSerial(), + N, + nchains; + progress=false, + chain_type=MyChain, + ) + @test all(ismissing(c.as[1]) for c in chains_serial) + + # Multi-threaded sampling + Random.seed!(1234) + chains_threads = sample( + MyModel(), + MySampler(), + MCMCThreads(), + N, + nchains; + progress=false, + chain_type=MyChain, + ) + @test all(ismissing(c.as[1]) for c in chains_threads) + @test all( + c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), + i in 2:N + ) + @test all( + c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), + i in 1:N + ) + + # Multi-core sampling + Random.seed!(1234) + chains_distributed = sample( + MyModel(), + MySampler(), + MCMCDistributed(), + N, + nchains; + progress=false, + chain_type=MyChain, + ) + @test all(ismissing(c.as[1]) for c in chains_distributed) + @test all( + c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), + i in 2:N + ) + @test all( + c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), + i in 1:N + ) end @testset "Chain constructors" begin - chain1 = sample(MyModel(), MySampler(), 100; sleepy = true) - chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain) + chain1 = sample(MyModel(), MySampler(), 100) + chain2 = sample(MyModel(), MySampler(), 100; chain_type=MyChain) @test chain1 isa Vector{<:MySample} @test chain2 isa MyChain end @testset "Sample stats" begin - chain = sample(MyModel(), MySampler(), 1000; chain_type = MyChain) - - @test chain.stats.stop > chain.stats.start + chain = sample(MyModel(), MySampler(), 1000; chain_type=MyChain) + + @test chain.stats.stop >= chain.stats.start @test chain.stats.duration == chain.stats.stop - chain.stats.start end @testset "Discard initial samples" begin - chain = sample(MyModel(), MySampler(), 100; sleepy = true, discard_initial = 50) - @test length(chain) == 100 + # Create a chain and discard initial samples. + Random.seed!(1234) + N = 100 + discard_initial = 50 + chain = sample(MyModel(), MySampler(), N; discard_initial=discard_initial) + @test length(chain) == N @test !ismissing(chain[1].a) + + # Repeat sampling without discarding initial samples. + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + ref_chain = sample( + MyModel(), MySampler(), N + discard_initial; progress=VERSION < v"1.6" + ) + @test all(chain[i].a == ref_chain[i + discard_initial].a for i in 1:N) + @test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N) end @testset "Thin chain by a factor of `thinning`" begin # Run a thinned chain with `N` samples thinned by factor of `thinning`. - Random.seed!(1234) + Random.seed!(100) N = 100 thinning = 3 - chain = sample(MyModel(), MySampler(), N; sleepy = true, thinning = thinning) + chain = sample(MyModel(), MySampler(), N; thinning=thinning) @test length(chain) == N @test ismissing(chain[1].a) # Repeat sampling without thinning. - Random.seed!(1234) - ref_chain = sample(MyModel(), MySampler(), N * thinning; sleepy = true) - @test all(chain[i].a === ref_chain[(i - 1) * thinning + 1].a for i in 1:N) + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(100) + ref_chain = sample(MyModel(), MySampler(), N * thinning; progress=VERSION < v"1.6") + @test all(chain[i].a == ref_chain[(i - 1) * thinning + 1].a for i in 2:N) + @test all(chain[i].b == ref_chain[(i - 1) * thinning + 1].b for i in 1:N) end - @testset "Sample without predetermined N" begin Random.seed!(1234) chain = sample(MyModel(), MySampler()) bmean = mean(x.b for x in chain) @test ismissing(chain[1].a) - @test abs(bmean) <= 0.001 && length(chain) < 10_000 + @test abs(bmean) <= 0.001 || length(chain) == 10_000 # Discard initial samples. - chain = sample(MyModel(), MySampler(); discard_initial = 50) + Random.seed!(1234) + discard_initial = 50 + chain = sample(MyModel(), MySampler(); discard_initial=discard_initial) bmean = mean(x.b for x in chain) @test !ismissing(chain[1].a) - @test abs(bmean) <= 0.001 && length(chain) < 10_000 + @test abs(bmean) <= 0.001 || length(chain) == 10_000 + + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + N = length(chain) + ref_chain = sample( + MyModel(), + MySampler(), + N; + discard_initial=discard_initial, + progress=VERSION < v"1.6", + ) + @test all(chain[i].a == ref_chain[i].a for i in 1:N) + @test all(chain[i].b == ref_chain[i].b for i in 1:N) # Thin chain by a factor of `thinning`. - chain = sample(MyModel(), MySampler(); thinning = 3) + Random.seed!(1234) + thinning = 3 + chain = sample(MyModel(), MySampler(); thinning=thinning) bmean = mean(x.b for x in chain) @test ismissing(chain[1].a) - @test abs(bmean) <= 0.001 && length(chain) < 10_000 + @test abs(bmean) <= 0.001 || length(chain) == 10_000 + + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + N = length(chain) + ref_chain = sample( + MyModel(), MySampler(), N; thinning=thinning, progress=VERSION < v"1.6" + ) + @test all(chain[i].a == ref_chain[i].a for i in 2:N) + @test all(chain[i].b == ref_chain[i].b for i in 1:N) end @testset "Sample vector of `NamedTuple`s" begin - chain = sample(MyModel(), MySampler(), 1_000; chain_type = Vector{NamedTuple}) + chain = sample(MyModel(), MySampler(), 1_000; chain_type=Vector{NamedTuple}) # Check output type @test chain isa Vector{<:NamedTuple} @test length(chain) == 1_000 @@ -337,15 +631,17 @@ # Check some statistical properties @test ismissing(chain[1].a) - @test mean(x.a for x in view(chain, 2:1_000)) ≈ 0.5 atol=6e-2 - @test var(x.a for x in view(chain, 2:1_000)) ≈ 1 / 12 atol=1e-2 - @test mean(x.b for x in chain) ≈ 0 atol=0.1 - @test var(x.b for x in chain) ≈ 1 atol=0.15 + @test mean(x.a for x in view(chain, 2:1_000)) ≈ 0.5 atol = 6e-2 + @test var(x.a for x in view(chain, 2:1_000)) ≈ 1 / 12 atol = 1e-2 + @test mean(x.b for x in chain) ≈ 0 atol = 0.11 + @test var(x.b for x in chain) ≈ 1 atol = 0.15 end - + @testset "Testing callbacks" begin - function count_iterations(rng, model, sampler, sample, state, i; iter_array, kwargs...) - push!(iter_array, i) + function count_iterations( + rng, model, sampler, sample, state, i; iter_array, kwargs... + ) + return push!(iter_array, i) end N = 100 it_array = Float64[] @@ -354,7 +650,9 @@ # sampling without predetermined N it_array = Float64[] - chain = sample(MyModel(), MySampler(); callback=count_iterations, iter_array=it_array) + chain = sample( + MyModel(), MySampler(); callback=count_iterations, iter_array=it_array + ) @test it_array == collect(1:size(chain, 1)) end end diff --git a/test/stepper.jl b/test/stepper.jl index f3a4b599..80143344 100644 --- a/test/stepper.jl +++ b/test/stepper.jl @@ -5,7 +5,7 @@ bs = [] iter = AbstractMCMC.steps(MyModel(), MySampler()) - iter = AbstractMCMC.steps(MyModel(), MySampler(); a = 1.0) # `a` shouldn't do anything + iter = AbstractMCMC.steps(MyModel(), MySampler(); a=1.0) # `a` shouldn't do anything for (count, t) in enumerate(iter) if count >= 1000 @@ -21,12 +21,55 @@ @test length(as) == length(bs) == 998 - @test mean(as) ≈ 0.5 atol=1e-2 - @test var(as) ≈ 1 / 12 atol=5e-3 - @test mean(bs) ≈ 0.0 atol=5e-2 - @test var(bs) ≈ 1 atol=5e-2 + @test mean(as) ≈ 0.5 atol = 2e-2 + @test var(as) ≈ 1 / 12 atol = 5e-3 + @test mean(bs) ≈ 0.0 atol = 5e-2 + @test var(bs) ≈ 1 atol = 5e-2 @test Base.IteratorSize(iter) == Base.IsInfinite() @test Base.IteratorEltype(iter) == Base.EltypeUnknown() end + + @testset "Discard initial samples" begin + # Create a chain of `N` samples after discarding some initial samples. + Random.seed!(1234) + N = 50 + discard_initial = 10 + iter = AbstractMCMC.steps(MyModel(), MySampler(); discard_initial=discard_initial) + as = [] + bs = [] + for t in Iterators.take(iter, N) + push!(as, t.a) + push!(bs, t.b) + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample( + MyModel(), MySampler(), N; discard_initial=discard_initial, progress=false + ) + @test all(as[i] == chain[i].a for i in 1:N) + @test all(bs[i] == chain[i].b for i in 1:N) + end + + @testset "Thin chain by a factor of `thinning`" begin + # Create a thinned chain with a thinning factor of `thinning`. + Random.seed!(1234) + N = 50 + thinning = 3 + iter = AbstractMCMC.steps(MyModel(), MySampler(); thinning=thinning) + as = [] + bs = [] + for t in Iterators.take(iter, N) + push!(as, t.a) + push!(bs, t.b) + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample(MyModel(), MySampler(), N; thinning=thinning, progress=false) + @test as[1] === chain[1].a === missing + @test all(as[i] == chain[i].a for i in 2:N) + @test all(bs[i] == chain[i].b for i in 1:N) + end end diff --git a/test/transducer.jl b/test/transducer.jl index 2b363e27..b161151c 100644 --- a/test/transducer.jl +++ b/test/transducer.jl @@ -5,9 +5,8 @@ N = 1_000 local chain Logging.with_logger(TerminalLogger()) do - xf = AbstractMCMC.Sample(MyModel(), MySampler(); - sleepy = true, logger = true) - chain = withprogress(1:N; interval=1e-3) |> xf |> collect + xf = AbstractMCMC.Sample(MyModel(), MySampler(); sleepy=true, logger=true) + chain = collect(xf(withprogress(1:N; interval=1e-3))) end # test output type and size @@ -16,15 +15,15 @@ # test some statistical properties tail_chain = @view chain[2:end] - @test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2 - @test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3 - @test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2 - @test var(x.b for x in tail_chain) ≈ 1 atol=6e-2 + @test mean(x.a for x in tail_chain) ≈ 0.5 atol = 6e-2 + @test var(x.a for x in tail_chain) ≈ 1 / 12 atol = 5e-3 + @test mean(x.b for x in tail_chain) ≈ 0.0 atol = 5e-2 + @test var(x.b for x in tail_chain) ≈ 1 atol = 6e-2 end @testset "drop" begin xf = AbstractMCMC.Sample(MyModel(), MySampler()) - chain = 1:10 |> xf |> Drop(1) |> collect + chain = collect(Drop(1)(xf(1:10))) @test chain isa Vector{MySample{Float64,Float64}} @test length(chain) == 9 end @@ -37,7 +36,7 @@ OfType(MySample{Float64,Float64}), Map(x -> (x.a, x.b)), ) - as, bs = foldl(xf, 1:999; init = (Float64[], Float64[])) do (as, bs), (a, b) + as, bs = foldl(xf, 1:999; init=(Float64[], Float64[])) do (as, bs), (a, b) push!(as, a) push!(bs, b) as, bs @@ -45,9 +44,56 @@ @test length(as) == length(bs) == 998 - @test mean(as) ≈ 0.5 atol=1e-2 - @test var(as) ≈ 1 / 12 atol=5e-3 - @test mean(bs) ≈ 0.0 atol=5e-2 - @test var(bs) ≈ 1 atol=5e-2 + @test mean(as) ≈ 0.5 atol = 2e-2 + @test var(as) ≈ 1 / 12 atol = 5e-3 + @test mean(bs) ≈ 0.0 atol = 5e-2 + @test var(bs) ≈ 1 atol = 5e-2 + end + + @testset "Discard initial samples" begin + # Create a chain of `N` samples after discarding some initial samples. + Random.seed!(1234) + N = 50 + discard_initial = 10 + xf = opcompose( + AbstractMCMC.Sample(MyModel(), MySampler(); discard_initial=discard_initial), + Map(x -> (x.a, x.b)), + ) + as, bs = foldl(xf, 1:N; init=([], [])) do (as, bs), (a, b) + push!(as, a) + push!(bs, b) + as, bs + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample( + MyModel(), MySampler(), N; discard_initial=discard_initial, progress=false + ) + @test all(as[i] == chain[i].a for i in 1:N) + @test all(bs[i] == chain[i].b for i in 1:N) + end + + @testset "Thin chain by a factor of `thinning`" begin + # Create a thinned chain with a thinning factor of `thinning`. + Random.seed!(1234) + N = 50 + thinning = 3 + xf = opcompose( + AbstractMCMC.Sample(MyModel(), MySampler(); thinning=thinning), + Map(x -> (x.a, x.b)), + ) + as, bs = foldl(xf, 1:N; init=([], [])) do (as, bs), (a, b) + push!(as, a) + push!(bs, b) + as, bs + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample(MyModel(), MySampler(), N; thinning=thinning, progress=false) + @test as[1] === chain[1].a === missing + @test all(as[i] == chain[i].a for i in 2:N) + @test all(bs[i] == chain[i].b for i in 1:N) end end diff --git a/test/utils.jl b/test/utils.jl index f6ac9d27..f69fcdab 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -20,17 +20,19 @@ function AbstractMCMC.step( rng::AbstractRNG, model::MyModel, sampler::MySampler, - state::Union{Nothing,Integer} = nothing; - sleepy = false, - loggers = false, - kwargs... + state::Union{Nothing,Integer}=nothing; + loggers=false, + init_params=nothing, + kwargs..., ) - # sample `a` is missing in the first step - a = state === nothing ? missing : rand(rng) - b = randn(rng) + # sample `a` is missing in the first step if not provided + a, b = if state === nothing && init_params !== nothing + init_params.a, init_params.b + else + (state === nothing ? missing : rand(rng)), randn(rng) + end loggers && push!(LOGGERS, Logging.current_logger()) - sleepy && sleep(0.001) _state = state === nothing ? 1 : state + 1 @@ -43,8 +45,8 @@ function AbstractMCMC.bundle_samples( sampler::MySampler, ::Any, ::Type{MyChain}; - stats = nothing, - kwargs... + stats=nothing, + kwargs..., ) as = [t.a for t in samples] bs = [t.b for t in samples] @@ -59,24 +61,55 @@ function isdone( samples, state, iteration::Int; - kwargs... + kwargs..., ) # Calculate the mean of x.b. bmean = mean(x.b for x in samples) - return abs(bmean) <= 0.001 || iteration >= 10_000 || state >= 10_000 + return abs(bmean) <= 0.001 || iteration > 10_000 end # Set a default convergence function. function AbstractMCMC.sample(model, sampler::MySampler; kwargs...) - return sample(Random.GLOBAL_RNG, model, sampler, isdone; kwargs...) + return sample(Random.default_rng(), model, sampler, isdone; kwargs...) end function AbstractMCMC.chainscat( - chain::Union{MyChain,Vector{<:MyChain}}, - chains::Union{MyChain,Vector{<:MyChain}}... + chain::Union{MyChain,Vector{<:MyChain}}, chains::Union{MyChain,Vector{<:MyChain}}... ) return vcat(chain, chains...) end # Conversion to NamedTuple -Base.convert(::Type{NamedTuple}, x::MySample) = (a = x.a, b = x.b) +Base.convert(::Type{NamedTuple}, x::MySample) = (a=x.a, b=x.b) + +# Gaussian log density (without additive constants) +# Without LogDensityProblems.jl interface +mylogdensity(x) = -sum(abs2, x) / 2 + +# With LogDensityProblems.jl interface +struct MyLogDensity + dim::Int +end +LogDensityProblems.logdensity(::MyLogDensity, x) = mylogdensity(x) +LogDensityProblems.dimension(m::MyLogDensity) = m.dim +function LogDensityProblems.capabilities(::Type{MyLogDensity}) + return LogDensityProblems.LogDensityOrder{0}() +end + +# Define "sampling" +function AbstractMCMC.step( + rng::AbstractRNG, + model::AbstractMCMC.LogDensityModel{MyLogDensity}, + ::MySampler, + state::Union{Nothing,Integer}=nothing; + kwargs..., +) + # Sample from multivariate normal distribution + ℓ = model.logdensity + dim = LogDensityProblems.dimension(ℓ) + θ = randn(rng, dim) + logdensity_θ = LogDensityProblems.logdensity(ℓ, θ) + + _state = state === nothing ? 1 : state + 1 + return MySample(θ, logdensity_θ), _state +end