diff --git a/big_scape/cli/__init__.py b/big_scape/cli/__init__.py index d132bb36..aaf62a45 100644 --- a/big_scape/cli/__init__.py +++ b/big_scape/cli/__init__.py @@ -6,10 +6,12 @@ validate_binning_cluster_workflow, validate_binning_query_workflow, validate_alignment_mode, - validate_includelist, + validate_includelist_all, + validate_includelist_any, validate_gcf_cutoffs, validate_filter_gbk, validate_pfam_path, + validate_domain_include_list, validate_classify, validate_output_dir, validate_query_record, @@ -21,10 +23,12 @@ "validate_binning_cluster_workflow", "validate_binning_query_workflow", "validate_alignment_mode", - "validate_includelist", + "validate_includelist_all", + "validate_includelist_any", "validate_gcf_cutoffs", "validate_filter_gbk", "validate_pfam_path", + "validate_domain_include_list", "validate_classify", "validate_output_dir", "validate_query_record", diff --git a/big_scape/cli/cli_common_options.py b/big_scape/cli/cli_common_options.py index 9c1bf0a4..60fc5879 100644 --- a/big_scape/cli/cli_common_options.py +++ b/big_scape/cli/cli_common_options.py @@ -12,7 +12,8 @@ validate_not_empty_dir, validate_input_mode, validate_alignment_mode, - validate_includelist, + validate_includelist_all, + validate_includelist_any, validate_gcf_cutoffs, validate_filter_gbk, validate_record_type, @@ -218,13 +219,28 @@ def common_cluster_query(fn): ), click.option( # TODO: implement - "--domain_includelist_path", + "--domain_includelist_all_path", type=click.Path( exists=True, dir_okay=False, file_okay=True, path_type=Path ), - callback=validate_includelist, + callback=validate_includelist_all, help=( - "Path to txt file with Pfam accessions. Only BGCs containing " + "Path to txt file with Pfam accessions. Only BGCs containing all " + "the listed accessions will be analysed. In this file, each " + "line contains a single Pfam accession (with an optional comment," + " separated by a tab). Lines starting with '#' are ignored. Pfam " + "accessions are case-sensitive." + ), + ), + click.option( + # TODO: implement + "--domain_includelist_any_path", + type=click.Path( + exists=True, dir_okay=False, file_okay=True, path_type=Path + ), + callback=validate_includelist_any, + help=( + "Path to txt file with Pfam accessions. Only BGCs containing any of " "the listed accessions will be analysed. In this file, each " "line contains a single Pfam accession (with an optional comment," " separated by a tab). Lines starting with '#' are ignored. Pfam " diff --git a/big_scape/cli/cli_validations.py b/big_scape/cli/cli_validations.py index 92fe42de..a9aa04d4 100644 --- a/big_scape/cli/cli_validations.py +++ b/big_scape/cli/cli_validations.py @@ -193,7 +193,7 @@ def validate_filter_gbk(ctx, param, filter_str) -> list[str]: # hmmer parameters -def validate_includelist(ctx, param, domain_includelist_path) -> None: +def validate_includelist(ctx, param, domain_includelist_path): """Validate the path to the domain include list and return a list of domain accession strings contained within this file @@ -207,17 +207,24 @@ def validate_includelist(ctx, param, domain_includelist_path) -> None: if not domain_includelist_path.exists(): logging.error("domain_includelist file does not exist!") - raise InvalidArgumentError("--domain_includelist", domain_includelist_path) + raise InvalidArgumentError( + "--domain_includelist_all/any_path", domain_includelist_path + ) with domain_includelist_path.open(encoding="utf-8") as domain_includelist_file: lines = domain_includelist_file.readlines() - lines = [line.strip() for line in lines] + pfams = [] + + for line in lines: + line = line.strip() + elemts = line.split("\t") + pfams.append(elemts[0]) # expect Pfam accessions, i.e. PF00001 or PF00001.10 lines_valid = map( lambda string: string.startswith("PF") and len(string) in range(7, 11), - lines, + pfams, ) if not all(lines_valid): @@ -228,7 +235,37 @@ def validate_includelist(ctx, param, domain_includelist_path) -> None: "Invalid Pfam accession(s) found in file %s", domain_includelist_path ) - ctx.params["domain_includelist"] = lines + return pfams + + +def validate_includelist_all(ctx, param, domain_includelist_all_path) -> None: + """Validate the path to the domain include list and return a list of domain + accession strings contained within this file + + Returns: + list[str]: A list of domain accessions to include + """ + + pfams = validate_includelist(ctx, param, domain_includelist_all_path) + + ctx.params["domain_includelist_all"] = pfams + + return None + + +def validate_includelist_any(ctx, param, domain_includelist_any_path) -> None: + """Validate the path to the domain include list and return a list of domain + accession strings contained within this file + + Returns: + list[str]: A list of domain accessions to include + """ + + pfams = validate_includelist(ctx, param, domain_includelist_any_path) + + ctx.params["domain_includelist_any"] = pfams + + return None # workflow validations @@ -334,6 +371,23 @@ def validate_pfam_path(ctx) -> None: ) +def validate_domain_include_list(ctx) -> None: + """Raise an error if both domain includelists are given at the same time""" + + if ( + ctx.obj["domain_includelist_all_path"] + and ctx.obj["domain_includelist_any_path"] + ): + logging.error( + "You have selected both all and any domain_includelist options. " + "Please select only one of the two at a time." + ) + raise click.UsageError( + "You have selected both all and any domain_includelist options. " + "Please select only one of the two at a time." + ) + + def validate_record_type(ctx, _, record_type) -> Optional[bs_enums.genbank.RECORD_TYPE]: """Validates whether a region_type is provided when running classify""" valid_types = {mode.value: mode for mode in bs_enums.genbank.RECORD_TYPE} diff --git a/big_scape/cli/cluster_cli.py b/big_scape/cli/cluster_cli.py index dd7f64e4..ec45a9ef 100644 --- a/big_scape/cli/cluster_cli.py +++ b/big_scape/cli/cluster_cli.py @@ -13,6 +13,7 @@ validate_output_paths, validate_binning_cluster_workflow, validate_pfam_path, + validate_domain_include_list, set_start, ) @@ -58,6 +59,7 @@ def cluster(ctx, *args, **kwargs): # workflow validations validate_binning_cluster_workflow(ctx) validate_pfam_path(ctx) + validate_domain_include_list(ctx) validate_output_paths(ctx) # set start time and run label diff --git a/big_scape/comparison/comparable_region.py b/big_scape/comparison/comparable_region.py index 627fb5b3..76593648 100644 --- a/big_scape/comparison/comparable_region.py +++ b/big_scape/comparison/comparable_region.py @@ -92,8 +92,8 @@ def inflate_a(self, record_a: BGCRecord) -> None: comparable region has already been inflated, so this method should only be called once. """ - a_cds_list = record_a.get_cds() - a_cds_with_domains = record_a.get_cds_with_domains() + a_cds_list = list(record_a.get_cds()) + a_cds_with_domains = list(record_a.get_cds_with_domains()) lcs_a_start_orf_num = a_cds_with_domains[self.lcs_a_start].orf_num lcs_a_stop_orf_num = a_cds_with_domains[self.lcs_a_stop - 1].orf_num a_start_orf_num = a_cds_with_domains[self.a_start].orf_num @@ -121,8 +121,8 @@ def inflate_b(self, record_b: BGCRecord) -> None: comparable region has already been inflated, so this method should only be called once. """ - b_cds_list = record_b.get_cds() - b_cds_with_domains = record_b.get_cds_with_domains() + b_cds_list = list(record_b.get_cds()) + b_cds_with_domains = list(record_b.get_cds_with_domains()) if self.reverse: b_cds_list = b_cds_list[::-1] @@ -246,7 +246,9 @@ def cds_range_contains_biosynthetic( if end_inclusive: stop += 1 - for cds in record.get_cds_with_domains(reverse=reverse)[cds_start:stop]: + cds_list = list(record.get_cds_with_domains(reverse=reverse)) + + for cds in cds_list[cds_start:stop]: if cds.gene_kind is None: continue diff --git a/big_scape/comparison/extend.py b/big_scape/comparison/extend.py index e66f36c0..2c96dedf 100644 --- a/big_scape/comparison/extend.py +++ b/big_scape/comparison/extend.py @@ -21,12 +21,12 @@ def reset(pair: RecordPair) -> None: pair.comparable_region.a_start = 0 pair.comparable_region.b_start = 0 - pair.comparable_region.a_stop = len(pair.record_a.get_cds_with_domains()) - pair.comparable_region.b_stop = len(pair.record_b.get_cds_with_domains()) + pair.comparable_region.a_stop = len(list(pair.record_a.get_cds_with_domains())) + pair.comparable_region.b_stop = len(list(pair.record_b.get_cds_with_domains())) pair.comparable_region.domain_a_start = 0 pair.comparable_region.domain_b_start = 0 - pair.comparable_region.domain_a_stop = len(pair.record_a.get_hsps()) - pair.comparable_region.domain_b_stop = len(pair.record_b.get_hsps()) + pair.comparable_region.domain_a_stop = len(list(pair.record_a.get_hsps())) + pair.comparable_region.domain_b_stop = len(list(pair.record_b.get_hsps())) pair.comparable_region.reverse = False @@ -108,8 +108,8 @@ def extend( # get the cds lists # TODO: base extend on all domains in case of protoclusters, allow extend beyond # protocluster border - a_domains = pair.record_a.get_hsps() - b_domains = pair.record_b.get_hsps() + a_domains = list(pair.record_a.get_hsps()) + b_domains = list(pair.record_b.get_hsps()) a_max_dist = math.floor(len(a_domains) * max_match_dist_perc) b_max_dist = math.floor(len(b_domains) * max_match_dist_perc) diff --git a/big_scape/comparison/lcs.py b/big_scape/comparison/lcs.py index 9b2bcec0..76d5f4ac 100644 --- a/big_scape/comparison/lcs.py +++ b/big_scape/comparison/lcs.py @@ -171,8 +171,8 @@ def find_domain_lcs_region( logging.debug("region lcs") # these are regions, so we can get the full range of CDS - a_cds = pair.record_a.get_cds_with_domains() - b_cds = pair.record_b.get_cds_with_domains() + a_cds = list(pair.record_a.get_cds_with_domains()) + b_cds = list(pair.record_b.get_cds_with_domains()) # working on domains, not cds a_domains: list[bs_hmm.HSP] = [] @@ -431,8 +431,8 @@ def find_domain_lcs_protocluster( if not isinstance(pair.record_b, bs_genbank.ProtoCluster): raise TypeError("record_b must be a protocluster") - a_cds = pair.record_a.get_cds_with_domains() - b_cds = pair.record_b.get_cds_with_domains() + a_cds = list(pair.record_a.get_cds_with_domains()) + b_cds = list(pair.record_b.get_cds_with_domains()) # working on domains, not cds a_domains: list[bs_hmm.HSP] = [] diff --git a/big_scape/comparison/record_pair.py b/big_scape/comparison/record_pair.py index 272a2068..565e164a 100644 --- a/big_scape/comparison/record_pair.py +++ b/big_scape/comparison/record_pair.py @@ -31,10 +31,10 @@ def __init__(self, record_a: BGCRecord, record_b: BGCRecord): raise ValueError("Region in pair has no parent GBK!") # comparable regions start "deflated", meaning only CDS with domains - a_len = len(record_a.get_cds_with_domains()) - b_len = len(record_b.get_cds_with_domains()) - a_domain_len = len(record_a.get_hsps()) - b_domain_len = len(record_b.get_hsps()) + a_len = len(list(record_a.get_cds_with_domains())) + b_len = len(list(record_b.get_cds_with_domains())) + a_domain_len = len(list(record_a.get_hsps())) + b_domain_len = len(list(record_b.get_hsps())) self.comparable_region: ComparableRegion = ComparableRegion( 0, a_len, 0, b_len, 0, a_domain_len, 0, b_domain_len, False @@ -128,10 +128,11 @@ def get_domain_lists( reverse = self.comparable_region.reverse - a_cds_list = self.record_a.get_cds_with_domains()[a_start:a_stop] - b_cds_list = self.record_b.get_cds_with_domains(reverse=reverse)[ - b_start:b_stop - ] + cds_list_a = list(self.record_a.get_cds_with_domains()) + a_cds_list = cds_list_a[a_start:a_stop] + + cds_list_b = list(self.record_b.get_cds_with_domains(reverse=reverse)) + b_cds_list = cds_list_b[b_start:b_stop] a_domain_list: list[HSP] = [] for a_cds in a_cds_list: @@ -199,9 +200,9 @@ def log_comparable_region(self, label="<") -> None: # pragma: no cover if logging.getLogger().level > logging.DEBUG: return - a_cds_list = self.record_a.get_cds_with_domains() - b_cds_list = self.record_b.get_cds_with_domains( - reverse=self.comparable_region.reverse + a_cds_list = list(self.record_a.get_cds_with_domains()) + b_cds_list = list( + self.record_b.get_cds_with_domains(reverse=self.comparable_region.reverse) ) b_start = self.comparable_region.b_start diff --git a/big_scape/genbank/bgc_record.py b/big_scape/genbank/bgc_record.py index fb89c8d8..3a676e0b 100644 --- a/big_scape/genbank/bgc_record.py +++ b/big_scape/genbank/bgc_record.py @@ -4,7 +4,7 @@ # from python from __future__ import annotations -from typing import Optional, Sequence, TYPE_CHECKING +from typing import Optional, Sequence, TYPE_CHECKING, Generator import logging # from dependencies @@ -69,7 +69,7 @@ def __init__( # for networking self._families: dict[float, int] = {} - def get_cds(self, return_all=False, reverse=False) -> list[CDS]: + def get_cds(self, return_all=False, reverse=False) -> Generator[CDS, None, None]: """Get a list of CDS that lie within the coordinates specified in this region from the parent GBK class @@ -93,15 +93,14 @@ def get_cds(self, return_all=False, reverse=False) -> list[CDS]: if return_all: # TODO: I don't like this solution. maybe go back to the more difficult one if reverse: - return list(reverse(self.parent_gbk.genes)) + return reversed(self.parent_gbk.genes) - return list(self.parent_gbk.genes) + yield from self.parent_gbk.genes + return if self.nt_start is None or self.nt_stop is None: raise ValueError("Cannot CDS from region with no position information") - record_cds: list[CDS] = [] - if reverse: step = -1 else: @@ -114,11 +113,11 @@ def get_cds(self, return_all=False, reverse=False) -> list[CDS]: if cds.nt_stop > self.nt_stop: continue - record_cds.append(cds) - - return record_cds + yield cds - def get_cds_with_domains(self, return_all=False, reverse=False) -> list[CDS]: + def get_cds_with_domains( + self, return_all=False, reverse=False + ) -> Generator[CDS, None, None]: """Get a list of CDS that lie within the coordinates specified in this region from the parent GBK class @@ -137,43 +136,12 @@ def get_cds_with_domains(self, return_all=False, reverse=False) -> list[CDS]: if self.parent_gbk is None: raise ValueError("BGCRegion does not have a parent") - parent_gbk_cds: list[CDS] = self.parent_gbk.genes - - if return_all: - # TODO: I don't like this solution. maybe go back to the more difficult one - if reverse: - return [ - cds for cds in reversed(self.parent_gbk.genes) if len(cds.hsps) > 0 - ] - - return [cds for cds in self.parent_gbk.genes if len(cds.hsps) > 0] - - if self.nt_start is None or self.nt_stop is None: - raise ValueError("Cannot CDS from region with no position information") - - record_cds: list[CDS] = [] - - if reverse: - step = -1 - else: - step = 1 - - for cds in parent_gbk_cds[::step]: - if len(cds.hsps) == 0: - continue - - if cds.nt_start < self.nt_start: - continue - - if cds.nt_stop > self.nt_stop: - continue - - record_cds.append(cds) - - return record_cds + for cds in self.get_cds(return_all, reverse): + if len(cds.hsps) > 0: + yield cds - def get_hsps(self, return_all=False) -> list[HSP]: - """Get a list of all hsps in this region + def get_hsps(self, return_all=False) -> Generator[HSP, None, None]: + """Get a generator of all hsps in this region Args: return_all (bool): If set to true, returns all HSP regardless of coordinate @@ -182,14 +150,13 @@ def get_hsps(self, return_all=False) -> list[HSP]: Returns: list[HSP]: List of all hsps in this region """ - domains: list[HSP] = [] + for cds in self.get_cds_with_domains(return_all=return_all): if len(cds.hsps) > 0: if cds.strand == 1: - domains.extend(cds.hsps) + yield from cds.hsps elif cds.strand == -1: - domains.extend(cds.hsps[::-1]) - return domains + yield from cds.hsps[::-1] def get_cds_start_stop(self) -> tuple[int, int]: """Get cds ORF number of record start and stop with respect to full region diff --git a/big_scape/genbank/region.py b/big_scape/genbank/region.py index 6c404b58..25ad46b7 100644 --- a/big_scape/genbank/region.py +++ b/big_scape/genbank/region.py @@ -3,7 +3,7 @@ # from python from __future__ import annotations import logging -from typing import Dict, Optional, TYPE_CHECKING +from typing import Dict, Optional, TYPE_CHECKING, Generator # from dependencies from Bio.SeqFeature import SeqFeature @@ -241,7 +241,9 @@ def parse_full_region(cls, record: SeqRecord, parent_gbk: GBK) -> Region: region.cand_clusters = {} return region - def get_cds_with_domains(self, return_all=True, reverse=False) -> list[CDS]: + def get_cds_with_domains( + self, return_all=True, reverse=False + ) -> Generator[CDS, None, None]: return super().get_cds_with_domains(return_all, reverse) def __repr__(self): diff --git a/big_scape/output/html_template/overview_html b/big_scape/output/html_template/overview_html index 8b48263e..4f61255c 100644 --- a/big_scape/output/html_template/overview_html +++ b/big_scape/output/html_template/overview_html @@ -84,6 +84,14 @@ Include Singletons: +
+