diff --git a/ctx.go b/ctx.go index 27007bfa..32a3434f 100644 --- a/ctx.go +++ b/ctx.go @@ -1,8 +1,10 @@ package goproxy import ( + "context" "crypto/tls" "mime" + "net" "net/http" ) @@ -14,6 +16,10 @@ type ProxyCtx struct { // Will contain the remote server's response (if available. nil if the request wasn't send yet) Resp *http.Response RoundTripper RoundTripper + // Specify a custom connection dialer that will be used only for the current + // request, including WebSocket connection upgrades + Dialer func(ctx context.Context, network string, addr string) (net.Conn, error) + InitializeTLS func(rawConn net.Conn, cfg *tls.Config) (net.Conn, error) // will contain the recent error that occurred while trying to send receive or parse traffic Error error // A handle for the user to keep data in the context, from the call of ReqHandler to the diff --git a/examples/go.mod b/examples/go.mod index da68a941..f171aec1 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -3,15 +3,14 @@ module github.com/elazarl/goproxy/examples/goproxy-transparent go 1.20 require ( - github.com/elazarl/goproxy v0.0.0-20241217120900-7711dfa3811c - github.com/elazarl/goproxy/ext v0.0.0-20241217120900-7711dfa3811c + github.com/elazarl/goproxy v1.3.0 + github.com/elazarl/goproxy/ext v0.0.0-20250110140559-10fc34b80676 github.com/gorilla/websocket v1.5.3 github.com/inconshreveable/go-vhost v1.0.0 ) require ( - github.com/rogpeppe/go-charset v0.0.0-20190617161244-0dc95cdf6f31 // indirect - golang.org/x/net v0.33.0 // indirect + golang.org/x/net v0.34.0 // indirect golang.org/x/text v0.21.0 // indirect ) diff --git a/examples/go.sum b/examples/go.sum index 8aa1074a..aace2ae7 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1,13 +1,14 @@ -github.com/elazarl/goproxy/ext v0.0.0-20241217120900-7711dfa3811c h1:R+i10jtNSzKJKqEZAYJnR9M8y14k0zrNHqD1xkv/A2M= -github.com/elazarl/goproxy/ext v0.0.0-20241217120900-7711dfa3811c/go.mod h1:gNh8nYJoAm43RfaxurUnxr+N1PwuFV3ZMl/efxlIlY8= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/elazarl/goproxy/ext v0.0.0-20250110140559-10fc34b80676 h1:3bAtOWqImclW/5rXbhNyAcM122jafst+/+4J4vC8wZI= +github.com/elazarl/goproxy/ext v0.0.0-20250110140559-10fc34b80676/go.mod h1:q2JQCFWg+AQfe6O2cbf7LJDB48R68w+q0pBU53v02iM= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/inconshreveable/go-vhost v1.0.0 h1:IK4VZTlXL4l9vz2IZoiSFbYaaqUW7dXJAiPriUN5Ur8= github.com/inconshreveable/go-vhost v1.0.0/go.mod h1:aA6DnFhALT3zH0y+A39we+zbrdMC2N0X/q21e6FI0LU= -github.com/rogpeppe/go-charset v0.0.0-20180617210344-2471d30d28b4/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc= -github.com/rogpeppe/go-charset v0.0.0-20190617161244-0dc95cdf6f31 h1:DE4LcMKyqAVa6a0CGmVxANbnVb7stzMmPkQiieyNmfQ= -github.com/rogpeppe/go-charset v0.0.0-20190617161244-0dc95cdf6f31/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/http.go b/http.go index 63050270..38a3be5b 100644 --- a/http.go +++ b/http.go @@ -21,7 +21,9 @@ func (proxy *ProxyHttpServer) handleHttp(w http.ResponseWriter, r *http.Request) if resp == nil { if isWebSocketRequest(r) { ctx.Logf("Request looks like websocket upgrade.") - proxy.serveWebsocket(ctx, w, r) + if conn, err := proxy.hijackConnection(ctx, w); err == nil { + proxy.serveWebsocket(ctx, conn, r) + } } if !proxy.KeepHeader { diff --git a/https.go b/https.go index 3ea98873..0af46222 100644 --- a/https.go +++ b/https.go @@ -2,7 +2,6 @@ package goproxy import ( "bufio" - "context" "crypto/tls" "errors" "fmt" @@ -72,16 +71,23 @@ func stripPort(s string) string { return s[:ix] } -func (proxy *ProxyHttpServer) dial(ctx context.Context, network, addr string) (c net.Conn, err error) { +func (proxy *ProxyHttpServer) dial(ctx *ProxyCtx, network, addr string) (c net.Conn, err error) { + if ctx.Dialer != nil { + return ctx.Dialer(ctx.Req.Context(), network, addr) + } + if proxy.Tr.DialContext != nil { - return proxy.Tr.DialContext(ctx, network, addr) + return proxy.Tr.DialContext(ctx.Req.Context(), network, addr) } + + // if the user didn't specify any dialer, we just use the default one, + // provided by net package return net.Dial(network, addr) } func (proxy *ProxyHttpServer) connectDial(ctx *ProxyCtx, network, addr string) (c net.Conn, err error) { if proxy.ConnectDialWithReq == nil && proxy.ConnectDial == nil { - return proxy.dial(ctx.Req.Context(), network, addr) + return proxy.dial(ctx, network, addr) } if proxy.ConnectDialWithReq != nil { @@ -340,9 +346,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request ctx.Logf("Request looks like websocket upgrade.") if req.URL.Scheme == "http" { ctx.Logf("Enforced HTTP websocket forwarding over TLS") - proxy.serveWebsocketHttpOverTLS(ctx, w, req, rawClientTls) + proxy.serveWebsocket(ctx, rawClientTls, req) } else { - proxy.serveWebsocketTLS(ctx, w, req, tlsConfig, rawClientTls) + proxy.serveWebsocketTLS(ctx, req, tlsConfig, rawClientTls) } return } @@ -533,7 +539,7 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler( if connectReqHandler != nil { connectReqHandler(connectReq) } - c, err := proxy.dial(context.Background(), network, u.Host) + c, err := proxy.dial(&ProxyCtx{Req: &http.Request{}}, network, u.Host) if err != nil { return nil, err } @@ -564,11 +570,17 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler( u.Host += ":443" } return func(network, addr string) (net.Conn, error) { - c, err := proxy.dial(context.Background(), network, u.Host) + ctx := &ProxyCtx{Req: &http.Request{}} + c, err := proxy.dial(ctx, network, u.Host) if err != nil { return nil, err } - c = tls.Client(c, proxy.Tr.TLSClientConfig) + + c, err = proxy.initializeTLSconnection(ctx, c, proxy.Tr.TLSClientConfig) + if err != nil { + return nil, err + } + connectReq := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Opaque: addr}, @@ -630,3 +642,14 @@ func TLSConfigFromCA(ca *tls.Certificate) func(host string, ctx *ProxyCtx) (*tls return config, nil } } + +func (proxy *ProxyHttpServer) initializeTLSconnection( + ctx *ProxyCtx, + targetConn net.Conn, + tlsConfig *tls.Config, +) (net.Conn, error) { + if ctx.InitializeTLS != nil { + return ctx.InitializeTLS(targetConn, tlsConfig) + } + return tls.Client(targetConn, tlsConfig), nil +} diff --git a/websocket.go b/websocket.go index 97108ee5..a376e879 100644 --- a/websocket.go +++ b/websocket.go @@ -6,7 +6,6 @@ import ( "io" "net" "net/http" - "net/url" "strings" ) @@ -28,11 +27,11 @@ func isWebSocketRequest(r *http.Request) bool { func (proxy *ProxyHttpServer) serveWebsocketTLS( ctx *ProxyCtx, - w http.ResponseWriter, req *http.Request, tlsConfig *tls.Config, clientConn *tls.Conn, ) { + // wss host := req.URL.Host // Port is optional in req.URL.Host, in this case SplitHostPort returns // an error, and we add the default port @@ -40,48 +39,20 @@ func (proxy *ProxyHttpServer) serveWebsocketTLS( if err != nil || port == "" { host = net.JoinHostPort(req.URL.Host, "443") } - targetURL := url.URL{Scheme: "wss", Host: host, Path: req.URL.Path} - // Connect to upstream - targetConn, err := tls.Dial("tcp", targetURL.Host, tlsConfig) + targetConn, err := proxy.connectDial(ctx, "tcp", host) if err != nil { ctx.Warnf("Error dialing target site: %v", err) return } defer targetConn.Close() - // Perform handshake - if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil { - ctx.Warnf("Websocket handshake error: %v", err) - return - } - - // Proxy wss connection - proxy.proxyWebsocket(ctx, targetConn, clientConn) -} - -func (proxy *ProxyHttpServer) serveWebsocketHttpOverTLS( - ctx *ProxyCtx, - w http.ResponseWriter, - req *http.Request, - clientConn *tls.Conn, -) { - host := req.URL.Host - // Port is optional in req.URL.Host, in this case SplitHostPort returns - // an error, and we add the default port - _, port, err := net.SplitHostPort(req.URL.Host) - if err != nil || port == "" { - host = net.JoinHostPort(req.URL.Host, "80") - } - targetURL := url.URL{Scheme: "ws", Host: host, Path: req.URL.Path} - - // Connect to upstream - targetConn, err := proxy.connectDial(ctx, "tcp", targetURL.Host) + // Add TLS to the raw TCP connection + targetConn, err = proxy.initializeTLSconnection(ctx, targetConn, tlsConfig) if err != nil { - ctx.Warnf("Error dialing target site: %v", err) + ctx.Warnf("Websocket TLS connection error: %v", err) return } - defer targetConn.Close() // Perform handshake if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil { @@ -93,16 +64,7 @@ func (proxy *ProxyHttpServer) serveWebsocketHttpOverTLS( proxy.proxyWebsocket(ctx, targetConn, clientConn) } -func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request) { - targetURL := url.URL{Scheme: "ws", Host: req.URL.Host, Path: req.URL.Path} - - targetConn, err := proxy.connectDial(ctx, "tcp", targetURL.Host) - if err != nil { - ctx.Warnf("Error dialing target site: %v", err) - return - } - defer targetConn.Close() - +func (proxy *ProxyHttpServer) hijackConnection(ctx *ProxyCtx, w http.ResponseWriter) (net.Conn, error) { // Connect to Client hj, ok := w.(http.Hijacker) if !ok { @@ -111,8 +73,27 @@ func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, w http.ResponseWrite clientConn, _, err := hj.Hijack() if err != nil { ctx.Warnf("Hijack error: %v", err) + return nil, err + } + return clientConn, nil +} + +func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, clientConn net.Conn, req *http.Request) { + // ws + host := req.URL.Host + // Port is optional in req.URL.Host, in this case SplitHostPort returns + // an error, and we add the default port + _, port, err := net.SplitHostPort(req.URL.Host) + if err != nil || port == "" { + host = net.JoinHostPort(req.URL.Host, "80") + } + + targetConn, err := proxy.connectDial(ctx, "tcp", host) + if err != nil { + ctx.Warnf("Error dialing target site: %v", err) return } + defer targetConn.Close() // Perform handshake if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil {