Skip to content

Commit

Permalink
use utils
Browse files Browse the repository at this point in the history
  • Loading branch information
malmans2 committed Jun 13, 2024
1 parent cad2248 commit 116526a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 21 deletions.
2 changes: 0 additions & 2 deletions cads_broker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

dbsettings = None

TASKS_SUBDIR = "tasks_working_dir"


class SqlalchemySettings(pydantic_settings.BaseSettings):
"""Postgres-specific API settings.
Expand Down
27 changes: 8 additions & 19 deletions cads_broker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import hashlib
import io
import os
import pathlib
import pickle
import shutil
import threading
import time
import traceback
Expand All @@ -23,7 +21,7 @@
except ModuleNotFoundError:
pass

from cads_broker import Environment, config, factory
from cads_broker import Environment, config, factory, utils
from cads_broker import database as db
from cads_broker.qos import QoS

Expand Down Expand Up @@ -200,31 +198,22 @@ def __init__(self, number_of_workers) -> None:
parser.parse_rules(self.rules, self.environment)


def rmtree_if_exists(path: pathlib.Path, **kwargs: Any) -> None:
if path.exists():
shutil.rmtree(path, **kwargs)


class TempDirNannyPlugin(distributed.NannyPlugin):
def setup(self, nanny: distributed.Nanny) -> None:
self.tasks_path = pathlib.Path(nanny.worker_dir) / config.TASKS_SUBDIR
rmtree_if_exists(self.tasks_path)
self.tasks_path.mkdir()
path = utils.rm_task_path(nanny, None)
path.mkdir()

def teardown(self, nanny: distributed.Nanny) -> None:
rmtree_if_exists(self.tasks_path)
utils.rm_task_path(nanny, None)


class TempDirsWorkerPlugin(distributed.WorkerPlugin):
def setup(self, worker: distributed.Worker) -> None:
self.tasks_path = pathlib.Path(worker.local_directory) / config.TASKS_SUBDIR

def delete_task_working_dir(self, key: Key) -> None:
rmtree_if_exists(self.tasks_path / str(key))
def setup(self, worker) -> None:
self.worker = worker

def teardown(self, worker: distributed.Worker) -> None:
for key in worker.state.tasks:
self.delete_task_working_dir(key)
utils.rm_task_path(worker, key)

def transition(
self,
Expand All @@ -234,7 +223,7 @@ def transition(
**kwargs: Any,
) -> None:
if finish in ("memory", "error"):
self.delete_task_working_dir(key)
utils.rm_task_path(self.worker, key)


@attrs.define
Expand Down
34 changes: 34 additions & 0 deletions cads_broker/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pathlib
import shutil
from typing import Any

import distributed
from dask.typing import Key


def get_task_path(
worker_or_nanny: distributed.Worker | distributed.Nanny, key: Key | None
) -> pathlib.Path:
if isinstance(worker_or_nanny, distributed.Worker):
root = worker_or_nanny.local_directory
elif isinstance(worker_or_nanny, distributed.Nanny):
root = worker_or_nanny.worker_dir
else:
raise TypeError(
f"`worker_or_nanny` is of the wrong type: {type(worker_or_nanny)}"
)
path = pathlib.Path(root) / "tasks_working_dir"
if key is not None:
path /= str(key)
return path


def rm_task_path(
worker_or_nanny: distributed.Worker | distributed.Nanny,
key: Key | None,
**kwargs: Any,
) -> pathlib.Path:
path = get_task_path(worker_or_nanny, key)
if path.exists():
shutil.rmtree(path, **kwargs)
return path

0 comments on commit 116526a

Please sign in to comment.