Skip to content

Commit

Permalink
Allow CORS to be enabled in config
Browse files Browse the repository at this point in the history
  • Loading branch information
wcsanders1 committed Feb 23, 2019
1 parent abdf498 commit 655f6f4
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 18 deletions.
10 changes: 9 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ func (api *API) Start(dir, defaultCert, defaultKey string) error {
})
contextLogger.Debug("starting API")

api.handlers["OPTIONS"] = make(map[string]func(http.ResponseWriter, *http.Request))
for endpointName, endpoint := range api.endpoints {
var path string
if len(api.baseURL) > 0 {
Expand Down Expand Up @@ -149,7 +150,14 @@ func (api *API) Start(dir, defaultCert, defaultKey string) error {
}

contextLoggerEndpoint.Debug("registered endpoint; now assigning handler")
api.handlers[method][registeredRoute] = api.creator.getHandler(endpoint.EnforceValidJSON, endpoint.Headers, dir, file, api.file)
api.handlers[method][registeredRoute] = api.creator.getHandler(endpoint.EnforceValidJSON, endpoint.AllowCORS, endpoint.Headers, dir, file, api.file)
if endpoint.AllowCORS {
api.handlers["OPTIONS"][registeredRoute] = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "*")
w.Header().Set("Access-Control-Allow-headers", "*")
}
}
}

return api.creator.startAPI(defaultCert, defaultKey, api.server, api.httpConfig)
Expand Down
20 changes: 14 additions & 6 deletions api/api_creator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

type (
iCreator interface {
getHandler(enforceValidJSON bool, headers []config.Header, dir, fileName string, file wrapper.IFileOps) func(w http.ResponseWriter, r *http.Request)
getHandler(enforceValidJSON, allowCORS bool, headers []config.Header, dir, fileName string, file wrapper.IFileOps) func(w http.ResponseWriter, r *http.Request)
startAPI(defaultCert, defaultKey string, server wrapper.IServerOps, httpConfig config.HTTP) error
}

Expand All @@ -32,7 +32,7 @@ func newCreator(logger *logrus.Entry) *creator {
}
}

func (c creator) getHandler(enforceValidJSON bool, headers []config.Header, dir, fileName string, file wrapper.IFileOps) func(w http.ResponseWriter, r *http.Request) {
func (c creator) getHandler(enforceValidJSON, allowCORS bool, headers []config.Header, dir, fileName string, file wrapper.IFileOps) func(w http.ResponseWriter, r *http.Request) {
var path string
if len(fileName) > 0 {
path = fmt.Sprintf("%s/%s/%s", constants.APIDir, dir, fileName)
Expand All @@ -47,18 +47,22 @@ func (c creator) getHandler(enforceValidJSON bool, headers []config.Header, dir,
})

if enforceValidJSON {
return getJSONHandler(path, headers, file, contextLogger)
return getJSONHandler(path, headers, file, contextLogger, allowCORS)
}
return getGeneralHandler(path, headers, file, contextLogger)
return getGeneralHandler(path, headers, file, contextLogger, allowCORS)
}

func getJSONHandler(path string, headers []config.Header, file wrapper.IFileOps, logger *logrus.Entry) func(w http.ResponseWriter, r *http.Request) {
func getJSONHandler(path string, headers []config.Header, file wrapper.IFileOps, logger *logrus.Entry, allowCORS bool) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {

for _, header := range headers {
w.Header().Set(header.Key, header.Value)
}

if allowCORS {
w.Header().Set("Access-Control-Allow-Origin", "*")
}

if len(path) == 0 {
return
}
Expand All @@ -74,12 +78,16 @@ func getJSONHandler(path string, headers []config.Header, file wrapper.IFileOps,
}
}

func getGeneralHandler(path string, headers []config.Header, file wrapper.IFileOps, logger *logrus.Entry) func(w http.ResponseWriter, r *http.Request) {
func getGeneralHandler(path string, headers []config.Header, file wrapper.IFileOps, logger *logrus.Entry, allowCORS bool) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
for _, header := range headers {
w.Header().Set(header.Key, header.Value)
}

if allowCORS {
w.Header().Set("Access-Control-Allow-Origin", "*")
}

if len(path) == 0 {
return
}
Expand Down
2 changes: 1 addition & 1 deletion api/api_creator_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type fakeAPICreator struct {
mock.Mock
}

func (c *fakeAPICreator) getHandler(enforceValidJSON bool, headers []config.Header, dir, fileName string, file wrapper.IFileOps) func(w http.ResponseWriter, r *http.Request) {
func (c *fakeAPICreator) getHandler(enforceValidJSON, allowCORS bool, headers []config.Header, dir, fileName string, file wrapper.IFileOps) func(w http.ResponseWriter, r *http.Request) {
args := c.Called(enforceValidJSON, dir, fileName, file)
return args.Get(0).(func(w http.ResponseWriter, r *http.Request))
}
Expand Down
18 changes: 9 additions & 9 deletions api/api_creator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestGetHandler_ReturnsHandler_WhenEnforceJSONFalse(t *testing.T) {
log: log.GetFakeLogger(),
}

result := creator.getHandler(false, nil, "testDir", "testFile", &wrapper.FakeFileOps{})
result := creator.getHandler(false, false, nil, "testDir", "testFile", &wrapper.FakeFileOps{})

assert := assert.New(t)
assert.NotNil(result)
Expand All @@ -45,7 +45,7 @@ func TestGetHandler_ReturnsHandler_WhenEnforceJSONTrue(t *testing.T) {
log: log.GetFakeLogger(),
}

result := creator.getHandler(true, nil, "testDir", "testFile", &wrapper.FakeFileOps{})
result := creator.getHandler(true, false, nil, "testDir", "testFile", &wrapper.FakeFileOps{})

assert := assert.New(t)
assert.NotNil(result)
Expand All @@ -55,7 +55,7 @@ func TestGetHandler_ReturnsHandler_WhenEnforceJSONTrue(t *testing.T) {
func TestGetJSONHandler_ReturnsHandler_WhenCalled(t *testing.T) {
fileOps := wrapper.FakeFileOps{}
logger := log.GetFakeLogger()
funcResult := getJSONHandler("test", nil, &fileOps, logger)
funcResult := getJSONHandler("test", nil, &fileOps, logger, false)

assert.NotNil(t, funcResult)
}
Expand All @@ -66,7 +66,7 @@ func TestJSONHandler_Writes_OnSuccess(t *testing.T) {
logger := log.GetFakeLogger()
fileOps.On("Open", mock.AnythingOfType("string")).Return(os.NewFile(1, "fakefile"), nil)
fileOps.On("ReadAll", mock.AnythingOfType("*os.File")).Return(goodJSON, nil)
funcResult := getJSONHandler(path, nil, &fileOps, logger)
funcResult := getJSONHandler(path, nil, &fileOps, logger, false)
w := fake.ResponseWriter{}
w.On("WriteHeader", mock.AnythingOfType("int")).Return(1)
w.On("Write", mock.AnythingOfType("[]uint8")).Return(1, nil)
Expand All @@ -85,7 +85,7 @@ func TestJSONHandler_WritesError_OnFailure(t *testing.T) {
logger := log.GetFakeLogger()
fileOps.On("Open", mock.AnythingOfType("string")).Return(os.NewFile(1, "fakefile"), nil)
fileOps.On("ReadAll", mock.AnythingOfType("*os.File")).Return([]byte{}, errors.New(""))
funcResult := getJSONHandler(path, nil, &fileOps, logger)
funcResult := getJSONHandler(path, nil, &fileOps, logger, false)
w := fake.ResponseWriter{}
w.On("WriteHeader", mock.AnythingOfType("int")).Return(1)
w.On("Write", mock.AnythingOfType("[]uint8")).Return(1, nil)
Expand All @@ -102,7 +102,7 @@ func TestJSONHandler_WritesError_OnFailure(t *testing.T) {
func TestGetGeneralHanlder_ReturnsFunc_WhenCalled(t *testing.T) {
fileOps := wrapper.FakeFileOps{}
logger := log.GetFakeLogger()
funcResult := getGeneralHandler("test", nil, &fileOps, logger)
funcResult := getGeneralHandler("test", nil, &fileOps, logger, false)

assert.NotNil(t, funcResult)
}
Expand All @@ -113,7 +113,7 @@ func TestGeneralHandler_Writes_OnSuccess(t *testing.T) {
logger := log.GetFakeLogger()
fileOps.On("Open", mock.AnythingOfType("string")).Return(os.NewFile(1, "fakefile"), nil)
fileOps.On("ReadAll", mock.AnythingOfType("*os.File")).Return(goodJSON, nil)
funcResult := getGeneralHandler(path, nil, &fileOps, logger)
funcResult := getGeneralHandler(path, nil, &fileOps, logger, false)
w := fake.ResponseWriter{}
w.On("WriteHeader", mock.AnythingOfType("int")).Return(1)
w.On("Write", mock.AnythingOfType("[]uint8")).Return(1, nil)
Expand All @@ -132,7 +132,7 @@ func TestGeneralHandler_WritesError_WhenReadFails(t *testing.T) {
logger := log.GetFakeLogger()
fileOps.On("Open", mock.AnythingOfType("string")).Return(os.NewFile(1, "fakefile"), nil)
fileOps.On("ReadAll", mock.AnythingOfType("*os.File")).Return([]byte{}, errors.New(""))
funcResult := getGeneralHandler(path, nil, &fileOps, logger)
funcResult := getGeneralHandler(path, nil, &fileOps, logger, false)
w := fake.ResponseWriter{}
w.On("WriteHeader", mock.AnythingOfType("int")).Return(1)
w.On("Write", mock.AnythingOfType("[]uint8")).Return(1, nil)
Expand All @@ -151,7 +151,7 @@ func TestGeneralHandler_WritesError_WhenFileOpenFails(t *testing.T) {
fileOps := wrapper.FakeFileOps{}
logger := log.GetFakeLogger()
fileOps.On("Open", mock.AnythingOfType("string")).Return(os.NewFile(1, "fakefile"), errors.New(""))
funcResult := getGeneralHandler(path, nil, &fileOps, logger)
funcResult := getGeneralHandler(path, nil, &fileOps, logger, false)
w := fake.ResponseWriter{}
w.On("WriteHeader", mock.AnythingOfType("int")).Return(1)
w.On("Write", mock.AnythingOfType("[]uint8")).Return(1, nil)
Expand Down
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type (
Method string
Headers []Header
EnforceValidJSON bool
AllowCORS bool
}

// Header contains the keys and values to put on response headers.
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func main() {
hubFlags.Visit(func(f *flag.Flag) {
if f.Name == "v" {
*showVersion = true
fmt.Println("v0.2.1")
fmt.Println("v0.3.0")
}
})

Expand Down
1 change: 1 addition & 0 deletions mockApis/exampleCustomersApi/customersApi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ keyFile = ""
path = "customers/:id/balances"
file = "customers.json"
method = "GET"
allowCORS = true

[[endpoints.getCustomers.headers]]
key = "content-type"
Expand Down
2 changes: 2 additions & 0 deletions mockApis/exampleStudentsApi/studentsApi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ keyFile = ""
[endpoints.getStudents]
path = "students"
method = "GET"
allowCORS = true

[endpoints.getNothing]
path = ""
file = "students.json"
method = "GET"
enforceValidJSON = true
allowCORS = true

[endpoints.postStudents]
path = "students"
Expand Down

0 comments on commit 655f6f4

Please sign in to comment.