Skip to content

Commit

Permalink
Remove ase dependency, update ROY dataset and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX committed Jun 13, 2024
1 parent bc76938 commit bf97f18
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 54 deletions.
10 changes: 5 additions & 5 deletions examples/selection/GCH-ROY.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@

roy_data = load_roy_dataset()

structures = roy_data["structures"]

density = np.array([s.info["density"] for s in structures])
energy = np.array([s.info["energy"] for s in structures])
structype = np.array([s.info["type"] for s in structures])
density = roy_data["densities"]
energy = roy_data["energies"]
structype = roy_data["structure_types"]
iknown = np.where(structype == "known")[0]
iothers = np.where(structype != "known")[0]

Expand Down Expand Up @@ -247,3 +245,5 @@
},
)
"""

# %%
25 changes: 9 additions & 16 deletions src/skmatter/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,12 @@ def load_roy_dataset():
energies: `np.array` -- energies of the structures
"""
module_path = dirname(__file__)
target_structures = join(module_path, "data", "beran_roy_structures.xyz.bz2")

try:
from ase.io import read
except ImportError:
raise ImportError("load_roy_dataset requires the ASE package.")

import bz2

structures = read(bz2.open(target_structures, "rt"), ":", format="extxyz")
energies = np.array([f.info["energy"] for f in structures])

target_features = join(module_path, "data", "beran_roy_features.npz")
features = np.load(target_features)["feats"]

return Bunch(structures=structures, features=features, energies=energies)
target_properties = join(module_path, "data", "beran_roy_properties.npz")
properties = np.load(target_properties)

return Bunch(
densities=properties["densities"],
energies=properties["energies"],
structure_types=properties["structure_types"],
features=properties["feats"],
)
Binary file removed src/skmatter/datasets/data/beran_roy_features.npz
Binary file not shown.
Binary file not shown.
Binary file not shown.
37 changes: 4 additions & 33 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,45 +107,16 @@ class ROYTests(unittest.TestCase):
def setUpClass(cls):
cls.size = 264
cls.shape = (264, 32)
try:
from ase.io import read # NoQa: F401

cls.has_ase = True
cls.roy = load_roy_dataset()
except ImportError:
cls.has_ase = False

def test_load_dataset_without_ase(self):
"""Check if the correct exception occurs when ase isn't present."""
with unittest.mock.patch.dict("sys.modules", {"ase.io": None}):
with self.assertRaises(ImportError) as cm:
_ = load_roy_dataset()
self.assertEqual(
str(cm.exception), "load_roy_dataset requires the ASE package."
)
cls.roy = load_roy_dataset()

def test_dataset_content(self):
"""Check if the correct number of datapoints are present in the dataset.
Also check if the size of the dataset is correct.
"""
if self.has_ase is True:
self.assertEqual(len(self.roy["structures"]), self.size)
self.assertEqual(self.roy["features"].shape, self.shape)
self.assertEqual(len(self.roy["energies"]), self.size)

def test_dataset_consistency(self):
"""Check if the energies in the structures are the same as in the explicit
array.
"""
if self.has_ase is True:
self.assertTrue(
np.allclose(
self.roy["energies"],
[f.info["energy"] for f in self.roy["structures"]],
rtol=1e-6,
)
)
self.assertEqual(len(self.roy["structure_types"]), self.size)
self.assertEqual(self.roy["features"].shape, self.shape)
self.assertEqual(len(self.roy["energies"]), self.size)


if __name__ == "__main__":
Expand Down

0 comments on commit bf97f18

Please sign in to comment.