Skip to content

Commit

Permalink
Improve Websocket implementation (#622)
Browse files Browse the repository at this point in the history
* Update examples modules

* Remove redundant websocket functions

* Add a custom per-request custom dialer

* Remove unused target URL

* Use configurable dialer also for WSS connections

* Add customizable TLS client initialization for each request
  • Loading branch information
ErikPelli authored Jan 12, 2025
1 parent 10fc34b commit dd21e8d
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 65 deletions.
6 changes: 6 additions & 0 deletions ctx.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package goproxy

import (
"context"
"crypto/tls"
"mime"
"net"
"net/http"
)

Expand All @@ -14,6 +16,10 @@ type ProxyCtx struct {
// Will contain the remote server's response (if available. nil if the request wasn't send yet)
Resp *http.Response
RoundTripper RoundTripper
// Specify a custom connection dialer that will be used only for the current
// request, including WebSocket connection upgrades
Dialer func(ctx context.Context, network string, addr string) (net.Conn, error)
InitializeTLS func(rawConn net.Conn, cfg *tls.Config) (net.Conn, error)
// will contain the recent error that occurred while trying to send receive or parse traffic
Error error
// A handle for the user to keep data in the context, from the call of ReqHandler to the
Expand Down
7 changes: 3 additions & 4 deletions examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ module github.com/elazarl/goproxy/examples/goproxy-transparent
go 1.20

require (
github.com/elazarl/goproxy v0.0.0-20241217120900-7711dfa3811c
github.com/elazarl/goproxy/ext v0.0.0-20241217120900-7711dfa3811c
github.com/elazarl/goproxy v1.3.0
github.com/elazarl/goproxy/ext v0.0.0-20250110140559-10fc34b80676
github.com/gorilla/websocket v1.5.3
github.com/inconshreveable/go-vhost v1.0.0
)

require (
github.com/rogpeppe/go-charset v0.0.0-20190617161244-0dc95cdf6f31 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/text v0.21.0 // indirect
)

Expand Down
15 changes: 8 additions & 7 deletions examples/go.sum
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
github.com/elazarl/goproxy/ext v0.0.0-20241217120900-7711dfa3811c h1:R+i10jtNSzKJKqEZAYJnR9M8y14k0zrNHqD1xkv/A2M=
github.com/elazarl/goproxy/ext v0.0.0-20241217120900-7711dfa3811c/go.mod h1:gNh8nYJoAm43RfaxurUnxr+N1PwuFV3ZMl/efxlIlY8=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/elazarl/goproxy/ext v0.0.0-20250110140559-10fc34b80676 h1:3bAtOWqImclW/5rXbhNyAcM122jafst+/+4J4vC8wZI=
github.com/elazarl/goproxy/ext v0.0.0-20250110140559-10fc34b80676/go.mod h1:q2JQCFWg+AQfe6O2cbf7LJDB48R68w+q0pBU53v02iM=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/inconshreveable/go-vhost v1.0.0 h1:IK4VZTlXL4l9vz2IZoiSFbYaaqUW7dXJAiPriUN5Ur8=
github.com/inconshreveable/go-vhost v1.0.0/go.mod h1:aA6DnFhALT3zH0y+A39we+zbrdMC2N0X/q21e6FI0LU=
github.com/rogpeppe/go-charset v0.0.0-20180617210344-2471d30d28b4/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc=
github.com/rogpeppe/go-charset v0.0.0-20190617161244-0dc95cdf6f31 h1:DE4LcMKyqAVa6a0CGmVxANbnVb7stzMmPkQiieyNmfQ=
github.com/rogpeppe/go-charset v0.0.0-20190617161244-0dc95cdf6f31/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
4 changes: 3 additions & 1 deletion http.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ func (proxy *ProxyHttpServer) handleHttp(w http.ResponseWriter, r *http.Request)
if resp == nil {
if isWebSocketRequest(r) {
ctx.Logf("Request looks like websocket upgrade.")
proxy.serveWebsocket(ctx, w, r)
if conn, err := proxy.hijackConnection(ctx, w); err == nil {
proxy.serveWebsocket(ctx, conn, r)
}
}

if !proxy.KeepHeader {
Expand Down
41 changes: 32 additions & 9 deletions https.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package goproxy

import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -72,16 +71,23 @@ func stripPort(s string) string {
return s[:ix]
}

func (proxy *ProxyHttpServer) dial(ctx context.Context, network, addr string) (c net.Conn, err error) {
func (proxy *ProxyHttpServer) dial(ctx *ProxyCtx, network, addr string) (c net.Conn, err error) {
if ctx.Dialer != nil {
return ctx.Dialer(ctx.Req.Context(), network, addr)
}

if proxy.Tr.DialContext != nil {
return proxy.Tr.DialContext(ctx, network, addr)
return proxy.Tr.DialContext(ctx.Req.Context(), network, addr)
}

// if the user didn't specify any dialer, we just use the default one,
// provided by net package
return net.Dial(network, addr)
}

func (proxy *ProxyHttpServer) connectDial(ctx *ProxyCtx, network, addr string) (c net.Conn, err error) {
if proxy.ConnectDialWithReq == nil && proxy.ConnectDial == nil {
return proxy.dial(ctx.Req.Context(), network, addr)
return proxy.dial(ctx, network, addr)
}

if proxy.ConnectDialWithReq != nil {
Expand Down Expand Up @@ -340,9 +346,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
ctx.Logf("Request looks like websocket upgrade.")
if req.URL.Scheme == "http" {
ctx.Logf("Enforced HTTP websocket forwarding over TLS")
proxy.serveWebsocketHttpOverTLS(ctx, w, req, rawClientTls)
proxy.serveWebsocket(ctx, rawClientTls, req)
} else {
proxy.serveWebsocketTLS(ctx, w, req, tlsConfig, rawClientTls)
proxy.serveWebsocketTLS(ctx, req, tlsConfig, rawClientTls)
}
return
}
Expand Down Expand Up @@ -533,7 +539,7 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
if connectReqHandler != nil {
connectReqHandler(connectReq)
}
c, err := proxy.dial(context.Background(), network, u.Host)
c, err := proxy.dial(&ProxyCtx{Req: &http.Request{}}, network, u.Host)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -564,11 +570,17 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
u.Host += ":443"
}
return func(network, addr string) (net.Conn, error) {
c, err := proxy.dial(context.Background(), network, u.Host)
ctx := &ProxyCtx{Req: &http.Request{}}
c, err := proxy.dial(ctx, network, u.Host)
if err != nil {
return nil, err
}
c = tls.Client(c, proxy.Tr.TLSClientConfig)

c, err = proxy.initializeTLSconnection(ctx, c, proxy.Tr.TLSClientConfig)
if err != nil {
return nil, err
}

connectReq := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Opaque: addr},
Expand Down Expand Up @@ -630,3 +642,14 @@ func TLSConfigFromCA(ca *tls.Certificate) func(host string, ctx *ProxyCtx) (*tls
return config, nil
}
}

func (proxy *ProxyHttpServer) initializeTLSconnection(
ctx *ProxyCtx,
targetConn net.Conn,
tlsConfig *tls.Config,
) (net.Conn, error) {
if ctx.InitializeTLS != nil {
return ctx.InitializeTLS(targetConn, tlsConfig)
}
return tls.Client(targetConn, tlsConfig), nil
}
69 changes: 25 additions & 44 deletions websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io"
"net"
"net/http"
"net/url"
"strings"
)

Expand All @@ -28,60 +27,32 @@ func isWebSocketRequest(r *http.Request) bool {

func (proxy *ProxyHttpServer) serveWebsocketTLS(
ctx *ProxyCtx,
w http.ResponseWriter,
req *http.Request,
tlsConfig *tls.Config,
clientConn *tls.Conn,
) {
// wss
host := req.URL.Host
// Port is optional in req.URL.Host, in this case SplitHostPort returns
// an error, and we add the default port
_, port, err := net.SplitHostPort(req.URL.Host)
if err != nil || port == "" {
host = net.JoinHostPort(req.URL.Host, "443")
}
targetURL := url.URL{Scheme: "wss", Host: host, Path: req.URL.Path}

// Connect to upstream
targetConn, err := tls.Dial("tcp", targetURL.Host, tlsConfig)
targetConn, err := proxy.connectDial(ctx, "tcp", host)
if err != nil {
ctx.Warnf("Error dialing target site: %v", err)
return
}
defer targetConn.Close()

// Perform handshake
if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil {
ctx.Warnf("Websocket handshake error: %v", err)
return
}

// Proxy wss connection
proxy.proxyWebsocket(ctx, targetConn, clientConn)
}

func (proxy *ProxyHttpServer) serveWebsocketHttpOverTLS(
ctx *ProxyCtx,
w http.ResponseWriter,
req *http.Request,
clientConn *tls.Conn,
) {
host := req.URL.Host
// Port is optional in req.URL.Host, in this case SplitHostPort returns
// an error, and we add the default port
_, port, err := net.SplitHostPort(req.URL.Host)
if err != nil || port == "" {
host = net.JoinHostPort(req.URL.Host, "80")
}
targetURL := url.URL{Scheme: "ws", Host: host, Path: req.URL.Path}

// Connect to upstream
targetConn, err := proxy.connectDial(ctx, "tcp", targetURL.Host)
// Add TLS to the raw TCP connection
targetConn, err = proxy.initializeTLSconnection(ctx, targetConn, tlsConfig)
if err != nil {
ctx.Warnf("Error dialing target site: %v", err)
ctx.Warnf("Websocket TLS connection error: %v", err)
return
}
defer targetConn.Close()

// Perform handshake
if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil {
Expand All @@ -93,16 +64,7 @@ func (proxy *ProxyHttpServer) serveWebsocketHttpOverTLS(
proxy.proxyWebsocket(ctx, targetConn, clientConn)
}

func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request) {
targetURL := url.URL{Scheme: "ws", Host: req.URL.Host, Path: req.URL.Path}

targetConn, err := proxy.connectDial(ctx, "tcp", targetURL.Host)
if err != nil {
ctx.Warnf("Error dialing target site: %v", err)
return
}
defer targetConn.Close()

func (proxy *ProxyHttpServer) hijackConnection(ctx *ProxyCtx, w http.ResponseWriter) (net.Conn, error) {
// Connect to Client
hj, ok := w.(http.Hijacker)
if !ok {
Expand All @@ -111,8 +73,27 @@ func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, w http.ResponseWrite
clientConn, _, err := hj.Hijack()
if err != nil {
ctx.Warnf("Hijack error: %v", err)
return nil, err
}
return clientConn, nil
}

func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, clientConn net.Conn, req *http.Request) {
// ws
host := req.URL.Host
// Port is optional in req.URL.Host, in this case SplitHostPort returns
// an error, and we add the default port
_, port, err := net.SplitHostPort(req.URL.Host)
if err != nil || port == "" {
host = net.JoinHostPort(req.URL.Host, "80")
}

targetConn, err := proxy.connectDial(ctx, "tcp", host)
if err != nil {
ctx.Warnf("Error dialing target site: %v", err)
return
}
defer targetConn.Close()

// Perform handshake
if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil {
Expand Down

0 comments on commit dd21e8d

Please sign in to comment.