Skip to content

Commit

Permalink
Add pause/resume/context to workers (#101)
Browse files Browse the repository at this point in the history
* Add pause/resume/context to workers

- Allows a user to start/stop processes at will, via OS signals SIGSTOP and SIGCONT.
- Allows a user to bind processes to specific CPUs.
- Allows local_worker_pool to be used outside of a context manager
- Switch workers to be Protocol based, so Workers are effectively duck-typed (i.e. anything that has the required methods passes as a Worker)

Part of #96
  • Loading branch information
Northbadge authored Aug 18, 2022
1 parent 45d1e2d commit 16eb08b
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 9 deletions.
23 changes: 22 additions & 1 deletion compiler_opt/distributed/local/local_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import dataclasses
import functools
import multiprocessing
import psutil
import threading

from absl import logging
Expand All @@ -39,7 +40,7 @@

from contextlib import AbstractContextManager
from multiprocessing import connection
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -214,6 +215,18 @@ def shutdown(self):
except: # pylint: disable=bare-except
pass

def set_nice(self, val: int):
"""Sets the nice-ness of the process, this modifies how the OS
schedules it. Only works on Unix, since val is presumed to be an int.
"""
psutil.Process(self._process.pid).nice(val)

def set_affinity(self, val: List[int]):
"""Sets the CPU affinity of the process, this modifies which cores the OS
schedules it on.
"""
psutil.Process(self._process.pid).cpu_affinity(val)

def join(self):
self._observer.join()
self._pump.join()
Expand Down Expand Up @@ -247,3 +260,11 @@ def __exit__(self, *args):
# now wait for the message pumps to indicate they exit.
for s in self._stubs:
s.join()

def __del__(self):
self.__exit__()

@property
def stubs(self):
# Return a shallow copy, to avoid something messing the internal list up
return list(self._stubs)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Test for local worker manager."""

import concurrent.futures
import multiprocessing
import time

from absl.testing import absltest
Expand Down
4 changes: 2 additions & 2 deletions compiler_opt/distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
"""Common abstraction for a worker contract."""

from typing import Iterable, Optional, TypeVar, Protocol
from typing import Iterable, Optional, Protocol, TypeVar


class Worker:
class Worker(Protocol):

@classmethod
def is_priority_method(cls, method_name: str) -> bool:
Expand Down
40 changes: 37 additions & 3 deletions compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import json
import os
import signal
import subprocess
import threading
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(self):
# empty() is accurate and get() never blocks.
self._processes = set()
self._done = False
self._paused = False
self._lock = threading.Lock()

def enable(self):
Expand All @@ -129,10 +131,33 @@ def kill_all_processes(self):
for p in self._processes:
kill_process_ignore_exceptions(p)

def pause_all_processes(self):
with self._lock:
if self._paused:
return
self._paused = True

for p in self._processes:
# used to send the STOP signal; does not actually kill the process
os.kill(p.pid, signal.SIGSTOP)

def resume_all_processes(self):
with self._lock:
if not self._paused:
return
self._paused = False

for p in self._processes:
# used to send the CONTINUE signal; does not actually kill the process
os.kill(p.pid, signal.SIGCONT)

def unregister_process(self, p: 'subprocess.Popen[bytes]'):
with self._lock:
if not self._done:
self._processes.remove(p)
self._processes.remove(p)

def __del__(self):
if len(self._processes) > 0:
raise RuntimeError('Cancellation manager deleted while containing items.')


def start_cancellable_process(
Expand Down Expand Up @@ -174,6 +199,7 @@ def start_cancellable_process(
finally:
if cancellation_manager:
cancellation_manager.unregister_process(p)

if retcode != 0:
raise ProcessKilledError(
) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline)
Expand Down Expand Up @@ -249,7 +275,9 @@ class CompilationRunner(Worker):

@classmethod
def is_priority_method(cls, method_name: str) -> bool:
return method_name == 'cancel_all_work'
return method_name in {
'cancel_all_work', 'pause_all_work', 'resume_all_work'
}

def __init__(self,
clang_path: Optional[str] = None,
Expand All @@ -275,6 +303,12 @@ def enable(self):
def cancel_all_work(self):
self._cancellation_manager.kill_all_processes()

def pause_all_work(self):
self._cancellation_manager.pause_all_processes()

def resume_all_work(self):
self._cancellation_manager.resume_all_processes()

def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
Expand Down
22 changes: 20 additions & 2 deletions compiler_opt/rl/compilation_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import string
import subprocess
import threading
import time
from unittest import mock

Expand Down Expand Up @@ -213,9 +214,9 @@ def test_exception_handling(self, mock_compile_fn):
self.assertEqual(1, mock_compile_fn.call_count)

def test_start_subprocess_output(self):
ct = compilation_runner.WorkerCancellationManager()
cm = compilation_runner.WorkerCancellationManager()
output = compilation_runner.start_cancellable_process(
['ls', '-l'], timeout=100, cancellation_manager=ct, want_output=True)
['ls', '-l'], timeout=100, cancellation_manager=cm, want_output=True)
if output:
output_str = output.decode('utf-8')
else:
Expand All @@ -235,6 +236,23 @@ def test_timeout_kills_process(self):
time.sleep(2)
self.assertFalse(os.path.exists(sentinel_file))

def test_pause_resume(self):
cm = compilation_runner.WorkerCancellationManager()
start_time = time.time()

def stop_and_start():
time.sleep(0.25)
cm.pause_all_processes()
time.sleep(1)
cm.resume_all_processes()

threading.Thread(target=stop_and_start).start()
compilation_runner.start_cancellable_process(['sleep', '0.5'],
30,
cancellation_manager=cm)
# should be at least 1 second due to the pause.
self.assertGreater(time.time() - start_time, 1)


if __name__ == '__main__':
tf.test.main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ oauthlib==3.1.1
opt-einsum==3.3.0
pillow==8.3.1
protobuf==3.17.3
psutil==5.9.0
pyasn1==0.4.8
pyasn1_modules==0.2.8
pyglet==1.5.0
Expand Down

0 comments on commit 16eb08b

Please sign in to comment.