Skip to content

Commit

Permalink
added mean metric to sensitivity
Browse files Browse the repository at this point in the history
  • Loading branch information
luigibonati committed Dec 18, 2023
1 parent 3b08667 commit 9674613
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
23 changes: 13 additions & 10 deletions mlcolvar/utils/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def sensitivity_analysis(
plot_mode="violin",
ax=None,
):
"""Perform a sensitivity analysis to measure which input features the model is most sensitive to (i.e., which quantities produce significant changes in the output).
"""Perform a sensitivity analysis using the partial derivatives method. This allows us to measure which input features the model is most sensitive to (i.e., which quantities produce significant changes in the output).
To do this, the partial derivatives of the model with respect to each input :math:`x_i` are computed over a set of `N` points of a :math:`$$\{\mathbf{x}^{(j)}\}_{j=1} ^N$$` dataset.
These values, in the case where the dataset is not standardized, are multiplied by the standard deviation of the features over the dataset.
Expand All @@ -24,7 +24,9 @@ def sensitivity_analysis(
or as the root mean square (metric=`RMS`):
.. math:: s_i = \sqrt{\frac{1}{N} \sum_j \left({\frac{\partial s}{\partial x_i}(\mathbf{x}^{(j)})}\ \sigma_i\right)^2 }
The sensitivity values are normalized such that they sum to 1.
In alternative, one can also compute simply average, without taking the absolute values (metric=`mean`).
In all the above cases, the sensitivity values are normalized such that they sum to 1.
In case in which a labeled dataset these quantities can be computed also on the subset of the data belonging to each class.
Expand All @@ -42,7 +44,7 @@ def sensitivity_analysis(
feature_names : _type_, optional
array-like with input features names, by default it takes them from the dataset if available
metric : str, optional
sensitivity measure ('mean_abs_val'|'MAV','root_mean_square'|'RMS')', by default 'mean_abs_val'
sensitivity measure ('mean_abs_val'|'MAV','root_mean_square'|'RMS','mean'), by default 'mean_abs_val'
per_class : bool, optional
if the dataset has labels, compute also the sensitivity per class, by default False
plot_mode : str, optional
Expand Down Expand Up @@ -75,30 +77,31 @@ def sensitivity_analysis(
# get gradients
grad_output = torch.ones_like(s)
grad = torch.autograd.grad(s, X, grad_outputs=grad_output)[0].detach().cpu().numpy()
grad = np.abs(grad)
if metric != "mean":
grad = np.abs(grad)

# multiply grad_xi by std_xi
grad = grad * std

# normalize such that the averages sums to 1
grad /= grad.mean(axis=0).sum()
# normalize such that the average of the abs sums to 1
grad /= np.abs(grad).mean(axis=0).sum()

# get metrics
def _compute_score(grad, metric):
if (metric == "mean_abs_val") | (metric == "MEAN_ABS") | (metric == "MAV"):
if (metric == "mean_abs_val") | (metric == "MEAN_ABS") | (metric == "MAV") | (metric == 'mean'):
score = grad.mean(axis=0)
elif (metric == "root_mean_square") | (metric == "rms") | (metric == "RMS"):
score = np.sqrt((grad**2).mean(axis=0))
else:
raise NotImplementedError(
"only mean_abs_value (MAV) or root_mean_square (RMS) metrics are allowed"
"only `mean_abs_value` (MAV) or `root_mean_square` (RMS), or `mean` metrics are allowed"
)
return score

score = _compute_score(grad, metric)

# sort features based on score
index = score.argsort()
# sort features based on (absolute) sensitivity
index = np.abs(score).argsort()
feature_names = np.asarray(feature_names)[index]
score = score[index]
grad = grad[:, index]
Expand Down
5 changes: 4 additions & 1 deletion mlcolvar/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,10 @@ def _set_violin_attributes(violin_parts, color, alpha=0.5, label=None, zorder=No
ax.legend(*zip(*patch_labels), loc="lower right", frameon=False)
else:
ax.legend(loc="lower right", frameon=False)
ax.set_xlim(0, None)
if np.min(results["sensitivity"]["Dataset"])>=0:
ax.set_xlim(0, None)
else:
ax.axvline(0,color='grey')
ax.set_ylim(-1, in_num[-1] + 1)

if return_ax:
Expand Down

0 comments on commit 9674613

Please sign in to comment.