From 01b8f8caafdcc20df4fa3853c20d4b229597b2f0 Mon Sep 17 00:00:00 2001 From: Wrench56 Date: Tue, 2 Jul 2024 03:48:39 -0400 Subject: [PATCH] Add `expose` API for handling dynamic plugin endpoints --- src/backend/api/expose.py | 50 ++++++++++++++++++++++++++++++++++++++ src/backend/server/main.py | 28 +++++++++++++++++++-- src/backend/utils/stack.py | 13 ++++++++++ 3 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 src/backend/api/expose.py create mode 100644 src/backend/utils/stack.py diff --git a/src/backend/api/expose.py b/src/backend/api/expose.py new file mode 100644 index 0000000..7b8fe48 --- /dev/null +++ b/src/backend/api/expose.py @@ -0,0 +1,50 @@ +from typing import Any, Callable, Dict, Optional + +from fastapi import Request + +from utils import stack + +_SUBSCRIBERS: Dict[str, Dict[str, + Dict[str, Callable[[str, Request], Any]]]] = { + 'GET': {}, + 'PUT': {}, + 'POST': {}, + 'DELETE': {} +} + + +def subscribe_get(endpoint: str, callback: Callable[[str, Request], Any]) -> None: + _subscribe(endpoint, _SUBSCRIBERS['GET'], callback) + + +def subscribe_put(endpoint: str, callback: Callable[[str, Request], Any]) -> None: + _subscribe(endpoint, _SUBSCRIBERS['PUT'], callback) + + +def subscribe_post(endpoint: str, callback: Callable[[str, Request], Any]) -> None: + _subscribe(endpoint, _SUBSCRIBERS['POST'], callback) + + +def subscribe_delete(endpoint: str, callback: Callable[[str, Request], Any]) -> None: + _subscribe(endpoint, _SUBSCRIBERS['DELETE'], callback) + + +def _subscribe(endpoint: str, structure: Dict, callback: Callable[[str, Request], Any]) -> None: + module_name = stack.get_caller(depth=3)[0] + if not module_name.startswith('plugins.plugins.'): + return + + plugin_name = module_name.split('.')[2] + if plugin_name not in structure: + structure[plugin_name] = {} + + structure[plugin_name][endpoint] = callback + + +def fetch_callback(plugin: str, endpoint: str, method: str) -> Optional[Callable[[str, Request], Any]]: + endpoints = _SUBSCRIBERS[method.upper()].get(plugin) + if endpoints is None: + return None + if endpoint[-1] == '/': + endpoint = endpoint[:-1] + return endpoints.get(endpoint) diff --git a/src/backend/server/main.py b/src/backend/server/main.py index bca8364..77c024b 100644 --- a/src/backend/server/main.py +++ b/src/backend/server/main.py @@ -1,3 +1,5 @@ +from typing import Any + import datetime import logging @@ -5,6 +7,7 @@ from fastapi.responses import FileResponse, ORJSONResponse, PlainTextResponse from fastapi.staticfiles import StaticFiles +from api import expose from db import users from server import build from utils import config, const, motd, settings, status @@ -50,7 +53,8 @@ async def rebuild(request: Request) -> PlainTextResponse: return PlainTextResponse('ERROR: BUILD') build_size, units = build.get_frontend_size() return PlainTextResponse( - f'REBUILT: Rebuilt in {build_time}ms\nSize of build folder: {build_size}{units}' + f'REBUILT: Rebuilt in { + build_time}ms\nSize of build folder: {build_size}{units}' ) @@ -76,7 +80,8 @@ async def login(request: Request) -> PlainTextResponse: uuid = database.create_uuid(username) logging.info(f'Welcome user "{username}"!') - expire_time = float(config.get('security').get('auth_cookie_expire_time') or 3600.0) + expire_time = float(config.get('security').get( + 'auth_cookie_expire_time') or 3600.0) response.set_cookie( key='auth_cookie', value=uuid, @@ -120,3 +125,22 @@ async def update_setting(request: Request, id_: str) -> PlainTextResponse: settings.update_setting(id_, data.decode()) return response + + +# Plugins +@app.get('/plugins/{plugin}/{endpoint:path}') +@app.put('/plugins/{plugin}/{endpoint:path}') +@app.post('/plugins/{plugin}/{endpoint:path}') +@app.delete('/plugins/{plugin}/{endpoint:path}') +async def plugins(request: Request, plugin: str, endpoint: str) -> Any: + response = PlainTextResponse() + if not database.uuid_exists(request.cookies.get('auth_cookie')): + response.status_code = 401 + return response + + # Remove sensitive cookie(s) + request.cookies['auth_cookie'] = '' + callback = expose.fetch_callback(plugin, endpoint, request.method) + if callback: + return callback(endpoint, request) + return response diff --git a/src/backend/utils/stack.py b/src/backend/utils/stack.py new file mode 100644 index 0000000..f2a4b92 --- /dev/null +++ b/src/backend/utils/stack.py @@ -0,0 +1,13 @@ +# pylint: disable=protected-access + +from typing import Tuple +import sys + + +# Note: The sys wiki states that "It (sys._getframe()) +# is not guranteed to exist in all +# implementations of Python" +def get_caller(depth: int = 1) -> Tuple[str, str]: + frame = sys._getframe(depth) + return (frame.f_globals['__name__'], + frame.f_code.co_name)