diff --git a/big_scape/cli/__init__.py b/big_scape/cli/__init__.py index d132bb36..aaf62a45 100644 --- a/big_scape/cli/__init__.py +++ b/big_scape/cli/__init__.py @@ -6,10 +6,12 @@ validate_binning_cluster_workflow, validate_binning_query_workflow, validate_alignment_mode, - validate_includelist, + validate_includelist_all, + validate_includelist_any, validate_gcf_cutoffs, validate_filter_gbk, validate_pfam_path, + validate_domain_include_list, validate_classify, validate_output_dir, validate_query_record, @@ -21,10 +23,12 @@ "validate_binning_cluster_workflow", "validate_binning_query_workflow", "validate_alignment_mode", - "validate_includelist", + "validate_includelist_all", + "validate_includelist_any", "validate_gcf_cutoffs", "validate_filter_gbk", "validate_pfam_path", + "validate_domain_include_list", "validate_classify", "validate_output_dir", "validate_query_record", diff --git a/big_scape/cli/cli_common_options.py b/big_scape/cli/cli_common_options.py index 9c1bf0a4..60fc5879 100644 --- a/big_scape/cli/cli_common_options.py +++ b/big_scape/cli/cli_common_options.py @@ -12,7 +12,8 @@ validate_not_empty_dir, validate_input_mode, validate_alignment_mode, - validate_includelist, + validate_includelist_all, + validate_includelist_any, validate_gcf_cutoffs, validate_filter_gbk, validate_record_type, @@ -218,13 +219,28 @@ def common_cluster_query(fn): ), click.option( # TODO: implement - "--domain_includelist_path", + "--domain_includelist_all_path", type=click.Path( exists=True, dir_okay=False, file_okay=True, path_type=Path ), - callback=validate_includelist, + callback=validate_includelist_all, help=( - "Path to txt file with Pfam accessions. Only BGCs containing " + "Path to txt file with Pfam accessions. Only BGCs containing all " + "the listed accessions will be analysed. In this file, each " + "line contains a single Pfam accession (with an optional comment," + " separated by a tab). Lines starting with '#' are ignored. Pfam " + "accessions are case-sensitive." + ), + ), + click.option( + # TODO: implement + "--domain_includelist_any_path", + type=click.Path( + exists=True, dir_okay=False, file_okay=True, path_type=Path + ), + callback=validate_includelist_any, + help=( + "Path to txt file with Pfam accessions. Only BGCs containing any of " "the listed accessions will be analysed. In this file, each " "line contains a single Pfam accession (with an optional comment," " separated by a tab). Lines starting with '#' are ignored. Pfam " diff --git a/big_scape/cli/cli_validations.py b/big_scape/cli/cli_validations.py index 92fe42de..a9aa04d4 100644 --- a/big_scape/cli/cli_validations.py +++ b/big_scape/cli/cli_validations.py @@ -193,7 +193,7 @@ def validate_filter_gbk(ctx, param, filter_str) -> list[str]: # hmmer parameters -def validate_includelist(ctx, param, domain_includelist_path) -> None: +def validate_includelist(ctx, param, domain_includelist_path): """Validate the path to the domain include list and return a list of domain accession strings contained within this file @@ -207,17 +207,24 @@ def validate_includelist(ctx, param, domain_includelist_path) -> None: if not domain_includelist_path.exists(): logging.error("domain_includelist file does not exist!") - raise InvalidArgumentError("--domain_includelist", domain_includelist_path) + raise InvalidArgumentError( + "--domain_includelist_all/any_path", domain_includelist_path + ) with domain_includelist_path.open(encoding="utf-8") as domain_includelist_file: lines = domain_includelist_file.readlines() - lines = [line.strip() for line in lines] + pfams = [] + + for line in lines: + line = line.strip() + elemts = line.split("\t") + pfams.append(elemts[0]) # expect Pfam accessions, i.e. PF00001 or PF00001.10 lines_valid = map( lambda string: string.startswith("PF") and len(string) in range(7, 11), - lines, + pfams, ) if not all(lines_valid): @@ -228,7 +235,37 @@ def validate_includelist(ctx, param, domain_includelist_path) -> None: "Invalid Pfam accession(s) found in file %s", domain_includelist_path ) - ctx.params["domain_includelist"] = lines + return pfams + + +def validate_includelist_all(ctx, param, domain_includelist_all_path) -> None: + """Validate the path to the domain include list and return a list of domain + accession strings contained within this file + + Returns: + list[str]: A list of domain accessions to include + """ + + pfams = validate_includelist(ctx, param, domain_includelist_all_path) + + ctx.params["domain_includelist_all"] = pfams + + return None + + +def validate_includelist_any(ctx, param, domain_includelist_any_path) -> None: + """Validate the path to the domain include list and return a list of domain + accession strings contained within this file + + Returns: + list[str]: A list of domain accessions to include + """ + + pfams = validate_includelist(ctx, param, domain_includelist_any_path) + + ctx.params["domain_includelist_any"] = pfams + + return None # workflow validations @@ -334,6 +371,23 @@ def validate_pfam_path(ctx) -> None: ) +def validate_domain_include_list(ctx) -> None: + """Raise an error if both domain includelists are given at the same time""" + + if ( + ctx.obj["domain_includelist_all_path"] + and ctx.obj["domain_includelist_any_path"] + ): + logging.error( + "You have selected both all and any domain_includelist options. " + "Please select only one of the two at a time." + ) + raise click.UsageError( + "You have selected both all and any domain_includelist options. " + "Please select only one of the two at a time." + ) + + def validate_record_type(ctx, _, record_type) -> Optional[bs_enums.genbank.RECORD_TYPE]: """Validates whether a region_type is provided when running classify""" valid_types = {mode.value: mode for mode in bs_enums.genbank.RECORD_TYPE} diff --git a/big_scape/cli/cluster_cli.py b/big_scape/cli/cluster_cli.py index dd7f64e4..ec45a9ef 100644 --- a/big_scape/cli/cluster_cli.py +++ b/big_scape/cli/cluster_cli.py @@ -13,6 +13,7 @@ validate_output_paths, validate_binning_cluster_workflow, validate_pfam_path, + validate_domain_include_list, set_start, ) @@ -58,6 +59,7 @@ def cluster(ctx, *args, **kwargs): # workflow validations validate_binning_cluster_workflow(ctx) validate_pfam_path(ctx) + validate_domain_include_list(ctx) validate_output_paths(ctx) # set start time and run label diff --git a/big_scape/comparison/comparable_region.py b/big_scape/comparison/comparable_region.py index 627fb5b3..76593648 100644 --- a/big_scape/comparison/comparable_region.py +++ b/big_scape/comparison/comparable_region.py @@ -92,8 +92,8 @@ def inflate_a(self, record_a: BGCRecord) -> None: comparable region has already been inflated, so this method should only be called once. """ - a_cds_list = record_a.get_cds() - a_cds_with_domains = record_a.get_cds_with_domains() + a_cds_list = list(record_a.get_cds()) + a_cds_with_domains = list(record_a.get_cds_with_domains()) lcs_a_start_orf_num = a_cds_with_domains[self.lcs_a_start].orf_num lcs_a_stop_orf_num = a_cds_with_domains[self.lcs_a_stop - 1].orf_num a_start_orf_num = a_cds_with_domains[self.a_start].orf_num @@ -121,8 +121,8 @@ def inflate_b(self, record_b: BGCRecord) -> None: comparable region has already been inflated, so this method should only be called once. """ - b_cds_list = record_b.get_cds() - b_cds_with_domains = record_b.get_cds_with_domains() + b_cds_list = list(record_b.get_cds()) + b_cds_with_domains = list(record_b.get_cds_with_domains()) if self.reverse: b_cds_list = b_cds_list[::-1] @@ -246,7 +246,9 @@ def cds_range_contains_biosynthetic( if end_inclusive: stop += 1 - for cds in record.get_cds_with_domains(reverse=reverse)[cds_start:stop]: + cds_list = list(record.get_cds_with_domains(reverse=reverse)) + + for cds in cds_list[cds_start:stop]: if cds.gene_kind is None: continue diff --git a/big_scape/comparison/extend.py b/big_scape/comparison/extend.py index e66f36c0..2c96dedf 100644 --- a/big_scape/comparison/extend.py +++ b/big_scape/comparison/extend.py @@ -21,12 +21,12 @@ def reset(pair: RecordPair) -> None: pair.comparable_region.a_start = 0 pair.comparable_region.b_start = 0 - pair.comparable_region.a_stop = len(pair.record_a.get_cds_with_domains()) - pair.comparable_region.b_stop = len(pair.record_b.get_cds_with_domains()) + pair.comparable_region.a_stop = len(list(pair.record_a.get_cds_with_domains())) + pair.comparable_region.b_stop = len(list(pair.record_b.get_cds_with_domains())) pair.comparable_region.domain_a_start = 0 pair.comparable_region.domain_b_start = 0 - pair.comparable_region.domain_a_stop = len(pair.record_a.get_hsps()) - pair.comparable_region.domain_b_stop = len(pair.record_b.get_hsps()) + pair.comparable_region.domain_a_stop = len(list(pair.record_a.get_hsps())) + pair.comparable_region.domain_b_stop = len(list(pair.record_b.get_hsps())) pair.comparable_region.reverse = False @@ -108,8 +108,8 @@ def extend( # get the cds lists # TODO: base extend on all domains in case of protoclusters, allow extend beyond # protocluster border - a_domains = pair.record_a.get_hsps() - b_domains = pair.record_b.get_hsps() + a_domains = list(pair.record_a.get_hsps()) + b_domains = list(pair.record_b.get_hsps()) a_max_dist = math.floor(len(a_domains) * max_match_dist_perc) b_max_dist = math.floor(len(b_domains) * max_match_dist_perc) diff --git a/big_scape/comparison/lcs.py b/big_scape/comparison/lcs.py index 9b2bcec0..76d5f4ac 100644 --- a/big_scape/comparison/lcs.py +++ b/big_scape/comparison/lcs.py @@ -171,8 +171,8 @@ def find_domain_lcs_region( logging.debug("region lcs") # these are regions, so we can get the full range of CDS - a_cds = pair.record_a.get_cds_with_domains() - b_cds = pair.record_b.get_cds_with_domains() + a_cds = list(pair.record_a.get_cds_with_domains()) + b_cds = list(pair.record_b.get_cds_with_domains()) # working on domains, not cds a_domains: list[bs_hmm.HSP] = [] @@ -431,8 +431,8 @@ def find_domain_lcs_protocluster( if not isinstance(pair.record_b, bs_genbank.ProtoCluster): raise TypeError("record_b must be a protocluster") - a_cds = pair.record_a.get_cds_with_domains() - b_cds = pair.record_b.get_cds_with_domains() + a_cds = list(pair.record_a.get_cds_with_domains()) + b_cds = list(pair.record_b.get_cds_with_domains()) # working on domains, not cds a_domains: list[bs_hmm.HSP] = [] diff --git a/big_scape/comparison/record_pair.py b/big_scape/comparison/record_pair.py index 272a2068..565e164a 100644 --- a/big_scape/comparison/record_pair.py +++ b/big_scape/comparison/record_pair.py @@ -31,10 +31,10 @@ def __init__(self, record_a: BGCRecord, record_b: BGCRecord): raise ValueError("Region in pair has no parent GBK!") # comparable regions start "deflated", meaning only CDS with domains - a_len = len(record_a.get_cds_with_domains()) - b_len = len(record_b.get_cds_with_domains()) - a_domain_len = len(record_a.get_hsps()) - b_domain_len = len(record_b.get_hsps()) + a_len = len(list(record_a.get_cds_with_domains())) + b_len = len(list(record_b.get_cds_with_domains())) + a_domain_len = len(list(record_a.get_hsps())) + b_domain_len = len(list(record_b.get_hsps())) self.comparable_region: ComparableRegion = ComparableRegion( 0, a_len, 0, b_len, 0, a_domain_len, 0, b_domain_len, False @@ -128,10 +128,11 @@ def get_domain_lists( reverse = self.comparable_region.reverse - a_cds_list = self.record_a.get_cds_with_domains()[a_start:a_stop] - b_cds_list = self.record_b.get_cds_with_domains(reverse=reverse)[ - b_start:b_stop - ] + cds_list_a = list(self.record_a.get_cds_with_domains()) + a_cds_list = cds_list_a[a_start:a_stop] + + cds_list_b = list(self.record_b.get_cds_with_domains(reverse=reverse)) + b_cds_list = cds_list_b[b_start:b_stop] a_domain_list: list[HSP] = [] for a_cds in a_cds_list: @@ -199,9 +200,9 @@ def log_comparable_region(self, label="<") -> None: # pragma: no cover if logging.getLogger().level > logging.DEBUG: return - a_cds_list = self.record_a.get_cds_with_domains() - b_cds_list = self.record_b.get_cds_with_domains( - reverse=self.comparable_region.reverse + a_cds_list = list(self.record_a.get_cds_with_domains()) + b_cds_list = list( + self.record_b.get_cds_with_domains(reverse=self.comparable_region.reverse) ) b_start = self.comparable_region.b_start diff --git a/big_scape/genbank/bgc_record.py b/big_scape/genbank/bgc_record.py index fb89c8d8..3a676e0b 100644 --- a/big_scape/genbank/bgc_record.py +++ b/big_scape/genbank/bgc_record.py @@ -4,7 +4,7 @@ # from python from __future__ import annotations -from typing import Optional, Sequence, TYPE_CHECKING +from typing import Optional, Sequence, TYPE_CHECKING, Generator import logging # from dependencies @@ -69,7 +69,7 @@ def __init__( # for networking self._families: dict[float, int] = {} - def get_cds(self, return_all=False, reverse=False) -> list[CDS]: + def get_cds(self, return_all=False, reverse=False) -> Generator[CDS, None, None]: """Get a list of CDS that lie within the coordinates specified in this region from the parent GBK class @@ -93,15 +93,14 @@ def get_cds(self, return_all=False, reverse=False) -> list[CDS]: if return_all: # TODO: I don't like this solution. maybe go back to the more difficult one if reverse: - return list(reverse(self.parent_gbk.genes)) + return reversed(self.parent_gbk.genes) - return list(self.parent_gbk.genes) + yield from self.parent_gbk.genes + return if self.nt_start is None or self.nt_stop is None: raise ValueError("Cannot CDS from region with no position information") - record_cds: list[CDS] = [] - if reverse: step = -1 else: @@ -114,11 +113,11 @@ def get_cds(self, return_all=False, reverse=False) -> list[CDS]: if cds.nt_stop > self.nt_stop: continue - record_cds.append(cds) - - return record_cds + yield cds - def get_cds_with_domains(self, return_all=False, reverse=False) -> list[CDS]: + def get_cds_with_domains( + self, return_all=False, reverse=False + ) -> Generator[CDS, None, None]: """Get a list of CDS that lie within the coordinates specified in this region from the parent GBK class @@ -137,43 +136,12 @@ def get_cds_with_domains(self, return_all=False, reverse=False) -> list[CDS]: if self.parent_gbk is None: raise ValueError("BGCRegion does not have a parent") - parent_gbk_cds: list[CDS] = self.parent_gbk.genes - - if return_all: - # TODO: I don't like this solution. maybe go back to the more difficult one - if reverse: - return [ - cds for cds in reversed(self.parent_gbk.genes) if len(cds.hsps) > 0 - ] - - return [cds for cds in self.parent_gbk.genes if len(cds.hsps) > 0] - - if self.nt_start is None or self.nt_stop is None: - raise ValueError("Cannot CDS from region with no position information") - - record_cds: list[CDS] = [] - - if reverse: - step = -1 - else: - step = 1 - - for cds in parent_gbk_cds[::step]: - if len(cds.hsps) == 0: - continue - - if cds.nt_start < self.nt_start: - continue - - if cds.nt_stop > self.nt_stop: - continue - - record_cds.append(cds) - - return record_cds + for cds in self.get_cds(return_all, reverse): + if len(cds.hsps) > 0: + yield cds - def get_hsps(self, return_all=False) -> list[HSP]: - """Get a list of all hsps in this region + def get_hsps(self, return_all=False) -> Generator[HSP, None, None]: + """Get a generator of all hsps in this region Args: return_all (bool): If set to true, returns all HSP regardless of coordinate @@ -182,14 +150,13 @@ def get_hsps(self, return_all=False) -> list[HSP]: Returns: list[HSP]: List of all hsps in this region """ - domains: list[HSP] = [] + for cds in self.get_cds_with_domains(return_all=return_all): if len(cds.hsps) > 0: if cds.strand == 1: - domains.extend(cds.hsps) + yield from cds.hsps elif cds.strand == -1: - domains.extend(cds.hsps[::-1]) - return domains + yield from cds.hsps[::-1] def get_cds_start_stop(self) -> tuple[int, int]: """Get cds ORF number of record start and stop with respect to full region diff --git a/big_scape/genbank/region.py b/big_scape/genbank/region.py index 6c404b58..25ad46b7 100644 --- a/big_scape/genbank/region.py +++ b/big_scape/genbank/region.py @@ -3,7 +3,7 @@ # from python from __future__ import annotations import logging -from typing import Dict, Optional, TYPE_CHECKING +from typing import Dict, Optional, TYPE_CHECKING, Generator # from dependencies from Bio.SeqFeature import SeqFeature @@ -241,7 +241,9 @@ def parse_full_region(cls, record: SeqRecord, parent_gbk: GBK) -> Region: region.cand_clusters = {} return region - def get_cds_with_domains(self, return_all=True, reverse=False) -> list[CDS]: + def get_cds_with_domains( + self, return_all=True, reverse=False + ) -> Generator[CDS, None, None]: return super().get_cds_with_domains(return_all, reverse) def __repr__(self): diff --git a/big_scape/output/html_template/overview_html b/big_scape/output/html_template/overview_html index 8b48263e..4f61255c 100644 --- a/big_scape/output/html_template/overview_html +++ b/big_scape/output/html_template/overview_html @@ -84,6 +84,14 @@ Include Singletons: +
+ Domain Filter: + +
+
+ Domains to Include: + +

Input Data

@@ -220,6 +228,8 @@ $("#weights").html(run_data["weights"]); $("#alignment_mode").html(run_data["alignment_mode"]); $("#include_singletons").html(run_data["include_singletons"]); + $("#domain_filter").html(run_data["domain_filter"]); + $("#domains_to_include").html(run_data["domains_to_include"]); // input information $("#total_accession").html(run_data["input"]["accession"].length); $("#total_bgc").html(run_data["input"]["bgc"].length); diff --git a/big_scape/output/legacy_output.py b/big_scape/output/legacy_output.py index e7892a82..e45fee23 100644 --- a/big_scape/output/legacy_output.py +++ b/big_scape/output/legacy_output.py @@ -243,8 +243,18 @@ def generate_run_data_js( "bgc": [], }, "networks": [], + "domain_filter": "NA", + "domains_to_include": "NA", } + if run["domain_includelist_all"] is not None: + run_data["domain_filter"] = "All" + run_data["domains_to_include"] = ", ".join(run["domain_includelist_all"]) + + if run["domain_includelist_any"] is not None: + run_data["domain_filter"] = "Any" + run_data["domains_to_include"] = ", ".join(run["domain_includelist_any"]) + # these are mostly index dictionaries needed for certain fields members: dict[GBK, int] = {} genomes: dict[str, int] = {} @@ -911,7 +921,7 @@ def generate_bs_family_alignment( if bgc_db_id is None: raise AttributeError("Record has no database id!") - bgc_domains = bgc_record.get_hsps() + bgc_domains = list(bgc_record.get_hsps()) if bgc_db_id == family_db_id: aln.append([[dom_num, 0] for dom_num in range(len(bgc_domains))]) @@ -921,7 +931,7 @@ def generate_bs_family_alignment( # should they end up in the same family we can try to align them elif bgc_gbk == fam_record.parent_gbk: a_start, b_start, reverse = align_subrecords( - fam_record.get_hsps(), bgc_domains + list(fam_record.get_hsps()), bgc_domains ) else: @@ -934,7 +944,7 @@ def generate_bs_family_alignment( a_start, b_start, reverse = adjust_lcs_to_family_reference( lcs_data, family_db_id, - len(fam_record.get_hsps()), + len(list(fam_record.get_hsps())), len(bgc_domains), ) diff --git a/big_scape/run_bigscape.py b/big_scape/run_bigscape.py index 62596dba..13b52a97 100644 --- a/big_scape/run_bigscape.py +++ b/big_scape/run_bigscape.py @@ -23,6 +23,7 @@ write_record_annotations_file, write_full_network_file, ) +from big_scape.utility import domain_includelist_filter import big_scape.file_input as bs_files @@ -166,6 +167,17 @@ def signal_handler(sig, frame): # TODO: idea: use sqlite to set distances of 1.0 for all pairs that have no domains # in common + # DOMAIN INCLUSION LIST FILTER + if run["domain_includelist_all"] or run["domain_includelist_any"]: + logging.info("Filtering records by domain_includelist") + + all_bgc_records = domain_includelist_filter(run, all_bgc_records) + + logging.info( + "Continuing with %i filtered records", + len(all_bgc_records), + ) + # DISTANCE GENERATION # mix diff --git a/big_scape/utility/__init__.py b/big_scape/utility/__init__.py index 638b1e56..2c89c9f5 100644 --- a/big_scape/utility/__init__.py +++ b/big_scape/utility/__init__.py @@ -1,7 +1,5 @@ """Contains utility functions for various applications""" from .multiprocess import start_processes, worker_method +from .filters import domain_includelist_filter -__all__ = [ - "start_processes", - "worker_method", -] +__all__ = ["start_processes", "worker_method", "domain_includelist_filter"] diff --git a/big_scape/utility/filters.py b/big_scape/utility/filters.py new file mode 100644 index 00000000..2c111e4c --- /dev/null +++ b/big_scape/utility/filters.py @@ -0,0 +1,38 @@ +"""Contains helper functions for filtering""" + +# from other modules +import big_scape.genbank as bs_gbk + + +def domain_includelist_filter(run: dict, all_bgc_records: list[bs_gbk.BGCRecord]): + include_domain_accessions = [] + domainlist_bgc_records = [] + + if run["domain_includelist_any"] is not None: + include_domain_accessions = run["domain_includelist_any"] + + for record in all_bgc_records: + record_hsps = record.get_hsps() + + result = any(hsp.domain in include_domain_accessions for hsp in record_hsps) + + if result: + domainlist_bgc_records.append(record) + + return domainlist_bgc_records + + if run["domain_includelist_all"] is not None: + include_domain_accessions = run["domain_includelist_all"] + + for record in all_bgc_records: + record_domains = set(hsp.domain for hsp in record.get_hsps()) + # TODO: is this more or less efficient than using a record_domains generator? + + result = all( + domain in record_domains for domain in include_domain_accessions + ) + + if result: + domainlist_bgc_records.append(record) + + return domainlist_bgc_records diff --git a/test/comparison/test_extend.py b/test/comparison/test_extend.py index 66ad19ff..3fdfb1bc 100644 --- a/test/comparison/test_extend.py +++ b/test/comparison/test_extend.py @@ -174,9 +174,9 @@ def test_reset_region(self): 0, len(cds_b), 0, - len(record_a.get_hsps()), + len(list(record_a.get_hsps())), 0, - len(record_b.get_hsps()), + len(list(record_b.get_hsps())), False, ) @@ -206,9 +206,9 @@ def test_reset_protocluster(self): 0, len(cds_b), 0, - len(record_a.get_hsps()), + len(list(record_a.get_hsps())), 0, - len(record_b.get_hsps()), + len(list(record_b.get_hsps())), False, ) diff --git a/test/genbank/test_bgc_record.py b/test/genbank/test_bgc_record.py index 86c0c77d..7d463e3c 100644 --- a/test/genbank/test_bgc_record.py +++ b/test/genbank/test_bgc_record.py @@ -4,10 +4,14 @@ # from python from unittest import TestCase +from pathlib import Path # from other modules from big_scape.genbank import GBK, BGCRecord, CDS, Region from big_scape.hmm import HSP +from big_scape.utility import domain_includelist_filter +from big_scape import enums as bs_enums +from big_scape.file_input.load_files import get_all_bgc_records class TestBGCRecord(TestCase): @@ -40,7 +44,7 @@ def test_get_cds(self): expected_cds_count = 8 # TODO: use __len__ - actual_cds_count = len(record.get_cds()) + actual_cds_count = len(list(record.get_cds())) self.assertEqual(expected_cds_count, actual_cds_count) @@ -70,7 +74,7 @@ def test_get_cds_all(self): expected_cds_count = 10 # TODO: use __len__ - actual_cds_count = len(record.get_cds(True)) + actual_cds_count = len(list(record.get_cds(True))) self.assertEqual(expected_cds_count, actual_cds_count) @@ -98,6 +102,96 @@ def test_get_hsps(self): self.assertEqual(expected_domains, actual_domains) + def test_domainincludelist_filter_all(self): + """Tests whether the domain_includelist_filter all + correctly filters out records that do not contain all domains""" + + run = { + "input_dir": Path("test/test_data/alt_valid_gbk_input/"), + "input_mode": bs_enums.INPUT_MODE.RECURSIVE, + "include_gbk": None, + "exclude_gbk": None, + "cds_overlap_cutoff": None, + "cores": None, + "classify": False, + "legacy_classify": False, + "record_type": bs_enums.genbank.RECORD_TYPE.PROTO_CLUSTER, + "domain_includelist_all": ["PF00001", "PF00002"], + "domain_includelist_any": None, + } + + gbk_1 = GBK("", "", "") + gbk_1.region = Region(gbk_1, 0, 0, 100, False, "") + cds_1 = CDS(10, 90) + cds_1.strand = 1 + gbk_1.genes.append(cds_1) + domains_1 = ["PF00001", "PF00002", "PF00003", "PF00004", "PF00005"] + for domain in domains_1: + cds_1.hsps.append(HSP(cds_1, domain, 100, 0, 30)) + + gbk_2 = GBK("", "", "") + gbk_2.region = Region(gbk_2, 0, 0, 100, False, "") + cds_2 = CDS(10, 90) + cds_2.strand = 1 + gbk_2.genes.append(cds_2) + domains_2 = ["PF00001", "PF00003", "PF00004", "PF00005"] + for domain in domains_2: + cds_2.hsps.append(HSP(cds_2, domain, 100, 0, 30)) + + gbks = [gbk_1, gbk_2] + + all_bgc_records = get_all_bgc_records(run, gbks) + + domainlist_bgc_records = domain_includelist_filter(run, all_bgc_records) + + expected_records = get_all_bgc_records(run, [gbk_1]) + + self.assertEqual(expected_records, domainlist_bgc_records) + + def test_domainincludelist_filter_any(self): + """Tests whether the domain_includelist_filter any + correctly filters out records that do not contain all domains""" + + run = { + "input_dir": Path("test/test_data/alt_valid_gbk_input/"), + "input_mode": bs_enums.INPUT_MODE.RECURSIVE, + "include_gbk": None, + "exclude_gbk": None, + "cds_overlap_cutoff": None, + "cores": None, + "classify": False, + "legacy_classify": False, + "record_type": bs_enums.genbank.RECORD_TYPE.PROTO_CLUSTER, + "domain_includelist_all": None, + "domain_includelist_any": ["PF00001", "PF00002"], + } + + gbk_1 = GBK("", "", "") + gbk_1.region = Region(gbk_1, 0, 0, 100, False, "") + cds_1 = CDS(10, 90) + cds_1.strand = 1 + gbk_1.genes.append(cds_1) + domains_1 = ["PF00002", "PF00003", "PF00004", "PF00005"] + for domain in domains_1: + cds_1.hsps.append(HSP(cds_1, domain, 100, 0, 30)) + + gbk_2 = GBK("", "", "") + gbk_2.region = Region(gbk_2, 0, 0, 100, False, "") + cds_2 = CDS(10, 90) + cds_2.strand = 1 + gbk_2.genes.append(cds_2) + domains_2 = ["PF00001", "PF00003", "PF00004", "PF00005"] + for domain in domains_2: + cds_2.hsps.append(HSP(cds_2, domain, 100, 0, 30)) + + gbks = [gbk_1, gbk_2] + + all_bgc_records = get_all_bgc_records(run, gbks) + + domainlist_bgc_records = domain_includelist_filter(run, all_bgc_records) + + self.assertEqual(all_bgc_records, domainlist_bgc_records) + def test_get_cds_with_domains(self): """Tests whether the test_get_cds_with_domains method correctly retrieves a subset of CDS containing only domains @@ -122,7 +216,7 @@ def test_get_cds_with_domains(self): region = BGCRecord(gbk, 0, 0, 100, False, "") expected_cds_count = 4 - actual_cds_count = len(region.get_cds_with_domains()) + actual_cds_count = len(list(region.get_cds_with_domains())) self.assertEqual(expected_cds_count, actual_cds_count)