Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 55 additions & 71 deletions mineagent/affector/affector.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,81 @@
from dataclasses import dataclass

import torch
import torch.nn as nn
from gymnasium.spaces import MultiDiscrete

from ..client.protocol import NUM_KEYS
from ..utils import add_forward_hooks


@dataclass
class AffectorOutput:
"""All distribution parameters produced by the affector in a single object."""

key_logits: torch.Tensor
mouse_dx_mean: torch.Tensor
mouse_dx_std: torch.Tensor
mouse_dy_mean: torch.Tensor
mouse_dy_std: torch.Tensor
mouse_button_logits: torch.Tensor
scroll_mean: torch.Tensor
scroll_std: torch.Tensor
focus_means: torch.Tensor
focus_stds: torch.Tensor


class LinearAffector(nn.Module):
"""
Feed-forward affector (action) module.

This module produces distributions over actions for the environment given some input using linear layers.
Produces distribution parameters for the raw-input action space:
- Independent Bernoulli logits for each key in KEY_LIST
- Gaussian parameters (mean, std) for mouse_dx, mouse_dy, scroll_delta
- Independent Bernoulli logits for 3 mouse buttons
- Gaussian parameters for the internal focus/ROI mechanism
"""

def __init__(self, embed_dim: int, action_space: MultiDiscrete):
"""
Parameters
----------
embed_dim : int
Dimension of the input embeddings
action_space : MultiDiscrete
The action space for Minecraft is a length 8 numpy array:
0: longitudinal movement (i.e. moving forward and back)
1: lateral movement (i.e. moving left and right)
2: vertical movement (i.e. jumping)
3: pitch movement (vertical rotation, i.e. looking up and down)
4: yaw movement (hortizontal rotation, i.e. looking left and right)
5: functional (0: noop, 1: use, 2: drop, 3: attack, 4: craft, 5: equip, 6: place, 7: destroy)
6: item index to craft
7: inventory index
"""
def __init__(self, embed_dim: int, num_keys: int = NUM_KEYS):
super().__init__()

# Movement
self.longitudinal_action = nn.Linear(embed_dim, int(action_space.nvec[0]))
self.lateral_action = nn.Linear(embed_dim, int(action_space.nvec[1]))
self.vertical_action = nn.Linear(embed_dim, int(action_space.nvec[2]))
self.pitch_action = nn.Linear(embed_dim, int(action_space.nvec[3]))
self.yaw_action = nn.Linear(embed_dim, int(action_space.nvec[4]))
self.num_keys = num_keys

# Binary key logits (one per key)
self.key_head = nn.Linear(embed_dim, num_keys)

# Mouse movement (mean + log-std for dx and dy)
self.mouse_dx_mean = nn.Linear(embed_dim, 1)
self.mouse_dx_logstd = nn.Linear(embed_dim, 1)
self.mouse_dy_mean = nn.Linear(embed_dim, 1)
self.mouse_dy_logstd = nn.Linear(embed_dim, 1)

# Mouse buttons (3 independent Bernoulli logits)
self.mouse_button_head = nn.Linear(embed_dim, 3)

# Manipulation
self.functional_action = nn.Linear(embed_dim, int(action_space.nvec[5]))
self.craft_action = nn.Linear(embed_dim, int(action_space.nvec[6]))
self.inventory_action = nn.Linear(embed_dim, int(action_space.nvec[7]))
# Scroll (mean + log-std)
self.scroll_mean = nn.Linear(embed_dim, 1)
self.scroll_logstd = nn.Linear(embed_dim, 1)

# Internal
## distribution for which we can sample regions of interest
# Internal focus / region-of-interest
self.focus_means = nn.Linear(embed_dim, 2)
self.focus_stds = nn.Linear(embed_dim, 2)
self.focus_logstds = nn.Linear(embed_dim, 2)

self.softmax = nn.Softmax(dim=1)
self.sigmoid = nn.Sigmoid()
self.softplus = nn.Softplus()
self.action_space = action_space

# Monitoring
self.start_monitoring()

def forward(
self, x: torch.Tensor
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
long_dist = self.softmax(self.longitudinal_action(x))
lat_dist = self.softmax(self.lateral_action(x))
vert_dist = self.softmax(self.vertical_action(x))
pitch_dist = self.softmax(self.pitch_action(x))
yaw_dist = self.softmax(self.yaw_action(x))
func_dist = self.softmax(self.functional_action(x))
craft_dist = self.softmax(self.craft_action(x))
inventory_dist = self.softmax(self.inventory_action(x))
roi_means = self.focus_means(x)
roi_stds = self.softplus(self.focus_stds(x))

return (
long_dist,
lat_dist,
vert_dist,
pitch_dist,
yaw_dist,
func_dist,
craft_dist,
inventory_dist,
roi_means,
roi_stds,
def forward(self, x: torch.Tensor) -> AffectorOutput:
return AffectorOutput(
key_logits=self.key_head(x),
mouse_dx_mean=self.mouse_dx_mean(x).squeeze(-1),
mouse_dx_std=self.softplus(self.mouse_dx_logstd(x)).squeeze(-1),
mouse_dy_mean=self.mouse_dy_mean(x).squeeze(-1),
mouse_dy_std=self.softplus(self.mouse_dy_logstd(x)).squeeze(-1),
mouse_button_logits=self.mouse_button_head(x),
scroll_mean=self.scroll_mean(x).squeeze(-1),
scroll_std=self.softplus(self.scroll_logstd(x)).squeeze(-1),
focus_means=self.focus_means(x),
focus_stds=self.softplus(self.focus_logstds(x)),
)

def stop_monitoring(self):
Expand Down
59 changes: 45 additions & 14 deletions mineagent/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import datetime

import numpy as np
import torch
from torchvision.transforms.functional import center_crop, crop # type: ignore
from gymnasium.spaces import MultiDiscrete

from ..perception.visual import VisualPerception
from ..affector.affector import LinearAffector
Expand All @@ -15,6 +15,12 @@
from ..monitoring.event import Action
from ..utils import sample_action
from ..monitoring.event_bus import get_event_bus
from ..client.protocol import (
NUM_KEYS,
MOUSE_DX_RANGE,
MOUSE_DY_RANGE,
SCROLL_RANGE,
)


class AgentV1:
Expand All @@ -26,17 +32,18 @@ class AgentV1:
- How fast is the visual perception module? Does it need to be faster?
"""

EMBED_DIM = 32 + 32

def __init__(
self,
config: AgentConfig,
action_space: MultiDiscrete,
) -> None:
self.vision = VisualPerception(out_channels=32)
self.affector = LinearAffector(32 + 32, action_space)
self.critic = LinearCritic(32 + 32)
self.affector = LinearAffector(self.EMBED_DIM)
self.critic = LinearCritic(self.EMBED_DIM)
self.memory = TrajectoryBuffer(config.max_buffer_size)
self.inverse_dynamics = InverseDynamics(32 + 32, action_space)
self.forward_dynamics = ForwardDynamics(32 + 32, action_space.shape[0] + 2)
self.inverse_dynamics = InverseDynamics(self.EMBED_DIM)
self.forward_dynamics = ForwardDynamics(self.EMBED_DIM)
self.ppo = PPO(self.affector, self.critic, config.ppo)
self.icm = ICM(self.forward_dynamics, self.inverse_dynamics, config.icm)
self.config = config
Expand All @@ -45,7 +52,7 @@ def __init__(
# region of interest initialization
self.roi_action: torch.Tensor | None = None
self.prev_visual_features: torch.Tensor = torch.zeros(
(1, 64), dtype=torch.float
(1, self.EMBED_DIM), dtype=torch.float
)
self.event_bus = get_event_bus()

Expand Down Expand Up @@ -78,15 +85,38 @@ def stop_monitoring(self) -> None:
self.forward_dynamics.stop_monitoring()
self.monitor_actions = False

def act(self, obs: torch.Tensor, reward: float = 0.0) -> torch.Tensor:
@staticmethod
def action_tensor_to_env(action: torch.Tensor) -> dict[str, np.ndarray]:
"""
Convert the flat action tensor (without the trailing 2 focus dims)
into the Dict-space action expected by MinecraftEnv.
"""
a = action.squeeze(0)
keys = (a[:NUM_KEYS] > 0.5).to(torch.int8).numpy()
col = NUM_KEYS
mouse_dx = np.float32(np.clip(a[col].item(), *MOUSE_DX_RANGE))
col += 1
mouse_dy = np.float32(np.clip(a[col].item(), *MOUSE_DY_RANGE))
col += 1
scroll = np.float32(np.clip(a[col].item(), *SCROLL_RANGE))
col += 1
mouse_buttons = (a[col : col + 3] > 0.5).to(torch.int8).numpy()
return {
"keys": keys,
"mouse_dx": mouse_dx,
"mouse_dy": mouse_dy,
"mouse_buttons": mouse_buttons,
"scroll_delta": scroll,
}

def act(self, obs: torch.Tensor, reward: float = 0.0) -> dict[str, np.ndarray]:
roi_obs = self._transform_observation(obs)
with torch.no_grad():
visual_features = self.vision(obs, roi_obs)
actions = self.affector(visual_features)
affector_output = self.affector(visual_features)
value = self.critic(visual_features)
action, logp_action = sample_action(actions)
action, logp_action = sample_action(affector_output)

# Get the intrinsic reward associated with the previous observation
with torch.no_grad():
intrinsic_reward = self.icm.intrinsic_reward(
self.prev_visual_features, action, visual_features
Expand All @@ -96,20 +126,21 @@ def act(self, obs: torch.Tensor, reward: float = 0.0) -> torch.Tensor:
visual_features, action, reward, intrinsic_reward, value, logp_action
)

# Once the trajectory buffer is full, we can start learning
if len(self.memory) == self.config.max_buffer_size:
self.ppo.update(self.memory)
self.icm.update(self.memory)

self.prev_visual_features = visual_features
self.roi_action = action[:, -2:].round().long()
env_action = action[:, :-2].long()

env_action = self.action_tensor_to_env(action[:, :-2])

if self.monitor_actions:
self.event_bus.publish(
Action(
timestamp=datetime.now(),
visual_features=visual_features,
action_distribution=actions,
action_distribution=affector_output,
action=action,
logp_action=logp_action,
value=value,
Expand Down
12 changes: 12 additions & 0 deletions mineagent/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,29 @@
from .protocol import (
COMMAND_TO_KEY,
GLFW,
KEY_LIST,
KEY_TO_INDEX,
NUM_KEYS,
Observation,
RawInput,
action_to_raw_input,
make_action_space,
parse_observation,
raw_input_to_action,
)

__all__ = [
"AsyncMinecraftClient",
"ConnectionConfig",
"COMMAND_TO_KEY",
"GLFW",
"KEY_LIST",
"KEY_TO_INDEX",
"NUM_KEYS",
"Observation",
"RawInput",
"action_to_raw_input",
"make_action_space",
"parse_observation",
"raw_input_to_action",
]
Loading
Loading