diff --git a/auth/cors.go b/auth/cors.go index 095043a9..3d8c2aa0 100644 --- a/auth/cors.go +++ b/auth/cors.go @@ -32,11 +32,23 @@ func CorsConfig(conf *config.Configuration) cors.Config { } return false } + if allowedOrigin := headerIgnoreCase(conf, "access-control-allow-origin"); allowedOrigin != "" && len(compiledOrigins) == 0 { + corsConf.AllowOrigins = append(corsConf.AllowOrigins, allowedOrigin) + } } return corsConf } +func headerIgnoreCase(conf *config.Configuration, search string) (value string) { + for key, value := range conf.Server.ResponseHeaders { + if strings.ToLower(key) == search { + return value + } + } + return "" +} + func compileAllowedCORSOrigins(allowedOrigins []string) []*regexp.Regexp { var compiledAllowedOrigins []*regexp.Regexp for _, origin := range allowedOrigins { diff --git a/auth/cors_test.go b/auth/cors_test.go index edf9209c..a3254de2 100644 --- a/auth/cors_test.go +++ b/auth/cors_test.go @@ -32,6 +32,21 @@ func TestCorsConfig(t *testing.T) { assert.False(t, allowF("https://test.com")) assert.False(t, allowF("https://other.com")) } +func TestEmptyCorsConfigWithResponseHeaders(t *testing.T) { + mode.Set(mode.Prod) + serverConf := config.Configuration{} + serverConf.Server.ResponseHeaders = map[string]string{"Access-control-allow-origin": "https://example.com"} + + actual := CorsConfig(&serverConf) + assert.NotNil(t, actual.AllowOriginFunc) + actual.AllowOriginFunc = nil // func cannot be checked with equal + + assert.Equal(t, cors.Config{ + AllowAllOrigins: false, + AllowOrigins: []string{"https://example.com"}, + MaxAge: 12 * time.Hour, + }, actual) +} func TestDevCorsConfig(t *testing.T) { mode.Set(mode.Dev) diff --git a/router/router_test.go b/router/router_test.go index c7082557..14f74456 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -86,7 +86,7 @@ func TestHeadersFromConfiguration(t *testing.T) { config := config.Configuration{PassStrength: 5} config.Server.ResponseHeaders = map[string]string{ "New-Cool-Header": "Nice", - "Access-Control-Allow-Origin": "---", + "Access-Control-Allow-Origin": "http://test1.com", } g, closable := Create(db.GormDatabase, @@ -106,7 +106,7 @@ func TestHeadersFromConfiguration(t *testing.T) { res, err := client.Do(req) assert.Nil(t, err) - assert.Equal(t, "---", res.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, "http://test1.com", res.Header.Get("Access-Control-Allow-Origin")) assert.Equal(t, "Nice", res.Header.Get("New-Cool-Header")) } @@ -168,6 +168,74 @@ func TestInvalidOrigin(t *testing.T) { assert.Equal(t, http.StatusForbidden, res.StatusCode) } +func TestAllowedOriginFromResponseHeaders(t *testing.T) { + mode.Set(mode.Prod) + db := testdb.NewDBWithDefaultUser(t) + defer db.Close() + + config := config.Configuration{PassStrength: 5} + config.Server.ResponseHeaders = map[string]string{ + "Access-Control-Allow-Origin": "http://test1.com", + "Access-Control-Allow-Methods": "GET,POST"} + + g, closable := Create(db.GormDatabase, + &model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"}, + &config, + ) + server := httptest.NewServer(g) + + defer func() { + closable() + server.Close() + }() + + req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil) + req.Header.Add("Origin", "http://test1.com") + assert.Nil(t, err) + + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, "http://test1.com", res.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, http.StatusOK, res.StatusCode) + + req.Header.Set("Origin", "http://example.com") + res, err = client.Do(req) + assert.Nil(t, err) + assert.Equal(t, "http://test1.com", res.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, http.StatusForbidden, res.StatusCode) +} + +func TestAllowedWildcardOriginInHeader(t *testing.T) { + mode.Set(mode.Prod) + db := testdb.NewDBWithDefaultUser(t) + defer db.Close() + + config := config.Configuration{PassStrength: 5} + config.Server.ResponseHeaders = map[string]string{ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET,POST"} + + g, closable := Create(db.GormDatabase, + &model.VersionInfo{Version: "1.0.0", BuildDate: "2018-02-20-17:30:47", Commit: "asdasds"}, + &config, + ) + server := httptest.NewServer(g) + + defer func() { + closable() + server.Close() + }() + + req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", server.URL, "version"), nil) + req.Header.Add("Origin", "http://test1.com") + assert.Nil(t, err) + + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, "*", res.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, http.StatusOK, res.StatusCode) +} + func TestCORSHeaderRegex(t *testing.T) { mode.Set(mode.Prod) db := testdb.NewDBWithDefaultUser(t) @@ -206,7 +274,7 @@ func TestCORSConfigOverride(t *testing.T) { config := config.Configuration{PassStrength: 5} config.Server.ResponseHeaders = map[string]string{ "New-Cool-Header": "Nice", - "Access-Control-Allow-Origin": "something-else", + "Access-Control-Allow-Origin": "http://example.com/", "Access-Control-Allow-Methods": "321test", "Access-Control-Allow-Headers": "some-headers", } @@ -232,10 +300,16 @@ func TestCORSConfigOverride(t *testing.T) { res, err := client.Do(req) assert.Nil(t, err) + assert.Equal(t, http.StatusNoContent, res.StatusCode) assert.Equal(t, "Nice", res.Header.Get("New-Cool-Header")) assert.Equal(t, "http://test123.com", res.Header.Get("Access-Control-Allow-Origin")) assert.Equal(t, "GET,OPTIONS", res.Header.Get("Access-Control-Allow-Methods")) assert.Equal(t, "Content-Type", res.Header.Get("Access-Control-Allow-Headers")) + + req.Header.Set("Origin", "http://example.com") + res, err = client.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusForbidden, res.StatusCode) } func (s *IntegrationSuite) TestOptionsRequest() {