Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix/extend strat db #179

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions big_scape/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
validate_binning_cluster_workflow,
validate_binning_query_workflow,
validate_alignment_mode,
validate_extend_strategy,
validate_includelist_all,
validate_includelist_any,
validate_gcf_cutoffs,
Expand All @@ -23,6 +24,7 @@
"validate_binning_cluster_workflow",
"validate_binning_query_workflow",
"validate_alignment_mode",
"validate_extend_strategy",
"validate_includelist_all",
"validate_includelist_any",
"validate_gcf_cutoffs",
Expand Down
2 changes: 2 additions & 0 deletions big_scape/cli/cli_common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
validate_not_empty_dir,
validate_input_mode,
validate_alignment_mode,
validate_extend_strategy,
validate_includelist_all,
validate_includelist_any,
validate_gcf_cutoffs,
Expand Down Expand Up @@ -307,6 +308,7 @@ def common_cluster_query(fn):
"--extend_strategy",
type=click.Choice(["legacy", "greedy"]),
default="legacy",
callback=validate_extend_strategy,
help="Strategy to extend BGCs. 'legacy' will use the original BiG-SCAPE extension strategy, "
"while 'greedy' will use a new greedy extension strategy. (default: legacy).",
),
Expand Down
12 changes: 12 additions & 0 deletions big_scape/cli/cli_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ def validate_alignment_mode(
return None


def validate_extend_strategy(
ctx, param, extend_strategy
) -> Optional[bs_enums.EXTEND_STRATEGY]:
"""Validate the passed extend strategy is one of the allowed modes"""
valid_strats = [strat.value for strat in bs_enums.EXTEND_STRATEGY]

for strat in valid_strats:
if extend_strategy == strat:
return bs_enums.EXTEND_STRATEGY[strat.upper()]
return None


def validate_gcf_cutoffs(ctx, param, gcf_cutoffs) -> list[float]:
"""Validates range and formats into correct list[float] format"""

Expand Down
33 changes: 25 additions & 8 deletions big_scape/comparison/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlalchemy import insert, select

# from other modules
import big_scape.enums as bs_enums
from big_scape.data import DB

# from this module
Expand Down Expand Up @@ -183,7 +184,7 @@ def save_edges_to_db(
# yield pair, distance, jaccard, adjacency, dss, edge_param_id


def get_edge_param_id(run, weights) -> int:
def get_edge_param_id(run: dict, weights: str) -> int:
"""get edge params id if available, else create a new one

Args:
Expand All @@ -201,23 +202,29 @@ def get_edge_param_id(run, weights) -> int:
raise RuntimeError("DB.metadata is None")

alignment_mode = run["alignment_mode"]
extend_strategy = run["extend_strategy"]

edge_param_id = edge_params_query(alignment_mode, weights)
edge_param_id = edge_params_query(alignment_mode, weights, extend_strategy)

if edge_param_id is None:
edge_param_id = edge_params_insert(alignment_mode, weights)
edge_param_id = edge_params_insert(alignment_mode, weights, extend_strategy)

logging.debug("Edge params id: %d", edge_param_id[0])

return edge_param_id[0]


def edge_params_query(alignment_mode, weights):
def edge_params_query(
alignment_mode: bs_enums.ALIGNMENT_MODE,
weights: str,
extend_strategy: bs_enums.EXTEND_STRATEGY,
):
"""Create and run a query for edge params

Args:
alignment_mode (enum): global, glocal or auto
weights (str): weights category, i.e. "mix"
extend_strategy (enum): legacy, greedy

Raises:
RuntimeError: no dabatase
Expand All @@ -233,6 +240,7 @@ def edge_params_query(alignment_mode, weights):
edge_params_query = (
select(edge_params_table.c.id)
.where(edge_params_table.c.alignment_mode == alignment_mode.name)
.where(edge_params_table.c.extend_strategy == extend_strategy.name)
.where(edge_params_table.c.weights == weights)
)

Expand All @@ -241,12 +249,17 @@ def edge_params_query(alignment_mode, weights):
return edge_param_id


def edge_params_insert(alignment_mode, weights):
def edge_params_insert(
alignment_mode: bs_enums.ALIGNMENT_MODE,
weights: str,
extend_strategy: bs_enums.EXTEND_STRATEGY,
):
"""Insert an edge param entry into the database

Args:
alignment_mode (_type_): global, glocal or auto
weights (_type_): weights category, i.e. "mix"
alignment_mode (enum): global, glocal or auto
weights (str): weights category, i.e. "mix"
extend_strategy (enum): legacy, greedy

Raises:
RuntimeError: no dabatase
Expand All @@ -261,7 +274,11 @@ def edge_params_insert(alignment_mode, weights):
edge_params_table = DB.metadata.tables["edge_params"]
edge_params_insert = (
edge_params_table.insert()
.values(alignment_mode=alignment_mode.name, weights=weights)
.values(
alignment_mode=alignment_mode.name,
weights=weights,
extend_strategy=extend_strategy.name,
)
.returning(edge_params_table.c.id)
.compile()
)
Expand Down
27 changes: 13 additions & 14 deletions big_scape/comparison/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from .record_pair import RecordPair

# from dependencies
import click
from sqlalchemy import select

# from other modules
Expand Down Expand Up @@ -108,6 +107,7 @@ def get_batch_size(cores: int, desired_batch_size: int, num_items: int):
def generate_edges(
pair_generator: RecordPairGenerator,
alignment_mode: bs_enums.ALIGNMENT_MODE,
extend_strategy: bs_enums.EXTEND_STRATEGY,
cores: int,
max_queue_length: int,
callback: Optional[Callable] = None,
Expand Down Expand Up @@ -184,6 +184,7 @@ def on_complete(future: Future):
(
batch,
alignment_mode,
extend_strategy,
pair_generator.edge_param_id,
pair_generator.weights,
),
Expand Down Expand Up @@ -294,7 +295,11 @@ def do_lcs_pair(pair: RecordPair) -> bool: # pragma no cover
return False


def expand_pair(pair: RecordPair, alignment_mode: bs_enums.ALIGNMENT_MODE) -> bool:
def expand_pair(
pair: RecordPair,
alignment_mode: bs_enums.ALIGNMENT_MODE,
extend_strategy: bs_enums.EXTEND_STRATEGY,
) -> bool:
"""Expand the pair

Args:
Expand All @@ -304,22 +309,15 @@ def expand_pair(pair: RecordPair, alignment_mode: bs_enums.ALIGNMENT_MODE) -> bo
Returns:
bool: True if the pair was extended, False if it does not
"""
# TODO: true arg means silent if there is no context. this is done so that unit
# tests don't complain. remove this and mock the context in unit tests instead
click_context = click.get_current_context(silent=True)

if not click_context:
raise RuntimeError("No click context found")

if click_context.obj["extend_strategy"] == "legacy":
if extend_strategy == bs_enums.EXTEND_STRATEGY.LEGACY:
extend(
pair,
BigscapeConfig.EXPAND_MATCH_SCORE,
BigscapeConfig.EXPAND_MISMATCH_SCORE,
BigscapeConfig.EXPAND_GAP_SCORE,
BigscapeConfig.EXPAND_MAX_MATCH_PERC,
)
if click_context.obj["extend_strategy"] == "greedy":
if extend_strategy == bs_enums.EXTEND_STRATEGY.GREEDY:
extend_greedy(pair)

# after local expansion, additionally expand shortest arms in glocal/auto
Expand Down Expand Up @@ -395,6 +393,7 @@ def calculate_scores_pair(
data: tuple[
list[Union[tuple[int, int], tuple[BGCRecord, BGCRecord]]],
bs_enums.ALIGNMENT_MODE,
bs_enums.EXTEND_STRATEGY,
int,
str,
]
Expand All @@ -414,14 +413,14 @@ def calculate_scores_pair(

Args:
data (tuple[list[tuple[int, int]], str, str]): list of pairs, alignment mode,
bin label
extend_strategy, edge_param_id, bin label

Returns:
list[tuple[int, int, float, float, float, float, int, int, int, int, int, int,
int, int, bool, str,]]: list of scores for each pair in the
order as the input data list, including lcs and extension coordinates
"""
data, alignment_mode, edge_param_id, weights_label = data
data, alignment_mode, extend_strategy, edge_param_id, weights_label = data

# convert database ids to minimal record objects
if isinstance(data[0][0], int):
Expand Down Expand Up @@ -491,7 +490,7 @@ def calculate_scores_pair(
):
needs_expand = do_lcs_pair(pair)
if needs_expand:
expand_pair(pair, alignment_mode)
expand_pair(pair, alignment_mode, extend_strategy)

if weights_label not in LEGACY_WEIGHTS:
bin_weights = LEGACY_WEIGHTS["mix"]["weights"]
Expand Down
3 changes: 2 additions & 1 deletion big_scape/data/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ CREATE TABLE IF NOT EXISTS edge_params (
id INTEGER PRIMARY KEY AUTOINCREMENT,
weights TEXT NOT NULL,
alignment_mode TEXT NOT NULL,
extend_strategy TEXT NOT NULL,
UNIQUE(id),
UNIQUE(weights, alignment_mode)
UNIQUE(weights, alignment_mode, extend_strategy)
);

CREATE INDEX IF NOT EXISTS record_id_index ON bgc_record(id);
Expand Down
1 change: 1 addition & 0 deletions big_scape/distances/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def callback(edges):
bs_comparison.generate_edges(
missing_edge_bin,
run["alignment_mode"],
run["extend_strategy"],
run["cores"],
run["cores"] * 2,
callback,
Expand Down
1 change: 1 addition & 0 deletions big_scape/distances/legacy_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def callback(edges):
bs_comparison.generate_edges(
missing_edge_bin,
run["alignment_mode"],
run["extend_strategy"],
run["cores"],
run["cores"] * 2,
callback,
Expand Down
1 change: 1 addition & 0 deletions big_scape/distances/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def callback(edges):
bs_comparison.generate_edges(
missing_edge_bin,
run["alignment_mode"],
run["extend_strategy"],
run["cores"],
run["cores"] * 2,
callback,
Expand Down
1 change: 1 addition & 0 deletions big_scape/distances/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def callback(edges):
bs_comparison.generate_edges(
bin,
run["alignment_mode"],
run["extend_strategy"],
run["cores"],
run["cores"] * 2,
callback,
Expand Down
10 changes: 9 additions & 1 deletion big_scape/enums/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Module containing code related to enums"""

from .input_parameters import INPUT_MODE
from .source_type import SOURCE_TYPE
from .partial_task import TASK, INPUT_TASK, HMM_TASK, COMPARISON_TASK
from .comparison import ALIGNMENT_MODE, LCS_MODE, COMPARISON_MODE, CLASSIFY_MODE
from .comparison import (
ALIGNMENT_MODE,
EXTEND_STRATEGY,
LCS_MODE,
COMPARISON_MODE,
CLASSIFY_MODE,
)
from .genbank import RECORD_TYPE

__all__ = [
Expand All @@ -13,6 +20,7 @@
"HMM_TASK",
"COMPARISON_TASK",
"ALIGNMENT_MODE",
"EXTEND_STRATEGY",
"LCS_MODE",
"COMPARISON_MODE",
"CLASSIFY_MODE",
Expand Down
5 changes: 5 additions & 0 deletions big_scape/enums/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ class ALIGNMENT_MODE(Enum):
AUTO = "auto"


class EXTEND_STRATEGY(Enum):
LEGACY = "legacy"
GREEDY = "greedy"


class COMPARISON_MODE(Enum):
CDS = "cds"
DOMAIN = "domain"
Expand Down
7 changes: 6 additions & 1 deletion test/comparison/test_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
RecordPairGenerator,
ConnectedComponentPairGenerator,
QueryRecordPairGenerator,
QueryMissingRecordPairGenerator,
save_edge_to_db,
get_record_category,
get_legacy_weights_from_category,
Expand Down Expand Up @@ -325,6 +324,7 @@ def test_connected_component_pair_generator(self):

run = {
"alignment_mode": bs_enums.ALIGNMENT_MODE.AUTO,
"extend_strategy": bs_enums.EXTEND_STRATEGY.LEGACY,
"legacy_weights": True,
"classify": bs_enums.CLASSIFY_MODE.CATEGORY,
}
Expand Down Expand Up @@ -657,6 +657,7 @@ def test_mix_iter(self):
run = {
"record_type": bs_enums.RECORD_TYPE.REGION,
"alignment_mode": bs_enums.ALIGNMENT_MODE.AUTO,
"extend_strategy": bs_enums.EXTEND_STRATEGY.LEGACY,
}

new_bin = generate_mix_bin(bgc_list, run)
Expand Down Expand Up @@ -903,6 +904,7 @@ def test_as_class_bin_generator(self):

run_category_weights = {
"alignment_mode": bs_enums.ALIGNMENT_MODE.AUTO,
"extend_strategy": bs_enums.EXTEND_STRATEGY.LEGACY,
"legacy_weights": True,
"classify": bs_enums.CLASSIFY_MODE.CATEGORY,
"hybrids_off": False,
Expand All @@ -921,6 +923,7 @@ def test_get_edge_params_id_insert(self):
bs_data.DB.create_in_mem()
run = {
"alignment_mode": bs_enums.ALIGNMENT_MODE.AUTO,
"extend_strategy": bs_enums.EXTEND_STRATEGY.LEGACY,
"legacy_weights": True,
"classify": bs_enums.CLASSIFY_MODE.CATEGORY,
}
Expand All @@ -937,6 +940,7 @@ def test_get_edge_params_id_fetch(self):
bs_data.DB.create_in_mem()
run = {
"alignment_mode": bs_enums.ALIGNMENT_MODE.AUTO,
"extend_strategy": bs_enums.EXTEND_STRATEGY.LEGACY,
"legacy_weights": True,
"classify": bs_enums.CLASSIFY_MODE.CATEGORY,
}
Expand All @@ -955,6 +959,7 @@ def test_get_edge_weight(self):
bs_data.DB.create_in_mem()
run = {
"alignment_mode": bs_enums.ALIGNMENT_MODE.AUTO,
"extend_strategy": bs_enums.EXTEND_STRATEGY.LEGACY,
"legacy_weights": True,
"classify": bs_enums.CLASSIFY_MODE.CATEGORY,
}
Expand Down
Loading
Loading