Skip to content

Commit

Permalink
#issue 354 html mode not working as expected (#355)
Browse files Browse the repository at this point in the history
* inital

* updated handling of path params for static paths

* updated app private members
  • Loading branch information
Goldziher authored Aug 12, 2022
1 parent 6391bbd commit ef5d3f5
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 112 deletions.
162 changes: 80 additions & 82 deletions starlite/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
LifeCycleHandler,
Middleware,
)
from starlite.utils import normalize_path
from starlite.utils.templates import create_template_engine

if TYPE_CHECKING:
Expand All @@ -57,20 +56,20 @@
class Starlite(Router):
__slots__ = (
"_registered_routes",
"_static_paths",
"allowed_hosts",
"asgi_handler",
"asgi_router",
"cache_config",
"compression_config",
"cors_config",
"csrf_config",
"debug",
"compression_config",
"openapi_schema",
"plain_routes",
"plugins",
"route_map",
"state",
"static_paths",
"template_engine",
)

Expand Down Expand Up @@ -149,6 +148,8 @@ def __init__(
template_config: An instance of [TemplateConfig][starlite.config.TemplateConfig]
tags: A list of string tags that will be appended to the schema of all route handlers under the application.
"""
self._registered_routes: Set[BaseRoute] = set()
self._static_paths: Set[str] = set()
self.allowed_hosts = allowed_hosts
self.cache_config = cache_config
self.cors_config = cors_config
Expand All @@ -159,9 +160,7 @@ def __init__(
self.plugins = plugins or []
self.route_map: Dict[str, Any] = {}
self.routes: List[BaseRoute] = []
self._registered_routes: Set[BaseRoute] = set()
self.state = State()
self.static_paths = set()

super().__init__(
after_request=after_request,
Expand All @@ -181,19 +180,65 @@ def __init__(
)

self.asgi_router = StarliteASGIRouter(on_shutdown=on_shutdown or [], on_startup=on_startup or [], app=self)
self.asgi_handler = self.create_asgi_handler()
self.asgi_handler = self._create_asgi_handler()
self.openapi_schema: Optional["OpenAPI"] = None
if openapi_config:
self.openapi_schema = openapi_config.create_openapi_schema_model(self)
self.register(openapi_config.openapi_controller)
if static_files_config:
for config in static_files_config if isinstance(static_files_config, list) else [static_files_config]:
path = normalize_path(config.path)
self.static_paths.add(path)
self.register(asgi(path=path)(config.to_static_files_app()))
self._static_paths.add(config.path)
self.register(asgi(path=config.path)(config.to_static_files_app()))
self.template_engine = create_template_engine(template_config)

def create_asgi_handler(self) -> "ASGIApp":
async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
"""
The application entry point.
Lifespan events (startup / shutdown) are sent to the lifespan handler, otherwise the ASGI handler is used
"""
scope["app"] = self
if scope["type"] == "lifespan":
await self.asgi_router.lifespan(scope, receive, send)
return
scope["state"] = {}
await self.asgi_handler(scope, receive, send)

def register(self, value: ControllerRouterHandler) -> None: # type: ignore[override]
"""
Registers a route handler on the app. This method can be used to dynamically add endpoints to an application.
Args:
value: an instance of [Router][starlite.router.Router], a subclasses of
[Controller][starlite.controller.Controller] or any function decorated by the route handler decorators.
Returns:
None
"""
routes = super().register(value=value)
for route in routes:
if isinstance(route, HTTPRoute):
route_handlers = route.route_handlers
else:
route_handlers = [cast("Union[WebSocketRoute, ASGIRoute]", route).route_handler] # type: ignore
for route_handler in route_handlers:
self._create_handler_signature_model(route_handler=route_handler)
route_handler.resolve_guards()
route_handler.resolve_middleware()
if isinstance(route_handler, HTTPRouteHandler):
route_handler.resolve_response_class()
route_handler.resolve_before_request()
route_handler.resolve_after_request()
route_handler.resolve_after_response()
route_handler.resolve_response_headers()
route_handler.resolve_response_cookies()
if isinstance(route, HTTPRoute):
route.create_handler_map()
elif isinstance(route, WebSocketRoute):
route.handler_parameter_model = route.create_handler_kwargs_model(route.route_handler)
self._construct_route_map()

def _create_asgi_handler(self) -> "ASGIApp":
"""
Creates an ASGIApp that wraps the ASGI router inside an exception handler.
Expand All @@ -208,21 +253,9 @@ def create_asgi_handler(self) -> "ASGIApp":
asgi_handler = CORSMiddleware(app=asgi_handler, **self.cors_config.dict())
if self.csrf_config:
asgi_handler = CSRFMiddleware(app=asgi_handler, config=self.csrf_config)
return self.wrap_in_exception_handler(asgi_handler, exception_handlers=self.exception_handlers or {})

async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
"""
The application entry point.
Lifespan events (startup / shutdown) are sent to the lifespan handler, otherwise the ASGI handler is used
"""
scope["app"] = self
if scope["type"] == "lifespan":
await self.asgi_router.lifespan(scope, receive, send)
return
scope["state"] = {}
await self.asgi_handler(scope, receive, send)
return self._wrap_in_exception_handler(asgi_handler, exception_handlers=self.exception_handlers or {})

def wrap_in_exception_handler(
def _wrap_in_exception_handler(
self, app: "ASGIApp", exception_handlers: Dict[Union[int, Type[Exception]], ExceptionHandler]
) -> "ASGIApp":
"""
Expand All @@ -231,38 +264,38 @@ def wrap_in_exception_handler(

return ExceptionHandlerMiddleware(app=app, exception_handlers=exception_handlers, debug=self.debug)

def add_node_to_route_map(self, route: BaseRoute) -> Dict[str, Any]:
def _add_node_to_route_map(self, route: BaseRoute) -> Dict[str, Any]:
"""
Adds a new route path (e.g. '/foo/bar/{param:int}') into the route_map tree.
Inserts non-parameter paths ('plain routes') off the tree's root node.
For paths containing parameters, splits the path on '/' and nests each path
segment under the previous segment's node (see prefix tree / trie).
"""
cur_node = self.route_map
current_node = self.route_map
path = route.path
if route.path_parameters or path in self.static_paths:
if route.path_parameters or path in self._static_paths:
for param_definition in route.path_parameters:
path = path.replace(param_definition["full"], "")
path = path.replace("{}", "*")
components = ["/", *[component for component in path.split("/") if component]]
for component in components:
components_set = cast("Set[str]", cur_node["_components"])
components_set = cast("Set[str]", current_node["_components"])
components_set.add(component)
if component not in cur_node:
cur_node[component] = {"_components": set()}
cur_node = cast("Dict[str, Any]", cur_node[component])
if "static_path" in cur_node:
if component not in current_node:
current_node[component] = {"_components": set()}
current_node = cast("Dict[str, Any]", current_node[component])
if "_static_path" in current_node:
raise ImproperlyConfiguredException("Cannot have configured routes below a static path")
else:
if path not in self.route_map:
self.route_map[path] = {"_components": set()}
self.plain_routes.add(path)
cur_node = self.route_map[path]
self.configure_route_map_node(route, cur_node)
return cur_node
current_node = self.route_map[path]
self._configure_route_map_node(route, current_node)
return current_node

def configure_route_map_node(self, route: BaseRoute, node: Dict[str, Any]) -> None:
def _configure_route_map_node(self, route: BaseRoute, node: Dict[str, Any]) -> None:
"""
Set required attributes and route handlers on route_map tree node.
"""
Expand All @@ -272,44 +305,44 @@ def configure_route_map_node(self, route: BaseRoute, node: Dict[str, Any]) -> No
node["_asgi_handlers"] = {}
if "_is_asgi" not in node:
node["_is_asgi"] = False
if route.path in self.static_paths:
if route.path in self._static_paths:
if node["_components"]:
raise ImproperlyConfiguredException("Cannot have configured routes below a static path")
node["static_path"] = route.path
node["_static_path"] = route.path
node["_is_asgi"] = True
asgi_handlers = cast("Dict[str, ASGIApp]", node["_asgi_handlers"])
if isinstance(route, HTTPRoute):
for method, handler_mapping in route.route_handler_map.items():
handler, _ = handler_mapping
asgi_handlers[method] = self.build_route_middleware_stack(route, handler)
asgi_handlers[method] = self._build_route_middleware_stack(route, handler)
elif isinstance(route, WebSocketRoute):
asgi_handlers["websocket"] = self.build_route_middleware_stack(route, route.route_handler)
asgi_handlers["websocket"] = self._build_route_middleware_stack(route, route.route_handler)
elif isinstance(route, ASGIRoute):
asgi_handlers["asgi"] = self.build_route_middleware_stack(route, route.route_handler)
asgi_handlers["asgi"] = self._build_route_middleware_stack(route, route.route_handler)
node["_is_asgi"] = True

def construct_route_map(self) -> None:
def _construct_route_map(self) -> None:
"""
Create a map of the app's routes. This map is used in the asgi router to route requests.
"""
if "_components" not in self.route_map:
self.route_map["_components"] = set()
new_routes = [route for route in self.routes if route not in self._registered_routes]
for route in new_routes:
node = self.add_node_to_route_map(route)
node = self._add_node_to_route_map(route)
if node["_path_parameters"] != route.path_parameters:
raise ImproperlyConfiguredException("Should not use routes with conflicting path parameters")
self._registered_routes.add(route)

def build_route_middleware_stack(
def _build_route_middleware_stack(
self,
route: Union[HTTPRoute, WebSocketRoute, ASGIRoute],
route_handler: Union[HTTPRouteHandler, "WebsocketRouteHandler", ASGIRouteHandler],
) -> "ASGIApp":
"""Constructs a middleware stack that serves as the point of entry for each route"""

# we wrap the route.handle method in the ExceptionHandlerMiddleware
asgi_handler = self.wrap_in_exception_handler(
asgi_handler = self._wrap_in_exception_handler(
app=route.handle, exception_handlers=route_handler.resolve_exception_handlers()
)

Expand All @@ -320,46 +353,11 @@ def build_route_middleware_stack(
asgi_handler = middleware(app=asgi_handler)

# we wrap the entire stack again in ExceptionHandlerMiddleware
return self.wrap_in_exception_handler(
return self._wrap_in_exception_handler(
app=asgi_handler, exception_handlers=route_handler.resolve_exception_handlers()
)

def register(self, value: ControllerRouterHandler) -> None: # type: ignore[override]
"""
Registers a route handler on the app. This method can be used to dynamically add endpoints to an application.
Args:
value: an instance of [Router][starlite.router.Router], a subclasses of
[Controller][starlite.controller.Controller] or any function decorated by the route handler decorators.
Returns:
None
"""
routes = super().register(value=value)
for route in routes:
if isinstance(route, HTTPRoute):
route_handlers = route.route_handlers
else:
route_handlers = [cast("Union[WebSocketRoute, ASGIRoute]", route).route_handler] # type: ignore
for route_handler in route_handlers:
self.create_handler_signature_model(route_handler=route_handler)
route_handler.resolve_guards()
route_handler.resolve_middleware()
if isinstance(route_handler, HTTPRouteHandler):
route_handler.resolve_response_class()
route_handler.resolve_before_request()
route_handler.resolve_after_request()
route_handler.resolve_after_response()
route_handler.resolve_response_headers()
route_handler.resolve_response_cookies()
if isinstance(route, HTTPRoute):
route.create_handler_map()
elif isinstance(route, WebSocketRoute):
route.handler_parameter_model = route.create_handler_kwargs_model(route.route_handler)
self.construct_route_map()

def create_handler_signature_model(self, route_handler: "BaseRouteHandler") -> None:
def _create_handler_signature_model(self, route_handler: "BaseRouteHandler") -> None:
"""
Creates function signature models for all route handler functions and provider dependencies
"""
Expand Down
Loading

0 comments on commit ef5d3f5

Please sign in to comment.