diff --git a/qnexus/client/jobs/__init__.py b/qnexus/client/jobs/__init__.py index bd281ff..ac56558 100644 --- a/qnexus/client/jobs/__init__.py +++ b/qnexus/client/jobs/__init__.py @@ -1,10 +1,11 @@ """Client API for jobs in Nexus.""" +import abc import asyncio import json import logging import ssl -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime, timezone from enum import Enum from typing import Any, Type, Union, cast, overload @@ -84,15 +85,90 @@ class RemoteRetryStrategy(str, Enum): @dataclass -class WebsocketStrategy: +class WaitStrategy(abc.ABC): + wait_for_status: JobStatusEnum = JobStatusEnum.COMPLETED + + @abc.abstractmethod + async def get_status(self, job: JobRef) -> JobStatus: + pass + + def _finished(self, job_status: JobStatus) -> bool: + return ( + job_status.status not in WAITING_STATUS + or job_status.status == self.wait_for_status + ) + + +@dataclass +class WebsocketStrategy(WaitStrategy): """Use a websocket connection for real-time updates. Best for short-running jobs (<10 minutes). """ + async def get_status(self, job: JobRef) -> JobStatus: + """Check the Status of a Job via a websocket connection. + Will use SSO tokens.""" + + job_status = status(job) + logger.debug("Job %s initial status: %s", job.id, job_status.status.value) + if self._finished(job_status): + return job_status + + ssl_context = httpx.create_ssl_context(verify=CONFIG.httpx_verify) + + def _process_exception(exc: Exception) -> Exception | None: + """Utility wrapper around process_exception that tells the websockets + library not to auto-retry SSLErrors as they are usually not recoverable. + + Unfortunately SSLError inherits from OSError which websockets will always + retried when `connect` is used in an async for loop. + """ + if isinstance(exc, ssl.SSLError): + return exc + return process_exception(exc) + + additional_headers = { + # TODO, this cookie will expire frequently + "Cookie": f"myqos_id={get_nexus_client().auth.cookies.get('myqos_id')}" # type: ignore + } + logger.debug("Job %s: opening websocket connection", job.id) + async for websocket in connect( + f"{CONFIG.websockets_url}/api/jobs/v1beta3/{job.id}/attributes/status/ws", + ssl=ssl_context, + additional_headers=additional_headers, + process_exception=_process_exception, + logger=logger, + ): + try: + async for status_json in websocket: + job_status = JobStatus.from_dict(json.loads(status_json)) + logger.debug( + "Job %s websocket update: %s", + job.id, + job_status.status.value, + ) + + if self._finished(job_status): + break + break + except ConnectionClosed: + logger.debug( + "Job %s: websocket connection closed, attempting to reconnect", + job.id, + ) + continue + finally: + try: + await websocket.close(code=1000, reason="Client closed connection") + except GeneratorExit: + pass + + return job_status + @dataclass -class PollingStrategy: +class PollingStrategy(WaitStrategy): """Use exponential backoff polling. More robust for long-running jobs (>10 minutes). @@ -109,23 +185,105 @@ class PollingStrategy: max_interval_running: float = 180.0 backoff_factor: float = 2.0 + async def get_status(self, job: JobRef) -> JobStatus: + """Poll job status with exponential backoff and adaptive intervals. + + Uses different maximum poll intervals based on job state: + - QUEUED: Polls less frequently (default 20 min) since queue position changes slowly + - RUNNING/SUBMITTED: Polls more frequently (default 3 min) for responsiveness + + Args: + job: The job to monitor. + wait_for_status: The status to wait for. + strategy: Polling configuration. + + Returns: + The final JobStatus when the target status is reached or job terminates. + """ + interval = self.initial_interval + logger.debug( + "Starting polling for job %s (target: %s, interval: %.1fs, " + "max queued: %.1fs, max running: %.1fs)", + job.id, + self.wait_for_status.value, + self.initial_interval, + self.max_interval_queued, + self.max_interval_running, + ) + + while True: + job_status = status(job) + + # Adapt max interval based on job state + if job_status.status == JobStatusEnum.QUEUED: + max_interval = self.max_interval_queued + else: + max_interval = self.max_interval_running + + # Clamp interval to current max (allows faster polling when transitioning + # from QUEUED to RUNNING) + interval = min(interval, max_interval) + + logger.debug( + "Job %s status: %s (next poll in %.1fs, max: %.1fs)", + job.id, + job_status.status.value, + interval, + max_interval, + ) + + if self._finished(job_status): + logger.debug( + "Job %s reached status: %s", job.id, job_status.status.value + ) + return job_status + + await asyncio.sleep(interval) + interval = min(interval * self.backoff_factor, max_interval) + @dataclass -class HybridStrategy: +class HybridStrategy(WebsocketStrategy, PollingStrategy): """Start with websocket, fall back to polling. Recommended for most use cases. Attributes: websocket_timeout: How long to use websocket before switching to polling. - polling: Configuration for the polling fallback. """ websocket_timeout: float = 600.0 - polling: PollingStrategy = field(default_factory=PollingStrategy) + async def get_status(self, job: JobRef) -> JobStatus: + """Use websocket for initial period, then fall back to polling. -WaitStrategy = WebsocketStrategy | PollingStrategy | HybridStrategy + Args: + job: The job to monitor. + wait_for_status: The status to wait for. + strategy: Hybrid strategy configuration. + + Returns: + The final JobStatus when the target status is reached or job terminates. + """ + logger.debug( + "Using hybrid strategy for job %s (websocket timeout: %.1fs)", + job.id, + self.websocket_timeout, + ) + try: + # Try websocket first with a timeout + return await asyncio.wait_for( + WebsocketStrategy.get_status(self, job), + timeout=self.websocket_timeout, + ) + except asyncio.TimeoutError: + # Websocket phase timed out, switch to polling + logger.debug( + "Job %s: websocket timeout after %.1fs, switching to polling", + job.id, + self.websocket_timeout, + ) + return await PollingStrategy.get_status(self, job) class Params( @@ -386,112 +544,11 @@ def _fetch_by_id( ) -async def poll_job_status( - job: JobRef, - wait_for_status: JobStatusEnum = JobStatusEnum.COMPLETED, - strategy: PollingStrategy = PollingStrategy(), -) -> JobStatus: - """Poll job status with exponential backoff and adaptive intervals. - - Uses different maximum poll intervals based on job state: - - QUEUED: Polls less frequently (default 20 min) since queue position changes slowly - - RUNNING/SUBMITTED: Polls more frequently (default 3 min) for responsiveness - - Args: - job: The job to monitor. - wait_for_status: The status to wait for. - strategy: Polling configuration. - - Returns: - The final JobStatus when the target status is reached or job terminates. - """ - interval = strategy.initial_interval - logger.debug( - "Starting polling for job %s (target: %s, interval: %.1fs, " - "max queued: %.1fs, max running: %.1fs)", - job.id, - wait_for_status.value, - strategy.initial_interval, - strategy.max_interval_queued, - strategy.max_interval_running, - ) - - while True: - job_status = status(job) - - # Adapt max interval based on job state - if job_status.status == JobStatusEnum.QUEUED: - max_interval = strategy.max_interval_queued - else: - max_interval = strategy.max_interval_running - - # Clamp interval to current max (allows faster polling when transitioning - # from QUEUED to RUNNING) - interval = min(interval, max_interval) - - logger.debug( - "Job %s status: %s (next poll in %.1fs, max: %.1fs)", - job.id, - job_status.status.value, - interval, - max_interval, - ) - - if ( - job_status.status not in WAITING_STATUS - or job_status.status == wait_for_status - ): - logger.debug("Job %s reached status: %s", job.id, job_status.status.value) - return job_status - - await asyncio.sleep(interval) - interval = min(interval * strategy.backoff_factor, max_interval) - - -async def hybrid_wait( - job: JobRef, - wait_for_status: JobStatusEnum = JobStatusEnum.COMPLETED, - strategy: HybridStrategy = HybridStrategy(), -) -> JobStatus: - """Use websocket for initial period, then fall back to polling. - - Args: - job: The job to monitor. - wait_for_status: The status to wait for. - strategy: Hybrid strategy configuration. - - Returns: - The final JobStatus when the target status is reached or job terminates. - """ - logger.debug( - "Using hybrid strategy for job %s (websocket timeout: %.1fs)", - job.id, - strategy.websocket_timeout, - ) - try: - # Try websocket first with a timeout - job_status = await asyncio.wait_for( - listen_job_status(job=job, wait_for_status=wait_for_status), - timeout=strategy.websocket_timeout, - ) - return job_status - except asyncio.TimeoutError: - # Websocket phase timed out, switch to polling - logger.debug( - "Job %s: websocket timeout after %.1fs, switching to polling", - job.id, - strategy.websocket_timeout, - ) - return await poll_job_status( - job=job, wait_for_status=wait_for_status, strategy=strategy.polling - ) - - def wait_for( job: JobRef, wait_for_status: JobStatusEnum = JobStatusEnum.COMPLETED, timeout: float | None = None, - strategy: WaitStrategy = HybridStrategy(), + strategy: WaitStrategy | None = None, ) -> JobStatus: """Check job status until the job is complete (or a specified status). @@ -528,6 +585,9 @@ def wait_for( polling=PollingStrategy(max_interval_running=60.0) )) """ + if strategy is None: + strategy = HybridStrategy() + logger.debug( "Waiting for job %s with strategy=%s, timeout=%s, target=%s", job.id, @@ -536,19 +596,7 @@ def wait_for( wait_for_status.value, ) - match strategy: - case WebsocketStrategy(): - coro = listen_job_status(job=job, wait_for_status=wait_for_status) - case PollingStrategy(): - coro = poll_job_status( - job=job, wait_for_status=wait_for_status, strategy=strategy - ) - case HybridStrategy(): - coro = hybrid_wait( - job=job, wait_for_status=wait_for_status, strategy=strategy - ) - case _: - assert_never(strategy) + coro = strategy.get_status(job) if timeout is not None: coro = asyncio.wait_for(coro, timeout=timeout) @@ -595,68 +643,6 @@ def status(job: JobRef, scope: ScopeFilterEnum = ScopeFilterEnum.USER) -> JobSta return job_status -async def listen_job_status( - job: JobRef, wait_for_status: JobStatusEnum = JobStatusEnum.COMPLETED -) -> JobStatus: - """Check the Status of a Job via a websocket connection. - Will use SSO tokens.""" - job_status = status(job) - logger.debug("Job %s initial status: %s", job.id, job_status.status.value) - if job_status.status not in WAITING_STATUS or job_status.status == wait_for_status: - return job_status - - ssl_context = httpx.create_ssl_context(verify=CONFIG.httpx_verify) - - def _process_exception(exc: Exception) -> Exception | None: - """Utility wrapper around process_exception that tells the websockets - library not to auto-retry SSLErrors as they are usually not recoverable. - - Unfortunately SSLError inherits from OSError which websockets will always - retried when `connect` is used in an async for loop. - """ - if isinstance(exc, ssl.SSLError): - return exc - return process_exception(exc) - - additional_headers = { - # TODO, this cookie will expire frequently - "Cookie": f"myqos_id={get_nexus_client().auth.cookies.get('myqos_id')}" # type: ignore - } - logger.debug("Job %s: opening websocket connection", job.id) - async for websocket in connect( - f"{CONFIG.websockets_url}/api/jobs/v1beta3/{job.id}/attributes/status/ws", - ssl=ssl_context, - additional_headers=additional_headers, - process_exception=_process_exception, - logger=logger, - ): - try: - async for status_json in websocket: - job_status = JobStatus.from_dict(json.loads(status_json)) - logger.debug( - "Job %s websocket update: %s", job.id, job_status.status.value - ) - - if ( - job_status.status not in WAITING_STATUS - or job_status.status == wait_for_status - ): - break - break - except ConnectionClosed: - logger.debug( - "Job %s: websocket connection closed, attempting to reconnect", job.id - ) - continue - finally: - try: - await websocket.close(code=1000, reason="Client closed connection") - except GeneratorExit: - pass - - return job_status - - @merge_scope_from_context @overload def results(