Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DirectCR-RSSA #346

Open
wants to merge 43 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
7861bc7
Write RSSACR-Direct but it's incorrect. I will turn it into CR-RSSA.
Vilin97 Sep 10, 2023
e36f466
Fix the main part of the SSA code. Time to clean up.
Vilin97 Sep 11, 2023
5ce7d25
Rename to `DirectCRRSSA`.
Vilin97 Sep 12, 2023
2a8bfa7
Fix a docstring.
Vilin97 Sep 12, 2023
0d96e80
Shorten a function.
Vilin97 Sep 12, 2023
07873e8
Uncomment tests in `ABC.jl`
Vilin97 Sep 12, 2023
7fce14e
Delete comment.
Vilin97 Sep 12, 2023
74346b2
Add DirectCRRSSA to the diffusion test.
Vilin97 Sep 12, 2023
c5f93dc
Add `DirectCRRSSA` to `ABC.jl`.
Vilin97 Sep 12, 2023
46f0256
Merge branch 'master'
Vilin97 Sep 12, 2023
fd3ea1e
Shorten `getindex`.
Vilin97 Sep 12, 2023
107068f
Remove the low bound on site rates, as it is not used.
Vilin97 Sep 12, 2023
e1ed533
Add an `@inbounds`.
Vilin97 Sep 12, 2023
3290d30
Add `AbstractMatrix` back in.
Vilin97 Sep 12, 2023
af28404
Remove another change to shorten the PR.
Vilin97 Sep 12, 2023
f5bdc99
Shorten a function.
Vilin97 Sep 12, 2023
3757606
Swap order of functions.
Vilin97 Sep 12, 2023
22e8ece
Remove typos from `ABC.jl`.
Vilin97 Sep 12, 2023
289a9da
Fix test.
Vilin97 Sep 12, 2023
2c0d324
Merge branch 'master' of https://github.com/SciML/JumpProcesses.jl in…
Vilin97 Sep 14, 2023
0be2fd0
Write RSSACR-Direct but it's incorrect. I will turn it into CR-RSSA.
Vilin97 Sep 10, 2023
675c140
Fix the main part of the SSA code. Time to clean up.
Vilin97 Sep 11, 2023
551af91
Rename to `DirectCRRSSA`.
Vilin97 Sep 12, 2023
5d02237
Fix a docstring.
Vilin97 Sep 12, 2023
1a24b08
Shorten a function.
Vilin97 Sep 12, 2023
4701171
Uncomment tests in `ABC.jl`
Vilin97 Sep 12, 2023
4a7e562
Delete comment.
Vilin97 Sep 12, 2023
d3bda54
Add DirectCRRSSA to the diffusion test.
Vilin97 Sep 12, 2023
d146280
Add `DirectCRRSSA` to `ABC.jl`.
Vilin97 Sep 12, 2023
ab6a9f6
Shorten `getindex`.
Vilin97 Sep 12, 2023
97a7974
Remove the low bound on site rates, as it is not used.
Vilin97 Sep 12, 2023
2b3a7a7
Add an `@inbounds`.
Vilin97 Sep 12, 2023
7c7a039
Add `AbstractMatrix` back in.
Vilin97 Sep 12, 2023
92880db
Remove another change to shorten the PR.
Vilin97 Sep 12, 2023
681be86
Shorten a function.
Vilin97 Sep 12, 2023
6ed6e1d
Swap order of functions.
Vilin97 Sep 12, 2023
0ccc09a
Remove typos from `ABC.jl`.
Vilin97 Sep 12, 2023
355778d
Fix test.
Vilin97 Sep 12, 2023
558e3e0
Merge branch 'RSSACRDirect' of https://github.com/SciML/JumpProcesses…
Vilin97 Aug 21, 2024
9148dc0
Merge branch 'master' into RSSACRDirect
Vilin97 Aug 21, 2024
1039d2e
Merge branches 'RSSACRDirect' and 'RSSACRDirect' of https://github.co…
Vilin97 Aug 21, 2024
2568790
Merge branch 'master' into RSSACRDirect
Vilin97 Aug 21, 2024
88d22b7
Address comments.
Vilin97 Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ include("spatial/bracketing.jl")

include("spatial/nsm.jl")
include("spatial/directcrdirect.jl")
include("spatial/directcrrssa.jl")

include("aggregators/aggregated_api.jl")

Expand Down Expand Up @@ -101,6 +102,6 @@ export ExtendedJumpArray
export CartesianGrid, CartesianGridRej
export SpatialMassActionJump
export outdegree, num_sites, neighbors
export NSM, DirectCRDirect
export NSM, DirectCRDirect, DirectCRRSSA

end # module
3 changes: 3 additions & 0 deletions src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108
"""
struct DirectCRDirect <: AbstractAggregatorAlgorithm end

struct DirectCRRSSA <: AbstractAggregatorAlgorithm end

const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(),
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve())

Expand Down Expand Up @@ -187,3 +189,4 @@ supports_variablerates(aggregator::Coevolve) = true
is_spatial(aggregator::AbstractAggregatorAlgorithm) = false
is_spatial(aggregator::NSM) = true
is_spatial(aggregator::DirectCRDirect) = true
is_spatial(aggregator::DirectCRRSSA) = true
49 changes: 32 additions & 17 deletions src/spatial/bracketing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@ struct LowHigh{T}
low::T
high::T

LowHigh(low::T, high::T) where {T} = new{T}(deepcopy(low), deepcopy(high))
LowHigh(pair::Tuple{T,T}) where {T} = new{T}(pair[1], pair[2])
LowHigh(low_and_high::T) where {T} = new{T}(low_and_high, deepcopy(low_and_high))
function LowHigh(low::T, high::T; do_copy = true) where {T}
if do_copy
return new{T}(deepcopy(low), deepcopy(high))
else
return new{T}(low, high)
end
end
LowHigh(pair::Tuple{T,T}; kwargs...) where {T} = LowHigh(pair[1], pair[2]; kwargs...)
LowHigh(low_and_high::T; kwargs...) where {T} = LowHigh(low_and_high, low_and_high; kwargs...)
Comment on lines +8 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does switching to using a branch have performance implications? We create these structures a lot right in the scalar case right? I'm not sure if Julia will remove it during compilation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if it's not optimized away, it shouldn't be expensive, right? I am not good at reading low level code but it seems that it does get optimized away in the end.

# Old struct without branching
struct LowHighOld{T}
    low::T
    high::T

    function LowHighOld(low::T, high::T) where {T}
        new{T}(deepcopy(low), deepcopy(high))
    end
end

# New struct with branching
struct LowHighNew{T}
    low::T
    high::T

    function LowHighNew(low::T, high::T; do_copy=true) where {T}
        if do_copy
            return new{T}(deepcopy(low), deepcopy(high))
        else
            return new{T}(low, high)
        end
    end
end

# Test values
low_value = 10
high_value = 20

# Code introspection to see if the branch is removed in scalar cases
@code_llvm debuginfo=:none LowHighOld(low_value, high_value)

@code_llvm debuginfo=:none LowHighNew(low_value, high_value; do_copy=true)

The output of the code introspection is as follows:

define void @julia_LowHighOld_2364([2 x i64]* noalias nocapture noundef nonnull sret([2 x i64]) 
align 8 dereferenceable(16) %0, {}* noundef nonnull readonly %1, i64 signext %2, i64 signext %3) #0 {
top:
  %newstruct.sroa.0.0..sroa_idx = getelementptr inbounds [2 x i64], [2 x i64]* %0, i64 0, i64 0 
  store i64 %2, i64* %newstruct.sroa.0.0..sroa_idx, align 8
  %newstruct.sroa.2.0..sroa_idx1 = getelementptr inbounds [2 x i64], [2 x i64]* %0, i64 0, i64 1  store i64 %3, i64* %newstruct.sroa.2.0..sroa_idx1, align 8
  ret void
}

define void @julia_LowHighNew_2366([2 x i64]* noalias nocapture noundef nonnull sret([2 x i64]) 
align 8 dereferenceable(16) %0, [1 x i8]* nocapture noundef nonnull readonly align 1 dereferenceable(1) %1, {}* noundef nonnull readonly %2, i64 signext %3, i64 signext %4) #0 {
top:
  %.sroa.025.0..sroa_idx = getelementptr inbounds [2 x i64], [2 x i64]* %0, i64 0, i64 0        
  store i64 %3, i64* %.sroa.025.0..sroa_idx, align 8
  %.sroa.2.0..sroa_idx26 = getelementptr inbounds [2 x i64], [2 x i64]* %0, i64 0, i64 1
  store i64 %4, i64* %.sroa.2.0..sroa_idx26, align 8
  ret void
}

end

function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh)
Expand All @@ -16,28 +22,32 @@ function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh)
end

@inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix)
@inbounds for (i, uval) in enumerate(u)
u_low_high[i] = LowHigh(get_spec_brackets(bracket_data, i, uval))
num_species, num_sites = size(u)
update_u_brackets!(u_low_high, bracket_data, u, 1:num_species, 1:num_sites)
end

@inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix, species_vec, sites)
@inbounds for site in sites
for species in species_vec
u_low_high[species, site] = LowHigh(get_spec_brackets(bracket_data, species, u[species, site]))
end
end
nothing
end

### convenience functions for LowHigh ###
function setindex!(low_high::LowHigh, val::LowHigh, i)
low_high.low[i] = val.low
low_high.high[i] = val.high
val
function is_inside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M}
return u_low_high.low[species, site] < u[species, site] < u_low_high.high[species, site]
end

function getindex(low_high::LowHigh, i)
return LowHigh(low_high.low[i], low_high.high[i])
### convenience functions for LowHigh ###
function setindex!(low_high::LowHigh{A}, val::LowHigh, i...) where {A <: AbstractArray}
low_high.low[i...] = val.low
low_high.high[i...] = val.high
val
end
getindex(low_high::LowHigh{A}, i) where {A <: AbstractArray} = LowHigh(low_high.low[i], low_high.high[i])

function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site)
return LowHigh(
total_site_rate(rx_rates.low, hop_rates.low, site),
total_site_rate(rx_rates.high, hop_rates.high, site))
end
get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low)

function update_rx_rates!(rx_rates::LowHigh, rxs, u_low_high, integrator, site)
update_rx_rates!(rx_rates.low, rxs, u_low_high.low, integrator, site)
Expand All @@ -48,3 +58,8 @@ function update_hop_rates!(hop_rates::LowHigh, species, u_low_high, site, spatia
update_hop_rates!(hop_rates.low, species, u_low_high.low, site, spatial_system)
update_hop_rates!(hop_rates.high, species, u_low_high.high, site, spatial_system)
end

function reset!(low_high::LowHigh)
reset!(low_high.low)
reset!(low_high.high)
end
5 changes: 2 additions & 3 deletions src/spatial/directcrdirect.jl
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a very hard-to-find bug in the ordering of arguments in generate_jumps!. I fixed it here and in NSM.jl.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
const MINJUMPRATE = 2.0^exponent(1e-12)

#NOTE state vector u is a matrix. u[i,j] is species i, site j
#NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j
mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR,
VJMAP, JVMAP, SS, U <: PriorityTable,
W <: Function} <:
Expand Down Expand Up @@ -107,12 +106,12 @@ end
function initialize!(p::DirectCRDirectJumpAggregation, integrator, u, params, t)
p.end_time = integrator.sol.prob.tspan[2]
fill_rates_and_get_times!(p, integrator, t)
generate_jumps!(p, integrator, params, u, t)
generate_jumps!(p, integrator, u, params, t)
nothing
end

# calculate the next jump / jump time
function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, params, u, t)
function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, u, params, t)
p.next_jump_time = t + randexp(p.rng) / p.rt.gsum
p.next_jump_time >= p.end_time && return nothing
site = sample(p.rt, p.site_rates, p.rng)
Expand Down
265 changes: 265 additions & 0 deletions src/spatial/directcrrssa.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# site chosen with DirectCR, rx or hop chosen with RSSA

############################ DirectCRRSSA ###################################
const MINJUMPRATE = 2.0^exponent(1e-12)

#NOTE state vector u is a matrix. u[i,j] is species i, site j
mutable struct DirectCRRSSAJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR,
VJMAP, JVMAP, SS, U <: PriorityTable, S, F1, F2} <:
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::SpatialJump{J}
prev_jump::SpatialJump{J}
next_jump_time::T
end_time::T
bracket_data::BD
u_low_high::LowHigh{M} # species bracketing
rx_rates::LowHigh{RX}
hop_rates::LowHigh{HOP}
site_rates_high::Vector{T} # we do not need site_rates_low
save_positions::Tuple{Bool, Bool}
rng::RNG
dep_gr::DEPGR #dep graph is same for each site
vartojumps_map::VJMAP #vartojumps_map is same for each site
jumptovars_map::JVMAP #jumptovars_map is same for each site
spatial_system::SS
numspecies::Int #number of species
rt::U
rates::F1 # legacy, not used
affects!::F2 # legacy, not used
end

function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD,
u_low_high::LowHigh{M}, rx_rates::LowHigh{RX},
hop_rates::LowHigh{HOP}, site_rates_high::Vector{T},
sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS;
num_specs, minrate = convert(T, MINJUMPRATE),
vartojumps_map = nothing, jumptovars_map = nothing,
dep_graph = nothing,
kwargs...) where {J, T, BD, RX, HOP, RNG, SS, M}

# a dependency graph is needed
if dep_graph === nothing
dg = make_dependency_graph(num_specs, get_majumps(rx_rates))
else
dg = dep_graph
# make sure each jump depends on itself
add_self_dependencies!(dg)
end

# a species-to-reactions graph is needed
if vartojumps_map === nothing
vtoj_map = var_to_jumps_map(num_specs, get_majumps(rx_rates))
else
vtoj_map = vartojumps_map
end

if jumptovars_map === nothing
jtov_map = jump_to_vars_map(get_majumps(rx_rates))
else
jtov_map = jumptovars_map
end

# mapping from jump rate to group id
minexponent = exponent(minrate)

# use the largest power of two that is <= the passed in minrate
minrate = 2.0^minexponent
ratetogroup = rate -> priortogid(rate, minexponent)

# construct an empty initial priority table -- we'll reset this in init
rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate)

DirectCRRSSAJumpAggregation{
T,
BD,
M,
RNG,
J,
RX,
HOP,
typeof(dg),
typeof(vtoj_map),
typeof(jtov_map),
SS,
typeof(rt),
Nothing,
Nothing,
Nothing,
}(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates_high, sps, rng, dg,
vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing)
end

############################# Required Functions ##############################
# creating the JumpAggregation structure (function wrapper-based constant jumps)
function aggregate(aggregator::DirectCRRSSA, starting_state, p, t, end_time,
constant_jumps, ma_jumps, save_positions, rng; hopping_constants,
spatial_system, bracket_data = nothing, kwargs...)
T = typeof(end_time)
num_species = size(starting_state, 1)
majumps = ma_jumps
if majumps === nothing
majumps = MassActionJump(Vector{T}(),
Vector{Vector{Pair{Int, Int}}}(),
Vector{Vector{Pair{Int, Int}}}())
end

next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder
next_jump_time = typemax(T)
rx_rates = LowHigh(RxRates(num_sites(spatial_system), majumps),
RxRates(num_sites(spatial_system), majumps);
do_copy = false) # do not copy ma_jumps
hop_rates = LowHigh(HopRates(hopping_constants, spatial_system),
HopRates(hopping_constants, spatial_system);
do_copy = false) # do not copy hopping_constants
site_rates_high = zeros(T, num_sites(spatial_system))
bd = (bracket_data === nothing) ? BracketData{T, eltype(starting_state)}() :
bracket_data
u_low_high = LowHigh(starting_state)

DirectCRRSSAJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high,
rx_rates, hop_rates,
site_rates_high, save_positions, rng, spatial_system;
num_specs = num_species, kwargs...)
end

# set up a new simulation and calculate the first jump / jump time
function initialize!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t)
p.end_time = integrator.sol.prob.tspan[2]
fill_rates_and_get_times!(p, integrator, t)
generate_jumps!(p, integrator, u, params, t)
nothing
end

# calculate the next jump / jump time
function generate_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t)
@unpack rng, rt, site_rates_high, rx_rates, hop_rates, spatial_system = p
time_delta = zero(t)
while true
site = sample(rt, site_rates_high, rng)
jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng)
time_delta += randexp(rng)
if accept_jump(p, u, jump)
p.next_jump_time = t + time_delta / groupsum(rt)
p.next_jump = jump
break
end
Comment on lines +139 to +145
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handle if there is no next jump?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I do not rembmer any of the logic by now. Do we have a test for what you are asking? If not, we should, for all SSAs.

end
nothing
end

# execute one jump, changing the system state
function execute_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t,
affects!)
update_state!(p, integrator)
update_dependent_rates!(p, integrator, t)
nothing
end

######################## SSA specific helper routines ########################
# Return true if site is accepted.
function accept_jump(p, u, jump)
if is_hop(p, jump)
return accept_hop(p, u, jump)
else
return accept_rx(p, u, jump)
end
end
Vilin97 marked this conversation as resolved.
Show resolved Hide resolved

function accept_hop(p, u, jump)
@unpack hop_rates, spatial_system, rng = p
species, site = jump.jidx, jump.src
acceptance_threshold = rand(rng) * hop_rate(hop_rates.high, species, site)
if hop_rate(hop_rates.low, species, site) > acceptance_threshold
return true
else
# compute the real rate. Could have used hop_rates.high as well.
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
real_rate = evalhoprate(hop_rates.low, u, species, site, spatial_system)
return real_rate > acceptance_threshold
end
end

function accept_rx(p, u, jump)
@unpack rx_rates, rng = p
rx, site = reaction_id_from_jump(p, jump), jump.src
acceptance_threshold = rand(rng) * rx_rate(rx_rates.high, rx, site)
if rx_rate(rx_rates.low, rx, site) > acceptance_threshold
return true
else
# compute the real rate. Could have used rx_rates.high as well.
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
real_rate = evalrxrate(rx_rates.low, u, rx, site)
return real_rate > acceptance_threshold
end
end

"""
fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, u, t)

reset all stucts, reevaluate all rates, repopulate the priority table
"""
function fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, integrator, t)
@unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates_high, rt = aggregation
u = integrator.u
update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix)

reset!(rx_rates)
reset!(hop_rates)
fill!(site_rates_high, zero(eltype(site_rates_high)))

rxs = 1:num_rxs(rx_rates.low)
species = 1:(aggregation.numspecies)

for site in 1:num_sites(spatial_system)
update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site)
update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system)
site_rates_high[site] = total_site_rate(rx_rates.high, hop_rates.high, site)
end

# setup PriorityTable
reset!(rt)
for (pid, priority) in enumerate(site_rates_high)
insert!(rt, pid, priority)
end
nothing
end

"""
update_dependent_rates!(p, integrator, t)

recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump)
"""
function update_dependent_rates!(p::DirectCRRSSAJumpAggregation, integrator, t)
jump = p.prev_jump
if is_hop(p, jump)
update_brackets!(p, integrator, jump.jidx, (jump.src, jump.dst))
else
update_brackets!(p, integrator, p.jumptovars_map[reaction_id_from_jump(p, jump)], jump.src)
end
end

function update_brackets!(p, integrator, species_to_update, sites_to_update)
@unpack rx_rates, hop_rates, site_rates_high, u_low_high, bracket_data, vartojumps_map, spatial_system = p
u = integrator.u
for site in sites_to_update, species in species_to_update
Vilin97 marked this conversation as resolved.
Show resolved Hide resolved
if !is_inside_brackets(u_low_high, u, species, site)
update_u_brackets!(u_low_high, bracket_data, u, species, site)
update_rx_rates!(rx_rates,
vartojumps_map[species],
u_low_high,
integrator,
site)
update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system)

oldrate = site_rates_high[site]
site_rates_high[site] = total_site_rate(rx_rates.high, hop_rates.high, site)
update!(p.rt, site, oldrate, site_rates_high[site])
end
end
nothing
end

"""
num_constant_rate_jumps(aggregator::DirectCRRSSAJumpAggregation)

number of constant rate jumps
"""
num_constant_rate_jumps(aggregator::DirectCRRSSAJumpAggregation) = 0
Loading