Skip to content

Commit fb4c0db

Browse files
feat: support cancel interrupt
1 parent 4768f82 commit fb4c0db

File tree

4 files changed

+466
-93
lines changed

4 files changed

+466
-93
lines changed

compose/checkpoint_test.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,3 +1015,184 @@ func TestPreHandlerInterrupt(t *testing.T) {
10151015
assert.NoError(t, err)
10161016
assert.Equal(t, "1", result)
10171017
}
1018+
1019+
func TestCancelInterrupt(t *testing.T) {
1020+
g := NewGraph[string, string]()
1021+
_ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1022+
time.Sleep(3 * time.Second)
1023+
return input + "1", nil
1024+
}))
1025+
_ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1026+
return input + "2", nil
1027+
}))
1028+
_ = g.AddEdge(START, "1")
1029+
_ = g.AddEdge("1", "2")
1030+
_ = g.AddEdge("2", END)
1031+
ctx := context.Background()
1032+
1033+
// pregel
1034+
r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore()))
1035+
assert.NoError(t, err)
1036+
// interrupt after nodes
1037+
canceledCtx, cancel := WithGraphInterrupt(ctx)
1038+
go func() {
1039+
time.Sleep(500 * time.Millisecond)
1040+
cancel(WithGraphInterruptTimeout(time.Hour))
1041+
}()
1042+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1"))
1043+
assert.Error(t, err)
1044+
info, success := ExtractInterruptInfo(err)
1045+
assert.True(t, success)
1046+
assert.Equal(t, []string{"1"}, info.AfterNodes)
1047+
result, err := r.Invoke(ctx, "input", WithCheckPointID("1"))
1048+
assert.NoError(t, err)
1049+
assert.Equal(t, "input12", result)
1050+
// infinite timeout
1051+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1052+
go func() {
1053+
time.Sleep(500 * time.Millisecond)
1054+
cancel()
1055+
}()
1056+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2"))
1057+
assert.Error(t, err)
1058+
info, success = ExtractInterruptInfo(err)
1059+
assert.True(t, success)
1060+
assert.Equal(t, []string{"1"}, info.AfterNodes)
1061+
result, err = r.Invoke(ctx, "input", WithCheckPointID("2"))
1062+
assert.NoError(t, err)
1063+
assert.Equal(t, "input12", result)
1064+
1065+
// interrupt rerun nodes
1066+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1067+
go func() {
1068+
time.Sleep(500 * time.Millisecond)
1069+
cancel(WithGraphInterruptTimeout(0))
1070+
}()
1071+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3"))
1072+
assert.Error(t, err)
1073+
info, success = ExtractInterruptInfo(err)
1074+
assert.True(t, success)
1075+
assert.Equal(t, []string{"1"}, info.RerunNodes)
1076+
result, err = r.Invoke(ctx, "input", WithCheckPointID("3"))
1077+
assert.NoError(t, err)
1078+
assert.Equal(t, "12", result)
1079+
1080+
// dag
1081+
g = NewGraph[string, string]()
1082+
_ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1083+
time.Sleep(3 * time.Second)
1084+
return input + "1", nil
1085+
}))
1086+
_ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1087+
return input + "2", nil
1088+
}))
1089+
_ = g.AddEdge(START, "1")
1090+
_ = g.AddEdge("1", "2")
1091+
_ = g.AddEdge("2", END)
1092+
r, err = g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore()))
1093+
assert.NoError(t, err)
1094+
// interrupt after nodes
1095+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1096+
go func() {
1097+
time.Sleep(500 * time.Millisecond)
1098+
cancel(WithGraphInterruptTimeout(time.Hour))
1099+
}()
1100+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1"))
1101+
assert.Error(t, err)
1102+
info, success = ExtractInterruptInfo(err)
1103+
assert.True(t, success)
1104+
assert.Equal(t, []string{"1"}, info.AfterNodes)
1105+
result, err = r.Invoke(ctx, "input", WithCheckPointID("1"))
1106+
assert.NoError(t, err)
1107+
assert.Equal(t, "input12", result)
1108+
// infinite timeout
1109+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1110+
go func() {
1111+
time.Sleep(500 * time.Millisecond)
1112+
cancel()
1113+
}()
1114+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2"))
1115+
assert.Error(t, err)
1116+
info, success = ExtractInterruptInfo(err)
1117+
assert.True(t, success)
1118+
assert.Equal(t, []string{"1"}, info.AfterNodes)
1119+
result, err = r.Invoke(ctx, "input", WithCheckPointID("2"))
1120+
assert.NoError(t, err)
1121+
assert.Equal(t, "input12", result)
1122+
1123+
// interrupt rerun nodes
1124+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1125+
go func() {
1126+
time.Sleep(300 * time.Millisecond)
1127+
cancel(WithGraphInterruptTimeout(0))
1128+
}()
1129+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3"))
1130+
assert.Error(t, err)
1131+
info, success = ExtractInterruptInfo(err)
1132+
assert.True(t, success)
1133+
assert.Equal(t, []string{"1"}, info.RerunNodes)
1134+
result, err = r.Invoke(ctx, "input", WithCheckPointID("3"))
1135+
assert.NoError(t, err)
1136+
assert.Equal(t, "12", result)
1137+
1138+
// dag multi canceled nodes
1139+
gg := NewGraph[string, map[string]any]()
1140+
_ = gg.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1141+
return input + "1", nil
1142+
}))
1143+
_ = gg.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1144+
time.Sleep(3 * time.Second)
1145+
return input + "2", nil
1146+
}), WithOutputKey("2"))
1147+
_ = gg.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1148+
time.Sleep(3 * time.Second)
1149+
return input + "3", nil
1150+
}), WithOutputKey("3"))
1151+
_ = gg.AddLambdaNode("4", InvokableLambda(func(ctx context.Context, input map[string]any) (output map[string]any, err error) {
1152+
return input, nil
1153+
}))
1154+
_ = gg.AddEdge(START, "1")
1155+
_ = gg.AddEdge("1", "2")
1156+
_ = gg.AddEdge("1", "3")
1157+
_ = gg.AddEdge("2", "4")
1158+
_ = gg.AddEdge("3", "4")
1159+
_ = gg.AddEdge("4", END)
1160+
ctx = context.Background()
1161+
rr, err := gg.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore()))
1162+
assert.NoError(t, err)
1163+
// interrupt after nodes
1164+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1165+
go func() {
1166+
time.Sleep(500 * time.Millisecond)
1167+
cancel(WithGraphInterruptTimeout(time.Hour))
1168+
}()
1169+
_, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("1"))
1170+
assert.Error(t, err)
1171+
info, success = ExtractInterruptInfo(err)
1172+
assert.True(t, success)
1173+
assert.Equal(t, 2, len(info.AfterNodes))
1174+
result2, err := rr.Invoke(ctx, "input", WithCheckPointID("1"))
1175+
assert.NoError(t, err)
1176+
assert.Equal(t, map[string]any{
1177+
"2": "input12",
1178+
"3": "input13",
1179+
}, result2)
1180+
1181+
// interrupt rerun nodes
1182+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1183+
go func() {
1184+
time.Sleep(500 * time.Millisecond)
1185+
cancel(WithGraphInterruptTimeout(0))
1186+
}()
1187+
_, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("2"))
1188+
assert.Error(t, err)
1189+
info, success = ExtractInterruptInfo(err)
1190+
assert.True(t, success)
1191+
assert.Equal(t, 2, len(info.RerunNodes))
1192+
result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2"))
1193+
assert.NoError(t, err)
1194+
assert.Equal(t, map[string]any{
1195+
"2": "2",
1196+
"3": "3",
1197+
}, result2)
1198+
}

compose/graph_call_options.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
package compose
1818

1919
import (
20+
"context"
2021
"fmt"
2122
"reflect"
23+
"time"
2224

2325
"github.com/cloudwego/eino/callbacks"
2426
"github.com/cloudwego/eino/components/document"
@@ -29,6 +31,51 @@ import (
2931
"github.com/cloudwego/eino/components/retriever"
3032
)
3133

34+
type graphCancelChanKey struct{}
35+
type graphCancelChanVal struct {
36+
ch chan *time.Duration
37+
}
38+
39+
type graphInterruptOptions struct {
40+
timeout *time.Duration
41+
}
42+
43+
type GraphInterruptOption func(o *graphInterruptOptions)
44+
45+
// WithGraphInterruptTimeout specifies the max waiting time before generating an interrupt.
46+
// After the max waiting time, the graph will force an interrupt. Any unfinished tasks will be re-run when the graph is resumed.
47+
func WithGraphInterruptTimeout(timeout time.Duration) GraphInterruptOption {
48+
return func(o *graphInterruptOptions) {
49+
o.timeout = &timeout
50+
}
51+
}
52+
53+
// WithGraphInterrupt creates a context with graph cancellation support.
54+
// When the returned context is used to invoke a graph or workflow, calling the interrupt function will trigger an interrupt.
55+
// The graph will wait for current tasks to complete by default.
56+
func WithGraphInterrupt(parent context.Context) (ctx context.Context, interrupt func(opts ...GraphInterruptOption)) {
57+
ch := make(chan *time.Duration, 1)
58+
ctx = context.WithValue(parent, graphCancelChanKey{}, &graphCancelChanVal{
59+
ch: ch,
60+
})
61+
return ctx, func(opts ...GraphInterruptOption) {
62+
o := &graphInterruptOptions{}
63+
for _, opt := range opts {
64+
opt(o)
65+
}
66+
ch <- o.timeout
67+
close(ch)
68+
}
69+
}
70+
71+
func getGraphCancel(ctx context.Context) *graphCancelChanVal {
72+
val, ok := ctx.Value(graphCancelChanKey{}).(*graphCancelChanVal)
73+
if !ok {
74+
return nil
75+
}
76+
return val
77+
}
78+
3279
// Option is a functional option type for calling a graph.
3380
type Option struct {
3481
options []any

0 commit comments

Comments
 (0)