Skip to content

Commit

Permalink
descriptor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Apr 2, 2024
1 parent cef2b35 commit dbdd985
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 23 deletions.
41 changes: 18 additions & 23 deletions openqdc/utils/descriptors.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
from abc import ABC, abstractmethod
from typing import Any, List

import datamol as dm
import numpy as np
from ase.atoms import Atoms
from numpy import ndarray

from openqdc.utils.io import to_atoms
from openqdc.utils.package_utils import requires_package
import datamol as dm


class Descriptor(ABC):
"""
Base class for all descriptors.
Base class for all descriptors.
Descriptors are used to transform 3D atomic structures into feature vectors.
"""

_model: Any

def __init__(self, *, species: List[str], **kwargs) -> None:
"""
"""
Parameters
----------
species : List[str]
List of chemical species for the descriptor embedding.
kwargs : dict
Additional keyword arguments to be passed to the descriptor model.
"""
"""
self.chemical_species = species
self._model = self.instantiate_model(**kwargs)

Expand All @@ -37,11 +38,11 @@ def model(self) -> Any:
@abstractmethod
def instantiate_model(self, **kwargs) -> Any:
"""
Instantiate the descriptor model with the provided kwargs parameters
Instantiate the descriptor model with the provided kwargs parameters
and return it. The model will be stored in the _model attribute.
If a package is required to instantiate the model, it should be checked
using the requires_package decorator or in the method itself.
Parameters
----------
kwargs : dict
Expand All @@ -53,19 +54,18 @@ def instantiate_model(self, **kwargs) -> Any:
def calculate(self, atoms: Atoms, **kwargs) -> ndarray:
"""
Calculate the descriptor for a single given Atoms object.
Parameters
----------
atoms : Atoms
Ase Atoms object to calculate the descriptor for.
Returns
-------
ndarray
ndarray containing the descriptor values
"""
raise NotImplementedError


def fit_transform(self, atoms: List[Atoms], **kwargs) -> List[ndarray]:
"""Parallelized version of the calculate method.
Expand All @@ -75,32 +75,27 @@ def fit_transform(self, atoms: List[Atoms], **kwargs) -> List[ndarray]:
List of Ase Atoms object to calculate the descriptor for.
kwargs : dict
Additional keyword arguments to be passed to the datamol parallelized model.
Returns
-------
List[ndarray]
List of ndarray containing the descriptor values
"""

descr_values = dm.parallelized(self.calculate,
atoms,
scheduler="threads",
**kwargs)
descr_values = dm.parallelized(self.calculate, atoms, scheduler="threads", **kwargs)
return descr_values



def from_xyz(self, positions: np.ndarray, atomic_numbers: np.ndarray) -> ndarray:
"""
Calculate the descriptor from positions and atomic numbers of a single structure.
Parameters
----------
positions : np.ndarray (n_atoms, 3)
Positions of the chemical structure.
atomic_numbers : np.ndarray (n_atoms,)
Atomic numbers of the chemical structure.
Returns
-------
ndarray
Expand All @@ -117,7 +112,6 @@ def __repr__(self):


class SOAP(Descriptor):

@requires_package("dscribe")
def instantiate_model(self, **kwargs):
from dscribe.descriptors import SOAP as SOAPModel
Expand Down Expand Up @@ -154,7 +148,7 @@ def instantiate_model(self, **kwargs):

r_cut = kwargs.pop("r_cut", 5.0)
g2_params = kwargs.pop("g2_params", [[1, 1], [1, 2], [1, 3]])
g3_params = kwargs.pop("g3_params", [1,1,2,-1])
g3_params = kwargs.pop("g3_params", [1, 1, 2, -1])
g4_params = kwargs.pop("g4_params", [[1, 1, 1], [1, 2, 1], [1, 1, -1], [1, 2, -1]])
g5_params = kwargs.pop("g5_params", [[1, 2, -1], [1, 1, 1], [-1, 1, 1], [1, 2, 1]])
periodic = kwargs.pop("periodic", False)
Expand Down Expand Up @@ -191,15 +185,16 @@ def instantiate_model(self, **kwargs):
normalize_gaussians=normalize_gaussians,
normalization=normalization,
)

def calculate(self, atoms: Atoms, **kwargs) -> ndarray:
return self.model.create(atoms, **kwargs)


# Dynamic mapping of available descriptors
AVAILABLE_DESCRIPTORS = {
str_name.lower(): cls
for str_name, cls in globals().items()
if isinstance(cls, type) and issubclass(cls, Descriptor) and str_name != "Descriptor" # Exclude the base class
if isinstance(cls, type) and issubclass(cls, Descriptor) and str_name != "Descriptor" # Exclude the base class
}


Expand Down
35 changes: 35 additions & 0 deletions tests/test_descriptors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from openqdc import Dummy
from openqdc.utils.descriptors import ACSF, MBTR, SOAP, Descriptor


@pytest.fixture
def dummy():
return Dummy()


@pytest.mark.parametrize("model", [SOAP, ACSF, MBTR])
def test_init(model):
model = model(species=["H"])
assert isinstance(model, Descriptor)


@pytest.mark.parametrize("model", [SOAP, ACSF, MBTR])
def test_descriptor(model, dummy):
model = model(species=dummy.chemical_species)
results = model.fit_transform([dummy.get_ase_atoms(i) for i in range(4)])
assert len(results) == 4


@pytest.mark.parametrize("model", [SOAP, ACSF, MBTR])
def test_from_positions(model):
model = model(species=["H"])
_ = model.from_xyz([[0, 0, 0], [1, 1, 1]], [1, 1])


@pytest.mark.parametrize(
"model,override", [(SOAP, {"r_cut": 3.0}), (ACSF, {"r_cut": 3.0}), (MBTR, {"normalize_gaussians": False})]
)
def test_overwrite(model, override, dummy):
model = model(species=dummy.chemical_species, **override)

0 comments on commit dbdd985

Please sign in to comment.