Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sensitivity analysis #108

Merged
merged 15 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 172 additions & 36 deletions docs/notebooks/examples/ex_DeepLDA.ipynb

Large diffs are not rendered by default.

579 changes: 579 additions & 0 deletions docs/notebooks/tutorials/adv_features_relevances.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mlcolvar/cvs/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class BaseCV:
"""
Base collective variable class.

To inherit from this class, the class must define a BLOCKS class attribute.
"""

Expand Down
48 changes: 33 additions & 15 deletions mlcolvar/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ def __init__(
# Keeping this private for now. Changing it at runtime would
# require changing dataset_split and the dataloaders.
self._random_split = random_split

# save generator if given, otherwise set it to torch.default_generator
self.generator = generator if generator is not None else default_generator
if self.generator is not None and not self._random_split:
warnings.warn("A torch.generator was provided but it is not used with random_split=False")
warnings.warn(
"A torch.generator was provided but it is not used with random_split=False"
)

# Make sure batch_size and shuffle are lists.
if isinstance(batch_size, int):
Expand Down Expand Up @@ -219,11 +221,13 @@ def __repr__(self) -> str:

def _split(self, dataset):
"""Perform the random or sequential spliting of a single dataset.

Returns a list of Subset[DictDataset] objects.
"""

dataset_split = split_dataset(dataset, self.lengths, self._random_split, self.generator)
dataset_split = split_dataset(
dataset, self.lengths, self._random_split, self.generator
)
return dataset_split

def _check_setup(self):
Expand All @@ -234,10 +238,13 @@ def _check_setup(self):
"outside a Lightning trainer please call .setup() first."
)

def split_dataset(dataset,
lengths: Sequence,
random_split : bool,
generator : Optional[torch.Generator] = default_generator) -> list:

def split_dataset(
dataset,
lengths: Sequence,
random_split: bool,
generator: Optional[torch.Generator] = default_generator,
) -> list:
"""
Sequentially or randomly split a dataset into non-overlapping new datasets of given lengths.

Expand Down Expand Up @@ -271,15 +278,22 @@ def split_dataset(dataset,
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.")
warnings.warn(
f"Length of split at index {i} is 0. "
f"This might result in an empty dataset."
)

# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
)
if random_split:
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(_accumulate(lengths), lengths)
]
else:
return [
Subset(dataset, np.arange(offset - length, offset))
Expand All @@ -302,8 +316,12 @@ def sequential_split(dataset, lengths: Sequence) -> list:
until there are no remainders left.
"""

warnings.warn("The function sequential_split is deprecated, use split_dataset(.., .., random_split=False, ..)", FutureWarning, stacklevel=2)

warnings.warn(
"The function sequential_split is deprecated, use split_dataset(.., .., random_split=False, ..)",
FutureWarning,
stacklevel=2,
)

return split_dataset(dataset=dataset, lengths=lengths, random_split=False)


Expand Down
18 changes: 17 additions & 1 deletion mlcolvar/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ class DictDataset(Dataset):
'weights' : np.asarray([0.5,1.5,1.5,0.5]) }
"""

def __init__(self, dictionary: dict = None, **kwargs):
def __init__(self, dictionary: dict = None, feature_names=None, **kwargs):
"""Create a Dataset from a dictionary or from a list of kwargs.

Parameters
----------
dictionary : dict
Dictionary with names and tensors
feature_names : array-like
List or numpy array with feature names

"""
# assert type dict
Expand All @@ -44,6 +46,9 @@ def __init__(self, dictionary: dict = None, **kwargs):
# save dictionary
self._dictionary = dictionary

# save feature names
self.feature_names = feature_names

# check that all elements of dict have same length
it = iter(dictionary.values())
self.length = len(next(it))
Expand Down Expand Up @@ -101,6 +106,17 @@ def __repr__(self) -> str:
def keys(self):
return tuple(self._dictionary.keys())

@property
def feature_names(self):
"""Feature names."""
return self._feature_names

@feature_names.setter
def feature_names(self, value):
self._feature_names = (
np.asarray(value, dtype=str) if value is not None else value
)


def test_DictDataset():
# from list
Expand Down
6 changes: 6 additions & 0 deletions mlcolvar/tests/test_utils_explain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import pytest
Dismissed Show dismissed Hide dismissed

from mlcolvar.utils.explain import test_sensitivity_analysis

if __name__ == "__main__":
test_sensitivity_analysis()
177 changes: 177 additions & 0 deletions mlcolvar/utils/explain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import numpy as np
import torch

from mlcolvar.utils.plot import plot_sensitivity


def sensitivity_analysis(
model,
dataset,
std=None,
feature_names=None,
metric="mean_abs_val",
per_class=False,
plot_mode="violin",
ax=None,
):
"""Perform a sensitivity analysis using the partial derivatives method. This allows us to measure which input features the model is most sensitive to (i.e., which quantities produce significant changes in the output).

To do this, the partial derivatives of the model with respect to each input :math:`x_i` are computed over a set of `N` points of a :math:`$$\{\mathbf{x}^{(j)}\}_{j=1} ^N$$` dataset.
These values, in the case where the dataset is not standardized, are multiplied by the standard deviation of the features over the dataset.

Then, an average sensitivity value :math:`s_i` is computed, either as the mean absolute value (metric=`MAV`):
.. math:: s_i = \frac{1}{N} \sum_j \left|{\frac{\partial s}{\partial x_i}(\mathbf{x}^{(j)})}\right| \sigma_i

or as the root mean square (metric=`RMS`):
.. math:: s_i = \sqrt{\frac{1}{N} \sum_j \left({\frac{\partial s}{\partial x_i}(\mathbf{x}^{(j)})}\ \sigma_i\right)^2 }

In alternative, one can also compute simply average, without taking the absolute values (metric=`mean`).

In all the above cases, the sensitivity values are normalized such that they sum to 1.

In case in which a labeled dataset these quantities can be computed also on the subset of the data belonging to each class.

See also
--------
mlcolvar.utils.fes.plot_sensitivity
Plot the sensitivity analysis results

Parameters
----------
model : mlcolvar.cvs.BaseCV
collective variable model
dataset : mlcovar.data.DictDataset
dataset on which to compute the sensitivity analysis.
std : array_like, optional
standard deviation of the features, by default it will be computed from the dataset
feature_names : array-like, optional
array-like with input features names, by default they will be taken from the dataset if available
metric : str, optional
sensitivity measure ('mean_abs_val'|'MAV','root_mean_square'|'RMS','mean'), by default 'mean_abs_val'
per_class : bool, optional
if the dataset has labels, compute also the sensitivity per class, by default False
plot_mode : str, optional
how to visualize the results ('violin','barh','scatter'), by default 'violin'
ax : matplotlib.axis, optional
ax where to plot the results, by default it will be initialized

Returns
-------
results: dictionary
results of the sensitivity analysis, containing 'feature_names', the 'sensitivity' and the 'gradients' per samples, ordered according to the sensitivity.
"""

# get dataset
X = dataset["data"]
n_inputs = X.shape[1]

# get feature names
if feature_names is None:
if dataset.feature_names is not None:
feature_names = dataset.feature_names
else:
feature_names = np.asarray([str(i + 1) for i in range(n_inputs)])

# get standard deviation
if std is None:
std = dataset.get_stats()["data"]["std"].detach().numpy()
else:
std = np.asarray(std)

# compute cv
X.requires_grad = True
s = model(X)

# get gradients
grad_output = torch.ones_like(s)
grad = torch.autograd.grad(s, X, grad_outputs=grad_output)[0].detach().cpu().numpy()
if metric != "mean":
grad = np.abs(grad)

# multiply grad_xi by std_xi
grad = grad * std

# normalize such that the average of the abs sums to 1
grad /= np.abs(grad).mean(axis=0).sum()

# get metrics
def _compute_score(grad, metric):
if (metric == "mean_abs_val") | (metric == "MEAN_ABS") | (metric == "MAV") | (metric == 'mean'):
score = grad.mean(axis=0)
elif (metric == "root_mean_square") | (metric == "rms") | (metric == "RMS"):
score = np.sqrt((grad**2).mean(axis=0))
else:
raise NotImplementedError(
"only `mean_abs_value` (MAV) or `root_mean_square` (RMS), or `mean` metrics are allowed"
)
return score

score = _compute_score(grad, metric)

# sort features based on (absolute) sensitivity
index = np.abs(score).argsort()
feature_names = np.asarray(feature_names)[index]
score = score[index]
grad = grad[:, index]

# store into results
out = {}
out["feature_names"] = feature_names
out["sensitivity"] = {"Dataset": score}
out["gradients"] = {"Dataset": grad}

# per class statistics
if per_class:
try:
labels = dataset["labels"].numpy().astype(int)
except KeyError:
raise KeyError(
"Per class analyis requested but no labels found in the given dataset."
)

unique_labels = np.unique(labels)
for i, l in enumerate(unique_labels):
mask = np.argwhere(labels == l)[:, 0]
grad_l = grad[mask, :]
score_l = _compute_score(grad_l, metric)
out["sensitivity"][f"State {l}"] = score_l
out["gradients"][f"State {l}"] = grad_l

# plot
if plot_mode is not None:
plot_sensitivity(out, mode=plot_mode, ax=ax)

return out


def test_sensitivity_analysis():
from mlcolvar.data import DictDataset
from mlcolvar.cvs import DeepLDA

n_states = 2
in_features, out_features = 2, n_states - 1
layers = [in_features, 5, 5, out_features]

# create dataset
samples = 10
X = torch.randn((samples * n_states, 2))

# create labels
y = torch.zeros(X.shape[0])
for i in range(1, n_states):
y[samples * i :] += 1

dataset = DictDataset({"data": X, "labels": y})

# define CV
opts = {
"nn": {"activation": "shifted_softplus"},
}
model = DeepLDA(layers, n_states, options=opts)

# feature importances
for per_class in [True, False, None]:
for names in [None, ["x", "y"], np.asarray(["x", "y"])]:
results = sensitivity_analysis(
Dismissed Show dismissed Hide dismissed
model, dataset, feature_names=names, per_class=per_class, plot_mode=None
)
2 changes: 1 addition & 1 deletion mlcolvar/utils/fes.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def test_compute_fes():

Y = np.random.rand(2, 100)

if SKLEARN_IS_INSTALLED: # TODO: change to use pytest functionalities?
if SKLEARN_IS_INSTALLED: # TODO: change to use pytest functionalities?
fes, bins, bounds, error_ = compute_fes(
X=[Y[0], Y[1]],
weights=np.ones_like(X),
Expand Down
2 changes: 1 addition & 1 deletion mlcolvar/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def create_dataset_from_files(
dictionary = {"data": torch.Tensor(df_data.values)}
if create_labels:
dictionary["labels"] = torch.Tensor(df["labels"].values)
dataset = DictDataset(dictionary)
dataset = DictDataset(dictionary, feature_names=df_data.columns.values)

if return_dataframe:
return dataset, df
Expand Down
Loading