Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding some better examples #6

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/pyalluv.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 4

pyalluv.clusters
pyalluv.fluxes
Expand Down
50 changes: 50 additions & 0 deletions examples/_example_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import json

eg_data = [
# slice 0: initial state
{'a': {None: 10}, 'b': {None: 14}, 'c': {None: 4}},
# slice 1
{
None: {'a': 1},
'a': {'a': 4},
'b': {'a': 6},
'c': {'b': 8},
'd': {'b': 5},
'e': {'c': 4, None: 4}
},
# slice 2
{
None: {'b': 1, 'c': 3, 'd': 2, 'e': 1},
'a': {'b': 5},
'b': {'a': 3, 'c': 5, 'd': 1},
'c': {'e': 3},
'd': {'e': 4}
},
# slice 3
{
None: {'b': 2},
'a': {'a': 5, None: 4},
'b': {'b': 3},
'c': {'b': 5},
'd': {'c': 3},
'e': {'c': 1, 'd': 1},
'f': {'d': 4}
},
# slice 4
{
None: {'a': 1, 'c': 1, 'd': 1},
'a': {'a': 5},
'b': {'a': 3, 'b': 1},
'c': {'b': 2, 'c': 4},
'd': {'d': 2, 'e': 2, 'f': 4}
},
# slice 5
{
None: {},
'a': {'a': 5, 'b': 4, None: 1},
'b': {'c': 6, 'd': 8}
}
]

with open('example_data.json', 'w') as fobj:
json.dump(eg_data, fobj)
1 change: 1 addition & 0 deletions examples/example_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"a": {"null": 10}, "b": {"null": 14}, "c": {"null": 4}}, {"null": {"a": 1}, "a": {"a": 4}, "b": {"a": 6}, "c": {"b": 8}, "d": {"b": 5}, "e": {"c": 4, "null": 4}}, {"null": {"b": 1, "c": 3, "d": 2, "e": 1}, "a": {"b": 5}, "b": {"a": 3, "c": 5, "d": 1}, "c": {"e": 3}, "d": {"e": 4}}, {"null": {"b": 2}, "a": {"a": 5, "null": 4}, "b": {"b": 3}, "c": {"b": 5}, "d": {"c": 3}, "e": {"c": 1, "d": 1}, "f": {"d": 4}}, {"null": {"a": 1, "c": 1, "d": 1}, "a": {"a": 5}, "b": {"a": 3, "b": 1}, "c": {"b": 2, "c": 4}, "d": {"d": 2, "e": 2, "f": 4}}, {"null": {}, "a": {"a": 5, "b": 4, "null": 1}, "b": {"c": 6, "d": 8}}]
60 changes: 60 additions & 0 deletions examples/stacked_clusters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import pathlib
import json
from matplotlib import pyplot as plt
import matplotlib
matplotlib.use('TkAgg')

try:
from pyalluv import AlluvialPlot, Cluster, Flux
except ImportError:
os.sys.path.append(os.path.dirname('..'))
from pyalluv import AlluvialPlot, Cluster, Flux

with open(os.path.join(
pathlib.Path(__file__).parent.absolute(),
'example_data.json'
), 'r') as fobj:
eg_data = json.load(fobj)

fc_clusters = 'xkcd:gray'
fc_edges = 'lightgray'

clusters = []
# first create the clusters
for a_slice in eg_data:
slice_clusters = dict()
for target, fluxes in a_slice.items():
if target is None or target == 'null':
pass
else:
if target not in slice_clusters:
slice_clusters[target] = Cluster(
height=sum(fluxes.values()),
label=target,
width=0.2,
facecolor=fc_clusters
)
clusters.append(dict(slice_clusters))
# now the fluxes

for idx, a_slice in enumerate(eg_data):
slice_clusters = clusters[idx]
if idx:
prev_clusters = clusters[idx-1]
for target, fluxes in a_slice.items():
if target is not None and target != 'null':
for source, amount in fluxes.items():
if source is not None and source != 'null':
Flux(
flux=amount,
source_cluster=prev_clusters[source],
target_cluster=slice_clusters[target],
facecolor=fc_edges
)
fig, ax = plt.subplots()
AlluvialPlot(
axes=ax,
clusters=[list(_clusters.values()) for _clusters in clusters]
)
plt.show()
1 change: 1 addition & 0 deletions pyalluv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .clusters import Cluster
from .fluxes import Flux


# backwards compatibility for obsolete version
SankeyPlot = AlluvialPlot
12 changes: 5 additions & 7 deletions pyalluv/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __init__(self, height, anchor=None, width=1.0, label=None, **kwargs):
x_coord, y_coord = anchor
else:
x_coord, y_coord = anchor, None
self = self.set_x_pos(x_coord).set_y_pos(y_coord)
self.set_x_pos(x_coord)
self.set_y_pos(y_coord)

# init the in and out fluxes:
self.out_fluxes = []
Expand All @@ -97,7 +98,7 @@ def __init__(self, height, anchor=None, width=1.0, label=None, **kwargs):

def set_x_pos(self, x_pos):
r"""
Set the horizontal position of a cluster.
Set `self.x_pos`: the horizontal position of a cluster.

The position is set according to the value provided in ``x_pos`` and
``self.x_anchor``.
Expand All @@ -109,8 +110,7 @@ def set_x_pos(self, x_pos):

Returns
--------
self: :class:`.Cluster`
with new property ``x_pos``.
None

"""
self.x_pos = x_pos
Expand All @@ -121,8 +121,6 @@ def set_x_pos(self, x_pos):
elif self.x_anchor == 'right':
self.x_pos -= 0.5 * self.width

return self

def get_patch(self, **kwargs):
_kwargs = dict(kwargs)
_kwargs.update(self.patch_kwargs)
Expand Down Expand Up @@ -323,7 +321,7 @@ def set_y_pos(self, y_pos):
else:
self.mid_height = None

return self
return None

def set_in_out_anchors(self,):
"""
Expand Down
83 changes: 66 additions & 17 deletions pyalluv/plotting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division, absolute_import, unicode_literals
import warnings
from matplotlib.collections import PatchCollection
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
Expand All @@ -12,7 +13,7 @@ class AlluvialPlot(object):
Parameters
===========

clusters: dict[str, dict], dict[float, list] or list[list]
clusters: list[list]
You have 2 options to create an Alluvial diagram\:

raw data: dict[str, dict]
Expand All @@ -31,14 +32,35 @@ class AlluvialPlot(object):
If it is present in the out-fluxes of a cluster, the specified amount
simply vanishes and will not lead to a flux.

collections of :obj:`.Cluster`: dict[float, list] and list[list]
collections of :obj:`.Cluster`: list[list]

.. WARNING:: support for passing a dictionary will de dropped in the
next release, use a list instead and provide
:ref:`x_positions<xposref>` separately.
Converting a dictionary of clusters to the new format can
be done as follows:

.. code-block:: python

x_pos, clusters = zip(*sorted(clusters.items(),key=lambda x:x[0]))

If a `list` is provided each element must be a `list`
of :obj:`.Cluster` objects. A `dictionary` must provide a `list` of
:obj:`.Cluster` (*value*) for a horizontal position (*key*), e.g.
``{1.0: [c11, c12, ...], 2.0: [c21, c22, ...], ...}``.

axes: :class:`matplotlib.axes.Axes`
Axes to draw an Alluvial diagram on.
axes: :class:`matplotlib.axes.Axes` (default=None)
Axes to draw an Alluvial diagram on. If provided the alluvial diagram
will be drawn directly to this axes. Alternatively you can omit this
argument when creating an instance and later call the
:meth:`AlluvialPlot.draw_on` method.


x_positions: list (default=None)
.. _xposref:

A list with horizontal positioning of the clusters.

y_pos: str
**options:** ``'overwrite'``, ``'keep'``, ``'complement'``, ``'sorted'``

Expand Down Expand Up @@ -136,7 +158,7 @@ class AlluvialPlot(object):

.. note::

This ca be used to draw multiple alluvial diagrams on the same
This can be used to draw multiple alluvial diagrams on the same
:obj:`~matplotlib.axes.Axes` by simply calling
:class:`~.AlluvialPlot` repeatedly with changing offset value, thus
stacking alluvial diagrams.
Expand All @@ -148,8 +170,9 @@ class AlluvialPlot(object):
Holds for each vertical position a list of :obj:`.Cluster` objects.
"""
def __init__(
self, clusters, axes, y_pos='overwrite', cluster_w_spacing=1,
cluster_kwargs={}, flux_kwargs={}, label_kwargs={},
self, clusters, axes=None, x_positions=None, y_pos='overwrite',
cluster_w_spacing=1, cluster_kwargs=None, flux_kwargs=None,
label_kwargs=None,
**kwargs
):
# if clusters are given in a list of lists (each list is a x position)
Expand All @@ -160,23 +183,39 @@ def __init__(
)
self.with_cluster_labels = kwargs.get('with_cluster_labels', True)
self.format_xaxis = kwargs.get('format_xaxis', True)
self._cluster_kwargs = cluster_kwargs
self._flux_kwargs = flux_kwargs
self._cluster_kwargs = cluster_kwargs or dict()
self._flux_kwargs = flux_kwargs or dict()
self._label_kwargs = label_kwargs or dict()
self._x_axis_offset = kwargs.get('x_axis_offset', 0.0)
self._fill_figure = kwargs.get('fill_figure', False)
self._invisible_y = kwargs.get('invisible_y', True)
self._invisible_x = kwargs.get('invisible_x', False)
self.y_offset = kwargs.get('y_offset', 0)
self.y_fix = kwargs.get('y_fix', None)
self._clusters = {}
if isinstance(clusters, dict):
warnings.warn(
"Support for providing a dictionary as argument for `clusters`"
" will be dropped in the next release.\nUse a list instead and"
" provide the x positions separately with the `x_positions`"
" argument.", PendingDeprecationWarning)
# key is the x position here, value is a list of clusters
self.clusters = clusters
# create composed key
# TODO
self._clusters.update()
else:
self.clusters = {}
for cluster in clusters:
try:
self.clusters[cluster.x_pos].append(cluster)
except KeyError:
self.clusters[cluster.x_pos] = [cluster]
for idx, slice_clusters in enumerate(clusters):
for cluster in slice_clusters:
if cluster.x_pos is not None:
x_pos = cluster.x_pos
else:
x_pos = idx
try:
self.clusters[x_pos].append(cluster)
except KeyError:
self.clusters[x_pos] = [cluster]
self.x_positions = sorted(self.clusters.keys())
# set the x positions correctly for the clusters
if self._set_x_pos:
Expand Down Expand Up @@ -287,7 +326,7 @@ def __init__(
)
axes.add_collection(patch_collection)
if self.with_cluster_labels:
label_collection = self.get_labelcollection(**label_kwargs)
label_collection = self.get_labelcollection(**self._label_kwargs)
if label_collection:
for label in label_collection:
axes.annotate(**label)
Expand Down Expand Up @@ -317,6 +356,16 @@ def __init__(
if isinstance(self.x_positions[0], datetime) and self.format_xaxis:
self.set_dates_xaxis(axes, _minor_tick)

def draw_on(self, axes):
r"""
Draws the alluvial diagram onto the provided axes.

Parameters
----------
axes: :class:`matplotlib.axes.Axes`
"""
pass

def distribute_clusters(self, x_pos):
r"""
Distribute the clusters for a given x_position vertically
Expand Down Expand Up @@ -648,8 +697,8 @@ def get_patchcollection(
)

def get_labelcollection(self, *args, **kwargs):
h_margin = kwargs.pop('h_margin', None)
v_margin = kwargs.pop('v_margin', None)
h_margin = kwargs.pop('h_margin', 0.1)
v_margin = kwargs.pop('v_margin', 0.1)
if 'horizontalalignment' not in kwargs:
kwargs['horizontalalignment'] = 'right'
if 'verticalalignment' not in kwargs:
Expand Down
16 changes: 0 additions & 16 deletions pyalluv/tests/test_draw.py

This file was deleted.

File renamed without changes.
Loading