diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index b808684dd4..f72df467b6 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -139,6 +139,7 @@ import time import uuid from collections.abc import Mapping as _Mapping +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, @@ -204,6 +205,7 @@ def __init__( causal_consistency: Optional[bool] = None, default_transaction_options: Optional[TransactionOptions] = None, snapshot: Optional[bool] = False, + bind: Optional[bool] = False, ) -> None: if snapshot: if causal_consistency: @@ -222,6 +224,7 @@ def __init__( ) self._default_transaction_options = default_transaction_options self._snapshot = snapshot + self._bind = bind @property def causal_consistency(self) -> bool: @@ -545,9 +548,12 @@ def _check_ended(self) -> None: raise InvalidOperation("Cannot use ended session") async def __aenter__(self) -> AsyncClientSession: + self._token = _SESSION.set(self) return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._token: + _SESSION.reset(self._token) await self._end_session(lock=True) @property @@ -1065,6 +1071,9 @@ def __copy__(self) -> NoReturn: raise TypeError("A AsyncClientSession cannot be copied, create a new session instead") +_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None) + + class _EmptyServerSession: __slots__ = "dirty", "started_retryable_write" diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 1b25bf4ee8..08b1895c1f 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -136,9 +136,16 @@ def __init__( self._killed = False self._session: Optional[AsyncClientSession] + from .client_session import _SESSION + + bound_session = _SESSION.get() + if session: self._session = session self._explicit_session = True + elif bound_session: + self._session = bound_session + self._explicit_session = True else: self._session = None self._explicit_session = False diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a236b21348..4b2664ad7a 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -65,7 +65,7 @@ from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk -from pymongo.asynchronous.client_session import _EmptyServerSession +from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext @@ -1355,13 +1355,18 @@ def _close_cursor_soon( def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession: server_session = _EmptyServerSession() opts = client_session.SessionOptions(**kwargs) - return client_session.AsyncClientSession(self, server_session, opts, implicit) + bind = opts._bind + session = client_session.AsyncClientSession(self, server_session, opts, implicit) + if bind: + _SESSION.set(session) + return session def start_session( self, causal_consistency: Optional[bool] = None, default_transaction_options: Optional[client_session.TransactionOptions] = None, snapshot: Optional[bool] = False, + bind: Optional[bool] = False, ) -> client_session.AsyncClientSession: """Start a logical session. @@ -1384,6 +1389,7 @@ def start_session( causal_consistency=causal_consistency, default_transaction_options=default_transaction_options, snapshot=snapshot, + bind=bind, ) def _ensure_session( diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index aaf2d7574f..d17bcc0868 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -139,6 +139,7 @@ import time import uuid from collections.abc import Mapping as _Mapping +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, @@ -203,6 +204,7 @@ def __init__( causal_consistency: Optional[bool] = None, default_transaction_options: Optional[TransactionOptions] = None, snapshot: Optional[bool] = False, + bind: Optional[bool] = False, ) -> None: if snapshot: if causal_consistency: @@ -221,6 +223,7 @@ def __init__( ) self._default_transaction_options = default_transaction_options self._snapshot = snapshot + self._bind = bind @property def causal_consistency(self) -> bool: @@ -544,9 +547,12 @@ def _check_ended(self) -> None: raise InvalidOperation("Cannot use ended session") def __enter__(self) -> ClientSession: + self._token = _SESSION.set(self) return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._token: + _SESSION.reset(self._token) self._end_session(lock=True) @property @@ -1060,6 +1066,9 @@ def __copy__(self) -> NoReturn: raise TypeError("A ClientSession cannot be copied, create a new session instead") +_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None) + + class _EmptyServerSession: __slots__ = "dirty", "started_retryable_write" diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 31c4604f89..11f1327d53 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -136,9 +136,16 @@ def __init__( self._killed = False self._session: Optional[ClientSession] + from .client_session import _SESSION + + bound_session = _SESSION.get() + if session: self._session = session self._explicit_session = True + elif bound_session: + self._session = bound_session + self._explicit_session = True else: self._session = None self._explicit_session = False diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 99a517e5c1..c2fa1b01f9 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -107,7 +107,7 @@ from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk -from pymongo.synchronous.client_session import _EmptyServerSession +from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext @@ -1353,13 +1353,18 @@ def _close_cursor_soon( def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: server_session = _EmptyServerSession() opts = client_session.SessionOptions(**kwargs) - return client_session.ClientSession(self, server_session, opts, implicit) + bind = opts._bind + session = client_session.ClientSession(self, server_session, opts, implicit) + if bind: + _SESSION.set(session) + return session def start_session( self, causal_consistency: Optional[bool] = None, default_transaction_options: Optional[client_session.TransactionOptions] = None, snapshot: Optional[bool] = False, + bind: Optional[bool] = False, ) -> client_session.ClientSession: """Start a logical session. @@ -1382,6 +1387,7 @@ def start_session( causal_consistency=causal_consistency, default_transaction_options=default_transaction_options, snapshot=snapshot, + bind=bind, ) def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]: diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 3655f49aab..a12a29353d 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -394,6 +394,23 @@ async def test_cursor_clone(self): await cursor.close() await clone.close() + async def test_bind_session(self): + coll = self.client.pymongo_test.collection + + # Explicit session via context variable. + async with self.client.start_session(bind=True) as s: + cursor = coll.find() + self.assertTrue(cursor.session is s) + + # Nested sessions. + session1 = self.client.start_session(bind=True) + async with session1: + session2 = self.client.start_session(bind=True) + async with session2: + await coll.find_one() # uses session2 + await coll.find_one() # uses session1 + await coll.find_one() # uses implicit session + async def test_cursor(self): listener = self.listener client = self.client diff --git a/test/test_session.py b/test/test_session.py index a6266884aa..cbd78df1aa 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -394,6 +394,23 @@ def test_cursor_clone(self): cursor.close() clone.close() + def test_bind_session(self): + coll = self.client.pymongo_test.collection + + # Explicit session via context variable. + with self.client.start_session(bind=True) as s: + cursor = coll.find() + self.assertTrue(cursor.session is s) + + # Nested sessions. + session1 = self.client.start_session(bind=True) + with session1: + session2 = self.client.start_session(bind=True) + with session2: + coll.find_one() # uses session2 + coll.find_one() # uses session1 + coll.find_one() # uses implicit session + def test_cursor(self): listener = self.listener client = self.client