From 5489c4ba632358a8c9af447cacf56fec8e361378 Mon Sep 17 00:00:00 2001 From: Luigi Bonati Date: Thu, 6 Jun 2024 10:18:24 +0200 Subject: [PATCH] fix tests and warnings --- mlcolvar/explain/lasso.py | 28 +++++++++++++------ mlcolvar/explain/sensitivity.py | 6 ---- ...explain.py => test_explain_sensitivity.py} | 2 +- 3 files changed, 20 insertions(+), 16 deletions(-) rename mlcolvar/tests/{test_utils_explain.py => test_explain_sensitivity.py} (52%) diff --git a/mlcolvar/explain/lasso.py b/mlcolvar/explain/lasso.py index 018eec02..86d7cf18 100644 --- a/mlcolvar/explain/lasso.py +++ b/mlcolvar/explain/lasso.py @@ -113,6 +113,8 @@ def lasso_classification(dataset, if scale_inputs: scaler = StandardScaler(with_mean=True, with_std=True) descriptors = scaler.fit_transform(raw_descriptors) + else: + descriptors = raw_descriptors # Define cross-validation for LASSO, using # a custom scoring function based on accuracy and number of features @@ -162,13 +164,13 @@ def lasso_classification(dataset, print(f'- Accuracy : {accuracy*100:.2f}%') print(f'- # features : {len(selected_coeffs)}\n') print(f'Features: ') - for i,(f,c) in enumerate(zip(selected_feature_names, selected_coeffs)): + for j,(f,c) in enumerate(zip(selected_feature_names, selected_coeffs)): print(f'({i+1}) {f:13s}: {c:.6f}') print('==================================\n') # plot results if plot: - _ = plot_lasso_classification(classifier, feats, coeffs) + plot_lasso_classification(classifier, feats, coeffs) return classifier, feats, coeffs @@ -185,8 +187,11 @@ def plot_lasso_classification(classifier, feats = None, coeffs = None, draw_labe # define figure if axs is None: - fig, axs = plt.subplots(3, n_models, figsize=(6*n_models, 9), sharex=True) + init_axs = True + _, axs = plt.subplots(3, n_models, figsize=(6*n_models, 9), sharex=True) plt.suptitle('LASSO CLASSIFICATION') + else: + init_axs = False for i,key in enumerate(classifier.scores_.keys()): @@ -252,9 +257,9 @@ def plot_lasso_classification(classifier, feats = None, coeffs = None, draw_labe ax.axvline(classifier.C_[i],color='gray',linestyle='dotted') ax.set_xmargin(0) - matplotlib.pyplot.tight_layout() - - return (fig, axs) + if init_axs: + matplotlib.pyplot.tight_layout() + def lasso_regression(dataset, alphas = None, @@ -300,6 +305,8 @@ def lasso_regression(dataset, if scale_inputs: scaler = StandardScaler(with_mean=True, with_std=True) descriptors = scaler.fit_transform(raw_descriptors) + else: + descriptors = raw_descriptors # Define Cross-validation & fit _regressor = LassoCV(alphas=alphas) @@ -355,6 +362,9 @@ def plot_lasso_regression(regressor, feats = None, coeffs = None, draw_labels='a if axs is None: fig, axs = plt.subplots(3, 1, figsize=(6, 9), sharex=True) plt.suptitle('LASSO REGRESSION') + init_axs = True + else: + init_axs = False # (1) COEFFICIENTS PATH ax = axs[0] @@ -406,8 +416,8 @@ def plot_lasso_regression(regressor, feats = None, coeffs = None, draw_labels='a for ax in axs: ax.axvline(regressor.alpha_,color='gray',linestyle='dotted') - plt.tight_layout() - - return (fig, axs) + if init_axs: + matplotlib.pyplot.tight_layout() + diff --git a/mlcolvar/explain/sensitivity.py b/mlcolvar/explain/sensitivity.py index 0e82ede8..a553a4ba 100644 --- a/mlcolvar/explain/sensitivity.py +++ b/mlcolvar/explain/sensitivity.py @@ -202,9 +202,6 @@ def plot_sensitivity(results, mode="violin", per_class=None, max_features = 100, fig = plt.figure(figsize=(5, 0.25 * n_inputs)) ax = fig.add_subplot(111) ax.set_title("Sensitivity Analysis") - return_ax = True - else: - return_ax = False # define utils functions def _set_violin_attributes(violin_parts, color, alpha=0.5, label=None, zorder=None): @@ -294,9 +291,6 @@ def _set_violin_attributes(violin_parts, color, alpha=0.5, label=None, zorder=No ax.axvline(0,color='grey') ax.set_ylim(-1, in_num[-1] + 1) - if return_ax: - return ax - def test_sensitivity_analysis(): from mlcolvar.data import DictDataset from mlcolvar.cvs import DeepLDA diff --git a/mlcolvar/tests/test_utils_explain.py b/mlcolvar/tests/test_explain_sensitivity.py similarity index 52% rename from mlcolvar/tests/test_utils_explain.py rename to mlcolvar/tests/test_explain_sensitivity.py index 87d4cf0e..0b80078d 100644 --- a/mlcolvar/tests/test_utils_explain.py +++ b/mlcolvar/tests/test_explain_sensitivity.py @@ -1,6 +1,6 @@ import pytest -from mlcolvar.utils.explain import test_sensitivity_analysis +from mlcolvar.explain.sensitivity import test_sensitivity_analysis if __name__ == "__main__": test_sensitivity_analysis()