@@ -24,12 +24,15 @@ import (
2424 "github.com/stretchr/testify/assert"
2525 "go.uber.org/mock/gomock"
2626
27+ "github.com/cloudwego/eino/callbacks"
28+ chatmodel "github.com/cloudwego/eino/components/model"
2729 "github.com/cloudwego/eino/components/prompt"
2830 "github.com/cloudwego/eino/compose"
2931 "github.com/cloudwego/eino/flow/agent"
3032 "github.com/cloudwego/eino/internal/generic"
3133 "github.com/cloudwego/eino/internal/mock/components/model"
3234 "github.com/cloudwego/eino/schema"
35+ template "github.com/cloudwego/eino/utils/callbacks"
3336)
3437
3538func TestHostMultiAgent (t * testing.T ) {
@@ -48,6 +51,14 @@ func TestHostMultiAgent(t *testing.T) {
4851
4952 specialist2 := & Specialist {
5053 Invokable : func (ctx context.Context , input []* schema.Message , opts ... agent.AgentOption ) (* schema.Message , error ) {
54+ agentOpts := agent .GetImplSpecificOptions (& specialist2Options {}, opts ... )
55+ if agentOpts .mockOutput != nil {
56+ return & schema.Message {
57+ Role : schema .Assistant ,
58+ Content : * agentOpts .mockOutput ,
59+ }, nil
60+ }
61+
5162 return & schema.Message {
5263 Role : schema .Assistant ,
5364 Content : "specialist2 invoke answer" ,
@@ -92,11 +103,18 @@ func TestHostMultiAgent(t *testing.T) {
92103 Content : "direct answer" ,
93104 }
94105
95- mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (directAnswerMsg , nil ).Times (1 )
106+ mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any (), gomock .Any ()).
107+ DoAndReturn (func (_ context.Context , input []* schema.Message , opts ... chatmodel.Option ) (* schema.Message , error ) {
108+ modelOpts := chatmodel .GetCommonOptions (& chatmodel.Options {}, opts ... )
109+ assert .Equal (t , * modelOpts .Temperature , float32 (0.7 ))
110+ return directAnswerMsg , nil
111+ }).
112+ Times (1 )
96113
97114 mockCallback := & mockAgentCallback {}
98115
99- out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ))
116+ out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ),
117+ WithAgentModelOptions (hostMA .HostNodeKey (), chatmodel .WithTemperature (0.7 )))
100118 assert .NoError (t , err )
101119 assert .Equal (t , "direct answer" , out .Content )
102120 assert .Empty (t , mockCallback .infos )
@@ -164,11 +182,18 @@ func TestHostMultiAgent(t *testing.T) {
164182 }
165183
166184 mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (handOffMsg , nil ).Times (1 )
167- mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (specialistMsg , nil ).Times (1 )
185+ mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any (), gomock .Any ()).
186+ DoAndReturn (func (_ context.Context , input []* schema.Message , opts ... chatmodel.Option ) (* schema.Message , error ) {
187+ modelOpts := chatmodel .GetCommonOptions (& chatmodel.Options {}, opts ... )
188+ assert .Equal (t , * modelOpts .Temperature , float32 (0.7 ))
189+ return specialistMsg , nil
190+ }).
191+ Times (1 )
168192
169193 mockCallback := & mockAgentCallback {}
170194
171- out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ))
195+ out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ),
196+ WithAgentModelOptions (specialist1 .Name , chatmodel .WithTemperature (0.7 )))
172197 assert .NoError (t , err )
173198 assert .Equal (t , "specialist 1 answer" , out .Content )
174199 assert .Equal (t , []* HandOffInfo {
@@ -379,16 +404,41 @@ func TestHostMultiAgent(t *testing.T) {
379404 },
380405 }
381406
382- specialistMsg := & schema.Message {
407+ specialist1Msg := & schema.Message {
383408 Role : schema .Assistant ,
384409 Content : "Beijing" ,
385410 }
386411
387- mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (handOffMsg , nil ).Times (1 )
388- mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (specialistMsg , nil ).Times (1 )
412+ mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (handOffMsg , nil ).Times (2 )
413+ mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any (), gomock .Any ()).
414+ DoAndReturn (func (_ context.Context , input []* schema.Message , opts ... chatmodel.Option ) (* schema.Message , error ) {
415+ modelOpts := chatmodel .GetCommonOptions (& chatmodel.Options {}, opts ... )
416+ assert .Equal (t , * modelOpts .Temperature , float32 (0.7 ))
417+ return specialist1Msg , nil
418+ }).
419+ Times (1 )
389420
390421 mockCallback := & mockAgentCallback {}
391422
423+ var hostOutput , specialist1Output , specialist2Output string
424+ hostModelCallback := template .NewHandlerHelper ().ChatModel (& template.ModelCallbackHandler {
425+ OnEnd : func (ctx context.Context , runInfo * callbacks.RunInfo , output * chatmodel.CallbackOutput ) context.Context {
426+ hostOutput = output .Message .ToolCalls [0 ].Function .Name
427+ return ctx
428+ },
429+ }).Handler ()
430+ specialist1ModelCallback := template .NewHandlerHelper ().ChatModel (& template.ModelCallbackHandler {
431+ OnEnd : func (ctx context.Context , runInfo * callbacks.RunInfo , output * chatmodel.CallbackOutput ) context.Context {
432+ specialist1Output = output .Message .Content
433+ return ctx
434+ },
435+ }).Handler ()
436+ specialist2LambdaCallback := template .NewHandlerHelper ().Lambda (callbacks .NewHandlerBuilder ().OnEndFn (
437+ func (ctx context.Context , info * callbacks.RunInfo , output callbacks.CallbackOutput ) context.Context {
438+ specialist2Output = output .(* schema.Message ).Content
439+ return ctx
440+ }).Build ()).Handler ()
441+
392442 hostMA , err := NewMultiAgent (ctx , & MultiAgentConfig {
393443 Host : Host {
394444 ChatModel : mockHostLLM ,
@@ -409,7 +459,14 @@ func TestHostMultiAgent(t *testing.T) {
409459 Compile (ctx )
410460 assert .NoError (t , err )
411461
412- out , err := fullGraph .Invoke (ctx , map [string ]any {"country_name" : "China" }, compose .WithCallbacks (ConvertCallbackHandlers (mockCallback )).DesignateNodeWithPath (compose .NewNodePath ("host_ma_node" , hostMA .HostNodeKey ())))
462+ convertedOptions := ConvertOptions (compose .NewNodePath ("host_ma_node" ), WithAgentCallbacks (mockCallback ),
463+ WithAgentModelOptions (specialist1 .Name , chatmodel .WithTemperature (0.7 )),
464+ WithAgentModelCallbacks (hostMA .HostNodeKey (), hostModelCallback ),
465+ WithAgentModelCallbacks (specialist1 .Name , specialist1ModelCallback ),
466+ WithSpecialistLambdaCallbacks (specialist2 .Name , specialist2LambdaCallback ),
467+ WithSpecialistLambdaOptions (specialist2 .Name , withSpecialist2MockOutput ("mock_city_name" )))
468+
469+ out , err := fullGraph .Invoke (ctx , map [string ]any {"country_name" : "China" }, convertedOptions ... )
413470 assert .NoError (t , err )
414471 assert .Equal (t , "Beijing" , out .Content )
415472 assert .Equal (t , []* HandOffInfo {
@@ -418,6 +475,28 @@ func TestHostMultiAgent(t *testing.T) {
418475 Argument : `{"reason": "specialist 1 is the best"}` ,
419476 },
420477 }, mockCallback .infos )
478+ assert .Equal (t , hostOutput , specialist1 .Name )
479+ assert .Equal (t , specialist1Output , out .Content )
480+ assert .Equal (t , specialist2Output , "" )
481+
482+ handOffMsg .ToolCalls [0 ].Function .Name = specialist2 .Name
483+ handOffMsg .ToolCalls [0 ].Function .Arguments = `{"reason": "specialist 2 is even better"}`
484+
485+ out , err = fullGraph .Invoke (ctx , map [string ]any {"country_name" : "China" }, convertedOptions ... )
486+ assert .NoError (t , err )
487+ assert .Equal (t , "mock_city_name" , out .Content )
488+ assert .Equal (t , []* HandOffInfo {
489+ {
490+ ToAgentName : specialist1 .Name ,
491+ Argument : `{"reason": "specialist 1 is the best"}` ,
492+ },
493+ {
494+ ToAgentName : specialist2 .Name ,
495+ Argument : `{"reason": "specialist 2 is even better"}` ,
496+ },
497+ }, mockCallback .infos )
498+ assert .Equal (t , hostOutput , specialist2 .Name )
499+ assert .Equal (t , specialist2Output , "mock_city_name" )
421500 })
422501}
423502
@@ -429,3 +508,13 @@ func (m *mockAgentCallback) OnHandOff(ctx context.Context, info *HandOffInfo) co
429508 m .infos = append (m .infos , info )
430509 return ctx
431510}
511+
512+ type specialist2Options struct {
513+ mockOutput * string
514+ }
515+
516+ func withSpecialist2MockOutput (mockOutput string ) agent.AgentOption {
517+ return agent .WrapImplSpecificOptFn (func (o * specialist2Options ) {
518+ o .mockOutput = & mockOutput
519+ })
520+ }
0 commit comments