diff --git a/cassiopeia/preprocess/UMI_utils.py b/cassiopeia/preprocess/UMI_utils.py index cc8268bf..05ffbd0a 100755 --- a/cassiopeia/preprocess/UMI_utils.py +++ b/cassiopeia/preprocess/UMI_utils.py @@ -5,7 +5,7 @@ """ import os -from typing import Callable, List, Optional, Tuple +from typing import Callable, Generator, List, Optional, Tuple, Union from typing_extensions import Literal import array @@ -101,11 +101,50 @@ def detect_cell_bc_tag(bam_fp: str) -> str: return raw_tag +def group_bam_by_key( + sorted_fn: str, + sort_key: Callable[[pysam.AlignedSegment], Union[str, Tuple[str, ...]]], + n_threads: int = 1, +) -> Generator[ + Tuple[Union[str, Tuple[str, ...]], List[pysam.AlignedSegment]], None, None +]: + """Given a sorted BAM, yield groups of alignments. + + Note: + The BAM must already sorted by the key used to group. + + Args: + sorted_fn: Path to the sorted BAM. + sort_key: Function that yields a key, given a single alignment. + n_threads: Number of threads to use. + + Returns: + A generator yielding tuple of two elements, where the first element is + the grouping key, and the second is a list of alignments with that key. + """ + with pysam.AlignmentFile(sorted_fn, check_sq=False, threads=8) as bam: + current_key = None + current_als = [] + for read in bam: + read_key = sort_key(read) + if current_key != read_key: + if current_als: + yield current_key, current_als + current_als = [] + current_key = read_key + current_als.append(read) + if current_als: + yield current_key, current_als + + def sort_bam( bam_fp: str, sorted_fn: str, - sort_key: Callable[[pysam.AlignedSegment], str] = sort_key, + sort_key: Callable[ + [pysam.AlignedSegment], Union[str, Tuple[str, ...]] + ] = sort_key, filter_func: Callable[[pysam.AlignedSegment], str] = filter_func, + n_threads: int = 1, ) -> (int, int): """Sorts aligned segments (representing a read) in the BAM file according to a specified key. @@ -119,53 +158,61 @@ def sort_bam( sort_key: A function specifying the key by which to sort the aligned sequences. filter_func: A function specifying the key by which to filter out irrelevant sequences. + n_threads: Number of threads to use. Returns: The max read length and the the total number of relevant reads sorted. """ Path(sorted_fn).parent.mkdir(exist_ok=True) - bam_fh = pysam.AlignmentFile(str(bam_fp), check_sq=False) - - relevant = filter(filter_func, bam_fh) - - max_read_length = 0 - total_reads_out = 0 + with pysam.AlignmentFile( + bam_fp, check_sq=False, threads=n_threads + ) as bam_fh: + relevant = filter(filter_func, bam_fh) + + max_read_length = 0 + total_reads_out = 0 + + chunk_fns = [] + + for i, chunk in enumerate(utilities.chunks(relevant, 10000000)): + suffix = ".{:06d}.bam".format(i) + chunk_fn = Path(sorted_fn).with_suffix(suffix) + sorted_chunk = sorted(chunk, key=sort_key) + + with pysam.AlignmentFile( + chunk_fn, "wb", template=bam_fh, threads=n_threads + ) as fh: + for al in sorted_chunk: + max_read_length = max(max_read_length, al.query_length) + total_reads_out += 1 + fh.write(al) + chunk_fns.append(chunk_fn) + + chunk_fhs = [ + pysam.AlignmentFile(str(fn), check_header=False, check_sq=False) + for fn in chunk_fns + ] - chunk_fns = [] + with pysam.AlignmentFile( + sorted_fn, "wb", template=bam_fh, threads=n_threads + ) as fh: + merged_chunks = heapq.merge(*chunk_fhs, key=sort_key) - for i, chunk in enumerate(utilities.chunks(relevant, 10000000)): - suffix = ".{:06d}.bam".format(i) - chunk_fn = Path(sorted_fn).with_suffix(suffix) - sorted_chunk = sorted(chunk, key=sort_key) + merged_chunks = progress( + merged_chunks, + total=total_reads_out, + desc="Merging sorted chunks", + ) - with pysam.AlignmentFile(str(chunk_fn), "wb", template=bam_fh) as fh: - for al in sorted_chunk: - max_read_length = max(max_read_length, al.query_length) - total_reads_out += 1 + for al in merged_chunks: fh.write(al) - chunk_fns.append(chunk_fn) - chunk_fhs = [ - pysam.AlignmentFile(str(fn), check_header=False, check_sq=False) - for fn in chunk_fns - ] - - with pysam.AlignmentFile(str(sorted_fn), "wb", template=bam_fh) as fh: - merged_chunks = heapq.merge(*chunk_fhs, key=sort_key) - - merged_chunks = progress( - merged_chunks, total=total_reads_out, desc="Merging sorted chunks" - ) + for fh in chunk_fhs: + fh.close() - for al in merged_chunks: - fh.write(al) - - for fh in chunk_fhs: - fh.close() - - for fn in chunk_fns: - fn.unlink() + for fn in chunk_fns: + fn.unlink() return max_read_length, total_reads_out @@ -177,7 +224,7 @@ def form_collapsed_clusters( max_indels: int, cell_key: Callable[[pysam.AlignedSegment], str] = cell_key, UMI_key: Callable[[pysam.AlignedSegment], str] = UMI_key, - method: Literal["cutoff", "bayesian"] = "cutoff", + method: Literal["cutoff", "likelihood"] = "cutoff", n_threads: int = 1, ): """Aggregates together aligned segments (reads) that share UMIs if their @@ -217,14 +264,16 @@ def form_collapsed_clusters( most probable at each position. n_threads: Number of threads to use. """ - - sorted_als = pysam.AlignmentFile(sorted_fn, check_sq=False) - + cellBC_UMI_func = lambda al: (cell_key(al), UMI_key(al)) cellBC_UMIs = set() max_read_length = 0 - for al in sorted_als: - cellBC_UMIs.add((cell_key(al), UMI_key(al))) - max_read_length = max(max_read_length, al.query_length) + with pysam.AlignmentFile( + sorted_fn, check_sq=False, threads=n_threads + ) as sorted_als: + bam_header = str(sorted_als.header) + for al in sorted_als: + cellBC_UMIs.add(cellBC_UMI_func(al)) + max_read_length = max(max_read_length, al.query_length) # Raise warning when max_hq_mismatches / max_read_length > 0.5 if max_hq_mismatches / max_read_length > 0.5: @@ -235,18 +284,18 @@ def form_collapsed_clusters( PreprocessWarning, ) - # Read in the AlignmentFile again as iterating over it in the previous for - # loop has destructively removed all alignments from the file object - sorted_als = pysam.AlignmentFile(sorted_fn, check_sq=False) - cell_groups = utilities.group_by(sorted_als, cell_key) - # Helper function so that we can use joblib to parallelize the computation def cluster_group(cell_BC, UMI, UMI_group, header_text): header = pysam.AlignmentHeader.from_text(header_text) UMI_group = [ pysam.AlignedSegment.fromstring(s, header) for s in UMI_group ] - if method == "cutoff": + # Very unlikely, but for very deeply sequenced libraries or libraries + # that requires many cycles of PCR amplification, some UMIs may have + # upwards of 10k+ reads. Likelihood-based consensus calling isn't + # meant to deal with such high read counts, so we will fall back to + # the cutoff method. + if method == "cutoff" or len(UMI_group) > 10000: clusters = form_clusters( UMI_group, max_read_length, max_hq_mismatches ) @@ -305,10 +354,11 @@ def cluster_group(cell_BC, UMI, UMI_group, header_text): cell_BC, UMI, [aln.to_string() for aln in UMI_group], - str(sorted_als.header) + bam_header, + ) + for (cell_BC, UMI), UMI_group in group_bam_by_key( + sorted_fn, cellBC_UMI_func, n_threads=n_threads ) - for cell_BC, cell_group in cell_groups - for UMI, UMI_group in utilities.group_by(cell_group, UMI_key) ) with pysam.AlignmentFile( diff --git a/cassiopeia/preprocess/cassiopeia_preprocess.py b/cassiopeia/preprocess/cassiopeia_preprocess.py index abcc6dc6..afb9b12d 100755 --- a/cassiopeia/preprocess/cassiopeia_preprocess.py +++ b/cassiopeia/preprocess/cassiopeia_preprocess.py @@ -121,6 +121,17 @@ def main(): ) continue + # If intBC correction was performed, don't correct in the + # filter_molecule_table step + if stage == "filter_molecule_table" and pipeline_parameters[stage].get( + "whitelist" + ): + logger.warning( + "intBC whitelist was provided. " + "Turning off intBC correction in `filter_molecule_table` stage." + ) + pipeline_parameters[stage]["intbc_dist_thresh"] = -1 + procedure = STAGES[stage] data = procedure(data, **pipeline_parameters[stage]) diff --git a/cassiopeia/preprocess/pipeline.py b/cassiopeia/preprocess/pipeline.py index 9a04f88f..7349c171 100755 --- a/cassiopeia/preprocess/pipeline.py +++ b/cassiopeia/preprocess/pipeline.py @@ -290,13 +290,14 @@ def collapse_umis( logger.info(f"Using BAM tag `{cell_bc_tag}` as cell barcodes") max_read_length, total_reads_out = UMI_utils.sort_bam( - bam_fp, + str(bam_fp), str(sorted_file_name), sort_key=lambda al: ( al.get_tag(cell_bc_tag), al.get_tag(BAM_CONSTANTS["UMI_TAG"]), ), filter_func=lambda al: al.has_tag(cell_bc_tag), + n_threads=n_threads, ) logger.info("Sorted bam directory saved to " + str(sorted_file_name)) logger.info("Max read length of " + str(max_read_length)) diff --git a/test/preprocess_tests/collapse_umi_test.py b/test/preprocess_tests/collapse_umi_test.py index d3829301..2590f0bd 100755 --- a/test/preprocess_tests/collapse_umi_test.py +++ b/test/preprocess_tests/collapse_umi_test.py @@ -30,7 +30,9 @@ def setUp(self): ".collapsed.bam" ) - _, _ = UMI_utils.sort_bam(self.test_file, str(self.sorted_file_name)) + _, _ = UMI_utils.sort_bam( + str(self.test_file), str(self.sorted_file_name) + ) UMI_utils.form_collapsed_clusters( str(self.sorted_file_name), @@ -67,7 +69,7 @@ def setUp(self): ) _, _ = UMI_utils.sort_bam( - self.uncorrected_test_file, + str(self.uncorrected_test_file), str(self.uncorrected_sorted_file_name), sort_key=lambda al: (al.get_tag("CR"), al.get_tag("UR")), filter_func=lambda al: al.has_tag("CR"),