Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cahity committed Nov 18, 2024
1 parent edcc55f commit 76daf4a
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion test/models/test_empirical_mean_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_predict(self):

# NOTE: Actual coordinates of test points are not used, only the indices are used.

with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
self.model.predict(np.array([[-1, -1]]))

X = np.array([[-1, -1, 0], [-1, -1, 1]])
Expand Down
6 changes: 6 additions & 0 deletions test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ def test_sample_from_single_posterior(self):
samples.shape, (2, len(self.X)), "Sample shape mismatch for single posterior."
)

def test_get_lengthscale_and_var(self):
"""Test get_lengthscale_and_var method."""
lengthscales, variances = self.model.get_lengthscale_and_var()
self.assertGreaterEqual(np.min(variances), 0, "Negative variance.")
self.assertGreaterEqual(np.min(lengthscales), 0, "Negative lengthscale.")


class TestGetGPyTorchModelListWithKnownHyperparams(unittest.TestCase):
"""Test the `get_gpytorch_modellist_w_known_hyperparams` function."""
Expand Down
2 changes: 1 addition & 1 deletion test/test_confidence_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def setUp(self):

def test_update(self):
"""Test the update method."""
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
self.confidence_region.update(np.array([1, 1]), np.ones([2, 3]))

self.confidence_region_new = copy.deepcopy(self.confidence_region)
Expand Down
2 changes: 1 addition & 1 deletion vectoptal/confidence_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def update(
:type scale: np.ndarray
"""
if covariance.shape[-1] != covariance.shape[-2]:
raise AssertionError("Covariance matrix must be square.")
raise ValueError("Covariance matrix must be square.")
std = np.sqrt(np.diag(covariance.squeeze()))

L = mean - std * scale
Expand Down
4 changes: 2 additions & 2 deletions vectoptal/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,11 +676,11 @@ def get_lengthscale_and_var(self) -> tuple[np.ndarray, np.ndarray]:
:return: Lengthscales and variances for each objective.
:rtype: tuple[np.ndarray, np.ndarray]
"""
lengthscales = np.zeros(self.input_dim)
lengthscales = np.zeros((len(self.model.models), self.input_dim))
variances = np.zeros(self.input_dim)
for model_i, model in enumerate(self.model.models):
cov_module = model.covar_module
lengthscale = cov_module.base_kernel.lengthscale.squeeze().numpy(force=True).item()
lengthscale = cov_module.base_kernel.lengthscale.squeeze().numpy(force=True)
variance = cov_module.outputscale.squeeze().numpy(force=True).item()

lengthscales[model_i] = lengthscale
Expand Down

0 comments on commit 76daf4a

Please sign in to comment.