diff --git a/examples/terminal/README.md b/examples/terminal/README.md new file mode 100644 index 000000000..0b9acc65c --- /dev/null +++ b/examples/terminal/README.md @@ -0,0 +1,12 @@ +### Terminal-Bench examples + +- Requirements: Python >= 3.12 +- Install Terminal-Bench: + +```bash +pip install terminal-bench +``` + +After installing, you can run the sample script in this folder to evaluate openai/o4-mini on the terminal-bench-core v0.1.1 dataset with Terminal Bench's terminus 1 agent. + + diff --git a/examples/terminal/prepare_terminal_data.py b/examples/terminal/prepare_terminal_data.py new file mode 100644 index 000000000..c6d4e552a --- /dev/null +++ b/examples/terminal/prepare_terminal_data.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List, Optional +import yaml + +from terminal_bench.dataset.dataset import Dataset + +def load_terminal_bench_dataset( + dataset_name: str, + dataset_version: str = "head", + task_ids: Optional[List[str]] = None, + n_tasks: Optional[int] = None, + cache_path: Optional[Path] = None, + local_registry_path: Optional[Path] = None, +) -> List[Dict[str, Any]]: + """Load Terminal-Bench dataset and convert to minimal rLLM task dicts. + + Args: + dataset_name: Dataset registry name. + dataset_version: Concrete version or "head". + task_ids: Optional subset of task IDs to include. + n_tasks: Optional cap on number of tasks. + cache_path: Optional path for dataset cache. + local_registry_path: Optional path to a local registry. + + Returns: + List[Dict[str, Any]]: Each dict includes ``task_path``, ``task_id``, + and ``instruction``. + """ + dataset = Dataset( + name=dataset_name, + version=dataset_version, + task_ids=task_ids, + n_tasks=n_tasks, + local_registry_path=local_registry_path, + ) + + tasks: List[Dict[str, Any]] = [] + for task_path in dataset: + task_config = load_task_config(task_path) + + task_dict = { + "task_path": str(task_path), + "task_id": task_path.name, + "instruction": task_config["instruction"], + } + tasks.append(task_dict) + + return tasks + + +def load_task_config(task_path: Path) -> Dict[str, Any]: + """Load and validate task configuration from task.yaml file. + + Args: + task_path: Path to a Terminal-Bench task directory. + + Returns: + Dict[str, Any]: Parsed YAML mapping. + + Raises: + FileNotFoundError: If ``task.yaml`` is missing. + ValueError: If required fields are missing. + """ + task_yaml_path = task_path / "task.yaml" + + if not task_yaml_path.exists(): + raise FileNotFoundError(f"task.yaml not found at {task_yaml_path}") + + with open(task_yaml_path, 'r') as f: + config = yaml.safe_load(f) + + # Validate required fields + required_fields = ["instruction"] + for field in required_fields: + if field not in config: + raise ValueError(f"Missing required field '{field}' in {task_yaml_path}") + + return config \ No newline at end of file diff --git a/examples/terminal/run_terminal.py b/examples/terminal/run_terminal.py new file mode 100644 index 000000000..83cefd4a9 --- /dev/null +++ b/examples/terminal/run_terminal.py @@ -0,0 +1,55 @@ +import asyncio +import os + +from rllm.engine.agent_workflow_engine import AgentWorkflowEngine +from rllm.integrations.terminal_terminus_1 import TerminalLiteLLMEngine +from rllm.workflows.terminal_workflow import TerminalWorkflow +from rllm.agents.terminal_terminus_agent import TerminalTerminusAgent +from rllm.environments.terminal.terminal_terminus import TerminalTerminusEnv +from examples.terminal.prepare_terminal_data import load_terminal_bench_dataset + +if __name__ == "__main__": + os.environ["TOKENIZERS_PARALLELISM"] = "true" + + dataset_name = "terminal-bench-core" + dataset_version = "0.1.1" + + model_name = "openai/o4-mini" + openai_base_url = None + max_turns = 50 + max_agent_timeout_sec = 600.0 + + env_args = {"model_name": model_name, "api_base": openai_base_url, "cleanup": True} + rollout_engine = TerminalLiteLLMEngine( + model=env_args["model_name"], api_base=env_args["api_base"] + ) + + engine = AgentWorkflowEngine( + workflow_cls=TerminalWorkflow, + workflow_args={ + "agent_cls": TerminalTerminusAgent, + "env_cls": TerminalTerminusEnv, + "env_args": env_args, + "max_steps": max_turns, + "global_agent_timeout_sec": max_agent_timeout_sec, + }, + rollout_engine=rollout_engine, + n_parallel_tasks=1, + # Terminal-Bench already retries LLM calls 3 times in handle_llm_interaction + retry_limit=1, + ) + + asyncio.run(engine.initialize_pool()) + + tasks = load_terminal_bench_dataset( + dataset_name=dataset_name, + dataset_version=dataset_version, + ) + + print(f"Loaded {len(tasks)} tasks from {dataset_name} {dataset_version}") + + episodes = asyncio.run(engine.execute_tasks(tasks=tasks)) + + total = len(episodes) + correct = sum(ep.is_correct for ep in episodes) + print(f"Accuracy: {correct}/{total} = {correct / total:.3f}") \ No newline at end of file diff --git a/rllm/agents/terminal_terminus_agent.py b/rllm/agents/terminal_terminus_agent.py new file mode 100644 index 000000000..cf5867678 --- /dev/null +++ b/rllm/agents/terminal_terminus_agent.py @@ -0,0 +1,90 @@ +from typing import Any, Dict, List, Optional +import copy + +from rllm.agents.agent import Action, BaseAgent, Step, Trajectory + + +class TerminalTerminusAgent(BaseAgent): + """Thin agent wrapper; environment handles Terminal-Bench specifics. + + Maintains a simple alternating chat message history and mirrors raw + model responses to ``Action`` objects consumed by the environment. + """ + + def __init__(self, **kwargs): + """Initialize internal state.""" + self.reset() + + def update_from_env( + self, + observation: Any, + reward: float, + done: bool, + info: Dict[str, Any], + **kwargs, + ) -> None: + """Update agent state from an environment transition. + + Args: + observation: Latest observation dict from the environment. + reward: Scalar reward from the previous action. + done: Whether the episode has terminated. + info: Auxiliary environment info. + **kwargs: Unused; reserved for extensions. + """ + if self._trajectory.steps: + prior_step = self._trajectory.steps[-1] + prior_step.observation = observation + prior_step.reward = reward + prior_step.done = done + prior_step.info = info + + self.messages.append({"role": "user", "content": observation["prompt"]}) + self.cur_step = Step(observation=observation) + + def update_from_model(self, response: str, **kwargs) -> Action: + """Record model response and produce an action. + + Args: + response: Raw assistant text. + **kwargs: Unused; reserved for extensions. + + Returns: + Action: Action object whose ``action`` is the raw response. + """ + self._trajectory.steps.append(self.cur_step) + + cur_step = self._trajectory.steps[-1] + cur_step.model_response = response + cur_step.action = response + + self.messages.append({"role": "assistant", "content": response}) + cur_step.chat_completions = copy.deepcopy(self.messages) + self.step += 1 + return Action(action=response) + + def get_current_state(self) -> Optional[Step]: + """Return the most recent step in the trajectory. + + Returns: + Optional[Step]: Last step if available. + """ + assert self._trajectory.steps, "Trajectory should not be empty when get_current_state is called." + return self._trajectory.steps[-1] + + def reset(self) -> None: + """Reset message history and trajectory.""" + self._trajectory = Trajectory() + self.messages = [] + self.step = 0 + + @property + def chat_completions(self) -> List[Dict[str, str]]: + """OpenAI-style message history consumed by the rollout engine.""" + return self.messages + + @property + def trajectory(self) -> Trajectory: + return self._trajectory + + \ No newline at end of file diff --git a/rllm/environments/terminal/terminal_terminus.py b/rllm/environments/terminal/terminal_terminus.py new file mode 100644 index 000000000..1ad9ca3fc --- /dev/null +++ b/rllm/environments/terminal/terminal_terminus.py @@ -0,0 +1,357 @@ +import json +import uuid +from pathlib import Path +from typing import Any, Dict, Tuple + +from rllm.environments.base.base_env import BaseEnv +from rllm.integrations.terminal_terminus_1 import RLLMTerminus as Terminus + +from terminal_bench.terminal.terminal import Terminal +from terminal_bench.terminal.docker_compose_manager import DockerComposeManager +from terminal_bench.agents.terminus_1 import CommandBatchResponse +from terminal_bench.parsers.parser_factory import ParserFactory +from terminal_bench.parsers.base_parser import UnitTestStatus +from terminal_bench.handlers.trial_handler import TrialHandler +from pydantic import ValidationError + +class TerminalTerminusEnv(BaseEnv): + """Environment bridging rLLM and Terminal-Bench's Terminus agent. + + Manages Docker/tmux Terminal-Bench sessions, builds prompts, executes + command batches, and runs unit tests to compute rewards. + + Args: + model_name: LLM model identifier used by Terminus. + api_base: Optional base URL for the LLM API. + task_path: Path to the Terminal-Bench task directory. + instruction: Natural language instruction for the task. + task_id: Identifier for the task instance. + task: Optional task dictionary overriding individual parameters. + max_episodes: Maximum number of steps before forced evaluation. + cleanup: Whether to remove Docker artifacts on shutdown. + no_rebuild: Skip Docker image rebuilds if True. + logging_dir: Optional directory for logs and markers. + max_test_timeout_sec: Maximum time to wait for tests to complete. + **kwargs: Reserved for future configuration. + """ + + def __init__( + self, + model_name: str, + api_base: str = None, + task_path: str = None, + instruction: str = None, + task_id: str = "unknown", + task: Dict[str, Any] = None, + max_episodes: int = 50, + cleanup: bool = True, + no_rebuild: bool = False, + logging_dir: str = None, + max_test_timeout_sec: int = 120, + **kwargs + ): + """Initialize Terminal-Bench environment.""" + # Handle both task dictionary and individual parameters + if task is not None: + self.task = task + task_path = task.get("task_path") + instruction = task.get("instruction") + task_id = task.get("task_id", "unknown") + else: + self.task = { + "task_path": task_path, + "instruction": instruction, + "task_id": task_id, + } + + self.model_name = model_name + self.api_base = api_base + self.max_episodes = max_episodes + self.cleanup = cleanup + self.no_rebuild = no_rebuild + self.logging_dir = Path(logging_dir) if logging_dir else None + self.max_test_timeout_sec = max_test_timeout_sec + + # Task configuration (may be provided later via reset(task=...)) + self.task_path = Path(task_path) if task_path else None + self.instruction = instruction + self.task_id = task_id + + # Unique session ID + self.session_id = f"{self.task_id}_{uuid.uuid4().hex[:8]}" + + # Terminal-Bench components (initialized in reset) + self.terminal = None + self.session = None + self.terminus_agent = None + self.trial_handler = None + self.parser = None + + # Episode tracking + self.current_episode = 0 + self.is_initialized = False + + def reset(self, task: Dict[str, Any] | None = None, uid: str | None = None) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Reset environment and return initial observation. + + Args: + task: Optional task dictionary with ``task_path``, ``instruction``, ``task_id``. + uid: Rollout identifier to namespace logs and sessions. + + Returns: + Tuple[Dict[str, Any], Dict[str, Any]]: Initial observation and info. + """ + if task is not None: + self.task = task + self.task_path = Path(task.get("task_path")) + self.instruction = task.get("instruction") + self.task_id = task.get("task_id", "unknown") + self.session_id = f"{self.task_id}_{uuid.uuid4().hex[:8]}" + + if not self.task_path: + raise ValueError("TerminalTerminusEnv.reset requires a task with 'task_path'") + + # Initialize trial handler + output_path = self.logging_dir or Path("/tmp/rllm_terminal_bench_logs") + output_path.mkdir(parents=True, exist_ok=True) + self.trial_handler = TrialHandler( + trial_name=f"{self.task_id}.{uid}.rllm-run", + input_path=self.task_path, + output_path=output_path + ) + + task_config = self.trial_handler.task + self.parser = ParserFactory.get_parser(task_config.parser_name) + self.max_test_timeout_sec = task_config.max_test_timeout_sec + self._initialize_terminal_sync() + self.terminus_agent = Terminus( + model_name=self.model_name, + max_episodes=self.max_episodes, + api_base=self.api_base + ) + initial_prompt = self._build_initial_prompt_sync() + + self.current_episode = 0 + self.is_initialized = True + + observation = { + "prompt": initial_prompt, + "type": "initial" + } + info = { + "task_id": self.task_id, + "episode": self.current_episode, + "max_episodes": self.max_episodes, + "instruction": self.instruction + } + return observation, info + + def step(self, action) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]: + """Execute agent action and return environment response. + + Args: + action: Raw string or object with ``action`` attribute containing the + JSON command batch produced by the agent. + + Returns: + Tuple[observation, reward, done, info]. + """ + + if not self.is_initialized: + raise RuntimeError("Environment not initialized. Call reset() first.") + + # Ensure action is a raw JSON string + if isinstance(action, str): + action_str = action + elif hasattr(action, "action"): + action_str = action.action + else: + action_str = str(action) + + # Parse model response into command batch + try: + parsed_response = CommandBatchResponse.model_validate_json(action_str) + except (json.JSONDecodeError, ValidationError) as e: + # End trajectory if we can't parse the response + reward, _ = self._evaluate_completion_sync() + observation = {"prompt": "", "type": "terminal"} + info = { + "task_id": self.task_id, + "episode": self.current_episode, + "parse_error": True, + "error_message": str(e), + "is_task_complete": False + } + return observation, reward, True, info + + self._record_asciinema_marker_sync(parsed_response.model_dump_json()) + timeout_occurred, terminal_output = self._execute_commands( + parsed_response.commands + ) + + # Determine whether to run tests now (on "done" or max episodes) + should_run_tests = parsed_response.is_task_complete or self._check_episode_limit() + if should_run_tests: + reward, _ = self._evaluate_completion_sync() + done = True + else: + reward = 0.0 + done = False + + + self.current_episode += 1 + + # Prepare next observation + if done: + observation = {"prompt": "", "type": "terminal"} + else: + observation = { + "prompt": terminal_output, + "type": "timeout" if timeout_occurred else "continuation" + } + info = { + "task_id": self.task_id, + "episode": self.current_episode, + "max_episodes": self.max_episodes, + "timeout_occurred": timeout_occurred, + "is_task_complete": parsed_response.is_task_complete + } + return observation, reward, done, info + + def close(self): + """Clean up terminal and container resources.""" + if self.terminal: + self.terminal.stop() + + @staticmethod + def from_dict(env_args: Dict[str, Any]) -> "TerminalTerminusEnv": + """Create environment instance from dictionary configuration. + + If top-level task keys are present (``task_path``, ``instruction``, + ``task_id``), they are collected into a nested ``task`` dict. + """ + # Handle case where task data is passed at the top level + if "task_path" in env_args and "instruction" in env_args and "task_id" in env_args: + # Create task dict from individual parameters + task = { + "task_path": env_args.pop("task_path"), + "instruction": env_args.pop("instruction"), + "task_id": env_args.pop("task_id") + } + env_args["task"] = task + return TerminalTerminusEnv(**env_args) + + @staticmethod + def is_multithread_safe() -> bool: + """Thread-safe via per-instance isolated containers.""" + return True + + def _initialize_terminal_sync(self): + """Initialize Docker container and tmux session synchronously.""" + # Create terminal interface + self.terminal = Terminal( + client_container_name=self.trial_handler.client_container_name, + client_image_name=self.trial_handler.client_image_name, + docker_compose_path=self.trial_handler.task_paths.docker_compose_path, + docker_image_name_prefix=self.trial_handler.docker_image_name_prefix, + sessions_logs_path=self.trial_handler.trial_paths.sessions_path, + agent_logs_path=self.trial_handler.trial_paths.agent_logging_dir, + no_rebuild=self.no_rebuild, + cleanup=self.cleanup + ) + + # Start containers and get the container object + self.terminal.start() + + # Create tmux session for agent interaction + self.session = self.terminal.create_session( + "agent", + is_active_stream=False, + as_configured_user=True + ) + + def _build_initial_prompt_sync(self) -> str: + """Build initial prompt using Terminus template synchronously.""" + terminal_state = self.session.capture_pane() + + return self.terminus_agent.build_initial_prompt( + self.instruction, + terminal_state + ) + + def _execute_commands(self, commands) -> Tuple[bool, str]: + """Execute command batch synchronously.""" + return self.terminus_agent.execute_commands(commands, self.session) + + def _record_asciinema_marker_sync(self, marker_text: str): + """Record interaction marker for debugging synchronously.""" + if self.logging_dir and hasattr(self.session, 'get_asciinema_timestamp'): + current_timestamp = self.session.get_asciinema_timestamp() + marker_file = self.logging_dir / f"{self.task_id}_markers.jsonl" + with open(marker_file, 'a') as f: + json.dump({ + "timestamp": current_timestamp, + "episode": self.current_episode, + "marker": marker_text + }, f) + f.write('\n') + + def _evaluate_completion_sync(self) -> Tuple[float, bool]: + """Evaluate task completion by running tests synchronously. + + Copies test artifacts into the container, executes the task's test script, + parses the results, and returns a binary reward. + + Returns: + Tuple[float, bool]: ``(reward, done)`` where reward is 1.0 if all tests + pass, else 0.0; ``done`` is always True. + """ + # Ensure test artifacts are copied into the container under /tests + paths = [self.trial_handler.task_paths.run_tests_path] + if self.trial_handler.task_paths.test_dir.exists(): + paths.append(self.trial_handler.task_paths.test_dir) + self.terminal.copy_to_container( + paths=paths, + container_dir=str(DockerComposeManager.CONTAINER_TEST_DIR), + ) + + # Choose session according to run_tests_in_same_shell + if self.trial_handler.task.run_tests_in_same_shell: + test_session = self.session + else: + test_session = self.terminal.create_session( + "tests", is_active_stream=False, as_configured_user=False + ) + + # Execute test script + test_script_path = str(DockerComposeManager.CONTAINER_TEST_DIR / "run-tests.sh") + try: + test_session.send_keys( + [f"bash {test_script_path}", "Enter"], + block=True, + max_timeout_sec=self.trial_handler.task.max_test_timeout_sec, + ) + + # Capture test output (blocking send_keys should be sufficient) + test_output = test_session.capture_pane(capture_entire=True) + + # Parse test results using Terminal-Bench parser + parser_results = self.parser.parse(test_output) + + # Check if all tests passed + if parser_results and all( + status == UnitTestStatus.PASSED for status in parser_results.values() + ): + reward = 1.0 + else: + reward = 0.0 + + except Exception: + reward = 0.0 + + return reward, True + + def _check_episode_limit(self) -> bool: + """Check if episode limit reached.""" + return self.current_episode >= self.max_episodes - 1 + diff --git a/rllm/integrations/terminal_terminus_1.py b/rllm/integrations/terminal_terminus_1.py new file mode 100644 index 000000000..5f6d66cf3 --- /dev/null +++ b/rllm/integrations/terminal_terminus_1.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Tuple + +from terminal_bench.agents.terminus_1 import ( + Command, + CommandBatchResponse, + Terminus, +) +from terminal_bench.llms.lite_llm import LiteLLM +from terminal_bench.llms.chat import Chat +from terminal_bench.terminal.tmux_session import TmuxSession + +from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine + +class RLLMTerminus(Terminus): + """ + rLLM integration subclass for Terminal Bench's `Terminus` agent. + + This class exposes public wrappers around selected private methods/fields of + the upstream `Terminus` class, enabling rLLM workflows to leverage the + agent's internal building blocks (prompt construction, command execution, + and LLM interaction) without modifying the third-party source. + """ + + def __init__( + self, + model_name: str, + max_episodes: int = 50, + api_base: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__( + model_name=model_name, + max_episodes=max_episodes, + api_base=api_base, + **kwargs, + ) + + # Public methods exposing private Terminus methods + + def handle_llm_interaction( + self, + chat: Chat, + prompt: str, + logging_paths: tuple[Path | None, Path | None, Path | None], + ) -> CommandBatchResponse: + """Public wrapper for the internal LLM interaction handler.""" + # Uses the upstream protected method directly + return self._handle_llm_interaction( + chat=chat, + prompt=prompt, + logging_paths=logging_paths, + ) + + def build_initial_prompt(self, instruction: str, terminal_state: str) -> str: + """ + Build the initial prompt without executing the agent loop. + + Mirrors the original formatting logic used by `perform_task`. + """ + # Access upstream private fields for formatting consistency + return self._prompt_template.format( + response_schema=self._response_schema, + instruction=instruction, + history="", + terminal_state=terminal_state, + ) + + def execute_commands( + self, + commands: list[Command], + session: TmuxSession, + ) -> Tuple[bool, str]: + """ + Execute commands and return formatted output. + + Returns a tuple of (timeout_occurred, formatted_output). + """ + return self._execute_commands(commands, session) + + def get_response_schema(self) -> str: + """Get the JSON schema for the response format.""" + return self._response_schema + + def get_prompt_template(self) -> str: + """Get the prompt template string.""" + return self._prompt_template + + def get_timeout_template(self) -> str: + """Get the timeout prompt template.""" + return self._timeout_template + + def format_timeout_prompt( + self, + instruction: str, + history: str, + terminal_state: str, + ) -> str: + """Format the prompt after a timeout occurrence.""" + return self._timeout_template.format( + response_schema=self._response_schema, + instruction=instruction, + history=history, + terminal_state=terminal_state, + ) + + +class TerminalLiteLLMEngine(RolloutEngine): + """Minimal rollout engine delegating to Terminal-Bench's LiteLLM + Chat. + + Args: + model: LLM model identifier. + tokenizer: Optional tokenizer (unused; Terminal-Bench handles counting). + api_base: Optional base URL for the LLM API. + sampling_params: Optional dict of generation parameters. + max_episodes: Max steps used to configure the Terminus helper. + logging_dir: Optional path to write logs (unused at per-call level). + **kwargs: Reserved for future configuration. + """ + + def __init__(self, model: str, tokenizer=None, api_base: str | None = None, sampling_params: dict | None = None, max_episodes: int = 50, logging_dir: str | None = None, **kwargs: Any): + self.model = model + self.tokenizer = tokenizer # Unused; Terminal-Bench handles token counting + self.api_base = api_base + self.sampling_params = sampling_params or {} + + self._llm = LiteLLM(model_name=model, api_base=api_base) + self._terminus = RLLMTerminus(model_name=model, max_episodes=max_episodes, api_base=api_base) + + async def get_model_response(self, messages: list[dict], **kwargs: Any) -> ModelOutput: + """Get a chat completion via Terminal-Bench's Chat abstraction. + + Expects ``messages`` in OpenAI-style format and returns a ``ModelOutput`` + containing the assistant text and token accounting computed by LiteLLM. + + Args: + messages: List of role-content message dictionaries. + **kwargs: Unused; present for API compatibility. + + Returns: + ModelOutput: Assistant text and token counts. + """ + # Expect the last message to be the user prompt + assert messages and messages[-1]["role"] == "user", "Last message must be a user turn" + prompt = messages[-1]["content"] + message_history = messages[:-1] + + # Create per-call Chat to avoid cross-thread state contamination + chat = Chat(self._llm) + chat._messages = list(message_history) + + # Token counts + prompt_messages_for_count = message_history + [{"role": "user", "content": prompt}] + prompt_tokens = self._llm.count_tokens(prompt_messages_for_count) + + # Disable per-call file logging + logging_paths = (None, None, None) + # Delegate to Terminus' handle_llm_interaction with a per-call agent + self._terminus.handle_llm_interaction( + chat=chat, + prompt=prompt, + logging_paths=logging_paths, + ) + + # Assistant raw text is last assistant message in Chat history + assistant_text = next((m["content"] for m in reversed(chat._messages) if m["role"] == "assistant"), "") + + # Completion tokens + completion_tokens = self._llm.count_tokens([{"role": "assistant", "content": assistant_text}]) + + return ModelOutput( + text=assistant_text, + tool_calls=[], + finish_reason="stop", + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + ) \ No newline at end of file diff --git a/rllm/workflows/terminal_workflow.py b/rllm/workflows/terminal_workflow.py new file mode 100644 index 000000000..8a89ed16d --- /dev/null +++ b/rllm/workflows/terminal_workflow.py @@ -0,0 +1,98 @@ +import time +from rllm.agents.agent import Episode +from rllm.workflows.workflow import TerminationEvent, TerminationReason, Workflow + +class TerminalWorkflow(Workflow): + """Multi-step workflow for Terminal-Bench integration. + + This workflow wires a thin agent to a Terminal-Bench-based environment + and iterates for up to ``max_steps`` steps or until the environment + signals completion. + + Args: + agent_cls: Class of the agent to instantiate. + env_cls: Class of the environment to instantiate. + agent_args: Optional constructor kwargs for ``agent_cls``. + env_args: Optional constructor kwargs for ``env_cls``. + max_steps: Maximum number of agent-environment interaction steps. + sampling_params: Optional sampling parameters forwarded to the engine. + **kwargs: Additional parameters forwarded to the base ``Workflow``. + """ + def __init__( + self, + agent_cls, + env_cls, + agent_args=None, + env_args=None, + max_steps=50, + global_agent_timeout_sec=600.0, + sampling_params=None, + **kwargs, + ): + super().__init__(**kwargs) + + agent_args = dict(agent_args) if agent_args is not None else {} + env_args = dict(env_args) if env_args is not None else {} + sampling_params = dict(sampling_params) if sampling_params is not None else {} + + self.agent = agent_cls(**agent_args) + self.register_agent(self.agent) + self.env = env_cls(**env_args) + self.max_steps = max_steps + self.sampling_params = sampling_params + self.global_agent_timeout_sec = global_agent_timeout_sec + + async def run(self, task: dict, uid: str, **kwargs) -> Episode: + """Execute a multi-step Terminal-Bench workflow. + + Args: + task: Task specification dictionary. + uid: Unique identifier for this rollout. + **kwargs: Unused; present for API compatibility. + + Returns: + Episode: Post-processed episode when the workflow terminates. + """ + + observation, info = await self.run_in_executor(self.reset, task=task, uid=uid) + self.agent.update_from_env(observation, 0, False, info) + + # Compute absolute deadline if a global agent timeout is configured + deadline = None + if self.global_agent_timeout_sec is not None: + deadline = time.time() + float(self.global_agent_timeout_sec) + + for _ in range(self.max_steps): + # Enforce global agent timeout before each step + if deadline is not None and time.time() > deadline: + await self._eval_and_terminate() + + # Get model response via rollout engine (delegates to TB Terminus under the hood) + try: + output = await self.get_model_response(self.agent) + except Exception: + await self._eval_and_terminate() + + action = self.agent.update_from_model(output.text) + + next_obs, reward, done, info = await self.run_in_executor(self.env.step, action) + self.agent.update_from_env(next_obs, reward, done, info) + + if done: + await self.run_in_executor(self.env.close) + raise TerminationEvent(TerminationReason.ENV_DONE) + + # Terminal-Bench parity: always run tests once the agent loop ends + await self._eval_and_terminate() + + async def _eval_and_terminate(self) -> None: + """Run final evaluation, close environment, and terminate the workflow. + + Always raises ``TerminationEvent`` with ``TerminationReason.ENV_DONE`` + after attempting to evaluate and close the environment. + """ + try: + await self.run_in_executor(self.env._evaluate_completion_sync) + finally: + await self.run_in_executor(self.env.close) + raise TerminationEvent(TerminationReason.ENV_DONE) diff --git a/rllm/workflows/workflow.py b/rllm/workflows/workflow.py index 78f5ccd35..8acb14f41 100644 --- a/rllm/workflows/workflow.py +++ b/rllm/workflows/workflow.py @@ -206,7 +206,7 @@ def reset(self, task: dict | None = None, uid: str | None = None) -> tuple[Any, continue attr_value = getattr(self, attr_name) if isinstance(attr_value, BaseEnv) and hasattr(attr_value, "reset"): - return attr_value.reset(task=task) + return attr_value.reset(task=task, uid=uid) print(f"No environments found to resetin {self.__class__.__name__}") diff --git a/tests/agents/test_terminal_terminus_agent.py b/tests/agents/test_terminal_terminus_agent.py new file mode 100644 index 000000000..ffb4345fb --- /dev/null +++ b/tests/agents/test_terminal_terminus_agent.py @@ -0,0 +1,59 @@ +from rllm.agents.terminal_terminus_agent import TerminalTerminusAgent + + +class TestTerminalTerminusAgent: + def test_init_and_properties(self): + agent = TerminalTerminusAgent() + # After init/reset + assert agent.messages == [] + assert agent.step == 0 + # trajectory is initialized + assert agent.trajectory is agent._trajectory + + def test_update_flow(self): + agent = TerminalTerminusAgent() + + # Provide initial observation + obs = {"prompt": "hello"} + agent.update_from_env(observation=obs, reward=0.0, done=False, info={}) + + # Model responds with raw text (Terminus 1 flow mirrors raw response) + response = "run ls -la" + action = agent.update_from_model(response) + + # Check action mirrors response + assert action.action == response + # Trajectory has one step, with observation and model response + assert len(agent._trajectory.steps) == 1 + step = agent._trajectory.steps[0] + assert step.observation == obs + assert step.model_response == response + assert step.action == response + + # Agent messages alternate user/assistant + assert agent.messages[0]["role"] == "user" + assert agent.messages[0]["content"] == "hello" + assert agent.messages[1]["role"] == "assistant" + assert agent.messages[1]["content"] == response + + def test_reset_clears_state(self): + agent = TerminalTerminusAgent() + + # Build some state + agent.update_from_env({"prompt": "p"}, 0.0, False, {}) + agent.update_from_model("resp") + assert len(agent._trajectory.steps) == 1 + assert len(agent.messages) == 2 + + # Reset + agent.reset() + assert len(agent._trajectory.steps) == 0 + assert agent.messages == [] + assert agent.step == 0 + + def test_get_current_state(self): + agent = TerminalTerminusAgent() + agent.update_from_env({"prompt": "a"}, 0.0, False, {}) + agent.update_from_model("b") + cur = agent.get_current_state() + assert cur is agent._trajectory.steps[-1] diff --git a/tests/envs/test_terminal_terminus_env.py b/tests/envs/test_terminal_terminus_env.py new file mode 100644 index 000000000..8dae124f9 --- /dev/null +++ b/tests/envs/test_terminal_terminus_env.py @@ -0,0 +1,201 @@ +from unittest.mock import Mock, patch + +from terminal_bench.agents.terminus_1 import CommandBatchResponse + +from rllm.environments.terminal.terminal_terminus import TerminalTerminusEnv + + +def _setup_trial_handler_and_terminal(mock_trial_handler, mock_terminal, run_tests_in_same_shell=None): + # Trial handler task config + mock_task = Mock() + mock_task.parser_name = "pytest" + mock_task.max_test_timeout_sec = 10 + if run_tests_in_same_shell is not None: + mock_task.run_tests_in_same_shell = run_tests_in_same_shell + + mock_th_instance = Mock() + mock_th_instance.task = mock_task + # Create nested path mocks explicitly + mock_th_instance.task_paths = Mock() + mock_th_instance.trial_paths = Mock() + mock_th_instance.task_paths.run_tests_path = Mock() + mock_th_instance.task_paths.test_dir = Mock() + mock_th_instance.trial_paths.sessions_path = Mock() + mock_th_instance.trial_paths.agent_logging_dir = Mock() + mock_th_instance.client_container_name = "client" + mock_th_instance.client_image_name = "image" + mock_th_instance.task_paths.docker_compose_path = Mock() + mock_th_instance.docker_image_name_prefix = "prefix" + mock_trial_handler.return_value = mock_th_instance + + # Terminal + mock_term_instance = Mock() + mock_term_instance.create_session = Mock(return_value=Mock()) + mock_terminal.return_value = mock_term_instance + + return mock_th_instance, mock_term_instance + + +def _make_env_and_reset(): + env = TerminalTerminusEnv(model_name="m", api_base=None, task_path="/tmp", instruction="do x", task_id="t1") + obs, info = env.reset(uid="u1") + return env, obs, info + + +def _cbr_json(is_done: bool): + from terminal_bench.agents.terminus_1 import CommandBatchResponse + + return CommandBatchResponse( + state_analysis="", + explanation="", + commands=[], + is_task_complete=is_done, + ).model_dump_json() + + +class TestTerminalTerminusEnv: + # Patch where the symbols are looked up (in the module under test) + @patch("rllm.environments.terminal.terminal_terminus.Terminal") + @patch("rllm.environments.terminal.terminal_terminus.TrialHandler") + @patch("rllm.environments.terminal.terminal_terminus.ParserFactory.get_parser") + def test_reset_initializes_terminal_and_returns_observation(self, mock_get_parser, mock_trial_handler, mock_terminal): + # Mocks + _setup_trial_handler_and_terminal(mock_trial_handler, mock_terminal) + mock_get_parser.return_value = Mock() + + env, obs, info = _make_env_and_reset() + + assert isinstance(obs, dict) + assert "prompt" in obs + # We don't assert on info here since we didn't capture the actual return values + + @patch("rllm.environments.terminal.terminal_terminus.Terminal") + @patch("rllm.environments.terminal.terminal_terminus.TrialHandler") + @patch("rllm.environments.terminal.terminal_terminus.ParserFactory.get_parser") + def test_step_parses_and_executes_commands(self, mock_get_parser, mock_trial_handler, mock_terminal): + _setup_trial_handler_and_terminal(mock_trial_handler, mock_terminal) + mock_get_parser.return_value = Mock() + + env, _, _ = _make_env_and_reset() + sample = _cbr_json(is_done=False) + + with patch.object(CommandBatchResponse, "model_validate_json", wraps=CommandBatchResponse.model_validate_json) as spy_validate, patch.object(TerminalTerminusEnv, "_execute_commands", return_value=(False, "out")) as spy_exec: + obs, reward, done, info = env.step(sample) + + spy_validate.assert_called() + spy_exec.assert_called() + assert done is False + assert obs["prompt"] == "out" + + @patch("rllm.environments.terminal.terminal_terminus.Terminal") + @patch("rllm.environments.terminal.terminal_terminus.TrialHandler") + @patch("rllm.environments.terminal.terminal_terminus.ParserFactory.get_parser") + def test_parse_failure_sets_done_and_info(self, mock_get_parser, mock_trial_handler, mock_terminal): + _setup_trial_handler_and_terminal(mock_trial_handler, mock_terminal) + mock_get_parser.return_value = Mock() + + env, _, _ = _make_env_and_reset() + + # Send invalid JSON to trigger parse failure branch + obs, reward, done, info = env.step("{not-json}") + assert done is True + assert info.get("parse_error") is True + + @patch("rllm.environments.terminal.terminal_terminus.Terminal") + @patch("rllm.environments.terminal.terminal_terminus.TrialHandler") + @patch("rllm.environments.terminal.terminal_terminus.ParserFactory.get_parser") + def test_is_task_complete_runs_tests_once(self, mock_get_parser, mock_trial_handler, mock_terminal): + _setup_trial_handler_and_terminal(mock_trial_handler, mock_terminal) + mock_get_parser.return_value = Mock() + + env, _, _ = _make_env_and_reset() + sample_done = _cbr_json(is_done=True) + + # Spy on evaluation path + with patch.object(TerminalTerminusEnv, "_evaluate_completion_sync", return_value=(1.0, True)) as spy_eval: + obs, reward, done, info = env.step(sample_done) + spy_eval.assert_called_once() + assert done is True + + @patch("rllm.environments.terminal.terminal_terminus.Terminal") + @patch("rllm.environments.terminal.terminal_terminus.TrialHandler") + @patch("rllm.environments.terminal.terminal_terminus.ParserFactory.get_parser") + def test_run_tests_in_new_shell_when_flag_false(self, mock_get_parser, mock_trial_handler, mock_terminal): + # Default run_tests_in_same_shell is False; expect a second session for tests + _setup_trial_handler_and_terminal(mock_trial_handler, mock_terminal, run_tests_in_same_shell=False) + mock_get_parser.return_value = Mock() + + # Configure created sessions to have send_keys and capture_pane + agent_sess = mock_terminal.return_value.create_session.return_value + agent_sess.send_keys = Mock() + agent_sess.capture_pane = Mock(return_value="output") + + env, _, _ = _make_env_and_reset() + + with patch.object(TerminalTerminusEnv, "_execute_commands", return_value=(False, "out")): + sample_done = _cbr_json(is_done=True) + + # After reset: one session created. On finalize tests, another session should be created + initial_calls = mock_terminal.return_value.create_session.call_count + _ = env.step(sample_done) + assert mock_terminal.return_value.create_session.call_count == initial_calls + 1 + + @patch("rllm.environments.terminal.terminal_terminus.Terminal") + @patch("rllm.environments.terminal.terminal_terminus.TrialHandler") + @patch("rllm.environments.terminal.terminal_terminus.ParserFactory.get_parser") + def test_copy_to_container_inputs(self, mock_get_parser, mock_trial_handler, mock_terminal): + th, term = _setup_trial_handler_and_terminal(mock_trial_handler, mock_terminal) + # Pretend tests dir exists + th.task_paths.test_dir.exists.return_value = True + mock_get_parser.return_value = Mock() + + # Configure created sessions to have send_keys and capture_pane + agent_sess = mock_terminal.return_value.create_session.return_value + agent_sess.send_keys = Mock() + agent_sess.capture_pane = Mock(return_value="output") + + env, _, _ = _make_env_and_reset() + + # Trigger test run path + with patch.object(TerminalTerminusEnv, "_execute_commands", return_value=(False, "out")): + sample_done = _cbr_json(is_done=True) + _ = env.step(sample_done) + + # copy_to_container should be called with run-tests and tests dir + assert term.copy_to_container.called + + @patch("rllm.environments.terminal.terminal_terminus.Terminal") + @patch("rllm.environments.terminal.terminal_terminus.TrialHandler") + @patch("rllm.environments.terminal.terminal_terminus.ParserFactory.get_parser") + def test_close_stops_terminal(self, mock_get_parser, mock_trial_handler, mock_terminal): + # Set up mocks and reset to initialize terminal + _, term = _setup_trial_handler_and_terminal(mock_trial_handler, mock_terminal) + mock_get_parser.return_value = Mock() + + env, _, _ = _make_env_and_reset() + + # Terminal.stop should be invoked by close() + assert not term.stop.called + env.close() + assert term.stop.called + + @patch("rllm.environments.terminal.terminal_terminus.Terminal") + @patch("rllm.environments.terminal.terminal_terminus.TrialHandler") + @patch("rllm.environments.terminal.terminal_terminus.ParserFactory.get_parser") + def test_run_tests_in_same_shell_toggle(self, mock_get_parser, mock_trial_handler, mock_terminal): + # Configure with same-shell flag + _setup_trial_handler_and_terminal(mock_trial_handler, mock_terminal, run_tests_in_same_shell=True) + mock_get_parser.return_value = Mock() + + env, _, _ = _make_env_and_reset() + + with patch.object(TerminalTerminusEnv, "_execute_commands", return_value=(False, "out")): + with patch.object(TerminalTerminusEnv, "_evaluate_completion_sync", return_value=(1.0, True)): + sample_done = _cbr_json(is_done=True) + + _ = env.step(sample_done) + + # Because run_tests_in_same_shell=True, workflow should reuse agent session + # and not create a new one for tests + # After reset there should be exactly one create_session call; no extra calls should be made + mock_terminal.return_value.create_session.assert_called_once()