Skip to content

Commit

Permalink
feat(server): add TowerService::on_session_close (#1284)
Browse files Browse the repository at this point in the history
* add TowerService build and notify on session close

* refactor the API

* clarify docs

* add test for on_session_close
  • Loading branch information
niklasad1 authored Feb 6, 2024
1 parent 387f52f commit 8470f2b
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 22 deletions.
70 changes: 55 additions & 15 deletions examples/examples/jsonrpsee_as_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use futures::FutureExt;
use hyper::header::AUTHORIZATION;
use hyper::server::conn::AddrStream;
use hyper::HeaderMap;
Expand All @@ -45,16 +46,18 @@ use jsonrpsee::proc_macros::rpc;
use jsonrpsee::server::middleware::rpc::{ResponseFuture, RpcServiceBuilder, RpcServiceT};
use jsonrpsee::server::{stop_channel, ServerHandle, StopHandle, TowerServiceBuilder};
use jsonrpsee::types::{ErrorObject, ErrorObjectOwned, Request};
use jsonrpsee::ws_client::HeaderValue;
use jsonrpsee::ws_client::{HeaderValue, WsClientBuilder};
use jsonrpsee::{MethodResponse, Methods};
use tower::Service;
use tower_http::cors::CorsLayer;
use tracing_subscriber::util::SubscriberInitExt;

#[derive(Default, Clone)]
#[derive(Default, Clone, Debug)]
struct Metrics {
ws_connections: Arc<AtomicUsize>,
http_connections: Arc<AtomicUsize>,
opened_ws_connections: Arc<AtomicUsize>,
closed_ws_connections: Arc<AtomicUsize>,
http_calls: Arc<AtomicUsize>,
success_http_calls: Arc<AtomicUsize>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -106,7 +109,9 @@ async fn main() -> anyhow::Result<()> {
let filter = tracing_subscriber::EnvFilter::try_from_default_env()?;
tracing_subscriber::FmtSubscriber::builder().with_env_filter(filter).finish().try_init()?;

let handle = run_server();
let metrics = Metrics::default();

let handle = run_server(metrics.clone());
tokio::spawn(handle.stopped());

{
Expand All @@ -117,6 +122,14 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("response: {x}");
}

{
let client = WsClientBuilder::default().build("ws://127.0.0.1:9944").await.unwrap();

// Fails because the authorization header is missing.
let x = client.trusted_call().await.unwrap_err();
tracing::info!("response: {x}");
}

{
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("don't care in this example"));
Expand All @@ -127,10 +140,12 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("response: {x}");
}

tracing::info!("{:?}", metrics);

Ok(())
}

fn run_server() -> ServerHandle {
fn run_server(metrics: Metrics) -> ServerHandle {
use hyper::service::{make_service_fn, service_fn};

let addr = SocketAddr::from(([127, 0, 0, 1], 9944));
Expand Down Expand Up @@ -159,7 +174,7 @@ fn run_server() -> ServerHandle {
let per_conn = PerConnection {
methods: ().into_rpc().into(),
stop_handle: stop_handle.clone(),
metrics: Metrics::default(),
metrics,
svc_builder: jsonrpsee::server::Server::builder()
.set_http_middleware(tower::ServiceBuilder::new().layer(CorsLayer::permissive()))
.max_connections(33)
Expand All @@ -185,15 +200,40 @@ fn run_server() -> ServerHandle {

let mut svc = svc_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle);

async move {
// You can't determine whether the websocket upgrade handshake failed or not here.
let rp = svc.call(req).await;
if is_websocket {
metrics.ws_connections.fetch_add(1, Ordering::Relaxed);
} else {
metrics.http_connections.fetch_add(1, Ordering::Relaxed);
if is_websocket {
// Utilize the session close future to know when the actual WebSocket
// session was closed.
let session_close = svc.on_session_closed();

// A little bit weird API but the response to HTTP request must be returned below
// and we spawn a task to register when the session is closed.
tokio::spawn(async move {
session_close.await;
tracing::info!("Closed WebSocket connection");
metrics.closed_ws_connections.fetch_add(1, Ordering::Relaxed);
});

async move {
tracing::info!("Opened WebSocket connection");
metrics.opened_ws_connections.fetch_add(1, Ordering::Relaxed);
svc.call(req).await
}
.boxed()
} else {
// HTTP.
async move {
tracing::info!("Opened HTTP connection");
metrics.http_calls.fetch_add(1, Ordering::Relaxed);
let rp = svc.call(req).await;

if rp.is_ok() {
metrics.success_http_calls.fetch_add(1, Ordering::Relaxed);
}

tracing::info!("Closed HTTP connection");
rp
}
rp
.boxed()
}
}))
}
Expand Down
2 changes: 1 addition & 1 deletion server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ serde_json = { version = "1", features = ["raw_value"] }
soketto = { version = "0.7.1", features = ["http"] }
tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
tokio-stream = "0.1.7"
tokio-stream = { version = "0.1.7", features = ["sync"] }
hyper = { version = "0.14", features = ["server", "http1", "http2"] }
tower = "0.4.13"
thiserror = "1"
Expand Down
40 changes: 39 additions & 1 deletion server/src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures_util::{Stream, StreamExt};
use futures_util::{Future, Stream, StreamExt};
use pin_project::pin_project;
use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Interval;
use tokio_stream::wrappers::BroadcastStream;

/// Create channel to determine whether
/// the server shall continue to run or not.
Expand Down Expand Up @@ -157,3 +158,40 @@ impl Stream for IntervalStream {
}
}
}

#[derive(Debug, Clone)]
pub(crate) struct SessionClose(tokio::sync::broadcast::Sender<()>);

impl SessionClose {
pub(crate) fn close(self) {
let _ = self.0.send(());
}

pub(crate) fn closed(&self) -> SessionClosedFuture {
SessionClosedFuture(BroadcastStream::new(self.0.subscribe()))
}
}

/// A future that resolves when the connection has been closed.
#[derive(Debug)]
pub struct SessionClosedFuture(BroadcastStream<()>);

impl Future for SessionClosedFuture {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
// Only message is only sent and
// ignore can't keep up errors.
Poll::Ready(_) => Poll::Ready(()),
}
}
}

pub(crate) fn session_close() -> (SessionClose, SessionClosedFuture) {
// SessionClosedFuture is closed after one message has been recevied
// and max one message is handled then it's closed.
let (tx, rx) = tokio::sync::broadcast::channel(1);
(SessionClose(tx), SessionClosedFuture(BroadcastStream::new(rx)))
}
25 changes: 24 additions & 1 deletion server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;

use crate::future::{ConnectionGuard, ServerHandle, StopHandle};
use crate::future::{session_close, ConnectionGuard, ServerHandle, SessionClose, SessionClosedFuture, StopHandle};
use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT};
use crate::transport::ws::BackgroundTaskParams;
use crate::transport::{http, ws};
Expand Down Expand Up @@ -501,6 +501,7 @@ impl<RpcMiddleware, HttpMiddleware> TowerServiceBuilder<RpcMiddleware, HttpMiddl
conn_guard: self.conn_guard,
server_cfg: self.server_cfg,
},
on_session_close: None,
};

TowerService { rpc_middleware, http_middleware: self.http_middleware }
Expand Down Expand Up @@ -941,6 +942,24 @@ pub struct TowerService<RpcMiddleware, HttpMiddleware> {
http_middleware: tower::ServiceBuilder<HttpMiddleware>,
}

impl<RpcMiddleware, HttpMiddleware> TowerService<RpcMiddleware, HttpMiddleware> {
/// A future that returns when the connection has been closed.
///
/// This method must be called before every [`TowerService::call`]
/// because the `SessionClosedFuture` may already been consumed or
/// not used.
pub fn on_session_closed(&mut self) -> SessionClosedFuture {
if let Some(n) = self.rpc_middleware.on_session_close.as_mut() {
// If it's called more then once another listener is created.
n.closed()
} else {
let (session_close, fut) = session_close();
self.rpc_middleware.on_session_close = Some(session_close);
fut
}
}
}

impl<RpcMiddleware, HttpMiddleware> hyper::service::Service<hyper::Request<hyper::Body>>
for TowerService<RpcMiddleware, HttpMiddleware>
where
Expand Down Expand Up @@ -979,6 +998,7 @@ where
pub struct TowerServiceNoHttp<L> {
inner: ServiceData,
rpc_middleware: RpcServiceBuilder<L>,
on_session_close: Option<SessionClose>,
}

impl<RpcMiddleware> hyper::service::Service<hyper::Request<hyper::Body>> for TowerServiceNoHttp<RpcMiddleware>
Expand All @@ -1004,6 +1024,7 @@ where
let conn_guard = &self.inner.conn_guard;
let stop_handle = self.inner.stop_handle.clone();
let conn_id = self.inner.conn_id;
let on_session_close = self.on_session_close.take();

tracing::trace!(target: LOG_TARGET, "{:?}", request);

Expand Down Expand Up @@ -1076,6 +1097,7 @@ where
sink,
rx,
pending_calls_completed,
on_session_close,
};

ws::background_task(params).await;
Expand Down Expand Up @@ -1176,6 +1198,7 @@ fn process_connection<'a, RpcMiddleware, HttpMiddleware, U>(
conn_guard: conn_guard.clone(),
},
rpc_middleware,
on_session_close: None,
};

let service = http_middleware.service(tower_service);
Expand Down
71 changes: 69 additions & 2 deletions server/src/tests/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use std::fmt;
use std::error::Error as StdError;
use std::net::SocketAddr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::{fmt, sync::atomic::AtomicUsize};

use crate::{RpcModule, ServerBuilder, ServerHandle};
use crate::{stop_channel, RpcModule, Server, ServerBuilder, ServerHandle};

use futures_util::FutureExt;
use hyper::server::conn::AddrStream;
use jsonrpsee_core::server::Methods;
use jsonrpsee_core::{DeserializeOwned, RpcResult, StringError};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::{error::ErrorCode, ErrorObject, ErrorObjectOwned, Response, ResponseSuccess};
use tower::Service;
use tracing_subscriber::{EnvFilter, FmtSubscriber};

pub(crate) struct TestContext;
Expand Down Expand Up @@ -194,3 +201,63 @@ impl From<MyAppError> for ErrorObjectOwned {
fn invalid_params() -> ErrorObjectOwned {
ErrorCode::InvalidParams.into()
}

#[derive(Debug, Clone, Default)]
pub(crate) struct Metrics {
pub(crate) ws_sessions_opened: Arc<AtomicUsize>,
pub(crate) ws_sessions_closed: Arc<AtomicUsize>,
}

pub(crate) fn ws_server_with_stats(metrics: Metrics) -> SocketAddr {
use hyper::service::{make_service_fn, service_fn};

let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let (stop_handle, server_handle) = stop_channel();
let stop_handle2 = stop_handle.clone();

// And a MakeService to handle each connection...
let make_service = make_service_fn(move |_conn: &AddrStream| {
let stop_handle = stop_handle2.clone();
let metrics = metrics.clone();

async move {
Ok::<_, Box<dyn StdError + Send + Sync>>(service_fn(move |req| {
let is_websocket = crate::ws::is_upgrade_request(&req);
let metrics = metrics.clone();
let stop_handle = stop_handle.clone();

let mut svc =
Server::builder().max_connections(33).to_service_builder().build(Methods::new(), stop_handle);

if is_websocket {
// This should work for each callback.
let session_close1 = svc.on_session_closed();
let session_close2 = svc.on_session_closed();

tokio::spawn(async move {
metrics.ws_sessions_opened.fetch_add(1, Ordering::SeqCst);
tokio::join!(session_close2, session_close1);
metrics.ws_sessions_closed.fetch_add(1, Ordering::SeqCst);
});

async move { svc.call(req).await }.boxed()
} else {
// HTTP.
async move { svc.call(req).await }.boxed()
}
}))
}
});

let server = hyper::Server::bind(&addr).serve(make_service);

let addr = server.local_addr();

tokio::spawn(async move {
let graceful = server.with_graceful_shutdown(async move { stop_handle.shutdown().await });
graceful.await.unwrap();
drop(server_handle)
});

addr
}
27 changes: 26 additions & 1 deletion server/src/tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use std::sync::atomic::Ordering;
use std::time::Duration;

use crate::tests::helpers::{deser_call, init_logger, server_with_context};
use crate::tests::helpers::{deser_call, init_logger, server_with_context, ws_server_with_stats, Metrics};
use crate::types::SubscriptionId;
use crate::{BatchRequestConfig, RegisterMethodError};
use crate::{RpcModule, ServerBuilder};
Expand Down Expand Up @@ -874,6 +875,30 @@ async fn drop_client_with_pending_calls_works() {
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}

#[tokio::test]
async fn server_notify_on_conn_close() {
init_logger();

let metrics = Metrics::default();
let addr = ws_server_with_stats(metrics.clone());

let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

// Wait for the server to process
tokio::time::sleep(Duration::from_millis(100)).await;

assert_eq!(metrics.ws_sessions_opened.load(Ordering::SeqCst), 1);
assert_eq!(metrics.ws_sessions_closed.load(Ordering::SeqCst), 0);

client.close().with_default_timeout().await.unwrap().unwrap();

// Wait for the server to process
tokio::time::sleep(Duration::from_millis(100)).await;

assert_eq!(metrics.ws_sessions_opened.load(Ordering::SeqCst), 1);
assert_eq!(metrics.ws_sessions_closed.load(Ordering::SeqCst), 1);
}

async fn server_with_infinite_call(
timeout: Duration,
tx: tokio::sync::mpsc::UnboundedSender<()>,
Expand Down
Loading

0 comments on commit 8470f2b

Please sign in to comment.