Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions adk/call_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type options struct {
sessionValues map[string]any
checkPointID *string
skipTransferMessages bool
sessionID *string
sessionOptions []StoreOption
}

// AgentRunOption is the call option for adk Agent.
Expand Down Expand Up @@ -55,6 +57,21 @@ func WithSkipTransferMessages() AgentRunOption {
})
}

// WithSessionID sets the session ID for the agent run.
func WithSessionID(id string) AgentRunOption {
return WrapImplSpecificOptFn(func(t *options) {
t.sessionID = &id
})
}

// WithSessionOptions passes StoreOption (e.g. WithLimit, WithRoundID) to the SessionStore.
// These options control how session history is loaded and saved.
func WithSessionOptions(opts ...StoreOption) AgentRunOption {
return WrapImplSpecificOptFn(func(t *options) {
t.sessionOptions = append(t.sessionOptions, opts...)
})
}

// WrapImplSpecificOptFn is the option to wrap the implementation specific option function.
func WrapImplSpecificOptFn[T any](optFn func(*T)) AgentRunOption {
return AgentRunOption{
Expand Down
108 changes: 91 additions & 17 deletions adk/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@

package adk

// Runner orchestrates agent execution, session, and checkpoints.
//
// Behavior summary:
// - Run: loads session history (if SessionService provided) and saves new input messages;
// outputs are persisted via SessionService while iterating.
// - Resume: resumes from checkpoints only; session history is not loaded during resume,
// but outputs will still be persisted if SessionService is provided.
//
import (
"context"
"fmt"
Expand All @@ -37,13 +45,18 @@ type Runner struct {
// store is the checkpoint store used to persist agent state upon interruption.
// If nil, checkpointing is disabled.
store compose.CheckPointStore
// sessionService manages session persistence and hooks.
// If nil, session persistence is disabled.
sessionService *SessionService
}

type RunnerConfig struct {
Agent Agent
EnableStreaming bool

CheckPointStore compose.CheckPointStore
// SessionService is optional. If provided, enables session persistence.
SessionService *SessionService
}

// ResumeParams contains all parameters needed to resume an execution.
Expand All @@ -61,6 +74,7 @@ func NewRunner(_ context.Context, conf RunnerConfig) *Runner {
enableStreaming: conf.EnableStreaming,
a: conf.Agent,
store: conf.CheckPointStore,
sessionService: conf.SessionService,
}
}

Expand All @@ -83,14 +97,38 @@ func (r *Runner) Run(ctx context.Context, messages []Message,

AddSessionValues(ctx, o.sessionValues)

if r.sessionService != nil && o.sessionID != nil {
// load history from session
history, err := r.sessionService.load(ctx, *o.sessionID, o.sessionOptions...)
if err != nil {
niter, gen := NewAsyncIteratorPair[*AgentEvent]()
gen.Send(&AgentEvent{Err: fmt.Errorf("failed to load session history: %w", err)})
gen.Close()
return niter
}
// Add history to session context (not input messages)
session := getSession(ctx)
for _, event := range history {
session.addEvent(event)
}

// Save new input messages to session
if err := r.sessionService.saveInput(ctx, *o.sessionID, messages, o.sessionOptions...); err != nil {
niter, gen := NewAsyncIteratorPair[*AgentEvent]()
gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save input to session: %w", err)})
gen.Close()
return niter
}
}

iter := fa.Run(ctx, input, opts...)
if r.store == nil {
if r.store == nil && r.sessionService == nil {
return iter
}

niter, gen := NewAsyncIteratorPair[*AgentEvent]()

go r.handleIter(ctx, iter, gen, o.checkPointID)
go r.handleIter(ctx, iter, gen, o.checkPointID, o.sessionID, o.sessionOptions)
return niter
}

Expand Down Expand Up @@ -156,18 +194,18 @@ func (r *Runner) resume(ctx context.Context, checkPointID string, resumeData map

fa := toFlowAgent(ctx, r.a)
aIter := fa.Resume(ctx, resumeInfo, opts...)
if r.store == nil {
if r.store == nil && r.sessionService == nil {
return aIter, nil
}

niter, gen := NewAsyncIteratorPair[*AgentEvent]()

go r.handleIter(ctx, aIter, gen, &checkPointID)
go r.handleIter(ctx, aIter, gen, &checkPointID, o.sessionID, o.sessionOptions)
return niter, nil
}

func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent],
gen *AsyncGenerator[*AgentEvent], checkPointID *string) {
gen *AsyncGenerator[*AgentEvent], checkPointID *string, sessionID *string, sessionOptions []StoreOption) {
defer func() {
panicErr := recover()
if panicErr != nil {
Expand All @@ -177,21 +215,61 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven

gen.Close()
}()

// Wrap iterator with checkpoint handling if needed
processedIter := r.wrapWithCheckpoint(ctx, aIter, checkPointID)

// Handle session or just forward events
if r.sessionService != nil && sessionID != nil {
// Delegate to SessionService for session handling
r.sessionService.saveOutput(ctx, *sessionID, processedIter, gen, sessionOptions...)
} else {
// Just forward all events
for {
item, ok := processedIter.Next()
if !ok {
break
}
gen.Send(item)
}
}
}

// wrapWithCheckpoint wraps an iterator to process checkpoints for interrupt events.
// Returns the original iterator if checkPointID is nil.
func (r *Runner) wrapWithCheckpoint(ctx context.Context, aIter *AsyncIterator[*AgentEvent], checkPointID *string) *AsyncIterator[*AgentEvent] {
if checkPointID == nil {
return aIter // No checkpoint, return as-is
}

// Create intermediate iterator for checkpoint processing
niter, gen := NewAsyncIteratorPair[*AgentEvent]()
go func() {
defer gen.Close()
r.processCheckpoints(ctx, aIter, gen, *checkPointID)
}()
return niter
}

// processCheckpoints handles checkpoint save logic for interrupt events.
func (r *Runner) processCheckpoints(ctx context.Context, aIter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent], checkPointID string) {
var (
interruptSignal *core.InterruptSignal
legacyData any
)

for {
event, ok := aIter.Next()
if !ok {
break
}

// Handle checkpoint for interrupts
if event.Action != nil && event.Action.internalInterrupted != nil {
// Even if multiple interrupt happens, they should be merged into one
// action by CompositeInterrupt, so here in Runner we must assume at most
// one interrupt action happens
if interruptSignal != nil {
// even if multiple interrupt happens, they should be merged into one
// action by CompositeInterrupt, so here in Runner we must assume at most
// one interrupt action happens
panic("multiple interrupt actions should not happen in Runner")
}
interruptSignal = event.Action.internalInterrupted
Expand All @@ -210,15 +288,11 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven
}
legacyData = event.Action.Interrupted.Data

if checkPointID != nil {
// save checkpoint first before sending interrupt event,
// so when end-user receives interrupt event, they can resume from this checkpoint
err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{
Data: legacyData,
}, interruptSignal)
if err != nil {
gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint: %w", err)})
}
err := r.saveCheckPoint(ctx, checkPointID, &InterruptInfo{
Data: legacyData,
}, interruptSignal)
if err != nil {
gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint: %w", err)})
}
}

Expand Down
Loading
Loading