From 61b278c2c45f8867df645e12fcf9c973256ed20e Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Tue, 4 Mar 2025 09:37:03 +0100 Subject: [PATCH] feat: add queue to async client --- scrapegraph-py/scrapegraph_py/async_client.py | 155 +++++++++++++++++- scrapegraph-py/uv.lock | 2 +- 2 files changed, 154 insertions(+), 3 deletions(-) diff --git a/scrapegraph-py/scrapegraph_py/async_client.py b/scrapegraph-py/scrapegraph_py/async_client.py index 99b6212..716fb70 100644 --- a/scrapegraph-py/scrapegraph_py/async_client.py +++ b/scrapegraph-py/scrapegraph_py/async_client.py @@ -1,5 +1,9 @@ import asyncio -from typing import Any, Optional +from typing import Any, Optional, Dict, Callable, Awaitable, TypeVar, Generic +from enum import Enum +from dataclasses import dataclass +from datetime import datetime +from uuid import uuid4 from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientError @@ -20,6 +24,26 @@ ) from scrapegraph_py.utils.helpers import handle_async_response, validate_api_key +T = TypeVar('T') + +class JobStatus(Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + +@dataclass +class Job(Generic[T]): + id: str + status: JobStatus + created_at: datetime + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + result: Optional[T] = None + error: Optional[Exception] = None + task: Optional[Callable[..., Awaitable[T]]] = None + args: tuple = () + kwargs: dict = None class AsyncClient: @classmethod @@ -58,6 +82,7 @@ def __init__( timeout: Optional[float] = None, max_retries: int = 3, retry_delay: float = 1.0, + max_queue_size: int = 1000, ): """Initialize AsyncClient with configurable parameters. @@ -67,6 +92,7 @@ def __init__( timeout: Request timeout in seconds. None means no timeout (infinite) max_retries: Maximum number of retry attempts retry_delay: Delay between retries in seconds + max_queue_size: Maximum number of jobs in the queue """ logger.info("🔑 Initializing AsyncClient") @@ -96,8 +122,132 @@ def __init__( headers=self.headers, connector=TCPConnector(ssl=ssl), timeout=self.timeout ) + # Initialize job queue + self.job_queue: asyncio.Queue[Job] = asyncio.Queue(maxsize=max_queue_size) + self.jobs: Dict[str, Job] = {} + self._queue_processor_task = None + logger.info("✅ AsyncClient initialized successfully") + async def start_queue_processor(self): + """Start the background job queue processor.""" + if self._queue_processor_task is None: + self._queue_processor_task = asyncio.create_task(self._process_queue()) + logger.info("🚀 Job queue processor started") + + async def stop_queue_processor(self): + """Stop the background job queue processor.""" + if self._queue_processor_task is not None: + self._queue_processor_task.cancel() + try: + await self._queue_processor_task + except asyncio.CancelledError: + pass + self._queue_processor_task = None + logger.info("⏚ī¸ Job queue processor stopped") + + async def _process_queue(self): + """Process jobs from the queue.""" + while True: + try: + job = await self.job_queue.get() + job.status = JobStatus.RUNNING + job.started_at = datetime.now() + + try: + if job.task: + job.result = await job.task(*job.args, **(job.kwargs or {})) + job.status = JobStatus.COMPLETED + except Exception as e: + job.error = e + job.status = JobStatus.FAILED + logger.error(f"❌ Job {job.id} failed: {str(e)}") + finally: + job.completed_at = datetime.now() + self.job_queue.task_done() + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"❌ Queue processor error: {str(e)}") + + async def submit_job(self, task: Callable[..., Awaitable[T]], *args, **kwargs) -> str: + """Submit a new job to the queue. + + Args: + task: Async function to execute + *args: Positional arguments for the task + **kwargs: Keyword arguments for the task + + Returns: + str: Job ID + """ + job_id = str(uuid4()) + job = Job( + id=job_id, + status=JobStatus.PENDING, + created_at=datetime.now(), + task=task, + args=args, + kwargs=kwargs + ) + + self.jobs[job_id] = job + await self.job_queue.put(job) + logger.info(f"📋 Job {job_id} submitted to queue") + + # Ensure queue processor is running + if self._queue_processor_task is None: + await self.start_queue_processor() + + return job_id + + async def get_job_status(self, job_id: str) -> Dict[str, Any]: + """Get the status of a job. + + Args: + job_id: The ID of the job to check + + Returns: + Dict containing job status information + """ + if job_id not in self.jobs: + raise ValueError(f"Job {job_id} not found") + + job = self.jobs[job_id] + return { + "id": job.id, + "status": job.status.value, + "created_at": job.created_at, + "started_at": job.started_at, + "completed_at": job.completed_at, + "result": job.result, + "error": str(job.error) if job.error else None + } + + async def wait_for_job(self, job_id: str, timeout: Optional[float] = None) -> Any: + """Wait for a job to complete and return its result. + + Args: + job_id: The ID of the job to wait for + timeout: Maximum time to wait in seconds + + Returns: + The result of the job + """ + if job_id not in self.jobs: + raise ValueError(f"Job {job_id} not found") + + job = self.jobs[job_id] + + while job.status in (JobStatus.PENDING, JobStatus.RUNNING): + await asyncio.sleep(0.1) + + if job.error: + raise job.error + + return job.result + async def _make_request(self, method: str, url: str, **kwargs) -> Any: """Make HTTP request with retry logic.""" for attempt in range(self.max_retries): @@ -285,8 +435,9 @@ async def get_searchscraper(self, request_id: str): return result async def close(self): - """Close the session to free up resources""" + """Close the session and stop the queue processor.""" logger.info("🔒 Closing AsyncClient session") + await self.stop_queue_processor() await self.session.close() logger.debug("✅ Session closed successfully") diff --git a/scrapegraph-py/uv.lock b/scrapegraph-py/uv.lock index c250817..290ee64 100644 --- a/scrapegraph-py/uv.lock +++ b/scrapegraph-py/uv.lock @@ -1525,7 +1525,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "aiohttp", specifier = ">=3.11.8" }, + { name = "aiohttp", specifier = ">=3.10" }, { name = "beautifulsoup4", specifier = ">=4.12.3" }, { name = "furo", marker = "extra == 'docs'", specifier = "==2024.5.6" }, { name = "pydantic", specifier = ">=2.10.2" },