Skip to content

Commit

Permalink
Integrate scheduler into dispatcher main loop
Browse files Browse the repository at this point in the history
Dispatcher refactoring to get pg_notify publish payload
  as separate method

Refactor periodic module under dispatcher entirely
  Use real numbers for schedule reference time
  Run based on due_to_run method

Review comments about naming and code comments
  • Loading branch information
AlanCoding committed Jul 27, 2023
1 parent 852bb07 commit e3b8320
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 132 deletions.
14 changes: 9 additions & 5 deletions awx/main/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ def get_task_queuename():


class PubSub(object):
def __init__(self, conn):
def __init__(self, conn, select_timeout=None):
self.conn = conn
if select_timeout is None:
self.select_timeout = 5
else:
self.select_timeout = select_timeout

def listen(self, channel):
with self.conn.cursor() as cur:
Expand Down Expand Up @@ -72,12 +76,12 @@ def current_notifies(conn):
n = psycopg.connection.Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n

def events(self, select_timeout=5, yield_timeouts=False):
def events(self, yield_timeouts=False):
if not self.conn.autocommit:
raise RuntimeError('Listening for events can only be done in autocommit mode')

while True:
if select.select([self.conn], [], [], select_timeout) == NOT_READY:
if select.select([self.conn], [], [], self.select_timeout) == NOT_READY:
if yield_timeouts:
yield None
else:
Expand All @@ -90,7 +94,7 @@ def close(self):


@contextmanager
def pg_bus_conn(new_connection=False):
def pg_bus_conn(new_connection=False, select_timeout=None):
'''
Any listeners probably want to establish a new database connection,
separate from the Django connection used for queries, because that will prevent
Expand All @@ -115,7 +119,7 @@ def pg_bus_conn(new_connection=False):
raise RuntimeError('Unexpectedly could not connect to postgres for pg_notify actions')
conn = pg_connection.connection

pubsub = PubSub(conn)
pubsub = PubSub(conn, select_timeout=select_timeout)
yield pubsub
if new_connection:
conn.close()
7 changes: 5 additions & 2 deletions awx/main/dispatch/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def running(self, *args, **kwargs):
def cancel(self, task_ids, *args, **kwargs):
return self.control_with_reply('cancel', *args, extra_data={'task_ids': task_ids}, **kwargs)

def schedule(self, *args, **kwargs):
return self.control_with_reply('schedule', *args, **kwargs)

@classmethod
def generate_reply_queue_name(cls):
return f"reply_to_{str(uuid.uuid4()).replace('-','_')}"
Expand All @@ -52,14 +55,14 @@ def control_with_reply(self, command, timeout=5, extra_data=None):
if not connection.get_autocommit():
raise RuntimeError('Control-with-reply messages can only be done in autocommit mode')

with pg_bus_conn() as conn:
with pg_bus_conn(select_timeout=timeout) as conn:
conn.listen(reply_queue)
send_data = {'control': command, 'reply_to': reply_queue}
if extra_data:
send_data.update(extra_data)
conn.notify(self.queuename, json.dumps(send_data))

for reply in conn.events(select_timeout=timeout, yield_timeouts=True):
for reply in conn.events(yield_timeouts=True):
if reply is None:
logger.error(f'{self.service} did not reply within {timeout}s')
raise RuntimeError(f"{self.service} did not reply within {timeout}s")
Expand Down
185 changes: 135 additions & 50 deletions awx/main/dispatch/periodic.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,142 @@
import logging
import os
import time
from multiprocessing import Process
import yaml
from datetime import datetime

from django.conf import settings
from django.db import connections
from schedule import Scheduler
from django_guid import set_guid
from django_guid.utils import generate_guid

from awx.main.dispatch.worker import TaskWorker
from awx.main.utils.db import set_connection_name

logger = logging.getLogger('awx.main.dispatch.periodic')


class Scheduler(Scheduler):
def run_continuously(self):
idle_seconds = max(1, min(self.jobs).period.total_seconds() / 2)

def run():
ppid = os.getppid()
logger.warning('periodic beat started')

set_connection_name('periodic') # set application_name to distinguish from other dispatcher processes

while True:
if os.getppid() != ppid:
# if the parent PID changes, this process has been orphaned
# via e.g., segfault or sigkill, we should exit too
pid = os.getpid()
logger.warning(f'periodic beat exiting gracefully pid:{pid}')
raise SystemExit()
try:
for conn in connections.all():
# If the database connection has a hiccup, re-establish a new
# connection
conn.close_if_unusable_or_obsolete()
set_guid(generate_guid())
self.run_pending()
except Exception:
logger.exception('encountered an error while scheduling periodic tasks')
time.sleep(idle_seconds)

process = Process(target=run)
process.daemon = True
process.start()


def run_continuously():
scheduler = Scheduler()
for task in settings.CELERYBEAT_SCHEDULE.values():
apply_async = TaskWorker.resolve_callable(task['task']).apply_async
total_seconds = task['schedule'].total_seconds()
scheduler.every(total_seconds).seconds.do(apply_async)
scheduler.run_continuously()
class ScheduledTask:
"""
Class representing schedules, very loosely modeled after python schedule library Job
the idea of this class is to:
- only deal in relative times (time since the scheduler global start)
- only deal in integer math for target runtimes, but float for current relative time
Missed schedule policy:
Invariant target times are maintained, meaning that if interval=10s offset=0
and it runs at t=7s, then it calls for next run in 3s.
However, if a complete interval has passed, that is counted as a missed run,
and missed runs are abandoned (no catch-up runs).
"""

def __init__(self, name: str, data: dict):
# parameters need for schedule computation
self.interval = int(data['schedule'].total_seconds())
self.offset = 0 # offset relative to start time this schedule begins
self.index = 0 # number of periods of the schedule that has passed

# parameters that do not affect scheduling logic
self.last_run = None # time of last run, only used for debug
self.completed_runs = 0 # number of times schedule is known to run
self.name = name
self.data = data # used by caller to know what to run

@property
def next_run(self):
"Time until the next run with t=0 being the global_start of the scheduler class"
return (self.index + 1) * self.interval + self.offset

def due_to_run(self, relative_time):
return bool(self.next_run <= relative_time)

def expected_runs(self, relative_time):
return int((relative_time - self.offset) / self.interval)

def mark_run(self, relative_time):
self.last_run = relative_time
self.completed_runs += 1
new_index = self.expected_runs(relative_time)
if new_index > self.index + 1:
logger.warning(f'Missed {new_index - self.index - 1} schedules of {self.name}')
self.index = new_index

def missed_runs(self, relative_time):
"Number of times job was supposed to ran but failed to, only used for debug"
missed_ct = self.expected_runs(relative_time) - self.completed_runs
# if this is currently due to run do not count that as a missed run
if missed_ct and self.due_to_run(relative_time):
missed_ct -= 1
return missed_ct


class Scheduler:
def __init__(self, schedule):
"""
Expects schedule in the form of a dictionary like
{
'job1': {'schedule': timedelta(seconds=50), 'other': 'stuff'}
}
Only the schedule nearest-second value is used for scheduling,
the rest of the data is for use by the caller to know what to run.
"""
self.jobs = [ScheduledTask(name, data) for name, data in schedule.items()]
min_interval = min(job.interval for job in self.jobs)
num_jobs = len(self.jobs)

# this is intentionally oppioniated against spammy schedules
# a core goal is to spread out the scheduled tasks (for worker management)
# and high-frequency schedules just do not work with that
if num_jobs > min_interval:
raise RuntimeError(f'Number of schedules ({num_jobs}) is more than the shortest schedule interval ({min_interval} seconds).')

# even space out jobs over the base interval
for i, job in enumerate(self.jobs):
job.offset = (i * min_interval) // num_jobs

# internally times are all referenced relative to startup time, add grace period
self.global_start = time.time() + 2.0

def get_and_mark_pending(self):
relative_time = time.time() - self.global_start
to_run = []
for job in self.jobs:
if job.due_to_run(relative_time):
to_run.append(job)
logger.debug(f'scheduler found {job.name} to run, {relative_time - job.next_run} seconds after target')
job.mark_run(relative_time)
return to_run

def time_until_next_run(self):
relative_time = time.time() - self.global_start
next_job = min(self.jobs, key=lambda j: j.next_run)
delta = next_job.next_run - relative_time
if delta <= 0.1:
# careful not to give 0 or negative values to the select timeout, which has unclear interpretation
logger.warning(f'Scheduler next run of {next_job.name} is {-delta} seconds in the past')
return 0.1
elif delta > 20.0:
logger.warning(f'Scheduler next run unexpectedly over 20 seconds in future: {delta}')
return 20.0
logger.debug(f'Scheduler next run is {next_job.name} in {delta} seconds')
return delta

def debug(self, *args, **kwargs):
data = dict()
data['title'] = 'Scheduler status'

now = datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S UTC')
start_time = datetime.fromtimestamp(self.global_start).strftime('%Y-%m-%d %H:%M:%S UTC')
relative_time = time.time() - self.global_start
data['started_time'] = start_time
data['current_time'] = now
data['current_time_relative'] = round(relative_time, 3)
data['total_schedules'] = len(self.jobs)

data['schedule_list'] = dict(
[
(
job.name,
dict(
last_run_seconds_ago=round(relative_time - job.last_run, 3) if job.last_run else None,
next_run_in_seconds=round(job.next_run - relative_time, 3),
offset_in_seconds=job.offset,
completed_runs=job.completed_runs,
missed_runs=job.missed_runs(relative_time),
),
)
for job in sorted(self.jobs, key=lambda job: job.interval)
]
)
return yaml.safe_dump(data, default_flow_style=False, sort_keys=False)
23 changes: 17 additions & 6 deletions awx/main/dispatch/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,32 @@ def delay(cls, *args, **kwargs):
return cls.apply_async(args, kwargs)

@classmethod
def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw):
def get_async_body(cls, args=None, kwargs=None, uuid=None, **kw):
"""
Get the python dict to become JSON data in the pg_notify message
This same message gets passed over the dispatcher IPC queue to workers
If a task is submitted to a multiprocessing pool, skipping pg_notify, this might be used directly
"""
task_id = uuid or str(uuid4())
args = args or []
kwargs = kwargs or {}
queue = queue or getattr(cls.queue, 'im_func', cls.queue)
if not queue:
msg = f'{cls.name}: Queue value required and may not be None'
logger.error(msg)
raise ValueError(msg)
obj = {'uuid': task_id, 'args': args, 'kwargs': kwargs, 'task': cls.name, 'time_pub': time.time()}
guid = get_guid()
if guid:
obj['guid'] = guid
if bind_kwargs:
obj['bind_kwargs'] = bind_kwargs
obj.update(**kw)
return obj

@classmethod
def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw):
queue = queue or getattr(cls.queue, 'im_func', cls.queue)
if not queue:
msg = f'{cls.name}: Queue value required and may not be None'
logger.error(msg)
raise ValueError(msg)
obj = cls.get_async_body(args=args, kwargs=kwargs, uuid=uuid, **kw)
if callable(queue):
queue = queue()
if not is_testing():
Expand Down Expand Up @@ -116,4 +126,5 @@ def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw):
setattr(fn, 'name', cls.name)
setattr(fn, 'apply_async', cls.apply_async)
setattr(fn, 'delay', cls.delay)
setattr(fn, 'get_async_body', cls.get_async_body)
return fn
Loading

0 comments on commit e3b8320

Please sign in to comment.