Skip to content

Commit

Permalink
Version 1.0.0 push
Browse files Browse the repository at this point in the history
  • Loading branch information
SGenheden committed Feb 11, 2022
1 parent 8e1779d commit ec507d0
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 47 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

## Version 1.0.0 - 2022-02-11

- Generalization of featurization for TED calculations
- Utility routines for route property calculations

## Version 0.2.1 - 2021-12-21

### Trivial changes
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,5 @@ The software is licensed under the MIT license (see LICENSE file), and is free a

## References

1. Genheden S, Engkvist O, Bjerrum E (2020) Clustering of synthetic routes using tree edit distance. ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.13372475.v1
2. Genheden S, Engkvist O, Bjerrum E (2021) Fast prediction of distances between synthetic routes with deep learning. ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.14778150.v1
1. Genheden S, Engkvist O, Bjerrum E (2021) Clustering of synthetic routes using tree edit distance. J. Chem. Inf. Model. 61:3899–3907 [https://doi.org/10.1021/acs.jcim.1c00232](https://doi.org/10.1021/acs.jcim.1c00232)
2. Genheden S, Engkvist O, Bjerrum E (2022) Fast prediction of distances between synthetic routes with deep learning. Mach. Learn. Sci. Technol. 3:015018 [https://doi.org/10.1088/2632-2153/ac4a91](https://doi.org/10.1088/2632-2153/ac4a91)
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
project = "route-distances"
copyright = "2021, Molecular AI group"
author = "Molecular AI group"
release = "0.2.1"
release = "1.0.0"

extensions = [
"sphinx.ext.autodoc",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "route-distances"
version = "0.2.1"
version = "1.0.0"
description = "Models for calculating distances between synthesis routes"
authors = ["Genheden, Samuel <[email protected]>"]
license = "MIT"
Expand Down
62 changes: 31 additions & 31 deletions route_distances/ted/reactiontree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
import itertools
import math
from copy import deepcopy
from typing import List, Union, Iterable, Tuple, Dict, Any
from typing import List, Union, Iterable, Tuple, Callable, Optional
from logging import getLogger

import numpy as np
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from apted import APTED as Apted

from route_distances.ted.utils import TreeContent, AptedConfig
from route_distances.ted.utils import (
TreeContent,
AptedConfig,
StandardFingerprintFactory,
)
from route_distances.validation import validate_dict
from route_distances.utils.type_utils import StrDict

Expand All @@ -30,6 +32,8 @@ class ReactionTreeWrapper:
:param reaction_tree: the reaction tree to wrap
:param content: the content of the route to consider in the distance calculation
:param exhaustive_limit: if the number of possible ordered trees are below this limit create them all
:param fp_factory: the factory of the fingerprint, Morgan fingerprint for molecules and reactions by default
:param dist_func: the distance function to use when renaming nodes
"""

_index_permutations = {
Expand All @@ -41,8 +45,8 @@ def __init__(
reaction_tree: StrDict,
content: Union[str, TreeContent] = TreeContent.MOLECULES,
exhaustive_limit: int = 20,
fp_radius: int = 2,
fp_nbits: int = 2048,
fp_factory: Callable[[StrDict, Optional[StrDict]], None] = None,
dist_func: Callable[[np.ndarray, np.ndarray], float] = None,
) -> None:
validate_dict(reaction_tree)
single_node_tree = not bool(reaction_tree.get("children", []))
Expand All @@ -56,11 +60,11 @@ def __init__(
self._content = TreeContent(content)
self._base_tree = deepcopy(reaction_tree)

self._fp_params = (fp_radius, fp_nbits)
self._add_mol_fingerprints(self._base_tree)
self._fp_factory = fp_factory or StandardFingerprintFactory()
self._add_fingerprints(self._base_tree)

if self._content != TreeContent.MOLECULES and not single_node_tree:
self._add_rxn_fingerprint(self._base_tree["children"][0], self._base_tree)
self._add_fingerprints(self._base_tree["children"][0], self._base_tree)

if self._content == TreeContent.MOLECULES:
self._base_tree = self._remove_children_nodes(self._base_tree)
Expand All @@ -78,6 +82,8 @@ def __init__(
else:
self._trees.append(self._base_tree)

self._dist_func = dist_func

@property
def info(self) -> StrDict:
"""Return a dictionary with internal information about the wrapper"""
Expand Down Expand Up @@ -158,31 +164,24 @@ def distance_to_with_sorting(self, other: "ReactionTreeWrapper") -> float:
:param other: another tree to calculate distance to
:return: the distance
"""
config = AptedConfig(sort_children=True)
config = AptedConfig(sort_children=True, dist_func=self._dist_func)
return Apted(self.first_tree, other.first_tree, config).compute_edit_distance()

def _add_mol_fingerprints(self, tree: Dict[str, Any]) -> None:
mol = Chem.MolFromSmiles(tree["smiles"])
rd_fp = AllChem.GetMorganFingerprintAsBitVect(mol, *self._fp_params)
tree["fingerprint"] = np.zeros((1,), dtype=np.int8)
DataStructs.ConvertToNumpyArray(rd_fp, tree["fingerprint"])
def _add_fingerprints(self, tree: StrDict, parent: StrDict = None) -> None:
if "fingerprint" not in tree:
try:
self._fp_factory(tree, parent)
except ValueError:
pass
if "fingerprint" not in tree:
tree["fingerprint"] = []
tree["sort_key"] = "".join(f"{digit}" for digit in tree["fingerprint"])
if "children" not in tree:
tree["children"] = []

for child in tree["children"]:
for grandchild in child["children"]:
self._add_mol_fingerprints(grandchild)

def _add_rxn_fingerprint(self, node: StrDict, parent: StrDict) -> None:
node["fingerprint"] = parent["fingerprint"].copy()
for reactant in node["children"]:
node["fingerprint"] -= reactant["fingerprint"]
node["sort_key"] = "".join(f"{digit}" for digit in node["fingerprint"])

for child in node["children"]:
for grandchild in child.get("children", []):
self._add_rxn_fingerprint(grandchild, child)
self._add_fingerprints(grandchild, child)

def _create_all_trees(self) -> None:
self._trees = []
Expand Down Expand Up @@ -212,7 +211,7 @@ def _distance_iter_exhaustive(self, other: "ReactionTreeWrapper") -> _FloatItera
self._logger.debug(
f"APTED: Exhaustive search. {len(self.trees)} {len(other.trees)}"
)
config = AptedConfig(randomize=False)
config = AptedConfig(randomize=False, dist_func=self._dist_func)
for tree1, tree2 in itertools.product(self.trees, other.trees):
yield Apted(tree1, tree2, config).compute_edit_distance()

Expand All @@ -222,10 +221,10 @@ def _distance_iter_random(
self._logger.debug(
f"APTED: Heuristic search. {len(self.trees)} {len(other.trees)}"
)
config = AptedConfig(randomize=False)
config = AptedConfig(randomize=False, dist_func=self._dist_func)
yield Apted(self.first_tree, other.first_tree, config).compute_edit_distance()

config = AptedConfig(randomize=True)
config = AptedConfig(randomize=True, dist_func=self._dist_func)
for _ in range(ntimes):
yield Apted(
self.first_tree, other.first_tree, config
Expand All @@ -244,7 +243,7 @@ def _distance_iter_semi_exhaustive(
first_wrapper = other
second_wrapper = self

config = AptedConfig(randomize=False)
config = AptedConfig(randomize=False, dist_func=self._dist_func)
for tree1 in first_wrapper.trees:
yield Apted(
tree1, second_wrapper.first_tree, config
Expand Down Expand Up @@ -279,7 +278,8 @@ def _recurse_tree(node):
def _make_base_copy(node: StrDict) -> StrDict:
return {
"type": node["type"],
"smiles": node["smiles"],
"smiles": node.get("smiles", ""),
"metadata": node.get("metadata"),
"fingerprint": node["fingerprint"],
"sort_key": node["sort_key"],
"children": [],
Expand Down
63 changes: 60 additions & 3 deletions route_distances/ted/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from enum import Enum
from operator import itemgetter

import numpy as np
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from apted import Config as BaseAptedConfig
from scipy.spatial.distance import jaccard as jaccard_dist

from route_distances.utils.type_utils import StrDict
from route_distances.utils.type_utils import StrDict, Callable


class TreeContent(str, Enum):
Expand All @@ -27,20 +30,27 @@ class AptedConfig(BaseAptedConfig):
:param randomize: if True, the children will be shuffled
:param sort_children: if True, the children will be sorted
:param dist_func: the distance function used for renaming nodes, Jaccard by default
"""

def __init__(self, randomize: bool = False, sort_children: bool = False) -> None:
def __init__(
self,
randomize: bool = False,
sort_children: bool = False,
dist_func: Callable[[np.ndarray, np.ndarray], float] = None,
) -> None:
super().__init__()
self._randomize = randomize
self._sort_children = sort_children
self._dist_func = dist_func or jaccard_dist

def rename(self, node1: StrDict, node2: StrDict) -> float:
if node1["type"] != node2["type"]:
return 1

fp1 = node1["fingerprint"]
fp2 = node2["fingerprint"]
return jaccard_dist(fp1, fp2)
return self._dist_func(fp1, fp2)

def children(self, node: StrDict) -> List[StrDict]:
if self._sort_children:
Expand All @@ -50,3 +60,50 @@ def children(self, node: StrDict) -> List[StrDict]:
children = list(node["children"])
random.shuffle(children)
return children


class StandardFingerprintFactory:
"""
Calculate Morgan fingerprint for molecules, and difference fingerprints for reactions
:param radius: the radius of the fingerprint
:param nbits: the fingerprint lengths
"""

def __init__(self, radius: int = 2, nbits: int = 2048) -> None:
self._fp_params = (radius, nbits)

def __call__(self, tree: StrDict, parent: StrDict = None) -> None:
if tree["type"] == "reaction":
if parent is None:
raise ValueError(
"Must specify parent when making Morgan fingerprints for reaction nodes"
)
self._add_rxn_fingerprint(tree, parent)
else:
self._add_mol_fingerprints(tree)

def _add_mol_fingerprints(self, tree: StrDict) -> None:
if "fingerprint" not in tree:
mol = Chem.MolFromSmiles(tree["smiles"])
rd_fp = AllChem.GetMorganFingerprintAsBitVect(mol, *self._fp_params)
tree["fingerprint"] = np.zeros((1,), dtype=np.int8)
DataStructs.ConvertToNumpyArray(rd_fp, tree["fingerprint"])
tree["sort_key"] = "".join(f"{digit}" for digit in tree["fingerprint"])
if "children" not in tree:
tree["children"] = []

for child in tree["children"]:
for grandchild in child["children"]:
self._add_mol_fingerprints(grandchild)

def _add_rxn_fingerprint(self, node: StrDict, parent: StrDict) -> None:
if "fingerprint" not in node:
node["fingerprint"] = parent["fingerprint"].copy()
for reactant in node["children"]:
node["fingerprint"] -= reactant["fingerprint"]
node["sort_key"] = "".join(f"{digit}" for digit in node["fingerprint"])

for child in node["children"]:
for grandchild in child.get("children", []):
self._add_rxn_fingerprint(grandchild, child)
34 changes: 25 additions & 9 deletions route_distances/tools/cluster_aizynth_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from __future__ import annotations
import argparse
import warnings
import functools
import time
import math
from typing import List

import pandas as pd
from tqdm import tqdm

import route_distances.lstm.defaults as defaults
from route_distances.route_distances import route_distances_calculator
from route_distances.clustering import ClusteringHelper
from route_distances.utils.type_utils import RouteDistancesCalculator
Expand All @@ -19,9 +20,12 @@ def _get_args() -> argparse.Namespace:
"Tool to calculate pairwise distances for AiZynthFinder output"
)
parser.add_argument("--files", nargs="+", required=True)
parser.add_argument("--fp_size", type=int, default=defaults.FP_SIZE)
parser.add_argument("--lstm_size", type=int, default=defaults.LSTM_SIZE)
parser.add_argument("--model", required=True)
parser.add_argument("--only_clustering", action="store_true", default=False)
parser.add_argument("--nclusters", type=int, default=None)
parser.add_argument("--min_density", type=int, default=None)
parser.add_argument("--output", default="finder_output_dist.hdf5")
return parser.parse_args()

Expand Down Expand Up @@ -51,12 +55,21 @@ def _calc_distances(row: pd.Series, calculator: RouteDistancesCalculator) -> pd.
return pd.Series(dict_)


def _do_clustering(row: pd.Series, nclusters: int) -> pd.Series:
def _do_clustering(
row: pd.Series, nclusters: int, min_density: int = None
) -> pd.Series:
if row.distance_matrix == [[0.0]] or len(row.trees) < 3:
return pd.Series({"cluster_labels": [], "cluster_time": 0})

if min_density is None:
max_clusters = min(len(row.trees), 10)
else:
max_clusters = int(math.ceil(len(row.trees) / min_density))

time0 = time.perf_counter_ns()
labels = ClusteringHelper.cluster(row.distance_matrix, nclusters).tolist()
labels = ClusteringHelper.cluster(
row.distance_matrix, nclusters, max_clusters=max_clusters
).tolist()
cluster_time = (time.perf_counter_ns() - time0) * 1e-9
return pd.Series({"cluster_labels": labels, "cluster_time": cluster_time})

Expand All @@ -76,21 +89,24 @@ def main() -> None:
calculator = route_distances_calculator(
"lstm",
model_path=args.model,
fp_size=args.fp_size,
lstm_size=args.lstm_size,
)

if not args.only_clustering:
func = functools.partial(
_calc_distances, calculator=calculator
)
dist_data = data.progress_apply(func, axis=1)
dist_data = data.progress_apply(_calc_distances, axis=1, calculator=calculator)
data = data.assign(
distance_matrix=dist_data.distance_matrix,
distances_time=dist_data.distances_time,
)

if args.nclusters is not None:
func = functools.partial(_do_clustering, nclusters=args.nclusters)
cluster_data = data.progress_apply(func, axis=1)
cluster_data = data.progress_apply(
_do_clustering,
axis=1,
nclusters=args.nclusters,
min_density=args.min_density,
)
data = data.assign(
cluster_labels=cluster_data.cluster_labels,
cluster_time=cluster_data.cluster_time,
Expand Down
Loading

0 comments on commit ec507d0

Please sign in to comment.