diff --git a/env.yml b/env.yml index 16ccc3c..7787e28 100644 --- a/env.yml +++ b/env.yml @@ -11,10 +11,15 @@ dependencies: - gcsfs - typer - prettytable + - s3fs + - pydantic + - python-dotenv + # Scientific - pandas - numpy + - zarr # Chem - datamol #==0.9.0 diff --git a/openqdc/cli.py b/openqdc/cli.py index c9cc2b0..5c88eda 100644 --- a/openqdc/cli.py +++ b/openqdc/cli.py @@ -175,6 +175,12 @@ def upload( help="Whether to overwrite or force the re-download of the datasets.", ), ] = True, + as_zarr : Annotated[ + bool, + typer.Option( + help="Whether to upload the zarr files if available.", + ), + ] = False, ): """ Upload a preprocessed dataset to the remote storage. @@ -183,11 +189,100 @@ def upload( if exist_dataset(dataset): logger.info(f"Uploading {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}") try: - SANITIZED_AVAILABLE_DATASETS[dataset]().upload(overwrite=overwrite) + SANITIZED_AVAILABLE_DATASETS[dataset]().upload(overwrite=overwrite, as_zarr=as_zarr) except Exception as e: logger.error(f"Error while uploading {dataset}. {e}. Did you preprocess the dataset first?") raise e +@app.command() +def convert_to_zarr( + datasets: List[str], + overwrite: Annotated[ + bool, + typer.Option( + help="Whether to overwrite or force the re-download of the datasets.", + ), + ] = False, + download: Annotated[ + bool, + typer.Option( + help="Whether to force the re-download of the datasets.", + ), + ] = False, +): + """ + Conver a preprocessed dataset to the zarr file format. + """ + import zarr + from openqdc.utils.io import load_pkl + from os.path import join as p_join + import numpy as np + import os + def silent_remove(filename): + try: + os.remove(filename) + except OSError: + pass + for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)): + if exist_dataset(dataset): + logger.info(f"Uploading {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}") + try: + ds=SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=download) + #os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True) + + pkl = load_pkl(p_join(ds.preprocess_path, "props.pkl")) + metadata=p_join(ds.preprocess_path, "metadata.zip") + if overwrite: silent_remove(metadata) + group = zarr.group(zarr.storage.ZipStore(metadata)) + for key, value in pkl.items(): + #sub=group.create_group(key) + if key in ['name', 'subset']: + data=group.create_dataset(key, shape=value[0].shape, dtype=value[0].dtype) + data[:]=value[0][:] + data2=group.create_dataset(key + "_ptr", shape=value[1].shape, + dtype=np.int32) + data2[:] = value[1][:] + else: + data=group.create_dataset(key, shape=value.shape, dtype=value.dtype) + data[:]=value[:] + + force_attrs = { + "unit" : str(ds.force_unit), + "level_of_theory" : ds.force_methods, + } + + energy_attrs = { + "unit" : str(ds.energy_unit), + "level_of_theory": ds.energy_methods + } + + atomic_inputs_attrs = { + "unit" : str(ds.distance_unit), + } + attrs = { + "forces" : force_attrs, + "energies" : energy_attrs, + "atomic_inputs" : atomic_inputs_attrs + } + + + #os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True) + for key, value in ds.data.items(): + if key not in ds.data_keys: + continue + print(key, value.shape) + zarr_path=p_join(ds.preprocess_path, key + ".zip") #ds.__name__, + if overwrite: silent_remove(zarr_path) + z=zarr.open(zarr.storage.ZipStore(zarr_path), "w", zarr_version=2, shape=value.shape, + dtype=value.dtype) + z[:]=value[:] + if key in attrs: + z.attrs.update(attrs[key]) + + except Exception as e: + logger.error(f"Error while converting {dataset}. {e}. Did you preprocess the dataset first?") + raise e + if __name__ == "__main__": app() diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py index fbcdca7..0977a7f 100644 --- a/openqdc/datasets/base.py +++ b/openqdc/datasets/base.py @@ -273,7 +273,7 @@ def root(self): @property def preprocess_path(self): - path = p_join(self.root, "preprocessed" if not self.read_as_zarr else "zarr") + path = p_join(self.root, "preprocessed") os.makedirs(path, exist_ok=True) return path @@ -393,7 +393,7 @@ def save_preprocess(self, data_dict, upload=False, overwrite=True, as_zarr: bool # save memmaps logger.info("Preprocessing data and saving it to cache.") for key in self.data_keys: - local_path = p_join(self.preprocess_path, f"{key}.mmap") + local_path = p_join(self.preprocess_path, f"{key}.mmap" if not as_zarr else f"{key}.zip") out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape) out[:] = data_dict.pop(key)[:] out.flush() @@ -468,6 +468,7 @@ def read_preprocess(self, overwrite_local_cache=False): assert all([key in all_pkl_keys for key in self.pkl_data_keys]) for key in all_pkl_keys: if key not in self.pkl_data_keys: + print(key, list(tmp.items())) self.data[key] = tmp[key][:][tmp[key][:]] else: self.data[key] = tmp[key][:] @@ -513,7 +514,12 @@ def upload(self, overwrite: bool = False, as_zarr: bool = False): """ Upload the preprocessed data to the remote storage. """ - self.save_preprocess(self.data, True, overwrite, as_zarr) + for key in self.data_keys: + local_path = p_join(self.preprocess_path, f"{key}.mmap" if not as_zarr else f"{key}.zip") + push_remote(local_path, overwrite=overwrite) + local_path = p_join(self.preprocess_path, "props.pkl" if not as_zarr else "metadata.zip") + push_remote(local_path, overwrite=overwrite) + def save_xyz(self, idx: int, energy_method: int = 0, path: Optional[str] = None, ext=True): """ diff --git a/openqdc/datasets/potential/__init__.py b/openqdc/datasets/potential/__init__.py index da04cae..0c184cd 100644 --- a/openqdc/datasets/potential/__init__.py +++ b/openqdc/datasets/potential/__init__.py @@ -24,6 +24,7 @@ from .transition1x import Transition1X from .vqm24 import VQM24 from .waterclusters3_30 import WaterClusters +from .waterclusters import SCANWaterClusters AVAILABLE_POTENTIAL_DATASETS = { "Alchemy": Alchemy, @@ -59,6 +60,7 @@ "TMQM": TMQM, "Transition1X": Transition1X, "WaterClusters": WaterClusters, + "SCANWaterClusters": SCANWaterClusters, "MultixcQM9": MultixcQM9, "MultixcQM9_V2": MultixcQM9_V2, "RevMD17": RevMD17, diff --git a/openqdc/datasets/potential/qmx.py b/openqdc/datasets/potential/qmx.py index 95a3821..2348587 100644 --- a/openqdc/datasets/potential/qmx.py +++ b/openqdc/datasets/potential/qmx.py @@ -122,6 +122,11 @@ class QM7(QMX): 'WB97M-V:def2-tzvp', 'WB97M-V:sto-3g', 'WB97X-D:aug-cc-pvdz', 'WB97X-D:aug-cc-pvtz', 'WB97X-D:def2-svp', 'WB97X-D:def2-tzvp', 'WB97X-D:sto-3g'] + + __energy_methods__ = [ + PotentialMethod.NONE for _ in range(len(energy_target_names)) # "wb97x/6-31g(d)" + ] + diff --git a/openqdc/datasets/potential/waterclusters.py b/openqdc/datasets/potential/waterclusters.py new file mode 100644 index 0000000..b9961ee --- /dev/null +++ b/openqdc/datasets/potential/waterclusters.py @@ -0,0 +1,104 @@ +import zipfile +from io import StringIO +from os.path import join as p_join + +import numpy as np +from tqdm import tqdm +from collections import defaultdict +from openqdc.datasets.base import BaseDataset +from openqdc.methods import PotentialMethod +from openqdc.utils.constants import ATOM_TABLE, MAX_ATOMIC_NUMBER +from openqdc.utils.package_utils import requires_package + +_default_basis_sets = { + "BEGDB_H2O": "aug-cc-pVQZ", + "WATER27": "aug-cc-pVQZ", + "H2O_alkali_clusters": "def2-QZVPPD", + "H2O_halide_clusters": "def2-QZVPPD", +} + +@requires_package("monty") +@requires_package("pymatgen") +def read_geometries(fname, dataset): + from monty.serialization import loadfn + geometries = {k: v.to_ase_atoms() for k, v in loadfn(fname)[dataset].items()} + return geometries + +@requires_package("monty") +def read_energies(fname, dataset): + from monty.serialization import loadfn + # fname + _energies = loadfn(fname)[dataset] + metadata_restrictions = {"basis_set": _default_basis_sets.get(dataset)} + + functionals_to_return = [] + for dfa, at_dfa_d in _energies.items(): + functionals_to_return += [ + f"{dfa}" if dfa == at_dfa else f"{dfa}@{at_dfa}" + for at_dfa in at_dfa_d + ] + + energies = defaultdict(dict) + for f in functionals_to_return: + if "-FLOSIC" in f and "@" not in f: + func = f.split("-FLOSIC")[0] + at_f = "-FLOSIC" + else: + func = f.split("@")[0] + at_f = f.split("@")[-1] + + if func not in _energies: + print(f"No functional {func} included in dataset - available options:\n{', '.join(_energies.keys())}") + elif at_f not in _energies[func]: + print(f"No @functional {at_f} included in {func} dataset - available options:\n{', '.join(_energies[func].keys())}") + else: + if isinstance(_energies[func][at_f],list): + for entry in _energies[func][at_f]: + if all( + entry["metadata"].get(k) == v for k, v in metadata_restrictions.items() + ): + energies[f] = entry + break + if f not in energies: + print(f"No matching metadata {json.dumps(metadata_restrictions)} for method {f}") + else: + energies[f] = _energies[func][at_f] + return dict(energies) + +def format_geometry_and_entries(geometry, energies, subset): + pass + + +class SCANWaterClusters(BaseDataset): + """https://chemrxiv.org/engage/chemrxiv/article-details/662aaff021291e5d1db7d8ec""" + + __name__ = "scanwaterclusters" + __energy_methods__ = [PotentialMethod.GFN2_XTB] + + __energy_unit__ = "hartree" + __distance_unit__ = "ang" + __forces_unit__ = "hartree/ang" + + energy_target_names = ['HF', 'HF-r2SCAN-DC4', 'SCAN', 'SCAN@HF', 'SCAN@r2SCAN50', 'r2SCAN', 'r2SCAN@HF', 'r2SCAN@r2SCAN50', 'r2SCAN50', 'r2SCAN100', 'r2SCAN10', 'r2SCAN20', 'r2SCAN25', 'r2SCAN30', 'r2SCAN40', 'r2SCAN60', 'r2SCAN70', 'r2SCAN80', 'r2SCAN90'] + force_target_names = [] + + subsets = ["BEGDB_H2O","WATER27","H2O_alkali_clusters","H2O_halide_clusters"] + __links__ = { + "geometries.json.gz" : "https://github.com/esoteric-ephemera/water_cluster_density_errors/blob/main/data_files/geometries.json.gz?raw=True", + "total_energies.json.gz" : "https://github.com/esoteric-ephemera/water_cluster_density_errors/blob/main/data_files/total_energies.json.gz?raw=True" + } + + def read_raw_entries(self): + entries=[] + for i, subset in enumerate(self.subsets): + + geometries = read_geometries(p_join(self.root, "geometries.json.gz" ), subset) + energies = read_energies(p_join(self.root, "total_energies.json.gz" ), subset) + datum ={} + for k in energies: + _ = energies[k].pop("metadata") + datum[k] = energies[k]["total_energies"] + + return pd.concat([pd.DataFrame({"positions" : geometries}),datum], axis=1) + + diff --git a/openqdc/utils/download_api.py b/openqdc/utils/download_api.py index c96f3d9..1687b3e 100644 --- a/openqdc/utils/download_api.py +++ b/openqdc/utils/download_api.py @@ -10,6 +10,8 @@ from dataclasses import dataclass from typing import Optional +from dotenv import load_dotenv + import fsspec import gdown import requests @@ -23,7 +25,6 @@ import openqdc.utils.io as ioqdc - @dataclass class FileSystem: """ @@ -33,7 +34,13 @@ class FileSystem: public_endpoint: Optional[AbstractFileSystem] = None private_endpoint: Optional[AbstractFileSystem] = None local_endpoint: AbstractFileSystem = LocalFileSystem() - + endpoint_url = 'https://874f02b9d981bd6c279e979c0d91c4b4.r2.cloudflarestorage.com' + + def __init__(self): + logger.warning("Problem loading enviromnet variables") if not load_dotenv() else "" + self.KEY = os.getenv("CLOUDFARE_KEY", None) + self.SECRET = os.getenv("CLOUDFARE_SECRET", None) + @property def public(self): self.connect() @@ -64,16 +71,24 @@ def connect(self): warnings.simplefilter("ignore") # No quota warning self.public_endpoint = self.get_default_endpoint("public") self.private_endpoint = self.get_default_endpoint("private") - self.public_endpoint.client_kwargs = {"timeout": ClientTimeout(total=3600, connect=1000)} + #self.public_endpoint.client_kwargs = {"timeout": ClientTimeout(total=3600, connect=1000)} def get_default_endpoint(self, endpoint: str) -> AbstractFileSystem: """ Return a default endpoint for the given str [public, private] """ if endpoint == "private": - return fsspec.filesystem("gs") + #return fsspec.filesystem("gs") + return fsspec.filesystem("s3", + key=self.KEY, + secret=self.SECRET, + endpoint_url=self.endpoint_url) elif endpoint == "public": - return fsspec.filesystem("https") + return fsspec.filesystem("s3", + key='a046308c078b0134c9e261aa91f63ab2', + secret='d5b32f241ad8ee8d0a3173cd51b4f36d6869f168b21acef75f244a81dc10e1fb', + endpoint_url=self.endpoint_url) + #return fsspec.filesystem("https") else: return self.local_endpoint @@ -283,3 +298,4 @@ def from_config(self, config: dict): API = FileSystem() + diff --git a/openqdc/utils/io.py b/openqdc/utils/io.py index f462f9c..6e50389 100644 --- a/openqdc/utils/io.py +++ b/openqdc/utils/io.py @@ -54,9 +54,10 @@ def get_remote_cache(write_access=False) -> str: Returns the entry point based on the write access. """ if write_access: - remote_cache = "gs://qmdata-public/openqdc" + remote_cache = "/openqdc/v1" #"gs://qmdata-public/openqdc" else: - remote_cache = "https://storage.googleapis.com/qmdata-public/openqdc" + remote_cache = "/openqdc/v1" + #remote_cache = "https://storage.googleapis.com/qmdata-public/openqdc" return remote_cache