-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Branch length estimation #50
base: refactor
Are you sure you want to change the base?
Changes from 15 commits
83d6287
93fcd99
c52b474
dfb7321
55a0bec
4437e48
af23207
8a92290
b04ef59
7ee49ac
b3f57b1
b781c23
ef27e11
7fb4d38
5d8f0ce
cf019e9
b50186e
1e19f65
32f5624
cb4bcd6
5fddb6f
549b34e
a9fbf18
1cdd402
23fec63
73a52e1
a3b87a9
e4cb173
aac3977
ac0994f
62ff0c3
ae0e9d6
cb571fd
d5dc2b6
3d27e7a
563c853
0db9e8b
634e39b
a8b87bb
4debf9f
3b69f5c
5119e8b
4d0a3b6
c419804
ee797b5
d36e554
86c0ad0
dd4f0bf
569f8c0
90a2155
b33bb62
a24fff3
dcce54d
505c802
e4c1dc7
452a978
7f292eb
b398ea4
22fb385
6f4eefd
cf94b2a
ec7c03b
8baab4c
4b5039a
a1372c2
ec87a1d
d1fc030
3e38683
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from .branch_length_estimator import ( | ||
BranchLengthEstimator, | ||
IIDExponentialBLE, | ||
IIDExponentialBLEGridSearchCV, | ||
) | ||
from .lineage_simulator import ( | ||
LineageSimulator, | ||
PerfectBinaryTree, | ||
PerfectBinaryTreeWithRootBranch, | ||
) | ||
from .lineage_tracing_simulator import ( | ||
LineageTracingSimulator, | ||
IIDExponentialLineageTracer, | ||
) | ||
from .tree import Tree |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,325 @@ | ||
import abc | ||
import copy | ||
from typing import List, Tuple | ||
|
||
import cvxpy as cp | ||
import numpy as np | ||
|
||
from .tree import Tree | ||
|
||
|
||
class BranchLengthEstimator(abc.ABC): | ||
r""" | ||
Abstract base class for all branch length estimators. | ||
|
||
A BranchLengthEstimator implements a method estimate_branch_lengths which, | ||
given a Tree with lineage tracing character vectors at the leaves (and | ||
possibly at the internal nodes too), estimates the branch lengths of the | ||
tree. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def estimate_branch_lengths(self, tree: Tree) -> None: | ||
mattjones315 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
r""" | ||
Estimates the branch lengths of the tree. | ||
|
||
Annotates the tree's nodes with their estimated age, and | ||
the tree's branches with their estiamted lengths. Operates on the tree | ||
in-place. | ||
|
||
Args: | ||
tree: The tree for which to estimate branch lengths. | ||
""" | ||
|
||
|
||
class IIDExponentialBLE(BranchLengthEstimator): | ||
r""" | ||
A simple branch length estimator that assumes that the characters evolve IID | ||
over the phylogeny with the same cutting rate. | ||
|
||
This estimator requires that the ancestral states are provided. | ||
|
||
The optimization problem is a special kind of convex program called an | ||
exponential cone program: | ||
https://docs.mosek.com/modeling-cookbook/expo.html | ||
Because it is a convex program, it can be readily solved. | ||
|
||
Args: | ||
minimum_branch_length: Estimated branch lengths will be constrained to | ||
have at least this lenght. | ||
mattjones315 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
l2_regularization: Consecutive branches will be regularized to have | ||
similar length via an L2 penalty whose weight is given by | ||
l2_regularization. | ||
verbose: Verbosity level. | ||
|
||
Attributes: | ||
log_likelihood: The log-likelihood of the training data under the | ||
estimated model. | ||
log_loss: The log-loss of the training data under the estimated model. | ||
This is the log likhelihood plus the regularization terms. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo -> "likelihood" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||
""" | ||
|
||
def __init__( | ||
self, | ||
minimum_branch_length: float = 0, | ||
l2_regularization: float = 0, | ||
verbose: bool = False, | ||
): | ||
self.minimum_branch_length = minimum_branch_length | ||
self.l2_regularization = l2_regularization | ||
self.verbose = verbose | ||
|
||
def estimate_branch_lengths(self, tree: Tree) -> None: | ||
r""" | ||
See base class. The only caveat is that this method raises if it fails | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. incomplete sentence: "this method raises [an error?]..." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've always used "raises" for short to mean that it "raises an error", but it might not be proper english. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah - you might have more software engineering lingo than I do! If it's not too much of a problem, let's be more verbose here in the docstring. Can you also add a |
||
to solve the underlying optimization problem for any reason. | ||
|
||
Raises: | ||
cp.error.SolverError | ||
""" | ||
# Extract parameters | ||
minimum_branch_length = self.minimum_branch_length | ||
l2_regularization = self.l2_regularization | ||
verbose = self.verbose | ||
|
||
# # Wrap the networkx DiGraph for goodies. | ||
# tree = Tree(tree) | ||
mattjones315 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# # # # # Create variables of the optimization problem # # # # # | ||
r_X_t_variables = dict( | ||
[ | ||
(node_id, cp.Variable(name=f"r_X_t_{node_id}")) | ||
for node_id in tree.nodes() | ||
] | ||
) | ||
time_increases_constraints = [ | ||
r_X_t_variables[parent] | ||
>= r_X_t_variables[child] + minimum_branch_length | ||
for (parent, child) in tree.edges() | ||
] | ||
leaves_have_age_0_constraints = [ | ||
r_X_t_variables[leaf] == 0 for leaf in tree.leaves() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This appears to be an ultrametric condition - this is okay for our purposes right now, but what if you were attempting to do this for a tree where this did not hold? Do we have a separate estimator without this condition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is the ultrametric condition. We don't have an estimator without this condition. As of today I don't think we need it, but if we do I guess we can just add this option to this estimator as an argument? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's totally fine to keep this condition in there. It might be worthwhile to add this condition, somehow, to the name of the estimator like |
||
] | ||
non_negative_r_X_t_constraints = [ | ||
r_X_t >= 0 for r_X_t in r_X_t_variables.values() | ||
] | ||
all_constraints = ( | ||
time_increases_constraints | ||
+ leaves_have_age_0_constraints | ||
+ non_negative_r_X_t_constraints | ||
) | ||
|
||
# # # # # Compute the log-likelihood # # # # # | ||
log_likelihood = 0 | ||
|
||
# Because all rates are equal, the number of cuts in each node is a | ||
# sufficient statistic. This makes the solver WAY faster! | ||
for (parent, child) in tree.edges(): | ||
edge_length = r_X_t_variables[parent] - r_X_t_variables[child] | ||
# TODO: hardcoded '0' here... | ||
mattjones315 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
zeros_parent = tree.get_state(parent).count("0") | ||
zeros_child = tree.get_state(child).count("0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we're assuming throughout here that the internal nodes have states. This isn't always the case, but we can always infer back the states at internal nodes. Should you have a check at the top that the states of internal nodes exist? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, as per the docstring: As a note, the estimator currently does not have the responsibility of imputing ancestral states if they are not provided. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally makes sense! When this code is deployed, let's make sure that an error is thrown if ancestral states are not given. |
||
new_cuts_child = zeros_parent - zeros_child | ||
assert new_cuts_child >= 0 | ||
mattjones315 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Add log-lik for characters that didn't get cut | ||
log_likelihood += zeros_child * (-edge_length) | ||
# Add log-lik for characters that got cut | ||
log_likelihood += new_cuts_child * cp.log(1 - cp.exp(-edge_length)) | ||
|
||
# # # # # Add regularization # # # # # | ||
|
||
l2_penalty = 0 | ||
for (parent, child) in tree.edges(): | ||
for child_of_child in tree.children(child): | ||
edge_length_above = ( | ||
r_X_t_variables[parent] - r_X_t_variables[child] | ||
) | ||
edge_length_below = ( | ||
r_X_t_variables[child] - r_X_t_variables[child_of_child] | ||
) | ||
l2_penalty += (edge_length_above - edge_length_below) ** 2 | ||
l2_penalty *= l2_regularization | ||
|
||
# # # # # Solve the problem # # # # # | ||
|
||
obj = cp.Maximize(log_likelihood - l2_penalty) | ||
prob = cp.Problem(obj, all_constraints) | ||
|
||
f_star = prob.solve(solver="ECOS", verbose=verbose) | ||
|
||
# # # # # Populate the tree with the estimated branch lengths # # # # # | ||
|
||
for node in tree.nodes(): | ||
tree.set_age(node, age=r_X_t_variables[node].value) | ||
|
||
for (parent, child) in tree.edges(): | ||
new_edge_length = ( | ||
r_X_t_variables[parent].value - r_X_t_variables[child].value | ||
) | ||
tree.set_edge_length(parent, child, length=new_edge_length) | ||
|
||
self.log_likelihood = log_likelihood.value | ||
self.log_loss = f_star | ||
|
||
@classmethod | ||
def log_likelihood(self, tree: Tree) -> float: | ||
r""" | ||
The log-likelihood of the given tree under the model. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this is the log likelihood of the branch length model. Should we have a method in the tree class that gives the likelihood of the tree, over all parameters? (internal nodes, branch lengths, etc) If we do, why not just delegate this function call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are many possible models for the data on the tree (although currently the IIDExponential is the only one), just like there are many tree topology reconstruction algorithms for character matrices (the different Cassiopeia solvers). So, just like character matrices don't have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your point totally makes sense. I guess there are two different likelihoods that are worth discussing. The first is the likelihood of the model that's used in the branch length estimation procedure that depends on the cutting model proposed (i.e., you don't take into account things like the state that arises). The second is a full likelihood of the tree across all parameters and observations. We can use something like Felsenstein's algorithm to compute this tree likelihood, given a tree topology, branch lengths, and a stochastic transition matrix. Maybe this functionality should be in the tree class, but not the BLE model likelihood. |
||
log_likelihood = 0.0 | ||
for (parent, child) in tree.edges(): | ||
edge_length = tree.get_age(parent) - tree.get_age(child) | ||
# TODO: hardcoded '0' here... | ||
zeros_parent = tree.get_state(parent).count("0") | ||
zeros_child = tree.get_state(child).count("0") | ||
new_cuts_child = zeros_parent - zeros_child | ||
assert new_cuts_child >= 0 | ||
# Add log-lik for characters that didn't get cut | ||
log_likelihood += zeros_child * (-edge_length) | ||
# Add log-lik for characters that got cut | ||
if edge_length < 1e-8 and new_cuts_child > 0: | ||
return -np.inf | ||
log_likelihood += new_cuts_child * np.log(1 - np.exp(-edge_length)) | ||
return log_likelihood | ||
|
||
|
||
class IIDExponentialBLEGridSearchCV(BranchLengthEstimator): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I typically like to have one file per class. It's easier for me to track things and I think it makes things easier to review in the future. What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! I used to follow your pattern until very recently, only because I started to read about people's thought on this and it seems people tend to prefer the several-related-classes-per-module way: (No need to read these, just here for reference) However, our classes have very small interfaces and large implementations, so it does seem like the one-file-per-class you propose would be more suitable? Citing one of the answers above: If your concrete implementations have very different internal concerns, your single file accumulates all those concerns. For example, implementations with non-overlapping dependencies make your single file depend on the union of all those dependencies. So, it might sometimes be reasonable to consider the sub-classes' coupling to their dependencies outweighs their coupling to the interface (or conversely, the concern of implementing an interface is weaker than the concerns internal to that implementation). As a specific example, take a generic database interface. Concrete implementations using an in-memory DB, an SQL RDBMS and a web query respectively may have nothing in common apart from the interface, and forcing everyone who wants the lightweight in-memory version to also import an SQL library is nasty. What do you think about having a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see the advantage of using the several-classes-per-file if the API and dependencies are all very similar, but yes I feel that the models we'll be developing with the BranchLengthEstimator subclasses will all vary quite a bit in their implementations. For example, the bayesian estimators, I predict, will require quite a bit of extra implementation that's not overlapping with the maximum likelihood based estimators. So, maybe it would be worthwhile to keep everything separate. I don't think this will become unwieldy in the future, and in my opinion helps readability and interpretability of the codebase. |
||
r""" | ||
Like IIDExponentialBLE but with automatic tuning of hyperparameters. | ||
|
||
This class fits the hyperparameters of IIDExponentialBLE based on | ||
character-level held-out log-likelihood. It leaves out one character at a | ||
time, fitting the data on all the remaining characters. Thus, the number | ||
of models trained by this class is #characters * grid size. | ||
|
||
Args: | ||
minimum_branch_lengths: The grid of minimum_branch_length to use. | ||
l2_regularizations: The grid of l2_regularization to use. | ||
verbose: Verbosity level. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
minimum_branch_lengths: Tuple[float] = (0,), | ||
l2_regularizations: Tuple[float] = (0,), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this is a grid of parameters to use, is there a reason why we're passing Tuples? I feel like there must be a publicly available hyperparamter tuning library that allows you to pass ranges of hyperparamters. That might be more useful here - thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re: tuples, isn't this the natural parameterization of a grid? E.g. Re: hyperparam tuning library: Yes! I agree! I was peeking a bit into sklearn to see if we could reuse anything. I just wanted to get CV up and running for experimentation purposes, so I went the fast route of just coding it myself. We should look into a generic and reusable implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! I had a total brainfart and was thinking you were using tuples that were of length 2. No reason why you can't use Tuples here rather than lists. But yes, let's investigate hyperparameter tuning libraries as I imagine they'll do a lot of good stuff about exploring large parameter regimes intelligently. For example, Google has |
||
verbose: bool = False, | ||
): | ||
self.minimum_branch_lengths = minimum_branch_lengths | ||
self.l2_regularizations = l2_regularizations | ||
self.verbose = verbose | ||
|
||
def estimate_branch_lengths(self, tree: Tree) -> None: | ||
r""" | ||
See base class. The only caveat is that this method raises if it fails | ||
to solve the underlying optimization problem for any reason. | ||
|
||
Raises: | ||
cp.error.SolverError | ||
""" | ||
# Extract parameters | ||
minimum_branch_lengths = self.minimum_branch_lengths | ||
l2_regularizations = self.l2_regularizations | ||
verbose = self.verbose | ||
|
||
held_out_log_likelihoods = [] # type: List[Tuple[float, List]] | ||
for minimum_branch_length in minimum_branch_lengths: | ||
for l2_regularization in l2_regularizations: | ||
cv_log_likelihood = self._cv_log_likelihood( | ||
tree=tree, | ||
minimum_branch_length=minimum_branch_length, | ||
l2_regularization=l2_regularization, | ||
) | ||
held_out_log_likelihoods.append( | ||
( | ||
cv_log_likelihood, | ||
[minimum_branch_length, l2_regularization], | ||
) | ||
) | ||
|
||
# Refit model on full dataset with the best hyperparameters | ||
held_out_log_likelihoods.sort(reverse=True) | ||
( | ||
best_minimum_branch_length, | ||
best_l2_regularization, | ||
) = held_out_log_likelihoods[0][1] | ||
if verbose: | ||
print( | ||
f"Refitting full model with:\n" | ||
f"minimum_branch_length={best_minimum_branch_length}\n" | ||
f"l2_regularization={best_l2_regularization}" | ||
) | ||
final_model = IIDExponentialBLE( | ||
minimum_branch_length=best_minimum_branch_length, | ||
l2_regularization=best_l2_regularization, | ||
) | ||
final_model.estimate_branch_lengths(tree) | ||
self.minimum_branch_length = best_minimum_branch_length | ||
self.l2_regularization = best_l2_regularization | ||
self.log_likelihood = final_model.log_likelihood | ||
self.log_loss = final_model.log_loss | ||
|
||
def _cv_log_likelihood( | ||
self, tree: Tree, minimum_branch_length: float, l2_regularization: float | ||
) -> float: | ||
r""" | ||
Given the tree and the parameters of the model, returns the | ||
cross-validated log-likelihood of the model. This is done by holding out | ||
one character at a time, fitting the model on the remaining characters, | ||
and evaluating the log-likelihood on the held-out character. As a | ||
consequence, #character models are fit by this method. The mean held-out | ||
log-likelihood over the #character folds is returned. | ||
""" | ||
verbose = self.verbose | ||
if verbose: | ||
print( | ||
f"Cross-validating hyperparameters:" | ||
f"\nminimum_branch_length={minimum_branch_length}" | ||
f"\nl2_regularizations={l2_regularization}" | ||
) | ||
n_characters = tree.num_characters() | ||
log_likelihood_folds = np.zeros(shape=(n_characters)) | ||
for held_out_character_idx in range(n_characters): | ||
tree_train, tree_valid = self._cv_split( | ||
tree=tree, held_out_character_idx=held_out_character_idx | ||
) | ||
try: | ||
IIDExponentialBLE( | ||
minimum_branch_length=minimum_branch_length, | ||
l2_regularization=l2_regularization, | ||
).estimate_branch_lengths(tree_train) | ||
tree_valid.copy_branch_lengths(tree_other=tree_train) | ||
held_out_log_likelihood = IIDExponentialBLE.log_likelihood( | ||
tree_valid | ||
) | ||
except cp.error.SolverError: | ||
held_out_log_likelihood = -np.inf | ||
mattjones315 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
log_likelihood_folds[ | ||
held_out_character_idx | ||
] = held_out_log_likelihood | ||
if verbose: | ||
print(f"log_likelihood_folds = {log_likelihood_folds}") | ||
print( | ||
f"mean log_likelihood_folds = " | ||
f"{np.mean(log_likelihood_folds)}" | ||
) | ||
return np.mean(log_likelihood_folds) | ||
|
||
def _cv_split( | ||
self, tree: Tree, held_out_character_idx: int | ||
) -> Tuple[Tree, Tree]: | ||
r""" | ||
Creates a training and a cross validation tree by hiding the | ||
character at position held_out_character_idx. | ||
""" | ||
tree_train = copy.deepcopy(tree) | ||
tree_valid = copy.deepcopy(tree) | ||
for node in tree.nodes(): | ||
state = tree_train.get_state(node) | ||
train_state = ( | ||
state[:held_out_character_idx] | ||
+ state[(held_out_character_idx + 1) :] | ||
) | ||
valid_data = state[held_out_character_idx] | ||
tree_train.set_state(node, train_state) | ||
tree_valid.set_state(node, valid_data) | ||
return tree_train, tree_valid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need to add this to the setup.py, my tests break because this wasn't installed during
make install
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed!
QQ: I noticed many of the requirements have a minimum version requirement, e.g.
numpy > 1.17
, is there a reason for that? I was thinking to just addcvxpy
without a version requirement.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having version requirements is helpful for a couple of reasons -
(1) Sometimes, APIs are different between versions (e.g., networkx changes the way it calls some functionality in versions <2 and >2, I believe)
(2) Sometimes, earlier versions of packages throw warnings (this is especially relevant if, for example, a particular package relies on another package that has updated and is deprecating some feature). Requiring a minimal version (or maximal version, for that matter) may help you get rid of warnings being thrown around.
In general, it's likely that having an outdated version of a particular package will be fine and the worst thing that will happen is that a warning might be thrown. However, I think it's good for software reproducibility so that when a new user comes along they can create a very similar environment to the developer. Hope this makes sense!