Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wg: bounded channels #4037

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions common/wireguard/src/active_peers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ use crate::event::Event;

// Channels that are used to communicate with the various tunnels
#[derive(Clone)]
pub struct PeerEventSender(mpsc::UnboundedSender<Event>);
pub(crate) struct PeerEventReceiver(mpsc::UnboundedReceiver<Event>);
pub struct PeerEventSender(mpsc::Sender<Event>);
pub(crate) struct PeerEventReceiver(mpsc::Receiver<Event>);

impl PeerEventSender {
pub(crate) fn send(&self, event: Event) -> Result<(), mpsc::error::SendError<Event>> {
self.0.send(event)
pub(crate) async fn send(&self, event: Event) -> Result<(), mpsc::error::SendError<Event>> {
self.0.send(event).await
}
}

Expand All @@ -27,7 +27,7 @@ impl PeerEventReceiver {
}

pub(crate) fn peer_event_channel() -> (PeerEventSender, PeerEventReceiver) {
let (tx, rx) = mpsc::unbounded_channel();
let (tx, rx) = mpsc::channel(16);
(PeerEventSender(tx), PeerEventReceiver(rx))
}

Expand Down
4 changes: 2 additions & 2 deletions common/wireguard/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ pub async fn start_wireguard(
gateway_client_registry: Arc<GatewayClientRegistry>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
// We can either index peers by their IP like standard wireguard
let peers_by_ip = Arc::new(std::sync::Mutex::new(network_table::NetworkTable::new()));
let peers_by_ip = Arc::new(tokio::sync::Mutex::new(network_table::NetworkTable::new()));

// ... or by their tunnel tag, which is a random number assigned to them
let peers_by_tag = Arc::new(std::sync::Mutex::new(wg_tunnel::PeersByTag::new()));
let peers_by_tag = Arc::new(tokio::sync::Mutex::new(wg_tunnel::PeersByTag::new()));

// Start the tun device that is used to relay traffic outbound
let (tun, tun_task_tx, tun_task_response_rx) = tun_device::TunDevice::new(peers_by_ip.clone());
Expand Down
12 changes: 6 additions & 6 deletions common/wireguard/src/packet_relayer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ pub(crate) struct PacketRelayer {
tun_task_response_rx: TunTaskResponseRx,

// After receiving from the tun device, relay back to the correct tunnel
peers_by_tag: Arc<std::sync::Mutex<HashMap<u64, PeerEventSender>>>,
peers_by_tag: Arc<tokio::sync::Mutex<HashMap<u64, PeerEventSender>>>,
}

impl PacketRelayer {
pub(crate) fn new(
tun_task_tx: TunTaskTx,
tun_task_response_rx: TunTaskResponseRx,
peers_by_tag: Arc<std::sync::Mutex<HashMap<u64, PeerEventSender>>>,
peers_by_tag: Arc<tokio::sync::Mutex<HashMap<u64, PeerEventSender>>>,
) -> (Self, PacketRelaySender) {
let (packet_tx, packet_rx) = packet_relay_channel();
(
Expand All @@ -58,13 +58,13 @@ impl PacketRelayer {
tokio::select! {
Some((tag, packet)) = self.packet_rx.0.recv() => {
log::info!("Sent packet to tun device with tag: {tag}");
self.tun_task_tx.send((tag, packet)).unwrap();
self.tun_task_tx.send((tag, packet)).await.tap_err(|e| log::error!("{e}")).ok();
},
Some((tag, packet)) = self.tun_task_response_rx.recv() => {
log::info!("Received response from tun device with tag: {tag}");
self.peers_by_tag.lock().unwrap().get(&tag).and_then(|tx| {
tx.send(Event::Ip(packet.into())).tap_err(|e| log::error!("{e}")).ok()
});
if let Some(tx) = self.peers_by_tag.lock().await.get(&tag) {
tx.send(Event::Ip(packet.into())).await.tap_err(|e| log::error!("{e}")).ok();
}
}
}
}
Expand Down
10 changes: 4 additions & 6 deletions common/wireguard/src/platform/linux/tun_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub struct TunDevice {
tun_task_response_tx: TunTaskResponseTx,

// The routing table, as how wireguard does it
peers_by_ip: Arc<std::sync::Mutex<PeersByIp>>,
peers_by_ip: Arc<tokio::sync::Mutex<PeersByIp>>,

// This is an alternative to the routing table, where we just match outgoing source IP with
// incoming destination IP.
Expand All @@ -52,7 +52,7 @@ pub struct TunDevice {

impl TunDevice {
pub fn new(
peers_by_ip: Arc<std::sync::Mutex<PeersByIp>>,
peers_by_ip: Arc<tokio::sync::Mutex<PeersByIp>>,
) -> (Self, TunTaskTx, TunTaskResponseRx) {
let tun = setup_tokio_tun_device(
format!("{TUN_BASE_NAME}%d").as_str(),
Expand Down Expand Up @@ -123,14 +123,12 @@ impl TunDevice {

// This is how wireguard does it, by consulting the AllowedIPs table.
if false {
let Ok(peers) = self.peers_by_ip.lock() else {
log::error!("Failed to lock peers_by_ip, aborting tun device read");
return;
};
let peers = self.peers_by_ip.lock().await;
if let Some(peer_tx) = peers.longest_match(dst_addr).map(|(_, tx)| tx) {
log::info!("Forward packet to wg tunnel");
peer_tx
.send(Event::Ip(packet.to_vec().into()))
.await
.tap_err(|err| log::error!("{err}"))
.ok();
return;
Expand Down
10 changes: 5 additions & 5 deletions common/wireguard/src/tun_task_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ use tokio::sync::mpsc;
pub(crate) type TunTaskPayload = (u64, Vec<u8>);

#[derive(Clone)]
pub struct TunTaskTx(mpsc::UnboundedSender<TunTaskPayload>);
pub(crate) struct TunTaskRx(mpsc::UnboundedReceiver<TunTaskPayload>);
pub struct TunTaskTx(mpsc::Sender<TunTaskPayload>);
pub(crate) struct TunTaskRx(mpsc::Receiver<TunTaskPayload>);

impl TunTaskTx {
pub(crate) fn send(
pub(crate) async fn send(
&self,
data: TunTaskPayload,
) -> Result<(), tokio::sync::mpsc::error::SendError<TunTaskPayload>> {
self.0.send(data)
self.0.send(data).await
}
}

Expand All @@ -22,7 +22,7 @@ impl TunTaskRx {
}

pub(crate) fn tun_task_channel() -> (TunTaskTx, TunTaskRx) {
let (tun_task_tx, tun_task_rx) = tokio::sync::mpsc::unbounded_channel();
let (tun_task_tx, tun_task_rx) = tokio::sync::mpsc::channel(16);
(TunTaskTx(tun_task_tx), TunTaskRx(tun_task_rx))
}

Expand Down
15 changes: 9 additions & 6 deletions common/wireguard/src/udp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ pub struct WgUdpListener {
registered_peers: RegisteredPeers,

// The routing table, as defined by wireguard
peers_by_ip: Arc<std::sync::Mutex<PeersByIp>>,
peers_by_ip: Arc<tokio::sync::Mutex<PeersByIp>>,

// ... or alternatively we can map peers by their tag
peers_by_tag: Arc<std::sync::Mutex<PeersByTag>>,
peers_by_tag: Arc<tokio::sync::Mutex<PeersByTag>>,

// The UDP socket to the peer
udp: Arc<UdpSocket>,
Expand All @@ -70,8 +70,8 @@ pub struct WgUdpListener {
impl WgUdpListener {
pub async fn new(
packet_tx: PacketRelaySender,
peers_by_ip: Arc<std::sync::Mutex<PeersByIp>>,
peers_by_tag: Arc<std::sync::Mutex<PeersByTag>>,
peers_by_ip: Arc<tokio::sync::Mutex<PeersByIp>>,
peers_by_tag: Arc<tokio::sync::Mutex<PeersByTag>>,
gateway_client_registry: Arc<GatewayClientRegistry>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync + 'static>> {
let wg_address = SocketAddr::new(WG_ADDRESS.parse().unwrap(), WG_PORT);
Expand Down Expand Up @@ -143,6 +143,7 @@ impl WgUdpListener {
log::info!("udp: received {len} bytes from {addr} from known peer");
peer_tx
.send(Event::Wg(buf[..len].to_vec().into()))
.await
.tap_err(|e| log::error!("{e}"))
.ok();
continue;
Expand Down Expand Up @@ -186,6 +187,7 @@ impl WgUdpListener {
// We found the peer as connected, even though the addr was not known
log::info!("udp: received {len} bytes from {addr} which is a known peer with unknown addr");
peer_tx.send(Event::WgVerified(buf[..len].to_vec().into()))
.await
.tap_err(|err| log::error!("{err}"))
.ok();
} else {
Expand All @@ -205,10 +207,11 @@ impl WgUdpListener {
self.packet_tx.clone(),
);

self.peers_by_ip.lock().unwrap().insert(registered_peer.allowed_ips, peer_tx.clone());
self.peers_by_tag.lock().unwrap().insert(tag, peer_tx.clone());
self.peers_by_ip.lock().await.insert(registered_peer.allowed_ips, peer_tx.clone());
self.peers_by_tag.lock().await.insert(tag, peer_tx.clone());

peer_tx.send(Event::Wg(buf[..len].to_vec().into()))
.await
.tap_err(|e| log::error!("{e}"))
.ok();

Expand Down
Loading