From de3bb1d0a28f04895f1850b8fb841c7755e5d3f6 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 29 Oct 2022 19:52:33 +0200 Subject: [PATCH] Add support for mount routes (#694) * cleanup * refactor * simplified design * fix ci * finished taversal * simplified implementation * added docs * Update starlite/asgi/routing_trie/validate.py Co-authored-by: provinzkraut <25355197+provinzkraut@users.noreply.github.com> * Update starlite/router.py Co-authored-by: provinzkraut <25355197+provinzkraut@users.noreply.github.com> * Update starlite/router.py Co-authored-by: provinzkraut <25355197+provinzkraut@users.noreply.github.com> * addressed revoew comments * 1.35.0 Co-authored-by: provinzkraut <25355197+provinzkraut@users.noreply.github.com> --- CHANGELOG.md | 30 +- docs/usage/1-routing/5-mounting-asgi-apps.md | 23 ++ .../routing}/__init__.py | 0 examples/routing/mount_custom_app.py | 24 ++ examples/routing/mounting_starlette_app.py | 30 ++ examples/tests/routing/test_mounting.py | 32 ++ mkdocs.yml | 1 + mypy.ini | 3 + poetry.lock | 42 +-- pyproject.toml | 3 +- starlite/__init__.py | 2 +- starlite/app.py | 246 ++------------ starlite/asgi.py | 318 ------------------ starlite/asgi/__init__.py | 3 + starlite/asgi/asgi_router.py | 199 +++++++++++ starlite/asgi/routing_trie/__init__.py | 6 + starlite/asgi/routing_trie/mapping.py | 156 +++++++++ starlite/asgi/routing_trie/traversal.py | 199 +++++++++++ starlite/asgi/routing_trie/types.py | 61 ++++ starlite/asgi/routing_trie/utils.py | 23 ++ starlite/asgi/routing_trie/validate.py | 29 ++ starlite/asgi/utils.py | 42 +++ starlite/config/compression.py | 8 +- starlite/config/logging.py | 43 +-- starlite/config/static_files.py | 9 +- starlite/datastructures/provide.py | 2 +- starlite/handlers/asgi.py | 13 + starlite/logging/picologging.py | 2 +- starlite/openapi/parameters.py | 8 +- starlite/router.py | 18 +- starlite/routes/base.py | 20 +- starlite/testing/test_client/transport.py | 4 +- starlite/types/internal_types.py | 8 +- starlite/utils/exception.py | 4 +- starlite/utils/extractors.py | 2 +- tests/conftest.py | 19 +- tests/kwargs/test_path_params.py | 1 + .../test_exception_handler_middleware.py | 4 +- tests/middleware/test_middleware_handling.py | 2 +- tests/openapi/test_request_body.py | 2 +- tests/routing/test_path_mounting.py | 80 +++++ tests/routing/test_path_resolution.py | 4 +- tests/routing/test_route_map.py | 155 --------- .../test_static_files.py | 33 +- tests/test_guards.py | 28 +- 45 files changed, 1117 insertions(+), 824 deletions(-) create mode 100644 docs/usage/1-routing/5-mounting-asgi-apps.md rename {tests/static_files => examples/routing}/__init__.py (100%) create mode 100644 examples/routing/mount_custom_app.py create mode 100644 examples/routing/mounting_starlette_app.py create mode 100644 examples/tests/routing/test_mounting.py delete mode 100644 starlite/asgi.py create mode 100644 starlite/asgi/__init__.py create mode 100644 starlite/asgi/asgi_router.py create mode 100644 starlite/asgi/routing_trie/__init__.py create mode 100644 starlite/asgi/routing_trie/mapping.py create mode 100644 starlite/asgi/routing_trie/traversal.py create mode 100644 starlite/asgi/routing_trie/types.py create mode 100644 starlite/asgi/routing_trie/utils.py create mode 100644 starlite/asgi/routing_trie/validate.py create mode 100644 starlite/asgi/utils.py create mode 100644 tests/routing/test_path_mounting.py delete mode 100644 tests/routing/test_route_map.py rename tests/{static_files => routing}/test_static_files.py (76%) diff --git a/CHANGELOG.md b/CHANGELOG.md index b636588b78..ee643da5f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,21 @@ # Changelog +[1.35.0] + +- add context-manager when using SQLAlchemy sessions. +- add support for mounting ASGI applications. +- fix `SQLAlchemyPlugin.to_dict()` where instance has relationship raising an exception. +- update route registration to ensure unique handlers. +- update routing logic to use a cleaner architecture. +- update sessions to support explicitly setting to `Empty`. +- update test client to run session creation in the client's portal. + [1.34.0] -- Add support for server-side sessions -- Fix an issue where header values would be forced to lower case -- Add a `__test__ = False` attribute to the `TestClient` so it won't get collected by pytest - together with an async test +- add a `__test__ = False` attribute to the `TestClient` so it won't get collected by pytest together with an async + test. +- add support for server-side sessions. +- fix an issue where header values would be forced to lower case. [1.33.0] @@ -102,7 +112,8 @@ - add `**kwargs` support to route handlers. - breaking: remove `create_test_request`. -- breaking: update Starlette to version `0.21.0`. This version changes the TestClient to use `httpx` instead of `requests`, which is a breaking change. +- breaking: update Starlette to version `0.21.0`. This version changes the TestClient to use `httpx` instead + of `requests`, which is a breaking change. - fix add default empty session to `RequestFactory`. [1.21.2] @@ -124,7 +135,8 @@ [1.20.0] -- update ASGI typings (`scope`, `receive`, `send`, `message` and `ASGIApp`) to use strong types derived from [asgiref](https://github.com/django/asgiref). +- update ASGI typings (`scope`, `receive`, `send`, `message` and `ASGIApp`) to use strong types derived + from [asgiref](https://github.com/django/asgiref). - update `SessionMiddleware` to use custom serializer used on request. - update `openapi-pydantic-schema` to `v1.3.0` adding support for `__schema_name__`. @@ -252,7 +264,8 @@ [1.8.0] - add [Stoplights Elements](https://stoplight.io/open-source/elements) OpenAPI support @aedify-swi -- breaking replace [openapi-pydantic-schema](https://github.com/kuimono/openapi-schema-pydantic) with [pydantic-openapi-schema](https://github.com/starlite-api/pydantic-openapi-schema). +- breaking replace [openapi-pydantic-schema](https://github.com/kuimono/openapi-schema-pydantic) + with [pydantic-openapi-schema](https://github.com/starlite-api/pydantic-openapi-schema). [1.7.3] @@ -434,7 +447,8 @@ - add template support @ashwinvin. - update `starlite.request` by renaming it to `starlite.connection`. -- update the kwarg parsing and data injection logic to compute required kwargs for each route handler during application bootstrap. +- update the kwarg parsing and data injection logic to compute required kwargs for each route handler during application + bootstrap. - update the redoc UI path from `/schema/redoc` to `/schema` @yudjinn. [0.7.2] diff --git a/docs/usage/1-routing/5-mounting-asgi-apps.md b/docs/usage/1-routing/5-mounting-asgi-apps.md new file mode 100644 index 0000000000..554ad3ec4d --- /dev/null +++ b/docs/usage/1-routing/5-mounting-asgi-apps.md @@ -0,0 +1,23 @@ +# Mounting ASGI Apps + +Starlite support "mounting" ASGI applications on sub paths, that is - specifying a handler function that will handle all +requests addressed to a given path. + +```py title="Mounting an ASGI App" +--8<-- "examples/routing/mount_custom_app.py" +``` + +The handler function will receive all requests with a url that begins with `/some/sub-path`, e.g. `/some/sub-path` and +`/some/sub-path/abc` and `/some/sub-path/123/another/sub-path` etc. + +!!! info Technical Details + If we were to send a request to the above with the url `/some/sub-path/abc`, the handler will be invoked and + the value of `scope["path"]` will equal `/`. If we send a request to `/some/sub-path/abc`, it will also be invoked, + and `scope["path"]` will equal `/abc`. + +Mounting is especially useful when you need to combine components of other ASGI applications - for example, for 3rd part libraries. +The following example is identical in principle to the one above but it uses `Starlette`: + +```py title="Mounting a Starlette App" +--8<-- "examples/routing/mounting_starlette_app.py" +``` diff --git a/tests/static_files/__init__.py b/examples/routing/__init__.py similarity index 100% rename from tests/static_files/__init__.py rename to examples/routing/__init__.py diff --git a/examples/routing/mount_custom_app.py b/examples/routing/mount_custom_app.py new file mode 100644 index 0000000000..ee7ef7875d --- /dev/null +++ b/examples/routing/mount_custom_app.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from starlite import Response, Starlite, asgi + +if TYPE_CHECKING: + from starlite.types import Receive, Scope, Send + + +@asgi("/some/sub-path", is_mount=True) +async def my_asgi_app(scope: "Scope", receive: "Receive", send: "Send") -> None: + """ + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + response = Response(content={"forwarded_path": scope["path"]}) + await response(scope, receive, send) + + +app = Starlite(route_handlers=[my_asgi_app]) diff --git a/examples/routing/mounting_starlette_app.py b/examples/routing/mounting_starlette_app.py new file mode 100644 index 0000000000..9a282c6997 --- /dev/null +++ b/examples/routing/mounting_starlette_app.py @@ -0,0 +1,30 @@ +from typing import TYPE_CHECKING + +from starlette.applications import Starlette +from starlette.responses import JSONResponse +from starlette.routing import Route + +from starlite import Starlite, asgi + +if TYPE_CHECKING: + from starlette.requests import Request + + +async def index(request: "Request") -> JSONResponse: + """A generic starlette handler.""" + return JSONResponse({"forwarded_path": request.url.path}) + + +starlette_app = asgi(path="/some/sub-path", is_mount=True)( + Starlette( + debug=True, + routes=[ + Route("/", index), + Route("/abc", index), + Route("/123/another/sub-path", index), + ], + ) +) + + +app = Starlite(route_handlers=[starlette_app]) diff --git a/examples/tests/routing/test_mounting.py b/examples/tests/routing/test_mounting.py new file mode 100644 index 0000000000..7fc55bf2ec --- /dev/null +++ b/examples/tests/routing/test_mounting.py @@ -0,0 +1,32 @@ +from typing import TYPE_CHECKING + +import pytest + +from examples.routing import mount_custom_app, mounting_starlette_app +from starlite.status_codes import HTTP_200_OK +from starlite.testing import TestClient + +if TYPE_CHECKING: + from starlite import Starlite + + +@pytest.mark.parametrize( + "app", + ( + mount_custom_app.app, + mounting_starlette_app.app, + ), +) +def test_mounting_asgi_app_example(app: "Starlite") -> None: + with TestClient(app) as client: + response = client.get("/some/sub-path") + assert response.status_code == HTTP_200_OK + assert response.json() == {"forwarded_path": "/"} + + response = client.get("/some/sub-path/abc") + assert response.status_code == HTTP_200_OK + assert response.json() == {"forwarded_path": "/abc"} + + response = client.get("/some/sub-path/123/another/sub-path") + assert response.status_code == HTTP_200_OK + assert response.json() == {"forwarded_path": "/123/another/sub-path"} diff --git a/mkdocs.yml b/mkdocs.yml index 1eb1eb5aa0..0eeab93202 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -60,6 +60,7 @@ nav: - usage/1-routing/2-routers.md - usage/1-routing/3-controllers.md - usage/1-routing/4-registering-components-multiple-times.md + - usage/1-routing/5-mounting-asgi-apps.md - Route Handlers: - usage/2-route-handlers/0-route-handlers-concept.md - usage/2-route-handlers/1-http-route-handlers.md diff --git a/mypy.ini b/mypy.ini index 2676a9b92d..6d0e993d72 100644 --- a/mypy.ini +++ b/mypy.ini @@ -35,3 +35,6 @@ ignore_missing_imports = True [mypy-mako.*] ignore_missing_imports = True + +[mypy-fakeredis.*] +ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index bb8f2d221c..c397e702dd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1171,14 +1171,18 @@ black = [ {file = "black-22.10.0-1fixedarch-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:197df8509263b0b8614e1df1756b1dd41be6738eed2ba9e9769f3880c2b9d7b6"}, {file = "black-22.10.0-1fixedarch-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:2644b5d63633702bc2c5f3754b1b475378fbbfb481f62319388235d0cd104c2d"}, {file = "black-22.10.0-1fixedarch-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:e41a86c6c650bcecc6633ee3180d80a025db041a8e2398dcc059b3afa8382cd4"}, + {file = "black-22.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2039230db3c6c639bd84efe3292ec7b06e9214a2992cd9beb293d639c6402edb"}, {file = "black-22.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14ff67aec0a47c424bc99b71005202045dc09270da44a27848d534600ac64fc7"}, {file = "black-22.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:819dc789f4498ecc91438a7de64427c73b45035e2e3680c92e18795a839ebb66"}, + {file = "black-22.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5b9b29da4f564ba8787c119f37d174f2b69cdfdf9015b7d8c5c16121ddc054ae"}, {file = "black-22.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8b49776299fece66bffaafe357d929ca9451450f5466e997a7285ab0fe28e3b"}, {file = "black-22.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:21199526696b8f09c3997e2b4db8d0b108d801a348414264d2eb8eb2532e540d"}, {file = "black-22.10.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e464456d24e23d11fced2bc8c47ef66d471f845c7b7a42f3bd77bf3d1789650"}, {file = "black-22.10.0-cp37-cp37m-win_amd64.whl", hash = "sha256:9311e99228ae10023300ecac05be5a296f60d2fd10fff31cf5c1fa4ca4b1988d"}, + {file = "black-22.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:fba8a281e570adafb79f7755ac8721b6cf1bbf691186a287e990c7929c7692ff"}, {file = "black-22.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:915ace4ff03fdfff953962fa672d44be269deb2eaf88499a0f8805221bc68c87"}, {file = "black-22.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:444ebfb4e441254e87bad00c661fe32df9969b2bf224373a448d8aca2132b395"}, + {file = "black-22.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:974308c58d057a651d182208a484ce80a26dac0caef2895836a92dd6ebd725e0"}, {file = "black-22.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72ef3925f30e12a184889aac03d77d031056860ccae8a1e519f6cbb742736383"}, {file = "black-22.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:432247333090c8c5366e69627ccb363bc58514ae3e63f7fc75c54b1ea80fa7de"}, {file = "black-22.10.0-py3-none-any.whl", hash = "sha256:c957b2b4ea88587b46cf49d1dc17681c1e672864fd7af32fc1e9664d572b3458"}, @@ -1661,51 +1665,20 @@ nodeenv = [ orjson = [ {file = "orjson-3.8.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:a70aaa2e56356e58c6e1b49f7b7f069df5b15e55db002a74db3ff3f7af67c7ff"}, {file = "orjson-3.8.1-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d45db052d01d0ab7579470141d5c3592f4402d43cfacb67f023bc1210a67b7bc"}, - {file = "orjson-3.8.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2aae92398c0023ac26a6cd026375f765ef5afe127eccabf563c78af7b572d59"}, - {file = "orjson-3.8.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0bd5b4e539db8a9635776bdf9a25c3db84e37165e65d45c8ca90437adc46d6d8"}, {file = "orjson-3.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21efb87b168066201a120b0f54a2381f6f51ff3727e07b3908993732412b314a"}, - {file = "orjson-3.8.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:e073338e422f518c1d4d80efc713cd17f3ed6d37c8c7459af04a95459f3206d1"}, - {file = "orjson-3.8.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:8f672f3987f6424f60ab2e86ea7ed76dd2806b8e9b506a373fc8499aed85ddb5"}, - {file = "orjson-3.8.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:231c30958ed99c23128a21993c5ac0a70e1e568e6a898a47f70d5d37461ca47c"}, - {file = "orjson-3.8.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59b4baf71c9f39125d7e535974b146cc180926462969f6d8821b4c5e975e11b3"}, {file = "orjson-3.8.1-cp310-none-win_amd64.whl", hash = "sha256:fe25f50dc3d45364428baa0dbe3f613a5171c64eb0286eb775136b74e61ba58a"}, {file = "orjson-3.8.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:6802edf98f6918e89df355f56be6e7db369b31eed64ff2496324febb8b0aa43b"}, {file = "orjson-3.8.1-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:a4244f4199a160717f0027e434abb886e322093ceadb2f790ff0c73ed3e17662"}, - {file = "orjson-3.8.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6956cf7a1ac97523e96f75b11534ff851df99a6474a561ad836b6e82004acbb8"}, - {file = "orjson-3.8.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0b4e3857dd2416b479f700e9bdf4fcec8c690d2716622397d2b7e848f9833e50"}, {file = "orjson-3.8.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8873e490dea0f9cd975d66f84618b6fb57b1ba45ecb218313707a71173d764f"}, - {file = "orjson-3.8.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:124207d2cd04e845eaf2a6171933cde40aebcb8c2d7d3b081e01be066d3014b6"}, - {file = "orjson-3.8.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:d8ed77098c2e22181fce971f49a34204c38b79ca91c01d515d07015339ae8165"}, {file = "orjson-3.8.1-cp311-none-win_amd64.whl", hash = "sha256:8623ac25fa0850a44ac845e9333c4da9ae5707b7cec8ac87cbe9d4e41137180f"}, {file = "orjson-3.8.1-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:d67a0bd0283a3b17ac43c5ab8e4a7e9d3aa758d6ec5d51c232343c408825a5ad"}, {file = "orjson-3.8.1-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d89ef8a4444d83e0a5171d14f2ab4895936ab1773165b020f97d29cf289a2d88"}, - {file = "orjson-3.8.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97839a6abbebb06099294e6057d5b3061721ada08b76ae792e7041b6cb54c97f"}, - {file = "orjson-3.8.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6071bcf51f0ae4d53b9d3e9164f7138164df4291c484a7b14562075aaa7a2b7b"}, {file = "orjson-3.8.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c15e7d691cee75b5192fc1fa8487bf541d463246dc25c926b9b40f5b6ab56770"}, - {file = "orjson-3.8.1-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:b9abc49c014def1b832fcd53bdc670474b6fe41f373d16f40409882c0d0eccba"}, - {file = "orjson-3.8.1-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:3fd5472020042482d7da4c26a0ee65dbd931f691e1c838c6cf4232823179ecc1"}, - {file = "orjson-3.8.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:e399ed1b0d6f8089b9b6ff2cb3e71ba63a56d8ea88e1d95467949795cc74adfd"}, - {file = "orjson-3.8.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:5e3db6496463c3000d15b7a712da5a9601c6c43682f23f81862fe1d2a338f295"}, - {file = "orjson-3.8.1-cp37-none-win_amd64.whl", hash = "sha256:0f21eed14697083c01f7e00a87e21056fc8fb5851e8a7bca98345189abcdb4d4"}, - {file = "orjson-3.8.1-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:5a9e324213220578d324e0858baeab47808a13d3c3fbc6ba55a3f4f069d757cf"}, - {file = "orjson-3.8.1-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:69097c50c3ccbcc61292192b045927f1688ca57ce80525dc5d120e0b91e19bb0"}, - {file = "orjson-3.8.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7822cba140f7ca48ed0256229f422dbae69e3a3475176185db0c0538cfadb57"}, - {file = "orjson-3.8.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:03389e3750c521a7f3d4837de23cfd21a7f24574b4b3985c9498f440d21adb03"}, {file = "orjson-3.8.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0f9d9b5c6692097de07dd0b2d5ff20fd135bacd1b2fb7ea383ee717a4150c93"}, - {file = "orjson-3.8.1-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:c2c9ef10b6344465fd5ac002be2d34f818211274dd79b44c75b2c14a979f84f3"}, - {file = "orjson-3.8.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:7adaac93678ac61f5dc070f615b18639d16ee66f6a946d5221dbf315e8b74bec"}, - {file = "orjson-3.8.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b0c1750f73658906b82cabbf4be2f74300644c17cb037fbc8b48d746c3b90c76"}, - {file = "orjson-3.8.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:da6306e1f03e7085fe0db61d4a3377f70c6fd865118d0afe17f80ae9a8f6f124"}, {file = "orjson-3.8.1-cp38-none-win_amd64.whl", hash = "sha256:f532c2cbe8c140faffaebcfb34d43c9946599ea8138971f181a399bec7d6b123"}, {file = "orjson-3.8.1-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:6a7b76d4b44bca418f7797b1e157907b56b7d31caa9091db4e99ebee51c16933"}, {file = "orjson-3.8.1-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:f850489d89ea12be486492e68f0fd63e402fa28e426d4f0b5fc1eec0595e6109"}, - {file = "orjson-3.8.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4449e70b98f3ad3e43958360e4be1189c549865c0a128e8629ec96ce92d251c3"}, - {file = "orjson-3.8.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:45357eea9114bd41ef19280066591e9069bb4f6f5bffd533e9bfc12a439d735f"}, {file = "orjson-3.8.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f5a9bc5bc4d730153529cb0584c63ff286d50663ccd48c9435423660b1bb12d"}, - {file = "orjson-3.8.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:a806aca6b80fa1d996aa16593e4995a71126a085ee1a59fff19ccad29a4e47fd"}, - {file = "orjson-3.8.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:395d02fd6be45f960da014372e7ecefc9e5f8df57a0558b7111a5fa8423c0669"}, - {file = "orjson-3.8.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:caff3c1e964cfee044a03a46244ecf6373f3c56142ad16458a1446ac6d69824a"}, - {file = "orjson-3.8.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5ded261268d5dfd307078fe3370295e5eb15bdde838bbb882acf8538e061c451"}, {file = "orjson-3.8.1-cp39-none-win_amd64.whl", hash = "sha256:45c1914795ffedb2970bfcd3ed83daf49124c7c37943ed0a7368971c6ea5e278"}, {file = "orjson-3.8.1.tar.gz", hash = "sha256:07c42de52dfef56cdcaf2278f58e837b26f5b5af5f1fd133a68c4af203851fc7"}, ] @@ -1859,6 +1832,13 @@ pyyaml = [ {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"}, + {file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"}, + {file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"}, {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, diff --git a/pyproject.toml b/pyproject.toml index 7da339e0dc..494a52e0f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "starlite" -version = "1.34.0" +version = "1.35.0" description = "Light-weight and flexible ASGI API Framework" authors = ["Na'aman Hirschfeld "] maintainers = [ @@ -110,6 +110,7 @@ disable = [ "cyclic-import", "duplicate-code", "fixme", + "import-outside-toplevel", "line-too-long", "missing-class-docstring", "missing-module-docstring", diff --git a/starlite/__init__.py b/starlite/__init__.py index dce9a89e57..362fccb79e 100644 --- a/starlite/__init__.py +++ b/starlite/__init__.py @@ -77,7 +77,7 @@ from starlite.response import Response from starlite.router import Router from starlite.routes import ASGIRoute, BaseRoute, HTTPRoute, WebSocketRoute -from starlite.testing import TestClient, create_test_client # type: ignore[no-redef] +from starlite.testing import TestClient, create_test_client from starlite.types.partial import Partial __all__ = ( diff --git a/starlite/app.py b/starlite/app.py index 675c5eb2d0..555ce09fdc 100644 --- a/starlite/app.py +++ b/starlite/app.py @@ -1,43 +1,31 @@ -from collections import defaultdict from datetime import date, datetime, time, timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, cast -from starlette.middleware import Middleware as StarletteMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.staticfiles import StaticFiles from typing_extensions import TypedDict -from starlite.asgi import ( - ASGIRouter, - PathParameterTypePathDesignator, - PathParamNode, - RouteMapNode, -) +from starlite.asgi import ASGIRouter +from starlite.asgi.utils import get_route_handlers, wrap_in_exception_handler from starlite.config import AppConfig, CacheConfig, OpenAPIConfig from starlite.config.logging import get_logger_placeholder from starlite.connection import Request, WebSocket from starlite.datastructures.state import State -from starlite.exceptions import ( - ImproperlyConfiguredException, - NoRouteMatchFoundException, -) -from starlite.handlers.asgi import asgi +from starlite.exceptions import NoRouteMatchFoundException from starlite.handlers.http import HTTPRouteHandler from starlite.middleware.compression.base import CompressionMiddleware -from starlite.middleware.csrf import CSRFMiddleware -from starlite.middleware.exceptions import ExceptionHandlerMiddleware from starlite.router import Router -from starlite.routes import ASGIRoute, BaseRoute, HTTPRoute, WebSocketRoute +from starlite.routes import ASGIRoute, HTTPRoute, WebSocketRoute from starlite.signature import SignatureModelFactory +from starlite.types.internal_types import PathParameterDefinition from starlite.utils import as_async_callable_list, join_paths, unique if TYPE_CHECKING: from pydantic_openapi_schema.v3_1_0 import SecurityRequirement from pydantic_openapi_schema.v3_1_0.open_api import OpenAPI - from starlite.asgi import ComponentsSet, PathParamPlaceholderType from starlite.config import ( BaseLoggingConfig, CompressionConfig, @@ -49,7 +37,6 @@ from starlite.datastructures import CacheControlHeader, ETag, Provide from starlite.handlers.base import BaseRouteHandler from starlite.plugins.base import PluginProtocol - from starlite.routes.base import PathParameterDefinition from starlite.types import ( AfterExceptionHookHandler, AfterRequestHookHandler, @@ -109,22 +96,8 @@ class HandlerIndex(TypedDict): """Unique identifier of the handler. Either equal to the 'name' attribute or the __str__ value of the handler.""" -class HandlerNode(TypedDict): - """This class encapsulates a route handler node.""" - - asgi_app: "ASGIApp" - """ASGI App stack""" - handler: "RouteHandlerType" - """Route handler instance.""" - - class Starlite(Router): __slots__ = ( - "_init", - "_registered_routes", - "_route_handler_index", - "_route_mapping", - "_static_paths", "after_exception", "after_shutdown", "after_startup", @@ -146,7 +119,6 @@ class Starlite(Router): "on_startup", "openapi_config", "openapi_schema", - "plain_routes", "plugins", "request_class", "route_map", @@ -277,17 +249,12 @@ def __init__( websocket_class: An optional subclass of [WebSocket][starlite.connection.websocket.WebSocket] to use for websocket connections. """ - self._registered_routes: Set[BaseRoute] = set() - self._route_mapping: Dict[str, List[BaseRoute]] = defaultdict(list) - self._route_handler_index: Dict[str, "RouteHandlerType"] = {} - self._static_paths: Set[str] = set() self.openapi_schema: Optional["OpenAPI"] = None self.get_logger: "GetLogger" = get_logger_placeholder self.logger: Optional["Logger"] = None - self.plain_routes: Set[str] = set() - self.route_map: RouteMapNode = {} - self.routes: List[BaseRoute] = [] + self.routes: List[Union["HTTPRoute", "ASGIRoute", "WebSocketRoute"]] = [] self.state = State() + self.asgi_router = ASGIRouter(app=self) # creates app config object from parameters config = AppConfig( @@ -391,10 +358,8 @@ def __init__( for static_config in ( self.static_files_config if isinstance(self.static_files_config, list) else [self.static_files_config] ): - self._static_paths.add(static_config.path) - self.register(asgi(path=static_config.path, name=static_config.name)(static_config.to_static_files_app())) + self.register(static_config.to_static_files_app()) - self.asgi_router = ASGIRouter(app=self) self.asgi_handler = self._create_asgi_handler() async def __call__( @@ -434,22 +399,27 @@ def register(self, value: "ControllerRouterHandler") -> None: # type: ignore[ov None """ routes = super().register(value=value) + for route in routes: - route_handlers = self._get_route_handlers(route) - for route_handler in route_handlers: + route_handlers = get_route_handlers(route) + for route_handler in route_handlers: self._create_handler_signature_model(route_handler=route_handler) route_handler.resolve_guards() route_handler.resolve_middleware() + if isinstance(route_handler, HTTPRouteHandler): route_handler.resolve_before_request() route_handler.resolve_after_response() route_handler.resolve_response_handler() + if isinstance(route, HTTPRoute): route.create_handler_map() + elif isinstance(route, WebSocketRoute): route.handler_parameter_model = route.create_handler_kwargs_model(route.route_handler) - self._construct_route_map() + + self.asgi_router.construct_routing_trie() def get_handler_index_by_name(self, name: str) -> Optional[HandlerIndex]: """Receives a route handler name and returns an optional dictionary @@ -478,12 +448,12 @@ def handler() -> None: Returns: A [HandlerIndex][starlite.app.HandlerIndex] instance or None. """ - handler = self._route_handler_index.get(name) + handler = self.asgi_router.route_handler_index.get(name) if not handler: return None identifier = handler.name or str(handler) - routes = self._route_mapping[identifier] + routes = self.asgi_router.route_mapping[identifier] paths = sorted(unique([route.path for route in routes])) return HandlerIndex(handler=handler, paths=paths, identifier=identifier) @@ -526,25 +496,26 @@ def get_membership_details(group_id: int, user_id: int) -> None: output: List[str] = [] routes = sorted( - self._route_mapping[handler_index["identifier"]], key=lambda r: len(r.path_parameters), reverse=True + self.asgi_router.route_mapping[handler_index["identifier"]], + key=lambda r: len(r.path_parameters), + reverse=True, ) passed_parameters = set(path_parameters.keys()) selected_route = routes[-1] for route in routes: - if passed_parameters.issuperset({param["name"] for param in route.path_parameters}): + if passed_parameters.issuperset({param.name for param in route.path_parameters}): selected_route = route break for component in selected_route.path_components: - if isinstance(component, dict): - val = path_parameters.get(component["name"]) + if isinstance(component, PathParameterDefinition): + val = path_parameters.get(component.name) if not ( - isinstance(val, component["type"]) - or (component["type"] in allow_str_instead and isinstance(val, str)) + isinstance(val, component.type) or (component.type in allow_str_instead and isinstance(val, str)) ): raise NoRouteMatchFoundException( - f"Received type for path parameter {component['name']} doesn't match declared type {component['type']}" + f"Received type for path parameter {component.name} doesn't match declared type {component.type}" ) output.append(str(val)) else: @@ -586,7 +557,7 @@ def url_for_static_asset(self, name: str, file_path: str) -> str: if not isinstance(handler_fn, StaticFiles): raise NoRouteMatchFoundException(f"Handler with name {name} is not a static files handler") - return join_paths([handler_index["paths"][0], file_path]) # type: ignore [unreachable] + return join_paths([handler_index["paths"][0], file_path]) # type: ignore[unreachable] @property def route_handler_method_view(self) -> Dict[str, List[str]]: @@ -595,7 +566,7 @@ def route_handler_method_view(self) -> Dict[str, List[str]]: A dictionary mapping route handlers to paths. """ route_map: Dict[str, List[str]] = {} - for handler, routes in self._route_mapping.items(): + for handler, routes in self.asgi_router.route_mapping.items(): route_map[handler] = [route.path for route in routes] return route_map @@ -614,163 +585,8 @@ def _create_asgi_handler(self) -> "ASGIApp": asgi_handler = TrustedHostMiddleware(app=asgi_handler, allowed_hosts=self.allowed_hosts) # type: ignore if self.cors_config: asgi_handler = CORSMiddleware(app=asgi_handler, **self.cors_config.dict()) # type: ignore - return self._wrap_in_exception_handler(asgi_handler, exception_handlers=self.exception_handlers or {}) - - def _wrap_in_exception_handler(self, app: "ASGIApp", exception_handlers: "ExceptionHandlersMap") -> "ASGIApp": - """Wraps the given ASGIApp in an instance of - ExceptionHandlerMiddleware.""" - return ExceptionHandlerMiddleware(app=app, exception_handlers=exception_handlers, debug=self.debug) - - def _add_node_to_route_map(self, route: BaseRoute) -> RouteMapNode: - """Adds a new route path (e.g. '/foo/bar/{param:int}') into the - route_map tree. - - Inserts non-parameter paths ('plain routes') off the tree's root - node. For paths containing parameters, splits the path on '/' - and nests each path segment under the previous segment's node - (see prefix tree / trie). - """ - current_node = self.route_map - path = route.path - - if route.path_parameters or path in self._static_paths: - components = cast( - "List[Union[str, PathParamPlaceholderType, PathParameterDefinition]]", ["/", *route.path_components] - ) - for component in components: - components_set = cast("ComponentsSet", current_node["_components"]) - - if isinstance(component, dict): - # The rest of the path should be regarded as a parameter value. - if component["type"] is Path: - components_set.add(PathParameterTypePathDesignator) - # Represent path parameters using a special value - component = PathParamNode - - components_set.add(component) - - if component not in current_node: - current_node[component] = {"_components": set()} - current_node = cast("RouteMapNode", current_node[component]) - if "_static_path" in current_node: - raise ImproperlyConfiguredException("Cannot have configured routes below a static path") - else: - if path not in self.route_map: - self.route_map[path] = {"_components": set()} - self.plain_routes.add(path) - current_node = self.route_map[path] - self._configure_route_map_node(route, current_node) - return current_node - - @staticmethod - def _get_route_handlers(route: BaseRoute) -> List["RouteHandlerType"]: - """Retrieve handler(s) as a list for given route.""" - route_handlers: List["RouteHandlerType"] = [] - if isinstance(route, (WebSocketRoute, ASGIRoute)): - route_handlers.append(route.route_handler) - else: - route_handlers.extend(cast("HTTPRoute", route).route_handlers) - - return route_handlers - - def _store_handler_to_route_mapping(self, route: BaseRoute) -> None: - """Stores the mapping of route handlers to routes and to route handler - names. - - Args: - route: A Route instance. - - Returns: - None - """ - route_handlers = self._get_route_handlers(route) - - for handler in route_handlers: - if handler.name in self._route_handler_index and str(self._route_handler_index[handler.name]) != str( - handler - ): - raise ImproperlyConfiguredException( - f"route handler names must be unique - {handler.name} is not unique." - ) - identifier = handler.name or str(handler) - self._route_mapping[identifier].append(route) - self._route_handler_index[identifier] = handler - - def _configure_route_map_node(self, route: BaseRoute, node: RouteMapNode) -> None: - """Set required attributes and route handlers on route_map tree - node.""" - if "_path_parameters" not in node: - node["_path_parameters"] = route.path_parameters - if "_asgi_handlers" not in node: - node["_asgi_handlers"] = {} - if "_is_asgi" not in node: - node["_is_asgi"] = False - if route.path in self._static_paths: - if node["_components"]: - raise ImproperlyConfiguredException("Cannot have configured routes below a static path") - node["_static_path"] = route.path - node["_is_asgi"] = True - asgi_handlers = cast("Dict[str, HandlerNode]", node["_asgi_handlers"]) - if isinstance(route, HTTPRoute): - for method, handler_mapping in route.route_handler_map.items(): - handler, _ = handler_mapping - asgi_handlers[method] = HandlerNode( - asgi_app=self._build_route_middleware_stack(route, handler), - handler=handler, - ) - elif isinstance(route, WebSocketRoute): - asgi_handlers["websocket"] = HandlerNode( - asgi_app=self._build_route_middleware_stack(route, route.route_handler), - handler=route.route_handler, - ) - elif isinstance(route, ASGIRoute): - asgi_handlers["asgi"] = HandlerNode( - asgi_app=self._build_route_middleware_stack(route, route.route_handler), - handler=route.route_handler, - ) - node["_is_asgi"] = True - - def _construct_route_map(self) -> None: - """Create a map of the app's routes. - - This map is used in the asgi router to route requests. - """ - if "_components" not in self.route_map: - self.route_map["_components"] = set() - new_routes = [route for route in self.routes if route not in self._registered_routes] - for route in new_routes: - node = self._add_node_to_route_map(route) - if node["_path_parameters"] != route.path_parameters: - raise ImproperlyConfiguredException("Should not use routes with conflicting path parameters") - self._store_handler_to_route_mapping(route) - self._registered_routes.add(route) - - def _build_route_middleware_stack( - self, - route: Union[HTTPRoute, WebSocketRoute, ASGIRoute], - route_handler: "RouteHandlerType", - ) -> "ASGIApp": - """Constructs a middleware stack that serves as the point of entry for - each route.""" - - # we wrap the route.handle method in the ExceptionHandlerMiddleware - asgi_handler = self._wrap_in_exception_handler( - app=route.handle, exception_handlers=route_handler.resolve_exception_handlers() # type: ignore[arg-type] - ) - - if self.csrf_config: - asgi_handler = CSRFMiddleware(app=asgi_handler, config=self.csrf_config) - - for middleware in route_handler.resolve_middleware(): - if isinstance(middleware, StarletteMiddleware): - handler, kwargs = middleware - asgi_handler = handler(app=asgi_handler, **kwargs) - else: - asgi_handler = middleware(app=asgi_handler) # type: ignore - - # we wrap the entire stack again in ExceptionHandlerMiddleware - return self._wrap_in_exception_handler( - app=asgi_handler, exception_handlers=route_handler.resolve_exception_handlers() # pyright: ignore + return wrap_in_exception_handler( + debug=self.debug, app=asgi_handler, exception_handlers=self.exception_handlers or {} ) def _create_handler_signature_model(self, route_handler: "BaseRouteHandler") -> None: diff --git a/starlite/asgi.py b/starlite/asgi.py deleted file mode 100644 index e98a6c332b..0000000000 --- a/starlite/asgi.py +++ /dev/null @@ -1,318 +0,0 @@ -import re -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from pathlib import Path -from traceback import format_exc -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Set, - Tuple, - Type, - Union, - cast, -) -from uuid import UUID - -from pydantic.datetime_parse import ( - parse_date, - parse_datetime, - parse_duration, - parse_time, -) - -from starlite.enums import ScopeType -from starlite.exceptions import ( - MethodNotAllowedException, - NotFoundException, - ValidationException, -) -from starlite.utils import AsyncCallable - -if TYPE_CHECKING: - from starlite.app import HandlerNode, Starlite - from starlite.routes.base import PathParameterDefinition - from starlite.types import ( - ASGIApp, - LifeSpanHandler, - LifeSpanReceive, - LifeSpanSend, - LifeSpanShutdownCompleteEvent, - LifeSpanShutdownFailedEvent, - LifeSpanStartupCompleteEvent, - LifeSpanStartupFailedEvent, - Receive, - RouteHandlerType, - Scope, - Send, - ) - - -class PathParamNode: - """Sentinel object to represent a path param in the route map.""" - - -class PathParameterTypePathDesignator: - """Sentinel object to a path parameter of type 'path'.""" - - -PathParamPlaceholderType = Type[PathParamNode] -TerminusNodePlaceholderType = Type[PathParameterTypePathDesignator] -RouteMapNode = Dict[Union[str, PathParamPlaceholderType], Any] -ComponentsSet = Set[Union[str, PathParamPlaceholderType, TerminusNodePlaceholderType]] - - -class ASGIRouter: - __slots__ = ("app",) - - def __init__( - self, - app: "Starlite", - ) -> None: - """This class is the Starlite ASGI router. It handles both the ASGI - lifespan event and routing connection requests. - - Args: - app: The Starlite app instance - """ - self.app = app - - def _traverse_route_map(self, path: str, scope: "Scope") -> Tuple[RouteMapNode, List[str]]: - """Traverses the application route mapping and retrieves the correct - node for the request url. - - Args: - path: The request's path. - scope: The ASGI connection scope. - - Raises: - NotFoundException: if no correlating node is found. - - Returns: - A tuple containing the target RouteMapNode and a list containing all path parameter values. - """ - path_params: List[str] = [] - current_node = self.app.route_map - components = ["/", *[component for component in path.split("/") if component]] - for idx, component in enumerate(components): - components_set = cast("ComponentsSet", current_node["_components"]) - if component in components_set: - current_node = cast("RouteMapNode", current_node[component]) - if "_static_path" in current_node: - self._handle_static_path(scope=scope, node=current_node) - break - continue - if PathParamNode in components_set: - current_node = cast("RouteMapNode", current_node[PathParamNode]) - if PathParameterTypePathDesignator in components_set: - path_params.append("/".join(path.split("/")[idx:])) - break - path_params.append(component) - continue - raise NotFoundException() - return current_node, path_params - - @staticmethod - def _handle_static_path(scope: "Scope", node: RouteMapNode) -> None: - """Normalize the static path and update scope so file resolution will - work as expected. - - Args: - scope: The ASGI connection scope. - node: Trie Node - - Returns: - None - """ - static_path = cast("str", node["_static_path"]) - if static_path != "/" and scope["path"].startswith(static_path): - start_idx = len(static_path) - scope["path"] = scope["path"][start_idx:] + "/" - - @staticmethod - def _parse_path_parameters( - path_parameter_definitions: List["PathParameterDefinition"], request_path_parameter_values: List[str] - ) -> Dict[str, Any]: - """Parses path parameters into their expected types. - - Args: - path_parameter_definitions: A list of [PathParameterDefinition][starlite.route.base.PathParameterDefinition] instances - request_path_parameter_values: A list of raw strings sent as path parameters as part of the request - - Raises: - ValidationException: if path parameter parsing fails - - Returns: - A dictionary mapping path parameter names to parsed values - """ - result: Dict[str, Any] = {} - parsers_map: Dict[Any, Callable] = { - str: str, - float: float, - int: int, - Decimal: Decimal, - UUID: UUID, - Path: lambda x: Path(re.sub("//+", "", (x.lstrip("/")))), - date: parse_date, - datetime: parse_datetime, - time: parse_time, - timedelta: parse_duration, - } - - try: - for idx, parameter_definition in enumerate(path_parameter_definitions): - raw_param_value = request_path_parameter_values[idx] - parameter_type = parameter_definition["type"] - parameter_name = parameter_definition["name"] - parser = parsers_map[parameter_type] - result[parameter_name] = parser(raw_param_value) - return result - except (ValueError, TypeError, KeyError) as e: # pragma: no cover - raise ValidationException( - f"unable to parse path parameters {','.join(request_path_parameter_values)}" - ) from e - - def _parse_scope_to_route(self, scope: "Scope") -> Tuple[Dict[str, "HandlerNode"], bool]: - """Given a scope object, retrieve the _asgi_handlers and _is_asgi - values from correct trie node.""" - - path = scope["path"].strip() - if path != "/" and path.endswith("/"): - path = path.rstrip("/") - if path in self.app.plain_routes: - current_node: RouteMapNode = self.app.route_map[path] - path_params: List[str] = [] - else: - current_node, path_params = self._traverse_route_map(path=path, scope=scope) - - scope["path_params"] = ( - self._parse_path_parameters( - path_parameter_definitions=current_node["_path_parameters"], request_path_parameter_values=path_params - ) - if path_params - else {} - ) - - asgi_handlers = cast("Dict[str, HandlerNode]", current_node["_asgi_handlers"]) - is_asgi = cast("bool", current_node["_is_asgi"]) - return asgi_handlers, is_asgi - - @staticmethod - def _resolve_handler_node( - scope: "Scope", asgi_handlers: Dict[str, "HandlerNode"], is_asgi: bool - ) -> Tuple["ASGIApp", "RouteHandlerType"]: - """Given a scope, returns the ASGI App and route handler for the - route.""" - if is_asgi: - node = asgi_handlers[ScopeType.ASGI] - elif scope["type"] == ScopeType.HTTP: - if scope["method"] not in asgi_handlers: - raise MethodNotAllowedException() - node = asgi_handlers[scope["method"]] - else: - node = asgi_handlers[ScopeType.WEBSOCKET] - return node["asgi_app"], node["handler"] - - async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: - """The main entry point to the Router class.""" - try: - asgi_handlers, is_asgi = self._parse_scope_to_route(scope=scope) - asgi_app, handler = self._resolve_handler_node(scope=scope, asgi_handlers=asgi_handlers, is_asgi=is_asgi) - except KeyError as e: - raise NotFoundException() from e - scope["route_handler"] = handler - await asgi_app(scope, receive, send) - - async def lifespan(self, receive: "LifeSpanReceive", send: "LifeSpanSend") -> None: - """Handles the ASGI "lifespan" event on application startup and - shutdown. - - Args: - receive: The ASGI receive function. - send: The ASGI send function. - - Returns: - None. - """ - message = await receive() - try: - shutdown_event: "LifeSpanShutdownCompleteEvent" = {"type": "lifespan.shutdown.complete"} - - if message["type"] == "lifespan.startup": - await self.startup() - startup_event: "LifeSpanStartupCompleteEvent" = {"type": "lifespan.startup.complete"} - await send(startup_event) - await receive() - else: - await self.shutdown() - await send(shutdown_event) - except BaseException as e: - if message["type"] == "lifespan.startup": - startup_failure_event: "LifeSpanStartupFailedEvent" = { - "type": "lifespan.startup.failed", - "message": format_exc(), - } - await send(startup_failure_event) - else: - shutdown_failure_event: "LifeSpanShutdownFailedEvent" = { - "type": "lifespan.shutdown.failed", - "message": format_exc(), - } - await send(shutdown_failure_event) - raise e - else: - await self.shutdown() - await send(shutdown_event) - - async def _call_lifespan_handler(self, handler: "LifeSpanHandler") -> None: - """Determines whether the lifecycle handler expects an argument, and if - so passes the `app.state` to it. If the handler is an async function, - it awaits the return. - - Args: - handler (LifeSpanHandler): sync or async callable that may or may not have an argument. - """ - async_callable = AsyncCallable(handler) # type: ignore - - if async_callable.num_expected_args > 0: - await async_callable(self.app.state) # type: ignore[arg-type] - else: - await async_callable() - - async def startup(self) -> None: - """Run any [LifeSpanHandlers][starlite.types.LifeSpanHandler] defined - in the application's `.on_startup` list. - - Calls the `before_startup` hook and `after_startup` hook - handlers respectively before and after calling in the lifespan - handlers. - """ - for hook in self.app.before_startup: - await hook(self.app) - - for handler in self.app.on_startup: - await self._call_lifespan_handler(handler) - - for hook in self.app.after_startup: - await hook(self.app) - - async def shutdown(self) -> None: - """Run any [LifeSpanHandlers][starlite.types.LifeSpanHandler] defined - in the application's `.on_shutdown` list. - - Calls the `before_shutdown` hook and `after_shutdown` hook - handlers respectively before and after calling in the lifespan - handlers. - """ - - for hook in self.app.before_shutdown: - await hook(self.app) - - for handler in self.app.on_shutdown: - await self._call_lifespan_handler(handler) - - for hook in self.app.after_shutdown: - await hook(self.app) diff --git a/starlite/asgi/__init__.py b/starlite/asgi/__init__.py new file mode 100644 index 0000000000..6fae011ccf --- /dev/null +++ b/starlite/asgi/__init__.py @@ -0,0 +1,3 @@ +from starlite.asgi.asgi_router import ASGIRouter + +__all__ = ("ASGIRouter",) diff --git a/starlite/asgi/asgi_router.py b/starlite/asgi/asgi_router.py new file mode 100644 index 0000000000..75baa725a3 --- /dev/null +++ b/starlite/asgi/asgi_router.py @@ -0,0 +1,199 @@ +from collections import defaultdict +from traceback import format_exc +from typing import TYPE_CHECKING, Dict, List, Set, Union + +from starlite.asgi.routing_trie import validate_node +from starlite.asgi.routing_trie.mapping import add_map_route_to_trie +from starlite.asgi.routing_trie.traversal import parse_scope_to_route +from starlite.asgi.routing_trie.utils import create_node +from starlite.asgi.utils import get_route_handlers +from starlite.exceptions import ImproperlyConfiguredException +from starlite.utils import AsyncCallable + +if TYPE_CHECKING: + from starlite.app import Starlite + from starlite.asgi.routing_trie.types import RouteTrieNode + from starlite.routes import ASGIRoute, HTTPRoute, WebSocketRoute + from starlite.routes.base import BaseRoute + from starlite.types import ( + LifeSpanHandler, + LifeSpanReceive, + LifeSpanSend, + LifeSpanShutdownCompleteEvent, + LifeSpanShutdownFailedEvent, + LifeSpanStartupCompleteEvent, + LifeSpanStartupFailedEvent, + Receive, + RouteHandlerType, + Scope, + Send, + ) + + +class ASGIRouter: + __slots__ = ( + "_plain_routes", + "_registered_routes", + "app", + "root_route_map_node", + "route_handler_index", + "route_mapping", + ) + + def __init__( + self, + app: "Starlite", + ) -> None: + """This class is the Starlite ASGI router. It handles both the ASGI + lifespan event and routing connection requests. + + Args: + app: The Starlite app instance + """ + self._plain_routes: Set[str] = set() + self._registered_routes: Set[Union["HTTPRoute", "WebSocketRoute", "ASGIRoute"]] = set() + self.app = app + self.root_route_map_node: "RouteTrieNode" = create_node() + self.route_handler_index: Dict[str, "RouteHandlerType"] = {} + self.route_mapping: Dict[str, List["BaseRoute"]] = defaultdict(list) + + async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: + """The main entry point to the Router class.""" + asgi_app, handler = parse_scope_to_route( + root_node=self.root_route_map_node, scope=scope, plain_routes=self._plain_routes + ) + scope["route_handler"] = handler + await asgi_app(scope, receive, send) + + def _store_handler_to_route_mapping(self, route: "BaseRoute") -> None: + """Stores the mapping of route handlers to routes and to route handler + names. + + Args: + route: A Route instance. + + Returns: + None + """ + + for handler in get_route_handlers(route): + if handler.name in self.route_handler_index and str(self.route_handler_index[handler.name]) != str(handler): + raise ImproperlyConfiguredException( + f"route handler names must be unique - {handler.name} is not unique." + ) + identifier = handler.name or str(handler) + self.route_mapping[identifier].append(route) + self.route_handler_index[identifier] = handler + + async def _call_lifespan_handler(self, handler: "LifeSpanHandler") -> None: + """Determines whether the lifecycle handler expects an argument, and if + so passes the `app.state` to it. If the handler is an async function, + it awaits the return. + + Args: + handler (LifeSpanHandler): sync or async callable that may or may not have an argument. + """ + async_callable = AsyncCallable(handler) # type: ignore + + if async_callable.num_expected_args > 0: + await async_callable(self.app.state) # type: ignore[arg-type] + else: + await async_callable() + + def construct_routing_trie(self) -> None: + """Create a map of the app's routes. + + This map is used in the asgi router to route requests. + """ + new_routes = [route for route in self.app.routes if route not in self._registered_routes] + for route in new_routes: + node = add_map_route_to_trie( + root_node=self.root_route_map_node, + route=route, + app=self.app, + plain_routes=self._plain_routes, + ) + + if node["path_parameters"] != route.path_parameters: + raise ImproperlyConfiguredException("Should not use routes with conflicting path parameters") + + self._store_handler_to_route_mapping(route) + self._registered_routes.add(route) + + validate_node(node=self.root_route_map_node) + + async def lifespan(self, receive: "LifeSpanReceive", send: "LifeSpanSend") -> None: + """Handles the ASGI "lifespan" event on application startup and + shutdown. + + Args: + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None. + """ + message = await receive() + try: + shutdown_event: "LifeSpanShutdownCompleteEvent" = {"type": "lifespan.shutdown.complete"} + + if message["type"] == "lifespan.startup": + await self.startup() + startup_event: "LifeSpanStartupCompleteEvent" = {"type": "lifespan.startup.complete"} + await send(startup_event) + await receive() + else: + await self.shutdown() + await send(shutdown_event) + except BaseException as e: + if message["type"] == "lifespan.startup": + startup_failure_event: "LifeSpanStartupFailedEvent" = { + "type": "lifespan.startup.failed", + "message": format_exc(), + } + await send(startup_failure_event) + else: + shutdown_failure_event: "LifeSpanShutdownFailedEvent" = { + "type": "lifespan.shutdown.failed", + "message": format_exc(), + } + await send(shutdown_failure_event) + raise e + else: + await self.shutdown() + await send(shutdown_event) + + async def startup(self) -> None: + """Run any [LifeSpanHandlers][starlite.types.LifeSpanHandler] defined + in the application's `.on_startup` list. + + Calls the `before_startup` hook and `after_startup` hook + handlers respectively before and after calling in the lifespan + handlers. + """ + for hook in self.app.before_startup: + await hook(self.app) + + for handler in self.app.on_startup: + await self._call_lifespan_handler(handler) + + for hook in self.app.after_startup: + await hook(self.app) + + async def shutdown(self) -> None: + """Run any [LifeSpanHandlers][starlite.types.LifeSpanHandler] defined + in the application's `.on_shutdown` list. + + Calls the `before_shutdown` hook and `after_shutdown` hook + handlers respectively before and after calling in the lifespan + handlers. + """ + + for hook in self.app.before_shutdown: + await hook(self.app) + + for handler in self.app.on_shutdown: + await self._call_lifespan_handler(handler) + + for hook in self.app.after_shutdown: + await hook(self.app) diff --git a/starlite/asgi/routing_trie/__init__.py b/starlite/asgi/routing_trie/__init__.py new file mode 100644 index 0000000000..4e5d3b90c7 --- /dev/null +++ b/starlite/asgi/routing_trie/__init__.py @@ -0,0 +1,6 @@ +from starlite.asgi.routing_trie.mapping import add_map_route_to_trie +from starlite.asgi.routing_trie.traversal import parse_scope_to_route +from starlite.asgi.routing_trie.types import RouteTrieNode +from starlite.asgi.routing_trie.validate import validate_node + +__all__ = ["RouteTrieNode", "add_map_route_to_trie", "parse_scope_to_route", "validate_node"] diff --git a/starlite/asgi/routing_trie/mapping.py b/starlite/asgi/routing_trie/mapping.py new file mode 100644 index 0000000000..e23906f5c6 --- /dev/null +++ b/starlite/asgi/routing_trie/mapping.py @@ -0,0 +1,156 @@ +from pathlib import Path +from typing import TYPE_CHECKING, Set, Type, Union, cast + +from starlette.middleware import Middleware as StarletteMiddleware + +from starlite.asgi.routing_trie.types import ASGIHandlerTuple, PathParameterSentinel +from starlite.asgi.routing_trie.utils import create_node +from starlite.asgi.utils import wrap_in_exception_handler +from starlite.types.internal_types import PathParameterDefinition + +if TYPE_CHECKING: + from starlite.app import Starlite + from starlite.asgi.routing_trie.types import RouteTrieNode + from starlite.routes import ASGIRoute, HTTPRoute, WebSocketRoute + from starlite.types import ASGIApp, RouteHandlerType + + +def add_map_route_to_trie( + app: "Starlite", + root_node: "RouteTrieNode", + route: Union["HTTPRoute", "WebSocketRoute", "ASGIRoute"], + plain_routes: Set[str], +) -> "RouteTrieNode": + """Adds a new route path (e.g. '/foo/bar/{param:int}') into the route_map + tree. + + Inserts non-parameter paths ('plain routes') off the tree's root + node. For paths containing parameters, splits the path on '/' and + nests each path segment under the previous segment's node (see + prefix tree / trie). + + Args: + app: The Starlite app instance. + root_node: The root trie node. + route: The route that is being added. + plain_routes: The set of plain routes. + + Returns: + A RouteTrieNode instance. + """ + current_node = root_node + path = route.path + + is_mount = hasattr(route, "route_handler") and getattr(route.route_handler, "is_mount", False) # type: ignore[union-attr] + + if not (route.path_parameters or is_mount): + plain_routes.add(path) + if path not in root_node["children"]: + current_node["children"][path] = create_node() + current_node = root_node["children"][path] + else: + for component in route.path_components: + if isinstance(component, PathParameterDefinition): + if component.type is Path: + current_node["is_path_type"] = True + break + + next_node_key: Union[Type[PathParameterSentinel], str] = PathParameterSentinel + + else: + next_node_key = component + + if next_node_key not in current_node["children"]: + current_node["children"][next_node_key] = create_node() + + current_node["child_keys"] = set(current_node["children"].keys()) + + current_node = current_node["children"][next_node_key] + + configure_node(route=route, app=app, node=current_node) + return current_node + + +def configure_node( + app: "Starlite", + route: Union["HTTPRoute", "WebSocketRoute", "ASGIRoute"], + node: "RouteTrieNode", +) -> None: + """Set required attributes and route handlers on route_map tree node. + + Args: + app: The Starlite app instance. + route: The route that is being added. + node: The trie node being configured. + + Returns: + None + """ + from starlite.routes import HTTPRoute, WebSocketRoute + + if not node["path_parameters"]: + node["path_parameters"] = route.path_parameters + + if isinstance(route, HTTPRoute): + for method, handler_mapping in route.route_handler_map.items(): + handler, _ = handler_mapping + node["asgi_handlers"][method] = ASGIHandlerTuple( + asgi_app=build_route_middleware_stack(app=app, route=route, route_handler=handler), + handler=handler, + ) + + elif isinstance(route, WebSocketRoute): + node["asgi_handlers"]["websocket"] = ASGIHandlerTuple( + asgi_app=build_route_middleware_stack(app=app, route=route, route_handler=route.route_handler), + handler=route.route_handler, + ) + + else: + node["asgi_handlers"]["asgi"] = ASGIHandlerTuple( + asgi_app=build_route_middleware_stack(app=app, route=route, route_handler=route.route_handler), + handler=route.route_handler, + ) + node["is_asgi"] = True + node["is_mount"] = route.route_handler.is_mount + node["is_static"] = route.route_handler.is_static + + +def build_route_middleware_stack( + app: "Starlite", + route: Union["HTTPRoute", "WebSocketRoute", "ASGIRoute"], + route_handler: "RouteHandlerType", +) -> "ASGIApp": + """Constructs a middleware stack that serves as the point of entry for each + route. + + Args: + app: The Starlite app instance. + route: The route that is being added. + route_handler: The route handler that is being wrapped. + + Returns: + An ASGIApp that is composed of a "stack" of middlewares. + """ + from starlite.middleware.csrf import CSRFMiddleware + + # we wrap the route.handle method in the ExceptionHandlerMiddleware + asgi_handler = wrap_in_exception_handler( + debug=app.debug, app=route.handle, exception_handlers=route_handler.resolve_exception_handlers() # type: ignore[arg-type] + ) + + if app.csrf_config: + asgi_handler = CSRFMiddleware(app=asgi_handler, config=app.csrf_config) + + for middleware in route_handler.resolve_middleware(): + if isinstance(middleware, StarletteMiddleware): + handler, kwargs = middleware + asgi_handler = handler(app=asgi_handler, **kwargs) + else: + asgi_handler = middleware(app=asgi_handler) # type: ignore + + # we wrap the entire stack again in ExceptionHandlerMiddleware + return wrap_in_exception_handler( + debug=app.debug, + app=cast("ASGIApp", asgi_handler), + exception_handlers=route_handler.resolve_exception_handlers(), + ) # pyright: ignore diff --git a/starlite/asgi/routing_trie/traversal.py b/starlite/asgi/routing_trie/traversal.py new file mode 100644 index 0000000000..bbced270a8 --- /dev/null +++ b/starlite/asgi/routing_trie/traversal.py @@ -0,0 +1,199 @@ +from collections import deque +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from pathlib import Path +from re import sub +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Deque, + Dict, + List, + Set, + Tuple, + Type, + Union, +) +from uuid import UUID + +from pydantic.datetime_parse import ( + parse_date, + parse_datetime, + parse_duration, + parse_time, +) + +from starlite.asgi.routing_trie.types import PathParameterSentinel +from starlite.enums import ScopeType +from starlite.exceptions import ( + MethodNotAllowedException, + NotFoundException, + ValidationException, +) +from starlite.utils import normalize_path + +if TYPE_CHECKING: + from starlite.asgi.routing_trie.types import ASGIHandlerTuple, RouteTrieNode + from starlite.types import Scope + from starlite.types.internal_types import PathParameterDefinition + +parsers_map: Dict[Any, Callable[[Any], Any]] = { + str: str, + float: float, + int: int, + Decimal: Decimal, + UUID: UUID, + Path: lambda x: Path(sub("//+", "", (x.lstrip("/")))), + date: parse_date, + datetime: parse_datetime, + time: parse_time, + timedelta: parse_duration, +} + + +def traverse_route_map( + current_node: "RouteTrieNode", + path: str, + path_components: Deque[Union[str, Type[PathParameterSentinel]]], + path_params: List[str], + scope: "Scope", +) -> Tuple["RouteTrieNode", List[str]]: + """Traverses the application route mapping and retrieves the correct node + for the request url. + + Args: + current_node: A trie node. + path: The request's path. + path_components: A list of ordered path components. + path_params: A list of extracted path parameters. + scope: The ASGI connection scope. + + Raises: + NotFoundException: if no correlating node is found. + + Returns: + A tuple containing the target RouteMapNode and a list containing all path parameter values. + """ + + if current_node["is_mount"]: + if current_node["is_static"] and not (path_components and path_components[0] in current_node["child_keys"]): + # static paths require an ending slash. + scope["path"] = normalize_path("/".join(path_components) + "/") # type: ignore[arg-type] + return current_node, path_params + if not current_node["is_static"]: + scope["path"] = normalize_path("/".join(path_components)) # type: ignore[arg-type] + return current_node, path_params + + if current_node["is_path_type"]: + path_params.append(normalize_path("/".join(path_components))) # type: ignore[arg-type] + return current_node, path_params + + has_path_param = PathParameterSentinel in current_node["child_keys"] + + if not path_components: + if has_path_param or not current_node["asgi_handlers"]: + raise NotFoundException() + return current_node, path_params + + component = path_components.popleft() + + if component in current_node["child_keys"]: + return traverse_route_map( + current_node=current_node["children"][component], + path=path, + path_components=path_components, + path_params=path_params, + scope=scope, + ) + + if has_path_param: + path_params.append(component) # type: ignore[arg-type] + + return traverse_route_map( + current_node=current_node["children"][PathParameterSentinel], + path=path, + path_components=path_components, + path_params=path_params, + scope=scope, + ) + + raise NotFoundException() + + +def parse_path_parameters( + path_parameter_definitions: List["PathParameterDefinition"], request_path_parameter_values: List[str] +) -> Dict[str, Any]: + """Parses path parameters into their expected types. + + Args: + path_parameter_definitions: A list of [PathParameterDefinition][starlite.route.base.PathParameterDefinition] instances + request_path_parameter_values: A list of raw strings sent as path parameters as part of the request + + Raises: + ValidationException: if path parameter parsing fails + + Returns: + A dictionary mapping path parameter names to parsed values + """ + result: Dict[str, Any] = {} + + try: + for idx, parameter_definition in enumerate(path_parameter_definitions): + raw_param_value = request_path_parameter_values[idx] + parser = parsers_map[parameter_definition.type] + result[parameter_definition.name] = parser(raw_param_value) + return result + except (ValueError, TypeError, KeyError) as e: # pragma: no cover + raise ValidationException(f"unable to parse path parameters {','.join(request_path_parameter_values)}") from e + + +def parse_scope_to_route(root_node: "RouteTrieNode", scope: "Scope", plain_routes: Set[str]) -> "ASGIHandlerTuple": + """Given a scope object, retrieve the asgi_handlers and is_mount boolean + values from correct trie node. + + Args: + root_node: The root trie node. + scope: The ASGI scope instance. + plain_routes: The set of plain routes. + + Raises: + MethodNotAllowedException: if no matching method is found. + + Returns: + A tuple containing the stack of middlewares and the route handler that is wrapped by it. + """ + + path = scope["path"].strip().rstrip("/") or "/" + + if path in plain_routes: + current_node: "RouteTrieNode" = root_node["children"][path] + scope["path_params"] = {} + + else: + current_node, path_params = traverse_route_map( + current_node=root_node, + path=path, + path_components=deque([component for component in path.split("/") if component]), + path_params=[], + scope=scope, + ) + scope["path_params"] = ( + parse_path_parameters( + path_parameter_definitions=current_node["path_parameters"], + request_path_parameter_values=path_params, + ) + if path_params + else {} + ) + + try: + if current_node["is_asgi"]: + return current_node["asgi_handlers"]["asgi"] + + if scope["type"] == ScopeType.HTTP: + return current_node["asgi_handlers"][scope["method"]] + + return current_node["asgi_handlers"]["websocket"] + except KeyError as e: + raise MethodNotAllowedException() from e diff --git a/starlite/asgi/routing_trie/types.py b/starlite/asgi/routing_trie/types.py new file mode 100644 index 0000000000..4f655b5b7f --- /dev/null +++ b/starlite/asgi/routing_trie/types.py @@ -0,0 +1,61 @@ +from typing import TYPE_CHECKING, Dict, List, NamedTuple, Set, Type, Union + +from typing_extensions import TypedDict + +if TYPE_CHECKING: + from typing_extensions import Literal + + from starlite.types import ASGIApp, Method, RouteHandlerType + from starlite.types.internal_types import PathParameterDefinition + + +class PathParameterSentinel: + """Sentinel class designating a path parameter.""" + + +class ASGIHandlerTuple(NamedTuple): + """This class encapsulates a route handler node.""" + + asgi_app: "ASGIApp" + """An ASGI stack, composed of a handler function and layers of middleware that wrap it.""" + handler: "RouteHandlerType" + """The route handler instance.""" + + +class RouteTrieNode(TypedDict): + """This class represents a radix trie node.""" + + asgi_handlers: Dict[Union["Method", "Literal['websocket', 'asgi']"], "ASGIHandlerTuple"] + """ + A mapping of ASGI handlers stored on the node. + """ + child_keys: Set[Union[str, Type[PathParameterSentinel]]] + """ + A set containing the child keys, same as the children dictionary - but as a set, which offers faster lookup. + """ + children: Dict[Union[str, Type[PathParameterSentinel]], "RouteTrieNode"] # type: ignore[misc] + """ + A dictionary mapping path components or using the PathParameterSentinel class to child nodes. + """ + is_asgi: bool + """ + Designate the node as having an `@asgi` type handler. + """ + is_mount: bool + """ + Designates the node as a "mount" path, meaning that the handler function will be forwarded all sub paths. + """ + is_path_type: bool + """ + Designates the node as expecting a path parameter of type 'Path', + which means that any sub path under the node is considered to be a path parameter value rather than a url. + """ + is_static: bool + """ + Designates the node as a static path node, which means that any sub path under the node is considered to be + a file path in one of the static directories. + """ + path_parameters: List["PathParameterDefinition"] + """ + A list of tuples containing path parameter definitions. This is used for parsing extracted path parameter values. + """ diff --git a/starlite/asgi/routing_trie/utils.py b/starlite/asgi/routing_trie/utils.py new file mode 100644 index 0000000000..52677cf5c3 --- /dev/null +++ b/starlite/asgi/routing_trie/utils.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from starlite.asgi.routing_trie.types import RouteTrieNode + + +def create_node() -> "RouteTrieNode": + """Creates a RouteMapNode instance. + + Returns: + A route map node instance. + """ + + return { + "asgi_handlers": {}, + "child_keys": set(), + "children": {}, + "is_asgi": False, + "is_mount": False, + "is_static": False, + "is_path_type": False, + "path_parameters": [], + } diff --git a/starlite/asgi/routing_trie/validate.py b/starlite/asgi/routing_trie/validate.py new file mode 100644 index 0000000000..3617c93deb --- /dev/null +++ b/starlite/asgi/routing_trie/validate.py @@ -0,0 +1,29 @@ +from typing import TYPE_CHECKING + +from starlite.asgi.routing_trie.types import PathParameterSentinel +from starlite.exceptions import ImproperlyConfiguredException + +if TYPE_CHECKING: + from starlite.asgi.routing_trie.types import RouteTrieNode + + +def validate_node(node: "RouteTrieNode") -> None: + """Recursively traverses the trie from the given node upwards. + + Args: + node: A trie node. + + Raises: + ImproperlyConfiguredException + + Returns: + None + """ + if node["is_asgi"] and bool(set(node["asgi_handlers"]).difference({"asgi"})): + raise ImproperlyConfiguredException("ASGI handlers must have a unique path not shared by other route handlers.") + + if node["is_static"] and PathParameterSentinel in node["child_keys"]: + raise ImproperlyConfiguredException("Path parameters cannot be configured for a static path.") + + for child in node["children"].values(): + validate_node(node=child) diff --git a/starlite/asgi/utils.py b/starlite/asgi/utils.py new file mode 100644 index 0000000000..1b86da12b6 --- /dev/null +++ b/starlite/asgi/utils.py @@ -0,0 +1,42 @@ +from typing import TYPE_CHECKING, List, cast + +if TYPE_CHECKING: + from typing import Union + + from starlite.routes import ASGIRoute, HTTPRoute, WebSocketRoute + from starlite.routes.base import BaseRoute + from starlite.types import ASGIApp, ExceptionHandlersMap, RouteHandlerType + + +def wrap_in_exception_handler(debug: bool, app: "ASGIApp", exception_handlers: "ExceptionHandlersMap") -> "ASGIApp": + """Wraps the given ASGIApp in an instance of ExceptionHandlerMiddleware. + + Args: + debug: Dictates whether exceptions are raised in debug mode. + app: The ASGI app that is being wrapped. + exception_handlers: A mapping of exceptions to handler functions. + + Returns: + A wrapped ASGIApp. + """ + from starlite.middleware.exceptions import ExceptionHandlerMiddleware + + return ExceptionHandlerMiddleware(app=app, exception_handlers=exception_handlers, debug=debug) + + +def get_route_handlers(route: "BaseRoute") -> List["RouteHandlerType"]: + """Retrieve handler(s) as a list for given route. + + Args: + route: The route from which the route handlers are extracted. + + Returns: + The route handlers defined on the route. + """ + route_handlers: List["RouteHandlerType"] = [] + if hasattr(route, "route_handlers"): + route_handlers.extend(cast("HTTPRoute", route).route_handlers) + else: + route_handlers.append(cast("Union[WebSocketRoute, ASGIRoute]", route).route_handler) + + return route_handlers diff --git a/starlite/config/compression.py b/starlite/config/compression.py index 53a7e27359..05ff3b0881 100644 --- a/starlite/config/compression.py +++ b/starlite/config/compression.py @@ -83,14 +83,10 @@ def to_middleware(self, app: "ASGIApp") -> "ASGIApp": A middleware instance """ if self.backend == "gzip": - from starlite.middleware.compression.gzip import ( # pylint: disable=import-outside-toplevel - GZipMiddleware, - ) + from starlite.middleware.compression.gzip import GZipMiddleware return cast("ASGIApp", GZipMiddleware(app=app, **self.dict())) - from starlite.middleware.compression.brotli import ( # pylint: disable=import-outside-toplevel - BrotliMiddleware, - ) + from starlite.middleware.compression.brotli import BrotliMiddleware return BrotliMiddleware(app=app, **self.dict()) diff --git a/starlite/config/logging.py b/starlite/config/logging.py index 8c3fc09c2c..d9795486db 100644 --- a/starlite/config/logging.py +++ b/starlite/config/logging.py @@ -180,26 +180,22 @@ def configure(self) -> "GetLogger": Returns: A 'logging.getLogger' like function. """ - try: - if "picologging" in str(dumps(self.handlers)): - - from picologging import ( # pylint: disable=import-outside-toplevel - config, - getLogger, - ) - values = self.dict(exclude_none=True, exclude={"incremental"}) + if "picologging" in str(dumps(self.handlers)): + try: + from picologging import config, getLogger + except ImportError as e: # pragma: no cover + raise MissingDependencyException("picologging is not installed") from e else: - from logging import ( # type: ignore[no-redef] # pylint: disable=import-outside-toplevel - config, - getLogger, - ) - - values = self.dict(exclude_none=True) - config.dictConfig(values) - return cast("Callable[[str], Logger]", getLogger) - except ImportError as e: # pragma: no cover - raise MissingDependencyException("picologging is not installed") from e + values = self.dict(exclude_none=True, exclude={"incremental"}) + + else: + from logging import config, getLogger # type: ignore[no-redef] + + values = self.dict(exclude_none=True) + + config.dictConfig(values) + return cast("Callable[[str], Logger]", getLogger) def default_structlog_processors() -> Optional[Iterable[Processor]]: # pyright: ignore @@ -209,7 +205,7 @@ def default_structlog_processors() -> Optional[Iterable[Processor]]: # pyright: An optional list of processors. """ try: - import structlog # pylint: disable=import-outside-toplevel + import structlog return [ structlog.contextvars.merge_contextvars, @@ -230,7 +226,7 @@ def default_wrapper_class() -> Optional[Type[BindableLogger]]: # pyright: ignor """ try: - import structlog # pylint: disable=import-outside-toplevel + import structlog return structlog.make_filtering_bound_logger(INFO) except ImportError: # pragma: no cover @@ -244,7 +240,7 @@ def default_logger_factory() -> Optional[Callable[..., WrappedLogger]]: An optional logger factory. """ try: - import structlog # pylint: disable=import-outside-toplevel + import structlog return structlog.BytesLoggerFactory() except ImportError: # pragma: no cover @@ -276,10 +272,7 @@ def configure(self) -> "GetLogger": A 'logging.getLogger' like function. """ try: - from structlog import ( # pylint: disable=import-outside-toplevel - configure, - get_logger, - ) + from structlog import configure, get_logger # we now configure structlog configure(**self.dict(exclude={"standard_lib_logging_config"})) diff --git a/starlite/config/static_files.py b/starlite/config/static_files.py index f4662c045a..1317456fbb 100644 --- a/starlite/config/static_files.py +++ b/starlite/config/static_files.py @@ -1,12 +1,13 @@ -from typing import TYPE_CHECKING, List, Optional, cast +from typing import TYPE_CHECKING, List, Optional from pydantic import BaseModel, DirectoryPath, constr, validator from starlette.staticfiles import StaticFiles +from starlite.handlers import asgi from starlite.utils import normalize_path if TYPE_CHECKING: - from starlite.types import ASGIApp + from starlite.handlers import ASGIRouteHandler class StaticFilesConfig(BaseModel): @@ -49,7 +50,7 @@ def validate_path(cls, value: str) -> str: # pylint: disable=no-self-argument raise ValueError("path parameters are not supported for static files") return normalize_path(value) - def to_static_files_app(self) -> "ASGIApp": + def to_static_files_app(self) -> "ASGIRouteHandler": """Returns an ASGI app serving static files based on the config. Returns: @@ -61,4 +62,4 @@ def to_static_files_app(self) -> "ASGIApp": directory=str(self.directories[0]), ) static_files.all_directories = self.directories # type: ignore[assignment] - return cast("ASGIApp", static_files) + return asgi(path=self.path, name=self.name, is_static=True)(static_files) diff --git a/starlite/datastructures/provide.py b/starlite/datastructures/provide.py index fc536fafbf..9102e3b5d0 100644 --- a/starlite/datastructures/provide.py +++ b/starlite/datastructures/provide.py @@ -40,7 +40,7 @@ async def __call__(self, **kwargs: Any) -> Any: if self.use_cache: self.value = value - return value # noqa: R504 + return value def __eq__(self, other: Any) -> bool: # check if memory address is identical, otherwise compare attributes diff --git a/starlite/handlers/asgi.py b/starlite/handlers/asgi.py index 821ede168d..9f2eeb60bb 100644 --- a/starlite/handlers/asgi.py +++ b/starlite/handlers/asgi.py @@ -13,6 +13,8 @@ class ASGIRouteHandler(BaseRouteHandler["ASGIRouteHandler"]): + __slots__ = ("is_mount", "is_static") + @validate_arguments(config={"arbitrary_types_allowed": True}) def __init__( self, @@ -22,6 +24,8 @@ def __init__( guards: Optional[List[Guard]] = None, name: Optional[str] = None, opt: Optional[Dict[str, Any]] = None, + is_mount: bool = False, + is_static: bool = False, **kwargs: Any, ) -> None: """ASGI Route Handler decorator. Use this decorator to decorate ASGI @@ -33,8 +37,15 @@ def __init__( name: A string identifying the route handler. opt: A string key dictionary of arbitrary values that can be accessed [Guards][starlite.types.Guard]. path: A path fragment for the route handler function or a list of path fragments. If not given defaults to '/' + is_mount: A boolean dictating whether the handler's paths should be regarded as mount paths. Mount path accept + any arbitrary paths that begin with the defined prefixed path. For example, a mount with the path `/some-path/` + will accept requests for `/some-path/` and any sub path under this, e.g. `/some-path/sub-path/` etc. + is_static: A boolean dictating whether the handler's paths should be regarded as static paths. Static paths + are used to deliver static files. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ + self.is_mount = is_mount or is_static + self.is_static = is_static super().__init__(path, exception_handlers=exception_handlers, guards=guards, name=name, opt=opt, **kwargs) def __call__(self, fn: "AnyCallable") -> "ASGIRouteHandler": @@ -53,10 +64,12 @@ def _validate_handler_function(self) -> None: if signature.return_annotation is not None: raise ImproperlyConfiguredException("ASGI handler functions should return 'None'") + if any(key not in signature.parameters for key in ("scope", "send", "receive")): raise ImproperlyConfiguredException( "ASGI handler functions should define 'scope', 'send' and 'receive' arguments" ) + if not is_async_callable(fn): raise ImproperlyConfiguredException("Functions decorated with 'asgi' must be async functions") diff --git a/starlite/logging/picologging.py b/starlite/logging/picologging.py index 5abab74f5a..4d4efb648d 100644 --- a/starlite/logging/picologging.py +++ b/starlite/logging/picologging.py @@ -12,7 +12,7 @@ raise MissingDependencyException("picologging is not installed") from e -class QueueListenerHandler(QueueHandler): # type: ignore[misc] +class QueueListenerHandler(QueueHandler): def __init__(self, handlers: Optional[List[Any]] = None) -> None: """Configures queue listener and handler to support non-blocking logging configuration. diff --git a/starlite/openapi/parameters.py b/starlite/openapi/parameters.py index f2d617f967..f0a6268e91 100644 --- a/starlite/openapi/parameters.py +++ b/starlite/openapi/parameters.py @@ -20,8 +20,8 @@ from pydantic_openapi_schema.v3_1_0.schema import Schema from starlite.handlers import BaseRouteHandler - from starlite.routes.base import PathParameterDefinition from starlite.types import Dependencies + from starlite.types.internal_types import PathParameterDefinition def create_path_parameter_schema( @@ -29,7 +29,7 @@ def create_path_parameter_schema( ) -> "Schema": """Create a path parameter from the given path_param definition.""" field.sub_fields = None - field.outer_type_ = path_parameter["type"] + field.outer_type_ = path_parameter.type return create_schema(field=field, generate_examples=generate_examples, plugins=[]) @@ -83,10 +83,10 @@ def create_parameter( is_required = cast("bool", model_field.required) if model_field.required is not Undefined else False extra = model_field.field_info.extra - if any(path_param["name"] == parameter_name for path_param in path_parameters): + if any(path_param.name == parameter_name for path_param in path_parameters): param_in = ParamType.PATH is_required = True - path_parameter = [p for p in path_parameters if parameter_name in p["name"]][0] + path_parameter = [p for p in path_parameters if parameter_name in p.name][0] schema = create_path_parameter_schema( path_parameter=path_parameter, field=model_field, diff --git a/starlite/router.py b/starlite/router.py index 50a5c66e9b..3ab3819079 100644 --- a/starlite/router.py +++ b/starlite/router.py @@ -138,7 +138,7 @@ def __init__( self.response_class = response_class self.response_cookies = response_cookies or [] self.response_headers = response_headers or {} - self.routes: List["BaseRoute"] = [] + self.routes: List[Union["HTTPRoute", "ASGIRoute", "WebSocketRoute"]] = [] self.security = security or [] self.tags = tags or [] @@ -161,16 +161,24 @@ def register(self, value: ControllerRouterHandler) -> List["BaseRoute"]: for route_path, handler_or_method_map in self._map_route_handlers(value=validated_value): path = join_paths([self.path, route_path]) if isinstance(handler_or_method_map, WebsocketRouteHandler): - route: "BaseRoute" = WebSocketRoute(path=path, route_handler=handler_or_method_map) + route: Union["WebSocketRoute", "ASGIRoute", "HTTPRoute"] = WebSocketRoute( + path=path, route_handler=handler_or_method_map + ) self.routes.append(route) elif isinstance(handler_or_method_map, ASGIRouteHandler): route = ASGIRoute(path=path, route_handler=handler_or_method_map) self.routes.append(route) else: - existing_handlers: List[HTTPRouteHandler] = list(self.route_handler_method_map.get(path, {}).values()) # type: ignore - route_handlers = unique(list(handler_or_method_map.values())) + existing_handlers = self.route_handler_method_map.get(path, {}) + + if not isinstance(existing_handlers, dict): + raise ImproperlyConfiguredException( + "Cannot have both HTTP routes and websocket / asgi route handlers on the same path" + ) + + route_handlers = unique(handler_or_method_map.values()) if existing_handlers: - route_handlers.extend(unique(existing_handlers)) + route_handlers.extend(unique(existing_handlers.values())) existing_route_index = find_index( self.routes, lambda x: x.path == path # pylint: disable=cell-var-from-loop # noqa: B023 ) diff --git a/starlite/routes/base.py b/starlite/routes/base.py index 389902fc83..5d33315d59 100644 --- a/starlite/routes/base.py +++ b/starlite/routes/base.py @@ -3,7 +3,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union from uuid import UUID from typing_extensions import TypedDict @@ -11,6 +11,7 @@ from starlite.exceptions import ImproperlyConfiguredException from starlite.kwargs import KwargsModel from starlite.signature import get_signature_model +from starlite.types.internal_types import PathParameterDefinition from starlite.utils import join_paths, normalize_path if TYPE_CHECKING: @@ -19,6 +20,7 @@ from starlite.types import Method, Receive, Scope, Send param_match_regex = re.compile(r"{(.*?)}") + param_type_map = { "str": str, "int": int, @@ -33,12 +35,6 @@ } -class PathParameterDefinition(TypedDict): - name: str - full: str - type: Type - - class RouteHandlerIndex(TypedDict): name: str handler: "BaseRouteHandler" @@ -49,7 +45,6 @@ class BaseRoute(ABC): "app", "handler_names", "methods", - "param_convertors", "path", "path_format", "path_parameters", @@ -76,7 +71,7 @@ class meant to be extended. """ self.path, self.path_format, self.path_components = self._parse_path(path) self.path_parameters: List[PathParameterDefinition] = [ - component for component in self.path_components if isinstance(component, dict) + component for component in self.path_components if isinstance(component, PathParameterDefinition) ] self.handler_names = handler_names self.scope_type = scope_type @@ -105,10 +100,9 @@ def create_handler_kwargs_model(self, route_handler: "BaseRouteHandler") -> Kwar path_parameters = set() for param in self.path_parameters: - param_name = param["name"] - if param_name in path_parameters: - raise ImproperlyConfiguredException(f"Duplicate parameter '{param_name}' detected in '{self.path}'.") - path_parameters.add(param_name) + if param.name in path_parameters: + raise ImproperlyConfiguredException(f"Duplicate parameter '{param.name}' detected in '{self.path}'.") + path_parameters.add(param.name) return KwargsModel.create_for_signature_model( signature_model=signature_model, diff --git a/starlite/testing/test_client/transport.py b/starlite/testing/test_client/transport.py index a3950457d2..4378c7bd98 100644 --- a/starlite/testing/test_client/transport.py +++ b/starlite/testing/test_client/transport.py @@ -140,8 +140,8 @@ def handle_request(self, request: "Request") -> "Response": ) session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) raise ConnectionUpgradeException(session) - else: - scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) + + scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) raw_kwargs: Dict[str, Any] = {"stream": BytesIO()} diff --git a/starlite/types/internal_types.py b/starlite/types/internal_types.py index 65f1ce4556..f060c9ec3f 100644 --- a/starlite/types/internal_types.py +++ b/starlite/types/internal_types.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, NamedTuple, Type, Union from typing_extensions import Literal @@ -27,3 +27,9 @@ ResponseType = Type[Response] ControllerRouterHandler = Union[Type[Controller], RouteHandlerType, Router, Callable[..., Any]] RouteHandlerMapItem = Union[WebsocketRouteHandler, ASGIRouteHandler, Dict[Method, HTTPRouteHandler]] + + +class PathParameterDefinition(NamedTuple): + name: str + full: str + type: Type diff --git a/starlite/utils/exception.py b/starlite/utils/exception.py index 34183e2503..bd4a1cd797 100644 --- a/starlite/utils/exception.py +++ b/starlite/utils/exception.py @@ -61,9 +61,7 @@ def to_response(self) -> "Response": Returns: A response instance. """ - from starlite.response import ( # pylint: disable=import-outside-toplevel - Response, - ) + from starlite.response import Response return Response( content=self.dict(exclude_none=True, exclude={"headers"}), diff --git a/starlite/utils/extractors.py b/starlite/utils/extractors.py index eb5fc32187..70c6e028f3 100644 --- a/starlite/utils/extractors.py +++ b/starlite/utils/extractors.py @@ -13,7 +13,7 @@ from typing_extensions import Literal, TypedDict -from starlite.connection import Request +from starlite.connection.request import Request from starlite.datastructures.upload_file import UploadFile from starlite.enums import HttpMethod, RequestEncodingType from starlite.parsers import parse_cookie_string diff --git a/tests/conftest.py b/tests/conftest.py index 232985d919..bdf4b4d198 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,9 @@ -import os -import pathlib -import sys +from os import environ, urandom +from pathlib import Path +from sys import version_info from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Generator, Union, cast -import fakeredis.aioredis # type: ignore -import py # type: ignore +import fakeredis.aioredis # pyright: ignore import pytest from piccolo.conf.apps import Finder from piccolo.table import create_db_tables, drop_db_tables @@ -51,11 +50,11 @@ def pytest_generate_tests(metafunc: Callable) -> None: """Sets ENV variables for testing.""" - os.environ.update(PICCOLO_CONF="tests.piccolo_conf") + environ.update(PICCOLO_CONF="tests.piccolo_conf") @pytest.fixture() -def template_dir(tmp_path: pathlib.Path) -> pathlib.Path: +def template_dir(tmp_path: Path) -> Path: return tmp_path @@ -105,7 +104,7 @@ async def mock_asgi_app(scope: "Scope", receive: "Receive", send: "Send") -> Non @pytest.fixture def cookie_session_backend_config() -> CookieBackendConfig: - return CookieBackendConfig(secret=SecretBytes(os.urandom(16))) + return CookieBackendConfig(secret=SecretBytes(urandom(16))) @pytest.fixture() @@ -119,7 +118,7 @@ def memory_session_backend_config() -> MemoryBackendConfig: @pytest.fixture -def file_session_backend_config(tmpdir: py.path.local) -> FileBackendConfig: +def file_session_backend_config(tmpdir: Path) -> FileBackendConfig: return FileBackendConfig(storage_path=tmpdir) @@ -223,7 +222,7 @@ def session_backend_config(request: pytest.FixtureRequest) -> Union[ServerSideSe def session_backend_config_async_safe( request: pytest.FixtureRequest, ) -> Union[ServerSideSessionConfig, CookieBackendConfig]: - if sys.version_info < (3, 10) and request.param == "redis_session_backend_config": + if version_info < (3, 10) and request.param == "redis_session_backend_config": return pytest.skip("") return cast("Union[ServerSideSessionConfig, CookieBackendConfig]", request.getfixturevalue(request.param)) diff --git a/tests/kwargs/test_path_params.py b/tests/kwargs/test_path_params.py index ad586ab7a2..54c1eec517 100644 --- a/tests/kwargs/test_path_params.py +++ b/tests/kwargs/test_path_params.py @@ -131,6 +131,7 @@ def test_method() -> None: ["datetime", datetime, datetime.now().isoformat()], ["timedelta", timedelta, timedelta(days=1).total_seconds()], ["path", Path, "/1/2/3/4/some-file.txt"], + ["path", Path, "1/2/3/4/some-file.txt"], ], ) def test_path_param_type_resolution(param_type_name: str, param_type_class: Any, value: Any) -> None: diff --git a/tests/middleware/test_exception_handler_middleware.py b/tests/middleware/test_exception_handler_middleware.py index 2f422f0890..77a01b8724 100644 --- a/tests/middleware/test_exception_handler_middleware.py +++ b/tests/middleware/test_exception_handler_middleware.py @@ -97,7 +97,9 @@ def exception_handler(request: Request, exc: Exception) -> Response: return Response(content={"an": "error"}, status_code=HTTP_500_INTERNAL_SERVER_ERROR, media_type=MediaType.JSON) app = Starlite(route_handlers=[handler], exception_handlers={Exception: exception_handler}, openapi_config=None) - assert app.route_map["/"]["_asgi_handlers"]["GET"]["asgi_app"].exception_handlers == {Exception: exception_handler} + assert app.asgi_router.root_route_map_node["children"]["/"]["asgi_handlers"]["GET"][0].exception_handlers == { + Exception: exception_handler + } def test_exception_handler_middleware_calls_app_level_after_exception_hook() -> None: diff --git a/tests/middleware/test_middleware_handling.py b/tests/middleware/test_middleware_handling.py index 2205434c95..eae8f071ed 100644 --- a/tests/middleware/test_middleware_handling.py +++ b/tests/middleware/test_middleware_handling.py @@ -84,7 +84,7 @@ def test_custom_middleware_processing(middleware: Any) -> None: assert app.middleware == [middleware] unpacked_middleware = [] - cur = client.app.route_map["/"]["_asgi_handlers"]["GET"]["asgi_app"] + cur = client.app.asgi_router.root_route_map_node["children"]["/"]["asgi_handlers"]["GET"][0] while hasattr(cur, "app"): unpacked_middleware.append(cur) cur = cast("ASGIApp", cur.app) diff --git a/tests/openapi/test_request_body.py b/tests/openapi/test_request_body.py index 032f32c3c3..19f618dc53 100644 --- a/tests/openapi/test_request_body.py +++ b/tests/openapi/test_request_body.py @@ -19,7 +19,7 @@ class Config(BaseConfig): def test_create_request_body() -> None: for route in Starlite(route_handlers=[PersonController]).routes: for route_handler, _ in route.route_handler_map.values(): # type: ignore - handler_fields = route_handler.signature_model.__fields__ + handler_fields = route_handler.signature_model.__fields__ # type: ignore if "data" in handler_fields: request_body = create_request_body(field=handler_fields["data"], generate_examples=True, plugins=[]) assert request_body diff --git a/tests/routing/test_path_mounting.py b/tests/routing/test_path_mounting.py new file mode 100644 index 0000000000..ffbe1fa43e --- /dev/null +++ b/tests/routing/test_path_mounting.py @@ -0,0 +1,80 @@ +from typing import TYPE_CHECKING + +import pytest + +from starlite import ( + ImproperlyConfiguredException, + MediaType, + Response, + Starlite, + asgi, + get, + websocket, +) +from starlite.status_codes import HTTP_200_OK +from starlite.testing import create_test_client + +if TYPE_CHECKING: + from starlite.connection import WebSocket + from starlite.types import Receive, Scope, Send + + +def test_supports_mounting() -> None: + @asgi("/base/sub/path", is_mount=True) + async def asgi_handler(scope: "Scope", receive: "Receive", send: "Send") -> None: + response = Response(scope["path"], media_type=MediaType.TEXT, status_code=HTTP_200_OK) + await response(scope, receive, send) + + with create_test_client(asgi_handler) as client: + response = client.get("/base/sub/path") + assert response.status_code == HTTP_200_OK + assert response.text == "/" + + response = client.get("/base/sub/path/abcd") + assert response.status_code == HTTP_200_OK + assert response.text == "/abcd" + + response = client.get("/base/sub/path/abcd/complex/123/terminus") + assert response.status_code == HTTP_200_OK + assert response.text == "/abcd/complex/123/terminus" + + +def test_supports_sub_routes_below_asgi_handlers() -> None: + @asgi("/base/sub/path") + async def asgi_handler(scope: "Scope", receive: "Receive", send: "Send") -> None: + response = Response(scope["path"], media_type=MediaType.TEXT, status_code=HTTP_200_OK) + await response(scope, receive, send) + + @get("/base/sub/path/abc") + def regular_handler() -> None: + return + + assert Starlite(route_handlers=[asgi_handler, regular_handler]) + + +def test_does_not_support_asgi_handlers_on_same_level_as_regular_handlers() -> None: + @asgi("/base/sub/path") + async def asgi_handler(scope: "Scope", receive: "Receive", send: "Send") -> None: + response = Response(scope["path"], media_type=MediaType.TEXT, status_code=HTTP_200_OK) + await response(scope, receive, send) + + @get("/base/sub/path") + def regular_handler() -> None: + return + + with pytest.raises(ImproperlyConfiguredException): + Starlite(route_handlers=[asgi_handler, regular_handler]) + + +def test_does_not_support_asgi_handlers_on_same_level_as_websockets() -> None: + @asgi("/base/sub/path") + async def asgi_handler(scope: "Scope", receive: "Receive", send: "Send") -> None: + response = Response(scope["path"], media_type=MediaType.TEXT, status_code=HTTP_200_OK) + await response(scope, receive, send) + + @websocket("/base/sub/path") + async def regular_handler(socket: "WebSocket") -> None: + return + + with pytest.raises(ImproperlyConfiguredException): + Starlite(route_handlers=[asgi_handler, regular_handler]) diff --git a/tests/routing/test_path_resolution.py b/tests/routing/test_path_resolution.py index 420c1a4b05..008100ddb6 100644 --- a/tests/routing/test_path_resolution.py +++ b/tests/routing/test_path_resolution.py @@ -111,7 +111,7 @@ def handler_fn(some_id: int = 1) -> str: assert some_id return str(some_id) - with create_test_client(handler_fn) as client: + with create_test_client(handler_fn, openapi_config=None) as client: first_response = client.get("/") assert first_response.status_code == HTTP_200_OK assert first_response.text == "1" @@ -167,7 +167,7 @@ def test_path_order() -> None: def handler_fn(some_id: int = 1) -> str: return str(some_id) - with create_test_client(handler_fn) as client: + with create_test_client(handler_fn, openapi_config=None) as client: first_response = client.get("/something/5") assert first_response.status_code == HTTP_200_OK assert first_response.text == "5" diff --git a/tests/routing/test_route_map.py b/tests/routing/test_route_map.py deleted file mode 100644 index d4e1d16a8e..0000000000 --- a/tests/routing/test_route_map.py +++ /dev/null @@ -1,155 +0,0 @@ -import re -from random import shuffle -from string import ascii_letters -from typing import List, Set, Tuple, Union, cast - -from hypothesis import given, settings -from hypothesis import strategies as st -from hypothesis.strategies import DrawFn - -from starlite import HTTPRoute, get -from starlite.asgi import PathParamNode, PathParamPlaceholderType, RouteMapNode -from starlite.middleware.exceptions import ExceptionHandlerMiddleware -from starlite.testing import create_test_client - -param_pattern = re.compile(r"{.*?:int}") - -RouteMapTestCase = Tuple[str, str, Set[str]] - - -def is_path_in_route_map(route_map: RouteMapNode, path: str, path_params: Set[str]) -> bool: - if not path_params: - return path in route_map - components = cast( - "List[Union[str, PathParamPlaceholderType]]", - [ - "/", - *[ - PathParamNode if param_pattern.fullmatch(component) else component - for component in path.split("/") - if component - ], - ], - ) - cur_node = route_map - for component in components: - if component not in cur_node: - return False - cur_node = cur_node[component] - route_params = {param["full"] for param in cur_node.get("_path_parameters", [])} - if path_params == route_params: - return True - return False - - -@st.composite -def route_test_paths(draw: DrawFn) -> List[RouteMapTestCase]: - def build_record(components: List[str], params: Set[str]) -> RouteMapTestCase: - segments = components + [f"{{{p}:int}}" for p in params] - shuffle(segments) - router_path = "/" + "/".join(segments) - request_path = param_pattern.sub("1", router_path) - return router_path, request_path, {f"{p}:int" for p in params} - - parameter_names = ["a", "b", "c", "d", "e"] - param_st = st.sets(st.sampled_from(parameter_names), max_size=3) - components_st = st.lists(st.text(alphabet=ascii_letters, min_size=1, max_size=4), min_size=1, max_size=3) - path_st = st.builds(build_record, components_st, param_st) - return cast( - "List[RouteMapTestCase]", draw(st.lists(path_st, min_size=10, max_size=10, unique_by=lambda record: record[1])) - ) - - -def test_route_map_starts_empty() -> None: - @get(path=[]) - def handler_fn() -> None: - ... - - client = create_test_client(handler_fn) - route_map = client.app.route_map - assert route_map["_components"] == set() - assert list(route_map.keys()) == ["_components", "/"] - - -@given(test_paths=route_test_paths()) -@settings( - max_examples=5, - deadline=None, -) -def test_add_route_map_path(test_paths: List[RouteMapTestCase]) -> None: - @get(path=[]) - def handler_fn(a: int = 0, b: int = 0, c: int = 0, d: int = 0, e: int = 0) -> None: - ... - - client = create_test_client(handler_fn) - app = client.app - route_map = app.route_map - for router_path, _, path_params in test_paths: - assert is_path_in_route_map(route_map, router_path, path_params) is False - route = HTTPRoute( - path=router_path, - route_handlers=[get(path=router_path)(handler_fn)], - ) - app._add_node_to_route_map(route) - assert is_path_in_route_map(route_map, router_path, path_params) is True - - -@given(test_paths=route_test_paths()) -@settings( - max_examples=5, - deadline=None, -) -def test_handler_paths_added(test_paths: List[RouteMapTestCase]) -> None: - @get(path=[router_path for router_path, _, _ in test_paths]) - def handler_fn(a: int = 0, b: int = 0, c: int = 0, d: int = 0, e: int = 0) -> None: - ... - - client = create_test_client(handler_fn) - route_map = client.app.route_map - for router_path, _, path_params in test_paths: - assert is_path_in_route_map(route_map, router_path, path_params) is True - - -@given(test_paths=route_test_paths()) -@settings( - max_examples=5, - deadline=None, -) -def test_find_existing_asgi_handlers(test_paths: List[RouteMapTestCase]) -> None: - def handler_fn(a: int = 0, b: int = 0) -> None: - ... - - client = create_test_client(get(path=[router_path for router_path, _, _ in test_paths])(handler_fn)) - app = client.app - router = app.asgi_router - for router_path, request_path, _ in test_paths: - route = HTTPRoute( - path=router_path, - route_handlers=[get(path=router_path)(handler_fn)], - ) - app._add_node_to_route_map(route) - asgi_handlers, is_asgi = router._parse_scope_to_route({"path": request_path}) # type: ignore[arg-type] - assert "GET" in asgi_handlers - assert isinstance(asgi_handlers["GET"]["asgi_app"], ExceptionHandlerMiddleware) - assert is_asgi is False - - -@given(test_paths=route_test_paths()) -@settings( - max_examples=5, - deadline=None, -) -def test_missing_asgi_handlers(test_paths: List[RouteMapTestCase]) -> None: - def handler_fn(a: int = 0, b: int = 0) -> None: - ... - - client = create_test_client(get(path=[])(handler_fn)) - app = client.app - router = app.asgi_router - for router_path, request_path, _ in test_paths: - route = HTTPRoute( - path=router_path, - route_handlers=[get(path=router_path)(handler_fn)], - ) - app._add_node_to_route_map(route) - assert router._parse_scope_to_route({"path": request_path}) == ({}, False) # type: ignore[arg-type] diff --git a/tests/static_files/test_static_files.py b/tests/routing/test_static_files.py similarity index 76% rename from tests/static_files/test_static_files.py rename to tests/routing/test_static_files.py index 4f07ec7a3d..74b17ebdda 100644 --- a/tests/static_files/test_static_files.py +++ b/tests/routing/test_static_files.py @@ -3,18 +3,20 @@ import pytest from pydantic import ValidationError -from starlite import ImproperlyConfiguredException, Starlite, get +from starlite import ImproperlyConfiguredException, MediaType, Starlite, get from starlite.config import StaticFilesConfig +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client if TYPE_CHECKING: from pathlib import Path -def test_staticfiles(tmpdir: "Path") -> None: +def test_staticfiles_standard_config(tmpdir: "Path") -> None: path = tmpdir / "test.txt" path.write_text("content", "utf-8") static_files_config = StaticFilesConfig(path="/static", directories=[tmpdir]) + with create_test_client([], static_files_config=static_files_config) as client: response = client.get("/static/test.txt") assert response.status_code == 200 @@ -53,21 +55,34 @@ def test_config_validation(tmpdir: "Path") -> None: StaticFilesConfig(path="/{param:int}", directories=[tmpdir]) -def test_path_inside_static(tmpdir: "Path") -> None: +def test_sub_path_under_static_path(tmpdir: "Path") -> None: path = tmpdir / "test.txt" path.write_text("content", "utf-8") - @get("/static/strange/{f:str}") + @get("/static/sub/{f:str}", media_type=MediaType.TEXT) def handler(f: str) -> str: return f - static_files_config = StaticFilesConfig(path="/static", directories=[tmpdir]) - with pytest.raises(ImproperlyConfiguredException): - Starlite(route_handlers=[handler], static_files_config=static_files_config) + with create_test_client( + handler, static_files_config=StaticFilesConfig(path="/static", directories=[tmpdir]) + ) as client: + response = client.get("/static/test.txt") + assert response.status_code == HTTP_200_OK + + response = client.get("/static/sub/abc") + assert response.status_code == HTTP_200_OK + + +def test_validation_of_static_path_and_path_parameter(tmpdir: "Path") -> None: + path = tmpdir / "test.txt" + path.write_text("content", "utf-8") + + @get("/static/{f:str}", media_type=MediaType.TEXT) + def handler(f: str) -> str: + return f - app = Starlite(route_handlers=[], static_files_config=static_files_config) with pytest.raises(ImproperlyConfiguredException): - app.register(handler) + Starlite(route_handlers=[handler], static_files_config=StaticFilesConfig(path="/static", directories=[tmpdir])) def test_multiple_configs(tmpdir: "Path") -> None: diff --git a/tests/test_guards.py b/tests/test_guards.py index db41c06db6..dd4724df5a 100644 --- a/tests/test_guards.py +++ b/tests/test_guards.py @@ -50,7 +50,9 @@ def my_http_route_handler() -> None: response = client.get("/secret", headers={"Authorization": "yes"}) assert response.status_code == HTTP_403_FORBIDDEN assert response.json().get("detail") == "local" - client.app.route_map["/secret"]["_asgi_handlers"]["GET"]["handler"].opt["allow_all"] = True + client.app.asgi_router.root_route_map_node["children"]["/secret"]["asgi_handlers"]["GET"][1].opt[ + "allow_all" + ] = True response = client.get("/secret", headers={"Authorization": "yes"}) assert response.status_code == HTTP_200_OK @@ -68,7 +70,9 @@ async def my_asgi_handler(scope: Scope, receive: Receive, send: Send) -> None: response = client.get("/secret", headers={"Authorization": "yes"}) assert response.status_code == HTTP_403_FORBIDDEN assert response.json().get("detail") == "local" - client.app.route_map["/secret"]["_asgi_handlers"]["asgi"]["handler"].opt["allow_all"] = True + client.app.asgi_router.root_route_map_node["children"]["/secret"]["asgi_handlers"]["asgi"][1].opt[ + "allow_all" + ] = True response = client.get("/secret", headers={"Authorization": "yes"}) assert response.status_code == HTTP_200_OK @@ -87,7 +91,7 @@ async def my_websocket_route_handler(socket: WebSocket) -> None: with pytest.raises(WebSocketDisconnect), client.websocket_connect("/") as ws: ws.send_json({"data": "123"}) - client.app.route_map["/"]["_asgi_handlers"]["websocket"]["handler"].opt["allow_all"] = True + client.app.asgi_router.root_route_map_node["children"]["/"]["asgi_handlers"]["websocket"][1].opt["allow_all"] = True with client.websocket_connect("/") as ws: ws.send_json({"data": "123"}) @@ -101,5 +105,19 @@ def http_route_handler() -> None: router = Router(path="/router", route_handlers=[http_route_handler], guards=[router_guard]) app = Starlite(route_handlers=[http_route_handler, router], guards=[app_guard]) - assert len(app.route_map["/http"]["_asgi_handlers"]["GET"]["handler"]._resolved_guards) == 2 - assert len(app.route_map["/router/http"]["_asgi_handlers"]["GET"]["handler"]._resolved_guards) == 3 + assert ( + len( + app.asgi_router.root_route_map_node["children"]["/http"]["asgi_handlers"]["GET"][ + 1 + ]._resolved_guards # pyright: ignore + ) + == 2 + ) + assert ( + len( + app.asgi_router.root_route_map_node["children"]["/router/http"]["asgi_handlers"]["GET"][ + 1 + ]._resolved_guards # pyright: ignore + ) + == 3 + )