Skip to content

Commit

Permalink
more... [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
hjoliver committed Mar 26, 2024
1 parent 451408b commit 676d374
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 83 deletions.
3 changes: 1 addition & 2 deletions cylc/flow/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,7 @@ def _load_pool_from_db(self):
self.xtrigger_mgr.load_xtrigger_for_restart)
self.workflow_db_mgr.pri_dao.select_abs_outputs_for_restart(
self.pool.load_abs_outputs_for_restart)
self.pool.task_hold_mgr.load_from_db(
self.workflow_db_mgr.pri_dao.select_tasks_to_hold)
self.pool.task_hold_mgr.load_from_db()
self.pool.update_flow_mgr()

def restart_remote_init(self):
Expand Down
31 changes: 19 additions & 12 deletions cylc/flow/task_hold_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from dataclasses import dataclass
from typing import (
Callable,
Dict,
List,
Set,
Expand All @@ -33,6 +34,8 @@
if TYPE_CHECKING:
from cylc.flow.cycling import PointBase
from cylc.flow.task_proxy import TaskProxy
from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager
from cylc.flow.data_store_mgr import DataStoreMgr


@dataclass
Expand All @@ -57,35 +60,35 @@ class TaskHoldMgr:

def __init__(
self,
workflow_db_mgr,
data_store_mgr,
ancestors
workflow_db_mgr: 'WorkflowDatabaseManager',
data_store_mgr: 'DataStoreMgr',
ancestors: Dict[str, List[str]]
) -> None:
self.ancestors = ancestors
self.workflow_db_mgr = workflow_db_mgr
self.data_store_mgr = data_store_mgr
self.store: Set[Tuple[str, 'PointBase']] = set()

def load_from_db(self):
"""Update the hold store from the database."""
"""Update the task hold store from the database."""
self.store.update(
(name, get_point(cycle)) for name, cycle in
self.workflow_db_mgr.select_tasks_to_hold()
self.workflow_db_mgr.pri_dao.select_tasks_to_hold()
)

def is_held(self, name, point) -> bool:
"""Is this task to be held?"""
"""Is this task listed in the hold store?"""
return (name, point) in self.store

def hold_active_tasks(self, itasks: List['TaskProxy']) -> None:
"""Add tasks to the hold store."""
"""What it says."""
for itask in itasks:
itask.state_reset(is_held=True)
self.store.add((itask.tdef.name, itask.point))
self.data_store_mgr.delta_task_held(itask)
itask.state_reset(is_held=True)
self.workflow_db_mgr.put_tasks_to_hold(self.store)

def update_future_tasks(self, tasks) -> None:
def hold_future_tasks(self, tasks: List[str, 'PointBase']) -> None:
"""Add a task to the hold store."""
for name, cycle in tasks:
self.data_store_mgr.delta_task_held((name, cycle, True))
Expand Down Expand Up @@ -113,8 +116,12 @@ def release_future_tasks(self, ftasks) -> None:
self.store.difference_update(matched)
self.workflow_db_mgr.put_tasks_to_hold(self.store)

def release_active_tasks(self, itasks: List['TaskProxy']) -> None:
"""Release held active tasks, and queue the if ready."""
def release_active_tasks(
self,
itasks: List['TaskProxy'],
queue_func: Callable
) -> None:
"""Release held active tasks, and queue them if ready."""
for itask in itasks:
if not itask.state_reset(is_held=False):
continue
Expand All @@ -123,7 +130,7 @@ def release_active_tasks(self, itasks: List['TaskProxy']) -> None:
not itask.state.is_runahead
and all(itask.is_ready_to_run())
):
self.queue_task(itask)
queue_func(itask)
self.remove_active_task(itask)

def clear(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions cylc/flow/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ def release_held_tasks(self, items: Iterable[str]) -> int:
items,
warn=False
)
self.task_hold_mgr.release_active_tasks(itasks)
self.task_hold_mgr.release_active_tasks(itasks, self.queue_task)

# Release matching future tasks (only 'waiting' selector is valid).
# future_matched: 'Set[Tuple[str, PointBase]]' = set()
Expand All @@ -1269,7 +1269,8 @@ def release_held_tasks(self, items: Iterable[str]) -> int:
def release_hold_point(self) -> None:
"""Unset the workflow hold point and release all held active tasks."""
self.hold_point = None
self.task_hold_mgr.release_active_tasks(self.get_tasks())
self.task_hold_mgr.release_active_tasks(
self.get_tasks(), self.queue_task)
self.task_hold_mgr.clear()
self.workflow_db_mgr.put_workflow_hold_cycle_point(None)

Expand Down
57 changes: 52 additions & 5 deletions tests/integration/test_task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ async def test_hold_tasks(
hold_expected = itask.identity in expected_tasks_to_hold_ids
assert itask.state.is_held is hold_expected

assert get_task_ids(task_pool.tasks_to_hold) == expected_tasks_to_hold_ids
assert get_task_ids(task_pool.task_hold_mgr.store) == expected_tasks_to_hold_ids

logged_warnings = assert_expected_log(caplog, expected_warnings)
assert n_warnings == len(logged_warnings)
Expand Down Expand Up @@ -424,7 +424,7 @@ async def test_release_held_tasks(
for itask in task_pool.get_tasks():
hold_expected = itask.identity in expected_tasks_to_hold_ids
assert itask.state.is_held is hold_expected
assert get_task_ids(task_pool.tasks_to_hold) == expected_tasks_to_hold_ids
assert get_task_ids(task_pool.task_hold_mgr.store) == expected_tasks_to_hold_ids
db_tasks_to_hold = db_select(example_flow, True, 'tasks_to_hold')
assert get_task_ids(db_tasks_to_hold) == expected_tasks_to_hold_ids

Expand All @@ -434,7 +434,7 @@ async def test_release_held_tasks(
assert itask.state.is_held is (itask.identity == '1/bar')

expected_tasks_to_hold_ids = sorted(['1/bar'])
assert get_task_ids(task_pool.tasks_to_hold) == expected_tasks_to_hold_ids
assert get_task_ids(task_pool.task_hold_mgr.store) == expected_tasks_to_hold_ids

db_tasks_to_hold = db_select(example_flow, True, 'tasks_to_hold')
assert get_task_ids(db_tasks_to_hold) == expected_tasks_to_hold_ids
Expand Down Expand Up @@ -468,7 +468,7 @@ async def test_hold_point(
hold_expected = itask.identity in expected_held_task_ids
assert itask.state.is_held is hold_expected

assert get_task_ids(task_pool.tasks_to_hold) == expected_held_task_ids
assert get_task_ids(task_pool.task_hold_mgr.store) == expected_held_task_ids
db_tasks_to_hold = db_select(example_flow, True, 'tasks_to_hold')
assert get_task_ids(db_tasks_to_hold) == expected_held_task_ids

Expand All @@ -482,7 +482,7 @@ async def test_hold_point(
for itask in task_pool.get_tasks():
assert itask.state.is_held is False

assert task_pool.tasks_to_hold == set()
assert task_pool.task_hold_mgr.store == set()
assert db_select(example_flow, True, 'tasks_to_hold') == []


Expand Down Expand Up @@ -1874,3 +1874,50 @@ def max_cycle(tasks):
mod_blah.pool.compute_runahead()
after = mod_blah.pool.runahead_limit_point
assert bool(before != after) == expected


async def test_task_hold(
flow,
scheduler,
start,
log_filter,
):
"""
"""
id_ = flow(
{
'scheduler': {'allow implicit tasks': 'True'},
'scheduling': {
'graph': {
'R1': """
foo => bar => baz
"""
}
}
}
)
schd = scheduler(id_)

async with start(schd) as log:

# Hold active tasks by glob.
schd.pool.hold_tasks(["1/*"])
assert schd.pool.task_hold_mgr.store == {
("foo", IntegerPoint(1))
}

# Hold future tasks explicitly.
schd.pool.hold_tasks(["1/baz"])
assert schd.pool.task_hold_mgr.store == {
("foo", IntegerPoint(1)),
("baz", IntegerPoint(1))
}

# Release all held tasks by glob.
schd.pool.release_held_tasks(["1/*"])
assert not schd.pool.task_hold_mgr.store

#assert not log_filter(
# log,
# contains="did not complete required outputs: ['a', 'b']"
#)
62 changes: 0 additions & 62 deletions tests/unit/test_id_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,65 +337,3 @@ def test_point_match(
):
set_cycling_type(point.TYPE, time_zone='Z')
assert point_match(point, value, pattern_match) is expected


@pytest.mark.parametrize(
'ids,matched,not_matched',
[
(['1/*'], ['1/a', '1/b', '1/c'], []),
(['*/b'], ['1/b', '2/b'], []),
(['1/[ab]', '3/c'], ['1/a', '1/b'], ['3/c']),
(['1/FAM'], ['1/a', '1/b'], []),
(['1/BA*'], ['1/b', '1/c'], []),
(['*'], ['1/b', '2/b', '1/a', '2/a', '1/c', '2/c'], []),
(['1'], ['1/a', '1/b', '1/c'], []),
(
['2:waiting', '[12]:running'],
['2/b', '2/a', '2/c'],
['[12]:running']
),
]
)
def test_filter_ids_2(
set_cycling_type: Callable,
ids: List[str],
matched: List[str],
not_matched: List[str]
):
"""Test glob-matching of tasks in the task-to-hold list.
Assumes task pool was matched first, for non-waiting tasks.
"""
set_cycling_type(CYCLER_TYPE_INTEGER)

def foo(ids: List[str]) -> 'Set[Tuple[str, PointBase]]':
"""Convert list of task IDs to set of (name, point) tuples."""
return set(
(t['task'], IntegerPoint(t['cycle']))
for t in [
Tokens(id, relative=True)
for id in ids
]
)

def oof(t_ids: 'Set[Tuple[str, PointBase]]') -> List[str]:
"""Convert set of (name, point) tuples to list of task IDs."""
return [
f"{id[1]}/{id[0]}"
for id in t_ids
]

namespaces = {
'a': ['FAM', 'a'],
'b': ['FAM', 'BAM', 'b'],
'c': ['BAM', 'c'],
}

_matched, _not_matched = filter_ids(
foo(['1/a', '1/b', '1/c', '2/a', '2/b', '2/c']),
ids,
namespaces
)

assert sorted(oof(_matched)) == sorted(matched)
assert sorted(_not_matched) == sorted(not_matched)

0 comments on commit 676d374

Please sign in to comment.