Skip to content

Commit

Permalink
Handle auth info
Browse files Browse the repository at this point in the history
  • Loading branch information
moubctez committed Nov 7, 2024
1 parent 255a3cd commit a5ffb68
Show file tree
Hide file tree
Showing 11 changed files with 650 additions and 386 deletions.
882 changes: 532 additions & 350 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
config.protoc_arg("--experimental_allow_proto3_optional");
// Make all messages serde-serializable
config.type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]");
tonic_build::configure().compile_with_config(
tonic_build::configure().compile_protos_with_config(
config,
&["proto/core/proxy.proto"],
&["proto/core"],
Expand Down
16 changes: 8 additions & 8 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fs, io::Error as IoError};
use std::{fs::read_to_string, path::PathBuf};

use clap::Parser;
use log::LevelFilter;
Expand Down Expand Up @@ -38,13 +38,13 @@ pub struct Config {
/// Configuration file path
#[arg(long = "config", short)]
#[serde(skip)]
config_path: Option<std::path::PathBuf>,
config_path: Option<PathBuf>,
}

#[derive(thiserror::Error, Debug)]
pub enum ConfigError {
#[error("Failed to read config file")]
IoError(#[from] IoError),
IoError(#[from] std::io::Error),
#[error("Failed to parse config file")]
ParseError(#[from] toml::de::Error),
}
Expand All @@ -55,11 +55,11 @@ pub fn get_config() -> Result<Config, ConfigError> {

// load config from file if one was specified
if let Some(config_path) = cli_config.config_path {
info!("Reading configuration from config file: {config_path:?}");
let config_toml = fs::read_to_string(config_path)?;
info!("Reading configuration from file: {config_path:?}");
let config_toml = read_to_string(config_path)?;
let file_config: Config = toml::from_str(&config_toml)?;
return Ok(file_config);
Ok(file_config)
} else {
Ok(cli_config)
}

Ok(cli_config)
}
12 changes: 7 additions & 5 deletions src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub(crate) struct ProxyServer {
impl ProxyServer {
#[must_use]
/// Create new `ProxyServer`.
pub fn new() -> Self {
pub(crate) fn new() -> Self {
Self {
current_id: Arc::new(AtomicU64::new(1)),
clients: Arc::new(Mutex::new(HashMap::new())),
Expand All @@ -42,7 +42,7 @@ impl ProxyServer {
/// Sends message to the other side of RPC, with given `payload` and optional `device_info`.
/// Returns `tokio::sync::oneshot::Reveicer` to let the caller await reply.
#[instrument(name = "send_grpc_message", level = "debug", skip(self))]
pub fn send(
pub(crate) fn send(
&self,
payload: Option<core_request::Payload>,
device_info: Option<DeviceInfo>,
Expand All @@ -64,9 +64,11 @@ impl ProxyServer {
self.connected.store(true, Ordering::Relaxed);
Ok(rx)
} else {
error!("Defguard core is disconnected");
error!("Defguard Core is not connected");
self.connected.store(false, Ordering::Relaxed);
Err(ApiError::Unexpected("Defguard core is disconnected".into()))
Err(ApiError::Unexpected(
"Defguard Core is not connected".into(),
))
}
}
}
Expand Down Expand Up @@ -96,7 +98,7 @@ impl proxy_server::Proxy for ProxyServer {
error!("Failed to determine client address for request: {request:?}");
return Err(Status::internal("Failed to determine client address"));
};
info!("Defguard core RPC client connected from: {address}");
info!("Defguard Core gRPC client connected from: {address}");

let (tx, rx) = mpsc::unbounded_channel();
self.clients.lock().unwrap().insert(address, tx);
Expand Down
10 changes: 5 additions & 5 deletions src/handlers/enrollment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
},
};

pub fn router() -> Router<AppState> {
pub(crate) fn router() -> Router<AppState> {
Router::new()
.route("/start", post(start_enrollment_process))
.route("/activate_user", post(activate_user))
Expand All @@ -21,7 +21,7 @@ pub fn router() -> Router<AppState> {
}

#[instrument(level = "debug", skip(state))]
pub async fn start_enrollment_process(
async fn start_enrollment_process(
State(state): State<AppState>,
mut private_cookies: PrivateCookieJar,
Json(req): Json<EnrollmentStartRequest>,
Expand Down Expand Up @@ -60,7 +60,7 @@ pub async fn start_enrollment_process(
}

#[instrument(level = "debug", skip(state))]
pub async fn activate_user(
async fn activate_user(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
mut private_cookies: PrivateCookieJar,
Expand Down Expand Up @@ -95,7 +95,7 @@ pub async fn activate_user(
}

#[instrument(level = "debug", skip(state))]
pub async fn create_device(
async fn create_device(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
private_cookies: PrivateCookieJar,
Expand Down Expand Up @@ -123,7 +123,7 @@ pub async fn create_device(
}

#[instrument(level = "debug", skip(state))]
pub async fn get_network_info(
async fn get_network_info(
State(state): State<AppState>,
private_cookies: PrivateCookieJar,
Json(mut req): Json<ExistingDevice>,
Expand Down
9 changes: 5 additions & 4 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ use crate::{error::ApiError, proto::core_response::Payload};

pub(crate) mod desktop_client_mfa;
pub(crate) mod enrollment;
pub(crate) mod openid_login;
pub(crate) mod password_reset;
pub(crate) mod polling;

// timeout in seconds for awaiting core response
const CORE_RESPONSE_TIMEOUT: u64 = 5;
// Timeout for awaiting response from Defguard Core.
const CORE_RESPONSE_TIMEOUT: Duration = Duration::from_secs(5);

#[tonic::async_trait]
impl<S> FromRequestParts<S> for DeviceInfo
Expand Down Expand Up @@ -49,7 +50,7 @@ where
/// Waits for core response with a given timeout and returns the response payload.
async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
debug!("Fetching core response...");
if let Ok(core_response) = timeout(Duration::from_secs(CORE_RESPONSE_TIMEOUT), rx).await {
if let Ok(core_response) = timeout(CORE_RESPONSE_TIMEOUT, rx).await {
debug!("Got gRPC response from Defguard core: {core_response:?}");
if let Ok(Payload::CoreError(core_error)) = core_response {
error!(
Expand All @@ -61,7 +62,7 @@ async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
core_response
.map_err(|err| ApiError::Unexpected(format!("Failed to receive core response: {err}")))
} else {
error!("Did not receive core response within {CORE_RESPONSE_TIMEOUT} seconds");
error!("Did not receive core response within {CORE_RESPONSE_TIMEOUT:?}");
Err(ApiError::CoreTimeout)
}
}
75 changes: 75 additions & 0 deletions src/handlers/openid_login.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use axum::{extract::State, routing::get, Json, Router};
use axum_extra::extract::{
cookie::{Cookie, SameSite},
PrivateCookieJar,
};
use serde::Serialize;
use time::Duration;

use crate::{
error::ApiError,
handlers::get_core_response,
http::AppState,
proto::{core_request, core_response},
};

const COOKIE_MAX_AGE: Duration = Duration::days(1);
static CSRF_COOKIE_NAME: &str = "csrf";
static NONCE_COOKIE_NAME: &str = "nonce";

pub(crate) fn router() -> Router<AppState> {
Router::new().route("/auth_info", get(auth_info))
}

#[derive(Serialize)]
struct AuthInfo {
url: String,
}

impl AuthInfo {
#[must_use]
fn new(url: String) -> Self {
Self { url }
}
}

/// Request external OAuth2/OpenID provider details from Defguard Core.
#[instrument(level = "debug", skip(state))]
async fn auth_info(
State(state): State<AppState>,
private_cookies: PrivateCookieJar,
) -> Result<(PrivateCookieJar, Json<AuthInfo>), ApiError> {
debug!("Getting auth info for OAuth2/OpenID login");

let rx = state
.grpc_server
.send(Some(core_request::Payload::AuthInfo(())), None)?;
let payload = get_core_response(rx).await?;
if let core_response::Payload::AuthInfo(response) = payload {
debug!("Got auth info {response:?}");

let nonce_cookie = Cookie::build((NONCE_COOKIE_NAME, response.nonce))
// .domain(cookie_domain)
// .path("/api/v1/openid/callback")
.http_only(true)
.same_site(SameSite::Strict)
.secure(true)
.max_age(COOKIE_MAX_AGE)
.build();
let csrf_cookie = Cookie::build((CSRF_COOKIE_NAME, response.csrf_token))
// .domain(cookie_domain)
// .path("/api/v1/openid/callback")
.http_only(true)
.same_site(SameSite::Strict)
.secure(true)
.max_age(COOKIE_MAX_AGE)
.build();
let private_cookies = private_cookies.add(nonce_cookie).add(csrf_cookie);

let auth_info = AuthInfo::new(response.url);
Ok((private_cookies, Json(auth_info)))
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}
8 changes: 4 additions & 4 deletions src/handlers/password_reset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ use crate::{
},
};

pub fn router() -> Router<AppState> {
pub(crate) fn router() -> Router<AppState> {
Router::new()
.route("/request", post(request_password_reset))
.route("/start", post(start_password_reset))
.route("/reset", post(reset_password))
}

#[instrument(level = "debug", skip(state))]
pub async fn request_password_reset(
async fn request_password_reset(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
Json(req): Json<PasswordResetInitializeRequest>,
Expand All @@ -42,7 +42,7 @@ pub async fn request_password_reset(
}

#[instrument(level = "debug", skip(state))]
pub async fn start_password_reset(
async fn start_password_reset(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
mut private_cookies: PrivateCookieJar,
Expand Down Expand Up @@ -77,7 +77,7 @@ pub async fn start_password_reset(
}

#[instrument(level = "debug", skip(state))]
pub async fn reset_password(
async fn reset_password(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
mut private_cookies: PrivateCookieJar,
Expand Down
2 changes: 1 addition & 1 deletion src/handlers/polling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
};

#[instrument(level = "debug", skip(state))]
pub async fn info(
pub(crate) async fn info(
State(state): State<AppState>,
Json(req): Json<InstanceInfoRequest>,
) -> Result<Json<InstanceInfoResponse>, ApiError> {
Expand Down
18 changes: 11 additions & 7 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
config::Config,
error::ApiError,
grpc::ProxyServer,
handlers::{desktop_client_mfa, enrollment, password_reset, polling},
handlers::{desktop_client_mfa, enrollment, openid_login, password_reset, polling},
proto::proxy_server,
};

Expand Down Expand Up @@ -71,7 +71,10 @@ async fn healthcheckgrpc(State(state): State<AppState>) -> (StatusCode, &'static
if state.grpc_server.connected.load(Ordering::Relaxed) {
(StatusCode::OK, "Alive")
} else {
(StatusCode::SERVICE_UNAVAILABLE, "Not connected to core")
(
StatusCode::SERVICE_UNAVAILABLE,
"Not connected to Defguard Core",
)
}
}

Expand All @@ -95,7 +98,7 @@ fn get_client_addr(request: &Request<Body>) -> String {
}

pub async fn run_server(config: Config) -> anyhow::Result<()> {
info!("Starting Defguard proxy server");
info!("Starting Defguard Proxy server");
debug!("Using config: {config:?}");

let mut tasks = JoinSet::new();
Expand All @@ -111,7 +114,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
key: Key::generate(),
};

// read gRPC TLS cert and key
// Read gRPC TLS certificate and key.
debug!("Configuring certificates for gRPC");
let grpc_cert = config
.grpc_cert
Expand All @@ -121,7 +124,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
.grpc_key
.as_ref()
.and_then(|path| read_to_string(path).ok());
debug!("Configured certificates for gRPC, cert: {grpc_cert:?}");
debug!("Configured gRPC certificate: {grpc_cert:?}");

// Start gRPC server.
debug!("Spawning gRPC server");
Expand Down Expand Up @@ -159,7 +162,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
tokio::spawn(async move {
loop {
tokio::time::sleep(RATE_LIMITER_CLEANUP_PERIOD).await;
tracing::debug!(
debug!(
"Cleaning-up rate limiter storage, current size: {}",
governor_limiter.len()
);
Expand Down Expand Up @@ -188,6 +191,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
.nest("/enrollment", enrollment::router())
.nest("/password-reset", password_reset::router())
.nest("/client-mfa", desktop_client_mfa::router())
.nest("/openid", openid_login::router())
.route("/poll", post(polling::info))
.route("/health", get(healthcheck))
.route("/health-grpc", get(healthcheckgrpc))
Expand Down Expand Up @@ -231,7 +235,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
.context("Error running HTTP server")
});

info!("Defguard proxy server initialization complete");
info!("Defguard Proxy server initialization complete");
while let Some(Ok(result)) = tasks.join_next().await {
result?;
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ pub(crate) mod proto {
#[macro_use]
extern crate tracing;

pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "-", env!("VERGEN_GIT_SHA"));
pub static VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "-", env!("VERGEN_GIT_SHA"));

0 comments on commit a5ffb68

Please sign in to comment.