Skip to content

Commit b1b5409

Browse files
committed
Change test=False to explore=True in acting
1 parent 613f24b commit b1b5409

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@
118118
free_nats = torch.full((1, ), args.free_nats, device=args.device) # Allowed deviation in KL divergence
119119

120120

121-
def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, test):
121+
def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, explore=False):
122122
# Infer belief over current state q(s_t|o≤t,a<t) from the history
123123
belief, _, _, _, posterior_state, _, _ = transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoder(observation).unsqueeze(dim=0)) # Action and observation need extra time dimension
124124
belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0) # Remove time dimension from belief/state
125125
action = planner(belief, posterior_state) # Get action from planner(q(s_t|o≤t,a<t), p)
126-
if not test:
126+
if explore:
127127
action = action + args.action_noise * torch.randn_like(action) # Add exploration noise ε ~ p(ε) to the action
128128
next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu()) # Perform environment step (action repeats handled internally)
129129
return belief, posterior_state, action, next_observation, reward, done
@@ -142,7 +142,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
142142
belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), torch.zeros(1, args.state_size, device=args.device), torch.zeros(1, env.action_size, device=args.device)
143143
pbar = tqdm(range(args.max_episode_length // args.action_repeat))
144144
for t in pbar:
145-
belief, posterior_state, action, observation, reward, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=True)
145+
belief, posterior_state, action, observation, reward, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device))
146146
total_reward += reward
147147
if args.render:
148148
env.render()
@@ -218,7 +218,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
218218
belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), torch.zeros(1, args.state_size, device=args.device), torch.zeros(1, env.action_size, device=args.device)
219219
pbar = tqdm(range(args.max_episode_length // args.action_repeat))
220220
for t in pbar:
221-
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=False)
221+
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), explore=True)
222222
D.append(observation, action.cpu(), reward, done)
223223
total_reward += reward
224224
observation = next_observation
@@ -250,7 +250,7 @@ def update_belief_and_act(args, env, planner, transition_model, encoder, belief,
250250
belief, posterior_state, action = torch.zeros(args.test_episodes, args.belief_size, device=args.device), torch.zeros(args.test_episodes, args.state_size, device=args.device), torch.zeros(args.test_episodes, env.action_size, device=args.device)
251251
pbar = tqdm(range(args.max_episode_length // args.action_repeat))
252252
for t in pbar:
253-
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), test=True)
253+
belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device))
254254
total_rewards += reward.numpy()
255255
if not args.symbolic_env: # Collect real vs. predicted frames for video
256256
video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy()) # Decentre

0 commit comments

Comments
 (0)