Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ dependencies = [
"cloudpickle>=3.1.1",
"runpod",
"python-dotenv>=1.0.0",
"rich>=14.0.0",
]

[project.optional-dependencies]
rich = ["rich>=13.0.0"]

[dependency-groups]
dev = [
"mypy>=1.16.1",
Expand Down
2 changes: 2 additions & 0 deletions src/tetra_rp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
runpod,
NetworkVolume,
)
from .core.utils.rich_ui import capture_local_prints # noqa: E402


__all__ = [
"remote",
"capture_local_prints",
"CpuServerlessEndpoint",
"CpuInstanceType",
"CudaVersion",
Expand Down
16 changes: 13 additions & 3 deletions src/tetra_rp/core/api/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from typing import Any, Dict, Optional

import aiohttp
from ..utils.rich_ui import format_endpoint_created


log = logging.getLogger(__name__)


RUNPOD_API_BASE_URL = os.environ.get("RUNPOD_API_BASE_URL", "https://api.runpod.io")
RUNPOD_REST_API_URL = os.environ.get("RUNPOD_REST_API_URL", "https://rest.runpod.io/v1")

Expand Down Expand Up @@ -126,9 +129,16 @@ async def create_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
raise Exception("Unexpected GraphQL response structure")

endpoint_data = result["saveEndpoint"]
log.info(
f"Created endpoint: {endpoint_data.get('id', 'unknown')} - {endpoint_data.get('name', 'unnamed')}"
)

# Use Rich formatting if available
try:
format_endpoint_created(
endpoint_data.get("id", "unknown"), endpoint_data.get("name", "unnamed")
)
except ImportError:
log.info(
f"Created endpoint: {endpoint_data.get('id', 'unknown')} - {endpoint_data.get('name', 'unnamed')}"
)

return endpoint_data

Expand Down
24 changes: 19 additions & 5 deletions src/tetra_rp/core/resources/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from pathlib import Path

from ..utils.singleton import SingletonMixin
from ..utils.rich_ui import (
create_reused_resource_panel,
format_console_url,
is_rich_enabled,
print_with_rich,
)

from .base import DeployableResource

Expand All @@ -29,7 +35,8 @@ def _load_resources(self) -> Dict[str, DeployableResource]:
try:
with open(RESOURCE_STATE_FILE, "rb") as f:
self._resources = cloudpickle.load(f)
log.debug(f"Loaded saved resources from {RESOURCE_STATE_FILE}")
if not is_rich_enabled():
log.debug(f"Loaded saved resources from {RESOURCE_STATE_FILE}")
except Exception as e:
log.error(f"Failed to load resources from {RESOURCE_STATE_FILE}: {e}")
return self._resources
Expand All @@ -38,7 +45,8 @@ def _save_resources(self) -> None:
"""Persist state of resources to disk using cloudpickle."""
with open(RESOURCE_STATE_FILE, "wb") as f:
cloudpickle.dump(self._resources, f)
log.debug(f"Saved resources in {RESOURCE_STATE_FILE}")
if not is_rich_enabled():
log.debug(f"Saved resources in {RESOURCE_STATE_FILE}")

def add_resource(self, uid: str, resource: DeployableResource):
"""Add a resource to the manager."""
Expand Down Expand Up @@ -68,12 +76,18 @@ async def get_or_deploy_resource(
self.remove_resource(uid)
return await self.get_or_deploy_resource(config)

log.debug(f"{existing} exists, reusing.")
log.info(f"URL: {existing.url}")
if is_rich_enabled():
# Show a panel for reused resources similar to fresh deployments
panel = create_reused_resource_panel(
existing.name, existing.id, existing.url
)
print_with_rich(panel)
else:
log.debug(f"{existing} exists, reusing.")
return existing

if deployed_resource := await config.deploy():
log.info(f"URL: {deployed_resource.url}")
format_console_url(deployed_resource.url)
self.add_resource(uid, deployed_resource)
return deployed_resource

Expand Down
104 changes: 71 additions & 33 deletions src/tetra_rp/core/resources/serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
field_validator,
model_validator,
)

from tetra_rp.core.utils.rich_ui import (
job_progress_tracker,
create_deployment_panel,
create_metrics_table,
print_with_rich,
is_rich_enabled,
)
from runpod.endpoint.runner import Job

from ..api.runpod import RunpodGraphQLClient
Expand All @@ -22,6 +30,7 @@
from .gpu import GpuGroup
from .network_volume import NetworkVolume
from .template import KeyValuePair, PodTemplate
from ..utils.rich_ui import rich_ui, format_api_info, format_job_status


# Environment variables are loaded from the .env file
Expand Down Expand Up @@ -153,7 +162,7 @@ def validate_gpus(cls, value: List[GpuGroup]) -> List[GpuGroup]:
@model_validator(mode="after")
def sync_input_fields(self):
"""Sync between temporary inputs and exported fields"""
if self.flashboot:
if self.flashboot and not self.name.endswith("-fb"):
self.name += "-fb"

if self.networkVolume and self.networkVolume.is_created:
Expand Down Expand Up @@ -243,6 +252,13 @@ async def deploy(self) -> "DeployableResource":
result = await client.create_endpoint(payload)

if endpoint := self.__class__(**result):
if is_rich_enabled():
panel = create_deployment_panel(
endpoint.name, endpoint.id or "", endpoint.url
)
print_with_rich(panel)
else:
log.info(f"Deployed: {endpoint}")
return endpoint

raise ValueError("Deployment failed, no endpoint was returned.")
Expand Down Expand Up @@ -292,44 +308,60 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput":
# log.debug(f"[{self}] Payload: {payload}")

# Create a job using the endpoint
log.info(f"{self} | API /run")
format_api_info(str(self), self.id or "", "/run")
job = await asyncio.to_thread(self.endpoint.run, request_input=payload)

log_subgroup = f"Job:{job.job_id}"

log.info(f"{self} | Started {log_subgroup}")

current_pace = 0
attempt = 0
job_status = Status.UNKNOWN
last_status = job_status
# Use Rich progress tracker if available
with job_progress_tracker(job.job_id, self.name) as tracker:
if not is_rich_enabled():
log.info(f"{self} | Started {log_subgroup}")
elif tracker:
# Initialize the progress tracker with starting status
tracker.update_status(
"IN_QUEUE", "Job submitted, waiting for worker..."
)

current_pace = 0
attempt = 0
job_status = Status.UNKNOWN
last_status = job_status

# Poll for job status
while True:
await asyncio.sleep(current_pace)
# Poll for job status
while True:
await asyncio.sleep(current_pace)

# Check job status
job_status = await asyncio.to_thread(job.status)

if last_status == job_status:
# nothing changed, increase the gap
attempt += 1
indicator = "." * (attempt // 2) if attempt % 2 == 0 else ""
if indicator:
log.info(f"{log_subgroup} | {indicator}")
else:
# status changed, reset the gap
log.info(f"{log_subgroup} | Status: {job_status}")
attempt = 0

last_status = job_status

# Adjust polling pace appropriately
current_pace = get_backoff_delay(attempt)

if job_status in ("COMPLETED", "FAILED", "CANCELLED"):
response = await asyncio.to_thread(job._fetch_job)
return JobOutput(**response)
if last_status == job_status:
# nothing changed, increase the gap
attempt += 1
if tracker:
tracker.show_progress_indicator()
else:
if not rich_ui.enabled:
indicator = (
"." * (attempt // 2) if attempt % 2 == 0 else ""
)
if indicator:
log.info(f"{log_subgroup} | {indicator}")
else:
# status changed, reset the gap
format_job_status(job.job_id, job_status)
attempt = 0

last_status = job_status

# Adjust polling pace appropriately
current_pace = get_backoff_delay(attempt)

if job_status in ("COMPLETED", "FAILED", "CANCELLED"):
if tracker:
tracker.update_status(job_status)
response = await asyncio.to_thread(job._fetch_job)
return JobOutput(**response)

except Exception as e:
if job and job.job_id:
Expand Down Expand Up @@ -390,9 +422,15 @@ class JobOutput(BaseModel):
error: Optional[str] = ""

def model_post_init(self, __context):
log_group = f"Worker:{self.workerId}"
log.info(f"{log_group} | Delay Time: {self.delayTime} ms")
log.info(f"{log_group} | Execution Time: {self.executionTime} ms")
if is_rich_enabled():
metrics_table = create_metrics_table(
self.delayTime, self.executionTime, self.workerId
)
print_with_rich(metrics_table)
else:
log_group = f"Worker:{self.workerId}"
log.info(f"{log_group} | Delay Time: {self.delayTime} ms")
log.info(f"{log_group} | Execution Time: {self.executionTime} ms")


class Status(str, Enum):
Expand Down
Loading
Loading