Skip to content

Commit

Permalink
Implement Clone for ChainedInterceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
c-thiel committed Mar 26, 2024
1 parent 92bec35 commit 333eaf1
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 28 deletions.
30 changes: 22 additions & 8 deletions src/api/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
//! specific interceptors for authentication.

use std::error::Error;
use std::sync::Arc;

use custom_error::custom_error;
use tonic::codegen::InterceptedService;
use tonic::service::Interceptor;
use tonic::transport::{Channel, Endpoint};
use tonic::{Request, Status};

use crate::api::interceptors::{AccessTokenInterceptor, ServiceAccountInterceptor};
use crate::api::interceptors::{
AccessTokenInterceptor, InterceptorImmutable, ServiceAccountInterceptor,
};
use crate::api::zitadel::oidc::v2beta::oidc_service_client::OidcServiceClient;
use crate::api::zitadel::org::v2beta::organization_service_client::OrganizationServiceClient;
use crate::api::zitadel::session::v2beta::session_service_client::SessionServiceClient;
Expand Down Expand Up @@ -45,8 +48,9 @@ enum AuthType {
/// would create its own return type. With this interceptor, the return type
/// stays the same and is not dependent on the authentication type used.
/// The builder can always return `Client<InterceptedService<Channel, ChainedInterceptor>>`.
#[derive(Clone)]
pub struct ChainedInterceptor {
interceptors: Vec<Box<dyn Interceptor + Send>>,
interceptors: Vec<Arc<dyn InterceptorImmutable + Send + Sync>>,
}

impl ChainedInterceptor {
Expand All @@ -56,22 +60,32 @@ impl ChainedInterceptor {
}
}

pub(crate) fn add_interceptor(mut self, interceptor: Box<dyn Interceptor + Send>) -> Self {
pub(crate) fn add_interceptor(
mut self,
interceptor: Arc<dyn InterceptorImmutable + Send + Sync>,
) -> Self {
self.interceptors.push(interceptor);
self
}
}

impl Interceptor for ChainedInterceptor {
fn call(&mut self, request: Request<()>) -> Result<Request<()>, Status> {
impl InterceptorImmutable for ChainedInterceptor {
fn call(&self, request: Request<()>) -> Result<Request<()>, Status> {
let mut request = request;
for interceptor in &mut self.interceptors {
for interceptor in &self.interceptors {
let interceptor = Arc::clone(interceptor);
request = interceptor.call(request)?;
}
Ok(request)
}
}

impl Interceptor for ChainedInterceptor {
fn call(&mut self, request: Request<()>) -> Result<Request<()>, Status> {
InterceptorImmutable::call(self, request)
}
}

/// A builder to create configured gRPC clients for ZITADEL API access.
/// The builder accepts the api endpoint and (depending on activated features)
/// an authentication method.
Expand Down Expand Up @@ -317,11 +331,11 @@ impl ClientBuilder {
match &self.auth_type {
AuthType::AccessToken(token) => {
interceptor =
interceptor.add_interceptor(Box::new(AccessTokenInterceptor::new(token)));
interceptor.add_interceptor(Arc::new(AccessTokenInterceptor::new(token)));
}
AuthType::ServiceAccount(service_account, auth_options) => {
interceptor =
interceptor.add_interceptor(Box::new(ServiceAccountInterceptor::new(
interceptor.add_interceptor(Arc::new(ServiceAccountInterceptor::new(
&self.api_endpoint,
service_account,
auth_options.clone(),
Expand Down
85 changes: 65 additions & 20 deletions src/api/interceptors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
//! interceptors is to authenticate the clients to ZITADEL with
//! provided credentials.

use std::sync::{Arc, RwLock};
use std::thread;

use tokio::runtime::Builder;
use tonic::{service::Interceptor, Request, Status};

use crate::credentials::{AuthenticationOptions, ServiceAccount};

pub(crate) trait InterceptorImmutable {
/// Intercept a request before it is sent, optionally cancelling it.
fn call(&self, request: Request<()>) -> Result<Request<()>, Status>;
}

/// Simple gRPC `Interceptor` that attaches a given access token to any request
/// a client sends. The token is attached with the `Bearer` auth-scheme.
///
Expand Down Expand Up @@ -41,8 +47,9 @@ use crate::credentials::{AuthenticationOptions, ServiceAccount};
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct AccessTokenInterceptor {
access_token: String,
access_token: std::sync::Arc<String>,
}

impl AccessTokenInterceptor {
Expand All @@ -55,13 +62,13 @@ impl AccessTokenInterceptor {
/// and the corresponding [`authenticate`][crate::credentials::ServiceAccount::authenticate] method
pub fn new(token: &str) -> Self {
AccessTokenInterceptor {
access_token: token.to_string(),
access_token: Arc::new(token.to_string()),
}
}
}

impl Interceptor for AccessTokenInterceptor {
fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
impl InterceptorImmutable for AccessTokenInterceptor {
fn call(&self, mut request: Request<()>) -> Result<Request<()>, Status> {
let meta = request.metadata_mut();
if !meta.contains_key("authorization") {
meta.insert(
Expand All @@ -73,6 +80,12 @@ impl Interceptor for AccessTokenInterceptor {
}
}

impl Interceptor for AccessTokenInterceptor {
fn call(&mut self, request: Request<()>) -> Result<Request<()>, Status> {
InterceptorImmutable::call(self, request)
}
}

/// gRPC `Interceptor` that authenticates the service client calls
/// with the given [`ServiceAccount`][crate::credentials::ServiceAccount].
///
Expand Down Expand Up @@ -125,7 +138,13 @@ impl Interceptor for AccessTokenInterceptor {
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct ServiceAccountInterceptor {
inner: Arc<RwLock<ServiceAccountInterceptorInner>>,
}

#[derive(Clone)]
pub(crate) struct ServiceAccountInterceptorInner {
audience: String,
service_account: ServiceAccount,
auth_options: AuthenticationOptions,
Expand All @@ -143,22 +162,37 @@ impl ServiceAccountInterceptor {
service_account: &ServiceAccount,
auth_options: Option<AuthenticationOptions>,
) -> Self {
Self {
let inner = ServiceAccountInterceptorInner {
audience: audience.to_string(),
service_account: service_account.clone(),
auth_options: auth_options.unwrap_or_default(),
token: None,
token_expiry: None,
};

ServiceAccountInterceptor {
inner: Arc::new(RwLock::new(inner)),
}
}
}

impl Interceptor for ServiceAccountInterceptor {
fn call(&mut self, mut request: tonic::Request<()>) -> Result<tonic::Request<()>, Status> {
impl InterceptorImmutable for ServiceAccountInterceptor {
fn call(&self, mut request: tonic::Request<()>) -> Result<tonic::Request<()>, Status> {
use std::ops::Deref;
let ServiceAccountInterceptorInner {
audience,
service_account,
auth_options,
token,
token_expiry,
// We unwrap the RWLock to propagate the error if any
// thread panics and the lock is poisoned
} = self.inner.read().unwrap().deref().clone();

let meta = request.metadata_mut();
if !meta.contains_key("authorization") {
if let Some(token) = &self.token {
if let Some(expiry) = self.token_expiry {
if let Some(token) = &token {
if let Some(expiry) = token_expiry {
if expiry > time::OffsetDateTime::now_utc() {
meta.insert(
"authorization",
Expand All @@ -170,14 +204,11 @@ impl Interceptor for ServiceAccountInterceptor {
}
}

let aud = self.audience.clone();
let auth = self.auth_options.clone();
let sa = self.service_account.clone();

let token = thread::spawn(move || {
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
rt.block_on(async {
sa.authenticate_with_options(&aud, &auth)
service_account
.authenticate_with_options(&audience, &auth_options)
.await
.map_err(|e| Status::internal(e.to_string()))
})
Expand All @@ -187,8 +218,10 @@ impl Interceptor for ServiceAccountInterceptor {
.join()
.map_err(|_| Status::internal("could not fetch token"))??;

self.token = Some(token.clone());
self.token_expiry = Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(59));
let mut inner = self.inner.write().unwrap();
inner.token = Some(token.clone());
inner.token_expiry =
Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(59));

meta.insert(
"authorization",
Expand All @@ -200,6 +233,12 @@ impl Interceptor for ServiceAccountInterceptor {
}
}

impl Interceptor for ServiceAccountInterceptor {
fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, Status> {
InterceptorImmutable::call(self, request)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -333,20 +372,26 @@ mod tests {
let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap();
let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None);

Check failure on line 373 in src/api/interceptors.rs

View workflow job for this annotation

GitHub Actions / Linting and Testing

variable does not need to be mutable
interceptor.call(Request::new(())).unwrap();
let token = interceptor.token.clone().unwrap();
let token = interceptor.inner.read().unwrap().token.clone().unwrap();
interceptor.call(Request::new(())).unwrap();

assert_eq!(token, interceptor.token.unwrap());
assert_eq!(
token,
interceptor.inner.read().unwrap().token.clone().unwrap()
);
}

#[tokio::test]
async fn service_account_interceptor_should_respect_token_lifetime_async() {
let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap();
let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None);
interceptor.call(Request::new(())).unwrap();
let token = interceptor.token.clone().unwrap();
let token = interceptor.inner.read().unwrap().token.clone().unwrap();
interceptor.call(Request::new(())).unwrap();

assert_eq!(token, interceptor.token.unwrap());
assert_eq!(
token,
interceptor.inner.read().unwrap().token.clone().unwrap()
);
}
}

0 comments on commit 333eaf1

Please sign in to comment.