diff --git a/tfx/orchestration/experimental/core/task_manager.py b/tfx/orchestration/experimental/core/task_manager.py index 15406dc769..bde35ceff7 100644 --- a/tfx/orchestration/experimental/core/task_manager.py +++ b/tfx/orchestration/experimental/core/task_manager.py @@ -52,19 +52,45 @@ def __init__(self, errors): self.errors = errors +class _ActiveSchedulerCounter: + """Class for keeping count of active task schedulers.""" + + def __init__(self): + self._lock = threading.Lock() + self._count = 0 + + def __enter__(self): + with self._lock: + self._count += 1 + + def __exit__(self, exc_type, exc_val, exc_tb): + with self._lock: + self._count -= 1 + + def count(self) -> int: + with self._lock: + return self._count + + class _SchedulerWrapper: """Wraps a TaskScheduler to store additional details.""" - def __init__(self, task_scheduler: ts.TaskScheduler): + def __init__( + self, + task_scheduler: ts.TaskScheduler, + active_scheduler_counter: _ActiveSchedulerCounter, + ): self._task_scheduler = task_scheduler + self._active_scheduler_counter = active_scheduler_counter self.pause = False def schedule(self) -> ts.TaskSchedulerResult: - logging.info('Starting task scheduler: %s', self._task_scheduler) - try: - return self._task_scheduler.schedule() - finally: - logging.info('Task scheduler finished: %s', self._task_scheduler) + with self._active_scheduler_counter: + logging.info('Starting task scheduler: %s', self._task_scheduler) + try: + return self._task_scheduler.schedule() + finally: + logging.info('Task scheduler finished: %s', self._task_scheduler) def cancel(self, cancel_task: task_lib.CancelNodeTask) -> None: logging.info('Cancelling task scheduler: %s', self._task_scheduler) @@ -111,16 +137,18 @@ def __init__(self, self._tm_lock = threading.Lock() self._stop_event = threading.Event() self._scheduler_by_node_uid: Dict[task_lib.NodeUid, _SchedulerWrapper] = {} + self._active_scheduler_counter = _ActiveSchedulerCounter() # Async executor for the main task management thread. self._main_executor = futures.ThreadPoolExecutor( max_workers=1, thread_name_prefix='orchestrator_task_manager_main' ) self._main_future = None + self._max_active_task_schedulers = max_active_task_schedulers # Async executor for task schedulers. self._ts_executor = futures.ThreadPoolExecutor( - max_workers=max_active_task_schedulers, + max_workers=self._max_active_task_schedulers, thread_name_prefix='orchestrator_active_task_schedulers', ) self._ts_futures = set() @@ -165,13 +193,16 @@ def _main(self) -> None: """Runs the main task management loop.""" try: while not self._stop_event.is_set(): + self._cleanup() + num_active = self._active_scheduler_counter.count() logging.log_every_n_seconds( logging.INFO, - 'Number of active task schedulers: %d', + 'Number of active task schedulers: %d (max: %d), queued: %d', 30, - len(self._ts_futures), + num_active, + self._max_active_task_schedulers, + len(self._ts_futures) - num_active, ) - self._cleanup() task = self._task_queue.dequeue(self._max_dequeue_wait_secs) if task is None: continue @@ -212,7 +243,11 @@ def _handle_exec_node_task(self, task: task_lib.ExecNodeTask) -> None: typing.cast( ts.TaskScheduler[task_lib.ExecNodeTask], ts.TaskSchedulerRegistry.create_task_scheduler( - self._mlmd_handle, task.pipeline, task))) + self._mlmd_handle, task.pipeline, task + ), + ), + self._active_scheduler_counter, + ) logging.info('Instantiated task scheduler: %s', scheduler) if task.cancel_type == task_lib.NodeCancelType.PAUSE_EXEC: scheduler.pause = True