22
33import json
44import logging
5+ from concurrent .futures import ThreadPoolExecutor , as_completed
56from typing import TYPE_CHECKING , Any , Dict , List , Optional , cast
67
78import boto3
@@ -50,6 +51,7 @@ def __init__(
5051 boto_session : Optional [boto3 .Session ] = None ,
5152 boto_client_config : Optional [BotocoreConfig ] = None ,
5253 region_name : Optional [str ] = None ,
54+ max_parallel_reads : int = 1 ,
5355 ** kwargs : Any ,
5456 ):
5557 """Initialize S3SessionManager with S3 storage.
@@ -62,11 +64,20 @@ def __init__(
6264 boto_session: Optional boto3 session
6365 boto_client_config: Optional boto3 client configuration
6466 region_name: AWS region for S3 storage
67+ max_parallel_reads: Maximum number of parallel S3 read operations for list_messages().
68+ Defaults to 1 (sequential) for backward compatibility and safety.
69+ Set to a higher value (e.g., 10) for better performance with many messages.
70+ Can be overridden per-call via list_messages() kwargs.
6571 **kwargs: Additional keyword arguments for future extensibility.
6672 """
6773 self .bucket = bucket
6874 self .prefix = prefix
6975
76+ # Validate max_parallel_reads
77+ if not isinstance (max_parallel_reads , int ) or max_parallel_reads < 1 :
78+ raise ValueError (f"max_parallel_reads must be a positive integer, got { max_parallel_reads } " )
79+ self .max_parallel_reads = max_parallel_reads
80+
7081 session = boto_session or boto3 .Session (region_name = region_name )
7182
7283 # Add strands-agents to the request user agent
@@ -259,7 +270,24 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
259270 def list_messages (
260271 self , session_id : str , agent_id : str , limit : Optional [int ] = None , offset : int = 0 , ** kwargs : Any
261272 ) -> List [SessionMessage ]:
262- """List messages for an agent with pagination from S3."""
273+ """List messages for an agent with pagination from S3.
274+
275+ Args:
276+ session_id: ID of the session
277+ agent_id: ID of the agent
278+ limit: Optional limit on number of messages to return
279+ offset: Optional offset for pagination
280+ **kwargs: Additional keyword arguments. Supports:
281+ max_parallel_reads: Override the instance-level max_parallel_reads setting
282+ for this call only.
283+
284+ Returns:
285+ List of SessionMessage objects, sorted by message_id.
286+
287+ Raises:
288+ ValueError: If max_parallel_reads override is not a positive integer.
289+ SessionException: If S3 error occurs during message retrieval.
290+ """
263291 messages_prefix = f"{ self ._get_agent_path (session_id , agent_id )} messages/"
264292 try :
265293 paginator = self .client .get_paginator ("list_objects_v2" )
@@ -287,10 +315,51 @@ def list_messages(
287315 else :
288316 message_keys = message_keys [offset :]
289317
290- # Load only the required message objects
318+ # Load message objects in parallel for better performance
291319 messages : List [SessionMessage ] = []
292- for key in message_keys :
293- message_data = self ._read_s3_object (key )
320+ if not message_keys :
321+ return messages
322+
323+ # Use ThreadPoolExecutor to fetch messages concurrently
324+ # Allow per-call override of max_parallel_reads via kwargs, otherwise use instance default
325+ max_parallel_reads_override = kwargs .get ("max_parallel_reads" )
326+ if max_parallel_reads_override is not None :
327+ if not isinstance (max_parallel_reads_override , int ) or max_parallel_reads_override < 1 :
328+ raise ValueError (
329+ f"max_parallel_reads must be a positive integer, got { max_parallel_reads_override } "
330+ )
331+ max_parallel_reads_value = max_parallel_reads_override
332+ else :
333+ # Instance default was already validated in __init__, no need to check again
334+ max_parallel_reads_value = self .max_parallel_reads
335+
336+ max_workers = min (max_parallel_reads_value , len (message_keys ))
337+
338+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
339+ # Submit all read tasks
340+ future_to_key = {executor .submit (self ._read_s3_object , key ): key for key in message_keys }
341+
342+ # Create a mapping from key to index to maintain order
343+ key_to_index = {key : idx for idx , key in enumerate (message_keys )}
344+
345+ # Initialize results list with None placeholders to maintain order
346+ results : List [Optional [Dict [str , Any ]]] = [None ] * len (message_keys )
347+
348+ # Process results as they complete
349+ for future in as_completed (future_to_key ):
350+ key = future_to_key [future ]
351+ try :
352+ message_data = future .result ()
353+ # Store result at the correct index to maintain order
354+ results [key_to_index [key ]] = message_data
355+ except Exception as e :
356+ # Log error but continue processing other messages
357+ # Individual failures shouldn't stop the entire operation
358+ logger .warning ("key=<%s> | failed to read message from s3" , key , exc_info = e )
359+
360+ # Convert results to SessionMessage objects, filtering out None values
361+ # If SessionMessage.from_dict fails, let it propagate - data corruption should be visible
362+ for message_data in results :
294363 if message_data :
295364 messages .append (SessionMessage .from_dict (message_data ))
296365
0 commit comments