1212
1313from laygo .context import IContextManager
1414from laygo .context .parallel import ParallelContextManager
15+ from laygo .context .types import IContextHandle
1516from laygo .helpers import is_context_aware
1617from laygo .transformers .transformer import Transformer
17- from laygo .transformers .transformer import passthrough_chunks
1818
1919T = TypeVar ("T" )
2020U = TypeVar ("U" )
@@ -62,7 +62,7 @@ def __init__(self, *data: Iterable[T], context_manager: IContextManager | None =
6262 self .processed_data : Iterator = iter (self .data_source )
6363
6464 # Rule 1: Pipeline creates a simple context manager by default.
65- self .context_manager = context_manager or ParallelContextManager ()
65+ self .context_manager = context_manager if context_manager is not None else ParallelContextManager ()
6666
6767 def __del__ (self ) -> None :
6868 """Clean up the context manager when the pipeline is destroyed."""
@@ -93,16 +93,6 @@ def context(self, ctx: dict[str, Any]) -> "Pipeline[T]":
9393 self .context_manager .update (ctx )
9494 return self
9595
96- def _sync_context_back (self ) -> None :
97- """Synchronize the final pipeline context back to the original context reference.
98-
99- This is called after processing is complete to update the original
100- context with any changes made during pipeline execution.
101- """
102- # This method is kept for backward compatibility but is no longer needed
103- # since we use the context manager directly
104- pass
105-
10696 def transform [U ](self , t : Callable [[Transformer [T , T ]], Transformer [T , U ]]) -> "Pipeline[U]" :
10797 """Apply a transformation using a lambda function.
10898
@@ -170,7 +160,7 @@ def apply[U](
170160 match transformer :
171161 case Transformer ():
172162 # Pass the pipeline's context manager to the transformer
173- self .processed_data = transformer (self .processed_data , self .context_manager ) # type: ignore
163+ self .processed_data = transformer (self .processed_data , context = self .context_manager ) # type: ignore
174164 case _ if callable (transformer ):
175165 if is_context_aware (transformer ):
176166 self .processed_data = transformer (self .processed_data , self .context_manager ) # type: ignore
@@ -318,14 +308,13 @@ def branch(
318308 branches : dict [str , Transformer [T , Any ]],
319309 batch_size : int = 1000 ,
320310 max_batch_buffer : int = 1 ,
321- use_queue_chunks : bool = True ,
322311 ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]:
323312 """Forks the pipeline into multiple branches for concurrent, parallel processing.
324313
325314 This is a **terminal operation** that implements a fan-out pattern where
326315 the entire dataset is copied to each branch for independent processing.
327- Each branch processes the complete dataset concurrently using separate
328- transformers, and results are collected and returned in a dictionary.
316+ Each branch gets its own Pipeline instance with isolated context management,
317+ and results are collected and returned in a dictionary.
329318
330319 Args:
331320 branches: A dictionary where keys are branch names (str) and values
@@ -336,13 +325,12 @@ def branch(
336325 max_batch_buffer: The maximum number of batches to buffer for each
337326 branch queue. Controls memory usage and creates
338327 backpressure. Defaults to 1.
339- use_queue_chunks: Whether to use passthrough chunking for the
340- transformers. When True, batches are processed
341- as chunks. Defaults to True.
342328
343329 Returns:
344- A dictionary where keys are the branch names and values are lists
345- of all items processed by that branch's transformer.
330+ A tuple containing:
331+ - A dictionary where keys are the branch names and values are lists
332+ of all items processed by that branch's transformer.
333+ - A merged dictionary of all context values from all branches.
346334
347335 Note:
348336 This operation consumes the pipeline's iterator, making subsequent
@@ -372,32 +360,41 @@ def producer() -> None:
372360 for q in queues :
373361 q .put (None )
374362
375- def consumer (transformer : Transformer , queue : Queue ) -> list [Any ]:
376- """Consumes batches from a queue and runs them through a transformer."""
363+ def consumer (
364+ transformer : Transformer , queue : Queue , context_handle : IContextHandle
365+ ) -> tuple [list [Any ], dict [str , Any ]]:
366+ """Consumes batches from a queue and processes them through a dedicated pipeline."""
377367
378368 def stream_from_queue () -> Iterator [T ]:
379369 while (batch := queue .get ()) is not None :
380- yield batch
370+ yield from batch
371+
372+ # Create a new pipeline for this branch but share the parent's context manager
373+ # This ensures all branches share the same context
374+ branch_pipeline = Pipeline (stream_from_queue (), context_manager = context_handle .create_proxy ()) # type: ignore
381375
382- if use_queue_chunks :
383- transformer = transformer . set_chunker ( passthrough_chunks )
376+ # Apply the transformer to the branch pipeline and get results
377+ result_list , branch_context = branch_pipeline . apply ( transformer ). to_list ( )
384378
385- result_iterator = transformer (stream_from_queue (), self .context_manager ) # type: ignore
386- return list (result_iterator )
379+ return result_list , branch_context
387380
388381 with ThreadPoolExecutor (max_workers = num_branches + 1 ) as executor :
389382 executor .submit (producer )
390383
391384 future_to_name = {
392- executor .submit (consumer , transformer , queues [i ]): name for i , (name , transformer ) in enumerate (branch_items )
385+ executor .submit (consumer , transformer , queues [i ], self .context_manager .get_handle ()): name
386+ for i , (name , transformer ) in enumerate (branch_items )
393387 }
394388
389+ # Collect results - context is shared through the same context manager
395390 for future in as_completed (future_to_name ):
396391 name = future_to_name [future ]
397392 try :
398- final_results [ name ] = future .result ()
399- except Exception as e :
400- print ( f"Branch ' { name } ' raised an exception: { e } " )
393+ result_list , branch_context = future .result ()
394+ final_results [ name ] = result_list
395+ except Exception :
401396 final_results [name ] = []
402397
403- return final_results , self .context_manager .to_dict ()
398+ # After all threads complete, get the final context state
399+ final_context = self .context_manager .to_dict ()
400+ return final_results , final_context
0 commit comments