-
Notifications
You must be signed in to change notification settings - Fork 136
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
296 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Plugins | ||
|
||
This folder contains some features for `aizynthfinder` that | ||
does not yet fit into the main codebase. It could be experimental | ||
features, or features that require the user to install some | ||
additional third-party dependencies. | ||
|
||
For the expansion models, you generally need to add the `plugins` folder to the `PYTHONPATH`, e.g. | ||
|
||
export PYTHONPATH=~/aizynthfinder/plugins/ | ||
|
||
where the `aizynthfinder` repository is in the home folder | ||
|
||
## Chemformer expansion model | ||
|
||
An expansion model using a REST API for the Chemformer model | ||
is supplied in the `expansion_strategies` module. | ||
|
||
To use it, you first need to install the `chemformer` package | ||
and launch the REST API service that comes with it. | ||
|
||
To use the expansion model in `aizynthfinder` you can use a config-file | ||
containing these lines | ||
|
||
expansion: | ||
chemformer: | ||
type: expansion_strategies.ChemformerBasedExpansionStrategy | ||
url: http://localhost:8000/chemformer-api/predict | ||
search: | ||
algorithm_config: | ||
immediate_instantiation: [chemformer] | ||
time_limit: 300 | ||
|
||
The `time_limit` is a recommandation for allowing the more expensive expansion model | ||
to finish a sufficient number of retrosynthesis iterations. | ||
|
||
You would have to change `localhost:8000` to the name and port of the machine hosting the REST service. | ||
|
||
You can then use the config-file with either `aizynthcli` or the Jupyter notebook interface. | ||
|
||
## ModelZoo expansion model | ||
|
||
An expansion model using the ModelZoo feature is supplied in the `expansion_strategies` | ||
module. This is an adoption of the code from this repo: `https://github.com/AlanHassen/modelsmatter` that were used in the publications [Models Matter: The Impact of Single-Step Models on Synthesis Prediction](https://arxiv.org/abs/2308.05522) and [Mind the Retrosynthesis Gap: Bridging the divide between Single-step and Multi-step Retrosynthesis Prediction](https://openreview.net/forum?id=LjdtY0hM7tf). | ||
|
||
To use it, you first need to install the `modelsmatter_modelzoo` package from | ||
https://github.com/PTorrenPeraire/modelsmatter_modelzoo and set up the `ssbenchmark` | ||
environment. | ||
|
||
Ensure that the `external_models` sub-package contains the models required. | ||
If it does not, you will need to manually clone the required model repositories | ||
within `external_models`. | ||
|
||
To use the expansion model in `aizynthfinder`, you can specify it in the config-file | ||
under `expansion`. Here is an example setting to use the expansion model with `chemformer` | ||
as the external model: | ||
|
||
expansion: | ||
chemformer: | ||
type: expansion_strategies.ModelZooExpansionStrategy: | ||
module_path: /path_to_folder_containing_cloned_repository/modelsmatter_modelzoo/external_models/modelsmatter_chemformer_hpc/ | ||
use_gpu: False | ||
params: | ||
module_path: /path_to_model_file/chemformer_backward.ckpt | ||
vocab_path: /path_to_vocab_file/bart_vocab_downstream.txt | ||
search: | ||
algorithm_config: | ||
immediate_instantiation: [chemformer] | ||
time_limit: 300 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
""" Module containing classes that implements different expansion policy strategies | ||
""" | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
import requests | ||
from requests.exceptions import ConnectionError # pylint: disable=redefined-builtin | ||
|
||
from aizynthfinder.chem import SmilesBasedRetroReaction | ||
from aizynthfinder.context.policy import ExpansionStrategy | ||
from aizynthfinder.utils.math import softmax | ||
|
||
if TYPE_CHECKING: | ||
from aizynthfinder.chem import TreeMolecule | ||
from aizynthfinder.chem.reaction import RetroReaction | ||
from aizynthfinder.context.config import Configuration | ||
from aizynthfinder.utils.type_utils import Dict, List, Sequence, Tuple | ||
|
||
try: | ||
from ssbenchmark.model_zoo import ModelZoo | ||
except ImportError: | ||
HAS_MODELZOO = False | ||
else: | ||
HAS_MODELZOO = True | ||
|
||
|
||
class ChemformerBasedExpansionStrategy(ExpansionStrategy): | ||
""" | ||
A template-free expansion strategy that will return `SmilesBasedRetroReaction` objects upon expansion. | ||
It is based on calls to a REST API to the Chemformer model | ||
:param key: the key or label | ||
:param config: the configuration of the tree search | ||
:param url: the URL to the REST API | ||
:param ntrials: how many time to try a REST request | ||
""" | ||
|
||
_required_kwargs = ["url"] | ||
|
||
def __init__(self, key: str, config: Configuration, **kwargs: str) -> None: | ||
super().__init__(key, config, **kwargs) | ||
|
||
self._url: str = kwargs["url"] | ||
self._ntrials = kwargs.get("ntrials", 3) | ||
self._cache: Dict[str, Tuple[Sequence[str], Sequence[float]]] = {} | ||
|
||
# pylint: disable=R0914 | ||
def get_actions( | ||
self, | ||
molecules: Sequence[TreeMolecule], | ||
cache_molecules: Sequence[TreeMolecule] = None, | ||
) -> Tuple[List[RetroReaction], List[float]]: | ||
""" | ||
Get all the probable actions of a set of molecules, using the selected policies and given cutoffs | ||
:param molecules: the molecules to consider | ||
:param cache_molecules: additional molecules that are sent to | ||
the expansion model but for which predictions are not returned | ||
:return: the actions and the priors of those actions | ||
""" | ||
possible_actions = [] | ||
priors = [] | ||
|
||
cache_molecules = cache_molecules or [] | ||
self._update_cache(molecules + cache_molecules) | ||
|
||
for mol in molecules: | ||
try: | ||
output_smiles, probs = self._cache[mol.inchi_key] | ||
except KeyError: | ||
continue | ||
|
||
priors.extend(probs) | ||
for idx, reactants_str in enumerate(output_smiles): | ||
metadata = {} | ||
metadata["policy_probability"] = float(probs[idx]) | ||
metadata["policy_probability_rank"] = idx | ||
metadata["policy_name"] = self.key | ||
possible_actions.append( | ||
SmilesBasedRetroReaction( | ||
mol, metadata=metadata, reactants_str=reactants_str | ||
) | ||
) | ||
return possible_actions, priors # type: ignore | ||
|
||
def reset_cache(self) -> None: | ||
"""Reset the prediction cache""" | ||
self._cache = {} | ||
|
||
def _update_cache(self, molecules: Sequence[TreeMolecule]) -> None: | ||
pred_inchis = [] | ||
smiles_list = [] | ||
for molecule in molecules: | ||
if molecule.inchi_key in self._cache or molecule.inchi_key in pred_inchis: | ||
continue | ||
smiles_list.append(molecule.smiles) | ||
pred_inchis.append(molecule.inchi_key) | ||
|
||
if not pred_inchis: | ||
return | ||
|
||
for _ in range(self._ntrials): | ||
try: | ||
ret = requests.post(self._url, json=smiles_list) | ||
except ConnectionError: | ||
continue | ||
if ret.status_code == requests.codes.ok: | ||
break | ||
|
||
if ret.status_code != requests.codes.ok: | ||
self._logger.debug( | ||
f"Failed to retrieve results from Chemformer model: {ret.content}" | ||
) | ||
return | ||
|
||
predictions = ret.json() | ||
for prediction, inchi in zip(predictions, pred_inchis): | ||
self._cache[inchi] = (prediction["output"], softmax(prediction["lhs"])) | ||
|
||
|
||
class ModelZooExpansionStrategy(ExpansionStrategy): | ||
""" | ||
An expansion strategy that uses a single step model to operate on a Smiles-level | ||
of abstraction | ||
:param key: the key or label of the single step model | ||
:param config: the configuration of the tree search | ||
:param module_path: the path to the external model | ||
:raises ImportError: if ssbenchmark has not been installed. | ||
""" | ||
|
||
_required_kwargs = ["module_path"] | ||
|
||
def __init__(self, key: str, config: Configuration, **kwargs: str) -> None: | ||
if not HAS_MODELZOO: | ||
raise ImportError( | ||
"Cannot use this expansion strategy as it seems like " | ||
"ssbenchmark is not installed." | ||
) | ||
|
||
super().__init__(key, config, **kwargs) | ||
module_path = kwargs["module_path"] | ||
gpu_mode = kwargs.pop("use_gpu", False) | ||
model_params = dict(kwargs.pop("params", {})) | ||
|
||
self.model_zoo = ModelZoo(key, module_path, gpu_mode, model_params) | ||
self._cache: Dict[str, Tuple[Sequence[str], Sequence[float]]] = {} | ||
|
||
def get_actions( | ||
self, | ||
molecules: Sequence[TreeMolecule], | ||
cache_molecules: Sequence[TreeMolecule] = None, | ||
) -> Tuple[List[RetroReaction], List[float]]: | ||
""" | ||
Get all the probable actions of a set of molecules, using the selected policies | ||
and given cutoffs. | ||
:param molecules: the molecules to consider | ||
:param cache_molecules: additional molecules that are sent to the expansion | ||
model but for which predictions are not returned | ||
:return: the actions and the priors of those actions | ||
""" | ||
possible_actions = [] | ||
priors = [] | ||
cache_molecules = cache_molecules or [] | ||
self._update_cache(molecules + cache_molecules) | ||
|
||
for mol in molecules: | ||
output_smiles, probs = self._cache[mol.inchi_key] | ||
priors.extend(probs) | ||
|
||
for idx, move in enumerate(output_smiles): | ||
metadata = {} | ||
metadata["reaction"] = move | ||
metadata["policy_probability"] = float(probs[idx].round(4)) | ||
metadata["policy_probability_rank"] = idx | ||
metadata["policy_name"] = self.key | ||
|
||
possible_actions.append( | ||
SmilesBasedRetroReaction(mol, reactants_str=move, metadata=metadata) | ||
) | ||
|
||
return possible_actions, priors | ||
|
||
def _update_cache(self, molecules: Sequence[TreeMolecule]) -> None: | ||
pred_inchis = [] | ||
smiles_list = [] | ||
|
||
for molecule in molecules: | ||
if molecule.inchi_key in self._cache or molecule.inchi_key in pred_inchis: | ||
continue | ||
smiles_list.append(molecule.smiles) | ||
pred_inchis.append(molecule.inchi_key) | ||
|
||
if not pred_inchis: | ||
return | ||
|
||
pred_reactants, pred_priors = ( | ||
np.array(item) for item in self.model_zoo.model_call(smiles_list) | ||
) | ||
|
||
for reactants, priors, inchi in zip(pred_reactants, pred_priors, pred_inchis): | ||
self._cache[inchi] = (reactants, priors) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,4 @@ dependencies: | |
- git | ||
- pip>=20.0 | ||
- pip: | ||
- git+ssh://[email protected]:7999/com/aizynthfinder.git | ||
- https://github.com/MolecularAI/aizynthfinder/archive/v3.7.0.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters