Skip to content

Commit

Permalink
Merge branch 'main' into Stack-based-refactor-pt.-2
Browse files Browse the repository at this point in the history
  • Loading branch information
nossleinad authored Aug 5, 2024
2 parents 894271f + 540bdf2 commit bd0961d
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 6 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MolecularEvolution"
uuid = "9f975960-e239-4209-8aa0-3d3ad5a82892"
authors = ["Ben Murrell <[email protected]> and contributors"]
version = "0.1.0"
version = "0.2.1"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -10,7 +10,8 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
julia = "1.6"
Distributions = "0.25"
LinearAlgebra = "1"
Requires = "1.3"
StatsBase = "0.34"
julia = "1.6"
2 changes: 1 addition & 1 deletion src/core/nodes/AbstractTreeNode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ function ladderize!(tree::T) where {T<:AbstractTreeNode}
end
end

# Creates a dictionary of all the child counts (including the node itself) which can then be used by ladderize to sort the nodes
# Creates a dictionary of all the child counts which can then be used by ladderize to sort the nodes
function countchildren(tree::T) where {T<:AbstractTreeNode}
# Initialize the dictionary to store the number of children for each node
children_count = Dict{T, Int}()
Expand Down
1 change: 1 addition & 0 deletions src/models/discrete_models/discrete_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ include("codon_models.jl")
include("GeneralCTMC.jl")
include("DiagonalizedCTMC.jl")
include("PiQ.jl")
include("interpolated_discrete_model.jl")
include("utils/utils.jl")
142 changes: 142 additions & 0 deletions src/models/discrete_models/interpolated_discrete_model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#InterpolatedDiscreteModel works by storing a number of P matrices, and the "t" values to which they correspond
#For a requested t, the returned P matrix is interpolated between it's two neighbours

function check_eq_P(P)
return maximum(std(P,dims = 1))
end

mutable struct InterpolatedDiscreteModel <: DiscreteStateModel
tvec::Vector{Float64}
Pvec::Array{Float64,3} #Now a tensor. Pvec[:,:,i] is the ith P matrix.
#Generator must be something that returns a P matrix with T as the only argument
function InterpolatedDiscreteModel(siz::Int64, generator, tvec::Vector{Float64})
@assert tvec[1] == 0.0
@assert issorted(tvec)
Pvec = zeros(siz,siz,length(tvec))
for (i,t) in enumerate(tvec)
Pvec[:,:,i] .= generator(t)
end
cp = check_eq_P(Pvec[:,:,end])
if cp > 10^-9
@warn "Max std dev of last P matrix is $(cp). Far from equilibrium - extend range"
end
new(tvec,Pvec)
end
#Alternative constructor where you directly specify the tensor of P matrices for each "t"
function InterpolatedDiscreteModel(Pvec::Array{Float64,3}, tvec::Vector{Float64})
@assert tvec[1] == 0.0
@assert issorted(tvec)
cp = check_eq_P(Pvec[:,:,end])
if cp > 10^-9
@warn "Max std dev of last P matrix is $(cp). Far from equilibrium - extend range"
end
new(tvec,Pvec)
end
end


#The keys in d must be the boundaries of the ranges for which we would index into
function range_index(v::Float64,ts::Vector{Float64})
p = searchsortedlast(ts,v) # index of the last key less than or equal to v
return p,p+1
end

function interp_weight(tup,p)
@assert tup[1] <= p && p <= tup[2]
w = (p - tup[1]) ./ (tup[2] - tup[1])
return 1-w,w
end

#We could maybe hang a destintion "matrix" on the model, and this would store the
#interpolated P to that, saving new allocations.
function matrix_interpolate(t_query,ts,Ps)
inds = range_index(t_query,ts)
if inds[2] > length(ts)
return Ps[:,:,end]
end
w = interp_weight((ts[inds[1]],ts[inds[2]]),t_query)
approxP = w[1].*Ps[:,:,inds[1]] .+ w[2].*Ps[:,:,inds[2]]
return approxP
end

function backward!(
dest::DiscretePartition,
source::DiscretePartition,
model::InterpolatedDiscreteModel,
node::FelNode)
P = matrix_interpolate(node.branchlength, model.tvec, model.Pvec)
mul!(dest.state, P, source.state)
dest.scaling .= source.scaling
end

function forward!(
dest::DiscretePartition,
source::DiscretePartition,
model::InterpolatedDiscreteModel,
node::FelNode)
P = matrix_interpolate(node.branchlength, model.tvec, model.Pvec)
dest.state .= (source.state'*P)'
dest.scaling .= source.scaling
end

function eq_freq(model::InterpolatedDiscreteModel)
model.Pvec[1,:,end]
end

#step: Higher numbers mean smaller jumps
#cap: After this many points it starts doubling
function t_sequence(t::Float64,n::Int64; step = 2 ,cap = n - 10)
ts = zeros(n)
ts[1] = 0.0 #Note setting the first one to the zero
ts[2] = t
c = 2
for i in 3:n
ts[i] = ts[i-1]+ts[c]
if mod(i,step)==0
c += 1
elseif i > cap
c = i
end
end
return ts
end

function matrix_sequence(Q::Array{Float64,2},t::Float64,n::Int64; step = 2 ,cap = n - 10)
P = exp(Q .* t)
Ps = zeros(size(P)[1],size(P)[2],n) #Big stack of matrices
Ps[:,:,1] .= Diagonal(ones(size(P)[1])) #Note setting the first one to the identity
Ps[:,:,2] .= P
c = 2
for i in 3:n
Ps[:,:,i] .= Ps[:,:,i-1]*Ps[:,:,c]
if mod(i,step)==0
c += 1
elseif i > cap
c = i
end
end
return Ps
end

#This will take an existing InterpolatedDiscreteModel, and effectively scale all the "t" values in e^Qt.
function rescale!(m::InterpolatedDiscreteModel, factor::Float64)
m.tvec .= m.tvec ./ factor
end


#This is literally just a single P matrix. Maybe some uses, but likely for testing speed bounds
mutable struct PModel <: DiscreteStateModel
P::Array{Float64,2}
end
function backward!(
dest::DiscretePartition, source::DiscretePartition,
model::PModel, node::FelNode)
mul!(dest.state, model.P, source.state)
dest.scaling .= source.scaling
end
function forward!(
dest::DiscretePartition, source::DiscretePartition,
model::PModel, node::FelNode)
dest.state .= (source.state'*model.P)'
dest.scaling .= source.scaling
end
10 changes: 7 additions & 3 deletions src/utils/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ function populate_tree!(
data;
init_all_messages = true,
tolerate_missing = 1, #0 = error if missing; 1 = warn and set to missing data; 2 = set to missing data
leaf_name_transform = x -> x
)
if init_all_messages
internal_message_init!(tree, starting_message)
Expand All @@ -90,8 +91,8 @@ function populate_tree!(
end
name_dic = Dict(zip(names, 1:length(names)))
for n in getleaflist(tree)
if haskey(name_dic, n.name)
populate_message!(n.message, data[name_dic[n.name]])
if haskey(name_dic, leaf_name_transform(n.name))
populate_message!(n.message, data[name_dic[leaf_name_transform(n.name)]])
else
warn_str = n.name * " on tree but not found in names."
if tolerate_missing == 0
Expand All @@ -107,7 +108,7 @@ end


"""
populate_tree!(tree::FelNode, starting_message, names, data; init_all_messages = true, tolerate_missing = 1)
populate_tree!(tree::FelNode, starting_message, names, data; init_all_messages = true, tolerate_missing = 1, leaf_name_transform = x -> x)
Takes a tree, and a `starting_message` (which will serve as the memory template for populating messages all over the tree).
`starting_message` can be a message (ie. a vector of Partitions), but will also work with a single Partition (although the tree)
Expand All @@ -117,6 +118,7 @@ When a leaf on the tree has a name that doesn't match anything in `names`, then
- `tolerate_missing = 0`, an error will be thrown
- `tolerate_missing = 1`, a warning will be thrown, and the message will be set to the uninformative message (requires identity!(::Partition) to be defined)
- `tolerate_missing = 2`, the message will be set to the uninformative message, without warnings (requires identity!(::Partition) to be defined)
A renaming function that can eg. strip tags from the tree when matching leaf names with `names` can be passed to `leaf_name_transform`
"""
function populate_tree!(
tree::FelNode,
Expand All @@ -125,6 +127,7 @@ function populate_tree!(
data;
init_all_messages = true,
tolerate_missing = 1,
leaf_name_transform = x -> x
)
populate_tree!(
tree,
Expand All @@ -133,6 +136,7 @@ function populate_tree!(
data,
init_all_messages = init_all_messages,
tolerate_missing = tolerate_missing,
leaf_name_transform = leaf_name_transform
)
end

Expand Down

0 comments on commit bd0961d

Please sign in to comment.