Skip to content

Commit

Permalink
chore: unpack kinetic.jl (#303)
Browse files Browse the repository at this point in the history
* chore: unpack kinetic.jl

* chore: cherry pick more changes from #300

Co-authored-by: Kai Xu <[email protected]>
  • Loading branch information
xukai92 and Kai Xu authored Nov 29, 2022
1 parent 11380fc commit 403c7e5
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
8 changes: 7 additions & 1 deletion src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ include("utilities.jl")
# r: momentum variables
# z: phase point / a pair of θ and r

include("kinetic.jl")
# TODO Move it back to hamiltonian.jl after the rand interface is updated
abstract type AbstractKinetic end

struct GaussianKinetic <: AbstractKinetic end

export GaussianKinetic

include("metric.jl")
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric

Expand Down
18 changes: 10 additions & 8 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ function ∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat)
return DualValue(res[1], -res[2])
end

∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric}, r::AbstractVecOrMat) = copy(r)
∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric}, r::AbstractVecOrMat) = h.metric.M⁻¹ .* r
∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric}, r::AbstractVecOrMat) = h.metric.M⁻¹ * r
∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric, <:GaussianKinetic}, r::AbstractVecOrMat) = copy(r)
∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric, <:GaussianKinetic}, r::AbstractVecOrMat) = h.metric.M⁻¹ .* r
∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric, <:GaussianKinetic}, r::AbstractVecOrMat) = h.metric.M⁻¹ * r

struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat}, V<:DualValue}
θ::T # Position variables / model parameters.
Expand Down Expand Up @@ -109,32 +109,34 @@ neg_energy(z::PhasePoint) = z.ℓπ.value + z.ℓκ.value

neg_energy(h::Hamiltonian, θ::AbstractVecOrMat) = h.ℓπ(θ)

# GaussianKinetic

neg_energy(
h::Hamiltonian{<:UnitEuclideanMetric},
h::Hamiltonian{<:UnitEuclideanMetric, <:GaussianKinetic},
r::T,
θ::T
) where {T<:AbstractVector} = -sum(abs2, r) / 2

neg_energy(
h::Hamiltonian{<:UnitEuclideanMetric},
h::Hamiltonian{<:UnitEuclideanMetric, <:GaussianKinetic},
r::T,
θ::T
) where {T<:AbstractMatrix} = -vec(sum(abs2, r; dims=1)) / 2

neg_energy(
h::Hamiltonian{<:DiagEuclideanMetric},
h::Hamiltonian{<:DiagEuclideanMetric, <:GaussianKinetic},
r::T,
θ::T
) where {T<:AbstractVector} = -sum(abs2.(r) .* h.metric.M⁻¹) / 2

neg_energy(
h::Hamiltonian{<:DiagEuclideanMetric},
h::Hamiltonian{<:DiagEuclideanMetric, <:GaussianKinetic},
r::T,
θ::T
) where {T<:AbstractMatrix} = -vec(sum(abs2.(r) .* h.metric.M⁻¹; dims=1) ) / 2

function neg_energy(
h::Hamiltonian{<:DenseEuclideanMetric},
h::Hamiltonian{<:DenseEuclideanMetric, <:GaussianKinetic},
r::T,
θ::T
) where {T<:AbstractVecOrMat}
Expand Down
3 changes: 0 additions & 3 deletions src/kinetic.jl

This file was deleted.

6 changes: 3 additions & 3 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ getname(m::T) where {T<:AbstractMetric} = getname(T)
function _rand(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
metric::UnitEuclideanMetric{T},
kinetic,
kinetic::GaussianKinetic,
) where {T}
r = randn(rng, T, size(metric)...)
return r
Expand All @@ -96,7 +96,7 @@ end
function _rand(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
metric::DiagEuclideanMetric{T},
kinetic,
kinetic::GaussianKinetic,
) where {T}
r = randn(rng, T, size(metric)...)
r ./= metric.sqrtM⁻¹
Expand All @@ -106,7 +106,7 @@ end
function _rand(
rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}},
metric::DenseEuclideanMetric{T},
kinetic,
kinetic::GaussianKinetic,
) where {T}
r = randn(rng, T, size(metric)...)
ldiv!(metric.cholM⁻¹, r)
Expand Down

0 comments on commit 403c7e5

Please sign in to comment.