From cf2cbc219c79371e8d7e1ae6896c5367ec02b9c0 Mon Sep 17 00:00:00 2001 From: colganwi Date: Wed, 18 Oct 2023 14:15:08 -0400 Subject: [PATCH] faster leaf removal --- cassiopeia/data/CassiopeiaTree.py | 41 ++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/cassiopeia/data/CassiopeiaTree.py b/cassiopeia/data/CassiopeiaTree.py index 8df6e866..482a4bbe 100755 --- a/cassiopeia/data/CassiopeiaTree.py +++ b/cassiopeia/data/CassiopeiaTree.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import ete3 +import heapq import networkx as nx import numpy as np import pandas as pd @@ -1595,7 +1596,7 @@ def remove_leaves_and_prune_lineages( Removes the specified leaves and all ancestors of those leaves that are no longer the ancestor of any of the remaining leaves. In the context of a phylogeny, this prunes the lineage of all nodes no longer relevant - to observed samples. Additionally, maintains consistency with the + to observed samples. Additionally, maintains consistency with the updated tree by removing the node from all leaf data. Args: @@ -1614,19 +1615,35 @@ def remove_leaves_and_prune_lineages( for n in nodes: if not self.is_leaf(n): raise CassiopeiaTreeError("A specified node is not a leaf.") + + # Keep track of nodes to check and their depths + nodes_to_check = set() + nodes_depth_queue = [] + # Remove leaves from the tree for n in nodes: - if len(self.nodes) == 1: - self.__remove_node(n) - else: - curr_parent = self.parent(n) - self.__remove_node(n) - while len(self.children(curr_parent)) < 1 and not self.is_root( - curr_parent - ): - next_parent = self.parent(curr_parent) - self.__remove_node(curr_parent) - curr_parent = next_parent + parent = next(self.__network.predecessors(n)) + if parent not in nodes_to_check: + parent_time = self.__network.nodes[parent]["time"] + heapq.heappush(nodes_depth_queue, (parent_time, parent)) + nodes_to_check.add(parent) + self.__network.remove_node(n) + + # Check nodes with a children removed from bottom to top + while len(nodes_to_check) > 0: + _, n = heapq.heappop(nodes_depth_queue) + nodes_to_check.remove(n) + if n == self.root: + continue + children = list(self.__network.successors(n)) + # Remove nodes with no children + if len(children) == 0: + parent = next(self.__network.predecessors(n)) + parent_time = self.__network.nodes[parent]["time"] + self.__network.remove_node(n) + if parent not in nodes_to_check: + heapq.heappush(nodes_depth_queue, (parent_time, parent)) + nodes_to_check.add(parent) # Remove all removed nodes from data fields # This function will also clear the cache