From a78f802fd1382eb888d8a13600f5e3a166cd6438 Mon Sep 17 00:00:00 2001 From: rht Date: Sun, 7 Jan 2024 01:39:20 -0500 Subject: [PATCH] refactor: Move Matplotlib-specific components to separate file --- mesa/experimental/components/matplotlib.py | 114 +++++++++++++++++++ mesa/experimental/jupyter_viz.py | 122 +-------------------- 2 files changed, 120 insertions(+), 116 deletions(-) create mode 100644 mesa/experimental/components/matplotlib.py diff --git a/mesa/experimental/components/matplotlib.py b/mesa/experimental/components/matplotlib.py new file mode 100644 index 00000000000..b9d6e43c9df --- /dev/null +++ b/mesa/experimental/components/matplotlib.py @@ -0,0 +1,114 @@ +from typing import List, Optional + +import networkx as nx +import solara +from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator + +import mesa + + +@solara.component +def SpaceMatplotlib(model, agent_portrayal, dependencies: Optional[List[any]] = None): + space_fig = Figure() + space_ax = space_fig.subplots() + space = getattr(model, "grid", None) + if space is None: + # Sometimes the space is defined as model.space instead of model.grid + space = model.space + if isinstance(space, mesa.space.NetworkGrid): + _draw_network_grid(space, space_ax, agent_portrayal) + elif isinstance(space, mesa.space.ContinuousSpace): + _draw_continuous_space(space, space_ax, agent_portrayal) + else: + _draw_grid(space, space_ax, agent_portrayal) + space_ax.set_axis_off() + solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) + + +def _draw_grid(space, space_ax, agent_portrayal): + def portray(g): + x = [] + y = [] + s = [] # size + c = [] # color + for i in range(g.width): + for j in range(g.height): + content = g._grid[i][j] + if not content: + continue + if not hasattr(content, "__iter__"): + # Is a single grid + content = [content] + for agent in content: + data = agent_portrayal(agent) + x.append(i) + y.append(j) + if "size" in data: + s.append(data["size"]) + if "color" in data: + c.append(data["color"]) + out = {"x": x, "y": y} + if len(s) > 0: + out["s"] = s + if len(c) > 0: + out["c"] = c + return out + + space_ax.scatter(**portray(space)) + + +def _draw_network_grid(space, space_ax, agent_portrayal): + graph = space.G + pos = nx.spring_layout(graph, seed=0) + nx.draw( + graph, + ax=space_ax, + pos=pos, + **agent_portrayal(graph), + ) + + +def _draw_continuous_space(space, space_ax, agent_portrayal): + def portray(space): + x = [] + y = [] + s = [] # size + c = [] # color + for agent in space._agent_to_index: + data = agent_portrayal(agent) + _x, _y = agent.pos + x.append(_x) + y.append(_y) + if "size" in data: + s.append(data["size"]) + if "color" in data: + c.append(data["color"]) + out = {"x": x, "y": y} + if len(s) > 0: + out["s"] = s + if len(c) > 0: + out["c"] = c + return out + + space_ax.scatter(**portray(space)) + + +def make_plot(model, measure): + fig = Figure() + ax = fig.subplots() + df = model.datacollector.get_model_vars_dataframe() + if isinstance(measure, str): + ax.plot(df.loc[:, measure]) + ax.set_ylabel(measure) + elif isinstance(measure, dict): + for m, color in measure.items(): + ax.plot(df.loc[:, m], label=m, color=color) + fig.legend() + elif isinstance(measure, (list, tuple)): + for m in measure: + ax.plot(df.loc[:, m], label=m) + fig.legend() + # Set integer x axis + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + solara.FigureMatplotlib(fig) diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index 5cb12b531b3..2e5f066409f 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -1,16 +1,12 @@ import sys import threading -from typing import List, Optional import matplotlib.pyplot as plt -import networkx as nx import reacton.ipywidgets as widgets import solara -from matplotlib.figure import Figure -from matplotlib.ticker import MaxNLocator from solara.alias import rv -import mesa +import mesa.experimental.components.matplotlib as components_matplotlib # Avoid interactive backend plt.switch_backend("agg") @@ -72,7 +68,7 @@ def ColorCard(color, layout_type): rv.CardTitle(children=["Space"]) if space_drawer == "default": # draw with the default implementation - SpaceMatplotlib( + components_matplotlib.SpaceMatplotlib( model, agent_portrayal, dependencies=[current_step.value] ) elif space_drawer: @@ -85,7 +81,7 @@ def ColorCard(color, layout_type): # Is a custom object measure(model) else: - make_plot(model, measure) + components_matplotlib.make_plot(model, measure) return main # 3. Set up UI @@ -106,7 +102,7 @@ def render_in_jupyter(): # 4. Space if space_drawer == "default": # draw with the default implementation - SpaceMatplotlib( + components_matplotlib.SpaceMatplotlib( model, agent_portrayal, dependencies=[current_step.value] ) elif space_drawer: @@ -121,7 +117,7 @@ def render_in_jupyter(): # Is a custom object measure(model) else: - make_plot(model, measure) + components_matplotlib.make_plot(model, measure) def render_in_browser(): # if space drawer is disabled, do not include it @@ -182,7 +178,7 @@ def on_value_play(change): def do_step(): model.step() previous_step.value = current_step.value - current_step.value = model.schedule.steps + current_step.value += 1 def do_play(): model.running = True @@ -316,112 +312,6 @@ def change_handler(value, name=name): raise ValueError(f"{input_type} is not a supported input type") -@solara.component -def SpaceMatplotlib(model, agent_portrayal, dependencies: Optional[List[any]] = None): - space_fig = Figure() - space_ax = space_fig.subplots() - space = getattr(model, "grid", None) - if space is None: - # Sometimes the space is defined as model.space instead of model.grid - space = model.space - if isinstance(space, mesa.space.NetworkGrid): - _draw_network_grid(space, space_ax, agent_portrayal) - elif isinstance(space, mesa.space.ContinuousSpace): - _draw_continuous_space(space, space_ax, agent_portrayal) - else: - _draw_grid(space, space_ax, agent_portrayal) - space_ax.set_axis_off() - solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) - - -def _draw_grid(space, space_ax, agent_portrayal): - def portray(g): - x = [] - y = [] - s = [] # size - c = [] # color - for i in range(g.width): - for j in range(g.height): - content = g._grid[i][j] - if not content: - continue - if not hasattr(content, "__iter__"): - # Is a single grid - content = [content] - for agent in content: - data = agent_portrayal(agent) - x.append(i) - y.append(j) - if "size" in data: - s.append(data["size"]) - if "color" in data: - c.append(data["color"]) - out = {"x": x, "y": y} - if len(s) > 0: - out["s"] = s - if len(c) > 0: - out["c"] = c - return out - - space_ax.scatter(**portray(space)) - - -def _draw_network_grid(space, space_ax, agent_portrayal): - graph = space.G - pos = nx.spring_layout(graph, seed=0) - nx.draw( - graph, - ax=space_ax, - pos=pos, - **agent_portrayal(graph), - ) - - -def _draw_continuous_space(space, space_ax, agent_portrayal): - def portray(space): - x = [] - y = [] - s = [] # size - c = [] # color - for agent in space._agent_to_index: - data = agent_portrayal(agent) - _x, _y = agent.pos - x.append(_x) - y.append(_y) - if "size" in data: - s.append(data["size"]) - if "color" in data: - c.append(data["color"]) - out = {"x": x, "y": y} - if len(s) > 0: - out["s"] = s - if len(c) > 0: - out["c"] = c - return out - - space_ax.scatter(**portray(space)) - - -def make_plot(model, measure): - fig = Figure() - ax = fig.subplots() - df = model.datacollector.get_model_vars_dataframe() - if isinstance(measure, str): - ax.plot(df.loc[:, measure]) - ax.set_ylabel(measure) - elif isinstance(measure, dict): - for m, color in measure.items(): - ax.plot(df.loc[:, m], label=m, color=color) - fig.legend() - elif isinstance(measure, (list, tuple)): - for m in measure: - ax.plot(df.loc[:, m], label=m) - fig.legend() - # Set integer x axis - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) - solara.FigureMatplotlib(fig) - - def make_text(renderer): def function(model): solara.Markdown(renderer(model))