From e4fbdbd34cc22cc22985a10e103ccbe5dc3a0968 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20H=C3=A4ggblad?= Date: Tue, 24 Oct 2023 11:34:44 +0200 Subject: [PATCH 1/2] Make peer event channel bounded --- common/wireguard/src/active_peers.rs | 10 +++++----- common/wireguard/src/lib.rs | 4 ++-- common/wireguard/src/packet_relayer.rs | 10 +++++----- common/wireguard/src/platform/linux/tun_device.rs | 10 ++++------ common/wireguard/src/udp_listener.rs | 15 +++++++++------ 5 files changed, 25 insertions(+), 24 deletions(-) diff --git a/common/wireguard/src/active_peers.rs b/common/wireguard/src/active_peers.rs index 6973d29be1..a22453c861 100644 --- a/common/wireguard/src/active_peers.rs +++ b/common/wireguard/src/active_peers.rs @@ -11,12 +11,12 @@ use crate::event::Event; // Channels that are used to communicate with the various tunnels #[derive(Clone)] -pub struct PeerEventSender(mpsc::UnboundedSender); -pub(crate) struct PeerEventReceiver(mpsc::UnboundedReceiver); +pub struct PeerEventSender(mpsc::Sender); +pub(crate) struct PeerEventReceiver(mpsc::Receiver); impl PeerEventSender { - pub(crate) fn send(&self, event: Event) -> Result<(), mpsc::error::SendError> { - self.0.send(event) + pub(crate) async fn send(&self, event: Event) -> Result<(), mpsc::error::SendError> { + self.0.send(event).await } } @@ -27,7 +27,7 @@ impl PeerEventReceiver { } pub(crate) fn peer_event_channel() -> (PeerEventSender, PeerEventReceiver) { - let (tx, rx) = mpsc::unbounded_channel(); + let (tx, rx) = mpsc::channel(16); (PeerEventSender(tx), PeerEventReceiver(rx)) } diff --git a/common/wireguard/src/lib.rs b/common/wireguard/src/lib.rs index cd13ba6937..dc226c56fd 100644 --- a/common/wireguard/src/lib.rs +++ b/common/wireguard/src/lib.rs @@ -33,10 +33,10 @@ pub async fn start_wireguard( gateway_client_registry: Arc, ) -> Result<(), Box> { // We can either index peers by their IP like standard wireguard - let peers_by_ip = Arc::new(std::sync::Mutex::new(network_table::NetworkTable::new())); + let peers_by_ip = Arc::new(tokio::sync::Mutex::new(network_table::NetworkTable::new())); // ... or by their tunnel tag, which is a random number assigned to them - let peers_by_tag = Arc::new(std::sync::Mutex::new(wg_tunnel::PeersByTag::new())); + let peers_by_tag = Arc::new(tokio::sync::Mutex::new(wg_tunnel::PeersByTag::new())); // Start the tun device that is used to relay traffic outbound let (tun, tun_task_tx, tun_task_response_rx) = tun_device::TunDevice::new(peers_by_ip.clone()); diff --git a/common/wireguard/src/packet_relayer.rs b/common/wireguard/src/packet_relayer.rs index 40d49d5d37..67f6bf9cba 100644 --- a/common/wireguard/src/packet_relayer.rs +++ b/common/wireguard/src/packet_relayer.rs @@ -32,14 +32,14 @@ pub(crate) struct PacketRelayer { tun_task_response_rx: TunTaskResponseRx, // After receiving from the tun device, relay back to the correct tunnel - peers_by_tag: Arc>>, + peers_by_tag: Arc>>, } impl PacketRelayer { pub(crate) fn new( tun_task_tx: TunTaskTx, tun_task_response_rx: TunTaskResponseRx, - peers_by_tag: Arc>>, + peers_by_tag: Arc>>, ) -> (Self, PacketRelaySender) { let (packet_tx, packet_rx) = packet_relay_channel(); ( @@ -62,9 +62,9 @@ impl PacketRelayer { }, Some((tag, packet)) = self.tun_task_response_rx.recv() => { log::info!("Received response from tun device with tag: {tag}"); - self.peers_by_tag.lock().unwrap().get(&tag).and_then(|tx| { - tx.send(Event::Ip(packet.into())).tap_err(|e| log::error!("{e}")).ok() - }); + if let Some(tx) = self.peers_by_tag.lock().await.get(&tag) { + tx.send(Event::Ip(packet.into())).await.tap_err(|e| log::error!("{e}")).ok(); + } } } } diff --git a/common/wireguard/src/platform/linux/tun_device.rs b/common/wireguard/src/platform/linux/tun_device.rs index cc449112bc..348abdb609 100644 --- a/common/wireguard/src/platform/linux/tun_device.rs +++ b/common/wireguard/src/platform/linux/tun_device.rs @@ -43,7 +43,7 @@ pub struct TunDevice { tun_task_response_tx: TunTaskResponseTx, // The routing table, as how wireguard does it - peers_by_ip: Arc>, + peers_by_ip: Arc>, // This is an alternative to the routing table, where we just match outgoing source IP with // incoming destination IP. @@ -52,7 +52,7 @@ pub struct TunDevice { impl TunDevice { pub fn new( - peers_by_ip: Arc>, + peers_by_ip: Arc>, ) -> (Self, TunTaskTx, TunTaskResponseRx) { let tun = setup_tokio_tun_device( format!("{TUN_BASE_NAME}%d").as_str(), @@ -123,14 +123,12 @@ impl TunDevice { // This is how wireguard does it, by consulting the AllowedIPs table. if false { - let Ok(peers) = self.peers_by_ip.lock() else { - log::error!("Failed to lock peers_by_ip, aborting tun device read"); - return; - }; + let peers = self.peers_by_ip.lock().await; if let Some(peer_tx) = peers.longest_match(dst_addr).map(|(_, tx)| tx) { log::info!("Forward packet to wg tunnel"); peer_tx .send(Event::Ip(packet.to_vec().into())) + .await .tap_err(|err| log::error!("{err}")) .ok(); return; diff --git a/common/wireguard/src/udp_listener.rs b/common/wireguard/src/udp_listener.rs index 5c04c010a4..8e32325c36 100644 --- a/common/wireguard/src/udp_listener.rs +++ b/common/wireguard/src/udp_listener.rs @@ -50,10 +50,10 @@ pub struct WgUdpListener { registered_peers: RegisteredPeers, // The routing table, as defined by wireguard - peers_by_ip: Arc>, + peers_by_ip: Arc>, // ... or alternatively we can map peers by their tag - peers_by_tag: Arc>, + peers_by_tag: Arc>, // The UDP socket to the peer udp: Arc, @@ -70,8 +70,8 @@ pub struct WgUdpListener { impl WgUdpListener { pub async fn new( packet_tx: PacketRelaySender, - peers_by_ip: Arc>, - peers_by_tag: Arc>, + peers_by_ip: Arc>, + peers_by_tag: Arc>, gateway_client_registry: Arc, ) -> Result> { let wg_address = SocketAddr::new(WG_ADDRESS.parse().unwrap(), WG_PORT); @@ -143,6 +143,7 @@ impl WgUdpListener { log::info!("udp: received {len} bytes from {addr} from known peer"); peer_tx .send(Event::Wg(buf[..len].to_vec().into())) + .await .tap_err(|e| log::error!("{e}")) .ok(); continue; @@ -186,6 +187,7 @@ impl WgUdpListener { // We found the peer as connected, even though the addr was not known log::info!("udp: received {len} bytes from {addr} which is a known peer with unknown addr"); peer_tx.send(Event::WgVerified(buf[..len].to_vec().into())) + .await .tap_err(|err| log::error!("{err}")) .ok(); } else { @@ -205,10 +207,11 @@ impl WgUdpListener { self.packet_tx.clone(), ); - self.peers_by_ip.lock().unwrap().insert(registered_peer.allowed_ips, peer_tx.clone()); - self.peers_by_tag.lock().unwrap().insert(tag, peer_tx.clone()); + self.peers_by_ip.lock().await.insert(registered_peer.allowed_ips, peer_tx.clone()); + self.peers_by_tag.lock().await.insert(tag, peer_tx.clone()); peer_tx.send(Event::Wg(buf[..len].to_vec().into())) + .await .tap_err(|e| log::error!("{e}")) .ok(); From edf90158d0338cc513013effbe204ff3cfcbfdea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20H=C3=A4ggblad?= Date: Tue, 24 Oct 2023 11:37:10 +0200 Subject: [PATCH 2/2] Make tun task channel bounded --- common/wireguard/src/packet_relayer.rs | 2 +- common/wireguard/src/tun_task_channel.rs | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/common/wireguard/src/packet_relayer.rs b/common/wireguard/src/packet_relayer.rs index 67f6bf9cba..c37af0648e 100644 --- a/common/wireguard/src/packet_relayer.rs +++ b/common/wireguard/src/packet_relayer.rs @@ -58,7 +58,7 @@ impl PacketRelayer { tokio::select! { Some((tag, packet)) = self.packet_rx.0.recv() => { log::info!("Sent packet to tun device with tag: {tag}"); - self.tun_task_tx.send((tag, packet)).unwrap(); + self.tun_task_tx.send((tag, packet)).await.tap_err(|e| log::error!("{e}")).ok(); }, Some((tag, packet)) = self.tun_task_response_rx.recv() => { log::info!("Received response from tun device with tag: {tag}"); diff --git a/common/wireguard/src/tun_task_channel.rs b/common/wireguard/src/tun_task_channel.rs index 3ecc6a76b4..1cbd6985da 100644 --- a/common/wireguard/src/tun_task_channel.rs +++ b/common/wireguard/src/tun_task_channel.rs @@ -3,15 +3,15 @@ use tokio::sync::mpsc; pub(crate) type TunTaskPayload = (u64, Vec); #[derive(Clone)] -pub struct TunTaskTx(mpsc::UnboundedSender); -pub(crate) struct TunTaskRx(mpsc::UnboundedReceiver); +pub struct TunTaskTx(mpsc::Sender); +pub(crate) struct TunTaskRx(mpsc::Receiver); impl TunTaskTx { - pub(crate) fn send( + pub(crate) async fn send( &self, data: TunTaskPayload, ) -> Result<(), tokio::sync::mpsc::error::SendError> { - self.0.send(data) + self.0.send(data).await } } @@ -22,7 +22,7 @@ impl TunTaskRx { } pub(crate) fn tun_task_channel() -> (TunTaskTx, TunTaskRx) { - let (tun_task_tx, tun_task_rx) = tokio::sync::mpsc::unbounded_channel(); + let (tun_task_tx, tun_task_rx) = tokio::sync::mpsc::channel(16); (TunTaskTx(tun_task_tx), TunTaskRx(tun_task_rx)) }