diff --git a/mesa/experimental/components/altair.py b/mesa/experimental/components/altair.py new file mode 100644 index 00000000000..cf597460b36 --- /dev/null +++ b/mesa/experimental/components/altair.py @@ -0,0 +1,66 @@ +import contextlib +from typing import Optional + +import pandas as pd +import solara + +with contextlib.suppress(ImportError): + import altair as alt + +@solara.component +def SpaceAltair(model, agent_portrayal, dependencies: Optional[list[any]] = None): + space = getattr(model, "grid", None) + if space is None: + # Sometimes the space is defined as model.space instead of model.grid + space = model.space + chart = _draw_grid(space, agent_portrayal) + solara.FigureAltair(chart) + +# _draw_grid is derived from +# https://github.com/Princeton-CDH/simulating-risk/blob/907c290e12c97b28aa9ce9c80ea7fc52a4f280ae/simulatingrisk/hawkdove/server.py#L114-L171 +# Copyright 2023 The Center for Digital Humanities of Princeton +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +def _draw_grid(space, agent_portrayal): + def portray(g): + all_agent_data = [] + for content, (x, y) in space.coord_iter(): + if not content: + continue + if not hasattr(content, "__iter__"): + # Is a single grid + content = [content] + for agent in content: + # use all data from agent portrayal, and add x,y coordinates + agent_data = agent_portrayal(agent) + agent_data["x"] = x + agent_data["y"] = y + all_agent_data.append(agent_data) + return all_agent_data + + all_agent_data = portray(space) + df = pd.DataFrame(all_agent_data) + chart = ( + alt.Chart(df) + .mark_point(filled=True) + .encode( + # no x-axis label + x=alt.X("x", axis=None), + # no y-axis label + y=alt.Y("y", axis=None), + size=alt.Size("size"), + color=alt.Color("color"), + ) + # .configure_view(strokeOpacity=0) # hide grid/chart lines + ) + return chart diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index 7a586239058..adc4030f785 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -7,6 +7,7 @@ from solara.alias import rv import mesa.experimental.components.matplotlib as components_matplotlib +import mesa.experimental.components.altair as components_altair # Avoid interactive backend plt.switch_backend("agg") @@ -74,6 +75,10 @@ def ColorCard(color, layout_type): components_matplotlib.SpaceMatplotlib( model, agent_portrayal, dependencies=[current_step.value] ) + elif space_drawer == "altair": + components_altair.SpaceAltair( + model, agent_portrayal, dependencies=[current_step.value] + ) elif space_drawer: # if specified, draw agent space with an alternate renderer space_drawer(model, agent_portrayal)