From 62396276daefe88157ed1c39d5f9a311576e0748 Mon Sep 17 00:00:00 2001 From: Einar Omang Date: Thu, 2 Jan 2025 14:55:07 +0100 Subject: [PATCH] feat: Add SASL/OAUTHTOKEN support --- Cargo.toml | 3 +- src/client/mod.rs | 2 +- src/connection.rs | 3 +- src/connection/transport.rs | 2 +- src/connection/transport/sasl.rs | 114 +++++++++++++++++++++++++++++-- src/lib.rs | 1 + src/messenger.rs | 17 ++--- src/protocol/frame.rs | 4 +- 8 files changed, 124 insertions(+), 22 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1b5a150..d9af1cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,13 +30,14 @@ lz4 = { version = "1.23", optional = true } parking_lot = "0.12" rand = "0.8" rustls = { version = "0.23", optional = true, default-features = false, features = ["logging", "ring", "std", "tls12"] } +serde = "^1.0.210" snap = { version = "1", optional = true } thiserror = "1.0" tokio = { version = "1.19", default-features = false, features = ["io-util", "net", "rt", "sync", "time", "macros"] } tokio-rustls = { version = "0.26", optional = true, default-features = false, features = ["logging", "ring", "tls12"] } tracing = "0.1" zstd = { version = "0.13", optional = true } -rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2"]} +rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2", "oauthbearer"]} [dev-dependencies] assert_matches = "1.5" diff --git a/src/client/mod.rs b/src/client/mod.rs index 214f3dd..91e43dc 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -22,7 +22,7 @@ use error::{Error, Result}; use self::{controller::ControllerClient, partition::UnknownTopicHandling}; -pub use crate::connection::{Credentials, SaslConfig}; +pub use crate::connection::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig}; #[derive(Debug, Error)] pub enum ProduceError { diff --git a/src/connection.rs b/src/connection.rs index b90a6ba..8068872 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -20,9 +20,8 @@ use crate::{ client::metadata_cache::MetadataCache, }; -pub use self::transport::Credentials; -pub use self::transport::SaslConfig; pub use self::transport::TlsConfig; +pub use self::transport::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig}; mod topology; mod transport; diff --git a/src/connection/transport.rs b/src/connection/transport.rs index 7119a19..e1f6ffc 100644 --- a/src/connection/transport.rs +++ b/src/connection/transport.rs @@ -11,7 +11,7 @@ use tokio::net::TcpStream; use tokio_rustls::{client::TlsStream, TlsConnector}; mod sasl; -pub use sasl::{Credentials, SaslConfig}; +pub use sasl::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig}; #[cfg(feature = "transport-tls")] pub type TlsConfig = Option>; diff --git a/src/connection/transport/sasl.rs b/src/connection/transport/sasl.rs index c266b58..5338184 100644 --- a/src/connection/transport/sasl.rs +++ b/src/connection/transport/sasl.rs @@ -1,3 +1,14 @@ +use std::{fmt::Debug, sync::Arc}; + +use futures::future::BoxFuture; +use rsasl::{ + callback::SessionCallback, + config::SASLConfig, + property::{AuthzId, OAuthBearerKV, OAuthBearerToken}, +}; + +use crate::messenger::SaslError; + #[derive(Debug, Clone)] pub enum SaslConfig { /// SASL - PLAIN @@ -15,6 +26,11 @@ pub enum SaslConfig { /// # References /// - ScramSha512(Credentials), + /// SASL - OAUTHBEARER + /// + /// # References + /// - + Oauthbearer(OauthBearerCredentials), } #[derive(Debug, Clone)] @@ -30,19 +46,103 @@ impl Credentials { } impl SaslConfig { - pub(crate) fn credentials(&self) -> Credentials { + pub(crate) async fn get_sasl_config(&self) -> Result, SaslError> { match self { - Self::Plain(credentials) => credentials.clone(), - Self::ScramSha256(credentials) => credentials.clone(), - Self::ScramSha512(credentials) => credentials.clone(), + Self::Plain(credentials) + | Self::ScramSha256(credentials) + | Self::ScramSha512(credentials) => Ok(SASLConfig::with_credentials( + None, + credentials.username.clone(), + credentials.password.clone(), + )?), + Self::Oauthbearer(credentials) => { + // Fetch the token first, since that's an async call. + let token = (*credentials.callback)() + .await + .map_err(SaslError::Callback)?; + + struct OauthProvider { + authz_id: Option, + bearer_kvs: Vec<(String, String)>, + token: String, + } + + // Define a callback that is called while stepping through the SASL client + // to provide necessary data for oauth. + // Since this callback is synchronous, we fetch the token first. Generally + // speaking the SASL process should not take long enough for the token to + // expire, but we do need to check for token expiry each time we authenticate. + impl SessionCallback for OauthProvider { + fn callback( + &self, + _session_data: &rsasl::callback::SessionData, + _context: &rsasl::callback::Context<'_>, + request: &mut rsasl::callback::Request<'_>, + ) -> Result<(), rsasl::prelude::SessionError> { + request + .satisfy::( + &self + .bearer_kvs + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect::>(), + )? + .satisfy::(&self.token)?; + if let Some(authz_id) = &self.authz_id { + request.satisfy::(authz_id)?; + } + Ok(()) + } + } + + Ok(SASLConfig::builder() + .with_default_mechanisms() + .with_callback(OauthProvider { + authz_id: credentials.authz_id.clone(), + bearer_kvs: credentials.bearer_kvs.clone(), + token, + })?) + } } } pub(crate) fn mechanism(&self) -> &str { + use rsasl::mechanisms::*; match self { - Self::Plain { .. } => "PLAIN", - Self::ScramSha256 { .. } => "SCRAM-SHA-256", - Self::ScramSha512 { .. } => "SCRAM-SHA-512", + Self::Plain { .. } => plain::PLAIN.mechanism.as_str(), + Self::ScramSha256 { .. } => scram::SCRAM_SHA256.mechanism.as_str(), + Self::ScramSha512 { .. } => scram::SCRAM_SHA512.mechanism.as_str(), + Self::Oauthbearer { .. } => oauthbearer::OAUTHBEARER.mechanism.as_str(), } } } + +type DynError = Box; + +/// Callback for fetching an OAUTH token. This can and should cache tokens and only request a new token +/// when the old is close to expiring. +pub type OauthCallback = + Arc BoxFuture<'static, Result> + Send + Sync>; + +#[derive(Clone)] +pub struct OauthBearerCredentials { + /// Callback that should return a token that is valid and will remain valid for + /// long enough to complete authentication. + /// The token must be on [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) format. + pub callback: OauthCallback, + /// ID of a user to impersonate. Can be left as `None` to authenticate using + /// the user for the token returned by `callback`. + pub authz_id: Option, + /// Custom key-value pairs sent as part of the SASL request. Most normal usage + /// can let this be an empty list. + pub bearer_kvs: Vec<(String, String)>, +} + +impl Debug for OauthBearerCredentials { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OauthBearerCredentials") + .field("authz_id", &self.authz_id) + .field("bearer_kvs", &self.bearer_kvs) + .finish_non_exhaustive() + } +} diff --git a/src/lib.rs b/src/lib.rs index e8522cb..1c93810 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,7 @@ pub mod client; mod connection; pub use connection::Error as ConnectionError; +pub use messenger::SaslError; #[cfg(feature = "unstable-fuzzing")] pub mod messenger; diff --git a/src/messenger.rs b/src/messenger.rs index efad01a..aeb38c5 100644 --- a/src/messenger.rs +++ b/src/messenger.rs @@ -13,9 +13,8 @@ use std::{ use futures::future::BoxFuture; use parking_lot::Mutex; use rsasl::{ - config::SASLConfig, mechname::MechanismNameError, - prelude::{Mechname, SessionError}, + prelude::{Mechname, SASLError, SessionError}, }; use thiserror::Error; use tokio::{ @@ -28,6 +27,7 @@ use tokio::{ }; use tracing::{debug, info, warn}; +use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType}; use crate::{ backoff::ErrorOrThrottle, protocol::{ @@ -48,10 +48,6 @@ use crate::{ client::SaslConfig, protocol::{api_version::ApiVersionRange, primitives::CompactString}, }; -use crate::{ - connection::Credentials, - protocol::{messages::ApiVersionsRequest, traits::ReadType}, -}; #[derive(Debug)] struct Response { @@ -205,6 +201,12 @@ pub enum SaslError { #[error("Sasl session error: {0}")] SaslSessionError(#[from] SessionError), + #[error("Invalid SASL config: {0}")] + InvalidConfig(#[from] SASLError), + + #[error("Error in user defined callback: {0}")] + Callback(Box), + #[error("unsupported sasl mechanism")] UnsupportedSaslMechanism, } @@ -581,8 +583,7 @@ where let mechanism = config.mechanism(); let resp = self.sasl_handshake(mechanism).await?; - let Credentials { username, password } = config.credentials(); - let config = SASLConfig::with_credentials(None, username, password).unwrap(); + let config = config.get_sasl_config().await?; let sasl = rsasl::prelude::SASLClient::new(config); let raw_mechanisms = resp.mechanisms.0.unwrap_or_default(); let mechanisms = raw_mechanisms diff --git a/src/protocol/frame.rs b/src/protocol/frame.rs index 479b3c4..47997f8 100644 --- a/src/protocol/frame.rs +++ b/src/protocol/frame.rs @@ -163,7 +163,7 @@ mod tests { data.set_position(0); let actual = data.read_message(0).await.unwrap(); - assert_eq!(actual, vec![]); + assert!(actual.is_empty()) } #[tokio::test] @@ -172,6 +172,6 @@ mod tests { client.write_message(&[]).await.unwrap(); let actual = server.read_message(0).await.unwrap(); - assert_eq!(actual, vec![]); + assert!(actual.is_empty()) } }