From c216660ccf4ce95d8c5d434e3199067f6af2ff60 Mon Sep 17 00:00:00 2001 From: Wrench56 Date: Sun, 14 Jul 2024 06:15:22 -0400 Subject: [PATCH] Implement on-the-fly plugin loading --- src/backend/api/expose.py | 63 ++++++++++++++++++++++++++----- src/backend/plugins/downloader.py | 4 ++ src/backend/plugins/handler.py | 47 ++++++++++++++++++++--- src/backend/plugins/priority.py | 17 ++++++++- 4 files changed, 115 insertions(+), 16 deletions(-) diff --git a/src/backend/api/expose.py b/src/backend/api/expose.py index 00d653c..8309ff3 100644 --- a/src/backend/api/expose.py +++ b/src/backend/api/expose.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from fastapi import Request, WebSocket @@ -9,10 +9,10 @@ class TrieNode: def __init__(self): - self.children = {} + self.children: Dict[str, TrieNode] = {} self.callback: Optional[CallbackFunctionType] = None - def __str__(self, level=0): + def __str__(self, level=0) -> str: result = [] indent = ' ' * (level * 2) if self.callback: @@ -26,8 +26,8 @@ def __str__(self, level=0): class URLRouter: - def __init__(self): - self.routes = {} + def __init__(self) -> None: + self.routes: Dict[str, Dict[str, TrieNode]] = {} def add_route( self, @@ -35,9 +35,10 @@ def add_route( plugin_name: str, pattern: str, callback: CallbackFunctionType, - ): + ) -> None: parts = pattern.strip('/').split('/') - node = self.routes.setdefault(method, {}).setdefault(plugin_name, TrieNode()) + node = self.routes.setdefault( + method, {}).setdefault(plugin_name, TrieNode()) for part in parts: if part not in node.children: @@ -47,7 +48,37 @@ def add_route( # Replace the existing callback with the new one node.callback = callback - def match(self, method: str, plugin_name: str, url: str): + def remove_route(self, method: str, plugin_name: str, pattern: str) -> bool: + parts = pattern.strip('/').split('/') + node = self.routes.get(method, {}).get(plugin_name) + + if not node: + return False + + return self._remove_parts(node, parts, 0) + + def _remove_parts(self, node: TrieNode, parts: List[str], index: int) -> bool: + if index == len(parts): + if node.callback: + node.callback = None + return len(node.children) == 0 + return False + + part = parts[index] + if part in node.children: + should_delete_child = self._remove_parts( + node.children[part], parts, index + 1) + if should_delete_child: + del node.children[part] + return len(node.children) == 0 and node.callback is None + return False + + def remove_plugin(self, plugin_name: str) -> None: + self.routes = {method: {pn: pd for pn, pd in plugin_dict.items() if pn != plugin_name} + for method, plugin_dict in self.routes.items() + if any(pn != plugin_name for pn in plugin_dict)} + + def match(self, method: str, plugin_name: str, url: str) -> Optional[CallbackFunctionType]: parts = url.strip('/').split('/') node = self.routes.get(method, {}).get(plugin_name, TrieNode()) return self._match_parts(node, parts, 0) @@ -78,7 +109,7 @@ def _match_parts( return None - def __str__(self): + def __str__(self) -> str: result = [] for method, plugin_dict in self.routes.items(): result.append(f'Method: {method}') @@ -126,3 +157,17 @@ def fetch_callback( plugin: str, url: str, method: str ) -> Optional[CallbackFunctionType]: return _ROUTER.match(method.upper(), plugin, url) + + +def unsubscribe(endpoint: str, method: str) -> bool: + module_name = stack.get_caller(depth=2)[0] + if not module_name.startswith('plugins.plugins.'): + return False + return _ROUTER.remove_route(method, module_name.split('.')[2], endpoint) + + +# Do not allow plugins to use this +def unload_plugin(name: str) -> None: + if stack.get_caller(depth=2)[0].startswith('plugins.plugins.'): + return + _ROUTER.remove_plugin(name) diff --git a/src/backend/plugins/downloader.py b/src/backend/plugins/downloader.py index 3532216..ddcf6bc 100644 --- a/src/backend/plugins/downloader.py +++ b/src/backend/plugins/downloader.py @@ -19,6 +19,9 @@ def from_url(url: str) -> bool: return False name = config['plugin'].get('name').replace('-', '_') + handler.unload(name) + priority.remove_plugin(name) + zip_url = config['plugin'].get('zip_url') try: @@ -45,6 +48,7 @@ def from_url(url: str) -> bool: priority.add_new_plugin(name, 2) handler.load(name) + logging.info(f'Plugin "{name}" installed successfully') return True diff --git a/src/backend/plugins/handler.py b/src/backend/plugins/handler.py index a48b1b4..96e16b5 100644 --- a/src/backend/plugins/handler.py +++ b/src/backend/plugins/handler.py @@ -3,6 +3,7 @@ import importlib import logging +from api import expose from plugins.base_plugin import Plugin from plugins import priority @@ -11,14 +12,17 @@ def load_all() -> None: - for plugin_name, prio in priority.fetch_plugins(): - # Skip if plugin has already been loaded - if _PLUGINS.get(plugin_name) is not None: - continue - logging.info(f'Loading plugin "{plugin_name}" with priority {prio}') + loaded = 0 + for plugin_name, _ in priority.fetch_plugins(): plugin = load(plugin_name) if plugin is not None: - plugin.load() + loaded += 1 + + log = f'Loaded {loaded}/{priority.length()} plugins' + if loaded == priority.length(): + logging.info(log) + return + logging.warning(log) def load(name: str) -> Optional[Plugin]: @@ -27,6 +31,8 @@ def load(name: str) -> Optional[Plugin]: plugin: Plugin = importlib.import_module(source).init() if plugin: _PLUGINS[name] = plugin + plugin.load() + logging.info(f'Loaded plugin "{name}" successfully') return plugin except TypeError: # Abstract class (Plugin) does not implement methods like load & unload @@ -41,5 +47,34 @@ def load(name: str) -> Optional[Plugin]: return None +def unload(name: str) -> bool: + plugin = _PLUGINS.get('name') + if plugin is None: + logging.warning(f'Plugin "{name}" is not loaded') + return False + _PLUGINS[name] = None + if not plugin.unload(): + logging.error(f'Plugin "{name}" could not be unloaded') + return False + + # Remove all subscription + expose.unload_plugin(name) + + logging.info(f'Plugin "{name}" has been unloaded') + return True + + +def unload_all() -> bool: + success = True + for plugin in _PLUGINS.values(): + if not plugin.unload(): + logging.error(f'Plugin "{plugin.name}" could not be unloaded') + success = False + logging.info(f'Plugin "{plugin.name}" has been unloaded') + + _PLUGINS.clear() + return success + + def get_plugin_names() -> Tuple[str, ...]: return tuple(_PLUGINS.keys()) diff --git a/src/backend/plugins/priority.py b/src/backend/plugins/priority.py index 3d6fa13..5a0d60a 100644 --- a/src/backend/plugins/priority.py +++ b/src/backend/plugins/priority.py @@ -1,5 +1,7 @@ from typing import Generator, List, Tuple +import logging + from utils.const import PLUGINS_DIR _PRIORITIES: List[Tuple[str, int]] = [] @@ -28,10 +30,23 @@ def change_plugin_priority(name: str, priority: int) -> None: _PRIORITIES[i] = (name, priority) +def remove_plugin(name: str) -> None: + for i, pp_pair in enumerate(_PRIORITIES): + if name == pp_pair[0]: + _PRIORITIES.pop(i) + _save_priorities() + return + logging.warning(f'Plugin "{name}" was not found in the priority list') + + +def length() -> int: + return len(_PRIORITIES) + + def _save_priorities() -> None: with open(f'{PLUGINS_DIR}/priorities.csv', 'w', encoding='utf-8') as f: for prio in _PRIORITIES: - f.write(f'{prio[0]},{prio[1]}') + f.write(f'{prio[0]},{prio[1]}\n') f.close()