Skip to content

Commit

Permalink
Fix: Log the actual number of active task schedulers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565553684
  • Loading branch information
goutham authored and tfx-copybara committed Sep 15, 2023
1 parent efcedd0 commit 204d93a
Showing 1 changed file with 46 additions and 11 deletions.
57 changes: 46 additions & 11 deletions tfx/orchestration/experimental/core/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,45 @@ def __init__(self, errors):
self.errors = errors


class _ActiveSchedulerCounter:
"""Class for keeping count of active task schedulers."""

def __init__(self):
self._lock = threading.Lock()
self._count = 0

def __enter__(self):
with self._lock:
self._count += 1

def __exit__(self, exc_type, exc_val, exc_tb):
with self._lock:
self._count -= 1

def count(self) -> int:
with self._lock:
return self._count


class _SchedulerWrapper:
"""Wraps a TaskScheduler to store additional details."""

def __init__(self, task_scheduler: ts.TaskScheduler):
def __init__(
self,
task_scheduler: ts.TaskScheduler,
active_scheduler_counter: _ActiveSchedulerCounter,
):
self._task_scheduler = task_scheduler
self._active_scheduler_counter = active_scheduler_counter
self.pause = False

def schedule(self) -> ts.TaskSchedulerResult:
logging.info('Starting task scheduler: %s', self._task_scheduler)
try:
return self._task_scheduler.schedule()
finally:
logging.info('Task scheduler finished: %s', self._task_scheduler)
with self._active_scheduler_counter:
logging.info('Starting task scheduler: %s', self._task_scheduler)
try:
return self._task_scheduler.schedule()
finally:
logging.info('Task scheduler finished: %s', self._task_scheduler)

def cancel(self, cancel_task: task_lib.CancelNodeTask) -> None:
logging.info('Cancelling task scheduler: %s', self._task_scheduler)
Expand Down Expand Up @@ -111,16 +137,18 @@ def __init__(self,
self._tm_lock = threading.Lock()
self._stop_event = threading.Event()
self._scheduler_by_node_uid: Dict[task_lib.NodeUid, _SchedulerWrapper] = {}
self._active_scheduler_counter = _ActiveSchedulerCounter()

# Async executor for the main task management thread.
self._main_executor = futures.ThreadPoolExecutor(
max_workers=1, thread_name_prefix='orchestrator_task_manager_main'
)
self._main_future = None
self._max_active_task_schedulers = max_active_task_schedulers

# Async executor for task schedulers.
self._ts_executor = futures.ThreadPoolExecutor(
max_workers=max_active_task_schedulers,
max_workers=self._max_active_task_schedulers,
thread_name_prefix='orchestrator_active_task_schedulers',
)
self._ts_futures = set()
Expand Down Expand Up @@ -165,13 +193,16 @@ def _main(self) -> None:
"""Runs the main task management loop."""
try:
while not self._stop_event.is_set():
self._cleanup()
num_active = self._active_scheduler_counter.count()
logging.log_every_n_seconds(
logging.INFO,
'Number of active task schedulers: %d',
'Number of active task schedulers: %d (max: %d), queued: %d',
30,
len(self._ts_futures),
num_active,
self._max_active_task_schedulers,
len(self._ts_futures) - num_active,
)
self._cleanup()
task = self._task_queue.dequeue(self._max_dequeue_wait_secs)
if task is None:
continue
Expand Down Expand Up @@ -212,7 +243,11 @@ def _handle_exec_node_task(self, task: task_lib.ExecNodeTask) -> None:
typing.cast(
ts.TaskScheduler[task_lib.ExecNodeTask],
ts.TaskSchedulerRegistry.create_task_scheduler(
self._mlmd_handle, task.pipeline, task)))
self._mlmd_handle, task.pipeline, task
),
),
self._active_scheduler_counter,
)
logging.info('Instantiated task scheduler: %s', scheduler)
if task.cancel_type == task_lib.NodeCancelType.PAUSE_EXEC:
scheduler.pause = True
Expand Down

0 comments on commit 204d93a

Please sign in to comment.