Skip to content

Commit 1a36a75

Browse files
committed
update boltmann rl for Mesa 3
- updated boltzmann_rl for Mesa 3.0 - creating duplicate agents for some reason; need ot reset the unique_id iterator
1 parent 877d9ee commit 1a36a75

File tree

1 file changed

+20
-24
lines changed

1 file changed

+20
-24
lines changed

rl/boltzmann_money/model.py

+20-24
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,16 @@
1919
# Import necessary libraries
2020
import numpy as np
2121
import seaborn as sns
22-
from mesa_models.boltzmann_wealth_model.model import (
23-
BoltzmannWealthModel,
24-
MoneyAgent,
25-
compute_gini,
26-
)
22+
from mesa.examples.basic.boltzmann_wealth_model.agents import MoneyAgent
23+
from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealth
2724

2825
NUM_AGENTS = 10
2926

3027

3128
# Define the agent class
3229
class MoneyAgentRL(MoneyAgent):
33-
def __init__(self, unique_id, model):
34-
super().__init__(unique_id, model)
30+
def __init__(self, model):
31+
super().__init__(model)
3532
self.wealth = np.random.randint(1, NUM_AGENTS)
3633

3734
def move(self, action):
@@ -74,45 +71,46 @@ def take_money(self):
7471

7572
def step(self):
7673
# Get the action for the agent
77-
action = self.model.action_dict[self.unique_id]
74+
# TODO: figure out why agents are being made twice
75+
action = self.model.action_dict[self.unique_id - 11]
7876
# Move the agent based on the action
7977
self.move(action)
8078
# Take money from other agents in the same cell
8179
self.take_money()
8280

8381

8482
# Define the model class
85-
class BoltzmannWealthModelRL(BoltzmannWealthModel, gymnasium.Env):
86-
def __init__(self, N, width, height):
87-
super().__init__(N, width, height)
83+
class BoltzmannWealthModelRL(BoltzmannWealth, gymnasium.Env):
84+
def __init__(self, n, width, height):
85+
super().__init__(n, width, height)
8886
# Define the observation and action space for the RL model
8987
# The observation space is the wealth of each agent and their position
90-
self.observation_space = gymnasium.spaces.Box(low=0, high=10 * N, shape=(N, 3))
88+
self.observation_space = gymnasium.spaces.Box(low=0, high=10 * n, shape=(n, 3))
9189
# The action space is a MultiDiscrete space with 5 possible actions for each agent
92-
self.action_space = gymnasium.spaces.MultiDiscrete([5] * N)
90+
self.action_space = gymnasium.spaces.MultiDiscrete([5] * n)
9391
self.is_visualize = False
9492

9593
def step(self, action):
9694
self.action_dict = action
9795
# Perform one step of the model
98-
self.schedule.step()
96+
self.agents.shuffle_do("step")
9997
# Collect data for visualization
10098
self.datacollector.collect(self)
10199
# Compute the new Gini coefficient
102-
new_gini = compute_gini(self)
100+
new_gini = self.compute_gini()
103101
# Compute the reward based on the change in Gini coefficient
104102
reward = self.calculate_reward(new_gini)
105103
self.prev_gini = new_gini
106104
# Get the observation for the RL model
107105
obs = self._get_obs()
108-
if self.schedule.time > 5 * NUM_AGENTS:
106+
if self.time > 5 * NUM_AGENTS:
109107
# Terminate the episode if the model has run for a certain number of timesteps
110108
done = True
111109
reward = -1
112110
elif new_gini < 0.1:
113111
# Terminate the episode if the Gini coefficient is below a certain threshold
114112
done = True
115-
reward = 50 / self.schedule.time
113+
reward = 50 / self.time
116114
else:
117115
done = False
118116
info = {}
@@ -142,20 +140,18 @@ def reset(self, *, seed=None, options=None):
142140
self.visualize()
143141
super().reset()
144142
self.grid = mesa.space.MultiGrid(self.grid.width, self.grid.height, True)
145-
self.schedule = mesa.time.RandomActivation(self)
143+
self.remove_all_agents()
146144
for i in range(self.num_agents):
147145
# Create MoneyAgentRL instances and add them to the schedule
148-
a = MoneyAgentRL(i, self)
149-
self.schedule.add(a)
146+
a = MoneyAgentRL(self)
150147
x = self.random.randrange(self.grid.width)
151148
y = self.random.randrange(self.grid.height)
152149
self.grid.place_agent(a, (x, y))
153-
self.prev_gini = compute_gini(self)
150+
self.prev_gini = self.compute_gini()
154151
return self._get_obs(), {}
155152

156153
def _get_obs(self):
157154
# The observation is the wealth of each agent and their position
158-
obs = []
159-
for a in self.schedule.agents:
160-
obs.append([a.wealth, *list(a.pos)])
155+
obs = [[a.wealth, *a.pos] for a in self.agents]
156+
obs = np.array(obs)
161157
return np.array(obs)

0 commit comments

Comments
 (0)