Skip to content

Commit 752c6cb

Browse files
committed
fix: context sharing when branching
1 parent 742e9fe commit 752c6cb

File tree

6 files changed

+36
-38
lines changed

6 files changed

+36
-38
lines changed

laygo/pipeline.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
from laygo.context import IContextManager
1414
from laygo.context.parallel import ParallelContextManager
15+
from laygo.context.types import IContextHandle
1516
from laygo.helpers import is_context_aware
1617
from laygo.transformers.transformer import Transformer
17-
from laygo.transformers.transformer import passthrough_chunks
1818

1919
T = TypeVar("T")
2020
U = TypeVar("U")
@@ -62,7 +62,7 @@ def __init__(self, *data: Iterable[T], context_manager: IContextManager | None =
6262
self.processed_data: Iterator = iter(self.data_source)
6363

6464
# Rule 1: Pipeline creates a simple context manager by default.
65-
self.context_manager = context_manager or ParallelContextManager()
65+
self.context_manager = context_manager if context_manager is not None else ParallelContextManager()
6666

6767
def __del__(self) -> None:
6868
"""Clean up the context manager when the pipeline is destroyed."""
@@ -93,16 +93,6 @@ def context(self, ctx: dict[str, Any]) -> "Pipeline[T]":
9393
self.context_manager.update(ctx)
9494
return self
9595

96-
def _sync_context_back(self) -> None:
97-
"""Synchronize the final pipeline context back to the original context reference.
98-
99-
This is called after processing is complete to update the original
100-
context with any changes made during pipeline execution.
101-
"""
102-
# This method is kept for backward compatibility but is no longer needed
103-
# since we use the context manager directly
104-
pass
105-
10696
def transform[U](self, t: Callable[[Transformer[T, T]], Transformer[T, U]]) -> "Pipeline[U]":
10797
"""Apply a transformation using a lambda function.
10898
@@ -170,7 +160,7 @@ def apply[U](
170160
match transformer:
171161
case Transformer():
172162
# Pass the pipeline's context manager to the transformer
173-
self.processed_data = transformer(self.processed_data, self.context_manager) # type: ignore
163+
self.processed_data = transformer(self.processed_data, context=self.context_manager) # type: ignore
174164
case _ if callable(transformer):
175165
if is_context_aware(transformer):
176166
self.processed_data = transformer(self.processed_data, self.context_manager) # type: ignore
@@ -318,14 +308,13 @@ def branch(
318308
branches: dict[str, Transformer[T, Any]],
319309
batch_size: int = 1000,
320310
max_batch_buffer: int = 1,
321-
use_queue_chunks: bool = True,
322311
) -> tuple[dict[str, list[Any]], dict[str, Any]]:
323312
"""Forks the pipeline into multiple branches for concurrent, parallel processing.
324313
325314
This is a **terminal operation** that implements a fan-out pattern where
326315
the entire dataset is copied to each branch for independent processing.
327-
Each branch processes the complete dataset concurrently using separate
328-
transformers, and results are collected and returned in a dictionary.
316+
Each branch gets its own Pipeline instance with isolated context management,
317+
and results are collected and returned in a dictionary.
329318
330319
Args:
331320
branches: A dictionary where keys are branch names (str) and values
@@ -336,13 +325,12 @@ def branch(
336325
max_batch_buffer: The maximum number of batches to buffer for each
337326
branch queue. Controls memory usage and creates
338327
backpressure. Defaults to 1.
339-
use_queue_chunks: Whether to use passthrough chunking for the
340-
transformers. When True, batches are processed
341-
as chunks. Defaults to True.
342328
343329
Returns:
344-
A dictionary where keys are the branch names and values are lists
345-
of all items processed by that branch's transformer.
330+
A tuple containing:
331+
- A dictionary where keys are the branch names and values are lists
332+
of all items processed by that branch's transformer.
333+
- A merged dictionary of all context values from all branches.
346334
347335
Note:
348336
This operation consumes the pipeline's iterator, making subsequent
@@ -372,32 +360,41 @@ def producer() -> None:
372360
for q in queues:
373361
q.put(None)
374362

375-
def consumer(transformer: Transformer, queue: Queue) -> list[Any]:
376-
"""Consumes batches from a queue and runs them through a transformer."""
363+
def consumer(
364+
transformer: Transformer, queue: Queue, context_handle: IContextHandle
365+
) -> tuple[list[Any], dict[str, Any]]:
366+
"""Consumes batches from a queue and processes them through a dedicated pipeline."""
377367

378368
def stream_from_queue() -> Iterator[T]:
379369
while (batch := queue.get()) is not None:
380-
yield batch
370+
yield from batch
371+
372+
# Create a new pipeline for this branch but share the parent's context manager
373+
# This ensures all branches share the same context
374+
branch_pipeline = Pipeline(stream_from_queue(), context_manager=context_handle.create_proxy()) # type: ignore
381375

382-
if use_queue_chunks:
383-
transformer = transformer.set_chunker(passthrough_chunks)
376+
# Apply the transformer to the branch pipeline and get results
377+
result_list, branch_context = branch_pipeline.apply(transformer).to_list()
384378

385-
result_iterator = transformer(stream_from_queue(), self.context_manager) # type: ignore
386-
return list(result_iterator)
379+
return result_list, branch_context
387380

388381
with ThreadPoolExecutor(max_workers=num_branches + 1) as executor:
389382
executor.submit(producer)
390383

391384
future_to_name = {
392-
executor.submit(consumer, transformer, queues[i]): name for i, (name, transformer) in enumerate(branch_items)
385+
executor.submit(consumer, transformer, queues[i], self.context_manager.get_handle()): name
386+
for i, (name, transformer) in enumerate(branch_items)
393387
}
394388

389+
# Collect results - context is shared through the same context manager
395390
for future in as_completed(future_to_name):
396391
name = future_to_name[future]
397392
try:
398-
final_results[name] = future.result()
399-
except Exception as e:
400-
print(f"Branch '{name}' raised an exception: {e}")
393+
result_list, branch_context = future.result()
394+
final_results[name] = result_list
395+
except Exception:
401396
final_results[name] = []
402397

403-
return final_results, self.context_manager.to_dict()
398+
# After all threads complete, get the final context state
399+
final_context = self.context_manager.to_dict()
400+
return final_results, final_context

laygo/transformers/http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __call__(self, data: Iterable[In], context: IContextManager | None = None) -
124124
Returns:
125125
An iterator over the processed data.
126126
"""
127-
run_context = context or self._default_context
127+
run_context = self._default_context
128128

129129
self._finalize_config()
130130

laygo/transformers/parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def from_transformer[T, U](
124124

125125
def __call__(self, data: Iterable[In], context: IContextManager | None = None) -> Iterator[Out]:
126126
"""Execute the transformer by distributing chunks to a process pool."""
127-
run_context = context or self._default_context
127+
run_context = context if context is not None else self._default_context
128128

129129
# Get the picklable handle from the context manager.
130130
context_handle = run_context.get_handle()

laygo/transformers/threaded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __call__(self, data: Iterable[In], context: IContextManager | None = None) -
119119
Returns:
120120
An iterator over the transformed data.
121121
"""
122-
run_context = context or self._default_context
122+
run_context = context if context is not None else self._default_context
123123

124124
# Since threads share memory, we can pass the context manager directly.
125125
# No handle/proxy mechanism is needed, but the locking inside

laygo/transformers/transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,9 @@ def __call__(self, data: Iterable[In], context: IContextManager | None = None) -
345345
Returns:
346346
An iterator over the transformed data.
347347
"""
348+
348349
# Use the provided context by reference, or default to a simple context.
349-
run_context = context or self._default_context
350+
run_context = context if context is not None else self._default_context
350351

351352
try:
352353
for chunk in self._chunk_generator(data):

tests/test_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,5 +498,5 @@ def context_modifier_b(chunk: list[int], ctx: IContextManager) -> list[int]:
498498
assert result["branch_b"] == [3, 6, 9]
499499

500500
# Context values should reflect the actual chunk sizes processed
501-
assert context["branch_a_processed"] == 3
502501
assert context["branch_b_processed"] == 3
502+
assert context["branch_a_processed"] == 3

0 commit comments

Comments
 (0)