diff --git a/tests/test_tool_utils.py b/tests/test_tool_utils.py index 264ff2e98..417f1d2ac 100644 --- a/tests/test_tool_utils.py +++ b/tests/test_tool_utils.py @@ -1,8 +1,11 @@ """Tests for the tool_utils module.""" -from verifiers.utils.tool_utils import convert_func_to_oai_tool from typing import Optional +from verifiers.types import Message +from verifiers.envs.stateful_tool_env import StatefulToolEnv +from verifiers.utils.tool_utils import build_schema_only_tool, convert_func_to_oai_tool + class TestToolUtils: """Test cases for the tool_utils module.""" @@ -177,3 +180,57 @@ def test_func(param1): "strict": True, }, } + + def test_build_schema_only_tool_skips_complex_args(self): + """Schema-only tools should drop skipped non-pydantic params from the generated schema.""" + + def complex_tool( + user_id: str, + msg: Message, + payload: list[dict[str, str]], + env: StatefulToolEnv, + ) -> str: + return "ok" + + schema_stub = build_schema_only_tool( + complex_tool, + args_to_skip=["msg", "env"], + ) + stub_schema = convert_func_to_oai_tool(schema_stub) + parameters = stub_schema["function"]["parameters"] + properties = parameters["properties"] + required = parameters.get("required", []) + + assert set(properties) == {"user_id", "payload"} + assert "msg" not in properties + assert "env" not in properties + assert "msg" not in required + assert "env" not in required + + def test_add_tool_to_stateful_tool_env( + self, mock_stateful_tool_env: StatefulToolEnv + ): + def complex_tool( + user_id: str, + msg: Message, + payload: list[dict[str, str]], + env: StatefulToolEnv, + ) -> str: + return "ok" + + env = mock_stateful_tool_env + original_tool_count = len(env.tools) + env.add_tool(complex_tool, args_to_skip=["msg", "env"]) + + assert len(env.tools) == original_tool_count + 1 + assert len(env.oai_tools) == original_tool_count + 1 + + oai_tool = env.oai_tools[-1] + props = oai_tool["function"]["parameters"]["properties"] + required = oai_tool["function"]["parameters"].get("required", []) + assert set(props) == {"user_id", "payload"} + assert "msg" not in props + assert "env" not in props + assert "msg" not in required + assert "env" not in required + assert env.skipped_args["complex_tool"] == ["msg", "env"] diff --git a/verifiers/envs/stateful_tool_env.py b/verifiers/envs/stateful_tool_env.py index 7b9fb8dcb..ffb081ebe 100644 --- a/verifiers/envs/stateful_tool_env.py +++ b/verifiers/envs/stateful_tool_env.py @@ -7,7 +7,7 @@ from verifiers.envs.tool_env import ToolEnv from verifiers.types import ChatCompletionMessageToolCall, Message, Messages, State from verifiers.utils.async_utils import maybe_await -from verifiers.utils.tool_utils import convert_func_to_oai_tool +from verifiers.utils.tool_utils import build_schema_only_tool, convert_func_to_oai_tool class StatefulToolEnv(ToolEnv): @@ -38,22 +38,10 @@ def __init__( def add_tool(self, tool: Callable, args_to_skip: list[str] = []): self.tools.append(tool) - oai_tool = convert_func_to_oai_tool(tool) - for arg in args_to_skip: - assert "function" in oai_tool - assert "parameters" in oai_tool["function"] - if ( - "properties" in oai_tool["function"]["parameters"] - and isinstance(oai_tool["function"]["parameters"]["properties"], dict) - and arg in oai_tool["function"]["parameters"]["properties"] - ): - oai_tool["function"]["parameters"]["properties"].pop(arg) - if ( - "required" in oai_tool["function"]["parameters"] - and isinstance(oai_tool["function"]["parameters"]["required"], list) - and arg in oai_tool["function"]["parameters"]["required"] - ): - oai_tool["function"]["parameters"]["required"].remove(arg) + schema_only_tool = build_schema_only_tool(tool, args_to_skip) + oai_tool = convert_func_to_oai_tool(schema_only_tool) + assert "function" in oai_tool + assert "parameters" in oai_tool["function"] if self.oai_tools is None: self.oai_tools = [] self.oai_tools.append(oai_tool) @@ -107,18 +95,22 @@ async def env_response( self, messages: Messages, state: State, **kwargs ) -> tuple[Messages, State]: assert isinstance(messages, list) - assert "tool_calls" in messages[-1] - tool_messages = [] - for tool_call in messages[-1]["tool_calls"]: - assert isinstance(tool_call, ChatCompletionMessageToolCall) - tool_name: str = tool_call.function.name - tool_args: dict = json.loads(tool_call.function.arguments) - tool_call_id: str = tool_call.id or "" - tool_args = self.update_tool_args( - tool_name, tool_args, messages, state, **kwargs - ) - tool_message: Message = await self.call_tool( - tool_name, tool_args, tool_call_id - ) - tool_messages.append(tool_message) - return tool_messages, state + if self.disallow_non_tool_responses: + assert "tool_calls" in messages[-1] + if "tool_calls" in messages[-1]: + tool_messages = [] + for tool_call in messages[-1]["tool_calls"]: + assert isinstance(tool_call, ChatCompletionMessageToolCall) + tool_name: str = tool_call.function.name + tool_args: dict = json.loads(tool_call.function.arguments) + tool_call_id: str = tool_call.id or "" + tool_args = self.update_tool_args( + tool_name, tool_args, messages, state, **kwargs + ) + tool_message: Message = await self.call_tool( + tool_name, tool_args, tool_call_id + ) + tool_messages.append(tool_message) + return tool_messages, state + else: + return await self.response_to_non_tool_call(messages, state, **kwargs) diff --git a/verifiers/envs/tool_env.py b/verifiers/envs/tool_env.py index 8dd559d71..2d27ddc60 100644 --- a/verifiers/envs/tool_env.py +++ b/verifiers/envs/tool_env.py @@ -13,11 +13,13 @@ def __init__( tools: list[Callable] | None = None, max_turns: int = 10, error_formatter: Callable[[Exception], str] = lambda e: f"{str(e)}", + disallow_non_tool_responses: bool = True, **kwargs, ): self.tools = tools or [] self.max_turns = max_turns self.error_formatter = error_formatter + self.disallow_non_tool_responses = disallow_non_tool_responses self.oai_tools = [convert_func_to_oai_tool(tool) for tool in self.tools] self.tool_map = { getattr(tool, "__name__", tool.__class__.__name__): tool @@ -74,20 +76,35 @@ async def env_response( self, messages: Messages, state: State, **kwargs ) -> tuple[Messages, State]: assert isinstance(messages, list) - assert "tool_calls" in messages[-1] - tool_messages = [] - for tool_call in messages[-1]["tool_calls"]: - match tool_call: - case ChatCompletionMessageToolCall(): - tool_name: str = tool_call.function.name - tool_args: dict = json.loads(tool_call.function.arguments) - tool_call_id: str = tool_call.id or "" - case _: - tool_name: str = tool_call["function"]["name"] - tool_args: dict = json.loads(tool_call["function"]["arguments"]) - tool_call_id: str = tool_call["id"] - tool_message: Message = await self.call_tool( - tool_name, tool_args, tool_call_id - ) - tool_messages.append(tool_message) - return tool_messages, state + if self.disallow_non_tool_responses: + assert "tool_calls" in messages[-1] + if "tool_calls" in messages[-1]: + tool_messages = [] + for tool_call in messages[-1]["tool_calls"]: + match tool_call: + case ChatCompletionMessageToolCall(): + tool_name: str = tool_call.function.name + tool_args: dict = json.loads(tool_call.function.arguments) + tool_call_id: str = tool_call.id or "" + case _: + tool_name: str = tool_call["function"]["name"] + tool_args: dict = json.loads(tool_call["function"]["arguments"]) + tool_call_id: str = tool_call["id"] + tool_message: Message = await self.call_tool( + tool_name, tool_args, tool_call_id + ) + tool_messages.append(tool_message) + return tool_messages, state + else: + return await self.response_to_non_tool_call(messages, state, **kwargs) + + async def response_to_non_tool_call( + self, messages: Messages, state: State, **kwargs + ) -> tuple[Messages, State]: + """ + If disallow_non_tool_responses=False, we might want to add "nudging" + as user messages to remind the model what to do. By default, we don't + do anything. + """ + assert self.disallow_non_tool_responses is False + return [], state diff --git a/verifiers/utils/tool_utils.py b/verifiers/utils/tool_utils.py index f8c2fb275..ba21a2c2d 100644 --- a/verifiers/utils/tool_utils.py +++ b/verifiers/utils/tool_utils.py @@ -1,4 +1,5 @@ -from typing import Any +import inspect +from typing import Any, Callable from agents.function_schema import function_schema from openai.types.chat import ChatCompletionFunctionToolParam @@ -19,3 +20,57 @@ def convert_func_to_oai_tool(func: Any) -> ChatCompletionFunctionToolParam: "strict": True, }, } + + +def build_schema_only_tool(tool: Callable, args_to_skip: list[str]) -> Callable: + """ + Convert a function to an OpenAI/Pydantic-compatible stub tool, excluding specified parameters. + + Args: + func: The function to convert + exclude_params: List of parameter names to exclude from the schema + """ + if not args_to_skip: + return tool + + original_signature = inspect.signature(tool) + + missing_args = [ + name for name in args_to_skip if name not in original_signature.parameters + ] + assert not missing_args, ( + f"{getattr(tool, '__name__')} does not define {missing_args}." + ) + + filtered_parameters = [ + parameter + for name, parameter in original_signature.parameters.items() + if name not in args_to_skip + ] + schema_signature = original_signature.replace(parameters=filtered_parameters) + + tool_annotations = dict(getattr(tool, "__annotations__", {})) + for arg in args_to_skip: + tool_annotations.pop(arg) + + if inspect.iscoroutinefunction(tool): + + async def schema_stub(*args: Any, **kwargs: Any) -> Any: + raise RuntimeError( + "Schema-only stub created for tool registration; this callable should not be invoked." + ) + else: + + def schema_stub(*args: Any, **kwargs: Any) -> Any: + raise RuntimeError( + "Schema-only stub created for tool registration; this callable should not be invoked." + ) + + schema_stub.__name__ = getattr(tool, "__name__", tool.__class__.__name__) + schema_stub.__qualname__ = getattr(tool, "__qualname__", schema_stub.__name__) + schema_stub.__module__ = getattr(tool, "__module__", schema_stub.__module__) + schema_stub.__doc__ = getattr(tool, "__doc__", schema_stub.__doc__) + schema_stub.__annotations__ = tool_annotations + schema_stub.__signature__ = schema_signature # type: ignore[attr-defined] + + return schema_stub