Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions mineagent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def stop_monitoring(self) -> None:
@staticmethod
def action_tensor_to_env(action: torch.Tensor) -> dict[str, np.ndarray]:
"""
Convert the flat action tensor (without the trailing 2 focus dims)
into the Dict-space action expected by MinecraftEnv.
Convert the flat env action tensor into the Dict-space action
expected by MinecraftEnv.
"""
a = action.squeeze(0)
keys = (a[:NUM_KEYS] > 0.5).to(torch.int8).numpy()
Expand All @@ -115,37 +115,44 @@ def act(self, obs: torch.Tensor, reward: float = 0.0) -> dict[str, np.ndarray]:
visual_features = self.vision(obs, roi_obs)
affector_output = self.affector(visual_features)
value = self.critic(visual_features)
action, logp_action = sample_action(affector_output)
env_action, env_logp, focus_action, focus_logp = sample_action(affector_output)

with torch.no_grad():
intrinsic_reward = self.icm.intrinsic_reward(
self.prev_visual_features, action, visual_features
self.prev_visual_features, env_action, visual_features
)

self.memory.store(
visual_features, action, reward, intrinsic_reward, value, logp_action
visual_features,
env_action,
reward,
intrinsic_reward,
value,
env_logp,
focus=focus_action,
focus_logp=focus_logp,
)

if len(self.memory) == self.config.max_buffer_size:
self.ppo.update(self.memory)
self.icm.update(self.memory)

self.prev_visual_features = visual_features
self.roi_action = action[:, -2:].round().long()
self.roi_action = focus_action.round().long()

env_action = self.action_tensor_to_env(action[:, :-2])
env_action_dict = self.action_tensor_to_env(env_action)

if self.monitor_actions:
self.event_bus.publish(
Action(
timestamp=datetime.now(),
visual_features=visual_features,
action_distribution=affector_output,
action=action,
logp_action=logp_action,
action=env_action,
logp_action=env_logp,
value=value,
region_of_interest=roi_obs,
intrinsic_reward=intrinsic_reward,
)
)
return env_action
return env_action_dict
3 changes: 3 additions & 0 deletions mineagent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class PPOConfig:
Discount factor for calculating rewards
gae_discount_factor : float, optional
Discount factor for Generalized Advantage Estimation
focus_loss_coeff : float, optional
Coefficient for the separate REINFORCE loss on the focus/ROI head
"""

clip_ratio: float = 0.2
Expand All @@ -57,6 +59,7 @@ class PPOConfig:
train_critic_iters: int = 80
discount_factor: float = 0.99
gae_discount_factor: float = 0.97
focus_loss_coeff: float = 0.01


@dataclass
Expand Down
7 changes: 2 additions & 5 deletions mineagent/learning/icm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,9 @@ def _compute_inverse_dynamics_loss(self, sample: ICMSample) -> torch.Tensor:
loss = loss + F.binary_cross_entropy_with_logits(
pred.mouse_button_logits, actions[:, col : col + 3]
)
col += 3

# Gaussian NLL for focus
loss = loss + F.gaussian_nll_loss(
pred.focus_means, actions[:, col : col + 2], pred.focus_stds**2
)
# Focus is excluded -- it's an internal perception decision,
# not part of the environment state transition.

return loss

Expand Down
43 changes: 37 additions & 6 deletions mineagent/learning/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,26 @@ class PPOSample:
features : torch.Tensor
Visual features computed from raw observations
actions : torch.Tensor
Actions taken using features
Environment actions taken (no focus)
returns : torch.Tensor
The return from the full trajectory computed so far
advantages : torch.Tensor
Advantages of taking each action over the alternative, e.g. `Q(s, a) - V(s)`
log_probabilities: torch.Tensor
Log of the probability of selecting the action taken
Joint log probability of the environment action
focus_actions : torch.Tensor
Focus/ROI coordinates (stored separately from env actions)
focus_log_probabilities : torch.Tensor
Joint log probability of the focus action
"""

features: torch.Tensor
actions: torch.Tensor
returns: torch.Tensor
advantages: torch.Tensor
log_probabilities: torch.Tensor
focus_actions: torch.Tensor
focus_log_probabilities: torch.Tensor

def __len__(self):
return self.features.shape[0]
Expand All @@ -44,8 +50,6 @@ def get(
self, shuffle: bool = False, batch_size: Union[int, None] = None
) -> Generator["PPOSample", None, None]:
"""


Parameters
----------
shuffle : bool, optional
Expand Down Expand Up @@ -76,6 +80,8 @@ def get(
returns=self.returns[batch_ind],
advantages=self.advantages[batch_ind],
log_probabilities=self.log_probabilities[batch_ind],
focus_actions=self.focus_actions[batch_ind],
focus_log_probabilities=self.focus_log_probabilities[batch_ind],
)

start_idx += batch_size
Expand Down Expand Up @@ -108,6 +114,7 @@ def __init__(self, actor: nn.Module, critic: nn.Module, config: PPOConfig):
self.gae_discount_factor = config.gae_discount_factor
self.clip_ratio = config.clip_ratio
self.target_kl = config.target_kl
self.focus_loss_coeff = config.focus_loss_coeff
self.actor_optim = optim.Adam(
self.actor.parameters(),
lr=config.actor_lr,
Expand All @@ -126,11 +133,23 @@ def _compute_actor_loss(self, data: PPOSample) -> tuple[torch.Tensor, torch.Tens
)

action_dist = self.actor(feat)

# Standard clipped surrogate loss (env actions only)
logp = joint_logp_action(action_dist, act)
ratio = torch.exp(logp - logp_old)
clip_adv = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * adv
loss = -(torch.min(ratio * adv, clip_adv)).mean()
env_loss = -(torch.min(ratio * adv, clip_adv)).mean()

# Separate REINFORCE loss for focus head (not clipped, not in KL)
focus_dist = torch.distributions.Normal(
action_dist.focus_means, action_dist.focus_stds
)
focus_logp_new = focus_dist.log_prob(data.focus_actions).sum(dim=-1)
focus_loss = -(focus_logp_new * adv.detach()).mean()

loss = env_loss + self.focus_loss_coeff * focus_loss

# KL is based on env actions only
kl = (logp_old - logp).mean()

return loss, kl
Expand Down Expand Up @@ -177,10 +196,12 @@ def _finalize_trajectory(self, data: TrajectoryBuffer) -> PPOSample:

This method expects the following at time `t`:
- data.features[t]: visual features of the environment computed at time `t`
- data.actions[t]: the action taken in the environment at time `t`
- data.actions[t]: the env action taken in the environment at time `t`
- data.log_probabilities[t]: the log probability of selecting `data.actions[t]`
- data.values[t]: the value assigned to `data.features[t]`
- data.rewards[t]: the reward given by the environment about the action take at time `t - 1`
- data.focus[t]: the focus/ROI coordinates at time `t`
- data.focus_logp[t]: the log probability of the focus coordinates
"""

# Cannot use the last values here since we don't have the associated reward yet
Expand All @@ -189,6 +210,14 @@ def _finalize_trajectory(self, data: TrajectoryBuffer) -> PPOSample:
# Sum per-component log probs to get joint log prob per timestep
raw_logp = torch.stack(list(data.log_probs_buffer)[:-1])
log_probabilities = raw_logp.sum(dim=-1) if raw_logp.dim() > 1 else raw_logp

# Focus data (stored separately)
focus_actions = torch.stack(list(data.focus_buffer)[:-1])
raw_focus_logp = torch.stack(list(data.focus_logp_buffer)[:-1])
focus_log_probabilities = (
raw_focus_logp.sum(dim=-1) if raw_focus_logp.dim() > 1 else raw_focus_logp
)

# Cannot use the first reward value since we no longer have the associated feature
# The reward for a_t is at r_{t+1}
env_rewards = torch.tensor(list(data.rewards_buffer)[1:])
Expand All @@ -215,6 +244,8 @@ def _finalize_trajectory(self, data: TrajectoryBuffer) -> PPOSample:
returns=returns,
advantages=advantages,
log_probabilities=log_probabilities,
focus_actions=focus_actions,
focus_log_probabilities=focus_log_probabilities,
)

def update(self, trajectory: TrajectoryBuffer) -> None:
Expand Down
18 changes: 15 additions & 3 deletions mineagent/memory/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self, max_buffer_size: int):
self.intrinsic_rewards_buffer: deque[float] = deque([], maxlen=max_buffer_size)
self.values_buffer: deque[float] = deque([], maxlen=max_buffer_size)
self.log_probs_buffer: deque[torch.Tensor] = deque([], maxlen=max_buffer_size)
self.focus_buffer: deque[torch.Tensor] = deque([], maxlen=max_buffer_size)
self.focus_logp_buffer: deque[torch.Tensor] = deque([], maxlen=max_buffer_size)

def __len__(self):
return len(self.features_buffer)
Expand All @@ -40,28 +42,38 @@ def store(
intrinsic_reward: float,
value: float,
log_prob: torch.Tensor,
focus: torch.Tensor | None = None,
focus_logp: torch.Tensor | None = None,
) -> None:
"""
Append a single time-step to the trajectory. Updates values for the previous time steps.
Append a single time-step to the trajectory.

Parameters
----------
visual_features : torch.Tensor
Features computed by visual perception from the observation of the environment
action : torch.Tensor
Action tensor for the MineDojo environment + the region of interest (x,y) coordinates
Environment action tensor (keys, mouse, buttons, scroll -- no focus)
reward : float
Reward value from the environment for the previous action
intrinsic_reward : float
Reward value from the Intrinsic Curiosity Module (ICM)
value : float
Value assigned to the observation by the agent
log_prob : torch.Tensor
Log probability of selecting each sub-action
Log probability of selecting each environment sub-action
focus : torch.Tensor | None
Focus/ROI coordinates (2-dim), stored separately from env action
focus_logp : torch.Tensor | None
Log probability of the focus coordinates
"""
self.features_buffer.append(visual_features)
self.actions_buffer.append(action)
self.rewards_buffer.append(reward)
self.intrinsic_rewards_buffer.append(intrinsic_reward)
self.values_buffer.append(value)
self.log_probs_buffer.append(log_prob)
if focus is not None:
self.focus_buffer.append(focus)
if focus_logp is not None:
self.focus_logp_buffer.append(focus_logp)
2 changes: 1 addition & 1 deletion mineagent/reasoning/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, embed_dim: int, action_dim: int | None = None):
from ..client.protocol import NUM_KEYS

if action_dim is None:
action_dim = NUM_KEYS + 3 + 3 + 2
action_dim = NUM_KEYS + 3 + 3
self.l1 = nn.Linear(embed_dim + action_dim, 512)
self.l2 = nn.Linear(512, embed_dim)

Expand Down
Loading
Loading