From 86293108d98492f0cc45420338fb470cf8f1ed15 Mon Sep 17 00:00:00 2001 From: Robert Hopkins Date: Wed, 21 Aug 2024 13:16:19 -0400 Subject: [PATCH 01/10] Update matplotlib.py --- mesa/visualization/components/matplotlib.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 83d0e3d8eaf..ca1f94be34c 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -22,6 +22,7 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = Non _draw_continuous_space(space, space_ax, agent_portrayal) else: _draw_grid(space, space_ax, agent_portrayal) + solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) @@ -94,6 +95,7 @@ def portray(g): _split_and_scatter(portray(space), space_ax) +# draws using networkx's matplotlib integration def _draw_network_grid(space, space_ax, agent_portrayal): graph = space.G pos = nx.spring_layout(graph, seed=0) From 2cc3f105a3b89dcc1ce68dea9f2ed78e069aa637 Mon Sep 17 00:00:00 2001 From: Robert Hopkins Date: Thu, 22 Aug 2024 11:35:20 -0400 Subject: [PATCH 02/10] matplotlib visualization supports more params No longer are shape, color, and marker explicitly implemented. Instead, the implementation is more parameter-independent. --- mesa/visualization/components/matplotlib.py | 164 ++++++++++---------- 1 file changed, 86 insertions(+), 78 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index ca1f94be34c..8de5eae40d8 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -4,6 +4,7 @@ import solara from matplotlib.figure import Figure from matplotlib.ticker import MaxNLocator +from matplotlib.pyplot import get_cmap import mesa @@ -27,71 +28,80 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = Non # matplotlib scatter does not allow for multiple shapes in one call -def _split_and_scatter(portray_data, space_ax): - grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []}) - - # Extract data from the dictionary - x = portray_data["x"] - y = portray_data["y"] - s = portray_data["s"] - c = portray_data["c"] - m = portray_data["m"] - - if not (len(x) == len(y) == len(s) == len(c) == len(m)): - raise ValueError( - "Length mismatch in portrayal data lists: " - f"x: {len(x)}, y: {len(y)}, size: {len(s)}, " - f"color: {len(c)}, marker: {len(m)}" - ) - - # Group the data by marker - for i in range(len(x)): - marker = m[i] - grouped_data[marker]["x"].append(x[i]) - grouped_data[marker]["y"].append(y[i]) - grouped_data[marker]["s"].append(s[i]) - grouped_data[marker]["c"].append(c[i]) - - # Plot each group with the same marker - for marker, data in grouped_data.items(): - space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker) +def _split_and_scatter(space_ax, portray_data) -> None: + cmap = portray_data.pop("cmap", None) + + # enforce marker iterability + markers = portray_data.pop("marker", ["o"] * len(portray_data["x"])) + + # enforce default color + # if no 'color' or 'facecolor' or 'c' then default to "tab:blue" color + if ( + "color" not in portray_data + and "facecolor" not in portray_data + and "c" not in portray_data + ): + portray_data["color"] = ["tab:blue"] * len(portray_data["x"]) + + grouped_data = defaultdict(lambda: {key: [] for key in portray_data}) + + for i, marker in enumerate(markers): + for key in portray_data: + # apply colormap + if cmap and key == "c": + color = portray_data[key][i] + + # TODO: break into helper functions for readability + # apply color map if not RGB(A) format or color string (mimicking default matplotlib behavior) + if not ( + isinstance(color, str) # str format + or ( + (len(color) == 3 or len(color) == 4) # RGB(A) + and ( + all( + # all floats, valid RGB(A) + isinstance(c, (int, float)) and 0 <= c <= 1 + for c in color + ) + ) + ) + ): + color = get_cmap(cmap[i])(color) + + grouped_data[marker][key].append(color) + elif key != "cmap": # do nothing special, don't pass on color maps + grouped_data[marker][key].append(portray_data[key][i]) + + print(grouped_data) def _draw_grid(space, space_ax, agent_portrayal): def portray(g): - x = [] - y = [] - s = [] # size - c = [] # color - m = [] # shape - for i in range(g.width): - for j in range(g.height): - content = g._grid[i][j] - if not content: - continue - if not hasattr(content, "__iter__"): - # Is a single grid - content = [content] - for agent in content: - data = agent_portrayal(agent) - x.append(i) - y.append(j) - - # This is the default value for the marker size, which auto-scales - # according to the grid area. - default_size = (180 / max(g.width, g.height)) ** 2 - # establishing a default prevents misalignment if some agents are not given size, color, etc. - size = data.get("size", default_size) - s.append(size) - color = data.get("color", "tab:blue") - c.append(color) - mark = data.get("shape", "o") - m.append(mark) - out = {"x": x, "y": y, "s": s, "c": c, "m": m} + + default_values = { + "size": (180 / max(g.width, g.height)) ** 2, + } + + out = {} + + # used to initialize lists for alignment purposes + num_agents = len(space._agent_to_index) + + for i, agent in enumerate(space._agent_to_index): + data = agent_portrayal(agent) + + for key, value in data.items(): + if key not in out: + # initialize list + out[key] = [default_values.get(key, default=None)] * num_agents + out[key][i] = value + return out space_ax.set_xlim(-1, space.width) space_ax.set_ylim(-1, space.height) + + # portray and scatter the agents in the space _split_and_scatter(portray(space), space_ax) @@ -109,27 +119,25 @@ def _draw_network_grid(space, space_ax, agent_portrayal): def _draw_continuous_space(space, space_ax, agent_portrayal): def portray(space): - x = [] - y = [] - s = [] # size - c = [] # color - m = [] # shape - for agent in space._agent_to_index: + + default_values = { + "size": 20, + } + + out = {} + + # used to initialize lists for alignment purposes + num_agents = len(space._agent_to_index) + + for i, agent in enumerate(space._agent_to_index): data = agent_portrayal(agent) - _x, _y = agent.pos - x.append(_x) - y.append(_y) - - # This is matplotlib's default marker size - default_size = 20 - # establishing a default prevents misalignment if some agents are not given size, color, etc. - size = data.get("size", default_size) - s.append(size) - color = data.get("color", "tab:blue") - c.append(color) - mark = data.get("shape", "o") - m.append(mark) - out = {"x": x, "y": y, "s": s, "c": c, "m": m} + + for key, value in data.items(): + if key not in out: + # initialize list + out[key] = [default_values.get(key, default=None)] * num_agents + out[key][i] = value + return out # Determine border style based on space.torus @@ -148,7 +156,7 @@ def portray(space): space_ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding) space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding) - # Portray and scatter the agents in the space + # portray and scatter the agents in the space _split_and_scatter(portray(space), space_ax) From bc67c60f87e099085bf14bec417ecf76c489fece Mon Sep 17 00:00:00 2001 From: Robert Hopkins Date: Thu, 22 Aug 2024 17:17:01 -0400 Subject: [PATCH 03/10] old param keywords, x,y pos old parameter keywords now work for backwards compatibility. x and y position are now added and the code definitely does work. --- mesa/visualization/components/matplotlib.py | 77 +++++++++++++++------ 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 8de5eae40d8..08a1beb5402 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -27,8 +27,15 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = Non solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) +# used to make non(less?)-breaking change +# this *does* however block the matplotlib 'color' param which is distinct from 'c'. +def _translate_old_keywords(dict): + key_mapping: dict[str, str] = {"size": "s", "color": "c", "shape": "marker"} + return {key_mapping.get(key, key): val for (key, val) in dict.items()} + + # matplotlib scatter does not allow for multiple shapes in one call -def _split_and_scatter(space_ax, portray_data) -> None: +def _split_and_scatter(portray_data, space_ax) -> None: cmap = portray_data.pop("cmap", None) # enforce marker iterability @@ -51,16 +58,16 @@ def _split_and_scatter(space_ax, portray_data) -> None: if cmap and key == "c": color = portray_data[key][i] - # TODO: break into helper functions for readability - # apply color map if not RGB(A) format or color string (mimicking default matplotlib behavior) + # TODO: somehow format differently, black formatter makes this very ugly + # apply color map if not RGB(A) or string format (mimicking default matplotlib behavior) if not ( - isinstance(color, str) # str format + isinstance(color, str) # string format color ('tab:blue') or ( - (len(color) == 3 or len(color) == 4) # RGB(A) + (len(color) in {3, 4}) and ( all( - # all floats, valid RGB(A) - isinstance(c, (int, float)) and 0 <= c <= 1 + # all floats, then RGB(A) + isinstance(c, (int, float)) for c in color ) ) @@ -72,7 +79,8 @@ def _split_and_scatter(space_ax, portray_data) -> None: elif key != "cmap": # do nothing special, don't pass on color maps grouped_data[marker][key].append(portray_data[key][i]) - print(grouped_data) + for marker, data in grouped_data.items(): + space_ax.scatter(marker=marker, **data) def _draw_grid(space, space_ax, agent_portrayal): @@ -84,19 +92,40 @@ def portray(g): out = {} + # TODO: find way to avoid iterating twice # used to initialize lists for alignment purposes - num_agents = len(space._agent_to_index) - - for i, agent in enumerate(space._agent_to_index): - data = agent_portrayal(agent) - - for key, value in data.items(): - if key not in out: - # initialize list - out[key] = [default_values.get(key, default=None)] * num_agents - out[key][i] = value - - return out + num_agents = 0 + for i in range(g.width): + for j in range(g.height): + content = g._grid[i][j] + if not content: + continue + if not hasattr(content, "__iter__"): + num_agents += 1 + continue + num_agents += len(content) + + index = 0 + for i in range(g.width): + for j in range(g.height): + content = g._grid[i][j] + if not content: + continue + if not hasattr(content, "__iter__"): + # Is a single grid + content = [content] + for agent in content: + data = agent_portrayal(agent) + data["x"] = i + data["y"] = j + + for key, value in data.items(): + if key not in out: + # initialize list + out[key] = [default_values.get(key, None)] * num_agents + out[key][index] = value + index += 1 + return _translate_old_keywords(out) space_ax.set_xlim(-1, space.width) space_ax.set_ylim(-1, space.height) @@ -120,8 +149,11 @@ def _draw_network_grid(space, space_ax, agent_portrayal): def _draw_continuous_space(space, space_ax, agent_portrayal): def portray(space): + # TODO: look into if more default values are needed + # especially relating to 'color', 'facecolor', and 'c' params & + # interactions w/ the current implementation of _split_and_scatter default_values = { - "size": 20, + "s": 20, } out = {} @@ -131,6 +163,7 @@ def portray(space): for i, agent in enumerate(space._agent_to_index): data = agent_portrayal(agent) + data["x"], data["y"] = agent.pos for key, value in data.items(): if key not in out: @@ -138,7 +171,7 @@ def portray(space): out[key] = [default_values.get(key, default=None)] * num_agents out[key][i] = value - return out + return _translate_old_keywords(out) # Determine border style based on space.torus border_style = "solid" if not space.torus else (0, (5, 10)) From 96326003c0a4d64f3c10dc9ad1878133017d2e3e Mon Sep 17 00:00:00 2001 From: Robert Hopkins Date: Thu, 22 Aug 2024 21:42:19 -0400 Subject: [PATCH 04/10] implemented 'norm', reformatted colormap application --- mesa/visualization/components/matplotlib.py | 35 ++++++++++----------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 08a1beb5402..ce909a75fe0 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,5 +1,6 @@ from collections import defaultdict +from matplotlib.colors import Normalize import networkx as nx import solara from matplotlib.figure import Figure @@ -37,10 +38,10 @@ def _translate_old_keywords(dict): # matplotlib scatter does not allow for multiple shapes in one call def _split_and_scatter(portray_data, space_ax) -> None: cmap = portray_data.pop("cmap", None) + norm = portray_data.pop("norm", None) # enforce marker iterability markers = portray_data.pop("marker", ["o"] * len(portray_data["x"])) - # enforce default color # if no 'color' or 'facecolor' or 'c' then default to "tab:blue" color if ( @@ -57,26 +58,23 @@ def _split_and_scatter(portray_data, space_ax) -> None: # apply colormap if cmap and key == "c": color = portray_data[key][i] - - # TODO: somehow format differently, black formatter makes this very ugly - # apply color map if not RGB(A) or string format (mimicking default matplotlib behavior) - if not ( - isinstance(color, str) # string format color ('tab:blue') - or ( - (len(color) in {3, 4}) - and ( - all( - # all floats, then RGB(A) - isinstance(c, (int, float)) - for c in color + # apply color map only if the color is numerical representation + # this ignores RGB(A) and string formats (mimicking default matplotlib behavior) + if isinstance(color, (int, float)): + if norm: + if not isinstance( + norm[i], Normalize + ): # does not support string norms (yet?) + raise TypeError( + "'norm' param must be of type Normalize or a subclass." ) - ) - ) - ): + else: + color = norm[i](color) color = get_cmap(cmap[i])(color) - grouped_data[marker][key].append(color) - elif key != "cmap": # do nothing special, don't pass on color maps + elif ( + key != "cmap" and key != "norm" + ): # do nothing special, don't pass on color maps grouped_data[marker][key].append(portray_data[key][i]) for marker, data in grouped_data.items(): @@ -125,6 +123,7 @@ def portray(g): out[key] = [default_values.get(key, None)] * num_agents out[key][index] = value index += 1 + return _translate_old_keywords(out) space_ax.set_xlim(-1, space.width) From ed6cf811eb1aa7adc95b2970e9291ca98848a50a Mon Sep 17 00:00:00 2001 From: Robert Hopkins Date: Thu, 22 Aug 2024 22:05:06 -0400 Subject: [PATCH 05/10] formatting --- mesa/visualization/components/matplotlib.py | 52 +++++++++------------ 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index ce909a75fe0..9352ec011d4 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,12 +1,11 @@ from collections import defaultdict +import matplotlib.pyplot as plt from matplotlib.colors import Normalize -import networkx as nx -import solara from matplotlib.figure import Figure from matplotlib.ticker import MaxNLocator -from matplotlib.pyplot import get_cmap - +import networkx as nx +import solara import mesa @@ -30,21 +29,23 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = Non # used to make non(less?)-breaking change # this *does* however block the matplotlib 'color' param which is distinct from 'c'. -def _translate_old_keywords(dict): - key_mapping: dict[str, str] = {"size": "s", "color": "c", "shape": "marker"} - return {key_mapping.get(key, key): val for (key, val) in dict.items()} +def _translate_old_keywords(data): + """ + Translates old keyword names in the given dictionary to the new names. + """ + key_mapping = {"size": "s", "color": "c", "shape": "marker"} + return {key_mapping.get(key, key): val for (key, val) in data.items()} # matplotlib scatter does not allow for multiple shapes in one call -def _split_and_scatter(portray_data, space_ax) -> None: +def _split_and_scatter(portray_data: dict, space_ax) -> None: cmap = portray_data.pop("cmap", None) norm = portray_data.pop("norm", None) # enforce marker iterability markers = portray_data.pop("marker", ["o"] * len(portray_data["x"])) # enforce default color - # if no 'color' or 'facecolor' or 'c' then default to "tab:blue" color - if ( + if ( # if no 'color' or 'facecolor' or 'c' then default to "tab:blue" color "color" not in portray_data and "facecolor" not in portray_data and "c" not in portray_data @@ -64,16 +65,16 @@ def _split_and_scatter(portray_data, space_ax) -> None: if norm: if not isinstance( norm[i], Normalize - ): # does not support string norms (yet?) + ): # string param norms not yet supported raise TypeError( - "'norm' param must be of type Normalize or a subclass." + "'norm' must be an instance of Normalize or its subclasses." ) - else: - color = norm[i](color) - color = get_cmap(cmap[i])(color) + color = norm[i](color) + color = plt.get_cmap(cmap[i])(color) grouped_data[marker][key].append(color) - elif ( - key != "cmap" and key != "norm" + elif key not in ( + "cmap", + "norm", ): # do nothing special, don't pass on color maps grouped_data[marker][key].append(portray_data[key][i]) @@ -89,10 +90,7 @@ def portray(g): } out = {} - - # TODO: find way to avoid iterating twice - # used to initialize lists for alignment purposes - num_agents = 0 + num_agents = 0 # TODO: find way to avoid iterating twice for i in range(g.width): for j in range(g.height): content = g._grid[i][j] @@ -129,7 +127,6 @@ def portray(g): space_ax.set_xlim(-1, space.width) space_ax.set_ylim(-1, space.height) - # portray and scatter the agents in the space _split_and_scatter(portray(space), space_ax) @@ -151,13 +148,8 @@ def portray(space): # TODO: look into if more default values are needed # especially relating to 'color', 'facecolor', and 'c' params & # interactions w/ the current implementation of _split_and_scatter - default_values = { - "s": 20, - } - + default_values = {"s": 20} out = {} - - # used to initialize lists for alignment purposes num_agents = len(space._agent_to_index) for i, agent in enumerate(space._agent_to_index): @@ -165,8 +157,7 @@ def portray(space): data["x"], data["y"] = agent.pos for key, value in data.items(): - if key not in out: - # initialize list + if key not in out: # initialize list out[key] = [default_values.get(key, default=None)] * num_agents out[key][i] = value @@ -188,7 +179,6 @@ def portray(space): space_ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding) space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding) - # portray and scatter the agents in the space _split_and_scatter(portray(space), space_ax) From 5b67adedb224c2faa221bcd3d1ad2b3faa54a8a3 Mon Sep 17 00:00:00 2001 From: Robert Hopkins Date: Fri, 23 Aug 2024 09:19:37 -0400 Subject: [PATCH 06/10] update to cmap, norm flow --- mesa/visualization/components/matplotlib.py | 31 +++++++++++++-------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 9352ec011d4..30c5fa69b7b 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -39,8 +39,11 @@ def _translate_old_keywords(data): # matplotlib scatter does not allow for multiple shapes in one call def _split_and_scatter(portray_data: dict, space_ax) -> None: - cmap = portray_data.pop("cmap", None) - norm = portray_data.pop("norm", None) + # if any cmaps are passed, this is true + cmap_exists = portray_data.get("cmap", None) + cmap = None + # if any norms are passed, this is true + norm_exists = portray_data.get("norm", None) # enforce marker iterability markers = portray_data.pop("marker", ["o"] * len(portray_data["x"])) @@ -55,26 +58,32 @@ def _split_and_scatter(portray_data: dict, space_ax) -> None: grouped_data = defaultdict(lambda: {key: [] for key in portray_data}) for i, marker in enumerate(markers): + if cmap_exists: + cmap = portray_data.get("cmap")[i] for key in portray_data: - # apply colormap + # apply colormap if applicable for this index if cmap and key == "c": color = portray_data[key][i] # apply color map only if the color is numerical representation # this ignores RGB(A) and string formats (mimicking default matplotlib behavior) if isinstance(color, (int, float)): - if norm: - if not isinstance( - norm[i], Normalize - ): # string param norms not yet supported - raise TypeError( - "'norm' must be an instance of Normalize or its subclasses." - ) - color = norm[i](color) + if norm_exists: + norm = portray_data.get("norm")[i] + if norm: + if not isinstance( + norm, Normalize + ): # string param norms not yet supported + raise TypeError( + "'norm' must be an instance of Normalize or its subclasses." + ) + color = norm(color) color = plt.get_cmap(cmap[i])(color) grouped_data[marker][key].append(color) elif key not in ( "cmap", "norm", + "vmin", + "vmax", ): # do nothing special, don't pass on color maps grouped_data[marker][key].append(portray_data[key][i]) From 2e4cdd0710d35fdbd4cd1665afb81bc307812cbe Mon Sep 17 00:00:00 2001 From: Robert Hopkins Date: Fri, 23 Aug 2024 09:41:36 -0400 Subject: [PATCH 07/10] colormap application moved to its own function --- mesa/visualization/components/matplotlib.py | 74 ++++++++++++--------- mesa/visualization/solara_viz.py | 3 +- 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 30c5fa69b7b..e863a0c9cba 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,5 +1,6 @@ from collections import defaultdict +from matplotlib.pylab import norm import matplotlib.pyplot as plt from matplotlib.colors import Normalize from matplotlib.figure import Figure @@ -37,13 +38,36 @@ def _translate_old_keywords(data): return {key_mapping.get(key, key): val for (key, val) in data.items()} +def _apply_color_map(color, cmap=None, norm=None, vmin=None, vmax=None): + """ + Given parameters for manual colormap application, applies color map + according to default implementation in matplotlib + """ + if not cmap: # if no colormap is provided, return original color + return color + color_map = plt.get_cmap(cmap) + if norm: # check if norm is provided and apply it + if not isinstance(norm, Normalize): + raise TypeError( + "'norm' must be an instance of Normalize or its subclasses." + ) + return color_map(norm(color)) + if not (vmin == None or vmax == None): # check for custom norm params + new_norm = Normalize(vmin, vmax) + return color_map(new_norm(color)) + try: + return color_map(color) + except Exception as e: + raise ValueError("Color mapping failed due to invalid arguments") from e + + # matplotlib scatter does not allow for multiple shapes in one call def _split_and_scatter(portray_data: dict, space_ax) -> None: - # if any cmaps are passed, this is true - cmap_exists = portray_data.get("cmap", None) - cmap = None - # if any norms are passed, this is true - norm_exists = portray_data.get("norm", None) + # if any of the following params are passed into portray(), this is true + cmap_exists = portray_data.pop("cmap", None) + norm_exists = portray_data.pop("norm", None) + vmin_exists = portray_data.pop("vmin", None) + vmax_exists = portray_data.pop("vmax", None) # enforce marker iterability markers = portray_data.pop("marker", ["o"] * len(portray_data["x"])) @@ -58,34 +82,20 @@ def _split_and_scatter(portray_data: dict, space_ax) -> None: grouped_data = defaultdict(lambda: {key: [] for key in portray_data}) for i, marker in enumerate(markers): - if cmap_exists: - cmap = portray_data.get("cmap")[i] + for key in portray_data: - # apply colormap if applicable for this index - if cmap and key == "c": - color = portray_data[key][i] - # apply color map only if the color is numerical representation - # this ignores RGB(A) and string formats (mimicking default matplotlib behavior) - if isinstance(color, (int, float)): - if norm_exists: - norm = portray_data.get("norm")[i] - if norm: - if not isinstance( - norm, Normalize - ): # string param norms not yet supported - raise TypeError( - "'norm' must be an instance of Normalize or its subclasses." - ) - color = norm(color) - color = plt.get_cmap(cmap[i])(color) - grouped_data[marker][key].append(color) - elif key not in ( - "cmap", - "norm", - "vmin", - "vmax", - ): # do nothing special, don't pass on color maps - grouped_data[marker][key].append(portray_data[key][i]) + if key == "c": # apply colormap if possible + # prepare arguments + cmap = cmap_exists[i] if cmap_exists else None + norm = norm_exists[i] if norm_exists else None + vmin = vmin_exists[i] if vmin_exists else None + vmax = vmax_exists[i] if vmax_exists else None + # apply colormap with prepared arguments + portray_data["c"][i] = _apply_color_map( + portray_data["c"][i], cmap, norm, vmin, vmax + ) + + grouped_data[marker][key].append(portray_data[key][i]) for marker, data in grouped_data.items(): space_ax.scatter(marker=marker, **data) diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index 6ec33231cae..347a44616d9 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -104,7 +104,8 @@ def SolaraViz( measures: List of callables or data attributes to plot name: Name for display agent_portrayal: Options for rendering agents (dictionary); - Default drawer supports custom `"size"`, `"color"`, and `"shape"`. + Default drawer supports custom matplotlib's [scatter](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html) + params, with the exception of (currently) vmin, vmax, & plotnonfinite. space_drawer: Method to render the agent space for the model; default implementation is the `SpaceMatplotlib` component; simulations with no space to visualize should From bf59cbca1b0bc8f9f2b1ca912dceaa2f117b36bd Mon Sep 17 00:00:00 2001 From: Robert Hopkins Date: Fri, 23 Aug 2024 09:43:01 -0400 Subject: [PATCH 08/10] Update matplotlib.py --- mesa/visualization/components/matplotlib.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index e863a0c9cba..09a1abe4cd4 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -29,7 +29,8 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = Non # used to make non(less?)-breaking change -# this *does* however block the matplotlib 'color' param which is distinct from 'c'. +# this *does* however block the matplotlib 'color' param which is somewhat distinct from 'c'. +# maybe translate 'size' and 'shape' but not 'color'? def _translate_old_keywords(data): """ Translates old keyword names in the given dictionary to the new names. From ddec45acf13db76f37e0a73dc178d60513574b9e Mon Sep 17 00:00:00 2001 From: Robert Hopkins Date: Fri, 23 Aug 2024 13:07:46 -0400 Subject: [PATCH 09/10] improved num_agents iteration --- mesa/visualization/components/matplotlib.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 09a1abe4cd4..10a60f3fdcd 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,6 +1,5 @@ from collections import defaultdict -from matplotlib.pylab import norm import matplotlib.pyplot as plt from matplotlib.colors import Normalize from matplotlib.figure import Figure @@ -8,6 +7,7 @@ import networkx as nx import solara import mesa +from mesa.space import GridContent @solara.component @@ -110,16 +110,14 @@ def portray(g): } out = {} - num_agents = 0 # TODO: find way to avoid iterating twice - for i in range(g.width): - for j in range(g.height): - content = g._grid[i][j] - if not content: - continue - if not hasattr(content, "__iter__"): - num_agents += 1 - continue - num_agents += len(content) + num_agents = 0 + for content in g: + if not content: + continue + if isinstance(content, GridContent): # one agent + num_agents += 1 + continue + num_agents += len(content) index = 0 for i in range(g.width): From c6a8a65d07563ae15f4b905ba1cf62e7c24e868a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Aug 2024 17:08:54 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/visualization/components/matplotlib.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 10a60f3fdcd..f85b640ed25 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,11 +1,12 @@ from collections import defaultdict import matplotlib.pyplot as plt +import networkx as nx +import solara from matplotlib.colors import Normalize from matplotlib.figure import Figure from matplotlib.ticker import MaxNLocator -import networkx as nx -import solara + import mesa from mesa.space import GridContent @@ -83,7 +84,6 @@ def _split_and_scatter(portray_data: dict, space_ax) -> None: grouped_data = defaultdict(lambda: {key: [] for key in portray_data}) for i, marker in enumerate(markers): - for key in portray_data: if key == "c": # apply colormap if possible # prepare arguments @@ -104,7 +104,6 @@ def _split_and_scatter(portray_data: dict, space_ax) -> None: def _draw_grid(space, space_ax, agent_portrayal): def portray(g): - default_values = { "size": (180 / max(g.width, g.height)) ** 2, } @@ -136,7 +135,7 @@ def portray(g): for key, value in data.items(): if key not in out: # initialize list - out[key] = [default_values.get(key, None)] * num_agents + out[key] = [default_values.get(key)] * num_agents out[key][index] = value index += 1 @@ -162,7 +161,6 @@ def _draw_network_grid(space, space_ax, agent_portrayal): def _draw_continuous_space(space, space_ax, agent_portrayal): def portray(space): - # TODO: look into if more default values are needed # especially relating to 'color', 'facecolor', and 'c' params & # interactions w/ the current implementation of _split_and_scatter