Skip to content
176 changes: 140 additions & 36 deletions uit/uit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from functools import wraps
from itertools import chain
from pathlib import PurePosixPath, Path
from enum import StrEnum, auto
from io import StringIO
from urllib.parse import urljoin, urlencode # noqa: F401

import param
Expand Down Expand Up @@ -44,6 +46,37 @@
_server = None


class BatchSystem(StrEnum):

PBS = auto()
SLURM = auto()


SYSTEMS = {"carpenter": BatchSystem.PBS, "nautilus": BatchSystem.SLURM}

COMMANDS = {
BatchSystem.PBS: {
"status": {
"command": "qstat",
"full": " -f",
"username": " -u",
"job_id": " -x",
},
"submit": "qsub",
"delete": "qdel",
},
BatchSystem.SLURM: {
"status": {
"command": "squeue -l",
"username": " -u",
"job_id": " -j ",
},
"submit": "sbatch",
"delete": "scancel",
},
}


class Client(param.Parameterized):
"""Provides a python abstraction for interacting with the UIT API.

Expand Down Expand Up @@ -101,6 +134,9 @@ def __init__(
self.scope = scope
self.port = port

self.batch_system = None
self.commands = None

if self.token is not None:
self.param.trigger("token")

Expand Down Expand Up @@ -287,6 +323,8 @@ def prepare_connect(self, system, login_node, exclude_login_nodes, retry_on_fail
self._username = self._userinfo["SYSTEMS"][self._system.upper()]["USERNAME"]
self._uit_url = self._uit_urls[login_node]
self.connected = True
self.batch_system = SYSTEMS[self.system]
self.commands = COMMANDS[self.batch_system]

return login_node, retry_on_failure

Expand Down Expand Up @@ -683,7 +721,7 @@ def show_usage(self, parse=True, as_df=False):
if not parse:
return result

return self._parse_hpc_output(result, as_df)
return self._parse_hpc_output(result, as_df, batch_system=self.batch_system)

@_ensure_connected
@robust()
Expand All @@ -698,17 +736,25 @@ def status(
):
username = username if username is not None else self.username

cmd = "qstat"
# cmd will either be "qstat" or "squeue"
cmd = self.commands["status"]["command"]

if full:
cmd += " -f"
elif username:
cmd += f" -u {username}"
if self.batch_system == BatchSystem.SLURM:
if username:
cmd += self.commands["status"]["username"]
cmd += f" {username}"
else:
if full:
cmd += self.commands["status"]["full"]
elif username:
cmd += self.commands["status"]["username"]
cmd += f" {username}"

if job_id:
if isinstance(job_id, (tuple, list)):
job_id = " ".join([j.split(".")[0] for j in job_id])
cmd += f" -x {job_id}"
cmd += self.commands["status"]["job_id"]
cmd += job_id
result = self.call(cmd)
return self._process_status_result(result, parse=parse, full=full, as_df=as_df)
else:
Expand All @@ -718,11 +764,12 @@ def status(
if not with_historic:
return result1
else:
cmd += " -x"
cmd += self.commands["status"]["job_id"]
result = self.call(cmd)
result2 = self._process_status_result(result, parse=parse, full=full, as_df=as_df)

if not parse:
if self.batch_system == BatchSystem.SLURM:
return pd.concat((result1, result2))
elif not parse:
return result1, result2
elif as_df:
return pd.concat((result1, result2))
Expand Down Expand Up @@ -768,7 +815,7 @@ def submit(self, pbs_script, working_dir=None, remote_name="run.pbs", local_temp

# Submit the script using call() with qsub command
try:
job_id = self.call(f"qsub {remote_name}", working_dir=working_dir)
job_id = self.call(f"{self.commands['submit']} {remote_name}", working_dir=working_dir)
except RuntimeError as e:
raise RuntimeError("An exception occurred while submitting job script: {}".format(str(e)))

Expand All @@ -780,21 +827,53 @@ def submit(self, pbs_script, working_dir=None, remote_name="run.pbs", local_temp
@_ensure_connected
def get_queues(self, update_cache=False):
if self._queues is None or update_cache:
self._queues = self._process_get_queues_output(self.call("qstat -Q"))
if self.batch_system == BatchSystem.SLURM:
self._queues = self._process_get_queues_output(self.call("sacctmgr show qos format=Name%20"))
else:
self._queues = self._process_get_queues_output(self.call("qstat -Q"))
return self._queues

def _process_get_queues_output(self, output):
standard_queues = [] if self.system == "jim" else QUEUES
standard_queues = QUEUES
other_queues = set([i.split()[0] for i in output.splitlines()][2:]) - set(standard_queues)
all_queues = standard_queues + sorted([q for q in other_queues if "_" not in q])
return all_queues

@_ensure_connected
def get_raw_queue_stats(self):
return json.loads(self.call("qstat -Q -f -F json"))["Queue"]
if self.batch_system == BatchSystem.SLURM:
output = "id name max_walltime max_jobs max_nodes"
for queue in json.loads(self.call("sacctmgr show qos --json"))["QOS"]:
max_walltime = str(queue["limits"]["max"]["wall_clock"]["per"]["job"]["number"])
max_jobs = str(queue["limits"]["max"]["jobs"]["active_jobs"]["per"]["user"]["number"])
max_nodes = -1
for max_tres in queue["limits"]["max"]["tres"]["per"]["job"]:
if max_tres["type"] == "node":
max_nodes = max_tres["count"]
output += f"\n{queue['id']} {queue['name']} {max_walltime} {max_jobs} {max_nodes}"
return self._parse_slurm_output(output)

else:
return json.loads(self.call("qstat -Q -f -F json"))["Queue"]

@_ensure_connected
def get_node_maxes(self, queues, queues_stats):
if self.batch_system == BatchSystem.SLURM:
return self._slurm_node_maxes(queues, queues_stats)

else:
return self._pbs_node_maxes(queues, queues_stats)

def _slurm_node_maxes(self, queues, queues_stats):
ncpus_maxes = dict()

for q in queues:
max_nodes = str(queues_stats.loc[queues_stats["name"] == f"{q.lower()}", "max_nodes"].iloc[0])
ncpus_maxes[q] = max_nodes if max_nodes != "-1" else "Not Found"

return ncpus_maxes

def _pbs_node_maxes(self, queues, queues_stats):
q_sts = {q: queues_stats[q] for q in queues if q in queues_stats.keys()}

ncpus_maxes = dict()
Expand All @@ -809,6 +888,21 @@ def get_node_maxes(self, queues, queues_stats):

@_ensure_connected
def get_wall_time_maxes(self, queues, queues_stats):
if self.batch_system == BatchSystem.SLURM:
return self._slurm_wall_time_maxes(queues, queues_stats)
else:
return self._pbs_wall_time_maxes(queues, queues_stats)

def _slurm_wall_time_maxes(self, queues, queues_stats):
wall_time_maxes = dict()

for q in queues:
max_walltimes = str(queues_stats.loc[queues_stats["name"] == f"{q.lower()}", "max_walltime"].iloc[0])
wall_time_maxes[q] = max_walltimes

return wall_time_maxes

def _pbs_wall_time_maxes(self, queues, queues_stats):
q_sts = {q: queues_stats[q] for q in queues if q in queues_stats.keys()}

wall_time_maxes = dict()
Expand Down Expand Up @@ -847,28 +941,37 @@ def _process_status_result(self, result, parse, full, as_df):
if not parse:
return result

if full:
result = self._parse_full_status(result)
if as_df:
return self._as_df(result).T
else:
return result

columns = (
"job_id",
"username",
"queue",
"jobname",
"session_id",
"nds",
"tsk",
"requested_memory",
"requested_time",
"status",
"elapsed_time",
)

return self._parse_hpc_output(result, as_df, columns=columns, delimiter_char="-")
if self.batch_system == BatchSystem.SLURM:
# Trimming the top of result so that read_tables works properly
result = result.split("\n", 1)[1]
return self._parse_slurm_output(result=result)
else:
if full:
result = self._parse_full_status(result)
if as_df:
return self._as_df(result).T
else:
return result

columns = (
"job_id",
"username",
"queue",
"jobname",
"session_id",
"nds",
"tsk",
"requested_memory",
"requested_time",
"status",
"elapsed_time",
)

return self._parse_hpc_output(result, as_df, columns=columns, delimiter_char="-")

@staticmethod
def _parse_slurm_output(result):
return pd.read_table(StringIO(result), delim_whitespace=True)

@staticmethod
def _parse_full_status(status_str):
Expand Down Expand Up @@ -933,6 +1036,7 @@ def _parse_hpc_output(
delimiter_char="=",
num_header_lines=3,
):

if output:
delimiter = delimiter or cls._parse_hpc_delimiter(output, delimiter_char=delimiter_char)

Expand Down