diff --git a/CHANGELOG.md b/CHANGELOG.md index 1acd20d..c762268 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # CHANGELOG +## Version 1.0.0 - 2022-02-11 + +- Generalization of featurization for TED calculations +- Utility routines for route property calculations + ## Version 0.2.1 - 2021-12-21 ### Trivial changes diff --git a/README.md b/README.md index 4076adc..6c4245a 100644 --- a/README.md +++ b/README.md @@ -100,5 +100,5 @@ The software is licensed under the MIT license (see LICENSE file), and is free a ## References -1. Genheden S, Engkvist O, Bjerrum E (2020) Clustering of synthetic routes using tree edit distance. ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.13372475.v1 -2. Genheden S, Engkvist O, Bjerrum E (2021) Fast prediction of distances between synthetic routes with deep learning. ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.14778150.v1 +1. Genheden S, Engkvist O, Bjerrum E (2021) Clustering of synthetic routes using tree edit distance. J. Chem. Inf. Model. 61:3899–3907 [https://doi.org/10.1021/acs.jcim.1c00232](https://doi.org/10.1021/acs.jcim.1c00232) +2. Genheden S, Engkvist O, Bjerrum E (2022) Fast prediction of distances between synthetic routes with deep learning. Mach. Learn. Sci. Technol. 3:015018 [https://doi.org/10.1088/2632-2153/ac4a91](https://doi.org/10.1088/2632-2153/ac4a91) diff --git a/docs/conf.py b/docs/conf.py index 82830fb..91ec349 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -6,7 +6,7 @@ project = "route-distances" copyright = "2021, Molecular AI group" author = "Molecular AI group" -release = "0.2.1" +release = "1.0.0" extensions = [ "sphinx.ext.autodoc", diff --git a/pyproject.toml b/pyproject.toml index 03b441f..4f303a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "route-distances" -version = "0.2.1" +version = "1.0.0" description = "Models for calculating distances between synthesis routes" authors = ["Genheden, Samuel "] license = "MIT" diff --git a/route_distances/ted/reactiontree.py b/route_distances/ted/reactiontree.py index f7935c7..337fe27 100644 --- a/route_distances/ted/reactiontree.py +++ b/route_distances/ted/reactiontree.py @@ -7,15 +7,17 @@ import itertools import math from copy import deepcopy -from typing import List, Union, Iterable, Tuple, Dict, Any +from typing import List, Union, Iterable, Tuple, Callable, Optional from logging import getLogger import numpy as np -from rdkit import Chem, DataStructs -from rdkit.Chem import AllChem from apted import APTED as Apted -from route_distances.ted.utils import TreeContent, AptedConfig +from route_distances.ted.utils import ( + TreeContent, + AptedConfig, + StandardFingerprintFactory, +) from route_distances.validation import validate_dict from route_distances.utils.type_utils import StrDict @@ -30,6 +32,8 @@ class ReactionTreeWrapper: :param reaction_tree: the reaction tree to wrap :param content: the content of the route to consider in the distance calculation :param exhaustive_limit: if the number of possible ordered trees are below this limit create them all + :param fp_factory: the factory of the fingerprint, Morgan fingerprint for molecules and reactions by default + :param dist_func: the distance function to use when renaming nodes """ _index_permutations = { @@ -41,8 +45,8 @@ def __init__( reaction_tree: StrDict, content: Union[str, TreeContent] = TreeContent.MOLECULES, exhaustive_limit: int = 20, - fp_radius: int = 2, - fp_nbits: int = 2048, + fp_factory: Callable[[StrDict, Optional[StrDict]], None] = None, + dist_func: Callable[[np.ndarray, np.ndarray], float] = None, ) -> None: validate_dict(reaction_tree) single_node_tree = not bool(reaction_tree.get("children", [])) @@ -56,11 +60,11 @@ def __init__( self._content = TreeContent(content) self._base_tree = deepcopy(reaction_tree) - self._fp_params = (fp_radius, fp_nbits) - self._add_mol_fingerprints(self._base_tree) + self._fp_factory = fp_factory or StandardFingerprintFactory() + self._add_fingerprints(self._base_tree) if self._content != TreeContent.MOLECULES and not single_node_tree: - self._add_rxn_fingerprint(self._base_tree["children"][0], self._base_tree) + self._add_fingerprints(self._base_tree["children"][0], self._base_tree) if self._content == TreeContent.MOLECULES: self._base_tree = self._remove_children_nodes(self._base_tree) @@ -78,6 +82,8 @@ def __init__( else: self._trees.append(self._base_tree) + self._dist_func = dist_func + @property def info(self) -> StrDict: """Return a dictionary with internal information about the wrapper""" @@ -158,31 +164,24 @@ def distance_to_with_sorting(self, other: "ReactionTreeWrapper") -> float: :param other: another tree to calculate distance to :return: the distance """ - config = AptedConfig(sort_children=True) + config = AptedConfig(sort_children=True, dist_func=self._dist_func) return Apted(self.first_tree, other.first_tree, config).compute_edit_distance() - def _add_mol_fingerprints(self, tree: Dict[str, Any]) -> None: - mol = Chem.MolFromSmiles(tree["smiles"]) - rd_fp = AllChem.GetMorganFingerprintAsBitVect(mol, *self._fp_params) - tree["fingerprint"] = np.zeros((1,), dtype=np.int8) - DataStructs.ConvertToNumpyArray(rd_fp, tree["fingerprint"]) + def _add_fingerprints(self, tree: StrDict, parent: StrDict = None) -> None: + if "fingerprint" not in tree: + try: + self._fp_factory(tree, parent) + except ValueError: + pass + if "fingerprint" not in tree: + tree["fingerprint"] = [] tree["sort_key"] = "".join(f"{digit}" for digit in tree["fingerprint"]) if "children" not in tree: tree["children"] = [] for child in tree["children"]: for grandchild in child["children"]: - self._add_mol_fingerprints(grandchild) - - def _add_rxn_fingerprint(self, node: StrDict, parent: StrDict) -> None: - node["fingerprint"] = parent["fingerprint"].copy() - for reactant in node["children"]: - node["fingerprint"] -= reactant["fingerprint"] - node["sort_key"] = "".join(f"{digit}" for digit in node["fingerprint"]) - - for child in node["children"]: - for grandchild in child.get("children", []): - self._add_rxn_fingerprint(grandchild, child) + self._add_fingerprints(grandchild, child) def _create_all_trees(self) -> None: self._trees = [] @@ -212,7 +211,7 @@ def _distance_iter_exhaustive(self, other: "ReactionTreeWrapper") -> _FloatItera self._logger.debug( f"APTED: Exhaustive search. {len(self.trees)} {len(other.trees)}" ) - config = AptedConfig(randomize=False) + config = AptedConfig(randomize=False, dist_func=self._dist_func) for tree1, tree2 in itertools.product(self.trees, other.trees): yield Apted(tree1, tree2, config).compute_edit_distance() @@ -222,10 +221,10 @@ def _distance_iter_random( self._logger.debug( f"APTED: Heuristic search. {len(self.trees)} {len(other.trees)}" ) - config = AptedConfig(randomize=False) + config = AptedConfig(randomize=False, dist_func=self._dist_func) yield Apted(self.first_tree, other.first_tree, config).compute_edit_distance() - config = AptedConfig(randomize=True) + config = AptedConfig(randomize=True, dist_func=self._dist_func) for _ in range(ntimes): yield Apted( self.first_tree, other.first_tree, config @@ -244,7 +243,7 @@ def _distance_iter_semi_exhaustive( first_wrapper = other second_wrapper = self - config = AptedConfig(randomize=False) + config = AptedConfig(randomize=False, dist_func=self._dist_func) for tree1 in first_wrapper.trees: yield Apted( tree1, second_wrapper.first_tree, config @@ -279,7 +278,8 @@ def _recurse_tree(node): def _make_base_copy(node: StrDict) -> StrDict: return { "type": node["type"], - "smiles": node["smiles"], + "smiles": node.get("smiles", ""), + "metadata": node.get("metadata"), "fingerprint": node["fingerprint"], "sort_key": node["sort_key"], "children": [], diff --git a/route_distances/ted/utils.py b/route_distances/ted/utils.py index bcb1607..112d343 100644 --- a/route_distances/ted/utils.py +++ b/route_distances/ted/utils.py @@ -5,10 +5,13 @@ from enum import Enum from operator import itemgetter +import numpy as np +from rdkit import Chem, DataStructs +from rdkit.Chem import AllChem from apted import Config as BaseAptedConfig from scipy.spatial.distance import jaccard as jaccard_dist -from route_distances.utils.type_utils import StrDict +from route_distances.utils.type_utils import StrDict, Callable class TreeContent(str, Enum): @@ -27,12 +30,19 @@ class AptedConfig(BaseAptedConfig): :param randomize: if True, the children will be shuffled :param sort_children: if True, the children will be sorted + :param dist_func: the distance function used for renaming nodes, Jaccard by default """ - def __init__(self, randomize: bool = False, sort_children: bool = False) -> None: + def __init__( + self, + randomize: bool = False, + sort_children: bool = False, + dist_func: Callable[[np.ndarray, np.ndarray], float] = None, + ) -> None: super().__init__() self._randomize = randomize self._sort_children = sort_children + self._dist_func = dist_func or jaccard_dist def rename(self, node1: StrDict, node2: StrDict) -> float: if node1["type"] != node2["type"]: @@ -40,7 +50,7 @@ def rename(self, node1: StrDict, node2: StrDict) -> float: fp1 = node1["fingerprint"] fp2 = node2["fingerprint"] - return jaccard_dist(fp1, fp2) + return self._dist_func(fp1, fp2) def children(self, node: StrDict) -> List[StrDict]: if self._sort_children: @@ -50,3 +60,50 @@ def children(self, node: StrDict) -> List[StrDict]: children = list(node["children"]) random.shuffle(children) return children + + +class StandardFingerprintFactory: + """ + Calculate Morgan fingerprint for molecules, and difference fingerprints for reactions + + :param radius: the radius of the fingerprint + :param nbits: the fingerprint lengths + """ + + def __init__(self, radius: int = 2, nbits: int = 2048) -> None: + self._fp_params = (radius, nbits) + + def __call__(self, tree: StrDict, parent: StrDict = None) -> None: + if tree["type"] == "reaction": + if parent is None: + raise ValueError( + "Must specify parent when making Morgan fingerprints for reaction nodes" + ) + self._add_rxn_fingerprint(tree, parent) + else: + self._add_mol_fingerprints(tree) + + def _add_mol_fingerprints(self, tree: StrDict) -> None: + if "fingerprint" not in tree: + mol = Chem.MolFromSmiles(tree["smiles"]) + rd_fp = AllChem.GetMorganFingerprintAsBitVect(mol, *self._fp_params) + tree["fingerprint"] = np.zeros((1,), dtype=np.int8) + DataStructs.ConvertToNumpyArray(rd_fp, tree["fingerprint"]) + tree["sort_key"] = "".join(f"{digit}" for digit in tree["fingerprint"]) + if "children" not in tree: + tree["children"] = [] + + for child in tree["children"]: + for grandchild in child["children"]: + self._add_mol_fingerprints(grandchild) + + def _add_rxn_fingerprint(self, node: StrDict, parent: StrDict) -> None: + if "fingerprint" not in node: + node["fingerprint"] = parent["fingerprint"].copy() + for reactant in node["children"]: + node["fingerprint"] -= reactant["fingerprint"] + node["sort_key"] = "".join(f"{digit}" for digit in node["fingerprint"]) + + for child in node["children"]: + for grandchild in child.get("children", []): + self._add_rxn_fingerprint(grandchild, child) diff --git a/route_distances/tools/cluster_aizynth_output.py b/route_distances/tools/cluster_aizynth_output.py index 51aadb5..0ee784b 100644 --- a/route_distances/tools/cluster_aizynth_output.py +++ b/route_distances/tools/cluster_aizynth_output.py @@ -2,13 +2,14 @@ from __future__ import annotations import argparse import warnings -import functools import time +import math from typing import List import pandas as pd from tqdm import tqdm +import route_distances.lstm.defaults as defaults from route_distances.route_distances import route_distances_calculator from route_distances.clustering import ClusteringHelper from route_distances.utils.type_utils import RouteDistancesCalculator @@ -19,9 +20,12 @@ def _get_args() -> argparse.Namespace: "Tool to calculate pairwise distances for AiZynthFinder output" ) parser.add_argument("--files", nargs="+", required=True) + parser.add_argument("--fp_size", type=int, default=defaults.FP_SIZE) + parser.add_argument("--lstm_size", type=int, default=defaults.LSTM_SIZE) parser.add_argument("--model", required=True) parser.add_argument("--only_clustering", action="store_true", default=False) parser.add_argument("--nclusters", type=int, default=None) + parser.add_argument("--min_density", type=int, default=None) parser.add_argument("--output", default="finder_output_dist.hdf5") return parser.parse_args() @@ -51,12 +55,21 @@ def _calc_distances(row: pd.Series, calculator: RouteDistancesCalculator) -> pd. return pd.Series(dict_) -def _do_clustering(row: pd.Series, nclusters: int) -> pd.Series: +def _do_clustering( + row: pd.Series, nclusters: int, min_density: int = None +) -> pd.Series: if row.distance_matrix == [[0.0]] or len(row.trees) < 3: return pd.Series({"cluster_labels": [], "cluster_time": 0}) + if min_density is None: + max_clusters = min(len(row.trees), 10) + else: + max_clusters = int(math.ceil(len(row.trees) / min_density)) + time0 = time.perf_counter_ns() - labels = ClusteringHelper.cluster(row.distance_matrix, nclusters).tolist() + labels = ClusteringHelper.cluster( + row.distance_matrix, nclusters, max_clusters=max_clusters + ).tolist() cluster_time = (time.perf_counter_ns() - time0) * 1e-9 return pd.Series({"cluster_labels": labels, "cluster_time": cluster_time}) @@ -76,21 +89,24 @@ def main() -> None: calculator = route_distances_calculator( "lstm", model_path=args.model, + fp_size=args.fp_size, + lstm_size=args.lstm_size, ) if not args.only_clustering: - func = functools.partial( - _calc_distances, calculator=calculator - ) - dist_data = data.progress_apply(func, axis=1) + dist_data = data.progress_apply(_calc_distances, axis=1, calculator=calculator) data = data.assign( distance_matrix=dist_data.distance_matrix, distances_time=dist_data.distances_time, ) if args.nclusters is not None: - func = functools.partial(_do_clustering, nclusters=args.nclusters) - cluster_data = data.progress_apply(func, axis=1) + cluster_data = data.progress_apply( + _do_clustering, + axis=1, + nclusters=args.nclusters, + min_density=args.min_density, + ) data = data.assign( cluster_labels=cluster_data.cluster_labels, cluster_time=cluster_data.cluster_time, diff --git a/route_distances/utils/routes.py b/route_distances/utils/routes.py new file mode 100644 index 0000000..3df1087 --- /dev/null +++ b/route_distances/utils/routes.py @@ -0,0 +1,144 @@ +""" Module containing helper routines for routes """ +from typing import Dict, Any, Set, List, Tuple + +import numpy as np + +from route_distances.utils.type_utils import StrDict + + +def calc_depth(tree_dict: StrDict, depth: int = 0) -> int: + """ + Calculate the depth of a route, recursively + + :param tree_dict: the route + :param depth: the current depth, don't specify for route + """ + children = tree_dict.get("children", []) + if children: + return max(calc_depth(child, depth + 1) for child in children) + return depth + + +def calc_llr(tree_dict: StrDict) -> int: + """ + Calculate the longest linear route for a synthetic route + + :param tree_dict: the route + """ + return calc_depth(tree_dict) // 2 + + +def extract_leaves( + tree_dict: StrDict, +) -> Set[str]: + """ + Extract a set with the SMILES of all the leaf nodes, i.e. + starting material + + :param tree_dict: the route + :return: a set of SMILE strings + """ + + def traverse(tree_dict: StrDict, leaves: Set[str]) -> None: + children = tree_dict.get("children", []) + if children: + for child in children: + traverse(child, leaves) + else: + leaves.add(tree_dict["smiles"]) + + leaves = set() + traverse(tree_dict, leaves) + return leaves + + +def is_solved(route: StrDict) -> bool: + """ + Find if a route is solved, i.e. if all starting material + is in stock. + + To be accurate, each molecule node need to have an extra + boolean property called `in_stock`. + + :param route: the route to analyze + """ + + def find_leaves_not_in_stock(tree_dict: StrDict) -> None: + children = tree_dict.get("children", []) + if not children and not tree_dict.get("in_stock", True): + raise ValueError(f"child not in stock {tree_dict}") + elif children: + for child in children: + find_leaves_not_in_stock(child) + + try: + find_leaves_not_in_stock(route) + except ValueError: + return False + return True + + +def route_score( + tree_dict: StrDict, + mol_costs: Dict[bool, float] = None, + average_yield=0.8, + reaction_cost=1.0, +) -> float: + """ + Calculate the score of route using the method from + (Badowski et al. Chem Sci. 2019, 10, 4640). + + The reaction cost is constant and the yield is an average yield. + The starting materials are assigned a cost based on whether they are in + stock or not. By default starting material in stock is assigned a + cost of 1 and starting material not in stock is assigned a cost of 10. + + To be accurate, each molecule node need to have an extra + boolean property called `in_stock`. + + :param tree_dict: the route to analyze + :param mol_costs: the starting material cost + :param average_yield: the average yield, defaults to 0.8 + :param reaction_cost: the reaction cost, defaults to 1.0 + :return: the computed cost + """ + mol_cost = mol_costs or {True: 1, False: 10} + + reactions = tree_dict.get("children", []) + if not reactions: + return mol_cost[tree_dict.get("in_stock", True)] + + child_sum = sum( + 1 / average_yield * route_score(child) for child in reactions[0]["children"] + ) + return reaction_cost + child_sum + + +def route_scorer(routes: List[StrDict]) -> Tuple[List[StrDict], List[float]]: + """ + Scores and sort a list of routes. + Returns a tuple of the sorted routes and their costs. + + :param routes: the routes to score + :return: the sorted routes and their costs + """ + scores = np.asarray([route_score(route) for route in routes]) + sorted_idx = np.argsort(scores) + routes = [routes[idx] for idx in sorted_idx] + return routes, scores[sorted_idx].tolist() + + +def route_ranks(scores: List[float]) -> List[int]: + """ + Compute the rank of route scores. Rank starts at 1 + + :param scores: the route scores + :return: a list of ranks for each route + """ + ranks = [1] + for idx in range(1, len(scores)): + if abs(scores[idx] - scores[idx - 1]) < 1e-8: + ranks.append(ranks[idx - 1]) + else: + ranks.append(ranks[idx - 1] + 1) + return ranks diff --git a/tests/conftest.py b/tests/conftest.py index c8231cd..8d5dc78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,8 @@ def wrapper(filename, index=0): trees = json.load(fileobj) if isinstance(trees, dict): return trees + elif index == -1: + return trees else: return trees[index] diff --git a/tests/test_route_utils.py b/tests/test_route_utils.py new file mode 100644 index 0000000..79c537b --- /dev/null +++ b/tests/test_route_utils.py @@ -0,0 +1,96 @@ +from route_distances.utils.routes import ( + calc_depth, + calc_llr, + extract_leaves, + is_solved, + route_score, + route_scorer, + route_ranks, +) + + +def remove_in_stock(tree_dict): + if "in_stock" in tree_dict: + del tree_dict["in_stock"] + for child in tree_dict.get("children", []): + remove_in_stock(child) + + +def test_route_depth(load_reaction_tree): + routes = load_reaction_tree("example_routes.json", index=-1) + + assert calc_depth(routes[0]) == 2 + assert calc_depth(routes[1]) == 4 + + assert calc_llr(routes[0]) == 1 + assert calc_llr(routes[1]) == 2 + + +def test_route_leaves(load_reaction_tree): + route = load_reaction_tree("example_routes.json", index=0) + + assert extract_leaves(route) == { + "Cc1ccc2nc3ccccc3c(Cl)c2c1", + "Nc1ccc(NC(=S)Nc2ccccc2)cc1", + } + + +def test_route_solved(load_reaction_tree): + route = load_reaction_tree("example_routes.json", index=0) + + assert is_solved(route) + + +def test_route_not_solved(load_reaction_tree): + route = load_reaction_tree("example_routes.json", index=0) + route["children"][0]["children"][0]["in_stock"] = False + + assert not is_solved(route) + + +def test_route_solved_unspec(load_reaction_tree): + route = load_reaction_tree("example_routes.json", index=0) + remove_in_stock(route) + + assert is_solved(route) + + +def test_route_score(load_reaction_tree): + routes = load_reaction_tree("example_routes.json", index=-1) + + assert route_score(routes[0]) == 3.5 + assert route_score(routes[1]) == 6.625 + + +def test_route_score_unsolved(load_reaction_tree): + route = load_reaction_tree("example_routes.json", index=0) + route["children"][0]["children"][0]["in_stock"] = False + + assert route_score(route) == 14.75 + + +def test_route_score_unspec(load_reaction_tree): + route = load_reaction_tree("example_routes.json", index=0) + remove_in_stock(route) + + assert route_score(route) == 3.5 + + +def test_route_scorer(load_reaction_tree): + routes = load_reaction_tree("example_routes.json", index=-1) + routes2 = [routes[2], routes[0], routes[1]] + + sorted_routes, route_scores = route_scorer(routes2) + + assert route_scores == [3.5, 6.625, 6.625] + assert sorted_routes[0] == routes[0] + assert sorted_routes[1] == routes[2] + assert sorted_routes[2] == routes[1] + + +def test_route_rank(): + + assert route_ranks([4.0, 5.0, 5.0]) == [1, 2, 2] + assert route_ranks([4.0, 4.0, 5.0]) == [1, 1, 2] + assert route_ranks([4.0, 5.0, 6.0]) == [1, 2, 3] + assert route_ranks([4.0, 5.0, 5.0, 6.0]) == [1, 2, 2, 3] diff --git a/tests/test_ted.py b/tests/test_ted.py index 50f062d..8fc1797 100644 --- a/tests/test_ted.py +++ b/tests/test_ted.py @@ -5,6 +5,7 @@ from route_distances.ted.utils import ( AptedConfig, TreeContent, + StandardFingerprintFactory, ) from route_distances.ted.reactiontree import ReactionTreeWrapper from route_distances.ted.distances import distance_matrix @@ -21,6 +22,26 @@ def collect_smiles(tree, query_type, smiles_list): node2 = {"type": "mol", "fingerprint": [1, 1, 0]} +example_tree = { + "type": "mol", + "smiles": "Cc1ccc2nc3ccccc3c(Nc3ccc(NC(=S)Nc4ccccc4)cc3)c2c1", + "children": [ + { + "type": "reaction", + "children": [ + { + "type": "mol", + "smiles": "Cc1ccc2nc3ccccc3c(Cl)c2c1", + }, + { + "type": "mol", + "smiles": "Nc1ccc(NC(=S)Nc2ccccc2)cc1", + }, + ], + } + ], +} + def test_rename_cost_different_types(): config = AptedConfig() @@ -290,3 +311,31 @@ def test_distance_matrix_timeout(load_reaction_tree): with pytest.raises(ValueError): distance_matrix(reaction_trees, content="molecules", timeout=0) + + +def test_fingerprint_calculations(): + wrapper = ReactionTreeWrapper( + example_tree, content="both", fp_factory=StandardFingerprintFactory(nbits=128) + ) + + fp = wrapper.first_tree["sort_key"] + mol1 = "1000010000000000000010001000100101000101100000010000010000100001" + mol2 = "1100000001110110011000100010000001001000000100100000110000100100" + assert fp == mol1 + mol2 + + fp = wrapper.first_tree["children"][0]["sort_key"] + rxn1 = "00000-1000000-1000-100-2000000000000000000000000000-10-20000000000000-1" + rxn2 = "-10000000001000100-10000-100-10000000000-10-1000000000000-11000-10000100" + assert fp == rxn1 + rxn2 + + +def test_custom_fingerprint_calculations(): + def factory(tree, parent): + if tree["type"] != "reaction": + return + tree["fingerprint"] = [1, 2, 3, 4] + + wrapper = ReactionTreeWrapper(example_tree, content="both", fp_factory=factory) + + assert wrapper.first_tree["sort_key"] == "" + assert wrapper.first_tree["children"][0]["sort_key"] == "1234"