-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: relativistic HMC * chore: move codes to experimental folder * chore: commit unsaved change * chore: warp experimental codes into a seperate module * chore: remove new the module * chore: create experimental tests and set up CI * Moved all experimental code including tests into research folder. (#304) * Moved all experimental code into research folder. * Minor changes in deps. * Improve tests for experimental code. * Change CI name. * Fix path for research includes. * More fixes to includes * Minor fixes. * Update research/src/relativistic_hmc.jl * Apply suggestions from code review * Create README.md * Update research/src/relativistic_hmc.jl * Update relativistic_hmc.jl Co-authored-by: Kai Xu <[email protected]> Co-authored-by: Kai Xu <[email protected]> Co-authored-by: Hong Ge <[email protected]>
- Loading branch information
1 parent
403c7e5
commit 6a55a3f
Showing
7 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
name: TestExperimentalFeatures | ||
|
||
on: | ||
push: | ||
branches: | ||
- master | ||
pull_request: | ||
|
||
jobs: | ||
test: | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
matrix: | ||
version: | ||
- '1' | ||
os: | ||
- ubuntu-latest | ||
- macOS-latest | ||
- windows-latest | ||
arch: | ||
- x86 | ||
- x64 | ||
exclude: | ||
- os: ubuntu-latest | ||
arch: x86 | ||
- os: macOS-latest | ||
arch: x86 | ||
- os: windows-latest | ||
arch: x86 | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- uses: julia-actions/setup-julia@v1 | ||
with: | ||
version: ${{ matrix.version }} | ||
arch: ${{ matrix.arch }} | ||
- uses: julia-actions/julia-buildpkg@latest | ||
- name: Run integration tests | ||
uses: julia-actions/julia-runtest@latest | ||
env: | ||
AHMC_TEST_GROUP: Experimental |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[deps] | ||
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | ||
AdaptiveRejectionSampling = "c75e803d-635f-53bd-ab7d-544e482d8c75" | ||
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" | ||
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | ||
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Requires = "ae029012-a4dd-5104-9daa-d747884805df" | ||
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" | ||
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" | ||
|
||
[compat] | ||
AbstractMCMC = "3.2, 4" | ||
ArgCheck = "1, 2" | ||
DocStringExtensions = "0.8, 0.9" | ||
InplaceOps = "0.3" | ||
ProgressMeter = "1" | ||
Requires = "0.5, 1" | ||
Setfield = "0.7, 0.8, 1" | ||
StatsBase = "0.31, 0.32, 0.33" | ||
StatsFuns = "0.8, 0.9, 1" | ||
UnPack = "1" | ||
julia = "1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
In order to use algorithms in this folder, please navigate to the AdvancedHMC folder and run | ||
|
||
|
||
``` | ||
] activate research/ | ||
] develop src/ | ||
] instantiate | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
using AdvancedHMC | ||
import AdvancedHMC: ∂H∂r, neg_energy, AbstractKinetic | ||
import Random: AbstractRNG | ||
|
||
struct RelativisticKinetic{T} <: AbstractKinetic | ||
"Mass" | ||
m::T | ||
"Speed of light" | ||
c::T | ||
end | ||
|
||
|
||
function ∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric,<:RelativisticKinetic}, r::AbstractVecOrMat) | ||
mass = h.kinetic.m .* sqrt.(r.^2 ./ (h.kinetic.m.^2 * h.kinetic.c.^2) .+ 1) | ||
return r ./ mass | ||
end | ||
function ∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric,<:RelativisticKinetic}, r::AbstractVecOrMat) | ||
r = h.metric.sqrtM⁻¹ .* r | ||
mass = h.kinetic.m .* sqrt.(r.^2 ./ (h.kinetic.m.^2 * h.kinetic.c.^2) .+ 1) | ||
retval = r ./ mass # red part of (15) | ||
return h.metric.sqrtM⁻¹ .* retval # (15) | ||
end | ||
|
||
|
||
function neg_energy( | ||
h::Hamiltonian{<:UnitEuclideanMetric,<:RelativisticKinetic}, | ||
r::T, | ||
θ::T | ||
) where {T<:AbstractVector} | ||
return -sum(h.kinetic.m .* h.kinetic.c.^2 .* sqrt.(r.^2 ./ (h.kinetic.m.^2 .* h.kinetic.c.^2) .+ 1)) | ||
end | ||
|
||
function neg_energy( | ||
h::Hamiltonian{<:DiagEuclideanMetric,<:RelativisticKinetic}, | ||
r::T, | ||
θ::T | ||
) where {T<:AbstractVector} | ||
r = h.metric.sqrtM⁻¹ .* r | ||
return -sum(h.kinetic.m .* h.kinetic.c.^2 .* sqrt.(r.^2 ./ (h.kinetic.m.^2 .* h.kinetic.c.^2) .+ 1)) | ||
end | ||
|
||
|
||
using AdaptiveRejectionSampling: RejectionSampler, run_sampler! | ||
|
||
# TODO Support AbstractVector{<:AbstractRNG} | ||
function _rand( | ||
rng::AbstractRNG, | ||
metric::UnitEuclideanMetric{T}, | ||
kinetic::RelativisticKinetic{T}, | ||
) where {T} | ||
h_temp = Hamiltonian(metric, kinetic, identity, identity) | ||
densityfunc = x -> exp(neg_energy(h_temp, [x], [x])) | ||
sampler = RejectionSampler(densityfunc, (-Inf, Inf); max_segments=5) | ||
sz = size(metric) | ||
r = run_sampler!(rng, sampler, prod(sz)) | ||
r = reshape(r, sz) | ||
return r | ||
end | ||
|
||
# TODO Support AbstractVector{<:AbstractRNG} | ||
function _rand( | ||
rng::AbstractRNG, | ||
metric::DiagEuclideanMetric{T}, | ||
kinetic::RelativisticKinetic{T}, | ||
) where {T} | ||
r = _rand(rng, UnitEuclideanMetric(size(metric)), kinetic) | ||
# p' = A p where A = sqrtM | ||
r ./= metric.sqrtM⁻¹ | ||
return r | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
using ReTest, AdvancedHMC | ||
using AdvancedHMC | ||
using LinearAlgebra: dot | ||
|
||
@testset "Hamiltonian" begin | ||
f = x -> dot(x, x) | ||
g = x -> 2x | ||
metric = UnitEuclideanMetric(10) | ||
h = Hamiltonian(metric, RelativisticKinetic(1.0, 1.0), f, g) | ||
@test h.kinetic isa RelativisticKinetic | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
using ReTest | ||
|
||
# include the source code for relativistic HMC | ||
include("../src/relativistic_hmc.jl") | ||
|
||
# include the tests for relativistic HMC | ||
include("relativistic_hmc.jl") | ||
|
||
@main function runtests(patterns...; dry::Bool=false) | ||
retest(patterns...; dry=dry, verbose=Inf) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters