diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index ff9c1517dc1f..88b3712602a0 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -822,7 +822,7 @@ continueLogin: return } - if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, nil, ""); err != nil { + if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, ""); err != nil { if errors.Is(err, ErrAddressNotVerified) { h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, errors.WithStack(schema.NewAddressNotVerifiedError())) return diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index 3b98097e66d5..ec2c899c7b90 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -63,6 +63,7 @@ type ( } HookExecutor struct { d executorDependencies + c *claims.Claims } HookExecutorProvider interface { LoginHookExecutor() *HookExecutor @@ -119,6 +120,14 @@ func (e *HookExecutor) handleLoginError(_ http.ResponseWriter, r *http.Request, return flowError } +type PostLoginHookOpt func(*HookExecutor) + +func WithClaims(c *claims.Claims) PostLoginHookOpt { + return func(h *HookExecutor) { + h.c = c + } +} + func (e *HookExecutor) PostLoginHook( w http.ResponseWriter, r *http.Request, @@ -126,8 +135,8 @@ func (e *HookExecutor) PostLoginHook( f *Flow, i *identity.Identity, s *session.Session, - c *claims.Claims, provider string, + opts ...PostLoginHookOpt, ) (err error) { ctx := r.Context() ctx, span := e.d.Tracer(ctx).Tracer().Start(ctx, "HookExecutor.PostLoginHook") @@ -164,13 +173,17 @@ func (e *HookExecutor) PostLoginHook( classified := s s = s.Declassified() + for _, o := range opts { + o(e) + } + e.d.Logger(). WithRequest(r). WithField("identity_id", i.ID). WithField("flow_method", f.Active). Debug("Running ExecuteLoginPostHook.") for k, executor := range e.d.PostLoginHooks(r.Context(), f.Active) { - if err := executor.ExecuteLoginPostHook(w, r, g, f, s, c); err != nil { + if err := executor.ExecuteLoginPostHook(w, r, g, f, s, e.c); err != nil { if errors.Is(err, ErrHookAbortFlow) { e.d.Logger(). WithRequest(r). diff --git a/selfservice/flow/login/hook_test.go b/selfservice/flow/login/hook_test.go index bcaa8bb58b2f..fe73f22d7eef 100644 --- a/selfservice/flow/login/hook_test.go +++ b/selfservice/flow/login/hook_test.go @@ -72,7 +72,7 @@ func TestLoginExecutor(t *testing.T) { } testhelpers.SelfServiceHookLoginErrorHandler(t, w, r, - reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, nil, "")) + reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, "")) }) ts := httptest.NewServer(router) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 5c32de727721..34b8c63fc527 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -176,7 +176,7 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo httprouter.ParamsFromContext(r.Context()).ByName("organization")) for _, c := range oidcCredentials.Providers { if c.Subject == claims.Subject && c.Provider == provider.Config().ID { - if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, claims, provider.Config().ID); err != nil { + if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID, login.WithClaims(claims)); err != nil { return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err) } return nil, nil