Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions compose/checkpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1015,3 +1015,184 @@ func TestPreHandlerInterrupt(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "1", result)
}

func TestCancelInterrupt(t *testing.T) {
g := NewGraph[string, string]()
_ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
time.Sleep(3 * time.Second)
return input + "1", nil
}))
_ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return input + "2", nil
}))
_ = g.AddEdge(START, "1")
_ = g.AddEdge("1", "2")
_ = g.AddEdge("2", END)
ctx := context.Background()

// pregel
r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore()))
assert.NoError(t, err)
// interrupt after nodes
canceledCtx, cancel := WithGraphInterrupt(ctx)
go func() {
time.Sleep(500 * time.Millisecond)
cancel(WithGraphInterruptTimeout(time.Hour))
}()
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1"))
assert.Error(t, err)
info, success := ExtractInterruptInfo(err)
assert.True(t, success)
assert.Equal(t, []string{"1"}, info.AfterNodes)
result, err := r.Invoke(ctx, "input", WithCheckPointID("1"))
assert.NoError(t, err)
assert.Equal(t, "input12", result)
// infinite timeout
canceledCtx, cancel = WithGraphInterrupt(ctx)
go func() {
time.Sleep(500 * time.Millisecond)
cancel()
}()
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2"))
assert.Error(t, err)
info, success = ExtractInterruptInfo(err)
assert.True(t, success)
assert.Equal(t, []string{"1"}, info.AfterNodes)
result, err = r.Invoke(ctx, "input", WithCheckPointID("2"))
assert.NoError(t, err)
assert.Equal(t, "input12", result)

// interrupt rerun nodes
canceledCtx, cancel = WithGraphInterrupt(ctx)
go func() {
time.Sleep(500 * time.Millisecond)
cancel(WithGraphInterruptTimeout(0))
}()
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3"))
assert.Error(t, err)
info, success = ExtractInterruptInfo(err)
assert.True(t, success)
assert.Equal(t, []string{"1"}, info.RerunNodes)
result, err = r.Invoke(ctx, "input", WithCheckPointID("3"))
assert.NoError(t, err)
assert.Equal(t, "12", result)

// dag
g = NewGraph[string, string]()
_ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
time.Sleep(3 * time.Second)
return input + "1", nil
}))
_ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return input + "2", nil
}))
_ = g.AddEdge(START, "1")
_ = g.AddEdge("1", "2")
_ = g.AddEdge("2", END)
r, err = g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore()))
assert.NoError(t, err)
// interrupt after nodes
canceledCtx, cancel = WithGraphInterrupt(ctx)
go func() {
time.Sleep(500 * time.Millisecond)
cancel(WithGraphInterruptTimeout(time.Hour))
}()
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1"))
assert.Error(t, err)
info, success = ExtractInterruptInfo(err)
assert.True(t, success)
assert.Equal(t, []string{"1"}, info.AfterNodes)
result, err = r.Invoke(ctx, "input", WithCheckPointID("1"))
assert.NoError(t, err)
assert.Equal(t, "input12", result)
// infinite timeout
canceledCtx, cancel = WithGraphInterrupt(ctx)
go func() {
time.Sleep(500 * time.Millisecond)
cancel()
}()
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2"))
assert.Error(t, err)
info, success = ExtractInterruptInfo(err)
assert.True(t, success)
assert.Equal(t, []string{"1"}, info.AfterNodes)
result, err = r.Invoke(ctx, "input", WithCheckPointID("2"))
assert.NoError(t, err)
assert.Equal(t, "input12", result)

// interrupt rerun nodes
canceledCtx, cancel = WithGraphInterrupt(ctx)
go func() {
time.Sleep(300 * time.Millisecond)
cancel(WithGraphInterruptTimeout(0))
}()
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3"))
assert.Error(t, err)
info, success = ExtractInterruptInfo(err)
assert.True(t, success)
assert.Equal(t, []string{"1"}, info.RerunNodes)
result, err = r.Invoke(ctx, "input", WithCheckPointID("3"))
assert.NoError(t, err)
assert.Equal(t, "12", result)

// dag multi canceled nodes
gg := NewGraph[string, map[string]any]()
_ = gg.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return input + "1", nil
}))
_ = gg.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
time.Sleep(3 * time.Second)
return input + "2", nil
}), WithOutputKey("2"))
_ = gg.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
time.Sleep(3 * time.Second)
return input + "3", nil
}), WithOutputKey("3"))
_ = gg.AddLambdaNode("4", InvokableLambda(func(ctx context.Context, input map[string]any) (output map[string]any, err error) {
return input, nil
}))
_ = gg.AddEdge(START, "1")
_ = gg.AddEdge("1", "2")
_ = gg.AddEdge("1", "3")
_ = gg.AddEdge("2", "4")
_ = gg.AddEdge("3", "4")
_ = gg.AddEdge("4", END)
ctx = context.Background()
rr, err := gg.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore()))
assert.NoError(t, err)
// interrupt after nodes
canceledCtx, cancel = WithGraphInterrupt(ctx)
go func() {
time.Sleep(500 * time.Millisecond)
cancel(WithGraphInterruptTimeout(time.Hour))
}()
_, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("1"))
assert.Error(t, err)
info, success = ExtractInterruptInfo(err)
assert.True(t, success)
assert.Equal(t, 2, len(info.AfterNodes))
result2, err := rr.Invoke(ctx, "input", WithCheckPointID("1"))
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"2": "input12",
"3": "input13",
}, result2)

// interrupt rerun nodes
canceledCtx, cancel = WithGraphInterrupt(ctx)
go func() {
time.Sleep(500 * time.Millisecond)
cancel(WithGraphInterruptTimeout(0))
}()
_, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("2"))
assert.Error(t, err)
info, success = ExtractInterruptInfo(err)
assert.True(t, success)
assert.Equal(t, 2, len(info.RerunNodes))
result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2"))
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"2": "2",
"3": "3",
}, result2)
}
47 changes: 47 additions & 0 deletions compose/graph_call_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package compose

import (
"context"
"fmt"
"reflect"
"time"

"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/document"
Expand All @@ -29,6 +31,51 @@ import (
"github.com/cloudwego/eino/components/retriever"
)

type graphCancelChanKey struct{}
type graphCancelChanVal struct {
ch chan *time.Duration
}

type graphInterruptOptions struct {
timeout *time.Duration
}

type GraphInterruptOption func(o *graphInterruptOptions)

// WithGraphInterruptTimeout specifies the max waiting time before generating an interrupt.
// After the max waiting time, the graph will force an interrupt. Any unfinished tasks will be re-run when the graph is resumed.
func WithGraphInterruptTimeout(timeout time.Duration) GraphInterruptOption {
return func(o *graphInterruptOptions) {
o.timeout = &timeout
}
}

// WithGraphInterrupt creates a context with graph cancellation support.
// When the returned context is used to invoke a graph or workflow, calling the interrupt function will trigger an interrupt.
// The graph will wait for current tasks to complete by default.
func WithGraphInterrupt(parent context.Context) (ctx context.Context, interrupt func(opts ...GraphInterruptOption)) {
ch := make(chan *time.Duration, 1)
ctx = context.WithValue(parent, graphCancelChanKey{}, &graphCancelChanVal{
ch: ch,
})
return ctx, func(opts ...GraphInterruptOption) {
o := &graphInterruptOptions{}
for _, opt := range opts {
opt(o)
}
ch <- o.timeout
close(ch)
}
}

func getGraphCancel(ctx context.Context) *graphCancelChanVal {
val, ok := ctx.Value(graphCancelChanKey{}).(*graphCancelChanVal)
if !ok {
return nil
}
return val
}

// Option is a functional option type for calling a graph.
type Option struct {
options []any
Expand Down
Loading
Loading