Skip to content

Commit c04737c

Browse files
fix: Runner run with new runctx
Change-Id: If0ae01e3a7d396d2c3080b359eff656a3ed43d0f
1 parent 9c9cb4b commit c04737c

File tree

3 files changed

+38
-36
lines changed

3 files changed

+38
-36
lines changed

adk/flow.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ func initRunCtx(ctx context.Context, agentName string, input *AgentInput) (conte
319319
return context.WithValue(ctx, runCtxKey{}, runCtx), runCtx
320320
}
321321

322+
func ctxWithNewRunCtx(ctx context.Context) context.Context {
323+
return context.WithValue(ctx, runCtxKey{}, &runContext{session: newRunSession()})
324+
}
325+
322326
func getSession(ctx context.Context) *runSession {
323327
v := ctx.Value(runCtxKey{})
324328

@@ -369,7 +373,7 @@ func genAgentInput(runCtx *runContext, agentName string) (*AgentInput, error) {
369373
return input, nil
370374
}
371375

372-
func (a *flowAgent) Run(ctx context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] {
376+
func (a *flowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] {
373377
agentName := a.Name(ctx)
374378

375379
ctx, runCtx := initRunCtx(ctx, agentName, input)
@@ -384,10 +388,10 @@ func (a *flowAgent) Run(ctx context.Context, input *AgentInput, _ ...AgentRunOpt
384388
}
385389

386390
if wf, ok := a.Agent.(*workflowAgent); ok {
387-
return wf.Run(ctx, input)
391+
return wf.Run(ctx, input, opts...)
388392
}
389393

390-
aIter := a.Agent.Run(ctx, input)
394+
aIter := a.Agent.Run(ctx, input, opts...)
391395

392396
iterator, generator := NewAsyncIteratorPair[*AgentEvent]()
393397

@@ -434,7 +438,7 @@ func (a *flowAgent) Run(ctx context.Context, input *AgentInput, _ ...AgentRunOpt
434438
return
435439
}
436440

437-
subAIter := agentToRun.Run(ctx, input)
441+
subAIter := agentToRun.Run(ctx, input, opts...)
438442
for {
439443
subEvent, ok_ := subAIter.Next()
440444
if !ok_ {

adk/runner.go

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,41 +22,35 @@ import (
2222
"github.com/cloudwego/eino/schema"
2323
)
2424

25-
type Runner struct{}
26-
27-
type RunnerConfig struct{}
25+
type Runner struct {
26+
enableStreaming bool
27+
}
2828

29-
func NewRunner(_ context.Context, _ *RunnerConfig) *Runner {
30-
return &Runner{}
29+
type RunnerConfig struct {
30+
EnableStreaming bool
3131
}
3232

33-
type RunOptions struct {
34-
enableStreaming bool
33+
func NewRunner(_ context.Context, conf RunnerConfig) *Runner {
34+
return &Runner{enableStreaming: conf.EnableStreaming}
3535
}
3636

37-
type RunOption func(*RunOptions)
37+
func (r *Runner) Run(ctx context.Context, agent Agent, msgs []Message,
38+
opts ...AgentRunOption) *AsyncIterator[*AgentEvent] {
3839

39-
func WithEnableStreaming() RunOption {
40-
return func(opts *RunOptions) {
41-
opts.enableStreaming = true
42-
}
43-
}
44-
func (r *Runner) Run(ctx context.Context, agent Agent, msgs []Message, opts ...RunOption) *AsyncIterator[*AgentEvent] {
4540
fa := toFlowAgent(agent)
4641

47-
options := &RunOptions{}
48-
for _, opt := range opts {
49-
opt(options)
50-
}
51-
5242
input := &AgentInput{
5343
Msgs: msgs,
54-
EnableStreaming: options.enableStreaming,
44+
EnableStreaming: r.enableStreaming,
5545
}
5646

57-
return fa.Run(ctx, input)
47+
ctx = ctxWithNewRunCtx(ctx)
48+
49+
return fa.Run(ctx, input, opts...)
5850
}
5951

60-
func (r *Runner) Query(ctx context.Context, agent Agent, query string, opts ...RunOption) *AsyncIterator[*AgentEvent] {
52+
func (r *Runner) Query(ctx context.Context, agent Agent,
53+
query string, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] {
54+
6155
return r.Run(ctx, agent, []Message{schema.UserMessage(query)}, opts...)
6256
}

adk/workflow.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (a *workflowAgent) Description(_ context.Context) string {
5353
return a.description
5454
}
5555

56-
func (a *workflowAgent) Run(ctx context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] {
56+
func (a *workflowAgent) Run(ctx context.Context, input *AgentInput, opts ...AgentRunOption) *AsyncIterator[*AgentEvent] {
5757
iterator, generator := NewAsyncIteratorPair[*AgentEvent]()
5858

5959
go func() {
@@ -74,11 +74,11 @@ func (a *workflowAgent) Run(ctx context.Context, input *AgentInput, _ ...AgentRu
7474
// Different workflow execution based on mode
7575
switch a.mode {
7676
case workflowAgentModeSequential:
77-
a.runSequential(ctx, input, generator)
77+
a.runSequential(ctx, input, generator, opts...)
7878
case workflowAgentModeLoop:
79-
a.runLoop(ctx, input, generator)
79+
a.runLoop(ctx, input, generator, opts...)
8080
case workflowAgentModeParallel:
81-
a.runParallel(ctx, input, generator)
81+
a.runParallel(ctx, input, generator, opts...)
8282
default:
8383
err = errors.New(fmt.Sprintf("unsupported workflow agent mode: %d", a.mode))
8484
}
@@ -88,10 +88,10 @@ func (a *workflowAgent) Run(ctx context.Context, input *AgentInput, _ ...AgentRu
8888
}
8989

9090
func (a *workflowAgent) runSequential(ctx context.Context, input *AgentInput,
91-
generator *AsyncGenerator[*AgentEvent]) (exit bool) {
91+
generator *AsyncGenerator[*AgentEvent], opts ...AgentRunOption) (exit bool) {
9292

9393
for _, subAgent := range a.subAgents {
94-
subIterator := subAgent.Run(ctx, input)
94+
subIterator := subAgent.Run(ctx, input, opts...)
9595
for {
9696
event, ok := subIterator.Next()
9797
if !ok {
@@ -116,22 +116,26 @@ func (a *workflowAgent) runSequential(ctx context.Context, input *AgentInput,
116116
return false
117117
}
118118

119-
func (a *workflowAgent) runLoop(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent]) {
119+
func (a *workflowAgent) runLoop(ctx context.Context, input *AgentInput,
120+
generator *AsyncGenerator[*AgentEvent], opts ...AgentRunOption) {
121+
120122
if len(a.subAgents) == 0 {
121123
return
122124
}
123125

124126
var iterations int
125127
for iterations < a.maxIterations || a.maxIterations == 0 {
126128
iterations++
127-
exit := a.runSequential(ctx, input, generator)
129+
exit := a.runSequential(ctx, input, generator, opts...)
128130
if exit {
129131
return
130132
}
131133
}
132134
}
133135

134-
func (a *workflowAgent) runParallel(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent]) {
136+
func (a *workflowAgent) runParallel(ctx context.Context, input *AgentInput,
137+
generator *AsyncGenerator[*AgentEvent], opts ...AgentRunOption) {
138+
135139
if len(a.subAgents) == 0 {
136140
return
137141
}
@@ -150,7 +154,7 @@ func (a *workflowAgent) runParallel(ctx context.Context, input *AgentInput, gene
150154
wg.Done()
151155
}()
152156

153-
iterator := agent.Run(ctx, input)
157+
iterator := agent.Run(ctx, input, opts...)
154158
for {
155159
event, ok := iterator.Next()
156160
if !ok {

0 commit comments

Comments
 (0)