diff --git a/cassiopeia/data/CassiopeiaTree.py b/cassiopeia/data/CassiopeiaTree.py index af5523fb..6eb23b50 100644 --- a/cassiopeia/data/CassiopeiaTree.py +++ b/cassiopeia/data/CassiopeiaTree.py @@ -391,10 +391,31 @@ def internal_nodes(self) -> List[str]: if "internal_nodes" not in self.__cache: self.__cache["internal_nodes"] = [ - n for n in self.__network if self.__network.out_degree(n) > 1 + n for n in self.__network if self.__network.out_degree(n) > 0 ] return self.__cache["internal_nodes"][:] + @property + def non_root_internal_nodes(self) -> List[str]: + """Returns internal nodes in tree (excluding the root). + + Returns: + The internal nodes of the tree that are not the root (i.e. all + nodes not at the leaves, and not the root) + + Raises: + CassiopeiaTreeError if the tree has not been initialized. + """ + self.__check_network_initialized() + + if "non_root_internal_nodes" not in self.__cache: + res = [ + n for n in self.__network if self.__network.out_degree(n) > 0 + ] + res.remove(self.root) + self.__cache["non_root_internal_nodes"] = res + return self.__cache["non_root_internal_nodes"][:] + @property def nodes(self) -> List[str]: """Returns all nodes in tree. @@ -465,12 +486,16 @@ def is_internal_node(self, node: str) -> bool: self.__check_network_initialized() return self.__network.out_degree(node) > 0 - def reconstruct_ancestral_characters(self) -> None: + def reconstruct_ancestral_characters(self, zero_the_root: bool = False): """Reconstruct ancestral character states. Reconstructs ancestral states (i.e., those character states in the internal nodes) using the Camin-Sokal parsimony criterion (i.e., irreversibility). Operates on the tree in place. + + Args: + zero_the_root: If True, the root will be forced to have unmutated + chracters. """ self.__check_network_initialized() @@ -484,6 +509,9 @@ def reconstruct_ancestral_characters(self) -> None: ) self.__set_character_states(n, reconstructed) + if zero_the_root: + self.__set_character_states(self.root, [0] * self.n_character) + def parent(self, node: str) -> str: """Gets the parent of a node. @@ -535,23 +563,25 @@ def set_time(self, node: str, new_time: float) -> None: """ self.__check_network_initialized() - parent = self.parent(node) - if new_time < self.get_time(parent): - raise CassiopeiaTreeError( - "New age is less than the age of the parent." - ) + if node != self.root: + parent = self.parent(node) + if new_time < self.get_time(parent): + raise CassiopeiaTreeError( + "New time is less than the time of the parent." + ) for child in self.children(node): if new_time > self.get_time(child): raise CassiopeiaTreeError( - "New age is greater than than a child." + "New time is greater than than a child." ) self.__network.nodes[node]["time"] = new_time - self.__network[parent][node]["length"] = new_time - self.get_time( - parent - ) + if node != self.root: + self.__network[parent][node]["length"] = new_time - self.get_time( + parent + ) for child in self.children(node): self.__network[node][child]["length"] = ( self.get_time(child) - new_time @@ -712,6 +742,8 @@ def get_character_states(self, node: str) -> List[int]: Returns: The full character state array of the specified node. """ + if node not in self.__network.nodes: + raise CassiopeiaTreeError(f"Node {node} does not exist!") return self.__network.nodes[node]["character_states"][:] def depth_first_traverse_nodes( @@ -817,6 +849,8 @@ def get_mutations_along_edge( Returns a list of tuples (character, state) of mutations that occur along an edge. Characters are 0-indexed. + WARNING: A character dropout event will also be considered a mutation! + Args: parent: parent in tree child: child in tree @@ -844,6 +878,26 @@ def get_mutations_along_edge( return mutations + def get_number_of_mutations_along_edge( + self, parent: str, child: str + ) -> int: + return len(self.get_mutations_along_edge(parent, child)) + + def get_number_of_unmutated_characters_in_node( + self, node: str + ) -> int: + states = self.get_character_states(node) + return states.count(0) + + def get_number_of_mutated_characters_in_node( + self, node: str + ) -> int: + r""" + WARNING: dropped out characters will be considered as mutated too! + """ + return self.n_character -\ + self.get_number_of_unmutated_characters_in_node(node) + def relabel_nodes(self, relabel_map: Dict[str, str]) -> None: """Relabels the nodes in the tree. @@ -954,3 +1008,23 @@ def compute_dissimilarity_map( ) self.set_dissimilarity_map(dissimilarity_map) + + def scale_to_unit_length(self) -> None: + r""" + Scales the tree to have unit length. I.e. the longest path from root to + leaf will have length 1 after the scaling. + """ + times = {} + max_time = max(self.get_times().values()) + for node in self.nodes: + times[node] = self.get_time(node) / max_time + self.set_times(times) + + +def resolve_multifurcations(tree: CassiopeiaTree) -> None: + r""" + Resolves the multifurcations of the CassiopeiaTree inplace. + """ + binary_topology = utilities.resolve_multifurcations_networkx( + tree.get_tree_topology()) + tree.populate_tree(binary_topology) diff --git a/cassiopeia/data/__init__.py b/cassiopeia/data/__init__.py index 69ba29f2..33c7d542 100644 --- a/cassiopeia/data/__init__.py +++ b/cassiopeia/data/__init__.py @@ -1,7 +1,8 @@ """Top level for data.""" -from .CassiopeiaTree import CassiopeiaTree +from .CassiopeiaTree import CassiopeiaTree, resolve_multifurcations from .utilities import ( + resolve_multifurcations_networkx, sample_bootstrap_allele_tables, sample_bootstrap_character_matrices, to_newick, diff --git a/cassiopeia/data/utilities.py b/cassiopeia/data/utilities.py index ddd9ad1a..3d08cf22 100644 --- a/cassiopeia/data/utilities.py +++ b/cassiopeia/data/utilities.py @@ -1,6 +1,8 @@ """ General utilities for the datasets encountered in Cassiopeia. """ +import copy +from queue import PriorityQueue from typing import Callable, Dict, List, Optional, Tuple import ete3 @@ -318,3 +320,96 @@ def sample_bootstrap_allele_tables( ) return bootstrap_samples + + +def resolve_multifurcations_networkx(tree: nx.DiGraph) -> nx.DiGraph: + r""" + Given a tree represented by a networkx DiGraph, it resolves + multifurcations. The tree is NOT modified in-place. + The root is made to have only one children, as in a real-life tumor + (the founding cell never divides immediately!) + """ + tree = copy.deepcopy(tree) + node_names = set([n for n in tree]) + root = [n for n in tree if tree.in_degree(n) == 0][0] + subtree_sizes = {} + _dfs_subtree_sizes(tree, subtree_sizes, root) + assert len(subtree_sizes) == len([n for n in tree]) + + # First make the root have degree 1. + if tree.out_degree(root) >= 2: + children = list(tree.successors(root)) + assert len(children) == tree.out_degree(root) + # First remove the edges from the root + tree.remove_edges_from([(root, child) for child in children]) + # Now create the intermediate node and add edges back + root_child = f"{root}-child" + if root_child in node_names: + raise RuntimeError("Node name already exists!") + tree.add_edge(root, root_child) + tree.add_edges_from([(root_child, child) for child in children]) + + def _dfs_resolve_multifurcations(tree, v): + children = list(tree.successors(v)) + if len(children) >= 3: + # Must resolve the multifurcation + _resolve_multifurcation(tree, v, subtree_sizes, node_names) + for child in children: + _dfs_resolve_multifurcations(tree, child) + + _dfs_resolve_multifurcations(tree, root) + # Check that the tree is binary + if not (len(tree.nodes) == len(tree.edges) + 1): + raise RuntimeError("Failed to binarize tree") + return tree + + +def _resolve_multifurcation(tree, v, subtree_sizes, node_names): + r""" + node_names is used to make sure we don't create a node name that already + exists. + """ + children = list(tree.successors(v)) + n_children = len(children) + assert n_children >= 3 + + # Remove all edges from v to its children + tree.remove_edges_from([(v, child) for child in children]) + + # Create the new binary structure + queue = PriorityQueue() + for child in children: + queue.put((subtree_sizes[child], child)) + + for i in range(n_children - 2): + # Coalesce two smallest subtrees + subtree_1_size, subtree_1_root = queue.get() + subtree_2_size, subtree_2_root = queue.get() + assert subtree_1_size <= subtree_2_size + coalesced_tree_size = subtree_1_size + subtree_2_size + 1 + coalesced_tree_root = f"{v}-coalesce-{i}" + if coalesced_tree_root in node_names: + raise RuntimeError("Node name already exists!") + # For debugging: + # print(f"Coalescing {subtree_1_root} (sz {subtree_1_size}) and" + # f" {subtree_2_root} (sz {subtree_2_size})") + tree.add_edges_from( + [ + (coalesced_tree_root, subtree_1_root), + (coalesced_tree_root, subtree_2_root), + ] + ) + queue.put((coalesced_tree_size, coalesced_tree_root)) + # Hang the two subtrees obtained to v + subtree_1_size, subtree_1_root = queue.get() + subtree_2_size, subtree_2_root = queue.get() + assert subtree_1_size <= subtree_2_size + tree.add_edges_from([(v, subtree_1_root), (v, subtree_2_root)]) + + +def _dfs_subtree_sizes(tree, subtree_sizes, v) -> int: + res = 1 + for child in tree.successors(v): + res += _dfs_subtree_sizes(tree, subtree_sizes, child) + subtree_sizes[v] = res + return res diff --git a/cassiopeia/solver/ResolveMultifurcationsWrapper.py b/cassiopeia/solver/ResolveMultifurcationsWrapper.py new file mode 100644 index 00000000..87b29ea5 --- /dev/null +++ b/cassiopeia/solver/ResolveMultifurcationsWrapper.py @@ -0,0 +1,27 @@ +import copy +from cassiopeia.data import CassiopeiaTree, resolve_multifurcations +from cassiopeia.solver import CassiopeiaSolver + + +# https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python +# permalink: https://stackoverflow.com/a/1445289 +class ResolveMultifurcationsWrapper(CassiopeiaSolver.CassiopeiaSolver): + r""" + Wraps a CassiopeiaSolver. + The wrapped solver is used to solve for the tree topology, after which + the multifurcations are resolved. + """ + + def __init__(self, solver: CassiopeiaSolver.CassiopeiaSolver): + solver = copy.deepcopy(solver) + self.__class__ = type( + solver.__class__.__name__, + (self.__class__, solver.__class__), + {}, + ) + self.__dict__ = solver.__dict__ + self.__solver = solver + + def solve(self, tree: CassiopeiaTree) -> None: + self.__solver.solve(tree) + resolve_multifurcations(tree) diff --git a/cassiopeia/solver/StringifyNodeNamesWrapper.py b/cassiopeia/solver/StringifyNodeNamesWrapper.py new file mode 100644 index 00000000..71d519e4 --- /dev/null +++ b/cassiopeia/solver/StringifyNodeNamesWrapper.py @@ -0,0 +1,33 @@ +import copy +from cassiopeia.data import CassiopeiaTree +from cassiopeia.solver import CassiopeiaSolver + + +# https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python +# permalink: https://stackoverflow.com/a/1445289 +class StringifyNodeNamesWrapper(CassiopeiaSolver.CassiopeiaSolver): + r""" + Wraps a CassiopeiaSolver. + The wrapped solver is used to solve for the tree topology, after which + the node names are cast to string. This is because some solvers create + nodes with integer IDs, and this can break downstream code. + """ + def __init__(self, solver: CassiopeiaSolver.CassiopeiaSolver): + solver = copy.deepcopy(solver) + self.__class__ = type( + solver.__class__.__name__, + (self.__class__, solver.__class__), + {}, + ) + self.__dict__ = solver.__dict__ + self.__solver = solver + + def solve(self, tree: CassiopeiaTree) -> None: + self.__solver.solve(tree) + relabel_map = {node: 'internal-' + str(node) for node in + tree.internal_nodes} + num_nodes_before = len(tree.nodes) + tree.relabel_nodes(relabel_map) + num_nodes_after = len(tree.nodes) + if num_nodes_before != num_nodes_after: + raise RuntimeError("There was a colision stringifying node names.") diff --git a/cassiopeia/solver/__init__.py b/cassiopeia/solver/__init__.py index bf8d2640..14a7c101 100755 --- a/cassiopeia/solver/__init__.py +++ b/cassiopeia/solver/__init__.py @@ -10,7 +10,9 @@ from .MaxCutSolver import MaxCutSolver from .NeighborJoiningSolver import NeighborJoiningSolver from .solver_utilities import collapse_tree, collapse_unifurcations +from .ResolveMultifurcationsWrapper import ResolveMultifurcationsWrapper from .SpectralGreedySolver import SpectralGreedySolver from .SpectralSolver import SpectralSolver +from .StringifyNodeNamesWrapper import StringifyNodeNamesWrapper from .VanillaGreedySolver import VanillaGreedySolver from . import dissimilarity_functions as dissimilarity diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py new file mode 100644 index 00000000..44dcb6b8 --- /dev/null +++ b/cassiopeia/tools/__init__.py @@ -0,0 +1,20 @@ +from .branch_length_estimator import ( + BranchLengthEstimator, + BLEMultifurcationWrapper, + IIDExponentialBLE, + IIDExponentialBLEGridSearchCV, + IIDExponentialPosteriorMeanBLE, + IIDExponentialPosteriorMeanBLEGridSearchCV, +) +from .lineage_simulator import ( + BirthProcess, + LineageSimulator, + PerfectBinaryTree, + PerfectBinaryTreeWithRootBranch, + TumorWithAFitSubclone, +) +from .lineage_tracing_simulator import ( + LineageTracingSimulator, + IIDExponentialLineageTracer, +) +from .cell_subsampler import CellSubsampler, UniformCellSubsampler diff --git a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py new file mode 100644 index 00000000..4c1b8ece --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py @@ -0,0 +1,44 @@ +import copy +from cassiopeia.data import CassiopeiaTree, resolve_multifurcations_networkx +from .BranchLengthEstimator import BranchLengthEstimator + + +# https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python +# permalink: https://stackoverflow.com/a/1445289 +class BLEMultifurcationWrapper(BranchLengthEstimator): + r""" + Wraps a BranchLengthEstimator. + When estimating branch lengths: + 1) the tree topology is first copied out + 2) then multifurcations in the tree topology are resolved into a binary + structure, + 3) then branch lengths are estimated on this binary topology + 4) finally, the node ages are copied back onto the original tree. + Maximum Parsimony will be used to reconstruct the ancestral states. + """ + + def __init__(self, ble_model: BranchLengthEstimator): + ble_model = copy.deepcopy(ble_model) + self.__class__ = type( + ble_model.__class__.__name__, + (self.__class__, ble_model.__class__), + {}, + ) + self.__dict__ = ble_model.__dict__ + self.__ble_model = ble_model + + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: + binary_topology = resolve_multifurcations_networkx( + tree.get_tree_topology() + ) + # For debugging: + # print(f"binary_topology = {binary_topology.__dict__}") + tree_binary = CassiopeiaTree( + character_matrix=tree.get_current_character_matrix(), + tree=binary_topology, + ) + tree_binary.reconstruct_ancestral_characters(zero_the_root=True) + self.__ble_model.estimate_branch_lengths(tree_binary) + # Copy the times from the binary tree onto the original tree + times = dict([(v, tree_binary.get_time(v)) for v in tree.nodes]) + tree.set_times(times) diff --git a/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py new file mode 100644 index 00000000..96bf76af --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py @@ -0,0 +1,33 @@ +import abc + +from cassiopeia.data import CassiopeiaTree + + +class BranchLengthEstimatorError(Exception): + """An Exception class for the BranchLengthEstimator class.""" + + pass + + +class BranchLengthEstimator(abc.ABC): + r""" + Abstract base class for all branch length estimators. + + A BranchLengthEstimator implements a method estimate_branch_lengths which, + given a Tree with lineage tracing character vectors at the leaves (and + possibly at the internal nodes too), estimates the branch lengths of the + tree. + """ + + @abc.abstractmethod + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: + r""" + Estimates the branch lengths of the tree. + + Annotates the tree's nodes with their estimated age, and + the tree's branches with their estimated lengths. Operates on the tree + in-place. + + Args: + tree: The tree for which to estimate branch lengths. + """ diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py new file mode 100644 index 00000000..41644a0a --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -0,0 +1,359 @@ +import multiprocessing +from typing import List, Optional, Tuple + +import cvxpy as cp +import numpy as np + +from cassiopeia.data import CassiopeiaTree +from .BranchLengthEstimator import BranchLengthEstimator +from . import utils + + +class IIDExponentialBLE(BranchLengthEstimator): + r""" + A simple branch length estimator that assumes that the characters evolve IID + over the phylogeny with the same cutting rate. + + This estimator requires that the ancestral states are provided. + + The optimization problem is a special kind of convex program called an + exponential cone program: + https://docs.mosek.com/modeling-cookbook/expo.html + Because it is a convex program, it can be readily solved. + + Args: + minimum_branch_length: Estimated branch lengths will be constrained to + have at least length THIS MULTIPLE OF THE TREE HEIGHT. If this is + greater than 1.0 / [height of the tree] (where the height + is measured in terms of the greatest number of edges of any lineage) + then all edges will have length 0, so be careful! + l2_regularization: Consecutive branches will be regularized to have + similar length via an L2 penalty whose weight is given by + l2_regularization. + verbose: Verbosity level. + + Attributes: + log_likelihood: The log-likelihood of the training data under the + estimated model. + log_loss: The log-loss of the training data under the estimated model. + This is the log likelihood plus the regularization terms. + """ + + def __init__( + self, + minimum_branch_length: float = 0, + l2_regularization: float = 0, + verbose: bool = False, + ): + self.minimum_branch_length = minimum_branch_length + self.l2_regularization = l2_regularization + self.verbose = verbose + + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: + r""" + See base class. The only caveat is that this method raises if it fails + to solve the underlying optimization problem for any reason. + + Raises: + cp.error.SolverError + """ + # Extract parameters + minimum_branch_length = self.minimum_branch_length + l2_regularization = self.l2_regularization + verbose = self.verbose + + # # Wrap the networkx DiGraph for goodies. + # tree = Tree(tree) + + # # # # # Create variables of the optimization problem # # # # # + r_X_t_variables = dict( + [ + (node_id, cp.Variable(name=f"r_X_t_{node_id}")) + for node_id in tree.nodes + ] + ) + a_leaf = tree.leaves[0] + root = tree.root + root_has_time_0_constraint = [r_X_t_variables[root] == 0] + time_increases_constraints = [ + r_X_t_variables[child] + >= r_X_t_variables[parent] + + minimum_branch_length * r_X_t_variables[a_leaf] + for (parent, child) in tree.edges + ] + leaves_have_same_time_constraints = [ + r_X_t_variables[leaf] == r_X_t_variables[a_leaf] + for leaf in tree.leaves + if leaf != a_leaf + ] + non_negative_r_X_t_constraints = [ + r_X_t >= 0 for r_X_t in r_X_t_variables.values() + ] + all_constraints = ( + root_has_time_0_constraint + + time_increases_constraints + + leaves_have_same_time_constraints + + non_negative_r_X_t_constraints + ) + + # # # # # Compute the log-likelihood # # # # # + log_likelihood = 0 + + # Because all rates are equal, the number of cuts in each node is a + # sufficient statistic. This makes the solver WAY faster! + for (parent, child) in tree.edges: + edge_length = r_X_t_variables[child] - r_X_t_variables[parent] + # TODO: hardcoded '0' here... + zeros_parent = tree.get_number_of_unmutated_characters_in_node( + parent + ) + zeros_child = tree.get_number_of_unmutated_characters_in_node(child) + new_cuts_child = zeros_parent - zeros_child + assert new_cuts_child >= 0 + # Add log-lik for characters that didn't get cut + log_likelihood += zeros_child * (-edge_length) + # Add log-lik for characters that got cut + log_likelihood += new_cuts_child * cp.log( + 1 - cp.exp(-edge_length - 1e-8) + ) + + # # # # # Add regularization # # # # # + + l2_penalty = 0 + for (parent, child) in tree.edges: + for child_of_child in tree.children(child): + edge_length_above = ( + r_X_t_variables[child] - r_X_t_variables[parent] + ) + edge_length_below = ( + r_X_t_variables[child_of_child] - r_X_t_variables[child] + ) + l2_penalty += (edge_length_above - edge_length_below) ** 2 + l2_penalty *= l2_regularization + + # # # # # Solve the problem # # # # # + + obj = cp.Maximize(log_likelihood - l2_penalty) + prob = cp.Problem(obj, all_constraints) + + f_star = prob.solve(solver="ECOS", verbose=verbose) + + # # # # # Populate the tree with the estimated branch lengths # # # # # + + times = {node: r_X_t_variables[node].value for node in tree.nodes} + # We smooth out epsilons that might make a parent's time greater + # than its child + for (parent, child) in tree.depth_first_traverse_edges(): + times[child] = max(times[parent], times[child]) + tree.set_times(times) + + log_likelihood = log_likelihood.value + log_loss = f_star + if np.isnan(log_likelihood): + log_likelihood = -np.inf + if np.isnan(log_loss): + log_loss = -np.inf + self.log_likelihood = log_likelihood + self.log_loss = log_loss + + @classmethod + def log_likelihood(self, tree: CassiopeiaTree) -> float: + r""" + The log-likelihood of the given tree under the model. + """ + log_likelihood = 0.0 + for (parent, child) in tree.edges: + edge_length = tree.get_branch_length(parent, child) + n_mutated = tree.get_number_of_mutations_along_edge(parent, child) + n_nonmutated = tree.get_number_of_unmutated_characters_in_node( + child + ) + assert n_mutated >= 0 and n_nonmutated >= 0 + # Add log-lik for characters that didn't get cut + log_likelihood += n_nonmutated * (-edge_length) + # Add log-lik for characters that got cut + if n_mutated > 0: + if edge_length < 1e-8: + return -np.inf + log_likelihood += n_mutated * np.log(1 - np.exp(-edge_length)) + assert not np.isnan(log_likelihood) + return log_likelihood + + +class IIDExponentialBLEGridSearchCV(BranchLengthEstimator): + r""" + Like IIDExponentialBLE but with automatic tuning of hyperparameters. + + This class fits the hyperparameters of IIDExponentialBLE based on + character-level held-out log-likelihood. It leaves out one character at a + time, fitting the data on all the remaining characters. Thus, the number + of models trained by this class is #characters * grid size. + + Args: + minimum_branch_lengths: The grid of minimum_branch_length to use. + l2_regularizations: The grid of l2_regularization to use. + verbose: Verbosity level. + """ + + def __init__( + self, + minimum_branch_lengths: Tuple[float] = (0,), + l2_regularizations: Tuple[float] = (0,), + processes: int = 6, + verbose: bool = False, + ): + self.minimum_branch_lengths = minimum_branch_lengths + self.l2_regularizations = l2_regularizations + self.processes = processes + self.verbose = verbose + + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: + r""" + See base class. The only caveat is that this method raises if it fails + to solve the underlying optimization problem for any reason. + + Raises: + cp.error.SolverError + """ + # Extract parameters + minimum_branch_lengths = self.minimum_branch_lengths + l2_regularizations = self.l2_regularizations + verbose = self.verbose + + held_out_log_likelihoods = [] # type: List[Tuple[float, List]] + grid = np.zeros( + shape=(len(minimum_branch_lengths), len(l2_regularizations)) + ) + for i, minimum_branch_length in enumerate(minimum_branch_lengths): + for j, l2_regularization in enumerate(l2_regularizations): + cv_log_likelihood = self._cv_log_likelihood( + tree=tree, + minimum_branch_length=minimum_branch_length, + l2_regularization=l2_regularization, + ) + held_out_log_likelihoods.append( + ( + cv_log_likelihood, + [minimum_branch_length, l2_regularization], + ) + ) + grid[i, j] = cv_log_likelihood + + # Refit model on full dataset with the best hyperparameters + held_out_log_likelihoods.sort(reverse=True) + ( + best_minimum_branch_length, + best_l2_regularization, + ) = held_out_log_likelihoods[0][1] + if verbose: + print( + f"Refitting full model with:\n" + f"minimum_branch_length={best_minimum_branch_length}\n" + f"l2_regularization={best_l2_regularization}" + ) + final_model = IIDExponentialBLE( + minimum_branch_length=best_minimum_branch_length, + l2_regularization=best_l2_regularization, + ) + final_model.estimate_branch_lengths(tree) + self.minimum_branch_length = best_minimum_branch_length + self.l2_regularization = best_l2_regularization + self.log_likelihood = final_model.log_likelihood + self.log_loss = final_model.log_loss + self.grid = grid + + def _cv_log_likelihood( + self, + tree: CassiopeiaTree, + minimum_branch_length: float, + l2_regularization: float, + ) -> float: + r""" + Given the tree and the parameters of the model, returns the + cross-validated log-likelihood of the model. This is done by holding out + one character at a time, fitting the model on the remaining characters, + and evaluating the log-likelihood on the held-out character. As a + consequence, #character models are fit by this method. The mean held-out + log-likelihood over the #character folds is returned. + """ + verbose = self.verbose + processes = self.processes + if verbose: + print( + f"Cross-validating hyperparameters:" + f"\nminimum_branch_length={minimum_branch_length}" + f"\nl2_regularizations={l2_regularization}" + ) + n_characters = tree.n_character + params = [] + for held_out_character_idx in range(n_characters): + train_tree, valid_tree = self._cv_split( + tree=tree, held_out_character_idx=held_out_character_idx + ) + model = IIDExponentialBLE( + minimum_branch_length=minimum_branch_length, + l2_regularization=l2_regularization, + ) + params.append((model, train_tree, valid_tree)) + with multiprocessing.Pool(processes=processes) as pool: + map_fn = pool.map if processes > 1 else map + log_likelihood_folds = list(map_fn(_fit_model, params)) + if verbose: + print(f"log_likelihood_folds = {log_likelihood_folds}") + return np.mean(np.array(log_likelihood_folds)) + + def _cv_split( + self, tree: CassiopeiaTree, held_out_character_idx: int + ) -> Tuple[CassiopeiaTree, CassiopeiaTree]: + r""" + Creates a training and a cross validation tree by hiding the + character at position held_out_character_idx. + """ + tree_topology = tree.get_tree_topology() + train_states = {} + valid_states = {} + for node in tree.nodes: + state = tree.get_character_states(node) + train_state = ( + state[:held_out_character_idx] + + state[(held_out_character_idx + 1) :] + ) + valid_state = [state[held_out_character_idx]] + train_states[node] = train_state + valid_states[node] = valid_state + train_tree = CassiopeiaTree(tree=tree_topology) + valid_tree = CassiopeiaTree(tree=tree_topology) + train_tree.initialize_all_character_states(train_states) + valid_tree.initialize_all_character_states(valid_states) + return train_tree, valid_tree + + def plot_grid( + self, figure_file: Optional[str] = None, show_plot: bool = True + ): + utils.plot_grid( + grid=self.grid, + yticklabels=self.minimum_branch_lengths, + xticklabels=self.l2_regularizations, + ylabel=r"Minimum Branch Length ($\epsilon$)", + xlabel=r"l2 Regularization ($\lambda$)", + figure_file=figure_file, + show_plot=show_plot, + ) + + +def _fit_model(args): + r""" + This is used by IIDExponentialBLEGridSearchCV to + parallelize the CV folds. It must be defined here (at the top level of + the module) for multiprocessing to be able to pickle it. (This is why + coverage misses it) + """ + model, train_tree, valid_tree = args + assert valid_tree.n_character == 1 + try: + model.estimate_branch_lengths(train_tree) + valid_tree.set_times(train_tree.get_times()) + held_out_log_likelihood = IIDExponentialBLE.log_likelihood(valid_tree) + except cp.error.SolverError: + held_out_log_likelihood = -np.inf + return held_out_log_likelihood diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp new file mode 100644 index 00000000..39eeee02 --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp @@ -0,0 +1,553 @@ +#include +#include +#include +#include +#include +#include +#include + +#define forn(i, n) for(int i = 0; i < int(n); i++) +#define forall(i,c) for(typeof((c).begin()) i = (c).begin();i != (c).end();i++) + +using namespace std; +const int maxN = 8192; +const int maxK = 63; +const int maxT = 501; +const float INF = 1e16; +float _down_cache[maxN][maxT + 1][maxK + 1]; +float _up_cache[maxN][maxT + 1][maxK + 1]; +float log_joints[maxN][maxT + 1]; +float posteriors[maxN][maxT + 1]; +float posterior_means[maxN]; + +string input_dir = ""; +string output_dir = ""; +int N = -1; +vector children[maxN]; +int root = -1; +int is_internal_node[maxN]; +int get_number_of_mutated_characters_in_node[maxN]; +vector non_root_internal_nodes; +vector leaves; +int parent[maxN]; +int is_leaf[maxN]; +int K = -1; +int T = -1; +int enforce_parsimony = -1; +float r; +float lam; + +float logsumexp(const vector & lls){ + float mx = -INF; + for(auto ll: lls){ + mx = max(mx, ll); + } + float res = 0.0; + for(auto ll: lls){ + res += exp(ll - mx); + } + res = log(res) + mx; + return res; +} + +int _compatible_with_observed_data(int x, int observed_cuts){ + if(enforce_parsimony) + return x == observed_cuts; + else + return x <= observed_cuts; +} + +bool _state_is_valid(int v, int t, int x){ + if(v == root) + return x == 0; + int p = parent[v]; + int cuts_v = get_number_of_mutated_characters_in_node[v]; + int cuts_p = get_number_of_mutated_characters_in_node[p]; + if(enforce_parsimony){ + return cuts_p <= x && x <= cuts_v; + } else { + return x <= cuts_v; + } +} + + +void read_N(){ + ifstream fin(input_dir + "/N.txt"); + if(!fin.good()){ + cerr << "N input file not found" << endl; + exit(1); + } + fin >> N; + if(N == -1){ + cerr << "N input corrupted" << endl; + exit(1); + } + if(N > maxN){ + cerr << "N larger than maxN" << endl; + exit(1); + } +} + +void read_children(){ + ifstream fin(input_dir + "/children.txt"); + if(!fin.good()){ + cerr << "children input file not found" << endl; + exit(1); + } + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + int n_children; + fin >> n_children; + for(int i = 0; i < n_children; i++){ + int c; + fin >> c; + children[v].push_back(c); + } + } + if(lines_read != N){ + cerr << "children input corrupted" << endl; + exit(1); + } +} + +void read_root(){ + ifstream fin(input_dir + "/root.txt"); + fin >> root; + if(root == -1){ + cerr << "N input corrupted" << endl; + exit(1); + } +} + +void read_is_internal_node(){ + ifstream fin(input_dir + "/is_internal_node.txt"); + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + fin >> is_internal_node[v]; + } + if(lines_read != N){ + cerr << "is_internal_node input corrupted" << endl; + exit(1); + } +} + +void read_get_number_of_mutated_characters_in_node(){ + ifstream fin(input_dir + "/get_number_of_mutated_characters_in_node.txt"); + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + fin >> get_number_of_mutated_characters_in_node[v]; + } + if(lines_read != N){ + cerr << "get_number_of_mutated_characters_in_node input corrupted" << endl; + exit(1); + } +} + +void read_non_root_internal_nodes(){ + ifstream fin(input_dir + "/non_root_internal_nodes.txt"); + int v; + while(fin >> v){ + non_root_internal_nodes.push_back(v); + } +} + +void read_leaves(){ + ifstream fin(input_dir + "/leaves.txt"); + int v; + while(fin >> v){ + leaves.push_back(v); + } + if(leaves.size() == 0){ + cerr << "leaves input corrupted" << endl; + exit(1); + } +} + +void read_parent(){ + ifstream fin(input_dir + "/parent.txt"); + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + fin >> parent[v]; + } + if(lines_read != N - 1){ + cerr << "parent input corrupted" << endl; + exit(1); + } +} + +void read_is_leaf(){ + ifstream fin(input_dir + "/is_leaf.txt"); + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + fin >> is_leaf[v]; + } + if(lines_read != N){ + cerr << "is_leaf input corrupted" << endl; + exit(1); + } +} + +void read_K(){ + ifstream fin(input_dir + "/K.txt"); + fin >> K; + if(K == -1){ + cerr << "K input corrupted" << endl; + exit(1); + } + if(K > maxK){ + cerr << "K larger than maxK" << endl; + exit(1); + } +} + +void read_T(){ + // T is the discretization level. + ifstream fin(input_dir + "/T.txt"); + fin >> T; + if(T == -1){ + cerr << "T input corrupted" << endl; + exit(1); + } + if(T > maxT){ + cerr << "T larger than maxT" << endl; + exit(1); + } +} + +void read_enforce_parsimony(){ + ifstream fin(input_dir + "/enforce_parsimony.txt"); + fin >> enforce_parsimony; + if(enforce_parsimony == -1){ + cerr << "enforce_parsimony input corrupted" << endl; + exit(1); + } +} + +void read_r(){ + ifstream fin(input_dir + "/r.txt"); + fin >> r; + if(r == -1){ + cerr << "r input corrupted" << endl; + exit(1); + } +} + +void read_lam(){ + ifstream fin(input_dir + "/lam.txt"); + fin >> lam; + if(lam == -1){ + cerr << "lam input corrupted" << endl; + exit(1); + } +} + +float down(int v, int t, int x){ + // Avoid doing anything at all for invalid states. + if(!_state_is_valid(v, t, x)){ + return -INF; + } + if(_down_cache[v][t][x] < 1.0){ + return _down_cache[v][t][x]; + } + // Pull out params + float dt = 1.0 / T; + assert(v != root); + assert(0 <= t && t <= T); + assert(0 <= x && x <= K); + if(!(1.0 - lam * dt - K * r * dt > 0)){ + cerr << "Please choose a bigger discretization_level." << endl; + exit(1); + } + float log_likelihood = 0.0; + if(t == T){ + // Base case + if ( + is_leaf[v] + && (x == get_number_of_mutated_characters_in_node[v]) + ){ + log_likelihood = 0.0; + } else { + log_likelihood = -INF; + } + } + else{ + // Recursion. + vector log_likelihoods_cases; + // Case 1: Nothing happens + log_likelihoods_cases.push_back( + log(1.0 - lam * dt - (K - x) * r * dt) + + down(v, t + 1, x) + ); + // Case 2: One character mutates. + if(x + 1 <= K){ + log_likelihoods_cases.push_back( + log((K - x) * r * dt) + down(v, t + 1, x + 1) + ); + } + // Case 3: Cell divides + // The number of cuts at this state must match the ground truth. + if ( + _compatible_with_observed_data( + x, get_number_of_mutated_characters_in_node[v] + ) + && (!is_leaf[v]) + ){ + float ll = 0.0; + forn(i, children[v].size()){ + int child = children[v][i]; + ll += down(child, t + 1, x);// If we want to ignore missing data, we just have to replace x by x+gone_missing(p->v). I.e. dropped out characters become free mutations. + } + ll += log(lam * dt); + log_likelihoods_cases.push_back(ll); + } + log_likelihood = logsumexp(log_likelihoods_cases); + } + _down_cache[v][t][x] = log_likelihood; + return log_likelihood; +} + +float up(int v, int t, int x){ + // Avoid doing anything at all for invalid states. + if(!_state_is_valid(v, t, x)) + return -INF; + if(_up_cache[v][t][x] < 1.0){ + return _up_cache[v][t][x]; + } + // Pull out params + float dt = 1.0 / T; + assert(0 <= v && v < N); + assert(0 <= t && t <= T); + assert(0 <= x && x <= K); + if(!(1.0 - lam * dt - K * r * dt > 0)){ + cerr << "Please choose a bigger discretization_level." << endl; + exit(1); + } + float log_likelihood = 0.0; + if(v == root){ + // Base case: we reached the root of the tree. + if((t == 0) && (x == get_number_of_mutated_characters_in_node[v])) + log_likelihood = 0.0; + else + log_likelihood = -INF; + } else if(t == 0){ + // Base case: we reached the start of the process, but we're not yet + // at the root. + assert(v != root); + log_likelihood = -INF; + } else { + // Recursion. + vector log_likelihoods_cases; + // Case 1: Nothing happened + log_likelihoods_cases.push_back( + log(1.0 - lam * dt - (K - x) * r * dt) + up(v, t - 1, x) + ); + // Case 2: Mutation happened + if(x - 1 >= 0){ + log_likelihoods_cases.push_back( + log((K - (x - 1)) * r * dt) + up(v, t - 1, x - 1) + ); + } + // Case 3: A cell division happened + if(v != root){ + int p = parent[v]; + if(_compatible_with_observed_data( + x, get_number_of_mutated_characters_in_node[p] // If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. + )){ + vector siblings; + for(auto u: children[p]) + if(u != v) + siblings.push_back(u); + float ll = log(lam * dt) + up(p, t - 1, x); // If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. + for(auto u: siblings){ + ll += down(u, t, x); // If we want to ignore missing data, we just have to replace x by cuts(p)+gone_missing(p->u). I.e. dropped out characters become free mutations. + } + log_likelihoods_cases.push_back(ll); + } + } + log_likelihood = logsumexp(log_likelihoods_cases); + } + _up_cache[v][t][x] = log_likelihood; + return log_likelihood; +} + +void write_down(){ + ofstream fout(output_dir + "/down.txt"); + string res = ""; + forn(v, N){ + if(v == root) continue; + forn(t, T + 1){ + forn(x, K + 1){ + if(_state_is_valid(v, t, x)){ + res += to_string(v) + " " + to_string(t) + " " + to_string(x) + " " + to_string(down(v, t, x)) + "\n"; + } + } + } + } + fout << res; +} + +void write_up(){ + ofstream fout(output_dir + "/up.txt"); + string res = ""; + forn(v, N){ + forn(t, T + 1){ + forn(x, K + 1){ + if(_state_is_valid(v, t, x)){ + res += to_string(v) + " " + to_string(t) + " " + to_string(x) + " " + to_string(up(v, t, x)) + "\n"; + } + } + } + } + fout << res; +} + +void write_log_likelihood(){ + ofstream fout(output_dir + "/log_likelihood.txt"); + float log_likelihood = 0; + for(auto child_of_root: children[root]){ + log_likelihood += down(child_of_root, 0, 0); + } + fout << log_likelihood; +} + +float _compute_log_joint(int v, int t){ + if(!(is_internal_node[v] and v != root)){ + cerr << "_compute_log_joint received invalid inputs" << endl; + exit(1); + } + vector valid_num_cuts; + if(enforce_parsimony){ + valid_num_cuts.push_back(get_number_of_mutated_characters_in_node[v]); + } else { + for(int x = 0; x <= get_number_of_mutated_characters_in_node[v]; x++){ + valid_num_cuts.push_back(x); + } + } + vector ll_for_xs; + for(auto x: valid_num_cuts){ + float ll_for_x = up(v, t, x); + for(auto u: children[v]){ + ll_for_x += down(u, t, x); + } + ll_for_xs.push_back(ll_for_x); + } + return logsumexp(ll_for_xs); +} + +void _write_out_log_joints(){ + ofstream fout(output_dir + "/log_joints.txt"); + string res = ""; + for(auto v: non_root_internal_nodes){ + res += to_string(v); + for(int t = 0; t <= T; t++){ + res += " " + to_string(log_joints[v][t]); + } + res += "\n"; + } + fout << res; +} + +void _write_out_posteriors(){ + // NOTE: copy-pasta of _write_out_log_joints) + ofstream fout(output_dir + "/posteriors.txt"); + string res = ""; + for(auto v: non_root_internal_nodes){ + res += to_string(v); + for(int t = 0; t <= T; t++){ + res += " " + to_string(posteriors[v][t]); + } + res += "\n"; + } + fout << res; +} + +void _write_out_posterior_means(){ + ofstream fout(output_dir + "/posterior_means.txt"); + string res = ""; + for(auto v: non_root_internal_nodes){ + res += to_string(v); + res += " " + to_string(posterior_means[v]); + res += "\n"; + } + fout << res; +} + +void write_posteriors(){ + // mimmicks _compute_posteriors of the python implementation. + for(auto v: non_root_internal_nodes){ + // Compute the log_joints. + for(int t = 0; t <= T; t++){ + log_joints[v][t] = _compute_log_joint(v, t); + } + + // Compute the posteriors + float mx = -INF; + for(int t = 0; t <= T; t++){ + mx = max(mx, log_joints[v][t]); + } + for(int t = 0; t <= T; t++){ + posteriors[v][t] = exp(log_joints[v][t] - mx); + } + // Normalize posteriors + float tot_sum = 0.0; + for(int t = 0; t <= T; t++){ + tot_sum += posteriors[v][t]; + } + for(int t = 0; t <= T; t++){ + posteriors[v][t] /= tot_sum; + } + + // Compute the posterior means. + posterior_means[v] = 0.0; + for(int t = 0; t <= T; t++){ + posterior_means[v] += posteriors[v][t] * t; + } + posterior_means[v] /= float(T); + } + + // Write out the log_joints, posteriors, and posterior_means + _write_out_log_joints(); + _write_out_posteriors(); + _write_out_posterior_means(); +} + + +int main(int argc, char *argv[]){ + if(argc != 3){ + cerr << "Need intput_dir and output_dir arguments, " << argc - 1 << " provided." << endl; + exit(1); + } + input_dir = string(argv[1]); + output_dir = string(argv[2]); + read_N(); + read_children(); + read_root(); + read_is_internal_node(); + read_get_number_of_mutated_characters_in_node(); + read_non_root_internal_nodes(); + read_leaves(); + read_parent(); + read_is_leaf(); + read_K(); + read_T(); + read_enforce_parsimony(); + read_r(); + read_lam(); + + forn(v, N) forn(t, T + 1) forn(k, K + 1) _down_cache[v][t][k] = _up_cache[v][t][k] = 1.0; + write_down(); + write_up(); + write_log_likelihood(); + write_posteriors(); + return 0; +} diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py new file mode 100644 index 00000000..7eccb4ec --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -0,0 +1,876 @@ +import multiprocessing +import os +import subprocess +import tempfile +import time +from copy import deepcopy +from typing import List, Optional, Tuple + +import numpy as np +from scipy import integrate +from scipy.special import binom, logsumexp + +from cassiopeia.data import CassiopeiaTree + +from . import utils +from .BranchLengthEstimator import ( + BranchLengthEstimator, + BranchLengthEstimatorError, +) + + +class IIDExponentialPosteriorMeanBLE(BranchLengthEstimator): + r""" + TODO: Update to match my technical write-up. + + Same model as IIDExponentialBLE but computes the posterior mean instead + of the MLE. The phylogeny model is chosen to be a birth process. + + This estimator requires that the ancestral states are provided. + + TODO: Use numpy autograd to optimize the hyperparams? (Empirical Bayes) + + We compute the posterior means using a forward-backward-style algorithm + (DP on a tree). + + Args: TODO + + Attributes: TODO + + """ + + def __init__( + self, + mutation_rate: float, + birth_rate: float, + discretization_level: int, + enforce_parsimony: bool = True, + use_cpp_implementation: bool = False, + debug_cpp_implementation: bool = False, + verbose: bool = False, + ) -> None: + # TODO: If we use autograd, we can tune the hyperparams with gradient + # descent? + self.mutation_rate = mutation_rate + # TODO: Is there some easy heuristic way to set this to a reasonable + # value and thus avoid grid searching it / optimizing it? + self.birth_rate = birth_rate + self.discretization_level = discretization_level + self.enforce_parsimony = enforce_parsimony + self.use_cpp_implementation = use_cpp_implementation + self.debug_cpp_implementation = debug_cpp_implementation + self.verbose = verbose + + def _compute_log_likelihood(self): + tree = self.tree + log_likelihood = 0 + # TODO: Should I also add a division event when the root has multiple + # children? (If not, the joint we are computing won't integrate to 1; + # on the other hand, this is a constant multiplicative term that doesn't + # affect inference. + for child_of_root in tree.children(tree.root): + log_likelihood += self.down(child_of_root, 0, 0) + self.log_likelihood = log_likelihood + + def _compute_log_joint(self, v, t): + r""" + P(t_v = t, X, T). + Depending on whether we are enforcing parsimony or not, we consider + different possible number of cuts for v. + """ + tree = self.tree + assert tree.is_internal_node(v) and v != tree.root + enforce_parsimony = self.enforce_parsimony + children = tree.children(v) + if enforce_parsimony: + valid_num_cuts = [tree.get_number_of_mutated_characters_in_node(v)] + else: + valid_num_cuts = range( + tree.get_number_of_mutated_characters_in_node(v) + 1 + ) + ll_for_xs = [] + for x in valid_num_cuts: + ll_for_xs.append( + sum([self.down(u, t, x) for u in children]) + self.up(v, t, x) + ) + return logsumexp(ll_for_xs) + + def _compute_posteriors(self): + tree = self.tree + discretization_level = self.discretization_level + log_joints = {} # log P(t_v = t, X, T) + posteriors = {} # P(t_v = t | X, T) + posterior_means = {} # E[t_v = t | X, T] + for v in tree.non_root_internal_nodes: + # Compute the posterior for this node + log_joint = np.zeros(shape=(discretization_level + 1,)) + for t in range(discretization_level + 1): + log_joint[t] = self._compute_log_joint(v, t) + log_joints[v] = log_joint.copy() + posterior = np.exp(log_joint - log_joint.max()) + posterior /= np.sum(posterior) + posteriors[v] = posterior + posterior_means[v] = ( + posterior * np.array(range(discretization_level + 1)) + ).sum() / discretization_level + self.log_joints = log_joints + self.posteriors = posteriors + self.posterior_means = posterior_means + + def _populate_branch_lengths(self): + tree = self.tree + posterior_means = self.posterior_means + times = {} + for node in tree.non_root_internal_nodes: + times[node] = posterior_means[node] + times[tree.root] = 0.0 + for leaf in tree.leaves: + times[leaf] = 1.0 + tree.set_times(times) + + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: + r""" + See base class. + """ + self._down_cache = {} + self._up_cache = {} + self.tree = tree + verbose = self.verbose + if self.debug_cpp_implementation: + # Write out true dp values to check by eye against c++ + # implementation values. + self._write_out_dps() + if self.use_cpp_implementation: + time_cpp_start = time.time() + if self.debug_cpp_implementation: + # Use a directory that won't go away. + self._populate_attributes_with_cpp_implementation( + tmp_dir=os.getcwd() + "/tmp" + ) + else: + # Use a temporary directory. + with tempfile.TemporaryDirectory() as tmp_dir: + self._populate_attributes_with_cpp_implementation(tmp_dir) + time_cpp_end = time.time() + if verbose: + print(f"time_cpp = {time_cpp_end - time_cpp_start}") + else: + time_compute_log_likelihood_start = time.time() + self._compute_log_likelihood() + time_compute_log_likelihood_end = time.time() + if verbose: + print( + f"time_compute_log_likelihood (dp_down) = {time_compute_log_likelihood_end - time_compute_log_likelihood_start}" + ) + time_compute_posteriors_start = time.time() + self._compute_posteriors() + time_compute_posteriors_end = time.time() + if verbose: + print( + f"time_compute_posteriors (dp_up) = {time_compute_posteriors_end - time_compute_posteriors_start}" + ) + time_populate_branch_lengths_start = time.time() + self._populate_branch_lengths() + time_populate_branch_lengths_end = time.time() + if verbose: + print( + f"time_populate_branch_lengths = {time_populate_branch_lengths_end - time_populate_branch_lengths_start}" + ) + + def _write_out_dps(self): + r""" + For debugging the c++ implementation: + This writes out the down and up values of the correct python + implementation to the files tmp/down_true.txt + and + tmp/up_true.txt + respectively. + Compare these against tmp/down.txt and tmp/up.txt, which are the values + computed by the c++ implementation. + """ + tree = self.tree + N = len(tree.nodes) + T = self.discretization_level + K = tree.n_character + id_to_node = dict(zip(range(len(tree.nodes)), tree.nodes)) + + if not os.path.exists("tmp"): + os.mkdir("tmp") + + res = "" + for v_id in range(N): + v = id_to_node[v_id] + if v == tree.root: + continue + for t in range(T + 1): + for x in range(K + 1): + if self._state_is_valid(v, t, x): + res += ( + str(v_id) + + " " + + str(t) + + " " + + str(x) + + " " + + str(self.down(v, t, x)) + + "\n" + ) + with open("tmp/down_true.txt", "w") as fout: + fout.write(res) + + res = "" + for v_id in range(N): + v = id_to_node[v_id] + for t in range(T + 1): + for x in range(K + 1): + if self._state_is_valid(v, t, x): + res += ( + str(v_id) + + " " + + str(t) + + " " + + str(x) + + " " + + str(self.up(v, t, x)) + + "\n" + ) + with open("tmp/up_true.txt", "w") as fout: + fout.write(res) + + def _write_out_list_of_lists(self, lls: List[List[int]], filename: str): + res = "" + for l in lls: + for i, x in enumerate(l): + if i: + res += " " + res += str(x) + res += "\n" + with open(filename, "w") as file: + file.write(res) + + def _populate_attributes_with_cpp_implementation(self, tmp_dir): + r""" + A cpp implementation is run to compute up and down caches, which is + the computational bottleneck. The other attributes such as the + log-likelihood, and the posteriors, are also populated because + even these trivial computations are too slow in vanilla python. + Looking forward, a cython implementation will hopefully be the + best way forward. + To remove anything that has to do with the cpp implementation, you just + have to remove this function (and the gates around it). + I.e., this python implementation is loosely coupled to the cpp call: we + just have to remove the call to this method to turn it off, and all + other code will work just fine. This is because all that this method + does is *warm up the cache* with values computed from the cpp + subprocess, and the caching process is totally transparent to the + other methods of the class. + """ + # First extract the relevant information from the tree and serialize it. + tree = self.tree + node_to_id = dict(zip(tree.nodes, range(len(tree.nodes)))) + id_to_node = dict(zip(range(len(tree.nodes)), tree.nodes)) + + N = [[len(tree.nodes)]] + if not os.path.exists(tmp_dir): + os.mkdir(tmp_dir) + self._write_out_list_of_lists(N, f"{tmp_dir}/N.txt") + + children = [ + [node_to_id[v]] + + [len(tree.children(v))] + + [node_to_id[c] for c in tree.children(v)] + for v in tree.nodes + ] + self._write_out_list_of_lists(children, f"{tmp_dir}/children.txt") + + root = [[node_to_id[tree.root]]] + self._write_out_list_of_lists(root, f"{tmp_dir}/root.txt") + + is_internal_node = [ + [node_to_id[v], 1 * tree.is_internal_node(v)] for v in tree.nodes + ] + self._write_out_list_of_lists( + is_internal_node, f"{tmp_dir}/is_internal_node.txt" + ) + + get_number_of_mutated_characters_in_node = [ + [node_to_id[v], tree.get_number_of_mutated_characters_in_node(v)] + for v in tree.nodes + ] + self._write_out_list_of_lists( + get_number_of_mutated_characters_in_node, + f"{tmp_dir}/get_number_of_mutated_characters_in_node.txt", + ) + + non_root_internal_nodes = [ + [node_to_id[v]] for v in tree.non_root_internal_nodes + ] + self._write_out_list_of_lists( + non_root_internal_nodes, f"{tmp_dir}/non_root_internal_nodes.txt" + ) + + leaves = [[node_to_id[v]] for v in tree.leaves] + self._write_out_list_of_lists(leaves, f"{tmp_dir}/leaves.txt") + + parent = [ + [node_to_id[v], node_to_id[tree.parent(v)]] + for v in tree.nodes + if v != tree.root + ] + self._write_out_list_of_lists(parent, f"{tmp_dir}/parent.txt") + + K = [[tree.n_character]] + self._write_out_list_of_lists(K, f"{tmp_dir}/K.txt") + + T = [[self.discretization_level]] + self._write_out_list_of_lists(T, f"{tmp_dir}/T.txt") + + enforce_parsimony = [[1 * self.enforce_parsimony]] + self._write_out_list_of_lists( + enforce_parsimony, f"{tmp_dir}/enforce_parsimony.txt" + ) + + r = [[self.mutation_rate]] + self._write_out_list_of_lists(r, f"{tmp_dir}/r.txt") + + lam = [[self.birth_rate]] + self._write_out_list_of_lists(lam, f"{tmp_dir}/lam.txt") + + is_leaf = [[node_to_id[v], 1 * tree.is_leaf(v)] for v in tree.nodes] + self._write_out_list_of_lists(is_leaf, f"{tmp_dir}/is_leaf.txt") + + # Run the c++ implementation + try: + # os.system('IIDExponentialPosteriorMeanBLE') + subprocess.run( + [ + "./IIDExponentialPosteriorMeanBLE", + f"{tmp_dir}", + f"{tmp_dir}", + ], + check=True, + cwd=os.path.dirname(__file__), + ) + except subprocess.CalledProcessError: + raise BranchLengthEstimatorError( + "Couldn't run c++ implementation," + " or c++ implementation started running and errored." + ) + + # Load the c++ implementation results into the cache + with open(f"{tmp_dir}/down.txt", "r") as fin: + for line in fin: + v, t, x, ll = line.split(" ") + v, t, x = int(v), int(t), int(x) + ll = float(ll) + self._down_cache[(id_to_node[v], t, x)] = ll + with open(f"{tmp_dir}/up.txt", "r") as fin: + for line in fin: + v, t, x, ll = line.split(" ") + v, t, x = int(v), int(t), int(x) + ll = float(ll) + self._up_cache[(id_to_node[v], t, x)] = ll + + discretization_level = self.discretization_level + + # Load the log_likelihood + with open(f"{tmp_dir}/log_likelihood.txt", "r") as fin: + self.log_likelihood = float(fin.read()) + + # Load the posteriors + log_joints = {} # log P(t_v = t, X, T) + with open(f"{tmp_dir}/log_joints.txt", "r") as fin: + for line in fin: + vals = line.split(" ") + assert len(vals) == discretization_level + 2 + v_id = int(vals[0]) + log_joint = np.zeros(shape=(discretization_level + 1,)) + for i, val in enumerate(vals[1:]): + log_joint[i] = float(val) + log_joints[id_to_node[v_id]] = log_joint + + posteriors = {} # P(t_v = t | X, T) + with open(f"{tmp_dir}/posteriors.txt", "r") as fin: + for line in fin: + vals = line.split(" ") + assert len(vals) == discretization_level + 2 + v_id = int(vals[0]) + posterior = np.zeros(shape=(discretization_level + 1,)) + for i, val in enumerate(vals[1:]): + posterior[i] = float(val) + posteriors[id_to_node[v_id]] = posterior + + posterior_means = {} # E[t_v = t | X, T] + with open(f"{tmp_dir}/posterior_means.txt", "r") as fin: + for line in fin: + v_id, val = line.split(" ") + v_id = int(v_id) + val = float(val) + posterior_means[id_to_node[v_id]] = val + + self.log_joints = log_joints + self.posteriors = posteriors + self.posterior_means = posterior_means + + def _compatible_with_observed_data(self, x, observed_cuts) -> bool: + if self.enforce_parsimony: + return x == observed_cuts + else: + return x <= observed_cuts + + def _state_is_valid(self, v, t, x) -> bool: + r""" + Used to optimize the DP by avoiding states with 0 probability. + The number of mutations should be between those of v and its parent. + """ + tree = self.tree + if v == tree.root: + return x == 0 + p = tree.parent(v) + cuts_v = tree.get_number_of_mutated_characters_in_node(v) + cuts_p = tree.get_number_of_mutated_characters_in_node(p) + if self.enforce_parsimony: + return cuts_p <= x <= cuts_v + else: + return x <= cuts_v + + def up(self, v, t, x) -> float: + r""" + TODO: Rename this _up? + log P(X_up(b(v)), T_up(b(v)), t \in t_b(v), X_b(v)(t) = x) + TODO: Update to match my technical write-up. + """ + # Avoid doing anything at all for invalid states. + if not self._state_is_valid(v, t, x): + return -np.inf + if (v, t, x) in self._up_cache: # TODO: Use arrays? + # TODO: Use a decorator instead of a hand-made cache? + return self._up_cache[(v, t, x)] + if self.use_cpp_implementation and not self.debug_cpp_implementation: + raise ValueError( + f"Bug in cpp implementation: State up({(v, t, x)})" + f" was not populated." + ) + # Pull out params + r = self.mutation_rate + lam = self.birth_rate + dt = 1.0 / self.discretization_level + K = self.tree.n_character + tree = self.tree + assert 0 <= t <= self.discretization_level + assert 0 <= x <= K + if not (1.0 - lam * dt - K * r * dt > 0): + raise ValueError("Please choose a bigger discretization_level.") + log_likelihood = 0.0 + if v == tree.root: # Base case: we reached the root of the tree. + if t == 0 and x == tree.get_number_of_mutated_characters_in_node(v): + log_likelihood = 0.0 + else: + log_likelihood = -np.inf + elif t == 0: + # Base case: we reached the start of the process, but we're not yet + # at the root. + assert v != tree.root + log_likelihood = -np.inf + else: # Recursion. + log_likelihoods_cases = [] + # Case 1: Nothing happened + log_likelihoods_cases.append( + np.log(1.0 - lam * dt - (K - x) * r * dt) + self.up(v, t - 1, x) + ) + # Case 2: Mutation happened + if x - 1 >= 0: + log_likelihoods_cases.append( + np.log((K - (x - 1)) * r * dt) + self.up(v, t - 1, x - 1) + ) + # Case 3: A cell division happened + if v != tree.root: + p = tree.parent(v) + if self._compatible_with_observed_data( + x, + tree.get_number_of_mutated_characters_in_node( + p + ), # If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. + ): + siblings = [u for u in tree.children(p) if u != v] + ll = ( + np.log(lam * dt) + + self.up( + p, t - 1, x + ) # If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. + + sum( + [self.down(u, t, x) for u in siblings] + ) # If we want to ignore missing data, we just have to replace x by cuts(p)+gone_missing(p->u). I.e. dropped out characters become free mutations. + ) + log_likelihoods_cases.append(ll) + log_likelihood = logsumexp(log_likelihoods_cases) + self._up_cache[(v, t, x)] = log_likelihood + return log_likelihood + + def down(self, v, t, x) -> float: + r""" + TODO: Rename this _down? + log P(X_down(v), T_down(v) | t_v = t, X_v = x) + TODO: Update to match my technical write-up. + """ + # Avoid doing anything at all for invalid states. + if not self._state_is_valid(v, t, x): + return -np.inf + if (v, t, x) in self._down_cache: + # TODO: Use a decorator instead of a hand-made cache? + return self._down_cache[(v, t, x)] + if self.use_cpp_implementation and not self.debug_cpp_implementation: + raise ValueError( + f"Bug in cpp implementation: State " + f"down({(v, t, x)}) was not populated." + ) + # Pull out params + discretization_level = self.discretization_level + r = self.mutation_rate + lam = self.birth_rate + dt = 1.0 / self.discretization_level + K = self.tree.n_character + tree = self.tree + assert v != tree.root + assert 0 <= t <= self.discretization_level + assert 0 <= x <= K + if not (1.0 - lam * dt - K * r * dt > 0): + raise ValueError("Please choose a bigger discretization_level.") + log_likelihood = 0.0 + if t == discretization_level: # Base case + if tree.is_leaf( + v + ) and x == tree.get_number_of_mutated_characters_in_node(v): + log_likelihood = 0.0 + else: + log_likelihood = -np.inf + else: # Recursion. + log_likelihoods_cases = [] + # Case 1: Nothing happens + log_likelihoods_cases.append( + np.log(1.0 - lam * dt - (K - x) * r * dt) + + self.down(v, t + 1, x) + ) + # Case 2: One character mutates. + if x + 1 <= K: + log_likelihoods_cases.append( + np.log((K - x) * r * dt) + self.down(v, t + 1, x + 1) + ) + # Case 3: Cell divides + # The number of cuts at this state must match the ground truth. + if self._compatible_with_observed_data( + x, tree.get_number_of_mutated_characters_in_node(v) + ) and not tree.is_leaf(v): + ll = sum( + [ + self.down(child, t + 1, x) for child in tree.children(v) + ] # If we want to ignore missing data, we just have to replace x by x+gone_missing(p->v). I.e. dropped out characters become free mutations. + ) + np.log(lam * dt) + log_likelihoods_cases.append(ll) + log_likelihood = logsumexp(log_likelihoods_cases) + self._down_cache[(v, t, x)] = log_likelihood + return log_likelihood + + @classmethod + def exact_log_full_joint( + self, tree: CassiopeiaTree, mutation_rate: float, birth_rate: float + ) -> float: + r""" + log P(T, X, branch_lengths), i.e. the full joint log likelihood given + both character vectors _and_ branch lengths. + """ + tree = deepcopy(tree) + ll = 0.0 + lam = birth_rate + r = mutation_rate + lg = np.log + e = np.exp + b = binom + for (p, c) in tree.edges: + t = tree.get_branch_length(p, c) + # Birth process likelihood + ll += -t * lam + if c not in tree.leaves: + ll += lg(lam) + # Mutation process likelihood + cuts = tree.get_number_of_mutations_along_edge(p, c) + uncuts = tree.get_number_of_unmutated_characters_in_node(c) + ll += ( + (-t * r) * uncuts + + lg(1 - e(-t * r)) * cuts + + lg(b(cuts + uncuts, cuts)) + ) + return ll + + @classmethod + def numerical_log_likelihood( + self, + tree: CassiopeiaTree, + mutation_rate: float, + birth_rate: float, + epsrel: float = 0.01, + ): + r""" + log P(T, X), i.e. the marginal log likelihood given _only_ tree + topology and character vectors (including those of internal nodes). + It is computed with a grid. + """ + + tree = deepcopy(tree) + + def f(*args): + times_list = args + times = {} + for node, time in list( + zip(tree.non_root_internal_nodes, times_list) + ): + times[node] = time + times[tree.root] = 0 + for leaf in tree.leaves: + times[leaf] = 1.0 + for (p, c) in tree.edges: + if times[p] >= times[c]: + return 0.0 + tree.set_times(times) + return np.exp( + IIDExponentialPosteriorMeanBLE.exact_log_full_joint( + tree=tree, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + ) + ) + + res = np.log( + integrate.nquad( + f, + [[0, 1]] * len(tree.non_root_internal_nodes), + opts={"epsrel": epsrel}, + )[0] + ) + assert not np.isnan(res) + return res + + @classmethod + def numerical_log_joint( + self, + tree: CassiopeiaTree, + node, + mutation_rate: float, + birth_rate: float, + discretization_level: int, + epsrel: float = 0.01, + ): + r""" + log P(t_node = t, X, T) for each t in the interval [0, 1] discretized + to the level discretization_level + """ + res = np.zeros(shape=(discretization_level + 1,)) + other_nodes = [n for n in tree.non_root_internal_nodes if n != node] + node_time = -1 + + tree = deepcopy(tree) + + def f(*args): + times_list = args + times = {} + times[node] = node_time + assert len(other_nodes) == len(times_list) + for other_node, time in list(zip(other_nodes, times_list)): + times[other_node] = time + times[tree.root] = 0 + for leaf in tree.leaves: + times[leaf] = 1.0 + for (p, c) in tree.edges: + if times[p] >= times[c]: + return 0.0 + tree.set_times(times) + return np.exp( + IIDExponentialPosteriorMeanBLE.exact_log_full_joint( + tree=tree, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + ) + ) + + for i in range(discretization_level + 1): + node_time = i / discretization_level + if len(other_nodes) == 0: + # There is nothing to integrate over. + times = {} + times[tree.root] = 0 + for leaf in tree.leaves: + times[leaf] = 1.0 + times[node] = node_time + tree.set_times(times) + res[i] = self.exact_log_full_joint( + tree=tree, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + ) + res[i] -= np.log(discretization_level) + else: + res[i] = ( + np.log( + integrate.nquad( + f, + [[0, 1]] * (len(tree.non_root_internal_nodes) - 1), + opts={"epsrel": epsrel}, + )[0] + ) + - np.log(discretization_level) + ) + assert not np.isnan(res[i]) + + return res + + @classmethod + def numerical_posterior( + self, + tree: CassiopeiaTree, + node, + mutation_rate: float, + birth_rate: float, + discretization_level: int, + epsrel: float = 0.01, + ): + numerical_log_joint = self.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + epsrel=epsrel, + ) + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + return numerical_posterior + + +def _fit_model(model_and_tree): + r""" + This is used by IIDExponentialPosteriorMeanBLEGridSearchCV to + parallelize the grid search. It must be defined here (at the top level of + the module) for multiprocessing to be able to pickle it. (This is why + coverage misses it) + """ + model, tree = model_and_tree + model.estimate_branch_lengths(tree) + return model.log_likelihood + + +class IIDExponentialPosteriorMeanBLEGridSearchCV(BranchLengthEstimator): + r""" + Like IIDExponentialPosteriorMeanBLE but with automatic tuning of + hyperparameters. + + This class fits the hyperparameters of IIDExponentialPosteriorMeanBLE based + on data log-likelihood. I.e. is performs empirical Bayes. + + Args: TODO + """ + + def __init__( + self, + mutation_rates: Tuple[float] = (0,), + birth_rates: Tuple[float] = (0,), + discretization_level: int = 1000, + enforce_parsimony: bool = True, + use_cpp_implementation: bool = False, + processes: int = 6, + verbose: bool = False, + ): + self.mutation_rates = mutation_rates + self.birth_rates = birth_rates + self.discretization_level = discretization_level + self.enforce_parsimony = enforce_parsimony + self.use_cpp_implementation = use_cpp_implementation + self.processes = processes + self.verbose = verbose + + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: + r""" + See base class. + """ + mutation_rates = self.mutation_rates + birth_rates = self.birth_rates + discretization_level = self.discretization_level + enforce_parsimony = self.enforce_parsimony + use_cpp_implementation = self.use_cpp_implementation + processes = self.processes + verbose = self.verbose + + lls = [] + grid = np.zeros(shape=(len(mutation_rates), len(birth_rates))) + models = [] + mutation_and_birth_rates = [] + ijs = [] + for i, mutation_rate in enumerate(mutation_rates): + for j, birth_rate in enumerate(birth_rates): + if self.verbose: + print( + f"Fitting model with:\n" + f"mutation_rate={mutation_rate}\n" + f"birth_rate={birth_rate}" + ) + models.append( + IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + enforce_parsimony=enforce_parsimony, + use_cpp_implementation=use_cpp_implementation, + ) + ) + mutation_and_birth_rates.append((mutation_rate, birth_rate)) + ijs.append((i, j)) + with multiprocessing.Pool(processes=processes) as pool: + map_fn = pool.map if processes > 1 else map + lls = list( + map_fn( + _fit_model, + zip(models, [deepcopy(tree) for _ in range(len(models))]), + ) + ) + lls_and_rates = list(zip(lls, mutation_and_birth_rates)) + for ll, (i, j) in list(zip(lls, ijs)): + grid[i, j] = ll + lls_and_rates.sort(reverse=True) + (best_mutation_rate, best_birth_rate,) = lls_and_rates[ + 0 + ][1] + if verbose: + print( + f"Refitting model with:\n" + f"best_mutation_rate={best_mutation_rate}\n" + f"best_birth_rate={best_birth_rate}" + ) + final_model = IIDExponentialPosteriorMeanBLE( + mutation_rate=best_mutation_rate, + birth_rate=best_birth_rate, + discretization_level=discretization_level, + enforce_parsimony=enforce_parsimony, + use_cpp_implementation=use_cpp_implementation, + ) + final_model.estimate_branch_lengths(tree) + self.mutation_rate = best_mutation_rate + self.birth_rate = best_birth_rate + self.log_likelihood = final_model.log_likelihood + self.log_joints = final_model.log_joints + self.posteriors = final_model.posteriors + self.posterior_means = final_model.posterior_means + self.grid = grid + + def plot_grid( + self, figure_file: Optional[str] = None, show_plot: bool = True + ): + utils.plot_grid( + grid=self.grid, + yticklabels=self.mutation_rates, + xticklabels=self.birth_rates, + ylabel=r"Mutation Rate ($r$)", + xlabel=r"Birth Rate ($\lambda$)", + figure_file=figure_file, + show_plot=show_plot, + ) diff --git a/cassiopeia/tools/branch_length_estimator/__init__.py b/cassiopeia/tools/branch_length_estimator/__init__.py new file mode 100644 index 00000000..f0c15231 --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/__init__.py @@ -0,0 +1,7 @@ +from .BLEMultifurcationWrapper import BLEMultifurcationWrapper +from .BranchLengthEstimator import BranchLengthEstimator +from .IIDExponentialBLE import IIDExponentialBLE, IIDExponentialBLEGridSearchCV +from .IIDExponentialPosteriorMeanBLE import ( + IIDExponentialPosteriorMeanBLE, + IIDExponentialPosteriorMeanBLEGridSearchCV, +) diff --git a/cassiopeia/tools/branch_length_estimator/utils.py b/cassiopeia/tools/branch_length_estimator/utils.py new file mode 100644 index 00000000..5a3406e5 --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/utils.py @@ -0,0 +1,27 @@ +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt +from typing import Optional + + +def plot_grid( + grid, + yticklabels, + xticklabels, + ylabel, + xlabel, + figure_file: Optional[str], + show_plot: str = True, +) -> None: + sns.heatmap( + grid, + yticklabels=yticklabels, + xticklabels=xticklabels, + mask=np.isneginf(grid), + ) + plt.ylabel(ylabel) + plt.xlabel(xlabel) + if figure_file: + plt.savefig(fname=figure_file) + if show_plot: + plt.show() diff --git a/cassiopeia/tools/cell_subsampler.py b/cassiopeia/tools/cell_subsampler.py new file mode 100644 index 00000000..1c6f5fd9 --- /dev/null +++ b/cassiopeia/tools/cell_subsampler.py @@ -0,0 +1,119 @@ +import abc +import networkx as nx +import numpy as np +from typing import Optional + +from cassiopeia.data import CassiopeiaTree + + +class CellSubsamplerError(Exception): + """An Exception class for the CellSubsampler class.""" + + pass + + +class CellSubsampler(abc.ABC): + r""" + Abstract base class for all cell samplers. + + A CellSubsampler implements a method 'subsample' which, given a Tree, + returns a second Tree which is the result of subsampling cells + (i.e. leafs) of the tree. Only the tree topology will be created for the + new tree. + """ + + @abc.abstractmethod + def subsample(self, tree: CassiopeiaTree) -> CassiopeiaTree: + r""" + Returns a new CassiopeiaTree which is the result of subsampling + the cells in the original CassiopeiaTree. + + Args: + tree: The tree for which to subsample leaves. + """ + + +class UniformCellSubsampler(CellSubsampler): + def __init__( + self, ratio: Optional[float] = None, n_cells: Optional[int] = None + ): + r""" + Samples 'ratio' of the leaves, rounded down, uniformly at random. + """ + if ratio is None and n_cells is None: + raise CellSubsamplerError( + "At least one of 'ratio' and 'n_cells' " "must be specified." + ) + if ratio is not None and n_cells is not None: + raise CellSubsamplerError( + "Exactly one of 'ratio' and 'n_cells'" "must be specified." + ) + self.__ratio = ratio + self.__n_cells = n_cells + + def subsample(self, tree: CassiopeiaTree) -> CassiopeiaTree: + ratio = self.__ratio + n_cells = self.__n_cells + n_subsample = ( + n_cells if n_cells is not None else int(tree.n_cell * ratio) + ) + if n_subsample == 0: + raise CellSubsamplerError( + "ratio too low: no cells would be " "sampled." + ) + + # First determine which nodes are part of the induced subgraph. + leaf_keep_idx = np.random.choice( + range(tree.n_cell), n_subsample, replace=False + ) + leaves_in_induced_subtree = [tree.leaves[i] for i in leaf_keep_idx] + induced_subtree_degs = dict( + [(leaf, 0) for leaf in leaves_in_induced_subtree] + ) + + nodes_in_induced_subtree = set(leaves_in_induced_subtree) + for node in tree.depth_first_traverse_nodes(postorder=True): + children = tree.children(node) + induced_subtree_deg = sum( + [child in nodes_in_induced_subtree for child in children] + ) + if induced_subtree_deg > 0: + nodes_in_induced_subtree.add(node) + induced_subtree_degs[node] = induced_subtree_deg + + # For debugging: + # print(f"leaves_in_induced_subtree = {leaves_in_induced_subtree}") + # print(f"nodes_in_induced_subtree = {nodes_in_induced_subtree}") + # print(f"induced_subtree_degs = {induced_subtree_degs}") + nodes = [] + edges = [] + up = {} + for node in tree.depth_first_traverse_nodes(postorder=False): + if node == tree.root: + nodes.append(node) + up[node] = node + continue + + if node not in nodes_in_induced_subtree: + continue + + if induced_subtree_degs[tree.parent(node)] >= 2: + up[node] = tree.parent(node) + else: + up[node] = up[tree.parent(node)] + + if ( + induced_subtree_degs[node] >= 2 + or induced_subtree_degs[node] == 0 + ): + nodes.append(node) + edges.append((up[node], node)) + subtree_topology = nx.DiGraph() + subtree_topology.add_nodes_from(nodes) + subtree_topology.add_edges_from(edges) + res = CassiopeiaTree( + tree=subtree_topology, + ) + # Copy times over + res.set_times(dict([(node, tree.get_time(node)) for node in res.nodes])) + return res diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py new file mode 100644 index 00000000..6a070ef1 --- /dev/null +++ b/cassiopeia/tools/lineage_simulator.py @@ -0,0 +1,284 @@ +import abc +from typing import List + +import networkx as nx +import numpy as np +from queue import Queue + +from cassiopeia.data import CassiopeiaTree + + +class LineageSimulator(abc.ABC): + r""" + Abstract base class for lineage simulators. + + A LineageSimulator implements the method simulate_lineage that generates a + lineage tree (i.e. a phylogeny, in the form of a Tree). + """ + + @abc.abstractmethod + def simulate_lineage(self) -> CassiopeiaTree: + r""" + Simulates a lineage tree, i.e. a Tree with branch lengths and age + specified for each node. Additional information such as cell fitness, + etc. might be specified by more complex simulators. + """ + + +class PerfectBinaryTree(LineageSimulator): + r""" + Generates a perfect binary tree with given branch lengths at each depth. + + Args: + generation_branch_lengths: The branches at depth d in the tree will have + length generation_branch_lengths[d] + """ + + def __init__(self, generation_branch_lengths: List[float]): + self.generation_branch_lengths = generation_branch_lengths[:] + + def simulate_lineage(self) -> CassiopeiaTree: + r""" + See base class. + """ + generation_branch_lengths = self.generation_branch_lengths + n_generations = len(generation_branch_lengths) + tree = nx.DiGraph() + tree.add_nodes_from(range(2 ** (n_generations + 1) - 1)) + edges = [ + (int((child - 1) / 2), child) + for child in range(1, 2 ** (n_generations + 1) - 1) + ] + node_generation = [] + for i in range(n_generations + 1): + node_generation += [i] * 2 ** i + tree.add_edges_from(edges) + for (parent, child) in edges: + parent_generation = node_generation[parent] + branch_length = generation_branch_lengths[parent_generation] + tree.edges[parent, child]["length"] = branch_length + tree.nodes[0]["age"] = sum(generation_branch_lengths) + for child in range(1, 2 ** (n_generations + 1) - 1): + child_generation = node_generation[child] + branch_length = generation_branch_lengths[child_generation - 1] + tree.nodes[child]["age"] = ( + tree.nodes[int((child - 1) / 2)]["age"] - branch_length + ) + times = {} + for node in tree.nodes: + times[node] = tree.nodes[0]["age"] - tree.nodes[node]["age"] + res = CassiopeiaTree(tree=tree) + res.set_times(times) + return res + + +class PerfectBinaryTreeWithRootBranch(LineageSimulator): + r""" + Generates a perfect binary tree *hanging from a branch*, with given branch + lengths at each depth. + + Args: + generation_branch_lengths: The branches at depth d in the tree will have + length generation_branch_lengths[d] + """ + + def __init__(self, generation_branch_lengths: List[float]): + self.generation_branch_lengths = generation_branch_lengths + + def simulate_lineage(self) -> CassiopeiaTree: + r""" + See base class. + """ + # generation_branch_lengths = self.generation_branch_lengths + generation_branch_lengths = self.generation_branch_lengths + n_generations = len(generation_branch_lengths) + tree = nx.DiGraph() + tree.add_nodes_from(range(2 ** n_generations)) + edges = [ + (int(child / 2), child) for child in range(1, 2 ** n_generations) + ] + tree.add_edges_from(edges) + node_generation = [0] + for i in range(n_generations): + node_generation += [i + 1] * 2 ** i + for (parent, child) in edges: + parent_generation = node_generation[parent] + branch_length = generation_branch_lengths[parent_generation] + tree.edges[parent, child]["length"] = branch_length + tree.nodes[0]["age"] = sum(generation_branch_lengths) + for child in range(1, 2 ** n_generations): + child_generation = node_generation[child] + branch_length = generation_branch_lengths[child_generation - 1] + tree.nodes[child]["age"] = ( + tree.nodes[int(child / 2)]["age"] - branch_length + ) + times = {} + for node in tree.nodes: + times[node] = tree.nodes[0]["age"] - tree.nodes[node]["age"] + res = CassiopeiaTree(tree=tree) + res.set_times(times) + return res + + +class BirthProcess(LineageSimulator): + r""" + A Birth Process with exponential holding times. + + Args: + birth_rate: Birth rate of the process + tree_depth: Depth of the simulated tree + """ + + def __init__(self, birth_rate: float, tree_depth: float): + self.birth_rate = birth_rate + self.tree_depth = tree_depth + + def simulate_lineage(self) -> CassiopeiaTree: + r""" + See base class. + """ + tree_depth = self.tree_depth + birth_rate = self.birth_rate + node_age = {} + node_age[0] = tree_depth + live_nodes = [1] + edges = [(0, 1)] + t = 0 + last_node_id = 1 + while t < tree_depth: + num_live_nodes = len(live_nodes) + # Wait till next node divides. + waiting_time = np.random.exponential( + 1.0 / (birth_rate * num_live_nodes) + ) + when_node_divides = t + waiting_time + del waiting_time + if when_node_divides >= tree_depth: + # The simulation has ended. + for node in live_nodes: + node_age[node] = 0 + del live_nodes + break + # Choose which node divides uniformly at random + node_that_divides = live_nodes[ + np.random.randint(low=0, high=num_live_nodes) + ] + # Remove the node that divided and add its two children + live_nodes.remove(node_that_divides) + left_child_id = last_node_id + 1 + right_child_id = last_node_id + 2 + last_node_id += 2 + live_nodes += [left_child_id, right_child_id] + edges += [ + (node_that_divides, left_child_id), + (node_that_divides, right_child_id), + ] + node_age[node_that_divides] = tree_depth - when_node_divides + t = when_node_divides + tree_nx = nx.DiGraph() + tree_nx.add_nodes_from(range(last_node_id + 1)) + tree_nx.add_edges_from(edges) + times = {} + for node in tree_nx.nodes: + times[node] = node_age[0] - node_age[node] + tree = CassiopeiaTree(tree=tree_nx) + tree.set_times(times) + return tree + + +class TumorWithAFitSubclone(LineageSimulator): + r""" + TODO + + Args: + branch_length: TODO + TODO + """ + + def __init__( + self, + branch_length: float, + branch_length_fit: float, + experiment_duration: float, + generations_until_fit_subclone: int, + ): + self.branch_length = branch_length + self.branch_length_fit = branch_length_fit + self.experiment_duration = experiment_duration + self.generations_until_fit_subclone = generations_until_fit_subclone + + def simulate_lineage(self) -> CassiopeiaTree: + r""" + See base class. + """ + branch_length = self.branch_length + branch_length_fit = self.branch_length_fit + experiment_duration = self.experiment_duration + generations_until_fit_subclone = self.generations_until_fit_subclone + + def node_name_generator(): + i = 0 + while True: + yield str(i) + i += 1 + + tree = nx.DiGraph() # This is what will get populated. + + names = node_name_generator() + q = Queue() # (node, time, fitness, generation) + times = {} + + root = next(names) + "_unfit" + tree.add_node(root) + times[root] = 0.0 + + root_child = next(names) + "_unfit" + tree.add_edge(root, root_child) + q.put((root_child, 0.0, "unfit", 0)) + subclone_started = False + while not q.empty(): + # Pop next node + (node, time, node_fitness, generation) = q.get() + time_till_division = ( + branch_length if node_fitness == "unfit" else branch_length_fit + ) + time_of_division = time + time_till_division + if time_of_division >= experiment_duration: + # Not enough time left for the cell to divide. + times[node] = experiment_duration + continue + # Create children, add edges to them, and push children to the + # queue. + times[node] = time_of_division + left_child_fitness = node_fitness + right_child_fitness = node_fitness + if ( + not subclone_started + and generation + 1 == generations_until_fit_subclone + ): + # Start the subclone + subclone_started = True + left_child_fitness = "fit" + left_child = next(names) + "_" + left_child_fitness + right_child = next(names) + "_" + right_child_fitness + tree.add_nodes_from([left_child, right_child]) + tree.add_edges_from([(node, left_child), (node, right_child)]) + q.put( + ( + left_child, + time_of_division, + left_child_fitness, + generation + 1, + ) + ) + q.put( + ( + right_child, + time_of_division, + right_child_fitness, + generation + 1, + ) + ) + res = CassiopeiaTree(tree=tree) + res.set_times(times) + return res diff --git a/cassiopeia/tools/lineage_tracing_simulator.py b/cassiopeia/tools/lineage_tracing_simulator.py new file mode 100644 index 00000000..c03051a9 --- /dev/null +++ b/cassiopeia/tools/lineage_tracing_simulator.py @@ -0,0 +1,80 @@ +import abc + +import numpy as np + +from cassiopeia.data import CassiopeiaTree + + +class LineageTracingSimulator(abc.ABC): + r""" + Abstract base class for all lineage tracing simulators. + + A LineageTracingSimulator implements a method overlay_lineage_tracing_data + which overlays lineage tracing data (i.e. character vectors) on top of the + tree. These are stored as the node's state. + """ + + @abc.abstractmethod + def overlay_lineage_tracing_data(self, tree: CassiopeiaTree) -> None: + r""" + Annotates the tree's nodes with lineage tracing character vectors. + These are stored as the node's state. (Operates on the tree in-place.) + + Args: + tree: The tree to overlay lineage tracing data on. + """ + + +class IIDExponentialLineageTracer(LineageTracingSimulator): + r""" + Characters evolve IID over the lineage, with the same given mutation rate. + + Args: + mutation_rate: The mutation rate of each character (same for all). + num_characters: The number of characters. + """ + + def __init__(self, mutation_rate: float, num_characters: float): + self.mutation_rate = mutation_rate + self.num_characters = num_characters + + def overlay_lineage_tracing_data(self, tree: CassiopeiaTree) -> None: + r""" + See base class. + """ + num_characters = self.num_characters + mutation_rate = self.mutation_rate + states = {} + + def dfs(node: str, tree: CassiopeiaTree): + node_state = states[node] + for child in tree.children(node): + # Compute the state of the child + child_state = [] + edge_length = tree.get_branch_length(node, child) + # print(f"{node} -> {child}, length {edge_length}") + assert edge_length >= 0 + for i in range(num_characters): + # See what happens to character i + if node_state[i] != 0: + # The character has already mutated; there in nothing + # to do + child_state += [node_state[i]] + continue + else: + # Determine if the character will mutate. + mutates = ( + np.random.exponential(1.0 / mutation_rate) + < edge_length + ) + if mutates: + child_state += [1] + else: + child_state += [0] + states[child] = child_state + dfs(child, tree) + + root = tree.root + states[root] = [0] * num_characters + dfs(root, tree) + tree.initialize_all_character_states(states) diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..e446d0a1 --- /dev/null +++ b/conftest.py @@ -0,0 +1,21 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runslow", action="store_true", default=False, help="run slow tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "slow: mark test as slow to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) diff --git a/docs/requirements.txt b/docs/requirements.txt index 6b535e2d..45eb1c7e 100755 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -22,4 +22,6 @@ pydata-sphinx-theme>=0.4.0 python-Levenshtein pathlib typing_extensions; python_version < '3.8' - +cvxpy +parameterized +seaborn diff --git a/setup.py b/setup.py index 549c106b..b1e0717e 100755 --- a/setup.py +++ b/setup.py @@ -32,7 +32,10 @@ 'nbconvert >= 5.4.0', 'nbformat >= 4.4.0', 'hits', - 'scikit-bio >= 0.5.6' + 'scikit-bio >= 0.5.6', + 'cvxpy', + 'parameterized', + 'seaborn', ] @@ -42,6 +45,7 @@ # files to wrap with cython to_cythonize = [Extension("cassiopeia.preprocess.doublet_utils", ["cassiopeia/preprocess/doublet_utils.pyx"]), + Extension("cassiopeia.tools.branch_length_estimator.IIDExponentialPosteriorMeanBLE", ["cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py"]), Extension("cassiopeia.preprocess.map_utils", ["cassiopeia/preprocess/map_utils.pyx"]), Extension("cassiopeia.preprocess.collapse_cython", ["cassiopeia/preprocess/collapse_cython.pyx"]), Extension("cassiopeia.solver.ilp_solver_utilities", ["cassiopeia/solver/ilp_solver_utilities.pyx"])] diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py new file mode 100644 index 00000000..626d537e --- /dev/null +++ b/test/tools_tests/branch_length_estimator_test.py @@ -0,0 +1,1125 @@ +import itertools +import multiprocessing +import unittest +from copy import deepcopy + +import networkx as nx +import numpy as np +import pytest +from parameterized import parameterized + +from cassiopeia.data import CassiopeiaTree +from cassiopeia.tools import (BirthProcess, BLEMultifurcationWrapper, + IIDExponentialBLE, + IIDExponentialBLEGridSearchCV, + IIDExponentialLineageTracer, + IIDExponentialPosteriorMeanBLE, + IIDExponentialPosteriorMeanBLEGridSearchCV) + + +class TestIIDExponentialBLE(unittest.TestCase): + def test_no_mutations(self): + r""" + Tree topology is just a branch 0->1. + There is one unmutated character i.e.: + root [state = '0'] + | + v + child [state = '0'] + This is thus the simplest possible example of no mutations, and the MLE + branch length should be 0 + """ + tree = nx.DiGraph() + tree.add_node("0"), tree.add_node("1") + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [0]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal(tree.get_branch_length("0", "1"), 0.0) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(tree.get_time("1"), 0.0) + np.testing.assert_almost_equal(log_likelihood, 0.0) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_saturation(self): + r""" + Tree topology is just a branch 0->1. + There is one mutated character i.e.: + root [state = '0'] + | + v + child [state = '1'] + This is thus the simplest possible example of saturation, and the MLE + branch length should be infinity (>15 for all practical purposes) + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]) + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + assert tree.get_branch_length("0", "1") > 15.0 + assert tree.get_time("1") > 15.0 + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=5) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_hand_solvable_problem_1(self): + r""" + Tree topology is just a branch 0->1. + There is one mutated character and one unmutated character, i.e.: + root [state = '00'] + | + v + child [state = '01'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(2) ~ 0.693 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]) + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [0, 1]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(2), decimal=3 + ) + np.testing.assert_almost_equal(tree.get_time("1"), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_hand_solvable_problem_2(self): + r""" + Tree topology is just a branch 0->1. + There are two mutated characters and one unmutated character, i.e.: + root [state = '000'] + | + v + child [state = '011'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} log(exp(-r * t0)) + 2 * log(1 - exp(-r * t0)) + The solution is r * t0 = ln(3) ~ 1.098 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]) + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0], + "1": [0, 1, 1]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(3), decimal=3 + ) + np.testing.assert_almost_equal(tree.get_time("1"), np.log(3), decimal=3) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_hand_solvable_problem_3(self): + r""" + Tree topology is just a branch 0->1. + There are two unmutated characters and one mutated character, i.e.: + root [state = '000'] + | + v + child [state = '001'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(1.5) ~ 0.405 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]) + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0], + "1": [0, 0, 1]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(1.5), decimal=3 + ) + np.testing.assert_almost_equal(tree.get_time("1"), np.log(1.5), decimal=3) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_with_no_mutations(self): + r""" + Perfect binary tree with no mutations: Should give edges of length 0 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]) + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0, 0], + "1": [0, 0, 0, 0], + "2": [0, 0, 0, 0], + "3": [0, 0, 0, 0], + "4": [0, 0, 0, 0], + "5": [0, 0, 0, 0], + "6": [0, 0, 0, 0]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + for edge in tree.edges: + np.testing.assert_almost_equal( + tree.get_branch_length(*edge), 0, decimal=3 + ) + np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_with_one_mutation(self): + r""" + Perfect binary tree with one mutation at a node 6: Should give very short + edges 1->3,1->4,0->2 and very long edges 0->1,2->5,2->6. + The problem can be solved by hand: it trivially reduces to a 1-dimensional + problem: + min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(1.5) ~ 0.405 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [0], + "2": [0], + "3": [0], + "4": [0], + "5": [0], + "6": [1]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal(tree.get_branch_length("0", "1"), 0.405, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("0", "2"), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "3"), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "4"), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("2", "5"), 0.405, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("2", "6"), 0.405, decimal=3) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_with_saturation(self): + r""" + Perfect binary tree with saturation. The edges which saturate should thus + have length infinity (>15 for all practical purposes) + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [0], + "2": [1], + "3": [1], + "4": [1], + "5": [1], + "6": [1]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + assert tree.get_branch_length("0", "2") > 15.0 + assert tree.get_branch_length("1", "3") > 15.0 + assert tree.get_branch_length("1", "4") > 15.0 + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_regression(self): + r""" + Regression test. Cannot be solved by hand. We just check that this solution + never changes. + """ + # Perfect binary tree with normal amount of mutations on each edge + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0, 0, 0, 0, 0, 0, 0], + "1": [1, 0, 0, 0, 0, 0, 0, 0, 0], + "2": [0, 0, 0, 0, 0, 6, 0, 0, 0], + "3": [1, 2, 0, 0, 0, 0, 0, 0, 0], + "4": [1, 0, 3, 0, 0, 0, 0, 0, 0], + "5": [0, 0, 0, 0, 5, 6, 7, 0, 0], + "6": [0, 0, 0, 4, 0, 6, 0, 8, 9]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal(tree.get_branch_length("0", "1"), 0.203, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("0", "2"), 0.082, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "3"), 0.175, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "4"), 0.175, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("2", "5"), 0.295, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("2", "6"), 0.295, decimal=3) + np.testing.assert_almost_equal(log_likelihood, -22.689, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_symmetric_tree(self): + r""" + Symmetric tree should have equal length edges. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0], + "1": [1, 0, 0], + "2": [1, 0, 0], + "3": [1, 1, 0], + "4": [1, 1, 0], + "5": [1, 1, 0], + "6": [1, 1, 0]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), tree.get_branch_length("0", "2") + ) + np.testing.assert_almost_equal( + tree.get_branch_length("1", "3"), tree.get_branch_length("1", "4") + ) + np.testing.assert_almost_equal( + tree.get_branch_length("1", "4"), tree.get_branch_length("2", "5") + ) + np.testing.assert_almost_equal( + tree.get_branch_length("2", "5"), tree.get_branch_length("2", "6") + ) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_with_infinite_legs(self): + r""" + Perfect binary tree with saturated leaves. The first level of the tree + should be normal (can be solved by hand, solution is log(2)), + the branches for the leaves should be infinity (>15 for all practical + purposes) + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [1, 0], + "2": [1, 0], + "3": [1, 1], + "4": [1, 1], + "5": [1, 1], + "6": [1, 1]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal(tree.get_branch_length("0", "1"), 0.693, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("0", "2"), 0.693, decimal=3) + assert tree.get_branch_length("1", "3") > 15 + assert tree.get_branch_length("1", "4") > 15 + assert tree.get_branch_length("2", "5") > 15 + assert tree.get_branch_length("2", "6") > 15 + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_on_simulated_data(self): + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.set_times( + {"0": 0, + "1": 0.1, + "2": 0.9, + "3": 1.0, + "4": 1.0, + "5": 1.0, + "6": 1.0} + ) + np.random.seed(1) + IIDExponentialLineageTracer( + mutation_rate=1.0, num_characters=100 + ).overlay_lineage_tracing_data(tree) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + assert 0.05 < tree.get_time("1") < 0.15 + assert 0.8 < tree.get_time("2") < 1.0 + assert 0.9 < tree.get_time("3") < 1.1 + assert 0.9 < tree.get_time("4") < 1.1 + assert 0.9 < tree.get_time("5") < 1.1 + assert 0.9 < tree.get_time("6") < 1.1 + np.testing.assert_almost_equal(tree.get_time("0"), 0) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_subtree_collapses_when_no_mutations(self): + r""" + A subtree with no mutations should collapse to 0. It reduces the problem to + the same as in 'test_hand_solvable_problem_1' + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4"]), + tree.add_edges_from([("0", "1"), ("1", "2"), ("1", "3"), ("0", "4")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1], + "2": [1], + "3": [1], + "4": [0]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(2), decimal=3 + ) + np.testing.assert_almost_equal(tree.get_branch_length("1", "2"), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "3"), 0.0, decimal=3) + np.testing.assert_almost_equal( + tree.get_branch_length("0", "4"), np.log(2), decimal=3 + ) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_minimum_branch_length(self): + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4"]) + tree.add_edges_from([("0", "1"), ("0", "2"), ("0", "3"), ("2", "4")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_character_states_at_leaves( + {"1": [1], + "3": [1], + "4": [1]} + ) + tree.reconstruct_ancestral_characters(zero_the_root=True) + # Too large a minimum_branch_length + model = IIDExponentialBLE(minimum_branch_length=0.6) + model.estimate_branch_lengths(tree) + for node in tree.nodes: + print(f"{node} = {tree.get_time(node)}") + assert model.log_likelihood == -np.inf + # An okay minimum_branch_length + model = IIDExponentialBLE(minimum_branch_length=0.4) + model.estimate_branch_lengths(tree) + assert model.log_likelihood != -np.inf + + +class TestIIDExponentialBLEGridSearchCV(unittest.TestCase): + def test_IIDExponentialBLEGridSearchCV_smoke(self): + r""" + Just want to see that it runs in both single and multiprocessor mode + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]), + tree.add_edges_from([("0", "1")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1]}, + ) + for processes in [1, 2]: + model = IIDExponentialBLEGridSearchCV( + minimum_branch_lengths=(1.0,), + l2_regularizations=(1.0,), + verbose=True, + processes=processes, + ) + model.estimate_branch_lengths(tree) + + def test_IIDExponentialBLEGridSearchCV(self): + r""" + We make sure to test a tree for which no regularization produces + a likelihood of 0. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6", "7"]), + tree.add_edges_from( + [("0", "1"), ("1", "2"), ("1", "3"), ("2", "4"), ("2", "5"), ("3", "6"), + ("3", "7")] + ) + tree = CassiopeiaTree(tree=tree) + tree.initialize_character_states_at_leaves( + {"4": [1, 1, 0], + "5": [1, 1, 0], + "6": [1, 0, 0], + "7": [1, 0, 0]}, + ) + tree.reconstruct_ancestral_characters(zero_the_root=True) + model = IIDExponentialBLEGridSearchCV( + minimum_branch_lengths=(0, 0.2, 4.0), + l2_regularizations=(0.0, 2.0, 4.0), + verbose=True, + processes=6, + ) + model.estimate_branch_lengths(tree) + print(model.grid) + assert model.grid[0, 0] == -np.inf + + # import seaborn as sns + # import matplotlib.pyplot as plt + # sns.heatmap( + # model.grid, + # yticklabels=model.minimum_branch_lengths, + # xticklabels=model.l2_regularizations, + # mask=np.isneginf(model.grid), + # ) + # plt.ylabel("minimum_branch_length") + # plt.xlabel("l2_regularization") + # plt.show() + + np.testing.assert_almost_equal(model.minimum_branch_length, 0.2) + np.testing.assert_almost_equal(model.l2_regularization, 2.0) + + +def get_z_scores( + repetition, + birth_rate_true, + mutation_rate_true, + birth_rate_model, + mutation_rate_model, + num_characters, +): + r""" + This function is at the global scope because it needs to be pickled + for parallelization. + """ + np.random.seed(repetition) + tree = BirthProcess( + birth_rate=birth_rate_true, tree_depth=1.0 + ).simulate_lineage() + tree_true = deepcopy(tree) + IIDExponentialLineageTracer( + mutation_rate=mutation_rate_true, num_characters=num_characters + ).overlay_lineage_tracing_data(tree) + discretization_level = 100 + model = IIDExponentialPosteriorMeanBLE( + birth_rate=birth_rate_model, + mutation_rate=mutation_rate_model, + discretization_level=discretization_level, + use_cpp_implementation=True + ) + model.estimate_branch_lengths(tree) + z_scores = [] + if len(tree.non_root_internal_nodes) > 0: + for node in [np.random.choice(tree.non_root_internal_nodes)]: + true_age = tree_true.get_time(node) + z_score = model.posteriors[node][ + : int(true_age * discretization_level) + ].sum() + z_scores.append(z_score) + return z_scores + + +def get_z_scores_under_true_model(repetition): + r""" + This function is at the global scope because it needs to be pickled + for parallelization. + """ + return get_z_scores( + repetition, + birth_rate_true=0.8, + mutation_rate_true=1.2, + birth_rate_model=0.8, + mutation_rate_model=1.2, + num_characters=3, + ) + + +def get_z_scores_under_misspecified_model(repetition): + r""" + This function is at the global scope because it needs to be pickled + for parallelization. + """ + return get_z_scores( + repetition, + birth_rate_true=0.4, + mutation_rate_true=0.6, + birth_rate_model=0.8, + mutation_rate_model=1.2, + num_characters=3, + ) + + +class TestIIDExponentialPosteriorMeanBLE(unittest.TestCase): + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLE(self, name, use_cpp_implementation): + r""" + For a small tree with only one internal node, the likelihood of the data, + and the posterior age of the internal node, can be computed easily in + closed form. We check the theoretical values against those obtained from + our model. + """ + from scipy.special import logsumexp + + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3"]) + tree.add_edges_from([("0", "1"), ("1", "2"), ("1", "3")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0, 0, 0, 0, 0, 0, 0], + "1": [0, 1, 0, 0, 0, 0, 1, 1, 0], + "2": [0, 1, 0, 1, 1, 0, 1, 1, 1], + "3": [0, 1, 1, 1, 0, 0, 1, 1, 1]}, + ) + + mutation_rate = 0.3 + birth_rate = 0.8 + discretization_level = 200 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + use_cpp_implementation=use_cpp_implementation + ) + + model.estimate_branch_lengths(tree) + print(f"{model.log_likelihood} = model.log_likelihood") + + # Test the model log likelihood vs its computation from the joint of the + # age of vertex 1. + model_log_joints = model.log_joints[ + "1" + ] # log P(t_1 = t, X, T) where t_1 is the age of the first node. + model_log_likelihood_2 = logsumexp(model_log_joints) + print(f"{model_log_likelihood_2} = {model_log_likelihood_2}") + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_2, significant=3 + ) + + # Test the model log likelihood vs its computation from the leaf nodes. + for leaf in ["2", "3"]: + model_log_likelihood_up = model.up( + leaf, discretization_level, tree.get_number_of_mutated_characters_in_node(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) + print(f"{model_log_likelihood_up} = model_log_likelihood_up") + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 + ) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + print(f"{numerical_log_likelihood} = numerical_log_likelihood") + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Test the _whole_ array of log joints P(t_v = t, X, T) against its + # numerical computation + numerical_log_joint = IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node="1", + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + np.testing.assert_array_almost_equal( + model.log_joints["1"][50:-50], numerical_log_joint[50:-50], decimal=1 + ) + + # Test the model posterior against its numerical posterior + numerical_posterior = IIDExponentialPosteriorMeanBLE.numerical_posterior( + tree=tree, + node="1", + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[1]) + # plt.show() + # plt.plot(numerical_posterior) + # plt.show() + total_variation = np.sum(np.abs(model.posteriors["1"] - numerical_posterior)) + assert total_variation < 0.03 + + # Test the posterior mean against the numerical posterior mean. + numerical_posterior_mean = np.sum( + numerical_posterior + * np.array(range(discretization_level + 1)) + / discretization_level + ) + posterior_mean = tree.get_time("1") + np.testing.assert_approx_equal( + posterior_mean, numerical_posterior_mean, significant=2 + ) + + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLE_2(self, name, use_cpp_implementation): + r""" + We run the Bayesian estimator on a small tree with all different leaves, + and then check that: + - The likelihood of the data, computed from all of the leaves, is the same. + - The posteriors of the internal node ages matches their numerical + counterpart. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [0, 0], + "2": [1, 0], + "3": [0, 0], + "4": [0, 1], + "5": [1, 0], + "6": [1, 1]}, + ) + + mutation_rate = 0.625 + birth_rate = 0.75 + discretization_level = 100 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + use_cpp_implementation=use_cpp_implementation + ) + + model.estimate_branch_lengths(tree) + print(model.log_likelihood) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Check that the likelihood computed from each leaf node is correct. + for leaf in tree.leaves: + model_log_likelihood_up = model.up( + leaf, discretization_level, tree.get_number_of_mutated_characters_in_node(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 + ) + + model_log_likelihood_up_wrong = model.up( + leaf, discretization_level, (tree.get_number_of_mutated_characters_in_node(leaf) + 1) % 2 + ) + with pytest.raises(AssertionError): + np.testing.assert_approx_equal( + model.log_likelihood, + model_log_likelihood_up_wrong, + significant=3, + ) + + # Check that the posterior ages of the nodes are correct. + for node in tree.non_root_internal_nodes: + numerical_log_joint = ( + IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + ) + np.testing.assert_array_almost_equal( + model.log_joints[node][25:-25], + numerical_log_joint[25:-25], + decimal=1, + ) + + # Test the model posterior against its numerical posterior. + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[node]) + # plt.show() + # plt.plot(analytical_posterior) + # plt.show() + total_variation = np.sum( + np.abs(model.posteriors[node] - numerical_posterior) + ) + assert total_variation < 0.03 + + @pytest.mark.slow + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLE_3(self, name, use_cpp_implementation): + r""" + Same as test_IIDExponentialPosteriorMeanBLE_2 but with a weirder topology. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6", "7"]), + tree.add_edges_from( + [("0", "1"), ("1", "2"), ("1", "3"), ("2", "4"), ("2", "5"), ("2", "6"), + ("0", "7")] + ) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [0, 0], + "2": [1, 0], + "3": [1, 1], + "4": [1, 0], + "5": [1, 0], + "6": [1, 1], + "7": [0, 0]}, + ) + + mutation_rate = 0.625 + birth_rate = 0.75 + discretization_level = 100 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + use_cpp_implementation=use_cpp_implementation + ) + + model.estimate_branch_lengths(tree) + print(model.log_likelihood) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Check that the likelihood computed from each leaf node is correct. + for leaf in tree.leaves: + model_log_likelihood_up = model.up( + leaf, discretization_level, tree.get_number_of_mutated_characters_in_node(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=2 + ) + + # Check that the posterior ages of the nodes are correct. + for node in tree.non_root_internal_nodes: + numerical_log_joint = ( + IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + ) + np.testing.assert_array_almost_equal( + model.log_joints[node][25:-25], + numerical_log_joint[25:-25], + decimal=1, + ) + + # Test the model posterior against its numerical posterior. + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[node]) + # plt.show() + # plt.plot(numerical_posterior) + # plt.show() + total_variation = np.sum( + np.abs(model.posteriors[node] - numerical_posterior) + ) + assert total_variation < 0.03 + + @pytest.mark.slow + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(self, name, use_cpp_implementation): + r""" + A tree from the DREAM subchallenge 1, verified analytically. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_character_states_at_leaves( + {"3": [2, 0, 1, 1, 0, 0, 0, 1, 1, 1], + "4": [2, 0, 1, 1, 0, 1, 0, 1, 1, 1], + "5": [2, 0, 1, 1, 0, 1, 0, 1, 1, 1], + "6": [2, 0, 1, 1, 0, 1, 0, 1, 1, 1]}, + ) + tree.reconstruct_ancestral_characters(zero_the_root=True) + + mutation_rate = 0.6 + birth_rate = 0.8 + discretization_level = 500 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + use_cpp_implementation=use_cpp_implementation + ) + + model.estimate_branch_lengths(tree) + print(model.log_likelihood) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Check that the likelihood computed from each leaf node is correct. + for leaf in tree.leaves: + model_log_likelihood_up = model.up( + leaf, discretization_level, tree.get_number_of_mutated_characters_in_node(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 + ) + + # Check that the posterior ages of the nodes are correct. + for node in tree.non_root_internal_nodes: + numerical_log_joint = ( + IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + ) + mean_error = np.mean( + np.abs(model.log_joints[node][25:-25] - numerical_log_joint[25:-25]) + / np.abs(numerical_log_joint[25:-25]) + ) + assert mean_error < 0.03 + + # Test the model posterior against its numerical posterior. + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[node]) + # plt.show() + # plt.plot(numerical_posterior) + # plt.show() + total_variation = np.sum( + np.abs(model.posteriors[node] - numerical_posterior) + ) + assert total_variation < 0.05 + + @pytest.mark.slow + def test_IIDExponentialPosteriorMeanBLE_posterior_calibration(self): + r""" + Under the true model, the Z scores should be ~Unif[0, 1] + Under the wrong model, the Z scores should not be ~Unif[0, 1] + This test is slow because we need to make many repetitions to get + enough statistical power for the test to be meaningful. + We use p-values computed from the Hoeffding bound. + TODO: There might be a more powerful test, e.g. Kolmogorov–Smirnov? + (This would mean we need less repetitions and can make the test faster.) + This test uses the c++ implementation to be faster. + """ + repetitions = 1000 + + # Under the true model, the Z scores should be ~Unif[0, 1] + with multiprocessing.Pool(processes=6) as pool: + z_scores = pool.map(get_z_scores_under_true_model, range(repetitions)) + z_scores = np.array(list(itertools.chain(*z_scores))) + mean_z_score = z_scores.mean() + p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) + print(f"p_value under true model = {p_value}") + assert p_value > 0.01 + # import matplotlib.pyplot as plt + # plt.hist(z_scores, bins=10) + # plt.show() + + # Under the wrong model, the Z scores should not be ~Unif[0, 1] + with multiprocessing.Pool(processes=6) as pool: + z_scores = pool.map( + get_z_scores_under_misspecified_model, range(repetitions) + ) + z_scores = np.array(list(itertools.chain(*z_scores))) + mean_z_score = z_scores.mean() + p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) + print(f"p_value under misspecified model = {p_value}") + assert p_value < 0.01 + # import matplotlib.pyplot as plt + # plt.hist(z_scores, bins=10) + # plt.show() + + +class TestIIDExponentialPosteriorMeanBLEGridSeachCV(unittest.TestCase): + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLEGridSeachCV_smoke(self, name, use_cpp_implementation): + r""" + Just want to see that it runs in both single and multiprocessor mode + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]), + tree.add_edges_from([("0", "1")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1]} + ) + for processes in [1, 2]: + model = IIDExponentialPosteriorMeanBLEGridSearchCV( + mutation_rates=(0.5,), + birth_rates=(1.5,), + discretization_level=5, + verbose=True, + use_cpp_implementation=use_cpp_implementation + ) + model.estimate_branch_lengths(tree) + + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLEGridSeachCV(self, name, use_cpp_implementation): + r""" + We just check that the grid search estimator does its job on a small grid. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4"]), + tree.add_edges_from([("0", "1"), ("1", "2"), ("1", "3"), ("0", "4")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1], + "2": [1], + "3": [1], + "4": [0]} + ) + + discretization_level = 100 + mutation_rates = (0.625, 0.750, 0.875) + birth_rates = (0.25, 0.50, 0.75) + model = IIDExponentialPosteriorMeanBLEGridSearchCV( + mutation_rates=mutation_rates, + birth_rates=birth_rates, + discretization_level=discretization_level, + verbose=True, + use_cpp_implementation=use_cpp_implementation + ) + + # Test the model log likelihood against its numerical computation + model.estimate_branch_lengths(tree) + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, + mutation_rate=model.mutation_rate, + birth_rate=model.birth_rate, + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # import matplotlib.pyplot as plt + # import seaborn as sns + # sns.heatmap( + # model.grid, + # yticklabels=mutation_rates, + # xticklabels=birth_rates + # ) + # plt.ylabel('Mutation Rate') + # plt.xlabel('Birth Rate') + # plt.show() + + np.testing.assert_almost_equal(model.mutation_rate, 0.75) + np.testing.assert_almost_equal(model.birth_rate, 0.5) + np.testing.assert_almost_equal(model.posterior_means["1"], 0.6815, decimal=3) + + +class TestBLEMultifurcationWrapper(unittest.TestCase): + def test_smoke(self): + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3"]) + tree.add_edges_from([("0", "1"), ("0", "2"), ("0", "3")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [0, 1], + "2": [0, 1], + "3": [0, 1]} + ) + model = BLEMultifurcationWrapper(IIDExponentialBLE()) + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(2), decimal=3 + ) + np.testing.assert_almost_equal( + tree.get_branch_length("0", "2"), np.log(2), decimal=3 + ) + np.testing.assert_almost_equal( + tree.get_branch_length("0", "3"), np.log(2), decimal=3 + ) + np.testing.assert_almost_equal(tree.get_time("1"), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_time("2"), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_time("3"), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + # The tree topology that the estimator sees is different from the + # one in the final phylogeny, thus the lik will be different! + np.testing.assert_almost_equal(log_likelihood * 3, log_likelihood_2, decimal=3) diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py new file mode 100644 index 00000000..c3142689 --- /dev/null +++ b/test/tools_tests/lineage_simulator_test.py @@ -0,0 +1,127 @@ +import pytest +import unittest + +import numpy as np + +from cassiopeia.tools import ( + PerfectBinaryTree, + PerfectBinaryTreeWithRootBranch, + BirthProcess, + TumorWithAFitSubclone, +) + + +class TestPerfectBinaryTree(unittest.TestCase): + def test_PerfectBinaryTree(self): + tree = PerfectBinaryTree( + generation_branch_lengths=[2, 3] + ).simulate_lineage() + newick = tree.get_newick() + assert newick == "((3,4),(5,6));" + self.assertDictEqual( + tree.get_times(), {0: 0, 1: 2, 2: 2, 3: 5, 4: 5, 5: 5, 6: 5} + ) + + +class TestPerfectBinaryTreeWithRootBranch(unittest.TestCase): + def test_PerfectBinaryTreeWithRootBranch(self): + tree = PerfectBinaryTreeWithRootBranch( + generation_branch_lengths=[2, 3, 4] + ).simulate_lineage() + newick = tree.get_newick() + assert newick == "(((4,5),(6,7)));" + self.assertDictEqual( + tree.get_times(), {0: 0, 1: 2, 2: 5, 3: 5, 4: 9, 5: 9, 6: 9, 7: 9} + ) + + +class TestBirthProcess(unittest.TestCase): + @pytest.mark.slow + def test_BirthProcess(self): + r""" + Generate tree, then choose a random lineage can count how many nodes + are on the lineage. This is the number of times the process triggered + on that lineage. + + Also, the probability that a tree with only one internal node is + obtained is e^-lam * (1 - e^-lam) where lam is the birth rate, so we + also check this. + """ + np.random.seed(1) + birth_rate = 0.6 + intensities = [] + repetitions = 10000 + topology_hits = 0 + + def num_ancestors(tree, node: int) -> int: + r""" + Number of ancestors of a node. Terribly inefficient implementation. + """ + res = 0 + root = tree.root + while node != root: + node = tree.parent(node) + res += 1 + return res + + for _ in range(repetitions): + tree_true = BirthProcess( + birth_rate=birth_rate, tree_depth=1.0 + ).simulate_lineage() + if len(tree_true.nodes) == 4: + topology_hits += 1 + leaf = np.random.choice(tree_true.leaves) + n_leaves = len(tree_true.leaves) + n_hits = num_ancestors(tree_true, leaf) - 1 + intensity = n_leaves / 2 ** n_hits * n_hits + intensities.append(intensity) + # Check that the probability of the topology matches + empirical_topology_prob = topology_hits / repetitions + theoretical_topology_prob = np.exp(-birth_rate) * ( + 1.0 - np.exp(-birth_rate) + ) + assert ( + np.abs(empirical_topology_prob - theoretical_topology_prob) < 0.02 + ) + inferred_birth_rate = np.array(intensities).mean() + print(f"{birth_rate} == {inferred_birth_rate}") + assert np.abs(birth_rate - inferred_birth_rate) < 0.05 + + +class TestTumorWithAFitSubclone(unittest.TestCase): + def test_TumorWithAFitSubclone(self): + r""" + Small test that can be drawn by hand. + Checks that the generated phylogeny is correct. + """ + tree = TumorWithAFitSubclone( + branch_length=1, + branch_length_fit=0.5, + experiment_duration=2, + generations_until_fit_subclone=1, + ).simulate_lineage() + self.assertListEqual( + tree.nodes, + ["0_unfit", "1_unfit", "2_fit", "3_unfit", "4_fit", "5_fit"], + ) + self.assertListEqual( + tree.edges, + [ + ("0_unfit", "1_unfit"), + ("1_unfit", "2_fit"), + ("1_unfit", "3_unfit"), + ("2_fit", "4_fit"), + ("2_fit", "5_fit"), + ], + ) + self.assertDictEqual( + tree.get_times(), + { + "0_unfit": 0.0, + "1_unfit": 1.0, + "2_fit": 1.5, + "3_unfit": 2.0, + "4_fit": 2.0, + "5_fit": 2.0, + }, + ) diff --git a/test/tools_tests/lineage_tracing_simulator_test.py b/test/tools_tests/lineage_tracing_simulator_test.py new file mode 100644 index 00000000..afb004dd --- /dev/null +++ b/test/tools_tests/lineage_tracing_simulator_test.py @@ -0,0 +1,35 @@ +import unittest + +import networkx as nx +import numpy as np + +from cassiopeia.data import CassiopeiaTree +from cassiopeia.tools import IIDExponentialLineageTracer + + +class Test(unittest.TestCase): + def test_smoke(self): + r""" + Just tests that lineage_tracing_simulator runs + """ + np.random.seed(1) + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from( + [ + ("0", "1"), + ("0", "2"), + ("1", "3"), + ("1", "4"), + ("2", "5"), + ("2", "6"), + ] + ) + tree = CassiopeiaTree(tree=tree) + tree.set_times( + {"0": 0, "1": 0.1, "2": 0.9, "3": 1.0, "4": 1.0, "5": 1.0, "6": 1.0} + ) + np.random.seed(1) + IIDExponentialLineageTracer( + mutation_rate=1.0, num_characters=10 + ).overlay_lineage_tracing_data(tree)