Skip to content

Commit

Permalink
Final merge
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Jun 18, 2024
2 parents ea08244 + 85d0435 commit 1885a0b
Show file tree
Hide file tree
Showing 43 changed files with 649 additions and 580 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ jobs:
- name: Run tests
run: python -m pytest

- name: Test building the doc
run: mkdocs build
#- name: Test building the doc
# run: mkdocs build
4 changes: 2 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ docs_dir: "docs"
nav:
- Overview: index.md
- Available Datasets: datasets.md
- Tutorials:
- Really hard example: tutorials/usage.ipynb
#- Tutorials:
# #- Really hard example: tutorials/usage.ipynb
- API:
- Datasets: API/available_datasets.md
- Isolated Atoms Energies: API/isolated_atom_energies.md
Expand Down
3 changes: 2 additions & 1 deletion openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_project_root():
"ANI1CCX": "openqdc.datasets.potential.ani",
"ANI1CCX_V2": "openqdc.datasets.potential.ani",
"ANI1X": "openqdc.datasets.potential.ani",
"ANI2X": "openqdc.datasets.potential.ani",
"Spice": "openqdc.datasets.potential.spice",
"SpiceV2": "openqdc.datasets.potential.spice",
"SpiceVL2": "openqdc.datasets.potential.spice",
Expand Down Expand Up @@ -100,7 +101,7 @@ def __dir__():
from .datasets.interaction.metcalf import Metcalf
from .datasets.interaction.splinter import Splinter
from .datasets.interaction.x40 import X40
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2X
from .datasets.potential.comp6 import COMP6
from .datasets.potential.dummy import Dummy
from .datasets.potential.gdml import GDML
Expand Down
70 changes: 52 additions & 18 deletions openqdc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,24 @@
from typing_extensions import Annotated

from openqdc.datasets import COMMON_MAP_POTENTIALS # noqa
from openqdc.datasets import AVAILABLE_DATASETS, AVAILABLE_POTENTIAL_DATASETS
from openqdc.raws.config_factory import DataConfigFactory, DataDownloader
from openqdc.datasets import (
AVAILABLE_DATASETS,
AVAILABLE_INTERACTION_DATASETS,
AVAILABLE_POTENTIAL_DATASETS,
)

app = typer.Typer(help="OpenQDC CLI")


def sanitize(dictionary):
return {k.lower().replace("_", "").replace("-", ""): v for k, v in dictionary.items()}


SANITIZED_AVAILABLE_DATASETS = sanitize(AVAILABLE_DATASETS)


def exist_dataset(dataset):
if dataset not in AVAILABLE_DATASETS:
if dataset not in sanitize(AVAILABLE_DATASETS):
logger.error(f"{dataset} is not available. Please open an issue on Github for the team to look into it.")
return False
return True
Expand Down Expand Up @@ -54,10 +64,10 @@ def download(
"""
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
if AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite:
if SANITIZED_AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite:
logger.info(f"{dataset} is already cached. Skipping download")
else:
AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir)
SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir)


@app.command()
Expand All @@ -83,22 +93,48 @@ def datasets():


@app.command()
def fetch(datasets: List[str]):
def fetch(
datasets: List[str],
overwrite: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the files.",
),
] = False,
cache_dir: Annotated[
Optional[str],
typer.Option(
help="Path to the cache. If not provided, the default cache directory (.cache/openqdc/) will be used.",
),
] = None,
):
"""
Download the raw datasets files from the main openQDC hub.
Special case: if the dataset is "all", all available datasets will be downloaded.
overwrite: bool = False,
If True, the files will be re-downloaded and overwritten.
cache_dir: Optional[str] = None,
Path to the cache. If not provided, the default cache directory will be used.
Special case: if the dataset is "all", "potential", "interaction".
all: all available datasets will be downloaded.
potential: all the potential datasets will be downloaded
interaction: all the interaction datasets will be downloaded
Example:
openqdc fetch Spice
"""
if datasets[0] == "all":
dataset_names = DataConfigFactory.available_datasets
if datasets[0].lower() == "all":
dataset_names = list(sanitize(AVAILABLE_DATASETS).keys())
elif datasets[0].lower() == "potential":
dataset_names = list(sanitize(AVAILABLE_POTENTIAL_DATASETS).keys())
elif datasets[0].lower() == "interaction":
dataset_names = list(sanitize(AVAILABLE_INTERACTION_DATASETS).keys())
else:
dataset_names = datasets

for dataset_name in dataset_names:
dd = DataDownloader()
dd.from_name(dataset_name)
for dataset in list(map(lambda x: x.lower().replace("_", ""), dataset_names)):
if exist_dataset(dataset):
try:
SANITIZED_AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite)
except Exception as e:
logger.error(f"Something unexpected happended while fetching {dataset}: {repr(e)}")


@app.command()
Expand All @@ -122,14 +158,12 @@ def preprocess(
"""
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
logger.info(f"Preprocessing {AVAILABLE_DATASETS[dataset].__name__}")
logger.info(f"Preprocessing {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}")
try:
AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite)
SANITIZED_AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite)
except Exception as e:
logger.error(f"Error while preprocessing {dataset}. {e}. Did you fetch the dataset first?")
raise e
else:
logger.warning(f"{dataset} not found.")


if __name__ == "__main__":
Expand Down
29 changes: 25 additions & 4 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ class BaseDataset(DatasetPropertyMixIn):
__distance_unit__ = "ang"
__forces_unit__ = "hartree/ang"
__average_nb_atoms__ = None
__links__ = {}

def __init__(
self,
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
array_format: str = "numpy",
energy_type: str = "formation",
energy_type: Optional[str] = "formation",
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
recompute_statistics: bool = False,
Expand All @@ -111,7 +112,7 @@ def __init__(
Format to return arrays in. Supported formats: ["numpy", "torch", "jax"]
energy_type
Type of isolated atom energy to use for the dataset. Default: "formation"
Supported types: ["formation", "regression", "null"]
Supported types: ["formation", "regression", "null", None]
overwrite_local_cache
Whether to overwrite the locally cached dataset.
cache_dir
Expand All @@ -128,10 +129,11 @@ def __init__(
set_cache_dir(cache_dir)
# self._init_lambda_fn()
self.data = None
self._original_unit = self.__energy_unit__
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
self.energy_type = energy_type
self.energy_type = energy_type if energy_type is not None else "null"
self.refit_e0s = recompute_statistics or overwrite_local_cache
if not self.is_preprocessed():
raise DatasetNotAvailableError(self.__name__)
Expand All @@ -145,6 +147,17 @@ def _init_lambda_fn(self):
self._fn_distance = lambda x: x
self._fn_forces = lambda x: x

@property
def config(self):
assert len(self.__links__) > 0, "No links provided for fetching"
return dict(dataset_name=self.__name__, links=self.__links__)

@classmethod
def fetch(cls, cache_path: Optional[str] = None, overwrite: bool = False) -> None:
from openqdc.utils.download_api import DataDownloader

DataDownloader(cache_path, overwrite).from_config(cls.no_init().config)

def _post_init(
self,
overwrite_local_cache: bool = False,
Expand Down Expand Up @@ -256,6 +269,10 @@ def pkl_data_keys(self):
def pkl_data_types(self):
return {"name": str, "subset": str, "n_atoms": np.int32}

@property
def atom_energies(self):
return self._e0s_dispatcher

@property
def data_types(self):
return {
Expand Down Expand Up @@ -287,7 +304,11 @@ def _set_units(self, en, ds):
def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
logger.error("No energy methods defined for this dataset.")
f = get_conversion("hartree", self.__energy_unit__)
if self.energy_type == "formation":
f = get_conversion("hartree", self.__energy_unit__)
else:
# regression are calculated on the original unit of the dataset
f = get_conversion(self._original_unit, self.__energy_unit__)
self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix)

def convert_energy(self, x):
Expand Down
Loading

0 comments on commit 1885a0b

Please sign in to comment.