diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go index 2238f65e13c..4ba3d693f71 100644 --- a/gateway/gateway_test.go +++ b/gateway/gateway_test.go @@ -1,6 +1,7 @@ package gateway import ( + "bufio" "bytes" "context" "encoding/json" @@ -21,6 +22,7 @@ import ( "github.com/gorilla/websocket" proxyproto "github.com/pires/go-proxyproto" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" msgpack "gopkg.in/vmihailenco/msgpack.v2" "github.com/TykTechnologies/tyk-pump/analytics" @@ -1486,6 +1488,50 @@ func TestWebsocketsSeveralOpenClose(t *testing.T) { conn3.Close() } +func TestWebsocketsWithConnectionKeepAlive(t *testing.T) { + ts := StartTest(nil) + defer ts.Close() + + globalConf := ts.Gw.GetConfig() + globalConf.HttpServerOptions.EnableWebSockets = true + ts.Gw.SetConfig(globalConf) + + ts.Gw.BuildAndLoadAPI(func(spec *APISpec) { + spec.Proxy.ListenPath = "/" + }) + + baseURL := strings.Replace(ts.URL, "http://", "ws://", -1) + url, err := url.Parse(baseURL) + require.NoError(t, err) + + conn, err := net.Dial("tcp", url.Host) + require.NoError(t, err) + defer conn.Close() + + req := fmt.Sprintf(`GET %s/ws HTTP/1.1 +Host: %s +Accept-Encoding: gzip, deflate, br, zstd +Sec-WebSocket-Version: 13 +Sec-WebSocket-Extensions: permessage-deflate +Sec-WebSocket-Key: X62lCXELOHFcBBG72P2S2Q== +Connection: Upgrade, keep-alive +Upgrade: websocket + +`, baseURL, url.Host) + req = strings.Replace(req, "\n", "\r\n", -1) + _, err = conn.Write([]byte(req)) + require.NoError(t, err) + buf, err := bufio.NewReader(conn).ReadString('\n') + require.NoError(t, err) + assert.Contains(t, buf, "HTTP/1.1 101 Switching Protocols") + + _, _ = ts.Run(t, test.TestCase{ + Method: "GET", + Path: "/abc", + Code: http.StatusOK, + }) +} + func TestWebsocketsAndHTTPEndpointMatch(t *testing.T) { ts := StartTest(nil) t.Cleanup(ts.Close) diff --git a/internal/httputil/streaming.go b/internal/httputil/streaming.go index 2589d340058..94432483cf6 100644 --- a/internal/httputil/streaming.go +++ b/internal/httputil/streaming.go @@ -2,6 +2,7 @@ package httputil import ( "net/http" + "net/textproto" "strings" ) @@ -23,8 +24,7 @@ func IsSseStreamingResponse(r *http.Response) bool { // IsUpgrade checks if the request is an upgrade request and returns the upgrade type. func IsUpgrade(req *http.Request) (string, bool) { - connection := strings.ToLower(strings.TrimSpace(req.Header.Get(headerConnection))) - if connection != "upgrade" { + if !headerContainsTokenIgnoreCase(req.Header, headerConnection, "Upgrade") { return "", false } @@ -36,6 +36,28 @@ func IsUpgrade(req *http.Request) (string, bool) { return "", false } +func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { + for _, t := range headerTokens(h, key) { + if strings.EqualFold(t, token) { + return true + } + } + return false +} + +func headerTokens(h http.Header, key string) []string { + key = textproto.CanonicalMIMEHeaderKey(key) + var tokens []string + for _, v := range h[key] { + v = strings.TrimSpace(v) + for _, t := range strings.Split(v, ",") { + t = strings.TrimSpace(t) + tokens = append(tokens, t) + } + } + return tokens +} + // IsStreamingRequest returns true if the request designates streaming (gRPC or WebSocket). func IsStreamingRequest(r *http.Request) bool { _, upgrade := IsUpgrade(r) diff --git a/internal/httputil/streaming_test.go b/internal/httputil/streaming_test.go index 6b44dc765f6..806ee380959 100644 --- a/internal/httputil/streaming_test.go +++ b/internal/httputil/streaming_test.go @@ -57,11 +57,21 @@ func TestIsUpgrade(t *testing.T) { assert.True(t, ok) assert.Equal(t, "websocket", upgradeType) + req = newRequestWithHeaders(t, 0, map[string]string{headerConnection: "keep-alive, Upgrade", headerUpgrade: "websocket"}) + upgradeType, ok = IsUpgrade(req) + assert.True(t, ok) + assert.Equal(t, "websocket", upgradeType) + req = newRequestWithHeaders(t, 0, map[string]string{headerConnection: "keep-alive", headerUpgrade: "websocket"}) upgradeType, ok = IsUpgrade(req) assert.False(t, ok) assert.Empty(t, upgradeType) + req = newRequestWithHeaders(t, 0, map[string]string{headerConnection: "keep-alive, Upgrade"}) + upgradeType, ok = IsUpgrade(req) + assert.False(t, ok) + assert.Empty(t, upgradeType) + req = newRequestWithHeaders(t, 0, map[string]string{headerConnection: "Upgrade"}) upgradeType, ok = IsUpgrade(req) assert.False(t, ok)