Skip to content

Commit

Permalink
feat(openid/client): add support for the client_secret_post authentic…
Browse files Browse the repository at this point in the history
…ation method
  • Loading branch information
tronghn committed Oct 8, 2024
1 parent 5df7234 commit df5c78b
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 76 deletions.
7 changes: 4 additions & 3 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ The following flags are available:
| `openid.acr-values` | string | Space separated string that configures the default security level (`acr_values`) parameter for authorization requests. | |
| `openid.audiences` | strings | List of additional trusted audiences (other than the client_id) for OpenID Connect id_token validation. | |
| `openid.client-id` | string | Client ID for the OpenID client. | |
| `openid.client-jwk` | string | JWK containing the private key for the OpenID client in string format. | |
| `openid.client-jwk` | string | JWK containing the private key for the OpenID client in string format. If configured, this takes precedence over 'openid.client-secret'. | |
| `openid.client-secret` | string | Client secret for the OpenID client. Overridden by 'openid.client-jwk', if configured. | |
| `openid.post-logout-redirect-uri` | string | URI for redirecting the user after successful logout at the Identity Provider. | |
| `openid.provider` | string | Provider configuration to load and use, either `openid`, `azure`, `idporten`. | `openid` |
| `openid.resource-indicator` | string | OAuth2 resource indicator to include in authorization request for acquiring audience-restricted tokens. | |
Expand Down Expand Up @@ -82,7 +83,7 @@ The default configuration of Wonderwall will start in [_standalone mode_](archit
At minimum, the following configuration must be provided when in standalone mode:

- `openid.client-id`
- `openid.client-jwk`
- `openid.client-jwk` or `openid.client-secret`
- `openid.well-known-url`
- `ingress`

Expand All @@ -99,7 +100,7 @@ When the `sso.enabled` flag is enabled and the `sso.mode` flag is set to `server
At minimum, the following configuration must be provided when in SSO server mode:

- `openid.client-id`
- `openid.client-jwk`
- `openid.client-jwk` or `openid.client-secret`
- `openid.well-known-url`
- `ingress`
- `redis.address`
Expand Down
1 change: 1 addition & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func Initialize() (*Config, error) {
masked := *cfg
masked.EncryptionKey = redacted
masked.OpenID.ClientJWK = redacted
masked.OpenID.ClientSecret = redacted
masked.Redis.Password = redacted
logger.Infof("config: %+v", masked)

Expand Down
5 changes: 4 additions & 1 deletion pkg/config/openid.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type OpenID struct {
Audiences []string `json:"audiences"`
ClientID string `json:"client-id"`
ClientJWK string `json:"client-jwk"`
ClientSecret string `json:"client-secret"`
PostLogoutRedirectURI string `json:"post-logout-redirect-uri"`
Provider Provider `json:"provider"`
ResourceIndicator string `json:"resource-indicator"`
Expand All @@ -43,6 +44,7 @@ const (
OpenIDAudiences = "openid.audiences"
OpenIDClientID = "openid.client-id"
OpenIDClientJWK = "openid.client-jwk"
OpenIDClientSecret = "openid.client-secret"
OpenIDPostLogoutRedirectURI = "openid.post-logout-redirect-uri"
OpenIDProvider = "openid.provider"
OpenIDResourceIndicator = "openid.resource-indicator"
Expand All @@ -55,7 +57,8 @@ func openidFlags() {
flag.String(OpenIDACRValues, "", "Space separated string that configures the default security level (acr_values) parameter for authorization requests.")
flag.StringSlice(OpenIDAudiences, []string{}, "List of additional trusted audiences (other than the client_id) for OpenID Connect id_token validation.")
flag.String(OpenIDClientID, "", "Client ID for the OpenID client.")
flag.String(OpenIDClientJWK, "", "JWK containing the private key for the OpenID client in string format.")
flag.String(OpenIDClientJWK, "", "JWK containing the private key for the OpenID client in string format. If configured, this takes precedence over 'openid.client-secret'.")
flag.String(OpenIDClientSecret, "", "Client secret for the OpenID client. Overridden by 'openid.client-jwk', if configured.")
flag.String(OpenIDPostLogoutRedirectURI, "", "URI for redirecting the user after successful logout at the Identity Provider.")
flag.String(OpenIDProvider, string(ProviderOpenID), "Provider configuration to load and use, either 'openid', 'azure', 'idporten'.")
flag.String(OpenIDResourceIndicator, "", "OAuth2 resource indicator to include in authorization request for acquiring audience-restricted tokens.")
Expand Down
10 changes: 9 additions & 1 deletion pkg/mock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package mock

import (
"github.com/lestrrat-go/jwx/v2/jwk"

"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/crypto"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/openid/scopes"
)

Expand All @@ -22,6 +22,10 @@ func (c *TestClientConfiguration) Audiences() map[string]bool {
return c.trustedAudiences
}

func (c *TestClientConfiguration) AuthMethod() openidconfig.AuthMethod {
return openidconfig.AuthMethodPrivateKeyJWT
}

func (c *TestClientConfiguration) ClientID() string {
return c.Config.OpenID.ClientID
}
Expand All @@ -30,6 +34,10 @@ func (c *TestClientConfiguration) ClientJWK() jwk.Key {
return c.clientJwk
}

func (c *TestClientConfiguration) ClientSecret() string {
return c.Config.OpenID.ClientSecret
}

func (c *TestClientConfiguration) SetPostLogoutRedirectURI(uri string) {
c.Config.OpenID.PostLogoutRedirectURI = uri
}
Expand Down
101 changes: 57 additions & 44 deletions pkg/openid/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -107,54 +106,19 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A
return c.oauth2Config.Exchange(ctx, code, opts...)
}

func (c *Client) MakeAssertion(expiration time.Duration) (string, error) {
clientCfg := c.cfg.Client()
providerCfg := c.cfg.Provider()
key := clientCfg.ClientJWK()

iat := time.Now().Add(-5 * time.Second).Truncate(time.Second)
exp := iat.Add(expiration)

errs := make([]error, 0)

tok := jwt.New()
errs = append(errs, tok.Set(jwt.IssuerKey, clientCfg.ClientID()))
errs = append(errs, tok.Set(jwt.SubjectKey, clientCfg.ClientID()))
errs = append(errs, tok.Set(jwt.AudienceKey, providerCfg.Issuer()))
errs = append(errs, tok.Set(jwt.IssuedAtKey, iat))
errs = append(errs, tok.Set(jwt.ExpirationKey, exp))
errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String()))

for _, err := range errs {
if err != nil {
return "", fmt.Errorf("setting claim for client assertion: %w", err)
}
}

encoded, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key))
if err != nil {
return "", fmt.Errorf("signing client assertion: %w", err)
}

return string(encoded), nil
}

func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) {
assertion, err := c.MakeAssertion(DefaultClientAssertionLifetime)
params, err := c.AuthParams()
if err != nil {
return nil, fmt.Errorf("creating client assertion: %w", err)
return nil, err
}

v := url.Values{}
v.Set("grant_type", "refresh_token")
v.Set("refresh_token", refreshToken)
v.Set("client_id", c.cfg.Client().ClientID())
requestBody := strings.NewReader(params.URLValues(map[string]string{
"grant_type": "refresh_token",
"refresh_token": refreshToken,
"client_id": c.cfg.Client().ClientID(),
}).Encode())

for key, val := range openid.JwtAuthenticationParameters(assertion) {
v.Set(key, val)
}

r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().TokenEndpoint(), strings.NewReader(v.Encode()))
r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().TokenEndpoint(), requestBody)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
Expand Down Expand Up @@ -188,3 +152,52 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid

return &tokenResponse, nil
}

func (c *Client) AuthParams() (openid.AuthParams, error) {
switch c.cfg.Client().AuthMethod() {
case openidconfig.AuthMethodPrivateKeyJWT:
assertion, err := c.MakeAssertion(DefaultClientAssertionLifetime)
if err != nil {
return nil, fmt.Errorf("creating client assertion: %w", err)
}

return openid.AuthParamsJwtBearer(assertion), nil

case openidconfig.AuthMethodClientSecret:
return openid.AuthParamsClientSecret(c.cfg.Client().ClientSecret()), nil
}

return nil, fmt.Errorf("unsupported client authentication method: %q", c.cfg.Client().AuthMethod())
}

func (c *Client) MakeAssertion(expiration time.Duration) (string, error) {
clientCfg := c.cfg.Client()
providerCfg := c.cfg.Provider()
key := clientCfg.ClientJWK()

iat := time.Now().Add(-5 * time.Second).Truncate(time.Second)
exp := iat.Add(expiration)

errs := make([]error, 0)

tok := jwt.New()
errs = append(errs, tok.Set(jwt.IssuerKey, clientCfg.ClientID()))
errs = append(errs, tok.Set(jwt.SubjectKey, clientCfg.ClientID()))
errs = append(errs, tok.Set(jwt.AudienceKey, providerCfg.Issuer()))
errs = append(errs, tok.Set(jwt.IssuedAtKey, iat))
errs = append(errs, tok.Set(jwt.ExpirationKey, exp))
errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String()))

for _, err := range errs {
if err != nil {
return "", fmt.Errorf("setting claim for client assertion: %w", err)
}
}

encoded, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key))
if err != nil {
return "", fmt.Errorf("signing client assertion: %w", err)
}

return string(encoded), nil
}
9 changes: 4 additions & 5 deletions pkg/openid/client/login_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,15 @@ func (in *LoginCallback) StateMismatchError() error {
}

func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) {
clientAssertion, err := in.MakeAssertion(DefaultClientAssertionLifetime)
params, err := in.AuthParams()
if err != nil {
return nil, fmt.Errorf("creating client assertion: %w", err)
return nil, err
}

opts := []oauth2.AuthCodeOption{
opts := params.AuthCodeOptions([]oauth2.AuthCodeOption{
openid.RedirectURIOption(in.cookie.RedirectURI),
oauth2.VerifierOption(in.cookie.CodeVerifier),
}
opts = openid.WithJwtAuthentication(opts, clientAssertion)
})

code := in.requestParams.Get("code")
rawTokens, err := in.AuthCodeGrant(ctx, code, opts)
Expand Down
56 changes: 43 additions & 13 deletions pkg/openid/config/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,26 @@ import (
"fmt"

"github.com/lestrrat-go/jwx/v2/jwk"
log "github.com/sirupsen/logrus"

"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/openid/scopes"
)

type AuthMethod string

const (
AuthMethodPrivateKeyJWT AuthMethod = "private_key_jwt"
AuthMethodClientSecret AuthMethod = "client_secret"
)

type Client interface {
ACRValues() string
Audiences() map[string]bool
AuthMethod() AuthMethod
ClientID() string
ClientJWK() jwk.Key
ClientSecret() string
PostLogoutRedirectURI() string
ResourceIndicator() string
Scopes() scopes.Scopes
Expand All @@ -23,6 +33,7 @@ type Client interface {

type client struct {
config.OpenID
authMethod AuthMethod
clientJwk jwk.Key
trustedAudiences map[string]bool
}
Expand All @@ -35,6 +46,10 @@ func (in *client) Audiences() map[string]bool {
return in.trustedAudiences
}

func (in *client) AuthMethod() AuthMethod {
return in.authMethod
}

func (in *client) ClientID() string {
return in.OpenID.ClientID
}
Expand All @@ -43,6 +58,10 @@ func (in *client) ClientJWK() jwk.Key {
return in.clientJwk
}

func (in *client) ClientSecret() string {
return in.OpenID.ClientSecret
}

func (in *client) PostLogoutRedirectURI() string {
return in.OpenID.PostLogoutRedirectURI
}
Expand All @@ -64,20 +83,31 @@ func (in *client) WellKnownURL() string {
}

func NewClientConfig(cfg *config.Config) (Client, error) {
clientJwkString := cfg.OpenID.ClientJWK
if len(clientJwkString) == 0 {
return nil, fmt.Errorf("missing required config %s", config.OpenIDClientJWK)
c := &client{
OpenID: cfg.OpenID,
trustedAudiences: cfg.OpenID.TrustedAudiences(),
}

clientJwk, err := jwk.ParseKey([]byte(clientJwkString))
if err != nil {
return nil, fmt.Errorf("parsing client JWK: %w", err)
if len(cfg.OpenID.ClientJWK) == 0 && len(cfg.OpenID.ClientSecret) == 0 {
return nil, fmt.Errorf("missing required config: at least one of %q or %q must be set", config.OpenIDClientJWK, config.OpenIDClientSecret)
}

c := &client{
OpenID: cfg.OpenID,
clientJwk: clientJwk,
trustedAudiences: cfg.OpenID.TrustedAudiences(),
if len(cfg.OpenID.ClientSecret) > 0 {
c.authMethod = AuthMethodClientSecret
}

if len(cfg.OpenID.ClientJWK) > 0 {
if c.authMethod == AuthMethodClientSecret {
log.WithField("logger", "wonderwall.config").Info("both client JWK and client secret were set; using client JWK...")
}

clientJwk, err := jwk.ParseKey([]byte(cfg.OpenID.ClientJWK))
if err != nil {
return nil, fmt.Errorf("parsing client JWK: %w", err)
}

c.clientJwk = clientJwk
c.authMethod = AuthMethodPrivateKeyJWT
}

var clientConfig Client
Expand All @@ -87,17 +117,17 @@ func NewClientConfig(cfg *config.Config) (Client, error) {
case config.ProviderAzure:
clientConfig = c.Azure()
case "":
return nil, fmt.Errorf("missing required config %s", config.OpenIDProvider)
return nil, fmt.Errorf("missing required config %q", config.OpenIDProvider)
default:
clientConfig = c
}

if len(clientConfig.ClientID()) == 0 {
return nil, fmt.Errorf("missing required config %s", config.OpenIDClientID)
return nil, fmt.Errorf("missing required config %q", config.OpenIDClientID)
}

if len(clientConfig.WellKnownURL()) == 0 {
return nil, fmt.Errorf("missing required config %s", config.OpenIDWellKnownURL)
return nil, fmt.Errorf("missing required config %q", config.OpenIDWellKnownURL)
}

return clientConfig, nil
Expand Down
Loading

0 comments on commit df5c78b

Please sign in to comment.