Skip to content

Commit

Permalink
multi ligand eval (#74)
Browse files Browse the repository at this point in the history
* multi ligand eval

* fix: correctly handle receptor file

* feat: add custom split column name

* fix: correctly handle receptor file

* chore: lint

* rename ligand chains (#84)

* rename ligand chains

* linting

* save 1 ligand score per row

* chore: fix type and tests

* feat: add back aggregate scores

---------

Co-authored-by: OleinikovasV <[email protected]>

---------

Co-authored-by: Ninjani <[email protected]>
Co-authored-by: Thomas Castiglione <[email protected]>
  • Loading branch information
3 people authored Nov 14, 2024
1 parent 9f139c9 commit ba47d9b
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 141 deletions.
3 changes: 2 additions & 1 deletion src/plinder/eval/docking/make_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class EvaluationResults:
def from_scores_and_data_files(
cls, score_file: Path, data_file: Path, output_dir: Path, top_n: int = 10
) -> "EvaluationResults":
scores_df = pd.read_parquet(score_file)
# use only one score per system when plotting aggregated scores
scores_df = pd.read_parquet(score_file).drop_duplicates("reference")
data_df = pd.read_parquet(data_file)
merged_df = scores_df[scores_df["reference"].isin(data_df["system_id"])].merge(
data_df, left_on="reference", right_on="system_id", how="left"
Expand Down
45 changes: 34 additions & 11 deletions src/plinder/eval/docking/stratify_test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def compute_protein_max_similarities(


def compute_ligand_max_similarities(
df: pd.DataFrame, train_label: str, test_label: str, output_file: Path
df: pd.DataFrame,
split_label: str,
train_label: str,
test_label: str,
output_file: Path,
) -> None:
if "fp" not in df.columns:
smiles_fp_dict = {
Expand All @@ -68,10 +72,10 @@ def compute_ligand_max_similarities(
}
df["fp"] = df["ligand_rdkit_canonical_smiles"].map(smiles_fp_dict)

df_test = df.loc[df["split"] == test_label][["system_id", "fp"]].copy()
df_test = df.loc[df[split_label] == test_label][["system_id", "fp"]].copy()

df_test["tanimoto_similarity_max"] = smallmolecules.tanimoto_maxsim_matrix(
df.loc[df["split"] == train_label]["fp"].to_list(),
df.loc[df[split_label] == train_label]["fp"].to_list(),
df_test["fp"].to_list(),
)
df_test.drop("fp", axis=1).groupby("system_id").agg("max").reset_index().to_parquet(
Expand All @@ -83,6 +87,7 @@ def compute_ligand_max_similarities(
class StratifiedTestSet:
split_df: pd.DataFrame
output_dir: Path
split_label: str = "split"
train_label: str = "train"
test_label: str = "test"
similarity_thresholds: dict[str, int] = field(
Expand Down Expand Up @@ -127,6 +132,7 @@ def from_split(
cls,
split_file: Path,
output_dir: Path,
split_label: str = "split",
train_label: str = "train",
test_label: str = "test",
overwrite: bool = False,
Expand All @@ -135,13 +141,14 @@ def from_split(
split_df = pd.read_csv(split_file)
else:
split_df = pd.read_parquet(split_file)
assert all(x in split_df.columns for x in ["split", "system_id"])
split_df = split_df[split_df["split"].isin([train_label, test_label])][
["system_id", "split"]
assert all(x in split_df.columns for x in [split_label, "system_id"])
split_df = split_df[split_df[split_label].isin([train_label, test_label])][
[split_label, "system_id"]
].reset_index(drop=True)
data = cls(
split_df=split_df,
output_dir=output_dir,
split_label=split_label,
train_label=train_label,
test_label=test_label,
)
Expand Down Expand Up @@ -180,8 +187,16 @@ def compute_train_test_max_similarity(
self, overwrite: bool = False
) -> pd.DataFrame:
left, right = (
set(self.split_df[self.split_df["split"] == self.train_label]["system_id"]),
set(self.split_df[self.split_df["split"] == self.test_label]["system_id"]),
set(
self.split_df[self.split_df[self.split_label] == self.train_label][
"system_id"
]
),
set(
self.split_df[self.split_df[self.split_label] == self.test_label][
"system_id"
]
),
)
LOG.info(
f"compute_train_test_max_similarity: Found {len(left)} train and {len(right)} test systems"
Expand All @@ -204,6 +219,7 @@ def compute_train_test_max_similarity(
df = df.merge(self.split_df, on="system_id", how="left")
compute_ligand_max_similarities(
df,
self.split_label,
self.train_label,
self.test_label,
self.get_filename(metric),
Expand Down Expand Up @@ -250,9 +266,9 @@ def assign_test_set_quality(self) -> None:
"system_id",
"in",
set(
self.split_df[self.split_df["split"] == self.test_label][
"system_id"
]
self.split_df[
self.split_df[self.split_label] == self.test_label
]["system_id"]
),
)
], # type: ignore
Expand Down Expand Up @@ -300,6 +316,12 @@ def stratify_cmd(args: list[str] | None = None) -> None:
type=Path,
help="Path to output folder where similarity and stratification data are saved",
)
parser.add_argument(
"--split_label",
type=str,
default="split",
help="split=<split_label> is used to get split systems",
)
parser.add_argument(
"--train_label",
type=str,
Expand All @@ -325,6 +347,7 @@ def stratify_cmd(args: list[str] | None = None) -> None:
StratifiedTestSet.from_split(
split_file=Path(ns.split_file),
output_dir=Path(ns.output_dir),
split_label=ns.split_label,
train_label=ns.train_label,
test_label=ns.test_label,
overwrite=ns.overwrite,
Expand Down
58 changes: 36 additions & 22 deletions src/plinder/eval/docking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Distributed under the terms of the Apache License 2.0
from __future__ import annotations

import copy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -84,10 +85,18 @@ def from_files(
entity = io.LoadMMCIF(receptor_file.as_posix(), fault_tolerant=True)
else:
entity = io.LoadPDB(receptor_file.as_posix(), fault_tolerant=True)
ligand_views = [
io.LoadEntity(str(ligand_sdf_file), format="sdf").Select("ele != H")
for ligand_sdf_file in ligand_files
]

ligand_views = []
for i, ligand_file in enumerate(ligand_files):
ligand_entity = io.LoadEntity(str(ligand_file), format="sdf")

# rename ligand chain to have different chain names for each ligand
# this is necessary for ost to not complain about duplicate chain names
editor = ligand_entity.EditXCS()
editor.RenameChain(list(ligand_entity.chains)[0], f"LIG_{i}")
ligand_entity = ligand_entity.Select("ele != H")
ligand_views.append(ligand_entity)

return cls(
name=name,
receptor_file=receptor_file,
Expand Down Expand Up @@ -137,8 +146,8 @@ def from_model_files(
----------
model_file : Path
The path to the model file.
model_ligand_sdf_files : list[str | Path]
The list of ligand SDF files.
model_ligand_files : list[Path]
The list of ligand SDF files OR a path to directory with SDF files.
reference : PlinderSystem
The reference system.
score_protein : bool, default=False
Expand Down Expand Up @@ -419,7 +428,9 @@ def get_average_posebusters(self) -> dict[str, list[Any]]:
avg_scores[f"posebusters_{key}"] = None
return avg_scores

def summarize_scores(self) -> dict[str, Any]:
def summarize_scores(self) -> dict[str, dict[str, Any]]:
if self.ligand_scores is None:
return {}
scores: dict[str, Any] = dict(
model=self.model.name,
reference=self.reference.name,
Expand All @@ -432,21 +443,7 @@ def summarize_scores(self) -> dict[str, Any]:
fraction_model_ligands_mapped=self.num_mapped_model_ligands
/ self.model.num_ligands,
)
score_list = ["lddt_pli", "lddt_lp", "bisy_rmsd"]
for score_name in score_list:
(
scores[f"{score_name}_ave"],
scores[f"{score_name}_wave"],
) = self.get_average_ligand_scores(score_name)
if self.score_posebusters:
scores.update(self.get_average_posebusters())
if self.score_protein and self.protein_scores is not None:
scores["fraction_reference_proteins_mapped"] = (
self.num_mapped_reference_proteins / self.reference.num_proteins
)
scores["fraction_model_proteins_mapped"] = (
self.num_mapped_proteins / self.model.num_proteins
)
scores["lddt"] = self.protein_scores.lddt
scores["bb_lddt"] = self.protein_scores.bb_lddt
per_chain_lddt = list(self.protein_scores.per_chain_lddt.values())
Expand All @@ -455,4 +452,21 @@ def summarize_scores(self) -> dict[str, Any]:
scores["per_chain_bb_lddt_ave"] = np.mean(per_chain_bb_lddt)
if self.protein_scores.score_oligo:
scores.update(self.protein_scores.oligomer_scores)
return scores
# aggregated shared_scores to add
score_list = ["lddt_pli", "lddt_lp", "bisy_rmsd"]
for score_name in score_list:
(
scores[f"{score_name}_ave"],
scores[f"{score_name}_wave"],
) = self.get_average_ligand_scores(score_name)
if self.score_posebusters:
scores.update(self.get_average_posebusters())
# individual ligand scores
per_lig_scores: dict[str, dict[str, Any]] = {}
for ligand_score in self.ligand_scores:
per_lig_scores[ligand_score.chain] = copy.deepcopy(scores)
for score_name in score_list:
per_lig_scores[ligand_score.chain][score_name] = ligand_score.scores[
score_name
]
return per_lig_scores
72 changes: 48 additions & 24 deletions src/plinder/eval/docking/write_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pathlib import Path
from typing import Any

import numpy as np
import ost
import pandas as pd

Expand All @@ -23,12 +22,12 @@ def evaluate(
model_system_id: str,
reference_system_id: str,
receptor_file: Path | str,
ligand_files: list[Path | str],
ligand_file_list: list[Path],
predictions_dir: Path | None = None,
flexible: bool = False,
posebusters: bool = False,
posebusters_full: bool = False,
) -> dict[str, Any]:
) -> dict[str, dict[str, Any]]:
"""
Evaluate a single receptor - ligand pair
Expand All @@ -40,8 +39,8 @@ def evaluate(
The PLINDER systemID of the reference system
receptor_file: Path
The path to the receptor CIF/PDB file
ligand_files: list[Path]
The path to the ligand SDF files
ligand_file_list: list[Path]
The list of Paths to the ligand SDF files OR directory with .sdf files
predictions_dir: Path | None
The path to the directory containing the predictions (used if receptor/ligand files are not provided)
flexible: bool
Expand All @@ -53,17 +52,25 @@ def evaluate(
"""
reference_system = PlinderSystem(system_id=reference_system_id)
receptor_file = Path(receptor_file)
ligand_file_paths = [Path(ligand_file) for ligand_file in ligand_files]

if not receptor_file.exists():
if predictions_dir is not None and (predictions_dir / receptor_file).exists():
receptor_file = predictions_dir / receptor_file
if not receptor_file.exists():
raise FileNotFoundError(f"Receptor file {receptor_file} could not be found")
if not all(ligand_file.exists() for ligand_file in ligand_file_paths):
if predictions_dir is not None:
ligand_file_paths = [
predictions_dir / ligand_file for ligand_file in ligand_file_paths
]

ligand_file_paths = []
for ligand_file in ligand_file_list:
if ligand_file.exists():
if ligand_file.is_dir():
ligand_file_paths += list(ligand_file.glob("*.sdf"))
elif ligand_file.suffix == ".sdf":
ligand_file_paths.append(ligand_file)
elif ligand_file is not None:
if predictions_dir is not None and (predictions_dir / ligand_file).exists():
ligand_file = predictions_dir / ligand_file
ligand_file_paths.append(ligand_file)

assert ligand_file_paths is not None and all(
ligand_file.exists() for ligand_file in ligand_file_paths
), f"Ligand files {ligand_file_paths} could not be found"
Expand Down Expand Up @@ -113,11 +120,14 @@ def write_scores_as_json(
return
reference_system = PlinderSystem(system_id=scorer_input.reference_system_id)
receptor_file = None
if scorer_input.receptor_file is not None and not np.isnan(
scorer_input.receptor_file
if (
scorer_input.receptor_file is not None
and not str(scorer_input.receptor_file) == "nan"
):
receptor_file = Path(scorer_input.receptor_file)
ligand_file = Path(scorer_input.ligand_file)
# single ligand file path, directory
# TODO: or some concatenated list of paths with some separator
ligand_file_path = [Path(scorer_input.ligand_file)]

if receptor_file is not None and not receptor_file.exists():
if (
Expand All @@ -128,20 +138,22 @@ def write_scores_as_json(
else:
assert reference_system.receptor_cif is not None
receptor_file = Path(reference_system.receptor_cif)
else:

elif receptor_file is None:
LOG.warning(
"No receptor file provided, using reference receptor (RIGID REDOCKING!!)"
)
receptor_file = Path(reference_system.receptor_cif)
if ligand_file is not None and not ligand_file.exists():
if predictions_dir is not None and (predictions_dir / ligand_file).exists():
ligand_file = predictions_dir / ligand_file
elif not receptor_file.exists():
raise FileNotFoundError(f"Receptor file {receptor_file} could not be found")

assert receptor_file is not None
assert ligand_file is not None
assert ligand_file_path is not None
scores = evaluate(
scorer_input.id,
scorer_input.reference_system_id,
receptor_file,
[
ligand_file,
], # TODO: change to accept multi-ligand prediction input
ligand_file_path,
predictions_dir,
flexible,
posebusters,
Expand Down Expand Up @@ -227,8 +239,20 @@ def score_test_set(
LOG.error(
f"score_test_set: Error loading scores file {json_file}: {e}"
)
s["rank"] = json_file.stem
scores.append(s)
for ligand_chain_name, ligand_scores in s.items():
if isinstance(ligand_scores, dict):
scores_dict = {
"chain": ligand_chain_name,
"rank": json_file.stem,
}
scores_dict.update(
{
k: v
for k, v in ligand_scores.items()
if k != "ligand_scores"
}
)
scores.append(scores_dict)
pd.DataFrame(scores).to_parquet(output_file, index=False)


Expand Down
Loading

0 comments on commit ba47d9b

Please sign in to comment.