Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sensitivity tutorial to doc #112

Merged
merged 2 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions docs/notebooks/tutorials/adv_features_relevances.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@
"## Sensitivity analyis"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook we show how to identify the most relevant features for a model using a sensitivity analysis based on the partial derivatives. This has been used to gain insights into the workings of the [Deep-LDA](https://pubs.acs.org/doi/10.1021/acs.jpclett.0c00535) and [DeepTICA](https://www.pnas.org/doi/10.1073/pnas.2113533118https://www.pnas.org/doi/10.1073/pnas.2113533118) CVs."
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -87,9 +94,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the DeepLDA CV trained for the [intermolecular aldol reaction](https://colab.research.google.com/github/luigibonati/mlcolvar/blob/main/docs/notebooks/examples/ex_DeepLDA.ipynb) from the [examples](https://mlcolvar.readthedocs.io/en/stable/examples.html) as a case study to explore the sensitivity analysis method. Of course, the same analysis can be applied also to the other CVs.\n",
"\n",
"Note: we have found it to work well with smooth activation functions such as the `shifted_softplus`."
"We will use the DeepLDA CV trained for the [intermolecular aldol reaction](https://colab.research.google.com/github/luigibonati/mlcolvar/blob/main/docs/notebooks/examples/ex_DeepLDA.ipynb) from the [examples](https://mlcolvar.readthedocs.io/en/stable/examples.html) as a case study to explore the sensitivity analysis method. Of course, the same analysis can be applied also to the other CVs."
]
},
{
Expand Down Expand Up @@ -235,6 +240,13 @@
"### Features relevance "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Introduction"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -255,14 +267,23 @@
"\n",
"- just the mean: $$s_i = \\frac{1}{N} \\sum_j {\\frac{\\partial s}{\\partial x_i}(\\mathbf{x}^{(j)})}\\ \\sigma_i$$\n",
"\n",
"As we will show below, if we have a labeled dataset, we can restrict this analysis to get per-class statistics by averaging only on a subset of samples ($j \\in A,B,...$)"
"Notes:\n",
"- As we will show below, if we have a labeled dataset, we can restrict this analysis to get per-class statistics by averaging only on a subset of samples ($j \\in A,B,...$)\n",
"- Since it is based on the derivatives of the model, we have found it to work well with smooth activation functions such as the `shifted_softplus`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Perform the sensitivity analysis"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The sensitivity analysis is performed by the function `mlcolvar.utils.explain.sensitivity_analysis` which takes a model and a dataset as compulsory arguments. It can directly plot the features but in this case we will do it later and set `plot_mode=None`. "
"The sensitivity analysis is computed by the function `mlcolvar.utils.explain.sensitivity_analysis` which takes a model and a dataset as compulsory arguments. It can directly plot the features but in this case we will do it later and set `plot_mode=None`. "
]
},
{
Expand All @@ -285,7 +306,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The results are stored in a dictionary, which contains the `feature_names` (ranked according to the sensitivity), the sensitivity calculated on the whole dataset, and the per-sample gradients (useful to to more analyisis)."
"The results are stored in a dictionary, which contains the `feature_names` (ranked according to the metric), the sensitivity calculated on the whole dataset, and the per-sample gradients (useful for more detailed analyisis)."
]
},
{
Expand Down Expand Up @@ -332,7 +353,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If the dataset is labeled we can also compute the sensitivity by averaging only on the corresponding subsets:"
"If the dataset is labeled we can also compute the sensitivity by averaging only on the corresponding subsets, using `per_class=True`:"
]
},
{
Expand Down Expand Up @@ -447,8 +468,7 @@
"\n",
"results = sensitivity_analysis(model, dataset, per_class=True, plot_mode=None)\n",
"\n",
"# Plot \n",
"\n",
"# Plot sensitivity\n",
"fig,axs = plt.subplots(1,3,figsize=(12,12))\n",
"\n",
"modes = ['barh','scatter','violin']\n",
Expand Down
1 change: 1 addition & 0 deletions docs/tutorials_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Customizing CVs
notebooks/tutorials/adv_newcv_scratch.ipynb
notebooks/tutorials/adv_newcv_subclass.ipynb
notebooks/tutorials/adv_preprocessing.ipynb
notebooks/tutorials/adv_features_relevances.ipynb