Skip to content

Commit

Permalink
Implement on-the-fly plugin loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrench56 committed Jul 14, 2024
1 parent a7e093d commit c216660
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 16 deletions.
63 changes: 54 additions & 9 deletions src/backend/api/expose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

from fastapi import Request, WebSocket

Expand All @@ -9,10 +9,10 @@

class TrieNode:
def __init__(self):
self.children = {}
self.children: Dict[str, TrieNode] = {}
self.callback: Optional[CallbackFunctionType] = None

def __str__(self, level=0):
def __str__(self, level=0) -> str:
result = []
indent = ' ' * (level * 2)
if self.callback:
Expand All @@ -26,18 +26,19 @@ def __str__(self, level=0):


class URLRouter:
def __init__(self):
self.routes = {}
def __init__(self) -> None:
self.routes: Dict[str, Dict[str, TrieNode]] = {}

def add_route(
self,
method: str,
plugin_name: str,
pattern: str,
callback: CallbackFunctionType,
):
) -> None:
parts = pattern.strip('/').split('/')
node = self.routes.setdefault(method, {}).setdefault(plugin_name, TrieNode())
node = self.routes.setdefault(
method, {}).setdefault(plugin_name, TrieNode())

for part in parts:
if part not in node.children:
Expand All @@ -47,7 +48,37 @@ def add_route(
# Replace the existing callback with the new one
node.callback = callback

def match(self, method: str, plugin_name: str, url: str):
def remove_route(self, method: str, plugin_name: str, pattern: str) -> bool:
parts = pattern.strip('/').split('/')
node = self.routes.get(method, {}).get(plugin_name)

if not node:
return False

return self._remove_parts(node, parts, 0)

def _remove_parts(self, node: TrieNode, parts: List[str], index: int) -> bool:
if index == len(parts):
if node.callback:
node.callback = None
return len(node.children) == 0
return False

part = parts[index]
if part in node.children:
should_delete_child = self._remove_parts(
node.children[part], parts, index + 1)
if should_delete_child:
del node.children[part]
return len(node.children) == 0 and node.callback is None
return False

def remove_plugin(self, plugin_name: str) -> None:
self.routes = {method: {pn: pd for pn, pd in plugin_dict.items() if pn != plugin_name}
for method, plugin_dict in self.routes.items()
if any(pn != plugin_name for pn in plugin_dict)}

def match(self, method: str, plugin_name: str, url: str) -> Optional[CallbackFunctionType]:
parts = url.strip('/').split('/')
node = self.routes.get(method, {}).get(plugin_name, TrieNode())
return self._match_parts(node, parts, 0)
Expand Down Expand Up @@ -78,7 +109,7 @@ def _match_parts(

return None

def __str__(self):
def __str__(self) -> str:
result = []
for method, plugin_dict in self.routes.items():
result.append(f'Method: {method}')
Expand Down Expand Up @@ -126,3 +157,17 @@ def fetch_callback(
plugin: str, url: str, method: str
) -> Optional[CallbackFunctionType]:
return _ROUTER.match(method.upper(), plugin, url)


def unsubscribe(endpoint: str, method: str) -> bool:
module_name = stack.get_caller(depth=2)[0]
if not module_name.startswith('plugins.plugins.'):
return False
return _ROUTER.remove_route(method, module_name.split('.')[2], endpoint)


# Do not allow plugins to use this
def unload_plugin(name: str) -> None:
if stack.get_caller(depth=2)[0].startswith('plugins.plugins.'):
return
_ROUTER.remove_plugin(name)
4 changes: 4 additions & 0 deletions src/backend/plugins/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def from_url(url: str) -> bool:
return False

name = config['plugin'].get('name').replace('-', '_')
handler.unload(name)
priority.remove_plugin(name)

zip_url = config['plugin'].get('zip_url')

try:
Expand All @@ -45,6 +48,7 @@ def from_url(url: str) -> bool:

priority.add_new_plugin(name, 2)
handler.load(name)
logging.info(f'Plugin "{name}" installed successfully')
return True


Expand Down
47 changes: 41 additions & 6 deletions src/backend/plugins/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib
import logging

from api import expose
from plugins.base_plugin import Plugin
from plugins import priority

Expand All @@ -11,14 +12,17 @@


def load_all() -> None:
for plugin_name, prio in priority.fetch_plugins():
# Skip if plugin has already been loaded
if _PLUGINS.get(plugin_name) is not None:
continue
logging.info(f'Loading plugin "{plugin_name}" with priority {prio}')
loaded = 0
for plugin_name, _ in priority.fetch_plugins():
plugin = load(plugin_name)
if plugin is not None:
plugin.load()
loaded += 1

log = f'Loaded {loaded}/{priority.length()} plugins'
if loaded == priority.length():
logging.info(log)
return
logging.warning(log)


def load(name: str) -> Optional[Plugin]:
Expand All @@ -27,6 +31,8 @@ def load(name: str) -> Optional[Plugin]:
plugin: Plugin = importlib.import_module(source).init()
if plugin:
_PLUGINS[name] = plugin
plugin.load()
logging.info(f'Loaded plugin "{name}" successfully')
return plugin
except TypeError:
# Abstract class (Plugin) does not implement methods like load & unload
Expand All @@ -41,5 +47,34 @@ def load(name: str) -> Optional[Plugin]:
return None


def unload(name: str) -> bool:
plugin = _PLUGINS.get('name')
if plugin is None:
logging.warning(f'Plugin "{name}" is not loaded')
return False
_PLUGINS[name] = None
if not plugin.unload():
logging.error(f'Plugin "{name}" could not be unloaded')
return False

# Remove all subscription
expose.unload_plugin(name)

logging.info(f'Plugin "{name}" has been unloaded')
return True


def unload_all() -> bool:
success = True
for plugin in _PLUGINS.values():
if not plugin.unload():
logging.error(f'Plugin "{plugin.name}" could not be unloaded')
success = False
logging.info(f'Plugin "{plugin.name}" has been unloaded')

_PLUGINS.clear()
return success


def get_plugin_names() -> Tuple[str, ...]:
return tuple(_PLUGINS.keys())
17 changes: 16 additions & 1 deletion src/backend/plugins/priority.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Generator, List, Tuple

import logging

from utils.const import PLUGINS_DIR

_PRIORITIES: List[Tuple[str, int]] = []
Expand Down Expand Up @@ -28,10 +30,23 @@ def change_plugin_priority(name: str, priority: int) -> None:
_PRIORITIES[i] = (name, priority)


def remove_plugin(name: str) -> None:
for i, pp_pair in enumerate(_PRIORITIES):
if name == pp_pair[0]:
_PRIORITIES.pop(i)
_save_priorities()
return
logging.warning(f'Plugin "{name}" was not found in the priority list')


def length() -> int:
return len(_PRIORITIES)


def _save_priorities() -> None:
with open(f'{PLUGINS_DIR}/priorities.csv', 'w', encoding='utf-8') as f:
for prio in _PRIORITIES:
f.write(f'{prio[0]},{prio[1]}')
f.write(f'{prio[0]},{prio[1]}\n')
f.close()


Expand Down

0 comments on commit c216660

Please sign in to comment.