Skip to content

Commit

Permalink
Type checking improved.
Browse files Browse the repository at this point in the history
  • Loading branch information
wxtim committed Oct 21, 2024
1 parent 31290ba commit eccd7fe
Show file tree
Hide file tree
Showing 18 changed files with 92 additions and 63 deletions.
4 changes: 2 additions & 2 deletions cylc/flow/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ async def poll_tasks(schd: 'Scheduler', tasks: Iterable[str]):
"""Poll pollable tasks or a task or family if options are provided."""
validate.is_tasks(tasks)
yield
if schd.get_run_mode() == RunMode.SIMULATION.value:
if schd.get_run_mode() == RunMode.SIMULATION:
yield 0
itasks, _, bad_items = schd.pool.filter_task_proxies(tasks)
schd.task_job_mgr.poll_task_jobs(schd.workflow, itasks)
Expand All @@ -262,7 +262,7 @@ async def kill_tasks(schd: 'Scheduler', tasks: Iterable[str]):
validate.is_tasks(tasks)
yield
itasks, _, bad_items = schd.pool.filter_task_proxies(tasks)
if schd.get_run_mode() == RunMode.SIMULATION.value:
if schd.get_run_mode() == RunMode.SIMULATION:
for itask in itasks:
if itask.state(*TASK_STATUSES_ACTIVE):
itask.state_reset(TASK_STATUS_FAILED)
Expand Down
2 changes: 1 addition & 1 deletion cylc/flow/data_store_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def generate_definition_elements(self):
time_zone_info = TIME_ZONE_LOCAL_INFO
for key, val in time_zone_info.items():
setbuff(workflow.time_zone_info, key, val)
workflow.run_mode = RunMode.get(config.options)
workflow.run_mode = RunMode.get(config.options).value
workflow.cycling_mode = config.cfg['scheduling']['cycling mode']
workflow.workflow_log_dir = self.schd.workflow_log_dir
workflow.job_log_names.extend(list(JOB_LOG_OPTS.values()))
Expand Down
7 changes: 4 additions & 3 deletions cylc/flow/prerequisite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from cylc.flow.data_messages_pb2 import PbCondition, PbPrerequisite
from cylc.flow.exceptions import TriggerExpressionError
from cylc.flow.id import quick_relative_detokenise
from cylc.flow.run_modes import RunMode


if TYPE_CHECKING:
Expand Down Expand Up @@ -263,7 +264,7 @@ def _eval_satisfied(self) -> bool:

def satisfy_me(
self, outputs: Iterable['Tokens'],
mode: Literal['skip', 'live', 'simulation', 'skip'] = 'live'
mode: "RunMode" = RunMode.LIVE
) -> 'Set[Tokens]':
"""Attempt to satisfy me with given outputs.
Expand All @@ -273,9 +274,9 @@ def satisfy_me(
"""
satisfied_message: SatisfiedState

if mode != 'live':
if mode != RunMode.LIVE:
satisfied_message = self.DEP_STATE_SATISFIED_BY.format(
mode) # type: ignore
mode.value) # type: ignore
else:
satisfied_message = self.DEP_STATE_SATISFIED
valid = set()
Expand Down
11 changes: 7 additions & 4 deletions cylc/flow/run_modes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,12 @@ def describe(self):
raise KeyError(f'No description for {self}.')

@staticmethod
def get(options: 'Values') -> str:
def get(options: 'Values') -> "RunMode":
"""Return the workflow run mode from the options."""
return getattr(options, 'run_mode', None) or RunMode.LIVE.value
run_mode = getattr(options, 'run_mode', None)
if run_mode:
return RunMode(run_mode)
return RunMode.LIVE

def get_submit_method(self) -> 'Optional[SubmissionInterface]':
"""Return the job submission method for this run mode.
Expand Down Expand Up @@ -113,9 +116,9 @@ def disable_task_event_handlers(itask: 'TaskProxy'):
"""
mode = itask.run_mode
return (
mode == RunMode.SIMULATION.value
mode == RunMode.SIMULATION
or (
mode == RunMode.SKIP.value
mode == RunMode.SKIP
and itask.platform.get(
'disable task event handlers', False)
)
Expand Down
3 changes: 2 additions & 1 deletion cylc/flow/run_modes/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,13 @@ def sim_time_check(
"""
now = time()
sim_task_state_changed: bool = False

for itask in itasks:
if (
itask.state.status != TASK_STATUS_RUNNING
or (
itask.run_mode
and itask.run_mode != RunMode.SIMULATION.value
and itask.run_mode != RunMode.SIMULATION
)
):
continue
Expand Down
5 changes: 4 additions & 1 deletion cylc/flow/run_modes/skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def submit_task_job(
'execution retry delays': []
}
itask.summary['job_runner_name'] = RunMode.SKIP.value
itask.run_mode = RunMode.SKIP.value
itask.jobs.append(
task_job_mgr.get_simulation_job_conf(itask, _workflow)
)
itask.run_mode = RunMode.SKIP
task_job_mgr.workflow_db_mgr.put_insert_task_jobs(
itask, {
'time_submit': now[1],
Expand Down
15 changes: 9 additions & 6 deletions cylc/flow/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,10 @@ async def configure(self, params):
og_run_mode = self.get_run_mode()
if run_mode != og_run_mode:
raise InputError(
f'This workflow was originally run in {og_run_mode} mode:'
f' Will not restart in {run_mode} mode.')
"This workflow was originally run in "
f"{run_mode.value} mode:"
f" Will not restart in {run_mode.value} mode."
)

self.profiler.log_memory("scheduler.py: before load_flow_file")
try:
Expand Down Expand Up @@ -1195,7 +1197,7 @@ def run_event_handlers(self, event, reason=""):
Run workflow events only in live mode or skip mode.
"""
if self.get_run_mode() in WORKFLOW_ONLY_MODES:
if self.get_run_mode().value in WORKFLOW_ONLY_MODES:
return
self.workflow_event_handler.handle(self, event, str(reason))

Expand Down Expand Up @@ -1320,7 +1322,7 @@ def timeout_check(self):
"""Check workflow and task timers."""
self.check_workflow_timers()
# check submission and execution timeout and polling timers
if self.get_run_mode() != RunMode.SIMULATION.value:
if self.get_run_mode() != RunMode.SIMULATION:
self.task_job_mgr.check_task_jobs(self.workflow, self.pool)

async def workflow_shutdown(self):
Expand Down Expand Up @@ -1518,8 +1520,9 @@ async def _main_loop(self) -> None:
self.xtrigger_mgr.housekeep(self.pool.get_tasks())
self.pool.clock_expire_tasks()
self.release_queued_tasks()

if (
self.options.run_mode == RunMode.SIMULATION.value
self.get_run_mode() == RunMode.SIMULATION
and sim_time_check(
self.task_events_mgr,
self.pool.get_tasks(),
Expand Down Expand Up @@ -1979,7 +1982,7 @@ def _check_startup_opts(self) -> None:
f"option --{opt}=reload is only valid for restart"
)

def get_run_mode(self) -> str:
def get_run_mode(self) -> RunMode:
return RunMode.get(self.options)

async def handle_exception(self, exc: BaseException) -> NoReturn:
Expand Down
8 changes: 4 additions & 4 deletions cylc/flow/task_events_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def process_message(

# ... but either way update the job ID in the job proxy (it only
# comes in via the submission message).
if itask.run_mode != RunMode.SIMULATION.value:
if itask.run_mode != RunMode.SIMULATION:
job_tokens = itask.tokens.duplicate(
job=str(itask.submit_num)
)
Expand Down Expand Up @@ -896,7 +896,7 @@ def _process_message_check(
if (
itask.state(TASK_STATUS_WAITING)
# Polling in live mode only:
and itask.run_mode == RunMode.LIVE.value
and itask.run_mode == RunMode.LIVE
and (
(
# task has a submit-retry lined up
Expand Down Expand Up @@ -1470,7 +1470,7 @@ def _process_message_submitted(
)

itask.set_summary_time('submitted', event_time)
if itask.run_mode == RunMode.SIMULATION.value:
if itask.run_mode == RunMode.SIMULATION:
# Simulate job started as well.
itask.set_summary_time('started', event_time)
if itask.state_reset(TASK_STATUS_RUNNING, forced=forced):
Expand Down Expand Up @@ -1507,7 +1507,7 @@ def _process_message_submitted(
'submitted',
event_time,
)
if itask.run_mode == RunMode.SIMULATION.value:
if itask.run_mode == RunMode.SIMULATION:
# Simulate job started as well.
self.data_store_mgr.delta_job_time(
job_tokens,
Expand Down
13 changes: 9 additions & 4 deletions cylc/flow/task_job_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def submit_task_jobs(
itasks,
curve_auth,
client_pub_key_dir,
run_mode: Union[str, RunMode] = RunMode.LIVE,
run_mode: RunMode = RunMode.LIVE,
):
"""Prepare for job submission and submit task jobs.
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def submit_nonlive_task_jobs(
self: 'TaskJobManager',
workflow: str,
itasks: 'List[TaskProxy]',
workflow_run_mode: Union[str, RunMode],
workflow_run_mode: RunMode,
) -> 'Tuple[List[TaskProxy], List[TaskProxy]]':
"""Identify task mode and carry out alternative submission
paths if required:
Expand Down Expand Up @@ -1058,14 +1058,19 @@ def submit_nonlive_task_jobs(
# Get task config with broadcasts applied:
rtconfig = self.task_events_mgr.broadcast_mgr.get_updated_rtconfig(
itask)

# Apply task run mode
if workflow_run_mode in WORKFLOW_ONLY_MODES:
if workflow_run_mode.value in WORKFLOW_ONLY_MODES:
# Task run mode cannot override workflow run-mode sim or dummy:
run_mode = workflow_run_mode
else:
# If workflow mode is skip or live and task mode is set,
# override workflow mode, else use workflow mode.
run_mode = rtconfig.get('run mode', None) or workflow_run_mode
run_mode = rtconfig.get('run mode', None)
if run_mode:
run_mode = RunMode(run_mode)
else:
run_mode = workflow_run_mode

# Store the run mode of the this submission:
itask.run_mode = run_mode
Expand Down
7 changes: 3 additions & 4 deletions cylc/flow/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,10 +1422,9 @@ def spawn_on_output(self, itask: TaskProxy, output: str) -> None:
tasks = [c_task]

for t in tasks:

t.satisfy_me(
[itask.tokens.duplicate(task_sel=output)],
mode=itask.run_mode
mode=itask.run_mode # type: ignore
)
self.data_store_mgr.delta_task_prerequisite(t)
if not in_pool:
Expand Down Expand Up @@ -1554,7 +1553,7 @@ def spawn_on_all_outputs(
if completed_only:
c_task.satisfy_me(
[itask.tokens.duplicate(task_sel=message)],
mode=itask.run_mode
mode=itask.run_mode # type: ignore
)
self.data_store_mgr.delta_task_prerequisite(c_task)
self.add_to_pool(c_task)
Expand Down Expand Up @@ -1979,7 +1978,7 @@ def _set_outputs_itask(
rtconfig = bc_mgr.get_updated_rtconfig(itask)
outputs.remove(RunMode.SKIP.value)
skips = get_skip_mode_outputs(itask, rtconfig)
itask.run_mode = RunMode.SKIP.value
itask.run_mode = RunMode.SKIP
outputs = self._standardise_outputs(
itask.point, itask.tdef, outputs)
outputs = list(set(outputs + skips))
Expand Down
5 changes: 2 additions & 3 deletions cylc/flow/task_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
Optional,
Set,
Tuple,
Union,
)

from metomi.isodatetime.timezone import get_local_time_zone
Expand Down Expand Up @@ -300,7 +299,7 @@ def __init__(
self.graph_children = generate_graph_children(tdef, self.point)

self.mode_settings: Optional['ModeSettings'] = None
self.run_mode: Optional[Union[str, RunMode]] = None
self.run_mode: Optional[RunMode] = None

if self.tdef.expiration_offset is not None:
self.expire_time = (
Expand Down Expand Up @@ -551,7 +550,7 @@ def state_reset(
return False

def satisfy_me(
self, task_messages: 'Iterable[Tokens]', mode=RunMode.LIVE.value
self, task_messages: 'Iterable[Tokens]', mode: "RunMode" = RunMode.LIVE
) -> 'Set[Tokens]':
"""Try to satisfy my prerequisites with given output messages.
Expand Down
3 changes: 2 additions & 1 deletion cylc/flow/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from cylc.flow.cycling import PointBase
from cylc.flow.id import Tokens
from cylc.flow.prerequisite import PrereqMessage
from cylc.flow.run_modes import RunMode
from cylc.flow.taskdef import TaskDef


Expand Down Expand Up @@ -324,7 +325,7 @@ def __call__(
def satisfy_me(
self,
outputs: Iterable['Tokens'],
mode,
mode: "RunMode",
) -> Set['Tokens']:
"""Try to satisfy my prerequisites with given outputs.
Expand Down
12 changes: 8 additions & 4 deletions cylc/flow/workflow_db_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,15 @@ def put_workflow_params(self, schd: 'Scheduler') -> None:
value = getattr(schd.options, key, None)
value = None if value == 'reload' else value
self.put_workflow_params_1(key, value)
for key in (

self.put_workflow_params_1(
self.KEY_CYCLE_POINT_TIME_ZONE,
getattr(schd.options, self.KEY_CYCLE_POINT_TIME_ZONE, None),
)
self.put_workflow_params_1(
self.KEY_RUN_MODE,
self.KEY_CYCLE_POINT_TIME_ZONE
):
self.put_workflow_params_1(key, getattr(schd.options, key, None))
schd.get_run_mode().value,
)

def put_workflow_params_1(
self, key: str, value: Union[AnyStr, float, None]
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from cylc.flow.option_parsers import Options
from cylc.flow.pathutil import get_cylc_run_dir
from cylc.flow.rundb import CylcWorkflowDAO
from cylc.flow.run_modes import RunMode
from cylc.flow.scripts.validate import ValidateOptions
from cylc.flow.scripts.install import (
install as cylc_install,
Expand Down Expand Up @@ -686,7 +687,7 @@ def capture_live_submissions(capcall, monkeypatch):
would have been submitted had this fixture not been used.
"""
def fake_submit(self, _workflow, itasks, *_):
self.submit_nonlive_task_jobs(_workflow, itasks, 'simulation')
self.submit_nonlive_task_jobs(_workflow, itasks, RunMode.SIMULATION)
for itask in itasks:
for status in (TASK_STATUS_SUBMITTED, TASK_STATUS_SUCCEEDED):
self.task_events_mgr.process_message(
Expand Down
Loading

0 comments on commit eccd7fe

Please sign in to comment.