Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change in download api #84

Merged
merged 13 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ jobs:
- name: Run tests
run: python -m pytest

- name: Test building the doc
run: mkdocs build
#- name: Test building the doc
# run: mkdocs build
4 changes: 2 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ docs_dir: "docs"
nav:
- Overview: index.md
- Available Datasets: datasets.md
- Tutorials:
- Really hard example: tutorials/usage.ipynb
#- Tutorials:
# #- Really hard example: tutorials/usage.ipynb
- API:
- Datasets: API/available_datasets.md
- Isolated Atoms Energies: API/isolated_atom_energies.md
Expand Down
52 changes: 40 additions & 12 deletions openqdc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from typing_extensions import Annotated

from openqdc.datasets import COMMON_MAP_POTENTIALS # noqa
from openqdc.datasets import AVAILABLE_DATASETS, AVAILABLE_POTENTIAL_DATASETS
from openqdc.raws.config_factory import DataConfigFactory, DataDownloader
from openqdc.datasets import (
AVAILABLE_DATASETS,
AVAILABLE_INTERACTION_DATASETS,
AVAILABLE_POTENTIAL_DATASETS,
)

app = typer.Typer(help="OpenQDC CLI")

Expand Down Expand Up @@ -83,22 +86,49 @@ def datasets():


@app.command()
def fetch(datasets: List[str]):
def fetch(
datasets: List[str],
overwrite: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the files.",
),
] = False,
cache_dir: Annotated[
Optional[str],
typer.Option(
help="Path to the cache. If not provided, the default cache directory (.cache/openqdc/) will be used.",
),
] = None,
):
"""
Download the raw datasets files from the main openQDC hub.
Special case: if the dataset is "all", all available datasets will be downloaded.

overwrite: bool = False,
If True, the files will be re-downloaded and overwritten.
cache_dir: Optional[str] = None,
Path to the cache. If not provided, the default cache directory will be used.
Special case: if the dataset is "all", "potential", "interaction".
all: all available datasets will be downloaded.
potential: all the potential datasets will be downloaded
interaction: all the interaction datasets will be downloaded
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add docstrings, what does overwrite, cache_dir do? Also, can you add a usage command here + in the main readme?

Example:
openqdc fetch Spice
"""
if datasets[0] == "all":
dataset_names = DataConfigFactory.available_datasets
if datasets[0].lower() == "all":
dataset_names = AVAILABLE_DATASETS
elif datasets[0].lower() == "potential":
dataset_names = AVAILABLE_POTENTIAL_DATASETS
elif datasets[0].lower() == "interaction":
dataset_names = AVAILABLE_INTERACTION_DATASETS
else:
dataset_names = datasets

for dataset_name in dataset_names:
dd = DataDownloader()
dd.from_name(dataset_name)
for dataset in list(map(lambda x: x.lower().replace("_", ""), dataset_names)):
if exist_dataset(dataset):
try:
AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite)
except Exception as e:
logger.error(f"Something unexpected happended while fetching {dataset}: {repr(e)}")


@app.command()
Expand Down Expand Up @@ -128,8 +158,6 @@ def preprocess(
except Exception as e:
logger.error(f"Error while preprocessing {dataset}. {e}. Did you fetch the dataset first?")
raise e
else:
logger.warning(f"{dataset} not found.")


if __name__ == "__main__":
Expand Down
65 changes: 22 additions & 43 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class BaseDataset(DatasetPropertyMixIn):
__distance_unit__ = "ang"
__forces_unit__ = "hartree/ang"
__average_nb_atoms__ = None
__links__ = {}

def __init__(
self,
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
set_cache_dir(cache_dir)
# self._init_lambda_fn()
self.data = None
self._original_unit = self.__energy_unit__
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
Expand All @@ -145,6 +147,17 @@ def _init_lambda_fn(self):
self._fn_distance = lambda x: x
self._fn_forces = lambda x: x

@property
def config(self):
assert len(self.__links__) > 0, "No links provided for fetching"
return dict(dataset_name=self.__name__, links=self.__links__)

@classmethod
def fetch(cls, cache_path: Optional[str] = None, overwrite: bool = False) -> None:
from openqdc.utils.download_api import DataDownloader

DataDownloader(cache_path, overwrite).from_config(cls.no_init().config)

def _post_init(
self,
overwrite_local_cache: bool = False,
Expand Down Expand Up @@ -256,6 +269,10 @@ def pkl_data_keys(self):
def pkl_data_types(self):
return {"name": str, "subset": str, "n_atoms": np.int32}

@property
def atom_energies(self):
return self._e0s_dispatcher

@property
def data_types(self):
return {
Expand Down Expand Up @@ -287,7 +304,11 @@ def _set_units(self, en, ds):
def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
logger.error("No energy methods defined for this dataset.")
f = get_conversion("hartree", self.__energy_unit__)
if self.energy_type == "formation":
f = get_conversion("hartree", self.__energy_unit__)
else:
# regression are calculated on the original unit of the dataset
f = get_conversion(self._original_unit, self.__energy_unit__)
self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix)

def convert_energy(self, x):
Expand Down Expand Up @@ -546,48 +567,6 @@ def wrapper(idx):
datum["idxs"] = idxs
return datum

@classmethod
def as_dataloader(
cls,
batch_size: int = 8,
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
array_format: str = "torch",
energy_type: str = "formation",
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
recompute_statistics: bool = False,
transform: Optional[Callable] = None,
):
"""
Return the dataset as a dataloader.

Parameters
----------
batch_size : int, optional
Batch size, by default 8
For other parameters, see the __init__ method.
"""
if not has_package("torch_geometric"):
raise ImportError("torch_geometric is required to use this method.")
assert array_format in ["torch", "jax"], f"Format {array_format} must be torch or jax."
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

return DataLoader(
cls(
energy_unit=energy_unit,
distance_unit=distance_unit,
array_format=array_format,
energy_type=energy_type,
overwrite_local_cache=overwrite_local_cache,
cache_dir=cache_dir,
recompute_statistics=recompute_statistics,
transform=lambda x: Data(**x) if transform is None else transform,
),
batch_size=batch_size,
)

def as_iter(self, atoms: bool = False, energy_method: int = 0):
"""
Return the dataset as an iterator.
Expand Down
Loading
Loading