diff --git a/laygo/pipeline.py b/laygo/pipeline.py index 931d83e..4f3da88 100644 --- a/laygo/pipeline.py +++ b/laygo/pipeline.py @@ -2,14 +2,19 @@ from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator +from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor from concurrent.futures import as_completed import itertools +from multiprocessing import Manager from queue import Queue from typing import Any +from typing import Literal from typing import TypeVar from typing import overload +from loky import get_reusable_executor + from laygo.context import IContextManager from laygo.context.parallel import ParallelContextManager from laygo.context.types import IContextHandle @@ -21,6 +26,45 @@ PipelineFunction = Callable[[T], Any] +# This function must be defined at the top level of the module (e.g., after imports) +def _branch_consumer_process[T](transformer: Transformer, queue: "Queue", context_handle: IContextHandle) -> list[Any]: + """Entry point for a consumer process in parallel branching. + + Reconstructs the necessary objects and runs a dedicated pipeline instance + on the data from its queue. This function is called in separate processes + during process-based parallel execution. + + Args: + transformer: The transformer to apply to the data. + queue: Process-safe queue containing batched data items. + context_handle: Handle to create a context proxy in the new process. + + Returns: + List of processed results from applying the transformer. + """ + # Re-create the context proxy within the new process + context_proxy = context_handle.create_proxy() + + def stream_from_queue() -> Iterator[T]: + """Generate items from the process-safe queue. + + Yields items from batches until a None sentinel is received. + + Yields: + Items from the queue batches. + """ + while (batch := queue.get()) is not None: + yield from batch + + try: + # Each consumer process runs its own mini-pipeline + branch_pipeline = Pipeline(stream_from_queue(), context_manager=context_proxy) + result_list, _ = branch_pipeline.apply(transformer).to_list() + return result_list + finally: + context_proxy.shutdown() + + class Pipeline[T]: """Manages a data source and applies transformers to it. @@ -237,7 +281,9 @@ def to_list(self) -> tuple[list[T], dict[str, Any]]: and materializes all results into memory. Returns: - A list containing all processed items from the pipeline. + A tuple containing: + - A list of all processed items from the pipeline + - The final context dictionary Note: This operation consumes the pipeline's iterator, making subsequent @@ -246,7 +292,7 @@ def to_list(self) -> tuple[list[T], dict[str, Any]]: return list(self.processed_data), self.context_manager.to_dict() def each(self, function: PipelineFunction[T]) -> tuple[None, dict[str, Any]]: - """Apply a function to each element (terminal operation). + """Apply a function to each element for side effects. This is a terminal operation that processes each element for side effects and consumes the pipeline's iterator without returning results. @@ -255,6 +301,11 @@ def each(self, function: PipelineFunction[T]) -> tuple[None, dict[str, Any]]: function: The function to apply to each element. Should be used for side effects like logging, updating external state, etc. + Returns: + A tuple containing: + - None (no results are collected) + - The final context dictionary + Note: This operation consumes the pipeline's iterator, making subsequent operations on the same pipeline return empty results. @@ -265,7 +316,7 @@ def each(self, function: PipelineFunction[T]) -> tuple[None, dict[str, Any]]: return None, self.context_manager.to_dict() def first(self, n: int = 1) -> tuple[list[T], dict[str, Any]]: - """Get the first n elements of the pipeline (terminal operation). + """Get the first n elements of the pipeline. This is a terminal operation that consumes up to n elements from the pipeline's iterator and returns them as a list. @@ -274,8 +325,10 @@ def first(self, n: int = 1) -> tuple[list[T], dict[str, Any]]: n: The number of elements to retrieve. Must be at least 1. Returns: - A list containing the first n elements, or fewer if the pipeline - contains fewer than n elements. + A tuple containing: + - A list containing the first n elements, or fewer if the pipeline + contains fewer than n elements + - The final context dictionary Raises: AssertionError: If n is less than 1. @@ -288,12 +341,17 @@ def first(self, n: int = 1) -> tuple[list[T], dict[str, Any]]: return list(itertools.islice(self.processed_data, n)), self.context_manager.to_dict() def consume(self) -> tuple[None, dict[str, Any]]: - """Consume the pipeline without returning results (terminal operation). + """Consume the pipeline without returning results. This is a terminal operation that processes all elements in the pipeline for their side effects without materializing any results. Useful when the pipeline operations have side effects and you don't need the results. + Returns: + A tuple containing: + - None (no results are collected) + - The final context dictionary + Note: This operation consumes the pipeline's iterator, making subsequent operations on the same pipeline return empty results. @@ -303,98 +361,341 @@ def consume(self) -> tuple[None, dict[str, Any]]: return None, self.context_manager.to_dict() + def _producer_fanout( + self, + source_iterator: Iterator[T], + queues: dict[str, Queue], + batch_size: int, + ) -> None: + """Producer for fan-out mode. + + Sends every item to every branch. Used for unconditional branching + where all branches process all items. + + Args: + source_iterator: The source data iterator. + queues: Dictionary mapping branch names to their queues. + batch_size: Number of items per batch. + """ + for batch_tuple in itertools.batched(source_iterator, batch_size): + batch_list = list(batch_tuple) + for q in queues.values(): + q.put(batch_list) + for q in queues.values(): + q.put(None) + + def _producer_router( + self, + source_iterator: Iterator[T], + queues: dict[str, Queue], + parsed_branches: list[tuple[str, Transformer, Callable]], + batch_size: int, + ) -> None: + """Producer for router mode. + + Sends each item to the first matching branch when first_match=True. + This implements conditional routing where items go to exactly one branch. + + Args: + source_iterator: The source data iterator. + queues: Dictionary mapping branch names to their queues. + parsed_branches: List of (name, transformer, condition) tuples. + batch_size: Number of items per batch. + """ + buffers = {name: [] for name, _, _ in parsed_branches} + for item in source_iterator: + for name, _, condition in parsed_branches: + if condition(item): + branch_buffer = buffers[name] + branch_buffer.append(item) + if len(branch_buffer) >= batch_size: + queues[name].put(branch_buffer) + buffers[name] = [] + break + for name, buffer_list in buffers.items(): + if buffer_list: + queues[name].put(buffer_list) + for q in queues.values(): + q.put(None) + + def _producer_broadcast( + self, + source_iterator: Iterator[T], + queues: dict[str, Queue], + parsed_branches: list[tuple[str, Transformer, Callable]], + batch_size: int, + ) -> None: + """Producer for broadcast mode. + + Sends each item to all matching branches when first_match=False. + This implements conditional broadcasting where items can go to multiple branches. + + Args: + source_iterator: The source data iterator. + queues: Dictionary mapping branch names to their queues. + parsed_branches: List of (name, transformer, condition) tuples. + batch_size: Number of items per batch. + """ + buffers = {name: [] for name, _, _ in parsed_branches} + for item in source_iterator: + item_matches = [name for name, _, condition in parsed_branches if condition(item)] + + for name in item_matches: + buffers[name].append(item) + branch_buffer = buffers[name] + if len(branch_buffer) >= batch_size: + queues[name].put(branch_buffer) + buffers[name] = [] + + for name, buffer_list in buffers.items(): + if buffer_list: + queues[name].put(buffer_list) + for q in queues.values(): + q.put(None) + + # Overload 1: Unconditional fan-out + @overload + def branch( + self, + branches: Mapping[str, Transformer[T, Any]], + *, + executor_type: Literal["thread", "process"] = "thread", + batch_size: int = 1000, + max_batch_buffer: int = 1, + ) -> tuple[dict[str, list[Any]], dict[str, Any]]: ... + + # Overload 2: Conditional routing + @overload + def branch( + self, + branches: Mapping[str, tuple[Transformer[T, Any], Callable[[T], bool]]], + *, + executor_type: Literal["thread", "process"] = "thread", + first_match: bool = True, + batch_size: int = 1000, + max_batch_buffer: int = 1, + ) -> tuple[dict[str, list[Any]], dict[str, Any]]: ... + def branch( self, - branches: dict[str, Transformer[T, Any]], + branches: Mapping[str, Transformer[T, Any]] | Mapping[str, tuple[Transformer[T, Any], Callable[[T], bool]]], + *, + executor_type: Literal["thread", "process"] = "thread", + first_match: bool = True, batch_size: int = 1000, max_batch_buffer: int = 1, ) -> tuple[dict[str, list[Any]], dict[str, Any]]: - """Forks the pipeline into multiple branches for concurrent, parallel processing. + """ + Forks the pipeline for parallel processing with optional conditional routing. + + This is a **terminal operation** that consumes the pipeline. + + **1. Unconditional Fan-Out:** + If `branches` is a `Dict[str, Transformer]`, every item is sent to every branch. - This is a **terminal operation** that implements a fan-out pattern where - the entire dataset is copied to each branch for independent processing. - Each branch gets its own Pipeline instance with isolated context management, - and results are collected and returned in a dictionary. + **2. Conditional Routing:** + If `branches` is a `Dict[str, Tuple[Transformer, condition]]`, the `first_match` + argument determines the routing logic: + - `first_match=True` (default): Routes each item to the **first** branch + whose condition is met. This acts as a router. + - `first_match=False`: Routes each item to **all** branches whose + conditions are met. This acts as a conditional broadcast. Args: - branches: A dictionary where keys are branch names (str) and values - are `Transformer` instances of any subtype. - batch_size: The number of items to batch together when sending data - to branches. Larger batches can improve throughput but - use more memory. Defaults to 1000. - max_batch_buffer: The maximum number of batches to buffer for each - branch queue. Controls memory usage and creates - backpressure. Defaults to 1. + branches: A dictionary defining the branches. + executor_type: The parallelism model. 'thread' for I/O-bound tasks, + 'process' for CPU-bound tasks. Defaults to 'thread'. + first_match: Determines the routing logic for conditional branches. + batch_size: The number of items to batch for processing. + max_batch_buffer: The max number of batches to buffer per branch. Returns: - A tuple containing: - - A dictionary where keys are the branch names and values are lists - of all items processed by that branch's transformer. - - A merged dictionary of all context values from all branches. - - Note: - This operation consumes the pipeline's iterator, making subsequent - operations on the same pipeline return empty results. + A tuple containing a dictionary of results and the final context. """ if not branches: self.consume() return {}, {} + first_value = next(iter(branches.values())) + is_conditional = isinstance(first_value, tuple) + + parsed_branches: list[tuple[str, Transformer[T, Any], Callable[[T], bool]]] + if is_conditional: + parsed_branches = [(name, trans, cond) for name, (trans, cond) in branches.items()] # type: ignore + else: + parsed_branches = [(name, trans, lambda _: True) for name, trans in branches.items()] # type: ignore + + producer_fn: Callable + if not is_conditional: + producer_fn = self._producer_fanout + elif first_match: + producer_fn = self._producer_router + else: + producer_fn = self._producer_broadcast + + # Dispatch to the correct executor based on the chosen type + if executor_type == "thread": + return self._execute_branching_thread( + producer_fn=producer_fn, + parsed_branches=parsed_branches, + batch_size=batch_size, + max_batch_buffer=max_batch_buffer, + ) + elif executor_type == "process": + return self._execute_branching_process( + producer_fn=producer_fn, + parsed_branches=parsed_branches, + batch_size=batch_size, + max_batch_buffer=max_batch_buffer, + ) + else: + raise ValueError(f"Unsupported executor_type: '{executor_type}'. Must be 'thread' or 'process'.") + + def _execute_branching_process( + self, + *, + producer_fn: Callable, + parsed_branches: list[tuple[str, Transformer, Callable]], + batch_size: int, + max_batch_buffer: int, + ) -> tuple[dict[str, list[Any]], dict[str, Any]]: + """Execute branching using a process pool for consumers. + + Uses multiprocessing for true CPU parallelism. The producer runs in a + thread while consumers run in separate processes. + + Args: + producer_fn: The producer function to use for routing items. + parsed_branches: List of (name, transformer, condition) tuples. + batch_size: Number of items per batch. + max_batch_buffer: Maximum number of batches to buffer per branch. + + Returns: + A tuple containing: + - Dictionary mapping branch names to their result lists + - The final context dictionary + """ source_iterator = self.processed_data - branch_items = list(branches.items()) - num_branches = len(branch_items) - final_results: dict[str, list[Any]] = {} - - queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)] - - def producer() -> None: - """Reads from the source and distributes batches to ALL branch queues.""" - # Use itertools.batched for clean and efficient batch creation. - for batch_tuple in itertools.batched(source_iterator, batch_size): - # The batch is a tuple; convert to a list for consumers. - batch_list = list(batch_tuple) - for q in queues: - q.put(batch_list) - - # Signal to all consumers that the stream is finished. - for q in queues: - q.put(None) - - def consumer( - transformer: Transformer, queue: Queue, context_handle: IContextHandle - ) -> tuple[list[Any], dict[str, Any]]: - """Consumes batches from a queue and processes them through a dedicated pipeline.""" + num_branches = len(parsed_branches) + final_results: dict[str, list[Any]] = {name: [] for name, _, _ in parsed_branches} + context_handle = self.context_manager.get_handle() + + # A Manager creates queues that can be shared between processes + manager = Manager() + queues = {name: manager.Queue(maxsize=max_batch_buffer) for name, _, _ in parsed_branches} + + # The producer must run in a thread to access the pipeline's iterator, + # while consumers run in processes for true CPU parallelism. + producer_executor = ThreadPoolExecutor(max_workers=1) + consumer_executor = get_reusable_executor(max_workers=num_branches) + + try: + # Determine arguments for the producer function + producer_args: tuple + if producer_fn == self._producer_fanout: + producer_args = (source_iterator, queues, batch_size) + else: + producer_args = (source_iterator, queues, parsed_branches, batch_size) + + # Submit the producer to the thread pool + producer_future = producer_executor.submit(producer_fn, *producer_args) + + # Submit consumers to the process pool + future_to_name = { + consumer_executor.submit(_branch_consumer_process, transformer, queues[name], context_handle): name + for name, transformer, _ in parsed_branches + } + + # Collect results as they complete + for future in as_completed(future_to_name): + name = future_to_name[future] + try: + final_results[name] = future.result() + except Exception: + final_results[name] = [] + + # Check for producer errors after consumers are done + producer_future.result() + + finally: + producer_executor.shutdown() + # The reusable executor from loky is managed globally + + final_context = self.context_manager.to_dict() + return final_results, final_context + + # Rename original _execute_branching to be specific + def _execute_branching_thread( + self, + *, + producer_fn: Callable, + parsed_branches: list[tuple[str, Transformer, Callable]], + batch_size: int, + max_batch_buffer: int, + ) -> tuple[dict[str, list[Any]], dict[str, Any]]: + """Execute branching using a thread pool for consumers. + + Uses threading for I/O-bound tasks. Both producer and consumers run + in separate threads within the same process. + + Args: + producer_fn: The producer function to use for routing items. + parsed_branches: List of (name, transformer, condition) tuples. + batch_size: Number of items per batch. + max_batch_buffer: Maximum number of batches to buffer per branch. + + Returns: + A tuple containing: + - Dictionary mapping branch names to their result lists + - The final context dictionary + """ + source_iterator = self.processed_data + num_branches = len(parsed_branches) + final_results: dict[str, list[Any]] = {name: [] for name, _, _ in parsed_branches} + queues = {name: Queue(maxsize=max_batch_buffer) for name, _, _ in parsed_branches} + + def consumer(transformer: Transformer, queue: Queue, context_handle: IContextHandle) -> list[Any]: + """Consume batches from a queue and process them with a transformer. + + Creates a mini-pipeline for the transformer and processes all + batches from the queue until completion. + + Args: + transformer: The transformer to apply to the data. + queue: Queue containing batched data items. + context_handle: Handle to create a context proxy. + + Returns: + List of processed results from applying the transformer. + """ def stream_from_queue() -> Iterator[T]: while (batch := queue.get()) is not None: yield from batch - # Create a new pipeline for this branch but share the parent's context manager - # This ensures all branches share the same context branch_pipeline = Pipeline(stream_from_queue(), context_manager=context_handle.create_proxy()) # type: ignore - - # Apply the transformer to the branch pipeline and get results - result_list, branch_context = branch_pipeline.apply(transformer).to_list() - - return result_list, branch_context + result_list, _ = branch_pipeline.apply(transformer).to_list() + return result_list with ThreadPoolExecutor(max_workers=num_branches + 1) as executor: - executor.submit(producer) + producer_args: tuple + if producer_fn == self._producer_fanout: + producer_args = (source_iterator, queues, batch_size) + else: + producer_args = (source_iterator, queues, parsed_branches, batch_size) + executor.submit(producer_fn, *producer_args) future_to_name = { - executor.submit(consumer, transformer, queues[i], self.context_manager.get_handle()): name - for i, (name, transformer) in enumerate(branch_items) + executor.submit(consumer, transformer, queues[name], self.context_manager.get_handle()): name + for name, transformer, _ in parsed_branches } - # Collect results - context is shared through the same context manager for future in as_completed(future_to_name): name = future_to_name[future] try: - result_list, branch_context = future.result() - final_results[name] = result_list + final_results[name] = future.result() except Exception: final_results[name] = [] - # After all threads complete, get the final context state final_context = self.context_manager.to_dict() return final_results, final_context diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index bb1f2ad..429964e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,5 +1,8 @@ """Tests for the Pipeline class.""" +import os +import time + from laygo import Pipeline from laygo.context.types import IContextManager from laygo.transformers.transformer import createTransformer @@ -500,3 +503,117 @@ def context_modifier_b(chunk: list[int], ctx: IContextManager) -> list[int]: # Context values should reflect the actual chunk sizes processed assert context["branch_b_processed"] == 3 assert context["branch_a_processed"] == 3 + + def test_branch_conditional_router_mode(self): + """Test conditional branch with first_match=True (router mode).""" + # Setup + data = [1, "a", 2, 99.9, 3, "b"] + pipeline = Pipeline(data) + + branches = { + "integers": ( + createTransformer(int).map(lambda x: x + 1), + lambda x: isinstance(x, int), + ), + "strings": ( + createTransformer(str).map(lambda x: x.upper()), + lambda x: isinstance(x, str), + ), + "numbers": ( # This condition also matches integers + createTransformer(float).map(lambda x: x * 10), + lambda x: isinstance(x, int | float), + ), + } + + # Action: Execute branch with default first_match=True + result, _ = pipeline.branch(branches, first_match=True) + + # Assert: Items are routed to the *first* matching branch only. + # Integers (1, 2, 3) are caught by the 'integers' branch first. + assert sorted(result["integers"]) == [2, 3, 4] + # Strings ('a', 'b') are caught by the 'strings' branch. + assert sorted(result["strings"]) == ["A", "B"] + # The float (99.9) is caught by 'numbers'. Integers are NOT processed + # here because they were already matched by the 'integers' branch. + assert result["numbers"] == [999.0] + + def test_branch_conditional_broadcast_mode(self): + """Test conditional branch with first_match=False (broadcast mode).""" + # Setup + data = [1, "a", 2, 99.9, 3, "b"] + pipeline = Pipeline(data) + + branches = { + "integers": ( + createTransformer(int).map(lambda x: x + 1), + lambda x: isinstance(x, int), + ), + "strings": ( + createTransformer(str).map(lambda x: x.upper()), + lambda x: isinstance(x, str), + ), + "numbers": ( # This condition also matches integers + createTransformer(float).map(lambda x: x * 10), + lambda x: isinstance(x, int | float), + ), + } + + # Action: Execute branch with first_match=False + result, _ = pipeline.branch(branches, first_match=False) + + # Assert: Items are routed to *all* matching branches. + # Integers (1, 2, 3) are processed by the 'integers' branch. + assert sorted(result["integers"]) == [2, 3, 4] + # Strings ('a', 'b') are processed by the 'strings' branch. + assert sorted(result["strings"]) == ["A", "B"] + # The float (99.9) AND the integers (1, 2, 3) are processed by the 'numbers' branch. + assert sorted(result["numbers"]) == [10.0, 20.0, 30.0, 999.0] + + def test_branch_process_executor(self): + """Test branching with executor_type='process' for CPU-bound work.""" + + # Setup: A CPU-bound task is ideal for demonstrating process parallelism. + def heavy_computation(x: int) -> int: + # A simple but non-trivial calculation + time.sleep(0.01) # Simulate work + return x * x + + # This function will run inside the worker process to check its PID + def check_pid(chunk: list[int], ctx: IContextManager) -> list[int]: + # Store the worker's process ID in the shared context + if chunk: + ctx[f"pid_for_item_{chunk[0]}"] = os.getpid() + return chunk + + data = [1, 2, 3, 4] + pipeline = Pipeline(data) + main_pid = os.getpid() + + # Define branches with CPU-bound work and the PID check + branches = { + "evens": ( + createTransformer(int).filter(lambda x: x % 2 == 0).map(heavy_computation)._pipe(check_pid), + lambda x: True, # Condition to route data + ), + "odds": ( + createTransformer(int).filter(lambda x: x % 2 != 0).map(heavy_computation)._pipe(check_pid), + lambda x: True, + ), + } + + # Action: Execute the branch with the process executor + result, context = pipeline.branch( + branches, + first_match=False, # Use broadcast to send to all matching + executor_type="process", + ) + + # Assert: The computational results are correct + assert sorted(result["evens"]) == [4, 16] # 2*2, 4*4 + assert sorted(result["odds"]) == [1, 9] # 1*1, 3*3 + + # Assert: The work was done in different processes + worker_pids = {v for k, v in context.items() if "pid" in k} + assert len(worker_pids) > 0, "No worker PIDs were found in the context." + for pid in worker_pids: + assert pid != main_pid, f"Worker PID {pid} is the same as the main PID."