Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions curation/filter_timeouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Filter Stage 2 raw_tasks.jsonl by removing instances that timed out in Stage 3.

Usage:
python curation/filter_timeouts.py \
--raw-tasks curation/output/raw_tasks.jsonl \
--timeouts curation/output/timeouts.txt \
--output curation/output/raw_tasks.filtered.jsonl

The timeouts file should contain one instance_id per line.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Remove timed-out instance IDs from raw_tasks.jsonl"
)
parser.add_argument(
"--raw-tasks",
type=Path,
default=Path("curation/output/raw_tasks.jsonl"),
help="Path to Stage 2 raw_tasks.jsonl",
)
parser.add_argument(
"--timeouts",
type=Path,
default=Path("curation/output/timeouts.txt"),
help="File containing one instance_id per line to exclude",
)
parser.add_argument(
"--output",
type=Path,
default=Path("curation/output/raw_tasks.filtered.jsonl"),
help="Destination JSONL after filtering",
)
return parser.parse_args()


def load_timeouts(path: Path) -> set[str]:
if not path.exists():
return set()
ids: set[str] = set()
for line in path.read_text().splitlines():
stripped = line.strip()
if stripped:
ids.add(stripped)
return ids


def main() -> None:
args = parse_args()
timeouts = load_timeouts(args.timeouts)

kept = []
dropped = 0
with args.raw_tasks.open() as src:
for line in src:
line = line.rstrip("\n")
if not line:
continue
obj = json.loads(line)
if obj.get("instance_id") in timeouts:
dropped += 1
continue
kept.append(line)

args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text("\n".join(kept) + ("\n" if kept else ""))

print(
f"Filtered {dropped} instance(s); wrote {len(kept)} entries to {args.output}"
)


if __name__ == "__main__":
main()

23 changes: 17 additions & 6 deletions launch/launch/agent/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from launch.agent.action_parser import ActionParser
from launch.agent.prompt import ReAct_prompt
from launch.agent.state import AgentState, auto_catch
from launch.runtime import start_session
from launch.runtime import TIMEOUT_EXIT_CODE, start_session
from launch.utilities.language_handlers import get_language_handler

system_msg = """You are a developer. Your task is to install dependencies and set up a environment that is able to run the tests of the project.
Expand Down Expand Up @@ -57,6 +57,9 @@ class SetupObservation(BaseModel):

content: str = Field("", description="The content of the observation")
is_stop: bool = Field(False, description="Whether stop the setup loop")
exit_code: int | None = Field(
None, description="Exit code of the last executed command, if any"
)


class SetupActionParser(ActionParser):
Expand Down Expand Up @@ -108,7 +111,10 @@ def observation_for_setup_action(
if action.action == "command":
session = state["session"]
result = session.send_command(action.args)
return SetupObservation(content=result.to_observation(), is_stop=False)
exit_code = result.metadata.exit_code if result.metadata else None
return SetupObservation(
content=result.to_observation(), is_stop=False, exit_code=exit_code
)
if action.action == "search":
result = state["search_tool"].invoke(action.args)
return SetupObservation(content=json.dumps(result), is_stop=False)
Expand Down Expand Up @@ -202,8 +208,8 @@ def setup(max_steps: int, state: AgentState) -> dict:
commands = []
step = 0
while step < max_steps:
if time.time() - state["start_time"] > 30 * 60:
raise TimeoutError("Reached global timeout of 30 minutes")
if time.time() - state["start_time"] > 10 * 60:
raise TimeoutError("Reached global timeout of 10 minutes")
step += 1
# uses a window to avoid exceed context
commands_history = HumanMessage(
Expand All @@ -226,17 +232,22 @@ def setup(max_steps: int, state: AgentState) -> dict:


# print(response.pretty_repr())
logger.info("\n" + response.pretty_repr())
if state.get("debug"):
logger.debug("\n" + response.pretty_repr())
messages.append(response)
action = parse_setup_action(response.content)
if action and action.action == "command":
commands.append(action.args)
observation = observation_for_setup_action(state, action)
if observation.exit_code == TIMEOUT_EXIT_CODE:
logger.error("Setup command timed out (exit code %s)", observation.exit_code)
raise TimeoutError("Setup command timed out (exit code 124)")
if observation.is_stop:
break
message = HumanMessage(f"Observation:\n{observation.content}")
# print(observation.content)
logger.info("\n" + message.pretty_repr())
if state.get("debug"):
logger.debug("\n" + message.pretty_repr())
messages.append(message)

logger.info("-" * 10 + "End setup conversation" + "-" * 10)
Expand Down
6 changes: 4 additions & 2 deletions launch/launch/agent/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,17 @@ def verify(max_steps: int, state: AgentState) -> dict:
)
response = llm.invoke(input_messages)
# print(response.pretty_repr())
logger.info(response.pretty_repr())
if state.get("debug"):
logger.debug(response.pretty_repr())
messages.append(response)
action = parse_verify_action(response.content)
if action.action == "command":
commands.append(action.args)
observation = observation_for_verify_action(action, session)
message = HumanMessage(f"Observation:\n{observation.content}")
# print(message.pretty_repr())
logger.info(message.pretty_repr())
if state.get("debug"):
logger.debug(message.pretty_repr())
messages.append(message)
if action.action == "issue":
if observation.content == "":
Expand Down
2 changes: 1 addition & 1 deletion launch/launch/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _combine_outputs_between_matches(
output_segments.append(output_segment)
return "\n".join(output_segments) + "\n" if output_segments else ""

def send_command(self, command: str, timeout: float = 20 * 60) -> CommandResult:
def send_command(self, command: str, timeout: float = 5 * 60) -> CommandResult:
if not command.endswith("\n"):
command += "\n"

Expand Down