diff --git a/mineagent/agent/agent.py b/mineagent/agent/agent.py index e8ca219..4e977ab 100644 --- a/mineagent/agent/agent.py +++ b/mineagent/agent/agent.py @@ -88,17 +88,17 @@ def stop_monitoring(self) -> None: @staticmethod def action_tensor_to_env(action: torch.Tensor) -> dict[str, np.ndarray]: """ - Convert the flat action tensor (without the trailing 2 focus dims) - into the Dict-space action expected by MinecraftEnv. + Convert the flat env action tensor into the Dict-space action + expected by MinecraftEnv. """ a = action.squeeze(0) keys = (a[:NUM_KEYS] > 0.5).to(torch.int8).numpy() col = NUM_KEYS - mouse_dx = np.float32(np.clip(a[col].item(), *MOUSE_DX_RANGE)) + mouse_dx = np.clip(a[col], *MOUSE_DX_RANGE) col += 1 - mouse_dy = np.float32(np.clip(a[col].item(), *MOUSE_DY_RANGE)) + mouse_dy = np.clip(a[col], *MOUSE_DY_RANGE) col += 1 - scroll = np.float32(np.clip(a[col].item(), *SCROLL_RANGE)) + scroll = np.clip(a[col], *SCROLL_RANGE) col += 1 mouse_buttons = (a[col : col + 3] > 0.5).to(torch.int8).numpy() return { @@ -115,15 +115,22 @@ def act(self, obs: torch.Tensor, reward: float = 0.0) -> dict[str, np.ndarray]: visual_features = self.vision(obs, roi_obs) affector_output = self.affector(visual_features) value = self.critic(visual_features) - action, logp_action = sample_action(affector_output) + env_action, env_logp, focus_action, focus_logp = sample_action(affector_output) with torch.no_grad(): intrinsic_reward = self.icm.intrinsic_reward( - self.prev_visual_features, action, visual_features + self.prev_visual_features, env_action, visual_features ) self.memory.store( - visual_features, action, reward, intrinsic_reward, value, logp_action + visual_features, + env_action, + reward, + intrinsic_reward, + value, + env_logp, + focus=focus_action, + focus_logp=focus_logp, ) if len(self.memory) == self.config.max_buffer_size: @@ -131,9 +138,9 @@ def act(self, obs: torch.Tensor, reward: float = 0.0) -> dict[str, np.ndarray]: self.icm.update(self.memory) self.prev_visual_features = visual_features - self.roi_action = action[:, -2:].round().long() + self.roi_action = focus_action.round().long() - env_action = self.action_tensor_to_env(action[:, :-2]) + env_action_dict = self.action_tensor_to_env(env_action) if self.monitor_actions: self.event_bus.publish( @@ -141,11 +148,11 @@ def act(self, obs: torch.Tensor, reward: float = 0.0) -> dict[str, np.ndarray]: timestamp=datetime.now(), visual_features=visual_features, action_distribution=affector_output, - action=action, - logp_action=logp_action, + action=env_action, + logp_action=env_logp, value=value, region_of_interest=roi_obs, intrinsic_reward=intrinsic_reward, ) ) - return env_action + return env_action_dict diff --git a/mineagent/client/protocol.py b/mineagent/client/protocol.py index a708878..eab10fd 100644 --- a/mineagent/client/protocol.py +++ b/mineagent/client/protocol.py @@ -426,8 +426,8 @@ def raw_input_to_action(raw_input: RawInput) -> dict[str, np.ndarray]: return { "keys": keys, - "mouse_dx": np.float32(raw_input.mouse_dx), - "mouse_dy": np.float32(raw_input.mouse_dy), + "mouse_dx": np.array(raw_input.mouse_dx, dtype=np.float32), + "mouse_dy": np.array(raw_input.mouse_dy, dtype=np.float32), "mouse_buttons": mouse_buttons, - "scroll_delta": np.float32(raw_input.scroll_delta), + "scroll_delta": np.array(raw_input.scroll_delta, dtype=np.float32), } diff --git a/mineagent/config.py b/mineagent/config.py index 85c5c03..149ee58 100644 --- a/mineagent/config.py +++ b/mineagent/config.py @@ -47,6 +47,8 @@ class PPOConfig: Discount factor for calculating rewards gae_discount_factor : float, optional Discount factor for Generalized Advantage Estimation + focus_loss_coeff : float, optional + Coefficient for the separate REINFORCE loss on the focus/ROI head """ clip_ratio: float = 0.2 @@ -57,6 +59,7 @@ class PPOConfig: train_critic_iters: int = 80 discount_factor: float = 0.99 gae_discount_factor: float = 0.97 + focus_loss_coeff: float = 0.01 @dataclass diff --git a/mineagent/learning/icm.py b/mineagent/learning/icm.py index 34accc9..8f282b4 100644 --- a/mineagent/learning/icm.py +++ b/mineagent/learning/icm.py @@ -167,12 +167,9 @@ def _compute_inverse_dynamics_loss(self, sample: ICMSample) -> torch.Tensor: loss = loss + F.binary_cross_entropy_with_logits( pred.mouse_button_logits, actions[:, col : col + 3] ) - col += 3 - # Gaussian NLL for focus - loss = loss + F.gaussian_nll_loss( - pred.focus_means, actions[:, col : col + 2], pred.focus_stds**2 - ) + # Focus is excluded -- it's an internal perception decision, + # not part of the environment state transition. return loss diff --git a/mineagent/learning/ppo.py b/mineagent/learning/ppo.py index 0a6dae8..0cdffba 100644 --- a/mineagent/learning/ppo.py +++ b/mineagent/learning/ppo.py @@ -22,13 +22,17 @@ class PPOSample: features : torch.Tensor Visual features computed from raw observations actions : torch.Tensor - Actions taken using features + Environment actions taken (no focus) returns : torch.Tensor The return from the full trajectory computed so far advantages : torch.Tensor Advantages of taking each action over the alternative, e.g. `Q(s, a) - V(s)` log_probabilities: torch.Tensor - Log of the probability of selecting the action taken + Joint log probability of the environment action + focus_actions : torch.Tensor + Focus/ROI coordinates (stored separately from env actions) + focus_log_probabilities : torch.Tensor + Joint log probability of the focus action """ features: torch.Tensor @@ -36,6 +40,8 @@ class PPOSample: returns: torch.Tensor advantages: torch.Tensor log_probabilities: torch.Tensor + focus_actions: torch.Tensor + focus_log_probabilities: torch.Tensor def __len__(self): return self.features.shape[0] @@ -44,8 +50,6 @@ def get( self, shuffle: bool = False, batch_size: Union[int, None] = None ) -> Generator["PPOSample", None, None]: """ - - Parameters ---------- shuffle : bool, optional @@ -76,6 +80,8 @@ def get( returns=self.returns[batch_ind], advantages=self.advantages[batch_ind], log_probabilities=self.log_probabilities[batch_ind], + focus_actions=self.focus_actions[batch_ind], + focus_log_probabilities=self.focus_log_probabilities[batch_ind], ) start_idx += batch_size @@ -108,6 +114,7 @@ def __init__(self, actor: nn.Module, critic: nn.Module, config: PPOConfig): self.gae_discount_factor = config.gae_discount_factor self.clip_ratio = config.clip_ratio self.target_kl = config.target_kl + self.focus_loss_coeff = config.focus_loss_coeff self.actor_optim = optim.Adam( self.actor.parameters(), lr=config.actor_lr, @@ -126,11 +133,23 @@ def _compute_actor_loss(self, data: PPOSample) -> tuple[torch.Tensor, torch.Tens ) action_dist = self.actor(feat) + + # Standard clipped surrogate loss (env actions only) logp = joint_logp_action(action_dist, act) ratio = torch.exp(logp - logp_old) clip_adv = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * adv - loss = -(torch.min(ratio * adv, clip_adv)).mean() + env_loss = -(torch.min(ratio * adv, clip_adv)).mean() + + # Separate REINFORCE loss for focus head (not clipped, not in KL) + focus_dist = torch.distributions.Normal( + action_dist.focus_means, action_dist.focus_stds + ) + focus_logp_new = focus_dist.log_prob(data.focus_actions).sum(dim=-1) + focus_loss = -(focus_logp_new * adv.detach()).mean() + loss = env_loss + self.focus_loss_coeff * focus_loss + + # KL is based on env actions only kl = (logp_old - logp).mean() return loss, kl @@ -177,10 +196,12 @@ def _finalize_trajectory(self, data: TrajectoryBuffer) -> PPOSample: This method expects the following at time `t`: - data.features[t]: visual features of the environment computed at time `t` - - data.actions[t]: the action taken in the environment at time `t` + - data.actions[t]: the env action taken in the environment at time `t` - data.log_probabilities[t]: the log probability of selecting `data.actions[t]` - data.values[t]: the value assigned to `data.features[t]` - data.rewards[t]: the reward given by the environment about the action take at time `t - 1` + - data.focus[t]: the focus/ROI coordinates at time `t` + - data.focus_logp[t]: the log probability of the focus coordinates """ # Cannot use the last values here since we don't have the associated reward yet @@ -189,6 +210,14 @@ def _finalize_trajectory(self, data: TrajectoryBuffer) -> PPOSample: # Sum per-component log probs to get joint log prob per timestep raw_logp = torch.stack(list(data.log_probs_buffer)[:-1]) log_probabilities = raw_logp.sum(dim=-1) if raw_logp.dim() > 1 else raw_logp + + # Focus data (stored separately) + focus_actions = torch.stack(list(data.focus_buffer)[:-1]) + raw_focus_logp = torch.stack(list(data.focus_logp_buffer)[:-1]) + focus_log_probabilities = ( + raw_focus_logp.sum(dim=-1) if raw_focus_logp.dim() > 1 else raw_focus_logp + ) + # Cannot use the first reward value since we no longer have the associated feature # The reward for a_t is at r_{t+1} env_rewards = torch.tensor(list(data.rewards_buffer)[1:]) @@ -215,6 +244,8 @@ def _finalize_trajectory(self, data: TrajectoryBuffer) -> PPOSample: returns=returns, advantages=advantages, log_probabilities=log_probabilities, + focus_actions=focus_actions, + focus_log_probabilities=focus_log_probabilities, ) def update(self, trajectory: TrajectoryBuffer) -> None: diff --git a/mineagent/memory/trajectory.py b/mineagent/memory/trajectory.py index 7332723..2fa5851 100644 --- a/mineagent/memory/trajectory.py +++ b/mineagent/memory/trajectory.py @@ -28,6 +28,8 @@ def __init__(self, max_buffer_size: int): self.intrinsic_rewards_buffer: deque[float] = deque([], maxlen=max_buffer_size) self.values_buffer: deque[float] = deque([], maxlen=max_buffer_size) self.log_probs_buffer: deque[torch.Tensor] = deque([], maxlen=max_buffer_size) + self.focus_buffer: deque[torch.Tensor] = deque([], maxlen=max_buffer_size) + self.focus_logp_buffer: deque[torch.Tensor] = deque([], maxlen=max_buffer_size) def __len__(self): return len(self.features_buffer) @@ -40,16 +42,18 @@ def store( intrinsic_reward: float, value: float, log_prob: torch.Tensor, + focus: torch.Tensor | None = None, + focus_logp: torch.Tensor | None = None, ) -> None: """ - Append a single time-step to the trajectory. Updates values for the previous time steps. + Append a single time-step to the trajectory. Parameters ---------- visual_features : torch.Tensor Features computed by visual perception from the observation of the environment action : torch.Tensor - Action tensor for the MineDojo environment + the region of interest (x,y) coordinates + Environment action tensor (keys, mouse, buttons, scroll -- no focus) reward : float Reward value from the environment for the previous action intrinsic_reward : float @@ -57,7 +61,11 @@ def store( value : float Value assigned to the observation by the agent log_prob : torch.Tensor - Log probability of selecting each sub-action + Log probability of selecting each environment sub-action + focus : torch.Tensor | None + Focus/ROI coordinates (2-dim), stored separately from env action + focus_logp : torch.Tensor | None + Log probability of the focus coordinates """ self.features_buffer.append(visual_features) self.actions_buffer.append(action) @@ -65,3 +73,7 @@ def store( self.intrinsic_rewards_buffer.append(intrinsic_reward) self.values_buffer.append(value) self.log_probs_buffer.append(log_prob) + if focus is not None: + self.focus_buffer.append(focus) + if focus_logp is not None: + self.focus_logp_buffer.append(focus_logp) diff --git a/mineagent/reasoning/dynamics.py b/mineagent/reasoning/dynamics.py index 44da1ea..bc5ca86 100644 --- a/mineagent/reasoning/dynamics.py +++ b/mineagent/reasoning/dynamics.py @@ -45,7 +45,7 @@ def __init__(self, embed_dim: int, action_dim: int | None = None): from ..client.protocol import NUM_KEYS if action_dim is None: - action_dim = NUM_KEYS + 3 + 3 + 2 + action_dim = NUM_KEYS + 3 + 3 self.l1 = nn.Linear(embed_dim + action_dim, 512) self.l2 = nn.Linear(512, embed_dim) diff --git a/mineagent/utils.py b/mineagent/utils.py index 33413fe..a3e7987 100644 --- a/mineagent/utils.py +++ b/mineagent/utils.py @@ -123,72 +123,72 @@ def statistics(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def sample_action( output: AffectorOutput, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Sample from the distribution parameters in an AffectorOutput. - The returned action tensor has shape ``(batch, NUM_KEYS + 3 + 3 + 2)`` - laid out as: - [key_0 .. key_N | mouse_dx | mouse_dy | scroll | - mb_0 mb_1 mb_2 | focus_x focus_y] + Returns environment actions and focus actions as separate tensors so that + PPO/ICM only operate on the environment-affecting components. Returns ------- - action : torch.Tensor - Sampled action vector. - logp_action : torch.Tensor - Per-component log probabilities (same shape as action). + env_action : torch.Tensor + Environment action vector ``(batch, NUM_KEYS + 3 + 3)``. + env_logp : torch.Tensor + Per-component log probabilities for env_action (same shape). + focus_action : torch.Tensor + Focus/ROI coordinates ``(batch, 2)``. + focus_logp : torch.Tensor + Per-component log probabilities for focus_action ``(batch, 2)``. """ num_keys = output.key_logits.shape[-1] batch_size = output.key_logits.shape[0] - action_dim = num_keys + 3 + 3 + 2 # keys + (dx,dy,scroll) + 3 buttons + 2 focus + env_dim = num_keys + 3 + 3 # keys + (dx,dy,scroll) + 3 buttons - action = torch.zeros(batch_size, action_dim, dtype=torch.float) - logp = torch.zeros(batch_size, action_dim, dtype=torch.float) + env_action = torch.zeros(batch_size, env_dim, dtype=torch.float) + env_logp = torch.zeros(batch_size, env_dim, dtype=torch.float) # --- Keys (independent Bernoulli) --- key_dist = torch.distributions.Bernoulli(logits=output.key_logits) key_sample = key_dist.sample() - action[:, :num_keys] = key_sample - logp[:, :num_keys] = key_dist.log_prob(key_sample) + env_action[:, :num_keys] = key_sample + env_logp[:, :num_keys] = key_dist.log_prob(key_sample) col = num_keys # --- Mouse dx --- dx_dist = torch.distributions.Normal(output.mouse_dx_mean, output.mouse_dx_std) dx_sample = dx_dist.rsample() - action[:, col] = dx_sample - logp[:, col] = dx_dist.log_prob(dx_sample) + env_action[:, col] = dx_sample + env_logp[:, col] = dx_dist.log_prob(dx_sample) col += 1 # --- Mouse dy --- dy_dist = torch.distributions.Normal(output.mouse_dy_mean, output.mouse_dy_std) dy_sample = dy_dist.rsample() - action[:, col] = dy_sample - logp[:, col] = dy_dist.log_prob(dy_sample) + env_action[:, col] = dy_sample + env_logp[:, col] = dy_dist.log_prob(dy_sample) col += 1 # --- Scroll --- scroll_dist = torch.distributions.Normal(output.scroll_mean, output.scroll_std) scroll_sample = scroll_dist.rsample() - action[:, col] = scroll_sample - logp[:, col] = scroll_dist.log_prob(scroll_sample) + env_action[:, col] = scroll_sample + env_logp[:, col] = scroll_dist.log_prob(scroll_sample) col += 1 # --- Mouse buttons (independent Bernoulli) --- mb_dist = torch.distributions.Bernoulli(logits=output.mouse_button_logits) mb_sample = mb_dist.sample() - action[:, col : col + 3] = mb_sample - logp[:, col : col + 3] = mb_dist.log_prob(mb_sample) - col += 3 + env_action[:, col : col + 3] = mb_sample + env_logp[:, col : col + 3] = mb_dist.log_prob(mb_sample) - # --- Focus / ROI --- + # --- Focus / ROI (separate from env action) --- focus_dist = torch.distributions.Normal(output.focus_means, output.focus_stds) focus_sample = focus_dist.rsample() - action[:, col : col + 2] = focus_sample - logp[:, col : col + 2] = focus_dist.log_prob(focus_sample) + focus_logp = focus_dist.log_prob(focus_sample) - return action, logp + return env_action, env_logp, focus_sample, focus_logp def joint_logp_action( @@ -196,15 +196,16 @@ def joint_logp_action( actions_taken: torch.Tensor, ) -> torch.Tensor: """ - Compute the joint log-probability of *actions_taken* under the - distributions described by *output*. + Compute the joint log-probability of environment actions only. + + Focus/ROI actions are excluded -- they get their own loss term. Parameters ---------- output : AffectorOutput Distribution parameters from the affector. actions_taken : torch.Tensor - Previously sampled action tensor (same layout as ``sample_action``). + Environment action tensor ``(batch, NUM_KEYS + 3 + 3)``. Returns ------- @@ -237,11 +238,6 @@ def joint_logp_action( # Mouse buttons mb_dist = torch.distributions.Bernoulli(logits=output.mouse_button_logits) joint = joint + mb_dist.log_prob(actions_taken[:, col : col + 3]).sum(dim=-1) - col += 3 - - # Focus - focus_dist = torch.distributions.Normal(output.focus_means, output.focus_stds) - joint = joint + focus_dist.log_prob(actions_taken[:, col : col + 2]).sum(dim=-1) return joint diff --git a/tests/learning/test_icm.py b/tests/learning/test_icm.py index 7ef40cc..6b78d45 100644 --- a/tests/learning/test_icm.py +++ b/tests/learning/test_icm.py @@ -6,7 +6,8 @@ from mineagent.memory.trajectory import TrajectoryBuffer from mineagent.client.protocol import NUM_KEYS -ACTION_DIM = NUM_KEYS + 3 + 3 + 2 +ENV_ACTION_DIM = NUM_KEYS + 3 + 3 +FOCUS_DIM = 2 EMBED_DIM = AgentV1.EMBED_DIM @@ -32,10 +33,12 @@ def test_icm_update(icm_module: ICM) -> None: for _ in range(buffer_size): trajectory.store( torch.zeros((EMBED_DIM,), dtype=torch.float), - torch.zeros((ACTION_DIM,), dtype=torch.float), + torch.zeros((ENV_ACTION_DIM,), dtype=torch.float), 0.0, 0.0, 0.0, - torch.ones((ACTION_DIM,), dtype=torch.float), + torch.ones((ENV_ACTION_DIM,), dtype=torch.float), + focus=torch.zeros((FOCUS_DIM,), dtype=torch.float), + focus_logp=torch.ones((FOCUS_DIM,), dtype=torch.float), ) icm_module.update(trajectory) diff --git a/tests/learning/test_ppo.py b/tests/learning/test_ppo.py index 0048432..2d55edf 100644 --- a/tests/learning/test_ppo.py +++ b/tests/learning/test_ppo.py @@ -6,7 +6,8 @@ from mineagent.memory.trajectory import TrajectoryBuffer from mineagent.client.protocol import NUM_KEYS -ACTION_DIM = NUM_KEYS + 3 + 3 + 2 +ENV_ACTION_DIM = NUM_KEYS + 3 + 3 +FOCUS_DIM = 2 EMBED_DIM = AgentV1.EMBED_DIM @@ -30,10 +31,12 @@ def test_ppo_update(ppo_module: PPO) -> None: for _ in range(buffer_size): trajectory.store( torch.zeros((EMBED_DIM,), dtype=torch.float), - torch.zeros((ACTION_DIM,), dtype=torch.float), + torch.zeros((ENV_ACTION_DIM,), dtype=torch.float), 0.0, 0.0, 0.0, - torch.ones((ACTION_DIM,), dtype=torch.float), + torch.ones((ENV_ACTION_DIM,), dtype=torch.float), + focus=torch.zeros((FOCUS_DIM,), dtype=torch.float), + focus_logp=torch.ones((FOCUS_DIM,), dtype=torch.float), ) ppo_module.update(trajectory) diff --git a/tests/reasoning/test_dynamics.py b/tests/reasoning/test_dynamics.py index 7751d2d..0271bb1 100644 --- a/tests/reasoning/test_dynamics.py +++ b/tests/reasoning/test_dynamics.py @@ -8,7 +8,7 @@ EMBED_DIM = 64 BATCH = 32 -ACTION_DIM = NUM_KEYS + 3 + 3 + 2 # keys + dx/dy/scroll + 3 buttons + 2 focus +ACTION_DIM = NUM_KEYS + 3 + 3 # keys + dx/dy/scroll + 3 buttons (no focus) FORWARD_DYNAMICS_EXPECTED_PARAMS = ((EMBED_DIM + ACTION_DIM + 1) * 512) + ( (512 + 1) * EMBED_DIM