Skip to content

Commit ea0a11c

Browse files
authored
1 parent 4edfd23 commit ea0a11c

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

intermediate_source/mario_rl_tutorial.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def act(self, state):
350350
class Mario(Mario): # subclassing for continuity
351351
def __init__(self, state_dim, action_dim, save_dir):
352352
super().__init__(state_dim, action_dim, save_dir)
353-
self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000))
353+
self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000, device=torch.device("cpu")))
354354
self.batch_size = 32
355355

356356
def cache(self, state, next_state, action, reward, done):
@@ -369,11 +369,11 @@ def first_if_tuple(x):
369369
state = first_if_tuple(state).__array__()
370370
next_state = first_if_tuple(next_state).__array__()
371371

372-
state = torch.tensor(state, device=self.device)
373-
next_state = torch.tensor(next_state, device=self.device)
374-
action = torch.tensor([action], device=self.device)
375-
reward = torch.tensor([reward], device=self.device)
376-
done = torch.tensor([done], device=self.device)
372+
state = torch.tensor(state)
373+
next_state = torch.tensor(next_state)
374+
action = torch.tensor([action])
375+
reward = torch.tensor([reward])
376+
done = torch.tensor([done])
377377

378378
# self.memory.append((state, next_state, action, reward, done,))
379379
self.memory.add(TensorDict({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done}, batch_size=[]))
@@ -382,7 +382,7 @@ def recall(self):
382382
"""
383383
Retrieve a batch of experiences from memory
384384
"""
385-
batch = self.memory.sample(self.batch_size)
385+
batch = self.memory.sample(self.batch_size).to(self.device)
386386
state, next_state, action, reward, done = (batch.get(key) for key in ("state", "next_state", "action", "reward", "done"))
387387
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
388388

0 commit comments

Comments
 (0)