Skip to content

Commit

Permalink
fix code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Dec 14, 2022
1 parent f40933c commit 450bbc5
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions mellon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

DEFAULT_COV_FUNC = Matern52


class BaseEstimator:
R"""
Base class for the mellon estimators.
Expand Down Expand Up @@ -117,7 +118,7 @@ def _set_x(self, x):
def _compute_landmarks(self):
x = self.x
n_landmarks = self.n_landmarks
logger.info(f'Computing {n_landmarks:,} landmarks with k-means clustering.')
logger.info(f"Computing {n_landmarks:,} landmarks with k-means clustering.")
landmarks = compute_landmarks(x, n_landmarks=n_landmarks)
return landmarks

Expand All @@ -136,7 +137,7 @@ def _compute_cov_func(self):
cov_func_curry = self.cov_func_curry
ls = self.ls
cov_func = compute_cov_func(cov_func_curry, ls)
logger.info('Using covariance function %s.', str(cov_func))
logger.info("Using covariance function %s.", str(cov_func))
return cov_func

def _compute_L(self):
Expand All @@ -150,7 +151,7 @@ def _compute_L(self):
if isinstance(rank, float):
logger.info(
f'Computing rank reduction using "{method}" method '
f'retaining > {rank:.2%} of variance.'
f"retaining > {rank:.2%} of variance."
)
else:
logger.info(
Expand All @@ -166,7 +167,7 @@ def _compute_L(self):
"indicates underrepresentation by landmarks. Consider "
"increasing n_landmarks!"
)
logger.info(f'Using rank {new_rank:,} covariance representation.')
logger.info(f"Using rank {new_rank:,} covariance representation.")
return L

def _run_inference(self):
Expand All @@ -175,7 +176,7 @@ def _run_inference(self):
n_iter = self.n_iter
init_learn_rate = self.init_learn_rate
optimizer = self.optimizer
logger.info('Running inference using %s.', optimizer)
logger.info("Running inference using %s.", optimizer)
if optimizer == "adam":
results = minimize_adam(
function,
Expand Down Expand Up @@ -549,7 +550,7 @@ def _compute_loss_func(self):
def _set_log_density_x(self):
pre_transformation = self.pre_transformation
transform = self.transform
logger.info('Decoding latent density representation.')
logger.info("Decoding latent density representation.")
log_density_x = compute_log_density_x(pre_transformation, transform)
self.log_density_x = log_density_x

Expand All @@ -560,7 +561,7 @@ def _set_log_density_func(self):
mu = self.mu
cov_func = self.cov_func
jitter = self.jitter
logger.info('Computing predictive function.')
logger.info("Computing predictive function.")
log_density_func = compute_conditional_mean(
x,
landmarks,
Expand Down

0 comments on commit 450bbc5

Please sign in to comment.