Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bam group by speedup #164

Merged
merged 8 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 103 additions & 53 deletions cassiopeia/preprocess/UMI_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import os

from typing import Callable, List, Optional, Tuple
from typing import Callable, Generator, List, Optional, Tuple, Union
from typing_extensions import Literal

import array
Expand Down Expand Up @@ -101,11 +101,50 @@ def detect_cell_bc_tag(bam_fp: str) -> str:
return raw_tag


def group_bam_by_key(
sorted_fn: str,
sort_key: Callable[[pysam.AlignedSegment], Union[str, Tuple[str, ...]]],
n_threads: int = 1,
) -> Generator[
Tuple[Union[str, Tuple[str, ...]], List[pysam.AlignedSegment]], None, None
]:
"""Given a sorted BAM, yield groups of alignments.

Note:
The BAM must already sorted by the key used to group.

Args:
sorted_fn: Path to the sorted BAM.
sort_key: Function that yields a key, given a single alignment.
n_threads: Number of threads to use.

Returns:
A generator yielding tuple of two elements, where the first element is
the grouping key, and the second is a list of alignments with that key.
"""
with pysam.AlignmentFile(sorted_fn, check_sq=False, threads=8) as bam:
current_key = None
current_als = []
for read in bam:
read_key = sort_key(read)
if current_key != read_key:
if current_als:
yield current_key, current_als
current_als = []
current_key = read_key
current_als.append(read)
if current_als:
yield current_key, current_als


def sort_bam(
bam_fp: str,
sorted_fn: str,
sort_key: Callable[[pysam.AlignedSegment], str] = sort_key,
sort_key: Callable[
[pysam.AlignedSegment], Union[str, Tuple[str, ...]]
] = sort_key,
filter_func: Callable[[pysam.AlignedSegment], str] = filter_func,
n_threads: int = 1,
) -> (int, int):
"""Sorts aligned segments (representing a read) in the BAM file according
to a specified key.
Expand All @@ -119,53 +158,61 @@ def sort_bam(
sort_key: A function specifying the key by which to sort the aligned sequences.
filter_func: A function specifying the key by which to filter out
irrelevant sequences.
n_threads: Number of threads to use.

Returns:
The max read length and the the total number of relevant reads sorted.
"""
Path(sorted_fn).parent.mkdir(exist_ok=True)

bam_fh = pysam.AlignmentFile(str(bam_fp), check_sq=False)

relevant = filter(filter_func, bam_fh)

max_read_length = 0
total_reads_out = 0
with pysam.AlignmentFile(
bam_fp, check_sq=False, threads=n_threads
) as bam_fh:
relevant = filter(filter_func, bam_fh)

max_read_length = 0
total_reads_out = 0

chunk_fns = []

for i, chunk in enumerate(utilities.chunks(relevant, 10000000)):
suffix = ".{:06d}.bam".format(i)
chunk_fn = Path(sorted_fn).with_suffix(suffix)
sorted_chunk = sorted(chunk, key=sort_key)

with pysam.AlignmentFile(
chunk_fn, "wb", template=bam_fh, threads=n_threads
) as fh:
for al in sorted_chunk:
max_read_length = max(max_read_length, al.query_length)
total_reads_out += 1
fh.write(al)
chunk_fns.append(chunk_fn)

chunk_fhs = [
pysam.AlignmentFile(str(fn), check_header=False, check_sq=False)
for fn in chunk_fns
]

chunk_fns = []
with pysam.AlignmentFile(
sorted_fn, "wb", template=bam_fh, threads=n_threads
) as fh:
merged_chunks = heapq.merge(*chunk_fhs, key=sort_key)

for i, chunk in enumerate(utilities.chunks(relevant, 10000000)):
suffix = ".{:06d}.bam".format(i)
chunk_fn = Path(sorted_fn).with_suffix(suffix)
sorted_chunk = sorted(chunk, key=sort_key)
merged_chunks = progress(
merged_chunks,
total=total_reads_out,
desc="Merging sorted chunks",
)

with pysam.AlignmentFile(str(chunk_fn), "wb", template=bam_fh) as fh:
for al in sorted_chunk:
max_read_length = max(max_read_length, al.query_length)
total_reads_out += 1
for al in merged_chunks:
fh.write(al)
chunk_fns.append(chunk_fn)

chunk_fhs = [
pysam.AlignmentFile(str(fn), check_header=False, check_sq=False)
for fn in chunk_fns
]

with pysam.AlignmentFile(str(sorted_fn), "wb", template=bam_fh) as fh:
merged_chunks = heapq.merge(*chunk_fhs, key=sort_key)

merged_chunks = progress(
merged_chunks, total=total_reads_out, desc="Merging sorted chunks"
)
for fh in chunk_fhs:
fh.close()

for al in merged_chunks:
fh.write(al)

for fh in chunk_fhs:
fh.close()

for fn in chunk_fns:
fn.unlink()
for fn in chunk_fns:
fn.unlink()

return max_read_length, total_reads_out

Expand All @@ -177,7 +224,7 @@ def form_collapsed_clusters(
max_indels: int,
cell_key: Callable[[pysam.AlignedSegment], str] = cell_key,
UMI_key: Callable[[pysam.AlignedSegment], str] = UMI_key,
method: Literal["cutoff", "bayesian"] = "cutoff",
method: Literal["cutoff", "likelihood"] = "cutoff",
n_threads: int = 1,
):
"""Aggregates together aligned segments (reads) that share UMIs if their
Expand Down Expand Up @@ -217,14 +264,16 @@ def form_collapsed_clusters(
most probable at each position.
n_threads: Number of threads to use.
"""

sorted_als = pysam.AlignmentFile(sorted_fn, check_sq=False)

cellBC_UMI_func = lambda al: (cell_key(al), UMI_key(al))
cellBC_UMIs = set()
max_read_length = 0
for al in sorted_als:
cellBC_UMIs.add((cell_key(al), UMI_key(al)))
max_read_length = max(max_read_length, al.query_length)
with pysam.AlignmentFile(
sorted_fn, check_sq=False, threads=n_threads
) as sorted_als:
bam_header = str(sorted_als.header)
for al in sorted_als:
cellBC_UMIs.add(cellBC_UMI_func(al))
max_read_length = max(max_read_length, al.query_length)

# Raise warning when max_hq_mismatches / max_read_length > 0.5
if max_hq_mismatches / max_read_length > 0.5:
Expand All @@ -235,18 +284,18 @@ def form_collapsed_clusters(
PreprocessWarning,
)

# Read in the AlignmentFile again as iterating over it in the previous for
# loop has destructively removed all alignments from the file object
sorted_als = pysam.AlignmentFile(sorted_fn, check_sq=False)
cell_groups = utilities.group_by(sorted_als, cell_key)

# Helper function so that we can use joblib to parallelize the computation
def cluster_group(cell_BC, UMI, UMI_group, header_text):
header = pysam.AlignmentHeader.from_text(header_text)
UMI_group = [
pysam.AlignedSegment.fromstring(s, header) for s in UMI_group
]
if method == "cutoff":
# Very unlikely, but for very deeply sequenced libraries or libraries
# that requires many cycles of PCR amplification, some UMIs may have
# upwards of 10k+ reads. Likelihood-based consensus calling isn't
# meant to deal with such high read counts, so we will fall back to
# the cutoff method.
if method == "cutoff" or len(UMI_group) > 10000:
clusters = form_clusters(
UMI_group, max_read_length, max_hq_mismatches
)
Expand Down Expand Up @@ -305,10 +354,11 @@ def cluster_group(cell_BC, UMI, UMI_group, header_text):
cell_BC,
UMI,
[aln.to_string() for aln in UMI_group],
str(sorted_als.header)
bam_header,
)
for (cell_BC, UMI), UMI_group in group_bam_by_key(
sorted_fn, cellBC_UMI_func, n_threads=n_threads
)
for cell_BC, cell_group in cell_groups
for UMI, UMI_group in utilities.group_by(cell_group, UMI_key)
)

with pysam.AlignmentFile(
Expand Down
11 changes: 11 additions & 0 deletions cassiopeia/preprocess/cassiopeia_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ def main():
)
continue

# If intBC correction was performed, don't correct in the
# filter_molecule_table step
if stage == "filter_molecule_table" and pipeline_parameters[stage].get(
"whitelist"
):
logger.warning(
"intBC whitelist was provided. "
"Turning off intBC correction in `filter_molecule_table` stage."
)
pipeline_parameters[stage]["intbc_dist_thresh"] = -1

procedure = STAGES[stage]
data = procedure(data, **pipeline_parameters[stage])

Expand Down
3 changes: 2 additions & 1 deletion cassiopeia/preprocess/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,14 @@ def collapse_umis(
logger.info(f"Using BAM tag `{cell_bc_tag}` as cell barcodes")

max_read_length, total_reads_out = UMI_utils.sort_bam(
bam_fp,
str(bam_fp),
str(sorted_file_name),
sort_key=lambda al: (
al.get_tag(cell_bc_tag),
al.get_tag(BAM_CONSTANTS["UMI_TAG"]),
),
filter_func=lambda al: al.has_tag(cell_bc_tag),
n_threads=n_threads,
)
logger.info("Sorted bam directory saved to " + str(sorted_file_name))
logger.info("Max read length of " + str(max_read_length))
Expand Down
6 changes: 4 additions & 2 deletions test/preprocess_tests/collapse_umi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def setUp(self):
".collapsed.bam"
)

_, _ = UMI_utils.sort_bam(self.test_file, str(self.sorted_file_name))
_, _ = UMI_utils.sort_bam(
str(self.test_file), str(self.sorted_file_name)
)

UMI_utils.form_collapsed_clusters(
str(self.sorted_file_name),
Expand Down Expand Up @@ -67,7 +69,7 @@ def setUp(self):
)

_, _ = UMI_utils.sort_bam(
self.uncorrected_test_file,
str(self.uncorrected_test_file),
str(self.uncorrected_sorted_file_name),
sort_key=lambda al: (al.get_tag("CR"), al.get_tag("UR")),
filter_func=lambda al: al.has_tag("CR"),
Expand Down