diff --git a/compose/checkpoint_test.go b/compose/checkpoint_test.go index 73abaf7c..1b134a13 100644 --- a/compose/checkpoint_test.go +++ b/compose/checkpoint_test.go @@ -1015,3 +1015,184 @@ func TestPreHandlerInterrupt(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "1", result) } + +func TestCancelInterrupt(t *testing.T) { + g := NewGraph[string, string]() + _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + time.Sleep(3 * time.Second) + return input + "1", nil + })) + _ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input + "2", nil + })) + _ = g.AddEdge(START, "1") + _ = g.AddEdge("1", "2") + _ = g.AddEdge("2", END) + ctx := context.Background() + + // pregel + r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore())) + assert.NoError(t, err) + // interrupt after nodes + canceledCtx, cancel := WithGraphInterrupt(ctx) + go func() { + time.Sleep(500 * time.Millisecond) + cancel(WithGraphInterruptTimeout(time.Hour)) + }() + _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1")) + assert.Error(t, err) + info, success := ExtractInterruptInfo(err) + assert.True(t, success) + assert.Equal(t, []string{"1"}, info.AfterNodes) + result, err := r.Invoke(ctx, "input", WithCheckPointID("1")) + assert.NoError(t, err) + assert.Equal(t, "input12", result) + // infinite timeout + canceledCtx, cancel = WithGraphInterrupt(ctx) + go func() { + time.Sleep(500 * time.Millisecond) + cancel() + }() + _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2")) + assert.Error(t, err) + info, success = ExtractInterruptInfo(err) + assert.True(t, success) + assert.Equal(t, []string{"1"}, info.AfterNodes) + result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) + assert.NoError(t, err) + assert.Equal(t, "input12", result) + + // interrupt rerun nodes + canceledCtx, cancel = WithGraphInterrupt(ctx) + go func() { + time.Sleep(500 * time.Millisecond) + cancel(WithGraphInterruptTimeout(0)) + }() + _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3")) + assert.Error(t, err) + info, success = ExtractInterruptInfo(err) + assert.True(t, success) + assert.Equal(t, []string{"1"}, info.RerunNodes) + result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) + assert.NoError(t, err) + assert.Equal(t, "12", result) + + // dag + g = NewGraph[string, string]() + _ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + time.Sleep(3 * time.Second) + return input + "1", nil + })) + _ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input + "2", nil + })) + _ = g.AddEdge(START, "1") + _ = g.AddEdge("1", "2") + _ = g.AddEdge("2", END) + r, err = g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore())) + assert.NoError(t, err) + // interrupt after nodes + canceledCtx, cancel = WithGraphInterrupt(ctx) + go func() { + time.Sleep(500 * time.Millisecond) + cancel(WithGraphInterruptTimeout(time.Hour)) + }() + _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1")) + assert.Error(t, err) + info, success = ExtractInterruptInfo(err) + assert.True(t, success) + assert.Equal(t, []string{"1"}, info.AfterNodes) + result, err = r.Invoke(ctx, "input", WithCheckPointID("1")) + assert.NoError(t, err) + assert.Equal(t, "input12", result) + // infinite timeout + canceledCtx, cancel = WithGraphInterrupt(ctx) + go func() { + time.Sleep(500 * time.Millisecond) + cancel() + }() + _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2")) + assert.Error(t, err) + info, success = ExtractInterruptInfo(err) + assert.True(t, success) + assert.Equal(t, []string{"1"}, info.AfterNodes) + result, err = r.Invoke(ctx, "input", WithCheckPointID("2")) + assert.NoError(t, err) + assert.Equal(t, "input12", result) + + // interrupt rerun nodes + canceledCtx, cancel = WithGraphInterrupt(ctx) + go func() { + time.Sleep(300 * time.Millisecond) + cancel(WithGraphInterruptTimeout(0)) + }() + _, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3")) + assert.Error(t, err) + info, success = ExtractInterruptInfo(err) + assert.True(t, success) + assert.Equal(t, []string{"1"}, info.RerunNodes) + result, err = r.Invoke(ctx, "input", WithCheckPointID("3")) + assert.NoError(t, err) + assert.Equal(t, "12", result) + + // dag multi canceled nodes + gg := NewGraph[string, map[string]any]() + _ = gg.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input + "1", nil + })) + _ = gg.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + time.Sleep(3 * time.Second) + return input + "2", nil + }), WithOutputKey("2")) + _ = gg.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + time.Sleep(3 * time.Second) + return input + "3", nil + }), WithOutputKey("3")) + _ = gg.AddLambdaNode("4", InvokableLambda(func(ctx context.Context, input map[string]any) (output map[string]any, err error) { + return input, nil + })) + _ = gg.AddEdge(START, "1") + _ = gg.AddEdge("1", "2") + _ = gg.AddEdge("1", "3") + _ = gg.AddEdge("2", "4") + _ = gg.AddEdge("3", "4") + _ = gg.AddEdge("4", END) + ctx = context.Background() + rr, err := gg.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore())) + assert.NoError(t, err) + // interrupt after nodes + canceledCtx, cancel = WithGraphInterrupt(ctx) + go func() { + time.Sleep(500 * time.Millisecond) + cancel(WithGraphInterruptTimeout(time.Hour)) + }() + _, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("1")) + assert.Error(t, err) + info, success = ExtractInterruptInfo(err) + assert.True(t, success) + assert.Equal(t, 2, len(info.AfterNodes)) + result2, err := rr.Invoke(ctx, "input", WithCheckPointID("1")) + assert.NoError(t, err) + assert.Equal(t, map[string]any{ + "2": "input12", + "3": "input13", + }, result2) + + // interrupt rerun nodes + canceledCtx, cancel = WithGraphInterrupt(ctx) + go func() { + time.Sleep(500 * time.Millisecond) + cancel(WithGraphInterruptTimeout(0)) + }() + _, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("2")) + assert.Error(t, err) + info, success = ExtractInterruptInfo(err) + assert.True(t, success) + assert.Equal(t, 2, len(info.RerunNodes)) + result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2")) + assert.NoError(t, err) + assert.Equal(t, map[string]any{ + "2": "2", + "3": "3", + }, result2) +} diff --git a/compose/graph_call_options.go b/compose/graph_call_options.go index 9c753f80..277d72d7 100644 --- a/compose/graph_call_options.go +++ b/compose/graph_call_options.go @@ -17,8 +17,10 @@ package compose import ( + "context" "fmt" "reflect" + "time" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/document" @@ -29,6 +31,51 @@ import ( "github.com/cloudwego/eino/components/retriever" ) +type graphCancelChanKey struct{} +type graphCancelChanVal struct { + ch chan *time.Duration +} + +type graphInterruptOptions struct { + timeout *time.Duration +} + +type GraphInterruptOption func(o *graphInterruptOptions) + +// WithGraphInterruptTimeout specifies the max waiting time before generating an interrupt. +// After the max waiting time, the graph will force an interrupt. Any unfinished tasks will be re-run when the graph is resumed. +func WithGraphInterruptTimeout(timeout time.Duration) GraphInterruptOption { + return func(o *graphInterruptOptions) { + o.timeout = &timeout + } +} + +// WithGraphInterrupt creates a context with graph cancellation support. +// When the returned context is used to invoke a graph or workflow, calling the interrupt function will trigger an interrupt. +// The graph will wait for current tasks to complete by default. +func WithGraphInterrupt(parent context.Context) (ctx context.Context, interrupt func(opts ...GraphInterruptOption)) { + ch := make(chan *time.Duration, 1) + ctx = context.WithValue(parent, graphCancelChanKey{}, &graphCancelChanVal{ + ch: ch, + }) + return ctx, func(opts ...GraphInterruptOption) { + o := &graphInterruptOptions{} + for _, opt := range opts { + opt(o) + } + ch <- o.timeout + close(ch) + } +} + +func getGraphCancel(ctx context.Context) *graphCancelChanVal { + val, ok := ctx.Value(graphCancelChanKey{}).(*graphCancelChanVal) + if !ok { + return nil + } + return val +} + // Option is a functional option type for calling a graph. type Option struct { options []any diff --git a/compose/graph_manager.go b/compose/graph_manager.go index 211d9a8d..d14e7dd8 100644 --- a/compose/graph_manager.go +++ b/compose/graph_manager.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "runtime/debug" + "time" "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/internal/safe" @@ -260,8 +261,13 @@ type taskManager struct { opts []Option needAll bool - num uint32 - done *internal.UnboundedChan[*task] + num uint32 + done *internal.UnboundedChan[*task] + runningTasks map[string]*task + + cancelCh chan *time.Duration + canceled bool + deadline *time.Time } func (t *taskManager) execute(currentTask *task) { @@ -298,6 +304,8 @@ func (t *taskManager) submit(tasks []*task) error { t.num++ t.done.Send(currentTask) } + + t.runningTasks[currentTask.nodeKey] = currentTask } if len(tasks) == 0 { // all tasks' pre-handler failed @@ -305,7 +313,7 @@ func (t *taskManager) submit(tasks []*task) error { } var syncTask *task - if t.num == 0 && (len(tasks) == 1 || t.needAll) { + if t.num == 0 && (len(tasks) == 1 || t.needAll) && t.cancelCh == nil /*if graph can be interrupted by user, shouldn't sync run task*/ { syncTask = tasks[0] tasks = tasks[1:] } @@ -320,45 +328,169 @@ func (t *taskManager) submit(tasks []*task) error { return nil } -func (t *taskManager) wait() []*task { +func (t *taskManager) wait() (tasks []*task, canceled bool, canceledTasks []*task) { if t.needAll { - return t.waitAll() + tasks, canceledTasks = t.waitAll() + return tasks, t.canceled, canceledTasks } - ta, success := t.waitOne() + ta, success, canceled := t.waitOne() + if canceled { + // has canceled and timeout, return canceled tasks + for _, rta := range t.runningTasks { + canceledTasks = append(canceledTasks, rta) + } + t.runningTasks = make(map[string]*task) + t.num = 0 + return nil, true, canceledTasks + } + if t.canceled { + // has canceled, but not timeout, wait all + tasks, canceledTasks = t.waitAll() + return append(tasks, ta), true, canceledTasks + } if !success { - return []*task{} + return []*task{}, t.canceled, nil } - return []*task{ta} + return []*task{ta}, t.canceled, nil } -func (t *taskManager) waitOne() (*task, bool) { +func (t *taskManager) waitOne() (ta *task, success bool, canceled bool) { if t.num == 0 { - return nil, false + return nil, false, false + } + + if t.cancelCh == nil { + ta, _ = t.done.Receive() + } else { + ta, _, canceled = t.receive(t.done.Receive) } - ta, _ := t.done.Receive() t.num-- + if canceled { + return nil, false, true + } + + delete(t.runningTasks, ta.nodeKey) if ta.err != nil { - return ta, true + // biz error, jump post processor + return ta, true, false } runPostHandler(ta, t.runWrapper) - return ta, true + return ta, true, false } -func (t *taskManager) waitAll() []*task { +func (t *taskManager) waitAll() (successTasks []*task, canceledTasks []*task) { result := make([]*task, 0, t.num) for { - ta, success := t.waitOne() + ta, success, canceled := t.waitOne() + if canceled { + for _, rt := range t.runningTasks { + canceledTasks = append(canceledTasks, rt) + } + t.runningTasks = make(map[string]*task) + t.num = 0 + return result, canceledTasks + } if !success { - return result + return result, nil } result = append(result, ta) } } +func (t *taskManager) receive(recv func() (*task, bool)) (ta *task, closed bool, canceled bool) { + if t.deadline != nil { + // have canceled, receive in a certain time + return receiveWithDeadline(recv, *t.deadline) + } + if t.canceled { + // canceled without timeout + ta, closed = recv() + return ta, closed, false + } + if t.cancelCh != nil { + // have not canceled, receive while listening + ta, closed, canceled, t.canceled, t.deadline = receiveWithListening(recv, t.cancelCh) + return ta, closed, canceled + } + // won't cancel + ta, closed = recv() + return ta, closed, false +} + +func receiveWithDeadline(recv func() (*task, bool), deadline time.Time) (ta *task, closed bool, canceled bool) { + now := time.Now() + if deadline.Before(now) { + return nil, false, true + } + + timeout := deadline.Sub(now) + + resultCh := make(chan struct{}, 1) + + go func() { + ta, closed = recv() + resultCh <- struct{}{} + }() + + timeoutCh := time.After(timeout) + + select { + case <-resultCh: + return ta, closed, false + case <-timeoutCh: + return nil, false, true + } +} + +func receiveWithListening(recv func() (*task, bool), cancel chan *time.Duration) (*task, bool, bool, bool, *time.Time) { + type pair struct { + ta *task + closed bool + } + resultCh := make(chan pair, 1) + var timeoutCh <-chan time.Time + + var deadline *time.Time + canceled := false + go func() { + ta, closed := recv() + resultCh <- pair{ta, closed} + }() + + select { + case p := <-resultCh: + return p.ta, p.closed, false, false, nil + case timeout, ok := <-cancel: + if !ok { + // unreachable + break + } + canceled = true + if timeout == nil { + // canceled without timeout + break + } + timeoutCh = time.After(*timeout) + dt := time.Now().Add(*timeout) + deadline = &dt + } + + if timeoutCh != nil { + select { + case p := <-resultCh: + return p.ta, p.closed, false, canceled, deadline + case <-timeoutCh: + return nil, false, true, canceled, deadline + } + } + p := <-resultCh + return p.ta, p.closed, false, canceled, nil +} + func runPreHandler(ta *task, runWrapper runnableCallWrapper) (err error) { defer func() { if e := recover(); e != nil { diff --git a/compose/graph_run.go b/compose/graph_run.go index f54ebd2a..7effe86a 100644 --- a/compose/graph_run.go +++ b/compose/graph_run.go @@ -105,18 +105,15 @@ func runnableTransform(ctx context.Context, r *composableRunnable, input any, op } func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Option) (result any, err error) { - // Choose the appropriate wrapper function based on whether we're handling a stream or not. - haveOnStart := false + ctx, input = onGraphStart(ctx, input, isStream) defer func() { - if !haveOnStart { - ctx, input = onGraphStart(ctx, input, isStream) - } if err != nil { ctx, err = onGraphError(ctx, err) } else { ctx, result = onGraphEnd(ctx, result, isStream) } }() + var runWrapper runnableCallWrapper runWrapper = runnableInvoke if isStream { @@ -125,7 +122,7 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti // Initialize channel and task managers. cm := r.initChannelManager(isStream) - tm := r.initTaskManager(runWrapper, opts...) + tm := r.initTaskManager(runWrapper, getGraphCancel(ctx), opts...) maxSteps := r.options.maxRunSteps if r.dag { @@ -166,36 +163,10 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti var nextTasks []*task if cp := getCheckPointFromCtx(ctx); cp != nil { // in subgraph, try to load checkpoint from ctx - // load checkpoint from ctx - initialized = true // don't init again - - err = r.checkPointer.restoreCheckPoint(cp, isStream) - if err != nil { - return nil, newGraphRunError(fmt.Errorf("restore checkpoint fail: %w", err)) - } - - err = cm.loadChannels(cp.Channels) - if err != nil { - return nil, newGraphRunError(err) - } - if sm := getStateModifier(ctx); sm != nil && cp.State != nil { - err = sm(ctx, *path, cp.State) - if err != nil { - return nil, newGraphRunError(fmt.Errorf("state modifier fail: %w", err)) - } - } - if cp.State != nil { - ctx = context.WithValue(ctx, stateKey{}, &internalState{state: cp.State}) - } - - ctx, input = onGraphStart(ctx, input, isStream) - haveOnStart = true - nextTasks, err = r.restoreTasks(ctx, cp.Inputs, cp.SkipPreHandler, cp.ToolsNodeExecutedTools, cp.RerunNodes, isStream, optMap) // should restore after set state to context - if err != nil { - return nil, newGraphRunError(fmt.Errorf("restore tasks fail: %w", err)) - } + initialized = true + ctx, nextTasks, err = r.restoreFromCheckPoint(ctx, *path, getStateModifier(ctx), cp, isStream, cm, optMap) } else if checkPointID != nil && !forceNewRun { - cp, err := getCheckPointFromStore(ctx, *checkPointID, r.checkPointer) + cp, err = getCheckPointFromStore(ctx, *checkPointID, r.checkPointer) if err != nil { return nil, newGraphRunError(fmt.Errorf("load checkpoint from store fail: %w", err)) } @@ -203,34 +174,10 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti // load checkpoint from store initialized = true - err = r.checkPointer.restoreCheckPoint(cp, isStream) - if err != nil { - return nil, newGraphRunError(fmt.Errorf("restore checkpoint fail: %w", err)) - } - - err = cm.loadChannels(cp.Channels) - if err != nil { - return nil, newGraphRunError(err) - } ctx = setStateModifier(ctx, stateModifier) ctx = setCheckPointToCtx(ctx, cp) - if stateModifier != nil && cp.State != nil { - err = stateModifier(ctx, *NewNodePath(), cp.State) - if err != nil { - return nil, newGraphRunError(fmt.Errorf("state modifier fail: %w", err)) - } - } - if cp.State != nil { - ctx = context.WithValue(ctx, stateKey{}, &internalState{state: cp.State}) - } - ctx, input = onGraphStart(ctx, input, isStream) - haveOnStart = true - // resume graph - nextTasks, err = r.restoreTasks(ctx, cp.Inputs, cp.SkipPreHandler, cp.ToolsNodeExecutedTools, cp.RerunNodes, isStream, optMap) - if err != nil { - return nil, newGraphRunError(fmt.Errorf("restore tasks fail: %w", err)) - } + ctx, nextTasks, err = r.restoreFromCheckPoint(ctx, *NewNodePath(), stateModifier, cp, isStream, cm, optMap) } } if !initialized { @@ -239,9 +186,6 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti ctx = r.runCtx(ctx) } - ctx, input = onGraphStart(ctx, input, isStream) - haveOnStart = true - var isEnd bool nextTasks, result, isEnd, err = r.calculateNextTasks(ctx, []*task{{ nodeKey: START, @@ -272,13 +216,15 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti } } + // used to reporting NoTask error var lastCompletedTask []*task + // Main execution loop. for step := 0; ; step++ { // Check for context cancellation. select { case <-ctx.Done(): - _ = tm.waitAll() + _, _ = tm.waitAll() return nil, newGraphRunError(fmt.Errorf("context has been canceled: %w", ctx.Err())) default: } @@ -294,10 +240,25 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti if err != nil { return nil, newGraphRunError(fmt.Errorf("failed to submit tasks: %w", err)) } - var completedTasks []*task - completedTasks = tm.wait() + var totalCanceledTasks []*task + + completedTasks, canceled, canceledTasks := tm.wait() + totalCanceledTasks = append(totalCanceledTasks, canceledTasks...) tempInfo := newInterruptTempInfo() + if canceled { + if len(canceledTasks) > 0 { + // as rerun nodes + for _, t := range canceledTasks { + tempInfo.interruptRerunNodes = append(tempInfo.interruptRerunNodes, t.nodeKey) + } + } else { + // as interrupt after + for _, t := range completedTasks { + tempInfo.interruptAfterNodes = append(tempInfo.interruptAfterNodes, t.nodeKey) + } + } + } err = r.resolveInterruptCompletedTasks(tempInfo, completedTasks) if err != nil { @@ -305,8 +266,15 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti } if len(tempInfo.subGraphInterrupts)+len(tempInfo.interruptRerunNodes) > 0 { - cpt := tm.waitAll() - err = r.resolveInterruptCompletedTasks(tempInfo, cpt) + var newCompletedTasks []*task + newCompletedTasks, canceledTasks = tm.waitAll() + totalCanceledTasks = append(totalCanceledTasks, canceledTasks...) + for _, ct := range canceledTasks { + // handle timeout tasks as rerun + tempInfo.interruptRerunNodes = append(tempInfo.interruptRerunNodes, ct.nodeKey) + } + + err = r.resolveInterruptCompletedTasks(tempInfo, newCompletedTasks) if err != nil { return nil, err // err has been wrapped } @@ -318,7 +286,7 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti return nil, r.handleInterruptWithSubGraphAndRerunNodes( ctx, tempInfo, - append(completedTasks, cpt...), + append(append(completedTasks, newCompletedTasks...), totalCanceledTasks...), // canceled tasks are handled as rerun writeToCheckPointID, isSubGraph, cm, @@ -343,7 +311,12 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti tempInfo.interruptBeforeNodes = getHitKey(nextTasks, r.interruptBeforeNodes) if len(tempInfo.interruptBeforeNodes) > 0 || len(tempInfo.interruptAfterNodes) > 0 { - newCompletedTasks := tm.waitAll() + var newCompletedTasks []*task + newCompletedTasks, canceledTasks = tm.waitAll() + totalCanceledTasks = append(totalCanceledTasks, canceledTasks...) + for _, ct := range canceledTasks { + tempInfo.interruptRerunNodes = append(tempInfo.interruptRerunNodes, ct.nodeKey) + } err = r.resolveInterruptCompletedTasks(tempInfo, newCompletedTasks) if err != nil { @@ -354,7 +327,7 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti return nil, r.handleInterruptWithSubGraphAndRerunNodes( ctx, tempInfo, - append(completedTasks, newCompletedTasks...), + append(append(completedTasks, newCompletedTasks...), totalCanceledTasks...), writeToCheckPointID, isSubGraph, cm, @@ -380,6 +353,41 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti } } +func (r *runner) restoreFromCheckPoint( + ctx context.Context, + path NodePath, + sm StateModifier, + cp *checkpoint, + isStream bool, + cm *channelManager, + optMap map[string][]any, +) (context.Context, []*task, error) { + err := r.checkPointer.restoreCheckPoint(cp, isStream) + if err != nil { + return ctx, nil, newGraphRunError(fmt.Errorf("restore checkpoint fail: %w", err)) + } + + err = cm.loadChannels(cp.Channels) + if err != nil { + return ctx, nil, newGraphRunError(err) + } + if sm != nil && cp.State != nil { + err = sm(ctx, path, cp.State) + if err != nil { + return ctx, nil, newGraphRunError(fmt.Errorf("state modifier fail: %w", err)) + } + } + if cp.State != nil { + ctx = context.WithValue(ctx, stateKey{}, &internalState{state: cp.State}) + } + + nextTasks, err := r.restoreTasks(ctx, cp.Inputs, cp.SkipPreHandler, cp.ToolsNodeExecutedTools, cp.RerunNodes, isStream, optMap) // should restore after set state to context + if err != nil { + return ctx, nil, newGraphRunError(fmt.Errorf("restore tasks fail: %w", err)) + } + return ctx, nextTasks, nil +} + func newInterruptTempInfo() *interruptTempInfo { return &interruptTempInfo{ subGraphInterrupts: map[string]*subGraphInterruptError{}, @@ -808,13 +816,18 @@ func (r *runner) calculateBranch(ctx context.Context, curNodeKey string, startCh return ret, nil } -func (r *runner) initTaskManager(runWrapper runnableCallWrapper, opts ...Option) *taskManager { - return &taskManager{ - runWrapper: runWrapper, - opts: opts, - needAll: !r.eager, - done: internal.NewUnboundedChan[*task](), +func (r *runner) initTaskManager(runWrapper runnableCallWrapper, cancelVal *graphCancelChanVal, opts ...Option) *taskManager { + tm := &taskManager{ + runWrapper: runWrapper, + opts: opts, + needAll: !r.eager, + done: internal.NewUnboundedChan[*task](), + runningTasks: make(map[string]*task), + } + if cancelVal != nil { + tm.cancelCh = cancelVal.ch } + return tm } func (r *runner) initChannelManager(isStream bool) *channelManager {