From f522d16fd83156b2d7a82b97a683eda766a4e111 Mon Sep 17 00:00:00 2001 From: Alexander Matthew Payne Date: Sat, 27 Apr 2024 16:15:27 -0400 Subject: [PATCH] write main mcss clustering loop --- examples/chemical_series_clustering.ipynb | 204 +++++++++++++++++++++- harbor/clustering/hierarchical.py | 204 +++++++++++++++++++--- harbor/plotting/ligands.py | 79 +++++++++ harbor/similarity/mcss.py | 18 ++ 4 files changed, 474 insertions(+), 31 deletions(-) diff --git a/examples/chemical_series_clustering.ipynb b/examples/chemical_series_clustering.ipynb index 2d5b018..29ed411 100644 --- a/examples/chemical_series_clustering.ipynb +++ b/examples/chemical_series_clustering.ipynb @@ -66,7 +66,7 @@ "outputs": [], "source": [ "from rdkit import Chem\n", - "mols = Chem.SDMolSupplier(mypath)" + "mols = Chem.SDMolSupplier(str(mypath))" ] }, { @@ -93,7 +93,7 @@ "metadata": {}, "outputs": [], "source": [ - "# define the grid to show the scafffolds\n", + "# define the grid to show the scaffolds\n", "grid = mols2grid.display(mols)" ] }, @@ -106,12 +106,206 @@ "grid" ] }, + { + "cell_type": "markdown", + "source": [ + "# MCSS-based Clustering" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "from harbor.clustering.hierarchical import ClusterResults, ClusterCenter, HeirarchicalClustering\n", + "from openeye import oechem" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "mol: Chem.Mol = mols[0]\n", + "mol.GetPropsAsDict()" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "oemols = []\n", + "mol_ids = []\n", + "for rdkit_mol in mols[:20]:\n", + " smiles = Chem.MolToSmiles(rdkit_mol)\n", + " properties = rdkit_mol.GetPropsAsDict()\n", + " mol_ids.append(properties[\"Compound_ID\"])\n", + " mol = oechem.OEMol()\n", + " oechem.OESmilesToMol(mol, smiles)\n", + " oemols.append(mol)" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "from harbor.clustering import hierarchical as h\n", + "from importlib import reload\n", + "reload(h)" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "clusterer = h.HeirarchicalClustering(molecules=oemols, mol_ids=mol_ids)" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "clusters = clusterer.cluster(max_iterations=10)" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "len(clusters)" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "def get_descendents(cluster):\n", + " descendents = []\n", + " for child in cluster.children:\n", + " if isinstance(child, str):\n", + " descendents.append(cluster)\n", + " else:\n", + " descendents.extend(get_descendents(child))\n", + " return descendents" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "from harbor.plotting import ligands as l\n", + "reload(l)" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "ids_found = []\n", + "for cluster_id, cluster in clusters.items():\n", + " print(f\"Cluster {cluster_id}\")\n", + " descendents = get_descendents(cluster)\n", + " print(f\"Children: {len(descendents)}\")\n", + " l.plot_ligands_with_mcs(filename=f\"cluster_{cluster_id}.png\", mols=[desc.repr for desc in descendents], mcs_mol=cluster.repr)\n", + " ids_found.extend([desc.children[0] for desc in descendents])" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "set(ids_found)" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "set(mol_ids) - set(ids_found)" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "def get_row_col(i, max_cols, zero_indexed=True):\n", + " row = i // max_cols + (0 if zero_indexed else 1)\n", + " col = i % max_cols + (0 if zero_indexed else 1)\n", + " return row, col" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "for i in range(6):\n", + " print(get_row_col(i, 4, zero_indexed=False))" + ], + "metadata": { + "collapsed": false + }, + "execution_count": null + }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, "outputs": [], - "source": [] + "source": [], + "metadata": { + "collapsed": false + } } ], "metadata": { diff --git a/harbor/clustering/hierarchical.py b/harbor/clustering/hierarchical.py index b30c72e..ea88978 100644 --- a/harbor/clustering/hierarchical.py +++ b/harbor/clustering/hierarchical.py @@ -1,47 +1,50 @@ from harbor.similarity.mcss import get_mcs_mol, get_n_to_n_mcs from openeye import oechem -from pydantic import Field, BaseModel +from pydantic import Field, BaseModel, model_validator +from typing import Union import numpy as np -class Cluster(BaseModel): +class ClusterCenter(BaseModel): class Config: arbitrary_types_allowed = True cluster_id: str = Field(..., description="An id") - children: list[str] = Field(..., description="Children") + children: list[Union[str, "ClusterCenter"]] = Field(..., description="Children") repr: oechem.OEMol height: int = Field(..., description="maximum number of layers above 0") @classmethod - def from_mol(cls, layer_id, mol: oechem.OEMol) -> "Cluster": + def from_mol(cls, layer_id, mol: oechem.OEMol) -> "ClusterCenter": return cls( cluster_id=f"{0}_{layer_id}", repr=mol, - children=[f"{mol.GetTitle()}"], + children=[f"{layer_id}"], height=0, ) @classmethod - def from_clusters(cls, layer_id, cluster1: "Cluster", cluster2: "Cluster"): - height = max(cluster1.height, cluster2.height) + 1 + def from_clusters( + cls, height, layer_id, cluster1: "ClusterCenter", cluster2: "ClusterCenter" + ): repr = get_mcs_mol(cluster1.repr, cluster2.repr) return cls( cluster_id=f"{height}_{layer_id}", - children=[cluster1.cluster_id, cluster2.cluster_id], + children=[cluster1, cluster2], repr=repr, height=height, ) class ClusterResults(BaseModel): - new: list[Cluster] = Field(..., description="Newly formed clusters") - singles: list[Cluster] = Field( - ..., description="Clusters which are out for this round" + new: list[ClusterCenter] = Field(..., description="Newly formed clusters") + singles: list[ClusterCenter] = Field( + ..., description="Cluster centers which are out for this round" ) - outliers: list[Cluster] = Field( + outliers: list[ClusterCenter] = Field( ..., - description="Cluster centers further than the cutoff from any other molecule and therefore should be ignored for the rest of the clustering", + description="Cluster centers further than the cutoff from any other molecule " + "and therefore should be ignored for the rest of the clustering", ) @@ -49,35 +52,72 @@ def get_clusters_from_mcs_matrix( matrix: np.ndarray, clusters, cutoff, + height: int = Field(..., description="maximum number of layers above 0"), ) -> ClusterResults: """ Get all pairs for which the maximum is reciprocal and is greater than """ clusters = np.array(clusters) + print([cluster.cluster_id for cluster in clusters]) - potential_match = np.argsort(matrix)[:, -2] - max_mcs = np.sort(matrix)[:, -2] - potential_match = np.array(potential_match, dtype="object") + # set the diagonal to 0 + np.fill_diagonal(matrix, 0) + print(matrix) + + potential_match = np.argsort(matrix)[:, -1] + + duplicates = np.where(np.bincount(potential_match) > 1)[0] + max_mcs = np.sort(matrix)[:, -1] + print(potential_match) + print(max_mcs) pairs = [] singles = [] outliers = [] + ignore = [] for i in range(len(potential_match)): + print(i) if max_mcs[i] < cutoff: outliers.append(i) continue - if np.isnan(potential_match[i]): - continue + j = potential_match[i] - if np.isnan(potential_match[j]): + if i in ignore: + print(i) continue - if i == potential_match[j]: - pairs.append((i, j)) - potential_match[j] = np.nan - else: + if j in singles: singles.append(i) + continue + if j in duplicates: + # pick a single pair to combine, and remove all others + all_matches = np.where(potential_match == j)[0] + print(i, j, all_matches) + + # for each reciprocal match, find the one with the highest mcs + best_match = all_matches[np.argmax(max_mcs[all_matches])] + if best_match == i: + pairs.append((i, j)) + ignore.extend([i, j]) + else: + singles.append(i) + else: + if potential_match[j] == i: + pairs.append((i, j)) + ignore.extend([i, j]) + else: + singles.append(i) + ignore.append(i) + print("Pairs", pairs) + print("Singles", singles) + print("Outliers", outliers) + print("Ignore", ignore) + print("Pairs", pairs) + print("Singles", singles) + print("Outliers", outliers) + print("Ignore", ignore) + new = [ - Cluster.from_clusters(i, clusters[j], clusters[k]) + ClusterCenter.from_clusters(height, i, clusters[j], clusters[k]) for i, (j, k) in enumerate(pairs) ] return ClusterResults( @@ -89,11 +129,123 @@ def get_clusters_from_mcs_matrix( def get_clusters_from_pairs(clusters, pairs): return [ - Cluster.from_clusters(i, clusters[j], clusters[k]) + ClusterCenter.from_clusters(i, clusters[j], clusters[k]) for i, (j, k) in enumerate(pairs) ] -def mcs_wrapper(clusters: list[Cluster]): +def mcs_wrapper(clusters: list[ClusterCenter]): mols = [cluster.repr for cluster in clusters] return get_n_to_n_mcs(mols) + + +class HeirarchicalClustering(BaseModel): + """ + A class to run and process heirarchical clustering for molecules. + """ + + molecules: list[oechem.OEMol] = Field( + ..., description="A list of molecules to cluster" + ) + mol_ids: list[str] = Field(description="Molecule IDs") + + class Config: + arbitrary_types_allowed = True + + from pydantic import field_validator, model_validator + + @model_validator(mode="before") + def check_lengths_match(cls, values): + mol_ids = values["mol_ids"] + molecules = values["molecules"] + + if not mol_ids: + mol_ids = range(len(molecules)) + + if not len(mol_ids) == len(molecules): + raise ValueError( + f"Length of mol_ids ({len(mol_ids)}) does not match molecules ({len(molecules)}" + ) + return values + + @property + def num_mols(self): + return len(self.molecules) + + def cluster(self, max_iterations: int = 10, cutoff: int = 12): + + # Make initial clusters + clusters = [ + ClusterCenter.from_mol(mol_id, mol) + for mol_id, mol in zip(self.mol_ids, self.molecules) + ] + + # keep track of the molecule outliers + cluster_records = {} + + # main clustering loop + i = 0 + while i <= max_iterations: + i += 1 + # generate n x n matrix of MCS values + mcs_matrix = mcs_wrapper(clusters) + + # get new cluster centers + results = get_clusters_from_mcs_matrix( + mcs_matrix, clusters, cutoff=cutoff, height=i + ) + + print("New clusters") + print( + [ + child.cluster_id if isinstance(child, ClusterCenter) else child + for cluster in results.new + for child in cluster.children + ] + ) + + print("Singles") + print( + [ + child.cluster_id if isinstance(child, ClusterCenter) else child + for cluster in results.singles + for child in cluster.children + ] + ) + + print("Outliers") + print( + [ + child.cluster_id if isinstance(child, ClusterCenter) else child + for cluster in results.outliers + for child in cluster.children + ] + ) + + # update clusters with the new cluster centers and the previously unmatched singles + clusters = results.new + results.singles + + cluster_records.update( + {cluster.cluster_id: cluster for cluster in results.outliers} + ) + if len(results.singles) == 0: + cluster_records.update( + {cluster.cluster_id: cluster for cluster in results.new} + ) + break + return cluster_records + + +class HierarchicalClusteringResults(BaseModel): + """ + A class to store the results of hierarchical clustering. + """ + + clusters: list[ClusterCenter] = Field(..., description="Cluster centers") + outliers: list[ClusterCenter] = Field( + ..., + description="Molecules different enough from the rest that they were excluded by the cutoff", + ) + clusterer: HeirarchicalClustering = Field(..., description="The clusterer used") + cutoff: int = Field(..., description="The cutoff used for clustering") + max_iterations: int = Field(..., description="The maximum number of iterations") diff --git a/harbor/plotting/ligands.py b/harbor/plotting/ligands.py index b573221..327530f 100644 --- a/harbor/plotting/ligands.py +++ b/harbor/plotting/ligands.py @@ -82,3 +82,82 @@ def plot_aligned_ligands( fitcell = grid.GetCell(row, col) oedepict.OERenderMolecule(fitcell, fitdisp) oedepict.OEWriteImage(filename, image) + + +def get_mcs_from_mcs_mol(mcs_mol: oechem.OEMol): + # Prep MCS + atomexpr = ( + oechem.OEExprOpts_Aromaticity + | oechem.OEExprOpts_AtomicNumber + | oechem.OEExprOpts_FormalCharge + | oechem.OEExprOpts_RingMember + ) + bondexpr = oechem.OEExprOpts_Aromaticity | oechem.OEExprOpts_BondOrder + + # create maximum common substructure object + pattern_query = oechem.OEQMol(mcs_mol) + pattern_query.BuildExpressions(atomexpr, bondexpr) + mcss = oechem.OEMCSSearch(pattern_query) + mcss.SetMCSFunc(oechem.OEMCSMaxAtoms()) + return mcss + + +def plot_ligands_with_mcs( + filename: str, + mcs_mol: oechem.OEMol, + mols=list[oechem.OEMol], + max_width: int = 4, + quantum_width=150, + quantum_height=200, +): + n_ligands = len(mols) + # Prepare image + cols = min(max_width, n_ligands) + rows = int(np.ceil(n_ligands / max_width)) + print(rows, cols) + image = oedepict.OEImage(quantum_width * cols, quantum_height * rows) + grid = oedepict.OEImageGrid(image, rows, cols) + opts = oedepict.OE2DMolDisplayOptions( + grid.GetCellWidth(), grid.GetCellHeight(), oedepict.OEScale_AutoScale + ) + opts.SetTitleLocation(oedepict.OETitleLocation_Bottom) + opts.SetHydrogenStyle(oedepict.OEHydrogenStyle_Hidden) + + largest_mol = max(mols, key=lambda x: x.NumAtoms()) + + refscale = oedepict.OEGetMoleculeScale(largest_mol, opts) + oedepict.OEPrepareDepiction(largest_mol) + refdisp = oedepict.OE2DMolDisplay(largest_mol, opts) + refcell = grid.GetCell(1, 1) + oedepict.OERenderMolecule(refcell, refdisp) + + mcss = get_mcs_from_mcs_mol(mcs_mol) + + for i, fitmol in enumerate(mols): + row, col = get_row_col(i, max_width, zero_indexed=False) + print(row, col) + + alignres = oedepict.OEPrepareAlignedDepiction(fitmol, mcss) + + if not alignres.IsValid(): + oedepict.OEPrepareDepiction(fitmol) + opts.SetScale(refscale) + fitdisp = oedepict.OE2DMolDisplay(fitmol, opts) + if alignres.IsValid(): + fitabset = oechem.OEAtomBondSet( + alignres.GetTargetAtoms(), alignres.GetTargetBonds() + ) + oedepict.OEAddHighlighting( + fitdisp, + oechem.OEBlueTint, + oedepict.OEHighlightStyle_BallAndStick, + fitabset, + ) + + oedepict.OEWriteImage(filename, image) + + +def get_row_col(i, max_cols, zero_indexed=True): + row = i // max_cols + (0 if zero_indexed else 1) + col = i % max_cols + (0 if zero_indexed else 1) + return row, col diff --git a/harbor/similarity/mcss.py b/harbor/similarity/mcss.py index b4e0ba8..1160257 100644 --- a/harbor/similarity/mcss.py +++ b/harbor/similarity/mcss.py @@ -87,3 +87,21 @@ def get_mcs_mol(mol1: oechem.OEMol, mol2: oechem.OEMol): except StopIteration: raise RuntimeError return core_fragment + + +def get_mcs_from_mcs_mol(mcs_mol: oechem.OEMol): + # Prep MCS + atomexpr = ( + oechem.OEExprOpts_Aromaticity + | oechem.OEExprOpts_AtomicNumber + | oechem.OEExprOpts_FormalCharge + | oechem.OEExprOpts_RingMember + ) + bondexpr = oechem.OEExprOpts_Aromaticity | oechem.OEExprOpts_BondOrder + + # create maximum common substructure object + pattern_query = oechem.OEQMol(mcs_mol) + pattern_query.BuildExpressions(atomexpr, bondexpr) + mcss = oechem.OEMCSSearch(pattern_query) + mcss.SetMCSFunc(oechem.OEMCSMaxAtoms()) + return mcss