@@ -350,7 +350,7 @@ def act(self, state):
350
350
class Mario (Mario ): # subclassing for continuity
351
351
def __init__ (self , state_dim , action_dim , save_dir ):
352
352
super ().__init__ (state_dim , action_dim , save_dir )
353
- self .memory = TensorDictReplayBuffer (storage = LazyMemmapStorage (100000 ))
353
+ self .memory = TensorDictReplayBuffer (storage = LazyMemmapStorage (100000 , device = torch . device ( "cpu" ) ))
354
354
self .batch_size = 32
355
355
356
356
def cache (self , state , next_state , action , reward , done ):
@@ -369,11 +369,11 @@ def first_if_tuple(x):
369
369
state = first_if_tuple (state ).__array__ ()
370
370
next_state = first_if_tuple (next_state ).__array__ ()
371
371
372
- state = torch .tensor (state , device = self . device )
373
- next_state = torch .tensor (next_state , device = self . device )
374
- action = torch .tensor ([action ], device = self . device )
375
- reward = torch .tensor ([reward ], device = self . device )
376
- done = torch .tensor ([done ], device = self . device )
372
+ state = torch .tensor (state )
373
+ next_state = torch .tensor (next_state )
374
+ action = torch .tensor ([action ])
375
+ reward = torch .tensor ([reward ])
376
+ done = torch .tensor ([done ])
377
377
378
378
# self.memory.append((state, next_state, action, reward, done,))
379
379
self .memory .add (TensorDict ({"state" : state , "next_state" : next_state , "action" : action , "reward" : reward , "done" : done }, batch_size = []))
@@ -382,7 +382,7 @@ def recall(self):
382
382
"""
383
383
Retrieve a batch of experiences from memory
384
384
"""
385
- batch = self .memory .sample (self .batch_size )
385
+ batch = self .memory .sample (self .batch_size ). to ( self . device )
386
386
state , next_state , action , reward , done = (batch .get (key ) for key in ("state" , "next_state" , "action" , "reward" , "done" ))
387
387
return state , next_state , action .squeeze (), reward .squeeze (), done .squeeze ()
388
388
0 commit comments