Skip to content

Commit

Permalink
fix: replace host header by default
Browse files Browse the repository at this point in the history
This also aligns the behavior with the WS proxy behavior and adds header
support for WS proxy requests too.

It also adds the XFP header.

Closes: #879
  • Loading branch information
ctron committed Oct 4, 2024
1 parent 91ea563 commit c793b91
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 52 deletions.
182 changes: 135 additions & 47 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use axum::{
use bytes::BytesMut;
use futures_util::{sink::SinkExt, stream::StreamExt, TryStreamExt};
use http::{header::HOST, HeaderMap};
use reqwest::header::HeaderValue;
use std::sync::Arc;
use tokio_tungstenite::{
connect_async,
Expand All @@ -27,9 +26,16 @@ use tower_http::trace::TraceLayer;
///
/// Refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-Host
const X_FORWARDED_HOST: &str = "x-forwarded-host";
/// The X-Forwarded-Proto (XFP) header is a de-facto standard header for identifying the protocol
/// (HTTP or HTTPS) that a client used to connect to your proxy or load balancer.
///
/// Refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-Proto
const X_FORWARDED_PROTO: &str = "x-forwarded-proto";

/// A handler used for proxying HTTP requests to a backend.
pub(crate) struct ProxyHandlerHttp {
/// The protocol the proxy bound to
proto: String,
/// The client to use for proxy logic.
client: reqwest::Client,
/// The URL of the backend to which requests are to be proxied.
Expand Down Expand Up @@ -79,38 +85,76 @@ fn make_outbound_uri(backend: &Uri, request: &Uri) -> anyhow::Result<Uri> {
}

fn make_outbound_request(
inbound_proto: &str,
outbound_uri: &Uri,
headers: HeaderMap,
) -> anyhow::Result<http::Request<()>> {
let mut request = http::Request::builder().uri(outbound_uri.to_string());
method: http::Method,
original_headers: HeaderMap,
override_headers: HeaderMap,
) -> anyhow::Result<http::request::Builder> {
let mut request = http::Request::builder()
.uri(outbound_uri.to_string())
.method(method);

// get the host header value from the outbound request

let Some(outbound_host) = outbound_uri.authority().map(|authority| authority.host()) else {
anyhow::bail!("No host found in outbound URI");
};

for (maybe_key, val) in headers {
if let Some(key) = maybe_key {
// forward all inbound headers

for key in original_headers.keys() {
let values = original_headers
.get_all(key)
.iter()
.cloned()
.collect::<Vec<_>>();

for value in values {
if key == HOST {
// Except for the host header, which we replace with the backend host value.
// We also provide the original information in the XFH, XFP headers.
request = request.header(HOST, outbound_host);
request = request.header(X_FORWARDED_HOST, val);
request = request.header(X_FORWARDED_HOST, value);
request = request.header(X_FORWARDED_PROTO, inbound_proto);
} else {
request = request.header(key, value);
}
}
}

// Apply all header overrides.
// There is no special handling for any header (like host), as we leave manual intervention to
// the user.

if let Some(headers) = request.headers_mut() {
for (key, value) in override_headers {
let Some(key) = key else { continue };

if value.is_empty() {
// if the header value is empty, remove the header
headers.remove(key);
} else {
request = request.header(key, val);
// otherwise, replace header
headers.insert(key, value);
}
}
}

request.body(()).context("Failed to build outbound request")
Ok(request)
}

impl ProxyHandlerHttp {
/// Construct a new instance.
pub fn new(
proto: String,
client: reqwest::Client,
backend: Uri,
request_headers: HeaderMap,
rewrite: Option<String>,
) -> Arc<Self> {
Arc::new(Self {
proto,
client,
backend,
request_headers,
Expand Down Expand Up @@ -143,18 +187,16 @@ impl ProxyHandlerHttp {
) -> ServerResult<Response<Body>> {
// Construct the outbound URI & build a new request to be sent to the proxy backend.
let outbound_uri = make_outbound_uri(&state.backend, req.uri())?;

let mut headers = req.headers().clone();
for (header_name, header_value) in state.request_headers.clone() {
if let Some(name) = header_name {
headers.insert(name, header_value);
}
}

let mut outbound_req = state
.client
.request(req.method().clone(), outbound_uri.to_string())
.headers(headers.clone())
let outbound_req = make_outbound_request(
&state.proto,
&outbound_uri,
req.method().clone(),
req.headers().clone(),
state.request_headers.clone(),
)?;

// set body
let outbound_req = outbound_req
.body(reqwest::Body::from(
// It would be better to use a stream for this. However, right now,
// .into_data_stream() returns a stream which is not Send+Sync, so we can't pass it
Expand All @@ -166,17 +208,12 @@ impl ProxyHandlerHttp {
.map_err(|err| ServerError(err.into()))?
.freeze(),
))
.build()
.context("error building outbound request to proxy backend")?;

// Ensure the host header is set to target the backend if not given in the header override list.
if !headers.contains_key(HOST) {
if let Some(host) = state.backend.authority().map(|authority| authority.host()) {
if let Ok(host) = HeaderValue::from_str(host) {
outbound_req.headers_mut().insert("host", host);
}
}
}
// turn into reqwest type
let outbound_req = outbound_req
.try_into()
.context("error translating outbound request")?;

// Send the request & unpack the response.
let backend_res = state
Expand All @@ -197,31 +234,50 @@ impl ProxyHandlerHttp {

/// A handler used for proxying WebSockets to a backend.
pub struct ProxyHandlerWebSocket {
/// The protocol the proxy bound to
proto: String,
/// The URL of the backend to which requests are to be proxied.
backend: Uri,
/// An optional rewrite path to be used as the listening URI prefix, but which will be
/// stripped before being sent to the proxy backend.
rewrite: Option<String>,
/// The headers to inject with the request
request_headers: HeaderMap,
}

impl ProxyHandlerWebSocket {
/// Construct a new instance.
pub fn new(backend: Uri, rewrite: Option<String>) -> Arc<Self> {
Arc::new(Self { backend, rewrite })
pub fn new(
proto: String,
backend: Uri,
headers: HeaderMap,
rewrite: Option<String>,
) -> Arc<Self> {
Arc::new(Self {
proto,
backend,
rewrite,
request_headers: headers,
})
}

/// Build the sub-router for this proxy.
pub fn register(self: Arc<Self>, router: Router) -> Router {
let proxy = self.clone();
let override_headers = self.request_headers.clone();
let proto = self.proto.clone();
router.nest_service(
self.path(),
get(|req: Request<Body>| async move {
let headers = req.headers().to_owned();
let req_headers = req.headers().to_owned();
let uri = req.uri().clone();
let ws = req.extract::<WebSocketUpgrade, _>().await;
ws.map(|e| {
e.on_upgrade(|socket| async move {
proxy.clone().proxy_ws_request(socket, uri, headers).await
proxy
.clone()
.proxy_ws_request(&proto, socket, uri, req_headers, override_headers)
.await
})
})
}),
Expand All @@ -239,9 +295,11 @@ impl ProxyHandlerWebSocket {
#[tracing::instrument(level = "debug", skip(self, ws))]
async fn proxy_ws_request(
self: Arc<Self>,
inbound_proto: &str,
ws: WebSocket,
request_uri: Uri,
headers: HeaderMap,
req_headers: HeaderMap,
override_headers: HeaderMap,
) {
tracing::debug!("new websocket connection");

Expand All @@ -254,7 +312,24 @@ impl ProxyHandlerWebSocket {
}
};

let outbound_request = match make_outbound_request(&outbound_uri, headers) {
let outbound_request = match make_outbound_request(
inbound_proto,
&outbound_uri,
http::Method::CONNECT,
req_headers,
override_headers,
) {
Ok(outbound_uri) => outbound_uri,
Err(err) => {
tracing::error!(error = ?err, "failed to create outbound request");
return;
}
};

let outbound_request = match outbound_request
.body(())
.context("Failed to build outbound request")
{
Ok(outbound_uri) => outbound_uri,
Err(err) => {
tracing::error!(error = ?err, "failed to build outbound request");
Expand Down Expand Up @@ -330,6 +405,7 @@ impl ProxyHandlerWebSocket {

#[cfg(test)]
mod tests {
use crate::proxy::make_outbound_uri;
use axum::http::{HeaderValue, Uri};
use http::{
header::{
Expand All @@ -339,8 +415,6 @@ mod tests {
HeaderMap,
};

use crate::proxy::make_outbound_uri;

use super::{make_outbound_request, X_FORWARDED_HOST};

#[test]
Expand Down Expand Up @@ -450,17 +524,30 @@ mod tests {
HeaderValue::from_str("cookie1=value1; cookie2=value2")
.expect("Failed to create Header Value"),
),
(
COOKIE,
HeaderValue::from_str("cookie3=value1; cookie4=value2")
.expect("Failed to create Header Value"),
),
];
let mut want_headers = HeaderMap::new();

for (key, val) in inbound_headers {
want_headers.insert(key, val);
want_headers.append(key, val);
}

let have_outbound_uri = make_outbound_uri(&backend_uri, &inbound_uri)
.expect("Failed to create Uri instance from inbound");
let have_outbound_req = make_outbound_request(&have_outbound_uri, want_headers.clone())
.expect("Failed to create Request instance from inbound");
let have_outbound_req = make_outbound_request(
"http",
&have_outbound_uri,
http::Method::GET,
want_headers.clone(),
Default::default(),
)
.expect("Failed to create Request instance from inbound")
.body(())
.expect("Failed to create Request from builder");

assert_eq!(have_outbound_req.uri(), &have_outbound_uri);
assert_eq!(have_outbound_req.method(), &http::Method::GET);
Expand All @@ -472,9 +559,7 @@ mod tests {
&HeaderValue::from_static("backend")
);

for (key, val) in want_headers {
let key = key.expect("Expected header");

for key in want_headers.keys() {
if key == HOST {
continue;
}
Expand All @@ -490,12 +575,15 @@ mod tests {
continue;
}

let val = want_headers.get_all(key).iter().collect::<Vec<_>>();

assert_eq!(
have_outbound_req
.headers()
.get(key.clone())
.unwrap_or_else(|| panic!("Expected header value for {}", key)),
&val
.get_all(key.clone())
.iter()
.collect::<Vec<_>>(),
val
);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ fn router(state: Arc<State>, cfg: Arc<RtcServe>) -> Result<Router> {
state.serve_base.as_str()
);

let mut builder = ProxyBuilder::new(router);
let mut builder = ProxyBuilder::new(cfg.tls.is_some(), router);

// Build proxies

Expand Down
26 changes: 22 additions & 4 deletions src/serve/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ const DANGER: Emoji = Emoji("⚠️", "(!)");

/// A builder for the proxy router
pub(crate) struct ProxyBuilder {
tls: bool,
router: Router,
clients: ProxyClients,
}

impl ProxyBuilder {
/// Create a new builder
pub fn new(router: Router) -> Self {
pub fn new(tls: bool, router: Router) -> Self {
Self {
tls,
router,
clients: Default::default(),
}
Expand All @@ -36,8 +38,19 @@ impl ProxyBuilder {
rewrite: Option<String>,
opts: ProxyClientOptions,
) -> anyhow::Result<Self> {
let proto = match self.tls {
true => "https",
false => "http",
}
.to_string();

if ws {
let handler = ProxyHandlerWebSocket::new(backend.clone(), rewrite);
let handler = ProxyHandlerWebSocket::new(
proto,
backend.clone(),
request_headers.clone(),
rewrite,
);
tracing::info!(
"{}proxying websocket {} -> {}",
SERVER,
Expand All @@ -50,8 +63,13 @@ impl ProxyBuilder {
let no_sys_proxy = opts.no_system_proxy;
let insecure = opts.insecure;
let client = self.clients.get_client(opts)?;
let handler =
ProxyHandlerHttp::new(client, backend.clone(), request_headers.clone(), rewrite);
let handler = ProxyHandlerHttp::new(
proto,
client,
backend.clone(),
request_headers.clone(),
rewrite,
);
tracing::info!(
"{}proxying {} -> {} {} {}{}",
SERVER,
Expand Down

0 comments on commit c793b91

Please sign in to comment.