@@ -800,6 +800,246 @@ async def test_subscribe_async_generator_with_drain(self):
800800
801801 await nc .close ()
802802
803+ @async_test
804+ async def test_subscribe_concurrent_next_msg (self ):
805+ """Test multiple concurrent next_msg() calls on the same subscription"""
806+ nc = NATS ()
807+ await nc .connect ()
808+
809+ sub = await nc .subscribe ("test.concurrent.next" )
810+
811+ # Publish messages
812+ num_msgs = 12
813+ for i in range (num_msgs ):
814+ await nc .publish ("test.concurrent.next" , f"msg-{ i } " .encode ())
815+ await nc .flush ()
816+
817+ # Track results from concurrent next_msg calls
818+ consumer_results = {}
819+
820+ async def consumer_task (consumer_id : str , msg_count : int ):
821+ """Consumer task that uses next_msg() to get messages"""
822+ import random
823+
824+ received = []
825+ try :
826+ for _ in range (msg_count ):
827+ msg = await sub .next_msg (timeout = 2.0 )
828+ received .append (msg .data .decode ())
829+ # Add random processing delay
830+ await asyncio .sleep (random .uniform (0.01 , 0.03 ))
831+ except Exception as e :
832+ consumer_results [consumer_id ] = f"Error: { e } "
833+ return
834+ consumer_results [consumer_id ] = received
835+
836+ # Start multiple concurrent consumers using next_msg()
837+ tasks = [
838+ asyncio .create_task (consumer_task ("consumer_A" , 3 )),
839+ asyncio .create_task (consumer_task ("consumer_B" , 5 )),
840+ asyncio .create_task (consumer_task ("consumer_C" , 4 )),
841+ ]
842+
843+ # Wait for all consumers to finish
844+ await asyncio .gather (* tasks )
845+
846+ # Verify results
847+ consumer_A_msgs = consumer_results .get ("consumer_A" , [])
848+ consumer_B_msgs = consumer_results .get ("consumer_B" , [])
849+ consumer_C_msgs = consumer_results .get ("consumer_C" , [])
850+
851+ # All consumers should have finished without errors
852+ self .assertIsInstance (consumer_A_msgs , list , f"Consumer A failed: { consumer_A_msgs } " )
853+ self .assertIsInstance (consumer_B_msgs , list , f"Consumer B failed: { consumer_B_msgs } " )
854+ self .assertIsInstance (consumer_C_msgs , list , f"Consumer C failed: { consumer_C_msgs } " )
855+
856+ # Each consumer should get exactly what they requested
857+ self .assertEqual (len (consumer_A_msgs ), 3 , f"Consumer A got { len (consumer_A_msgs )} messages, expected 3" )
858+ self .assertEqual (len (consumer_B_msgs ), 5 , f"Consumer B got { len (consumer_B_msgs )} messages, expected 5" )
859+ self .assertEqual (len (consumer_C_msgs ), 4 , f"Consumer C got { len (consumer_C_msgs )} messages, expected 4" )
860+
861+ # All messages should be unique (no duplicates across consumers)
862+ all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
863+ self .assertEqual (
864+ len (all_received ),
865+ len (set (all_received )),
866+ f"Found duplicate messages: { [msg for msg in all_received if all_received .count (msg ) > 1 ]} " ,
867+ )
868+
869+ # All received messages should be from our published set
870+ expected_msgs = {f"msg-{ i } " for i in range (num_msgs )}
871+ received_msgs = set (all_received )
872+ self .assertTrue (received_msgs .issubset (expected_msgs ))
873+
874+ # Total should be exactly 12 messages consumed
875+ self .assertEqual (len (received_msgs ), 12 )
876+
877+ await nc .close ()
878+
879+ @async_test
880+ async def test_subscribe_concurrent_next_msg_with_unsubscribe_limit (self ):
881+ """Test concurrent next_msg() calls with unsubscribe limit"""
882+ nc = NATS ()
883+ await nc .connect ()
884+
885+ sub = await nc .subscribe ("test.concurrent.next.limit" )
886+ await sub .unsubscribe (limit = 8 ) # Auto-unsubscribe after 8 messages
887+
888+ # Publish more messages than the limit
889+ num_msgs = 15
890+ for i in range (num_msgs ):
891+ await nc .publish ("test.concurrent.next.limit" , f"msg-{ i } " .encode ())
892+ await nc .flush ()
893+
894+ # Track results from concurrent next_msg calls
895+ consumer_results = {}
896+
897+ async def consumer_task (consumer_id : str , max_attempts : int ):
898+ """Consumer that keeps calling next_msg until timeout or limit reached"""
899+ import random
900+
901+ received = []
902+ try :
903+ for attempt in range (max_attempts ):
904+ try :
905+ msg = await sub .next_msg (timeout = 0.5 )
906+ received .append (msg .data .decode ())
907+ # Add random processing delay
908+ await asyncio .sleep (random .uniform (0.005 , 0.02 ))
909+ except Exception as e :
910+ # Expected when subscription reaches limit
911+ break
912+ except Exception as e :
913+ consumer_results [consumer_id ] = f"Error: { e } "
914+ return
915+ consumer_results [consumer_id ] = received
916+
917+ # Start multiple concurrent consumers
918+ tasks = [
919+ asyncio .create_task (consumer_task ("consumer_A" , 10 )),
920+ asyncio .create_task (consumer_task ("consumer_B" , 10 )),
921+ asyncio .create_task (consumer_task ("consumer_C" , 10 )),
922+ ]
923+
924+ # Wait for all consumers to finish
925+ await asyncio .gather (* tasks )
926+
927+ # Verify results
928+ consumer_A_msgs = consumer_results .get ("consumer_A" , [])
929+ consumer_B_msgs = consumer_results .get ("consumer_B" , [])
930+ consumer_C_msgs = consumer_results .get ("consumer_C" , [])
931+
932+ # All consumers should have finished without errors
933+ self .assertIsInstance (consumer_A_msgs , list , f"Consumer A failed: { consumer_A_msgs } " )
934+ self .assertIsInstance (consumer_B_msgs , list , f"Consumer B failed: { consumer_B_msgs } " )
935+ self .assertIsInstance (consumer_C_msgs , list , f"Consumer C failed: { consumer_C_msgs } " )
936+
937+ # Total messages across all consumers should be exactly 8 (the unsubscribe limit)
938+ all_received = consumer_A_msgs + consumer_B_msgs + consumer_C_msgs
939+ self .assertEqual (len (all_received ), 8 , f"Expected 8 total messages, got { len (all_received )} : { all_received } " )
940+
941+ # All messages should be unique (no duplicates)
942+ self .assertEqual (
943+ len (all_received ),
944+ len (set (all_received )),
945+ f"Found duplicate messages: { [msg for msg in all_received if all_received .count (msg ) > 1 ]} " ,
946+ )
947+
948+ # All received messages should be from our published set
949+ expected_msgs = {f"msg-{ i } " for i in range (num_msgs )}
950+ received_msgs = set (all_received )
951+ self .assertTrue (received_msgs .issubset (expected_msgs ))
952+
953+ # Verify subscription reached its limit
954+ self .assertEqual (sub ._received , 8 )
955+ self .assertEqual (sub ._max_msgs , 8 )
956+
957+ await nc .close ()
958+
959+ @async_test
960+ async def test_subscribe_concurrent_next_msg_with_timeout (self ):
961+ """Test concurrent next_msg() calls with different timeout behaviors"""
962+ nc = NATS ()
963+ await nc .connect ()
964+
965+ sub = await nc .subscribe ("test.concurrent.next.timeout" )
966+
967+ # Publish only a few messages (less than what consumers will request)
968+ num_msgs = 3
969+ for i in range (num_msgs ):
970+ await nc .publish ("test.concurrent.next.timeout" , f"msg-{ i } " .encode ())
971+ await nc .flush ()
972+
973+ # Track results and timing
974+ consumer_results = {}
975+
976+ async def consumer_task (consumer_id : str , requests : int , timeout : float ):
977+ """Consumer that requests more messages than available"""
978+ import time
979+
980+ received = []
981+ timeouts = 0
982+ start_time = time .time ()
983+
984+ try :
985+ for _ in range (requests ):
986+ try :
987+ msg = await sub .next_msg (timeout = timeout )
988+ received .append (msg .data .decode ())
989+ except Exception as e :
990+ if "timeout" in str (e ).lower ():
991+ timeouts += 1
992+ else :
993+ break
994+
995+ end_time = time .time ()
996+ consumer_results [consumer_id ] = {
997+ "received" : received ,
998+ "timeouts" : timeouts ,
999+ "duration" : end_time - start_time ,
1000+ }
1001+ except Exception as e :
1002+ consumer_results [consumer_id ] = f"Error: { e } "
1003+
1004+ # Start consumers with different timeout strategies
1005+ tasks = [
1006+ asyncio .create_task (consumer_task ("fast_timeout" , 5 , 0.1 )), # Fast timeout
1007+ asyncio .create_task (consumer_task ("medium_timeout" , 5 , 0.3 )), # Medium timeout
1008+ asyncio .create_task (consumer_task ("slow_timeout" , 5 , 0.5 )), # Slow timeout
1009+ ]
1010+
1011+ # Wait for all consumers to finish
1012+ await asyncio .gather (* tasks )
1013+
1014+ # Verify results - collect all data first
1015+ all_received = []
1016+ total_timeouts = 0
1017+ consumers_with_msgs = 0
1018+
1019+ for consumer_id , result in consumer_results .items ():
1020+ self .assertIsInstance (result , dict , f"Consumer { consumer_id } failed: { result } " )
1021+
1022+ received = result ["received" ]
1023+ timeouts = result ["timeouts" ]
1024+
1025+ all_received .extend (received )
1026+ total_timeouts += timeouts
1027+
1028+ if len (received ) > 0 :
1029+ consumers_with_msgs += 1
1030+
1031+ # With only 3 messages and 3 consumers requesting 5 each, some distribution is expected
1032+ # But the key thing is that all 3 messages should be consumed
1033+ self .assertEqual (len (set (all_received )), 3 , f"Expected 3 unique messages, got { set (all_received )} " )
1034+
1035+ # There should be timeouts since we're requesting more messages than available
1036+ self .assertGreater (total_timeouts , 0 , "Should have some timeouts when requesting more messages than available" )
1037+
1038+ # At least one consumer should get messages (but due to race conditions, not necessarily all)
1039+ self .assertGreater (consumers_with_msgs , 0 , "At least one consumer should receive messages" )
1040+
1041+ await nc .close ()
1042+
8031043 @async_test
8041044 async def test_subscribe_iterate_unsub_comprehension (self ):
8051045 nc = NATS ()
0 commit comments