Skip to content

Commit

Permalink
* Changes made to address feedback
Browse files Browse the repository at this point in the history
* Code layout changed to play nice with sphinx docs template used
* removed useless `plot_a_bar` function
  • Loading branch information
matt s authored and DeliciousHair committed Aug 10, 2021
1 parent 3064c04 commit 4702251
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 112 deletions.
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Diagram Visualization
persim.plot_diagrams
persim.bottleneck_matching
persim.wasserstein_matching
persim.Barcode
persim.plot_landscape
persim.plot_landscape_simple

Expand Down
240 changes: 128 additions & 112 deletions persim/visuals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import numpy as np
import matplotlib.pyplot as plt

__all__ = ["plot_diagrams", "bottleneck_matching", "wasserstein_matching"]
import io

__all__ = [
"plot_diagrams",
"bottleneck_matching",
"wasserstein_matching",
"Barcode"
]


def plot_diagrams(
Expand All @@ -24,21 +31,21 @@ def plot_diagrams(
Parameters
----------
diagrams: ndarray (n_pairs, 2) or list of diagrams
A diagram or list of diagrams. If diagram is a list of diagrams,
A diagram or list of diagrams. If diagram is a list of diagrams,
then plot all on the same plot using different colors.
plot_only: list of numeric
If specified, an array of only the diagrams that should be plotted.
title: string, default is None
If title is defined, add it as title of the plot.
xy_range: list of numeric [xmin, xmax, ymin, ymax]
User provided range of axes. This is useful for comparing
User provided range of axes. This is useful for comparing
multiple persistence diagrams.
labels: string or list of strings
Legend labels for each diagram.
Legend labels for each diagram.
If none are specified, we use H_0, H_1, H_2,... by default.
colormap: string, default is 'default'
Any of matplotlib color palettes.
Some options are 'default', 'seaborn', 'sequential'.
Any of matplotlib color palettes.
Some options are 'default', 'seaborn', 'sequential'.
See all available styles with
.. code:: python
Expand All @@ -48,17 +55,17 @@ def plot_diagrams(
size: numeric, default is 20
Pixel size of each point plotted.
ax_color: any valid matplotlib color type.
ax_color: any valid matplotlib color type.
See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API.
diagonal: bool, default is True
Plot the diagonal x=y line.
lifetime: bool, default is False. If True, diagonal is turned to False.
Plot life time of each point instead of birth and death.
Plot life time of each point instead of birth and death.
Essentially, visualize (x, y-x).
legend: bool, default is True
If true, show the legend.
show: bool, default is False
Call plt.show() after plotting. If you are using self.plot() as part
Call plt.show() after plotting. If you are using self.plot() as part
of a subplot, set show=False and call plt.show() only once at the end.
"""

Expand Down Expand Up @@ -165,26 +172,23 @@ def plot_diagrams(
if show is True:
plt.show()

def plot_a_bar(p, q, c='b', linestyle='-'):
plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1)

def bottleneck_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None):
""" Visualize bottleneck matching between two diagrams
Parameters
===========
dgm1: Mx(>=2)
dgm1: Mx(>=2)
array of birth/death pairs for PD 1
dgm2: Nx(>=2)
dgm2: Nx(>=2)
array of birth/death paris for PD 2
matching: ndarray(Mx+Nx, 3)
A list of correspondences in an optimal matching, as well as their distance, where:
* First column is index of point in first persistence diagram, or -1 if diagonal
* Second column is index of point in second persistence diagram, or -1 if diagonal
* Third column is the distance of each matching
labels: list of strings
names of diagrams for legend. Default = ["dgm1", "dgm2"],
names of diagrams for legend. Default = ["dgm1", "dgm2"],
ax: matplotlib Axis object
For plotting on a particular axis.
Expand Down Expand Up @@ -248,7 +252,7 @@ def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None)
* Second column is index of point in second persistence diagram, or -1 if diagonal
* Third column is the distance of each matching
labels: list of strings
names of diagrams for legend. Default = ["dgm1", "dgm2"],
names of diagrams for legend. Default = ["dgm1", "dgm2"],
ax: matplotlib Axis object
For plotting on a particular axis.
Expand Down Expand Up @@ -287,76 +291,100 @@ def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None)

plot_diagrams([dgm1, dgm2], labels=labels, ax=ax)

class Barcode():
def __init__(self, diagrams):
'''
parameters:
class Barcode:
__doc__ = """
Barcode visualisation made easy!
Note that this convenience class requires instantiation as the number
of subplots produced depends on the dimension of the data.
"""

def __init__(self, diagrams, verbose=False):
"""
Parameters
===========
diagrams: list-like
typically the output of ripser(nodes)['dgms']
diagrams: list-like
typically the output of ripser(nodes)['dgms']
verbose: bool
Execute print statemens for extra information; currently only echoes
number of bars in each dimension (Default=False).
examples:
Examples
===========
>>> n = 300
>>> t = np.linspace(0, 2 * np.pi, n)
>>> noise = np.random.normal(0, 0.1, size=n)
>>> data = np.vstack([((3+d) * np.cos(t[i]+d), (3+d) * np.sin(t[i]+d)) for i, d in enumerate(noise)])
>>> diagrams = ripser(data)
>>> bc = Barcode(diagrams['dgms'])
>>> bc.plot_barcode()
"""
if not isinstance(diagrams, list):
diagrams = [diagrams]

n = 300
t = np.linspace(0, 2 * np.pi, n)
noise = np.random.normal(0, 0.1, size=n)
self.diagrams = diagrams
self._verbose = verbose
self._dim = len(diagrams)

data = np.vstack([((3+d) * np.cos(t[i]+d), (3+d) * np.sin(t[i]+d)) for i, d in enumerate(noise)])
diagrams = ripser(data)
def plot_barcode(self, figsize=None, show=True, export_png=False, dpi=100, **kwargs):
"""Wrapper method to produce barcode plot
bc = Barcode(diagrams['dgms'])
bc.plot_barcode()
'''
if not isinstance(diagrams, list):
diagrams = [diagrams]
Parameters
===========
figsize: tuple
figure size, default=(6,6) if H0+H1 only, (6,4) otherwise
if len(diagrams) == 2:
self.plot_barcode = self._plot_H0_H1
show: boolean
show the figure via plt.show()
else:
self.plot_barcode = self._plot_Hn
export_png: boolean
write image to png data, returned as io.BytesIO() instance,
default=False
self.diagrams = diagrams
**kwargs: artist paramters for the barcodes, defaults:
c='grey'
linestyle='-'
linewidth=0.5
dpi=100 (for png export)
def _plot_H0_H1(self, **kwargs):
'''
parameters:
===========
figsize: tuple
figure size, default=(6,6)
show: boolean
show the figure via plt.show()
export_png: boolean
write image to png data, returned as io.BytesIO() instance, default=False
**kwargs: artist paramters for the barcodes, defaults:
c='grey'
linestyle='-'
linewidth=0.5
dpi=100 (for png export)
returns:
Returns
===========
list of png exports or []
'''
import io

fsize = kwargs.get('figsize', (6, 6))
show = kwargs.get('show', True)
export = kwargs.get('export_png', False)
dpi = kwargs.get('dpi', 100)
out: list or None
list of png exports if export_png=True, otherwise None
"""
if self._dim == 2:
if figsize is None:
figsize = (6, 6)

return self._plot_H0_H1(
figsize=figsize,
show=show,
export_png=export_png,
dpi=dpi,
**kwargs
)

else:
if figsize is None:
figsize = (6, 4)

return self._plot_Hn(
figsize=figsize,
show=show,
export_png=export_png,
dpi=dpi,
**kwargs
)

def _plot_H0_H1(self, *, figsize, show, export_png, dpi, **kwargs):
out = []

fig, ax = plt.subplots(2, 1, figsize=fsize)
fig, ax = plt.subplots(2, 1, figsize=figsize)

for dim, diagram in enumerate(self.diagrams):
self._plot_many_bars(dim, diagram, dim, ax, **kwargs)

if export:
if export_png:
fp = io.BytesIO()
plt.savefig(fp, dpi=dpi)
fp.seek(0)
Expand All @@ -365,45 +393,21 @@ def _plot_H0_H1(self, **kwargs):

if show:
plt.show()
else:
plt.close()

return out

def _plot_Hn(self, **kwargs):
'''
parameters:
===========
figsize: tuple
figure size, default=(6,6)
show: boolean
show the figure via plt.show()
export_png: boolean
write image to png data, returned as io.BytesIO() instance, default=False
**kwargs: artist paramters for the barcodes, defaults:
c='grey'
linestyle='-'
linewidth=0.5
dpi=100 (for png export)
returns:
===========
list of png exports or []
'''
fsize = kwargs.get('figsize', (6, 4))
show = kwargs.get('show', True)
export = kwargs.get('export_png', False)
dpi = kwargs.get('dpi', 100)
if any(out):
return out

def _plot_Hn(self, *, figsize, show, export_png, dpi, **kwargs):
out = []

for dim, diagram in enumerate(self.diagrams):
fig, ax = plt.subplots(1, 1, figsize=fsize)
fig, ax = plt.subplots(1, 1, figsize=figsize)

self._plot_many_bars(dim, diagram, 0, [ax], **kwargs)

if export:
if export_png:
fp = io.BytesIO()
plt.savefig(fp, dpi=dpi)
fp.seek(0)
Expand All @@ -412,12 +416,16 @@ def _plot_Hn(self, **kwargs):

if show:
plt.show()
else:
plt.close()

return out
if any(out):
return out

def _plot_many_bars(self, dim, diagram, idx, ax, **kwargs):
number_of_bars = len(diagram)
print("Number of bars in dimension %d: %d" % (dim, number_of_bars))
if self._verbose:
print("Number of bars in dimension %d: %d" % (dim, number_of_bars))

if number_of_bars > 0:
births = np.vstack([(elem[0], i) for i, elem in enumerate(diagram)])
Expand All @@ -432,29 +440,37 @@ def _plot_many_bars(self, dim, diagram, idx, ax, **kwargs):
_ = [self._plot_a_bar(ax[idx], birth, deaths[i], max_death, **kwargs) for i, birth in enumerate(births)]

# the line below is to plot a vertical red line showing the maximal finite bar length
ax[idx].plot([max_death, max_death], [0, number_of_bars - 1],
ax[idx].plot(
[max_death, max_death],
[0, number_of_bars - 1],
c='r',
linestyle='--',
linewidth=0.5)
linewidth=0.7
)

title = "H%d barcode: %d finite, %d infinite" % (dim, number_of_bars_fin, number_of_bars_inf)
ax[idx].set_title(title, fontsize=10)
ax[idx].set_title(title, fontsize=9)
ax[idx].set_yticks([])

ax[idx].spines['right'].set_visible(False)
ax[idx].spines['left'].set_visible(False)
ax[idx].spines['top'].set_visible(False)
for loc in ('right', 'left', 'top'):
ax[idx].spines[loc].set_visible(False)

@staticmethod
def _plot_a_bar(ax, birth, death, max_death, c='gray', linestyle='-', linewidth=0.5, **kwargs):
def _plot_a_bar(ax, birth, death, max_death, c='gray', linestyle='-', linewidth=0.5):
if np.isinf(death[0]):
death[0] = 1.05 * max_death
ax.plot(death[0], death[1],
ax.plot(
death[0],
death[1],
c=c,
markersize=4,
marker='>')
marker='>',
)

ax.plot([birth[0], death[0]], [birth[1], death[1]],
ax.plot(
[birth[0], death[0]],
[birth[1], death[1]],
c=c,
linestyle=linestyle,
linewidth=linewidth)
linewidth=linewidth,
)

0 comments on commit 4702251

Please sign in to comment.