Skip to content

Commit

Permalink
Merge pull request #42 from oliver-sanders/cylc-set-task
Browse files Browse the repository at this point in the history
tokens: avoid surplus detokenising
  • Loading branch information
hjoliver authored Jan 25, 2024
2 parents cfada17 + 89c02ce commit ce2065b
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 21 deletions.
4 changes: 2 additions & 2 deletions cylc/flow/prerequisite.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _conditional_is_satisfied(self):
'"%s":\n%s' % (self.get_raw_conditional_expression(), err_msg))
return res

def satisfy_me(self, outputs: Iterable['Tokens']) -> Set[str]:
def satisfy_me(self, outputs: Iterable['Tokens']) -> 'Set[Tokens]':
"""Attempt to satisfy me with given outputs.
Updates cache with the result.
Expand All @@ -214,7 +214,7 @@ def satisfy_me(self, outputs: Iterable['Tokens']) -> Set[str]:
prereq = output.to_prereq_tuple()
if prereq not in self.satisfied:
continue
valid.add(output.relative_id_with_selectors)
valid.add(output)
self.satisfied[prereq] = self.DEP_STATE_SATISFIED
if self.conditional_expression is None:
self._all_satisfied = all(self.satisfied.values())
Expand Down
28 changes: 18 additions & 10 deletions cylc/flow/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ def spawn_on_output(self, itask, output, forced=False):
else:
tasks = [c_task]
for t in tasks:
t.satisfy_me([f"{itask.identity}:{output}"])
t.satisfy_me([itask.tokens.duplicate(task_sel=output)])
self.data_store_mgr.delta_task_prerequisite(t)
self.add_to_pool(t)

Expand Down Expand Up @@ -1423,7 +1423,9 @@ def spawn_on_all_outputs(
# not spawnable
continue
if completed_only:
c_task.satisfy_me([f"{itask.identity}:{output}"])
c_task.satisfy_me(
[itask.tokens.duplicate(task_sel=output)]
)
self.data_store_mgr.delta_task_prerequisite(c_task)
self.add_to_pool(c_task)
if (
Expand Down Expand Up @@ -1604,9 +1606,10 @@ def spawn_task(
and itask.tdef.has_abs_triggers
and itask.state.prerequisites_are_not_all_satisfied()
):
itask.satisfy_me(
[f"{a[0]}/{a[1]}:{a[2]}" for a in self.abs_outputs_done]
)
itask.satisfy_me([
Tokens(cycle=cycle, task=task, task_sel=output)
for cycle, task, output in self.abs_outputs_done
])

if prev_flow_wait:
self._spawn_after_flow_wait(itask)
Expand Down Expand Up @@ -1716,6 +1719,11 @@ def set( # noqa: A003
# Illegal flow command opts
return

_prereqs: List[Tokens] = [
Tokens(prereq, relative=True)
for prereq in (prereqs or [])
]

# Get matching pool tasks and future task definitions.
itasks, future_tasks, unmatched = self.filter_task_proxies(
items,
Expand All @@ -1725,17 +1733,17 @@ def set( # noqa: A003

for itask in itasks:
self.merge_flows(itask, flow_nums)
if prereqs:
if _prereqs:
self._set_prereqs_itask(
itask, prereqs, flow_nums, flow_wait)
itask, _prereqs, flow_nums, flow_wait)
else:
self._set_outputs_itask(itask, outputs)

for name, point in future_tasks:
tdef = self.config.get_taskdef(name)
if prereqs:
if _prereqs:
self._set_prereqs_tdef(
point, tdef, prereqs, flow_nums, flow_wait)
point, tdef, _prereqs, flow_nums, flow_wait)
else:
trans = self._get_task_proxy(
point, tdef, flow_nums, flow_wait, transient=True)
Expand Down Expand Up @@ -1778,7 +1786,7 @@ def _set_outputs_itask(
def _set_prereqs_itask(
self,
itask: 'TaskProxy',
prereqs: List[str],
prereqs: List[Tokens],
flow_nums: Set[int],
flow_wait: bool
) -> None:
Expand Down
10 changes: 5 additions & 5 deletions cylc/flow/task_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

from cylc.flow import LOG
from cylc.flow.flow_mgr import stringify_flow_nums
from cylc.flow.id import Tokens
from cylc.flow.platforms import get_platform
from cylc.flow.task_action_timer import TimerFlags
from cylc.flow.task_state import (
Expand All @@ -57,6 +56,7 @@
from cylc.flow.cycling import PointBase
from cylc.flow.task_action_timer import TaskActionTimer
from cylc.flow.taskdef import TaskDef
from cylc.flow.id import Tokens


class TaskProxy:
Expand Down Expand Up @@ -518,18 +518,18 @@ def state_reset(
return True
return False

def satisfy_me(self, outputs: Iterable[str]) -> None:
def satisfy_me(self, outputs: 'Iterable[Tokens]') -> None:
"""Try to satisfy my prerequisites with given outputs.
The output strings are of the form "cycle/task:message"
Log a warning for outputs that I don't depend on.
"""
tokens = [Tokens(p, relative=True) for p in outputs]
used = self.state.satisfy_me(tokens)
used = self.state.satisfy_me(outputs)
for output in set(outputs) - used:
LOG.warning(
f"{self.identity} does not depend on {output}"
f"{self.identity} does not depend on"
f" {output.relative_id_with_selectors}"
)

def clock_expire(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions cylc/flow/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,12 +313,12 @@ def __call__(
def satisfy_me(
self,
outputs: Iterable['Tokens']
) -> Set[str]:
) -> Set['Tokens']:
"""Try to satisfy my prerequisites with given outputs.
Return which outputs I actually depend on.
"""
valid: Set[str] = set()
valid: Set[Tokens] = set()
for prereq in (*self.prerequisites, *self.suicide_prerequisites):
yep = prereq.satisfy_me(outputs)
if yep:
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/test_task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
from cylc.flow import CYLC_LOG
from cylc.flow.cycling.integer import IntegerPoint
from cylc.flow.cycling.iso8601 import ISO8601Point
from cylc.flow.task_events_mgr import TaskEventsManager
from cylc.flow.data_store_mgr import TASK_PROXIES
from cylc.flow.id import Tokens
from cylc.flow.task_events_mgr import TaskEventsManager
from cylc.flow.task_outputs import (
TASK_OUTPUT_STARTED,
TASK_OUTPUT_SUCCEEDED
Expand Down Expand Up @@ -1523,7 +1524,10 @@ async def test_prereq_satisfaction(
assert not b.is_waiting_prereqs_done()

# set valid and invalid prerequisites, check log.
b.satisfy_me(["1/a:x", "1/a:y", "1/a:z", "1/a:w"])
b.satisfy_me([
Tokens(id_, relative=True)
for id_ in ["1/a:x", "1/a:y", "1/a:z", "1/a:w"]
])
assert log_filter(log, contains="1/b does not depend on 1/a:z")
assert log_filter(log, contains="1/b does not depend on 1/a:w")
assert not log_filter(log, contains="1/b does not depend on 1/a:x")
Expand Down

0 comments on commit ce2065b

Please sign in to comment.