Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/mypy #139

Closed
wants to merge 16 commits into from
Closed
7 changes: 6 additions & 1 deletion big_scape/benchmarking/benchmark_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ def calculate_v_measure(self) -> tuple[float, float, float]:
curated_fams = list(self.curated_labels.values())
computed_fams = list(self.computed_labels.values())

return homogeneity_completeness_v_measure(curated_fams, computed_fams)
metrics: tuple[float, float, float] = homogeneity_completeness_v_measure(
curated_fams,
computed_fams,
)

return metrics

def calculate_purity(self) -> dict[str, float]:
"""Calculate purity P of each computed GCF
Expand Down
6 changes: 3 additions & 3 deletions big_scape/benchmarking/benchmark_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def plot_per_cutoff(self, metrics: dict[str, dict[str, Any]]) -> None:
cutoffs_fl = list(map(float, cutoffs))

fig = plt.figure()
ax = fig.gca()
ax = fig.gca() # type: ignore

h = ax.plot(
cutoffs_fl,
Expand Down Expand Up @@ -225,8 +225,8 @@ def plot_conf_matrix_heatmap(
matrix_data: contains confusion matrix, row labels and column labels
"""
matrix, row_lab, col_lab = matrix_data
plt.imshow(matrix, cmap="binary", interpolation=None)
ax = plt.gca()
plt.imshow(matrix, cmap="binary", interpolation=None) # type: ignore
ax = plt.gca() # type: ignore
ax.set_xticks(
range(len(col_lab)),
labels=col_lab,
Expand Down
2 changes: 0 additions & 2 deletions big_scape/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
validate_input_mode,
validate_binning_cluster_workflow,
validate_binning_query_workflow,
validate_skip_hmmscan,
validate_alignment_mode,
validate_includelist,
validate_gcf_cutoffs,
Expand All @@ -21,7 +20,6 @@
"validate_input_mode",
"validate_binning_cluster_workflow",
"validate_binning_query_workflow",
"validate_skip_hmmscan",
"validate_alignment_mode",
"validate_includelist",
"validate_gcf_cutoffs",
Expand Down
3 changes: 3 additions & 0 deletions big_scape/cli/benchmark_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Click parameters for the BiG-SCAPE Benchmark CLI command """

# from python
from typing import no_type_check
import click
from pathlib import Path

Expand All @@ -13,6 +14,7 @@
from .cli_validations import set_start, validate_output_paths


@no_type_check
# BiG-SCAPE benchmark mode
@click.command()
@common_all
Expand All @@ -26,6 +28,7 @@
"a run output to these assignments."
),
)
@no_type_check
@click.option(
"--BiG_dir",
type=click.Path(exists=True, file_okay=False, path_type=Path),
Expand Down
42 changes: 15 additions & 27 deletions big_scape/cli/cli_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def set_start(param_dict) -> None:
# meta parameter validations


def validate_profiling(ctx, param, profiling) -> bool:
def validate_profiling(ctx, param, profiling: bool) -> bool:
"""Checks whether multithreading is possible, and therefore whether profiling can happen"""

if profiling and platform.system() == "Darwin":
Expand All @@ -48,7 +48,7 @@ def validate_profiling(ctx, param, profiling) -> bool:
# input parameter validations


def validate_not_empty_dir(ctx, param, dir) -> Path:
def validate_not_empty_dir(ctx, param, dir: Path) -> Path:
"""Validates that a given directory is not empty.
Raises a BadParameter"""

Expand All @@ -57,6 +57,7 @@ def validate_not_empty_dir(ctx, param, dir) -> Path:
if len(contents) == 0:
logging.error(f"{dir}/ directory is empty!")
raise click.BadParameter(f"{dir}/ directory is empty!")

return dir


Expand All @@ -75,7 +76,7 @@ def validate_input_mode(ctx, param, input_mode) -> Optional[bs_enums.INPUT_MODE]
return None


def validate_query_bgc(ctx, param, query_bgc_path) -> Path:
def validate_query_bgc(ctx, param, query_bgc_path: Path) -> Path:
"""Raises an InvalidArgumentError if the query bgc path does not exist"""

if query_bgc_path.suffix != ".gbk":
Expand All @@ -88,7 +89,7 @@ def validate_query_bgc(ctx, param, query_bgc_path) -> Path:
# output parameter validations


def validate_output_dir(ctx, param, output_dir) -> Path:
def validate_output_dir(ctx, param, output_dir: Path) -> Path:
"""Validates that output directory exists"""

if not output_dir.exists():
Expand Down Expand Up @@ -131,21 +132,23 @@ def validate_output_paths(ctx) -> None:
# comparison validations


def validate_classify(ctx, param, classify) -> Optional[bs_enums.CLASSIFY_MODE]:
def validate_classify(
ctx, param, classify: Optional[bs_enums.CLASSIFY_MODE]
) -> Optional[bs_enums.CLASSIFY_MODE]:
"""Validates whether the classification type is set, and if not
sets the parameter to False"""

# check if the property matches one of the enum values
valid_modes = [mode.value for mode in bs_enums.CLASSIFY_MODE]

if not classify:
return None

for mode in valid_modes:
if classify == mode:
return bs_enums.CLASSIFY_MODE[mode.upper()]

if classify is None:
classify = False

return classify
raise InvalidArgumentError("--classify", classify)


def validate_alignment_mode(
Expand Down Expand Up @@ -187,7 +190,9 @@ def validate_gcf_cutoffs(ctx, param, gcf_cutoffs) -> list[float]:
def validate_filter_gbk(ctx, param, filter_str) -> list[str]:
"""Validates and formats the filter string and returns a list of strings"""

return filter_str.split(",")
split_str: list[str] = filter_str.split(",")

return split_str


# hmmer parameters
Expand Down Expand Up @@ -317,23 +322,6 @@ def validate_binning_query_workflow(ctx) -> None:
)


def validate_skip_hmmscan(ctx) -> None:
"""Validates whether a BiG-SCAPE db exists when running skip_hmm, which
requires already processed gbk files and hence a DB in output"""

if ctx.obj["skip_hmmscan"] and ctx.obj["db_path"] is None:
logging.error(
"Missing option '--db_path'."
"BiG-SCAPE database has not been given, skip_hmmscan requires "
"a DB of already processed gbk files."
)
raise click.UsageError(
"Missing option '--db_path'."
"BiG-SCAPE database has not been given, skip_hmmscan requires "
"a DB of already processed gbk files."
)


def validate_pfam_path(ctx) -> None:
"""Validates whether a BiG-SCAPE db exists when pfam_path is not provided,
which requires already processed gbk files and hence a DB in output"""
Expand Down
2 changes: 0 additions & 2 deletions big_scape/cli/cluster_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .cli_validations import (
validate_output_paths,
validate_binning_cluster_workflow,
validate_skip_hmmscan,
validate_pfam_path,
set_start,
)
Expand Down Expand Up @@ -58,7 +57,6 @@ def cluster(ctx, *args, **kwargs):

# workflow validations
validate_binning_cluster_workflow(ctx)
validate_skip_hmmscan(ctx)
validate_pfam_path(ctx)
validate_output_paths(ctx)

Expand Down
3 changes: 2 additions & 1 deletion big_scape/cli/query_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Click parameters for the BiG-SCAPE Query CLI command """

# from python
from typing import no_type_check
import click
from pathlib import Path

Expand All @@ -12,7 +13,6 @@
from .cli_common_options import common_all, common_cluster_query
from .cli_validations import (
validate_output_paths,
validate_skip_hmmscan,
validate_query_bgc,
validate_pfam_path,
set_start,
Expand All @@ -21,6 +21,7 @@
)


@no_type_check
@click.command()
@common_all
@common_cluster_query
Expand Down
55 changes: 28 additions & 27 deletions big_scape/comparison/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations
import logging
from itertools import combinations
from typing import Generator, Iterator, Optional
from typing import Generator, Iterator, Optional, cast
from sqlalchemy import and_, select, func, or_

# from other modules
Expand All @@ -29,6 +29,7 @@
)
from big_scape.enums import SOURCE_TYPE, CLASSIFY_MODE, RECORD_TYPE

# from this module
import big_scape.comparison as bs_comparison


Expand Down Expand Up @@ -126,9 +127,8 @@ def num_pairs(self) -> int:
# records from the same gbk will not be compared -> will not be a pair. Use
# the database to find how many subrecords come from the same genbank, i.e.
# how many pairs should be removed
if not DB.metadata:
raise RuntimeError("DB metadata is None!")
record_table = DB.metadata.tables["bgc_record"]

record_table = DB.get_table("bgc_record")

# find a collection of gbks with more than one subrecord
member_table = (
Expand Down Expand Up @@ -179,10 +179,7 @@ def cull_singletons(self, cutoff: float, ref_only: bool = False):
RuntimeError: DB.metadata is None
"""

if not DB.metadata:
raise RuntimeError("DB.metadata is None")

distance_table = DB.metadata.tables["distance"]
distance_table = DB.get_table("distance")

# get all distances/edges in the table for the records in this bin and
# with distances below the cutoff
Expand All @@ -203,16 +200,25 @@ def cull_singletons(self, cutoff: float, ref_only: bool = False):

if ref_only:
singleton_record_ids = self.record_ids - edge_record_ids
self.source_records = [
record

new_records: list[BGCRecord] = []
for record in self.source_records:
if record.parent_gbk is None:
raise ValueError("Region in bin has no parent gbk!")
in_edge = record._db_id in edge_record_ids
in_singletons = record._db_id in singleton_record_ids
not_reference = record.parent_gbk.source_type != SOURCE_TYPE.REFERENCE

if in_edge or (in_singletons and not_reference):
new_records.append(record)

self.source_records = new_records

self.record_ids = {
record._db_id
for record in self.source_records
if (record._db_id in edge_record_ids)
or (
record._db_id in singleton_record_ids
and record.parent_gbk.source_type != SOURCE_TYPE.REFERENCE
)
]
self.record_ids = {record._db_id for record in self.source_records}
if record._db_id is not None
}

else:
self.record_ids = edge_record_ids
Expand Down Expand Up @@ -315,17 +321,14 @@ class MissingRecordPairGenerator(RecordPairGenerator):
already in the database
"""

def __init__(self, pair_generator):
def __init__(self, pair_generator: RecordPairGenerator):
super().__init__(
pair_generator.label, pair_generator.edge_param_id, pair_generator.weights
)
self.bin: RecordPairGenerator = pair_generator

def num_pairs(self) -> int:
if not DB.metadata:
raise RuntimeError("DB.metadata is None")

distance_table = DB.metadata.tables["distance"]
distance_table = DB.get_table("distance")

# get all region._db_id in the bin where the record_a_id and record_b_id are in the
# bin
Expand All @@ -337,7 +340,7 @@ def num_pairs(self) -> int:
)

# get count
existing_distance_count = DB.execute(select_statement).scalar_one()
existing_distance_count: int = DB.execute(select_statement).scalar_one()

# subtract from expected number of distances
return self.bin.num_pairs() - existing_distance_count
Expand All @@ -355,10 +358,8 @@ def generate_pairs(
Yields:
Generator[tuple[int, int]]: Generator for record pairs in this bin
"""
if not DB.metadata:
raise RuntimeError("DB.metadata is None")

distance_table = DB.metadata.tables["distance"]
distance_table = DB.get_table("distance")

# get all region._db_id in the bin
select_statement = (
Expand Down Expand Up @@ -619,7 +620,7 @@ def num_pairs(self) -> int:
)
)

existing_distance_count = DB.execute(select_statement).scalar_one()
existing_distance_count = cast(int, DB.execute(select_statement).scalar_one())

return self.bin.num_pairs() - existing_distance_count

Expand Down
Loading
Loading