Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add SASL/OAUTHTOKEN support #253

Merged
merged 2 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ tokio = { version = "1.19", default-features = false, features = ["io-util", "ne
tokio-rustls = { version = "0.26", optional = true, default-features = false, features = ["logging", "ring", "tls12"] }
tracing = "0.1"
zstd = { version = "0.13", optional = true }
rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2"]}
rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2", "oauthbearer"]}

[dev-dependencies]
assert_matches = "1.5"
Expand Down
2 changes: 1 addition & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use error::{Error, Result};

use self::{controller::ControllerClient, partition::UnknownTopicHandling};

pub use crate::connection::{Credentials, SaslConfig};
pub use crate::connection::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};

#[derive(Debug, Error)]
pub enum ProduceError {
Expand Down
3 changes: 1 addition & 2 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ use crate::{
client::metadata_cache::MetadataCache,
};

pub use self::transport::Credentials;
pub use self::transport::SaslConfig;
pub use self::transport::TlsConfig;
pub use self::transport::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};

mod topology;
mod transport;
Expand Down
2 changes: 1 addition & 1 deletion src/connection/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};

mod sasl;
pub use sasl::{Credentials, SaslConfig};
pub use sasl::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};

#[cfg(feature = "transport-tls")]
pub type TlsConfig = Option<Arc<rustls::ClientConfig>>;
Expand Down
115 changes: 108 additions & 7 deletions src/connection/transport/sasl.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
use std::{fmt::Debug, sync::Arc};

use futures::future::BoxFuture;
use rsasl::{
callback::SessionCallback,
config::SASLConfig,
property::{AuthzId, OAuthBearerKV, OAuthBearerToken},
};

use crate::messenger::SaslError;

#[derive(Debug, Clone)]
pub enum SaslConfig {
/// SASL - PLAIN
Expand All @@ -15,6 +26,11 @@ pub enum SaslConfig {
/// # References
/// - <https://datatracker.ietf.org/doc/html/draft-melnikov-scram-sha-512-04>
ScramSha512(Credentials),
/// SASL - OAUTHBEARER
///
/// # References
/// - <https://datatracker.ietf.org/doc/html/rfc7628>
Oauthbearer(OauthBearerCredentials),
}

#[derive(Debug, Clone)]
Expand All @@ -30,19 +46,104 @@ impl Credentials {
}

impl SaslConfig {
pub(crate) fn credentials(&self) -> Credentials {
pub(crate) async fn get_sasl_config(&self) -> Result<Arc<SASLConfig>, SaslError> {
match self {
Self::Plain(credentials) => credentials.clone(),
Self::ScramSha256(credentials) => credentials.clone(),
Self::ScramSha512(credentials) => credentials.clone(),
Self::Plain(credentials)
| Self::ScramSha256(credentials)
| Self::ScramSha512(credentials) => Ok(SASLConfig::with_credentials(
None,
credentials.username.clone(),
credentials.password.clone(),
)?),
Self::Oauthbearer(credentials) => {
// Fetch the token first, since that's an async call.
let token = (*credentials.callback)()
.await
.map_err(SaslError::Callback)?;

struct OauthProvider {
authz_id: Option<String>,
bearer_kvs: Vec<(String, String)>,
token: String,
}

// Define a callback that is called while stepping through the SASL client
// to provide necessary data for oauth.
// Since this callback is synchronous, we fetch the token first. Generally
// speaking the SASL process should not take long enough for the token to
// expire, but we do need to check for token expiry each time we authenticate.
impl SessionCallback for OauthProvider {
fn callback(
&self,
_session_data: &rsasl::callback::SessionData,
_context: &rsasl::callback::Context<'_>,
request: &mut rsasl::callback::Request<'_>,
) -> Result<(), rsasl::prelude::SessionError> {
request
.satisfy::<OAuthBearerKV>(
&self
.bearer_kvs
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect::<Vec<_>>(),
)?
.satisfy::<OAuthBearerToken>(&self.token)?;
if let Some(authz_id) = &self.authz_id {
request.satisfy::<AuthzId>(authz_id)?;
}
Ok(())
}
}

Ok(SASLConfig::builder()
.with_default_mechanisms()
.with_callback(OauthProvider {
authz_id: credentials.authz_id.clone(),
bearer_kvs: credentials.bearer_kvs.clone(),
token,
})?)
}
}
}

pub(crate) fn mechanism(&self) -> &str {
use rsasl::mechanisms::*;
match self {
Self::Plain { .. } => "PLAIN",
Self::ScramSha256 { .. } => "SCRAM-SHA-256",
Self::ScramSha512 { .. } => "SCRAM-SHA-512",
Self::Plain { .. } => plain::PLAIN.mechanism.as_str(),
Self::ScramSha256 { .. } => scram::SCRAM_SHA256.mechanism.as_str(),
Self::ScramSha512 { .. } => scram::SCRAM_SHA512.mechanism.as_str(),
Self::Oauthbearer { .. } => oauthbearer::OAUTHBEARER.mechanism.as_str(),
}
}
}

type DynError = Box<dyn std::error::Error + Send + Sync>;

/// Callback for fetching an OAUTH token. This should cache tokens and only request a new token
/// when the old is close to expiring.
pub type OauthCallback =
Arc<dyn Fn() -> BoxFuture<'static, Result<String, DynError>> + Send + Sync>;

#[derive(Clone)]
pub struct OauthBearerCredentials {
/// Callback that should return a token that is valid and will remain valid for
/// long enough to complete authentication. This should cache the token and only request
/// a new one when the old is close to expiring.
/// The token must be on [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) format.
pub callback: OauthCallback,
/// ID of a user to impersonate. Can be left as `None` to authenticate using
/// the user for the token returned by `callback`.
pub authz_id: Option<String>,
/// Custom key-value pairs sent as part of the SASL request. Most normal usage
/// can let this be an empty list.
pub bearer_kvs: Vec<(String, String)>,
}

impl Debug for OauthBearerCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OauthBearerCredentials")
.field("authz_id", &self.authz_id)
.field("bearer_kvs", &self.bearer_kvs)
.finish_non_exhaustive()
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub mod client;
mod connection;

pub use connection::Error as ConnectionError;
pub use messenger::SaslError;

#[cfg(feature = "unstable-fuzzing")]
pub mod messenger;
Expand Down
27 changes: 15 additions & 12 deletions src/messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ use std::{
use futures::future::BoxFuture;
use parking_lot::Mutex;
use rsasl::{
config::SASLConfig,
mechname::MechanismNameError,
prelude::{Mechname, SessionError},
prelude::{Mechname, SASLError, SessionError},
};
use thiserror::Error;
use tokio::{
Expand All @@ -28,6 +27,7 @@ use tokio::{
};
use tracing::{debug, info, warn};

use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType};
use crate::{
backoff::ErrorOrThrottle,
protocol::{
Expand All @@ -48,10 +48,6 @@ use crate::{
client::SaslConfig,
protocol::{api_version::ApiVersionRange, primitives::CompactString},
};
use crate::{
connection::Credentials,
protocol::{messages::ApiVersionsRequest, traits::ReadType},
};

#[derive(Debug)]
struct Response {
Expand Down Expand Up @@ -205,6 +201,12 @@ pub enum SaslError {
#[error("Sasl session error: {0}")]
SaslSessionError(#[from] SessionError),

#[error("Invalid SASL config: {0}")]
InvalidConfig(#[from] SASLError),

#[error("Error in user defined callback: {0}")]
Callback(Box<dyn std::error::Error + Send + Sync>),

#[error("unsupported sasl mechanism")]
UnsupportedSaslMechanism,
}
Expand Down Expand Up @@ -581,8 +583,7 @@ where
let mechanism = config.mechanism();
let resp = self.sasl_handshake(mechanism).await?;

let Credentials { username, password } = config.credentials();
let config = SASLConfig::with_credentials(None, username, password).unwrap();
let config = config.get_sasl_config().await?;
let sasl = rsasl::prelude::SASLClient::new(config);
let raw_mechanisms = resp.mechanisms.0.unwrap_or_default();
let mechanisms = raw_mechanisms
Expand All @@ -604,12 +605,14 @@ where
loop {
let mut to_sent = Cursor::new(Vec::new());
let state = session.step(data_received.as_deref(), &mut to_sent)?;
if !state.is_running() {

if state.has_sent_message() {
let authentication_response =
self.sasl_authentication(to_sent.into_inner()).await?;
data_received = Some(authentication_response.auth_bytes.0);
} else {
break;
}

let authentication_response = self.sasl_authentication(to_sent.into_inner()).await?;
data_received = Some(authentication_response.auth_bytes.0);
}

Ok(())
Expand Down
4 changes: 2 additions & 2 deletions src/protocol/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ mod tests {

data.set_position(0);
let actual = data.read_message(0).await.unwrap();
assert_eq!(actual, vec![]);
assert!(actual.is_empty())
crepererum marked this conversation as resolved.
Show resolved Hide resolved
}

#[tokio::test]
Expand All @@ -172,6 +172,6 @@ mod tests {
client.write_message(&[]).await.unwrap();

let actual = server.read_message(0).await.unwrap();
assert_eq!(actual, vec![]);
assert!(actual.is_empty())
crepererum marked this conversation as resolved.
Show resolved Hide resolved
}
}
Loading