Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Jun 30, 2024
1 parent ea59515 commit 0c60be6
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 11 deletions.
5 changes: 5 additions & 0 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@ dependencies:
- gcsfs
- typer
- prettytable
- s3fs
- pydantic
- python-dotenv


# Scientific
- pandas
- numpy
- zarr

# Chem
- datamol #==0.9.0
Expand Down
97 changes: 96 additions & 1 deletion openqdc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ def upload(
help="Whether to overwrite or force the re-download of the datasets.",
),
] = True,
as_zarr : Annotated[
bool,
typer.Option(
help="Whether to upload the zarr files if available.",
),
] = False,
):
"""
Upload a preprocessed dataset to the remote storage.
Expand All @@ -183,11 +189,100 @@ def upload(
if exist_dataset(dataset):
logger.info(f"Uploading {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}")
try:
SANITIZED_AVAILABLE_DATASETS[dataset]().upload(overwrite=overwrite)
SANITIZED_AVAILABLE_DATASETS[dataset]().upload(overwrite=overwrite, as_zarr=as_zarr)
except Exception as e:
logger.error(f"Error while uploading {dataset}. {e}. Did you preprocess the dataset first?")
raise e

@app.command()
def convert_to_zarr(
datasets: List[str],
overwrite: Annotated[
bool,
typer.Option(
help="Whether to overwrite or force the re-download of the datasets.",
),
] = False,
download: Annotated[
bool,
typer.Option(
help="Whether to force the re-download of the datasets.",
),
] = False,
):
"""
Conver a preprocessed dataset to the zarr file format.
"""
import zarr
from openqdc.utils.io import load_pkl
from os.path import join as p_join
import numpy as np
import os
def silent_remove(filename):
try:
os.remove(filename)
except OSError:
pass
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
logger.info(f"Uploading {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}")
try:
ds=SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=download)
#os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True)

pkl = load_pkl(p_join(ds.preprocess_path, "props.pkl"))
metadata=p_join(ds.preprocess_path, "metadata.zip")
if overwrite: silent_remove(metadata)
group = zarr.group(zarr.storage.ZipStore(metadata))
for key, value in pkl.items():
#sub=group.create_group(key)
if key in ['name', 'subset']:
data=group.create_dataset(key, shape=value[0].shape, dtype=value[0].dtype)
data[:]=value[0][:]
data2=group.create_dataset(key + "_ptr", shape=value[1].shape,
dtype=np.int32)
data2[:] = value[1][:]
else:
data=group.create_dataset(key, shape=value.shape, dtype=value.dtype)
data[:]=value[:]

force_attrs = {
"unit" : str(ds.force_unit),
"level_of_theory" : ds.force_methods,
}

energy_attrs = {
"unit" : str(ds.energy_unit),
"level_of_theory": ds.energy_methods
}

atomic_inputs_attrs = {
"unit" : str(ds.distance_unit),
}
attrs = {
"forces" : force_attrs,
"energies" : energy_attrs,
"atomic_inputs" : atomic_inputs_attrs
}


#os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True)
for key, value in ds.data.items():
if key not in ds.data_keys:
continue
print(key, value.shape)

zarr_path=p_join(ds.preprocess_path, key + ".zip") #ds.__name__,
if overwrite: silent_remove(zarr_path)
z=zarr.open(zarr.storage.ZipStore(zarr_path), "w", zarr_version=2, shape=value.shape,
dtype=value.dtype)
z[:]=value[:]
if key in attrs:
z.attrs.update(attrs[key])

except Exception as e:
logger.error(f"Error while converting {dataset}. {e}. Did you preprocess the dataset first?")
raise e

if __name__ == "__main__":
app()
12 changes: 9 additions & 3 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def root(self):

@property
def preprocess_path(self):
path = p_join(self.root, "preprocessed" if not self.read_as_zarr else "zarr")
path = p_join(self.root, "preprocessed")
os.makedirs(path, exist_ok=True)
return path

Expand Down Expand Up @@ -393,7 +393,7 @@ def save_preprocess(self, data_dict, upload=False, overwrite=True, as_zarr: bool
# save memmaps
logger.info("Preprocessing data and saving it to cache.")
for key in self.data_keys:
local_path = p_join(self.preprocess_path, f"{key}.mmap")
local_path = p_join(self.preprocess_path, f"{key}.mmap" if not as_zarr else f"{key}.zip")
out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape)
out[:] = data_dict.pop(key)[:]
out.flush()
Expand Down Expand Up @@ -468,6 +468,7 @@ def read_preprocess(self, overwrite_local_cache=False):
assert all([key in all_pkl_keys for key in self.pkl_data_keys])
for key in all_pkl_keys:
if key not in self.pkl_data_keys:
print(key, list(tmp.items()))
self.data[key] = tmp[key][:][tmp[key][:]]
else:
self.data[key] = tmp[key][:]
Expand Down Expand Up @@ -513,7 +514,12 @@ def upload(self, overwrite: bool = False, as_zarr: bool = False):
"""
Upload the preprocessed data to the remote storage.
"""
self.save_preprocess(self.data, True, overwrite, as_zarr)
for key in self.data_keys:
local_path = p_join(self.preprocess_path, f"{key}.mmap" if not as_zarr else f"{key}.zip")
push_remote(local_path, overwrite=overwrite)
local_path = p_join(self.preprocess_path, "props.pkl" if not as_zarr else "metadata.zip")
push_remote(local_path, overwrite=overwrite)


def save_xyz(self, idx: int, energy_method: int = 0, path: Optional[str] = None, ext=True):
"""
Expand Down
2 changes: 2 additions & 0 deletions openqdc/datasets/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .transition1x import Transition1X
from .vqm24 import VQM24
from .waterclusters3_30 import WaterClusters
from .waterclusters import SCANWaterClusters

AVAILABLE_POTENTIAL_DATASETS = {
"Alchemy": Alchemy,
Expand Down Expand Up @@ -59,6 +60,7 @@
"TMQM": TMQM,
"Transition1X": Transition1X,
"WaterClusters": WaterClusters,
"SCANWaterClusters": SCANWaterClusters,
"MultixcQM9": MultixcQM9,
"MultixcQM9_V2": MultixcQM9_V2,
"RevMD17": RevMD17,
Expand Down
5 changes: 5 additions & 0 deletions openqdc/datasets/potential/qmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ class QM7(QMX):
'WB97M-V:def2-tzvp', 'WB97M-V:sto-3g', 'WB97X-D:aug-cc-pvdz',
'WB97X-D:aug-cc-pvtz', 'WB97X-D:def2-svp', 'WB97X-D:def2-tzvp',
'WB97X-D:sto-3g']

__energy_methods__ = [
PotentialMethod.NONE for _ in range(len(energy_target_names)) # "wb97x/6-31g(d)"
]




Expand Down
104 changes: 104 additions & 0 deletions openqdc/datasets/potential/waterclusters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import zipfile
from io import StringIO
from os.path import join as p_join

import numpy as np
from tqdm import tqdm
from collections import defaultdict
from openqdc.datasets.base import BaseDataset
from openqdc.methods import PotentialMethod
from openqdc.utils.constants import ATOM_TABLE, MAX_ATOMIC_NUMBER
from openqdc.utils.package_utils import requires_package

_default_basis_sets = {
"BEGDB_H2O": "aug-cc-pVQZ",
"WATER27": "aug-cc-pVQZ",
"H2O_alkali_clusters": "def2-QZVPPD",
"H2O_halide_clusters": "def2-QZVPPD",
}

@requires_package("monty")
@requires_package("pymatgen")
def read_geometries(fname, dataset):
from monty.serialization import loadfn
geometries = {k: v.to_ase_atoms() for k, v in loadfn(fname)[dataset].items()}
return geometries

@requires_package("monty")
def read_energies(fname, dataset):
from monty.serialization import loadfn
# fname
_energies = loadfn(fname)[dataset]
metadata_restrictions = {"basis_set": _default_basis_sets.get(dataset)}

functionals_to_return = []
for dfa, at_dfa_d in _energies.items():
functionals_to_return += [
f"{dfa}" if dfa == at_dfa else f"{dfa}@{at_dfa}"
for at_dfa in at_dfa_d
]

energies = defaultdict(dict)
for f in functionals_to_return:
if "-FLOSIC" in f and "@" not in f:
func = f.split("-FLOSIC")[0]
at_f = "-FLOSIC"
else:
func = f.split("@")[0]
at_f = f.split("@")[-1]

if func not in _energies:
print(f"No functional {func} included in dataset - available options:\n{', '.join(_energies.keys())}")
elif at_f not in _energies[func]:
print(f"No @functional {at_f} included in {func} dataset - available options:\n{', '.join(_energies[func].keys())}")
else:
if isinstance(_energies[func][at_f],list):
for entry in _energies[func][at_f]:
if all(
entry["metadata"].get(k) == v for k, v in metadata_restrictions.items()
):
energies[f] = entry
break
if f not in energies:
print(f"No matching metadata {json.dumps(metadata_restrictions)} for method {f}")
else:
energies[f] = _energies[func][at_f]
return dict(energies)

def format_geometry_and_entries(geometry, energies, subset):
pass


class SCANWaterClusters(BaseDataset):
"""https://chemrxiv.org/engage/chemrxiv/article-details/662aaff021291e5d1db7d8ec"""

__name__ = "scanwaterclusters"
__energy_methods__ = [PotentialMethod.GFN2_XTB]

__energy_unit__ = "hartree"
__distance_unit__ = "ang"
__forces_unit__ = "hartree/ang"

energy_target_names = ['HF', 'HF-r2SCAN-DC4', 'SCAN', 'SCAN@HF', 'SCAN@r2SCAN50', 'r2SCAN', 'r2SCAN@HF', 'r2SCAN@r2SCAN50', 'r2SCAN50', 'r2SCAN100', 'r2SCAN10', 'r2SCAN20', 'r2SCAN25', 'r2SCAN30', 'r2SCAN40', 'r2SCAN60', 'r2SCAN70', 'r2SCAN80', 'r2SCAN90']
force_target_names = []

subsets = ["BEGDB_H2O","WATER27","H2O_alkali_clusters","H2O_halide_clusters"]
__links__ = {
"geometries.json.gz" : "https://github.com/esoteric-ephemera/water_cluster_density_errors/blob/main/data_files/geometries.json.gz?raw=True",
"total_energies.json.gz" : "https://github.com/esoteric-ephemera/water_cluster_density_errors/blob/main/data_files/total_energies.json.gz?raw=True"
}

def read_raw_entries(self):
entries=[]
for i, subset in enumerate(self.subsets):

geometries = read_geometries(p_join(self.root, "geometries.json.gz" ), subset)
energies = read_energies(p_join(self.root, "total_energies.json.gz" ), subset)
datum ={}
for k in energies:
_ = energies[k].pop("metadata")
datum[k] = energies[k]["total_energies"]

return pd.concat([pd.DataFrame({"positions" : geometries}),datum], axis=1)


26 changes: 21 additions & 5 deletions openqdc/utils/download_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dataclasses import dataclass
from typing import Optional

from dotenv import load_dotenv

import fsspec
import gdown
import requests
Expand All @@ -23,7 +25,6 @@

import openqdc.utils.io as ioqdc


@dataclass
class FileSystem:
"""
Expand All @@ -33,7 +34,13 @@ class FileSystem:
public_endpoint: Optional[AbstractFileSystem] = None
private_endpoint: Optional[AbstractFileSystem] = None
local_endpoint: AbstractFileSystem = LocalFileSystem()

endpoint_url = 'https://874f02b9d981bd6c279e979c0d91c4b4.r2.cloudflarestorage.com'

def __init__(self):
logger.warning("Problem loading enviromnet variables") if not load_dotenv() else ""
self.KEY = os.getenv("CLOUDFARE_KEY", None)
self.SECRET = os.getenv("CLOUDFARE_SECRET", None)

@property
def public(self):
self.connect()
Expand Down Expand Up @@ -64,16 +71,24 @@ def connect(self):
warnings.simplefilter("ignore") # No quota warning
self.public_endpoint = self.get_default_endpoint("public")
self.private_endpoint = self.get_default_endpoint("private")
self.public_endpoint.client_kwargs = {"timeout": ClientTimeout(total=3600, connect=1000)}
#self.public_endpoint.client_kwargs = {"timeout": ClientTimeout(total=3600, connect=1000)}

def get_default_endpoint(self, endpoint: str) -> AbstractFileSystem:
"""
Return a default endpoint for the given str [public, private]
"""
if endpoint == "private":
return fsspec.filesystem("gs")
#return fsspec.filesystem("gs")
return fsspec.filesystem("s3",
key=self.KEY,
secret=self.SECRET,
endpoint_url=self.endpoint_url)
elif endpoint == "public":
return fsspec.filesystem("https")
return fsspec.filesystem("s3",
key='a046308c078b0134c9e261aa91f63ab2',
secret='d5b32f241ad8ee8d0a3173cd51b4f36d6869f168b21acef75f244a81dc10e1fb',
endpoint_url=self.endpoint_url)
#return fsspec.filesystem("https")
else:
return self.local_endpoint

Expand Down Expand Up @@ -283,3 +298,4 @@ def from_config(self, config: dict):


API = FileSystem()

5 changes: 3 additions & 2 deletions openqdc/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def get_remote_cache(write_access=False) -> str:
Returns the entry point based on the write access.
"""
if write_access:
remote_cache = "gs://qmdata-public/openqdc"
remote_cache = "/openqdc/v1" #"gs://qmdata-public/openqdc"
else:
remote_cache = "https://storage.googleapis.com/qmdata-public/openqdc"
remote_cache = "/openqdc/v1"
#remote_cache = "https://storage.googleapis.com/qmdata-public/openqdc"
return remote_cache


Expand Down

0 comments on commit 0c60be6

Please sign in to comment.