diff --git a/mellon/__init__.py b/mellon/__init__.py index f7f3d91..e353293 100644 --- a/mellon/__init__.py +++ b/mellon/__init__.py @@ -27,7 +27,6 @@ "TimeSensitiveDensityEstimator", "Predictor", "Covariance", - "Log", "util", "cov", "model", @@ -70,3 +69,4 @@ } logging.config.dictConfig(LOGGING_CONFIG) +logger = logging.getLogger("mellon") diff --git a/mellon/base_cov.py b/mellon/base_cov.py index d0390fa..58af6e8 100644 --- a/mellon/base_cov.py +++ b/mellon/base_cov.py @@ -413,8 +413,8 @@ def k_grad(y): y_shape = y.shape y = select_active_dims(y, active_dims) - left_k = self.left.k(x, y) - right_k = self.right.k(x, y) + left_k = self.left.k(x, y)[..., None] + right_k = self.right.k(x, y)[..., None] left_grad = left_grad_func(y) right_grad = right_grad_func(y) @@ -447,6 +447,9 @@ def __repr__(self): return "(" + repr(self.left) + " ** " + repr(self.right) + ")" def k(self, x, y): + x = select_active_dims(x, self.active_dims) + y = select_active_dims(y, self.active_dims) + return self.left(x, y) ** self.right def k_grad(self, x): @@ -480,7 +483,7 @@ def k_grad(self, x): def k_grad(y): y_shape = y.shape y = select_active_dims(y, active_dims) - base_k = self.left.k(x, y) + base_k = self.left.k(x, y)[..., None] base_grad = base_grad_func(y) # Compute the gradient of the powered covariance function using the chain rule diff --git a/mellon/util.py b/mellon/util.py index f3efb49..8522466 100644 --- a/mellon/util.py +++ b/mellon/util.py @@ -443,7 +443,6 @@ def test_rank(input, tol=DEFAULT_RANK_TOL, threshold=None): rank_fraction = approx_rank / max_rank if threshold is not None: - logger = Log() if rank_fraction > threshold: logger.warning( f"High approx. rank fraction ({rank_fraction:.1%}). " diff --git a/tests/test_base_cov.py b/tests/test_base_cov.py index b53b575..8af0ba8 100644 --- a/tests/test_base_cov.py +++ b/tests/test_base_cov.py @@ -13,7 +13,7 @@ ) def test_Add(active_dims): n = 2 - d = 2 + d = 3 cov1 = mellon.cov.Matern32(1.4) cov2 = mellon.cov.Exponential(3.4) @@ -24,8 +24,8 @@ def test_Add(active_dims): x = jnp.ones((n, d)) values = cov(x, 2 * x) assert values.shape == ( - d, - d, + n, + n, ), "Covariance should be computed for each pair of samples." cov.active_dims = active_dims @@ -48,13 +48,9 @@ def test_Add(active_dims): expected_grad = jax.vmap(jax.jacfwd(k_func), in_axes=(0,), out_axes=1)(y) # Assert that the gradients are close - if active_dims is not None: - expected_grad = expected_grad[..., active_dims] - if jax.numpy.isscalar(active_dims): - expected_grad = expected_grad[..., None] assert jnp.allclose( computed_grad, expected_grad, atol=1e-6 - ), f"Gradients do not match in {CovarianceClass.__name__} with active_dims {active_dims}" + ), f"Gradients do not match in {cov.__class__.__name__} covariance with active_dims {active_dims}" @pytest.mark.parametrize( @@ -63,7 +59,7 @@ def test_Add(active_dims): ) def test_Mul(active_dims): n = 2 - d = 2 + d = 3 cov1 = mellon.cov.Matern32(1.4) cov2 = mellon.cov.Exponential(3.4) @@ -74,8 +70,8 @@ def test_Mul(active_dims): x = jnp.ones((n, d)) values = cov(x, 2 * x) assert values.shape == ( - d, - d, + n, + n, ), "Covariance should be computed for each pair of samples." cov.active_dims = active_dims @@ -97,14 +93,10 @@ def test_Mul(active_dims): k_func = lambda y: cov.k(x, y[None,])[..., 0] expected_grad = jax.vmap(jax.jacfwd(k_func), in_axes=(0,), out_axes=1)(y) - # Assert that the gradients are close - if active_dims is not None: - expected_grad = expected_grad[..., active_dims] - if jax.numpy.isscalar(active_dims): - expected_grad = expected_grad[..., None] + # Assert that the gradients are closn assert jnp.allclose( computed_grad, expected_grad, atol=1e-6 - ), f"Gradients do not match in {CovarianceClass.__name__} with active_dims {active_dims}" + ), f"Gradients do not match in {cov.__class__.__name__} covariance with active_dims {active_dims}" @pytest.mark.parametrize( @@ -113,7 +105,7 @@ def test_Mul(active_dims): ) def test_Pow(active_dims): n = 2 - d = 2 + d = 3 cov1 = mellon.cov.Matern32(1.4) cov = cov1**3.2 @@ -123,8 +115,8 @@ def test_Pow(active_dims): x = jnp.ones((n, d)) values = cov(x, 2 * x) assert values.shape == ( - d, - d, + n, + n, ), "Covariance should be computed for each pair of samples." cov.active_dims = active_dims @@ -147,13 +139,9 @@ def test_Pow(active_dims): expected_grad = jax.vmap(jax.jacfwd(k_func), in_axes=(0,), out_axes=1)(y) # Assert that the gradients are close - if active_dims is not None: - expected_grad = expected_grad[..., active_dims] - if jax.numpy.isscalar(active_dims): - expected_grad = expected_grad[..., None] assert jnp.allclose( computed_grad, expected_grad, atol=1e-6 - ), f"Gradients do not match in {CovarianceClass.__name__} with active_dims {active_dims}" + ), f"Gradients do not match in {cov.__class__.__name__} covariance with active_dims {active_dims}" def test_Hirachical(): @@ -183,16 +171,15 @@ def test_Hirachical(): ), "Serialization + deserialization of added covariance functions must return the same result." # Compute the gradient using k_grad - y = 2 * x + y = 2 * x**2 k_grad_func = cov.k_grad(x) computed_grad = k_grad_func(y) # Compute the gradient using JAX automatic differentiation k_func = lambda y: cov.k(x, y[None,])[..., 0] expected_grad = jax.vmap(jax.jacfwd(k_func), in_axes=(0,), out_axes=1)(y) - expected_grad = expected_grad[..., active_dims] # Assert that the gradients are close assert jnp.allclose( computed_grad, expected_grad, atol=1e-6 - ), f"Gradients do not match in hirachical covariance combination." + ), f"Gradients do not match in hirachichal covariance." diff --git a/tests/test_density_estimator.py b/tests/test_density_estimator.py index 62e4933..5340a67 100644 --- a/tests/test_density_estimator.py +++ b/tests/test_density_estimator.py @@ -10,7 +10,7 @@ def common_setup(tmp_path): d = 2 seed = 535 test_file = tmp_path / "predictor.json" - logger = mellon.Log() + logger = mellon.logger key = jax.random.PRNGKey(seed) L = jax.random.uniform(key, (d, d)) cov = L.T.dot(L) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index e762822..fe099ff 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -12,6 +12,7 @@ compute_landmarks_rescale_time, compute_nn_distances_within_time_points, compute_d, + compute_d_factal, compute_Lp, compute_L, ) @@ -55,41 +56,41 @@ def test_compute_nn_distances_within_time_points(): assert jnp.all(result_with_times == result) -def test_compute_d(caplog): +def test_compute_d_factal(caplog): # Create a random key for jax.random key = random.PRNGKey(0) # Create a random array using jax.random x_2d = random.normal(key, shape=(100, 10)) - result_2d = compute_d(x_2d) + result_2d = compute_d_factal(x_2d) assert isinstance(result_2d, float) # Test with 1D array (should return 1) x_1d = random.normal(key, shape=(100,)) - assert compute_d(x_1d) == 1 + assert compute_d_factal(x_1d) == 1 # Test with k > number of samples (expect a warning) x_small = random.normal(key, shape=(5, 10)) logger = logging.getLogger("mellon") logger.propagate = True with caplog.at_level(logging.WARNING, logger="mellon"): - compute_d(x_small, k=10) + compute_d_factal(x_small, k=10) logger.propagate = False assert "is greater than the number of samples" in caplog.text # Test with specific random seed x_seed = random.normal(key, shape=(100, 10)) - result_seed = compute_d(x_seed, seed=432) + result_seed = compute_d_factal(x_seed, seed=432) assert isinstance(result_seed, float) # Test with n < number of samples x_n = random.normal(key, shape=(1000, 10)) - result_n = compute_d(x_n, n=500) + result_n = compute_d_factal(x_n, n=500) assert isinstance(result_n, float) # Test with invalid input (negative k) with pytest.raises(ValueError): - compute_d(x_2d, k=-5) + compute_d_factal(x_2d, k=-5) def test_compute_Lp(): diff --git a/tests/test_time_sensitive_density_estimator.py b/tests/test_time_sensitive_density_estimator.py index 1b78471..eb8dd39 100644 --- a/tests/test_time_sensitive_density_estimator.py +++ b/tests/test_time_sensitive_density_estimator.py @@ -13,7 +13,7 @@ def common_setup_time_sensitive(tmp_path): d = 2 seed = 535 test_file = tmp_path / "predictor.json" - logger = mellon.Log() + logger = mellon.logger key = jax.random.PRNGKey(seed) L = jax.random.uniform(key, (d, d)) cov = L.T.dot(L) diff --git a/tests/test_util.py b/tests/test_util.py index 60c28a0..c2af85d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,3 +1,4 @@ +import logging import mellon import jax import jax.numpy as jnp @@ -69,15 +70,18 @@ def test_local_dimensionality(): assert dist.shape == (n,), "Local dim should be computed for each point." -def test_Log(): - logger = mellon.util.Log() - assert logger is mellon.util.Log(), "Log should be a singelton class." - assert hasattr(logger, "debug") - assert hasattr(logger, "info") - assert hasattr(logger, "warn") - assert hasattr(logger, "error") - mellon.util.Log.off() - mellon.util.Log.on() +def test_set_verbosity_to_false_changes_level_to_warning(caplog): + mellon.util.set_verbosity(False) + assert ( + mellon.logger.getEffectiveLevel() == logging.WARNING + ), "Logging level should be set to WARNING when verbosity is False." + + +def test_set_verbosity_to_true_changes_level_to_info(caplog): + mellon.util.set_verbosity(True) + assert ( + mellon.logger.getEffectiveLevel() == logging.INFO + ), "Logging level should be set to INFO when verbosity is True." def test_set_jax_config():