@@ -984,3 +984,179 @@ func (f *failStore) Set(ctx context.Context, checkPointID string, checkPoint []b
984984 f .t .Fatalf ("cannot call store" )
985985 return errors .New ("fail" )
986986}
987+
988+ func TestCancelInterrupt (t * testing.T ) {
989+ g := NewGraph [string , string ]()
990+ _ = g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
991+ time .Sleep (3 * time .Second )
992+ return input + "1" , nil
993+ }))
994+ _ = g .AddLambdaNode ("2" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
995+ return input + "2" , nil
996+ }))
997+ _ = g .AddEdge (START , "1" )
998+ _ = g .AddEdge ("1" , "2" )
999+ _ = g .AddEdge ("2" , END )
1000+ ctx := context .Background ()
1001+
1002+ // pregel
1003+ r , err := g .Compile (ctx , WithCheckPointStore (newInMemoryStore ()))
1004+ assert .NoError (t , err )
1005+ // interrupt after nodes
1006+ canceledCtx , cancel := WithGraphInterrupt (ctx )
1007+ go func () {
1008+ time .Sleep (500 * time .Millisecond )
1009+ timeout := time .Hour
1010+ cancel (& timeout )
1011+ }()
1012+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("1" ))
1013+ assert .Error (t , err )
1014+ info , success := ExtractInterruptInfo (err )
1015+ assert .True (t , success )
1016+ assert .Equal (t , []string {"1" }, info .AfterNodes )
1017+ result , err := r .Invoke (ctx , "input" , WithCheckPointID ("1" ))
1018+ assert .NoError (t , err )
1019+ assert .Equal (t , "input12" , result )
1020+ // infinite timeout
1021+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1022+ go func () {
1023+ time .Sleep (500 * time .Millisecond )
1024+ cancel (nil )
1025+ }()
1026+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("2" ))
1027+ assert .Error (t , err )
1028+ info , success = ExtractInterruptInfo (err )
1029+ assert .True (t , success )
1030+ assert .Equal (t , []string {"1" }, info .AfterNodes )
1031+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("2" ))
1032+ assert .NoError (t , err )
1033+ assert .Equal (t , "input12" , result )
1034+
1035+ // interrupt rerun nodes
1036+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1037+ go func () {
1038+ time .Sleep (500 * time .Millisecond )
1039+ timeout := time .Duration (0 )
1040+ cancel (& timeout )
1041+ }()
1042+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("3" ))
1043+ assert .Error (t , err )
1044+ info , success = ExtractInterruptInfo (err )
1045+ assert .True (t , success )
1046+ assert .Equal (t , []string {"1" }, info .RerunNodes )
1047+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("3" ))
1048+ assert .NoError (t , err )
1049+ assert .Equal (t , "12" , result )
1050+
1051+ // dag
1052+ r , err = g .Compile (ctx , WithNodeTriggerMode (AllPredecessor ), WithCheckPointStore (newInMemoryStore ()))
1053+ assert .NoError (t , err )
1054+ // interrupt after nodes
1055+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1056+ go func () {
1057+ time .Sleep (500 * time .Millisecond )
1058+ timeout := time .Hour
1059+ cancel (& timeout )
1060+ }()
1061+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("1" ))
1062+ assert .Error (t , err )
1063+ info , success = ExtractInterruptInfo (err )
1064+ assert .True (t , success )
1065+ assert .Equal (t , []string {"1" }, info .AfterNodes )
1066+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("1" ))
1067+ assert .NoError (t , err )
1068+ assert .Equal (t , "input12" , result )
1069+ // infinite timeout
1070+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1071+ go func () {
1072+ time .Sleep (500 * time .Millisecond )
1073+ cancel (nil )
1074+ }()
1075+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("2" ))
1076+ assert .Error (t , err )
1077+ info , success = ExtractInterruptInfo (err )
1078+ assert .True (t , success )
1079+ assert .Equal (t , []string {"1" }, info .AfterNodes )
1080+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("2" ))
1081+ assert .NoError (t , err )
1082+ assert .Equal (t , "input12" , result )
1083+
1084+ // interrupt rerun nodes
1085+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1086+ go func () {
1087+ time .Sleep (300 * time .Millisecond )
1088+ timeout := time .Duration (0 )
1089+ cancel (& timeout )
1090+ }()
1091+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("3" ))
1092+ assert .Error (t , err )
1093+ info , success = ExtractInterruptInfo (err )
1094+ assert .True (t , success )
1095+ assert .Equal (t , []string {"1" }, info .RerunNodes )
1096+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("3" ))
1097+ assert .NoError (t , err )
1098+ assert .Equal (t , "12" , result )
1099+
1100+ // dag multi canceled nodes
1101+ gg := NewGraph [string , map [string ]any ]()
1102+ _ = gg .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1103+ return input + "1" , nil
1104+ }))
1105+ _ = gg .AddLambdaNode ("2" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1106+ time .Sleep (3 * time .Second )
1107+ return input + "2" , nil
1108+ }), WithOutputKey ("2" ))
1109+ _ = gg .AddLambdaNode ("3" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1110+ time .Sleep (3 * time .Second )
1111+ return input + "3" , nil
1112+ }), WithOutputKey ("3" ))
1113+ _ = gg .AddLambdaNode ("4" , InvokableLambda (func (ctx context.Context , input map [string ]any ) (output map [string ]any , err error ) {
1114+ return input , nil
1115+ }))
1116+ _ = gg .AddEdge (START , "1" )
1117+ _ = gg .AddEdge ("1" , "2" )
1118+ _ = gg .AddEdge ("1" , "3" )
1119+ _ = gg .AddEdge ("2" , "4" )
1120+ _ = gg .AddEdge ("3" , "4" )
1121+ _ = gg .AddEdge ("4" , END )
1122+ ctx = context .Background ()
1123+ rr , err := gg .Compile (ctx , WithNodeTriggerMode (AllPredecessor ), WithCheckPointStore (newInMemoryStore ()))
1124+ assert .NoError (t , err )
1125+ // interrupt after nodes
1126+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1127+ go func () {
1128+ time .Sleep (500 * time .Millisecond )
1129+ timeout := time .Hour
1130+ cancel (& timeout )
1131+ }()
1132+ _ , err = rr .Invoke (canceledCtx , "input" , WithCheckPointID ("1" ))
1133+ assert .Error (t , err )
1134+ info , success = ExtractInterruptInfo (err )
1135+ assert .True (t , success )
1136+ assert .Equal (t , 2 , len (info .AfterNodes ))
1137+ result2 , err := rr .Invoke (ctx , "input" , WithCheckPointID ("1" ))
1138+ assert .NoError (t , err )
1139+ assert .Equal (t , map [string ]any {
1140+ "2" : "input12" ,
1141+ "3" : "input13" ,
1142+ }, result2 )
1143+
1144+ // interrupt rerun nodes
1145+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1146+ go func () {
1147+ time .Sleep (500 * time .Millisecond )
1148+ timeout := time .Duration (0 )
1149+ cancel (& timeout )
1150+ }()
1151+ _ , err = rr .Invoke (canceledCtx , "input" , WithCheckPointID ("2" ))
1152+ assert .Error (t , err )
1153+ info , success = ExtractInterruptInfo (err )
1154+ assert .True (t , success )
1155+ assert .Equal (t , 2 , len (info .RerunNodes ))
1156+ result2 , err = rr .Invoke (ctx , "input" , WithCheckPointID ("2" ))
1157+ assert .NoError (t , err )
1158+ assert .Equal (t , map [string ]any {
1159+ "2" : "2" ,
1160+ "3" : "3" ,
1161+ }, result2 )
1162+ }
0 commit comments