Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions adk/agent_tool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package adk

import (
"context"
"errors"

"github.com/bytedance/sonic"

"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)

var (
defaultAgentToolParam = schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"request": {
Desc: "request to be processed",
Required: true,
Type: schema.String,
},
})
)

type agentTool struct {
agent Agent

fullChatHistoryAsInput bool
}

func (at *agentTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
var param *schema.ParamsOneOf
if !at.fullChatHistoryAsInput {
param = defaultAgentToolParam
}

return &schema.ToolInfo{
Name: at.agent.Name(ctx),
Desc: at.agent.Description(ctx),
ParamsOneOf: param,
}, nil
}

func genTransferMessages(agentName string) []Message {
tooCall := schema.ToolCall{Function: schema.FunctionCall{Name: TransferToAgentToolName, Arguments: agentName}}

assistantMessage := schema.AssistantMessage("", []schema.ToolCall{tooCall})
toolMessage := schema.ToolMessage(transferToAgentToolOutput(agentName), "", schema.WithToolName(TransferToAgentToolName))

return []Message{
rewriteMessage(assistantMessage, agentName),
rewriteMessage(toolMessage, agentName),
}
}

func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) {
var input []Message
if at.fullChatHistoryAsInput {
history, err := getReactChatHistory(ctx)
if err != nil {
return "", err
}

input = history
} else {
type request struct {
Request string `json:"request"`
}

req := &request{}
err := sonic.UnmarshalString(argumentsInJSON, req)
if err != nil {
return "", err
}

input = []Message{
schema.UserMessage(req.Request),
}
}

events := NewRunner(ctx, RunnerConfig{EnableStreaming: false}).Run(ctx, at.agent, input)
var lastEvent *AgentEvent
for {
event, ok := events.Next()
if !ok {
break
}

if event.Err != nil {
return "", event.Err
}

lastEvent = event
}

if lastEvent == nil {
return "", errors.New("no event returned")
}

var ret string
if output := lastEvent.GetModelOutput(); output != nil {
msg, e := output.Response.GetMessage()
if e != nil {
return "", e
}

ret = msg.Content
}

if output := lastEvent.GetToolCallOutput(); output != nil {
msg, e := output.Response.GetMessage()
if e != nil {
return "", e
}

ret = msg.Content
}

return ret, nil
}

type AgentToolOptions struct {
fullChatHistoryAsInput bool
}

type AgentToolOption func(*AgentToolOptions)

func WithFullChatHistoryAsInput() AgentToolOption {
return func(options *AgentToolOptions) {
options.fullChatHistoryAsInput = true
}
}

func getReactChatHistory(ctx context.Context) ([]Message, error) {
var messages []Message
var agentName string
err := compose.ProcessState(ctx, func(ctx context.Context, st *State) error {
messages = st.Messages
agentName = st.AgentName
return nil
})

messages = messages[:len(messages)-1] // remove the last assistant message, which is the tool call message
history := make([]Message, 0, len(messages))
for _, msg := range messages {
if msg.Role == schema.System {
continue
}

if msg.Role == schema.Assistant || msg.Role == schema.Tool {
msg = rewriteMessage(msg, agentName)
}

history = append(history, msg)
}

history = append(history, genTransferMessages(agentName)...)

return history, err
}

func NewAgentTool(_ context.Context, agent Agent, options ...AgentToolOption) tool.BaseTool {
opts := &AgentToolOptions{}
for _, opt := range options {
opt(opts)
}

return &agentTool{agent: agent, fullChatHistoryAsInput: opts.fullChatHistoryAsInput}
}
191 changes: 191 additions & 0 deletions adk/agent_tool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package adk

import (
"context"
"testing"

"github.com/stretchr/testify/assert"

"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
)

// mockAgent implements the Agent interface for testing
type mockAgentForTool struct {
name string
description string
responses []*AgentEvent
}

func (a *mockAgentForTool) Name(_ context.Context) string {
return a.name
}

func (a *mockAgentForTool) Description(_ context.Context) string {
return a.description
}

func (a *mockAgentForTool) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] {
iterator, generator := NewAsyncIteratorPair[*AgentEvent]()

go func() {
defer generator.Close()

for _, event := range a.responses {
generator.Send(event)

// If the event has an Exit action, stop sending events
if event.Action != nil && event.Action.Exit {
break
}
}
}()

return iterator
}

func newMockAgentForTool(name, description string, responses []*AgentEvent) *mockAgentForTool {
return &mockAgentForTool{
name: name,
description: description,
responses: responses,
}
}

func TestAgentTool_Info(t *testing.T) {
// Create a mock agent
mockAgent_ := newMockAgentForTool("TestAgent", "Test agent description", nil)

// Create an agentTool with the mock agent
agentTool_ := NewAgentTool(context.Background(), mockAgent_)

// Test the Info method
ctx := context.Background()
info, err := agentTool_.Info(ctx)

// Verify results
assert.NoError(t, err)
assert.NotNil(t, info)
assert.Equal(t, "TestAgent", info.Name)
assert.Equal(t, "Test agent description", info.Desc)
assert.NotNil(t, info.ParamsOneOf)
}

func TestAgentTool_InvokableRun(t *testing.T) {
// Create a context
ctx := context.Background()

// Test cases
tests := []struct {
name string
agentResponses []*AgentEvent
request string
expectedOutput string
expectError bool
}{
{
name: "successful model response",
agentResponses: []*AgentEvent{
{
AgentName: "TestAgent",
Output: &AgentOutput{
ModelResponse: &ModelOutput{
Response: &MessageVariant{
IsStreaming: false,
Message: schema.AssistantMessage("Test response", nil),
},
},
},
},
},
request: `{"request":"Test request"}`,
expectedOutput: "Test response",
expectError: false,
},
{
name: "successful tool call response",
agentResponses: []*AgentEvent{
{
AgentName: "TestAgent",
Output: &AgentOutput{
ToolCallResponse: &ToolCallOutput{
Name: "TestTool",
ToolCallID: "test-id",
Response: &MessageVariant{
IsStreaming: false,
Message: schema.ToolMessage("Tool response", "test-id"),
},
},
},
},
},
request: `{"request":"Test tool request"}`,
expectedOutput: "Tool response",
expectError: false,
},
{
name: "invalid request JSON",
agentResponses: nil,
request: `invalid json`,
expectedOutput: "",
expectError: true,
},
{
name: "no events returned",
agentResponses: []*AgentEvent{},
request: `{"request":"Test request"}`,
expectedOutput: "",
expectError: true,
},
{
name: "error in event",
agentResponses: []*AgentEvent{
{
AgentName: "TestAgent",
Err: assert.AnError,
},
},
request: `{"request":"Test request"}`,
expectedOutput: "",
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a mock agent with the test responses
mockAgent_ := newMockAgentForTool("TestAgent", "Test agent description", tt.agentResponses)

// Create an agentTool with the mock agent
agentTool_ := NewAgentTool(ctx, mockAgent_)

// Call InvokableRun

output, err := agentTool_.(tool.InvokableTool).InvokableRun(ctx, tt.request)

// Verify results
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedOutput, output)
}
})
}
}
Loading
Loading