diff --git a/openqdc/__init__.py b/openqdc/__init__.py index 1bf99b1..7e0a39b 100644 --- a/openqdc/__init__.py +++ b/openqdc/__init__.py @@ -5,7 +5,11 @@ # The below lazy import logic is coming from openff-toolkit: # https://github.com/openforcefield/openff-toolkit/blob/b52879569a0344878c40248ceb3bd0f90348076a/openff/toolkit/__init__.py#L44 + # Dictionary of objects to lazily import; maps the object's name to its module path +def get_project_root(): + return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + _lazy_imports_obj = { "__version__": "openqdc._version", diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py index 6e58f92..469a033 100644 --- a/openqdc/datasets/base.py +++ b/openqdc/datasets/base.py @@ -2,7 +2,6 @@ import os import pickle as pkl -from copy import deepcopy from functools import partial from itertools import compress from os.path import join as p_join @@ -71,13 +70,13 @@ class BaseDataset(DatasetPropertyMixIn): __energy_methods__ = [] __force_mask__ = [] __isolated_atom_energies__ = [] + _fn_energy = lambda x: x + _fn_distance = lambda x: x + _fn_forces = lambda x: x __energy_unit__ = "hartree" __distance_unit__ = "ang" __forces_unit__ = "hartree/ang" - __fn_energy__ = lambda x: x - __fn_distance__ = lambda x: x - __fn_forces__ = lambda x: x __average_nb_atoms__ = None def __init__( @@ -123,6 +122,7 @@ def __init__( solver_type can be one of ["linear", "ridge"] """ set_cache_dir(cache_dir) + # self._init_lambda_fn() self.data = None self.recompute_statistics = recompute_statistics self.regressor_kwargs = regressor_kwargs @@ -136,6 +136,11 @@ def __init__( self.set_array_format(array_format) self._post_init(overwrite_local_cache, energy_unit, distance_unit) + def _init_lambda_fn(self): + self._fn_energy = lambda x: x + self._fn_distance = lambda x: x + self._fn_forces = lambda x: x + def _post_init( self, overwrite_local_cache: bool = False, @@ -272,7 +277,7 @@ def _set_units(self, en, ds): self.set_distance_unit(ds) if self.__force_methods__: self.__forces_unit__ = self.energy_unit + "/" + self.distance_unit - self.__class__.__fn_forces__ = get_conversion(old_en + "/" + old_ds, self.__forces_unit__) + self._fn_forces = get_conversion(old_en + "/" + old_ds, self.__forces_unit__) def _set_isolated_atom_energies(self): if self.__energy_methods__ is None: @@ -281,13 +286,13 @@ def _set_isolated_atom_energies(self): self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix) def convert_energy(self, x): - return self.__class__.__fn_energy__(x) + return self._fn_energy(x) def convert_distance(self, x): - return self.__class__.__fn_distance__(x) + return self._fn_distance(x) def convert_forces(self, x): - return self.__class__.__fn_forces__(x) + return self._fn_forces(x) def set_energy_unit(self, value: str): """ @@ -295,7 +300,7 @@ def set_energy_unit(self, value: str): """ old_unit = self.energy_unit self.__energy_unit__ = value - self.__class__.__fn_energy__ = get_conversion(old_unit, value) + self._fn_energy = get_conversion(old_unit, value) def set_distance_unit(self, value: str): """ @@ -303,7 +308,7 @@ def set_distance_unit(self, value: str): """ old_unit = self.distance_unit self.__distance_unit__ = value - self.__class__.__fn_distance__ = get_conversion(old_unit, value) + self._fn_distance = get_conversion(old_unit, value) def set_array_format(self, format: str): assert format in ["numpy", "torch", "jax"], f"Format {format} not supported." @@ -535,39 +540,29 @@ def get_statistics(self, return_none: bool = True): Whether to return None if the statistics for the forces are not available, by default True Otherwise, the statistics for the forces are set to 0.0 """ - stats = deepcopy(self.statistics.get_results()) - if len(stats) == 0: + selected_stats = self.statistics.get_results() + if len(selected_stats) == 0: raise StatisticsNotAvailableError(self.__name__) - selected_stats = stats if not return_none: selected_stats.update( { "ForcesCalculatorStats": { "mean": np.array([0.0]), "std": np.array([0.0]), - "components": { - "mean": np.array([[0.0], [0.0], [0.0]]), - "std": np.array([[0.0], [0.0], [0.0]]), - "rms": np.array([[0.0], [0.0], [0.0]]), - }, + "component_mean": np.array([[0.0], [0.0], [0.0]]), + "component_std": np.array([[0.0], [0.0], [0.0]]), + "component_rms": np.array([[0.0], [0.0], [0.0]]), } } ) # cycle trough dict to convert units - for key in selected_stats: - if key.lower() == str(ForcesCalculatorStats): - for key2 in selected_stats[key]: - if key2 != "components": - selected_stats[key][key2] = self.convert_forces(selected_stats[key][key2]) - else: - for key2 in selected_stats[key]["components"]: - selected_stats[key]["components"][key2] = self.convert_forces( - selected_stats[key]["components"][key2] - ) + for key, result in selected_stats.items(): + if isinstance(result, ForcesCalculatorStats): + result.transform(self.convert_forces) else: - for key2 in selected_stats[key]: - selected_stats[key][key2] = self.convert_energy(selected_stats[key][key2]) - return selected_stats + result.transform(self.convert_energy) + result.transform(self._convert_array) + return {k: result.to_dict() for k, result in selected_stats.items()} def __str__(self): return f"{self.__name__}" diff --git a/openqdc/datasets/io.py b/openqdc/datasets/io.py index d1ab330..bf90ea5 100644 --- a/openqdc/datasets/io.py +++ b/openqdc/datasets/io.py @@ -50,9 +50,14 @@ def __init__( self.__energy_unit__ = energy_unit self.__distance_unit__ = distance_unit self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory] + self.energy_target_names = ["xyz"] self.regressor_kwargs = regressor_kwargs self.transform = transform self._read_and_preprocess() + if "forces" in self.data: + self.__force_mask__ = [True] + self.__class__.__force_methods__ = [level_of_theory] + self.force_target_names = ["xyz"] self.set_array_format(array_format) self._post_init(True, energy_unit, distance_unit) diff --git a/openqdc/datasets/potential/dummy.py b/openqdc/datasets/potential/dummy.py index cca6b78..1c7a61c 100644 --- a/openqdc/datasets/potential/dummy.py +++ b/openqdc/datasets/potential/dummy.py @@ -1,8 +1,11 @@ +import pickle as pkl +from os.path import join as p_join + import numpy as np +from loguru import logger from openqdc.datasets.base import BaseDataset from openqdc.methods import PotentialMethod -from openqdc.utils.constants import NOT_DEFINED class Dummy(BaseDataset): @@ -11,8 +14,8 @@ class Dummy(BaseDataset): """ __name__ = "dummy" - __energy_methods__ = [PotentialMethod.SVWN_DEF2_TZVP, PotentialMethod.PM6] - __force_mask__ = [False, True] + __energy_methods__ = [PotentialMethod.SVWN_DEF2_TZVP, PotentialMethod.PM6, PotentialMethod.GFN2_XTB] + __force_mask__ = [False, True, True] __energy_unit__ = "kcal/mol" __distance_unit__ = "ang" __forces_unit__ = "kcal/mol/ang" @@ -23,25 +26,6 @@ class Dummy(BaseDataset): __isolated_atom_energies__ = [] __average_n_atoms__ = None - @property - def _stats(self): - return { - "formation": { - "energy": { - "mean": np.array([[-12.94348027, -9.83037297]]), - "std": np.array([[4.39971409, 3.3574188]]), - }, - "forces": NOT_DEFINED, - }, - "total": { - "energy": { - "mean": np.array([[-89.44242, -1740.5336]]), - "std": np.array([[29.599571, 791.48663]]), - }, - "forces": NOT_DEFINED, - }, - } - def _post_init(self, overwrite_local_cache, energy_unit, distance_unit) -> None: self.setup_dummy() return super()._post_init(overwrite_local_cache, energy_unit, distance_unit) @@ -90,3 +74,55 @@ def read_raw_entries(self): def __len__(self): return 9999 + + +class PredefinedDataset(BaseDataset): + __name__ = "predefineddataset" + __energy_methods__ = [PotentialMethod.WB97M_D3BJ_DEF2_TZVPPD] # "wb97m-d3bj/def2-tzvppd"] + __force_mask__ = [True] + __energy_unit__ = "hartree" + __distance_unit__ = "bohr" + __forces_unit__ = "hartree/bohr" + force_target_names = __energy_methods__ + energy_target_names = __energy_methods__ + + @property + def preprocess_path(self, overwrite_local_cache=False): + from os.path import join as p_join + + from openqdc import get_project_root + + return p_join(get_project_root(), "tests", "files", self.__name__, "preprocessed") + + def is_preprocessed(self): + return True + + def read_raw_entries(self): + pass + + def read_preprocess(self, overwrite_local_cache=False): + logger.info("Reading preprocessed data.") + logger.info( + f"Dataset {self.__name__} with the following units:\n\ + Energy: {self.energy_unit},\n\ + Distance: {self.distance_unit},\n\ + Forces: {self.force_unit if self.force_methods else 'None'}" + ) + self.data = {} + for key in self.data_keys: + print(key, self.data_shapes[key], self.data_types[key]) + filename = p_join(self.preprocess_path, f"{key}.mmap") + self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(*self.data_shapes[key]) + + filename = p_join(self.preprocess_path, "props.pkl") + with open(filename, "rb") as f: + tmp = pkl.load(f) + for key in ["name", "subset", "n_atoms"]: + x = tmp.pop(key) + if len(x) == 2: + self.data[key] = x[0][x[1]] + else: + self.data[key] = x + + for key in self.data: + logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}") diff --git a/openqdc/datasets/statistics.py b/openqdc/datasets/statistics.py index 2997d2e..e4fe9e5 100644 --- a/openqdc/datasets/statistics.py +++ b/openqdc/datasets/statistics.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from copy import deepcopy from dataclasses import asdict, dataclass from os.path import join as p_join from typing import Optional @@ -18,12 +19,9 @@ class StatisticsResults: def to_dict(self): return asdict(self) - def convert(self, func): + def transform(self, func): for k, v in self.to_dict().items(): - if isinstance(v, dict): - self.convert(func) - else: - setattr(self, k, func(v)) + setattr(self, k, func(v)) @dataclass @@ -36,19 +34,6 @@ class EnergyStatistics(StatisticsResults): std: Optional[np.ndarray] -@dataclass -class ForceComponentsStatistics(StatisticsResults): - """ - Dataclass for force statistics related to the x,y,z components - mean,std,rms are supposed to be 2d arrays related to the x,y,z components - of the forces - """ - - mean: Optional[np.ndarray] - std: Optional[np.ndarray] - rms: Optional[np.ndarray] - - @dataclass class ForceStatistics(StatisticsResults): """ @@ -57,7 +42,9 @@ class ForceStatistics(StatisticsResults): mean: Optional[np.ndarray] std: Optional[np.ndarray] - components: ForceComponentsStatistics + component_mean: Optional[np.ndarray] + component_std: Optional[np.ndarray] + component_rms: Optional[np.ndarray] class StatisticManager: @@ -66,11 +53,9 @@ class StatisticManager: the statistic calculators """ - _state = {} - _results = {} - def __init__(self, dataset, recompute: bool = False, *statistic_calculators: "AbstractStatsCalculator"): - self.reset_state() + self._state = {} + self._results = {} self._statistic_calculators = [ statistic_calculators.from_openqdc_dataset(dataset, recompute) for statistic_calculators in statistic_calculators @@ -89,6 +74,12 @@ def reset_state(self): """ self._state = {} + def reset_results(self): + """ + Reset the results dictionary + """ + self._results = {} + def get_state(self, key: Optional[str] = None): """ key : str, default = None @@ -105,11 +96,14 @@ def has_state(self, key: str): """ return key in self._state - def get_results(self): + def get_results(self, as_dict: bool = False): """ Aggregate results from all the calculators """ - return self._results + results = deepcopy(self._results) + if as_dict: + return {k: v.as_dict() for k, v in results.items()} + return {k: v for k, v in self._results.items()} def run_calculators(self): """ @@ -205,7 +199,7 @@ def save_statistics(self) -> None: """ Save statistics file to the dataset folder as a pkl file """ - save_pkl(self.result.to_dict(), self.preprocess_path) + save_pkl(self.result, self.preprocess_path) def attempt_load(self) -> bool: """ @@ -266,17 +260,20 @@ class ForcesCalculatorStats(AbstractStatsCalculator): def compute(self) -> ForceStatistics: if not self.has_forces: - return ForceStatistics( - mean=None, std=None, components=ForceComponentsStatistics(rms=None, std=None, mean=None) - ) + return ForceStatistics(mean=None, std=None, component_mean=None, component_std=None, component_rms=None) converted_force_data = self.forces - force_mean = np.nanmean(converted_force_data, axis=0) - force_std = np.nanstd(converted_force_data, axis=0) - force_rms = np.sqrt(np.nanmean(converted_force_data**2, axis=0)) + num_methods = converted_force_data.shape[2] + mean = np.nanmean(converted_force_data.reshape(-1, num_methods), axis=0) + std = np.nanstd(converted_force_data.reshape(-1, num_methods), axis=0) + component_mean = np.nanmean(converted_force_data, axis=0) + component_std = np.nanstd(converted_force_data, axis=0) + component_rms = np.sqrt(np.nanmean(converted_force_data**2, axis=0)) return ForceStatistics( - mean=np.atleast_2d(force_mean), - std=np.atleast_2d(force_std), - components=ForceComponentsStatistics(rms=force_rms, std=force_std, mean=force_mean), + mean=np.atleast_2d(mean), + std=np.atleast_2d(std), + component_mean=np.atleast_2d(component_mean), + component_std=np.atleast_2d(component_std), + component_rms=np.atleast_2d(component_rms), ) @@ -285,7 +282,7 @@ class TotalEnergyStats(AbstractStatsCalculator): Total Energy statistics calculator class """ - def compute(self): + def compute(self) -> EnergyStatistics: converted_energy_data = self.energies total_E_mean = np.nanmean(converted_energy_data, axis=0) total_E_std = np.nanstd(converted_energy_data, axis=0) diff --git a/openqdc/utils/constants.py b/openqdc/utils/constants.py index 0db0336..f14be26 100644 --- a/openqdc/utils/constants.py +++ b/openqdc/utils/constants.py @@ -24,11 +24,9 @@ NOT_DEFINED = { "mean": None, "std": None, - "components": { - "mean": None, - "std": None, - "rms": None, - }, + "component_mean": None, + "component_std": None, + "component_rms": None, } ATOM_TABLE = Chem.GetPeriodicTable() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/files/predefineddataset/preprocessed/atomic_inputs.mmap b/tests/files/predefineddataset/preprocessed/atomic_inputs.mmap new file mode 100644 index 0000000..3068169 Binary files /dev/null and b/tests/files/predefineddataset/preprocessed/atomic_inputs.mmap differ diff --git a/tests/files/predefineddataset/preprocessed/energies.mmap b/tests/files/predefineddataset/preprocessed/energies.mmap new file mode 100644 index 0000000..8c7f22c Binary files /dev/null and b/tests/files/predefineddataset/preprocessed/energies.mmap differ diff --git a/tests/files/predefineddataset/preprocessed/forces.mmap b/tests/files/predefineddataset/preprocessed/forces.mmap new file mode 100644 index 0000000..bea7597 Binary files /dev/null and b/tests/files/predefineddataset/preprocessed/forces.mmap differ diff --git a/tests/files/predefineddataset/preprocessed/forcescalculatorstats.pkl b/tests/files/predefineddataset/preprocessed/forcescalculatorstats.pkl new file mode 100644 index 0000000..17d2b1d Binary files /dev/null and b/tests/files/predefineddataset/preprocessed/forcescalculatorstats.pkl differ diff --git a/tests/files/predefineddataset/preprocessed/formationenergystats_formation.pkl b/tests/files/predefineddataset/preprocessed/formationenergystats_formation.pkl new file mode 100644 index 0000000..cfcb18e Binary files /dev/null and b/tests/files/predefineddataset/preprocessed/formationenergystats_formation.pkl differ diff --git a/tests/files/predefineddataset/preprocessed/peratomformationenergystats_formation.pkl b/tests/files/predefineddataset/preprocessed/peratomformationenergystats_formation.pkl new file mode 100644 index 0000000..a79f9d0 Binary files /dev/null and b/tests/files/predefineddataset/preprocessed/peratomformationenergystats_formation.pkl differ diff --git a/tests/files/predefineddataset/preprocessed/position_idx_range.mmap b/tests/files/predefineddataset/preprocessed/position_idx_range.mmap new file mode 100644 index 0000000..a6a40eb Binary files /dev/null and b/tests/files/predefineddataset/preprocessed/position_idx_range.mmap differ diff --git a/tests/files/predefineddataset/preprocessed/props.pkl b/tests/files/predefineddataset/preprocessed/props.pkl new file mode 100644 index 0000000..a92429b Binary files /dev/null and b/tests/files/predefineddataset/preprocessed/props.pkl differ diff --git a/tests/files/predefineddataset/preprocessed/totalenergystats.pkl b/tests/files/predefineddataset/preprocessed/totalenergystats.pkl new file mode 100644 index 0000000..72176eb Binary files /dev/null and b/tests/files/predefineddataset/preprocessed/totalenergystats.pkl differ diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..581cb98 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,74 @@ +import numpy as np +import pytest +from numpy import array, float32 + +from openqdc.datasets.potential.dummy import PredefinedDataset +from openqdc.datasets.statistics import EnergyStatistics, ForceStatistics + + +@pytest.fixture +def targets(): + return { + # hartree/bohr + "ForcesCalculatorStats": ForceStatistics( + mean=array([[-7.4906794e-08]], dtype=float32), + std=array([[0.02859425]], dtype=float32), + component_mean=array([[4.6292794e-07], [-2.1531498e-07], [-4.7250555e-07]], dtype=float32), + component_std=array([[0.02794589], [0.03237366], [0.02497733]], dtype=float32), + component_rms=array([[0.02794588], [0.03237367], [0.02497733]], dtype=float32), + ), + # Hartree + "TotalEnergyStats": EnergyStatistics( + mean=array([[-126.0]], dtype=float32), std=array([[79.64923]], dtype=float32) + ), + # Hartree + "FormationEnergyStats": EnergyStatistics(mean=array([[841.82607372]]), std=array([[448.15780975]])), + # Hartree + "PerAtomFormationEnergyStats": EnergyStatistics(mean=array([[20.18697415]]), std=array([[7.30153839]])), + } + + +@pytest.mark.parametrize( + "property,expected", + [ + ("n_atoms", [27, 48, 27, 45, 45]), + ("energies", [[-90.0], [-230.0], [-10.0], [-200.0], [-100.0]]), + ], +) +def test_dataset_load(property, expected): + ds = PredefinedDataset(energy_type="formation") + assert ds is not None + assert len(ds) == 5 + assert ds.data["atomic_inputs"].shape == (192, 5) + assert ds.data["forces"].shape == (192, 3, 1) + np.testing.assert_equal(ds.data[property], np.array(expected)) + + +def test_predefined_dataset(targets): + ds = PredefinedDataset(energy_type="formation") + keys = ["ForcesCalculatorStats", "FormationEnergyStats", "PerAtomFormationEnergyStats", "TotalEnergyStats"] + assert all(k in ds.get_statistics() for k in keys) + stats = ds.get_statistics() + + formation_energy_stats = stats["FormationEnergyStats"] + formation_energy_stats_t = targets["FormationEnergyStats"].to_dict() + np.testing.assert_almost_equal(formation_energy_stats["mean"], formation_energy_stats_t["mean"]) + np.testing.assert_almost_equal(formation_energy_stats["std"], formation_energy_stats_t["std"]) + + per_atom_formation_energy_stats = stats["PerAtomFormationEnergyStats"] + per_atom_formation_energy_stats_t = targets["PerAtomFormationEnergyStats"].to_dict() + np.testing.assert_almost_equal(per_atom_formation_energy_stats["mean"], per_atom_formation_energy_stats_t["mean"]) + np.testing.assert_almost_equal(per_atom_formation_energy_stats["std"], per_atom_formation_energy_stats_t["std"]) + + total_energy_stats = stats["TotalEnergyStats"] + total_energy_stats_t = targets["TotalEnergyStats"].to_dict() + np.testing.assert_almost_equal(total_energy_stats["mean"], total_energy_stats_t["mean"]) + np.testing.assert_almost_equal(total_energy_stats["std"], total_energy_stats_t["std"]) + + forces_stats = stats["ForcesCalculatorStats"] + forces_stats_t = targets["ForcesCalculatorStats"].to_dict() + np.testing.assert_almost_equal(forces_stats["mean"], forces_stats_t["mean"]) + np.testing.assert_almost_equal(forces_stats["std"], forces_stats_t["std"]) + np.testing.assert_almost_equal(forces_stats["component_mean"], forces_stats_t["component_mean"]) + np.testing.assert_almost_equal(forces_stats["component_std"], forces_stats_t["component_std"]) + np.testing.assert_almost_equal(forces_stats["component_rms"], forces_stats_t["component_rms"]) diff --git a/tests/test_dummy.py b/tests/test_dummy.py index 4929c4c..08ee127 100644 --- a/tests/test_dummy.py +++ b/tests/test_dummy.py @@ -1,11 +1,19 @@ """Path hack to make tests work.""" +import os + import numpy as np import pytest from openqdc.datasets.potential.dummy import Dummy # noqa: E402 +from openqdc.utils.io import get_local_cache from openqdc.utils.package_utils import has_package +# start by removing any cached data +cache_dir = get_local_cache() +os.system(f"rm -rf {cache_dir}/dummy") + + if has_package("torch"): import torch @@ -65,3 +73,56 @@ def custom_fn(bunch): assert "new_key" in data assert data["new_key"] == data["name"] + data["subset"] + + +def test_get_statistics(ds): + stats = ds.get_statistics() + + keys = ["ForcesCalculatorStats", "FormationEnergyStats", "PerAtomFormationEnergyStats", "TotalEnergyStats"] + assert all(k in stats for k in keys) + + +def test_energy_statistics_shapes(ds): + stats = ds.get_statistics() + + num_methods = len(ds.energy_methods) + + formation_energy_stats = stats["FormationEnergyStats"] + assert formation_energy_stats["mean"].shape == (1, num_methods) + assert formation_energy_stats["std"].shape == (1, num_methods) + + per_atom_formation_energy_stats = stats["PerAtomFormationEnergyStats"] + assert per_atom_formation_energy_stats["mean"].shape == (1, num_methods) + assert per_atom_formation_energy_stats["std"].shape == (1, num_methods) + + total_energy_stats = stats["TotalEnergyStats"] + assert total_energy_stats["mean"].shape == (1, num_methods) + assert total_energy_stats["std"].shape == (1, num_methods) + + +def test_force_statistics_shapes(ds): + stats = ds.get_statistics() + num_force_methods = len(ds.force_methods) + + forces_stats = stats["ForcesCalculatorStats"] + keys = ["mean", "std", "component_mean", "component_std", "component_rms"] + assert all(k in forces_stats for k in keys) + + assert forces_stats["mean"].shape == (1, num_force_methods) + assert forces_stats["std"].shape == (1, num_force_methods) + assert forces_stats["component_mean"].shape == (3, num_force_methods) + assert forces_stats["component_std"].shape == (3, num_force_methods) + assert forces_stats["component_rms"].shape == (3, num_force_methods) + + +@pytest.mark.parametrize("format", ["numpy", "torch", "jax"]) +def test_stats_array_format(format): + if not has_package(format): + pytest.skip(f"{format} is not installed, skipping test") + + ds = Dummy(array_format=format) + stats = ds.get_statistics() + + for key in stats.keys(): + for k, v in stats[key].items(): + assert isinstance(v, format_to_type[format]) diff --git a/tests/test_regressor.py b/tests/test_regressor.py index 6da0a28..db20b08 100644 --- a/tests/test_regressor.py +++ b/tests/test_regressor.py @@ -27,8 +27,9 @@ def test_regressors(small_dummy): setattr(reg, "solver_type", solver_type) reg.solver = reg._get_solver() assert isinstance(reg.solver, inst) + num_methods = len(small_dummy.energy_methods) try: results = reg.solve() - assert results[0].shape[1] == 2 + assert results[0].shape[1] == num_methods except np.linalg.LinAlgError: pass