Skip to content

Commit

Permalink
Merge pull request #84 from OpenDrugDiscovery/downloader
Browse files Browse the repository at this point in the history
Change in download api
  • Loading branch information
prtos authored Jun 8, 2024
2 parents e85839a + 2669d21 commit 318e9c5
Show file tree
Hide file tree
Showing 36 changed files with 539 additions and 454 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ jobs:
- name: Run tests
run: python -m pytest

- name: Test building the doc
run: mkdocs build
#- name: Test building the doc
# run: mkdocs build
4 changes: 2 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ docs_dir: "docs"
nav:
- Overview: index.md
- Available Datasets: datasets.md
- Tutorials:
- Really hard example: tutorials/usage.ipynb
#- Tutorials:
# #- Really hard example: tutorials/usage.ipynb
- API:
- Datasets: API/available_datasets.md
- Isolated Atoms Energies: API/isolated_atom_energies.md
Expand Down
52 changes: 40 additions & 12 deletions openqdc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from typing_extensions import Annotated

from openqdc.datasets import COMMON_MAP_POTENTIALS # noqa
from openqdc.datasets import AVAILABLE_DATASETS, AVAILABLE_POTENTIAL_DATASETS
from openqdc.raws.config_factory import DataConfigFactory, DataDownloader
from openqdc.datasets import (
AVAILABLE_DATASETS,
AVAILABLE_INTERACTION_DATASETS,
AVAILABLE_POTENTIAL_DATASETS,
)

app = typer.Typer(help="OpenQDC CLI")

Expand Down Expand Up @@ -83,22 +86,49 @@ def datasets():


@app.command()
def fetch(datasets: List[str]):
def fetch(
datasets: List[str],
overwrite: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the files.",
),
] = False,
cache_dir: Annotated[
Optional[str],
typer.Option(
help="Path to the cache. If not provided, the default cache directory (.cache/openqdc/) will be used.",
),
] = None,
):
"""
Download the raw datasets files from the main openQDC hub.
Special case: if the dataset is "all", all available datasets will be downloaded.
overwrite: bool = False,
If True, the files will be re-downloaded and overwritten.
cache_dir: Optional[str] = None,
Path to the cache. If not provided, the default cache directory will be used.
Special case: if the dataset is "all", "potential", "interaction".
all: all available datasets will be downloaded.
potential: all the potential datasets will be downloaded
interaction: all the interaction datasets will be downloaded
Example:
openqdc fetch Spice
"""
if datasets[0] == "all":
dataset_names = DataConfigFactory.available_datasets
if datasets[0].lower() == "all":
dataset_names = AVAILABLE_DATASETS
elif datasets[0].lower() == "potential":
dataset_names = AVAILABLE_POTENTIAL_DATASETS
elif datasets[0].lower() == "interaction":
dataset_names = AVAILABLE_INTERACTION_DATASETS
else:
dataset_names = datasets

for dataset_name in dataset_names:
dd = DataDownloader()
dd.from_name(dataset_name)
for dataset in list(map(lambda x: x.lower().replace("_", ""), dataset_names)):
if exist_dataset(dataset):
try:
AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite)
except Exception as e:
logger.error(f"Something unexpected happended while fetching {dataset}: {repr(e)}")


@app.command()
Expand Down Expand Up @@ -128,8 +158,6 @@ def preprocess(
except Exception as e:
logger.error(f"Error while preprocessing {dataset}. {e}. Did you fetch the dataset first?")
raise e
else:
logger.warning(f"{dataset} not found.")


if __name__ == "__main__":
Expand Down
23 changes: 22 additions & 1 deletion openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class BaseDataset(DatasetPropertyMixIn):
__distance_unit__ = "ang"
__forces_unit__ = "hartree/ang"
__average_nb_atoms__ = None
__links__ = {}

def __init__(
self,
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
set_cache_dir(cache_dir)
# self._init_lambda_fn()
self.data = None
self._original_unit = self.__energy_unit__
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
Expand All @@ -145,6 +147,17 @@ def _init_lambda_fn(self):
self._fn_distance = lambda x: x
self._fn_forces = lambda x: x

@property
def config(self):
assert len(self.__links__) > 0, "No links provided for fetching"
return dict(dataset_name=self.__name__, links=self.__links__)

@classmethod
def fetch(cls, cache_path: Optional[str] = None, overwrite: bool = False) -> None:
from openqdc.utils.download_api import DataDownloader

DataDownloader(cache_path, overwrite).from_config(cls.no_init().config)

def _post_init(
self,
overwrite_local_cache: bool = False,
Expand Down Expand Up @@ -256,6 +269,10 @@ def pkl_data_keys(self):
def pkl_data_types(self):
return {"name": str, "subset": str, "n_atoms": np.int32}

@property
def atom_energies(self):
return self._e0s_dispatcher

@property
def data_types(self):
return {
Expand Down Expand Up @@ -287,7 +304,11 @@ def _set_units(self, en, ds):
def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
logger.error("No energy methods defined for this dataset.")
f = get_conversion("hartree", self.__energy_unit__)
if self.energy_type == "formation":
f = get_conversion("hartree", self.__energy_unit__)
else:
# regression are calculated on the original unit of the dataset
f = get_conversion(self._original_unit, self.__energy_unit__)
self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix)

def convert_energy(self, x):
Expand Down
150 changes: 138 additions & 12 deletions openqdc/datasets/energies.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from os.path import join as p_join
from typing import Dict, Union

import numpy as np
from loguru import logger

from openqdc.methods.enums import PotentialMethod
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS
from openqdc.utils.io import load_pkl, save_pkl
from openqdc.utils.regressor import Regressor

Expand Down Expand Up @@ -41,6 +44,56 @@ def dispatch_factory(data, **kwargs) -> "IsolatedEnergyInterface":
return NullEnergy(data, **kwargs)


@dataclass(frozen=False, eq=True)
class AtomSpecies:
"""
Structure that defines a tuple of chemical specie and charge
and provide hash and automatic conversion from atom number to
checmical symbol
"""

symbol: Union[str, int]
charge: int = 0

def __post_init__(self):
if not isinstance(self.symbol, str):
self.symbol = ATOM_SYMBOLS[self.symbol]
self.number = ATOMIC_NUMBERS[self.symbol]

def __hash__(self):
return hash((self.symbol, self.charge))

def __eq__(self, other):
if not isinstance(other, AtomSpecies):
symbol, charge = other[0], other[1]
other = AtomSpecies(symbol=symbol, charge=charge)
return (self.number, self.charge) == (other.number, other.charge)


@dataclass
class AtomEnergy:
"""
Datastructure to store isolated atom energies
and the std deviation associated to the value.
By default the std will be 1 if no value was calculated
or not available (formation energy case)
"""

mean: np.array
std: np.array = field(default_factory=lambda: np.array([1], dtype=np.float32))

def __post_init__(self):
if not isinstance(self.mean, np.ndarray):
self.mean = np.array([self.mean], dtype=np.float32)

def append(self, other: "AtomEnergy"):
"""
Append the mean and std of another atom energy
"""
self.mean = np.append(self.mean, other.mean)
self.std = np.append(self.std, other.std)


class AtomEnergies:
"""
Manager class for interface with the isolated atom energies classes
Expand Down Expand Up @@ -71,15 +124,49 @@ def e0s_matrix(self) -> np.ndarray:
"""
return self.factory.e0_matrix

@property
def e0s_dict(self) -> Dict[AtomSpecies, AtomEnergy]:
"""
Return the isolated atom energies dictionary
"""
return self.factory.e0_dict

def __str__(self):
return f"Atoms: { list(set(map(lambda x : x.symbol, self.e0s_dict.keys())))}"

def __repr__(self):
return str(self)

def __getitem__(self, item: AtomSpecies) -> AtomEnergy:
"""
Retrieve a key from the isolated atom dictionary.
Item can be written as tuple(Symbol, charge),
tuple(Chemical number, charge). If no charge is passed,
it will be automatically set to 0.
Examples:
AtomEnergies[6], AtomEnergies[6,1],
AtomEnergies["C",1], AtomEnergies[(6,1)]
AtomEnergies[("C,1)]
"""
try:
atom, charge = item[0], item[1]
except TypeError:
atom = item
charge = 0
except IndexError:
atom = item[0]
charge = 0
if not isinstance(atom, str):
atom = ATOM_SYMBOLS[atom]
return self.e0s_dict[(atom, charge)]


class IsolatedEnergyInterface(ABC):
"""
Abstract class that defines the interface for the
different implementation of an isolated atom energy value
"""

_e0_matrixs = []

def __init__(self, data, **kwargs):
"""
Parameters
Expand All @@ -93,7 +180,8 @@ def __init__(self, data, **kwargs):
selected energy class. Mostly used for regression
to pass the regressor_kwargs.
"""

self._e0_matrixs = []
self._e0_dict = None
self.kwargs = kwargs
self.data = data
self._post_init()
Expand All @@ -120,27 +208,61 @@ def e0_matrix(self) -> np.ndarray:
"""
return np.array(self._e0_matrixs)

@property
def e0_dict(self) -> Dict:
"""
Return the isolated atom energies dict
"""

return self._e0s_dict

def __str__(self) -> str:
return self.__class__.__name__.lower()


class NullEnergy(IsolatedEnergyInterface):
class PhysicalEnergy(IsolatedEnergyInterface):
"""
Class that returns a null (zeros) matrix for the isolated atom energies in case
of no energies are available.
Class that returns a physical (SE,DFT,etc) isolated atom energies.
"""

def _assembly_e0_dict(self):
datum = {}
for method in self.data.__energy_methods__:
for key, values in method.atom_energies_dict.items():
atm = AtomSpecies(*key)
ens = AtomEnergy(values)
if atm not in datum:
datum[atm] = ens
else:
datum[atm].append(ens)
self._e0s_dict = datum

def _post_init(self):
self._e0_matrixs = [PotentialMethod.NONE.atom_energies_matrix for _ in range(len(self.data.energy_methods))]
self._e0_matrixs = [energy_method.atom_energies_matrix for energy_method in self.data.__energy_methods__]
self._assembly_e0_dict()


class PhysicalEnergy(IsolatedEnergyInterface):
class NullEnergy(IsolatedEnergyInterface):
"""
Class that returns a physical (SE,DFT,etc) isolated atom energies.
Class that returns a null (zeros) matrix for the isolated atom energies in case
of no energies are available.
"""

def _assembly_e0_dict(self):
datum = {}
for _ in self.data.__energy_methods__:
for key, values in PotentialMethod.NONE.atom_energies_dict.items():
atm = AtomSpecies(*key)
ens = AtomEnergy(values)
if atm not in datum:
datum[atm] = ens
else:
datum[atm].append(ens)
self._e0s_dict = datum

def _post_init(self):
self._e0_matrixs = [energy_method.atom_energies_matrix for energy_method in self.data.__energy_methods__]
self._e0_matrixs = [PotentialMethod.NONE.atom_energies_matrix for _ in range(len(self.data.energy_methods))]
self._assembly_e0_dict()


class RegressionEnergy(IsolatedEnergyInterface):
Expand Down Expand Up @@ -175,7 +297,9 @@ def _set_lin_atom_species_dict(self, E0s, covs) -> None:
"""
atomic_energies_dict = {}
for i, z in enumerate(self.regressor.numbers):
atomic_energies_dict[z] = E0s[i]
for charge in range(-10, 11):
atomic_energies_dict[AtomSpecies(z, charge)] = AtomEnergy(E0s[i], 1 if covs is None else covs[i])
# atomic_energies_dict[z] = E0s[i]
self._e0s_dict = atomic_energies_dict
self.save_e0s()

Expand All @@ -187,7 +311,9 @@ def _set_linear_e0s(self) -> None:
new_e0s = [np.zeros((max(self.data.numbers) + 1, MAX_CHARGE_NUMBER)) for _ in range(len(self))]
for z, e0 in self._e0s_dict.items():
for i in range(len(self)):
new_e0s[i][z, :] = e0[i]
# new_e0s[i][z, :] = e0[i]
new_e0s[i][z.number, z.charge] = e0.mean[i]
# for atom_sp, values in
self._e0_matrixs = new_e0s

def save_e0s(self) -> None:
Expand Down
Loading

0 comments on commit 318e9c5

Please sign in to comment.