Skip to content

Commit fcbfae9

Browse files
chrisguidryclaude
andcommitted
Reset contextvar tokens in resolved_dependencies
The resolved_dependencies function was setting contextvars without capturing reset tokens, which left stale references after the context manager exited. This could cause issues in reentrant or sequential calls, like "stack is closed" errors when the AsyncExitStack was already closed but still referenced. Now we capture tokens from all contextvar.set() calls and reset them in finally blocks to restore prior state. This follows the same pattern suggested in the fastmcp PR review: jlowin/fastmcp#2318 Added tests that verify contextvars are properly isolated and cleaned up between task executions, including a test for reentrant calls to resolved_dependencies that would previously fail with stale context. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 2822eaa commit fcbfae9

File tree

2 files changed

+269
-35
lines changed

2 files changed

+269
-35
lines changed

src/docket/dependencies.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -697,38 +697,50 @@ def __init__(self, parameter: str, error: Exception) -> None:
697697
async def resolved_dependencies(
698698
worker: "Worker", execution: Execution
699699
) -> AsyncGenerator[dict[str, Any], None]:
700-
# Set context variables once at the beginning
701-
Dependency.docket.set(worker.docket)
702-
Dependency.worker.set(worker)
703-
Dependency.execution.set(execution)
704-
705-
_Depends.cache.set({})
706-
707-
async with AsyncExitStack() as stack:
708-
_Depends.stack.set(stack)
709-
710-
arguments: dict[str, Any] = {}
711-
712-
parameters = get_dependency_parameters(execution.function)
713-
for parameter, dependency in parameters.items():
714-
kwargs = execution.kwargs
715-
if parameter in kwargs:
716-
arguments[parameter] = kwargs[parameter]
717-
continue
718-
719-
# Special case for TaskArguments, they are "magical" and infer the parameter
720-
# they refer to from the parameter name (unless otherwise specified). At
721-
# the top-level task function call, it doesn't make sense to specify one
722-
# _without_ a parameter name, so we'll call that a failed dependency.
723-
if isinstance(dependency, _TaskArgument) and not dependency.parameter:
724-
arguments[parameter] = FailedDependency(
725-
parameter, ValueError("No parameter name specified")
726-
)
727-
continue
728-
700+
# Capture tokens for all contextvar sets to ensure proper cleanup
701+
docket_token = Dependency.docket.set(worker.docket)
702+
worker_token = Dependency.worker.set(worker)
703+
execution_token = Dependency.execution.set(execution)
704+
cache_token = _Depends.cache.set({})
705+
706+
try:
707+
async with AsyncExitStack() as stack:
708+
stack_token = _Depends.stack.set(stack)
729709
try:
730-
arguments[parameter] = await stack.enter_async_context(dependency)
731-
except Exception as error:
732-
arguments[parameter] = FailedDependency(parameter, error)
733-
734-
yield arguments
710+
arguments: dict[str, Any] = {}
711+
712+
parameters = get_dependency_parameters(execution.function)
713+
for parameter, dependency in parameters.items():
714+
kwargs = execution.kwargs
715+
if parameter in kwargs:
716+
arguments[parameter] = kwargs[parameter]
717+
continue
718+
719+
# Special case for TaskArguments, they are "magical" and infer the parameter
720+
# they refer to from the parameter name (unless otherwise specified). At
721+
# the top-level task function call, it doesn't make sense to specify one
722+
# _without_ a parameter name, so we'll call that a failed dependency.
723+
if (
724+
isinstance(dependency, _TaskArgument)
725+
and not dependency.parameter
726+
):
727+
arguments[parameter] = FailedDependency(
728+
parameter, ValueError("No parameter name specified")
729+
)
730+
continue
731+
732+
try:
733+
arguments[parameter] = await stack.enter_async_context(
734+
dependency
735+
)
736+
except Exception as error:
737+
arguments[parameter] = FailedDependency(parameter, error)
738+
739+
yield arguments
740+
finally:
741+
_Depends.stack.reset(stack_token)
742+
finally:
743+
_Depends.cache.reset(cache_token)
744+
Dependency.execution.reset(execution_token)
745+
Dependency.worker.reset(worker_token)
746+
Dependency.docket.reset(docket_token)

tests/test_dependencies.py

Lines changed: 223 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
import pytest
66

77
from docket import CurrentDocket, CurrentWorker, Docket, Worker
8-
from docket.dependencies import Depends, ExponentialRetry, Retry, TaskArgument
8+
from docket.dependencies import (
9+
Depends,
10+
Dependency,
11+
ExponentialRetry,
12+
Retry,
13+
TaskArgument,
14+
_Depends, # type: ignore[attr-defined]
15+
resolved_dependencies,
16+
)
17+
from docket.execution import Execution
918

1019

1120
async def test_dependencies_may_be_duplicated(docket: Docket, worker: Worker):
@@ -449,3 +458,216 @@ async def dependent_task(result: int = Depends(sync_adder)):
449458
await worker.run_until_finished()
450459

451460
assert called
461+
462+
463+
async def test_contextvar_isolation_between_tasks(docket: Docket, worker: Worker):
464+
"""Contextvars should be isolated between sequential task executions"""
465+
executions_seen: list[tuple[str, Execution]] = []
466+
467+
async def first_task(a: str):
468+
# Capture the execution context during first task
469+
execution = Dependency.execution.get()
470+
executions_seen.append(("first", execution))
471+
assert a == "first"
472+
473+
async def second_task(b: str):
474+
# Capture the execution context during second task
475+
execution = Dependency.execution.get()
476+
executions_seen.append(("second", execution))
477+
assert b == "second"
478+
479+
# The execution should be different from the first task
480+
first_execution = executions_seen[0][1]
481+
assert execution is not first_execution
482+
assert execution.kwargs["b"] == "second"
483+
assert first_execution.kwargs["a"] == "first"
484+
485+
await docket.add(first_task)(a="first")
486+
await docket.add(second_task)(b="second")
487+
await worker.run_until_finished()
488+
489+
assert len(executions_seen) == 2
490+
assert executions_seen[0][0] == "first"
491+
assert executions_seen[1][0] == "second"
492+
493+
494+
async def test_contextvar_cleanup_after_task(docket: Docket, worker: Worker):
495+
"""Contextvars should be reset after task execution completes"""
496+
captured_stack = None
497+
captured_cache = None
498+
499+
async def capture_task():
500+
nonlocal captured_stack, captured_cache
501+
# Capture references during task execution
502+
captured_stack = _Depends.stack.get()
503+
captured_cache = _Depends.cache.get()
504+
505+
await docket.add(capture_task)()
506+
await worker.run_until_finished()
507+
508+
# After the task completes, the contextvars should be reset
509+
# Attempting to get them should raise LookupError
510+
with pytest.raises(LookupError):
511+
_Depends.stack.get()
512+
513+
with pytest.raises(LookupError):
514+
_Depends.cache.get()
515+
516+
with pytest.raises(LookupError):
517+
Dependency.execution.get()
518+
519+
with pytest.raises(LookupError):
520+
Dependency.worker.get()
521+
522+
with pytest.raises(LookupError):
523+
Dependency.docket.get()
524+
525+
526+
async def test_dependency_cache_isolated_between_tasks(docket: Docket, worker: Worker):
527+
"""Dependency cache should be fresh for each task, not reused"""
528+
call_counts = {"task1": 0, "task2": 0}
529+
530+
def dependency_for_task1() -> str:
531+
call_counts["task1"] += 1
532+
return f"task1-call-{call_counts['task1']}"
533+
534+
def dependency_for_task2() -> str:
535+
call_counts["task2"] += 1
536+
return f"task2-call-{call_counts['task2']}"
537+
538+
async def first_task(val: str = Depends(dependency_for_task1)):
539+
assert val == "task1-call-1"
540+
541+
async def second_task(val: str = Depends(dependency_for_task2)):
542+
assert val == "task2-call-1"
543+
544+
# Run tasks sequentially
545+
await docket.add(first_task)()
546+
await worker.run_until_finished()
547+
548+
await docket.add(second_task)()
549+
await worker.run_until_finished()
550+
551+
# Each dependency should have been called once (no cache leakage between tasks)
552+
assert call_counts["task1"] == 1
553+
assert call_counts["task2"] == 1
554+
555+
556+
async def test_async_exit_stack_cleanup(docket: Docket, worker: Worker):
557+
"""AsyncExitStack should be properly cleaned up after task execution"""
558+
from contextlib import asynccontextmanager
559+
560+
cleanup_called: list[str] = []
561+
562+
@asynccontextmanager
563+
async def tracked_resource():
564+
try:
565+
yield "resource"
566+
finally:
567+
cleanup_called.append("cleaned")
568+
569+
async def task_with_context(res: str = Depends(tracked_resource)):
570+
assert res == "resource"
571+
assert len(cleanup_called) == 0 # Not cleaned up yet
572+
573+
await docket.add(task_with_context)()
574+
await worker.run_until_finished()
575+
576+
# After task completes, cleanup should have been called
577+
assert cleanup_called == ["cleaned"]
578+
579+
580+
async def test_contextvar_reset_on_reentrant_call(docket: Docket, worker: Worker):
581+
"""Contextvars should be properly reset on reentrant calls to resolved_dependencies"""
582+
583+
# Create two mock executions
584+
async def task1():
585+
pass
586+
587+
async def task2():
588+
pass
589+
590+
execution1 = Execution(
591+
key="task1-key",
592+
function=task1,
593+
args=(),
594+
kwargs={},
595+
attempt=1,
596+
when=datetime.now(timezone.utc),
597+
)
598+
599+
execution2 = Execution(
600+
key="task2-key",
601+
function=task2,
602+
args=(),
603+
kwargs={},
604+
attempt=1,
605+
when=datetime.now(timezone.utc),
606+
)
607+
608+
# Capture contextvars from first call
609+
captured_exec1 = None
610+
captured_stack1 = None
611+
612+
async with resolved_dependencies(worker, execution1):
613+
captured_exec1 = Dependency.execution.get()
614+
captured_stack1 = _Depends.stack.get()
615+
assert captured_exec1 is execution1
616+
617+
# After exiting, contextvars should be reset (raise LookupError)
618+
try:
619+
current_exec = Dependency.execution.get()
620+
# If we get here without LookupError, check if it's stale
621+
assert current_exec is not execution1, (
622+
"Contextvar still points to old execution!"
623+
)
624+
except LookupError:
625+
# Expected - contextvar was properly reset
626+
pass
627+
628+
# Now make a second call - should not see values from first call
629+
async with resolved_dependencies(worker, execution2):
630+
captured_exec2 = Dependency.execution.get()
631+
captured_stack2 = _Depends.stack.get()
632+
assert captured_exec2 is execution2
633+
assert captured_exec2 is not captured_exec1
634+
# Stacks should be different objects
635+
assert captured_stack2 is not captured_stack1
636+
637+
638+
async def test_contextvar_not_leaked_to_caller(docket: Docket):
639+
"""Verify contextvars don't leak outside resolved_dependencies context"""
640+
# Before calling resolved_dependencies, contextvars should not be set
641+
with pytest.raises(LookupError):
642+
Dependency.execution.get()
643+
644+
async def dummy_task():
645+
pass
646+
647+
execution = Execution(
648+
key="test-key",
649+
function=dummy_task,
650+
args=(),
651+
kwargs={},
652+
attempt=1,
653+
when=datetime.now(timezone.utc),
654+
)
655+
656+
from docket.worker import Worker
657+
658+
async with Docket("test-contextvar-leak", url="memory://leak-test") as test_docket:
659+
async with Worker(test_docket) as test_worker:
660+
# Use resolved_dependencies
661+
async with resolved_dependencies(test_worker, execution):
662+
# Inside context, we should be able to get values
663+
assert Dependency.execution.get() is execution
664+
665+
# After exiting context, contextvars should be cleaned up
666+
with pytest.raises(LookupError):
667+
Dependency.execution.get()
668+
669+
with pytest.raises(LookupError):
670+
_Depends.stack.get()
671+
672+
with pytest.raises(LookupError):
673+
_Depends.cache.get()

0 commit comments

Comments
 (0)