|
19 | 19 | # Import necessary libraries
|
20 | 20 | import numpy as np
|
21 | 21 | import seaborn as sns
|
22 |
| -from mesa_models.boltzmann_wealth_model.model import ( |
23 |
| - BoltzmannWealthModel, |
24 |
| - MoneyAgent, |
25 |
| - compute_gini, |
26 |
| -) |
| 22 | +from mesa.examples.basic.boltzmann_wealth_model.agents import MoneyAgent |
| 23 | +from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealth |
27 | 24 |
|
28 | 25 | NUM_AGENTS = 10
|
29 | 26 |
|
30 | 27 |
|
31 | 28 | # Define the agent class
|
32 | 29 | class MoneyAgentRL(MoneyAgent):
|
33 |
| - def __init__(self, unique_id, model): |
34 |
| - super().__init__(unique_id, model) |
| 30 | + def __init__(self, model): |
| 31 | + super().__init__(model) |
35 | 32 | self.wealth = np.random.randint(1, NUM_AGENTS)
|
36 | 33 |
|
37 | 34 | def move(self, action):
|
@@ -74,45 +71,46 @@ def take_money(self):
|
74 | 71 |
|
75 | 72 | def step(self):
|
76 | 73 | # Get the action for the agent
|
77 |
| - action = self.model.action_dict[self.unique_id] |
| 74 | + # TODO: figure out why agents are being made twice |
| 75 | + action = self.model.action_dict[self.unique_id - 11] |
78 | 76 | # Move the agent based on the action
|
79 | 77 | self.move(action)
|
80 | 78 | # Take money from other agents in the same cell
|
81 | 79 | self.take_money()
|
82 | 80 |
|
83 | 81 |
|
84 | 82 | # Define the model class
|
85 |
| -class BoltzmannWealthModelRL(BoltzmannWealthModel, gymnasium.Env): |
86 |
| - def __init__(self, N, width, height): |
87 |
| - super().__init__(N, width, height) |
| 83 | +class BoltzmannWealthModelRL(BoltzmannWealth, gymnasium.Env): |
| 84 | + def __init__(self, n, width, height): |
| 85 | + super().__init__(n, width, height) |
88 | 86 | # Define the observation and action space for the RL model
|
89 | 87 | # The observation space is the wealth of each agent and their position
|
90 |
| - self.observation_space = gymnasium.spaces.Box(low=0, high=10 * N, shape=(N, 3)) |
| 88 | + self.observation_space = gymnasium.spaces.Box(low=0, high=10 * n, shape=(n, 3)) |
91 | 89 | # The action space is a MultiDiscrete space with 5 possible actions for each agent
|
92 |
| - self.action_space = gymnasium.spaces.MultiDiscrete([5] * N) |
| 90 | + self.action_space = gymnasium.spaces.MultiDiscrete([5] * n) |
93 | 91 | self.is_visualize = False
|
94 | 92 |
|
95 | 93 | def step(self, action):
|
96 | 94 | self.action_dict = action
|
97 | 95 | # Perform one step of the model
|
98 |
| - self.schedule.step() |
| 96 | + self.agents.shuffle_do("step") |
99 | 97 | # Collect data for visualization
|
100 | 98 | self.datacollector.collect(self)
|
101 | 99 | # Compute the new Gini coefficient
|
102 |
| - new_gini = compute_gini(self) |
| 100 | + new_gini = self.compute_gini() |
103 | 101 | # Compute the reward based on the change in Gini coefficient
|
104 | 102 | reward = self.calculate_reward(new_gini)
|
105 | 103 | self.prev_gini = new_gini
|
106 | 104 | # Get the observation for the RL model
|
107 | 105 | obs = self._get_obs()
|
108 |
| - if self.schedule.time > 5 * NUM_AGENTS: |
| 106 | + if self.time > 5 * NUM_AGENTS: |
109 | 107 | # Terminate the episode if the model has run for a certain number of timesteps
|
110 | 108 | done = True
|
111 | 109 | reward = -1
|
112 | 110 | elif new_gini < 0.1:
|
113 | 111 | # Terminate the episode if the Gini coefficient is below a certain threshold
|
114 | 112 | done = True
|
115 |
| - reward = 50 / self.schedule.time |
| 113 | + reward = 50 / self.time |
116 | 114 | else:
|
117 | 115 | done = False
|
118 | 116 | info = {}
|
@@ -142,20 +140,18 @@ def reset(self, *, seed=None, options=None):
|
142 | 140 | self.visualize()
|
143 | 141 | super().reset()
|
144 | 142 | self.grid = mesa.space.MultiGrid(self.grid.width, self.grid.height, True)
|
145 |
| - self.schedule = mesa.time.RandomActivation(self) |
| 143 | + self.remove_all_agents() |
146 | 144 | for i in range(self.num_agents):
|
147 | 145 | # Create MoneyAgentRL instances and add them to the schedule
|
148 |
| - a = MoneyAgentRL(i, self) |
149 |
| - self.schedule.add(a) |
| 146 | + a = MoneyAgentRL(self) |
150 | 147 | x = self.random.randrange(self.grid.width)
|
151 | 148 | y = self.random.randrange(self.grid.height)
|
152 | 149 | self.grid.place_agent(a, (x, y))
|
153 |
| - self.prev_gini = compute_gini(self) |
| 150 | + self.prev_gini = self.compute_gini() |
154 | 151 | return self._get_obs(), {}
|
155 | 152 |
|
156 | 153 | def _get_obs(self):
|
157 | 154 | # The observation is the wealth of each agent and their position
|
158 |
| - obs = [] |
159 |
| - for a in self.schedule.agents: |
160 |
| - obs.append([a.wealth, *list(a.pos)]) |
| 155 | + obs = [[a.wealth, *a.pos] for a in self.agents] |
| 156 | + obs = np.array(obs) |
161 | 157 | return np.array(obs)
|
0 commit comments