Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use TUN device for forwarding wireguard traffic #3902

Merged
merged 11 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion common/wireguard/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ base64 = "0.21.3"
#boringtun = "0.6.0"
boringtun = { git = "https://github.com/cloudflare/boringtun", rev = "e1d6360d6ab4529fc942a078e4c54df107abe2ba" }
bytes = "1.5.0"
dashmap = "5.5.3"
etherparse = "0.13.0"
futures = "0.3.28"
log.workspace = true
nym-task = { path = "../task" }
tap.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["rt-multi-thread", "net"]}
tokio = { workspace = true, features = ["rt-multi-thread", "net", "io-util"] }
tokio-tun = "0.9.0"
5 changes: 0 additions & 5 deletions common/wireguard/src/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ use bytes::Bytes;
#[allow(unused)]
#[derive(Debug, Clone)]
pub enum Event {
/// Dumb event with no data.
Dumb,
/// IP packet received from the WireGuard tunnel that should be passed through to the corresponding virtual device/internet.
/// Original implementation also has protocol here since it understands it, but we'll have to infer it downstream
WgPacket(Bytes),
Expand All @@ -17,9 +15,6 @@ pub enum Event {
impl Display for Event {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Event::Dumb => {
write!(f, "Dumb{{}}")
}
Event::WgPacket(data) => {
let size = data.len();
write!(f, "WgPacket{{ size={size} }}")
Expand Down
161 changes: 139 additions & 22 deletions common/wireguard/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use std::{
net::{Ipv4Addr, SocketAddr},
sync::Arc,
};

use base64::{engine::general_purpose, Engine as _};
use boringtun::x25519;
use dashmap::DashMap;
use etherparse::{InternetSlice, SlicedPacket};
use futures::StreamExt;
use log::{error, info};
use nym_task::TaskClient;
use tap::TapFallible;
use tokio::{net::UdpSocket, sync::mpsc, task::JoinHandle};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::UdpSocket,
sync::mpsc::{self, UnboundedSender},
task::JoinHandle,
};
use tun::WireGuardTunnel;

use crate::event::Event;
Expand All @@ -17,8 +27,14 @@ mod error;
mod event;
mod tun;

//const WG_ADDRESS = "0.0.0.0:51820";
const WG_ADDRESS: &str = "0.0.0.0:51822";
// The wireguard UDP listener
const WG_ADDRESS: &str = "0.0.0.0";
const WG_PORT: u16 = 51822;

// The interface used to route traffic
const TUN_BASE_NAME: &str = "nymtun";
const TUN_DEVICE_ADDRESS: &str = "10.0.0.1";
const TUN_DEVICE_NETMASK: &str = "255.255.255.0";

// The private key of the listener
// Corresponding public key: "WM8s8bYegwMa0TJ+xIwhk+dImk2IpDUKslDBCZPizlE="
Expand All @@ -27,11 +43,15 @@ const PRIVATE_KEY: &str = "AEqXrLFT4qjYq3wmX0456iv94uM6nDj5ugp6Jedcflg=";
// The public keys of the registered peers (clients)
const PEERS: &[&str; 1] = &[
// Corresponding private key: "ILeN6gEh6vJ3Ju8RJ3HVswz+sPgkcKtAYTqzQRhTtlo="
"NCIhkgiqxFx1ckKl3Zuh595DzIFl8mxju1Vg995EZhI=", // "mxV/mw7WZTe+0Msa0kvJHMHERDA/cSskiZWQce+TdEs=",
"NCIhkgiqxFx1ckKl3Zuh595DzIFl8mxju1Vg995EZhI=",
// Another key
// "mxV/mw7WZTe+0Msa0kvJHMHERDA/cSskiZWQce+TdEs=",
];

const MAX_PACKET: usize = 65535;

type ActivePeers = DashMap<SocketAddr, mpsc::UnboundedSender<Event>>;

fn init_static_dev_keys() -> (x25519::StaticSecret, x25519::PublicKey) {
// TODO: this is a temporary solution for development
let static_private_bytes: [u8; 32] = general_purpose::STANDARD
Expand All @@ -58,31 +78,109 @@ fn init_static_dev_keys() -> (x25519::StaticSecret, x25519::PublicKey) {
}

fn start_wg_tunnel(
addr: SocketAddr,
endpoint: SocketAddr,
udp: Arc<UdpSocket>,
static_private: x25519::StaticSecret,
peer_static_public: x25519::PublicKey,
tunnel_tx: UnboundedSender<Vec<u8>>,
) -> (JoinHandle<SocketAddr>, mpsc::UnboundedSender<Event>) {
let (mut tunnel, peer_tx) = WireGuardTunnel::new(udp, addr, static_private, peer_static_public);
let (mut tunnel, peer_tx) =
WireGuardTunnel::new(udp, endpoint, static_private, peer_static_public, tunnel_tx);
let join_handle = tokio::spawn(async move {
tunnel.spin_off().await;
addr
endpoint
});
(join_handle, peer_tx)
}

pub async fn start_wg_listener(
fn setup_tokio_tun_device(name: &str, address: Ipv4Addr, netmask: Ipv4Addr) -> tokio_tun::Tun {
log::info!("Creating TUN device with: address={address}, netmask={netmask}");
tokio_tun::Tun::builder()
.name(name)
.tap(false)
.packet_info(false)
.mtu(1350)
.up()
.address(address)
.netmask(netmask)
.try_build()
.expect("Failed to setup tun device, do you have permission?")
}

fn start_tun_device(_active_peers: Arc<ActivePeers>) -> UnboundedSender<Vec<u8>> {
let tun = setup_tokio_tun_device(
format!("{}%d", TUN_BASE_NAME).as_str(),
TUN_DEVICE_ADDRESS.parse().unwrap(),
TUN_DEVICE_NETMASK.parse().unwrap(),
);
log::info!("Created TUN device: {}", tun.name());

let (mut tun_device_rx, mut tun_device_tx) = tokio::io::split(tun);

// Channels to communicate with the other tasks
let (tun_task_tx, mut tun_task_rx) = mpsc::unbounded_channel::<Vec<u8>>();

tokio::spawn(async move {
let mut buf = [0u8; 1024];
loop {
tokio::select! {
// Reading from the TUN device
len = tun_device_rx.read(&mut buf) => match len {
Ok(len) => {
let packet = &buf[..len];
let dst_addr = boringtun::noise::Tunn::dst_address(packet).unwrap();

let headers = SlicedPacket::from_ip(packet).unwrap();
let src_addr = match headers.ip.unwrap() {
InternetSlice::Ipv4(ip, _) => ip.source_addr().to_string(),
InternetSlice::Ipv6(ip, _) => ip.source_addr().to_string(),
};
log::info!("iface: read Packet({src_addr} -> {dst_addr}, {len} bytes)");

// TODO: route packet to the correct peer.
log::info!("...forward packet to the correct peer (NOT YET IMPLEMENTED)");
},
Err(err) => {
log::info!("iface: read error: {err}");
break;
}
},

// Writing to the TUN device
Some(data) = tun_task_rx.recv() => {
let headers = SlicedPacket::from_ip(&data).unwrap();
let (source_addr, destination_addr) = match headers.ip.unwrap() {
InternetSlice::Ipv4(ip, _) => (ip.source_addr(), ip.destination_addr()),
InternetSlice::Ipv6(_, _) => unimplemented!(),
};

log::info!(
"iface: write Packet({source_addr} -> {destination_addr}, {} bytes)",
data.len()
);
// log::info!("iface: writing {} bytes", data.len());
tun_device_tx.write_all(&data).await.unwrap();
}
}
}
log::info!("TUN device shutting down");
});
tun_task_tx
}

async fn start_udp_listener(
tun_task_tx: UnboundedSender<Vec<u8>>,
active_peers: Arc<ActivePeers>,
mut task_client: TaskClient,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
log::info!("Starting wireguard listener on {}", WG_ADDRESS);
let udp_socket = Arc::new(UdpSocket::bind(WG_ADDRESS).await?);
let wg_address = SocketAddr::new(WG_ADDRESS.parse().unwrap(), WG_PORT);
log::info!("Starting wireguard UDP listener on {wg_address}");
let udp_socket = Arc::new(UdpSocket::bind(wg_address).await?);

// Setup some static keys for development
let (static_private, peer_static_public) = init_static_dev_keys();

tokio::spawn(async move {
// The set of active tunnels indexed by the peer's address
let mut active_peers: HashMap<SocketAddr, mpsc::UnboundedSender<Event>> = HashMap::new();
// Each tunnel is run in its own task, and the task handle is stored here so we can remove
// it from `active_peers` when the tunnel is closed
let mut active_peers_task_handles = futures::stream::FuturesUnordered::new();
Expand All @@ -91,50 +189,69 @@ pub async fn start_wg_listener(
while !task_client.is_shutdown() {
tokio::select! {
_ = task_client.recv() => {
log::trace!("WireGuard listener: received shutdown");
log::trace!("WireGuard UDP listener: received shutdown");
break;
}
// Handle tunnel closing
Some(addr) = active_peers_task_handles.next() => {
match addr {
Ok(addr) => {
info!("WireGuard listener: closed {addr:?}");
log::info!("Removing peer: {addr:?}");
active_peers.remove(&addr);
}
Err(err) => {
error!("WireGuard listener: error receiving shutdown from peer: {err}");
error!("WireGuard UDP listener: error receiving shutdown from peer: {err}");
}
}
}
},
// Handle incoming packets
Ok((len, addr)) = udp_socket.recv_from(&mut buf) => {
log::info!("Received {} bytes from {}", len, addr);
log::trace!("udp: received {} bytes from {}", len, addr);

if let Some(peer_tx) = active_peers.get_mut(&addr) {
log::info!("WireGuard listener: received packet from known peer");
log::info!("udp: received {len} bytes from {addr} from known peer");
peer_tx.send(Event::WgPacket(buf[..len].to_vec().into()))
.tap_err(|err| log::error!("{err}"))
.unwrap();
} else {
log::info!("WireGuard listener: received packet from unknown peer, starting tunnel");
log::info!("udp: received {len} bytes from {addr} from unknown peer, starting tunnel");
let (join_handle, peer_tx) = start_wg_tunnel(
addr,
udp_socket.clone(),
static_private.clone(),
peer_static_public
peer_static_public,
tun_task_tx.clone(),
);
peer_tx.send(Event::WgPacket(buf[..len].to_vec().into()))
.tap_err(|err| log::error!("{err}"))
.unwrap();

// WIP(JON): active peers should probably be keyed by peer_static_public
// instead. Does this current setup lead to any issues?
log::info!("Adding peer: {addr}");
active_peers.insert(addr, peer_tx);
active_peers_task_handles.push(join_handle);
}
}
},
}
}
log::info!("WireGuard listener: shutting down");
});

Ok(())
}

pub async fn start_wireguard(
task_client: TaskClient,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
// The set of active tunnels indexed by the peer's address
let active_peers: Arc<ActivePeers> = Arc::new(ActivePeers::new());

// Start the tun device that is used to relay traffic outbound
let tun_task_tx = start_tun_device(active_peers.clone());

// Start the UDP listener that clients connect to
start_udp_listener(tun_task_tx, active_peers, task_client).await?;

Ok(())
}
Loading
Loading