diff --git a/uit/async_client.py b/uit/async_client.py index da74389..6966223 100644 --- a/uit/async_client.py +++ b/uit/async_client.py @@ -29,12 +29,36 @@ ALL_OFF, ) from .util import robust, AsyncHpcEnv -from .pbs_script import PbsScript +from .pbs_script import PbsScript, NODE_TYPES from .exceptions import UITError, MaxRetriesError logger = logging.getLogger(__name__) _ensure_connected = Client._ensure_connected +COMMANDS = { + "pbs": { + "status": { + "command": "qstat", + "full": " -f", + "username": " -u", + "job_id": " -x", + }, + "submit": "qsub", + "delete": "qdel", + "list_queues": "qstat -Q", + }, + "slurm": { + "status": { + "command": "squeue -l", + "username": " -u", + "job_id": " -j ", + }, + "submit": "sbatch", + "delete": "scancel", + "list_queues": "sacctmgr show qos format=Name%20", + }, +} + class AsyncClient(Client): """Provides a python abstraction for interacting with the UIT API. @@ -77,6 +101,8 @@ def __init__( delay_token=True, ) self.env = AsyncHpcEnv(self) + self.scheduler = None + self.commands = None self._session = None if async_init: self.param.trigger("_async_init") @@ -480,17 +506,23 @@ async def status( ): username = username if username is not None else self.username - cmd = "qstat" + cmd = self.commands["status"]["command"] - if full: - cmd += " -f" - elif username: - cmd += f" -u {username}" + if self.scheduler == "slurm": + if username: + cmd += self.commands["status"]["username"] + cmd += f" {username}" + else: + if full: + cmd += self.commands["status"]["full"] + elif username: + cmd += f" {username}" if job_id: if isinstance(job_id, (tuple, list)): job_id = " ".join([j.split(".")[0] for j in job_id]) - cmd += f" -x {job_id}" + cmd += self.commands["status"]["job_id"] + cmd += job_id result = await self.call(cmd) return self._process_status_result(result, parse=parse, full=full, as_df=as_df) else: @@ -500,11 +532,13 @@ async def status( if not with_historic: return result1 else: - cmd += " -x" + cmd += self.commands["status"]["job_id"] result = await self.call(cmd) result2 = self._process_status_result(result, parse=parse, full=full, as_df=as_df) - if not parse: + if self.scheduler == "slurm": + return pd.concat((result1, result2)) + elif not parse: return result1, result2 elif as_df: return pd.concat((result1, result2)) @@ -550,7 +584,7 @@ async def submit(self, pbs_script, working_dir=None, remote_name="run.pbs", loca # Submit the script using call() with qsub command try: - job_id = await self.call(f"qsub {remote_name}", working_dir=working_dir) + job_id = await self.call(f"{self.commands['submit']} {remote_name}", working_dir=working_dir) except RuntimeError as e: raise RuntimeError("An exception occurred while submitting job script: {}".format(str(e))) @@ -562,12 +596,25 @@ async def submit(self, pbs_script, working_dir=None, remote_name="run.pbs", loca @_ensure_connected async def get_queues(self, update_cache=False): if self._queues is None or update_cache: - self._queues = self._process_get_queues_output(await self.call("qstat -Q")) + self._queues = self._process_get_queues_output(await self.call(self.commands["list_queues"])) return self._queues @_ensure_connected async def get_raw_queue_stats(self): - return json.loads(await self.call("qstat -Q -f -F json"))["Queue"] + if self.scheduler == "slurm": + output = "id name max_walltime max_jobs max_nodes" + for queue in json.loads(await self.call("sacctmgr show qos --json"))["QOS"]: + max_walltime = str(queue["limits"]["max"]["wall_clock"]["per"]["job"]["number"]) + max_jobs = str(queue["limits"]["max"]["jobs"]["active_jobs"]["per"]["user"]["number"]) + max_nodes = -1 + for max_tres in queue["limits"]["max"]["tres"]["per"]["job"]: + if max_tres["type"] == "node": + max_nodes = max_tres["count"] + output += f"\n{queue['id']} {queue['name']} {max_walltime} {max_jobs} {max_nodes}" + return self._parse_slurm_output(output) + + else: + return json.loads(await self.call("qstat -Q -f -F json"))["Queue"] @_ensure_connected async def get_available_modules(self, flatten=False): diff --git a/uit/gui_tools/submit.py b/uit/gui_tools/submit.py index 66cae46..0ae026e 100644 --- a/uit/gui_tools/submit.py +++ b/uit/gui_tools/submit.py @@ -91,8 +91,13 @@ def set_file_browser(self): self.workdir.file_browser = create_file_browser(self.uit_client, patterns=[]) @staticmethod - def get_default(value, objects): - return value if value in objects else objects[0] + def get_default(value, objects, default=None): + """Verify that value exists in the objects list, otherwise return a default or the first item in the list""" + if value in objects: + return value + if default in objects: + return default + return objects[0] @param.depends("uit_client", watch=True) async def update_hpc_connection_dependent_defaults(self): @@ -107,7 +112,7 @@ async def update_hpc_connection_dependent_defaults(self): self.hpc_subproject = self.get_default(self.hpc_subproject, subprojects) self.workdir.file_path = self.uit_client.WORKDIR.as_posix() self.param.node_type.objects = list(NODE_TYPES[self.uit_client.system].keys()) - self.node_type = self.get_default(self.node_type, self.param.node_type.objects) + self.node_type = self.get_default(self.node_type, self.param.node_type.objects, default="compute") self.param.queue.objects = await self.await_if_async(self.uit_client.get_queues()) self.queue = self.get_default(self.queue, self.param.queue.objects) self.node_maxes = await self.await_if_async( diff --git a/uit/node_types.csv b/uit/node_types.csv index 9d39b46..3feb066 100644 --- a/uit/node_types.csv +++ b/uit/node_types.csv @@ -1,5 +1,7 @@ -system,compute,gpu,bigmem,transfer,mla,highclock -nautilus,128,128,128,1,128,32 -narwhal,128,128,128,1 -warhawk,128,128,128,1,128 -carpenter,192,128,192,1 \ No newline at end of file +system,scheduler,compute,gpu,bigmem,transfer,mla,highclock +carpenter,pbs,192,128,192,1,, +narwhal,pbs,128,128,128,1,, +nautilus,slurm,128,128,128,1,128,32 +raider,slurm,128,128,128,1,128,32 +ruth,pbs,192,128,192,1,64,32 +warhawk,pbs,128,128,128,1,128 \ No newline at end of file diff --git a/uit/pbs_script.py b/uit/pbs_script.py index b689cef..da36717 100644 --- a/uit/pbs_script.py +++ b/uit/pbs_script.py @@ -153,11 +153,15 @@ def parse_time(time_str): return None @staticmethod - def format_time(date_time_obj): - hours = date_time_obj.days * 24 + date_time_obj.seconds // 3600 - minutes = date_time_obj.seconds % 3600 // 60 - seconds = date_time_obj.seconds % 3600 % 60 + def format_time(time_delta_obj): + hours = time_delta_obj.days * 24 + time_delta_obj.seconds // 3600 + minutes = time_delta_obj.seconds % 3600 // 60 + seconds = time_delta_obj.seconds % 3600 % 60 return f"{hours}:{minutes:02}:{seconds:02}" + + @staticmethod + def parse_minutes(minutes): + return PbsScript.format_time(datetime.timedelta(minutes=minutes)) @property def max_time(self): diff --git a/uit/uit.py b/uit/uit.py index 41f9e8d..b8cd431 100644 --- a/uit/uit.py +++ b/uit/uit.py @@ -11,6 +11,7 @@ from functools import wraps from itertools import chain from pathlib import PurePosixPath, Path +from io import StringIO from urllib.parse import urljoin, urlencode # noqa: F401 import param @@ -19,7 +20,7 @@ from werkzeug.serving import make_server from .config import parse_config, DEFAULT_CA_FILE, DEFAULT_CONFIG -from .pbs_script import PbsScript +from .pbs_script import PbsScript, NODE_TYPES from .util import robust, HpcEnv from .exceptions import UITError, MaxRetriesError @@ -43,6 +44,30 @@ _auth_code = None _server = None +COMMANDS = { + "pbs": { + "status": { + "command": "qstat", + "full": " -f", + "username": " -u", + "job_id": " -x", + }, + "submit": "qsub", + "delete": "qdel", + "list_queues": "qstat -Q", + }, + "slurm": { + "status": { + "command": "squeue -l", + "username": " -u", + "job_id": " -j ", + }, + "submit": "sbatch", + "delete": "scancel", + "list_queues": "sacctmgr show qos format=Name%20", + }, +} + class Client(param.Parameterized): """Provides a python abstraction for interacting with the UIT API. @@ -101,6 +126,9 @@ def __init__( self.scope = scope self.port = port + self.scheduler = None + self.commands = None + if self.token is not None: self.param.trigger("token") @@ -287,6 +315,8 @@ def prepare_connect(self, system, login_node, exclude_login_nodes, retry_on_fail self._username = self._userinfo["SYSTEMS"][self._system.upper()]["USERNAME"] self._uit_url = self._uit_urls[login_node] self.connected = True + self.scheduler = NODE_TYPES[self.system]["scheduler"] + self.commands = COMMANDS[self.scheduler] return login_node, retry_on_failure @@ -698,17 +728,25 @@ def status( ): username = username if username is not None else self.username - cmd = "qstat" + # cmd will either be "qstat" or "squeue" + cmd = self.commands["status"]["command"] - if full: - cmd += " -f" - elif username: - cmd += f" -u {username}" + if self.scheduler == "slurm": + if username: + cmd += self.commands["status"]["username"] + cmd += f" {username}" + else: + if full: + cmd += self.commands["status"]["full"] + elif username: + cmd += self.commands["status"]["username"] + cmd += f" {username}" if job_id: if isinstance(job_id, (tuple, list)): job_id = " ".join([j.split(".")[0] for j in job_id]) - cmd += f" -x {job_id}" + cmd += self.commands["status"]["job_id"] + cmd += job_id result = self.call(cmd) return self._process_status_result(result, parse=parse, full=full, as_df=as_df) else: @@ -718,11 +756,12 @@ def status( if not with_historic: return result1 else: - cmd += " -x" + cmd += self.commands["status"]["job_id"] result = self.call(cmd) result2 = self._process_status_result(result, parse=parse, full=full, as_df=as_df) - - if not parse: + if self.scheduler == "slurm": + return pd.concat((result1, result2)) + elif not parse: return result1, result2 elif as_df: return pd.concat((result1, result2)) @@ -768,7 +807,7 @@ def submit(self, pbs_script, working_dir=None, remote_name="run.pbs", local_temp # Submit the script using call() with qsub command try: - job_id = self.call(f"qsub {remote_name}", working_dir=working_dir) + job_id = self.call(f"{self.commands['submit']} {remote_name}", working_dir=working_dir) except RuntimeError as e: raise RuntimeError("An exception occurred while submitting job script: {}".format(str(e))) @@ -780,21 +819,51 @@ def submit(self, pbs_script, working_dir=None, remote_name="run.pbs", local_temp @_ensure_connected def get_queues(self, update_cache=False): if self._queues is None or update_cache: - self._queues = self._process_get_queues_output(self.call("qstat -Q")) + self._queues = self._process_get_queues_output(self.call(self.commands["list_queues"])) return self._queues def _process_get_queues_output(self, output): - standard_queues = [] if self.system == "jim" else QUEUES + standard_queues = QUEUES other_queues = set([i.split()[0] for i in output.splitlines()][2:]) - set(standard_queues) all_queues = standard_queues + sorted([q for q in other_queues if "_" not in q]) return all_queues @_ensure_connected def get_raw_queue_stats(self): - return json.loads(self.call("qstat -Q -f -F json"))["Queue"] + if self.scheduler == "slurm": + output = "id name max_walltime max_jobs max_nodes" + for queue in json.loads(self.call("sacctmgr show qos --json"))["QOS"]: + minutes = queue["limits"]["max"]["wall_clock"]["per"]["job"]["number"] + max_walltime = PbsScript.parse_minutes(minutes) + max_jobs = str(queue["limits"]["max"]["jobs"]["active_jobs"]["per"]["user"]["number"]) + max_nodes = -1 + for max_tres in queue["limits"]["max"]["tres"]["per"]["job"]: + if max_tres["type"] == "node": + max_nodes = max_tres["count"] * int(NODE_TYPES[self.system]['compute']) + output += f"\n{queue['id']} {queue['name']} {max_walltime} {max_jobs} {max_nodes}" + return self._parse_slurm_output(output) + + else: + return json.loads(self.call("qstat -Q -f -F json"))["Queue"] @_ensure_connected def get_node_maxes(self, queues, queues_stats): + if self.scheduler == "slurm": + return self._slurm_node_maxes(queues, queues_stats) + + else: + return self._pbs_node_maxes(queues, queues_stats) + + def _slurm_node_maxes(self, queues, queues_stats): + ncpus_maxes = dict() + + for q in queues: + max_nodes = str(queues_stats.loc[queues_stats["name"] == f"{q.lower()}", "max_nodes"].iloc[0]) + ncpus_maxes[q] = max_nodes if max_nodes != "-1" else "Not Found" + + return ncpus_maxes + + def _pbs_node_maxes(self, queues, queues_stats): q_sts = {q: queues_stats[q] for q in queues if q in queues_stats.keys()} ncpus_maxes = dict() @@ -809,6 +878,21 @@ def get_node_maxes(self, queues, queues_stats): @_ensure_connected def get_wall_time_maxes(self, queues, queues_stats): + if self.scheduler == "slurm": + return self._slurm_wall_time_maxes(queues, queues_stats) + else: + return self._pbs_wall_time_maxes(queues, queues_stats) + + def _slurm_wall_time_maxes(self, queues, queues_stats): + wall_time_maxes = dict() + + for q in queues: + max_walltimes = str(queues_stats.loc[queues_stats["name"] == f"{q.lower()}", "max_walltime"].iloc[0]) + wall_time_maxes[q] = max_walltimes + + return wall_time_maxes + + def _pbs_wall_time_maxes(self, queues, queues_stats): q_sts = {q: queues_stats[q] for q in queues if q in queues_stats.keys()} wall_time_maxes = dict() @@ -847,28 +931,36 @@ def _process_status_result(self, result, parse, full, as_df): if not parse: return result - if full: - result = self._parse_full_status(result) - if as_df: - return self._as_df(result).T - else: - return result - - columns = ( - "job_id", - "username", - "queue", - "jobname", - "session_id", - "nds", - "tsk", - "requested_memory", - "requested_time", - "status", - "elapsed_time", - ) - - return self._parse_hpc_output(result, as_df, columns=columns, delimiter_char="-") + if self.scheduler == "slurm": + result = result.split("\n", 1)[1] + return self._parse_slurm_output(result=result) + else: + if full: + result = self._parse_full_status(result) + if as_df: + return self._as_df(result).T + else: + return result + + columns = ( + "job_id", + "username", + "queue", + "jobname", + "session_id", + "nds", + "tsk", + "requested_memory", + "requested_time", + "status", + "elapsed_time", + ) + + return self._parse_hpc_output(result, as_df, columns=columns, delimiter_char="-") + + @staticmethod + def _parse_slurm_output(result): + return pd.read_table(StringIO(result), delim_whitespace=True) @staticmethod def _parse_full_status(status_str): @@ -933,6 +1025,7 @@ def _parse_hpc_output( delimiter_char="=", num_header_lines=3, ): + if output: delimiter = delimiter or cls._parse_hpc_delimiter(output, delimiter_char=delimiter_char)