Skip to content

Commit

Permalink
refactor: Move Matplotlib-specific Solara components to separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
rht committed Jan 15, 2024
1 parent d3a0e2f commit a2750ea
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 116 deletions.
114 changes: 114 additions & 0 deletions mesa/experimental/components/matplotlib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Optional

import networkx as nx
import solara
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator

import mesa


@solara.component
def SpaceMatplotlib(model, agent_portrayal, dependencies: Optional[list[any]] = None):
space_fig = Figure()
space_ax = space_fig.subplots()
space = getattr(model, "grid", None)

Check warning on line 15 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L13-L15

Added lines #L13 - L15 were not covered by tests
if space is None:
# Sometimes the space is defined as model.space instead of model.grid
space = model.space

Check warning on line 18 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L18

Added line #L18 was not covered by tests
if isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)

Check warning on line 20 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L20

Added line #L20 was not covered by tests
elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(space, space_ax, agent_portrayal)

Check warning on line 22 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L22

Added line #L22 was not covered by tests
else:
_draw_grid(space, space_ax, agent_portrayal)
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)

Check warning on line 26 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L24-L26

Added lines #L24 - L26 were not covered by tests


def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
y = []
s = [] # size
c = [] # color

Check warning on line 34 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L30-L34

Added lines #L30 - L34 were not covered by tests
for i in range(g.width):
for j in range(g.height):
content = g._grid[i][j]

Check warning on line 37 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L37

Added line #L37 was not covered by tests
if not content:
continue

Check warning on line 39 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L39

Added line #L39 was not covered by tests
if not hasattr(content, "__iter__"):
# Is a single grid
content = [content]

Check warning on line 42 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L42

Added line #L42 was not covered by tests
for agent in content:
data = agent_portrayal(agent)
x.append(i)
y.append(j)

Check warning on line 46 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L44-L46

Added lines #L44 - L46 were not covered by tests
if "size" in data:
s.append(data["size"])

Check warning on line 48 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L48

Added line #L48 was not covered by tests
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}

Check warning on line 51 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L50-L51

Added lines #L50 - L51 were not covered by tests
if len(s) > 0:
out["s"] = s

Check warning on line 53 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L53

Added line #L53 was not covered by tests
if len(c) > 0:
out["c"] = c
return out

Check warning on line 56 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L55-L56

Added lines #L55 - L56 were not covered by tests

space_ax.scatter(**portray(space))

Check warning on line 58 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L58

Added line #L58 was not covered by tests


def _draw_network_grid(space, space_ax, agent_portrayal):
graph = space.G
pos = nx.spring_layout(graph, seed=0)
nx.draw(

Check warning on line 64 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L62-L64

Added lines #L62 - L64 were not covered by tests
graph,
ax=space_ax,
pos=pos,
**agent_portrayal(graph),
)


def _draw_continuous_space(space, space_ax, agent_portrayal):
def portray(space):
x = []
y = []
s = [] # size
c = [] # color

Check warning on line 77 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L73-L77

Added lines #L73 - L77 were not covered by tests
for agent in space._agent_to_index:
data = agent_portrayal(agent)
_x, _y = agent.pos
x.append(_x)
y.append(_y)

Check warning on line 82 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L79-L82

Added lines #L79 - L82 were not covered by tests
if "size" in data:
s.append(data["size"])

Check warning on line 84 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L84

Added line #L84 was not covered by tests
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}

Check warning on line 87 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L86-L87

Added lines #L86 - L87 were not covered by tests
if len(s) > 0:
out["s"] = s

Check warning on line 89 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L89

Added line #L89 was not covered by tests
if len(c) > 0:
out["c"] = c
return out

Check warning on line 92 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L91-L92

Added lines #L91 - L92 were not covered by tests

space_ax.scatter(**portray(space))

Check warning on line 94 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L94

Added line #L94 was not covered by tests


def make_plot(model, measure):
fig = Figure()
ax = fig.subplots()
df = model.datacollector.get_model_vars_dataframe()

Check warning on line 100 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L98-L100

Added lines #L98 - L100 were not covered by tests
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)

Check warning on line 103 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L102-L103

Added lines #L102 - L103 were not covered by tests
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
fig.legend()

Check warning on line 107 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L106-L107

Added lines #L106 - L107 were not covered by tests
elif isinstance(measure, (list, tuple)):
for m in measure:
ax.plot(df.loc[:, m], label=m)
fig.legend()

Check warning on line 111 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L110-L111

Added lines #L110 - L111 were not covered by tests
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
solara.FigureMatplotlib(fig)

Check warning on line 114 in mesa/experimental/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/components/matplotlib.py#L113-L114

Added lines #L113 - L114 were not covered by tests
121 changes: 6 additions & 115 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from typing import Optional

import matplotlib.pyplot as plt
import networkx as nx
import reacton.ipywidgets as widgets
import solara
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from solara.alias import rv

import mesa
import mesa.experimental.components.matplotlib as components_matplotlib

# Avoid interactive backend
plt.switch_backend("agg")
Expand Down Expand Up @@ -72,7 +69,7 @@ def ColorCard(color, layout_type):
rv.CardTitle(children=["Space"])
if space_drawer == "default":
# draw with the default implementation
SpaceMatplotlib(
components_matplotlib.SpaceMatplotlib(
model, agent_portrayal, dependencies=[current_step.value]
)
elif space_drawer:
Expand All @@ -85,7 +82,7 @@ def ColorCard(color, layout_type):
# Is a custom object
measure(model)
else:
make_plot(model, measure)
components_matplotlib.make_plot(model, measure)

Check warning on line 85 in mesa/experimental/jupyter_viz.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/jupyter_viz.py#L85

Added line #L85 was not covered by tests
return main

# 3. Set up UI
Expand All @@ -106,7 +103,7 @@ def render_in_jupyter():
# 4. Space
if space_drawer == "default":
# draw with the default implementation
SpaceMatplotlib(
components_matplotlib.SpaceMatplotlib(

Check warning on line 106 in mesa/experimental/jupyter_viz.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/jupyter_viz.py#L106

Added line #L106 was not covered by tests
model, agent_portrayal, dependencies=[current_step.value]
)
elif space_drawer:
Expand All @@ -121,7 +118,7 @@ def render_in_jupyter():
# Is a custom object
measure(model)
else:
make_plot(model, measure)
components_matplotlib.make_plot(model, measure)

Check warning on line 121 in mesa/experimental/jupyter_viz.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/jupyter_viz.py#L121

Added line #L121 was not covered by tests

def render_in_browser():
# if space drawer is disabled, do not include it
Expand Down Expand Up @@ -182,7 +179,7 @@ def on_value_play(change):
def do_step():
model.step()
previous_step.value = current_step.value
current_step.value = model.schedule.steps
current_step.value += 1

Check warning on line 182 in mesa/experimental/jupyter_viz.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/jupyter_viz.py#L182

Added line #L182 was not covered by tests

def do_play():
model.running = True
Expand Down Expand Up @@ -316,112 +313,6 @@ def change_handler(value, name=name):
raise ValueError(f"{input_type} is not a supported input type")


@solara.component
def SpaceMatplotlib(model, agent_portrayal, dependencies: Optional[list[any]] = None):
space_fig = Figure()
space_ax = space_fig.subplots()
space = getattr(model, "grid", None)
if space is None:
# Sometimes the space is defined as model.space instead of model.grid
space = model.space
if isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(space, space_ax, agent_portrayal)
else:
_draw_grid(space, space_ax, agent_portrayal)
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)


def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
y = []
s = [] # size
c = [] # color
for i in range(g.width):
for j in range(g.height):
content = g._grid[i][j]
if not content:
continue
if not hasattr(content, "__iter__"):
# Is a single grid
content = [content]
for agent in content:
data = agent_portrayal(agent)
x.append(i)
y.append(j)
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c
return out

space_ax.scatter(**portray(space))


def _draw_network_grid(space, space_ax, agent_portrayal):
graph = space.G
pos = nx.spring_layout(graph, seed=0)
nx.draw(
graph,
ax=space_ax,
pos=pos,
**agent_portrayal(graph),
)


def _draw_continuous_space(space, space_ax, agent_portrayal):
def portray(space):
x = []
y = []
s = [] # size
c = [] # color
for agent in space._agent_to_index:
data = agent_portrayal(agent)
_x, _y = agent.pos
x.append(_x)
y.append(_y)
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c
return out

space_ax.scatter(**portray(space))


def make_plot(model, measure):
fig = Figure()
ax = fig.subplots()
df = model.datacollector.get_model_vars_dataframe()
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
fig.legend()
elif isinstance(measure, (list, tuple)):
for m in measure:
ax.plot(df.loc[:, m], label=m)
fig.legend()
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
solara.FigureMatplotlib(fig)


def make_text(renderer):
def function(model):
solara.Markdown(renderer(model))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def Test(user_params):


class TestJupyterViz(unittest.TestCase):
@patch("mesa.experimental.jupyter_viz.SpaceMatplotlib")
@patch("mesa.experimental.components.matplotlib.SpaceMatplotlib")
def test_call_space_drawer(self, mock_space_matplotlib):
mock_model_class = Mock()
agent_portrayal = {
Expand Down

0 comments on commit a2750ea

Please sign in to comment.