Skip to content

Commit

Permalink
* Changes made to address feedback
Browse files Browse the repository at this point in the history
* Code layout changed to play nice with sphinx docs template used
* removed useless `plot_a_bar` function
  • Loading branch information
matt s authored and DeliciousHair committed Aug 6, 2021
1 parent 3064c04 commit a9d0adc
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 101 deletions.
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Diagram Visualization
persim.plot_diagrams
persim.bottleneck_matching
persim.wasserstein_matching
persim.Barcode
persim.plot_landscape
persim.plot_landscape_simple

Expand Down
211 changes: 110 additions & 101 deletions persim/visuals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import numpy as np
import matplotlib.pyplot as plt

__all__ = ["plot_diagrams", "bottleneck_matching", "wasserstein_matching"]
import io

__all__ = [
"plot_diagrams",
"bottleneck_matching",
"wasserstein_matching",
"Barcode"
]


def plot_diagrams(
Expand Down Expand Up @@ -165,9 +172,6 @@ def plot_diagrams(
if show is True:
plt.show()

def plot_a_bar(p, q, c='b', linestyle='-'):
plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1)

def bottleneck_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None):
""" Visualize bottleneck matching between two diagrams
Expand Down Expand Up @@ -287,76 +291,93 @@ def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None)

plot_diagrams([dgm1, dgm2], labels=labels, ax=ax)

class Barcode():
def __init__(self, diagrams):
'''
parameters:
===========
diagrams: list-like
typically the output of ripser(nodes)['dgms']
examples:
===========
class Barcode:
__doc__ = """
Barcode visualisation made easy!
n = 300
t = np.linspace(0, 2 * np.pi, n)
noise = np.random.normal(0, 0.1, size=n)
Note that this convenience class requires instantiation as the number
of subplots produced depends on the dimension of the data.
"""

data = np.vstack([((3+d) * np.cos(t[i]+d), (3+d) * np.sin(t[i]+d)) for i, d in enumerate(noise)])
diagrams = ripser(data)
def __init__(self, diagrams, verbose=False):
"""
Parameters
===========
diagrams: list-like
typically the output of ripser(nodes)['dgms']
verbose: bool
Execute print statemens for extra information; currently only echoes
number of bars in each dimension (Default=False).
bc = Barcode(diagrams['dgms'])
bc.plot_barcode()
'''
Examples
===========
>>> n = 300
>>> t = np.linspace(0, 2 * np.pi, n)
>>> noise = np.random.normal(0, 0.1, size=n)
>>> data = np.vstack([((3+d) * np.cos(t[i]+d), (3+d) * np.sin(t[i]+d)) for i, d in enumerate(noise)])
>>> diagrams = ripser(data)
>>> bc = Barcode(diagrams['dgms'])
>>> bc.plot_barcode()
"""
if not isinstance(diagrams, list):
diagrams = [diagrams]

if len(diagrams) == 2:
self.plot_barcode = self._plot_H0_H1

else:
self.plot_barcode = self._plot_Hn

self.diagrams = diagrams
self._verbose = verbose
self._dim = len(diagrams)

def plot_barcode(self, figsize=(6,5), show=True, export_png=False, dpi=100, **kwargs):
"""Wrapper method to produce barcode plot
def _plot_H0_H1(self, **kwargs):
'''
parameters:
Parameters
===========
figsize: tuple
figure size, default=(6,6)
show: boolean
show the figure via plt.show()
export_png: boolean
write image to png data, returned as io.BytesIO() instance, default=False
**kwargs: artist paramters for the barcodes, defaults:
c='grey'
linestyle='-'
linewidth=0.5
dpi=100 (for png export)
returns:
figsize: tuple
figure size, default=(6,5)
show: boolean
show the figure via plt.show()
export_png: boolean
write image to png data, returned as io.BytesIO() instance, default=False
**kwargs: artist paramters for the barcodes, defaults:
c='grey'
linestyle='-'
linewidth=0.5
dpi=100 (for png export)
Returns
===========
list of png exports or []
'''
import io

fsize = kwargs.get('figsize', (6, 6))
show = kwargs.get('show', True)
export = kwargs.get('export_png', False)
dpi = kwargs.get('dpi', 100)
out: list or None
list of png exports if export_png=True, otherwise None
"""
if self._dim == 2:
return self._plot_H0_H1(
figsize=figsize,
show=show,
export_png=export_png,
dpi=dpi,
**kwargs
)

else:
return self._plot_Hn(
figsize=figsize,
show=show,
export_png=export_png,
dpi=dpi,
**kwargs
)

def _plot_H0_H1(self, *, figsize, show, export_png, dpi, **kwargs):
out = []

fig, ax = plt.subplots(2, 1, figsize=fsize)
fig, ax = plt.subplots(2, 1, figsize=figsize)

for dim, diagram in enumerate(self.diagrams):
self._plot_many_bars(dim, diagram, dim, ax, **kwargs)

if export:
if export_png:
fp = io.BytesIO()
plt.savefig(fp, dpi=dpi)
fp.seek(0)
Expand All @@ -365,45 +386,21 @@ def _plot_H0_H1(self, **kwargs):

if show:
plt.show()
else:
plt.close()

return out

def _plot_Hn(self, **kwargs):
'''
parameters:
===========
figsize: tuple
figure size, default=(6,6)
show: boolean
show the figure via plt.show()
export_png: boolean
write image to png data, returned as io.BytesIO() instance, default=False
**kwargs: artist paramters for the barcodes, defaults:
c='grey'
linestyle='-'
linewidth=0.5
dpi=100 (for png export)
returns:
===========
list of png exports or []
'''
fsize = kwargs.get('figsize', (6, 4))
show = kwargs.get('show', True)
export = kwargs.get('export_png', False)
dpi = kwargs.get('dpi', 100)
if any(out):
return out

def _plot_Hn(self, *, figsize, show, export_png, dpi, **kwargs):
out = []

for dim, diagram in enumerate(self.diagrams):
fig, ax = plt.subplots(1, 1, figsize=fsize)
fig, ax = plt.subplots(1, 1, figsize=figsize)

self._plot_many_bars(dim, diagram, 0, [ax], **kwargs)

if export:
if export_png:
fp = io.BytesIO()
plt.savefig(fp, dpi=dpi)
fp.seek(0)
Expand All @@ -412,12 +409,16 @@ def _plot_Hn(self, **kwargs):

if show:
plt.show()
else:
plt.close()

return out
if any(out):
return out

def _plot_many_bars(self, dim, diagram, idx, ax, **kwargs):
number_of_bars = len(diagram)
print("Number of bars in dimension %d: %d" % (dim, number_of_bars))
if self._verbose:
print("Number of bars in dimension %d: %d" % (dim, number_of_bars))

if number_of_bars > 0:
births = np.vstack([(elem[0], i) for i, elem in enumerate(diagram)])
Expand All @@ -432,29 +433,37 @@ def _plot_many_bars(self, dim, diagram, idx, ax, **kwargs):
_ = [self._plot_a_bar(ax[idx], birth, deaths[i], max_death, **kwargs) for i, birth in enumerate(births)]

# the line below is to plot a vertical red line showing the maximal finite bar length
ax[idx].plot([max_death, max_death], [0, number_of_bars - 1],
ax[idx].plot(
[max_death, max_death],
[0, number_of_bars - 1],
c='r',
linestyle='--',
linewidth=0.5)
linewidth=0.7
)

title = "H%d barcode: %d finite, %d infinite" % (dim, number_of_bars_fin, number_of_bars_inf)
ax[idx].set_title(title, fontsize=10)
ax[idx].set_title(title, fontsize=9)
ax[idx].set_yticks([])

ax[idx].spines['right'].set_visible(False)
ax[idx].spines['left'].set_visible(False)
ax[idx].spines['top'].set_visible(False)
for loc in ('right', 'left', 'top'):
ax[idx].spines[loc].set_visible(False)

@staticmethod
def _plot_a_bar(ax, birth, death, max_death, c='gray', linestyle='-', linewidth=0.5, **kwargs):
def _plot_a_bar(ax, birth, death, max_death, c='gray', linestyle='-', linewidth=0.5):
if np.isinf(death[0]):
death[0] = 1.05 * max_death
ax.plot(death[0], death[1],
ax.plot(
death[0],
death[1],
c=c,
markersize=4,
marker='>')
marker='>',
)

ax.plot([birth[0], death[0]], [birth[1], death[1]],
ax.plot(
[birth[0], death[0]],
[birth[1], death[1]],
c=c,
linestyle=linestyle,
linewidth=linewidth)
linewidth=linewidth,
)

0 comments on commit a9d0adc

Please sign in to comment.