Skip to content

Commit

Permalink
Refactoring plotting code to enable local plotting of drugbank refere…
Browse files Browse the repository at this point in the history
…nce and radial plots
  • Loading branch information
swansonk14 committed May 5, 2024
1 parent f051c55 commit a64d978
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 73 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,9 @@ admet_web
```

Then navigate to http://127.0.0.1:5000 to view the website.

### Analysis plots

The DrugBank reference plot and radial plots displayed on the ADMET-AI website can be generated locally using the
`scripts/plot_drugbank_reference.py` and `scripts/plot_radial_summaries.py` scripts, respectively. Both scripts
take as input a CSV file with ADMET-AI predictions along with other parameters.
24 changes: 21 additions & 3 deletions admet_ai/admet_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Defines functions for ADMET info."""
from pathlib import Path

import pandas as pd

from admet_ai.web.app import app
from admet_ai.constants import DEFAULT_ADMET_PATH


ADMET_DF = pd.DataFrame()
Expand All @@ -10,13 +12,13 @@
ADMET_ID_TO_UNITS: dict[str, str] = {}


def load_admet_info() -> None:
def load_admet_info(admet_path: Path = DEFAULT_ADMET_PATH) -> None:
"""Loads the ADMET info."""
# Set up global variables
global ADMET_DF, ADMET_ID_TO_NAME, ADMET_ID_TO_UNITS, ADMET_NAME_TO_ID

# Load ADMET info DataFrame
ADMET_DF = pd.read_csv(app.config["ADMET_PATH"])
ADMET_DF = pd.read_csv(admet_path)

# Map ADMET IDs to names and vice versa
ADMET_ID_TO_NAME = dict(zip(ADMET_DF["id"], ADMET_DF["name"]))
Expand All @@ -26,6 +28,19 @@ def load_admet_info() -> None:
ADMET_ID_TO_UNITS = dict(zip(ADMET_DF["id"], ADMET_DF["units"]))


def lazy_load_admet_info(func: callable) -> callable:
"""Decorator to lazily load the ADMET info."""

def wrapper(*args, **kwargs):
if ADMET_DF.empty:
load_admet_info()

return func(*args, **kwargs)

return wrapper


@lazy_load_admet_info
def get_admet_info() -> pd.DataFrame:
"""Get the ADMET info.
Expand All @@ -34,6 +49,7 @@ def get_admet_info() -> pd.DataFrame:
return ADMET_DF


@lazy_load_admet_info
def get_admet_id_to_name() -> dict[str, str]:
"""Get the ADMET ID to name mapping.
Expand All @@ -42,6 +58,7 @@ def get_admet_id_to_name() -> dict[str, str]:
return ADMET_ID_TO_NAME


@lazy_load_admet_info
def get_admet_name_to_id() -> dict[str, str]:
"""Get the ADMET name to ID mapping.
Expand All @@ -50,6 +67,7 @@ def get_admet_name_to_id() -> dict[str, str]:
return ADMET_NAME_TO_ID


@lazy_load_admet_info
def get_admet_id_to_units() -> dict[str, str]:
"""Get the ADMET ID to units mapping.
Expand Down
57 changes: 18 additions & 39 deletions admet_ai/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@
from rdkit import Chem
from rdkit.Chem.Draw.rdMolDraw2D import MolDraw2DSVG

from admet_ai.web.app.admet_info import (
from admet_ai.admet_info import (
get_admet_id_to_units,
get_admet_name_to_id,
)


SVG_WIDTH_PATTERN = re.compile(r"width=['\"]\d+(\.\d+)?[a-z]+['\"]")
SVG_HEIGHT_PATTERN = re.compile(r"height=['\"]\d+(\.\d+)?[a-z]+['\"]")


def string_to_latex_sup(string: str) -> str:
"""Converts a string with an exponential to LaTeX superscript.
Expand All @@ -28,34 +24,23 @@ def string_to_latex_sup(string: str) -> str:
return re.sub(r"\^(\d+)", r"$^{\1}$", string)


def replace_svg_dimensions(svg_content: str) -> str:
"""Replace the SVG width and height with 100%.
:param svg_content: The SVG content.
:return: The SVG content with the width and height replaced with 100%.
"""
# Replacing the width and height with 100%
svg_content = SVG_WIDTH_PATTERN.sub('width="100%"', svg_content)
svg_content = SVG_HEIGHT_PATTERN.sub('height="100%"', svg_content)

return svg_content


def plot_drugbank_reference(
preds_df: pd.DataFrame,
drugbank_df: pd.DataFrame,
x_property_name: str | None = None,
y_property_name: str | None = None,
max_molecule_num: int | None = None,
) -> str:
image_type: str = "svg",
) -> bytes:
"""Creates a 2D scatter plot of the DrugBank reference set vs the new set of molecules on two properties.
:param preds_df: A DataFrame containing the predictions on the new molecules.
:param drugbank_df: A DataFrame containing the DrugBank reference set.
:param x_property_name: The name of the property to plot on the x-axis.
:param y_property_name: The name of the property to plot on the y-axis.
:param max_molecule_num: If provided, will display molecule numbers up to this number.
:return: A string containing the SVG of the plot.
:param image_type: The image type for the plot (e.g., svg).
:return: Bytes containing the plot.
"""
# Set default values
if x_property_name is None:
Expand Down Expand Up @@ -130,26 +115,26 @@ def plot_drugbank_reference(

# Save plot as svg to pass to frontend
buf = BytesIO()
plt.savefig(buf, format="svg", bbox_inches="tight")
plt.savefig(buf, format=image_type, bbox_inches="tight")
plt.close()
buf.seek(0)
drugbank_svg = buf.getvalue().decode("utf-8")

# Set the SVG width and height to 100%
drugbank_svg = replace_svg_dimensions(drugbank_svg)
plot = buf.getvalue()

return drugbank_svg
return plot


def plot_radial_summary(
property_id_to_percentile: dict[str, float], percentile_suffix: str = "",
) -> str:
property_id_to_percentile: dict[str, float],
percentile_suffix: str = "",
image_type: str = "svg",
) -> bytes:
"""Creates a radial plot summary of important properties of a molecule in terms of DrugBank approved percentiles.
:param property_id_to_percentile: A dictionary mapping property IDs to their DrugBank approved percentiles.
Keys are the property name along with the percentile_suffix.
:param percentile_suffix: The suffix to add to the property names to get the DrugBank approved percentiles.
:return: A string containing the SVG of the plot.
:param image_type: The image type for the plot (e.g., svg).
:return: Bytes containing the plot.
"""
# Set max percentile
max_percentile = 100
Expand Down Expand Up @@ -230,17 +215,14 @@ def plot_radial_summary(
# Ensure no text labels are cut off
plt.tight_layout()

# Save plot as svg to pass to frontend
# Save plot
buf = BytesIO()
plt.savefig(buf, format="svg")
plt.savefig(buf, format=image_type)
plt.close()
buf.seek(0)
radial_svg = buf.getvalue().decode("utf-8")
plot = buf.getvalue()

# Set the SVG width and height to 100%
radial_svg = replace_svg_dimensions(radial_svg)

return radial_svg
return plot


def plot_molecule_svg(mol: str | Chem.Mol) -> str:
Expand All @@ -259,7 +241,4 @@ def plot_molecule_svg(mol: str | Chem.Mol) -> str:
d.FinishDrawing()
smiles_svg = d.GetDrawingText()

# Set the SVG width and height to 100%
smiles_svg = replace_svg_dimensions(smiles_svg)

return smiles_svg
15 changes: 14 additions & 1 deletion admet_ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import pandas as pd


def load_and_preprocess_data(data_path: Path, smiles_column: str = "smiles") -> pd.DataFrame:
def load_and_preprocess_data(
data_path: Path, smiles_column: str = "smiles"
) -> pd.DataFrame:
"""Preprocess a dataset of molecules by removing missing SMILES and setting the SMILES as the index.
:param data_path: Path to a CSV file containing a dataset of molecules.
Expand All @@ -29,3 +31,14 @@ def load_and_preprocess_data(data_path: Path, smiles_column: str = "smiles") ->
data.set_index(smiles_column, inplace=True)

return data


def get_drugbank_suffix(atc_code: str | None) -> str:
"""Gets the DrugBank percentile suffix for the given ATC code.
:param atc_code: The ATC code.
"""
if atc_code is None:
return "drugbank_approved_percentile"

return f"drugbank_approved_{atc_code}_percentile"
2 changes: 0 additions & 2 deletions admet_ai/web/app/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Sets the config parameters for the ADMET-AI Flask app object."""
from admet_ai.constants import (
DEFAULT_ADMET_PATH,
DEFAULT_DRUGBANK_PATH,
DEFAULT_MODELS_DIR,
)


MODELS_DIR = DEFAULT_MODELS_DIR
ADMET_PATH = DEFAULT_ADMET_PATH
DRUGBANK_PATH = DEFAULT_DRUGBANK_PATH
LOW_PERFORMANCE_THRESHOLD = 0.6
NUM_WORKERS = 0
27 changes: 12 additions & 15 deletions admet_ai/web/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from admet_ai.web.app import app


SVG_WIDTH_PATTERN = re.compile(r"width=['\"]\d+(\.\d+)?[a-z]+['\"]")
SVG_HEIGHT_PATTERN = re.compile(r"height=['\"]\d+(\.\d+)?[a-z]+['\"]")


def get_smiles_from_request() -> tuple[list[str] | None, str | None]:
"""Gets SMILES from a request.
Expand Down Expand Up @@ -77,21 +81,14 @@ def string_to_html_sup(string: str) -> str:
return re.sub(r"\^(-?\d+)", r"<sup>\1</sup>", string)


def string_to_latex_sup(string: str) -> str:
"""Converts a string with an exponential to LaTeX superscript.
:param string: A string.
:return: The string with an exponential in LaTeX superscript.
"""
return re.sub(r"\^(\d+)", r"$^{\1}$", string)


def get_drugbank_suffix(atc_code: str | None) -> str:
"""Gets the DrugBank percentile suffix for the given ATC code.
def replace_svg_dimensions(svg_content: str) -> str:
"""Replace the SVG width and height with 100%.
:param atc_code: The ATC code.
:param svg_content: The SVG content.
:return: The SVG content with the width and height replaced with 100%.
"""
if atc_code is None:
return "drugbank_approved_percentile"
# Replacing the width and height with 100%
svg_content = SVG_WIDTH_PATTERN.sub('width="100%"', svg_content)
svg_content = SVG_HEIGHT_PATTERN.sub('height="100%"', svg_content)

return f"drugbank_approved_{atc_code}_percentile"
return svg_content
28 changes: 17 additions & 11 deletions admet_ai/web/app/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,29 @@
)

from admet_ai._version import __version__
from admet_ai.web.app import app
from admet_ai.web.app.admet_info import get_admet_info
from admet_ai.web.app.drugbank import (
from admet_ai.admet_info import get_admet_info
from admet_ai.drugbank import (
get_drugbank,
get_drugbank_size,
get_drugbank_task_names,
get_drugbank_unique_atc_codes,
)
from admet_ai.web.app.models import get_admet_model
from admet_ai.web.app.plot import (
from admet_ai.plot import (
plot_drugbank_reference,
plot_molecule_svg,
plot_radial_summary,
)
from admet_ai.utils import get_drugbank_suffix
from admet_ai.web.app import app
from admet_ai.web.app.models import get_admet_model
from admet_ai.web.app.storage import (
get_user_preds,
set_user_preds,
update_user_activity,
)
from admet_ai.web.app.utils import (
get_drugbank_suffix,
get_smiles_from_request,
replace_svg_dimensions,
smiles_to_mols,
string_to_html_sup,
)
Expand Down Expand Up @@ -132,26 +134,29 @@ def index() -> str:
# Create DrugBank reference plot
drugbank_plot_svg = plot_drugbank_reference(
preds_df=all_preds,
drugbank_df=get_drugbank(atc_code=session.get("atc_code")),
x_property_name=session.get("drugbank_x_task_name"),
y_property_name=session.get("drugbank_y_task_name"),
atc_code=session.get("atc_code"),
max_molecule_num=app.config["MAX_VISIBLE_MOLECULES"],
)
).decode("utf-8")
drugbank_plot_svg = replace_svg_dimensions(drugbank_plot_svg)

# Get maximum number of molecules to display
num_display_molecules = min(len(all_smiles), app.config["MAX_VISIBLE_MOLECULES"])

# Create molecule SVG images
mol_svgs = [plot_molecule_svg(mol) for mol in mols[:num_display_molecules]]
mol_svgs = [replace_svg_dimensions(plot) for plot in mol_svgs]

# Create molecule radial plots for DrugBank approved percentiles
radial_svgs = [
plot_radial_summary(
property_id_to_percentile=smiles_to_property_id_to_pred[smiles],
percentile_suffix=get_drugbank_suffix(session.get("atc_code")),
)
).decode("utf-8")
for smiles in all_smiles[:num_display_molecules]
]
radial_svgs = [replace_svg_dimensions(plot) for plot in radial_svgs]

return render(
predicted=True,
Expand Down Expand Up @@ -206,11 +211,12 @@ def drugbank_plot() -> Response:
# Create DrugBank reference plot with ATC code
drugbank_plot_svg = plot_drugbank_reference(
preds_df=get_user_preds(session["user_id"]),
drugbank_df=get_drugbank(atc_code=session.get("atc_code")),
x_property_name=session["drugbank_x_task_name"],
y_property_name=session["drugbank_y_task_name"],
atc_code=session.get("atc_code"),
max_molecule_num=app.config["MAX_VISIBLE_MOLECULES"],
)
).decode("utf-8")
drugbank_plot_svg = replace_svg_dimensions(drugbank_plot_svg)

return jsonify({"svg": drugbank_plot_svg})

Expand Down
7 changes: 5 additions & 2 deletions admet_ai/web/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from datetime import timedelta
from threading import Thread

import matplotlib
from tap import tapify

from admet_ai.admet_info import load_admet_info
from admet_ai.drugbank import load_drugbank
from admet_ai.web.app import app
from admet_ai.web.app.admet_info import load_admet_info
from admet_ai.web.app.drugbank import load_drugbank
from admet_ai.web.app.models import load_admet_model
from admet_ai.web.app.storage import cleanup_storage

matplotlib.use("Agg")


def setup_web(
secret_key: str = "".join(
Expand Down
Loading

0 comments on commit a64d978

Please sign in to comment.