diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 944e3e8..bbb866f 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -59,6 +59,7 @@ Diagram Visualization persim.plot_diagrams persim.bottleneck_matching persim.wasserstein_matching + persim.Barcode persim.plot_landscape persim.plot_landscape_simple diff --git a/persim/visuals.py b/persim/visuals.py index 9bb50d0..e5ce7b3 100644 --- a/persim/visuals.py +++ b/persim/visuals.py @@ -1,7 +1,14 @@ import numpy as np import matplotlib.pyplot as plt -__all__ = ["plot_diagrams", "bottleneck_matching", "wasserstein_matching"] +import io + +__all__ = [ + "plot_diagrams", + "bottleneck_matching", + "wasserstein_matching", + "Barcode" +] def plot_diagrams( @@ -24,21 +31,21 @@ def plot_diagrams( Parameters ---------- diagrams: ndarray (n_pairs, 2) or list of diagrams - A diagram or list of diagrams. If diagram is a list of diagrams, + A diagram or list of diagrams. If diagram is a list of diagrams, then plot all on the same plot using different colors. plot_only: list of numeric If specified, an array of only the diagrams that should be plotted. title: string, default is None If title is defined, add it as title of the plot. xy_range: list of numeric [xmin, xmax, ymin, ymax] - User provided range of axes. This is useful for comparing + User provided range of axes. This is useful for comparing multiple persistence diagrams. labels: string or list of strings - Legend labels for each diagram. + Legend labels for each diagram. If none are specified, we use H_0, H_1, H_2,... by default. colormap: string, default is 'default' - Any of matplotlib color palettes. - Some options are 'default', 'seaborn', 'sequential'. + Any of matplotlib color palettes. + Some options are 'default', 'seaborn', 'sequential'. See all available styles with .. code:: python @@ -48,17 +55,17 @@ def plot_diagrams( size: numeric, default is 20 Pixel size of each point plotted. - ax_color: any valid matplotlib color type. + ax_color: any valid matplotlib color type. See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API. diagonal: bool, default is True Plot the diagonal x=y line. lifetime: bool, default is False. If True, diagonal is turned to False. - Plot life time of each point instead of birth and death. + Plot life time of each point instead of birth and death. Essentially, visualize (x, y-x). legend: bool, default is True If true, show the legend. show: bool, default is False - Call plt.show() after plotting. If you are using self.plot() as part + Call plt.show() after plotting. If you are using self.plot() as part of a subplot, set show=False and call plt.show() only once at the end. """ @@ -165,18 +172,15 @@ def plot_diagrams( if show is True: plt.show() -def plot_a_bar(p, q, c='b', linestyle='-'): - plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1) - def bottleneck_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None): """ Visualize bottleneck matching between two diagrams Parameters =========== - dgm1: Mx(>=2) + dgm1: Mx(>=2) array of birth/death pairs for PD 1 - dgm2: Nx(>=2) + dgm2: Nx(>=2) array of birth/death paris for PD 2 matching: ndarray(Mx+Nx, 3) A list of correspondences in an optimal matching, as well as their distance, where: @@ -184,7 +188,7 @@ def bottleneck_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None): * Second column is index of point in second persistence diagram, or -1 if diagonal * Third column is the distance of each matching labels: list of strings - names of diagrams for legend. Default = ["dgm1", "dgm2"], + names of diagrams for legend. Default = ["dgm1", "dgm2"], ax: matplotlib Axis object For plotting on a particular axis. @@ -248,7 +252,7 @@ def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None) * Second column is index of point in second persistence diagram, or -1 if diagonal * Third column is the distance of each matching labels: list of strings - names of diagrams for legend. Default = ["dgm1", "dgm2"], + names of diagrams for legend. Default = ["dgm1", "dgm2"], ax: matplotlib Axis object For plotting on a particular axis. @@ -287,76 +291,100 @@ def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None) plot_diagrams([dgm1, dgm2], labels=labels, ax=ax) -class Barcode(): - def __init__(self, diagrams): - ''' - parameters: +class Barcode: + __doc__ = """ + Barcode visualisation made easy! + + Note that this convenience class requires instantiation as the number + of subplots produced depends on the dimension of the data. + """ + + def __init__(self, diagrams, verbose=False): + """ + Parameters =========== - diagrams: list-like - typically the output of ripser(nodes)['dgms'] + diagrams: list-like + typically the output of ripser(nodes)['dgms'] + verbose: bool + Execute print statemens for extra information; currently only echoes + number of bars in each dimension (Default=False). - examples: + Examples =========== + >>> n = 300 + >>> t = np.linspace(0, 2 * np.pi, n) + >>> noise = np.random.normal(0, 0.1, size=n) + >>> data = np.vstack([((3+d) * np.cos(t[i]+d), (3+d) * np.sin(t[i]+d)) for i, d in enumerate(noise)]) + >>> diagrams = ripser(data) + >>> bc = Barcode(diagrams['dgms']) + >>> bc.plot_barcode() + """ + if not isinstance(diagrams, list): + diagrams = [diagrams] - n = 300 - t = np.linspace(0, 2 * np.pi, n) - noise = np.random.normal(0, 0.1, size=n) + self.diagrams = diagrams + self._verbose = verbose + self._dim = len(diagrams) - data = np.vstack([((3+d) * np.cos(t[i]+d), (3+d) * np.sin(t[i]+d)) for i, d in enumerate(noise)]) - diagrams = ripser(data) + def plot_barcode(self, figsize=None, show=True, export_png=False, dpi=100, **kwargs): + """Wrapper method to produce barcode plot - bc = Barcode(diagrams['dgms']) - bc.plot_barcode() - ''' - if not isinstance(diagrams, list): - diagrams = [diagrams] + Parameters + =========== + figsize: tuple + figure size, default=(6,6) if H0+H1 only, (6,4) otherwise - if len(diagrams) == 2: - self.plot_barcode = self._plot_H0_H1 + show: boolean + show the figure via plt.show() - else: - self.plot_barcode = self._plot_Hn + export_png: boolean + write image to png data, returned as io.BytesIO() instance, + default=False - self.diagrams = diagrams + **kwargs: artist paramters for the barcodes, defaults: + c='grey' + linestyle='-' + linewidth=0.5 + dpi=100 (for png export) - def _plot_H0_H1(self, **kwargs): - ''' - parameters: - =========== - figsize: tuple - figure size, default=(6,6) - - show: boolean - show the figure via plt.show() - - export_png: boolean - write image to png data, returned as io.BytesIO() instance, default=False - - **kwargs: artist paramters for the barcodes, defaults: - c='grey' - linestyle='-' - linewidth=0.5 - dpi=100 (for png export) - - returns: + Returns =========== - list of png exports or [] - ''' - import io - - fsize = kwargs.get('figsize', (6, 6)) - show = kwargs.get('show', True) - export = kwargs.get('export_png', False) - dpi = kwargs.get('dpi', 100) + out: list or None + list of png exports if export_png=True, otherwise None + """ + if self._dim == 2: + if figsize is None: + figsize = (6, 6) + + return self._plot_H0_H1( + figsize=figsize, + show=show, + export_png=export_png, + dpi=dpi, + **kwargs + ) + else: + if figsize is None: + figsize = (6, 4) + + return self._plot_Hn( + figsize=figsize, + show=show, + export_png=export_png, + dpi=dpi, + **kwargs + ) + + def _plot_H0_H1(self, *, figsize, show, export_png, dpi, **kwargs): out = [] - fig, ax = plt.subplots(2, 1, figsize=fsize) + fig, ax = plt.subplots(2, 1, figsize=figsize) for dim, diagram in enumerate(self.diagrams): self._plot_many_bars(dim, diagram, dim, ax, **kwargs) - if export: + if export_png: fp = io.BytesIO() plt.savefig(fp, dpi=dpi) fp.seek(0) @@ -365,45 +393,21 @@ def _plot_H0_H1(self, **kwargs): if show: plt.show() + else: + plt.close() - return out - - def _plot_Hn(self, **kwargs): - ''' - parameters: - =========== - figsize: tuple - figure size, default=(6,6) - - show: boolean - show the figure via plt.show() - - export_png: boolean - write image to png data, returned as io.BytesIO() instance, default=False - - **kwargs: artist paramters for the barcodes, defaults: - c='grey' - linestyle='-' - linewidth=0.5 - dpi=100 (for png export) - - returns: - =========== - list of png exports or [] - ''' - fsize = kwargs.get('figsize', (6, 4)) - show = kwargs.get('show', True) - export = kwargs.get('export_png', False) - dpi = kwargs.get('dpi', 100) + if any(out): + return out + def _plot_Hn(self, *, figsize, show, export_png, dpi, **kwargs): out = [] for dim, diagram in enumerate(self.diagrams): - fig, ax = plt.subplots(1, 1, figsize=fsize) + fig, ax = plt.subplots(1, 1, figsize=figsize) self._plot_many_bars(dim, diagram, 0, [ax], **kwargs) - if export: + if export_png: fp = io.BytesIO() plt.savefig(fp, dpi=dpi) fp.seek(0) @@ -412,12 +416,16 @@ def _plot_Hn(self, **kwargs): if show: plt.show() + else: + plt.close() - return out + if any(out): + return out def _plot_many_bars(self, dim, diagram, idx, ax, **kwargs): number_of_bars = len(diagram) - print("Number of bars in dimension %d: %d" % (dim, number_of_bars)) + if self._verbose: + print("Number of bars in dimension %d: %d" % (dim, number_of_bars)) if number_of_bars > 0: births = np.vstack([(elem[0], i) for i, elem in enumerate(diagram)]) @@ -432,29 +440,37 @@ def _plot_many_bars(self, dim, diagram, idx, ax, **kwargs): _ = [self._plot_a_bar(ax[idx], birth, deaths[i], max_death, **kwargs) for i, birth in enumerate(births)] # the line below is to plot a vertical red line showing the maximal finite bar length - ax[idx].plot([max_death, max_death], [0, number_of_bars - 1], + ax[idx].plot( + [max_death, max_death], + [0, number_of_bars - 1], c='r', linestyle='--', - linewidth=0.5) + linewidth=0.7 + ) title = "H%d barcode: %d finite, %d infinite" % (dim, number_of_bars_fin, number_of_bars_inf) - ax[idx].set_title(title, fontsize=10) + ax[idx].set_title(title, fontsize=9) ax[idx].set_yticks([]) - ax[idx].spines['right'].set_visible(False) - ax[idx].spines['left'].set_visible(False) - ax[idx].spines['top'].set_visible(False) + for loc in ('right', 'left', 'top'): + ax[idx].spines[loc].set_visible(False) @staticmethod - def _plot_a_bar(ax, birth, death, max_death, c='gray', linestyle='-', linewidth=0.5, **kwargs): + def _plot_a_bar(ax, birth, death, max_death, c='gray', linestyle='-', linewidth=0.5): if np.isinf(death[0]): death[0] = 1.05 * max_death - ax.plot(death[0], death[1], + ax.plot( + death[0], + death[1], c=c, markersize=4, - marker='>') + marker='>', + ) - ax.plot([birth[0], death[0]], [birth[1], death[1]], + ax.plot( + [birth[0], death[0]], + [birth[1], death[1]], c=c, linestyle=linestyle, - linewidth=linewidth) + linewidth=linewidth, + )