Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion tests/test_tool_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Tests for the tool_utils module."""

from verifiers.utils.tool_utils import convert_func_to_oai_tool
from typing import Optional

from verifiers.types import Message
from verifiers.envs.stateful_tool_env import StatefulToolEnv
from verifiers.utils.tool_utils import build_schema_only_tool, convert_func_to_oai_tool


class TestToolUtils:
"""Test cases for the tool_utils module."""
Expand Down Expand Up @@ -177,3 +180,57 @@ def test_func(param1):
"strict": True,
},
}

def test_build_schema_only_tool_skips_complex_args(self):
"""Schema-only tools should drop skipped non-pydantic params from the generated schema."""

def complex_tool(
user_id: str,
msg: Message,
payload: list[dict[str, str]],
env: StatefulToolEnv,
) -> str:
return "ok"

schema_stub = build_schema_only_tool(
complex_tool,
args_to_skip=["msg", "env"],
)
stub_schema = convert_func_to_oai_tool(schema_stub)
parameters = stub_schema["function"]["parameters"]
properties = parameters["properties"]
required = parameters.get("required", [])

assert set(properties) == {"user_id", "payload"}
assert "msg" not in properties
assert "env" not in properties
assert "msg" not in required
assert "env" not in required

def test_add_tool_to_stateful_tool_env(
self, mock_stateful_tool_env: StatefulToolEnv
):
def complex_tool(
user_id: str,
msg: Message,
payload: list[dict[str, str]],
env: StatefulToolEnv,
) -> str:
return "ok"

env = mock_stateful_tool_env
original_tool_count = len(env.tools)
env.add_tool(complex_tool, args_to_skip=["msg", "env"])

assert len(env.tools) == original_tool_count + 1
assert len(env.oai_tools) == original_tool_count + 1

oai_tool = env.oai_tools[-1]
props = oai_tool["function"]["parameters"]["properties"]
required = oai_tool["function"]["parameters"].get("required", [])
assert set(props) == {"user_id", "payload"}
assert "msg" not in props
assert "env" not in props
assert "msg" not in required
assert "env" not in required
assert env.skipped_args["complex_tool"] == ["msg", "env"]
56 changes: 24 additions & 32 deletions verifiers/envs/stateful_tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from verifiers.envs.tool_env import ToolEnv
from verifiers.types import ChatCompletionMessageToolCall, Message, Messages, State
from verifiers.utils.async_utils import maybe_await
from verifiers.utils.tool_utils import convert_func_to_oai_tool
from verifiers.utils.tool_utils import build_schema_only_tool, convert_func_to_oai_tool


class StatefulToolEnv(ToolEnv):
Expand Down Expand Up @@ -38,22 +38,10 @@ def __init__(

def add_tool(self, tool: Callable, args_to_skip: list[str] = []):
self.tools.append(tool)
oai_tool = convert_func_to_oai_tool(tool)
for arg in args_to_skip:
assert "function" in oai_tool
assert "parameters" in oai_tool["function"]
if (
"properties" in oai_tool["function"]["parameters"]
and isinstance(oai_tool["function"]["parameters"]["properties"], dict)
and arg in oai_tool["function"]["parameters"]["properties"]
):
oai_tool["function"]["parameters"]["properties"].pop(arg)
if (
"required" in oai_tool["function"]["parameters"]
and isinstance(oai_tool["function"]["parameters"]["required"], list)
and arg in oai_tool["function"]["parameters"]["required"]
):
oai_tool["function"]["parameters"]["required"].remove(arg)
schema_only_tool = build_schema_only_tool(tool, args_to_skip)
oai_tool = convert_func_to_oai_tool(schema_only_tool)
assert "function" in oai_tool
assert "parameters" in oai_tool["function"]
if self.oai_tools is None:
self.oai_tools = []
self.oai_tools.append(oai_tool)
Expand Down Expand Up @@ -107,18 +95,22 @@ async def env_response(
self, messages: Messages, state: State, **kwargs
) -> tuple[Messages, State]:
assert isinstance(messages, list)
assert "tool_calls" in messages[-1]
tool_messages = []
for tool_call in messages[-1]["tool_calls"]:
assert isinstance(tool_call, ChatCompletionMessageToolCall)
tool_name: str = tool_call.function.name
tool_args: dict = json.loads(tool_call.function.arguments)
tool_call_id: str = tool_call.id or ""
tool_args = self.update_tool_args(
tool_name, tool_args, messages, state, **kwargs
)
tool_message: Message = await self.call_tool(
tool_name, tool_args, tool_call_id
)
tool_messages.append(tool_message)
return tool_messages, state
if self.disallow_non_tool_responses:
assert "tool_calls" in messages[-1]
if "tool_calls" in messages[-1]:
tool_messages = []
for tool_call in messages[-1]["tool_calls"]:
assert isinstance(tool_call, ChatCompletionMessageToolCall)
tool_name: str = tool_call.function.name
tool_args: dict = json.loads(tool_call.function.arguments)
tool_call_id: str = tool_call.id or ""
tool_args = self.update_tool_args(
tool_name, tool_args, messages, state, **kwargs
)
tool_message: Message = await self.call_tool(
tool_name, tool_args, tool_call_id
)
tool_messages.append(tool_message)
return tool_messages, state
else:
return await self.response_to_non_tool_call(messages, state, **kwargs)
51 changes: 34 additions & 17 deletions verifiers/envs/tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ def __init__(
tools: list[Callable] | None = None,
max_turns: int = 10,
error_formatter: Callable[[Exception], str] = lambda e: f"{str(e)}",
disallow_non_tool_responses: bool = True,
**kwargs,
):
self.tools = tools or []
self.max_turns = max_turns
self.error_formatter = error_formatter
self.disallow_non_tool_responses = disallow_non_tool_responses
self.oai_tools = [convert_func_to_oai_tool(tool) for tool in self.tools]
self.tool_map = {
getattr(tool, "__name__", tool.__class__.__name__): tool
Expand Down Expand Up @@ -74,20 +76,35 @@ async def env_response(
self, messages: Messages, state: State, **kwargs
) -> tuple[Messages, State]:
assert isinstance(messages, list)
assert "tool_calls" in messages[-1]
tool_messages = []
for tool_call in messages[-1]["tool_calls"]:
match tool_call:
case ChatCompletionMessageToolCall():
tool_name: str = tool_call.function.name
tool_args: dict = json.loads(tool_call.function.arguments)
tool_call_id: str = tool_call.id or ""
case _:
tool_name: str = tool_call["function"]["name"]
tool_args: dict = json.loads(tool_call["function"]["arguments"])
tool_call_id: str = tool_call["id"]
tool_message: Message = await self.call_tool(
tool_name, tool_args, tool_call_id
)
tool_messages.append(tool_message)
return tool_messages, state
if self.disallow_non_tool_responses:
assert "tool_calls" in messages[-1]
if "tool_calls" in messages[-1]:
tool_messages = []
for tool_call in messages[-1]["tool_calls"]:
match tool_call:
case ChatCompletionMessageToolCall():
tool_name: str = tool_call.function.name
tool_args: dict = json.loads(tool_call.function.arguments)
tool_call_id: str = tool_call.id or ""
case _:
tool_name: str = tool_call["function"]["name"]
tool_args: dict = json.loads(tool_call["function"]["arguments"])
tool_call_id: str = tool_call["id"]
tool_message: Message = await self.call_tool(
tool_name, tool_args, tool_call_id
)
tool_messages.append(tool_message)
return tool_messages, state
else:
return await self.response_to_non_tool_call(messages, state, **kwargs)

async def response_to_non_tool_call(
self, messages: Messages, state: State, **kwargs
) -> tuple[Messages, State]:
"""
If disallow_non_tool_responses=False, we might want to add "nudging"
as user messages to remind the model what to do. By default, we don't
do anything.
"""
assert self.disallow_non_tool_responses is False
return [], state
57 changes: 56 additions & 1 deletion verifiers/utils/tool_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
import inspect
from typing import Any, Callable

from agents.function_schema import function_schema
from openai.types.chat import ChatCompletionFunctionToolParam
Expand All @@ -19,3 +20,57 @@ def convert_func_to_oai_tool(func: Any) -> ChatCompletionFunctionToolParam:
"strict": True,
},
}


def build_schema_only_tool(tool: Callable, args_to_skip: list[str]) -> Callable:
"""
Convert a function to an OpenAI/Pydantic-compatible stub tool, excluding specified parameters.

Args:
func: The function to convert
exclude_params: List of parameter names to exclude from the schema
"""
if not args_to_skip:
return tool

original_signature = inspect.signature(tool)

missing_args = [
name for name in args_to_skip if name not in original_signature.parameters
]
assert not missing_args, (
f"{getattr(tool, '__name__')} does not define {missing_args}."
)

filtered_parameters = [
parameter
for name, parameter in original_signature.parameters.items()
if name not in args_to_skip
]
schema_signature = original_signature.replace(parameters=filtered_parameters)

tool_annotations = dict(getattr(tool, "__annotations__", {}))
for arg in args_to_skip:
tool_annotations.pop(arg)

if inspect.iscoroutinefunction(tool):

async def schema_stub(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError(
"Schema-only stub created for tool registration; this callable should not be invoked."
)
else:

def schema_stub(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError(
"Schema-only stub created for tool registration; this callable should not be invoked."
)

schema_stub.__name__ = getattr(tool, "__name__", tool.__class__.__name__)
schema_stub.__qualname__ = getattr(tool, "__qualname__", schema_stub.__name__)
schema_stub.__module__ = getattr(tool, "__module__", schema_stub.__module__)
schema_stub.__doc__ = getattr(tool, "__doc__", schema_stub.__doc__)
schema_stub.__annotations__ = tool_annotations
schema_stub.__signature__ = schema_signature # type: ignore[attr-defined]

return schema_stub