@@ -1015,3 +1015,184 @@ func TestPreHandlerInterrupt(t *testing.T) {
10151015 assert .NoError (t , err )
10161016 assert .Equal (t , "1" , result )
10171017}
1018+
1019+ func TestCancelInterrupt (t * testing.T ) {
1020+ g := NewGraph [string , string ]()
1021+ _ = g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1022+ time .Sleep (3 * time .Second )
1023+ return input + "1" , nil
1024+ }))
1025+ _ = g .AddLambdaNode ("2" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1026+ return input + "2" , nil
1027+ }))
1028+ _ = g .AddEdge (START , "1" )
1029+ _ = g .AddEdge ("1" , "2" )
1030+ _ = g .AddEdge ("2" , END )
1031+ ctx := context .Background ()
1032+
1033+ // pregel
1034+ r , err := g .Compile (ctx , WithCheckPointStore (newInMemoryStore ()))
1035+ assert .NoError (t , err )
1036+ // interrupt after nodes
1037+ canceledCtx , cancel := WithGraphInterrupt (ctx )
1038+ go func () {
1039+ time .Sleep (500 * time .Millisecond )
1040+ cancel (WithGraphInterruptTimeout (time .Hour ))
1041+ }()
1042+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("1" ))
1043+ assert .Error (t , err )
1044+ info , success := ExtractInterruptInfo (err )
1045+ assert .True (t , success )
1046+ assert .Equal (t , []string {"1" }, info .AfterNodes )
1047+ result , err := r .Invoke (ctx , "input" , WithCheckPointID ("1" ))
1048+ assert .NoError (t , err )
1049+ assert .Equal (t , "input12" , result )
1050+ // infinite timeout
1051+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1052+ go func () {
1053+ time .Sleep (500 * time .Millisecond )
1054+ cancel ()
1055+ }()
1056+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("2" ))
1057+ assert .Error (t , err )
1058+ info , success = ExtractInterruptInfo (err )
1059+ assert .True (t , success )
1060+ assert .Equal (t , []string {"1" }, info .AfterNodes )
1061+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("2" ))
1062+ assert .NoError (t , err )
1063+ assert .Equal (t , "input12" , result )
1064+
1065+ // interrupt rerun nodes
1066+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1067+ go func () {
1068+ time .Sleep (500 * time .Millisecond )
1069+ cancel (WithGraphInterruptTimeout (0 ))
1070+ }()
1071+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("3" ))
1072+ assert .Error (t , err )
1073+ info , success = ExtractInterruptInfo (err )
1074+ assert .True (t , success )
1075+ assert .Equal (t , []string {"1" }, info .RerunNodes )
1076+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("3" ))
1077+ assert .NoError (t , err )
1078+ assert .Equal (t , "12" , result )
1079+
1080+ // dag
1081+ g = NewGraph [string , string ]()
1082+ _ = g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1083+ time .Sleep (3 * time .Second )
1084+ return input + "1" , nil
1085+ }))
1086+ _ = g .AddLambdaNode ("2" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1087+ return input + "2" , nil
1088+ }))
1089+ _ = g .AddEdge (START , "1" )
1090+ _ = g .AddEdge ("1" , "2" )
1091+ _ = g .AddEdge ("2" , END )
1092+ r , err = g .Compile (ctx , WithNodeTriggerMode (AllPredecessor ), WithCheckPointStore (newInMemoryStore ()))
1093+ assert .NoError (t , err )
1094+ // interrupt after nodes
1095+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1096+ go func () {
1097+ time .Sleep (500 * time .Millisecond )
1098+ cancel (WithGraphInterruptTimeout (time .Hour ))
1099+ }()
1100+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("1" ))
1101+ assert .Error (t , err )
1102+ info , success = ExtractInterruptInfo (err )
1103+ assert .True (t , success )
1104+ assert .Equal (t , []string {"1" }, info .AfterNodes )
1105+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("1" ))
1106+ assert .NoError (t , err )
1107+ assert .Equal (t , "input12" , result )
1108+ // infinite timeout
1109+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1110+ go func () {
1111+ time .Sleep (500 * time .Millisecond )
1112+ cancel ()
1113+ }()
1114+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("2" ))
1115+ assert .Error (t , err )
1116+ info , success = ExtractInterruptInfo (err )
1117+ assert .True (t , success )
1118+ assert .Equal (t , []string {"1" }, info .AfterNodes )
1119+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("2" ))
1120+ assert .NoError (t , err )
1121+ assert .Equal (t , "input12" , result )
1122+
1123+ // interrupt rerun nodes
1124+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1125+ go func () {
1126+ time .Sleep (300 * time .Millisecond )
1127+ cancel (WithGraphInterruptTimeout (0 ))
1128+ }()
1129+ _ , err = r .Invoke (canceledCtx , "input" , WithCheckPointID ("3" ))
1130+ assert .Error (t , err )
1131+ info , success = ExtractInterruptInfo (err )
1132+ assert .True (t , success )
1133+ assert .Equal (t , []string {"1" }, info .RerunNodes )
1134+ result , err = r .Invoke (ctx , "input" , WithCheckPointID ("3" ))
1135+ assert .NoError (t , err )
1136+ assert .Equal (t , "12" , result )
1137+
1138+ // dag multi canceled nodes
1139+ gg := NewGraph [string , map [string ]any ]()
1140+ _ = gg .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1141+ return input + "1" , nil
1142+ }))
1143+ _ = gg .AddLambdaNode ("2" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1144+ time .Sleep (3 * time .Second )
1145+ return input + "2" , nil
1146+ }), WithOutputKey ("2" ))
1147+ _ = gg .AddLambdaNode ("3" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1148+ time .Sleep (3 * time .Second )
1149+ return input + "3" , nil
1150+ }), WithOutputKey ("3" ))
1151+ _ = gg .AddLambdaNode ("4" , InvokableLambda (func (ctx context.Context , input map [string ]any ) (output map [string ]any , err error ) {
1152+ return input , nil
1153+ }))
1154+ _ = gg .AddEdge (START , "1" )
1155+ _ = gg .AddEdge ("1" , "2" )
1156+ _ = gg .AddEdge ("1" , "3" )
1157+ _ = gg .AddEdge ("2" , "4" )
1158+ _ = gg .AddEdge ("3" , "4" )
1159+ _ = gg .AddEdge ("4" , END )
1160+ ctx = context .Background ()
1161+ rr , err := gg .Compile (ctx , WithNodeTriggerMode (AllPredecessor ), WithCheckPointStore (newInMemoryStore ()))
1162+ assert .NoError (t , err )
1163+ // interrupt after nodes
1164+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1165+ go func () {
1166+ time .Sleep (500 * time .Millisecond )
1167+ cancel (WithGraphInterruptTimeout (time .Hour ))
1168+ }()
1169+ _ , err = rr .Invoke (canceledCtx , "input" , WithCheckPointID ("1" ))
1170+ assert .Error (t , err )
1171+ info , success = ExtractInterruptInfo (err )
1172+ assert .True (t , success )
1173+ assert .Equal (t , 2 , len (info .AfterNodes ))
1174+ result2 , err := rr .Invoke (ctx , "input" , WithCheckPointID ("1" ))
1175+ assert .NoError (t , err )
1176+ assert .Equal (t , map [string ]any {
1177+ "2" : "input12" ,
1178+ "3" : "input13" ,
1179+ }, result2 )
1180+
1181+ // interrupt rerun nodes
1182+ canceledCtx , cancel = WithGraphInterrupt (ctx )
1183+ go func () {
1184+ time .Sleep (500 * time .Millisecond )
1185+ cancel (WithGraphInterruptTimeout (0 ))
1186+ }()
1187+ _ , err = rr .Invoke (canceledCtx , "input" , WithCheckPointID ("2" ))
1188+ assert .Error (t , err )
1189+ info , success = ExtractInterruptInfo (err )
1190+ assert .True (t , success )
1191+ assert .Equal (t , 2 , len (info .RerunNodes ))
1192+ result2 , err = rr .Invoke (ctx , "input" , WithCheckPointID ("2" ))
1193+ assert .NoError (t , err )
1194+ assert .Equal (t , map [string ]any {
1195+ "2" : "2" ,
1196+ "3" : "3" ,
1197+ }, result2 )
1198+ }
0 commit comments