Skip to content

Commit

Permalink
server: uniform whitespace handling in rpc calls (#1082)
Browse files Browse the repository at this point in the history
* replace FutureDriver with mpsc and tokio::task

* tokio spawn for calls

* refactor round trip for multiple calls

* cleanup

* cleanup

* ws server: reject calls > 127 starting whitespaces

* fix clippy
  • Loading branch information
niklasad1 authored Apr 17, 2023
1 parent 9c58d09 commit 7364c02
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
13 changes: 13 additions & 0 deletions server/src/tests/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,19 @@ async fn whitespace_is_not_significant() {
let expected = r#"[{"jsonrpc":"2.0","result":3,"id":1}]"#;
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, expected);

// Up to 127 whitespace chars are accepted.
let req = format!("{}{}", " ".repeat(127), r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#);
let response = http_request(req.into(), uri.clone()).await.unwrap();
let expected = r#"{"jsonrpc":"2.0","result":3,"id":1}"#;
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, expected);

// More than 127 whitespace chars are not accepted.
let req = format!("{}{}", " ".repeat(128), r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#);
let response = http_request(req.into(), uri.clone()).await.unwrap();
assert_eq!(response.status, StatusCode::BAD_REQUEST);
assert_eq!(response.body, parse_error(Id::Null));
}

#[tokio::test]
Expand Down
12 changes: 12 additions & 0 deletions server/src/tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ async fn garbage_request_fails() {

#[tokio::test]
async fn whitespace_is_not_significant() {
init_logger();

let addr = server().await;
let mut client = WebSocketTestClient::new(addr).await.unwrap();

Expand All @@ -275,6 +277,16 @@ async fn whitespace_is_not_significant() {
let req = r#" [{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}]"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, r#"[{"jsonrpc":"2.0","result":3,"id":1}]"#);

// Up to 127 whitespace chars are accepted.
let req = format!("{}{}", " ".repeat(127), r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#);
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, ok_response(JsonValue::Number(3u32.into()), Id::Num(1)));

// More than 127 whitespace chars are not accepted.
let req = format!("{}{}", " ".repeat(128), r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#);
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, parse_error(Id::Null));
}

#[tokio::test]
Expand Down
23 changes: 11 additions & 12 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub(crate) async fn send_ping(sender: &mut Sender) -> Result<(), Error> {

#[derive(Debug, Clone)]
pub(crate) struct Batch<'a, L: Logger> {
pub(crate) data: Vec<u8>,
pub(crate) data: &'a [u8],
pub(crate) call: CallData<'a, L>,
pub(crate) max_len: usize,
}
Expand All @@ -78,7 +78,7 @@ pub(crate) struct CallData<'a, L: Logger> {
pub(crate) async fn process_batch_request<L: Logger>(b: Batch<'_, L>) -> Option<String> {
let Batch { data, call, max_len } = b;

if let Ok(batch) = serde_json::from_slice::<Vec<&JsonRawValue>>(&data) {
if let Ok(batch) = serde_json::from_slice::<Vec<&JsonRawValue>>(data) {
if batch.len() > max_len {
return Some(batch_response_error(Id::Null, reject_too_big_batch_request(max_len)));
}
Expand Down Expand Up @@ -126,15 +126,15 @@ pub(crate) async fn process_batch_request<L: Logger>(b: Batch<'_, L>) -> Option<
}

pub(crate) async fn process_single_request<L: Logger>(
data: Vec<u8>,
data: &[u8],
call: CallData<'_, L>,
) -> Option<CallOrSubscription> {
if let Ok(req) = serde_json::from_slice::<Request>(&data) {
if let Ok(req) = serde_json::from_slice::<Request>(data) {
Some(execute_call_with_tracing(req, call).await)
} else if serde_json::from_slice::<Notif>(&data).is_ok() {
} else if serde_json::from_slice::<Notif>(data).is_ok() {
None
} else {
let (id, code) = prepare_error(&data);
let (id, code) = prepare_error(data);
Some(CallOrSubscription::Call(MethodResponse::error(id, ErrorObject::from(code))))
}
}
Expand Down Expand Up @@ -341,7 +341,6 @@ pub(crate) async fn background_task<L: Logger>(

logger.on_disconnect(remote_addr, TransportProtocol::WebSocket);
drop(conn);

result
}

Expand Down Expand Up @@ -490,10 +489,10 @@ async fn execute_unchecked_call<L: Logger>(params: ExecuteCallParams<L>) {
} = params;

let request_start = logger.on_request(TransportProtocol::WebSocket);
let first_non_whitespace = data.iter().find(|byte| !byte.is_ascii_whitespace());
let first_non_whitespace = data.iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace());

match first_non_whitespace {
Some(b'{') => {
Some((start, b'{')) => {
let call_data = CallData {
conn_id: conn_id as usize,
bounded_subscriptions,
Expand All @@ -506,7 +505,7 @@ async fn execute_unchecked_call<L: Logger>(params: ExecuteCallParams<L>) {
request_start,
};

if let Some(rp) = process_single_request(data, call_data).await {
if let Some(rp) = process_single_request(&data[start..], call_data).await {
match rp {
CallOrSubscription::Subscription(r) => {
logger.on_response(&r.result, request_start, TransportProtocol::WebSocket);
Expand All @@ -519,7 +518,7 @@ async fn execute_unchecked_call<L: Logger>(params: ExecuteCallParams<L>) {
}
}
}
Some(b'[') => {
Some((start, b'[')) => {
let limit = match batch_requests_config {
BatchRequestConfig::Disabled => {
let response = MethodResponse::error(
Expand All @@ -546,7 +545,7 @@ async fn execute_unchecked_call<L: Logger>(params: ExecuteCallParams<L>) {
request_start,
};

let response = process_batch_request(Batch { data, call: call_data, max_len: limit }).await;
let response = process_batch_request(Batch { data: &data[start..], call: call_data, max_len: limit }).await;

if let Some(response) = response {
tx_log_from_str(&response, max_log_length);
Expand Down

0 comments on commit 7364c02

Please sign in to comment.