From 83d628792f7d8be7423d4f90897a3f5a9bf5c231 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Fri, 25 Dec 2020 16:28:31 -0800 Subject: [PATCH 01/61] Add codebase skeleton for branch length estimation --- cassiopeia/tools/__init__.py | 0 cassiopeia/tools/branch_length_estimation.py | 0 test/tools_tests/branch_length_estimation_test.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 cassiopeia/tools/__init__.py create mode 100644 cassiopeia/tools/branch_length_estimation.py create mode 100644 test/tools_tests/branch_length_estimation_test.py diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cassiopeia/tools/branch_length_estimation.py b/cassiopeia/tools/branch_length_estimation.py new file mode 100644 index 00000000..e69de29b diff --git a/test/tools_tests/branch_length_estimation_test.py b/test/tools_tests/branch_length_estimation_test.py new file mode 100644 index 00000000..e69de29b From 93fcd990f24d0c90e18eebc6e0a1c8bf1f21d81c Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Fri, 25 Dec 2020 20:26:47 -0800 Subject: [PATCH 02/61] Add branch length estimation, lineage simulator, and phylogeny simulator --- cassiopeia/tools/__init__.py | 2 + cassiopeia/tools/branch_length_estimation.py | 128 +++++++ cassiopeia/tools/lineage_simulator.py | 39 ++ cassiopeia/tools/phylogeny_simulator.py | 56 +++ cassiopeia/tools/tree.py | 311 ++++++++++++++++ .../branch_length_estimation_test.py | 340 ++++++++++++++++++ .../lineage_tracing_simulator_test.py | 24 ++ test/tools_tests/phylogeny_simulator_test.py | 18 + test/tools_tests/tree_test.py | 130 +++++++ 9 files changed, 1048 insertions(+) create mode 100644 cassiopeia/tools/lineage_simulator.py create mode 100644 cassiopeia/tools/phylogeny_simulator.py create mode 100644 cassiopeia/tools/tree.py create mode 100644 test/tools_tests/lineage_tracing_simulator_test.py create mode 100644 test/tools_tests/phylogeny_simulator_test.py create mode 100644 test/tools_tests/tree_test.py diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index e69de29b..0e144d30 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -0,0 +1,2 @@ +from .branch_length_estimation import PoissonConvexBLE +from .lineage_simulator import lineage_tracing_simulator diff --git a/cassiopeia/tools/branch_length_estimation.py b/cassiopeia/tools/branch_length_estimation.py index e69de29b..9cfbd938 100644 --- a/cassiopeia/tools/branch_length_estimation.py +++ b/cassiopeia/tools/branch_length_estimation.py @@ -0,0 +1,128 @@ +import abc +import cvxpy as cp +from .tree import Tree + + +class BranchLengthEstimator(abc.ABC): + r""" + Abstract base class for all branch length estimators. + """ + @abc.abstractmethod + def estimate_branch_lengths(self, tree: Tree) -> None: + r""" + Annotates the tree's nodes with their estimated age, and + the tree's branches with their lengths. + Operates on the tree in-place. + + Args: + tree: The tree for which to estimate branch lengths. + """ + + +class PoissonConvexBLE(BranchLengthEstimator): + r""" + A simple branch length estimator that assumes that the characters evolve IID + over the phylogeny with the same cutting rate. + + Maximum Parsinomy is used to impute the ancestral states first. Doing so + leads to a convex optimization problem. + """ + def __init__( + self, + minimum_edge_length: float = 0, # TODO: minimum_branch_length? + l2_regularization: float = 0, + verbose: bool = False + ): + self.minimum_edge_length = minimum_edge_length + self.l2_regularization = l2_regularization + self.verbose = verbose + + def estimate_branch_lengths(self, tree: Tree) -> None: + r""" + Estimates branch lengths for T. + + This is in fact an exponential cone program, which is a special kind of + convex problem: + https://docs.mosek.com/modeling-cookbook/expo.html + + Args: + tree: The tree for which to estimate branch lengths. + + Returns: + The log-likelihood under the model for the computed branch lengths. + """ + # Extract parameters + minimum_edge_length = self.minimum_edge_length + l2_regularization = self.l2_regularization + verbose = self.verbose + + # # Wrap the networkx DiGraph for goodies. + # T = Tree(tree) + T = tree + + # # # # # Create variables of the optimization problem # # # # # + r_X_t_variables = dict([(node_id, cp.Variable(name=f'r_X_t_{node_id}')) + for node_id in T.nodes()]) + time_increases_constraints = [ + r_X_t_variables[parent] + >= r_X_t_variables[child] + minimum_edge_length + for (parent, child) in T.edges() + ] + leaves_have_age_0_constraints =\ + [r_X_t_variables[leaf] == 0 for leaf in T.leaves()] + non_negative_r_X_t_constraints =\ + [r_X_t >= 0 for r_X_t in r_X_t_variables.values()] + all_constraints =\ + time_increases_constraints + \ + leaves_have_age_0_constraints + \ + non_negative_r_X_t_constraints + + # # # # # Compute the log-likelihood # # # # # + log_likelihood = 0 + + # Because all rates are equal, the number of cuts in each node is a + # sufficient statistic. This makes the solver WAY faster! + for (parent, child) in T.edges(): + edge_length = r_X_t_variables[parent] - r_X_t_variables[child] + zeros_parent = T.get_state(parent).count('0') # TODO: '0'... + zeros_child = T.get_state(child).count('0') # TODO: '0'... + new_cuts_child = zeros_parent - zeros_child + assert(new_cuts_child >= 0) + # Add log-lik for characters that didn't get cut + log_likelihood += zeros_child * (-edge_length) + # Add log-lik for characters that got cut + log_likelihood += new_cuts_child * cp.log(1 - cp.exp(-edge_length)) + + # # # # # Add regularization # # # # # + + l2_penalty = 0 + for (parent, child) in T.edges(): + for child_of_child in T.children(child): + edge_length_above =\ + r_X_t_variables[parent] - r_X_t_variables[child] + edge_length_below =\ + r_X_t_variables[child] - r_X_t_variables[child_of_child] + l2_penalty += (edge_length_above - edge_length_below) ** 2 + l2_penalty *= l2_regularization + + # # # # # Solve the problem # # # # # + + obj = cp.Maximize(log_likelihood - l2_penalty) + prob = cp.Problem(obj, all_constraints) + + f_star = prob.solve(solver='ECOS', verbose=verbose) + + # # # # # Populate the tree with the estimated branch lengths # # # # # + + for node in T.nodes(): + T.set_age(node, age=r_X_t_variables[node].value) + + for (parent, child) in T.edges(): + new_edge_length =\ + r_X_t_variables[parent].value - r_X_t_variables[child].value + T.set_edge_length( + parent, + child, + length=new_edge_length) + + return f_star diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py new file mode 100644 index 00000000..84cd37c2 --- /dev/null +++ b/cassiopeia/tools/lineage_simulator.py @@ -0,0 +1,39 @@ +import numpy as np +from .tree import Tree + + +def lineage_tracing_simulator( + T: Tree, + mutation_rate: float, + num_characters: float +) -> None: + r""" + Populates the phylogenetic tree T with lineage tracing characters. + """ + def dfs(node: int, T: Tree): + node_state = T.get_state(node) + for child in T.children(node): + # Compute the state of the child + child_state = '' + edge_length = T.get_age(node) - T.get_age(child) + # print(f"{node} -> {child}, length {edge_length}") + assert(edge_length >= 0) + for i in range(num_characters): + # See what happens to character i + if node_state[i] != '0': + # The character has already mutated; there in nothing to do + child_state += node_state[i] + continue + else: + # Determine if the character will mutate. + mutates =\ + np.random.exponential(1.0 / mutation_rate) < edge_length + if mutates: + child_state += '1' + else: + child_state += '0' + T.set_state(child, child_state) + dfs(child, T) + root = T.root() + T.set_state(root, '0' * num_characters) + dfs(root, T) diff --git a/cassiopeia/tools/phylogeny_simulator.py b/cassiopeia/tools/phylogeny_simulator.py new file mode 100644 index 00000000..b9c6f438 --- /dev/null +++ b/cassiopeia/tools/phylogeny_simulator.py @@ -0,0 +1,56 @@ +from .tree import Tree +import networkx as nx + +from typing import List + + +def generate_perfect_binary_tree( + generation_branch_lengths: List[float] +) -> Tree: + r""" + See test for doc. + """ + n_generations = len(generation_branch_lengths) + T = nx.DiGraph() + T.add_nodes_from(range(2 ** (n_generations + 1) - 1)) + edges = [(int((child - 1) / 2), child) for child in range(1, 2 ** (n_generations + 1) - 1)] + node_generation = [] + for i in range(n_generations + 1): + node_generation += [i] * 2 ** i + T.add_edges_from(edges) + for (parent, child) in edges: + parent_generation = node_generation[parent] + branch_length = generation_branch_lengths[parent_generation] + T.edges[parent, child]["length"] = branch_length + T.nodes[0]["age"] = sum(generation_branch_lengths) + for child in range(1, 2 ** (n_generations + 1) - 1): + child_generation = node_generation[child] + branch_length = generation_branch_lengths[child_generation - 1] + T.nodes[child]["age"] = T.nodes[int((child - 1) / 2)]["age"] - branch_length + return Tree(T) + + +def generate_perfect_binary_tree_with_root_branch( + generation_branch_lengths: List[float] +) -> Tree: + r""" + See test for doc. + """ + n_generations = len(generation_branch_lengths) + T = nx.DiGraph() + T.add_nodes_from(range(2 ** n_generations)) + edges = [(int(child / 2), child) for child in range(1, 2 ** n_generations)] + T.add_edges_from(edges) + node_generation = [0] + for i in range(n_generations): + node_generation += [i + 1] * 2 ** i + for (parent, child) in edges: + parent_generation = node_generation[parent] + branch_length = generation_branch_lengths[parent_generation] + T.edges[parent, child]["length"] = branch_length + T.nodes[0]["age"] = sum(generation_branch_lengths) + for child in range(1, 2 ** n_generations): + child_generation = node_generation[child] + branch_length = generation_branch_lengths[child_generation - 1] + T.nodes[child]["age"] = T.nodes[int(child / 2)]["age"] - branch_length + return Tree(T) diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py new file mode 100644 index 00000000..dc876d90 --- /dev/null +++ b/cassiopeia/tools/tree.py @@ -0,0 +1,311 @@ +from ete3 import Tree as ETEtree +from itolapi import Itol, ItolExport +import networkx as nx + +from typing import List, Optional, Tuple + + +def upload_to_itol_and_export_figure( + newick_tree: str, + apiKey: str, + projectName: str, + tree_name_in_iTOL: str, + figure_file: str, + plot_grid_line_scale: Optional[float] = None, + horizontal_scale_factor: Optional[int] = None, + bootstrap_display: bool = True, + verbose: bool = True +): + r""" + TODO: Can use 'metadata_source' to select what data to show!!! + For all export parameters see: https://itol.embl.de/help.cgi#export + :param tree_name_in_iTOL: The name of the uploaded tree in iTOL + :param figure_file: Name of the file where the tree figure will be exported to. + :param plot_grid_line_scale: If provided, the distance between lines on the grid. + :param verbose: Verbosity + """ + # Write out newick tree to file + tree_to_plot_file = "tree_to_plot.tree" + with open(tree_to_plot_file, "w") as file: + file.write(newick_tree) + + # Upload newick tree + itol_uploader = Itol() + itol_uploader.add_file("tree_to_plot.tree") + itol_uploader.params["treeName"] = tree_name_in_iTOL + itol_uploader.params["APIkey"] = apiKey + itol_uploader.params["projectName"] = projectName + good_upload = itol_uploader.upload() + if not good_upload: + print("There was an error:" + itol_uploader.comm.upload_output) + if verbose: + print("iTOL output: " + str(itol_uploader.comm.upload_output)) + print("Tree Web Page URL: " + itol_uploader.get_webpage()) + print("Warnings: " + str(itol_uploader.comm.warnings)) + tree_id = itol_uploader.comm.tree_id + + # Export tree. See https://itol.embl.de/help.cgi#export for all parameters + itol_exporter = ItolExport() + itol_exporter.set_export_param_value("tree", tree_id) + itol_exporter.set_export_param_value( + "format", figure_file.split(".")[-1] + ) # ['png', 'svg', 'eps', 'ps', 'pdf', 'nexus', 'newick'] + itol_exporter.set_export_param_value("display_mode", 1) # rectangular tree + itol_exporter.set_export_param_value("label_display", 1) # Possible values: 0 or 1 (0=hide labels, 1=show labels) + if plot_grid_line_scale is not None: + itol_exporter.set_export_param_value("internal_scale", 1) + itol_exporter.set_export_param_value("internalScale1", plot_grid_line_scale) + itol_exporter.set_export_param_value("internalScale2", plot_grid_line_scale) + itol_exporter.set_export_param_value("internalScale1Dashed", 1) + itol_exporter.set_export_param_value("internalScale2Dashed", 1) + else: + itol_exporter.set_export_param_value("tree_scale", 0) + if horizontal_scale_factor is not None: + itol_exporter.set_export_param_value( + "horizontal_scale_factor", horizontal_scale_factor + ) # doesnt actually scale the artboard + if bootstrap_display: + itol_exporter.set_export_param_value("bootstrap_display", 1) + itol_exporter.set_export_param_value("bootstrap_type", 2) + itol_exporter.set_export_param_value("bootstrap_label_size", 18) + # itol_exporter.set_export_param_value("bootstrap_label_position", 20) + # itol_exporter.set_export_param_value("bootstrap_symbol_position", 20) + # itol_exporter.set_export_param_value("bootstrap_label_sci", 1) + # itol_exporter.set_export_param_value("bootstrap_slider_min", -1) + # itol_exporter.set_export_param_value("bootstrap_symbol_position", 0) + # itol_exporter.set_export_param_value("branchlength_display", 1) + # itol_exporter.set_export_param_value("branchlength_label_rounding", 1) + # itol_exporter.set_export_param_value("branchlength_label_age", 1) + # itol_exporter.set_export_param_value("internalScale1Label", 1) + # itol_exporter.set_export_param_value("newick_format", "ID") + # itol_exporter.set_export_param_value("internalScale1Label", 1) + itol_exporter.set_export_param_value("leaf_sorting", 1) # Possible values: 1 or 2 (1=normal sorting, 2=no sorting) + print(f"Exporting tree to {figure_file}") + itol_exporter.export(figure_file) + + # Cleanup + # os.remove("tree_to_plot.tree") + + +def create_networkx_DiGraph_from_newick_file(file_path: str) -> nx.DiGraph: + def newick_to_network( + newick_filepath, + f=1 + ): + """ + Given a file path to a newick file, convert to a directed graph. + + :param newick_filepath: + File path to a newick text file + :param f: + Parameter to be passed to Ete3 while reading in the newick file. (Default 1) + :return: a networkx file of the tree + """ + + G = nx.DiGraph() # the new graph + tree = ETEtree(newick_filepath, format=f) + + # Create dict from ete3 node to cassiopeia.Node + # NOTE(sprillo): Instead of mapping to a Cassiopeia node, we'll map to a string (just the node name) + e2cass = {} + edge_lengths = {} + internal_node_id = 0 + for n in tree.traverse("postorder"): + node_name = '' + if n.name == '': + # print(f"Node without name, is internal.") + node_name = 'state-node-' + str(internal_node_id) + internal_node_id += 1 + else: + node_name = n.name + e2cass[n] = node_name + G.add_node(node_name) + edge_lengths[node_name] = n._dist + + for p in tree.traverse("postorder"): + pn = e2cass[p] + for c in p.children: + cn = e2cass[c] + G.add_edge(pn, cn) + G.edges[pn, cn]["length"] = edge_lengths[cn] + return G + + T = newick_to_network(file_path) + return T + + +class Tree(): + r""" + networkx.Digraph wrapper to isolate networkx dependency and add custom tree + methods. + """ + def __init__(self, T: nx.DiGraph): + self.T = T + + def root(self) -> int: + T = self.T + root = [n for n in T if T.in_degree(n) == 0][0] + return root + + def leaves(self) -> List[int]: + T = self.T + leaves = [n for n in T if T.out_degree(n) == 0 and T.in_degree(n) == 1] + return leaves + + def internal_nodes(self) -> List[int]: + T = self.T + return [n for n in T if n != self.root() and n not in self.leaves()] + + def non_root_nodes(self) -> List[int]: + return self.leaves() + self.internal_nodes() + + def nodes(self): + T = self.T + return list(T.nodes()) + + def num_characters(self) -> int: + return len(self.T.nodes[0]["characters"]) + + def get_state(self, node: int) -> str: + T = self.T + return T.nodes[node]["characters"] + + def set_state(self, node: int, state: str) -> None: + T = self.T + T.nodes[node]["characters"] = state + + def set_states(self, node_state_list: List[Tuple[int, str]]) -> None: + for (node, state) in node_state_list: + self.set_state(node, state) + + def get_age(self, node: int) -> float: + T = self.T + return T.nodes[node]["age"] + + def set_age(self, node: int, age: float) -> None: + T = self.T + T.nodes[node]["age"] = age + + def edges(self) -> List[Tuple[int, int]]: + """List of (parent, child) tuples""" + T = self.T + return list(T.edges) + + def get_edge_length(self, parent: int, child: int) -> float: + T = self.T + assert parent in T + assert child in T[parent] + return T.edges[parent, child]["length"] + + def set_edge_length(self, parent: int, child: int, length: float) -> None: + T = self.T + assert parent in T + assert child in T[parent] + T.edges[parent, child]["length"] = length + + def set_edge_lengths( + self, + parent_child_and_length_list: List[Tuple[int, int, float]]) -> None: + for (parent, child, length) in parent_child_and_length_list: + self.set_edge_length(parent, child, length) + + def children(self, node: int) -> List[int]: + T = self.T + return list(T.adj[node]) + + def to_newick_tree_format( + self, + print_node_names: bool = True, + print_internal_nodes: bool = False, + append_state_to_node_name: bool = False, + print_pct_of_mutated_characters_along_edge: bool = False, + add_N_to_node_id: bool = False + ) -> str: + r""" + Converts tree into Newick tree format for viewing in e.g. ITOL. + Arguments: + print_internal_nodes: If True, prints the names of internal + nodes too. + print_pct_of_mutated_characters_along_edge: Self-explanatory + """ + leaves = self.leaves() + + def format_node(v: int): + node_id_prefix = '' if not add_N_to_node_id else 'N' + node_id = '' if not print_node_names else str(v) + node_suffix =\ + '' if not append_state_to_node_name\ + else '_' + str(self.get_state(v)) + return node_id_prefix + node_id + node_suffix + + def subtree_newick_representation(v: int) -> str: + if len(self.children(v)) == 0: + return format_node(v) + subtrees_newick = [] + for child in self.children(v): + edge_length = self.get_edge_length(v, child) + if child in leaves: + subtree_newick = subtree_newick_representation(child) + else: + subtree_newick =\ + '(' + subtree_newick_representation(child) + ')' + if print_internal_nodes: + subtree_newick += format_node(child) + # Add edge length + subtree_newick = subtree_newick + ':' + str(edge_length) + if print_pct_of_mutated_characters_along_edge: + # Also add number of mutations + number_of_unmutated_characters_in_parent =\ + self.get_state(v).count('0') + number_of_mutations_along_edge =\ + self.get_state(v).count('0')\ + - self.get_state(child).count('0') + pct_of_mutated_characters_along_edge =\ + number_of_mutations_along_edge /\ + (number_of_unmutated_characters_in_parent + 1e-100) + subtree_newick = subtree_newick +\ + "[&&NHX:muts="\ + f"{self._fmt(pct_of_mutated_characters_along_edge)}]" + subtrees_newick.append(subtree_newick) + newick = ','.join(subtrees_newick) + return newick + + root = self.root() + res = '(' + subtree_newick_representation(root) + ')' + if print_internal_nodes: + res += format_node(root) + res += ');' + return res + + def _fmt(self, x: float): + return '%.2f' % x + + def reconstruct_ancestral_states(self): + r""" + Reconstructs ancestral states with maximum parsimony. + """ + root = self.root() + + def dfs(v: int) -> None: + children = self.children(v) + n_children = len(children) + if n_children == 0: + return + for child in children: + dfs(child) + children_states = [self.get_state(child) for child in children] + n_characters = len(children_states[0]) + state = '' + for character_id in range(n_characters): + states_for_this_character =\ + set([children_states[i][character_id] + for i in range(n_children)]) + if len(states_for_this_character) == 1: + state += states_for_this_character.pop() + else: + state += '0' + self.set_state(v, state) + if v == root: + # Reset state to all zeros! + self.set_state(v, '0' * n_characters) + dfs(root) diff --git a/test/tools_tests/branch_length_estimation_test.py b/test/tools_tests/branch_length_estimation_test.py index e69de29b..71c6137d 100644 --- a/test/tools_tests/branch_length_estimation_test.py +++ b/test/tools_tests/branch_length_estimation_test.py @@ -0,0 +1,340 @@ +import networkx as nx +import numpy as np + +from cassiopeia.tools import PoissonConvexBLE +from cassiopeia.tools.lineage_simulator import lineage_tracing_simulator +from cassiopeia.tools.tree import Tree + + +def estimate_branch_lengths(T): + estimator = PoissonConvexBLE() + return estimator.estimate_branch_lengths(T) + + +def test_no_mutations(): + r""" + Tree topology is just a branch 0->1. + There is one unmutated character i.e.: + root [state = '0'] + | + v + child [state = '0'] + This is thus the simplest possible example of no mutations, and the MLE + branch length should be 0 + """ + T = nx.DiGraph() + T.add_node(0), T.add_node(1) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '0' + T = Tree(T) + log_likelihood = estimate_branch_lengths(T) + np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.0) + np.testing.assert_almost_equal(T.get_age(0), 0.0) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, 0.0) + + +def test_saturation(): + r""" + Tree topology is just a branch 0->1. + There is one mutated character i.e.: + root [state = '0'] + | + v + child [state = '1'] + This is thus the simplest possible example of saturation, and the MLE + branch length should be infinity (>15 for all practical purposes) + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1]) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '1' + T = Tree(T) + log_likelihood = estimate_branch_lengths(T) + assert(T.get_edge_length(0, 1) > 15.0) + assert(T.get_age(0) > 15.0) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=5) + + +def test_hand_solvable_problem_1(): + r""" + Tree topology is just a branch 0->1. + There is one mutated character and one unmutated character, i.e.: + root [state = '00'] + | + v + child [state = '01'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(2) ~ 0.693 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1]) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '00' + T.nodes[1]["characters"] = '01' + T = Tree(T) + log_likelihood = estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), np.log(2), decimal=3) + np.testing.assert_almost_equal(T.get_age(0), np.log(2), decimal=3) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) + + +def test_hand_solvable_problem_2(): + r""" + Tree topology is just a branch 0->1. + There are two mutated characters and one unmutated character, i.e.: + root [state = '000'] + | + v + child [state = '011'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} log(exp(-r * t0)) + 2 * log(1 - exp(-r * t0)) + The solution is r * t0 = ln(3) ~ 1.098 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1]) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '000' + T.nodes[1]["characters"] = '011' + T = Tree(T) + log_likelihood = estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), np.log(3), decimal=3) + np.testing.assert_almost_equal(T.get_age(0), np.log(3), decimal=3) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + + +def test_hand_solvable_problem_3(): + r""" + Tree topology is just a branch 0->1. + There are two unmutated characters and one mutated character, i.e.: + root [state = '000'] + | + v + child [state = '001'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(1.5) ~ 0.405 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1]) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '000' + T.nodes[1]["characters"] = '001' + T = Tree(T) + log_likelihood = estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), np.log(1.5), decimal=3) + np.testing.assert_almost_equal(T.get_age(0), np.log(1.5), decimal=3) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + + +def test_small_tree_with_no_mutations(): + r""" + Perfect binary tree with no mutations: Should give edges of length 0 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]) + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '0000' + T.nodes[1]["characters"] = '0000' + T.nodes[2]["characters"] = '0000' + T.nodes[3]["characters"] = '0000' + T.nodes[4]["characters"] = '0000' + T.nodes[5]["characters"] = '0000' + T.nodes[6]["characters"] = '0000' + T = Tree(T) + log_likelihood = estimate_branch_lengths(T) + for edge in T.edges(): + np.testing.assert_almost_equal(T.get_edge_length(*edge), 0, decimal=3) + np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) + + +def test_small_tree_with_one_mutation(): + r""" + Perfect binary tree with one mutation at a node 6: Should give very short + edges 1->3,1->4,0->2 and very long edges 0->1,2->5,2->6. + The problem can be solved by hand: it trivially reduces to a 1-dimensional + problem: + min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(1.5) ~ 0.405 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '0' + T.nodes[2]["characters"] = '0' + T.nodes[3]["characters"] = '0' + T.nodes[4]["characters"] = '0' + T.nodes[5]["characters"] = '0' + T.nodes[6]["characters"] = '1' + T = Tree(T) + log_likelihood = estimate_branch_lengths(T) + np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.405, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.0, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 4), 0.0, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.405, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.405, decimal=3) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + + +def test_small_tree_with_saturation(): + r""" + Perfect binary tree with saturation. The edges which saturate should thus + have length infinity (>15 for all practical purposes) + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '0' + T.nodes[2]["characters"] = '1' + T.nodes[3]["characters"] = '1' + T.nodes[4]["characters"] = '1' + T.nodes[5]["characters"] = '1' + T.nodes[6]["characters"] = '1' + T = Tree(T) + _ = estimate_branch_lengths(T) + assert(T.get_edge_length(0, 2) > 15.0) + assert(T.get_edge_length(1, 3) > 15.0) + assert(T.get_edge_length(1, 4) > 15.0) + + +def test_small_tree_regression(): + r""" + Regression test. Cannot be solved by hand. We just check that this solution + never changes. + """ + # Perfect binary tree with normal amount of mutations on each edge + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '000000000' + T.nodes[1]["characters"] = '100000000' + T.nodes[2]["characters"] = '000006000' + T.nodes[3]["characters"] = '120000000' + T.nodes[4]["characters"] = '103000000' + T.nodes[5]["characters"] = '000056700' + T.nodes[6]["characters"] = '000406089' + T = Tree(T) + log_likelihood = estimate_branch_lengths(T) + np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.203, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.082, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.175, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 4), 0.175, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.295, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.295, decimal=3) + np.testing.assert_almost_equal(log_likelihood, -22.689, decimal=3) + + +def test_small_symmetric_tree(): + r""" + Symmetric tree should have equal length edges. + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '000' + T.nodes[1]["characters"] = '100' + T.nodes[2]["characters"] = '100' + T.nodes[3]["characters"] = '110' + T.nodes[4]["characters"] = '110' + T.nodes[5]["characters"] = '110' + T.nodes[6]["characters"] = '110' + T = Tree(T) + _ = estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), T.get_edge_length(0, 2)) + np.testing.assert_almost_equal( + T.get_edge_length(1, 3), T.get_edge_length(1, 4)) + np.testing.assert_almost_equal( + T.get_edge_length(1, 4), T.get_edge_length(2, 5)) + np.testing.assert_almost_equal( + T.get_edge_length(2, 5), T.get_edge_length(2, 6)) + + +def test_small_tree_with_infinite_legs(): + r""" + Perfect binary tree with saturated leaves. The first level of the tree + should be normal (can be solved by hand, solution is log(2)), + the branches for the leaves should be infinity (>15 for all practical + purposes) + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '00' + T.nodes[1]["characters"] = '10' + T.nodes[2]["characters"] = '10' + T.nodes[3]["characters"] = '11' + T.nodes[4]["characters"] = '11' + T.nodes[5]["characters"] = '11' + T.nodes[6]["characters"] = '11' + T = Tree(T) + _ = estimate_branch_lengths(T) + np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.693, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.693, decimal=3) + assert(T.get_edge_length(1, 3) > 15) + assert(T.get_edge_length(1, 4) > 15) + assert(T.get_edge_length(2, 5) > 15) + assert(T.get_edge_length(2, 6) > 15) + + +def test_on_simulated_data(): + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["age"] = 1 + T.nodes[1]["age"] = 0.9 + T.nodes[2]["age"] = 0.1 + T.nodes[3]["age"] = 0 + T.nodes[4]["age"] = 0 + T.nodes[5]["age"] = 0 + T.nodes[6]["age"] = 0 + np.random.seed(1) + T = Tree(T) + lineage_tracing_simulator(T, mutation_rate=1.0, num_characters=100) + for node in T.nodes(): + T.set_age(node, -1) + estimate_branch_lengths(T) + assert(0.9 < T.get_age(0) < 1.1) + assert(0.8 < T.get_age(1) < 1.0) + assert(0.05 < T.get_age(2) < 0.15) + np.testing.assert_almost_equal(T.get_age(3), 0) + np.testing.assert_almost_equal(T.get_age(4), 0) + np.testing.assert_almost_equal(T.get_age(5), 0) + np.testing.assert_almost_equal(T.get_age(6), 0) + + +def test_subtree_collapses_when_no_mutations(): + r""" + A subtree with no mutations should collapse to 0. It reduces the problem to + the same as in 'test_hand_solvable_problem_1' + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4]), + T.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '1' + T.nodes[2]["characters"] = '1' + T.nodes[3]["characters"] = '1' + T.nodes[4]["characters"] = '0' + T = Tree(T) + log_likelihood = estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), np.log(2), decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 2), 0.0, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) + np.testing.assert_almost_equal( + T.get_edge_length(0, 4), np.log(2), decimal=3) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) diff --git a/test/tools_tests/lineage_tracing_simulator_test.py b/test/tools_tests/lineage_tracing_simulator_test.py new file mode 100644 index 00000000..c91997a3 --- /dev/null +++ b/test/tools_tests/lineage_tracing_simulator_test.py @@ -0,0 +1,24 @@ +import numpy as np +import networkx as nx + +from cassiopeia.tools.lineage_simulator import lineage_tracing_simulator +from cassiopeia.tools.tree import Tree + + +def test_smoke(): + r""" + Just tests that lineage_tracing_simulator runs + """ + np.random.seed(1) + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["age"] = 1 + T.nodes[1]["age"] = 0.9 + T.nodes[2]["age"] = 0.1 + T.nodes[3]["age"] = 0 + T.nodes[4]["age"] = 0 + T.nodes[5]["age"] = 0 + T.nodes[6]["age"] = 0 + T = Tree(T) + lineage_tracing_simulator(T, mutation_rate=1.0, num_characters=10) diff --git a/test/tools_tests/phylogeny_simulator_test.py b/test/tools_tests/phylogeny_simulator_test.py new file mode 100644 index 00000000..7088d90a --- /dev/null +++ b/test/tools_tests/phylogeny_simulator_test.py @@ -0,0 +1,18 @@ +from cassiopeia.tools.phylogeny_simulator import generate_perfect_binary_tree,\ + generate_perfect_binary_tree_with_root_branch + + +def test_generate_perfect_binary_tree_with_fixed_lengths(): + T = generate_perfect_binary_tree( + generation_branch_lengths=[2, 3] + ) + newick = T.to_newick_tree_format(print_internal_nodes=True) + assert(newick == '((3:3,4:3)1:2,(5:3,6:3)2:2)0);') + + +def test_generate_perfect_binary_tree_with_fixed_lengths_with_root_branch(): + T = generate_perfect_binary_tree_with_root_branch( + generation_branch_lengths=[2, 3, 4], + ) + newick = T.to_newick_tree_format(print_internal_nodes=True) + assert(newick == '(((4:4,5:4)2:3,(6:4,7:4)3:3)1:2)0);') diff --git a/test/tools_tests/tree_test.py b/test/tools_tests/tree_test.py new file mode 100644 index 00000000..9f48cca2 --- /dev/null +++ b/test/tools_tests/tree_test.py @@ -0,0 +1,130 @@ +import networkx as nx +from cassiopeia.tools.tree import Tree + + +def test_to_newick_tree_format(): + r""" + Example tree based off https://itol.embl.de/help.cgi#upload . + The most basic newick example should give: + (2:0.5,(4:0.3,5:0.4):0.2):0.1); + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5]) + T.add_edges_from([(0, 1), (1, 2), (1, 3), (3, 4), (3, 5)]) + T = Tree(T) + T.set_edge_lengths( + [(0, 1, 0.1), + (1, 2, 0.5), + (1, 3, 0.2), + (3, 4, 0.3), + (3, 5, 0.4)] + ) + T.set_states( + [(0, '0000000000'), + (1, '1000000000'), + (2, '1111000000'), + (3, '1110000000'), + (4, '1110000111'), + (5, '1110111111')] + ) + res = T.to_newick_tree_format(print_internal_nodes=False) + assert(res == "((2:0.5,(4:0.3,5:0.4):0.2):0.1));") + res = T.to_newick_tree_format( + print_node_names=False, + print_internal_nodes=True, + append_state_to_node_name=True) + assert(res == "((_1111000000:0.5,(_1110000111:0.3,_1110111111:0.4)" + "_1110000000:0.2)_1000000000:0.1)_0000000000);") + res = T.to_newick_tree_format(print_internal_nodes=True) + assert(res == "((2:0.5,(4:0.3,5:0.4)3:0.2)1:0.1)0);") + res = T.to_newick_tree_format(print_node_names=False) + assert(res == "((:0.5,(:0.3,:0.4):0.2):0.1));") + res = T.to_newick_tree_format( + print_internal_nodes=True, + add_N_to_node_id=True) + assert(res == "((N2:0.5,(N4:0.3,N5:0.4)N3:0.2)N1:0.1)N0);") + res = T.to_newick_tree_format( + print_internal_nodes=True, + append_state_to_node_name=True, + add_N_to_node_id=True) + assert(res == "((N2_1111000000:0.5,(N4_1110000111:0.3,N5_1110111111:0.4)" + "N3_1110000000:0.2)N1_1000000000:0.1)N0_0000000000);") + res = T.to_newick_tree_format( + print_internal_nodes=True, + print_pct_of_mutated_characters_along_edge=True, + add_N_to_node_id=True) + assert(res == "((N2:0.5[&&NHX:muts=0.33],(N4:0.3[&&NHX:muts=0.43]," + "N5:0.4[&&NHX:muts=0.86])N3:0.2[&&NHX:muts=0.22])" + "N1:0.1[&&NHX:muts=0.10])N0);") + + +def test_reconstruct_ancestral_states(): + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) + T.add_edges_from([(10, 11), + (11, 13), + (13, 0), (13, 1), + (11, 14), + (14, 2), (14, 3), + (10, 12), + (12, 15), + (15, 4), (15, 5), + (12, 16), + (16, 6), (16, 7), (16, 8), (16, 9)]) + T = Tree(T) + T.set_states( + [(0, '01101110100'), + (1, '01211111111'), + (2, '01322121111'), + (3, '01432122111'), + (4, '01541232111'), + (5, '01651233111'), + (6, '01763243111'), + (7, '01873240111'), + (8, '01983240111'), + (9, '01093240010'), + ] + ) + T.reconstruct_ancestral_states() + assert(T.get_state(10) == '00000000000') + assert(T.get_state(11) == '01000100100') + assert(T.get_state(13) == '01001110100') + assert(T.get_state(14) == '01002120111') + assert(T.get_state(12) == '01000200010') + assert(T.get_state(15) == '01001230111') + assert(T.get_state(16) == '01003240010') + + +def test_reconstruct_ancestral_states_DREAM_challenge_tree_25(): + T = nx.DiGraph() + T.add_nodes_from(list(range(21))) + T.add_edges_from([(9, 8), (8, 10), (8, 7), (7, 11), (7, 12), (9, 6), + (6, 2), (2, 0), (0, 13), (0, 14), (2, 1), (1, 15), + (1, 16), (6, 5), (5, 3), (3, 17), (3, 18), (5, 4), + (4, 19), (4, 20)]) + T = Tree(T) + T.set_states( + [(10, '0022100000'), + (11, '0022100000'), + (12, '0022100000'), + (13, '2012000220'), + (14, '2012000200'), + (15, '2012000100'), + (16, '2012000100'), + (17, '0001110220'), + (18, '0001110220'), + (19, '0000210220'), + (20, '0000210220'), + ] + ) + T.reconstruct_ancestral_states() + assert(T.get_state(7) == '0022100000') + assert(T.get_state(8) == '0022100000') + assert(T.get_state(0) == '2012000200') + assert(T.get_state(1) == '2012000100') + assert(T.get_state(2) == '2012000000') + assert(T.get_state(3) == '0001110220') + assert(T.get_state(4) == '0000210220') + assert(T.get_state(5) == '0000010220') + assert(T.get_state(6) == '0000000000') + assert(T.get_state(9) == '0000000000') From c52b47466a7de999f27de37541bd6c571b3d33e9 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sat, 26 Dec 2020 18:41:58 -0800 Subject: [PATCH 03/61] Create APIs and refactor --- cassiopeia/tools/__init__.py | 2 - cassiopeia/tools/branch_length_estimation.py | 128 ------- cassiopeia/tools/lineage_simulator.py | 112 ++++-- cassiopeia/tools/phylogeny_simulator.py | 56 --- cassiopeia/tools/tree.py | 26 +- .../branch_length_estimation_test.py | 340 ------------------ .../lineage_tracing_simulator_test.py | 6 +- test/tools_tests/phylogeny_simulator_test.py | 18 - 8 files changed, 100 insertions(+), 588 deletions(-) delete mode 100644 cassiopeia/tools/branch_length_estimation.py delete mode 100644 cassiopeia/tools/phylogeny_simulator.py delete mode 100644 test/tools_tests/branch_length_estimation_test.py delete mode 100644 test/tools_tests/phylogeny_simulator_test.py diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index 0e144d30..e69de29b 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -1,2 +0,0 @@ -from .branch_length_estimation import PoissonConvexBLE -from .lineage_simulator import lineage_tracing_simulator diff --git a/cassiopeia/tools/branch_length_estimation.py b/cassiopeia/tools/branch_length_estimation.py deleted file mode 100644 index 9cfbd938..00000000 --- a/cassiopeia/tools/branch_length_estimation.py +++ /dev/null @@ -1,128 +0,0 @@ -import abc -import cvxpy as cp -from .tree import Tree - - -class BranchLengthEstimator(abc.ABC): - r""" - Abstract base class for all branch length estimators. - """ - @abc.abstractmethod - def estimate_branch_lengths(self, tree: Tree) -> None: - r""" - Annotates the tree's nodes with their estimated age, and - the tree's branches with their lengths. - Operates on the tree in-place. - - Args: - tree: The tree for which to estimate branch lengths. - """ - - -class PoissonConvexBLE(BranchLengthEstimator): - r""" - A simple branch length estimator that assumes that the characters evolve IID - over the phylogeny with the same cutting rate. - - Maximum Parsinomy is used to impute the ancestral states first. Doing so - leads to a convex optimization problem. - """ - def __init__( - self, - minimum_edge_length: float = 0, # TODO: minimum_branch_length? - l2_regularization: float = 0, - verbose: bool = False - ): - self.minimum_edge_length = minimum_edge_length - self.l2_regularization = l2_regularization - self.verbose = verbose - - def estimate_branch_lengths(self, tree: Tree) -> None: - r""" - Estimates branch lengths for T. - - This is in fact an exponential cone program, which is a special kind of - convex problem: - https://docs.mosek.com/modeling-cookbook/expo.html - - Args: - tree: The tree for which to estimate branch lengths. - - Returns: - The log-likelihood under the model for the computed branch lengths. - """ - # Extract parameters - minimum_edge_length = self.minimum_edge_length - l2_regularization = self.l2_regularization - verbose = self.verbose - - # # Wrap the networkx DiGraph for goodies. - # T = Tree(tree) - T = tree - - # # # # # Create variables of the optimization problem # # # # # - r_X_t_variables = dict([(node_id, cp.Variable(name=f'r_X_t_{node_id}')) - for node_id in T.nodes()]) - time_increases_constraints = [ - r_X_t_variables[parent] - >= r_X_t_variables[child] + minimum_edge_length - for (parent, child) in T.edges() - ] - leaves_have_age_0_constraints =\ - [r_X_t_variables[leaf] == 0 for leaf in T.leaves()] - non_negative_r_X_t_constraints =\ - [r_X_t >= 0 for r_X_t in r_X_t_variables.values()] - all_constraints =\ - time_increases_constraints + \ - leaves_have_age_0_constraints + \ - non_negative_r_X_t_constraints - - # # # # # Compute the log-likelihood # # # # # - log_likelihood = 0 - - # Because all rates are equal, the number of cuts in each node is a - # sufficient statistic. This makes the solver WAY faster! - for (parent, child) in T.edges(): - edge_length = r_X_t_variables[parent] - r_X_t_variables[child] - zeros_parent = T.get_state(parent).count('0') # TODO: '0'... - zeros_child = T.get_state(child).count('0') # TODO: '0'... - new_cuts_child = zeros_parent - zeros_child - assert(new_cuts_child >= 0) - # Add log-lik for characters that didn't get cut - log_likelihood += zeros_child * (-edge_length) - # Add log-lik for characters that got cut - log_likelihood += new_cuts_child * cp.log(1 - cp.exp(-edge_length)) - - # # # # # Add regularization # # # # # - - l2_penalty = 0 - for (parent, child) in T.edges(): - for child_of_child in T.children(child): - edge_length_above =\ - r_X_t_variables[parent] - r_X_t_variables[child] - edge_length_below =\ - r_X_t_variables[child] - r_X_t_variables[child_of_child] - l2_penalty += (edge_length_above - edge_length_below) ** 2 - l2_penalty *= l2_regularization - - # # # # # Solve the problem # # # # # - - obj = cp.Maximize(log_likelihood - l2_penalty) - prob = cp.Problem(obj, all_constraints) - - f_star = prob.solve(solver='ECOS', verbose=verbose) - - # # # # # Populate the tree with the estimated branch lengths # # # # # - - for node in T.nodes(): - T.set_age(node, age=r_X_t_variables[node].value) - - for (parent, child) in T.edges(): - new_edge_length =\ - r_X_t_variables[parent].value - r_X_t_variables[child].value - T.set_edge_length( - parent, - child, - length=new_edge_length) - - return f_star diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index 84cd37c2..73958d6b 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -1,39 +1,83 @@ -import numpy as np +import abc from .tree import Tree +import networkx as nx +from typing import List -def lineage_tracing_simulator( - T: Tree, - mutation_rate: float, - num_characters: float -) -> None: + +class LineageSimulator(abc.ABC): r""" - Populates the phylogenetic tree T with lineage tracing characters. + Abstract base class for lineage simulators. """ - def dfs(node: int, T: Tree): - node_state = T.get_state(node) - for child in T.children(node): - # Compute the state of the child - child_state = '' - edge_length = T.get_age(node) - T.get_age(child) - # print(f"{node} -> {child}, length {edge_length}") - assert(edge_length >= 0) - for i in range(num_characters): - # See what happens to character i - if node_state[i] != '0': - # The character has already mutated; there in nothing to do - child_state += node_state[i] - continue - else: - # Determine if the character will mutate. - mutates =\ - np.random.exponential(1.0 / mutation_rate) < edge_length - if mutates: - child_state += '1' - else: - child_state += '0' - T.set_state(child, child_state) - dfs(child, T) - root = T.root() - T.set_state(root, '0' * num_characters) - dfs(root, T) + @abc.abstractmethod + def simulate_lineage(self) -> Tree: + r"""Simulates a ground truth lineage""" + + +class PerfectBinaryTree(LineageSimulator): + def __init__( + self, + generation_branch_lengths: List[float] + ): + self.generation_branch_lengths = generation_branch_lengths[:] + + def simulate_lineage(self) -> Tree: + r""" + See test for doc. + """ + generation_branch_lengths = self.generation_branch_lengths + n_generations = len(generation_branch_lengths) + T = nx.DiGraph() + T.add_nodes_from(range(2 ** (n_generations + 1) - 1)) + edges = [(int((child - 1) / 2), child) + for child in range(1, 2 ** (n_generations + 1) - 1)] + node_generation = [] + for i in range(n_generations + 1): + node_generation += [i] * 2 ** i + T.add_edges_from(edges) + for (parent, child) in edges: + parent_generation = node_generation[parent] + branch_length = generation_branch_lengths[parent_generation] + T.edges[parent, child]["length"] = branch_length + T.nodes[0]["age"] = sum(generation_branch_lengths) + for child in range(1, 2 ** (n_generations + 1) - 1): + child_generation = node_generation[child] + branch_length = generation_branch_lengths[child_generation - 1] + T.nodes[child]["age"] =\ + T.nodes[int((child - 1) / 2)]["age"] - branch_length + return Tree(T) + + +class PerfectBinaryTreeWithRootBranch(LineageSimulator): + def __init__( + self, + generation_branch_lengths: List[float] + ): + self.generation_branch_lengths = generation_branch_lengths + + def simulate_lineage(self) -> Tree: + r""" + See test for doc. + """ + # generation_branch_lengths = self.generation_branch_lengths + generation_branch_lengths = self.generation_branch_lengths + n_generations = len(generation_branch_lengths) + T = nx.DiGraph() + T.add_nodes_from(range(2 ** n_generations)) + edges = [(int(child / 2), child) + for child in range(1, 2 ** n_generations)] + T.add_edges_from(edges) + node_generation = [0] + for i in range(n_generations): + node_generation += [i + 1] * 2 ** i + for (parent, child) in edges: + parent_generation = node_generation[parent] + branch_length = generation_branch_lengths[parent_generation] + T.edges[parent, child]["length"] = branch_length + T.nodes[0]["age"] = sum(generation_branch_lengths) + for child in range(1, 2 ** n_generations): + child_generation = node_generation[child] + branch_length = generation_branch_lengths[child_generation - 1] + T.nodes[child]["age"] =\ + T.nodes[int(child / 2)]["age"] - branch_length + return Tree(T) diff --git a/cassiopeia/tools/phylogeny_simulator.py b/cassiopeia/tools/phylogeny_simulator.py deleted file mode 100644 index b9c6f438..00000000 --- a/cassiopeia/tools/phylogeny_simulator.py +++ /dev/null @@ -1,56 +0,0 @@ -from .tree import Tree -import networkx as nx - -from typing import List - - -def generate_perfect_binary_tree( - generation_branch_lengths: List[float] -) -> Tree: - r""" - See test for doc. - """ - n_generations = len(generation_branch_lengths) - T = nx.DiGraph() - T.add_nodes_from(range(2 ** (n_generations + 1) - 1)) - edges = [(int((child - 1) / 2), child) for child in range(1, 2 ** (n_generations + 1) - 1)] - node_generation = [] - for i in range(n_generations + 1): - node_generation += [i] * 2 ** i - T.add_edges_from(edges) - for (parent, child) in edges: - parent_generation = node_generation[parent] - branch_length = generation_branch_lengths[parent_generation] - T.edges[parent, child]["length"] = branch_length - T.nodes[0]["age"] = sum(generation_branch_lengths) - for child in range(1, 2 ** (n_generations + 1) - 1): - child_generation = node_generation[child] - branch_length = generation_branch_lengths[child_generation - 1] - T.nodes[child]["age"] = T.nodes[int((child - 1) / 2)]["age"] - branch_length - return Tree(T) - - -def generate_perfect_binary_tree_with_root_branch( - generation_branch_lengths: List[float] -) -> Tree: - r""" - See test for doc. - """ - n_generations = len(generation_branch_lengths) - T = nx.DiGraph() - T.add_nodes_from(range(2 ** n_generations)) - edges = [(int(child / 2), child) for child in range(1, 2 ** n_generations)] - T.add_edges_from(edges) - node_generation = [0] - for i in range(n_generations): - node_generation += [i + 1] * 2 ** i - for (parent, child) in edges: - parent_generation = node_generation[parent] - branch_length = generation_branch_lengths[parent_generation] - T.edges[parent, child]["length"] = branch_length - T.nodes[0]["age"] = sum(generation_branch_lengths) - for child in range(1, 2 ** n_generations): - child_generation = node_generation[child] - branch_length = generation_branch_lengths[child_generation - 1] - T.nodes[child]["age"] = T.nodes[int(child / 2)]["age"] - branch_length - return Tree(T) diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index dc876d90..8b6de5b1 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -17,11 +17,14 @@ def upload_to_itol_and_export_figure( verbose: bool = True ): r""" + TODO: Change doc to Google style TODO: Can use 'metadata_source' to select what data to show!!! For all export parameters see: https://itol.embl.de/help.cgi#export :param tree_name_in_iTOL: The name of the uploaded tree in iTOL - :param figure_file: Name of the file where the tree figure will be exported to. - :param plot_grid_line_scale: If provided, the distance between lines on the grid. + :param figure_file: Name of the file where the tree figure will be exported + to. + :param plot_grid_line_scale: If provided, the distance between lines on the + grid. :param verbose: Verbosity """ # Write out newick tree to file @@ -51,11 +54,14 @@ def upload_to_itol_and_export_figure( "format", figure_file.split(".")[-1] ) # ['png', 'svg', 'eps', 'ps', 'pdf', 'nexus', 'newick'] itol_exporter.set_export_param_value("display_mode", 1) # rectangular tree - itol_exporter.set_export_param_value("label_display", 1) # Possible values: 0 or 1 (0=hide labels, 1=show labels) + # Possible values: 0 or 1 (0=hide labels, 1=show labels) + itol_exporter.set_export_param_value("label_display", 1) if plot_grid_line_scale is not None: itol_exporter.set_export_param_value("internal_scale", 1) - itol_exporter.set_export_param_value("internalScale1", plot_grid_line_scale) - itol_exporter.set_export_param_value("internalScale2", plot_grid_line_scale) + itol_exporter.set_export_param_value( + "internalScale1", plot_grid_line_scale) + itol_exporter.set_export_param_value( + "internalScale2", plot_grid_line_scale) itol_exporter.set_export_param_value("internalScale1Dashed", 1) itol_exporter.set_export_param_value("internalScale2Dashed", 1) else: @@ -79,7 +85,8 @@ def upload_to_itol_and_export_figure( # itol_exporter.set_export_param_value("internalScale1Label", 1) # itol_exporter.set_export_param_value("newick_format", "ID") # itol_exporter.set_export_param_value("internalScale1Label", 1) - itol_exporter.set_export_param_value("leaf_sorting", 1) # Possible values: 1 or 2 (1=normal sorting, 2=no sorting) + # Possible values: 1 or 2 (1=normal sorting, 2=no sorting) + itol_exporter.set_export_param_value("leaf_sorting", 1) print(f"Exporting tree to {figure_file}") itol_exporter.export(figure_file) @@ -93,12 +100,14 @@ def newick_to_network( f=1 ): """ + TODO: Am I even using this? Given a file path to a newick file, convert to a directed graph. :param newick_filepath: File path to a newick text file :param f: - Parameter to be passed to Ete3 while reading in the newick file. (Default 1) + Parameter to be passed to Ete3 while reading in the newick file. + (Default 1) :return: a networkx file of the tree """ @@ -106,7 +115,8 @@ def newick_to_network( tree = ETEtree(newick_filepath, format=f) # Create dict from ete3 node to cassiopeia.Node - # NOTE(sprillo): Instead of mapping to a Cassiopeia node, we'll map to a string (just the node name) + # NOTE(sprillo): Instead of mapping to a Cassiopeia node, we'll map to + # a string (just the node name) e2cass = {} edge_lengths = {} internal_node_id = 0 diff --git a/test/tools_tests/branch_length_estimation_test.py b/test/tools_tests/branch_length_estimation_test.py deleted file mode 100644 index 71c6137d..00000000 --- a/test/tools_tests/branch_length_estimation_test.py +++ /dev/null @@ -1,340 +0,0 @@ -import networkx as nx -import numpy as np - -from cassiopeia.tools import PoissonConvexBLE -from cassiopeia.tools.lineage_simulator import lineage_tracing_simulator -from cassiopeia.tools.tree import Tree - - -def estimate_branch_lengths(T): - estimator = PoissonConvexBLE() - return estimator.estimate_branch_lengths(T) - - -def test_no_mutations(): - r""" - Tree topology is just a branch 0->1. - There is one unmutated character i.e.: - root [state = '0'] - | - v - child [state = '0'] - This is thus the simplest possible example of no mutations, and the MLE - branch length should be 0 - """ - T = nx.DiGraph() - T.add_node(0), T.add_node(1) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '0' - T = Tree(T) - log_likelihood = estimate_branch_lengths(T) - np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.0) - np.testing.assert_almost_equal(T.get_age(0), 0.0) - np.testing.assert_almost_equal(T.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, 0.0) - - -def test_saturation(): - r""" - Tree topology is just a branch 0->1. - There is one mutated character i.e.: - root [state = '0'] - | - v - child [state = '1'] - This is thus the simplest possible example of saturation, and the MLE - branch length should be infinity (>15 for all practical purposes) - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1]) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '1' - T = Tree(T) - log_likelihood = estimate_branch_lengths(T) - assert(T.get_edge_length(0, 1) > 15.0) - assert(T.get_age(0) > 15.0) - np.testing.assert_almost_equal(T.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=5) - - -def test_hand_solvable_problem_1(): - r""" - Tree topology is just a branch 0->1. - There is one mutated character and one unmutated character, i.e.: - root [state = '00'] - | - v - child [state = '01'] - The solution can be verified by hand. The optimization problem is: - min_{r * t0} log(exp(-r * t0)) + log(1 - exp(-r * t0)) - The solution is r * t0 = ln(2) ~ 0.693 - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1]) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '00' - T.nodes[1]["characters"] = '01' - T = Tree(T) - log_likelihood = estimate_branch_lengths(T) - np.testing.assert_almost_equal( - T.get_edge_length(0, 1), np.log(2), decimal=3) - np.testing.assert_almost_equal(T.get_age(0), np.log(2), decimal=3) - np.testing.assert_almost_equal(T.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) - - -def test_hand_solvable_problem_2(): - r""" - Tree topology is just a branch 0->1. - There are two mutated characters and one unmutated character, i.e.: - root [state = '000'] - | - v - child [state = '011'] - The solution can be verified by hand. The optimization problem is: - min_{r * t0} log(exp(-r * t0)) + 2 * log(1 - exp(-r * t0)) - The solution is r * t0 = ln(3) ~ 1.098 - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1]) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '000' - T.nodes[1]["characters"] = '011' - T = Tree(T) - log_likelihood = estimate_branch_lengths(T) - np.testing.assert_almost_equal( - T.get_edge_length(0, 1), np.log(3), decimal=3) - np.testing.assert_almost_equal(T.get_age(0), np.log(3), decimal=3) - np.testing.assert_almost_equal(T.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) - - -def test_hand_solvable_problem_3(): - r""" - Tree topology is just a branch 0->1. - There are two unmutated characters and one mutated character, i.e.: - root [state = '000'] - | - v - child [state = '001'] - The solution can be verified by hand. The optimization problem is: - min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) - The solution is r * t0 = ln(1.5) ~ 0.405 - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1]) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '000' - T.nodes[1]["characters"] = '001' - T = Tree(T) - log_likelihood = estimate_branch_lengths(T) - np.testing.assert_almost_equal( - T.get_edge_length(0, 1), np.log(1.5), decimal=3) - np.testing.assert_almost_equal(T.get_age(0), np.log(1.5), decimal=3) - np.testing.assert_almost_equal(T.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) - - -def test_small_tree_with_no_mutations(): - r""" - Perfect binary tree with no mutations: Should give edges of length 0 - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]) - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '0000' - T.nodes[1]["characters"] = '0000' - T.nodes[2]["characters"] = '0000' - T.nodes[3]["characters"] = '0000' - T.nodes[4]["characters"] = '0000' - T.nodes[5]["characters"] = '0000' - T.nodes[6]["characters"] = '0000' - T = Tree(T) - log_likelihood = estimate_branch_lengths(T) - for edge in T.edges(): - np.testing.assert_almost_equal(T.get_edge_length(*edge), 0, decimal=3) - np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) - - -def test_small_tree_with_one_mutation(): - r""" - Perfect binary tree with one mutation at a node 6: Should give very short - edges 1->3,1->4,0->2 and very long edges 0->1,2->5,2->6. - The problem can be solved by hand: it trivially reduces to a 1-dimensional - problem: - min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) - The solution is r * t0 = ln(1.5) ~ 0.405 - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '0' - T.nodes[2]["characters"] = '0' - T.nodes[3]["characters"] = '0' - T.nodes[4]["characters"] = '0' - T.nodes[5]["characters"] = '0' - T.nodes[6]["characters"] = '1' - T = Tree(T) - log_likelihood = estimate_branch_lengths(T) - np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.405, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.0, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 4), 0.0, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.405, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.405, decimal=3) - np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) - - -def test_small_tree_with_saturation(): - r""" - Perfect binary tree with saturation. The edges which saturate should thus - have length infinity (>15 for all practical purposes) - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '0' - T.nodes[2]["characters"] = '1' - T.nodes[3]["characters"] = '1' - T.nodes[4]["characters"] = '1' - T.nodes[5]["characters"] = '1' - T.nodes[6]["characters"] = '1' - T = Tree(T) - _ = estimate_branch_lengths(T) - assert(T.get_edge_length(0, 2) > 15.0) - assert(T.get_edge_length(1, 3) > 15.0) - assert(T.get_edge_length(1, 4) > 15.0) - - -def test_small_tree_regression(): - r""" - Regression test. Cannot be solved by hand. We just check that this solution - never changes. - """ - # Perfect binary tree with normal amount of mutations on each edge - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '000000000' - T.nodes[1]["characters"] = '100000000' - T.nodes[2]["characters"] = '000006000' - T.nodes[3]["characters"] = '120000000' - T.nodes[4]["characters"] = '103000000' - T.nodes[5]["characters"] = '000056700' - T.nodes[6]["characters"] = '000406089' - T = Tree(T) - log_likelihood = estimate_branch_lengths(T) - np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.203, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.082, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.175, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 4), 0.175, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.295, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.295, decimal=3) - np.testing.assert_almost_equal(log_likelihood, -22.689, decimal=3) - - -def test_small_symmetric_tree(): - r""" - Symmetric tree should have equal length edges. - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '000' - T.nodes[1]["characters"] = '100' - T.nodes[2]["characters"] = '100' - T.nodes[3]["characters"] = '110' - T.nodes[4]["characters"] = '110' - T.nodes[5]["characters"] = '110' - T.nodes[6]["characters"] = '110' - T = Tree(T) - _ = estimate_branch_lengths(T) - np.testing.assert_almost_equal( - T.get_edge_length(0, 1), T.get_edge_length(0, 2)) - np.testing.assert_almost_equal( - T.get_edge_length(1, 3), T.get_edge_length(1, 4)) - np.testing.assert_almost_equal( - T.get_edge_length(1, 4), T.get_edge_length(2, 5)) - np.testing.assert_almost_equal( - T.get_edge_length(2, 5), T.get_edge_length(2, 6)) - - -def test_small_tree_with_infinite_legs(): - r""" - Perfect binary tree with saturated leaves. The first level of the tree - should be normal (can be solved by hand, solution is log(2)), - the branches for the leaves should be infinity (>15 for all practical - purposes) - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '00' - T.nodes[1]["characters"] = '10' - T.nodes[2]["characters"] = '10' - T.nodes[3]["characters"] = '11' - T.nodes[4]["characters"] = '11' - T.nodes[5]["characters"] = '11' - T.nodes[6]["characters"] = '11' - T = Tree(T) - _ = estimate_branch_lengths(T) - np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.693, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.693, decimal=3) - assert(T.get_edge_length(1, 3) > 15) - assert(T.get_edge_length(1, 4) > 15) - assert(T.get_edge_length(2, 5) > 15) - assert(T.get_edge_length(2, 6) > 15) - - -def test_on_simulated_data(): - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["age"] = 1 - T.nodes[1]["age"] = 0.9 - T.nodes[2]["age"] = 0.1 - T.nodes[3]["age"] = 0 - T.nodes[4]["age"] = 0 - T.nodes[5]["age"] = 0 - T.nodes[6]["age"] = 0 - np.random.seed(1) - T = Tree(T) - lineage_tracing_simulator(T, mutation_rate=1.0, num_characters=100) - for node in T.nodes(): - T.set_age(node, -1) - estimate_branch_lengths(T) - assert(0.9 < T.get_age(0) < 1.1) - assert(0.8 < T.get_age(1) < 1.0) - assert(0.05 < T.get_age(2) < 0.15) - np.testing.assert_almost_equal(T.get_age(3), 0) - np.testing.assert_almost_equal(T.get_age(4), 0) - np.testing.assert_almost_equal(T.get_age(5), 0) - np.testing.assert_almost_equal(T.get_age(6), 0) - - -def test_subtree_collapses_when_no_mutations(): - r""" - A subtree with no mutations should collapse to 0. It reduces the problem to - the same as in 'test_hand_solvable_problem_1' - """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4]), - T.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '1' - T.nodes[2]["characters"] = '1' - T.nodes[3]["characters"] = '1' - T.nodes[4]["characters"] = '0' - T = Tree(T) - log_likelihood = estimate_branch_lengths(T) - np.testing.assert_almost_equal( - T.get_edge_length(0, 1), np.log(2), decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 2), 0.0, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) - np.testing.assert_almost_equal( - T.get_edge_length(0, 4), np.log(2), decimal=3) - np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) diff --git a/test/tools_tests/lineage_tracing_simulator_test.py b/test/tools_tests/lineage_tracing_simulator_test.py index c91997a3..eec38849 100644 --- a/test/tools_tests/lineage_tracing_simulator_test.py +++ b/test/tools_tests/lineage_tracing_simulator_test.py @@ -1,7 +1,8 @@ import numpy as np import networkx as nx -from cassiopeia.tools.lineage_simulator import lineage_tracing_simulator +from cassiopeia.tools.lineage_tracing_simulator import\ + IIDExponentialLineageTracer from cassiopeia.tools.tree import Tree @@ -21,4 +22,5 @@ def test_smoke(): T.nodes[5]["age"] = 0 T.nodes[6]["age"] = 0 T = Tree(T) - lineage_tracing_simulator(T, mutation_rate=1.0, num_characters=10) + IIDExponentialLineageTracer(mutation_rate=1.0, num_characters=10)\ + .overlay_lineage_tracing_data(T) diff --git a/test/tools_tests/phylogeny_simulator_test.py b/test/tools_tests/phylogeny_simulator_test.py deleted file mode 100644 index 7088d90a..00000000 --- a/test/tools_tests/phylogeny_simulator_test.py +++ /dev/null @@ -1,18 +0,0 @@ -from cassiopeia.tools.phylogeny_simulator import generate_perfect_binary_tree,\ - generate_perfect_binary_tree_with_root_branch - - -def test_generate_perfect_binary_tree_with_fixed_lengths(): - T = generate_perfect_binary_tree( - generation_branch_lengths=[2, 3] - ) - newick = T.to_newick_tree_format(print_internal_nodes=True) - assert(newick == '((3:3,4:3)1:2,(5:3,6:3)2:2)0);') - - -def test_generate_perfect_binary_tree_with_fixed_lengths_with_root_branch(): - T = generate_perfect_binary_tree_with_root_branch( - generation_branch_lengths=[2, 3, 4], - ) - newick = T.to_newick_tree_format(print_internal_nodes=True) - assert(newick == '(((4:4,5:4)2:3,(6:4,7:4)3:3)1:2)0);') From dfb7321191203e5e3c35b42211e850d6cc587d29 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sat, 26 Dec 2020 18:51:06 -0800 Subject: [PATCH 04/61] Move plotting code out of tree.py --- cassiopeia/tools/lineage_simulator.py | 4 +- cassiopeia/tools/tree.py | 144 +----------------- .../lineage_tracing_simulator_test.py | 2 +- test/tools_tests/tree_test.py | 1 + 4 files changed, 5 insertions(+), 146 deletions(-) diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index 73958d6b..840e73a6 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -1,9 +1,9 @@ import abc -from .tree import Tree import networkx as nx - from typing import List +from .tree import Tree + class LineageSimulator(abc.ABC): r""" diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 8b6de5b1..e0330982 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -1,147 +1,5 @@ -from ete3 import Tree as ETEtree -from itolapi import Itol, ItolExport import networkx as nx - -from typing import List, Optional, Tuple - - -def upload_to_itol_and_export_figure( - newick_tree: str, - apiKey: str, - projectName: str, - tree_name_in_iTOL: str, - figure_file: str, - plot_grid_line_scale: Optional[float] = None, - horizontal_scale_factor: Optional[int] = None, - bootstrap_display: bool = True, - verbose: bool = True -): - r""" - TODO: Change doc to Google style - TODO: Can use 'metadata_source' to select what data to show!!! - For all export parameters see: https://itol.embl.de/help.cgi#export - :param tree_name_in_iTOL: The name of the uploaded tree in iTOL - :param figure_file: Name of the file where the tree figure will be exported - to. - :param plot_grid_line_scale: If provided, the distance between lines on the - grid. - :param verbose: Verbosity - """ - # Write out newick tree to file - tree_to_plot_file = "tree_to_plot.tree" - with open(tree_to_plot_file, "w") as file: - file.write(newick_tree) - - # Upload newick tree - itol_uploader = Itol() - itol_uploader.add_file("tree_to_plot.tree") - itol_uploader.params["treeName"] = tree_name_in_iTOL - itol_uploader.params["APIkey"] = apiKey - itol_uploader.params["projectName"] = projectName - good_upload = itol_uploader.upload() - if not good_upload: - print("There was an error:" + itol_uploader.comm.upload_output) - if verbose: - print("iTOL output: " + str(itol_uploader.comm.upload_output)) - print("Tree Web Page URL: " + itol_uploader.get_webpage()) - print("Warnings: " + str(itol_uploader.comm.warnings)) - tree_id = itol_uploader.comm.tree_id - - # Export tree. See https://itol.embl.de/help.cgi#export for all parameters - itol_exporter = ItolExport() - itol_exporter.set_export_param_value("tree", tree_id) - itol_exporter.set_export_param_value( - "format", figure_file.split(".")[-1] - ) # ['png', 'svg', 'eps', 'ps', 'pdf', 'nexus', 'newick'] - itol_exporter.set_export_param_value("display_mode", 1) # rectangular tree - # Possible values: 0 or 1 (0=hide labels, 1=show labels) - itol_exporter.set_export_param_value("label_display", 1) - if plot_grid_line_scale is not None: - itol_exporter.set_export_param_value("internal_scale", 1) - itol_exporter.set_export_param_value( - "internalScale1", plot_grid_line_scale) - itol_exporter.set_export_param_value( - "internalScale2", plot_grid_line_scale) - itol_exporter.set_export_param_value("internalScale1Dashed", 1) - itol_exporter.set_export_param_value("internalScale2Dashed", 1) - else: - itol_exporter.set_export_param_value("tree_scale", 0) - if horizontal_scale_factor is not None: - itol_exporter.set_export_param_value( - "horizontal_scale_factor", horizontal_scale_factor - ) # doesnt actually scale the artboard - if bootstrap_display: - itol_exporter.set_export_param_value("bootstrap_display", 1) - itol_exporter.set_export_param_value("bootstrap_type", 2) - itol_exporter.set_export_param_value("bootstrap_label_size", 18) - # itol_exporter.set_export_param_value("bootstrap_label_position", 20) - # itol_exporter.set_export_param_value("bootstrap_symbol_position", 20) - # itol_exporter.set_export_param_value("bootstrap_label_sci", 1) - # itol_exporter.set_export_param_value("bootstrap_slider_min", -1) - # itol_exporter.set_export_param_value("bootstrap_symbol_position", 0) - # itol_exporter.set_export_param_value("branchlength_display", 1) - # itol_exporter.set_export_param_value("branchlength_label_rounding", 1) - # itol_exporter.set_export_param_value("branchlength_label_age", 1) - # itol_exporter.set_export_param_value("internalScale1Label", 1) - # itol_exporter.set_export_param_value("newick_format", "ID") - # itol_exporter.set_export_param_value("internalScale1Label", 1) - # Possible values: 1 or 2 (1=normal sorting, 2=no sorting) - itol_exporter.set_export_param_value("leaf_sorting", 1) - print(f"Exporting tree to {figure_file}") - itol_exporter.export(figure_file) - - # Cleanup - # os.remove("tree_to_plot.tree") - - -def create_networkx_DiGraph_from_newick_file(file_path: str) -> nx.DiGraph: - def newick_to_network( - newick_filepath, - f=1 - ): - """ - TODO: Am I even using this? - Given a file path to a newick file, convert to a directed graph. - - :param newick_filepath: - File path to a newick text file - :param f: - Parameter to be passed to Ete3 while reading in the newick file. - (Default 1) - :return: a networkx file of the tree - """ - - G = nx.DiGraph() # the new graph - tree = ETEtree(newick_filepath, format=f) - - # Create dict from ete3 node to cassiopeia.Node - # NOTE(sprillo): Instead of mapping to a Cassiopeia node, we'll map to - # a string (just the node name) - e2cass = {} - edge_lengths = {} - internal_node_id = 0 - for n in tree.traverse("postorder"): - node_name = '' - if n.name == '': - # print(f"Node without name, is internal.") - node_name = 'state-node-' + str(internal_node_id) - internal_node_id += 1 - else: - node_name = n.name - e2cass[n] = node_name - G.add_node(node_name) - edge_lengths[node_name] = n._dist - - for p in tree.traverse("postorder"): - pn = e2cass[p] - for c in p.children: - cn = e2cass[c] - G.add_edge(pn, cn) - G.edges[pn, cn]["length"] = edge_lengths[cn] - return G - - T = newick_to_network(file_path) - return T +from typing import List, Tuple class Tree(): diff --git a/test/tools_tests/lineage_tracing_simulator_test.py b/test/tools_tests/lineage_tracing_simulator_test.py index eec38849..122f8710 100644 --- a/test/tools_tests/lineage_tracing_simulator_test.py +++ b/test/tools_tests/lineage_tracing_simulator_test.py @@ -1,5 +1,5 @@ -import numpy as np import networkx as nx +import numpy as np from cassiopeia.tools.lineage_tracing_simulator import\ IIDExponentialLineageTracer diff --git a/test/tools_tests/tree_test.py b/test/tools_tests/tree_test.py index 9f48cca2..03d43113 100644 --- a/test/tools_tests/tree_test.py +++ b/test/tools_tests/tree_test.py @@ -1,4 +1,5 @@ import networkx as nx + from cassiopeia.tools.tree import Tree From 55a0bec26ce27ccd991731e5fd3cd937867d6378 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sat, 26 Dec 2020 19:36:07 -0800 Subject: [PATCH 05/61] Forgot to add files --- cassiopeia/tools/branch_length_estimator.py | 138 +++++++ cassiopeia/tools/lineage_tracing_simulator.py | 66 ++++ .../branch_length_estimator_test.py | 337 ++++++++++++++++++ test/tools_tests/lineage_simulator_test.py | 15 + 4 files changed, 556 insertions(+) create mode 100644 cassiopeia/tools/branch_length_estimator.py create mode 100644 cassiopeia/tools/lineage_tracing_simulator.py create mode 100644 test/tools_tests/branch_length_estimator_test.py create mode 100644 test/tools_tests/lineage_simulator_test.py diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py new file mode 100644 index 00000000..8a477ca3 --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator.py @@ -0,0 +1,138 @@ +import abc +import cvxpy as cp + +from .tree import Tree + + +class BranchLengthEstimator(abc.ABC): + r""" + Abstract base class for all branch length estimators. + """ + @abc.abstractmethod + def estimate_branch_lengths(self, tree: Tree) -> None: + r""" + Annotates the tree's nodes with their estimated age, and + the tree's branches with their lengths. + Operates on the tree in-place. + + Args: + tree: The tree for which to estimate branch lengths. + """ + + +class PoissonConvexBLE(BranchLengthEstimator): + r""" + A simple branch length estimator that assumes that the characters evolve IID + over the phylogeny with the same cutting rate. + + Maximum Parsinomy is used to impute the ancestral states first. Doing so + leads to a convex optimization problem. + """ + def __init__( + self, + minimum_edge_length: float = 0, # TODO: minimum_branch_length? + l2_regularization: float = 0, + verbose: bool = False + ): + self.minimum_edge_length = minimum_edge_length + self.l2_regularization = l2_regularization + self.verbose = verbose + + def estimate_branch_lengths(self, tree: Tree) -> float: + r""" + TODO: This shouldn't return the log-likelihood according to the API. + What should we do about this? Maybe let's look at sklearn? + + Estimates branch lengths for the given tree. + + This is in fact an exponential cone program, which is a special kind of + convex problem: + https://docs.mosek.com/modeling-cookbook/expo.html + + Args: + tree: The tree for which to estimate branch lengths. + + Returns: + The log-likelihood under the model for the computed branch lengths. + """ + # Extract parameters + minimum_edge_length = self.minimum_edge_length + l2_regularization = self.l2_regularization + verbose = self.verbose + + # # Wrap the networkx DiGraph for goodies. + # T = Tree(tree) + T = tree + + # # # # # Create variables of the optimization problem # # # # # + r_X_t_variables = dict([(node_id, cp.Variable(name=f'r_X_t_{node_id}')) + for node_id in T.nodes()]) + time_increases_constraints = [ + r_X_t_variables[parent] + >= r_X_t_variables[child] + minimum_edge_length + for (parent, child) in T.edges() + ] + leaves_have_age_0_constraints =\ + [r_X_t_variables[leaf] == 0 for leaf in T.leaves()] + non_negative_r_X_t_constraints =\ + [r_X_t >= 0 for r_X_t in r_X_t_variables.values()] + all_constraints =\ + time_increases_constraints + \ + leaves_have_age_0_constraints + \ + non_negative_r_X_t_constraints + + # # # # # Compute the log-likelihood # # # # # + log_likelihood = 0 + + # Because all rates are equal, the number of cuts in each node is a + # sufficient statistic. This makes the solver WAY faster! + for (parent, child) in T.edges(): + edge_length = r_X_t_variables[parent] - r_X_t_variables[child] + zeros_parent = T.get_state(parent).count('0') # TODO: '0'... + zeros_child = T.get_state(child).count('0') # TODO: '0'... + new_cuts_child = zeros_parent - zeros_child + assert(new_cuts_child >= 0) + # Add log-lik for characters that didn't get cut + log_likelihood += zeros_child * (-edge_length) + # Add log-lik for characters that got cut + log_likelihood += new_cuts_child * cp.log(1 - cp.exp(-edge_length)) + + # # # # # Add regularization # # # # # + + l2_penalty = 0 + for (parent, child) in T.edges(): + for child_of_child in T.children(child): + edge_length_above =\ + r_X_t_variables[parent] - r_X_t_variables[child] + edge_length_below =\ + r_X_t_variables[child] - r_X_t_variables[child_of_child] + l2_penalty += (edge_length_above - edge_length_below) ** 2 + l2_penalty *= l2_regularization + + # # # # # Solve the problem # # # # # + + obj = cp.Maximize(log_likelihood - l2_penalty) + prob = cp.Problem(obj, all_constraints) + + f_star = prob.solve(solver='ECOS', verbose=verbose) + + # # # # # Populate the tree with the estimated branch lengths # # # # # + + for node in T.nodes(): + T.set_age(node, age=r_X_t_variables[node].value) + + for (parent, child) in T.edges(): + new_edge_length =\ + r_X_t_variables[parent].value - r_X_t_variables[child].value + T.set_edge_length( + parent, + child, + length=new_edge_length) + + return f_star + + def score(self, tree: Tree) -> float: + r""" + The log-likelihood of the given data under the model + """ + raise NotImplementedError() diff --git a/cassiopeia/tools/lineage_tracing_simulator.py b/cassiopeia/tools/lineage_tracing_simulator.py new file mode 100644 index 00000000..4338ac0e --- /dev/null +++ b/cassiopeia/tools/lineage_tracing_simulator.py @@ -0,0 +1,66 @@ +import abc +import numpy as np + +from .tree import Tree + + +class LineageTracingSimulator(abc.ABC): + r""" + Abstract base class for all lineage tracing simulators. + """ + @abc.abstractmethod + def overlay_lineage_tracing_data(self, tree: Tree) -> None: + r""" + Annotates the tree's nodes with lineage tracing character vectors. + Operates on the tree in-place. + """ + + +class IIDExponentialLineageTracer(): + r""" + Characters evolve IID over the lineage, with exponential rates. + """ + def __init__( + self, + mutation_rate: float, + num_characters: float + ): + self.mutation_rate = mutation_rate + self.num_characters = num_characters + + def overlay_lineage_tracing_data(self, T: Tree) -> None: + r""" + Populates the phylogenetic tree T with lineage tracing characters. + """ + num_characters = self.num_characters + mutation_rate = self.mutation_rate + + def dfs(node: int, T: Tree): + node_state = T.get_state(node) + for child in T.children(node): + # Compute the state of the child + child_state = '' + edge_length = T.get_age(node) - T.get_age(child) + # print(f"{node} -> {child}, length {edge_length}") + assert(edge_length >= 0) + for i in range(num_characters): + # See what happens to character i + if node_state[i] != '0': + # The character has already mutated; there in nothing + # to do + child_state += node_state[i] + continue + else: + # Determine if the character will mutate. + mutates =\ + np.random.exponential(1.0 / mutation_rate)\ + < edge_length + if mutates: + child_state += '1' + else: + child_state += '0' + T.set_state(child, child_state) + dfs(child, T) + root = T.root() + T.set_state(root, '0' * num_characters) + dfs(root, T) diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py new file mode 100644 index 00000000..66e3d5a1 --- /dev/null +++ b/test/tools_tests/branch_length_estimator_test.py @@ -0,0 +1,337 @@ +import networkx as nx +import numpy as np + +from cassiopeia.tools.branch_length_estimator import PoissonConvexBLE +from cassiopeia.tools.lineage_tracing_simulator import\ + IIDExponentialLineageTracer +from cassiopeia.tools.tree import Tree + + +def test_no_mutations(): + r""" + Tree topology is just a branch 0->1. + There is one unmutated character i.e.: + root [state = '0'] + | + v + child [state = '0'] + This is thus the simplest possible example of no mutations, and the MLE + branch length should be 0 + """ + T = nx.DiGraph() + T.add_node(0), T.add_node(1) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '0' + T = Tree(T) + log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.0) + np.testing.assert_almost_equal(T.get_age(0), 0.0) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, 0.0) + + +def test_saturation(): + r""" + Tree topology is just a branch 0->1. + There is one mutated character i.e.: + root [state = '0'] + | + v + child [state = '1'] + This is thus the simplest possible example of saturation, and the MLE + branch length should be infinity (>15 for all practical purposes) + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1]) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '1' + T = Tree(T) + log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + assert(T.get_edge_length(0, 1) > 15.0) + assert(T.get_age(0) > 15.0) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=5) + + +def test_hand_solvable_problem_1(): + r""" + Tree topology is just a branch 0->1. + There is one mutated character and one unmutated character, i.e.: + root [state = '00'] + | + v + child [state = '01'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(2) ~ 0.693 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1]) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '00' + T.nodes[1]["characters"] = '01' + T = Tree(T) + log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), np.log(2), decimal=3) + np.testing.assert_almost_equal(T.get_age(0), np.log(2), decimal=3) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) + + +def test_hand_solvable_problem_2(): + r""" + Tree topology is just a branch 0->1. + There are two mutated characters and one unmutated character, i.e.: + root [state = '000'] + | + v + child [state = '011'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} log(exp(-r * t0)) + 2 * log(1 - exp(-r * t0)) + The solution is r * t0 = ln(3) ~ 1.098 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1]) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '000' + T.nodes[1]["characters"] = '011' + T = Tree(T) + log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), np.log(3), decimal=3) + np.testing.assert_almost_equal(T.get_age(0), np.log(3), decimal=3) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + + +def test_hand_solvable_problem_3(): + r""" + Tree topology is just a branch 0->1. + There are two unmutated characters and one mutated character, i.e.: + root [state = '000'] + | + v + child [state = '001'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(1.5) ~ 0.405 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1]) + T.add_edge(0, 1) + T.nodes[0]["characters"] = '000' + T.nodes[1]["characters"] = '001' + T = Tree(T) + log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), np.log(1.5), decimal=3) + np.testing.assert_almost_equal(T.get_age(0), np.log(1.5), decimal=3) + np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + + +def test_small_tree_with_no_mutations(): + r""" + Perfect binary tree with no mutations: Should give edges of length 0 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]) + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '0000' + T.nodes[1]["characters"] = '0000' + T.nodes[2]["characters"] = '0000' + T.nodes[3]["characters"] = '0000' + T.nodes[4]["characters"] = '0000' + T.nodes[5]["characters"] = '0000' + T.nodes[6]["characters"] = '0000' + T = Tree(T) + log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + for edge in T.edges(): + np.testing.assert_almost_equal(T.get_edge_length(*edge), 0, decimal=3) + np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) + + +def test_small_tree_with_one_mutation(): + r""" + Perfect binary tree with one mutation at a node 6: Should give very short + edges 1->3,1->4,0->2 and very long edges 0->1,2->5,2->6. + The problem can be solved by hand: it trivially reduces to a 1-dimensional + problem: + min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(1.5) ~ 0.405 + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '0' + T.nodes[2]["characters"] = '0' + T.nodes[3]["characters"] = '0' + T.nodes[4]["characters"] = '0' + T.nodes[5]["characters"] = '0' + T.nodes[6]["characters"] = '1' + T = Tree(T) + log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.405, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.0, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 4), 0.0, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.405, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.405, decimal=3) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + + +def test_small_tree_with_saturation(): + r""" + Perfect binary tree with saturation. The edges which saturate should thus + have length infinity (>15 for all practical purposes) + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '0' + T.nodes[2]["characters"] = '1' + T.nodes[3]["characters"] = '1' + T.nodes[4]["characters"] = '1' + T.nodes[5]["characters"] = '1' + T.nodes[6]["characters"] = '1' + T = Tree(T) + _ = PoissonConvexBLE().estimate_branch_lengths(T) + assert(T.get_edge_length(0, 2) > 15.0) + assert(T.get_edge_length(1, 3) > 15.0) + assert(T.get_edge_length(1, 4) > 15.0) + + +def test_small_tree_regression(): + r""" + Regression test. Cannot be solved by hand. We just check that this solution + never changes. + """ + # Perfect binary tree with normal amount of mutations on each edge + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '000000000' + T.nodes[1]["characters"] = '100000000' + T.nodes[2]["characters"] = '000006000' + T.nodes[3]["characters"] = '120000000' + T.nodes[4]["characters"] = '103000000' + T.nodes[5]["characters"] = '000056700' + T.nodes[6]["characters"] = '000406089' + T = Tree(T) + log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.203, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.082, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.175, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 4), 0.175, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.295, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.295, decimal=3) + np.testing.assert_almost_equal(log_likelihood, -22.689, decimal=3) + + +def test_small_symmetric_tree(): + r""" + Symmetric tree should have equal length edges. + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '000' + T.nodes[1]["characters"] = '100' + T.nodes[2]["characters"] = '100' + T.nodes[3]["characters"] = '110' + T.nodes[4]["characters"] = '110' + T.nodes[5]["characters"] = '110' + T.nodes[6]["characters"] = '110' + T = Tree(T) + _ = PoissonConvexBLE().estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), T.get_edge_length(0, 2)) + np.testing.assert_almost_equal( + T.get_edge_length(1, 3), T.get_edge_length(1, 4)) + np.testing.assert_almost_equal( + T.get_edge_length(1, 4), T.get_edge_length(2, 5)) + np.testing.assert_almost_equal( + T.get_edge_length(2, 5), T.get_edge_length(2, 6)) + + +def test_small_tree_with_infinite_legs(): + r""" + Perfect binary tree with saturated leaves. The first level of the tree + should be normal (can be solved by hand, solution is log(2)), + the branches for the leaves should be infinity (>15 for all practical + purposes) + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["characters"] = '00' + T.nodes[1]["characters"] = '10' + T.nodes[2]["characters"] = '10' + T.nodes[3]["characters"] = '11' + T.nodes[4]["characters"] = '11' + T.nodes[5]["characters"] = '11' + T.nodes[6]["characters"] = '11' + T = Tree(T) + _ = PoissonConvexBLE().estimate_branch_lengths(T) + np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.693, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.693, decimal=3) + assert(T.get_edge_length(1, 3) > 15) + assert(T.get_edge_length(1, 4) > 15) + assert(T.get_edge_length(2, 5) > 15) + assert(T.get_edge_length(2, 6) > 15) + + +def test_on_simulated_data(): + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + T.nodes[0]["age"] = 1 + T.nodes[1]["age"] = 0.9 + T.nodes[2]["age"] = 0.1 + T.nodes[3]["age"] = 0 + T.nodes[4]["age"] = 0 + T.nodes[5]["age"] = 0 + T.nodes[6]["age"] = 0 + np.random.seed(1) + T = Tree(T) + IIDExponentialLineageTracer(mutation_rate=1.0, num_characters=100)\ + .overlay_lineage_tracing_data(T) + for node in T.nodes(): + T.set_age(node, -1) + PoissonConvexBLE().estimate_branch_lengths(T) + assert(0.9 < T.get_age(0) < 1.1) + assert(0.8 < T.get_age(1) < 1.0) + assert(0.05 < T.get_age(2) < 0.15) + np.testing.assert_almost_equal(T.get_age(3), 0) + np.testing.assert_almost_equal(T.get_age(4), 0) + np.testing.assert_almost_equal(T.get_age(5), 0) + np.testing.assert_almost_equal(T.get_age(6), 0) + + +def test_subtree_collapses_when_no_mutations(): + r""" + A subtree with no mutations should collapse to 0. It reduces the problem to + the same as in 'test_hand_solvable_problem_1' + """ + T = nx.DiGraph() + T.add_nodes_from([0, 1, 2, 3, 4]), + T.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) + T.nodes[0]["characters"] = '0' + T.nodes[1]["characters"] = '1' + T.nodes[2]["characters"] = '1' + T.nodes[3]["characters"] = '1' + T.nodes[4]["characters"] = '0' + T = Tree(T) + log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + np.testing.assert_almost_equal( + T.get_edge_length(0, 1), np.log(2), decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 2), 0.0, decimal=3) + np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) + np.testing.assert_almost_equal( + T.get_edge_length(0, 4), np.log(2), decimal=3) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py new file mode 100644 index 00000000..52be7208 --- /dev/null +++ b/test/tools_tests/lineage_simulator_test.py @@ -0,0 +1,15 @@ +from cassiopeia.tools.lineage_simulator import PerfectBinaryTree,\ + PerfectBinaryTreeWithRootBranch + + +def test_PerfectBinaryTree(): + T = PerfectBinaryTree(generation_branch_lengths=[2, 3]).simulate_lineage() + newick = T.to_newick_tree_format(print_internal_nodes=True) + assert(newick == '((3:3,4:3)1:2,(5:3,6:3)2:2)0);') + + +def test_PerfectBinaryTreeWithRootBranch(): + T = PerfectBinaryTreeWithRootBranch(generation_branch_lengths=[2, 3, 4])\ + .simulate_lineage() + newick = T.to_newick_tree_format(print_internal_nodes=True) + assert(newick == '(((4:4,5:4)2:3,(6:4,7:4)3:3)1:2)0);') From 4437e48546caade5d1a5b491d9976343aaa9ed78 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sat, 26 Dec 2020 23:49:25 -0800 Subject: [PATCH 06/61] Add IIDExponentialBLEGridSearchCV with test --- cassiopeia/tools/branch_length_estimator.py | 139 +++++++++++++++++- cassiopeia/tools/tree.py | 25 ++++ .../branch_length_estimator_test.py | 72 +++++++-- 3 files changed, 217 insertions(+), 19 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py index 8a477ca3..6954eda6 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator.py @@ -1,6 +1,8 @@ import abc +import copy import cvxpy as cp - +import numpy as np +from typing import List, Tuple from .tree import Tree @@ -20,7 +22,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: """ -class PoissonConvexBLE(BranchLengthEstimator): +class IIDExponentialBLE(BranchLengthEstimator): r""" A simple branch length estimator that assumes that the characters evolve IID over the phylogeny with the same cutting rate. @@ -129,10 +131,137 @@ def estimate_branch_lengths(self, tree: Tree) -> float: child, length=new_edge_length) + self.log_likelihood = log_likelihood.value + self.log_loss = f_star + return f_star - def score(self, tree: Tree) -> float: + @classmethod + def log_likelihood(self, T: Tree) -> float: + r""" + The log-likelihood under the model. + """ + log_likelihood = 0.0 + for (parent, child) in T.edges(): + edge_length = T.get_age(parent) - T.get_age(child) + zeros_parent = T.get_state(parent).count('0') # TODO: hardcoded '0' + zeros_child = T.get_state(child).count('0') # TODO: hardcoded '0' + new_cuts_child = zeros_parent - zeros_child + assert(new_cuts_child >= 0) + # Add log-lik for characters that didn't get cut + log_likelihood += zeros_child * (-edge_length) + # Add log-lik for characters that got cut + if edge_length < 1e-8 and new_cuts_child > 0: + return -np.inf + log_likelihood += new_cuts_child * np.log(1 - np.exp(-edge_length)) + return log_likelihood + + +class IIDExponentialBLEGridSearchCV(BranchLengthEstimator): + r""" + Cross-validated version of IIDExponentialBLE which fits the hyperparameters + based on character-level held-out log-likelihood. + """ + def __init__( + self, + minimum_edge_lengths: Tuple[float] = (0), + l2_regularizations: Tuple[float] = (0), + verbose: bool = False + ): + self.minimum_edge_lengths = minimum_edge_lengths + self.l2_regularizations = l2_regularizations + self.verbose = verbose + + def estimate_branch_lengths(self, T: Tree) -> None: + r""" + TODO + """ + # Extract parameters + minimum_edge_lengths = self.minimum_edge_lengths + l2_regularizations = self.l2_regularizations + verbose = self.verbose + + held_out_log_likelihoods = [] # type: List[Tuple[float, List]] + for minimum_edge_length in minimum_edge_lengths: + for l2_regularization in l2_regularizations: + cv_log_likelihood = self._cv_log_likelihood( + T=T, + minimum_edge_length=minimum_edge_length, + l2_regularization=l2_regularization) + held_out_log_likelihoods.append( + (cv_log_likelihood, + [minimum_edge_length, + l2_regularization]) + ) + + # Refit model on full dataset with the best hyperparameters + held_out_log_likelihoods.sort(reverse=True) + best_minimum_edge_length, best_l2_regularization =\ + held_out_log_likelihoods[0][1] + if verbose: + print(f"Refitting full model with:\n" + f"minimum_edge_length={best_minimum_edge_length}\n" + f"l2_regularization={best_l2_regularization}") + log_likelihood = IIDExponentialBLE( + minimum_edge_length=best_minimum_edge_length, + l2_regularization=best_l2_regularization + ).estimate_branch_lengths(T) + self.minimum_edge_length = best_minimum_edge_length + self.l2_regularization = best_l2_regularization + self.log_likelihood = log_likelihood + + def _cv_log_likelihood( + self, + T: Tree, + minimum_edge_length: float, + l2_regularization: float + ) -> float: + verbose = self.verbose + if verbose: + print(f"Cross-validating hyperparameters:" + f"\nminimum_edge_length={minimum_edge_length}" + f"\nl2_regularizations={l2_regularization}") + n_characters = T.num_characters() + log_likelihood_folds = np.zeros(shape=(n_characters)) + for held_out_character_idx in range(n_characters): + T_train, T_valid =\ + self._cv_split( + T, + held_out_character_idx=held_out_character_idx + ) + IIDExponentialBLE( + minimum_edge_length=minimum_edge_length, + l2_regularization=l2_regularization + ).estimate_branch_lengths(T_train) + T_valid.copy_branch_lengths(T_other=T_train) + held_out_log_likelihood =\ + IIDExponentialBLE.log_likelihood(T_valid) + log_likelihood_folds[held_out_character_idx] =\ + held_out_log_likelihood + if verbose: + print(f"log_likelihood_folds = {log_likelihood_folds}") + print(f"mean log_likelihood_folds = " + f"{np.mean(log_likelihood_folds)}") + return np.mean(log_likelihood_folds) + + def _cv_split( + self, + T: Tree, + held_out_character_idx: int + ) -> Tuple[Tree, Tree]: r""" - The log-likelihood of the given data under the model + Creates a training and a cross validation tree by hiding the + character at position held_out_character_idx. """ - raise NotImplementedError() + T_train = copy.deepcopy(T) + T_valid = copy.deepcopy(T) + for node in T.nodes(): + state = T_train.get_state(node) + train_state =\ + state[:held_out_character_idx]\ + + state[(held_out_character_idx + 1):] + valid_data =\ + state[held_out_character_idx] + T_train.set_state(node, train_state) + T_valid.set_state(node, valid_data) + return T_train, T_valid diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index e0330982..d507cbd4 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -177,3 +177,28 @@ def dfs(v: int) -> None: # Reset state to all zeros! self.set_state(v, '0' * n_characters) dfs(root) + + def copy_branch_lengths(self, T_other): + r""" + Copies the branch lengths of T_other onto self + """ + assert(self.nodes() == T_other.nodes()) + assert(self.edges() == T_other.edges()) + + for node in self.nodes(): + new_age = T_other.get_age(node) + self.set_age(node, age=new_age) + + for (parent, child) in self.edges(): + new_edge_length =\ + T_other.get_age(parent) - T_other.get_age(child) + self.set_edge_length( + parent, + child, + length=new_edge_length) + + def print_edges(self): + for (parent, child) in self.edges(): + print(f"{parent}[{self.get_state(parent)}] -> " + f"{child}[{self.get_state(child)}]: " + f"{self.get_edge_length(parent, child)}") diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 66e3d5a1..a677bcae 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -1,7 +1,8 @@ import networkx as nx import numpy as np -from cassiopeia.tools.branch_length_estimator import PoissonConvexBLE +from cassiopeia.tools.branch_length_estimator import IIDExponentialBLE,\ + IIDExponentialBLEGridSearchCV from cassiopeia.tools.lineage_tracing_simulator import\ IIDExponentialLineageTracer from cassiopeia.tools.tree import Tree @@ -24,11 +25,13 @@ def test_no_mutations(): T.nodes[0]["characters"] = '0' T.nodes[1]["characters"] = '0' T = Tree(T) - log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.0) np.testing.assert_almost_equal(T.get_age(0), 0.0) np.testing.assert_almost_equal(T.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, 0.0) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_saturation(): @@ -48,11 +51,13 @@ def test_saturation(): T.nodes[0]["characters"] = '0' T.nodes[1]["characters"] = '1' T = Tree(T) - log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) assert(T.get_edge_length(0, 1) > 15.0) assert(T.get_age(0) > 15.0) np.testing.assert_almost_equal(T.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=5) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_hand_solvable_problem_1(): @@ -73,12 +78,14 @@ def test_hand_solvable_problem_1(): T.nodes[0]["characters"] = '00' T.nodes[1]["characters"] = '01' T = Tree(T) - log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) np.testing.assert_almost_equal( T.get_edge_length(0, 1), np.log(2), decimal=3) np.testing.assert_almost_equal(T.get_age(0), np.log(2), decimal=3) np.testing.assert_almost_equal(T.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_hand_solvable_problem_2(): @@ -99,12 +106,14 @@ def test_hand_solvable_problem_2(): T.nodes[0]["characters"] = '000' T.nodes[1]["characters"] = '011' T = Tree(T) - log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) np.testing.assert_almost_equal( T.get_edge_length(0, 1), np.log(3), decimal=3) np.testing.assert_almost_equal(T.get_age(0), np.log(3), decimal=3) np.testing.assert_almost_equal(T.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_hand_solvable_problem_3(): @@ -125,12 +134,14 @@ def test_hand_solvable_problem_3(): T.nodes[0]["characters"] = '000' T.nodes[1]["characters"] = '001' T = Tree(T) - log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) np.testing.assert_almost_equal( T.get_edge_length(0, 1), np.log(1.5), decimal=3) np.testing.assert_almost_equal(T.get_age(0), np.log(1.5), decimal=3) np.testing.assert_almost_equal(T.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_small_tree_with_no_mutations(): @@ -148,10 +159,12 @@ def test_small_tree_with_no_mutations(): T.nodes[5]["characters"] = '0000' T.nodes[6]["characters"] = '0000' T = Tree(T) - log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) for edge in T.edges(): np.testing.assert_almost_equal(T.get_edge_length(*edge), 0, decimal=3) np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_small_tree_with_one_mutation(): @@ -174,7 +187,7 @@ def test_small_tree_with_one_mutation(): T.nodes[5]["characters"] = '0' T.nodes[6]["characters"] = '1' T = Tree(T) - log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.405, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.0, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) @@ -182,6 +195,8 @@ def test_small_tree_with_one_mutation(): np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.405, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.405, decimal=3) np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_small_tree_with_saturation(): @@ -200,10 +215,12 @@ def test_small_tree_with_saturation(): T.nodes[5]["characters"] = '1' T.nodes[6]["characters"] = '1' T = Tree(T) - _ = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) assert(T.get_edge_length(0, 2) > 15.0) assert(T.get_edge_length(1, 3) > 15.0) assert(T.get_edge_length(1, 4) > 15.0) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_small_tree_regression(): @@ -223,7 +240,7 @@ def test_small_tree_regression(): T.nodes[5]["characters"] = '000056700' T.nodes[6]["characters"] = '000406089' T = Tree(T) - log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.203, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.082, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.175, decimal=3) @@ -231,6 +248,8 @@ def test_small_tree_regression(): np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.295, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.295, decimal=3) np.testing.assert_almost_equal(log_likelihood, -22.689, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_small_symmetric_tree(): @@ -248,7 +267,7 @@ def test_small_symmetric_tree(): T.nodes[5]["characters"] = '110' T.nodes[6]["characters"] = '110' T = Tree(T) - _ = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) np.testing.assert_almost_equal( T.get_edge_length(0, 1), T.get_edge_length(0, 2)) np.testing.assert_almost_equal( @@ -257,6 +276,8 @@ def test_small_symmetric_tree(): T.get_edge_length(1, 4), T.get_edge_length(2, 5)) np.testing.assert_almost_equal( T.get_edge_length(2, 5), T.get_edge_length(2, 6)) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_small_tree_with_infinite_legs(): @@ -277,13 +298,15 @@ def test_small_tree_with_infinite_legs(): T.nodes[5]["characters"] = '11' T.nodes[6]["characters"] = '11' T = Tree(T) - _ = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.693, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.693, decimal=3) assert(T.get_edge_length(1, 3) > 15) assert(T.get_edge_length(1, 4) > 15) assert(T.get_edge_length(2, 5) > 15) assert(T.get_edge_length(2, 6) > 15) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_on_simulated_data(): @@ -303,7 +326,7 @@ def test_on_simulated_data(): .overlay_lineage_tracing_data(T) for node in T.nodes(): T.set_age(node, -1) - PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) assert(0.9 < T.get_age(0) < 1.1) assert(0.8 < T.get_age(1) < 1.0) assert(0.05 < T.get_age(2) < 0.15) @@ -311,6 +334,8 @@ def test_on_simulated_data(): np.testing.assert_almost_equal(T.get_age(4), 0) np.testing.assert_almost_equal(T.get_age(5), 0) np.testing.assert_almost_equal(T.get_age(6), 0) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_subtree_collapses_when_no_mutations(): @@ -327,7 +352,7 @@ def test_subtree_collapses_when_no_mutations(): T.nodes[3]["characters"] = '1' T.nodes[4]["characters"] = '0' T = Tree(T) - log_likelihood = PoissonConvexBLE().estimate_branch_lengths(T) + log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) np.testing.assert_almost_equal( T.get_edge_length(0, 1), np.log(2), decimal=3) np.testing.assert_almost_equal(T.get_edge_length(1, 2), 0.0, decimal=3) @@ -335,3 +360,22 @@ def test_subtree_collapses_when_no_mutations(): np.testing.assert_almost_equal( T.get_edge_length(0, 4), np.log(2), decimal=3) np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + +def test_IIDExponentialBLEGridSearchCV(): + T = nx.DiGraph() + T.add_nodes_from([0, 1]), + T.add_edges_from([(0, 1)]) + T.nodes[0]["characters"] = '000' + T.nodes[1]["characters"] = '001' + T = Tree(T) + model = IIDExponentialBLEGridSearchCV( + minimum_edge_lengths=(0, 1.0, 3.0), + l2_regularizations=(0, ), + verbose=True + ) + model.estimate_branch_lengths(T) + minimum_edge_length = model.minimum_edge_length + np.testing.assert_almost_equal(minimum_edge_length, 1.0) From af232075f673cad07b65a1f5aedf130a7cbf3086 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 27 Dec 2020 15:01:11 -0800 Subject: [PATCH 07/61] Stop returning log-likelihood in IIDExponentialBLE.estimate_branch_lengths to conform to API --- cassiopeia/tools/branch_length_estimator.py | 15 +++--- .../branch_length_estimator_test.py | 52 ++++++++++++++----- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py index 6954eda6..9fdfdc77 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator.py @@ -40,11 +40,8 @@ def __init__( self.l2_regularization = l2_regularization self.verbose = verbose - def estimate_branch_lengths(self, tree: Tree) -> float: + def estimate_branch_lengths(self, tree: Tree) -> None: r""" - TODO: This shouldn't return the log-likelihood according to the API. - What should we do about this? Maybe let's look at sklearn? - Estimates branch lengths for the given tree. This is in fact an exponential cone program, which is a special kind of @@ -134,8 +131,6 @@ def estimate_branch_lengths(self, tree: Tree) -> float: self.log_likelihood = log_likelihood.value self.log_loss = f_star - return f_star - @classmethod def log_likelihood(self, T: Tree) -> float: r""" @@ -202,13 +197,15 @@ def estimate_branch_lengths(self, T: Tree) -> None: print(f"Refitting full model with:\n" f"minimum_edge_length={best_minimum_edge_length}\n" f"l2_regularization={best_l2_regularization}") - log_likelihood = IIDExponentialBLE( + final_model = IIDExponentialBLE( minimum_edge_length=best_minimum_edge_length, l2_regularization=best_l2_regularization - ).estimate_branch_lengths(T) + ) + final_model.estimate_branch_lengths(T) self.minimum_edge_length = best_minimum_edge_length self.l2_regularization = best_l2_regularization - self.log_likelihood = log_likelihood + self.log_likelihood = final_model.log_likelihood + self.log_loss = final_model.log_loss def _cv_log_likelihood( self, diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index a677bcae..f8ce62e0 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -25,7 +25,9 @@ def test_no_mutations(): T.nodes[0]["characters"] = '0' T.nodes[1]["characters"] = '0' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.0) np.testing.assert_almost_equal(T.get_age(0), 0.0) np.testing.assert_almost_equal(T.get_age(1), 0.0) @@ -51,7 +53,9 @@ def test_saturation(): T.nodes[0]["characters"] = '0' T.nodes[1]["characters"] = '1' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood assert(T.get_edge_length(0, 1) > 15.0) assert(T.get_age(0) > 15.0) np.testing.assert_almost_equal(T.get_age(1), 0.0) @@ -78,7 +82,9 @@ def test_hand_solvable_problem_1(): T.nodes[0]["characters"] = '00' T.nodes[1]["characters"] = '01' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood np.testing.assert_almost_equal( T.get_edge_length(0, 1), np.log(2), decimal=3) np.testing.assert_almost_equal(T.get_age(0), np.log(2), decimal=3) @@ -106,7 +112,9 @@ def test_hand_solvable_problem_2(): T.nodes[0]["characters"] = '000' T.nodes[1]["characters"] = '011' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood np.testing.assert_almost_equal( T.get_edge_length(0, 1), np.log(3), decimal=3) np.testing.assert_almost_equal(T.get_age(0), np.log(3), decimal=3) @@ -134,7 +142,9 @@ def test_hand_solvable_problem_3(): T.nodes[0]["characters"] = '000' T.nodes[1]["characters"] = '001' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood np.testing.assert_almost_equal( T.get_edge_length(0, 1), np.log(1.5), decimal=3) np.testing.assert_almost_equal(T.get_age(0), np.log(1.5), decimal=3) @@ -159,7 +169,9 @@ def test_small_tree_with_no_mutations(): T.nodes[5]["characters"] = '0000' T.nodes[6]["characters"] = '0000' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood for edge in T.edges(): np.testing.assert_almost_equal(T.get_edge_length(*edge), 0, decimal=3) np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) @@ -187,7 +199,9 @@ def test_small_tree_with_one_mutation(): T.nodes[5]["characters"] = '0' T.nodes[6]["characters"] = '1' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.405, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.0, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) @@ -215,7 +229,9 @@ def test_small_tree_with_saturation(): T.nodes[5]["characters"] = '1' T.nodes[6]["characters"] = '1' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood assert(T.get_edge_length(0, 2) > 15.0) assert(T.get_edge_length(1, 3) > 15.0) assert(T.get_edge_length(1, 4) > 15.0) @@ -240,7 +256,9 @@ def test_small_tree_regression(): T.nodes[5]["characters"] = '000056700' T.nodes[6]["characters"] = '000406089' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.203, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.082, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.175, decimal=3) @@ -267,7 +285,9 @@ def test_small_symmetric_tree(): T.nodes[5]["characters"] = '110' T.nodes[6]["characters"] = '110' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood np.testing.assert_almost_equal( T.get_edge_length(0, 1), T.get_edge_length(0, 2)) np.testing.assert_almost_equal( @@ -298,7 +318,9 @@ def test_small_tree_with_infinite_legs(): T.nodes[5]["characters"] = '11' T.nodes[6]["characters"] = '11' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.693, decimal=3) np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.693, decimal=3) assert(T.get_edge_length(1, 3) > 15) @@ -326,7 +348,9 @@ def test_on_simulated_data(): .overlay_lineage_tracing_data(T) for node in T.nodes(): T.set_age(node, -1) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood assert(0.9 < T.get_age(0) < 1.1) assert(0.8 < T.get_age(1) < 1.0) assert(0.05 < T.get_age(2) < 0.15) @@ -352,7 +376,9 @@ def test_subtree_collapses_when_no_mutations(): T.nodes[3]["characters"] = '1' T.nodes[4]["characters"] = '0' T = Tree(T) - log_likelihood = IIDExponentialBLE().estimate_branch_lengths(T) + model = IIDExponentialBLE() + model.estimate_branch_lengths(T) + log_likelihood = model.log_likelihood np.testing.assert_almost_equal( T.get_edge_length(0, 1), np.log(2), decimal=3) np.testing.assert_almost_equal(T.get_edge_length(1, 2), 0.0, decimal=3) From 8a92290719d3598a0b0a55af33b4941c0df22184 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 27 Dec 2020 15:18:32 -0800 Subject: [PATCH 08/61] tuple border case bugfix --- cassiopeia/tools/branch_length_estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py index 9fdfdc77..2b791f21 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator.py @@ -159,8 +159,8 @@ class IIDExponentialBLEGridSearchCV(BranchLengthEstimator): """ def __init__( self, - minimum_edge_lengths: Tuple[float] = (0), - l2_regularizations: Tuple[float] = (0), + minimum_edge_lengths: Tuple[float] = (0, ), + l2_regularizations: Tuple[float] = (0, ), verbose: bool = False ): self.minimum_edge_lengths = minimum_edge_lengths From b04ef59c061940ab92a1baa993807972c5ca5391 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 27 Dec 2020 15:53:51 -0800 Subject: [PATCH 09/61] small bugfix --- cassiopeia/tools/tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index d507cbd4..36ce0117 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -32,7 +32,7 @@ def nodes(self): return list(T.nodes()) def num_characters(self) -> int: - return len(self.T.nodes[0]["characters"]) + return len(self.T.nodes[self.root()]["characters"]) def get_state(self, node: int) -> str: T = self.T From 7ee49ac80e2a1cfd49b15919b51be0156a1c910f Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 28 Dec 2020 12:37:49 -0800 Subject: [PATCH 10/61] IIDExponentialBLEGridSearchCV should check for solver errors in IIDExponentialBLE --- cassiopeia/tools/branch_length_estimator.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py index 2b791f21..5c6d0c41 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator.py @@ -226,13 +226,16 @@ def _cv_log_likelihood( T, held_out_character_idx=held_out_character_idx ) - IIDExponentialBLE( - minimum_edge_length=minimum_edge_length, - l2_regularization=l2_regularization - ).estimate_branch_lengths(T_train) - T_valid.copy_branch_lengths(T_other=T_train) - held_out_log_likelihood =\ - IIDExponentialBLE.log_likelihood(T_valid) + try: + IIDExponentialBLE( + minimum_edge_length=minimum_edge_length, + l2_regularization=l2_regularization + ).estimate_branch_lengths(T_train) + T_valid.copy_branch_lengths(T_other=T_train) + held_out_log_likelihood =\ + IIDExponentialBLE.log_likelihood(T_valid) + except cp.error.SolverError: + held_out_log_likelihood = -np.inf log_likelihood_folds[held_out_character_idx] =\ held_out_log_likelihood if verbose: From b3f57b1b706bdcfb0b77d8b55b3dba850d52ea78 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 28 Dec 2020 13:39:44 -0800 Subject: [PATCH 11/61] reorder imports --- cassiopeia/tools/__init__.py | 15 +++++++++++++++ cassiopeia/tools/branch_length_estimator.py | 4 +++- cassiopeia/tools/lineage_simulator.py | 3 ++- cassiopeia/tools/lineage_tracing_simulator.py | 3 ++- cassiopeia/tools/tree.py | 3 ++- test/tools_tests/branch_length_estimator_test.py | 7 ++----- test/tools_tests/lineage_simulator_test.py | 3 +-- .../tools_tests/lineage_tracing_simulator_test.py | 4 +--- 8 files changed, 28 insertions(+), 14 deletions(-) diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index e69de29b..8076c272 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -0,0 +1,15 @@ +from .branch_length_estimator import ( + BranchLengthEstimator, + IIDExponentialBLE, + IIDExponentialBLEGridSearchCV +) +from .lineage_simulator import ( + LineageSimulator, + PerfectBinaryTree, + PerfectBinaryTreeWithRootBranch +) +from .lineage_tracing_simulator import ( + LineageTracingSimulator, + IIDExponentialLineageTracer +) +from .tree import Tree diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py index 5c6d0c41..59621945 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator.py @@ -1,8 +1,10 @@ import abc import copy +from typing import List, Tuple + import cvxpy as cp import numpy as np -from typing import List, Tuple + from .tree import Tree diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index 840e73a6..119e5bf1 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -1,7 +1,8 @@ import abc -import networkx as nx from typing import List +import networkx as nx + from .tree import Tree diff --git a/cassiopeia/tools/lineage_tracing_simulator.py b/cassiopeia/tools/lineage_tracing_simulator.py index 4338ac0e..777b66c1 100644 --- a/cassiopeia/tools/lineage_tracing_simulator.py +++ b/cassiopeia/tools/lineage_tracing_simulator.py @@ -1,4 +1,5 @@ import abc + import numpy as np from .tree import Tree @@ -18,7 +19,7 @@ def overlay_lineage_tracing_data(self, tree: Tree) -> None: class IIDExponentialLineageTracer(): r""" - Characters evolve IID over the lineage, with exponential rates. + Characters evolve IID over the lineage, with the same rate. """ def __init__( self, diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 36ce0117..f4095117 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -1,6 +1,7 @@ -import networkx as nx from typing import List, Tuple +import networkx as nx + class Tree(): r""" diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index f8ce62e0..9b936516 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -1,11 +1,8 @@ import networkx as nx import numpy as np -from cassiopeia.tools.branch_length_estimator import IIDExponentialBLE,\ - IIDExponentialBLEGridSearchCV -from cassiopeia.tools.lineage_tracing_simulator import\ - IIDExponentialLineageTracer -from cassiopeia.tools.tree import Tree +from cassiopeia.tools import (IIDExponentialBLE, IIDExponentialBLEGridSearchCV, + IIDExponentialLineageTracer, Tree) def test_no_mutations(): diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py index 52be7208..cf7ae578 100644 --- a/test/tools_tests/lineage_simulator_test.py +++ b/test/tools_tests/lineage_simulator_test.py @@ -1,5 +1,4 @@ -from cassiopeia.tools.lineage_simulator import PerfectBinaryTree,\ - PerfectBinaryTreeWithRootBranch +from cassiopeia.tools import PerfectBinaryTree, PerfectBinaryTreeWithRootBranch def test_PerfectBinaryTree(): diff --git a/test/tools_tests/lineage_tracing_simulator_test.py b/test/tools_tests/lineage_tracing_simulator_test.py index 122f8710..f526aa40 100644 --- a/test/tools_tests/lineage_tracing_simulator_test.py +++ b/test/tools_tests/lineage_tracing_simulator_test.py @@ -1,9 +1,7 @@ import networkx as nx import numpy as np -from cassiopeia.tools.lineage_tracing_simulator import\ - IIDExponentialLineageTracer -from cassiopeia.tools.tree import Tree +from cassiopeia.tools import IIDExponentialLineageTracer, Tree def test_smoke(): From b781c235a17056f82244d5446c701e5e0970d0bd Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 28 Dec 2020 15:24:13 -0800 Subject: [PATCH 12/61] Rename minimum_edge_length to minimum_branch_length --- cassiopeia/tools/branch_length_estimator.py | 34 +++++++++---------- .../branch_length_estimator_test.py | 6 ++-- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py index 59621945..368c1e11 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator.py @@ -34,11 +34,11 @@ class IIDExponentialBLE(BranchLengthEstimator): """ def __init__( self, - minimum_edge_length: float = 0, # TODO: minimum_branch_length? + minimum_branch_length: float = 0, l2_regularization: float = 0, verbose: bool = False ): - self.minimum_edge_length = minimum_edge_length + self.minimum_branch_length = minimum_branch_length self.l2_regularization = l2_regularization self.verbose = verbose @@ -57,7 +57,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: The log-likelihood under the model for the computed branch lengths. """ # Extract parameters - minimum_edge_length = self.minimum_edge_length + minimum_branch_length = self.minimum_branch_length l2_regularization = self.l2_regularization verbose = self.verbose @@ -70,7 +70,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: for node_id in T.nodes()]) time_increases_constraints = [ r_X_t_variables[parent] - >= r_X_t_variables[child] + minimum_edge_length + >= r_X_t_variables[child] + minimum_branch_length for (parent, child) in T.edges() ] leaves_have_age_0_constraints =\ @@ -161,11 +161,11 @@ class IIDExponentialBLEGridSearchCV(BranchLengthEstimator): """ def __init__( self, - minimum_edge_lengths: Tuple[float] = (0, ), + minimum_branch_lengths: Tuple[float] = (0, ), l2_regularizations: Tuple[float] = (0, ), verbose: bool = False ): - self.minimum_edge_lengths = minimum_edge_lengths + self.minimum_branch_lengths = minimum_branch_lengths self.l2_regularizations = l2_regularizations self.verbose = verbose @@ -174,37 +174,37 @@ def estimate_branch_lengths(self, T: Tree) -> None: TODO """ # Extract parameters - minimum_edge_lengths = self.minimum_edge_lengths + minimum_branch_lengths = self.minimum_branch_lengths l2_regularizations = self.l2_regularizations verbose = self.verbose held_out_log_likelihoods = [] # type: List[Tuple[float, List]] - for minimum_edge_length in minimum_edge_lengths: + for minimum_branch_length in minimum_branch_lengths: for l2_regularization in l2_regularizations: cv_log_likelihood = self._cv_log_likelihood( T=T, - minimum_edge_length=minimum_edge_length, + minimum_branch_length=minimum_branch_length, l2_regularization=l2_regularization) held_out_log_likelihoods.append( (cv_log_likelihood, - [minimum_edge_length, + [minimum_branch_length, l2_regularization]) ) # Refit model on full dataset with the best hyperparameters held_out_log_likelihoods.sort(reverse=True) - best_minimum_edge_length, best_l2_regularization =\ + best_minimum_branch_length, best_l2_regularization =\ held_out_log_likelihoods[0][1] if verbose: print(f"Refitting full model with:\n" - f"minimum_edge_length={best_minimum_edge_length}\n" + f"minimum_branch_length={best_minimum_branch_length}\n" f"l2_regularization={best_l2_regularization}") final_model = IIDExponentialBLE( - minimum_edge_length=best_minimum_edge_length, + minimum_branch_length=best_minimum_branch_length, l2_regularization=best_l2_regularization ) final_model.estimate_branch_lengths(T) - self.minimum_edge_length = best_minimum_edge_length + self.minimum_branch_length = best_minimum_branch_length self.l2_regularization = best_l2_regularization self.log_likelihood = final_model.log_likelihood self.log_loss = final_model.log_loss @@ -212,13 +212,13 @@ def estimate_branch_lengths(self, T: Tree) -> None: def _cv_log_likelihood( self, T: Tree, - minimum_edge_length: float, + minimum_branch_length: float, l2_regularization: float ) -> float: verbose = self.verbose if verbose: print(f"Cross-validating hyperparameters:" - f"\nminimum_edge_length={minimum_edge_length}" + f"\nminimum_branch_length={minimum_branch_length}" f"\nl2_regularizations={l2_regularization}") n_characters = T.num_characters() log_likelihood_folds = np.zeros(shape=(n_characters)) @@ -230,7 +230,7 @@ def _cv_log_likelihood( ) try: IIDExponentialBLE( - minimum_edge_length=minimum_edge_length, + minimum_branch_length=minimum_branch_length, l2_regularization=l2_regularization ).estimate_branch_lengths(T_train) T_valid.copy_branch_lengths(T_other=T_train) diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 9b936516..27ae910b 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -395,10 +395,10 @@ def test_IIDExponentialBLEGridSearchCV(): T.nodes[1]["characters"] = '001' T = Tree(T) model = IIDExponentialBLEGridSearchCV( - minimum_edge_lengths=(0, 1.0, 3.0), + minimum_branch_lengths=(0, 1.0, 3.0), l2_regularizations=(0, ), verbose=True ) model.estimate_branch_lengths(T) - minimum_edge_length = model.minimum_edge_length - np.testing.assert_almost_equal(minimum_edge_length, 1.0) + minimum_branch_length = model.minimum_branch_length + np.testing.assert_almost_equal(minimum_branch_length, 1.0) From ef27e11eff4820f1097ce59632ce538e609dd7c6 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 28 Dec 2020 17:17:44 -0800 Subject: [PATCH 13/61] Rename T to tree --- cassiopeia/tools/branch_length_estimator.py | 123 +++--- cassiopeia/tools/lineage_simulator.py | 32 +- cassiopeia/tools/lineage_tracing_simulator.py | 20 +- cassiopeia/tools/tree.py | 76 ++-- .../branch_length_estimator_test.py | 411 +++++++++--------- test/tools_tests/lineage_simulator_test.py | 9 +- .../lineage_tracing_simulator_test.py | 24 +- test/tools_tests/tree_test.py | 108 ++--- 8 files changed, 413 insertions(+), 390 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py index 368c1e11..501c14d3 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator.py @@ -11,13 +11,20 @@ class BranchLengthEstimator(abc.ABC): r""" Abstract base class for all branch length estimators. + + A BranchLengthEstimator implements a method estimate_branch_lengths which, + given a Tree with lineage tracing character vectors at the leaves (and + possibly at the internal nodes too), estimates the branch lengths of the + tree. """ @abc.abstractmethod def estimate_branch_lengths(self, tree: Tree) -> None: r""" + Estimates the branch lengths of the tree. + Annotates the tree's nodes with their estimated age, and - the tree's branches with their lengths. - Operates on the tree in-place. + the tree's branches with their estiamted lengths. Operates on the tree + in-place. Args: tree: The tree for which to estimate branch lengths. @@ -29,8 +36,26 @@ class IIDExponentialBLE(BranchLengthEstimator): A simple branch length estimator that assumes that the characters evolve IID over the phylogeny with the same cutting rate. - Maximum Parsinomy is used to impute the ancestral states first. Doing so - leads to a convex optimization problem. + This estimator requires that the ancestral states are provided. + + The optimization problem is a special kind of convex program called an + exponential cone program: + https://docs.mosek.com/modeling-cookbook/expo.html + Because it is a convex program, it can be readily solved. + + Args: + minimum_branch_length: Estimated branch lengths will be constrained to + have at least this lenght. + l2_regularization: Consecutive branches will be regularized to have + similar length via an L2 penalty whose weight is given by + l2_regularization. + verbose: Verbosity level. + + Attributes: + log_likelihood: The log-likelihood of the training data under the + estimated model. + log_loss: The log-loss of the training data under the estimated model. + This is the log likhelihood plus the regularization terms. """ def __init__( self, @@ -44,17 +69,11 @@ def __init__( def estimate_branch_lengths(self, tree: Tree) -> None: r""" - Estimates branch lengths for the given tree. + See base class. Only caveat is that this method raises if it fails to + solve the underlying optimization problem for any reason. - This is in fact an exponential cone program, which is a special kind of - convex problem: - https://docs.mosek.com/modeling-cookbook/expo.html - - Args: - tree: The tree for which to estimate branch lengths. - - Returns: - The log-likelihood under the model for the computed branch lengths. + Raises: + cp.error.SolverError """ # Extract parameters minimum_branch_length = self.minimum_branch_length @@ -62,19 +81,18 @@ def estimate_branch_lengths(self, tree: Tree) -> None: verbose = self.verbose # # Wrap the networkx DiGraph for goodies. - # T = Tree(tree) - T = tree + # tree = Tree(tree) # # # # # Create variables of the optimization problem # # # # # r_X_t_variables = dict([(node_id, cp.Variable(name=f'r_X_t_{node_id}')) - for node_id in T.nodes()]) + for node_id in tree.nodes()]) time_increases_constraints = [ r_X_t_variables[parent] >= r_X_t_variables[child] + minimum_branch_length - for (parent, child) in T.edges() + for (parent, child) in tree.edges() ] leaves_have_age_0_constraints =\ - [r_X_t_variables[leaf] == 0 for leaf in T.leaves()] + [r_X_t_variables[leaf] == 0 for leaf in tree.leaves()] non_negative_r_X_t_constraints =\ [r_X_t >= 0 for r_X_t in r_X_t_variables.values()] all_constraints =\ @@ -87,10 +105,10 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # Because all rates are equal, the number of cuts in each node is a # sufficient statistic. This makes the solver WAY faster! - for (parent, child) in T.edges(): + for (parent, child) in tree.edges(): edge_length = r_X_t_variables[parent] - r_X_t_variables[child] - zeros_parent = T.get_state(parent).count('0') # TODO: '0'... - zeros_child = T.get_state(child).count('0') # TODO: '0'... + zeros_parent = tree.get_state(parent).count('0') # TODO: '0'... + zeros_child = tree.get_state(child).count('0') # TODO: '0'... new_cuts_child = zeros_parent - zeros_child assert(new_cuts_child >= 0) # Add log-lik for characters that didn't get cut @@ -101,8 +119,8 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # # # # # Add regularization # # # # # l2_penalty = 0 - for (parent, child) in T.edges(): - for child_of_child in T.children(child): + for (parent, child) in tree.edges(): + for child_of_child in tree.children(child): edge_length_above =\ r_X_t_variables[parent] - r_X_t_variables[child] edge_length_below =\ @@ -119,13 +137,13 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # # # # # Populate the tree with the estimated branch lengths # # # # # - for node in T.nodes(): - T.set_age(node, age=r_X_t_variables[node].value) + for node in tree.nodes(): + tree.set_age(node, age=r_X_t_variables[node].value) - for (parent, child) in T.edges(): + for (parent, child) in tree.edges(): new_edge_length =\ r_X_t_variables[parent].value - r_X_t_variables[child].value - T.set_edge_length( + tree.set_edge_length( parent, child, length=new_edge_length) @@ -134,15 +152,16 @@ def estimate_branch_lengths(self, tree: Tree) -> None: self.log_loss = f_star @classmethod - def log_likelihood(self, T: Tree) -> float: + def log_likelihood(self, tree: Tree) -> float: r""" - The log-likelihood under the model. + The log-likelihood of the given tree under the model. """ log_likelihood = 0.0 - for (parent, child) in T.edges(): - edge_length = T.get_age(parent) - T.get_age(child) - zeros_parent = T.get_state(parent).count('0') # TODO: hardcoded '0' - zeros_child = T.get_state(child).count('0') # TODO: hardcoded '0' + for (parent, child) in tree.edges(): + edge_length = tree.get_age(parent) - tree.get_age(child) + # TODO: hardcoded '0' here... + zeros_parent = tree.get_state(parent).count('0') + zeros_child = tree.get_state(child).count('0') new_cuts_child = zeros_parent - zeros_child assert(new_cuts_child >= 0) # Add log-lik for characters that didn't get cut @@ -169,7 +188,7 @@ def __init__( self.l2_regularizations = l2_regularizations self.verbose = verbose - def estimate_branch_lengths(self, T: Tree) -> None: + def estimate_branch_lengths(self, tree: Tree) -> None: r""" TODO """ @@ -182,7 +201,7 @@ def estimate_branch_lengths(self, T: Tree) -> None: for minimum_branch_length in minimum_branch_lengths: for l2_regularization in l2_regularizations: cv_log_likelihood = self._cv_log_likelihood( - T=T, + tree=tree, minimum_branch_length=minimum_branch_length, l2_regularization=l2_regularization) held_out_log_likelihoods.append( @@ -203,7 +222,7 @@ def estimate_branch_lengths(self, T: Tree) -> None: minimum_branch_length=best_minimum_branch_length, l2_regularization=best_l2_regularization ) - final_model.estimate_branch_lengths(T) + final_model.estimate_branch_lengths(tree) self.minimum_branch_length = best_minimum_branch_length self.l2_regularization = best_l2_regularization self.log_likelihood = final_model.log_likelihood @@ -211,7 +230,7 @@ def estimate_branch_lengths(self, T: Tree) -> None: def _cv_log_likelihood( self, - T: Tree, + tree: Tree, minimum_branch_length: float, l2_regularization: float ) -> float: @@ -220,22 +239,22 @@ def _cv_log_likelihood( print(f"Cross-validating hyperparameters:" f"\nminimum_branch_length={minimum_branch_length}" f"\nl2_regularizations={l2_regularization}") - n_characters = T.num_characters() + n_characters = tree.num_characters() log_likelihood_folds = np.zeros(shape=(n_characters)) for held_out_character_idx in range(n_characters): - T_train, T_valid =\ + tree_train, tree_valid =\ self._cv_split( - T, + tree=tree, held_out_character_idx=held_out_character_idx ) try: IIDExponentialBLE( minimum_branch_length=minimum_branch_length, l2_regularization=l2_regularization - ).estimate_branch_lengths(T_train) - T_valid.copy_branch_lengths(T_other=T_train) + ).estimate_branch_lengths(tree_train) + tree_valid.copy_branch_lengths(tree_other=tree_train) held_out_log_likelihood =\ - IIDExponentialBLE.log_likelihood(T_valid) + IIDExponentialBLE.log_likelihood(tree_valid) except cp.error.SolverError: held_out_log_likelihood = -np.inf log_likelihood_folds[held_out_character_idx] =\ @@ -248,22 +267,22 @@ def _cv_log_likelihood( def _cv_split( self, - T: Tree, + tree: Tree, held_out_character_idx: int ) -> Tuple[Tree, Tree]: r""" Creates a training and a cross validation tree by hiding the character at position held_out_character_idx. """ - T_train = copy.deepcopy(T) - T_valid = copy.deepcopy(T) - for node in T.nodes(): - state = T_train.get_state(node) + tree_train = copy.deepcopy(tree) + tree_valid = copy.deepcopy(tree) + for node in tree.nodes(): + state = tree_train.get_state(node) train_state =\ state[:held_out_character_idx]\ + state[(held_out_character_idx + 1):] valid_data =\ state[held_out_character_idx] - T_train.set_state(node, train_state) - T_valid.set_state(node, valid_data) - return T_train, T_valid + tree_train.set_state(node, train_state) + tree_valid.set_state(node, valid_data) + return tree_train, tree_valid diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index 119e5bf1..62e0a831 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -28,25 +28,25 @@ def simulate_lineage(self) -> Tree: """ generation_branch_lengths = self.generation_branch_lengths n_generations = len(generation_branch_lengths) - T = nx.DiGraph() - T.add_nodes_from(range(2 ** (n_generations + 1) - 1)) + tree = nx.DiGraph() + tree.add_nodes_from(range(2 ** (n_generations + 1) - 1)) edges = [(int((child - 1) / 2), child) for child in range(1, 2 ** (n_generations + 1) - 1)] node_generation = [] for i in range(n_generations + 1): node_generation += [i] * 2 ** i - T.add_edges_from(edges) + tree.add_edges_from(edges) for (parent, child) in edges: parent_generation = node_generation[parent] branch_length = generation_branch_lengths[parent_generation] - T.edges[parent, child]["length"] = branch_length - T.nodes[0]["age"] = sum(generation_branch_lengths) + tree.edges[parent, child]["length"] = branch_length + tree.nodes[0]["age"] = sum(generation_branch_lengths) for child in range(1, 2 ** (n_generations + 1) - 1): child_generation = node_generation[child] branch_length = generation_branch_lengths[child_generation - 1] - T.nodes[child]["age"] =\ - T.nodes[int((child - 1) / 2)]["age"] - branch_length - return Tree(T) + tree.nodes[child]["age"] =\ + tree.nodes[int((child - 1) / 2)]["age"] - branch_length + return Tree(tree) class PerfectBinaryTreeWithRootBranch(LineageSimulator): @@ -63,22 +63,22 @@ def simulate_lineage(self) -> Tree: # generation_branch_lengths = self.generation_branch_lengths generation_branch_lengths = self.generation_branch_lengths n_generations = len(generation_branch_lengths) - T = nx.DiGraph() - T.add_nodes_from(range(2 ** n_generations)) + tree = nx.DiGraph() + tree.add_nodes_from(range(2 ** n_generations)) edges = [(int(child / 2), child) for child in range(1, 2 ** n_generations)] - T.add_edges_from(edges) + tree.add_edges_from(edges) node_generation = [0] for i in range(n_generations): node_generation += [i + 1] * 2 ** i for (parent, child) in edges: parent_generation = node_generation[parent] branch_length = generation_branch_lengths[parent_generation] - T.edges[parent, child]["length"] = branch_length - T.nodes[0]["age"] = sum(generation_branch_lengths) + tree.edges[parent, child]["length"] = branch_length + tree.nodes[0]["age"] = sum(generation_branch_lengths) for child in range(1, 2 ** n_generations): child_generation = node_generation[child] branch_length = generation_branch_lengths[child_generation - 1] - T.nodes[child]["age"] =\ - T.nodes[int(child / 2)]["age"] - branch_length - return Tree(T) + tree.nodes[child]["age"] =\ + tree.nodes[int(child / 2)]["age"] - branch_length + return Tree(tree) diff --git a/cassiopeia/tools/lineage_tracing_simulator.py b/cassiopeia/tools/lineage_tracing_simulator.py index 777b66c1..c5df5643 100644 --- a/cassiopeia/tools/lineage_tracing_simulator.py +++ b/cassiopeia/tools/lineage_tracing_simulator.py @@ -29,19 +29,19 @@ def __init__( self.mutation_rate = mutation_rate self.num_characters = num_characters - def overlay_lineage_tracing_data(self, T: Tree) -> None: + def overlay_lineage_tracing_data(self, tree: Tree) -> None: r""" Populates the phylogenetic tree T with lineage tracing characters. """ num_characters = self.num_characters mutation_rate = self.mutation_rate - def dfs(node: int, T: Tree): - node_state = T.get_state(node) - for child in T.children(node): + def dfs(node: int, tree: Tree): + node_state = tree.get_state(node) + for child in tree.children(node): # Compute the state of the child child_state = '' - edge_length = T.get_age(node) - T.get_age(child) + edge_length = tree.get_age(node) - tree.get_age(child) # print(f"{node} -> {child}, length {edge_length}") assert(edge_length >= 0) for i in range(num_characters): @@ -60,8 +60,8 @@ def dfs(node: int, T: Tree): child_state += '1' else: child_state += '0' - T.set_state(child, child_state) - dfs(child, T) - root = T.root() - T.set_state(root, '0' * num_characters) - dfs(root, T) + tree.set_state(child, child_state) + dfs(child, tree) + root = tree.root() + tree.set_state(root, '0' * num_characters) + dfs(root, tree) diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index f4095117..58701787 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -8,69 +8,71 @@ class Tree(): networkx.Digraph wrapper to isolate networkx dependency and add custom tree methods. """ - def __init__(self, T: nx.DiGraph): - self.T = T + def __init__(self, tree: nx.DiGraph): + self.tree = tree def root(self) -> int: - T = self.T - root = [n for n in T if T.in_degree(n) == 0][0] + tree = self.tree + root = [n for n in tree if tree.in_degree(n) == 0][0] return root def leaves(self) -> List[int]: - T = self.T - leaves = [n for n in T if T.out_degree(n) == 0 and T.in_degree(n) == 1] + tree = self.tree + leaves = [n for n in tree + if tree.out_degree(n) == 0 + and tree.in_degree(n) == 1] return leaves def internal_nodes(self) -> List[int]: - T = self.T - return [n for n in T if n != self.root() and n not in self.leaves()] + tree = self.tree + return [n for n in tree if n != self.root() and n not in self.leaves()] def non_root_nodes(self) -> List[int]: return self.leaves() + self.internal_nodes() def nodes(self): - T = self.T - return list(T.nodes()) + tree = self.tree + return list(tree.nodes()) def num_characters(self) -> int: - return len(self.T.nodes[self.root()]["characters"]) + return len(self.tree.nodes[self.root()]["characters"]) def get_state(self, node: int) -> str: - T = self.T - return T.nodes[node]["characters"] + tree = self.tree + return tree.nodes[node]["characters"] def set_state(self, node: int, state: str) -> None: - T = self.T - T.nodes[node]["characters"] = state + tree = self.tree + tree.nodes[node]["characters"] = state def set_states(self, node_state_list: List[Tuple[int, str]]) -> None: for (node, state) in node_state_list: self.set_state(node, state) def get_age(self, node: int) -> float: - T = self.T - return T.nodes[node]["age"] + tree = self.tree + return tree.nodes[node]["age"] def set_age(self, node: int, age: float) -> None: - T = self.T - T.nodes[node]["age"] = age + tree = self.tree + tree.nodes[node]["age"] = age def edges(self) -> List[Tuple[int, int]]: """List of (parent, child) tuples""" - T = self.T - return list(T.edges) + tree = self.tree + return list(tree.edges) def get_edge_length(self, parent: int, child: int) -> float: - T = self.T - assert parent in T - assert child in T[parent] - return T.edges[parent, child]["length"] + tree = self.tree + assert parent in tree + assert child in tree[parent] + return tree.edges[parent, child]["length"] def set_edge_length(self, parent: int, child: int, length: float) -> None: - T = self.T - assert parent in T - assert child in T[parent] - T.edges[parent, child]["length"] = length + tree = self.tree + assert parent in tree + assert child in tree[parent] + tree.edges[parent, child]["length"] = length def set_edge_lengths( self, @@ -79,8 +81,8 @@ def set_edge_lengths( self.set_edge_length(parent, child, length) def children(self, node: int) -> List[int]: - T = self.T - return list(T.adj[node]) + tree = self.tree + return list(tree.adj[node]) def to_newick_tree_format( self, @@ -179,20 +181,20 @@ def dfs(v: int) -> None: self.set_state(v, '0' * n_characters) dfs(root) - def copy_branch_lengths(self, T_other): + def copy_branch_lengths(self, tree_other): r""" - Copies the branch lengths of T_other onto self + Copies the branch lengths of tree_other onto self """ - assert(self.nodes() == T_other.nodes()) - assert(self.edges() == T_other.edges()) + assert(self.nodes() == tree_other.nodes()) + assert(self.edges() == tree_other.edges()) for node in self.nodes(): - new_age = T_other.get_age(node) + new_age = tree_other.get_age(node) self.set_age(node, age=new_age) for (parent, child) in self.edges(): new_edge_length =\ - T_other.get_age(parent) - T_other.get_age(child) + tree_other.get_age(parent) - tree_other.get_age(child) self.set_edge_length( parent, child, diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 27ae910b..b56d754e 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -16,20 +16,20 @@ def test_no_mutations(): This is thus the simplest possible example of no mutations, and the MLE branch length should be 0 """ - T = nx.DiGraph() - T.add_node(0), T.add_node(1) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '0' - T = Tree(T) + tree = nx.DiGraph() + tree.add_node(0), tree.add_node(1) + tree.add_edge(0, 1) + tree.nodes[0]["characters"] = '0' + tree.nodes[1]["characters"] = '0' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.0) - np.testing.assert_almost_equal(T.get_age(0), 0.0) - np.testing.assert_almost_equal(T.get_age(1), 0.0) + np.testing.assert_almost_equal(tree.get_edge_length(0, 1), 0.0) + np.testing.assert_almost_equal(tree.get_age(0), 0.0) + np.testing.assert_almost_equal(tree.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, 0.0) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -44,20 +44,20 @@ def test_saturation(): This is thus the simplest possible example of saturation, and the MLE branch length should be infinity (>15 for all practical purposes) """ - T = nx.DiGraph() - T.add_nodes_from([0, 1]) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '1' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1]) + tree.add_edge(0, 1) + tree.nodes[0]["characters"] = '0' + tree.nodes[1]["characters"] = '1' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - assert(T.get_edge_length(0, 1) > 15.0) - assert(T.get_age(0) > 15.0) - np.testing.assert_almost_equal(T.get_age(1), 0.0) + assert(tree.get_edge_length(0, 1) > 15.0) + assert(tree.get_age(0) > 15.0) + np.testing.assert_almost_equal(tree.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=5) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -73,21 +73,21 @@ def test_hand_solvable_problem_1(): min_{r * t0} log(exp(-r * t0)) + log(1 - exp(-r * t0)) The solution is r * t0 = ln(2) ~ 0.693 """ - T = nx.DiGraph() - T.add_nodes_from([0, 1]) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '00' - T.nodes[1]["characters"] = '01' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1]) + tree.add_edge(0, 1) + tree.nodes[0]["characters"] = '00' + tree.nodes[1]["characters"] = '01' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - T.get_edge_length(0, 1), np.log(2), decimal=3) - np.testing.assert_almost_equal(T.get_age(0), np.log(2), decimal=3) - np.testing.assert_almost_equal(T.get_age(1), 0.0) + tree.get_edge_length(0, 1), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_age(0), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -103,21 +103,21 @@ def test_hand_solvable_problem_2(): min_{r * t0} log(exp(-r * t0)) + 2 * log(1 - exp(-r * t0)) The solution is r * t0 = ln(3) ~ 1.098 """ - T = nx.DiGraph() - T.add_nodes_from([0, 1]) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '000' - T.nodes[1]["characters"] = '011' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1]) + tree.add_edge(0, 1) + tree.nodes[0]["characters"] = '000' + tree.nodes[1]["characters"] = '011' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - T.get_edge_length(0, 1), np.log(3), decimal=3) - np.testing.assert_almost_equal(T.get_age(0), np.log(3), decimal=3) - np.testing.assert_almost_equal(T.get_age(1), 0.0) + tree.get_edge_length(0, 1), np.log(3), decimal=3) + np.testing.assert_almost_equal(tree.get_age(0), np.log(3), decimal=3) + np.testing.assert_almost_equal(tree.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -133,21 +133,21 @@ def test_hand_solvable_problem_3(): min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) The solution is r * t0 = ln(1.5) ~ 0.405 """ - T = nx.DiGraph() - T.add_nodes_from([0, 1]) - T.add_edge(0, 1) - T.nodes[0]["characters"] = '000' - T.nodes[1]["characters"] = '001' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1]) + tree.add_edge(0, 1) + tree.nodes[0]["characters"] = '000' + tree.nodes[1]["characters"] = '001' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - T.get_edge_length(0, 1), np.log(1.5), decimal=3) - np.testing.assert_almost_equal(T.get_age(0), np.log(1.5), decimal=3) - np.testing.assert_almost_equal(T.get_age(1), 0.0) + tree.get_edge_length(0, 1), np.log(1.5), decimal=3) + np.testing.assert_almost_equal(tree.get_age(0), np.log(1.5), decimal=3) + np.testing.assert_almost_equal(tree.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -155,24 +155,25 @@ def test_small_tree_with_no_mutations(): r""" Perfect binary tree with no mutations: Should give edges of length 0 """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]) - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '0000' - T.nodes[1]["characters"] = '0000' - T.nodes[2]["characters"] = '0000' - T.nodes[3]["characters"] = '0000' - T.nodes[4]["characters"] = '0000' - T.nodes[5]["characters"] = '0000' - T.nodes[6]["characters"] = '0000' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]) + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + tree.nodes[0]["characters"] = '0000' + tree.nodes[1]["characters"] = '0000' + tree.nodes[2]["characters"] = '0000' + tree.nodes[3]["characters"] = '0000' + tree.nodes[4]["characters"] = '0000' + tree.nodes[5]["characters"] = '0000' + tree.nodes[6]["characters"] = '0000' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - for edge in T.edges(): - np.testing.assert_almost_equal(T.get_edge_length(*edge), 0, decimal=3) + for edge in tree.edges(): + np.testing.assert_almost_equal( + tree.get_edge_length(*edge), 0, decimal=3) np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -185,28 +186,28 @@ def test_small_tree_with_one_mutation(): min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) The solution is r * t0 = ln(1.5) ~ 0.405 """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '0' - T.nodes[2]["characters"] = '0' - T.nodes[3]["characters"] = '0' - T.nodes[4]["characters"] = '0' - T.nodes[5]["characters"] = '0' - T.nodes[6]["characters"] = '1' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + tree.nodes[0]["characters"] = '0' + tree.nodes[1]["characters"] = '0' + tree.nodes[2]["characters"] = '0' + tree.nodes[3]["characters"] = '0' + tree.nodes[4]["characters"] = '0' + tree.nodes[5]["characters"] = '0' + tree.nodes[6]["characters"] = '1' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.405, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.0, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 4), 0.0, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.405, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.405, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(0, 1), 0.405, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(0, 2), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(1, 3), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(1, 4), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(2, 5), 0.405, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(2, 6), 0.405, decimal=3) np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -215,24 +216,24 @@ def test_small_tree_with_saturation(): Perfect binary tree with saturation. The edges which saturate should thus have length infinity (>15 for all practical purposes) """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '0' - T.nodes[2]["characters"] = '1' - T.nodes[3]["characters"] = '1' - T.nodes[4]["characters"] = '1' - T.nodes[5]["characters"] = '1' - T.nodes[6]["characters"] = '1' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + tree.nodes[0]["characters"] = '0' + tree.nodes[1]["characters"] = '0' + tree.nodes[2]["characters"] = '1' + tree.nodes[3]["characters"] = '1' + tree.nodes[4]["characters"] = '1' + tree.nodes[5]["characters"] = '1' + tree.nodes[6]["characters"] = '1' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - assert(T.get_edge_length(0, 2) > 15.0) - assert(T.get_edge_length(1, 3) > 15.0) - assert(T.get_edge_length(1, 4) > 15.0) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + assert(tree.get_edge_length(0, 2) > 15.0) + assert(tree.get_edge_length(1, 3) > 15.0) + assert(tree.get_edge_length(1, 4) > 15.0) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -242,28 +243,28 @@ def test_small_tree_regression(): never changes. """ # Perfect binary tree with normal amount of mutations on each edge - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '000000000' - T.nodes[1]["characters"] = '100000000' - T.nodes[2]["characters"] = '000006000' - T.nodes[3]["characters"] = '120000000' - T.nodes[4]["characters"] = '103000000' - T.nodes[5]["characters"] = '000056700' - T.nodes[6]["characters"] = '000406089' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + tree.nodes[0]["characters"] = '000000000' + tree.nodes[1]["characters"] = '100000000' + tree.nodes[2]["characters"] = '000006000' + tree.nodes[3]["characters"] = '120000000' + tree.nodes[4]["characters"] = '103000000' + tree.nodes[5]["characters"] = '000056700' + tree.nodes[6]["characters"] = '000406089' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.203, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.082, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.175, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 4), 0.175, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(2, 5), 0.295, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(2, 6), 0.295, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(0, 1), 0.203, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(0, 2), 0.082, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(1, 3), 0.175, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(1, 4), 0.175, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(2, 5), 0.295, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(2, 6), 0.295, decimal=3) np.testing.assert_almost_equal(log_likelihood, -22.689, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -271,29 +272,29 @@ def test_small_symmetric_tree(): r""" Symmetric tree should have equal length edges. """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '000' - T.nodes[1]["characters"] = '100' - T.nodes[2]["characters"] = '100' - T.nodes[3]["characters"] = '110' - T.nodes[4]["characters"] = '110' - T.nodes[5]["characters"] = '110' - T.nodes[6]["characters"] = '110' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + tree.nodes[0]["characters"] = '000' + tree.nodes[1]["characters"] = '100' + tree.nodes[2]["characters"] = '100' + tree.nodes[3]["characters"] = '110' + tree.nodes[4]["characters"] = '110' + tree.nodes[5]["characters"] = '110' + tree.nodes[6]["characters"] = '110' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - T.get_edge_length(0, 1), T.get_edge_length(0, 2)) + tree.get_edge_length(0, 1), tree.get_edge_length(0, 2)) np.testing.assert_almost_equal( - T.get_edge_length(1, 3), T.get_edge_length(1, 4)) + tree.get_edge_length(1, 3), tree.get_edge_length(1, 4)) np.testing.assert_almost_equal( - T.get_edge_length(1, 4), T.get_edge_length(2, 5)) + tree.get_edge_length(1, 4), tree.get_edge_length(2, 5)) np.testing.assert_almost_equal( - T.get_edge_length(2, 5), T.get_edge_length(2, 6)) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + tree.get_edge_length(2, 5), tree.get_edge_length(2, 6)) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -304,58 +305,58 @@ def test_small_tree_with_infinite_legs(): the branches for the leaves should be infinity (>15 for all practical purposes) """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["characters"] = '00' - T.nodes[1]["characters"] = '10' - T.nodes[2]["characters"] = '10' - T.nodes[3]["characters"] = '11' - T.nodes[4]["characters"] = '11' - T.nodes[5]["characters"] = '11' - T.nodes[6]["characters"] = '11' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + tree.nodes[0]["characters"] = '00' + tree.nodes[1]["characters"] = '10' + tree.nodes[2]["characters"] = '10' + tree.nodes[3]["characters"] = '11' + tree.nodes[4]["characters"] = '11' + tree.nodes[5]["characters"] = '11' + tree.nodes[6]["characters"] = '11' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - np.testing.assert_almost_equal(T.get_edge_length(0, 1), 0.693, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(0, 2), 0.693, decimal=3) - assert(T.get_edge_length(1, 3) > 15) - assert(T.get_edge_length(1, 4) > 15) - assert(T.get_edge_length(2, 5) > 15) - assert(T.get_edge_length(2, 6) > 15) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + np.testing.assert_almost_equal(tree.get_edge_length(0, 1), 0.693, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(0, 2), 0.693, decimal=3) + assert(tree.get_edge_length(1, 3) > 15) + assert(tree.get_edge_length(1, 4) > 15) + assert(tree.get_edge_length(2, 5) > 15) + assert(tree.get_edge_length(2, 6) > 15) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_on_simulated_data(): - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["age"] = 1 - T.nodes[1]["age"] = 0.9 - T.nodes[2]["age"] = 0.1 - T.nodes[3]["age"] = 0 - T.nodes[4]["age"] = 0 - T.nodes[5]["age"] = 0 - T.nodes[6]["age"] = 0 + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + tree.nodes[0]["age"] = 1 + tree.nodes[1]["age"] = 0.9 + tree.nodes[2]["age"] = 0.1 + tree.nodes[3]["age"] = 0 + tree.nodes[4]["age"] = 0 + tree.nodes[5]["age"] = 0 + tree.nodes[6]["age"] = 0 np.random.seed(1) - T = Tree(T) + tree = Tree(tree) IIDExponentialLineageTracer(mutation_rate=1.0, num_characters=100)\ - .overlay_lineage_tracing_data(T) - for node in T.nodes(): - T.set_age(node, -1) + .overlay_lineage_tracing_data(tree) + for node in tree.nodes(): + tree.set_age(node, -1) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - assert(0.9 < T.get_age(0) < 1.1) - assert(0.8 < T.get_age(1) < 1.0) - assert(0.05 < T.get_age(2) < 0.15) - np.testing.assert_almost_equal(T.get_age(3), 0) - np.testing.assert_almost_equal(T.get_age(4), 0) - np.testing.assert_almost_equal(T.get_age(5), 0) - np.testing.assert_almost_equal(T.get_age(6), 0) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + assert(0.9 < tree.get_age(0) < 1.1) + assert(0.8 < tree.get_age(1) < 1.0) + assert(0.05 < tree.get_age(2) < 0.15) + np.testing.assert_almost_equal(tree.get_age(3), 0) + np.testing.assert_almost_equal(tree.get_age(4), 0) + np.testing.assert_almost_equal(tree.get_age(5), 0) + np.testing.assert_almost_equal(tree.get_age(6), 0) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -364,41 +365,41 @@ def test_subtree_collapses_when_no_mutations(): A subtree with no mutations should collapse to 0. It reduces the problem to the same as in 'test_hand_solvable_problem_1' """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4]), - T.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) - T.nodes[0]["characters"] = '0' - T.nodes[1]["characters"] = '1' - T.nodes[2]["characters"] = '1' - T.nodes[3]["characters"] = '1' - T.nodes[4]["characters"] = '0' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4]), + tree.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) + tree.nodes[0]["characters"] = '0' + tree.nodes[1]["characters"] = '1' + tree.nodes[2]["characters"] = '1' + tree.nodes[3]["characters"] = '1' + tree.nodes[4]["characters"] = '0' + tree = Tree(tree) model = IIDExponentialBLE() - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - T.get_edge_length(0, 1), np.log(2), decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 2), 0.0, decimal=3) - np.testing.assert_almost_equal(T.get_edge_length(1, 3), 0.0, decimal=3) + tree.get_edge_length(0, 1), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(1, 2), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_edge_length(1, 3), 0.0, decimal=3) np.testing.assert_almost_equal( - T.get_edge_length(0, 4), np.log(2), decimal=3) + tree.get_edge_length(0, 4), np.log(2), decimal=3) np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(T) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) def test_IIDExponentialBLEGridSearchCV(): - T = nx.DiGraph() - T.add_nodes_from([0, 1]), - T.add_edges_from([(0, 1)]) - T.nodes[0]["characters"] = '000' - T.nodes[1]["characters"] = '001' - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1]), + tree.add_edges_from([(0, 1)]) + tree.nodes[0]["characters"] = '000' + tree.nodes[1]["characters"] = '001' + tree = Tree(tree) model = IIDExponentialBLEGridSearchCV( minimum_branch_lengths=(0, 1.0, 3.0), l2_regularizations=(0, ), verbose=True ) - model.estimate_branch_lengths(T) + model.estimate_branch_lengths(tree) minimum_branch_length = model.minimum_branch_length np.testing.assert_almost_equal(minimum_branch_length, 1.0) diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py index cf7ae578..8e4fb77a 100644 --- a/test/tools_tests/lineage_simulator_test.py +++ b/test/tools_tests/lineage_simulator_test.py @@ -2,13 +2,14 @@ def test_PerfectBinaryTree(): - T = PerfectBinaryTree(generation_branch_lengths=[2, 3]).simulate_lineage() - newick = T.to_newick_tree_format(print_internal_nodes=True) + tree = PerfectBinaryTree(generation_branch_lengths=[2, 3])\ + .simulate_lineage() + newick = tree.to_newick_tree_format(print_internal_nodes=True) assert(newick == '((3:3,4:3)1:2,(5:3,6:3)2:2)0);') def test_PerfectBinaryTreeWithRootBranch(): - T = PerfectBinaryTreeWithRootBranch(generation_branch_lengths=[2, 3, 4])\ + tree = PerfectBinaryTreeWithRootBranch(generation_branch_lengths=[2, 3, 4])\ .simulate_lineage() - newick = T.to_newick_tree_format(print_internal_nodes=True) + newick = tree.to_newick_tree_format(print_internal_nodes=True) assert(newick == '(((4:4,5:4)2:3,(6:4,7:4)3:3)1:2)0);') diff --git a/test/tools_tests/lineage_tracing_simulator_test.py b/test/tools_tests/lineage_tracing_simulator_test.py index f526aa40..b9ffb04f 100644 --- a/test/tools_tests/lineage_tracing_simulator_test.py +++ b/test/tools_tests/lineage_tracing_simulator_test.py @@ -9,16 +9,16 @@ def test_smoke(): Just tests that lineage_tracing_simulator runs """ np.random.seed(1) - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - T.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - T.nodes[0]["age"] = 1 - T.nodes[1]["age"] = 0.9 - T.nodes[2]["age"] = 0.1 - T.nodes[3]["age"] = 0 - T.nodes[4]["age"] = 0 - T.nodes[5]["age"] = 0 - T.nodes[6]["age"] = 0 - T = Tree(T) + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + tree.nodes[0]["age"] = 1 + tree.nodes[1]["age"] = 0.9 + tree.nodes[2]["age"] = 0.1 + tree.nodes[3]["age"] = 0 + tree.nodes[4]["age"] = 0 + tree.nodes[5]["age"] = 0 + tree.nodes[6]["age"] = 0 + tree = Tree(tree) IIDExponentialLineageTracer(mutation_rate=1.0, num_characters=10)\ - .overlay_lineage_tracing_data(T) + .overlay_lineage_tracing_data(tree) diff --git a/test/tools_tests/tree_test.py b/test/tools_tests/tree_test.py index 03d43113..8101f093 100644 --- a/test/tools_tests/tree_test.py +++ b/test/tools_tests/tree_test.py @@ -9,18 +9,18 @@ def test_to_newick_tree_format(): The most basic newick example should give: (2:0.5,(4:0.3,5:0.4):0.2):0.1); """ - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5]) - T.add_edges_from([(0, 1), (1, 2), (1, 3), (3, 4), (3, 5)]) - T = Tree(T) - T.set_edge_lengths( + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5]) + tree.add_edges_from([(0, 1), (1, 2), (1, 3), (3, 4), (3, 5)]) + tree = Tree(tree) + tree.set_edge_lengths( [(0, 1, 0.1), (1, 2, 0.5), (1, 3, 0.2), (3, 4, 0.3), (3, 5, 0.4)] ) - T.set_states( + tree.set_states( [(0, '0000000000'), (1, '1000000000'), (2, '1111000000'), @@ -28,29 +28,29 @@ def test_to_newick_tree_format(): (4, '1110000111'), (5, '1110111111')] ) - res = T.to_newick_tree_format(print_internal_nodes=False) + res = tree.to_newick_tree_format(print_internal_nodes=False) assert(res == "((2:0.5,(4:0.3,5:0.4):0.2):0.1));") - res = T.to_newick_tree_format( + res = tree.to_newick_tree_format( print_node_names=False, print_internal_nodes=True, append_state_to_node_name=True) assert(res == "((_1111000000:0.5,(_1110000111:0.3,_1110111111:0.4)" "_1110000000:0.2)_1000000000:0.1)_0000000000);") - res = T.to_newick_tree_format(print_internal_nodes=True) + res = tree.to_newick_tree_format(print_internal_nodes=True) assert(res == "((2:0.5,(4:0.3,5:0.4)3:0.2)1:0.1)0);") - res = T.to_newick_tree_format(print_node_names=False) + res = tree.to_newick_tree_format(print_node_names=False) assert(res == "((:0.5,(:0.3,:0.4):0.2):0.1));") - res = T.to_newick_tree_format( + res = tree.to_newick_tree_format( print_internal_nodes=True, add_N_to_node_id=True) assert(res == "((N2:0.5,(N4:0.3,N5:0.4)N3:0.2)N1:0.1)N0);") - res = T.to_newick_tree_format( + res = tree.to_newick_tree_format( print_internal_nodes=True, append_state_to_node_name=True, add_N_to_node_id=True) assert(res == "((N2_1111000000:0.5,(N4_1110000111:0.3,N5_1110111111:0.4)" "N3_1110000000:0.2)N1_1000000000:0.1)N0_0000000000);") - res = T.to_newick_tree_format( + res = tree.to_newick_tree_format( print_internal_nodes=True, print_pct_of_mutated_characters_along_edge=True, add_N_to_node_id=True) @@ -60,20 +60,20 @@ def test_to_newick_tree_format(): def test_reconstruct_ancestral_states(): - T = nx.DiGraph() - T.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) - T.add_edges_from([(10, 11), - (11, 13), - (13, 0), (13, 1), - (11, 14), - (14, 2), (14, 3), - (10, 12), - (12, 15), - (15, 4), (15, 5), - (12, 16), - (16, 6), (16, 7), (16, 8), (16, 9)]) - T = Tree(T) - T.set_states( + tree = nx.DiGraph() + tree.add_nodes_from(list(range(17))) + tree.add_edges_from([(10, 11), + (11, 13), + (13, 0), (13, 1), + (11, 14), + (14, 2), (14, 3), + (10, 12), + (12, 15), + (15, 4), (15, 5), + (12, 16), + (16, 6), (16, 7), (16, 8), (16, 9)]) + tree = Tree(tree) + tree.set_states( [(0, '01101110100'), (1, '01211111111'), (2, '01322121111'), @@ -86,25 +86,25 @@ def test_reconstruct_ancestral_states(): (9, '01093240010'), ] ) - T.reconstruct_ancestral_states() - assert(T.get_state(10) == '00000000000') - assert(T.get_state(11) == '01000100100') - assert(T.get_state(13) == '01001110100') - assert(T.get_state(14) == '01002120111') - assert(T.get_state(12) == '01000200010') - assert(T.get_state(15) == '01001230111') - assert(T.get_state(16) == '01003240010') + tree.reconstruct_ancestral_states() + assert(tree.get_state(10) == '00000000000') + assert(tree.get_state(11) == '01000100100') + assert(tree.get_state(13) == '01001110100') + assert(tree.get_state(14) == '01002120111') + assert(tree.get_state(12) == '01000200010') + assert(tree.get_state(15) == '01001230111') + assert(tree.get_state(16) == '01003240010') def test_reconstruct_ancestral_states_DREAM_challenge_tree_25(): - T = nx.DiGraph() - T.add_nodes_from(list(range(21))) - T.add_edges_from([(9, 8), (8, 10), (8, 7), (7, 11), (7, 12), (9, 6), - (6, 2), (2, 0), (0, 13), (0, 14), (2, 1), (1, 15), - (1, 16), (6, 5), (5, 3), (3, 17), (3, 18), (5, 4), - (4, 19), (4, 20)]) - T = Tree(T) - T.set_states( + tree = nx.DiGraph() + tree.add_nodes_from(list(range(21))) + tree.add_edges_from([(9, 8), (8, 10), (8, 7), (7, 11), (7, 12), (9, 6), + (6, 2), (2, 0), (0, 13), (0, 14), (2, 1), (1, 15), + (1, 16), (6, 5), (5, 3), (3, 17), (3, 18), (5, 4), + (4, 19), (4, 20)]) + tree = Tree(tree) + tree.set_states( [(10, '0022100000'), (11, '0022100000'), (12, '0022100000'), @@ -118,14 +118,14 @@ def test_reconstruct_ancestral_states_DREAM_challenge_tree_25(): (20, '0000210220'), ] ) - T.reconstruct_ancestral_states() - assert(T.get_state(7) == '0022100000') - assert(T.get_state(8) == '0022100000') - assert(T.get_state(0) == '2012000200') - assert(T.get_state(1) == '2012000100') - assert(T.get_state(2) == '2012000000') - assert(T.get_state(3) == '0001110220') - assert(T.get_state(4) == '0000210220') - assert(T.get_state(5) == '0000010220') - assert(T.get_state(6) == '0000000000') - assert(T.get_state(9) == '0000000000') + tree.reconstruct_ancestral_states() + assert(tree.get_state(7) == '0022100000') + assert(tree.get_state(8) == '0022100000') + assert(tree.get_state(0) == '2012000200') + assert(tree.get_state(1) == '2012000100') + assert(tree.get_state(2) == '2012000000') + assert(tree.get_state(3) == '0001110220') + assert(tree.get_state(4) == '0000210220') + assert(tree.get_state(5) == '0000010220') + assert(tree.get_state(6) == '0000000000') + assert(tree.get_state(9) == '0000000000') From 7fb4d388c108ad453e28429fa3f354fd87f546a8 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 28 Dec 2020 18:06:44 -0800 Subject: [PATCH 14/61] docs --- cassiopeia/tools/branch_length_estimator.py | 31 ++++++++++++++++--- cassiopeia/tools/lineage_simulator.py | 28 +++++++++++++++-- cassiopeia/tools/lineage_tracing_simulator.py | 19 +++++++++--- cassiopeia/tools/tree.py | 15 ++++++--- 4 files changed, 77 insertions(+), 16 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py index 501c14d3..aa6802e2 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator.py @@ -69,8 +69,8 @@ def __init__( def estimate_branch_lengths(self, tree: Tree) -> None: r""" - See base class. Only caveat is that this method raises if it fails to - solve the underlying optimization problem for any reason. + See base class. The only caveat is that this method raises if it fails + to solve the underlying optimization problem for any reason. Raises: cp.error.SolverError @@ -175,8 +175,17 @@ def log_likelihood(self, tree: Tree) -> float: class IIDExponentialBLEGridSearchCV(BranchLengthEstimator): r""" - Cross-validated version of IIDExponentialBLE which fits the hyperparameters - based on character-level held-out log-likelihood. + Like IIDExponentialBLE but with automatic tuning of hyperparameters. + + This class fits the hyperparameters of IIDExponentialBLE based on + character-level held-out log-likelihood. It leaves out one character at a + time, fitting the data on all the remaining characters. Thus, the number + of models trained by this class is #characters * grid size. + + Args: + minimum_branch_lengths: The grid of minimum_branch_length to use. + l2_regularizations: The grid of l2_regularization to use. + verbose: Verbosity level. """ def __init__( self, @@ -190,7 +199,11 @@ def __init__( def estimate_branch_lengths(self, tree: Tree) -> None: r""" - TODO + See base class. The only caveat is that this method raises if it fails + to solve the underlying optimization problem for any reason. + + Raises: + cp.error.SolverError """ # Extract parameters minimum_branch_lengths = self.minimum_branch_lengths @@ -234,6 +247,14 @@ def _cv_log_likelihood( minimum_branch_length: float, l2_regularization: float ) -> float: + r""" + Given the tree and the parameters of the model, returns the + cross-validated log-likelihood of the model. This is done by holding out + one character at a time, fitting the model on the remaining characters, + and evaluating the log-likelihood on the held-out character. As a + consequence, #character models are fit by this method. The mean held-out + log-likelihood over the #character folds is returned. + """ verbose = self.verbose if verbose: print(f"Cross-validating hyperparameters:" diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index 62e0a831..e0b30ae3 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -9,13 +9,27 @@ class LineageSimulator(abc.ABC): r""" Abstract base class for lineage simulators. + + A LineageSimulator implements the method simulate_lineage that generates a + lineage tree (i.e. a phylogeny, in the form of a Tree). """ @abc.abstractmethod def simulate_lineage(self) -> Tree: - r"""Simulates a ground truth lineage""" + r""" + Simulates a lineage tree, i.e. a Tree with branch lengths and age + specified for each node. Additional information such as cell fitness, + etc. might be specified by more complex simulators. + """ class PerfectBinaryTree(LineageSimulator): + r""" + Generates a perfect binary tree with given branch lengths at each depth. + + Args: + generation_branch_lengths: The branches at depth d in the tree will have + length generation_branch_lengths[d] + """ def __init__( self, generation_branch_lengths: List[float] @@ -24,7 +38,7 @@ def __init__( def simulate_lineage(self) -> Tree: r""" - See test for doc. + See base class. """ generation_branch_lengths = self.generation_branch_lengths n_generations = len(generation_branch_lengths) @@ -50,6 +64,14 @@ def simulate_lineage(self) -> Tree: class PerfectBinaryTreeWithRootBranch(LineageSimulator): + r""" + Generates a perfect binary tree *hanging from a branch*, with given branch + lengths at each depth. + + Args: + generation_branch_lengths: The branches at depth d in the tree will have + length generation_branch_lengths[d] + """ def __init__( self, generation_branch_lengths: List[float] @@ -58,7 +80,7 @@ def __init__( def simulate_lineage(self) -> Tree: r""" - See test for doc. + See base class. """ # generation_branch_lengths = self.generation_branch_lengths generation_branch_lengths = self.generation_branch_lengths diff --git a/cassiopeia/tools/lineage_tracing_simulator.py b/cassiopeia/tools/lineage_tracing_simulator.py index c5df5643..51b430b7 100644 --- a/cassiopeia/tools/lineage_tracing_simulator.py +++ b/cassiopeia/tools/lineage_tracing_simulator.py @@ -8,18 +8,29 @@ class LineageTracingSimulator(abc.ABC): r""" Abstract base class for all lineage tracing simulators. + + A LineageTracingSimulator implements a method overlay_lineage_tracing_data + which overlays lineage tracing data (i.e. character vectors) on top of the + tree. These are stored as the node's state. """ @abc.abstractmethod def overlay_lineage_tracing_data(self, tree: Tree) -> None: r""" Annotates the tree's nodes with lineage tracing character vectors. - Operates on the tree in-place. + These are stored as the node's state. (Operates on the tree in-place.) + + Args: + tree: The tree to overlay lineage tracing data on. """ -class IIDExponentialLineageTracer(): +class IIDExponentialLineageTracer(LineageTracingSimulator): r""" - Characters evolve IID over the lineage, with the same rate. + Characters evolve IID over the lineage, with the same given mutation rate. + + Args: + mutation_rate: The mutation rate of each character (same for all). + num_characters: The number of characters. """ def __init__( self, @@ -31,7 +42,7 @@ def __init__( def overlay_lineage_tracing_data(self, tree: Tree) -> None: r""" - Populates the phylogenetic tree T with lineage tracing characters. + See base class. """ num_characters = self.num_characters mutation_rate = self.mutation_rate diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 58701787..eb66030b 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -5,8 +5,13 @@ class Tree(): r""" - networkx.Digraph wrapper to isolate networkx dependency and add custom tree - methods. + A phylogenetic tree for holding data from lineages and lineage tracing + experiments. + + (Currently implemented as a light wrapper over networkx.DiGraph) + + Args: + tree: The networkx.DiGraph from which to create the tree. """ def __init__(self, tree: nx.DiGraph): self.tree = tree @@ -93,11 +98,13 @@ def to_newick_tree_format( add_N_to_node_id: bool = False ) -> str: r""" - Converts tree into Newick tree format for viewing in e.g. ITOL. - Arguments: + Converts tree into Newick tree format. + + Args: print_internal_nodes: If True, prints the names of internal nodes too. print_pct_of_mutated_characters_along_edge: Self-explanatory + TODO """ leaves = self.leaves() From 5d8f0ce41d588372ba648c8fd1a4fe8646438417 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 28 Dec 2020 18:19:44 -0800 Subject: [PATCH 15/61] Run black --- cassiopeia/tools/__init__.py | 6 +- cassiopeia/tools/branch_length_estimator.py | 146 +++++++------ cassiopeia/tools/lineage_simulator.py | 30 +-- cassiopeia/tools/lineage_tracing_simulator.py | 26 +-- cassiopeia/tools/tree.py | 105 +++++---- .../branch_length_estimator_test.py | 189 ++++++++-------- test/tools_tests/lineage_simulator_test.py | 14 +- .../lineage_tracing_simulator_test.py | 5 +- test/tools_tests/tree_test.py | 201 +++++++++++------- 9 files changed, 405 insertions(+), 317 deletions(-) diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index 8076c272..f7d36a54 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -1,15 +1,15 @@ from .branch_length_estimator import ( BranchLengthEstimator, IIDExponentialBLE, - IIDExponentialBLEGridSearchCV + IIDExponentialBLEGridSearchCV, ) from .lineage_simulator import ( LineageSimulator, PerfectBinaryTree, - PerfectBinaryTreeWithRootBranch + PerfectBinaryTreeWithRootBranch, ) from .lineage_tracing_simulator import ( LineageTracingSimulator, - IIDExponentialLineageTracer + IIDExponentialLineageTracer, ) from .tree import Tree diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator.py index aa6802e2..79ac4a7b 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator.py @@ -17,6 +17,7 @@ class BranchLengthEstimator(abc.ABC): possibly at the internal nodes too), estimates the branch lengths of the tree. """ + @abc.abstractmethod def estimate_branch_lengths(self, tree: Tree) -> None: r""" @@ -57,11 +58,12 @@ class IIDExponentialBLE(BranchLengthEstimator): log_loss: The log-loss of the training data under the estimated model. This is the log likhelihood plus the regularization terms. """ + def __init__( self, minimum_branch_length: float = 0, l2_regularization: float = 0, - verbose: bool = False + verbose: bool = False, ): self.minimum_branch_length = minimum_branch_length self.l2_regularization = l2_regularization @@ -84,21 +86,28 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # tree = Tree(tree) # # # # # Create variables of the optimization problem # # # # # - r_X_t_variables = dict([(node_id, cp.Variable(name=f'r_X_t_{node_id}')) - for node_id in tree.nodes()]) + r_X_t_variables = dict( + [ + (node_id, cp.Variable(name=f"r_X_t_{node_id}")) + for node_id in tree.nodes() + ] + ) time_increases_constraints = [ r_X_t_variables[parent] >= r_X_t_variables[child] + minimum_branch_length for (parent, child) in tree.edges() ] - leaves_have_age_0_constraints =\ - [r_X_t_variables[leaf] == 0 for leaf in tree.leaves()] - non_negative_r_X_t_constraints =\ - [r_X_t >= 0 for r_X_t in r_X_t_variables.values()] - all_constraints =\ - time_increases_constraints + \ - leaves_have_age_0_constraints + \ - non_negative_r_X_t_constraints + leaves_have_age_0_constraints = [ + r_X_t_variables[leaf] == 0 for leaf in tree.leaves() + ] + non_negative_r_X_t_constraints = [ + r_X_t >= 0 for r_X_t in r_X_t_variables.values() + ] + all_constraints = ( + time_increases_constraints + + leaves_have_age_0_constraints + + non_negative_r_X_t_constraints + ) # # # # # Compute the log-likelihood # # # # # log_likelihood = 0 @@ -107,10 +116,11 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # sufficient statistic. This makes the solver WAY faster! for (parent, child) in tree.edges(): edge_length = r_X_t_variables[parent] - r_X_t_variables[child] - zeros_parent = tree.get_state(parent).count('0') # TODO: '0'... - zeros_child = tree.get_state(child).count('0') # TODO: '0'... + # TODO: hardcoded '0' here... + zeros_parent = tree.get_state(parent).count("0") + zeros_child = tree.get_state(child).count("0") new_cuts_child = zeros_parent - zeros_child - assert(new_cuts_child >= 0) + assert new_cuts_child >= 0 # Add log-lik for characters that didn't get cut log_likelihood += zeros_child * (-edge_length) # Add log-lik for characters that got cut @@ -121,10 +131,12 @@ def estimate_branch_lengths(self, tree: Tree) -> None: l2_penalty = 0 for (parent, child) in tree.edges(): for child_of_child in tree.children(child): - edge_length_above =\ + edge_length_above = ( r_X_t_variables[parent] - r_X_t_variables[child] - edge_length_below =\ + ) + edge_length_below = ( r_X_t_variables[child] - r_X_t_variables[child_of_child] + ) l2_penalty += (edge_length_above - edge_length_below) ** 2 l2_penalty *= l2_regularization @@ -133,7 +145,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: obj = cp.Maximize(log_likelihood - l2_penalty) prob = cp.Problem(obj, all_constraints) - f_star = prob.solve(solver='ECOS', verbose=verbose) + f_star = prob.solve(solver="ECOS", verbose=verbose) # # # # # Populate the tree with the estimated branch lengths # # # # # @@ -141,12 +153,10 @@ def estimate_branch_lengths(self, tree: Tree) -> None: tree.set_age(node, age=r_X_t_variables[node].value) for (parent, child) in tree.edges(): - new_edge_length =\ + new_edge_length = ( r_X_t_variables[parent].value - r_X_t_variables[child].value - tree.set_edge_length( - parent, - child, - length=new_edge_length) + ) + tree.set_edge_length(parent, child, length=new_edge_length) self.log_likelihood = log_likelihood.value self.log_loss = f_star @@ -160,10 +170,10 @@ def log_likelihood(self, tree: Tree) -> float: for (parent, child) in tree.edges(): edge_length = tree.get_age(parent) - tree.get_age(child) # TODO: hardcoded '0' here... - zeros_parent = tree.get_state(parent).count('0') - zeros_child = tree.get_state(child).count('0') + zeros_parent = tree.get_state(parent).count("0") + zeros_child = tree.get_state(child).count("0") new_cuts_child = zeros_parent - zeros_child - assert(new_cuts_child >= 0) + assert new_cuts_child >= 0 # Add log-lik for characters that didn't get cut log_likelihood += zeros_child * (-edge_length) # Add log-lik for characters that got cut @@ -187,11 +197,12 @@ class IIDExponentialBLEGridSearchCV(BranchLengthEstimator): l2_regularizations: The grid of l2_regularization to use. verbose: Verbosity level. """ + def __init__( self, - minimum_branch_lengths: Tuple[float] = (0, ), - l2_regularizations: Tuple[float] = (0, ), - verbose: bool = False + minimum_branch_lengths: Tuple[float] = (0,), + l2_regularizations: Tuple[float] = (0,), + verbose: bool = False, ): self.minimum_branch_lengths = minimum_branch_lengths self.l2_regularizations = l2_regularizations @@ -216,24 +227,30 @@ def estimate_branch_lengths(self, tree: Tree) -> None: cv_log_likelihood = self._cv_log_likelihood( tree=tree, minimum_branch_length=minimum_branch_length, - l2_regularization=l2_regularization) + l2_regularization=l2_regularization, + ) held_out_log_likelihoods.append( - (cv_log_likelihood, - [minimum_branch_length, - l2_regularization]) + ( + cv_log_likelihood, + [minimum_branch_length, l2_regularization], + ) ) # Refit model on full dataset with the best hyperparameters held_out_log_likelihoods.sort(reverse=True) - best_minimum_branch_length, best_l2_regularization =\ - held_out_log_likelihoods[0][1] + ( + best_minimum_branch_length, + best_l2_regularization, + ) = held_out_log_likelihoods[0][1] if verbose: - print(f"Refitting full model with:\n" - f"minimum_branch_length={best_minimum_branch_length}\n" - f"l2_regularization={best_l2_regularization}") + print( + f"Refitting full model with:\n" + f"minimum_branch_length={best_minimum_branch_length}\n" + f"l2_regularization={best_l2_regularization}" + ) final_model = IIDExponentialBLE( minimum_branch_length=best_minimum_branch_length, - l2_regularization=best_l2_regularization + l2_regularization=best_l2_regularization, ) final_model.estimate_branch_lengths(tree) self.minimum_branch_length = best_minimum_branch_length @@ -242,10 +259,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: self.log_loss = final_model.log_loss def _cv_log_likelihood( - self, - tree: Tree, - minimum_branch_length: float, - l2_regularization: float + self, tree: Tree, minimum_branch_length: float, l2_regularization: float ) -> float: r""" Given the tree and the parameters of the model, returns the @@ -257,39 +271,41 @@ def _cv_log_likelihood( """ verbose = self.verbose if verbose: - print(f"Cross-validating hyperparameters:" - f"\nminimum_branch_length={minimum_branch_length}" - f"\nl2_regularizations={l2_regularization}") + print( + f"Cross-validating hyperparameters:" + f"\nminimum_branch_length={minimum_branch_length}" + f"\nl2_regularizations={l2_regularization}" + ) n_characters = tree.num_characters() log_likelihood_folds = np.zeros(shape=(n_characters)) for held_out_character_idx in range(n_characters): - tree_train, tree_valid =\ - self._cv_split( - tree=tree, - held_out_character_idx=held_out_character_idx - ) + tree_train, tree_valid = self._cv_split( + tree=tree, held_out_character_idx=held_out_character_idx + ) try: IIDExponentialBLE( minimum_branch_length=minimum_branch_length, - l2_regularization=l2_regularization + l2_regularization=l2_regularization, ).estimate_branch_lengths(tree_train) tree_valid.copy_branch_lengths(tree_other=tree_train) - held_out_log_likelihood =\ - IIDExponentialBLE.log_likelihood(tree_valid) + held_out_log_likelihood = IIDExponentialBLE.log_likelihood( + tree_valid + ) except cp.error.SolverError: held_out_log_likelihood = -np.inf - log_likelihood_folds[held_out_character_idx] =\ - held_out_log_likelihood + log_likelihood_folds[ + held_out_character_idx + ] = held_out_log_likelihood if verbose: print(f"log_likelihood_folds = {log_likelihood_folds}") - print(f"mean log_likelihood_folds = " - f"{np.mean(log_likelihood_folds)}") + print( + f"mean log_likelihood_folds = " + f"{np.mean(log_likelihood_folds)}" + ) return np.mean(log_likelihood_folds) def _cv_split( - self, - tree: Tree, - held_out_character_idx: int + self, tree: Tree, held_out_character_idx: int ) -> Tuple[Tree, Tree]: r""" Creates a training and a cross validation tree by hiding the @@ -299,11 +315,11 @@ def _cv_split( tree_valid = copy.deepcopy(tree) for node in tree.nodes(): state = tree_train.get_state(node) - train_state =\ - state[:held_out_character_idx]\ - + state[(held_out_character_idx + 1):] - valid_data =\ - state[held_out_character_idx] + train_state = ( + state[:held_out_character_idx] + + state[(held_out_character_idx + 1) :] + ) + valid_data = state[held_out_character_idx] tree_train.set_state(node, train_state) tree_valid.set_state(node, valid_data) return tree_train, tree_valid diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index e0b30ae3..0aa6e762 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -13,6 +13,7 @@ class LineageSimulator(abc.ABC): A LineageSimulator implements the method simulate_lineage that generates a lineage tree (i.e. a phylogeny, in the form of a Tree). """ + @abc.abstractmethod def simulate_lineage(self) -> Tree: r""" @@ -30,10 +31,8 @@ class PerfectBinaryTree(LineageSimulator): generation_branch_lengths: The branches at depth d in the tree will have length generation_branch_lengths[d] """ - def __init__( - self, - generation_branch_lengths: List[float] - ): + + def __init__(self, generation_branch_lengths: List[float]): self.generation_branch_lengths = generation_branch_lengths[:] def simulate_lineage(self) -> Tree: @@ -44,8 +43,10 @@ def simulate_lineage(self) -> Tree: n_generations = len(generation_branch_lengths) tree = nx.DiGraph() tree.add_nodes_from(range(2 ** (n_generations + 1) - 1)) - edges = [(int((child - 1) / 2), child) - for child in range(1, 2 ** (n_generations + 1) - 1)] + edges = [ + (int((child - 1) / 2), child) + for child in range(1, 2 ** (n_generations + 1) - 1) + ] node_generation = [] for i in range(n_generations + 1): node_generation += [i] * 2 ** i @@ -58,8 +59,9 @@ def simulate_lineage(self) -> Tree: for child in range(1, 2 ** (n_generations + 1) - 1): child_generation = node_generation[child] branch_length = generation_branch_lengths[child_generation - 1] - tree.nodes[child]["age"] =\ + tree.nodes[child]["age"] = ( tree.nodes[int((child - 1) / 2)]["age"] - branch_length + ) return Tree(tree) @@ -72,10 +74,8 @@ class PerfectBinaryTreeWithRootBranch(LineageSimulator): generation_branch_lengths: The branches at depth d in the tree will have length generation_branch_lengths[d] """ - def __init__( - self, - generation_branch_lengths: List[float] - ): + + def __init__(self, generation_branch_lengths: List[float]): self.generation_branch_lengths = generation_branch_lengths def simulate_lineage(self) -> Tree: @@ -87,8 +87,9 @@ def simulate_lineage(self) -> Tree: n_generations = len(generation_branch_lengths) tree = nx.DiGraph() tree.add_nodes_from(range(2 ** n_generations)) - edges = [(int(child / 2), child) - for child in range(1, 2 ** n_generations)] + edges = [ + (int(child / 2), child) for child in range(1, 2 ** n_generations) + ] tree.add_edges_from(edges) node_generation = [0] for i in range(n_generations): @@ -101,6 +102,7 @@ def simulate_lineage(self) -> Tree: for child in range(1, 2 ** n_generations): child_generation = node_generation[child] branch_length = generation_branch_lengths[child_generation - 1] - tree.nodes[child]["age"] =\ + tree.nodes[child]["age"] = ( tree.nodes[int(child / 2)]["age"] - branch_length + ) return Tree(tree) diff --git a/cassiopeia/tools/lineage_tracing_simulator.py b/cassiopeia/tools/lineage_tracing_simulator.py index 51b430b7..fb01055e 100644 --- a/cassiopeia/tools/lineage_tracing_simulator.py +++ b/cassiopeia/tools/lineage_tracing_simulator.py @@ -13,6 +13,7 @@ class LineageTracingSimulator(abc.ABC): which overlays lineage tracing data (i.e. character vectors) on top of the tree. These are stored as the node's state. """ + @abc.abstractmethod def overlay_lineage_tracing_data(self, tree: Tree) -> None: r""" @@ -32,11 +33,8 @@ class IIDExponentialLineageTracer(LineageTracingSimulator): mutation_rate: The mutation rate of each character (same for all). num_characters: The number of characters. """ - def __init__( - self, - mutation_rate: float, - num_characters: float - ): + + def __init__(self, mutation_rate: float, num_characters: float): self.mutation_rate = mutation_rate self.num_characters = num_characters @@ -51,28 +49,30 @@ def dfs(node: int, tree: Tree): node_state = tree.get_state(node) for child in tree.children(node): # Compute the state of the child - child_state = '' + child_state = "" edge_length = tree.get_age(node) - tree.get_age(child) # print(f"{node} -> {child}, length {edge_length}") - assert(edge_length >= 0) + assert edge_length >= 0 for i in range(num_characters): # See what happens to character i - if node_state[i] != '0': + if node_state[i] != "0": # The character has already mutated; there in nothing # to do child_state += node_state[i] continue else: # Determine if the character will mutate. - mutates =\ - np.random.exponential(1.0 / mutation_rate)\ + mutates = ( + np.random.exponential(1.0 / mutation_rate) < edge_length + ) if mutates: - child_state += '1' + child_state += "1" else: - child_state += '0' + child_state += "0" tree.set_state(child, child_state) dfs(child, tree) + root = tree.root() - tree.set_state(root, '0' * num_characters) + tree.set_state(root, "0" * num_characters) dfs(root, tree) diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index eb66030b..ef5ab2b9 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -3,7 +3,7 @@ import networkx as nx -class Tree(): +class Tree: r""" A phylogenetic tree for holding data from lineages and lineage tracing experiments. @@ -13,6 +13,7 @@ class Tree(): Args: tree: The networkx.DiGraph from which to create the tree. """ + def __init__(self, tree: nx.DiGraph): self.tree = tree @@ -23,9 +24,11 @@ def root(self) -> int: def leaves(self) -> List[int]: tree = self.tree - leaves = [n for n in tree - if tree.out_degree(n) == 0 - and tree.in_degree(n) == 1] + leaves = [ + n + for n in tree + if tree.out_degree(n) == 0 and tree.in_degree(n) == 1 + ] return leaves def internal_nodes(self) -> List[int]: @@ -80,8 +83,8 @@ def set_edge_length(self, parent: int, child: int, length: float) -> None: tree.edges[parent, child]["length"] = length def set_edge_lengths( - self, - parent_child_and_length_list: List[Tuple[int, int, float]]) -> None: + self, parent_child_and_length_list: List[Tuple[int, int, float]] + ) -> None: for (parent, child, length) in parent_child_and_length_list: self.set_edge_length(parent, child, length) @@ -95,7 +98,7 @@ def to_newick_tree_format( print_internal_nodes: bool = False, append_state_to_node_name: bool = False, print_pct_of_mutated_characters_along_edge: bool = False, - add_N_to_node_id: bool = False + add_N_to_node_id: bool = False, ) -> str: r""" Converts tree into Newick tree format. @@ -109,11 +112,13 @@ def to_newick_tree_format( leaves = self.leaves() def format_node(v: int): - node_id_prefix = '' if not add_N_to_node_id else 'N' - node_id = '' if not print_node_names else str(v) - node_suffix =\ - '' if not append_state_to_node_name\ - else '_' + str(self.get_state(v)) + node_id_prefix = "" if not add_N_to_node_id else "N" + node_id = "" if not print_node_names else str(v) + node_suffix = ( + "" + if not append_state_to_node_name + else "_" + str(self.get_state(v)) + ) return node_id_prefix + node_id + node_suffix def subtree_newick_representation(v: int) -> str: @@ -125,38 +130,42 @@ def subtree_newick_representation(v: int) -> str: if child in leaves: subtree_newick = subtree_newick_representation(child) else: - subtree_newick =\ - '(' + subtree_newick_representation(child) + ')' + subtree_newick = ( + "(" + subtree_newick_representation(child) + ")" + ) if print_internal_nodes: subtree_newick += format_node(child) # Add edge length - subtree_newick = subtree_newick + ':' + str(edge_length) + subtree_newick = subtree_newick + ":" + str(edge_length) if print_pct_of_mutated_characters_along_edge: # Also add number of mutations - number_of_unmutated_characters_in_parent =\ - self.get_state(v).count('0') - number_of_mutations_along_edge =\ - self.get_state(v).count('0')\ - - self.get_state(child).count('0') - pct_of_mutated_characters_along_edge =\ - number_of_mutations_along_edge /\ - (number_of_unmutated_characters_in_parent + 1e-100) - subtree_newick = subtree_newick +\ - "[&&NHX:muts="\ + number_of_unmutated_characters_in_parent = self.get_state( + v + ).count("0") + number_of_mutations_along_edge = self.get_state(v).count( + "0" + ) - self.get_state(child).count("0") + pct_of_mutated_characters_along_edge = ( + number_of_mutations_along_edge + / (number_of_unmutated_characters_in_parent + 1e-100) + ) + subtree_newick = ( + subtree_newick + "[&&NHX:muts=" f"{self._fmt(pct_of_mutated_characters_along_edge)}]" + ) subtrees_newick.append(subtree_newick) - newick = ','.join(subtrees_newick) + newick = ",".join(subtrees_newick) return newick root = self.root() - res = '(' + subtree_newick_representation(root) + ')' + res = "(" + subtree_newick_representation(root) + ")" if print_internal_nodes: res += format_node(root) - res += ');' + res += ");" return res def _fmt(self, x: float): - return '%.2f' % x + return "%.2f" % x def reconstruct_ancestral_states(self): r""" @@ -173,42 +182,46 @@ def dfs(v: int) -> None: dfs(child) children_states = [self.get_state(child) for child in children] n_characters = len(children_states[0]) - state = '' + state = "" for character_id in range(n_characters): - states_for_this_character =\ - set([children_states[i][character_id] - for i in range(n_children)]) + states_for_this_character = set( + [ + children_states[i][character_id] + for i in range(n_children) + ] + ) if len(states_for_this_character) == 1: state += states_for_this_character.pop() else: - state += '0' + state += "0" self.set_state(v, state) if v == root: # Reset state to all zeros! - self.set_state(v, '0' * n_characters) + self.set_state(v, "0" * n_characters) + dfs(root) def copy_branch_lengths(self, tree_other): r""" Copies the branch lengths of tree_other onto self """ - assert(self.nodes() == tree_other.nodes()) - assert(self.edges() == tree_other.edges()) + assert self.nodes() == tree_other.nodes() + assert self.edges() == tree_other.edges() for node in self.nodes(): new_age = tree_other.get_age(node) self.set_age(node, age=new_age) for (parent, child) in self.edges(): - new_edge_length =\ - tree_other.get_age(parent) - tree_other.get_age(child) - self.set_edge_length( - parent, - child, - length=new_edge_length) + new_edge_length = tree_other.get_age(parent) - tree_other.get_age( + child + ) + self.set_edge_length(parent, child, length=new_edge_length) def print_edges(self): for (parent, child) in self.edges(): - print(f"{parent}[{self.get_state(parent)}] -> " - f"{child}[{self.get_state(child)}]: " - f"{self.get_edge_length(parent, child)}") + print( + f"{parent}[{self.get_state(parent)}] -> " + f"{child}[{self.get_state(child)}]: " + f"{self.get_edge_length(parent, child)}" + ) diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index b56d754e..b97c6a64 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -1,8 +1,12 @@ import networkx as nx import numpy as np -from cassiopeia.tools import (IIDExponentialBLE, IIDExponentialBLEGridSearchCV, - IIDExponentialLineageTracer, Tree) +from cassiopeia.tools import ( + IIDExponentialBLE, + IIDExponentialBLEGridSearchCV, + IIDExponentialLineageTracer, + Tree, +) def test_no_mutations(): @@ -19,8 +23,8 @@ def test_no_mutations(): tree = nx.DiGraph() tree.add_node(0), tree.add_node(1) tree.add_edge(0, 1) - tree.nodes[0]["characters"] = '0' - tree.nodes[1]["characters"] = '0' + tree.nodes[0]["characters"] = "0" + tree.nodes[1]["characters"] = "0" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) @@ -47,14 +51,14 @@ def test_saturation(): tree = nx.DiGraph() tree.add_nodes_from([0, 1]) tree.add_edge(0, 1) - tree.nodes[0]["characters"] = '0' - tree.nodes[1]["characters"] = '1' + tree.nodes[0]["characters"] = "0" + tree.nodes[1]["characters"] = "1" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - assert(tree.get_edge_length(0, 1) > 15.0) - assert(tree.get_age(0) > 15.0) + assert tree.get_edge_length(0, 1) > 15.0 + assert tree.get_age(0) > 15.0 np.testing.assert_almost_equal(tree.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=5) log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) @@ -76,14 +80,15 @@ def test_hand_solvable_problem_1(): tree = nx.DiGraph() tree.add_nodes_from([0, 1]) tree.add_edge(0, 1) - tree.nodes[0]["characters"] = '00' - tree.nodes[1]["characters"] = '01' + tree.nodes[0]["characters"] = "00" + tree.nodes[1]["characters"] = "01" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), np.log(2), decimal=3) + tree.get_edge_length(0, 1), np.log(2), decimal=3 + ) np.testing.assert_almost_equal(tree.get_age(0), np.log(2), decimal=3) np.testing.assert_almost_equal(tree.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) @@ -106,14 +111,15 @@ def test_hand_solvable_problem_2(): tree = nx.DiGraph() tree.add_nodes_from([0, 1]) tree.add_edge(0, 1) - tree.nodes[0]["characters"] = '000' - tree.nodes[1]["characters"] = '011' + tree.nodes[0]["characters"] = "000" + tree.nodes[1]["characters"] = "011" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), np.log(3), decimal=3) + tree.get_edge_length(0, 1), np.log(3), decimal=3 + ) np.testing.assert_almost_equal(tree.get_age(0), np.log(3), decimal=3) np.testing.assert_almost_equal(tree.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) @@ -136,14 +142,15 @@ def test_hand_solvable_problem_3(): tree = nx.DiGraph() tree.add_nodes_from([0, 1]) tree.add_edge(0, 1) - tree.nodes[0]["characters"] = '000' - tree.nodes[1]["characters"] = '001' + tree.nodes[0]["characters"] = "000" + tree.nodes[1]["characters"] = "001" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), np.log(1.5), decimal=3) + tree.get_edge_length(0, 1), np.log(1.5), decimal=3 + ) np.testing.assert_almost_equal(tree.get_age(0), np.log(1.5), decimal=3) np.testing.assert_almost_equal(tree.get_age(1), 0.0) np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) @@ -158,20 +165,21 @@ def test_small_tree_with_no_mutations(): tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]) tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = '0000' - tree.nodes[1]["characters"] = '0000' - tree.nodes[2]["characters"] = '0000' - tree.nodes[3]["characters"] = '0000' - tree.nodes[4]["characters"] = '0000' - tree.nodes[5]["characters"] = '0000' - tree.nodes[6]["characters"] = '0000' + tree.nodes[0]["characters"] = "0000" + tree.nodes[1]["characters"] = "0000" + tree.nodes[2]["characters"] = "0000" + tree.nodes[3]["characters"] = "0000" + tree.nodes[4]["characters"] = "0000" + tree.nodes[5]["characters"] = "0000" + tree.nodes[6]["characters"] = "0000" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood for edge in tree.edges(): np.testing.assert_almost_equal( - tree.get_edge_length(*edge), 0, decimal=3) + tree.get_edge_length(*edge), 0, decimal=3 + ) np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -189,13 +197,13 @@ def test_small_tree_with_one_mutation(): tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = '0' - tree.nodes[1]["characters"] = '0' - tree.nodes[2]["characters"] = '0' - tree.nodes[3]["characters"] = '0' - tree.nodes[4]["characters"] = '0' - tree.nodes[5]["characters"] = '0' - tree.nodes[6]["characters"] = '1' + tree.nodes[0]["characters"] = "0" + tree.nodes[1]["characters"] = "0" + tree.nodes[2]["characters"] = "0" + tree.nodes[3]["characters"] = "0" + tree.nodes[4]["characters"] = "0" + tree.nodes[5]["characters"] = "0" + tree.nodes[6]["characters"] = "1" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) @@ -219,20 +227,20 @@ def test_small_tree_with_saturation(): tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = '0' - tree.nodes[1]["characters"] = '0' - tree.nodes[2]["characters"] = '1' - tree.nodes[3]["characters"] = '1' - tree.nodes[4]["characters"] = '1' - tree.nodes[5]["characters"] = '1' - tree.nodes[6]["characters"] = '1' + tree.nodes[0]["characters"] = "0" + tree.nodes[1]["characters"] = "0" + tree.nodes[2]["characters"] = "1" + tree.nodes[3]["characters"] = "1" + tree.nodes[4]["characters"] = "1" + tree.nodes[5]["characters"] = "1" + tree.nodes[6]["characters"] = "1" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - assert(tree.get_edge_length(0, 2) > 15.0) - assert(tree.get_edge_length(1, 3) > 15.0) - assert(tree.get_edge_length(1, 4) > 15.0) + assert tree.get_edge_length(0, 2) > 15.0 + assert tree.get_edge_length(1, 3) > 15.0 + assert tree.get_edge_length(1, 4) > 15.0 log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -246,13 +254,13 @@ def test_small_tree_regression(): tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = '000000000' - tree.nodes[1]["characters"] = '100000000' - tree.nodes[2]["characters"] = '000006000' - tree.nodes[3]["characters"] = '120000000' - tree.nodes[4]["characters"] = '103000000' - tree.nodes[5]["characters"] = '000056700' - tree.nodes[6]["characters"] = '000406089' + tree.nodes[0]["characters"] = "000000000" + tree.nodes[1]["characters"] = "100000000" + tree.nodes[2]["characters"] = "000006000" + tree.nodes[3]["characters"] = "120000000" + tree.nodes[4]["characters"] = "103000000" + tree.nodes[5]["characters"] = "000056700" + tree.nodes[6]["characters"] = "000406089" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) @@ -275,25 +283,29 @@ def test_small_symmetric_tree(): tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = '000' - tree.nodes[1]["characters"] = '100' - tree.nodes[2]["characters"] = '100' - tree.nodes[3]["characters"] = '110' - tree.nodes[4]["characters"] = '110' - tree.nodes[5]["characters"] = '110' - tree.nodes[6]["characters"] = '110' + tree.nodes[0]["characters"] = "000" + tree.nodes[1]["characters"] = "100" + tree.nodes[2]["characters"] = "100" + tree.nodes[3]["characters"] = "110" + tree.nodes[4]["characters"] = "110" + tree.nodes[5]["characters"] = "110" + tree.nodes[6]["characters"] = "110" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), tree.get_edge_length(0, 2)) + tree.get_edge_length(0, 1), tree.get_edge_length(0, 2) + ) np.testing.assert_almost_equal( - tree.get_edge_length(1, 3), tree.get_edge_length(1, 4)) + tree.get_edge_length(1, 3), tree.get_edge_length(1, 4) + ) np.testing.assert_almost_equal( - tree.get_edge_length(1, 4), tree.get_edge_length(2, 5)) + tree.get_edge_length(1, 4), tree.get_edge_length(2, 5) + ) np.testing.assert_almost_equal( - tree.get_edge_length(2, 5), tree.get_edge_length(2, 6)) + tree.get_edge_length(2, 5), tree.get_edge_length(2, 6) + ) log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -308,23 +320,23 @@ def test_small_tree_with_infinite_legs(): tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = '00' - tree.nodes[1]["characters"] = '10' - tree.nodes[2]["characters"] = '10' - tree.nodes[3]["characters"] = '11' - tree.nodes[4]["characters"] = '11' - tree.nodes[5]["characters"] = '11' - tree.nodes[6]["characters"] = '11' + tree.nodes[0]["characters"] = "00" + tree.nodes[1]["characters"] = "10" + tree.nodes[2]["characters"] = "10" + tree.nodes[3]["characters"] = "11" + tree.nodes[4]["characters"] = "11" + tree.nodes[5]["characters"] = "11" + tree.nodes[6]["characters"] = "11" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal(tree.get_edge_length(0, 1), 0.693, decimal=3) np.testing.assert_almost_equal(tree.get_edge_length(0, 2), 0.693, decimal=3) - assert(tree.get_edge_length(1, 3) > 15) - assert(tree.get_edge_length(1, 4) > 15) - assert(tree.get_edge_length(2, 5) > 15) - assert(tree.get_edge_length(2, 6) > 15) + assert tree.get_edge_length(1, 3) > 15 + assert tree.get_edge_length(1, 4) > 15 + assert tree.get_edge_length(2, 5) > 15 + assert tree.get_edge_length(2, 6) > 15 log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -342,16 +354,17 @@ def test_on_simulated_data(): tree.nodes[6]["age"] = 0 np.random.seed(1) tree = Tree(tree) - IIDExponentialLineageTracer(mutation_rate=1.0, num_characters=100)\ - .overlay_lineage_tracing_data(tree) + IIDExponentialLineageTracer( + mutation_rate=1.0, num_characters=100 + ).overlay_lineage_tracing_data(tree) for node in tree.nodes(): tree.set_age(node, -1) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood - assert(0.9 < tree.get_age(0) < 1.1) - assert(0.8 < tree.get_age(1) < 1.0) - assert(0.05 < tree.get_age(2) < 0.15) + assert 0.9 < tree.get_age(0) < 1.1 + assert 0.8 < tree.get_age(1) < 1.0 + assert 0.05 < tree.get_age(2) < 0.15 np.testing.assert_almost_equal(tree.get_age(3), 0) np.testing.assert_almost_equal(tree.get_age(4), 0) np.testing.assert_almost_equal(tree.get_age(5), 0) @@ -368,21 +381,23 @@ def test_subtree_collapses_when_no_mutations(): tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4]), tree.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) - tree.nodes[0]["characters"] = '0' - tree.nodes[1]["characters"] = '1' - tree.nodes[2]["characters"] = '1' - tree.nodes[3]["characters"] = '1' - tree.nodes[4]["characters"] = '0' + tree.nodes[0]["characters"] = "0" + tree.nodes[1]["characters"] = "1" + tree.nodes[2]["characters"] = "1" + tree.nodes[3]["characters"] = "1" + tree.nodes[4]["characters"] = "0" tree = Tree(tree) model = IIDExponentialBLE() model.estimate_branch_lengths(tree) log_likelihood = model.log_likelihood np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), np.log(2), decimal=3) + tree.get_edge_length(0, 1), np.log(2), decimal=3 + ) np.testing.assert_almost_equal(tree.get_edge_length(1, 2), 0.0, decimal=3) np.testing.assert_almost_equal(tree.get_edge_length(1, 3), 0.0, decimal=3) np.testing.assert_almost_equal( - tree.get_edge_length(0, 4), np.log(2), decimal=3) + tree.get_edge_length(0, 4), np.log(2), decimal=3 + ) np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) @@ -392,13 +407,13 @@ def test_IIDExponentialBLEGridSearchCV(): tree = nx.DiGraph() tree.add_nodes_from([0, 1]), tree.add_edges_from([(0, 1)]) - tree.nodes[0]["characters"] = '000' - tree.nodes[1]["characters"] = '001' + tree.nodes[0]["characters"] = "000" + tree.nodes[1]["characters"] = "001" tree = Tree(tree) model = IIDExponentialBLEGridSearchCV( minimum_branch_lengths=(0, 1.0, 3.0), - l2_regularizations=(0, ), - verbose=True + l2_regularizations=(0,), + verbose=True, ) model.estimate_branch_lengths(tree) minimum_branch_length = model.minimum_branch_length diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py index 8e4fb77a..d74b1f49 100644 --- a/test/tools_tests/lineage_simulator_test.py +++ b/test/tools_tests/lineage_simulator_test.py @@ -2,14 +2,16 @@ def test_PerfectBinaryTree(): - tree = PerfectBinaryTree(generation_branch_lengths=[2, 3])\ - .simulate_lineage() + tree = PerfectBinaryTree( + generation_branch_lengths=[2, 3] + ).simulate_lineage() newick = tree.to_newick_tree_format(print_internal_nodes=True) - assert(newick == '((3:3,4:3)1:2,(5:3,6:3)2:2)0);') + assert newick == "((3:3,4:3)1:2,(5:3,6:3)2:2)0);" def test_PerfectBinaryTreeWithRootBranch(): - tree = PerfectBinaryTreeWithRootBranch(generation_branch_lengths=[2, 3, 4])\ - .simulate_lineage() + tree = PerfectBinaryTreeWithRootBranch( + generation_branch_lengths=[2, 3, 4] + ).simulate_lineage() newick = tree.to_newick_tree_format(print_internal_nodes=True) - assert(newick == '(((4:4,5:4)2:3,(6:4,7:4)3:3)1:2)0);') + assert newick == "(((4:4,5:4)2:3,(6:4,7:4)3:3)1:2)0);" diff --git a/test/tools_tests/lineage_tracing_simulator_test.py b/test/tools_tests/lineage_tracing_simulator_test.py index b9ffb04f..4d2bab73 100644 --- a/test/tools_tests/lineage_tracing_simulator_test.py +++ b/test/tools_tests/lineage_tracing_simulator_test.py @@ -20,5 +20,6 @@ def test_smoke(): tree.nodes[5]["age"] = 0 tree.nodes[6]["age"] = 0 tree = Tree(tree) - IIDExponentialLineageTracer(mutation_rate=1.0, num_characters=10)\ - .overlay_lineage_tracing_data(tree) + IIDExponentialLineageTracer( + mutation_rate=1.0, num_characters=10 + ).overlay_lineage_tracing_data(tree) diff --git a/test/tools_tests/tree_test.py b/test/tools_tests/tree_test.py index 8101f093..21c09b4a 100644 --- a/test/tools_tests/tree_test.py +++ b/test/tools_tests/tree_test.py @@ -14,118 +14,157 @@ def test_to_newick_tree_format(): tree.add_edges_from([(0, 1), (1, 2), (1, 3), (3, 4), (3, 5)]) tree = Tree(tree) tree.set_edge_lengths( - [(0, 1, 0.1), - (1, 2, 0.5), - (1, 3, 0.2), - (3, 4, 0.3), - (3, 5, 0.4)] + [(0, 1, 0.1), (1, 2, 0.5), (1, 3, 0.2), (3, 4, 0.3), (3, 5, 0.4)] ) tree.set_states( - [(0, '0000000000'), - (1, '1000000000'), - (2, '1111000000'), - (3, '1110000000'), - (4, '1110000111'), - (5, '1110111111')] + [ + (0, "0000000000"), + (1, "1000000000"), + (2, "1111000000"), + (3, "1110000000"), + (4, "1110000111"), + (5, "1110111111"), + ] ) res = tree.to_newick_tree_format(print_internal_nodes=False) - assert(res == "((2:0.5,(4:0.3,5:0.4):0.2):0.1));") + assert res == "((2:0.5,(4:0.3,5:0.4):0.2):0.1));" res = tree.to_newick_tree_format( print_node_names=False, print_internal_nodes=True, - append_state_to_node_name=True) - assert(res == "((_1111000000:0.5,(_1110000111:0.3,_1110111111:0.4)" - "_1110000000:0.2)_1000000000:0.1)_0000000000);") + append_state_to_node_name=True, + ) + assert ( + res == "((_1111000000:0.5,(_1110000111:0.3,_1110111111:0.4)" + "_1110000000:0.2)_1000000000:0.1)_0000000000);" + ) res = tree.to_newick_tree_format(print_internal_nodes=True) - assert(res == "((2:0.5,(4:0.3,5:0.4)3:0.2)1:0.1)0);") + assert res == "((2:0.5,(4:0.3,5:0.4)3:0.2)1:0.1)0);" res = tree.to_newick_tree_format(print_node_names=False) - assert(res == "((:0.5,(:0.3,:0.4):0.2):0.1));") + assert res == "((:0.5,(:0.3,:0.4):0.2):0.1));" res = tree.to_newick_tree_format( - print_internal_nodes=True, - add_N_to_node_id=True) - assert(res == "((N2:0.5,(N4:0.3,N5:0.4)N3:0.2)N1:0.1)N0);") + print_internal_nodes=True, add_N_to_node_id=True + ) + assert res == "((N2:0.5,(N4:0.3,N5:0.4)N3:0.2)N1:0.1)N0);" res = tree.to_newick_tree_format( print_internal_nodes=True, append_state_to_node_name=True, - add_N_to_node_id=True) - assert(res == "((N2_1111000000:0.5,(N4_1110000111:0.3,N5_1110111111:0.4)" - "N3_1110000000:0.2)N1_1000000000:0.1)N0_0000000000);") + add_N_to_node_id=True, + ) + assert ( + res == "((N2_1111000000:0.5,(N4_1110000111:0.3,N5_1110111111:0.4)" + "N3_1110000000:0.2)N1_1000000000:0.1)N0_0000000000);" + ) res = tree.to_newick_tree_format( print_internal_nodes=True, print_pct_of_mutated_characters_along_edge=True, - add_N_to_node_id=True) - assert(res == "((N2:0.5[&&NHX:muts=0.33],(N4:0.3[&&NHX:muts=0.43]," - "N5:0.4[&&NHX:muts=0.86])N3:0.2[&&NHX:muts=0.22])" - "N1:0.1[&&NHX:muts=0.10])N0);") + add_N_to_node_id=True, + ) + assert ( + res == "((N2:0.5[&&NHX:muts=0.33],(N4:0.3[&&NHX:muts=0.43]," + "N5:0.4[&&NHX:muts=0.86])N3:0.2[&&NHX:muts=0.22])" + "N1:0.1[&&NHX:muts=0.10])N0);" + ) def test_reconstruct_ancestral_states(): tree = nx.DiGraph() tree.add_nodes_from(list(range(17))) - tree.add_edges_from([(10, 11), - (11, 13), - (13, 0), (13, 1), - (11, 14), - (14, 2), (14, 3), - (10, 12), - (12, 15), - (15, 4), (15, 5), - (12, 16), - (16, 6), (16, 7), (16, 8), (16, 9)]) + tree.add_edges_from( + [ + (10, 11), + (11, 13), + (13, 0), + (13, 1), + (11, 14), + (14, 2), + (14, 3), + (10, 12), + (12, 15), + (15, 4), + (15, 5), + (12, 16), + (16, 6), + (16, 7), + (16, 8), + (16, 9), + ] + ) tree = Tree(tree) tree.set_states( - [(0, '01101110100'), - (1, '01211111111'), - (2, '01322121111'), - (3, '01432122111'), - (4, '01541232111'), - (5, '01651233111'), - (6, '01763243111'), - (7, '01873240111'), - (8, '01983240111'), - (9, '01093240010'), - ] + [ + (0, "01101110100"), + (1, "01211111111"), + (2, "01322121111"), + (3, "01432122111"), + (4, "01541232111"), + (5, "01651233111"), + (6, "01763243111"), + (7, "01873240111"), + (8, "01983240111"), + (9, "01093240010"), + ] ) tree.reconstruct_ancestral_states() - assert(tree.get_state(10) == '00000000000') - assert(tree.get_state(11) == '01000100100') - assert(tree.get_state(13) == '01001110100') - assert(tree.get_state(14) == '01002120111') - assert(tree.get_state(12) == '01000200010') - assert(tree.get_state(15) == '01001230111') - assert(tree.get_state(16) == '01003240010') + assert tree.get_state(10) == "00000000000" + assert tree.get_state(11) == "01000100100" + assert tree.get_state(13) == "01001110100" + assert tree.get_state(14) == "01002120111" + assert tree.get_state(12) == "01000200010" + assert tree.get_state(15) == "01001230111" + assert tree.get_state(16) == "01003240010" def test_reconstruct_ancestral_states_DREAM_challenge_tree_25(): tree = nx.DiGraph() tree.add_nodes_from(list(range(21))) - tree.add_edges_from([(9, 8), (8, 10), (8, 7), (7, 11), (7, 12), (9, 6), - (6, 2), (2, 0), (0, 13), (0, 14), (2, 1), (1, 15), - (1, 16), (6, 5), (5, 3), (3, 17), (3, 18), (5, 4), - (4, 19), (4, 20)]) + tree.add_edges_from( + [ + (9, 8), + (8, 10), + (8, 7), + (7, 11), + (7, 12), + (9, 6), + (6, 2), + (2, 0), + (0, 13), + (0, 14), + (2, 1), + (1, 15), + (1, 16), + (6, 5), + (5, 3), + (3, 17), + (3, 18), + (5, 4), + (4, 19), + (4, 20), + ] + ) tree = Tree(tree) tree.set_states( - [(10, '0022100000'), - (11, '0022100000'), - (12, '0022100000'), - (13, '2012000220'), - (14, '2012000200'), - (15, '2012000100'), - (16, '2012000100'), - (17, '0001110220'), - (18, '0001110220'), - (19, '0000210220'), - (20, '0000210220'), - ] + [ + (10, "0022100000"), + (11, "0022100000"), + (12, "0022100000"), + (13, "2012000220"), + (14, "2012000200"), + (15, "2012000100"), + (16, "2012000100"), + (17, "0001110220"), + (18, "0001110220"), + (19, "0000210220"), + (20, "0000210220"), + ] ) tree.reconstruct_ancestral_states() - assert(tree.get_state(7) == '0022100000') - assert(tree.get_state(8) == '0022100000') - assert(tree.get_state(0) == '2012000200') - assert(tree.get_state(1) == '2012000100') - assert(tree.get_state(2) == '2012000000') - assert(tree.get_state(3) == '0001110220') - assert(tree.get_state(4) == '0000210220') - assert(tree.get_state(5) == '0000010220') - assert(tree.get_state(6) == '0000000000') - assert(tree.get_state(9) == '0000000000') + assert tree.get_state(7) == "0022100000" + assert tree.get_state(8) == "0022100000" + assert tree.get_state(0) == "2012000200" + assert tree.get_state(1) == "2012000100" + assert tree.get_state(2) == "2012000000" + assert tree.get_state(3) == "0001110220" + assert tree.get_state(4) == "0000210220" + assert tree.get_state(5) == "0000010220" + assert tree.get_state(6) == "0000000000" + assert tree.get_state(9) == "0000000000" From cf019e9a9056d78372bb390fee87ee7fdca99331 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 4 Jan 2021 19:43:34 -0800 Subject: [PATCH 16/61] More Tree boilerplate --- cassiopeia/tools/tree.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index ef5ab2b9..56e136cf 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -225,3 +225,15 @@ def print_edges(self): f"{child}[{self.get_state(child)}]: " f"{self.get_edge_length(parent, child)}" ) + + def num_cuts(self, v: int) -> int: + # TODO: Hardcoded '0'... + res = self.num_characters() - self.get_state(v).count('0') + return res + + def parent(self, v: int) -> int: + if v == self.root(): + raise ValueError("Asked for parent of root node!") + incident_edges_at_v = [edge for edge in self.edges() if edge[1] == v] + assert(len(incident_edges_at_v) == 1) + return incident_edges_at_v[0][0] From b50186e6ef99984aa3d60ad1dccac1b3e20ed17d Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 4 Jan 2021 20:39:25 -0800 Subject: [PATCH 17/61] Create branch_length_estimator package --- cassiopeia/tools/branch_length_estimator/__init__.py | 1 + .../{ => branch_length_estimator}/branch_length_estimator.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 cassiopeia/tools/branch_length_estimator/__init__.py rename cassiopeia/tools/{ => branch_length_estimator}/branch_length_estimator.py (99%) diff --git a/cassiopeia/tools/branch_length_estimator/__init__.py b/cassiopeia/tools/branch_length_estimator/__init__.py new file mode 100644 index 00000000..742de204 --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/__init__.py @@ -0,0 +1 @@ +from .branch_length_estimator import BranchLengthEstimator, IIDExponentialBLE, IIDExponentialBLEGridSearchCV diff --git a/cassiopeia/tools/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator/branch_length_estimator.py similarity index 99% rename from cassiopeia/tools/branch_length_estimator.py rename to cassiopeia/tools/branch_length_estimator/branch_length_estimator.py index 79ac4a7b..a2256d3e 100644 --- a/cassiopeia/tools/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator/branch_length_estimator.py @@ -5,7 +5,7 @@ import cvxpy as cp import numpy as np -from .tree import Tree +from ..tree import Tree class BranchLengthEstimator(abc.ABC): From 1e19f658c85ced53ff4bc2571911d79594a3c52b Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 4 Jan 2021 20:42:24 -0800 Subject: [PATCH 18/61] Break up branch_length_estimator.py into smaller modules --- .../BranchLengthEstimator.py | 27 +++++++++++++++++ ...ngth_estimator.py => IIDExponentialBLE.py} | 30 ++----------------- .../tools/branch_length_estimator/__init__.py | 3 +- cassiopeia/tools/tree.py | 4 +-- 4 files changed, 34 insertions(+), 30 deletions(-) create mode 100644 cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py rename cassiopeia/tools/branch_length_estimator/{branch_length_estimator.py => IIDExponentialBLE.py} (92%) diff --git a/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py new file mode 100644 index 00000000..18907f26 --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py @@ -0,0 +1,27 @@ +import abc + +from ..tree import Tree + + +class BranchLengthEstimator(abc.ABC): + r""" + Abstract base class for all branch length estimators. + + A BranchLengthEstimator implements a method estimate_branch_lengths which, + given a Tree with lineage tracing character vectors at the leaves (and + possibly at the internal nodes too), estimates the branch lengths of the + tree. + """ + + @abc.abstractmethod + def estimate_branch_lengths(self, tree: Tree) -> None: + r""" + Estimates the branch lengths of the tree. + + Annotates the tree's nodes with their estimated age, and + the tree's branches with their estimated lengths. Operates on the tree + in-place. + + Args: + tree: The tree for which to estimate branch lengths. + """ diff --git a/cassiopeia/tools/branch_length_estimator/branch_length_estimator.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py similarity index 92% rename from cassiopeia/tools/branch_length_estimator/branch_length_estimator.py rename to cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index a2256d3e..a8c2c8bd 100644 --- a/cassiopeia/tools/branch_length_estimator/branch_length_estimator.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -1,4 +1,3 @@ -import abc import copy from typing import List, Tuple @@ -6,30 +5,7 @@ import numpy as np from ..tree import Tree - - -class BranchLengthEstimator(abc.ABC): - r""" - Abstract base class for all branch length estimators. - - A BranchLengthEstimator implements a method estimate_branch_lengths which, - given a Tree with lineage tracing character vectors at the leaves (and - possibly at the internal nodes too), estimates the branch lengths of the - tree. - """ - - @abc.abstractmethod - def estimate_branch_lengths(self, tree: Tree) -> None: - r""" - Estimates the branch lengths of the tree. - - Annotates the tree's nodes with their estimated age, and - the tree's branches with their estiamted lengths. Operates on the tree - in-place. - - Args: - tree: The tree for which to estimate branch lengths. - """ +from .BranchLengthEstimator import BranchLengthEstimator class IIDExponentialBLE(BranchLengthEstimator): @@ -46,7 +22,7 @@ class IIDExponentialBLE(BranchLengthEstimator): Args: minimum_branch_length: Estimated branch lengths will be constrained to - have at least this lenght. + have at least this length. l2_regularization: Consecutive branches will be regularized to have similar length via an L2 penalty whose weight is given by l2_regularization. @@ -56,7 +32,7 @@ class IIDExponentialBLE(BranchLengthEstimator): log_likelihood: The log-likelihood of the training data under the estimated model. log_loss: The log-loss of the training data under the estimated model. - This is the log likhelihood plus the regularization terms. + This is the log likelihood plus the regularization terms. """ def __init__( diff --git a/cassiopeia/tools/branch_length_estimator/__init__.py b/cassiopeia/tools/branch_length_estimator/__init__.py index 742de204..b222a9f1 100644 --- a/cassiopeia/tools/branch_length_estimator/__init__.py +++ b/cassiopeia/tools/branch_length_estimator/__init__.py @@ -1 +1,2 @@ -from .branch_length_estimator import BranchLengthEstimator, IIDExponentialBLE, IIDExponentialBLEGridSearchCV +from .BranchLengthEstimator import BranchLengthEstimator +from .IIDExponentialBLE import IIDExponentialBLE, IIDExponentialBLEGridSearchCV diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 56e136cf..78cd262d 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -228,12 +228,12 @@ def print_edges(self): def num_cuts(self, v: int) -> int: # TODO: Hardcoded '0'... - res = self.num_characters() - self.get_state(v).count('0') + res = self.num_characters() - self.get_state(v).count("0") return res def parent(self, v: int) -> int: if v == self.root(): raise ValueError("Asked for parent of root node!") incident_edges_at_v = [edge for edge in self.edges() if edge[1] == v] - assert(len(incident_edges_at_v) == 1) + assert len(incident_edges_at_v) == 1 return incident_edges_at_v[0][0] From 32f56248dc6ad53b2ef571ec6b93fb0f724fdde2 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 4 Jan 2021 21:06:35 -0800 Subject: [PATCH 19/61] Add IIDExponentialPosteriorMeanBLE with tests --- cassiopeia/tools/__init__.py | 2 + .../IIDExponentialPosteriorMeanBLE.py | 275 ++++++++++++++++++ .../tools/branch_length_estimator/__init__.py | 4 + .../branch_length_estimator_test.py | 167 +++++++++++ 4 files changed, 448 insertions(+) create mode 100644 cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index f7d36a54..a9f47eca 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -2,6 +2,8 @@ BranchLengthEstimator, IIDExponentialBLE, IIDExponentialBLEGridSearchCV, + IIDExponentialPosteriorMeanBLE, + IIDExponentialPosteriorMeanBLEGridSearchCV, ) from .lineage_simulator import ( LineageSimulator, diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py new file mode 100644 index 00000000..cf2ec48a --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -0,0 +1,275 @@ +from typing import Tuple + +import numpy as np +from scipy.special import logsumexp + +from .BranchLengthEstimator import BranchLengthEstimator +from ..tree import Tree + + +class IIDExponentialPosteriorMeanBLE(BranchLengthEstimator): + r""" + Same model as IIDExponentialBLE but computes the posterior mean instead + of the MLE. The phylogeny model is chosen to be a birth process. + + This estimator requires that the ancestral states are provided. + TODO: Allow for two versions: one where the number of mutations of each + node must match exactly, and one where it must be upper bounded by the + number of mutations seen. (I believe the latter should ameliorate + subtree collapse further.) + + TODO: Use numpy autograd to do optimize the hyperparams? (Empirical Bayes) + + We compute the posterior means using a forward-backward-style algorithm + (DP on a tree). + + Args: + mutation_rate: TODO + birth_rate: TODO + discretization_level: TODO + verbose: Verbosity level. TODO + + Attributes: TODO + + """ + + def __init__( + self, mutation_rate: float, birth_rate: float, discretization_level: int + ) -> None: + # TODO: If we use autograd, we can tune the hyperparams with gradient + # descent. + self.mutation_rate = mutation_rate + # TODO: Is there some easy heuristic way to set this to a reasonable + # value and thus avoid grid searching it / optimizing it? + self.birth_rate = birth_rate + self.discretization_level = discretization_level + + def estimate_branch_lengths(self, tree: Tree) -> None: + r""" + See base class. + """ + discretization_level = self.discretization_level + self.down_cache = {} # TODO: Rename to _down_cache + self.up_cache = {} # TODO: Rename to _up_cache + self.tree = tree + log_likelihood = 0 + # TODO: Should I also add a division event when the root has multiple + # children? + for child_of_root in tree.children(tree.root()): + log_likelihood += self.down(child_of_root, discretization_level, 0) + self.log_likelihood = log_likelihood + # # # # # Compute Posteriors # # # # # + posteriors = {} + log_posteriors = {} + posterior_means = {} + for v in tree.internal_nodes(): + # Compute the posterior for this node + posterior = np.zeros(shape=(discretization_level + 1,)) + for t in range(discretization_level + 1): + posterior[t] = self.down(v, t, tree.num_cuts(v)) + self.up( + v, t, tree.num_cuts(v) + ) + posterior -= np.max(posterior) + log_posteriors[v] = posterior.copy() + posterior = np.exp(posterior) + posterior /= np.sum(posterior) + posteriors[v] = posterior + posterior_means[v] = ( + posterior * np.array(range(discretization_level + 1)) + ).sum() / discretization_level + self.posteriors = posteriors + self.log_posteriors = log_posteriors + self.posterior_means = posterior_means + # # # # # Populate the tree with the estimated branch lengths # # # # # + for node in tree.internal_nodes(): + tree.set_age(node, age=posterior_means[node]) + tree.set_age(tree.root(), age=1.0) + for leaf in tree.leaves(): + tree.set_age(leaf, age=0.0) + + for (parent, child) in tree.edges(): + new_edge_length = tree.get_age(parent) - tree.get_age(child) + tree.set_edge_length(parent, child, length=new_edge_length) + + def up(self, v, t, x) -> float: + r""" + TODO: Rename this _up. + log P(X_up(b(v)), T_up(b(v)), t \in t_b(v), X_b(v)(t) = x) + """ + if (v, t, x) in self.up_cache: + # TODO: Use a decorator instead of a hand-made cache + return self.up_cache[(v, t, x)] + # Pull out params + r = self.mutation_rate + lam = self.birth_rate + dt = 1.0 / self.discretization_level + K = self.tree.num_characters() + tree = self.tree + discretization_level = self.discretization_level + assert 0 <= t <= self.discretization_level + assert 0 <= x <= K + log_likelihood = 0.0 + if v == tree.root(): # Base case: we reached the root of the tree. + if t == discretization_level and x == tree.num_cuts(v): + log_likelihood = 0.0 + else: + log_likelihood = -np.inf + elif t == discretization_level: + # Base case: we reached the start of the process, but we're not yet + # at the root. + assert v != tree.root() + log_likelihood = -np.inf + else: # Recursion. + log_likelihoods_cases = [] + # Case 1: Nothing happened + log_likelihoods_cases.append( + np.log(1.0 - lam * dt - (K - x) * r * dt) + self.up(v, t + 1, x) + ) + # Case 2: Mutation happened + if x - 1 >= 0: + log_likelihoods_cases.append( + np.log((K - (x - 1)) * r * dt) + self.up(v, t + 1, x - 1) + ) + # Case 3: A cell division happened + if v != tree.root(): + p = tree.parent(v) + if x == tree.num_cuts(p): + siblings = [u for u in tree.children(p) if u != v] + ll = ( + np.log(lam * dt) + + self.up(p, t + 1, x) + + sum([self.down(u, t, x) for u in siblings]) + ) + if p == tree.root(): # The branch start is for free! + ll -= np.log(lam * dt) + log_likelihoods_cases.append(ll) + log_likelihood = logsumexp(log_likelihoods_cases) + self.up_cache[(v, t, x)] = log_likelihood + return log_likelihood + + def down(self, v, t, x) -> float: + r""" + TODO: Rename this _down. + log P(X_down(v), T_down(v) | t_v = t, X_v = x) + """ + if (v, t, x) in self.down_cache: + # TODO: Use a decorator instead of a hand-made cache + return self.down_cache[(v, t, x)] + # Pull out params + r = self.mutation_rate + lam = self.birth_rate + dt = 1.0 / self.discretization_level + K = self.tree.num_characters() + tree = self.tree + assert v != tree.root() + assert 0 <= t <= self.discretization_level + assert 0 <= x <= K + log_likelihood = 0.0 + if t == 0: # Base case + if v in tree.leaves() and x == tree.num_cuts(v): + log_likelihood = 0.0 + else: + log_likelihood = -np.inf + else: # Recursion. + log_likelihoods_cases = [] + # Case 1: Nothing happens + log_likelihoods_cases.append( + np.log(1.0 - lam * dt - (K - x) * r * dt) + + self.down(v, t - 1, x) + ) + # Case 2: One character mutates. + if x + 1 <= K: + log_likelihoods_cases.append( + np.log((K - x) * r * dt) + self.down(v, t - 1, x + 1) + ) + # Case 3: Cell divides + # The number of cuts at this state must match the ground truth. + # TODO: Allow for weak match at internal nodes and exact match at + # leaves. + if x == tree.num_cuts(v) and v not in tree.leaves(): + ll = sum( + [self.down(child, t - 1, x) for child in tree.children(v)] + ) + np.log(lam * dt) + log_likelihoods_cases.append(ll) + log_likelihood = logsumexp(log_likelihoods_cases) + self.down_cache[(v, t, x)] = log_likelihood + return log_likelihood + + +class IIDExponentialPosteriorMeanBLEGridSearchCV(BranchLengthEstimator): + r""" + Like IIDExponentialPosteriorMeanBLE but with automatic tuning of + hyperparameters. + + This class fits the hyperparameters of IIDExponentialPosteriorMeanBLE based + on data log-likelihood. I.e. is performs empirical Bayes. + + Args: + mutation_rates: TODO + birth_rate: TODO + discretization_level: TODO + verbose: Verbosity level. TODO + """ + + def __init__( + self, + mutation_rates: Tuple[float] = (0,), + birth_rates: Tuple[float] = (0,), + discretization_level: int = 1000, + verbose: bool = False, + ): + self.mutation_rates = mutation_rates + self.birth_rates = birth_rates + self.discretization_level = discretization_level + self.verbose = verbose + + def estimate_branch_lengths(self, tree: Tree) -> None: + r""" + See base class. + """ + mutation_rates = self.mutation_rates + birth_rates = self.birth_rates + discretization_level = self.discretization_level + verbose = self.verbose + lls = [] + grid = np.zeros(shape=(len(mutation_rates), len(birth_rates))) + for i, mutation_rate in enumerate(mutation_rates): + for j, birth_rate in enumerate(birth_rates): + if self.verbose: + print( + f"Fitting model with:\n" + f"best_mutation_rate={mutation_rate}\n" + f"best_birth_rate={birth_rate}" + ) + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + model.estimate_branch_lengths(tree) + ll = model.log_likelihood + lls.append((ll, (mutation_rate, birth_rate))) + grid[i, j] = ll + lls.sort(reverse=True) + (best_mutation_rate, best_birth_rate,) = lls[ + 0 + ][1] + if verbose: + print( + f"Refitting model with:\n" + f"best_mutation_rate={best_mutation_rate}\n" + f"best_birth_rate={best_birth_rate}" + ) + final_model = IIDExponentialPosteriorMeanBLE( + mutation_rate=best_mutation_rate, + birth_rate=best_birth_rate, + discretization_level=discretization_level, + ) + final_model.estimate_branch_lengths(tree) + self.mutation_rate = best_mutation_rate + self.birth_rate = best_birth_rate + self.log_likelihood = final_model.log_likelihood + self.posteriors = final_model.posteriors + self.log_posteriors = final_model.log_posteriors + self.posterior_means = final_model.posterior_means + self.grid = grid diff --git a/cassiopeia/tools/branch_length_estimator/__init__.py b/cassiopeia/tools/branch_length_estimator/__init__.py index b222a9f1..5346e0b8 100644 --- a/cassiopeia/tools/branch_length_estimator/__init__.py +++ b/cassiopeia/tools/branch_length_estimator/__init__.py @@ -1,2 +1,6 @@ from .BranchLengthEstimator import BranchLengthEstimator from .IIDExponentialBLE import IIDExponentialBLE, IIDExponentialBLEGridSearchCV +from .IIDExponentialPosteriorMeanBLE import ( + IIDExponentialPosteriorMeanBLE, + IIDExponentialPosteriorMeanBLEGridSearchCV, +) diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index b97c6a64..68cc439c 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -5,6 +5,8 @@ IIDExponentialBLE, IIDExponentialBLEGridSearchCV, IIDExponentialLineageTracer, + IIDExponentialPosteriorMeanBLE, + IIDExponentialPosteriorMeanBLEGridSearchCV, Tree, ) @@ -418,3 +420,168 @@ def test_IIDExponentialBLEGridSearchCV(): model.estimate_branch_lengths(tree) minimum_branch_length = model.minimum_branch_length np.testing.assert_almost_equal(minimum_branch_length, 1.0) + + +def test_IIDExponentialPosteriorMeanBLE(): + # Make a test case out of this! + + from scipy.special import binom + from scipy.special import logsumexp + + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3]) + tree.add_edges_from([(0, 1), (1, 2), (1, 3)]) + tree.nodes[0]["characters"] = "000000000" + tree.nodes[1]["characters"] = "010000110" + tree.nodes[2]["characters"] = "010110111" + tree.nodes[3]["characters"] = "011100111" + tree = Tree(tree) + + mutation_rate = 0.3 + birth_rate = 0.7 + discretization_level = 100 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + + model.estimate_branch_lengths(tree) + + print(model.log_likelihood) + + def cuts(parent, child): + zeros_parent = tree.get_state(parent).count("0") + zeros_child = tree.get_state(child).count("0") + new_cuts_child = zeros_parent - zeros_child + return new_cuts_child + + def uncuts(parent, child): + zeros_child = tree.get_state(child).count("0") + return zeros_child + + def analytical_log_joint(t): + t = 1.0 - t + if t == 0 or t == 1: + return -np.inf + e = np.exp + lg = np.log + lam = birth_rate + r = mutation_rate + res = 0.0 + res += ( + lg(lam) + -t * lam + -2 * (1.0 - t) * lam + ) # Tree topology likelihood + res += -t * r * uncuts(0, 1) + lg(1.0 - e(-t * r)) * cuts( + 0, 1 + ) # 0->1 edge likelihood + res += -(1.0 - t) * r * uncuts(1, 2) + lg( + 1.0 - e(-(1.0 - t) * r) + ) * cuts( + 1, 2 + ) # 1->2 edge likelihood + res += -(1.0 - t) * r * uncuts(1, 3) + lg( + 1.0 - e(-(1.0 - t) * r) + ) * cuts( + 1, 3 + ) # 1->3 edge likelihood + return res + + step = 2000 + analytical_log_likelihood = ( + logsumexp( + [ + analytical_log_joint(t) + for t in np.arange(1.0 / step, 1.0 - 1.0 / step, 1.0 / step) + ] + ) + - np.log(step) + + np.log(binom(cuts(0, 1) + uncuts(0, 1), cuts(0, 1))) + + np.log(binom(cuts(1, 2) + uncuts(1, 2), cuts(1, 2))) + + np.log(binom(cuts(1, 3) + uncuts(1, 3), cuts(1, 3))) + ) + + print(analytical_log_likelihood) + + np.testing.assert_approx_equal( + model.log_likelihood, analytical_log_likelihood, significant=3 + ) + + leaf = 2 + model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 + ) + + # import matplotlib.pyplot as plt + + # plt.plot(model.posteriors[1]) + # plt.show() + # print(model.posterior_means[1]) + + # Analytical posterior + analytical_posterior = np.array( + [ + analytical_log_joint(t) + for t in np.array(range(discretization_level + 1)) + / discretization_level + ] + ) + analytical_posterior -= analytical_posterior.max() + analytical_posterior = np.exp(analytical_posterior) + analytical_posterior /= analytical_posterior.sum() + # plt.plot(analytical_posterior) + # plt.show() + for i in range(discretization_level + 1): + np.testing.assert_almost_equal( + analytical_posterior[i], model.posteriors[1][i], decimal=2 + ) + + +def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): + # This is same tree as test_subtree_collapses_when_no_mutations. Should no + # longer collapse! + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4]), + tree.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) + tree.nodes[0]["characters"] = "0" + tree.nodes[1]["characters"] = "1" + tree.nodes[2]["characters"] = "1" + tree.nodes[3]["characters"] = "1" + tree.nodes[4]["characters"] = "0" + tree = Tree(tree) + + discretization_level = 100 + mutation_rates = (0.625, 0.750, 0.875) + birth_rates = (0.25, 0.50, 0.75) + model = IIDExponentialPosteriorMeanBLEGridSearchCV( + mutation_rates=mutation_rates, + birth_rates=birth_rates, + discretization_level=discretization_level, + verbose=True, + ) + + model.estimate_branch_lengths(tree) + + # import seaborn as sns + + # import matplotlib.pyplot as plt + + # sns.heatmap( + # model.grid, + # yticklabels=mutation_rates, + # xticklabels=birth_rates + # ) + # plt.ylabel('Mutation Rate') + # plt.xlabel('Birth Rate') + # plt.show() + + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[1]) + # plt.show() + # print(model.posterior_means[1]) + + np.testing.assert_almost_equal(model.posterior_means[1], 0.5006, decimal=3) + np.testing.assert_almost_equal(model.mutation_rate, 0.75) + np.testing.assert_almost_equal(model.birth_rate, 0.5) From cb4bcd6c3aa6528b15f95d0484c1d04a0246b4f3 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 4 Jan 2021 23:15:48 -0800 Subject: [PATCH 20/61] bugfix joint computation --- .../IIDExponentialPosteriorMeanBLE.py | 26 ++-- .../branch_length_estimator_test.py | 140 +++++++++++++----- 2 files changed, 117 insertions(+), 49 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index cf2ec48a..82182bbf 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -59,26 +59,30 @@ def estimate_branch_lengths(self, tree: Tree) -> None: log_likelihood += self.down(child_of_root, discretization_level, 0) self.log_likelihood = log_likelihood # # # # # Compute Posteriors # # # # # - posteriors = {} - log_posteriors = {} - posterior_means = {} + log_joints = {} # log P(t_v = t, X, T) + posteriors = {} # P(t_v = t | X, T) + posterior_means = {} # E[t_v = t | X, T] + lam = self.birth_rate + dt = 1.0 / discretization_level for v in tree.internal_nodes(): # Compute the posterior for this node - posterior = np.zeros(shape=(discretization_level + 1,)) + log_joint = np.zeros(shape=(discretization_level + 1,)) for t in range(discretization_level + 1): - posterior[t] = self.down(v, t, tree.num_cuts(v)) + self.up( - v, t, tree.num_cuts(v) + children = tree.children(v) + log_joint[t] = ( + sum([self.down(u, t, tree.num_cuts(v)) for u in children]) + + self.up(v, t, tree.num_cuts(v)) + + np.log(lam * dt) ) - posterior -= np.max(posterior) - log_posteriors[v] = posterior.copy() - posterior = np.exp(posterior) + log_joints[v] = log_joint.copy() + posterior = np.exp(log_joint - log_joint.max()) posterior /= np.sum(posterior) posteriors[v] = posterior posterior_means[v] = ( posterior * np.array(range(discretization_level + 1)) ).sum() / discretization_level + self.log_joints = log_joints self.posteriors = posteriors - self.log_posteriors = log_posteriors self.posterior_means = posterior_means # # # # # Populate the tree with the estimated branch lengths # # # # # for node in tree.internal_nodes(): @@ -269,7 +273,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: self.mutation_rate = best_mutation_rate self.birth_rate = best_birth_rate self.log_likelihood = final_model.log_likelihood + self.log_joints = final_model.log_joints self.posteriors = final_model.posteriors - self.log_posteriors = final_model.log_posteriors self.posterior_means = final_model.posterior_means self.grid = grid diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 68cc439c..78fd9ff8 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -1,5 +1,6 @@ import networkx as nx import numpy as np +import pytest from cassiopeia.tools import ( IIDExponentialBLE, @@ -423,8 +424,9 @@ def test_IIDExponentialBLEGridSearchCV(): def test_IIDExponentialPosteriorMeanBLE(): - # Make a test case out of this! - + r""" + TODO + """ from scipy.special import binom from scipy.special import logsumexp @@ -438,18 +440,14 @@ def test_IIDExponentialPosteriorMeanBLE(): tree = Tree(tree) mutation_rate = 0.3 - birth_rate = 0.7 - discretization_level = 100 + birth_rate = 0.8 + discretization_level = 200 model = IIDExponentialPosteriorMeanBLE( mutation_rate=mutation_rate, birth_rate=birth_rate, discretization_level=discretization_level, ) - model.estimate_branch_lengths(tree) - - print(model.log_likelihood) - def cuts(parent, child): zeros_parent = tree.get_state(parent).count("0") zeros_child = tree.get_state(child).count("0") @@ -461,6 +459,9 @@ def uncuts(parent, child): return zeros_child def analytical_log_joint(t): + r""" + when node 1 has age t, i.e. hangs at distance 1.0 - t from the root. + """ t = 1.0 - t if t == 0 or t == 1: return -np.inf @@ -485,35 +486,57 @@ def analytical_log_joint(t): ) * cuts( 1, 3 ) # 1->3 edge likelihood - return res - - step = 2000 - analytical_log_likelihood = ( - logsumexp( - [ - analytical_log_joint(t) - for t in np.arange(1.0 / step, 1.0 - 1.0 / step, 1.0 / step) - ] + # Adjust by the grid size so we don't overestimate the bucket's + # probability. + res -= np.log(discretization_level) + # Finally, we need to account for repetitions + res += ( + np.log(binom(cuts(0, 1) + uncuts(0, 1), cuts(0, 1))) + + np.log(binom(cuts(1, 2) + uncuts(1, 2), cuts(1, 2))) + + np.log(binom(cuts(1, 3) + uncuts(1, 3), cuts(1, 3))) ) - - np.log(step) - + np.log(binom(cuts(0, 1) + uncuts(0, 1), cuts(0, 1))) - + np.log(binom(cuts(1, 2) + uncuts(1, 2), cuts(1, 2))) - + np.log(binom(cuts(1, 3) + uncuts(1, 3), cuts(1, 3))) - ) - - print(analytical_log_likelihood) + return res + model.estimate_branch_lengths(tree) + print(f"{model.log_likelihood} = model.log_likelihood") + + # Test the model log likelihood vs its computation from the joint of the + # age of vertex 1. + model_log_joints = model.log_joints[ + 1 + ] # P(t_1 = t, X, T) where t_1 is the age of the first node. + model_log_likelihood_2 = logsumexp(model_log_joints) + print(f"{model_log_likelihood_2} = {model_log_likelihood_2}") np.testing.assert_approx_equal( - model.log_likelihood, analytical_log_likelihood, significant=3 + model.log_likelihood, model_log_likelihood_2, significant=3 ) + # Test the model log likelihood vs its computation from a leaf node. leaf = 2 model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) - print(model_log_likelihood_up) + print(f"{model_log_likelihood_up} = model_log_likelihood_up") np.testing.assert_approx_equal( model.log_likelihood, model_log_likelihood_up, significant=3 ) + # Test the model log likelihood against its analytic computation + analytical_log_joints = np.array( + [ + analytical_log_joint(t) + for t in np.array(range(discretization_level + 1)) + / discretization_level + ] + ) + analytical_log_likelihood = logsumexp(analytical_log_joints) + print(f"{analytical_log_likelihood} = analytical_log_likelihood") + np.testing.assert_approx_equal( + model.log_likelihood, analytical_log_likelihood, significant=3 + ) + + np.testing.assert_array_almost_equal( + analytical_log_joints[50:150], model.log_joints[1][50:150], decimal=1 + ) + # import matplotlib.pyplot as plt # plt.plot(model.posteriors[1]) @@ -521,23 +544,64 @@ def analytical_log_joint(t): # print(model.posterior_means[1]) # Analytical posterior - analytical_posterior = np.array( - [ - analytical_log_joint(t) - for t in np.array(range(discretization_level + 1)) - / discretization_level - ] + analytical_posterior = np.exp( + analytical_log_joints - analytical_log_joints.max() ) - analytical_posterior -= analytical_posterior.max() - analytical_posterior = np.exp(analytical_posterior) analytical_posterior /= analytical_posterior.sum() # plt.plot(analytical_posterior) # plt.show() - for i in range(discretization_level + 1): - np.testing.assert_almost_equal( - analytical_posterior[i], model.posteriors[1][i], decimal=2 + total_variation = np.sum(np.abs(analytical_posterior - model.posteriors[1])) + assert total_variation < 0.03 + + +def test_IIDExponentialPosteriorMeanBLE_2(): + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + tree.nodes[0]["characters"] = "00" + tree.nodes[1]["characters"] = "00" + tree.nodes[2]["characters"] = "10" + tree.nodes[3]["characters"] = "00" + tree.nodes[4]["characters"] = "01" + tree.nodes[5]["characters"] = "10" + tree.nodes[6]["characters"] = "11" + tree = Tree(tree) + + mutation_rate = 0.625 + birth_rate = 0.75 + discretization_level = 100 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + + model.estimate_branch_lengths(tree) + print(model.log_likelihood) + + for leaf in tree.leaves(): + model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 ) + model_log_likelihood_up_wrong = model.up( + leaf, 0, (tree.num_cuts(leaf) + 1) % 2 + ) + with pytest.raises(AssertionError): + np.testing.assert_approx_equal( + model.log_likelihood, + model_log_likelihood_up_wrong, + significant=3, + ) + + # import matplotlib.pyplot as plt + + # plt.plot(model.posteriors[1]) + # plt.show() + # print(model.posterior_means[1]) + def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): # This is same tree as test_subtree_collapses_when_no_mutations. Should no @@ -582,6 +646,6 @@ def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): # plt.show() # print(model.posterior_means[1]) - np.testing.assert_almost_equal(model.posterior_means[1], 0.5006, decimal=3) + np.testing.assert_almost_equal(model.posterior_means[1], 0.3184, decimal=3) np.testing.assert_almost_equal(model.mutation_rate, 0.75) np.testing.assert_almost_equal(model.birth_rate, 0.5) From 5fddb6fc9c6c2235a22f96b4989ee0699aecd12a Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 4 Jan 2021 23:19:05 -0800 Subject: [PATCH 21/61] More testing --- test/tools_tests/branch_length_estimator_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 78fd9ff8..529aea9d 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -553,6 +553,16 @@ def analytical_log_joint(t): total_variation = np.sum(np.abs(analytical_posterior - model.posteriors[1])) assert total_variation < 0.03 + analytical_posterior_mean = np.sum( + analytical_posterior + * np.array(range(discretization_level + 1)) + / discretization_level + ) + posterior_mean = tree.get_age(1) + np.testing.assert_approx_equal( + posterior_mean, analytical_posterior_mean, significant=2 + ) + def test_IIDExponentialPosteriorMeanBLE_2(): tree = nx.DiGraph() From 549b34e36e80ec1fce559323b422f90ad1480de8 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 5 Jan 2021 00:02:30 -0800 Subject: [PATCH 22/61] Add to cython --- .../branch_length_estimator/IIDExponentialPosteriorMeanBLE.py | 4 ++-- setup.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 82182bbf..bdb671e5 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -242,8 +242,8 @@ def estimate_branch_lengths(self, tree: Tree) -> None: if self.verbose: print( f"Fitting model with:\n" - f"best_mutation_rate={mutation_rate}\n" - f"best_birth_rate={birth_rate}" + f"mutation_rate={mutation_rate}\n" + f"birth_rate={birth_rate}" ) model = IIDExponentialPosteriorMeanBLE( mutation_rate=mutation_rate, diff --git a/setup.py b/setup.py index 549c106b..a9d3e7ef 100755 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ # files to wrap with cython to_cythonize = [Extension("cassiopeia.preprocess.doublet_utils", ["cassiopeia/preprocess/doublet_utils.pyx"]), + Extension("cassiopeia.tools.branch_length_estimator.IIDExponentialPosteriorMeanBLE", ["cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py"]), Extension("cassiopeia.preprocess.map_utils", ["cassiopeia/preprocess/map_utils.pyx"]), Extension("cassiopeia.preprocess.collapse_cython", ["cassiopeia/preprocess/collapse_cython.pyx"]), Extension("cassiopeia.solver.ilp_solver_utilities", ["cassiopeia/solver/ilp_solver_utilities.pyx"])] From a9fbf181e69566a116f84152a702f2f242a114c0 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 5 Jan 2021 18:20:27 -0800 Subject: [PATCH 23/61] A simple birth process --- cassiopeia/tools/__init__.py | 1 + cassiopeia/tools/lineage_simulator.py | 66 ++++++++++++++++++++++ cassiopeia/tools/tree.py | 29 ++++++++++ test/tools_tests/lineage_simulator_test.py | 31 +++++++++- 4 files changed, 126 insertions(+), 1 deletion(-) diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index a9f47eca..9d1b577d 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -6,6 +6,7 @@ IIDExponentialPosteriorMeanBLEGridSearchCV, ) from .lineage_simulator import ( + BirthProcess, LineageSimulator, PerfectBinaryTree, PerfectBinaryTreeWithRootBranch, diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index 0aa6e762..9e3efddf 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -2,6 +2,7 @@ from typing import List import networkx as nx +import numpy as np from .tree import Tree @@ -106,3 +107,68 @@ def simulate_lineage(self) -> Tree: tree.nodes[int(child / 2)]["age"] - branch_length ) return Tree(tree) + + +class BirthProcess(LineageSimulator): + r""" + A Birth Process with exponential holding times. + + Args: + birth_rate: Birth rate of the process + tree_depth: Depth of the simulated tree + """ + + def __init__(self, birth_rate: float, tree_depth: float): + self.birth_rate = birth_rate + self.tree_depth = tree_depth + + def simulate_lineage(self) -> Tree: + r""" + See base class. + """ + tree_depth = self.tree_depth + birth_rate = self.birth_rate + node_age = {} + node_age[0] = tree_depth + live_nodes = [1] + edges = [(0, 1)] + t = 0 + last_node_id = 1 + while t < tree_depth: + num_live_nodes = len(live_nodes) + # Wait till next node divides. + waiting_time = np.random.exponential( + 1.0 / (birth_rate * num_live_nodes) + ) + when_node_divides = t + waiting_time + del waiting_time + if when_node_divides >= tree_depth: + # The simulation has ended. + for node in live_nodes: + node_age[node] = 0 + del live_nodes + break + # Choose which node divides uniformly at random + node_that_divides = live_nodes[ + np.random.randint(low=0, high=num_live_nodes) + ] + # Remove the node that divided and add its two children + live_nodes.remove(node_that_divides) + left_child_id = last_node_id + 1 + right_child_id = last_node_id + 2 + last_node_id += 2 + live_nodes += [left_child_id, right_child_id] + edges += [ + (node_that_divides, left_child_id), + (node_that_divides, right_child_id), + ] + node_age[node_that_divides] = tree_depth - when_node_divides + t = when_node_divides + tree_nx = nx.DiGraph() + tree_nx.add_nodes_from(range(last_node_id + 1)) + tree_nx.add_edges_from(edges) + tree = Tree(tree_nx) + for node in tree.nodes(): + tree.set_age(node, node_age[node]) + tree.set_edge_length_from_node_ages() + return tree diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 78cd262d..4a302bff 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -237,3 +237,32 @@ def parent(self, v: int) -> int: incident_edges_at_v = [edge for edge in self.edges() if edge[1] == v] assert len(incident_edges_at_v) == 1 return incident_edges_at_v[0][0] + + def set_edge_length_from_node_ages(self) -> None: + r""" + Sets the edge lengths to match the node ages. + """ + for (parent, child) in self.edges(): + self.set_edge_length( + parent, child, self.get_age(parent) - self.get_age(child) + ) + + def length(self) -> float: + r""" + Total length of the tree + """ + res = 0 + for (parent, child) in self.edges(): + res += self.get_edge_length(parent, child) + return res + + def num_ancestors(self, node: int) -> int: + r""" + Number of ancestors of a node. Terribly inefficient implementation. + """ + res = 0 + root = self.root() + while node != root: + node = self.parent(node) + res += 1 + return res diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py index d74b1f49..9a5218e0 100644 --- a/test/tools_tests/lineage_simulator_test.py +++ b/test/tools_tests/lineage_simulator_test.py @@ -1,4 +1,10 @@ -from cassiopeia.tools import PerfectBinaryTree, PerfectBinaryTreeWithRootBranch +import numpy as np + +from cassiopeia.tools import ( + PerfectBinaryTree, + PerfectBinaryTreeWithRootBranch, + BirthProcess, +) def test_PerfectBinaryTree(): @@ -15,3 +21,26 @@ def test_PerfectBinaryTreeWithRootBranch(): ).simulate_lineage() newick = tree.to_newick_tree_format(print_internal_nodes=True) assert newick == "(((4:4,5:4)2:3,(6:4,7:4)3:3)1:2)0);" + + +def test_BirthProcess(): + r""" + Generate tree, then choose a random lineage can count how many nodes are on + the lineage. This is the number of times the process triggered on that + lineage. + """ + np.random.seed(1) + birth_rate = 0.6 + intensities = [] + for _ in range(10000): + tree_true = BirthProcess( + birth_rate=birth_rate, tree_depth=1.0 + ).simulate_lineage() + leaf = np.random.choice(tree_true.leaves()) + n_leaves = len(tree_true.leaves()) + n_hits = tree_true.num_ancestors(leaf) - 1 + intensity = n_leaves / 2 ** n_hits * n_hits + intensities.append(intensity) + inferred_birth_rate = np.array(intensities).mean() + print(f"{birth_rate} == {inferred_birth_rate}") + assert np.abs(birth_rate - inferred_birth_rate) < 0.05 From 1cdd4025b7c7e9cf791844ebb7edd65bfabb0348 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 5 Jan 2021 19:05:49 -0800 Subject: [PATCH 24/61] more testing --- test/tools_tests/lineage_simulator_test.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py index 9a5218e0..903dea41 100644 --- a/test/tools_tests/lineage_simulator_test.py +++ b/test/tools_tests/lineage_simulator_test.py @@ -28,19 +28,32 @@ def test_BirthProcess(): Generate tree, then choose a random lineage can count how many nodes are on the lineage. This is the number of times the process triggered on that lineage. + + Also, the probability that a tree with only one internal node is obtained is + e^-lam * (1 - e^-lam) where lam is the birth rate, so we also check this. """ np.random.seed(1) birth_rate = 0.6 intensities = [] - for _ in range(10000): + repetitions = 10000 + topology_hits = 0 + for _ in range(repetitions): tree_true = BirthProcess( birth_rate=birth_rate, tree_depth=1.0 ).simulate_lineage() + if len(tree_true.nodes()) == 4: + topology_hits += 1 leaf = np.random.choice(tree_true.leaves()) n_leaves = len(tree_true.leaves()) n_hits = tree_true.num_ancestors(leaf) - 1 intensity = n_leaves / 2 ** n_hits * n_hits intensities.append(intensity) + # Check that the probability of the topology matches + empirical_topology_prob = topology_hits / repetitions + theoretical_topology_prob = np.exp(-birth_rate) * ( + 1.0 - np.exp(-birth_rate) + ) + assert np.abs(empirical_topology_prob - theoretical_topology_prob) < 0.02 inferred_birth_rate = np.array(intensities).mean() print(f"{birth_rate} == {inferred_birth_rate}") assert np.abs(birth_rate - inferred_birth_rate) < 0.05 From 23fec63c14a87cc1d255f80a75e3c9d7684e43fe Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 5 Jan 2021 22:38:53 -0800 Subject: [PATCH 25/61] Posterior calibration test --- .../branch_length_estimator_test.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 529aea9d..6371d555 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -11,6 +11,19 @@ Tree, ) +from copy import deepcopy +import itertools +import matplotlib.pyplot as plt +import multiprocessing +import numpy as np + +from cassiopeia.tools import ( + Tree, + BirthProcess, + IIDExponentialLineageTracer, + IIDExponentialPosteriorMeanBLE, +) + def test_no_mutations(): r""" @@ -659,3 +672,84 @@ def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): np.testing.assert_almost_equal(model.posterior_means[1], 0.3184, decimal=3) np.testing.assert_almost_equal(model.mutation_rate, 0.75) np.testing.assert_almost_equal(model.birth_rate, 0.5) + + +def get_z_scores( + repetition, + birth_rate_true, + mutation_rate_true, + birth_rate_model, + mutation_rate_model, + num_characters, +): + np.random.seed(repetition) + tree = BirthProcess( + birth_rate=birth_rate_true, tree_depth=1.0 + ).simulate_lineage() + tree_true = deepcopy(tree) + IIDExponentialLineageTracer( + mutation_rate=mutation_rate_true, num_characters=num_characters + ).overlay_lineage_tracing_data(tree) + discretization_level = 100 + model = IIDExponentialPosteriorMeanBLE( + birth_rate=birth_rate_model, + mutation_rate=mutation_rate_model, + discretization_level=discretization_level, + ) + model.estimate_branch_lengths(tree) + z_scores = [] + if len(tree.internal_nodes()) > 0: + for node in [np.random.choice(tree.internal_nodes())]: + true_age = tree_true.get_age(node) + z_score = model.posteriors[node][ + : int(true_age * discretization_level) + ].sum() + z_scores.append(z_score) + return z_scores + + +def get_z_scores_under_true_model(repetition): + return get_z_scores( + repetition, + birth_rate_true=0.8, + mutation_rate_true=1.2, + birth_rate_model=0.8, + mutation_rate_model=1.2, + num_characters=3, + ) + + +def get_z_scores_under_misspecified_model(repetition): + return get_z_scores( + repetition, + birth_rate_true=0.4, + mutation_rate_true=0.6, + birth_rate_model=0.8, + mutation_rate_model=1.2, + num_characters=3, + ) + + +@pytest.mark.slow +def test_IIDExponentialPosteriorMeanBLE_posterior_calibration(): + repetitions = 1000 + + # Under the true model, the Z scores should be ~Unif[0, 1] + with multiprocessing.Pool(processes=6) as pool: + z_scores = pool.map(get_z_scores_under_true_model, range(repetitions)) + z_scores = np.array(list(itertools.chain(*z_scores))) + mean_z_score = z_scores.mean() + p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) + print(f"p_value = {p_value}") + assert p_value > 0.01 + + # Under the wrong model, the Z scores should not be ~Unif[0, 1] + with multiprocessing.Pool(processes=6) as pool: + z_scores = pool.map( + get_z_scores_under_misspecified_model, range(repetitions) + ) + z_scores = np.array(list(itertools.chain(*z_scores))) + mean_z_score = z_scores.mean() + p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) + print(f"p_value = {p_value}") + assert p_value < 0.01 From 73a52e1e24e792a41094bf19aedcb56334339509 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 5 Jan 2021 22:47:32 -0800 Subject: [PATCH 26/61] Doc test. Enable slowtests. --- conftest.py | 21 ++++ .../branch_length_estimator_test.py | 110 +++++++++--------- 2 files changed, 77 insertions(+), 54 deletions(-) create mode 100644 conftest.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..e446d0a1 --- /dev/null +++ b/conftest.py @@ -0,0 +1,21 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runslow", action="store_true", default=False, help="run slow tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "slow: mark test as slow to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 6371d555..0fddb33a 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -1,8 +1,13 @@ +import itertools +import multiprocessing +from copy import deepcopy + import networkx as nx import numpy as np import pytest from cassiopeia.tools import ( + BirthProcess, IIDExponentialBLE, IIDExponentialBLEGridSearchCV, IIDExponentialLineageTracer, @@ -11,19 +16,6 @@ Tree, ) -from copy import deepcopy -import itertools -import matplotlib.pyplot as plt -import multiprocessing -import numpy as np - -from cassiopeia.tools import ( - Tree, - BirthProcess, - IIDExponentialLineageTracer, - IIDExponentialPosteriorMeanBLE, -) - def test_no_mutations(): r""" @@ -438,10 +430,12 @@ def test_IIDExponentialBLEGridSearchCV(): def test_IIDExponentialPosteriorMeanBLE(): r""" - TODO + For a small tree with only one internal node, the likelihood of the data, + and the posterior age of the internal node, can be computed easily in + closed form. We check the theoretical values against those obtained from + our model. """ - from scipy.special import binom - from scipy.special import logsumexp + from scipy.special import binom, logsumexp tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3]) @@ -471,11 +465,13 @@ def uncuts(parent, child): zeros_child = tree.get_state(child).count("0") return zeros_child - def analytical_log_joint(t): + def analytical_log_joint(age): r""" - when node 1 has age t, i.e. hangs at distance 1.0 - t from the root. + Here t is the age of the internal node. """ - t = 1.0 - t + # Originally I did the math using the distance t of the internal node + # from the root, which is t = 1 - age. + t = 1.0 - age if t == 0 or t == 1: return -np.inf e = np.exp @@ -517,20 +513,20 @@ def analytical_log_joint(t): # age of vertex 1. model_log_joints = model.log_joints[ 1 - ] # P(t_1 = t, X, T) where t_1 is the age of the first node. + ] # log P(t_1 = t, X, T) where t_1 is the age of the first node. model_log_likelihood_2 = logsumexp(model_log_joints) print(f"{model_log_likelihood_2} = {model_log_likelihood_2}") np.testing.assert_approx_equal( model.log_likelihood, model_log_likelihood_2, significant=3 ) - # Test the model log likelihood vs its computation from a leaf node. - leaf = 2 - model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) - print(f"{model_log_likelihood_up} = model_log_likelihood_up") - np.testing.assert_approx_equal( - model.log_likelihood, model_log_likelihood_up, significant=3 - ) + # Test the model log likelihood vs its computation from the leaf nodes. + for leaf in [2, 3]: + model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) + print(f"{model_log_likelihood_up} = model_log_likelihood_up") + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 + ) # Test the model log likelihood against its analytic computation analytical_log_joints = np.array( @@ -545,27 +541,25 @@ def analytical_log_joint(t): np.testing.assert_approx_equal( model.log_likelihood, analytical_log_likelihood, significant=3 ) - + # Test the _whole_ array of log joints P(t_v = t, X, T) np.testing.assert_array_almost_equal( - analytical_log_joints[50:150], model.log_joints[1][50:150], decimal=1 + analytical_log_joints[50:-50], model.log_joints[1][50:-50], decimal=1 ) - # import matplotlib.pyplot as plt - - # plt.plot(model.posteriors[1]) - # plt.show() - # print(model.posterior_means[1]) - - # Analytical posterior + # Test the model posterior against its analytic posterior analytical_posterior = np.exp( analytical_log_joints - analytical_log_joints.max() ) analytical_posterior /= analytical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[1]) + # plt.show() # plt.plot(analytical_posterior) # plt.show() total_variation = np.sum(np.abs(analytical_posterior - model.posteriors[1])) assert total_variation < 0.03 + # Test the posterior mean against the analytical posterior mean. analytical_posterior_mean = np.sum( analytical_posterior * np.array(range(discretization_level + 1)) @@ -578,6 +572,11 @@ def analytical_log_joint(t): def test_IIDExponentialPosteriorMeanBLE_2(): + r""" + We run the Bayesian estimator on a small tree with all different leaves, + and then check that the likelihood of the data, computed from all of the + leaves, is the same. + """ tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) @@ -619,16 +618,11 @@ def test_IIDExponentialPosteriorMeanBLE_2(): significant=3, ) - # import matplotlib.pyplot as plt - - # plt.plot(model.posteriors[1]) - # plt.show() - # print(model.posterior_means[1]) - def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): - # This is same tree as test_subtree_collapses_when_no_mutations. Should no - # longer collapse! + r""" + We just check that the grid search estimator does its job on a small grid. + """ tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4]), tree.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) @@ -651,10 +645,8 @@ def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): model.estimate_branch_lengths(tree) - # import seaborn as sns - # import matplotlib.pyplot as plt - + # import seaborn as sns # sns.heatmap( # model.grid, # yticklabels=mutation_rates, @@ -664,11 +656,6 @@ def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): # plt.xlabel('Birth Rate') # plt.show() - # import matplotlib.pyplot as plt - # plt.plot(model.posteriors[1]) - # plt.show() - # print(model.posterior_means[1]) - np.testing.assert_almost_equal(model.posterior_means[1], 0.3184, decimal=3) np.testing.assert_almost_equal(model.mutation_rate, 0.75) np.testing.assert_almost_equal(model.birth_rate, 0.5) @@ -732,6 +719,15 @@ def get_z_scores_under_misspecified_model(repetition): @pytest.mark.slow def test_IIDExponentialPosteriorMeanBLE_posterior_calibration(): + r""" + Under the true model, the Z scores should be ~Unif[0, 1] + Under the wrong model, the Z scores should not be ~Unif[0, 1] + This test is slow because we need to make many repetitions to get + enough statistical power for the test to be meaningful. + We use p-values computed from the Hoeffding bound. + TODO: There might be a more powerful test, e.g. Kolmogorov–Smirnov? + (This would mean we need less repetitions and can make the test faster.) + """ repetitions = 1000 # Under the true model, the Z scores should be ~Unif[0, 1] @@ -740,8 +736,11 @@ def test_IIDExponentialPosteriorMeanBLE_posterior_calibration(): z_scores = np.array(list(itertools.chain(*z_scores))) mean_z_score = z_scores.mean() p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) - print(f"p_value = {p_value}") + print(f"p_value under true model = {p_value}") assert p_value > 0.01 + # import matplotlib.pyplot as plt + # plt.hist(z_scores, bins=10) + # plt.show() # Under the wrong model, the Z scores should not be ~Unif[0, 1] with multiprocessing.Pool(processes=6) as pool: @@ -751,5 +750,8 @@ def test_IIDExponentialPosteriorMeanBLE_posterior_calibration(): z_scores = np.array(list(itertools.chain(*z_scores))) mean_z_score = z_scores.mean() p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) - print(f"p_value = {p_value}") + print(f"p_value under misspecified model = {p_value}") assert p_value < 0.01 + # import matplotlib.pyplot as plt + # plt.hist(z_scores, bins=10) + # plt.show() From a3b87a9b85c05ca77b778724cb4d904986ba40df Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Wed, 6 Jan 2021 10:54:09 -0800 Subject: [PATCH 27/61] Multiprocessing in grid search (for IIDExponentialPosteriorMeanBLEGridSearchCV only...). Allow breaking parsimony in IIDExponentialPosteriorMeanBLE. --- .../IIDExponentialPosteriorMeanBLE.py | 155 ++++++++++++++---- .../branch_length_estimator_test.py | 4 +- 2 files changed, 123 insertions(+), 36 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index bdb671e5..f7f3bb51 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -1,5 +1,7 @@ from typing import Tuple +from copy import deepcopy +import multiprocessing import numpy as np from scipy.special import logsumexp @@ -27,6 +29,7 @@ class IIDExponentialPosteriorMeanBLE(BranchLengthEstimator): mutation_rate: TODO birth_rate: TODO discretization_level: TODO + enforce_parsimony: TODO verbose: Verbosity level. TODO Attributes: TODO @@ -34,46 +37,69 @@ class IIDExponentialPosteriorMeanBLE(BranchLengthEstimator): """ def __init__( - self, mutation_rate: float, birth_rate: float, discretization_level: int + self, + mutation_rate: float, + birth_rate: float, + discretization_level: int, + enforce_parsimony: bool = True, ) -> None: # TODO: If we use autograd, we can tune the hyperparams with gradient - # descent. + # descent? self.mutation_rate = mutation_rate # TODO: Is there some easy heuristic way to set this to a reasonable # value and thus avoid grid searching it / optimizing it? self.birth_rate = birth_rate self.discretization_level = discretization_level + self.enforce_parsimony = enforce_parsimony - def estimate_branch_lengths(self, tree: Tree) -> None: - r""" - See base class. - """ + def _compute_log_likelihood(self): + tree = self.tree discretization_level = self.discretization_level - self.down_cache = {} # TODO: Rename to _down_cache - self.up_cache = {} # TODO: Rename to _up_cache - self.tree = tree log_likelihood = 0 # TODO: Should I also add a division event when the root has multiple - # children? + # children? (If not, the joint we are computing won't integrate to 1; + # on the other hand, this is a constant multiplicative term that doesn't + # affect inference. for child_of_root in tree.children(tree.root()): log_likelihood += self.down(child_of_root, discretization_level, 0) self.log_likelihood = log_likelihood - # # # # # Compute Posteriors # # # # # + + def _compute_log_joint(self, v, t): + r""" + P(t_v = t, X, T). + Dependind on whether we are enforcing parsimony or not, we consider + different possible number of cuts for v. + """ + discretization_level = self.discretization_level + tree = self.tree + lam = self.birth_rate + enforce_parsimony = self.enforce_parsimony + dt = 1.0 / discretization_level + children = tree.children(v) + if enforce_parsimony: + valid_num_cuts = [tree.num_cuts(v)] + else: + valid_num_cuts = range(tree.num_cuts(v) + 1) + ll_for_x = [] + for x in valid_num_cuts: + ll_for_x.append( + sum([self.down(u, t, x) for u in children]) + + self.up(v, t, x) + + np.log(lam * dt) + ) + return logsumexp(ll_for_x) + + def _compute_posteriors(self): + tree = self.tree + discretization_level = self.discretization_level log_joints = {} # log P(t_v = t, X, T) posteriors = {} # P(t_v = t | X, T) posterior_means = {} # E[t_v = t | X, T] - lam = self.birth_rate - dt = 1.0 / discretization_level for v in tree.internal_nodes(): # Compute the posterior for this node log_joint = np.zeros(shape=(discretization_level + 1,)) for t in range(discretization_level + 1): - children = tree.children(v) - log_joint[t] = ( - sum([self.down(u, t, tree.num_cuts(v)) for u in children]) - + self.up(v, t, tree.num_cuts(v)) - + np.log(lam * dt) - ) + log_joint[t] = self._compute_log_joint(v, t) log_joints[v] = log_joint.copy() posterior = np.exp(log_joint - log_joint.max()) posterior /= np.sum(posterior) @@ -84,20 +110,40 @@ def estimate_branch_lengths(self, tree: Tree) -> None: self.log_joints = log_joints self.posteriors = posteriors self.posterior_means = posterior_means - # # # # # Populate the tree with the estimated branch lengths # # # # # + + def _populate_branch_lengths(self): + tree = self.tree + posterior_means = self.posterior_means for node in tree.internal_nodes(): tree.set_age(node, age=posterior_means[node]) tree.set_age(tree.root(), age=1.0) for leaf in tree.leaves(): tree.set_age(leaf, age=0.0) - for (parent, child) in tree.edges(): new_edge_length = tree.get_age(parent) - tree.get_age(child) tree.set_edge_length(parent, child, length=new_edge_length) + def estimate_branch_lengths(self, tree: Tree) -> None: + r""" + See base class. + """ + self.down_cache = {} # TODO: Rename to _down_cache + self.up_cache = {} # TODO: Rename to _up_cache + self.tree = tree + self._compute_log_likelihood() + self._compute_posteriors() + self._populate_branch_lengths() + + def compatible_with_observed_data(self, x, observed_cuts) -> bool: + # TODO: Make method private + if self.enforce_parsimony: + return x == observed_cuts + else: + return x <= observed_cuts + def up(self, v, t, x) -> float: r""" - TODO: Rename this _up. + TODO: Rename this _up? log P(X_up(b(v)), T_up(b(v)), t \in t_b(v), X_b(v)(t) = x) """ if (v, t, x) in self.up_cache: @@ -114,6 +160,7 @@ def up(self, v, t, x) -> float: assert 0 <= x <= K log_likelihood = 0.0 if v == tree.root(): # Base case: we reached the root of the tree. + # TODO: 'tree.root()' is O(n). We should have O(1) method. if t == discretization_level and x == tree.num_cuts(v): log_likelihood = 0.0 else: @@ -136,8 +183,9 @@ def up(self, v, t, x) -> float: ) # Case 3: A cell division happened if v != tree.root(): + # TODO: 'tree.root()' is O(n). We should have O(1) method. p = tree.parent(v) - if x == tree.num_cuts(p): + if self.compatible_with_observed_data(x, tree.num_cuts(p)): siblings = [u for u in tree.children(p) if u != v] ll = ( np.log(lam * dt) @@ -145,6 +193,8 @@ def up(self, v, t, x) -> float: + sum([self.down(u, t, x) for u in siblings]) ) if p == tree.root(): # The branch start is for free! + # TODO: 'tree.root()' is O(n). We should have O(1) + # method. ll -= np.log(lam * dt) log_likelihoods_cases.append(ll) log_likelihood = logsumexp(log_likelihoods_cases) @@ -153,7 +203,7 @@ def up(self, v, t, x) -> float: def down(self, v, t, x) -> float: r""" - TODO: Rename this _down. + TODO: Rename this _down? log P(X_down(v), T_down(v) | t_v = t, X_v = x) """ if (v, t, x) in self.down_cache: @@ -171,6 +221,7 @@ def down(self, v, t, x) -> float: log_likelihood = 0.0 if t == 0: # Base case if v in tree.leaves() and x == tree.num_cuts(v): + # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) check. log_likelihood = 0.0 else: log_likelihood = -np.inf @@ -190,7 +241,11 @@ def down(self, v, t, x) -> float: # The number of cuts at this state must match the ground truth. # TODO: Allow for weak match at internal nodes and exact match at # leaves. - if x == tree.num_cuts(v) and v not in tree.leaves(): + if ( + self.compatible_with_observed_data(x, tree.num_cuts(v)) + and v not in tree.leaves() + ): + # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) check. ll = sum( [self.down(child, t - 1, x) for child in tree.children(v)] ) + np.log(lam * dt) @@ -200,6 +255,17 @@ def down(self, v, t, x) -> float: return log_likelihood +def _fit_model(model_and_tree): + r""" + This is used by IIDExponentialPosteriorMeanBLEGridSearchCV to + parallelize the grid search. It must be defined here (at the top level of + the module) for multiprocessing to be able to pickle it. + """ + model, tree = model_and_tree + model.estimate_branch_lengths(tree) + return model.log_likelihood + + class IIDExponentialPosteriorMeanBLEGridSearchCV(BranchLengthEstimator): r""" Like IIDExponentialPosteriorMeanBLE but with automatic tuning of @@ -220,11 +286,15 @@ def __init__( mutation_rates: Tuple[float] = (0,), birth_rates: Tuple[float] = (0,), discretization_level: int = 1000, + enforce_parsimony: bool = True, + processes: int = 6, verbose: bool = False, ): self.mutation_rates = mutation_rates self.birth_rates = birth_rates self.discretization_level = discretization_level + self.enforce_parsimony = enforce_parsimony + self.processes = processes self.verbose = verbose def estimate_branch_lengths(self, tree: Tree) -> None: @@ -234,9 +304,14 @@ def estimate_branch_lengths(self, tree: Tree) -> None: mutation_rates = self.mutation_rates birth_rates = self.birth_rates discretization_level = self.discretization_level + enforce_parsimony = self.enforce_parsimony + processes = self.processes verbose = self.verbose lls = [] grid = np.zeros(shape=(len(mutation_rates), len(birth_rates))) + models = [] + mutation_and_birth_rates = [] + ijs = [] for i, mutation_rate in enumerate(mutation_rates): for j, birth_rate in enumerate(birth_rates): if self.verbose: @@ -245,17 +320,26 @@ def estimate_branch_lengths(self, tree: Tree) -> None: f"mutation_rate={mutation_rate}\n" f"birth_rate={birth_rate}" ) - model = IIDExponentialPosteriorMeanBLE( - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, + models.append( + IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + enforce_parsimony=enforce_parsimony, + ) ) - model.estimate_branch_lengths(tree) - ll = model.log_likelihood - lls.append((ll, (mutation_rate, birth_rate))) - grid[i, j] = ll - lls.sort(reverse=True) - (best_mutation_rate, best_birth_rate,) = lls[ + mutation_and_birth_rates.append((mutation_rate, birth_rate)) + ijs.append((i, j)) + with multiprocessing.Pool(processes=processes) as pool: + lls = pool.map( + _fit_model, + zip(models, [deepcopy(tree) for _ in range(len(models))]), + ) + lls_and_rates = list(zip(lls, mutation_and_birth_rates)) + for ll, (i, j) in list(zip(lls, ijs)): + grid[i, j] = ll + lls_and_rates.sort(reverse=True) + (best_mutation_rate, best_birth_rate,) = lls_and_rates[ 0 ][1] if verbose: @@ -268,6 +352,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: mutation_rate=best_mutation_rate, birth_rate=best_birth_rate, discretization_level=discretization_level, + enforce_parsimony=enforce_parsimony, ) final_model.estimate_branch_lengths(tree) self.mutation_rate = best_mutation_rate diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 0fddb33a..f8ac8f77 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -434,6 +434,8 @@ def test_IIDExponentialPosteriorMeanBLE(): and the posterior age of the internal node, can be computed easily in closed form. We check the theoretical values against those obtained from our model. + TODO: Add a test with a tree with 2 internal nodes and check the model + against the 2D numerical integral. """ from scipy.special import binom, logsumexp @@ -656,9 +658,9 @@ def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): # plt.xlabel('Birth Rate') # plt.show() - np.testing.assert_almost_equal(model.posterior_means[1], 0.3184, decimal=3) np.testing.assert_almost_equal(model.mutation_rate, 0.75) np.testing.assert_almost_equal(model.birth_rate, 0.5) + np.testing.assert_almost_equal(model.posterior_means[1], 0.3184, decimal=3) def get_z_scores( From e4cb1730446c4563bd007578830dda5008d56eb9 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Wed, 6 Jan 2021 12:44:52 -0800 Subject: [PATCH 28/61] Add joint log lokelihood computation classmethod to IIDExponentialPosteriorMeanBLE --- .../IIDExponentialPosteriorMeanBLE.py | 38 ++++++++++++-- cassiopeia/tools/tree.py | 19 +++++-- .../branch_length_estimator_test.py | 49 +++---------------- 3 files changed, 57 insertions(+), 49 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index f7f3bb51..ee474a46 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -3,7 +3,7 @@ from copy import deepcopy import multiprocessing import numpy as np -from scipy.special import logsumexp +from scipy.special import binom, logsumexp from .BranchLengthEstimator import BranchLengthEstimator from ..tree import Tree @@ -221,7 +221,8 @@ def down(self, v, t, x) -> float: log_likelihood = 0.0 if t == 0: # Base case if v in tree.leaves() and x == tree.num_cuts(v): - # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) check. + # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) + # check. log_likelihood = 0.0 else: log_likelihood = -np.inf @@ -245,7 +246,8 @@ def down(self, v, t, x) -> float: self.compatible_with_observed_data(x, tree.num_cuts(v)) and v not in tree.leaves() ): - # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) check. + # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) + # check. ll = sum( [self.down(child, t - 1, x) for child in tree.children(v)] ) + np.log(lam * dt) @@ -254,6 +256,36 @@ def down(self, v, t, x) -> float: self.down_cache[(v, t, x)] = log_likelihood return log_likelihood + @classmethod + def joint_log_likelihood( + self, tree: Tree, mutation_rate: float, birth_rate: float + ) -> float: + r""" + log P(T, X, branch_lengths), i.e. the log likelihood given both + character vectors _and_ branch lengths. + """ + ll = 0.0 + lam = birth_rate + r = mutation_rate + lg = np.log + e = np.exp + b = binom + for (p, c) in tree.edges(): + t = tree.get_edge_length(p, c) + # Birth process likelihood + ll += -t * lam + if c not in tree.leaves(): + ll += lg(lam) + # Mutation process likelihood + cuts = tree.number_of_mutations_along_edge(p, c) + uncuts = tree.number_of_nonmutations_along_edge(p, c) + ll += ( + (-t * r) * uncuts + + lg(1 - e(-t * r)) * cuts + + lg(b(cuts + uncuts, cuts)) + ) + return ll + def _fit_model(model_and_tree): r""" diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 4a302bff..69003088 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -142,11 +142,8 @@ def subtree_newick_representation(v: int) -> str: number_of_unmutated_characters_in_parent = self.get_state( v ).count("0") - number_of_mutations_along_edge = self.get_state(v).count( - "0" - ) - self.get_state(child).count("0") pct_of_mutated_characters_along_edge = ( - number_of_mutations_along_edge + self.number_of_mutations_along_edge(v, child) / (number_of_unmutated_characters_in_parent + 1e-100) ) subtree_newick = ( @@ -266,3 +263,17 @@ def num_ancestors(self, node: int) -> int: node = self.parent(node) res += 1 return res + + def number_of_mutations_along_edge(self, parent, child): + return self.get_state(parent).count("0") - self.get_state(child).count( + "0" + ) + + def number_of_nonmutations_along_edge(self, parent, child): + return self.get_state(child).count("0") + + def num_uncut(self, v): + return self.get_state(v).count("0") + + def num_cut(self, v): + return self.get_state(v).count("1") diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index f8ac8f77..a759b0fe 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -437,7 +437,7 @@ def test_IIDExponentialPosteriorMeanBLE(): TODO: Add a test with a tree with 2 internal nodes and check the model against the 2D numerical integral. """ - from scipy.special import binom, logsumexp + from scipy.special import logsumexp tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3]) @@ -457,16 +457,6 @@ def test_IIDExponentialPosteriorMeanBLE(): discretization_level=discretization_level, ) - def cuts(parent, child): - zeros_parent = tree.get_state(parent).count("0") - zeros_child = tree.get_state(child).count("0") - new_cuts_child = zeros_parent - zeros_child - return new_cuts_child - - def uncuts(parent, child): - zeros_child = tree.get_state(child).count("0") - return zeros_child - def analytical_log_joint(age): r""" Here t is the age of the internal node. @@ -476,37 +466,12 @@ def analytical_log_joint(age): t = 1.0 - age if t == 0 or t == 1: return -np.inf - e = np.exp - lg = np.log - lam = birth_rate - r = mutation_rate - res = 0.0 - res += ( - lg(lam) + -t * lam + -2 * (1.0 - t) * lam - ) # Tree topology likelihood - res += -t * r * uncuts(0, 1) + lg(1.0 - e(-t * r)) * cuts( - 0, 1 - ) # 0->1 edge likelihood - res += -(1.0 - t) * r * uncuts(1, 2) + lg( - 1.0 - e(-(1.0 - t) * r) - ) * cuts( - 1, 2 - ) # 1->2 edge likelihood - res += -(1.0 - t) * r * uncuts(1, 3) + lg( - 1.0 - e(-(1.0 - t) * r) - ) * cuts( - 1, 3 - ) # 1->3 edge likelihood - # Adjust by the grid size so we don't overestimate the bucket's - # probability. - res -= np.log(discretization_level) - # Finally, we need to account for repetitions - res += ( - np.log(binom(cuts(0, 1) + uncuts(0, 1), cuts(0, 1))) - + np.log(binom(cuts(1, 2) + uncuts(1, 2), cuts(1, 2))) - + np.log(binom(cuts(1, 3) + uncuts(1, 3), cuts(1, 3))) + tree_copy = deepcopy(tree) + tree_copy.set_age(1, age) + tree_copy.set_edge_length_from_node_ages() + return IIDExponentialPosteriorMeanBLE.joint_log_likelihood( + tree=tree_copy, mutation_rate=mutation_rate, birth_rate=birth_rate ) - return res model.estimate_branch_lengths(tree) print(f"{model.log_likelihood} = model.log_likelihood") @@ -533,7 +498,7 @@ def analytical_log_joint(age): # Test the model log likelihood against its analytic computation analytical_log_joints = np.array( [ - analytical_log_joint(t) + analytical_log_joint(t) - np.log(discretization_level) for t in np.array(range(discretization_level + 1)) / discretization_level ] From aac39777282b01af622ea218044d36ff1452d6de Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Wed, 6 Jan 2021 16:11:16 -0800 Subject: [PATCH 29/61] More numerical tests, better docs, better names --- .../IIDExponentialPosteriorMeanBLE.py | 140 +++++++++++- cassiopeia/tools/lineage_simulator.py | 2 +- cassiopeia/tools/tree.py | 2 +- .../branch_length_estimator_test.py | 214 ++++++++++++++---- 4 files changed, 308 insertions(+), 50 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index ee474a46..7ccc5e5d 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -1,12 +1,13 @@ +import multiprocessing +from copy import deepcopy from typing import Tuple -from copy import deepcopy -import multiprocessing import numpy as np +from scipy import integrate from scipy.special import binom, logsumexp -from .BranchLengthEstimator import BranchLengthEstimator from ..tree import Tree +from .BranchLengthEstimator import BranchLengthEstimator class IIDExponentialPosteriorMeanBLE(BranchLengthEstimator): @@ -257,13 +258,15 @@ def down(self, v, t, x) -> float: return log_likelihood @classmethod - def joint_log_likelihood( + def exact_log_full_joint( self, tree: Tree, mutation_rate: float, birth_rate: float ) -> float: r""" - log P(T, X, branch_lengths), i.e. the log likelihood given both - character vectors _and_ branch lengths. + log P(T, X, branch_lengths), i.e. the full joint log likelihood given + both character vectors _and_ branch lengths. """ + tree = deepcopy(tree) + tree.set_edge_lengths_from_node_ages() ll = 0.0 lam = birth_rate r = mutation_rate @@ -286,6 +289,131 @@ def joint_log_likelihood( ) return ll + @classmethod + def numerical_log_likelihood( + self, + tree: Tree, + mutation_rate: float, + birth_rate: float, + epsrel: float = 0.01, + ): + r""" + log P(T, X), i.e. the marginal log likelihood given _only_ tree + topology and character vectors (including those of internal nodes). + It is computed with a grid. + """ + + tree = deepcopy(tree) + + def f(*args): + ages = args + for node, age in list(zip(tree.internal_nodes(), ages)): + tree.set_age(node, age) + for (p, c) in tree.edges(): + if tree.get_age(p) <= tree.get_age(c): + return 0.0 + tree.set_edge_lengths_from_node_ages() + return np.exp( + IIDExponentialPosteriorMeanBLE.exact_log_full_joint( + tree=tree, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + ) + ) + + return np.log( + integrate.nquad( + f, + [[0, 1]] * len(tree.internal_nodes()), + opts={"epsrel": epsrel}, + )[0] + ) + + @classmethod + def numerical_log_joint( + self, + tree: Tree, + node, + mutation_rate: float, + birth_rate: float, + discretization_level: int, + epsrel: float = 0.01, + ): + r""" + log P(t_node = t, X, T) for each t in the interval [0, 1] discretized + to the level discretization_level + """ + res = np.zeros(shape=(discretization_level + 1,)) + other_nodes = [n for n in tree.internal_nodes() if n != node] + + tree = deepcopy(tree) + + def f(*args): + ages = args + for other_node, age in list(zip(other_nodes, ages)): + tree.set_age(other_node, age) + for (p, c) in tree.edges(): + if tree.get_age(p) <= tree.get_age(c): + return 0.0 + tree.set_edge_lengths_from_node_ages() + return np.exp( + IIDExponentialPosteriorMeanBLE.exact_log_full_joint( + tree=tree, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + ) + ) + + for i in range(discretization_level + 1): + node_age = i / discretization_level + tree.set_age(node, node_age) + tree.set_edge_lengths_from_node_ages() + if len(other_nodes) == 0: + # There is nothing to integrate over. + res[i] = self.exact_log_full_joint( + tree=tree, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + ) + res[i] -= np.log(discretization_level) + else: + res[i] = ( + np.log( + integrate.nquad( + f, + [[0, 1]] * (len(tree.internal_nodes()) - 1), + opts={"epsrel": epsrel}, + )[0] + ) + - np.log(discretization_level) + ) + + return res + + @classmethod + def numerical_posterior( + self, + tree: Tree, + node, + mutation_rate: float, + birth_rate: float, + discretization_level: int, + epsrel: float = 0.01, + ): + numerical_log_joint = self.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + epsrel=epsrel, + ) + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + return numerical_posterior + def _fit_model(model_and_tree): r""" diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index 9e3efddf..a5b792e0 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -170,5 +170,5 @@ def simulate_lineage(self) -> Tree: tree = Tree(tree_nx) for node in tree.nodes(): tree.set_age(node, node_age[node]) - tree.set_edge_length_from_node_ages() + tree.set_edge_lengths_from_node_ages() return tree diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 69003088..732b1eb6 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -235,7 +235,7 @@ def parent(self, v: int) -> int: assert len(incident_edges_at_v) == 1 return incident_edges_at_v[0][0] - def set_edge_length_from_node_ages(self) -> None: + def set_edge_lengths_from_node_ages(self) -> None: r""" Sets the edge lengths to match the node ages. """ diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index a759b0fe..03b7ed41 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -434,8 +434,6 @@ def test_IIDExponentialPosteriorMeanBLE(): and the posterior age of the internal node, can be computed easily in closed form. We check the theoretical values against those obtained from our model. - TODO: Add a test with a tree with 2 internal nodes and check the model - against the 2D numerical integral. """ from scipy.special import logsumexp @@ -457,22 +455,6 @@ def test_IIDExponentialPosteriorMeanBLE(): discretization_level=discretization_level, ) - def analytical_log_joint(age): - r""" - Here t is the age of the internal node. - """ - # Originally I did the math using the distance t of the internal node - # from the root, which is t = 1 - age. - t = 1.0 - age - if t == 0 or t == 1: - return -np.inf - tree_copy = deepcopy(tree) - tree_copy.set_age(1, age) - tree_copy.set_edge_length_from_node_ages() - return IIDExponentialPosteriorMeanBLE.joint_log_likelihood( - tree=tree_copy, mutation_rate=mutation_rate, birth_rate=birth_rate - ) - model.estimate_branch_lengths(tree) print(f"{model.log_likelihood} = model.log_likelihood") @@ -495,54 +477,65 @@ def analytical_log_joint(age): model.log_likelihood, model_log_likelihood_up, significant=3 ) - # Test the model log likelihood against its analytic computation - analytical_log_joints = np.array( - [ - analytical_log_joint(t) - np.log(discretization_level) - for t in np.array(range(discretization_level + 1)) - / discretization_level - ] + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) ) - analytical_log_likelihood = logsumexp(analytical_log_joints) - print(f"{analytical_log_likelihood} = analytical_log_likelihood") + print(f"{numerical_log_likelihood} = numerical_log_likelihood") np.testing.assert_approx_equal( - model.log_likelihood, analytical_log_likelihood, significant=3 + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Test the _whole_ array of log joints P(t_v = t, X, T) against its + # numerical computation + numerical_log_joint = IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=1, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, ) - # Test the _whole_ array of log joints P(t_v = t, X, T) np.testing.assert_array_almost_equal( - analytical_log_joints[50:-50], model.log_joints[1][50:-50], decimal=1 + model.log_joints[1][50:-50], numerical_log_joint[50:-50], decimal=1 ) - # Test the model posterior against its analytic posterior - analytical_posterior = np.exp( - analytical_log_joints - analytical_log_joints.max() + # Test the model posterior against its numerical posterior + numerical_posterior = IIDExponentialPosteriorMeanBLE.numerical_posterior( + tree=tree, + node=1, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, ) - analytical_posterior /= analytical_posterior.sum() # import matplotlib.pyplot as plt # plt.plot(model.posteriors[1]) # plt.show() - # plt.plot(analytical_posterior) + # plt.plot(numerical_posterior) # plt.show() - total_variation = np.sum(np.abs(analytical_posterior - model.posteriors[1])) + total_variation = np.sum(np.abs(model.posteriors[1] - numerical_posterior)) assert total_variation < 0.03 - # Test the posterior mean against the analytical posterior mean. - analytical_posterior_mean = np.sum( - analytical_posterior + # Test the posterior mean against the numerical posterior mean. + numerical_posterior_mean = np.sum( + numerical_posterior * np.array(range(discretization_level + 1)) / discretization_level ) posterior_mean = tree.get_age(1) np.testing.assert_approx_equal( - posterior_mean, analytical_posterior_mean, significant=2 + posterior_mean, numerical_posterior_mean, significant=2 ) def test_IIDExponentialPosteriorMeanBLE_2(): r""" We run the Bayesian estimator on a small tree with all different leaves, - and then check that the likelihood of the data, computed from all of the - leaves, is the same. + and then check that: + - The likelihood of the data, computed from all of the leaves, is the same. + - The posteriors of the internal node ages matches their numerical + counterpart. """ tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), @@ -568,6 +561,17 @@ def test_IIDExponentialPosteriorMeanBLE_2(): model.estimate_branch_lengths(tree) print(model.log_likelihood) + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Check that the likelihood computed from each leaf node is correct. for leaf in tree.leaves(): model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) print(model_log_likelihood_up) @@ -585,6 +589,121 @@ def test_IIDExponentialPosteriorMeanBLE_2(): significant=3, ) + # Check that the posterior ages of the nodes are correct. + for node in tree.internal_nodes(): + numerical_log_joint = ( + IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + ) + np.testing.assert_array_almost_equal( + model.log_joints[node][25:-25], + numerical_log_joint[25:-25], + decimal=1, + ) + + # Test the model posterior against its numerical posterior. + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[node]) + # plt.show() + # plt.plot(analytical_posterior) + # plt.show() + total_variation = np.sum( + np.abs(model.posteriors[node] - numerical_posterior) + ) + assert total_variation < 0.03 + + +@pytest.mark.slow +def test_IIDExponentialPosteriorMeanBLE_3(): + r""" + Same as test_IIDExponentialPosteriorMeanBLE_2 but with a weirder topology. + """ + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7]), + tree.add_edges_from( + [(0, 1), (1, 2), (1, 3), (2, 4), (2, 5), (2, 6), (0, 7)] + ) + tree.nodes[0]["characters"] = "00" + tree.nodes[1]["characters"] = "00" + tree.nodes[2]["characters"] = "10" + tree.nodes[3]["characters"] = "11" + tree.nodes[4]["characters"] = "10" + tree.nodes[5]["characters"] = "10" + tree.nodes[6]["characters"] = "11" + tree.nodes[7]["characters"] = "00" + tree = Tree(tree) + + mutation_rate = 0.625 + birth_rate = 0.75 + discretization_level = 100 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + + model.estimate_branch_lengths(tree) + print(model.log_likelihood) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Check that the likelihood computed from each leaf node is correct. + for leaf in tree.leaves(): + model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=2 + ) + + # Check that the posterior ages of the nodes are correct. + for node in tree.internal_nodes(): + numerical_log_joint = ( + IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + ) + np.testing.assert_array_almost_equal( + model.log_joints[node][25:-25], + numerical_log_joint[25:-25], + decimal=1, + ) + + # Test the model posterior against its numerical posterior. + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[node]) + # plt.show() + # plt.plot(numerical_posterior) + # plt.show() + total_variation = np.sum( + np.abs(model.posteriors[node] - numerical_posterior) + ) + assert total_variation < 0.03 + def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): r""" @@ -610,7 +729,18 @@ def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): verbose=True, ) + # Test the model log likelihood against its numerical computation model.estimate_branch_lengths(tree) + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, + mutation_rate=model.mutation_rate, + birth_rate=model.birth_rate, + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) # import matplotlib.pyplot as plt # import seaborn as sns From ac0994f89812043ad2841ad97bdeecc15312810f Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Wed, 6 Jan 2021 16:28:06 -0800 Subject: [PATCH 30/61] assert --- .../IIDExponentialPosteriorMeanBLE.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 7ccc5e5d..17372210 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -68,11 +68,12 @@ def _compute_log_likelihood(self): def _compute_log_joint(self, v, t): r""" P(t_v = t, X, T). - Dependind on whether we are enforcing parsimony or not, we consider + Depending on whether we are enforcing parsimony or not, we consider different possible number of cuts for v. """ discretization_level = self.discretization_level tree = self.tree + assert v in tree.internal_nodes() lam = self.birth_rate enforce_parsimony = self.enforce_parsimony dt = 1.0 / discretization_level @@ -419,7 +420,8 @@ def _fit_model(model_and_tree): r""" This is used by IIDExponentialPosteriorMeanBLEGridSearchCV to parallelize the grid search. It must be defined here (at the top level of - the module) for multiprocessing to be able to pickle it. + the module) for multiprocessing to be able to pickle it. (This is why + coverage misses it) """ model, tree = model_and_tree model.estimate_branch_lengths(tree) From 62ff0c3c8a7a65b44825eb04f3ed5f086945c99a Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Wed, 6 Jan 2021 17:08:08 -0800 Subject: [PATCH 31/61] One more test, this time on data from the DREAM challenge. --- .../branch_length_estimator_test.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 03b7ed41..788e9437 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -705,6 +705,86 @@ def test_IIDExponentialPosteriorMeanBLE_3(): assert total_variation < 0.03 +@pytest.mark.slow +def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(): + r""" + A tree from the DREAM subchallenge 1, verified analytically. + """ + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), + tree.add_edges_from( + [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)] + ) + tree.nodes[3]["characters"] = "2011000111" + tree.nodes[4]["characters"] = "2011010111" + tree.nodes[5]["characters"] = "2011010111" + tree.nodes[6]["characters"] = "2011010111" + tree = Tree(tree) + tree.reconstruct_ancestral_states() + + mutation_rate = 0.6 + birth_rate = 0.8 + discretization_level = 500 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + + model.estimate_branch_lengths(tree) + print(model.log_likelihood) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Check that the likelihood computed from each leaf node is correct. + for leaf in tree.leaves(): + model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 + ) + + # Check that the posterior ages of the nodes are correct. + for node in tree.internal_nodes(): + numerical_log_joint = ( + IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + ) + mean_error = np.mean(np.abs( + model.log_joints[node][25:-25] - + numerical_log_joint[25:-25]) / np.abs(numerical_log_joint[25:-25]) + ) + assert mean_error < 0.03 + + # Test the model posterior against its numerical posterior. + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[node]) + # plt.show() + # plt.plot(numerical_posterior) + # plt.show() + total_variation = np.sum( + np.abs(model.posteriors[node] - numerical_posterior) + ) + assert total_variation < 0.05 + + def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): r""" We just check that the grid search estimator does its job on a small grid. From ae0e9d678e518009845d89f7857c2ce530f360aa Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Wed, 6 Jan 2021 22:49:01 -0800 Subject: [PATCH 32/61] Bugfix IIDExponentialBLE.log_likelihood returning np.nan instead of -np.inf, which broke IIDExponentialBLEGridSearchCV --- .../IIDExponentialBLE.py | 81 +++++++++++-------- .../IIDExponentialPosteriorMeanBLE.py | 10 ++- .../branch_length_estimator_test.py | 51 ++++++++---- 3 files changed, 95 insertions(+), 47 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index a8c2c8bd..7164a106 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -1,3 +1,4 @@ +import multiprocessing import copy from typing import List, Tuple @@ -145,17 +146,17 @@ def log_likelihood(self, tree: Tree) -> float: log_likelihood = 0.0 for (parent, child) in tree.edges(): edge_length = tree.get_age(parent) - tree.get_age(child) - # TODO: hardcoded '0' here... - zeros_parent = tree.get_state(parent).count("0") - zeros_child = tree.get_state(child).count("0") - new_cuts_child = zeros_parent - zeros_child - assert new_cuts_child >= 0 + n_nonmutated = tree.number_of_nonmutations_along_edge(parent, child) + n_mutated = tree.number_of_mutations_along_edge(parent, child) + assert n_mutated >= 0 and n_nonmutated >= 0 # Add log-lik for characters that didn't get cut - log_likelihood += zeros_child * (-edge_length) + log_likelihood += n_nonmutated * (-edge_length) # Add log-lik for characters that got cut - if edge_length < 1e-8 and new_cuts_child > 0: - return -np.inf - log_likelihood += new_cuts_child * np.log(1 - np.exp(-edge_length)) + if n_mutated > 0: + if edge_length < 1e-8: + return -np.inf + log_likelihood += n_mutated * np.log(1 - np.exp(-edge_length)) + assert not np.isnan(log_likelihood) return log_likelihood @@ -178,10 +179,12 @@ def __init__( self, minimum_branch_lengths: Tuple[float] = (0,), l2_regularizations: Tuple[float] = (0,), + processes: int = 6, verbose: bool = False, ): self.minimum_branch_lengths = minimum_branch_lengths self.l2_regularizations = l2_regularizations + self.processes = processes self.verbose = verbose def estimate_branch_lengths(self, tree: Tree) -> None: @@ -198,8 +201,11 @@ def estimate_branch_lengths(self, tree: Tree) -> None: verbose = self.verbose held_out_log_likelihoods = [] # type: List[Tuple[float, List]] - for minimum_branch_length in minimum_branch_lengths: - for l2_regularization in l2_regularizations: + grid = np.zeros( + shape=(len(minimum_branch_lengths), len(l2_regularizations)) + ) + for i, minimum_branch_length in enumerate(minimum_branch_lengths): + for j, l2_regularization in enumerate(l2_regularizations): cv_log_likelihood = self._cv_log_likelihood( tree=tree, minimum_branch_length=minimum_branch_length, @@ -211,6 +217,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: [minimum_branch_length, l2_regularization], ) ) + grid[i, j] = cv_log_likelihood # Refit model on full dataset with the best hyperparameters held_out_log_likelihoods.sort(reverse=True) @@ -233,6 +240,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: self.l2_regularization = best_l2_regularization self.log_likelihood = final_model.log_likelihood self.log_loss = final_model.log_loss + self.grid = grid def _cv_log_likelihood( self, tree: Tree, minimum_branch_length: float, l2_regularization: float @@ -246,6 +254,7 @@ def _cv_log_likelihood( log-likelihood over the #character folds is returned. """ verbose = self.verbose + processes = self.processes if verbose: print( f"Cross-validating hyperparameters:" @@ -253,32 +262,22 @@ def _cv_log_likelihood( f"\nl2_regularizations={l2_regularization}" ) n_characters = tree.num_characters() - log_likelihood_folds = np.zeros(shape=(n_characters)) + params = [] for held_out_character_idx in range(n_characters): tree_train, tree_valid = self._cv_split( tree=tree, held_out_character_idx=held_out_character_idx ) - try: - IIDExponentialBLE( - minimum_branch_length=minimum_branch_length, - l2_regularization=l2_regularization, - ).estimate_branch_lengths(tree_train) - tree_valid.copy_branch_lengths(tree_other=tree_train) - held_out_log_likelihood = IIDExponentialBLE.log_likelihood( - tree_valid - ) - except cp.error.SolverError: - held_out_log_likelihood = -np.inf - log_likelihood_folds[ - held_out_character_idx - ] = held_out_log_likelihood - if verbose: - print(f"log_likelihood_folds = {log_likelihood_folds}") - print( - f"mean log_likelihood_folds = " - f"{np.mean(log_likelihood_folds)}" + model = IIDExponentialBLE( + minimum_branch_length=minimum_branch_length, + l2_regularization=l2_regularization, ) - return np.mean(log_likelihood_folds) + params.append((model, tree_train, tree_valid)) + if processes > 1: + with multiprocessing.Pool(processes=processes) as pool: + log_likelihood_folds = pool.map(_fit_model, params) + else: + log_likelihood_folds = list(map(_fit_model, params)) + return np.mean(np.array(log_likelihood_folds)) def _cv_split( self, tree: Tree, held_out_character_idx: int @@ -299,3 +298,21 @@ def _cv_split( tree_train.set_state(node, train_state) tree_valid.set_state(node, valid_data) return tree_train, tree_valid + + +def _fit_model(args): + r""" + This is used by IIDExponentialBLEGridSearchCV to + parallelize the CV folds. It must be defined here (at the top level of + the module) for multiprocessing to be able to pickle it. (This is why + coverage misses it) + """ + model, tree_train, tree_valid = args + assert tree_valid.num_characters() == 1 + try: + model.estimate_branch_lengths(tree_train) + tree_valid.copy_branch_lengths(tree_other=tree_train) + held_out_log_likelihood = IIDExponentialBLE.log_likelihood(tree_valid) + except cp.error.SolverError: + held_out_log_likelihood = -np.inf + return held_out_log_likelihood diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 17372210..c42c2272 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -160,6 +160,8 @@ def up(self, v, t, x) -> float: discretization_level = self.discretization_level assert 0 <= t <= self.discretization_level assert 0 <= x <= K + if not (1.0 - lam * dt - K * r * dt > 0): + raise ValueError("Please choose a bigger discretization_level.") log_likelihood = 0.0 if v == tree.root(): # Base case: we reached the root of the tree. # TODO: 'tree.root()' is O(n). We should have O(1) method. @@ -220,6 +222,8 @@ def down(self, v, t, x) -> float: assert v != tree.root() assert 0 <= t <= self.discretization_level assert 0 <= x <= K + if not (1.0 - lam * dt - K * r * dt > 0): + raise ValueError("Please choose a bigger discretization_level.") log_likelihood = 0.0 if t == 0: # Base case if v in tree.leaves() and x == tree.num_cuts(v): @@ -322,13 +326,15 @@ def f(*args): ) ) - return np.log( + res = np.log( integrate.nquad( f, [[0, 1]] * len(tree.internal_nodes()), opts={"epsrel": epsrel}, )[0] ) + assert not np.isnan(res) + return res @classmethod def numerical_log_joint( @@ -388,6 +394,7 @@ def f(*args): ) - np.log(discretization_level) ) + assert not np.isnan(res[i]) return res @@ -469,6 +476,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: enforce_parsimony = self.enforce_parsimony processes = self.processes verbose = self.verbose + lls = [] grid = np.zeros(shape=(len(mutation_rates), len(birth_rates))) models = [] diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 788e9437..7fb3f3a0 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -412,20 +412,45 @@ def test_subtree_collapses_when_no_mutations(): def test_IIDExponentialBLEGridSearchCV(): + r""" + We make sure to test a tree for which no regularization produces + a likelihood of 0. + """ tree = nx.DiGraph() - tree.add_nodes_from([0, 1]), - tree.add_edges_from([(0, 1)]) - tree.nodes[0]["characters"] = "000" - tree.nodes[1]["characters"] = "001" + tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7]), + tree.add_edges_from( + [(0, 1), (1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)] + ) + tree.nodes[4]["characters"] = "110" + tree.nodes[5]["characters"] = "110" + tree.nodes[6]["characters"] = "100" + tree.nodes[7]["characters"] = "100" tree = Tree(tree) + tree.reconstruct_ancestral_states() model = IIDExponentialBLEGridSearchCV( - minimum_branch_lengths=(0, 1.0, 3.0), - l2_regularizations=(0,), + minimum_branch_lengths=(0, 0.2, 4.0), + l2_regularizations=(0.0, 2.0, 4.0), verbose=True, + processes=1, ) model.estimate_branch_lengths(tree) - minimum_branch_length = model.minimum_branch_length - np.testing.assert_almost_equal(minimum_branch_length, 1.0) + print(model.grid) + assert model.grid[0, 0] == -np.inf + + # import seaborn as sns + # import matplotlib.pyplot as plt + # sns.heatmap( + # model.grid, + # yticklabels=model.minimum_branch_lengths, + # xticklabels=model.l2_regularizations, + # mask=np.isneginf(model.grid), + # ) + # plt.ylabel("minimum_branch_length") + # plt.xlabel("l2_regularization") + # plt.show() + + np.testing.assert_almost_equal(model.minimum_branch_length, 0.2) + np.testing.assert_almost_equal(model.l2_regularization, 2.0) def test_IIDExponentialPosteriorMeanBLE(): @@ -712,9 +737,7 @@ def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(): """ tree = nx.DiGraph() tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from( - [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)] - ) + tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) tree.nodes[3]["characters"] = "2011000111" tree.nodes[4]["characters"] = "2011010111" tree.nodes[5]["characters"] = "2011010111" @@ -763,9 +786,9 @@ def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(): discretization_level=discretization_level, ) ) - mean_error = np.mean(np.abs( - model.log_joints[node][25:-25] - - numerical_log_joint[25:-25]) / np.abs(numerical_log_joint[25:-25]) + mean_error = np.mean( + np.abs(model.log_joints[node][25:-25] - numerical_log_joint[25:-25]) + / np.abs(numerical_log_joint[25:-25]) ) assert mean_error < 0.03 From cb571fd6a18abef8c44b067e2416da77ee39d4c0 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Wed, 6 Jan 2021 23:02:30 -0800 Subject: [PATCH 33/61] Avoid cp.log(0) --- cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index 7164a106..da0b9f51 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -101,7 +101,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # Add log-lik for characters that didn't get cut log_likelihood += zeros_child * (-edge_length) # Add log-lik for characters that got cut - log_likelihood += new_cuts_child * cp.log(1 - cp.exp(-edge_length)) + log_likelihood += new_cuts_child * cp.log(1 - cp.exp(-edge_length - 1e-8)) # # # # # Add regularization # # # # # From d5dc2b600ffd1b938d1b369c7a100528feef140d Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Wed, 6 Jan 2021 23:28:17 -0800 Subject: [PATCH 34/61] Test that single processor & multiprocessing work --- .../IIDExponentialBLE.py | 12 +++--- .../IIDExponentialPosteriorMeanBLE.py | 9 ++-- .../branch_length_estimator_test.py | 42 ++++++++++++++++++- 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index da0b9f51..a985f0ab 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -101,7 +101,9 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # Add log-lik for characters that didn't get cut log_likelihood += zeros_child * (-edge_length) # Add log-lik for characters that got cut - log_likelihood += new_cuts_child * cp.log(1 - cp.exp(-edge_length - 1e-8)) + log_likelihood += new_cuts_child * cp.log( + 1 - cp.exp(-edge_length - 1e-8) + ) # # # # # Add regularization # # # # # @@ -272,11 +274,9 @@ def _cv_log_likelihood( l2_regularization=l2_regularization, ) params.append((model, tree_train, tree_valid)) - if processes > 1: - with multiprocessing.Pool(processes=processes) as pool: - log_likelihood_folds = pool.map(_fit_model, params) - else: - log_likelihood_folds = list(map(_fit_model, params)) + with multiprocessing.Pool(processes=processes) as pool: + map_fn = pool.map if processes > 1 else map + log_likelihood_folds = list(map_fn(_fit_model, params)) return np.mean(np.array(log_likelihood_folds)) def _cv_split( diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index c42c2272..0f71c2a6 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -501,9 +501,12 @@ def estimate_branch_lengths(self, tree: Tree) -> None: mutation_and_birth_rates.append((mutation_rate, birth_rate)) ijs.append((i, j)) with multiprocessing.Pool(processes=processes) as pool: - lls = pool.map( - _fit_model, - zip(models, [deepcopy(tree) for _ in range(len(models))]), + map_fn = pool.map if processes > 1 else map + lls = list( + map_fn( + _fit_model, + zip(models, [deepcopy(tree) for _ in range(len(models))]), + ) ) lls_and_rates = list(zip(lls, mutation_and_birth_rates)) for ll, (i, j) in list(zip(lls, ijs)): diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 7fb3f3a0..b2cff576 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -411,6 +411,26 @@ def test_subtree_collapses_when_no_mutations(): np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) +def test_IIDExponentialBLEGridSearchCV_smoke(): + r""" + Just want to see that it runs in both single and multiprocessor mode + """ + tree = nx.DiGraph() + tree.add_nodes_from([0, 1]), + tree.add_edges_from([(0, 1)]) + tree.nodes[0]["characters"] = "0" + tree.nodes[1]["characters"] = "1" + tree = Tree(tree) + for processes in [1, 2]: + model = IIDExponentialBLEGridSearchCV( + minimum_branch_lengths=(1.0,), + l2_regularizations=(1.0,), + verbose=True, + processes=processes, + ) + model.estimate_branch_lengths(tree) + + def test_IIDExponentialBLEGridSearchCV(): r""" We make sure to test a tree for which no regularization produces @@ -431,7 +451,7 @@ def test_IIDExponentialBLEGridSearchCV(): minimum_branch_lengths=(0, 0.2, 4.0), l2_regularizations=(0.0, 2.0, 4.0), verbose=True, - processes=1, + processes=6, ) model.estimate_branch_lengths(tree) print(model.grid) @@ -808,6 +828,26 @@ def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(): assert total_variation < 0.05 +def test_IIDExponentialPosteriorMeanBLEGridSeachCV_smoke(): + r""" + Just want to see that it runs in both single and multiprocessor mode + """ + tree = nx.DiGraph() + tree.add_nodes_from([0, 1]), + tree.add_edges_from([(0, 1)]) + tree.nodes[0]["characters"] = "0" + tree.nodes[1]["characters"] = "1" + tree = Tree(tree) + for processes in [1, 2]: + model = IIDExponentialPosteriorMeanBLEGridSearchCV( + mutation_rates=(0.5,), + birth_rates=(1.5,), + discretization_level=5, + verbose=True, + ) + model.estimate_branch_lengths(tree) + + def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): r""" We just check that the grid search estimator does its job on a small grid. From 3d27e7a322aaccff81f59474d2cb3d547befb727 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Thu, 7 Jan 2021 00:09:47 -0800 Subject: [PATCH 35/61] Make the minimum branch length be in terms of the tree height --- .../IIDExponentialBLE.py | 19 ++++++++++++++---- cassiopeia/tools/tree.py | 15 ++++++++++++++ .../branch_length_estimator_test.py | 20 +++++++++++++++++++ test/tools_tests/tree_test.py | 8 ++++++++ 4 files changed, 58 insertions(+), 4 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index a985f0ab..651f3c27 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -23,7 +23,10 @@ class IIDExponentialBLE(BranchLengthEstimator): Args: minimum_branch_length: Estimated branch lengths will be constrained to - have at least this length. + have at least length THIS MULTIPLE OF THE TREE HEIGHT. If this is + greater than 1.0 / [height of the tree] (where the height + is measured in terms of the greatest number of edges of any lineage) + then all edges will have length 0, so be careful! l2_regularization: Consecutive branches will be regularized to have similar length via an L2 penalty whose weight is given by l2_regularization. @@ -69,9 +72,11 @@ def estimate_branch_lengths(self, tree: Tree) -> None: for node_id in tree.nodes() ] ) + root = tree.root() time_increases_constraints = [ r_X_t_variables[parent] - >= r_X_t_variables[child] + minimum_branch_length + >= r_X_t_variables[child] + + minimum_branch_length * r_X_t_variables[root] for (parent, child) in tree.edges() ] leaves_have_age_0_constraints = [ @@ -137,8 +142,14 @@ def estimate_branch_lengths(self, tree: Tree) -> None: ) tree.set_edge_length(parent, child, length=new_edge_length) - self.log_likelihood = log_likelihood.value - self.log_loss = f_star + log_likelihood = log_likelihood.value + log_loss = f_star + if np.isnan(log_likelihood): + log_likelihood = -np.inf + if np.isnan(log_loss): + log_loss = -np.inf + self.log_likelihood = log_likelihood + self.log_loss = log_loss @classmethod def log_likelihood(self, tree: Tree) -> float: diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 732b1eb6..1bad4da9 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -277,3 +277,18 @@ def num_uncut(self, v): def num_cut(self, v): return self.get_state(v).count("1") + + def depth(self) -> int: + r""" + Depth of the tree. + E.g. the tree 0 -> 1 has depth 1. + """ + + def dfs(v): + res = 0 + for child in self.children(v): + res = max(res, dfs(child) + 1) + return res + + res = dfs(self.root()) + return res diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index b2cff576..d4a4c70a 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -411,6 +411,26 @@ def test_subtree_collapses_when_no_mutations(): np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) +def test_minimum_branch_length(): + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4]) + tree.add_edges_from([(0, 1), (0, 2), (0, 3), (2, 4)]) + tree = Tree(tree) + for node in tree.nodes(): + tree.set_state(node, "1") + tree.reconstruct_ancestral_states() + # Too large a minimum_branch_length + model = IIDExponentialBLE(minimum_branch_length=0.6) + model.estimate_branch_lengths(tree) + for node in tree.nodes(): + print(f"{node} = {tree.get_age(node)}") + assert model.log_likelihood == -np.inf + # An okay minimum_branch_length + model = IIDExponentialBLE(minimum_branch_length=0.4) + model.estimate_branch_lengths(tree) + assert model.log_likelihood != -np.inf + + def test_IIDExponentialBLEGridSearchCV_smoke(): r""" Just want to see that it runs in both single and multiprocessor mode diff --git a/test/tools_tests/tree_test.py b/test/tools_tests/tree_test.py index 21c09b4a..54a7671d 100644 --- a/test/tools_tests/tree_test.py +++ b/test/tools_tests/tree_test.py @@ -168,3 +168,11 @@ def test_reconstruct_ancestral_states_DREAM_challenge_tree_25(): assert tree.get_state(5) == "0000010220" assert tree.get_state(6) == "0000000000" assert tree.get_state(9) == "0000000000" + + +def test_depth(): + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4]) + tree.add_edges_from([(0, 1), (0, 2), (0, 3), (2, 4)]) + tree = Tree(tree) + assert tree.depth() == 2 From 563c853262fe183e73ec6e1fac220a3cd1d6db92 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Fri, 8 Jan 2021 16:22:36 -0800 Subject: [PATCH 36/61] Allow choosing how to format branch lengths in tree newick representation --- cassiopeia/tools/tree.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 1bad4da9..60074f68 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple import networkx as nx @@ -99,6 +99,7 @@ def to_newick_tree_format( append_state_to_node_name: bool = False, print_pct_of_mutated_characters_along_edge: bool = False, add_N_to_node_id: bool = False, + fmt_branch_lengths: str = "%s", ) -> str: r""" Converts tree into Newick tree format. @@ -136,7 +137,9 @@ def subtree_newick_representation(v: int) -> str: if print_internal_nodes: subtree_newick += format_node(child) # Add edge length - subtree_newick = subtree_newick + ":" + str(edge_length) + subtree_newick = ( + subtree_newick + ":" + (fmt_branch_lengths % edge_length) + ) if print_pct_of_mutated_characters_along_edge: # Also add number of mutations number_of_unmutated_characters_in_parent = self.get_state( @@ -292,3 +295,11 @@ def dfs(v): res = dfs(self.root()) return res + + def scale(self, factor: float): + r""" + The branch lengths of the tree are all scaled by this factor + """ + for node in self.nodes(): + self.set_age(node, factor * self.get_age(node)) + self.set_edge_lengths_from_node_ages() From 0db9e8b6b27bbbb5b3e2dab7ea386069419387fd Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 10 Jan 2021 16:51:38 -0800 Subject: [PATCH 37/61] CV grid plotting --- .../IIDExponentialBLE.py | 20 +++++++++++++- .../IIDExponentialPosteriorMeanBLE.py | 18 ++++++++++++- .../tools/branch_length_estimator/utils.py | 27 +++++++++++++++++++ 3 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 cassiopeia/tools/branch_length_estimator/utils.py diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index 651f3c27..f26ea302 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -1,12 +1,13 @@ import multiprocessing import copy -from typing import List, Tuple +from typing import List, Optional, Tuple import cvxpy as cp import numpy as np from ..tree import Tree from .BranchLengthEstimator import BranchLengthEstimator +from . import utils class IIDExponentialBLE(BranchLengthEstimator): @@ -288,6 +289,8 @@ def _cv_log_likelihood( with multiprocessing.Pool(processes=processes) as pool: map_fn = pool.map if processes > 1 else map log_likelihood_folds = list(map_fn(_fit_model, params)) + if verbose: + print(f"log_likelihood_folds = {log_likelihood_folds}") return np.mean(np.array(log_likelihood_folds)) def _cv_split( @@ -310,6 +313,21 @@ def _cv_split( tree_valid.set_state(node, valid_data) return tree_train, tree_valid + def plot_grid( + self, + figure_file: Optional[str] = None, + show_plot: bool = True + ): + utils.plot_grid( + grid=self.grid, + yticklabels=self.minimum_branch_lengths, + xticklabels=self.l2_regularizations, + ylabel=r"Minimum Branch Length ($\epsilon$)", + xlabel=r"l2 Regularization ($\lambda$)", + figure_file=figure_file, + show_plot=show_plot, + ) + def _fit_model(args): r""" diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 0f71c2a6..e0570985 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -1,6 +1,6 @@ import multiprocessing from copy import deepcopy -from typing import Tuple +from typing import Optional, Tuple import numpy as np from scipy import integrate @@ -8,6 +8,7 @@ from ..tree import Tree from .BranchLengthEstimator import BranchLengthEstimator +from . import utils class IIDExponentialPosteriorMeanBLE(BranchLengthEstimator): @@ -535,3 +536,18 @@ def estimate_branch_lengths(self, tree: Tree) -> None: self.posteriors = final_model.posteriors self.posterior_means = final_model.posterior_means self.grid = grid + + def plot_grid( + self, + figure_file: Optional[str] = None, + show_plot: bool = True + ): + utils.plot_grid( + grid=self.grid, + yticklabels=self.mutation_rates, + xticklabels=self.birth_rates, + ylabel=r"Mutation Rate ($r$)", + xlabel=r"Birth Rate ($\lambda$)", + figure_file=figure_file, + show_plot=show_plot, + ) diff --git a/cassiopeia/tools/branch_length_estimator/utils.py b/cassiopeia/tools/branch_length_estimator/utils.py new file mode 100644 index 00000000..deda78a6 --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/utils.py @@ -0,0 +1,27 @@ +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt +from typing import Optional + + +def plot_grid( + grid, + yticklabels, + xticklabels, + ylabel, + xlabel, + figure_file: Optional[str], + show_plot: str = True +) -> None: + sns.heatmap( + grid, + yticklabels=yticklabels, + xticklabels=xticklabels, + mask=np.isneginf(grid), + ) + plt.ylabel(ylabel) + plt.xlabel(xlabel) + if figure_file: + plt.savefig(fname=figure_file) + if show_plot: + plt.show() From 634e39b0c0955aeeb509a66eb987b808bb4fc7e9 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 10 Jan 2021 17:04:29 -0800 Subject: [PATCH 38/61] Make up(.) include the division event --- .../branch_length_estimator/IIDExponentialBLE.py | 4 +--- .../IIDExponentialPosteriorMeanBLE.py | 12 ++---------- .../tools/branch_length_estimator/utils.py | 2 +- test/tools_tests/branch_length_estimator_test.py | 16 ++++++++++++---- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index f26ea302..b69276ca 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -314,9 +314,7 @@ def _cv_split( return tree_train, tree_valid def plot_grid( - self, - figure_file: Optional[str] = None, - show_plot: bool = True + self, figure_file: Optional[str] = None, show_plot: bool = True ): utils.plot_grid( grid=self.grid, diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index e0570985..97012fb5 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -86,9 +86,7 @@ def _compute_log_joint(self, v, t): ll_for_x = [] for x in valid_num_cuts: ll_for_x.append( - sum([self.down(u, t, x) for u in children]) - + self.up(v, t, x) - + np.log(lam * dt) + sum([self.down(u, t, x) for u in children]) + self.up(v, t, x) ) return logsumexp(ll_for_x) @@ -197,10 +195,6 @@ def up(self, v, t, x) -> float: + self.up(p, t + 1, x) + sum([self.down(u, t, x) for u in siblings]) ) - if p == tree.root(): # The branch start is for free! - # TODO: 'tree.root()' is O(n). We should have O(1) - # method. - ll -= np.log(lam * dt) log_likelihoods_cases.append(ll) log_likelihood = logsumexp(log_likelihoods_cases) self.up_cache[(v, t, x)] = log_likelihood @@ -538,9 +532,7 @@ def estimate_branch_lengths(self, tree: Tree) -> None: self.grid = grid def plot_grid( - self, - figure_file: Optional[str] = None, - show_plot: bool = True + self, figure_file: Optional[str] = None, show_plot: bool = True ): utils.plot_grid( grid=self.grid, diff --git a/cassiopeia/tools/branch_length_estimator/utils.py b/cassiopeia/tools/branch_length_estimator/utils.py index deda78a6..5a3406e5 100644 --- a/cassiopeia/tools/branch_length_estimator/utils.py +++ b/cassiopeia/tools/branch_length_estimator/utils.py @@ -11,7 +11,7 @@ def plot_grid( ylabel, xlabel, figure_file: Optional[str], - show_plot: str = True + show_plot: str = True, ) -> None: sns.heatmap( grid, diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index d4a4c70a..17f509f8 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -536,7 +536,9 @@ def test_IIDExponentialPosteriorMeanBLE(): # Test the model log likelihood vs its computation from the leaf nodes. for leaf in [2, 3]: - model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) + model_log_likelihood_up = model.up( + leaf, 0, tree.num_cuts(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) print(f"{model_log_likelihood_up} = model_log_likelihood_up") np.testing.assert_approx_equal( model.log_likelihood, model_log_likelihood_up, significant=3 @@ -638,7 +640,9 @@ def test_IIDExponentialPosteriorMeanBLE_2(): # Check that the likelihood computed from each leaf node is correct. for leaf in tree.leaves(): - model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) + model_log_likelihood_up = model.up( + leaf, 0, tree.num_cuts(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) print(model_log_likelihood_up) np.testing.assert_approx_equal( model.log_likelihood, model_log_likelihood_up, significant=3 @@ -731,7 +735,9 @@ def test_IIDExponentialPosteriorMeanBLE_3(): # Check that the likelihood computed from each leaf node is correct. for leaf in tree.leaves(): - model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) + model_log_likelihood_up = model.up( + leaf, 0, tree.num_cuts(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) print(model_log_likelihood_up) np.testing.assert_approx_equal( model.log_likelihood, model_log_likelihood_up, significant=2 @@ -809,7 +815,9 @@ def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(): # Check that the likelihood computed from each leaf node is correct. for leaf in tree.leaves(): - model_log_likelihood_up = model.up(leaf, 0, tree.num_cuts(leaf)) + model_log_likelihood_up = model.up( + leaf, 0, tree.num_cuts(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) print(model_log_likelihood_up) np.testing.assert_approx_equal( model.log_likelihood, model_log_likelihood_up, significant=3 From a8b87bb05946e82e9f17431667d5e2883e41eebf Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Thu, 14 Jan 2021 17:17:27 -0800 Subject: [PATCH 39/61] Comments, remove unused methods from Tree --- .../IIDExponentialPosteriorMeanBLE.py | 6 +++- cassiopeia/tools/tree.py | 31 ++++++++++--------- test/tools_tests/tree_test.py | 10 ++++++ 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 97012fb5..811876f6 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -13,6 +13,8 @@ class IIDExponentialPosteriorMeanBLE(BranchLengthEstimator): r""" + TODO: Update to match my technical write-up. + Same model as IIDExponentialBLE but computes the posterior mean instead of the MLE. The phylogeny model is chosen to be a birth process. @@ -146,8 +148,9 @@ def up(self, v, t, x) -> float: r""" TODO: Rename this _up? log P(X_up(b(v)), T_up(b(v)), t \in t_b(v), X_b(v)(t) = x) + TODO: Update to match my technical write-up. """ - if (v, t, x) in self.up_cache: + if (v, t, x) in self.up_cache: # TODO: Use arrays # TODO: Use a decorator instead of a hand-made cache return self.up_cache[(v, t, x)] # Pull out params @@ -204,6 +207,7 @@ def down(self, v, t, x) -> float: r""" TODO: Rename this _down? log P(X_down(v), T_down(v) | t_v = t, X_v = x) + TODO: Update to match my technical write-up. """ if (v, t, x) in self.down_cache: # TODO: Use a decorator instead of a hand-made cache diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py index 60074f68..0e06c703 100644 --- a/cassiopeia/tools/tree.py +++ b/cassiopeia/tools/tree.py @@ -35,9 +35,6 @@ def internal_nodes(self) -> List[int]: tree = self.tree return [n for n in tree if n != self.root() and n not in self.leaves()] - def non_root_nodes(self) -> List[int]: - return self.leaves() + self.internal_nodes() - def nodes(self): tree = self.tree return list(tree.nodes()) @@ -247,15 +244,6 @@ def set_edge_lengths_from_node_ages(self) -> None: parent, child, self.get_age(parent) - self.get_age(child) ) - def length(self) -> float: - r""" - Total length of the tree - """ - res = 0 - for (parent, child) in self.edges(): - res += self.get_edge_length(parent, child) - return res - def num_ancestors(self, node: int) -> int: r""" Number of ancestors of a node. Terribly inefficient implementation. @@ -278,9 +266,6 @@ def number_of_nonmutations_along_edge(self, parent, child): def num_uncut(self, v): return self.get_state(v).count("0") - def num_cut(self, v): - return self.get_state(v).count("1") - def depth(self) -> int: r""" Depth of the tree. @@ -303,3 +288,19 @@ def scale(self, factor: float): for node in self.nodes(): self.set_age(node, factor * self.get_age(node)) self.set_edge_lengths_from_node_ages() + + def __str__(self): + def node_str(p, v): + res = "" + if p is not None: + res += f"({self.number_of_mutations_along_edge(p, v)})" + res += self.get_state(v) + return res + + def dfs(p, v, depth) -> List[str]: + res = ["\t" * depth + node_str(p, v) + "\n"] + for c in self.children(v): + res += dfs(v, c, depth + 1) + return res + + return "".join(dfs(None, self.root(), 0)) diff --git a/test/tools_tests/tree_test.py b/test/tools_tests/tree_test.py index 54a7671d..29917716 100644 --- a/test/tools_tests/tree_test.py +++ b/test/tools_tests/tree_test.py @@ -176,3 +176,13 @@ def test_depth(): tree.add_edges_from([(0, 1), (0, 2), (0, 3), (2, 4)]) tree = Tree(tree) assert tree.depth() == 2 + + +def test_str(): + tree = nx.DiGraph() + tree.add_nodes_from([0, 1, 2, 3, 4]) + tree.add_edges_from([(0, 1), (0, 2), (0, 3), (2, 4)]) + tree = Tree(tree) + tree.set_states([(0, "00"), (1, "01"), (2, "00"), (3, "10"), (4, "11")]) + res = str(tree) + assert res == "00\n\t(1)01\n\t(0)00\n\t\t(2)11\n\t(1)10\n" From 3b69f5c044eff6f76edca21dd9a80e33c614b007 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Thu, 21 Jan 2021 12:28:59 -0800 Subject: [PATCH 40/61] Use CassiopeiaTree for branch length estimation --- cassiopeia/data/CassiopeiaTree.py | 126 +- .../BranchLengthEstimator.py | 4 +- .../IIDExponentialBLE.py | 119 +- .../IIDExponentialPosteriorMeanBLE.py | 159 +- cassiopeia/tools/lineage_simulator.py | 33 +- cassiopeia/tools/lineage_tracing_simulator.py | 30 +- .../branch_length_estimator_test.py | 1895 +++++++++-------- test/tools_tests/lineage_simulator_test.py | 125 +- .../lineage_tracing_simulator_test.py | 50 +- 9 files changed, 1388 insertions(+), 1153 deletions(-) diff --git a/cassiopeia/data/CassiopeiaTree.py b/cassiopeia/data/CassiopeiaTree.py index 662c5e35..8b144aa8 100644 --- a/cassiopeia/data/CassiopeiaTree.py +++ b/cassiopeia/data/CassiopeiaTree.py @@ -14,6 +14,7 @@ This object can be passed to any CassiopeiaSolver subclass as well as any analysis module, like a branch length estimator or rate matrix estimator """ +import copy import ete3 import networkx as nx import numpy as np @@ -106,6 +107,7 @@ def __init__( self.__cache = {} if tree is not None: + tree = copy.deepcopy(tree) self.populate_tree(tree) def populate_tree(self, tree: Union[str, ete3.Tree, nx.DiGraph]): @@ -328,7 +330,7 @@ def leaves(self) -> List[str]: self.__cache["leaves"] = [ n for n in self.__network if self.__network.out_degree(n) == 0 ] - return self.__cache["leaves"] + return self.__cache["leaves"][:] @property def internal_nodes(self) -> List[str]: @@ -345,9 +347,31 @@ def internal_nodes(self) -> List[str]: if "internal_nodes" not in self.__cache: self.__cache["internal_nodes"] = [ - n for n in self.__network if self.__network.out_degree(n) > 1 + n for n in self.__network if self.__network.out_degree(n) > 0 ] - return self.__cache["internal_nodes"] + return self.__cache["internal_nodes"][:] + + @property + def non_root_internal_nodes(self) -> List[str]: + """Returns internal nodes in tree (excluding the root). + + Returns: + The internal nodes of the tree that are not the root (i.e. all + nodes not at the leaves, and not the root) + + Raises: + CassiopeiaTreeError if the tree has not been initialized. + """ + if self.__network is None: + raise CassiopeiaTreeError("Tree is not initialized.") + + if "non_root_internal_nodes" not in self.__cache: + res = [ + n for n in self.__network if self.__network.out_degree(n) > 0 + ] + res.remove(self.root) + self.__cache["non_root_internal_nodes"] = res + return self.__cache["non_root_internal_nodes"][:] @property def nodes(self) -> List[str]: @@ -364,7 +388,7 @@ def nodes(self) -> List[str]: if "nodes" not in self.__cache: self.__cache["nodes"] = [n for n in self.__network] - return self.__cache["nodes"] + return self.__cache["nodes"][:] @property def edges(self) -> List[Tuple[str, str]]: @@ -381,7 +405,7 @@ def edges(self) -> List[Tuple[str, str]]: if "edges" not in self.__cache: self.__cache["edges"] = [(u, v) for (u, v) in self.__network.edges] - return self.__cache["edges"] + return self.__cache["edges"][:] def is_leaf(self, node: str) -> bool: """Returns whether or not the node is a leaf. @@ -424,12 +448,16 @@ def is_internal_node(self, node: str) -> bool: raise CassiopeiaTreeError("Tree is not initialized.") return self.__network.out_degree(node) > 0 - def reconstruct_ancestral_characters(self): + def reconstruct_ancestral_characters(self, zero_the_root: bool = False): """Reconstruct ancestral character states. Reconstructs ancestral states (i.e., those character states in the internal nodes) using the Camin-Sokal parsimony criterion (i.e., irreversibility). Operates on the tree in place. + + Args: + zero_the_root: If True, the root will be forced to have unmutated + chracters. """ if self.__network is None: raise CassiopeiaTreeError("Tree is not initialized.") @@ -444,6 +472,9 @@ def reconstruct_ancestral_characters(self): ) self.__set_character_states(n, reconstructed) + if zero_the_root: + self.__set_character_states(self.root, [0] * self.n_character) + def parent(self, node: str) -> str: """Gets the parent of a node. @@ -498,28 +529,59 @@ def set_time(self, node: str, new_time: float) -> None: if self.__network is None: raise CassiopeiaTreeError("Tree is not initialized") - parent = self.parent(node) - if new_time < self.get_time(parent): - raise CassiopeiaTreeError( - "New age is less than the age of the parent." - ) + if node != self.root: + parent = self.parent(node) + if new_time < self.get_time(parent): + raise CassiopeiaTreeError( + "New time is less than the time of the parent." + ) for child in self.children(node): if new_time > self.get_time(child): raise CassiopeiaTreeError( - "New age is greater than than a child." + "New time is greater than than a child." ) self.__network.nodes[node]["time"] = new_time - self.__network[parent][node]["length"] = new_time - self.get_time( - parent - ) + if node != self.root: + self.__network[parent][node]["length"] = new_time - self.get_time( + parent + ) for child in self.children(node): self.__network[node][child]["length"] = ( self.get_time(child) - new_time ) + def set_times(self, time_dict: Dict[str, float]) -> None: + """Sets the time of all nodes in the tree. + + Args: + time_dict: Dictionary mapping nodes to their time. + + Raises: + CassiopeiaTreeError if the tree is not initialized, if the time + of any parent is greater than that of a child. + """ + if self.__network is None: + raise CassiopeiaTreeError("Tree is not initialized") + + # TODO: Check that the keys of time_dict match exactly the nodes in the + # tree and raise otherwise? + # Currently, if nodes are missing in time_dict, code below blows up. If + # extra nodes are present, they are ignored. + + for (parent, child) in self.edges: + time_parent = time_dict[parent] + time_child = time_dict[child] + if time_parent > time_child: + raise CassiopeiaTreeError( + "Time of parent greater than that of child: " + f"{time_parent} > {time_child}") + self.__network[parent][child]["length"] = time_child - time_parent + for node, time in time_dict.items(): + self.__network.nodes[node]["time"] = time + def get_time(self, node: str) -> float: """Gets the time of a node. @@ -534,6 +596,17 @@ def get_time(self, node: str) -> float: return self.__network.nodes[node]["time"] + def get_times(self) -> Dict[str, float]: + """Gets the times of all nodes. + + Raises: + CassiopeiaTreeError if the tree has not been initialized. + """ + if self.__network is None: + raise CassiopeiaTreeError("Tree is not initialized.") + + return dict([(node, self.get_time(node)) for node in self.nodes]) + def set_branch_length(self, parent: str, child: str, length: float): """Sets the length of a branch. @@ -763,6 +836,23 @@ def get_mutations_along_edge( return mutations + def get_number_of_mutations_along_edge( + self, parent: str, child: str + ) -> int: + return len(self.get_mutations_along_edge(parent, child)) + + def get_number_of_unmutated_characters_in_node( + self, node: str + ) -> int: + states = self.get_character_states(node) + return states.count(0) + + def get_number_of_mutated_characters_in_node( + self, node: str + ) -> int: + return self.n_character -\ + self.get_number_of_unmutated_characters_in_node(node) + def relabel_nodes(self, relabel_map: Dict[str, str]): """Relabels the nodes in the tree. @@ -782,3 +872,9 @@ def relabel_nodes(self, relabel_map: Dict[str, str]): # reset cache because we've changed names self.__cache = {} + + def get_tree_topology(self) -> nx.DiGraph: + r""" + Returns the underlying tree topology as a networkx DiGraph. + """ + return copy.deepcopy(self.__network) diff --git a/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py index 18907f26..33f8bddd 100644 --- a/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py +++ b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py @@ -1,6 +1,6 @@ import abc -from ..tree import Tree +from cassiopeia.data import CassiopeiaTree class BranchLengthEstimator(abc.ABC): @@ -14,7 +14,7 @@ class BranchLengthEstimator(abc.ABC): """ @abc.abstractmethod - def estimate_branch_lengths(self, tree: Tree) -> None: + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: r""" Estimates the branch lengths of the tree. diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index b69276ca..23eb66da 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -5,7 +5,7 @@ import cvxpy as cp import numpy as np -from ..tree import Tree +from cassiopeia.data import CassiopeiaTree from .BranchLengthEstimator import BranchLengthEstimator from . import utils @@ -50,7 +50,7 @@ def __init__( self.l2_regularization = l2_regularization self.verbose = verbose - def estimate_branch_lengths(self, tree: Tree) -> None: + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: r""" See base class. The only caveat is that this method raises if it fails to solve the underlying optimization problem for any reason. @@ -70,25 +70,30 @@ def estimate_branch_lengths(self, tree: Tree) -> None: r_X_t_variables = dict( [ (node_id, cp.Variable(name=f"r_X_t_{node_id}")) - for node_id in tree.nodes() + for node_id in tree.nodes ] ) - root = tree.root() + a_leaf = tree.leaves[0] + root = tree.root + root_has_time_0_constraint = [r_X_t_variables[root] == 0] time_increases_constraints = [ - r_X_t_variables[parent] - >= r_X_t_variables[child] - + minimum_branch_length * r_X_t_variables[root] - for (parent, child) in tree.edges() + r_X_t_variables[child] + >= r_X_t_variables[parent] + + minimum_branch_length * r_X_t_variables[a_leaf] + for (parent, child) in tree.edges ] - leaves_have_age_0_constraints = [ - r_X_t_variables[leaf] == 0 for leaf in tree.leaves() + leaves_have_same_time_constraints = [ + r_X_t_variables[leaf] == r_X_t_variables[a_leaf] + for leaf in tree.leaves + if leaf != a_leaf ] non_negative_r_X_t_constraints = [ r_X_t >= 0 for r_X_t in r_X_t_variables.values() ] all_constraints = ( - time_increases_constraints - + leaves_have_age_0_constraints + root_has_time_0_constraint + + time_increases_constraints + + leaves_have_same_time_constraints + non_negative_r_X_t_constraints ) @@ -97,11 +102,13 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # Because all rates are equal, the number of cuts in each node is a # sufficient statistic. This makes the solver WAY faster! - for (parent, child) in tree.edges(): - edge_length = r_X_t_variables[parent] - r_X_t_variables[child] + for (parent, child) in tree.edges: + edge_length = r_X_t_variables[child] - r_X_t_variables[parent] # TODO: hardcoded '0' here... - zeros_parent = tree.get_state(parent).count("0") - zeros_child = tree.get_state(child).count("0") + zeros_parent = tree.get_number_of_unmutated_characters_in_node( + parent + ) + zeros_child = tree.get_number_of_unmutated_characters_in_node(child) new_cuts_child = zeros_parent - zeros_child assert new_cuts_child >= 0 # Add log-lik for characters that didn't get cut @@ -114,13 +121,13 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # # # # # Add regularization # # # # # l2_penalty = 0 - for (parent, child) in tree.edges(): + for (parent, child) in tree.edges: for child_of_child in tree.children(child): edge_length_above = ( - r_X_t_variables[parent] - r_X_t_variables[child] + r_X_t_variables[child] - r_X_t_variables[parent] ) edge_length_below = ( - r_X_t_variables[child] - r_X_t_variables[child_of_child] + r_X_t_variables[child_of_child] - r_X_t_variables[child] ) l2_penalty += (edge_length_above - edge_length_below) ** 2 l2_penalty *= l2_regularization @@ -134,14 +141,12 @@ def estimate_branch_lengths(self, tree: Tree) -> None: # # # # # Populate the tree with the estimated branch lengths # # # # # - for node in tree.nodes(): - tree.set_age(node, age=r_X_t_variables[node].value) - - for (parent, child) in tree.edges(): - new_edge_length = ( - r_X_t_variables[parent].value - r_X_t_variables[child].value - ) - tree.set_edge_length(parent, child, length=new_edge_length) + times = {node: r_X_t_variables[node].value for node in tree.nodes} + # We smooth out epsilons that might make a parent's time greater + # than its child + for (parent, child) in tree.depth_first_traverse_edges(): + times[child] = max(times[parent], times[child]) + tree.set_times(times) log_likelihood = log_likelihood.value log_loss = f_star @@ -153,15 +158,17 @@ def estimate_branch_lengths(self, tree: Tree) -> None: self.log_loss = log_loss @classmethod - def log_likelihood(self, tree: Tree) -> float: + def log_likelihood(self, tree: CassiopeiaTree) -> float: r""" The log-likelihood of the given tree under the model. """ log_likelihood = 0.0 - for (parent, child) in tree.edges(): - edge_length = tree.get_age(parent) - tree.get_age(child) - n_nonmutated = tree.number_of_nonmutations_along_edge(parent, child) - n_mutated = tree.number_of_mutations_along_edge(parent, child) + for (parent, child) in tree.edges: + edge_length = tree.get_branch_length(parent, child) + n_mutated = tree.get_number_of_mutations_along_edge(parent, child) + n_nonmutated = tree.get_number_of_unmutated_characters_in_node( + child + ) assert n_mutated >= 0 and n_nonmutated >= 0 # Add log-lik for characters that didn't get cut log_likelihood += n_nonmutated * (-edge_length) @@ -201,7 +208,7 @@ def __init__( self.processes = processes self.verbose = verbose - def estimate_branch_lengths(self, tree: Tree) -> None: + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: r""" See base class. The only caveat is that this method raises if it fails to solve the underlying optimization problem for any reason. @@ -257,7 +264,10 @@ def estimate_branch_lengths(self, tree: Tree) -> None: self.grid = grid def _cv_log_likelihood( - self, tree: Tree, minimum_branch_length: float, l2_regularization: float + self, + tree: CassiopeiaTree, + minimum_branch_length: float, + l2_regularization: float, ) -> float: r""" Given the tree and the parameters of the model, returns the @@ -275,17 +285,17 @@ def _cv_log_likelihood( f"\nminimum_branch_length={minimum_branch_length}" f"\nl2_regularizations={l2_regularization}" ) - n_characters = tree.num_characters() + n_characters = tree.n_character params = [] for held_out_character_idx in range(n_characters): - tree_train, tree_valid = self._cv_split( + train_tree, valid_tree = self._cv_split( tree=tree, held_out_character_idx=held_out_character_idx ) model = IIDExponentialBLE( minimum_branch_length=minimum_branch_length, l2_regularization=l2_regularization, ) - params.append((model, tree_train, tree_valid)) + params.append((model, train_tree, valid_tree)) with multiprocessing.Pool(processes=processes) as pool: map_fn = pool.map if processes > 1 else map log_likelihood_folds = list(map_fn(_fit_model, params)) @@ -294,24 +304,29 @@ def _cv_log_likelihood( return np.mean(np.array(log_likelihood_folds)) def _cv_split( - self, tree: Tree, held_out_character_idx: int - ) -> Tuple[Tree, Tree]: + self, tree: CassiopeiaTree, held_out_character_idx: int + ) -> Tuple[CassiopeiaTree, CassiopeiaTree]: r""" Creates a training and a cross validation tree by hiding the character at position held_out_character_idx. """ - tree_train = copy.deepcopy(tree) - tree_valid = copy.deepcopy(tree) - for node in tree.nodes(): - state = tree_train.get_state(node) + tree_topology = tree.get_tree_topology() + train_states = {} + valid_states = {} + for node in tree.nodes: + state = tree.get_character_states(node) train_state = ( state[:held_out_character_idx] + state[(held_out_character_idx + 1) :] ) - valid_data = state[held_out_character_idx] - tree_train.set_state(node, train_state) - tree_valid.set_state(node, valid_data) - return tree_train, tree_valid + valid_state = [state[held_out_character_idx]] + train_states[node] = train_state + valid_states[node] = valid_state + train_tree = CassiopeiaTree(tree=tree_topology) + valid_tree = CassiopeiaTree(tree=tree_topology) + train_tree.initialize_all_character_states(train_states) + valid_tree.initialize_all_character_states(valid_states) + return train_tree, valid_tree def plot_grid( self, figure_file: Optional[str] = None, show_plot: bool = True @@ -334,12 +349,12 @@ def _fit_model(args): the module) for multiprocessing to be able to pickle it. (This is why coverage misses it) """ - model, tree_train, tree_valid = args - assert tree_valid.num_characters() == 1 + model, train_tree, valid_tree = args + assert valid_tree.n_character == 1 try: - model.estimate_branch_lengths(tree_train) - tree_valid.copy_branch_lengths(tree_other=tree_train) - held_out_log_likelihood = IIDExponentialBLE.log_likelihood(tree_valid) + model.estimate_branch_lengths(train_tree) + valid_tree.set_times(train_tree.get_times()) + held_out_log_likelihood = IIDExponentialBLE.log_likelihood(valid_tree) except cp.error.SolverError: held_out_log_likelihood = -np.inf return held_out_log_likelihood diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 811876f6..8f478116 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -6,7 +6,7 @@ from scipy import integrate from scipy.special import binom, logsumexp -from ..tree import Tree +from cassiopeia.data import CassiopeiaTree from .BranchLengthEstimator import BranchLengthEstimator from . import utils @@ -64,8 +64,8 @@ def _compute_log_likelihood(self): # children? (If not, the joint we are computing won't integrate to 1; # on the other hand, this is a constant multiplicative term that doesn't # affect inference. - for child_of_root in tree.children(tree.root()): - log_likelihood += self.down(child_of_root, discretization_level, 0) + for child_of_root in tree.children(tree.root): + log_likelihood += self.down(child_of_root, 0, 0) self.log_likelihood = log_likelihood def _compute_log_joint(self, v, t): @@ -74,17 +74,16 @@ def _compute_log_joint(self, v, t): Depending on whether we are enforcing parsimony or not, we consider different possible number of cuts for v. """ - discretization_level = self.discretization_level tree = self.tree - assert v in tree.internal_nodes() - lam = self.birth_rate + assert tree.is_internal_node(v) and v != tree.root enforce_parsimony = self.enforce_parsimony - dt = 1.0 / discretization_level children = tree.children(v) if enforce_parsimony: - valid_num_cuts = [tree.num_cuts(v)] + valid_num_cuts = [tree.get_number_of_mutated_characters_in_node(v)] else: - valid_num_cuts = range(tree.num_cuts(v) + 1) + valid_num_cuts = range( + tree.get_number_of_mutated_characters_in_node(v) + 1 + ) ll_for_x = [] for x in valid_num_cuts: ll_for_x.append( @@ -98,7 +97,7 @@ def _compute_posteriors(self): log_joints = {} # log P(t_v = t, X, T) posteriors = {} # P(t_v = t | X, T) posterior_means = {} # E[t_v = t | X, T] - for v in tree.internal_nodes(): + for v in tree.non_root_internal_nodes: # Compute the posterior for this node log_joint = np.zeros(shape=(discretization_level + 1,)) for t in range(discretization_level + 1): @@ -117,16 +116,15 @@ def _compute_posteriors(self): def _populate_branch_lengths(self): tree = self.tree posterior_means = self.posterior_means - for node in tree.internal_nodes(): - tree.set_age(node, age=posterior_means[node]) - tree.set_age(tree.root(), age=1.0) - for leaf in tree.leaves(): - tree.set_age(leaf, age=0.0) - for (parent, child) in tree.edges(): - new_edge_length = tree.get_age(parent) - tree.get_age(child) - tree.set_edge_length(parent, child, length=new_edge_length) - - def estimate_branch_lengths(self, tree: Tree) -> None: + times = {} + for node in tree.non_root_internal_nodes: + times[node] = posterior_means[node] + times[tree.root] = 0.0 + for leaf in tree.leaves: + times[leaf] = 1.0 + tree.set_times(times) + + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: r""" See base class. """ @@ -157,7 +155,7 @@ def up(self, v, t, x) -> float: r = self.mutation_rate lam = self.birth_rate dt = 1.0 / self.discretization_level - K = self.tree.num_characters() + K = self.tree.n_character tree = self.tree discretization_level = self.discretization_level assert 0 <= t <= self.discretization_level @@ -165,37 +163,38 @@ def up(self, v, t, x) -> float: if not (1.0 - lam * dt - K * r * dt > 0): raise ValueError("Please choose a bigger discretization_level.") log_likelihood = 0.0 - if v == tree.root(): # Base case: we reached the root of the tree. - # TODO: 'tree.root()' is O(n). We should have O(1) method. - if t == discretization_level and x == tree.num_cuts(v): + if v == tree.root: # Base case: we reached the root of the tree. + if t == 0 and x == tree.get_number_of_mutated_characters_in_node(v): log_likelihood = 0.0 else: log_likelihood = -np.inf - elif t == discretization_level: + elif t == 0: # Base case: we reached the start of the process, but we're not yet # at the root. - assert v != tree.root() + assert v != tree.root log_likelihood = -np.inf else: # Recursion. log_likelihoods_cases = [] # Case 1: Nothing happened log_likelihoods_cases.append( - np.log(1.0 - lam * dt - (K - x) * r * dt) + self.up(v, t + 1, x) + np.log(1.0 - lam * dt - (K - x) * r * dt) + self.up(v, t - 1, x) ) # Case 2: Mutation happened if x - 1 >= 0: log_likelihoods_cases.append( - np.log((K - (x - 1)) * r * dt) + self.up(v, t + 1, x - 1) + np.log((K - (x - 1)) * r * dt) + self.up(v, t - 1, x - 1) ) # Case 3: A cell division happened - if v != tree.root(): + if v != tree.root: # TODO: 'tree.root()' is O(n). We should have O(1) method. p = tree.parent(v) - if self.compatible_with_observed_data(x, tree.num_cuts(p)): + if self.compatible_with_observed_data( + x, tree.get_number_of_mutated_characters_in_node(p) + ): siblings = [u for u in tree.children(p) if u != v] ll = ( np.log(lam * dt) - + self.up(p, t + 1, x) + + self.up(p, t - 1, x) + sum([self.down(u, t, x) for u in siblings]) ) log_likelihoods_cases.append(ll) @@ -213,19 +212,23 @@ def down(self, v, t, x) -> float: # TODO: Use a decorator instead of a hand-made cache return self.down_cache[(v, t, x)] # Pull out params + discretization_level = self.discretization_level r = self.mutation_rate lam = self.birth_rate dt = 1.0 / self.discretization_level - K = self.tree.num_characters() + K = self.tree.n_character tree = self.tree - assert v != tree.root() + assert v != tree.root assert 0 <= t <= self.discretization_level assert 0 <= x <= K if not (1.0 - lam * dt - K * r * dt > 0): raise ValueError("Please choose a bigger discretization_level.") log_likelihood = 0.0 - if t == 0: # Base case - if v in tree.leaves() and x == tree.num_cuts(v): + if t == discretization_level: # Base case + if ( + v in tree.leaves + and x == tree.get_number_of_mutated_characters_in_node(v) + ): # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) # check. log_likelihood = 0.0 @@ -236,25 +239,27 @@ def down(self, v, t, x) -> float: # Case 1: Nothing happens log_likelihoods_cases.append( np.log(1.0 - lam * dt - (K - x) * r * dt) - + self.down(v, t - 1, x) + + self.down(v, t + 1, x) ) # Case 2: One character mutates. if x + 1 <= K: log_likelihoods_cases.append( - np.log((K - x) * r * dt) + self.down(v, t - 1, x + 1) + np.log((K - x) * r * dt) + self.down(v, t + 1, x + 1) ) # Case 3: Cell divides # The number of cuts at this state must match the ground truth. # TODO: Allow for weak match at internal nodes and exact match at # leaves. if ( - self.compatible_with_observed_data(x, tree.num_cuts(v)) - and v not in tree.leaves() + self.compatible_with_observed_data( + x, tree.get_number_of_mutated_characters_in_node(v) + ) + and v not in tree.leaves ): # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) # check. ll = sum( - [self.down(child, t - 1, x) for child in tree.children(v)] + [self.down(child, t + 1, x) for child in tree.children(v)] ) + np.log(lam * dt) log_likelihoods_cases.append(ll) log_likelihood = logsumexp(log_likelihoods_cases) @@ -263,29 +268,28 @@ def down(self, v, t, x) -> float: @classmethod def exact_log_full_joint( - self, tree: Tree, mutation_rate: float, birth_rate: float + self, tree: CassiopeiaTree, mutation_rate: float, birth_rate: float ) -> float: r""" log P(T, X, branch_lengths), i.e. the full joint log likelihood given both character vectors _and_ branch lengths. """ tree = deepcopy(tree) - tree.set_edge_lengths_from_node_ages() ll = 0.0 lam = birth_rate r = mutation_rate lg = np.log e = np.exp b = binom - for (p, c) in tree.edges(): - t = tree.get_edge_length(p, c) + for (p, c) in tree.edges: + t = tree.get_branch_length(p, c) # Birth process likelihood ll += -t * lam - if c not in tree.leaves(): + if c not in tree.leaves: ll += lg(lam) # Mutation process likelihood - cuts = tree.number_of_mutations_along_edge(p, c) - uncuts = tree.number_of_nonmutations_along_edge(p, c) + cuts = tree.get_number_of_mutations_along_edge(p, c) + uncuts = tree.get_number_of_unmutated_characters_in_node(c) ll += ( (-t * r) * uncuts + lg(1 - e(-t * r)) * cuts @@ -296,7 +300,7 @@ def exact_log_full_joint( @classmethod def numerical_log_likelihood( self, - tree: Tree, + tree: CassiopeiaTree, mutation_rate: float, birth_rate: float, epsrel: float = 0.01, @@ -310,13 +314,19 @@ def numerical_log_likelihood( tree = deepcopy(tree) def f(*args): - ages = args - for node, age in list(zip(tree.internal_nodes(), ages)): - tree.set_age(node, age) - for (p, c) in tree.edges(): - if tree.get_age(p) <= tree.get_age(c): + times_list = args + times = {} + for node, time in list( + zip(tree.non_root_internal_nodes, times_list) + ): + times[node] = time + times[tree.root] = 0 + for leaf in tree.leaves: + times[leaf] = 1.0 + for (p, c) in tree.edges: + if times[p] >= times[c]: return 0.0 - tree.set_edge_lengths_from_node_ages() + tree.set_times(times) return np.exp( IIDExponentialPosteriorMeanBLE.exact_log_full_joint( tree=tree, @@ -328,7 +338,7 @@ def f(*args): res = np.log( integrate.nquad( f, - [[0, 1]] * len(tree.internal_nodes()), + [[0, 1]] * len(tree.non_root_internal_nodes), opts={"epsrel": epsrel}, )[0] ) @@ -338,7 +348,7 @@ def f(*args): @classmethod def numerical_log_joint( self, - tree: Tree, + tree: CassiopeiaTree, node, mutation_rate: float, birth_rate: float, @@ -350,18 +360,25 @@ def numerical_log_joint( to the level discretization_level """ res = np.zeros(shape=(discretization_level + 1,)) - other_nodes = [n for n in tree.internal_nodes() if n != node] + other_nodes = [n for n in tree.non_root_internal_nodes if n != node] + node_time = -1 tree = deepcopy(tree) def f(*args): - ages = args - for other_node, age in list(zip(other_nodes, ages)): - tree.set_age(other_node, age) - for (p, c) in tree.edges(): - if tree.get_age(p) <= tree.get_age(c): + times_list = args + times = {} + times[node] = node_time + assert len(other_nodes) == len(times_list) + for other_node, time in list(zip(other_nodes, times_list)): + times[other_node] = time + times[tree.root] = 0 + for leaf in tree.leaves: + times[leaf] = 1.0 + for (p, c) in tree.edges: + if times[p] >= times[c]: return 0.0 - tree.set_edge_lengths_from_node_ages() + tree.set_times(times) return np.exp( IIDExponentialPosteriorMeanBLE.exact_log_full_joint( tree=tree, @@ -371,11 +388,15 @@ def f(*args): ) for i in range(discretization_level + 1): - node_age = i / discretization_level - tree.set_age(node, node_age) - tree.set_edge_lengths_from_node_ages() + node_time = i / discretization_level if len(other_nodes) == 0: # There is nothing to integrate over. + times = {} + times[tree.root] = 0 + for leaf in tree.leaves: + times[leaf] = 1.0 + times[node] = node_time + tree.set_times(times) res[i] = self.exact_log_full_joint( tree=tree, mutation_rate=mutation_rate, @@ -387,7 +408,7 @@ def f(*args): np.log( integrate.nquad( f, - [[0, 1]] * (len(tree.internal_nodes()) - 1), + [[0, 1]] * (len(tree.non_root_internal_nodes) - 1), opts={"epsrel": epsrel}, )[0] ) @@ -400,7 +421,7 @@ def f(*args): @classmethod def numerical_posterior( self, - tree: Tree, + tree: CassiopeiaTree, node, mutation_rate: float, birth_rate: float, @@ -465,7 +486,7 @@ def __init__( self.processes = processes self.verbose = verbose - def estimate_branch_lengths(self, tree: Tree) -> None: + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: r""" See base class. """ diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index a5b792e0..9eb4d1f6 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -4,7 +4,7 @@ import networkx as nx import numpy as np -from .tree import Tree +from cassiopeia.data import CassiopeiaTree class LineageSimulator(abc.ABC): @@ -16,7 +16,7 @@ class LineageSimulator(abc.ABC): """ @abc.abstractmethod - def simulate_lineage(self) -> Tree: + def simulate_lineage(self) -> CassiopeiaTree: r""" Simulates a lineage tree, i.e. a Tree with branch lengths and age specified for each node. Additional information such as cell fitness, @@ -36,7 +36,7 @@ class PerfectBinaryTree(LineageSimulator): def __init__(self, generation_branch_lengths: List[float]): self.generation_branch_lengths = generation_branch_lengths[:] - def simulate_lineage(self) -> Tree: + def simulate_lineage(self) -> CassiopeiaTree: r""" See base class. """ @@ -63,7 +63,12 @@ def simulate_lineage(self) -> Tree: tree.nodes[child]["age"] = ( tree.nodes[int((child - 1) / 2)]["age"] - branch_length ) - return Tree(tree) + times = {} + for node in tree.nodes: + times[node] = tree.nodes[0]["age"] - tree.nodes[node]["age"] + res = CassiopeiaTree(tree=tree) + res.set_times(times) + return res class PerfectBinaryTreeWithRootBranch(LineageSimulator): @@ -79,7 +84,7 @@ class PerfectBinaryTreeWithRootBranch(LineageSimulator): def __init__(self, generation_branch_lengths: List[float]): self.generation_branch_lengths = generation_branch_lengths - def simulate_lineage(self) -> Tree: + def simulate_lineage(self) -> CassiopeiaTree: r""" See base class. """ @@ -106,7 +111,12 @@ def simulate_lineage(self) -> Tree: tree.nodes[child]["age"] = ( tree.nodes[int(child / 2)]["age"] - branch_length ) - return Tree(tree) + times = {} + for node in tree.nodes: + times[node] = tree.nodes[0]["age"] - tree.nodes[node]["age"] + res = CassiopeiaTree(tree=tree) + res.set_times(times) + return res class BirthProcess(LineageSimulator): @@ -122,7 +132,7 @@ def __init__(self, birth_rate: float, tree_depth: float): self.birth_rate = birth_rate self.tree_depth = tree_depth - def simulate_lineage(self) -> Tree: + def simulate_lineage(self) -> CassiopeiaTree: r""" See base class. """ @@ -167,8 +177,9 @@ def simulate_lineage(self) -> Tree: tree_nx = nx.DiGraph() tree_nx.add_nodes_from(range(last_node_id + 1)) tree_nx.add_edges_from(edges) - tree = Tree(tree_nx) - for node in tree.nodes(): - tree.set_age(node, node_age[node]) - tree.set_edge_lengths_from_node_ages() + times = {} + for node in tree_nx.nodes: + times[node] = node_age[0] - node_age[node] + tree = CassiopeiaTree(tree=tree_nx) + tree.set_times(times) return tree diff --git a/cassiopeia/tools/lineage_tracing_simulator.py b/cassiopeia/tools/lineage_tracing_simulator.py index fb01055e..c03051a9 100644 --- a/cassiopeia/tools/lineage_tracing_simulator.py +++ b/cassiopeia/tools/lineage_tracing_simulator.py @@ -2,7 +2,7 @@ import numpy as np -from .tree import Tree +from cassiopeia.data import CassiopeiaTree class LineageTracingSimulator(abc.ABC): @@ -15,7 +15,7 @@ class LineageTracingSimulator(abc.ABC): """ @abc.abstractmethod - def overlay_lineage_tracing_data(self, tree: Tree) -> None: + def overlay_lineage_tracing_data(self, tree: CassiopeiaTree) -> None: r""" Annotates the tree's nodes with lineage tracing character vectors. These are stored as the node's state. (Operates on the tree in-place.) @@ -38,27 +38,28 @@ def __init__(self, mutation_rate: float, num_characters: float): self.mutation_rate = mutation_rate self.num_characters = num_characters - def overlay_lineage_tracing_data(self, tree: Tree) -> None: + def overlay_lineage_tracing_data(self, tree: CassiopeiaTree) -> None: r""" See base class. """ num_characters = self.num_characters mutation_rate = self.mutation_rate + states = {} - def dfs(node: int, tree: Tree): - node_state = tree.get_state(node) + def dfs(node: str, tree: CassiopeiaTree): + node_state = states[node] for child in tree.children(node): # Compute the state of the child - child_state = "" - edge_length = tree.get_age(node) - tree.get_age(child) + child_state = [] + edge_length = tree.get_branch_length(node, child) # print(f"{node} -> {child}, length {edge_length}") assert edge_length >= 0 for i in range(num_characters): # See what happens to character i - if node_state[i] != "0": + if node_state[i] != 0: # The character has already mutated; there in nothing # to do - child_state += node_state[i] + child_state += [node_state[i]] continue else: # Determine if the character will mutate. @@ -67,12 +68,13 @@ def dfs(node: int, tree: Tree): < edge_length ) if mutates: - child_state += "1" + child_state += [1] else: - child_state += "0" - tree.set_state(child, child_state) + child_state += [0] + states[child] = child_state dfs(child, tree) - root = tree.root() - tree.set_state(root, "0" * num_characters) + root = tree.root + states[root] = [0] * num_characters dfs(root, tree) + tree.initialize_all_character_states(states) diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 17f509f8..bc9f234f 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -5,6 +5,9 @@ import networkx as nx import numpy as np import pytest +import unittest + +from cassiopeia.data import CassiopeiaTree from cassiopeia.tools import ( BirthProcess, @@ -13,920 +16,510 @@ IIDExponentialLineageTracer, IIDExponentialPosteriorMeanBLE, IIDExponentialPosteriorMeanBLEGridSearchCV, - Tree, ) -def test_no_mutations(): - r""" - Tree topology is just a branch 0->1. - There is one unmutated character i.e.: - root [state = '0'] - | - v - child [state = '0'] - This is thus the simplest possible example of no mutations, and the MLE - branch length should be 0 - """ - tree = nx.DiGraph() - tree.add_node(0), tree.add_node(1) - tree.add_edge(0, 1) - tree.nodes[0]["characters"] = "0" - tree.nodes[1]["characters"] = "0" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - np.testing.assert_almost_equal(tree.get_edge_length(0, 1), 0.0) - np.testing.assert_almost_equal(tree.get_age(0), 0.0) - np.testing.assert_almost_equal(tree.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, 0.0) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_saturation(): - r""" - Tree topology is just a branch 0->1. - There is one mutated character i.e.: - root [state = '0'] - | - v - child [state = '1'] - This is thus the simplest possible example of saturation, and the MLE - branch length should be infinity (>15 for all practical purposes) - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1]) - tree.add_edge(0, 1) - tree.nodes[0]["characters"] = "0" - tree.nodes[1]["characters"] = "1" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - assert tree.get_edge_length(0, 1) > 15.0 - assert tree.get_age(0) > 15.0 - np.testing.assert_almost_equal(tree.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=5) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_hand_solvable_problem_1(): - r""" - Tree topology is just a branch 0->1. - There is one mutated character and one unmutated character, i.e.: - root [state = '00'] - | - v - child [state = '01'] - The solution can be verified by hand. The optimization problem is: - min_{r * t0} log(exp(-r * t0)) + log(1 - exp(-r * t0)) - The solution is r * t0 = ln(2) ~ 0.693 - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1]) - tree.add_edge(0, 1) - tree.nodes[0]["characters"] = "00" - tree.nodes[1]["characters"] = "01" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), np.log(2), decimal=3 - ) - np.testing.assert_almost_equal(tree.get_age(0), np.log(2), decimal=3) - np.testing.assert_almost_equal(tree.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_hand_solvable_problem_2(): - r""" - Tree topology is just a branch 0->1. - There are two mutated characters and one unmutated character, i.e.: - root [state = '000'] - | - v - child [state = '011'] - The solution can be verified by hand. The optimization problem is: - min_{r * t0} log(exp(-r * t0)) + 2 * log(1 - exp(-r * t0)) - The solution is r * t0 = ln(3) ~ 1.098 - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1]) - tree.add_edge(0, 1) - tree.nodes[0]["characters"] = "000" - tree.nodes[1]["characters"] = "011" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), np.log(3), decimal=3 - ) - np.testing.assert_almost_equal(tree.get_age(0), np.log(3), decimal=3) - np.testing.assert_almost_equal(tree.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_hand_solvable_problem_3(): - r""" - Tree topology is just a branch 0->1. - There are two unmutated characters and one mutated character, i.e.: - root [state = '000'] - | - v - child [state = '001'] - The solution can be verified by hand. The optimization problem is: - min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) - The solution is r * t0 = ln(1.5) ~ 0.405 - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1]) - tree.add_edge(0, 1) - tree.nodes[0]["characters"] = "000" - tree.nodes[1]["characters"] = "001" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), np.log(1.5), decimal=3 - ) - np.testing.assert_almost_equal(tree.get_age(0), np.log(1.5), decimal=3) - np.testing.assert_almost_equal(tree.get_age(1), 0.0) - np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_small_tree_with_no_mutations(): - r""" - Perfect binary tree with no mutations: Should give edges of length 0 - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]) - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = "0000" - tree.nodes[1]["characters"] = "0000" - tree.nodes[2]["characters"] = "0000" - tree.nodes[3]["characters"] = "0000" - tree.nodes[4]["characters"] = "0000" - tree.nodes[5]["characters"] = "0000" - tree.nodes[6]["characters"] = "0000" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - for edge in tree.edges(): - np.testing.assert_almost_equal( - tree.get_edge_length(*edge), 0, decimal=3 +class TestIIDExponentialBLE(unittest.TestCase): + def test_no_mutations(self): + r""" + Tree topology is just a branch 0->1. + There is one unmutated character i.e.: + root [state = '0'] + | + v + child [state = '0'] + This is thus the simplest possible example of no mutations, and the MLE + branch length should be 0 + """ + tree = nx.DiGraph() + tree.add_node("0"), tree.add_node("1") + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [0]} ) - np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_small_tree_with_one_mutation(): - r""" - Perfect binary tree with one mutation at a node 6: Should give very short - edges 1->3,1->4,0->2 and very long edges 0->1,2->5,2->6. - The problem can be solved by hand: it trivially reduces to a 1-dimensional - problem: - min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) - The solution is r * t0 = ln(1.5) ~ 0.405 - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = "0" - tree.nodes[1]["characters"] = "0" - tree.nodes[2]["characters"] = "0" - tree.nodes[3]["characters"] = "0" - tree.nodes[4]["characters"] = "0" - tree.nodes[5]["characters"] = "0" - tree.nodes[6]["characters"] = "1" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - np.testing.assert_almost_equal(tree.get_edge_length(0, 1), 0.405, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(0, 2), 0.0, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(1, 3), 0.0, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(1, 4), 0.0, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(2, 5), 0.405, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(2, 6), 0.405, decimal=3) - np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_small_tree_with_saturation(): - r""" - Perfect binary tree with saturation. The edges which saturate should thus - have length infinity (>15 for all practical purposes) - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = "0" - tree.nodes[1]["characters"] = "0" - tree.nodes[2]["characters"] = "1" - tree.nodes[3]["characters"] = "1" - tree.nodes[4]["characters"] = "1" - tree.nodes[5]["characters"] = "1" - tree.nodes[6]["characters"] = "1" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - assert tree.get_edge_length(0, 2) > 15.0 - assert tree.get_edge_length(1, 3) > 15.0 - assert tree.get_edge_length(1, 4) > 15.0 - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_small_tree_regression(): - r""" - Regression test. Cannot be solved by hand. We just check that this solution - never changes. - """ - # Perfect binary tree with normal amount of mutations on each edge - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = "000000000" - tree.nodes[1]["characters"] = "100000000" - tree.nodes[2]["characters"] = "000006000" - tree.nodes[3]["characters"] = "120000000" - tree.nodes[4]["characters"] = "103000000" - tree.nodes[5]["characters"] = "000056700" - tree.nodes[6]["characters"] = "000406089" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - np.testing.assert_almost_equal(tree.get_edge_length(0, 1), 0.203, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(0, 2), 0.082, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(1, 3), 0.175, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(1, 4), 0.175, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(2, 5), 0.295, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(2, 6), 0.295, decimal=3) - np.testing.assert_almost_equal(log_likelihood, -22.689, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_small_symmetric_tree(): - r""" - Symmetric tree should have equal length edges. - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = "000" - tree.nodes[1]["characters"] = "100" - tree.nodes[2]["characters"] = "100" - tree.nodes[3]["characters"] = "110" - tree.nodes[4]["characters"] = "110" - tree.nodes[5]["characters"] = "110" - tree.nodes[6]["characters"] = "110" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), tree.get_edge_length(0, 2) - ) - np.testing.assert_almost_equal( - tree.get_edge_length(1, 3), tree.get_edge_length(1, 4) - ) - np.testing.assert_almost_equal( - tree.get_edge_length(1, 4), tree.get_edge_length(2, 5) - ) - np.testing.assert_almost_equal( - tree.get_edge_length(2, 5), tree.get_edge_length(2, 6) - ) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_small_tree_with_infinite_legs(): - r""" - Perfect binary tree with saturated leaves. The first level of the tree - should be normal (can be solved by hand, solution is log(2)), - the branches for the leaves should be infinity (>15 for all practical - purposes) - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = "00" - tree.nodes[1]["characters"] = "10" - tree.nodes[2]["characters"] = "10" - tree.nodes[3]["characters"] = "11" - tree.nodes[4]["characters"] = "11" - tree.nodes[5]["characters"] = "11" - tree.nodes[6]["characters"] = "11" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - np.testing.assert_almost_equal(tree.get_edge_length(0, 1), 0.693, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(0, 2), 0.693, decimal=3) - assert tree.get_edge_length(1, 3) > 15 - assert tree.get_edge_length(1, 4) > 15 - assert tree.get_edge_length(2, 5) > 15 - assert tree.get_edge_length(2, 6) > 15 - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_on_simulated_data(): - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["age"] = 1 - tree.nodes[1]["age"] = 0.9 - tree.nodes[2]["age"] = 0.1 - tree.nodes[3]["age"] = 0 - tree.nodes[4]["age"] = 0 - tree.nodes[5]["age"] = 0 - tree.nodes[6]["age"] = 0 - np.random.seed(1) - tree = Tree(tree) - IIDExponentialLineageTracer( - mutation_rate=1.0, num_characters=100 - ).overlay_lineage_tracing_data(tree) - for node in tree.nodes(): - tree.set_age(node, -1) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - assert 0.9 < tree.get_age(0) < 1.1 - assert 0.8 < tree.get_age(1) < 1.0 - assert 0.05 < tree.get_age(2) < 0.15 - np.testing.assert_almost_equal(tree.get_age(3), 0) - np.testing.assert_almost_equal(tree.get_age(4), 0) - np.testing.assert_almost_equal(tree.get_age(5), 0) - np.testing.assert_almost_equal(tree.get_age(6), 0) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_subtree_collapses_when_no_mutations(): - r""" - A subtree with no mutations should collapse to 0. It reduces the problem to - the same as in 'test_hand_solvable_problem_1' - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4]), - tree.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) - tree.nodes[0]["characters"] = "0" - tree.nodes[1]["characters"] = "1" - tree.nodes[2]["characters"] = "1" - tree.nodes[3]["characters"] = "1" - tree.nodes[4]["characters"] = "0" - tree = Tree(tree) - model = IIDExponentialBLE() - model.estimate_branch_lengths(tree) - log_likelihood = model.log_likelihood - np.testing.assert_almost_equal( - tree.get_edge_length(0, 1), np.log(2), decimal=3 - ) - np.testing.assert_almost_equal(tree.get_edge_length(1, 2), 0.0, decimal=3) - np.testing.assert_almost_equal(tree.get_edge_length(1, 3), 0.0, decimal=3) - np.testing.assert_almost_equal( - tree.get_edge_length(0, 4), np.log(2), decimal=3 - ) - np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) - log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) - - -def test_minimum_branch_length(): - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4]) - tree.add_edges_from([(0, 1), (0, 2), (0, 3), (2, 4)]) - tree = Tree(tree) - for node in tree.nodes(): - tree.set_state(node, "1") - tree.reconstruct_ancestral_states() - # Too large a minimum_branch_length - model = IIDExponentialBLE(minimum_branch_length=0.6) - model.estimate_branch_lengths(tree) - for node in tree.nodes(): - print(f"{node} = {tree.get_age(node)}") - assert model.log_likelihood == -np.inf - # An okay minimum_branch_length - model = IIDExponentialBLE(minimum_branch_length=0.4) - model.estimate_branch_lengths(tree) - assert model.log_likelihood != -np.inf - - -def test_IIDExponentialBLEGridSearchCV_smoke(): - r""" - Just want to see that it runs in both single and multiprocessor mode - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1]), - tree.add_edges_from([(0, 1)]) - tree.nodes[0]["characters"] = "0" - tree.nodes[1]["characters"] = "1" - tree = Tree(tree) - for processes in [1, 2]: - model = IIDExponentialBLEGridSearchCV( - minimum_branch_lengths=(1.0,), - l2_regularizations=(1.0,), - verbose=True, - processes=processes, + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal(tree.get_branch_length("0", "1"), 0.0) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(tree.get_time("1"), 0.0) + np.testing.assert_almost_equal(log_likelihood, 0.0) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_saturation(self): + r""" + Tree topology is just a branch 0->1. + There is one mutated character i.e.: + root [state = '0'] + | + v + child [state = '1'] + This is thus the simplest possible example of saturation, and the MLE + branch length should be infinity (>15 for all practical purposes) + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]) + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1]} ) + model = IIDExponentialBLE() model.estimate_branch_lengths(tree) - - -def test_IIDExponentialBLEGridSearchCV(): - r""" - We make sure to test a tree for which no regularization produces - a likelihood of 0. - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7]), - tree.add_edges_from( - [(0, 1), (1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)] - ) - tree.nodes[4]["characters"] = "110" - tree.nodes[5]["characters"] = "110" - tree.nodes[6]["characters"] = "100" - tree.nodes[7]["characters"] = "100" - tree = Tree(tree) - tree.reconstruct_ancestral_states() - model = IIDExponentialBLEGridSearchCV( - minimum_branch_lengths=(0, 0.2, 4.0), - l2_regularizations=(0.0, 2.0, 4.0), - verbose=True, - processes=6, - ) - model.estimate_branch_lengths(tree) - print(model.grid) - assert model.grid[0, 0] == -np.inf - - # import seaborn as sns - # import matplotlib.pyplot as plt - # sns.heatmap( - # model.grid, - # yticklabels=model.minimum_branch_lengths, - # xticklabels=model.l2_regularizations, - # mask=np.isneginf(model.grid), - # ) - # plt.ylabel("minimum_branch_length") - # plt.xlabel("l2_regularization") - # plt.show() - - np.testing.assert_almost_equal(model.minimum_branch_length, 0.2) - np.testing.assert_almost_equal(model.l2_regularization, 2.0) - - -def test_IIDExponentialPosteriorMeanBLE(): - r""" - For a small tree with only one internal node, the likelihood of the data, - and the posterior age of the internal node, can be computed easily in - closed form. We check the theoretical values against those obtained from - our model. - """ - from scipy.special import logsumexp - - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3]) - tree.add_edges_from([(0, 1), (1, 2), (1, 3)]) - tree.nodes[0]["characters"] = "000000000" - tree.nodes[1]["characters"] = "010000110" - tree.nodes[2]["characters"] = "010110111" - tree.nodes[3]["characters"] = "011100111" - tree = Tree(tree) - - mutation_rate = 0.3 - birth_rate = 0.8 - discretization_level = 200 - model = IIDExponentialPosteriorMeanBLE( - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, - ) - - model.estimate_branch_lengths(tree) - print(f"{model.log_likelihood} = model.log_likelihood") - - # Test the model log likelihood vs its computation from the joint of the - # age of vertex 1. - model_log_joints = model.log_joints[ - 1 - ] # log P(t_1 = t, X, T) where t_1 is the age of the first node. - model_log_likelihood_2 = logsumexp(model_log_joints) - print(f"{model_log_likelihood_2} = {model_log_likelihood_2}") - np.testing.assert_approx_equal( - model.log_likelihood, model_log_likelihood_2, significant=3 - ) - - # Test the model log likelihood vs its computation from the leaf nodes. - for leaf in [2, 3]: - model_log_likelihood_up = model.up( - leaf, 0, tree.num_cuts(leaf) - ) - np.log(birth_rate * 1.0 / discretization_level) - print(f"{model_log_likelihood_up} = model_log_likelihood_up") - np.testing.assert_approx_equal( - model.log_likelihood, model_log_likelihood_up, significant=3 + log_likelihood = model.log_likelihood + assert tree.get_branch_length("0", "1") > 15.0 + assert tree.get_time("1") > 15.0 + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=5) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_hand_solvable_problem_1(self): + r""" + Tree topology is just a branch 0->1. + There is one mutated character and one unmutated character, i.e.: + root [state = '00'] + | + v + child [state = '01'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(2) ~ 0.693 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]) + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [0, 1]} ) - - # Test the model log likelihood against its numerical computation - numerical_log_likelihood = ( - IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( - tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(2), decimal=3 ) - ) - print(f"{numerical_log_likelihood} = numerical_log_likelihood") - np.testing.assert_approx_equal( - model.log_likelihood, numerical_log_likelihood, significant=3 - ) - - # Test the _whole_ array of log joints P(t_v = t, X, T) against its - # numerical computation - numerical_log_joint = IIDExponentialPosteriorMeanBLE.numerical_log_joint( - tree=tree, - node=1, - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, - ) - np.testing.assert_array_almost_equal( - model.log_joints[1][50:-50], numerical_log_joint[50:-50], decimal=1 - ) - - # Test the model posterior against its numerical posterior - numerical_posterior = IIDExponentialPosteriorMeanBLE.numerical_posterior( - tree=tree, - node=1, - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, - ) - # import matplotlib.pyplot as plt - # plt.plot(model.posteriors[1]) - # plt.show() - # plt.plot(numerical_posterior) - # plt.show() - total_variation = np.sum(np.abs(model.posteriors[1] - numerical_posterior)) - assert total_variation < 0.03 - - # Test the posterior mean against the numerical posterior mean. - numerical_posterior_mean = np.sum( - numerical_posterior - * np.array(range(discretization_level + 1)) - / discretization_level - ) - posterior_mean = tree.get_age(1) - np.testing.assert_approx_equal( - posterior_mean, numerical_posterior_mean, significant=2 - ) - - -def test_IIDExponentialPosteriorMeanBLE_2(): - r""" - We run the Bayesian estimator on a small tree with all different leaves, - and then check that: - - The likelihood of the data, computed from all of the leaves, is the same. - - The posteriors of the internal node ages matches their numerical - counterpart. - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["characters"] = "00" - tree.nodes[1]["characters"] = "00" - tree.nodes[2]["characters"] = "10" - tree.nodes[3]["characters"] = "00" - tree.nodes[4]["characters"] = "01" - tree.nodes[5]["characters"] = "10" - tree.nodes[6]["characters"] = "11" - tree = Tree(tree) - - mutation_rate = 0.625 - birth_rate = 0.75 - discretization_level = 100 - model = IIDExponentialPosteriorMeanBLE( - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, - ) - - model.estimate_branch_lengths(tree) - print(model.log_likelihood) - - # Test the model log likelihood against its numerical computation - numerical_log_likelihood = ( - IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( - tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + np.testing.assert_almost_equal(tree.get_time("1"), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_hand_solvable_problem_2(self): + r""" + Tree topology is just a branch 0->1. + There are two mutated characters and one unmutated character, i.e.: + root [state = '000'] + | + v + child [state = '011'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} log(exp(-r * t0)) + 2 * log(1 - exp(-r * t0)) + The solution is r * t0 = ln(3) ~ 1.098 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]) + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0], + "1": [0, 1, 1]} ) - ) - np.testing.assert_approx_equal( - model.log_likelihood, numerical_log_likelihood, significant=3 - ) - - # Check that the likelihood computed from each leaf node is correct. - for leaf in tree.leaves(): - model_log_likelihood_up = model.up( - leaf, 0, tree.num_cuts(leaf) - ) - np.log(birth_rate * 1.0 / discretization_level) - print(model_log_likelihood_up) - np.testing.assert_approx_equal( - model.log_likelihood, model_log_likelihood_up, significant=3 + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(3), decimal=3 ) - - model_log_likelihood_up_wrong = model.up( - leaf, 0, (tree.num_cuts(leaf) + 1) % 2 + np.testing.assert_almost_equal(tree.get_time("1"), np.log(3), decimal=3) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_hand_solvable_problem_3(self): + r""" + Tree topology is just a branch 0->1. + There are two unmutated characters and one mutated character, i.e.: + root [state = '000'] + | + v + child [state = '001'] + The solution can be verified by hand. The optimization problem is: + min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(1.5) ~ 0.405 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]) + tree.add_edge("0", "1") + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0], + "1": [0, 0, 1]} ) - with pytest.raises(AssertionError): - np.testing.assert_approx_equal( - model.log_likelihood, - model_log_likelihood_up_wrong, - significant=3, - ) - - # Check that the posterior ages of the nodes are correct. - for node in tree.internal_nodes(): - numerical_log_joint = ( - IIDExponentialPosteriorMeanBLE.numerical_log_joint( - tree=tree, - node=node, - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(1.5), decimal=3 + ) + np.testing.assert_almost_equal(tree.get_time("1"), np.log(1.5), decimal=3) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_with_no_mutations(self): + r""" + Perfect binary tree with no mutations: Should give edges of length 0 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]) + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0, 0], + "1": [0, 0, 0, 0], + "2": [0, 0, 0, 0], + "3": [0, 0, 0, 0], + "4": [0, 0, 0, 0], + "5": [0, 0, 0, 0], + "6": [0, 0, 0, 0]} + ) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + for edge in tree.edges: + np.testing.assert_almost_equal( + tree.get_branch_length(*edge), 0, decimal=3 ) + np.testing.assert_almost_equal(log_likelihood, 0.0, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_with_one_mutation(self): + r""" + Perfect binary tree with one mutation at a node 6: Should give very short + edges 1->3,1->4,0->2 and very long edges 0->1,2->5,2->6. + The problem can be solved by hand: it trivially reduces to a 1-dimensional + problem: + min_{r * t0} 2 * log(exp(-r * t0)) + log(1 - exp(-r * t0)) + The solution is r * t0 = ln(1.5) ~ 0.405 + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [0], + "2": [0], + "3": [0], + "4": [0], + "5": [0], + "6": [1]} ) - np.testing.assert_array_almost_equal( - model.log_joints[node][25:-25], - numerical_log_joint[25:-25], - decimal=1, + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal(tree.get_branch_length("0", "1"), 0.405, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("0", "2"), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "3"), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "4"), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("2", "5"), 0.405, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("2", "6"), 0.405, decimal=3) + np.testing.assert_almost_equal(log_likelihood, -1.909, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_with_saturation(self): + r""" + Perfect binary tree with saturation. The edges which saturate should thus + have length infinity (>15 for all practical purposes) + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [0], + "2": [1], + "3": [1], + "4": [1], + "5": [1], + "6": [1]} ) - - # Test the model posterior against its numerical posterior. - numerical_posterior = np.exp( - numerical_log_joint - numerical_log_joint.max() + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + assert tree.get_branch_length("0", "2") > 15.0 + assert tree.get_branch_length("1", "3") > 15.0 + assert tree.get_branch_length("1", "4") > 15.0 + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_regression(self): + r""" + Regression test. Cannot be solved by hand. We just check that this solution + never changes. + """ + # Perfect binary tree with normal amount of mutations on each edge + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0, 0, 0, 0, 0, 0, 0], + "1": [1, 0, 0, 0, 0, 0, 0, 0, 0], + "2": [0, 0, 0, 0, 0, 6, 0, 0, 0], + "3": [1, 2, 0, 0, 0, 0, 0, 0, 0], + "4": [1, 0, 3, 0, 0, 0, 0, 0, 0], + "5": [0, 0, 0, 0, 5, 6, 7, 0, 0], + "6": [0, 0, 0, 4, 0, 6, 0, 8, 9]} ) - numerical_posterior /= numerical_posterior.sum() - # import matplotlib.pyplot as plt - # plt.plot(model.posteriors[node]) - # plt.show() - # plt.plot(analytical_posterior) - # plt.show() - total_variation = np.sum( - np.abs(model.posteriors[node] - numerical_posterior) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal(tree.get_branch_length("0", "1"), 0.203, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("0", "2"), 0.082, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "3"), 0.175, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "4"), 0.175, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("2", "5"), 0.295, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("2", "6"), 0.295, decimal=3) + np.testing.assert_almost_equal(log_likelihood, -22.689, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_symmetric_tree(self): + r""" + Symmetric tree should have equal length edges. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0], + "1": [1, 0, 0], + "2": [1, 0, 0], + "3": [1, 1, 0], + "4": [1, 1, 0], + "5": [1, 1, 0], + "6": [1, 1, 0]} ) - assert total_variation < 0.03 - - -@pytest.mark.slow -def test_IIDExponentialPosteriorMeanBLE_3(): - r""" - Same as test_IIDExponentialPosteriorMeanBLE_2 but with a weirder topology. - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7]), - tree.add_edges_from( - [(0, 1), (1, 2), (1, 3), (2, 4), (2, 5), (2, 6), (0, 7)] - ) - tree.nodes[0]["characters"] = "00" - tree.nodes[1]["characters"] = "00" - tree.nodes[2]["characters"] = "10" - tree.nodes[3]["characters"] = "11" - tree.nodes[4]["characters"] = "10" - tree.nodes[5]["characters"] = "10" - tree.nodes[6]["characters"] = "11" - tree.nodes[7]["characters"] = "00" - tree = Tree(tree) - - mutation_rate = 0.625 - birth_rate = 0.75 - discretization_level = 100 - model = IIDExponentialPosteriorMeanBLE( - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, - ) - - model.estimate_branch_lengths(tree) - print(model.log_likelihood) - - # Test the model log likelihood against its numerical computation - numerical_log_likelihood = ( - IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( - tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), tree.get_branch_length("0", "2") ) - ) - np.testing.assert_approx_equal( - model.log_likelihood, numerical_log_likelihood, significant=3 - ) - - # Check that the likelihood computed from each leaf node is correct. - for leaf in tree.leaves(): - model_log_likelihood_up = model.up( - leaf, 0, tree.num_cuts(leaf) - ) - np.log(birth_rate * 1.0 / discretization_level) - print(model_log_likelihood_up) - np.testing.assert_approx_equal( - model.log_likelihood, model_log_likelihood_up, significant=2 + np.testing.assert_almost_equal( + tree.get_branch_length("1", "3"), tree.get_branch_length("1", "4") ) - - # Check that the posterior ages of the nodes are correct. - for node in tree.internal_nodes(): - numerical_log_joint = ( - IIDExponentialPosteriorMeanBLE.numerical_log_joint( - tree=tree, - node=node, - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, - ) + np.testing.assert_almost_equal( + tree.get_branch_length("1", "4"), tree.get_branch_length("2", "5") ) - np.testing.assert_array_almost_equal( - model.log_joints[node][25:-25], - numerical_log_joint[25:-25], - decimal=1, + np.testing.assert_almost_equal( + tree.get_branch_length("2", "5"), tree.get_branch_length("2", "6") ) - - # Test the model posterior against its numerical posterior. - numerical_posterior = np.exp( - numerical_log_joint - numerical_log_joint.max() + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_small_tree_with_infinite_legs(self): + r""" + Perfect binary tree with saturated leaves. The first level of the tree + should be normal (can be solved by hand, solution is log(2)), + the branches for the leaves should be infinity (>15 for all practical + purposes) + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [1, 0], + "2": [1, 0], + "3": [1, 1], + "4": [1, 1], + "5": [1, 1], + "6": [1, 1]} ) - numerical_posterior /= numerical_posterior.sum() - # import matplotlib.pyplot as plt - # plt.plot(model.posteriors[node]) - # plt.show() - # plt.plot(numerical_posterior) - # plt.show() - total_variation = np.sum( - np.abs(model.posteriors[node] - numerical_posterior) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal(tree.get_branch_length("0", "1"), 0.693, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("0", "2"), 0.693, decimal=3) + assert tree.get_branch_length("1", "3") > 15 + assert tree.get_branch_length("1", "4") > 15 + assert tree.get_branch_length("2", "5") > 15 + assert tree.get_branch_length("2", "6") > 15 + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_on_simulated_data(self): + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.set_times( + {"0": 0, + "1": 0.1, + "2": 0.9, + "3": 1.0, + "4": 1.0, + "5": 1.0, + "6": 1.0} ) - assert total_variation < 0.03 - - -@pytest.mark.slow -def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(): - r""" - A tree from the DREAM subchallenge 1, verified analytically. - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[3]["characters"] = "2011000111" - tree.nodes[4]["characters"] = "2011010111" - tree.nodes[5]["characters"] = "2011010111" - tree.nodes[6]["characters"] = "2011010111" - tree = Tree(tree) - tree.reconstruct_ancestral_states() - - mutation_rate = 0.6 - birth_rate = 0.8 - discretization_level = 500 - model = IIDExponentialPosteriorMeanBLE( - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, - ) - - model.estimate_branch_lengths(tree) - print(model.log_likelihood) - - # Test the model log likelihood against its numerical computation - numerical_log_likelihood = ( - IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( - tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + np.random.seed(1) + IIDExponentialLineageTracer( + mutation_rate=1.0, num_characters=100 + ).overlay_lineage_tracing_data(tree) + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + assert 0.05 < tree.get_time("1") < 0.15 + assert 0.8 < tree.get_time("2") < 1.0 + assert 0.9 < tree.get_time("3") < 1.1 + assert 0.9 < tree.get_time("4") < 1.1 + assert 0.9 < tree.get_time("5") < 1.1 + assert 0.9 < tree.get_time("6") < 1.1 + np.testing.assert_almost_equal(tree.get_time("0"), 0) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_subtree_collapses_when_no_mutations(self): + r""" + A subtree with no mutations should collapse to 0. It reduces the problem to + the same as in 'test_hand_solvable_problem_1' + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4"]), + tree.add_edges_from([("0", "1"), ("1", "2"), ("1", "3"), ("0", "4")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1], + "2": [1], + "3": [1], + "4": [0]} ) - ) - np.testing.assert_approx_equal( - model.log_likelihood, numerical_log_likelihood, significant=3 - ) - - # Check that the likelihood computed from each leaf node is correct. - for leaf in tree.leaves(): - model_log_likelihood_up = model.up( - leaf, 0, tree.num_cuts(leaf) - ) - np.log(birth_rate * 1.0 / discretization_level) - print(model_log_likelihood_up) - np.testing.assert_approx_equal( - model.log_likelihood, model_log_likelihood_up, significant=3 + model = IIDExponentialBLE() + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(2), decimal=3 ) - - # Check that the posterior ages of the nodes are correct. - for node in tree.internal_nodes(): - numerical_log_joint = ( - IIDExponentialPosteriorMeanBLE.numerical_log_joint( - tree=tree, - node=node, - mutation_rate=mutation_rate, - birth_rate=birth_rate, - discretization_level=discretization_level, - ) + np.testing.assert_almost_equal(tree.get_branch_length("1", "2"), 0.0, decimal=3) + np.testing.assert_almost_equal(tree.get_branch_length("1", "3"), 0.0, decimal=3) + np.testing.assert_almost_equal( + tree.get_branch_length("0", "4"), np.log(2), decimal=3 ) - mean_error = np.mean( - np.abs(model.log_joints[node][25:-25] - numerical_log_joint[25:-25]) - / np.abs(numerical_log_joint[25:-25]) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + + def test_minimum_branch_length(self): + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4"]) + tree.add_edges_from([("0", "1"), ("0", "2"), ("0", "3"), ("2", "4")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_character_states_at_leaves( + {"1": [1], + "3": [1], + "4": [1]} ) - assert mean_error < 0.03 - - # Test the model posterior against its numerical posterior. - numerical_posterior = np.exp( - numerical_log_joint - numerical_log_joint.max() + tree.reconstruct_ancestral_characters(zero_the_root=True) + # Too large a minimum_branch_length + model = IIDExponentialBLE(minimum_branch_length=0.6) + model.estimate_branch_lengths(tree) + for node in tree.nodes: + print(f"{node} = {tree.get_time(node)}") + assert model.log_likelihood == -np.inf + # An okay minimum_branch_length + model = IIDExponentialBLE(minimum_branch_length=0.4) + model.estimate_branch_lengths(tree) + assert model.log_likelihood != -np.inf + + +class TestIIDExponentialBLEGridSearchCV(unittest.TestCase): + def test_IIDExponentialBLEGridSearchCV_smoke(self): + r""" + Just want to see that it runs in both single and multiprocessor mode + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]), + tree.add_edges_from([("0", "1")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1]}, ) - numerical_posterior /= numerical_posterior.sum() - # import matplotlib.pyplot as plt - # plt.plot(model.posteriors[node]) - # plt.show() - # plt.plot(numerical_posterior) - # plt.show() - total_variation = np.sum( - np.abs(model.posteriors[node] - numerical_posterior) + for processes in [1, 2]: + model = IIDExponentialBLEGridSearchCV( + minimum_branch_lengths=(1.0,), + l2_regularizations=(1.0,), + verbose=True, + processes=processes, + ) + model.estimate_branch_lengths(tree) + + def test_IIDExponentialBLEGridSearchCV(self): + r""" + We make sure to test a tree for which no regularization produces + a likelihood of 0. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6", "7"]), + tree.add_edges_from( + [("0", "1"), ("1", "2"), ("1", "3"), ("2", "4"), ("2", "5"), ("3", "6"), + ("3", "7")] ) - assert total_variation < 0.05 - - -def test_IIDExponentialPosteriorMeanBLEGridSeachCV_smoke(): - r""" - Just want to see that it runs in both single and multiprocessor mode - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1]), - tree.add_edges_from([(0, 1)]) - tree.nodes[0]["characters"] = "0" - tree.nodes[1]["characters"] = "1" - tree = Tree(tree) - for processes in [1, 2]: - model = IIDExponentialPosteriorMeanBLEGridSearchCV( - mutation_rates=(0.5,), - birth_rates=(1.5,), - discretization_level=5, + tree = CassiopeiaTree(tree=tree) + tree.initialize_character_states_at_leaves( + {"4": [1, 1, 0], + "5": [1, 1, 0], + "6": [1, 0, 0], + "7": [1, 0, 0]}, + ) + tree.reconstruct_ancestral_characters(zero_the_root=True) + model = IIDExponentialBLEGridSearchCV( + minimum_branch_lengths=(0, 0.2, 4.0), + l2_regularizations=(0.0, 2.0, 4.0), verbose=True, + processes=6, ) model.estimate_branch_lengths(tree) + print(model.grid) + assert model.grid[0, 0] == -np.inf + # import seaborn as sns + # import matplotlib.pyplot as plt + # sns.heatmap( + # model.grid, + # yticklabels=model.minimum_branch_lengths, + # xticklabels=model.l2_regularizations, + # mask=np.isneginf(model.grid), + # ) + # plt.ylabel("minimum_branch_length") + # plt.xlabel("l2_regularization") + # plt.show() -def test_IIDExponentialPosteriorMeanBLEGridSeachCV(): - r""" - We just check that the grid search estimator does its job on a small grid. - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4]), - tree.add_edges_from([(0, 1), (1, 2), (1, 3), (0, 4)]) - tree.nodes[0]["characters"] = "0" - tree.nodes[1]["characters"] = "1" - tree.nodes[2]["characters"] = "1" - tree.nodes[3]["characters"] = "1" - tree.nodes[4]["characters"] = "0" - tree = Tree(tree) - - discretization_level = 100 - mutation_rates = (0.625, 0.750, 0.875) - birth_rates = (0.25, 0.50, 0.75) - model = IIDExponentialPosteriorMeanBLEGridSearchCV( - mutation_rates=mutation_rates, - birth_rates=birth_rates, - discretization_level=discretization_level, - verbose=True, - ) - - # Test the model log likelihood against its numerical computation - model.estimate_branch_lengths(tree) - numerical_log_likelihood = ( - IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( - tree=tree, - mutation_rate=model.mutation_rate, - birth_rate=model.birth_rate, - ) - ) - np.testing.assert_approx_equal( - model.log_likelihood, numerical_log_likelihood, significant=3 - ) - - # import matplotlib.pyplot as plt - # import seaborn as sns - # sns.heatmap( - # model.grid, - # yticklabels=mutation_rates, - # xticklabels=birth_rates - # ) - # plt.ylabel('Mutation Rate') - # plt.xlabel('Birth Rate') - # plt.show() - - np.testing.assert_almost_equal(model.mutation_rate, 0.75) - np.testing.assert_almost_equal(model.birth_rate, 0.5) - np.testing.assert_almost_equal(model.posterior_means[1], 0.3184, decimal=3) + np.testing.assert_almost_equal(model.minimum_branch_length, 0.2) + np.testing.assert_almost_equal(model.l2_regularization, 2.0) def get_z_scores( @@ -937,6 +530,10 @@ def get_z_scores( mutation_rate_model, num_characters, ): + r""" + This function is at the global scope because it needs to be pickled + for parallelization. + """ np.random.seed(repetition) tree = BirthProcess( birth_rate=birth_rate_true, tree_depth=1.0 @@ -953,9 +550,9 @@ def get_z_scores( ) model.estimate_branch_lengths(tree) z_scores = [] - if len(tree.internal_nodes()) > 0: - for node in [np.random.choice(tree.internal_nodes())]: - true_age = tree_true.get_age(node) + if len(tree.non_root_internal_nodes) > 0: + for node in [np.random.choice(tree.non_root_internal_nodes)]: + true_age = tree_true.get_time(node) z_score = model.posteriors[node][ : int(true_age * discretization_level) ].sum() @@ -964,6 +561,10 @@ def get_z_scores( def get_z_scores_under_true_model(repetition): + r""" + This function is at the global scope because it needs to be pickled + for parallelization. + """ return get_z_scores( repetition, birth_rate_true=0.8, @@ -975,6 +576,10 @@ def get_z_scores_under_true_model(repetition): def get_z_scores_under_misspecified_model(repetition): + r""" + This function is at the global scope because it needs to be pickled + for parallelization. + """ return get_z_scores( repetition, birth_rate_true=0.4, @@ -985,41 +590,489 @@ def get_z_scores_under_misspecified_model(repetition): ) -@pytest.mark.slow -def test_IIDExponentialPosteriorMeanBLE_posterior_calibration(): - r""" - Under the true model, the Z scores should be ~Unif[0, 1] - Under the wrong model, the Z scores should not be ~Unif[0, 1] - This test is slow because we need to make many repetitions to get - enough statistical power for the test to be meaningful. - We use p-values computed from the Hoeffding bound. - TODO: There might be a more powerful test, e.g. Kolmogorov–Smirnov? - (This would mean we need less repetitions and can make the test faster.) - """ - repetitions = 1000 - - # Under the true model, the Z scores should be ~Unif[0, 1] - with multiprocessing.Pool(processes=6) as pool: - z_scores = pool.map(get_z_scores_under_true_model, range(repetitions)) - z_scores = np.array(list(itertools.chain(*z_scores))) - mean_z_score = z_scores.mean() - p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) - print(f"p_value under true model = {p_value}") - assert p_value > 0.01 - # import matplotlib.pyplot as plt - # plt.hist(z_scores, bins=10) - # plt.show() - - # Under the wrong model, the Z scores should not be ~Unif[0, 1] - with multiprocessing.Pool(processes=6) as pool: - z_scores = pool.map( - get_z_scores_under_misspecified_model, range(repetitions) +class TestIIDExponentialPosteriorMeanBLE(unittest.TestCase): + def test_IIDExponentialPosteriorMeanBLE(self): + r""" + For a small tree with only one internal node, the likelihood of the data, + and the posterior age of the internal node, can be computed easily in + closed form. We check the theoretical values against those obtained from + our model. + """ + from scipy.special import logsumexp + + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3"]) + tree.add_edges_from([("0", "1"), ("1", "2"), ("1", "3")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0, 0, 0, 0, 0, 0, 0, 0], + "1": [0, 1, 0, 0, 0, 0, 1, 1, 0], + "2": [0, 1, 0, 1, 1, 0, 1, 1, 1], + "3": [0, 1, 1, 1, 0, 0, 1, 1, 1]}, + ) + + mutation_rate = 0.3 + birth_rate = 0.8 + discretization_level = 200 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + + model.estimate_branch_lengths(tree) + print(f"{model.log_likelihood} = model.log_likelihood") + + # Test the model log likelihood vs its computation from the joint of the + # age of vertex 1. + model_log_joints = model.log_joints[ + "1" + ] # log P(t_1 = t, X, T) where t_1 is the age of the first node. + model_log_likelihood_2 = logsumexp(model_log_joints) + print(f"{model_log_likelihood_2} = {model_log_likelihood_2}") + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_2, significant=3 + ) + + # Test the model log likelihood vs its computation from the leaf nodes. + for leaf in ["2", "3"]: + model_log_likelihood_up = model.up( + leaf, discretization_level, tree.get_number_of_mutated_characters_in_node(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) + print(f"{model_log_likelihood_up} = model_log_likelihood_up") + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 + ) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + print(f"{numerical_log_likelihood} = numerical_log_likelihood") + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Test the _whole_ array of log joints P(t_v = t, X, T) against its + # numerical computation + numerical_log_joint = IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node="1", + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + np.testing.assert_array_almost_equal( + model.log_joints["1"][50:-50], numerical_log_joint[50:-50], decimal=1 + ) + + # Test the model posterior against its numerical posterior + numerical_posterior = IIDExponentialPosteriorMeanBLE.numerical_posterior( + tree=tree, + node="1", + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[1]) + # plt.show() + # plt.plot(numerical_posterior) + # plt.show() + total_variation = np.sum(np.abs(model.posteriors["1"] - numerical_posterior)) + assert total_variation < 0.03 + + # Test the posterior mean against the numerical posterior mean. + numerical_posterior_mean = np.sum( + numerical_posterior + * np.array(range(discretization_level + 1)) + / discretization_level ) - z_scores = np.array(list(itertools.chain(*z_scores))) - mean_z_score = z_scores.mean() - p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) - print(f"p_value under misspecified model = {p_value}") - assert p_value < 0.01 - # import matplotlib.pyplot as plt - # plt.hist(z_scores, bins=10) - # plt.show() + posterior_mean = tree.get_time("1") + np.testing.assert_approx_equal( + posterior_mean, numerical_posterior_mean, significant=2 + ) + + def test_IIDExponentialPosteriorMeanBLE_2(self): + r""" + We run the Bayesian estimator on a small tree with all different leaves, + and then check that: + - The likelihood of the data, computed from all of the leaves, is the same. + - The posteriors of the internal node ages matches their numerical + counterpart. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [0, 0], + "2": [1, 0], + "3": [0, 0], + "4": [0, 1], + "5": [1, 0], + "6": [1, 1]}, + ) + + mutation_rate = 0.625 + birth_rate = 0.75 + discretization_level = 100 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + + model.estimate_branch_lengths(tree) + print(model.log_likelihood) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Check that the likelihood computed from each leaf node is correct. + for leaf in tree.leaves: + model_log_likelihood_up = model.up( + leaf, discretization_level, tree.get_number_of_mutated_characters_in_node(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 + ) + + model_log_likelihood_up_wrong = model.up( + leaf, discretization_level, (tree.get_number_of_mutated_characters_in_node(leaf) + 1) % 2 + ) + with pytest.raises(AssertionError): + np.testing.assert_approx_equal( + model.log_likelihood, + model_log_likelihood_up_wrong, + significant=3, + ) + + # Check that the posterior ages of the nodes are correct. + for node in tree.non_root_internal_nodes: + numerical_log_joint = ( + IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + ) + np.testing.assert_array_almost_equal( + model.log_joints[node][25:-25], + numerical_log_joint[25:-25], + decimal=1, + ) + + # Test the model posterior against its numerical posterior. + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[node]) + # plt.show() + # plt.plot(analytical_posterior) + # plt.show() + total_variation = np.sum( + np.abs(model.posteriors[node] - numerical_posterior) + ) + assert total_variation < 0.03 + + @pytest.mark.slow + def test_IIDExponentialPosteriorMeanBLE_3(self): + r""" + Same as test_IIDExponentialPosteriorMeanBLE_2 but with a weirder topology. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6", "7"]), + tree.add_edges_from( + [("0", "1"), ("1", "2"), ("1", "3"), ("2", "4"), ("2", "5"), ("2", "6"), + ("0", "7")] + ) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [0, 0], + "2": [1, 0], + "3": [1, 1], + "4": [1, 0], + "5": [1, 0], + "6": [1, 1], + "7": [0, 0]}, + ) + + mutation_rate = 0.625 + birth_rate = 0.75 + discretization_level = 100 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + + model.estimate_branch_lengths(tree) + print(model.log_likelihood) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Check that the likelihood computed from each leaf node is correct. + for leaf in tree.leaves: + model_log_likelihood_up = model.up( + leaf, discretization_level, tree.get_number_of_mutated_characters_in_node(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=2 + ) + + # Check that the posterior ages of the nodes are correct. + for node in tree.non_root_internal_nodes: + numerical_log_joint = ( + IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + ) + np.testing.assert_array_almost_equal( + model.log_joints[node][25:-25], + numerical_log_joint[25:-25], + decimal=1, + ) + + # Test the model posterior against its numerical posterior. + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[node]) + # plt.show() + # plt.plot(numerical_posterior) + # plt.show() + total_variation = np.sum( + np.abs(model.posteriors[node] - numerical_posterior) + ) + assert total_variation < 0.03 + + @pytest.mark.slow + def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(self): + r""" + A tree from the DREAM subchallenge 1, verified analytically. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from([("0", "1"), ("0", "2"), ("1", "3"), ("1", "4"), + ("2", "5"), ("2", "6")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_character_states_at_leaves( + {"3": [2, 0, 1, 1, 0, 0, 0, 1, 1, 1], + "4": [2, 0, 1, 1, 0, 1, 0, 1, 1, 1], + "5": [2, 0, 1, 1, 0, 1, 0, 1, 1, 1], + "6": [2, 0, 1, 1, 0, 1, 0, 1, 1, 1]}, + ) + tree.reconstruct_ancestral_characters(zero_the_root=True) + + mutation_rate = 0.6 + birth_rate = 0.8 + discretization_level = 500 + model = IIDExponentialPosteriorMeanBLE( + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + + model.estimate_branch_lengths(tree) + print(model.log_likelihood) + + # Test the model log likelihood against its numerical computation + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, mutation_rate=mutation_rate, birth_rate=birth_rate + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # Check that the likelihood computed from each leaf node is correct. + for leaf in tree.leaves: + model_log_likelihood_up = model.up( + leaf, discretization_level, tree.get_number_of_mutated_characters_in_node(leaf) + ) - np.log(birth_rate * 1.0 / discretization_level) + print(model_log_likelihood_up) + np.testing.assert_approx_equal( + model.log_likelihood, model_log_likelihood_up, significant=3 + ) + + # Check that the posterior ages of the nodes are correct. + for node in tree.non_root_internal_nodes: + numerical_log_joint = ( + IIDExponentialPosteriorMeanBLE.numerical_log_joint( + tree=tree, + node=node, + mutation_rate=mutation_rate, + birth_rate=birth_rate, + discretization_level=discretization_level, + ) + ) + mean_error = np.mean( + np.abs(model.log_joints[node][25:-25] - numerical_log_joint[25:-25]) + / np.abs(numerical_log_joint[25:-25]) + ) + assert mean_error < 0.03 + + # Test the model posterior against its numerical posterior. + numerical_posterior = np.exp( + numerical_log_joint - numerical_log_joint.max() + ) + numerical_posterior /= numerical_posterior.sum() + # import matplotlib.pyplot as plt + # plt.plot(model.posteriors[node]) + # plt.show() + # plt.plot(numerical_posterior) + # plt.show() + total_variation = np.sum( + np.abs(model.posteriors[node] - numerical_posterior) + ) + assert total_variation < 0.05 + + @pytest.mark.slow + def test_IIDExponentialPosteriorMeanBLE_posterior_calibration(self): + r""" + Under the true model, the Z scores should be ~Unif[0, 1] + Under the wrong model, the Z scores should not be ~Unif[0, 1] + This test is slow because we need to make many repetitions to get + enough statistical power for the test to be meaningful. + We use p-values computed from the Hoeffding bound. + TODO: There might be a more powerful test, e.g. Kolmogorov–Smirnov? + (This would mean we need less repetitions and can make the test faster.) + """ + repetitions = 1000 + + # Under the true model, the Z scores should be ~Unif[0, 1] + with multiprocessing.Pool(processes=6) as pool: + z_scores = pool.map(get_z_scores_under_true_model, range(repetitions)) + z_scores = np.array(list(itertools.chain(*z_scores))) + mean_z_score = z_scores.mean() + p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) + print(f"p_value under true model = {p_value}") + assert p_value > 0.01 + # import matplotlib.pyplot as plt + # plt.hist(z_scores, bins=10) + # plt.show() + + # Under the wrong model, the Z scores should not be ~Unif[0, 1] + with multiprocessing.Pool(processes=6) as pool: + z_scores = pool.map( + get_z_scores_under_misspecified_model, range(repetitions) + ) + z_scores = np.array(list(itertools.chain(*z_scores))) + mean_z_score = z_scores.mean() + p_value = 2 * np.exp(-2 * repetitions * (mean_z_score - 0.5) ** 2) + print(f"p_value under misspecified model = {p_value}") + assert p_value < 0.01 + # import matplotlib.pyplot as plt + # plt.hist(z_scores, bins=10) + # plt.show() + + +class TestIIDExponentialPosteriorMeanBLEGridSeachCV(unittest.TestCase): + def test_IIDExponentialPosteriorMeanBLEGridSeachCV_smoke(self): + r""" + Just want to see that it runs in both single and multiprocessor mode + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1"]), + tree.add_edges_from([("0", "1")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1]} + ) + for processes in [1, 2]: + model = IIDExponentialPosteriorMeanBLEGridSearchCV( + mutation_rates=(0.5,), + birth_rates=(1.5,), + discretization_level=5, + verbose=True, + ) + model.estimate_branch_lengths(tree) + + def test_IIDExponentialPosteriorMeanBLEGridSeachCV(self): + r""" + We just check that the grid search estimator does its job on a small grid. + """ + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4"]), + tree.add_edges_from([("0", "1"), ("1", "2"), ("1", "3"), ("0", "4")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0], + "1": [1], + "2": [1], + "3": [1], + "4": [0]} + ) + + discretization_level = 100 + mutation_rates = (0.625, 0.750, 0.875) + birth_rates = (0.25, 0.50, 0.75) + model = IIDExponentialPosteriorMeanBLEGridSearchCV( + mutation_rates=mutation_rates, + birth_rates=birth_rates, + discretization_level=discretization_level, + verbose=True, + ) + + # Test the model log likelihood against its numerical computation + model.estimate_branch_lengths(tree) + numerical_log_likelihood = ( + IIDExponentialPosteriorMeanBLE.numerical_log_likelihood( + tree=tree, + mutation_rate=model.mutation_rate, + birth_rate=model.birth_rate, + ) + ) + np.testing.assert_approx_equal( + model.log_likelihood, numerical_log_likelihood, significant=3 + ) + + # import matplotlib.pyplot as plt + # import seaborn as sns + # sns.heatmap( + # model.grid, + # yticklabels=mutation_rates, + # xticklabels=birth_rates + # ) + # plt.ylabel('Mutation Rate') + # plt.xlabel('Birth Rate') + # plt.show() + + np.testing.assert_almost_equal(model.mutation_rate, 0.75) + np.testing.assert_almost_equal(model.birth_rate, 0.5) + np.testing.assert_almost_equal(model.posterior_means["1"], 0.6815, decimal=3) diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py index 903dea41..6e3067f2 100644 --- a/test/tools_tests/lineage_simulator_test.py +++ b/test/tools_tests/lineage_simulator_test.py @@ -1,3 +1,6 @@ +import pytest +import unittest + import numpy as np from cassiopeia.tools import ( @@ -7,53 +10,77 @@ ) -def test_PerfectBinaryTree(): - tree = PerfectBinaryTree( - generation_branch_lengths=[2, 3] - ).simulate_lineage() - newick = tree.to_newick_tree_format(print_internal_nodes=True) - assert newick == "((3:3,4:3)1:2,(5:3,6:3)2:2)0);" - - -def test_PerfectBinaryTreeWithRootBranch(): - tree = PerfectBinaryTreeWithRootBranch( - generation_branch_lengths=[2, 3, 4] - ).simulate_lineage() - newick = tree.to_newick_tree_format(print_internal_nodes=True) - assert newick == "(((4:4,5:4)2:3,(6:4,7:4)3:3)1:2)0);" - - -def test_BirthProcess(): - r""" - Generate tree, then choose a random lineage can count how many nodes are on - the lineage. This is the number of times the process triggered on that - lineage. - - Also, the probability that a tree with only one internal node is obtained is - e^-lam * (1 - e^-lam) where lam is the birth rate, so we also check this. - """ - np.random.seed(1) - birth_rate = 0.6 - intensities = [] - repetitions = 10000 - topology_hits = 0 - for _ in range(repetitions): - tree_true = BirthProcess( - birth_rate=birth_rate, tree_depth=1.0 +class TestPerfectBinaryTree(unittest.TestCase): + def test_PerfectBinaryTree(self): + tree = PerfectBinaryTree( + generation_branch_lengths=[2, 3] ).simulate_lineage() - if len(tree_true.nodes()) == 4: - topology_hits += 1 - leaf = np.random.choice(tree_true.leaves()) - n_leaves = len(tree_true.leaves()) - n_hits = tree_true.num_ancestors(leaf) - 1 - intensity = n_leaves / 2 ** n_hits * n_hits - intensities.append(intensity) - # Check that the probability of the topology matches - empirical_topology_prob = topology_hits / repetitions - theoretical_topology_prob = np.exp(-birth_rate) * ( - 1.0 - np.exp(-birth_rate) - ) - assert np.abs(empirical_topology_prob - theoretical_topology_prob) < 0.02 - inferred_birth_rate = np.array(intensities).mean() - print(f"{birth_rate} == {inferred_birth_rate}") - assert np.abs(birth_rate - inferred_birth_rate) < 0.05 + newick = tree.get_newick() + assert newick == "((3,4),(5,6));" + self.assertDictEqual( + tree.get_times(), {0: 0, 1: 2, 2: 2, 3: 5, 4: 5, 5: 5, 6: 5} + ) + + +class TestPerfectBinaryTreeWithRootBranch(unittest.TestCase): + def test_PerfectBinaryTreeWithRootBranch(self): + tree = PerfectBinaryTreeWithRootBranch( + generation_branch_lengths=[2, 3, 4] + ).simulate_lineage() + newick = tree.get_newick() + assert newick == "(((4,5),(6,7)));" + self.assertDictEqual( + tree.get_times(), {0: 0, 1: 2, 2: 5, 3: 5, 4: 9, 5: 9, 6: 9, 7: 9} + ) + + +class TestBirthProcess(unittest.TestCase): + @pytest.mark.slow + def test_BirthProcess(self): + r""" + Generate tree, then choose a random lineage can count how many nodes are on + the lineage. This is the number of times the process triggered on that + lineage. + + Also, the probability that a tree with only one internal node is obtained is + e^-lam * (1 - e^-lam) where lam is the birth rate, so we also check this. + """ + np.random.seed(1) + birth_rate = 0.6 + intensities = [] + repetitions = 10000 + topology_hits = 0 + + def num_ancestors(tree, node: int) -> int: + r""" + Number of ancestors of a node. Terribly inefficient implementation. + """ + res = 0 + root = tree.root + while node != root: + node = tree.parent(node) + res += 1 + return res + + for _ in range(repetitions): + tree_true = BirthProcess( + birth_rate=birth_rate, tree_depth=1.0 + ).simulate_lineage() + if len(tree_true.nodes) == 4: + topology_hits += 1 + leaf = np.random.choice(tree_true.leaves) + n_leaves = len(tree_true.leaves) + n_hits = num_ancestors(tree_true, leaf) - 1 + intensity = n_leaves / 2 ** n_hits * n_hits + intensities.append(intensity) + # Check that the probability of the topology matches + empirical_topology_prob = topology_hits / repetitions + theoretical_topology_prob = np.exp(-birth_rate) * ( + 1.0 - np.exp(-birth_rate) + ) + assert ( + np.abs(empirical_topology_prob - theoretical_topology_prob) < 0.02 + ) + inferred_birth_rate = np.array(intensities).mean() + print(f"{birth_rate} == {inferred_birth_rate}") + assert np.abs(birth_rate - inferred_birth_rate) < 0.05 diff --git a/test/tools_tests/lineage_tracing_simulator_test.py b/test/tools_tests/lineage_tracing_simulator_test.py index 4d2bab73..afb004dd 100644 --- a/test/tools_tests/lineage_tracing_simulator_test.py +++ b/test/tools_tests/lineage_tracing_simulator_test.py @@ -1,25 +1,35 @@ +import unittest + import networkx as nx import numpy as np -from cassiopeia.tools import IIDExponentialLineageTracer, Tree +from cassiopeia.data import CassiopeiaTree +from cassiopeia.tools import IIDExponentialLineageTracer -def test_smoke(): - r""" - Just tests that lineage_tracing_simulator runs - """ - np.random.seed(1) - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5, 6]), - tree.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) - tree.nodes[0]["age"] = 1 - tree.nodes[1]["age"] = 0.9 - tree.nodes[2]["age"] = 0.1 - tree.nodes[3]["age"] = 0 - tree.nodes[4]["age"] = 0 - tree.nodes[5]["age"] = 0 - tree.nodes[6]["age"] = 0 - tree = Tree(tree) - IIDExponentialLineageTracer( - mutation_rate=1.0, num_characters=10 - ).overlay_lineage_tracing_data(tree) +class Test(unittest.TestCase): + def test_smoke(self): + r""" + Just tests that lineage_tracing_simulator runs + """ + np.random.seed(1) + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3", "4", "5", "6"]), + tree.add_edges_from( + [ + ("0", "1"), + ("0", "2"), + ("1", "3"), + ("1", "4"), + ("2", "5"), + ("2", "6"), + ] + ) + tree = CassiopeiaTree(tree=tree) + tree.set_times( + {"0": 0, "1": 0.1, "2": 0.9, "3": 1.0, "4": 1.0, "5": 1.0, "6": 1.0} + ) + np.random.seed(1) + IIDExponentialLineageTracer( + mutation_rate=1.0, num_characters=10 + ).overlay_lineage_tracing_data(tree) From 4d0a3b6a25b24f325f18601843cdf86adea229dd Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Fri, 29 Jan 2021 11:57:14 -0800 Subject: [PATCH 41/61] Remove duplicated code --- cassiopeia/data/CassiopeiaTree.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/cassiopeia/data/CassiopeiaTree.py b/cassiopeia/data/CassiopeiaTree.py index b1283173..7df338a0 100644 --- a/cassiopeia/data/CassiopeiaTree.py +++ b/cassiopeia/data/CassiopeiaTree.py @@ -364,8 +364,7 @@ def non_root_internal_nodes(self) -> List[str]: Raises: CassiopeiaTreeError if the tree has not been initialized. """ - if self.__network is None: - raise CassiopeiaTreeError("Tree is not initialized.") + self.__check_network_initialized() if "non_root_internal_nodes" not in self.__cache: res = [ @@ -556,8 +555,7 @@ def set_times(self, time_dict: Dict[str, float]) -> None: CassiopeiaTreeError if the tree is not initialized, if the time of any parent is greater than that of a child. """ - if self.__network is None: - raise CassiopeiaTreeError("Tree is not initialized") + self.__check_network_initialized() # TODO: Check that the keys of time_dict match exactly the nodes in the # tree and raise otherwise? @@ -594,8 +592,7 @@ def get_times(self) -> Dict[str, float]: Raises: CassiopeiaTreeError if the tree has not been initialized. """ - if self.__network is None: - raise CassiopeiaTreeError("Tree is not initialized.") + self.__check_network_initialized() return dict([(node, self.get_time(node)) for node in self.nodes]) From c419804a2b0fffc8cf5d12334b5f340dfae9149e Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Thu, 4 Feb 2021 12:04:21 -0800 Subject: [PATCH 42/61] check --- cassiopeia/data/CassiopeiaTree.py | 2 + cassiopeia/tools/__init__.py | 1 - cassiopeia/tools/tree.py | 306 ------------------------------ test/tools_tests/tree_test.py | 188 ------------------ 4 files changed, 2 insertions(+), 495 deletions(-) delete mode 100644 cassiopeia/tools/tree.py delete mode 100644 test/tools_tests/tree_test.py diff --git a/cassiopeia/data/CassiopeiaTree.py b/cassiopeia/data/CassiopeiaTree.py index 7df338a0..1c1ce671 100644 --- a/cassiopeia/data/CassiopeiaTree.py +++ b/cassiopeia/data/CassiopeiaTree.py @@ -691,6 +691,8 @@ def get_character_states(self, node: str) -> List[int]: Returns: The full character state array of the specified node. """ + if node not in self.__network.nodes: + raise CassiopeiaTreeError(f"Node {node} does not exist!") return self.__network.nodes[node]["character_states"][:] def depth_first_traverse_nodes( diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index 9d1b577d..ff6ecdf0 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -15,4 +15,3 @@ LineageTracingSimulator, IIDExponentialLineageTracer, ) -from .tree import Tree diff --git a/cassiopeia/tools/tree.py b/cassiopeia/tools/tree.py deleted file mode 100644 index 0e06c703..00000000 --- a/cassiopeia/tools/tree.py +++ /dev/null @@ -1,306 +0,0 @@ -from typing import List, Optional, Tuple - -import networkx as nx - - -class Tree: - r""" - A phylogenetic tree for holding data from lineages and lineage tracing - experiments. - - (Currently implemented as a light wrapper over networkx.DiGraph) - - Args: - tree: The networkx.DiGraph from which to create the tree. - """ - - def __init__(self, tree: nx.DiGraph): - self.tree = tree - - def root(self) -> int: - tree = self.tree - root = [n for n in tree if tree.in_degree(n) == 0][0] - return root - - def leaves(self) -> List[int]: - tree = self.tree - leaves = [ - n - for n in tree - if tree.out_degree(n) == 0 and tree.in_degree(n) == 1 - ] - return leaves - - def internal_nodes(self) -> List[int]: - tree = self.tree - return [n for n in tree if n != self.root() and n not in self.leaves()] - - def nodes(self): - tree = self.tree - return list(tree.nodes()) - - def num_characters(self) -> int: - return len(self.tree.nodes[self.root()]["characters"]) - - def get_state(self, node: int) -> str: - tree = self.tree - return tree.nodes[node]["characters"] - - def set_state(self, node: int, state: str) -> None: - tree = self.tree - tree.nodes[node]["characters"] = state - - def set_states(self, node_state_list: List[Tuple[int, str]]) -> None: - for (node, state) in node_state_list: - self.set_state(node, state) - - def get_age(self, node: int) -> float: - tree = self.tree - return tree.nodes[node]["age"] - - def set_age(self, node: int, age: float) -> None: - tree = self.tree - tree.nodes[node]["age"] = age - - def edges(self) -> List[Tuple[int, int]]: - """List of (parent, child) tuples""" - tree = self.tree - return list(tree.edges) - - def get_edge_length(self, parent: int, child: int) -> float: - tree = self.tree - assert parent in tree - assert child in tree[parent] - return tree.edges[parent, child]["length"] - - def set_edge_length(self, parent: int, child: int, length: float) -> None: - tree = self.tree - assert parent in tree - assert child in tree[parent] - tree.edges[parent, child]["length"] = length - - def set_edge_lengths( - self, parent_child_and_length_list: List[Tuple[int, int, float]] - ) -> None: - for (parent, child, length) in parent_child_and_length_list: - self.set_edge_length(parent, child, length) - - def children(self, node: int) -> List[int]: - tree = self.tree - return list(tree.adj[node]) - - def to_newick_tree_format( - self, - print_node_names: bool = True, - print_internal_nodes: bool = False, - append_state_to_node_name: bool = False, - print_pct_of_mutated_characters_along_edge: bool = False, - add_N_to_node_id: bool = False, - fmt_branch_lengths: str = "%s", - ) -> str: - r""" - Converts tree into Newick tree format. - - Args: - print_internal_nodes: If True, prints the names of internal - nodes too. - print_pct_of_mutated_characters_along_edge: Self-explanatory - TODO - """ - leaves = self.leaves() - - def format_node(v: int): - node_id_prefix = "" if not add_N_to_node_id else "N" - node_id = "" if not print_node_names else str(v) - node_suffix = ( - "" - if not append_state_to_node_name - else "_" + str(self.get_state(v)) - ) - return node_id_prefix + node_id + node_suffix - - def subtree_newick_representation(v: int) -> str: - if len(self.children(v)) == 0: - return format_node(v) - subtrees_newick = [] - for child in self.children(v): - edge_length = self.get_edge_length(v, child) - if child in leaves: - subtree_newick = subtree_newick_representation(child) - else: - subtree_newick = ( - "(" + subtree_newick_representation(child) + ")" - ) - if print_internal_nodes: - subtree_newick += format_node(child) - # Add edge length - subtree_newick = ( - subtree_newick + ":" + (fmt_branch_lengths % edge_length) - ) - if print_pct_of_mutated_characters_along_edge: - # Also add number of mutations - number_of_unmutated_characters_in_parent = self.get_state( - v - ).count("0") - pct_of_mutated_characters_along_edge = ( - self.number_of_mutations_along_edge(v, child) - / (number_of_unmutated_characters_in_parent + 1e-100) - ) - subtree_newick = ( - subtree_newick + "[&&NHX:muts=" - f"{self._fmt(pct_of_mutated_characters_along_edge)}]" - ) - subtrees_newick.append(subtree_newick) - newick = ",".join(subtrees_newick) - return newick - - root = self.root() - res = "(" + subtree_newick_representation(root) + ")" - if print_internal_nodes: - res += format_node(root) - res += ");" - return res - - def _fmt(self, x: float): - return "%.2f" % x - - def reconstruct_ancestral_states(self): - r""" - Reconstructs ancestral states with maximum parsimony. - """ - root = self.root() - - def dfs(v: int) -> None: - children = self.children(v) - n_children = len(children) - if n_children == 0: - return - for child in children: - dfs(child) - children_states = [self.get_state(child) for child in children] - n_characters = len(children_states[0]) - state = "" - for character_id in range(n_characters): - states_for_this_character = set( - [ - children_states[i][character_id] - for i in range(n_children) - ] - ) - if len(states_for_this_character) == 1: - state += states_for_this_character.pop() - else: - state += "0" - self.set_state(v, state) - if v == root: - # Reset state to all zeros! - self.set_state(v, "0" * n_characters) - - dfs(root) - - def copy_branch_lengths(self, tree_other): - r""" - Copies the branch lengths of tree_other onto self - """ - assert self.nodes() == tree_other.nodes() - assert self.edges() == tree_other.edges() - - for node in self.nodes(): - new_age = tree_other.get_age(node) - self.set_age(node, age=new_age) - - for (parent, child) in self.edges(): - new_edge_length = tree_other.get_age(parent) - tree_other.get_age( - child - ) - self.set_edge_length(parent, child, length=new_edge_length) - - def print_edges(self): - for (parent, child) in self.edges(): - print( - f"{parent}[{self.get_state(parent)}] -> " - f"{child}[{self.get_state(child)}]: " - f"{self.get_edge_length(parent, child)}" - ) - - def num_cuts(self, v: int) -> int: - # TODO: Hardcoded '0'... - res = self.num_characters() - self.get_state(v).count("0") - return res - - def parent(self, v: int) -> int: - if v == self.root(): - raise ValueError("Asked for parent of root node!") - incident_edges_at_v = [edge for edge in self.edges() if edge[1] == v] - assert len(incident_edges_at_v) == 1 - return incident_edges_at_v[0][0] - - def set_edge_lengths_from_node_ages(self) -> None: - r""" - Sets the edge lengths to match the node ages. - """ - for (parent, child) in self.edges(): - self.set_edge_length( - parent, child, self.get_age(parent) - self.get_age(child) - ) - - def num_ancestors(self, node: int) -> int: - r""" - Number of ancestors of a node. Terribly inefficient implementation. - """ - res = 0 - root = self.root() - while node != root: - node = self.parent(node) - res += 1 - return res - - def number_of_mutations_along_edge(self, parent, child): - return self.get_state(parent).count("0") - self.get_state(child).count( - "0" - ) - - def number_of_nonmutations_along_edge(self, parent, child): - return self.get_state(child).count("0") - - def num_uncut(self, v): - return self.get_state(v).count("0") - - def depth(self) -> int: - r""" - Depth of the tree. - E.g. the tree 0 -> 1 has depth 1. - """ - - def dfs(v): - res = 0 - for child in self.children(v): - res = max(res, dfs(child) + 1) - return res - - res = dfs(self.root()) - return res - - def scale(self, factor: float): - r""" - The branch lengths of the tree are all scaled by this factor - """ - for node in self.nodes(): - self.set_age(node, factor * self.get_age(node)) - self.set_edge_lengths_from_node_ages() - - def __str__(self): - def node_str(p, v): - res = "" - if p is not None: - res += f"({self.number_of_mutations_along_edge(p, v)})" - res += self.get_state(v) - return res - - def dfs(p, v, depth) -> List[str]: - res = ["\t" * depth + node_str(p, v) + "\n"] - for c in self.children(v): - res += dfs(v, c, depth + 1) - return res - - return "".join(dfs(None, self.root(), 0)) diff --git a/test/tools_tests/tree_test.py b/test/tools_tests/tree_test.py deleted file mode 100644 index 29917716..00000000 --- a/test/tools_tests/tree_test.py +++ /dev/null @@ -1,188 +0,0 @@ -import networkx as nx - -from cassiopeia.tools.tree import Tree - - -def test_to_newick_tree_format(): - r""" - Example tree based off https://itol.embl.de/help.cgi#upload . - The most basic newick example should give: - (2:0.5,(4:0.3,5:0.4):0.2):0.1); - """ - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4, 5]) - tree.add_edges_from([(0, 1), (1, 2), (1, 3), (3, 4), (3, 5)]) - tree = Tree(tree) - tree.set_edge_lengths( - [(0, 1, 0.1), (1, 2, 0.5), (1, 3, 0.2), (3, 4, 0.3), (3, 5, 0.4)] - ) - tree.set_states( - [ - (0, "0000000000"), - (1, "1000000000"), - (2, "1111000000"), - (3, "1110000000"), - (4, "1110000111"), - (5, "1110111111"), - ] - ) - res = tree.to_newick_tree_format(print_internal_nodes=False) - assert res == "((2:0.5,(4:0.3,5:0.4):0.2):0.1));" - res = tree.to_newick_tree_format( - print_node_names=False, - print_internal_nodes=True, - append_state_to_node_name=True, - ) - assert ( - res == "((_1111000000:0.5,(_1110000111:0.3,_1110111111:0.4)" - "_1110000000:0.2)_1000000000:0.1)_0000000000);" - ) - res = tree.to_newick_tree_format(print_internal_nodes=True) - assert res == "((2:0.5,(4:0.3,5:0.4)3:0.2)1:0.1)0);" - res = tree.to_newick_tree_format(print_node_names=False) - assert res == "((:0.5,(:0.3,:0.4):0.2):0.1));" - res = tree.to_newick_tree_format( - print_internal_nodes=True, add_N_to_node_id=True - ) - assert res == "((N2:0.5,(N4:0.3,N5:0.4)N3:0.2)N1:0.1)N0);" - res = tree.to_newick_tree_format( - print_internal_nodes=True, - append_state_to_node_name=True, - add_N_to_node_id=True, - ) - assert ( - res == "((N2_1111000000:0.5,(N4_1110000111:0.3,N5_1110111111:0.4)" - "N3_1110000000:0.2)N1_1000000000:0.1)N0_0000000000);" - ) - res = tree.to_newick_tree_format( - print_internal_nodes=True, - print_pct_of_mutated_characters_along_edge=True, - add_N_to_node_id=True, - ) - assert ( - res == "((N2:0.5[&&NHX:muts=0.33],(N4:0.3[&&NHX:muts=0.43]," - "N5:0.4[&&NHX:muts=0.86])N3:0.2[&&NHX:muts=0.22])" - "N1:0.1[&&NHX:muts=0.10])N0);" - ) - - -def test_reconstruct_ancestral_states(): - tree = nx.DiGraph() - tree.add_nodes_from(list(range(17))) - tree.add_edges_from( - [ - (10, 11), - (11, 13), - (13, 0), - (13, 1), - (11, 14), - (14, 2), - (14, 3), - (10, 12), - (12, 15), - (15, 4), - (15, 5), - (12, 16), - (16, 6), - (16, 7), - (16, 8), - (16, 9), - ] - ) - tree = Tree(tree) - tree.set_states( - [ - (0, "01101110100"), - (1, "01211111111"), - (2, "01322121111"), - (3, "01432122111"), - (4, "01541232111"), - (5, "01651233111"), - (6, "01763243111"), - (7, "01873240111"), - (8, "01983240111"), - (9, "01093240010"), - ] - ) - tree.reconstruct_ancestral_states() - assert tree.get_state(10) == "00000000000" - assert tree.get_state(11) == "01000100100" - assert tree.get_state(13) == "01001110100" - assert tree.get_state(14) == "01002120111" - assert tree.get_state(12) == "01000200010" - assert tree.get_state(15) == "01001230111" - assert tree.get_state(16) == "01003240010" - - -def test_reconstruct_ancestral_states_DREAM_challenge_tree_25(): - tree = nx.DiGraph() - tree.add_nodes_from(list(range(21))) - tree.add_edges_from( - [ - (9, 8), - (8, 10), - (8, 7), - (7, 11), - (7, 12), - (9, 6), - (6, 2), - (2, 0), - (0, 13), - (0, 14), - (2, 1), - (1, 15), - (1, 16), - (6, 5), - (5, 3), - (3, 17), - (3, 18), - (5, 4), - (4, 19), - (4, 20), - ] - ) - tree = Tree(tree) - tree.set_states( - [ - (10, "0022100000"), - (11, "0022100000"), - (12, "0022100000"), - (13, "2012000220"), - (14, "2012000200"), - (15, "2012000100"), - (16, "2012000100"), - (17, "0001110220"), - (18, "0001110220"), - (19, "0000210220"), - (20, "0000210220"), - ] - ) - tree.reconstruct_ancestral_states() - assert tree.get_state(7) == "0022100000" - assert tree.get_state(8) == "0022100000" - assert tree.get_state(0) == "2012000200" - assert tree.get_state(1) == "2012000100" - assert tree.get_state(2) == "2012000000" - assert tree.get_state(3) == "0001110220" - assert tree.get_state(4) == "0000210220" - assert tree.get_state(5) == "0000010220" - assert tree.get_state(6) == "0000000000" - assert tree.get_state(9) == "0000000000" - - -def test_depth(): - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4]) - tree.add_edges_from([(0, 1), (0, 2), (0, 3), (2, 4)]) - tree = Tree(tree) - assert tree.depth() == 2 - - -def test_str(): - tree = nx.DiGraph() - tree.add_nodes_from([0, 1, 2, 3, 4]) - tree.add_edges_from([(0, 1), (0, 2), (0, 3), (2, 4)]) - tree = Tree(tree) - tree.set_states([(0, "00"), (1, "01"), (2, "00"), (3, "10"), (4, "11")]) - res = str(tree) - assert res == "00\n\t(1)01\n\t(0)00\n\t\t(2)11\n\t(1)10\n" From ee797b58ede62fdc7077a784dcfeef00122e318c Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sat, 6 Feb 2021 22:42:58 -0800 Subject: [PATCH 43/61] Easy optimization of my bayesian estimator DP: only visit states with cuts_p <= x <= cuts_v --- .../IIDExponentialPosteriorMeanBLE.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 8f478116..5ced3e67 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -142,12 +142,31 @@ def compatible_with_observed_data(self, x, observed_cuts) -> bool: else: return x <= observed_cuts + def state_is_valid(self, v, t, x) -> bool: + r""" + Used to optimize the DP by avoiding states with 0 probability. + The number of mutations should be between those of v and its parent. + """ + tree = self.tree + if v == tree.root: + return x == 0 + p = tree.parent(v) + cuts_v = tree.get_number_of_mutated_characters_in_node(v) + cuts_p = tree.get_number_of_mutated_characters_in_node(p) + if self.enforce_parsimony: + return cuts_p <= x <= cuts_v + else: + return x <= cuts_v + def up(self, v, t, x) -> float: r""" TODO: Rename this _up? log P(X_up(b(v)), T_up(b(v)), t \in t_b(v), X_b(v)(t) = x) TODO: Update to match my technical write-up. """ + # Avoid doing anything at all for invalid states. + if not self.state_is_valid(v, t, x): + return -np.inf if (v, t, x) in self.up_cache: # TODO: Use arrays # TODO: Use a decorator instead of a hand-made cache return self.up_cache[(v, t, x)] @@ -208,6 +227,9 @@ def down(self, v, t, x) -> float: log P(X_down(v), T_down(v) | t_v = t, X_v = x) TODO: Update to match my technical write-up. """ + # Avoid doing anything at all for invalid states. + if not self.state_is_valid(v, t, x): + return -np.inf if (v, t, x) in self.down_cache: # TODO: Use a decorator instead of a hand-made cache return self.down_cache[(v, t, x)] From d36e5543f1c682fd09240db1e9d231b19b4acd7c Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 7 Feb 2021 10:16:37 -0800 Subject: [PATCH 44/61] Make some methods private --- .../IIDExponentialPosteriorMeanBLE.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 5ced3e67..a9821d0b 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -135,14 +135,14 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: self._compute_posteriors() self._populate_branch_lengths() - def compatible_with_observed_data(self, x, observed_cuts) -> bool: + def _compatible_with_observed_data(self, x, observed_cuts) -> bool: # TODO: Make method private if self.enforce_parsimony: return x == observed_cuts else: return x <= observed_cuts - def state_is_valid(self, v, t, x) -> bool: + def _state_is_valid(self, v, t, x) -> bool: r""" Used to optimize the DP by avoiding states with 0 probability. The number of mutations should be between those of v and its parent. @@ -165,7 +165,7 @@ def up(self, v, t, x) -> float: TODO: Update to match my technical write-up. """ # Avoid doing anything at all for invalid states. - if not self.state_is_valid(v, t, x): + if not self._state_is_valid(v, t, x): return -np.inf if (v, t, x) in self.up_cache: # TODO: Use arrays # TODO: Use a decorator instead of a hand-made cache @@ -207,7 +207,7 @@ def up(self, v, t, x) -> float: if v != tree.root: # TODO: 'tree.root()' is O(n). We should have O(1) method. p = tree.parent(v) - if self.compatible_with_observed_data( + if self._compatible_with_observed_data( x, tree.get_number_of_mutated_characters_in_node(p) ): siblings = [u for u in tree.children(p) if u != v] @@ -228,7 +228,7 @@ def down(self, v, t, x) -> float: TODO: Update to match my technical write-up. """ # Avoid doing anything at all for invalid states. - if not self.state_is_valid(v, t, x): + if not self._state_is_valid(v, t, x): return -np.inf if (v, t, x) in self.down_cache: # TODO: Use a decorator instead of a hand-made cache @@ -273,7 +273,7 @@ def down(self, v, t, x) -> float: # TODO: Allow for weak match at internal nodes and exact match at # leaves. if ( - self.compatible_with_observed_data( + self._compatible_with_observed_data( x, tree.get_number_of_mutated_characters_in_node(v) ) and v not in tree.leaves From 86c0ad024187a721c84c5f4d6a6fa83e716492c1 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 7 Feb 2021 12:31:12 -0800 Subject: [PATCH 45/61] Address some TODOs --- cassiopeia/data/CassiopeiaTree.py | 5 + .../IIDExponentialPosteriorMeanBLE.py | 94 +++++++++++-------- 2 files changed, 58 insertions(+), 41 deletions(-) diff --git a/cassiopeia/data/CassiopeiaTree.py b/cassiopeia/data/CassiopeiaTree.py index 1c1ce671..dbf8a19b 100644 --- a/cassiopeia/data/CassiopeiaTree.py +++ b/cassiopeia/data/CassiopeiaTree.py @@ -794,6 +794,8 @@ def get_mutations_along_edge( Returns a list of tuples (character, state) of mutations that occur along an edge. Characters are 0-indexed. + WARNING: A character dropout event will also be considered a mutation! + Args: parent: parent in tree child: child in tree @@ -835,6 +837,9 @@ def get_number_of_unmutated_characters_in_node( def get_number_of_mutated_characters_in_node( self, node: str ) -> int: + r""" + WARNING: dropped out characters will be considered as mutated too! + """ return self.n_character -\ self.get_number_of_unmutated_characters_in_node(node) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index a9821d0b..80cffd55 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -1,5 +1,6 @@ import multiprocessing from copy import deepcopy +import time from typing import Optional, Tuple import numpy as np @@ -19,22 +20,13 @@ class IIDExponentialPosteriorMeanBLE(BranchLengthEstimator): of the MLE. The phylogeny model is chosen to be a birth process. This estimator requires that the ancestral states are provided. - TODO: Allow for two versions: one where the number of mutations of each - node must match exactly, and one where it must be upper bounded by the - number of mutations seen. (I believe the latter should ameliorate - subtree collapse further.) - TODO: Use numpy autograd to do optimize the hyperparams? (Empirical Bayes) + TODO: Use numpy autograd to optimize the hyperparams? (Empirical Bayes) We compute the posterior means using a forward-backward-style algorithm (DP on a tree). - Args: - mutation_rate: TODO - birth_rate: TODO - discretization_level: TODO - enforce_parsimony: TODO - verbose: Verbosity level. TODO + Args: TODO Attributes: TODO @@ -46,6 +38,7 @@ def __init__( birth_rate: float, discretization_level: int, enforce_parsimony: bool = True, + use_cpp_implementation: bool = False ) -> None: # TODO: If we use autograd, we can tune the hyperparams with gradient # descent? @@ -55,10 +48,10 @@ def __init__( self.birth_rate = birth_rate self.discretization_level = discretization_level self.enforce_parsimony = enforce_parsimony + self.use_cpp_implementation = use_cpp_implementation def _compute_log_likelihood(self): tree = self.tree - discretization_level = self.discretization_level log_likelihood = 0 # TODO: Should I also add a division event when the root has multiple # children? (If not, the joint we are computing won't integrate to 1; @@ -128,15 +121,39 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: r""" See base class. """ - self.down_cache = {} # TODO: Rename to _down_cache - self.up_cache = {} # TODO: Rename to _up_cache + self._down_cache = {} + self._up_cache = {} self.tree = tree + if self.use_cpp_implementation: + time_cpp_start = time.time() + self._populate_cache_with_cpp_implementation() + time_cpp_end = time.time() + print(f"time_cpp = {time_cpp_end - time_cpp_start}") + time_compute_log_likelihood_start = time.time() self._compute_log_likelihood() + time_compute_log_likelihood_end = time.time() + print(f"time_compute_log_likelihood (dp_down) = {time_compute_log_likelihood_end - time_compute_log_likelihood_start}") + time_compute_posteriors_start = time.time() self._compute_posteriors() + time_compute_posteriors_end = time.time() + print(f"time_compute_posteriors (dp_up) = {time_compute_posteriors_end - time_compute_posteriors_start}") + time_populate_branch_lengths_start = time.time() self._populate_branch_lengths() + time_populate_branch_lengths_end = time.time() + print(f"time_populate_branch_lengths = {time_populate_branch_lengths_end - time_populate_branch_lengths_start}") + + def _populate_cache_with_cpp_implementation(self): + r""" + A cpp implementation is run to compute up and down caches, which is + the computational bottleneck. + """ + # First extract the relevant information from the tree. + # Serialize the tree information + # Run the c++ implementation + # Read the cache values + pass def _compatible_with_observed_data(self, x, observed_cuts) -> bool: - # TODO: Make method private if self.enforce_parsimony: return x == observed_cuts else: @@ -167,9 +184,12 @@ def up(self, v, t, x) -> float: # Avoid doing anything at all for invalid states. if not self._state_is_valid(v, t, x): return -np.inf - if (v, t, x) in self.up_cache: # TODO: Use arrays - # TODO: Use a decorator instead of a hand-made cache - return self.up_cache[(v, t, x)] + if (v, t, x) in self._up_cache: # TODO: Use arrays? + # TODO: Use a decorator instead of a hand-made cache? + return self._up_cache[(v, t, x)] + if self.use_cpp_implementation: + raise ValueError(f"Bug in cpp implementation: State up({(v, t, x)})" + f" was not populated.") # Pull out params r = self.mutation_rate lam = self.birth_rate @@ -205,20 +225,19 @@ def up(self, v, t, x) -> float: ) # Case 3: A cell division happened if v != tree.root: - # TODO: 'tree.root()' is O(n). We should have O(1) method. p = tree.parent(v) if self._compatible_with_observed_data( - x, tree.get_number_of_mutated_characters_in_node(p) + x, tree.get_number_of_mutated_characters_in_node(p) # If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. ): siblings = [u for u in tree.children(p) if u != v] ll = ( np.log(lam * dt) - + self.up(p, t - 1, x) - + sum([self.down(u, t, x) for u in siblings]) + + self.up(p, t - 1, x) # If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. + + sum([self.down(u, t, x) for u in siblings]) # If we want to ignore missing data, we just have to replace x by cuts(p)+gone_missing(p->u). I.e. dropped out characters become free mutations. ) log_likelihoods_cases.append(ll) log_likelihood = logsumexp(log_likelihoods_cases) - self.up_cache[(v, t, x)] = log_likelihood + self._up_cache[(v, t, x)] = log_likelihood return log_likelihood def down(self, v, t, x) -> float: @@ -230,9 +249,12 @@ def down(self, v, t, x) -> float: # Avoid doing anything at all for invalid states. if not self._state_is_valid(v, t, x): return -np.inf - if (v, t, x) in self.down_cache: - # TODO: Use a decorator instead of a hand-made cache - return self.down_cache[(v, t, x)] + if (v, t, x) in self._down_cache: + # TODO: Use a decorator instead of a hand-made cache? + return self._down_cache[(v, t, x)] + if self.use_cpp_implementation: + raise ValueError(f"Bug in cpp implementation: State " + f"down({(v, t, x)}) was not populated.") # Pull out params discretization_level = self.discretization_level r = self.mutation_rate @@ -248,11 +270,9 @@ def down(self, v, t, x) -> float: log_likelihood = 0.0 if t == discretization_level: # Base case if ( - v in tree.leaves + tree.is_leaf(v) and x == tree.get_number_of_mutated_characters_in_node(v) ): - # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) - # check. log_likelihood = 0.0 else: log_likelihood = -np.inf @@ -270,22 +290,18 @@ def down(self, v, t, x) -> float: ) # Case 3: Cell divides # The number of cuts at this state must match the ground truth. - # TODO: Allow for weak match at internal nodes and exact match at - # leaves. if ( self._compatible_with_observed_data( x, tree.get_number_of_mutated_characters_in_node(v) ) - and v not in tree.leaves + and not tree.is_leaf(v) ): - # TODO: 'v not in tree.leaves()' is O(n). We should have O(1) - # check. ll = sum( - [self.down(child, t + 1, x) for child in tree.children(v)] + [self.down(child, t + 1, x) for child in tree.children(v)] # If we want to ignore missing data, we just have to replace x by x+gone_missing(p->v). I.e. dropped out characters become free mutations. ) + np.log(lam * dt) log_likelihoods_cases.append(ll) log_likelihood = logsumexp(log_likelihoods_cases) - self.down_cache[(v, t, x)] = log_likelihood + self._down_cache[(v, t, x)] = log_likelihood return log_likelihood @classmethod @@ -485,11 +501,7 @@ class IIDExponentialPosteriorMeanBLEGridSearchCV(BranchLengthEstimator): This class fits the hyperparameters of IIDExponentialPosteriorMeanBLE based on data log-likelihood. I.e. is performs empirical Bayes. - Args: - mutation_rates: TODO - birth_rate: TODO - discretization_level: TODO - verbose: Verbosity level. TODO + Args: TODO """ def __init__( From dd4f0bf72e0caf07bda4d79c0c4d01621fab34e2 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 7 Feb 2021 19:24:15 -0800 Subject: [PATCH 46/61] Add c++ implementation of Bayesian estimator --- .../BranchLengthEstimator.py | 6 + .../IIDExponentialPosteriorMeanBLE.cpp | 438 ++++++++++++++++++ .../IIDExponentialPosteriorMeanBLE.py | 286 ++++++++++-- docs/requirements.txt | 2 +- .../branch_length_estimator_test.py | 43 +- 5 files changed, 724 insertions(+), 51 deletions(-) create mode 100644 cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp diff --git a/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py index 33f8bddd..cd94456e 100644 --- a/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py +++ b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py @@ -3,6 +3,12 @@ from cassiopeia.data import CassiopeiaTree +class BranchLengthEstimatorError(Exception): + """An Exception class for the CassiopeiaTree class.""" + + pass + + class BranchLengthEstimator(abc.ABC): r""" Abstract base class for all branch length estimators. diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp new file mode 100644 index 00000000..c2eaaf3d --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp @@ -0,0 +1,438 @@ +#include +#include +#include +#include +#include +#include +#include + +#define forn(i, n) for(int i = 0; i < int(n); i++) +#define forall(i,c) for(typeof((c).begin()) i = (c).begin();i != (c).end();i++) + +using namespace std; +const int maxN = 1024; +const int maxK = 128; +const int maxT = 512; +const float INF = 1e16; +float _down_cache[maxN][maxT][maxK]; +float _up_cache[maxN][maxT][maxK]; + +string input_dir = ""; +string output_dir = ""; +int N = -1; +vector children[maxN]; +int root = -1; +int is_internal_node[maxN]; +int get_number_of_mutated_characters_in_node[maxN]; +vector non_root_internal_nodes; +vector leaves; +int parent[maxN]; +int is_leaf[maxN]; +int K = -1; +int T = -1; +int enforce_parsimony = -1; +float r; +float lam; + +float logsumexp(const vector & lls){ + float mx = -INF; + for(auto ll: lls){ + mx = max(mx, ll); + } + float res = 0.0; + for(auto ll: lls){ + res += exp(ll - mx); + } + res = log(res) + mx; + return res; +} + +int _compatible_with_observed_data(int x, int observed_cuts){ + if(enforce_parsimony) + return x == observed_cuts; + else + return x <= observed_cuts; +} + +bool _state_is_valid(int v, int t, int x){ + if(v == root) + return x == 0; + int p = parent[v]; + int cuts_v = get_number_of_mutated_characters_in_node[v]; + int cuts_p = get_number_of_mutated_characters_in_node[p]; + if(enforce_parsimony){ + return cuts_p <= x && x <= cuts_v; + } else { + return x <= cuts_v; + } +} + + +void read_N(){ + ifstream fin(input_dir + "/N.txt"); + if(!fin.good()){ + cerr << "N input file not found" << endl; + exit(1); + } + fin >> N; + if(N == -1){ + cerr << "N input corrupted" << endl; + exit(1); + } + if(N >= maxN - 10){ + cerr << "N larger than maxN" << endl; + exit(1); + } +} + +void read_children(){ + ifstream fin(input_dir + "/children.txt"); + if(!fin.good()){ + cerr << "children input file not found" << endl; + exit(1); + } + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + int n_children; + fin >> n_children; + for(int i = 0; i < n_children; i++){ + int c; + fin >> c; + children[v].push_back(c); + } + } + if(lines_read != N){ + cerr << "children input corrupted" << endl; + exit(1); + } +} + +void read_root(){ + ifstream fin(input_dir + "/root.txt"); + fin >> root; + if(root == -1){ + cerr << "N input corrupted" << endl; + exit(1); + } +} + +void read_is_internal_node(){ + ifstream fin(input_dir + "/is_internal_node.txt"); + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + fin >> is_internal_node[v]; + } + if(lines_read != N){ + cerr << "is_internal_node input corrupted" << endl; + exit(1); + } +} + +void read_get_number_of_mutated_characters_in_node(){ + ifstream fin(input_dir + "/get_number_of_mutated_characters_in_node.txt"); + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + fin >> get_number_of_mutated_characters_in_node[v]; + } + if(lines_read != N){ + cerr << "get_number_of_mutated_characters_in_node input corrupted" << endl; + exit(1); + } +} + +void read_non_root_internal_nodes(){ + ifstream fin(input_dir + "/non_root_internal_nodes.txt"); + int v; + while(fin >> v){ + non_root_internal_nodes.push_back(v); + } +} + +void read_leaves(){ + ifstream fin(input_dir + "/leaves.txt"); + int v; + while(fin >> v){ + leaves.push_back(v); + } + if(leaves.size() == 0){ + cerr << "leaves input corrupted" << endl; + exit(1); + } +} + +void read_parent(){ + ifstream fin(input_dir + "/parent.txt"); + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + fin >> parent[v]; + } + if(lines_read != N - 1){ + cerr << "parent input corrupted" << endl; + exit(1); + } +} + +void read_is_leaf(){ + ifstream fin(input_dir + "/is_leaf.txt"); + int lines_read = 0; + int v; + while(fin >> v){ + lines_read++; + fin >> is_leaf[v]; + } + if(lines_read != N){ + cerr << "is_leaf input corrupted" << endl; + exit(1); + } +} + +void read_K(){ + ifstream fin(input_dir + "/K.txt"); + fin >> K; + if(K == -1){ + cerr << "K input corrupted" << endl; + exit(1); + } + if(K >= maxK - 10){ + cerr << "K larger than maxK" << endl; + exit(1); + } +} + +void read_T(){ + // T is the discretization level. + ifstream fin(input_dir + "/T.txt"); + fin >> T; + if(T == -1){ + cerr << "T input corrupted" << endl; + exit(1); + } + if(T >= maxT - 10){ + cerr << "T larger than maxT" << endl; + exit(1); + } +} + +void read_enforce_parsimony(){ + ifstream fin(input_dir + "/enforce_parsimony.txt"); + fin >> enforce_parsimony; + if(enforce_parsimony == -1){ + cerr << "enforce_parsimony input corrupted" << endl; + exit(1); + } +} + +void read_r(){ + ifstream fin(input_dir + "/r.txt"); + fin >> r; + if(r == -1){ + cerr << "r input corrupted" << endl; + exit(1); + } +} + +void read_lam(){ + ifstream fin(input_dir + "/lam.txt"); + fin >> lam; + if(lam == -1){ + cerr << "lam input corrupted" << endl; + exit(1); + } +} + +float down(int v, int t, int x){ + // Avoid doing anything at all for invalid states. + if(!_state_is_valid(v, t, x)){ + return -INF; + } + if(_down_cache[v][t][x] < 1.0){ + return _down_cache[v][t][x]; + } + // Pull out params + float dt = 1.0 / T; + assert(v != root); + assert(0 <= t && t <= T); + assert(0 <= x && x <= K); + if(!(1.0 - lam * dt - K * r * dt > 0)){ + cerr << "Please choose a bigger discretization_level." << endl; + exit(1); + } + float log_likelihood = 0.0; + if(t == T){ + // Base case + if ( + is_leaf[v] + && (x == get_number_of_mutated_characters_in_node[v]) + ){ + log_likelihood = 0.0; + } else { + log_likelihood = -INF; + } + } + else{ + // Recursion. + vector log_likelihoods_cases; + // Case 1: Nothing happens + log_likelihoods_cases.push_back( + log(1.0 - lam * dt - (K - x) * r * dt) + + down(v, t + 1, x) + ); + // Case 2: One character mutates. + if(x + 1 <= K){ + log_likelihoods_cases.push_back( + log((K - x) * r * dt) + down(v, t + 1, x + 1) + ); + } + // Case 3: Cell divides + // The number of cuts at this state must match the ground truth. + if ( + _compatible_with_observed_data( + x, get_number_of_mutated_characters_in_node[v] + ) + && (!is_leaf[v]) + ){ + float ll = 0.0; + forn(i, children[v].size()){ + int child = children[v][i]; + ll += down(child, t + 1, x);// If we want to ignore missing data, we just have to replace x by x+gone_missing(p->v). I.e. dropped out characters become free mutations. + } + ll += log(lam * dt); + log_likelihoods_cases.push_back(ll); + } + log_likelihood = logsumexp(log_likelihoods_cases); + } + _down_cache[v][t][x] = log_likelihood; + return log_likelihood; +} + +float up(int v, int t, int x){ + // Avoid doing anything at all for invalid states. + if(!_state_is_valid(v, t, x)) + return -INF; + if(_up_cache[v][t][x] < 1.0){ + return _up_cache[v][t][x]; + } + // Pull out params + float dt = 1.0 / T; + assert(0 <= v && v < N); + assert(0 <= t && t <= T); + assert(0 <= x && x <= K); + if(!(1.0 - lam * dt - K * r * dt > 0)){ + cerr << "Please choose a bigger discretization_level." << endl; + exit(1); + } + float log_likelihood = 0.0; + if(v == root){ + // Base case: we reached the root of the tree. + if((t == 0) && (x == get_number_of_mutated_characters_in_node[v])) + log_likelihood = 0.0; + else + log_likelihood = -INF; + } else if(t == 0){ + // Base case: we reached the start of the process, but we're not yet + // at the root. + assert(v != root); + log_likelihood = -INF; + } else { + // Recursion. + vector log_likelihoods_cases; + // Case 1: Nothing happened + log_likelihoods_cases.push_back( + log(1.0 - lam * dt - (K - x) * r * dt) + up(v, t - 1, x) + ); + // Case 2: Mutation happened + if(x - 1 >= 0){ + log_likelihoods_cases.push_back( + log((K - (x - 1)) * r * dt) + up(v, t - 1, x - 1) + ); + } + // Case 3: A cell division happened + if(v != root){ + int p = parent[v]; + if(_compatible_with_observed_data( + x, get_number_of_mutated_characters_in_node[p] // If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. + )){ + vector siblings; + for(auto u: children[p]) + if(u != v) + siblings.push_back(u); + float ll = log(lam * dt) + up(p, t - 1, x); // If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. + for(auto u: siblings){ + ll += down(u, t, x); // If we want to ignore missing data, we just have to replace x by cuts(p)+gone_missing(p->u). I.e. dropped out characters become free mutations. + } + log_likelihoods_cases.push_back(ll); + } + } + log_likelihood = logsumexp(log_likelihoods_cases); + } + _up_cache[v][t][x] = log_likelihood; + return log_likelihood; +} + +void write_down(){ + ofstream fout(output_dir + "/down.txt"); + string res = ""; + forn(v, N){ + if(v == root) continue; + forn(t, T + 1){ + forn(x, K + 1){ + if(_state_is_valid(v, t, x)){ + res += to_string(v) + " " + to_string(t) + " " + to_string(x) + " " + to_string(down(v, t, x)) + "\n"; + } + } + } + } + fout << res; +} + +void write_up(){ + ofstream fout(output_dir + "/up.txt"); + string res = ""; + forn(v, N){ + forn(t, T + 1){ + forn(x, K + 1){ + if(_state_is_valid(v, t, x)){ + res += to_string(v) + " " + to_string(t) + " " + to_string(x) + " " + to_string(up(v, t, x)) + "\n"; + } + } + } + } + fout << res; +} + + +int main(int argc, char *argv[]){ + if(argc != 3){ + cerr << "Need intput_dir and output_dir arguments, " << argc - 1 << " provided." << endl; + exit(1); + } + input_dir = string(argv[1]); + output_dir = string(argv[2]); + read_N(); + read_children(); + read_root(); + read_is_internal_node(); + read_get_number_of_mutated_characters_in_node(); + read_non_root_internal_nodes(); + read_leaves(); + read_parent(); + read_is_leaf(); + read_K(); + read_T(); + read_enforce_parsimony(); + read_r(); + read_lam(); + + forn(v, N) forn(t, T + 1) forn(k, K + 1) _down_cache[v][t][k] = _up_cache[v][t][k] = 1.0; + write_down(); + write_up(); + return 0; +} diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 80cffd55..793e76a1 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -1,15 +1,22 @@ import multiprocessing -from copy import deepcopy +import os +import subprocess +import tempfile import time -from typing import Optional, Tuple +from copy import deepcopy +from typing import List, Optional, Tuple import numpy as np from scipy import integrate from scipy.special import binom, logsumexp from cassiopeia.data import CassiopeiaTree -from .BranchLengthEstimator import BranchLengthEstimator + from . import utils +from .BranchLengthEstimator import ( + BranchLengthEstimator, + BranchLengthEstimatorError, +) class IIDExponentialPosteriorMeanBLE(BranchLengthEstimator): @@ -38,7 +45,8 @@ def __init__( birth_rate: float, discretization_level: int, enforce_parsimony: bool = True, - use_cpp_implementation: bool = False + use_cpp_implementation: bool = False, + debug_cpp_implementation: bool = False, ) -> None: # TODO: If we use autograd, we can tune the hyperparams with gradient # descent? @@ -49,6 +57,7 @@ def __init__( self.discretization_level = discretization_level self.enforce_parsimony = enforce_parsimony self.use_cpp_implementation = use_cpp_implementation + self.debug_cpp_implementation = debug_cpp_implementation def _compute_log_likelihood(self): tree = self.tree @@ -124,34 +133,230 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: self._down_cache = {} self._up_cache = {} self.tree = tree + if self.debug_cpp_implementation: + # Write out true dp values to check by eye against c++ + # implementation values. + self._write_out_dps() if self.use_cpp_implementation: time_cpp_start = time.time() - self._populate_cache_with_cpp_implementation() + if self.debug_cpp_implementation: + # Use a directory that won't go away. + self._populate_cache_with_cpp_implementation( + tmp_dir=os.getcwd() + "/tmp" + ) + else: + # Use a temporary directory. + with tempfile.TemporaryDirectory() as tmp_dir: + self._populate_cache_with_cpp_implementation(tmp_dir) time_cpp_end = time.time() print(f"time_cpp = {time_cpp_end - time_cpp_start}") time_compute_log_likelihood_start = time.time() self._compute_log_likelihood() time_compute_log_likelihood_end = time.time() - print(f"time_compute_log_likelihood (dp_down) = {time_compute_log_likelihood_end - time_compute_log_likelihood_start}") + print( + f"time_compute_log_likelihood (dp_down) = {time_compute_log_likelihood_end - time_compute_log_likelihood_start}" + ) time_compute_posteriors_start = time.time() self._compute_posteriors() time_compute_posteriors_end = time.time() - print(f"time_compute_posteriors (dp_up) = {time_compute_posteriors_end - time_compute_posteriors_start}") + print( + f"time_compute_posteriors (dp_up) = {time_compute_posteriors_end - time_compute_posteriors_start}" + ) time_populate_branch_lengths_start = time.time() self._populate_branch_lengths() time_populate_branch_lengths_end = time.time() - print(f"time_populate_branch_lengths = {time_populate_branch_lengths_end - time_populate_branch_lengths_start}") + print( + f"time_populate_branch_lengths = {time_populate_branch_lengths_end - time_populate_branch_lengths_start}" + ) - def _populate_cache_with_cpp_implementation(self): + def _write_out_dps(self): + r""" + For debugging the c++ implementation: + This writes out the down and up values of the correct python implementation to the files + tmp/down_true.txt + and + tmp/up_true.txt + respectively. + Compare these against tmp/down.txt and tmp/up.txt, which are the values computed by the c++ implementation. + """ + tree = self.tree + N = len(tree.nodes) + T = self.discretization_level + K = tree.n_character + id_to_node = dict(zip(range(len(tree.nodes)), tree.nodes)) + + if not os.path.exists("tmp"): + os.mkdir("tmp") + + res = "" + for v_id in range(N): + v = id_to_node[v_id] + if v == tree.root: + continue + for t in range(T + 1): + for x in range(K + 1): + if self._state_is_valid(v, t, x): + res += ( + str(v_id) + + " " + + str(t) + + " " + + str(x) + + " " + + str(self.down(v, t, x)) + + "\n" + ) + with open("tmp/down_true.txt", "w") as fout: + fout.write(res) + + res = "" + for v_id in range(N): + v = id_to_node[v_id] + for t in range(T + 1): + for x in range(K + 1): + if self._state_is_valid(v, t, x): + res += ( + str(v_id) + + " " + + str(t) + + " " + + str(x) + + " " + + str(self.up(v, t, x)) + + "\n" + ) + with open("tmp/up_true.txt", "w") as fout: + fout.write(res) + + def _write_out_list_of_lists(self, lls: List[List[int]], filename: str): + res = "" + for l in lls: + for i, x in enumerate(l): + if i: + res += " " + res += str(x) + res += "\n" + with open(filename, "w") as file: + file.write(res) + + def _populate_cache_with_cpp_implementation(self, tmp_dir): r""" A cpp implementation is run to compute up and down caches, which is the computational bottleneck. + To remove anything that has to do with the cpp implementation, you just + have to remove this function (and the gates around it). + I.e., this python implementation is loosely coupled to the cpp call: we + just have to remove the call to this method to turn it off, and all + other code will work just fine. This is because all that this method + does is *warm up the cache* with values computed from the cpp + subprocess, and the caching process is totally transparent to the + other methods of the class. """ - # First extract the relevant information from the tree. - # Serialize the tree information + # First extract the relevant information from the tree and serialize it. + tree = self.tree + node_to_id = dict(zip(tree.nodes, range(len(tree.nodes)))) + id_to_node = dict(zip(range(len(tree.nodes)), tree.nodes)) + + N = [[len(tree.nodes)]] + if not os.path.exists(tmp_dir): + os.mkdir(tmp_dir) + self._write_out_list_of_lists(N, f"{tmp_dir}/N.txt") + + children = [ + [node_to_id[v]] + + [len(tree.children(v))] + + [node_to_id[c] for c in tree.children(v)] + for v in tree.nodes + ] + self._write_out_list_of_lists(children, f"{tmp_dir}/children.txt") + + root = [[node_to_id[tree.root]]] + self._write_out_list_of_lists(root, f"{tmp_dir}/root.txt") + + is_internal_node = [ + [node_to_id[v], 1 * tree.is_internal_node(v)] for v in tree.nodes + ] + self._write_out_list_of_lists( + is_internal_node, f"{tmp_dir}/is_internal_node.txt" + ) + + get_number_of_mutated_characters_in_node = [ + [node_to_id[v], tree.get_number_of_mutated_characters_in_node(v)] + for v in tree.nodes + ] + self._write_out_list_of_lists( + get_number_of_mutated_characters_in_node, + f"{tmp_dir}/get_number_of_mutated_characters_in_node.txt", + ) + + non_root_internal_nodes = [ + [node_to_id[v]] for v in tree.non_root_internal_nodes + ] + self._write_out_list_of_lists( + non_root_internal_nodes, f"{tmp_dir}/non_root_internal_nodes.txt" + ) + + leaves = [[node_to_id[v]] for v in tree.leaves] + self._write_out_list_of_lists(leaves, f"{tmp_dir}/leaves.txt") + + parent = [ + [node_to_id[v], node_to_id[tree.parent(v)]] + for v in tree.nodes + if v != tree.root + ] + self._write_out_list_of_lists(parent, f"{tmp_dir}/parent.txt") + + K = [[tree.n_character]] + self._write_out_list_of_lists(K, f"{tmp_dir}/K.txt") + + T = [[self.discretization_level]] + self._write_out_list_of_lists(T, f"{tmp_dir}/T.txt") + + enforce_parsimony = [[1 * self.enforce_parsimony]] + self._write_out_list_of_lists( + enforce_parsimony, f"{tmp_dir}/enforce_parsimony.txt" + ) + + r = [[self.mutation_rate]] + self._write_out_list_of_lists(r, f"{tmp_dir}/r.txt") + + lam = [[self.birth_rate]] + self._write_out_list_of_lists(lam, f"{tmp_dir}/lam.txt") + + is_leaf = [[node_to_id[v], 1 * tree.is_leaf(v)] for v in tree.nodes] + self._write_out_list_of_lists(is_leaf, f"{tmp_dir}/is_leaf.txt") + # Run the c++ implementation - # Read the cache values - pass + try: + # os.system('IIDExponentialPosteriorMeanBLE') + subprocess.run( + [ + "./IIDExponentialPosteriorMeanBLE", + f"{tmp_dir}", + f"{tmp_dir}", + ], + check=True, + cwd=os.path.dirname(__file__), + ) + except subprocess.CalledProcessError: + raise BranchLengthEstimatorError( + "Couldn't run c++ implementation," + " or c++ implementation started running and errored." + ) + + # Load the c++ implementation results into the cache + with open(f"{tmp_dir}/down.txt", "r") as fin: + for line in fin: + v, t, x, ll = line.split(" ") + v, t, x = int(v), int(t), int(x) + ll = float(ll) + self._down_cache[(id_to_node[v], t, x)] = ll + with open(f"{tmp_dir}/up.txt", "r") as fin: + for line in fin: + v, t, x, ll = line.split(" ") + v, t, x = int(v), int(t), int(x) + ll = float(ll) + self._up_cache[(id_to_node[v], t, x)] = ll def _compatible_with_observed_data(self, x, observed_cuts) -> bool: if self.enforce_parsimony: @@ -187,16 +392,17 @@ def up(self, v, t, x) -> float: if (v, t, x) in self._up_cache: # TODO: Use arrays? # TODO: Use a decorator instead of a hand-made cache? return self._up_cache[(v, t, x)] - if self.use_cpp_implementation: - raise ValueError(f"Bug in cpp implementation: State up({(v, t, x)})" - f" was not populated.") + if self.use_cpp_implementation and not self.debug_cpp_implementation: + raise ValueError( + f"Bug in cpp implementation: State up({(v, t, x)})" + f" was not populated." + ) # Pull out params r = self.mutation_rate lam = self.birth_rate dt = 1.0 / self.discretization_level K = self.tree.n_character tree = self.tree - discretization_level = self.discretization_level assert 0 <= t <= self.discretization_level assert 0 <= x <= K if not (1.0 - lam * dt - K * r * dt > 0): @@ -227,13 +433,20 @@ def up(self, v, t, x) -> float: if v != tree.root: p = tree.parent(v) if self._compatible_with_observed_data( - x, tree.get_number_of_mutated_characters_in_node(p) # If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. + x, + tree.get_number_of_mutated_characters_in_node( + p + ), # If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. ): siblings = [u for u in tree.children(p) if u != v] ll = ( np.log(lam * dt) - + self.up(p, t - 1, x) # If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. - + sum([self.down(u, t, x) for u in siblings]) # If we want to ignore missing data, we just have to replace x by cuts(p)+gone_missing(p->u). I.e. dropped out characters become free mutations. + + self.up( + p, t - 1, x + ) # If we want to ignore missing data, we just have to replace x by x-gone_missing(p->v). I.e. dropped out characters become free mutations. + + sum( + [self.down(u, t, x) for u in siblings] + ) # If we want to ignore missing data, we just have to replace x by cuts(p)+gone_missing(p->u). I.e. dropped out characters become free mutations. ) log_likelihoods_cases.append(ll) log_likelihood = logsumexp(log_likelihoods_cases) @@ -252,9 +465,11 @@ def down(self, v, t, x) -> float: if (v, t, x) in self._down_cache: # TODO: Use a decorator instead of a hand-made cache? return self._down_cache[(v, t, x)] - if self.use_cpp_implementation: - raise ValueError(f"Bug in cpp implementation: State " - f"down({(v, t, x)}) was not populated.") + if self.use_cpp_implementation and not self.debug_cpp_implementation: + raise ValueError( + f"Bug in cpp implementation: State " + f"down({(v, t, x)}) was not populated." + ) # Pull out params discretization_level = self.discretization_level r = self.mutation_rate @@ -269,10 +484,9 @@ def down(self, v, t, x) -> float: raise ValueError("Please choose a bigger discretization_level.") log_likelihood = 0.0 if t == discretization_level: # Base case - if ( - tree.is_leaf(v) - and x == tree.get_number_of_mutated_characters_in_node(v) - ): + if tree.is_leaf( + v + ) and x == tree.get_number_of_mutated_characters_in_node(v): log_likelihood = 0.0 else: log_likelihood = -np.inf @@ -290,14 +504,13 @@ def down(self, v, t, x) -> float: ) # Case 3: Cell divides # The number of cuts at this state must match the ground truth. - if ( - self._compatible_with_observed_data( - x, tree.get_number_of_mutated_characters_in_node(v) - ) - and not tree.is_leaf(v) - ): + if self._compatible_with_observed_data( + x, tree.get_number_of_mutated_characters_in_node(v) + ) and not tree.is_leaf(v): ll = sum( - [self.down(child, t + 1, x) for child in tree.children(v)] # If we want to ignore missing data, we just have to replace x by x+gone_missing(p->v). I.e. dropped out characters become free mutations. + [ + self.down(child, t + 1, x) for child in tree.children(v) + ] # If we want to ignore missing data, we just have to replace x by x+gone_missing(p->v). I.e. dropped out characters become free mutations. ) + np.log(lam * dt) log_likelihoods_cases.append(ll) log_likelihood = logsumexp(log_likelihoods_cases) @@ -510,6 +723,7 @@ def __init__( birth_rates: Tuple[float] = (0,), discretization_level: int = 1000, enforce_parsimony: bool = True, + use_cpp_implementation: bool = False, processes: int = 6, verbose: bool = False, ): @@ -517,6 +731,7 @@ def __init__( self.birth_rates = birth_rates self.discretization_level = discretization_level self.enforce_parsimony = enforce_parsimony + self.use_cpp_implementation = use_cpp_implementation self.processes = processes self.verbose = verbose @@ -528,6 +743,7 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: birth_rates = self.birth_rates discretization_level = self.discretization_level enforce_parsimony = self.enforce_parsimony + use_cpp_implementation = self.use_cpp_implementation processes = self.processes verbose = self.verbose @@ -550,6 +766,7 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: birth_rate=birth_rate, discretization_level=discretization_level, enforce_parsimony=enforce_parsimony, + use_cpp_implementation=use_cpp_implementation, ) ) mutation_and_birth_rates.append((mutation_rate, birth_rate)) @@ -580,6 +797,7 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: birth_rate=best_birth_rate, discretization_level=discretization_level, enforce_parsimony=enforce_parsimony, + use_cpp_implementation=use_cpp_implementation, ) final_model.estimate_branch_lengths(tree) self.mutation_rate = best_mutation_rate diff --git a/docs/requirements.txt b/docs/requirements.txt index 6b535e2d..dcbdb1e1 100755 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -22,4 +22,4 @@ pydata-sphinx-theme>=0.4.0 python-Levenshtein pathlib typing_extensions; python_version < '3.8' - +parameterized diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index bc9f234f..d91ee912 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -1,22 +1,19 @@ import itertools import multiprocessing +import unittest from copy import deepcopy import networkx as nx import numpy as np import pytest -import unittest +from parameterized import parameterized from cassiopeia.data import CassiopeiaTree - -from cassiopeia.tools import ( - BirthProcess, - IIDExponentialBLE, - IIDExponentialBLEGridSearchCV, - IIDExponentialLineageTracer, - IIDExponentialPosteriorMeanBLE, - IIDExponentialPosteriorMeanBLEGridSearchCV, -) +from cassiopeia.tools import (BirthProcess, IIDExponentialBLE, + IIDExponentialBLEGridSearchCV, + IIDExponentialLineageTracer, + IIDExponentialPosteriorMeanBLE, + IIDExponentialPosteriorMeanBLEGridSearchCV) class TestIIDExponentialBLE(unittest.TestCase): @@ -547,6 +544,7 @@ def get_z_scores( birth_rate=birth_rate_model, mutation_rate=mutation_rate_model, discretization_level=discretization_level, + use_cpp_implementation=True ) model.estimate_branch_lengths(tree) z_scores = [] @@ -591,7 +589,8 @@ def get_z_scores_under_misspecified_model(repetition): class TestIIDExponentialPosteriorMeanBLE(unittest.TestCase): - def test_IIDExponentialPosteriorMeanBLE(self): + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLE(self, name, use_cpp_implementation): r""" For a small tree with only one internal node, the likelihood of the data, and the posterior age of the internal node, can be computed easily in @@ -618,6 +617,7 @@ def test_IIDExponentialPosteriorMeanBLE(self): mutation_rate=mutation_rate, birth_rate=birth_rate, discretization_level=discretization_level, + use_cpp_implementation=use_cpp_implementation ) model.estimate_branch_lengths(tree) @@ -695,7 +695,8 @@ def test_IIDExponentialPosteriorMeanBLE(self): posterior_mean, numerical_posterior_mean, significant=2 ) - def test_IIDExponentialPosteriorMeanBLE_2(self): + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLE_2(self, name, use_cpp_implementation): r""" We run the Bayesian estimator on a small tree with all different leaves, and then check that: @@ -725,6 +726,7 @@ def test_IIDExponentialPosteriorMeanBLE_2(self): mutation_rate=mutation_rate, birth_rate=birth_rate, discretization_level=discretization_level, + use_cpp_implementation=use_cpp_implementation ) model.estimate_branch_lengths(tree) @@ -793,7 +795,8 @@ def test_IIDExponentialPosteriorMeanBLE_2(self): assert total_variation < 0.03 @pytest.mark.slow - def test_IIDExponentialPosteriorMeanBLE_3(self): + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLE_3(self, name, use_cpp_implementation): r""" Same as test_IIDExponentialPosteriorMeanBLE_2 but with a weirder topology. """ @@ -822,6 +825,7 @@ def test_IIDExponentialPosteriorMeanBLE_3(self): mutation_rate=mutation_rate, birth_rate=birth_rate, discretization_level=discretization_level, + use_cpp_implementation=use_cpp_implementation ) model.estimate_branch_lengths(tree) @@ -880,7 +884,8 @@ def test_IIDExponentialPosteriorMeanBLE_3(self): assert total_variation < 0.03 @pytest.mark.slow - def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(self): + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(self, name, use_cpp_implementation): r""" A tree from the DREAM subchallenge 1, verified analytically. """ @@ -904,6 +909,7 @@ def test_IIDExponentialPosteriorMeanBLE_DREAM_subC1(self): mutation_rate=mutation_rate, birth_rate=birth_rate, discretization_level=discretization_level, + use_cpp_implementation=use_cpp_implementation ) model.estimate_branch_lengths(tree) @@ -971,6 +977,7 @@ def test_IIDExponentialPosteriorMeanBLE_posterior_calibration(self): We use p-values computed from the Hoeffding bound. TODO: There might be a more powerful test, e.g. Kolmogorov–Smirnov? (This would mean we need less repetitions and can make the test faster.) + This test uses the c++ implementation to be faster. """ repetitions = 1000 @@ -1002,7 +1009,8 @@ def test_IIDExponentialPosteriorMeanBLE_posterior_calibration(self): class TestIIDExponentialPosteriorMeanBLEGridSeachCV(unittest.TestCase): - def test_IIDExponentialPosteriorMeanBLEGridSeachCV_smoke(self): + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLEGridSeachCV_smoke(self, name, use_cpp_implementation): r""" Just want to see that it runs in both single and multiprocessor mode """ @@ -1020,10 +1028,12 @@ def test_IIDExponentialPosteriorMeanBLEGridSeachCV_smoke(self): birth_rates=(1.5,), discretization_level=5, verbose=True, + use_cpp_implementation=use_cpp_implementation ) model.estimate_branch_lengths(tree) - def test_IIDExponentialPosteriorMeanBLEGridSeachCV(self): + @parameterized.expand([("cpp", True), ("no_cpp", False)]) + def test_IIDExponentialPosteriorMeanBLEGridSeachCV(self, name, use_cpp_implementation): r""" We just check that the grid search estimator does its job on a small grid. """ @@ -1047,6 +1057,7 @@ def test_IIDExponentialPosteriorMeanBLEGridSeachCV(self): birth_rates=birth_rates, discretization_level=discretization_level, verbose=True, + use_cpp_implementation=use_cpp_implementation ) # Test the model log likelihood against its numerical computation From 569f8c0f415a152afed31f26f49e7a117dc8133a Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 7 Feb 2021 20:14:56 -0800 Subject: [PATCH 47/61] Increase cpp bounds --- .../IIDExponentialPosteriorMeanBLE.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp index c2eaaf3d..48143b4a 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp @@ -10,8 +10,8 @@ #define forall(i,c) for(typeof((c).begin()) i = (c).begin();i != (c).end();i++) using namespace std; -const int maxN = 1024; -const int maxK = 128; +const int maxN = 10001; +const int maxK = 101; const int maxT = 512; const float INF = 1e16; float _down_cache[maxN][maxT][maxK]; From 90a2155ecd5a1cfc5a01c4ada20ba17ab364aa33 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 7 Feb 2021 20:18:38 -0800 Subject: [PATCH 48/61] bounds --- .../IIDExponentialPosteriorMeanBLE.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp index 48143b4a..8f0186e2 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp @@ -79,7 +79,7 @@ void read_N(){ cerr << "N input corrupted" << endl; exit(1); } - if(N >= maxN - 10){ + if(N >= maxN){ cerr << "N larger than maxN" << endl; exit(1); } @@ -201,7 +201,7 @@ void read_K(){ cerr << "K input corrupted" << endl; exit(1); } - if(K >= maxK - 10){ + if(K >= maxK - 1){ cerr << "K larger than maxK" << endl; exit(1); } @@ -215,7 +215,7 @@ void read_T(){ cerr << "T input corrupted" << endl; exit(1); } - if(T >= maxT - 10){ + if(T >= maxT - 1){ cerr << "T larger than maxT" << endl; exit(1); } From b33bb62c38b2720f195ea1a323c22c4ed1d732ae Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sun, 7 Feb 2021 20:21:50 -0800 Subject: [PATCH 49/61] bounds --- .../IIDExponentialPosteriorMeanBLE.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp index 8f0186e2..0054a56b 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp @@ -10,12 +10,12 @@ #define forall(i,c) for(typeof((c).begin()) i = (c).begin();i != (c).end();i++) using namespace std; -const int maxN = 10001; -const int maxK = 101; +const int maxN = 10000; +const int maxK = 100; const int maxT = 512; const float INF = 1e16; -float _down_cache[maxN][maxT][maxK]; -float _up_cache[maxN][maxT][maxK]; +float _down_cache[maxN][maxT + 1][maxK + 1]; +float _up_cache[maxN][maxT + 1][maxK + 1]; string input_dir = ""; string output_dir = ""; @@ -79,7 +79,7 @@ void read_N(){ cerr << "N input corrupted" << endl; exit(1); } - if(N >= maxN){ + if(N > maxN){ cerr << "N larger than maxN" << endl; exit(1); } @@ -201,7 +201,7 @@ void read_K(){ cerr << "K input corrupted" << endl; exit(1); } - if(K >= maxK - 1){ + if(K > maxK){ cerr << "K larger than maxK" << endl; exit(1); } @@ -215,7 +215,7 @@ void read_T(){ cerr << "T input corrupted" << endl; exit(1); } - if(T >= maxT - 1){ + if(T > maxT){ cerr << "T larger than maxT" << endl; exit(1); } From a24fff3cac70dcabd7046818d1f3ab577a26bf34 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Mon, 8 Feb 2021 12:31:06 -0800 Subject: [PATCH 50/61] More nitro --- .../IIDExponentialPosteriorMeanBLE.cpp | 121 +++++++++++++++++- .../IIDExponentialPosteriorMeanBLE.py | 84 +++++++++--- 2 files changed, 183 insertions(+), 22 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp index 0054a56b..39eeee02 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.cpp @@ -10,12 +10,15 @@ #define forall(i,c) for(typeof((c).begin()) i = (c).begin();i != (c).end();i++) using namespace std; -const int maxN = 10000; -const int maxK = 100; -const int maxT = 512; +const int maxN = 8192; +const int maxK = 63; +const int maxT = 501; const float INF = 1e16; float _down_cache[maxN][maxT + 1][maxK + 1]; float _up_cache[maxN][maxT + 1][maxK + 1]; +float log_joints[maxN][maxT + 1]; +float posteriors[maxN][maxT + 1]; +float posterior_means[maxN]; string input_dir = ""; string output_dir = ""; @@ -408,6 +411,116 @@ void write_up(){ fout << res; } +void write_log_likelihood(){ + ofstream fout(output_dir + "/log_likelihood.txt"); + float log_likelihood = 0; + for(auto child_of_root: children[root]){ + log_likelihood += down(child_of_root, 0, 0); + } + fout << log_likelihood; +} + +float _compute_log_joint(int v, int t){ + if(!(is_internal_node[v] and v != root)){ + cerr << "_compute_log_joint received invalid inputs" << endl; + exit(1); + } + vector valid_num_cuts; + if(enforce_parsimony){ + valid_num_cuts.push_back(get_number_of_mutated_characters_in_node[v]); + } else { + for(int x = 0; x <= get_number_of_mutated_characters_in_node[v]; x++){ + valid_num_cuts.push_back(x); + } + } + vector ll_for_xs; + for(auto x: valid_num_cuts){ + float ll_for_x = up(v, t, x); + for(auto u: children[v]){ + ll_for_x += down(u, t, x); + } + ll_for_xs.push_back(ll_for_x); + } + return logsumexp(ll_for_xs); +} + +void _write_out_log_joints(){ + ofstream fout(output_dir + "/log_joints.txt"); + string res = ""; + for(auto v: non_root_internal_nodes){ + res += to_string(v); + for(int t = 0; t <= T; t++){ + res += " " + to_string(log_joints[v][t]); + } + res += "\n"; + } + fout << res; +} + +void _write_out_posteriors(){ + // NOTE: copy-pasta of _write_out_log_joints) + ofstream fout(output_dir + "/posteriors.txt"); + string res = ""; + for(auto v: non_root_internal_nodes){ + res += to_string(v); + for(int t = 0; t <= T; t++){ + res += " " + to_string(posteriors[v][t]); + } + res += "\n"; + } + fout << res; +} + +void _write_out_posterior_means(){ + ofstream fout(output_dir + "/posterior_means.txt"); + string res = ""; + for(auto v: non_root_internal_nodes){ + res += to_string(v); + res += " " + to_string(posterior_means[v]); + res += "\n"; + } + fout << res; +} + +void write_posteriors(){ + // mimmicks _compute_posteriors of the python implementation. + for(auto v: non_root_internal_nodes){ + // Compute the log_joints. + for(int t = 0; t <= T; t++){ + log_joints[v][t] = _compute_log_joint(v, t); + } + + // Compute the posteriors + float mx = -INF; + for(int t = 0; t <= T; t++){ + mx = max(mx, log_joints[v][t]); + } + for(int t = 0; t <= T; t++){ + posteriors[v][t] = exp(log_joints[v][t] - mx); + } + // Normalize posteriors + float tot_sum = 0.0; + for(int t = 0; t <= T; t++){ + tot_sum += posteriors[v][t]; + } + for(int t = 0; t <= T; t++){ + posteriors[v][t] /= tot_sum; + } + + // Compute the posterior means. + posterior_means[v] = 0.0; + for(int t = 0; t <= T; t++){ + posterior_means[v] += posteriors[v][t] * t; + } + posterior_means[v] /= float(T); + } + + // Write out the log_joints, posteriors, and posterior_means + _write_out_log_joints(); + _write_out_posteriors(); + _write_out_posterior_means(); +} + int main(int argc, char *argv[]){ if(argc != 3){ @@ -434,5 +547,7 @@ int main(int argc, char *argv[]){ forn(v, N) forn(t, T + 1) forn(k, K + 1) _down_cache[v][t][k] = _up_cache[v][t][k] = 1.0; write_down(); write_up(); + write_log_likelihood(); + write_posteriors(); return 0; } diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 793e76a1..8f0d5f5f 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -86,12 +86,12 @@ def _compute_log_joint(self, v, t): valid_num_cuts = range( tree.get_number_of_mutated_characters_in_node(v) + 1 ) - ll_for_x = [] + ll_for_xs = [] for x in valid_num_cuts: - ll_for_x.append( + ll_for_xs.append( sum([self.down(u, t, x) for u in children]) + self.up(v, t, x) ) - return logsumexp(ll_for_x) + return logsumexp(ll_for_xs) def _compute_posteriors(self): tree = self.tree @@ -141,27 +141,28 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: time_cpp_start = time.time() if self.debug_cpp_implementation: # Use a directory that won't go away. - self._populate_cache_with_cpp_implementation( + self._populate_attributes_with_cpp_implementation( tmp_dir=os.getcwd() + "/tmp" ) else: # Use a temporary directory. with tempfile.TemporaryDirectory() as tmp_dir: - self._populate_cache_with_cpp_implementation(tmp_dir) + self._populate_attributes_with_cpp_implementation(tmp_dir) time_cpp_end = time.time() print(f"time_cpp = {time_cpp_end - time_cpp_start}") - time_compute_log_likelihood_start = time.time() - self._compute_log_likelihood() - time_compute_log_likelihood_end = time.time() - print( - f"time_compute_log_likelihood (dp_down) = {time_compute_log_likelihood_end - time_compute_log_likelihood_start}" - ) - time_compute_posteriors_start = time.time() - self._compute_posteriors() - time_compute_posteriors_end = time.time() - print( - f"time_compute_posteriors (dp_up) = {time_compute_posteriors_end - time_compute_posteriors_start}" - ) + else: + time_compute_log_likelihood_start = time.time() + self._compute_log_likelihood() + time_compute_log_likelihood_end = time.time() + print( + f"time_compute_log_likelihood (dp_down) = {time_compute_log_likelihood_end - time_compute_log_likelihood_start}" + ) + time_compute_posteriors_start = time.time() + self._compute_posteriors() + time_compute_posteriors_end = time.time() + print( + f"time_compute_posteriors (dp_up) = {time_compute_posteriors_end - time_compute_posteriors_start}" + ) time_populate_branch_lengths_start = time.time() self._populate_branch_lengths() time_populate_branch_lengths_end = time.time() @@ -239,10 +240,14 @@ def _write_out_list_of_lists(self, lls: List[List[int]], filename: str): with open(filename, "w") as file: file.write(res) - def _populate_cache_with_cpp_implementation(self, tmp_dir): + def _populate_attributes_with_cpp_implementation(self, tmp_dir): r""" A cpp implementation is run to compute up and down caches, which is - the computational bottleneck. + the computational bottleneck. The other attributes such as the + log-likelihood, and the posteriors, are also populated because + even these trivial computations are too slow in vanilla python. + Looking forward, a cython implementation will hopefully be the + best way forward. To remove anything that has to do with the cpp implementation, you just have to remove this function (and the gates around it). I.e., this python implementation is loosely coupled to the cpp call: we @@ -358,6 +363,47 @@ def _populate_cache_with_cpp_implementation(self, tmp_dir): ll = float(ll) self._up_cache[(id_to_node[v], t, x)] = ll + discretization_level = self.discretization_level + + # Load the log_likelihood + with open(f"{tmp_dir}/log_likelihood.txt", "r") as fin: + self.log_likelihood = float(fin.read()) + + # Load the posteriors + log_joints = {} # log P(t_v = t, X, T) + with open(f"{tmp_dir}/log_joints.txt", "r") as fin: + for line in fin: + vals = line.split(" ") + assert(len(vals) == discretization_level + 2) + v_id = int(vals[0]) + log_joint = np.zeros(shape=(discretization_level + 1,)) + for i, val in enumerate(vals[1:]): + log_joint[i] = float(val) + log_joints[id_to_node[v_id]] = log_joint + + posteriors = {} # P(t_v = t | X, T) + with open(f"{tmp_dir}/posteriors.txt", "r") as fin: + for line in fin: + vals = line.split(" ") + assert(len(vals) == discretization_level + 2) + v_id = int(vals[0]) + posterior = np.zeros(shape=(discretization_level + 1,)) + for i, val in enumerate(vals[1:]): + posterior[i] = float(val) + posteriors[id_to_node[v_id]] = posterior + + posterior_means = {} # E[t_v = t | X, T] + with open(f"{tmp_dir}/posterior_means.txt", "r") as fin: + for line in fin: + v_id, val = line.split(" ") + v_id = int(v_id) + val = float(val) + posterior_means[id_to_node[v_id]] = val + + self.log_joints = log_joints + self.posteriors = posteriors + self.posterior_means = posterior_means + def _compatible_with_observed_data(self, x, observed_cuts) -> bool: if self.enforce_parsimony: return x == observed_cuts From dcce54d501bff3ebdc57f5776c44515f28b3af88 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 9 Feb 2021 08:49:55 -0800 Subject: [PATCH 51/61] requirements --- docs/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index dcbdb1e1..5fc40d16 100755 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -22,4 +22,5 @@ pydata-sphinx-theme>=0.4.0 python-Levenshtein pathlib typing_extensions; python_version < '3.8' +cvxpy parameterized From 505c80225631467aed82cd9a4b84ebe80b60b2c2 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 9 Feb 2021 08:52:44 -0800 Subject: [PATCH 52/61] requirements --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a9d3e7ef..941a2f21 100755 --- a/setup.py +++ b/setup.py @@ -32,7 +32,9 @@ 'nbconvert >= 5.4.0', 'nbformat >= 4.4.0', 'hits', - 'scikit-bio >= 0.5.6' + 'scikit-bio >= 0.5.6', + 'cvxpy', + 'parameterized', ] From e4c1dc7f1b47522ab82e95d7490ec207842d4625 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 9 Feb 2021 08:54:26 -0800 Subject: [PATCH 53/61] requirements --- docs/requirements.txt | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 5fc40d16..45eb1c7e 100755 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -24,3 +24,4 @@ pathlib typing_extensions; python_version < '3.8' cvxpy parameterized +seaborn diff --git a/setup.py b/setup.py index 941a2f21..b1e0717e 100755 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ 'scikit-bio >= 0.5.6', 'cvxpy', 'parameterized', + 'seaborn', ] From 7f292ebc4b241aa48ee025f2a790317f364aa3ef Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Fri, 12 Feb 2021 19:33:22 -0800 Subject: [PATCH 54/61] Resolve multifurcations --- cassiopeia/tools/__init__.py | 1 + .../IIDExponentialPosteriorMeanBLE.py | 4 +-- .../tools/branch_length_estimator/__init__.py | 1 + .../branch_length_estimator_test.py | 36 ++++++++++++++++++- 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index ff6ecdf0..164c9791 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -1,5 +1,6 @@ from .branch_length_estimator import ( BranchLengthEstimator, + BLEMultifurcationWrapper, IIDExponentialBLE, IIDExponentialBLEGridSearchCV, IIDExponentialPosteriorMeanBLE, diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index 8f0d5f5f..c9421139 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -374,7 +374,7 @@ def _populate_attributes_with_cpp_implementation(self, tmp_dir): with open(f"{tmp_dir}/log_joints.txt", "r") as fin: for line in fin: vals = line.split(" ") - assert(len(vals) == discretization_level + 2) + assert len(vals) == discretization_level + 2 v_id = int(vals[0]) log_joint = np.zeros(shape=(discretization_level + 1,)) for i, val in enumerate(vals[1:]): @@ -385,7 +385,7 @@ def _populate_attributes_with_cpp_implementation(self, tmp_dir): with open(f"{tmp_dir}/posteriors.txt", "r") as fin: for line in fin: vals = line.split(" ") - assert(len(vals) == discretization_level + 2) + assert len(vals) == discretization_level + 2 v_id = int(vals[0]) posterior = np.zeros(shape=(discretization_level + 1,)) for i, val in enumerate(vals[1:]): diff --git a/cassiopeia/tools/branch_length_estimator/__init__.py b/cassiopeia/tools/branch_length_estimator/__init__.py index 5346e0b8..f0c15231 100644 --- a/cassiopeia/tools/branch_length_estimator/__init__.py +++ b/cassiopeia/tools/branch_length_estimator/__init__.py @@ -1,3 +1,4 @@ +from .BLEMultifurcationWrapper import BLEMultifurcationWrapper from .BranchLengthEstimator import BranchLengthEstimator from .IIDExponentialBLE import IIDExponentialBLE, IIDExponentialBLEGridSearchCV from .IIDExponentialPosteriorMeanBLE import ( diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index d91ee912..08cde817 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -9,7 +9,8 @@ from parameterized import parameterized from cassiopeia.data import CassiopeiaTree -from cassiopeia.tools import (BirthProcess, IIDExponentialBLE, +from cassiopeia.tools import (BirthProcess, BLEMultifurcationWrapper, + IIDExponentialBLE, IIDExponentialBLEGridSearchCV, IIDExponentialLineageTracer, IIDExponentialPosteriorMeanBLE, @@ -1087,3 +1088,36 @@ def test_IIDExponentialPosteriorMeanBLEGridSeachCV(self, name, use_cpp_implement np.testing.assert_almost_equal(model.mutation_rate, 0.75) np.testing.assert_almost_equal(model.birth_rate, 0.5) np.testing.assert_almost_equal(model.posterior_means["1"], 0.6815, decimal=3) + + +class TestBLEMultifurcationWrapper(unittest.TestCase): + def test_smoke(self): + tree = nx.DiGraph() + tree.add_nodes_from(["0", "1", "2", "3"]) + tree.add_edges_from([("0", "1"), ("0", "2"), ("0", "3")]) + tree = CassiopeiaTree(tree=tree) + tree.initialize_all_character_states( + {"0": [0, 0], + "1": [0, 1], + "2": [0, 1], + "3": [0, 1]} + ) + model = BLEMultifurcationWrapper(IIDExponentialBLE()) + model.estimate_branch_lengths(tree) + log_likelihood = model.log_likelihood + np.testing.assert_almost_equal( + tree.get_branch_length("0", "1"), np.log(2), decimal=3 + ) + np.testing.assert_almost_equal( + tree.get_branch_length("0", "2"), np.log(2), decimal=3 + ) + np.testing.assert_almost_equal( + tree.get_branch_length("0", "3"), np.log(2), decimal=3 + ) + np.testing.assert_almost_equal(tree.get_time("1"), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_time("2"), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_time("3"), np.log(2), decimal=3) + np.testing.assert_almost_equal(tree.get_time("0"), 0.0) + np.testing.assert_almost_equal(log_likelihood, -1.386 * 3, decimal=3) + log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) + np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) From b398ea43f9f72a439572e5177915b5d8575cd5c3 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Fri, 12 Feb 2021 19:42:54 -0800 Subject: [PATCH 55/61] Forgot to add file --- .../BLEMultifurcationWrapper.py | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py diff --git a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py new file mode 100644 index 00000000..2e480211 --- /dev/null +++ b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py @@ -0,0 +1,140 @@ +import copy +from queue import PriorityQueue +from cassiopeia.data import CassiopeiaTree +import networkx as nx +from .BranchLengthEstimator import ( + BranchLengthEstimator, + BranchLengthEstimatorError, +) + + +# https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python +# permalink: https://stackoverflow.com/a/1445289 +class BLEMultifurcationWrapper(BranchLengthEstimator): + r""" + Wraps a BranchLengthEstimator. + When estimating branch lengths: + 1) the tree topology is first copied out + 2) then multifurcations in the tree topology are resolved into a binary + structure, + 3) then branch lengths are estimated on this binary topology + 4) finally, the node ages are copied back onto the original tree. + Maximum Parsimony will be used to reconstruct the ancestral states. + """ + + def __init__(self, ble_model: BranchLengthEstimator): + ble_model = copy.deepcopy(ble_model) + self.__class__ = type( + ble_model.__class__.__name__, + (self.__class__, ble_model.__class__), + {}, + ) + self.__dict__ = ble_model.__dict__ + self.__ble_model = ble_model + + def estimate_branch_lengths(self, tree: CassiopeiaTree): + binary_topology = binarize_topology(tree.get_tree_topology()) + # For debugging: + # print(f"binary_topology = {binary_topology.__dict__}") + tree_binary = CassiopeiaTree( + character_matrix=tree.get_current_character_matrix(), + tree=binary_topology, + ) + tree_binary.reconstruct_ancestral_characters(zero_the_root=True) + self.__ble_model.estimate_branch_lengths(tree_binary) + # Copy the times from the binary tree onto the original tree + times = dict([(v, tree_binary.get_time(v)) for v in tree.nodes]) + tree.set_times(times) + + +def binarize_topology(tree: nx.DiGraph) -> nx.DiGraph: + r""" + Given a tree represented by a networkx DiGraph, it resolves + multifurcations. The tree is NOT modified in-place. + The root is made to have only one children, as in a real-life tumor + (the founding cell never divides immediately!) + """ + tree = copy.deepcopy(tree) + node_names = set([n for n in tree]) + root = [n for n in tree if tree.in_degree(n) == 0][0] + subtree_sizes = {} + _dfs_subtree_sizes(tree, subtree_sizes, root) + assert len(subtree_sizes) == len([n for n in tree]) + + # First make the root have degree 1. + if tree.out_degree(root) >= 2: + children = list(tree.successors(root)) + assert len(children) == tree.out_degree(root) + # First remove the edges from the root + tree.remove_edges_from([(root, child) for child in children]) + # Now create the intermediate node and add edges back + root_child = f"{root}-child" + if root_child in node_names: + raise BranchLengthEstimatorError("Node name already exists!") + tree.add_edge(root, root_child) + tree.add_edges_from([(root_child, child) for child in children]) + + def _dfs_resolve_multifurcations(tree, v): + children = list(tree.successors(v)) + if len(children) >= 3: + # Must resolve the multifurcation + _resolve_multifurcation(tree, v, subtree_sizes, node_names) + for child in children: + _dfs_resolve_multifurcations(tree, child) + + _dfs_resolve_multifurcations(tree, root) + # Check that the tree is binary + if not (len(tree.nodes) == len(tree.edges) + 1): + raise BranchLengthEstimatorError("Failed to binarize tree") + return tree + + +def _resolve_multifurcation(tree, v, subtree_sizes, node_names): + r""" + node_names is used to make sure we don't create a node name that already + exists. + """ + children = list(tree.successors(v)) + n_children = len(children) + assert n_children >= 3 + + # Remove all edges from v to its children + tree.remove_edges_from([(v, child) for child in children]) + + # Create the new binary structure + queue = PriorityQueue() + for child in children: + queue.put((subtree_sizes[child], child)) + + for i in range(n_children - 2): + # Coalesce two smallest subtrees + subtree_1_size, subtree_1_root = queue.get() + subtree_2_size, subtree_2_root = queue.get() + assert subtree_1_size <= subtree_2_size + coalesced_tree_size = subtree_1_size + subtree_2_size + 1 + coalesced_tree_root = f"{v}-coalesce-{i}" + if coalesced_tree_root in node_names: + raise BranchLengthEstimatorError("Node name already exists!") + # For debugging: + # print(f"Coalescing {subtree_1_root} (sz {subtree_1_size}) and" + # f" {subtree_2_root} (sz {subtree_2_size})") + tree.add_edges_from( + [ + (coalesced_tree_root, subtree_1_root), + (coalesced_tree_root, subtree_2_root), + ] + ) + queue.put((coalesced_tree_size, coalesced_tree_root)) + # Hang the two subtrees obtained to v + subtree_1_size, subtree_1_root = queue.get() + subtree_2_size, subtree_2_root = queue.get() + assert subtree_1_size <= subtree_2_size + tree.add_edges_from([(v, subtree_1_root), (v, subtree_2_root)]) + + +def _dfs_subtree_sizes(tree, subtree_sizes, v) -> int: + res = 1 + for child in tree.successors(v): + res += _dfs_subtree_sizes(tree, subtree_sizes, child) + subtree_sizes[v] = res + return res From 22fb3853531ee336ae5c3a6fd65b91980686c508 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sat, 13 Feb 2021 17:20:38 -0800 Subject: [PATCH 56/61] Add TumorWithAFitSubclone --- cassiopeia/tools/__init__.py | 1 + cassiopeia/tools/lineage_simulator.py | 99 ++++++++++++++++++++++ test/tools_tests/lineage_simulator_test.py | 40 +++++++++ 3 files changed, 140 insertions(+) diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index 164c9791..84cfae28 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -11,6 +11,7 @@ LineageSimulator, PerfectBinaryTree, PerfectBinaryTreeWithRootBranch, + TumorWithAFitSubclone, ) from .lineage_tracing_simulator import ( LineageTracingSimulator, diff --git a/cassiopeia/tools/lineage_simulator.py b/cassiopeia/tools/lineage_simulator.py index 9eb4d1f6..6a070ef1 100644 --- a/cassiopeia/tools/lineage_simulator.py +++ b/cassiopeia/tools/lineage_simulator.py @@ -3,6 +3,7 @@ import networkx as nx import numpy as np +from queue import Queue from cassiopeia.data import CassiopeiaTree @@ -183,3 +184,101 @@ def simulate_lineage(self) -> CassiopeiaTree: tree = CassiopeiaTree(tree=tree_nx) tree.set_times(times) return tree + + +class TumorWithAFitSubclone(LineageSimulator): + r""" + TODO + + Args: + branch_length: TODO + TODO + """ + + def __init__( + self, + branch_length: float, + branch_length_fit: float, + experiment_duration: float, + generations_until_fit_subclone: int, + ): + self.branch_length = branch_length + self.branch_length_fit = branch_length_fit + self.experiment_duration = experiment_duration + self.generations_until_fit_subclone = generations_until_fit_subclone + + def simulate_lineage(self) -> CassiopeiaTree: + r""" + See base class. + """ + branch_length = self.branch_length + branch_length_fit = self.branch_length_fit + experiment_duration = self.experiment_duration + generations_until_fit_subclone = self.generations_until_fit_subclone + + def node_name_generator(): + i = 0 + while True: + yield str(i) + i += 1 + + tree = nx.DiGraph() # This is what will get populated. + + names = node_name_generator() + q = Queue() # (node, time, fitness, generation) + times = {} + + root = next(names) + "_unfit" + tree.add_node(root) + times[root] = 0.0 + + root_child = next(names) + "_unfit" + tree.add_edge(root, root_child) + q.put((root_child, 0.0, "unfit", 0)) + subclone_started = False + while not q.empty(): + # Pop next node + (node, time, node_fitness, generation) = q.get() + time_till_division = ( + branch_length if node_fitness == "unfit" else branch_length_fit + ) + time_of_division = time + time_till_division + if time_of_division >= experiment_duration: + # Not enough time left for the cell to divide. + times[node] = experiment_duration + continue + # Create children, add edges to them, and push children to the + # queue. + times[node] = time_of_division + left_child_fitness = node_fitness + right_child_fitness = node_fitness + if ( + not subclone_started + and generation + 1 == generations_until_fit_subclone + ): + # Start the subclone + subclone_started = True + left_child_fitness = "fit" + left_child = next(names) + "_" + left_child_fitness + right_child = next(names) + "_" + right_child_fitness + tree.add_nodes_from([left_child, right_child]) + tree.add_edges_from([(node, left_child), (node, right_child)]) + q.put( + ( + left_child, + time_of_division, + left_child_fitness, + generation + 1, + ) + ) + q.put( + ( + right_child, + time_of_division, + right_child_fitness, + generation + 1, + ) + ) + res = CassiopeiaTree(tree=tree) + res.set_times(times) + return res diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py index 6e3067f2..a178bf60 100644 --- a/test/tools_tests/lineage_simulator_test.py +++ b/test/tools_tests/lineage_simulator_test.py @@ -7,6 +7,7 @@ PerfectBinaryTree, PerfectBinaryTreeWithRootBranch, BirthProcess, + TumorWithAFitSubclone, ) @@ -84,3 +85,42 @@ def num_ancestors(tree, node: int) -> int: inferred_birth_rate = np.array(intensities).mean() print(f"{birth_rate} == {inferred_birth_rate}") assert np.abs(birth_rate - inferred_birth_rate) < 0.05 + + +class TestTumorWithAFitSubclone(unittest.TestCase): + def test_TumorWithAFitSubclone(self): + r""" + Small test that can be drawn by hand. + Checks that the generated phylogeny is correct. + """ + tree = TumorWithAFitSubclone( + branch_length=1, + branch_length_fit=0.5, + experiment_duration=2, + generations_until_fit_subclone=1, + ).simulate_lineage() + self.assertListEqual( + tree.nodes, + ["0_unfit", "1_unfit", "2_fit", "3_unfit", "4_fit", "5_fit"], + ) + self.assertListEqual( + tree.edges, + [ + ("0_unfit", "1_unfit"), + ("1_unfit", "2_fit"), + ("1_unfit", "3_unfit"), + ("2_fit", "4_fit"), + ("2_fit", "5_fit"), + ], + ) + self.assertDictEqual( + tree.get_times(), + { + "0_unfit": 0.0, + "1_unfit": 1.0, + "2_fit": 1.5, + "3_unfit": 2.0, + "4_fit": 2.0, + "5_fit": 2.0, + }, + ) From 6f4eefd044aea58a599b9ffe634931a1434149dd Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Sat, 13 Feb 2021 22:06:24 -0800 Subject: [PATCH 57/61] Some goodies --- cassiopeia/data/CassiopeiaTree.py | 11 ++++++++ .../BLEMultifurcationWrapper.py | 7 +++++ .../IIDExponentialPosteriorMeanBLE.py | 27 ++++++++++++------- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/cassiopeia/data/CassiopeiaTree.py b/cassiopeia/data/CassiopeiaTree.py index d5b768ea..011d831d 100644 --- a/cassiopeia/data/CassiopeiaTree.py +++ b/cassiopeia/data/CassiopeiaTree.py @@ -997,3 +997,14 @@ def compute_dissimilarity_map( ) self.set_dissimilarity_map(dissimilarity_map) + + def scale_to_unit_length(self) -> None: + r""" + Scales the tree to have unit length. I.e. the longest path from root to + leaf will have length 1 after the scaling. + """ + times = {} + max_time = max(self.get_times().values()) + for node in self.nodes: + times[node] = self.get_time(node) / max_time + self.set_times(times) diff --git a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py index 2e480211..f01f5f6f 100644 --- a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py +++ b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py @@ -89,6 +89,13 @@ def _dfs_resolve_multifurcations(tree, v): return tree +def binarize_cassiopeia_tree(tree: CassiopeiaTree) -> CassiopeiaTree: + return CassiopeiaTree( + character_matrix=tree.get_current_character_matrix(), + tree=binarize_topology(tree.get_tree_topology()) + ) + + def _resolve_multifurcation(tree, v, subtree_sizes, node_names): r""" node_names is used to make sure we don't create a node name that already diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index c9421139..cbca5412 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -47,6 +47,7 @@ def __init__( enforce_parsimony: bool = True, use_cpp_implementation: bool = False, debug_cpp_implementation: bool = False, + verbose: bool = False ) -> None: # TODO: If we use autograd, we can tune the hyperparams with gradient # descent? @@ -58,6 +59,7 @@ def __init__( self.enforce_parsimony = enforce_parsimony self.use_cpp_implementation = use_cpp_implementation self.debug_cpp_implementation = debug_cpp_implementation + self.verbose = verbose def _compute_log_likelihood(self): tree = self.tree @@ -133,6 +135,7 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: self._down_cache = {} self._up_cache = {} self.tree = tree + verbose = self.verbose if self.debug_cpp_implementation: # Write out true dp values to check by eye against c++ # implementation values. @@ -149,26 +152,30 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: with tempfile.TemporaryDirectory() as tmp_dir: self._populate_attributes_with_cpp_implementation(tmp_dir) time_cpp_end = time.time() - print(f"time_cpp = {time_cpp_end - time_cpp_start}") + if verbose: + print(f"time_cpp = {time_cpp_end - time_cpp_start}") else: time_compute_log_likelihood_start = time.time() self._compute_log_likelihood() time_compute_log_likelihood_end = time.time() - print( - f"time_compute_log_likelihood (dp_down) = {time_compute_log_likelihood_end - time_compute_log_likelihood_start}" - ) + if verbose: + print( + f"time_compute_log_likelihood (dp_down) = {time_compute_log_likelihood_end - time_compute_log_likelihood_start}" + ) time_compute_posteriors_start = time.time() self._compute_posteriors() time_compute_posteriors_end = time.time() - print( - f"time_compute_posteriors (dp_up) = {time_compute_posteriors_end - time_compute_posteriors_start}" - ) + if verbose: + print( + f"time_compute_posteriors (dp_up) = {time_compute_posteriors_end - time_compute_posteriors_start}" + ) time_populate_branch_lengths_start = time.time() self._populate_branch_lengths() time_populate_branch_lengths_end = time.time() - print( - f"time_populate_branch_lengths = {time_populate_branch_lengths_end - time_populate_branch_lengths_start}" - ) + if verbose: + print( + f"time_populate_branch_lengths = {time_populate_branch_lengths_end - time_populate_branch_lengths_start}" + ) def _write_out_dps(self): r""" From cf94b2a6d1ecb7cfdbf27fa2a2ea22e6662503fa Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 16 Feb 2021 22:22:26 -0800 Subject: [PATCH 58/61] Resolving of multifurcations, and cell subsampler --- cassiopeia/data/CassiopeiaTree.py | 9 ++ cassiopeia/data/__init__.py | 4 +- cassiopeia/data/utilities.py | 95 ++++++++++++++ .../solver/ResolveMultifurcationsWrapper.py | 27 ++++ .../solver/StringifyNodeNamesWrapper.py | 33 +++++ cassiopeia/solver/__init__.py | 2 + cassiopeia/tools/__init__.py | 1 + .../BLEMultifurcationWrapper.py | 116 +----------------- .../BranchLengthEstimator.py | 2 +- .../IIDExponentialBLE.py | 3 +- .../IIDExponentialPosteriorMeanBLE.py | 7 +- cassiopeia/tools/cell_subsampler.py | 95 ++++++++++++++ .../branch_length_estimator_test.py | 6 +- test/tools_tests/lineage_simulator_test.py | 11 +- 14 files changed, 286 insertions(+), 125 deletions(-) create mode 100644 cassiopeia/solver/ResolveMultifurcationsWrapper.py create mode 100644 cassiopeia/solver/StringifyNodeNamesWrapper.py create mode 100644 cassiopeia/tools/cell_subsampler.py diff --git a/cassiopeia/data/CassiopeiaTree.py b/cassiopeia/data/CassiopeiaTree.py index 011d831d..19b5a995 100644 --- a/cassiopeia/data/CassiopeiaTree.py +++ b/cassiopeia/data/CassiopeiaTree.py @@ -1008,3 +1008,12 @@ def scale_to_unit_length(self) -> None: for node in self.nodes: times[node] = self.get_time(node) / max_time self.set_times(times) + + +def resolve_multifurcations(tree: CassiopeiaTree) -> None: + r""" + Resolves the multifurcations of the CassiopeiaTree inplace. + """ + binary_topology = utilities.resolve_multifurcations_networkx( + tree.get_tree_topology()) + tree.populate_tree(binary_topology) diff --git a/cassiopeia/data/__init__.py b/cassiopeia/data/__init__.py index 2e9c1edf..0bda6e1c 100644 --- a/cassiopeia/data/__init__.py +++ b/cassiopeia/data/__init__.py @@ -1,4 +1,4 @@ """Top level for data.""" -from .CassiopeiaTree import CassiopeiaTree -from .utilities import to_newick +from .CassiopeiaTree import CassiopeiaTree, resolve_multifurcations +from .utilities import to_newick, resolve_multifurcations_networkx diff --git a/cassiopeia/data/utilities.py b/cassiopeia/data/utilities.py index 84637396..580d1e92 100644 --- a/cassiopeia/data/utilities.py +++ b/cassiopeia/data/utilities.py @@ -1,6 +1,8 @@ """ General utilities for the datasets encountered in Cassiopeia. """ +import copy +from queue import PriorityQueue from typing import Callable, Dict, List, Optional import ete3 @@ -158,3 +160,96 @@ def _compute_dissimilarity_map(): return dm return _compute_dissimilarity_map() + + +def resolve_multifurcations_networkx(tree: nx.DiGraph) -> nx.DiGraph: + r""" + Given a tree represented by a networkx DiGraph, it resolves + multifurcations. The tree is NOT modified in-place. + The root is made to have only one children, as in a real-life tumor + (the founding cell never divides immediately!) + """ + tree = copy.deepcopy(tree) + node_names = set([n for n in tree]) + root = [n for n in tree if tree.in_degree(n) == 0][0] + subtree_sizes = {} + _dfs_subtree_sizes(tree, subtree_sizes, root) + assert len(subtree_sizes) == len([n for n in tree]) + + # First make the root have degree 1. + if tree.out_degree(root) >= 2: + children = list(tree.successors(root)) + assert len(children) == tree.out_degree(root) + # First remove the edges from the root + tree.remove_edges_from([(root, child) for child in children]) + # Now create the intermediate node and add edges back + root_child = f"{root}-child" + if root_child in node_names: + raise RuntimeError("Node name already exists!") + tree.add_edge(root, root_child) + tree.add_edges_from([(root_child, child) for child in children]) + + def _dfs_resolve_multifurcations(tree, v): + children = list(tree.successors(v)) + if len(children) >= 3: + # Must resolve the multifurcation + _resolve_multifurcation(tree, v, subtree_sizes, node_names) + for child in children: + _dfs_resolve_multifurcations(tree, child) + + _dfs_resolve_multifurcations(tree, root) + # Check that the tree is binary + if not (len(tree.nodes) == len(tree.edges) + 1): + raise RuntimeError("Failed to binarize tree") + return tree + + +def _resolve_multifurcation(tree, v, subtree_sizes, node_names): + r""" + node_names is used to make sure we don't create a node name that already + exists. + """ + children = list(tree.successors(v)) + n_children = len(children) + assert n_children >= 3 + + # Remove all edges from v to its children + tree.remove_edges_from([(v, child) for child in children]) + + # Create the new binary structure + queue = PriorityQueue() + for child in children: + queue.put((subtree_sizes[child], child)) + + for i in range(n_children - 2): + # Coalesce two smallest subtrees + subtree_1_size, subtree_1_root = queue.get() + subtree_2_size, subtree_2_root = queue.get() + assert subtree_1_size <= subtree_2_size + coalesced_tree_size = subtree_1_size + subtree_2_size + 1 + coalesced_tree_root = f"{v}-coalesce-{i}" + if coalesced_tree_root in node_names: + raise RuntimeError("Node name already exists!") + # For debugging: + # print(f"Coalescing {subtree_1_root} (sz {subtree_1_size}) and" + # f" {subtree_2_root} (sz {subtree_2_size})") + tree.add_edges_from( + [ + (coalesced_tree_root, subtree_1_root), + (coalesced_tree_root, subtree_2_root), + ] + ) + queue.put((coalesced_tree_size, coalesced_tree_root)) + # Hang the two subtrees obtained to v + subtree_1_size, subtree_1_root = queue.get() + subtree_2_size, subtree_2_root = queue.get() + assert subtree_1_size <= subtree_2_size + tree.add_edges_from([(v, subtree_1_root), (v, subtree_2_root)]) + + +def _dfs_subtree_sizes(tree, subtree_sizes, v) -> int: + res = 1 + for child in tree.successors(v): + res += _dfs_subtree_sizes(tree, subtree_sizes, child) + subtree_sizes[v] = res + return res diff --git a/cassiopeia/solver/ResolveMultifurcationsWrapper.py b/cassiopeia/solver/ResolveMultifurcationsWrapper.py new file mode 100644 index 00000000..87b29ea5 --- /dev/null +++ b/cassiopeia/solver/ResolveMultifurcationsWrapper.py @@ -0,0 +1,27 @@ +import copy +from cassiopeia.data import CassiopeiaTree, resolve_multifurcations +from cassiopeia.solver import CassiopeiaSolver + + +# https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python +# permalink: https://stackoverflow.com/a/1445289 +class ResolveMultifurcationsWrapper(CassiopeiaSolver.CassiopeiaSolver): + r""" + Wraps a CassiopeiaSolver. + The wrapped solver is used to solve for the tree topology, after which + the multifurcations are resolved. + """ + + def __init__(self, solver: CassiopeiaSolver.CassiopeiaSolver): + solver = copy.deepcopy(solver) + self.__class__ = type( + solver.__class__.__name__, + (self.__class__, solver.__class__), + {}, + ) + self.__dict__ = solver.__dict__ + self.__solver = solver + + def solve(self, tree: CassiopeiaTree) -> None: + self.__solver.solve(tree) + resolve_multifurcations(tree) diff --git a/cassiopeia/solver/StringifyNodeNamesWrapper.py b/cassiopeia/solver/StringifyNodeNamesWrapper.py new file mode 100644 index 00000000..71d519e4 --- /dev/null +++ b/cassiopeia/solver/StringifyNodeNamesWrapper.py @@ -0,0 +1,33 @@ +import copy +from cassiopeia.data import CassiopeiaTree +from cassiopeia.solver import CassiopeiaSolver + + +# https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python +# permalink: https://stackoverflow.com/a/1445289 +class StringifyNodeNamesWrapper(CassiopeiaSolver.CassiopeiaSolver): + r""" + Wraps a CassiopeiaSolver. + The wrapped solver is used to solve for the tree topology, after which + the node names are cast to string. This is because some solvers create + nodes with integer IDs, and this can break downstream code. + """ + def __init__(self, solver: CassiopeiaSolver.CassiopeiaSolver): + solver = copy.deepcopy(solver) + self.__class__ = type( + solver.__class__.__name__, + (self.__class__, solver.__class__), + {}, + ) + self.__dict__ = solver.__dict__ + self.__solver = solver + + def solve(self, tree: CassiopeiaTree) -> None: + self.__solver.solve(tree) + relabel_map = {node: 'internal-' + str(node) for node in + tree.internal_nodes} + num_nodes_before = len(tree.nodes) + tree.relabel_nodes(relabel_map) + num_nodes_after = len(tree.nodes) + if num_nodes_before != num_nodes_after: + raise RuntimeError("There was a colision stringifying node names.") diff --git a/cassiopeia/solver/__init__.py b/cassiopeia/solver/__init__.py index bf8d2640..14a7c101 100755 --- a/cassiopeia/solver/__init__.py +++ b/cassiopeia/solver/__init__.py @@ -10,7 +10,9 @@ from .MaxCutSolver import MaxCutSolver from .NeighborJoiningSolver import NeighborJoiningSolver from .solver_utilities import collapse_tree, collapse_unifurcations +from .ResolveMultifurcationsWrapper import ResolveMultifurcationsWrapper from .SpectralGreedySolver import SpectralGreedySolver from .SpectralSolver import SpectralSolver +from .StringifyNodeNamesWrapper import StringifyNodeNamesWrapper from .VanillaGreedySolver import VanillaGreedySolver from . import dissimilarity_functions as dissimilarity diff --git a/cassiopeia/tools/__init__.py b/cassiopeia/tools/__init__.py index 84cfae28..44dcb6b8 100644 --- a/cassiopeia/tools/__init__.py +++ b/cassiopeia/tools/__init__.py @@ -17,3 +17,4 @@ LineageTracingSimulator, IIDExponentialLineageTracer, ) +from .cell_subsampler import CellSubsampler, UniformCellSubsampler diff --git a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py index f01f5f6f..95365ca5 100644 --- a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py +++ b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py @@ -1,11 +1,6 @@ import copy -from queue import PriorityQueue -from cassiopeia.data import CassiopeiaTree -import networkx as nx -from .BranchLengthEstimator import ( - BranchLengthEstimator, - BranchLengthEstimatorError, -) +from cassiopeia.data import CassiopeiaTree, resolve_multifurcations_networkx +from .BranchLengthEstimator import BranchLengthEstimator # https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python @@ -32,10 +27,11 @@ def __init__(self, ble_model: BranchLengthEstimator): self.__dict__ = ble_model.__dict__ self.__ble_model = ble_model - def estimate_branch_lengths(self, tree: CassiopeiaTree): - binary_topology = binarize_topology(tree.get_tree_topology()) + def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: + binary_topology = resolve_multifurcations_networkx( + tree.get_tree_topology()) # For debugging: - # print(f"binary_topology = {binary_topology.__dict__}") + print(f"binary_topology = {binary_topology.__dict__}") tree_binary = CassiopeiaTree( character_matrix=tree.get_current_character_matrix(), tree=binary_topology, @@ -45,103 +41,3 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree): # Copy the times from the binary tree onto the original tree times = dict([(v, tree_binary.get_time(v)) for v in tree.nodes]) tree.set_times(times) - - -def binarize_topology(tree: nx.DiGraph) -> nx.DiGraph: - r""" - Given a tree represented by a networkx DiGraph, it resolves - multifurcations. The tree is NOT modified in-place. - The root is made to have only one children, as in a real-life tumor - (the founding cell never divides immediately!) - """ - tree = copy.deepcopy(tree) - node_names = set([n for n in tree]) - root = [n for n in tree if tree.in_degree(n) == 0][0] - subtree_sizes = {} - _dfs_subtree_sizes(tree, subtree_sizes, root) - assert len(subtree_sizes) == len([n for n in tree]) - - # First make the root have degree 1. - if tree.out_degree(root) >= 2: - children = list(tree.successors(root)) - assert len(children) == tree.out_degree(root) - # First remove the edges from the root - tree.remove_edges_from([(root, child) for child in children]) - # Now create the intermediate node and add edges back - root_child = f"{root}-child" - if root_child in node_names: - raise BranchLengthEstimatorError("Node name already exists!") - tree.add_edge(root, root_child) - tree.add_edges_from([(root_child, child) for child in children]) - - def _dfs_resolve_multifurcations(tree, v): - children = list(tree.successors(v)) - if len(children) >= 3: - # Must resolve the multifurcation - _resolve_multifurcation(tree, v, subtree_sizes, node_names) - for child in children: - _dfs_resolve_multifurcations(tree, child) - - _dfs_resolve_multifurcations(tree, root) - # Check that the tree is binary - if not (len(tree.nodes) == len(tree.edges) + 1): - raise BranchLengthEstimatorError("Failed to binarize tree") - return tree - - -def binarize_cassiopeia_tree(tree: CassiopeiaTree) -> CassiopeiaTree: - return CassiopeiaTree( - character_matrix=tree.get_current_character_matrix(), - tree=binarize_topology(tree.get_tree_topology()) - ) - - -def _resolve_multifurcation(tree, v, subtree_sizes, node_names): - r""" - node_names is used to make sure we don't create a node name that already - exists. - """ - children = list(tree.successors(v)) - n_children = len(children) - assert n_children >= 3 - - # Remove all edges from v to its children - tree.remove_edges_from([(v, child) for child in children]) - - # Create the new binary structure - queue = PriorityQueue() - for child in children: - queue.put((subtree_sizes[child], child)) - - for i in range(n_children - 2): - # Coalesce two smallest subtrees - subtree_1_size, subtree_1_root = queue.get() - subtree_2_size, subtree_2_root = queue.get() - assert subtree_1_size <= subtree_2_size - coalesced_tree_size = subtree_1_size + subtree_2_size + 1 - coalesced_tree_root = f"{v}-coalesce-{i}" - if coalesced_tree_root in node_names: - raise BranchLengthEstimatorError("Node name already exists!") - # For debugging: - # print(f"Coalescing {subtree_1_root} (sz {subtree_1_size}) and" - # f" {subtree_2_root} (sz {subtree_2_size})") - tree.add_edges_from( - [ - (coalesced_tree_root, subtree_1_root), - (coalesced_tree_root, subtree_2_root), - ] - ) - queue.put((coalesced_tree_size, coalesced_tree_root)) - # Hang the two subtrees obtained to v - subtree_1_size, subtree_1_root = queue.get() - subtree_2_size, subtree_2_root = queue.get() - assert subtree_1_size <= subtree_2_size - tree.add_edges_from([(v, subtree_1_root), (v, subtree_2_root)]) - - -def _dfs_subtree_sizes(tree, subtree_sizes, v) -> int: - res = 1 - for child in tree.successors(v): - res += _dfs_subtree_sizes(tree, subtree_sizes, child) - subtree_sizes[v] = res - return res diff --git a/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py index cd94456e..96bf76af 100644 --- a/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py +++ b/cassiopeia/tools/branch_length_estimator/BranchLengthEstimator.py @@ -4,7 +4,7 @@ class BranchLengthEstimatorError(Exception): - """An Exception class for the CassiopeiaTree class.""" + """An Exception class for the BranchLengthEstimator class.""" pass diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index 23eb66da..512e16f3 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -1,5 +1,4 @@ import multiprocessing -import copy from typing import List, Optional, Tuple import cvxpy as cp @@ -317,7 +316,7 @@ def _cv_split( state = tree.get_character_states(node) train_state = ( state[:held_out_character_idx] - + state[(held_out_character_idx + 1) :] + + state[(held_out_character_idx + 1):] ) valid_state = [state[held_out_character_idx]] train_states[node] = train_state diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index cbca5412..d97f974f 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -180,12 +180,13 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: def _write_out_dps(self): r""" For debugging the c++ implementation: - This writes out the down and up values of the correct python implementation to the files - tmp/down_true.txt + This writes out the down and up values of the correct python + implementation to the files tmp/down_true.txt and tmp/up_true.txt respectively. - Compare these against tmp/down.txt and tmp/up.txt, which are the values computed by the c++ implementation. + Compare these against tmp/down.txt and tmp/up.txt, which are the values + computed by the c++ implementation. """ tree = self.tree N = len(tree.nodes) diff --git a/cassiopeia/tools/cell_subsampler.py b/cassiopeia/tools/cell_subsampler.py new file mode 100644 index 00000000..8fc8b1ec --- /dev/null +++ b/cassiopeia/tools/cell_subsampler.py @@ -0,0 +1,95 @@ +import abc +import networkx as nx +import numpy as np + +from cassiopeia.data import CassiopeiaTree + + +class CellSubsamplerError(Exception): + """An Exception class for the CellSubsampler class.""" + + pass + + +class CellSubsampler(abc.ABC): + r""" + Abstract base class for all cell samplers. + + A CellSubsampler implements a method 'subsample' which, given a Tree, + returns a second Tree which is the result of subsampling cells + (i.e. leafs) of the tree. Only the tree topology will be created for the + new tree. + """ + + @abc.abstractmethod + def subsample(self, tree: CassiopeiaTree) -> CassiopeiaTree: + r""" + Returns a new CassiopeiaTree which is the result of subsampling + the cells in the original CassiopeiaTree. + + Args: + tree: The tree for which to subsample leaves. + """ + + +class UniformCellSubsampler(CellSubsampler): + def __init__(self, ratio: float): + r""" + Samples 'ratio' of the leaves, rounded down, uniformly at random. + """ + self.__ratio = ratio + + def subsample(self, tree: CassiopeiaTree) -> CassiopeiaTree: + ratio = self.__ratio + n_subsample = int(tree.n_cell * ratio) + if n_subsample == 0: + raise CellSubsamplerError("ratio too low: no cells would be " + "sampled.") + + # First determine which nodes are part of the induced subgraph. + leaf_keep_idx = np.random.choice(range(tree.n_cell), n_subsample, + replace=False) + leaves_in_induced_subtree = [tree.leaves[i] for i in leaf_keep_idx] + induced_subtree_degs = dict([(leaf, 0) for leaf in leaves_in_induced_subtree]) + + nodes_in_induced_subtree = set(leaves_in_induced_subtree) + for node in tree.depth_first_traverse_nodes(postorder=True): + children = tree.children(node) + induced_subtree_deg = sum([child in nodes_in_induced_subtree for child in children]) + if induced_subtree_deg > 0: + nodes_in_induced_subtree.add(node) + induced_subtree_degs[node] = induced_subtree_deg + + # For debugging: + # print(f"leaves_in_induced_subtree = {leaves_in_induced_subtree}") + # print(f"nodes_in_induced_subtree = {nodes_in_induced_subtree}") + # print(f"induced_subtree_degs = {induced_subtree_degs}") + nodes = [] + edges = [] + up = {} + for node in tree.depth_first_traverse_nodes(postorder=False): + if node == tree.root: + nodes.append(node) + up[node] = node + continue + + if node not in nodes_in_induced_subtree: + continue + + if induced_subtree_degs[tree.parent(node)] >= 2: + up[node] = tree.parent(node) + else: + up[node] = up[tree.parent(node)] + + if induced_subtree_degs[node] >= 2 or induced_subtree_degs[node] == 0: + nodes.append(node) + edges.append((up[node], node)) + subtree_topology = nx.DiGraph() + subtree_topology.add_nodes_from(nodes) + subtree_topology.add_edges_from(edges) + res = CassiopeiaTree( + tree=subtree_topology, + ) + # Copy times over + res.set_times(dict([(node, tree.get_time(node)) for node in res.nodes])) + return res diff --git a/test/tools_tests/branch_length_estimator_test.py b/test/tools_tests/branch_length_estimator_test.py index 08cde817..626d537e 100644 --- a/test/tools_tests/branch_length_estimator_test.py +++ b/test/tools_tests/branch_length_estimator_test.py @@ -1118,6 +1118,8 @@ def test_smoke(self): np.testing.assert_almost_equal(tree.get_time("2"), np.log(2), decimal=3) np.testing.assert_almost_equal(tree.get_time("3"), np.log(2), decimal=3) np.testing.assert_almost_equal(tree.get_time("0"), 0.0) - np.testing.assert_almost_equal(log_likelihood, -1.386 * 3, decimal=3) + np.testing.assert_almost_equal(log_likelihood, -1.386, decimal=3) log_likelihood_2 = IIDExponentialBLE.log_likelihood(tree) - np.testing.assert_almost_equal(log_likelihood, log_likelihood_2, decimal=3) + # The tree topology that the estimator sees is different from the + # one in the final phylogeny, thus the lik will be different! + np.testing.assert_almost_equal(log_likelihood * 3, log_likelihood_2, decimal=3) diff --git a/test/tools_tests/lineage_simulator_test.py b/test/tools_tests/lineage_simulator_test.py index a178bf60..c3142689 100644 --- a/test/tools_tests/lineage_simulator_test.py +++ b/test/tools_tests/lineage_simulator_test.py @@ -39,12 +39,13 @@ class TestBirthProcess(unittest.TestCase): @pytest.mark.slow def test_BirthProcess(self): r""" - Generate tree, then choose a random lineage can count how many nodes are on - the lineage. This is the number of times the process triggered on that - lineage. + Generate tree, then choose a random lineage can count how many nodes + are on the lineage. This is the number of times the process triggered + on that lineage. - Also, the probability that a tree with only one internal node is obtained is - e^-lam * (1 - e^-lam) where lam is the birth rate, so we also check this. + Also, the probability that a tree with only one internal node is + obtained is e^-lam * (1 - e^-lam) where lam is the birth rate, so we + also check this. """ np.random.seed(1) birth_rate = 0.6 From ec7c03bc62cab76897b181f32c1388ef3643e209 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 16 Feb 2021 22:25:52 -0800 Subject: [PATCH 59/61] black --- .../BLEMultifurcationWrapper.py | 3 ++- .../IIDExponentialBLE.py | 2 +- .../IIDExponentialPosteriorMeanBLE.py | 2 +- cassiopeia/tools/cell_subsampler.py | 23 +++++++++++++------ 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py index 95365ca5..f96dc129 100644 --- a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py +++ b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py @@ -29,7 +29,8 @@ def __init__(self, ble_model: BranchLengthEstimator): def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: binary_topology = resolve_multifurcations_networkx( - tree.get_tree_topology()) + tree.get_tree_topology() + ) # For debugging: print(f"binary_topology = {binary_topology.__dict__}") tree_binary = CassiopeiaTree( diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py index 512e16f3..41644a0a 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialBLE.py @@ -316,7 +316,7 @@ def _cv_split( state = tree.get_character_states(node) train_state = ( state[:held_out_character_idx] - + state[(held_out_character_idx + 1):] + + state[(held_out_character_idx + 1) :] ) valid_state = [state[held_out_character_idx]] train_states[node] = train_state diff --git a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py index d97f974f..7eccb4ec 100644 --- a/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py +++ b/cassiopeia/tools/branch_length_estimator/IIDExponentialPosteriorMeanBLE.py @@ -47,7 +47,7 @@ def __init__( enforce_parsimony: bool = True, use_cpp_implementation: bool = False, debug_cpp_implementation: bool = False, - verbose: bool = False + verbose: bool = False, ) -> None: # TODO: If we use autograd, we can tune the hyperparams with gradient # descent? diff --git a/cassiopeia/tools/cell_subsampler.py b/cassiopeia/tools/cell_subsampler.py index 8fc8b1ec..93cd81d4 100644 --- a/cassiopeia/tools/cell_subsampler.py +++ b/cassiopeia/tools/cell_subsampler.py @@ -43,19 +43,25 @@ def subsample(self, tree: CassiopeiaTree) -> CassiopeiaTree: ratio = self.__ratio n_subsample = int(tree.n_cell * ratio) if n_subsample == 0: - raise CellSubsamplerError("ratio too low: no cells would be " - "sampled.") + raise CellSubsamplerError( + "ratio too low: no cells would be " "sampled." + ) # First determine which nodes are part of the induced subgraph. - leaf_keep_idx = np.random.choice(range(tree.n_cell), n_subsample, - replace=False) + leaf_keep_idx = np.random.choice( + range(tree.n_cell), n_subsample, replace=False + ) leaves_in_induced_subtree = [tree.leaves[i] for i in leaf_keep_idx] - induced_subtree_degs = dict([(leaf, 0) for leaf in leaves_in_induced_subtree]) + induced_subtree_degs = dict( + [(leaf, 0) for leaf in leaves_in_induced_subtree] + ) nodes_in_induced_subtree = set(leaves_in_induced_subtree) for node in tree.depth_first_traverse_nodes(postorder=True): children = tree.children(node) - induced_subtree_deg = sum([child in nodes_in_induced_subtree for child in children]) + induced_subtree_deg = sum( + [child in nodes_in_induced_subtree for child in children] + ) if induced_subtree_deg > 0: nodes_in_induced_subtree.add(node) induced_subtree_degs[node] = induced_subtree_deg @@ -81,7 +87,10 @@ def subsample(self, tree: CassiopeiaTree) -> CassiopeiaTree: else: up[node] = up[tree.parent(node)] - if induced_subtree_degs[node] >= 2 or induced_subtree_degs[node] == 0: + if ( + induced_subtree_degs[node] >= 2 + or induced_subtree_degs[node] == 0 + ): nodes.append(node) edges.append((up[node], node)) subtree_topology = nx.DiGraph() From 8baab4cb84f6d6d0a11c4ba097f08e0a2dca7978 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Wed, 17 Feb 2021 11:12:02 -0800 Subject: [PATCH 60/61] Enhance UniformCellSubsampler --- cassiopeia/tools/cell_subsampler.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/cassiopeia/tools/cell_subsampler.py b/cassiopeia/tools/cell_subsampler.py index 93cd81d4..1c6f5fd9 100644 --- a/cassiopeia/tools/cell_subsampler.py +++ b/cassiopeia/tools/cell_subsampler.py @@ -1,6 +1,7 @@ import abc import networkx as nx import numpy as np +from typing import Optional from cassiopeia.data import CassiopeiaTree @@ -33,15 +34,29 @@ def subsample(self, tree: CassiopeiaTree) -> CassiopeiaTree: class UniformCellSubsampler(CellSubsampler): - def __init__(self, ratio: float): + def __init__( + self, ratio: Optional[float] = None, n_cells: Optional[int] = None + ): r""" Samples 'ratio' of the leaves, rounded down, uniformly at random. """ + if ratio is None and n_cells is None: + raise CellSubsamplerError( + "At least one of 'ratio' and 'n_cells' " "must be specified." + ) + if ratio is not None and n_cells is not None: + raise CellSubsamplerError( + "Exactly one of 'ratio' and 'n_cells'" "must be specified." + ) self.__ratio = ratio + self.__n_cells = n_cells def subsample(self, tree: CassiopeiaTree) -> CassiopeiaTree: ratio = self.__ratio - n_subsample = int(tree.n_cell * ratio) + n_cells = self.__n_cells + n_subsample = ( + n_cells if n_cells is not None else int(tree.n_cell * ratio) + ) if n_subsample == 0: raise CellSubsamplerError( "ratio too low: no cells would be " "sampled." From a1372c259414734cf3fb4d921f272ead904a89e9 Mon Sep 17 00:00:00 2001 From: Sebastian Prillo <10426884+sprillo@users.noreply.github.com> Date: Tue, 23 Feb 2021 14:47:12 -0800 Subject: [PATCH 61/61] Remove print --- .../tools/branch_length_estimator/BLEMultifurcationWrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py index f96dc129..4c1b8ece 100644 --- a/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py +++ b/cassiopeia/tools/branch_length_estimator/BLEMultifurcationWrapper.py @@ -32,7 +32,7 @@ def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None: tree.get_tree_topology() ) # For debugging: - print(f"binary_topology = {binary_topology.__dict__}") + # print(f"binary_topology = {binary_topology.__dict__}") tree_binary = CassiopeiaTree( character_matrix=tree.get_current_character_matrix(), tree=binary_topology,