diff --git a/.gitignore b/.gitignore index 119b3cc..e573469 100755 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ dist build scripts/ logs/ +shortcuts.txt diff --git a/setup.py b/setup.py index 7cbaca7..bf8f299 100755 --- a/setup.py +++ b/setup.py @@ -36,3 +36,4 @@ def get_version(path): 'Programming Language :: Python :: 3' ], ) + diff --git a/slurmpy/slurmpy.py b/slurmpy/slurmpy.py index 966e9b0..56a612f 100755 --- a/slurmpy/slurmpy.py +++ b/slurmpy/slurmpy.py @@ -7,59 +7,57 @@ #SBATCH -e logs/job-name.%J.err #SBATCH -o logs/job-name.%J.out #SBATCH -J job-name +#SBATCH --no-requeue -#SBATCH --account=ucgd-kp -#SBATCH --partition=ucgd-kp +#SBATCH -J -Q +#SBATCH --no-requeue +#SBATCH --partition=ucgd #SBATCH --time=84:00:00 set -eo pipefail -o nounset -__script__ - ->>> s = Slurm("job-name", {"account": "ucgd-kp", "partition": "ucgd-kp"}, bash_strict=False) ->>> print(str(s)) -#!/bin/bash - -#SBATCH -e logs/job-name.%J.err -#SBATCH -o logs/job-name.%J.out -#SBATCH -J job-name - -#SBATCH --account=ucgd-kp -#SBATCH --partition=ucgd-kp -#SBATCH --time=84:00:00 - - +hostname +date +touch job-name.started __script__ - - ->>> job_id = s.run("rm -f aaa; sleep 10; echo 213 > aaa", name_addition="", tries=1) - ->>> job = s.run("cat aaa; rm aaa", name_addition="", tries=1, depends_on=[job_id]) - +touch job-name.finished + +>>> job_id = s.run("rm -f aaa; sleep 10; echo 213 > aaa", name_addition="") +>>> job = s.run("cat aaa; rm aaa", name_addition="", depends_on=[job_id]) +>>> s.query(job, "JobState") +'PENDING' """ + from __future__ import print_function -import sys +import atexit +import datetime +import hashlib import os import subprocess +import sys import tempfile -import atexit -import hashlib -import datetime +import time TMPL = """\ #!/bin/bash -#SBATCH -e {log_dir}/{name}.%J.err -#SBATCH -o {log_dir}/{name}.%J.out +#SBATCH -e logs/{name}.%J.err +#SBATCH -o logs/{name}.%J.out #SBATCH -J {name} +#SBATCH --no-requeue {header} -{bash_setup} +set -eo pipefail -o nounset -__script__""" +hostname +date +touch job-name.started +__script__ +touch job-name.finished +""" def tmp(suffix=".sh"): @@ -68,36 +66,34 @@ def tmp(suffix=".sh"): return t +class SlurmException(Exception): + def init(self, *args, **kwargs): + super(SlurmException, self).init(*args, **kwargs) + pass + + class Slurm(object): def __init__(self, name, slurm_kwargs=None, tmpl=None, - date_in_name=True, scripts_dir="slurm-scripts", - log_dir='logs', bash_strict=True): + date_in_name=True, scripts_dir="slurm-scripts/"): if slurm_kwargs is None: slurm_kwargs = {} if tmpl is None: tmpl = TMPL - self.log_dir = log_dir - self.bash_strict = bash_strict header = [] - if 'time' not in slurm_kwargs.keys(): - slurm_kwargs['time'] = '84:00:00' - for k, v in slurm_kwargs.items(): - if len(k) > 1: - k = "--" + k + "=" - else: - k = "-" + k + " " - header.append("#SBATCH %s%s" % (k, v)) - - # add bash setup list to collect bash script config - bash_setup = [] - if bash_strict: - bash_setup.append("set -eo pipefail -o nounset") + for k, v in sorted(slurm_kwargs.items()): # sort only for the doctest purpose + is_long_option = len(k) > 1 + if is_long_option: + k = "--" + k + if v is not None: + k += "=" + else: + k = "-" + k + " " + header.append("#SBATCH %s%s" % (k, v if v is not None else "")) self.header = "\n".join(header) - self.bash_setup = "\n".join(bash_setup) - self.name = "".join(x for x in name.replace( - " ", "-") if x.isalnum() or x == "-") + self.name = "".join(x for x in name.replace(" ", "_") if x.isalnum() or x in ("-", "_")) + self.name = name self.tmpl = tmpl self.slurm_kwargs = slurm_kwargs if scripts_dir is not None: @@ -107,46 +103,47 @@ def __init__(self, name, slurm_kwargs=None, tmpl=None, self.date_in_name = bool(date_in_name) def __str__(self): - return self.tmpl.format(name=self.name, header=self.header, - log_dir=self.log_dir, - bash_setup=self.bash_setup) + return self.tmpl.format(name=self.name, header=self.header) - def _tmpfile(self): + def _get_scriptname(self, name_addition=None): if self.scripts_dir is None: return tmp() else: - for _dir in [self.scripts_dir, self.log_dir]: - if not os.path.exists(_dir): - os.makedirs(_dir) - return "%s/%s.sh" % (self.scripts_dir, self.name) + if not os.path.exists(self.scripts_dir): + os.makedirs(self.scripts_dir) - def run(self, command, name_addition=None, cmd_kwargs=None, - _cmd="sbatch", tries=1, depends_on=None): + script_name = self.name.strip("-") + if name_addition: + script_name += name_addition.strip(" -") + return "%s/%s.sh" % (self.scripts_dir, script_name) + + def run(self, command, name_addition=None, cmd_kwargs=None, local=False, depends_on=None, log_file=None, + after=None): """ command: a bash command that you want to run - name_addition: if not specified, the sha1 of the command to run - appended to job name. if it is "date", the yyyy-mm-dd - date will be added to the job name. - cmd_kwargs: dict of extra arguments to fill in command - (so command itself can be a template). - _cmd: submit command (change to "bash" for testing). - tries: try to run a job either this many times or until the first - success. - depends_on: job ids that this depends on before it is run (users 'afterok') + name_addition: if not specified, the shal of the command to run + appended to job name. if it is "date", the yyyy-mm-dd + date will be added to the job name. + cmd_kwargs: diet of extra arguments to fill in command + (so command itself can be a template). + local: if True, run locally in the background (for testing). Returns the pid() + depends_on: job ids that this depends on before it is run (uses 'afterok') + after: job ids that this depends on them to START before it is run (uses 'after') """ + if name_addition is None: - name_addition = hashlib.sha1(command.encode("utf-8")).hexdigest() + # name_addition = hashlib.sha1(command.encode("utf-8")).hexdigest() + name_addition = '' if self.date_in_name: - name_addition += "-" + str(datetime.date.today()) + name_addition += "-" + datetime.datetime.fromtimestamp(time.time()).isoformat() name_addition = name_addition.strip(" -") + script_name = self._get_scriptname(name_addition) + if cmd_kwargs is None: cmd_kwargs = {} - n = self.name - self.name = self.name.strip(" -") - self.name += ("-" + name_addition.strip(" -")) args = [] for k, v in cmd_kwargs.items(): args.append("export %s=%s" % (k, v)) @@ -156,29 +153,103 @@ def run(self, command, name_addition=None, cmd_kwargs=None, if depends_on is None or (len(depends_on) == 1 and depends_on[0] is None): depends_on = [] - with open(self._tmpfile(), "w") as sh: + if after is None or (len(after) == 1 and after[0] is None): + after = [] + + if "logs/" in tmpl and not os.path.exists("logs/"): + os.makedirs("logs") + + with open(script_name, "w") as sh: sh.write(tmpl) - job_id = None - for itry in range(1, tries + 1): - args = [_cmd] - args.extend([("--dependency=afterok:%d" % int(d)) - for d in depends_on]) - if itry > 1: - mid = "--dependency=afternotok:%d" % job_id - args.append(mid) - args.append(sh.name) - res = subprocess.check_output(args).strip() - print(res, file=sys.stderr) - self.name = n - if not res.startswith(b"Submitted batch"): + log_file = log_file if log_file else sys.stderr + + _cmd = 'bash' if local else 'sbatch' + args = [_cmd] + if not local: + args.extend([("--dependency=afterok:%d" % int(d)) for d in depends_on]) + args.extend([("--dependency=after:%d" % int(d)) for d in after]) + args.append(sh.name) + if not local: + # res = subprocess.check_output(args, text=True).strip() + res = subprocess.check_output(args, universal_newlines=True).strip() + + print(res, file=log_file) + if not res.startswith("Submitted batch"): + return None + job_id = res.split()[-1] + return job_id + else: + pid = subprocess.Popen(args, stdout=log_file, stderr=log_file) + return "pid#" + str(pid.pid) + + @staticmethod + def query(job_id, field=None, on_failure='exception'): + try: + ret = subprocess.check_output(["scontrol", "-d", "-o", "show", "job", str(job_id)], + universal_newlines=True, stderr=subprocess.STDOUT) + + except: + if on_failure == 'warn': + print("warning: scontrol query of job_id=%s failed" % str(job_id)) return None - j_id = int(res.split()[-1]) - if itry == 1: - job_id = j_id - return job_id + elif on_failure == 'silent': + return None + else: + raise SlurmException("Failed to query SLURM") + + try: + ret_dict = {pair[0]: "=".join(pair[1:]) for pair in [_.split("=") for _ in ret.split()]} + except: + print('ret_dict failed for job_id=' + str(job_id) + ', ret=' + str(ret)) + raise SlurmException("Failed to create ret_dict") + + if field is not None: + return ret_dict[field] + return ret_dict + + @staticmethod + def _still_running_pid(pid): + try: + with open('/dev/null', 'w') as devnull: + subprocess.check_call(['ps', '-p', str(pid)], stdout=devnull, stderr=devnull) + except subprocess.CalledProcessError: + return False + return True + + @staticmethod + def _still_running_jobid(job_id): + status = Slurm.query(job_id, field='JobState', on_failure='silent') + if status in ('PENDING', 'RUNNING', 'SUSPENDED', 'CONFIGURING'): + return True + return False + + @staticmethod + def still_running(job_id): + if job_id is None: + return False + job_id = str(job_id) + if job_id.startswith('pid#'): + pid = job_id[4:] + return Slurm._still_running_pid(pid) + else: + return Slurm._still_running_jobid(job_id) + + @staticmethod + def kill(job_id): + if job_id is None: + return False + job_id = str(job_id) + if job_id.startswith('pid#'): + pid = job_id[4:] + os.system('kill -9 ' + pid) + return True + else: + os.system('scancel ' + job_id) + return True if __name__ == "__main__": import doctest + doctest.testmod() diff --git a/test/TestSlurmpy.py b/test/TestSlurmpy.py new file mode 100644 index 0000000..02b1c33 --- /dev/null +++ b/test/TestSlurmpy.py @@ -0,0 +1,58 @@ + +import time +import unittest + +from slurmpy.slurmpy import Slurm, SlurmException + +class TestSlurmpy(unittest.TestCase): + def setUp(self): + self.slurm_queue_args = {"time": "00:00:15", 'no-requeue': None, "Q": None} + + def test_silent_query_of_nonexistent_job(self): + ret = Slurm.query('101', on_failure='silent') + self.assertIsNone(ret) + + def test_exception_query_of_nonexistent_job(self): + with self.assertRaises(SlurmException): + Slurm.query('101') + + def test_sending_to_queue(self): + s = Slurm("job-name", self.slurm_queue_args) + job_id = s.run('sleep 5') + self.assertTrue(s.still_running(job_id)) + + def test_sending_local(self): + s = Slurm("job-name") + job_id = s.run('sleep 5', local=True) + self.assertTrue(s.still_running(job_id)) + time.sleep(7) + self.assertFalse(s.still_running(job_id)) + + def test_multiple_local_sends(self): + s = Slurm("job-name") + ids = [] + for i in range(5): + ids.append(s.run('sleep 5', local=True)) + + def test_multiple_queue_sends(self): + s = Slurm("job-name") + ids = [] + for i in range(5): + ids.append(s.run('sleep 5')) + + def test_kill_local(self): + s = Slurm("job-name") + job_id = s.run('sleep 10', local=True) + self.assertTrue(s.still_running(job_id)) + s.kill(job_id) + self.assertFalse(s.still_running(job_id)) + + def test_kill_queue(self): + s = Slurm("job-name", self.slurm_queue_args) + job_id = s.run('sleep 10', local=True) + self.assertTrue(s.still_running(job_id)) + s.kill(job_id) + self.assertFalse(s.still_running(job_id)) + +if __name__ == '__main__': + unittest.main() diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29