Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace old Gibbs sampler with the experimental one. #2328

Merged
merged 81 commits into from
Dec 19, 2024

Conversation

mhauru
Copy link
Member

@mhauru mhauru commented Sep 23, 2024

Closes #2318.

Work in progress.

Copy link

codecov bot commented Sep 23, 2024

Codecov Report

Attention: Patch coverage is 87.45098% with 32 lines in your changes missing coverage. Please review.

Project coverage is 85.39%. Comparing base (2707d12) to head (96f8dd4).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
src/mcmc/gibbs.jl 87.67% 26 Missing ⚠️
src/mcmc/repeat_sampler.jl 80.00% 4 Missing ⚠️
src/mcmc/Inference.jl 75.00% 1 Missing ⚠️
src/mcmc/particle_mcmc.jl 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2328      +/-   ##
==========================================
- Coverage   86.30%   85.39%   -0.92%     
==========================================
  Files          22       21       -1     
  Lines        1577     1588      +11     
==========================================
- Hits         1361     1356       -5     
- Misses        216      232      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Sep 23, 2024

Pull Request Test Coverage Report for Build 12400670488

Details

  • 223 of 255 (87.45%) changed or added relevant lines in 11 files are covered.
  • 13 unchanged lines in 3 files lost coverage.
  • Overall coverage increased (+6.9%) to 85.39%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/mcmc/Inference.jl 3 4 75.0%
src/mcmc/particle_mcmc.jl 2 3 66.67%
src/mcmc/repeat_sampler.jl 16 20 80.0%
src/mcmc/gibbs.jl 185 211 87.68%
Files with Coverage Reduction New Missed Lines %
src/mcmc/Inference.jl 1 86.39%
src/mcmc/ess.jl 1 94.64%
src/mcmc/particle_mcmc.jl 11 86.75%
Totals Coverage Status
Change from base Build 12397554649: 6.9%
Covered Lines: 1356
Relevant Lines: 1588

💛 - Coveralls

HISTORY.md Outdated Show resolved Hide resolved
@mhauru
Copy link
Member Author

mhauru commented Sep 26, 2024

@torfjelde, if you have a moment to take a look at the one remaining test failure, would be interested in your thoughts. We are sampling for a model with two vector variables, m and z, and we seem to somehow end up with a case where there's a VarInfo with only z in it, but the sampler is looking for m too. I wonder if it's something about the interaction between particle sampling with Libtask and how the new Gibbs does things with the local varinfos. The test that fails is this one:

    @testset "dynamic model" begin
        @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M}
            N = length(y)
            rpm = DirichletProcess(alpha)

            z = zeros(Int, N)
            cluster_counts = zeros(Int, N)
            fill!(cluster_counts, 0)

            for i in 1:N
                z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts)
                cluster_counts[z[i]] += 1
            end

            Kmax = findlast(!iszero, cluster_counts)
            m = M(undef, Kmax)
            for k in 1:Kmax
                m[k] ~ Normal(1.0, 1.0)
            end
        end
        model = imm(Random.randn(100), 1.0)
        # https://github.com/TuringLang/Turing.jl/issues/1725
        # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100);
        sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100)
    end

@torfjelde
Copy link
Member

Will have a look at this in a bit @mhauru (just need to do some grocery shopping 😬 )

@mhauru
Copy link
Member Author

mhauru commented Sep 26, 2024

Collecting links to old relevant PRs so I don't have to look for them again: #2231, #2099

@torfjelde
Copy link
Member

Think I found the error: if the number of m increases, say, from length(m) = 2 to length(m) = 3 during the PG step, then the lines

if has_conditioned_gibbs(context, vn)
value = get_conditioned_gibbs(context, vn)
return value, logpdf(right, value), vi
end
# Otherwise, falls back to the default behavior.
return DynamicPPL.tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
)

doesn't hit the Gibbs branch since @varname(m[3]) is not present in the GibbsContext 😕

@torfjelde
Copy link
Member

doesn't hit the Gibbs branch since @varname(m[3]) is not present in the GibbsContext

I'm a bit uncertain how we should best handle this @yebai @mhauru

The first partially viable idea that comes to mind is to subset the varinfo to make sure that it only contains the correct variables. If we do this, then m[3] will just be "ignored" (in the varinfos) until we're actually sampling the m variables, in which case it would be captured correctly.

But this would not quite be equivalent to the current implementation of Gibbs, which, AFAIK, keeps the very first occurence of m around rather than resampling everytime. And naively, I would expect this to be incorrect.

Another way is to explicitly add the varinfos to the GibbsContext itself, and then, when we encounter a value that should in fact go into a different varinfo, we add it there. But this has a few issues:

  1. Requires the VarInfo to be mutable.
  2. Requires the VarInfo to have a container that can keep the new incoming value m[3].
  3. Implementation of Gibbs does end up being more complicated than the current approach. However, it might be worth it.

Thoughts?

@yebai
Copy link
Member

yebai commented Sep 27, 2024

Another way is to explicitly add the varinfos to the GibbsContext itself, and then, when we encounter a value that should in fact go into a different varinfo, we add it there. But this has a few issues:

Requires the VarInfo to be mutable.
Requires the VarInfo to have a container that can keep the new incoming value m[3].
Implementation of Gibbs does end up being more complicated than the current approach. However, it might be worth it.

I lean towards the above approach and (maybe later) provide explicit APIs to inference algorithms. This will enable us to handle reversible jumps (varying model dimensions) in MCMC more flexibly. At the moment, this is only possible in particle Gibbs; if it happens in HMC/MH, inference will likely fail (silently)

EDIT: we can keep VarInfos immutable by default, and requires inference developers to hook into specific APIs to mutate VarInfos.

@torfjelde
Copy link
Member

This does however complicate the new Gibbs sampling procedure quite drastically 😕

And it makes me bring up a question I really didn't think I'd be asking: is it then actually preferable to the current Gibbs with keeping it all in a single VarInfo with a flag to specify whether it should be sampled or not? 😬

I guess we should first have a go at implementing this for the new Gibbs and then we can see 👍

Another point to add to the conversation that @mhauru brought to my attention the other day: we also want to support stuff like Gibbs(@varname(m) => NUTS(), @varname(m) => HMC()), i.e. multiple samplers targeting the same variables. This adds a few "complications" (beyond addressing the growing model problem discussed above):

  1. Need to determine which varinfo to pick from varinfos based on the varnames present / targeted.
  2. A naive implementation will result in duplicated entries in varinfos. We can however address this if we really feel like it's worth it, so probably a non-issue atm.

So all in all, immediate things we need to address with Gibbs:

  1. Support changing dimensions.
  2. Support picking a varinfo to condition on based on the varnames present rather than based on ===.

@mhauru
Copy link
Member Author

mhauru commented Oct 10, 2024

I've been trying to think of a way to fix this, that would also fix the problem where different Gibbs subsamplers can't sample the same variables (e.g. you can't first sample x and y using one sampler, and then y and z with a different one). My best thought at the moment is the following design:

  1. There is only one, global VarInfo, call it vi.
  2. make_conditional takes that vi and a list of VarNames that the current subsampler samples. It hijacks the tilde pipeline to condition all other variables to their current values in vi.
  3. vi may have some variables linked, some not.
  4. Every time we call a subsampler we can hand it vi as the VarInfo. It won’t mess with any of the variables it’s not supposed to touch, because the tilde pipeline hijack from point 2.

Point 3. is maybe undesirable, but I think it’s minor compared to all the Selector/gibbsid stuff, which we would still get rid of.

The only problem I see with this is combining the local state from the previous iteration of the current subsampler with the global vi. Somehow we would need to join up-to-date information from the global vi with state-information from the previous iteration, specific to this subsampler. The right way to do this depends on the state, which is a different type of object for different subsamplers. EDIT: Actually, maybe this is okay, because we seem to already assume that every state object has a field called state.vi , we could just reset that.

The great benefit of sticking to one, global VarInfo is never having to worry about moving data between the local VarInfos. That would have to happen in both cases, when a new variable is introduced by one sampler (the failing test in this PR) and when two samplers sample the same variable. It sounds like a pain to implement.

@mhauru
Copy link
Member Author

mhauru commented Oct 10, 2024

I can imagine two different philosophies to implementing a Gibbs sampler:

  1. Every subsampler is doing its own sampling process on a low-dimensional model (a conditioned version of the full model), independent of the others. The logprobability function it's sampling from just keeps changing between iterations, because the other variables change and thus the conditioned model changes, but otherwise it's blind to the existence of the variables it isn't sampling. This is what the new Gibbs sampler does.
  2. Every subsampler is working with the same, full model, with all the variables, but only makes the changes to a subset of those variables. It still "sees" the whole model. This is what the old Gibbs sampler did.

My above proposal would essentially be doing 2., but using code that's very much like the new sampler, where the information about which sampler modifies which variables is in the sampler/GibbsContext, and not in VarInfo like it was in the old Gibbs.

The reason I'm leaning towards 2. is that 1. seems to run to some fundamental issues in cases where either

  • Variables appear and disappear based on values of other variables,
  • Two samplers want to modify the value of the same variable.

Both of those situations quite deeply violate the idea that the different subsamplers can operate mostly independently of each other.

Any thoughts very welcome, I'm still very much trying to understand the landscape of the problem.

@yebai
Copy link
Member

yebai commented Oct 10, 2024

Thanks, @mhauru, for the excellent summary of the problem and proposals. Storing conditioned variables in a context, like GibbsContext as you suggested, is very sensible. The consequence is that VarInfo and Context will have overlapped model parameters, e.g. conditioned variables will be found in both VarInfo and Context, which is fine.

In addition, it's worth mentioning that we currently have two mechanisms for passing observations to models, i.e.

(1) via model arguments, e.g. gdemo(x, y).
(2) via condition API, e.g. condition(model, (x=1,y=2)).

Among these options, (1) will hardcode observation information directly in the model while (2) stores them in a context. You could look at the DynamicPPL codebase for a more detailed picture of how it works. We want to unify these options, perhaps towards using (2) only.

This Gibbs refactoring could be an excellent starting point for a design_notes repo to record these thoughts and discussions.

@torfjelde
Copy link
Member

Every subsampler is working with the same, full model, with all the variables, but only makes the changes to a subset of those variables. It still "sees" the whole model. This is what the old Gibbs sampler did.

Overall, I'm also in favour of this @mhauru 👍 I think your reasoning is solid here.

The only other "option" I'm seeing is to keep track of which variables correpond to which varinfos (with each varinfo only containing the relevant information), but then we're effectively just re-implementing a lot of the functionality that is already provided in varinfo 😕

The only "issue" is that this does mean we have to support this "link / transform only part of the varinfo, which does mean we need something "equivalent" to all the getindex(varinfo, sampler) stuff that we've been trying to move away from (since we need a way to extract the vectorized part relevant only for the specific sampler we're going to use in that particular step) 😕

Doulby however, I think we can make this much nicer than the current approach by simply making all these getindex(varinfo, sampler) instead take the relevant varnames instead of the samplers themselves, which should make it all less painful.

But yeah, don't see how we can take approach (1) in a "nice" way, and so I'm also in favour of just trying to make (2) as painless as possible to maintain.

@mhauru
Copy link
Member Author

mhauru commented Oct 11, 2024

Thanks for the comments both, this is very helpful.

Doulby however, I think we can make this much nicer than the current approach by simply making all these getindex(varinfo, sampler) instead take the relevant varnames instead of the samplers themselves, which should make it all less painful.

Yeah, I think this is the way to go.

@mhauru
Copy link
Member Author

mhauru commented Nov 29, 2024

I'm done making the changes I had in mind. I may still experiment with some performance improvements, but not sure if any will make it in here. I'll also try to reduce the iteration counts in some tests to make them faster, the only CI failure is because one job just timed out at 6h.

Since both Tor and I seem to be happy, I'm gonna ping others in case they want to take a look: @penelopeysm, @willtebbutt, @sunxd3, @yebai. I think we can rely on @torfjelde giving an expert review, everyone else can judge for themselves how thorough a look they want to take, but I think everyone should be at least aware that this, somewhat major, change is happening. If you want to give this PR a review but haven't yet had time, self-request a review and we'll make sure to wait before merging.

For help in reviewing: This PR does a few things:

  1. Deletes the old src/mcmc/gibbs.jl, and the related src/mcmc/gibbs_conditional.jl.
  2. Moves src/experimental/gibbs.jl to be the new src/mcmc/gibbs.jl, and merges test/experimental/gibbs.jl and test/mcmc/gibbs.jl.
  3. Makes a lot of edits to the experimental/new Gibbs to accommodate dynamic models and some other things.
  4. Adds more, new tests to test/mcmc/gibbs.jl.
  5. Introduces RepeatSampler and its tests. This has to be done in the same PR because the old Gibbs had repeat functionality built-in, whereas the new Gibbs doesn't.
  6. Makes a bunch of small changes to various samplers to accommodate the new Gibbs.

Points 4-6 one can reviewed like usual, as a diff of a few hundred lines. Points 2-3 I think are better viewed as a new Gibbs sampler from scratch. The changes in point 3 are so extensive that reading it as a diff doesn't make much sense unless you know the old code really well.

@penelopeysm
Copy link
Member

I'm happy to take a look next week, but doubt I'll get to it today as my head is already several layers deep in DynamicPPL stuff 😄

@mhauru
Copy link
Member Author

mhauru commented Nov 29, 2024

I managed to decrease the iteration counts on a lot of the heaviest tests, the total runtime should be reduced substantially now. They seem to still pass somewhat robustly, i.e. I tried at least two random seeds.

Also did some quick checks of performance overheads, and the previous large overheads are gone in my example cases. Now, rather than being e.g. 100-500% slower than the old Gibbs we are more like 0-50% slower. This for models dominated by overheads from outside model evaluation, i.e. fast models where performance is not a big deal.

@mhauru
Copy link
Member Author

mhauru commented Dec 2, 2024

The Mooncake stack overflows are something @willtebbutt is aware of and knows the reason for, so we can ignore them for now. Would still hold off from merging until they are fixed.

@yebai
Copy link
Member

yebai commented Dec 16, 2024

@penelopeysm, can you help resolve the merge conflicts so we can try to merge this before the new year?

@penelopeysm
Copy link
Member

@yebai Sure! Are we happy otherwise with the PR, i.e. if conflicts are fixed and CI passes we can merge?

@yebai
Copy link
Member

yebai commented Dec 16, 2024

I think so.

@penelopeysm penelopeysm self-assigned this Dec 17, 2024
@penelopeysm
Copy link
Member

penelopeysm commented Dec 19, 2024

@yebai CI pretty much passes fine, apart from:

  • Some numerical tests fail on x86 by a fairly small amount. I can't quite tell why – as far as I can tell, everything has been seeded correctly.

One of them is in the Gibbs tests, on the dynamic Chinese restaurant process model. This test is slightly dubious anyway imo

# The below are regression tests. The values we are comparing against are from
# running the above model on the "old" Gibbs sampler that was in place still on
# 2024-11-20. The model was run 5 times with 10_000 samples each time. The values
# to compare to are the mean of those 5 runs, atol is roughly estimated from the
# standard deviation of those 5 runs.
# TODO(mhauru) Could we do something smarter here? Maybe a dynamic model for which
# the posterior is analytically known? Doing 10_000 samples to run the test suite
# is not ideal
# Issue ref: https://github.com/TuringLang/Turing.jl/issues/2402
@test isapprox(mean(num_ms), 8.6087; atol=0.8)
@test isapprox(std(num_ms), 1.8865; atol=0.02)

dynamic model: Test Failed at /home/runner/work/Turing.jl/Turing.jl/test/mcmc/gibbs.jl:484
  Expression: isapprox(mean(num_ms), 8.6087; atol = 0.8)
   Evaluated: isapprox(9.8377, 8.6087; atol = 0.8)

The other one is in ESS:

MoGtest_default with CSMC + ESS: Test Failed at /home/runner/work/Turing.jl/Turing.jl/test/test_utils/numerical_tests.jl:55
  Expression: ≈(E, val, atol = atol, rtol = rtol)
   Evaluated: 3.88348278598424 ≈ 4.0 (atol=0.1, rtol=0.0)
  • There's some weird behaviour in that Gibbs test suite runs much, much slower on 1.11 than 1.10. It doesn't affect the outcome though.

Personally I don't think that either of these are serious enough to prevent us from merging this PR. I reckon that both should be tracked via new issues. If you agree, feel free to hit the button 😄

@yebai
Copy link
Member

yebai commented Dec 19, 2024

There's some weird behaviour in that Gibbs test suite runs much, much slower on 1.11 than 1.10. It doesn't affect the outcome though.

This is likely a Libtask issue on Julia 1.11. Hopefully, we will resolve this in #2427. cc @willtebbutt

EDIT: it is slightly odd that Gibbs runs faster on the master branch for Julia 1.11 branch before this PR.

@penelopeysm can you open issues to track the other minor numerical issues on X86? This is likely due to an insufficient number of MCMC iterations.

@yebai yebai merged commit 9e5467a into master Dec 19, 2024
59 of 62 checks passed
@yebai yebai deleted the mhauru/change-gibbs-sampler branch December 19, 2024 11:02
@yebai
Copy link
Member

yebai commented Dec 19, 2024

Many thanks to @mhauru, @torfjelde, @penelopeysm, and all who helped!

@sunxd3
Copy link
Member

sunxd3 commented Dec 19, 2024

🎉🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove old Gibbs sampler, make the experimental one the default
6 participants