Skip to content

Commit

Permalink
feat: refactor KeyPair and simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
zeitgeist committed Dec 21, 2023
1 parent 02499d8 commit f8aa96d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 50 deletions.
43 changes: 22 additions & 21 deletions endpoint/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ use rsa::{pkcs8::LineEnding, RsaPrivateKey, RsaPublicKey};

const KEY_BITS: usize = 2048;

#[derive(Debug, Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
//#[derive(Debug)]
pub struct KeyPair {
pub priv_key: String,
pub pub_key: String,
pub pub_key: RsaPublicKey,
enc_key: EncodingKey,
dec_key: DecodingKey,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -123,6 +123,8 @@ fn decode_basic_auth(raw_auth_info: &str) -> Option<OAuth2ClientCredentials> {
None
}

const BEARER_TOKEN_START: &str = "Bearer ";

#[rocket::async_trait]
impl<'r> FromRequest<'r> for UserToken {
type Error = status::Custom<UserTokenError>;
Expand All @@ -132,10 +134,10 @@ impl<'r> FromRequest<'r> for UserToken {
) -> request::Outcome<Self, status::Custom<UserTokenError>> {
if let Some(authen_header) = req.headers().get_one("Authorization") {
let authen_str = authen_header.to_string();
if authen_str.starts_with("Bearer") {
let token = authen_str[6..authen_str.len()].trim();
let pub_key = &req.rocket().state::<KeyPair>().unwrap().pub_key;
if let Ok(token_data) = decode_token(token.to_string(), pub_key.to_string()) {
if authen_str.starts_with(BEARER_TOKEN_START) {
let token = authen_str[BEARER_TOKEN_START.len()..authen_str.len()].trim();
let key_pair = req.rocket().state::<KeyPair>().unwrap();
if let Ok(token_data) = decode_token(token.to_string(), key_pair) {
return Outcome::Success(token_data.claims);
}
}
Expand Down Expand Up @@ -168,29 +170,28 @@ pub fn generate_keys() -> KeyPair {
.to_public_key_pem(LineEnding::default())
.expect("could not serialize public key");

KeyPair { priv_key, pub_key }
let dec_key = DecodingKey::from_rsa_pem(pub_key.as_bytes()).unwrap();
let enc_key = EncodingKey::from_rsa_pem(priv_key.as_bytes()).unwrap();

KeyPair {
pub_key: public_key,
enc_key,
dec_key,
}
}

fn decode_token(token: String, pub_key: String) -> Result<TokenData<UserToken>> {
fn decode_token(token: String, key_pair: &KeyPair) -> Result<TokenData<UserToken>> {
let mut v = Validation::new(Algorithm::RS256);
v.validate_exp = false;
v.required_spec_claims = HashSet::new();

jsonwebtoken::decode::<UserToken>(
&token,
&DecodingKey::from_rsa_pem(pub_key.as_bytes()).unwrap(),
&v,
)
jsonwebtoken::decode::<UserToken>(&token, &key_pair.dec_key, &v)
}

pub fn encode_token(u: &UserToken, priv_key: String) -> Result<String> {
pub fn encode_token(u: &UserToken, key_pair: &KeyPair) -> Result<String> {
let header = Header::new(Algorithm::RS256);

jsonwebtoken::encode(
&header,
u,
&EncodingKey::from_rsa_pem(priv_key.as_bytes()).unwrap(),
)
jsonwebtoken::encode(&header, u, &key_pair.enc_key)
}

impl<'a> OpenApiFromRequest<'a> for UserToken {
Expand Down
50 changes: 21 additions & 29 deletions endpoint/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ use rocket_okapi::{get_openapi_route, openapi, openapi_get_routes_spec};
use api_types::*;
use datamodel::{PfId, ProductFootprint};
use openid_conf::OpenIdConfiguration;
use rsa::pkcs8::{self};
use rsa::traits::PublicKeyParts;
use rsa::RsaPublicKey;
use sample_data::PCF_DEMO_DATA;
use Either::Left;

Expand Down Expand Up @@ -80,8 +78,7 @@ fn openid_configuration() -> Json<OpenIdConfiguration> {
/// endpoint to retrieve the Json Web Key Set to verify the token's signature
#[get("/2/jwks")]
fn jwks(state: &State<KeyPair>) -> Json<JwkSet> {
let pub_key: RsaPublicKey =
pkcs8::DecodePublicKey::from_public_key_pem(&state.pub_key).unwrap();
let pub_key = &state.pub_key;

let jwks = JwkSet {
keys: vec![Jwk {
Expand Down Expand Up @@ -113,14 +110,9 @@ fn oauth2_create_token(
body: Form<auth::OAuth2ClientCredentialsBody<'_>>,
state: &State<KeyPair>,
) -> Either<Json<auth::OAuth2TokenReply>, error::OAuth2ErrorMessage> {
println!("{state:?}");

if req.id == AUTH_USERNAME && req.secret == AUTH_PASSWORD {
let access_token = auth::encode_token(
&auth::UserToken { username: req.id },
state.priv_key.clone(),
)
.unwrap();
let access_token =
auth::encode_token(&auth::UserToken { username: req.id }, state).unwrap();

let reply = auth::OAuth2TokenReply {
access_token,
Expand Down Expand Up @@ -595,9 +587,9 @@ fn verify_token_signature_test() {
username: "hello".to_string(),
};

let server_priv_key: String = client.rocket().state::<KeyPair>().unwrap().priv_key.clone();
let key_pair = client.rocket().state::<KeyPair>().unwrap();

let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap();
let jwt = auth::encode_token(&token, key_pair).ok().unwrap();

let response = client.get("/2/jwks").dispatch();

Expand Down Expand Up @@ -628,9 +620,9 @@ fn get_list_test() {
username: "hello".to_string(),
};

let server_priv_key: String = client.rocket().state::<KeyPair>().unwrap().priv_key.clone();
let key_pair = client.rocket().state::<KeyPair>().unwrap();

let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap();
let jwt = auth::encode_token(&token, key_pair).ok().unwrap();
let bearer_token = format!("Bearer {jwt}");

let get_list_uri = "/2/footprints";
Expand Down Expand Up @@ -670,9 +662,9 @@ fn get_list_with_filter_eq_test() {
username: "hello".to_string(),
};

let server_priv_key: String = client.rocket().state::<KeyPair>().unwrap().priv_key.clone();
let key_pair = client.rocket().state::<KeyPair>().unwrap();

let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap();
let jwt = auth::encode_token(&token, key_pair).ok().unwrap();
let bearer_token = format!("Bearer {jwt}");

let get_list_with_limit_uri = "/2/footprints?$filter=pcf/geographyCountry+eq+'FR'";
Expand All @@ -696,9 +688,9 @@ fn get_list_with_filter_lt_test() {
username: "hello".to_string(),
};

let server_priv_key: String = client.rocket().state::<KeyPair>().unwrap().priv_key.clone();
let key_pair = client.rocket().state::<KeyPair>().unwrap();

let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap();
let jwt = auth::encode_token(&token, key_pair).ok().unwrap();

let bearer_token = format!("Bearer {jwt}");

Expand All @@ -723,9 +715,9 @@ fn get_list_with_filter_eq_and_lt_test() {
username: "hello".to_string(),
};

let server_priv_key: String = client.rocket().state::<KeyPair>().unwrap().priv_key.clone();
let key_pair = client.rocket().state::<KeyPair>().unwrap();

let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap();
let jwt = auth::encode_token(&token, key_pair).ok().unwrap();
let bearer_token = format!("Bearer {jwt}");

let get_list_with_limit_uri = "/2/footprints?$filter=(pcf/geographyCountry+eq+'FR')+and+(updated+lt+'2023-01-01T00:00:00.000Z')";
Expand All @@ -749,9 +741,9 @@ fn get_list_with_filter_any_test() {
username: "hello".to_string(),
};

let server_priv_key: String = client.rocket().state::<KeyPair>().unwrap().priv_key.clone();
let key_pair = client.rocket().state::<KeyPair>().unwrap();

let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap();
let jwt = auth::encode_token(&token, key_pair).ok().unwrap();
let bearer_token = format!("Bearer {jwt}");

let get_list_with_limit_uri =
Expand Down Expand Up @@ -792,9 +784,9 @@ fn get_list_with_limit_test() {
username: "hello".to_string(),
};

let server_priv_key: String = client.rocket().state::<KeyPair>().unwrap().priv_key.clone();
let key_pair = client.rocket().state::<KeyPair>().unwrap();

let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap();
let jwt = auth::encode_token(&token, key_pair).ok().unwrap();
let bearer_token = format!("Bearer {jwt}");

let get_list_with_limit_uri = "/2/footprints?limit=3";
Expand Down Expand Up @@ -863,9 +855,9 @@ fn post_events_test() {
username: "hello".to_string(),
};

let server_priv_key: String = client.rocket().state::<KeyPair>().unwrap().priv_key.clone();
let key_pair = client.rocket().state::<KeyPair>().unwrap();

let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap();
let jwt = auth::encode_token(&token, key_pair).ok().unwrap();
let bearer_token = format!("Bearer {jwt}");

let post_events_uri = "/2/events";
Expand Down Expand Up @@ -925,9 +917,9 @@ fn get_pcf_test() {
username: "hello".to_string(),
};

let server_priv_key: String = client.rocket().state::<KeyPair>().unwrap().priv_key.clone();
let key_pair = client.rocket().state::<KeyPair>().unwrap();

let jwt = auth::encode_token(&token, server_priv_key).ok().unwrap();
let jwt = auth::encode_token(&token, key_pair).ok().unwrap();
let bearer_token = format!("Bearer {jwt}");

// test auth
Expand Down

0 comments on commit f8aa96d

Please sign in to comment.