Skip to content

Commit

Permalink
Support limiting the number of states running in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
lkubb committed Oct 10, 2024
1 parent b65d8a1 commit 32ffb61
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 35 deletions.
1 change: 1 addition & 0 deletions changelog/49301.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for limiting the number of parallel states executing at the same time via `state_max_parallel`
10 changes: 10 additions & 0 deletions doc/ref/configuration/minion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2331,6 +2331,16 @@ performance is hampered.
state_queue: 2
.. conf_minion:: state_max_parallel

``state_max_parallel``
----------------------

Default: ``0``

Limit the number of ``parallel: true`` states that can be running at the same time.
By default, there is no limit.

.. conf_minion:: state_verbose

``state_verbose``
Expand Down
4 changes: 4 additions & 0 deletions salt/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ def _gather_buffer_space():
"state_auto_order": bool,
# Fire events as state chunks are processed by the state compiler
"state_events": bool,
# Limit the number of states that can be running in parallel
"state_max_parallel": int,
# The number of seconds a minion should wait before retry when attempting authentication
"acceptance_wait_time": float,
# The number of seconds a minion should wait before giving up during authentication
Expand Down Expand Up @@ -1218,6 +1220,7 @@ def _gather_buffer_space():
"state_events": False,
"state_aggregate": False,
"state_queue": False,
"state_max_parallel": 0,
"snapper_states": False,
"snapper_states_config": "root",
"acceptance_wait_time": 10,
Expand Down Expand Up @@ -1557,6 +1560,7 @@ def _gather_buffer_space():
"state_auto_order": True,
"state_events": False,
"state_aggregate": False,
"state_max_parallel": 0,
"search": "",
"loop_interval": 60,
"nodegroups": {},
Expand Down
100 changes: 65 additions & 35 deletions salt/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2002,7 +2002,9 @@ def _call_parallel_target(cls, instance, init_kwargs, name, cdata, low):
with salt.utils.files.fopen(tfile, "wb+") as fp_:
fp_.write(msgpack_serialize(ret))

def call_parallel(self, cdata: dict[str, Any], low: LowChunk):
def call_parallel(
self, cdata: dict[str, Any], low: LowChunk, running: dict[str, dict]
):
"""
Call the state defined in the given cdata in parallel
"""
Expand All @@ -2025,13 +2027,19 @@ def call_parallel(self, cdata: dict[str, Any], low: LowChunk):
args=(instance, self._init_kwargs, name, cdata, low),
name=f"ParallelState({name})",
)
proc.start()
if "__procs__" not in running:
running["__procs__"] = {}
running["__procs__"][_gen_tag(low)] = proc
if self.check_max_parallel(running):
proc.start()
comment = "Started in a separate process"
else:
comment = "Waiting to be started in a separate process, max_parallel hit"
ret = {
"name": name,
"result": None,
"changes": {},
"comment": "Started in a separate process",
"proc": proc,
"comment": comment,
}
return ret

Expand Down Expand Up @@ -2177,7 +2185,9 @@ def call(
)
elif not low.get("__prereq__") and low.get("parallel"):
# run the state call in parallel, but only if not in a prereq
ret = self.call_parallel(cdata, low)
ret = self.call_parallel(
cdata, low, running if running is not None else {}
)
else:
self.format_slots(cdata)
with salt.utils.files.set_umask(low.get("__umask__")):
Expand Down Expand Up @@ -2498,6 +2508,9 @@ def _call_pending(
if "__FAILHARD__" in running:
running.pop("__FAILHARD__")
return running
# Start any queued states when state_max_parallel has been hit previously
self.reconcile_procs(running)

tag = _gen_tag(low)
if tag not in running:
# Check if this low chunk is paused
Expand All @@ -2518,6 +2531,7 @@ def _call_pending(
if self.reconcile_procs(running):
break
time.sleep(0.01)
running.pop("__procs__", None)
ret = {**disabled, **running}
return ret

Expand Down Expand Up @@ -2581,41 +2595,57 @@ def check_pause(self, low: LowChunk) -> Optional[str]:
return "run"
return "run"

def check_max_parallel(self, running: dict) -> bool:
"""
Check whether an additional ``parallel`` state can be started.
"""
if not (allowed := self.opts.get("state_max_parallel")):
return True
cnt = sum(
int(proc.ident is not None and proc.is_alive())
for proc in running.get("__procs__", {}).values()
)
return cnt < allowed

def reconcile_procs(self, running: dict) -> bool:
"""
Check the running dict for processes and resolve them
"""
retset = set()
for tag in running:
proc = running[tag].get("proc")
if proc:
if not proc.is_alive():
ret_cache = os.path.join(
self.opts["cachedir"],
self.jid,
salt.utils.hashutils.sha1_digest(tag),
)
if not os.path.isfile(ret_cache):
ret = {
"result": False,
"comment": "Parallel process failed to return",
"name": running[tag]["name"],
"changes": {},
}
try:
with salt.utils.files.fopen(ret_cache, "rb") as fp_:
ret = msgpack_deserialize(fp_.read())
except OSError:
ret = {
"result": False,
"comment": "Parallel cache failure",
"name": running[tag]["name"],
"changes": {},
}
running[tag].update(ret)
running[tag].pop("proc")
else:
retset.add(False)
# Cannot iterate over the dict itself, need to pop items from the dictionary later
for tag in list(running.get("__procs__", {})):
proc = running["__procs__"][tag]
if proc.ident is None:
if self.check_max_parallel(running):
proc.start()
retset.add(False)
elif not proc.is_alive():
ret_cache = os.path.join(
self.opts["cachedir"],
self.jid,
salt.utils.hashutils.sha1_digest(tag),
)
if not os.path.isfile(ret_cache):
ret = {
"result": False,
"comment": "Parallel process failed to return",
"name": running[tag]["name"],
"changes": {},
}
try:
with salt.utils.files.fopen(ret_cache, "rb") as fp_:
ret = msgpack_deserialize(fp_.read())
except OSError:
ret = {
"result": False,
"comment": "Parallel cache failure",
"name": running[tag]["name"],
"changes": {},
}
running[tag].update(ret)
running["__procs__"].pop(tag)
else:
retset.add(False)
return False not in retset

def _check_requisites(self, low: LowChunk, running: dict[str, dict[str, Any]]):
Expand Down
52 changes: 52 additions & 0 deletions tests/pytests/functional/modules/state/test_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import datetime

import pytest


@pytest.fixture(scope="module")
def minion_config_overrides():
return {"state_max_parallel": 2}


@pytest.mark.skip_on_windows
def test_max_parallel(state, state_tree):
"""
Ensure the number of running ``parallel`` states can be limited.
"""
sls_contents = """
service_a:
cmd.run:
- name: sleep 3
- parallel: True
service_b:
cmd.run:
- name: sleep 3
- parallel: True
service_c:
cmd.run:
- name: 'true'
- parallel: True
"""

with pytest.helpers.temp_file("state_max_parallel.sls", sls_contents, state_tree):
ret = state.sls(
"state_max_parallel",
__pub_jid="1", # Because these run in parallel we need a fake JID)
)
start_a = datetime.datetime.combine(
datetime.date.today(),
datetime.time.fromisoformat(
ret["cmd_|-service_a_|-sleep 3_|-run"]["start_time"]
),
)
start_c = datetime.datetime.combine(
datetime.date.today(),
datetime.time.fromisoformat(
ret["cmd_|-service_c_|-true_|-run"]["start_time"]
),
)
start_diff = start_c - start_a
# c needs to wait for a or b to finish
assert start_diff > datetime.timedelta(seconds=3)

0 comments on commit 32ffb61

Please sign in to comment.