Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sensible Defaults #184

Closed
willtebbutt opened this issue Feb 27, 2020 · 2 comments
Closed

Sensible Defaults #184

willtebbutt opened this issue Feb 27, 2020 · 2 comments

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Feb 27, 2020

While the example in the README is important in that it exposes all of the useful bits of AdvancedHMC, it would be nice if it were possible to define a couple of sensible defaults. e.g.

vanilla_hmc_sampler = HMC(step_size, n_steps)
reasonably_sensible_nuts_sampler = NUTS(step_size)

I think that this kind of thing would be a helpful addition to the package, as it makes it more straightforward for a new user to get up and running with their first sampler / for an expert to prototype stuff before optimising.

In particular, it would make writing examples for Stheno.jls documentation quite a lot less verbose.

edit:

I think the above would likely tie a sampler and an adapter together. The sampling API would be something like

samples, stats = sample(h, vanilla_hmc_sampler, n_samples_to_draw; progress=true)

and by default it would maybe spend n_samples_to_draw steps burning in / adapting, then draw the samples.

@xukai92 xukai92 added this to the Release v0.3 milestone Aug 2, 2020
@yebai yebai removed this from the Release v0.3 milestone Nov 13, 2022
@scheidan
Copy link
Contributor

scheidan commented Mar 6, 2023

I'd like to bump this issue. I'm planning to use AdvancedHMC for a teaching exercise (outside of Turing) so really need a a much simpler, user-friendly interface.

For my own use I wrote this wrapper function below (more or less directly from the Readme). Is something like this of scope for this package? It would make it much easier for users not using Turing to profit from your work.

# NUTS sampler as used by STAN
Based on documentation of AdvancedHMC.

```Julia
stanHMC(lp::Function, ∇lp::Function,  θ_init,;
        n_samples::Int=1000, n_adapts::Int=n_samples÷2)
```

## Arguments
- `lp`:  log density to sample from (up to a constant)
- `∇lp`: function that computes gradient
- `n_samples::Int=1000`: number of samples
- `n_adapts::Int=n_samples÷2`: length of adaptation

### Return Value
A named tuple with fields:
- `samples`: array containing the samples
- `stats`: contains various statistics form the sampler. See AdvancedHMC documentation.
"""
function advancedHMC_NUTS(lp::Function, ∇lp::Function,  θ_init,;
                          n_samples::Int=1000, n_adapts::Int=n_samples÷2)

    # Define a Hamiltonian system
    D = length(θ_init)   # number of parameters
    metric = DiagEuclideanMetric(D)

    # choose AD framework or provide a function manually
    hamiltonian = Hamiltonian(metric, lp, θr -> (lp(θr), ∇lp(θr)))

    # Define a leapfrog solver, with initial step size chosen heuristically
    initial_ϵ = find_good_stepsize(hamiltonian, θ_init)
    integrator = Leapfrog(initial_ϵ)

    # Define an HMC sampler, with the following components
    #   - multinomial sampling scheme,
    #   - generalised No-U-Turn criteria, and
    #   - windowed adaption for step-size and diagonal mass matrix
    proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
    adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

    # -- run sampler
    samples, stats = sample(hamiltonian, proposal, θ_init, n_samples,
                            adaptor, n_adapts; progress=true)

    return (samples=samples, stats=stats)

end

@yebai
Copy link
Member

yebai commented Aug 2, 2023

@JaimeRZP's recent work likely fixes this. See e.g. #323

@yebai yebai closed this as completed Aug 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants