Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convinience constructors #325

Merged
merged 109 commits into from
Jul 20, 2023
Merged

Convinience constructors #325

merged 109 commits into from
Jul 20, 2023

Conversation

JaimeRZP
Copy link
Member

@JaimeRZP JaimeRZP commented Jun 28, 2023

Fix #314. Fix #313

src/AdvancedHMC.jl Outdated Show resolved Hide resolved
src/abstractmcmc.jl Outdated Show resolved Hide resolved
src/abstractmcmc.jl Outdated Show resolved Hide resolved
src/abstractmcmc.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/sampler.jl Outdated Show resolved Hide resolved
test/sampler.jl Outdated Show resolved Hide resolved
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
@info string("Found initial step size ", ϵ)
end
integrator = eval(spl.integrator_method)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant more something like

const SYMBOL_TO_INTEGRATOR_TYPE = Dict(:leapfrog => Leapfrog)

function determine_integrator(integrator::Symbol)
    if !haskey(SYMBOL_TO_INTEGRATOR_TYPE, integrator)
        error("Integrator $integrator not supported.")
    end

    return SYMBOL_TO_INTEGRATOR_TYPE[integrator]
end

# If it's the "constructor" of an integrator or instantance of an integrator, do nothing.
determine_integrator(x::AbstractIntegrator) = x
determine_integrator(x::Type{<:AbstractIntegrator}) = x

determine_integrator(x) = error("Integrator $x not supported.")

This doesn't require usage of eval (which evaluates in the global scope of the current module, i.e. if you someone attempts to naively pass in a type form their own package, then that will break), while still allowing the user to pass in a simple Symbol or explicitly the integrator type.

With the above, the this line becomes

Suggested change
integrator = eval(spl.integrator_method)
integrator = determine_integrator(spl.integrator)

(assuming you've also renamed the fields as I've suggested above)

Similarly should be done for the metrics.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uncertain if we need the definition for x::AbstractIntegrator

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also rename the above to determine_integrator_constuctor, to be more explicit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And should we be using ConstructionBase.constructorof here? Basically, if someone does NUTS(; metric=DiagEuclideanMetric{Float64}) but then we try to use this as a constructor for eltype Float32, things will be break.

EDIT: This is generally only going to happen if the argument is consturcted programmatically, in which case I guess they caller can just use ConstructionBase.constructorof themselves.

src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/abstractmcmc.jl Outdated Show resolved Hide resolved
init_params,
)
# rerturns a dummy integrator
return Leapfrog(0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this mess up the type when we call find_initial_stepsize? F.ex. will this result in Float64 even if the rest of the sampler has been specified to be Float32?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't because it is not used anywhere when a HMCSampler is provided.
I have made it return a AbstractIntegrator now which should be more general.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found a more elegant solution!
I now simply returns the integrator provided by the user, even though it is technically not used anywhere.

function make_metric(spl::Union{HMC,NUTS,HMCDA}, logdensity)
d = LogDensityProblems.dimension(logdensity)
metric = eval(spl.metric_type)
return metric(d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also not preserve the type unfortuantely; this will always result in Float64 😕

Copy link
Member Author

@JaimeRZP JaimeRZP Jul 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float64 for what?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved!

src/constructors.jl Outdated Show resolved Hide resolved
hmcda = HMCDA(1000, 0.8, 1.0)

# Check that everything is initalized correctly
@testset "Constructors" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of notes on the tests:

  • We should test different element types (in particular Float32 and Float64).
  • I'd recommend taking a single step so we execute the initialization procedure, and then instead check the resulting values from that (e.g. do we get the correct eltype, is the step-size the correct one / if we chose automatically, is it different non-zero?)

JaimeRZP and others added 2 commits July 19, 2023 10:40
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
test/constructors.jl Outdated Show resolved Hide resolved
src/AdvancedHMC.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
src/constructors.jl Outdated Show resolved Hide resolved
test/constructors.jl Outdated Show resolved Hide resolved
test/constructors.jl Outdated Show resolved Hide resolved
test/constructors.jl Outdated Show resolved Hide resolved
Comment on lines 247 to 250
function get_type_of_spl(spl::AbstractHMCSampler)
T = collect(typeof(spl).parameters)[1]
return T
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too hacky for my taste. According to a quick check, (unsurprisingly) it also can't be inferred. It seems this could be fixed by removing collect - but IMO the correct approach would be to define dispatches such as get_type_of_spl(::NUTS{T}) = T - or more generically define abstract type AbstractHMCSampler{T<:Real} <: AbstractSampler end and define NUTS{T<:Real} <: AbstractHMCSampler{T} ....

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks David! I was trying to get something like this to work for 1h yesterday until I exasperated.

Comment on lines 252 to 269
const SYMBOL_TO_INTEGRATOR_TYPE = Dict(
:leapfrog => Leapfrog,
:jitterleapfro => JitteredLeapfrog,
:temperedleapfrog => TemperedLeapfrog,
)

function determine_integrator_constructor(integrator::Symbol)
if !haskey(SYMBOL_TO_INTEGRATOR_TYPE, integrator)
error("Integrator $integrator not supported.")
end

return SYMBOL_TO_INTEGRATOR_TYPE[integrator]
end

# If it's the "constructor" of an integrator or instantance of an integrator, do nothing.
determine_integrator_constructor(x::AbstractIntegrator) = x
determine_integrator_constructor(x::Type{<:AbstractIntegrator}) = x
determine_integrator_constructor(x) = error("Integrator $x not supported.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a different design? It seems quite weird... Why do we need to retrieve the constructor? Couldn't we just use something like get_integrator(Val(spl.integrator), eps) where we dispatch on the first argument?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let me try!

Copy link
Member Author

@JaimeRZP JaimeRZP Jul 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still want to keep the look up table design but I have disentangled make_step_size from make_integrator.
@torfjelde ?

Copy link
Member

@yebai yebai Jul 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to agree with @devmotion here that a multi-dispatching approach is cleaner and more extensible. For example, the following works:

using AdvancedHMC
get_integrator(i...) = error("Integrator $(typeof(i)) not supported.")
get_integrator(i::Val{:leapfrog}, eps) = Leapfrog(eps)

# It is quite easy to add a new integrator by overloading `get_integrator`
get_integrator(i::Val{:jitteredleapfrog}, eps, jitter) = JitteredLeapfrog(eps, jitter)
get_integrator(i::Val{:temperedleapfrog}, eps, α) = TemperedLeapfrog(eps, α)

julia> get_integrator(Val(:leapfrog), 0.)
Leapfrog=0.0)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this solution a lot!! I have given the same treatment to the integrator and the metric.
There's a little bit of nesting going on here but basically if the user provides an AbstractMetric/AbstractIntegrator the code simly returns it. If on the other hand, the user provides a symbol it gets wrapped in Val(::Symbol) and then we do your solution.

This allows the user to use the NUTS constructor with their own metric/integrator without having to build the whole sampler themselves which is something I always wanted to code into these constructors.

make_integrator(spl::HMCSampler, ϵ::Real) = spl.κ.τ.integrator
make_integrator(spl::AbstractHMCSampler, ϵ::Real) = make_integrator(spl.integrator, ϵ)
make_integrator(i::AbstractIntegrator, ϵ::Real) = i
make_integrator(i::Type{<:AbstractIntegrator}, ϵ::Real) = i
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ)
make_integrator(i...) = error("Integrator $(typeof(i)) not supported.")
make_integrator(i::Val{:leapfrog}, ϵ::Real) = Leapfrog(ϵ)
make_integrator(i::Val{:jitteredleapfrog}, ϵ::Real) = JitteredLeapfrog(ϵ)
make_integrator(i::Val{:temperedleapfrog}, ϵ::Real) = TemperedLeapfrog(ϵ)

#########

make_metric(i...) = error("Metric $(typeof(i)) not supported.")
make_metric(i::Symbol, T::Type, d::Int) = make_metric(Val(i), T, d)
make_metric(i::AbstractMetric, T::Type, d::Int) = i
make_metric(i::Type{AbstractMetric}, T::Type, d::Int) = i
make_metric(i::Val{:diagonal}, T::Type, d::Int) = DiagEuclideanMetric(T, d)
make_metric(i::Val{:unit}, T::Type, d::Int) = UnitEuclideanMetric(T, d)
make_metric(i::Val{:dense}, T::Type, d::Int) = DenseEuclideanMetric(T, d)

metric_type::D
end

function NUTS(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HMCSampler still uses @kwdef it seems?

@yebai yebai merged commit a7edfa9 into master Jul 20, 2023
@delete-merged-branch delete-merged-branch bot deleted the convinience_constructors branch July 20, 2023 18:43
Comment on lines +291 to +292
make_integrator(i::AbstractIntegrator, ϵ::Real) = i
make_integrator(i::Type{<:AbstractIntegrator}, ϵ::Real) = i
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are both needed?

make_integrator(i::AbstractIntegrator, ϵ::Real) = i
make_integrator(i::Type{<:AbstractIntegrator}, ϵ::Real) = i
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ)
make_integrator(i...) = error("Integrator $(typeof(i)) not supported.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work as desired IIUC:

julia> f(x...) = typeof(x)
f (generic function with 1 method)

julia> f(3)
Tuple{Int64}

julia> f(1.0, "a")
Tuple{Float64, String}

IMO one could simply use

make_integrator(@nospecialize(i), ::Real) = error("Integrator $i not supported.")

@yebai yebai mentioned this pull request Jul 25, 2023
using AdvancedHMC, AbstractMCMC, Random
include("common.jl")

# Initalize samplers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you maybe make two changes:

  • Make each of the sampler types have their own testset inside of the testset below? Makes it a bit clearer what's going on.
  • For each, just iterate over the different types we're testing (each having it's own testset).

@test hmcda_32.init_ϵ == 0.0f0
end

@testset "First step" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above: separere into different testsets:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make HMCState stores rng in AbstractMCMC interface. More friendly default sample interface
4 participants