diff --git a/mineagent/client/__init__.py b/mineagent/client/__init__.py new file mode 100644 index 0000000..c08c9e2 --- /dev/null +++ b/mineagent/client/__init__.py @@ -0,0 +1,18 @@ +from .connection import AsyncMinecraftClient, ConnectionConfig +from .protocol import ( + COMMAND_TO_KEY, + GLFW, + Observation, + RawInput, + parse_observation, +) + +__all__ = [ + "AsyncMinecraftClient", + "ConnectionConfig", + "COMMAND_TO_KEY", + "GLFW", + "Observation", + "RawInput", + "parse_observation", +] diff --git a/mineagent/client/connection.py b/mineagent/client/connection.py new file mode 100644 index 0000000..64430a4 --- /dev/null +++ b/mineagent/client/connection.py @@ -0,0 +1,151 @@ +import struct +import asyncio +import logging +from dataclasses import dataclass + +import numpy as np + +from .protocol import Observation, RawInput, parse_observation + + +@dataclass +class ConnectionConfig: + """Configuration for Minecraft Forge mod connection.""" + + observation_socket: str = "/tmp/mineagent_observation.sock" + action_socket: str = "/tmp/mineagent_action.sock" + frame_width: int = 320 + frame_height: int = 240 + timeout: float = 30.0 + max_retries: int = 3 + retry_delay: float = 1.0 + + +class AsyncMinecraftClient: + """ + Async client for communicating with the Minecraft Forge mod via Unix domain sockets. + """ + + def __init__(self, config: ConnectionConfig | None = None): + self.config = config or ConnectionConfig() + self._observation_reader: asyncio.StreamReader | None = None + self._action_writer: asyncio.StreamWriter | None = None + self._connected: bool = False + self._logger = logging.getLogger(__name__) + + @property + def connected(self) -> bool: + return self._connected + + async def connect(self) -> bool: + """Establish connection to the Minecraft Forge mod.""" + for attempt in range(self.config.max_retries): + try: + self._observation_reader, _ = await asyncio.open_unix_connection( + self.config.observation_socket + ) + _, self._action_writer = await asyncio.open_unix_connection( + self.config.action_socket + ) + self._connected = True + self._logger.info( + "Connected to Minecraft Forge mod - Observation: %s, Action: %s", + self.config.observation_socket, + self.config.action_socket, + ) + return True + except OSError as e: + self._logger.warning("Connection attempt %d failed: %s", attempt + 1, e) + await self._cleanup() + if attempt < self.config.max_retries - 1: + await asyncio.sleep(self.config.retry_delay) + else: + self._logger.error( + "Failed to connect after %d attempts", self.config.max_retries + ) + return False + return False + + async def disconnect(self) -> None: + """Disconnect from the Minecraft Forge mod.""" + await self._cleanup() + self._connected = False + self._logger.info("Disconnected from Minecraft Forge mod") + + async def _cleanup(self) -> None: + """Clean up sockets.""" + if self._action_writer: + self._action_writer.close() + await self._action_writer.wait_closed() + self._action_writer = None + self._observation_reader = None + + async def send_action(self, raw_input: RawInput) -> bool: + """ + Send a raw input action to the Minecraft Forge mod. + + Parameters + ---------- + raw_input : RawInput + The input to send + + Returns + ------- + bool + True if sent successfully, False otherwise + """ + if not self._connected or not self._action_writer: + self._logger.error("Not connected to Minecraft Forge mod") + return False + + try: + data = raw_input.to_bytes() + self._action_writer.write(data) + await self._action_writer.drain() + return True + except OSError as e: + self._logger.error("Failed to send action: %s", e) + self._connected = False + return False + + async def receive_observation(self) -> Observation | None: + """ + Receive an observation from the Minecraft Forge mod. + + Returns + ------- + Observation | None + The observation if received successfully, None otherwise + """ + if not self._connected or not self._observation_reader: + self._logger.error("Not connected to Minecraft Forge mod") + return None + + try: + header = await self._observation_reader.readexactly(12) + reward = struct.unpack(">d", header[0:8])[0] + frame_length = struct.unpack(">I", header[8:12])[0] + + if frame_length == 0: + return Observation( + reward=reward, + frame=np.zeros( + (self.config.frame_height, self.config.frame_width, 3), + dtype=np.uint8, + ), + ) + + frame_data = await self._observation_reader.readexactly(frame_length) + + return parse_observation( + header, + frame_data, + (self.config.frame_height, self.config.frame_width), + ) + except asyncio.IncompleteReadError: + self._logger.warning("Incomplete observation received") + return None + except OSError as e: + self._logger.error("Failed to receive observation: %s", e) + self._connected = False + return None diff --git a/mineagent/client/protocol.py b/mineagent/client/protocol.py new file mode 100644 index 0000000..336eb30 --- /dev/null +++ b/mineagent/client/protocol.py @@ -0,0 +1,278 @@ +import struct +from dataclasses import dataclass, field + +import numpy as np + + +class GLFW: + """GLFW key and mouse button constants.""" + + KEY_SPACE = 32 + KEY_APOSTROPHE = 39 + KEY_COMMA = 44 + KEY_MINUS = 45 + KEY_PERIOD = 46 + KEY_SLASH = 47 + KEY_0 = 48 + KEY_1 = 49 + KEY_2 = 50 + KEY_3 = 51 + KEY_4 = 52 + KEY_5 = 53 + KEY_6 = 54 + KEY_7 = 55 + KEY_8 = 56 + KEY_9 = 57 + KEY_SEMICOLON = 59 + KEY_EQUAL = 61 + KEY_A = 65 + KEY_B = 66 + KEY_C = 67 + KEY_D = 68 + KEY_E = 69 + KEY_F = 70 + KEY_G = 71 + KEY_H = 72 + KEY_I = 73 + KEY_J = 74 + KEY_K = 75 + KEY_L = 76 + KEY_M = 77 + KEY_N = 78 + KEY_O = 79 + KEY_P = 80 + KEY_Q = 81 + KEY_R = 82 + KEY_S = 83 + KEY_T = 84 + KEY_U = 85 + KEY_V = 86 + KEY_W = 87 + KEY_X = 88 + KEY_Y = 89 + KEY_Z = 90 + KEY_LEFT_BRACKET = 91 + KEY_BACKSLASH = 92 + KEY_RIGHT_BRACKET = 93 + KEY_GRAVE_ACCENT = 96 + + KEY_ESCAPE = 256 + KEY_ENTER = 257 + KEY_TAB = 258 + KEY_BACKSPACE = 259 + KEY_INSERT = 260 + KEY_DELETE = 261 + KEY_RIGHT = 262 + KEY_LEFT = 263 + KEY_DOWN = 264 + KEY_UP = 265 + KEY_PAGE_UP = 266 + KEY_PAGE_DOWN = 267 + KEY_HOME = 268 + KEY_END = 269 + KEY_CAPS_LOCK = 280 + KEY_SCROLL_LOCK = 281 + KEY_NUM_LOCK = 282 + KEY_PRINT_SCREEN = 283 + KEY_PAUSE = 284 + KEY_F1 = 290 + KEY_F2 = 291 + KEY_F3 = 292 + KEY_F4 = 293 + KEY_F5 = 294 + KEY_F6 = 295 + KEY_F7 = 296 + KEY_F8 = 297 + KEY_F9 = 298 + KEY_F10 = 299 + KEY_F11 = 300 + KEY_F12 = 301 + + KEY_LEFT_SHIFT = 340 + KEY_LEFT_CONTROL = 341 + KEY_LEFT_ALT = 342 + KEY_LEFT_SUPER = 343 + KEY_RIGHT_SHIFT = 344 + KEY_RIGHT_CONTROL = 345 + KEY_RIGHT_ALT = 346 + KEY_RIGHT_SUPER = 347 + KEY_MENU = 348 + + MOUSE_BUTTON_LEFT = 0 + MOUSE_BUTTON_RIGHT = 1 + MOUSE_BUTTON_MIDDLE = 2 + + +COMMAND_TO_KEY: dict[str, int] = { + "w": GLFW.KEY_W, + "forward": GLFW.KEY_W, + "s": GLFW.KEY_S, + "back": GLFW.KEY_S, + "a": GLFW.KEY_A, + "left": GLFW.KEY_A, + "d": GLFW.KEY_D, + "right": GLFW.KEY_D, + "space": GLFW.KEY_SPACE, + "jump": GLFW.KEY_SPACE, + "shift": GLFW.KEY_LEFT_SHIFT, + "sneak": GLFW.KEY_LEFT_SHIFT, + "ctrl": GLFW.KEY_LEFT_CONTROL, + "sprint": GLFW.KEY_LEFT_CONTROL, + "e": GLFW.KEY_E, + "inventory": GLFW.KEY_E, + "q": GLFW.KEY_Q, + "drop": GLFW.KEY_Q, + "f": GLFW.KEY_F, + "swap": GLFW.KEY_F, + "1": GLFW.KEY_1, + "hotbar1": GLFW.KEY_1, + "2": GLFW.KEY_2, + "hotbar2": GLFW.KEY_2, + "3": GLFW.KEY_3, + "hotbar3": GLFW.KEY_3, + "4": GLFW.KEY_4, + "hotbar4": GLFW.KEY_4, + "5": GLFW.KEY_5, + "hotbar5": GLFW.KEY_5, + "6": GLFW.KEY_6, + "hotbar6": GLFW.KEY_6, + "7": GLFW.KEY_7, + "hotbar7": GLFW.KEY_7, + "8": GLFW.KEY_8, + "hotbar8": GLFW.KEY_8, + "9": GLFW.KEY_9, + "hotbar9": GLFW.KEY_9, + "esc": GLFW.KEY_ESCAPE, + "escape": GLFW.KEY_ESCAPE, + "enter": GLFW.KEY_ENTER, + "tab": GLFW.KEY_TAB, + "t": GLFW.KEY_T, + "chat": GLFW.KEY_T, + "/": GLFW.KEY_SLASH, + "command": GLFW.KEY_SLASH, +} + + +@dataclass +class RawInput: + """ + Raw input data to send to Minecraft. + + Protocol format (variable size): + - 1 byte: numKeysPressed (0-255) + - N*2 bytes: keyCodes (shorts, big-endian) + - 4 bytes: mouseDeltaX (float, big-endian) + - 4 bytes: mouseDeltaY (float, big-endian) + - 1 byte: mouseButtons (bits: 0=left, 1=right, 2=middle) + - 4 bytes: scrollDelta (float, big-endian) + - 2 bytes: textLength (big-endian) + - M bytes: textBytes (UTF-8) + """ + + key_codes: list[int] = field(default_factory=list) + mouse_dx: float = 0.0 + mouse_dy: float = 0.0 + mouse_buttons: int = 0 + scroll_delta: float = 0.0 + text: str = "" + + def to_bytes(self) -> bytes: + """Serialize to binary protocol format.""" + data = bytearray() + + num_keys = len(self.key_codes) + if num_keys > 255: + raise ValueError(f"Too many keys pressed: {num_keys} (max 255)") + data.append(num_keys) + + for key_code in self.key_codes: + data.extend(struct.pack(">h", key_code)) + + data.extend(struct.pack(">f", self.mouse_dx)) + data.extend(struct.pack(">f", self.mouse_dy)) + + data.append(self.mouse_buttons & 0xFF) + + data.extend(struct.pack(">f", self.scroll_delta)) + + text_bytes = self.text.encode("utf-8") + text_length = len(text_bytes) + if text_length > 65535: + raise ValueError(f"Text too long: {text_length} bytes (max 65535)") + data.extend(struct.pack(">H", text_length)) + data.extend(text_bytes) + + return bytes(data) + + def set_left_mouse(self, pressed: bool) -> None: + if pressed: + self.mouse_buttons |= 1 << GLFW.MOUSE_BUTTON_LEFT + else: + self.mouse_buttons &= ~(1 << GLFW.MOUSE_BUTTON_LEFT) + + def set_right_mouse(self, pressed: bool) -> None: + if pressed: + self.mouse_buttons |= 1 << GLFW.MOUSE_BUTTON_RIGHT + else: + self.mouse_buttons &= ~(1 << GLFW.MOUSE_BUTTON_RIGHT) + + def set_middle_mouse(self, pressed: bool) -> None: + if pressed: + self.mouse_buttons |= 1 << GLFW.MOUSE_BUTTON_MIDDLE + else: + self.mouse_buttons &= ~(1 << GLFW.MOUSE_BUTTON_MIDDLE) + + @staticmethod + def release_all() -> "RawInput": + """Create an empty input that releases all keys and mouse buttons.""" + return RawInput() + + +@dataclass +class Observation: + """Observation received from the Minecraft mod.""" + + reward: float + frame: np.ndarray + + +def parse_observation( + header: bytes, frame_data: bytes, frame_shape: tuple[int, int] = (240, 320) +) -> Observation: + """ + Parse observation from raw bytes. + + Parameters + ---------- + header : bytes + 12 bytes: reward (double, 8 bytes) + frame length (uint32, 4 bytes) + frame_data : bytes + Raw RGB frame data (H*W*3 bytes, RGB order) + frame_shape : tuple[int, int] + (height, width) of the frame + + Returns + ------- + Observation + Parsed observation with reward and frame + """ + if len(header) != 12: + raise ValueError(f"Header must be 12 bytes, got {len(header)}") + + reward = struct.unpack(">d", header[0:8])[0] + frame_length = struct.unpack(">I", header[8:12])[0] + + if frame_length != len(frame_data): + raise ValueError( + f"Frame length mismatch: header says {frame_length}, got {len(frame_data)}" + ) + + height, width = frame_shape + if len(frame_data) != height * width * 3: + raise ValueError( + f"Frame data size mismatch: expected {height * width * 3}, got {len(frame_data)}" + ) + + frame = np.frombuffer(frame_data, dtype=np.uint8).reshape(height, width, 3) + + return Observation(reward=reward, frame=frame) diff --git a/mineagent/env.py b/mineagent/env.py index 1b8c90d..6cdeb7f 100644 --- a/mineagent/env.py +++ b/mineagent/env.py @@ -1,398 +1,144 @@ -""" -Minecraft Forge Client for Gymnasium API - -This module provides a client implementation that connects to the Minecraft Forge mod -and exposes a Gymnasium-compatible interface for reinforcement learning experiments. -The current implementation focuses on reading frame data only. -""" +import asyncio +from dataclasses import dataclass +from typing import Any -import socket -import numpy as np import gymnasium as gym +import numpy as np from gymnasium import spaces -from typing import Any -import time -import logging -from dataclasses import dataclass -from PIL import Image -from .config import Config - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) +from .client import AsyncMinecraftClient, ConnectionConfig, RawInput @dataclass -class ConnectionConfig: - """Configuration for Minecraft Forge mod connection""" - - command_port: str = "/tmp/mineagent_receive.sock" - data_port: str = "/tmp/mineagent_send.sock" - width: int = 320 - height: int = 240 - timeout: float = 30.0 - max_retries: int = 3 - retry_delay: float = 1.0 - - -class MinecraftForgeClient: - """ - Low-level client for communicating with the Minecraft Forge mod - Uses TCP for commands and UDP for frame data (asynchronous) - """ +class MinecraftEnvConfig: + """Configuration for the Minecraft environment.""" - def __init__(self, config: ConnectionConfig): - self.config = config - self.command_socket: socket.socket | None = None - self.data_socket: socket.socket | None = None - self.connected: bool = False - self.logger: logging.Logger = logging.getLogger(__name__) - - def connect(self) -> bool: - """ - Establish connection to the Minecraft Forge mod - - Returns - ------- - bool - True if connection successful, False otherwise - """ - for attempt in range(self.config.max_retries): - try: - # Connect Unix domain socket for commands - self.command_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.command_socket.settimeout(self.config.timeout) - self.command_socket.connect(self.config.command_port) - - # Create Unix domain socket for frame data - self.data_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.data_socket.settimeout(self.config.timeout) - - # Performance optimizations for large data transfers - # Set large receive buffer (1MB) for better throughput - self.data_socket.setsockopt( - socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024 - ) - self.data_socket.connect(self.config.data_port) - - self.connected = True - self.logger.info( - f"Connected to Minecraft Forge mod - Command: {self.config.command_port}, Data: {self.config.data_port}" - ) - return True - except (socket.error, ConnectionRefusedError) as e: - self.logger.warning(f"Connection attempt {attempt + 1} failed: {e}") - self._cleanup_sockets() - if attempt < self.config.max_retries - 1: - time.sleep(self.config.retry_delay) - else: - self.logger.error("Failed to connect after all retries") - return False - - return False - - def disconnect(self): - """Disconnect from the Minecraft Forge mod""" - self._cleanup_sockets() - self.connected = False - self.logger.info("Disconnected from Minecraft Forge mod") - - def _cleanup_sockets(self): - """Clean up both sockets""" - if self.command_socket: - try: - self.command_socket.close() - finally: - self.command_socket = None - - if self.data_socket: - try: - self.data_socket.close() - finally: - self.data_socket = None - - def send_command(self, command: str) -> bool: - """ - Send a command to the Minecraft Forge mod via TCP - - Parameters - ---------- - command : str - Command to send to the mod - - Returns - ------- - bool - True if command sent successfully, False otherwise - """ - if not self.connected or not self.command_socket: - self.logger.error("Not connected to Minecraft Forge mod") - return False - - try: - self.command_socket.send((command + "\n").encode()) - return True - except socket.error as e: - self.logger.error(f"Failed to send command: {e}") - self.connected = False - return False - - def receive_frame_data(self) -> np.ndarray | None: - """ - Receive frame data from the Minecraft Forge mod via Unix domain socket - - Returns - ------- - np.ndarray | None - Frame data as numpy array if successful, None otherwise - """ - if not self.connected or not self.data_socket: - self.logger.error("Not connected to Minecraft Forge mod") - return None - - try: - # Fastest approach: receive header + data in one shot if possible - # First get just the header to know the total size - header_data = self.data_socket.recv(8, socket.MSG_WAITALL) - if len(header_data) != 8: - self.logger.error( - "Failed to receive complete protocol header, got %d bytes", - len(header_data), - ) - return None - else: - self.logger.info("Received header: %s", len(header_data)) - - # Parse header - reward = int.from_bytes(header_data[0:4], byteorder="big", signed=True) - data_length = int.from_bytes( - header_data[4:8], byteorder="big", signed=False - ) - self.logger.info( - "Received reward: %d, data length: %d", reward, data_length - ) + frame_width: int = 320 + frame_height: int = 240 + max_steps: int = 10_000 - if data_length == 0: - return None - - # Receive all frame data in single call with MSG_WAITALL - frame_data = self.data_socket.recv(data_length, socket.MSG_WAITALL) - if len(frame_data) != data_length: - self.logger.error( - "Failed to receive complete frame data, got %d bytes, expected %d bytes", - len(frame_data), - data_length, - ) - return None - - # Parse frame data using delta encoding protocol - return self._parse_frame_data(frame_data) - - except socket.timeout: - # Timeout is expected if no frames are being sent - self.logger.warning("Timeout waiting for frame data") - return None - except socket.error as e: - self.logger.error(f"Failed to receive frame data: {e}") - self.connected = False - return None - - def _parse_frame_data(self, frame_data: bytes) -> np.ndarray | None: - """ - Parse frame data using delta encoding protocol - - Parameters - ---------- - frame_data : bytes - Raw frame data from socket - - Returns - ------- - np.ndarray | None - Parsed frame as numpy array if successful, None otherwise - """ - return np.flipud( - np.frombuffer(frame_data, dtype=np.uint8).reshape( - self.config.height, self.config.width, 3 - ) - ) - -class MinecraftEnv(gym.Env[np.ndarray, np.int64]): +class MinecraftEnv(gym.Env): """ - Gymnasium environment for Minecraft using the Forge mod + Gymnasium environment for Minecraft using the Forge mod. This environment provides a Gymnasium-compatible interface for interacting - with Minecraft through the custom Forge mod. Currently supports reading - frame data only. + with Minecraft through the custom Forge mod. """ metadata = {"render_modes": ["rgb_array"]} def __init__( self, - config: Config | None = None, + env_config: MinecraftEnvConfig | None = None, connection_config: ConnectionConfig | None = None, ): super().__init__() - self.config = config or Config() + self.env_config = env_config or MinecraftEnvConfig() self.connection_config = connection_config or ConnectionConfig() - self.client = MinecraftForgeClient(self.connection_config) - # Set up observation space (RGB image) - height, width = self.config.engine.image_size + self.connection_config.frame_width = self.env_config.frame_width + self.connection_config.frame_height = self.env_config.frame_height + + self._client = AsyncMinecraftClient(self.connection_config) + self._loop: asyncio.AbstractEventLoop | None = None + self.observation_space = spaces.Box( - low=0, high=255, shape=(height, width, 3), dtype=np.uint8 + low=0, + high=255, + shape=(self.env_config.frame_height, self.env_config.frame_width, 3), + dtype=np.uint8, ) - # For now, we don't support actions, but we need to define the space - # This will be expanded when action support is added - self.action_space = spaces.Discrete(1) # No-op action + self.action_space = spaces.Discrete(1) - self.step_count = 0 - self.logger = logging.getLogger(__name__) + self._step_count = 0 + self._last_reward: float = 0.0 + + def _ensure_loop(self) -> asyncio.AbstractEventLoop: + if self._loop is None or self._loop.is_closed(): + self._loop = asyncio.new_event_loop() + return self._loop + + def _run_async(self, coro): + loop = self._ensure_loop() + return loop.run_until_complete(coro) def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[np.ndarray, dict[str, Any]]: - """ - Reset the environment - - Parameters - ---------- - seed : int | None - Random seed for reproducibility - options : dict[str, Any] | None - Additional options for reset - - Returns - ------- - Tuple[np.ndarray, Dict[str, Any]] - Initial observation and info dictionary - """ super().reset(seed=seed, options=options) - # Connect to Minecraft if not already connected - if not self.client.connected: - if not self.client.connect(): + if not self._client.connected: + if not self._run_async(self._client.connect()): raise RuntimeError("Failed to connect to Minecraft Forge mod") - # Send reset command to mod (for future use) - self.client.send_command("RESET") + self._run_async(self._client.send_action(RawInput.release_all())) - # Get initial frame - current_frame = self.client.receive_frame_data() - if current_frame is None: - self.logger.warning("No frame data received") - # Return a black frame if we can't get data - height, width = self.connection_config.width, self.connection_config.height - current_frame = np.zeros((height, width, 3), dtype=np.uint8) + obs = self._run_async(self._client.receive_observation()) + if obs is None: + frame = np.zeros( + (self.env_config.frame_height, self.env_config.frame_width, 3), + dtype=np.uint8, + ) else: - self.logger.info(f"Received frame data: {current_frame.shape}") + frame = obs.frame + self._last_reward = obs.reward - self.step_count = 0 + self._step_count = 0 - return current_frame, {"step_count": self.step_count} + return frame, {"step_count": self._step_count, "reward": self._last_reward} - def step(self, action) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: - """ - Execute one step in the environment + def step(self, action: int) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: + obs = self._run_async(self._client.receive_observation()) - Parameters - ---------- - action - Action to take (currently ignored) - - Returns - ------- - Tuple[np.ndarray, float, bool, bool, Dict[str, Any]] - Observation, reward, terminated, truncated, info - """ - # For now, we ignore the action since we're only reading frame data - - # Get new frame data - new_frame = self.client.receive_frame_data() - if new_frame is not None: - current_frame = self._resize_frame(new_frame) + if obs is not None: + frame = obs.frame + reward = obs.reward else: - height, width = self.config.engine.image_size - current_frame = np.zeros((height, width, 3), dtype=np.uint8) + frame = np.zeros( + (self.env_config.frame_height, self.env_config.frame_width, 3), + dtype=np.uint8, + ) + reward = 0.0 - self.step_count += 1 + self._last_reward = reward + self._step_count += 1 - # Placeholder values for reward and termination - reward = 0.0 terminated = False - truncated = self.step_count >= self.config.engine.max_steps + truncated = self._step_count >= self.env_config.max_steps - info = {"step_count": self.step_count, "frame_received": new_frame is not None} + info = { + "step_count": self._step_count, + "reward": reward, + "frame_received": obs is not None, + } - return current_frame, reward, terminated, truncated, info + return frame, reward, terminated, truncated, info def render(self, mode: str = "rgb_array") -> np.ndarray | None: - """ - Parameters - ---------- - mode : str - Render mode - - Raises - ------ - NotImplementedError - Rendering is not supported. - """ raise NotImplementedError("Rendering is not supported.") def close(self): - """Close the environment and disconnect from Minecraft""" - self.client.disconnect() - - def _resize_frame(self, frame: np.ndarray) -> np.ndarray: - """ - Resize frame to match expected dimensions - - Parameters - ---------- - frame : np.ndarray - Input frame - - Returns - ------- - np.ndarray - Resized frame - """ - target_height, target_width = self.config.engine.image_size - - if frame.shape[:2] != (target_height, target_width): - # Use PIL for high-quality resizing - pil_image = Image.fromarray(frame) - try: - # Try new PIL API first - from PIL.Image import Resampling - - pil_image = pil_image.resize( - (target_width, target_height), Resampling.LANCZOS - ) - except (ImportError, AttributeError): - # Fall back to old PIL API using numeric constant (1 = LANCZOS) - pil_image = pil_image.resize((target_width, target_height), 1) - frame = np.array(pil_image) - - return frame + if self._client.connected: + self._run_async(self._client.disconnect()) + if self._loop and not self._loop.is_closed(): + self._loop.close() + self._loop = None def create_minecraft_env( - config: Config | None = None, connection_config: ConnectionConfig | None = None + env_config: MinecraftEnvConfig | None = None, + connection_config: ConnectionConfig | None = None, ) -> MinecraftEnv: """ - Factory function to create a Minecraft environment + Factory function to create a Minecraft environment. Parameters ---------- - config : Config | None - MineAgent configuration object + env_config : MinecraftEnvConfig | None + Environment configuration connection_config : ConnectionConfig | None Connection configuration for the Minecraft mod @@ -401,4 +147,4 @@ def create_minecraft_env( MinecraftEnv Configured Minecraft environment """ - return MinecraftEnv(config=config, connection_config=connection_config) + return MinecraftEnv(env_config=env_config, connection_config=connection_config) diff --git a/pixi.lock b/pixi.lock index f891b59..54510d8 100644 --- a/pixi.lock +++ b/pixi.lock @@ -18,10 +18,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.34.6-hb03c661_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2026.2.25-hbd8a1cb_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.4-he90730b_1.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/cpython-3.12.12-py312hd8ed1ab_2.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/cpython-3.12.13-py312hd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/dav1d-1.2.1-hd590300_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/editables-0.5-pyhd8ed1ab_1.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/filelock-3.24.3-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/filelock-3.25.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/fmt-12.1.0-hff5e90c_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2 @@ -48,7 +48,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.45.1-default_hbd61a6d_101.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/lerc-4.0.0-h0aef613_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libabseil-20260107.1-cxx17_h7b12aa8_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libavif16-1.3.0-h316e467_3.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libavif16-1.4.0-hcfa2d63_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-5_h5875eb1_mkl.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-5_hfef963f_mkl.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h7a8fb5f_6.conda @@ -57,8 +57,8 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20250104-pl5321h7949ede_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.7.4-hecca717_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.1-ha770c72_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.1-h73754d4_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.2-ha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.2-h73754d4_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-15.2.0-he0feb66_18.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-15.2.0-h69a702a_18.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgfortran-15.2.0-h69a702a_18.conda @@ -75,7 +75,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.55-h421ea60_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-6.33.5-h2b00c02_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libre2-11-2025.11.05-h0dc7533_1.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.51.2-hf4e2dac_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.52.0-hf4e2dac_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_18.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-15.2.0-hdf11a46_18.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda @@ -85,12 +85,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.6.0-hd42ef1d_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.17.0-h8a09558_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-16-2.15.1-hca6bf5a_1.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.15.1-he237659_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-16-2.15.2-hca6bf5a_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.15.2-he237659_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-22.1.0-h4922eb0_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/markdown-3.10.2-pyhcf101f3_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.3-py312h8a5da7c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.3-py312h8a5da7c_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mkl-2025.3.0-h0e700b2_463.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mpc-1.3.1-h24ddda3_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mpfr-4.2.1-h90cbb55_3.conda @@ -113,16 +113,16 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/pybind11-3.0.1-pyh7a1b43c_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pybind11-abi-11-hc364b38_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pybind11-global-3.0.1-pyhc7ab6ef_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.12-hd63d673_2_cpython.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.13-hd63d673_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/noarch/python_abi-3.12-8_cp312.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pytorch-2.10.0-cpu_mkl_py312_hca44ed5_103.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/rav1e-0.7.1-h8fae777_3.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/rav1e-0.8.1-h1fbca29_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/re2-2025.11.05-h5301d42_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/scipy-1.17.1-py312h54fa4ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/setuptools-82.0.0-pyh332efcf_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/sleef-3.9.0-ha0421bc_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/svt-av1-4.0.0-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/svt-av1-4.0.1-hecca717_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_106.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tbb-2022.3.0-hb700be7_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tensorboard-2.20.0-pyhe01879c_0.conda @@ -179,10 +179,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.34.6-hb03c661_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2026.2.25-hbd8a1cb_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.4-he90730b_1.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/cpython-3.12.12-py312hd8ed1ab_2.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/cpython-3.12.13-py312hd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/dav1d-1.2.1-hd590300_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/editables-0.5-pyhd8ed1ab_1.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/filelock-3.24.3-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/filelock-3.25.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/fmt-12.1.0-hff5e90c_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2 @@ -209,7 +209,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.45.1-default_hbd61a6d_101.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/lerc-4.0.0-h0aef613_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libabseil-20260107.1-cxx17_h7b12aa8_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libavif16-1.3.0-h316e467_3.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libavif16-1.4.0-hcfa2d63_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-5_h5875eb1_mkl.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.2.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.2.0-hb03c661_1.conda @@ -222,8 +222,8 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libev-4.33-hd590300_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.7.4-hecca717_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.1-ha770c72_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.1-h73754d4_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.2-ha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.2-h73754d4_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-15.2.0-he0feb66_18.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-15.2.0-h69a702a_18.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgfortran-15.2.0-h69a702a_18.conda @@ -241,7 +241,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.55-h421ea60_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-6.33.5-h2b00c02_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libre2-11-2025.11.05-h0dc7533_1.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.51.2-hf4e2dac_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.52.0-hf4e2dac_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_18.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-15.2.0-hdf11a46_18.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda @@ -251,12 +251,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.6.0-hd42ef1d_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.17.0-h8a09558_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-16-2.15.1-hca6bf5a_1.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.15.1-he237659_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-16-2.15.2-hca6bf5a_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.15.2-he237659_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-22.1.0-h4922eb0_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/markdown-3.10.2-pyhcf101f3_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.3-py312h8a5da7c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.3-py312h8a5da7c_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mkl-2025.3.0-h0e700b2_463.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mpc-1.3.1-h24ddda3_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mpfr-4.2.1-h90cbb55_3.conda @@ -280,16 +280,16 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/pybind11-3.0.1-pyh7a1b43c_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pybind11-abi-11-hc364b38_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pybind11-global-3.0.1-pyhc7ab6ef_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.12-hd63d673_2_cpython.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.13-hd63d673_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/noarch/python_abi-3.12-8_cp312.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pytorch-2.10.0-cpu_mkl_py312_hca44ed5_103.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/rav1e-0.7.1-h8fae777_3.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/rav1e-0.8.1-h1fbca29_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/re2-2025.11.05-h5301d42_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/scipy-1.17.1-py312h54fa4ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/setuptools-82.0.0-pyh332efcf_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/sleef-3.9.0-ha0421bc_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/svt-av1-4.0.0-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/svt-av1-4.0.1-hecca717_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_106.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tbb-2022.3.0-hb700be7_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tensorboard-2.20.0-pyhe01879c_0.conda @@ -328,22 +328,23 @@ environments: - pypi: https://files.pythonhosted.org/packages/05/2c/ffc08c54c05cdce6fbed2aeebc46348dbe180c6d2c541c7af7ba0aa5f5f8/Farama_Notifications-0.0.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/37/82dbef0f6342eb01f54bca073ac1498433d6ce71e50c3c3282b655733b31/fonttools-4.61.1-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/56/d3/ea5f088e3638dbab12e5c20d6559d5b3bdaeaa1f2af74e526e6815836285/gymnasium-1.2.3-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/b8/58/40fbbcefeda82364720eba5cf2270f98496bdfa19ea75b4cccae79c698e6/identify-2.6.16-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/40/66/71c1227dff78aaeb942fed29dd5651f2aec166cc7c9aeea3e8b26a539b7d/identify-2.6.17-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/90/6d240beb0f24b74371762873e9b7f499f1e02166a2d9c5801f4dbf8fa12e/kiwisolver-1.4.9-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/3e/f3/c5195b1ae57ef85339fd7285dfb603b22c8b4e79114bae5f4f0fcf688677/matplotlib-3.10.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/48/31/05e764397056194206169869b50cf2fee4dbbbc71b344705b9c0d878d4d8/platformdirs-4.9.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/63/d7/97f7e3a6abb67d8080dd406fd4df842c2be0efaf712d1c899c32a075027c/platformdirs-4.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0c/82/a2c93e32800940d9573fb28c346772a14778b84ba7524e691b324620ab89/pyright-1.1.408-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/06/54/82a6e2ef37f0f23dccac604b9585bdcbd0698604feb64807dcb72853693e/python_discovery-1.1.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/75/0f/2bf7e3b5a4a65f623cb820feb5793e243fad58ae561015ee15a6152f67a2/python_discovery-1.1.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/ff/90/bf134f4c1e5243e62690e09d63c55df948a74084c8ac3e48a88468314da6/ruff-0.15.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/d3/01/a10fe54b653061585e655f5286c2662ebddb68831ed3eaebfb0eb08c0a16/ruff-0.15.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/16/e1/3079a9ff9b8e11b846c6ac5c8b5bfb7ff225eee721825310c91b3b50304f/tqdm-4.67.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/bd/e0/1eed384f02555dde685fff1a1ac805c1c7dcb6dd019c916fe659b1c1f9ec/types_pyyaml-6.0.12.20250915-py3-none-any.whl @@ -487,17 +488,17 @@ packages: - pytest-xdist ; extra == 'test-no-images' - wurlitzer ; extra == 'test-no-images' requires_python: '>=3.11' -- conda: https://conda.anaconda.org/conda-forge/noarch/cpython-3.12.12-py312hd8ed1ab_2.conda +- conda: https://conda.anaconda.org/conda-forge/noarch/cpython-3.12.13-py312hd8ed1ab_0.conda noarch: generic - sha256: ccb90d95bac9f1f4f6629a4addb44d36433e4ad1fe4ac87a864f90ff305dbf6d - md5: ef3e093ecfd4533eee992cdaa155b47e + sha256: d3e9bbd7340199527f28bbacf947702368f31de60c433a16446767d3c6aaf6fe + md5: f54c1ffb8ecedb85a8b7fcde3a187212 depends: - python >=3.12,<3.13.0a0 - python_abi * *_cp312 license: Python-2.0 purls: [] - size: 46644 - timestamp: 1769471040321 + size: 46463 + timestamp: 1772728929620 - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl name: cycler version: 0.12.1 @@ -555,16 +556,16 @@ packages: name: farama-notifications version: 0.0.4 sha256: 14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae -- conda: https://conda.anaconda.org/conda-forge/noarch/filelock-3.24.3-pyhd8ed1ab_0.conda - sha256: 6d576ed3bd0e7c57b1144f0b2014de9ea3fab9786316bc3e748105d44e0140a0 - md5: 9dbb20eec24beb026291c20a35ce1ff9 +- conda: https://conda.anaconda.org/conda-forge/noarch/filelock-3.25.0-pyhd8ed1ab_0.conda + sha256: 55162ec0ff4e22d8f762b3e774f08f79f234ba37ab5e64d7a1421ed9216a222c + md5: 49a92015e912176999ae81bea11ea778 depends: - python >=3.10 license: Unlicense purls: - pkg:pypi/filelock?source=compressed-mapping - size: 24808 - timestamp: 1771468713029 + size: 25656 + timestamp: 1772380968183 - conda: https://conda.anaconda.org/conda-forge/linux-64/fmt-12.1.0-hff5e90c_0.conda sha256: d4e92ba7a7b4965341dc0fca57ec72d01d111b53c12d11396473115585a9ead6 md5: f7d7a4104082b39e3b3473fbd4a38229 @@ -885,10 +886,10 @@ packages: purls: [] size: 12728445 timestamp: 1767969922681 -- pypi: https://files.pythonhosted.org/packages/b8/58/40fbbcefeda82364720eba5cf2270f98496bdfa19ea75b4cccae79c698e6/identify-2.6.16-py2.py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/40/66/71c1227dff78aaeb942fed29dd5651f2aec166cc7c9aeea3e8b26a539b7d/identify-2.6.17-py2.py3-none-any.whl name: identify - version: 2.6.16 - sha256: 391ee4d77741d994189522896270b787aed8670389bfd60f326d677d64a6dfb0 + version: 2.6.17 + sha256: be5f8412d5ed4b20f2bd41a65f920990bdccaa6a4a18a08f1eefdcd0bdd885f0 requires_dist: - ukkonen ; extra == 'license' requires_python: '>=3.10' @@ -1007,21 +1008,21 @@ packages: purls: [] size: 1384817 timestamp: 1770863194876 -- conda: https://conda.anaconda.org/conda-forge/linux-64/libavif16-1.3.0-h316e467_3.conda - sha256: f5ab201b8b4e1f776ced0340c59f87e441fd6763d3face527b5cf3f2280502c9 - md5: 22d5cc5fb45aab8ed3c00cde2938b825 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libavif16-1.4.0-hcfa2d63_0.conda + sha256: 918fd09af66968361c8fa40a76f864b7febb8286dd5dcb1419517b9db950c84c + md5: e226d3dbe1e2482fd8e15cb924fd1e7c depends: - __glibc >=2.17,<3.0.a0 - aom >=3.9.1,<3.10.0a0 - dav1d >=1.2.1,<1.2.2.0a0 - libgcc >=14 - - rav1e >=0.7.1,<0.8.0a0 - - svt-av1 >=4.0.0,<4.0.1.0a0 + - rav1e >=0.8.1,<0.9.0a0 + - svt-av1 >=4.0.1,<4.0.2.0a0 license: BSD-2-Clause license_family: BSD purls: [] - size: 140323 - timestamp: 1769476997956 + size: 148589 + timestamp: 1772682433596 - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-5_h5875eb1_mkl.conda build_number: 5 sha256: 328d64d4eb51047c39a8039a30eb47695855829d0a11b72d932171cb1dcdfad3 @@ -1176,29 +1177,29 @@ packages: purls: [] size: 58592 timestamp: 1769456073053 -- conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.1-ha770c72_0.conda - sha256: 4641d37faeb97cf8a121efafd6afd040904d4bca8c46798122f417c31d5dfbec - md5: f4084e4e6577797150f9b04a4560ceb0 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.2-ha770c72_0.conda + sha256: 2e1bfe1e856eb707d258f669ef6851af583ceaffab5e64821b503b0f7cd09e9e + md5: 26c746d14402a3b6c684d045b23b9437 depends: - - libfreetype6 >=2.14.1 + - libfreetype6 >=2.14.2 license: GPL-2.0-only OR FTL purls: [] - size: 7664 - timestamp: 1757945417134 -- conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.1-h73754d4_0.conda - sha256: 4a7af818a3179fafb6c91111752954e29d3a2a950259c14a2fc7ba40a8b03652 - md5: 8e7251989bca326a28f4a5ffbd74557a + size: 8035 + timestamp: 1772757210108 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.2-h73754d4_0.conda + sha256: aba65b94bdbed52de17ec3d0c6f2ebac2ef77071ad22d6900d1614d0dd702a0c + md5: 8eaba3d1a4d7525c6814e861614457fd depends: - __glibc >=2.17,<3.0.a0 - libgcc >=14 - - libpng >=1.6.50,<1.7.0a0 + - libpng >=1.6.55,<1.7.0a0 - libzlib >=1.3.1,<2.0a0 constrains: - - freetype >=2.14.1 + - freetype >=2.14.2 license: GPL-2.0-only OR FTL purls: [] - size: 386739 - timestamp: 1757945416744 + size: 386316 + timestamp: 1772757193822 - conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-15.2.0-he0feb66_18.conda sha256: faf7d2017b4d718951e3a59d081eb09759152f93038479b768e3d612688f83f5 md5: 0aa00f03f9e39fb9876085dee11a85d4 @@ -1438,9 +1439,9 @@ packages: purls: [] size: 213122 timestamp: 1768190028309 -- conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.51.2-hf4e2dac_0.conda - sha256: 04596fcee262a870e4b7c9807224680ff48d4d0cc0dac076a602503d3dc6d217 - md5: da5be73701eecd0e8454423fd6ffcf30 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.52.0-hf4e2dac_0.conda + sha256: d716847b7deca293d2e49ed1c8ab9e4b9e04b9d780aea49a97c26925b28a7993 + md5: fd893f6a3002a635b5e50ceb9dd2c0f4 depends: - __glibc >=2.17,<3.0.a0 - icu >=78.2,<79.0a0 @@ -1448,8 +1449,8 @@ packages: - libzlib >=1.3.1,<2.0a0 license: blessing purls: [] - size: 942808 - timestamp: 1768147973361 + size: 951405 + timestamp: 1772818874251 - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_18.conda sha256: 78668020064fdaa27e9ab65cd2997e2c837b564ab26ce3bf0e58a2ce1a525c6e md5: 1b08cd684f34175e4514474793d44bcb @@ -1579,39 +1580,39 @@ packages: purls: [] size: 100393 timestamp: 1702724383534 -- conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.15.1-he237659_1.conda - sha256: 047be059033c394bd32ae5de66ce389824352120b3a7c0eff980195f7ed80357 - md5: 417955234eccd8f252b86a265ccdab7f +- conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.15.2-he237659_0.conda + sha256: 275c324f87bda1a3b67d2f4fcc3555eeff9e228a37655aa001284a7ceb6b0392 + md5: e49238a1609f9a4a844b09d9926f2c3d depends: - __glibc >=2.17,<3.0.a0 - - icu >=78.1,<79.0a0 + - icu >=78.2,<79.0a0 - libgcc >=14 - libiconv >=1.18,<2.0a0 - - liblzma >=5.8.1,<6.0a0 - - libxml2-16 2.15.1 hca6bf5a_1 + - liblzma >=5.8.2,<6.0a0 + - libxml2-16 2.15.2 hca6bf5a_0 - libzlib >=1.3.1,<2.0a0 license: MIT license_family: MIT purls: [] - size: 45402 - timestamp: 1766327161688 -- conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-16-2.15.1-hca6bf5a_1.conda - sha256: 8331284bf9ae641b70cdc0e5866502dd80055fc3b9350979c74bb1d192e8e09e - md5: 3fdd8d99683da9fe279c2f4cecd1e048 + size: 45968 + timestamp: 1772704614539 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-16-2.15.2-hca6bf5a_0.conda + sha256: 08d2b34b49bec9613784f868209bb7c3bb8840d6cf835ff692e036b09745188c + md5: f3bc152cb4f86babe30f3a4bf0dbef69 depends: - __glibc >=2.17,<3.0.a0 - - icu >=78.1,<79.0a0 + - icu >=78.2,<79.0a0 - libgcc >=14 - libiconv >=1.18,<2.0a0 - - liblzma >=5.8.1,<6.0a0 + - liblzma >=5.8.2,<6.0a0 - libzlib >=1.3.1,<2.0a0 constrains: - - libxml2 2.15.1 + - libxml2 2.15.2 license: MIT license_family: MIT purls: [] - size: 555747 - timestamp: 1766327145986 + size: 557492 + timestamp: 1772704601644 - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda sha256: d4bfe88d7cb447768e31650f06257995601f89076080e76df55e3112d4e47dc4 md5: edb0dca6bc32e4f4789199455a1dbeb8 @@ -1651,9 +1652,9 @@ packages: - pkg:pypi/markdown?source=compressed-mapping size: 85893 timestamp: 1770694658918 -- conda: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.3-py312h8a5da7c_0.conda - sha256: f77f9f1a4da45cbc8792d16b41b6f169f649651a68afdc10b2da9da12b9aa42b - md5: f775a43412f7f3d7ed218113ad233869 +- conda: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.3-py312h8a5da7c_1.conda + sha256: 5f3aad1f3a685ed0b591faad335957dbdb1b73abfd6fc731a0d42718e0653b33 + md5: 93a4752d42b12943a355b682ee43285b depends: - __glibc >=2.17,<3.0.a0 - libgcc >=14 @@ -1664,9 +1665,9 @@ packages: license: BSD-3-Clause license_family: BSD purls: - - pkg:pypi/markupsafe?source=hash-mapping - size: 25321 - timestamp: 1759055268795 + - pkg:pypi/markupsafe?source=compressed-mapping + size: 26057 + timestamp: 1772445297924 - pypi: https://files.pythonhosted.org/packages/3e/f3/c5195b1ae57ef85339fd7285dfb603b22c8b4e79114bae5f4f0fcf688677/matplotlib-3.10.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl name: matplotlib version: 3.10.8 @@ -1689,7 +1690,7 @@ packages: - pypi: ./ name: mineagent version: 0.0.1 - sha256: ae8bfca4b4c60d8874099937409006912ba77d6ef257a029af735f7b3ab9786b + sha256: 753be6a849db87c6880ff10b3da76be9a4c1513b8e7c10dceeb47b8eab4db221 requires_dist: - pyyaml - dacite @@ -1703,6 +1704,7 @@ packages: - pyright - pre-commit ; extra == 'dev' - pytest ; extra == 'dev' + - pytest-asyncio ; extra == 'dev' - pytest-mock ; extra == 'dev' - ruff ; extra == 'dev' - types-pyyaml ; extra == 'dev' @@ -1987,10 +1989,10 @@ packages: purls: [] size: 450960 timestamp: 1754665235234 -- pypi: https://files.pythonhosted.org/packages/48/31/05e764397056194206169869b50cf2fee4dbbbc71b344705b9c0d878d4d8/platformdirs-4.9.2-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/63/d7/97f7e3a6abb67d8080dd406fd4df842c2be0efaf712d1c899c32a075027c/platformdirs-4.9.4-py3-none-any.whl name: platformdirs - version: 4.9.2 - sha256: 9170634f126f8efdae22fb58ae8a0eaa86f38365bc57897a6c4f781d1f5875bd + version: 4.9.4 + sha256: 68a9a4619a666ea6439f2ff250c12a853cd1cbd5158d258bd824a7df6be2f868 requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda sha256: e14aafa63efa0528ca99ba568eaf506eb55a0371d12e6250aaaa61718d2eb62e @@ -2131,6 +2133,19 @@ packages: - setuptools ; extra == 'dev' - xmlschema ; extra == 'dev' requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl + name: pytest-asyncio + version: 1.3.0 + sha256: 611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5 + requires_dist: + - backports-asyncio-runner>=1.1,<2 ; python_full_version < '3.11' + - pytest>=8.2,<10 + - typing-extensions>=4.12 ; python_full_version < '3.13' + - sphinx>=5.3 ; extra == 'docs' + - sphinx-rtd-theme>=1 ; extra == 'docs' + - coverage>=6.2 ; extra == 'testing' + - hypothesis>=5.7.1 ; extra == 'testing' + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl name: pytest-mock version: 3.15.1 @@ -2141,15 +2156,14 @@ packages: - pytest-asyncio ; extra == 'dev' - tox ; extra == 'dev' requires_python: '>=3.9' -- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.12-hd63d673_2_cpython.conda - build_number: 2 - sha256: 6621befd6570a216ba94bc34ec4618e4f3777de55ad0adc15fc23c28fadd4d1a - md5: c4540d3de3fa228d9fa95e31f8e97f89 +- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.13-hd63d673_0_cpython.conda + sha256: a44655c1c3e1d43ed8704890a91e12afd68130414ea2c0872e154e5633a13d7e + md5: 7eccb41177e15cc672e1babe9056018e depends: - __glibc >=2.17,<3.0.a0 - bzip2 >=1.0.8,<2.0a0 - ld_impl_linux-64 >=2.36.1 - - libexpat >=2.7.3,<3.0a0 + - libexpat >=2.7.4,<3.0a0 - libffi >=3.5.2,<3.6.0a0 - libgcc >=14 - liblzma >=5.8.2,<6.0a0 @@ -2159,7 +2173,7 @@ packages: - libxcrypt >=4.4.36 - libzlib >=1.3.1,<2.0a0 - ncurses >=6.5,<7.0a0 - - openssl >=3.5.4,<4.0a0 + - openssl >=3.5.5,<4.0a0 - readline >=8.3,<9.0a0 - tk >=8.6.13,<8.7.0a0 - tzdata @@ -2167,8 +2181,8 @@ packages: - python_abi 3.12.* *_cp312 license: Python-2.0 purls: [] - size: 31457785 - timestamp: 1769472855343 + size: 31608571 + timestamp: 1772730708989 - pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl name: python-dateutil version: 2.9.0.post0 @@ -2176,10 +2190,10 @@ packages: requires_dist: - six>=1.5 requires_python: '>=2.7,!=3.0.*,!=3.1.*,!=3.2.*' -- pypi: https://files.pythonhosted.org/packages/06/54/82a6e2ef37f0f23dccac604b9585bdcbd0698604feb64807dcb72853693e/python_discovery-1.1.0-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/75/0f/2bf7e3b5a4a65f623cb820feb5793e243fad58ae561015ee15a6152f67a2/python_discovery-1.1.1-py3-none-any.whl name: python-discovery - version: 1.1.0 - sha256: a162893b8809727f54594a99ad2179d2ede4bf953e12d4c7abc3cc9cdbd1437b + version: 1.1.1 + sha256: 69f11073fa2392251e405d4e847d60ffffd25fd762a0dc4d1a7d6b9c3f79f1a3 requires_dist: - filelock>=3.15.4 - platformdirs>=4.3.6,<5 @@ -2253,19 +2267,19 @@ packages: version: 6.0.3 sha256: ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc requires_python: '>=3.8' -- conda: https://conda.anaconda.org/conda-forge/linux-64/rav1e-0.7.1-h8fae777_3.conda - sha256: 6e5e704c1c21f820d760e56082b276deaf2b53cf9b751772761c3088a365f6f4 - md5: 2c42649888aac645608191ffdc80d13a +- conda: https://conda.anaconda.org/conda-forge/linux-64/rav1e-0.8.1-h1fbca29_0.conda + sha256: cf550bbc8e5ebedb6dba9ccaead3e07bd1cb86b183644a4c853e06e4b3ad5ac7 + md5: d83958768626b3c8471ce032e28afcd3 depends: - __glibc >=2.17,<3.0.a0 - - libgcc >=13 + - libgcc >=14 constrains: - __glibc >=2.17 license: BSD-2-Clause license_family: BSD purls: [] - size: 5176669 - timestamp: 1746622023242 + size: 5595970 + timestamp: 1772540833621 - conda: https://conda.anaconda.org/conda-forge/linux-64/re2-2025.11.05-h5301d42_1.conda sha256: 3fc684b81631348540e9a42f6768b871dfeab532d3f47d5c341f1f83e2a2b2b2 md5: 66a715bc01c77d43aca1f9fcb13dde3c @@ -2288,10 +2302,10 @@ packages: purls: [] size: 345073 timestamp: 1765813471974 -- pypi: https://files.pythonhosted.org/packages/ff/90/bf134f4c1e5243e62690e09d63c55df948a74084c8ac3e48a88468314da6/ruff-0.15.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +- pypi: https://files.pythonhosted.org/packages/d3/01/a10fe54b653061585e655f5286c2662ebddb68831ed3eaebfb0eb08c0a16/ruff-0.15.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl name: ruff - version: 0.15.4 - sha256: 451a2e224151729b3b6c9ffb36aed9091b2996fe4bdbd11f47e27d8f2e8888ec + version: 0.15.5 + sha256: c1cb7169f53c1ddb06e71a9aebd7e98fc0fea936b39afb36d8e86d36ecc2636a requires_python: '>=3.7' - conda: https://conda.anaconda.org/conda-forge/linux-64/scipy-1.17.1-py312h54fa4ab_0.conda sha256: e3ad577361d67f6c078a6a7a3898bf0617b937d44dc4ccd57aa3336f2b5778dd @@ -2344,9 +2358,9 @@ packages: purls: [] size: 1951720 timestamp: 1756274576844 -- conda: https://conda.anaconda.org/conda-forge/linux-64/svt-av1-4.0.0-hecca717_0.conda - sha256: e5e036728ef71606569232cc94a0480722e14ed69da3dd1e363f3d5191d83c01 - md5: 9a6117aee038999ffefe6082ff1e9a81 +- conda: https://conda.anaconda.org/conda-forge/linux-64/svt-av1-4.0.1-hecca717_0.conda + sha256: 4a1d2005153b9454fc21c9bad1b539df189905be49e851ec62a6212c2e045381 + md5: 2a2170a3e5c9a354d09e4be718c43235 depends: - __glibc >=2.17,<3.0.a0 - libgcc >=14 @@ -2354,8 +2368,8 @@ packages: license: BSD-2-Clause license_family: BSD purls: [] - size: 2620937 - timestamp: 1769280649780 + size: 2619743 + timestamp: 1769664536467 - conda: https://conda.anaconda.org/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_106.conda sha256: 1c8057e6875eba958aa8b3c1a072dc9a75d013f209c26fd8125a5ebd3abbec0c md5: 32d866e43b25275f61566b9391ccb7b5 diff --git a/pyproject.toml b/pyproject.toml index c388490..b5aa283 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ mineagent = "mineagent.engine:run" dev = [ "pre-commit", "pytest", + "pytest-asyncio", "pytest-mock", "ruff", "types-PyYAML", diff --git a/scripts/test_action_client.py b/scripts/test_action_client.py index 2d9dc0f..c114547 100644 --- a/scripts/test_action_client.py +++ b/scripts/test_action_client.py @@ -10,264 +10,10 @@ import struct import time import threading -from dataclasses import dataclass, field from typing import Optional import argparse - -# ============================================================================= -# GLFW Key Code Constants -# Reference: https://www.glfw.org/docs/3.3/group__keys.html -# ============================================================================= - - -class GLFW: - """GLFW key and mouse button constants.""" - - # Printable keys - KEY_SPACE = 32 - KEY_APOSTROPHE = 39 - KEY_COMMA = 44 - KEY_MINUS = 45 - KEY_PERIOD = 46 - KEY_SLASH = 47 - KEY_0 = 48 - KEY_1 = 49 - KEY_2 = 50 - KEY_3 = 51 - KEY_4 = 52 - KEY_5 = 53 - KEY_6 = 54 - KEY_7 = 55 - KEY_8 = 56 - KEY_9 = 57 - KEY_SEMICOLON = 59 - KEY_EQUAL = 61 - KEY_A = 65 - KEY_B = 66 - KEY_C = 67 - KEY_D = 68 - KEY_E = 69 - KEY_F = 70 - KEY_G = 71 - KEY_H = 72 - KEY_I = 73 - KEY_J = 74 - KEY_K = 75 - KEY_L = 76 - KEY_M = 77 - KEY_N = 78 - KEY_O = 79 - KEY_P = 80 - KEY_Q = 81 - KEY_R = 82 - KEY_S = 83 - KEY_T = 84 - KEY_U = 85 - KEY_V = 86 - KEY_W = 87 - KEY_X = 88 - KEY_Y = 89 - KEY_Z = 90 - KEY_LEFT_BRACKET = 91 - KEY_BACKSLASH = 92 - KEY_RIGHT_BRACKET = 93 - KEY_GRAVE_ACCENT = 96 - - # Function keys - KEY_ESCAPE = 256 - KEY_ENTER = 257 - KEY_TAB = 258 - KEY_BACKSPACE = 259 - KEY_INSERT = 260 - KEY_DELETE = 261 - KEY_RIGHT = 262 - KEY_LEFT = 263 - KEY_DOWN = 264 - KEY_UP = 265 - KEY_PAGE_UP = 266 - KEY_PAGE_DOWN = 267 - KEY_HOME = 268 - KEY_END = 269 - KEY_CAPS_LOCK = 280 - KEY_SCROLL_LOCK = 281 - KEY_NUM_LOCK = 282 - KEY_PRINT_SCREEN = 283 - KEY_PAUSE = 284 - KEY_F1 = 290 - KEY_F2 = 291 - KEY_F3 = 292 - KEY_F4 = 293 - KEY_F5 = 294 - KEY_F6 = 295 - KEY_F7 = 296 - KEY_F8 = 297 - KEY_F9 = 298 - KEY_F10 = 299 - KEY_F11 = 300 - KEY_F12 = 301 - - # Modifier keys - KEY_LEFT_SHIFT = 340 - KEY_LEFT_CONTROL = 341 - KEY_LEFT_ALT = 342 - KEY_LEFT_SUPER = 343 - KEY_RIGHT_SHIFT = 344 - KEY_RIGHT_CONTROL = 345 - KEY_RIGHT_ALT = 346 - KEY_RIGHT_SUPER = 347 - KEY_MENU = 348 - - # Mouse buttons - MOUSE_BUTTON_LEFT = 0 - MOUSE_BUTTON_RIGHT = 1 - MOUSE_BUTTON_MIDDLE = 2 - - -# Command name to GLFW key code mapping -COMMAND_TO_KEY = { - # Movement (Minecraft default bindings) - "w": GLFW.KEY_W, - "forward": GLFW.KEY_W, - "s": GLFW.KEY_S, - "back": GLFW.KEY_S, - "a": GLFW.KEY_A, - "left": GLFW.KEY_A, - "d": GLFW.KEY_D, - "right": GLFW.KEY_D, - "space": GLFW.KEY_SPACE, - "jump": GLFW.KEY_SPACE, - "shift": GLFW.KEY_LEFT_SHIFT, - "sneak": GLFW.KEY_LEFT_SHIFT, - "ctrl": GLFW.KEY_LEFT_CONTROL, - "sprint": GLFW.KEY_LEFT_CONTROL, - # Interaction - "e": GLFW.KEY_E, - "inventory": GLFW.KEY_E, - "q": GLFW.KEY_Q, - "drop": GLFW.KEY_Q, - "f": GLFW.KEY_F, - "swap": GLFW.KEY_F, - # Hotbar (number keys) - "1": GLFW.KEY_1, - "hotbar1": GLFW.KEY_1, - "2": GLFW.KEY_2, - "hotbar2": GLFW.KEY_2, - "3": GLFW.KEY_3, - "hotbar3": GLFW.KEY_3, - "4": GLFW.KEY_4, - "hotbar4": GLFW.KEY_4, - "5": GLFW.KEY_5, - "hotbar5": GLFW.KEY_5, - "6": GLFW.KEY_6, - "hotbar6": GLFW.KEY_6, - "7": GLFW.KEY_7, - "hotbar7": GLFW.KEY_7, - "8": GLFW.KEY_8, - "hotbar8": GLFW.KEY_8, - "9": GLFW.KEY_9, - "hotbar9": GLFW.KEY_9, - # Special keys - "esc": GLFW.KEY_ESCAPE, - "escape": GLFW.KEY_ESCAPE, - "enter": GLFW.KEY_ENTER, - "tab": GLFW.KEY_TAB, - "t": GLFW.KEY_T, # Chat - "chat": GLFW.KEY_T, - "/": GLFW.KEY_SLASH, # Command - "command": GLFW.KEY_SLASH, -} - - -# ============================================================================= -# RawInput Data Class -# ============================================================================= - - -@dataclass -class RawInput: - """ - Raw input data to send to Minecraft. - - Protocol format (variable size): - - 1 byte: numKeysPressed (0-255) - - N*2 bytes: keyCodes (shorts, big-endian) - - 4 bytes: mouseDeltaX (float, big-endian) - - 4 bytes: mouseDeltaY (float, big-endian) - - 1 byte: mouseButtons (bits: 0=left, 1=right, 2=middle) - - 4 bytes: scrollDelta (float, big-endian) - - 2 bytes: textLength (big-endian) - - M bytes: textBytes (UTF-8) - """ - - key_codes: list[int] = field(default_factory=list) - mouse_dx: float = 0.0 - mouse_dy: float = 0.0 - mouse_buttons: int = 0 # Bit flags: 0=left, 1=right, 2=middle - scroll_delta: float = 0.0 - text: str = "" - - def to_bytes(self) -> bytes: - """Convert to binary protocol format.""" - data = bytearray() - - # Number of keys (1 byte) - num_keys = len(self.key_codes) - if num_keys > 255: - raise ValueError(f"Too many keys pressed: {num_keys} (max 255)") - data.append(num_keys) - - # Key codes (N * 2 bytes, big-endian shorts) - for key_code in self.key_codes: - data.extend(struct.pack(">h", key_code)) - - # Mouse delta X (4 bytes, big-endian float) - data.extend(struct.pack(">f", self.mouse_dx)) - - # Mouse delta Y (4 bytes, big-endian float) - data.extend(struct.pack(">f", self.mouse_dy)) - - # Mouse buttons (1 byte) - data.append(self.mouse_buttons & 0xFF) - - # Scroll delta (4 bytes, big-endian float) - data.extend(struct.pack(">f", self.scroll_delta)) - - # Text length and content - text_bytes = self.text.encode("utf-8") - text_length = len(text_bytes) - if text_length > 65535: - raise ValueError(f"Text too long: {text_length} bytes (max 65535)") - data.extend(struct.pack(">H", text_length)) - data.extend(text_bytes) - - return bytes(data) - - def set_left_mouse(self, pressed: bool): - """Set left mouse button state.""" - if pressed: - self.mouse_buttons |= 1 << GLFW.MOUSE_BUTTON_LEFT - else: - self.mouse_buttons &= ~(1 << GLFW.MOUSE_BUTTON_LEFT) - - def set_right_mouse(self, pressed: bool): - """Set right mouse button state.""" - if pressed: - self.mouse_buttons |= 1 << GLFW.MOUSE_BUTTON_RIGHT - else: - self.mouse_buttons &= ~(1 << GLFW.MOUSE_BUTTON_RIGHT) - - def set_middle_mouse(self, pressed: bool): - """Set middle mouse button state.""" - if pressed: - self.mouse_buttons |= 1 << GLFW.MOUSE_BUTTON_MIDDLE - else: - self.mouse_buttons &= ~(1 << GLFW.MOUSE_BUTTON_MIDDLE) - - -# ============================================================================= -# Test Clients -# ============================================================================= +from mineagent.client import GLFW, RawInput, COMMAND_TO_KEY class RawInputTestClient: @@ -359,19 +105,16 @@ def receive_observations(self): try: while self.running: - # Read reward (8 bytes, double) reward_data = self._read_exact(8) if not reward_data: break reward = struct.unpack(">d", reward_data)[0] - # Read frame length (4 bytes, int) length_data = self._read_exact(4) if not length_data: break frame_length = struct.unpack(">I", length_data)[0] - # Read frame data frame_data = self._read_exact(frame_length) if not frame_data: break @@ -401,11 +144,6 @@ def _read_exact(self, n: int) -> Optional[bytes]: return data -# ============================================================================= -# Command Parsing and Help -# ============================================================================= - - def show_help(): """Display comprehensive help information.""" print("\n" + "=" * 70) @@ -474,7 +212,7 @@ class HeldState: def __init__(self): self.keys: set[int] = set() - self.mouse_buttons: int = 0 # Bit flags: 0=left, 1=right, 2=middle + self.mouse_buttons: int = 0 def set_mouse_button(self, button: int, pressed: bool): """Set a mouse button state. button: 0=left, 1=right, 2=middle""" @@ -503,7 +241,6 @@ def parse_command( Returns: RawInput to send, or None if command was informational only """ - # Use held_state if provided, otherwise fall back to old behavior if held_state is None: held_state = HeldState() held_state.keys = held_keys @@ -514,14 +251,12 @@ def parse_command( main_command = parts[0] - # Handle special commands if main_command in ["help", "h", "?"]: show_help() return None raw_input = RawInput() - # Mouse movement if main_command == "mouse" and len(parts) >= 3: try: raw_input.mouse_dx = float(parts[1]) @@ -533,7 +268,6 @@ def parse_command( print("✗ Invalid mouse coordinates. Use: mouse ") return None - # Scroll wheel if main_command == "scroll" and len(parts) >= 2: try: raw_input.scroll_delta = float(parts[1]) @@ -544,23 +278,18 @@ def parse_command( print("✗ Invalid scroll amount. Use: scroll ") return None - # Text input if main_command == "text" and len(parts) >= 2: raw_input.text = " ".join(parts[1:]) raw_input.key_codes = list(held_state.keys) raw_input.mouse_buttons = held_state.mouse_buttons return raw_input - # Chat shortcut - returns a list of RawInputs to send in sequence if main_command == "say" and len(parts) >= 2: message = " ".join(parts[1:]) - # Return special marker - will be handled by caller raw_input.key_codes = [GLFW.KEY_T] - raw_input.text = f"__SAY__{message}" # Special marker for say command + raw_input.text = f"__SAY__{message}" return raw_input - # Mouse clicks (single shot - press then immediately release) - # These are one-shot and don't modify held state if main_command == "lclick": raw_input.mouse_buttons = held_state.mouse_buttons | ( 1 << GLFW.MOUSE_BUTTON_LEFT @@ -580,7 +309,6 @@ def parse_command( raw_input.key_codes = list(held_state.keys) return raw_input - # Mouse hold/release - these modify the held state if main_command == "ldown": held_state.set_mouse_button(GLFW.MOUSE_BUTTON_LEFT, True) raw_input.mouse_buttons = held_state.mouse_buttons @@ -602,13 +330,11 @@ def parse_command( raw_input.key_codes = list(held_state.keys) return raw_input - # Release all keys and mouse buttons if main_command == "release": held_keys.clear() held_state.clear() - return raw_input # Empty state releases all + return raw_input - # Combo command (press multiple keys together, then release) if main_command == "combo" and len(parts) > 1: for action_name in parts[1:]: if action_name in COMMAND_TO_KEY: @@ -616,23 +342,20 @@ def parse_command( raw_input.mouse_buttons = held_state.mouse_buttons return raw_input - # Hold command (add to held keys) if main_command == "hold" and len(parts) > 1: for action_name in parts[1:]: if action_name in COMMAND_TO_KEY: held_state.keys.add(COMMAND_TO_KEY[action_name]) - held_keys.add(COMMAND_TO_KEY[action_name]) # Keep old behavior too + held_keys.add(COMMAND_TO_KEY[action_name]) raw_input.key_codes = list(held_state.keys) raw_input.mouse_buttons = held_state.mouse_buttons return raw_input - # Single key command (momentary press - does NOT add to held state) if main_command in COMMAND_TO_KEY: raw_input.key_codes = [COMMAND_TO_KEY[main_command]] raw_input.mouse_buttons = held_state.mouse_buttons return raw_input - # Try to interpret as raw key code try: key_code = int(main_command) raw_input.key_codes = [key_code] @@ -645,26 +368,19 @@ def parse_command( return None -# ============================================================================= -# Main Functions -# ============================================================================= - - def run_interactive_mode(): """Run interactive mode for testing raw input.""" client = RawInputTestClient() observation_client = ObservationTestClient() held_state = HeldState() - held_keys = held_state.keys # For backward compatibility + held_keys = held_state.keys print("MineAgent Raw Input Test Client - Interactive Mode") print("=" * 50) - # Connect to action socket if not client.connect(): return - # Optionally connect to observation socket print("\nDo you want to monitor observations? (y/n): ", end="") try: if input().lower().startswith("y"): @@ -688,7 +404,6 @@ def run_interactive_mode(): if not command_line: continue - # Handle special meta commands if command_line.lower() in ["quit", "q"]: break elif command_line.lower() == "status": @@ -704,37 +419,30 @@ def run_interactive_mode(): ) continue elif command_line.lower() == "clear": - print("\033[2J\033[H") # Clear screen + print("\033[2J\033[H") continue elif command_line.lower() == "test": run_test_sequence(client) continue - # Parse and send command raw_input = parse_command(command_line, held_keys, held_state) if raw_input is not None: - # Handle special 'say' command (multi-step) if raw_input.text.startswith("__SAY__"): - message = raw_input.text[7:] # Remove __SAY__ prefix + message = raw_input.text[7:] print(f" Opening chat and typing: {message}") - # Step 1: Press T to open chat client.send_raw_input(RawInput(key_codes=[GLFW.KEY_T])) - time.sleep(0.15) # Wait for chat to open + time.sleep(0.15) - # Step 2: Release T client.send_raw_input(RawInput()) time.sleep(0.05) - # Step 3: Type the message client.send_raw_input(RawInput(text=message)) time.sleep(0.05) - # Step 4: Press Enter to send client.send_raw_input(RawInput(key_codes=[GLFW.KEY_ENTER])) time.sleep(0.05) - # Step 5: Release Enter client.send_raw_input(RawInput()) else: client.send_raw_input(raw_input) @@ -748,7 +456,6 @@ def run_interactive_mode(): except KeyboardInterrupt: print("\nInterrupted by user") finally: - # Send release all before disconnect client.send_raw_input(RawInput()) client.disconnect() observation_client.disconnect() @@ -797,7 +504,7 @@ def run_automated_test(): except KeyboardInterrupt: print("\nTest interrupted by user") finally: - client.send_raw_input(RawInput()) # Release all + client.send_raw_input(RawInput()) client.disconnect() diff --git a/tests/client/__init__.py b/tests/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/client/test_connection.py b/tests/client/test_connection.py new file mode 100644 index 0000000..0b6994c --- /dev/null +++ b/tests/client/test_connection.py @@ -0,0 +1,172 @@ +import asyncio +import struct +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest + +from mineagent.client.connection import AsyncMinecraftClient, ConnectionConfig +from mineagent.client.protocol import RawInput + +FRAME_WIDTH = 2 +FRAME_HEIGHT = 2 +FRAME_SIZE = FRAME_WIDTH * FRAME_HEIGHT * 3 + + +@pytest.fixture +def config(): + return ConnectionConfig( + frame_width=FRAME_WIDTH, + frame_height=FRAME_HEIGHT, + max_retries=3, + retry_delay=0.0, + ) + + +@pytest.fixture +def client(config): + return AsyncMinecraftClient(config) + + +def _make_mock_reader(): + reader = AsyncMock(spec=asyncio.StreamReader) + return reader + + +def _make_mock_writer(): + writer = MagicMock(spec=asyncio.StreamWriter) + writer.drain = AsyncMock() + writer.wait_closed = AsyncMock() + return writer + + +def _build_observation_bytes(reward: float, frame: bytes) -> tuple[bytes, bytes]: + """Build the header and frame bytes matching the wire protocol.""" + header = struct.pack(">d", reward) + struct.pack(">I", len(frame)) + return header, frame + + +async def _connect_client(client: AsyncMinecraftClient): + """Patch open_unix_connection and connect, returning the mock reader/writer.""" + reader = _make_mock_reader() + writer = _make_mock_writer() + + async def fake_open(path): + if path == client.config.observation_socket: + return (reader, MagicMock()) + return (MagicMock(), writer) + + with patch( + "mineagent.client.connection.asyncio.open_unix_connection", + side_effect=fake_open, + ): + result = await client.connect() + + assert result is True + return reader, writer + + +@pytest.mark.asyncio +async def test_connect_success(client): + reader = _make_mock_reader() + writer = _make_mock_writer() + + async def fake_open(path): + if path == client.config.observation_socket: + return (reader, MagicMock()) + return (MagicMock(), writer) + + with patch( + "mineagent.client.connection.asyncio.open_unix_connection", + side_effect=fake_open, + ): + result = await client.connect() + + assert result is True + assert client.connected is True + + +@pytest.mark.asyncio +async def test_connect_failure_retries(client): + call_count = 0 + + async def fake_open(path): + nonlocal call_count + call_count += 1 + raise OSError("Connection refused") + + with patch( + "mineagent.client.connection.asyncio.open_unix_connection", + side_effect=fake_open, + ): + result = await client.connect() + + assert result is False + assert client.connected is False + assert call_count == client.config.max_retries + + +@pytest.mark.asyncio +async def test_disconnect(client): + reader, writer = await _connect_client(client) + assert client.connected is True + + await client.disconnect() + + assert client.connected is False + writer.close.assert_called_once() + writer.wait_closed.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_action(client): + raw_input = RawInput(key_codes=[87], mouse_dx=1.0, mouse_dy=-1.0) + expected_bytes = raw_input.to_bytes() + + _, writer = await _connect_client(client) + + result = await client.send_action(raw_input) + + assert result is True + writer.write.assert_called_once_with(expected_bytes) + writer.drain.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_action_not_connected(client): + result = await client.send_action(RawInput()) + assert result is False + + +@pytest.mark.asyncio +async def test_receive_observation(client): + reward = 1.5 + frame_bytes = np.zeros((FRAME_HEIGHT, FRAME_WIDTH, 3), dtype=np.uint8).tobytes() + header, frame_data = _build_observation_bytes(reward, frame_bytes) + + reader, _ = await _connect_client(client) + reader.readexactly = AsyncMock(side_effect=[header, frame_data]) + + obs = await client.receive_observation() + + assert obs is not None + assert obs.reward == reward + assert obs.frame.shape == (FRAME_HEIGHT, FRAME_WIDTH, 3) + + +@pytest.mark.asyncio +async def test_receive_observation_incomplete(client): + reader, _ = await _connect_client(client) + reader.readexactly = AsyncMock( + side_effect=asyncio.IncompleteReadError(partial=b"", expected=12) + ) + + obs = await client.receive_observation() + + assert obs is None + + +@pytest.mark.asyncio +async def test_receive_observation_not_connected(client): + obs = await client.receive_observation() + assert obs is None diff --git a/tests/client/test_protocol.py b/tests/client/test_protocol.py new file mode 100644 index 0000000..ddd52a2 --- /dev/null +++ b/tests/client/test_protocol.py @@ -0,0 +1,276 @@ +import struct + +import numpy as np +import pytest + +from mineagent.client.protocol import GLFW, RawInput, parse_observation + + +# --- to_bytes serialization --- + + +def test_empty_raw_input_to_bytes(): + data = RawInput().to_bytes() + + # 1 (key count) + 4+4 (mouse dx/dy) + 1 (buttons) + 4 (scroll) + 2 (text len) = 16 + assert len(data) == 16 + + assert data[0] == 0 + assert struct.unpack(">f", data[1:5])[0] == 0.0 + assert struct.unpack(">f", data[5:9])[0] == 0.0 + assert data[9] == 0 + assert struct.unpack(">f", data[10:14])[0] == 0.0 + assert struct.unpack(">H", data[14:16])[0] == 0 + + +def test_to_bytes_with_key_codes(): + raw = RawInput(key_codes=[GLFW.KEY_W, GLFW.KEY_SPACE]) + data = raw.to_bytes() + + assert data[0] == 2 + assert struct.unpack(">h", data[1:3])[0] == GLFW.KEY_W + assert struct.unpack(">h", data[3:5])[0] == GLFW.KEY_SPACE + + +def test_to_bytes_with_mouse_deltas(): + raw = RawInput(mouse_dx=10.5, mouse_dy=-3.25) + data = raw.to_bytes() + + offset = 1 # skip key count (0 keys) + assert struct.unpack(">f", data[offset : offset + 4])[0] == pytest.approx(10.5) + assert struct.unpack(">f", data[offset + 4 : offset + 8])[0] == pytest.approx(-3.25) + + +def test_to_bytes_with_text(): + raw = RawInput(text="hello") + data = raw.to_bytes() + + text_len_offset = 14 # 1 + 4 + 4 + 1 + 4 + text_len = struct.unpack(">H", data[text_len_offset : text_len_offset + 2])[0] + assert text_len == 5 + assert data[text_len_offset + 2 :] == b"hello" + + +def test_to_bytes_with_unicode_text(): + raw = RawInput(text="héllo") + data = raw.to_bytes() + + text_len_offset = 14 + text_len = struct.unpack(">H", data[text_len_offset : text_len_offset + 2])[0] + encoded = "héllo".encode("utf-8") + assert text_len == len(encoded) # 6 bytes, not 5 characters + assert data[text_len_offset + 2 :] == encoded + + +def test_to_bytes_round_trip(): + raw = RawInput( + key_codes=[GLFW.KEY_W, GLFW.KEY_A], + mouse_dx=5.0, + mouse_dy=-2.0, + mouse_buttons=0b101, # left + middle + scroll_delta=1.5, + text="test", + ) + data = raw.to_bytes() + + offset = 0 + num_keys = data[offset] + offset += 1 + assert num_keys == 2 + + keys = [] + for _ in range(num_keys): + keys.append(struct.unpack(">h", data[offset : offset + 2])[0]) + offset += 2 + assert keys == [GLFW.KEY_W, GLFW.KEY_A] + + mouse_dx = struct.unpack(">f", data[offset : offset + 4])[0] + offset += 4 + assert mouse_dx == pytest.approx(5.0) + + mouse_dy = struct.unpack(">f", data[offset : offset + 4])[0] + offset += 4 + assert mouse_dy == pytest.approx(-2.0) + + mouse_buttons = data[offset] + offset += 1 + assert mouse_buttons == 0b101 + + scroll_delta = struct.unpack(">f", data[offset : offset + 4])[0] + offset += 4 + assert scroll_delta == pytest.approx(1.5) + + text_len = struct.unpack(">H", data[offset : offset + 2])[0] + offset += 2 + assert text_len == 4 + + text = data[offset : offset + text_len].decode("utf-8") + assert text == "test" + + assert offset + text_len == len(data) + + +# --- to_bytes edge cases --- + + +def test_to_bytes_too_many_keys(): + raw = RawInput(key_codes=list(range(256))) + with pytest.raises(ValueError, match="Too many keys"): + raw.to_bytes() + + +def test_to_bytes_max_valid_keys(): + raw = RawInput(key_codes=list(range(255))) + data = raw.to_bytes() + assert data[0] == 255 + + +def test_to_bytes_text_too_long(): + raw = RawInput(text="x" * 65536) + with pytest.raises(ValueError, match="Text too long"): + raw.to_bytes() + + +# --- Mouse button helpers --- + + +def test_set_left_mouse(): + raw = RawInput() + raw.set_left_mouse(True) + assert raw.mouse_buttons & 1 + + raw.set_left_mouse(False) + assert raw.mouse_buttons == 0 + + +def test_set_right_mouse(): + raw = RawInput() + raw.set_right_mouse(True) + assert raw.mouse_buttons & 2 + + raw.set_right_mouse(False) + assert raw.mouse_buttons == 0 + + +def test_set_middle_mouse(): + raw = RawInput() + raw.set_middle_mouse(True) + assert raw.mouse_buttons & 4 + + raw.set_middle_mouse(False) + assert raw.mouse_buttons == 0 + + +def test_set_multiple_mouse_buttons(): + raw = RawInput() + raw.set_left_mouse(True) + raw.set_right_mouse(True) + raw.set_middle_mouse(True) + assert raw.mouse_buttons == 0b111 + + raw.set_right_mouse(False) + assert raw.mouse_buttons == 0b101 + + +# --- release_all --- + + +def test_release_all(): + raw = RawInput.release_all() + assert raw.key_codes == [] + assert raw.mouse_dx == 0.0 + assert raw.mouse_dy == 0.0 + assert raw.mouse_buttons == 0 + assert raw.scroll_delta == 0.0 + assert raw.text == "" + + +# --- parse_observation --- + +FRAME_HEIGHT = 4 +FRAME_WIDTH = 4 +FRAME_SHAPE = (FRAME_HEIGHT, FRAME_WIDTH) +FRAME_NUM_BYTES = FRAME_HEIGHT * FRAME_WIDTH * 3 + + +def _build_header(reward: float, frame_length: int) -> bytes: + return struct.pack(">d", reward) + struct.pack(">I", frame_length) + + +def test_parse_observation(): + reward = 2.5 + frame_data = bytes(range(FRAME_NUM_BYTES % 256)) * ( + FRAME_NUM_BYTES // (FRAME_NUM_BYTES % 256) + 1 + ) + frame_data = frame_data[:FRAME_NUM_BYTES] + header = _build_header(reward, len(frame_data)) + + obs = parse_observation(header, frame_data, FRAME_SHAPE) + + assert obs.reward == reward + assert obs.frame.shape == (FRAME_HEIGHT, FRAME_WIDTH, 3) + assert obs.frame.dtype == np.uint8 + + +def test_parse_observation_zero_reward(): + frame_data = b"\x00" * FRAME_NUM_BYTES + header = _build_header(0.0, FRAME_NUM_BYTES) + + obs = parse_observation(header, frame_data, FRAME_SHAPE) + + assert obs.reward == 0.0 + assert obs.frame.shape == (FRAME_HEIGHT, FRAME_WIDTH, 3) + assert np.all(obs.frame == 0) + + +def test_parse_observation_negative_reward(): + frame_data = b"\xff" * FRAME_NUM_BYTES + header = _build_header(-100.0, FRAME_NUM_BYTES) + + obs = parse_observation(header, frame_data, FRAME_SHAPE) + + assert obs.reward == -100.0 + assert np.all(obs.frame == 255) + + +def test_parse_observation_frame_values(): + frame_array = np.arange(FRAME_NUM_BYTES, dtype=np.uint8) + frame_data = frame_array.tobytes() + header = _build_header(1.0, FRAME_NUM_BYTES) + + obs = parse_observation(header, frame_data, FRAME_SHAPE) + + expected = frame_array.reshape(FRAME_HEIGHT, FRAME_WIDTH, 3) + np.testing.assert_array_equal(obs.frame, expected) + + +def test_parse_observation_invalid_header_length(): + with pytest.raises(ValueError, match="Header must be 12 bytes"): + parse_observation(b"\x00" * 8, b"\x00" * FRAME_NUM_BYTES, FRAME_SHAPE) + + +def test_parse_observation_frame_length_mismatch(): + wrong_length = FRAME_NUM_BYTES + 1 + header = _build_header(0.0, wrong_length) + + with pytest.raises(ValueError, match="Frame length mismatch"): + parse_observation(header, b"\x00" * FRAME_NUM_BYTES, FRAME_SHAPE) + + +def test_parse_observation_frame_size_mismatch(): + bad_frame = b"\x00" * (FRAME_NUM_BYTES - 1) + header = _build_header(0.0, len(bad_frame)) + + with pytest.raises(ValueError, match="Frame data size mismatch"): + parse_observation(header, bad_frame, FRAME_SHAPE) + + +def test_parse_observation_default_frame_shape(): + height, width = 240, 320 + num_bytes = height * width * 3 + frame_data = b"\x00" * num_bytes + header = _build_header(0.0, num_bytes) + + obs = parse_observation(header, frame_data) + + assert obs.frame.shape == (240, 320, 3)