diff --git a/cassiopeia/preprocess/lineage_utils.py b/cassiopeia/preprocess/lineage_utils.py index 26938259..a25162c3 100755 --- a/cassiopeia/preprocess/lineage_utils.py +++ b/cassiopeia/preprocess/lineage_utils.py @@ -49,26 +49,29 @@ def assign_lineage_groups( piv_assigned = pd.DataFrame() # Loop for iteratively assigning LGs - prev_clust_size = np.inf + remaining_intBCs = pivot_in.columns.tolist() i = 0 - while prev_clust_size > min_clust_size: + while len(remaining_intBCs) > 0 and pivot_in.shape[0] > 0: # run function - piv_lg, piv_nolg = find_top_lg( + piv_lg, piv_nolg, intBC_top = find_top_lg( pivot_in, i, min_intbc_prop=min_intbc_thresh, kinship_thresh=kinship_thresh, + intbcs = remaining_intBCs, ) - # append returned objects to output variable - piv_assigned = pd.concat([piv_assigned, piv_lg], sort=True) - + # if lineage group larger than min_clust_size + if piv_lg.shape[0] > min_clust_size: + # append returned objects to output variable + piv_assigned = pd.concat([piv_assigned, piv_lg], sort=True) + i += 1 + # update pivot_in by removing assigned alignments pivot_in = piv_nolg - prev_clust_size = piv_lg.shape[0] - - i += 1 + # remove intBC_top from remaining_intBCs + remaining_intBCs.remove(intBC_top) return piv_assigned @@ -78,6 +81,7 @@ def find_top_lg( iteration: int, min_intbc_prop: float = 0.2, kinship_thresh: float = 0.2, + intbcs = List[str], ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Algorithm to creates lineage groups from a pivot table of UMI counts @@ -98,6 +102,7 @@ def find_top_lg( the most frequent intBC kinship_thresh: Determines the proportion of intBCs that a cell needs to share with the cluster in order to included in that cluster + intbcs: A list of intBCs to consider for seeding the lineage group Returns: A pivot table of cells labled with lineage group assignments, and a @@ -105,7 +110,7 @@ def find_top_lg( """ # Calculate sum of observed intBCs, identify top intBC - intBC_sums = PIVOT_in.sum(0).sort_values(ascending=False) + intBC_sums = PIVOT_in.loc[:,intbcs].sum(0).sort_values(ascending=False) intBC_top = intBC_sums.index[0] # Take subset of PIVOT table that contain cells that have the top intBC @@ -146,11 +151,12 @@ def find_top_lg( PIV_LG["lineageGrp"] = iteration + 1 # Print statements - logger.debug( - f"LG {iteration+1} Assignment: {PIV_LG.shape[0]} cells assigned" - ) + if PIV_LG.shape[0] > 0: + logger.debug( + f"LG {iteration+1} Assignment: {PIV_LG.shape[0]} cells assigned" + ) - return PIV_LG, PIV_noLG + return PIV_LG, PIV_noLG, intBC_top def filter_intbcs_lg_sets(