diff --git a/uit_plus_job/models.py b/uit_plus_job/models.py index b74e697..50cbf7c 100644 --- a/uit_plus_job/models.py +++ b/uit_plus_job/models.py @@ -18,7 +18,7 @@ from django.contrib.auth.models import User from social_django.utils import load_strategy from tethys_apps.base.function_extractor import TethysFunctionExtractor -from uit.exceptions import UITError +from uit.exceptions import MaxRetriesError, UITError from uit import AsyncClient, PbsScript, PbsJob, PbsArrayJob from uit.pbs_script import PbsDirective from tethys_compute.models.tethys_job import TethysJob @@ -368,11 +368,11 @@ def process_intermediate_results_function(self): @process_intermediate_results_function.setter def process_intermediate_results_function(self, function): if isinstance(function, str): - self._process_results_function = function + self._process_intermediate_results_function = function return module_path = inspect.getmodule(function).__name__.split(".") module_path.append(function.__name__) - self._process_results_function = ".".join(module_path) + self._process_intermediate_results_function = ".".join(module_path) @property def remote_workspace_id(self): @@ -547,8 +547,12 @@ async def update_status(self, status=None, *args, **kwargs): # Update status if status not given and still pending/running elif update_needed and self.is_time_to_update(): - await self._update_status(*args, **kwargs) - self._last_status_update = timezone.now() + try: + await self._update_status(*args, **kwargs) + self._last_status_update = timezone.now() + except (MaxRetriesError, UITError) as e: + log.info(f"Unable to connect to {self.system} for user {self.user} due to the following error: {e}") + return # Post-process status after update if old status was pending/running if update_needed: @@ -617,18 +621,20 @@ def intermediate_transfer_interval_exceeded(self): async def process_results(self): """Process the results using the UIT Plus Python client.""" - log.debug("Started processing results for job: {}".format(self)) + log.debug(f"Started processing results for job: {self}") await self.get_remote_files(self.transfer_output_files) self.completion_time = timezone.now() self._status = "COM" await self._safe_save() - log.debug("Finished processing results for job: {}".format(self)) + if self.process_results_function: + self.process_results_function(self) + log.debug(f"Finished processing results for job: {self}") async def get_intermediate_results(self): """Retrieve intermediate result files from the supercomputer.""" if await self.get_remote_files(self.transfer_intermediate_files): if self.process_intermediate_results_function: - self.process_intermediate_results_function() + self.process_intermediate_results_function(self) def resolve_paths(self, paths): resolved_paths = [] @@ -640,22 +646,29 @@ def resolve_paths(self, paths): resolved_paths.append(self.pbs_job.resolve_path(p)) return resolved_paths - async def get_remote_files(self, remote_filenames): + async def get_remote_files(self, remote_paths): """Transfer files from a directory on the super computer. Args: - remote_filenames (List[str]): Files to retrieve from remote_dir + remote_paths (List[str]): Files to retrieve from remote_dir Returns: bool: True if all file transfers succeed. """ + remote_dirnames = [] + if isinstance(remote_paths, dict): + remote_dirnames = remote_paths.get("dirs", []) + remote_filenames = remote_paths.get("files", []) + else: + remote_filenames = remote_paths # Ensure the local transfer directory exists workspace = Path(self.workspace) success = True - remote_paths = self.resolve_paths(remote_filenames) + remote_file_paths = self.resolve_paths(remote_filenames) + remote_dir_paths = self.resolve_paths(remote_dirnames) - for remote_path in remote_paths: + for remote_path in remote_file_paths: rel_path = remote_path.relative_to(self.working_dir) local_path = workspace / rel_path local_path.parent.mkdir(parents=True, exist_ok=True) @@ -668,6 +681,14 @@ async def get_remote_files(self, remote_filenames): success = False log.error("Failed to get remote file: {}".format(str(e))) + for remote_dir in remote_dir_paths: + rel_path = remote_dir.relative_to(self.working_dir) + local_dir = workspace / rel_path + try: + await self.client.get_dir(remote_dir=remote_dir, local_dir=local_dir) + except Exception as e: + log.exception(e) + return success @_ensure_connected diff --git a/uit_plus_job/submit_stage.py b/uit_plus_job/submit_stage.py index 01e04e3..b0b07b0 100644 --- a/uit_plus_job/submit_stage.py +++ b/uit_plus_job/submit_stage.py @@ -710,13 +710,38 @@ def action_button(self): return row + @property + def transfer_input_files(self): + return None + + @property + def transfer_intermediate_files(self): + return None + @property def transfer_output_files(self): return None + @property + def intermediate_transfer_interval(self): + return 0 + + @property + def process_intermediate_results_function(job): + pass + + @property + def process_results_function(job): + pass + async def submit(self, custom_logs=None): self.job.script = self.pbs_script # update script to ensure it reflects any UI updates job = await database_sync_to_async(UitPlusJob.instance_from_pbs_job)(self.job, self.tethys_user) job.custom_logs = custom_logs or self.custom_logs + job.transfer_input_files = self.transfer_input_files + job.transfer_intermediate_files = self.transfer_intermediate_files job.transfer_output_files = self.transfer_output_files + job.intermediate_transfer_interval = self.intermediate_transfer_interval + job.process_intermediate_results_function = self.process_intermediate_results_function + job.process_results_function = self.process_results_function await job.execute() diff --git a/uit_plus_job/tests/integrated_tests/test_models.py b/uit_plus_job/tests/integrated_tests/test_models.py index 7cd624c..cfc81ce 100644 --- a/uit_plus_job/tests/integrated_tests/test_models.py +++ b/uit_plus_job/tests/integrated_tests/test_models.py @@ -345,7 +345,7 @@ def test_get_remote_file_no_local_path(self, mock_client): remote_dir = "WORKDIR" mock_client.get_file.return_value = {"success": True} self.assertFalse( - asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_filenames=remote_files_names)) + asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_paths=remote_files_names)) ) @mock.patch("uit_plus_job.models.log") @@ -356,7 +356,7 @@ def test_get_remote_file_io_error(self, mock_client, mock_log): mock_client.get_file.side_effect = IOError # call the method - ret = asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_filenames=remote_files_names)) + ret = asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_paths=remote_files_names)) # test results self.assertFalse(ret) @@ -370,7 +370,7 @@ def test_get_remote_file(self, mock_client, mock_os): mock_os.path.join.side_effect = ["local_path", "remote_path"] mock_client.get_file.return_value = {"success": True} mock_os.path.exists.return_value = True - ret = asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_filenames=remote_files_names)) + ret = asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_paths=remote_files_names)) # test results self.assertTrue(ret)