From 10f032fefd3e099df86726f1b855f0317658b874 Mon Sep 17 00:00:00 2001 From: Lukas Hergt Date: Thu, 2 Nov 2023 11:58:03 -0700 Subject: [PATCH] implement axes logscales (#328) * add logscale capability to `kde_contour_plot_2d` * add logscale capability to `fastkde_contour_plot_2d`, and remove `set_xlim` which should really be determined by the contourplots * add logscale capability to `hist_plot_2d` * presumably fix `test_hist_plot_2d` which did not create new figures despite checking for axes limits * add logscale capability to `kde_plot_1d` * add logscale capability to `fastkde_plot_1d` * add `logx` and `logxy` kwargs to `make_1d_axes`, `make_2d_axes`, `plot_1d`, and `plot_2d` which takes a list of parameters that are to be plotted on a log-scale * add logscale capability to `hist_plot_1d` * fix `logx` and `logxy` behaviour to work for empty lists and both for already existing axes and newly to create axes * add log-scale capability to `hist_plot_1d` when it is called on its own * add tests for log-scale capability of the anesthetic plotting functions * attempt setting logscale only if ax not None * move `local_kwargs` instantiation until after `logxy` is popped from kwargs * add test for logscale creation in `make_1d_axes` and `make_2d_axes` * add test for correct logscale handling of `Samples.plot_1d` and `Samples.plot_2d` * version bump to 2.2.0 * skip logscale tests for fastkde if not installed * add `logx`, `logxy`, and `label` to the docstring of `plot_1d` and/or `plot_2d` * add documentation for log-scale usage * add test checking for ValueError if log-axes do not match in repeated calls to `plot_1d` or `plot_2d` * add minimum requirement of `Sphinx>=4.2.0` * add minimum requirement of `sphinx_rtd_theme>=1.2.2` * version bump to 2.3.0 * version bump to 2.4.0 * fix logscaling post master merge * add test for combination of logscale hist plot with bins and range kwargs * add test for combination of logscale hist plot with bins and range kwargs in test_samples * fix `hist_plot_1d` for various combinations of `bins` and `range` kwargs with logscale, and remove astropy option since we now have string input for bins independent of astropy * add `noqa: disable=D101` to suppress `missing docstring` * change tests involving astropy to now test for ValueError, since astropy has been removed, since automatic bin computation is now directly integrated in anesthetic * change `logxy` to independent `logx` and `logy` * adjust docs to new `logx` and `logy` kwargs instead of the `logxy` kwarg * change docs on logscale to plot more parameters to see the different `logx` and `logy` behaviour * change docs on logscale to include descriptive legends * version bump to 2.5.0 * add pytest warning capture for hist * replace allsegs with get_paths * fix range of hist1d as data has already been logarithmed * get_paths only works on matplotlib>=3.8.0 * split logscale tests up for 1d and 2d, to make it easier to identify where issues reside * do not emit to orthogonal axes when logx and logy are different, and use data bounds for fastkde by default * revert the data constrained bounds from fastkde, they help with sharp bounds, but they ruin Gaussian results, this is very different from gaussian_kde, so for fastkde we actually would need to provide prior bounds and it is not good enough to infer bounds from the data... * add test for setting limits when horizontal axes are linear and vertical are logarithmic or vice versa, where those limits should not emit * simplify tests from previous commit to focus only on the cases that could actually fail * make new tests compatible with python 3.9 * clean up warnings import setup --------- Co-authored-by: AdamOrmondroyd --- README.rst | 2 +- anesthetic/_version.py | 2 +- anesthetic/plot.py | 182 +++++++++++++++++----- anesthetic/plotting/_matplotlib/hist.py | 61 +++++--- anesthetic/samples.py | 55 +++++-- docs/source/plotting.rst | 29 +++- tests/test_plot.py | 192 +++++++++++++++++++++++- tests/test_samples.py | 133 +++++++++++++++- 8 files changed, 580 insertions(+), 76 deletions(-) diff --git a/README.rst b/README.rst index 0beff72a..7a289866 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ anesthetic: nested sampling post-processing =========================================== :Authors: Will Handley and Lukas Hergt -:Version: 2.4.2 +:Version: 2.5.0 :Homepage: https://github.com/handley-lab/anesthetic :Documentation: http://anesthetic.readthedocs.io/ diff --git a/anesthetic/_version.py b/anesthetic/_version.py index cb9dc8a9..e59b17b4 100644 --- a/anesthetic/_version.py +++ b/anesthetic/_version.py @@ -1 +1 @@ -__version__ = '2.4.2' +__version__ = '2.5.0' diff --git a/anesthetic/plot.py b/anesthetic/plot.py index b25f60d9..1b1e065a 100644 --- a/anesthetic/plot.py +++ b/anesthetic/plot.py @@ -43,6 +43,8 @@ class AxesSeries(Series): labels : dict(str:str), optional Dictionary mapping params to plot labels. Default: params + logx : list(str), optional + List of parameters to be plotted on a log scale. gridspec_kw : dict, optional Dict with keywords passed to the :class:`matplotlib.gridspec.GridSpec` constructor used to create the grid the subplots are placed on. @@ -58,14 +60,23 @@ class AxesSeries(Series): """ + _metadata = ['_logx'] + _logx = [] + def __init__(self, data=None, index=None, fig=None, ncol=None, labels=None, - gridspec_kw=None, subplot_spec=None, *args, **kwargs): + logx=None, gridspec_kw=None, subplot_spec=None, + *args, **kwargs): if data is None and index is not None: data = self.axes_series(index=index, fig=fig, ncol=ncol, gridspec_kw=gridspec_kw, subplot_spec=subplot_spec) self._set_xlabels(axes=data, labels=labels) super().__init__(data=data, index=index, *args, **kwargs) + if logx is None: + self._logx = [] + else: + self._logx = logx + self._set_xscale() @property def _constructor(self): @@ -100,6 +111,11 @@ def axes_series(index, fig, ncol=None, gridspec_kw=None, ax.set_yticks([]) return axes + def _set_xscale(self): + for p, ax in self.items(): + if p in self._logx: + ax.set_xscale('log') + @staticmethod def _set_xlabels(axes, labels, **kwargs): if labels is None: @@ -148,6 +164,9 @@ class AxesDataFrame(DataFrame): If 'outer', plot ticks only on the very left and very bottom. If 'inner', plot ticks also in inner subplots. If None, plot no ticks at all. + logx, logy : list(str), optional + Lists of parameters to be plotted on a log scale on the x-axis or + y-axis, respectively. gridspec_kw : dict, optional Dict with keywords passed to the :class:`matplotlib.gridspec.GridSpec` constructor used to create the grid the subplots are placed on. @@ -171,10 +190,14 @@ class AxesDataFrame(DataFrame): """ + _metadata = ['_logx', '_logy'] + _logx = [] + _logy = [] + def __init__(self, data=None, index=None, columns=None, fig=None, lower=True, diagonal=True, upper=True, labels=None, - ticks='inner', gridspec_kw=None, subplot_spec=None, - *args, **kwargs): + ticks='inner', logx=None, logy=None, + gridspec_kw=None, subplot_spec=None, *args, **kwargs): if data is None and index is not None and columns is not None: position = self._position_frame(index=index, columns=columns, @@ -193,6 +216,16 @@ def __init__(self, data=None, index=None, columns=None, fig=None, index=index, columns=columns, *args, **kwargs) + if logx is None: + self._logx = [] + else: + self._logx = logx + if logy is None: + self._logy = [] + else: + self._logy = logy + if self._logx or self._logy: + self._set_scale() self.tick_params(axis='both', which='both', labelrotation=45, labelsize='small') @@ -289,15 +322,19 @@ def _make_diagonal(ax): class DiagonalAxes(type(ax)): def set_xlim(self, left=None, right=None, emit=True, auto=False, xmin=None, xmax=None): - super().set_ylim(bottom=left, top=right, emit=True, auto=auto, - ymin=xmin, ymax=xmax) + if (self.get_xaxis().get_scale() == + self.get_yaxis().get_scale()): + super().set_ylim(bottom=left, top=right, emit=True, + auto=auto, ymin=xmin, ymax=xmax) return super().set_xlim(left=left, right=right, emit=emit, auto=auto, xmin=xmin, xmax=xmax) def set_ylim(self, bottom=None, top=None, emit=True, auto=False, ymin=None, ymax=None): - super().set_xlim(left=bottom, right=top, emit=True, auto=auto, - xmin=ymin, xmax=ymax) + if (self.get_xaxis().get_scale() == + self.get_yaxis().get_scale()): + super().set_xlim(left=bottom, right=top, emit=True, + auto=auto, xmin=ymin, xmax=ymax) return super().set_ylim(bottom=bottom, top=top, emit=emit, auto=auto, ymin=ymin, ymax=ymax) @@ -342,6 +379,15 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False, ax.__class__ = OffDiagonalAxes + def _set_scale(self): + for y, rows in self.iterrows(): + for x, ax in rows.items(): + if ax is not None: + if x in self._logx: + ax.set_xscale('log') + if y in self._logy: + ax.set_yscale('log') + @staticmethod def _set_labels(axes, labels, **kwargs): all_params = list(axes.columns) + list(axes.index) @@ -559,7 +605,7 @@ def scatter(self, params, lower=True, upper=True, **kwargs): ax.scatter(params[x], params[y], zorder=z, **kwargs) -def make_1d_axes(params, ncol=None, labels=None, +def make_1d_axes(params, ncol=None, labels=None, logx=None, gridspec_kw=None, subplot_spec=None, **fig_kw): """Create a set of axes for plotting 1D marginalised posteriors. @@ -576,6 +622,9 @@ def make_1d_axes(params, ncol=None, labels=None, Dictionary mapping params to plot labels. Default: params + logx : list(str), optional + List of parameters to be plotted on a log scale. + gridspec_kw : dict, optional Dict with keywords passed to the :class:`matplotlib.gridspec.GridSpec` constructor used to create the grid the subplots are placed on. @@ -611,6 +660,7 @@ def make_1d_axes(params, ncol=None, labels=None, fig=fig, ncol=ncol, labels=labels, + logx=logx, gridspec_kw=gridspec_kw, subplot_spec=subplot_spec) if gridspec_kw is None: @@ -619,7 +669,8 @@ def make_1d_axes(params, ncol=None, labels=None, def make_2d_axes(params, labels=None, lower=True, diagonal=True, upper=True, - ticks='inner', gridspec_kw=None, subplot_spec=None, **fig_kw): + ticks='inner', logx=None, logy=None, + gridspec_kw=None, subplot_spec=None, **fig_kw): """Create a set of axes for plotting 2D marginalised posteriors. Parameters @@ -647,6 +698,10 @@ def make_2d_axes(params, labels=None, lower=True, diagonal=True, upper=True, * ``'inner'``: plot ticks also in inner subplots. * ``None``: plot no ticks at all. + logx, logy : list(str), optional + Lists of parameters to be plotted on a log scale on the x-axis or + y-axis, respectively. + gridspec_kw : dict, optional Dict with keywords passed to the :class:`matplotlib.gridspec.GridSpec` constructor used to create the grid the subplots are placed on. @@ -689,6 +744,8 @@ def make_2d_axes(params, labels=None, lower=True, diagonal=True, upper=True, upper=upper, labels=labels, ticks=ticks, + logx=logx, + logy=logy, gridspec_kw=gridspec_kw, subplot_spec=subplot_spec) fig.align_labels() @@ -739,6 +796,8 @@ def fastkde_plot_1d(ax, data, *args, **kwargs): """ kwargs = normalize_kwargs(kwargs) + if ax.get_xaxis().get_scale() == 'log': + data = np.log10(data) xmin = kwargs.pop('xmin', None) xmax = kwargs.pop('xmax', None) levels = kwargs.pop('levels', [0.95, 0.68]) @@ -768,6 +827,8 @@ def fastkde_plot_1d(ax, data, *args, **kwargs): i = ((x > quantile(x, q[0], p)) & (x < quantile(x, q[-1], p))) area = np.trapz(x=x[i], y=p[i]) if density else 1 + if ax.get_xaxis().get_scale() == 'log': + x = 10**x ans = ax.plot(x[i], p[i]/area, color=color, *args, **kwargs) if facecolor and facecolor not in [None, 'None', 'none']: @@ -853,6 +914,8 @@ def kde_plot_1d(ax, data, *args, **kwargs): if weights is not None: data = data[weights != 0] weights = weights[weights != 0] + if ax.get_xaxis().get_scale() == 'log': + data = np.log10(data) ncompress = kwargs.pop('ncompress', False) nplot = kwargs.pop('nplot_1d', 100) @@ -887,6 +950,8 @@ def kde_plot_1d(ax, data, *args, **kwargs): pp = cut_and_normalise_gaussian(x, p, bw, xmin=data.min(), xmax=data.max()) pp /= pp.max() area = np.trapz(x=x, y=pp) if density else 1 + if ax.get_xaxis().get_scale() == 'log': + x = 10**x ans = ax.plot(x, pp/area, color=color, *args, **kwargs) if facecolor and facecolor not in [None, 'None', 'none']: @@ -947,7 +1012,7 @@ def hist_plot_1d(ax, data, *args, **kwargs): """ kwargs = normalize_kwargs(kwargs) weights = kwargs.pop('weights', None) - bins = kwargs.pop('bins', 10) + bins = kwargs.pop('bins', 'fd') histtype = kwargs.pop('histtype', 'bar') density = kwargs.get('density', False) @@ -958,24 +1023,36 @@ def hist_plot_1d(ax, data, *args, **kwargs): q = kwargs.pop('q', 5) q = quantile_plot_interval(q=q) + if ax.get_xaxis().get_scale() == 'log': + data = np.log10(data) xmin = quantile(data, q[0], weights) xmax = quantile(data, q[-1], weights) - range = kwargs.pop('range', (xmin, xmax)) - - if isinstance(bins, str) and bins in ['fd', 'scott', 'sqrt']: - bins = histogram_bin_edges(data, - weights=weights, - bins=bins, - beta=kwargs.pop('beta', 'equal'), - range=range) + if 'range' in kwargs and ax.get_xaxis().get_scale() == 'log': + range = kwargs.pop('range') + if range is not None: + range = (np.log10(range[0]), np.log10(range[1])) + else: + range = (data.min(), data.max()) + else: + range = kwargs.pop('range', (xmin, xmax)) + if isinstance(bins, (int, str)): + if isinstance(bins, int): + bins = np.linspace(range[0], range[1], bins+1) + elif isinstance(bins, str) and bins in ['fd', 'scott', 'sqrt']: + bins = histogram_bin_edges(data, + weights=weights, + bins=bins, + beta=kwargs.pop('beta', 'equal'), + range=range) + if ax.get_xaxis().get_scale() == 'log': + bins = 10 ** bins + if ax.get_xaxis().get_scale() == 'log': + data = 10**data + range = (10**range[0], 10**range[1]) if isinstance(bins, str) and bins in ['knuth', 'freedman', 'blocks']: - try: - from astropy.visualization import hist - h, edges, bars = hist(data, ax=ax, bins=bins, - range=range, histtype=histtype, - color=color, *args, **kwargs) - except ImportError: - raise ImportError("You need to install astropy to use astropyhist") + raise ValueError("The astropy strings 'knuth', 'freedman', and " + "'blocks' are no longer supported. Please use the" + "similar 'fd', 'scott', or 'sqrt' from now on.") else: h, edges, bars = ax.hist(data, weights=weights, bins=bins, range=range, histtype=histtype, @@ -1034,6 +1111,14 @@ def fastkde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs): xmax = kwargs.pop('xmax', None) ymin = kwargs.pop('ymin', None) ymax = kwargs.pop('ymax', None) + if ax.get_xaxis().get_scale() == 'log': + data_x = np.log10(data_x) + xmin = None if xmin is None else np.log10(xmin) + xmax = None if xmax is None else np.log10(xmax) + if ax.get_yaxis().get_scale() == 'log': + data_y = np.log10(data_y) + ymin = None if ymin is None else np.log10(ymin) + ymax = None if ymax is None else np.log10(ymax) label = kwargs.pop('label', None) zorder = kwargs.pop('zorder', 1) levels = kwargs.pop('levels', [0.95, 0.68]) @@ -1060,6 +1145,11 @@ def fastkde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs): i = (pdf >= levels[0]*0.5).any(axis=0) j = (pdf >= levels[0]*0.5).any(axis=1) + if ax.get_xaxis().get_scale() == 'log': + x = 10**x + if ax.get_yaxis().get_scale() == 'log': + y = 10**y + if facecolor not in [None, 'None', 'none']: linewidths = kwargs.pop('linewidths', 0.5) contf = ax.contourf(x[i], y[j], pdf[np.ix_(j, i)], levels, cmap=cmap, @@ -1084,8 +1174,6 @@ def fastkde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs): vmin=vmin, vmax=vmax, linewidths=linewidths, colors=edgecolor, cmap=cmap, *args, **kwargs) - ax.set_xlim(xmin, xmax, auto=True) - ax.set_ylim(ymin, ymax, auto=True) return contf, cont @@ -1149,6 +1237,10 @@ def kde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs): data_x = data_x[weights != 0] data_y = data_y[weights != 0] weights = weights[weights != 0] + if ax.get_xaxis().get_scale() == 'log': + data_x = np.log10(data_x) + if ax.get_yaxis().get_scale() == 'log': + data_y = np.log10(data_y) ncompress = kwargs.pop('ncompress', 'equal') nplot = kwargs.pop('nplot_2d', 1000) @@ -1189,6 +1281,10 @@ def kde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs): xmin=data_y.min(), xmax=data_y.max()) levels = iso_probability_contours(P, contours=levels) + if ax.get_xaxis().get_scale() == 'log': + X = 10**X + if ax.get_yaxis().get_scale() == 'log': + Y = 10**Y if facecolor not in [None, 'None', 'none']: linewidths = kwargs.pop('linewidths', 0.5) @@ -1251,6 +1347,10 @@ def hist_plot_2d(ax, data_x, data_y, *args, **kwargs): """ kwargs = normalize_kwargs(kwargs) weights = kwargs.pop('weights', None) + if ax.get_xaxis().get_scale() == 'log': + data_x = np.log10(data_x) + if ax.get_yaxis().get_scale() == 'log': + data_y = np.log10(data_y) vmin = kwargs.pop('vmin', 0) label = kwargs.pop('label', None) @@ -1267,17 +1367,13 @@ def hist_plot_2d(ax, data_x, data_y, *args, **kwargs): ymax = quantile(data_y, q[-1], weights) rge = kwargs.pop('range', ((xmin, xmax), (ymin, ymax))) - if levels is None: - pdf, x, y, image = ax.hist2d(data_x, data_y, weights=weights, - cmap=cmap, range=rge, vmin=vmin, - *args, **kwargs) - else: - bins = kwargs.pop('bins', 10) - density = kwargs.pop('density', False) - cmin = kwargs.pop('cmin', None) - cmax = kwargs.pop('cmax', None) - pdf, x, y = np.histogram2d(data_x, data_y, bins, rge, - density, weights) + bins = kwargs.pop('bins', 10) + density = kwargs.pop('density', False) + cmin = kwargs.pop('cmin', None) + cmax = kwargs.pop('cmax', None) + pdf, x, y = np.histogram2d(data_x, data_y, bins, rge, + density, weights) + if levels is not None: levels = iso_probability_contours(pdf, levels) pdf = np.digitize(pdf, levels, right=True) pdf = np.array(levels)[pdf] @@ -1286,9 +1382,13 @@ def hist_plot_2d(ax, data_x, data_y, *args, **kwargs): pdf[pdf < cmin] = np.ma.masked if cmax is not None: pdf[pdf > cmax] = np.ma.masked - snap = kwargs.pop('snap', True) - image = ax.pcolormesh(x, y, pdf.T, cmap=cmap, vmin=vmin, snap=snap, - *args, **kwargs) + snap = kwargs.pop('snap', True) + if ax.get_xaxis().get_scale() == 'log': + x = 10**x + if ax.get_yaxis().get_scale() == 'log': + y = 10**y + image = ax.pcolormesh(x, y, pdf.T, cmap=cmap, vmin=vmin, snap=snap, + *args, **kwargs) ax.add_patch(plt.Rectangle((0, 0), 0, 0, fc=cmap(0.999), ec=cmap(0.32), lw=2, label=label)) diff --git a/anesthetic/plotting/_matplotlib/hist.py b/anesthetic/plotting/_matplotlib/hist.py index 47c3fb26..763591cb 100644 --- a/anesthetic/plotting/_matplotlib/hist.py +++ b/anesthetic/plotting/_matplotlib/hist.py @@ -26,8 +26,24 @@ class HistPlot(_WeightedMPLPlot, _HistPlot): + + # noqa: disable=D101 + def _args_adjust(self) -> None: + if ( + hasattr(self, 'bins') and + isinstance(self.bins, str) and + self.bins in ['fd', 'scott', 'sqrt'] + ): + self.bins = self._calculate_bins(self.data) + super()._args_adjust() + # noqa: disable=D101 def _calculate_bins(self, data): + if self.logx: + data = np.log10(data) + if 'range' in self.kwds and self.kwds['range'] is not None: + xmin, xmax = self.kwds['range'] + self.kwds['range'] = (np.log10(xmin), np.log10(xmax)) nd_values = data.infer_objects(copy=False)._get_numeric_data() values = np.ravel(nd_values) weights = self.kwds.get("weights", None) @@ -41,10 +57,25 @@ def _calculate_bins(self, data): values = values[~isna(values)] - hist, bins = np.histogram( - values, bins=self.bins, range=self.kwds.get("range", None), - weights=weights - ) + if isinstance(self.bins, str) and self.bins in ['fd', 'scott', 'sqrt']: + bins = histogram_bin_edges( + values, + weights=weights, + bins=self.bins, + beta=self.kwds.pop('beta', 'equal'), + range=self.kwds.get('range', None) + ) + else: + bins = np.histogram_bin_edges( + values, + weights=weights, + bins=self.bins, + range=self.kwds.get('range', None) + ) + if self.logx: + bins = 10**bins + if 'range' in self.kwds and self.kwds['range'] is not None: + self.kwds['range'] = (xmin, xmax) return bins def _get_colors(self, num_colors=None, color_kwds='color'): @@ -149,23 +180,19 @@ def __init__( ) -> None: super().__init__(data, bins=bins, bottom=bottom, **kwargs) - def _args_adjust(self) -> None: - if 'range' not in self.kwds: + def _calculate_bins(self, data): + if 'range' not in self.kwds or self.kwds['range'] is None: q = self.kwds.get('q', 5) q = quantile_plot_interval(q=q) weights = self.kwds.get('weights', None) - xmin = quantile(self.data, q[0], weights) - xmax = quantile(self.data, q[-1], weights) + xmin = quantile(data, q[0], weights) + xmax = quantile(data, q[-1], weights) self.kwds['range'] = (xmin, xmax) - if isinstance(self.bins, str) and self.bins in ['fd', 'scott', 'sqrt']: - self.bins = histogram_bin_edges( - self.data, - weights=self.kwds.get('weights', None), - bins=self.bins, - beta=self.kwds.pop('beta', 'equal'), - range=self.kwds.get('range', None) - ) - super()._args_adjust() + bins = super()._calculate_bins(data) + self.kwds.pop('range') + else: + bins = super()._calculate_bins(data) + return bins @classmethod def _plot( diff --git a/anesthetic/samples.py b/anesthetic/samples.py index 2e76eee7..cc3bc3f8 100644 --- a/anesthetic/samples.py +++ b/anesthetic/samples.py @@ -198,6 +198,13 @@ def plot_1d(self, axes=None, *args, **kwargs): can be hard to interpret/expensive for :class:`Samples`, :class:`MCMCSamples`, or :class:`NestedSamples`. + logx : list(str), optional + Which parameters/columns to plot on a log scale. + Needs to match if plotting on top of a pre-existing axes. + + label : str, optional + Legend label added to each axis. + Returns ------- axes : :class:`pandas.Series` of :class:`matplotlib.axes.Axes` @@ -215,7 +222,14 @@ def plot_1d(self, axes=None, *args, **kwargs): axes = self.drop_labels().columns if not isinstance(axes, AxesSeries): - _, axes = make_1d_axes(axes, labels=self.get_labels_map()) + _, axes = make_1d_axes(axes, labels=self.get_labels_map(), + logx=kwargs.pop('logx', None)) + logx = axes._logx + else: + logx = kwargs.pop('logx', axes._logx) + if logx != axes._logx: + raise ValueError(f"logx does not match the pre-existing axes." + f"logx={logx}, axes._logx={axes._logx}") kwargs['kind'] = kwargs.get('kind', 'kde_1d') kwargs['label'] = kwargs.get('label', self.label) @@ -237,7 +251,7 @@ def plot_1d(self, axes=None, *args, **kwargs): for x, ax in axes.items(): if x in self and kwargs['kind'] is not None: xlabel = self.get_label(x) - self[x].plot(ax=ax, xlabel=xlabel, + self[x].plot(ax=ax, xlabel=xlabel, logx=x in logx, *args, **kwargs) ax.set_xlabel(xlabel) else: @@ -314,6 +328,14 @@ def plot_2d(self, axes=None, *args, **kwargs): overwrite any kwarg with the same key passed to _kwargs. Default: {} + logx, logy : list(str), optional + Which parameters/columns to plot on a log scale for the x-axis and + y-axis, respectively. + Needs to match if plotting on top of a pre-existing axes. + + label : str, optional + Legend label added to each axis. + Returns ------- axes : :class:`pandas.DataFrame` of :class:`matplotlib.axes.Axes` @@ -341,13 +363,6 @@ def plot_2d(self, axes=None, *args, **kwargs): "the following string shortcuts: " f"{list(self.plot_2d_default_kinds.keys())}") - local_kwargs = {pos: kwargs.pop('%s_kwargs' % pos, {}) - for pos in ['upper', 'lower', 'diagonal']} - kwargs['label'] = kwargs.get('label', self.label) - - for pos in local_kwargs: - local_kwargs[pos].update(kwargs) - if axes is None: axes = self.drop_labels().columns @@ -355,7 +370,25 @@ def plot_2d(self, axes=None, *args, **kwargs): _, axes = make_2d_axes(axes, labels=self.get_labels_map(), upper=('upper' in kind), lower=('lower' in kind), - diagonal=('diagonal' in kind)) + diagonal=('diagonal' in kind), + logx=kwargs.pop('logx', None), + logy=kwargs.pop('logy', None)) + logx = axes._logx + logy = axes._logy + else: + logx = kwargs.pop('logx', axes._logx) + logy = kwargs.pop('logy', axes._logy) + if logx != axes._logx or logy != axes._logy: + raise ValueError(f"logx or logy not matching existing axes:" + f"logx={logx}, axes._logx={axes._logx}" + f"logy={logy}, axes._logy={axes._logy}") + + local_kwargs = {pos: kwargs.pop('%s_kwargs' % pos, {}) + for pos in ['upper', 'lower', 'diagonal']} + kwargs['label'] = kwargs.get('label', self.label) + + for pos in local_kwargs: + local_kwargs[pos].update(kwargs) for y, row in axes.iterrows(): for x, ax in row.items(): @@ -383,11 +416,13 @@ def plot_2d(self, axes=None, *args, **kwargs): ylabel = self.get_label(y) if x == y: self[x].plot(ax=ax.twin, xlabel=xlabel, + logx=x in logx, *args, **lkwargs) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) else: self.plot(x, y, ax=ax, xlabel=xlabel, + logx=x in logx, logy=y in logy, ylabel=ylabel, *args, **lkwargs) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) diff --git a/docs/source/plotting.rst b/docs/source/plotting.rst index 72f736cb..d4fb2d56 100644 --- a/docs/source/plotting.rst +++ b/docs/source/plotting.rst @@ -251,10 +251,37 @@ method from there, directing it to the correct position with the ``loc`` and axes.iloc[ 0, -1].legend(loc='lower right', bbox_to_anchor=(1, 1)) axes.iloc[-1, 0].legend(loc='lower center', bbox_to_anchor=(len(axes)/2, len(axes))) +Log-scale +--------- + +You can plot selected parameters on a log-scale by passing a list of those +parameters under the keyword ``logx`` to :func:`anesthetic.plot.make_1d_axes` +or :meth:`anesthetic.samples.Samples.plot_1d`, and under the keywords ``logx`` +and ``logy`` to :func:`anesthetic.plot.make_2d_axes` or +:meth:`anesthetic.samples.Samples.plot_2d`: + +.. plot:: :context: close-figs + + fig, axes = make_1d_axes(['x0', 'x1', 'x2', 'x3'], logx=['x2']) + samples.plot_1d(axes, label="'x2' on log-scale") + axes['x2'].legend() + +.. plot:: :context: close-figs + + fig, axes = make_2d_axes(['x0', 'x1', 'x2', 'x3'], logx=['x2'], logy=['x2']) + samples.plot_2d(axes, label="'x2' on log-scale") + axes.iloc[-1, 0].legend(bbox_to_anchor=(len(axes), len(axes)), loc='lower right') + +.. plot:: :context: close-figs + + fig, axes = make_2d_axes(['x0', 'x1', 'x2', 'x3'], logx=['x2']) + samples.plot_2d(axes, label="'x2' on log-scale for x-axis, but not for y-axis") + axes.iloc[-1, 0].legend(bbox_to_anchor=(len(axes), len(axes)), loc='lower right') + Ticks ----- -You can pass the keyword ``ticks`` to :func:anesthetic.plot.make_2d_axes: to +You can pass the keyword ``ticks`` to :func:`anesthetic.plot.make_2d_axes`: to adjust the tick settings of the 2D axes. There are three options: * ``ticks='inner'`` diff --git a/tests/test_plot.py b/tests/test_plot.py index 47db51cd..8864d905 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,5 +1,6 @@ import anesthetic.examples._matplotlib_agg # noqa: F401 from packaging import version +from warnings import catch_warnings, filterwarnings import pytest import numpy as np import matplotlib @@ -486,10 +487,8 @@ def test_astropyhist_plot_1d(bins): fig, ax = plt.subplots() np.random.seed(0) data = np.random.randn(100) - bars = hist_plot_1d(ax, data, histtype='bar', bins=bins)[2] - assert np.all([isinstance(b, Patch) for b in bars]) - assert max([b.get_height() for b in bars]) == 1. - assert np.all(np.array([b.get_height() for b in bars]) <= 1.) + with pytest.raises(ValueError): + hist_plot_1d(ax, data, bins=bins) @pytest.mark.parametrize('bins', ['fd', 'scott', 'sqrt']) @@ -534,12 +533,14 @@ def test_hist_plot_2d(): ymin, ymax = ax.get_ylim() assert xmin > -5 and xmax < 5 and ymin > -5 and ymax < 5 + fig, ax = plt.subplots() data_x, data_y = np.random.uniform(-10, 10, (2, 1000)) hist_plot_2d(ax, data_x, data_y) xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() assert xmin > -10 and xmax < 10 and ymin > -10 and ymax < 10 + fig, ax = plt.subplots() data_x, data_y = np.random.uniform(-10, 10, (2, 1000)) weights = np.exp(-(data_x**2 + data_y**2)/2) hist_plot_2d(ax, data_x, data_y, weights=weights, bins=30) @@ -745,6 +746,189 @@ def test_scatter_plot_2d(): scatter_plot_2d(ax, data_x, data_y, q=0) +def test_make_axes_logscale(): + # 1d + fig, axes = make_1d_axes(['x0', 'x1', 'x2', 'x3'], logx=['x1', 'x3']) + assert axes.loc['x0'].get_xscale() == 'linear' + assert axes.loc['x1'].get_xscale() == 'log' + assert axes.loc['x2'].get_xscale() == 'linear' + assert axes.loc['x3'].get_xscale() == 'log' + + # 2d, logx only + fig, axes = make_2d_axes(['x0', 'x1', 'x2', 'x3'], + logx=['x1', 'x3']) + for y, rows in axes.iterrows(): + for x, ax in rows.items(): + if x in ['x0', 'x2']: + assert ax.get_xscale() == 'linear' + else: + assert ax.get_xscale() == 'log' + assert ax.get_yscale() == 'linear' + with catch_warnings(): + filterwarnings('error', category=UserWarning, + message="Attempt to set non-positive") + ax.set_ylim(-1, 1) + + # 2d, logy only + fig, axes = make_2d_axes(['x0', 'x1', 'x2', 'x3'], + logy=['x1', 'x3']) + for y, rows in axes.iterrows(): + for x, ax in rows.items(): + assert ax.get_xscale() == 'linear' + with catch_warnings(): + filterwarnings('error', category=UserWarning, + message="Attempt to set non-positive") + ax.set_xlim(-1, 1) + if y in ['x0', 'x2']: + assert ax.get_yscale() == 'linear' + else: + assert ax.get_yscale() == 'log' + + # 2d, logx and logy + fig, axes = make_2d_axes(['x0', 'x1', 'x2', 'x3'], + logx=['x1', 'x3'], + logy=['x1', 'x3']) + for y, rows in axes.iterrows(): + for x, ax in rows.items(): + if x in ['x0', 'x2']: + assert ax.get_xscale() == 'linear' + else: + assert ax.get_xscale() == 'log' + if y in ['x0', 'x2']: + assert ax.get_yscale() == 'linear' + else: + assert ax.get_yscale() == 'log' + + +@pytest.mark.parametrize('plot_1d', [kde_plot_1d, + skipif_no_fastkde(fastkde_plot_1d), + hist_plot_1d]) +def test_logscale_1d(plot_1d): + np.random.seed(42) + logdata = np.random.randn(1000) + data = 10**logdata + + fig, ax = plt.subplots() + ax.set_xscale('log') + p = plot_1d(ax, data) + if 'kde' in plot_1d.__name__: + amax = abs(np.log10(p[0].get_xdata()[np.argmax(p[0].get_ydata())])) + else: + amax = abs(np.log10(p[1][np.argmax(p[0])])) + assert amax < 0.5 + + +@pytest.mark.parametrize('b', ['scott', 20, np.logspace(-5, 5, 20)]) +def test_logscale_hist_kwargs(b): + np.random.seed(42) + logdata = np.random.randn(1000) + data = 10**logdata + + fig, ax = plt.subplots() + ax.set_xscale('log') + h, edges, _ = hist_plot_1d(ax, data, bins=b) + amax = abs(np.log10(edges[np.argmax(h)])) + assert amax < 0.5 + assert edges[0] < 1e-3 + assert edges[-1] > 1e3 + h, edges, _ = hist_plot_1d(ax, data, bins=b, range=(1e-3, 1e3)) + amax = abs(np.log10(edges[np.argmax(h)])) + assert amax < 0.5 + if isinstance(b, (int, str)): + # edges are trimmed according to range + assert edges[0] == 1e-3 + assert edges[-1] == 1e3 + else: + # edges passed directly to bins are not trimmed according to range + assert edges[0] == b[0] + assert edges[-1] == b[-1] + + +@pytest.mark.parametrize('plot_2d', + [kde_contour_plot_2d, + skipif_no_fastkde(fastkde_contour_plot_2d), + hist_plot_2d, scatter_plot_2d]) +def test_logscale_2d(plot_2d): + np.random.seed(0) + logx = np.random.randn(1000) + logy = np.random.randn(1000) + x = 10**logx + y = 10**logy + + # logx + fig, ax = plt.subplots() + ax.set_xscale('log') + p = plot_2d(ax, x, logy) + if 'kde' in plot_2d.__name__: + if version.parse(matplotlib.__version__) >= version.parse('3.8.0'): + xmax, ymax = p[0].get_paths()[1].vertices[0].T + else: + xmax, ymax = p[0].allsegs[1][0].T + xmax = np.mean(np.log10(xmax)) + ymax = np.mean(ymax) + elif 'hist' in plot_2d.__name__: + c = p.get_coordinates() + c = (c[:-1, :] + c[1:, :]) / 2 + c = (c[:, :-1] + c[:, 1:]) / 2 + c = c.reshape((-1, 2)) + xmax = abs(np.log10(c[np.argmax(p.get_array())][0])) + ymax = abs(c[np.argmax(p.get_array())][1]) + else: + xmax = np.mean(np.log10(p[0].get_xdata())) + ymax = np.mean(p[0].get_ydata()) + assert xmax < 0.5 + assert ymax < 0.5 + + # logy + fig, ax = plt.subplots() + ax.set_yscale('log') + p = plot_2d(ax, logx, y) + if 'kde' in plot_2d.__name__: + if version.parse(matplotlib.__version__) >= version.parse('3.8.0'): + xmax, ymax = p[0].get_paths()[1].vertices[0].T + else: + xmax, ymax = p[0].allsegs[1][0].T + xmax = np.mean(xmax) + ymax = np.mean(np.log10(ymax)) + elif 'hist' in plot_2d.__name__: + c = p.get_coordinates() + c = (c[:-1, :] + c[1:, :]) / 2 + c = (c[:, :-1] + c[:, 1:]) / 2 + c = c.reshape((-1, 2)) + xmax = abs(c[np.argmax(p.get_array())][0]) + ymax = abs(np.log10(c[np.argmax(p.get_array())][1])) + else: + xmax = np.mean(p[0].get_xdata()) + ymax = np.mean(np.log10(p[0].get_ydata())) + assert xmax < 0.5 + assert ymax < 0.5 + + # logx and logy + fig, ax = plt.subplots() + ax.set_xscale('log') + ax.set_yscale('log') + p = plot_2d(ax, x, y) + if 'kde' in plot_2d.__name__: + if version.parse(matplotlib.__version__) >= version.parse('3.8.0'): + xmax, ymax = p[0].get_paths()[1].vertices[0].T + else: + xmax, ymax = p[0].allsegs[1][0].T + xmax = np.mean(np.log10(xmax)) + ymax = np.mean(np.log10(ymax)) + elif 'hist' in plot_2d.__name__: + c = p.get_coordinates() + c = (c[:-1, :] + c[1:, :]) / 2 + c = (c[:, :-1] + c[:, 1:]) / 2 + c = c.reshape((-1, 2)) + xmax = abs(np.log10(c[np.argmax(p.get_array())][0])) + ymax = abs(np.log10(c[np.argmax(p.get_array())][1])) + else: + xmax = np.mean(np.log10(p[0].get_xdata())) + ymax = np.mean(np.log10(p[0].get_ydata())) + assert xmax < 0.5 + assert ymax < 0.5 + + @pytest.mark.parametrize('sigmas', [(1, '1sigma', 0.682689492137086), (2, '2sigma', 0.954499736103642), (3, '3sigma', 0.997300203936740), diff --git a/tests/test_samples.py b/tests/test_samples.py index 42263538..4dca96f4 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -394,7 +394,8 @@ def test_plot_1d_colours(kind): def test_astropyhist(): np.random.seed(3) ns = read_chains('./tests/example_data/pc') - ns.plot_1d(['x0', 'x1', 'x2', 'x3'], kind='hist_1d', bins='knuth') + with pytest.raises(ValueError): + ns.plot_1d(['x0', 'x1', 'x2', 'x3'], kind='hist_1d', bins='knuth') def test_hist_levels(): @@ -431,6 +432,136 @@ def test_plot_1d_no_axes(): assert axes.iloc[2].get_xlabel() == 'x2' +@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde')]) +def test_plot_logscale_1d(kind): + ns = read_chains('./tests/example_data/pc') + params = ['x0', 'x1', 'x2', 'x3', 'x4'] + + # 1d + axes = ns.plot_1d(params, kind=kind + '_1d', logx=['x2']) + for x, ax in axes.items(): + if x == 'x2': + assert ax.get_xscale() == 'log' + else: + assert ax.get_xscale() == 'linear' + ax = axes.loc['x2'] + if 'kde' in kind: + p = ax.get_children() + arg = np.argmax(p[0].get_ydata()) + pmax = np.log10(p[0].get_xdata()[arg]) + d = 0.1 + else: + arg = np.argmax([p.get_height() for p in ax.patches]) + pmax = np.log10(ax.patches[arg].get_x()) + d = np.log10(ax.patches[arg+1].get_x() / ax.patches[arg].get_x()) + assert pmax == pytest.approx(-1, abs=d) + + +@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde')]) +def test_plot_logscale_2d(kind): + ns = read_chains('./tests/example_data/pc') + params = ['x0', 'x1', 'x2', 'x3', 'x4'] + + # 2d, logx only + axes = ns.plot_2d(params, kind=kind, logx=['x2']) + for y, rows in axes.iterrows(): + for x, ax in rows.items(): + if ax is not None: + if x == 'x2': + assert ax.get_xscale() == 'log' + else: + assert ax.get_xscale() == 'linear' + ax.get_yscale() == 'linear' + if x == y: + if x == 'x2': + assert ax.twin.get_xscale() == 'log' + else: + assert ax.twin.get_xscale() == 'linear' + assert ax.twin.get_yscale() == 'linear' + + # 2d, logy only + axes = ns.plot_2d(params, kind=kind, logy=['x2']) + for y, rows in axes.iterrows(): + for x, ax in rows.items(): + if ax is not None: + ax.get_xscale() == 'linear' + if y == 'x2': + assert ax.get_yscale() == 'log' + else: + assert ax.get_yscale() == 'linear' + if x == y: + assert ax.twin.get_xscale() == 'linear' + assert ax.twin.get_yscale() == 'linear' + + # 2d, logx and logy + axes = ns.plot_2d(params, kind=kind, logx=['x2'], logy=['x2']) + for y, rows in axes.iterrows(): + for x, ax in rows.items(): + if ax is not None: + if x == 'x2': + assert ax.get_xscale() == 'log' + else: + assert ax.get_xscale() == 'linear' + if y == 'x2': + assert ax.get_yscale() == 'log' + else: + assert ax.get_yscale() == 'linear' + if x == y: + if x == 'x2': + assert ax.twin.get_xscale() == 'log' + else: + assert ax.twin.get_xscale() == 'linear' + assert ax.twin.get_yscale() == 'linear' + + +@pytest.mark.parametrize('k', ['hist_1d', 'hist']) +@pytest.mark.parametrize('b', ['scott', 10, np.logspace(-3, 0, 20)]) +@pytest.mark.parametrize('r', [None, (1e-5, 1)]) +def test_plot_logscale_hist_kwargs(k, b, r): + ns = read_chains('./tests/example_data/pc') + with pytest.warns(UserWarning) if k == 'hist' else nullcontext(): + axes = ns[['x2']].plot_1d(kind=k, logx=['x2'], bins=b, range=r) + ax = axes.loc['x2'] + assert ax.get_xscale() == 'log' + arg = np.argmax([p.get_height() for p in ax.patches]) + pmax = np.log10(ax.patches[arg].get_x()) + d = np.log10(ax.patches[arg+1].get_x() / ax.patches[arg].get_x()) + assert pmax == pytest.approx(-1, abs=d) + + +def test_logscale_failure_without_match(): + ns = read_chains('./tests/example_data/pc') + params = ['x0', 'x2'] + + # 1d + axes = ns.plot_1d(params) + with pytest.raises(ValueError): + ns.plot_1d(axes, logx=['x2']) + fig, axes = make_1d_axes(params) + with pytest.raises(ValueError): + ns.plot_1d(axes, logx=['x2']) + + # 2d + axes = ns.plot_2d(params) + with pytest.raises(ValueError): + ns.plot_2d(axes, logx=['x2']) + axes = ns.plot_2d(params) + with pytest.raises(ValueError): + ns.plot_2d(axes, logy=['x2']) + axes = ns.plot_2d(params) + with pytest.raises(ValueError): + ns.plot_2d(axes, logx=['x2'], logy=['x2']) + fig, axes = make_2d_axes(params) + with pytest.raises(ValueError): + ns.plot_2d(axes, logx=['x2']) + fig, axes = make_2d_axes(params) + with pytest.raises(ValueError): + ns.plot_2d(axes, logy=['x2']) + fig, axes = make_2d_axes(params) + with pytest.raises(ValueError): + ns.plot_2d(axes, logx=['x2'], logy=['x2']) + + def test_mcmc_stats(): mcmc = read_chains('./tests/example_data/cb') chains = mcmc.groupby(('chain', '$n_\\mathrm{chain}$'), group_keys=False)