Skip to content

Commit

Permalink
Add logic to override CORS headers
Browse files Browse the repository at this point in the history
  • Loading branch information
mbillewicz-olx committed Sep 18, 2024
1 parent 3301c0c commit 9acdf28
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
16 changes: 15 additions & 1 deletion cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,9 @@ func (client *Client) Middleware() echo.MiddlewareFunc {
}
} else {
b, ok := client.adapter.Get(key)
response := BytesToResponse(b)
if ok {
response := BytesToResponse(b)
response.Header = client.rewriteCorsHeaders(response.Header, c.Response().Header())
if response.Expiration.After(time.Now()) {
response.LastAccess = time.Now()
response.Frequency++
Expand Down Expand Up @@ -209,6 +210,19 @@ func (client *Client) Middleware() echo.MiddlewareFunc {
}
}

func (client *Client) rewriteCorsHeaders(cachedHeaders http.Header, responseHeaders http.Header) http.Header {
corsHeaders := []string{
"Access-Control-Allow-Origin",
"Access-Control-Allow-Credentials",
}
for _, h := range corsHeaders {
if val := responseHeaders.Get(h); val != "" {
cachedHeaders.Set(h, val)
}
}
return cachedHeaders
}

func (client *Client) cacheableMethod(method string) bool {
for _, m := range client.methods {
if method == m {
Expand Down
46 changes: 46 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"
"time"

"github.com/coinpaprika/echo-http-cache/adapter/memory"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -231,6 +232,51 @@ func TestMiddleware(t *testing.T) {
}
}

func TestCorsHeaders(t *testing.T) {
e := echo.New()

handler := func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
}
memoryAdapter, err := memory.NewAdapter()
require.NoError(t, err)

client, _ := NewClient(
ClientWithAdapter(memoryAdapter),
ClientWithTTL(1*time.Minute),
)

cacheMiddleware := client.Middleware()

req, err := http.NewRequest(http.MethodGet, "/test", nil)
require.NoError(t, err)
rec := httptest.NewRecorder()

// simulate CORS middleware
rec.Header().Add("Access-Control-Allow-Origin", "http://localhost:8181")
rec.Header().Add("Access-Control-Allow-Credentials", "true")

c := e.NewContext(req, rec)
_ = cacheMiddleware(handler)(c)

assert.Equal(t, "http://localhost:8181", rec.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", rec.Header().Get("Access-Control-Allow-Credentials"))

secondRec := httptest.NewRecorder()

// simulate CORS middleware
secondRec.Header().Add("Access-Control-Allow-Origin", "http://coinpaprika.com")
secondRec.Header().Add("Access-Control-Allow-Credentials", "true")

secondC := e.NewContext(req, secondRec)
_ = cacheMiddleware(handler)(secondC)

assert.Equal(t, "http://coinpaprika.com", secondRec.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", secondRec.Header().Get("Access-Control-Allow-Credentials"))
time.Sleep(time.Second)
}


func TestRestrictedPaths(t *testing.T) {
tests := []struct {
name string
Expand Down

0 comments on commit 9acdf28

Please sign in to comment.