Skip to content

Commit

Permalink
Updated the test and made some minor adjustments based on the feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Theodor Björk authored and Theodor Björk committed Sep 1, 2024
1 parent fbf74a2 commit 409c5ba
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 21 deletions.
11 changes: 7 additions & 4 deletions src/bayes/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
midpoint_rooting=false,
)
Samples tree topologies from a posterior distribution. felsenstein! should be called on the initial tree before calling this function.
Samples tree topologies from a posterior distribution.
# Arguments
- `initial_tree`: An initial topology with (important!) the leaves populated with data, for the likelihood calculation.
- `initial_tree`: An initial tree topology with the leaves populated with data, for the likelihood calculation.
- `models`: A list of branch models.
- `num_of_samples`: The number of tree samples drawn from the posterior.
- `bl_sampler`: Sampler used to drawn branchlengths from the posterior.
Expand All @@ -22,8 +22,11 @@ Samples tree topologies from a posterior distribution. felsenstein! should be ca
- `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees.
- `midpoint_rooting`: Specifies whether the drawn samples should be midpoint rerooted (Important! Should only be used for time-reversible branch models starting in equilibrium).
!!! note
The leaves of the initial tree should be populated with data and felsenstein! should be called on the initial tree before calling this function.
# Returns
- `samples`: The trees drawn from the posterior.
- `samples`: The trees drawn from the posterior. Returns shallow tree copies, which needs to be repopulated before running felsenstein! etc.
- `sample_LLs`: The associated log-likelihoods of the tree (optional).
"""
function metropolis_sample(
Expand All @@ -50,7 +53,7 @@ function metropolis_sample(
for i=1:iterations

# Updates the tree topolgy and branchlengths.
nni_optim!(tree, x -> models, nni_selection_rule = softmax_sampler)
nni_optim!(tree, x -> models, selection_rule = softmax_sampler)
branchlength_optim!(tree, x -> models, bl_modifier = bl_sampler)

if (i-burn_in) % sample_interval == 0 && i > burn_in
Expand Down
4 changes: 2 additions & 2 deletions src/core/algorithms/branchlength_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function branchlength_optim!(
temp_message = pop!(temp_messages)
model_list = models(node)
fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list)
bl = univariate_modifier(fun, bl_modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength)
bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength)
if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt)
node.branchlength = bl
end
Expand All @@ -110,7 +110,7 @@ function branchlength_optim!(
#-------------------
model_list = models(node)
fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list)
bl = univariate_modifier(fun, bl_modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength)
bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength)
if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt)
node.branchlength = bl
end
Expand Down
22 changes: 11 additions & 11 deletions src/core/algorithms/nni_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function nni_optim!(
tree::FelNode,
models,
partition_list;
nni_selection_rule = (x) -> argmax(x),
selection_rule = x -> argmax(x),
traversal = Iterators.reverse
) where {T <: Partition}

Expand Down Expand Up @@ -94,7 +94,7 @@ function nni_optim!(
temp_message,
models;
partition_list = partition_list,
nni_selection_rule = nni_selection_rule,
selection_rule = selection_rule,
)
if nnid && last(last(stack)) #We nnid a sibling that hasn't been visited (then, down would be true in the next iter)...
#... and now we want to continue down the nnid sibling (now a child to node)
Expand Down Expand Up @@ -143,15 +143,15 @@ function nni_optim!(
tree::FelNode,
models::Vector{<:BranchModel},
partition_list;
nni_selection_rule = (x) -> argmax(x),
selection_rule = x -> argmax(x),
traversal = Iterators.reverse,
) where {T <: Partition}
nni_optim!(
temp_messages,
tree,
x -> models,
partition_list,
nni_selection_rule = nni_selection_rule,
selection_rule = selection_rule,
traversal = traversal,
)
end
Expand All @@ -160,7 +160,7 @@ function nni_optim!(
tree::FelNode,
model::BranchModel,
partition_list;
nni_selection_rule = (x) -> argmax(x),
selection_rule = x -> argmax(x),
traversal = Iterators.reverse,

) where {T <: Partition}
Expand All @@ -169,7 +169,7 @@ function nni_optim!(
tree,
x -> [model],
partition_list,
nni_selection_rule = nni_selection_rule,
selection_rule = selection_rule,
traversal = traversal,
)
end
Expand All @@ -179,7 +179,7 @@ function do_nni(
temp_message,
models::F;
partition_list = 1:length(node.message),
nni_selection_rule = (x) -> argmax(x),
selection_rule = x -> argmax(x),
) where {F<:Function}
if length(node.children) == 0 || node.parent === nothing
return false
Expand Down Expand Up @@ -260,7 +260,7 @@ function do_nni(
end
end

sampled_config_ind = nni_selection_rule(nni_LLs)
sampled_config_ind = selection_rule(nni_LLs)
change = sampled_config_ind != 1
(sampled_sib_ind, sampled_child_ind) = nni_configs[sampled_config_ind]

Expand Down Expand Up @@ -295,7 +295,7 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th
# Keyword Arguments
- `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models, the default option).
- `nni_selection_rule = (x) -> argmax(x)`: a function that takes the current and proposed log likelihoods and selects a nni configuration. Note that the current log likelihood is stored at x[1].
- `selection_rule = x -> argmax(x)`: a function that takes the current and proposed log likelihoods and selects a nni configuration. Note that the current log likelihood is stored at x[1].
- `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized.
- `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable.
- `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`.
Expand All @@ -304,7 +304,7 @@ function nni_optim!(
tree::FelNode,
models;
partition_list = nothing,
nni_selection_rule = (x) -> argmax(x),
selection_rule = x -> argmax(x),
sort_tree = false,
traversal = Iterators.reverse,
shuffle = false
Expand All @@ -321,7 +321,7 @@ function nni_optim!(
tree,
models,
partition_list,
nni_selection_rule = nni_selection_rule,
selection_rule = selection_rule,
traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal
)
end
2 changes: 1 addition & 1 deletion src/core/nodes/FelNode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ end
"""
function copy_tree(root::FelNode, shallow_copy=false)
Returns a untangled copy of the a tree. Optionally, the flag `shallow_copy` can be used to obtained a copy of the tree with only the names and branchlengths.
Returns an untangled copy of the tree. Optionally, the flag `shallow_copy` can be used to obtain a copy of the tree with only the names and branchlengths.
"""
function copy_tree(root::FelNode, shallow_copy=false)

Expand Down
2 changes: 1 addition & 1 deletion src/utils/simple_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct GoldenSectionOpt <: UnivariateOpt end
struct BrentsMethodOpt <: UnivariateOpt end

function univariate_modifier(fun, modifier::UnivariateOpt; a=0, b=1, transform=unit_transform, tol=10e-5, kwargs...)
return univariate_maximize(fun, a + tol, b - tol, unit_transform, modifier, tol)
return univariate_maximize(fun, a, b, unit_transform, modifier, tol)
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/utils/simple_sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ end
"""
BranchlengthSampler
A type that allows you to specify a additive proposal function in the log domain and a prior distrubution over the log of the branchlengths. It also holds the acceptance ratio acc_ratio (acc_ratio[1] stores the number of accepts, and acc_ratio[1] stores the number of rejects).
A type that allows you to specify a additive proposal function in the log domain and a prior distrubution over the log of the branchlengths. It also holds the acceptance ratio acc_ratio (acc_ratio[1] stores the number of accepts, and acc_ratio[2] stores the number of rejects).
"""
struct BranchlengthSampler <: UnivariateSampler
acc_ratio::Vector{Int}
Expand Down
2 changes: 1 addition & 1 deletion test/partition_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ begin
branchlength_optim!(tree, bm_models, partition_list = [1])
branchlength_optim!(tree, bm_models, partition_list = [2])
branchlength_optim!(tree, bm_models)
branchlength_optim!(tree, bm_models, bl_optimizer=BrentsMethodOpt())
branchlength_optim!(tree, bm_models, bl_modifier=BrentsMethodOpt())
branchlength_optim!(tree, x -> bm_models, partition_list = [2])
branchlength_optim!(tree, x -> bm_models)

Expand Down

0 comments on commit 409c5ba

Please sign in to comment.