Skip to content

Commit

Permalink
resolve host for every new connection in static tunnels
Browse files Browse the repository at this point in the history
  • Loading branch information
pufferffish committed Apr 4, 2022
1 parent 06d425b commit f637b0f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 37 deletions.
4 changes: 1 addition & 3 deletions cmd/wireproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ func main() {
exePath := executablePath()
unveilOrPanic("/", "r")
unveilOrPanic(exePath, "x")
if err := protect.UnveilBlock(); err != nil {
log.Fatal(err)
}

// only allow standard stdio operation, file reading, networking, and exec
// also remove unveil permission to lock unveil
pledgeOrPanic("stdio rpath inet dns proc exec")

isDaemonProcess := len(os.Args) > 1 && os.Args[1] == daemonProcess
Expand Down
89 changes: 55 additions & 34 deletions routine.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ type RoutineSpawner interface {
SpawnRoutine(vt *VirtualTun)
}

type addressPort struct {
address string
port uint16
}

// LookupAddr lookups a hostname.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, error) {
Expand All @@ -47,29 +52,7 @@ func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, erro
}
}

// ResolveAddrPort resolves a hostname and returns an AddrPort.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) ResolveAddrPort(saddr string) (*netip.AddrPort, error) {
name, sport, err := net.SplitHostPort(saddr)
if err != nil {
return nil, err
}

addr, err := d.ResolveAddrWithContext(context.Background(), name)
if err != nil {
return nil, err
}

port, err := strconv.Atoi(sport)
if err != nil || port < 0 || port > 65535 {
return nil, &net.OpError{Op: "dial", Err: errors.New("port must be numeric")}
}

addrPort := netip.AddrPortFrom(*addr, uint16(port))
return &addrPort, nil
}

// ResolveAddrPort resolves a hostname and returns an AddrPort.
// ResolveAddrPortWithContext resolves a hostname and returns an AddrPort.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*netip.Addr, error) {
addrs, err := d.LookupAddr(ctx, name)
Expand Down Expand Up @@ -101,7 +84,7 @@ func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*n
return &addr, nil
}

// ResolveAddrPort resolves a hostname and returns an IP.
// Resolve resolves a hostname and returns an IP.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addr, err := d.ResolveAddrWithContext(ctx, name)
Expand All @@ -112,6 +95,30 @@ func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context,
return ctx, addr.AsSlice(), nil
}

func parseAddressPort(endpoint string) (*addressPort, error) {
name, sport, err := net.SplitHostPort(endpoint)
if err != nil {
return nil, err
}

port, err := strconv.Atoi(sport)
if err != nil || port < 0 || port > 65535 {
return nil, &net.OpError{Op: "dial", Err: errors.New("port must be numeric")}
}

return &addressPort{address: name, port: uint16(port)}, nil
}

func (d VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, error) {
addr, err := d.ResolveAddrWithContext(context.Background(), endpoint.address)
if err != nil {
return nil, err
}

addrPort := netip.AddrPortFrom(*addr, endpoint.port)
return &addrPort, nil
}

// Spawns a socks5 server.
func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) {
conf := &socks5.Config{Dial: vt.tnet.DialContext, Resolver: vt}
Expand Down Expand Up @@ -150,8 +157,16 @@ func connForward(bufSize int, from io.ReadWriteCloser, to io.ReadWriteCloser) {
}

// tcpClientForward starts a new connection via wireguard and forward traffic from `conn`
func tcpClientForward(tnet *netstack.Net, target *net.TCPAddr, conn net.Conn) {
sconn, err := tnet.DialTCP(target)
func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
target, err := vt.resolveToAddrPort(raddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
return
}

tcpAddr := TCPAddrFromAddrPort(*target)

sconn, err := vt.tnet.DialTCP(tcpAddr)
if err != nil {
errorLogger.Printf("TCP Client Tunnel to %s: %s\n", target, err.Error())
return
Expand All @@ -163,11 +178,10 @@ func tcpClientForward(tnet *netstack.Net, target *net.TCPAddr, conn net.Conn) {

// Spawns a local TCP server which acts as a proxy to the specified target
func (conf *TCPClientTunnelConfig) SpawnRoutine(vt *VirtualTun) {
raddr, err := vt.ResolveAddrPort(conf.Target)
raddr, err := parseAddressPort(conf.Target)
if err != nil {
log.Fatal(err)
}
tcpAddr := TCPAddrFromAddrPort(*raddr)

server, err := net.ListenTCP("tcp", conf.BindAddress)
if err != nil {
Expand All @@ -179,13 +193,21 @@ func (conf *TCPClientTunnelConfig) SpawnRoutine(vt *VirtualTun) {
if err != nil {
log.Fatal(err)
}
go tcpClientForward(vt.tnet, tcpAddr, conn)
go tcpClientForward(vt, raddr, conn)
}
}

// tcpServerForward starts a new connection locally and forward traffic from `conn`
func tcpServerForward(target *net.TCPAddr, conn net.Conn) {
sconn, err := net.DialTCP("tcp", nil, target)
func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
target, err := vt.resolveToAddrPort(raddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
return
}

tcpAddr := TCPAddrFromAddrPort(*target)

sconn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
return
Expand All @@ -197,11 +219,10 @@ func tcpServerForward(target *net.TCPAddr, conn net.Conn) {

// Spawns a TCP server on wireguard which acts as a proxy to the specified target
func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) {
raddr, err := vt.ResolveAddrPort(conf.Target)
raddr, err := parseAddressPort(conf.Target)
if err != nil {
log.Fatal(err)
}
tcpAddr := TCPAddrFromAddrPort(*raddr)

addr := &net.TCPAddr{Port: conf.ListenPort}
server, err := vt.tnet.ListenTCP(addr)
Expand All @@ -214,6 +235,6 @@ func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) {
if err != nil {
log.Fatal(err)
}
go tcpServerForward(tcpAddr, conn)
go tcpServerForward(vt, raddr, conn)
}
}

0 comments on commit f637b0f

Please sign in to comment.