diff --git a/adk/agent_middleware.go b/adk/agent_middleware.go new file mode 100644 index 00000000..302dec2a --- /dev/null +++ b/adk/agent_middleware.go @@ -0,0 +1,182 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" +) + +// AgentMiddleware provides hooks to customize agent behavior at various stages of execution. +type AgentMiddleware struct { + // Name of the middleware, default empty. This will be used for middleware deduplication. + Name string + + // AdditionalInstruction adds supplementary text to the agent's system instruction. + // This instruction is concatenated with the base instruction before each chat model call. + AdditionalInstruction string + + // AdditionalTools adds supplementary tools to the agent's available toolset. + // These tools are combined with the tools configured for the agent. + AdditionalTools []tool.BaseTool + + // BeforeChatModel is called before each ChatModel invocation, allowing modification of the agent state. + BeforeChatModel func(context.Context, *ChatModelAgentState) error + + // AfterChatModel is called after each ChatModel invocation, allowing modification of the agent state. + AfterChatModel func(context.Context, *ChatModelAgentState) error + + // WrapToolCall wraps tool calls with custom middleware logic. + // Each middleware contains Invokable and/or Streamable functions for tool calls. + WrapToolCall compose.ToolMiddleware + + // BeforeAgent is called before the agent starts executing. It allows modifying the context + // or performing any setup actions before the agent begins processing. + // When an error is returned: + // 1. The framework will immediately return an AsyncIterator containing only this error + // 2. Subsequent BeforeAgent steps in other middlewares will be interrupted + // 3. The OnEvents handlers in previously executed middlewares will be invoked + BeforeAgent func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) + + // OnEvents is called to handle events generated by the agent during execution. + // - iter: The iterator contains the original output from the agent or the processed output from the previous middlewares. + // - gen: The generator is used to send the processed events to the next middleware or directly as output. + // This allows for filtering, transforming, or adding events in the middleware chain. + OnEvents func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) +} + +// AgentMiddlewareChecker is an interface that agents can implement to indicate +// whether they support and enable middleware functionality. +// Agents implementing this interface will execute middlewares internally; +// otherwise, middlewares will be executed outside the agent by Runner. +type AgentMiddlewareChecker interface { + IsAgentMiddlewareEnabled() bool +} + +// ChatModelAgentState represents the state of a chat model agent during conversation. +type ChatModelAgentState struct { + // Messages contains all messages in the current conversation session. + Messages []Message +} + +type InvocationType string + +const ( + // InvocationTypeRun indicates the agent is starting a new execution from scratch. + InvocationTypeRun InvocationType = "Run" + // InvocationTypeResume indicates the agent is resuming a previously interrupted execution. + InvocationTypeResume InvocationType = "Resume" +) + +// AgentContext contains the context information for an agent's execution. +// It provides access to input data, resume information, and execution options. +type AgentContext struct { + // AgentInput contains the input data for the agent's execution. + AgentInput *AgentInput + // ResumeInfo contains information needed to resume a previously interrupted execution. + ResumeInfo *ResumeInfo + // AgentRunOptions contains options for configuring the agent's execution. + AgentRunOptions []AgentRunOption + + // internal properties, read only + agentName string + invocationType InvocationType +} + +func (a *AgentContext) AgentName() string { + return a.agentName +} + +func (a *AgentContext) InvocationType() InvocationType { + return a.invocationType +} + +func isAgentMiddlewareEnabled(a Agent) bool { + if c, ok := a.(AgentMiddlewareChecker); ok && c.IsAgentMiddlewareEnabled() { + return true + } + return false +} + +func newAgentMWHelper(mws ...AgentMiddleware) *agentMWHelper { + helper := &agentMWHelper{} + dedup := make(map[string]struct{}) + for _, mw := range mws { + if _, found := dedup[mw.Name]; mw.Name != "" && found { + continue + } + dedup[mw.Name] = struct{}{} + helper.beforeAgentFns = append(helper.beforeAgentFns, mw.BeforeAgent) + helper.onEventsFns = append(helper.onEventsFns, mw.OnEvents) + } + return helper +} + +type agentMWHelper struct { + beforeAgentFns []func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) + onEventsFns []func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) +} + +func (a *agentMWHelper) execBeforeAgents(ctx context.Context, ac *AgentContext) (context.Context, *AsyncIterator[*AgentEvent]) { + var err error + for i, beforeAgent := range a.beforeAgentFns { + if beforeAgent == nil { + continue + } + ctx, err = beforeAgent(ctx, ac) + if err != nil { + iter, gen := NewAsyncIteratorPair[*AgentEvent]() + gen.Send(&AgentEvent{Err: err}) + gen.Close() + return ctx, a.execOnEventsFromIndex(ctx, ac, i-1, iter) + } + } + return ctx, nil +} + +func (a *agentMWHelper) execOnEvents(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent]) *AsyncIterator[*AgentEvent] { + return a.execOnEventsFromIndex(ctx, ac, len(a.onEventsFns)-1, iter) +} + +func (a *agentMWHelper) execOnEventsFromIndex(ctx context.Context, ac *AgentContext, fromIdx int, iter *AsyncIterator[*AgentEvent]) *AsyncIterator[*AgentEvent] { + for idx := fromIdx; idx >= 0; idx-- { + onEvents := a.onEventsFns[idx] + if onEvents == nil { + continue + } + i, g := NewAsyncIteratorPair[*AgentEvent]() + onEvents(ctx, ac, iter, g) + iter = i + } + return iter +} + +var globalAgentMiddlewares []AgentMiddleware + +// AppendGlobalAgentMiddlewares is used to add global Agent middlewares. +// These middlewares execute at the outermost layer of every Agent (following the "onion model" pattern). +func AppendGlobalAgentMiddlewares(mws ...AgentMiddleware) { + globalAgentMiddlewares = append(globalAgentMiddlewares, mws...) +} + +// GetGlobalAgentMiddlewares is used to retrieve global Agent middlewares. +// This method is typically employed by custom Agent that has implemented the AgentMiddlewareChecker interface. +func GetGlobalAgentMiddlewares() []AgentMiddleware { + return globalAgentMiddlewares +} diff --git a/adk/agent_middleware_test.go b/adk/agent_middleware_test.go new file mode 100644 index 00000000..5c15286d --- /dev/null +++ b/adk/agent_middleware_test.go @@ -0,0 +1,644 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/components/tool/utils" + "github.com/cloudwego/eino/compose" + mockModel "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/schema" +) + +func TestChatModelAgentWithMWCallbacks(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + t.Run("run", func(t *testing.T) { + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Hello", nil), nil). + Times(1) + mw := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + ac.AgentInput.Messages[0].Content = "bye" + ac.AgentRunOptions = append(ac.AgentRunOptions, WithSessionValues(map[string]any{"sv": 1})) + return context.WithValue(ctx, "mw1", 1), nil + }, + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + assert.Equal(t, "bye", ac.AgentInput.Messages[0].Content) + assert.Equal(t, 1, len(ac.AgentRunOptions)) + defer gen.Close() + for { + event, ok := iter.Next() + if !ok { + break + } + gen.Send(event) + } + gen.Send(&AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("bye from OnEvents", nil), + Role: schema.Assistant, + }, + }, + }) + }, + } + + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "a", + Description: "b", + Model: cm, + Middlewares: []AgentMiddleware{mw}, + }) + assert.NoError(t, err) + iter := a.Run(ctx, &AgentInput{Messages: []Message{{Content: "Hi"}}}) + event, ok := iter.Next() + assert.True(t, ok) + assert.Equal(t, "Hello", event.Output.MessageOutput.Message.Content) + event, ok = iter.Next() + assert.True(t, ok) + assert.Equal(t, "bye from OnEvents", event.Output.MessageOutput.Message.Content) + event, ok = iter.Next() + assert.False(t, ok) + assert.Nil(t, event) + }) + + t.Run("run with partly callbacks", func(t *testing.T) { + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Hello", nil), nil). + Times(1) + mw1 := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + ac.AgentInput.Messages[0].Content = "bye" + ac.AgentRunOptions = append(ac.AgentRunOptions, WithSessionValues(map[string]any{"sv": 1})) + return context.WithValue(ctx, "mw1", 1), nil + }, + } + mw2 := AgentMiddleware{ + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + assert.Equal(t, "bye", ac.AgentInput.Messages[0].Content) + assert.Equal(t, 1, len(ac.AgentRunOptions)) + defer gen.Close() + for { + event, ok := iter.Next() + if !ok { + break + } + gen.Send(event) + } + gen.Send(&AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("bye from OnEvents", nil), + Role: schema.Assistant, + }, + }, + }) + }, + } + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "a", + Description: "b", + Model: cm, + Middlewares: []AgentMiddleware{mw1, mw2}, + }) + assert.NoError(t, err) + iter := a.Run(ctx, &AgentInput{Messages: []Message{{Content: "Hi"}}}) + event, ok := iter.Next() + assert.True(t, ok) + assert.Equal(t, "Hello", event.Output.MessageOutput.Message.Content) + event, ok = iter.Next() + assert.True(t, ok) + assert.Equal(t, "bye from OnEvents", event.Output.MessageOutput.Message.Content) + event, ok = iter.Next() + assert.False(t, ok) + assert.Nil(t, event) + }) + + t.Run("run with BeforeAgent interrupt", func(t *testing.T) { + mw1 := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + ac.AgentInput.Messages[0].Content = "bye" + ac.AgentRunOptions = append(ac.AgentRunOptions, WithSessionValues(map[string]any{"sv": 1})) + return context.WithValue(ctx, "mw1", 1), nil + }, + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + assert.Equal(t, "bye", ac.AgentInput.Messages[0].Content) + assert.Equal(t, 1, len(ac.AgentRunOptions)) + defer gen.Close() + gen.Send(&AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("bye from OnEvents", nil), + Role: schema.Assistant, + }, + }, + }) + for { + event, ok := iter.Next() + if !ok { + break + } + gen.Send(event) + } + }, + } + mw2 := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + return ctx, fmt.Errorf("mock err") + }, + } + mw3 := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + ac.AgentInput.Messages[0].Content = "invalid change" + return ctx, nil + }, + } + + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "a", + Description: "b", + Model: cm, + Middlewares: []AgentMiddleware{mw1, mw2, mw3}, + }) + assert.NoError(t, err) + iter := a.Run(ctx, &AgentInput{Messages: []Message{{Content: "Hi"}}}) + event, ok := iter.Next() + assert.True(t, ok) + assert.Equal(t, "bye from OnEvents", event.Output.MessageOutput.Message.Content) + event, ok = iter.Next() + assert.True(t, ok) + assert.Error(t, fmt.Errorf("mock err"), event.Err) + event, ok = iter.Next() + assert.False(t, ok) + assert.Nil(t, event) + }) +} + +func TestRunnerWithMWCallbacks(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockToolCallingChatModel(ctrl) + + t.Run("run with ChatModelAgent", func(t *testing.T) { + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Hello", nil), nil). + Times(1) + mw1 := AgentMiddleware{ + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + assert.Equal(t, "bye", ac.AgentInput.Messages[0].Content) + assert.Equal(t, 1, len(ac.AgentRunOptions)) + defer gen.Close() + for { + event, ok := iter.Next() + if !ok { + break + } + gen.Send(event) + } + gen.Send(&AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("bye from OnEvents", nil), + Role: schema.Assistant, + }, + }, + }) + }, + } + mw2 := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + ac.AgentInput.Messages[0].Content = "bye" + ac.AgentRunOptions = append(ac.AgentRunOptions, WithSessionValues(map[string]any{"sv": 1})) + return context.WithValue(ctx, "mw1", 1), nil + }, + } + + globalAgentMiddlewares = []AgentMiddleware{mw2} + defer func() { + globalAgentMiddlewares = nil + }() + + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "a", + Description: "b", + Model: cm, + Middlewares: []AgentMiddleware{mw1}, + }) + assert.NoError(t, err) + + runner := NewRunner(ctx, RunnerConfig{ + Agent: a, + EnableStreaming: false, + CheckPointStore: nil, + }) + + iter := runner.Run(ctx, []Message{{Content: "Hi"}}) + event, ok := iter.Next() + assert.True(t, ok) + assert.Equal(t, "Hello", event.Output.MessageOutput.Message.Content) + event, ok = iter.Next() + assert.True(t, ok) + assert.Equal(t, "bye from OnEvents", event.Output.MessageOutput.Message.Content) + event, ok = iter.Next() + assert.False(t, ok) + assert.Nil(t, event) + }) + + t.Run("run with ChatModelAgent BeforeAgent interrupt", func(t *testing.T) { + mw1 := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + ac.AgentInput.Messages[0].Content = "bye" + ac.AgentRunOptions = append(ac.AgentRunOptions, WithSessionValues(map[string]any{"sv": 1})) + return context.WithValue(ctx, "mw1", 1), nil + }, + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + assert.Equal(t, "bye", ac.AgentInput.Messages[0].Content) + assert.Equal(t, 1, len(ac.AgentRunOptions)) + defer gen.Close() + gen.Send(&AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("bye from OnEvents", nil), + Role: schema.Assistant, + }, + }, + }) + for { + event, ok := iter.Next() + if !ok { + break + } + gen.Send(event) + } + }, + } + mw2 := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + return ctx, fmt.Errorf("mock err") + }, + } + + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "a", + Description: "b", + Model: cm, + Middlewares: []AgentMiddleware{mw1}, + }) + assert.NoError(t, err) + + globalAgentMiddlewares = []AgentMiddleware{mw2} + defer func() { + globalAgentMiddlewares = nil + }() + + runner := NewRunner(ctx, RunnerConfig{ + Agent: a, + }) + + iter := runner.Run(ctx, []Message{{Content: "Hi"}}) + event, ok := iter.Next() + assert.True(t, ok) + assert.Error(t, fmt.Errorf("mock err"), event.Err) + event, ok = iter.Next() + assert.False(t, ok) + assert.Nil(t, event) + }) + + t.Run("run with Workflow Agents", func(t *testing.T) { + var ( + cmBeforeAgentCounter, cmOnEventsCounter int32 + runnerBeforeAgentCounter, runnerOnEventsCounter int32 + ) + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Hello", nil), nil). + Times(5) + + mwChatModelAgent := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + atomic.AddInt32(&cmBeforeAgentCounter, 1) + return ctx, nil + }, + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + atomic.AddInt32(&cmOnEventsCounter, 1) + defer gen.Close() + for { + event, ok := iter.Next() + if !ok { + break + } + gen.Send(event) + } + }, + } + + mwRunner := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + atomic.AddInt32(&runnerBeforeAgentCounter, 1) + return ctx, nil + }, + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + atomic.AddInt32(&runnerOnEventsCounter, 1) + BypassIterator(iter, gen) + }, + } + + newChatModelAgent := func() func() Agent { + nameCounter := 0 + return func() Agent { + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: fmt.Sprintf("chat_model_agent_%d", nameCounter), + Description: "b", + Model: cm, + Middlewares: []AgentMiddleware{mwChatModelAgent}, + }) + assert.NoError(t, err) + nameCounter++ + return a + } + }() + + loop, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", + SubAgents: []Agent{newChatModelAgent(), newChatModelAgent()}, + MaxIterations: 1, + }) + assert.NoError(t, err) + + parallel, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "parallel", + SubAgents: []Agent{newChatModelAgent(), newChatModelAgent()}, + }) + assert.NoError(t, err) + + seq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", + SubAgents: []Agent{newChatModelAgent(), loop, parallel}, + }) + assert.NoError(t, err) + + globalAgentMiddlewares = []AgentMiddleware{mwRunner} + defer func() { + globalAgentMiddlewares = nil + }() + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seq, + }) + + iter := runner.Run(ctx, []Message{{Content: "work work"}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + assert.Equal(t, int32(5), cmBeforeAgentCounter) + assert.Equal(t, int32(5), cmOnEventsCounter) + assert.Equal(t, int32(8), runnerBeforeAgentCounter) + assert.Equal(t, int32(8), runnerOnEventsCounter) + }) + + t.Run("run with Workflow Agents BeforeAgent interrupt", func(t *testing.T) { + var ( + cmBeforeAgentCounter, cmOnEventsCounter int32 + runnerBeforeAgentCounter, runnerOnEventsCounter int32 + ) + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Hello", nil), nil). + Times(1) + + mwChatModelAgent := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + atomic.AddInt32(&cmBeforeAgentCounter, 1) + ac.AgentInput.Messages[0].Content = "bye" + return ctx, nil + }, + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + atomic.AddInt32(&cmOnEventsCounter, 1) + assert.Equal(t, "bye", ac.AgentInput.Messages[0].Content) + defer gen.Close() + for { + event, ok := iter.Next() + if !ok { + break + } + gen.Send(event) + } + }, + } + + mwRunner := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + v := atomic.AddInt32(&runnerBeforeAgentCounter, 1) + if v == 3 { + return ctx, fmt.Errorf("interrupt") + } + return ctx, nil + }, + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + atomic.AddInt32(&runnerOnEventsCounter, 1) + BypassIterator(iter, gen) + }, + } + + newChatModelAgent := func() func() Agent { + nameCounter := 0 + return func() Agent { + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: fmt.Sprintf("chat_model_agent_%d", nameCounter), + Description: "b", + Model: cm, + Middlewares: []AgentMiddleware{mwChatModelAgent}, + }) + assert.NoError(t, err) + nameCounter++ + return a + } + }() + + loop, err := NewLoopAgent(ctx, &LoopAgentConfig{ + Name: "loop", + SubAgents: []Agent{newChatModelAgent(), newChatModelAgent()}, + MaxIterations: 1, + }) + assert.NoError(t, err) + + parallel, err := NewParallelAgent(ctx, &ParallelAgentConfig{ + Name: "parallel", + SubAgents: []Agent{newChatModelAgent(), newChatModelAgent()}, + }) + assert.NoError(t, err) + + seq, err := NewSequentialAgent(ctx, &SequentialAgentConfig{ + Name: "seq", + SubAgents: []Agent{newChatModelAgent(), loop, parallel}, + }) + assert.NoError(t, err) + + globalAgentMiddlewares = []AgentMiddleware{mwRunner} + defer func() { + globalAgentMiddlewares = nil + }() + + runner := NewRunner(ctx, RunnerConfig{ + Agent: seq, + }) + + iter := runner.Run(ctx, []Message{{Content: "work work"}}) + for { + _, ok := iter.Next() + if !ok { + break + } + } + assert.Equal(t, int32(1), cmBeforeAgentCounter) + assert.Equal(t, int32(1), cmOnEventsCounter) + assert.Equal(t, int32(3), runnerBeforeAgentCounter) + assert.Equal(t, int32(2), runnerOnEventsCounter) + }) + + t.Run("resume with ChatModelAgent", func(t *testing.T) { + cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).Times(1) + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("Hello", []schema.ToolCall{ + { + ID: "1", + Type: "function", + Function: schema.FunctionCall{ + Name: "test_tool", + Arguments: "{\"input\":123}", + }, + Extra: nil, + }, + }), nil). + Times(1) + + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + Return(schema.AssistantMessage("ok", nil), nil). + Times(1) + + type toolInput struct { + Input int `json:"input"` + } + type toolOutput struct { + Output string `json:"output"` + } + type interruptOptions struct { + NewInput *string + } + withOptions := func(newInput string) tool.Option { + return tool.WrapImplSpecificOptFn(func(t *interruptOptions) { + t.NewInput = &newInput + }) + } + mockTool, err := utils.InferOptionableTool( + "test_tool", + "test", + func(ctx context.Context, input *toolInput, opts ...tool.Option) (output *toolOutput, err error) { + o := tool.GetImplSpecificOptions[interruptOptions](nil, opts...) + if o.NewInput == nil { + return nil, compose.Interrupt(ctx, input.Input) + } + return &toolOutput{Output: "from interrupt:" + *o.NewInput}, nil + }, + ) + assert.NoError(t, err) + + mw1 := AgentMiddleware{ + OnEvents: func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + BypassIterator(iter, gen) + }, + } + mw2 := AgentMiddleware{ + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + if ac.InvocationType() == InvocationTypeRun { + assert.Equal(t, "Hi", ac.AgentInput.Messages[0].Content) + } else if ac.InvocationType() == InvocationTypeResume { + assert.NotNil(t, ac.ResumeInfo) + } else { + assert.Fail(t, "invalid invocationType") + } + return ctx, nil + }, + } + + a, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{ + Name: "a", + Description: "b", + Model: cm, + ToolsConfig: ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{mockTool}, + }, + }, + Middlewares: []AgentMiddleware{mw1}, + }) + assert.NoError(t, err) + + globalAgentMiddlewares = []AgentMiddleware{mw2} + defer func() { + globalAgentMiddlewares = nil + }() + + runner := NewRunner(ctx, RunnerConfig{ + Agent: a, + EnableStreaming: false, + CheckPointStore: newMyStore(), + }) + + iter := runner.Run(ctx, []Message{{Content: "Hi"}}, WithCheckPointID("1")) + event, ok := iter.Next() + assert.True(t, ok) + assert.Equal(t, "Hello", event.Output.MessageOutput.Message.Content) + assert.NotNil(t, event.Output.MessageOutput.Message.ToolCalls[0]) + event, ok = iter.Next() + assert.True(t, ok) + assert.NotNil(t, event.Action.Interrupted) + event, ok = iter.Next() + assert.False(t, ok) + assert.Nil(t, event) + + iter, err = runner.Resume(ctx, "1", WithToolOptions([]tool.Option{withOptions("resume_input")})) + assert.NoError(t, err) + event, ok = iter.Next() + assert.True(t, ok) + assert.Equal(t, "{\"output\":\"from interrupt:resume_input\"}", event.Output.MessageOutput.Message.Content) + event, ok = iter.Next() + assert.True(t, ok) + assert.Equal(t, "ok", event.Output.MessageOutput.Message.Content) + event, ok = iter.Next() + assert.False(t, ok) + assert.Nil(t, event) + }) +} diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 21c18eed..98ee8b43 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -54,6 +54,7 @@ type chatModelAgentRunOptions struct { chatModelOptions []model.Option toolOptions []tool.Option agentToolOptions map[ /*tool name*/ string][]AgentRunOption // todo: map or list? + graphCallbacks []callbacks.Handler // resume historyModifier func(context.Context, []Message) []Message @@ -80,6 +81,13 @@ func WithAgentToolRunOptions(opts map[string] /*tool name*/ []AgentRunOption) Ag }) } +// WithGraphCallbacks sets callback handlers for internal graph / chain execution. +func WithGraphCallbacks(callbacks ...callbacks.Handler) AgentRunOption { + return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) { + t.graphCallbacks = callbacks + }) +} + // WithHistoryModifier sets a function to modify history during resume. // Deprecated: use ResumeWithData and ChatModelAgentResumeData instead. func WithHistoryModifier(f func(context.Context, []Message) []Message) AgentRunOption { @@ -129,33 +137,6 @@ func defaultGenModelInput(ctx context.Context, instruction string, input *AgentI return msgs, nil } -// ChatModelAgentState represents the state of a chat model agent during conversation. -type ChatModelAgentState struct { - // Messages contains all messages in the current conversation session. - Messages []Message -} - -// AgentMiddleware provides hooks to customize agent behavior at various stages of execution. -type AgentMiddleware struct { - // AdditionalInstruction adds supplementary text to the agent's system instruction. - // This instruction is concatenated with the base instruction before each chat model call. - AdditionalInstruction string - - // AdditionalTools adds supplementary tools to the agent's available toolset. - // These tools are combined with the tools configured for the agent. - AdditionalTools []tool.BaseTool - - // BeforeChatModel is called before each ChatModel invocation, allowing modification of the agent state. - BeforeChatModel func(context.Context, *ChatModelAgentState) error - - // AfterChatModel is called after each ChatModel invocation, allowing modification of the agent state. - AfterChatModel func(context.Context, *ChatModelAgentState) error - - // WrapToolCall wraps tool calls with custom middleware logic. - // Each middleware contains Invokable and/or Streamable functions for tool calls. - WrapToolCall compose.ToolMiddleware -} - type ChatModelAgentConfig struct { // Name of the agent. Better be unique across all agents. Name string @@ -221,17 +202,20 @@ type ChatModelAgent struct { exit tool.BaseTool + beforeAgents []func(context.Context, *AgentContext) (context.Context, error) beforeChatModels, afterChatModels []func(context.Context, *ChatModelAgentState) error + onEvents []func(context.Context, *AgentContext, *AsyncIterator[*AgentEvent], *AsyncGenerator[*AgentEvent]) modelRetryConfig *ModelRetryConfig // runner once sync.Once run runFunc + helper *agentMWHelper frozen uint32 } -type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, opts ...compose.Option) +type runFunc func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, callbacks []callbacks.Handler, opts ...compose.Option) // NewChatModelAgent constructs a chat model-backed agent with the provided config. func NewChatModelAgent(_ context.Context, config *ChatModelAgentConfig) (*ChatModelAgent, error) { @@ -250,39 +234,26 @@ func NewChatModelAgent(_ context.Context, config *ChatModelAgentConfig) (*ChatMo genInput = config.GenModelInput } - beforeChatModels := make([]func(context.Context, *ChatModelAgentState) error, 0) - afterChatModels := make([]func(context.Context, *ChatModelAgentState) error, 0) - sb := &strings.Builder{} - sb.WriteString(config.Instruction) - tc := config.ToolsConfig - for _, m := range config.Middlewares { - sb.WriteString("\n") - sb.WriteString(m.AdditionalInstruction) - tc.Tools = append(tc.Tools, m.AdditionalTools...) - - if m.WrapToolCall.Invokable != nil || m.WrapToolCall.Streamable != nil { - tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, m.WrapToolCall) - } - if m.BeforeChatModel != nil { - beforeChatModels = append(beforeChatModels, m.BeforeChatModel) - } - if m.AfterChatModel != nil { - afterChatModels = append(afterChatModels, m.AfterChatModel) - } + mwHelper := &chatModelMWHelper{ + instruction: config.Instruction, + toolsConfig: config.ToolsConfig, } + mwHelper = mwHelper.withMWs(config.Middlewares) return &ChatModelAgent{ name: config.Name, description: config.Description, - instruction: sb.String(), + instruction: mwHelper.instruction, model: config.Model, - toolsConfig: tc, + toolsConfig: mwHelper.toolsConfig, genModelInput: genInput, exit: config.Exit, outputKey: config.OutputKey, maxIterations: config.MaxIterations, - beforeChatModels: beforeChatModels, - afterChatModels: afterChatModels, + beforeAgents: mwHelper.beforeAgents, + beforeChatModels: mwHelper.beforeChatModels, + afterChatModels: mwHelper.afterChatModels, + onEvents: mwHelper.onEvents, modelRetryConfig: config.ModelRetryConfig, }, nil } @@ -542,7 +513,7 @@ func genReactCallbacks(ctx context.Context, agentName string, generator *AsyncGenerator[*AgentEvent], enableStreaming bool, store *bridgeStore, - modelRetryConfigs *ModelRetryConfig) compose.Option { + modelRetryConfigs *ModelRetryConfig) callbacks.Handler { h := &cbHandler{ ctx: ctx, @@ -551,7 +522,8 @@ func genReactCallbacks(ctx context.Context, agentName string, agentName: agentName, store: store, enableStreaming: enableStreaming, - modelRetryConfigs: modelRetryConfigs} + modelRetryConfigs: modelRetryConfigs, + } cmHandler := &ub.ModelCallbackHandler{ OnEnd: h.onChatModelEnd, @@ -616,7 +588,7 @@ func genReactCallbacks(ctx context.Context, agentName string, cb := ub.NewHandlerHelper().ChatModel(cmHandler).ToolsNode(toolsNodeHandler).Graph(reactGraphHandler).Chain(chainHandler).Handler() - return compose.WithCallbacks(cb) + return cb } type noToolsCbHandler struct { @@ -657,7 +629,7 @@ func (h *noToolsCbHandler) onGraphError(ctx context.Context, return ctx } -func genNoToolsCallbacks(generator *AsyncGenerator[*AgentEvent], modelRetryConfigs *ModelRetryConfig) compose.Option { +func genNoToolsCallbacks(generator *AsyncGenerator[*AgentEvent], modelRetryConfigs *ModelRetryConfig) callbacks.Handler { h := &noToolsCbHandler{ AsyncGenerator: generator, modelRetryConfigs: modelRetryConfigs, @@ -671,7 +643,7 @@ func genNoToolsCallbacks(generator *AsyncGenerator[*AgentEvent], modelRetryConfi cb := ub.NewHandlerHelper().ChatModel(cmHandler).Chain(graphHandler).Handler() - return compose.WithCallbacks(cb) + return cb } func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStream, outputKey string) error { @@ -690,7 +662,7 @@ func setOutputToSession(ctx context.Context, msg Message, msgStream MessageStrea } func errFunc(err error) runFunc { - return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ ...compose.Option) { + return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, _ []callbacks.Handler, _ ...compose.Option) { generator.Send(&AgentEvent{Err: err}) } } @@ -703,198 +675,215 @@ type ChatModelAgentResumeData struct { HistoryModifier func(ctx context.Context, history []Message) []Message } -func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc { +func (a *ChatModelAgent) buildRunFunc(ctx context.Context) (*agentMWHelper, runFunc) { a.once.Do(func() { - instruction := a.instruction - toolsNodeConf := a.toolsConfig.ToolsNodeConfig - returnDirectly := copyMap(a.toolsConfig.ReturnDirectly) + a.helper, a.run = func() (*agentMWHelper, runFunc) { + helper := &chatModelMWHelper{ + instruction: a.instruction, + toolsConfig: ToolsConfig{ + ToolsNodeConfig: a.toolsConfig.ToolsNodeConfig, + ReturnDirectly: copyMap(a.toolsConfig.ReturnDirectly), + }, + beforeChatModels: a.beforeChatModels, + afterChatModels: a.afterChatModels, + beforeAgents: a.beforeAgents, + onEvents: a.onEvents, + } - transferToAgents := a.subAgents - if a.parentAgent != nil && !a.disallowTransferToParent { - transferToAgents = append(transferToAgents, a.parentAgent) - } + if mws := GetGlobalAgentMiddlewares(); len(mws) > 0 { + helper = helper.withMWs(mws) + } - if len(transferToAgents) > 0 { - transferInstruction := genTransferToAgentInstruction(ctx, transferToAgents) - instruction = concatInstructions(instruction, transferInstruction) + transferToAgents := a.subAgents + if a.parentAgent != nil && !a.disallowTransferToParent { + transferToAgents = append(transferToAgents, a.parentAgent) + } - toolsNodeConf.Tools = append(toolsNodeConf.Tools, &transferToAgent{}) - returnDirectly[TransferToAgentToolName] = true - } + if len(transferToAgents) > 0 { + helper = helper.withTransferToAgents(ctx, transferToAgents) + } - if a.exit != nil { - toolsNodeConf.Tools = append(toolsNodeConf.Tools, a.exit) - exitInfo, err := a.exit.Info(ctx) - if err != nil { - a.run = errFunc(err) - return + if a.exit != nil { + var ef runFunc + helper, ef = helper.withExitTool(ctx, a.exit) + if ef != nil { + return helper.toMWHelper(), ef + } } - returnDirectly[exitInfo.Name] = true - } - if len(toolsNodeConf.Tools) == 0 { - var chatModel model.ToolCallingChatModel = a.model - if a.modelRetryConfig != nil { - chatModel = newRetryChatModel(a.model, a.modelRetryConfig) + // without tools, call chat model once + if len(helper.toolsConfig.Tools) == 0 { + return helper.toMWHelper(), a.buildSimpleChatModelChain(helper) } - a.run = func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], - store *bridgeStore, opts ...compose.Option) { - r, err := compose.NewChain[*AgentInput, Message](compose.WithGenLocalState(func(ctx context.Context) (state *ChatModelAgentState) { - return &ChatModelAgentState{} - })). - AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *AgentInput) ([]Message, error) { - messages, err := a.genModelInput(ctx, instruction, input) - if err != nil { - return nil, err - } - return messages, nil - })). - AppendChatModel( - chatModel, - compose.WithStatePreHandler(func(ctx context.Context, in []*schema.Message, state *ChatModelAgentState) ([]*schema.Message, error) { - state.Messages = in - for _, bc := range a.beforeChatModels { - err := bc(ctx, state) - if err != nil { - return nil, err - } - } - return state.Messages, nil - }), - compose.WithStatePostHandler(func(ctx context.Context, in *schema.Message, state *ChatModelAgentState) (*schema.Message, error) { - state.Messages = append(state.Messages, in) - for _, ac := range a.afterChatModels { - err := ac(ctx, state) - if err != nil { - return nil, err - } - } - return in, nil - }), - ). - Compile(ctx, compose.WithGraphName(a.name), - compose.WithCheckPointStore(store), - compose.WithSerializer(&gobSerializer{})) - if err != nil { - generator.Send(&AgentEvent{Err: err}) - return - } + // with tools, react + return helper.toMWHelper(), a.buildReActChain(ctx, helper) + }() + }) - callOpt := genNoToolsCallbacks(generator, a.modelRetryConfig) - var runOpts []compose.Option - runOpts = append(runOpts, opts...) - runOpts = append(runOpts, callOpt) - - var msg Message - var msgStream MessageStream - if input.EnableStreaming { - msgStream, err = r.Stream(ctx, input, runOpts...) - } else { - msg, err = r.Invoke(ctx, input, runOpts...) - } + atomic.StoreUint32(&a.frozen, 1) + return a.helper, a.run +} - if err == nil { - if a.outputKey != "" { - err = setOutputToSession(ctx, msg, msgStream, a.outputKey) - if err != nil { - generator.Send(&AgentEvent{Err: err}) - } - } else if msgStream != nil { - msgStream.Close() - } - } +func (a *ChatModelAgent) buildSimpleChatModelChain(helper *chatModelMWHelper) runFunc { + return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, handlers []callbacks.Handler, opts ...compose.Option) { + genState := func(ctx context.Context) *State { + return &State{AgentName: a.name} + } - generator.Close() - } + chatModel := a.model + if a.modelRetryConfig != nil { + chatModel = newRetryChatModel(a.model, a.modelRetryConfig) + } - return + modelPreHandle := func(ctx context.Context, input []Message, st *State) ([]Message, error) { + s := &ChatModelAgentState{Messages: append(st.Messages, input...)} + for _, bcm := range helper.beforeChatModels { + if err := bcm(ctx, s); err != nil { + return nil, err + } + } + st.Messages = s.Messages + return st.Messages, nil } - // react - conf := &reactConfig{ - model: a.model, - toolsConfig: &toolsNodeConf, - toolsReturnDirectly: returnDirectly, - agentName: a.name, - maxIterations: a.maxIterations, - beforeChatModel: a.beforeChatModels, - afterChatModel: a.afterChatModels, - modelRetryConfig: a.modelRetryConfig, + modelPostHandle := func(ctx context.Context, input Message, st *State) (Message, error) { + s := &ChatModelAgentState{Messages: append(st.Messages, input)} + for _, acm := range helper.afterChatModels { + if err := acm(ctx, s); err != nil { + return nil, err + } + } + st.Messages = s.Messages + return input, nil } - g, err := newReact(ctx, conf) + r, err := compose.NewChain[*AgentInput, Message](compose.WithGenLocalState(genState)). + AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *AgentInput) ([]Message, error) { + return a.genModelInput(ctx, helper.instruction, input) + })). + AppendChatModel(chatModel, + compose.WithStatePreHandler(modelPreHandle), + compose.WithStatePostHandler(modelPostHandle), + ). + Compile(ctx, compose.WithGraphName(a.name), + compose.WithCheckPointStore(store), + compose.WithSerializer(&gobSerializer{})) if err != nil { - a.run = errFunc(err) + generator.Send(&AgentEvent{Err: err}) return } - a.run = func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, - opts ...compose.Option) { - var compileOptions []compose.GraphCompileOption - compileOptions = append(compileOptions, - compose.WithGraphName(a.name), - compose.WithCheckPointStore(store), - compose.WithSerializer(&gobSerializer{}), - // ensure the graph won't exceed max steps due to max iterations - compose.WithMaxRunSteps(math.MaxInt)) - - runnable, err_ := compose.NewChain[*AgentInput, Message](). - AppendLambda( - compose.InvokableLambda(func(ctx context.Context, input *AgentInput) ([]Message, error) { - return a.genModelInput(ctx, instruction, input) - }), - ). - AppendGraph(g, compose.WithNodeName("ReAct"), compose.WithGraphCompileOptions(compose.WithMaxRunSteps(math.MaxInt))). - Compile(ctx, compileOptions...) - if err_ != nil { - generator.Send(&AgentEvent{Err: err_}) - return - } + opts = append(opts, compose.WithCallbacks(append(handlers, genNoToolsCallbacks(generator, a.modelRetryConfig))...)) - callOpt := genReactCallbacks(ctx, a.name, generator, input.EnableStreaming, store, a.modelRetryConfig) - var runOpts []compose.Option - runOpts = append(runOpts, opts...) - runOpts = append(runOpts, callOpt) - if a.toolsConfig.EmitInternalEvents { - runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(generator)))) - } - if input.EnableStreaming { - runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEnableStreaming(true)))) - } - - var msg Message - var msgStream MessageStream - if input.EnableStreaming { - msgStream, err_ = runnable.Stream(ctx, input, runOpts...) - } else { - msg, err_ = runnable.Invoke(ctx, input, runOpts...) - } + var msg Message + var msgStream MessageStream + if input.EnableStreaming { + msgStream, err = r.Stream(ctx, input, opts...) + } else { + msg, err = r.Invoke(ctx, input, opts...) + } - if err_ == nil { - if a.outputKey != "" { - err_ = setOutputToSession(ctx, msg, msgStream, a.outputKey) - if err_ != nil { - generator.Send(&AgentEvent{Err: err_}) - } - } else if msgStream != nil { - msgStream.Close() + if err == nil { + if a.outputKey != "" { + err = setOutputToSession(ctx, msg, msgStream, a.outputKey) + if err != nil { + generator.Send(&AgentEvent{Err: err}) } + } else if msgStream != nil { + msgStream.Close() } - - generator.Close() } + + generator.Close() + } +} + +func (a *ChatModelAgent) buildReActChain(ctx context.Context, helper *chatModelMWHelper) runFunc { + g, err := newReact(ctx, &reactConfig{ + model: a.model, + toolsConfig: &helper.toolsConfig.ToolsNodeConfig, + toolsReturnDirectly: helper.toolsConfig.ReturnDirectly, + agentName: a.name, + maxIterations: a.maxIterations, + beforeChatModel: helper.beforeChatModels, + afterChatModel: helper.afterChatModels, + modelRetryConfig: a.modelRetryConfig, }) + if err != nil { + return errFunc(err) + } - atomic.StoreUint32(&a.frozen, 1) + return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore, handlers []callbacks.Handler, opts ...compose.Option) { + var compileOptions []compose.GraphCompileOption + compileOptions = append(compileOptions, + compose.WithGraphName(a.name), + compose.WithCheckPointStore(store), + compose.WithSerializer(&gobSerializer{}), + // ensure the graph won't exceed max steps due to max iterations + compose.WithMaxRunSteps(math.MaxInt)) + + runnable, err_ := compose.NewChain[*AgentInput, Message](). + AppendLambda( + compose.InvokableLambda(func(ctx context.Context, input *AgentInput) ([]Message, error) { + return a.genModelInput(ctx, helper.instruction, input) + }), + ). + AppendGraph(g, compose.WithNodeName("ReAct"), compose.WithGraphCompileOptions(compose.WithMaxRunSteps(math.MaxInt))). + Compile(ctx, compileOptions...) + if err_ != nil { + generator.Send(&AgentEvent{Err: err_}) + return + } + + reactCallback := genReactCallbacks(ctx, a.name, generator, input.EnableStreaming, store, a.modelRetryConfig) + opts = append(opts, compose.WithCallbacks(append(handlers, reactCallback)...)) + if a.toolsConfig.EmitInternalEvents { + opts = append(opts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(generator)))) + } + if input.EnableStreaming { + opts = append(opts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEnableStreaming(true)))) + } + + var msg Message + var msgStream MessageStream + if input.EnableStreaming { + msgStream, err_ = runnable.Stream(ctx, input, opts...) + } else { + msg, err_ = runnable.Invoke(ctx, input, opts...) + } + + if err_ == nil { + if a.outputKey != "" { + err_ = setOutputToSession(ctx, msg, msgStream, a.outputKey) + if err_ != nil { + generator.Send(&AgentEvent{Err: err_}) + } + } else if msgStream != nil { + msgStream.Close() + } + } - return a.run + generator.Close() + } } func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - run := a.buildRunFunc(ctx) + agentContext := &AgentContext{ + AgentInput: input, + AgentRunOptions: opts, + agentName: a.name, + invocationType: InvocationTypeRun, + } + + mwHelper, run := a.buildRunFunc(ctx) - co := getComposeOptions(opts) + ctx, termIter := mwHelper.execBeforeAgents(ctx, agentContext) + if termIter != nil { + return termIter + } + + co, ch := getComposeOptions(agentContext.AgentRunOptions) co = append(co, compose.WithCheckPointID(bridgeCheckpointID)) iterator, generator := NewAsyncIteratorPair[*AgentEvent]() @@ -909,23 +898,37 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age generator.Close() }() - run(ctx, input, generator, newBridgeStore(), co...) + store := newBridgeStore() + + run(ctx, agentContext.AgentInput, generator, store, ch, co...) }() - return iterator + return mwHelper.execOnEvents(ctx, agentContext, iterator) } func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - run := a.buildRunFunc(ctx) + agentContext := &AgentContext{ + ResumeInfo: info, + AgentRunOptions: opts, + agentName: a.name, + invocationType: InvocationTypeResume, + } - co := getComposeOptions(opts) + mwHelper, run := a.buildRunFunc(ctx) + + ctx, termIter := mwHelper.execBeforeAgents(ctx, agentContext) + if termIter != nil { + return termIter + } + + co, ch := getComposeOptions(agentContext.AgentRunOptions) co = append(co, compose.WithCheckPointID(bridgeCheckpointID)) if info.InterruptState == nil { panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has no state", a.Name(ctx))) } - stateByte, ok := info.InterruptState.([]byte) + stateByte, ok := agentContext.ResumeInfo.InterruptState.([]byte) if !ok { panic(fmt.Sprintf("ChatModelAgent.Resume: agent '%s' was asked to resume but has invalid interrupt state type: %T", a.Name(ctx), info.InterruptState)) @@ -962,14 +965,22 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A generator.Close() }() - run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator, - newResumeBridgeStore(stateByte), co...) + store := newResumeBridgeStore(stateByte) + if a.toolsConfig.EmitInternalEvents { + co = append(co, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEventGenerator(generator)))) + } + + run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator, store, ch, co...) }() - return iterator + return mwHelper.execOnEvents(ctx, agentContext, iterator) +} + +func (a *ChatModelAgent) IsAgentMiddlewareEnabled() bool { + return true } -func getComposeOptions(opts []AgentRunOption) []compose.Option { +func getComposeOptions(opts []AgentRunOption) ([]compose.Option, []callbacks.Handler) { o := GetImplSpecificOptions[chatModelAgentRunOptions](nil, opts...) var co []compose.Option if len(o.chatModelOptions) > 0 { @@ -995,7 +1006,7 @@ func getComposeOptions(opts []AgentRunOption) []compose.Option { return nil })) } - return co + return co, o.graphCallbacks } type gobSerializer struct{} @@ -1013,3 +1024,79 @@ func (g *gobSerializer) Unmarshal(data []byte, v any) error { buf := bytes.NewBuffer(data) return gob.NewDecoder(buf).Decode(v) } + +type chatModelMWHelper struct { + instruction string + toolsConfig ToolsConfig + beforeChatModels []func(context.Context, *ChatModelAgentState) error + afterChatModels []func(context.Context, *ChatModelAgentState) error + beforeAgents []func(context.Context, *AgentContext) (context.Context, error) + onEvents []func(context.Context, *AgentContext, *AsyncIterator[*AgentEvent], *AsyncGenerator[*AgentEvent]) +} + +func (c *chatModelMWHelper) withMWs(mws []AgentMiddleware) *chatModelMWHelper { + dedup := make(map[string]struct{}) + beforeChatModels := make([]func(context.Context, *ChatModelAgentState) error, 0) + afterChatModels := make([]func(context.Context, *ChatModelAgentState) error, 0) + beforeAgents := make([]func(context.Context, *AgentContext) (context.Context, error), 0) + onEvents := make([]func(context.Context, *AgentContext, *AsyncIterator[*AgentEvent], *AsyncGenerator[*AgentEvent]), 0) + sb := &strings.Builder{} + sb.WriteString(c.instruction) + tc := c.toolsConfig + for _, m := range mws { + if _, found := dedup[m.Name]; m.Name != "" && found { + continue + } + dedup[m.Name] = struct{}{} + sb.WriteString("\n") + sb.WriteString(m.AdditionalInstruction) + tc.Tools = append(tc.Tools, m.AdditionalTools...) + + if m.WrapToolCall.Invokable != nil || m.WrapToolCall.Streamable != nil { + tc.ToolCallMiddlewares = append(tc.ToolCallMiddlewares, m.WrapToolCall) + } + if m.BeforeChatModel != nil { + beforeChatModels = append(beforeChatModels, m.BeforeChatModel) + } + if m.AfterChatModel != nil { + afterChatModels = append(afterChatModels, m.AfterChatModel) + } + beforeAgents = append(beforeAgents, m.BeforeAgent) + onEvents = append(onEvents, m.OnEvents) + } + + c.instruction = sb.String() + c.toolsConfig = tc + c.beforeChatModels = append(beforeChatModels, c.beforeChatModels...) + c.afterChatModels = append(afterChatModels, c.afterChatModels...) + c.beforeAgents = append(beforeAgents, c.beforeAgents...) + c.onEvents = append(onEvents, c.onEvents...) + return c +} + +func (c *chatModelMWHelper) withTransferToAgents(ctx context.Context, transferToAgents []Agent) *chatModelMWHelper { + transferInstruction := genTransferToAgentInstruction(ctx, transferToAgents) + c.instruction = concatInstructions(c.instruction, transferInstruction) + c.toolsConfig.Tools = append(c.toolsConfig.Tools, &transferToAgent{}) + c.toolsConfig.ReturnDirectly[TransferToAgentToolName] = true + + return c +} + +func (c *chatModelMWHelper) withExitTool(ctx context.Context, exitTool tool.BaseTool) (*chatModelMWHelper, runFunc) { + c.toolsConfig.ToolsNodeConfig.Tools = append(c.toolsConfig.ToolsNodeConfig.Tools, exitTool) + exitInfo, err := exitTool.Info(ctx) + if err != nil { + return nil, errFunc(err) + } + c.toolsConfig.ReturnDirectly[exitInfo.Name] = true + + return c, nil +} + +func (c *chatModelMWHelper) toMWHelper() *agentMWHelper { + return &agentMWHelper{ + beforeAgentFns: c.beforeAgents, + onEventsFns: c.onEvents, + } +} diff --git a/adk/flow.go b/adk/flow.go index 09e0cc8d..c46e039e 100644 --- a/adk/flow.go +++ b/adk/flow.go @@ -299,59 +299,128 @@ func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRun return genErrorIter(err) } - if wf, ok := a.Agent.(*workflowAgent); ok { - return wf.Run(ctx, input, opts...) + var ( + mwHelper *agentMWHelper + needExecMW = !isAgentMiddlewareEnabled(a.Agent) + agentContext = &AgentContext{ + AgentInput: input, + AgentRunOptions: opts, + agentName: agentName, + invocationType: InvocationTypeRun, + } + ) + + if needExecMW { + mwHelper = newAgentMWHelper(GetGlobalAgentMiddlewares()...) + var termIter *AsyncIterator[*AgentEvent] + ctx, termIter = mwHelper.execBeforeAgents(ctx, agentContext) + if termIter != nil { + return termIter + } + // TODO: set back input in runCtx ? } - aIter := a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) + input = agentContext.AgentInput + opts = agentContext.AgentRunOptions + iter := func() *AsyncIterator[*AgentEvent] { + if wf, ok := a.Agent.(*workflowAgent); ok { + return wf.Run(ctx, input, opts...) + } + if mf, ok := a.Agent.(*multiAgent); ok { + return mf.Run(ctx, input, opts...) + } + + aIter := a.Agent.Run(ctx, input, filterOptions(agentName, opts)...) + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + + go a.run(ctx, runCtx, aIter, generator, opts...) - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + return iterator + }() - go a.run(ctx, runCtx, aIter, generator, opts...) + if needExecMW { + iter = mwHelper.execOnEvents(ctx, agentContext, iter) + } - return iterator + return iter } func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + agentName := a.Name(ctx) ctx, info = buildResumeInfo(ctx, a.Name(ctx), info) - if info.WasInterrupted { - ra, ok := a.Agent.(ResumableAgent) - if !ok { - return genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ - "but is not a ResumableAgent", a.Name(ctx))) + var ( + mwHelper *agentMWHelper + needExecMW = !isAgentMiddlewareEnabled(a.Agent) + agentContext = &AgentContext{ + ResumeInfo: info, + AgentRunOptions: opts, + agentName: agentName, + invocationType: InvocationTypeRun, } - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() - - aIter := ra.Resume(ctx, info, opts...) - if _, ok := ra.(*workflowAgent); ok { - return aIter + ) + + if needExecMW { + mwHelper = newAgentMWHelper(GetGlobalAgentMiddlewares()...) + var termIter *AsyncIterator[*AgentEvent] + ctx, termIter = mwHelper.execBeforeAgents(ctx, agentContext) + if termIter != nil { + return termIter } - go a.run(ctx, getRunCtx(ctx), aIter, generator, opts...) - return iterator } - nextAgentName, err := getNextResumeAgent(ctx, info) - if err != nil { - return genErrorIter(err) - } + info = agentContext.ResumeInfo + opts = agentContext.AgentRunOptions + + iter := func() *AsyncIterator[*AgentEvent] { + if info.WasInterrupted { + ra, ok := a.Agent.(ResumableAgent) + if !ok { + return genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' is an interrupt point "+ + "but is not a ResumableAgent", a.Name(ctx))) + } + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + + aIter := ra.Resume(ctx, info, opts...) + if _, ok := ra.(*workflowAgent); ok { + return aIter + } + if _, ok := ra.(*multiAgent); ok { + return aIter + } + go a.run(ctx, getRunCtx(ctx), aIter, generator, opts...) + return iterator + } + + nextAgentName, err := getNextResumeAgent(ctx, info) + if err != nil { + return genErrorIter(err) + } - subAgent := a.getAgent(ctx, nextAgentName) - if subAgent == nil { - // the inner agent wrapped by flowAgent may be ANY agent, including flowAgent, - // AgentWithDeterministicTransferTo, or any other custom agent user defined, - // or any combinations of the above in any order, - // that ultimately wraps the flowAgent with sub-agents - // We need to go through these wrappers to reach the flowAgent with sub-agents. - if len(a.subAgents) == 0 { - if ra, ok := a.Agent.(ResumableAgent); ok { - return ra.Resume(ctx, info, opts...) + subAgent := a.getAgent(ctx, nextAgentName) + if subAgent == nil { + // the inner agent wrapped by flowAgent may be ANY agent, including flowAgent, + // AgentWithDeterministicTransferTo, or any other custom agent user defined, + // or any combinations of the above in any order, + // that ultimately wraps the flowAgent with sub-agents + // We need to go through these wrappers to reach the flowAgent with sub-agents. + if len(a.subAgents) == 0 { + if ra, ok := a.Agent.(ResumableAgent); ok { + return ra.Resume(ctx, info, opts...) + } } + return genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' not found from flowAgent '%s'", nextAgentName, a.Name(ctx))) } - return genErrorIter(fmt.Errorf("failed to resume agent: agent '%s' not found from flowAgent '%s'", nextAgentName, a.Name(ctx))) + + return subAgent.Resume(ctx, info, opts...) + }() + + if needExecMW { + iter = mwHelper.execOnEvents(ctx, agentContext, iter) } - return subAgent.Resume(ctx, info, opts...) + return iter } type DeterministicTransferConfig struct { diff --git a/adk/multi_agent.go b/adk/multi_agent.go new file mode 100644 index 00000000..a7ae24dd --- /dev/null +++ b/adk/multi_agent.go @@ -0,0 +1,114 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "fmt" +) + +type MultiAgentConfig struct { + Agent Agent // required + Name string // optional, default using Agent.Name + Description string // optional, default using Agent.Description + Middlewares []AgentMiddleware // optional +} + +// NewMultiAgent enables wrapping any individual Agent into a Multi-Agent structure. +// It assigns a new Name and Description to the wrapped Agent and adds middleware-based runtime capabilities around it. +// This method is essentially a simple wrapper for Agents, designed to address the following specific use cases: +// 1. Customizing Agent Metadata: Modify the public-facing Name and Description of any existing Agent, while adding middleware to its execution flow. +// 2. Trace Aggregation in Isolated Multi-Agent Systems: In multi-Agent architectures (e.g., Supervisor patterns) where internal TransferToAgent operations exist, +// individual Agents operate in separate layers with isolated Contexts. This isolation means there is no "root span" for tracing across all sub-Agents. +// By using NewMultiAgent with appropriate trace middleware, a unified root span is created to aggregate all child spans, enabling coherent tracing across the system. +func NewMultiAgent(ctx context.Context, config MultiAgentConfig) (ResumableAgent, error) { + if config.Agent == nil { + return nil, fmt.Errorf("missing agent") + } + ma := &multiAgent{ + agent: toFlowAgent(ctx, config.Agent), + name: config.Name, + description: config.Description, + middlewares: config.Middlewares, + } + if ma.name == "" { + ma.name = config.Agent.Name(ctx) + } + if ma.description == "" { + ma.description = config.Agent.Description(ctx) + } + return ma, nil +} + +type multiAgent struct { + agent *flowAgent + name string + description string + middlewares []AgentMiddleware +} + +func (ma *multiAgent) Name(ctx context.Context) string { + return ma.name +} + +func (ma *multiAgent) Description(ctx context.Context) string { + return ma.description +} + +func (ma *multiAgent) Run(ctx context.Context, input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] { + agentContext := &AgentContext{ + AgentInput: input, + AgentRunOptions: options, + agentName: ma.name, + invocationType: InvocationTypeRun, + } + + mwHelper := newAgentMWHelper(append(globalAgentMiddlewares, ma.middlewares...)...) + + ctx, termIter := mwHelper.execBeforeAgents(ctx, agentContext) + if termIter != nil { + return termIter + } + + iter := ma.agent.Run(ctx, agentContext.AgentInput, agentContext.AgentRunOptions...) + + return mwHelper.execOnEvents(ctx, agentContext, iter) +} + +func (ma *multiAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + agentContext := &AgentContext{ + ResumeInfo: info, + AgentRunOptions: opts, + agentName: ma.name, + invocationType: InvocationTypeResume, + } + + mwHelper := newAgentMWHelper(append(globalAgentMiddlewares, ma.middlewares...)...) + + ctx, termIter := mwHelper.execBeforeAgents(ctx, agentContext) + if termIter != nil { + return termIter + } + + iter := ma.agent.Resume(ctx, agentContext.ResumeInfo, agentContext.AgentRunOptions...) + + return mwHelper.execOnEvents(ctx, agentContext, iter) +} + +func (ma *multiAgent) IsAgentMiddlewareEnabled() bool { + return true +} diff --git a/adk/multi_agent_test.go b/adk/multi_agent_test.go new file mode 100644 index 00000000..dfb5537e --- /dev/null +++ b/adk/multi_agent_test.go @@ -0,0 +1,104 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestMultiAgent(t *testing.T) { + ctx := context.Background() + cnt := 0 + ma := &mockAgent{ + name: "mock_name", + description: "mock desc", + responses: []*AgentEvent{ + { + AgentName: "mock_name", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + IsStreaming: false, + Message: schema.AssistantMessage("ok", nil), + Role: schema.Assistant, + }, + }, + }, + { + AgentName: "mock_name", + Action: &AgentAction{ + Exit: true, + }, + }, + }, + } + mw := AgentMiddleware{ + Name: "test", + BeforeAgent: func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) { + assert.Equal(t, "mock_name", ac.AgentName()) + assert.Equal(t, InvocationTypeRun, ac.InvocationType()) + assert.Equal(t, "hello", ac.AgentInput.Messages[0].Content) + + ac.AgentInput.Messages[0].Content = "bye" + ctx = context.WithValue(ctx, "mock_key", 1) + return ctx, nil + }, + OnEvents: NewSyncOnSingleEventHandler(func(ctx context.Context, ac *AgentContext, fromEvent *AgentEvent) (toEvent *AgentEvent) { + assert.Equal(t, "bye", ac.AgentInput.Messages[0].Content) + assert.Equal(t, 1, ctx.Value("mock_key").(int)) + assert.Nil(t, fromEvent.Err) + if cnt == 0 { + assert.Equal(t, "ok", fromEvent.Output.MessageOutput.Message.Content) + fromEvent.Output.MessageOutput.Message.Content = "okok" + } else { + assert.True(t, fromEvent.Action.Exit) + } + cnt++ + return fromEvent + }), + } + + a, err := NewMultiAgent(ctx, MultiAgentConfig{ + Agent: ma, + Middlewares: []AgentMiddleware{mw}, + }) + assert.NoError(t, err) + assert.Equal(t, "mock_name", a.Name(ctx)) + assert.Equal(t, "mock desc", a.Description(ctx)) + + r := NewRunner(ctx, RunnerConfig{Agent: a}) + iter := r.Run(ctx, []Message{schema.UserMessage("hello")}) + readCnt := 0 + for { + event, ok := iter.Next() + if !ok { + break + } + if readCnt == 0 { + assert.Equal(t, "okok", event.Output.MessageOutput.Message.Content) + } else { + assert.True(t, event.Action.Exit) + } + readCnt++ + } + + assert.Equal(t, 2, cnt) +} diff --git a/adk/prebuilt/planexecute/plan_execute.go b/adk/prebuilt/planexecute/plan_execute.go index 4dacc484..1ad3dfce 100644 --- a/adk/prebuilt/planexecute/plan_execute.go +++ b/adk/prebuilt/planexecute/plan_execute.go @@ -834,6 +834,10 @@ func NewReplanner(_ context.Context, cfg *ReplannerConfig) (adk.Agent, error) { // Config provides configuration options for creating a plan-execute-replan agent. type Config struct { + Name string + Description string + Middlewares []adk.AgentMiddleware + // Planner specifies the agent that generates the plan. // You can use provided NewPlanner to create a planner agent. Planner adk.Agent @@ -871,8 +875,14 @@ func New(ctx context.Context, cfg *Config) (adk.ResumableAgent, error) { return nil, err } + name := cfg.Name + if name == "" { + name = "plan_execute_replan" + } return adk.NewSequentialAgent(ctx, &adk.SequentialAgentConfig{ - Name: "plan_execute_replan", - SubAgents: []adk.Agent{cfg.Planner, loop}, + Name: name, + Description: cfg.Description, + SubAgents: []adk.Agent{cfg.Planner, loop}, + Middlewares: cfg.Middlewares, }) } diff --git a/adk/prebuilt/supervisor/supervisor.go b/adk/prebuilt/supervisor/supervisor.go index 97d67780..cf105a55 100644 --- a/adk/prebuilt/supervisor/supervisor.go +++ b/adk/prebuilt/supervisor/supervisor.go @@ -25,6 +25,10 @@ import ( ) type Config struct { + Name string + Description string + Middlewares []adk.AgentMiddleware + // Supervisor specifies the agent that will act as the supervisor, coordinating and managing the sub-agents. Supervisor adk.Agent @@ -48,5 +52,20 @@ func New(ctx context.Context, conf *Config) (adk.ResumableAgent, error) { })) } - return adk.SetSubAgents(ctx, conf.Supervisor, subAgents) + sa, err := adk.SetSubAgents(ctx, conf.Supervisor, subAgents) + if err != nil { + return nil, err + } + + name := conf.Name + if name == "" { + name = supervisorName + } + + return adk.NewMultiAgent(ctx, adk.MultiAgentConfig{ + Agent: sa, + Name: name, + Description: conf.Description, + Middlewares: conf.Middlewares, + }) } diff --git a/adk/prebuilt/supervisor/supervisor_test.go b/adk/prebuilt/supervisor/supervisor_test.go index 73ea41fb..67318d88 100644 --- a/adk/prebuilt/supervisor/supervisor_test.go +++ b/adk/prebuilt/supervisor/supervisor_test.go @@ -47,6 +47,7 @@ func TestNewSupervisor(t *testing.T) { subAgent2 := mockAdk.NewMockAgent(ctrl) supervisorAgent.EXPECT().Name(gomock.Any()).Return("SupervisorAgent").AnyTimes() + supervisorAgent.EXPECT().Description(gomock.Any()).Return("mock desc").AnyTimes() subAgent1.EXPECT().Name(gomock.Any()).Return("SubAgent1").AnyTimes() subAgent2.EXPECT().Name(gomock.Any()).Return("SubAgent2").AnyTimes() diff --git a/adk/utils.go b/adk/utils.go index 84f962ce..d8c88f4e 100644 --- a/adk/utils.go +++ b/adk/utils.go @@ -20,11 +20,13 @@ import ( "context" "errors" "io" + "runtime/debug" "strings" "github.com/google/uuid" "github.com/cloudwego/eino/internal" + "github.com/cloudwego/eino/internal/safe" "github.com/cloudwego/eino/schema" ) @@ -239,3 +241,59 @@ func genErrorIter(err error) *AsyncIterator[*AgentEvent] { generator.Close() return iterator } + +// NewAsyncOnSingleEventHandler creates an OnEvents middleware function that handles each event asynchronously. +// It wraps the synchronous single event handler in a goroutine. +func NewAsyncOnSingleEventHandler(onEvent func(ctx context.Context, ac *AgentContext, fromEvent *AgentEvent) (toEvent *AgentEvent)) ( + onEventsFn func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent])) { + return func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + go NewSyncOnSingleEventHandler(onEvent)(ctx, ac, iter, gen) + } +} + +// NewSyncOnSingleEventHandler creates an OnEvents middleware function that handles each event synchronously. +// It applies the given onEvent function to each event in the iterator. +// The function can be used in the following scenarios: +// 1. Modify the event (modify message/error): Read and modify fromEvent, and finally return fromEvent +// 2. Skip this event output: Return nil +// 3. Terminate output: After consuming a specific event, return nil for any subsequent events received +func NewSyncOnSingleEventHandler(onEvent func(ctx context.Context, ac *AgentContext, fromEvent *AgentEvent) (toEvent *AgentEvent)) ( + onEventsFn func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent])) { + return func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) + gen.Send(&AgentEvent{Err: e}) + } + gen.Close() + }() + + for { + event, ok := iter.Next() + if !ok { + break + } + toEvent := onEvent(ctx, ac, event) + if toEvent == nil { + continue + } + gen.Send(toEvent) + } + } +} + +// BypassIterator creates a goroutine that simply passes events from the input iterator to the output generator. +// This is useful when you need to do something without modifying events. +func BypassIterator(iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) { + go func() { + defer gen.Close() + for { + event, ok := iter.Next() + if !ok { + break + } + gen.Send(event) + } + }() +} diff --git a/adk/workflow.go b/adk/workflow.go index 273ff3f5..2cbd3586 100644 --- a/adk/workflow.go +++ b/adk/workflow.go @@ -40,6 +40,7 @@ type workflowAgent struct { name string description string subAgents []*flowAgent + middlewares []AgentMiddleware mode workflowAgentMode @@ -54,9 +55,30 @@ func (a *workflowAgent) Description(_ context.Context) string { return a.description } -func (a *workflowAgent) Run(ctx context.Context, _ *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() +func (a *workflowAgent) IsAgentMiddlewareEnabled() bool { + return true +} +func (a *workflowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { + var ( + mwHelper = newAgentMWHelper(append(GetGlobalAgentMiddlewares(), a.middlewares...)...) + agentContext = &AgentContext{ + AgentInput: input, + AgentRunOptions: opts, + agentName: a.name, + invocationType: InvocationTypeRun, + } + ) + + // FIXME: write back new *AgentInput to runCtx ? + var termIter *AsyncIterator[*AgentEvent] + ctx, termIter = mwHelper.execBeforeAgents(ctx, agentContext) + if termIter != nil { + return termIter + } + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + opts = agentContext.AgentRunOptions go func() { var err error @@ -85,7 +107,7 @@ func (a *workflowAgent) Run(ctx context.Context, _ *AgentInput, opts ...AgentRun } }() - return iterator + return mwHelper.execOnEvents(ctx, agentContext, iterator) } type sequentialWorkflowState struct { @@ -108,8 +130,25 @@ func init() { } func (a *workflowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] { - iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + var ( + mwHelper = newAgentMWHelper(append(GetGlobalAgentMiddlewares(), a.middlewares...)...) + agentContext = &AgentContext{ + ResumeInfo: info, + AgentRunOptions: opts, + agentName: a.name, + invocationType: InvocationTypeRun, + } + ) + var termIter *AsyncIterator[*AgentEvent] + ctx, termIter = mwHelper.execBeforeAgents(ctx, agentContext) + if termIter != nil { + return termIter + } + + iterator, generator := NewAsyncIteratorPair[*AgentEvent]() + info = agentContext.ResumeInfo + opts = agentContext.AgentRunOptions go func() { var err error defer func() { @@ -141,7 +180,8 @@ func (a *workflowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...Ag err = fmt.Errorf("unsupported workflow agent state type: %T", s) } }() - return iterator + + return mwHelper.execOnEvents(ctx, agentContext, iterator) } // WorkflowInterruptInfo CheckpointSchema: persisted via InterruptInfo.Data (gob). @@ -538,29 +578,33 @@ type SequentialAgentConfig struct { Name string Description string SubAgents []Agent + Middlewares []AgentMiddleware } type ParallelAgentConfig struct { Name string Description string SubAgents []Agent + Middlewares []AgentMiddleware } type LoopAgentConfig struct { Name string Description string SubAgents []Agent + Middlewares []AgentMiddleware MaxIterations int } func newWorkflowAgent(ctx context.Context, name, desc string, - subAgents []Agent, mode workflowAgentMode, maxIterations int) (*flowAgent, error) { + subAgents []Agent, mode workflowAgentMode, maxIterations int, middlewares ...AgentMiddleware) (*flowAgent, error) { wa := &workflowAgent{ name: name, description: desc, mode: mode, + middlewares: middlewares, maxIterations: maxIterations, } @@ -582,15 +626,15 @@ func newWorkflowAgent(ctx context.Context, name, desc string, // NewSequentialAgent creates an agent that runs sub-agents sequentially. func NewSequentialAgent(ctx context.Context, config *SequentialAgentConfig) (ResumableAgent, error) { - return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeSequential, 0) + return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeSequential, 0, config.Middlewares...) } // NewParallelAgent creates an agent that runs sub-agents in parallel. func NewParallelAgent(ctx context.Context, config *ParallelAgentConfig) (ResumableAgent, error) { - return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeParallel, 0) + return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeParallel, 0, config.Middlewares...) } // NewLoopAgent creates an agent that loops over sub-agents with a max iteration limit. func NewLoopAgent(ctx context.Context, config *LoopAgentConfig) (ResumableAgent, error) { - return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeLoop, config.MaxIterations) + return newWorkflowAgent(ctx, config.Name, config.Description, config.SubAgents, workflowAgentModeLoop, config.MaxIterations, config.Middlewares...) }