From 24a1df30a12a399cc54d8ff2d578db57e0f37a4a Mon Sep 17 00:00:00 2001 From: rht Date: Thu, 17 Aug 2023 08:45:53 -0400 Subject: [PATCH] solara: Implement visualization for network grid (#1767) --- mesa/experimental/jupyter_viz.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index fffcda5ebca..7a49d7aa64f 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -1,11 +1,14 @@ import threading import matplotlib.pyplot as plt +import networkx as nx import reacton.ipywidgets as widgets import solara from matplotlib.figure import Figure from matplotlib.ticker import MaxNLocator +import mesa + # Avoid interactive backend plt.switch_backend("agg") @@ -91,10 +94,24 @@ def portray(self, g): return out +def _draw_network_grid(viz, space_ax): + graph = viz.model.grid.G + pos = nx.spring_layout(graph, seed=0) + nx.draw( + graph, + ax=space_ax, + pos=pos, + **viz.agent_portrayal(graph), + ) + + def make_space(viz): space_fig = Figure() space_ax = space_fig.subplots() - space_ax.scatter(**viz.portray(viz.model.grid)) + if isinstance(viz.model.grid, mesa.space.NetworkGrid): + _draw_network_grid(viz, space_ax) + else: + space_ax.scatter(**viz.portray(viz.model.grid)) space_ax.set_axis_off() solara.FigureMatplotlib(space_fig, dependencies=[viz.model, viz.df])