From 03b6aff24a0815dc941b23fe691c6f1c51bb405c Mon Sep 17 00:00:00 2001 From: nossleinad Date: Sun, 27 Oct 2024 19:08:48 +0100 Subject: [PATCH 1/4] Add felsenstein_roundtrip and root_optim --- src/MolecularEvolution.jl | 2 + src/core/algorithms/algorithms.jl | 1 + src/core/algorithms/felsenstein.jl | 55 ++++++++ src/core/algorithms/root_optim.jl | 133 ++++++++++++++++++ src/models/compound_models/alwaysup.jl | 26 ++++ src/models/compound_models/compound_models.jl | 3 +- test/partition_selection.jl | 12 ++ 7 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/core/algorithms/root_optim.jl create mode 100644 src/models/compound_models/alwaysup.jl diff --git a/src/MolecularEvolution.jl b/src/MolecularEvolution.jl index f9648db..4aeb35e 100644 --- a/src/MolecularEvolution.jl +++ b/src/MolecularEvolution.jl @@ -102,6 +102,7 @@ export combine!, felsenstein!, felsenstein_down!, + felsenstein_roundtrip!, sample_down!, #endpoint_conditioned_sample_down!, log_likelihood!, @@ -120,6 +121,7 @@ export reroot!, nni_optim!, branchlength_optim!, + root_optim!, metropolis_sample, copy_tree, diff --git a/src/core/algorithms/algorithms.jl b/src/core/algorithms/algorithms.jl index 9a634a8..6ef9e81 100644 --- a/src/core/algorithms/algorithms.jl +++ b/src/core/algorithms/algorithms.jl @@ -2,6 +2,7 @@ include("felsenstein.jl") include("branchlength_optim.jl") include("lls.jl") include("nni_optim.jl") +include("root_optim.jl") include("ancestors.jl") include("generative.jl") diff --git a/src/core/algorithms/felsenstein.jl b/src/core/algorithms/felsenstein.jl index 9fe72b3..d030d43 100644 --- a/src/core/algorithms/felsenstein.jl +++ b/src/core/algorithms/felsenstein.jl @@ -164,3 +164,58 @@ function felsenstein_down!( temp_message = temp_message, ) end + +""" + felsenstein_roundtrip!(tree::FelNode, models; partition_list = 1:length(tree.message), temp_message = copy_message(tree.message[partition_list])) + +Should usually be called on the root of the tree. First propagates Felsenstein pass up from the tips to the root, +then propagates Felsenstein pass down from the root to the tips, with the direction of time reversed (i.e. forward! = backward!). +**This is useful when searching for the optimal root** (see [`root_optim!`](@ref)). +models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or +a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. +partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over. +""" +function felsenstein_roundtrip!( + tree::FelNode, + models; + partition_list = 1:length(tree.message), + temp_message = copy_message(tree.message[partition_list]), +) + parent_message = tree.parent_message[partition_list] #Store the parent message + tree.parent_message[partition_list] .= temp_message + identity!.(tree.parent_message[partition_list]) #Set the parent message to identity + + always_up_models(n::FelNode) = AlwaysUpModel.(models(n)) + felsenstein!(tree, always_up_models, partition_list = partition_list) + felsenstein_down!(tree, always_up_models, partition_list = partition_list) + + tree.parent_message[partition_list] .= parent_message #Restore the parent message +end + +function felsenstein_roundtrip!( + tree::FelNode, + models::Vector{<:BranchModel}; + partition_list = 1:length(tree.message), + temp_message = copy_message(tree.message[partition_list]), +) + felsenstein_roundtrip!( + tree, + x -> models, + partition_list = partition_list, + temp_message = temp_message, + ) +end + +function felsenstein_roundtrip!( + tree::FelNode, + model::BranchModel; + partition_list = 1:length(tree.message), + temp_message = copy_message(tree.message[partition_list]), +) + felsenstein_roundtrip!( + tree, + x -> [model], + partition_list = partition_list, + temp_message = temp_message, + ) +end \ No newline at end of file diff --git a/src/core/algorithms/root_optim.jl b/src/core/algorithms/root_optim.jl new file mode 100644 index 0000000..870b67e --- /dev/null +++ b/src/core/algorithms/root_optim.jl @@ -0,0 +1,133 @@ +#Assume that felsenstein_roundtrip! has been called +#Compute the log likelihood of observations below this root-candidate +function root_LL_below!( + dest::Vector{<:Partition}, + temp::Vector{<:Partition}, + dist_above_node::Real, + node::FelNode, + model_list::Vector{<:BranchModel}; + partition_list = 1:length(tree.message) +) + @assert 0.0 <= dist_above_node < node.branchlength || dist_above_node == node.branchlength == 0.0 #if dist_above_node == node.branchlength != 0.0, then it's node.parent with 0.0 dist_above_child that should be called + branchlength = node.branchlength + for (p, part) in enumerate(partition_list) + node.branchlength = dist_above_node + backward!(dest[p], node.message[part], model_list[part], node) + node.branchlength = branchlength - dist_above_node + backward!(temp[p], node.parent_message[part], model_list[part], node) + combine!(dest[p], temp[p]) + end + node.branchlength = branchlength +end + +function steal_messages!(new_root::FelNode, old_root::FelNode) + new_root.message = old_root.message + new_root.parent_message = old_root.parent_message + new_root.child_messages = old_root.child_messages +end + +""" + root_optim!(tree::FelNode, models; ) + +Optimizes the root position and root state of a tree. Returns the new, optimal root node. +models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or +a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. + +# Keyword Arguments +- `partition_list=1:length(tree.message)`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models, the default option). +- `starting_message_modifier!=(objective, starting_message::Vector{<:Partition}) -> starting_message`: (can either be) an optimizer (or a sampler) of the root state. +- `starting_message0::Vector{<:Partition}=copy_message(tree.parent_message[partition_list])`: the initial starting message used by `starting_message_modifier!`. +- `K=10`: the number of equidistant root-candidate points along a branch. (only to be used in the frequentist framework!?) +""" +function root_optim!( + tree::FelNode, + models; + partition_list = 1:length(tree.message), + starting_message_modifier! = (objective, starting_message0::Vector{<:Partition}) -> starting_message0, + starting_message0::Vector{<:Partition} = copy_message(tree.parent_message[partition_list]), + K = 10 #Number of root-candidate points on a branch +) + #Initialize some messages + node_message = copy_message(tree.parent_message[partition_list]) + node_starting_message = copy_message(tree.parent_message[partition_list]) + temp_message = copy_message(tree.parent_message[partition_list]) + + #Initialize the fallback optimum + opt_root = tree + opt_dist = 0.0 + opt_LL = log_likelihood(tree, models, partition_list = partition_list) + opt_starting_message = copy_message(tree.parent_message[partition_list]) + + #Do most of the message passing + felsenstein_roundtrip!(tree, models, partition_list = partition_list, temp_message = temp_message) + + #Optimize the root position + root state + nodelist = getnodelist(tree) + for node in nodelist + copy_partition_to!.(node_starting_message, starting_message0) + model_list = models(node) + for dist_above_node in unique(range(0.0, node.branchlength, K + 1)[1:end-1]) + # unique() to avoid recomputations + #Compute the log likelihood of observations below this root-candidate... + root_LL_below!( + node_message, + temp_message, + dist_above_node, + node, + model_list, + partition_list = partition_list + ) + function objective(starting_message::Vector{<:Partition}) + for p = 1:length(partition_list) + #... combine it with a root state... + combine!(temp_message[p], [starting_message[p], node_message[p]], true) + end + #... and get the total log likelihood. + return sum(total_LL.(temp_message)) + end + node_starting_message = starting_message_modifier!(objective, node_starting_message) + #Reuse this as the starting_message0 for the next iteration + LL = objective(node_starting_message) + if LL > opt_LL + opt_root, opt_dist, opt_LL = node, dist_above_node, LL + copy_partition_to!.(opt_starting_message, node_starting_message) + end + end + end + new_root = opt_root == tree ? tree : reroot!(opt_root, dist_above_child = opt_dist) #Maybe reroot! should take care of this? + steal_messages!(new_root, tree) + new_root.parent_message[partition_list] .= opt_starting_message + return new_root +end + +root_optim!( + tree::FelNode, + models::Vector{<:BranchModel}; + partition_list = 1:length(tree.message), + starting_message_modifier! = (objective, starting_message0::Vector{<:Partition}) -> starting_message0, + starting_message0::Vector{<:Partition} = copy_message(tree.parent_message[partition_list]), + K = 10 #Number of root-candidate points on a branch +) = root_optim!( + tree, + x -> models, + partition_list = partition_list, + starting_message_modifier! = starting_message_modifier!, + starting_message0 = starting_message0, + K = K + ) + +root_optim!( + tree::FelNode, + model::BranchModel; + partition_list = 1:length(tree.message), + starting_message_modifier! = (objective, starting_message0::Vector{<:Partition}) -> starting_message0, + starting_message0::Vector{<:Partition} = copy_message(tree.parent_message[partition_list]), + K = 10 #Number of root-candidate points on a branch +) = root_optim!( + tree, + x -> [model], + partition_list = partition_list, + starting_message_modifier! = starting_message_modifier!, + starting_message0 = starting_message0, + K = K + ) \ No newline at end of file diff --git a/src/models/compound_models/alwaysup.jl b/src/models/compound_models/alwaysup.jl new file mode 100644 index 0000000..34f75ff --- /dev/null +++ b/src/models/compound_models/alwaysup.jl @@ -0,0 +1,26 @@ +export AlwaysUpModel +mutable struct AlwaysUpModel{T} <: BranchModel where {T <: BranchModel} + model::T +end + +function backward!( + dest::Partition, + source::Partition, + model::AlwaysUpModel, + node::FelNode, +) + backward!(dest, source, model.model, node) +end + +function forward!( + dest::Partition, + source::Partition, + model::AlwaysUpModel, + node::FelNode, +) + backward!(dest, source, model.model, node) +end + +function eq_freq(model::AlwaysUpModel) + return eq_freq(model.model) +end \ No newline at end of file diff --git a/src/models/compound_models/compound_models.jl b/src/models/compound_models/compound_models.jl index d0c1b0c..c0f096b 100644 --- a/src/models/compound_models/compound_models.jl +++ b/src/models/compound_models/compound_models.jl @@ -1,4 +1,5 @@ include("swm.jl") include("bwm.jl") include("cat.jl") -include("covarion.jl") \ No newline at end of file +include("covarion.jl") +include("alwaysup.jl") \ No newline at end of file diff --git a/test/partition_selection.jl b/test/partition_selection.jl index fb7fbb9..1041d81 100644 --- a/test/partition_selection.jl +++ b/test/partition_selection.jl @@ -72,4 +72,16 @@ begin nni_optim!(tree, bm_models) nni_optim!(tree, x -> bm_models, partition_list = [2]) nni_optim!(tree, x -> bm_models) + + felsenstein_roundtrip!(tree, bm_models, partition_list = [1]) + felsenstein_roundtrip!(tree, bm_models, partition_list = [2]) + felsenstein_roundtrip!(tree, bm_models) + felsenstein_roundtrip!(tree, x -> bm_models, partition_list = [2]) + felsenstein_roundtrip!(tree, x -> bm_models) + + tree = root_optim!(tree, bm_models, partition_list = [1]) + tree = root_optim!(tree, bm_models, partition_list = [2]) + tree = root_optim!(tree, bm_models) + tree = root_optim!(tree, x -> bm_models, partition_list = [2]) + tree = root_optim!(tree, x -> bm_models) end From d8393fc57fa81968861deab3e042b726e466b8f1 Mon Sep 17 00:00:00 2001 From: nossleinad Date: Sun, 27 Oct 2024 20:00:17 +0100 Subject: [PATCH 2/4] Update docstring --- src/core/algorithms/root_optim.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/algorithms/root_optim.jl b/src/core/algorithms/root_optim.jl index 870b67e..fa3e191 100644 --- a/src/core/algorithms/root_optim.jl +++ b/src/core/algorithms/root_optim.jl @@ -34,8 +34,8 @@ models can either be a single model (if the messages on the tree contain just on a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. # Keyword Arguments -- `partition_list=1:length(tree.message)`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models, the default option). -- `starting_message_modifier!=(objective, starting_message::Vector{<:Partition}) -> starting_message`: (can either be) an optimizer (or a sampler) of the root state. +- `partition_list=1:length(tree.message)`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize root position and root state with all models, the default option). +- `starting_message_modifier!=(objective, starting_message0::Vector{<:Partition}) -> starting_message0`: (can either be) an optimizer (or a sampler) of the root state. `objective` returns the log likelihood of a root state (and implicitly, a root position). - `starting_message0::Vector{<:Partition}=copy_message(tree.parent_message[partition_list])`: the initial starting message used by `starting_message_modifier!`. - `K=10`: the number of equidistant root-candidate points along a branch. (only to be used in the frequentist framework!?) """ From 8e2826591c6bb095603538dfff4081aeaabc2d01 Mon Sep 17 00:00:00 2001 From: nossleinad Date: Wed, 30 Oct 2024 22:50:54 +0100 Subject: [PATCH 3/4] Make user-defined optimization more flexible --- src/core/algorithms/root_optim.jl | 47 +++++++++++++------------------ 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/src/core/algorithms/root_optim.jl b/src/core/algorithms/root_optim.jl index fa3e191..d88631e 100644 --- a/src/core/algorithms/root_optim.jl +++ b/src/core/algorithms/root_optim.jl @@ -26,6 +26,14 @@ function steal_messages!(new_root::FelNode, old_root::FelNode) new_root.child_messages = old_root.child_messages end +function default_root_LL_wrapper(parent_message::Vector{<:Partition}) + function root_LL!(message::Vector{<:Partition}) + combine!.(message, parent_message) + return parent_message, sum(total_LL.(message)) + end + return root_LL! +end + """ root_optim!(tree::FelNode, models; ) @@ -35,36 +43,32 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th # Keyword Arguments - `partition_list=1:length(tree.message)`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize root position and root state with all models, the default option). -- `starting_message_modifier!=(objective, starting_message0::Vector{<:Partition}) -> starting_message0`: (can either be) an optimizer (or a sampler) of the root state. `objective` returns the log likelihood of a root state (and implicitly, a root position). -- `starting_message0::Vector{<:Partition}=copy_message(tree.parent_message[partition_list])`: the initial starting message used by `starting_message_modifier!`. +- `root_LL!=default_root_LL_wrapper(tree.parent_message[partition_list])`: a function that takes a message and returns a (optimal) parent message and LL (log likelihood). The default option uses the constant `tree.parent_message[partition_list]` as parent message for all root-candidates. - `K=10`: the number of equidistant root-candidate points along a branch. (only to be used in the frequentist framework!?) """ function root_optim!( tree::FelNode, models; partition_list = 1:length(tree.message), - starting_message_modifier! = (objective, starting_message0::Vector{<:Partition}) -> starting_message0, - starting_message0::Vector{<:Partition} = copy_message(tree.parent_message[partition_list]), + root_LL! = default_root_LL_wrapper(tree.parent_message[partition_list]), K = 10 #Number of root-candidate points on a branch ) #Initialize some messages node_message = copy_message(tree.parent_message[partition_list]) - node_starting_message = copy_message(tree.parent_message[partition_list]) temp_message = copy_message(tree.parent_message[partition_list]) + #Do most of the message passing + felsenstein_roundtrip!(tree, models, partition_list = partition_list, temp_message = temp_message) + #Initialize the fallback optimum opt_root = tree opt_dist = 0.0 opt_LL = log_likelihood(tree, models, partition_list = partition_list) opt_starting_message = copy_message(tree.parent_message[partition_list]) - #Do most of the message passing - felsenstein_roundtrip!(tree, models, partition_list = partition_list, temp_message = temp_message) - #Optimize the root position + root state nodelist = getnodelist(tree) for node in nodelist - copy_partition_to!.(node_starting_message, starting_message0) model_list = models(node) for dist_above_node in unique(range(0.0, node.branchlength, K + 1)[1:end-1]) # unique() to avoid recomputations @@ -77,17 +81,8 @@ function root_optim!( model_list, partition_list = partition_list ) - function objective(starting_message::Vector{<:Partition}) - for p = 1:length(partition_list) - #... combine it with a root state... - combine!(temp_message[p], [starting_message[p], node_message[p]], true) - end - #... and get the total log likelihood. - return sum(total_LL.(temp_message)) - end - node_starting_message = starting_message_modifier!(objective, node_starting_message) - #Reuse this as the starting_message0 for the next iteration - LL = objective(node_starting_message) + node_starting_message, LL = root_LL!(node_message) + #TODO: enable root sampling if LL > opt_LL opt_root, opt_dist, opt_LL = node, dist_above_node, LL copy_partition_to!.(opt_starting_message, node_starting_message) @@ -104,15 +99,13 @@ root_optim!( tree::FelNode, models::Vector{<:BranchModel}; partition_list = 1:length(tree.message), - starting_message_modifier! = (objective, starting_message0::Vector{<:Partition}) -> starting_message0, - starting_message0::Vector{<:Partition} = copy_message(tree.parent_message[partition_list]), + root_LL! = default_root_LL_wrapper(tree.parent_message[partition_list]), K = 10 #Number of root-candidate points on a branch ) = root_optim!( tree, x -> models, partition_list = partition_list, - starting_message_modifier! = starting_message_modifier!, - starting_message0 = starting_message0, + root_LL! = root_LL!, K = K ) @@ -120,14 +113,12 @@ root_optim!( tree::FelNode, model::BranchModel; partition_list = 1:length(tree.message), - starting_message_modifier! = (objective, starting_message0::Vector{<:Partition}) -> starting_message0, - starting_message0::Vector{<:Partition} = copy_message(tree.parent_message[partition_list]), + root_LL! = default_root_LL_wrapper(tree.parent_message[partition_list]), K = 10 #Number of root-candidate points on a branch ) = root_optim!( tree, x -> [model], partition_list = partition_list, - starting_message_modifier! = starting_message_modifier!, - starting_message0 = starting_message0, + root_LL! = root_LL!, K = K ) \ No newline at end of file From 369ad629ec15049c06b715e01c71f45a8b56bb5e Mon Sep 17 00:00:00 2001 From: nossleinad Date: Wed, 30 Oct 2024 23:12:12 +0100 Subject: [PATCH 4/4] Make compat with 1.6 --- src/core/algorithms/root_optim.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/algorithms/root_optim.jl b/src/core/algorithms/root_optim.jl index d88631e..701ddef 100644 --- a/src/core/algorithms/root_optim.jl +++ b/src/core/algorithms/root_optim.jl @@ -70,7 +70,7 @@ function root_optim!( nodelist = getnodelist(tree) for node in nodelist model_list = models(node) - for dist_above_node in unique(range(0.0, node.branchlength, K + 1)[1:end-1]) + for dist_above_node in unique(range(0.0, node.branchlength, length=K+1)[1:end-1]) # unique() to avoid recomputations #Compute the log likelihood of observations below this root-candidate... root_LL_below!(