Skip to content

Commit

Permalink
Setting strict chemprop version to avoid import errors
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed May 5, 2024
1 parent 29c0db3 commit f051c55
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 27 deletions.
File renamed without changes.
36 changes: 20 additions & 16 deletions admet_ai/web/app/drugbank.py → admet_ai/drugbank.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
"""Defines functions for the DrugBank approved reference set."""
from collections import defaultdict
from functools import lru_cache
from pathlib import Path

import matplotlib
import pandas as pd

from admet_ai.admet_info import get_admet_id_to_name
from admet_ai.constants import (
DEFAULT_DRUGBANK_PATH,
DRUGBANK_ATC_NAME_PREFIX,
DRUGBANK_ATC_PREFIX,
DRUGBANK_DELIMITER,
DRUGBANK_ID_COLUMN,
DRUGBANK_NAME_COLUMN,
DRUGBANK_SMILES_COLUMN,
)
from admet_ai.web.app import app
from admet_ai.web.app.admet_info import get_admet_id_to_name

matplotlib.use("Agg")


DRUGBANK_DF = pd.DataFrame()
ATC_CODE_TO_DRUGBANK_INDICES: dict[str, list[int]] = {}


def load_drugbank() -> None:
"""Loads the reference set of DrugBank approved molecules with their model predictions."""
def load_drugbank(drugbank_path: Path = DEFAULT_DRUGBANK_PATH) -> None:
"""Loads the reference set of DrugBank approved molecules with their model predictions.
:param drugbank_path: The path to the DrugBank reference set.
"""
# Set up global variables
global DRUGBANK_DF, ATC_CODE_TO_DRUGBANK_INDICES

# Load DrugBank DataFrame
DRUGBANK_DF = pd.read_csv(app.config["DRUGBANK_PATH"])
DRUGBANK_DF = pd.read_csv(drugbank_path)

# Map ATC codes to all indices of the DRUGBANK_DF with that ATC code
atc_code_to_drugbank_indices = defaultdict(set)
Expand All @@ -55,6 +56,9 @@ def get_drugbank(atc_code: str | None = None) -> pd.DataFrame:
:param atc_code: The ATC code to filter by. If None or 'all', returns the entire DrugBank.
:return: A DataFrame containing the DrugBank reference set, optionally filtered by ATC code.
"""
if DRUGBANK_DF.empty:
load_drugbank()

if atc_code is None:
return DRUGBANK_DF

Expand All @@ -80,17 +84,17 @@ def get_drugbank_unique_atc_codes() -> list[str]:
:return: A list of unique ATC codes in the DrugBank reference set.
"""
drugbank = get_drugbank()

return sorted(
{
atc_code.lower()
for atc_column in [
column
for column in DRUGBANK_DF.columns
for column in drugbank.columns
if column.startswith(DRUGBANK_ATC_NAME_PREFIX)
]
for atc_codes in DRUGBANK_DF[atc_column]
.dropna()
.str.split(DRUGBANK_DELIMITER)
for atc_codes in drugbank[atc_column].dropna().str.split(DRUGBANK_DELIMITER)
for atc_code in atc_codes
}
)
Expand All @@ -102,16 +106,16 @@ def get_drugbank_tasks_ids() -> list[str]:
:return: A list of tasks (properties) predicted in the DrugBank reference set.
"""
drugbank = get_drugbank()

non_task_columns = [
DRUGBANK_ID_COLUMN,
DRUGBANK_NAME_COLUMN,
DRUGBANK_SMILES_COLUMN,
] + [
column
for column in DRUGBANK_DF.columns
if column.startswith(DRUGBANK_ATC_PREFIX)
column for column in drugbank.columns if column.startswith(DRUGBANK_ATC_PREFIX)
]
task_columns = set(DRUGBANK_DF.columns) - set(non_task_columns)
task_columns = set(drugbank.columns) - set(non_task_columns)
drugbank_task_ids = sorted(task_columns)

return drugbank_task_ids
Expand Down
24 changes: 14 additions & 10 deletions admet_ai/web/app/plot.py → admet_ai/plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Defines functions for plotting for the ADMET-AI website."""
"""Defines functions for ADMET-AI plots."""
import re
from io import BytesIO

Expand All @@ -13,14 +13,21 @@
get_admet_id_to_units,
get_admet_name_to_id,
)
from admet_ai.web.app.drugbank import get_drugbank
from admet_ai.web.app.utils import string_to_latex_sup


SVG_WIDTH_PATTERN = re.compile(r"width=['\"]\d+(\.\d+)?[a-z]+['\"]")
SVG_HEIGHT_PATTERN = re.compile(r"height=['\"]\d+(\.\d+)?[a-z]+['\"]")


def string_to_latex_sup(string: str) -> str:
"""Converts a string with an exponential to LaTeX superscript.
:param string: A string.
:return: The string with an exponential in LaTeX superscript.
"""
return re.sub(r"\^(\d+)", r"$^{\1}$", string)


def replace_svg_dimensions(svg_content: str) -> str:
"""Replace the SVG width and height with 100%.
Expand All @@ -36,17 +43,17 @@ def replace_svg_dimensions(svg_content: str) -> str:

def plot_drugbank_reference(
preds_df: pd.DataFrame,
drugbank_df: pd.DataFrame,
x_property_name: str | None = None,
y_property_name: str | None = None,
atc_code: str | None = None,
max_molecule_num: int | None = None,
) -> str:
"""Creates a 2D scatter plot of the DrugBank reference set vs the new set of molecules on two properties.
:param preds_df: A DataFrame containing the predictions on the new molecules.
:param drugbank_df: A DataFrame containing the DrugBank reference set.
:param x_property_name: The name of the property to plot on the x-axis.
:param y_property_name: The name of the property to plot on the y-axis.
:param atc_code: The ATC code to filter the DrugBank reference set by.
:param max_molecule_num: If provided, will display molecule numbers up to this number.
:return: A string containing the SVG of the plot.
"""
Expand All @@ -57,18 +64,15 @@ def plot_drugbank_reference(
if y_property_name is None:
y_property_name = "Clinical Toxicity"

# Get DrugBank reference, optionally filtered ATC code
drugbank = get_drugbank(atc_code=atc_code)

# Map property names to IDs
admet_name_to_id = get_admet_name_to_id()
x_property_id = admet_name_to_id[x_property_name]
y_property_id = admet_name_to_id[y_property_name]

# Scatter plot of DrugBank molecules with histogram marginals
sns.jointplot(
x=drugbank[x_property_id],
y=drugbank[y_property_id],
x=drugbank_df[x_property_id],
y=drugbank_df[y_property_id],
kind="scatter",
marginal_kws=dict(bins=50, fill=True),
label="DrugBank Reference",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
install_requires=[
"chemfunc>=1.0.4",
"chemprop>=1.6.1",
"chemprop==1.6.1",
"numpy",
"pandas",
"rdkit>=2023.3.3",
Expand Down

0 comments on commit f051c55

Please sign in to comment.