Skip to content

Commit

Permalink
Merge branch 'main' into ui_layout
Browse files Browse the repository at this point in the history
  • Loading branch information
ankitk50 authored Oct 30, 2023
2 parents e8fd953 + adc5549 commit c464a33
Showing 1 changed file with 45 additions and 10 deletions.
55 changes: 45 additions & 10 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,23 @@ def check_param_is_fixed(param):


def make_space(model, agent_portrayal):
space_fig = Figure()
space_ax = space_fig.subplots()
space = getattr(model, "grid", None)
if space is None:
# Sometimes the space is defined as model.space instead of model.grid
space = model.space
if isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(space, space_ax, agent_portrayal)
else:
_draw_grid(space, space_ax, agent_portrayal)
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig, format="png")


def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
y = []
Expand Down Expand Up @@ -212,18 +229,11 @@ def portray(g):
out["c"] = c
return out

space_fig = Figure()
space_ax = space_fig.subplots()
if isinstance(model.grid, mesa.space.NetworkGrid):
_draw_network_grid(model, space_ax, agent_portrayal)
else:
space_ax.scatter(**portray(model.grid))
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig)
space_ax.scatter(**portray(space))


def _draw_network_grid(model, space_ax, agent_portrayal):
graph = model.grid.G
def _draw_network_grid(space, space_ax, agent_portrayal):
graph = space.G
pos = nx.spring_layout(graph, seed=0)
nx.draw(
graph,
Expand All @@ -233,6 +243,31 @@ def _draw_network_grid(model, space_ax, agent_portrayal):
)


def _draw_continuous_space(space, space_ax, agent_portrayal):
def portray(space):
x = []
y = []
s = [] # size
c = [] # color
for agent in space._agent_to_index:
data = agent_portrayal(agent)
_x, _y = agent.pos
x.append(_x)
y.append(_y)
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c
return out

space_ax.scatter(**portray(space))


def make_plot(model, measure):
fig = Figure()
ax = fig.subplots()
Expand Down

0 comments on commit c464a33

Please sign in to comment.