diff --git a/.gitignore b/.gitignore index 4ad12d7e..04928f82 100644 --- a/.gitignore +++ b/.gitignore @@ -268,7 +268,7 @@ target/ # Jupyter Notebook .ipynb_checkpoints -.virtual_documents +*.virtual_documents # pyenv .python-version diff --git a/mesa_geo/__init__.py b/mesa_geo/__init__.py index 8021bd5c..a89c0915 100644 --- a/mesa_geo/__init__.py +++ b/mesa_geo/__init__.py @@ -6,7 +6,6 @@ import datetime -from mesa_geo import visualization from mesa_geo.geoagent import AgentCreator, GeoAgent from mesa_geo.geospace import GeoSpace from mesa_geo.raster_layers import Cell, ImageLayer, RasterLayer diff --git a/mesa_geo/geoagent.py b/mesa_geo/geoagent.py index 73887cd2..2a142043 100644 --- a/mesa_geo/geoagent.py +++ b/mesa_geo/geoagent.py @@ -12,7 +12,7 @@ import geopandas as gpd import numpy as np import pyproj -from mesa import Agent +from mesa import Agent, Model from shapely.geometry import mapping from shapely.geometry.base import BaseGeometry from shapely.ops import transform @@ -153,6 +153,9 @@ def create_agent(self, geometry, unique_id): f"Unable to set CRS for {self.agent_class.__name__} due to empty CRS in {self.__class__.__name__}" ) + if not isinstance(self.model, Model): + raise ValueError("Model must be a valid Mesa model object") + new_agent = self.agent_class( unique_id=unique_id, model=self.model, diff --git a/mesa_geo/raster_layers.py b/mesa_geo/raster_layers.py index 09c287cd..98406b6c 100644 --- a/mesa_geo/raster_layers.py +++ b/mesa_geo/raster_layers.py @@ -164,7 +164,7 @@ class Cell(Agent): pos: Coordinate | None indices: Coordinate | None - def __init__(self, pos=None, indices=None): + def __init__(self, model, pos=None, indices=None): """ Initialize a cell. @@ -174,7 +174,7 @@ def __init__(self, pos=None, indices=None): Origin is at upper left corner of the grid """ - super().__init__(uuid.uuid4().int, None) + super().__init__(uuid.uuid4().int, model) self.pos = pos self.indices = indices @@ -218,15 +218,18 @@ class RasterLayer(RasterBase): _neighborhood_cache: dict[Any, list[Coordinate]] _attributes: set[str] - def __init__(self, width, height, crs, total_bounds, cell_cls: type[Cell] = Cell): + def __init__( + self, width, height, crs, total_bounds, model, cell_cls: type[Cell] = Cell + ): super().__init__(width, height, crs, total_bounds) + self.model = model self.cell_cls = cell_cls self.cells = [] for x in range(self.width): col: list[cell_cls] = [] for y in range(self.height): row_idx, col_idx = self.height - y - 1, x - col.append(self.cell_cls(pos=(x, y), indices=(row_idx, col_idx))) + col.append(self.cell_cls(model, pos=(x, y), indices=(row_idx, col_idx))) self.cells.append(col) self._attributes = set() diff --git a/mesa_geo/tile_layers.py b/mesa_geo/tile_layers.py index 1ec4716a..6defe516 100644 --- a/mesa_geo/tile_layers.py +++ b/mesa_geo/tile_layers.py @@ -27,7 +27,7 @@ class RasterWebTile: kind: str = "raster_web_tile" @classmethod - def from_xyzservices(cls, provider: xyzservices.TileProvider) -> RasterWebTile: + def from_xyzservices(cls, provider=xyzservices.TileProvider) -> RasterWebTile: """ Create a RasterWebTile from an xyzservices TileProvider. diff --git a/mesa_geo/visualization/ModularVisualization.py b/mesa_geo/visualization/ModularVisualization.py deleted file mode 100644 index 6fcf9d85..00000000 --- a/mesa_geo/visualization/ModularVisualization.py +++ /dev/null @@ -1,24 +0,0 @@ -import warnings - -import mesa - - -class ModularServer(mesa.visualization.ModularServer): - def __init__( - self, - model_cls, - visualization_elements, - name="Mesa Model", - model_params=None, - port=None, - ): - super().__init__(model_cls, visualization_elements, name, model_params, port) - - def launch(self, port=None, open_browser=True): - warnings.warn( - "Importing ModularServer from mesa_geo is deprecated, and will be removed in a future release. " - "Import from mesa instead.", - DeprecationWarning, - stacklevel=2, - ) - super().launch(port, open_browser) diff --git a/mesa_geo/visualization/__init__.py b/mesa_geo/visualization/__init__.py index 6d5f0293..256d887d 100644 --- a/mesa_geo/visualization/__init__.py +++ b/mesa_geo/visualization/__init__.py @@ -1 +1,5 @@ -from mesa_geo.visualization.modules import * # noqa +# Import specific classes or functions from the modules +from mesa_geo.visualization.geojupyter_viz import GeoJupyterViz +from mesa_geo.visualization.leaflet_viz import LeafletViz + +__all__ = ["GeoJupyterViz", "LeafletViz"] diff --git a/mesa_geo/visualization/geojupyter_viz.py b/mesa_geo/visualization/geojupyter_viz.py new file mode 100644 index 00000000..5002165f --- /dev/null +++ b/mesa_geo/visualization/geojupyter_viz.py @@ -0,0 +1,211 @@ +import sys + +import matplotlib.pyplot as plt +import mesa.experimental.components.matplotlib as components_matplotlib +import solara +import xyzservices.providers as xyz +from mesa.experimental import jupyter_viz as jv +from solara.alias import rv + +import mesa_geo.visualization.leaflet_viz as leaflet_viz + +# Avoid interactive backend +plt.switch_backend("agg") + + +# TODO: Turn this function into a Solara component once the current_step.value +# dependency is passed to measure() +""" +Geo-Mesa Visualization Module +============================= +Card: Helper Function that initiates the Solara Card for Browser +GeoJupyterViz: Main Function users employ to create visualization +""" + + +def Card( + model, + measures, + agent_portrayal, + map_drawer, + center_default, + zoom, + current_step, + color, + layout_type, +): + with rv.Card( + style_=f"background-color: {color}; width: 100%; height: 100%" + ) as main: + if "Map" in layout_type: + rv.CardTitle(children=["Map"]) + leaflet_viz.map(model, map_drawer, zoom, center_default) + + if "Measure" in layout_type: + rv.CardTitle(children=["Measure"]) + measure = measures[layout_type["Measure"]] + if callable(measure): + # Is a custom object + measure(model) + else: + components_matplotlib.PlotMatplotlib( + model, measure, dependencies=[current_step.value] + ) + return main + + +@solara.component +def GeoJupyterViz( + model_class, + model_params, + measures=None, + name=None, + agent_portrayal=None, + play_interval=150, + # parameters for leaflet_viz + view=None, + zoom=None, + tiles=xyz.OpenStreetMap.Mapnik, + center_point=None, # Due to projection challenges in calculation allow user to specify center point +): + """Initialize a component to visualize a model. + Args: + model_class: class of the model to instantiate + model_params: parameters for initializing the model + measures: list of callables or data attributes to plot + name: name for display + agent_portrayal: options for rendering agents (dictionary) + space_drawer: method to render the agent space for + the model; default implementation is the `SpaceMatplotlib` component; + simulations with no space to visualize should + specify `space_drawer=False` + play_interval: play interval (default: 150) + center_point: list of center coords + """ + if name is None: + name = model_class.__name__ + + current_step = solara.use_reactive(0) + + # 1. Set up model parameters + user_params, fixed_params = jv.split_model_params(model_params) + model_parameters, set_model_parameters = solara.use_state( + {**fixed_params, **{k: v.get("value") for k, v in user_params.items()}} + ) + + # 2. Set up Model + def make_model(): + model = model_class(**model_parameters) + current_step.value = 0 + return model + + reset_counter = solara.use_reactive(0) + model = solara.use_memo( + make_model, dependencies=[*list(model_parameters.values()), reset_counter.value] + ) + + def handle_change_model_params(name: str, value: any): + set_model_parameters({**model_parameters, name: value}) + + # 3. Set up UI + with solara.AppBar(): + solara.AppBarTitle(name) + + # 4. Set Up Map + # render layout, pass through map build parameters + map_drawer = leaflet_viz.MapModule( + portrayal_method=agent_portrayal, + view=view, + zoom=zoom, + tiles=tiles, + ) + layers = map_drawer.render(model) + + # determine center point + if center_point: + center_default = center_point + else: + bounds = layers["layers"]["total_bounds"] + center_default = list((bounds[2:] + bounds[:2]) / 2) + + def render_in_jupyter(): + # TODO: Build API to allow users to set rows and columns + # call in property of model layers geospace line; use 1 column to prevent map overlap + + with solara.Row( + justify="space-between", style={"flex-grow": "1"} + ) and solara.GridFixed(columns=2): + jv.UserInputs(user_params, on_change=handle_change_model_params) + jv.ModelController(model, play_interval, current_step, reset_counter) + solara.Markdown(md_text=f"###Step - {current_step}") + + # Builds Solara component of map + leaflet_viz.map_jupyter(model, map_drawer, zoom, center_default) + + # Place measurement in separate row + with solara.Row( + justify="space-between", + style={"flex-grow": "1"}, + ): + # 5. Plots + for measure in measures: + if callable(measure): + # Is a custom object + measure(model) + else: + components_matplotlib.PlotMatplotlib( + model, measure, dependencies=[current_step.value] + ) + + def render_in_browser(): + # determine center point + if center_point: + center_default = center_point + else: + bounds = layers["layers"]["total_bounds"] + center_default = list((bounds[2:] + bounds[:2]) / 2) + + # if space drawer is disabled, do not include it + layout_types = [{"Map": "default"}] + + if measures: + layout_types += [{"Measure": elem} for elem in range(len(measures))] + + grid_layout_initial = jv.make_initial_grid_layout(layout_types=layout_types) + grid_layout, set_grid_layout = solara.use_state(grid_layout_initial) + + with solara.Sidebar(): + with solara.Card("Controls", margin=1, elevation=2): + jv.UserInputs(user_params, on_change=handle_change_model_params) + jv.ModelController(model, play_interval, current_step, reset_counter) + with solara.Card("Progress", margin=1, elevation=2): + solara.Markdown(md_text=f"####Step - {current_step}") + + items = [ + Card( + model, + measures, + agent_portrayal, + map_drawer, + center_default, + zoom, + current_step, + color="white", + layout_type=layout_types[i], + ) + for i in range(len(layout_types)) + ] + + solara.GridDraggable( + items=items, + grid_layout=grid_layout, + resizable=True, + draggable=True, + on_grid_layout=set_grid_layout, + ) + + if ("ipykernel" in sys.argv[0]) or ("colab_kernel_launcher.py" in sys.argv[0]): + # When in Jupyter or Google Colab + render_in_jupyter() + else: + render_in_browser() diff --git a/mesa_geo/visualization/modules/MapVisualization.py b/mesa_geo/visualization/leaflet_viz.py similarity index 81% rename from mesa_geo/visualization/modules/MapVisualization.py rename to mesa_geo/visualization/leaflet_viz.py index f8b10b2e..c09cb534 100644 --- a/mesa_geo/visualization/modules/MapVisualization.py +++ b/mesa_geo/visualization/leaflet_viz.py @@ -1,22 +1,65 @@ -from __future__ import annotations +""" +# ipyleaflet +Map visualization using [ipyleaflet](https://ipyleaflet.readthedocs.io/), a ipywidgets wrapper for [leaflet.js](https://leafletjs.com/) +""" import dataclasses from dataclasses import dataclass -from pathlib import Path import geopandas as gpd +import ipyleaflet +import solara import xyzservices -import xyzservices.providers as xyz from folium.utilities import image_to_url -from mesa.visualization.ModularVisualization import VisualizationElement from shapely.geometry import Point, mapping from mesa_geo.raster_layers import RasterBase, RasterLayer from mesa_geo.tile_layers import LeafletOption, RasterWebTile +@solara.component +def map(model, map_drawer, zoom, center_default): + # render map in browser + zoom_map = solara.reactive(zoom) + center = solara.reactive(center_default) + + base_map = map_drawer.tiles + layers = map_drawer.render(model) + + ipyleaflet.Map.element( + zoom=zoom_map.value, + center=center.value, + scroll_wheel_zoom=True, + layers=[ + ipyleaflet.TileLayer.element(url=base_map["url"]), + ipyleaflet.GeoJSON.element(data=layers["agents"]), + ], + ) + + +@solara.component +def map_jupyter(model, map_drawer, zoom, center_default): + zoom_map = solara.reactive(zoom) + center = solara.reactive(center_default) + + base_map = map_drawer.tiles + layers = map_drawer.render(model) + + # prevents overlap of map with measures + with solara.Column(style={"isolation": "isolate"}): + ipyleaflet.Map.element( + zoom=zoom_map.value, + center=center.value, + scroll_wheel_zoom=True, + layers=[ + ipyleaflet.TileLayer.element(url=base_map["url"]), + ipyleaflet.GeoJSON.element(data=layers["agents"]), + ], + ) + + @dataclass -class LeafletPortrayal: +class LeafletViz: """A dataclass defining the portrayal of a GeoAgent in Leaflet map. The fields are defined to be consistent with GeoJSON options in @@ -28,7 +71,7 @@ class LeafletPortrayal: popupProperties: dict[str, LeafletOption] | None = None # noqa: N815 -class MapModule(VisualizationElement): +class MapModule: """A MapModule for Leaflet maps that uses a user-defined portrayal method to generate a portrayal of a raster Cell or a GeoAgent. @@ -40,22 +83,12 @@ class MapModule(VisualizationElement): - In addition, the portrayal dictionary can contain a "description" key, which will be used as the popup text. """ - local_includes = [ - "js/MapModule.js", - "css/external/leaflet.css", - "js/external/leaflet.js", - ] - local_dir = (Path(__file__).parent / "../templates").resolve() - def __init__( self, - portrayal_method=None, - view=None, - zoom=None, - map_width=500, - map_height=500, - tiles=xyz.OpenStreetMap.Mapnik, - scale_options=None, + portrayal_method, + view, + zoom, + tiles, ): """ Create a new MapModule. @@ -112,16 +145,8 @@ def __init__( self._crs = "epsg:4326" if isinstance(tiles, xyzservices.TileProvider): - tiles = RasterWebTile.from_xyzservices(tiles) - tiles_js = tiles.to_dict() if tiles is not None else None - new_element = f"new MapModule({view}, {zoom}, {map_width}, {map_height}, {tiles_js}, {scale_options})" - self.js_code = f"elements.push({new_element});" - for py_str, js_str in { - "None": "null", - "True": "true", - "False": "false", - }.items(): - self.js_code = self.js_code.replace(py_str, js_str) + tiles = RasterWebTile.from_xyzservices(tiles).to_dict() + self.tiles = tiles def render(self, model): return { @@ -167,7 +192,7 @@ def _render_agents(self, model): agent_portrayal = {} if self.portrayal_method: properties = self.portrayal_method(agent) - agent_portrayal = LeafletPortrayal( + agent_portrayal = LeafletViz( popupProperties=properties.pop("description", None) ) if isinstance(agent.geometry, Point): diff --git a/mesa_geo/visualization/modules/__init__.py b/mesa_geo/visualization/modules/__init__.py deleted file mode 100644 index a4afab58..00000000 --- a/mesa_geo/visualization/modules/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Container for all built-in visualization modules. -""" - -from mesa_geo.visualization.modules.MapVisualization import MapModule # noqa diff --git a/mesa_geo/visualization/templates/js/MapModule.js b/mesa_geo/visualization/templates/js/MapModule.js deleted file mode 100644 index 1ba222e6..00000000 --- a/mesa_geo/visualization/templates/js/MapModule.js +++ /dev/null @@ -1,90 +0,0 @@ -const MapModule = function (view, zoom, map_width, map_height, tiles, scale_options) { - // Create the map tag - const map_tag = document.createElement("div"); - map_tag.style.width = map_width + "px"; - map_tag.style.height = map_height + "px"; - map_tag.style.border = "1px dotted"; - map_tag.id = "mapid" - const customView = (view !== null && zoom !== null) - - // Append it to #elements - const elements = document.getElementById("elements"); - elements.appendChild(map_tag); - - // Create Leaflet map and Agent layers - const Lmap = L.map('mapid', {zoomSnap: 0.1}) - if (customView) { - Lmap.setView(view, zoom) - } - if (scale_options !== null) { - L.control.scale(scale_options).addTo(Lmap) - } - let agentLayer = L.geoJSON().addTo(Lmap) - - // create tile layer - if (tiles !== null) { - if (tiles.kind === "raster_web_tile") { - L.tileLayer(tiles.url, tiles.options).addTo(Lmap) - } else if (tiles.kind === "wms_web_tile") { - L.tileLayer.wms(tiles.url, tiles.options).addTo(Lmap) - } else { - throw new Error("Unknown tile type: " + tiles.kind) - } - } - - let mapLayers = [] - let hasFitBounds = false - this.renderLayers = function (layers) { - mapLayers.forEach(layer => {layer.remove()}) - mapLayers = [] - - layers.rasters.forEach(function (layer) { - const rasterLayer = L.imageOverlay(layer, layers.total_bounds).addTo(Lmap) - mapLayers.push(rasterLayer) - }) - layers.vectors.forEach(function (layer) { - const vectorLayer = L.geoJSON(layer).addTo(Lmap) - mapLayers.push(vectorLayer) - }) - if (!hasFitBounds && !customView && layers.total_bounds.length !== 0) { - Lmap.fitBounds(layers.total_bounds) - hasFitBounds = true - } - } - - this.renderAgents = function (agents) { - agentLayer.remove() - agentLayer = L.geoJSON(agents, { - onEachFeature: PopUpProperties, - style: function (feature) { - return feature.properties.style - }, - pointToLayer: function (feature, latlang) { - return L.circleMarker(latlang, feature.properties.pointToLayer); - } - }).addTo(Lmap) - } - - this.render = function (data) { - this.renderLayers(data.layers) - this.renderAgents(data.agents) - } - - this.reset = function () { - agentLayer.remove() - mapLayers.forEach(layer => {layer.remove()}) - mapLayers = [] - } -} - - -function PopUpProperties(feature, layer) { - let popupContent = '
' + p + ' | ' + feature.properties.popupProperties[p] + ' |