diff --git a/mesa/experimental/altair_grid.py b/mesa/experimental/altair_grid.py new file mode 100644 index 00000000000..eaa423aab37 --- /dev/null +++ b/mesa/experimental/altair_grid.py @@ -0,0 +1,99 @@ +from typing import Callable, Optional + +import altair as alt +import solara + +import mesa + + +def get_agent_data_from_coord_iter(data): + """ + Extracts agent data from a sequence of tuples containing agent objects and their coordinates. + + Parameters: + - data (iterable): A sequence of tuples where each tuple contains an agent object and its coordinates. + + Yields: + - dict: A dictionary containing agent data with updated coordinates. The dictionary excludes 'model' and 'pos' attributes. + """ + for agent, (x, y) in data: + if agent: + agent_data = agent[0].__dict__.copy() + agent_data.update({"x": x, "y": y}) + agent_data.pop("model", None) + agent_data.pop("pos", None) + yield agent_data + + +def create_grid( + color: Optional[str] = None, + on_click: Optional[Callable[[mesa.Model, mesa.space.Coordinate], None]] = None, +) -> Callable[[mesa.Model], solara.component]: + """ + Factory function for creating a grid component for a Mesa model. + + Parameters: + - color (Optional[str]): Color of the grid lines. Defaults to None. + - on_click (Optional[Callable[[mesa.Model, mesa.space.Coordinate], None]]): + Function to be called when a grid cell is clicked. Defaults to None. + + Returns: + - Callable[[mesa.Model], solara.component]: A function that creates a grid component for the given model. + """ + + def create_grid_function(model: mesa.Model) -> solara.component: + return Grid(model, color, on_click) + + return create_grid_function + + +def Grid(model, color=None, on_click=None): + """ + Handles click events on grid cells. + + Parameters: + - datum (dict): Data associated with the clicked cell. + + Notes: + - Invokes the provided `on_click` function with the model and cell coordinates. + - Updates the data displayed on the grid. + """ + if color is None: + color = "unique_id:N" + + if color[-2] != ":": + color = color + ":N" + + print(model.grid.coord_iter()) + + data = solara.reactive( + list(get_agent_data_from_coord_iter(model.grid.coord_iter())) + ) + + def update_data(): + data.value = list(get_agent_data_from_coord_iter(model.grid.coord_iter())) + + def click_handler(datum): + if datum is None: + return + on_click(model, datum["x"], datum["y"]) + update_data() + + default_tooltip = [ + f"{key}:N" for key in data.value[0] + ] # add all agent attributes to tooltip + chart = ( + alt.Chart(alt.Data(values=data.value)) + .mark_rect() + .encode( + x=alt.X("x:N", scale=alt.Scale(domain=list(range(model.grid.width)))), + y=alt.Y( + "y:N", + scale=alt.Scale(domain=list(range(model.grid.height - 1, -1, -1))), + ), + color=color, + tooltip=default_tooltip, + ) + .properties(width=600, height=600) + ) + return solara.FigureAltair(chart, on_click=click_handler) diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index a6ae318a822..9f371fc6fa0 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -1,5 +1,6 @@ import sys import threading +from typing import Optional import matplotlib.pyplot as plt import reacton.ipywidgets as widgets @@ -7,6 +8,7 @@ from solara.alias import rv import mesa.experimental.components.matplotlib as components_matplotlib +from mesa.experimental.altair_grid import create_grid from mesa.experimental.UserParam import Slider # Avoid interactive backend @@ -113,6 +115,9 @@ def render_in_jupyter(): components_matplotlib.SpaceMatplotlib( model, agent_portrayal, dependencies=[current_step.value] ) + elif space_drawer == "altair": + # draw with the default implementation + SpaceAltair(model, agent_portrayal, dependencies=[current_step.value]) elif space_drawer: # if specified, draw agent space with an alternate renderer space_drawer(model, agent_portrayal) @@ -128,7 +133,7 @@ def render_in_jupyter(): model, measure, dependencies=[current_step.value] ) - def render_in_browser(): + def render_in_browser(statistics=False): # if space drawer is disabled, do not include it layout_types = [{"Space": "default"}] if space_drawer else [] @@ -144,6 +149,13 @@ def render_in_browser(): ModelController(model, play_interval, current_step, reset_counter) with solara.Card("Progress", margin=1, elevation=2): solara.Markdown(md_text=f"####Step - {current_step}") + with solara.Card("Analytics", margin=1, elevation=2): + if statistics: + df = model.datacollector.get_model_vars_dataframe() + for col in list(df.columns): + solara.Markdown( + md_text=f"####Avg. {col} - {df.loc[:, f'{col}'].mean()}" + ) items = [ Card( @@ -345,6 +357,12 @@ def change_handler(value, name=name): raise ValueError(f"{input_type} is not a supported input type") +@solara.component +def SpaceAltair(model, agent_portrayal, dependencies: Optional[list[any]] = None): + grid = create_grid(color="wealth") + grid(model) + + def make_text(renderer): def function(model): solara.Markdown(renderer(model)) @@ -356,7 +374,7 @@ def get_initial_grid_layout(layout_types): grid_lay = [] y_coord = 0 for ii in range(len(layout_types)): - template_layout = {"h": 10, "i": 0, "moved": False, "w": 6, "y": 0, "x": 0} + template_layout = {"h": 20, "i": 0, "moved": False, "w": 6, "y": 0, "x": 0} if ii == 0: grid_lay.append(template_layout) else: diff --git a/pyproject.toml b/pyproject.toml index 0233bc5874b..a4e8df9f2e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "pandas", "solara", "tqdm", + "altair" ] dynamic = ["version"] diff --git a/tests/test_jupyter_viz.py b/tests/test_jupyter_viz.py index 248e8c319d5..ae0b324c0fa 100644 --- a/tests/test_jupyter_viz.py +++ b/tests/test_jupyter_viz.py @@ -92,6 +92,7 @@ def test_call_space_drawer(self, mock_space_matplotlib): } current_step = 0 dependencies = [current_step] + # initialize with space drawer unspecified (use default) # component must be rendered for code to run solara.render( @@ -99,8 +100,10 @@ def test_call_space_drawer(self, mock_space_matplotlib): model_class=mock_model_class, model_params={}, agent_portrayal=agent_portrayal, + space_drawer="default", ) ) + # should call default method with class instance and agent portrayal mock_space_matplotlib.assert_called_with( mock_model_class.return_value, agent_portrayal, dependencies=dependencies