Skip to content

Commit

Permalink
Adding DrugBank ID and ATC code (not just name); including DrugBank d…
Browse files Browse the repository at this point in the history
…ata; incrementing version to 1.2.0
  • Loading branch information
swansonk14 committed Dec 11, 2023
1 parent 3053513 commit be7b1fd
Show file tree
Hide file tree
Showing 8 changed files with 2,649 additions and 33 deletions.
2 changes: 1 addition & 1 deletion admet_ai/_version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains the version information for ADMET-AI."""
# major, minor, patch
version_info = 1, 1, 0
version_info = 1, 2, 0

# Nice string for the version
__version__ = ".".join(map(str, version_info))
12 changes: 8 additions & 4 deletions admet_ai/admet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict
from multiprocessing import Pool
from pathlib import Path
from typing import Iterable

import numpy as np
import pandas as pd
Expand All @@ -24,7 +23,12 @@
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

from admet_ai.constants import DEFAULT_DRUGBANK_PATH, DEFAULT_MODELS_DIR
from admet_ai.constants import (
DEFAULT_DRUGBANK_PATH,
DEFAULT_MODELS_DIR,
DRUGBANK_ATC_NAME_PREFIX,
DRUGBANK_DELIMITER,
)
from admet_ai.physchem import compute_physicochemical_properties


Expand Down Expand Up @@ -82,10 +86,10 @@ def __init__(
# Map ATC codes to all indices of the drugbank with that ATC code
atc_code_to_drugbank_indices = defaultdict(set)
for atc_column in [
column for column in self.drugbank.columns if column.startswith("atc_")
column for column in self.drugbank.columns if column.startswith(DRUGBANK_ATC_NAME_PREFIX)
]:
for index, atc_codes in self.drugbank[atc_column].dropna().items():
for atc_code in atc_codes.split(";"):
for atc_code in atc_codes.split(DRUGBANK_DELIMITER):
atc_code_to_drugbank_indices[atc_code.lower()].add(index)

# Save ATC code to indices mapping to global variable and convert set to sorted list
Expand Down
6 changes: 4 additions & 2 deletions admet_ai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# Paths to data and models
with resources.path("admet_ai", "resources") as resources_dir:
DEFAULT_ADMET_PATH = resources_dir / "data" / "admet.csv"
# TODO: update DrugBank path once it's added to the repo
# DEFAULT_DRUGBANK_PATH = None
DEFAULT_DRUGBANK_PATH = resources_dir / "data" / "drugbank_approved.csv"
DEFAULT_MODELS_DIR = resources_dir / "models"

# DrugBank columns
DRUGBANK_ATC_NAME_PREFIX = "atc_name"
DRUGBANK_DELIMITER = ";"
2,580 changes: 2,580 additions & 0 deletions admet_ai/resources/data/drugbank_approved.csv

Large diffs are not rendered by default.

17 changes: 11 additions & 6 deletions admet_ai/web/app/drugbank.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from functools import lru_cache

import matplotlib
import numpy as np
import pandas as pd
from scipy.stats import percentileofscore

from admet_ai.constants import DRUGBANK_ATC_NAME_PREFIX, DRUGBANK_DELIMITER
from admet_ai.web.app import app
from admet_ai.web.app.admet_info import get_admet_id_to_name

Expand All @@ -28,10 +27,12 @@ def load_drugbank() -> None:
# Map ATC codes to all indices of the DRUGBANK_DF with that ATC code
atc_code_to_drugbank_indices = defaultdict(set)
for atc_column in [
column for column in DRUGBANK_DF.columns if column.startswith("atc_")
column
for column in DRUGBANK_DF.columns
if column.startswith(DRUGBANK_ATC_NAME_PREFIX)
]:
for index, atc_codes in DRUGBANK_DF[atc_column].dropna().items():
for atc_code in atc_codes.split(";"):
for atc_code in atc_codes.split(DRUGBANK_DELIMITER):
atc_code_to_drugbank_indices[atc_code.lower()].add(index)

# Save ATC code to indices mapping to global variable and convert set to sorted list
Expand Down Expand Up @@ -76,9 +77,13 @@ def get_drugbank_unique_atc_codes() -> list[str]:
{
atc_code.lower()
for atc_column in [
column for column in DRUGBANK_DF.columns if column.startswith("atc_")
column
for column in DRUGBANK_DF.columns
if column.startswith(DRUGBANK_ATC_NAME_PREFIX)
]
for atc_codes in DRUGBANK_DF[atc_column].dropna().str.split(";")
for atc_codes in DRUGBANK_DF[atc_column]
.dropna()
.str.split(DRUGBANK_DELIMITER)
for atc_code in atc_codes
}
)
Expand Down
54 changes: 40 additions & 14 deletions scripts/get_drugbank_approved.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from rdkit import Chem
from tqdm import tqdm

from admet_ai.constants import DRUGBANK_ATC_NAME_PREFIX, DRUGBANK_DELIMITER
from tdc_constants import (
DRUGBANK_ATC_DELIMITER,
DRUGBANK_ATC_PREFIX,
DRUGBANK_ATC_CODE_COLUMN,
DRUGBANK_ID_COLUMN,
DRUGBANK_NAME_COLUMN,
DRUGBANK_SMILES_COLUMN,
)
Expand All @@ -27,12 +28,26 @@ def get_approved_smiles_from_drugbank(data_path: Path, save_path: Path) -> None:
drugbank = ET.parse(data_path).getroot()
drugs = list(drugbank)

approved_smiles = []
approved_names = []
approved_atcs = []
approved_ids = []
approved_smiles = []
approved_atc_codes = []
approved_atc_names = []

# Loop through drugs to find approved drugs and get their SMILES
for drug in tqdm(drugs):
# Get DrugBank ID
drugbank_ids = drug.findall("db:drugbank-id", DRUGBANK_NAMESPACES)
drugbank_ids = tuple(
drugbank_id.text
for drugbank_id in drugbank_ids
if drugbank_id.text.startswith("DB")
)

# DrugBank ID length validation
if len(drugbank_ids) == 0:
raise ValueError("DrugBank ID missing")

# Get groups to determine approval status
groups_list = drug.findall("db:groups", DRUGBANK_NAMESPACES)

Expand Down Expand Up @@ -104,32 +119,43 @@ def get_approved_smiles_from_drugbank(data_path: Path, save_path: Path) -> None:
# Get ATC codes
atc_codes = atcs_list[0].findall("db:atc-code", DRUGBANK_NAMESPACES)

# Get unique ATC codes
unique_atc_codes = set()
# Get unique ATC info
drug_unique_atc_codes = set()
drug_level_to_unique_atc_names = {level: set() for level in range(1, 5)}
for atc_code in atc_codes:
atc_levels = atc_code.findall("db:level", DRUGBANK_NAMESPACES)
atc_levels = atc_code.findall("db:level", DRUGBANK_NAMESPACES)[::-1]

if len(atc_levels) != 4:
raise ValueError("ATC code does not have 4 levels")

unique_atc_codes.add(tuple(atc_levels[-i].text for i in range(1, 5)))
drug_unique_atc_codes.add(atc_levels[-1].get("code"))

for level in range(1, 5):
drug_level_to_unique_atc_names[level].add(atc_levels[level - 1].text)

# Add info for approved drug
approved_smiles.append(smiles)
approved_names.append(name)
approved_atcs.append(sorted(unique_atc_codes))
approved_ids.append(drugbank_ids)
approved_smiles.append(smiles)
approved_atc_codes.append(drug_unique_atc_codes)
approved_atc_names.append(drug_level_to_unique_atc_names)

# Create dataset of approved drugs, drop duplicates, and sort
data = pd.DataFrame(
{
DRUGBANK_NAME_COLUMN: approved_names,
DRUGBANK_ID_COLUMN: [DRUGBANK_DELIMITER.join(ids) for ids in approved_ids],
DRUGBANK_SMILES_COLUMN: approved_smiles,
DRUGBANK_ATC_CODE_COLUMN: [
DRUGBANK_DELIMITER.join(sorted(atc_codes))
for atc_codes in approved_atc_codes
],
**{
f"{DRUGBANK_ATC_PREFIX}_{i + 1}": [
DRUGBANK_ATC_DELIMITER.join(atc_code[i] for atc_code in atc_codes)
for atc_codes in approved_atcs
f"{DRUGBANK_ATC_NAME_PREFIX}_{level}": [
DRUGBANK_DELIMITER.join(sorted(level_to_atc_names[level]))
for level_to_atc_names in approved_atc_names
]
for i in range(4)
for level in range(1, 5)
},
}
)
Expand Down
7 changes: 3 additions & 4 deletions scripts/plot_drugbank_approved.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import seaborn as sns
from tqdm import trange

from tdc_constants import DRUGBANK_ATC_DELIMITER, DRUGBANK_ATC_PREFIX
from admet_ai.constants import DRUGBANK_ATC_NAME_PREFIX, DRUGBANK_DELIMITER

FIGSIZE = (18, 14)
matplotlib.rcParams["font.size"] = 28
Expand All @@ -34,11 +34,11 @@ def plot_drugbank_approved(
# Plot distribution of ATC codes at each level
for level in trange(1, 5, desc="ATC levels"):
# Compute ATC code counts at this level and keep only the top k
atc_column = f"{DRUGBANK_ATC_PREFIX}_{level}"
atc_column = f"{DRUGBANK_ATC_NAME_PREFIX}_{level}"
atc_code_counts = Counter(
atc_code
for atc_list in data[atc_column].dropna()
for atc_code in atc_list.split(DRUGBANK_ATC_DELIMITER)
for atc_code in atc_list.split(DRUGBANK_DELIMITER)
)
atc_code_df = pd.DataFrame.from_dict(
atc_code_counts, orient="index", columns=["count"]
Expand All @@ -55,7 +55,6 @@ def plot_drugbank_approved(

# Remove y-axis label and change font size of y-axis tick labels
ax.set_ylabel("")
# ax.set_yticklabels(ax.get_yticklabels(), fontsize=12)

# Add x-axis label
ax.set_xlabel("Count")
Expand Down
4 changes: 2 additions & 2 deletions scripts/tdc_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
ADMET_GROUP_SMILES_COLUMN = "Drug"
ADMET_GROUP_TARGET_COLUMN = "Y"

DRUGBANK_ID_COLUMN = "id"
DRUGBANK_NAME_COLUMN = "name"
DRUGBANK_SMILES_COLUMN = "smiles"
DRUGBANK_ATC_PREFIX = "atc"
DRUGBANK_ATC_DELIMITER = ";"
DRUGBANK_ATC_CODE_COLUMN = "atc"

0 comments on commit be7b1fd

Please sign in to comment.