diff --git a/src/backend/main.py b/src/backend/main.py index 9def165..2138a4d 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -1,13 +1,8 @@ import logging import logging.config -import asyncio -import atexit -import signal -import sys - from server import build, start -from utils import config, processes, status +from utils import config, cleanup, status def main() -> None: @@ -21,9 +16,7 @@ def main() -> None: status.update() # Register cleanup functions - atexit.register(lambda: cleanup) - signal.signal(signal.SIGTERM, lambda _, __: cleanup()) - signal.signal(signal.SIGINT, lambda _, __: cleanup()) + cleanup.init('main') # Build the frontend with Vite build.build_frontend() @@ -37,17 +30,8 @@ def init_logger(): logging.info('Logging init done') -def cleanup() -> None: - logging.info('Cleaning up...') - processes.terminate_subprocesses() - loop = asyncio.get_event_loop() - for task in asyncio.all_tasks(loop=loop): - task.cancel() - loop.stop() - - # Raises SystemExit - sys.exit() - - if __name__ == '__main__': - main() + try: + main() + except KeyboardInterrupt: + pass diff --git a/src/backend/server/main.py b/src/backend/server/main.py index e30d957..1ede963 100644 --- a/src/backend/server/main.py +++ b/src/backend/server/main.py @@ -1,6 +1,7 @@ from typing import Any, AsyncGenerator import asyncio +import contextlib import datetime import inspect import json @@ -16,13 +17,18 @@ from plugins import downloader, handler from server import build from server.endpoint_filter import EndpointFilter -from utils import config, const, motd, settings, status +from utils import cleanup, config, const, motd, settings, status database = users.Database() app = FastAPI() app.mount('/assets', StaticFiles(directory='../public/assets'), name='static') +# Note: Put this after FastAPI init! +# cleanup.init(...) uses the previously +# set signal handler. +cleanup.init('server') + # Ignore /ping logs uvicorn_logger = logging.getLogger('uvicorn.access') uvicorn_logger.addFilter(EndpointFilter(path='/ping')) @@ -227,7 +233,7 @@ async def plugin_status(request: Request) -> EventSourceResponse: handler.set_update_flag() async def event_generator() -> AsyncGenerator[str, None]: - while True: + while not cleanup.get_flag(): if await request.is_disconnected(): break @@ -235,6 +241,6 @@ async def event_generator() -> AsyncGenerator[str, None]: yield json.dumps( {'ok': True, 'plugins': handler.get_plugin_statuses()} ) - await asyncio.sleep(2.0) + await asyncio.sleep(1.0) return EventSourceResponse(event_generator(), media_type='text/event-stream') diff --git a/src/backend/utils/cleanup.py b/src/backend/utils/cleanup.py new file mode 100644 index 0000000..3bca986 --- /dev/null +++ b/src/backend/utils/cleanup.py @@ -0,0 +1,50 @@ +from types import FrameType +from typing import Optional + +import logging +import signal +import threading +import sys + +from utils import flag, processes + +_SHUTDOWN_FLAG = flag.Flag(False) +_HANDLERS = {} + + +def _timeout() -> None: + logging.error('Graceful cleanup timeout! Shutting down forcefully...') + sys.exit() + + +_GRACEFUL_TIMER = threading.Timer(2.5, _timeout) + + +def init(context: str) -> None: + _HANDLERS[2] = signal.getsignal(signal.SIGINT) + _HANDLERS[15] = signal.getsignal(signal.SIGTERM) + signal.signal(signal.SIGINT, lambda s, f: cleanup(s, f, context)) + signal.signal(signal.SIGTERM, lambda s, f: cleanup(s, f, context)) + + +def cleanup(sig: int, frame: Optional[FrameType], context: str) -> None: + logging.info(f'(Context: {context}) Cleaning up...') + _SHUTDOWN_FLAG.set() + processes.terminate_subprocesses() + handler = _HANDLERS[sig] + if callable(handler): + handler(sig, frame) + logging.info(f'(Context: {context}) Cleanup done!{ + " Bye-bye!" if context == "main" else ""}') + + # Set forceful cleanup timeout to 2.5 seconds + if not _GRACEFUL_TIMER.is_alive(): + _GRACEFUL_TIMER.start() + + # Stop timeout timer if everything finished + elif context == 'main': + _GRACEFUL_TIMER.cancel() + + +def get_flag() -> bool: + return _SHUTDOWN_FLAG.get()