diff --git a/big_scape/benchmarking/benchmark_metrics.py b/big_scape/benchmarking/benchmark_metrics.py index 36a5ab3f..bb97ee6f 100644 --- a/big_scape/benchmarking/benchmark_metrics.py +++ b/big_scape/benchmarking/benchmark_metrics.py @@ -80,7 +80,12 @@ def calculate_v_measure(self) -> tuple[float, float, float]: curated_fams = list(self.curated_labels.values()) computed_fams = list(self.computed_labels.values()) - return homogeneity_completeness_v_measure(curated_fams, computed_fams) + metrics: tuple[float, float, float] = homogeneity_completeness_v_measure( + curated_fams, + computed_fams, + ) + + return metrics def calculate_purity(self) -> dict[str, float]: """Calculate purity P of each computed GCF diff --git a/big_scape/benchmarking/benchmark_output.py b/big_scape/benchmarking/benchmark_output.py index cb176776..d346117a 100644 --- a/big_scape/benchmarking/benchmark_output.py +++ b/big_scape/benchmarking/benchmark_output.py @@ -172,7 +172,7 @@ def plot_per_cutoff(self, metrics: dict[str, dict[str, Any]]) -> None: cutoffs_fl = list(map(float, cutoffs)) fig = plt.figure() - ax = fig.gca() + ax = fig.gca() # type: ignore h = ax.plot( cutoffs_fl, @@ -225,8 +225,8 @@ def plot_conf_matrix_heatmap( matrix_data: contains confusion matrix, row labels and column labels """ matrix, row_lab, col_lab = matrix_data - plt.imshow(matrix, cmap="binary", interpolation=None) - ax = plt.gca() + plt.imshow(matrix, cmap="binary", interpolation=None) # type: ignore + ax = plt.gca() # type: ignore ax.set_xticks( range(len(col_lab)), labels=col_lab, diff --git a/big_scape/cli/__init__.py b/big_scape/cli/__init__.py index b1cc8582..d132bb36 100644 --- a/big_scape/cli/__init__.py +++ b/big_scape/cli/__init__.py @@ -5,7 +5,6 @@ validate_input_mode, validate_binning_cluster_workflow, validate_binning_query_workflow, - validate_skip_hmmscan, validate_alignment_mode, validate_includelist, validate_gcf_cutoffs, @@ -21,7 +20,6 @@ "validate_input_mode", "validate_binning_cluster_workflow", "validate_binning_query_workflow", - "validate_skip_hmmscan", "validate_alignment_mode", "validate_includelist", "validate_gcf_cutoffs", diff --git a/big_scape/cli/benchmark_cli.py b/big_scape/cli/benchmark_cli.py index 30bd87be..e68984fa 100644 --- a/big_scape/cli/benchmark_cli.py +++ b/big_scape/cli/benchmark_cli.py @@ -1,6 +1,7 @@ """ Click parameters for the BiG-SCAPE Benchmark CLI command """ # from python +from typing import no_type_check import click from pathlib import Path @@ -13,6 +14,7 @@ from .cli_validations import set_start, validate_output_paths +@no_type_check # BiG-SCAPE benchmark mode @click.command() @common_all @@ -26,6 +28,7 @@ "a run output to these assignments." ), ) +@no_type_check @click.option( "--BiG_dir", type=click.Path(exists=True, file_okay=False, path_type=Path), diff --git a/big_scape/cli/cli_validations.py b/big_scape/cli/cli_validations.py index 25ff9f27..d9bc3ba8 100644 --- a/big_scape/cli/cli_validations.py +++ b/big_scape/cli/cli_validations.py @@ -36,7 +36,7 @@ def set_start(param_dict) -> None: # meta parameter validations -def validate_profiling(ctx, param, profiling) -> bool: +def validate_profiling(ctx, param, profiling: bool) -> bool: """Checks whether multithreading is possible, and therefore whether profiling can happen""" if profiling and platform.system() == "Darwin": @@ -48,7 +48,7 @@ def validate_profiling(ctx, param, profiling) -> bool: # input parameter validations -def validate_not_empty_dir(ctx, param, dir) -> Path: +def validate_not_empty_dir(ctx, param, dir: Path) -> Path: """Validates that a given directory is not empty. Raises a BadParameter""" @@ -57,6 +57,7 @@ def validate_not_empty_dir(ctx, param, dir) -> Path: if len(contents) == 0: logging.error(f"{dir}/ directory is empty!") raise click.BadParameter(f"{dir}/ directory is empty!") + return dir @@ -75,7 +76,7 @@ def validate_input_mode(ctx, param, input_mode) -> Optional[bs_enums.INPUT_MODE] return None -def validate_query_bgc(ctx, param, query_bgc_path) -> Path: +def validate_query_bgc(ctx, param, query_bgc_path: Path) -> Path: """Raises an InvalidArgumentError if the query bgc path does not exist""" if query_bgc_path.suffix != ".gbk": @@ -88,7 +89,7 @@ def validate_query_bgc(ctx, param, query_bgc_path) -> Path: # output parameter validations -def validate_output_dir(ctx, param, output_dir) -> Path: +def validate_output_dir(ctx, param, output_dir: Path) -> Path: """Validates that output directory exists""" if not output_dir.exists(): @@ -131,21 +132,23 @@ def validate_output_paths(ctx) -> None: # comparison validations -def validate_classify(ctx, param, classify) -> Optional[bs_enums.CLASSIFY_MODE]: +def validate_classify( + ctx, param, classify: Optional[bs_enums.CLASSIFY_MODE] +) -> Optional[bs_enums.CLASSIFY_MODE]: """Validates whether the classification type is set, and if not sets the parameter to False""" # check if the property matches one of the enum values valid_modes = [mode.value for mode in bs_enums.CLASSIFY_MODE] + if not classify: + return None + for mode in valid_modes: if classify == mode: return bs_enums.CLASSIFY_MODE[mode.upper()] - if classify is None: - classify = False - - return classify + raise InvalidArgumentError("--classify", classify) def validate_alignment_mode( @@ -187,7 +190,9 @@ def validate_gcf_cutoffs(ctx, param, gcf_cutoffs) -> list[float]: def validate_filter_gbk(ctx, param, filter_str) -> list[str]: """Validates and formats the filter string and returns a list of strings""" - return filter_str.split(",") + split_str: list[str] = filter_str.split(",") + + return split_str # hmmer parameters @@ -317,23 +322,6 @@ def validate_binning_query_workflow(ctx) -> None: ) -def validate_skip_hmmscan(ctx) -> None: - """Validates whether a BiG-SCAPE db exists when running skip_hmm, which - requires already processed gbk files and hence a DB in output""" - - if ctx.obj["skip_hmmscan"] and ctx.obj["db_path"] is None: - logging.error( - "Missing option '--db_path'." - "BiG-SCAPE database has not been given, skip_hmmscan requires " - "a DB of already processed gbk files." - ) - raise click.UsageError( - "Missing option '--db_path'." - "BiG-SCAPE database has not been given, skip_hmmscan requires " - "a DB of already processed gbk files." - ) - - def validate_pfam_path(ctx) -> None: """Validates whether a BiG-SCAPE db exists when pfam_path is not provided, which requires already processed gbk files and hence a DB in output""" diff --git a/big_scape/cli/cluster_cli.py b/big_scape/cli/cluster_cli.py index 956c9f18..eb3ab2c4 100644 --- a/big_scape/cli/cluster_cli.py +++ b/big_scape/cli/cluster_cli.py @@ -12,7 +12,6 @@ from .cli_validations import ( validate_output_paths, validate_binning_cluster_workflow, - validate_skip_hmmscan, validate_pfam_path, set_start, ) @@ -58,7 +57,6 @@ def cluster(ctx, *args, **kwargs): # workflow validations validate_binning_cluster_workflow(ctx) - validate_skip_hmmscan(ctx) validate_pfam_path(ctx) validate_output_paths(ctx) diff --git a/big_scape/cli/query_cli.py b/big_scape/cli/query_cli.py index 383d0d01..310dd9c9 100644 --- a/big_scape/cli/query_cli.py +++ b/big_scape/cli/query_cli.py @@ -1,6 +1,7 @@ """ Click parameters for the BiG-SCAPE Query CLI command """ # from python +from typing import no_type_check import click from pathlib import Path @@ -12,7 +13,6 @@ from .cli_common_options import common_all, common_cluster_query from .cli_validations import ( validate_output_paths, - validate_skip_hmmscan, validate_query_bgc, validate_pfam_path, set_start, @@ -21,6 +21,7 @@ ) +@no_type_check @click.command() @common_all @common_cluster_query diff --git a/big_scape/comparison/binning.py b/big_scape/comparison/binning.py index 503651c1..1d0f0498 100644 --- a/big_scape/comparison/binning.py +++ b/big_scape/comparison/binning.py @@ -14,7 +14,7 @@ from __future__ import annotations import logging from itertools import combinations -from typing import Generator, Iterator, Optional +from typing import Generator, Iterator, Optional, cast from sqlalchemy import and_, select, func, or_ # from other modules @@ -29,6 +29,7 @@ ) from big_scape.enums import SOURCE_TYPE, CLASSIFY_MODE, RECORD_TYPE +# from this module import big_scape.comparison as bs_comparison @@ -126,9 +127,8 @@ def num_pairs(self) -> int: # records from the same gbk will not be compared -> will not be a pair. Use # the database to find how many subrecords come from the same genbank, i.e. # how many pairs should be removed - if not DB.metadata: - raise RuntimeError("DB metadata is None!") - record_table = DB.metadata.tables["bgc_record"] + + record_table = DB.get_table("bgc_record") # find a collection of gbks with more than one subrecord member_table = ( @@ -179,10 +179,7 @@ def cull_singletons(self, cutoff: float, ref_only: bool = False): RuntimeError: DB.metadata is None """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - distance_table = DB.metadata.tables["distance"] + distance_table = DB.get_table("distance") # get all distances/edges in the table for the records in this bin and # with distances below the cutoff @@ -203,16 +200,25 @@ def cull_singletons(self, cutoff: float, ref_only: bool = False): if ref_only: singleton_record_ids = self.record_ids - edge_record_ids - self.source_records = [ - record + + new_records: list[BGCRecord] = [] + for record in self.source_records: + if record.parent_gbk is None: + raise ValueError("Region in bin has no parent gbk!") + in_edge = record._db_id in edge_record_ids + in_singletons = record._db_id in singleton_record_ids + not_reference = record.parent_gbk.source_type != SOURCE_TYPE.REFERENCE + + if in_edge or (in_singletons and not_reference): + new_records.append(record) + + self.source_records = new_records + + self.record_ids = { + record._db_id for record in self.source_records - if (record._db_id in edge_record_ids) - or ( - record._db_id in singleton_record_ids - and record.parent_gbk.source_type != SOURCE_TYPE.REFERENCE - ) - ] - self.record_ids = {record._db_id for record in self.source_records} + if record._db_id is not None + } else: self.record_ids = edge_record_ids @@ -315,17 +321,14 @@ class MissingRecordPairGenerator(RecordPairGenerator): already in the database """ - def __init__(self, pair_generator): + def __init__(self, pair_generator: RecordPairGenerator): super().__init__( pair_generator.label, pair_generator.edge_param_id, pair_generator.weights ) self.bin: RecordPairGenerator = pair_generator def num_pairs(self) -> int: - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - distance_table = DB.metadata.tables["distance"] + distance_table = DB.get_table("distance") # get all region._db_id in the bin where the record_a_id and record_b_id are in the # bin @@ -337,7 +340,7 @@ def num_pairs(self) -> int: ) # get count - existing_distance_count = DB.execute(select_statement).scalar_one() + existing_distance_count: int = DB.execute(select_statement).scalar_one() # subtract from expected number of distances return self.bin.num_pairs() - existing_distance_count @@ -355,10 +358,8 @@ def generate_pairs( Yields: Generator[tuple[int, int]]: Generator for record pairs in this bin """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - distance_table = DB.metadata.tables["distance"] + distance_table = DB.get_table("distance") # get all region._db_id in the bin select_statement = ( @@ -619,7 +620,7 @@ def num_pairs(self) -> int: ) ) - existing_distance_count = DB.execute(select_statement).scalar_one() + existing_distance_count = cast(int, DB.execute(select_statement).scalar_one()) return self.bin.num_pairs() - existing_distance_count diff --git a/big_scape/comparison/utility.py b/big_scape/comparison/utility.py index 68bde596..2fea0523 100644 --- a/big_scape/comparison/utility.py +++ b/big_scape/comparison/utility.py @@ -3,9 +3,10 @@ # from python import logging import sqlite3 +from typing import Optional # from dependencies -from sqlalchemy import insert, select +from sqlalchemy import Row, insert, select # from other modules from big_scape.data import DB @@ -41,10 +42,7 @@ def save_edge_to_db( # save the comparison data to the database - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - distance_table = DB.metadata.tables["distance"] + distance_table = DB.get_table("distance") # save the entry to the database statement = insert(distance_table).values( @@ -156,7 +154,7 @@ def save_edges_to_db( # if not DB.metadata: # raise RuntimeError("DB.metadata is None") -# distance_table = DB.metadata.tables["distance"] +# distance_table = DB.get_table("distance") # distance_query = distance_table.select().where( # distance_table.c.record_a_id.in_(region_ids) # & distance_table.c.record_b_id.in_(region_ids) @@ -197,22 +195,22 @@ def get_edge_param_id(run, weights) -> int: int: id of the edge param entry """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - alignment_mode = run["alignment_mode"] - edge_param_id = edge_params_query(alignment_mode, weights) + edge_param_id_res = edge_params_query(alignment_mode, weights) + + edge_param_id: int = edge_param_id_res[0] if edge_param_id_res else None if edge_param_id is None: - edge_param_id = edge_params_insert(alignment_mode, weights) + edge_param_id_res = edge_params_insert(alignment_mode, weights) + edge_param_id = edge_param_id_res[0] - logging.debug("Edge params id: %d", edge_param_id[0]) + logging.debug("Edge params id: %d", edge_param_id) - return edge_param_id[0] + return edge_param_id -def edge_params_query(alignment_mode, weights): +def edge_params_query(alignment_mode, weights) -> Optional[Row]: """Create and run a query for edge params Args: @@ -226,10 +224,8 @@ def edge_params_query(alignment_mode, weights): Row | None: cursor result """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") + edge_params_table = DB.get_table("edge_params") - edge_params_table = DB.metadata.tables["edge_params"] edge_params_query = ( select(edge_params_table.c.id) .where(edge_params_table.c.alignment_mode == alignment_mode.name) @@ -277,12 +273,20 @@ def get_edge_weight(edge_param_id: int) -> str: if DB.metadata is None: raise RuntimeError("DB.metadata is None") - edge_params_table = DB.metadata.tables["edge_params"] + edge_params_table = DB.get_table("edge_params") edge_weight_query = select(edge_params_table.c.weights).where( edge_params_table.c.id == edge_param_id ) - weights = DB.execute(edge_weight_query).fetchone()[0] + row = DB.execute(edge_weight_query).fetchone() + + if row is None: + raise RuntimeError("No edge weights found") + + weights = row[0] + + if not isinstance(weights, str): + raise TypeError(f"Unexpected type for weights: {type(weights)}") return weights diff --git a/big_scape/comparison/workflow.py b/big_scape/comparison/workflow.py index 6d909d5a..2e37d2e4 100644 --- a/big_scape/comparison/workflow.py +++ b/big_scape/comparison/workflow.py @@ -14,8 +14,9 @@ from concurrent.futures import ProcessPoolExecutor, Future import platform from threading import Event, Condition -from typing import Generator, Callable, Optional, TypeVar, Union +from typing import Generator, Callable, Optional, TypeVar, Union, cast, no_type_check from math import ceil + from .record_pair import RecordPair # from dependencies @@ -120,7 +121,9 @@ def generate_edges( # prepare a process pool logging.debug("Using %d cores", cores) - pair_data: Union[tuple[int, int], tuple[BGCRecord, BGCRecord]] + pair_data: Generator[ + Union[tuple[int, int], tuple[BGCRecord, BGCRecord]], None, None + ] if platform.system() == "Darwin": logging.debug( "Running on %s: sending full records", @@ -366,6 +369,11 @@ def expand_pair(pair: RecordPair) -> bool: return False +# TODO: mypy is annoying here, partially because it's correct and partially because +# mypy is just annoying and dumb. I'm done being all safe and secure with this function +# refactor the entire thing to be slightly different depending on whether you're giving +# it full records or just ids, and then mypy will be happy +@no_type_check def calculate_scores_pair( data: tuple[ list[Union[tuple[int, int], tuple[BGCRecord, BGCRecord]]], @@ -375,8 +383,8 @@ def calculate_scores_pair( ] ) -> list[ tuple[ - Optional[int], - Optional[int], + int, + int, float, float, float, @@ -387,6 +395,11 @@ def calculate_scores_pair( ]: # pragma no cover """Calculate the scores for a list of pairs + Note that the input data can be in the form of either a list of database ids or a + list of full BGCRecord objects. This is because the function is designed to be run + in parallel, and the database ids are used to fetch the minimal information needed + to perform the distance calculations. This only works on non-mac systems. + Args: data (tuple[list[tuple[int, int]], str, str]): list of pairs, alignment mode, bin label @@ -396,21 +409,26 @@ def calculate_scores_pair( int, int, bool, str,]]: list of scores for each pair in the order as the input data list, including lcs and extension coordinates """ - data, alignment_mode, edge_param_id, weights_label = data + pairs, alignment_mode, edge_param_id, weights_label = data # convert database ids to minimal record objects - if isinstance(data[0][0], int): - pair_ids = data + pair_ids: list[tuple[int, int]] = [] + + # if first is int, assume all are ints + if isinstance(pairs[0][0], int): + pair_ids.extend(pairs) records = fetch_records_from_database(pair_ids) else: pair_ids = [] records = {} - for pair in data: - pair_ids.append((pair[0]._db_id, pair[1]._db_id)) - records[pair[0]._db_id] = pair[0] - records[pair[1]._db_id] = pair[1] + for record_pair in pairs: + pair_ids.append((record_pair[0]._db_id, record_pair[1]._db_id)) + records[record_pair[0]._db_id] = record_pair[0] + records[record_pair[1]._db_id] = record_pair[1] - results = [] + results: list[ + tuple[int, int, float, float, float, float, int, bs_comparison.ComparableRegion] + ] = [] # only relevant for when not working on mac -> records are fetched # from db and not passed as full objects @@ -436,8 +454,8 @@ def calculate_scores_pair( if jaccard == 0.0: results.append( ( - pair.record_a._db_id, - pair.record_b._db_id, + id_a, + id_b, 1.0, 0.0, 0.0, @@ -483,8 +501,8 @@ def calculate_scores_pair( results.append( ( - pair.record_a._db_id, - pair.record_b._db_id, + id_a, + id_b, distance, jaccard, adjacency, @@ -515,11 +533,11 @@ def fetch_records_from_database(pairs: list[tuple[int, int]]) -> dict[int, BGCRe if DB.metadata is None: raise RuntimeError("DB metadata is None!") - gbk_table = DB.metadata.tables["gbk"] - record_table = DB.metadata.tables["bgc_record"] - cds_table = DB.metadata.tables["cds"] - hsp_table = DB.metadata.tables["hsp"] - algn_table = DB.metadata.tables["hsp_alignment"] + gbk_table = DB.get_table("gbk") + record_table = DB.get_table("bgc_record") + cds_table = DB.get_table("cds") + hsp_table = DB.get_table("hsp") + algn_table = DB.get_table("hsp_alignment") # gather minimally needed information for distance calculation and object creation query_statement = ( diff --git a/big_scape/data/partial_task.py b/big_scape/data/partial_task.py index 21d3f43d..f7c9b146 100644 --- a/big_scape/data/partial_task.py +++ b/big_scape/data/partial_task.py @@ -61,10 +61,7 @@ def get_input_data_state(gbks: list[GBK]) -> bs_enums.INPUT_TASK: if distance_count == 0: return bs_enums.INPUT_TASK.NO_DATA - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - gbk_table = DB.metadata.tables["gbk"] + gbk_table = DB.get_table("gbk") # get set of gbks in database db_gbk_rows = DB.execute(gbk_table.select()).all() @@ -100,10 +97,7 @@ def get_missing_gbks(gbks: list[GBK]) -> list[GBK]: # dictionary of gbk path to gbk object gbk_dict = {str(gbk.hash): gbk for gbk in gbks} - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - gbk_table = DB.metadata.tables["gbk"] + gbk_table = DB.get_table("gbk") # get set of gbks in database db_gbk_rows = DB.execute(gbk_table.select()).all() @@ -160,10 +154,7 @@ def get_cds_to_scan(gbks: list[GBK]) -> list[CDS]: # get a list of database cds_ids that are present in the cds_scanned table - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - scanned_cds_table = DB.metadata.tables["scanned_cds"] + scanned_cds_table = DB.get_table("scanned_cds") select_query = select(scanned_cds_table.c.cds_id) scanned_cds_ids = set(DB.execute(select_query)) @@ -206,19 +197,13 @@ def get_comparison_data_state(gbks: list[GBK]) -> bs_enums.COMPARISON_TASK: # check if all record ids are present in the comparison region ids - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - bgc_record_table = DB.metadata.tables["bgc_record"] + bgc_record_table = DB.get_table("bgc_record") select_statement = select(bgc_record_table.c.id) record_ids = set(DB.execute(select_statement).fetchall()) - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - distance_table = DB.metadata.tables["distance"] + distance_table = DB.get_table("distance") select_statement = select(distance_table.c.record_a_id).distinct() @@ -232,38 +217,3 @@ def get_comparison_data_state(gbks: list[GBK]) -> bs_enums.COMPARISON_TASK: return bs_enums.COMPARISON_TASK.NEW_DATA return bs_enums.COMPARISON_TASK.ALL_DONE - - -# TODO: does not seem to be used -def get_missing_distances( - pair_generator: RecordPairGenerator, -) -> Generator[tuple[Optional[int], Optional[int]], None, None]: - """Get a generator of BGCPairs that are missing from a network - - Args: - network (BSNetwork): network to check - bin (BGCBin): bin to check - - Yields: - Generator[BGCPair]: generator of BGCPairs that are missing from the network - """ - - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - distance_table = DB.metadata.tables["distance"] - - # get all region._db_id in the bin - select_statement = ( - select(distance_table.c.record_a_id, distance_table.c.record_b_id) - .where(distance_table.c.record_a_id.in_(pair_generator.record_ids)) - .where(distance_table.c.record_b_id.in_(pair_generator.record_ids)) - ) - - # generate a set of tuples of region id pairs - existing_distances = set(DB.execute(select_statement).fetchall()) - - for pair in pair_generator.generate_pairs(): - # if the pair is not in the set of existing distances, yield it - if pair not in existing_distances and pair[::-1] not in existing_distances: - yield pair diff --git a/big_scape/data/sqlite.py b/big_scape/data/sqlite.py index c388daaf..b1ecb226 100644 --- a/big_scape/data/sqlite.py +++ b/big_scape/data/sqlite.py @@ -14,6 +14,7 @@ Select, Insert, CursorResult, + Table, create_engine, func, select, @@ -205,8 +206,9 @@ def load_from_disk(db_path: Path) -> None: DB.reflect() - page_count = raw_file_connection.execute("PRAGMA page_count;") - page_count = page_count.fetchone()[0] + page_count: float = raw_file_connection.execute( + "PRAGMA page_count;" + ).fetchone()[0] with tqdm.tqdm(total=page_count, unit="page", desc="Loading database") as t: @@ -287,14 +289,14 @@ def get_table_row_count(table_name: str) -> int: if not DB.opened(): raise DBClosedError() - if not DB.metadata: - raise RuntimeError("DB.metadata is None") + table_metadata = DB.get_table(table_name) - table_metadata = DB.metadata.tables[table_name] - return DB.execute( + row_count: int = DB.execute( select(func.count("*")).select_from(table_metadata) ).scalar_one() + return row_count + @staticmethod def get_table_row_batch( table_name: str, batch_size=100 @@ -312,10 +314,7 @@ def get_table_row_batch( if not DB.opened(): raise DBClosedError() - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - table_metadata = DB.metadata.tables[table_name] + table_metadata = DB.get_table(table_name) table_select = select(table_metadata) table_select = table_select.execution_options(stream_results=True) @@ -326,6 +325,16 @@ def get_table_row_batch( break yield tuple(rows) + @staticmethod + def get_table(table_name: str) -> Table: + if not DB.opened(): + raise DBClosedError() + + if not DB.metadata: + raise RuntimeError("DB.metadata is None") + + return DB.metadata.tables[table_name] + def read_schema(path: Path) -> list[str]: """Read an .sql schema from a file""" diff --git a/big_scape/distances/adjacency.py b/big_scape/distances/adjacency.py index 0e00b25e..73517211 100644 --- a/big_scape/distances/adjacency.py +++ b/big_scape/distances/adjacency.py @@ -1,13 +1,12 @@ """Contains code to calculate adjacency indexes of set pairs""" - # from other modules import logging from big_scape.comparison.record_pair import RecordPair from big_scape.hmm import HSP -def calc_ai_lists(list_a: list, list_b: list): +def calc_ai_lists(list_a: list, list_b: list) -> float: """Calculate the adjacency index of two lists, which is the Jaccard index of sets of neighbouring items in two sorted lists. diff --git a/big_scape/genbank/bgc_record.py b/big_scape/genbank/bgc_record.py index dc7333fe..13bac31c 100644 --- a/big_scape/genbank/bgc_record.py +++ b/big_scape/genbank/bgc_record.py @@ -62,6 +62,7 @@ def __init__( self.nt_stop = nt_stop self.product = product self.merged: bool = False + self.merged_number: Optional[str] = None # for database operations self._db_id: Optional[int] = None @@ -235,19 +236,15 @@ def save_record( query. Defaults to True. """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - bgc_record_table = DB.metadata.tables["bgc_record"] + bgc_record_table = DB.get_table("bgc_record") if not hasattr(self, "category"): self.category: Optional[str] = None - # why - if hasattr(self, "merged_number"): + if self.merged_number is not None: number = self.merged_number else: - number = self.number + number = str(self.number) contig_edge = None if self.contig_edge is not None: @@ -333,7 +330,7 @@ def parse_products(feature: SeqFeature) -> str: logging.error("product qualifier not found in feature!") raise InvalidGBKError() - products = feature.qualifiers["product"] + products: list[str] = feature.qualifiers["product"] # single product? just return it if len(products) == 1: diff --git a/big_scape/genbank/candidate_cluster.py b/big_scape/genbank/candidate_cluster.py index 01dca6f2..924ecbc9 100644 --- a/big_scape/genbank/candidate_cluster.py +++ b/big_scape/genbank/candidate_cluster.py @@ -171,10 +171,7 @@ def load_all(region_dict: dict[int, Region]): ids as keys. Used for reassembling the hierarchy """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - record_table = DB.metadata.tables["bgc_record"] + record_table = DB.get_table("bgc_record") candidate_cluster_select_query = ( record_table.select() diff --git a/big_scape/genbank/cds.py b/big_scape/genbank/cds.py index 7fa691db..7b00ae0a 100644 --- a/big_scape/genbank/cds.py +++ b/big_scape/genbank/cds.py @@ -129,7 +129,7 @@ def save(self, commit=True): if self.parent_gbk is not None and self.parent_gbk._db_id is not None: parent_gbk_id = self.parent_gbk._db_id - cds_table = DB.metadata.tables["cds"] + cds_table = DB.get_table("cds") insert_query = ( cds_table.insert() .returning(cds_table.c.id) @@ -331,10 +331,7 @@ def load_all(gbk_dict: dict[int, GBK]) -> None: as keys. Used for parenting """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - cds_table = DB.metadata.tables["cds"] + cds_table = DB.get_table("cds") region_select_query = ( cds_table.select() diff --git a/big_scape/genbank/gbk.py b/big_scape/genbank/gbk.py index a5285e56..072d38c1 100644 --- a/big_scape/genbank/gbk.py +++ b/big_scape/genbank/gbk.py @@ -147,16 +147,14 @@ def save(self, commit=True) -> None: this returns the id of the GBK row Arguments: - commit: commit immediately after executing the insert query""" - - if not DB.metadata: - raise RuntimeError("DB.metadata is None") + commit: commit immediately after executing the insert query + """ organism = self.metadata["organism"] taxonomy = self.metadata["taxonomy"] description = self.metadata["description"] - gbk_table = DB.metadata.tables["gbk"] + gbk_table = DB.get_table("gbk") insert_query = ( gbk_table.insert() .prefix_with("OR REPLACE") @@ -210,10 +208,7 @@ def load_all() -> list[GBK]: list[GBK]: _description_ """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - gbk_table = DB.metadata.tables["gbk"] + gbk_table = DB.get_table("gbk") select_query = ( gbk_table.select() @@ -263,10 +258,7 @@ def load_many(input_gbks: list[GBK]) -> list[GBK]: input_gbk_hashes = [gbk.hash for gbk in input_gbks] - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - gbk_table = DB.metadata.tables["gbk"] + gbk_table = DB.get_table("gbk") select_query = ( gbk_table.select() .add_columns( @@ -315,7 +307,7 @@ def get_as_version(gbk_seq_record: SeqRecord) -> str: """ try: - as_version = gbk_seq_record.annotations["structured_comment"][ + as_version: str = gbk_seq_record.annotations["structured_comment"][ "antiSMASH-Data" ]["Version"] except KeyError: diff --git a/big_scape/genbank/proto_cluster.py b/big_scape/genbank/proto_cluster.py index 8f1041f2..809d84f1 100644 --- a/big_scape/genbank/proto_cluster.py +++ b/big_scape/genbank/proto_cluster.py @@ -186,10 +186,7 @@ def load_all(candidate_cluster_dict: dict[int, CandidateCluster]): objects with database ids as keys. Used for reassembling the hierarchy """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - record_table = DB.metadata.tables["bgc_record"] + record_table = DB.get_table("bgc_record") protocluster_select_query = ( record_table.select() @@ -250,9 +247,9 @@ def load_all(candidate_cluster_dict: dict[int, CandidateCluster]): new_proto_cluster._db_id = result.id # add to parent CandidateCluster protocluster dict - parent_candidate_cluster.proto_clusters[ - result.record_number - ] = new_proto_cluster + parent_candidate_cluster.proto_clusters[result.record_number] = ( + new_proto_cluster + ) # add to dictionary protocluster_dict[result.id] = new_proto_cluster diff --git a/big_scape/genbank/proto_core.py b/big_scape/genbank/proto_core.py index 72372ef2..d2cc3bf2 100644 --- a/big_scape/genbank/proto_core.py +++ b/big_scape/genbank/proto_core.py @@ -117,10 +117,7 @@ def load_all(protocluster_dict: dict[int, ProtoCluster]): as keys. Used for parenting """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - record_table = DB.metadata.tables["bgc_record"] + record_table = DB.get_table("bgc_record") region_select_query = ( record_table.select() diff --git a/big_scape/genbank/region.py b/big_scape/genbank/region.py index ec9307cd..c9928beb 100644 --- a/big_scape/genbank/region.py +++ b/big_scape/genbank/region.py @@ -259,10 +259,7 @@ def load_all(gbk_dict: dict[int, GBK]) -> None: keys. Used for reassembling the hierarchy """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - record_table = DB.metadata.tables["bgc_record"] + record_table = DB.get_table("bgc_record") region_select_query = ( record_table.select() diff --git a/big_scape/hmm/hmmer.py b/big_scape/hmm/hmmer.py index 48cf985f..dcd74ffd 100644 --- a/big_scape/hmm/hmmer.py +++ b/big_scape/hmm/hmmer.py @@ -173,7 +173,7 @@ def set_hmm_scanned(cds: CDS) -> None: cds (CDS): cds to update state for """ - state_table = DB.metadata.tables["scanned_cds"] + state_table = DB.get_table("scanned_cds") upignore_statement = ( insert(state_table).prefix_with("OR IGNORE").values(cds_id=cds._db_id) diff --git a/big_scape/hmm/hsp.py b/big_scape/hmm/hsp.py index 557acba1..2824ace8 100644 --- a/big_scape/hmm/hsp.py +++ b/big_scape/hmm/hsp.py @@ -41,10 +41,7 @@ def save(self, commit=True) -> None: if self.cds is not None and self.cds._db_id is not None: parent_cds_id = self.cds._db_id - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - hsp_table = DB.metadata.tables["hsp"] + hsp_table = DB.get_table("hsp") insert_query = ( hsp_table.insert() .returning(hsp_table.c.id) @@ -113,12 +110,12 @@ def __eq__(self, __o: object) -> bool: # special case if we are comparing this to a string if isinstance(__o, str): # we do not care about version numbers in this comparison, so strip it - return self.domain == __o + return self.domain == __o # type: ignore if not isinstance(__o, HSP): raise NotImplementedError() - return __o.domain == self.domain + return __o.domain == self.domain # type: ignore def __hash__(self) -> int: return hash(self.domain) @@ -148,11 +145,8 @@ def load_all(cds_list: list[CDS]) -> None: """ cds_dict = {cds._db_id: cds for cds in cds_list} - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - hsp_table = DB.metadata.tables["hsp"] - hsp_alignment_table = DB.metadata.tables["hsp_alignment"] + hsp_table = DB.get_table("hsp") + hsp_alignment_table = DB.get_table("hsp_alignment") hsp_select_query = ( hsp_table.select() @@ -245,10 +239,7 @@ def save(self, commit=True) -> None: if self.hsp is not None and self.hsp._db_id is not None: parent_hsp_id = self.hsp._db_id - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - hsp_align_table = DB.metadata.tables["hsp_alignment"] + hsp_align_table = DB.get_table("hsp_alignment") insert_query = hsp_align_table.insert().values( hsp_id=parent_hsp_id, alignment=self.align_string, diff --git a/big_scape/network/families.py b/big_scape/network/families.py index 524c70b9..3d1203bf 100644 --- a/big_scape/network/families.py +++ b/big_scape/network/families.py @@ -102,7 +102,7 @@ def get_cc_edge_weight_std(connected_component) -> float: """ edge_weights = [edge[2] for edge in connected_component] - edge_std = np.std(edge_weights) + edge_std: float = np.std(edge_weights) edge_std = round(edge_std, 2) return edge_std @@ -203,8 +203,8 @@ def save_to_db(regions_families): regions_families (list[tuple[int, int, float]]): list of (region_id, family, cutoff) tuples """ - family_table = DB.metadata.tables["family"] - bgc_record_family_table = DB.metadata.tables["bgc_record_family"] + family_table = DB.get_table("family") + bgc_record_family_table = DB.get_table("bgc_record_family") for region_id, family, cutoff, bin_label in regions_families: # obtain unique family id if present @@ -241,8 +241,8 @@ def save_to_db(regions_families): def reset_db_family_tables(): """Clear previous family assignments from database""" - DB.execute(DB.metadata.tables["bgc_record_family"].delete()) - DB.execute(DB.metadata.tables["family"].delete()) + DB.execute(DB.get_table("bgc_record_family").delete()) + DB.execute(DB.get_table("family").delete()) def save_singletons(record_type: RECORD_TYPE, cutoff: float, bin_label: str) -> None: @@ -258,9 +258,9 @@ def save_singletons(record_type: RECORD_TYPE, cutoff: float, bin_label: str) -> if DB.metadata is None: raise RuntimeError("DB metadata is None!") - family_table = DB.metadata.tables["family"] - bgc_record_family_table = DB.metadata.tables["bgc_record_family"] - record_table = DB.metadata.tables["bgc_record"] + family_table = DB.get_table("family") + bgc_record_family_table = DB.get_table("bgc_record_family") + record_table = DB.get_table("bgc_record") singleton_query = ( select(record_table.c.id) diff --git a/big_scape/network/network.py b/big_scape/network/network.py index ec73dd23..768d4c93 100644 --- a/big_scape/network/network.py +++ b/big_scape/network/network.py @@ -44,7 +44,7 @@ def get_connected_components( Yields: Generator[list[tuple[int, int, float, float, float, float, int]], None, None]: - a generator yielding a list of edges for each connected component + a generator yielding a list of edges for each connected component """ # create a temporary table with the records to include @@ -64,9 +64,9 @@ def get_connected_components( logging.info(f"Found {len(cc_ids)} connected components") - if DB.metadata is None: - raise RuntimeError("DB.metadata is None") - distance_table = DB.metadata.tables["distance"] + distance_table = DB.get_table("distance") + + cc_table = DB.get_table("connected_component") # return connected components per connected component id # cc_ids will be repeated accross cutoffs and bins, so @@ -83,23 +83,19 @@ def get_connected_components( ).where( and_( distance_table.c.record_a_id.in_( - select(DB.metadata.tables["connected_component"].c.record_id).where( - DB.metadata.tables["connected_component"].c.id == cc_id, - DB.metadata.tables["connected_component"].c.cutoff == cutoff, - DB.metadata.tables["connected_component"].c.edge_param_id - == edge_param_id, - DB.metadata.tables["connected_component"].c.bin_label - == bin.label, + select(cc_table.c.record_id).where( + cc_table.c.id == cc_id, + cc_table.c.cutoff == cutoff, + cc_table.c.edge_param_id == edge_param_id, + cc_table.c.bin_label == bin.label, ) ), distance_table.c.record_b_id.in_( - select(DB.metadata.tables["connected_component"].c.record_id).where( - DB.metadata.tables["connected_component"].c.id == cc_id, - DB.metadata.tables["connected_component"].c.cutoff == cutoff, - DB.metadata.tables["connected_component"].c.edge_param_id - == edge_param_id, - DB.metadata.tables["connected_component"].c.bin_label - == bin.label, + select(cc_table.c.record_id).where( + cc_table.c.id == cc_id, + cc_table.c.cutoff == cutoff, + cc_table.c.edge_param_id == edge_param_id, + cc_table.c.bin_label == bin.label, ) ), distance_table.c.edge_param_id == edge_param_id, @@ -144,9 +140,7 @@ def generate_connected_components( seed_record (BGCRecord, optional): a seed record to start the connected component from. """ - if DB.metadata is None: - raise RuntimeError("DB.metadata is None") - distance_table = DB.metadata.tables["distance"] + distance_table = DB.get_table("distance") edge = get_random_edge( cutoff, edge_param_id, bin_label, temp_record_table, seed_record @@ -163,8 +157,14 @@ def generate_connected_components( seen = set() if DB.engine is None: - raise RuntimeError("DB.engine is None") - cursor = DB.engine.raw_connection().driver_connection.cursor() + raise RuntimeError("DB engine is None") + + driver = DB.engine.raw_connection().driver_connection + + if driver is None: + raise RuntimeError("Driver is None") + + cursor = driver.cursor() edge_count_query = select(func.count(distance_table.c.record_a_id)).where( and_( @@ -172,7 +172,13 @@ def generate_connected_components( distance_table.c.edge_param_id == edge_param_id, ) ) - num_edges = DB.execute(edge_count_query).fetchone()[0] + + result = DB.execute(edge_count_query).fetchone() + + if result is None: + raise ValueError("No result from query") + + num_edges = result[0] with tqdm.tqdm(total=num_edges, desc="Generating connected components") as t: while len(edges) > 0: @@ -241,17 +247,15 @@ def has_missing_cc_assignments( Args: cutoff (float): the distance cutoff edge_param_id (int): the edge parameter id - temp_table (Table, optional): a temporary table with the records to include in the connected - component. Defaults to None. + temp_table (Table, optional): a temporary table with the records to include + in the connected component. Defaults to None. Returns: bool: True if there are missing connected component assignments, False otherwise """ - if DB.metadata is None: - raise RuntimeError("DB.metadata is None") - distance_table = DB.metadata.tables["distance"] - cc_table = DB.metadata.tables["connected_component"] + distance_table = DB.get_table("distance") + cc_table = DB.get_table("connected_component") select_statement = ( select(func.count(distinct(distance_table.c.record_a_id))) @@ -269,7 +273,12 @@ def has_missing_cc_assignments( distance_table.c.record_a_id.in_(select(temp_record_table.c.record_id)) ) - num_missing = DB.execute(select_statement).fetchone()[0] + result = DB.execute(select_statement).fetchone() + + if result is None: + raise ValueError("No result from query") + + num_missing: int = result[0] return num_missing > 0 @@ -285,15 +294,13 @@ def get_connected_component_ids( Args: cutoff (float): the distance cutoff edge_param_id (int): the edge parameter id - temp_record_table (Table, optional): a temporary table with the records to include in the - connected component. Defaults to None. + temp_record_table (Table, optional): a temporary table with the records to + include in the connected component. Defaults to None. Returns: list[int]: a list of connected component ids """ - if DB.metadata is None: - raise RuntimeError("DB.metadata is None") - cc_table = DB.metadata.tables["connected_component"] + cc_table = DB.get_table("connected_component") select_statement = ( select(cc_table.c.id) .distinct() @@ -311,10 +318,10 @@ def get_connected_component_ids( cc_table.c.record_id.in_(select(temp_record_table.c.record_id)) ) - cc_ids = DB.execute(select_statement).fetchall() - # returned as tuples, convert to list - cc_ids = [cc_id[0] for cc_id in cc_ids] + cc_ids: list[int] = cast( + list[int], [cc_id[0] for cc_id in DB.execute(select_statement).fetchall()] + ) return cc_ids @@ -335,16 +342,15 @@ def get_random_edge( Args: cutoff: the distance cutoff edge_param_id: the edge parameter id - temp_record_table (Table, optional): a temporary table with the records to include in the - connected component. Defaults to None. + temp_record_table (Table, optional): a temporary table with the records to + include in the connected component. Defaults to None. Returns: Optional[tuple[int, int]]: a tuple with the record ids of the edge or None """ - if DB.metadata is None: - raise RuntimeError("DB.metadata is None") - distance_table = DB.metadata.tables["distance"] - cc_table = DB.metadata.tables["connected_component"] + + distance_table = DB.get_table("distance") + cc_table = DB.get_table("connected_component") random_edge_query = ( # select edge as just record ids @@ -384,6 +390,7 @@ def get_random_edge( cc_table.c.bin_label == bin_label, ) ), + # and where the edge has a distance less than the cutoff and the edge param id is the same ) if temp_record_table is not None: @@ -394,7 +401,9 @@ def get_random_edge( ) ) - edge = DB.execute(random_edge_query).fetchone() + edge: Optional[tuple[int, int]] = cast( + Optional[tuple[int, int]], DB.execute(random_edge_query).fetchone() + ) return edge @@ -421,10 +430,9 @@ def get_cc_edges( Returns: Optional[tuple[int, int]]: a tuple with the record ids of the edge or none """ - if DB.metadata is None: - raise RuntimeError("DB.metadata is None") - distance_table = DB.metadata.tables["distance"] - cc_table = DB.metadata.tables["connected_component"] + + distance_table = DB.get_table("distance") + cc_table = DB.get_table("connected_component") cc_edge_query = select( distance_table.c.record_a_id, distance_table.c.record_b_id @@ -457,7 +465,7 @@ def get_cc_edges( ) ) - edges = DB.execute(cc_edge_query).fetchall() + edges = cast(list[tuple[int, int]], DB.execute(cc_edge_query).fetchall()) return edges @@ -477,9 +485,7 @@ def get_edge( # fetch an edge from the database - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - distance_table = DB.metadata.tables["distance"] + distance_table = DB.get_table("distance") select_statment = ( select( distance_table.c.record_a_id, @@ -515,10 +521,7 @@ def get_edges( # fetch edges from the database - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - distance_table = DB.metadata.tables["distance"] + distance_table = DB.get_table("distance") select_statement = ( select( distance_table.c.record_a_id, @@ -593,20 +596,27 @@ def create_temp_record_table(include_records: list[BGCRecord]) -> Table: # generate a short random string temp_table_name = "temp_" + "".join(random.choices(string.ascii_lowercase, k=10)) + bgc_record_table = DB.get_table("bgc_record") - temp_table = Table( + Table( temp_table_name, DB.metadata, Column( "record_id", Integer, - ForeignKey(DB.metadata.tables["bgc_record"].c.id), + ForeignKey(bgc_record_table.c.id), primary_key=True, nullable=False, ), prefixes=["TEMPORARY"], ) + if DB.metadata is None: + raise RuntimeError( + "DB metadata is None. This should never happen at this point " + "but it means that the database was somehow uninitialized" + ) + DB.metadata.create_all(DB.engine) # create_temp_table = f""" @@ -642,7 +652,7 @@ def create_temp_record_table(include_records: list[BGCRecord]) -> Table: def reset_db_connected_components_table(): """Removes any data from the connected component table""" - DB.execute(DB.metadata.tables["connected_component"].delete()) + DB.execute(DB.get_table("connected_component").delete()) def reference_only_connected_component(connected_component, bgc_records) -> bool: @@ -689,7 +699,7 @@ def get_connected_component_id(connected_component, cutoff, edge_param_id) -> in record_id = connected_component[0][0] - cc_table = DB.metadata.tables["connected_component"] + cc_table = DB.get_table("connected_component") select_statement = ( select(cc_table.c.id) @@ -704,9 +714,9 @@ def get_connected_component_id(connected_component, cutoff, edge_param_id) -> in .limit(1) ) - cc_ids = DB.execute(select_statement).fetchone() + cc_id = int(DB.execute(select_statement).fetchone()[0]) - return cc_ids[0] + return cc_id def remove_connected_component(connected_component, cutoff, edge_param_id) -> None: @@ -717,7 +727,7 @@ def remove_connected_component(connected_component, cutoff, edge_param_id) -> No cc_id = get_connected_component_id(connected_component, cutoff, edge_param_id) - cc_table = DB.metadata.tables["connected_component"] + cc_table = DB.get_table("connected_component") delete_statement = delete(cc_table).where( cc_table.c.id == cc_id, diff --git a/big_scape/network/utility.py b/big_scape/network/utility.py index ab3a9dfd..dc53ae80 100644 --- a/big_scape/network/utility.py +++ b/big_scape/network/utility.py @@ -15,6 +15,7 @@ def sim_matrix_from_graph(graph: nx.Graph, edge_property: str) -> np.ndarray: Returns: ndarray: _description_ """ + matrix: np.ndarray matrix = nx.to_numpy_array(graph, weight=edge_property, nonedge=1.0) # have to convert from distances to similarity matrix = 1 - matrix @@ -60,7 +61,7 @@ def edge_list_to_adj_list( return adj_list -def adj_list_to_sim_matrix(adj_list: dict[int, dict[int, float]]) -> np.ndarray: +def adj_list_to_sim_matrix(adj_list: dict[int, dict[int, float]]) -> list[list[float]]: """Return a similarity matrix from an adjacency list Adjacency list is expected to be a dictionary of dictionaries, where the keys of @@ -78,6 +79,7 @@ def adj_list_to_sim_matrix(adj_list: dict[int, dict[int, float]]) -> np.ndarray: np.ndarray: similarity matrix """ # set up the matrix + matrix: np.ndarray matrix = np.zeros((len(adj_list), len(adj_list))) # set up a dictionary to map region ids to matrix indices @@ -92,7 +94,9 @@ def adj_list_to_sim_matrix(adj_list: dict[int, dict[int, float]]) -> np.ndarray: b_matrix_idx = region_to_index[record_b] matrix[a_matrix_idx][b_matrix_idx] = 1 - adj_list[record_a][record_b] - return matrix.tolist() + sim_matrix_list: list[list[float]] = matrix.tolist() + + return sim_matrix_list def edge_list_to_sim_matrix( diff --git a/big_scape/output/legacy_output.py b/big_scape/output/legacy_output.py index cb9514f3..23f596da 100644 --- a/big_scape/output/legacy_output.py +++ b/big_scape/output/legacy_output.py @@ -1,4 +1,8 @@ -"""Contains functions to mimic legacy output as seen in BiG-SCAPE 1.0""" +"""Contains functions to mimic legacy output as seen in BiG-SCAPE 1.0 + +NOTE: here be dragons. This code is long, complex, and contains a whole bunch of +hacky workarounds to make the output of BiG-SCAPE 2.0 look like the output of BiG-SCAPE 1.0. +""" # from python from itertools import repeat @@ -389,7 +393,7 @@ def read_bigscape_results_js(bigscape_results_js_path: Path) -> list[Any]: # last one has a semicolon at the end. remove it lines[-1] = lines[-1][:1] - return json.loads("".join(lines)) + return json.loads("".join(lines)) # type: ignore def read_run_data_js(run_data_js_path: Path) -> dict[str, Any]: @@ -419,7 +423,7 @@ def read_run_data_js(run_data_js_path: Path) -> dict[str, Any]: # last line after that has a semicolon at the end. remove the semicolon lines[-1] = lines[-1][:1] - return json.loads("".join(lines)) + return json.loads("".join(lines)) # type: ignore def generate_bigscape_results_js(output_dir: Path, label: str, cutoff: float) -> None: @@ -760,7 +764,7 @@ def fetch_lcs_from_db(a_id: int, b_id: int, edge_param_id: int) -> dict[str, Any """ if DB.metadata is None: raise RuntimeError("Database metadata is None!") - dist_table = DB.metadata.tables["distance"] + dist_table = DB.get_table("distance") select_query = ( dist_table.select() .where(dist_table.c.record_a_id != dist_table.c.record_b_id) @@ -1045,12 +1049,9 @@ def generate_bs_families_members( # get a dictionary of node id to family id node_family = {} - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - # get all families from the database - family_table = DB.metadata.tables["family"] - bgc_families_table = DB.metadata.tables["bgc_record_family"] + family_table = DB.get_table("family") + bgc_families_table = DB.get_table("bgc_record_family") select_statement = ( select( @@ -1316,10 +1317,8 @@ def write_record_annotations_file(run, cutoff, all_bgc_records) -> None: cutoff_path = output_files_root / f"{label}_c{cutoff}" record_annotations_path = cutoff_path / "record_annotations.tsv" - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - bgc_record_table = DB.metadata.tables["bgc_record"] - gbk_table = DB.metadata.tables["gbk"] + bgc_record_table = DB.get_table("bgc_record") + gbk_table = DB.get_table("gbk") record_categories = {} for record in all_bgc_records: @@ -1408,13 +1407,10 @@ def write_clustering_file(run, cutoff, pair_generator) -> None: pair_generator_path = cutoff_path / pair_generator.label clustering_file_path = pair_generator_path / f"{bin_label}_clustering_c{cutoff}.tsv" - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - gbk_table = DB.metadata.tables["gbk"] - bgc_record_table = DB.metadata.tables["bgc_record"] - family_table = DB.metadata.tables["family"] - rec_fam_table = DB.metadata.tables["bgc_record_family"] + gbk_table = DB.get_table("gbk") + bgc_record_table = DB.get_table("bgc_record") + family_table = DB.get_table("family") + rec_fam_table = DB.get_table("bgc_record_family") record_ids = pair_generator.record_ids select_statement = ( @@ -1592,26 +1588,7 @@ def write_cutoff_network_file( def get_cutoff_edgelist( run: dict, cutoff: float, pair_generator: RecordPairGenerator -) -> set[ - tuple[ - str, - str, - int, - str, - str, - int, - float, - float, - float, - float, - str, - int, - int, - int, - int, - str, - ] -]: +) -> set[Any]: """Generate the network egdelist for a given bin with edges above the cutoff Args: @@ -1621,16 +1598,13 @@ def get_cutoff_edgelist( RuntimeError: no database present Returns: - set: edgelist + set: edgelist. A set of a long tuple each corresponding to a row in the distance table """ - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - distance_table = DB.metadata.tables["distance"] - gbk_table = DB.metadata.tables["gbk"] - bgc_record_table = DB.metadata.tables["bgc_record"] - edge_params_table = DB.metadata.tables["edge_params"] + distance_table = DB.get_table("distance") + gbk_table = DB.get_table("gbk") + bgc_record_table = DB.get_table("bgc_record") + edge_params_table = DB.get_table("edge_params") bgc_record_a = alias(bgc_record_table) bgc_record_b = alias(bgc_record_table) @@ -1699,28 +1673,7 @@ def write_full_network_file(run: dict, all_bgc_records: list[BGCRecord]) -> None write_network_file(full_network_file_path, edgelist) -def get_full_network_edgelist( - run: dict, all_bgc_records: list -) -> set[ - tuple[ - str, - str, - int, - str, - str, - int, - float, - float, - float, - float, - str, - int, - int, - int, - int, - str, - ] -]: +def get_full_network_edgelist(run: dict, all_bgc_records: list) -> set[Any]: """Get all edges for the pairs of records in this run, for the weights relevant to this run Args: @@ -1730,7 +1683,7 @@ def get_full_network_edgelist( RuntimeError: no database present Returns: - set: edgelist + set: edgelist. A set of a long tuple each corresponding to a row in the distance table """ legacy_weights = [ @@ -1753,13 +1706,10 @@ def get_full_network_edgelist( record_ids = [record._db_id for record in all_bgc_records] - if not DB.metadata: - raise RuntimeError("DB.metadata is None") - - distance_table = DB.metadata.tables["distance"] - gbk_table = DB.metadata.tables["gbk"] - bgc_record_table = DB.metadata.tables["bgc_record"] - edge_params_table = DB.metadata.tables["edge_params"] + distance_table = DB.get_table("distance") + gbk_table = DB.get_table("gbk") + bgc_record_table = DB.get_table("bgc_record") + edge_params_table = DB.get_table("edge_params") bgc_record_a = alias(bgc_record_table) bgc_record_b = alias(bgc_record_table) diff --git a/big_scape/trees/newick_tree.py b/big_scape/trees/newick_tree.py index df826d14..a057682a 100644 --- a/big_scape/trees/newick_tree.py +++ b/big_scape/trees/newick_tree.py @@ -101,7 +101,8 @@ def process_newick_tree(tree_file: Path) -> str: # Noticed this could happen if the sequences are exactly # the same and all distances == 0 logging.debug("Unable to root at midpoint") - return tree.format("newick") + newick_tree: str = tree.format("newick") + return newick_tree def find_tree_domains( diff --git a/dev-requirements.txt b/dev-requirements.txt index d5e2d099..6b6cb91d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -19,3 +19,6 @@ types-psutil networkx-stubs data-science-types types-tqdm + +# mypy +mypy==1.9.0 diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..47c59b32 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +[mypy] +warn_return_any = True +files=bigscape.py +ignore_missing_imports = True