Skip to content

Commit

Permalink
Return ServiceError when handling push
Browse files Browse the repository at this point in the history
  • Loading branch information
threema-donat committed May 17, 2024
1 parent 2d80360 commit bda6b04
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 60 deletions.
35 changes: 14 additions & 21 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use a2::error::Error as A2Error;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use reqwest::Error as ReqwestError;
use axum::response::{IntoResponse, Response};
use reqwest::{Error as ReqwestError, StatusCode};
use thiserror::Error;

#[derive(Error, Debug)]
Expand Down Expand Up @@ -54,25 +51,21 @@ pub enum InfluxdbError {
Other(String),
}

#[derive(Debug)]
pub struct ServiceError(anyhow::Error);
#[derive(Error, Debug)]
pub enum ServiceError {
#[error("Missing content type")]
MissingContentType,
#[error("Invalid content type: {0}")]
InvalidContentType(String),
#[error("Missing parameters")]
MissingParams,
#[error("Invalid parameters")]
InvalidParams,
}

// Tell axum how to convert `ServiceError` into a response.
impl IntoResponse for ServiceError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}

impl<E> From<E> for ServiceError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
}
62 changes: 23 additions & 39 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use crate::{
http_client,
influxdb::Influxdb,
push::{
apns, fcm,
fcm::FcmEndpointConfig,
apns,
fcm::{self, FcmEndpointConfig},
hms::{self, HmsContext, HmsEndpointConfig},
threema_gateway, ApnsToken, FcmToken, HmsToken, PushToken,
},
Expand Down Expand Up @@ -192,41 +192,25 @@ fn get_router(state: AppState) -> Router {
.with_state(state)
}

mod responses {
use super::*;

/// Return a generic "400 bad request" response.
pub fn bad_request(body: impl Into<Body>) -> axum::response::Response<Body> {
Response::builder()
.status(reqwest::StatusCode::BAD_REQUEST)
.header(CONTENT_TYPE, "text/plain")
.body(body.into())
.unwrap()
}
}

/// Main push handling entry point.
///
/// Handle a request, return a response.
async fn handle_push_request(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
body: axum::body::Bytes,
) -> Result<axum::response::Response, ServiceError> {
) -> Result<Response, ServiceError> {
// Verify content type
let content_type = headers.get(CONTENT_TYPE).and_then(|h| h.to_str().ok());
match content_type {
Some(ct) if ct.starts_with("application/x-www-form-urlencoded") => {}
Some(ct) => {
warn!("Bad request, invalid content type: {}", ct);
return Ok(responses::bad_request(format!(
"Invalid content type: {}",
ct
)));
return Err(ServiceError::InvalidContentType(ct.to_owned()));
}
None => {
warn!("Bad request, missing content type");
return Ok(responses::bad_request("Missing content type"));
return Err(ServiceError::MissingContentType);
}
}

Expand All @@ -235,7 +219,7 @@ async fn handle_push_request(

// Validate parameters
if parsed.is_empty() {
return Ok(responses::bad_request("Invalid or missing parameters"));
return Err(ServiceError::MissingParams);
}

/// Iterate over parameters and find first matching key.
Expand All @@ -257,7 +241,7 @@ async fn handle_push_request(
Some(v) => v,
None => {
warn!("Missing request parameter: {}", $name);
return Ok(responses::bad_request("Invalid or missing parameters"));
return Err(ServiceError::MissingParams);
}
}
};
Expand Down Expand Up @@ -292,29 +276,29 @@ async fn handle_push_request(
let identity = find_or_bad_request!("identity").to_string();
if identity.len() != 8 || identity.starts_with('*') {
warn!("Got push request with invalid identity: {}", identity);
return Ok(responses::bad_request("Invalid or missing parameters"));
return Err(ServiceError::InvalidParams);
}
let public_key_hex = find_or_bad_request!("public_key");
if public_key_hex.len() != 64 {
warn!(
"Got push request with invalid public key length: {}",
public_key_hex.len()
);
return Ok(responses::bad_request("Invalid or missing parameters"));
return Err(ServiceError::InvalidParams);
}
let Ok(public_key) = HEXLOWER_PERMISSIVE.decode(public_key_hex.as_bytes()) else {
warn!(
"Got push request with invalid public key: {}",
public_key_hex
);
return Ok(responses::bad_request("Invalid or missing parameters"));
return Err(ServiceError::InvalidParams);
};
let Ok(public_key) = public_key.try_into() else {
warn!(
"Got push request with invalid public key: {}",
public_key_hex
);
return Ok(responses::bad_request("Invalid or missing parameters"));
return Err(ServiceError::InvalidParams);
};
PushToken::ThreemaGateway {
identity,
Expand All @@ -323,7 +307,7 @@ async fn handle_push_request(
}
other => {
warn!("Got push request with invalid token type: {}", other);
return Ok(responses::bad_request("Invalid or missing parameters"));
return Err(ServiceError::InvalidParams);
}
};
let session_public_key = find_or_bad_request!("session");
Expand All @@ -332,7 +316,7 @@ async fn handle_push_request(
Ok(parsed) => parsed,
Err(e) => {
warn!("Got push request with invalid version param: {:?}", e);
return Ok(responses::bad_request("Invalid or missing parameters"));
return Err(ServiceError::InvalidParams);
}
};
let affiliation = find!("affiliation").map(Cow::as_ref);
Expand All @@ -341,7 +325,7 @@ async fn handle_push_request(
// Parsing as u32 succeeded
Some(Ok(val)) => val,
// Parsing as u32 failed
Some(Err(_)) => return Ok(responses::bad_request("Invalid or missing parameters")),
Some(Err(_)) => return Err(ServiceError::InvalidParams),
// No TTL value was specified
None => TTL_DEFAULT,
};
Expand All @@ -355,11 +339,11 @@ async fn handle_push_request(
let endpoint = Some(match endpoint_str.as_ref() {
"p" => Endpoint::Production,
"s" => Endpoint::Sandbox,
_ => return Ok(responses::bad_request("Invalid or missing parameters")),
_ => return Err(ServiceError::InvalidParams),
});
let collapse_id = match collapse_key.as_deref().map(CollapseId::new) {
Some(Ok(id)) => Some(id),
Some(Err(_)) => return Ok(responses::bad_request("Invalid or missing parameters")),
Some(Err(_)) => return Err(ServiceError::InvalidParams),
None => None,
};
(bundle_id, endpoint, collapse_id)
Expand Down Expand Up @@ -646,7 +630,7 @@ mod tests {

assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = get_body(resp).await;
assert_eq!(&body, "Invalid or missing parameters");
assert_eq!(&body, "Missing parameters");
}

/// A request with missing parameters should result in a HTTP 400 response.
Expand All @@ -662,7 +646,7 @@ mod tests {

assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = get_body(resp).await;
assert_eq!(&body, "Invalid or missing parameters");
assert_eq!(&body, "Missing parameters");
}

/// A request with missing parameters should result in a HTTP 400 response.
Expand All @@ -678,7 +662,7 @@ mod tests {

assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = get_body(resp).await;
assert_eq!(&body, "Invalid or missing parameters");
assert_eq!(&body, "Missing parameters");
}

/// A request with bad parameters should result in a HTTP 400 response.
Expand All @@ -697,10 +681,10 @@ mod tests {

assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = get_body(resp).await;
assert_eq!(&body, "Invalid or missing parameters");
assert_eq!(&body, "Invalid parameters");
}

/// A request wit missing parameters should result in a HTTP 400 response.
/// A request with missing parameters should result in a HTTP 400 response.
#[tokio::test]
async fn test_bad_token_type() {
let app = get_test_app();
Expand All @@ -713,7 +697,7 @@ mod tests {

assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = get_body(resp).await;
assert_eq!(&body, "Invalid or missing parameters");
assert_eq!(&body, "Invalid parameters");
}

/// A request with invalid TTL parameter should result in a HTTP 400 response.
Expand All @@ -732,7 +716,7 @@ mod tests {

assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = get_body(resp).await;
assert_eq!(&body, "Invalid or missing parameters");
assert_eq!(&body, "Invalid parameters");
}

#[tokio::test]
Expand Down

0 comments on commit bda6b04

Please sign in to comment.