diff --git a/pkg/constant/constant.go b/pkg/constant/constant.go index 47bf294b..ee5563b7 100644 --- a/pkg/constant/constant.go +++ b/pkg/constant/constant.go @@ -51,9 +51,10 @@ const ( _ contextKey = iota ContextScopeName - HeaderXForwardedFor = "X-Forwarded-For" - HeaderXRealIP = "X-Real-IP" - HeaderXHMAC = "X-HMAC-SHA256" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedHost = "X-Forwarded-Host" + HeaderXRealIP = "X-Real-IP" + HeaderXHMAC = "X-HMAC-SHA256" DurationType = "time.Duration" diff --git a/pkg/proxy/middleware/base.go b/pkg/proxy/middleware/base.go index 928903c8..c87c72bf 100644 --- a/pkg/proxy/middleware/base.go +++ b/pkg/proxy/middleware/base.go @@ -322,12 +322,11 @@ func ProxyMiddleware( // @step: add the proxy forwarding headers req.Header.Set("X-Real-IP", utils.RealIP(req)) if xff := req.Header.Get(constant.HeaderXForwardedFor); xff == "" { - req.Header.Set("X-Forwarded-For", utils.RealIP(req)) - } else { - req.Header.Set("X-Forwarded-For", xff) + req.Header.Set(constant.HeaderXForwardedFor, utils.RealIP(req)) + } + if xfh := req.Header.Get(constant.HeaderXForwardedHost); xfh == "" { + req.Header.Set(constant.HeaderXForwardedHost, req.Host) } - req.Header.Set("X-Forwarded-Host", req.Host) - req.Header.Set("X-Forwarded-Proto", req.Header.Get("X-Forwarded-Proto")) if len(corsOrigins) > 0 { // if CORS is enabled by Gatekeeper, do not propagate CORS requests upstream diff --git a/pkg/testsuite/server_test.go b/pkg/testsuite/server_test.go index 9fdacb19..c46d3b6f 100644 --- a/pkg/testsuite/server_test.go +++ b/pkg/testsuite/server_test.go @@ -1518,7 +1518,7 @@ func TestXForwarded(t *testing.T) { ExecutionSettings []fakeRequest }{ { - Name: "TestEmptyXForwarded", + Name: "TestEmptyXForwardedFor", ProxySettings: func(_ *config.Config) { }, ExecutionSettings: []fakeRequest{ @@ -1535,7 +1535,7 @@ func TestXForwarded(t *testing.T) { }, }, { - Name: "TestXForwardedPresent", + Name: "TestXForwardedForPresent", ProxySettings: func(_ *config.Config) { }, ExecutionSettings: []fakeRequest{ @@ -1574,6 +1574,43 @@ func TestXForwarded(t *testing.T) { }, }, }, + { + Name: "TestEmptyXForwardedHost", + ProxySettings: func(_ *config.Config) { + }, + ExecutionSettings: []fakeRequest{ + { + URI: FakeAuthAllURL + FakeTestURL, + HasToken: true, + ExpectedProxy: true, + ExpectedProxyHeadersValidator: map[string]func(*testing.T, *config.Config, string){ + "X-Forwarded-Host": func(t *testing.T, _ *config.Config, value string) { + assert.Contains(t, value, "127.0.0.1") + }, + }, + ExpectedCode: http.StatusOK, + }, + }, + }, + { + Name: "TestXForwardedHostPresent", + ProxySettings: func(_ *config.Config) { + }, + ExecutionSettings: []fakeRequest{ + { + URI: FakeAuthAllURL + FakeTestURL, + HasToken: true, + ExpectedProxy: true, + Headers: map[string]string{ + "X-Forwarded-Host": "189.10.10.1", + }, + ExpectedProxyHeaders: map[string]string{ + "X-Forwarded-Host": "189.10.10.1", + }, + ExpectedCode: http.StatusOK, + }, + }, + }, } for _, testCase := range testCases {