diff --git a/env.yml b/env.yml index 7787e28..23e1720 100644 --- a/env.yml +++ b/env.yml @@ -12,7 +12,7 @@ dependencies: - typer - prettytable - s3fs - - pydantic + - pydantic - python-dotenv diff --git a/openqdc/cli.py b/openqdc/cli.py index 66e5888..4963c8e 100644 --- a/openqdc/cli.py +++ b/openqdc/cli.py @@ -12,6 +12,7 @@ AVAILABLE_INTERACTION_DATASETS, AVAILABLE_POTENTIAL_DATASETS, ) +from openqdc.utils.io import get_local_cache app = typer.Typer(help="OpenQDC CLI") @@ -55,7 +56,7 @@ def download( help="Path to the cache. If not provided, the default cache directory (.cache/openqdc/) will be used.", ), ] = None, - as_zarr : Annotated[ + as_zarr: Annotated[ bool, typer.Option( help="Whether to overwrite or force the re-download of the datasets.", @@ -70,14 +71,14 @@ def download( """ for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)): if exist_dataset(dataset): - ds=SANITIZED_AVAILABLE_DATASETS[dataset].no_init() - ds.read_as_zarr=as_zarr + ds = SANITIZED_AVAILABLE_DATASETS[dataset].no_init() + ds.read_as_zarr = as_zarr if ds.is_cached() and not overwrite: logger.info(f"{dataset} is already cached. Skipping download") else: - SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, - cache_dir=cache_dir, - read_as_zarr=as_zarr) + SANITIZED_AVAILABLE_DATASETS[dataset]( + overwrite_local_cache=True, cache_dir=cache_dir, read_as_zarr=as_zarr + ) @app.command() @@ -185,7 +186,7 @@ def upload( help="Whether to overwrite or force the re-download of the datasets.", ), ] = True, - as_zarr : Annotated[ + as_zarr: Annotated[ bool, typer.Option( help="Whether to upload the zarr files if available.", @@ -204,6 +205,7 @@ def upload( logger.error(f"Error while uploading {dataset}. {e}. Did you preprocess the dataset first?") raise e + @app.command() def convert_to_zarr( datasets: List[str], @@ -221,7 +223,7 @@ def convert_to_zarr( ] = False, ): """ - Conver a preprocessed dataset to the zarr file format. + Convert a preprocessed dataset to the zarr file format. """ import os from os.path import join as p_join @@ -229,72 +231,77 @@ def convert_to_zarr( import numpy as np import zarr - from openqdc.utils.io import load_pkl + from openqdc.utils.io import load_pkl + def silent_remove(filename): try: os.remove(filename) except OSError: pass + for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)): if exist_dataset(dataset): logger.info(f"Uploading {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}") try: - ds=SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=download) - #os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True) - + ds = SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=download) + # os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True) + pkl = load_pkl(p_join(ds.preprocess_path, "props.pkl")) - metadata=p_join(ds.preprocess_path, "metadata.zip") - if overwrite: silent_remove(metadata) + metadata = p_join(ds.preprocess_path, "metadata.zip") + if overwrite: + silent_remove(metadata) group = zarr.group(zarr.storage.ZipStore(metadata)) for key, value in pkl.items(): - #sub=group.create_group(key) - if key in ['name', 'subset']: - data=group.create_dataset(key, shape=value[0].shape, dtype=value[0].dtype) - data[:]=value[0][:] - data2=group.create_dataset(key + "_ptr", shape=value[1].shape, - dtype=np.int32) + # sub=group.create_group(key) + if key in ["name", "subset"]: + data = group.create_dataset(key, shape=value[0].shape, dtype=value[0].dtype) + data[:] = value[0][:] + data2 = group.create_dataset(key + "_ptr", shape=value[1].shape, dtype=np.int32) data2[:] = value[1][:] else: - data=group.create_dataset(key, shape=value.shape, dtype=value.dtype) - data[:]=value[:] - + data = group.create_dataset(key, shape=value.shape, dtype=value.dtype) + data[:] = value[:] + force_attrs = { - "unit" : str(ds.force_unit), - "level_of_theory" : ds.force_methods, + "unit": str(ds.force_unit), + "level_of_theory": ds.force_methods, } - energy_attrs = { - "unit" : str(ds.energy_unit), - "level_of_theory": ds.energy_methods - } + energy_attrs = {"unit": str(ds.energy_unit), "level_of_theory": ds.energy_methods} atomic_inputs_attrs = { - "unit" : str(ds.distance_unit), - } - attrs = { - "forces" : force_attrs, - "energies" : energy_attrs, - "atomic_inputs" : atomic_inputs_attrs + "unit": str(ds.distance_unit), } + attrs = {"forces": force_attrs, "energies": energy_attrs, "atomic_inputs": atomic_inputs_attrs} - - #os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True) + # os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True) for key, value in ds.data.items(): if key not in ds.data_keys: continue print(key, value.shape) - zarr_path=p_join(ds.preprocess_path, key + ".zip") #ds.__name__, - if overwrite: silent_remove(zarr_path) - z=zarr.open(zarr.storage.ZipStore(zarr_path), "w", zarr_version=2, shape=value.shape, - dtype=value.dtype) - z[:]=value[:] + zarr_path = p_join(ds.preprocess_path, key + ".zip") # ds.__name__, + if overwrite: + silent_remove(zarr_path) + z = zarr.open( + zarr.storage.ZipStore(zarr_path), "w", zarr_version=2, shape=value.shape, dtype=value.dtype + ) + z[:] = value[:] if key in attrs: z.attrs.update(attrs[key]) except Exception as e: logger.error(f"Error while converting {dataset}. {e}. Did you preprocess the dataset first?") raise e - + + +@app.command +def cache(): + """ + Get the current local cache path of openQDC + """ + print(f"openQDC local cache:\n {get_local_cache()}") + + if __name__ == "__main__": app() diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py index 79c86df..44c64f2 100644 --- a/openqdc/datasets/base.py +++ b/openqdc/datasets/base.py @@ -1,19 +1,18 @@ """The BaseDataset defining shared functionality between all datasets.""" import os -import pickle as pkl from functools import partial from itertools import compress from os.path import join as p_join from typing import Callable, Dict, List, Optional, Union import numpy as np -import zarr from ase.io.extxyz import write_extxyz from loguru import logger from sklearn.utils import Bunch from tqdm import tqdm +from openqdc.datasets.dataset_structure import MemMapDataset, ZarrDataset from openqdc.datasets.energies import AtomEnergies from openqdc.datasets.properties import DatasetPropertyMixIn from openqdc.datasets.statistics import ( @@ -33,7 +32,6 @@ copy_exists, dict_to_atoms, get_local_cache, - pull_locally, push_remote, set_cache_dir, ) @@ -156,6 +154,12 @@ def _init_lambda_fn(self): self._fn_distance = lambda x: x self._fn_forces = lambda x: x + @property + def dataset_wrapper(self): + if not hasattr("_dataset_wrapper", self): + self._dataset_wrapper = ZarrDataset() if self.read_as_zarr else MemMapDataset() + return self._dataset_wrapper + @property def config(self): assert len(self.__links__) > 0, "No links provided for fetching" @@ -167,17 +171,6 @@ def fetch(cls, cache_path: Optional[str] = None, overwrite: bool = False) -> Non DataDownloader(cache_path, overwrite).from_config(cls.no_init().config) - @property - def ext(self): - return ".mmap" if not self.read_as_zarr else ".zip" - - @property - def load_fn(self): - return np.memmap if not self.read_as_zarr else zarr.open - - def add_extension(self, filename): - return filename + self.ext - def _post_init( self, overwrite_local_cache: bool = False, @@ -392,29 +385,32 @@ def save_preprocess(self, data_dict, upload=False, overwrite=True, as_zarr: bool """ # save memmaps logger.info("Preprocessing data and saving it to cache.") - for key in self.data_keys: - local_path = p_join(self.preprocess_path, f"{key}.mmap" if not as_zarr else f"{key}.zip") - out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape) - out[:] = data_dict.pop(key)[:] - out.flush() - if upload: - push_remote(local_path, overwrite=overwrite) - - # save smiles and subset - local_path = p_join(self.preprocess_path, "props.pkl") - - # assert that (required) pkl keys are present in data_dict - assert all([key in data_dict.keys() for key in self.pkl_data_keys]) + paths = self.dataset_wrapper.save_preprocess( + self.preprocess_path, self.data_keys, data_dict, self.pkl_data_keys, self.pkl_data_types + ) + if upload: + for local_path in paths: + push_remote(local_path, overwrite=overwrite) # make it async? - # store unique and inverse indices for str-based pkl keys - for key in self.pkl_data_keys: - if self.pkl_data_types[key] == str: - data_dict[key] = np.unique(data_dict[key], return_inverse=True) + def read_preprocess(self, overwrite_local_cache=False): + logger.info("Reading preprocessed data.") + logger.info( + f"Dataset {self.__name__} with the following units:\n\ + Energy: {self.energy_unit},\n\ + Distance: {self.distance_unit},\n\ + Forces: {self.force_unit if self.force_methods else 'None'}" + ) - with open(local_path, "wb") as f: - pkl.dump(data_dict, f) - if upload: - push_remote(local_path, overwrite=overwrite) + self.data = self.dataset_wrapper.load_data( + self.preprocess_path, + self.data_keys, + self.data_types, + self.data_shapes, + self.pkl_data_keys, + overwrite_local_cache, + ) # this should be async if possible + for key in self.data: + logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}") def _convert_on_loading(self, x, key): if key == "energies": @@ -428,63 +424,15 @@ def _convert_on_loading(self, x, key): else: return x - def read_preprocess(self, overwrite_local_cache=False): - logger.info("Reading preprocessed data.") - logger.info( - f"Dataset {self.__name__} with the following units:\n\ - Energy: {self.energy_unit},\n\ - Distance: {self.distance_unit},\n\ - Forces: {self.force_unit if self.force_methods else 'None'}" - ) - self.data = {} - for key in self.data_keys: - filename = p_join(self.preprocess_path, self.add_extension(f"{key}")) - pull_locally(filename, overwrite=overwrite_local_cache) - self.data[key] = self.load_fn(filename, mode="r", dtype=self.data_types[key]) - if self.read_as_zarr: - self.data[key] = self.data[key][:] - self.data[key] = self.data[key].reshape(*self.data_shapes[key]) - - if not self.read_as_zarr: - filename = p_join(self.preprocess_path, "props.pkl") - pull_locally(filename, overwrite=overwrite_local_cache) - with open(filename, "rb") as f: - tmp = pkl.load(f) - all_pkl_keys = set(tmp.keys()) - set(self.data_keys) - # assert required pkl_keys are present in all_pkl_keys - assert all([key in all_pkl_keys for key in self.pkl_data_keys]) - for key in all_pkl_keys: - x = tmp.pop(key) - if len(x) == 2: - self.data[key] = x[0][x[1]] - else: - self.data[key] = x - else: - filename = p_join(self.preprocess_path, self.add_extension("metadata")) - pull_locally(filename, overwrite=overwrite_local_cache) - tmp = self.load_fn(filename) - all_pkl_keys = set(tmp.keys()) - set(self.data_keys) - # assert required pkl_keys are present in all_pkl_keys - assert all([key in all_pkl_keys for key in self.pkl_data_keys]) - for key in all_pkl_keys: - if key not in self.pkl_data_keys: - #print(key, list(tmp.items())) - self.data[key] = tmp[key][:][tmp[key][:]] - else: - self.data[key] = tmp[key][:] - - for key in self.data: - logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}") - def is_preprocessed(self): """ Check if the dataset is preprocessed and available online or locally. """ - predicats = [copy_exists(p_join(self.preprocess_path, self.add_extension(f"{key}"))) for key in self.data_keys] - - if not self.read_as_zarr: - predicats += [copy_exists(p_join(self.preprocess_path, "props.pkl"))] - print(predicats) + predicats = [ + copy_exists(p_join(self.preprocess_path, self.dataset_wrapper.add_extension(f"{key}"))) + for key in self.data_keys + ] + predicats += [copy_exists(p_join(self.preprocess_path, file)) for file in self.dataset_wrapper._extra_files] return all(predicats) def is_cached(self): @@ -492,10 +440,10 @@ def is_cached(self): Check if the dataset is cached locally. """ predicats = [ - os.path.exists(p_join(self.preprocess_path, self.add_extension(f"{key}"))) for key in self.data_keys + os.path.exists(p_join(self.preprocess_path, self.dataset_wrapper.add_extension(f"{key}"))) + for key in self.data_keys ] - if not self.read_as_zarr: - predicats += [os.path.exists(p_join(self.preprocess_path, "props.pkl"))] + predicats += [copy_exists(p_join(self.preprocess_path, file)) for file in self.dataset_wrapper._extra_files] return all(predicats) def preprocess(self, upload: bool = False, overwrite: bool = True, as_zarr: bool = True): @@ -521,7 +469,6 @@ def upload(self, overwrite: bool = False, as_zarr: bool = False): push_remote(local_path, overwrite=overwrite) local_path = p_join(self.preprocess_path, "props.pkl" if not as_zarr else "metadata.zip") push_remote(local_path, overwrite=overwrite) - def save_xyz(self, idx: int, energy_method: int = 0, path: Optional[str] = None, ext=True): """ diff --git a/openqdc/datasets/dataset_structure.py b/openqdc/datasets/dataset_structure.py index 44c561c..e072afd 100644 --- a/openqdc/datasets/dataset_structure.py +++ b/openqdc/datasets/dataset_structure.py @@ -11,15 +11,15 @@ class GeneralStructure(ABC): """ - Base class for datasets in the openQDC package. + Abstract Factory class for datasets type in the openQDC package. """ - _ext : Optional[str] = None - _extra_files : Optional[List[str]] = None + _ext: Optional[str] = None + _extra_files: Optional[List[str]] = None @property def ext(self): - return self._ext + return self._ext @property @abstractmethod @@ -37,22 +37,30 @@ def save_preprocess(self, preprocess_path, data_keys, data_dict, extra_data_keys def load_extra_files(self, data, preprocess_path, data_keys, pkl_data_keys, overwrite): raise NotImplementedError - def load_data(self, preprocess_path, data_keys, data_types, data_shapes, overwrite): + def join_and_ext(self, path, filename): + return p_join(path, self.add_extension(filename)) + + def load_data(self, preprocess_path, data_keys, data_types, data_shapes, extra_data_keys, overwrite): data = {} for key in data_keys: - filename = p_join(preprocess_path, self.add_extension(f"{key}")) + filename = self.join_and_ext(preprocess_path, key) pull_locally(filename, overwrite=overwrite) data[key] = self.load_fn(filename, mode="r", dtype=data_types[key]) - if isinstance(self, "ZarrDataset"): - data[key] = data[key][:] + data[key] = self.unpack(data[key]) data[key] = data[key].reshape(*data_shapes[key]) - data=self.load_extra_files(data, preprocess_path, data_keys, data_types, data_shapes, overwrite) - return data + data = self.load_extra_files(data, preprocess_path, data_keys, extra_data_keys, overwrite) + return data + def unpack(self, data): + return data class MemMapDataset(GeneralStructure): + """ + Dataset structure for memory-mapped numpy arrays and props.pkl files. + """ + _ext = ".mmap" _extra_files = ["props.pkl"] @@ -68,7 +76,7 @@ def save_preprocess(self, preprocess_path, data_keys, data_dict, extra_data_keys """ local_paths = [] for key in data_keys: - local_path = p_join(preprocess_path, self.add_extension(key)) + local_path = self.join_and_ext(preprocess_path, key) out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape) out[:] = data_dict.pop(key)[:] out.flush() @@ -105,35 +113,47 @@ def load_extra_files(self, data, preprocess_path, data_keys, pkl_data_keys, over data[key] = x[0][x[1]] else: data[key] = x - return data - + return data + class ZarrDataset(GeneralStructure): + """ + Dataset structure for zarr files. + """ + _ext = ".zip" _extra_files = ["metadata.zip"] - zarr_version = 2 + _zarr_version = 2 @property def load_fn(self): return zarr.open + def unpack(self, data): + return data[:] + def save_preprocess(self, preprocess_path, data_keys, data_dict, extra_data_keys, extra_data_types) -> List[str]: - #os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True) - local_paths =[] + # os.makedirs(p_join(ds.root, "zips", ds.__name__), exist_ok=True) + local_paths = [] for key, value in data_dict.items(): if key not in data_keys: continue - zarr_path=p_join(preprocess_path, self.add_extension(key)) - value=data_dict.pop(key) - z=zarr.open(zarr.storage.ZipStore(zarr_path), "w", zarr_version=self.zarr_version, shape=value.shape, - dtype=value.dtype) - z[:]=value[:] + zarr_path = self.join_and_ext(preprocess_path, key) + value = data_dict.pop(key) + z = zarr.open( + zarr.storage.ZipStore(zarr_path), + "w", + zarr_version=self._zarr_version, + shape=value.shape, + dtype=value.dtype, + ) + z[:] = value[:] local_paths.append(zarr_path) - #if key in attrs: + # if key in attrs: # z.attrs.update(attrs[key]) - metadata=p_join(preprocess_path, "metadata.zip") - + metadata = p_join(preprocess_path, "metadata.zip") + group = zarr.group(zarr.storage.ZipStore(metadata)) for key in extra_data_keys: @@ -141,21 +161,20 @@ def save_preprocess(self, preprocess_path, data_keys, data_dict, extra_data_keys data_dict[key] = np.unique(data_dict[key], return_inverse=True) for key, value in data_dict.items(): - #sub=group.create_group(key) - if key in ['name', 'subset']: - data=group.create_dataset(key, shape=value[0].shape, dtype=value[0].dtype) - data[:]=value[0][:] - data2=group.create_dataset(key + "_ptr", shape=value[1].shape, - dtype=np.int32) + # sub=group.create_group(key) + if key in ["name", "subset"]: + data = group.create_dataset(key, shape=value[0].shape, dtype=value[0].dtype) + data[:] = value[0][:] + data2 = group.create_dataset(key + "_ptr", shape=value[1].shape, dtype=np.int32) data2[:] = value[1][:] else: - data=group.create_dataset(key, shape=value.shape, dtype=value.dtype) - data[:]=value[:] + data = group.create_dataset(key, shape=value.shape, dtype=value.dtype) + data[:] = value[:] local_paths.append(metadata) return local_paths def load_extra_files(self, data, preprocess_path, data_keys, pkl_data_keys, overwrite): - filename = p_join(preprocess_path, self.add_extension("metadata")) + filename = self.join_and_ext(preprocess_path, "metadata") pull_locally(filename, overwrite=overwrite) tmp = self.load_fn(filename) all_pkl_keys = set(tmp.keys()) - set(data_keys) @@ -168,4 +187,4 @@ def load_extra_files(self, data, preprocess_path, data_keys, pkl_data_keys, over data[key] = tmp[key][:] return data - \ No newline at end of file + # TODO: checksum , maybe convert to archive instead of zips diff --git a/openqdc/datasets/potential/alchemy.py b/openqdc/datasets/potential/alchemy.py index 0643e82..969d13d 100644 --- a/openqdc/datasets/potential/alchemy.py +++ b/openqdc/datasets/potential/alchemy.py @@ -1,4 +1,3 @@ - from os.path import join as p_join import datamol as dm diff --git a/openqdc/datasets/potential/qmx.py b/openqdc/datasets/potential/qmx.py index 65fa583..e0551ce 100644 --- a/openqdc/datasets/potential/qmx.py +++ b/openqdc/datasets/potential/qmx.py @@ -4,11 +4,10 @@ import datamol as dm import numpy as np import pandas as pd -from tqdm import tqdm from openqdc.datasets.base import BaseDataset from openqdc.methods import PotentialMethod -from openqdc.utils import load_hdf5_file, read_qc_archive_h5 +from openqdc.utils import read_qc_archive_h5 from openqdc.utils.io import get_local_cache from openqdc.utils.molecule import get_atomic_number_and_charge @@ -80,63 +79,186 @@ def config(self): assert len(self.__links__) > 0, "No links provided for fetching" return dict(dataset_name="qmx", links=self.__links__) - - @property - def preprocess_path(self): - path = p_join(self.root, "preprocessed", self.__name__) - os.makedirs(path, exist_ok=True) - return path - def read_raw_entries(self): raw_path = p_join(self.root, f"{self.__name__}.h5.gz") samples = read_qc_archive_h5(raw_path, self.__name__, self.energy_target_names, None) return samples -# ['smiles', 'E1-CC2', 'E2-CC2', 'f1-CC2', 'f2-CC2', 'E1-PBE0', 'E2-PBE0', 'f1-PBE0', 'f2-PBE0', 'E1-PBE0.1', 'E2-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1', 'E1-CAM', 'E2-CAM', 'f1-CAM', 'f2-CAM'] +# ['smiles', 'E1-CC2', 'E2-CC2', 'f1-CC2', 'f2-CC2', 'E1-PBE0', 'E2-PBE0', +# 'f1-PBE0', 'f2-PBE0', 'E1-PBE0.1', 'E2-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1', +# 'E1-CAM', 'E2-CAM', 'f1-CAM', 'f2-CAM'] class QM7(QMX): __links__ = {"qm7.hdf5.gz": "https://zenodo.org/record/3588337/files/150.hdf5.gz?download=1"} __name__ = "qm7" - - energy_target_names = ['B2PLYP-D3(BJ):aug-cc-pvdz', 'B2PLYP-D3(BJ):aug-cc-pvtz', 'B2PLYP-D3(BJ):def2-svp', - 'B2PLYP-D3(BJ):def2-tzvp', 'B2PLYP-D3(BJ):sto-3g', 'B2PLYP-D3:aug-cc-pvdz', - 'B2PLYP-D3:aug-cc-pvtz', 'B2PLYP-D3:def2-svp', 'B2PLYP-D3:def2-tzvp', - 'B2PLYP-D3:sto-3g', 'B2PLYP-D3M(BJ):aug-cc-pvdz', 'B2PLYP-D3M(BJ):aug-cc-pvtz', - 'B2PLYP-D3M(BJ):def2-svp', 'B2PLYP-D3M(BJ):def2-tzvp', 'B2PLYP-D3M(BJ):sto-3g', - 'B2PLYP-D3M:aug-cc-pvdz', 'B2PLYP-D3M:aug-cc-pvtz', 'B2PLYP-D3M:def2-svp', - 'B2PLYP-D3M:def2-tzvp', 'B2PLYP-D3M:sto-3g', 'B2PLYP:aug-cc-pvdz', - 'B2PLYP:aug-cc-pvtz', 'B2PLYP:def2-svp', 'B2PLYP:def2-tzvp', - 'B2PLYP:sto-3g', 'B3LYP-D3(BJ):aug-cc-pvdz', 'B3LYP-D3(BJ):aug-cc-pvtz', - 'B3LYP-D3(BJ):def2-svp', 'B3LYP-D3(BJ):def2-tzvp', 'B3LYP-D3(BJ):sto-3g', - 'B3LYP-D3:aug-cc-pvdz', 'B3LYP-D3:aug-cc-pvtz', 'B3LYP-D3:def2-svp', - 'B3LYP-D3:def2-tzvp', 'B3LYP-D3:sto-3g', 'B3LYP-D3M(BJ):aug-cc-pvdz', - 'B3LYP-D3M(BJ):aug-cc-pvtz', 'B3LYP-D3M(BJ):def2-svp', 'B3LYP-D3M(BJ):def2-tzvp', - 'B3LYP-D3M(BJ):sto-3g', 'B3LYP-D3M:aug-cc-pvdz', 'B3LYP-D3M:aug-cc-pvtz', - 'B3LYP-D3M:def2-svp', 'B3LYP-D3M:def2-tzvp', 'B3LYP-D3M:sto-3g', - 'B3LYP:aug-cc-pvdz', 'B3LYP:aug-cc-pvtz', 'B3LYP:def2-svp', 'B3LYP:def2-tzvp', - 'B3LYP:sto-3g', 'HF:aug-cc-pvdz', 'HF:aug-cc-pvtz', 'HF:def2-svp', - 'HF:def2-tzvp', 'HF:sto-3g', 'MP2:aug-cc-pvdz', 'MP2:aug-cc-pvtz', - 'MP2:def2-svp', 'MP2:def2-tzvp', 'MP2:sto-3g', 'PBE0:aug-cc-pvdz', - 'PBE0:aug-cc-pvtz', 'PBE0:def2-svp', 'PBE0:def2-tzvp', 'PBE0:sto-3g', - 'PBE:aug-cc-pvdz', 'PBE:aug-cc-pvtz', 'PBE:def2-svp', 'PBE:def2-tzvp', - 'PBE:sto-3g', 'WB97M-V:aug-cc-pvdz', 'WB97M-V:aug-cc-pvtz', 'WB97M-V:def2-svp', - 'WB97M-V:def2-tzvp', 'WB97M-V:sto-3g', 'WB97X-D:aug-cc-pvdz', - 'WB97X-D:aug-cc-pvtz', 'WB97X-D:def2-svp', 'WB97X-D:def2-tzvp', - 'WB97X-D:sto-3g'] - - __energy_methods__ = [ - PotentialMethod.NONE for _ in range(len(energy_target_names)) # "wb97x/6-31g(d)" - ] - + energy_target_names = [ + "B2PLYP-D3(BJ):aug-cc-pvdz", + "B2PLYP-D3(BJ):aug-cc-pvtz", + "B2PLYP-D3(BJ):def2-svp", + "B2PLYP-D3(BJ):def2-tzvp", + "B2PLYP-D3(BJ):sto-3g", + "B2PLYP-D3:aug-cc-pvdz", + "B2PLYP-D3:aug-cc-pvtz", + "B2PLYP-D3:def2-svp", + "B2PLYP-D3:def2-tzvp", + "B2PLYP-D3:sto-3g", + "B2PLYP-D3M(BJ):aug-cc-pvdz", + "B2PLYP-D3M(BJ):aug-cc-pvtz", + "B2PLYP-D3M(BJ):def2-svp", + "B2PLYP-D3M(BJ):def2-tzvp", + "B2PLYP-D3M(BJ):sto-3g", + "B2PLYP-D3M:aug-cc-pvdz", + "B2PLYP-D3M:aug-cc-pvtz", + "B2PLYP-D3M:def2-svp", + "B2PLYP-D3M:def2-tzvp", + "B2PLYP-D3M:sto-3g", + "B2PLYP:aug-cc-pvdz", + "B2PLYP:aug-cc-pvtz", + "B2PLYP:def2-svp", + "B2PLYP:def2-tzvp", + "B2PLYP:sto-3g", + "B3LYP-D3(BJ):aug-cc-pvdz", + "B3LYP-D3(BJ):aug-cc-pvtz", + "B3LYP-D3(BJ):def2-svp", + "B3LYP-D3(BJ):def2-tzvp", + "B3LYP-D3(BJ):sto-3g", + "B3LYP-D3:aug-cc-pvdz", + "B3LYP-D3:aug-cc-pvtz", + "B3LYP-D3:def2-svp", + "B3LYP-D3:def2-tzvp", + "B3LYP-D3:sto-3g", + "B3LYP-D3M(BJ):aug-cc-pvdz", + "B3LYP-D3M(BJ):aug-cc-pvtz", + "B3LYP-D3M(BJ):def2-svp", + "B3LYP-D3M(BJ):def2-tzvp", + "B3LYP-D3M(BJ):sto-3g", + "B3LYP-D3M:aug-cc-pvdz", + "B3LYP-D3M:aug-cc-pvtz", + "B3LYP-D3M:def2-svp", + "B3LYP-D3M:def2-tzvp", + "B3LYP-D3M:sto-3g", + "B3LYP:aug-cc-pvdz", + "B3LYP:aug-cc-pvtz", + "B3LYP:def2-svp", + "B3LYP:def2-tzvp", + "B3LYP:sto-3g", + "HF:aug-cc-pvdz", + "HF:aug-cc-pvtz", + "HF:def2-svp", + "HF:def2-tzvp", + "HF:sto-3g", + "MP2:aug-cc-pvdz", + "MP2:aug-cc-pvtz", + "MP2:def2-svp", + "MP2:def2-tzvp", + "MP2:sto-3g", + "PBE0:aug-cc-pvdz", + "PBE0:aug-cc-pvtz", + "PBE0:def2-svp", + "PBE0:def2-tzvp", + "PBE0:sto-3g", + "PBE:aug-cc-pvdz", + "PBE:aug-cc-pvtz", + "PBE:def2-svp", + "PBE:def2-tzvp", + "PBE:sto-3g", + "WB97M-V:aug-cc-pvdz", + "WB97M-V:aug-cc-pvtz", + "WB97M-V:def2-svp", + "WB97M-V:def2-tzvp", + "WB97M-V:sto-3g", + "WB97X-D:aug-cc-pvdz", + "WB97X-D:aug-cc-pvtz", + "WB97X-D:def2-svp", + "WB97X-D:def2-tzvp", + "WB97X-D:sto-3g", + ] + __energy_methods__ = [PotentialMethod.NONE for _ in range(len(energy_target_names))] # "wb97x/6-31g(d)" class QM7b(QMX): __links__ = {"qm7b.hdf5.gz": "https://zenodo.org/record/3588335/files/200.hdf5.gz?download=1"} __name__ = "qm7b" - energy_target_names = ['CCSD(T0):cc-pVDZ', 'HF:cc-pVDZ', 'HF:cc-pVTZ', 'MP2:cc-pVTZ', - 'B2PLYP-D3:aug-cc-pvdz', 'B2PLYP-D3:aug-cc-pvtz', 'B2PLYP-D3:def2-svp', 'B2PLYP-D3:def2-tzvp', 'B2PLYP-D3:sto-3g', 'B2PLYP-D3M(BJ):aug-cc-pvdz', 'B2PLYP-D3M(BJ):aug-cc-pvtz', 'B2PLYP-D3M(BJ):def2-svp', 'B2PLYP-D3M(BJ):def2-tzvp', 'B2PLYP-D3M(BJ):sto-3g', 'B2PLYP-D3M:aug-cc-pvdz', 'B2PLYP-D3M:aug-cc-pvtz', 'B2PLYP-D3M:def2-svp', 'B2PLYP-D3M:def2-tzvp', 'B2PLYP-D3M:sto-3g', 'B2PLYP:aug-cc-pvdz', 'B2PLYP:aug-cc-pvtz', 'B2PLYP:def2-svp', 'B2PLYP:def2-tzvp', 'B2PLYP:sto-3g', 'B3LYP-D3(BJ):aug-cc-pvdz', 'B3LYP-D3(BJ):aug-cc-pvtz', 'B3LYP-D3(BJ):def2-svp', 'B3LYP-D3(BJ):def2-tzvp', 'B3LYP-D3(BJ):sto-3g', 'B3LYP-D3:aug-cc-pvdz', 'B3LYP-D3:aug-cc-pvtz', 'B3LYP-D3:def2-svp', 'B3LYP-D3:def2-tzvp', 'B3LYP-D3:sto-3g', 'B3LYP-D3M(BJ):aug-cc-pvdz', 'B3LYP-D3M(BJ):aug-cc-pvtz', 'B3LYP-D3M(BJ):def2-svp', 'B3LYP-D3M(BJ):def2-tzvp', 'B3LYP-D3M(BJ):sto-3g', 'B3LYP-D3M:aug-cc-pvdz', 'B3LYP-D3M:aug-cc-pvtz', 'B3LYP-D3M:def2-svp', 'B3LYP-D3M:def2-tzvp', 'B3LYP-D3M:sto-3g', 'B3LYP:aug-cc-pvdz', 'B3LYP:aug-cc-pvtz', 'B3LYP:def2-svp', 'B3LYP:def2-tzvp', 'B3LYP:sto-3g', 'HF:aug-cc-pvdz', 'HF:aug-cc-pvtz', 'HF:cc-pvtz', 'HF:def2-svp', 'HF:def2-tzvp', 'HF:sto-3g', 'PBE0:aug-cc-pvdz', 'PBE0:aug-cc-pvtz', 'PBE0:def2-svp', 'PBE0:def2-tzvp', 'PBE0:sto-3g', 'PBE:aug-cc-pvdz', 'PBE:aug-cc-pvtz', 'PBE:def2-svp', 'PBE:def2-tzvp', 'PBE:sto-3g', 'SVWN:sto-3g', 'WB97M-V:aug-cc-pvdz', 'WB97M-V:aug-cc-pvtz', 'WB97M-V:def2-svp', 'WB97M-V:def2-tzvp', 'WB97M-V:sto-3g', 'WB97X-D:aug-cc-pvdz', 'WB97X-D:aug-cc-pvtz', 'WB97X-D:def2-svp', 'WB97X-D:def2-tzvp', 'WB97X-D:sto-3g'] + energy_target_names = [ + "CCSD(T0):cc-pVDZ", + "HF:cc-pVDZ", + "HF:cc-pVTZ", + "MP2:cc-pVTZ", + "B2PLYP-D3:aug-cc-pvdz", + "B2PLYP-D3:aug-cc-pvtz", + "B2PLYP-D3:def2-svp", + "B2PLYP-D3:def2-tzvp", + "B2PLYP-D3:sto-3g", + "B2PLYP-D3M(BJ):aug-cc-pvdz", + "B2PLYP-D3M(BJ):aug-cc-pvtz", + "B2PLYP-D3M(BJ):def2-svp", + "B2PLYP-D3M(BJ):def2-tzvp", + "B2PLYP-D3M(BJ):sto-3g", + "B2PLYP-D3M:aug-cc-pvdz", + "B2PLYP-D3M:aug-cc-pvtz", + "B2PLYP-D3M:def2-svp", + "B2PLYP-D3M:def2-tzvp", + "B2PLYP-D3M:sto-3g", + "B2PLYP:aug-cc-pvdz", + "B2PLYP:aug-cc-pvtz", + "B2PLYP:def2-svp", + "B2PLYP:def2-tzvp", + "B2PLYP:sto-3g", + "B3LYP-D3(BJ):aug-cc-pvdz", + "B3LYP-D3(BJ):aug-cc-pvtz", + "B3LYP-D3(BJ):def2-svp", + "B3LYP-D3(BJ):def2-tzvp", + "B3LYP-D3(BJ):sto-3g", + "B3LYP-D3:aug-cc-pvdz", + "B3LYP-D3:aug-cc-pvtz", + "B3LYP-D3:def2-svp", + "B3LYP-D3:def2-tzvp", + "B3LYP-D3:sto-3g", + "B3LYP-D3M(BJ):aug-cc-pvdz", + "B3LYP-D3M(BJ):aug-cc-pvtz", + "B3LYP-D3M(BJ):def2-svp", + "B3LYP-D3M(BJ):def2-tzvp", + "B3LYP-D3M(BJ):sto-3g", + "B3LYP-D3M:aug-cc-pvdz", + "B3LYP-D3M:aug-cc-pvtz", + "B3LYP-D3M:def2-svp", + "B3LYP-D3M:def2-tzvp", + "B3LYP-D3M:sto-3g", + "B3LYP:aug-cc-pvdz", + "B3LYP:aug-cc-pvtz", + "B3LYP:def2-svp", + "B3LYP:def2-tzvp", + "B3LYP:sto-3g", + "HF:aug-cc-pvdz", + "HF:aug-cc-pvtz", + "HF:cc-pvtz", + "HF:def2-svp", + "HF:def2-tzvp", + "HF:sto-3g", + "PBE0:aug-cc-pvdz", + "PBE0:aug-cc-pvtz", + "PBE0:def2-svp", + "PBE0:def2-tzvp", + "PBE0:sto-3g", + "PBE:aug-cc-pvdz", + "PBE:aug-cc-pvtz", + "PBE:def2-svp", + "PBE:def2-tzvp", + "PBE:sto-3g", + "SVWN:sto-3g", + "WB97M-V:aug-cc-pvdz", + "WB97M-V:aug-cc-pvtz", + "WB97M-V:def2-svp", + "WB97M-V:def2-tzvp", + "WB97M-V:sto-3g", + "WB97X-D:aug-cc-pvdz", + "WB97X-D:aug-cc-pvtz", + "WB97X-D:def2-svp", + "WB97X-D:def2-tzvp", + "WB97X-D:sto-3g", + ] class QM8(QMX): @@ -152,6 +274,7 @@ class QM8(QMX): - Columns 14-17: LR-TDCAM-B3LYP/def2TZVP """ + __name__ = "qm8" __energy_methods__ = [ @@ -164,8 +287,7 @@ class QM8(QMX): PotentialMethod.NONE, PotentialMethod.NONE, ] - - + __links__ = { "qm8.csv": "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv", "qm8.tar.gz": "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/gdb8.tar.gz", @@ -174,27 +296,42 @@ class QM8(QMX): def read_raw_entries(self): df = pd.read_csv(p_join(self.root, "qm8.csv")) mols = dm.read_sdf(p_join(self.root, "qm8.sdf"), sanitize=False, remove_hs=False) - samples=[] + samples = [] for idx_row, mol in zip(df.iterrows(), mols): _, row = idx_row positions = mol.GetConformer().GetPositions() x = get_atomic_number_and_charge(mol) n_atoms = positions.shape[0] - samples.append(dict( - atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float32).reshape(-1, 5), - name=np.array([row["smiles"]]), - energies=np.array([row[['E1-CC2', 'E2-CC2', 'E1-PBE0', 'E2-PBE0', "E1-PBE0.1", "E2-PBE0.1", 'E1-CAM', 'E2-CAM']].tolist()], dtype=np.float64).reshape(1,-1), - n_atoms=np.array([n_atoms], dtype=np.int32), - subset=np.array([f"{self.__name__}"]), - )) + samples.append( + dict( + atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float32).reshape(-1, 5), + name=np.array([row["smiles"]]), + energies=np.array( + [ + row[ + ["E1-CC2", "E2-CC2", "E1-PBE0", "E2-PBE0", "E1-PBE0.1", "E2-PBE0.1", "E1-CAM", "E2-CAM"] + ].tolist() + ], + dtype=np.float64, + ).reshape(1, -1), + n_atoms=np.array([n_atoms], dtype=np.int32), + subset=np.array([f"{self.__name__}"]), + ) + ) return samples - class QM9(QMX): __links__ = {"qm9.hdf5.gz": "https://zenodo.org/record/3588339/files/155.hdf5.gz?download=1"} __name__ = "qm9" energy_target_names = [ - 'Internal energy at 0 K', - 'B3LYP:def2-svp', 'HF:cc-pvtz', 'HF:sto-3g', 'PBE:sto-3g', 'SVWN:sto-3g', 'WB97X-D:aug-cc-pvtz', 'WB97X-D:def2-svp', 'WB97X-D:def2-tzvp'] - + "Internal energy at 0 K", + "B3LYP:def2-svp", + "HF:cc-pvtz", + "HF:sto-3g", + "PBE:sto-3g", + "SVWN:sto-3g", + "WB97X-D:aug-cc-pvtz", + "WB97X-D:def2-svp", + "WB97X-D:def2-tzvp", + ] diff --git a/openqdc/datasets/potential/waterclusters.py b/openqdc/datasets/potential/waterclusters.py index 1535c08..04c0db0 100644 --- a/openqdc/datasets/potential/waterclusters.py +++ b/openqdc/datasets/potential/waterclusters.py @@ -1,14 +1,10 @@ -import zipfile from collections import defaultdict -from io import StringIO from os.path import join as p_join -import numpy as np -from tqdm import tqdm +import pandas as pd from openqdc.datasets.base import BaseDataset from openqdc.methods import PotentialMethod -from openqdc.utils.constants import ATOM_TABLE, MAX_ATOMIC_NUMBER from openqdc.utils.package_utils import requires_package _default_basis_sets = { @@ -18,13 +14,16 @@ "H2O_halide_clusters": "def2-QZVPPD", } + @requires_package("monty") @requires_package("pymatgen") def read_geometries(fname, dataset): from monty.serialization import loadfn + geometries = {k: v.to_ase_atoms() for k, v in loadfn(fname)[dataset].items()} return geometries + @requires_package("monty") def read_energies(fname, dataset): from monty.serialization import loadfn @@ -35,11 +34,8 @@ def read_energies(fname, dataset): functionals_to_return = [] for dfa, at_dfa_d in _energies.items(): - functionals_to_return += [ - f"{dfa}" if dfa == at_dfa else f"{dfa}@{at_dfa}" - for at_dfa in at_dfa_d - ] - + functionals_to_return += [f"{dfa}" if dfa == at_dfa else f"{dfa}@{at_dfa}" for at_dfa in at_dfa_d] + energies = defaultdict(dict) for f in functionals_to_return: if "-FLOSIC" in f and "@" not in f: @@ -50,23 +46,23 @@ def read_energies(fname, dataset): at_f = f.split("@")[-1] if func not in _energies: - print(f"No functional {func} included in dataset - available options:\n{', '.join(_energies.keys())}") + print(f"No functional {func} included in dataset" f"- available options:\n{', '.join(_energies.keys())}") elif at_f not in _energies[func]: - print(f"No @functional {at_f} included in {func} dataset - available options:\n{', '.join(_energies[func].keys())}") + print( + f"No @functional {at_f} included in {func} dataset" + f"- available options:\n{', '.join(_energies[func].keys())}" + ) else: - if isinstance(_energies[func][at_f],list): + if isinstance(_energies[func][at_f], list): for entry in _energies[func][at_f]: - if all( - entry["metadata"].get(k) == v for k, v in metadata_restrictions.items() - ): + if all(entry["metadata"].get(k) == v for k, v in metadata_restrictions.items()): energies[f] = entry break - if f not in energies: - print(f"No matching metadata {json.dumps(metadata_restrictions)} for method {f}") else: energies[f] = _energies[func][at_f] return dict(energies) + def format_geometry_and_entries(geometry, energies, subset): pass @@ -81,26 +77,43 @@ class SCANWaterClusters(BaseDataset): __distance_unit__ = "ang" __forces_unit__ = "hartree/ang" - energy_target_names = ['HF', 'HF-r2SCAN-DC4', 'SCAN', 'SCAN@HF', 'SCAN@r2SCAN50', 'r2SCAN', 'r2SCAN@HF', 'r2SCAN@r2SCAN50', 'r2SCAN50', 'r2SCAN100', 'r2SCAN10', 'r2SCAN20', 'r2SCAN25', 'r2SCAN30', 'r2SCAN40', 'r2SCAN60', 'r2SCAN70', 'r2SCAN80', 'r2SCAN90'] + energy_target_names = [ + "HF", + "HF-r2SCAN-DC4", + "SCAN", + "SCAN@HF", + "SCAN@r2SCAN50", + "r2SCAN", + "r2SCAN@HF", + "r2SCAN@r2SCAN50", + "r2SCAN50", + "r2SCAN100", + "r2SCAN10", + "r2SCAN20", + "r2SCAN25", + "r2SCAN30", + "r2SCAN40", + "r2SCAN60", + "r2SCAN70", + "r2SCAN80", + "r2SCAN90", + ] force_target_names = [] - subsets = ["BEGDB_H2O","WATER27","H2O_alkali_clusters","H2O_halide_clusters"] + subsets = ["BEGDB_H2O", "WATER27", "H2O_alkali_clusters", "H2O_halide_clusters"] __links__ = { - "geometries.json.gz" : "https://github.com/esoteric-ephemera/water_cluster_density_errors/blob/main/data_files/geometries.json.gz?raw=True", - "total_energies.json.gz" : "https://github.com/esoteric-ephemera/water_cluster_density_errors/blob/main/data_files/total_energies.json.gz?raw=True" + "geometries.json.gz": "https://github.com/esoteric-ephemera/water_cluster_density_errors/blob/main/data_files/geometries.json.gz?raw=True", # noqa + "total_energies.json.gz": "https://github.com/esoteric-ephemera/water_cluster_density_errors/blob/main/data_files/total_energies.json.gz?raw=True", # noqa } - + def read_raw_entries(self): - entries=[] + entries = [] # noqa for i, subset in enumerate(self.subsets): - - geometries = read_geometries(p_join(self.root, "geometries.json.gz" ), subset) - energies = read_energies(p_join(self.root, "total_energies.json.gz" ), subset) - datum ={} + geometries = read_geometries(p_join(self.root, "geometries.json.gz"), subset) + energies = read_energies(p_join(self.root, "total_energies.json.gz"), subset) + datum = {} for k in energies: - _ = energies[k].pop("metadata") - datum[k] = energies[k]["total_energies"] - - return pd.concat([pd.DataFrame({"positions" : geometries}),datum], axis=1) - - + _ = energies[k].pop("metadata") + datum[k] = energies[k]["total_energies"] + + return pd.concat([pd.DataFrame({"positions": geometries}), datum], axis=1) diff --git a/openqdc/datasets/statistics.py b/openqdc/datasets/statistics.py index 0165952..7197c2c 100644 --- a/openqdc/datasets/statistics.py +++ b/openqdc/datasets/statistics.py @@ -149,7 +149,7 @@ def __init__( self.e0_matrix = e0_matrix self.n_atoms = n_atoms self.atom_species_charges_tuple = (atom_species, atom_charges) - self._root=p_join(get_local_cache(), self.name) + self._root = p_join(get_local_cache(), self.name) if atom_species is not None and atom_charges is not None: # by value not reference self.atom_species_charges_tuple = np.concatenate((atom_species[:, None], atom_charges[:, None]), axis=-1) @@ -160,7 +160,7 @@ def has_forces(self) -> bool: @property def preprocess_path(self): - path = p_join(self.root, "statistics", self.name + f"_{str(self)}" + f".pkl") + path = p_join(self.root, "statistics", self.name + f"_{str(self)}" + ".pkl") return path @property @@ -187,7 +187,7 @@ def from_openqdc_dataset(cls, dataset, recompute: bool = False): atom_charges=dataset.data["atomic_inputs"][:, 1].ravel(), e0_matrix=dataset.__isolated_atom_energies__, ) - obj._root = dataset.root # set to the dataset root in case of multiple datasets + obj._root = dataset.root # set to the dataset root in case of multiple datasets return obj @abstractmethod diff --git a/openqdc/utils/download_api.py b/openqdc/utils/download_api.py index 36b34f5..9ca1e64 100644 --- a/openqdc/utils/download_api.py +++ b/openqdc/utils/download_api.py @@ -14,7 +14,8 @@ import gdown import requests import tqdm -from aiohttp import ClientTimeout + +# from aiohttp import ClientTimeout from dotenv import load_dotenv from fsspec import AbstractFileSystem from fsspec.callbacks import TqdmCallback @@ -34,13 +35,13 @@ class FileSystem: public_endpoint: Optional[AbstractFileSystem] = None private_endpoint: Optional[AbstractFileSystem] = None local_endpoint: AbstractFileSystem = LocalFileSystem() - endpoint_url = 'https://874f02b9d981bd6c279e979c0d91c4b4.r2.cloudflarestorage.com' - + endpoint_url = "https://874f02b9d981bd6c279e979c0d91c4b4.r2.cloudflarestorage.com" + def __init__(self): load_dotenv() self.KEY = os.getenv("CLOUDFARE_KEY", None) self.SECRET = os.getenv("CLOUDFARE_SECRET", None) - + @property def public(self): self.connect() @@ -71,24 +72,23 @@ def connect(self): warnings.simplefilter("ignore") # No quota warning self.public_endpoint = self.get_default_endpoint("public") self.private_endpoint = self.get_default_endpoint("private") - #self.public_endpoint.client_kwargs = {"timeout": ClientTimeout(total=3600, connect=1000)} + # self.public_endpoint.client_kwargs = {"timeout": ClientTimeout(total=3600, connect=1000)} def get_default_endpoint(self, endpoint: str) -> AbstractFileSystem: """ Return a default endpoint for the given str [public, private] """ if endpoint == "private": - #return fsspec.filesystem("gs") - return fsspec.filesystem("s3", - key=self.KEY, - secret=self.SECRET, - endpoint_url=self.endpoint_url) + # return fsspec.filesystem("gs") + return fsspec.filesystem("s3", key=self.KEY, secret=self.SECRET, endpoint_url=self.endpoint_url) elif endpoint == "public": - return fsspec.filesystem("s3", - key='a046308c078b0134c9e261aa91f63ab2', - secret='d5b32f241ad8ee8d0a3173cd51b4f36d6869f168b21acef75f244a81dc10e1fb', - endpoint_url=self.endpoint_url) - #return fsspec.filesystem("https") + return fsspec.filesystem( + "s3", + key="a046308c078b0134c9e261aa91f63ab2", + secret="d5b32f241ad8ee8d0a3173cd51b4f36d6869f168b21acef75f244a81dc10e1fb", + endpoint_url=self.endpoint_url, + ) + # return fsspec.filesystem("https") else: return self.local_endpoint @@ -298,4 +298,3 @@ def from_config(self, config: dict): API = FileSystem() - diff --git a/openqdc/utils/io.py b/openqdc/utils/io.py index 6fe8dd5..cc468c4 100644 --- a/openqdc/utils/io.py +++ b/openqdc/utils/io.py @@ -54,10 +54,10 @@ def get_remote_cache(write_access=False) -> str: Returns the entry point based on the write access. """ if write_access: - remote_cache = "/openqdc/v1" #"gs://qmdata-public/openqdc" + remote_cache = "/openqdc/v1" # "gs://qmdata-public/openqdc" else: remote_cache = "/openqdc/v1" - #remote_cache = "https://storage.googleapis.com/qmdata-public/openqdc" + # remote_cache = "https://storage.googleapis.com/qmdata-public/openqdc" return remote_cache @@ -88,7 +88,6 @@ def pull_locally(local_path, overwrite=False): def copy_exists(local_path): remote_path = local_path.replace(get_local_cache(), get_remote_cache()) - print(remote_path) return os.path.exists(local_path) or API.exists(remote_path)