From cb169e34e3e3156397403157d0ea76eaa093472c Mon Sep 17 00:00:00 2001 From: nggit Date: Tue, 7 Nov 2023 21:05:37 +0700 Subject: [PATCH] release 0.0.207 (#27) * add auto reload support * fix python tests below 3.8 * remove --reuse-port, as it's enabled by default * add timeout option in tasks.create --------- Co-authored-by: nggit <12218311+nggit@users.noreply.github.com> --- alltests.py | 1 + example_uvloop.py | 14 +-- setup.py | 8 +- tests/asgi_server.py | 16 ++-- tests/http_server.py | 14 ++- tests/test_http_client.py | 45 ++++++++- tests/utils.py | 10 +- tremolo/__init__.py | 2 +- tremolo/__main__.py | 7 +- tremolo/http_server.py | 3 +- tremolo/lib/http_response.py | 6 +- tremolo/lib/tasks.py | 9 +- tremolo/tremolo.py | 179 ++++++++++++++++++++++++++++------- tremolo/utils.py | 11 ++- 14 files changed, 244 insertions(+), 81 deletions(-) diff --git a/alltests.py b/alltests.py index 8ac1af0..2997f44 100644 --- a/alltests.py +++ b/alltests.py @@ -20,6 +20,7 @@ kwargs=dict(host=HTTP_HOST, port=HTTP_PORT, debug=False, + reload=True, client_max_body_size=73728)) ) processes.append(mp.Process( diff --git a/example_uvloop.py b/example_uvloop.py index f55168d..8c93eac 100644 --- a/example_uvloop.py +++ b/example_uvloop.py @@ -2,6 +2,13 @@ import asyncio +try: + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + print('INFO: uvloop is not installed') + import tremolo @@ -22,11 +29,4 @@ async def app(scope, receive, send): }) if __name__ == '__main__': - try: - import uvloop - - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - except ImportError: - print('INFO: uvloop is not installed') - tremolo.run(app, host='0.0.0.0', port=8000, debug=True) diff --git a/setup.py b/setup.py index 5acf42d..87fd668 100644 --- a/setup.py +++ b/setup.py @@ -2,18 +2,12 @@ from setuptools import setup -if __name__ == '__main__': - import sys - - if len(sys.argv) == 1: - sys.argv.append('install') - with open('README.md', 'r') as f: long_description = f.read() setup( name='tremolo', - version='0.0.205', + version='0.0.207', license='MIT', author='nggit', author_email='contact@anggit.com', diff --git a/tests/asgi_server.py b/tests/asgi_server.py index 46a41cb..9da40c2 100644 --- a/tests/asgi_server.py +++ b/tests/asgi_server.py @@ -2,7 +2,6 @@ __all__ = ('app', 'ASGI_HOST', 'ASGI_PORT') -import asyncio # noqa: E402 import os # noqa: E402 import sys # noqa: E402 @@ -16,7 +15,13 @@ from tests.http_server import HTTP_PORT, TEST_FILE # noqa: E402 -ASGI_HOST = '::' +if sys.version_info[:2] < (3, 8): + # on Windows, Python versions below 3.8 don't properly support + # dual-stack IPv4/6. https://github.com/python/cpython/issues/73701 + ASGI_HOST = '0.0.0.0' +else: + ASGI_HOST = '::' + ASGI_PORT = HTTP_PORT + 10 @@ -116,11 +121,4 @@ async def app(scope, receive, send): }) if __name__ == '__main__': - try: - import uvloop - - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - except ImportError: - print('INFO: uvloop is not installed') - tremolo.run(app, host=ASGI_HOST, port=ASGI_PORT, debug=True, worker_num=2) diff --git a/tests/http_server.py b/tests/http_server.py index 26c3e02..aad9ca3 100644 --- a/tests/http_server.py +++ b/tests/http_server.py @@ -270,6 +270,17 @@ async def timeouts(request=None, **_): # should raise a TimeoutError and ended up with a RequestTimeout await request.recv(100) + +@app.route('/reload') +async def reload(request=None, **_): + yield b'%d' % hash(app) + + if request.query_string != b'': + mtime = float(request.query_string) + + # simulate a code change + os.utime(TEST_FILE, (mtime, mtime)) + # test multiple ports app.listen(HTTP_PORT + 1, request_timeout=2, keepalive_timeout=2) app.listen(HTTP_PORT + 2) @@ -279,6 +290,7 @@ async def timeouts(request=None, **_): app.listen('tremolo-test', debug=False, client_max_body_size=73728) if __name__ == '__main__': - app.run(HTTP_HOST, port=HTTP_PORT, debug=True, client_max_body_size=73728) + app.run(HTTP_HOST, port=HTTP_PORT, debug=True, reload=True, + client_max_body_size=73728) # END diff --git a/tests/test_http_client.py b/tests/test_http_client.py index 9c03830..ccf9e26 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -475,11 +475,17 @@ def test_headertoolarge(self): self.assertEqual(data, (b'', b'')) def test_requesttimeout(self): - data = getcontents( - host=HTTP_HOST, - port=HTTP_PORT + 1, - raw=b'GET / HTTP/1.1\r\nHost: localhost:%d\r\n' % (HTTP_PORT + 1) - ) + for _ in range(10): + try: + data = getcontents( + host=HTTP_HOST, + port=HTTP_PORT + 1, + raw=b'GET / HTTP/1.1\r\n' + b'Host: localhost:%d\r\n' % (HTTP_PORT + 1) + ) + break + except ConnectionResetError: + continue self.assertEqual(data, (b'', b'')) @@ -740,6 +746,34 @@ def test_websocket(self): self.assertEqual(payload[:7], data_out[:7]) + def test_reload(self): + header, body1 = getcontents(host=HTTP_HOST, + port=HTTP_PORT, + method='GET', + url='/reload?%f' % time.time(), + version='1.0') + + self.assertFalse(body1 == b'') + + for _ in range(10): + time.sleep(1) + + try: + header, body2 = getcontents(host=HTTP_HOST, + port=HTTP_PORT, + method='GET', + url='/reload', + version='1.0') + except ConnectionResetError: + continue + + self.assertFalse(body2 == b'') + + if body2 != body1: + break + + self.assertFalse(body2 == body1) + if __name__ == '__main__': mp.set_start_method('spawn') @@ -749,6 +783,7 @@ def test_websocket(self): kwargs=dict(host=HTTP_HOST, port=HTTP_PORT, debug=False, + reload=True, client_max_body_size=73728) ) diff --git a/tests/utils.py b/tests/utils.py index 97d6699..ea945b2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -57,14 +57,14 @@ def getcontents( family = socket.AF_INET if ':' in host: - if host == '::': - host = '127.0.0.1' - else: - family = socket.AF_INET6 + family = socket.AF_INET6 + + if host in ('0.0.0.0', '::'): + host = 'localhost' with socket.socket(family, socket.SOCK_STREAM) as sock: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - sock.settimeout(5) + sock.settimeout(10) while sock.connect_ex((host, port)) != 0: time.sleep(1) diff --git a/tremolo/__init__.py b/tremolo/__init__.py index d5c672d..aa7d888 100644 --- a/tremolo/__init__.py +++ b/tremolo/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.0.205' +__version__ = '0.0.207' from .tremolo import Tremolo # noqa: E402 from . import exceptions # noqa: E402,F401 diff --git a/tremolo/__main__.py b/tremolo/__main__.py index b6bf70e..70d7842 100644 --- a/tremolo/__main__.py +++ b/tremolo/__main__.py @@ -23,7 +23,6 @@ print(' --bind Address to bind.') print(' Instead of using --host or --port') print(' E.g. "127.0.0.1:8000" or "/tmp/file.sock"') # noqa: E501 - print(' --reuse-port Use SO_REUSEPORT when available') print(' --worker-num Number of worker processes. Defaults to 1') # noqa: E501 print(' --backlog Maximum number of pending connections') # noqa: E501 print(' Defaults to 100') @@ -33,7 +32,9 @@ print(' E.g. "/path/to/privkey.pem"') print(' --debug Enable debug mode.') print(' Intended for development') - print(' --no-ws Disable built-in WebSocket support.') # noqa: E501 + print(' --reload Enable auto reload on code changes.') # noqa: E501 + print(' Intended for development') + print(' --no-ws Disable built-in WebSocket support') # noqa: E501 print(' --log-level Defaults to "DEBUG". See') print(' https://docs.python.org/3/library/logging.html#levels') # noqa: E501 print(' --download-rate Limits the sending speed to the client') # noqa: E501 @@ -53,7 +54,7 @@ sys.exit() elif sys.argv[i - 1] == '--no-ws': options['ws'] = False - elif sys.argv[i - 1] in ('--debug', '--reuse-port'): + elif sys.argv[i - 1] in ('--debug', '--reload'): options[sys.argv[i - 1].lstrip('-').replace('-', '_')] = True elif sys.argv[i - 1] in ('--host', '--log-level', diff --git a/tremolo/http_server.py b/tremolo/http_server.py index 4dfddd7..11f8083 100644 --- a/tremolo/http_server.py +++ b/tremolo/http_server.py @@ -55,7 +55,8 @@ def connection_made(self, transport): def connection_lost(self, exc): if self._middlewares['close']: - self.loop.create_task(self._connection_lost(exc)) + task = self.loop.create_task(self._connection_lost(exc)) + self.loop.call_at(self.loop.time() + 30, task.cancel) else: super().connection_lost(exc) diff --git a/tremolo/lib/http_response.py b/tremolo/lib/http_response.py index 8eff8a0..5214807 100644 --- a/tremolo/lib/http_response.py +++ b/tremolo/lib/http_response.py @@ -298,10 +298,10 @@ async def sendfile( self._request.context.RESPONSE_SENDFILE_HANDLE.close ) - file_size = os.stat(path).st_size - mtime = os.path.getmtime(path) + st = os.stat(path) + file_size = st.st_size mdate = time.strftime('%a, %d %b %Y %H:%M:%S GMT', - time.gmtime(mtime)).encode('latin-1') + time.gmtime(st.st_mtime)).encode('latin-1') if (self._request.version == b'1.1' and b'range' in self._request.headers): diff --git a/tremolo/lib/tasks.py b/tremolo/lib/tasks.py index b20d325..7a49cfc 100644 --- a/tremolo/lib/tasks.py +++ b/tremolo/lib/tasks.py @@ -11,8 +11,13 @@ def __init__(self, tasks, loop=None): self._loop = loop self._tasks = tasks - def create(self, coro): + def create(self, coro, timeout=0): task = self._loop.create_task(coro) - self._tasks.append(task.cancel) + + if timeout > 0: + self._loop.call_at(self._loop.time() + timeout, task.cancel) + else: + # until the connection is lost + self._tasks.append(task.cancel) return task diff --git a/tremolo/tremolo.py b/tremolo/tremolo.py index 44ab6ed..d4d25e3 100644 --- a/tremolo/tremolo.py +++ b/tremolo/tremolo.py @@ -13,11 +13,11 @@ import time # noqa: E402 from functools import wraps # noqa: E402 -from importlib import import_module # noqa: E402 +from importlib import import_module, reload as reload_module # noqa: E402 from shutil import get_terminal_size # noqa: E402 from . import __version__, handlers # noqa: E402 -from .utils import log_date, server_date # noqa: E402 +from .utils import file_signature, log_date, server_date # noqa: E402 from .lib.connections import KeepAliveConnections # noqa: E402 from .lib.contexts import ServerContext as WorkerContext # noqa: E402 from .lib.locks import ServerLock # noqa: E402 @@ -263,7 +263,6 @@ async def _serve(self, host, port, **options): pools = { 'queue': QueuePool(1024, self._logger) } - lifespan = None if 'app' in options and isinstance(options['app'], str): from .asgi_lifespan import ASGILifespan @@ -278,13 +277,14 @@ async def _serve(self, host, port, **options): options['app'] += ':app' path, attr_name = options['app'].rsplit(':', 1) - dir_name, base_name = os.path.split(path) + options['app_dir'], base_name = os.path.split( + os.path.abspath(path)) module_name = os.path.splitext(base_name)[0] - if dir_name == '': - dir_name = os.getcwd() + if options['app_dir'] == '': + options['app_dir'] = os.getcwd() - sys.path.insert(0, dir_name) + sys.path.insert(0, options['app_dir']) options['app'] = getattr(import_module(module_name), attr_name) @@ -365,6 +365,9 @@ async def _serve(self, host, port, **options): print() + paths = [path for path in sys.path + if not options['app_dir'].startswith(path)] + modules = {} process_num = 1 # serve forever @@ -379,6 +382,43 @@ async def _serve(self, host, port, **options): # update server date server_info['date'] = server_date() + # detect code changes + if options.get('reload', False): + for module in (dict(modules) or sys.modules.values()): + if hasattr(module, '__file__'): + for path in paths: + if (module.__file__ is None or + module.__file__.startswith(path)): + break + else: + if not os.path.exists(module.__file__): + if module in modules: + del modules[module] + + continue + + _sign = file_signature(module.__file__) + + if module in modules: + if modules[module] == _sign: + # file not modified + continue + + modules[module] = _sign + else: + modules[module] = _sign + continue + + self._logger.info('reload: %s', + module.__file__) + + server.close() + await server.wait_closed() + + # essentially means sys.exit(0) + # to trigger a reload + return + if options['_conn'].poll(): break @@ -389,7 +429,7 @@ async def _serve(self, host, port, **options): server.close() await server.wait_closed() - if lifespan is None: + if options['app'] is None: i = len(self.events['worker_stop']) while i > 0: @@ -436,6 +476,12 @@ def _worker(self, host, port, **kwargs): self._loop.create_task(self._stop(task)) self._loop.run_forever() finally: + exc = task.exception() + + # to avoid None, SystemExit, etc. for being printed + if isinstance(exc, Exception): + self._logger.error(exc) + self._loop.close() def create_sock(self, host, port, reuse_port=True): @@ -447,6 +493,9 @@ def create_sock(self, host, port, reuse_port=True): sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) if host == '::' and hasattr(socket, 'IPPROTO_IPV6'): + # on Windows, Python versions below 3.8 + # don't properly support dual-stack IPv4/6. + # https://github.com/python/cpython/issues/73701 sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) else: @@ -482,6 +531,18 @@ def create_sock(self, host, port, reuse_port=True): return sock + def close_sock(self, sock): + try: + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + if sock.family.name == 'AF_UNIX': + os.unlink(sock.getsockname()) + except FileNotFoundError: + pass + except OSError: + sock.close() + def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs): kwargs['log_level'] = kwargs.get('log_level', 'DEBUG').upper() terminal_width = min(get_terminal_size()[0], 72) @@ -495,26 +556,31 @@ def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs): ) print('-' * terminal_width) - if 'app' in kwargs: - if not isinstance(kwargs['app'], str): - import __main__ + import __main__ - if hasattr(__main__, '__file__'): - for attr_name in dir(__main__): - if attr_name.startswith('__'): - continue + if 'app' in kwargs: + if (not isinstance(kwargs['app'], str) and + hasattr(__main__, '__file__')): + for attr_name in dir(__main__): + if attr_name.startswith('__'): + continue - if getattr(__main__, attr_name) == kwargs['app']: - break - else: - attr_name = 'app' + if getattr(__main__, attr_name) == kwargs['app']: + break + else: + attr_name = 'app' - kwargs['app'] = '%s:%s' % (__main__.__file__, attr_name) + kwargs['app'] = '%s:%s' % (__main__.__file__, attr_name) locks = [] else: locks = [mp.Lock() for _ in range(kwargs.get('locks', 16))] + if hasattr(__main__, '__file__'): + kwargs['app_dir'], base_name = os.path.split( + os.path.abspath(__main__.__file__)) + module_name = os.path.splitext(base_name)[0] + if kwargs['log_level'] in ('DEBUG', 'INFO'): print('Routes:') @@ -542,6 +608,9 @@ def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs): else: self.listen(port, host=host, **kwargs) + if worker_num < 1: + raise ValueError('worker_num must be greater than 0') + try: worker_num = min(worker_num, len(os.sched_getaffinity(0))) except AttributeError: @@ -608,7 +677,48 @@ def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs): try: for i, (parent_conn, p, args, options) in enumerate(processes): if not p.is_alive(): - print('A worker process died. Restarting...') + if p.exitcode == 0: + print('Reloading...') + + if 'app' not in kwargs: + for module in list(sys.modules.values()): + if (hasattr(module, '__file__') and + module.__name__ not in ( + '__main__', + '__mp_main__', + 'tremolo') and + not module.__name__.startswith( + 'tremolo.') and + module.__file__ is not None and + module.__file__.startswith( + kwargs['app_dir']) and + os.path.exists(module.__file__)): + reload_module(module) + + if module_name in sys.modules: + _module = sys.modules[module_name] + else: + _module = import_module(module_name) + + # we need to update/rebind objects like + # routes, middleware, etc. + for attr_name in dir(_module): + if attr_name.startswith('__'): + continue + + attr = getattr(_module, attr_name) + + if isinstance(attr, self.__class__): + self.__dict__.update(attr.__dict__) + else: + print('A worker process died. Restarting...') + + if p.exitcode != 0 or hasattr(socks[args], 'share'): + # renew socket + # this is a workaround, especially on Windows + socks[args] = self.create_sock( + *args, options.get('reuse_port', reuse_port) + ) parent_conn.close() parent_conn, child_conn = mp.Pipe() @@ -624,10 +734,10 @@ def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs): ) p.start() - pid = parent_conn.recv() + child_pid = parent_conn.recv() if hasattr(socks[args], 'share'): - parent_conn.send(socks[args].share(pid)) + parent_conn.send(socks[args].share(child_pid)) else: parent_conn.send(socks[args].fileno()) @@ -637,9 +747,15 @@ def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs): processes[i] = (parent_conn, p, args, options) # response ping from child - while parent_conn.poll(): - parent_conn.recv() - parent_conn.send(len(processes)) + while True: + try: + if not parent_conn.poll(): + break + + parent_conn.recv() + parent_conn.send(len(processes)) + except BrokenPipeError: + break time.sleep(1) except KeyboardInterrupt: @@ -654,13 +770,4 @@ def run(self, host=None, port=0, reuse_port=True, worker_num=1, **kwargs): print('pid %d terminated' % p.pid) for sock in socks.values(): - try: - sock.shutdown(socket.SHUT_RDWR) - sock.close() - - if sock.family.name == 'AF_UNIX': - os.unlink(sock.getsockname()) - except FileNotFoundError: - pass - except OSError: - sock.close() + self.close_sock(sock) diff --git a/tremolo/utils.py b/tremolo/utils.py index e8a4510..210129e 100644 --- a/tremolo/utils.py +++ b/tremolo/utils.py @@ -1,11 +1,20 @@ # Copyright (c) 2023 nggit -__all__ = ('html_escape', 'log_date', 'server_date') +__all__ = ('file_signature', 'html_escape', 'log_date', 'server_date') + +import os # noqa: E402 +import stat # noqa: E402 from datetime import datetime # noqa: E402 from html import escape # noqa: E402 +def file_signature(path): + st = os.stat(path) + + return (stat.S_IFMT(st.st_mode), st.st_size, st.st_mtime) + + def html_escape(data): if isinstance(data, str): return escape(data)