From 6d31c6b85eb2a9c663cfd5fe6c0026f85b2cf369 Mon Sep 17 00:00:00 2001 From: Maximilian Danielsson Date: Wed, 21 Aug 2024 12:31:41 +0200 Subject: [PATCH] Go with full traversal --- src/core/algorithms/nni_optim.jl | 104 +------------------------------ 1 file changed, 1 insertion(+), 103 deletions(-) diff --git a/src/core/algorithms/nni_optim.jl b/src/core/algorithms/nni_optim.jl index 6e8cc84..d75d814 100644 --- a/src/core/algorithms/nni_optim.jl +++ b/src/core/algorithms/nni_optim.jl @@ -1,13 +1,3 @@ -#= -About clades getting skipped: -- the iterative implementation perfectly mimics the recursive one (they can both skip clades) -- some nnis can lead to some clades not getting optimized and some getting optimized multiple times -- I could push "every other" during first down and use lastind to know if a clade's been visisted, if a sibling clade's not been visited, I'll simply not fel-up yet but continue down -- -- Sanity checks: compare switch_LL with log_likelihood! of deepcopied tree with said switch -full_traversal passed the sanity check -=# - #After a do_nni, we have to update parent_message if we want to continue down (assume that temp_message is the forwarded parent.parent_message) function update_parent_message!( node::FelNode, @@ -29,7 +19,7 @@ function update_parent_message!( end end -function nni_optim_full_traversal!( +function nni_optim!( temp_messages::Vector{Vector{T}}, tree::FelNode, models, @@ -147,97 +137,6 @@ function nni_optim_full_traversal!( end end -function nni_optim!( - temp_messages::Vector{Vector{T}}, - tree::FelNode, - models, - partition_list; - acc_rule = (x, y) -> x > y, - traversal = Iterators.reverse -) where {T <: Partition} - - #Consider a NamedTuple/struct - stack = [(pop!(temp_messages), tree, 1, 1, true, true)] - while !isempty(stack) - temp_message, node, ind, lastind, first, down = pop!(stack) - #We start out with a regular downward pass... - #(except for some extra bookkeeping to track if node is visited for the first time) - #------------------- - if isleafnode(node) - push!(temp_messages, temp_message) - continue - end - if down - if first - model_list = models(node) - for part in partition_list - forward!( - temp_message[part], - node.parent_message[part], - model_list[part], - node, - ) - end - @assert length(node.children) <= 2 - #Temp must be constant between iterations for a node during down... - child_iter = traversal(1:length(node.children)) - lastind = Base.first(child_iter) #(which is why we track the last child to be visited during down) - push!(stack, (Vector{T}(), node, ind, lastind, false, false)) #... but not up - for i = child_iter #Iterative reverse <=> Recursive non-reverse, also optimal for lazysort!?? - push!(stack, (temp_message, node, i, lastind, false, true)) - end - end - if !first - sib_inds = sibling_inds(node.children[ind]) - for part in partition_list - combine!( - (node.children[ind]).parent_message[part], - [mess[part] for mess in node.child_messages[sib_inds]], - true, - ) - combine!( - (node.children[ind]).parent_message[part], - [temp_message[part]], - false, - ) - end - #But calling nni_optim! recursively... (the iterative equivalent) - push!(stack, (safepop!(temp_messages, temp_message), node.children[ind], ind, lastind, true, true)) #first + down combination => safepop! - ind == lastind && push!(temp_messages, temp_message) #We no longer need constant temp - end - end - if !down - #Then combine node.child_messages into node.message... - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - end - #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. - #------------------- - if !isroot(node) - temp_message = pop!(temp_messages) - model_list = models(node) - nnid, exceed_sib, exceed_child = do_nni( - node, - temp_message, - models; - partition_list = partition_list, - acc_rule = acc_rule, - ) - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - backward!(node.parent.child_messages[ind][part], node.message[part], model_list[part], node) - combine!( - node.parent.message[part], - [mess[part] for mess in node.parent.child_messages], - true, - ) - end - push!(temp_messages, temp_message) - end - end - end -end - #Unsure if this is the best choice to handle the model,models, and model_func stuff. function nni_optim!( temp_messages::Vector{Vector{T}}, @@ -415,7 +314,6 @@ function nni_optim!( partition_list = 1:length(tree.message) end - #Need to decide here between nni_optim and nni_optim_full_traversal nni_optim!( temp_messages, tree,