@@ -115,7 +115,11 @@ def __init__( # noqa: PLR0913
115115 failure_chance : float ,
116116 mode : Mode ,
117117 probe : Callable [[Dependencies , int ], Any ] | None ,
118+ assert_eventually : Callable [[Dependencies , int ], None ] | None ,
119+ assert_always : Callable [[Dependencies , int ], None ] | None ,
118120 ) -> None :
121+ self ._assert_eventually = assert_eventually
122+ self ._assert_always = assert_always
119123 self ._runnable_coroutines : RunnableCoroutines = []
120124 self ._awaitables : Awaitables = {}
121125 self ._runnable_functions : RunnableFunctions = []
@@ -333,7 +337,7 @@ def _run_function_and_move_awaitables_to_runnables(
333337 v = v ,
334338 )
335339
336- def _run (self ) -> list [Promise [Any ]]:
340+ def _run (self ) -> list [Promise [Any ]]: # noqa: C901
337341 promises = self ._initialize_runnables ()
338342
339343 while True :
@@ -347,6 +351,10 @@ def _run(self) -> list[Promise[Any]]:
347351
348352 if self ._probe is not None :
349353 self ._probe_results .append (self ._probe (self .deps , self .tick ))
354+
355+ if self ._assert_always is not None :
356+ self ._assert_always (self .deps , self .tick )
357+
350358 self .tick += 1
351359
352360 if (
@@ -370,11 +378,19 @@ def _run(self) -> list[Promise[Any]]:
370378 else :
371379 assert_never (next_step )
372380
381+ if self ._assert_eventually is not None :
382+ self ._assert_eventually (self .deps , self .tick )
383+
373384 assert all (p .done () for p in promises ), "All promises should be resolved."
374385 if self ._log_file is not None :
375386 self .dump (file = self ._log_file )
376387 return promises
377388
389+ def _add_coro_to_awaitables (
390+ self , p : Promise [Any ], coro_and_promise : CoroAndPromise [Any ]
391+ ) -> None :
392+ self ._awaitables .setdefault (p , []).append (coro_and_promise )
393+
378394 def _process_each_runnable (
379395 self ,
380396 runnable : Runnable [Any ],
@@ -414,7 +430,9 @@ def _process_each_runnable(
414430 invocation = yieldable_or_final_value .to_invoke (),
415431 runnable = runnable ,
416432 )
417- self ._awaitables [p ] = [runnable .coro_and_promise ]
433+ self ._add_coro_to_awaitables (
434+ p = p , coro_and_promise = runnable .coro_and_promise
435+ )
418436 self ._events .append (
419437 AwaitedForPromise (promise_id = p .promise_id , tick = self .tick )
420438 )
@@ -427,8 +445,8 @@ def _process_each_runnable(
427445 self ._runnable_coroutines .append (Runnable (runnable .coro_and_promise , Ok (p )))
428446
429447 elif isinstance (yieldable_or_final_value , Promise ):
430- self ._awaitables . setdefault ( yieldable_or_final_value , []). append (
431- runnable .coro_and_promise ,
448+ self ._add_coro_to_awaitables (
449+ yieldable_or_final_value , runnable .coro_and_promise
432450 )
433451 if yieldable_or_final_value .done ():
434452 unblock_depands_coros (
0 commit comments