From 3d6017ab75a37b114bd4e999291933d5cad683e4 Mon Sep 17 00:00:00 2001 From: Erik Pellizzon Date: Tue, 31 Dec 2024 15:40:23 +0100 Subject: [PATCH] Prevent Request headers canonicalization (#607) * Add header extraction with tests * Add prevent canonicalization implementation * Update documentation * Fix linter issues * Remove redundant body read * Avoid to duplicate read data if prevent canonicalization is set to false * Copy https mitm fixes to http mitm * Prevent response body memory leak by making sure to close it when we are done * Add additional request parsing tests * Rely on stdlib header parsing implementation to extract header names --- .golangci.yml | 2 +- README.md | 1 + go.mod | 12 +- go.sum | 10 ++ https.go | 153 ++++++++++++-------- internal/http1parser/header.go | 43 ++++++ internal/http1parser/header_test.go | 48 +++++++ internal/http1parser/request.go | 94 ++++++++++++ internal/http1parser/request_test.go | 204 +++++++++++++++++++++++++++ proxy.go | 13 +- 10 files changed, 508 insertions(+), 72 deletions(-) create mode 100644 internal/http1parser/header.go create mode 100644 internal/http1parser/header_test.go create mode 100644 internal/http1parser/request.go create mode 100644 internal/http1parser/request_test.go diff --git a/.golangci.yml b/.golangci.yml index 1deff156..d51cb34a 100755 --- a/.golangci.yml +++ b/.golangci.yml @@ -20,7 +20,6 @@ linters: - containedctx - decorder - dogsled - - dupl - durationcheck - errchkjson - errname @@ -80,6 +79,7 @@ linters: - copyloopvar - cyclop - depguard + - dupl - dupword - err113 - exhaustruct diff --git a/README.md b/README.md index 405ada64..367fae26 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ proxy is `localhost:8080`, which is the default one in our example. - You can specify a `MITM certificates cache`, to reuse them later for other requests to the same host, thus saving CPU. Not enabled by default, but you should use it in production! - Redirect normal HTTP traffic to a `custom handler`, when the target is a `relative path` (e.g. `/ping`) - You can choose the logger to use, by implementing the `Logger` interface +- You can `disable` the HTTP request headers `canonicalization`, by setting `PreventCanonicalization` to true ## Proxy modes 1. Regular HTTP proxy diff --git a/go.mod b/go.mod index 8ef6b038..8823c52a 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,14 @@ module github.com/elazarl/goproxy go 1.20 -require golang.org/x/net v0.33.0 +require ( + github.com/stretchr/testify v1.10.0 + golang.org/x/net v0.33.0 +) -require golang.org/x/text v0.21.0 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/text v0.21.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 845330e7..0c0c948a 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,14 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/https.go b/https.go index 26b6202f..3ea98873 100644 --- a/https.go +++ b/https.go @@ -16,6 +16,7 @@ import ( "sync" "sync/atomic" + "github.com/elazarl/goproxy/internal/http1parser" "github.com/elazarl/goproxy/internal/signer" ) @@ -192,15 +193,25 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request var targetSiteCon net.Conn var remote *bufio.Reader - for { - client := bufio.NewReader(proxyClient) - req, err := http.ReadRequest(client) + client := http1parser.NewRequestReader(proxy.PreventCanonicalization, proxyClient) + for !client.IsEOF() { + req, err := client.ReadRequest() if err != nil && !errors.Is(err, io.EOF) { ctx.Warnf("cannot read request of MITM HTTP client: %+#v", err) } if err != nil { return } + + // Take the original value before filtering the request + closeConn := req.Close + + // since we're converting the request, need to carry over the + // original connecting IP as well + req.RemoteAddr = r.RemoteAddr + ctx.Logf("req %v", r.Host) + ctx.Req = req + req, resp := proxy.filterRequest(req, ctx) if resp == nil { // Establish a connection with the remote server only if the proxy @@ -218,18 +229,27 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request httpError(proxyClient, ctx, err) return } - resp, err = http.ReadResponse(remote, req) + resp, err = func() (*http.Response, error) { + defer req.Body.Close() + return http.ReadResponse(remote, req) + }() if err != nil { httpError(proxyClient, ctx, err) return } - defer resp.Body.Close() } resp = proxy.filterResponse(resp, ctx) - if err := resp.Write(proxyClient); err != nil { + err = resp.Write(proxyClient) + _ = resp.Body.Close() + if err != nil { httpError(proxyClient, ctx, err) return } + + if closeConn { + ctx.Logf("Non-persistent connection; closing") + return + } } case ConnectMitm: _, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) @@ -255,9 +275,10 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request ctx.Warnf("Cannot handshake client %v %v", r.Host, err) return } - clientTlsReader := bufio.NewReader(rawClientTls) - for !isEOF(clientTlsReader) { - req, err := http.ReadRequest(clientTlsReader) + + clientTlsReader := http1parser.NewRequestReader(proxy.PreventCanonicalization, rawClientTls) + for !clientTlsReader.IsEOF() { + req, err := clientTlsReader.ReadRequest() ctx := &ProxyCtx{ Req: req, Session: atomic.AddInt64(&proxy.sess, 1), @@ -266,10 +287,9 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request RoundTripper: ctx.RoundTripper, } if err != nil && !errors.Is(err, io.EOF) { - return + ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err) } if err != nil { - ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err) return } @@ -298,7 +318,8 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request // parse the HTTP Body for PRI requests. This leaves the body of // the http2.ClientPreface ("SM\r\n\r\n") on the wire which we need // to clear before setting up the connection. - _, err := clientTlsReader.Discard(6) + reader := clientTlsReader.Reader() + _, err := reader.Discard(6) if err != nil { ctx.Warnf("Failed to process HTTP2 client preface: %v", err) return @@ -307,7 +328,7 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request ctx.Warnf("HTTP2 connection failed: disallowed") return } - tr := H2Transport{clientTlsReader, rawClientTls, tlsConfig.Clone(), host} + tr := H2Transport{reader, rawClientTls, tlsConfig.Clone(), host} if _, err := tr.RoundTrip(req); err != nil { ctx.Warnf("HTTP2 connection failed: %v", err) } else { @@ -349,61 +370,69 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request ctx.Logf("resp %v", resp.Status) } resp = proxy.filterResponse(resp, ctx) - defer resp.Body.Close() - - text := resp.Status - statusCode := strconv.Itoa(resp.StatusCode) + " " - text = strings.TrimPrefix(text, statusCode) - // always use 1.1 to support chunked encoding - if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+text+"\r\n"); err != nil { - ctx.Warnf("Cannot write TLS response HTTP status from mitm'd client: %v", err) - return - } - if resp.Request.Method == http.MethodHead { - // don't change Content-Length for HEAD request - } else if (resp.StatusCode >= 100 && resp.StatusCode < 200) || - resp.StatusCode == http.StatusNoContent { - // RFC7230: A server MUST NOT send a Content-Length header field in any response - // with a status code of 1xx (Informational) or 204 (No Content) - resp.Header.Del("Content-Length") - } else { - // Since we don't know the length of resp, return chunked encoded response - // TODO: use a more reasonable scheme - resp.Header.Del("Content-Length") - resp.Header.Set("Transfer-Encoding", "chunked") - } - // Force connection close otherwise chrome will keep CONNECT tunnel open forever - resp.Header.Set("Connection", "close") - if err := resp.Header.Write(rawClientTls); err != nil { - ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err) - return - } - if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { - ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err) - return - } + // Run defer inside a custom function to prevent response body memory leak + if ok := func() bool { + defer resp.Body.Close() + + text := resp.Status + statusCode := strconv.Itoa(resp.StatusCode) + " " + text = strings.TrimPrefix(text, statusCode) + // always use 1.1 to support chunked encoding + if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+text+"\r\n"); err != nil { + ctx.Warnf("Cannot write TLS response HTTP status from mitm'd client: %v", err) + return false + } - if resp.Request.Method == http.MethodHead || - (resp.StatusCode >= 100 && resp.StatusCode < 200) || - resp.StatusCode == http.StatusNoContent || - resp.StatusCode == http.StatusNotModified { - // Don't write out a response body, when it's not allowed - // in RFC7230 - } else { - chunked := newChunkedWriter(rawClientTls) - if _, err := io.Copy(chunked, resp.Body); err != nil { - ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err) - return + if resp.Request.Method == http.MethodHead { + // don't change Content-Length for HEAD request + } else if (resp.StatusCode >= 100 && resp.StatusCode < 200) || + resp.StatusCode == http.StatusNoContent { + // RFC7230: A server MUST NOT send a Content-Length header field in any response + // with a status code of 1xx (Informational) or 204 (No Content) + resp.Header.Del("Content-Length") + } else { + // Since we don't know the length of resp, return chunked encoded response + // TODO: use a more reasonable scheme + resp.Header.Del("Content-Length") + resp.Header.Set("Transfer-Encoding", "chunked") } - if err := chunked.Close(); err != nil { - ctx.Warnf("Cannot write TLS chunked EOF from mitm'd client: %v", err) - return + // Force connection close otherwise chrome will keep CONNECT tunnel open forever + resp.Header.Set("Connection", "close") + if err := resp.Header.Write(rawClientTls); err != nil { + ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err) + return false } if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { - ctx.Warnf("Cannot write TLS response chunked trailer from mitm'd client: %v", err) - return + ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err) + return false } + + if resp.Request.Method == http.MethodHead || + (resp.StatusCode >= 100 && resp.StatusCode < 200) || + resp.StatusCode == http.StatusNoContent || + resp.StatusCode == http.StatusNotModified { + // Don't write out a response body, when it's not allowed + // in RFC7230 + } else { + chunked := newChunkedWriter(rawClientTls) + if _, err := io.Copy(chunked, resp.Body); err != nil { + ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err) + return false + } + if err := chunked.Close(); err != nil { + ctx.Warnf("Cannot write TLS chunked EOF from mitm'd client: %v", err) + return false + } + if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { + ctx.Warnf("Cannot write TLS response chunked trailer from mitm'd client: %v", err) + return false + } + } + + return true + }(); !ok { + return } if closeConn { diff --git a/internal/http1parser/header.go b/internal/http1parser/header.go new file mode 100644 index 00000000..d4ef3e64 --- /dev/null +++ b/internal/http1parser/header.go @@ -0,0 +1,43 @@ +package http1parser + +import ( + "errors" + "net/textproto" + "strings" +) + +var ErrBadProto = errors.New("bad protocol") + +// Http1ExtractHeaders is an HTTP/1.0 and HTTP/1.1 header-only parser, +// to extract the original header names for the received request. +// Fully inspired by readMIMEHeader() in +// https://github.com/golang/go/blob/master/src/net/textproto/reader.go +func Http1ExtractHeaders(r *textproto.Reader) ([]string, error) { + // Discard first line, it doesn't contain useful information, and it has + // already been validated in http.ReadRequest() + if _, err := r.ReadLine(); err != nil { + return nil, err + } + + // The first line cannot start with a leading space. + if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { + return nil, ErrBadProto + } + + var headerNames []string + for { + kv, err := r.ReadContinuedLine() + if len(kv) == 0 { + // We have finished to parse the headers if we receive empty + // data without an error + return headerNames, err + } + + // Key ends at first colon. + k, _, ok := strings.Cut(kv, ":") + if !ok { + return nil, ErrBadProto + } + headerNames = append(headerNames, k) + } +} diff --git a/internal/http1parser/header_test.go b/internal/http1parser/header_test.go new file mode 100644 index 00000000..1e905442 --- /dev/null +++ b/internal/http1parser/header_test.go @@ -0,0 +1,48 @@ +package http1parser_test + +import ( + "bufio" + "bytes" + "net/textproto" + "testing" + + "github.com/elazarl/goproxy/internal/http1parser" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHttp1ExtractHeaders_Empty(t *testing.T) { + http1Data := "POST /index.html HTTP/1.1\r\n" + + "\r\n" + + textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data)))) + headers, err := http1parser.Http1ExtractHeaders(textParser) + require.NoError(t, err) + assert.Empty(t, headers) +} + +func TestHttp1ExtractHeaders(t *testing.T) { + http1Data := "POST /index.html HTTP/1.1\r\n" + + "Host: www.test.com\r\n" + + "Accept: */ /*\r\n" + + "Content-Length: 17\r\n" + + "lowercase: 3z\r\n" + + "\r\n" + + `{"hello":"world"}` + + textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data)))) + headers, err := http1parser.Http1ExtractHeaders(textParser) + require.NoError(t, err) + assert.Len(t, headers, 4) + assert.Contains(t, headers, "Content-Length") + assert.Contains(t, headers, "lowercase") +} + +func TestHttp1ExtractHeaders_InvalidData(t *testing.T) { + http1Data := "POST /index.html HTTP/1.1\r\n" + + `{"hello":"world"}` + + textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data)))) + _, err := http1parser.Http1ExtractHeaders(textParser) + require.Error(t, err) +} diff --git a/internal/http1parser/request.go b/internal/http1parser/request.go new file mode 100644 index 00000000..0e37bc2d --- /dev/null +++ b/internal/http1parser/request.go @@ -0,0 +1,94 @@ +package http1parser + +import ( + "bufio" + "bytes" + "errors" + "io" + "net/http" + "net/textproto" +) + +type RequestReader struct { + preventCanonicalization bool + reader *bufio.Reader + // Used only when preventCanonicalization value is true + cloned *bytes.Buffer +} + +func NewRequestReader(preventCanonicalization bool, conn io.Reader) *RequestReader { + if !preventCanonicalization { + return &RequestReader{ + preventCanonicalization: false, + reader: bufio.NewReader(conn), + } + } + + var cloned bytes.Buffer + reader := bufio.NewReader(io.TeeReader(conn, &cloned)) + return &RequestReader{ + preventCanonicalization: true, + reader: reader, + cloned: &cloned, + } +} + +// IsEOF returns true if there is no more data that can be read from the +// buffer and the underlying connection is closed. +func (r *RequestReader) IsEOF() bool { + _, err := r.reader.Peek(1) + return errors.Is(err, io.EOF) +} + +// Reader is used to take over the buffered connection data +// (e.g. with HTTP/2 data). +// After calling this function, make sure to consume all the data related +// to the current request. +func (r *RequestReader) Reader() *bufio.Reader { + return r.reader +} + +func (r *RequestReader) ReadRequest() (*http.Request, error) { + if !r.preventCanonicalization { + // Just call the HTTP library function if the preventCanonicalization + // configuration is disabled + return http.ReadRequest(r.reader) + } + + req, err := http.ReadRequest(r.reader) + if err != nil { + return nil, err + } + + httpDataReader := getRequestReader(r.reader, r.cloned) + headers, _ := Http1ExtractHeaders(httpDataReader) + + for _, headerName := range headers { + canonicalizedName := textproto.CanonicalMIMEHeaderKey(headerName) + if canonicalizedName == headerName { + continue + } + + // Rewrite header keys to the non-canonical parsed value + values, ok := req.Header[canonicalizedName] + if ok { + req.Header.Del(canonicalizedName) + req.Header[headerName] = values + } + } + + return req, nil +} + +func getRequestReader(r *bufio.Reader, cloned *bytes.Buffer) *textproto.Reader { + // "Cloned" buffer uses the raw connection as the data source. + // However, the *bufio.Reader can read also bytes of another unrelated + // request on the same connection, since it's buffered, so we have to + // ignore them before passing the data to our headers parser. + // Data related to the next request will remain inside the buffer for + // later usage. + data := cloned.Next(cloned.Len() - r.Buffered()) + return &textproto.Reader{ + R: bufio.NewReader(bytes.NewReader(data)), + } +} diff --git a/internal/http1parser/request_test.go b/internal/http1parser/request_test.go new file mode 100644 index 00000000..adc87bbe --- /dev/null +++ b/internal/http1parser/request_test.go @@ -0,0 +1,204 @@ +package http1parser_test + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/elazarl/goproxy/internal/http1parser" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + _data = "POST /index.html HTTP/1.1\r\n" + + "Host: www.test.com\r\n" + + "Accept: */*\r\n" + + "Content-Length: 17\r\n" + + "lowercase: 3z\r\n" + + "\r\n" + + `{"hello":"world"}` + + _data2 = "GET /index.html HTTP/1.1\r\n" + + "Host: www.test.com\r\n" + + "Accept: */*\r\n" + + "lowercase: 3z\r\n" + + "\r\n" +) + +func TestCanonicalRequest(t *testing.T) { + // Here we are simulating two requests on the same connection + http1Data := bytes.NewReader(append([]byte(_data), _data2...)) + parser := http1parser.NewRequestReader(false, http1Data) + + // 1st request + req, err := parser.ReadRequest() + require.NoError(t, err) + assert.NotEmpty(t, req.Header) + assert.NotContains(t, req.Header, "lowercase") + assert.Contains(t, req.Header, "Lowercase") + require.NoError(t, req.Body.Close()) + + // 2nd request + req, err = parser.ReadRequest() + require.NoError(t, err) + assert.NotEmpty(t, req.Header) + + // Make sure that the buffers are empty after all requests have been processed + assert.True(t, parser.IsEOF()) +} + +func TestNonCanonicalRequest(t *testing.T) { + http1Data := bytes.NewReader([]byte(_data)) + parser := http1parser.NewRequestReader(true, http1Data) + + req, err := parser.ReadRequest() + require.NoError(t, err) + assert.NotEmpty(t, req.Header) + assert.Contains(t, req.Header, "lowercase") + assert.NotContains(t, req.Header, "Lowercase") +} + +func TestMultipleNonCanonicalRequests(t *testing.T) { + http1Data := bytes.NewReader(append([]byte(_data), _data2...)) + parser := http1parser.NewRequestReader(true, http1Data) + + req, err := parser.ReadRequest() + require.NoError(t, err) + assert.NotEmpty(t, req.Header) + assert.Contains(t, req.Header, "lowercase") + assert.NotContains(t, req.Header, "Lowercase") + + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Len(t, body, 17) + require.NoError(t, req.Body.Close()) + + req, err = parser.ReadRequest() + require.NoError(t, err) + assert.NotEmpty(t, req.Header) + + assert.True(t, parser.IsEOF()) +} + +// reqTest is inspired by https://github.com/golang/go/blob/master/src/net/http/readrequest_test.go +type reqTest struct { + Raw string + Req *http.Request + Body string + Trailer http.Header + Error string +} + +var ( + noError = "" + noBodyStr = "" + noTrailer http.Header +) + +var reqTests = []reqTest{ + // Baseline test; All Request fields included for template use + { + "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + "Host: www.techcrunch.com\r\n" + + "user-agent: Fake\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + + "Accept-Language: en-us,en;q=0.5\r\n" + + "Accept-Encoding: gzip,deflate\r\n" + + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + + "Keep-Alive: 300\r\n" + + "Content-Length: 7\r\n" + + "Proxy-Connection: keep-alive\r\n\r\n" + + "abcdef\n???", + &http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Scheme: "http", + Host: "www.techcrunch.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, + "Accept-Language": {"en-us,en;q=0.5"}, + "Accept-Encoding": {"gzip,deflate"}, + "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"}, + "Keep-Alive": {"300"}, + "Proxy-Connection": {"keep-alive"}, + "Content-Length": {"7"}, + "user-agent": {"Fake"}, + }, + Close: false, + ContentLength: 7, + Host: "www.techcrunch.com", + RequestURI: "http://www.techcrunch.com/", + }, + "abcdef\n", + noTrailer, + noError, + }, + + // GET request with no body (the normal case) + { + "GET / HTTP/1.1\r\n" + + "Host: foo.com\r\n\r\n", + &http.Request{ + Method: http.MethodGet, + URL: &url.URL{ + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Close: false, + ContentLength: 0, + Host: "foo.com", + RequestURI: "/", + }, + noBodyStr, + noTrailer, + noError, + }, +} + +func TestReadRequest(t *testing.T) { + for i := range reqTests { + tt := &reqTests[i] + + testName := fmt.Sprintf("Test %d (%q)", i, tt.Raw) + t.Run(testName, func(t *testing.T) { + r := bufio.NewReader(strings.NewReader(tt.Raw)) + parser := http1parser.NewRequestReader(true, r) + req, err := parser.ReadRequest() + if err != nil && err.Error() == tt.Error { + // Test finished, we expected an error + return + } + require.NoError(t, err) + + // Check request equality (excluding body) + rbody := req.Body + req.Body = nil + assert.Equal(t, tt.Req, req) + + // Check if the two bodies match + var bodyString string + if rbody != nil { + data, err := io.ReadAll(rbody) + require.NoError(t, err) + bodyString = string(data) + _ = rbody.Close() + } + assert.Equal(t, tt.Body, bodyString) + assert.Equal(t, tt.Trailer, req.Trailer) + }) + } +} diff --git a/proxy.go b/proxy.go index 21add93e..f9a82946 100644 --- a/proxy.go +++ b/proxy.go @@ -1,8 +1,6 @@ package goproxy import ( - "bufio" - "errors" "io" "log" "net" @@ -39,6 +37,12 @@ type ProxyHttpServer struct { CertStore CertStorage KeepHeader bool AllowHTTP2 bool + // When PreventCanonicalization is true, the header names present in + // the request sent through the proxy are directly passed to the destination server, + // instead of following the HTTP RFC for their canonicalization. + // This is useful when the header name isn't treated as a case-insensitive + // value by the target server, because they don't follow the specs. + PreventCanonicalization bool // KeepAcceptEncoding, if true, prevents the proxy from dropping // Accept-Encoding headers from the client. // @@ -63,11 +67,6 @@ func copyHeaders(dst, src http.Header, keepDestHeaders bool) { } } -func isEOF(r *bufio.Reader) bool { - _, err := r.Peek(1) - return errors.Is(err, io.EOF) -} - func (proxy *ProxyHttpServer) filterRequest(r *http.Request, ctx *ProxyCtx) (req *http.Request, resp *http.Response) { req = r for _, h := range proxy.reqHandlers {