-
Notifications
You must be signed in to change notification settings - Fork 75
/
Copy pathippo.py
331 lines (293 loc) · 11.7 KB
/
ippo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from dataclasses import dataclass, MISSING
from typing import Dict, Iterable, Tuple, Type
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from torch.distributions import Categorical
from torchrl.data import Composite, Unbounded
from torchrl.modules import IndependentNormal, ProbabilisticActor, TanhNormal
from torchrl.modules.distributions import MaskedCategorical
from torchrl.objectives import ClipPPOLoss, LossModule, ValueEstimators
from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
class Ippo(Algorithm):
"""Independent PPO (from `https://arxiv.org/abs/2011.09533 <https://arxiv.org/abs/2011.09533>`__).
Args:
share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
clip_epsilon (scalar): weight clipping threshold in the clipped PPO loss equation.
entropy_coef (scalar): entropy multiplier when computing the total loss.
critic_coef (scalar): critic loss multiplier when computing the total
loss_critic_type (str): loss function for the value discrepancy.
Can be one of "l1", "l2" or "smooth_l1".
lmbda (float): The GAE lambda
scale_mapping (str): positive mapping function to be used with the std.
choices: "softplus", "exp", "relu", "biased_softplus_1";
use_tanh_normal (bool): if ``True``, use TanhNormal as the continuyous action distribution with support bound
to the action domain. Otherwise, an IndependentNormal is used.
minibatch_advantage (bool): if ``True``, advantage computation is perfomend on minibatches of size
``experiment.config.on_policy_minibatch_size`` instead of the full
``experiment.config.on_policy_collected_frames_per_batch``, this helps not exploding memory usage
"""
def __init__(
self,
share_param_critic: bool,
clip_epsilon: float,
entropy_coef: bool,
critic_coef: float,
loss_critic_type: str,
lmbda: float,
scale_mapping: str,
use_tanh_normal: bool,
minibatch_advantage: bool,
**kwargs
):
super().__init__(**kwargs)
self.share_param_critic = share_param_critic
self.clip_epsilon = clip_epsilon
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.loss_critic_type = loss_critic_type
self.lmbda = lmbda
self.scale_mapping = scale_mapping
self.use_tanh_normal = use_tanh_normal
self.minibatch_advantage = minibatch_advantage
#############################
# Overridden abstract methods
#############################
def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
# Loss
loss_module = ClipPPOLoss(
actor=policy_for_loss,
critic=self.get_critic(group),
clip_epsilon=self.clip_epsilon,
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
loss_critic_type=self.loss_critic_type,
normalize_advantage=False,
)
loss_module.set_keys(
reward=(group, "reward"),
action=(group, "action"),
done=(group, "done"),
terminated=(group, "terminated"),
advantage=(group, "advantage"),
value_target=(group, "value_target"),
value=(group, "state_value"),
sample_log_prob=(group, "log_prob"),
)
loss_module.make_value_estimator(
ValueEstimators.GAE, gamma=self.experiment_config.gamma, lmbda=self.lmbda
)
return loss_module, False
def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
return {
"loss_objective": list(loss.actor_network_params.flatten_keys().values()),
"loss_critic": list(loss.critic_network_params.flatten_keys().values()),
}
def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:
n_agents = len(self.group_map[group])
if continuous:
logits_shape = list(self.action_spec[group, "action"].shape)
logits_shape[-1] *= 2
else:
logits_shape = [
*self.action_spec[group, "action"].shape,
self.action_spec[group, "action"].space.n,
]
actor_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)
actor_output_spec = Composite(
{
group: Composite(
{"logits": Unbounded(shape=logits_shape)},
shape=(n_agents,),
)
}
)
actor_module = model_config.get_model(
input_spec=actor_input_spec,
output_spec=actor_output_spec,
agent_group=group,
input_has_agent_dim=True,
n_agents=n_agents,
centralised=False,
share_params=self.experiment_config.share_policy_params,
device=self.device,
action_spec=self.action_spec,
)
if continuous:
extractor_module = TensorDictModule(
NormalParamExtractor(scale_mapping=self.scale_mapping),
in_keys=[(group, "logits")],
out_keys=[(group, "loc"), (group, "scale")],
)
policy = ProbabilisticActor(
module=TensorDictSequential(actor_module, extractor_module),
spec=self.action_spec[group, "action"],
in_keys=[(group, "loc"), (group, "scale")],
out_keys=[(group, "action")],
distribution_class=(
IndependentNormal if not self.use_tanh_normal else TanhNormal
),
distribution_kwargs=(
{
"low": self.action_spec[(group, "action")].space.low,
"high": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
),
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
else:
if self.action_mask_spec is None:
policy = ProbabilisticActor(
module=actor_module,
spec=self.action_spec[group, "action"],
in_keys=[(group, "logits")],
out_keys=[(group, "action")],
distribution_class=Categorical,
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
else:
policy = ProbabilisticActor(
module=actor_module,
spec=self.action_spec[group, "action"],
in_keys={
"logits": (group, "logits"),
"mask": (group, "action_mask"),
},
out_keys=[(group, "action")],
distribution_class=MaskedCategorical,
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
return policy
def _get_policy_for_collection(
self, policy_for_loss: TensorDictModule, group: str, continuous: bool
) -> TensorDictModule:
# IPPO uses the same stochastic actor for collection
return policy_for_loss
def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
keys = list(batch.keys(True, True))
group_shape = batch.get(group).shape
nested_done_key = ("next", group, "done")
nested_terminated_key = ("next", group, "terminated")
nested_reward_key = ("next", group, "reward")
if nested_done_key not in keys:
batch.set(
nested_done_key,
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
)
if nested_terminated_key not in keys:
batch.set(
nested_terminated_key,
batch.get(("next", "terminated"))
.unsqueeze(-1)
.expand((*group_shape, 1)),
)
if nested_reward_key not in keys:
batch.set(
nested_reward_key,
batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)),
)
loss = self.get_loss_and_updater(group)[0]
if self.minibatch_advantage:
increment = -(
-self.experiment.config.train_minibatch_size(self.on_policy)
// batch.shape[1]
)
else:
increment = batch.batch_size[0] + 1
last_start_index = 0
start_index = increment
minibatches = []
while last_start_index < batch.shape[0]:
minimbatch = batch[last_start_index:start_index]
minibatches.append(minimbatch)
with torch.no_grad():
loss.value_estimator(
minimbatch,
params=loss.critic_network_params,
target_params=loss.target_critic_network_params,
)
last_start_index = start_index
start_index += increment
batch = torch.cat(minibatches, dim=0)
return batch
def process_loss_vals(
self, group: str, loss_vals: TensorDictBase
) -> TensorDictBase:
loss_vals.set(
"loss_objective", loss_vals["loss_objective"] + loss_vals["loss_entropy"]
)
del loss_vals["loss_entropy"]
return loss_vals
#####################
# Custom new methods
#####################
def get_critic(self, group: str) -> TensorDictModule:
n_agents = len(self.group_map[group])
critic_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)
critic_output_spec = Composite(
{
group: Composite(
{"state_value": Unbounded(shape=(n_agents, 1))},
shape=(n_agents,),
)
}
)
value_module = self.critic_model_config.get_model(
input_spec=critic_input_spec,
output_spec=critic_output_spec,
n_agents=n_agents,
centralised=False,
input_has_agent_dim=True,
agent_group=group,
share_params=self.share_param_critic,
device=self.device,
action_spec=self.action_spec,
)
return value_module
@dataclass
class IppoConfig(AlgorithmConfig):
"""Configuration dataclass for :class:`~benchmarl.algorithms.Ippo`."""
share_param_critic: bool = MISSING
clip_epsilon: float = MISSING
entropy_coef: float = MISSING
critic_coef: float = MISSING
loss_critic_type: str = MISSING
lmbda: float = MISSING
scale_mapping: str = MISSING
use_tanh_normal: bool = MISSING
minibatch_advantage: bool = MISSING
@staticmethod
def associated_class() -> Type[Algorithm]:
return Ippo
@staticmethod
def supports_continuous_actions() -> bool:
return True
@staticmethod
def supports_discrete_actions() -> bool:
return True
@staticmethod
def on_policy() -> bool:
return True
@staticmethod
def has_independent_critic() -> bool:
return True