Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions uit_plus_job/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from django.contrib.auth.models import User
from social_django.utils import load_strategy
from tethys_apps.base.function_extractor import TethysFunctionExtractor
from uit.exceptions import UITError
from uit.exceptions import MaxRetriesError, UITError
from uit import AsyncClient, PbsScript, PbsJob, PbsArrayJob
from uit.pbs_script import PbsDirective
from tethys_compute.models.tethys_job import TethysJob
Expand Down Expand Up @@ -368,11 +368,11 @@ def process_intermediate_results_function(self):
@process_intermediate_results_function.setter
def process_intermediate_results_function(self, function):
if isinstance(function, str):
self._process_results_function = function
self._process_intermediate_results_function = function
return
module_path = inspect.getmodule(function).__name__.split(".")
module_path.append(function.__name__)
self._process_results_function = ".".join(module_path)
self._process_intermediate_results_function = ".".join(module_path)

@property
def remote_workspace_id(self):
Expand Down Expand Up @@ -547,8 +547,12 @@ async def update_status(self, status=None, *args, **kwargs):

# Update status if status not given and still pending/running
elif update_needed and self.is_time_to_update():
await self._update_status(*args, **kwargs)
self._last_status_update = timezone.now()
try:
await self._update_status(*args, **kwargs)
self._last_status_update = timezone.now()
except (MaxRetriesError, UITError) as e:
log.info(f"Unable to connect to {self.system} for user {self.user} due to the following error: {e}")
return

# Post-process status after update if old status was pending/running
if update_needed:
Expand Down Expand Up @@ -617,18 +621,20 @@ def intermediate_transfer_interval_exceeded(self):

async def process_results(self):
"""Process the results using the UIT Plus Python client."""
log.debug("Started processing results for job: {}".format(self))
log.debug(f"Started processing results for job: {self}")
await self.get_remote_files(self.transfer_output_files)
self.completion_time = timezone.now()
self._status = "COM"
await self._safe_save()
log.debug("Finished processing results for job: {}".format(self))
if self.process_results_function:
self.process_results_function(self)
log.debug(f"Finished processing results for job: {self}")

async def get_intermediate_results(self):
"""Retrieve intermediate result files from the supercomputer."""
if await self.get_remote_files(self.transfer_intermediate_files):
if self.process_intermediate_results_function:
self.process_intermediate_results_function()
self.process_intermediate_results_function(self)

def resolve_paths(self, paths):
resolved_paths = []
Expand All @@ -640,22 +646,29 @@ def resolve_paths(self, paths):
resolved_paths.append(self.pbs_job.resolve_path(p))
return resolved_paths

async def get_remote_files(self, remote_filenames):
async def get_remote_files(self, remote_paths):
"""Transfer files from a directory on the super computer.

Args:
remote_filenames (List[str]): Files to retrieve from remote_dir
remote_paths (List[str]): Files to retrieve from remote_dir

Returns:
bool: True if all file transfers succeed.
"""
remote_dirnames = []
if isinstance(remote_paths, dict):
remote_dirnames = remote_paths.get("dirs", [])
remote_filenames = remote_paths.get("files", [])
else:
remote_filenames = remote_paths

# Ensure the local transfer directory exists
workspace = Path(self.workspace)
success = True
remote_paths = self.resolve_paths(remote_filenames)
remote_file_paths = self.resolve_paths(remote_filenames)
remote_dir_paths = self.resolve_paths(remote_dirnames)

for remote_path in remote_paths:
for remote_path in remote_file_paths:
rel_path = remote_path.relative_to(self.working_dir)
local_path = workspace / rel_path
local_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -668,6 +681,14 @@ async def get_remote_files(self, remote_filenames):
success = False
log.error("Failed to get remote file: {}".format(str(e)))

for remote_dir in remote_dir_paths:
rel_path = remote_dir.relative_to(self.working_dir)
local_dir = workspace / rel_path
try:
await self.client.get_dir(remote_dir=remote_dir, local_dir=local_dir)
except Exception as e:
log.exception(e)

return success

@_ensure_connected
Expand Down
25 changes: 25 additions & 0 deletions uit_plus_job/submit_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,13 +710,38 @@ def action_button(self):

return row

@property
def transfer_input_files(self):
return None

@property
def transfer_intermediate_files(self):
return None

@property
def transfer_output_files(self):
return None

@property
def intermediate_transfer_interval(self):
return 0

@property
def process_intermediate_results_function(job):
pass

@property
def process_results_function(job):
pass

async def submit(self, custom_logs=None):
self.job.script = self.pbs_script # update script to ensure it reflects any UI updates
job = await database_sync_to_async(UitPlusJob.instance_from_pbs_job)(self.job, self.tethys_user)
job.custom_logs = custom_logs or self.custom_logs
job.transfer_input_files = self.transfer_input_files
job.transfer_intermediate_files = self.transfer_intermediate_files
job.transfer_output_files = self.transfer_output_files
job.intermediate_transfer_interval = self.intermediate_transfer_interval
job.process_intermediate_results_function = self.process_intermediate_results_function
job.process_results_function = self.process_results_function
await job.execute()
6 changes: 3 additions & 3 deletions uit_plus_job/tests/integrated_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def test_get_remote_file_no_local_path(self, mock_client):
remote_dir = "WORKDIR"
mock_client.get_file.return_value = {"success": True}
self.assertFalse(
asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_filenames=remote_files_names))
asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_paths=remote_files_names))
)

@mock.patch("uit_plus_job.models.log")
Expand All @@ -356,7 +356,7 @@ def test_get_remote_file_io_error(self, mock_client, mock_log):
mock_client.get_file.side_effect = IOError

# call the method
ret = asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_filenames=remote_files_names))
ret = asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_paths=remote_files_names))

# test results
self.assertFalse(ret)
Expand All @@ -370,7 +370,7 @@ def test_get_remote_file(self, mock_client, mock_os):
mock_os.path.join.side_effect = ["local_path", "remote_path"]
mock_client.get_file.return_value = {"success": True}
mock_os.path.exists.return_value = True
ret = asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_filenames=remote_files_names))
ret = asyncio.run(self.uitplusjob.get_remote_files(remote_dir=remote_dir, remote_paths=remote_files_names))

# test results
self.assertTrue(ret)
Expand Down