Skip to content

PYTHON-4542 Improved sessions API #2335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -204,6 +205,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -222,6 +224,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to bind/unbind the session in ClientSession.__enter__/__exit__. That way the stack of sessions is managed correctly (ie we call _SESSION.reset(token)). Think about how nested cases will work:

session1 = client.start_session(bind=True)
with session1:
    session2 = client.start_session(bind=True)
    with session2:
        coll.find_one() # uses session2
    coll.find_one() # uses session1
coll.find_one() # uses implicit session


@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -545,9 +548,12 @@ def _check_ended(self) -> None:
raise InvalidOperation("Cannot use ended session")

async def __aenter__(self) -> AsyncClientSession:
self._token = _SESSION.set(self)
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._token:
_SESSION.reset(self._token)
await self._end_session(lock=True)

@property
Expand Down Expand Up @@ -1065,6 +1071,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A AsyncClientSession cannot be copied, create a new session instead")


_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 7 additions & 0 deletions pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,16 @@ def __init__(
self._killed = False
self._session: Optional[AsyncClientSession]

from .client_session import _SESSION

bound_session = _SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif bound_session:
self._session = bound_session
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
10 changes: 8 additions & 2 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from pymongo.asynchronous import client_session, database, uri_parser
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
Expand Down Expand Up @@ -1355,13 +1355,18 @@ def _close_cursor_soon(
def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession:
server_session = _EmptyServerSession()
opts = client_session.SessionOptions(**kwargs)
return client_session.AsyncClientSession(self, server_session, opts, implicit)
bind = opts._bind
session = client_session.AsyncClientSession(self, server_session, opts, implicit)
if bind:
_SESSION.set(session)
return session

def start_session(
self,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.AsyncClientSession:
"""Start a logical session.

Expand All @@ -1384,6 +1389,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(
Expand Down
9 changes: 9 additions & 0 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -203,6 +204,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -221,6 +223,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind

@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -544,9 +547,12 @@ def _check_ended(self) -> None:
raise InvalidOperation("Cannot use ended session")

def __enter__(self) -> ClientSession:
self._token = _SESSION.set(self)
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._token:
_SESSION.reset(self._token)
self._end_session(lock=True)

@property
Expand Down Expand Up @@ -1060,6 +1066,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A ClientSession cannot be copied, create a new session instead")


_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 7 additions & 0 deletions pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,16 @@ def __init__(
self._killed = False
self._session: Optional[ClientSession]

from .client_session import _SESSION

bound_session = _SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif bound_session:
self._session = bound_session
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
10 changes: 8 additions & 2 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
from pymongo.synchronous import client_session, database, uri_parser
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
Expand Down Expand Up @@ -1353,13 +1353,18 @@ def _close_cursor_soon(
def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession:
server_session = _EmptyServerSession()
opts = client_session.SessionOptions(**kwargs)
return client_session.ClientSession(self, server_session, opts, implicit)
bind = opts._bind
session = client_session.ClientSession(self, server_session, opts, implicit)
if bind:
_SESSION.set(session)
return session

def start_session(
self,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.ClientSession:
"""Start a logical session.

Expand All @@ -1382,6 +1387,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
Expand Down
17 changes: 17 additions & 0 deletions test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,23 @@ async def test_cursor_clone(self):
await cursor.close()
await clone.close()

async def test_bind_session(self):
coll = self.client.pymongo_test.collection

# Explicit session via context variable.
async with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)

# Nested sessions.
session1 = self.client.start_session(bind=True)
async with session1:
session2 = self.client.start_session(bind=True)
async with session2:
await coll.find_one() # uses session2
await coll.find_one() # uses session1
await coll.find_one() # uses implicit session

async def test_cursor(self):
listener = self.listener
client = self.client
Expand Down
17 changes: 17 additions & 0 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,23 @@ def test_cursor_clone(self):
cursor.close()
clone.close()

def test_bind_session(self):
coll = self.client.pymongo_test.collection

# Explicit session via context variable.
with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)

# Nested sessions.
session1 = self.client.start_session(bind=True)
with session1:
session2 = self.client.start_session(bind=True)
with session2:
coll.find_one() # uses session2
coll.find_one() # uses session1
coll.find_one() # uses implicit session

def test_cursor(self):
listener = self.listener
client = self.client
Expand Down
Loading