Skip to content

Commit

Permalink
Merge pull request #228 from YosefLab/leaf-subsample-speedup
Browse files Browse the repository at this point in the history
Leaf Subsampler Speedup
  • Loading branch information
mattjones315 authored Oct 19, 2023
2 parents d53e6d9 + cf2cbc2 commit 2ca07df
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions cassiopeia/data/CassiopeiaTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import ete3
import heapq
import networkx as nx
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1595,7 +1596,7 @@ def remove_leaves_and_prune_lineages(
Removes the specified leaves and all ancestors of those leaves that are
no longer the ancestor of any of the remaining leaves. In the context
of a phylogeny, this prunes the lineage of all nodes no longer relevant
to observed samples. Additionally, maintains consistency with the
to observed samples. Additionally, maintains consistency with the
updated tree by removing the node from all leaf data.
Args:
Expand All @@ -1614,19 +1615,35 @@ def remove_leaves_and_prune_lineages(
for n in nodes:
if not self.is_leaf(n):
raise CassiopeiaTreeError("A specified node is not a leaf.")

# Keep track of nodes to check and their depths
nodes_to_check = set()
nodes_depth_queue = []

# Remove leaves from the tree
for n in nodes:
if len(self.nodes) == 1:
self.__remove_node(n)
else:
curr_parent = self.parent(n)
self.__remove_node(n)
while len(self.children(curr_parent)) < 1 and not self.is_root(
curr_parent
):
next_parent = self.parent(curr_parent)
self.__remove_node(curr_parent)
curr_parent = next_parent
parent = next(self.__network.predecessors(n))
if parent not in nodes_to_check:
parent_time = self.__network.nodes[parent]["time"]
heapq.heappush(nodes_depth_queue, (parent_time, parent))
nodes_to_check.add(parent)
self.__network.remove_node(n)

# Check nodes with a children removed from bottom to top
while len(nodes_to_check) > 0:
_, n = heapq.heappop(nodes_depth_queue)
nodes_to_check.remove(n)
if n == self.root:
continue
children = list(self.__network.successors(n))
# Remove nodes with no children
if len(children) == 0:
parent = next(self.__network.predecessors(n))
parent_time = self.__network.nodes[parent]["time"]
self.__network.remove_node(n)
if parent not in nodes_to_check:
heapq.heappush(nodes_depth_queue, (parent_time, parent))
nodes_to_check.add(parent)

# Remove all removed nodes from data fields
# This function will also clear the cache
Expand Down

0 comments on commit 2ca07df

Please sign in to comment.