Skip to content

Commit 588aa78

Browse files
committed
support disable communication option during training
1 parent 4ac3a2b commit 588aa78

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

config.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ dhc:
66
hidden_dim: 256
77
max_comm_agents: 3 # includes the agent itself
88
batch_size: 192
9-
max_num_agents: 12
9+
max_num_agents: 16
1010
latent_dim: 784 # 16 * 7 * 7, do not forget to change if the observation_shape is changed
1111
max_episode_length: 256
1212

1313
communication:
14+
disable_communication: 1
15+
comm_enabled_prob: 0.7
1416
num_comm_layers: 2
1517
num_comm_heads: 2
1618

pathfinding/models/dhc/model.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# credits to https://github.com/ZiyuanMa/DHC/blob/master/model.py
2+
import random
23
import torch
34
import torch.nn as nn
45
import torch.nn.functional as F
@@ -269,7 +270,12 @@ def forward(self, obs, steps, hidden, comm_mask):
269270
# hidden size: batch_size*num_agents x self.hidden_dim
270271
hidden = self.recurrent(latent[i], hidden)
271272
hidden = hidden.view(self._batch_size, num_agents, self.hidden_dim)
272-
hidden = self.comm(hidden, comm_mask[:, i])
273+
274+
if DHC_CONFIG["communication"]["disable_communication"]:
275+
if random.random() < DHC_CONFIG["communication"]["comm_enabled_prob"]:
276+
hidden = self.comm(hidden, comm_mask[:, i])
277+
else:
278+
hidden = self.comm(hidden, comm_mask[:, i])
273279
# only hidden from agent 0
274280
hidden_buffer.append(hidden[:, 0])
275281
hidden = hidden.view(self._batch_size * num_agents, self.hidden_dim)

pathfinding/models/dhc/train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
TRAIN_CONFIG = settings["dhc"]["train"]
1111

12-
torch.manual_seed(239)
13-
np.random.seed(239)
14-
random.seed(239)
12+
torch.manual_seed(0)
13+
np.random.seed(0)
14+
random.seed(0)
1515

1616

1717
def main(

pathfinding/models/dhc/worker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def stats(self, interval: int):
371371
self.stat_dict[add_agent_key] = []
372372

373373
if key[1] < WRK_CONFIG["max_map_length"]:
374-
add_map_key = (key[0], key[1] + 5)
374+
add_map_key = (key[0], key[1] + 10)
375375
if add_map_key not in self.stat_dict:
376376
self.stat_dict[add_map_key] = []
377377

0 commit comments

Comments
 (0)