Skip to content

Commit

Permalink
Upgrade to Tokio 1.0 (#753)
Browse files Browse the repository at this point in the history
Co-authored-by: janpetschexain <[email protected]>
Co-authored-by: João Oliveira <[email protected]>
Co-authored-by: Paolo Barbolini <[email protected]>
Co-authored-by: Muhammad Hamza <[email protected]>
Co-authored-by: teenjuna <[email protected]>
  • Loading branch information
6 people authored Jan 15, 2021
1 parent f59de24 commit ffefea0
Show file tree
Hide file tree
Showing 18 changed files with 100 additions and 68 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ jobs:
benches: true
- build: tls
features: "--features tls"
- build: uds
features: "--features tokio/uds"
- build: no-default-features
features: "--no-default-features"
- build: compression
Expand Down
18 changes: 10 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ edition = "2018"
all-features = true

[dependencies]
async-compression = { version = "0.3.1", features = ["brotli", "deflate", "gzip", "stream"], optional = true }
bytes = "0.5"
async-compression = { version = "0.3.7", features = ["brotli", "deflate", "gzip", "tokio"], optional = true }
bytes = "1.0"
futures = { version = "0.3", default-features = false, features = ["alloc"] }
headers = "0.3"
http = "0.2"
hyper = { version = "0.13", features = ["stream"] }
hyper = { version = "0.14", features = ["stream", "server", "http1", "http2", "tcp", "client"] }
log = "0.4"
mime = "0.3"
mime_guess = "2.0.0"
Expand All @@ -31,23 +31,26 @@ scoped-tls = "1.0"
serde = "1.0"
serde_json = "1.0"
serde_urlencoded = "0.7"
tokio = { version = "0.2", features = ["fs", "stream", "sync", "time"] }
tokio = { version = "1.0", features = ["fs", "sync", "time"] }
tokio-stream = "0.1.1"
tokio-util = { version = "0.6", features = ["io"] }
tracing = { version = "0.1", default-features = false, features = ["log", "std"] }
tracing-futures = { version = "0.2", default-features = false, features = ["std-future"] }
tower-service = "0.3"
# tls is enabled by default, we don't want that yet
tokio-tungstenite = { version = "0.11", default-features = false, optional = true }
tokio-tungstenite = { version = "0.13", default-features = false, optional = true }
percent-encoding = "2.1"
pin-project = "1.0"
tokio-rustls = { version = "0.14", optional = true }
tokio-rustls = { version = "0.22", optional = true }

[dev-dependencies]
pretty_env_logger = "0.4"
tracing-subscriber = "0.2.7"
tracing-log = "0.1"
serde_derive = "1.0"
handlebars = "3.0.0"
tokio = { version = "0.2", features = ["macros"] }
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
tokio-stream = { version = "0.1.1", features = ["net"] }
listenfd = "0.3"

[features]
Expand Down Expand Up @@ -78,7 +81,6 @@ required-features = ["compression"]

[[example]]
name = "unix_socket"
required-features = ["tokio/uds"]

[[example]]
name = "websockets"
Expand Down
2 changes: 1 addition & 1 deletion examples/futures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async fn main() {
}

async fn sleepy(Seconds(seconds): Seconds) -> Result<impl warp::Reply, Infallible> {
tokio::time::delay_for(Duration::from_secs(seconds)).await;
tokio::time::sleep(Duration::from_secs(seconds)).await;
Ok(format!("I waited {} seconds!", seconds))
}

Expand Down
5 changes: 4 additions & 1 deletion examples/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use futures::StreamExt;
use std::convert::Infallible;
use std::time::Duration;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;
use warp::{sse::Event, Filter};

// create server-sent event
Expand All @@ -16,7 +17,9 @@ async fn main() {
let routes = warp::path("ticks").and(warp::get()).map(|| {
let mut counter: u64 = 0;
// create server event source
let event_stream = interval(Duration::from_secs(1)).map(move |_| {
let interval = interval(Duration::from_secs(1));
let stream = IntervalStream::new(interval);
let event_stream = stream.map(move |_| {
counter += 1;
sse_counter(counter)
});
Expand Down
2 changes: 2 additions & 0 deletions examples/sse_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::{
Arc, Mutex,
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use warp::{sse::Event, Filter};

#[tokio::main]
Expand Down Expand Up @@ -83,6 +84,7 @@ fn user_connected(users: Users) -> impl Stream<Item = Result<Event, warp::Error>
// Use an unbounded channel to handle buffering and flushing of messages
// to the event source...
let (tx, rx) = mpsc::unbounded_channel();
let rx = UnboundedReceiverStream::new(rx);

tx.send(Message::UserId(my_id))
// rx is right above, so this cannot fail
Expand Down
5 changes: 3 additions & 2 deletions examples/unix_socket.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#![deny(warnings)]

use tokio::net::UnixListener;
use tokio_stream::wrappers::UnixListenerStream;

#[tokio::main]
async fn main() {
pretty_env_logger::init();

let mut listener = UnixListener::bind("/tmp/warp.sock").unwrap();
let incoming = listener.incoming();
let listener = UnixListener::bind("/tmp/warp.sock").unwrap();
let incoming = UnixListenerStream::new(listener);
warp::serve(warp::fs::dir("examples/dir"))
.run_incoming(incoming)
.await;
Expand Down
2 changes: 2 additions & 0 deletions examples/websockets_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::{

use futures::{FutureExt, StreamExt};
use tokio::sync::{mpsc, RwLock};
use tokio_stream::wrappers::UnboundedReceiverStream;
use warp::ws::{Message, WebSocket};
use warp::Filter;

Expand Down Expand Up @@ -59,6 +60,7 @@ async fn user_connected(ws: WebSocket, users: Users) {
// Use an unbounded channel to handle buffering and flushing of messages
// to the websocket...
let (tx, rx) = mpsc::unbounded_channel();
let rx = UnboundedReceiverStream::new(rx);
tokio::task::spawn(rx.forward(user_ws_tx).map(|result| {
if let Err(e) = result {
eprintln!("websocket send error: {}", e);
Expand Down
8 changes: 4 additions & 4 deletions src/filters/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};

use bytes::{buf::BufExt, Buf, Bytes};
use bytes::{Buf, Bytes};
use futures::{future, ready, Stream, TryFutureExt};
use headers::ContentLength;
use http::header::CONTENT_TYPE;
Expand Down Expand Up @@ -131,8 +131,8 @@ pub fn bytes() -> impl Filter<Extract = (Bytes,), Error = Rejection> + Copy {
/// fn full_body(mut body: impl Buf) {
/// // It could have several non-contiguous slices of memory...
/// while body.has_remaining() {
/// println!("slice = {:?}", body.bytes());
/// let cnt = body.bytes().len();
/// println!("slice = {:?}", body.chunk());
/// let cnt = body.chunk().len();
/// body.advance(cnt);
/// }
/// }
Expand Down Expand Up @@ -232,7 +232,7 @@ impl Decode for Json {
const WITH_NO_CONTENT_TYPE: bool = true;

fn decode<B: Buf, T: DeserializeOwned>(mut buf: B) -> Result<T, BoxError> {
serde_json::from_slice(&buf.to_bytes()).map_err(Into::into)
serde_json::from_slice(&buf.copy_to_bytes(buf.remaining())).map_err(Into::into)
}
}

Expand Down
15 changes: 11 additions & 4 deletions src/filters/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
//!
//! Filters that compress the body of a response.

use async_compression::stream::{BrotliEncoder, DeflateEncoder, GzipEncoder};
use async_compression::tokio::bufread::{BrotliEncoder, DeflateEncoder, GzipEncoder};
use http::header::HeaderValue;
use hyper::{
header::{CONTENT_ENCODING, CONTENT_LENGTH},
Body,
};
use tokio_util::io::{ReaderStream, StreamReader};

use crate::filter::{Filter, WrapSealed};
use crate::reject::IsReject;
Expand Down Expand Up @@ -56,7 +57,9 @@ pub struct Compression<F> {
/// ```
pub fn gzip() -> Compression<impl Fn(CompressionProps) -> Response + Copy> {
let func = move |mut props: CompressionProps| {
let body = Body::wrap_stream(GzipEncoder::new(props.body));
let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::new(StreamReader::new(
props.body,
))));
props
.head
.headers
Expand All @@ -82,7 +85,9 @@ pub fn gzip() -> Compression<impl Fn(CompressionProps) -> Response + Copy> {
/// ```
pub fn deflate() -> Compression<impl Fn(CompressionProps) -> Response + Copy> {
let func = move |mut props: CompressionProps| {
let body = Body::wrap_stream(DeflateEncoder::new(props.body));
let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(StreamReader::new(
props.body,
))));
props
.head
.headers
Expand All @@ -108,7 +113,9 @@ pub fn deflate() -> Compression<impl Fn(CompressionProps) -> Response + Copy> {
/// ```
pub fn brotli() -> Compression<impl Fn(CompressionProps) -> Response + Copy> {
let func = move |mut props: CompressionProps| {
let body = Body::wrap_stream(BrotliEncoder::new(props.body));
let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(StreamReader::new(
props.body,
))));
props
.head
.headers
Expand Down
5 changes: 3 additions & 2 deletions src/filters/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use hyper::Body;
use mime_guess;
use percent_encoding::percent_decode_str;
use tokio::fs::File as TkFile;
use tokio::io::AsyncRead;
use tokio::io::AsyncSeekExt;
use tokio_util::io::poll_read_buf;

use crate::filter::{Filter, FilterClone, One};
use crate::reject::{self, Rejection};
Expand Down Expand Up @@ -419,7 +420,7 @@ fn file_stream(
}
reserve_at_least(&mut buf, buf_size);

let n = match ready!(Pin::new(&mut f).poll_read_buf(cx, &mut buf)) {
let n = match ready!(poll_read_buf(Pin::new(&mut f), cx, &mut buf)) {
Ok(n) => n as u64,
Err(err) => {
tracing::debug!("file read error: {}", err);
Expand Down
12 changes: 8 additions & 4 deletions src/filters/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use http::header::{HeaderValue, CACHE_CONTROL, CONTENT_TYPE};
use hyper::Body;
use pin_project::pin_project;
use serde_json::{self, Error};
use tokio::time::{self, Delay};
use tokio::time::{self, Sleep};

use self::sealed::SseError;
use super::header;
Expand Down Expand Up @@ -386,7 +386,7 @@ impl KeepAlive {
S: TryStream<Ok = Event> + Send + 'static,
S::Error: StdError + Send + Sync + 'static,
{
let alive_timer = time::delay_for(self.max_interval);
let alive_timer = time::sleep(self.max_interval);
SseKeepAlive {
event_stream,
comment_text: self.comment_text,
Expand All @@ -403,7 +403,8 @@ struct SseKeepAlive<S> {
event_stream: S,
comment_text: Cow<'static, str>,
max_interval: Duration,
alive_timer: Delay,
#[pin]
alive_timer: Sleep,
}

/// Keeps event source connection alive when no events sent over a some time.
Expand All @@ -421,6 +422,7 @@ struct SseKeepAlive<S> {
/// use std::convert::Infallible;
/// use futures::StreamExt;
/// use tokio::time::interval;
/// use tokio_stream::wrappers::IntervalStream;
/// use warp::{Filter, Stream, sse::Event};
///
/// // create server-sent event
Expand All @@ -433,7 +435,9 @@ struct SseKeepAlive<S> {
/// .and(warp::get())
/// .map(|| {
/// let mut counter: u64 = 0;
/// let event_stream = interval(Duration::from_secs(15)).map(move |_| {
/// let interval = interval(Duration::from_secs(15));
/// let stream = IntervalStream::new(interval);
/// let event_stream = stream.map(move |_| {
/// counter += 1;
/// sse_counter(counter)
/// });
Expand Down
61 changes: 35 additions & 26 deletions src/filters/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use super::{body, header};
use crate::filter::{Filter, One};
use super::header;
use crate::filter::{filter_fn_one, Filter, One};
use crate::reject::Rejection;
use crate::reply::{Reply, Response};
use futures::{future, ready, FutureExt, Sink, Stream, TryFutureExt};
use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
use http;
use hyper::upgrade::OnUpgrade;
use tokio_tungstenite::{
tungstenite::protocol::{self, WebSocketConfig},
WebSocketStream,
Expand Down Expand Up @@ -57,19 +58,21 @@ pub fn ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy {
//.and(header::exact2(Upgrade::websocket()))
//.and(header::exact2(SecWebsocketVersion::V13))
.and(header::header2::<SecWebsocketKey>())
.and(body::body())
.map(move |key: SecWebsocketKey, body: ::hyper::Body| Ws {
body,
config: None,
key,
})
.and(on_upgrade())
.map(
move |key: SecWebsocketKey, on_upgrade: Option<OnUpgrade>| Ws {
config: None,
key,
on_upgrade,
},
)
}

/// Extracted by the [`ws`](ws) filter, and used to finish an upgrade.
pub struct Ws {
body: ::hyper::Body,
config: Option<WebSocketConfig>,
key: SecWebsocketKey,
on_upgrade: Option<OnUpgrade>,
}

impl Ws {
Expand Down Expand Up @@ -132,23 +135,24 @@ where
U: Future<Output = ()> + Send + 'static,
{
fn into_response(self) -> Response {
let on_upgrade = self.on_upgrade;
let config = self.ws.config;
let fut = self
.ws
.body
.on_upgrade()
.and_then(move |upgraded| {
tracing::trace!("websocket upgrade complete");
WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
})
.and_then(move |socket| on_upgrade(socket).map(Ok))
.map(|result| {
if let Err(err) = result {
tracing::debug!("ws upgrade error: {}", err);
}
});
::tokio::task::spawn(fut);
if let Some(on_upgrade) = self.ws.on_upgrade {
let on_upgrade_cb = self.on_upgrade;
let config = self.ws.config;
let fut = on_upgrade
.and_then(move |upgraded| {
tracing::trace!("websocket upgrade complete");
WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
})
.and_then(move |socket| on_upgrade_cb(socket).map(Ok))
.map(|result| {
if let Err(err) = result {
tracing::debug!("ws upgrade error: {}", err);
}
});
::tokio::task::spawn(fut);
} else {
tracing::debug!("ws couldn't be upgraded since no upgrade state was present");
}

let mut res = http::Response::default();

Expand All @@ -163,6 +167,11 @@ where
}
}

// Extracts OnUpgrade state from the route.
fn on_upgrade() -> impl Filter<Extract = (Option<OnUpgrade>,), Error = Rejection> + Copy {
filter_fn_one(|route| future::ready(Ok(route.extensions_mut().remove::<OnUpgrade>())))
}

/// A websocket `Stream` and `Sink`, provided to `ws` filters.
///
/// Ping messages sent from the client will be handled internally by replying with a Pong message.
Expand Down
3 changes: 1 addition & 2 deletions src/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,10 @@ impl Route {
self.req.extensions()
}

/*
#[cfg(feature = "websocket")]
pub(crate) fn extensions_mut(&mut self) -> &mut http::Extensions {
self.req.extensions_mut()
}
*/

pub(crate) fn uri(&self) -> &http::Uri {
self.req.uri()
Expand Down
Loading

0 comments on commit ffefea0

Please sign in to comment.