Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 104 additions & 81 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,14 @@ def __init__(
# and returned to pool from the left side. Stale sockets removed
# from the right side.
self.conns: collections.deque[AsyncConnection] = collections.deque()
self._conns_lock = _async_create_lock()
self.active_contexts: set[_CancellationContext] = set()
self._active_contexts_lock = _async_create_lock()
# The main lock for the pool. The lock should only be used to protect
# updating attributes.
# If possible, avoid any additional work while holding the lock.
# If looping over an attribute, copy the container and do not take the lock.
self.lock = _async_create_lock()
self._max_connecting_cond = _async_create_condition(self.lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
Expand All @@ -760,15 +765,19 @@ def __init__(
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = _async_create_condition(self.lock)
# Use a different lock to prevent lock contention. This lock protects
# "requests".
self.size_cond = _async_create_condition(_async_create_lock())
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
self.max_pool_size = float("inf")
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _async_create_condition(self.lock)
# Use a different lock to prevent lock contention. This lock protects
# "_pending".
self._max_connecting_cond = _async_create_condition(_async_create_lock())
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
Expand All @@ -788,6 +797,7 @@ def __init__(
)
# Similar to active_sockets but includes threads in the wait queue.
self.operation_count: int = 0
self._operation_count_lock = _async_create_lock()
# Retain references to pinned connections to prevent the CPython GC
# from thinking that a cursor's pinned connection can be GC'd when the
# cursor is GC'd (see PYTHON-2751).
Expand All @@ -797,20 +807,24 @@ def __init__(

async def ready(self) -> None:
# Take the lock to avoid the race condition described in PYTHON-2699.
async with self.lock:
if self.state != PoolState.READY:
state_changed = False
if self.state != PoolState.READY:
async with self.lock:
self.state = PoolState.READY
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_ready(self.address)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
message=_ConnectionStatusMessage.POOL_READY,
clientId=self._client_id,
serverHost=self.address[0],
serverPort=self.address[1],
)
state_changed = True
if not state_changed:
return
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_ready(self.address)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
message=_ConnectionStatusMessage.POOL_READY,
clientId=self._client_id,
serverHost=self.address[0],
serverPort=self.address[1],
)

@property
def closed(self) -> bool:
Expand All @@ -824,38 +838,45 @@ async def _reset(
interrupt_connections: bool = False,
) -> None:
old_state = self.state
async with self.size_cond:
if self.closed:
return
if self.closed:
return

async with self.lock:
if self.opts.pause_enabled and pause and not self.opts.load_balanced:
old_state, self.state = self.state, PoolState.PAUSED
self.gen.inc(service_id)
newpid = os.getpid()

if self.pid != newpid:
self.pid = newpid
self.active_sockets = 0
self.operation_count = 0

self.active_sockets = 0
async with self._conns_lock:
if service_id is None:
sockets, self.conns = self.conns, collections.deque()
else:
discard: collections.deque = collections.deque() # type: ignore[type-arg]
keep: collections.deque = collections.deque() # type: ignore[type-arg]
for conn in self.conns:
if conn.service_id == service_id:
discard.append(conn)
else:
keep.append(conn)
sockets = discard
async with self._operation_count_lock:
self.operation_count = 0
if service_id is not None:
discard: collections.deque = collections.deque() # type: ignore[type-arg]
keep: collections.deque = collections.deque() # type: ignore[type-arg]
for conn in self.conns.copy():
if conn.service_id == service_id:
discard.append(conn)
else:
keep.append(conn)
sockets = discard
async with self._conns_lock:
self.conns = keep

if close:
self.state = PoolState.CLOSED
async with self.lock:
self.state = PoolState.CLOSED
# Clear the wait queue
self._max_connecting_cond.notify_all()
self.size_cond.notify_all()

if interrupt_connections:
for context in self.active_contexts:
for context in self.active_contexts.copy():
context.cancel()

listeners = self.opts._event_listeners
Expand Down Expand Up @@ -914,9 +935,8 @@ async def update_is_writable(self, is_writable: Optional[bool]) -> None:
Pool.
"""
self.is_writable = is_writable
async with self.lock:
for _socket in self.conns:
_socket.update_is_writable(self.is_writable) # type: ignore[arg-type]
for _socket in self.conns.copy():
_socket.update_is_writable(self.is_writable) # type: ignore[arg-type]

async def reset(
self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False
Expand Down Expand Up @@ -947,12 +967,9 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:

if self.opts.max_idle_time_seconds is not None:
close_conns = []
async with self.lock:
while (
self.conns
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
close_conns.append(self.conns.pop())
conns = self.conns.copy()
while conns and conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds:
close_conns.append(self.conns.pop())
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value]
Expand All @@ -963,12 +980,12 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
await conn.close_conn(ConnectionClosedReason.IDLE)

while True:
# There are enough sockets in the pool.
if len(self.conns) + self.active_sockets >= self.opts.min_pool_size:
return
if self.requests >= self.opts.min_pool_size:
return
async with self.size_cond:
# There are enough sockets in the pool.
if len(self.conns) + self.active_sockets >= self.opts.min_pool_size:
return
if self.requests >= self.opts.min_pool_size:
return
self.requests += 1
incremented = False
try:
Expand All @@ -978,16 +995,17 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
if self._pending >= self._max_connecting:
return
self._pending += 1
incremented = True
incremented = True
conn = await self.connect()
close_conn = False
async with self.lock:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation:
close_conn = True
if not close_conn:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation:
close_conn = True
if not close_conn:
async with self._conns_lock:
self.conns.appendleft(conn)
async with self._active_contexts_lock:
self.active_contexts.discard(conn.cancel_context)
if close_conn:
await conn.close_conn(ConnectionClosedReason.STALE)
Expand All @@ -1011,11 +1029,11 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
Note that the pool does not keep a reference to the socket -- you
must call checkin() when you're done with it.
"""
async with self.lock:
conn_id = self.next_connection_id
# Use a temporary context so that interrupt_connections can cancel creating the socket.
tmp_context = _CancellationContext()
conn_id = self.next_connection_id
async with self._active_contexts_lock:
self.next_connection_id += 1
# Use a temporary context so that interrupt_connections can cancel creating the socket.
tmp_context = _CancellationContext()
self.active_contexts.add(tmp_context)

listeners = self.opts._event_listeners
Expand All @@ -1036,7 +1054,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
networking_interface = await _configured_protocol_interface(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
async with self.lock:
async with self._active_contexts_lock:
self.active_contexts.discard(tmp_context)
if self.enabled_for_cmap:
assert listeners is not None
Expand All @@ -1061,7 +1079,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
raise

conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
async with self.lock:
async with self._active_contexts_lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
Expand All @@ -1076,7 +1094,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
async with self.lock:
async with self._active_contexts_lock:
self.active_contexts.discard(conn.cancel_context)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise
Expand Down Expand Up @@ -1136,7 +1154,7 @@ async def checkout(
durationMS=duration,
)
try:
async with self.lock:
async with self._active_contexts_lock:
self.active_contexts.add(conn.cancel_context)
yield conn
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
Expand All @@ -1155,11 +1173,11 @@ async def checkout(
await self.checkin(conn)
raise
if conn.pinned_txn:
async with self.lock:
async with self._active_contexts_lock:
self.__pinned_sockets.add(conn)
self.ntxns += 1
elif conn.pinned_cursor:
async with self.lock:
async with self._active_contexts_lock:
self.__pinned_sockets.add(conn)
self.ncursors += 1
elif conn.active:
Expand Down Expand Up @@ -1223,7 +1241,7 @@ async def _get_conn(
"Attempted to check out a connection from closed connection pool"
)

async with self.lock:
async with self._operation_count_lock:
self.operation_count += 1

# Get a free socket or create one.
Expand Down Expand Up @@ -1254,7 +1272,7 @@ async def _get_conn(
try:
async with self.lock:
self.active_sockets += 1
incremented = True
incremented = True
while conn is None:
# CMAP: we MUST wait for either maxConnecting OR for a socket
# to be checked back into the pool.
Expand All @@ -1272,7 +1290,8 @@ async def _get_conn(
self._raise_if_not_ready(checkout_started_time, emit_event=False)

try:
conn = self.conns.popleft()
async with self._conns_lock:
conn = self.conns.popleft()
except IndexError:
self._pending += 1
if conn: # We got a socket from the pool
Expand All @@ -1291,10 +1310,11 @@ async def _get_conn(
if conn:
# We checked out a socket but authentication failed.
await conn.close_conn(ConnectionClosedReason.ERROR)
if incremented:
async with self.lock:
self.active_sockets -= 1
async with self.size_cond:
self.requests -= 1
if incremented:
self.active_sockets -= 1
self.size_cond.notify()

if not emitted_event:
Expand Down Expand Up @@ -1330,9 +1350,9 @@ async def checkin(self, conn: AsyncConnection) -> None:
conn.active = False
conn.pinned_txn = False
conn.pinned_cursor = False
self.__pinned_sockets.discard(conn)
listeners = self.opts._event_listeners
async with self.lock:
async with self._active_contexts_lock:
self.__pinned_sockets.discard(conn)
self.active_contexts.discard(conn.cancel_context)
if self.enabled_for_cmap:
assert listeners is not None
Expand Down Expand Up @@ -1371,28 +1391,31 @@ async def checkin(self, conn: AsyncConnection) -> None:
)
else:
close_conn = False
async with self.lock:
# Hold the lock to ensure this section does not race with
# Pool.reset().
if self.stale_generation(conn.generation, conn.service_id):
close_conn = True
else:
conn.update_last_checkin_time()
conn.update_is_writable(bool(self.is_writable))
conn.update_last_checkin_time()
conn.update_is_writable(bool(self.is_writable))
if self.stale_generation(conn.generation, conn.service_id):
close_conn = True
else:
async with self._conns_lock:
self.conns.appendleft(conn)
async with self._max_connecting_cond:
# Notify any threads waiting to create a connection.
self._max_connecting_cond.notify()
if close_conn:
await conn.close_conn(ConnectionClosedReason.STALE)

async with self.size_cond:
async with self._active_contexts_lock:
self.active_sockets -= 1
if txn:
self.ntxns -= 1
elif cursor:
self.ncursors -= 1
self.requests -= 1
self.active_sockets -= 1

async with self._operation_count_lock:
self.operation_count -= 1

async with self.size_cond:
self.requests -= 1
self.size_cond.notify()

async def _perished(self, conn: AsyncConnection) -> bool:
Expand Down
Loading
Loading