Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #93 #94

Merged
merged 2 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
array_format: str = "numpy",
energy_type: str = "formation",
energy_type: Optional[str] = "formation",
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
recompute_statistics: bool = False,
Expand All @@ -112,7 +112,7 @@ def __init__(
Format to return arrays in. Supported formats: ["numpy", "torch", "jax"]
energy_type
Type of isolated atom energy to use for the dataset. Default: "formation"
Supported types: ["formation", "regression", "null"]
Supported types: ["formation", "regression", "null", None]
overwrite_local_cache
Whether to overwrite the locally cached dataset.
cache_dir
Expand All @@ -133,7 +133,7 @@ def __init__(
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
self.energy_type = energy_type
self.energy_type = energy_type if energy_type is not None else "null"
self.refit_e0s = recompute_statistics or overwrite_local_cache
if not self.is_preprocessed():
raise DatasetNotAvailableError(self.__name__)
Expand Down
3 changes: 1 addition & 2 deletions openqdc/datasets/energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from loguru import logger

from openqdc.methods.enums import PotentialMethod
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS, MAX_CHARGE_NUMBER
from openqdc.utils.io import load_pkl, save_pkl
from openqdc.utils.regressor import Regressor

POSSIBLE_ENERGIES = ["formation", "regression", "null"]
MAX_CHARGE_NUMBER = 21


def dispatch_factory(data, **kwargs) -> "IsolatedEnergyInterface":
Expand Down
9 changes: 7 additions & 2 deletions openqdc/methods/atom_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import numpy as np
from loguru import logger

from openqdc.utils.constants import ATOMIC_NUMBERS, MAX_ATOMIC_NUMBER, MAX_CHARGE
from openqdc.utils.constants import (
ATOMIC_NUMBERS,
MAX_ATOMIC_NUMBER,
MAX_CHARGE,
MAX_CHARGE_NUMBER,
)

EF_KEY = Tuple[str, int]

Expand Down Expand Up @@ -35,7 +40,7 @@ def to_e_matrix(atom_energies: dict) -> np.ndarray:
| 2 | | | | | |
"""

matrix = np.zeros((MAX_ATOMIC_NUMBER, MAX_CHARGE * 2 + 1))
matrix = np.zeros((MAX_ATOMIC_NUMBER, MAX_CHARGE_NUMBER))
if len(atom_energies) > 0:
for key in atom_energies.keys():
try:
Expand Down
3 changes: 2 additions & 1 deletion openqdc/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
from rdkit import Chem

MAX_CHARGE: Final[int] = 6
MAX_CHARGE: Final[int] = 10
MAX_CHARGE_NUMBER: Final[int] = 2 * MAX_CHARGE + 1

NB_ATOMIC_FEATURES: Final[int] = 5

Expand Down
Loading