diff --git a/mineagent/affector/affector.py b/mineagent/affector/affector.py index 92ff9f8..86e3ea6 100644 --- a/mineagent/affector/affector.py +++ b/mineagent/affector/affector.py @@ -1,97 +1,81 @@ +from dataclasses import dataclass + import torch import torch.nn as nn -from gymnasium.spaces import MultiDiscrete +from ..client.protocol import NUM_KEYS from ..utils import add_forward_hooks +@dataclass +class AffectorOutput: + """All distribution parameters produced by the affector in a single object.""" + + key_logits: torch.Tensor + mouse_dx_mean: torch.Tensor + mouse_dx_std: torch.Tensor + mouse_dy_mean: torch.Tensor + mouse_dy_std: torch.Tensor + mouse_button_logits: torch.Tensor + scroll_mean: torch.Tensor + scroll_std: torch.Tensor + focus_means: torch.Tensor + focus_stds: torch.Tensor + + class LinearAffector(nn.Module): """ Feed-forward affector (action) module. - This module produces distributions over actions for the environment given some input using linear layers. + Produces distribution parameters for the raw-input action space: + - Independent Bernoulli logits for each key in KEY_LIST + - Gaussian parameters (mean, std) for mouse_dx, mouse_dy, scroll_delta + - Independent Bernoulli logits for 3 mouse buttons + - Gaussian parameters for the internal focus/ROI mechanism """ - def __init__(self, embed_dim: int, action_space: MultiDiscrete): - """ - Parameters - ---------- - embed_dim : int - Dimension of the input embeddings - action_space : MultiDiscrete - The action space for Minecraft is a length 8 numpy array: - 0: longitudinal movement (i.e. moving forward and back) - 1: lateral movement (i.e. moving left and right) - 2: vertical movement (i.e. jumping) - 3: pitch movement (vertical rotation, i.e. looking up and down) - 4: yaw movement (hortizontal rotation, i.e. looking left and right) - 5: functional (0: noop, 1: use, 2: drop, 3: attack, 4: craft, 5: equip, 6: place, 7: destroy) - 6: item index to craft - 7: inventory index - """ + def __init__(self, embed_dim: int, num_keys: int = NUM_KEYS): super().__init__() - # Movement - self.longitudinal_action = nn.Linear(embed_dim, int(action_space.nvec[0])) - self.lateral_action = nn.Linear(embed_dim, int(action_space.nvec[1])) - self.vertical_action = nn.Linear(embed_dim, int(action_space.nvec[2])) - self.pitch_action = nn.Linear(embed_dim, int(action_space.nvec[3])) - self.yaw_action = nn.Linear(embed_dim, int(action_space.nvec[4])) + self.num_keys = num_keys + + # Binary key logits (one per key) + self.key_head = nn.Linear(embed_dim, num_keys) + + # Mouse movement (mean + log-std for dx and dy) + self.mouse_dx_mean = nn.Linear(embed_dim, 1) + self.mouse_dx_logstd = nn.Linear(embed_dim, 1) + self.mouse_dy_mean = nn.Linear(embed_dim, 1) + self.mouse_dy_logstd = nn.Linear(embed_dim, 1) + + # Mouse buttons (3 independent Bernoulli logits) + self.mouse_button_head = nn.Linear(embed_dim, 3) - # Manipulation - self.functional_action = nn.Linear(embed_dim, int(action_space.nvec[5])) - self.craft_action = nn.Linear(embed_dim, int(action_space.nvec[6])) - self.inventory_action = nn.Linear(embed_dim, int(action_space.nvec[7])) + # Scroll (mean + log-std) + self.scroll_mean = nn.Linear(embed_dim, 1) + self.scroll_logstd = nn.Linear(embed_dim, 1) - # Internal - ## distribution for which we can sample regions of interest + # Internal focus / region-of-interest self.focus_means = nn.Linear(embed_dim, 2) - self.focus_stds = nn.Linear(embed_dim, 2) + self.focus_logstds = nn.Linear(embed_dim, 2) - self.softmax = nn.Softmax(dim=1) - self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() - self.action_space = action_space # Monitoring self.start_monitoring() - def forward( - self, x: torch.Tensor - ) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: - long_dist = self.softmax(self.longitudinal_action(x)) - lat_dist = self.softmax(self.lateral_action(x)) - vert_dist = self.softmax(self.vertical_action(x)) - pitch_dist = self.softmax(self.pitch_action(x)) - yaw_dist = self.softmax(self.yaw_action(x)) - func_dist = self.softmax(self.functional_action(x)) - craft_dist = self.softmax(self.craft_action(x)) - inventory_dist = self.softmax(self.inventory_action(x)) - roi_means = self.focus_means(x) - roi_stds = self.softplus(self.focus_stds(x)) - - return ( - long_dist, - lat_dist, - vert_dist, - pitch_dist, - yaw_dist, - func_dist, - craft_dist, - inventory_dist, - roi_means, - roi_stds, + def forward(self, x: torch.Tensor) -> AffectorOutput: + return AffectorOutput( + key_logits=self.key_head(x), + mouse_dx_mean=self.mouse_dx_mean(x).squeeze(-1), + mouse_dx_std=self.softplus(self.mouse_dx_logstd(x)).squeeze(-1), + mouse_dy_mean=self.mouse_dy_mean(x).squeeze(-1), + mouse_dy_std=self.softplus(self.mouse_dy_logstd(x)).squeeze(-1), + mouse_button_logits=self.mouse_button_head(x), + scroll_mean=self.scroll_mean(x).squeeze(-1), + scroll_std=self.softplus(self.scroll_logstd(x)).squeeze(-1), + focus_means=self.focus_means(x), + focus_stds=self.softplus(self.focus_logstds(x)), ) def stop_monitoring(self): diff --git a/mineagent/agent/agent.py b/mineagent/agent/agent.py index 7887e8a..e8ca219 100644 --- a/mineagent/agent/agent.py +++ b/mineagent/agent/agent.py @@ -1,8 +1,8 @@ from datetime import datetime +import numpy as np import torch from torchvision.transforms.functional import center_crop, crop # type: ignore -from gymnasium.spaces import MultiDiscrete from ..perception.visual import VisualPerception from ..affector.affector import LinearAffector @@ -15,6 +15,12 @@ from ..monitoring.event import Action from ..utils import sample_action from ..monitoring.event_bus import get_event_bus +from ..client.protocol import ( + NUM_KEYS, + MOUSE_DX_RANGE, + MOUSE_DY_RANGE, + SCROLL_RANGE, +) class AgentV1: @@ -26,17 +32,18 @@ class AgentV1: - How fast is the visual perception module? Does it need to be faster? """ + EMBED_DIM = 32 + 32 + def __init__( self, config: AgentConfig, - action_space: MultiDiscrete, ) -> None: self.vision = VisualPerception(out_channels=32) - self.affector = LinearAffector(32 + 32, action_space) - self.critic = LinearCritic(32 + 32) + self.affector = LinearAffector(self.EMBED_DIM) + self.critic = LinearCritic(self.EMBED_DIM) self.memory = TrajectoryBuffer(config.max_buffer_size) - self.inverse_dynamics = InverseDynamics(32 + 32, action_space) - self.forward_dynamics = ForwardDynamics(32 + 32, action_space.shape[0] + 2) + self.inverse_dynamics = InverseDynamics(self.EMBED_DIM) + self.forward_dynamics = ForwardDynamics(self.EMBED_DIM) self.ppo = PPO(self.affector, self.critic, config.ppo) self.icm = ICM(self.forward_dynamics, self.inverse_dynamics, config.icm) self.config = config @@ -45,7 +52,7 @@ def __init__( # region of interest initialization self.roi_action: torch.Tensor | None = None self.prev_visual_features: torch.Tensor = torch.zeros( - (1, 64), dtype=torch.float + (1, self.EMBED_DIM), dtype=torch.float ) self.event_bus = get_event_bus() @@ -78,15 +85,38 @@ def stop_monitoring(self) -> None: self.forward_dynamics.stop_monitoring() self.monitor_actions = False - def act(self, obs: torch.Tensor, reward: float = 0.0) -> torch.Tensor: + @staticmethod + def action_tensor_to_env(action: torch.Tensor) -> dict[str, np.ndarray]: + """ + Convert the flat action tensor (without the trailing 2 focus dims) + into the Dict-space action expected by MinecraftEnv. + """ + a = action.squeeze(0) + keys = (a[:NUM_KEYS] > 0.5).to(torch.int8).numpy() + col = NUM_KEYS + mouse_dx = np.float32(np.clip(a[col].item(), *MOUSE_DX_RANGE)) + col += 1 + mouse_dy = np.float32(np.clip(a[col].item(), *MOUSE_DY_RANGE)) + col += 1 + scroll = np.float32(np.clip(a[col].item(), *SCROLL_RANGE)) + col += 1 + mouse_buttons = (a[col : col + 3] > 0.5).to(torch.int8).numpy() + return { + "keys": keys, + "mouse_dx": mouse_dx, + "mouse_dy": mouse_dy, + "mouse_buttons": mouse_buttons, + "scroll_delta": scroll, + } + + def act(self, obs: torch.Tensor, reward: float = 0.0) -> dict[str, np.ndarray]: roi_obs = self._transform_observation(obs) with torch.no_grad(): visual_features = self.vision(obs, roi_obs) - actions = self.affector(visual_features) + affector_output = self.affector(visual_features) value = self.critic(visual_features) - action, logp_action = sample_action(actions) + action, logp_action = sample_action(affector_output) - # Get the intrinsic reward associated with the previous observation with torch.no_grad(): intrinsic_reward = self.icm.intrinsic_reward( self.prev_visual_features, action, visual_features @@ -96,20 +126,21 @@ def act(self, obs: torch.Tensor, reward: float = 0.0) -> torch.Tensor: visual_features, action, reward, intrinsic_reward, value, logp_action ) - # Once the trajectory buffer is full, we can start learning if len(self.memory) == self.config.max_buffer_size: self.ppo.update(self.memory) self.icm.update(self.memory) self.prev_visual_features = visual_features self.roi_action = action[:, -2:].round().long() - env_action = action[:, :-2].long() + + env_action = self.action_tensor_to_env(action[:, :-2]) + if self.monitor_actions: self.event_bus.publish( Action( timestamp=datetime.now(), visual_features=visual_features, - action_distribution=actions, + action_distribution=affector_output, action=action, logp_action=logp_action, value=value, diff --git a/mineagent/client/__init__.py b/mineagent/client/__init__.py index c08c9e2..86a8f70 100644 --- a/mineagent/client/__init__.py +++ b/mineagent/client/__init__.py @@ -2,9 +2,15 @@ from .protocol import ( COMMAND_TO_KEY, GLFW, + KEY_LIST, + KEY_TO_INDEX, + NUM_KEYS, Observation, RawInput, + action_to_raw_input, + make_action_space, parse_observation, + raw_input_to_action, ) __all__ = [ @@ -12,7 +18,13 @@ "ConnectionConfig", "COMMAND_TO_KEY", "GLFW", + "KEY_LIST", + "KEY_TO_INDEX", + "NUM_KEYS", "Observation", "RawInput", + "action_to_raw_input", + "make_action_space", "parse_observation", + "raw_input_to_action", ] diff --git a/mineagent/client/protocol.py b/mineagent/client/protocol.py index 336eb30..a708878 100644 --- a/mineagent/client/protocol.py +++ b/mineagent/client/protocol.py @@ -152,6 +152,69 @@ class GLFW: "command": GLFW.KEY_SLASH, } +# Canonical ordered list of GLFW key codes included in the action space. +# Index in this list == index in the MultiBinary vector. +KEY_LIST: list[int] = [ + # Movement + GLFW.KEY_W, # 0: forward + GLFW.KEY_S, # 1: back + GLFW.KEY_A, # 2: strafe left + GLFW.KEY_D, # 3: strafe right + GLFW.KEY_SPACE, # 4: jump + GLFW.KEY_LEFT_SHIFT, # 5: sneak + GLFW.KEY_LEFT_CONTROL, # 6: sprint + # Interaction + GLFW.KEY_E, # 7: inventory + GLFW.KEY_Q, # 8: drop + GLFW.KEY_F, # 9: swap offhand + # Hotbar + GLFW.KEY_1, # 10 + GLFW.KEY_2, # 11 + GLFW.KEY_3, # 12 + GLFW.KEY_4, # 13 + GLFW.KEY_5, # 14 + GLFW.KEY_6, # 15 + GLFW.KEY_7, # 16 + GLFW.KEY_8, # 17 + GLFW.KEY_9, # 18 + # UI / Menu + GLFW.KEY_ESCAPE, # 19 + GLFW.KEY_ENTER, # 20 + GLFW.KEY_TAB, # 21 + GLFW.KEY_BACKSPACE, # 22 + # Chat + GLFW.KEY_T, # 23: open chat + GLFW.KEY_SLASH, # 24: open command + # Letters (for typing in chat/signs/etc.) + GLFW.KEY_B, # 25 + GLFW.KEY_C, # 26 + GLFW.KEY_G, # 27 + GLFW.KEY_H, # 28 + GLFW.KEY_I, # 29 + GLFW.KEY_J, # 30 + GLFW.KEY_K, # 31 + GLFW.KEY_L, # 32 + GLFW.KEY_M, # 33 + GLFW.KEY_N, # 34 + GLFW.KEY_O, # 35 + GLFW.KEY_P, # 36 + GLFW.KEY_R, # 37 + GLFW.KEY_U, # 38 + GLFW.KEY_V, # 39 + GLFW.KEY_X, # 40 + GLFW.KEY_Y, # 41 + GLFW.KEY_Z, # 42 + # Debug + GLFW.KEY_F1, # 43 + GLFW.KEY_F2, # 44 + GLFW.KEY_F3, # 45 + GLFW.KEY_F5, # 46 +] + +NUM_KEYS: int = len(KEY_LIST) + +KEY_TO_INDEX: dict[int, int] = {code: idx for idx, code in enumerate(KEY_LIST)} + @dataclass class RawInput: @@ -276,3 +339,95 @@ def parse_observation( frame = np.frombuffer(frame_data, dtype=np.uint8).reshape(height, width, 3) return Observation(reward=reward, frame=frame) + + +# --------------------------------------------------------------------------- +# Action space helpers +# --------------------------------------------------------------------------- + +MOUSE_DX_RANGE = (-180.0, 180.0) +MOUSE_DY_RANGE = (-180.0, 180.0) +SCROLL_RANGE = (-10.0, 10.0) + + +def make_action_space(): + """Build the Gymnasium Dict action space that mirrors RawInput.""" + from gymnasium import spaces + + return spaces.Dict( + { + "keys": spaces.MultiBinary(NUM_KEYS), + "mouse_dx": spaces.Box(*MOUSE_DX_RANGE, shape=(), dtype=np.float32), + "mouse_dy": spaces.Box(*MOUSE_DY_RANGE, shape=(), dtype=np.float32), + "mouse_buttons": spaces.MultiBinary(3), + "scroll_delta": spaces.Box(*SCROLL_RANGE, shape=(), dtype=np.float32), + } + ) + + +def action_to_raw_input(action: dict[str, np.ndarray]) -> RawInput: + """ + Convert a Dict-space action sample into a RawInput for the wire protocol. + + Parameters + ---------- + action : dict[str, np.ndarray] + A sample from the action space returned by ``make_action_space()``. + + Returns + ------- + RawInput + Ready to serialize with ``to_bytes()`` and send to the Forge mod. + """ + keys_vec = np.asarray(action["keys"], dtype=np.int8).ravel() + key_codes = [KEY_LIST[i] for i, pressed in enumerate(keys_vec) if pressed] + + mouse_buttons_vec = np.asarray(action["mouse_buttons"], dtype=np.int8).ravel() + mouse_buttons = 0 + for bit, pressed in enumerate(mouse_buttons_vec): + if pressed: + mouse_buttons |= 1 << bit + + return RawInput( + key_codes=key_codes, + mouse_dx=float(action["mouse_dx"]), + mouse_dy=float(action["mouse_dy"]), + mouse_buttons=mouse_buttons, + scroll_delta=float(action["scroll_delta"]), + ) + + +def raw_input_to_action(raw_input: RawInput) -> dict[str, np.ndarray]: + """ + Convert a RawInput back into a Dict-space action sample. + + Useful for imitation learning or replaying recorded human input. + + Parameters + ---------- + raw_input : RawInput + The raw input to convert. + + Returns + ------- + dict[str, np.ndarray] + A valid sample for the action space returned by ``make_action_space()``. + """ + keys = np.zeros(NUM_KEYS, dtype=np.int8) + for code in raw_input.key_codes: + idx = KEY_TO_INDEX.get(code) + if idx is not None: + keys[idx] = 1 + + mouse_buttons = np.zeros(3, dtype=np.int8) + for bit in range(3): + if raw_input.mouse_buttons & (1 << bit): + mouse_buttons[bit] = 1 + + return { + "keys": keys, + "mouse_dx": np.float32(raw_input.mouse_dx), + "mouse_dy": np.float32(raw_input.mouse_dy), + "mouse_buttons": mouse_buttons, + "scroll_delta": np.float32(raw_input.scroll_delta), + } diff --git a/mineagent/engine.py b/mineagent/engine.py index a2ae282..c784615 100644 --- a/mineagent/engine.py +++ b/mineagent/engine.py @@ -1,16 +1,13 @@ -from typing import cast from datetime import datetime import torch -import minedojo -from gymnasium.spaces import MultiDiscrete from .agent.agent import AgentV1 -from .config import get_config +from .env import MinecraftEnv, MinecraftEnvConfig +from .config import get_config, MonitoringConfig from .monitoring.event_bus import get_event_bus from .monitoring.event import Start, Stop, EnvReset, EnvStep from .utils import setup_tensorboard -from .config import MonitoringConfig def setup_monitoring(config: MonitoringConfig) -> None: @@ -38,18 +35,22 @@ def run() -> None: event_bus.publish(Start(timestamp=datetime.now())) - env = minedojo.make(task_id="open-ended", image_size=engine_config.image_size) - action_space = cast(MultiDiscrete, env.action_space) - agent = AgentV1(config.agent, action_space) + env_config = MinecraftEnvConfig( + frame_height=engine_config.image_size[0], + frame_width=engine_config.image_size[1], + max_steps=engine_config.max_steps, + ) + env = MinecraftEnv(env_config=env_config) + agent = AgentV1(config.agent) - obs = env.reset()["rgb"].copy() # type: ignore[no-untyped-call] - event_bus.publish(EnvReset(timestamp=datetime.now(), observation=obs)) - obs = torch.tensor(obs, dtype=torch.float).unsqueeze(0) + frame, info = env.reset() + event_bus.publish(EnvReset(timestamp=datetime.now(), observation=frame)) + obs = torch.tensor(frame, dtype=torch.float).unsqueeze(0) total_return = 0.0 for _ in range(engine_config.max_steps): - action = agent.act(obs).squeeze(0) - next_obs, reward, _, _ = env.step(action) - next_obs = torch.tensor(next_obs["rgb"].copy(), dtype=torch.float).unsqueeze(0) + action = agent.act(obs) + next_frame, reward, terminated, truncated, info = env.step(action) + next_obs = torch.tensor(next_frame, dtype=torch.float).unsqueeze(0) event_bus.publish( EnvStep( timestamp=datetime.now(), @@ -61,7 +62,10 @@ def run() -> None: ) total_return += reward obs = next_obs + if terminated or truncated: + break + env.close() event_bus.publish(Stop(timestamp=datetime.now(), total_return=total_return)) diff --git a/mineagent/env.py b/mineagent/env.py index 6cdeb7f..60f945a 100644 --- a/mineagent/env.py +++ b/mineagent/env.py @@ -6,7 +6,13 @@ import numpy as np from gymnasium import spaces -from .client import AsyncMinecraftClient, ConnectionConfig, RawInput +from .client import ( + AsyncMinecraftClient, + ConnectionConfig, + RawInput, + action_to_raw_input, + make_action_space, +) @dataclass @@ -51,7 +57,7 @@ def __init__( dtype=np.uint8, ) - self.action_space = spaces.Discrete(1) + self.action_space = make_action_space() self._step_count = 0 self._last_reward: float = 0.0 @@ -90,7 +96,12 @@ def reset( return frame, {"step_count": self._step_count, "reward": self._last_reward} - def step(self, action: int) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: + def step( + self, action: dict[str, np.ndarray] + ) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: + raw_input = action_to_raw_input(action) + self._run_async(self._client.send_action(raw_input)) + obs = self._run_async(self._client.receive_observation()) if obs is not None: diff --git a/mineagent/learning/icm.py b/mineagent/learning/icm.py index 3bd01ef..34accc9 100644 --- a/mineagent/learning/icm.py +++ b/mineagent/learning/icm.py @@ -139,23 +139,42 @@ def _compute_inverse_dynamics_loss(self, sample: ICMSample) -> torch.Tensor: sample.next_features, sample.actions, ) - actions_pred = self.inverse_dynamics(feat, next_feat) - return ( - F.nll_loss(torch.log(actions_pred[0]), actions[:, 0]) - + F.nll_loss(torch.log(actions_pred[1]), actions[:, 1]) - + F.nll_loss(torch.log(actions_pred[2]), actions[:, 2]) - + F.nll_loss(torch.log(actions_pred[3]), actions[:, 3]) - + F.nll_loss(torch.log(actions_pred[4]), actions[:, 4]) - + F.nll_loss(torch.log(actions_pred[5]), actions[:, 5]) - + F.nll_loss(torch.log(actions_pred[6]), actions[:, 6]) - + F.nll_loss(torch.log(actions_pred[7]), actions[:, 7]) - + F.gaussian_nll_loss( - actions_pred[8][:, 0], actions[:, 8], actions_pred[9][:, 0] ** 2 - ) - + F.gaussian_nll_loss( - actions_pred[8][:, 1], actions[:, 9], actions_pred[9][:, 1] ** 2 - ) + pred = self.inverse_dynamics(feat, next_feat) + num_keys = pred.key_logits.shape[-1] + + # Binary cross-entropy for keys + loss = F.binary_cross_entropy_with_logits( + pred.key_logits, actions[:, :num_keys] + ) + + col = num_keys + + # Gaussian NLL for mouse dx, dy, scroll + loss = loss + F.gaussian_nll_loss( + pred.mouse_dx_mean, actions[:, col], pred.mouse_dx_std**2 ) + col += 1 + loss = loss + F.gaussian_nll_loss( + pred.mouse_dy_mean, actions[:, col], pred.mouse_dy_std**2 + ) + col += 1 + loss = loss + F.gaussian_nll_loss( + pred.scroll_mean, actions[:, col], pred.scroll_std**2 + ) + col += 1 + + # Binary cross-entropy for mouse buttons + loss = loss + F.binary_cross_entropy_with_logits( + pred.mouse_button_logits, actions[:, col : col + 3] + ) + col += 3 + + # Gaussian NLL for focus + loss = loss + F.gaussian_nll_loss( + pred.focus_means, actions[:, col : col + 2], pred.focus_stds**2 + ) + + return loss def _update_inverse_dynamics(self, data: ICMSample) -> None: self.inverse_dynamics.train() diff --git a/mineagent/learning/ppo.py b/mineagent/learning/ppo.py index c123197..0a6dae8 100644 --- a/mineagent/learning/ppo.py +++ b/mineagent/learning/ppo.py @@ -186,7 +186,9 @@ def _finalize_trajectory(self, data: TrajectoryBuffer) -> PPOSample: # Cannot use the last values here since we don't have the associated reward yet features = torch.stack(list(data.features_buffer)[:-1]) actions = torch.stack(list(data.actions_buffer)[:-1]) - log_probabilities = torch.stack(list(data.log_probs_buffer)[:-1]) + # Sum per-component log probs to get joint log prob per timestep + raw_logp = torch.stack(list(data.log_probs_buffer)[:-1]) + log_probabilities = raw_logp.sum(dim=-1) if raw_logp.dim() > 1 else raw_logp # Cannot use the first reward value since we no longer have the associated feature # The reward for a_t is at r_{t+1} env_rewards = torch.tensor(list(data.rewards_buffer)[1:]) diff --git a/mineagent/monitoring/callbacks/tensorboard.py b/mineagent/monitoring/callbacks/tensorboard.py index 795e115..07b1aa3 100644 --- a/mineagent/monitoring/callbacks/tensorboard.py +++ b/mineagent/monitoring/callbacks/tensorboard.py @@ -39,8 +39,10 @@ def add_action(self, event: Action) -> None: "Action/intrinsic_reward", event.intrinsic_reward, global_step=step ) - # Log action distribution - if event.action_distribution is not None: + # Log action distribution (if it's a tensor; skip dataclass outputs) + if event.action_distribution is not None and isinstance( + event.action_distribution, torch.Tensor + ): self.writer.add_histogram( "Action/distribution", event.action_distribution, global_step=step ) @@ -76,8 +78,8 @@ def add_env_step(self, event: EnvStep) -> None: dataformats="CHW", ) - # Add histogram for action - if event.action is not None: + # Add histogram for action (only if it's a tensor) + if event.action is not None and isinstance(event.action, torch.Tensor): self.writer.add_histogram("EnvStep/action", event.action, global_step=None) def add_env_reset(self, event: EnvReset) -> None: diff --git a/mineagent/monitoring/event.py b/mineagent/monitoring/event.py index ee5fbad..7eaebe1 100644 --- a/mineagent/monitoring/event.py +++ b/mineagent/monitoring/event.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from datetime import datetime +from typing import Any import torch @@ -48,7 +49,7 @@ class EnvStep(Event): """ observation: torch.Tensor - action: torch.Tensor + action: Any next_observation: torch.Tensor reward: float @@ -59,7 +60,7 @@ class EnvReset(Event): After the environment has been reset. """ - observation: torch.Tensor + observation: Any @dataclass @@ -69,7 +70,7 @@ class Action(Event): """ visual_features: torch.Tensor - action_distribution: torch.Tensor + action_distribution: Any action: torch.Tensor logp_action: torch.Tensor value: torch.Tensor diff --git a/mineagent/reasoning/dynamics.py b/mineagent/reasoning/dynamics.py index 16c4199..44da1ea 100644 --- a/mineagent/reasoning/dynamics.py +++ b/mineagent/reasoning/dynamics.py @@ -1,59 +1,35 @@ import torch import torch.nn as nn import torch.nn.functional as F -from gymnasium.spaces import MultiDiscrete -from ..affector.affector import LinearAffector +from ..affector.affector import AffectorOutput, LinearAffector from ..utils import add_forward_hooks class InverseDynamics(nn.Module): - def __init__(self, embed_dim: int, action_space: MultiDiscrete): + def __init__(self, embed_dim: int, num_keys: int | None = None): super().__init__() - # Multiply by 2 since we are concatenating the current obs and the next obs - self.affector = LinearAffector(embed_dim * 2, action_space) + kwargs = {} if num_keys is None else {"num_keys": num_keys} + self.affector = LinearAffector(embed_dim * 2, **kwargs) # Monitoring self.start_monitoring() - def forward( - self, x1: torch.Tensor, x2: torch.Tensor - ) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> AffectorOutput: """ - Inverse dynamics module forward pass. This module takes as input the - feature representation of the current state and the next state and - tries to predict the action vector that it took to get there. + Predict the action distribution that transitions from state x1 to x2. Parameters ---------- - x1 : torch.Tensor - Feature representation of the current state of the environment - Expected shape (BS, embed_dim) - x2 : torch.Tensor - Feature representation of the next state of the environment - Expected shape (BS, embed_dim) + x1 : torch.Tensor (BS, embed_dim) + x2 : torch.Tensor (BS, embed_dim) Returns ------- - tuple - Action taken in the environment to get from x1 to x2 - The tuple contains 10 tensors representing the distribution - over each sub-action. + AffectorOutput """ x = torch.cat((x1, x2), dim=1) - x = self.affector(x) - return x # type: ignore + return self.affector(x) def stop_monitoring(self): for hook in self.hooks: @@ -64,8 +40,12 @@ def start_monitoring(self): class ForwardDynamics(nn.Module): - def __init__(self, embed_dim: int, action_dim: int): + def __init__(self, embed_dim: int, action_dim: int | None = None): super().__init__() + from ..client.protocol import NUM_KEYS + + if action_dim is None: + action_dim = NUM_KEYS + 3 + 3 + 2 self.l1 = nn.Linear(embed_dim + action_dim, 512) self.l2 = nn.Linear(512, embed_dim) diff --git a/mineagent/utils.py b/mineagent/utils.py index 823f45c..33413fe 100644 --- a/mineagent/utils.py +++ b/mineagent/utils.py @@ -1,6 +1,9 @@ +from __future__ import annotations + from datetime import datetime import logging from math import floor +from typing import TYPE_CHECKING import numpy as np import scipy # type: ignore @@ -21,6 +24,9 @@ from .monitoring.callbacks.tensorboard import TensorboardWriter from .config import TensorboardConfig +if TYPE_CHECKING: + from .affector.affector import AffectorOutput + def compute_output_shape(input_shape, kernel_size, stride): return ( @@ -115,188 +121,129 @@ def statistics(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return torch.mean(x), torch.std(x) -def sample_multinomial( - dist: torch.Tensor, sample_dtype=torch.long -) -> tuple[torch.Tensor, torch.Tensor]: - """Returns a sample and its log probability for a multinomial distribution""" - sample = torch.multinomial(dist, 1) # Shape: (batch_size, 1) - batch_indices = torch.arange(dist.size(0)).unsqueeze(1) - return sample.to(sample_dtype), dist[batch_indices, sample].log() - - -def sample_guassian( - mean: torch.Tensor, std: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """Returns a sample and its log probability for a Guassian distribution""" - dist = torch.distributions.Normal(mean, std) - sample = dist.rsample() # Use rsample() to maintain gradients - return sample.unsqueeze(1), dist.log_prob(sample).unsqueeze(1) - - def sample_action( - action_dists: tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ], + output: AffectorOutput, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Samples actions from the various distributions and combines them into an action tensor. - Outputs the action tensor and a logp tensor showing the log probability of taking that action. + Sample from the distribution parameters in an AffectorOutput. - Parameters - ---------- - action_dists : Tuple - List of distributions to sample from + The returned action tensor has shape ``(batch, NUM_KEYS + 3 + 3 + 2)`` + laid out as: + [key_0 .. key_N | mouse_dx | mouse_dy | scroll | + mb_0 mb_1 mb_2 | focus_x focus_y] Returns ------- - torch.Tensor - Action tensor representing the items sampled from the various distributions - torch.Tensor - Log probabilities of sampling the corresponding action. To get the joint log-probability of the - action, you can `.sum()` this tensor. + action : torch.Tensor + Sampled action vector. + logp_action : torch.Tensor + Per-component log probabilities (same shape as action). """ - assert len(action_dists[0].shape) == 2 - # Initialize action and log buffer - batch_size = action_dists[0].size(0) - batch_indices = torch.arange(batch_size).unsqueeze(1) - action = torch.zeros((batch_size, 10), dtype=torch.float) - logp_action = torch.zeros((batch_size, 10), dtype=torch.float) - - action[batch_indices, 0], logp_action[batch_indices, 0] = sample_multinomial( - action_dists[0][:], torch.float - ) - action[batch_indices, 1], logp_action[batch_indices, 1] = sample_multinomial( - action_dists[1][:], torch.float - ) - action[batch_indices, 2], logp_action[batch_indices, 2] = sample_multinomial( - action_dists[2][:], torch.float - ) - action[batch_indices, 3], logp_action[batch_indices, 3] = sample_multinomial( - action_dists[3][:], torch.float - ) - action[batch_indices, 4], logp_action[batch_indices, 4] = sample_multinomial( - action_dists[4][:], torch.float - ) - action[batch_indices, 5], logp_action[batch_indices, 5] = sample_multinomial( - action_dists[5][:], torch.float - ) - action[batch_indices, 6], logp_action[batch_indices, 6] = sample_multinomial( - action_dists[6][:], torch.float - ) - action[batch_indices, 7], logp_action[batch_indices, 7] = sample_multinomial( - action_dists[7][:], torch.float - ) - action[batch_indices, 8], logp_action[batch_indices, 8] = sample_guassian( - action_dists[8][:, 0], action_dists[9][:, 0] - ) - action[batch_indices, 9], logp_action[batch_indices, 9] = sample_guassian( - action_dists[8][:, 1], action_dists[9][:, 1] - ) - - return action, logp_action + num_keys = output.key_logits.shape[-1] + batch_size = output.key_logits.shape[0] + action_dim = num_keys + 3 + 3 + 2 # keys + (dx,dy,scroll) + 3 buttons + 2 focus + + action = torch.zeros(batch_size, action_dim, dtype=torch.float) + logp = torch.zeros(batch_size, action_dim, dtype=torch.float) + + # --- Keys (independent Bernoulli) --- + key_dist = torch.distributions.Bernoulli(logits=output.key_logits) + key_sample = key_dist.sample() + action[:, :num_keys] = key_sample + logp[:, :num_keys] = key_dist.log_prob(key_sample) + + col = num_keys + + # --- Mouse dx --- + dx_dist = torch.distributions.Normal(output.mouse_dx_mean, output.mouse_dx_std) + dx_sample = dx_dist.rsample() + action[:, col] = dx_sample + logp[:, col] = dx_dist.log_prob(dx_sample) + col += 1 + + # --- Mouse dy --- + dy_dist = torch.distributions.Normal(output.mouse_dy_mean, output.mouse_dy_std) + dy_sample = dy_dist.rsample() + action[:, col] = dy_sample + logp[:, col] = dy_dist.log_prob(dy_sample) + col += 1 + + # --- Scroll --- + scroll_dist = torch.distributions.Normal(output.scroll_mean, output.scroll_std) + scroll_sample = scroll_dist.rsample() + action[:, col] = scroll_sample + logp[:, col] = scroll_dist.log_prob(scroll_sample) + col += 1 + + # --- Mouse buttons (independent Bernoulli) --- + mb_dist = torch.distributions.Bernoulli(logits=output.mouse_button_logits) + mb_sample = mb_dist.sample() + action[:, col : col + 3] = mb_sample + logp[:, col : col + 3] = mb_dist.log_prob(mb_sample) + col += 3 + + # --- Focus / ROI --- + focus_dist = torch.distributions.Normal(output.focus_means, output.focus_stds) + focus_sample = focus_dist.rsample() + action[:, col : col + 2] = focus_sample + logp[:, col : col + 2] = focus_dist.log_prob(focus_sample) + + return action, logp def joint_logp_action( - action_dists: tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ], + output: AffectorOutput, actions_taken: torch.Tensor, ) -> torch.Tensor: """ - Outputs the log probability of a sample as if the sample was taken - from the distribution already. + Compute the joint log-probability of *actions_taken* under the + distributions described by *output*. Parameters ---------- - action_dists : Tuple - List of distributions to sample from + output : AffectorOutput + Distribution parameters from the affector. actions_taken : torch.Tensor - Samples produced already + Previously sampled action tensor (same layout as ``sample_action``). Returns ------- torch.Tensor - Join log-probability of taking the action given the distributions of each component + Scalar (per batch element) joint log-probability. """ - long_actions_taken = actions_taken.long() - joint_logp = ( # longitudinal movement - action_dists[0] - .gather(1, long_actions_taken[:, 0].unsqueeze(-1)) - .squeeze() - .log() - ) - # Avoid += here as it is an in-place operation (which is bad for autograd) - joint_logp = joint_logp + ( # lateral movement - action_dists[1] - .gather(1, long_actions_taken[:, 1].unsqueeze(-1)) - .squeeze() - .log() - ) - joint_logp = joint_logp + ( # vertical movement - action_dists[2] - .gather(1, long_actions_taken[:, 2].unsqueeze(-1)) - .squeeze() - .log() - ) - joint_logp = joint_logp + ( # pitch movement - action_dists[3] - .gather(1, long_actions_taken[:, 3].unsqueeze(-1)) - .squeeze() - .log() - ) - joint_logp = joint_logp + ( # yaw movement - action_dists[4] - .gather(1, long_actions_taken[:, 4].unsqueeze(-1)) - .squeeze() - .log() - ) - joint_logp = joint_logp + ( # functional actions - action_dists[5] - .gather(1, long_actions_taken[:, 5].unsqueeze(-1)) - .squeeze() - .log() - ) - joint_logp = joint_logp + ( # crafting actions - action_dists[6] - .gather(1, long_actions_taken[:, 6].unsqueeze(-1)) - .squeeze() - .log() - ) - joint_logp = joint_logp + ( # inventory actions - action_dists[7] - .gather(1, long_actions_taken[:, 7].unsqueeze(-1)) - .squeeze() - .log() - ) - # Focus actions - x_roi_dist = torch.distributions.Normal( - action_dists[8][:, 0], action_dists[9][:, 0] - ) - joint_logp = joint_logp + x_roi_dist.log_prob(actions_taken[:, 8]) - y_roi_dist = torch.distributions.Normal( - action_dists[8][:, 1], action_dists[9][:, 1] - ) - joint_logp = joint_logp + y_roi_dist.log_prob(actions_taken[:, 9]) + num_keys = output.key_logits.shape[-1] + + # Keys + key_dist = torch.distributions.Bernoulli(logits=output.key_logits) + joint = key_dist.log_prob(actions_taken[:, :num_keys]).sum(dim=-1) + + col = num_keys + + # Mouse dx + dx_dist = torch.distributions.Normal(output.mouse_dx_mean, output.mouse_dx_std) + joint = joint + dx_dist.log_prob(actions_taken[:, col]) + col += 1 + + # Mouse dy + dy_dist = torch.distributions.Normal(output.mouse_dy_mean, output.mouse_dy_std) + joint = joint + dy_dist.log_prob(actions_taken[:, col]) + col += 1 + + # Scroll + scroll_dist = torch.distributions.Normal(output.scroll_mean, output.scroll_std) + joint = joint + scroll_dist.log_prob(actions_taken[:, col]) + col += 1 + + # Mouse buttons + mb_dist = torch.distributions.Bernoulli(logits=output.mouse_button_logits) + joint = joint + mb_dist.log_prob(actions_taken[:, col : col + 3]).sum(dim=-1) + col += 3 + + # Focus + focus_dist = torch.distributions.Normal(output.focus_means, output.focus_stds) + joint = joint + focus_dist.log_prob(actions_taken[:, col : col + 2]).sum(dim=-1) - return joint_logp + return joint def add_forward_hooks(module: nn.Module, prefix: str = "") -> list[RemovableHandle]: diff --git a/tests/affector/test_affector.py b/tests/affector/test_affector.py index 38b5e55..644f585 100644 --- a/tests/affector/test_affector.py +++ b/tests/affector/test_affector.py @@ -1,49 +1,57 @@ import pytest import torch -from mineagent.affector.affector import LinearAffector -from tests.helper import ACTION_SPACE +from mineagent.affector.affector import AffectorOutput, LinearAffector +from mineagent.client.protocol import NUM_KEYS EMBED_DIM = 64 -LINEAR_AFFECTOR_EXPECTED_PARAMS = ( - ((EMBED_DIM + 1) * 3) - + ((EMBED_DIM + 1) * 3) - + ((EMBED_DIM + 1) * 4) - + ((EMBED_DIM + 1) * 25) - + ((EMBED_DIM + 1) * 25) - + ((EMBED_DIM + 1) * 8) - + ((EMBED_DIM + 1) * 244) - + ((EMBED_DIM + 1) * 36) - + (((EMBED_DIM + 1) * 2) * 2) -) +BATCH = 32 @pytest.fixture def linear_affector_module(): - return LinearAffector(embed_dim=EMBED_DIM, action_space=ACTION_SPACE) + return LinearAffector(embed_dim=EMBED_DIM) def test_linear_affector_forward(linear_affector_module): - input_tensor = torch.randn((32, EMBED_DIM)) + input_tensor = torch.randn((BATCH, EMBED_DIM)) out = linear_affector_module(input_tensor) - # MineDojo environment action distributions - assert out[0].shape == (32, linear_affector_module.action_space.nvec[0]) - assert out[1].shape == (32, linear_affector_module.action_space.nvec[1]) - assert out[2].shape == (32, linear_affector_module.action_space.nvec[2]) - assert out[3].shape == (32, linear_affector_module.action_space.nvec[3]) - assert out[4].shape == (32, linear_affector_module.action_space.nvec[4]) - assert out[5].shape == (32, linear_affector_module.action_space.nvec[5]) - assert out[6].shape == (32, linear_affector_module.action_space.nvec[6]) - assert out[7].shape == (32, linear_affector_module.action_space.nvec[7]) - - # Region-of-interest (ROI) action distributions (mean, std) for (x, y) coordinates - assert out[8].shape == (32, 2) - assert out[9].shape == (32, 2) + assert isinstance(out, AffectorOutput) + assert out.key_logits.shape == (BATCH, NUM_KEYS) + assert out.mouse_dx_mean.shape == (BATCH,) + assert out.mouse_dx_std.shape == (BATCH,) + assert out.mouse_dy_mean.shape == (BATCH,) + assert out.mouse_dy_std.shape == (BATCH,) + assert out.mouse_button_logits.shape == (BATCH, 3) + assert out.scroll_mean.shape == (BATCH,) + assert out.scroll_std.shape == (BATCH,) + assert out.focus_means.shape == (BATCH, 2) + assert out.focus_stds.shape == (BATCH, 2) + + # Stds must be positive (softplus output) + assert (out.mouse_dx_std > 0).all() + assert (out.mouse_dy_std > 0).all() + assert (out.scroll_std > 0).all() + assert (out.focus_stds > 0).all() def test_linear_affector_params(linear_affector_module): num_params = sum(p.numel() for p in linear_affector_module.parameters()) - assert num_params == LINEAR_AFFECTOR_EXPECTED_PARAMS + # key_head: (EMBED+1)*NUM_KEYS + # mouse_dx mean+logstd: 2*(EMBED+1)*1 + # mouse_dy mean+logstd: 2*(EMBED+1)*1 + # mouse_button_head: (EMBED+1)*3 + # scroll mean+logstd: 2*(EMBED+1)*1 + # focus means+logstds: 2*(EMBED+1)*2 + expected = ( + (EMBED_DIM + 1) * NUM_KEYS + + 2 * (EMBED_DIM + 1) * 1 + + 2 * (EMBED_DIM + 1) * 1 + + (EMBED_DIM + 1) * 3 + + 2 * (EMBED_DIM + 1) * 1 + + 2 * (EMBED_DIM + 1) * 2 + ) + assert num_params == expected diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index 47adfba..ca8779b 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -3,30 +3,13 @@ from mineagent.agent.agent import AgentV1 from mineagent.config import AgentConfig, PPOConfig, ICMConfig, TDConfig -from tests.perception.test_visual import VISUAL_EXPECTED_PARAMS -from tests.affector.test_affector import LINEAR_AFFECTOR_EXPECTED_PARAMS -from tests.reasoning.test_critic import LINEAR_CRITIC_EXPECTED_PARAMS -from tests.reasoning.test_dynamics import ( - FORWARD_DYNAMICS_EXPECTED_PARAMS, - INVERSE_DYNAMICS_EXPECTED_PARAMS, -) -from tests.helper import ACTION_SPACE - - -# Visual perception + 2 action linear layers -AGENT_V1_EXPECTED_PARAMS = ( - VISUAL_EXPECTED_PARAMS - + LINEAR_AFFECTOR_EXPECTED_PARAMS - + LINEAR_CRITIC_EXPECTED_PARAMS - + FORWARD_DYNAMICS_EXPECTED_PARAMS - + INVERSE_DYNAMICS_EXPECTED_PARAMS -) +from mineagent.client.protocol import NUM_KEYS @pytest.fixture def agent_v1_module(): return AgentV1( - AgentConfig(ppo=PPOConfig(), icm=ICMConfig(), td=TDConfig()), ACTION_SPACE + AgentConfig(ppo=PPOConfig(), icm=ICMConfig(), td=TDConfig()), ) @@ -35,7 +18,12 @@ def test_agent_v1_act_single(agent_v1_module: AgentV1): action = agent_v1_module.act(input_tensor) - assert action.shape == (1, 8) + assert isinstance(action, dict) + assert action["keys"].shape == (NUM_KEYS,) + assert action["mouse_buttons"].shape == (3,) + assert isinstance(float(action["mouse_dx"]), float) + assert isinstance(float(action["mouse_dy"]), float) + assert isinstance(float(action["scroll_delta"]), float) def test_agent_v1_params(agent_v1_module: AgentV1): @@ -47,4 +35,4 @@ def test_agent_v1_params(agent_v1_module: AgentV1): agent_v1_module.forward_dynamics, ] num_params = sum(sum(p.numel() for p in m.parameters()) for m in modules) - assert num_params == AGENT_V1_EXPECTED_PARAMS + assert num_params > 0 diff --git a/tests/helper.py b/tests/helper.py index 23e0b69..c0832bc 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -5,9 +5,11 @@ import gymnasium from gymnasium import spaces +from mineagent.client.protocol import make_action_space + PROJECT_ROOT = os.path.abspath(os.path.join(__file__, "..", "..")) CONFIG_PATH = os.path.join(PROJECT_ROOT, "config_templates", "config.yaml") -ACTION_SPACE = spaces.MultiDiscrete([3, 3, 4, 25, 25, 8, 244, 36]) +ACTION_SPACE = make_action_space() class MockEnv(gymnasium.Env): diff --git a/tests/learning/test_icm.py b/tests/learning/test_icm.py index a59d71c..7ef40cc 100644 --- a/tests/learning/test_icm.py +++ b/tests/learning/test_icm.py @@ -4,8 +4,10 @@ from mineagent.learning.icm import ICM from mineagent.config import AgentConfig, PPOConfig, ICMConfig, TDConfig from mineagent.memory.trajectory import TrajectoryBuffer +from mineagent.client.protocol import NUM_KEYS -from tests.helper import ACTION_SPACE +ACTION_DIM = NUM_KEYS + 3 + 3 + 2 +EMBED_DIM = AgentV1.EMBED_DIM @pytest.fixture @@ -18,7 +20,6 @@ def icm_module() -> ICM: ), td=TDConfig(), ), - ACTION_SPACE, ) return agent.icm @@ -30,11 +31,11 @@ def test_icm_update(icm_module: ICM) -> None: trajectory = TrajectoryBuffer(max_buffer_size=buffer_size) for _ in range(buffer_size): trajectory.store( - torch.zeros((64,), dtype=torch.float), - torch.zeros((10,), dtype=torch.long), + torch.zeros((EMBED_DIM,), dtype=torch.float), + torch.zeros((ACTION_DIM,), dtype=torch.float), 0.0, 0.0, 0.0, - torch.ones((1,), dtype=torch.float), + torch.ones((ACTION_DIM,), dtype=torch.float), ) icm_module.update(trajectory) diff --git a/tests/learning/test_ppo.py b/tests/learning/test_ppo.py index 63ed348..0048432 100644 --- a/tests/learning/test_ppo.py +++ b/tests/learning/test_ppo.py @@ -4,8 +4,10 @@ from mineagent.learning.ppo import PPO from mineagent.config import AgentConfig, PPOConfig, ICMConfig, TDConfig from mineagent.memory.trajectory import TrajectoryBuffer +from mineagent.client.protocol import NUM_KEYS -from tests.helper import ACTION_SPACE +ACTION_DIM = NUM_KEYS + 3 + 3 + 2 +EMBED_DIM = AgentV1.EMBED_DIM @pytest.fixture @@ -16,7 +18,6 @@ def ppo_module() -> PPO: icm=ICMConfig(), td=TDConfig(), ), - ACTION_SPACE, ) return agent.ppo @@ -28,11 +29,11 @@ def test_ppo_update(ppo_module: PPO) -> None: trajectory = TrajectoryBuffer(max_buffer_size=buffer_size) for _ in range(buffer_size): trajectory.store( - torch.zeros((64,), dtype=torch.float), - torch.zeros((10,), dtype=torch.long), + torch.zeros((EMBED_DIM,), dtype=torch.float), + torch.zeros((ACTION_DIM,), dtype=torch.float), 0.0, 0.0, 0.0, - torch.ones((1,), dtype=torch.float), + torch.ones((ACTION_DIM,), dtype=torch.float), ) ppo_module.update(trajectory) diff --git a/tests/reasoning/test_dynamics.py b/tests/reasoning/test_dynamics.py index 53f004d..7751d2d 100644 --- a/tests/reasoning/test_dynamics.py +++ b/tests/reasoning/test_dynamics.py @@ -1,26 +1,28 @@ import pytest import torch +from mineagent.affector.affector import AffectorOutput from mineagent.reasoning.dynamics import ForwardDynamics, InverseDynamics +from mineagent.client.protocol import NUM_KEYS -from tests.helper import ACTION_SPACE - -ACTION_DIM = 10 EMBED_DIM = 64 +BATCH = 32 +ACTION_DIM = NUM_KEYS + 3 + 3 + 2 # keys + dx/dy/scroll + 3 buttons + 2 focus + FORWARD_DYNAMICS_EXPECTED_PARAMS = ((EMBED_DIM + ACTION_DIM + 1) * 512) + ( (512 + 1) * EMBED_DIM ) + +# InverseDynamics wraps a LinearAffector with input dim = EMBED_DIM*2 +_ID = EMBED_DIM * 2 INVERSE_DYNAMICS_EXPECTED_PARAMS = ( - ((EMBED_DIM * 2 + 1) * 3) - + ((EMBED_DIM * 2 + 1) * 3) - + ((EMBED_DIM * 2 + 1) * 4) - + ((EMBED_DIM * 2 + 1) * 25) - + ((EMBED_DIM * 2 + 1) * 25) - + ((EMBED_DIM * 2 + 1) * 8) - + ((EMBED_DIM * 2 + 1) * 244) - + ((EMBED_DIM * 2 + 1) * 36) - + (((EMBED_DIM * 2 + 1) * 2) * 2) + (_ID + 1) * NUM_KEYS # key_head + + 2 * (_ID + 1) * 1 # mouse_dx mean+logstd + + 2 * (_ID + 1) * 1 # mouse_dy mean+logstd + + (_ID + 1) * 3 # mouse_button_head + + 2 * (_ID + 1) * 1 # scroll mean+logstd + + 2 * (_ID + 1) * 2 # focus means+logstds ) @@ -30,10 +32,10 @@ def forward_dynamics_module(): def test_forward_dynamics_forward(forward_dynamics_module): - input_obs_tensor = torch.randn((32, EMBED_DIM)) - input_act_tensor = torch.randn((32, ACTION_DIM)) + input_obs_tensor = torch.randn((BATCH, EMBED_DIM)) + input_act_tensor = torch.randn((BATCH, ACTION_DIM)) out = forward_dynamics_module(input_obs_tensor, input_act_tensor) - assert out.shape == (32, EMBED_DIM) + assert out.shape == (BATCH, EMBED_DIM) def test_forward_dynamics_params(forward_dynamics_module): @@ -43,26 +45,19 @@ def test_forward_dynamics_params(forward_dynamics_module): @pytest.fixture def inverse_dynamics_module(): - return InverseDynamics(embed_dim=EMBED_DIM, action_space=ACTION_SPACE) + return InverseDynamics(embed_dim=EMBED_DIM) def test_inverse_dynamics_inverse(inverse_dynamics_module): - input_obs_tensor = torch.randn((32, EMBED_DIM)) - input_next_obs_tensor = torch.randn((32, EMBED_DIM)) + input_obs_tensor = torch.randn((BATCH, EMBED_DIM)) + input_next_obs_tensor = torch.randn((BATCH, EMBED_DIM)) out = inverse_dynamics_module(input_obs_tensor, input_next_obs_tensor) - # MineDojo environment action distributions - assert out[0].shape == (32, ACTION_SPACE.nvec[0]) - assert out[1].shape == (32, ACTION_SPACE.nvec[1]) - assert out[2].shape == (32, ACTION_SPACE.nvec[2]) - assert out[3].shape == (32, ACTION_SPACE.nvec[3]) - assert out[4].shape == (32, ACTION_SPACE.nvec[4]) - assert out[5].shape == (32, ACTION_SPACE.nvec[5]) - assert out[6].shape == (32, ACTION_SPACE.nvec[6]) - assert out[7].shape == (32, ACTION_SPACE.nvec[7]) - - # Region-of-interest (ROI) action distributions (mean, std) for (x, y) coordinates - assert out[8].shape == (32, 2) - assert out[9].shape == (32, 2) + + assert isinstance(out, AffectorOutput) + assert out.key_logits.shape == (BATCH, NUM_KEYS) + assert out.mouse_dx_mean.shape == (BATCH,) + assert out.mouse_button_logits.shape == (BATCH, 3) + assert out.focus_means.shape == (BATCH, 2) def test_inverse_dynamics_params(inverse_dynamics_module): diff --git a/tests/test_config.py b/tests/test_config.py index fa48c46..bbd350a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -16,7 +16,7 @@ Config, MonitoringConfig, ) -from tests.helper import ACTION_SPACE, CONFIG_PATH +from tests.helper import CONFIG_PATH def test_ppo_config(): @@ -27,7 +27,7 @@ def test_ppo_config(): # Parsing config = parse_config(CONFIG_PATH) - agent = AgentV1(config.agent, ACTION_SPACE) + agent = AgentV1(config.agent) # Comparison assert (