-
Notifications
You must be signed in to change notification settings - Fork 4
feat: improved job wait_for strategies #318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
11605d9
19f00d0
2503e74
30e47a2
07f4545
341a101
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import asyncio | ||
| import json | ||
| import logging | ||
| import ssl | ||
| from datetime import datetime, timezone | ||
| from enum import Enum | ||
|
|
@@ -60,10 +61,11 @@ | |
| SystemRef, | ||
| WasmModuleRef, | ||
| ) | ||
|
|
||
| from qnexus.models.scope import ScopeFilterEnum | ||
| from qnexus.models.utils import assert_never | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| EPOCH_START = datetime(1970, 1, 1, tzinfo=timezone.utc) | ||
|
|
||
|
|
||
|
|
@@ -80,6 +82,22 @@ class RemoteRetryStrategy(str, Enum): | |
| FULL_RESTART = "FULL_RESTART" | ||
|
|
||
|
|
||
| class WaitStrategy(str, Enum): | ||
| """Strategy for waiting on job completion. | ||
|
|
||
| WEBSOCKET: Use a websocket connection for real-time updates. | ||
| Best for short-running jobs (<5 minutes). | ||
| POLLING: Use exponential backoff polling. | ||
| More robust for long-running jobs. | ||
| AUTO: Start with websocket, fall back to polling after 5 minutes. | ||
| Recommended for most use cases. | ||
| """ | ||
|
|
||
| WEBSOCKET = "websocket" | ||
| POLLING = "polling" | ||
| AUTO = "auto" | ||
|
|
||
|
|
||
| class Params( | ||
| CreatorFilter, | ||
| PropertiesFilter, | ||
|
|
@@ -338,19 +356,168 @@ def _fetch_by_id( | |
| ) | ||
|
|
||
|
|
||
| def wait_for( | ||
| async def poll_job_status( | ||
| job: JobRef, | ||
| wait_for_status: JobStatusEnum = JobStatusEnum.COMPLETED, | ||
| initial_interval: float = 1.0, | ||
| max_interval_queued: float = 1200.0, | ||
| max_interval_running: float = 180.0, | ||
| backoff_factor: float = 2.0, | ||
| ) -> JobStatus: | ||
| """Poll job status with exponential backoff and adaptive intervals. | ||
|
|
||
| Uses different maximum poll intervals based on job state: | ||
| - QUEUED: Polls less frequently (default 20 min) since queue position changes slowly | ||
| - RUNNING/SUBMITTED: Polls more frequently (default 3 min) for responsiveness | ||
|
Comment on lines
+397
to
+398
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great idea |
||
|
|
||
| Args: | ||
| job: The job to monitor. | ||
| wait_for_status: The status to wait for. | ||
| initial_interval: Starting poll interval in seconds. | ||
| max_interval_queued: Maximum poll interval when job is queued (default: 1200s). | ||
| max_interval_running: Maximum poll interval when job is running (default: 180s). | ||
| backoff_factor: Multiplier for interval after each poll. | ||
|
|
||
| Returns: | ||
| The final JobStatus when the target status is reached or job terminates. | ||
| """ | ||
| interval = initial_interval | ||
| logger.debug( | ||
| "Starting polling for job %s (target: %s, interval: %.1fs, " | ||
| "max queued: %.1fs, max running: %.1fs)", | ||
| job.id, | ||
| wait_for_status.value, | ||
| initial_interval, | ||
| max_interval_queued, | ||
| max_interval_running, | ||
| ) | ||
|
|
||
| while True: | ||
| job_status = status(job) | ||
|
|
||
| # Adapt max interval based on job state | ||
| if job_status.status == JobStatusEnum.QUEUED: | ||
| max_interval = max_interval_queued | ||
| else: | ||
| max_interval = max_interval_running | ||
|
|
||
| # Clamp interval to current max (allows faster polling when transitioning | ||
| # from QUEUED to RUNNING) | ||
| interval = min(interval, max_interval) | ||
|
|
||
| logger.debug( | ||
| "Job %s status: %s (next poll in %.1fs, max: %.1fs)", | ||
| job.id, | ||
| job_status.status.value, | ||
| interval, | ||
| max_interval, | ||
| ) | ||
|
|
||
| if ( | ||
| job_status.status not in WAITING_STATUS | ||
| or job_status.status == wait_for_status | ||
| ): | ||
| logger.debug("Job %s reached status: %s", job.id, job_status.status.value) | ||
| return job_status | ||
|
|
||
| await asyncio.sleep(interval) | ||
| interval = min(interval * backoff_factor, max_interval) | ||
|
|
||
|
|
||
| async def hybrid_wait( | ||
| job: JobRef, | ||
| wait_for_status: JobStatusEnum = JobStatusEnum.COMPLETED, | ||
| timeout: float | None = 900.0, | ||
| websocket_timeout: float = 600.0, | ||
| ) -> JobStatus: | ||
| """Check job status until the job is complete (or a specified status).""" | ||
| job_status = asyncio.run( | ||
| asyncio.wait_for( | ||
| """Use websocket for initial period, then fall back to polling. | ||
|
|
||
| Args: | ||
| job: The job to monitor. | ||
| wait_for_status: The status to wait for. | ||
| websocket_timeout: How long to use websocket before switching to polling. | ||
|
|
||
| Returns: | ||
| The final JobStatus when the target status is reached or job terminates. | ||
| """ | ||
| logger.debug( | ||
| "Using hybrid strategy for job %s (websocket timeout: %.1fs)", | ||
| job.id, | ||
| websocket_timeout, | ||
| ) | ||
| try: | ||
| # Try websocket first with a timeout | ||
| job_status = await asyncio.wait_for( | ||
| listen_job_status(job=job, wait_for_status=wait_for_status), | ||
| timeout=timeout, | ||
| timeout=websocket_timeout, | ||
| ) | ||
| return job_status | ||
| except asyncio.TimeoutError: | ||
| # Websocket phase timed out, switch to polling | ||
| logger.debug( | ||
| "Job %s: websocket timeout after %.1fs, switching to polling", | ||
| job.id, | ||
| websocket_timeout, | ||
| ) | ||
| return await poll_job_status(job=job, wait_for_status=wait_for_status) | ||
|
|
||
|
|
||
| def wait_for( | ||
| job: JobRef, | ||
| wait_for_status: JobStatusEnum = JobStatusEnum.COMPLETED, | ||
| timeout: float | None = None, | ||
| strategy: WaitStrategy = WaitStrategy.AUTO, | ||
| websocket_timeout: float = 600.0, | ||
| ) -> JobStatus: | ||
| """Check job status until the job is complete (or a specified status). | ||
|
|
||
| Args: | ||
| job: The job to monitor. | ||
| wait_for_status: The status to wait for (default: COMPLETED). | ||
| timeout: Overall timeout in seconds. None for no timeout (default: None). | ||
| strategy: How to monitor the job: | ||
| - WEBSOCKET: Real-time updates via websocket. Best for short jobs (<10 minutes). | ||
| - POLLING: Exponential backoff polling. Robust for long jobs (>10 minutes). | ||
| - AUTO: Websocket first, then polling fallback (default). | ||
| Recommended for most use cases. | ||
| websocket_timeout: For AUTO strategy, how long to use websocket | ||
| before switching to polling (default: 600 seconds). | ||
|
|
||
| Returns: | ||
| The final JobStatus. | ||
|
|
||
| Raises: | ||
| JobError: If the job errors, is cancelled, depleted, or terminated | ||
| (unless that was the status being waited for). | ||
| asyncio.TimeoutError: If the overall timeout is exceeded. | ||
| """ | ||
| logger.debug( | ||
| "Waiting for job %s with strategy=%s, timeout=%s, target=%s", | ||
| job.id, | ||
| strategy.value, | ||
| timeout, | ||
| wait_for_status.value, | ||
| ) | ||
|
|
||
| match strategy: | ||
| case WaitStrategy.WEBSOCKET: | ||
| coro = listen_job_status(job=job, wait_for_status=wait_for_status) | ||
| case WaitStrategy.POLLING: | ||
| coro = poll_job_status(job=job, wait_for_status=wait_for_status) | ||
| case WaitStrategy.AUTO: | ||
| coro = hybrid_wait( | ||
| job=job, | ||
| wait_for_status=wait_for_status, | ||
| websocket_timeout=websocket_timeout, | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of an
|
||
| case _: | ||
| assert_never(strategy) | ||
|
|
||
| if timeout is not None: | ||
| coro = asyncio.wait_for(coro, timeout=timeout) | ||
|
|
||
| job_status = asyncio.run(coro) | ||
| logger.info("Job %s finished with status: %s", job.id, job_status.status.value) | ||
|
|
||
| if ( | ||
| job_status.status == JobStatusEnum.ERROR | ||
| and wait_for_status != JobStatusEnum.ERROR | ||
|
|
@@ -387,7 +554,6 @@ def status(job: JobRef, scope: ScopeFilterEnum = ScopeFilterEnum.USER) -> JobSta | |
| message=resp.text, status_code=resp.status_code | ||
| ) | ||
| job_status = JobStatus.from_dict(resp.json()) | ||
| # job.last_status = job_status.status | ||
| return job_status | ||
|
|
||
|
|
||
|
|
@@ -397,7 +563,7 @@ async def listen_job_status( | |
| """Check the Status of a Job via a websocket connection. | ||
| Will use SSO tokens.""" | ||
| job_status = status(job) | ||
| # logger.debug("Current job status: %s", job_status.status) | ||
| logger.debug("Job %s initial status: %s", job.id, job_status.status.value) | ||
| if job_status.status not in WAITING_STATUS or job_status.status == wait_for_status: | ||
| return job_status | ||
|
|
||
|
|
@@ -418,17 +584,20 @@ def _process_exception(exc: Exception) -> Exception | None: | |
| # TODO, this cookie will expire frequently | ||
| "Cookie": f"myqos_id={get_nexus_client().auth.cookies.get('myqos_id')}" # type: ignore | ||
| } | ||
| logger.debug("Job %s: opening websocket connection", job.id) | ||
| async for websocket in connect( | ||
| f"{CONFIG.websockets_url}/api/jobs/v1beta3/{job.id}/attributes/status/ws", | ||
| ssl=ssl_context, | ||
| additional_headers=additional_headers, | ||
| process_exception=_process_exception, | ||
| # logger=logger, | ||
| logger=logger, | ||
| ): | ||
| try: | ||
| async for status_json in websocket: | ||
| # logger.debug("New status: %s", status_json) | ||
| job_status = JobStatus.from_dict(json.loads(status_json)) | ||
| logger.debug( | ||
| "Job %s websocket update: %s", job.id, job_status.status.value | ||
| ) | ||
|
|
||
| if ( | ||
| job_status.status not in WAITING_STATUS | ||
|
|
@@ -437,9 +606,9 @@ def _process_exception(exc: Exception) -> Exception | None: | |
| break | ||
| break | ||
| except ConnectionClosed: | ||
| # logger.debug( | ||
| # "Websocket connection closed... attempting to reconnect..." | ||
| # ) | ||
| logger.debug( | ||
| "Job %s: websocket connection closed, attempting to reconnect", job.id | ||
| ) | ||
| continue | ||
| finally: | ||
| try: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(related to my other comment about using subclasses for different strategies)
The implementation of
wait_forgives the user no way to choose different values for these params. They have to pass in anenummember and the code inwait_forpicks a function, calling it with just the required args, so these'll always get defaults.With a class-based approach, it could be:
and then a user could (if they chose) do:
and they could pass that to
wait_for