Skip to content

Commit

Permalink
feat: Add SASL/OAUTHTOKEN support
Browse files Browse the repository at this point in the history
  • Loading branch information
einarmo committed Jan 3, 2025
1 parent 8c20bc9 commit 493d791
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ tokio = { version = "1.19", default-features = false, features = ["io-util", "ne
tokio-rustls = { version = "0.26", optional = true, default-features = false, features = ["logging", "ring", "tls12"] }
tracing = "0.1"
zstd = { version = "0.13", optional = true }
rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2"]}
rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2", "oauthbearer"]}

[dev-dependencies]
assert_matches = "1.5"
Expand Down
2 changes: 1 addition & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use error::{Error, Result};

use self::{controller::ControllerClient, partition::UnknownTopicHandling};

pub use crate::connection::{Credentials, SaslConfig};
pub use crate::connection::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};

#[derive(Debug, Error)]
pub enum ProduceError {
Expand Down
3 changes: 1 addition & 2 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ use crate::{
client::metadata_cache::MetadataCache,
};

pub use self::transport::Credentials;
pub use self::transport::SaslConfig;
pub use self::transport::TlsConfig;
pub use self::transport::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};

mod topology;
mod transport;
Expand Down
2 changes: 1 addition & 1 deletion src/connection/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};

mod sasl;
pub use sasl::{Credentials, SaslConfig};
pub use sasl::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};

#[cfg(feature = "transport-tls")]
pub type TlsConfig = Option<Arc<rustls::ClientConfig>>;
Expand Down
115 changes: 108 additions & 7 deletions src/connection/transport/sasl.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
use std::{fmt::Debug, sync::Arc};

use futures::future::BoxFuture;
use rsasl::{
callback::SessionCallback,
config::SASLConfig,
property::{AuthzId, OAuthBearerKV, OAuthBearerToken},
};

use crate::messenger::SaslError;

#[derive(Debug, Clone)]
pub enum SaslConfig {
/// SASL - PLAIN
Expand All @@ -15,6 +26,11 @@ pub enum SaslConfig {
/// # References
/// - <https://datatracker.ietf.org/doc/html/draft-melnikov-scram-sha-512-04>
ScramSha512(Credentials),
/// SASL - OAUTHBEARER
///
/// # References
/// - <https://datatracker.ietf.org/doc/html/rfc7628>
Oauthbearer(OauthBearerCredentials),
}

#[derive(Debug, Clone)]
Expand All @@ -30,19 +46,104 @@ impl Credentials {
}

impl SaslConfig {
pub(crate) fn credentials(&self) -> Credentials {
pub(crate) async fn get_sasl_config(&self) -> Result<Arc<SASLConfig>, SaslError> {
match self {
Self::Plain(credentials) => credentials.clone(),
Self::ScramSha256(credentials) => credentials.clone(),
Self::ScramSha512(credentials) => credentials.clone(),
Self::Plain(credentials)
| Self::ScramSha256(credentials)
| Self::ScramSha512(credentials) => Ok(SASLConfig::with_credentials(
None,
credentials.username.clone(),
credentials.password.clone(),
)?),
Self::Oauthbearer(credentials) => {
// Fetch the token first, since that's an async call.
let token = (*credentials.callback)()
.await
.map_err(SaslError::Callback)?;

struct OauthProvider {
authz_id: Option<String>,
bearer_kvs: Vec<(String, String)>,
token: String,
}

// Define a callback that is called while stepping through the SASL client
// to provide necessary data for oauth.
// Since this callback is synchronous, we fetch the token first. Generally
// speaking the SASL process should not take long enough for the token to
// expire, but we do need to check for token expiry each time we authenticate.
impl SessionCallback for OauthProvider {
fn callback(
&self,
_session_data: &rsasl::callback::SessionData,
_context: &rsasl::callback::Context<'_>,
request: &mut rsasl::callback::Request<'_>,
) -> Result<(), rsasl::prelude::SessionError> {
request
.satisfy::<OAuthBearerKV>(
&self
.bearer_kvs
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect::<Vec<_>>(),
)?
.satisfy::<OAuthBearerToken>(&self.token)?;
if let Some(authz_id) = &self.authz_id {
request.satisfy::<AuthzId>(authz_id)?;
}
Ok(())
}
}

Ok(SASLConfig::builder()
.with_default_mechanisms()
.with_callback(OauthProvider {
authz_id: credentials.authz_id.clone(),
bearer_kvs: credentials.bearer_kvs.clone(),
token,
})?)
}
}
}

pub(crate) fn mechanism(&self) -> &str {
use rsasl::mechanisms::*;
match self {
Self::Plain { .. } => "PLAIN",
Self::ScramSha256 { .. } => "SCRAM-SHA-256",
Self::ScramSha512 { .. } => "SCRAM-SHA-512",
Self::Plain { .. } => plain::PLAIN.mechanism.as_str(),
Self::ScramSha256 { .. } => scram::SCRAM_SHA256.mechanism.as_str(),
Self::ScramSha512 { .. } => scram::SCRAM_SHA512.mechanism.as_str(),
Self::Oauthbearer { .. } => oauthbearer::OAUTHBEARER.mechanism.as_str(),
}
}
}

type DynError = Box<dyn std::error::Error + Send + Sync>;

/// Callback for fetching an OAUTH token. This should cache tokens and only request a new token
/// when the old is close to expiring.
pub type OauthCallback =
Arc<dyn Fn() -> BoxFuture<'static, Result<String, DynError>> + Send + Sync>;

#[derive(Clone)]
pub struct OauthBearerCredentials {
/// Callback that should return a token that is valid and will remain valid for
/// long enough to complete authentication. This should cache the token and only request
/// a new one when the old is close to expiring.
/// The token must be on [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) format.
pub callback: OauthCallback,
/// ID of a user to impersonate. Can be left as `None` to authenticate using
/// the user for the token returned by `callback`.
pub authz_id: Option<String>,
/// Custom key-value pairs sent as part of the SASL request. Most normal usage
/// can let this be an empty list.
pub bearer_kvs: Vec<(String, String)>,
}

impl Debug for OauthBearerCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OauthBearerCredentials")
.field("authz_id", &self.authz_id)
.field("bearer_kvs", &self.bearer_kvs)
.finish_non_exhaustive()
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub mod client;
mod connection;

pub use connection::Error as ConnectionError;
pub use messenger::SaslError;

#[cfg(feature = "unstable-fuzzing")]
pub mod messenger;
Expand Down
17 changes: 9 additions & 8 deletions src/messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ use std::{
use futures::future::BoxFuture;
use parking_lot::Mutex;
use rsasl::{
config::SASLConfig,
mechname::MechanismNameError,
prelude::{Mechname, SessionError},
prelude::{Mechname, SASLError, SessionError},
};
use thiserror::Error;
use tokio::{
Expand All @@ -28,6 +27,7 @@ use tokio::{
};
use tracing::{debug, info, warn};

use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType};
use crate::{
backoff::ErrorOrThrottle,
protocol::{
Expand All @@ -48,10 +48,6 @@ use crate::{
client::SaslConfig,
protocol::{api_version::ApiVersionRange, primitives::CompactString},
};
use crate::{
connection::Credentials,
protocol::{messages::ApiVersionsRequest, traits::ReadType},
};

#[derive(Debug)]
struct Response {
Expand Down Expand Up @@ -205,6 +201,12 @@ pub enum SaslError {
#[error("Sasl session error: {0}")]
SaslSessionError(#[from] SessionError),

#[error("Invalid SASL config: {0}")]
InvalidConfig(#[from] SASLError),

#[error("Error in user defined callback: {0}")]
Callback(Box<dyn std::error::Error + Send + Sync>),

#[error("unsupported sasl mechanism")]
UnsupportedSaslMechanism,
}
Expand Down Expand Up @@ -581,8 +583,7 @@ where
let mechanism = config.mechanism();
let resp = self.sasl_handshake(mechanism).await?;

let Credentials { username, password } = config.credentials();
let config = SASLConfig::with_credentials(None, username, password).unwrap();
let config = config.get_sasl_config().await?;
let sasl = rsasl::prelude::SASLClient::new(config);
let raw_mechanisms = resp.mechanisms.0.unwrap_or_default();
let mechanisms = raw_mechanisms
Expand Down
4 changes: 2 additions & 2 deletions src/protocol/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ mod tests {

data.set_position(0);
let actual = data.read_message(0).await.unwrap();
assert_eq!(actual, vec![]);
assert!(actual.is_empty())
}

#[tokio::test]
Expand All @@ -172,6 +172,6 @@ mod tests {
client.write_message(&[]).await.unwrap();

let actual = server.read_message(0).await.unwrap();
assert_eq!(actual, vec![]);
assert!(actual.is_empty())
}
}

0 comments on commit 493d791

Please sign in to comment.