Skip to content

Commit

Permalink
Merge pull request #15 from Angelin01/hot-reload-certs
Browse files Browse the repository at this point in the history
Implement TLS hot reloading
  • Loading branch information
Angelin01 authored Nov 28, 2023
2 parents 2635a67 + 404006f commit 014469b
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 6 deletions.
152 changes: 152 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ figment = { version = "0.10.12", features = ["env", "yaml"] }
axum = { version = "0.6.20", default-features = false, features = ["json", "tokio"] }
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
tokio = { version = "1.34.0", features = ["macros", "rt-multi-thread", "signal"] }
notify = { version = "6.1.1", default-features = false }
notify-debouncer-full = { version = "0.3.1", default-features = false}
anyhow = "1.0.75"
thiserror = "1.0.50"

Expand Down
4 changes: 2 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ pub struct ServerConfig {
bind_addr: IpAddr,
port: u16,
pub insecure: bool,
cert: PathBuf,
key: PathBuf,
pub cert: PathBuf,
pub key: PathBuf,
}

impl ServerConfig {
Expand Down
74 changes: 70 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
mod health_handler;

use std::path::{Path};
use std::time::Duration;
use anyhow::Result;
use anyhow::{Result, Error};
use axum::Router;
use axum::routing::get;
use axum_server::Handle;
use axum_server::tls_rustls::RustlsConfig;
use notify::{RecommendedWatcher, RecursiveMode, Watcher};
use notify_debouncer_full::{DebouncedEvent, Debouncer, FileIdMap, new_debouncer};
use tokio::signal;
use tokio::sync::mpsc::Receiver;
use health_handler::health_handler;
use crate::config::Config;
use crate::config::{Config};

fn build_app() -> Router {
Router::new()
Expand All @@ -32,10 +37,18 @@ pub async fn serve(config: &Config) -> Result<()> {
}
else {
let tls_config = config.server.tls_config().await?;
axum_server::bind_rustls(addr, tls_config)
let hot_reload = tokio::spawn(hot_reload_tls(tls_config.clone(), config.server.cert.clone(), config.server.key.clone()));

let result = axum_server::bind_rustls(addr, tls_config)
.handle(shutdown_handle)
.serve(service)
.await?;
.await;

hot_reload.abort();

if let Err(e) = result {
return Err(Error::new(e));
}
}

Ok(())
Expand Down Expand Up @@ -72,3 +85,56 @@ async fn graceful_shutdown(handle: Handle) {
handle.graceful_shutdown(Some(Duration::from_secs(30)));
}
}


async fn hot_reload_tls(tls_config: RustlsConfig, cert_path: impl AsRef<Path>, key_path: impl AsRef<Path>) -> Result<()> {
let (mut debouncer, mut event_rx) = tls_watcher().await?;

debouncer.watcher().watch(cert_path.as_ref(), RecursiveMode::NonRecursive)?;
debouncer.watcher().watch(key_path.as_ref(), RecursiveMode::NonRecursive)?;

while let Some(events) = event_rx.recv().await {
let should_reload = events.iter().any(|e| {
let kind = &e.kind;
kind.is_modify() || kind.is_create() || kind.is_remove()
});

if should_reload {
match tls_config.reload_from_pem_file(&cert_path, &key_path).await {
Ok(_) => println!("Reloaded TLS certificates"),
Err(e) => println!("Failed reloading TLS certificates: {e}"),
};
}
}

Ok(())
}

async fn tls_watcher() -> Result<(Debouncer<RecommendedWatcher, FileIdMap>, Receiver<Vec<DebouncedEvent>>)> {
let (tx, rx) = tokio::sync::mpsc::channel(1);
// We're using this since async closures are unstable and I'd rather avoid nightly
let current_thread = tokio::runtime::Handle::current();

let debouncer = new_debouncer(
Duration::from_secs(1),
None,
move |res| {
let tx = tx.clone();

match res {
Ok(value) => {
current_thread.spawn(async move {
if let Err(e) = tx.send(value).await {
println!("Failed sending TLS reload event, error: {e}");
}
});
}
Err(err) => {
println!("Errored while watching TLS, errors: {err:?}");
}
};
}
)?;

Ok((debouncer, rx))
}

0 comments on commit 014469b

Please sign in to comment.