Skip to content

Commit

Permalink
refactor: move claims-related code into module
Browse files Browse the repository at this point in the history
Co-authored-by: Trong Huu Nguyen <[email protected]>
Co-authored-by: Tommy Trøen <[email protected]>
  • Loading branch information
3 people committed Nov 4, 2024
1 parent 2a32ed2 commit 28ba1cd
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 86 deletions.
78 changes: 78 additions & 0 deletions src/claims.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use serde::Serialize;
use jsonwebtoken as jwt;

pub trait Assertion {
fn new(token_endpoint: String, client_id: String, target: String) -> Self;
}

#[derive(Serialize)]
pub struct ClientAssertion {
exp: usize,
iat: usize,
nbf: usize,
jti: String,
sub: String,
iss: String,
aud: String,
}

#[derive(Serialize)]
pub struct JWTBearerAssertion {
exp: usize,
iat: usize,
nbf: usize,
jti: String,
scope: String,
iss: String,
aud: String,
}

impl Assertion for JWTBearerAssertion {
fn new(token_endpoint: String, client_id: String, target: String) -> Self {
let now = epoch_now_secs();
let jti = uuid::Uuid::new_v4();

Self {
exp: (now + 30) as usize,
iat: now as usize,
nbf: now as usize,
jti: jti.to_string(),
iss: client_id, // issuer of the token is the client itself
aud: token_endpoint, // audience of the token is the issuer
scope: target,
}
}
}

impl Assertion for ClientAssertion {
fn new(token_endpoint: String, client_id: String, _target: String) -> Self {
let now = epoch_now_secs();
let jti = uuid::Uuid::new_v4();

Self {
exp: (now + 30) as usize,
iat: now as usize,
nbf: now as usize,
jti: jti.to_string(),
iss: client_id.clone(), // issuer of the token is the client itself
aud: token_endpoint, // audience of the token is the issuer
sub: client_id,
}
}
}

fn epoch_now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
}

pub fn serialize<T: Serialize + Assertion>(
claims: T,
client_assertion_header: &jwt::Header,
key: &jwt::EncodingKey,
) -> Result<String, jsonwebtoken::errors::Error> {
jwt::encode(client_assertion_header, &claims, key)
}

9 changes: 5 additions & 4 deletions src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use log::error;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
use crate::claims::{ClientAssertion, JWTBearerAssertion};

#[axum::debug_handler]
pub async fn token(
Expand Down Expand Up @@ -100,10 +101,10 @@ impl Claims {
#[derive(Clone)]
pub struct HandlerState {
pub cfg: Config,
pub maskinporten: Arc<RwLock<Provider<MaskinportenTokenRequest, JWTBearerAssertionClaims>>>,
pub azure_ad_obo: Arc<RwLock<Provider<AzureADOnBehalfOfTokenRequest, ClientAssertionClaims>>>,
pub azure_ad_cc: Arc<RwLock<Provider<AzureADClientCredentialsTokenRequest, ClientAssertionClaims>>>,
pub token_x: Arc<RwLock<Provider<TokenXTokenRequest, ClientAssertionClaims>>>,
pub maskinporten: Arc<RwLock<Provider<MaskinportenTokenRequest, JWTBearerAssertion>>>,
pub azure_ad_obo: Arc<RwLock<Provider<AzureADOnBehalfOfTokenRequest, ClientAssertion>>>,
pub azure_ad_cc: Arc<RwLock<Provider<AzureADClientCredentialsTokenRequest, ClientAssertion>>>,
pub token_x: Arc<RwLock<Provider<TokenXTokenRequest, ClientAssertion>>>,
}

#[derive(Debug, Error)]
Expand Down
80 changes: 3 additions & 77 deletions src/identity_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use axum::Json;
use axum::response::IntoResponse;
use log::error;
use reqwest::StatusCode;
use crate::claims::{serialize, Assertion};
use crate::handlers::{ApiError};
use crate::types::{TokenExchangeRequest, TokenRequest, TokenResponse};

Expand Down Expand Up @@ -128,7 +129,7 @@ impl TokenRequestFactory for TokenXTokenRequest {
impl<T, U> Provider<T, U>
where
T: Serialize + TokenRequestFactory,
U: Serialize + ClientAssertion,
U: Serialize + Assertion,
{
pub fn new(
issuer: String,
Expand Down Expand Up @@ -199,7 +200,7 @@ where

fn create_assertion(&self, target: String) -> String {
let assertion = U::new(self.token_endpoint.clone(), self.client_id.clone(), target);
serialize_claims(assertion, &self.client_assertion_header, &self.private_jwk).unwrap()
serialize(assertion, &self.client_assertion_header, &self.private_jwk).unwrap()
}

pub async fn get_token(
Expand Down Expand Up @@ -228,78 +229,3 @@ where
self.get_token_with_config(token_request).await
}
}

pub trait ClientAssertion {
fn new(token_endpoint: String, client_id: String, target: String) -> Self;
}

#[derive(Serialize)]
pub struct ClientAssertionClaims {
exp: usize,
iat: usize,
nbf: usize,
jti: String,
sub: String,
iss: String,
aud: String,
}

#[derive(Serialize)]
pub struct JWTBearerAssertionClaims {
exp: usize,
iat: usize,
nbf: usize,
jti: String,
scope: String,
iss: String,
aud: String,
}

fn epoch_now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
}

fn serialize_claims<T: Serialize>(
claims: T,
client_assertion_header: &jwt::Header,
key: &jwt::EncodingKey,
) -> Result<String, jsonwebtoken::errors::Error> {
jwt::encode(client_assertion_header, &claims, key)
}

impl ClientAssertion for JWTBearerAssertionClaims {
fn new(token_endpoint: String, client_id: String, target: String) -> Self {
let now = epoch_now_secs();
let jti = uuid::Uuid::new_v4();

Self {
exp: (now + 30) as usize,
iat: now as usize,
nbf: now as usize,
jti: jti.to_string(),
iss: client_id, // issuer of the token is the client itself
aud: token_endpoint, // audience of the token is the issuer
scope: target,
}
}
}

impl ClientAssertion for ClientAssertionClaims {
fn new(token_endpoint: String, client_id: String, _target: String) -> Self {
let now = epoch_now_secs();
let jti = uuid::Uuid::new_v4();

Self {
exp: (now + 30) as usize,
iat: now as usize,
nbf: now as usize,
jti: jti.to_string(),
iss: client_id.clone(), // issuer of the token is the client itself
aud: token_endpoint, // audience of the token is the issuer
sub: client_id,
}
}
}
12 changes: 7 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod handlers;
pub mod identity_provider;
pub mod jwks;
pub mod types;
mod claims;

use crate::config::Config;
use axum::routing::post;
Expand All @@ -12,7 +13,8 @@ use log::{info, LevelFilter};
use std::sync::Arc;
use tokio::sync::RwLock;
use identity_provider::Provider;
use crate::identity_provider::{AzureADClientCredentialsTokenRequest, AzureADOnBehalfOfTokenRequest, ClientAssertionClaims, JWTBearerAssertionClaims, MaskinportenTokenRequest, TokenXTokenRequest};
use crate::claims::{ClientAssertion, JWTBearerAssertion};
use crate::identity_provider::{AzureADClientCredentialsTokenRequest, AzureADOnBehalfOfTokenRequest, MaskinportenTokenRequest, TokenXTokenRequest};

pub mod config {
use clap::Parser;
Expand Down Expand Up @@ -90,7 +92,7 @@ async fn main() {
let cfg = Config::parse();

info!("Fetch JWKS for Maskinporten...");
let maskinporten: Provider<MaskinportenTokenRequest, JWTBearerAssertionClaims> = Provider::new(
let maskinporten: Provider<MaskinportenTokenRequest, JWTBearerAssertion> = Provider::new(
cfg.maskinporten_issuer.clone(),
cfg.maskinporten_client_id.clone(),
cfg.maskinporten_token_endpoint.clone(),
Expand All @@ -101,7 +103,7 @@ async fn main() {
).unwrap();

info!("Fetch JWKS for Azure AD (on behalf of)...");
let azure_ad_obo: Provider<AzureADOnBehalfOfTokenRequest, ClientAssertionClaims> = Provider::new(
let azure_ad_obo: Provider<AzureADOnBehalfOfTokenRequest, ClientAssertion> = Provider::new(
cfg.azure_ad_issuer.clone(),
cfg.azure_ad_client_id.clone(),
cfg.azure_ad_token_endpoint.clone(),
Expand All @@ -112,7 +114,7 @@ async fn main() {
).unwrap();

info!("Fetch JWKS for Azure AD (client credentials)...");
let azure_ad_cc: Provider<AzureADClientCredentialsTokenRequest, ClientAssertionClaims> = Provider::new(
let azure_ad_cc: Provider<AzureADClientCredentialsTokenRequest, ClientAssertion> = Provider::new(
cfg.azure_ad_issuer.clone(),
cfg.azure_ad_client_id.clone(),
cfg.azure_ad_token_endpoint.clone(),
Expand All @@ -123,7 +125,7 @@ async fn main() {
).unwrap();

info!("Fetch JWKS for TokenX...");
let token_x: Provider<TokenXTokenRequest, ClientAssertionClaims> = Provider::new(
let token_x: Provider<TokenXTokenRequest, ClientAssertion> = Provider::new(
cfg.token_x_issuer.clone(),
cfg.token_x_client_id.clone(),
cfg.token_x_token_endpoint.clone(),
Expand Down

0 comments on commit 28ba1cd

Please sign in to comment.