Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mbar overlap matrix plot #107

Merged
merged 16 commits into from
Aug 26, 2020
32 changes: 32 additions & 0 deletions docs/visualisation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Visualisation of the results
============================
It is quite often that the user want to visualise the results to gain
confidence on the computed free energy. **alchemlyb** provides various
visualisation tools to help user to judge the estimate.

.. _plot_overlap_matrix:
Overlap Matrix of the MBAR
--------------------------
The accuracy of the :class:`~alchemlyb.estimators.MBAR` estimator depends on
the overlap between different lambda states. The overlap matrix from the
:class:`~alchemlyb.estimators.MBAR` estimator could be plotted to check
the degree of overlap. It is recommended that there should be at least
**0.03** [Klimovich2015]_ overlap between neighboring states. ::

>>> import pandas as pd
>>> from alchemtest.gmx import load_benzene
>>> from alchemlyb.parsing.gmx import extract_u_nk
>>> from alchemlyb.estimators import MBAR

>>> bz = load_benzene().data
>>> u_nk_coul = pd.concat([extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']])
>>> mbar_coul = MBAR()
>>> mbar_coul.fit(u_nk_coul)

>>> from alchemlyb.visualisation.mbar_martix import plot_mbar_omatrix
>>> ax = plot_mbar_omatrix(mbar_coul.overlap_maxtrix)
>>> ax.figure.savefig('O_MBAR.pdf', bbox_inches='tight', pad_inches=0.0)
orbeckst marked this conversation as resolved.
Show resolved Hide resolved

.. [Klimovich2015] Klimovich, P.V., Shirts, M.R. & Mobley, D.L. Guidelines for
the analysis of free energy calculations. J Comput Aided Mol Des 29, 397–411
(2015). https://doi.org/10.1007/s10822-015-9840-9
16 changes: 16 additions & 0 deletions docs/visualisation/alchemlyb.visualisation.plot_mbar_omatrix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Plot Overlap Matrix from MBar
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MBAR

=============================

The function :func:`~alchemlyb.visualisation.plot_mbar_omatrix` allows the user
to plot the overlap matrix from
:attr:`~alchemlyb.estimators.MBAR.overlap_maxtrix`. The user can pass
:class:`matplotlib.axes.Axes` into the function to have the overlap maxtrix
drawn on a specific axes. The user could also specify a list of lambda states
to be skipped when labelling the states.

Please check :ref:`How to plot MBAR overlap matrix <plot_overlap_matrix>` for
usage.

API Reference
-------------
.. autofunction:: alchemlyb.visualisation.plot_mbar_omatrix
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@
license='BSD',
long_description=open('README.rst').read(),
tests_require = ['pytest', 'alchemtest'],
install_requires=['numpy', 'pandas>=0.23.0', 'pymbar>=3.0.5', 'scipy', 'scikit-learn']
install_requires=['numpy', 'pandas>=0.23.0', 'pymbar>=3.0.5', 'scipy', 'scikit-learn', 'matplotlib']
)
4 changes: 4 additions & 0 deletions src/alchemlyb/estimators/mbar_.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class MBAR(BaseEstimator):
states_ : list
Lambda states for which free energy differences were obtained.

overlap_maxtrix: DataFrame
The overlap matrix.

"""

def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7,
Expand Down Expand Up @@ -85,6 +88,7 @@ def fit(self, u_nk):
verbose=self.verbose)

self.states_ = u_nk.columns.values.tolist()
self.overlap_maxtrix = self._mbar.computeOverlap()['matrix']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it a managed attribute so that it is not computed by default or make it a method compute_overlap_matrix() -- not sure if @dotsdl has a preference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxtrix => matrix


# set attributes
out = self._mbar.getFreeEnergyDifferences(return_theta=True)
Expand Down
23 changes: 23 additions & 0 deletions src/alchemlyb/tests/test_visualisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

def test_plot_mbar_omatrix():
'''Just test if the plot runs'''
import pandas as pd
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imports at top of test file

from alchemtest.gmx import load_benzene
from alchemlyb.parsing.gmx import extract_u_nk
from alchemlyb.estimators import MBAR

bz = load_benzene().data
u_nk_coul = pd.concat([extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']])
mbar_coul = MBAR()
mbar_coul.fit(u_nk_coul)

orbeckst marked this conversation as resolved.
Show resolved Hide resolved
from alchemlyb.visualisation.mbar_martix import plot_mbar_omatrix
plot_mbar_omatrix(mbar_coul.overlap_maxtrix)
plot_mbar_omatrix(mbar_coul.overlap_maxtrix, [1,])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert isinstance axes


# Bump up coverage
overlap_maxtrix = mbar_coul.overlap_maxtrix
overlap_maxtrix[0,0] = 0.0025
overlap_maxtrix[-1, -1] = 0.9975
plot_mbar_omatrix(overlap_maxtrix)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert that you return an axes object


1 change: 1 addition & 0 deletions src/alchemlyb/visualisation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mbar_martix import plot_mbar_omatrix
81 changes: 81 additions & 0 deletions src/alchemlyb/visualisation/mbar_martix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Functions for Plotting the overlay matrix for the Mbar estimator.

"""
from __future__ import division
import matplotlib
matplotlib.use('Agg')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove, this is up to the user

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orbeckst Thank you for the review. I got a problem with CI for py2, where if I don't add this line in. I would get

src/alchemlyb/tests/test_visualisation.py:15: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/alchemlyb/visualisation/mbar_martix.py:31: in plot_mbar_omatrix
    fig, ax = plt.subplots(figsize=(size / 2, size / 2))
../../../virtualenv/python2.7.15/lib/python2.7/site-packages/matplotlib/pyplot.py:1184: in subplots
    fig = figure(**fig_kw)
../../../virtualenv/python2.7.15/lib/python2.7/site-packages/matplotlib/pyplot.py:533: in figure
    **kwargs)
../../../virtualenv/python2.7.15/lib/python2.7/site-packages/matplotlib/backend_bases.py:161: in new_figure_manager
    return cls.new_figure_manager_given_figure(num, fig)
../../../virtualenv/python2.7.15/lib/python2.7/site-packages/matplotlib/backends/_backend_tk.py:1046: in new_figure_manager_given_figure
    window = Tk.Tk(className="matplotlib")

As is seen https://travis-ci.org/github/alchemistry/alchemlyb/jobs/720700286, Google suggests that this is one way of solving the problem. I'm not familiar with CI so I'm not sure of what is the best way to process.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orbeckst I'm sorry for the confusion. I originally put the line matplotlib.use('Agg') in order to avoid the travis failing for py2. I wonder what is your suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to be brief: see https://github.com/MDAnalysis/mdanalysis/pull/2798/files#diff-354f30a63fb0907d4ad57269548329e3R39 : add the env var MPLBACKEND: "agg" to the CI.

import matplotlib.pyplot as plt
import numpy as np
def plot_mbar_omatrix(matrix, skip_lambda_index=[], ax=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write out omatrix => overlap_matrix: just be explicit

'''Plot the MBar overlap matrix.

Parameters
----------
matrix : DataFrame
DataFrame of the overlap matrix obtained from
:attr:`~alchemlyb.estimators.MABR.overlap_maxtrix`
skip_lambda_index : List
list of lambda indices to be omitted from plotting process.
Default: [].
ax : matplotlib.axes.Axes
Matplotlib axes object where the plot will be drawn on. If ax=None,
a new axes will be generated.

Returns
-------
matplotlib.axes.Axes
An axes with the overlap matrix drawn.
'''
# Compute the size of the figure, if ax is not given.
max_prob = matrix.max()
size = len(matrix)
if ax is None:
fig, ax = plt.subplots(figsize=(size / 2, size / 2))
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
for i in range(size):
if i != 0:
ax.axvline(x=i, ls='-', lw=0.5, color='k', alpha=0.25)
ax.axhline(y=i, ls='-', lw=0.5, color='k', alpha=0.25)
for j in range(size):
if matrix[j, i] < 0.005:
ii = ''
elif matrix[j, i] > 0.995:
ii = '1.00'
else:
ii = ("{:.2f}".format(matrix[j, i])[1:])
alf = matrix[j, i] / max_prob
ax.fill_between([i, i + 1], [size - j, size - j], [size - (j + 1), size - (j + 1)], color='k', alpha=alf)
ax.annotate(ii, xy=(i, j), xytext=(i + 0.5, size - (j + 0.5)), size=8, textcoords='data', va='center',
ha='center', color=('k' if alf < 0.5 else 'w'))

if skip_lambda_index:
ks = [int(l) for l in skip_lambda_index]
ks = np.delete(np.arange(size + len(ks)), ks)
else:
ks = range(size)
for i in range(size):
ax.annotate(ks[i], xy=(i + 0.5, 1), xytext=(i + 0.5, size + 0.5), size=10, textcoords=('data', 'data'),
va='center', ha='center', color='k')
ax.annotate(ks[i], xy=(-0.5, size - (size - 0.5)), xytext=(-0.5, size - (i + 0.5)), size=10, textcoords=('data', 'data'),
va='center', ha='center', color='k')
ax.annotate('$\lambda$', xy=(-0.5, size - (size - 0.5)), xytext=(-0.5, size + 0.5), size=10, textcoords=('data', 'data'),
va='center', ha='center', color='k')
ax.plot([0, size], [0, 0], 'k-', lw=4.0, solid_capstyle='butt')
ax.plot([size, size], [0, size], 'k-', lw=4.0, solid_capstyle='butt')
ax.plot([0, 0], [0, size], 'k-', lw=2.0, solid_capstyle='butt')
ax.plot([0, size], [size, size], 'k-', lw=2.0, solid_capstyle='butt')

cx = np.repeat(range(size + 1), 2)
cy = sorted(np.repeat(range(size + 1), 2), reverse=True)
ax.plot(cx[2:-1], cy[1:-2], 'k-', lw=2.0)
ax.plot(np.array(cx[2:-3]) + 1, cy[1:-4], 'k-', lw=2.0)
ax.plot(cx[1:-2], np.array(cy[:-3]) - 1, 'k-', lw=2.0)
ax.plot(cx[1:-4], np.array(cy[:-5]) - 2, 'k-', lw=2.0)

ax.set_xlim(-1, size)
ax.set_ylim(0, size + 1)
return ax