Skip to content

Commit

Permalink
Merge pull request #5 from MagicCastle/suspendfail
Browse files Browse the repository at this point in the history
Add suspend_fail command
  • Loading branch information
cmd-ntrf authored Jan 13, 2025
2 parents 5f3534c + ec43862 commit 6261f07
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.6.2"
python-hostlist = "^1.21"
filelock = "^3.16.1"
requests = ">=2"

[tool.poetry.dev-dependencies]
Expand All @@ -20,3 +21,4 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry.scripts]
slurm_suspend = "slurm_autoscale_tfe:suspend"
slurm_resume = "slurm_autoscale_tfe:resume"
slurm_resume_fail = "slurm_autoscale_tfe:resume_fail"
156 changes: 77 additions & 79 deletions src/slurm_autoscale_tfe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from subprocess import run, PIPE
from requests.exceptions import Timeout

from filelock import FileLock
from hostlist import expand_hostlist

from .tfe import TFECLient, InvalidAPIToken, InvalidWorkspaceId
Expand All @@ -24,6 +25,14 @@

NODE_STATE_REGEX = re.compile(r"^NodeName=([a-z0-9-]*).*State=([A-Z_+]*).*$")
DOWN_FLAG_SET = frozenset(["DOWN", "POWER_DOWN", "POWERED_DOWN", "POWERING_DOWN"])
INSTANCE_TYPES = frozenset(
[
"aws_instance",
"azurerm_linux_virtual_machine",
"google_compute_instance",
"openstack_compute_instance_v2",
]
)


class AutoscaleException(Exception):
Expand All @@ -34,6 +43,7 @@ class Commands(Enum):
"""Enumerate the name of script's commands"""

RESUME = "resume"
RESUME_FAIL = "resume_fail"
SUSPEND = "suspend"


Expand All @@ -43,12 +53,19 @@ def change_host_state(hostlist, state, reason=None):
the state set by Slurm after calling resumeprogram or suspendprogram.
"""
reason = [f"reason={reason}"] if reason is not None else []
run(
["scontrol", "update", f"NodeName={hostlist}", f"state={state}"] + reason,
stdout=PIPE,
stderr=PIPE,
check=False,
)
try:
scontrol_run = run(
["scontrol", "update", f"NodeName={hostlist}", f"state={state}"] + reason,
stdout=PIPE,
stderr=PIPE,
check=False,
)
except FileNotFoundError as exc:
raise AutoscaleException("Cannot find command scontrol") from exc
if scontrol_run.stderr:
raise AutoscaleException(
f"Error while calling scontrol update: {scontrol_run.stderr.decode()}"
)


def resume(hostlist=sys.argv[-1]):
Expand Down Expand Up @@ -77,14 +94,26 @@ def suspend(hostlist=sys.argv[-1]):
return 0


def connect_tfe_client():
"""Return a TFE client object using environment variables for authentication
def resume_fail(hostlist=sys.argv[-1]):
"""Issue a request to Terraform cloud to power down the instances listed in
hostlist.
"""
if environ.get("TFE_TOKEN", "") == "":
try:
main(Commands.RESUME_FAIL, frozenset.difference, hostlist)
except AutoscaleException as exc:
logging.error("Failed to resume_fail '%s': %s", hostlist, str(exc))
change_host_state(hostlist, "DOWN", reason=str(exc))
return 1
return 0


def connect_tfe_client():
"""Return a TFE client object using environment variables for authentication"""
if "TFE_TOKEN" not in environ:
raise AutoscaleException(
f"{sys.argv[0]} requires environment variable TFE_TOKEN"
)
if environ.get("TFE_WORKSPACE", "") == "":
if "TFE_WORKSPACE" not in environ:
raise AutoscaleException(
f"{sys.argv[0]} requires environment variable TFE_WORKSPACE"
)
Expand All @@ -103,8 +132,7 @@ def connect_tfe_client():


def get_pool_from_tfe(tfe_client):
"""Retrieve id and content of POOL variable from Terraform cloud
"""
"""Retrieve id and content of POOL variable from Terraform cloud"""
try:
tfe_var = tfe_client.fetch_variable(POOL_VAR)
except Timeout as exc:
Expand All @@ -122,31 +150,19 @@ def get_pool_from_tfe(tfe_client):
return tfe_var["id"], frozenset()


def identify_online_nodes(tfe_pool):
"""Identify from a list of hosts which ones are online based on Slurm."""
def get_instances_from_tfe(tfe_client):
"""Return all names of instances that are created in Terraform Cloud state."""
try:
scontrol_run = run(
["scontrol", "show", "-o", "node", ",".join(tfe_pool)],
stdout=PIPE,
stderr=PIPE,
check=False,
)
except FileNotFoundError as exc:
raise AutoscaleException("Cannot find command scontrol") from exc
if scontrol_run.stderr:
raise AutoscaleException(
f"Error while calling scontrol {scontrol_run.stderr.decode()}"
)

slurm_pool = []
for line in scontrol_run.stdout.decode().split("\n"):
match = NODE_STATE_REGEX.match(line)
if match:
node_state = frozenset(match.group(2).split("+"))
if not node_state.intersection(DOWN_FLAG_SET):
slurm_pool.append(match.group(1))

return frozenset(slurm_pool)
tfe_resources = tfe_client.fetch_resources()
except Timeout as exc:
raise AutoscaleException("Connection to Terraform cloud timeout (5s)") from exc
instances = []
address_prefix = None
for resource in tfe_resources:
if resource["attributes"]["provider-type"] in INSTANCE_TYPES:
instances.append(resource["attributes"]["name-index"])
address_prefix = resource["attributes"]["address"].split("[")[0]
return frozenset(instances), address_prefix


def main(command, set_op, hostlist):
Expand All @@ -156,51 +172,33 @@ def main(command, set_op, hostlist):
"""
hosts = frozenset(expand_hostlist(hostlist))
tfe_client = connect_tfe_client()
var_id, tfe_pool = get_pool_from_tfe(tfe_client)

# Verify that TFE pool corresponds to Slurm pool:
# When a powered up node fail to respond after slurm.conf's ResumeTimeout
# slurmctld marks the node as "DOWN", but it will not call the SuspendProgram
# on the node. Therefore, a change drift can happen between Slurm internal memory
# of what nodes are online and the Terraform Cloud pool variable. To limit the
# drift effect, we validate the state in Slurm of each node present in Terraform Cloud
# pool variable. We only keep the nodes that are present in Slurm.
slurm_pool = identify_online_nodes(tfe_pool)
zombie_nodes = tfe_pool - slurm_pool - hosts
extra_command = ""
if len(zombie_nodes) > 0:
zombie_nodes_string = ",".join(sorted(zombie_nodes))
logging.warning(
"TFE vs Slurm drift detected, these nodes will be suspended: %s",
zombie_nodes_string,
)
extra_command = f" & suspend {zombie_nodes_string} (drift detection)"

new_pool = set_op(slurm_pool, hosts)

if tfe_pool != new_pool:
try:
tfe_client.update_variable(var_id, list(new_pool))
except Timeout as exc:
raise AutoscaleException(
"Connection to Terraform cloud timeout (5s)"
) from exc
else:
logging.warning(
'TFE pool was already correctly set when "%s %s" was issued',
command.value,
hostlist,
)

with FileLock("/tmp/slurm_autoscale_tfe_pool.lock"):
var_id, tfe_pool = get_pool_from_tfe(tfe_client)
next_pool = set_op(tfe_pool, hosts)
if tfe_pool != next_pool:
try:
tfe_client.update_variable(var_id, list(next_pool))
except Timeout as exc:
raise AutoscaleException(
"Connection to Terraform cloud timeout (5s)"
) from exc
else:
logging.warning(
'TFE pool variable is unchanged following the issue of "%s %s"',
command.value,
hostlist,
)

_, address_prefix = get_instances_from_tfe(tfe_client)
try:
tfe_client.apply(f"Slurm {command.value} {hostlist} {extra_command}".strip())
run_id = tfe_client.apply(
f"Slurm {command.value} {hostlist}".strip(),
targets=[f'module.{address_prefix}["{hostname}"]' for hostname in hosts],
)
except Timeout as exc:
raise AutoscaleException("Connection to Terraform cloud timeout (5s)") from exc
logging.info("%s %s", command.value, hostlist)

logging.info("%s %s (%s)", command.value, hostlist, run_id)

if __name__ == "__main__":
if sys.argv[1] == Commands.RESUME.value:
sys.exit(resume())
elif sys.argv[1] == Commands.SUSPEND.value:
sys.exit(suspend())
if command == Commands.RESUME_FAIL:
change_host_state(hostlist, "IDLE")
29 changes: 26 additions & 3 deletions src/slurm_autoscale_tfe/tfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def fetch_variable(self, var_name):
}
return None

def fetch_resources(self):
"""Get all resources from the workspace"""
url = "/".join((WORKSPACE_API, self.workspace, "resources"))
resources = []
while url is not None:
resp = requests.get(url, headers=self.headers, timeout=self.timeout)
json_ = resp.json()
data = json_["data"]
resources.extend(data)
url = json_["links"]["next"]
return resources

def update_variable(self, var_id, value):
"""Update a workspace variable content"""
patch_data = {
Expand All @@ -74,16 +86,27 @@ def update_variable(self, var_id, value):
url, headers=self.headers, json=patch_data, timeout=self.timeout
)

def apply(self, message):
def apply(self, message, targets):
"""Queue a workspace run"""
run_data = {
"data": {
"attributes": {"message": message},
"attributes": {
"message": message,
"target-addrs": targets,
"auto-apply": True,
},
"relationships": {
"workspace": {"data": {"type": "workspaces", "id": self.workspace}},
},
}
}
return requests.post(
resp = requests.post(
RUNS_API, headers=self.headers, json=run_data, timeout=self.timeout
)
return resp.json()["data"]["id"]

def get_run_status(self, run_id):
"""Return status of run"""
url = "/".join((RUNS_API, run_id))
resp = requests.get(url, headers=self.headers, timeout=self.timeout)
return resp.json()["data"]["attributes"]["status"]

0 comments on commit 6261f07

Please sign in to comment.