Skip to content

Commit 687f0b8

Browse files
feat: support cancel interrupt
1 parent 5c18870 commit 687f0b8

File tree

4 files changed

+453
-92
lines changed

4 files changed

+453
-92
lines changed

compose/checkpoint_test.go

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,3 +984,179 @@ func (f *failStore) Set(ctx context.Context, checkPointID string, checkPoint []b
984984
f.t.Fatalf("cannot call store")
985985
return errors.New("fail")
986986
}
987+
988+
func TestCancelInterrupt(t *testing.T) {
989+
g := NewGraph[string, string]()
990+
_ = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
991+
time.Sleep(3 * time.Second)
992+
return input + "1", nil
993+
}))
994+
_ = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
995+
return input + "2", nil
996+
}))
997+
_ = g.AddEdge(START, "1")
998+
_ = g.AddEdge("1", "2")
999+
_ = g.AddEdge("2", END)
1000+
ctx := context.Background()
1001+
1002+
// pregel
1003+
r, err := g.Compile(ctx, WithCheckPointStore(newInMemoryStore()))
1004+
assert.NoError(t, err)
1005+
// interrupt after nodes
1006+
canceledCtx, cancel := WithGraphInterrupt(ctx)
1007+
go func() {
1008+
time.Sleep(500 * time.Millisecond)
1009+
timeout := time.Hour
1010+
cancel(&timeout)
1011+
}()
1012+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1"))
1013+
assert.Error(t, err)
1014+
info, success := ExtractInterruptInfo(err)
1015+
assert.True(t, success)
1016+
assert.Equal(t, []string{"1"}, info.AfterNodes)
1017+
result, err := r.Invoke(ctx, "input", WithCheckPointID("1"))
1018+
assert.NoError(t, err)
1019+
assert.Equal(t, "input12", result)
1020+
// infinite timeout
1021+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1022+
go func() {
1023+
time.Sleep(500 * time.Millisecond)
1024+
cancel(nil)
1025+
}()
1026+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2"))
1027+
assert.Error(t, err)
1028+
info, success = ExtractInterruptInfo(err)
1029+
assert.True(t, success)
1030+
assert.Equal(t, []string{"1"}, info.AfterNodes)
1031+
result, err = r.Invoke(ctx, "input", WithCheckPointID("2"))
1032+
assert.NoError(t, err)
1033+
assert.Equal(t, "input12", result)
1034+
1035+
// interrupt rerun nodes
1036+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1037+
go func() {
1038+
time.Sleep(500 * time.Millisecond)
1039+
timeout := time.Duration(0)
1040+
cancel(&timeout)
1041+
}()
1042+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3"))
1043+
assert.Error(t, err)
1044+
info, success = ExtractInterruptInfo(err)
1045+
assert.True(t, success)
1046+
assert.Equal(t, []string{"1"}, info.RerunNodes)
1047+
result, err = r.Invoke(ctx, "input", WithCheckPointID("3"))
1048+
assert.NoError(t, err)
1049+
assert.Equal(t, "12", result)
1050+
1051+
// dag
1052+
r, err = g.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore()))
1053+
assert.NoError(t, err)
1054+
// interrupt after nodes
1055+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1056+
go func() {
1057+
time.Sleep(500 * time.Millisecond)
1058+
timeout := time.Hour
1059+
cancel(&timeout)
1060+
}()
1061+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("1"))
1062+
assert.Error(t, err)
1063+
info, success = ExtractInterruptInfo(err)
1064+
assert.True(t, success)
1065+
assert.Equal(t, []string{"1"}, info.AfterNodes)
1066+
result, err = r.Invoke(ctx, "input", WithCheckPointID("1"))
1067+
assert.NoError(t, err)
1068+
assert.Equal(t, "input12", result)
1069+
// infinite timeout
1070+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1071+
go func() {
1072+
time.Sleep(500 * time.Millisecond)
1073+
cancel(nil)
1074+
}()
1075+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("2"))
1076+
assert.Error(t, err)
1077+
info, success = ExtractInterruptInfo(err)
1078+
assert.True(t, success)
1079+
assert.Equal(t, []string{"1"}, info.AfterNodes)
1080+
result, err = r.Invoke(ctx, "input", WithCheckPointID("2"))
1081+
assert.NoError(t, err)
1082+
assert.Equal(t, "input12", result)
1083+
1084+
// interrupt rerun nodes
1085+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1086+
go func() {
1087+
time.Sleep(300 * time.Millisecond)
1088+
timeout := time.Duration(0)
1089+
cancel(&timeout)
1090+
}()
1091+
_, err = r.Invoke(canceledCtx, "input", WithCheckPointID("3"))
1092+
assert.Error(t, err)
1093+
info, success = ExtractInterruptInfo(err)
1094+
assert.True(t, success)
1095+
assert.Equal(t, []string{"1"}, info.RerunNodes)
1096+
result, err = r.Invoke(ctx, "input", WithCheckPointID("3"))
1097+
assert.NoError(t, err)
1098+
assert.Equal(t, "12", result)
1099+
1100+
// dag multi canceled nodes
1101+
gg := NewGraph[string, map[string]any]()
1102+
_ = gg.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1103+
return input + "1", nil
1104+
}))
1105+
_ = gg.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1106+
time.Sleep(3 * time.Second)
1107+
return input + "2", nil
1108+
}), WithOutputKey("2"))
1109+
_ = gg.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1110+
time.Sleep(3 * time.Second)
1111+
return input + "3", nil
1112+
}), WithOutputKey("3"))
1113+
_ = gg.AddLambdaNode("4", InvokableLambda(func(ctx context.Context, input map[string]any) (output map[string]any, err error) {
1114+
return input, nil
1115+
}))
1116+
_ = gg.AddEdge(START, "1")
1117+
_ = gg.AddEdge("1", "2")
1118+
_ = gg.AddEdge("1", "3")
1119+
_ = gg.AddEdge("2", "4")
1120+
_ = gg.AddEdge("3", "4")
1121+
_ = gg.AddEdge("4", END)
1122+
ctx = context.Background()
1123+
rr, err := gg.Compile(ctx, WithNodeTriggerMode(AllPredecessor), WithCheckPointStore(newInMemoryStore()))
1124+
assert.NoError(t, err)
1125+
// interrupt after nodes
1126+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1127+
go func() {
1128+
time.Sleep(500 * time.Millisecond)
1129+
timeout := time.Hour
1130+
cancel(&timeout)
1131+
}()
1132+
_, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("1"))
1133+
assert.Error(t, err)
1134+
info, success = ExtractInterruptInfo(err)
1135+
assert.True(t, success)
1136+
assert.Equal(t, 2, len(info.AfterNodes))
1137+
result2, err := rr.Invoke(ctx, "input", WithCheckPointID("1"))
1138+
assert.NoError(t, err)
1139+
assert.Equal(t, map[string]any{
1140+
"2": "input12",
1141+
"3": "input13",
1142+
}, result2)
1143+
1144+
// interrupt rerun nodes
1145+
canceledCtx, cancel = WithGraphInterrupt(ctx)
1146+
go func() {
1147+
time.Sleep(500 * time.Millisecond)
1148+
timeout := time.Duration(0)
1149+
cancel(&timeout)
1150+
}()
1151+
_, err = rr.Invoke(canceledCtx, "input", WithCheckPointID("2"))
1152+
assert.Error(t, err)
1153+
info, success = ExtractInterruptInfo(err)
1154+
assert.True(t, success)
1155+
assert.Equal(t, 2, len(info.RerunNodes))
1156+
result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2"))
1157+
assert.NoError(t, err)
1158+
assert.Equal(t, map[string]any{
1159+
"2": "2",
1160+
"3": "3",
1161+
}, result2)
1162+
}

compose/graph_call_options.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
package compose
1818

1919
import (
20+
"context"
2021
"fmt"
2122
"reflect"
23+
"time"
2224

2325
"github.com/cloudwego/eino/callbacks"
2426
"github.com/cloudwego/eino/components/document"
@@ -29,6 +31,34 @@ import (
2931
"github.com/cloudwego/eino/components/retriever"
3032
)
3133

34+
type graphCancelChanKey struct{}
35+
type graphCancelChanVal struct {
36+
ch chan *time.Duration
37+
}
38+
39+
// WithGraphInterrupt creates a context with graph cancellation support.
40+
// When the returned context is used to invoke a graph or workflow, calling the interrupt function will trigger an interrupt.
41+
// The graph will wait for current tasks to complete (up to maxWaitTime) before generating an interrupt.
42+
// After maxWaitTime, the graph will force an interrupt. Any unfinished tasks will be re-run when the graph is resumed.
43+
func WithGraphInterrupt(parent context.Context) (ctx context.Context, interrupt func(maxWaitTime *time.Duration)) {
44+
ch := make(chan *time.Duration, 1)
45+
ctx = context.WithValue(parent, graphCancelChanKey{}, &graphCancelChanVal{
46+
ch: ch,
47+
})
48+
return ctx, func(maxWaitTime *time.Duration) {
49+
ch <- maxWaitTime
50+
close(ch)
51+
}
52+
}
53+
54+
func getGraphCancel(ctx context.Context) *graphCancelChanVal {
55+
val, ok := ctx.Value(graphCancelChanKey{}).(*graphCancelChanVal)
56+
if !ok {
57+
return nil
58+
}
59+
return val
60+
}
61+
3262
// Option is a functional option type for calling a graph.
3363
type Option struct {
3464
options []any

0 commit comments

Comments
 (0)