Skip to content

Commit

Permalink
doublet detection
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed Sep 19, 2023
1 parent 5432ffd commit 24efa19
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 63 deletions.
50 changes: 49 additions & 1 deletion cassiopeia/preprocess/doublet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Invoked through pipeline.py and supports the filter_molecule_table and
call_lineage_groups functions.
"""
from typing import Dict, Set, Tuple
from typing import Dict, Set, Tuple, List

import pandas as pd

Expand Down Expand Up @@ -168,3 +168,51 @@ def filter_inter_doublets(at: pd.DataFrame, rule: float = 0.35) -> pd.DataFrame:
n_cells = at["cellBC"].nunique()
logger.debug(f"Filtered {n_filtered} inter-doublets of {n_cells} cells")
return at[at["cellBC"].isin(passing_cellBCs)]


def filter_doublet_lg_sets(
PIV: pd.DataFrame,
master_LGs: List[int],
master_intBCs: Dict[int, List[str]]
) -> Tuple[List[int], Dict[int, List[str]]]:
"""Filters out lineage groups that are likely doublets.
Essentially, filters out lineage groups that have a high proportion of
intBCs that are shared with other lineage groups. For every lineage group,
calculates the proportion of intBCs that are shared with every pair of two
other lineage groups. If the proportion is > .8, then the lineage group
is filtered out.
Args:
PIV: A pivot table of cellBC-intBC-allele groups to be filtered
master_LGs: A list of lineage groups to be filtered
master_intBCs: A dictionary that has mappings from the lineage group
number to the set of intBCs being used for reconstruction
Returns:
A filtered list of lineage groups and a filtered dictionary of intBCs
for each lineage group
"""
lg_sorted = (PIV.value_counts('lineageGrp')
.reset_index().sort_values('lineageGrp', ascending=False))

for lg in lg_sorted['lineageGrp']:
lg = tuple([lg])
filtered = False
lg_intBC = set(master_intBCs[lg])
for lg_i in master_LGs:
for lg_j in master_LGs:
if lg == lg_i or lg == lg_j:
continue
pair_intBC = set(master_intBCs[lg_i]).union(set(master_intBCs[lg_j]))
if len(pair_intBC.intersection(lg_intBC)) > len(lg_intBC) * .8:
master_LGs.remove(lg)
master_intBCs.pop(lg)
logger.debug(f"Filtered lineage group {lg} as a doublet"
f" of {lg_i} and {lg_j}")
filtered = True
break
if filtered:
break

return master_LGs, master_intBCs
116 changes: 67 additions & 49 deletions cassiopeia/preprocess/lineage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from typing import Dict, List, Tuple

from collections import Counter
from matplotlib import colors, colorbar
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -209,14 +210,14 @@ def filter_intbcs_lg_sets(
def score_lineage_kinships(
PIV: pd.DataFrame,
master_LGs: List[int],
master_intBCs: Dict[int, pd.DataFrame],
master_intBCs: Dict[int, List[str]],
) -> pd.DataFrame:
"""Identifies which lineage group each cell should belong to.
"""Calculates lineage group kinship for each cell.
Given a set of cells and a set of lineage groups with their intBCs sets,
identifies which lineage group each cell has the most kinship with. Kinship
is defined as the total UMI count of intBCs shared between the cell and the
intBC set of a lineage group.
calculates lineage group kinship for each cell. Kinship is defined as the
UMI count of intBCs shared between the cell and the intBC set of a lineage
group normalized by the total UMI count of the cell.
Args:
PIV: A pivot table of cells labled with lineage group assignments
Expand All @@ -226,8 +227,7 @@ def score_lineage_kinships(
Returns:
A DataFrame that contains the lineage group for each cell with the
greatest kinship
A DataFrame that contains the lineage group kinship for each cell.
"""

dfLG2intBC = pd.DataFrame()
Expand Down Expand Up @@ -259,67 +259,85 @@ def score_lineage_kinships(

# Matrix math
dfCellBC2LG = subPIVOT.dot(dfLG2intBC.T)
max_kinship = dfCellBC2LG.max(axis=1)

max_kinship_ind = dfCellBC2LG.apply(lambda x: np.argmax(x), axis=1)
max_kinship_frame = max_kinship.to_frame()

max_kinship_LG = pd.concat(
[max_kinship_frame, max_kinship_ind + 1], axis=1, sort=True
)
max_kinship_LG.columns = ["maxOverlap", "lineageGrp"]

return max_kinship_LG
return dfCellBC2LG


def annotate_lineage_groups(
dfMT: pd.DataFrame,
max_kinship_LG: pd.DataFrame,
master_intBCs: Dict[int, pd.DataFrame],
at: pd.DataFrame,
kinship_scores: pd.DataFrame,
doublet_kinship_thresh: float = 0.8,
) -> pd.DataFrame:
"""
Assign cells in the allele table to a lineage group.
Takes in an allele table and a DataFrame identifying the chosen
lineage group for each cell and annotates the lineage groups in the
original DataFrame.
Takes in an allele table and a DataFrame of kinship scores for each cell
which is used to assign cells to lineage groups. If a cell has a kinship
score above doublet_kinship_thresh, it is assigned to the lineage group
with the highest kinship score. If a cell has a kinship score below
doublet_kinship_thresh, it is assigned to the two lineage groups with the
highest kinship scores. Returns the original allele table with an
additional lineageGrp column.
Args:
dfMT: An allele table of cellBC-UMI-allele groups
max_kinship_LG: A DataFrame with the max kinship lineage group for each
at: An allele table of cellBC-UMI-allele groups
kinship_scores: A DataFrame with lineage kinship scores for each
cell, see documentation of score_lineage_kinships
master_intBCs: A dictionary relating lineage group to its set of intBCs
doublet_kinship_thresh: A float between 0 and 1 specifying the
the minimum kinship score a cell needs to be assigned to a single
lineage group.
Returns:
Original allele table with annotated lineage group assignments for cells
"""

dfMT["lineageGrp"] = 0
if doublet_kinship_thresh:
logger.info("Identifying inter-lineage group doublets with"
f" kinship threshold {doublet_kinship_thresh}...")

# Assign cells to lineage groups using kinship scores
cellBC2LG = {}
for n in max_kinship_LG.index:
cellBC2LG[n] = max_kinship_LG.loc[n, "lineageGrp"]

dfMT["lineageGrp"] = dfMT["cellBC"].map(cellBC2LG)

dfMT["lineageGrp"] = dfMT["lineageGrp"].fillna(value=0)

lg_sizes = {}
n_doublets = 0
for cellBC, scores in kinship_scores.iterrows():
sorted_scores = scores.sort_values(ascending=False)
if doublet_kinship_thresh:
if sorted_scores[0] < doublet_kinship_thresh:
cellBC2LG[cellBC] = [sorted_scores.index[0],
sorted_scores.index[1]]
n_doublets += 1
else:
cellBC2LG[cellBC] = [sorted_scores.index[0]]
else:
cellBC2LG[cellBC] = [sorted_scores.index[0]]

if doublet_kinship_thresh:
n_cells = len(cellBC2LG)
logger.debug(f"Identified {n_doublets} inter-group doublets"
f" out of {n_cells} cells")

# Rename lineage groups based on size
lg_counts = Counter([item for sublist in cellBC2LG.values() for item in sublist])
rename_lg = {}

for n, g in dfMT.groupby("lineageGrp"):
if n != 0:
lg_sizes[n] = len(g["cellBC"].unique())

sorted_by_value = sorted(lg_sizes.items(), key=lambda kv: kv[1])[::-1]
for i, tup in zip(range(1, len(sorted_by_value) + 1), sorted_by_value):
rename_lg[tup[0]] = float(i)

rename_lg[0] = 0.0

dfMT["lineageGrp"] = dfMT.apply(lambda x: rename_lg[x.lineageGrp], axis=1)

return dfMT
i = 1
for lg, count in lg_counts.most_common():
rename_lg[lg] = i
i += 1

# Rename lineage groups in cellBC2LG
for cellBC, lgs in cellBC2LG.items():
if len(lgs) == 1:
cellBC2LG[cellBC] = rename_lg[lgs[0]]
else:
if rename_lg[lgs[0]] < rename_lg[lgs[1]]:
cellBC2LG[cellBC] = (rename_lg[lgs[0]], rename_lg[lgs[1]])
else:
cellBC2LG[cellBC] = (rename_lg[lgs[1]], rename_lg[lgs[0]])

# Add lineageGrp column to allele table
at["lineageGrp"] = at["cellBC"].map(cellBC2LG)
at["lineageGrp"] = at["lineageGrp"].fillna(value=0).astype(str)

return at


def filter_intbcs_final_lineages(
Expand Down
71 changes: 58 additions & 13 deletions cassiopeia/preprocess/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,9 @@ def call_lineage_groups(
min_umi_per_intbc: int = 1,
min_cluster_prop: float = 0.005,
min_intbc_thresh: float = 0.05,
inter_doublet_threshold: float = 0.35,
inter_doublet_threshold: float = None,
doublet_kinship_thresh: float = 0.75,
keep_doublets: bool = False,
kinship_thresh: float = 0.25,
plot: bool = False,
) -> pd.DataFrame:
Expand Down Expand Up @@ -1059,6 +1061,15 @@ def call_lineage_groups(
inter_doublet_threshold: The threshold specifying the minimum proportion
of kinship a cell shares with its assigned lineage group out of all
lineage groups for it to be retained during doublet filtering
doublet_kinship_thresh: The threshold specifying the minimum kinship a
cell needs to have with a lineage group in order to be considered a
singlet. Cells with kinship scores below this threshold will be
filtered out or marked as doublets depending on the value of
`keep_doublets`.
keep_doublets: Whether or not to keep doublets in the allele table. If
True, doublets will appear in the allele table with a lineage group
of (lg1, lg2). If False, doublets will be removed from the
allele table.
kinship_thresh: The threshold specifying the minimum proportion of
intBCs shared between a cell and the intBC set of a lineage group
needed to assign that cell to that lineage group in putative
Expand All @@ -1068,16 +1079,29 @@ def call_lineage_groups(
Returns:
None, saves output allele table to file.
"""

if inter_doublet_threshold:
logger.warning(
"Doublet filtering with the inter_doublet_threshold parameter is"
" depreciated and will be removed in Cassiopeia 2.1.0. Please use"
" the doublet_kinship_thresh parameter instead."
)

logger.info(
f"{input_df.shape[0]} UMIs (rows), with {input_df.shape[1]} attributes (columns)"
)
logger.info(str(len(input_df["cellBC"].unique())) + " Cells")

if min_umi_per_intbc > 1:
logger.info(f"Filtering out intBCs with less than "
f"{min_umi_per_intbc} UMIs...")
input_df = input_df.groupby(['cellBC',"intBC"]).filter(
lambda x: len(x) >= min_umi_per_intbc)

# Create a pivot_table
piv = pd.pivot_table(
input_df, index="cellBC", columns="intBC", values="UMI", aggfunc="count"
)
# Filter out intBCs with fewer than min_umi_per_intbc UMIs
piv[piv < min_umi_per_intbc] = np.nan

# Normalize by total UMIs per cell
Expand Down Expand Up @@ -1108,33 +1132,54 @@ def call_lineage_groups(
piv_assigned, min_intbc_thresh=min_intbc_thresh
)

logger.info(
"Redefining lineage groups by removing doublet groups..."
)
master_LGs, master_intBCs = doublet_utils.filter_doublet_lg_sets(
piv_assigned, master_LGs, master_intBCs
)

logger.info("Reassigning cells to refined lineage groups by kinship...")
kinship_scores = lineage_utils.score_lineage_kinships(
piv_assigned, master_LGs, master_intBCs
)

logger.info("Annotating alignment table with refined lineage groups...")
allele_table = lineage_utils.annotate_lineage_groups(
input_df, kinship_scores, master_intBCs
input_df, kinship_scores, doublet_kinship_thresh=doublet_kinship_thresh,
)
if inter_doublet_threshold:

if doublet_kinship_thresh:
if not keep_doublets:
logger.info("Filtering out inter-lineage group doublets with"
f" kinship threshold {doublet_kinship_thresh}...")
allele_table = allele_table[
~allele_table["lineageGrp"].str.startswith("(")]
if inter_doublet_threshold:
logger.warning(
"Ignoring inter_doublet_threshold parameter since"
" doublet_kinship_thresh is set."
)
elif inter_doublet_threshold:
logger.info(
f"Filtering out inter-lineage group doublets with proportion {inter_doublet_threshold}..."
f"Filtering out inter-lineage group doublets with"
f" doublet threshold {inter_doublet_threshold}..."
)
allele_table = doublet_utils.filter_inter_doublets(
allele_table, rule=inter_doublet_threshold
allele_table, rule=inter_doublet_threshold,
keep_doublets=keep_doublets
)

logger.info(
"Filtering out low proportion intBCs in finalized lineage groups..."
)
filtered_lgs = lineage_utils.filter_intbcs_final_lineages(
allele_table, min_intbc_thresh=min_intbc_thresh
)
#filtered_lgs = lineage_utils.filter_intbcs_final_lineages(
# allele_table, min_intbc_thresh=min_intbc_thresh
#)

allele_table = lineage_utils.filtered_lineage_group_to_allele_table(
filtered_lgs
)
#allele_table = lineage_utils.filtered_lineage_group_to_allele_table(
# filtered_lgs
#)

logger.debug("Final lineage group assignments:")
for n, g in allele_table.groupby(["lineageGrp"]):
Expand All @@ -1146,7 +1191,7 @@ def call_lineage_groups(
min_umi_per_cell=int(min_umi_per_cell),
min_avg_reads_per_umi=min_avg_reads_per_umi,
)
allele_table["lineageGrp"] = allele_table["lineageGrp"].astype(int)
allele_table["lineageGrp"] = allele_table["lineageGrp"].astype(str)

if plot:
logger.info("Producing Plots...")
Expand Down

0 comments on commit 24efa19

Please sign in to comment.