Skip to content

Commit

Permalink
Go with full traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
nossleinad committed Aug 21, 2024
1 parent bd0961d commit 6d31c6b
Showing 1 changed file with 1 addition and 103 deletions.
104 changes: 1 addition & 103 deletions src/core/algorithms/nni_optim.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
#=
About clades getting skipped:
- the iterative implementation perfectly mimics the recursive one (they can both skip clades)
- some nnis can lead to some clades not getting optimized and some getting optimized multiple times
- I could push "every other" during first down and use lastind to know if a clade's been visisted, if a sibling clade's not been visited, I'll simply not fel-up yet but continue down
-
- Sanity checks: compare switch_LL with log_likelihood! of deepcopied tree with said switch
full_traversal passed the sanity check
=#

#After a do_nni, we have to update parent_message if we want to continue down (assume that temp_message is the forwarded parent.parent_message)
function update_parent_message!(
node::FelNode,
Expand All @@ -29,7 +19,7 @@ function update_parent_message!(
end
end

function nni_optim_full_traversal!(
function nni_optim!(
temp_messages::Vector{Vector{T}},
tree::FelNode,
models,
Expand Down Expand Up @@ -147,97 +137,6 @@ function nni_optim_full_traversal!(
end
end

function nni_optim!(
temp_messages::Vector{Vector{T}},
tree::FelNode,
models,
partition_list;
acc_rule = (x, y) -> x > y,
traversal = Iterators.reverse
) where {T <: Partition}

#Consider a NamedTuple/struct
stack = [(pop!(temp_messages), tree, 1, 1, true, true)]
while !isempty(stack)
temp_message, node, ind, lastind, first, down = pop!(stack)
#We start out with a regular downward pass...
#(except for some extra bookkeeping to track if node is visited for the first time)
#-------------------
if isleafnode(node)
push!(temp_messages, temp_message)
continue
end
if down
if first
model_list = models(node)
for part in partition_list
forward!(
temp_message[part],
node.parent_message[part],
model_list[part],
node,
)
end
@assert length(node.children) <= 2
#Temp must be constant between iterations for a node during down...
child_iter = traversal(1:length(node.children))
lastind = Base.first(child_iter) #(which is why we track the last child to be visited during down)
push!(stack, (Vector{T}(), node, ind, lastind, false, false)) #... but not up
for i = child_iter #Iterative reverse <=> Recursive non-reverse, also optimal for lazysort!??
push!(stack, (temp_message, node, i, lastind, false, true))
end
end
if !first
sib_inds = sibling_inds(node.children[ind])
for part in partition_list
combine!(
(node.children[ind]).parent_message[part],
[mess[part] for mess in node.child_messages[sib_inds]],
true,
)
combine!(
(node.children[ind]).parent_message[part],
[temp_message[part]],
false,
)
end
#But calling nni_optim! recursively... (the iterative equivalent)
push!(stack, (safepop!(temp_messages, temp_message), node.children[ind], ind, lastind, true, true)) #first + down combination => safepop!
ind == lastind && push!(temp_messages, temp_message) #We no longer need constant temp
end
end
if !down
#Then combine node.child_messages into node.message...
for part in partition_list
combine!(node.message[part], [mess[part] for mess in node.child_messages], true)
end
#But now we need to optimize the current node, and then prop back up to set your parents children message correctly.
#-------------------
if !isroot(node)
temp_message = pop!(temp_messages)
model_list = models(node)
nnid, exceed_sib, exceed_child = do_nni(
node,
temp_message,
models;
partition_list = partition_list,
acc_rule = acc_rule,
)
for part in partition_list
combine!(node.message[part], [mess[part] for mess in node.child_messages], true)
backward!(node.parent.child_messages[ind][part], node.message[part], model_list[part], node)
combine!(
node.parent.message[part],
[mess[part] for mess in node.parent.child_messages],
true,
)
end
push!(temp_messages, temp_message)
end
end
end
end

#Unsure if this is the best choice to handle the model,models, and model_func stuff.
function nni_optim!(
temp_messages::Vector{Vector{T}},
Expand Down Expand Up @@ -415,7 +314,6 @@ function nni_optim!(
partition_list = 1:length(tree.message)
end

#Need to decide here between nni_optim and nni_optim_full_traversal
nni_optim!(
temp_messages,
tree,
Expand Down

0 comments on commit 6d31c6b

Please sign in to comment.