Skip to content

Commit

Permalink
Dataset structure factory
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Jul 9, 2024
1 parent b221f72 commit d031b13
Show file tree
Hide file tree
Showing 10 changed files with 404 additions and 284 deletions.
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- typer
- prettytable
- s3fs
- pydantic
- pydantic
- python-dotenv


Expand Down
93 changes: 50 additions & 43 deletions openqdc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AVAILABLE_INTERACTION_DATASETS,
AVAILABLE_POTENTIAL_DATASETS,
)
from openqdc.utils.io import get_local_cache

app = typer.Typer(help="OpenQDC CLI")

Expand Down Expand Up @@ -55,7 +56,7 @@ def download(
help="Path to the cache. If not provided, the default cache directory (.cache/openqdc/) will be used.",
),
] = None,
as_zarr : Annotated[
as_zarr: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the datasets.",
Expand All @@ -70,14 +71,14 @@ def download(
"""
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
ds=SANITIZED_AVAILABLE_DATASETS[dataset].no_init()
ds.read_as_zarr=as_zarr
ds = SANITIZED_AVAILABLE_DATASETS[dataset].no_init()
ds.read_as_zarr = as_zarr
if ds.is_cached() and not overwrite:
logger.info(f"{dataset} is already cached. Skipping download")
else:
SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=True,
cache_dir=cache_dir,
read_as_zarr=as_zarr)
SANITIZED_AVAILABLE_DATASETS[dataset](
overwrite_local_cache=True, cache_dir=cache_dir, read_as_zarr=as_zarr
)


@app.command()
Expand Down Expand Up @@ -185,7 +186,7 @@ def upload(
help="Whether to overwrite or force the re-download of the datasets.",
),
] = True,
as_zarr : Annotated[
as_zarr: Annotated[
bool,
typer.Option(
help="Whether to upload the zarr files if available.",
Expand All @@ -204,6 +205,7 @@ def upload(
logger.error(f"Error while uploading {dataset}. {e}. Did you preprocess the dataset first?")
raise e


@app.command()
def convert_to_zarr(
datasets: List[str],
Expand All @@ -221,80 +223,85 @@ def convert_to_zarr(
] = False,
):
"""
Conver a preprocessed dataset to the zarr file format.
Convert a preprocessed dataset to the zarr file format.
"""
import os
from os.path import join as p_join

import numpy as np
import zarr

from openqdc.utils.io import load_pkl
from openqdc.utils.io import load_pkl

def silent_remove(filename):
try:
os.remove(filename)
except OSError:
pass

for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
logger.info(f"Uploading {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}")
try:
ds=SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=download)
#os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True)
ds = SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=download)
# os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True)

pkl = load_pkl(p_join(ds.preprocess_path, "props.pkl"))
metadata=p_join(ds.preprocess_path, "metadata.zip")
if overwrite: silent_remove(metadata)
metadata = p_join(ds.preprocess_path, "metadata.zip")
if overwrite:
silent_remove(metadata)
group = zarr.group(zarr.storage.ZipStore(metadata))
for key, value in pkl.items():
#sub=group.create_group(key)
if key in ['name', 'subset']:
data=group.create_dataset(key, shape=value[0].shape, dtype=value[0].dtype)
data[:]=value[0][:]
data2=group.create_dataset(key + "_ptr", shape=value[1].shape,
dtype=np.int32)
# sub=group.create_group(key)
if key in ["name", "subset"]:
data = group.create_dataset(key, shape=value[0].shape, dtype=value[0].dtype)
data[:] = value[0][:]
data2 = group.create_dataset(key + "_ptr", shape=value[1].shape, dtype=np.int32)
data2[:] = value[1][:]
else:
data=group.create_dataset(key, shape=value.shape, dtype=value.dtype)
data[:]=value[:]
data = group.create_dataset(key, shape=value.shape, dtype=value.dtype)
data[:] = value[:]

force_attrs = {
"unit" : str(ds.force_unit),
"level_of_theory" : ds.force_methods,
"unit": str(ds.force_unit),
"level_of_theory": ds.force_methods,
}

energy_attrs = {
"unit" : str(ds.energy_unit),
"level_of_theory": ds.energy_methods
}
energy_attrs = {"unit": str(ds.energy_unit), "level_of_theory": ds.energy_methods}

atomic_inputs_attrs = {
"unit" : str(ds.distance_unit),
}
attrs = {
"forces" : force_attrs,
"energies" : energy_attrs,
"atomic_inputs" : atomic_inputs_attrs
"unit": str(ds.distance_unit),
}
attrs = {"forces": force_attrs, "energies": energy_attrs, "atomic_inputs": atomic_inputs_attrs}


#os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True)
# os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True)
for key, value in ds.data.items():
if key not in ds.data_keys:
continue
print(key, value.shape)

zarr_path=p_join(ds.preprocess_path, key + ".zip") #ds.__name__,
if overwrite: silent_remove(zarr_path)
z=zarr.open(zarr.storage.ZipStore(zarr_path), "w", zarr_version=2, shape=value.shape,
dtype=value.dtype)
z[:]=value[:]
zarr_path = p_join(ds.preprocess_path, key + ".zip") # ds.__name__,
if overwrite:
silent_remove(zarr_path)
z = zarr.open(
zarr.storage.ZipStore(zarr_path), "w", zarr_version=2, shape=value.shape, dtype=value.dtype
)
z[:] = value[:]
if key in attrs:
z.attrs.update(attrs[key])

except Exception as e:
logger.error(f"Error while converting {dataset}. {e}. Did you preprocess the dataset first?")
raise e



@app.command
def cache():
"""
Get the current local cache path of openQDC
"""
print(f"openQDC local cache:\n {get_local_cache()}")


if __name__ == "__main__":
app()
131 changes: 39 additions & 92 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
"""The BaseDataset defining shared functionality between all datasets."""

import os
import pickle as pkl
from functools import partial
from itertools import compress
from os.path import join as p_join
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import zarr
from ase.io.extxyz import write_extxyz
from loguru import logger
from sklearn.utils import Bunch
from tqdm import tqdm

from openqdc.datasets.dataset_structure import MemMapDataset, ZarrDataset
from openqdc.datasets.energies import AtomEnergies
from openqdc.datasets.properties import DatasetPropertyMixIn
from openqdc.datasets.statistics import (
Expand All @@ -33,7 +32,6 @@
copy_exists,
dict_to_atoms,
get_local_cache,
pull_locally,
push_remote,
set_cache_dir,
)
Expand Down Expand Up @@ -156,6 +154,12 @@ def _init_lambda_fn(self):
self._fn_distance = lambda x: x
self._fn_forces = lambda x: x

@property
def dataset_wrapper(self):
if not hasattr("_dataset_wrapper", self):
self._dataset_wrapper = ZarrDataset() if self.read_as_zarr else MemMapDataset()
return self._dataset_wrapper

@property
def config(self):
assert len(self.__links__) > 0, "No links provided for fetching"
Expand All @@ -167,17 +171,6 @@ def fetch(cls, cache_path: Optional[str] = None, overwrite: bool = False) -> Non

DataDownloader(cache_path, overwrite).from_config(cls.no_init().config)

@property
def ext(self):
return ".mmap" if not self.read_as_zarr else ".zip"

@property
def load_fn(self):
return np.memmap if not self.read_as_zarr else zarr.open

def add_extension(self, filename):
return filename + self.ext

def _post_init(
self,
overwrite_local_cache: bool = False,
Expand Down Expand Up @@ -392,29 +385,32 @@ def save_preprocess(self, data_dict, upload=False, overwrite=True, as_zarr: bool
"""
# save memmaps
logger.info("Preprocessing data and saving it to cache.")
for key in self.data_keys:
local_path = p_join(self.preprocess_path, f"{key}.mmap" if not as_zarr else f"{key}.zip")
out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape)
out[:] = data_dict.pop(key)[:]
out.flush()
if upload:
push_remote(local_path, overwrite=overwrite)

# save smiles and subset
local_path = p_join(self.preprocess_path, "props.pkl")

# assert that (required) pkl keys are present in data_dict
assert all([key in data_dict.keys() for key in self.pkl_data_keys])
paths = self.dataset_wrapper.save_preprocess(
self.preprocess_path, self.data_keys, data_dict, self.pkl_data_keys, self.pkl_data_types
)
if upload:
for local_path in paths:
push_remote(local_path, overwrite=overwrite) # make it async?

# store unique and inverse indices for str-based pkl keys
for key in self.pkl_data_keys:
if self.pkl_data_types[key] == str:
data_dict[key] = np.unique(data_dict[key], return_inverse=True)
def read_preprocess(self, overwrite_local_cache=False):
logger.info("Reading preprocessed data.")
logger.info(
f"Dataset {self.__name__} with the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.force_methods else 'None'}"
)

with open(local_path, "wb") as f:
pkl.dump(data_dict, f)
if upload:
push_remote(local_path, overwrite=overwrite)
self.data = self.dataset_wrapper.load_data(
self.preprocess_path,
self.data_keys,
self.data_types,
self.data_shapes,
self.pkl_data_keys,
overwrite_local_cache,
) # this should be async if possible
for key in self.data:
logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")

def _convert_on_loading(self, x, key):
if key == "energies":
Expand All @@ -428,74 +424,26 @@ def _convert_on_loading(self, x, key):
else:
return x

def read_preprocess(self, overwrite_local_cache=False):
logger.info("Reading preprocessed data.")
logger.info(
f"Dataset {self.__name__} with the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.force_methods else 'None'}"
)
self.data = {}
for key in self.data_keys:
filename = p_join(self.preprocess_path, self.add_extension(f"{key}"))
pull_locally(filename, overwrite=overwrite_local_cache)
self.data[key] = self.load_fn(filename, mode="r", dtype=self.data_types[key])
if self.read_as_zarr:
self.data[key] = self.data[key][:]
self.data[key] = self.data[key].reshape(*self.data_shapes[key])

if not self.read_as_zarr:
filename = p_join(self.preprocess_path, "props.pkl")
pull_locally(filename, overwrite=overwrite_local_cache)
with open(filename, "rb") as f:
tmp = pkl.load(f)
all_pkl_keys = set(tmp.keys()) - set(self.data_keys)
# assert required pkl_keys are present in all_pkl_keys
assert all([key in all_pkl_keys for key in self.pkl_data_keys])
for key in all_pkl_keys:
x = tmp.pop(key)
if len(x) == 2:
self.data[key] = x[0][x[1]]
else:
self.data[key] = x
else:
filename = p_join(self.preprocess_path, self.add_extension("metadata"))
pull_locally(filename, overwrite=overwrite_local_cache)
tmp = self.load_fn(filename)
all_pkl_keys = set(tmp.keys()) - set(self.data_keys)
# assert required pkl_keys are present in all_pkl_keys
assert all([key in all_pkl_keys for key in self.pkl_data_keys])
for key in all_pkl_keys:
if key not in self.pkl_data_keys:
#print(key, list(tmp.items()))
self.data[key] = tmp[key][:][tmp[key][:]]
else:
self.data[key] = tmp[key][:]

for key in self.data:
logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")

def is_preprocessed(self):
"""
Check if the dataset is preprocessed and available online or locally.
"""
predicats = [copy_exists(p_join(self.preprocess_path, self.add_extension(f"{key}"))) for key in self.data_keys]

if not self.read_as_zarr:
predicats += [copy_exists(p_join(self.preprocess_path, "props.pkl"))]
print(predicats)
predicats = [
copy_exists(p_join(self.preprocess_path, self.dataset_wrapper.add_extension(f"{key}")))
for key in self.data_keys
]
predicats += [copy_exists(p_join(self.preprocess_path, file)) for file in self.dataset_wrapper._extra_files]
return all(predicats)

def is_cached(self):
"""
Check if the dataset is cached locally.
"""
predicats = [
os.path.exists(p_join(self.preprocess_path, self.add_extension(f"{key}"))) for key in self.data_keys
os.path.exists(p_join(self.preprocess_path, self.dataset_wrapper.add_extension(f"{key}")))
for key in self.data_keys
]
if not self.read_as_zarr:
predicats += [os.path.exists(p_join(self.preprocess_path, "props.pkl"))]
predicats += [copy_exists(p_join(self.preprocess_path, file)) for file in self.dataset_wrapper._extra_files]
return all(predicats)

def preprocess(self, upload: bool = False, overwrite: bool = True, as_zarr: bool = True):
Expand All @@ -521,7 +469,6 @@ def upload(self, overwrite: bool = False, as_zarr: bool = False):
push_remote(local_path, overwrite=overwrite)
local_path = p_join(self.preprocess_path, "props.pkl" if not as_zarr else "metadata.zip")
push_remote(local_path, overwrite=overwrite)


def save_xyz(self, idx: int, energy_method: int = 0, path: Optional[str] = None, ext=True):
"""
Expand Down
Loading

0 comments on commit d031b13

Please sign in to comment.