Skip to content

Commit

Permalink
feat: load session only once when middleware is used (#4187)
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr authored Nov 4, 2024
1 parent 5665f20 commit 234b6f2
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 6 deletions.
8 changes: 4 additions & 4 deletions selfservice/flow/settings/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ type createNativeSettingsFlow struct {
// default: errorGeneric
func (h *Handler) createNativeSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
ctx := r.Context()
s, err := h.d.SessionManager().FetchFromRequest(ctx, r)
s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
if err != nil {
h.d.Writer().WriteError(w, r, err)
return
Expand Down Expand Up @@ -298,7 +298,7 @@ type createBrowserSettingsFlow struct {
// default: errorGeneric
func (h *Handler) createBrowserSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
ctx := r.Context()
s, err := h.d.SessionManager().FetchFromRequest(ctx, r)
s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
if err != nil {
h.d.SelfServiceErrorManager().Forward(ctx, w, r, err)
return
Expand Down Expand Up @@ -404,7 +404,7 @@ func (h *Handler) getSettingsFlow(w http.ResponseWriter, r *http.Request, _ http
return
}

sess, err := h.d.SessionManager().FetchFromRequest(ctx, r)
sess, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
if err != nil {
h.d.Writer().WriteError(w, r, err)
return
Expand Down Expand Up @@ -574,7 +574,7 @@ func (h *Handler) updateSettingsFlow(w http.ResponseWriter, r *http.Request, ps
return
}

ss, err := h.d.SessionManager().FetchFromRequest(ctx, r)
ss, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
if err != nil {
h.d.SettingsFlowErrorHandler().WriteFlowError(w, r, node.DefaultGroup, f, nil, err)
return
Expand Down
13 changes: 11 additions & 2 deletions session/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package session

import (
"context"
"fmt"
"net/http"
"strconv"
Expand Down Expand Up @@ -837,9 +838,17 @@ func (h *Handler) listMySessions(w http.ResponseWriter, r *http.Request, _ httpr
h.r.Writer().Write(w, r, sess)
}

type sessionInContext int

const (
sessionInContextKey sessionInContext = iota
)

func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
if _, err := h.r.SessionManager().FetchFromRequest(r.Context(), r); err != nil {
ctx := r.Context()
sess, err := h.r.SessionManager().FetchFromRequest(ctx, r)
if err != nil {
if onUnauthenticated != nil {
onUnauthenticated(w, r, ps)
return
Expand All @@ -849,7 +858,7 @@ func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated http
return
}

wrap(w, r, ps)
wrap(w, r.WithContext(context.WithValue(ctx, sessionInContextKey, sess)), ps)
}
}

Expand Down
3 changes: 3 additions & 0 deletions session/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ type Manager interface {
// FetchFromRequest creates an HTTP session using cookies.
FetchFromRequest(context.Context, *http.Request) (*Session, error)

// FetchFromRequestContext returns the session from the context or if that is unset, falls back to FetchFromRequest.
FetchFromRequestContext(context.Context, *http.Request) (*Session, error)

// PurgeFromRequest removes an HTTP session.
PurgeFromRequest(context.Context, http.ResponseWriter, *http.Request) error

Expand Down
11 changes: 11 additions & 0 deletions session/manager_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ func (s *ManagerHTTP) extractToken(r *http.Request) string {
return token
}

func (s *ManagerHTTP) FetchFromRequestContext(ctx context.Context, r *http.Request) (_ *Session, err error) {
ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.FetchFromRequestContext")
otelx.End(span, &err)

if sess, ok := ctx.Value(sessionInContextKey).(*Session); ok {
return sess, nil
}

return s.FetchFromRequest(ctx, r)
}

func (s *ManagerHTTP) FetchFromRequest(ctx context.Context, r *http.Request) (_ *Session, err error) {
ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.FetchFromRequest")
defer func() {
Expand Down
14 changes: 14 additions & 0 deletions session/manager_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ func TestManagerHTTP(t *testing.T) {
reg.Writer().Write(w, r, sess)
})

rp.GET("/session/get-middleware", reg.SessionHandler().IsAuthenticated(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
sess, err := reg.SessionManager().FetchFromRequestContext(r.Context(), r)
if err != nil {
t.Logf("Got error on lookup: %s %T", err, errors.Unwrap(err))
reg.Writer().WriteError(w, r, err)
return
}
reg.Writer().Write(w, r, sess)
}, session.RedirectOnUnauthenticated("https://failed.com")))

pts := httptest.NewServer(x.NewTestCSRFHandler(rp, reg))
t.Cleanup(pts.Close)
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, pts.URL)
Expand All @@ -263,6 +273,10 @@ func TestManagerHTTP(t *testing.T) {
res, err := c.Get(pts.URL + "/session/get")
require.NoError(t, err)
assert.EqualValues(t, http.StatusOK, res.StatusCode)

res, err = c.Get(pts.URL + "/session/get-middleware")
require.NoError(t, err)
assert.EqualValues(t, http.StatusOK, res.StatusCode)
})

t.Run("case=key rotation", func(t *testing.T) {
Expand Down

0 comments on commit 234b6f2

Please sign in to comment.