Skip to content

Commit

Permalink
Make user-defined optimization more flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
nossleinad committed Oct 30, 2024
1 parent d8393fc commit 8e28265
Showing 1 changed file with 19 additions and 28 deletions.
47 changes: 19 additions & 28 deletions src/core/algorithms/root_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ function steal_messages!(new_root::FelNode, old_root::FelNode)
new_root.child_messages = old_root.child_messages
end

function default_root_LL_wrapper(parent_message::Vector{<:Partition})
function root_LL!(message::Vector{<:Partition})
combine!.(message, parent_message)
return parent_message, sum(total_LL.(message))
end
return root_LL!
end

"""
root_optim!(tree::FelNode, models; <keyword arguments>)
Expand All @@ -35,36 +43,32 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th
# Keyword Arguments
- `partition_list=1:length(tree.message)`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize root position and root state with all models, the default option).
- `starting_message_modifier!=(objective, starting_message0::Vector{<:Partition}) -> starting_message0`: (can either be) an optimizer (or a sampler) of the root state. `objective` returns the log likelihood of a root state (and implicitly, a root position).
- `starting_message0::Vector{<:Partition}=copy_message(tree.parent_message[partition_list])`: the initial starting message used by `starting_message_modifier!`.
- `root_LL!=default_root_LL_wrapper(tree.parent_message[partition_list])`: a function that takes a message and returns a (optimal) parent message and LL (log likelihood). The default option uses the constant `tree.parent_message[partition_list]` as parent message for all root-candidates.
- `K=10`: the number of equidistant root-candidate points along a branch. (only to be used in the frequentist framework!?)
"""
function root_optim!(
tree::FelNode,
models;
partition_list = 1:length(tree.message),
starting_message_modifier! = (objective, starting_message0::Vector{<:Partition}) -> starting_message0,
starting_message0::Vector{<:Partition} = copy_message(tree.parent_message[partition_list]),
root_LL! = default_root_LL_wrapper(tree.parent_message[partition_list]),
K = 10 #Number of root-candidate points on a branch
)
#Initialize some messages
node_message = copy_message(tree.parent_message[partition_list])
node_starting_message = copy_message(tree.parent_message[partition_list])
temp_message = copy_message(tree.parent_message[partition_list])

#Do most of the message passing
felsenstein_roundtrip!(tree, models, partition_list = partition_list, temp_message = temp_message)

#Initialize the fallback optimum
opt_root = tree
opt_dist = 0.0
opt_LL = log_likelihood(tree, models, partition_list = partition_list)
opt_starting_message = copy_message(tree.parent_message[partition_list])

#Do most of the message passing
felsenstein_roundtrip!(tree, models, partition_list = partition_list, temp_message = temp_message)

#Optimize the root position + root state
nodelist = getnodelist(tree)
for node in nodelist
copy_partition_to!.(node_starting_message, starting_message0)
model_list = models(node)
for dist_above_node in unique(range(0.0, node.branchlength, K + 1)[1:end-1])
# unique() to avoid recomputations
Expand All @@ -77,17 +81,8 @@ function root_optim!(
model_list,
partition_list = partition_list
)
function objective(starting_message::Vector{<:Partition})
for p = 1:length(partition_list)
#... combine it with a root state...
combine!(temp_message[p], [starting_message[p], node_message[p]], true)
end
#... and get the total log likelihood.
return sum(total_LL.(temp_message))
end
node_starting_message = starting_message_modifier!(objective, node_starting_message)
#Reuse this as the starting_message0 for the next iteration
LL = objective(node_starting_message)
node_starting_message, LL = root_LL!(node_message)
#TODO: enable root sampling
if LL > opt_LL
opt_root, opt_dist, opt_LL = node, dist_above_node, LL
copy_partition_to!.(opt_starting_message, node_starting_message)
Expand All @@ -104,30 +99,26 @@ root_optim!(
tree::FelNode,
models::Vector{<:BranchModel};
partition_list = 1:length(tree.message),
starting_message_modifier! = (objective, starting_message0::Vector{<:Partition}) -> starting_message0,
starting_message0::Vector{<:Partition} = copy_message(tree.parent_message[partition_list]),
root_LL! = default_root_LL_wrapper(tree.parent_message[partition_list]),
K = 10 #Number of root-candidate points on a branch
) = root_optim!(
tree,
x -> models,
partition_list = partition_list,
starting_message_modifier! = starting_message_modifier!,
starting_message0 = starting_message0,
root_LL! = root_LL!,
K = K
)

root_optim!(
tree::FelNode,
model::BranchModel;
partition_list = 1:length(tree.message),
starting_message_modifier! = (objective, starting_message0::Vector{<:Partition}) -> starting_message0,
starting_message0::Vector{<:Partition} = copy_message(tree.parent_message[partition_list]),
root_LL! = default_root_LL_wrapper(tree.parent_message[partition_list]),
K = 10 #Number of root-candidate points on a branch
) = root_optim!(
tree,
x -> [model],
partition_list = partition_list,
starting_message_modifier! = starting_message_modifier!,
starting_message0 = starting_message0,
root_LL! = root_LL!,
K = K
)

0 comments on commit 8e28265

Please sign in to comment.