From 8c55788673609871273b30dba89c7a74d332884d Mon Sep 17 00:00:00 2001 From: haliphax Date: Wed, 10 Apr 2024 12:17:57 -0500 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20ground=20work=20for=20forc?= =?UTF-8?q?ibly=20closing=20connections?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xthulu/ssh/__init__.py | 7 ++++--- xthulu/ssh/context/__init__.py | 4 ++-- xthulu/ssh/exceptions.py | 4 ++++ xthulu/ssh/process_factory.py | 5 ++++- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/xthulu/ssh/__init__.py b/xthulu/ssh/__init__.py index 29b88b5..8a699f6 100644 --- a/xthulu/ssh/__init__.py +++ b/xthulu/ssh/__init__.py @@ -8,7 +8,7 @@ from tracemalloc import start # 3rd party -from asyncssh import listen +from asyncssh import SSHAcceptor, listen # local from ..configuration import get_config @@ -20,7 +20,7 @@ from .server import SSHServer -async def start_server(): +async def start_server() -> SSHAcceptor: """Start the SSH server.""" register_encodings() @@ -49,4 +49,5 @@ async def start_server(): start() await res.db.set_bind(res.db.bind) - await listen(**kwargs) + + return await listen(**kwargs) diff --git a/xthulu/ssh/context/__init__.py b/xthulu/ssh/context/__init__.py index 944b048..eecc4b8 100644 --- a/xthulu/ssh/context/__init__.py +++ b/xthulu/ssh/context/__init__.py @@ -21,7 +21,7 @@ from ...models import User from ...scripting import load_userland_module from ..console import XthuluConsole -from ..exceptions import Goto, ProcessClosing +from ..exceptions import Goto, ProcessClosing, ProcessForciblyClosed from ..structs import Script from .lock_manager import _LockManager from .logger_adapter import ContextLoggerAdapter @@ -275,7 +275,7 @@ async def runscript(self, script: Script) -> Any: mod = load_userland_module(script.name) main: Callable[..., Any] = getattr(mod, "main") return await main(self, *script.args, **script.kwargs) - except (ProcessClosing, Goto): + except (ProcessClosing, ProcessForciblyClosed, Goto): raise except Exception: message = f"Exception in script {script.name}" diff --git a/xthulu/ssh/exceptions.py b/xthulu/ssh/exceptions.py index 19107ce..63fb375 100644 --- a/xthulu/ssh/exceptions.py +++ b/xthulu/ssh/exceptions.py @@ -20,3 +20,7 @@ def __init__(self, script: str, *args, **kwargs): class ProcessClosing(Exception): """Thrown when the `asyncssh.SSHServerProcess` is closing""" + + +class ProcessForciblyClosed(Exception): + """Thrown when the process is being forcibly closed by the server""" diff --git a/xthulu/ssh/process_factory.py b/xthulu/ssh/process_factory.py index ee41a3e..a4e835f 100644 --- a/xthulu/ssh/process_factory.py +++ b/xthulu/ssh/process_factory.py @@ -12,7 +12,7 @@ from ..events.structs import EventData from .console import XthuluConsole from .context import SSHContext -from .exceptions import Goto, ProcessClosing +from .exceptions import Goto, ProcessClosing, ProcessForciblyClosed from .structs import Script @@ -68,6 +68,9 @@ async def input_loop(): # process is likely closing break + except ProcessForciblyClosed: + break + except TimeoutError: cx.log.warn("Timed out") cx.echo("\n\n[bright_white on red] TIMED OUT [/]\n\n")