diff --git a/endpoint/src/auth.rs b/endpoint/src/auth.rs index 3477541..1a19f34 100644 --- a/endpoint/src/auth.rs +++ b/endpoint/src/auth.rs @@ -35,11 +35,11 @@ use rsa::{pkcs8::LineEnding, RsaPrivateKey, RsaPublicKey}; const KEY_BITS: usize = 2048; -#[derive(Debug, Serialize, Deserialize)] -#[serde(crate = "rocket::serde")] +//#[derive(Debug)] pub struct KeyPair { - pub priv_key: String, - pub pub_key: String, + pub pub_key: RsaPublicKey, + enc_key: EncodingKey, + dec_key: DecodingKey, } #[derive(Debug, Serialize, Deserialize)] @@ -123,6 +123,8 @@ fn decode_basic_auth(raw_auth_info: &str) -> Option { None } +const BEARER_TOKEN_START: &str = "Bearer "; + #[rocket::async_trait] impl<'r> FromRequest<'r> for UserToken { type Error = status::Custom; @@ -132,10 +134,10 @@ impl<'r> FromRequest<'r> for UserToken { ) -> request::Outcome> { if let Some(authen_header) = req.headers().get_one("Authorization") { let authen_str = authen_header.to_string(); - if authen_str.starts_with("Bearer") { - let token = authen_str[6..authen_str.len()].trim(); - let pub_key = &req.rocket().state::().unwrap().pub_key; - if let Ok(token_data) = decode_token(token.to_string(), pub_key.to_string()) { + if authen_str.starts_with(BEARER_TOKEN_START) { + let token = authen_str[BEARER_TOKEN_START.len()..authen_str.len()].trim(); + let key_pair = req.rocket().state::().unwrap(); + if let Ok(token_data) = decode_token(token.to_string(), key_pair) { return Outcome::Success(token_data.claims); } } @@ -168,29 +170,28 @@ pub fn generate_keys() -> KeyPair { .to_public_key_pem(LineEnding::default()) .expect("could not serialize public key"); - KeyPair { priv_key, pub_key } + let dec_key = DecodingKey::from_rsa_pem(pub_key.as_bytes()).unwrap(); + let enc_key = EncodingKey::from_rsa_pem(priv_key.as_bytes()).unwrap(); + + KeyPair { + pub_key: public_key, + enc_key, + dec_key, + } } -fn decode_token(token: String, pub_key: String) -> Result> { +fn decode_token(token: String, key_pair: &KeyPair) -> Result> { let mut v = Validation::new(Algorithm::RS256); v.validate_exp = false; v.required_spec_claims = HashSet::new(); - jsonwebtoken::decode::( - &token, - &DecodingKey::from_rsa_pem(pub_key.as_bytes()).unwrap(), - &v, - ) + jsonwebtoken::decode::(&token, &key_pair.dec_key, &v) } -pub fn encode_token(u: &UserToken, priv_key: String) -> Result { +pub fn encode_token(u: &UserToken, key_pair: &KeyPair) -> Result { let header = Header::new(Algorithm::RS256); - jsonwebtoken::encode( - &header, - u, - &EncodingKey::from_rsa_pem(priv_key.as_bytes()).unwrap(), - ) + jsonwebtoken::encode(&header, u, &key_pair.enc_key) } impl<'a> OpenApiFromRequest<'a> for UserToken { diff --git a/endpoint/src/main.rs b/endpoint/src/main.rs index 1cef8c8..ff808d0 100644 --- a/endpoint/src/main.rs +++ b/endpoint/src/main.rs @@ -43,9 +43,7 @@ use rocket_okapi::{get_openapi_route, openapi, openapi_get_routes_spec}; use api_types::*; use datamodel::{PfId, ProductFootprint}; use openid_conf::OpenIdConfiguration; -use rsa::pkcs8::{self}; use rsa::traits::PublicKeyParts; -use rsa::RsaPublicKey; use sample_data::PCF_DEMO_DATA; use Either::Left; @@ -80,8 +78,7 @@ fn openid_configuration() -> Json { /// endpoint to retrieve the Json Web Key Set to verify the token's signature #[get("/2/jwks")] fn jwks(state: &State) -> Json { - let pub_key: RsaPublicKey = - pkcs8::DecodePublicKey::from_public_key_pem(&state.pub_key).unwrap(); + let pub_key = &state.pub_key; let jwks = JwkSet { keys: vec![Jwk { @@ -113,14 +110,9 @@ fn oauth2_create_token( body: Form>, state: &State, ) -> Either, error::OAuth2ErrorMessage> { - println!("{state:?}"); - if req.id == AUTH_USERNAME && req.secret == AUTH_PASSWORD { - let access_token = auth::encode_token( - &auth::UserToken { username: req.id }, - state.priv_key.clone(), - ) - .unwrap(); + let access_token = + auth::encode_token(&auth::UserToken { username: req.id }, state).unwrap(); let reply = auth::OAuth2TokenReply { access_token, @@ -595,9 +587,9 @@ fn verify_token_signature_test() { username: "hello".to_string(), }; - let server_priv_key: String = client.rocket().state::().unwrap().priv_key.clone(); + let key_pair = client.rocket().state::().unwrap(); - let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap(); + let jwt = auth::encode_token(&token, key_pair).ok().unwrap(); let response = client.get("/2/jwks").dispatch(); @@ -628,9 +620,9 @@ fn get_list_test() { username: "hello".to_string(), }; - let server_priv_key: String = client.rocket().state::().unwrap().priv_key.clone(); + let key_pair = client.rocket().state::().unwrap(); - let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap(); + let jwt = auth::encode_token(&token, key_pair).ok().unwrap(); let bearer_token = format!("Bearer {jwt}"); let get_list_uri = "/2/footprints"; @@ -670,9 +662,9 @@ fn get_list_with_filter_eq_test() { username: "hello".to_string(), }; - let server_priv_key: String = client.rocket().state::().unwrap().priv_key.clone(); + let key_pair = client.rocket().state::().unwrap(); - let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap(); + let jwt = auth::encode_token(&token, key_pair).ok().unwrap(); let bearer_token = format!("Bearer {jwt}"); let get_list_with_limit_uri = "/2/footprints?$filter=pcf/geographyCountry+eq+'FR'"; @@ -696,9 +688,9 @@ fn get_list_with_filter_lt_test() { username: "hello".to_string(), }; - let server_priv_key: String = client.rocket().state::().unwrap().priv_key.clone(); + let key_pair = client.rocket().state::().unwrap(); - let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap(); + let jwt = auth::encode_token(&token, key_pair).ok().unwrap(); let bearer_token = format!("Bearer {jwt}"); @@ -723,9 +715,9 @@ fn get_list_with_filter_eq_and_lt_test() { username: "hello".to_string(), }; - let server_priv_key: String = client.rocket().state::().unwrap().priv_key.clone(); + let key_pair = client.rocket().state::().unwrap(); - let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap(); + let jwt = auth::encode_token(&token, key_pair).ok().unwrap(); let bearer_token = format!("Bearer {jwt}"); let get_list_with_limit_uri = "/2/footprints?$filter=(pcf/geographyCountry+eq+'FR')+and+(updated+lt+'2023-01-01T00:00:00.000Z')"; @@ -749,9 +741,9 @@ fn get_list_with_filter_any_test() { username: "hello".to_string(), }; - let server_priv_key: String = client.rocket().state::().unwrap().priv_key.clone(); + let key_pair = client.rocket().state::().unwrap(); - let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap(); + let jwt = auth::encode_token(&token, key_pair).ok().unwrap(); let bearer_token = format!("Bearer {jwt}"); let get_list_with_limit_uri = @@ -792,9 +784,9 @@ fn get_list_with_limit_test() { username: "hello".to_string(), }; - let server_priv_key: String = client.rocket().state::().unwrap().priv_key.clone(); + let key_pair = client.rocket().state::().unwrap(); - let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap(); + let jwt = auth::encode_token(&token, key_pair).ok().unwrap(); let bearer_token = format!("Bearer {jwt}"); let get_list_with_limit_uri = "/2/footprints?limit=3"; @@ -863,9 +855,9 @@ fn post_events_test() { username: "hello".to_string(), }; - let server_priv_key: String = client.rocket().state::().unwrap().priv_key.clone(); + let key_pair = client.rocket().state::().unwrap(); - let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap(); + let jwt = auth::encode_token(&token, key_pair).ok().unwrap(); let bearer_token = format!("Bearer {jwt}"); let post_events_uri = "/2/events"; @@ -925,9 +917,9 @@ fn get_pcf_test() { username: "hello".to_string(), }; - let server_priv_key: String = client.rocket().state::().unwrap().priv_key.clone(); + let key_pair = client.rocket().state::().unwrap(); - let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap(); + let jwt = auth::encode_token(&token, key_pair).ok().unwrap(); let bearer_token = format!("Bearer {jwt}"); // test auth