118118free_nats = torch .full ((1 , ), args .free_nats , device = args .device ) # Allowed deviation in KL divergence
119119
120120
121- def update_belief_and_act (args , env , planner , transition_model , encoder , belief , posterior_state , action , observation , test ):
121+ def update_belief_and_act (args , env , planner , transition_model , encoder , belief , posterior_state , action , observation , explore = False ):
122122 # Infer belief over current state q(s_t|o≤t,a<t) from the history
123123 belief , _ , _ , _ , posterior_state , _ , _ = transition_model (posterior_state , action .unsqueeze (dim = 0 ), belief , encoder (observation ).unsqueeze (dim = 0 )) # Action and observation need extra time dimension
124124 belief , posterior_state = belief .squeeze (dim = 0 ), posterior_state .squeeze (dim = 0 ) # Remove time dimension from belief/state
125125 action = planner (belief , posterior_state ) # Get action from planner(q(s_t|o≤t,a<t), p)
126- if not test :
126+ if explore :
127127 action = action + args .action_noise * torch .randn_like (action ) # Add exploration noise ε ~ p(ε) to the action
128128 next_observation , reward , done = env .step (action .cpu () if isinstance (env , EnvBatcher ) else action [0 ].cpu ()) # Perform environment step (action repeats handled internally)
129129 return belief , posterior_state , action , next_observation , reward , done
@@ -142,7 +142,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
142142 belief , posterior_state , action = torch .zeros (1 , args .belief_size , device = args .device ), torch .zeros (1 , args .state_size , device = args .device ), torch .zeros (1 , env .action_size , device = args .device )
143143 pbar = tqdm (range (args .max_episode_length // args .action_repeat ))
144144 for t in pbar :
145- belief , posterior_state , action , observation , reward , done = update_belief_and_act (args , env , planner , transition_model , encoder , belief , posterior_state , action , observation .to (device = args .device ), test = True )
145+ belief , posterior_state , action , observation , reward , done = update_belief_and_act (args , env , planner , transition_model , encoder , belief , posterior_state , action , observation .to (device = args .device ))
146146 total_reward += reward
147147 if args .render :
148148 env .render ()
@@ -218,7 +218,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
218218 belief , posterior_state , action = torch .zeros (1 , args .belief_size , device = args .device ), torch .zeros (1 , args .state_size , device = args .device ), torch .zeros (1 , env .action_size , device = args .device )
219219 pbar = tqdm (range (args .max_episode_length // args .action_repeat ))
220220 for t in pbar :
221- belief , posterior_state , action , next_observation , reward , done = update_belief_and_act (args , env , planner , transition_model , encoder , belief , posterior_state , action , observation .to (device = args .device ), test = False )
221+ belief , posterior_state , action , next_observation , reward , done = update_belief_and_act (args , env , planner , transition_model , encoder , belief , posterior_state , action , observation .to (device = args .device ), explore = True )
222222 D .append (observation , action .cpu (), reward , done )
223223 total_reward += reward
224224 observation = next_observation
@@ -250,7 +250,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
250250 belief , posterior_state , action = torch .zeros (args .test_episodes , args .belief_size , device = args .device ), torch .zeros (args .test_episodes , args .state_size , device = args .device ), torch .zeros (args .test_episodes , env .action_size , device = args .device )
251251 pbar = tqdm (range (args .max_episode_length // args .action_repeat ))
252252 for t in pbar :
253- belief , posterior_state , action , next_observation , reward , done = update_belief_and_act (args , test_envs , planner , transition_model , encoder , belief , posterior_state , action , observation .to (device = args .device ), test = True )
253+ belief , posterior_state , action , next_observation , reward , done = update_belief_and_act (args , test_envs , planner , transition_model , encoder , belief , posterior_state , action , observation .to (device = args .device ))
254254 total_rewards += reward .numpy ()
255255 if not args .symbolic_env : # Collect real vs. predicted frames for video
256256 video_frames .append (make_grid (torch .cat ([observation , observation_model (belief , posterior_state ).cpu ()], dim = 3 ) + 0.5 , nrow = 5 ).numpy ()) # Decentre
0 commit comments