Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to scale heatmap vmin and vmax based on avg. computed values #186

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions ternary/heatmapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def polygon_generator(data, scale, style, permutation=None):
yield map(project, vertices), value


def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None,
def heatmap(data, scale, vmin=None, vmax=None, adj_vlims=False, cmap=None, ax=None,
scientific=False, style='triangular', colorbar=True,
permutation=None, use_rgba=False, cbarlabel=None, cb_kwargs=None):
"""
Expand All @@ -203,6 +203,8 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None,
The minimum color value, used to normalize colors. Computed if absent.
vmax: float, None
The maximum color value, used to normalize colors. Computed if absent.
adj_vlims: bool, False
Redefine min and max color values based on computed averages.
cmap: String or matplotlib.colors.Colormap, None
The name of the Matplotlib colormap to use.
ax: Matplotlib AxesSubplot, None
Expand All @@ -229,6 +231,12 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None,

if not ax:
fig, ax = plt.subplots()

if vmax or vmin:
vlims_defined = True
else:
vlims_defined = False

# If use_rgba, make the RGBA values numpy arrays so that they can
# be averaged.
if use_rgba:
Expand All @@ -247,6 +255,18 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None,
vertices_values = polygon_generator(data, scale, style,
permutation=permutation)

# adjust limits of the colormapper if requested,
# but only if user also didn't request specific vlims
if adj_vlims and not vlims_defined:
vals = []
for _, val in vertices_values:
vals.append(val)
vmin = min(vals)
vmax = max(vals)

vertices_values = polygon_generator(data, scale, style,
permutation=permutation)

# Draw the polygons and color them
for vertices, value in vertices_values:
if value is None:
Expand All @@ -257,7 +277,7 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None,
color = value # rgba tuple (r,g,b,a) all in [0,1]
# Matplotlib wants a list of xs and a list of ys
xs, ys = unzip(vertices)
ax.fill(xs, ys, facecolor=color, edgecolor=color)
ax.fill(xs, ys, facecolor=color, edgecolor=None)

if not cb_kwargs:
cb_kwargs = dict()
Expand All @@ -272,8 +292,8 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None,

def heatmapf(func, scale=10, boundary=True, cmap=None, ax=None,
scientific=False, style='triangular', colorbar=True,
permutation=None, vmin=None, vmax=None, cbarlabel=None,
cb_kwargs=None):
permutation=None, vmin=None, vmax=None, adj_vlims=False,
cbarlabel=None, cb_kwargs=None):
"""
Computes func on heatmap partition coordinates and plots heatmap. In other
words, computes the function on lattice points of the simplex (normalized
Expand Down Expand Up @@ -303,6 +323,8 @@ def heatmapf(func, scale=10, boundary=True, cmap=None, ax=None,
The minimum color value, used to normalize colors.
vmax: float
The maximum color value, used to normalize colors.
adj_vlims: bool, False
Redefine min and max color values based on computed averages.
cb_kwargs: dict
dict of kwargs to pass to colorbar

Expand All @@ -318,7 +340,8 @@ def heatmapf(func, scale=10, boundary=True, cmap=None, ax=None,
# Pass everything to the heatmapper
ax = heatmap(data, scale, cmap=cmap, ax=ax, style=style,
scientific=scientific, colorbar=colorbar,
permutation=permutation, vmin=vmin, vmax=vmax,
permutation=permutation, vmin=vmin, vmax=vmax,
adj_vlims=adj_vlims,
cbarlabel=cbarlabel, cb_kwargs=cb_kwargs)
return ax

Expand Down Expand Up @@ -347,8 +370,8 @@ def svg_polygon(coordinates, color):
return polygon


def svg_heatmap(data, scale, filename, vmax=None, vmin=None, style='h',
permutation=None, cmap=None):
def svg_heatmap(data, scale, filename, vmax=None, vmin=None, adj_vlims=False,
style='h', permutation=None, cmap=None):
"""
Create a heatmap in SVG format. Intended for use with very large datasets,
which would require large amounts of RAM using matplotlib. You can convert
Expand All @@ -370,6 +393,8 @@ def svg_heatmap(data, scale, filename, vmax=None, vmin=None, style='h',
The minimum color value, used to normalize colors.
vmax: float
The maximum color value, used to normalize colors.
adj_vlims: bool, False
Redefine min and max color values based on computed averages.
cmap: String or matplotlib.colors.Colormap, None
The name of the Matplotlib colormap to use.
style: String, "h"
Expand Down
6 changes: 3 additions & 3 deletions ternary/ternary_axes_subplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def plot_colored_trajectory(self, points, cmap=None, **kwargs):

def heatmap(self, data, scale=None, cmap=None, scientific=False,
style='triangular', colorbar=True, use_rgba=False,
vmin=None, vmax=None, cbarlabel=None, cb_kwargs=None):
vmin=None, vmax=None, adj_vlims=False, cbarlabel=None, cb_kwargs=None):
permutation = self._permutation
if not scale:
scale = self.get_scale()
Expand All @@ -446,12 +446,12 @@ def heatmap(self, data, scale=None, cmap=None, scientific=False,
heatmapping.heatmap(data, scale, cmap=cmap, style=style, ax=ax,
scientific=scientific, colorbar=colorbar,
permutation=permutation, use_rgba=use_rgba,
vmin=vmin, vmax=vmax, cbarlabel=cbarlabel,
vmin=vmin, vmax=vmax, adj_vlims=adj_vlims, cbarlabel=cbarlabel,
cb_kwargs=cb_kwargs)

def heatmapf(self, func, scale=None, cmap=None, boundary=True,
style='triangular', colorbar=True, scientific=False,
vmin=None, vmax=None, cbarlabel=None, cb_kwargs=None):
vmin=None, vmax=None, adj_vlims=False, cbarlabel=None, cb_kwargs=None):
if not scale:
scale = self.get_scale()
if style.lower()[0] == 'd':
Expand Down