Skip to content

Commit

Permalink
viz: Automatically visualize Cell color attributes in cell space
Browse files Browse the repository at this point in the history
This change adds automatic visualization of Cell color attributes in the Mesa cell space visualization. When a Cell has a color attribute, it will now be automatically rendered in the space visualization without requiring manual configuration.

Key changes:
- Update orthogonal grid drawing to show cell colors using imshow 
- Update hex grid drawing to show cell colors using PatchCollection
- Handle invalid colors gracefully with warnings
- Keep existing agent and grid line visualization behavior
- Add documentation with examples

Colors can be specified using any valid matplotlib color format (names, hex codes, RGB/RGBA tuples). The visualization maintains proper layering with cell colors appearing behind agents and grid lines.

Cell colors are now fully optional - cells without color attributes will be transparent/unfilled just like before.
  • Loading branch information
EwoutH authored Dec 20, 2024
1 parent 393f6a0 commit 95a6b65
Showing 1 changed file with 77 additions and 45 deletions.
122 changes: 77 additions & 45 deletions mesa/visualization/mpl_space_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def draw_orthogonal_grid(
draw_grid: bool = True,
**kwargs,
):
"""Visualize a orthogonal grid.
"""Visualize a orthogonal grid with automatic cell coloring.
Args:
space: the space to visualize
Expand All @@ -263,13 +263,32 @@ def draw_orthogonal_grid(
Returns:
Returns the Axes object with the plot drawn onto it.
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
Cell colors will be automatically visualized if cells have a 'color' attribute. The color
attribute can be any valid matplotlib color specification (name, hex, RGB tuple, etc.).
"""
if ax is None:
fig, ax = plt.subplots()

# First draw cell colors if they exist
if hasattr(space, "all_cells"): # Check if it's a cell space
cell_colors = np.full((space.height, space.width, 4), [1, 1, 1, 0]) # Transparent default

for cell in space.all_cells:
if hasattr(cell, "color"):
x, y = cell.coordinate
try:
rgba_color = to_rgba(cell.color)
cell_colors[y, x] = rgba_color
except ValueError:
warnings.warn(
f"Invalid color value '{cell.color}' for cell at {cell.coordinate}",
UserWarning,
stacklevel=2,
)

# Plot the cell colors
ax.imshow(cell_colors, origin='lower', interpolation='nearest')

# gather agent data
s_default = (180 / max(space.width, space.height)) ** 2
arguments = collect_agent_data(space, agent_portrayal, size=s_default)
Expand All @@ -290,15 +309,14 @@ def draw_orthogonal_grid(

return ax


def draw_hex_grid(
space: HexGrid,
agent_portrayal: Callable,
ax: Axes | None = None,
draw_grid: bool = True,
**kwargs,
):
"""Visualize a hex grid.
"""Visualize a hex grid with automatic cell coloring.
Args:
space: the space to visualize
Expand All @@ -310,28 +328,56 @@ def draw_hex_grid(
Returns:
Returns the Axes object with the plot drawn onto it.
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
Cell colors will be automatically visualized if cells have a 'color' attribute. The color
attribute can be any valid matplotlib color specification (name, hex, RGB tuple, etc.).
"""
if ax is None:
fig, ax = plt.subplots()

# gather data
# First create hexagons for cells if they exist
if hasattr(space, "all_cells"):
patches = []
offset = math.sqrt(0.75)

for cell in space.all_cells:
x, y = cell.coordinate
if y % 2 == 0:
x += 0.5
y *= offset

hex_patch = RegularPolygon(
(x, y),
numVertices=6,
radius=math.sqrt(1 / 3),
orientation=np.radians(120),
)

if hasattr(cell, "color"):
try:
hex_patch.set_facecolor(cell.color)
except ValueError:
warnings.warn(
f"Invalid color value '{cell.color}' for cell at {cell.coordinate}",
UserWarning,
stacklevel=2,
)
hex_patch.set_facecolor('none')
else:
hex_patch.set_facecolor('none')

patches.append(hex_patch)

# Add colored hexagons
cell_collection = PatchCollection(patches, match_original=True)
ax.add_collection(cell_collection)

# gather data for agents
s_default = (180 / max(space.width, space.height)) ** 2
arguments = collect_agent_data(space, agent_portrayal, size=s_default)

# for hexgrids we have to go from logical coordinates to visual coordinates
# this is a bit messy.

# give all even rows an offset in the x direction
# give all rows an offset in the y direction

# numbers here are based on a distance of 1 between centers of hexes
# Convert logical to visual coordinates for agents
offset = math.sqrt(0.75)

loc = arguments["loc"].astype(float)

logical = np.mod(loc[:, 1], 2) == 0
loc[:, 0][logical] += 0.5
loc[:, 1] *= offset
Expand All @@ -340,43 +386,29 @@ def draw_hex_grid(
# plot the agents
_scatter(ax, arguments, **kwargs)

# further styling and adding of grid
# further styling
ax.set_xlim(-1, space.width + 0.5)
ax.set_ylim(-offset, space.height * offset)

def setup_hexmesh(
width,
height,
):
"""Helper function for creating the hexmaesh."""
# fixme: this should be done once, rather than in each update
# fixme check coordinate system in hexgrid (see https://www.redblobgames.com/grids/hexagons/#coordinates-offset)

patches = []
for x, y in itertools.product(range(width), range(height)):
if draw_grid:
# Grid lines
grid_patches = []
for x, y in itertools.product(range(space.width), range(space.height)):
if y % 2 == 0:
x += 0.5 # noqa: PLW2901
y *= offset # noqa: PLW2901
hex = RegularPolygon(
x += 0.5
y *= offset
hex_patch = RegularPolygon(
(x, y),
numVertices=6,
radius=math.sqrt(1 / 3),
orientation=np.radians(120),
)
patches.append(hex)
mesh = PatchCollection(
patches, edgecolor="k", facecolor=(1, 1, 1, 0), linestyle="dotted", lw=1
grid_patches.append(hex_patch)
grid_collection = PatchCollection(
grid_patches, edgecolor="k", facecolor="none", linestyle="dotted", lw=1
)
return mesh
ax.add_collection(grid_collection)

if draw_grid:
# add grid
ax.add_collection(
setup_hexmesh(
space.width,
space.height,
)
)
return ax


Expand Down

0 comments on commit 95a6b65

Please sign in to comment.