Skip to content

Commit e96ca3b

Browse files
committed
Add parallel reading support to S3SessionManager.list_messages()
- Add max_parallel_reads parameter to S3SessionManager.__init__() - Implement parallel S3 reads using ThreadPoolExecutor in list_messages() - Support both instance-level and per-call max_parallel_reads configuration - Add comprehensive tests for parallel reading functionality - Maintain backward compatibility (default max_parallel_reads=1)
1 parent cee5145 commit e96ca3b

File tree

2 files changed

+347
-5
lines changed

2 files changed

+347
-5
lines changed

src/strands/session/s3_session_manager.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import logging
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
56
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
67

78
import boto3
@@ -50,6 +51,7 @@ def __init__(
5051
boto_session: Optional[boto3.Session] = None,
5152
boto_client_config: Optional[BotocoreConfig] = None,
5253
region_name: Optional[str] = None,
54+
max_parallel_reads: int = 1,
5355
**kwargs: Any,
5456
):
5557
"""Initialize S3SessionManager with S3 storage.
@@ -62,11 +64,20 @@ def __init__(
6264
boto_session: Optional boto3 session
6365
boto_client_config: Optional boto3 client configuration
6466
region_name: AWS region for S3 storage
67+
max_parallel_reads: Maximum number of parallel S3 read operations for list_messages().
68+
Defaults to 1 (sequential) for backward compatibility and safety.
69+
Set to a higher value (e.g., 10) for better performance with many messages.
70+
Can be overridden per-call via list_messages() kwargs.
6571
**kwargs: Additional keyword arguments for future extensibility.
6672
"""
6773
self.bucket = bucket
6874
self.prefix = prefix
6975

76+
# Validate max_parallel_reads
77+
if not isinstance(max_parallel_reads, int) or max_parallel_reads < 1:
78+
raise ValueError(f"max_parallel_reads must be a positive integer, got {max_parallel_reads}")
79+
self.max_parallel_reads = max_parallel_reads
80+
7081
session = boto_session or boto3.Session(region_name=region_name)
7182

7283
# Add strands-agents to the request user agent
@@ -259,7 +270,24 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
259270
def list_messages(
260271
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any
261272
) -> List[SessionMessage]:
262-
"""List messages for an agent with pagination from S3."""
273+
"""List messages for an agent with pagination from S3.
274+
275+
Args:
276+
session_id: ID of the session
277+
agent_id: ID of the agent
278+
limit: Optional limit on number of messages to return
279+
offset: Optional offset for pagination
280+
**kwargs: Additional keyword arguments. Supports:
281+
max_parallel_reads: Override the instance-level max_parallel_reads setting
282+
for this call only.
283+
284+
Returns:
285+
List of SessionMessage objects, sorted by message_id.
286+
287+
Raises:
288+
ValueError: If max_parallel_reads override is not a positive integer.
289+
SessionException: If S3 error occurs during message retrieval.
290+
"""
263291
messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/"
264292
try:
265293
paginator = self.client.get_paginator("list_objects_v2")
@@ -287,10 +315,51 @@ def list_messages(
287315
else:
288316
message_keys = message_keys[offset:]
289317

290-
# Load only the required message objects
318+
# Load message objects in parallel for better performance
291319
messages: List[SessionMessage] = []
292-
for key in message_keys:
293-
message_data = self._read_s3_object(key)
320+
if not message_keys:
321+
return messages
322+
323+
# Use ThreadPoolExecutor to fetch messages concurrently
324+
# Allow per-call override of max_parallel_reads via kwargs, otherwise use instance default
325+
max_parallel_reads_override = kwargs.get("max_parallel_reads")
326+
if max_parallel_reads_override is not None:
327+
if not isinstance(max_parallel_reads_override, int) or max_parallel_reads_override < 1:
328+
raise ValueError(
329+
f"max_parallel_reads must be a positive integer, got {max_parallel_reads_override}"
330+
)
331+
max_parallel_reads_value = max_parallel_reads_override
332+
else:
333+
# Instance default was already validated in __init__, no need to check again
334+
max_parallel_reads_value = self.max_parallel_reads
335+
336+
max_workers = min(max_parallel_reads_value, len(message_keys))
337+
338+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
339+
# Submit all read tasks
340+
future_to_key = {executor.submit(self._read_s3_object, key): key for key in message_keys}
341+
342+
# Create a mapping from key to index to maintain order
343+
key_to_index = {key: idx for idx, key in enumerate(message_keys)}
344+
345+
# Initialize results list with None placeholders to maintain order
346+
results: List[Optional[Dict[str, Any]]] = [None] * len(message_keys)
347+
348+
# Process results as they complete
349+
for future in as_completed(future_to_key):
350+
key = future_to_key[future]
351+
try:
352+
message_data = future.result()
353+
# Store result at the correct index to maintain order
354+
results[key_to_index[key]] = message_data
355+
except Exception as e:
356+
# Log error but continue processing other messages
357+
# Individual failures shouldn't stop the entire operation
358+
logger.warning("key=<%s> | failed to read message from s3", key, exc_info=e)
359+
360+
# Convert results to SessionMessage objects, filtering out None values
361+
# If SessionMessage.from_dict fails, let it propagate - data corruption should be visible
362+
for message_data in results:
294363
if message_data:
295364
messages.append(SessionMessage.from_dict(message_data))
296365

0 commit comments

Comments
 (0)