diff --git a/laygo/transformers/strategies/threaded.py b/laygo/transformers/strategies/threaded.py index a855458..2ed8c54 100644 --- a/laygo/transformers/strategies/threaded.py +++ b/laygo/transformers/strategies/threaded.py @@ -5,8 +5,9 @@ from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from concurrent.futures import wait -from functools import partial import itertools +import threading +from typing import ClassVar from laygo.context.types import IContextManager from laygo.transformers.strategies.types import ChunkGenerator @@ -15,26 +16,38 @@ class ThreadedStrategy[In, Out](ExecutionStrategy[In, Out]): + # Class-level thread pool cache to reuse executors + _thread_pools: ClassVar[dict[int, ThreadPoolExecutor]] = {} + _pool_lock: ClassVar[threading.Lock] = threading.Lock() + def __init__(self, max_workers: int = 4, ordered: bool = True): self.max_workers = max_workers self.ordered = ordered + @classmethod + def _get_thread_pool(cls, max_workers: int) -> ThreadPoolExecutor: + """Get or create a reusable thread pool for the given worker count.""" + with cls._pool_lock: + if max_workers not in cls._thread_pools: + cls._thread_pools[max_workers] = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix=f"laygo-{max_workers}" + ) + return cls._thread_pools[max_workers] + def execute(self, transformer_logic, chunk_generator, data, context): """Execute the transformer on data concurrently. - It uses the shared context provided by the Pipeline, if available. + Uses a reusable thread pool to minimize thread creation overhead. Args: + transformer_logic: The transformation function to apply. + chunk_generator: Function to generate data chunks. data: The input data to process. context: Optional pipeline context for shared state. Returns: An iterator over the transformed data. """ - - # Since threads share memory, we can pass the context manager directly. - # No handle/proxy mechanism is needed, but the locking inside - # ParallelContextManager is crucial for thread safety. yield from self._execute_with_context(data, transformer_logic, context, chunk_generator) def _execute_with_context( @@ -48,13 +61,15 @@ def _execute_with_context( Args: data: The input data to process. + transformer: The transformation function to apply. shared_context: The shared context for the execution. + chunk_generator: Function to generate data chunks. Returns: An iterator over the transformed data. """ - def process_chunk(chunk: list[In], shared_context: IContextManager) -> list[Out]: + def process_chunk(chunk: list[In]) -> list[Out]: """Process a single chunk by passing the chunk and context explicitly. Args: @@ -66,49 +81,58 @@ def process_chunk(chunk: list[In], shared_context: IContextManager) -> list[Out] """ return transformer(chunk, shared_context) # type: ignore - # Create a partial function with the shared_context "baked in". - process_chunk_with_context = partial(process_chunk, shared_context=shared_context) - def _ordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]: """Generate results in their original order.""" futures: deque[Future[list[Out]]] = deque() - for _ in range(self.max_workers + 1): + + # Pre-submit initial batch of futures + for _ in range(min(self.max_workers, 10)): # Limit initial submissions try: chunk = next(chunks_iter) - futures.append(executor.submit(process_chunk_with_context, chunk)) + futures.append(executor.submit(process_chunk, chunk)) except StopIteration: break + while futures: - yield futures.popleft().result() + # Get the next result and submit the next chunk + result = futures.popleft().result() + yield result + try: chunk = next(chunks_iter) - futures.append(executor.submit(process_chunk_with_context, chunk)) + futures.append(executor.submit(process_chunk, chunk)) except StopIteration: continue def _unordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]: """Generate results as they complete.""" + # Pre-submit initial batch futures = { - executor.submit(process_chunk_with_context, chunk) - for chunk in itertools.islice(chunks_iter, self.max_workers + 1) + executor.submit(process_chunk, chunk) for chunk in itertools.islice(chunks_iter, min(self.max_workers, 10)) } + while futures: done, futures = wait(futures, return_when=FIRST_COMPLETED) for future in done: yield future.result() try: chunk = next(chunks_iter) - futures.add(executor.submit(process_chunk_with_context, chunk)) + futures.add(executor.submit(process_chunk, chunk)) except StopIteration: continue - def result_iterator_manager() -> Iterator[Out]: - """Manage the thread pool and yield flattened results.""" - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - chunks_to_process = chunk_generator(data) - gen_func = _ordered_generator if self.ordered else _unordered_generator - processed_chunks_iterator = gen_func(chunks_to_process, executor) - for result_chunk in processed_chunks_iterator: - yield from result_chunk - - return result_iterator_manager() + # Use the reusable thread pool instead of creating a new one + executor = self._get_thread_pool(self.max_workers) + chunks_to_process = chunk_generator(data) + gen_func = _ordered_generator if self.ordered else _unordered_generator + + # Process chunks using the reusable executor + for result_chunk in gen_func(chunks_to_process, executor): + yield from result_chunk + + def __del__(self) -> None: + """Shutdown all cached thread pools. Call this during application cleanup.""" + with self._pool_lock: + for pool in self._thread_pools.values(): + pool.shutdown(wait=True) + self._thread_pools.clear() diff --git a/tests/test_custom_transformer.py b/tests/test_custom_transformer.py new file mode 100644 index 0000000..2cf6602 --- /dev/null +++ b/tests/test_custom_transformer.py @@ -0,0 +1,30 @@ +from collections.abc import Iterable +from collections.abc import Iterator + +from laygo.context.types import IContextManager +from laygo.pipeline import Pipeline +from laygo.transformers.types import BaseTransformer + +# In should be an int + + +class MultiplierTransformer(BaseTransformer[int, int]): + def __call__(self, data: Iterable[int], context: IContextManager | None = None) -> Iterator[int]: + """ + Takes an iterable of data and yields each item multiplied. + """ + + multiplier = context["multiplier"] if context and "multiplier" in context else 1 + + for item in data: + yield item * multiplier + + +class TestCustomTransformer: + def test_multiplier_transformer(self): + data = [1, 2, 3, 4, 5] + expected_output = [2, 4, 6, 8, 10] + + result, _ = Pipeline(data).context({"multiplier": 2}).apply(MultiplierTransformer()).to_list() + + assert result == expected_output, f"Expected {expected_output}, but got {result}"