diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 2e975d99b7..df33e977ac 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -6,6 +6,7 @@ use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ body::{Body, HttpBody}, boxed::BoxedIntoRoute, + extract::MatchedPath, handler::Handler, util::try_downcast, }; @@ -20,7 +21,8 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tower_layer::Layer; +use tower::service_fn; +use tower_layer::{layer_fn, Layer}; use tower_service::Service; pub mod future; @@ -72,8 +74,7 @@ impl Clone for Router { } struct RouterInner { - path_router: PathRouter, - fallback_router: PathRouter, + path_router: PathRouter, default_fallback: bool, catch_all_fallback: Fallback, } @@ -91,7 +92,6 @@ impl fmt::Debug for Router { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") .field("path_router", &self.inner.path_router) - .field("fallback_router", &self.inner.fallback_router) .field("default_fallback", &self.inner.default_fallback) .field("catch_all_fallback", &self.inner.catch_all_fallback) .finish() @@ -141,7 +141,6 @@ where Self { inner: Arc::new(RouterInner { path_router: Default::default(), - fallback_router: PathRouter::new_fallback(), default_fallback: true, catch_all_fallback: Fallback::Default(Route::new(NotFound)), }), @@ -153,7 +152,6 @@ where Ok(inner) => inner, Err(arc) => RouterInner { path_router: arc.path_router.clone(), - fallback_router: arc.fallback_router.clone(), default_fallback: arc.default_fallback, catch_all_fallback: arc.catch_all_fallback.clone(), }, @@ -207,8 +205,7 @@ where let RouterInner { path_router, - fallback_router, - default_fallback, + default_fallback: _, // we don't need to inherit the catch-all fallback. It is only used for CONNECT // requests with an empty path. If we were to inherit the catch-all fallback // it would end up matching `/{path}/*` which doesn't match empty paths. @@ -217,10 +214,6 @@ where tap_inner!(self, mut this => { panic_on_err!(this.path_router.nest(path, path_router)); - - if !default_fallback { - panic_on_err!(this.fallback_router.nest(path, fallback_router)); - } }) } @@ -247,36 +240,24 @@ where where R: Into>, { - const PANIC_MSG: &str = - "Failed to merge fallbacks. This is a bug in axum. Please file an issue"; - let other: Router = other.into(); let RouterInner { path_router, - fallback_router: mut other_fallback, default_fallback, catch_all_fallback, } = other.into_inner(); map_inner!(self, mut this => { - panic_on_err!(this.path_router.merge(path_router)); - match (this.default_fallback, default_fallback) { // both have the default fallback // use the one from other - (true, true) => { - this.fallback_router.merge(other_fallback).expect(PANIC_MSG); - } + (true, true) => {} // this has default fallback, other has a custom fallback (true, false) => { - this.fallback_router.merge(other_fallback).expect(PANIC_MSG); this.default_fallback = false; } // this has a custom fallback, other has a default (false, true) => { - let fallback_router = std::mem::take(&mut this.fallback_router); - other_fallback.merge(fallback_router).expect(PANIC_MSG); - this.fallback_router = other_fallback; } // both have a custom fallback, not allowed (false, false) => { @@ -284,6 +265,8 @@ where } }; + panic_on_err!(this.path_router.merge(path_router)); + this.catch_all_fallback = this .catch_all_fallback .merge(catch_all_fallback) @@ -304,7 +287,6 @@ where { map_inner!(self, this => RouterInner { path_router: this.path_router.layer(layer.clone()), - fallback_router: this.fallback_router.layer(layer.clone()), default_fallback: this.default_fallback, catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)), }) @@ -322,7 +304,6 @@ where { map_inner!(self, this => RouterInner { path_router: this.path_router.route_layer(layer), - fallback_router: this.fallback_router, default_fallback: this.default_fallback, catch_all_fallback: this.catch_all_fallback, }) @@ -376,8 +357,47 @@ where } fn fallback_endpoint(self, endpoint: Endpoint) -> Self { + // TODO make this better, get rid of the `unwrap`s. + // We need the returned `Service` to be `Clone` and the function inside `service_fn` to be + // `FnMut` so instead of just using the owned service, we do this trick with `Option`. We + // know this will be called just once so it's fine. We're doing that so that we avoid one + // clone inside `oneshot_inner` so that the `Router` and subsequently the `State` is not + // cloned too much. tap_inner!(self, mut this => { - this.fallback_router.set_fallback(endpoint); + _ = this.path_router.route_endpoint( + "/", + endpoint.clone().layer( + layer_fn( + |service: Route| { + let mut service = Some(service); + service_fn( + move |mut request: Request| { + request.extensions_mut().remove::(); + service.take().unwrap().oneshot_inner_owned(request) + } + ) + } + ) + ) + ); + + _ = this.path_router.route_endpoint( + FALLBACK_PARAM_PATH, + endpoint.layer( + layer_fn( + |service: Route| { + let mut service = Some(service); + service_fn( + move |mut request: Request| { + request.extensions_mut().remove::(); + service.take().unwrap().oneshot_inner_owned(request) + } + ) + } + ) + ) + ); + this.default_fallback = false; }) } @@ -386,7 +406,6 @@ where pub fn with_state(self, state: S) -> Router { map_inner!(self, this => RouterInner { path_router: this.path_router.with_state(state.clone()), - fallback_router: this.fallback_router.with_state(state.clone()), default_fallback: this.default_fallback, catch_all_fallback: this.catch_all_fallback.with_state(state), }) @@ -398,11 +417,6 @@ where Err((req, state)) => (req, state), }; - let (req, state) = match self.inner.fallback_router.call_with_state(req, state) { - Ok(future) => return future, - Err((req, state)) => (req, state), - }; - self.inner .catch_all_fallback .clone() diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index 83f33e4e1f..c73778f2b0 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -9,33 +9,17 @@ use tower_layer::Layer; use tower_service::Service; use super::{ - future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint, - MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM, + future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route, + RouteId, NEST_TAIL_PARAM, }; -pub(super) struct PathRouter { +pub(super) struct PathRouter { routes: HashMap>, node: Arc, prev_route_id: RouteId, v7_checks: bool, } -impl PathRouter -where - S: Clone + Send + Sync + 'static, -{ - pub(super) fn new_fallback() -> Self { - let mut this = Self::default(); - this.set_fallback(Endpoint::Route(Route::new(NotFound))); - this - } - - pub(super) fn set_fallback(&mut self, endpoint: Endpoint) { - self.replace_endpoint("/", endpoint.clone()); - self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint); - } -} - fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> { if path.is_empty() { return Err("Paths must start with a `/`. Use \"/\" for root routes"); @@ -72,7 +56,7 @@ fn validate_v07_paths(path: &str) -> Result<(), &'static str> { .unwrap_or(Ok(())) } -impl PathRouter +impl PathRouter where S: Clone + Send + Sync + 'static, { @@ -159,10 +143,7 @@ where .map_err(|err| format!("Invalid route {path:?}: {err}")) } - pub(super) fn merge( - &mut self, - other: PathRouter, - ) -> Result<(), Cow<'static, str>> { + pub(super) fn merge(&mut self, other: PathRouter) -> Result<(), Cow<'static, str>> { let PathRouter { routes, node, @@ -179,24 +160,9 @@ where .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); - if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) { - // when merging two routers it doesn't matter if you do `a.merge(b)` or - // `b.merge(a)`. This must also be true for fallbacks. - // - // However all fallback routers will have routes for `/` and `/*` so when merging - // we have to ignore the top level fallbacks on one side otherwise we get - // conflicts. - // - // `Router::merge` makes sure that when merging fallbacks `other` always has the - // fallback we want to keep. It panics if both routers have a custom fallback. Thus - // it is always okay to ignore one fallback and `Router::merge` also makes sure the - // one we can ignore is that of `self`. - self.replace_endpoint(path, route); - } else { - match route { - Endpoint::MethodRouter(method_router) => self.route(path, method_router)?, - Endpoint::Route(route) => self.route_service(path, route)?, - } + match route { + Endpoint::MethodRouter(method_router) => self.route(path, method_router)?, + Endpoint::Route(route) => self.route_service(path, route)?, } } @@ -206,7 +172,7 @@ where pub(super) fn nest( &mut self, path_to_nest_at: &str, - router: PathRouter, + router: PathRouter, ) -> Result<(), Cow<'static, str>> { let prefix = validate_nest_path(self.v7_checks, path_to_nest_at); @@ -282,7 +248,7 @@ where Ok(()) } - pub(super) fn layer(self, layer: L) -> PathRouter + pub(super) fn layer(self, layer: L) -> PathRouter where L: Layer + Clone + Send + Sync + 'static, L::Service: Service + Clone + Send + Sync + 'static, @@ -344,7 +310,7 @@ where !self.routes.is_empty() } - pub(super) fn with_state(self, state: S) -> PathRouter { + pub(super) fn with_state(self, state: S) -> PathRouter { let routes = self .routes .into_iter() @@ -388,14 +354,12 @@ where Ok(match_) => { let id = *match_.value; - if !IS_FALLBACK { - #[cfg(feature = "matched-path")] - crate::extract::matched_path::set_matched_path_for_request( - id, - &self.node.route_id_to_path, - &mut parts.extensions, - ); - } + #[cfg(feature = "matched-path")] + crate::extract::matched_path::set_matched_path_for_request( + id, + &self.node.route_id_to_path, + &mut parts.extensions, + ); url_params::insert_url_params(&mut parts.extensions, match_.params); @@ -418,18 +382,6 @@ where } } - pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint) { - match self.node.at(path) { - Ok(match_) => { - let id = *match_.value; - self.routes.insert(id, endpoint); - } - Err(_) => self - .route_endpoint(path, endpoint) - .expect("path wasn't matched so endpoint shouldn't exist"), - } - } - fn next_route_id(&mut self) -> RouteId { let next_id = self .prev_route_id @@ -441,7 +393,7 @@ where } } -impl Default for PathRouter { +impl Default for PathRouter { fn default() -> Self { Self { routes: Default::default(), @@ -452,7 +404,7 @@ impl Default for PathRouter { } } -impl fmt::Debug for PathRouter { +impl fmt::Debug for PathRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PathRouter") .field("routes", &self.routes) @@ -461,7 +413,7 @@ impl fmt::Debug for PathRouter { } } -impl Clone for PathRouter { +impl Clone for PathRouter { fn clone(&self) -> Self { Self { routes: self.routes.clone(), diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 8851738667..fe41dc72a4 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -407,6 +407,87 @@ async fn what_matches_wildcard() { assert_eq!(get("/x/a/b/").await, "x"); } +#[should_panic( + expected = "Invalid route \"/{*wild}\": Insertion failed due to conflict with previously registered route: /{*__private__axum_fallback}" +)] +#[test] +fn colliding_fallback_with_wildcard() { + _ = Router::<()>::new() + .fallback(|| async { "fallback" }) + .route("/{*wild}", get(|| async { "wildcard" })); +} + +// We might want to reject this too +#[crate::test] +async fn colliding_wildcard_with_fallback() { + let router = Router::new() + .route("/{*wild}", get(|| async { "wildcard" })) + .fallback(|| async { "fallback" }); + + let client = TestClient::new(router); + + let res = client.get("/").await; + let body = res.text().await; + assert_eq!(body, "fallback"); + + let res = client.get("/x").await; + let body = res.text().await; + assert_eq!(body, "wildcard"); +} + +// We might want to reject this too +#[crate::test] +async fn colliding_fallback_with_fallback() { + let router = Router::new() + .fallback(|| async { "fallback1" }) + .fallback(|| async { "fallback2" }); + + let client = TestClient::new(router); + + let res = client.get("/").await; + let body = res.text().await; + assert_eq!(body, "fallback1"); + + let res = client.get("/x").await; + let body = res.text().await; + assert_eq!(body, "fallback1"); +} + +#[crate::test] +async fn colliding_root_with_fallback() { + let router = Router::new() + .route("/", get(|| async { "root" })) + .fallback(|| async { "fallback" }); + + let client = TestClient::new(router); + + let res = client.get("/").await; + let body = res.text().await; + assert_eq!(body, "root"); + + let res = client.get("/x").await; + let body = res.text().await; + assert_eq!(body, "fallback"); +} + +#[crate::test] +async fn colliding_fallback_with_root() { + let router = Router::new() + .fallback(|| async { "fallback" }) + .route("/", get(|| async { "root" })); + + let client = TestClient::new(router); + + // This works because fallback registers `any` so the `get` gets merged into it. + let res = client.get("/").await; + let body = res.text().await; + assert_eq!(body, "root"); + + let res = client.get("/x").await; + let body = res.text().await; + assert_eq!(body, "fallback"); +} + #[crate::test] async fn static_and_dynamic_paths() { let app = Router::new() diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index 6e14203662..3368346edd 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -387,3 +387,100 @@ async fn colon_in_route() { async fn asterisk_in_route() { _ = Router::<()>::new().nest("/*foo", Router::new()); } + +#[crate::test] +async fn nesting_router_with_fallback() { + let nested = Router::new().fallback(|| async { "nested" }); + let router = Router::new().route("/{x}/{y}", get(|| async { "two segments" })); + + let client = TestClient::new(router.nest("/nest", nested)); + + let res = client.get("/a/b").await; + let body = res.text().await; + assert_eq!(body, "two segments"); + + let res = client.get("/nest/b").await; + let body = res.text().await; + assert_eq!(body, "nested"); +} + +#[crate::test] +async fn defining_missing_routes_in_nested_router() { + let router = Router::new() + .route("/nest/before", get(|| async { "before" })) + .nest( + "/nest", + Router::new() + .route("/mid", get(|| async { "nested mid" })) + .fallback(|| async { "nested fallback" }), + ) + .route("/nest/after", get(|| async { "after" })); + + let client = TestClient::new(router); + + let res = client.get("/nest/before").await; + let body = res.text().await; + assert_eq!(body, "before"); + + let res = client.get("/nest/after").await; + let body = res.text().await; + assert_eq!(body, "after"); + + let res = client.get("/nest/mid").await; + let body = res.text().await; + assert_eq!(body, "nested mid"); + + let res = client.get("/nest/fallback").await; + let body = res.text().await; + assert_eq!(body, "nested fallback"); +} + +#[test] +#[should_panic( + expected = "Overlapping method route. Handler for `GET /nest/override` already exists" +)] +fn overriding_by_nested_router() { + _ = Router::<()>::new() + .route("/nest/override", get(|| async { "outer" })) + .nest( + "/nest", + Router::new().route("/override", get(|| async { "inner" })), + ); +} + +#[test] +#[should_panic( + expected = "Overlapping method route. Handler for `GET /nest/override` already exists" +)] +fn overriding_nested_router_() { + _ = Router::<()>::new() + .nest( + "/nest", + Router::new().route("/override", get(|| async { "inner" })), + ) + .route("/nest/override", get(|| async { "outer" })); +} + +// This is just documenting current state, not intended behavior. +#[crate::test] +async fn overriding_nested_service_router() { + let router = Router::new() + .route("/nest/before", get(|| async { "outer" })) + .nest_service( + "/nest", + Router::new() + .route("/before", get(|| async { "inner" })) + .route("/after", get(|| async { "inner" })), + ) + .route("/nest/after", get(|| async { "outer" })); + + let client = TestClient::new(router); + + let res = client.get("/nest/before").await; + let body = res.text().await; + assert_eq!(body, "outer"); + + let res = client.get("/nest/after").await; + let body = res.text().await; + assert_eq!(body, "outer"); +}