diff --git a/_typos.toml b/_typos.toml index e600bca9..f73942fd 100644 --- a/_typos.toml +++ b/_typos.toml @@ -7,6 +7,7 @@ invokable = "invokable" InvokableLambda = "InvokableLambda" InvokableRun = "InvokableRun" typ = "typ" +byted = "byted" [files] extend-exclude = ["go.mod", "go.sum", "check_branch_name.sh"] diff --git a/adk/agent_tool.go b/adk/agent_tool.go new file mode 100644 index 00000000..25c00600 --- /dev/null +++ b/adk/agent_tool.go @@ -0,0 +1,249 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + + "github.com/bytedance/sonic" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +var ( + defaultAgentToolParam = schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "request": { + Desc: "request to be processed", + Required: true, + Type: schema.String, + }, + }) +) + +type agentToolOptions struct { + agentName string + opts []AgentRunOption +} + +func withAgentToolOptions(agentName string, opts []AgentRunOption) tool.Option { + return tool.WrapImplSpecificOptFn(func(opt *agentToolOptions) { + opt.agentName = agentName + opt.opts = opts + }) +} + +func getOptionsByAgentName(agentName string, opts []tool.Option) []AgentRunOption { + var ret []AgentRunOption + for _, opt := range opts { + o := tool.GetImplSpecificOptions[agentToolOptions](nil, opt) + if o != nil && o.agentName == agentName { + ret = append(ret, o.opts...) + } + } + return ret +} + +type agentTool struct { + agent Agent + + fullChatHistoryAsInput bool +} + +func (at *agentTool) Info(ctx context.Context) (*schema.ToolInfo, error) { + var param *schema.ParamsOneOf + if !at.fullChatHistoryAsInput { + param = defaultAgentToolParam + } + + return &schema.ToolInfo{ + Name: at.agent.Name(ctx), + Desc: at.agent.Description(ctx), + ParamsOneOf: param, + }, nil +} + +func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + var intData *agentToolInterruptInfo + var bResume bool + err := compose.ProcessState(ctx, func(ctx context.Context, s *State) error { + toolCallID := compose.GetToolCallID(ctx) + intData, bResume = s.AgentToolInterruptData[toolCallID] + if bResume { + delete(s.AgentToolInterruptData, toolCallID) + } + return nil + }) + if err != nil { + // cannot resume + bResume = false + } + + var ms *mockStore + var iter *AsyncIterator[*AgentEvent] + if bResume { + ms = newResumeStore(intData.Data) + + iter, err = newInvokableAgentToolRunner(at.agent, ms).Resume(ctx, mockCheckPointID, getOptionsByAgentName(at.agent.Name(ctx), opts)...) + if err != nil { + return "", err + } + } else { + ms = newEmptyStore() + var input []Message + if at.fullChatHistoryAsInput { + history, err := getReactChatHistory(ctx, at.agent.Name(ctx)) + if err != nil { + return "", err + } + + input = history + } else { + type request struct { + Request string `json:"request"` + } + + req := &request{} + err := sonic.UnmarshalString(argumentsInJSON, req) + if err != nil { + return "", err + } + input = []Message{ + schema.UserMessage(req.Request), + } + } + + iter = newInvokableAgentToolRunner(at.agent, ms).Run(ctx, input, append(getOptionsByAgentName(at.agent.Name(ctx), opts), WithCheckPointID(mockCheckPointID))...) + } + + var lastEvent *AgentEvent + for { + event, ok := iter.Next() + if !ok { + break + } + + if event.Err != nil { + return "", event.Err + } + + lastEvent = event + } + + if lastEvent != nil && lastEvent.Action != nil && lastEvent.Action.Interrupted != nil { + data, existed, err_ := ms.Get(ctx, mockCheckPointID) + if err_ != nil { + return "", fmt.Errorf("failed to get interrupt info: %w", err_) + } + if !existed { + return "", fmt.Errorf("interrupt has happened, but cannot find interrupt info") + } + err = compose.ProcessState(ctx, func(ctx context.Context, st *State) error { + st.AgentToolInterruptData[compose.GetToolCallID(ctx)] = &agentToolInterruptInfo{ + LastEvent: lastEvent, + Data: data, + } + return nil + }) + if err != nil { + return "", fmt.Errorf("failed to save agent tool checkpoint to state: %w", err) + } + return "", compose.InterruptAndRerun + } + + if lastEvent == nil { + return "", errors.New("no event returned") + } + + var ret string + if lastEvent.Output != nil { + if output := lastEvent.Output.MessageOutput; output != nil { + if !output.IsStreaming { + ret = output.Message.Content + } else { + msg, err := schema.ConcatMessageStream(output.MessageStream) + if err != nil { + return "", err + } + ret = msg.Content + } + } + } + + return ret, nil +} + +type AgentToolOptions struct { + fullChatHistoryAsInput bool +} + +type AgentToolOption func(*AgentToolOptions) + +func WithFullChatHistoryAsInput() AgentToolOption { + return func(options *AgentToolOptions) { + options.fullChatHistoryAsInput = true + } +} + +func getReactChatHistory(ctx context.Context, destAgentName string) ([]Message, error) { + var messages []Message + var agentName string + err := compose.ProcessState(ctx, func(ctx context.Context, st *State) error { + messages = st.Messages + agentName = st.AgentName + return nil + }) + + messages = messages[:len(messages)-1] // remove the last assistant message, which is the tool call message + history := make([]Message, 0, len(messages)) + history = append(history, messages...) + a, t := GenTransferMessages(ctx, destAgentName) + history = append(history, a, t) + for _, msg := range messages { + if msg.Role == schema.System { + continue + } + + if msg.Role == schema.Assistant || msg.Role == schema.Tool { + msg = rewriteMessage(msg, agentName) + } + + history = append(history, msg) + } + + return history, err +} + +func NewAgentTool(_ context.Context, agent Agent, options ...AgentToolOption) tool.BaseTool { + opts := &AgentToolOptions{} + for _, opt := range options { + opt(opts) + } + + return &agentTool{agent: agent, fullChatHistoryAsInput: opts.fullChatHistoryAsInput} +} + +func newInvokableAgentToolRunner(agent Agent, store compose.CheckPointStore) *Runner { + return &Runner{ + a: agent, + enableStreaming: false, + store: store, + } +} diff --git a/adk/agent_tool_test.go b/adk/agent_tool_test.go new file mode 100644 index 00000000..5a744a6d --- /dev/null +++ b/adk/agent_tool_test.go @@ -0,0 +1,204 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// mockAgent implements the Agent interface for testing +type mockAgentForTool struct { + name string + description string + responses []*AgentEvent +} + +func (a *mockAgentForTool) Name(_ context.Context) string { + return a.name +} + +func (a *mockAgentForTool) Description(_ context.Context) string { + return a.description +} + +func (a *mockAgentForTool) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + + go func() { + defer generator.Close() + + for _, event := range a.responses { + generator.Send(event) + + // If the event has an Exit action, stop sending events + if event.Action != nil && event.Action.Exit { + break + } + } + }() + + return iterator +} + +func newMockAgentForTool(name, description string, responses []*AgentEvent) *mockAgentForTool { + return &mockAgentForTool{ + name: name, + description: description, + responses: responses, + } +} + +func TestAgentTool_Info(t *testing.T) { + // Create a mock agent + mockAgent_ := newMockAgentForTool("TestAgent", "Test agent description", nil) + + // Create an agentTool with the mock agent + agentTool_ := NewAgentTool(context.Background(), mockAgent_) + + // Test the Info method + ctx := context.Background() + info, err := agentTool_.Info(ctx) + + // Verify results + assert.NoError(t, err) + assert.NotNil(t, info) + assert.Equal(t, "TestAgent", info.Name) + assert.Equal(t, "Test agent description", info.Desc) + assert.NotNil(t, info.ParamsOneOf) +} + +func TestAgentTool_InvokableRun(t *testing.T) { + // Create a context + ctx := context.Background() + + // Test cases + tests := []struct { + name string + agentResponses []*AgentEvent + request string + expectedOutput string + expectError bool + }{ + { + name: "successful model response", + agentResponses: []*AgentEvent{ + { + AgentName: "TestAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Test response", nil), + Role: schema.Assistant, + }, + }, + }, + }, + request: `{"request":"Test request"}`, + expectedOutput: "Test response", + expectError: false, + }, + { + name: "successful tool call response", + agentResponses: []*AgentEvent{ + { + AgentName: "TestAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.ToolMessage("Tool response", "test-id"), + Role: schema.Tool, + }, + }, + }, + }, + request: `{"request":"Test tool request"}`, + expectedOutput: "Tool response", + expectError: false, + }, + { + name: "invalid request JSON", + agentResponses: nil, + request: `invalid json`, + expectedOutput: "", + expectError: true, + }, + { + name: "no events returned", + agentResponses: []*AgentEvent{}, + request: `{"request":"Test request"}`, + expectedOutput: "", + expectError: true, + }, + { + name: "error in event", + agentResponses: []*AgentEvent{ + { + AgentName: "TestAgent", + Err: assert.AnError, + }, + }, + request: `{"request":"Test request"}`, + expectedOutput: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a mock agent with the test responses + mockAgent_ := newMockAgentForTool("TestAgent", "Test agent description", tt.agentResponses) + + // Create an agentTool with the mock agent + agentTool_ := NewAgentTool(ctx, mockAgent_) + + // Call InvokableRun + output, err := agentTool_.(tool.InvokableTool).InvokableRun(ctx, tt.request) + + // Verify results + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedOutput, output) + } + }) + } +} + +func buildToolTestGraph(it tool.InvokableTool) compose.Runnable[string, string] { + ctx := context.Background() + g := compose.NewGraph[string, string](compose.WithGenLocalState(func(ctx context.Context) (state *State) { + return &State{} + })) + _ = g.AddLambdaNode("tool node", compose.InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return it.InvokableRun(ctx, input) + })) + _ = g.AddEdge(compose.START, "tool node") + _ = g.AddEdge("tool node", compose.END) + r, err := g.Compile(ctx) + if err != nil { + panic(err) + } + return r +} diff --git a/adk/call_option.go b/adk/call_option.go new file mode 100644 index 00000000..31cf56ba --- /dev/null +++ b/adk/call_option.go @@ -0,0 +1,110 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +type options struct { + sessionValues map[string]any + checkPointID *string + skipTransferMessages bool +} + +// AgentRunOption is the call option for adk Agent. +type AgentRunOption struct { + implSpecificOptFn any + + // specify which Agent can see this AgentRunOption, if empty, all Agents can see this AgentRunOption + agentNames []string +} + +func (o AgentRunOption) DesignateAgent(name ...string) AgentRunOption { + o.agentNames = append(o.agentNames, name...) + return o +} + +func getCommonOptions(base *options, opts ...AgentRunOption) *options { + if base == nil { + base = &options{} + } + + return GetImplSpecificOptions[options](base, opts...) +} + +func WithSessionValues(v map[string]any) AgentRunOption { + return WrapImplSpecificOptFn(func(o *options) { + o.sessionValues = v + }) +} + +func WithSkipTransferMessages() AgentRunOption { + return WrapImplSpecificOptFn(func(t *options) { + t.skipTransferMessages = true + }) +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. +func WrapImplSpecificOptFn[T any](optFn func(*T)) AgentRunOption { + return AgentRunOption{ + implSpecificOptFn: optFn, + } +} + +// GetImplSpecificOptions extract the implementation specific options from AgentRunOption list, optionally providing a base options with default values. +// e.g. +// +// myOption := &MyOption{ +// Field1: "default_value", +// } +// +// myOption := model.GetImplSpecificOptions(myOption, opts...) +func GetImplSpecificOptions[T any](base *T, opts ...AgentRunOption) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + optFn, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + optFn(base) + } + } + } + + return base +} + +func filterOptions(agentName string, opts []AgentRunOption) []AgentRunOption { + if len(opts) == 0 { + return nil + } + var filteredOpts []AgentRunOption + for i := range opts { + opt := opts[i] + if len(opt.agentNames) == 0 { + filteredOpts = append(filteredOpts, opt) + continue + } + for j := range opt.agentNames { + if opt.agentNames[j] == agentName { + filteredOpts = append(filteredOpts, opt) + break + } + } + } + return filteredOpts +} diff --git a/adk/call_option_test.go b/adk/call_option_test.go new file mode 100644 index 00000000..0deee265 --- /dev/null +++ b/adk/call_option_test.go @@ -0,0 +1,42 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" +) + +type mockAgentForOption struct { + opts []AgentRunOption + + options *options +} + +func (m *mockAgentForOption) Name(ctx context.Context) string { + return "agent_1" +} + +func (m *mockAgentForOption) Description(ctx context.Context) string { + return "" +} + +func (m *mockAgentForOption) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + m.opts = opts + m.options = getCommonOptions(&options{}, opts...) + + return nil +} diff --git a/adk/chatmodel.go b/adk/chatmodel.go new file mode 100644 index 00000000..4176790f --- /dev/null +++ b/adk/chatmodel.go @@ -0,0 +1,733 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "bytes" + "context" + "encoding/gob" + "errors" + "fmt" + "math" + "runtime/debug" + "sync" + "sync/atomic" + + "github.com/bytedance/sonic" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/internal/safe" + "github.com/cloudwego/eino/schema" + ub "github.com/cloudwego/eino/utils/callbacks" +) + +type chatModelAgentRunOptions struct { + // run + chatModelOptions []model.Option + toolOptions []tool.Option + agentToolOptions map[ /*tool name*/ string][]AgentRunOption // todo: map or list? + + // resume + historyModifier func(context.Context, []Message) []Message +} + +func WithChatModelOptions(opts []model.Option) AgentRunOption { + return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) { + t.chatModelOptions = opts + }) +} + +func WithToolOptions(opts []tool.Option) AgentRunOption { + return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) { + t.toolOptions = opts + }) +} + +func WithAgentToolRunOptions(opts map[string] /*tool name*/ []AgentRunOption) AgentRunOption { + return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) { + t.agentToolOptions = opts + }) +} + +func WithHistoryModifier(f func(context.Context, []Message) []Message) AgentRunOption { + return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) { + t.historyModifier = f + }) +} + +type ToolsConfig struct { + compose.ToolsNodeConfig + + // ReturnDirectly specifies tools that cause the agent to return immediately when called. + // If multiple listed tools are called simultaneously, only the first one triggers the return. + // The map keys are tool names indicate whether the tool should trigger immediate return. + ReturnDirectly map[string]bool +} + +// GenModelInput transforms agent instructions and input into a format suitable for the model. +type GenModelInput func(ctx context.Context, instruction string, input *AgentInput) ([]Message, error) + +func defaultGenModelInput(ctx context.Context, instruction string, input *AgentInput) ([]Message, error) { + msgs := make([]Message, 0, len(input.Messages)+1) + + if instruction != "" { + sp := schema.SystemMessage(instruction) + + vs := GetSessionValues(ctx) + if len(vs) > 0 { + ct := prompt.FromMessages(schema.FString, sp) + ms, err := ct.Format(ctx, vs) + if err != nil { + return nil, err + } + + sp = ms[0] + } + + msgs = append(msgs, sp) + } + + msgs = append(msgs, input.Messages...) + + return msgs, nil +} + +type ChatModelAgentConfig struct { + // Name of the agent. Better be unique across all agents. + Name string + // Description of the agent's capabilities. + // Helps other agents determine whether to transfer tasks to this agent. + Description string + // Instruction used as the system prompt for this agent. + // Optional. If empty, no system prompt will be used. + // Supports f-string placeholders for session values in default GenModelInput, for example: + // "You are a helpful assistant. The current time is {Time}. The current user is {User}." + // These placeholders will be replaced with session values for "Time" and "User". + Instruction string + + Model model.ToolCallingChatModel + + ToolsConfig ToolsConfig + + // GenModelInput transforms instructions and input messages into the model's input format. + // Optional. Defaults to defaultGenModelInput which combines instruction and messages. + GenModelInput GenModelInput + + // Exit defines the tool used to terminate the agent process. + // Optional. If nil, no Exit Action will be generated. + // You can use the provided 'ExitTool' implementation directly. + Exit tool.BaseTool + + // OutputKey stores the agent's response in the session. + // Optional. When set, stores output via AddSessionValue(ctx, outputKey, msg.Content). + OutputKey string + + // MaxIterations defines the upper limit of ChatModel generation cycles. + // The agent will terminate with an error if this limit is exceeded. + // Optional. Defaults to 20. + MaxIterations int +} + +type ChatModelAgent struct { + name string + description string + instruction string + + model model.ToolCallingChatModel + toolsConfig ToolsConfig + + genModelInput GenModelInput + + outputKey string + maxIterations int + + subAgents []Agent + parentAgent Agent + + disallowTransferToParent bool + + exit tool.BaseTool + + // runner + once sync.Once + run runFunc + frozen uint32 +} + +type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *mockStore, opts ...compose.Option) + +var registerInternalTypeOnce sync.Once + +func NewChatModelAgent(_ context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { + if config.Name == "" { + return nil, errors.New("agent 'Name' is required") + } + if config.Description == "" { + return nil, errors.New("agent 'Description' is required") + } + if config.Model == nil { + return nil, errors.New("agent 'Model' is required") + } + + var err error + registerInternalTypeOnce.Do(func() { + err = compose.RegisterInternalType(func(key string, value any) error { + gob.RegisterName(key, value) + return nil + }) + gob.RegisterName("_eino_message", &schema.Message{}) + gob.RegisterName("_eino_document", &schema.Document{}) + gob.RegisterName("_eino_tool_call", schema.ToolCall{}) + gob.RegisterName("_eino_function_call", schema.FunctionCall{}) + gob.RegisterName("_eino_response_meta", &schema.ResponseMeta{}) + gob.RegisterName("_eino_token_usage", &schema.TokenUsage{}) + gob.RegisterName("_eino_log_probs", &schema.LogProbs{}) + gob.RegisterName("_eino_chat_message_part", schema.ChatMessagePart{}) + gob.RegisterName("_eino_chat_message_image_url", &schema.ChatMessageImageURL{}) + gob.RegisterName("_eino_chat_message_audio_url", &schema.ChatMessageAudioURL{}) + gob.RegisterName("_eino_chat_message_video_url", &schema.ChatMessageVideoURL{}) + gob.RegisterName("_eino_chat_message_file_url", &schema.ChatMessageFileURL{}) + gob.RegisterName("_eino_adk_chat_model_agent_interrupt_info", &ChatModelAgentInterruptInfo{}) + }) + if err != nil { + return nil, err + } + + genInput := defaultGenModelInput + if config.GenModelInput != nil { + genInput = config.GenModelInput + } + + return &ChatModelAgent{ + name: config.Name, + description: config.Description, + instruction: config.Instruction, + model: config.Model, + toolsConfig: config.ToolsConfig, + genModelInput: genInput, + exit: config.Exit, + outputKey: config.OutputKey, + maxIterations: config.MaxIterations, + }, nil +} + +const ( + TransferToAgentToolName = "transfer_to_agent" + TransferToAgentToolDesc = "Transfer the question to another agent." +) + +var ( + toolInfoTransferToAgent = &schema.ToolInfo{ + Name: TransferToAgentToolName, + Desc: TransferToAgentToolDesc, + + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "agent_name": { + Desc: "the name of the agent to transfer to", + Required: true, + Type: schema.String, + }, + }), + } + + ToolInfoExit = &schema.ToolInfo{ + Name: "exit", + Desc: "Exit the agent process and return the final result.", + + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "final_result": { + Desc: "the final result to return", + Required: true, + Type: schema.String, + }, + }), + } +) + +type ExitTool struct{} + +func (et ExitTool) Info(_ context.Context) (*schema.ToolInfo, error) { + return ToolInfoExit, nil +} + +func (et ExitTool) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { + type exitParams struct { + FinalResult string `json:"final_result"` + } + + params := &exitParams{} + err := sonic.UnmarshalString(argumentsInJSON, params) + if err != nil { + return "", err + } + + err = SendToolGenAction(ctx, "exit", NewExitAction()) + if err != nil { + return "", err + } + + return params.FinalResult, nil +} + +type transferToAgent struct{} + +func (tta transferToAgent) Info(_ context.Context) (*schema.ToolInfo, error) { + return toolInfoTransferToAgent, nil +} + +func transferToAgentToolOutput(destName string) string { + return fmt.Sprintf("successfully transferred to agent [%s]", destName) +} + +func (tta transferToAgent) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { + type transferParams struct { + AgentName string `json:"agent_name"` + } + + params := &transferParams{} + err := sonic.UnmarshalString(argumentsInJSON, params) + if err != nil { + return "", err + } + + err = SendToolGenAction(ctx, TransferToAgentToolName, NewTransferToAgentAction(params.AgentName)) + if err != nil { + return "", err + } + + return transferToAgentToolOutput(params.AgentName), nil +} + +func (a *ChatModelAgent) Name(_ context.Context) string { + return a.name +} + +func (a *ChatModelAgent) Description(_ context.Context) string { + return a.description +} + +func (a *ChatModelAgent) OnSetSubAgents(_ context.Context, subAgents []Agent) error { + if atomic.LoadUint32(&a.frozen) == 1 { + return errors.New("agent has been frozen after run") + } + + if len(a.subAgents) > 0 { + return errors.New("agent's sub-agents has already been set") + } + + a.subAgents = subAgents + return nil +} + +func (a *ChatModelAgent) OnSetAsSubAgent(_ context.Context, parent Agent) error { + if atomic.LoadUint32(&a.frozen) == 1 { + return errors.New("agent has been frozen after run") + } + + if a.parentAgent != nil { + return errors.New("agent has already been set as a sub-agent of another agent") + } + + a.parentAgent = parent + return nil +} + +func (a *ChatModelAgent) OnDisallowTransferToParent(_ context.Context) error { + if atomic.LoadUint32(&a.frozen) == 1 { + return errors.New("agent has been frozen after run") + } + + a.disallowTransferToParent = true + + return nil +} + +type cbHandler struct { + *AsyncGenerator[*AgentEvent] + agentName string + + enableStreaming bool + store *mockStore +} + +func (h *cbHandler) onChatModelEnd(ctx context.Context, + _ *callbacks.RunInfo, output *model.CallbackOutput) context.Context { + + event := EventFromMessage(output.Message, nil, schema.Assistant, "") + h.Send(event) + return ctx +} + +func (h *cbHandler) onChatModelEndWithStreamOutput(ctx context.Context, + _ *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { + + cvt := func(in *model.CallbackOutput) (Message, error) { + return in.Message, nil + } + out := schema.StreamReaderWithConvert(output, cvt) + event := EventFromMessage(nil, out, schema.Assistant, "") + h.Send(event) + + return ctx +} + +func (h *cbHandler) onToolEnd(ctx context.Context, + runInfo *callbacks.RunInfo, output *tool.CallbackOutput) context.Context { + + toolCallID := compose.GetToolCallID(ctx) + msg := schema.ToolMessage(output.Response, toolCallID, schema.WithToolName(runInfo.Name)) + event := EventFromMessage(msg, nil, schema.Tool, runInfo.Name) + + action := popToolGenAction(ctx, runInfo.Name) + event.Action = action + + h.Send(event) + + return ctx +} + +func (h *cbHandler) onToolEndWithStreamOutput(ctx context.Context, + runInfo *callbacks.RunInfo, output *schema.StreamReader[*tool.CallbackOutput]) context.Context { + + toolCallID := compose.GetToolCallID(ctx) + cvt := func(in *tool.CallbackOutput) (Message, error) { + return schema.ToolMessage(in.Response, toolCallID), nil + } + out := schema.StreamReaderWithConvert(output, cvt) + event := EventFromMessage(nil, out, schema.Tool, runInfo.Name) + h.Send(event) + + return ctx +} + +type ChatModelAgentInterruptInfo struct { // replace temp info by info when save the data + Info *compose.InterruptInfo + Data []byte +} + +func (h *cbHandler) onGraphError(ctx context.Context, + _ *callbacks.RunInfo, err error) context.Context { + + info, ok := compose.ExtractInterruptInfo(err) + if !ok { + h.Send(&AgentEvent{Err: err}) + return ctx + } + + data, existed, err := h.store.Get(ctx, mockCheckPointID) + if err != nil { + h.Send(&AgentEvent{AgentName: h.agentName, Err: fmt.Errorf("failed to get interrupt info: %w", err)}) + return ctx + } + if !existed { + h.Send(&AgentEvent{AgentName: h.agentName, Err: fmt.Errorf("interrupt has happened, but cannot find interrupt info")}) + return ctx + } + h.Send(&AgentEvent{AgentName: h.agentName, Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: &ChatModelAgentInterruptInfo{Data: data, Info: info}, + }, + }}) + + return ctx +} + +func genReactCallbacks(agentName string, + generator *AsyncGenerator[*AgentEvent], + enableStreaming bool, + store *mockStore) compose.Option { + + h := &cbHandler{AsyncGenerator: generator, agentName: agentName, store: store, enableStreaming: enableStreaming} + + cmHandler := &ub.ModelCallbackHandler{ + OnEnd: h.onChatModelEnd, + OnEndWithStreamOutput: h.onChatModelEndWithStreamOutput, + } + toolHandler := &ub.ToolCallbackHandler{ + OnEnd: h.onToolEnd, + OnEndWithStreamOutput: h.onToolEndWithStreamOutput, + } + graphHandler := callbacks.NewHandlerBuilder().OnErrorFn(h.onGraphError).Build() + + cb := ub.NewHandlerHelper().ChatModel(cmHandler).Tool(toolHandler).Graph(graphHandler).Handler() + + return compose.WithCallbacks(cb) +} + +func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStream, outputKey string) error { + if msg != nil { + AddSessionValue(ctx, outputKey, msg.Content) + return nil + } + + concatenated, err := schema.ConcatMessageStream(msgStream) + if err != nil { + return err + } + + AddSessionValue(ctx, outputKey, concatenated.Content) + return nil +} + +func errFunc(err error) runFunc { + return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *mockStore, _ ...compose.Option) { + generator.Send(&AgentEvent{Err: err}) + } +} + +func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc { + a.once.Do(func() { + instruction := a.instruction + toolsNodeConf := a.toolsConfig.ToolsNodeConfig + returnDirectly := copyMap(a.toolsConfig.ReturnDirectly) + + transferToAgents := a.subAgents + if a.parentAgent != nil && !a.disallowTransferToParent { + transferToAgents = append(transferToAgents, a.parentAgent) + } + + if len(transferToAgents) > 0 { + transferInstruction := genTransferToAgentInstruction(ctx, transferToAgents) + instruction = concatInstructions(instruction, transferInstruction) + + toolsNodeConf.Tools = append(toolsNodeConf.Tools, &transferToAgent{}) + returnDirectly[TransferToAgentToolName] = true + } + + if a.exit != nil { + toolsNodeConf.Tools = append(toolsNodeConf.Tools, a.exit) + exitInfo, err := a.exit.Info(ctx) + if err != nil { + a.run = errFunc(err) + return + } + returnDirectly[exitInfo.Name] = true + } + + if len(toolsNodeConf.Tools) == 0 { + a.run = func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *mockStore, opts ...compose.Option) { + var err error + var msgs []Message + msgs, err = a.genModelInput(ctx, instruction, input) + if err != nil { + generator.Send(&AgentEvent{Err: err}) + return + } + + var msg Message + var msgStream MessageStream + if input.EnableStreaming { + msgStream, err = a.model.Stream(ctx, msgs) // todo: chat model option + } else { + msg, err = a.model.Generate(ctx, msgs) + } + + var event *AgentEvent + if err == nil { + if a.outputKey != "" { + if msgStream != nil { + // copy the stream first because when setting output to session, the stream will be consumed + ss := msgStream.Copy(2) + event = EventFromMessage(msg, ss[1], schema.Assistant, "") + msgStream = ss[0] + } else { + event = EventFromMessage(msg, nil, schema.Assistant, "") + } + // send event asap, because setting output to session will block until stream fully consumed + generator.Send(event) + err = setOutputToSession(ctx, msg, msgStream, a.outputKey) + if err != nil { + generator.Send(&AgentEvent{Err: err}) + } + } else { + event = EventFromMessage(msg, msgStream, schema.Assistant, "") + generator.Send(event) + } + } else { + event = &AgentEvent{Err: err} + generator.Send(event) + } + + generator.Close() + } + + return + } + + // react + conf := &reactConfig{ + model: a.model, + toolsConfig: &toolsNodeConf, + toolsReturnDirectly: returnDirectly, + agentName: a.name, + maxIterations: a.maxIterations, + } + + g, err := newReact(ctx, conf) + if err != nil { + a.run = errFunc(err) + return + } + + a.run = func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *mockStore, opts ...compose.Option) { + var compileOptions []compose.GraphCompileOption + compileOptions = append(compileOptions, + compose.WithGraphName("React"), + compose.WithCheckPointStore(store), + compose.WithSerializer(&gobSerializer{}), + // ensure the graph won't exceed max steps due to max iterations + compose.WithMaxRunSteps(math.MaxInt)) + + runnable, err_ := g.Compile(ctx, compileOptions...) + if err != nil { + generator.Send(&AgentEvent{AgentName: a.name, Err: err}) + return + } + + var msgs []Message + msgs, err_ = a.genModelInput(ctx, instruction, input) + if err_ != nil { + generator.Send(&AgentEvent{Err: err_}) + return + } + + callOpt := genReactCallbacks(a.name, generator, input.EnableStreaming, store) + + var msg Message + var msgStream MessageStream + if input.EnableStreaming { + msgStream, err_ = runnable.Stream(ctx, msgs, append(opts, callOpt)...) + } else { + msg, err_ = runnable.Invoke(ctx, msgs, append(opts, callOpt)...) + } + + if err_ == nil { + if a.outputKey != "" { + err_ = setOutputToSession(ctx, msg, msgStream, a.outputKey) + if err_ != nil { + generator.Send(&AgentEvent{Err: err_}) + } + } else if msgStream != nil { + msgStream.Close() + } + } + + generator.Close() + } + }) + + atomic.StoreUint32(&a.frozen, 1) + + return a.run +} + +func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + run := a.buildRunFunc(ctx) + + co := getComposeOptions(opts) + co = append(co, compose.WithCheckPointID(mockCheckPointID)) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&AgentEvent{Err: e}) + } + + generator.Close() + }() + + run(ctx, input, generator, newEmptyStore(), co...) + }() + + return iterator +} + +func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + run := a.buildRunFunc(ctx) + + co := getComposeOptions(opts) + co = append(co, compose.WithCheckPointID(mockCheckPointID)) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&AgentEvent{Err: e}) + } + + generator.Close() + }() + + run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator, newResumeStore(info.Data.(*ChatModelAgentInterruptInfo).Data), co...) + }() + + return iterator +} + +func getComposeOptions(opts []AgentRunOption) []compose.Option { + o := GetImplSpecificOptions[chatModelAgentRunOptions](nil, opts...) + var co []compose.Option + if len(o.chatModelOptions) > 0 { + co = append(co, compose.WithChatModelOption(o.chatModelOptions...)) + } + var to []tool.Option + if len(o.toolOptions) > 0 { + to = append(to, o.toolOptions...) + } + for toolName, atos := range o.agentToolOptions { + to = append(to, withAgentToolOptions(toolName, atos)) + } + if len(to) > 0 { + co = append(co, compose.WithToolsNodeOption(compose.WithToolOption(to...))) + } + if o.historyModifier != nil { + co = append(co, compose.WithStateModifier(func(ctx context.Context, path compose.NodePath, state any) error { + s, ok := state.(*State) + if !ok { + return fmt.Errorf("unexpected state type: %T, expected: %T", state, &State{}) + } + s.Messages = o.historyModifier(ctx, s.Messages) + return nil + })) + } + return co +} + +type gobSerializer struct{} + +func (g *gobSerializer) Marshal(v any) ([]byte, error) { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(v) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (g *gobSerializer) Unmarshal(data []byte, v any) error { + buf := bytes.NewBuffer(data) + return gob.NewDecoder(buf).Decode(v) +} diff --git a/adk/chatmodel_test.go b/adk/chatmodel_test.go new file mode 100644 index 00000000..c238cc3f --- /dev/null +++ b/adk/chatmodel_test.go @@ -0,0 +1,345 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + mockModel "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/schema" +) + +// TestChatModelAgentRun tests the Run method of ChatModelAgent +func TestChatModelAgentRun(t *testing.T) { + // Basic test for Run method + t.Run("BasicFunctionality", func(t *testing.T) { + ctx := context.Background() + + // Create a mock chat model + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the mock model + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Hello, I am an AI assistant.", nil), nil). + Times(1) + + // Create a ChatModelAgent + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent for unit testing", + Instruction: "You are a helpful assistant.", + Model: cm, + }) + assert.NoError(t, err) + assert.NotNil(t, agent) + + // Run the agent + input := &AgentInput{ + Messages: []Message{ + schema.UserMessage("Hello, who are you?"), + }, + } + iterator := agent.Run(ctx, input) + assert.NotNil(t, iterator) + + // Get the event from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.Nil(t, event.Err) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + + // Verify the message content + msg := event.Output.MessageOutput.Message + assert.NotNil(t, msg) + assert.Equal(t, "Hello, I am an AI assistant.", msg.Content) + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) + }) + + // Test with streaming output + t.Run("StreamOutput", func(t *testing.T) { + ctx := context.Background() + + // Create a mock chat model + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Create a stream reader for the mock response + sr := schema.StreamReaderFromArray([]*schema.Message{ + schema.AssistantMessage("Hello", nil), + schema.AssistantMessage(", I am", nil), + schema.AssistantMessage(" an AI assistant.", nil), + }) + + // Set up expectations for the mock model + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + Return(sr, nil). + Times(1) + + // Create a ChatModelAgent + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent for unit testing", + Instruction: "You are a helpful assistant.", + Model: cm, + }) + assert.NoError(t, err) + assert.NotNil(t, agent) + + // Run the agent with streaming enabled + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello, who are you?")}, + EnableStreaming: true, + } + iterator := agent.Run(ctx, input) + assert.NotNil(t, iterator) + + // Get the event from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.Nil(t, event.Err) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + assert.True(t, event.Output.MessageOutput.IsStreaming) + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) + }) + + // Test error handling + t.Run("ErrorHandling", func(t *testing.T) { + ctx := context.Background() + + // Create a mock chat model + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the mock model to return an error + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("model error")). + Times(1) + + // Create a ChatModelAgent + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent for unit testing", + Instruction: "You are a helpful assistant.", + Model: cm, + }) + assert.NoError(t, err) + assert.NotNil(t, agent) + + // Run the agent + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Hello, who are you?")}, + } + iterator := agent.Run(ctx, input) + assert.NotNil(t, iterator) + + // Get the event from the iterator, should contain an error + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.NotNil(t, event.Err) + assert.Equal(t, "model error", event.Err.Error()) + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) + }) + + // Test with tools + t.Run("WithTools", func(t *testing.T) { + ctx := context.Background() + + // Create a fake tool for testing + fakeTool := &fakeToolForTest{ + tarCount: 1, + } + + info, err := fakeTool.Info(ctx) + assert.NoError(t, err) + + // Create a mock chat model + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the mock model + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Using tool", + []schema.ToolCall{ + { + ID: "tool-call-1", + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: `{"name": "test user"}`, + }, + }, + }), nil). + Times(1) + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Task completed", nil), nil). + Times(1) + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // Create a ChatModelAgent with tools + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent for unit testing", + Instruction: "You are a helpful assistant.", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool}, + }, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, agent) + + // Run the agent + input := &AgentInput{ + Messages: []Message{schema.UserMessage("Use the test tool")}, + } + iterator := agent.Run(ctx, input) + assert.NotNil(t, iterator) + + // Get events from the iterator + // First event should be the model output with tool call + event1, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event1) + assert.Nil(t, event1.Err) + assert.NotNil(t, event1.Output) + assert.NotNil(t, event1.Output.MessageOutput) + assert.Equal(t, schema.Assistant, event1.Output.MessageOutput.Role) + + // Second event should be the tool output + event2, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event2) + assert.Nil(t, event2.Err) + assert.NotNil(t, event2.Output) + assert.NotNil(t, event2.Output.MessageOutput) + assert.Equal(t, schema.Tool, event2.Output.MessageOutput.Role) + + // Third event should be the final model output + event3, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event3) + assert.Nil(t, event3.Err) + assert.NotNil(t, event3.Output) + assert.NotNil(t, event3.Output.MessageOutput) + assert.Equal(t, schema.Assistant, event3.Output.MessageOutput.Role) + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) + }) +} + +// TestExitTool tests the Exit tool functionality +func TestExitTool(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock chat model + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the mock model + // First call: model generates a message with Exit tool call + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("I'll exit with a final result", + []schema.ToolCall{ + { + ID: "tool-call-1", + Function: schema.FunctionCall{ + Name: "exit", + Arguments: `{"final_result": "This is the final result"}`}, + }, + }), nil). + Times(1) + + // Model should implement WithTools + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // Create an agent with the Exit tool + agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "TestAgent", + Description: "Test agent with Exit tool", + Instruction: "You are a helpful assistant.", + Model: cm, + Exit: &ExitTool{}, + }) + assert.NoError(t, err) + assert.NotNil(t, agent) + + // Run the agent + input := &AgentInput{ + Messages: []Message{ + schema.UserMessage("Please exit with a final result"), + }, + } + iterator := agent.Run(ctx, input) + assert.NotNil(t, iterator) + + // First event: model output with tool call + event1, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event1) + assert.Nil(t, event1.Err) + assert.NotNil(t, event1.Output) + assert.NotNil(t, event1.Output.MessageOutput) + assert.Equal(t, schema.Assistant, event1.Output.MessageOutput.Role) + + // Second event: tool output (Exit) + event2, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event2) + assert.Nil(t, event2.Err) + assert.NotNil(t, event2.Output) + assert.NotNil(t, event2.Output.MessageOutput) + assert.Equal(t, schema.Tool, event2.Output.MessageOutput.Role) + + // Verify the action is Exit + assert.NotNil(t, event2.Action) + assert.True(t, event2.Action.Exit) + + // Verify the final result + assert.Equal(t, "This is the final result", event2.Output.MessageOutput.Message.Content) + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) +} diff --git a/adk/deterministic_transfer.go b/adk/deterministic_transfer.go new file mode 100644 index 00000000..8cbb8021 --- /dev/null +++ b/adk/deterministic_transfer.go @@ -0,0 +1,149 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "runtime/debug" + + "github.com/cloudwego/eino/internal/safe" + "github.com/cloudwego/eino/schema" +) + +func AgentWithDeterministicTransferTo(_ context.Context, config *DeterministicTransferConfig) Agent { + if ra, ok := config.Agent.(ResumableAgent); ok { + return &resumableAgentWithDeterministicTransferTo{ + agent: ra, + toAgentNames: config.ToAgentNames, + } + } + return &agentWithDeterministicTransferTo{ + agent: config.Agent, + toAgentNames: config.ToAgentNames, + } +} + +type agentWithDeterministicTransferTo struct { + agent Agent + toAgentNames []string +} + +func (a *agentWithDeterministicTransferTo) Description(ctx context.Context) string { + return a.agent.Description(ctx) +} + +func (a *agentWithDeterministicTransferTo) Name(ctx context.Context) string { + return a.agent.Name(ctx) +} + +func (a *agentWithDeterministicTransferTo) Run(ctx context.Context, + input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + + if _, ok := a.agent.(*flowAgent); ok { + ctx = ClearRunCtx(ctx) + } + + aIter := a.agent.Run(ctx, input, options...) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + go appendTransferAction(ctx, aIter, generator, a.toAgentNames) + + return iterator +} + +type resumableAgentWithDeterministicTransferTo struct { + agent ResumableAgent + toAgentNames []string +} + +func (a *resumableAgentWithDeterministicTransferTo) Description(ctx context.Context) string { + return a.agent.Description(ctx) +} + +func (a *resumableAgentWithDeterministicTransferTo) Name(ctx context.Context) string { + return a.agent.Name(ctx) +} + +func (a *resumableAgentWithDeterministicTransferTo) Run(ctx context.Context, + input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + + if _, ok := a.agent.(*flowAgent); ok { + ctx = ClearRunCtx(ctx) + } + + aIter := a.agent.Run(ctx, input, options...) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + go appendTransferAction(ctx, aIter, generator, a.toAgentNames) + + return iterator +} + +func (a *resumableAgentWithDeterministicTransferTo) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + aIter := a.agent.Resume(ctx, info, opts...) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + go appendTransferAction(ctx, aIter, generator, a.toAgentNames) + + return iterator +} + +func appendTransferAction(ctx context.Context, aIter *AsyncIterator[*AgentEvent], generator *AsyncGenerator[*AgentEvent], toAgentNames []string) { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&AgentEvent{Err: e}) + } + + generator.Close() + }() + + interrupted := false + + for { + event, ok := aIter.Next() + if !ok { + break + } + + generator.Send(event) + + if event.Action != nil && event.Action.Interrupted != nil { + interrupted = true + } else { + interrupted = false + } + } + + if interrupted { + return + } + + for _, toAgentName := range toAgentNames { + aMsg, tMsg := GenTransferMessages(ctx, toAgentName) + aEvent := EventFromMessage(aMsg, nil, schema.Assistant, "") + generator.Send(aEvent) + tEvent := EventFromMessage(tMsg, nil, schema.Tool, tMsg.ToolName) + tEvent.Action = &AgentAction{ + TransferToAgent: &TransferToAgentAction{ + DestAgentName: toAgentName, + }, + } + generator.Send(tEvent) + } +} diff --git a/adk/flow.go b/adk/flow.go new file mode 100644 index 00000000..66313d11 --- /dev/null +++ b/adk/flow.go @@ -0,0 +1,457 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "runtime/debug" + "strings" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/internal/safe" + "github.com/cloudwego/eino/schema" +) + +type HistoryEntry struct { + IsUserInput bool + AgentName string + Message Message +} + +type HistoryRewriter func(ctx context.Context, entries []*HistoryEntry) ([]Message, error) + +type flowAgent struct { + Agent + + subAgents []*flowAgent + parentAgent *flowAgent + + disallowTransferToParent bool + historyRewriter HistoryRewriter + + checkPointStore compose.CheckPointStore +} + +func (a *flowAgent) deepCopy() *flowAgent { + ret := &flowAgent{ + Agent: a.Agent, + subAgents: make([]*flowAgent, 0, len(a.subAgents)), + parentAgent: a.parentAgent, + disallowTransferToParent: a.disallowTransferToParent, + historyRewriter: a.historyRewriter, + checkPointStore: a.checkPointStore, + } + + for _, sa := range a.subAgents { + ret.subAgents = append(ret.subAgents, sa.deepCopy()) + } + return ret +} + +func SetSubAgents(ctx context.Context, agent Agent, subAgents []Agent) (Agent, error) { + return setSubAgents(ctx, agent, subAgents) +} + +type AgentOption func(options *flowAgent) + +func WithDisallowTransferToParent() AgentOption { + return func(fa *flowAgent) { + fa.disallowTransferToParent = true + } +} + +func WithHistoryRewriter(h HistoryRewriter) AgentOption { + return func(fa *flowAgent) { + fa.historyRewriter = h + } +} + +func toFlowAgent(ctx context.Context, agent Agent, opts ...AgentOption) *flowAgent { + var fa *flowAgent + var ok bool + if fa, ok = agent.(*flowAgent); !ok { + fa = &flowAgent{Agent: agent} + } else { + fa = fa.deepCopy() + } + for _, opt := range opts { + opt(fa) + } + + if fa.historyRewriter == nil { + fa.historyRewriter = buildDefaultHistoryRewriter(agent.Name(ctx)) + } + + return fa +} + +func AgentWithOptions(ctx context.Context, agent Agent, opts ...AgentOption) Agent { + return toFlowAgent(ctx, agent, opts...) +} + +func setSubAgents(ctx context.Context, agent Agent, subAgents []Agent) (*flowAgent, error) { + fa := toFlowAgent(ctx, agent) + + if len(fa.subAgents) > 0 { + return nil, errors.New("agent's sub-agents has already been set") + } + + if onAgent, ok_ := fa.Agent.(OnSubAgents); ok_ { + err := onAgent.OnSetSubAgents(ctx, subAgents) + if err != nil { + return nil, err + } + } + + for _, s := range subAgents { + fsa := toFlowAgent(ctx, s) + + if fsa.parentAgent != nil { + return nil, errors.New("agent has already been set as a sub-agent of another agent") + } + + fsa.parentAgent = fa + if onAgent, ok__ := fsa.Agent.(OnSubAgents); ok__ { + err := onAgent.OnSetAsSubAgent(ctx, agent) + if err != nil { + return nil, err + } + + if fsa.disallowTransferToParent { + err = onAgent.OnDisallowTransferToParent(ctx) + if err != nil { + return nil, err + } + } + } + + fa.subAgents = append(fa.subAgents, fsa) + } + + return fa, nil +} + +func (a *flowAgent) getAgent(ctx context.Context, name string) *flowAgent { + for _, subAgent := range a.subAgents { + if subAgent.Name(ctx) == name { + return subAgent + } + } + + if a.parentAgent != nil && a.parentAgent.Name(ctx) == name { + return a.parentAgent + } + + return nil +} + +func belongToRunPath(eventRunPath []RunStep, runPath []RunStep) bool { + if len(runPath) < len(eventRunPath) { + return false + } + + for i, step := range eventRunPath { + if !runPath[i].Equals(step) { + return false + } + } + + return true +} + +func rewriteMessage(msg Message, agentName string) Message { + var sb strings.Builder + sb.WriteString("For context:") + if msg.Role == schema.Assistant { + if msg.Content != "" { + sb.WriteString(fmt.Sprintf(" [%s] said: %s.", agentName, msg.Content)) + } + if len(msg.ToolCalls) > 0 { + for i := range msg.ToolCalls { + f := msg.ToolCalls[i].Function + sb.WriteString(fmt.Sprintf(" [%s] called tool: `%s` with arguments: %s.", + agentName, f.Name, f.Arguments)) + } + } + } else if msg.Role == schema.Tool && msg.Content != "" { + sb.WriteString(fmt.Sprintf(" [%s] `%s` tool returned result: %s.", + agentName, msg.ToolName, msg.Content)) + } + + return schema.UserMessage(sb.String()) +} + +func genMsg(entry *HistoryEntry, agentName string) (Message, error) { + msg := entry.Message + if entry.AgentName != agentName { + msg = rewriteMessage(msg, entry.AgentName) + } + + return msg, nil +} + +func (ai *AgentInput) deepCopy() *AgentInput { + copied := &AgentInput{ + Messages: make([]Message, len(ai.Messages)), + EnableStreaming: ai.EnableStreaming, + } + + copy(copied.Messages, ai.Messages) + + return copied +} + +func (a *flowAgent) genAgentInput(ctx context.Context, runCtx *runContext, skipTransferMessages bool) (*AgentInput, error) { + input := runCtx.RootInput.deepCopy() + runPath := runCtx.RunPath + + events := runCtx.Session.getEvents() + historyEntries := make([]*HistoryEntry, 0) + + for _, m := range input.Messages { + historyEntries = append(historyEntries, &HistoryEntry{ + IsUserInput: true, + Message: m, + }) + } + + for _, event := range events { + if !belongToRunPath(event.RunPath, runPath) { + continue + } + + if skipTransferMessages && event.Action != nil && event.Action.TransferToAgent != nil { + // If skipTransferMessages is true and the event contain transfer action, the message in this event won't be appended to history entries. + if event.Output != nil && + event.Output.MessageOutput != nil && + event.Output.MessageOutput.Role == schema.Tool && + len(historyEntries) > 0 { + // If the skipped message's role is Tool, remove the previous history entry as it's also a transfer message(from ChatModelAgent and GenTransferMessages). + historyEntries = historyEntries[:len(historyEntries)-1] + } + continue + } + + msg, err := getMessageFromWrappedEvent(event) + if err != nil { + return nil, err + } + + if msg == nil { + continue + } + + historyEntries = append(historyEntries, &HistoryEntry{ + AgentName: event.AgentName, + Message: msg, + }) + } + + messages, err := a.historyRewriter(ctx, historyEntries) + if err != nil { + return nil, err + } + input.Messages = messages + + return input, nil +} + +func buildDefaultHistoryRewriter(agentName string) HistoryRewriter { + return func(ctx context.Context, entries []*HistoryEntry) ([]Message, error) { + messages := make([]Message, 0, len(entries)) + var err error + for _, entry := range entries { + msg := entry.Message + if !entry.IsUserInput { + msg, err = genMsg(entry, agentName) + if err != nil { + return nil, fmt.Errorf("gen agent input failed: %w", err) + } + } + + if msg != nil { + messages = append(messages, msg) + } + } + + return messages, nil + } +} + +func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + agentName := a.Name(ctx) + + ctx, runCtx := initRunCtx(ctx, agentName, input) + + o := getCommonOptions(nil, opts...) + + input, err := a.genAgentInput(ctx, runCtx, o.skipTransferMessages) + if err != nil { + return genErrorIter(err) + } + + if wf, ok := a.Agent.(*workflowAgent); ok { + return wf.Run(ctx, input, opts...) + } + + aIter := a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + + go a.run(ctx, runCtx, aIter, generator, opts...) + + return iterator +} + +func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + runCtx := getRunCtx(ctx) + if runCtx == nil { + return genErrorIter(fmt.Errorf("failed to resume agent: run context is empty")) + } + + agentName := a.Name(ctx) + targetName := agentName + if len(runCtx.RunPath) > 0 { + targetName = runCtx.RunPath[len(runCtx.RunPath)-1].agentName + } + + if agentName != targetName { + // go to target flow agent + targetAgent := recursiveGetAgent(ctx, a, targetName) + if targetAgent == nil { + return genErrorIter(fmt.Errorf("failed to resume agent: cannot find agent: %s", agentName)) + } + return targetAgent.Resume(ctx, info, opts...) + } + + if wf, ok := a.Agent.(*workflowAgent); ok { + return wf.Resume(ctx, info, opts...) + } + + // resume current agent + ra, ok := a.Agent.(ResumableAgent) + if !ok { + return genErrorIter(fmt.Errorf("failed to resume agent: target agent[%s] isn't resumable", agentName)) + } + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + aIter := ra.Resume(ctx, info, opts...) + + go a.run(ctx, runCtx, aIter, generator, opts...) + + return iterator +} + +type DeterministicTransferConfig struct { + Agent Agent + ToAgentNames []string +} + +func (a *flowAgent) run( + ctx context.Context, + runCtx *runContext, + aIter *AsyncIterator[*AgentEvent], + generator *AsyncGenerator[*AgentEvent], + opts ...AgentRunOption) { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&AgentEvent{Err: e}) + } + + generator.Close() + }() + + var lastAction *AgentAction + for { + event, ok := aIter.Next() + if !ok { + break + } + + event.AgentName = a.Name(ctx) + event.RunPath = runCtx.RunPath + // copy the event so that the copied event's stream is exclusive for any potential consumer + // copy before adding to session because once added to session it's stream could be consumed by genAgentInput at any time + copied := copyAgentEvent(event) + setAutomaticClose(copied) + setAutomaticClose(event) + runCtx.Session.addEvent(copied) + lastAction = event.Action + generator.Send(event) + } + + var destName string + if lastAction != nil { + if lastAction.Interrupted != nil { + appendInterruptRunCtx(ctx, runCtx) + return + } + if lastAction.Exit { + return + } + + if lastAction.TransferToAgent != nil { + destName = lastAction.TransferToAgent.DestAgentName + } + } + + // handle transferring to another agent + if destName != "" { + agentToRun := a.getAgent(ctx, destName) + if agentToRun == nil { + e := errors.New(fmt.Sprintf( + "transfer failed: agent '%s' not found when transferring from '%s'", + destName, a.Name(ctx))) + generator.Send(&AgentEvent{Err: e}) + return + } + + subAIter := agentToRun.Run(ctx, nil /*subagents get input from runCtx*/, opts...) + for { + subEvent, ok_ := subAIter.Next() + if !ok_ { + break + } + + setAutomaticClose(subEvent) + generator.Send(subEvent) + } + } +} + +func recursiveGetAgent(ctx context.Context, agent *flowAgent, agentName string) *flowAgent { + if agent == nil { + return nil + } + if agent.Name(ctx) == agentName { + return agent + } + a := agent.getAgent(ctx, agentName) + if a != nil { + return a + } + for _, sa := range agent.subAgents { + a = recursiveGetAgent(ctx, sa, agentName) + if a != nil { + return a + } + } + return nil +} diff --git a/adk/flow_test.go b/adk/flow_test.go new file mode 100644 index 00000000..b1b3941c --- /dev/null +++ b/adk/flow_test.go @@ -0,0 +1,144 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + mockModel "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/schema" +) + +// TestTransferToAgent tests the TransferToAgent functionality +func TestTransferToAgent(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock models for parent and child agents + parentModel := mockModel.NewMockToolCallingChatModel(ctrl) + childModel := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the parent model + // First call: parent model generates a message with TransferToAgent tool call + parentModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("I'll transfer this to the child agent", + []schema.ToolCall{ + { + ID: "tool-call-1", + Function: schema.FunctionCall{ + Name: TransferToAgentToolName, + Arguments: `{"agent_name": "ChildAgent"}`, + }, + }, + }), nil). + Times(1) + + // Set up expectations for the child model + // Second call: child model generates a response + childModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Hello from child agent", nil), nil). + Times(1) + + // Both models should implement WithTools + parentModel.EXPECT().WithTools(gomock.Any()).Return(parentModel, nil).AnyTimes() + childModel.EXPECT().WithTools(gomock.Any()).Return(childModel, nil).AnyTimes() + + // Create parent agent + parentAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "ParentAgent", + Description: "Parent agent that will transfer to child", + Instruction: "You are a parent agent.", + Model: parentModel, + }) + assert.NoError(t, err) + assert.NotNil(t, parentAgent) + + // Create child agent + childAgent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "ChildAgent", + Description: "Child agent that handles specific tasks", + Instruction: "You are a child agent.", + Model: childModel, + }) + assert.NoError(t, err) + assert.NotNil(t, childAgent) + + // Set up parent-child relationship + flowAgent, err := SetSubAgents(ctx, parentAgent, []Agent{childAgent}) + assert.NoError(t, err) + assert.NotNil(t, flowAgent) + + assert.NotNil(t, parentAgent.subAgents) + assert.NotNil(t, childAgent.parentAgent) + + // Run the parent agent + input := &AgentInput{ + Messages: []Message{ + schema.UserMessage("Please transfer this to the child agent"), + }, + } + iterator := flowAgent.Run(ctx, input) + assert.NotNil(t, iterator) + + // First event: parent model output with tool call + event1, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event1) + assert.Nil(t, event1.Err) + assert.NotNil(t, event1.Output) + assert.NotNil(t, event1.Output.MessageOutput) + assert.Equal(t, schema.Assistant, event1.Output.MessageOutput.Role) + + // Second event: tool output (TransferToAgent) + event2, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event2) + assert.Nil(t, event2.Err) + assert.NotNil(t, event2.Output) + assert.NotNil(t, event2.Output.MessageOutput) + assert.Equal(t, schema.Tool, event2.Output.MessageOutput.Role) + + // Verify the action is TransferToAgent + assert.NotNil(t, event2.Action) + assert.NotNil(t, event2.Action.TransferToAgent) + assert.Equal(t, "ChildAgent", event2.Action.TransferToAgent.DestAgentName) + + // Third event: child model output + event3, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event3) + assert.Nil(t, event3.Err) + assert.NotNil(t, event3.Output) + assert.NotNil(t, event3.Output.MessageOutput) + assert.Equal(t, schema.Assistant, event3.Output.MessageOutput.Role) + + // Verify the message content from child agent + msg := event3.Output.MessageOutput.Message + assert.NotNil(t, msg) + assert.Equal(t, "Hello from child agent", msg.Content) + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) +} diff --git a/adk/instruction.go b/adk/instruction.go new file mode 100644 index 00000000..6ee2dafd --- /dev/null +++ b/adk/instruction.go @@ -0,0 +1,43 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "fmt" + "strings" +) + +const ( + TransferToAgentInstruction = `Available other agents: %s + +Decision rule: +- If you're best suited for the question according to your description: ANSWER +- If another agent is better according its description: CALL '%s' function with their agent name + +When transferring: OUTPUT ONLY THE FUNCTION CALL` +) + +func genTransferToAgentInstruction(ctx context.Context, agents []Agent) string { + var sb strings.Builder + for _, agent := range agents { + sb.WriteString(fmt.Sprintf("\n- Agent name: %s\n Agent description: %s", + agent.Name(ctx), agent.Description(ctx))) + } + + return fmt.Sprintf(TransferToAgentInstruction, sb.String(), TransferToAgentToolName) +} diff --git a/adk/interface.go b/adk/interface.go new file mode 100644 index 00000000..ed8a6699 --- /dev/null +++ b/adk/interface.go @@ -0,0 +1,230 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "io" + + "github.com/cloudwego/eino/schema" +) + +type Message = *schema.Message +type MessageStream = *schema.StreamReader[Message] + +type MessageVariant struct { + IsStreaming bool + + Message Message + MessageStream MessageStream + // message role: Assistant or Tool + Role schema.RoleType + // only used when Role is Tool + ToolName string +} + +func EventFromMessage(msg Message, msgStream MessageStream, + role schema.RoleType, toolName string) *AgentEvent { + return &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: msgStream != nil, + Message: msg, + MessageStream: msgStream, + Role: role, + ToolName: toolName, + }, + }, + } +} + +type messageVariantSerialization struct { + IsStreaming bool + Message Message + MessageStream Message +} + +func (mv *MessageVariant) GobEncode() ([]byte, error) { + s := &messageVariantSerialization{ + IsStreaming: mv.IsStreaming, + Message: mv.Message, + } + if mv.IsStreaming { + var messages []Message + for { + frame, err := mv.MessageStream.Recv() + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("error receiving message stream: %w", err) + } + messages = append(messages, frame) + } + m, err := schema.ConcatMessages(messages) + if err != nil { + return nil, fmt.Errorf("failed to encode message: cannot concat message stream: %w", err) + } + s.MessageStream = m + } + buf := &bytes.Buffer{} + err := gob.NewEncoder(buf).Encode(s) + if err != nil { + return nil, fmt.Errorf("failed to gob encode message variant: %w", err) + } + return buf.Bytes(), nil +} + +func (mv *MessageVariant) GobDecode(b []byte) error { + s := &messageVariantSerialization{} + err := gob.NewDecoder(bytes.NewReader(b)).Decode(s) + if err != nil { + return fmt.Errorf("failed to decoding message variant: %w", err) + } + mv.IsStreaming = s.IsStreaming + mv.Message = s.Message + if s.MessageStream != nil { + mv.MessageStream = schema.StreamReaderFromArray([]*schema.Message{s.MessageStream}) + } + return nil +} + +func (mv *MessageVariant) GetMessage() (Message, error) { + var message Message + if mv.IsStreaming { + var err error + message, err = schema.ConcatMessageStream(mv.MessageStream) + if err != nil { + return nil, err + } + } else { + message = mv.Message + } + + return message, nil +} + +type TransferToAgentAction struct { + DestAgentName string +} + +type AgentOutput struct { + MessageOutput *MessageVariant + + CustomizedOutput any +} + +func NewTransferToAgentAction(destAgentName string) *AgentAction { + return &AgentAction{TransferToAgent: &TransferToAgentAction{DestAgentName: destAgentName}} +} + +func NewExitAction() *AgentAction { + return &AgentAction{Exit: true} +} + +type AgentAction struct { + Exit bool + + Interrupted *InterruptInfo + + TransferToAgent *TransferToAgentAction + + CustomizedAction any +} + +type RunStep struct { + agentName string +} + +func (r *RunStep) String() string { + return r.agentName +} + +func (r *RunStep) Equals(r1 RunStep) bool { + return r.agentName == r1.agentName +} + +func (r *RunStep) GobEncode() ([]byte, error) { + s := &runStepSerialization{AgentName: r.agentName} + buf := &bytes.Buffer{} + err := gob.NewEncoder(buf).Encode(s) + if err != nil { + return nil, fmt.Errorf("failed to gob encode RunStep: %w", err) + } + return buf.Bytes(), nil +} + +func (r *RunStep) GobDecode(b []byte) error { + s := &runStepSerialization{} + err := gob.NewDecoder(bytes.NewReader(b)).Decode(s) + if err != nil { + return fmt.Errorf("failed to gob decode RunStep: %w", err) + } + r.agentName = s.AgentName + return nil +} + +type runStepSerialization struct { + AgentName string +} + +type AgentEvent struct { + AgentName string + + RunPath []RunStep + + Output *AgentOutput + + Action *AgentAction + + Err error +} + +type AgentInput struct { + Messages []Message + EnableStreaming bool +} + +//go:generate mockgen -destination ../internal/mock/adk/Agent_mock.go --package adk -source interface.go +type Agent interface { + Name(ctx context.Context) string + Description(ctx context.Context) string + + // Run runs the agent. + // The returned AgentEvent within the AsyncIterator must be safe to modify. + // If the returned AgentEvent within the AsyncIterator contains MessageStream, + // the MessageStream MUST be exclusive and safe to be received directly. + // NOTE: it's recommended to use SetAutomaticClose() on the MessageStream of AgentEvents emitted by AsyncIterator, + // so that even the events are not processed, the MessageStream can still be closed. + Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] +} + +type OnSubAgents interface { + OnSetSubAgents(ctx context.Context, subAgents []Agent) error + OnSetAsSubAgent(ctx context.Context, parent Agent) error + + OnDisallowTransferToParent(ctx context.Context) error +} + +type ResumableAgent interface { + Agent + + Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] +} diff --git a/adk/interrupt.go b/adk/interrupt.go new file mode 100644 index 00000000..3b518123 --- /dev/null +++ b/adk/interrupt.go @@ -0,0 +1,130 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + + "github.com/cloudwego/eino/compose" +) + +type ResumeInfo struct { + EnableStreaming bool + *InterruptInfo +} + +type InterruptInfo struct { + Data any +} + +func WithCheckPointID(id string) AgentRunOption { + return WrapImplSpecificOptFn(func(t *options) { + t.checkPointID = &id + }) +} + +func init() { + gob.RegisterName("_eino_adk_serialization", &serialization{}) + gob.RegisterName("_eino_adk_workflow_interrupt_info", &WorkflowInterruptInfo{}) + gob.RegisterName("_eino_adk_react_state", &State{}) + gob.RegisterName("_eino_compose_interrupt_info", &compose.InterruptInfo{}) + gob.RegisterName("_eino_compose_tools_interrupt_and_rerun_extra", &compose.ToolsInterruptAndRerunExtra{}) +} + +type serialization struct { + RunCtx *runContext + Info *InterruptInfo +} + +func getCheckPoint( + ctx context.Context, + store compose.CheckPointStore, + key string, +) (*runContext, *ResumeInfo, bool, error) { + data, existed, err := store.Get(ctx, key) + if err != nil { + return nil, nil, false, fmt.Errorf("failed to get checkpoint from store: %w", err) + } + if !existed { + return nil, nil, false, nil + } + s := &serialization{} + err = gob.NewDecoder(bytes.NewReader(data)).Decode(s) + if err != nil { + return nil, nil, false, fmt.Errorf("failed to decode checkpoint: %w", err) + } + enableStreaming := false + if s.RunCtx.RootInput != nil { + enableStreaming = s.RunCtx.RootInput.EnableStreaming + } + return s.RunCtx, &ResumeInfo{ + EnableStreaming: enableStreaming, + InterruptInfo: s.Info, + }, true, nil +} + +func saveCheckPoint( + ctx context.Context, + store compose.CheckPointStore, + key string, + runCtx *runContext, + info *InterruptInfo, +) error { + buf := &bytes.Buffer{} + err := gob.NewEncoder(buf).Encode(&serialization{ + RunCtx: runCtx, + Info: info, + }) + if err != nil { + return fmt.Errorf("failed to encode checkpoint: %w", err) + } + return store.Set(ctx, key, buf.Bytes()) +} + +const mockCheckPointID = "adk_react_mock_key" + +func newEmptyStore() *mockStore { + return &mockStore{} +} + +func newResumeStore(data []byte) *mockStore { + return &mockStore{ + Data: data, + Valid: true, + } +} + +type mockStore struct { + Data []byte + Valid bool +} + +func (m *mockStore) Get(ctx context.Context, checkPointID string) ([]byte, bool, error) { + if m.Valid { + return m.Data, true, nil + } + return nil, false, nil +} + +func (m *mockStore) Set(ctx context.Context, checkPointID string, checkPoint []byte) error { + m.Data = checkPoint + m.Valid = true + return nil +} diff --git a/adk/interrupt_test.go b/adk/interrupt_test.go new file mode 100644 index 00000000..8f73c363 --- /dev/null +++ b/adk/interrupt_test.go @@ -0,0 +1,866 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func TestSaveAgentEventWrapper(t *testing.T) { + sr, sw := schema.Pipe[Message](1) + sw.Send(schema.UserMessage("test"), nil) + sw.Close() + sr = sr.Copy(2)[1] + + w := &agentEventWrapper{ + AgentEvent: &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + MessageStream: sr, + }, + }, + RunPath: []RunStep{ + { + "a1", + }, + { + "a2", + }, + }, + }, + mu: sync.Mutex{}, + concatenatedMessage: nil, + } + + _, err := getMessageFromWrappedEvent(w) + assert.NoError(t, err) + + buf, err := w.GobEncode() + assert.NoError(t, err) + assert.NoError(t, err) + + w1 := &agentEventWrapper{} + err = w1.GobDecode(buf) + assert.NoError(t, err) +} + +func TestSimpleInterrupt(t *testing.T) { + data := "hello world" + agent := &myAgent{ + runner: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + Message: nil, + MessageStream: schema.StreamReaderFromArray([]Message{ + schema.UserMessage("hello "), + schema.UserMessage("world"), + }), + }, + }, + }) + generator.Send(&AgentEvent{ + Action: &AgentAction{Interrupted: &InterruptInfo{ + Data: data, + }}, + }) + generator.Close() + return iter + }, + resumer: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + assert.NotNil(t, info) + assert.True(t, info.EnableStreaming) + assert.Equal(t, data, info.Data) + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Close() + return iter + }, + } + store := newMyStore() + ctx := context.Background() + runner := NewRunner(ctx, RunnerConfig{ + Agent: agent, + EnableStreaming: true, + CheckPointStore: store, + }) + iter := runner.Query(ctx, "hello world", WithCheckPointID("1")) + event, ok := iter.Next() + assert.True(t, ok) + event, ok = iter.Next() + assert.True(t, ok) + assert.Equal(t, data, event.Action.Interrupted.Data) + _, ok = iter.Next() + assert.False(t, ok) + + _, err := runner.Resume(ctx, "1") + assert.NoError(t, err) +} + +func TestMultiAgentInterrupt(t *testing.T) { + ctx := context.Background() + sa1 := &myAgent{ + name: "sa1", + runner: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + AgentName: "sa1", + Action: &AgentAction{ + TransferToAgent: &TransferToAgentAction{ + DestAgentName: "sa2", + }, + }, + }) + generator.Close() + return iter + }, + } + sa2 := &myAgent{ + name: "sa2", + runner: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + AgentName: "sa2", + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: "hello world", + }, + }, + }) + generator.Close() + return iter + }, + resumer: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + assert.NotNil(t, info) + assert.Equal(t, info.Data, "hello world") + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + AgentName: "sa2", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{Message: schema.UserMessage("completed")}, + }, + }) + generator.Close() + return iter + }, + } + a, err := SetSubAgents(ctx, sa1, []Agent{sa2}) + assert.NoError(t, err) + runner := NewRunner(ctx, RunnerConfig{ + Agent: a, + EnableStreaming: false, + CheckPointStore: newMyStore(), + }) + iter := runner.Query(ctx, "", WithCheckPointID("1")) + event, ok := iter.Next() + assert.True(t, ok) + assert.NotNil(t, event.Action.TransferToAgent) + event, ok = iter.Next() + assert.True(t, ok) + assert.NotNil(t, event.Action.Interrupted) + _, ok = iter.Next() + assert.False(t, ok) + iter, err = runner.Resume(ctx, "1") + assert.NoError(t, err) + event, ok = iter.Next() + assert.True(t, ok) + assert.Equal(t, event.Output.MessageOutput.Message.Content, "completed") + _, ok = iter.Next() + assert.False(t, ok) +} + +func TestWorkflowInterrupt(t *testing.T) { + ctx := context.Background() + sa1 := &myAgent{ + name: "sa1", + runner: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + AgentName: "sa1", + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: "sa1 interrupt data", + }, + }, + }) + generator.Close() + return iter + }, + resumer: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + assert.Equal(t, info.Data, "sa1 interrupt data") + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Close() + return iter + }, + } // interrupt once + sa2 := &myAgent{ + name: "sa2", + runner: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + AgentName: "sa2", + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: "sa2 interrupt data", + }, + }, + }) + generator.Close() + return iter + }, + resumer: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + assert.Equal(t, info.Data, "sa2 interrupt data") + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Close() + return iter + }, + } // interrupt once + sa3 := &myAgent{ + name: "sa3", + runner: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + AgentName: "sa3", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.UserMessage("sa3 completed"), + }, + }, + }) + generator.Close() + return iter + }, + } // won't interrupt + sa4 := &myAgent{ + name: "sa4", + runner: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + AgentName: "sa4", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.UserMessage("sa4 completed"), + }, + }, + }) + generator.Close() + return iter + }, + } // won't interrupt + + // sequential + a, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "sequential", + Description: "sequential agent", + SubAgents: []Agent{sa1, sa2, sa3, sa4}, + }) + assert.NoError(t, err) + runner := NewRunner(ctx, RunnerConfig{ + Agent: a, + CheckPointStore: newMyStore(), + }) + var events []*AgentEvent + iter := runner.Query(ctx, "hello world", WithCheckPointID("sequential-1")) + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + // Resume after sa1 interrupt + iter, err = runner.Resume(ctx, "sequential-1") + assert.NoError(t, err) + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + // Resume after sa2 interrupt + iter, err = runner.Resume(ctx, "sequential-1") + assert.NoError(t, err) + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + + expectedSequentialEvents := []*AgentEvent{ + { + AgentName: "sa1", + RunPath: []RunStep{{"sequential"}, {"sa1"}}, + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: &WorkflowInterruptInfo{ + OrigInput: &AgentInput{ + Messages: []Message{schema.UserMessage("hello world")}, + }, + SequentialInterruptIndex: 0, + SequentialInterruptInfo: &InterruptInfo{ + Data: "sa1 interrupt data", + }, + LoopIterations: 0, + }, + }, + }, + }, + { + AgentName: "sa2", + RunPath: []RunStep{{"sequential"}, {"sa1"}, {"sa2"}}, + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: &WorkflowInterruptInfo{ + OrigInput: &AgentInput{ + Messages: []Message{schema.UserMessage("hello world")}, + }, + SequentialInterruptIndex: 1, + SequentialInterruptInfo: &InterruptInfo{ + Data: "sa2 interrupt data", + }, + LoopIterations: 0, + }, + }, + }, + }, + { + AgentName: "sa3", + RunPath: []RunStep{{"sequential"}, {"sa1"}, {"sa2"}, {"sa3"}}, + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.UserMessage("sa3 completed"), + }, + }, + }, + { + AgentName: "sa4", + RunPath: []RunStep{{"sequential"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}}, + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.UserMessage("sa4 completed"), + }, + }, + }, + } + + assert.Equal(t, 4, len(events)) + assert.Equal(t, expectedSequentialEvents, events) + + // loop + a, err = NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", + SubAgents: []Agent{sa1, sa2, sa3, sa4}, + MaxIterations: 2, + }) + assert.NoError(t, err) + runner = NewRunner(ctx, RunnerConfig{ + Agent: a, + CheckPointStore: newMyStore(), + }) + events = []*AgentEvent{} + iter = runner.Query(ctx, "hello world", WithCheckPointID("1")) + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + for i := 0; i < 4; i++ { + iter, err = runner.Resume(ctx, "1") + assert.NoError(t, err) + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + } + expectedEvents := []*AgentEvent{ + { + AgentName: "sa1", + RunPath: []RunStep{{"loop"}, {"sa1"}}, + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: &WorkflowInterruptInfo{ + OrigInput: &AgentInput{ + Messages: []Message{schema.UserMessage("hello world")}, + }, + SequentialInterruptIndex: 0, + SequentialInterruptInfo: &InterruptInfo{ + Data: "sa1 interrupt data", + }, + LoopIterations: 0, + }, + }, + }, + }, + { + AgentName: "sa2", + RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}}, + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: &WorkflowInterruptInfo{ + OrigInput: &AgentInput{ + Messages: []Message{schema.UserMessage("hello world")}, + }, + SequentialInterruptIndex: 1, + SequentialInterruptInfo: &InterruptInfo{ + Data: "sa2 interrupt data", + }, + LoopIterations: 0, + }, + }, + }, + }, + { + AgentName: "sa3", + RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}}, + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.UserMessage("sa3 completed"), + }, + }, + }, + { + AgentName: "sa4", + RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}}, + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.UserMessage("sa4 completed"), + }, + }, + }, + { + AgentName: "sa1", + RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}, {"sa1"}}, + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: &WorkflowInterruptInfo{ + OrigInput: &AgentInput{ + Messages: []Message{schema.UserMessage("hello world")}, + }, + SequentialInterruptIndex: 0, + SequentialInterruptInfo: &InterruptInfo{ + Data: "sa1 interrupt data", + }, + LoopIterations: 1, + }, + }, + }, + }, + { + AgentName: "sa2", + RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}, {"sa1"}, {"sa2"}}, + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: &WorkflowInterruptInfo{ + OrigInput: &AgentInput{ + Messages: []Message{schema.UserMessage("hello world")}, + }, + SequentialInterruptIndex: 1, + SequentialInterruptInfo: &InterruptInfo{ + Data: "sa2 interrupt data", + }, + LoopIterations: 1, + }, + }, + }, + }, + { + AgentName: "sa3", + RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}, {"sa1"}, {"sa2"}, {"sa3"}}, + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.UserMessage("sa3 completed"), + }, + }, + }, + { + AgentName: "sa4", + RunPath: []RunStep{{"loop"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}, {"sa1"}, {"sa2"}, {"sa3"}, {"sa4"}}, + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.UserMessage("sa4 completed"), + }, + }, + }, + } + + assert.Equal(t, 8, len(events)) + assert.Equal(t, expectedEvents, events) + + // parallel + a, err = NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "parallel agent", + SubAgents: []Agent{sa1, sa2, sa3, sa4}, + }) + assert.NoError(t, err) + runner = NewRunner(ctx, RunnerConfig{ + Agent: a, + CheckPointStore: newMyStore(), + }) + iter = runner.Query(ctx, "hello world", WithCheckPointID("1")) + events = []*AgentEvent{} + + for { + event, ok := iter.Next() + if !ok { + break + } + events = append(events, event) + } + assert.Equal(t, 3, len(events)) + + iter, err = runner.Resume(ctx, "1") + assert.NoError(t, err) + _, ok := iter.Next() + assert.False(t, ok) +} + +func TestChatModelInterrupt(t *testing.T) { + ctx := context.Background() + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "name", + Description: "description", + Instruction: "instruction", + Model: &myModel{ + validator: func(i int, messages []*schema.Message) bool { + if i > 0 && (len(messages) != 4 || messages[2].Content != "new user message") { + return false + } + return true + }, + messages: []*schema.Message{ + schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "tool1", + Arguments: "arguments", + }, + }, + }), + schema.AssistantMessage("completed", nil), + }, + }, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{&myTool1{}}, + }, + }, + }) + assert.NoError(t, err) + runner := NewRunner(ctx, RunnerConfig{ + Agent: a, + CheckPointStore: newMyStore(), + }) + iter := runner.Query(ctx, "hello world", WithCheckPointID("1")) + event, ok := iter.Next() + assert.True(t, ok) + event, ok = iter.Next() + assert.True(t, ok) + assert.NoError(t, event.Err) + assert.NotNil(t, event.Action.Interrupted) + event, ok = iter.Next() + assert.False(t, ok) + + iter, err = runner.Resume(ctx, "1", WithHistoryModifier(func(ctx context.Context, messages []Message) []Message { + messages[2].Content = "new user message" + return messages + })) + assert.NoError(t, err) + event, ok = iter.Next() + assert.True(t, ok) + assert.NoError(t, event.Err) + assert.Equal(t, event.Output.MessageOutput.Message.Content, "result") + event, ok = iter.Next() + assert.True(t, ok) + assert.NoError(t, event.Err) + assert.Equal(t, event.Output.MessageOutput.Message.Content, "completed") +} + +func TestChatModelAgentToolInterrupt(t *testing.T) { + sa := &myAgent{ + runner: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + Action: &AgentAction{Interrupted: &InterruptInfo{ + Data: "hello world", + }}, + }) + generator.Close() + return iter + }, + resumer: func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + assert.NotNil(t, info) + assert.False(t, info.EnableStreaming) + assert.Equal(t, "hello world", info.Data) + + o := GetImplSpecificOptions[myAgentOptions](nil, opts...) + if o.interrupt { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + Action: &AgentAction{Interrupted: &InterruptInfo{ + Data: "hello world", + }}, + }) + generator.Close() + return iter + } + + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.UserMessage("my agent completed")}}}) + generator.Close() + return iter + }, + } + ctx := context.Background() + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "name", + Description: "description", + Instruction: "instruction", + Model: &myModel{ + messages: []*schema.Message{ + schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "myAgent", + Arguments: "{\"request\":\"123\"}", + }, + }, + }), + schema.AssistantMessage("completed", nil), + }, + }, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{NewAgentTool(ctx, sa)}, + }, + }, + }) + assert.NoError(t, err) + runner := NewRunner(ctx, RunnerConfig{ + Agent: a, + CheckPointStore: newMyStore(), + }) + + iter := runner.Query(ctx, "hello world", WithCheckPointID("1")) + event, ok := iter.Next() + assert.True(t, ok) + event, ok = iter.Next() + assert.True(t, ok) + assert.NoError(t, event.Err) + assert.NotNil(t, event.Action.Interrupted) + event, ok = iter.Next() + assert.False(t, ok) + + iter, err = runner.Resume(ctx, "1", WithAgentToolRunOptions(map[string][]AgentRunOption{ + "myAgent": {withResume()}, + })) + assert.NoError(t, err) + event, ok = iter.Next() + assert.True(t, ok) + assert.NoError(t, event.Err) + assert.NotNil(t, event.Action.Interrupted) + event, ok = iter.Next() + assert.False(t, ok) + iter, err = runner.Resume(ctx, "1") + assert.NoError(t, err) + event, ok = iter.Next() + assert.True(t, ok) + assert.NoError(t, event.Err) + assert.Equal(t, event.Output.MessageOutput.Message.Content, "my agent completed") + event, ok = iter.Next() + assert.True(t, ok) + assert.NoError(t, event.Err) + assert.Equal(t, event.Output.MessageOutput.Message.Content, "completed") + _, ok = iter.Next() + assert.False(t, ok) +} + +func newMyStore() *myStore { + return &myStore{ + m: map[string][]byte{}, + } +} + +type myStore struct { + m map[string][]byte +} + +func (m *myStore) Set(ctx context.Context, key string, value []byte) error { + m.m[key] = value + return nil +} + +func (m *myStore) Get(ctx context.Context, key string) ([]byte, bool, error) { + v, ok := m.m[key] + return v, ok, nil +} + +type myAgentOptions struct { + interrupt bool +} + +func withResume() AgentRunOption { + return WrapImplSpecificOptFn(func(t *myAgentOptions) { + t.interrupt = true + }) +} + +type myAgent struct { + name string + runner func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] + resumer func(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] +} + +func (m *myAgent) Name(ctx context.Context) string { + if len(m.name) > 0 { + return m.name + } + return "myAgent" +} + +func (m *myAgent) Description(ctx context.Context) string { + return "myAgent description" +} + +func (m *myAgent) Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + return m.runner(ctx, input, options...) +} + +func (m *myAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + return m.resumer(ctx, info, opts...) +} + +type myModel struct { + times int + messages []*schema.Message + validator func(int, []*schema.Message) bool +} + +func (m *myModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + if m.validator != nil && !m.validator(m.times, input) { + return nil, errors.New("invalid input") + } + if m.times >= len(m.messages) { + return nil, errors.New("exceeded max number of messages") + } + t := m.times + m.times++ + return m.messages[t], nil +} + +func (m *myModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + panic("implement me") +} + +func (m *myModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +type myTool1 struct { + times int +} + +func (m *myTool1) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: "tool1", + Desc: "desc", + }, nil +} + +func (m *myTool1) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + if m.times == 0 { + m.times = 1 + return "", compose.InterruptAndRerun + } + return "result", nil +} + +// Add this test case after the existing TestWorkflowInterrupt function +func TestWorkflowInterruptInvalidDataType(t *testing.T) { + ctx := context.Background() + + // Create a simple workflow agent + sa1 := &myAgent{ + name: "sa1", + runner: func(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iter, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{ + AgentName: "sa1", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.UserMessage("completed"), + }, + }, + }) + generator.Close() + return iter + }, + } + + a, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "sequential", + Description: "sequential agent", + SubAgents: []Agent{sa1}, + }) + assert.NoError(t, err) + + // Cast to workflowAgent to access Resume method directly + workflowAgent := a.(*flowAgent).Agent.(*workflowAgent) + + // Create ResumeInfo with invalid Data type (not *WorkflowInterruptInfo) + resumeInfo := &ResumeInfo{ + EnableStreaming: false, + InterruptInfo: &InterruptInfo{ + Data: "invalid data type", // This should be *WorkflowInterruptInfo but we pass string + }, + } + + // Call Resume method directly to trigger the error path + iter := workflowAgent.Resume(ctx, resumeInfo) + + // Verify that an error event is generated + event, ok := iter.Next() + assert.True(t, ok) + assert.NotNil(t, event.Err) + assert.Contains(t, event.Err.Error(), "type of InterruptInfo.Data is expected to") + assert.Contains(t, event.Err.Error(), "actual: string") + + // Verify no more events + _, ok = iter.Next() + assert.False(t, ok) +} diff --git a/adk/prebuilt/planexecute/plan_execute.go b/adk/prebuilt/planexecute/plan_execute.go new file mode 100644 index 00000000..06878f15 --- /dev/null +++ b/adk/prebuilt/planexecute/plan_execute.go @@ -0,0 +1,898 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package planexecute + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "runtime/debug" + "strings" + + "github.com/bytedance/sonic" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/internal/safe" + "github.com/cloudwego/eino/schema" +) + +// Plan represents an execution plan with a sequence of actionable steps. +// It supports JSON serialization and deserialization while providing access to the first step. +type Plan interface { + // FirstStep returns the first step to be executed in the plan. + FirstStep() string + + // Marshaler serializes the Plan into JSON. + // The resulting JSON can be used in prompt templates. + json.Marshaler + // Unmarshaler deserializes JSON content into the Plan. + // This processes output from structured chat models or tool calls into the Plan structure. + json.Unmarshaler +} + +// NewPlan is a function type that creates a new Plan instance. +type NewPlan func(ctx context.Context) Plan + +// defaultPlan is the default implementation of the Plan interface. +// +// JSON Schema: +// +// { +// "type": "object", +// "properties": { +// "steps": { +// "type": "array", +// "items": { +// "type": "string" +// }, +// "description": "Ordered list of actions to be taken. Each step should be clear, actionable, and arranged in a logical sequence." +// } +// }, +// "required": ["steps"] +// } +type defaultPlan struct { + // Steps contains the ordered list of actions to be taken. + // Each step should be clear, actionable, and arranged in a logical sequence. + Steps []string `json:"steps"` +} + +// FirstStep returns the first step in the plan or an empty string if no steps exist. +func (p *defaultPlan) FirstStep() string { + if len(p.Steps) == 0 { + return "" + } + return p.Steps[0] +} + +func (p *defaultPlan) MarshalJSON() ([]byte, error) { + type planTyp defaultPlan + return sonic.Marshal((*planTyp)(p)) +} + +func (p *defaultPlan) UnmarshalJSON(bytes []byte) error { + type planTyp defaultPlan + return sonic.Unmarshal(bytes, (*planTyp)(p)) +} + +// Response represents the final response to the user. +// This struct is used for JSON serialization/deserialization of the final response +// generated by the model. +type Response struct { + // Response is the complete response to provide to the user. + // This field is required. + Response string `json:"response"` +} + +var ( + // PlanToolInfo defines the schema for the Plan tool that can be used with ToolCallingChatModel. + // This schema instructs the model to generate a structured plan with ordered steps. + PlanToolInfo = schema.ToolInfo{ + Name: "Plan", + Desc: "Plan with a list of steps to execute in order. Each step should be clear, actionable, and arranged in a logical sequence. The output will be used to guide the execution process.", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "steps": { + Type: schema.Array, + ElemInfo: &schema.ParameterInfo{Type: schema.String}, + Desc: "different steps to follow, should be in sorted order", + Required: true, + }, + }, + ), + } + + // RespondToolInfo defines the schema for the response tool that can be used with ToolCallingChatModel. + // This schema instructs the model to generate a direct response to the user. + RespondToolInfo = schema.ToolInfo{ + Name: "Respond", + Desc: "Generate a direct response to the user. Use this tool when you have all the information needed to provide a final answer.", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "response": { + Type: schema.String, + Desc: "The complete response to provide to the user", + Required: true, + }, + }, + ), + } + + // PlannerPrompt is the prompt template for the planner. + // It provides context and guidance to the planner on how to generate the Plan. + PlannerPrompt = prompt.FromMessages(schema.FString, + schema.SystemMessage(`You are an expert planning agent. Given an objective, create a comprehensive step-by-step plan to achieve the objective. + +## YOUR TASK +Analyze the objective and generate a strategic plan that breaks down the goal into manageable, executable steps. + +## PLANNING REQUIREMENTS +Each step in your plan must be: +- **Specific and actionable**: Clear instructions that can be executed without ambiguity +- **Self-contained**: Include all necessary context, parameters, and requirements +- **Independently executable**: Can be performed by an agent without dependencies on other steps +- **Logically sequenced**: Arranged in optimal order for efficient execution +- **Objective-focused**: Directly contribute to achieving the main goal + +## PLANNING GUIDELINES +- Eliminate redundant or unnecessary steps +- Include relevant constraints, parameters, and success criteria for each step +- Ensure the final step produces a complete answer or deliverable +- Anticipate potential challenges and include mitigation strategies +- Structure steps to build upon each other logically +- Provide sufficient detail for successful execution + +## QUALITY CRITERIA +- Plan completeness: Does it address all aspects of the objective? +- Step clarity: Can each step be understood and executed independently? +- Logical flow: Do steps follow a sensible progression? +- Efficiency: Is this the most direct path to the objective? +- Adaptability: Can the plan handle unexpected results or changes?`), + schema.MessagesPlaceholder("input", false), + ) + + // ExecutorPrompt is the prompt template for the executor. + // It provides context and guidance to the executor on how to execute the Task. + ExecutorPrompt = prompt.FromMessages(schema.FString, + schema.SystemMessage(`You are a diligent and meticulous executor agent. Follow the given plan and execute your tasks carefully and thoroughly.`), + schema.UserMessage(`## OBJECTIVE +{input} +## Given the following plan: +{plan} +## COMPLETED STEPS & RESULTS +{executed_steps} +## Your task is to execute the first step, which is: +{step}`)) + + // ReplannerPrompt is the prompt template for the replanner. + // It provides context and guidance to the replanner on how to regenerate the Plan. + ReplannerPrompt = prompt.FromMessages(schema.FString, + schema.SystemMessage( + `You are going to review the progress toward an objective. Analyze the current state and determine the optimal next action. + +## YOUR TASK +Based on the progress above, you MUST choose exactly ONE action: + +### Option 1: COMPLETE (if objective is fully achieved) +Call '{respond_tool}' with: +- A comprehensive final answer +- Clear conclusion summarizing how the objective was met +- Key insights from the execution process + +### Option 2: CONTINUE (if more work is needed) +Call '{plan_tool}' with a revised plan that: +- Contains ONLY remaining steps (exclude completed ones) +- Incorporates lessons learned from executed steps +- Addresses any gaps or issues discovered +- Maintains logical step sequence + +## PLANNING REQUIREMENTS +Each step in your plan must be: +- **Specific and actionable**: Clear instructions that can be executed without ambiguity +- **Self-contained**: Include all necessary context, parameters, and requirements +- **Independently executable**: Can be performed by an agent without dependencies on other steps +- **Logically sequenced**: Arranged in optimal order for efficient execution +- **Objective-focused**: Directly contribute to achieving the main goal + +## PLANNING GUIDELINES +- Eliminate redundant or unnecessary steps +- Adapt strategy based on new information +- Include relevant constraints, parameters, and success criteria for each step + +## DECISION CRITERIA +- Has the original objective been completely satisfied? +- Are there any remaining requirements or sub-goals? +- Do the results suggest a need for strategy adjustment? +- What specific actions are still required?`), + schema.UserMessage(`## OBJECTIVE +{input} + +## ORIGINAL PLAN +{plan} + +## COMPLETED STEPS & RESULTS +{executed_steps}`), + ) +) + +const ( + // UserInputSessionKey is the session key for the user input. + UserInputSessionKey = "UserInput" + + // PlanSessionKey is the session key for the plan. + PlanSessionKey = "Plan" + + // ExecutedStepSessionKey is the session key for the execute result. + ExecutedStepSessionKey = "ExecutedStep" + + // ExecutedStepsSessionKey is the session key for the execute results. + ExecutedStepsSessionKey = "ExecutedSteps" +) + +// PlannerConfig provides configuration options for creating a planner agent. +// There are two ways to configure the planner to generate structured Plan output: +// 1. Use ChatModelWithFormattedOutput: A model pre-configured to output in the Plan format +// 2. Use ToolCallingChatModel + ToolInfo: A model that uses tool calling to generate +// the Plan structure +type PlannerConfig struct { + // ChatModelWithFormattedOutput is a model pre-configured to output in the Plan format. + // Create this by configuring a model to output structured data directly. + // See example: https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go + ChatModelWithFormattedOutput model.BaseChatModel + + // ToolCallingChatModel is a model that supports tool calling capabilities. + // When provided with ToolInfo, it will use tool calling to generate the Plan structure. + ToolCallingChatModel model.ToolCallingChatModel + + // ToolInfo defines the schema for the Plan structure when using tool calling. + // Optional. If not provided, PlanToolInfo will be used as the default. + ToolInfo *schema.ToolInfo + + // GenInputFn is a function that generates the input messages for the planner. + // Optional. If not provided, defaultGenPlannerInputFn will be used. + GenInputFn GenPlannerModelInputFn + + // NewPlan creates a new Plan instance for JSON. + // The returned Plan will be used to unmarshal the model-generated JSON output. + // Optional. If not provided, defaultNewPlan will be used. + NewPlan NewPlan +} + +// GenPlannerModelInputFn is a function type that generates input messages for the planner. +type GenPlannerModelInputFn func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) + +func defaultNewPlan(ctx context.Context) Plan { + return &defaultPlan{} +} + +func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) { + msgs, err := PlannerPrompt.Format(ctx, map[string]any{ + "input": userInput, + }) + if err != nil { + return nil, err + } + return msgs, nil +} + +type planner struct { + toolCall bool + chatModel model.BaseChatModel + genInputFn GenPlannerModelInputFn + newPlan NewPlan +} + +func (p *planner) Name(_ context.Context) string { + return "Planner" +} + +func (p *planner) Description(_ context.Context) string { + return "a planner agent" +} + +func argToContent(msg adk.Message) (adk.Message, error) { + if len(msg.ToolCalls) == 0 { + return nil, schema.ErrNoValue + } + + return schema.AssistantMessage(msg.ToolCalls[0].Function.Arguments, nil), nil +} + +func (p *planner) Run(ctx context.Context, input *adk.AgentInput, + _ ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + + adk.AddSessionValue(ctx, UserInputSessionKey, input.Messages) + + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&adk.AgentEvent{Err: e}) + } + + generator.Close() + }() + + msgs, err := p.genInputFn(ctx, input.Messages) + if err != nil { + generator.Send(&adk.AgentEvent{Err: err}) + return + } + var modelCallOptions []model.Option + if p.toolCall { + modelCallOptions = append(modelCallOptions, model.WithToolChoice(schema.ToolChoiceForced)) + } + + var msg adk.Message + if input.EnableStreaming { + s, err_ := p.chatModel.Stream(ctx, msgs, modelCallOptions...) + if err_ != nil { + generator.Send(&adk.AgentEvent{Err: err_}) + return + } + + ss := s.Copy(2) + var sOutput *schema.StreamReader[*schema.Message] + if p.toolCall { + sOutput = schema.StreamReaderWithConvert(ss[0], argToContent) + } else { + sOutput = ss[0] + } + + event := adk.EventFromMessage(nil, sOutput, schema.Assistant, "") + generator.Send(event) + + msg, err_ = schema.ConcatMessageStream(ss[1]) + if err_ != nil { + generator.Send(&adk.AgentEvent{Err: err_}) + return + } + + if p.toolCall && len(msg.ToolCalls) == 0 { + generator.Send(&adk.AgentEvent{Err: errors.New("no tool call")}) + return + } + } else { + var err_ error + msg, err_ = p.chatModel.Generate(ctx, msgs, modelCallOptions...) + if err_ != nil { + generator.Send(&adk.AgentEvent{Err: err_}) + return + } + + var output adk.Message + if p.toolCall { + if len(msg.ToolCalls) == 0 { + generator.Send(&adk.AgentEvent{Err: errors.New("no tool call")}) + return + } + output = schema.AssistantMessage(msg.ToolCalls[0].Function.Arguments, nil) + } else { + output = msg + } + + event := adk.EventFromMessage(output, nil, schema.Assistant, "") + generator.Send(event) + } + + var planJSON string + if p.toolCall { + planJSON = msg.ToolCalls[0].Function.Arguments + } else { + planJSON = msg.Content + } + plan := p.newPlan(ctx) + err = plan.UnmarshalJSON([]byte(planJSON)) + if err != nil { + err = fmt.Errorf("unmarshal plan error: %w", err) + generator.Send(&adk.AgentEvent{Err: err}) + return + } + + adk.AddSessionValue(ctx, PlanSessionKey, plan) + }() + + return iterator +} + +// NewPlanner creates a new planner agent based on the provided configuration. +// The planner agent uses either ChatModelWithFormattedOutput or ToolCallingChatModel+ToolInfo +// to generate structured Plan output. +// +// If ChatModelWithFormattedOutput is provided, it will be used directly. +// If ToolCallingChatModel is provided, it will be configured with ToolInfo (or PlanToolInfo by default) +// to generate structured Plan output. +func NewPlanner(_ context.Context, cfg *PlannerConfig) (adk.Agent, error) { + var chatModel model.BaseChatModel + var toolCall bool + if cfg.ChatModelWithFormattedOutput != nil { + chatModel = cfg.ChatModelWithFormattedOutput + } else { + toolCall = true + toolInfo := cfg.ToolInfo + if toolInfo == nil { + toolInfo = &PlanToolInfo + } + + var err error + chatModel, err = cfg.ToolCallingChatModel.WithTools([]*schema.ToolInfo{toolInfo}) + if err != nil { + return nil, err + } + } + + inputFn := cfg.GenInputFn + if inputFn == nil { + inputFn = defaultGenPlannerInputFn + } + + planParser := cfg.NewPlan + if planParser == nil { + planParser = defaultNewPlan + } + + return &planner{ + toolCall: toolCall, + chatModel: chatModel, + genInputFn: inputFn, + newPlan: planParser, + }, nil +} + +// ExecutionContext is the input information for the executor and the planner. +type ExecutionContext struct { + UserInput []adk.Message + Plan Plan + ExecutedSteps []ExecutedStep +} + +// GenModelInputFn is a function that generates the input messages for the executor and the planner. +type GenModelInputFn func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) + +// ExecutorConfig provides configuration options for creating an executor agent. +type ExecutorConfig struct { + // Model is the chat model used by the executor. + Model model.ToolCallingChatModel + + // ToolsConfig specifies the tools available to the executor. + ToolsConfig adk.ToolsConfig + + // MaxIterations defines the upper limit of ChatModel generation cycles. + // The agent will terminate with an error if this limit is exceeded. + // Optional. Defaults to 20. + MaxIterations int + + // GenInputFn generates the input messages for the Executor. + // Optional. If not provided, defaultGenExecutorInputFn will be used. + GenInputFn GenModelInputFn +} + +type ExecutedStep struct { + Step string + Result string +} + +// NewExecutor creates a new executor agent. +func NewExecutor(ctx context.Context, cfg *ExecutorConfig) (adk.Agent, error) { + + genInputFn := cfg.GenInputFn + if genInputFn == nil { + genInputFn = defaultGenExecutorInputFn + } + genInput := func(ctx context.Context, instruction string, _ *adk.AgentInput) ([]adk.Message, error) { + + plan, ok := adk.GetSessionValue(ctx, PlanSessionKey) + if !ok { + panic("impossible: plan not found") + } + plan_ := plan.(Plan) + + userInput, ok := adk.GetSessionValue(ctx, UserInputSessionKey) + if !ok { + panic("impossible: user input not found") + } + userInput_ := userInput.([]adk.Message) + + var executedSteps_ []ExecutedStep + executedStep, ok := adk.GetSessionValue(ctx, ExecutedStepsSessionKey) + if ok { + executedSteps_ = executedStep.([]ExecutedStep) + } + + in := &ExecutionContext{ + UserInput: userInput_, + Plan: plan_, + ExecutedSteps: executedSteps_, + } + + msgs, err := genInputFn(ctx, in) + if err != nil { + return nil, err + } + + return msgs, nil + } + + agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "Executor", + Description: "an executor agent", + Model: cfg.Model, + ToolsConfig: cfg.ToolsConfig, + GenModelInput: genInput, + MaxIterations: cfg.MaxIterations, + OutputKey: ExecutedStepSessionKey, + }) + if err != nil { + return nil, err + } + + return agent, nil +} + +func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) { + + planContent, err := in.Plan.MarshalJSON() + if err != nil { + return nil, err + } + + userMsgs, err := ExecutorPrompt.Format(ctx, map[string]any{ + "input": formatInput(in.UserInput), + "plan": string(planContent), + "executed_steps": formatExecutedSteps(in.ExecutedSteps), + "step": in.Plan.FirstStep(), + }) + if err != nil { + return nil, err + } + + return userMsgs, nil +} + +type replanner struct { + chatModel model.ToolCallingChatModel + planTool *schema.ToolInfo + respondTool *schema.ToolInfo + + genInputFn GenModelInputFn + newPlan NewPlan +} + +type ReplannerConfig struct { + // ChatModel is the model that supports tool calling capabilities. + // It will be configured with PlanTool and RespondTool to generate updated plans or responses. + ChatModel model.ToolCallingChatModel + + // PlanTool defines the schema for the Plan tool that can be used with ToolCallingChatModel. + // Optional. If not provided, the default PlanToolInfo will be used. + PlanTool *schema.ToolInfo + + // RespondTool defines the schema for the response tool that can be used with ToolCallingChatModel. + // Optional. If not provided, the default RespondToolInfo will be used. + RespondTool *schema.ToolInfo + + // GenInputFn generates the input messages for the Replanner. + // Optional. If not provided, buildGenReplannerInputFn will be used. + GenInputFn GenModelInputFn + + // NewPlan creates a new Plan instance. + // The returned Plan will be used to unmarshal the model-generated JSON output from PlanTool. + // Optional. If not provided, defaultNewPlan will be used. + NewPlan NewPlan +} + +// formatInput formats the input messages into a string. +func formatInput(input []adk.Message) string { + var sb strings.Builder + for _, msg := range input { + sb.WriteString(msg.Content) + sb.WriteString("\n") + } + + return sb.String() +} + +func formatExecutedSteps(results []ExecutedStep) string { + var sb strings.Builder + for _, result := range results { + sb.WriteString(fmt.Sprintf("Step: %s\nResult: %s\n\n", result.Step, result.Result)) + } + + return sb.String() +} + +func (r *replanner) Name(_ context.Context) string { + return "Replanner" +} + +func (r *replanner) Description(_ context.Context) string { + return "a replanner agent" +} + +func (r *replanner) genInput(ctx context.Context) ([]adk.Message, error) { + + executedStep, ok := adk.GetSessionValue(ctx, ExecutedStepSessionKey) + if !ok { + panic("impossible: execute result not found") + } + executedStep_ := executedStep.(string) + + plan, ok := adk.GetSessionValue(ctx, PlanSessionKey) + if !ok { + panic("impossible: plan not found") + } + plan_ := plan.(Plan) + step := plan_.FirstStep() + + var executedSteps_ []ExecutedStep + executedSteps, ok := adk.GetSessionValue(ctx, ExecutedStepsSessionKey) + if ok { + executedSteps_ = executedSteps.([]ExecutedStep) + } + + executedSteps_ = append(executedSteps_, ExecutedStep{ + Step: step, + Result: executedStep_, + }) + adk.AddSessionValue(ctx, ExecutedStepsSessionKey, executedSteps_) + + userInput, ok := adk.GetSessionValue(ctx, UserInputSessionKey) + if !ok { + panic("impossible: user input not found") + } + userInput_ := userInput.([]adk.Message) + + in := &ExecutionContext{ + UserInput: userInput_, + Plan: plan_, + ExecutedSteps: executedSteps_, + } + genInputFn := r.genInputFn + if genInputFn == nil { + genInputFn = buildGenReplannerInputFn(r.planTool.Name, r.respondTool.Name) + } + msgs, err := genInputFn(ctx, in) + if err != nil { + return nil, err + } + + return msgs, nil +} + +func (r *replanner) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + msgs, err := r.genInput(ctx) + if err != nil { + generator.Send(&adk.AgentEvent{Err: err}) + generator.Close() + return iterator + } + + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&adk.AgentEvent{Err: e}) + } + + generator.Close() + }() + + callOpt := model.WithToolChoice(schema.ToolChoiceForced) + + var planMsg adk.Message + if input.EnableStreaming { + var s adk.MessageStream + s, err = r.chatModel.Stream(ctx, msgs, callOpt) + if err != nil { + generator.Send(&adk.AgentEvent{Err: err}) + return + } + + ss := s.Copy(2) + sOutput := schema.StreamReaderWithConvert(ss[0], argToContent) + event := adk.EventFromMessage(nil, sOutput, schema.Assistant, "") + generator.Send(event) + + var chunks []adk.Message + s = ss[1] + var isResponse bool + for { + chunk, err_ := s.Recv() + if err_ != nil { + if err_ == io.EOF { + break + } + + generator.Send(&adk.AgentEvent{Err: err_}) + return + } + + if len(chunk.ToolCalls) > 0 && chunk.ToolCalls[0].Function.Name == r.respondTool.Name { + isResponse = true + break + } + + chunks = append(chunks, chunk) + } + s.Close() + + if isResponse { + action := adk.NewExitAction() + generator.Send(&adk.AgentEvent{Action: action}) + return + } + + planMsg, err = schema.ConcatMessages(chunks) + if err != nil { + generator.Send(&adk.AgentEvent{Err: err}) + return + } + + if len(planMsg.ToolCalls) == 0 { + generator.Send(&adk.AgentEvent{Err: errors.New("no tool call")}) + return + } + } else { + var msg adk.Message + msg, err = r.chatModel.Generate(ctx, msgs, callOpt) + if err != nil { + generator.Send(&adk.AgentEvent{Err: err}) + return + } + + if len(msg.ToolCalls) > 0 { + output := schema.AssistantMessage(msg.ToolCalls[0].Function.Arguments, nil) + event := adk.EventFromMessage(output, nil, schema.Assistant, "") + generator.Send(event) + + if len(msg.ToolCalls) > 0 && msg.ToolCalls[0].Function.Name == r.respondTool.Name { + action := adk.NewExitAction() + generator.Send(&adk.AgentEvent{Action: action}) + return + } + + planMsg = msg + } else { + generator.Send(&adk.AgentEvent{Err: errors.New("no tool call")}) + return + } + } + + // handle plan tool call + if planMsg.ToolCalls[0].Function.Name != r.planTool.Name { + errMsg := fmt.Sprintf("unexpected tool call: %s", planMsg.ToolCalls[0].Function.Name) + generator.Send(&adk.AgentEvent{Err: errors.New(errMsg)}) + return + } + + plan_ := r.newPlan(ctx) + err = plan_.UnmarshalJSON([]byte(planMsg.ToolCalls[0].Function.Arguments)) + if err != nil { + err = fmt.Errorf("unmarshal plan error: %w", err) + generator.Send(&adk.AgentEvent{Err: err}) + return + } + + adk.AddSessionValue(ctx, PlanSessionKey, plan_) + }() + + return iterator +} + +func buildGenReplannerInputFn(planToolName, respondToolName string) GenModelInputFn { + return func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) { + planContent, err := in.Plan.MarshalJSON() + if err != nil { + return nil, err + } + msgs, err := ReplannerPrompt.Format(ctx, map[string]any{ + "plan": string(planContent), + "input": formatInput(in.UserInput), + "executed_steps": formatExecutedSteps(in.ExecutedSteps), + "plan_tool": planToolName, + "respond_tool": respondToolName, + }) + if err != nil { + return nil, err + } + + return msgs, nil + } +} + +func NewReplanner(_ context.Context, cfg *ReplannerConfig) (adk.Agent, error) { + planTool := cfg.PlanTool + if planTool == nil { + planTool = &PlanToolInfo + } + + respondTool := cfg.RespondTool + if respondTool == nil { + respondTool = &RespondToolInfo + } + + chatModel, err := cfg.ChatModel.WithTools([]*schema.ToolInfo{planTool, respondTool}) + if err != nil { + return nil, err + } + + planParser := cfg.NewPlan + if planParser == nil { + planParser = defaultNewPlan + } + + return &replanner{ + chatModel: chatModel, + planTool: planTool, + respondTool: respondTool, + genInputFn: cfg.GenInputFn, + newPlan: planParser, + }, nil +} + +// Config provides configuration options for creating a plan-execute-replan agent. +type Config struct { + // Planner specifies the agent that generates the plan. + // You can use provided NewPlanner to create a planner agent. + Planner adk.Agent + + // Executor specifies the agent that executes the plan generated by planner or replanner. + // You can use provided NewExecutor to create an executor agent. + Executor adk.Agent + + // Replanner specifies the agent that replans the plan. + // You can use provided NewReplanner to create a replanner agent. + Replanner adk.Agent + + // MaxIterations defines the maximum number of loops for 'execute-replan'. + // Optional. If not provided, 10 will be used as the default. + MaxIterations int +} + +// New creates a new plan-execute-replan agent with the given configuration. +// The plan-execute-replan pattern works in three phases: +// 1. Planning: Generate a structured plan with clear, actionable steps +// 2. Execution: Execute the first step of the plan +// 3. Replanning: Evaluate progress and either complete the task or revise the plan +// This approach enables complex problem-solving through iterative refinement. +func New(ctx context.Context, cfg *Config) (adk.Agent, error) { + maxIterations := cfg.MaxIterations + if maxIterations <= 0 { + maxIterations = 10 + } + loop, err := adk.NewLoopAgent(ctx, &adk.LoopAgentConfig{ + Name: "execute_replan", + SubAgents: []adk.Agent{cfg.Executor, cfg.Replanner}, + MaxIterations: maxIterations, + }) + if err != nil { + return nil, err + } + + return adk.NewSequentialAgent(ctx, &adk.SequentialAgentConfig{ + Name: "plan_execute_replan", + SubAgents: []adk.Agent{cfg.Planner, loop}, + }) +} diff --git a/adk/prebuilt/planexecute/plan_execute_test.go b/adk/prebuilt/planexecute/plan_execute_test.go new file mode 100644 index 00000000..6b30d65f --- /dev/null +++ b/adk/prebuilt/planexecute/plan_execute_test.go @@ -0,0 +1,709 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package planexecute + +import ( + "context" + "testing" + + "github.com/bytedance/sonic" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/model" + mockAdk "github.com/cloudwego/eino/internal/mock/adk" + mockModel "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/schema" +) + +// TestNewPlanner tests the NewPlanner function with ChatModelWithFormattedOutput +func TestNewPlannerWithFormattedOutput(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock chat model + mockChatModel := mockModel.NewMockBaseChatModel(ctrl) + + // Create the PlannerConfig + conf := &PlannerConfig{ + ChatModelWithFormattedOutput: mockChatModel, + } + + // Create the planner + p, err := NewPlanner(ctx, conf) + assert.NoError(t, err) + assert.NotNil(t, p) + + // Verify the planner's name and description + assert.Equal(t, "Planner", p.Name(ctx)) + assert.Equal(t, "a planner agent", p.Description(ctx)) +} + +// TestNewPlannerWithToolCalling tests the NewPlanner function with ToolCallingChatModel +func TestNewPlannerWithToolCalling(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock tool calling chat model + mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) + mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) + + // Create the PlannerConfig + conf := &PlannerConfig{ + ToolCallingChatModel: mockToolCallingModel, + // Use default instruction and tool info + } + + // Create the planner + p, err := NewPlanner(ctx, conf) + assert.NoError(t, err) + assert.NotNil(t, p) + + // Verify the planner's name and description + assert.Equal(t, "Planner", p.Name(ctx)) + assert.Equal(t, "a planner agent", p.Description(ctx)) +} + +// TestPlannerRunWithFormattedOutput tests the Run method of a planner created with ChatModelWithFormattedOutput +func TestPlannerRunWithFormattedOutput(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock chat model + mockChatModel := mockModel.NewMockBaseChatModel(ctrl) + + // Create a plan response + planJSON := `{"steps":["Step 1", "Step 2", "Step 3"]}` + planMsg := schema.AssistantMessage(planJSON, nil) + + // Mock the Generate method + mockChatModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).Return(planMsg, nil).Times(1) + + // Create the PlannerConfig + conf := &PlannerConfig{ + ChatModelWithFormattedOutput: mockChatModel, + } + + // Create the planner + p, err := NewPlanner(ctx, conf) + assert.NoError(t, err) + + // Run the planner + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: p}) + iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("Plan this task")}) + + // Get the event from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + msg, _, err := adk.GetMessage(event) + assert.NoError(t, err) + assert.Equal(t, planMsg.Content, msg.Content) + + event, ok = iterator.Next() + assert.False(t, ok) + + plan := defaultNewPlan(ctx) + err = plan.UnmarshalJSON([]byte(msg.Content)) + assert.NoError(t, err) + plan_ := plan.(*defaultPlan) + assert.Equal(t, 3, len(plan_.Steps)) + assert.Equal(t, "Step 1", plan_.Steps[0]) + assert.Equal(t, "Step 2", plan_.Steps[1]) + assert.Equal(t, "Step 3", plan_.Steps[2]) +} + +// TestPlannerRunWithToolCalling tests the Run method of a planner created with ToolCallingChatModel +func TestPlannerRunWithToolCalling(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock tool calling chat model + mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) + + // Create a tool call response with a plan + planArgs := `{"steps":["Step 1", "Step 2", "Step 3"]}` + toolCall := schema.ToolCall{ + ID: "tool_call_id", + Type: "function", + Function: schema.FunctionCall{ + Name: "Plan", // This should match PlanToolInfo.Name + Arguments: planArgs, + }, + } + + toolCallMsg := schema.AssistantMessage("", nil) + toolCallMsg.ToolCalls = []schema.ToolCall{toolCall} + + // Mock the WithTools method to return a model that will be used for Generate + mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) + + // Mock the Generate method to return the tool call message + mockToolCallingModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).Return(toolCallMsg, nil).Times(1) + + // Create the PlannerConfig with ToolCallingChatModel + conf := &PlannerConfig{ + ToolCallingChatModel: mockToolCallingModel, + // Use default instruction and tool info + } + + // Create the planner + p, err := NewPlanner(ctx, conf) + assert.NoError(t, err) + + // Run the planner + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: p}) + iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("no input")}) + + // Get the event from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + + msg, _, err := adk.GetMessage(event) + assert.NoError(t, err) + assert.Equal(t, planArgs, msg.Content) + + _, ok = iterator.Next() + assert.False(t, ok) + + plan := defaultNewPlan(ctx) + err = plan.UnmarshalJSON([]byte(msg.Content)) + assert.NoError(t, err) + plan_ := plan.(*defaultPlan) + assert.NoError(t, err) + assert.Equal(t, 3, len(plan_.Steps)) + assert.Equal(t, "Step 1", plan_.Steps[0]) + assert.Equal(t, "Step 2", plan_.Steps[1]) + assert.Equal(t, "Step 3", plan_.Steps[2]) +} + +// TestNewExecutor tests the NewExecutor function +func TestNewExecutor(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock tool calling chat model + mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) + + // Create the ExecutorConfig + conf := &ExecutorConfig{ + Model: mockToolCallingModel, + MaxIterations: 3, + } + + // Create the executor + executor, err := NewExecutor(ctx, conf) + assert.NoError(t, err) + assert.NotNil(t, executor) + + // Verify the executor's name and description + assert.Equal(t, "Executor", executor.Name(ctx)) + assert.Equal(t, "an executor agent", executor.Description(ctx)) +} + +// TestExecutorRun tests the Run method of the executor +func TestExecutorRun(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock tool calling chat model + mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) + + // Store a plan in the session + plan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} + adk.AddSessionValue(ctx, PlanSessionKey, plan) + + // Set up expectations for the mock model + // The model should return the last user message as its response + mockToolCallingModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, messages []*schema.Message, opts ...model.Option) (*schema.Message, error) { + // Find the last user message + var lastUserMessage string + for _, msg := range messages { + if msg.Role == schema.User { + lastUserMessage = msg.Content + } + } + // Return the last user message as the model's response + return schema.AssistantMessage(lastUserMessage, nil), nil + }).Times(1) + + // Create the ExecutorConfig + conf := &ExecutorConfig{ + Model: mockToolCallingModel, + MaxIterations: 3, + } + + // Create the executor + executor, err := NewExecutor(ctx, conf) + assert.NoError(t, err) + + // Run the executor + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: executor}) + iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("no input")}, + adk.WithSessionValues(map[string]any{ + PlanSessionKey: plan, + UserInputSessionKey: []adk.Message{schema.UserMessage("no input")}, + }), + ) + + // Get the event from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + msg, _, err := adk.GetMessage(event) + assert.NoError(t, err) + t.Logf("executor model input msg:\n %s\n", msg.Content) + + _, ok = iterator.Next() + assert.False(t, ok) +} + +// TestNewReplanner tests the NewReplanner function +func TestNewReplanner(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock tool calling chat model + mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) + // Mock the WithTools method + mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) + + // Create plan and respond tools + planTool := &schema.ToolInfo{ + Name: "Plan", + Desc: "Plan tool", + } + + respondTool := &schema.ToolInfo{ + Name: "Respond", + Desc: "Respond tool", + } + + // Create the ReplannerConfig + conf := &ReplannerConfig{ + ChatModel: mockToolCallingModel, + PlanTool: planTool, + RespondTool: respondTool, + } + + // Create the replanner + rp, err := NewReplanner(ctx, conf) + assert.NoError(t, err) + assert.NotNil(t, rp) + + // Verify the replanner's name and description + assert.Equal(t, "Replanner", rp.Name(ctx)) + assert.Equal(t, "a replanner agent", rp.Description(ctx)) +} + +// TestReplannerRunWithPlan tests the Replanner's ability to use the plan_tool +func TestReplannerRunWithPlan(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock tool calling chat model + mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) + + // Create plan and respond tools + planTool := &schema.ToolInfo{ + Name: "Plan", + Desc: "Plan tool", + } + + respondTool := &schema.ToolInfo{ + Name: "Respond", + Desc: "Respond tool", + } + + // Create a tool call response for the Plan tool + planArgs := `{"steps":["Updated Step 1", "Updated Step 2"]}` + toolCall := schema.ToolCall{ + ID: "tool_call_id", + Type: "function", + Function: schema.FunctionCall{ + Name: planTool.Name, + Arguments: planArgs, + }, + } + + toolCallMsg := schema.AssistantMessage("", nil) + toolCallMsg.ToolCalls = []schema.ToolCall{toolCall} + + // Mock the Generate method + mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) + mockToolCallingModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).Return(toolCallMsg, nil).Times(1) + + // Create the ReplannerConfig + conf := &ReplannerConfig{ + ChatModel: mockToolCallingModel, + PlanTool: planTool, + RespondTool: respondTool, + } + + // Create the replanner + rp, err := NewReplanner(ctx, conf) + assert.NoError(t, err) + + // Store necessary values in the session + plan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} + + rp, err = agentOutputSessionKVs(ctx, rp) + assert.NoError(t, err) + + // Run the replanner + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: rp}) + iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("no input")}, + adk.WithSessionValues(map[string]any{ + PlanSessionKey: plan, + ExecutedStepSessionKey: "Execution result", + UserInputSessionKey: []adk.Message{schema.UserMessage("User input")}, + }), + ) + + // Get the event from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + + event, ok = iterator.Next() + assert.True(t, ok) + kvs := event.Output.CustomizedOutput.(map[string]any) + assert.Greater(t, len(kvs), 0) + + // Verify the updated plan was stored in the session + planValue, ok := kvs[PlanSessionKey] + assert.True(t, ok) + updatedPlan, ok := planValue.(*defaultPlan) + assert.True(t, ok) + assert.Equal(t, 2, len(updatedPlan.Steps)) + assert.Equal(t, "Updated Step 1", updatedPlan.Steps[0]) + assert.Equal(t, "Updated Step 2", updatedPlan.Steps[1]) + + // Verify the execute results were updated + executeResultsValue, ok := kvs[ExecutedStepsSessionKey] + assert.True(t, ok) + executeResults, ok := executeResultsValue.([]ExecutedStep) + assert.True(t, ok) + assert.Equal(t, 1, len(executeResults)) + assert.Equal(t, "Step 1", executeResults[0].Step) + assert.Equal(t, "Execution result", executeResults[0].Result) + + _, ok = iterator.Next() + assert.False(t, ok) +} + +// TestReplannerRunWithRespond tests the Replanner's ability to use the respond_tool +func TestReplannerRunWithRespond(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a mock tool calling chat model + mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) + + // Create plan and respond tools + planTool := &schema.ToolInfo{ + Name: "Plan", + Desc: "Plan tool", + } + + respondTool := &schema.ToolInfo{ + Name: "Respond", + Desc: "Respond tool", + } + + // Create a tool call response for the Respond tool + responseArgs := `{"response":"This is the final response to the user"}` + toolCall := schema.ToolCall{ + ID: "tool_call_id", + Type: "function", + Function: schema.FunctionCall{ + Name: respondTool.Name, + Arguments: responseArgs, + }, + } + + toolCallMsg := schema.AssistantMessage("", nil) + toolCallMsg.ToolCalls = []schema.ToolCall{toolCall} + + // Mock the Generate method + mockToolCallingModel.EXPECT().WithTools(gomock.Any()).Return(mockToolCallingModel, nil).Times(1) + mockToolCallingModel.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).Return(toolCallMsg, nil).Times(1) + + // Create the ReplannerConfig + conf := &ReplannerConfig{ + ChatModel: mockToolCallingModel, + PlanTool: planTool, + RespondTool: respondTool, + } + + // Create the replanner + rp, err := NewReplanner(ctx, conf) + assert.NoError(t, err) + + // Store necessary values in the session + plan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} + + // Run the replanner + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: rp}) + iterator := runner.Run(ctx, []adk.Message{schema.UserMessage("no input")}, + adk.WithSessionValues(map[string]any{ + PlanSessionKey: plan, + ExecutedStepSessionKey: "Execution result", + UserInputSessionKey: []adk.Message{schema.UserMessage("User input")}, + }), + ) + + // Get the event from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.Nil(t, event.Err) + msg, _, err := adk.GetMessage(event) + assert.NoError(t, err) + assert.Equal(t, responseArgs, msg.Content) + + // Verify that an exit action was generated + event, ok = iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event.Action) + assert.True(t, event.Action.Exit) + + _, ok = iterator.Next() + assert.False(t, ok) +} + +// TestNewPlanExecuteAgent tests the New function +func TestNewPlanExecuteAgent(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock agents + mockPlanner := mockAdk.NewMockAgent(ctrl) + mockExecutor := mockAdk.NewMockAgent(ctrl) + mockReplanner := mockAdk.NewMockAgent(ctrl) + + // Set up expectations for the mock agents + mockPlanner.EXPECT().Name(gomock.Any()).Return("Planner").AnyTimes() + mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() + + mockExecutor.EXPECT().Name(gomock.Any()).Return("Executor").AnyTimes() + mockExecutor.EXPECT().Description(gomock.Any()).Return("an executor agent").AnyTimes() + + mockReplanner.EXPECT().Name(gomock.Any()).Return("Replanner").AnyTimes() + mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() + + conf := &Config{ + Planner: mockPlanner, + Executor: mockExecutor, + Replanner: mockReplanner, + } + + // Create the plan execute agent + agent, err := New(ctx, conf) + assert.NoError(t, err) + assert.NotNil(t, agent) +} + +func TestPlanExecuteAgentWithReplan(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock agents + mockPlanner := mockAdk.NewMockAgent(ctrl) + mockExecutor := mockAdk.NewMockAgent(ctrl) + mockReplanner := mockAdk.NewMockAgent(ctrl) + + // Set up expectations for the mock agents + mockPlanner.EXPECT().Name(gomock.Any()).Return("Planner").AnyTimes() + mockPlanner.EXPECT().Description(gomock.Any()).Return("a planner agent").AnyTimes() + + mockExecutor.EXPECT().Name(gomock.Any()).Return("Executor").AnyTimes() + mockExecutor.EXPECT().Description(gomock.Any()).Return("an executor agent").AnyTimes() + + mockReplanner.EXPECT().Name(gomock.Any()).Return("Replanner").AnyTimes() + mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() + + // Create a plan + originalPlan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} + // Create an updated plan with fewer steps (after replanning) + updatedPlan := &defaultPlan{Steps: []string{"Updated Step 2", "Updated Step 3"}} + // Create execute result + originalExecuteResult := "Execution result for Step 1" + updatedExecuteResult := "Execution result for Updated Step 2" + + // Create user input + userInput := []adk.Message{schema.UserMessage("User task input")} + + finalResponse := &Response{Response: "Final response to user after executing all steps"} + + // Mock the planner Run method to set the original plan + mockPlanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + + // Set the plan in the session + adk.AddSessionValue(ctx, PlanSessionKey, originalPlan) + adk.AddSessionValue(ctx, UserInputSessionKey, userInput) + + // Send a message event + planJSON, _ := sonic.MarshalString(originalPlan) + msg := schema.AssistantMessage(planJSON, nil) + event := adk.EventFromMessage(msg, nil, schema.Assistant, "") + generator.Send(event) + generator.Close() + + return iterator + }, + ).Times(1) + + // Mock the executor Run method to set the execute result + mockExecutor.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + + plan, _ := adk.GetSessionValue(ctx, PlanSessionKey) + currentPlan := plan.(*defaultPlan) + var msg adk.Message + // Check if this is the first replanning (original plan has 3 steps) + if len(currentPlan.Steps) == 3 { + msg = schema.AssistantMessage(originalExecuteResult, nil) + adk.AddSessionValue(ctx, ExecutedStepSessionKey, originalExecuteResult) + } else { + msg = schema.AssistantMessage(updatedExecuteResult, nil) + adk.AddSessionValue(ctx, ExecutedStepSessionKey, updatedExecuteResult) + } + event := adk.EventFromMessage(msg, nil, schema.Assistant, "") + generator.Send(event) + generator.Close() + + return iterator + }, + ).Times(2) + + // Mock the replanner Run method to first update the plan, then respond to user + mockReplanner.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + + // First call: Update the plan + // Get the current plan from the session + plan, _ := adk.GetSessionValue(ctx, PlanSessionKey) + currentPlan := plan.(*defaultPlan) + + // Check if this is the first replanning (original plan has 3 steps) + if len(currentPlan.Steps) == 3 { + // Send a message event with the updated plan + planJSON, _ := sonic.MarshalString(updatedPlan) + msg := schema.AssistantMessage(planJSON, nil) + event := adk.EventFromMessage(msg, nil, schema.Assistant, "") + generator.Send(event) + + // Set the updated plan & execute result in the session + adk.AddSessionValue(ctx, PlanSessionKey, updatedPlan) + adk.AddSessionValue(ctx, ExecutedStepsSessionKey, []ExecutedStep{{ + Step: currentPlan.Steps[0], + Result: originalExecuteResult, + }}) + } else { + // Second call: Respond to user + responseJSON, err := sonic.MarshalString(finalResponse) + assert.NoError(t, err) + msg := schema.AssistantMessage(responseJSON, nil) + event := adk.EventFromMessage(msg, nil, schema.Assistant, "") + generator.Send(event) + + // Send exit action + action := adk.NewExitAction() + generator.Send(&adk.AgentEvent{Action: action}) + } + + generator.Close() + return iterator + }, + ).Times(2) + + conf := &Config{ + Planner: mockPlanner, + Executor: mockExecutor, + Replanner: mockReplanner, + } + + // Create the plan execute agent + agent, err := New(ctx, conf) + assert.NoError(t, err) + assert.NotNil(t, agent) + + // Run the agent + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent}) + iterator := runner.Run(ctx, userInput) + + // Collect all events + var events []*adk.AgentEvent + for { + event, ok := iterator.Next() + if !ok { + break + } + events = append(events, event) + } + + // Verify the events + assert.Greater(t, len(events), 0) + + for i, event := range events { + eventJSON, e := sonic.MarshalString(event) + assert.NoError(t, e) + t.Logf("event %d:\n%s", i, eventJSON) + } +} diff --git a/adk/prebuilt/planexecute/utils.go b/adk/prebuilt/planexecute/utils.go new file mode 100644 index 00000000..82350828 --- /dev/null +++ b/adk/prebuilt/planexecute/utils.go @@ -0,0 +1,58 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package planexecute + +import ( + "context" + + "github.com/cloudwego/eino/adk" +) + +type outputSessionKVsAgent struct { + adk.Agent +} + +func (o *outputSessionKVsAgent) Run(ctx context.Context, input *adk.AgentInput, + options ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + + iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + + iterator_ := o.Agent.Run(ctx, input, options...) + go func() { + defer generator.Close() + for { + event, ok := iterator_.Next() + if !ok { + break + } + generator.Send(event) + } + + kvs := adk.GetSessionValues(ctx) + + event := &adk.AgentEvent{ + Output: &adk.AgentOutput{CustomizedOutput: kvs}, + } + generator.Send(event) + }() + + return iterator +} + +func agentOutputSessionKVs(ctx context.Context, agent adk.Agent) (adk.Agent, error) { + return &outputSessionKVsAgent{Agent: agent}, nil +} diff --git a/adk/prebuilt/supervisor/supervisor.go b/adk/prebuilt/supervisor/supervisor.go new file mode 100644 index 00000000..5a6bbc8d --- /dev/null +++ b/adk/prebuilt/supervisor/supervisor.go @@ -0,0 +1,50 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package supervisor + +import ( + "context" + + "github.com/cloudwego/eino/adk" +) + +type Config struct { + // Supervisor specifies the agent that will act as the supervisor, coordinating and managing the sub-agents. + Supervisor adk.Agent + + // SubAgents specifies the list of agents that will be supervised and coordinated by the supervisor agent. + SubAgents []adk.Agent +} + +// New creates a supervisor-based multi-agent system with the given configuration. +// +// In the supervisor pattern, a designated supervisor agent coordinates multiple sub-agents. +// The supervisor can delegate tasks to sub-agents and receive their responses, while +// sub-agents can only communicate with the supervisor (not with each other directly). +// This hierarchical structure enables complex problem-solving through coordinated agent interactions. +func New(ctx context.Context, conf *Config) (adk.Agent, error) { + subAgents := make([]adk.Agent, 0, len(conf.SubAgents)) + supervisorName := conf.Supervisor.Name(ctx) + for _, subAgent := range conf.SubAgents { + subAgents = append(subAgents, adk.AgentWithDeterministicTransferTo(ctx, &adk.DeterministicTransferConfig{ + Agent: subAgent, + ToAgentNames: []string{supervisorName}, + })) + } + + return adk.SetSubAgents(ctx, conf.Supervisor, subAgents) +} diff --git a/adk/prebuilt/supervisor/supervisor_test.go b/adk/prebuilt/supervisor/supervisor_test.go new file mode 100644 index 00000000..07a8f332 --- /dev/null +++ b/adk/prebuilt/supervisor/supervisor_test.go @@ -0,0 +1,169 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package supervisor + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/adk" + mockAdk "github.com/cloudwego/eino/internal/mock/adk" + "github.com/cloudwego/eino/schema" +) + +// TestNewSupervisor tests the New function +func TestNewSupervisor(t *testing.T) { + ctx := context.Background() + + // Create a mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock agents + supervisorAgent := mockAdk.NewMockAgent(ctrl) + subAgent1 := mockAdk.NewMockAgent(ctrl) + subAgent2 := mockAdk.NewMockAgent(ctrl) + + supervisorAgent.EXPECT().Name(gomock.Any()).Return("SupervisorAgent").AnyTimes() + subAgent1.EXPECT().Name(gomock.Any()).Return("SubAgent1").AnyTimes() + subAgent2.EXPECT().Name(gomock.Any()).Return("SubAgent2").AnyTimes() + + aMsg, tMsg := adk.GenTransferMessages(ctx, "SubAgent1") + i, g := adk.NewAsyncIteratorPair[*adk.AgentEvent]() + g.Send(adk.EventFromMessage(aMsg, nil, schema.Assistant, "")) + event := adk.EventFromMessage(tMsg, nil, schema.Tool, tMsg.ToolName) + event.Action = &adk.AgentAction{TransferToAgent: &adk.TransferToAgentAction{DestAgentName: "SubAgent1"}} + g.Send(event) + g.Close() + supervisorAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) + + i, g = adk.NewAsyncIteratorPair[*adk.AgentEvent]() + subAgent1Msg := schema.AssistantMessage("SubAgent1", nil) + g.Send(adk.EventFromMessage(subAgent1Msg, nil, schema.Assistant, "")) + g.Close() + subAgent1.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) + + aMsg, tMsg = adk.GenTransferMessages(ctx, "SubAgent2 message") + i, g = adk.NewAsyncIteratorPair[*adk.AgentEvent]() + g.Send(adk.EventFromMessage(aMsg, nil, schema.Assistant, "")) + event = adk.EventFromMessage(tMsg, nil, schema.Tool, tMsg.ToolName) + event.Action = &adk.AgentAction{TransferToAgent: &adk.TransferToAgentAction{DestAgentName: "SubAgent2"}} + g.Send(event) + g.Close() + supervisorAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) + + i, g = adk.NewAsyncIteratorPair[*adk.AgentEvent]() + subAgent2Msg := schema.AssistantMessage("SubAgent2 message", nil) + g.Send(adk.EventFromMessage(subAgent2Msg, nil, schema.Assistant, "")) + g.Close() + subAgent2.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) + + i, g = adk.NewAsyncIteratorPair[*adk.AgentEvent]() + finishMsg := schema.AssistantMessage("finish", nil) + g.Send(adk.EventFromMessage(finishMsg, nil, schema.Assistant, "")) + g.Close() + supervisorAgent.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(i).Times(1) + + conf := &Config{ + Supervisor: supervisorAgent, + SubAgents: []adk.Agent{subAgent1, subAgent2}, + } + + multiAgent, err := New(ctx, conf) + assert.NoError(t, err) + assert.NotNil(t, multiAgent) + assert.Equal(t, "SupervisorAgent", multiAgent.Name(ctx)) + + runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: multiAgent}) + aIter := runner.Run(ctx, []adk.Message{schema.UserMessage("test")}) + + // transfer to agent1 + event, ok := aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SupervisorAgent", event.AgentName) + assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) + assert.NotEqual(t, 0, len(event.Output.MessageOutput.Message.ToolCalls)) + + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SupervisorAgent", event.AgentName) + assert.Equal(t, schema.Tool, event.Output.MessageOutput.Role) + assert.Equal(t, "SubAgent1", event.Action.TransferToAgent.DestAgentName) + + // agent1's output + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SubAgent1", event.AgentName) + assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) + assert.Equal(t, subAgent1Msg.Content, event.Output.MessageOutput.Message.Content) + + // transfer back to supervisor + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SubAgent1", event.AgentName) + assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) + assert.NotEqual(t, 0, len(event.Output.MessageOutput.Message.ToolCalls)) + + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SubAgent1", event.AgentName) + assert.Equal(t, schema.Tool, event.Output.MessageOutput.Role) + assert.Equal(t, "SupervisorAgent", event.Action.TransferToAgent.DestAgentName) + + // transfer to agent2 + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SupervisorAgent", event.AgentName) + assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) + assert.NotEqual(t, 0, len(event.Output.MessageOutput.Message.ToolCalls)) + + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SupervisorAgent", event.AgentName) + assert.Equal(t, schema.Tool, event.Output.MessageOutput.Role) + assert.Equal(t, "SubAgent2", event.Action.TransferToAgent.DestAgentName) + + // agent1's output + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SubAgent2", event.AgentName) + assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) + assert.Equal(t, subAgent2Msg.Content, event.Output.MessageOutput.Message.Content) + + // transfer back to supervisor + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SubAgent2", event.AgentName) + assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) + assert.NotEqual(t, 0, len(event.Output.MessageOutput.Message.ToolCalls)) + + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SubAgent2", event.AgentName) + assert.Equal(t, schema.Tool, event.Output.MessageOutput.Role) + assert.Equal(t, "SupervisorAgent", event.Action.TransferToAgent.DestAgentName) + + // finish + event, ok = aIter.Next() + assert.True(t, ok) + assert.Equal(t, "SupervisorAgent", event.AgentName) + assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role) + assert.Equal(t, finishMsg.Content, event.Output.MessageOutput.Message.Content) +} diff --git a/adk/react.go b/adk/react.go new file mode 100644 index 00000000..be972198 --- /dev/null +++ b/adk/react.go @@ -0,0 +1,256 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "io" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +var ErrExceedMaxIterations = errors.New("exceeds max iterations") + +type State struct { + Messages []Message + + ReturnDirectlyToolCallID string + + ToolGenActions map[string]*AgentAction + + AgentName string + + AgentToolInterruptData map[string] /*tool call id*/ *agentToolInterruptInfo + + RemainingIterations int +} + +type agentToolInterruptInfo struct { + LastEvent *AgentEvent + Data []byte +} + +func SendToolGenAction(ctx context.Context, toolName string, action *AgentAction) error { + return compose.ProcessState(ctx, func(ctx context.Context, st *State) error { + st.ToolGenActions[toolName] = action + + return nil + }) +} + +func popToolGenAction(ctx context.Context, toolName string) *AgentAction { + var action *AgentAction + err := compose.ProcessState(ctx, func(ctx context.Context, st *State) error { + action = st.ToolGenActions[toolName] + if action != nil { + delete(st.ToolGenActions, toolName) + } + + return nil + }) + + if err != nil { + panic("impossible") + } + + return action +} + +type reactConfig struct { + model model.ToolCallingChatModel + + toolsConfig *compose.ToolsNodeConfig + + toolsReturnDirectly map[string]bool + + agentName string + + maxIterations int +} + +func genToolInfos(ctx context.Context, config *compose.ToolsNodeConfig) ([]*schema.ToolInfo, error) { + toolInfos := make([]*schema.ToolInfo, 0, len(config.Tools)) + for _, t := range config.Tools { + tl, err := t.Info(ctx) + if err != nil { + return nil, err + } + + toolInfos = append(toolInfos, tl) + } + + return toolInfos, nil +} + +type reactGraph = *compose.Graph[[]Message, Message] +type sToolNodeOutput = *schema.StreamReader[[]Message] +type sGraphOutput = MessageStream + +func getReturnDirectlyToolCallID(ctx context.Context) string { + var toolCallID string + handler := func(_ context.Context, st *State) error { + toolCallID = st.ReturnDirectlyToolCallID + return nil + } + + _ = compose.ProcessState(ctx, handler) + + return toolCallID +} + +func newReact(ctx context.Context, config *reactConfig) (reactGraph, error) { + genState := func(ctx context.Context) *State { + return &State{ + ToolGenActions: map[string]*AgentAction{}, + AgentName: config.agentName, + AgentToolInterruptData: make(map[string]*agentToolInterruptInfo), + RemainingIterations: func() int { + if config.maxIterations <= 0 { + return 20 + } + return config.maxIterations + }(), + } + } + + const ( + chatModel_ = "ChatModel" + toolNode_ = "ToolNode" + ) + + g := compose.NewGraph[[]Message, Message](compose.WithGenLocalState(genState)) + + toolsInfo, err := genToolInfos(ctx, config.toolsConfig) + if err != nil { + return nil, err + } + + chatModel, err := config.model.WithTools(toolsInfo) + if err != nil { + return nil, err + } + + toolsNode, err := compose.NewToolNode(ctx, config.toolsConfig) + if err != nil { + return nil, err + } + + modelPreHandle := func(ctx context.Context, input []Message, st *State) ([]Message, error) { + if st.RemainingIterations <= 0 { + return nil, ErrExceedMaxIterations + } + st.RemainingIterations-- + + st.Messages = append(st.Messages, input...) + return st.Messages, nil + } + _ = g.AddChatModelNode(chatModel_, chatModel, + compose.WithStatePreHandler(modelPreHandle), compose.WithNodeName(chatModel_)) + + toolPreHandle := func(ctx context.Context, input Message, st *State) (Message, error) { + if input != nil { + // isn't resume + st.Messages = append(st.Messages, input) + } + + input = st.Messages[len(st.Messages)-1] + if len(config.toolsReturnDirectly) > 0 { + for i := range input.ToolCalls { + toolName := input.ToolCalls[i].Function.Name + if config.toolsReturnDirectly[toolName] { + st.ReturnDirectlyToolCallID = input.ToolCalls[i].ID + } + } + } + + return input, nil + } + + _ = g.AddToolsNode(toolNode_, toolsNode, + compose.WithStatePreHandler(toolPreHandle), compose.WithNodeName(toolNode_)) + + _ = g.AddEdge(compose.START, chatModel_) + + toolCallCheck := func(ctx context.Context, sMsg MessageStream) (string, error) { + defer sMsg.Close() + for { + chunk, err_ := sMsg.Recv() + if err_ != nil { + if err_ == io.EOF { + return compose.END, nil + } + + return "", err_ + } + + if len(chunk.ToolCalls) > 0 { + return toolNode_, nil + } + } + } + branch := compose.NewStreamGraphBranch(toolCallCheck, map[string]bool{compose.END: true, toolNode_: true}) + _ = g.AddBranch(chatModel_, branch) + + if len(config.toolsReturnDirectly) == 0 { + _ = g.AddEdge(toolNode_, chatModel_) + } else { + const ( + toolNodeToEndConverter = "ToolNodeToEndConverter" + ) + + cvt := func(ctx context.Context, sToolCallMessages sToolNodeOutput) (sGraphOutput, error) { + id := getReturnDirectlyToolCallID(ctx) + + return schema.StreamReaderWithConvert(sToolCallMessages, + func(in []Message) (Message, error) { + + for _, chunk := range in { + if chunk.ToolCallID == id { + return chunk, nil + } + } + + return nil, schema.ErrNoValue + }), nil + } + + _ = g.AddLambdaNode(toolNodeToEndConverter, compose.TransformableLambda(cvt), + compose.WithNodeName(toolNodeToEndConverter)) + _ = g.AddEdge(toolNodeToEndConverter, compose.END) + + checkReturnDirect := func(ctx context.Context, + sToolCallMessages sToolNodeOutput) (string, error) { + + id := getReturnDirectlyToolCallID(ctx) + + if len(id) != 0 { + return toolNodeToEndConverter, nil + } + + return chatModel_, nil + } + + branch = compose.NewStreamGraphBranch(checkReturnDirect, + map[string]bool{toolNodeToEndConverter: true, chatModel_: true}) + _ = g.AddBranch(toolNode_, branch) + } + + return g, nil +} diff --git a/adk/react_test.go b/adk/react_test.go new file mode 100644 index 00000000..169496f5 --- /dev/null +++ b/adk/react_test.go @@ -0,0 +1,584 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "testing" + + "github.com/bytedance/sonic" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + mockModel "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/schema" +) + +// TestReact tests the newReact function with different scenarios +func TestReact(t *testing.T) { + // Basic test for newReact function + t.Run("Invoke", func(t *testing.T) { + ctx := context.Background() + + // Create a fake tool for testing + fakeTool := &fakeToolForTest{ + tarCount: 3, + } + + info, err := fakeTool.Info(ctx) + assert.NoError(t, err) + + // Create a mock chat model + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the mock model + times := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) (Message, error) { + times++ + if times <= 2 { + return schema.AssistantMessage("hello test", + []schema.ToolCall{ + { + ID: randStrForTest(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStrForTest()), + }, + }, + }), + nil + } + + return schema.AssistantMessage("bye", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // Create a reactConfig + config := &reactConfig{ + model: cm, + toolsConfig: &compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool}, + }, + toolsReturnDirectly: map[string]bool{}, + } + + graph, err := newReact(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, graph) + + compiled, err := graph.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, compiled) + + // Test with a user message + result, err := compiled.Invoke(ctx, []Message{ + { + Role: schema.User, + Content: "Use the test tool to say hello", + }, + }) + assert.NoError(t, err) + assert.NotNil(t, result) + }) + + // Test with toolsReturnDirectly + t.Run("ToolsReturnDirectly", func(t *testing.T) { + ctx := context.Background() + + // Create a fake tool for testing + fakeTool := &fakeToolForTest{ + tarCount: 3, + } + + info, err := fakeTool.Info(ctx) + assert.NoError(t, err) + + // Create a mock chat model + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the mock model + times := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) (Message, error) { + times++ + if times <= 2 { + return schema.AssistantMessage("hello test", + []schema.ToolCall{ + { + ID: randStrForTest(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStrForTest()), + }, + }, + }), + nil + } + + return schema.AssistantMessage("bye", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // Create a reactConfig with toolsReturnDirectly + config := &reactConfig{ + model: cm, + toolsConfig: &compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool}, + }, + toolsReturnDirectly: map[string]bool{info.Name: true}, + } + + graph, err := newReact(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, graph) + + compiled, err := graph.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, compiled) + + // Test with a user message when tool returns directly + result, err := compiled.Invoke(ctx, []Message{ + { + Role: schema.User, + Content: "Use the test tool to say hello", + }, + }) + assert.NoError(t, err) + assert.NotNil(t, result) + + assert.Equal(t, result.Role, schema.Tool) + }) + + // Test streaming functionality + t.Run("Stream", func(t *testing.T) { + ctx := context.Background() + + // Create a fake tool for testing + fakeTool := &fakeToolForTest{ + tarCount: 3, + } + + fakeStreamTool := &fakeStreamToolForTest{ + tarCount: 3, + } + + // Create a mock chat model + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the mock model + times := 0 + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) ( + MessageStream, error) { + sr, sw := schema.Pipe[Message](1) + defer sw.Close() + + info, _ := fakeTool.Info(ctx) + streamInfo, _ := fakeStreamTool.Info(ctx) + + times++ + if times <= 1 { + sw.Send(schema.AssistantMessage("hello test", + []schema.ToolCall{ + { + ID: randStrForTest(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "tool"}`, randStrForTest()), + }, + }, + }), + nil) + return sr, nil + } else if times == 2 { + sw.Send(schema.AssistantMessage("hello stream", + []schema.ToolCall{ + { + ID: randStrForTest(), + Function: schema.FunctionCall{ + Name: streamInfo.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "stream tool"}`, randStrForTest()), + }, + }, + }), + nil) + return sr, nil + } + + sw.Send(schema.AssistantMessage("bye", nil), nil) + return sr, nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // Create a reactConfig + config := &reactConfig{ + model: cm, + toolsConfig: &compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool, fakeStreamTool}, + }, + toolsReturnDirectly: map[string]bool{}, + } + + graph, err := newReact(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, graph) + + compiled, err := graph.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, compiled) + + // Test streaming with a user message + outStream, err := compiled.Stream(ctx, []Message{ + { + Role: schema.User, + Content: "Use the test tool to say hello", + }, + }) + assert.NoError(t, err) + assert.NotNil(t, outStream) + + defer outStream.Close() + + msgs := make([]Message, 0) + for { + msg, err_ := outStream.Recv() + if err_ != nil { + if errors.Is(err_, io.EOF) { + break + } + t.Fatal(err_) + } + + msgs = append(msgs, msg) + } + + assert.NotEmpty(t, msgs) + }) + + // Test streaming with toolsReturnDirectly + t.Run("StreamWithToolsReturnDirectly", func(t *testing.T) { + ctx := context.Background() + + // Create a fake tool for testing + fakeTool := &fakeToolForTest{ + tarCount: 3, + } + + fakeStreamTool := &fakeStreamToolForTest{ + tarCount: 3, + } + + // Create a mock chat model + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the mock model + times := 0 + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) ( + MessageStream, error) { + sr, sw := schema.Pipe[Message](1) + defer sw.Close() + + info, _ := fakeTool.Info(ctx) + streamInfo, _ := fakeStreamTool.Info(ctx) + + times++ + if times <= 1 { + sw.Send(schema.AssistantMessage("hello test", + []schema.ToolCall{ + { + ID: randStrForTest(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "tool"}`, randStrForTest()), + }, + }, + }), + nil) + return sr, nil + } else if times == 2 { + sw.Send(schema.AssistantMessage("hello stream", + []schema.ToolCall{ + { + ID: randStrForTest(), + Function: schema.FunctionCall{ + Name: streamInfo.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "stream tool"}`, randStrForTest()), + }, + }, + }), + nil) + return sr, nil + } + + sw.Send(schema.AssistantMessage("bye", nil), nil) + return sr, nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + streamInfo, err := fakeStreamTool.Info(ctx) + assert.NoError(t, err) + + // Create a reactConfig with toolsReturnDirectly + config := &reactConfig{ + model: cm, + toolsConfig: &compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool, fakeStreamTool}, + }, + toolsReturnDirectly: map[string]bool{streamInfo.Name: true}, + } + + graph, err := newReact(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, graph) + + compiled, err := graph.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, compiled) + + // Reset times counter + times = 0 + + // Test streaming with a user message when tool returns directly + outStream, err := compiled.Stream(ctx, []Message{ + { + Role: schema.User, + Content: "Use the test tool to say hello", + }, + }) + assert.NoError(t, err) + assert.NotNil(t, outStream) + + msgs := make([]Message, 0) + for { + msg, err_ := outStream.Recv() + if err_ != nil { + if errors.Is(err_, io.EOF) { + break + } + t.Fatal(err) + } + + assert.Equal(t, msg.Role, schema.Tool) + + msgs = append(msgs, msg) + } + + outStream.Close() + + assert.NotEmpty(t, msgs) + }) + + t.Run("MaxIterations", func(t *testing.T) { + ctx := context.Background() + + // Create a fake tool for testing + fakeTool := &fakeToolForTest{ + tarCount: 3, + } + + info, err := fakeTool.Info(ctx) + assert.NoError(t, err) + + // Create a mock chat model + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + // Set up expectations for the mock model + times := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []Message, opts ...model.Option) (Message, error) { + times++ + if times <= 5 { + return schema.AssistantMessage("hello test", + []schema.ToolCall{ + { + ID: randStrForTest(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStrForTest()), + }, + }, + }), + nil + } + + return schema.AssistantMessage("bye", nil), nil + }).AnyTimes() + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes() + + // don't exceed max iterations + config := &reactConfig{ + model: cm, + toolsConfig: &compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool}, + }, + toolsReturnDirectly: map[string]bool{}, + maxIterations: 6, + } + + graph, err := newReact(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, graph) + + compiled, err := graph.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, compiled) + + // Test with a user message + result, err := compiled.Invoke(ctx, []Message{ + { + Role: schema.User, + Content: "Use the test tool to say hello", + }, + }) + assert.NoError(t, err) + assert.Equal(t, result.Content, "bye") + + // reset chat model times counter + times = 0 + // exceed max iterations + config = &reactConfig{ + model: cm, + toolsConfig: &compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool}, + }, + toolsReturnDirectly: map[string]bool{}, + maxIterations: 5, + } + + graph, err = newReact(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, graph) + + compiled, err = graph.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, compiled) + + // Test with a user message + result, err = compiled.Invoke(ctx, []Message{ + { + Role: schema.User, + Content: "Use the test tool to say hello", + }, + }) + assert.Error(t, err) + t.Logf("actual error: %v", err.Error()) + assert.ErrorIs(t, err, ErrExceedMaxIterations) + + assert.Contains(t, err.Error(), ErrExceedMaxIterations.Error()) + }) +} + +// Helper types and functions for testing + +type fakeStreamToolForTest struct { + tarCount int + curCount int +} + +func (t *fakeStreamToolForTest) StreamableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) ( + *schema.StreamReader[string], error) { + p := &fakeToolInputForTest{} + err := sonic.UnmarshalString(argumentsInJSON, p) + if err != nil { + return nil, err + } + + if t.curCount >= t.tarCount { + s := schema.StreamReaderFromArray([]string{`{"say": "bye"}`}) + return s, nil + } + t.curCount++ + s := schema.StreamReaderFromArray([]string{fmt.Sprintf(`{"say": "hello %v"}`, p.Name)}) + return s, nil +} + +type fakeToolForTest struct { + tarCount int + curCount int +} + +func (t *fakeToolForTest) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: "test_tool", + Desc: "test tool for unit testing", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "name": { + Desc: "user name for testing", + Required: true, + Type: schema.String, + }, + }), + }, nil +} + +func (t *fakeStreamToolForTest) Info(_ context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: "test_stream_tool", + Desc: "test stream tool for unit testing", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "name": { + Desc: "user name for testing", + Required: true, + Type: schema.String, + }, + }), + }, nil +} + +func (t *fakeToolForTest) InvokableRun(_ context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) { + p := &fakeToolInputForTest{} + err := sonic.UnmarshalString(argumentsInJSON, p) + if err != nil { + return "", err + } + + if t.curCount >= t.tarCount { + return `{"say": "bye"}`, nil + } + + t.curCount++ + return fmt.Sprintf(`{"say": "hello %v"}`, p.Name), nil +} + +type fakeToolInputForTest struct { + Name string `json:"name"` +} + +func randStrForTest() string { + seeds := []rune("test seed") + b := make([]rune, 8) + for i := range b { + b[i] = seeds[rand.Intn(len(seeds))] + } + return string(b) +} diff --git a/adk/runctx.go b/adk/runctx.go new file mode 100644 index 00000000..a16a9398 --- /dev/null +++ b/adk/runctx.go @@ -0,0 +1,276 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "sync" + + "github.com/cloudwego/eino/schema" +) + +type runSession struct { + Events []*agentEventWrapper + Values map[string]any + + interruptRunCtxs []*runContext // won't consider concurrency now + + mtx sync.Mutex +} + +type agentEventWrapper struct { + *AgentEvent + mu sync.Mutex + concatenatedMessage Message +} + +type otherAgentEventWrapperForEncode agentEventWrapper + +func (a *agentEventWrapper) GobEncode() ([]byte, error) { + if a.concatenatedMessage != nil && a.Output != nil && a.Output.MessageOutput != nil && a.Output.MessageOutput.IsStreaming { + a.Output.MessageOutput.MessageStream = schema.StreamReaderFromArray([]Message{a.concatenatedMessage}) + } + + buf := &bytes.Buffer{} + err := gob.NewEncoder(buf).Encode((*otherAgentEventWrapperForEncode)(a)) + if err != nil { + return nil, fmt.Errorf("failed to gob encode agent event wrapper: %w", err) + } + return buf.Bytes(), nil +} + +func (a *agentEventWrapper) GobDecode(b []byte) error { + return gob.NewDecoder(bytes.NewReader(b)).Decode((*otherAgentEventWrapperForEncode)(a)) +} + +func newRunSession() *runSession { + return &runSession{ + Values: make(map[string]any), + } +} + +func getInterruptRunCtxs(ctx context.Context) []*runContext { + session := getSession(ctx) + if session == nil { + return nil + } + return session.getInterruptRunCtxs() +} + +func appendInterruptRunCtx(ctx context.Context, interruptRunCtx *runContext) { + session := getSession(ctx) + if session == nil { + return + } + session.appendInterruptRunCtx(interruptRunCtx) +} + +func replaceInterruptRunCtx(ctx context.Context, interruptRunCtx *runContext) { + session := getSession(ctx) + if session == nil { + return + } + session.replaceInterruptRunCtx(interruptRunCtx) +} + +func GetSessionValues(ctx context.Context) map[string]any { + session := getSession(ctx) + if session == nil { + return map[string]any{} + } + + return session.getValues() +} + +func AddSessionValue(ctx context.Context, key string, value any) { + session := getSession(ctx) + if session == nil { + return + } + + session.addValue(key, value) +} + +func AddSessionValues(ctx context.Context, kvs map[string]any) { + session := getSession(ctx) + if session == nil { + return + } + + session.addValues(kvs) +} + +func GetSessionValue(ctx context.Context, key string) (any, bool) { + session := getSession(ctx) + if session == nil { + return nil, false + } + + return session.getValue(key) +} + +func (rs *runSession) addEvent(event *AgentEvent) { + rs.mtx.Lock() + rs.Events = append(rs.Events, &agentEventWrapper{ + AgentEvent: event, + }) + rs.mtx.Unlock() +} + +func (rs *runSession) getEvents() []*agentEventWrapper { + rs.mtx.Lock() + events := rs.Events + rs.mtx.Unlock() + + return events +} + +func (rs *runSession) getInterruptRunCtxs() []*runContext { + rs.mtx.Lock() + defer rs.mtx.Unlock() + return rs.interruptRunCtxs +} + +func (rs *runSession) appendInterruptRunCtx(runCtx *runContext) { + rs.mtx.Lock() + rs.interruptRunCtxs = append(rs.interruptRunCtxs, runCtx) + rs.mtx.Unlock() +} + +func (rs *runSession) replaceInterruptRunCtx(interruptRunCtx *runContext) { + // remove runctx whose path belongs to the new run ctx, and append the new run ctx + rs.mtx.Lock() + for i := 0; i < len(rs.interruptRunCtxs); i++ { + rc := rs.interruptRunCtxs[i] + if belongToRunPath(interruptRunCtx.RunPath, rc.RunPath) { + rs.interruptRunCtxs = append(rs.interruptRunCtxs[:i], rs.interruptRunCtxs[i+1:]...) + i-- + } + } + rs.interruptRunCtxs = append(rs.interruptRunCtxs, interruptRunCtx) + rs.mtx.Unlock() +} + +func (rs *runSession) getValues() map[string]any { + rs.mtx.Lock() + values := make(map[string]any, len(rs.Values)) + for k, v := range rs.Values { + values[k] = v + } + rs.mtx.Unlock() + + return values +} + +func (rs *runSession) addValue(key string, value any) { + rs.mtx.Lock() + rs.Values[key] = value + rs.mtx.Unlock() +} + +func (rs *runSession) addValues(kvs map[string]any) { + rs.mtx.Lock() + for k, v := range kvs { + rs.Values[k] = v + } + rs.mtx.Unlock() +} + +func (rs *runSession) getValue(key string) (any, bool) { + rs.mtx.Lock() + value, ok := rs.Values[key] + rs.mtx.Unlock() + + return value, ok +} + +type runContext struct { + RootInput *AgentInput + RunPath []RunStep + + Session *runSession +} + +func (rc *runContext) isRoot() bool { + return len(rc.RunPath) == 1 +} + +func (rc *runContext) deepCopy() *runContext { + copied := &runContext{ + RootInput: rc.RootInput, + RunPath: make([]RunStep, len(rc.RunPath)), + Session: rc.Session, + } + + copy(copied.RunPath, rc.RunPath) + + return copied +} + +type runCtxKey struct{} + +func getRunCtx(ctx context.Context) *runContext { + runCtx, ok := ctx.Value(runCtxKey{}).(*runContext) + if !ok { + return nil + } + return runCtx +} + +func setRunCtx(ctx context.Context, runCtx *runContext) context.Context { + return context.WithValue(ctx, runCtxKey{}, runCtx) +} + +func initRunCtx(ctx context.Context, agentName string, input *AgentInput) (context.Context, *runContext) { + runCtx := getRunCtx(ctx) + if runCtx != nil { + runCtx = runCtx.deepCopy() + } else { + runCtx = &runContext{Session: newRunSession()} + } + + runCtx.RunPath = append(runCtx.RunPath, RunStep{agentName}) + if runCtx.isRoot() { + runCtx.RootInput = input + } + + return setRunCtx(ctx, runCtx), runCtx +} + +// ClearRunCtx clears the run context of the multi-agents. This is particularly useful +// when a customized agent with a multi-agents inside it is set as a subagent of another +// multi-agents. In such cases, it's not expected to pass the outside run context to the +// inside multi-agents, so this function helps isolate the contexts properly. +func ClearRunCtx(ctx context.Context) context.Context { + return context.WithValue(ctx, runCtxKey{}, nil) +} + +func ctxWithNewRunCtx(ctx context.Context) context.Context { + return setRunCtx(ctx, &runContext{Session: newRunSession()}) +} + +func getSession(ctx context.Context) *runSession { + runCtx := getRunCtx(ctx) + if runCtx != nil { + return runCtx.Session + } + + return nil +} diff --git a/adk/runner.go b/adk/runner.go new file mode 100644 index 00000000..667f6122 --- /dev/null +++ b/adk/runner.go @@ -0,0 +1,147 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "fmt" + "runtime/debug" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/internal/safe" + "github.com/cloudwego/eino/schema" +) + +type Runner struct { + a Agent + enableStreaming bool + store compose.CheckPointStore +} + +type RunnerConfig struct { + Agent Agent + EnableStreaming bool + + CheckPointStore compose.CheckPointStore +} + +func NewRunner(_ context.Context, conf RunnerConfig) *Runner { + return &Runner{ + enableStreaming: conf.EnableStreaming, + a: conf.Agent, + store: conf.CheckPointStore, + } +} + +func (r *Runner) Run(ctx context.Context, messages []Message, + opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + o := getCommonOptions(nil, opts...) + + fa := toFlowAgent(ctx, r.a) + + input := &AgentInput{ + Messages: messages, + EnableStreaming: r.enableStreaming, + } + + ctx = ctxWithNewRunCtx(ctx) + + AddSessionValues(ctx, o.sessionValues) + + iter := fa.Run(ctx, input, opts...) + if r.store == nil { + return iter + } + + niter, gen := NewAsyncIteratorPair[*AgentEvent]() + + go r.handleIter(ctx, iter, gen, o.checkPointID) + return niter +} + +func getInterruptRunCtx(ctx context.Context) *runContext { + cs := getInterruptRunCtxs(ctx) + if len(cs) == 0 { + return nil + } + return cs[0] // assume that concurrency isn't existed, so only one run ctx is in ctx +} + +func (r *Runner) Query(ctx context.Context, + query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + + return r.Run(ctx, []Message{schema.UserMessage(query)}, opts...) +} + +func (r *Runner) Resume(ctx context.Context, checkPointID string, opts ...AgentRunOption) (*AsyncIterator[*AgentEvent], error) { + if r.store == nil { + return nil, fmt.Errorf("failed to resume: store is nil") + } + + runCtx, info, existed, err := getCheckPoint(ctx, r.store, checkPointID) + if err != nil { + return nil, fmt.Errorf("failed to get checkpoint: %w", err) + } + if !existed { + return nil, fmt.Errorf("checkpoint[%s] is not existed", checkPointID) + } + + ctx = setRunCtx(ctx, runCtx) + aIter := toFlowAgent(ctx, r.a).Resume(ctx, info, opts...) + if r.store == nil { + return aIter, nil + } + + niter, gen := NewAsyncIteratorPair[*AgentEvent]() + + go r.handleIter(ctx, aIter, gen, &checkPointID) + return niter, nil +} + +func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent], checkPointID *string) { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + gen.Send(&AgentEvent{Err: e}) + } + + gen.Close() + }() + var interruptedInfo *InterruptInfo + for { + event, ok := aIter.Next() + if !ok { + break + } + + if event.Action != nil && event.Action.Interrupted != nil { + interruptedInfo = event.Action.Interrupted + } else { + interruptedInfo = nil + } + + gen.Send(event) + } + + if interruptedInfo != nil && checkPointID != nil { + err := saveCheckPoint(ctx, r.store, *checkPointID, getInterruptRunCtx(ctx), interruptedInfo) + if err != nil { + gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint: %w", err)}) + } + } +} diff --git a/adk/runner_test.go b/adk/runner_test.go new file mode 100644 index 00000000..6ab3f128 --- /dev/null +++ b/adk/runner_test.go @@ -0,0 +1,263 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +// mockRunnerAgent is a simple implementation of the Agent interface for testing Runner +type mockRunnerAgent struct { + name string + description string + responses []*AgentEvent + // Track calls to verify correct parameters were passed + callCount int + lastInput *AgentInput + enableStreaming bool +} + +func (a *mockRunnerAgent) Name(_ context.Context) string { + return a.name +} + +func (a *mockRunnerAgent) Description(_ context.Context) string { + return a.description +} + +func (a *mockRunnerAgent) Run(_ context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + // Record the call details for verification + a.callCount++ + a.lastInput = input + a.enableStreaming = input.EnableStreaming + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + + go func() { + defer generator.Close() + + for _, event := range a.responses { + generator.Send(event) + + // If the event has an Exit action, stop sending events + if event.Action != nil && event.Action.Exit { + break + } + } + }() + + return iterator +} + +func newMockRunnerAgent(name, description string, responses []*AgentEvent) *mockRunnerAgent { + return &mockRunnerAgent{ + name: name, + description: description, + responses: responses, + } +} + +func TestNewRunner(t *testing.T) { + ctx := context.Background() + config := RunnerConfig{} + + runner := NewRunner(ctx, config) + + // Verify that a non-nil runner is returned + assert.NotNil(t, runner) +} + +func TestRunner_Run(t *testing.T) { + ctx := context.Background() + + // Create a mock agent with predefined responses + mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent for Runner", []*AgentEvent{ + { + AgentName: "TestAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Response from test agent", nil), + Role: schema.Assistant, + }, + }}, + }) + + // Create a runner + runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_}) + + // Create test messages + msgs := []Message{ + schema.UserMessage("Hello, agent!"), + } + + // Test Run method without streaming + iterator := runner.Run(ctx, msgs) + + // Verify that the agent's Run method was called with the correct parameters + assert.Equal(t, 1, mockAgent_.callCount) + assert.Equal(t, msgs, mockAgent_.lastInput.Messages) + assert.False(t, mockAgent_.enableStreaming) + + // Verify that we can get the expected response from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.Equal(t, "TestAgent", event.AgentName) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + assert.NotNil(t, event.Output.MessageOutput.Message) + assert.Equal(t, "Response from test agent", event.Output.MessageOutput.Message.Content) + + // Verify that the iterator is now closed + _, ok = iterator.Next() + assert.False(t, ok) +} + +func TestRunner_Run_WithStreaming(t *testing.T) { + ctx := context.Background() + + // Create a mock agent with predefined responses + mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent for Runner", []*AgentEvent{ + { + AgentName: "TestAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + Message: nil, + MessageStream: schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("Streaming response", nil)}), + Role: schema.Assistant, + }, + }}, + }) + + // Create a runner + runner := NewRunner(ctx, RunnerConfig{EnableStreaming: true, Agent: mockAgent_}) + + // Create test messages + msgs := []Message{ + schema.UserMessage("Hello, agent!"), + } + + // Test Run method with streaming enabled + iterator := runner.Run(ctx, msgs) + + // Verify that the agent's Run method was called with the correct parameters + assert.Equal(t, 1, mockAgent_.callCount) + assert.Equal(t, msgs, mockAgent_.lastInput.Messages) + assert.True(t, mockAgent_.enableStreaming) + + // Verify that we can get the expected response from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.Equal(t, "TestAgent", event.AgentName) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + assert.True(t, event.Output.MessageOutput.IsStreaming) + + // Verify that the iterator is now closed + _, ok = iterator.Next() + assert.False(t, ok) +} + +func TestRunner_Query(t *testing.T) { + ctx := context.Background() + + // Create a mock agent with predefined responses + mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent for Runner", []*AgentEvent{ + { + AgentName: "TestAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Response to query", nil), + Role: schema.Assistant, + }, + }}, + }) + + // Create a runner + runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_}) + + // Test Query method + iterator := runner.Query(ctx, "Test query") + + // Verify that the agent's Run method was called with the correct parameters + assert.Equal(t, 1, mockAgent_.callCount) + assert.Equal(t, 1, len(mockAgent_.lastInput.Messages)) + assert.Equal(t, "Test query", mockAgent_.lastInput.Messages[0].Content) + assert.False(t, mockAgent_.enableStreaming) + + // Verify that we can get the expected response from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.Equal(t, "TestAgent", event.AgentName) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + assert.NotNil(t, event.Output.MessageOutput.Message) + assert.Equal(t, "Response to query", event.Output.MessageOutput.Message.Content) + + // Verify that the iterator is now closed + _, ok = iterator.Next() + assert.False(t, ok) +} + +func TestRunner_Query_WithStreaming(t *testing.T) { + ctx := context.Background() + + // Create a mock agent with predefined responses + mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent for Runner", []*AgentEvent{ + { + AgentName: "TestAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: true, + Message: nil, + MessageStream: schema.StreamReaderFromArray([]*schema.Message{schema.AssistantMessage("Streaming query response", nil)}), + Role: schema.Assistant, + }, + }}, + }) + + // Create a runner + runner := NewRunner(ctx, RunnerConfig{EnableStreaming: true, Agent: mockAgent_}) + + // Test Query method with streaming enabled + iterator := runner.Query(ctx, "Test query") + + // Verify that the agent's Run method was called with the correct parameters + assert.Equal(t, 1, mockAgent_.callCount) + assert.Equal(t, 1, len(mockAgent_.lastInput.Messages)) + assert.Equal(t, "Test query", mockAgent_.lastInput.Messages[0].Content) + assert.True(t, mockAgent_.enableStreaming) + + // Verify that we can get the expected response from the iterator + event, ok := iterator.Next() + assert.True(t, ok) + assert.Equal(t, "TestAgent", event.AgentName) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + assert.True(t, event.Output.MessageOutput.IsStreaming) + + // Verify that the iterator is now closed + _, ok = iterator.Next() + assert.False(t, ok) +} diff --git a/adk/utils.go b/adk/utils.go new file mode 100644 index 00000000..fb81f894 --- /dev/null +++ b/adk/utils.go @@ -0,0 +1,220 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "io" + "strings" + + "github.com/google/uuid" + + "github.com/cloudwego/eino/internal" + "github.com/cloudwego/eino/schema" +) + +type AsyncIterator[T any] struct { + ch *internal.UnboundedChan[T] +} + +func (ai *AsyncIterator[T]) Next() (T, bool) { + return ai.ch.Receive() +} + +type AsyncGenerator[T any] struct { + ch *internal.UnboundedChan[T] +} + +func (ag *AsyncGenerator[T]) Send(v T) { + ag.ch.Send(v) +} + +func (ag *AsyncGenerator[T]) Close() { + ag.ch.Close() +} + +func NewAsyncIteratorPair[T any]() (*AsyncIterator[T], *AsyncGenerator[T]) { + ch := internal.NewUnboundedChan[T]() + return &AsyncIterator[T]{ch}, &AsyncGenerator[T]{ch} +} + +func copyMap[K comparable, V any](m map[K]V) map[K]V { + res := make(map[K]V, len(m)) + for k, v := range m { + res[k] = v + } + return res +} + +func concatInstructions(instructions ...string) string { + var sb strings.Builder + sb.WriteString(instructions[0]) + for i := 1; i < len(instructions); i++ { + sb.WriteString("\n\n") + sb.WriteString(instructions[i]) + } + + return sb.String() +} + +func GenTransferMessages(_ context.Context, destAgentName string) (Message, Message) { + toolCallID := uuid.NewString() + tooCall := schema.ToolCall{ID: toolCallID, Function: schema.FunctionCall{Name: TransferToAgentToolName, Arguments: destAgentName}} + assistantMessage := schema.AssistantMessage("", []schema.ToolCall{tooCall}) + toolMessage := schema.ToolMessage(transferToAgentToolOutput(destAgentName), toolCallID, schema.WithToolName(TransferToAgentToolName)) + return assistantMessage, toolMessage +} + +// set automatic close for event's message stream +func setAutomaticClose(e *AgentEvent) { + if e.Output == nil || e.Output.MessageOutput == nil || !e.Output.MessageOutput.IsStreaming { + return + } + + e.Output.MessageOutput.MessageStream.SetAutomaticClose() +} + +func getMessageFromWrappedEvent(e *agentEventWrapper) (Message, error) { + if e.AgentEvent.Output == nil || e.AgentEvent.Output.MessageOutput == nil { + return nil, nil + } + + if !e.AgentEvent.Output.MessageOutput.IsStreaming { + return e.AgentEvent.Output.MessageOutput.Message, nil + } + + if e.concatenatedMessage != nil { + return e.concatenatedMessage, nil + } + + e.mu.Lock() + defer e.mu.Unlock() + if e.concatenatedMessage != nil { + return e.concatenatedMessage, nil + } + + var ( + msgs []Message + s = e.AgentEvent.Output.MessageOutput.MessageStream + ) + + defer s.Close() + for { + msg, err := s.Recv() + if err != nil { + if err == io.EOF { + break + } + + return nil, err + } + + msgs = append(msgs, msg) + } + + if len(msgs) == 0 { + return nil, errors.New("no messages in MessageVariant.MessageStream") + } + + if len(msgs) == 1 { + e.concatenatedMessage = msgs[0] + } else { + var err error + e.concatenatedMessage, err = schema.ConcatMessages(msgs) + if err != nil { + return nil, err + } + } + + return e.concatenatedMessage, nil +} + +// copyAgentEvent copies an AgentEvent. +// If the MessageVariant is streaming, the MessageStream will be copied. +// RunPath will be deep copied. +// The result of Copy will be a new AgentEvent that is: +// - safe to set fields of AgentEvent +// - safe to extend RunPath +// - safe to receive from MessageStream +// NOTE: even if the AgentEvent is copied, it's still not recommended to modify +// the Message itself or Chunks of the MessageStream, as they are not copied. +// NOTE: if you have CustomizedOutput or CustomizedAction, they are NOT copied. +func copyAgentEvent(ae *AgentEvent) *AgentEvent { + rp := make([]RunStep, len(ae.RunPath)) + copy(rp, ae.RunPath) + + copied := &AgentEvent{ + AgentName: ae.AgentName, + RunPath: rp, + Action: ae.Action, + Err: ae.Err, + } + + if ae.Output == nil { + return copied + } + + copied.Output = &AgentOutput{ + CustomizedOutput: ae.Output.CustomizedOutput, + } + + mv := ae.Output.MessageOutput + if mv == nil { + return copied + } + + copied.Output.MessageOutput = &MessageVariant{ + IsStreaming: mv.IsStreaming, + Role: mv.Role, + ToolName: mv.ToolName, + } + if mv.IsStreaming { + sts := ae.Output.MessageOutput.MessageStream.Copy(2) + mv.MessageStream = sts[0] + copied.Output.MessageOutput.MessageStream = sts[1] + } else { + copied.Output.MessageOutput.Message = mv.Message + } + + return copied +} + +func GetMessage(e *AgentEvent) (Message, *AgentEvent, error) { + if e.Output == nil || e.Output.MessageOutput == nil { + return nil, e, nil + } + + msgOutput := e.Output.MessageOutput + if msgOutput.IsStreaming { + ss := msgOutput.MessageStream.Copy(2) + e.Output.MessageOutput.MessageStream = ss[0] + + msg, err := schema.ConcatMessageStream(ss[1]) + + return msg, e, err + } + + return msgOutput.Message, e, nil +} + +func genErrorIter(err error) *AsyncIterator[*AgentEvent] { + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{Err: err}) + generator.Close() + return iterator +} diff --git a/adk/utils_test.go b/adk/utils_test.go new file mode 100644 index 00000000..35501c60 --- /dev/null +++ b/adk/utils_test.go @@ -0,0 +1,167 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAsyncIteratorPair_Basic(t *testing.T) { + // Create a new iterator-generator pair + iterator, generator := NewAsyncIteratorPair[string]() + + // Test sending and receiving a value + generator.Send("test1") + val, ok := iterator.Next() + if !ok { + t.Error("receive should succeed") + } + if val != "test1" { + t.Errorf("expected 'test1', got '%s'", val) + } + + // Test sending and receiving multiple values + generator.Send("test2") + generator.Send("test3") + + val, ok = iterator.Next() + if !ok { + t.Error("receive should succeed") + } + if val != "test2" { + t.Errorf("expected 'test2', got '%s'", val) + } + + val, ok = iterator.Next() + if !ok { + t.Error("receive should succeed") + } + if val != "test3" { + t.Errorf("expected 'test3', got '%s'", val) + } +} + +func TestAsyncIteratorPair_Close(t *testing.T) { + iterator, generator := NewAsyncIteratorPair[int]() + + // Send some values + generator.Send(1) + generator.Send(2) + + // Close the generator + generator.Close() + + // Should still be able to read existing values + val, ok := iterator.Next() + if !ok { + t.Error("receive should succeed") + } + if val != 1 { + t.Errorf("expected 1, got %d", val) + } + + val, ok = iterator.Next() + if !ok { + t.Error("receive should succeed") + } + if val != 2 { + t.Errorf("expected 2, got %d", val) + } + + // After consuming all values, Next should return false + _, ok = iterator.Next() + if ok { + t.Error("receive from closed, empty channel should return ok=false") + } +} + +func TestAsyncIteratorPair_Concurrency(t *testing.T) { + iterator, generator := NewAsyncIteratorPair[int]() + const numSenders = 5 + const numReceivers = 3 + const messagesPerSender = 100 + + var rwg, swg sync.WaitGroup + rwg.Add(numReceivers) + swg.Add(numSenders) + + // Start senders + for i := 0; i < numSenders; i++ { + go func(id int) { + defer swg.Done() + for j := 0; j < messagesPerSender; j++ { + generator.Send(id*messagesPerSender + j) + time.Sleep(time.Microsecond) // Small delay to increase concurrency chance + } + }(i) + } + + // Start receivers + received := make([]int, 0, numSenders*messagesPerSender) + var mu sync.Mutex + + for i := 0; i < numReceivers; i++ { + go func() { + defer rwg.Done() + for { + val, ok := iterator.Next() + if !ok { + return + } + mu.Lock() + received = append(received, val) + mu.Unlock() + } + }() + } + + // Wait for senders to finish + swg.Wait() + generator.Close() + + // Wait for all goroutines to finish + rwg.Wait() + + // Verify we received all messages + if len(received) != numSenders*messagesPerSender { + t.Errorf("expected %d messages, got %d", numSenders*messagesPerSender, len(received)) + } + + // Create a map to check for duplicates and missing values + receivedMap := make(map[int]bool) + for _, val := range received { + receivedMap[val] = true + } + + if len(receivedMap) != numSenders*messagesPerSender { + t.Error("duplicate or missing messages detected") + } +} + +func TestGenErrorIter(t *testing.T) { + iter := genErrorIter(fmt.Errorf("test")) + e, ok := iter.Next() + assert.True(t, ok) + assert.Equal(t, "test", e.Err.Error()) + _, ok = iter.Next() + assert.False(t, ok) +} diff --git a/adk/workflow.go b/adk/workflow.go new file mode 100644 index 00000000..45ef547a --- /dev/null +++ b/adk/workflow.go @@ -0,0 +1,443 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "errors" + "fmt" + "reflect" + "runtime/debug" + "sync" + + "github.com/cloudwego/eino/internal/safe" +) + +type workflowAgentMode int + +const ( + workflowAgentModeUnknown workflowAgentMode = iota + workflowAgentModeSequential + workflowAgentModeLoop + workflowAgentModeParallel +) + +type workflowAgent struct { + name string + description string + subAgents []*flowAgent + + mode workflowAgentMode + + maxIterations int +} + +func (a *workflowAgent) Name(_ context.Context) string { + return a.name +} + +func (a *workflowAgent) Description(_ context.Context) string { + return a.description +} + +func (a *workflowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + + go func() { + + var err error + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&AgentEvent{Err: e}) + } else if err != nil { + generator.Send(&AgentEvent{Err: err}) + } + + generator.Close() + }() + + // Different workflow execution based on mode + switch a.mode { + case workflowAgentModeSequential: + a.runSequential(ctx, input, generator, nil, 0, opts...) + case workflowAgentModeLoop: + a.runLoop(ctx, input, generator, nil, opts...) + case workflowAgentModeParallel: + a.runParallel(ctx, input, generator, nil, opts...) + default: + err = errors.New(fmt.Sprintf("unsupported workflow agent mode: %d", a.mode)) + } + }() + + return iterator +} + +func (a *workflowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + wi, ok := info.Data.(*WorkflowInterruptInfo) + if !ok { + // unreachable + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + generator.Send(&AgentEvent{Err: fmt.Errorf("type of InterruptInfo.Data is expected to %s, actual: %T", reflect.TypeOf((*WorkflowInterruptInfo)(nil)).String(), info.Data)}) + generator.Close() + + return iterator + } + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + + go func() { + + var err error + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&AgentEvent{Err: e}) + } else if err != nil { + generator.Send(&AgentEvent{Err: err}) + } + + generator.Close() + }() + + // Different workflow execution based on mode + switch a.mode { + case workflowAgentModeSequential: + a.runSequential(ctx, wi.OrigInput, generator, wi, 0, opts...) + case workflowAgentModeLoop: + a.runLoop(ctx, wi.OrigInput, generator, wi, opts...) + case workflowAgentModeParallel: + a.runParallel(ctx, wi.OrigInput, generator, wi, opts...) + default: + err = errors.New(fmt.Sprintf("unsupported workflow agent mode: %d", a.mode)) + } + }() + return iterator +} + +type WorkflowInterruptInfo struct { + OrigInput *AgentInput + + SequentialInterruptIndex int + SequentialInterruptInfo *InterruptInfo + + LoopIterations int + + ParallelInterruptInfo map[int] /*index*/ *InterruptInfo +} + +func (a *workflowAgent) runSequential(ctx context.Context, input *AgentInput, + generator *AsyncGenerator[*AgentEvent], intInfo *WorkflowInterruptInfo, iterations int /*passed by loop agent*/, opts ...AgentRunOption) (exit, interrupted bool) { + var runPath []RunStep // reconstruct RunPath each loop + if iterations > 0 { + runPath = make([]RunStep, 0, (iterations+1)*len(a.subAgents)) + for iter := 0; iter < iterations; iter++ { + for j := 0; j < len(a.subAgents); j++ { + runPath = append(runPath, RunStep{ + agentName: a.subAgents[j].Name(ctx), + }) + } + } + } + + i := 0 + if intInfo != nil { // restore previous RunPath + i = intInfo.SequentialInterruptIndex + + for j := 0; j < i; j++ { + runPath = append(runPath, RunStep{ + agentName: a.subAgents[j].Name(ctx), + }) + } + } + + runCtx := getRunCtx(ctx) + nRunCtx := runCtx.deepCopy() + nRunCtx.RunPath = append(nRunCtx.RunPath, runPath...) + nCtx := setRunCtx(ctx, nRunCtx) + + for ; i < len(a.subAgents); i++ { + subAgent := a.subAgents[i] + + var subIterator *AsyncIterator[*AgentEvent] + if intInfo != nil && i == intInfo.SequentialInterruptIndex { + nCtx, nRunCtx = initRunCtx(nCtx, subAgent.Name(nCtx), nRunCtx.RootInput) + enableStreaming := false + if runCtx.RootInput != nil { + enableStreaming = runCtx.RootInput.EnableStreaming + } + subIterator = subAgent.Resume(nCtx, &ResumeInfo{ + EnableStreaming: enableStreaming, + InterruptInfo: intInfo.SequentialInterruptInfo, + }, opts...) + } else { + subIterator = subAgent.Run(nCtx, input, opts...) + nCtx, _ = initRunCtx(nCtx, subAgent.Name(nCtx), input) + } + + for { + event, ok := subIterator.Next() + if !ok { + break + } + + if event.Action != nil && event.Action.Interrupted != nil { + // shallow copy + newEvent := &AgentEvent{ + AgentName: event.AgentName, + RunPath: event.RunPath, + Output: event.Output, + Action: &AgentAction{ + Exit: event.Action.Exit, + Interrupted: &InterruptInfo{Data: event.Action.Interrupted.Data}, + TransferToAgent: event.Action.TransferToAgent, + CustomizedAction: event.Action.CustomizedAction, + }, + Err: event.Err, + } + newEvent.Action.Interrupted.Data = &WorkflowInterruptInfo{ + OrigInput: input, + SequentialInterruptIndex: i, + SequentialInterruptInfo: event.Action.Interrupted, + LoopIterations: iterations, + } + + // Reset run ctx, + // because the control should be transferred to the workflow agent, not the interrupted agent + replaceInterruptRunCtx(nCtx, runCtx) + + // Forward the event + generator.Send(newEvent) + return true, true + } + + // Forward the event + generator.Send(event) + + if event.Err != nil { + return true, false + } + + if event.Action != nil { + if event.Action.Exit { + return true, false + } + } + } + } + + return false, false +} + +func (a *workflowAgent) runLoop(ctx context.Context, input *AgentInput, + generator *AsyncGenerator[*AgentEvent], intInfo *WorkflowInterruptInfo, opts ...AgentRunOption) { + + if len(a.subAgents) == 0 { + return + } + var iterations int + if intInfo != nil { + iterations = intInfo.LoopIterations + } + for iterations < a.maxIterations || a.maxIterations == 0 { + exit, interrupted := a.runSequential(ctx, input, generator, intInfo, iterations, opts...) + if interrupted { + return + } + if exit { + return + } + intInfo = nil // only effect once + iterations++ + } +} + +func (a *workflowAgent) runParallel(ctx context.Context, input *AgentInput, + generator *AsyncGenerator[*AgentEvent], intInfo *WorkflowInterruptInfo, opts ...AgentRunOption) { + + if len(a.subAgents) == 0 { + return + } + + runners := getRunners(a.subAgents, input, intInfo, opts...) + var wg sync.WaitGroup + interruptMap := make(map[int]*InterruptInfo) + var mu sync.Mutex + if len(runners) > 1 { + for i := 1; i < len(runners); i++ { + wg.Add(1) + go func(idx int, runner func(ctx context.Context) *AsyncIterator[*AgentEvent]) { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + generator.Send(&AgentEvent{Err: e}) + } + wg.Done() + }() + + iterator := runner(ctx) + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Action != nil && event.Action.Interrupted != nil { + mu.Lock() + interruptMap[idx] = event.Action.Interrupted + mu.Unlock() + break + } + // Forward the event + generator.Send(event) + } + }(i, runners[i]) + } + } + + runner := runners[0] + iterator := runner(ctx) + for { + event, ok := iterator.Next() + if !ok { + break + } + if event.Action != nil && event.Action.Interrupted != nil { + mu.Lock() + interruptMap[0] = event.Action.Interrupted + mu.Unlock() + break + } + // Forward the event + generator.Send(event) + } + + if len(a.subAgents) > 1 { + wg.Wait() + } + + if len(interruptMap) > 0 { + replaceInterruptRunCtx(ctx, getRunCtx(ctx)) + generator.Send(&AgentEvent{ + AgentName: a.Name(ctx), + RunPath: getRunCtx(ctx).RunPath, + Action: &AgentAction{ + Interrupted: &InterruptInfo{ + Data: &WorkflowInterruptInfo{ + OrigInput: input, + ParallelInterruptInfo: interruptMap, + }, + }, + }, + }) + } +} + +func getRunners(subAgents []*flowAgent, input *AgentInput, intInfo *WorkflowInterruptInfo, opts ...AgentRunOption) []func(ctx context.Context) *AsyncIterator[*AgentEvent] { + ret := make([]func(ctx context.Context) *AsyncIterator[*AgentEvent], 0, len(subAgents)) + if intInfo == nil { + // init run + for _, subAgent := range subAgents { + sa := subAgent + ret = append(ret, func(ctx context.Context) *AsyncIterator[*AgentEvent] { + return sa.Run(ctx, input, opts...) + }) + } + return ret + } + // resume + for i, subAgent := range subAgents { + sa := subAgent + info, ok := intInfo.ParallelInterruptInfo[i] + if !ok { + // have executed + continue + } + ret = append(ret, func(ctx context.Context) *AsyncIterator[*AgentEvent] { + nCtx, runCtx := initRunCtx(ctx, sa.Name(ctx), input) + enableStreaming := false + if runCtx.RootInput != nil { + enableStreaming = runCtx.RootInput.EnableStreaming + } + return sa.Resume(nCtx, &ResumeInfo{ + EnableStreaming: enableStreaming, + InterruptInfo: info, + }, opts...) + }) + } + return ret +} + +type SequentialAgentConfig struct { + Name string + Description string + SubAgents []Agent +} + +type ParallelAgentConfig struct { + Name string + Description string + SubAgents []Agent +} + +type LoopAgentConfig struct { + Name string + Description string + SubAgents []Agent + + MaxIterations int +} + +func newWorkflowAgent(ctx context.Context, name, desc string, + subAgents []Agent, mode workflowAgentMode, maxIterations int) (*flowAgent, error) { + + wa := &workflowAgent{ + name: name, + description: desc, + mode: mode, + + maxIterations: maxIterations, + } + + fas := make([]Agent, len(subAgents)) + for i, subAgent := range subAgents { + fas[i] = toFlowAgent(ctx, subAgent, WithDisallowTransferToParent()) + } + + fa, err := setSubAgents(ctx, wa, fas) + if err != nil { + return nil, err + } + + wa.subAgents = fa.subAgents + + return fa, nil +} + +func NewSequentialAgent(ctx context.Context, config *SequentialAgentConfig) (Agent, error) { + return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeSequential, 0) +} + +func NewParallelAgent(ctx context.Context, config *ParallelAgentConfig) (Agent, error) { + return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeParallel, 0) +} + +func NewLoopAgent(ctx context.Context, config *LoopAgentConfig) (Agent, error) { + return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeLoop, config.MaxIterations) +} diff --git a/adk/workflow_test.go b/adk/workflow_test.go new file mode 100644 index 00000000..21332142 --- /dev/null +++ b/adk/workflow_test.go @@ -0,0 +1,649 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +// mockAgent is a simple implementation of the Agent interface for testing +type mockAgent struct { + name string + description string + responses []*AgentEvent +} + +func (a *mockAgent) Name(_ context.Context) string { + return a.name +} + +func (a *mockAgent) Description(_ context.Context) string { + return a.description +} + +func (a *mockAgent) Run(_ context.Context, _ *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] { + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + + go func() { + defer generator.Close() + + for _, event := range a.responses { + generator.Send(event) + + // If the event has an Exit action, stop sending events + if event.Action != nil && event.Action.Exit { + break + } + } + }() + + return iterator +} + +// newMockAgent creates a new mock agent with the given name, description, and responses +func newMockAgent(name, description string, responses []*AgentEvent) *mockAgent { + return &mockAgent{ + name: name, + description: description, + responses: responses, + } +} + +// TestSequentialAgent tests the sequential workflow agent +func TestSequentialAgent(t *testing.T) { + ctx := context.Background() + + // Create mock agents with predefined responses + agent1 := newMockAgent("Agent1", "First agent", []*AgentEvent{ + { + AgentName: "Agent1", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Response from Agent1", nil), + Role: schema.Assistant, + }, + }, + }, + }) + + agent2 := newMockAgent("Agent2", "Second agent", []*AgentEvent{ + { + AgentName: "Agent2", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Response from Agent2", nil), + Role: schema.Assistant, + }, + }}, + }) + + // Create a sequential agent with the mock agents + config := &SequentialAgentConfig{ + Name: "SequentialTestAgent", + Description: "Test sequential agent", + SubAgents: []Agent{agent1, agent2}, + } + + sequentialAgent, err := NewSequentialAgent(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, sequentialAgent) + + assert.Equal(t, "Test sequential agent", sequentialAgent.Description(ctx)) + + // Run the sequential agent + input := &AgentInput{ + Messages: []Message{ + schema.UserMessage("Test input"), + }, + } + + iterator := sequentialAgent.Run(ctx, input) + assert.NotNil(t, iterator) + + // First event should be from agent1 + event1, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event1) + assert.Nil(t, event1.Err) + assert.NotNil(t, event1.Output) + assert.NotNil(t, event1.Output.MessageOutput) + + // Get the message content from agent1 + msg1 := event1.Output.MessageOutput.Message + assert.NotNil(t, msg1) + assert.Equal(t, "Response from Agent1", msg1.Content) + + // Second event should be from agent2 + event2, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event2) + assert.Nil(t, event2.Err) + assert.NotNil(t, event2.Output) + assert.NotNil(t, event2.Output.MessageOutput) + + // Get the message content from agent2 + msg2 := event2.Output.MessageOutput.Message + assert.NotNil(t, msg2) + assert.Equal(t, "Response from Agent2", msg2.Content) + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) +} + +// TestSequentialAgentWithExit tests the sequential workflow agent with an exit action +func TestSequentialAgentWithExit(t *testing.T) { + ctx := context.Background() + + // Create mock agents with predefined responses + agent1 := newMockAgent("Agent1", "First agent", []*AgentEvent{ + { + AgentName: "Agent1", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Response from Agent1", nil), + Role: schema.Assistant, + }, + }, + Action: &AgentAction{ + Exit: true, + }, + }, + }) + + agent2 := newMockAgent("Agent2", "Second agent", []*AgentEvent{ + { + AgentName: "Agent2", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Response from Agent2", nil), + Role: schema.Assistant, + }, + }, + }, + }) + + // Create a sequential agent with the mock agents + config := &SequentialAgentConfig{ + Name: "SequentialTestAgent", + Description: "Test sequential agent", + SubAgents: []Agent{agent1, agent2}, + } + + sequentialAgent, err := NewSequentialAgent(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, sequentialAgent) + + // Run the sequential agent + input := &AgentInput{ + Messages: []Message{ + schema.UserMessage("Test input"), + }, + } + + iterator := sequentialAgent.Run(ctx, input) + assert.NotNil(t, iterator) + + // First event should be from agent1 with exit action + event1, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event1) + assert.Nil(t, event1.Err) + assert.NotNil(t, event1.Output) + assert.NotNil(t, event1.Output.MessageOutput) + assert.NotNil(t, event1.Action) + assert.True(t, event1.Action.Exit) + + // No more events due to exit action + _, ok = iterator.Next() + assert.False(t, ok) +} + +// TestParallelAgent tests the parallel workflow agent +func TestParallelAgent(t *testing.T) { + ctx := context.Background() + + // Create mock agents with predefined responses + agent1 := newMockAgent("Agent1", "First agent", []*AgentEvent{ + { + AgentName: "Agent1", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Response from Agent1", nil), + Role: schema.Assistant, + }, + }, + }, + }) + + agent2 := newMockAgent("Agent2", "Second agent", []*AgentEvent{ + { + AgentName: "Agent2", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Response from Agent2", nil), + Role: schema.Assistant, + }, + }, + }, + }) + + // Create a parallel agent with the mock agents + config := &ParallelAgentConfig{ + Name: "ParallelTestAgent", + Description: "Test parallel agent", + SubAgents: []Agent{agent1, agent2}, + } + + parallelAgent, err := NewParallelAgent(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, parallelAgent) + + // Run the parallel agent + input := AgentInput{ + Messages: []Message{ + schema.UserMessage("Test input"), + }, + } + + iterator := parallelAgent.Run(ctx, &input) + assert.NotNil(t, iterator) + + // Collect all events + var events []*AgentEvent + for { + event, ok := iterator.Next() + if !ok { + break + } + events = append(events, event) + } + + // Should have two events, one from each agent + assert.Equal(t, 2, len(events)) + + // Verify the events + for _, event := range events { + assert.Nil(t, event.Err) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + + msg := event.Output.MessageOutput.Message + assert.NotNil(t, msg) + assert.NoError(t, err) + + // Check the source agent name and message content + if event.AgentName == "Agent1" { + assert.Equal(t, "Response from Agent1", msg.Content) + } else if event.AgentName == "Agent2" { + assert.Equal(t, "Response from Agent2", msg.Content) + } else { + t.Fatalf("Unexpected source agent name: %s", event.AgentName) + } + } +} + +// TestLoopAgent tests the loop workflow agent +func TestLoopAgent(t *testing.T) { + ctx := context.Background() + + // Create a mock agent that will be called multiple times + agent := newMockAgent("LoopAgent", "Loop agent", []*AgentEvent{ + { + AgentName: "LoopAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Loop iteration", nil), + Role: schema.Assistant, + }, + }, + }, + }) + + // Create a loop agent with the mock agent and max iterations set to 3 + config := &LoopAgentConfig{ + Name: "LoopTestAgent", + Description: "Test loop agent", + SubAgents: []Agent{agent}, + + MaxIterations: 3, + } + + loopAgent, err := NewLoopAgent(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, loopAgent) + + // Run the loop agent + input := &AgentInput{ + Messages: []Message{ + schema.UserMessage("Test input"), + }, + } + + iterator := loopAgent.Run(ctx, input) + assert.NotNil(t, iterator) + + // Collect all events + var events []*AgentEvent + for { + event, ok := iterator.Next() + if !ok { + break + } + events = append(events, event) + } + + // Should have 3 events (one for each iteration) + assert.Equal(t, 3, len(events)) + + // Verify all events + for _, event := range events { + assert.Nil(t, event.Err) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + + msg := event.Output.MessageOutput.Message + assert.NotNil(t, msg) + assert.Equal(t, "Loop iteration", msg.Content) + } +} + +// TestLoopAgentWithExit tests the loop workflow agent with an exit action +func TestLoopAgentWithExit(t *testing.T) { + ctx := context.Background() + + // Create a mock agent that will exit after the first iteration + agent := newMockAgent("LoopAgent", "Loop agent", []*AgentEvent{ + { + AgentName: "LoopAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("Loop iteration with exit", nil), + Role: schema.Assistant, + }, + }, + Action: &AgentAction{ + Exit: true, + }, + }, + }) + + // Create a loop agent with the mock agent and max iterations set to 3 + config := &LoopAgentConfig{ + Name: "LoopTestAgent", + Description: "Test loop agent", + SubAgents: []Agent{agent}, + MaxIterations: 3, + } + + loopAgent, err := NewLoopAgent(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, loopAgent) + + // Run the loop agent + input := &AgentInput{ + Messages: []Message{ + schema.UserMessage("Test input"), + }, + } + + iterator := loopAgent.Run(ctx, input) + assert.NotNil(t, iterator) + + // Collect all events + var events []*AgentEvent + for { + event, ok := iterator.Next() + if !ok { + break + } + events = append(events, event) + } + + // Should have only 1 event due to exit action + assert.Equal(t, 1, len(events)) + + // Verify the event + event := events[0] + assert.Nil(t, event.Err) + assert.NotNil(t, event.Output) + assert.NotNil(t, event.Output.MessageOutput) + assert.NotNil(t, event.Action) + assert.True(t, event.Action.Exit) + + msg := event.Output.MessageOutput.Message + assert.NotNil(t, msg) + assert.Equal(t, "Loop iteration with exit", msg.Content) +} + +// Add these test functions to the existing workflow_test.go file + +// Replace the existing TestWorkflowAgentPanicRecovery function +func TestWorkflowAgentPanicRecovery(t *testing.T) { + ctx := context.Background() + + // Create a panic agent that panics in Run method + panicAgent := &panicMockAgent{ + mockAgent: mockAgent{ + name: "PanicAgent", + description: "Agent that panics", + responses: []*AgentEvent{}, + }, + } + + // Create a sequential agent with the panic agent + config := &SequentialAgentConfig{ + Name: "PanicTestAgent", + Description: "Test agent with panic", + SubAgents: []Agent{panicAgent}, + } + + sequentialAgent, err := NewSequentialAgent(ctx, config) + assert.NoError(t, err) + + // Run the agent and expect panic recovery + input := &AgentInput{ + Messages: []Message{ + schema.UserMessage("Test input"), + }, + } + + iterator := sequentialAgent.Run(ctx, input) + assert.NotNil(t, iterator) + + // Should receive an error event due to panic recovery + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.NotNil(t, event.Err) + assert.Contains(t, event.Err.Error(), "panic") + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) +} + +// Add these new mock agent types that properly panic +type panicMockAgent struct { + mockAgent +} + +func (a *panicMockAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + panic("test panic in agent") +} + +type panicResumableMockAgent struct { + mockAgent +} + +func (a *panicResumableMockAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + panic("test panic in resume") +} + +// Remove the old mockResumableAgent type and replace it with panicResumableMockAgent + +// TestWorkflowAgentUnsupportedMode tests unsupported workflow mode error (lines 65-71) +func TestWorkflowAgentUnsupportedMode(t *testing.T) { + ctx := context.Background() + + // Create a workflow agent with unsupported mode + agent := &workflowAgent{ + name: "UnsupportedModeAgent", + description: "Agent with unsupported mode", + subAgents: []*flowAgent{}, + mode: workflowAgentMode(999), // Invalid mode + } + + // Run the agent and expect error + input := &AgentInput{ + Messages: []Message{ + schema.UserMessage("Test input"), + }, + } + + iterator := agent.Run(ctx, input) + assert.NotNil(t, iterator) + + // Should receive an error event due to unsupported mode + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.NotNil(t, event.Err) + assert.Contains(t, event.Err.Error(), "unsupported workflow agent mode") + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) +} + +// TestWorkflowAgentResumePanicRecovery tests panic recovery in Resume method (lines 108-115) +func TestWorkflowAgentResumePanicRecovery(t *testing.T) { + ctx := context.Background() + + // Create a mock resumable agent that panics on Resume + panicAgent := &mockResumableAgent{ + mockAgent: mockAgent{ + name: "PanicResumeAgent", + description: "Agent that panics on resume", + responses: []*AgentEvent{}, + }, + } + + // Create a sequential agent with the panic agent + config := &SequentialAgentConfig{ + Name: "ResumeTestAgent", + Description: "Test agent for resume panic", + SubAgents: []Agent{panicAgent}, + } + + sequentialAgent, err := NewSequentialAgent(ctx, config) + assert.NoError(t, err) + + // Initialize context with run context - this is the key fix + ctx = ctxWithNewRunCtx(ctx) + + // Create valid resume info + resumeInfo := &ResumeInfo{ + EnableStreaming: false, + InterruptInfo: &InterruptInfo{ + Data: &WorkflowInterruptInfo{ + OrigInput: &AgentInput{ + Messages: []Message{schema.UserMessage("test")}, + }, + SequentialInterruptIndex: 0, + SequentialInterruptInfo: &InterruptInfo{ + Data: "some interrupt data", + }, + LoopIterations: 0, + }, + }, + } + + // Call Resume and expect panic recovery + iterator := sequentialAgent.(ResumableAgent).Resume(ctx, resumeInfo) + assert.NotNil(t, iterator) + + // Should receive an error event due to panic recovery + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.NotNil(t, event.Err) + assert.Contains(t, event.Err.Error(), "panic") + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) +} + +// mockResumableAgent extends mockAgent to implement ResumableAgent interface +type mockResumableAgent struct { + mockAgent +} + +func (a *mockResumableAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + panic("test panic in resume") +} + +// TestWorkflowAgentResumeInvalidDataType tests invalid data type in Resume method +func TestWorkflowAgentResumeInvalidDataType(t *testing.T) { + ctx := context.Background() + + // Create a workflow agent + agent := &workflowAgent{ + name: "InvalidDataTestAgent", + description: "Agent for invalid data test", + subAgents: []*flowAgent{}, + mode: workflowAgentModeSequential, + } + + // Create resume info with invalid data type + resumeInfo := &ResumeInfo{ + EnableStreaming: false, + InterruptInfo: &InterruptInfo{ + Data: "invalid data type", // Should be *WorkflowInterruptInfo + }, + } + + // Call Resume and expect type assertion error + iterator := agent.Resume(ctx, resumeInfo) + assert.NotNil(t, iterator) + + // Should receive an error event due to type assertion failure + event, ok := iterator.Next() + assert.True(t, ok) + assert.NotNil(t, event) + assert.NotNil(t, event.Err) + assert.Contains(t, event.Err.Error(), "type of InterruptInfo.Data is expected to") + assert.Contains(t, event.Err.Error(), "actual: string") + + // No more events + _, ok = iterator.Next() + assert.False(t, ok) +} diff --git a/compose/generic_helper.go b/compose/generic_helper.go index 0d317717..0534dc62 100644 --- a/compose/generic_helper.go +++ b/compose/generic_helper.go @@ -201,7 +201,7 @@ func defaultStreamMapFilter[T any](key string, isr streamReader) (streamReader, return nil, false } - convert := func(m map[string]any) (T, error) { + cvt := func(m map[string]any) (T, error) { var t T v, ok_ := m[key] if !ok_ { @@ -217,7 +217,7 @@ func defaultStreamMapFilter[T any](key string, isr streamReader) (streamReader, return vv, nil } - ret := schema.StreamReaderWithConvert[map[string]any, T](sr, convert) + ret := schema.StreamReaderWithConvert[map[string]any, T](sr, cvt) return packStreamReader(ret), true } diff --git a/compose/graph_run.go b/compose/graph_run.go index f339ffc6..f54ebd2a 100644 --- a/compose/graph_run.go +++ b/compose/graph_run.go @@ -398,31 +398,32 @@ type interruptTempInfo struct { } func (r *runner) resolveInterruptCompletedTasks(tempInfo *interruptTempInfo, completedTasks []*task) (err error) { - for i := 0; i < len(completedTasks); i++ { - if completedTasks[i].err != nil { - if info := isSubGraphInterrupt(completedTasks[i].err); info != nil { - tempInfo.subGraphInterrupts[completedTasks[i].nodeKey] = info + for _, completedTask := range completedTasks { + if completedTask.err != nil { + if info := isSubGraphInterrupt(completedTask.err); info != nil { + tempInfo.subGraphInterrupts[completedTask.nodeKey] = info continue } - extra, ok := IsInterruptRerunError(completedTasks[i].err) + extra, ok := IsInterruptRerunError(completedTask.err) if ok { - tempInfo.interruptRerunNodes = append(tempInfo.interruptRerunNodes, completedTasks[i].nodeKey) + tempInfo.interruptRerunNodes = append(tempInfo.interruptRerunNodes, completedTask.nodeKey) if extra != nil { - tempInfo.interruptRerunExtra[completedTasks[i].nodeKey] = extra + tempInfo.interruptRerunExtra[completedTask.nodeKey] = extra // save tool node info - if completedTasks[i].call.action.meta.component == ComponentOfToolsNode { + if completedTask.call.action.meta.component == ComponentOfToolsNode { if e, ok := extra.(*ToolsInterruptAndRerunExtra); ok { - tempInfo.interruptExecutedTools[completedTasks[i].nodeKey] = e.ExecutedTools + tempInfo.interruptExecutedTools[completedTask.nodeKey] = e.ExecutedTools } } } continue } - return wrapGraphNodeError(completedTasks[i].nodeKey, completedTasks[i].err) + return wrapGraphNodeError(completedTask.nodeKey, completedTask.err) } + for _, key := range r.interruptAfterNodes { - if key == completedTasks[i].nodeKey { + if key == completedTask.nodeKey { tempInfo.interruptAfterNodes = append(tempInfo.interruptAfterNodes, key) break } diff --git a/compose/stream_concat.go b/compose/stream_concat.go index 29829f7b..4feeba80 100644 --- a/compose/stream_concat.go +++ b/compose/stream_concat.go @@ -59,6 +59,10 @@ func concatStreamReader[T any](sr *schema.StreamReader[T]) (T, error) { break } + if _, ok := schema.GetSourceName(err); ok { + continue + } + var t T return t, newStreamReadError(err) } diff --git a/compose/tool_node.go b/compose/tool_node.go index 2bfb1f4a..aac0db10 100644 --- a/compose/tool_node.go +++ b/compose/tool_node.go @@ -59,10 +59,14 @@ func withExecutedTools(executedTools map[string]string) ToolsNodeOption { } } -// ToolsNode a node that can run tools in a graph. the interface in Graph Node as below: +// ToolsNode represents a node capable of executing tools within a graph. +// The Graph Node interface is defined as follows: // // Invoke(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) ([]*schema.Message, error) // Stream(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error) +// +// Input: An AssistantMessage containing ToolCalls +// Output: An array of ToolMessage where the order of elements corresponds to the order of ToolCalls in the input type ToolsNode struct { tuple *toolsTuple unknownToolHandler func(ctx context.Context, name, input string) (string, error) @@ -204,7 +208,9 @@ type toolCallTask struct { err error } -func (tn *ToolsNode) genToolCallTasks(ctx context.Context, tuple *toolsTuple, input *schema.Message, executedTools map[string]string, isStream bool) ([]toolCallTask, error) { +func (tn *ToolsNode) genToolCallTasks(ctx context.Context, tuple *toolsTuple, + input *schema.Message, executedTools map[string]string, isStream bool) ([]toolCallTask, error) { + if input.Role != schema.Assistant { return nil, fmt.Errorf("expected message role is Assistant, got %s", input.Role) } @@ -303,8 +309,11 @@ func runToolCallTaskByStream(ctx context.Context, task *toolCallTask, opts ...to } } -func sequentialRunToolCall(ctx context.Context, run func(ctx2 context.Context, callTask *toolCallTask, opts ...tool.Option), tasks []toolCallTask, opts ...tool.Option) { - for i := 0; i < len(tasks); i++ { +func sequentialRunToolCall(ctx context.Context, + run func(ctx2 context.Context, callTask *toolCallTask, opts ...tool.Option), + tasks []toolCallTask, opts ...tool.Option) { + + for i := range tasks { if tasks[i].executed { continue } @@ -313,7 +322,8 @@ func sequentialRunToolCall(ctx context.Context, run func(ctx2 context.Context, c } func parallelRunToolCall(ctx context.Context, - run func(ctx2 context.Context, callTask *toolCallTask, opts ...tool.Option), tasks []toolCallTask, opts ...tool.Option) { + run func(ctx2 context.Context, callTask *toolCallTask, opts ...tool.Option), + tasks []toolCallTask, opts ...tool.Option) { if len(tasks) == 1 { run(ctx, &tasks[0], opts...) @@ -370,6 +380,7 @@ func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, n := len(tasks) output := make([]*schema.Message, n) + rerunExtra := &ToolsInterruptAndRerunExtra{ ToolCalls: input.ToolCalls, ExecutedTools: make(map[string]string), @@ -454,9 +465,9 @@ func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message, // concat and save tool output for _, t := range tasks { if t.executed { - o, err := concatStreamReader(t.sOutput) - if err != nil { - return nil, fmt.Errorf("failed to concat tool[name:%s id:%s]'s stream output: %w", t.name, t.callID, err) + o, err_ := concatStreamReader(t.sOutput) + if err_ != nil { + return nil, fmt.Errorf("failed to concat tool[name:%s id:%s]'s stream output: %w", t.name, t.callID, err_) } rerunExtra.ExecutedTools[t.callID] = o } @@ -470,14 +481,14 @@ func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message, index := i callID := tasks[i].callID callName := tasks[i].name - convert := func(s string) ([]*schema.Message, error) { + cvt := func(s string) ([]*schema.Message, error) { ret := make([]*schema.Message, n) ret[index] = schema.ToolMessage(s, callID, schema.WithToolName(callName)) return ret, nil } - sOutput[i] = schema.StreamReaderWithConvert(tasks[i].sOutput, convert) + sOutput[i] = schema.StreamReaderWithConvert(tasks[i].sOutput, cvt) } return schema.MergeStreamReaders(sOutput), nil } diff --git a/compose/tool_node_test.go b/compose/tool_node_test.go index 9870118a..e5743035 100644 --- a/compose/tool_node_test.go +++ b/compose/tool_node_test.go @@ -47,32 +47,32 @@ func TestToolsNode(t *testing.T) { userCompanyToolInfo := &schema.ToolInfo{ Name: toolNameOfUserCompany, - Desc: "根据用户的姓名和邮箱,查询用户的公司和职位信息", + Desc: "Query user's company and position information based on user's name and email", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Type: "string", - Desc: "用户的姓名", + Desc: "User's name", }, "email": { Type: "string", - Desc: "用户的邮箱", + Desc: "User's email", }, }), } userSalaryToolInfo := &schema.ToolInfo{ Name: toolNameOfUserSalary, - Desc: "根据用户的姓名和邮箱,查询用户的薪酬信息", + Desc: "Query user's salary information based on user's name and email", ParamsOneOf: schema.NewParamsOneOfByParams( map[string]*schema.ParameterInfo{ "name": { Type: "string", - Desc: "用户的姓名", + Desc: "User's name", }, "email": { Type: "string", - Desc: "用户的邮箱", + Desc: "User's email", }, }), } @@ -113,12 +113,14 @@ func TestToolsNode(t *testing.T) { out, err := r.Invoke(ctx, []*schema.Message{}) assert.NoError(t, err) - assert.Equal(t, toolIDOfUserCompany, findMsgByToolCallID(out, toolIDOfUserCompany).ToolCallID) + msg := findMsgByToolCallID(out, toolIDOfUserCompany) + assert.Equal(t, toolIDOfUserCompany, msg.ToolCallID) assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","gender":"male","company":"bytedance","position":"CEO"}`, - findMsgByToolCallID(out, toolIDOfUserCompany).Content) + msg.Content) - assert.Equal(t, toolIDOfUserSalary, findMsgByToolCallID(out, toolIDOfUserSalary).ToolCallID) - assert.Contains(t, findMsgByToolCallID(out, toolIDOfUserSalary).Content, + msg = findMsgByToolCallID(out, toolIDOfUserSalary) + assert.Equal(t, toolIDOfUserSalary, msg.ToolCallID) + assert.Contains(t, msg.Content, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":5000}{"user_id":"zhangsan-zhangsan@bytedance.com","salary":3000}{"user_id":"zhangsan-zhangsan@bytedance.com","salary":2000}`) // 测试流式调用 @@ -129,12 +131,15 @@ func TestToolsNode(t *testing.T) { defer reader.Close() + var arrMsgs [][]*schema.Message for ; loops < 10; loops++ { msgs, err := reader.Recv() if err == io.EOF { break } + arrMsgs = append(arrMsgs, msgs) + assert.NoError(t, err) assert.Len(t, msgs, 2) @@ -167,6 +172,11 @@ func TestToolsNode(t *testing.T) { assert.Equal(t, 4, loops) + msgs, err_ := schema.ConcatMessageArray(arrMsgs) + assert.NoError(t, err_) + msg = findMsgByToolCallID(msgs, toolIDOfUserCompany) + msg = findMsgByToolCallID(msgs, toolIDOfUserSalary) + sr, sw := schema.Pipe[[]*schema.Message](2) sw.Send([]*schema.Message{ { @@ -228,6 +238,80 @@ func TestToolsNode(t *testing.T) { assert.Equal(t, 4, loops) }) + + t.Run("order_consistency", func(t *testing.T) { + // Create a ToolsNode with multiple tools + ui := utils.NewTool(userCompanyToolInfo, queryUserCompany) + us := utils.NewTool(userSalaryToolInfo, queryUserSalary) + + toolsNode, err_ := NewToolNode(context.Background(), &ToolsNodeConfig{ + Tools: []tool.BaseTool{ui, us}, + }) + assert.NoError(t, err_) + + // Create an input message with multiple tool calls in a specific order + input := &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: toolIDOfUserSalary, + Function: schema.FunctionCall{ + Name: toolNameOfUserSalary, + Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, + }, + }, + { + ID: toolIDOfUserCompany, + Function: schema.FunctionCall{ + Name: toolNameOfUserCompany, + Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, + }, + }, + }, + } + + // Invoke the ToolsNode + output, err_ := toolsNode.Invoke(context.Background(), input) + assert.NoError(t, err_) + + // Verify the order of output messages matches the order of input tool calls + assert.Equal(t, 2, len(output)) + assert.Equal(t, toolIDOfUserSalary, output[0].ToolCallID) + assert.Equal(t, toolIDOfUserCompany, output[1].ToolCallID) + + // Test with Stream method as well + streamer, err_ := toolsNode.Stream(context.Background(), input) + assert.NoError(t, err_) + defer streamer.Close() + + // Collect all stream outputs + var streamOutputs [][]*schema.Message + for { + chunk, err__ := streamer.Recv() + if err__ == io.EOF { + break + } + assert.NoError(t, err__) + streamOutputs = append(streamOutputs, chunk) + } + + // Verify each chunk maintains the correct order + for _, chunk := range streamOutputs { + if chunk[0] != nil { + assert.Equal(t, toolIDOfUserSalary, chunk[0].ToolCallID) + } + if chunk[1] != nil { + assert.Equal(t, toolIDOfUserCompany, chunk[1].ToolCallID) + } + } + + // Concatenate all stream outputs and verify final result + concatenated, err_ := schema.ConcatMessageArray(streamOutputs) + assert.NoError(t, err_) + assert.Equal(t, 2, len(concatenated)) + assert.Equal(t, toolIDOfUserSalary, concatenated[0].ToolCallID) + assert.Equal(t, toolIDOfUserCompany, concatenated[1].ToolCallID) + }) } type userCompanyRequest struct { diff --git a/flow/agent/react/react_test.go b/flow/agent/react/react_test.go index 643f0722..cb2b0296 100644 --- a/flow/agent/react/react_test.go +++ b/flow/agent/react/react_test.go @@ -90,7 +90,7 @@ func TestReact(t *testing.T) { out, err := a.Generate(ctx, []*schema.Message{ { Role: schema.User, - Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) assert.Nil(t, err) @@ -118,7 +118,7 @@ func TestReact(t *testing.T) { out, err = a.Generate(ctx, []*schema.Message{ { Role: schema.User, - Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) assert.Nil(t, err) @@ -219,7 +219,7 @@ func TestReactStream(t *testing.T) { out, err := a.Stream(ctx, []*schema.Message{ { Role: schema.User, - Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) if err != nil { @@ -270,7 +270,7 @@ func TestReactStream(t *testing.T) { out, err = a.Stream(ctx, []*schema.Message{ { Role: schema.User, - Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) if err != nil { @@ -306,7 +306,7 @@ func TestReactStream(t *testing.T) { out, err = a.Stream(ctx, []*schema.Message{ { Role: schema.User, - Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + Content: "Use greet tool to continuously say hello until you get a bye response, greet names in the following order: max, bob, alice, john, marry, joe, ken, lily, please start directly! please start directly! please start directly!", }, }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) assert.NoError(t, err) diff --git a/go.mod b/go.mod index 28f43802..1d63f059 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/bytedance/sonic v1.13.2 github.com/eino-contrib/jsonschema v1.0.0 github.com/getkin/kin-openapi v0.118.0 + github.com/google/uuid v1.6.0 github.com/nikolalohinski/gonja v1.5.3 github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f github.com/smartystreets/goconvey v1.8.1 diff --git a/go.sum b/go.sum index e9484abc..da9fcaef 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,8 @@ github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncV github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= diff --git a/internal/generic/generic.go b/internal/generic/generic.go index 7158b2f1..2f3293cb 100644 --- a/internal/generic/generic.go +++ b/internal/generic/generic.go @@ -79,3 +79,12 @@ func Reverse[S ~[]E, E any](s S) S { return d } + +// CopyMap copies a map to a new map. +func CopyMap[K comparable, V any](src map[K]V) map[K]V { + dst := make(map[K]V, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} diff --git a/internal/mock/adk/Agent_mock.go b/internal/mock/adk/Agent_mock.go new file mode 100644 index 00000000..9b98c2df --- /dev/null +++ b/internal/mock/adk/Agent_mock.go @@ -0,0 +1,171 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: interface.go +// +// Generated by this command: +// +// mockgen -destination ../internal/mock/adk/Agent_mock.go --package adk -source interface.go +// + +// Package adk is a generated GoMock package. +package adk + +import ( + context "context" + reflect "reflect" + + adk "github.com/cloudwego/eino/adk" + gomock "go.uber.org/mock/gomock" +) + +// MockAgent is a mock of Agent interface. +type MockAgent struct { + ctrl *gomock.Controller + recorder *MockAgentMockRecorder + isgomock struct{} +} + +// MockAgentMockRecorder is the mock recorder for MockAgent. +type MockAgentMockRecorder struct { + mock *MockAgent +} + +// NewMockAgent creates a new mock instance. +func NewMockAgent(ctrl *gomock.Controller) *MockAgent { + mock := &MockAgent{ctrl: ctrl} + mock.recorder = &MockAgentMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAgent) EXPECT() *MockAgentMockRecorder { + return m.recorder +} + +// Description mocks base method. +func (m *MockAgent) Description(ctx context.Context) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Description", ctx) + ret0, _ := ret[0].(string) + return ret0 +} + +// Description indicates an expected call of Description. +func (mr *MockAgentMockRecorder) Description(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Description", reflect.TypeOf((*MockAgent)(nil).Description), ctx) +} + +// Name mocks base method. +func (m *MockAgent) Name(ctx context.Context) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name", ctx) + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockAgentMockRecorder) Name(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockAgent)(nil).Name), ctx) +} + +// Run mocks base method. +func (m *MockAgent) Run(ctx context.Context, input *adk.AgentInput, options ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] { + m.ctrl.T.Helper() + varargs := []any{ctx, input} + for _, a := range options { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Run", varargs...) + ret0, _ := ret[0].(*adk.AsyncIterator[*adk.AgentEvent]) + return ret0 +} + +// Run indicates an expected call of Run. +func (mr *MockAgentMockRecorder) Run(ctx, input any, options ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, input}, options...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockAgent)(nil).Run), varargs...) +} + +// MockOnSubAgents is a mock of OnSubAgents interface. +type MockOnSubAgents struct { + ctrl *gomock.Controller + recorder *MockOnSubAgentsMockRecorder + isgomock struct{} +} + +// MockOnSubAgentsMockRecorder is the mock recorder for MockOnSubAgents. +type MockOnSubAgentsMockRecorder struct { + mock *MockOnSubAgents +} + +// NewMockOnSubAgents creates a new mock instance. +func NewMockOnSubAgents(ctrl *gomock.Controller) *MockOnSubAgents { + mock := &MockOnSubAgents{ctrl: ctrl} + mock.recorder = &MockOnSubAgentsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOnSubAgents) EXPECT() *MockOnSubAgentsMockRecorder { + return m.recorder +} + +// OnDisallowTransferToParent mocks base method. +func (m *MockOnSubAgents) OnDisallowTransferToParent(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnDisallowTransferToParent", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnDisallowTransferToParent indicates an expected call of OnDisallowTransferToParent. +func (mr *MockOnSubAgentsMockRecorder) OnDisallowTransferToParent(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDisallowTransferToParent", reflect.TypeOf((*MockOnSubAgents)(nil).OnDisallowTransferToParent), ctx) +} + +// OnSetAsSubAgent mocks base method. +func (m *MockOnSubAgents) OnSetAsSubAgent(ctx context.Context, parent adk.Agent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnSetAsSubAgent", ctx, parent) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnSetAsSubAgent indicates an expected call of OnSetAsSubAgent. +func (mr *MockOnSubAgentsMockRecorder) OnSetAsSubAgent(ctx, parent any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnSetAsSubAgent", reflect.TypeOf((*MockOnSubAgents)(nil).OnSetAsSubAgent), ctx, parent) +} + +// OnSetSubAgents mocks base method. +func (m *MockOnSubAgents) OnSetSubAgents(ctx context.Context, subAgents []adk.Agent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnSetSubAgents", ctx, subAgents) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnSetSubAgents indicates an expected call of OnSetSubAgents. +func (mr *MockOnSubAgentsMockRecorder) OnSetSubAgents(ctx, subAgents any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnSetSubAgents", reflect.TypeOf((*MockOnSubAgents)(nil).OnSetSubAgents), ctx, subAgents) +} diff --git a/schema/message.go b/schema/message.go index 57103b68..07b01fd4 100644 --- a/schema/message.go +++ b/schema/message.go @@ -18,7 +18,6 @@ package schema import ( "context" - "errors" "fmt" "io" "reflect" @@ -34,14 +33,15 @@ import ( "github.com/slongfield/pyfmt" "github.com/cloudwego/eino/internal" + "github.com/cloudwego/eino/internal/generic" ) func init() { internal.RegisterStreamChunkConcatFunc(ConcatMessages) - internal.RegisterStreamChunkConcatFunc(concatMessageArray) + internal.RegisterStreamChunkConcatFunc(ConcatMessageArray) } -func concatMessageArray(mas [][]*Message) ([]*Message, error) { +func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { arrayLen := len(mas[0]) ret := make([]*Message, arrayLen) @@ -491,7 +491,7 @@ func (m *Message) String() string { sb.WriteString(m.ReasoningContent) } if len(m.ToolCalls) > 0 { - sb.WriteString(fmt.Sprintf("\ntool_calls:\n")) + sb.WriteString("\ntool_calls:\n") for _, tc := range m.ToolCalls { if tc.Index != nil { sb.WriteString(fmt.Sprintf("index[%d]:", *tc.Index)) @@ -652,6 +652,14 @@ func concatToolCalls(chunks []ToolCall) ([]ToolCall, error) { return merged, nil } +func concatExtra(extraList []map[string]any) (map[string]any, error) { + if len(extraList) == 1 { + return generic.CopyMap(extraList[0]), nil + } + + return internal.ConcatItems(extraList) +} + // ConcatMessages concat messages with the same role and name. // It will concat tool calls with the same index. // It will return an error if the messages have different roles or names. @@ -818,12 +826,15 @@ func ConcatMessages(msgs []*Message) (*Message, error) { ret.ToolCalls = merged } - extra, err := internal.ConcatItems(extraList) - if err != nil { - return nil, fmt.Errorf("failed to concat message's extra: %w", err) - } - if len(extra) > 0 { - ret.Extra = extra + if len(extraList) > 0 { + extra, err := concatExtra(extraList) + if err != nil { + return nil, fmt.Errorf("failed to concat message's extra: %w", err) + } + + if len(extra) > 0 { + ret.Extra = extra + } } return &ret, nil @@ -846,14 +857,6 @@ func ConcatMessageStream(s *StreamReader[*Message]) (*Message, error) { msgs = append(msgs, msg) } - if len(msgs) == 0 { - return nil, errors.New("no messages in stream") - } - - if len(msgs) == 1 { - return msgs[0], nil - } - return ConcatMessages(msgs) } diff --git a/schema/select.go b/schema/select.go index 9e009f95..ed22dfc2 100644 --- a/schema/select.go +++ b/schema/select.go @@ -22,10 +22,8 @@ func receiveN[T any](chosenList []int, ss []*stream[T]) (int, *streamItem[T], bo return []func(chosenList []int, ss []*stream[T]) (index int, item *streamItem[T], ok bool){ nil, func(chosenList []int, ss []*stream[T]) (int, *streamItem[T], bool) { - select { - case item, ok := <-ss[chosenList[0]].items: - return chosenList[0], &item, ok - } + item, ok := <-ss[chosenList[0]].items + return chosenList[0], &item, ok }, func(chosenList []int, ss []*stream[T]) (int, *streamItem[T], bool) { select { diff --git a/schema/stream.go b/schema/stream.go index 1bb6e598..16cd9041 100644 --- a/schema/stream.go +++ b/schema/stream.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "reflect" + "runtime" "runtime/debug" "sync" "sync/atomic" @@ -197,7 +198,7 @@ func (sr *StreamReader[T]) Recv() (T, error) { // Close safely closes the StreamReader. // It should be called only once, as multiple calls may not work as expected. // Notice: always remember to call Close() after using Recv(). -// eg. +// e.g. // // defer sr.Close() // @@ -255,6 +256,41 @@ func (sr *StreamReader[T]) Copy(n int) []*StreamReader[T] { return copyStreamReaders[T](sr, n) } +// SetAutomaticClose sets the StreamReader to automatically close when it's no longer reachable and ready to be GCed. +// NOT concurrency safe. +func (sr *StreamReader[T]) SetAutomaticClose() { + switch sr.typ { + case readerTypeStream: + if !sr.st.automaticClose { + sr.st.automaticClose = true + var flag uint32 + sr.st.closedFlag = &flag + runtime.SetFinalizer(sr, func(s *StreamReader[T]) { + s.Close() + }) + } + case readerTypeMultiStream: + for _, s := range sr.msr.nonClosedStreams() { + if !s.automaticClose { + s.automaticClose = true + var flag uint32 + s.closedFlag = &flag + runtime.SetFinalizer(s, func(st *stream[T]) { + st.closeRecv() + }) + } + } + case readerTypeChild: + parent := sr.csr.parent.sr + parent.SetAutomaticClose() + case readerTypeWithConvert: + sr.srw.sr.SetAutomaticClose() + case readerTypeArray: + // no need to clean up + default: + } +} + func (sr *StreamReader[T]) recvAny() (any, error) { return sr.Recv() } @@ -312,6 +348,7 @@ type iStreamReader interface { recvAny() (any, error) copyAny(int) []iStreamReader Close() + SetAutomaticClose() } // stream is a channel-based stream with 1 sender and 1 receiver. @@ -321,6 +358,9 @@ type stream[T any] struct { items chan streamItem[T] closed chan struct{} + + automaticClose bool + closedFlag *uint32 // 0 = not closed, 1 = closed, only used when automaticClose is set } type streamItem[T any] struct { @@ -372,6 +412,13 @@ func (s *stream[T]) closeSend() { } func (s *stream[T]) closeRecv() { + if s.automaticClose { + if atomic.CompareAndSwapUint32(s.closedFlag, 0, 1) { + close(s.closed) + } + return + } + close(s.closed) } diff --git a/schema/stream_copy_external_test.go b/schema/stream_copy_external_test.go index 9147f90c..fc9cc475 100644 --- a/schema/stream_copy_external_test.go +++ b/schema/stream_copy_external_test.go @@ -144,9 +144,7 @@ func TestCopyDelay(t *testing.T) { wg.Wait() infos := make([]info, 0) for _, infoL := range infoList { - for _, info := range infoL { - infos = append(infos, info) - } + infos = append(infos, infoL...) } sort.Slice(infos, func(i, j int) bool { return infos[i].ts < infos[j].ts diff --git a/schema/stream_test.go b/schema/stream_test.go index ab01c7ad..e7020bda 100644 --- a/schema/stream_test.go +++ b/schema/stream_test.go @@ -287,7 +287,9 @@ func TestNewStreamCopy(t *testing.T) { wgEven := sync.WaitGroup{} wgEven.Add(m / 2) - copies := s.asReader().Copy(m) + sr := s.asReader() + sr.SetAutomaticClose() + copies := sr.Copy(m) for i := 0; i < m; i++ { idx := i go func() { @@ -341,8 +343,9 @@ func TestNewStreamCopy(t *testing.T) { copies := s.asReader().Copy(m) for i := 0; i < m; i++ { idx := i + cp := copies[idx] + cp.SetAutomaticClose() go func() { - cp := copies[idx] l := 0 defer func() { wg.Done() @@ -477,13 +480,12 @@ func TestStreamReaderWithConvert(t *testing.T) { } sta := StreamReaderWithConvert[int, int](s.asReader(), convA) + sta.SetAutomaticClose() s.send(1, nil) s.send(2, nil) s.closeSend() - defer sta.Close() - for { item, err := sta.Recv() if err != nil { @@ -518,6 +520,7 @@ func TestArrayStreamCombined(t *testing.T) { s.closeSend() nSR := MergeStreamReaders([]*StreamReader[int]{asr, s.asReader()}) + nSR.SetAutomaticClose() record := make([]bool, 6) for i := 0; i < 6; i++ { @@ -588,7 +591,7 @@ func TestMergeNamedStreamReaders(t *testing.T) { "stream2": sr2, } mergedSR := MergeNamedStreamReaders(namedStreams) - defer mergedSR.Close() + mergedSR.SetAutomaticClose() // Send data to the first stream and close it immediately go func() { @@ -688,7 +691,7 @@ func TestMergeNamedStreamReaders(t *testing.T) { "data": sr2, } mergedSR := MergeNamedStreamReaders(namedStreams) - defer mergedSR.Close() + mergedSR.SetAutomaticClose() // Send data to the second stream go func() { @@ -750,7 +753,7 @@ func TestMergeNamedStreamReaders(t *testing.T) { "stream3": sr3, } mergedSR := MergeNamedStreamReaders(namedStreams) - defer mergedSR.Close() + mergedSR.SetAutomaticClose() // Send data and close streams in sequence go func() {