Skip to content

Commit

Permalink
fix tests and warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
luigibonati committed Jun 6, 2024
1 parent e80ca9c commit 5489c4b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
28 changes: 19 additions & 9 deletions mlcolvar/explain/lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def lasso_classification(dataset,
if scale_inputs:
scaler = StandardScaler(with_mean=True, with_std=True)
descriptors = scaler.fit_transform(raw_descriptors)
else:
descriptors = raw_descriptors

# Define cross-validation for LASSO, using
# a custom scoring function based on accuracy and number of features
Expand Down Expand Up @@ -162,13 +164,13 @@ def lasso_classification(dataset,
print(f'- Accuracy : {accuracy*100:.2f}%')
print(f'- # features : {len(selected_coeffs)}\n')
print(f'Features: ')
for i,(f,c) in enumerate(zip(selected_feature_names, selected_coeffs)):
for j,(f,c) in enumerate(zip(selected_feature_names, selected_coeffs)):
print(f'({i+1}) {f:13s}: {c:.6f}')
print('==================================\n')

# plot results
if plot:
_ = plot_lasso_classification(classifier, feats, coeffs)
plot_lasso_classification(classifier, feats, coeffs)

return classifier, feats, coeffs

Expand All @@ -185,8 +187,11 @@ def plot_lasso_classification(classifier, feats = None, coeffs = None, draw_labe

# define figure
if axs is None:
fig, axs = plt.subplots(3, n_models, figsize=(6*n_models, 9), sharex=True)
init_axs = True
_, axs = plt.subplots(3, n_models, figsize=(6*n_models, 9), sharex=True)
plt.suptitle('LASSO CLASSIFICATION')
else:
init_axs = False

for i,key in enumerate(classifier.scores_.keys()):

Expand Down Expand Up @@ -252,9 +257,9 @@ def plot_lasso_classification(classifier, feats = None, coeffs = None, draw_labe
ax.axvline(classifier.C_[i],color='gray',linestyle='dotted')
ax.set_xmargin(0)

matplotlib.pyplot.tight_layout()

return (fig, axs)
if init_axs:
matplotlib.pyplot.tight_layout()


def lasso_regression(dataset,
alphas = None,
Expand Down Expand Up @@ -300,6 +305,8 @@ def lasso_regression(dataset,
if scale_inputs:
scaler = StandardScaler(with_mean=True, with_std=True)
descriptors = scaler.fit_transform(raw_descriptors)
else:
descriptors = raw_descriptors

# Define Cross-validation & fit
_regressor = LassoCV(alphas=alphas)
Expand Down Expand Up @@ -355,6 +362,9 @@ def plot_lasso_regression(regressor, feats = None, coeffs = None, draw_labels='a
if axs is None:
fig, axs = plt.subplots(3, 1, figsize=(6, 9), sharex=True)
plt.suptitle('LASSO REGRESSION')
init_axs = True
else:
init_axs = False

# (1) COEFFICIENTS PATH
ax = axs[0]
Expand Down Expand Up @@ -406,8 +416,8 @@ def plot_lasso_regression(regressor, feats = None, coeffs = None, draw_labels='a
for ax in axs:
ax.axvline(regressor.alpha_,color='gray',linestyle='dotted')

plt.tight_layout()

return (fig, axs)
if init_axs:
matplotlib.pyplot.tight_layout()



6 changes: 0 additions & 6 deletions mlcolvar/explain/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ def plot_sensitivity(results, mode="violin", per_class=None, max_features = 100,
fig = plt.figure(figsize=(5, 0.25 * n_inputs))
ax = fig.add_subplot(111)
ax.set_title("Sensitivity Analysis")
return_ax = True
else:
return_ax = False

# define utils functions
def _set_violin_attributes(violin_parts, color, alpha=0.5, label=None, zorder=None):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
Expand Down Expand Up @@ -294,9 +291,6 @@ def _set_violin_attributes(violin_parts, color, alpha=0.5, label=None, zorder=No
ax.axvline(0,color='grey')
ax.set_ylim(-1, in_num[-1] + 1)

if return_ax:
return ax

def test_sensitivity_analysis():
from mlcolvar.data import DictDataset
from mlcolvar.cvs import DeepLDA
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from mlcolvar.utils.explain import test_sensitivity_analysis
from mlcolvar.explain.sensitivity import test_sensitivity_analysis

if __name__ == "__main__":
test_sensitivity_analysis()

0 comments on commit 5489c4b

Please sign in to comment.