Skip to content

Commit adc5c10

Browse files
committed
Track next_msg and async for generators the same way
Signed-off-by: Waldemar Quevedo <[email protected]>
1 parent 54642f2 commit adc5c10

File tree

3 files changed

+279
-27
lines changed

3 files changed

+279
-27
lines changed

nats/src/nats/aio/client.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -754,12 +754,8 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
754754
# Async subs use join when draining already so just cancel here.
755755
if sub._wait_for_msgs_task and not sub._wait_for_msgs_task.done():
756756
sub._wait_for_msgs_task.cancel()
757-
# Sync subs may have some inflight next_msg calls that could be blocking
758-
# so cancel them here to unblock them.
759-
if sub._pending_next_msgs_calls:
760-
for fut in sub._pending_next_msgs_calls.values():
761-
fut.cancel()
762-
sub._pending_next_msgs_calls.clear()
757+
# For sync subs, stop processing will send sentinels to unblock any waiting consumers
758+
sub._stop_processing()
763759
self._subs.clear()
764760

765761
if self._transport is not None:
@@ -1802,9 +1798,9 @@ async def _process_msg(
18021798
await sub._jsi.check_for_sequence_mismatch(msg)
18031799

18041800
# Send sentinel after reaching max messages for non-callback subscriptions.
1805-
if max_msgs_reached and not sub._cb and sub._active_generators > 0:
1806-
# Send one sentinel per active generator to unblock them all.
1807-
for _ in range(sub._active_generators):
1801+
if max_msgs_reached and not sub._cb and sub._active_consumers is not None and sub._active_consumers > 0:
1802+
# Send one sentinel per active consumer to unblock them all.
1803+
for _ in range(sub._active_consumers):
18081804
try:
18091805
sub._pending_queue.put_nowait(None)
18101806
except Exception:

nats/src/nats/aio/subscription.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
AsyncIterator,
2121
Awaitable,
2222
Callable,
23-
Dict,
2423
Optional,
2524
)
2625

@@ -80,18 +79,16 @@ def __init__(
8079
self._cb = cb
8180
self._future = future
8281
self._closed = False
83-
self._active_generators = 0 # Track active async generators
8482

8583
# Per subscription message processor.
8684
self._pending_msgs_limit = pending_msgs_limit
8785
self._pending_bytes_limit = pending_bytes_limit
8886
self._pending_queue: asyncio.Queue[Msg] = asyncio.Queue(maxsize=pending_msgs_limit)
89-
# If no callback, then this is a sync subscription which will
90-
# require tracking the next_msg calls inflight for cancelling.
87+
# Track active consumers (both async generators and next_msg calls) for non-callback subscriptions.
9188
if cb is None:
92-
self._pending_next_msgs_calls: Optional[Dict[str, asyncio.Task]] = {}
89+
self._active_consumers = 0 # Counter of active consumers waiting for messages
9390
else:
94-
self._pending_next_msgs_calls = None
91+
self._active_consumers = None
9592
self._pending_size = 0
9693
self._wait_for_msgs_task = None
9794

@@ -138,7 +135,8 @@ async def _message_generator(self) -> AsyncIterator[Msg]:
138135
Async generator that yields messages directly from the subscription queue.
139136
"""
140137
yielded_count = 0
141-
self._active_generators += 1
138+
if self._active_consumers is not None:
139+
self._active_consumers += 1
142140
try:
143141
while True:
144142
# Check if subscription was cancelled/closed.
@@ -171,7 +169,8 @@ async def _message_generator(self) -> AsyncIterator[Msg]:
171169
except asyncio.CancelledError:
172170
pass
173171
finally:
174-
self._active_generators -= 1
172+
if self._active_consumers is not None:
173+
self._active_consumers -= 1
175174

176175
@property
177176
def pending_msgs(self) -> int:
@@ -225,6 +224,10 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
225224
if self._cb:
226225
raise errors.Error("nats: next_msg cannot be used in async subscriptions")
227226

227+
# Track this next_msg call
228+
if self._active_consumers is not None:
229+
self._active_consumers += 1
230+
228231
try:
229232
if timeout == 0 or timeout is None:
230233
# Wait forever for a message
@@ -240,13 +243,25 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
240243
if self._conn.is_closed:
241244
raise errors.ConnectionClosedError
242245
raise
243-
else:
244-
self._pending_size -= len(msg.data)
245-
# For sync subscriptions we will consider a message
246-
# to be done once it has been consumed by the client
247-
# regardless of whether it has been processed.
246+
finally:
247+
# Untrack this next_msg call.
248+
if self._active_consumers is not None:
249+
self._active_consumers -= 1
250+
251+
# Check for sentinel value which signals to stop
252+
if msg is None:
248253
self._pending_queue.task_done()
249-
return msg
254+
if self._conn.is_closed:
255+
raise errors.ConnectionClosedError
256+
raise errors.TimeoutError
257+
258+
self._pending_size -= len(msg.data)
259+
260+
# NOTE: For sync subscriptions we will consider a message
261+
# to be done once it has been consumed by the client
262+
# regardless of whether it has been processed.
263+
self._pending_queue.task_done()
264+
return msg
250265

251266
def _start(self, error_cb):
252267
"""
@@ -337,11 +352,12 @@ def _stop_processing(self) -> None:
337352
if self._wait_for_msgs_task and not self._wait_for_msgs_task.done():
338353
self._wait_for_msgs_task.cancel()
339354

340-
# Only put sentinel if there are active async generators
355+
# Send sentinels to unblock waiting consumers
341356
try:
342-
if self._pending_queue and self._active_generators > 0:
343-
# Put a None sentinel to wake up any async generators
344-
self._pending_queue.put_nowait(None)
357+
if self._pending_queue and self._active_consumers is not None and self._active_consumers > 0:
358+
# Send one sentinel for each active consumer (both generators and next_msg calls)
359+
for _ in range(self._active_consumers):
360+
self._pending_queue.put_nowait(None)
345361
except Exception:
346362
# Queue might be closed or full, that's ok
347363
pass

nats/tests/test_client.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,246 @@ async def test_subscribe_async_generator_with_drain(self):
800800

801801
await nc.close()
802802

803+
@async_test
804+
async def test_subscribe_concurrent_next_msg(self):
805+
"""Test multiple concurrent next_msg() calls on the same subscription"""
806+
nc = NATS()
807+
await nc.connect()
808+
809+
sub = await nc.subscribe("test.concurrent.next")
810+
811+
# Publish messages
812+
num_msgs = 12
813+
for i in range(num_msgs):
814+
await nc.publish("test.concurrent.next", f"msg-{i}".encode())
815+
await nc.flush()
816+
817+
# Track results from concurrent next_msg calls
818+
consumer_results = {}
819+
820+
async def consumer_task(consumer_id: str, msg_count: int):
821+
"""Consumer task that uses next_msg() to get messages"""
822+
import random
823+
824+
received = []
825+
try:
826+
for _ in range(msg_count):
827+
msg = await sub.next_msg(timeout=2.0)
828+
received.append(msg.data.decode())
829+
# Add random processing delay
830+
await asyncio.sleep(random.uniform(0.01, 0.03))
831+
except Exception as e:
832+
consumer_results[consumer_id] = f"Error: {e}"
833+
return
834+
consumer_results[consumer_id] = received
835+
836+
# Start multiple concurrent consumers using next_msg()
837+
tasks = [
838+
asyncio.create_task(consumer_task("consumer_A", 3)),
839+
asyncio.create_task(consumer_task("consumer_B", 5)),
840+
asyncio.create_task(consumer_task("consumer_C", 4)),
841+
]
842+
843+
# Wait for all consumers to finish
844+
await asyncio.gather(*tasks)
845+
846+
# Verify results
847+
consumer_A_msgs = consumer_results.get("consumer_A", [])
848+
consumer_B_msgs = consumer_results.get("consumer_B", [])
849+
consumer_C_msgs = consumer_results.get("consumer_C", [])
850+
851+
# All consumers should have finished without errors
852+
self.assertIsInstance(consumer_A_msgs, list, f"Consumer A failed: {consumer_A_msgs}")
853+
self.assertIsInstance(consumer_B_msgs, list, f"Consumer B failed: {consumer_B_msgs}")
854+
self.assertIsInstance(consumer_C_msgs, list, f"Consumer C failed: {consumer_C_msgs}")
855+
856+
# Each consumer should get exactly what they requested
857+
self.assertEqual(len(consumer_A_msgs), 3, f"Consumer A got {len(consumer_A_msgs)} messages, expected 3")
858+
self.assertEqual(len(consumer_B_msgs), 5, f"Consumer B got {len(consumer_B_msgs)} messages, expected 5")
859+
self.assertEqual(len(consumer_C_msgs), 4, f"Consumer C got {len(consumer_C_msgs)} messages, expected 4")
860+
861+
# All messages should be unique (no duplicates across consumers)
862+
all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
863+
self.assertEqual(
864+
len(all_received),
865+
len(set(all_received)),
866+
f"Found duplicate messages: {[msg for msg in all_received if all_received.count(msg) > 1]}",
867+
)
868+
869+
# All received messages should be from our published set
870+
expected_msgs = {f"msg-{i}" for i in range(num_msgs)}
871+
received_msgs = set(all_received)
872+
self.assertTrue(received_msgs.issubset(expected_msgs))
873+
874+
# Total should be exactly 12 messages consumed
875+
self.assertEqual(len(received_msgs), 12)
876+
877+
await nc.close()
878+
879+
@async_test
880+
async def test_subscribe_concurrent_next_msg_with_unsubscribe_limit(self):
881+
"""Test concurrent next_msg() calls with unsubscribe limit"""
882+
nc = NATS()
883+
await nc.connect()
884+
885+
sub = await nc.subscribe("test.concurrent.next.limit")
886+
await sub.unsubscribe(limit=8) # Auto-unsubscribe after 8 messages
887+
888+
# Publish more messages than the limit
889+
num_msgs = 15
890+
for i in range(num_msgs):
891+
await nc.publish("test.concurrent.next.limit", f"msg-{i}".encode())
892+
await nc.flush()
893+
894+
# Track results from concurrent next_msg calls
895+
consumer_results = {}
896+
897+
async def consumer_task(consumer_id: str, max_attempts: int):
898+
"""Consumer that keeps calling next_msg until timeout or limit reached"""
899+
import random
900+
901+
received = []
902+
try:
903+
for attempt in range(max_attempts):
904+
try:
905+
msg = await sub.next_msg(timeout=0.5)
906+
received.append(msg.data.decode())
907+
# Add random processing delay
908+
await asyncio.sleep(random.uniform(0.005, 0.02))
909+
except Exception as e:
910+
# Expected when subscription reaches limit
911+
break
912+
except Exception as e:
913+
consumer_results[consumer_id] = f"Error: {e}"
914+
return
915+
consumer_results[consumer_id] = received
916+
917+
# Start multiple concurrent consumers
918+
tasks = [
919+
asyncio.create_task(consumer_task("consumer_A", 10)),
920+
asyncio.create_task(consumer_task("consumer_B", 10)),
921+
asyncio.create_task(consumer_task("consumer_C", 10)),
922+
]
923+
924+
# Wait for all consumers to finish
925+
await asyncio.gather(*tasks)
926+
927+
# Verify results
928+
consumer_A_msgs = consumer_results.get("consumer_A", [])
929+
consumer_B_msgs = consumer_results.get("consumer_B", [])
930+
consumer_C_msgs = consumer_results.get("consumer_C", [])
931+
932+
# All consumers should have finished without errors
933+
self.assertIsInstance(consumer_A_msgs, list, f"Consumer A failed: {consumer_A_msgs}")
934+
self.assertIsInstance(consumer_B_msgs, list, f"Consumer B failed: {consumer_B_msgs}")
935+
self.assertIsInstance(consumer_C_msgs, list, f"Consumer C failed: {consumer_C_msgs}")
936+
937+
# Total messages across all consumers should be exactly 8 (the unsubscribe limit)
938+
all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
939+
self.assertEqual(len(all_received), 8, f"Expected 8 total messages, got {len(all_received)}: {all_received}")
940+
941+
# All messages should be unique (no duplicates)
942+
self.assertEqual(
943+
len(all_received),
944+
len(set(all_received)),
945+
f"Found duplicate messages: {[msg for msg in all_received if all_received.count(msg) > 1]}",
946+
)
947+
948+
# All received messages should be from our published set
949+
expected_msgs = {f"msg-{i}" for i in range(num_msgs)}
950+
received_msgs = set(all_received)
951+
self.assertTrue(received_msgs.issubset(expected_msgs))
952+
953+
# Verify subscription reached its limit
954+
self.assertEqual(sub._received, 8)
955+
self.assertEqual(sub._max_msgs, 8)
956+
957+
await nc.close()
958+
959+
@async_test
960+
async def test_subscribe_concurrent_next_msg_with_timeout(self):
961+
"""Test concurrent next_msg() calls with different timeout behaviors"""
962+
nc = NATS()
963+
await nc.connect()
964+
965+
sub = await nc.subscribe("test.concurrent.next.timeout")
966+
967+
# Publish only a few messages (less than what consumers will request)
968+
num_msgs = 3
969+
for i in range(num_msgs):
970+
await nc.publish("test.concurrent.next.timeout", f"msg-{i}".encode())
971+
await nc.flush()
972+
973+
# Track results and timing
974+
consumer_results = {}
975+
976+
async def consumer_task(consumer_id: str, requests: int, timeout: float):
977+
"""Consumer that requests more messages than available"""
978+
import time
979+
980+
received = []
981+
timeouts = 0
982+
start_time = time.time()
983+
984+
try:
985+
for _ in range(requests):
986+
try:
987+
msg = await sub.next_msg(timeout=timeout)
988+
received.append(msg.data.decode())
989+
except Exception as e:
990+
if "timeout" in str(e).lower():
991+
timeouts += 1
992+
else:
993+
break
994+
995+
end_time = time.time()
996+
consumer_results[consumer_id] = {
997+
"received": received,
998+
"timeouts": timeouts,
999+
"duration": end_time - start_time,
1000+
}
1001+
except Exception as e:
1002+
consumer_results[consumer_id] = f"Error: {e}"
1003+
1004+
# Start consumers with different timeout strategies
1005+
tasks = [
1006+
asyncio.create_task(consumer_task("fast_timeout", 5, 0.1)), # Fast timeout
1007+
asyncio.create_task(consumer_task("medium_timeout", 5, 0.3)), # Medium timeout
1008+
asyncio.create_task(consumer_task("slow_timeout", 5, 0.5)), # Slow timeout
1009+
]
1010+
1011+
# Wait for all consumers to finish
1012+
await asyncio.gather(*tasks)
1013+
1014+
# Verify results - collect all data first
1015+
all_received = []
1016+
total_timeouts = 0
1017+
consumers_with_msgs = 0
1018+
1019+
for consumer_id, result in consumer_results.items():
1020+
self.assertIsInstance(result, dict, f"Consumer {consumer_id} failed: {result}")
1021+
1022+
received = result["received"]
1023+
timeouts = result["timeouts"]
1024+
1025+
all_received.extend(received)
1026+
total_timeouts += timeouts
1027+
1028+
if len(received) > 0:
1029+
consumers_with_msgs += 1
1030+
1031+
# With only 3 messages and 3 consumers requesting 5 each, some distribution is expected
1032+
# But the key thing is that all 3 messages should be consumed
1033+
self.assertEqual(len(set(all_received)), 3, f"Expected 3 unique messages, got {set(all_received)}")
1034+
1035+
# There should be timeouts since we're requesting more messages than available
1036+
self.assertGreater(total_timeouts, 0, "Should have some timeouts when requesting more messages than available")
1037+
1038+
# At least one consumer should get messages (but due to race conditions, not necessarily all)
1039+
self.assertGreater(consumers_with_msgs, 0, "At least one consumer should receive messages")
1040+
1041+
await nc.close()
1042+
8031043
@async_test
8041044
async def test_subscribe_iterate_unsub_comprehension(self):
8051045
nc = NATS()

0 commit comments

Comments
 (0)