Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 81 additions & 22 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
return self._describe_sacct(app_id)

def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
# NOTE: Handles multiple job ID formats due to SLURM version differences.
# Different clusters use heterogeneous (+) vs regular (.) job ID formats.
try:
output = subprocess.check_output(
["sacct", "--parsable2", "-j", app_id],
Expand All @@ -594,15 +596,27 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
msg = ""
app_state = AppState.UNKNOWN
for row in reader:
job_id, *parts = row["JobID"].split("+")
# Handle both "+" (heterogeneous) and "." (regular) job ID formats
job_id_full = row["JobID"]

# Split on both "+" and "." to handle different SLURM configurations
if "+" in job_id_full:
job_id, *parts = job_id_full.split("+")
is_subjob = len(parts) > 0 and "." in parts[0]
else:
job_id, *parts = job_id_full.split(".")
is_subjob = len(parts) > 0

if job_id != app_id:
continue
if len(parts) > 0 and "." in parts[0]:
# we only care about the worker not the child jobs

if is_subjob:
# we only care about the main job not the child jobs (.batch, .0, etc.)
continue

state = row["State"]
msg = state
msg = row["State"]
# Remove truncation indicator (CANCELLED+) and extract base state from verbose formats
state = msg.split()[0].rstrip("+")
app_state = appstate_from_slurm_state(state)

role, _, replica_id = row["JobName"].rpartition("-")
Expand All @@ -629,6 +643,9 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
)

def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
# NOTE: This method contains multiple compatibility checks for different SLURM versions
# due to API format changes across versions (20.02, 23.02, 24.05, 24.11+).

# squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
# if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
output = subprocess.check_output(
Expand Down Expand Up @@ -670,7 +687,18 @@ def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
if state == AppState.PENDING:
# NOTE: torchx launched jobs points to exactly one host
# otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
hostname = job_resources.get("scheduled_nodes", "")

# SLURM 24.11.5+ returns job_resources=None for pending jobs (issue #1101)
if job_resources is not None:
hostname = job_resources.get("scheduled_nodes", "")
# If scheduled_nodes not found in job_resources, try nodes.list
if not hostname and "nodes" in job_resources:
nodes_info = job_resources.get("nodes", {})
if isinstance(nodes_info, dict):
hostname = nodes_info.get("list", "")
else:
# For pending jobs where job_resources is None, check top-level fields
hostname = job.get("nodes", "") or job.get("scheduled_nodes", "")

role.num_replicas += 1
role_status.replicas.append(
Expand All @@ -686,24 +714,35 @@ def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
# where each replica is a "sub-job" so `allocated_nodes` will always be 1
# but we deal with jobs that have not been launched with torchx
# which can have multiple hosts per sub-job (count them as replicas)
node_infos = job_resources.get("allocated_nodes", [])
nodes_data = job_resources.get("nodes", {})

# SLURM 24.11+ changed from allocated_nodes to nodes.allocation structure
if "allocation" in nodes_data and isinstance(
nodes_data["allocation"], list
):
# SLURM 24.11+ format: nodes.allocation is a list
for node_info in nodes_data["allocation"]:
hostname = node_info["name"]
cpu = int(node_info["cpus"]["used"])
memMB = (
int(node_info["memory"]["allocated"]) // 1024
) # Convert to MB

if not isinstance(node_infos, list):
# NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
# is not a list of individual nodes, but a map of the nodelist specs
# in this case just use jobs[].job_resources.nodes
hostname = job_resources.get("nodes")
role.num_replicas += 1
role_status.replicas.append(
ReplicaStatus(
id=int(replica_id),
role=role_name,
state=state,
hostname=hostname,
role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
role.num_replicas += 1
role_status.replicas.append(
ReplicaStatus(
id=int(replica_id),
role=role_name,
state=state,
hostname=hostname,
)
)
)
else:
for node_info in node_infos:
elif "allocated_nodes" in job_resources and isinstance(
job_resources["allocated_nodes"], list
):
# Legacy format: allocated_nodes is a list
for node_info in job_resources["allocated_nodes"]:
# NOTE: we expect resource specs for all the nodes to be the same
# NOTE: use allocated (not used/requested) memory since
# users may only specify --cpu, in which case slurm
Expand All @@ -726,6 +765,26 @@ def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
hostname=hostname,
)
)
else:
# Fallback: use hostname from nodes.list
if isinstance(nodes_data, str):
hostname = nodes_data
else:
hostname = (
nodes_data.get("list", "")
if isinstance(nodes_data, dict)
else ""
)

role.num_replicas += 1
role_status.replicas.append(
ReplicaStatus(
id=int(replica_id),
role=role_name,
state=state,
hostname=hostname,
)
)

return DescribeAppResponse(
app_id=app_id,
Expand Down
86 changes: 83 additions & 3 deletions torchx/schedulers/test/slurm_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
# pyre-strict

import datetime
import importlib
import json
import os
import subprocess
import tempfile
import unittest
from contextlib import contextmanager
from importlib import resources
from typing import Generator
from unittest.mock import call, MagicMock, patch

Expand Down Expand Up @@ -244,7 +245,6 @@ def test_dryrun_multi_role(self, mock_version: MagicMock) -> None:
)

script = req.materialize()
print(script)
self.assertEqual(
script,
f"""#!/bin/bash
Expand Down Expand Up @@ -455,7 +455,7 @@ def test_describe_sacct_running(

def test_describe_squeue(self) -> None:
with (
importlib.resources.path(__package__, "slurm-squeue-output.json") as path,
resources.path(__package__, "slurm-squeue-output.json") as path,
open(path) as fp,
):
mock_output = fp.read()
Expand Down Expand Up @@ -1048,3 +1048,83 @@ def test_no_gpu_resources(self) -> None:
).materialize()
self.assertNotIn("--gpus-per-node", " ".join(sbatch))
self.assertNotIn("--gpus-per-task", " ".join(sbatch))

def test_describe_squeue_handles_none_job_resources(self) -> None:
"""Test that describe handles job_resources=None without crashing (i.e. for SLURM 24.11.5)."""

# Mock SLURM 24.11.5 response with job_resources=None
mock_job_data = {
"jobs": [
{
"name": "test-job-0",
"job_state": ["PENDING"],
"job_resources": None, # This was causing the crash
"nodes": "",
"scheduled_nodes": "",
"command": "/bin/echo",
"current_working_directory": "/tmp",
}
]
}

with patch("subprocess.check_output") as mock_subprocess:
mock_subprocess.return_value = json.dumps(mock_job_data)

scheduler = SlurmScheduler("test")
result = scheduler._describe_squeue("123")

# Should not crash and should return a valid response
assert result is not None
assert result.app_id == "123"
assert result.state == AppState.PENDING

def test_describe_sacct_handles_dot_separated_job_ids(self) -> None:
"""Test that _describe_sacct handles job IDs with '.' separators (not just '+')."""
sacct_output = """JobID|JobName|Partition|Account|AllocCPUS|State|ExitCode
89|mesh0-0|all|root|8|CANCELLED by 2166|0:0
89.batch|batch||root|8|CANCELLED|0:15
89.0|process_allocator||root|8|CANCELLED|0:15
"""

with patch("subprocess.check_output") as mock_subprocess:
mock_subprocess.return_value = sacct_output

scheduler = SlurmScheduler("test")
result = scheduler._describe_sacct("89")

# Should process only the main job "89", not the sub-jobs
assert result is not None
assert result.app_id == "89"
assert result.state == AppState.CANCELLED
assert result.msg == "CANCELLED by 2166"

# Should have one role "mesh0" with one replica "0"
assert len(result.roles) == 1
assert result.roles[0].name == "mesh0"
assert result.roles[0].num_replicas == 1

def test_describe_squeue_nodes_as_string(self) -> None:
"""Test when job_resources.nodes is a string (hostname) not a dict."""
mock_job_data = {
"jobs": [
{
"name": "test-job-0",
"job_state": ["RUNNING"],
"job_resources": {
"nodes": "compute-node-123" # String, not dict
# No allocated_nodes field
},
"command": "/bin/echo",
"current_working_directory": "/tmp",
}
]
}

with patch("subprocess.check_output") as mock_subprocess:
mock_subprocess.return_value = json.dumps(mock_job_data)

scheduler = SlurmScheduler("test")
result = scheduler._describe_squeue("123")

assert result is not None
assert result.roles_statuses[0].replicas[0].hostname == "compute-node-123"
Loading