|
5 | 5 | import pytest |
6 | 6 |
|
7 | 7 | from docket import CurrentDocket, CurrentWorker, Docket, Worker |
8 | | -from docket.dependencies import Depends, ExponentialRetry, Retry, TaskArgument |
| 8 | +from docket.dependencies import ( |
| 9 | + Depends, |
| 10 | + Dependency, |
| 11 | + ExponentialRetry, |
| 12 | + Retry, |
| 13 | + TaskArgument, |
| 14 | + _Depends, # type: ignore[attr-defined] |
| 15 | + resolved_dependencies, |
| 16 | +) |
| 17 | +from docket.execution import Execution |
9 | 18 |
|
10 | 19 |
|
11 | 20 | async def test_dependencies_may_be_duplicated(docket: Docket, worker: Worker): |
@@ -449,3 +458,216 @@ async def dependent_task(result: int = Depends(sync_adder)): |
449 | 458 | await worker.run_until_finished() |
450 | 459 |
|
451 | 460 | assert called |
| 461 | + |
| 462 | + |
| 463 | +async def test_contextvar_isolation_between_tasks(docket: Docket, worker: Worker): |
| 464 | + """Contextvars should be isolated between sequential task executions""" |
| 465 | + executions_seen: list[tuple[str, Execution]] = [] |
| 466 | + |
| 467 | + async def first_task(a: str): |
| 468 | + # Capture the execution context during first task |
| 469 | + execution = Dependency.execution.get() |
| 470 | + executions_seen.append(("first", execution)) |
| 471 | + assert a == "first" |
| 472 | + |
| 473 | + async def second_task(b: str): |
| 474 | + # Capture the execution context during second task |
| 475 | + execution = Dependency.execution.get() |
| 476 | + executions_seen.append(("second", execution)) |
| 477 | + assert b == "second" |
| 478 | + |
| 479 | + # The execution should be different from the first task |
| 480 | + first_execution = executions_seen[0][1] |
| 481 | + assert execution is not first_execution |
| 482 | + assert execution.kwargs["b"] == "second" |
| 483 | + assert first_execution.kwargs["a"] == "first" |
| 484 | + |
| 485 | + await docket.add(first_task)(a="first") |
| 486 | + await docket.add(second_task)(b="second") |
| 487 | + await worker.run_until_finished() |
| 488 | + |
| 489 | + assert len(executions_seen) == 2 |
| 490 | + assert executions_seen[0][0] == "first" |
| 491 | + assert executions_seen[1][0] == "second" |
| 492 | + |
| 493 | + |
| 494 | +async def test_contextvar_cleanup_after_task(docket: Docket, worker: Worker): |
| 495 | + """Contextvars should be reset after task execution completes""" |
| 496 | + captured_stack = None |
| 497 | + captured_cache = None |
| 498 | + |
| 499 | + async def capture_task(): |
| 500 | + nonlocal captured_stack, captured_cache |
| 501 | + # Capture references during task execution |
| 502 | + captured_stack = _Depends.stack.get() |
| 503 | + captured_cache = _Depends.cache.get() |
| 504 | + |
| 505 | + await docket.add(capture_task)() |
| 506 | + await worker.run_until_finished() |
| 507 | + |
| 508 | + # After the task completes, the contextvars should be reset |
| 509 | + # Attempting to get them should raise LookupError |
| 510 | + with pytest.raises(LookupError): |
| 511 | + _Depends.stack.get() |
| 512 | + |
| 513 | + with pytest.raises(LookupError): |
| 514 | + _Depends.cache.get() |
| 515 | + |
| 516 | + with pytest.raises(LookupError): |
| 517 | + Dependency.execution.get() |
| 518 | + |
| 519 | + with pytest.raises(LookupError): |
| 520 | + Dependency.worker.get() |
| 521 | + |
| 522 | + with pytest.raises(LookupError): |
| 523 | + Dependency.docket.get() |
| 524 | + |
| 525 | + |
| 526 | +async def test_dependency_cache_isolated_between_tasks(docket: Docket, worker: Worker): |
| 527 | + """Dependency cache should be fresh for each task, not reused""" |
| 528 | + call_counts = {"task1": 0, "task2": 0} |
| 529 | + |
| 530 | + def dependency_for_task1() -> str: |
| 531 | + call_counts["task1"] += 1 |
| 532 | + return f"task1-call-{call_counts['task1']}" |
| 533 | + |
| 534 | + def dependency_for_task2() -> str: |
| 535 | + call_counts["task2"] += 1 |
| 536 | + return f"task2-call-{call_counts['task2']}" |
| 537 | + |
| 538 | + async def first_task(val: str = Depends(dependency_for_task1)): |
| 539 | + assert val == "task1-call-1" |
| 540 | + |
| 541 | + async def second_task(val: str = Depends(dependency_for_task2)): |
| 542 | + assert val == "task2-call-1" |
| 543 | + |
| 544 | + # Run tasks sequentially |
| 545 | + await docket.add(first_task)() |
| 546 | + await worker.run_until_finished() |
| 547 | + |
| 548 | + await docket.add(second_task)() |
| 549 | + await worker.run_until_finished() |
| 550 | + |
| 551 | + # Each dependency should have been called once (no cache leakage between tasks) |
| 552 | + assert call_counts["task1"] == 1 |
| 553 | + assert call_counts["task2"] == 1 |
| 554 | + |
| 555 | + |
| 556 | +async def test_async_exit_stack_cleanup(docket: Docket, worker: Worker): |
| 557 | + """AsyncExitStack should be properly cleaned up after task execution""" |
| 558 | + from contextlib import asynccontextmanager |
| 559 | + |
| 560 | + cleanup_called: list[str] = [] |
| 561 | + |
| 562 | + @asynccontextmanager |
| 563 | + async def tracked_resource(): |
| 564 | + try: |
| 565 | + yield "resource" |
| 566 | + finally: |
| 567 | + cleanup_called.append("cleaned") |
| 568 | + |
| 569 | + async def task_with_context(res: str = Depends(tracked_resource)): |
| 570 | + assert res == "resource" |
| 571 | + assert len(cleanup_called) == 0 # Not cleaned up yet |
| 572 | + |
| 573 | + await docket.add(task_with_context)() |
| 574 | + await worker.run_until_finished() |
| 575 | + |
| 576 | + # After task completes, cleanup should have been called |
| 577 | + assert cleanup_called == ["cleaned"] |
| 578 | + |
| 579 | + |
| 580 | +async def test_contextvar_reset_on_reentrant_call(docket: Docket, worker: Worker): |
| 581 | + """Contextvars should be properly reset on reentrant calls to resolved_dependencies""" |
| 582 | + |
| 583 | + # Create two mock executions |
| 584 | + async def task1(): |
| 585 | + pass |
| 586 | + |
| 587 | + async def task2(): |
| 588 | + pass |
| 589 | + |
| 590 | + execution1 = Execution( |
| 591 | + key="task1-key", |
| 592 | + function=task1, |
| 593 | + args=(), |
| 594 | + kwargs={}, |
| 595 | + attempt=1, |
| 596 | + when=datetime.now(timezone.utc), |
| 597 | + ) |
| 598 | + |
| 599 | + execution2 = Execution( |
| 600 | + key="task2-key", |
| 601 | + function=task2, |
| 602 | + args=(), |
| 603 | + kwargs={}, |
| 604 | + attempt=1, |
| 605 | + when=datetime.now(timezone.utc), |
| 606 | + ) |
| 607 | + |
| 608 | + # Capture contextvars from first call |
| 609 | + captured_exec1 = None |
| 610 | + captured_stack1 = None |
| 611 | + |
| 612 | + async with resolved_dependencies(worker, execution1): |
| 613 | + captured_exec1 = Dependency.execution.get() |
| 614 | + captured_stack1 = _Depends.stack.get() |
| 615 | + assert captured_exec1 is execution1 |
| 616 | + |
| 617 | + # After exiting, contextvars should be reset (raise LookupError) |
| 618 | + try: |
| 619 | + current_exec = Dependency.execution.get() |
| 620 | + # If we get here without LookupError, check if it's stale |
| 621 | + assert current_exec is not execution1, ( |
| 622 | + "Contextvar still points to old execution!" |
| 623 | + ) |
| 624 | + except LookupError: |
| 625 | + # Expected - contextvar was properly reset |
| 626 | + pass |
| 627 | + |
| 628 | + # Now make a second call - should not see values from first call |
| 629 | + async with resolved_dependencies(worker, execution2): |
| 630 | + captured_exec2 = Dependency.execution.get() |
| 631 | + captured_stack2 = _Depends.stack.get() |
| 632 | + assert captured_exec2 is execution2 |
| 633 | + assert captured_exec2 is not captured_exec1 |
| 634 | + # Stacks should be different objects |
| 635 | + assert captured_stack2 is not captured_stack1 |
| 636 | + |
| 637 | + |
| 638 | +async def test_contextvar_not_leaked_to_caller(docket: Docket): |
| 639 | + """Verify contextvars don't leak outside resolved_dependencies context""" |
| 640 | + # Before calling resolved_dependencies, contextvars should not be set |
| 641 | + with pytest.raises(LookupError): |
| 642 | + Dependency.execution.get() |
| 643 | + |
| 644 | + async def dummy_task(): |
| 645 | + pass |
| 646 | + |
| 647 | + execution = Execution( |
| 648 | + key="test-key", |
| 649 | + function=dummy_task, |
| 650 | + args=(), |
| 651 | + kwargs={}, |
| 652 | + attempt=1, |
| 653 | + when=datetime.now(timezone.utc), |
| 654 | + ) |
| 655 | + |
| 656 | + from docket.worker import Worker |
| 657 | + |
| 658 | + async with Docket("test-contextvar-leak", url="memory://leak-test") as test_docket: |
| 659 | + async with Worker(test_docket) as test_worker: |
| 660 | + # Use resolved_dependencies |
| 661 | + async with resolved_dependencies(test_worker, execution): |
| 662 | + # Inside context, we should be able to get values |
| 663 | + assert Dependency.execution.get() is execution |
| 664 | + |
| 665 | + # After exiting context, contextvars should be cleaned up |
| 666 | + with pytest.raises(LookupError): |
| 667 | + Dependency.execution.get() |
| 668 | + |
| 669 | + with pytest.raises(LookupError): |
| 670 | + _Depends.stack.get() |
| 671 | + |
| 672 | + with pytest.raises(LookupError): |
| 673 | + _Depends.cache.get() |
0 commit comments