Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Branch length estimation #50

Open
wants to merge 68 commits into
base: refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
83d6287
Add codebase skeleton for branch length estimation
sprillo Dec 26, 2020
93fcd99
Add branch length estimation, lineage simulator, and phylogeny simulator
sprillo Dec 26, 2020
c52b474
Create APIs and refactor
sprillo Dec 27, 2020
dfb7321
Move plotting code out of tree.py
sprillo Dec 27, 2020
55a0bec
Forgot to add files
sprillo Dec 27, 2020
4437e48
Add IIDExponentialBLEGridSearchCV with test
sprillo Dec 27, 2020
af23207
Stop returning log-likelihood in IIDExponentialBLE.estimate_branch_le…
sprillo Dec 27, 2020
8a92290
tuple border case bugfix
sprillo Dec 27, 2020
b04ef59
small bugfix
sprillo Dec 27, 2020
7ee49ac
IIDExponentialBLEGridSearchCV should check for solver errors in IIDEx…
sprillo Dec 28, 2020
b3f57b1
reorder imports
sprillo Dec 28, 2020
b781c23
Rename minimum_edge_length to minimum_branch_length
sprillo Dec 28, 2020
ef27e11
Rename T to tree
sprillo Dec 29, 2020
7fb4d38
docs
sprillo Dec 29, 2020
5d8f0ce
Run black
sprillo Dec 29, 2020
cf019e9
More Tree boilerplate
sprillo Jan 5, 2021
b50186e
Create branch_length_estimator package
sprillo Jan 5, 2021
1e19f65
Break up branch_length_estimator.py into smaller modules
sprillo Jan 5, 2021
32f5624
Add IIDExponentialPosteriorMeanBLE with tests
sprillo Jan 5, 2021
cb4bcd6
bugfix joint computation
sprillo Jan 5, 2021
5fddb6f
More testing
sprillo Jan 5, 2021
549b34e
Add to cython
sprillo Jan 5, 2021
a9fbf18
A simple birth process
sprillo Jan 6, 2021
1cdd402
more testing
sprillo Jan 6, 2021
23fec63
Posterior calibration test
sprillo Jan 6, 2021
73a52e1
Doc test. Enable slowtests.
sprillo Jan 6, 2021
a3b87a9
Multiprocessing in grid search (for IIDExponentialPosteriorMeanBLEGri…
sprillo Jan 6, 2021
e4cb173
Add joint log lokelihood computation classmethod to IIDExponentialPos…
sprillo Jan 6, 2021
aac3977
More numerical tests, better docs, better names
sprillo Jan 7, 2021
ac0994f
assert
sprillo Jan 7, 2021
62ff0c3
One more test, this time on data from the DREAM challenge.
sprillo Jan 7, 2021
ae0e9d6
Bugfix IIDExponentialBLE.log_likelihood returning np.nan instead of -…
sprillo Jan 7, 2021
cb571fd
Avoid cp.log(0)
sprillo Jan 7, 2021
d5dc2b6
Test that single processor & multiprocessing work
sprillo Jan 7, 2021
3d27e7a
Make the minimum branch length be in terms of the tree height
sprillo Jan 7, 2021
563c853
Allow choosing how to format branch lengths in tree newick representa…
sprillo Jan 9, 2021
0db9e8b
CV grid plotting
sprillo Jan 11, 2021
634e39b
Make up(.) include the division event
sprillo Jan 11, 2021
a8b87bb
Comments, remove unused methods from Tree
sprillo Jan 15, 2021
4debf9f
Merge branch 'refactor' into branch-length-estimation
sprillo Jan 20, 2021
3b69f5c
Use CassiopeiaTree for branch length estimation
sprillo Jan 21, 2021
5119e8b
Merge branch 'refactor' into branch-length-estimation
sprillo Jan 29, 2021
4d0a3b6
Remove duplicated code
sprillo Jan 29, 2021
c419804
check
sprillo Feb 4, 2021
ee797b5
Easy optimization of my bayesian estimator DP: only visit states with…
sprillo Feb 7, 2021
d36e554
Make some methods private
sprillo Feb 7, 2021
86c0ad0
Address some TODOs
sprillo Feb 7, 2021
dd4f0bf
Add c++ implementation of Bayesian estimator
sprillo Feb 8, 2021
569f8c0
Increase cpp bounds
sprillo Feb 8, 2021
90a2155
bounds
sprillo Feb 8, 2021
b33bb62
bounds
sprillo Feb 8, 2021
a24fff3
More nitro
sprillo Feb 8, 2021
dcce54d
requirements
sprillo Feb 9, 2021
505c802
requirements
sprillo Feb 9, 2021
e4c1dc7
requirements
sprillo Feb 9, 2021
452a978
Merge branch 'refactor' into branch-length-estimation
sprillo Feb 12, 2021
7f292eb
Resolve multifurcations
sprillo Feb 13, 2021
b398ea4
Forgot to add file
sprillo Feb 13, 2021
22fb385
Add TumorWithAFitSubclone
sprillo Feb 14, 2021
6f4eefd
Some goodies
sprillo Feb 14, 2021
cf94b2a
Resolving of multifurcations, and cell subsampler
sprillo Feb 17, 2021
ec7c03b
black
sprillo Feb 17, 2021
8baab4c
Enhance UniformCellSubsampler
sprillo Feb 17, 2021
4b5039a
Merge branch 'refactor' into branch-length-estimation
sprillo Feb 17, 2021
a1372c2
Remove print
sprillo Feb 23, 2021
ec87a1d
Merge branch 'refactor' into branch-length-estimation
sprillo Feb 24, 2021
d1fc030
Merge branch 'refactor' into branch-length-estimation
sprillo Feb 27, 2021
3e38683
Merge branch 'refactor' into branch-length-estimation
sprillo Mar 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cassiopeia/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .branch_length_estimator import (
BranchLengthEstimator,
IIDExponentialBLE,
IIDExponentialBLEGridSearchCV,
)
from .lineage_simulator import (
LineageSimulator,
PerfectBinaryTree,
PerfectBinaryTreeWithRootBranch,
)
from .lineage_tracing_simulator import (
LineageTracingSimulator,
IIDExponentialLineageTracer,
)
from .tree import Tree
325 changes: 325 additions & 0 deletions cassiopeia/tools/branch_length_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
import abc
import copy
from typing import List, Tuple

import cvxpy as cp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to add this to the setup.py, my tests break because this wasn't installed during make install

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

QQ: I noticed many of the requirements have a minimum version requirement, e.g. numpy > 1.17, is there a reason for that? I was thinking to just add cvxpy without a version requirement.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having version requirements is helpful for a couple of reasons -

(1) Sometimes, APIs are different between versions (e.g., networkx changes the way it calls some functionality in versions <2 and >2, I believe)
(2) Sometimes, earlier versions of packages throw warnings (this is especially relevant if, for example, a particular package relies on another package that has updated and is deprecating some feature). Requiring a minimal version (or maximal version, for that matter) may help you get rid of warnings being thrown around.

In general, it's likely that having an outdated version of a particular package will be fine and the worst thing that will happen is that a warning might be thrown. However, I think it's good for software reproducibility so that when a new user comes along they can create a very similar environment to the developer. Hope this makes sense!

import numpy as np

from .tree import Tree


class BranchLengthEstimator(abc.ABC):
r"""
Abstract base class for all branch length estimators.

A BranchLengthEstimator implements a method estimate_branch_lengths which,
given a Tree with lineage tracing character vectors at the leaves (and
possibly at the internal nodes too), estimates the branch lengths of the
tree.
"""

@abc.abstractmethod
def estimate_branch_lengths(self, tree: Tree) -> None:
mattjones315 marked this conversation as resolved.
Show resolved Hide resolved
r"""
Estimates the branch lengths of the tree.

Annotates the tree's nodes with their estimated age, and
the tree's branches with their estiamted lengths. Operates on the tree
in-place.

Args:
tree: The tree for which to estimate branch lengths.
"""


class IIDExponentialBLE(BranchLengthEstimator):
r"""
A simple branch length estimator that assumes that the characters evolve IID
over the phylogeny with the same cutting rate.

This estimator requires that the ancestral states are provided.

The optimization problem is a special kind of convex program called an
exponential cone program:
https://docs.mosek.com/modeling-cookbook/expo.html
Because it is a convex program, it can be readily solved.

Args:
minimum_branch_length: Estimated branch lengths will be constrained to
have at least this lenght.
mattjones315 marked this conversation as resolved.
Show resolved Hide resolved
l2_regularization: Consecutive branches will be regularized to have
similar length via an L2 penalty whose weight is given by
l2_regularization.
verbose: Verbosity level.

Attributes:
log_likelihood: The log-likelihood of the training data under the
estimated model.
log_loss: The log-loss of the training data under the estimated model.
This is the log likhelihood plus the regularization terms.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo -> "likelihood"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

"""

def __init__(
self,
minimum_branch_length: float = 0,
l2_regularization: float = 0,
verbose: bool = False,
):
self.minimum_branch_length = minimum_branch_length
self.l2_regularization = l2_regularization
self.verbose = verbose

def estimate_branch_lengths(self, tree: Tree) -> None:
r"""
See base class. The only caveat is that this method raises if it fails
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incomplete sentence: "this method raises [an error?]..."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've always used "raises" for short to mean that it "raises an error", but it might not be proper english.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah - you might have more software engineering lingo than I do! If it's not too much of a problem, let's be more verbose here in the docstring.

Can you also add a Raises section to the docstring (ala Args and Returns)?

to solve the underlying optimization problem for any reason.

Raises:
cp.error.SolverError
"""
# Extract parameters
minimum_branch_length = self.minimum_branch_length
l2_regularization = self.l2_regularization
verbose = self.verbose

# # Wrap the networkx DiGraph for goodies.
# tree = Tree(tree)
mattjones315 marked this conversation as resolved.
Show resolved Hide resolved

# # # # # Create variables of the optimization problem # # # # #
r_X_t_variables = dict(
[
(node_id, cp.Variable(name=f"r_X_t_{node_id}"))
for node_id in tree.nodes()
]
)
time_increases_constraints = [
r_X_t_variables[parent]
>= r_X_t_variables[child] + minimum_branch_length
for (parent, child) in tree.edges()
]
leaves_have_age_0_constraints = [
r_X_t_variables[leaf] == 0 for leaf in tree.leaves()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be an ultrametric condition - this is okay for our purposes right now, but what if you were attempting to do this for a tree where this did not hold? Do we have a separate estimator without this condition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the ultrametric condition. We don't have an estimator without this condition. As of today I don't think we need it, but if we do I guess we can just add this option to this estimator as an argument?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's totally fine to keep this condition in there. It might be worthwhile to add this condition, somehow, to the name of the estimator like SimpleUltrametricBLE ( I know this is super long, but does that idea make sense?). That way, the estimator can be extended later into another class that does not use the Ultrametric condition.

]
non_negative_r_X_t_constraints = [
r_X_t >= 0 for r_X_t in r_X_t_variables.values()
]
all_constraints = (
time_increases_constraints
+ leaves_have_age_0_constraints
+ non_negative_r_X_t_constraints
)

# # # # # Compute the log-likelihood # # # # #
log_likelihood = 0

# Because all rates are equal, the number of cuts in each node is a
# sufficient statistic. This makes the solver WAY faster!
for (parent, child) in tree.edges():
edge_length = r_X_t_variables[parent] - r_X_t_variables[child]
# TODO: hardcoded '0' here...
mattjones315 marked this conversation as resolved.
Show resolved Hide resolved
zeros_parent = tree.get_state(parent).count("0")
zeros_child = tree.get_state(child).count("0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we're assuming throughout here that the internal nodes have states. This isn't always the case, but we can always infer back the states at internal nodes.

Should you have a check at the top that the states of internal nodes exist?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as per the docstring: This estimator requires that the ancestral states are provided. . The estimator doesn't sanity-check the input tree currently, but we can do that. (In general, there's a bunch of sanity checks that could be added in many places and aren't there because it wasn't my focus when coding.)

As a note, the estimator currently does not have the responsibility of imputing ancestral states if they are not provided.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally makes sense! When this code is deployed, let's make sure that an error is thrown if ancestral states are not given.

new_cuts_child = zeros_parent - zeros_child
assert new_cuts_child >= 0
mattjones315 marked this conversation as resolved.
Show resolved Hide resolved
# Add log-lik for characters that didn't get cut
log_likelihood += zeros_child * (-edge_length)
# Add log-lik for characters that got cut
log_likelihood += new_cuts_child * cp.log(1 - cp.exp(-edge_length))

# # # # # Add regularization # # # # #

l2_penalty = 0
for (parent, child) in tree.edges():
for child_of_child in tree.children(child):
edge_length_above = (
r_X_t_variables[parent] - r_X_t_variables[child]
)
edge_length_below = (
r_X_t_variables[child] - r_X_t_variables[child_of_child]
)
l2_penalty += (edge_length_above - edge_length_below) ** 2
l2_penalty *= l2_regularization

# # # # # Solve the problem # # # # #

obj = cp.Maximize(log_likelihood - l2_penalty)
prob = cp.Problem(obj, all_constraints)

f_star = prob.solve(solver="ECOS", verbose=verbose)

# # # # # Populate the tree with the estimated branch lengths # # # # #

for node in tree.nodes():
tree.set_age(node, age=r_X_t_variables[node].value)

for (parent, child) in tree.edges():
new_edge_length = (
r_X_t_variables[parent].value - r_X_t_variables[child].value
)
tree.set_edge_length(parent, child, length=new_edge_length)

self.log_likelihood = log_likelihood.value
self.log_loss = f_star

@classmethod
def log_likelihood(self, tree: Tree) -> float:
r"""
The log-likelihood of the given tree under the model.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is the log likelihood of the branch length model. Should we have a method in the tree class that gives the likelihood of the tree, over all parameters? (internal nodes, branch lengths, etc)

If we do, why not just delegate this function call to tree.log_likelihood?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are many possible models for the data on the tree (although currently the IIDExponential is the only one), just like there are many tree topology reconstruction algorithms for character matrices (the different Cassiopeia solvers). So, just like character matrices don't have a solve method that gives the tree topology, I don't think trees should have a log_likelihood method that gives their likelihood. Instead, another object takes care of that: Just asCassiopeiaSolver takes care of providing tree topologies, BranchLengthEstimator provides tree data likelihood under specific models. This was the picture in my mind when designing this at least.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your point totally makes sense. I guess there are two different likelihoods that are worth discussing.

The first is the likelihood of the model that's used in the branch length estimation procedure that depends on the cutting model proposed (i.e., you don't take into account things like the state that arises).

The second is a full likelihood of the tree across all parameters and observations. We can use something like Felsenstein's algorithm to compute this tree likelihood, given a tree topology, branch lengths, and a stochastic transition matrix. Maybe this functionality should be in the tree class, but not the BLE model likelihood.

log_likelihood = 0.0
for (parent, child) in tree.edges():
edge_length = tree.get_age(parent) - tree.get_age(child)
# TODO: hardcoded '0' here...
zeros_parent = tree.get_state(parent).count("0")
zeros_child = tree.get_state(child).count("0")
new_cuts_child = zeros_parent - zeros_child
assert new_cuts_child >= 0
# Add log-lik for characters that didn't get cut
log_likelihood += zeros_child * (-edge_length)
# Add log-lik for characters that got cut
if edge_length < 1e-8 and new_cuts_child > 0:
return -np.inf
log_likelihood += new_cuts_child * np.log(1 - np.exp(-edge_length))
return log_likelihood


class IIDExponentialBLEGridSearchCV(BranchLengthEstimator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I typically like to have one file per class. It's easier for me to track things and I think it makes things easier to review in the future. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

I used to follow your pattern until very recently, only because I started to read about people's thought on this and it seems people tend to prefer the several-related-classes-per-module way:


(No need to read these, just here for reference)
https://softwareengineering.stackexchange.com/questions/209982/is-it-considered-pythonic-to-have-multiple-classes-defined-in-the-same-file
https://stackoverflow.com/questions/106896/how-many-classes-should-i-put-in-one-file
https://www.reddit.com/r/Python/comments/41haw4/why_do_python_style_conventions_prefer_multiple/


However, our classes have very small interfaces and large implementations, so it does seem like the one-file-per-class you propose would be more suitable? Citing one of the answers above:

If your concrete implementations have very different internal concerns, your single file accumulates all those concerns. For example, implementations with non-overlapping dependencies make your single file depend on the union of all those dependencies.

So, it might sometimes be reasonable to consider the sub-classes' coupling to their dependencies outweighs their coupling to the interface (or conversely, the concern of implementing an interface is weaker than the concerns internal to that implementation).

As a specific example, take a generic database interface. Concrete implementations using an in-memory DB, an SQL RDBMS and a web query respectively may have nothing in common apart from the interface, and forcing everyone who wants the lightweight in-memory version to also import an SQL library is nasty.

What do you think about having a tools/branch_length_estimation package with one file per class? (In this case BranchLengthEstimator.py, IIDExponentialBLE.py, IIDExponentialBLEGridSearchCV.py).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the advantage of using the several-classes-per-file if the API and dependencies are all very similar, but yes I feel that the models we'll be developing with the BranchLengthEstimator subclasses will all vary quite a bit in their implementations. For example, the bayesian estimators, I predict, will require quite a bit of extra implementation that's not overlapping with the maximum likelihood based estimators.

So, maybe it would be worthwhile to keep everything separate. I don't think this will become unwieldy in the future, and in my opinion helps readability and interpretability of the codebase.

r"""
Like IIDExponentialBLE but with automatic tuning of hyperparameters.

This class fits the hyperparameters of IIDExponentialBLE based on
character-level held-out log-likelihood. It leaves out one character at a
time, fitting the data on all the remaining characters. Thus, the number
of models trained by this class is #characters * grid size.

Args:
minimum_branch_lengths: The grid of minimum_branch_length to use.
l2_regularizations: The grid of l2_regularization to use.
verbose: Verbosity level.
"""

def __init__(
self,
minimum_branch_lengths: Tuple[float] = (0,),
l2_regularizations: Tuple[float] = (0,),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is a grid of parameters to use, is there a reason why we're passing Tuples?

I feel like there must be a publicly available hyperparamter tuning library that allows you to pass ranges of hyperparamters. That might be more useful here - thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: tuples, isn't this the natural parameterization of a grid? E.g. minimum_branch_lengths = (0, 1), l2_regularizations = (0, 10, 100, 1000). The grid is given by all possible combinations of both parameters.

Re: hyperparam tuning library: Yes! I agree! I was peeking a bit into sklearn to see if we could reuse anything. I just wanted to get CV up and running for experimentation purposes, so I went the fast route of just coding it myself. We should look into a generic and reusable implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I had a total brainfart and was thinking you were using tuples that were of length 2. No reason why you can't use Tuples here rather than lists.

But yes, let's investigate hyperparameter tuning libraries as I imagine they'll do a lot of good stuff about exploring large parameter regimes intelligently. For example, Google has Vizier (https://cloud.google.com/ai-platform/optimizer/docs/overview) that not only explores large parameter spaces but actually optimizes some function at the same time so that the new experiments you do from your grid of parameters are likely improvements.

verbose: bool = False,
):
self.minimum_branch_lengths = minimum_branch_lengths
self.l2_regularizations = l2_regularizations
self.verbose = verbose

def estimate_branch_lengths(self, tree: Tree) -> None:
r"""
See base class. The only caveat is that this method raises if it fails
to solve the underlying optimization problem for any reason.

Raises:
cp.error.SolverError
"""
# Extract parameters
minimum_branch_lengths = self.minimum_branch_lengths
l2_regularizations = self.l2_regularizations
verbose = self.verbose

held_out_log_likelihoods = [] # type: List[Tuple[float, List]]
for minimum_branch_length in minimum_branch_lengths:
for l2_regularization in l2_regularizations:
cv_log_likelihood = self._cv_log_likelihood(
tree=tree,
minimum_branch_length=minimum_branch_length,
l2_regularization=l2_regularization,
)
held_out_log_likelihoods.append(
(
cv_log_likelihood,
[minimum_branch_length, l2_regularization],
)
)

# Refit model on full dataset with the best hyperparameters
held_out_log_likelihoods.sort(reverse=True)
(
best_minimum_branch_length,
best_l2_regularization,
) = held_out_log_likelihoods[0][1]
if verbose:
print(
f"Refitting full model with:\n"
f"minimum_branch_length={best_minimum_branch_length}\n"
f"l2_regularization={best_l2_regularization}"
)
final_model = IIDExponentialBLE(
minimum_branch_length=best_minimum_branch_length,
l2_regularization=best_l2_regularization,
)
final_model.estimate_branch_lengths(tree)
self.minimum_branch_length = best_minimum_branch_length
self.l2_regularization = best_l2_regularization
self.log_likelihood = final_model.log_likelihood
self.log_loss = final_model.log_loss

def _cv_log_likelihood(
self, tree: Tree, minimum_branch_length: float, l2_regularization: float
) -> float:
r"""
Given the tree and the parameters of the model, returns the
cross-validated log-likelihood of the model. This is done by holding out
one character at a time, fitting the model on the remaining characters,
and evaluating the log-likelihood on the held-out character. As a
consequence, #character models are fit by this method. The mean held-out
log-likelihood over the #character folds is returned.
"""
verbose = self.verbose
if verbose:
print(
f"Cross-validating hyperparameters:"
f"\nminimum_branch_length={minimum_branch_length}"
f"\nl2_regularizations={l2_regularization}"
)
n_characters = tree.num_characters()
log_likelihood_folds = np.zeros(shape=(n_characters))
for held_out_character_idx in range(n_characters):
tree_train, tree_valid = self._cv_split(
tree=tree, held_out_character_idx=held_out_character_idx
)
try:
IIDExponentialBLE(
minimum_branch_length=minimum_branch_length,
l2_regularization=l2_regularization,
).estimate_branch_lengths(tree_train)
tree_valid.copy_branch_lengths(tree_other=tree_train)
held_out_log_likelihood = IIDExponentialBLE.log_likelihood(
tree_valid
)
except cp.error.SolverError:
held_out_log_likelihood = -np.inf
mattjones315 marked this conversation as resolved.
Show resolved Hide resolved
log_likelihood_folds[
held_out_character_idx
] = held_out_log_likelihood
if verbose:
print(f"log_likelihood_folds = {log_likelihood_folds}")
print(
f"mean log_likelihood_folds = "
f"{np.mean(log_likelihood_folds)}"
)
return np.mean(log_likelihood_folds)

def _cv_split(
self, tree: Tree, held_out_character_idx: int
) -> Tuple[Tree, Tree]:
r"""
Creates a training and a cross validation tree by hiding the
character at position held_out_character_idx.
"""
tree_train = copy.deepcopy(tree)
tree_valid = copy.deepcopy(tree)
for node in tree.nodes():
state = tree_train.get_state(node)
train_state = (
state[:held_out_character_idx]
+ state[(held_out_character_idx + 1) :]
)
valid_data = state[held_out_character_idx]
tree_train.set_state(node, train_state)
tree_valid.set_state(node, valid_data)
return tree_train, tree_valid
Loading