From ffefea08050ffe8022d3391f4bd5e5ab4e95d7c9 Mon Sep 17 00:00:00 2001 From: Arve Knudsen Date: Fri, 15 Jan 2021 19:00:45 +0100 Subject: [PATCH] Upgrade to Tokio 1.0 (#753) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: janpetschexain <58227040+janpetschexain@users.noreply.github.com> Co-authored-by: João Oliveira Co-authored-by: Paolo Barbolini Co-authored-by: Muhammad Hamza Co-authored-by: teenjuna <53595243+teenjuna@users.noreply.github.com> --- .github/workflows/ci.yml | 2 -- Cargo.toml | 18 ++++++----- examples/futures.rs | 2 +- examples/sse.rs | 5 ++- examples/sse_chat.rs | 2 ++ examples/unix_socket.rs | 5 +-- examples/websockets_chat.rs | 2 ++ src/filters/body.rs | 8 ++--- src/filters/compression.rs | 15 ++++++--- src/filters/fs.rs | 5 +-- src/filters/sse.rs | 12 +++++--- src/filters/ws.rs | 61 +++++++++++++++++++++---------------- src/route.rs | 3 +- src/test.rs | 13 +++++--- src/tls.rs | 6 ++-- src/transport.rs | 6 ++-- tests/body.rs | 2 +- tests/ws.rs | 1 + 18 files changed, 100 insertions(+), 68 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 246bbd0ec..7a1cd98a2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,8 +45,6 @@ jobs: benches: true - build: tls features: "--features tls" - - build: uds - features: "--features tokio/uds" - build: no-default-features features: "--no-default-features" - build: compression diff --git a/Cargo.toml b/Cargo.toml index b40c2b52c..5452c6b00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,12 +17,12 @@ edition = "2018" all-features = true [dependencies] -async-compression = { version = "0.3.1", features = ["brotli", "deflate", "gzip", "stream"], optional = true } -bytes = "0.5" +async-compression = { version = "0.3.7", features = ["brotli", "deflate", "gzip", "tokio"], optional = true } +bytes = "1.0" futures = { version = "0.3", default-features = false, features = ["alloc"] } headers = "0.3" http = "0.2" -hyper = { version = "0.13", features = ["stream"] } +hyper = { version = "0.14", features = ["stream", "server", "http1", "http2", "tcp", "client"] } log = "0.4" mime = "0.3" mime_guess = "2.0.0" @@ -31,15 +31,17 @@ scoped-tls = "1.0" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.7" -tokio = { version = "0.2", features = ["fs", "stream", "sync", "time"] } +tokio = { version = "1.0", features = ["fs", "sync", "time"] } +tokio-stream = "0.1.1" +tokio-util = { version = "0.6", features = ["io"] } tracing = { version = "0.1", default-features = false, features = ["log", "std"] } tracing-futures = { version = "0.2", default-features = false, features = ["std-future"] } tower-service = "0.3" # tls is enabled by default, we don't want that yet -tokio-tungstenite = { version = "0.11", default-features = false, optional = true } +tokio-tungstenite = { version = "0.13", default-features = false, optional = true } percent-encoding = "2.1" pin-project = "1.0" -tokio-rustls = { version = "0.14", optional = true } +tokio-rustls = { version = "0.22", optional = true } [dev-dependencies] pretty_env_logger = "0.4" @@ -47,7 +49,8 @@ tracing-subscriber = "0.2.7" tracing-log = "0.1" serde_derive = "1.0" handlebars = "3.0.0" -tokio = { version = "0.2", features = ["macros"] } +tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } +tokio-stream = { version = "0.1.1", features = ["net"] } listenfd = "0.3" [features] @@ -78,7 +81,6 @@ required-features = ["compression"] [[example]] name = "unix_socket" -required-features = ["tokio/uds"] [[example]] name = "websockets" diff --git a/examples/futures.rs b/examples/futures.rs index 013428093..43bf2f6ef 100644 --- a/examples/futures.rs +++ b/examples/futures.rs @@ -16,7 +16,7 @@ async fn main() { } async fn sleepy(Seconds(seconds): Seconds) -> Result { - tokio::time::delay_for(Duration::from_secs(seconds)).await; + tokio::time::sleep(Duration::from_secs(seconds)).await; Ok(format!("I waited {} seconds!", seconds)) } diff --git a/examples/sse.rs b/examples/sse.rs index 7177d91c0..c0a27661c 100644 --- a/examples/sse.rs +++ b/examples/sse.rs @@ -2,6 +2,7 @@ use futures::StreamExt; use std::convert::Infallible; use std::time::Duration; use tokio::time::interval; +use tokio_stream::wrappers::IntervalStream; use warp::{sse::Event, Filter}; // create server-sent event @@ -16,7 +17,9 @@ async fn main() { let routes = warp::path("ticks").and(warp::get()).map(|| { let mut counter: u64 = 0; // create server event source - let event_stream = interval(Duration::from_secs(1)).map(move |_| { + let interval = interval(Duration::from_secs(1)); + let stream = IntervalStream::new(interval); + let event_stream = stream.map(move |_| { counter += 1; sse_counter(counter) }); diff --git a/examples/sse_chat.rs b/examples/sse_chat.rs index b6b6221c7..e3c308cf4 100644 --- a/examples/sse_chat.rs +++ b/examples/sse_chat.rs @@ -5,6 +5,7 @@ use std::sync::{ Arc, Mutex, }; use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; use warp::{sse::Event, Filter}; #[tokio::main] @@ -83,6 +84,7 @@ fn user_connected(users: Users) -> impl Stream // Use an unbounded channel to handle buffering and flushing of messages // to the event source... let (tx, rx) = mpsc::unbounded_channel(); + let rx = UnboundedReceiverStream::new(rx); tx.send(Message::UserId(my_id)) // rx is right above, so this cannot fail diff --git a/examples/unix_socket.rs b/examples/unix_socket.rs index 951a28782..531518dfc 100644 --- a/examples/unix_socket.rs +++ b/examples/unix_socket.rs @@ -1,13 +1,14 @@ #![deny(warnings)] use tokio::net::UnixListener; +use tokio_stream::wrappers::UnixListenerStream; #[tokio::main] async fn main() { pretty_env_logger::init(); - let mut listener = UnixListener::bind("/tmp/warp.sock").unwrap(); - let incoming = listener.incoming(); + let listener = UnixListener::bind("/tmp/warp.sock").unwrap(); + let incoming = UnixListenerStream::new(listener); warp::serve(warp::fs::dir("examples/dir")) .run_incoming(incoming) .await; diff --git a/examples/websockets_chat.rs b/examples/websockets_chat.rs index 081a8472f..5086dd6b2 100644 --- a/examples/websockets_chat.rs +++ b/examples/websockets_chat.rs @@ -7,6 +7,7 @@ use std::sync::{ use futures::{FutureExt, StreamExt}; use tokio::sync::{mpsc, RwLock}; +use tokio_stream::wrappers::UnboundedReceiverStream; use warp::ws::{Message, WebSocket}; use warp::Filter; @@ -59,6 +60,7 @@ async fn user_connected(ws: WebSocket, users: Users) { // Use an unbounded channel to handle buffering and flushing of messages // to the websocket... let (tx, rx) = mpsc::unbounded_channel(); + let rx = UnboundedReceiverStream::new(rx); tokio::task::spawn(rx.forward(user_ws_tx).map(|result| { if let Err(e) = result { eprintln!("websocket send error: {}", e); diff --git a/src/filters/body.rs b/src/filters/body.rs index d13e99f83..82fe54453 100644 --- a/src/filters/body.rs +++ b/src/filters/body.rs @@ -7,7 +7,7 @@ use std::fmt; use std::pin::Pin; use std::task::{Context, Poll}; -use bytes::{buf::BufExt, Buf, Bytes}; +use bytes::{Buf, Bytes}; use futures::{future, ready, Stream, TryFutureExt}; use headers::ContentLength; use http::header::CONTENT_TYPE; @@ -131,8 +131,8 @@ pub fn bytes() -> impl Filter + Copy { /// fn full_body(mut body: impl Buf) { /// // It could have several non-contiguous slices of memory... /// while body.has_remaining() { -/// println!("slice = {:?}", body.bytes()); -/// let cnt = body.bytes().len(); +/// println!("slice = {:?}", body.chunk()); +/// let cnt = body.chunk().len(); /// body.advance(cnt); /// } /// } @@ -232,7 +232,7 @@ impl Decode for Json { const WITH_NO_CONTENT_TYPE: bool = true; fn decode(mut buf: B) -> Result { - serde_json::from_slice(&buf.to_bytes()).map_err(Into::into) + serde_json::from_slice(&buf.copy_to_bytes(buf.remaining())).map_err(Into::into) } } diff --git a/src/filters/compression.rs b/src/filters/compression.rs index 6a3086a21..d67f996f6 100644 --- a/src/filters/compression.rs +++ b/src/filters/compression.rs @@ -2,12 +2,13 @@ //! //! Filters that compress the body of a response. -use async_compression::stream::{BrotliEncoder, DeflateEncoder, GzipEncoder}; +use async_compression::tokio::bufread::{BrotliEncoder, DeflateEncoder, GzipEncoder}; use http::header::HeaderValue; use hyper::{ header::{CONTENT_ENCODING, CONTENT_LENGTH}, Body, }; +use tokio_util::io::{ReaderStream, StreamReader}; use crate::filter::{Filter, WrapSealed}; use crate::reject::IsReject; @@ -56,7 +57,9 @@ pub struct Compression { /// ``` pub fn gzip() -> Compression Response + Copy> { let func = move |mut props: CompressionProps| { - let body = Body::wrap_stream(GzipEncoder::new(props.body)); + let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::new(StreamReader::new( + props.body, + )))); props .head .headers @@ -82,7 +85,9 @@ pub fn gzip() -> Compression Response + Copy> { /// ``` pub fn deflate() -> Compression Response + Copy> { let func = move |mut props: CompressionProps| { - let body = Body::wrap_stream(DeflateEncoder::new(props.body)); + let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(StreamReader::new( + props.body, + )))); props .head .headers @@ -108,7 +113,9 @@ pub fn deflate() -> Compression Response + Copy> { /// ``` pub fn brotli() -> Compression Response + Copy> { let func = move |mut props: CompressionProps| { - let body = Body::wrap_stream(BrotliEncoder::new(props.body)); + let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(StreamReader::new( + props.body, + )))); props .head .headers diff --git a/src/filters/fs.rs b/src/filters/fs.rs index 1f04f6bb5..2078095ba 100644 --- a/src/filters/fs.rs +++ b/src/filters/fs.rs @@ -22,7 +22,8 @@ use hyper::Body; use mime_guess; use percent_encoding::percent_decode_str; use tokio::fs::File as TkFile; -use tokio::io::AsyncRead; +use tokio::io::AsyncSeekExt; +use tokio_util::io::poll_read_buf; use crate::filter::{Filter, FilterClone, One}; use crate::reject::{self, Rejection}; @@ -419,7 +420,7 @@ fn file_stream( } reserve_at_least(&mut buf, buf_size); - let n = match ready!(Pin::new(&mut f).poll_read_buf(cx, &mut buf)) { + let n = match ready!(poll_read_buf(Pin::new(&mut f), cx, &mut buf)) { Ok(n) => n as u64, Err(err) => { tracing::debug!("file read error: {}", err); diff --git a/src/filters/sse.rs b/src/filters/sse.rs index 659c74225..31a225142 100644 --- a/src/filters/sse.rs +++ b/src/filters/sse.rs @@ -54,7 +54,7 @@ use http::header::{HeaderValue, CACHE_CONTROL, CONTENT_TYPE}; use hyper::Body; use pin_project::pin_project; use serde_json::{self, Error}; -use tokio::time::{self, Delay}; +use tokio::time::{self, Sleep}; use self::sealed::SseError; use super::header; @@ -386,7 +386,7 @@ impl KeepAlive { S: TryStream + Send + 'static, S::Error: StdError + Send + Sync + 'static, { - let alive_timer = time::delay_for(self.max_interval); + let alive_timer = time::sleep(self.max_interval); SseKeepAlive { event_stream, comment_text: self.comment_text, @@ -403,7 +403,8 @@ struct SseKeepAlive { event_stream: S, comment_text: Cow<'static, str>, max_interval: Duration, - alive_timer: Delay, + #[pin] + alive_timer: Sleep, } /// Keeps event source connection alive when no events sent over a some time. @@ -421,6 +422,7 @@ struct SseKeepAlive { /// use std::convert::Infallible; /// use futures::StreamExt; /// use tokio::time::interval; +/// use tokio_stream::wrappers::IntervalStream; /// use warp::{Filter, Stream, sse::Event}; /// /// // create server-sent event @@ -433,7 +435,9 @@ struct SseKeepAlive { /// .and(warp::get()) /// .map(|| { /// let mut counter: u64 = 0; -/// let event_stream = interval(Duration::from_secs(15)).map(move |_| { +/// let interval = interval(Duration::from_secs(15)); +/// let stream = IntervalStream::new(interval); +/// let event_stream = stream.map(move |_| { /// counter += 1; /// sse_counter(counter) /// }); diff --git a/src/filters/ws.rs b/src/filters/ws.rs index 8475523af..956cfe4ea 100644 --- a/src/filters/ws.rs +++ b/src/filters/ws.rs @@ -6,13 +6,14 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use super::{body, header}; -use crate::filter::{Filter, One}; +use super::header; +use crate::filter::{filter_fn_one, Filter, One}; use crate::reject::Rejection; use crate::reply::{Reply, Response}; use futures::{future, ready, FutureExt, Sink, Stream, TryFutureExt}; use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade}; use http; +use hyper::upgrade::OnUpgrade; use tokio_tungstenite::{ tungstenite::protocol::{self, WebSocketConfig}, WebSocketStream, @@ -57,19 +58,21 @@ pub fn ws() -> impl Filter, Error = Rejection> + Copy { //.and(header::exact2(Upgrade::websocket())) //.and(header::exact2(SecWebsocketVersion::V13)) .and(header::header2::()) - .and(body::body()) - .map(move |key: SecWebsocketKey, body: ::hyper::Body| Ws { - body, - config: None, - key, - }) + .and(on_upgrade()) + .map( + move |key: SecWebsocketKey, on_upgrade: Option| Ws { + config: None, + key, + on_upgrade, + }, + ) } /// Extracted by the [`ws`](ws) filter, and used to finish an upgrade. pub struct Ws { - body: ::hyper::Body, config: Option, key: SecWebsocketKey, + on_upgrade: Option, } impl Ws { @@ -132,23 +135,24 @@ where U: Future + Send + 'static, { fn into_response(self) -> Response { - let on_upgrade = self.on_upgrade; - let config = self.ws.config; - let fut = self - .ws - .body - .on_upgrade() - .and_then(move |upgraded| { - tracing::trace!("websocket upgrade complete"); - WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok) - }) - .and_then(move |socket| on_upgrade(socket).map(Ok)) - .map(|result| { - if let Err(err) = result { - tracing::debug!("ws upgrade error: {}", err); - } - }); - ::tokio::task::spawn(fut); + if let Some(on_upgrade) = self.ws.on_upgrade { + let on_upgrade_cb = self.on_upgrade; + let config = self.ws.config; + let fut = on_upgrade + .and_then(move |upgraded| { + tracing::trace!("websocket upgrade complete"); + WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok) + }) + .and_then(move |socket| on_upgrade_cb(socket).map(Ok)) + .map(|result| { + if let Err(err) = result { + tracing::debug!("ws upgrade error: {}", err); + } + }); + ::tokio::task::spawn(fut); + } else { + tracing::debug!("ws couldn't be upgraded since no upgrade state was present"); + } let mut res = http::Response::default(); @@ -163,6 +167,11 @@ where } } +// Extracts OnUpgrade state from the route. +fn on_upgrade() -> impl Filter,), Error = Rejection> + Copy { + filter_fn_one(|route| future::ready(Ok(route.extensions_mut().remove::()))) +} + /// A websocket `Stream` and `Sink`, provided to `ws` filters. /// /// Ping messages sent from the client will be handled internally by replying with a Pong message. diff --git a/src/route.rs b/src/route.rs index cceb2581f..c692ead38 100644 --- a/src/route.rs +++ b/src/route.rs @@ -75,11 +75,10 @@ impl Route { self.req.extensions() } - /* + #[cfg(feature = "websocket")] pub(crate) fn extensions_mut(&mut self) -> &mut http::Extensions { self.req.extensions_mut() } - */ pub(crate) fn uri(&self) -> &http::Uri { self.req.uri() diff --git a/src/test.rs b/src/test.rs index 50fc0e388..3deb5b10b 100644 --- a/src/test.rs +++ b/src/test.rs @@ -102,6 +102,8 @@ use serde::Serialize; use serde_json; #[cfg(feature = "websocket")] use tokio::sync::{mpsc, oneshot}; +#[cfg(feature = "websocket")] +use tokio_stream::wrappers::UnboundedReceiverStream; use crate::filter::Filter; use crate::reject::IsReject; @@ -451,7 +453,7 @@ impl WsBuilder { } } - /// Execute this Websocket request against te provided filter. + /// Execute this Websocket request against the provided filter. /// /// If the handshake succeeds, returns a `WsClient`. /// @@ -483,6 +485,7 @@ impl WsBuilder { { let (upgraded_tx, upgraded_rx) = oneshot::channel(); let (wr_tx, wr_rx) = mpsc::unbounded_channel(); + let wr_rx = UnboundedReceiverStream::new(wr_rx); let (rd_tx, rd_rx) = mpsc::unbounded_channel(); tokio::spawn(async move { @@ -515,7 +518,7 @@ impl WsBuilder { let upgrade = ::hyper::Client::builder() .build(AddrConnect(addr)) .request(req) - .and_then(|res| res.into_body().on_upgrade()); + .and_then(|res| hyper::upgrade::on(res)); let upgraded = match upgrade.await { Ok(up) => { @@ -576,9 +579,9 @@ impl WsClient { /// Receive a websocket message from the server. pub async fn recv(&mut self) -> Result { self.rx - .next() + .recv() .await - .map(|unbounded_result| unbounded_result.map_err(WsError::new)) + .map(|result| result.map_err(WsError::new)) .unwrap_or_else(|| { // websocket is closed Err(WsError::new("closed")) @@ -588,7 +591,7 @@ impl WsClient { /// Assert the server has closed the connection. pub async fn recv_closed(&mut self) -> Result<(), WsError> { self.rx - .next() + .recv() .await .map(|result| match result { Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))), diff --git a/src/tls.rs b/src/tls.rs index 632c24ddd..44cb7c13c 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -6,7 +6,7 @@ use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use futures::ready; use hyper::server::accept::Accept; @@ -295,8 +295,8 @@ impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { let pin = self.get_mut(); match pin.state { State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { diff --git a/src/transport.rs b/src/transport.rs index 42f0431df..be553e706 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -4,7 +4,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use hyper::server::conn::AddrStream; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pub trait Transport: AsyncRead + AsyncWrite { fn remote_addr(&self) -> Option; @@ -22,8 +22,8 @@ impl AsyncRead for LiftIo { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut self.get_mut().0).poll_read(cx, buf) } } diff --git a/tests/body.rs b/tests/body.rs index 8d4bce7d0..3d4537b05 100644 --- a/tests/body.rs +++ b/tests/body.rs @@ -198,5 +198,5 @@ async fn stream() { let bufs = bufs.unwrap(); assert_eq!(bufs.len(), 1); - assert_eq!(bufs[0].bytes(), b"foo=bar"); + assert_eq!(bufs[0].chunk(), b"foo=bar"); } diff --git a/tests/ws.rs b/tests/ws.rs index 5ff23e482..052a195f4 100644 --- a/tests/ws.rs +++ b/tests/ws.rs @@ -237,6 +237,7 @@ async fn ws_with_query() { .expect("handshake"); } +// Websocket filter that echoes all messages back. fn ws_echo() -> impl Filter + Copy { warp::ws().map(|ws: warp::ws::Ws| { ws.on_upgrade(|websocket| {