From a4a3f4c1128d65655377cb5442bf81c667d820c5 Mon Sep 17 00:00:00 2001 From: Jorge Alejandro Jimenez Luna Date: Thu, 29 Jun 2023 13:04:10 -0400 Subject: [PATCH] Split structure in order to implement transparent mode later --- Cargo.lock | 1 + Cargo.toml | 1 + src/app.rs | 167 ++----------------------------------- src/core.rs | 116 ++++++++++++++++++++++++++ src/http.rs | 103 ++++++++++------------- src/main.rs | 2 + src/transport/http.rs | 190 ++++++++++++++++++++++++++++++++++++++++++ src/transport/mod.rs | 10 +++ src/transport/tcp.rs | 44 ++++++++++ 9 files changed, 414 insertions(+), 220 deletions(-) create mode 100644 src/core.rs create mode 100644 src/transport/http.rs create mode 100644 src/transport/mod.rs create mode 100644 src/transport/tcp.rs diff --git a/Cargo.lock b/Cargo.lock index 6e4af12..68f7845 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -788,6 +788,7 @@ dependencies = [ name = "proxyswarm" version = "0.3.6" dependencies = [ + "async-trait", "base64", "bytes", "clap", diff --git a/Cargo.toml b/Cargo.toml index 1d4497e..b99b610 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,3 +27,4 @@ config = { version = "0.13.3", default-features = false, features = ["ini"] } thiserror = "1.0.40" http = "0.2.9" wildmatch = "2.1.1" +async-trait = "0.1.68" diff --git a/src/app.rs b/src/app.rs index 58bbed2..7cc4a11 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,78 +1,19 @@ -use super::http::{DigestState, HttpHandler}; +use super::http::DigestState; use super::proxy::{Credentials, Proxy}; use crate::acl::{Acl, Rule}; use crate::error::Error; - -use tokio::net::{TcpListener, TcpStream}; -use tokio::{self, signal, sync::oneshot}; - -use http_body_util::{combinators::BoxBody, BodyExt, Empty}; -use hyper::{ - body::{Bytes, Incoming}, - header::{PROXY_AUTHENTICATE, PROXY_AUTHORIZATION}, - server::conn::http1, - service::service_fn, - Request, Response, StatusCode, -}; +use crate::transport::http::HttpServer; +use crate::transport::Server; use config::Config; -use log::{debug, error, info, trace, warn}; +use log::info; use std::{ - convert::Infallible, net::SocketAddr, str::FromStr, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, - }, + sync::{Arc, Mutex}, }; -macro_rules! box_body { - ($t:expr) => { - $t.map(|f| f.boxed()) - }; -} - -#[inline] -pub(crate) fn empty() -> BoxBody { - Empty::::new() - .map_err(|never| match never {}) - .boxed() -} - -pub async fn redirect_http( - req: Request, -) -> Result>, Error> { - debug!("Request forwarded directly to original destination"); - - let host = req.uri().host().ok_or("Uri has no host")?; - let port = req.uri().port_u16().unwrap_or(80); - - let address = format!("{port}:{host}"); - - // Open a TCP connection to the remote host - let stream = match TcpStream::connect(address).await { - Ok(v) => v, - Err(e) => { - warn!("Unable to connect to {}: {e}", req.uri()); - return Ok(Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(empty()) - .unwrap()); - } - }; - - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; - tokio::spawn(async move { - if let Err(err) = conn.await { - println!("Connection failed: {:?}", err); - } - }); - - Ok(box_body!(sender.send_request(box_body!(req)).await?)) -} - #[derive(Clone, Debug)] pub enum OperationMode { Transparent, @@ -98,7 +39,7 @@ pub struct AppContext { pub mode: OperationMode, pub acl: Acl, - digest_state: Arc>, + pub digest_state: Arc>, } pub struct App { @@ -182,107 +123,13 @@ impl App { }) } - async fn handle_connection( - context: AppContext, - id: u32, - mut req: Request, - ) -> Result>, Infallible> { - debug!("[#{id}] Requested: {}", req.uri()); - trace!("[#{id}] Request struct: {req:?}"); - - if let Some(host) = req.uri().host() { - if context.acl.match_hostname(host) == Rule::Deny { - debug!("[#{id}] Avoided try to connect with {host}"); - return Ok(redirect_http(req).await.unwrap_or_else(|e| { - error!("Error forwarding request to destination: {e}"); - - Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(empty()) - .unwrap() - })); - } - } - - // Remove proxy headers - if matches!(context.mode, OperationMode::Proxy) { - let headers = req.headers_mut(); - headers.remove(PROXY_AUTHENTICATE); - headers.remove(PROXY_AUTHORIZATION); - } - - // Forward the request - let client = HttpHandler::new(id, context.proxies, Arc::clone(&context.digest_state)); - if let Err(e) = client.request(req).await { - error!("Error forwarding request to destination: {e}"); - return Ok(Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(empty()) - .unwrap()); - } - - debug!("[#{id}] Connection processed successful"); - Ok(Response::builder().body(empty()).unwrap()) - } - - async fn serve_http(context: AppContext) -> Result<(), Error> { - let addr = context.addr; - let count = Arc::new(AtomicU32::new(0)); - - let tcp_listener = TcpListener::bind(addr).await?; - info!("Proxy listening at http://{addr}. Press Ctrl+C to stop it",); - - let (tx, mut rx) = oneshot::channel(); - - // Main loop - tokio::spawn(async move { - loop { - tokio::select! { - conn = tcp_listener.accept() => { - let (stream, remote_addr) = match conn { - Ok(v) => v, - Err(e) => { - error!("Unable to accept incomming TCP connection: {e}"); - return; - } - }; - - // Get connections count - let id = count.fetch_add(1, Ordering::SeqCst); - debug!("[#{id}] Incoming connection: <{remote_addr}>"); - - let context = context.clone(); - let proxy = - service_fn(move |req| App::handle_connection(context.clone(), id, req)); - - tokio::spawn(async move { - if let Err(e) = http1::Builder::new() - .keep_alive(true) - .preserve_header_case(true) - .serve_connection(stream, proxy) - .with_upgrades() - .await { - error!("Server error: {e}"); - } - }); - } - _ = (&mut rx) => { break; } - } - } - }); - - signal::ctrl_c().await?; - let _ = tx.send(()); - Ok(()) - } - pub async fn run(self) -> Result<(), String> { // Separate to avoid add more logic if matches!(self.context.mode, OperationMode::Transparent) { todo!("Wait a little more :("); } - App::serve_http(self.context) + HttpServer::serve(self.context) .await .map_err(|e| format!("Server Error: {}", e))?; diff --git a/src/core.rs b/src/core.rs new file mode 100644 index 0000000..c1b68b8 --- /dev/null +++ b/src/core.rs @@ -0,0 +1,116 @@ +use async_trait::async_trait; +use hyper::{body::Incoming, Request}; +use std::{ + marker::PhantomData, + net::{IpAddr, SocketAddr}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, +}; + +use crate::error::Error; + +#[derive(Hash, Clone, Eq, PartialEq, Debug)] +pub(crate) enum MaybeNamedHost { + Address(IpAddr), + Hostname(String), +} + +impl From for MaybeNamedHost { + fn from(value: IpAddr) -> Self { + MaybeNamedHost::Address(value) + } +} + +impl From<&str> for MaybeNamedHost { + fn from(value: &str) -> Self { + MaybeNamedHost::Hostname(value.to_string()) + } +} + +impl std::fmt::Display for MaybeNamedHost { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MaybeNamedHost::Address(addr) => addr.fmt(f), + MaybeNamedHost::Hostname(name) => name.fmt(f), + } + } +} + +#[derive(Hash, Clone, Eq, PartialEq, Debug)] +pub struct MaybeNamedSock { + pub(crate) host: MaybeNamedHost, + pub(crate) port: u16, +} + +impl TryFrom for SocketAddr { + type Error = Error; + fn try_from(value: MaybeNamedSock) -> Result { + let ip = match value.host { + MaybeNamedHost::Address(addr) => addr, + MaybeNamedHost::Hostname(e) => { + return Err(e.into()); + } + }; + Ok(SocketAddr::new(ip, value.port)) + } +} + +impl From for MaybeNamedSock { + fn from(addr: SocketAddr) -> Self { + Self { + host: MaybeNamedHost::Address(addr.ip()), + port: addr.port(), + } + } +} + +impl std::fmt::Display for MaybeNamedSock { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let MaybeNamedHost::Address(IpAddr::V6(addr)) = self.host { + write!(f, "[{}]:{}", addr, self.port) + } else { + write!(f, "{}:{}", self.host, self.port) + } + } +} + +#[async_trait] +pub trait ToStream { + async fn into_stream(self) -> Result; +} + +pub struct ProxyRequest +where + T: ToStream, + S: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + pub destination: MaybeNamedSock, + pub inner: T, + pub _phanton: PhantomData, +} + +impl ProxyRequest +where + T: ToStream, + S: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + pub async fn into_stream(self) -> Result { + self.inner.into_stream().await + } +} + +#[async_trait] +impl ToStream for Request { + async fn into_stream(self) -> Result { + Ok(hyper::upgrade::on(self).await?) + } +} + +#[async_trait] +impl ToStream for TcpStream { + async fn into_stream(self) -> Result { + Ok(self) + } +} \ No newline at end of file diff --git a/src/http.rs b/src/http.rs index 08572d0..94d4dda 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,13 +1,15 @@ -use super::proxy::{AuthenticationScheme, Proxy}; -use super::utils::natural_size; +use crate::core::{ProxyRequest, ToStream}; use crate::error::Error; +use crate::proxy::{AuthenticationScheme, Proxy}; +use crate::utils::natural_size; + use base64::{engine::general_purpose::STANDARD, Engine as _}; use http_body_util::Empty; use hyper::{ self, - body::{Body, Bytes, Incoming}, + body::{Body, Bytes}, client::conn::http1::{self, Connection, SendRequest}, - header, Method, Request, StatusCode, Uri, Version, + header, Method, Request, StatusCode, Version, }; use log::{debug, error, trace, warn}; use std::{ @@ -46,14 +48,16 @@ where }) } -async fn tunnel( +async fn tunnel( id: u32, connection: Connection, - mut request: Request, + request: ProxyRequest, cancellation_token: Receiver<()>, -) -> Request +) -> Option> where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + R: ToStream + Send + Unpin + 'static, + S: AsyncRead + AsyncWrite + Send + Unpin + 'static, B1: Body + 'static, ::Error: Into>, { @@ -64,12 +68,12 @@ where Ok(v) => v, Err(e) => { error!("[#{id}] Unable to get underline stream: {e}"); - return request; + return Some(request); } } } _ = cancellation_token => { - return request; + return Some(request); } }; @@ -78,32 +82,27 @@ where // Upgrade the request to a tunnel. trace!("[#{id}] Upgrading request connection"); - match hyper::upgrade::on(&mut request).await { - Ok(mut upgraded) => { - // Proxying data - let (from, to) = match tokio::io::copy_bidirectional(&mut upgraded, &mut io).await { - Ok(v) => v, - Err(e) => { - warn!("[#{id}] Server io error: {e}"); - return request; - } - }; + let Ok(mut inner) = request.into_stream().await else { + error!("[#{id}] Unable to get incomming stream"); + return None; + }; - // Print message when done - debug!( - "[#{id}] Client wrote {} and received {}", - natural_size(from, false), - natural_size(to, false) - ); + let (from, to) = match tokio::io::copy_bidirectional(&mut inner, &mut io).await { + Ok(v) => v, + Err(e) => { + warn!("[#{id}] Server io error: {e}"); + return None; } - Err(e) => warn!("[#{id}] Upgrade error: {e}"), - } + }; - request -} + // Print message when done + debug!( + "[#{id}] Client wrote {} and received {}", + natural_size(from, false), + natural_size(to, false) + ); -fn host_addr(uri: &Uri) -> Option { - uri.authority().map(|auth| auth.to_string()) + None } impl HttpHandler { @@ -119,25 +118,6 @@ impl HttpHandler { } } - // pub fn from_proxy(proxy: Proxy) -> Self { - // ProxyClient { - // proxies: vec![proxy], - // bypass: Vec::new() - // } - // } - - // pub fn add_bypass_uri(&mut self, uri: &str) { - // self.bypass.push(String::from(uri)); - // } - - // pub fn add_proxy(&mut self, proxy: Proxy) { - // self.proxies.push(proxy); - // } - - // pub fn proxies(&self) -> &[Proxy] { - // return &self.proxies; - // } - pub async fn get_proxy_transport( &self, proxy: &Proxy, @@ -162,18 +142,17 @@ impl HttpHandler { Ok((sender, conn)) } - pub fn get_auth_response(&self, proxy: &Proxy, uri: &Uri) -> Result { + pub fn get_auth_response(&self, proxy: &Proxy, uri: &str) -> Result { let Some(credentials) = &proxy.credentials else { return Err(Error::AuthenticationRequired); }; // If the digest state is present, then we use it if let Some(state) = self.digest_state.lock().unwrap().as_mut() { - let uri = uri.to_string(); let context = digest_auth::AuthContext::new_with_method( &credentials.username, &credentials.password, - &uri, + uri, Option::<&'_ [u8]>::None, digest_auth::HttpMethod::CONNECT, ); @@ -187,8 +166,12 @@ impl HttpHandler { } } - pub async fn request(&self, mut req: Request) -> Result<(), Error> { - let uri = req.uri().clone(); + pub async fn request(&self, mut req: ProxyRequest) -> Result<(), Error> + where + T: ToStream + Send + Unpin + 'static, + S: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + let dest = req.destination.to_string(); for proxy in self.proxies.iter() { let Ok((mut sender, conn)) = self.get_proxy_transport(proxy).await else { @@ -204,10 +187,10 @@ impl HttpHandler { while retry_count > 0 { retry_count -= 1; let shallow = Request::builder() - .uri(&uri) + .uri(&dest) .method(Method::CONNECT) .version(Version::HTTP_11) - .header(header::HOST, host_addr(&uri).unwrap()) + .header(header::HOST, &dest) // In order to make a persistent connection .header("Proxy-Connection", "keep-alive") .header( @@ -216,7 +199,7 @@ impl HttpHandler { ) .header( header::PROXY_AUTHORIZATION, - self.get_auth_response(proxy, &uri)?, + self.get_auth_response(proxy, &dest)?, ) .body(Empty::new()) .unwrap(); @@ -284,7 +267,7 @@ impl HttpHandler { debug!("[#{id}] Successful connected with the proxy"); // Re-init all - req = wrapper.await?; + req = wrapper.await?.unwrap(); (tx, rx) = oneshot::channel(); wrapper = tokio::spawn(async move { tunnel(id, t.1, req, rx).await }); continue; @@ -293,7 +276,7 @@ impl HttpHandler { trace!("[#{id}] Reusing old connection"); } - req = wrapper.await?; + req = wrapper.await?.unwrap(); } Err("Unable to proxy".into()) diff --git a/src/main.rs b/src/main.rs index ef6b910..c6480cf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,8 @@ pub mod error; pub mod http; pub mod proxy; pub mod utils; +pub mod core; +pub mod transport; use crate::app::App; use log::{error, info, LevelFilter}; diff --git a/src/transport/http.rs b/src/transport/http.rs new file mode 100644 index 0000000..2eed138 --- /dev/null +++ b/src/transport/http.rs @@ -0,0 +1,190 @@ +use super::Server; + +use crate::acl::Rule; +use crate::app::AppContext; +use crate::core::{MaybeNamedHost, MaybeNamedSock, ProxyRequest}; +use crate::error::Error; +use crate::http::HttpHandler; + +use tokio::net::{TcpListener, TcpStream}; +use tokio::{self, signal, sync::oneshot}; + +use http_body_util::{combinators::BoxBody, BodyExt, Empty}; +use hyper::{ + body::{Bytes, Incoming}, + header::{PROXY_AUTHENTICATE, PROXY_AUTHORIZATION}, + server::conn::http1, + service::service_fn, + Request, Response, StatusCode, +}; + +use async_trait::async_trait; +use log::{debug, error, info, trace, warn}; + +use std::{ + convert::Infallible, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, +}; + +pub struct HttpServer; + +macro_rules! box_body { + ($t:expr) => { + $t.map(|f| f.boxed()) + }; +} + +#[inline] +pub(crate) fn empty() -> BoxBody { + Empty::::new() + .map_err(|never| match never {}) + .boxed() +} + +pub async fn redirect_http( + req: Request, +) -> Result>, Error> { + debug!("Request forwarded directly to original destination"); + + let host = req.uri().host().ok_or("Uri has no host")?; + let port = req.uri().port_u16().unwrap_or(80); + + let address = format!("{port}:{host}"); + + // Open a TCP connection to the remote host + let stream = match TcpStream::connect(address).await { + Ok(v) => v, + Err(e) => { + warn!("Unable to connect to {}: {e}", req.uri()); + return Ok(Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(empty()) + .unwrap()); + } + }; + + let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + tokio::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + Ok(box_body!(sender.send_request(box_body!(req)).await?)) +} + +impl HttpServer { + async fn handle_http_connection( + context: AppContext, + id: u32, + mut req: Request, + ) -> Result>, Infallible> { + debug!("[#{id}] Requested: {}", req.uri()); + trace!("[#{id}] Request struct: {req:?}"); + + let Some(host) = req.uri().host().map(|x| x.to_string()) else { + error!("[#{id}] Invalid request, missing host part in the uri"); + return Ok(Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(empty()) + .unwrap()); + }; + + if context.acl.match_hostname(&host) == Rule::Deny { + debug!("[#{id}] Avoided try to connect with {host}"); + return Ok(redirect_http(req).await.unwrap_or_else(|e| { + error!("Error forwarding request to destination: {e}"); + + Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(empty()) + .unwrap() + })); + } + + // Remove proxy headers + let headers = req.headers_mut(); + headers.remove(PROXY_AUTHENTICATE); + headers.remove(PROXY_AUTHORIZATION); + + // Forward the request + let client = HttpHandler::new(id, context.proxies, Arc::clone(&context.digest_state)); + + let request = ProxyRequest { + destination: MaybeNamedSock { + host: MaybeNamedHost::Hostname(host.to_string()), + port: req.uri().port_u16().unwrap_or(80), + }, + inner: req, + _phanton: std::marker::PhantomData, + }; + + if let Err(e) = client.request(request).await { + error!("Error forwarding request to destination: {e}"); + return Ok(Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(empty()) + .unwrap()); + } + + debug!("[#{id}] Connection processed successful"); + Ok(Response::builder().body(empty()).unwrap()) + } +} + +#[async_trait] +impl Server for HttpServer { + async fn serve(context: AppContext) -> Result<(), Error> { + let addr = context.addr; + let count = Arc::new(AtomicU32::new(0)); + + let tcp_listener = TcpListener::bind(addr).await?; + info!("Proxy listening at http://{addr}. Press Ctrl+C to stop it",); + + let (tx, mut rx) = oneshot::channel(); + + // Main loop + tokio::spawn(async move { + loop { + tokio::select! { + conn = tcp_listener.accept() => { + let (stream, remote_addr) = match conn { + Ok(v) => v, + Err(e) => { + error!("Unable to accept incomming TCP connection: {e}"); + return; + } + }; + + // Get connections count + let id = count.fetch_add(1, Ordering::SeqCst); + debug!("[#{id}] Incoming connection: <{remote_addr}>"); + + let context = context.clone(); + let proxy = + service_fn(move |req| Self::handle_http_connection(context.clone(), id, req)); + + tokio::spawn(async move { + if let Err(e) = http1::Builder::new() + .keep_alive(true) + .preserve_header_case(true) + .serve_connection(stream, proxy) + .with_upgrades() + .await { + error!("Server error: {e}"); + } + }); + } + _ = (&mut rx) => { break; } + } + } + }); + + signal::ctrl_c().await?; + let _ = tx.send(()); + Ok(()) + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs new file mode 100644 index 0000000..6bc8696 --- /dev/null +++ b/src/transport/mod.rs @@ -0,0 +1,10 @@ +use crate::{app::AppContext, error::Error}; +use async_trait::async_trait; + +pub mod http; +// pub mod tcp; + +#[async_trait] +pub trait Server { + async fn serve(context: AppContext) -> Result<(), Error>; +} diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs new file mode 100644 index 0000000..6199273 --- /dev/null +++ b/src/transport/tcp.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use super::Server; +use crate::{app::AppContext, core::ProxyRequest, error::Error, http::HttpHandler}; + +use async_trait::async_trait; +use log::{debug, error}; +use tokio::net::TcpStream; + +pub struct TcpServer; + +impl TcpServer { + async fn handle_tcp_connection(context: AppContext, id: u32, mut stream: TcpStream) { + let peer_addr = match stream.peer_addr() { + Ok(x) => x, + Err(e) => { + error!("[#{id}] Unable to get destination address: {e}"); + return; + } + }; + + let client = HttpHandler::new(id, context.proxies, Arc::clone(&context.digest_state)); + + let request = ProxyRequest { + destination: peer_addr.into(), + inner: stream, + _phanton: std::marker::PhantomData, + }; + + if let Err(e) = client.request(request).await { + error!("Error forwarding request to destination: {e}"); + return; + } + + debug!("[#{id}] Connection processed successful"); + } +} + +#[async_trait] +impl Server for TcpServer { + async fn serve(context: AppContext) -> Result<(), Error> { + Ok(()) + } +}