Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specification of initial state for sample #119

Merged
merged 17 commits into from
Oct 24, 2023

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Mar 13, 2023

This seems convenient to have, e.g. for resuming sampling, running special warm-up procedures.

EDIT: Note that this is now dependent on #126

@codecov
Copy link

codecov bot commented Mar 13, 2023

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (4dbcb3f) 97.37% compared to head (3ed5314) 96.87%.
Report is 4 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #119      +/-   ##
==========================================
- Coverage   97.37%   96.87%   -0.51%     
==========================================
  Files           8        8              
  Lines         305      320      +15     
==========================================
+ Hits          297      310      +13     
- Misses          8       10       +2     
Files Coverage Δ
src/AbstractMCMC.jl 100.00% <ø> (ø)
src/sample.jl 95.87% <93.10%> (-0.78%) ⬇️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@torfjelde torfjelde requested a review from devmotion March 13, 2023 22:33
src/sample.jl Outdated Show resolved Hide resolved
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Can you add tests of all new lines?
  2. The names init_params and initial_state seem inconsistent, can we either use init or initial in both cases?

src/sample.jl Outdated
Comment on lines 492 to 494
chains = Distributed.pmap(
sample_chain, pool, seeds, _init_params, _initial_state
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work? I don't think pmap broadcasts its arguments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't have to, right? _init_params will either be fill(nothing, nchains) or it will be _initial_state which should also be a vector of the correct length (I should add a check for this though)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah, you created these arrays. The motivation for the branch here was to avoid allocating such arrays.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, am aware 👍 But the branching becomes a bit more annoying if we also have to do this for init_state, so I figured just allocating was a better. I can make it a conditional if you prefer 👍

@torfjelde
Copy link
Member Author

Can you add tests of all new lines?

Will do (but tomorrow; am about to sleep)!

The names init_params and initial_state seem inconsistent, can we either use init or initial in both cases?

Agreed. Hmm, I guess I'll make it init_state then? Personally prefer the initial_* but given that one already exists, I guess it makes sense to add this (though this init_params isn't even offiically supported, right?)

@devmotion
Copy link
Member

though this init_params isn't even offiically supported, right?

It's officially supported: It's clearly documented and used in downstream packages such as EllipticalSliceSampling, DynamicPPL, and Turing.

@torfjelde
Copy link
Member Author

It's officially supported: It's clearly documented and used in downstream packages such as EllipticalSliceSampling, DynamicPPL, and Turing.

Yes, that I'm fully aware of! Just remembered seeing the following in the docs:

There is no "official" way for providing initial parameter values yet. However, multiple packages such as EllipticalSliceSampling.jl and AdvancedMH.jl support an init_params keyword argument for setting the initial values when sampling a single chain. To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, we decided to support init_params in the default implementations of the ensemble methods:

Which I took to mean "we haven't really made an explicit decision on how to support initial parameters, but because so many downstream packages use init_params, we stay compatible with it".

But nonetheless, I'll change it to init_state then 👍

@devmotion
Copy link
Member

Generally, I think I'd prefer initial but yeah, it involves more changes 🤷 I'm also not a big fan of the name params, maybe initial_sample would be better?

Since #120 requires a breaking release anyway, we could also include more breaking changes. Otherwise we could deprecate the keyword argument, maybe something like the following could work:

function f(...; init_params=nothing, initial_params=init_params)
    if init_params !== nothing
        if initial_params !== init_params
            throw(ArgumentError("..."))
        end
        Base.deprecate("....", f)
    end
    ...
end

@torfjelde
Copy link
Member Author

torfjelde commented Sep 1, 2023

@devmotion I just came across this again; given that we decided to scrap #120 in the end, what do you think of at least merging this? Being able to specify the initial state would be quite useful.

EDIT: Just remember you're on vacation! No need to rush this:)

test/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

I decided to just rip the band-aid off and go with renaming init_params to initial_params right away.

If we don't, we end up in a somewhat awkward scenario where we have to pass along both init_params and initial_params, which might also just break downstream step.

@devmotion
Copy link
Member

Can we merge #126 first and then rebase this PR on it? Or maybe even directly rebase the PR on #126 and adjust the base branch of the PR?

@torfjelde
Copy link
Member Author

Most certainly 👍

@torfjelde torfjelde force-pushed the torfjelde/initial-state branch from bf70831 to 00054ef Compare October 2, 2023 10:40
@torfjelde torfjelde changed the base branch from master to torfjelde/init-params-fix October 2, 2023 10:41
@torfjelde
Copy link
Member Author

Rebased and change base for PR.

I'll add some tests for the intial_state stuff sometime today 👍

@torfjelde torfjelde changed the base branch from torfjelde/init-params-fix to master October 2, 2023 14:18
@torfjelde torfjelde force-pushed the torfjelde/initial-state branch from 00054ef to ca4f4b9 Compare October 2, 2023 14:23
src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
test/sample.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@devmotion
Copy link
Member

Seems the PR breaks EllipticalSliceSampling? 🤔 Interesting, even though it won't cause any breakage in practice if we tag a breaking release.

@torfjelde
Copy link
Member Author

Could it maybe be something related to init_params -> initial_params?

But yes, I'll bump the major version so won't break anything in practice:)

@devmotion
Copy link
Member

Yeah, I assume that's the reason: https://github.com/TuringLang/EllipticalSliceSampling.jl/blob/ca4babb2baba9008805bc8234a6fd182119e57dc/src/abstractmcmc.jl#L25 https://github.com/TuringLang/EllipticalSliceSampling.jl/blob/ca4babb2baba9008805bc8234a6fd182119e57dc/test/simple.jl#L40 I wonder though why other packages that currently use init_params (AdvancedMH IIRC? Turing?) are not affected by this change. Maybe they are missing tests for init_params/initial_params?

@@ -9,6 +9,7 @@ version = "4.5.0"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means it could be removed from the [extras] section below. But do we actually have to depend on FillArrays? I think we should just forward the user-input or nothing, but not build any arrays explicitly?

Comment on lines +435 to +438
_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to avoid these FillArrays. Maybe we could

  • move the function below to a callable struct (could be shared between all ensemble algorithms maybe?)
  • pass initial_params and initial_state to the constructor as well but only use it to define type parameters that allow us to distinguish between the four possible cases (no initial params and no initial state, only initial state, only initial params, and both initial params and state)
  • define the function of the callable struct depending on the type parameters, forwarding the versions with only the seed or only the seed and one additional argument to the three-argument version

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is adding in callable structs here really an improvement? 😕 I agree it's more efficient, but it seems like this will be quite a bit more complex + the efficiency doesn't really matter here, right?

Copy link
Member

@devmotion devmotion Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A callable struct should generally be better for the compiler than a closure, shouldn't it? Regardless of whether we change or add arguments as in this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! But is this performance critical code? And it seems to be me that we'll need a callable struct for each scenario?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷 Are you sure? To me it seems one struct is sufficient - both the multithreaded and the multicore version seem to use the same inner structure, and in the serial case we could set channel = nothing. If needed we could also dispatch on the type of the algorithm to handle minor differences in the function call.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it seems one struct is sufficient

To clarify, I don't question whether we can have a single callable struct with different call implementations; I meant more that it seems it won't be as simple as just doing

struct SampleFunc
    # ...
end

function (f::SampleFunc)(args...)
    # ...
end

multithreaded and the multicore version seem to use the same inner structure

But if we put initial_params and initial_state in the callable struct, then we'll need to pmap, etc. over a range containing the corresponding indices, no? Which seems like it would lead to more allocations than the current impl using Fill(nothing, nchains)?

Or am I misunderstanding what you mean here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just bumping this discussion:) Would be nice to get a version of this PR merged.

@torfjelde
Copy link
Member Author

Maybe they are missing tests for init_params/initial_params?

Very likely 😕

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go with the FillArrays dependency for now.

@torfjelde torfjelde merged commit 8d45ff4 into master Oct 24, 2023
@delete-merged-branch delete-merged-branch bot deleted the torfjelde/initial-state branch October 24, 2023 15:26
torfjelde added a commit that referenced this pull request Oct 24, 2023
torfjelde added a commit that referenced this pull request Oct 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants