From f71aa930823ee39bdc0f4c08691bd95bcf4229e4 Mon Sep 17 00:00:00 2001 From: tyranron Date: Tue, 21 Nov 2023 17:48:58 +0100 Subject: [PATCH] Showcase --- juniper_warp/src/lib.rs | 86 +++++++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 17 deletions(-) diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index 4100bea7b..7af472632 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -668,13 +668,20 @@ pub mod subscriptions { #[cfg(test)] mod tests { - mod graphql { + mod make_graphql_filter { + use std::future; + use juniper::{ http::GraphQLBatchRequest, tests::fixtures::starwars::schema::{Database, Query}, - EmptyMutation, EmptySubscription, RootNode, + EmptyMutation, EmptySubscription, + }; + use warp::{ + http, + reject::{self, Reject}, + test::request, + Filter as _, }; - use warp::{http, test::request, Filter as _}; use super::super::make_graphql_filter; @@ -687,11 +694,7 @@ mod tests { EmptySubscription, >; - let schema: Schema = RootNode::new( - Query, - EmptyMutation::::new(), - EmptySubscription::::new(), - ); + let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()); let state = warp::any().map(Database::new); let filter = warp::path("graphql2").and(make_graphql_filter(schema, state.boxed())); @@ -717,7 +720,17 @@ mod tests { } #[tokio::test] - async fn batch_requests() { + async fn rejects_fast_when_context_extractor_fails() { + use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }; + + #[derive(Debug)] + struct ExtractionError; + + impl Reject for ExtractionError {} + type Schema = juniper::RootNode< 'static, Query, @@ -725,11 +738,49 @@ mod tests { EmptySubscription, >; - let schema: Schema = RootNode::new( - Query, - EmptyMutation::::new(), - EmptySubscription::::new(), + let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()); + + // Should error on first extraction only, to check whether it rejects fast and doesn't + // switch to other `.or()` filter branches. See #1177 for details: + // https://github.com/graphql-rust/juniper/issues/1177 + let is_called = Arc::new(AtomicBool::new(false)); + let context_extractor = warp::any().and_then(move || { + future::ready(if is_called.swap(true, Ordering::Relaxed) { + Ok(Database::new()) + } else { + Err(reject::custom(ExtractionError)) + }) + }); + + let filter = + warp::path("graphql").and(make_graphql_filter(schema, context_extractor.boxed())); + + let resp = request() + .method("POST") + .path("/graphql") + .header("accept", "application/json") + .header("content-type", "application/json") + .body(r#"{"variables": null, "query": "{ hero(episode: NEW_HOPE) { name } }"}"#) + .reply(&filter) + .await; + + assert_eq!( + resp.status(), + http::StatusCode::INTERNAL_SERVER_ERROR, + "response: {resp:#?}", ); + } + + #[tokio::test] + async fn batch_requests() { + type Schema = juniper::RootNode< + 'static, + Query, + EmptyMutation, + EmptySubscription, + >; + + let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()); let state = warp::any().map(Database::new); let filter = warp::path("graphql2").and(make_graphql_filter(schema, state.boxed())); @@ -768,7 +819,7 @@ mod tests { } } - mod graphiql { + mod graphiql_filter { use warp::{http, test::request, Filter as _}; use super::super::{graphiql_filter, graphiql_response}; @@ -794,7 +845,7 @@ mod tests { } #[tokio::test] - async fn endpoint_returns_graphiql_source() { + async fn returns_graphiql_source() { let filter = warp::get() .and(warp::path("dogs-api")) .and(warp::path("graphiql")) @@ -833,7 +884,7 @@ mod tests { } } - mod playground { + mod playground_filter { use warp::{http, test::request, Filter as _}; use super::super::playground_filter; @@ -855,7 +906,7 @@ mod tests { } #[tokio::test] - async fn endpoint_returns_playground_source() { + async fn returns_playground_source() { let filter = warp::get() .and(warp::path("dogs-api")) .and(warp::path("playground")) @@ -875,6 +926,7 @@ mod tests { response.headers().get("content-type").unwrap(), "text/html;charset=utf-8" ); + let body = String::from_utf8(response.body().to_vec()).unwrap(); assert!(body.contains(