diff --git a/src/plinder/eval/docking/make_plots.py b/src/plinder/eval/docking/make_plots.py index ad31d359..d66ce72b 100644 --- a/src/plinder/eval/docking/make_plots.py +++ b/src/plinder/eval/docking/make_plots.py @@ -43,7 +43,8 @@ class EvaluationResults: def from_scores_and_data_files( cls, score_file: Path, data_file: Path, output_dir: Path, top_n: int = 10 ) -> "EvaluationResults": - scores_df = pd.read_parquet(score_file) + # use only one score per system when plotting aggregated scores + scores_df = pd.read_parquet(score_file).drop_duplicates("reference") data_df = pd.read_parquet(data_file) merged_df = scores_df[scores_df["reference"].isin(data_df["system_id"])].merge( data_df, left_on="reference", right_on="system_id", how="left" diff --git a/src/plinder/eval/docking/stratify_test_set.py b/src/plinder/eval/docking/stratify_test_set.py index 29cf170a..49bf1ec3 100644 --- a/src/plinder/eval/docking/stratify_test_set.py +++ b/src/plinder/eval/docking/stratify_test_set.py @@ -59,7 +59,11 @@ def compute_protein_max_similarities( def compute_ligand_max_similarities( - df: pd.DataFrame, train_label: str, test_label: str, output_file: Path + df: pd.DataFrame, + split_label: str, + train_label: str, + test_label: str, + output_file: Path, ) -> None: if "fp" not in df.columns: smiles_fp_dict = { @@ -68,10 +72,10 @@ def compute_ligand_max_similarities( } df["fp"] = df["ligand_rdkit_canonical_smiles"].map(smiles_fp_dict) - df_test = df.loc[df["split"] == test_label][["system_id", "fp"]].copy() + df_test = df.loc[df[split_label] == test_label][["system_id", "fp"]].copy() df_test["tanimoto_similarity_max"] = smallmolecules.tanimoto_maxsim_matrix( - df.loc[df["split"] == train_label]["fp"].to_list(), + df.loc[df[split_label] == train_label]["fp"].to_list(), df_test["fp"].to_list(), ) df_test.drop("fp", axis=1).groupby("system_id").agg("max").reset_index().to_parquet( @@ -83,6 +87,7 @@ def compute_ligand_max_similarities( class StratifiedTestSet: split_df: pd.DataFrame output_dir: Path + split_label: str = "split" train_label: str = "train" test_label: str = "test" similarity_thresholds: dict[str, int] = field( @@ -127,6 +132,7 @@ def from_split( cls, split_file: Path, output_dir: Path, + split_label: str = "split", train_label: str = "train", test_label: str = "test", overwrite: bool = False, @@ -135,13 +141,14 @@ def from_split( split_df = pd.read_csv(split_file) else: split_df = pd.read_parquet(split_file) - assert all(x in split_df.columns for x in ["split", "system_id"]) - split_df = split_df[split_df["split"].isin([train_label, test_label])][ - ["system_id", "split"] + assert all(x in split_df.columns for x in [split_label, "system_id"]) + split_df = split_df[split_df[split_label].isin([train_label, test_label])][ + [split_label, "system_id"] ].reset_index(drop=True) data = cls( split_df=split_df, output_dir=output_dir, + split_label=split_label, train_label=train_label, test_label=test_label, ) @@ -180,8 +187,16 @@ def compute_train_test_max_similarity( self, overwrite: bool = False ) -> pd.DataFrame: left, right = ( - set(self.split_df[self.split_df["split"] == self.train_label]["system_id"]), - set(self.split_df[self.split_df["split"] == self.test_label]["system_id"]), + set( + self.split_df[self.split_df[self.split_label] == self.train_label][ + "system_id" + ] + ), + set( + self.split_df[self.split_df[self.split_label] == self.test_label][ + "system_id" + ] + ), ) LOG.info( f"compute_train_test_max_similarity: Found {len(left)} train and {len(right)} test systems" @@ -204,6 +219,7 @@ def compute_train_test_max_similarity( df = df.merge(self.split_df, on="system_id", how="left") compute_ligand_max_similarities( df, + self.split_label, self.train_label, self.test_label, self.get_filename(metric), @@ -250,9 +266,9 @@ def assign_test_set_quality(self) -> None: "system_id", "in", set( - self.split_df[self.split_df["split"] == self.test_label][ - "system_id" - ] + self.split_df[ + self.split_df[self.split_label] == self.test_label + ]["system_id"] ), ) ], # type: ignore @@ -300,6 +316,12 @@ def stratify_cmd(args: list[str] | None = None) -> None: type=Path, help="Path to output folder where similarity and stratification data are saved", ) + parser.add_argument( + "--split_label", + type=str, + default="split", + help="split= is used to get split systems", + ) parser.add_argument( "--train_label", type=str, @@ -325,6 +347,7 @@ def stratify_cmd(args: list[str] | None = None) -> None: StratifiedTestSet.from_split( split_file=Path(ns.split_file), output_dir=Path(ns.output_dir), + split_label=ns.split_label, train_label=ns.train_label, test_label=ns.test_label, overwrite=ns.overwrite, diff --git a/src/plinder/eval/docking/utils.py b/src/plinder/eval/docking/utils.py index d911c06f..73e98ba9 100644 --- a/src/plinder/eval/docking/utils.py +++ b/src/plinder/eval/docking/utils.py @@ -2,6 +2,7 @@ # Distributed under the terms of the Apache License 2.0 from __future__ import annotations +import copy from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -84,10 +85,18 @@ def from_files( entity = io.LoadMMCIF(receptor_file.as_posix(), fault_tolerant=True) else: entity = io.LoadPDB(receptor_file.as_posix(), fault_tolerant=True) - ligand_views = [ - io.LoadEntity(str(ligand_sdf_file), format="sdf").Select("ele != H") - for ligand_sdf_file in ligand_files - ] + + ligand_views = [] + for i, ligand_file in enumerate(ligand_files): + ligand_entity = io.LoadEntity(str(ligand_file), format="sdf") + + # rename ligand chain to have different chain names for each ligand + # this is necessary for ost to not complain about duplicate chain names + editor = ligand_entity.EditXCS() + editor.RenameChain(list(ligand_entity.chains)[0], f"LIG_{i}") + ligand_entity = ligand_entity.Select("ele != H") + ligand_views.append(ligand_entity) + return cls( name=name, receptor_file=receptor_file, @@ -137,8 +146,8 @@ def from_model_files( ---------- model_file : Path The path to the model file. - model_ligand_sdf_files : list[str | Path] - The list of ligand SDF files. + model_ligand_files : list[Path] + The list of ligand SDF files OR a path to directory with SDF files. reference : PlinderSystem The reference system. score_protein : bool, default=False @@ -419,7 +428,9 @@ def get_average_posebusters(self) -> dict[str, list[Any]]: avg_scores[f"posebusters_{key}"] = None return avg_scores - def summarize_scores(self) -> dict[str, Any]: + def summarize_scores(self) -> dict[str, dict[str, Any]]: + if self.ligand_scores is None: + return {} scores: dict[str, Any] = dict( model=self.model.name, reference=self.reference.name, @@ -432,21 +443,7 @@ def summarize_scores(self) -> dict[str, Any]: fraction_model_ligands_mapped=self.num_mapped_model_ligands / self.model.num_ligands, ) - score_list = ["lddt_pli", "lddt_lp", "bisy_rmsd"] - for score_name in score_list: - ( - scores[f"{score_name}_ave"], - scores[f"{score_name}_wave"], - ) = self.get_average_ligand_scores(score_name) - if self.score_posebusters: - scores.update(self.get_average_posebusters()) if self.score_protein and self.protein_scores is not None: - scores["fraction_reference_proteins_mapped"] = ( - self.num_mapped_reference_proteins / self.reference.num_proteins - ) - scores["fraction_model_proteins_mapped"] = ( - self.num_mapped_proteins / self.model.num_proteins - ) scores["lddt"] = self.protein_scores.lddt scores["bb_lddt"] = self.protein_scores.bb_lddt per_chain_lddt = list(self.protein_scores.per_chain_lddt.values()) @@ -455,4 +452,21 @@ def summarize_scores(self) -> dict[str, Any]: scores["per_chain_bb_lddt_ave"] = np.mean(per_chain_bb_lddt) if self.protein_scores.score_oligo: scores.update(self.protein_scores.oligomer_scores) - return scores + # aggregated shared_scores to add + score_list = ["lddt_pli", "lddt_lp", "bisy_rmsd"] + for score_name in score_list: + ( + scores[f"{score_name}_ave"], + scores[f"{score_name}_wave"], + ) = self.get_average_ligand_scores(score_name) + if self.score_posebusters: + scores.update(self.get_average_posebusters()) + # individual ligand scores + per_lig_scores: dict[str, dict[str, Any]] = {} + for ligand_score in self.ligand_scores: + per_lig_scores[ligand_score.chain] = copy.deepcopy(scores) + for score_name in score_list: + per_lig_scores[ligand_score.chain][score_name] = ligand_score.scores[ + score_name + ] + return per_lig_scores diff --git a/src/plinder/eval/docking/write_scores.py b/src/plinder/eval/docking/write_scores.py index 48d7262f..8e94ee2c 100644 --- a/src/plinder/eval/docking/write_scores.py +++ b/src/plinder/eval/docking/write_scores.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import Any -import numpy as np import ost import pandas as pd @@ -23,12 +22,12 @@ def evaluate( model_system_id: str, reference_system_id: str, receptor_file: Path | str, - ligand_files: list[Path | str], + ligand_file_list: list[Path], predictions_dir: Path | None = None, flexible: bool = False, posebusters: bool = False, posebusters_full: bool = False, -) -> dict[str, Any]: +) -> dict[str, dict[str, Any]]: """ Evaluate a single receptor - ligand pair @@ -40,8 +39,8 @@ def evaluate( The PLINDER systemID of the reference system receptor_file: Path The path to the receptor CIF/PDB file - ligand_files: list[Path] - The path to the ligand SDF files + ligand_file_list: list[Path] + The list of Paths to the ligand SDF files OR directory with .sdf files predictions_dir: Path | None The path to the directory containing the predictions (used if receptor/ligand files are not provided) flexible: bool @@ -53,17 +52,25 @@ def evaluate( """ reference_system = PlinderSystem(system_id=reference_system_id) receptor_file = Path(receptor_file) - ligand_file_paths = [Path(ligand_file) for ligand_file in ligand_files] + if not receptor_file.exists(): if predictions_dir is not None and (predictions_dir / receptor_file).exists(): receptor_file = predictions_dir / receptor_file if not receptor_file.exists(): raise FileNotFoundError(f"Receptor file {receptor_file} could not be found") - if not all(ligand_file.exists() for ligand_file in ligand_file_paths): - if predictions_dir is not None: - ligand_file_paths = [ - predictions_dir / ligand_file for ligand_file in ligand_file_paths - ] + + ligand_file_paths = [] + for ligand_file in ligand_file_list: + if ligand_file.exists(): + if ligand_file.is_dir(): + ligand_file_paths += list(ligand_file.glob("*.sdf")) + elif ligand_file.suffix == ".sdf": + ligand_file_paths.append(ligand_file) + elif ligand_file is not None: + if predictions_dir is not None and (predictions_dir / ligand_file).exists(): + ligand_file = predictions_dir / ligand_file + ligand_file_paths.append(ligand_file) + assert ligand_file_paths is not None and all( ligand_file.exists() for ligand_file in ligand_file_paths ), f"Ligand files {ligand_file_paths} could not be found" @@ -113,11 +120,14 @@ def write_scores_as_json( return reference_system = PlinderSystem(system_id=scorer_input.reference_system_id) receptor_file = None - if scorer_input.receptor_file is not None and not np.isnan( - scorer_input.receptor_file + if ( + scorer_input.receptor_file is not None + and not str(scorer_input.receptor_file) == "nan" ): receptor_file = Path(scorer_input.receptor_file) - ligand_file = Path(scorer_input.ligand_file) + # single ligand file path, directory + # TODO: or some concatenated list of paths with some separator + ligand_file_path = [Path(scorer_input.ligand_file)] if receptor_file is not None and not receptor_file.exists(): if ( @@ -128,20 +138,22 @@ def write_scores_as_json( else: assert reference_system.receptor_cif is not None receptor_file = Path(reference_system.receptor_cif) - else: + + elif receptor_file is None: + LOG.warning( + "No receptor file provided, using reference receptor (RIGID REDOCKING!!)" + ) receptor_file = Path(reference_system.receptor_cif) - if ligand_file is not None and not ligand_file.exists(): - if predictions_dir is not None and (predictions_dir / ligand_file).exists(): - ligand_file = predictions_dir / ligand_file + elif not receptor_file.exists(): + raise FileNotFoundError(f"Receptor file {receptor_file} could not be found") + assert receptor_file is not None - assert ligand_file is not None + assert ligand_file_path is not None scores = evaluate( scorer_input.id, scorer_input.reference_system_id, receptor_file, - [ - ligand_file, - ], # TODO: change to accept multi-ligand prediction input + ligand_file_path, predictions_dir, flexible, posebusters, @@ -227,8 +239,20 @@ def score_test_set( LOG.error( f"score_test_set: Error loading scores file {json_file}: {e}" ) - s["rank"] = json_file.stem - scores.append(s) + for ligand_chain_name, ligand_scores in s.items(): + if isinstance(ligand_scores, dict): + scores_dict = { + "chain": ligand_chain_name, + "rank": json_file.stem, + } + scores_dict.update( + { + k: v + for k, v in ligand_scores.items() + if k != "ligand_scores" + } + ) + scores.append(scores_dict) pd.DataFrame(scores).to_parquet(output_file, index=False) diff --git a/tests/test_eval.py b/tests/test_eval.py index 519b9429..635182e5 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -62,46 +62,51 @@ def test_single_protein_single_ligand_scoring( score_posebusters=True, ).summarize_scores() true_scores = { - "model": "1a3b__1__1.B__1.D", - "reference": "1a3b__1__1.B__1.D", - "num_reference_ligands": 1, - "num_model_ligands": 1, - "num_reference_proteins": 1, - "num_model_proteins": 1, - "fraction_reference_ligands_mapped": 1.0, - "fraction_model_ligands_mapped": 1.0, - "lddt_pli_ave": 0.8581504702194357, - "lddt_pli_wave": 0.8581504702194357, - "bisy_rmsd_ave": 1.6171839155715722, - "bisy_rmsd_wave": 1.6171839155715722, - "lddt_lp_ave": 1.0, - "lddt_lp_wave": 1.0, - "posebusters_mol_pred_loaded": True, - "posebusters_mol_cond_loaded": True, - "posebusters_sanitization": True, - "posebusters_all_atoms_connected": True, - "posebusters_bond_lengths": True, - "posebusters_bond_angles": True, - "posebusters_internal_steric_clash": True, - "posebusters_aromatic_ring_flatness": True, - "posebusters_double_bond_flatness": True, - "posebusters_internal_energy": True, - "posebusters_protein-ligand_maximum_distance": True, - "posebusters_minimum_distance_to_protein": False, - "posebusters_minimum_distance_to_organic_cofactors": True, - "posebusters_minimum_distance_to_inorganic_cofactors": True, - "posebusters_minimum_distance_to_waters": True, - "posebusters_volume_overlap_with_protein": True, - "posebusters_volume_overlap_with_organic_cofactors": True, - "posebusters_volume_overlap_with_inorganic_cofactors": True, - "posebusters_volume_overlap_with_waters": True, - "fraction_reference_proteins_mapped": 1.0, - "fraction_model_proteins_mapped": 1.0, - "lddt": 1.0, - "bb_lddt": 1.0, - "per_chain_lddt_ave": 0.9960159362549801, - "per_chain_bb_lddt_ave": 1.0, + "LIG_0": { + "model": "1a3b__1__1.B__1.D", + "reference": "1a3b__1__1.B__1.D", + "num_reference_ligands": 1, + "num_model_ligands": 1, + "num_reference_proteins": 1, + "num_model_proteins": 1, + "fraction_reference_ligands_mapped": 1.0, + "fraction_model_ligands_mapped": 1.0, + "lddt": 1.0, + "bb_lddt": 1.0, + "per_chain_lddt_ave": 0.9960159362549801, + "per_chain_bb_lddt_ave": 1.0, + "lddt_pli_ave": 0.8581504702194357, + "lddt_pli_wave": 0.8581504702194357, + "lddt_lp_ave": 1.0, + "lddt_lp_wave": 1.0, + "bisy_rmsd_ave": 1.6171839155715722, + "bisy_rmsd_wave": 1.6171839155715722, + "posebusters_mol_pred_loaded": 1.0, + "posebusters_mol_cond_loaded": 1.0, + "posebusters_sanitization": 1.0, + "posebusters_inchi_convertible": 1.0, + "posebusters_all_atoms_connected": 1.0, + "posebusters_bond_lengths": 1.0, + "posebusters_bond_angles": 1.0, + "posebusters_internal_steric_clash": 1.0, + "posebusters_aromatic_ring_flatness": 1.0, + "posebusters_double_bond_flatness": 1.0, + "posebusters_internal_energy": 1.0, + "posebusters_protein-ligand_maximum_distance": 1.0, + "posebusters_minimum_distance_to_protein": 0.0, + "posebusters_minimum_distance_to_organic_cofactors": 1.0, + "posebusters_minimum_distance_to_inorganic_cofactors": 1.0, + "posebusters_minimum_distance_to_waters": 1.0, + "posebusters_volume_overlap_with_protein": 1.0, + "posebusters_volume_overlap_with_organic_cofactors": 1.0, + "posebusters_volume_overlap_with_inorganic_cofactors": 1.0, + "posebusters_volume_overlap_with_waters": 1.0, + "lddt_pli": 0.8581504702194357, + "lddt_lp": 1.0, + "bisy_rmsd": 1.6171839155715722, + } } + for k in true_scores: assert k in scores if type(true_scores[k]) == float: @@ -125,49 +130,53 @@ def test_multi_protein_single_ligand_scoring( score_posebusters=True, ).summarize_scores() true_scores = { - "model": "1ai5__1__1.A_1.B__1.D", - "reference": "1ai5__1__1.A_1.B__1.D", - "num_reference_ligands": 1, - "num_model_ligands": 1, - "num_reference_proteins": 2, - "num_model_proteins": 2, - "fraction_reference_ligands_mapped": 1.0, - "fraction_model_ligands_mapped": 1.0, - "lddt_pli_ave": 0.5106951871657754, - "lddt_pli_wave": 0.5106951871657754, - "bisy_rmsd_ave": 3.6651428915654645, - "bisy_rmsd_wave": 3.6651428915654645, - "lddt_lp_ave": 1.0, - "lddt_lp_wave": 1.0, - "posebusters_mol_pred_loaded": True, - "posebusters_mol_cond_loaded": True, - "posebusters_sanitization": True, - "posebusters_all_atoms_connected": True, - "posebusters_bond_lengths": True, - "posebusters_bond_angles": True, - "posebusters_internal_steric_clash": True, - "posebusters_aromatic_ring_flatness": True, - "posebusters_double_bond_flatness": True, - "posebusters_internal_energy": True, - "posebusters_protein-ligand_maximum_distance": True, - "posebusters_minimum_distance_to_protein": False, - "posebusters_minimum_distance_to_organic_cofactors": True, - "posebusters_minimum_distance_to_inorganic_cofactors": True, - "posebusters_minimum_distance_to_waters": True, - "posebusters_volume_overlap_with_protein": False, - "posebusters_volume_overlap_with_organic_cofactors": True, - "posebusters_volume_overlap_with_inorganic_cofactors": True, - "posebusters_volume_overlap_with_waters": True, - "fraction_reference_proteins_mapped": 1.0, - "fraction_model_proteins_mapped": 1.0, - "lddt": 1.0, - "bb_lddt": 1.0, - "per_chain_lddt_ave": 0.9991023339317774, - "per_chain_bb_lddt_ave": 1.0, - "qs_global": 1.0, - "qs_best": 1.0, - "dockq_wave": 1.0, - "dockq_ave": 1.0, + "LIG_0": { + "model": "1ai5__1__1.A_1.B__1.D", + "reference": "1ai5__1__1.A_1.B__1.D", + "num_reference_ligands": 1, + "num_model_ligands": 1, + "num_reference_proteins": 2, + "num_model_proteins": 2, + "fraction_reference_ligands_mapped": 1.0, + "fraction_model_ligands_mapped": 1.0, + "lddt": 1.0, + "bb_lddt": 1.0, + "per_chain_lddt_ave": 0.9991023339317774, + "per_chain_bb_lddt_ave": 1.0, + "qs_global": 1.0, + "qs_best": 1.0, + "dockq_wave": 1.0, + "dockq_ave": 1.0, + "lddt_pli_ave": 0.5106951871657754, + "lddt_pli_wave": 0.5106951871657754, + "lddt_lp_ave": 1.0, + "lddt_lp_wave": 1.0, + "bisy_rmsd_ave": 3.6651428915654645, + "bisy_rmsd_wave": 3.6651428915654645, + "posebusters_mol_pred_loaded": 1.0, + "posebusters_mol_cond_loaded": 1.0, + "posebusters_sanitization": 1.0, + "posebusters_inchi_convertible": 1.0, + "posebusters_all_atoms_connected": 1.0, + "posebusters_bond_lengths": 1.0, + "posebusters_bond_angles": 1.0, + "posebusters_internal_steric_clash": 1.0, + "posebusters_aromatic_ring_flatness": 1.0, + "posebusters_double_bond_flatness": 1.0, + "posebusters_internal_energy": 1.0, + "posebusters_protein-ligand_maximum_distance": 1.0, + "posebusters_minimum_distance_to_protein": 0.0, + "posebusters_minimum_distance_to_organic_cofactors": 1.0, + "posebusters_minimum_distance_to_inorganic_cofactors": 1.0, + "posebusters_minimum_distance_to_waters": 1.0, + "posebusters_volume_overlap_with_protein": 0.0, + "posebusters_volume_overlap_with_organic_cofactors": 1.0, + "posebusters_volume_overlap_with_inorganic_cofactors": 1.0, + "posebusters_volume_overlap_with_waters": 1.0, + "lddt_pli": 0.5106951871657754, + "lddt_lp": 1.0, + "bisy_rmsd": 3.6651428915654645, + } } for k in true_scores: assert k in scores @@ -179,6 +188,9 @@ def test_multi_protein_single_ligand_scoring( assert scores[k] == true_scores[k], f"{k}: {scores[k]} != {true_scores[k]}" +# TODO: add multiligand test! + + @pytest.fixture def prediction_csv(read_plinder_eval_mount, mock_cpl_eval, tmp_path): csv = f"""\ @@ -210,7 +222,7 @@ def test_evaluate_stratify_plot_cmds(prediction_csv, mock_cpl_eval): score_df = pd.read_parquet(f"{prediction_csv.parent}/scores.parquet") assert np.allclose( - score_df.sort_values(by="reference").bisy_rmsd_wave.to_list(), + score_df.sort_values(by="reference").bisy_rmsd.to_list(), [1.617184, 3.665143], )