diff --git a/adk/call_option.go b/adk/call_option.go index 31cf56ba..e16fdd3b 100644 --- a/adk/call_option.go +++ b/adk/call_option.go @@ -20,6 +20,8 @@ type options struct { sessionValues map[string]any checkPointID *string skipTransferMessages bool + sessionID *string + sessionOptions []StoreOption } // AgentRunOption is the call option for adk Agent. @@ -55,6 +57,21 @@ func WithSkipTransferMessages() AgentRunOption { }) } +// WithSessionID sets the session ID for the agent run. +func WithSessionID(id string) AgentRunOption { + return WrapImplSpecificOptFn(func(t *options) { + t.sessionID = &id + }) +} + +// WithSessionOptions passes StoreOption (e.g. WithLimit, WithRoundID) to the SessionStore. +// These options control how session history is loaded and saved. +func WithSessionOptions(opts ...StoreOption) AgentRunOption { + return WrapImplSpecificOptFn(func(t *options) { + t.sessionOptions = append(t.sessionOptions, opts...) + }) +} + // WrapImplSpecificOptFn is the option to wrap the implementation specific option function. func WrapImplSpecificOptFn[T any](optFn func(*T)) AgentRunOption { return AgentRunOption{ diff --git a/adk/runner.go b/adk/runner.go index b0e9a486..dbc5df19 100644 --- a/adk/runner.go +++ b/adk/runner.go @@ -16,6 +16,14 @@ package adk +// Runner orchestrates agent execution, session, and checkpoints. +// +// Behavior summary: +// - Run: loads session history (if SessionService provided) and saves new input messages; +// outputs are persisted via SessionService while iterating. +// - Resume: resumes from checkpoints only; session history is not loaded during resume, +// but outputs will still be persisted if SessionService is provided. +// import ( "context" "fmt" @@ -37,6 +45,9 @@ type Runner struct { // store is the checkpoint store used to persist agent state upon interruption. // If nil, checkpointing is disabled. store compose.CheckPointStore + // sessionService manages session persistence and hooks. + // If nil, session persistence is disabled. + sessionService *SessionService } type RunnerConfig struct { @@ -44,6 +55,8 @@ type RunnerConfig struct { EnableStreaming bool CheckPointStore compose.CheckPointStore + // SessionService is optional. If provided, enables session persistence. + SessionService *SessionService } // ResumeParams contains all parameters needed to resume an execution. @@ -61,6 +74,7 @@ func NewRunner(_ context.Context, conf RunnerConfig) *Runner { enableStreaming: conf.EnableStreaming, a: conf.Agent, store: conf.CheckPointStore, + sessionService: conf.SessionService, } } @@ -83,14 +97,38 @@ func (r *Runner) Run(ctx context.Context, messages []Message, AddSessionValues(ctx, o.sessionValues) + if r.sessionService != nil && o.sessionID != nil { + // load history from session + history, err := r.sessionService.load(ctx, *o.sessionID, o.sessionOptions...) + if err != nil { + niter, gen := NewAsyncIteratorPair[*AgentEvent]() + gen.Send(&AgentEvent{Err: fmt.Errorf("failed to load session history: %w", err)}) + gen.Close() + return niter + } + // Add history to session context (not input messages) + session := getSession(ctx) + for _, event := range history { + session.addEvent(event) + } + + // Save new input messages to session + if err := r.sessionService.saveInput(ctx, *o.sessionID, messages, o.sessionOptions...); err != nil { + niter, gen := NewAsyncIteratorPair[*AgentEvent]() + gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save input to session: %w", err)}) + gen.Close() + return niter + } + } + iter := fa.Run(ctx, input, opts...) - if r.store == nil { + if r.store == nil && r.sessionService == nil { return iter } niter, gen := NewAsyncIteratorPair[*AgentEvent]() - go r.handleIter(ctx, iter, gen, o.checkPointID) + go r.handleIter(ctx, iter, gen, o.checkPointID, o.sessionID, o.sessionOptions) return niter } @@ -156,18 +194,18 @@ func (r *Runner) resume(ctx context.Context, checkPointID string, resumeData map fa := toFlowAgent(ctx, r.a) aIter := fa.Resume(ctx, resumeInfo, opts...) - if r.store == nil { + if r.store == nil && r.sessionService == nil { return aIter, nil } niter, gen := NewAsyncIteratorPair[*AgentEvent]() - go r.handleIter(ctx, aIter, gen, &checkPointID) + go r.handleIter(ctx, aIter, gen, &checkPointID, o.sessionID, o.sessionOptions) return niter, nil } func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEvent], - gen *AsyncGenerator[*AgentEvent], checkPointID *string) { + gen *AsyncGenerator[*AgentEvent], checkPointID *string, sessionID *string, sessionOptions []StoreOption) { defer func() { panicErr := recover() if panicErr != nil { @@ -177,21 +215,61 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven gen.Close() }() + + // Wrap iterator with checkpoint handling if needed + processedIter := r.wrapWithCheckpoint(ctx, aIter, checkPointID) + + // Handle session or just forward events + if r.sessionService != nil && sessionID != nil { + // Delegate to SessionService for session handling + r.sessionService.saveOutput(ctx, *sessionID, processedIter, gen, sessionOptions...) + } else { + // Just forward all events + for { + item, ok := processedIter.Next() + if !ok { + break + } + gen.Send(item) + } + } +} + +// wrapWithCheckpoint wraps an iterator to process checkpoints for interrupt events. +// Returns the original iterator if checkPointID is nil. +func (r *Runner) wrapWithCheckpoint(ctx context.Context, aIter *AsyncIterator[*AgentEvent], checkPointID *string) *AsyncIterator[*AgentEvent] { + if checkPointID == nil { + return aIter // No checkpoint, return as-is + } + + // Create intermediate iterator for checkpoint processing + niter, gen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer gen.Close() + r.processCheckpoints(ctx, aIter, gen, *checkPointID) + }() + return niter +} + +// processCheckpoints handles checkpoint save logic for interrupt events. +func (r *Runner) processCheckpoints(ctx context.Context, aIter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent], checkPointID string) { var ( interruptSignal *core.InterruptSignal legacyData any ) + for { event, ok := aIter.Next() if !ok { break } + // Handle checkpoint for interrupts if event.Action != nil && event.Action.internalInterrupted != nil { + // Even if multiple interrupt happens, they should be merged into one + // action by CompositeInterrupt, so here in Runner we must assume at most + // one interrupt action happens if interruptSignal != nil { - // even if multiple interrupt happens, they should be merged into one - // action by CompositeInterrupt, so here in Runner we must assume at most - // one interrupt action happens panic("multiple interrupt actions should not happen in Runner") } interruptSignal = event.Action.internalInterrupted @@ -210,15 +288,11 @@ func (r *Runner) handleIter(ctx context.Context, aIter *AsyncIterator[*AgentEven } legacyData = event.Action.Interrupted.Data - if checkPointID != nil { - // save checkpoint first before sending interrupt event, - // so when end-user receives interrupt event, they can resume from this checkpoint - err := r.saveCheckPoint(ctx, *checkPointID, &InterruptInfo{ - Data: legacyData, - }, interruptSignal) - if err != nil { - gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint: %w", err)}) - } + err := r.saveCheckPoint(ctx, checkPointID, &InterruptInfo{ + Data: legacyData, + }, interruptSignal) + if err != nil { + gen.Send(&AgentEvent{Err: fmt.Errorf("failed to save checkpoint: %w", err)}) } } diff --git a/adk/session.go b/adk/session.go new file mode 100644 index 00000000..1532fb28 --- /dev/null +++ b/adk/session.go @@ -0,0 +1,609 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +// Package adk provides Agent Development Kit primitives for building multi-turn agents. +// +// Session vs Checkpoint vs Memory: +// - Session: durable conversation history across runs. Loaded during Run, saved for inputs/outputs. +// - Checkpoint: per-run snapshots on interruption, used only for Resume flows. +// - Memory: long-term context injection/recall, orthogonal and application-specific. +// +// Using SessionService: +// 1) Create a SessionStore (e.g., InMemorySessionStore for tests or a production store). +// 2) Create SessionService with optional handlers: +// - AfterGetSessionHandler: batch transforms on loaded history (summarize/compact). +// - BeforeAddSessionHandler: per-event filter/transform before saving; return nil to skip. +// 3) Pass SessionService in RunnerConfig and run with WithSessionID(""). +// +// Handlers examples: +// // Summarize assistant messages after load +// h1 := func(ctx context.Context, evs []*AgentEvent) ([]*AgentEvent, error) { /* modify evs */ return evs, nil } +// // Skip saving specific assistant replies +// h2 := func(ctx context.Context, ev *AgentEvent) (*AgentEvent, error) { /* return nil to skip */ return ev, nil } +// +// Retry Policy Guidance: +// Implementations of SessionStore should handle transient errors internally with retry/backoff. +// For example, a production store might wrap Add/Set with exponential backoff on network timeouts. +// Pseudocode: +// for attempt := 0; attempt < max; attempt++ { err = store.Add(...); if transient(err) { sleep(backoff(attempt)); continue } break } +// Errors returned from SessionStore methods indicate cases that can't be retried or exhausted cases and are treated as CRITICAL. + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "log" + "sync" +) + +// StoreOption allows passing implementation-specific filters and metadata to SessionStore. +// Different store implementations can define their own option types and process them accordingly. +// This follows the same pattern as AgentRunOption for maximum flexibility. +// +// Example Usage: +// +// // InMemorySessionStore defines its own options +// type InMemoryStoreOptions struct { +// Limit int +// Offset int +// RoundID string +// } +// +// // Helper functions create options using WrapStoreImplSpecificOptFn +// func WithLimit(n int) StoreOption { +// return WrapStoreImplSpecificOptFn(func(o *InMemoryStoreOptions) { +// o.Limit = n +// }) +// } +// +// // Store implementation extracts options using GetStoreImplSpecificOptions +// func (s *InMemorySessionStore) Get(ctx, id string, opts ...StoreOption) ([][]byte, error) { +// o := GetStoreImplSpecificOptions(&InMemoryStoreOptions{}, opts...) +// // Use o.Limit, o.Offset, o.RoundID +// } +type StoreOption struct { + implSpecificOptFn any +} + +// WrapStoreImplSpecificOptFn wraps an implementation-specific option function into a StoreOption. +// This allows different SessionStore implementations to define their own option types. +func WrapStoreImplSpecificOptFn[T any](optFn func(*T)) StoreOption { + return StoreOption{ + implSpecificOptFn: optFn, + } +} + +// GetStoreImplSpecificOptions extracts implementation-specific options from a StoreOption list. +// Usage: +// +// opts := GetStoreImplSpecificOptions(&MyStoreOptions{}, storeOpts...) +func GetStoreImplSpecificOptions[T any](base *T, opts ...StoreOption) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + optFn, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + optFn(base) + } + } + } + + return base +} + +// SessionStore is the interface for persisting session history. +// +// SessionID Design: +// The sessionID should encapsulate all application-level metadata needed for isolation and routing. +// Two common patterns: +// 1. Structured ID: "app_123:user_456:session_789" (implementation can parse for sharding/indexing) +// 2. Globally Unique ID: Use UUIDs or similar, ensuring uniqueness across all users/apps +// +// The framework treats sessionID as an opaque string; metadata extraction is the store's responsibility. +// +// Deletion Policy: +// This interface intentionally omits a Delete method. Session cleanup (e.g., GDPR compliance, TTL expiry) +// is an application-level concern, not a runtime execution concern. Implementations should provide +// deletion capabilities outside this interface (e.g., admin APIs, background jobs). +// +// Error Handling: +// Implementations should encapsulate retry logic for transient failures. +// Any error returned from these methods represents a case that can't be retried or exhausted retries, +// and will be treated as CRITICAL by SessionService. +// Implementations should also handle concurrency control (e.g., locking) if needed. +type SessionStore interface { + // Get retrieves the session history from the store. + // Accepts options like WithLimit, WithOffset for filtering. + Get(ctx context.Context, sessionID string, opts ...StoreOption) ([][]byte, error) + // Add appends new entries to the session history. + // Accepts options like WithRoundID for metadata attachment. + Add(ctx context.Context, sessionID string, entries [][]byte, opts ...StoreOption) error + // Set overwrites the session history with the given entries. + // Accepts options like WithRoundID for metadata attachment. + Set(ctx context.Context, sessionID string, entries [][]byte, opts ...StoreOption) error +} + +// AfterGetSessionHandler processes loaded events in batch (for summarization, compaction). +// After processing, the resulting events are injected into the agent's run context. +type AfterGetSessionHandler func(ctx context.Context, events []*AgentEvent) ([]*AgentEvent, error) + +// BeforeAddSessionHandler processes individual events before saving. +// It applies to events from both saveOutput (agent execution) and saveInput (user input). +// Returning nil means the event should not be saved. +type BeforeAddSessionHandler func(ctx context.Context, event *AgentEvent) (*AgentEvent, error) + +// EventSerializer defines how AgentEvents are converted to/from bytes. +// Implementations can use any format (e.g., gob, json, protobuf). +type EventSerializer interface { + Marshal(event *AgentEvent) ([]byte, error) + Unmarshal(data []byte) (*AgentEvent, error) +} + +// defaultGobSerializer implements EventSerializer using encoding/gob. +type defaultGobSerializer struct{} + +// Marshal converts an AgentEvent to gob-encoded bytes. +func (s *defaultGobSerializer) Marshal(event *AgentEvent) ([]byte, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(event); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// Unmarshal converts gob-encoded bytes to an AgentEvent. +func (s *defaultGobSerializer) Unmarshal(data []byte) (*AgentEvent, error) { + var event AgentEvent + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&event); err != nil { + return nil, err + } + return &event, nil +} + +// SessionService manages the retrieval and persistence of session history. +// It acts as a middleware between the Runner and the SessionStore. +type SessionService struct { + store SessionStore + serializer EventSerializer + beforeAdd []BeforeAddSessionHandler + afterGet []AfterGetSessionHandler + persistAfterGet bool +} + +// SessionServiceOption configures a SessionService. +type SessionServiceOption func(*SessionService) + +// WithAfterGetSession adds handlers that run after loading session data. +func WithAfterGetSession(handlers ...AfterGetSessionHandler) SessionServiceOption { + return func(s *SessionService) { + s.afterGet = append(s.afterGet, handlers...) + } +} + +// WithBeforeAddSession adds handlers that run before saving session data. +// These handlers are applied to both saveOutput and saveInput operations. +func WithBeforeAddSession(handlers ...BeforeAddSessionHandler) SessionServiceOption { + return func(s *SessionService) { + s.beforeAdd = append(s.beforeAdd, handlers...) + } +} + +// WithPersistAfterGetSession enables automatic persistence of modified events after AfterGetSession handlers. +func WithPersistAfterGetSession() SessionServiceOption { + return func(s *SessionService) { s.persistAfterGet = true } +} + +// WithEventSerializer configures a custom serializer for the SessionService. +func WithEventSerializer(serializer EventSerializer) SessionServiceOption { + return func(s *SessionService) { + s.serializer = serializer + } +} + +// NewSessionService creates a new SessionService with the given store and options. +func NewSessionService(store SessionStore, opts ...SessionServiceOption) *SessionService { + s := &SessionService{ + store: store, + serializer: &defaultGobSerializer{}, + } + for _, opt := range opts { + opt(s) + } + return s +} + +// MarshalEvent converts an AgentEvent to bytes using the configured serializer. +func (s *SessionService) MarshalEvent(event *AgentEvent) ([]byte, error) { + return s.serializer.Marshal(event) +} + +// UnmarshalEvent converts bytes to an AgentEvent using the configured serializer. +func (s *SessionService) UnmarshalEvent(data []byte) (*AgentEvent, error) { + return s.serializer.Unmarshal(data) +} + +// marshalEvents serializes a batch of events using the configured serializer. +// The label tailors error messages to the caller context (e.g., "event", "input event"). +func (s *SessionService) marshalEvents(events []*AgentEvent, label string) ([][]byte, error) { + if len(events) == 0 { + return nil, nil + } + entries := make([][]byte, 0, len(events)) + for _, event := range events { + b, err := s.serializer.Marshal(event) + if err != nil { + return nil, fmt.Errorf("failed to marshal %s: %w", label, err) + } + entries = append(entries, b) + } + return entries, nil +} + +// unmarshalEntries deserializes a batch of entries into events. +func (s *SessionService) unmarshalEntries(entries [][]byte) ([]*AgentEvent, error) { + if len(entries) == 0 { + return nil, nil + } + events := make([]*AgentEvent, 0, len(entries)) + for _, entry := range entries { + event, err := s.serializer.Unmarshal(entry) + if err != nil { + return nil, fmt.Errorf("failed to deserialize session entry: %w", err) + } + events = append(events, event) + } + return events, nil +} + +// applyBeforeAdd runs the beforeAdd handler chain with consistent logging. +// Returns the transformed event; nil means "don't save". +func (s *SessionService) applyBeforeAdd(ctx context.Context, event *AgentEvent, op string) *AgentEvent { + current := event + for i, handler := range s.beforeAdd { + transformed, err := handler(ctx, current) + if err != nil { + if op == "input event" { + log.Printf("[WARN] BeforeAddSession handler[%d] failed for input event: %v, using last successful version", i, err) + } else { + log.Printf("[WARN] BeforeAddSession handler[%d] failed for event: %v, using last successful version", i, err) + } + break + } + if transformed == nil { + current = nil + break + } + current = transformed + } + return current +} + +// load retrieves agent events from the session store. +// CRITICAL errors (store.Get, deserialization) abort and return error. +// NON-CRITICAL errors (handlers, persistAfterGet) log warnings and continue. +func (s *SessionService) load(ctx context.Context, sessionID string, opts ...StoreOption) ([]*AgentEvent, error) { + // 1. load from store (CRITICAL) + data, err := s.store.Get(ctx, sessionID, opts...) + if err != nil { + return nil, fmt.Errorf("failed to load session (sessionID=%s): %w", sessionID, err) + } + + // 2. Deserialize (CRITICAL) + events, err := s.unmarshalEntries(data) + if err != nil { + return nil, fmt.Errorf("failed to decode session entries (sessionID=%s): %w", sessionID, err) + } + + // 3. Apply AfterGetSession handlers (NON-CRITICAL) + for i, handler := range s.afterGet { + modifiedEvents, err := handler(ctx, events) + if err != nil { + log.Printf("[WARN] AfterGetSession handler[%d] failed (sessionID=%s): %v, skipping remaining handlers", i, sessionID, err) + break + } + events = modifiedEvents + } + + // 4. PersistAfterGet if enabled (NON-CRITICAL) + if s.persistAfterGet { + if err := s.persistEvents(ctx, sessionID, events); err != nil { + log.Printf("[WARN] PersistAfterGet failed (sessionID=%s): %v", sessionID, err) + } + } + + return events, nil +} + +// saveOutput consumes events from an iterator, applies handlers, and saves to session. +// Emits transformed events if persisted, original events if not persisted. +// Critical errors (serialization, store.Add) are sent as error events to the generator. +// Non-critical errors (handlers) log warnings and continue. +func (s *SessionService) saveOutput( + ctx context.Context, + sessionID string, + iter *AsyncIterator[*AgentEvent], + gen *AsyncGenerator[*AgentEvent], + opts ...StoreOption, +) { + var toSave []*AgentEvent + + // 1. Consume iterator and apply handlers + for { + event, ok := iter.Next() + if !ok { + break + } + + // Built-in filter: skip interrupt events (temporary state, not persisted) + if event.Action != nil && event.Action.Interrupted != nil { + gen.Send(event) // Emit original interrupt + continue // Don't save + } + + // Apply BeforeAddSession handlers with chaining + current := s.applyBeforeAdd(ctx, event, "event") + + if current != nil { + gen.Send(current) // Emit transformed + toSave = append(toSave, current) + } else { + gen.Send(event) // Emit original (not saved) + } + } + + // 2. Serialize (CRITICAL) + entries, err := s.marshalEvents(toSave, "event") + if err != nil { + gen.Send(&AgentEvent{Err: fmt.Errorf("failed to serialize events (sessionID=%s): %w", sessionID, err)}) + return + } + + // 3. Add to store (CRITICAL) + if len(entries) > 0 { + if err = s.store.Add(ctx, sessionID, entries, opts...); err != nil { + errorEvent := &AgentEvent{Err: fmt.Errorf("failed to save session (sessionID=%s): %w", sessionID, err)} + gen.Send(errorEvent) + } + } +} + +// saveInput persists input messages to the session store. +// Messages are converted to AgentEvents, processed by handlers, and saved. +// CRITICAL errors (serialization, store.Add) return error. +// NON-CRITICAL errors (handlers) log warnings and continue. +func (s *SessionService) saveInput(ctx context.Context, sessionID string, input []Message, opts ...StoreOption) error { + var toSave []*AgentEvent + + for _, msg := range input { + event := &AgentEvent{ + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: msg, + }, + }, + } + + // Apply BeforeAddSession handlers + current := s.applyBeforeAdd(ctx, event, "input event") + + if current != nil { + toSave = append(toSave, current) + } + } + + if len(toSave) == 0 { + return nil + } + + // Serialize (CRITICAL) + entries, err := s.marshalEvents(toSave, "input event") + if err != nil { + return fmt.Errorf("failed to serialize input (sessionID=%s): %w", sessionID, err) + } + + // Add to store (CRITICAL) + if err = s.store.Add(ctx, sessionID, entries, opts...); err != nil { + return fmt.Errorf("failed to save input to session (sessionID=%s): %w", sessionID, err) + } + + return nil +} + +// persistEvents serializes and persists events to the store (used by PersistAfterGet). +func (s *SessionService) persistEvents(ctx context.Context, sessionID string, events []*AgentEvent) error { + entries, err := s.marshalEvents(events, "event") + if err != nil { + return err + } + return s.store.Set(ctx, sessionID, entries) +} + +// InMemoryStoreOptions defines options specific to InMemorySessionStore. +type InMemoryStoreOptions struct { + // Limit specifies the maximum number of events to return (latest N) + Limit int + // Offset specifies how many events to skip from the beginning + Offset int + // RoundID attaches metadata when saving or filters when querying + RoundID string +} + +// WithLimit requests the latest N events from InMemorySessionStore. +// Commonly used for "load last 10 turns" scenarios. +func WithLimit(n int) StoreOption { + return WrapStoreImplSpecificOptFn(func(o *InMemoryStoreOptions) { + o.Limit = n + }) +} + +// WithOffset skips the first N events in InMemorySessionStore. +// Typically used in combination with WithLimit for pagination. +func WithOffset(n int) StoreOption { + return WrapStoreImplSpecificOptFn(func(o *InMemoryStoreOptions) { + o.Offset = n + }) +} + +// WithRoundID attaches or filters by round identifier in InMemorySessionStore. +// When used with Add/Set, attaches the roundID metadata to stored events. +// When used with Get, filters events to return only those matching the roundID. +func WithRoundID(id string) StoreOption { + return WrapStoreImplSpecificOptFn(func(o *InMemoryStoreOptions) { + o.RoundID = id + }) +} + +// InMemorySessionStore is a simple in-memory implementation of SessionStore. +// +// Scope and Behavior: +// - Intended for testing and non-production scenarios; data is process-local and volatile. +// - Serializes store operations per sessionID via a mutex; does NOT enforce single-runner policy. +// Multiple runners may still target the same sessionID at the application level. +// - Uses copy-on-read/write to avoid shared slice mutation by callers. +// - Supports metadata attachment (RoundID) to demonstrate how stores can index/filter by metadata. +// +// This store is useful for tests and examples where persistence is not required. +// Production deployments should use a resilient store with retries and durability guarantees. +type InMemorySessionStore struct { + data map[string][]storedEntry + locks map[string]*sync.Mutex + mu sync.Mutex // Protects the maps themselves +} + +// storedEntry represents a single session event with optional metadata +type storedEntry struct { + data []byte + roundID string +} + +// NewInMemorySessionStore creates a new in-memory session store. +func NewInMemorySessionStore() *InMemorySessionStore { + return &InMemorySessionStore{ + data: make(map[string][]storedEntry), + locks: make(map[string]*sync.Mutex), + } +} + +// getSessionLock returns the mutex for the given sessionID, creating one if it doesn't exist. +func (s *InMemorySessionStore) getSessionLock(sessionID string) *sync.Mutex { + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.locks[sessionID]; !exists { + s.locks[sessionID] = &sync.Mutex{} + } + return s.locks[sessionID] +} + +// Get retrieves the session history from the store. +// This implementation supports WithLimit, WithOffset, and WithRoundID options. +func (s *InMemorySessionStore) Get(_ context.Context, sessionID string, opts ...StoreOption) ([][]byte, error) { + lock := s.getSessionLock(sessionID) + lock.Lock() + defer lock.Unlock() + + entries := s.data[sessionID] + if entries == nil { + return [][]byte{}, nil + } + + // Extract options + o := GetStoreImplSpecificOptions(&InMemoryStoreOptions{}, opts...) + + // Apply RoundID filter first + filtered := entries + if o.RoundID != "" { + filtered = make([]storedEntry, 0, len(entries)) + for _, entry := range entries { + if entry.roundID == o.RoundID { + filtered = append(filtered, entry) + } + } + } + + // Apply Offset (skip first N) + if o.Offset > 0 { + if o.Offset >= len(filtered) { + return [][]byte{}, nil + } + filtered = filtered[o.Offset:] + } + + // Apply Limit (latest N after offset and filter) + if o.Limit > 0 && len(filtered) > o.Limit { + // Return the latest N events (from the end of the remaining slice) + filtered = filtered[len(filtered)-o.Limit:] + } + + // Extract data bytes from filtered entries + result := make([][]byte, len(filtered)) + for i, entry := range filtered { + result[i] = entry.data + } + + return result, nil +} + +// Add appends new entries to the session history. +// Supports WithRoundID option for metadata attachment. +func (s *InMemorySessionStore) Add(_ context.Context, sessionID string, entries [][]byte, opts ...StoreOption) error { + lock := s.getSessionLock(sessionID) + lock.Lock() + defer lock.Unlock() + + // Extract options + o := GetStoreImplSpecificOptions(&InMemoryStoreOptions{}, opts...) + + // Wrap entries with metadata + for _, data := range entries { + s.data[sessionID] = append(s.data[sessionID], storedEntry{ + data: data, + roundID: o.RoundID, + }) + } + + return nil +} + +// Set overwrites the session history with the given entries. +// Supports WithRoundID option for metadata attachment. +func (s *InMemorySessionStore) Set(_ context.Context, sessionID string, entries [][]byte, opts ...StoreOption) error { + lock := s.getSessionLock(sessionID) + lock.Lock() + defer lock.Unlock() + + // Extract options + o := GetStoreImplSpecificOptions(&InMemoryStoreOptions{}, opts...) + + // Wrap entries with metadata + newEntries := make([]storedEntry, 0, len(entries)) + for _, data := range entries { + newEntries = append(newEntries, storedEntry{ + data: data, + roundID: o.RoundID, + }) + } + + s.data[sessionID] = newEntries + return nil +} diff --git a/adk/session_test.go b/adk/session_test.go new file mode 100644 index 00000000..a7f83727 --- /dev/null +++ b/adk/session_test.go @@ -0,0 +1,689 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package adk + +import ( + "bytes" + "context" + "encoding/gob" + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +type errorAddStore struct { + data map[string][][]byte + addErr error +} + +func newErrorAddStore(addErr error) *errorAddStore { + return &errorAddStore{data: make(map[string][][]byte), addErr: addErr} +} + +func (e *errorAddStore) Get(_ context.Context, sessionID string, _ ...StoreOption) ([][]byte, error) { + return e.data[sessionID], nil +} + +func (e *errorAddStore) Add(_ context.Context, _ string, _ [][]byte, _ ...StoreOption) error { + return e.addErr +} + +func (e *errorAddStore) Set(_ context.Context, sessionID string, entries [][]byte, _ ...StoreOption) error { + e.data[sessionID] = entries + return nil +} + +func TestSession_Saving_FiltersAndHandlers(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore() + sessionID := "saving_filters_handlers" + + skipHandler := func(ctx context.Context, event *AgentEvent) (*AgentEvent, error) { + if event.Output != nil && event.Output.MessageOutput != nil { + m := event.Output.MessageOutput.Message + if m != nil && m.Role == schema.Assistant && m.Content == "Skip this" { + return nil, nil + } + } + return event, nil + } + + service := NewSessionService(store, WithBeforeAddSession(skipHandler)) + + mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent", []*AgentEvent{ + {AgentName: "TestAgent", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("Skip this", nil)}}}, + {AgentName: "TestAgent", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("Save this", nil)}}}, + {AgentName: "TestAgent", Action: &AgentAction{Interrupted: &InterruptInfo{}}}, + }) + + runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_, SessionService: service}) + iter := runner.Run(ctx, []Message{schema.UserMessage("Hello")}, WithSessionID(sessionID)) + + var received []*AgentEvent + for { + ev, ok := iter.Next() + if !ok { + break + } + received = append(received, ev) + } + assert.Equal(t, 3, len(received)) + assert.Equal(t, "Skip this", received[0].Output.MessageOutput.Message.Content) + assert.Equal(t, "Save this", received[1].Output.MessageOutput.Message.Content) + + storedData, _ := store.Get(ctx, sessionID) + assert.Equal(t, 2, len(storedData)) + + var savedInput, savedResp *AgentEvent + _ = gob.NewDecoder(bytes.NewReader(storedData[0])).Decode(&savedInput) + _ = gob.NewDecoder(bytes.NewReader(storedData[1])).Decode(&savedResp) + assert.Equal(t, "Hello", savedInput.Output.MessageOutput.Message.Content) + assert.Equal(t, "Save this", savedResp.Output.MessageOutput.Message.Content) +} + +func TestSession_BeforeAdd_ChainAndError(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore() + sessionID := "before_add_chain_error" + + failingHandler := func(ctx context.Context, event *AgentEvent) (*AgentEvent, error) { + return nil, errors.New("handler failed") + } + successHandler := func(ctx context.Context, event *AgentEvent) (*AgentEvent, error) { + modified := *event + if modified.AgentName == "" { + modified.AgentName = "Input" + } + modified.AgentName = "Modified_" + modified.AgentName + return &modified, nil + } + + service := NewSessionService(store, WithBeforeAddSession(successHandler, failingHandler)) + + mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent", []*AgentEvent{ + {AgentName: "TestAgent", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("Response", nil)}}}, + }) + + runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_, SessionService: service}) + iter := runner.Run(ctx, []Message{schema.UserMessage("Hello")}, WithSessionID(sessionID)) + + var received *AgentEvent + for { + ev, ok := iter.Next() + if !ok { + break + } + received = ev + } + if received == nil { + t.Fatal("No event received") + } + assert.Equal(t, "Modified_TestAgent", received.AgentName) + + data, _ := store.Get(ctx, sessionID) + assert.Equal(t, 2, len(data)) + var in, out *AgentEvent + _ = gob.NewDecoder(bytes.NewReader(data[0])).Decode(&in) + _ = gob.NewDecoder(bytes.NewReader(data[1])).Decode(&out) + assert.Equal(t, "Modified_Input", in.AgentName) + assert.Equal(t, "Modified_TestAgent", out.AgentName) +} + +func TestSession_LoadAndSaveInput(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore() + sessionID := "load_and_save" + + // 1. Basic Load/Save and UnmarshalEvent + prev := &AgentEvent{Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.UserMessage("Previous message")}}} + buf := &bytes.Buffer{} + _ = gob.NewEncoder(buf).Encode(prev) + _ = store.Add(ctx, sessionID, [][]byte{buf.Bytes()}) + + service := NewSessionService(store) + mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent", []*AgentEvent{ + {AgentName: "TestAgent", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("Response", nil)}}}, + }) + runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_, SessionService: service}) + + iter := runner.Run(ctx, []Message{schema.UserMessage("New message")}, WithSessionID(sessionID)) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + data, _ := store.Get(ctx, sessionID) + assert.Equal(t, 3, len(data)) + var h, in, out *AgentEvent + _ = gob.NewDecoder(bytes.NewReader(data[0])).Decode(&h) + _ = gob.NewDecoder(bytes.NewReader(data[1])).Decode(&in) + _ = gob.NewDecoder(bytes.NewReader(data[2])).Decode(&out) + assert.Equal(t, "Previous message", h.Output.MessageOutput.Message.Content) + assert.Equal(t, "New message", in.Output.MessageOutput.Message.Content) + assert.Equal(t, "Response", out.Output.MessageOutput.Message.Content) + + // Verify UnmarshalEvent helper + unmarshaled, err := service.UnmarshalEvent(data[2]) + assert.NoError(t, err) + assert.Equal(t, "Response", unmarshaled.Output.MessageOutput.Message.Content) + + // 2. Runner with Session Options (WithLimit) + // Add more events to test limiting + addTestEvents(ctx, store, sessionID, 10, "History ") + // Store now has 3 (initial) + 10 (added) = 13 events + + // Run with limit + iter = runner.Run(ctx, []Message{schema.UserMessage("New run")}, + WithSessionID(sessionID), + WithSessionOptions(WithLimit(5))) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + // Verify total count increased by 2 (Input + Response) + finalData, _ := store.Get(ctx, sessionID) + assert.Equal(t, 15, len(finalData)) +} + +func TestSession_SerializationAndStoreErrors(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore() + sessionID := "round_trip_and_errors" + + service := NewSessionService(store) + + events := []*AgentEvent{ + {Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.UserMessage("Hello")}}}, + {AgentName: "Agent", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("World", nil)}}}, + } + err := service.persistEvents(ctx, sessionID, events) + assert.NoError(t, err) + loaded, err := service.load(ctx, sessionID) + assert.NoError(t, err) + assert.Equal(t, 2, len(loaded)) + assert.Equal(t, "Hello", loaded[0].Output.MessageOutput.Message.Content) + assert.Equal(t, "World", loaded[1].Output.MessageOutput.Message.Content) + + badID := "bad_bytes" + _ = store.Set(ctx, badID, [][]byte{[]byte("not-gob-data")}) + _, err = service.load(ctx, badID) + assert.Error(t, err) + + errStore := newErrorAddStore(errors.New("store add failed")) + s2 := NewSessionService(errStore) + mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent", []*AgentEvent{ + {AgentName: "TestAgent", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("Response", nil)}}}, + }) + runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_, SessionService: s2}) + iter := runner.Run(ctx, []Message{schema.UserMessage("Hi")}, WithSessionID("add-error")) + var gotErrEvent bool + for { + ev, ok := iter.Next() + if !ok { + break + } + if ev.Err != nil { + gotErrEvent = true + } + } + assert.True(t, gotErrEvent) + + err = s2.saveInput(ctx, "sid", []Message{schema.UserMessage("Hi")}) + assert.Error(t, err) +} + +func TestSessionService_AfterGet_PersistAndRun_Integration(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore() + sessionID := "after_get_persist_integration" + + // Pre-populate store with history: user + assistant + prevUser := &AgentEvent{Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.UserMessage("Prev1")}}} + prevAssistant := &AgentEvent{AgentName: "AgentA", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("Prev2", nil)}}} + buf1, buf2 := &bytes.Buffer{}, &bytes.Buffer{} + _ = gob.NewEncoder(buf1).Encode(prevUser) + _ = gob.NewEncoder(buf2).Encode(prevAssistant) + _ = store.Add(ctx, sessionID, [][]byte{buf1.Bytes(), buf2.Bytes()}) + + // AfterGet handlers: transform and then tag + h1 := func(ctx context.Context, events []*AgentEvent) ([]*AgentEvent, error) { + out := make([]*AgentEvent, 0, len(events)+1) + for _, ev := range events { + if ev.Output != nil && ev.Output.MessageOutput != nil { + msg := ev.Output.MessageOutput.Message + if msg != nil && msg.Role == schema.Assistant { + // modify assistant content + ev2 := *ev + ev2.Output = &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage(msg.Content+"-Transformed", nil)}} + out = append(out, &ev2) + continue + } + } + out = append(out, ev) + } + // Append summary event + out = append(out, &AgentEvent{AgentName: "Summarizer", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("Summary", nil)}}}) + return out, nil + } + + h2 := func(ctx context.Context, events []*AgentEvent) ([]*AgentEvent, error) { + out := make([]*AgentEvent, 0, len(events)) + for _, ev := range events { + // ensure AgentName tagged + name := ev.AgentName + if name == "" { + name = "Input" + } + ev2 := *ev + ev2.AgentName = "AG_" + name + out = append(out, &ev2) + } + return out, nil + } + + // BeforeAdd handler to prefix agent name on save + beforeAdd := func(ctx context.Context, ev *AgentEvent) (*AgentEvent, error) { + ev2 := *ev + if ev2.AgentName == "" { + ev2.AgentName = "Input" + } + ev2.AgentName = "Before_Add_" + ev2.AgentName + return &ev2, nil + } + + service := NewSessionService(store, WithAfterGetSession(h1, h2), WithPersistAfterGetSession(), WithBeforeAddSession(beforeAdd)) + + // Run with session to trigger load (AfterGet + persist) and saving input/output + mockAgent_ := newMockRunnerAgent("TestAgent", "Test agent", []*AgentEvent{ + {AgentName: "TestAgent", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("Response", nil)}}}, + }) + runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_, SessionService: service}) + + iter := runner.Run(ctx, []Message{schema.UserMessage("New")}, WithSessionID(sessionID)) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + // Verify store contents: transformed history persisted, plus input and response (interrupts are filtered by default) + data, _ := store.Get(ctx, sessionID) + // transformed history should be 3 entries (Prev1, Prev2-Transformed, Summary) then +2 new entries + assert.Equal(t, 5, len(data)) + + // Decode and check first three transformed entries + var e0, e1, e2 *AgentEvent + _ = gob.NewDecoder(bytes.NewReader(data[0])).Decode(&e0) + _ = gob.NewDecoder(bytes.NewReader(data[1])).Decode(&e1) + _ = gob.NewDecoder(bytes.NewReader(data[2])).Decode(&e2) + // After h2, AgentName should be tagged + assert.Equal(t, "AG_Input", e0.AgentName) + assert.Equal(t, "AG_AgentA", e1.AgentName) + assert.Equal(t, "AG_Summarizer", e2.AgentName) + // content transformed + assert.Equal(t, "Prev2-Transformed", e1.Output.MessageOutput.Message.Content) + + // Input event saved with BeforeAdd prefix + var e3 *AgentEvent + _ = gob.NewDecoder(bytes.NewReader(data[3])).Decode(&e3) + assert.Equal(t, "Before_Add_Input", e3.AgentName) + assert.Equal(t, "New", e3.Output.MessageOutput.Message.Content) + + // Response saved with BeforeAdd prefix + var e4 *AgentEvent + _ = gob.NewDecoder(bytes.NewReader(data[4])).Decode(&e4) + _ = assert.Equal(t, "Before_Add_TestAgent", e4.AgentName) + assert.Equal(t, "Response", e4.Output.MessageOutput.Message.Content) +} + +func TestSessionService_AfterGet_ErrorAndPersist_Integration(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore() + sessionID := "after_get_error_persist" + + // Pre-populate store + ev := &AgentEvent{Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.UserMessage("Seed")}}} + buf := &bytes.Buffer{} + _ = gob.NewEncoder(buf).Encode(ev) + _ = store.Add(ctx, sessionID, [][]byte{buf.Bytes()}) + + // First AfterGet fails, second should be skipped + hFail := func(ctx context.Context, events []*AgentEvent) ([]*AgentEvent, error) { + return nil, errors.New("boom") + } + hSkipped := func(ctx context.Context, events []*AgentEvent) ([]*AgentEvent, error) { + return []*AgentEvent{}, nil + } + + // BeforeAdd that skips saving assistant responses with content "SkipMe" + beforeAddSkip := func(ctx context.Context, ev *AgentEvent) (*AgentEvent, error) { + if ev.Output != nil && ev.Output.MessageOutput != nil { + m := ev.Output.MessageOutput.Message + if m != nil && m.Role == schema.Assistant && m.Content == "SkipMe" { + return nil, nil + } + } + return ev, nil + } + + service := NewSessionService(store, WithAfterGetSession(hFail, hSkipped), WithPersistAfterGetSession(), WithBeforeAddSession(beforeAddSkip)) + + mockAgent_ := newMockRunnerAgent("AgentB", "Test agent", []*AgentEvent{ + {AgentName: "AgentB", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("SkipMe", nil)}}}, + {AgentName: "AgentB", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("KeepMe", nil)}}}, + }) + + runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_, SessionService: service}) + iter := runner.Run(ctx, []Message{schema.UserMessage("New2")}, WithSessionID(sessionID)) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + // AfterGet failed: history persisted as original (1 entry), then input (1), then only the second assistant (skipping first) + data, _ := store.Get(ctx, sessionID) + assert.Equal(t, 3, len(data)) + + var h0, h1, h2 *AgentEvent + _ = gob.NewDecoder(bytes.NewReader(data[0])).Decode(&h0) + _ = gob.NewDecoder(bytes.NewReader(data[1])).Decode(&h1) + _ = gob.NewDecoder(bytes.NewReader(data[2])).Decode(&h2) + assert.Equal(t, "Seed", h0.Output.MessageOutput.Message.Content) + assert.Equal(t, "New2", h1.Output.MessageOutput.Message.Content) + assert.Equal(t, "KeepMe", h2.Output.MessageOutput.Message.Content) +} + +func TestSession_Streaming_RoundTrip(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore() + sessionID := "streaming_round_trip" + + // Build a streaming assistant message with multiple frames + frames := []*schema.Message{ + schema.AssistantMessage("part1", nil), + schema.AssistantMessage(" ", nil), + schema.AssistantMessage("part2", nil), + } + stream := schema.StreamReaderFromArray(frames) + stream.SetAutomaticClose() + + // Mock agent emits streaming output + i, g := NewAsyncIteratorPair[*AgentEvent]() + g.Send(EventFromMessage(nil, stream, schema.Assistant, "")) + g.Close() + + service := NewSessionService(store) + + // Persist output via session service + niter, ngen := NewAsyncIteratorPair[*AgentEvent]() + go func() { + defer ngen.Close() + service.saveOutput(ctx, sessionID, i, ngen) + }() + + // Drain iterator to ensure save completes + for { + _, ok := niter.Next() + if !ok { + break + } + } + + // Load back and verify concatenation + loaded, err := service.load(ctx, sessionID) + assert.NoError(t, err) + // Expect a single assistant message equal to concatenation of frames + assert.Equal(t, 1, len(loaded)) + msg, _, err := GetMessage(loaded[0]) + assert.NoError(t, err) + assert.NotNil(t, msg) + assert.Equal(t, schema.Assistant, msg.Role) + assert.Equal(t, "part1 part2", msg.Content) +} + +// Helper to add events to store +func addTestEvents(ctx context.Context, store SessionStore, sessionID string, count int, prefix string, opts ...StoreOption) { + for i := 0; i < count; i++ { + event := &AgentEvent{ + AgentName: "Agent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.AssistantMessage(fmt.Sprintf("%sEvent %d", prefix, i), nil), + }, + }, + } + buf := &bytes.Buffer{} + _ = gob.NewEncoder(buf).Encode(event) + _ = store.Add(ctx, sessionID, [][]byte{buf.Bytes()}, opts...) + } +} + +func TestInMemoryStore_Options(t *testing.T) { + ctx := context.Background() + + t.Run("GetStoreImplSpecificOptions", func(t *testing.T) { + opts := []StoreOption{ + WithLimit(10), + WithOffset(5), + WithRoundID("test_round"), + } + + extracted := GetStoreImplSpecificOptions(&InMemoryStoreOptions{}, opts...) + assert.Equal(t, 10, extracted.Limit) + assert.Equal(t, 5, extracted.Offset) + assert.Equal(t, "test_round", extracted.RoundID) + + // Test with base values + base := &InMemoryStoreOptions{Limit: 100, Offset: 20} + extracted2 := GetStoreImplSpecificOptions(base, WithLimit(50)) + assert.Equal(t, 50, extracted2.Limit) // Overridden + assert.Equal(t, 20, extracted2.Offset) // Kept from base + }) + + t.Run("WithLimit", func(t *testing.T) { + store := NewInMemorySessionStore() + sessionID := "limit_test" + addTestEvents(ctx, store, sessionID, 10, "") + + // Get all + allData, _ := store.Get(ctx, sessionID) + assert.Equal(t, 10, len(allData)) + + // Get latest 5 + limitedData, _ := store.Get(ctx, sessionID, WithLimit(5)) + assert.Equal(t, 5, len(limitedData)) + + // Verify content (last 5 are 5-9) + var lastEvent AgentEvent + _ = gob.NewDecoder(bytes.NewReader(limitedData[4])).Decode(&lastEvent) + assert.Equal(t, "Event 9", lastEvent.Output.MessageOutput.Message.Content) + + var firstOfLast5 AgentEvent + _ = gob.NewDecoder(bytes.NewReader(limitedData[0])).Decode(&firstOfLast5) + assert.Equal(t, "Event 5", firstOfLast5.Output.MessageOutput.Message.Content) + }) + + t.Run("WithOffset", func(t *testing.T) { + store := NewInMemorySessionStore() + sessionID := "offset_test" + addTestEvents(ctx, store, sessionID, 10, "") + + // Skip first 3 + offsetData, _ := store.Get(ctx, sessionID, WithOffset(3)) + assert.Equal(t, 7, len(offsetData)) + + // Verify first event after offset is Event 3 + var firstAfterOffset AgentEvent + _ = gob.NewDecoder(bytes.NewReader(offsetData[0])).Decode(&firstAfterOffset) + assert.Equal(t, "Event 3", firstAfterOffset.Output.MessageOutput.Message.Content) + + // Offset >= total length + emptyData, _ := store.Get(ctx, sessionID, WithOffset(10)) + assert.Equal(t, 0, len(emptyData)) + }) + + t.Run("WithLimitAndOffset", func(t *testing.T) { + store := NewInMemorySessionStore() + sessionID := "limit_offset_test" + addTestEvents(ctx, store, sessionID, 20, "") + + // Skip first 5, then take latest 3 from remaining (should get events 17, 18, 19) + data, _ := store.Get(ctx, sessionID, WithOffset(5), WithLimit(3)) + assert.Equal(t, 3, len(data)) + + // Verify: After skipping 5, we have events[5:20], latest 3 are events[17:20] + var event0 AgentEvent + _ = gob.NewDecoder(bytes.NewReader(data[0])).Decode(&event0) + assert.Equal(t, "Event 17", event0.Output.MessageOutput.Message.Content) + + var event2 AgentEvent + _ = gob.NewDecoder(bytes.NewReader(data[2])).Decode(&event2) + assert.Equal(t, "Event 19", event2.Output.MessageOutput.Message.Content) + + // WithLimit larger than remaining after offset + data2, _ := store.Get(ctx, sessionID, WithOffset(18), WithLimit(10)) + assert.Equal(t, 2, len(data2)) // Only events 18, 19 remain + }) + + t.Run("WithRoundID", func(t *testing.T) { + store := NewInMemorySessionStore() + sessionID := "roundid_test" + + // Add events with different RoundIDs + addTestEvents(ctx, store, sessionID, 5, "Round2-", WithRoundID("round_2")) + addTestEvents(ctx, store, sessionID, 3, "Round3-", WithRoundID("round_3")) + + // Get all events (no filter) + allData, _ := store.Get(ctx, sessionID) + assert.Equal(t, 8, len(allData)) + + // Filter by round_2 + round2Data, _ := store.Get(ctx, sessionID, WithRoundID("round_2")) + assert.Equal(t, 5, len(round2Data)) + var firstR2 AgentEvent + _ = gob.NewDecoder(bytes.NewReader(round2Data[0])).Decode(&firstR2) + assert.Equal(t, "Round2-Event 0", firstR2.Output.MessageOutput.Message.Content) + + // Filter by round_3 + round3Data, _ := store.Get(ctx, sessionID, WithRoundID("round_3")) + assert.Equal(t, 3, len(round3Data)) + var firstR3 AgentEvent + _ = gob.NewDecoder(bytes.NewReader(round3Data[0])).Decode(&firstR3) + assert.Equal(t, "Round3-Event 0", firstR3.Output.MessageOutput.Message.Content) + + // Filter by nonexistent round + noneData, err := store.Get(ctx, sessionID, WithRoundID("nonexistent")) + assert.NoError(t, err) + assert.Equal(t, 0, len(noneData)) + + // Combine RoundID filter with Limit + // Filter round_2 (5 events), then take Limit(2) -> last 2 of round_2 + limitedR2, _ := store.Get(ctx, sessionID, WithRoundID("round_2"), WithLimit(2)) + assert.Equal(t, 2, len(limitedR2)) + var lastR2 AgentEvent + _ = gob.NewDecoder(bytes.NewReader(limitedR2[1])).Decode(&lastR2) + assert.Equal(t, "Round2-Event 4", lastR2.Output.MessageOutput.Message.Content) + }) +} + +// JSONSerializer for testing +type JSONSerializer struct{} + +func (s *JSONSerializer) Marshal(event *AgentEvent) ([]byte, error) { + return json.Marshal(event) +} + +func (s *JSONSerializer) Unmarshal(data []byte) (*AgentEvent, error) { + var event AgentEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, err + } + return &event, nil +} + +func TestSessionService_CustomSerializer(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore() + sessionID := "json_serializer_test" + + // Use JSON serializer + service := NewSessionService(store, WithEventSerializer(&JSONSerializer{})) + + // Create an event + event := &AgentEvent{ + AgentName: "TestAgent", + Output: &AgentOutput{ + MessageOutput: &MessageVariant{ + Message: schema.AssistantMessage("Hello JSON", nil), + }, + }, + } + + // Save event via service + err := service.persistEvents(ctx, sessionID, []*AgentEvent{event}) + assert.NoError(t, err) + + // Verify storage contains JSON + data, err := store.Get(ctx, sessionID) + assert.NoError(t, err) + assert.Equal(t, 1, len(data)) + + // Check if it looks like JSON (starts with {) + assert.Equal(t, byte('{'), data[0][0]) + assert.Contains(t, string(data[0]), "Hello JSON") + + // Load back via service + loaded, err := service.load(ctx, sessionID) + assert.NoError(t, err) + assert.Equal(t, 1, len(loaded)) + assert.Equal(t, "TestAgent", loaded[0].AgentName) + assert.Equal(t, "Hello JSON", loaded[0].Output.MessageOutput.Message.Content) + + // Verify Runner integration + mockAgent_ := newMockRunnerAgent("RunnerAgent", "Test", []*AgentEvent{ + {AgentName: "RunnerAgent", Output: &AgentOutput{MessageOutput: &MessageVariant{Message: schema.AssistantMessage("Response", nil)}}}, + }) + runner := NewRunner(ctx, RunnerConfig{Agent: mockAgent_, SessionService: service}) + + iter := runner.Run(ctx, []Message{schema.UserMessage("Input")}, WithSessionID(sessionID)) + for { + _, ok := iter.Next() + if !ok { + break + } + } + + // Verify store has 3 events (1 initial + 1 input + 1 response), all JSON + finalData, _ := store.Get(ctx, sessionID) + assert.Equal(t, 3, len(finalData)) + for _, entry := range finalData { + assert.Equal(t, byte('{'), entry[0], "Entry should be JSON") + } +}