Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 135 additions & 31 deletions mesa/visualization/components/altair_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import warnings

import altair as alt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import solara
from matplotlib.colors import to_rgba
from matplotlib.colors import to_rgb

import mesa
from mesa.discrete_space import DiscreteSpace, Grid
Expand Down Expand Up @@ -183,48 +184,48 @@ def _draw_grid(space, agent_portrayal, propertylayer_portrayal):
encoding_dict["color"] = alt.Color(
"color:N",
scale=alt.Scale(domain=unique_colors, range=unique_colors),
legend=None,
)
has_size = "size" in all_agent_data[0]
if has_size:
encoding_dict["size"] = alt.Size("size", type="quantitative", legend=None)
encoding_dict["size"] = alt.Size("size", type="quantitative")

agent_chart = (
alt.Chart(
alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict)
)
.mark_point(filled=True)
.properties(width=300, height=300)
# .configure_view(strokeOpacity=0) # hide grid/chart lines
)
# This is the default value for the marker size, which auto-scales
# according to the grid area.

# This is the default value for the marker size, which auto-scales according to the grid area.
if not has_size:
length = min(space.width, space.height)
chart = agent_chart.mark_point(size=30000 / length**2, filled=True)
agent_chart = agent_chart.mark_point(size=30000 / length**2, filled=True)

if propertylayer_portrayal is not None:
base_width = agent_chart.properties().width
base_height = agent_chart.properties().height
chart_width = agent_chart.properties().width
chart_height = agent_chart.properties().height
chart = chart_property_layers(
space=space,
propertylayer_portrayal=propertylayer_portrayal,
base_width=base_width,
base_height=base_height,
chart_width=chart_width,
chart_height=chart_height,
)
chart = chart + agent_chart
else:
chart = agent_chart

chart = chart + agent_chart
return chart
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, chart does not exist here, you will have to declare a chart variable above to use it here. The chart declared inside the if statement get destroyed with it.

Copy link
Collaborator Author

@sanika-n sanika-n Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong here but from what I know, in languages like C++ a variable defined inside a loop only exists within that loop’s scope but in python I am fairly sure that variables defined in loops remain accessible outside the loop and since I am defining chart both in the if and else part of the loop, it is definitely going to be defined by the time we reach the return line.
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're right that Python variables from if/else blocks remain accessible afterward, it's safer to initialize chart at the function level first. This ensures it's always defined regardless of execution path. Could you update your code to follow this pattern? It prevents potential undefined variable issues if your conditions change later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okk, that makes sense, will change it 👍



def chart_property_layers(space, propertylayer_portrayal, base_width, base_height):
def chart_property_layers(space, propertylayer_portrayal, chart_width, chart_height):
"""Creates Property Layers in the Altair Components.

Args:
space: the ContinuousSpace instance
propertylayer_portrayal:Dictionary of PropertyLayer portrayal specifications
base_width: width of the agent chart to maintain consistency with the property charts
base_height: height of the agent chart to maintain consistency with the property charts
chart_width: width of the agent chart to maintain consistency with the property charts
chart_height: height of the agent chart to maintain consistency with the property charts
Returns:
Altair Chart
"""
Expand All @@ -245,7 +246,7 @@ def chart_property_layers(space, propertylayer_portrayal, base_width, base_heigh

data = layer.data.astype(float) if layer.data.dtype == bool else layer.data

if (space.width, space.height) is not data.shape:
if (space.width, space.height) != data.shape:
warnings.warn(
f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({space.width}, {space.height}).",
UserWarning,
Expand All @@ -265,49 +266,152 @@ def chart_property_layers(space, propertylayer_portrayal, base_width, base_heigh
}
)

# Add RGBA color if "color" is in portrayal
if "color" in portrayal:
df["color"] = df["value"].apply(
lambda val,
portrayal=portrayal,
alpha=alpha: f"rgba({int(to_rgba(portrayal['color'], alpha=alpha)[0] * 255)}, {int(to_rgba(portrayal['color'], alpha=alpha)[1] * 255)}, {int(to_rgba(portrayal['color'], alpha=alpha)[2] * 255)}, {to_rgba(portrayal['color'], alpha=alpha)[3]:.2f})"
if val > 0
else "rgba(0, 0, 0, 0)"
)
# any value less than vmin will be mapped to the color corresponding to vmin
# any value more than vmax will be mapped to the color corresponding to vmax
def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal):
a = (val - vmin) / (vmax - vmin)
a = max(0, min(a, 1)) # to ensure that a is between 0 and 1
a *= alpha # vmax will have an opacity corresponding to alpha
rgb_color = to_rgb(portrayal["color"])
r = int(rgb_color[0] * 255)
g = int(rgb_color[1] * 255)
b = int(rgb_color[2] * 255)

return f"rgba({r}, {g}, {b}, {a:.2f})"

df["color"] = df["value"].apply(apply_rgba)

chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
color=alt.Color("color:N", legend=None),
fill=alt.Fill("color:N", scale=None),
)
.properties(width=base_width, height=base_height, title=layer_name)
.properties(width=chart_width, height=chart_height, title=layer_name)
)
base = (base + chart) if base is not None else chart
# Add colormap if "colormap" is in portrayal

if colorbar:
list_value = []
list_color = []

i = vmin

while i <= vmax:
list_value.append(i)
list_color.append(apply_rgba(i))
i += 1

if vmax not in list_value:
list_value.append(vmax)
list_color.append(apply_rgba(vmax))
df_colorbar = pd.DataFrame(
{
"value": list_value,
"color": list_color,
}
)

x_values = np.array(df_colorbar["value"])
rgba_colors = np.array(df_colorbar["color"])
# Ensure rgba_colors is a 2D array
if rgba_colors.ndim == 1:
rgba_colors = np.array(
[list(color) for color in rgba_colors]
) # Convert tuples to a 2D array

def parse_rgba(color_str):
if isinstance(color_str, str) and color_str.startswith("rgba"):
color_str = (
color_str.replace("rgba(", "").replace(")", "").split(",")
)
return np.array(
[
float(color_str[i]) / 255
if i < 3
else float(color_str[i])
for i in range(4)
],
dtype=float,
)
return np.array(
color_str, dtype=float
) # If already a tuple, convert to float

# Convert color strings to RGBA tuples (ensures correct dtype)
rgba_colors = np.array(
[parse_rgba(c) for c in df_colorbar["color"]], dtype=float
)

# Ensure rgba_colors is a 2D array with shape (n, 4)
rgba_colors = np.array(rgba_colors).reshape(-1, 4)

# Create an RGBA gradient image (256 steps for smooth transition)
gradient = np.zeros((50, 256, 4)) # (Height, Width, RGBA)

# Interpolate each channel (R, G, B, A) separately
interp_r = np.interp(
np.linspace(0, 255, 256),
np.linspace(0, 255, len(rgba_colors)),
rgba_colors[:, 0],
)
interp_g = np.interp(
np.linspace(0, 255, 256),
np.linspace(0, 255, len(rgba_colors)),
rgba_colors[:, 1],
)
interp_b = np.interp(
np.linspace(0, 255, 256),
np.linspace(0, 255, len(rgba_colors)),
rgba_colors[:, 2],
)
interp_a = np.interp(
np.linspace(0, 255, 256),
np.linspace(0, 255, len(rgba_colors)),
rgba_colors[:, 3],
)

interp_colors = np.stack(
[interp_r, interp_g, interp_b, interp_a], axis=-1
)
gradient[:] = interp_colors
fig, ax = plt.subplots(figsize=(6, 0.25), dpi=100)
ax.imshow(
gradient,
aspect="auto",
extent=[x_values.min(), x_values.max(), 0, 1],
)
ax.set_yticks([])
ax.set_xlabel(layer_name)
ax.set_xticks(np.linspace(x_values.min(), x_values.max(), 11))
plt.show()

elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")
cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])

chart = (
alt.Chart(df)
.mark_rect()
.mark_rect(opacity=alpha)
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
color=alt.Color(
"value:Q",
scale=cmap_scale,
title=layer_name if colorbar else None,
title=layer_name,
legend=alt.Legend(title=layer_name) if colorbar else None,
),
)
.properties(width=base_width, height=base_height, title=layer_name)
.properties(width=chart_width, height=chart_height)
)
base = (base + chart) if base is not None else chart

else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)
return chart
return base
Loading