diff --git a/docs/configuration.md b/docs/configuration.md index 62eea7a..43e33b8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -23,7 +23,8 @@ The following flags are available: | `openid.acr-values` | string | Space separated string that configures the default security level (`acr_values`) parameter for authorization requests. | | | `openid.audiences` | strings | List of additional trusted audiences (other than the client_id) for OpenID Connect id_token validation. | | | `openid.client-id` | string | Client ID for the OpenID client. | | -| `openid.client-jwk` | string | JWK containing the private key for the OpenID client in string format. | | +| `openid.client-jwk` | string | JWK containing the private key for the OpenID client in string format. If configured, this takes precedence over 'openid.client-secret'. | | +| `openid.client-secret` | string | Client secret for the OpenID client. Overridden by 'openid.client-jwk', if configured. | | | `openid.post-logout-redirect-uri` | string | URI for redirecting the user after successful logout at the Identity Provider. | | | `openid.provider` | string | Provider configuration to load and use, either `openid`, `azure`, `idporten`. | `openid` | | `openid.resource-indicator` | string | OAuth2 resource indicator to include in authorization request for acquiring audience-restricted tokens. | | @@ -82,7 +83,7 @@ The default configuration of Wonderwall will start in [_standalone mode_](archit At minimum, the following configuration must be provided when in standalone mode: - `openid.client-id` -- `openid.client-jwk` +- `openid.client-jwk` or `openid.client-secret` - `openid.well-known-url` - `ingress` @@ -99,7 +100,7 @@ When the `sso.enabled` flag is enabled and the `sso.mode` flag is set to `server At minimum, the following configuration must be provided when in SSO server mode: - `openid.client-id` -- `openid.client-jwk` +- `openid.client-jwk` or `openid.client-secret` - `openid.well-known-url` - `ingress` - `redis.address` diff --git a/pkg/config/config.go b/pkg/config/config.go index 1eac636..208bc82 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -119,6 +119,7 @@ func Initialize() (*Config, error) { masked := *cfg masked.EncryptionKey = redacted masked.OpenID.ClientJWK = redacted + masked.OpenID.ClientSecret = redacted masked.Redis.Password = redacted logger.Infof("config: %+v", masked) diff --git a/pkg/config/openid.go b/pkg/config/openid.go index 07a4a4a..2455be4 100644 --- a/pkg/config/openid.go +++ b/pkg/config/openid.go @@ -20,6 +20,7 @@ type OpenID struct { Audiences []string `json:"audiences"` ClientID string `json:"client-id"` ClientJWK string `json:"client-jwk"` + ClientSecret string `json:"client-secret"` PostLogoutRedirectURI string `json:"post-logout-redirect-uri"` Provider Provider `json:"provider"` ResourceIndicator string `json:"resource-indicator"` @@ -43,6 +44,7 @@ const ( OpenIDAudiences = "openid.audiences" OpenIDClientID = "openid.client-id" OpenIDClientJWK = "openid.client-jwk" + OpenIDClientSecret = "openid.client-secret" OpenIDPostLogoutRedirectURI = "openid.post-logout-redirect-uri" OpenIDProvider = "openid.provider" OpenIDResourceIndicator = "openid.resource-indicator" @@ -55,7 +57,8 @@ func openidFlags() { flag.String(OpenIDACRValues, "", "Space separated string that configures the default security level (acr_values) parameter for authorization requests.") flag.StringSlice(OpenIDAudiences, []string{}, "List of additional trusted audiences (other than the client_id) for OpenID Connect id_token validation.") flag.String(OpenIDClientID, "", "Client ID for the OpenID client.") - flag.String(OpenIDClientJWK, "", "JWK containing the private key for the OpenID client in string format.") + flag.String(OpenIDClientJWK, "", "JWK containing the private key for the OpenID client in string format. If configured, this takes precedence over 'openid.client-secret'.") + flag.String(OpenIDClientSecret, "", "Client secret for the OpenID client. Overridden by 'openid.client-jwk', if configured.") flag.String(OpenIDPostLogoutRedirectURI, "", "URI for redirecting the user after successful logout at the Identity Provider.") flag.String(OpenIDProvider, string(ProviderOpenID), "Provider configuration to load and use, either 'openid', 'azure', 'idporten'.") flag.String(OpenIDResourceIndicator, "", "OAuth2 resource indicator to include in authorization request for acquiring audience-restricted tokens.") diff --git a/pkg/mock/client.go b/pkg/mock/client.go index 3286ecc..dba124c 100644 --- a/pkg/mock/client.go +++ b/pkg/mock/client.go @@ -2,9 +2,9 @@ package mock import ( "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/crypto" + openidconfig "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/openid/scopes" ) @@ -22,6 +22,10 @@ func (c *TestClientConfiguration) Audiences() map[string]bool { return c.trustedAudiences } +func (c *TestClientConfiguration) AuthMethod() openidconfig.AuthMethod { + return openidconfig.AuthMethodPrivateKeyJWT +} + func (c *TestClientConfiguration) ClientID() string { return c.Config.OpenID.ClientID } @@ -30,6 +34,10 @@ func (c *TestClientConfiguration) ClientJWK() jwk.Key { return c.clientJwk } +func (c *TestClientConfiguration) ClientSecret() string { + return c.Config.OpenID.ClientSecret +} + func (c *TestClientConfiguration) SetPostLogoutRedirectURI(uri string) { c.Config.OpenID.PostLogoutRedirectURI = uri } diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index b628ca2..a66d6ff 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net/http" - "net/url" "strings" "time" @@ -107,54 +106,19 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A return c.oauth2Config.Exchange(ctx, code, opts...) } -func (c *Client) MakeAssertion(expiration time.Duration) (string, error) { - clientCfg := c.cfg.Client() - providerCfg := c.cfg.Provider() - key := clientCfg.ClientJWK() - - iat := time.Now().Add(-5 * time.Second).Truncate(time.Second) - exp := iat.Add(expiration) - - errs := make([]error, 0) - - tok := jwt.New() - errs = append(errs, tok.Set(jwt.IssuerKey, clientCfg.ClientID())) - errs = append(errs, tok.Set(jwt.SubjectKey, clientCfg.ClientID())) - errs = append(errs, tok.Set(jwt.AudienceKey, providerCfg.Issuer())) - errs = append(errs, tok.Set(jwt.IssuedAtKey, iat)) - errs = append(errs, tok.Set(jwt.ExpirationKey, exp)) - errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String())) - - for _, err := range errs { - if err != nil { - return "", fmt.Errorf("setting claim for client assertion: %w", err) - } - } - - encoded, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key)) - if err != nil { - return "", fmt.Errorf("signing client assertion: %w", err) - } - - return string(encoded), nil -} - func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) { - assertion, err := c.MakeAssertion(DefaultClientAssertionLifetime) + params, err := c.AuthParams() if err != nil { - return nil, fmt.Errorf("creating client assertion: %w", err) + return nil, err } - v := url.Values{} - v.Set("grant_type", "refresh_token") - v.Set("refresh_token", refreshToken) - v.Set("client_id", c.cfg.Client().ClientID()) + requestBody := strings.NewReader(params.URLValues(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": c.cfg.Client().ClientID(), + }).Encode()) - for key, val := range openid.JwtAuthenticationParameters(assertion) { - v.Set(key, val) - } - - r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().TokenEndpoint(), strings.NewReader(v.Encode())) + r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().TokenEndpoint(), requestBody) if err != nil { return nil, fmt.Errorf("creating request: %w", err) } @@ -188,3 +152,52 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid return &tokenResponse, nil } + +func (c *Client) AuthParams() (openid.AuthParams, error) { + switch c.cfg.Client().AuthMethod() { + case openidconfig.AuthMethodPrivateKeyJWT: + assertion, err := c.MakeAssertion(DefaultClientAssertionLifetime) + if err != nil { + return nil, fmt.Errorf("creating client assertion: %w", err) + } + + return openid.AuthParamsJwtBearer(assertion), nil + + case openidconfig.AuthMethodClientSecret: + return openid.AuthParamsClientSecret(c.cfg.Client().ClientSecret()), nil + } + + return nil, fmt.Errorf("unsupported client authentication method: %q", c.cfg.Client().AuthMethod()) +} + +func (c *Client) MakeAssertion(expiration time.Duration) (string, error) { + clientCfg := c.cfg.Client() + providerCfg := c.cfg.Provider() + key := clientCfg.ClientJWK() + + iat := time.Now().Add(-5 * time.Second).Truncate(time.Second) + exp := iat.Add(expiration) + + errs := make([]error, 0) + + tok := jwt.New() + errs = append(errs, tok.Set(jwt.IssuerKey, clientCfg.ClientID())) + errs = append(errs, tok.Set(jwt.SubjectKey, clientCfg.ClientID())) + errs = append(errs, tok.Set(jwt.AudienceKey, providerCfg.Issuer())) + errs = append(errs, tok.Set(jwt.IssuedAtKey, iat)) + errs = append(errs, tok.Set(jwt.ExpirationKey, exp)) + errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String())) + + for _, err := range errs { + if err != nil { + return "", fmt.Errorf("setting claim for client assertion: %w", err) + } + } + + encoded, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key)) + if err != nil { + return "", fmt.Errorf("signing client assertion: %w", err) + } + + return string(encoded), nil +} diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index c5744c3..484f827 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -55,16 +55,15 @@ func (in *LoginCallback) StateMismatchError() error { } func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) { - clientAssertion, err := in.MakeAssertion(DefaultClientAssertionLifetime) + params, err := in.AuthParams() if err != nil { - return nil, fmt.Errorf("creating client assertion: %w", err) + return nil, err } - opts := []oauth2.AuthCodeOption{ + opts := params.AuthCodeOptions([]oauth2.AuthCodeOption{ openid.RedirectURIOption(in.cookie.RedirectURI), oauth2.VerifierOption(in.cookie.CodeVerifier), - } - opts = openid.WithJwtAuthentication(opts, clientAssertion) + }) code := in.requestParams.Get("code") rawTokens, err := in.AuthCodeGrant(ctx, code, opts) diff --git a/pkg/openid/config/client.go b/pkg/openid/config/client.go index fac1dff..5e483e5 100644 --- a/pkg/openid/config/client.go +++ b/pkg/openid/config/client.go @@ -4,16 +4,26 @@ import ( "fmt" "github.com/lestrrat-go/jwx/v2/jwk" + log "github.com/sirupsen/logrus" "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/openid/scopes" ) +type AuthMethod string + +const ( + AuthMethodPrivateKeyJWT AuthMethod = "private_key_jwt" + AuthMethodClientSecret AuthMethod = "client_secret" +) + type Client interface { ACRValues() string Audiences() map[string]bool + AuthMethod() AuthMethod ClientID() string ClientJWK() jwk.Key + ClientSecret() string PostLogoutRedirectURI() string ResourceIndicator() string Scopes() scopes.Scopes @@ -23,6 +33,7 @@ type Client interface { type client struct { config.OpenID + authMethod AuthMethod clientJwk jwk.Key trustedAudiences map[string]bool } @@ -35,6 +46,10 @@ func (in *client) Audiences() map[string]bool { return in.trustedAudiences } +func (in *client) AuthMethod() AuthMethod { + return in.authMethod +} + func (in *client) ClientID() string { return in.OpenID.ClientID } @@ -43,6 +58,10 @@ func (in *client) ClientJWK() jwk.Key { return in.clientJwk } +func (in *client) ClientSecret() string { + return in.OpenID.ClientSecret +} + func (in *client) PostLogoutRedirectURI() string { return in.OpenID.PostLogoutRedirectURI } @@ -64,20 +83,31 @@ func (in *client) WellKnownURL() string { } func NewClientConfig(cfg *config.Config) (Client, error) { - clientJwkString := cfg.OpenID.ClientJWK - if len(clientJwkString) == 0 { - return nil, fmt.Errorf("missing required config %s", config.OpenIDClientJWK) + c := &client{ + OpenID: cfg.OpenID, + trustedAudiences: cfg.OpenID.TrustedAudiences(), } - clientJwk, err := jwk.ParseKey([]byte(clientJwkString)) - if err != nil { - return nil, fmt.Errorf("parsing client JWK: %w", err) + if len(cfg.OpenID.ClientJWK) == 0 && len(cfg.OpenID.ClientSecret) == 0 { + return nil, fmt.Errorf("missing required config: at least one of %q or %q must be set", config.OpenIDClientJWK, config.OpenIDClientSecret) } - c := &client{ - OpenID: cfg.OpenID, - clientJwk: clientJwk, - trustedAudiences: cfg.OpenID.TrustedAudiences(), + if len(cfg.OpenID.ClientSecret) > 0 { + c.authMethod = AuthMethodClientSecret + } + + if len(cfg.OpenID.ClientJWK) > 0 { + if c.authMethod == AuthMethodClientSecret { + log.WithField("logger", "wonderwall.config").Info("both client JWK and client secret were set; using client JWK...") + } + + clientJwk, err := jwk.ParseKey([]byte(cfg.OpenID.ClientJWK)) + if err != nil { + return nil, fmt.Errorf("parsing client JWK: %w", err) + } + + c.clientJwk = clientJwk + c.authMethod = AuthMethodPrivateKeyJWT } var clientConfig Client @@ -87,17 +117,17 @@ func NewClientConfig(cfg *config.Config) (Client, error) { case config.ProviderAzure: clientConfig = c.Azure() case "": - return nil, fmt.Errorf("missing required config %s", config.OpenIDProvider) + return nil, fmt.Errorf("missing required config %q", config.OpenIDProvider) default: clientConfig = c } if len(clientConfig.ClientID()) == 0 { - return nil, fmt.Errorf("missing required config %s", config.OpenIDClientID) + return nil, fmt.Errorf("missing required config %q", config.OpenIDClientID) } if len(clientConfig.WellKnownURL()) == 0 { - return nil, fmt.Errorf("missing required config %s", config.OpenIDWellKnownURL) + return nil, fmt.Errorf("missing required config %q", config.OpenIDWellKnownURL) } return clientConfig, nil diff --git a/pkg/openid/oauth2.go b/pkg/openid/oauth2.go index 8a398ac..2fe5be5 100644 --- a/pkg/openid/oauth2.go +++ b/pkg/openid/oauth2.go @@ -21,20 +21,47 @@ type TokenErrorResponse struct { ErrorDescription string `json:"error_description"` } -// JwtAuthenticationParameters returns a map of parameters to be sent to the authorization server when using a JWT for client authentication in RFC 7523, section 2.2. -func JwtAuthenticationParameters(clientAssertion string) map[string]string { - return map[string]string{ - "client_assertion": clientAssertion, - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", +type AuthParams map[string]string + +// AuthCodeOptions adds AuthParams to the given [oauth2.AuthCodeOption] slice and returns the updated slice. +func (a AuthParams) AuthCodeOptions(opts []oauth2.AuthCodeOption) []oauth2.AuthCodeOption { + for key, val := range a { + opts = append(opts, oauth2.SetAuthURLParam(key, val)) } + + return opts } -func WithJwtAuthentication(opts []oauth2.AuthCodeOption, clientAssertion string) []oauth2.AuthCodeOption { - for k, v := range JwtAuthenticationParameters(clientAssertion) { - opts = append(opts, oauth2.SetAuthURLParam(k, v)) +// URLValues adds AuthParams to the given map of parameters and returns a [url.Values]. +func (a AuthParams) URLValues(params map[string]string) url.Values { + v := url.Values{} + + for key, val := range params { + v.Set(key, val) } - return opts + for key, val := range a { + v.Set(key, val) + } + + return v +} + +// AuthParamsClientSecret returns a map of parameters to be sent to the authorization server when using a client secret for client authentication in RFC 6749, section 2.3.1. +// The target authorization server must support the "client_secret_post" client authentication method. +func AuthParamsClientSecret(clientSecret string) AuthParams { + return map[string]string{ + "client_secret": clientSecret, + } +} + +// AuthParamsJwtBearer returns a map of parameters to be sent to the authorization server when using a JWT for client authentication in RFC 7523, section 2.2. +// The target authorization server must support the "private_key_jwt" client authentication method. +func AuthParamsJwtBearer(clientAssertion string) AuthParams { + return map[string]string{ + "client_assertion": clientAssertion, + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + } } func RedirectURIOption(redirectUri string) oauth2.AuthCodeOption {