Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions components/tool/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package tool

// Option defines call option for InvokableTool or StreamableTool component, which is part of component interface signature.
// Each tool implementation could define its own options struct and option funcs within its own package,
// then wrap the impl specific option funcs into this type, before passing to InvokableRun or StreamableRun.
// Each tool implementation could define its own options struct and option functions within its own package,
// then wrap the impl specific option functions into this type, before passing to InvokableRun or StreamableRun.
type Option struct {
implSpecificOptFn any
}
Expand Down
7 changes: 3 additions & 4 deletions compose/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/internal/compose"
"github.com/cloudwego/eino/internal/generic"
"github.com/cloudwego/eino/internal/gmap"
"github.com/cloudwego/eino/schema"
Expand Down Expand Up @@ -1121,9 +1122,7 @@ func validateDAG(chanSubscribeTo map[string]*chanCall, invertedEdges map[string]
}

func NewNodePath(path ...string) *NodePath {
return &NodePath{path: path}
return compose.NewNodePath(path...)
}

type NodePath struct {
path []string
}
type NodePath = compose.NodePath
4 changes: 4 additions & 0 deletions compose/graph_call_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ func (o Option) DesignateNodeWithPath(path ...*NodePath) Option {
return o
}

func (o Option) Paths() []*NodePath {
return o.paths
}

// WithEmbeddingOption is a functional option type for embedding component.
// e.g.
//
Expand Down
12 changes: 6 additions & 6 deletions compose/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func initNodeCallbacks(ctx context.Context, key string, info *nodeInfo, meta *ex
if len(opts[i].handler) != 0 {
if len(opts[i].paths) != 0 {
for _, k := range opts[i].paths {
if len(k.path) == 1 && k.path[0] == key {
if len(k.Path()) == 1 && k.Path()[0] == key {
cbs = append(cbs, opts[i].handler...)
break
}
Expand Down Expand Up @@ -314,18 +314,18 @@ func extractOption(nodes map[string]*chanCall, opts ...Option) (map[string][]any
}
}
for _, path := range opt.paths {
if len(path.path) == 0 {
if len(path.Path()) == 0 {
return nil, fmt.Errorf("call option has designated an empty path")
}

var curNode *chanCall
var ok bool
if curNode, ok = nodes[path.path[0]]; !ok {
if curNode, ok = nodes[path.Path()[0]]; !ok {
return nil, fmt.Errorf("option has designated an unknown node: %s", path)
}
curNodeKey := path.path[0]
curNodeKey := path.Path()[0]

if len(path.path) == 1 {
if len(path.Path()) == 1 {
if len(opt.options) == 0 {
// sub graph common callbacks has been added to ctx in initNodeCallback and won't be passed to subgraph only pass options
// node callback also won't be passed
Expand All @@ -350,7 +350,7 @@ func extractOption(nodes map[string]*chanCall, opts ...Option) (map[string][]any
}
// designate to sub graph's nodes
nOpt := opt.deepCopy()
nOpt.paths = []*NodePath{NewNodePath(path.path[1:]...)}
nOpt.paths = []*NodePath{NewNodePath(path.Path()[1:]...)}
optMap[curNodeKey] = append(optMap[curNodeKey], nOpt)
}
}
Expand Down
3 changes: 2 additions & 1 deletion flow/agent/agent_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package agent
import "github.com/cloudwego/eino/compose"

// AgentOption is the common option type for various agent and multi-agent implementations.
// For options intended to use with underlying graph or components, use WithComposeOptions to specify.
// For options intended to use with particular agent/multi-agent implementations, use WrapImplSpecificOptFn to specify.
type AgentOption struct {
implSpecificOptFn any
composeOptions []compose.Option
}

// GetComposeOptions returns all compose options from the given agent options.
// Deprecated
func GetComposeOptions(opts ...AgentOption) []compose.Option {
var result []compose.Option
for _, opt := range opts {
Expand All @@ -37,6 +37,7 @@ func GetComposeOptions(opts ...AgentOption) []compose.Option {
}

// WithComposeOptions returns an agent option that specifies compose options.
// Deprecated: use option functions defined by each agent flow implementation instead.
func WithComposeOptions(opts ...compose.Option) AgentOption {
return AgentOption{
composeOptions: opts,
Expand Down
74 changes: 73 additions & 1 deletion flow/agent/multiagent/host/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type HandOffInfo struct {
}

// ConvertCallbackHandlers converts []host.MultiAgentCallback to callbacks.Handler.
// Deprecated: use ConvertOptions to convert agent.AgentOption to compose.Option when adding MultiAgent's Graph to another Graph.
func ConvertCallbackHandlers(handlers ...MultiAgentCallback) callbacks.Handler {
onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
if output == nil || info == nil {
Expand Down Expand Up @@ -121,5 +122,76 @@ func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler {
}

handlers := agentOptions.agentCallbacks
return ConvertCallbackHandlers(handlers...)

onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
if output == nil || info == nil {
return ctx
}

msg := output.Message
if msg == nil || msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 {
return ctx
}

agentName := msg.ToolCalls[0].Function.Name
argument := msg.ToolCalls[0].Function.Arguments

for _, cb := range handlers {
ctx = cb.OnHandOff(ctx, &HandOffInfo{
ToAgentName: agentName,
Argument: argument,
})
}

return ctx
}

onChatModelEndWithStreamOutput := func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context {
if output == nil || info == nil {
return ctx
}

defer output.Close()

var msgs []*schema.Message
for {
oneOutput, err := output.Recv()
if err == io.EOF {
break
}
if err != nil {
return ctx
}

msg := oneOutput.Message
if msg == nil {
continue
}

msgs = append(msgs, msg)
}

msg, err := schema.ConcatMessages(msgs)
if err != nil {
return ctx
}

if msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 {
return ctx
}

for _, cb := range handlers {
ctx = cb.OnHandOff(ctx, &HandOffInfo{
ToAgentName: msg.ToolCalls[0].Function.Name,
Argument: msg.ToolCalls[0].Function.Arguments,
})
}

return ctx
}

return template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
OnEnd: onChatModelEnd,
OnEndWithStreamOutput: onChatModelEndWithStreamOutput,
}).Handler()
}
105 changes: 97 additions & 8 deletions flow/agent/multiagent/host/compose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ import (
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"

"github.com/cloudwego/eino/callbacks"
chatmodel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent"
"github.com/cloudwego/eino/internal/generic"
"github.com/cloudwego/eino/internal/mock/components/model"
"github.com/cloudwego/eino/schema"
template "github.com/cloudwego/eino/utils/callbacks"
)

func TestHostMultiAgent(t *testing.T) {
Expand All @@ -48,6 +51,14 @@ func TestHostMultiAgent(t *testing.T) {

specialist2 := &Specialist{
Invokable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) {
agentOpts := agent.GetImplSpecificOptions(&specialist2Options{}, opts...)
if agentOpts.mockOutput != nil {
return &schema.Message{
Role: schema.Assistant,
Content: *agentOpts.mockOutput,
}, nil
}

return &schema.Message{
Role: schema.Assistant,
Content: "specialist2 invoke answer",
Expand Down Expand Up @@ -92,11 +103,18 @@ func TestHostMultiAgent(t *testing.T) {
Content: "direct answer",
}

mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(directAnswerMsg, nil).Times(1)
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, input []*schema.Message, opts ...chatmodel.Option) (*schema.Message, error) {
modelOpts := chatmodel.GetCommonOptions(&chatmodel.Options{}, opts...)
assert.Equal(t, *modelOpts.Temperature, float32(0.7))
return directAnswerMsg, nil
}).
Times(1)

mockCallback := &mockAgentCallback{}

out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback))
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback),
WithAgentModelOptions(hostMA.HostNodeKey(), chatmodel.WithTemperature(0.7)))
assert.NoError(t, err)
assert.Equal(t, "direct answer", out.Content)
assert.Empty(t, mockCallback.infos)
Expand Down Expand Up @@ -164,11 +182,18 @@ func TestHostMultiAgent(t *testing.T) {
}

mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1)
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, input []*schema.Message, opts ...chatmodel.Option) (*schema.Message, error) {
modelOpts := chatmodel.GetCommonOptions(&chatmodel.Options{}, opts...)
assert.Equal(t, *modelOpts.Temperature, float32(0.7))
return specialistMsg, nil
}).
Times(1)

mockCallback := &mockAgentCallback{}

out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback))
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback),
WithAgentModelOptions(specialist1.Name, chatmodel.WithTemperature(0.7)))
assert.NoError(t, err)
assert.Equal(t, "specialist 1 answer", out.Content)
assert.Equal(t, []*HandOffInfo{
Expand Down Expand Up @@ -379,16 +404,41 @@ func TestHostMultiAgent(t *testing.T) {
},
}

specialistMsg := &schema.Message{
specialist1Msg := &schema.Message{
Role: schema.Assistant,
Content: "Beijing",
}

mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1)
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(2)
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, input []*schema.Message, opts ...chatmodel.Option) (*schema.Message, error) {
modelOpts := chatmodel.GetCommonOptions(&chatmodel.Options{}, opts...)
assert.Equal(t, *modelOpts.Temperature, float32(0.7))
return specialist1Msg, nil
}).
Times(1)

mockCallback := &mockAgentCallback{}

var hostOutput, specialist1Output, specialist2Output string
hostModelCallback := template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *chatmodel.CallbackOutput) context.Context {
hostOutput = output.Message.ToolCalls[0].Function.Name
return ctx
},
}).Handler()
specialist1ModelCallback := template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *chatmodel.CallbackOutput) context.Context {
specialist1Output = output.Message.Content
return ctx
},
}).Handler()
specialist2LambdaCallback := template.NewHandlerHelper().Lambda(callbacks.NewHandlerBuilder().OnEndFn(
func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
specialist2Output = output.(*schema.Message).Content
return ctx
}).Build()).Handler()

hostMA, err := NewMultiAgent(ctx, &MultiAgentConfig{
Host: Host{
ChatModel: mockHostLLM,
Expand All @@ -409,7 +459,14 @@ func TestHostMultiAgent(t *testing.T) {
Compile(ctx)
assert.NoError(t, err)

out, err := fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, compose.WithCallbacks(ConvertCallbackHandlers(mockCallback)).DesignateNodeWithPath(compose.NewNodePath("host_ma_node", hostMA.HostNodeKey())))
convertedOptions := ConvertOptions(compose.NewNodePath("host_ma_node"), WithAgentCallbacks(mockCallback),
WithAgentModelOptions(specialist1.Name, chatmodel.WithTemperature(0.7)),
WithAgentModelCallbacks(hostMA.HostNodeKey(), hostModelCallback),
WithAgentModelCallbacks(specialist1.Name, specialist1ModelCallback),
WithSpecialistLambdaCallbacks(specialist2.Name, specialist2LambdaCallback),
WithSpecialistLambdaOptions(specialist2.Name, withSpecialist2MockOutput("mock_city_name")))

out, err := fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, convertedOptions...)
assert.NoError(t, err)
assert.Equal(t, "Beijing", out.Content)
assert.Equal(t, []*HandOffInfo{
Expand All @@ -418,6 +475,28 @@ func TestHostMultiAgent(t *testing.T) {
Argument: `{"reason": "specialist 1 is the best"}`,
},
}, mockCallback.infos)
assert.Equal(t, hostOutput, specialist1.Name)
assert.Equal(t, specialist1Output, out.Content)
assert.Equal(t, specialist2Output, "")

handOffMsg.ToolCalls[0].Function.Name = specialist2.Name
handOffMsg.ToolCalls[0].Function.Arguments = `{"reason": "specialist 2 is even better"}`

out, err = fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, convertedOptions...)
assert.NoError(t, err)
assert.Equal(t, "mock_city_name", out.Content)
assert.Equal(t, []*HandOffInfo{
{
ToAgentName: specialist1.Name,
Argument: `{"reason": "specialist 1 is the best"}`,
},
{
ToAgentName: specialist2.Name,
Argument: `{"reason": "specialist 2 is even better"}`,
},
}, mockCallback.infos)
assert.Equal(t, hostOutput, specialist2.Name)
assert.Equal(t, specialist2Output, "mock_city_name")
})
}

Expand All @@ -429,3 +508,13 @@ func (m *mockAgentCallback) OnHandOff(ctx context.Context, info *HandOffInfo) co
m.infos = append(m.infos, info)
return ctx
}

type specialist2Options struct {
mockOutput *string
}

func withSpecialist2MockOutput(mockOutput string) agent.AgentOption {
return agent.WrapImplSpecificOptFn(func(o *specialist2Options) {
o.mockOutput = &mockOutput
})
}
Loading
Loading