-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convinience constructors #325
Conversation
src/abstractmcmc.jl
Outdated
ϵ = find_good_stepsize(rng, hamiltonian, init_params) | ||
@info string("Found initial step size ", ϵ) | ||
end | ||
integrator = eval(spl.integrator_method) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant more something like
const SYMBOL_TO_INTEGRATOR_TYPE = Dict(:leapfrog => Leapfrog)
function determine_integrator(integrator::Symbol)
if !haskey(SYMBOL_TO_INTEGRATOR_TYPE, integrator)
error("Integrator $integrator not supported.")
end
return SYMBOL_TO_INTEGRATOR_TYPE[integrator]
end
# If it's the "constructor" of an integrator or instantance of an integrator, do nothing.
determine_integrator(x::AbstractIntegrator) = x
determine_integrator(x::Type{<:AbstractIntegrator}) = x
determine_integrator(x) = error("Integrator $x not supported.")
This doesn't require usage of eval
(which evaluates in the global scope of the current module, i.e. if you someone attempts to naively pass in a type form their own package, then that will break), while still allowing the user to pass in a simple Symbol
or explicitly the integrator type.
With the above, the this line becomes
integrator = eval(spl.integrator_method) | |
integrator = determine_integrator(spl.integrator) |
(assuming you've also renamed the fields as I've suggested above)
Similarly should be done for the metrics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uncertain if we need the definition for x::AbstractIntegrator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can also rename the above to determine_integrator_constuctor
, to be more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And should we be using ConstructionBase.constructorof
here? Basically, if someone does NUTS(; metric=DiagEuclideanMetric{Float64})
but then we try to use this as a constructor for eltype
Float32
, things will be break.
EDIT: This is generally only going to happen if the argument is consturcted programmatically, in which case I guess they caller can just use ConstructionBase.constructorof
themselves.
src/abstractmcmc.jl
Outdated
init_params, | ||
) | ||
# rerturns a dummy integrator | ||
return Leapfrog(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this mess up the type when we call find_initial_stepsize
? F.ex. will this result in Float64
even if the rest of the sampler has been specified to be Float32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't because it is not used anywhere when a HMCSampler
is provided.
I have made it return a AbstractIntegrator
now which should be more general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found a more elegant solution!
I now simply returns the integrator provided by the user, even though it is technically not used anywhere.
src/abstractmcmc.jl
Outdated
function make_metric(spl::Union{HMC,NUTS,HMCDA}, logdensity) | ||
d = LogDensityProblems.dimension(logdensity) | ||
metric = eval(spl.metric_type) | ||
return metric(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will also not preserve the type unfortuantely; this will always result in Float64
😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float64
for what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved!
hmcda = HMCDA(1000, 0.8, 1.0) | ||
|
||
# Check that everything is initalized correctly | ||
@testset "Constructors" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of notes on the tests:
- We should test different element types (in particular
Float32
andFloat64
). - I'd recommend taking a single step so we execute the initialization procedure, and then instead check the resulting values from that (e.g. do we get the correct
eltype
, is the step-size the correct one / if we chose automatically, is it different non-zero?)
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
src/abstractmcmc.jl
Outdated
function get_type_of_spl(spl::AbstractHMCSampler) | ||
T = collect(typeof(spl).parameters)[1] | ||
return T | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems too hacky for my taste. According to a quick check, (unsurprisingly) it also can't be inferred. It seems this could be fixed by removing collect
- but IMO the correct approach would be to define dispatches such as get_type_of_spl(::NUTS{T}) = T
- or more generically define abstract type AbstractHMCSampler{T<:Real} <: AbstractSampler end
and define NUTS{T<:Real} <: AbstractHMCSampler{T} ...
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks David! I was trying to get something like this to work for 1h yesterday until I exasperated.
src/abstractmcmc.jl
Outdated
const SYMBOL_TO_INTEGRATOR_TYPE = Dict( | ||
:leapfrog => Leapfrog, | ||
:jitterleapfro => JitteredLeapfrog, | ||
:temperedleapfrog => TemperedLeapfrog, | ||
) | ||
|
||
function determine_integrator_constructor(integrator::Symbol) | ||
if !haskey(SYMBOL_TO_INTEGRATOR_TYPE, integrator) | ||
error("Integrator $integrator not supported.") | ||
end | ||
|
||
return SYMBOL_TO_INTEGRATOR_TYPE[integrator] | ||
end | ||
|
||
# If it's the "constructor" of an integrator or instantance of an integrator, do nothing. | ||
determine_integrator_constructor(x::AbstractIntegrator) = x | ||
determine_integrator_constructor(x::Type{<:AbstractIntegrator}) = x | ||
determine_integrator_constructor(x) = error("Integrator $x not supported.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a different design? It seems quite weird... Why do we need to retrieve the constructor? Couldn't we just use something like get_integrator(Val(spl.integrator), eps)
where we dispatch on the first argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok let me try!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we still want to keep the look up table design but I have disentangled make_step_size
from make_integrator
.
@torfjelde ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tend to agree with @devmotion here that a multi-dispatching approach is cleaner and more extensible. For example, the following works:
using AdvancedHMC
get_integrator(i...) = error("Integrator $(typeof(i)) not supported.")
get_integrator(i::Val{:leapfrog}, eps) = Leapfrog(eps)
# It is quite easy to add a new integrator by overloading `get_integrator`
get_integrator(i::Val{:jitteredleapfrog}, eps, jitter) = JitteredLeapfrog(eps, jitter)
get_integrator(i::Val{:temperedleapfrog}, eps, α) = TemperedLeapfrog(eps, α)
julia> get_integrator(Val(:leapfrog), 0.)
Leapfrog(ϵ=0.0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this solution a lot!! I have given the same treatment to the integrator and the metric.
There's a little bit of nesting going on here but basically if the user provides an AbstractMetric
/AbstractIntegrator
the code simly returns it. If on the other hand, the user provides a symbol it gets wrapped in Val(::Symbol)
and then we do your solution.
This allows the user to use the NUTS constructor with their own metric/integrator without having to build the whole sampler themselves which is something I always wanted to code into these constructors.
make_integrator(spl::HMCSampler, ϵ::Real) = spl.κ.τ.integrator
make_integrator(spl::AbstractHMCSampler, ϵ::Real) = make_integrator(spl.integrator, ϵ)
make_integrator(i::AbstractIntegrator, ϵ::Real) = i
make_integrator(i::Type{<:AbstractIntegrator}, ϵ::Real) = i
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ)
make_integrator(i...) = error("Integrator $(typeof(i)) not supported.")
make_integrator(i::Val{:leapfrog}, ϵ::Real) = Leapfrog(ϵ)
make_integrator(i::Val{:jitteredleapfrog}, ϵ::Real) = JitteredLeapfrog(ϵ)
make_integrator(i::Val{:temperedleapfrog}, ϵ::Real) = TemperedLeapfrog(ϵ)
#########
make_metric(i...) = error("Metric $(typeof(i)) not supported.")
make_metric(i::Symbol, T::Type, d::Int) = make_metric(Val(i), T, d)
make_metric(i::AbstractMetric, T::Type, d::Int) = i
make_metric(i::Type{AbstractMetric}, T::Type, d::Int) = i
make_metric(i::Val{:diagonal}, T::Type, d::Int) = DiagEuclideanMetric(T, d)
make_metric(i::Val{:unit}, T::Type, d::Int) = UnitEuclideanMetric(T, d)
make_metric(i::Val{:dense}, T::Type, d::Int) = DenseEuclideanMetric(T, d)
metric_type::D | ||
end | ||
|
||
function NUTS( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HMCSampler
still uses @kwdef
it seems?
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…ang/AdvancedHMC.jl into convinience_constructors
make_integrator(i::AbstractIntegrator, ϵ::Real) = i | ||
make_integrator(i::Type{<:AbstractIntegrator}, ϵ::Real) = i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are both needed?
make_integrator(i::AbstractIntegrator, ϵ::Real) = i | ||
make_integrator(i::Type{<:AbstractIntegrator}, ϵ::Real) = i | ||
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ) | ||
make_integrator(i...) = error("Integrator $(typeof(i)) not supported.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work as desired IIUC:
julia> f(x...) = typeof(x)
f (generic function with 1 method)
julia> f(3)
Tuple{Int64}
julia> f(1.0, "a")
Tuple{Float64, String}
IMO one could simply use
make_integrator(@nospecialize(i), ::Real) = error("Integrator $i not supported.")
using AdvancedHMC, AbstractMCMC, Random | ||
include("common.jl") | ||
|
||
# Initalize samplers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you maybe make two changes:
- Make each of the sampler types have their own testset inside of the testset below? Makes it a bit clearer what's going on.
- For each, just iterate over the different types we're testing (each having it's own testset).
@test hmcda_32.init_ϵ == 0.0f0 | ||
end | ||
|
||
@testset "First step" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to above: separere into different testsets:)
Fix #314. Fix #313