diff --git a/samlsp/new.go b/samlsp/new.go index 81fa75f6..8bec69eb 100644 --- a/samlsp/new.go +++ b/samlsp/new.go @@ -2,21 +2,22 @@ package samlsp import ( + "crypto" + "crypto/ecdsa" "crypto/rsa" "crypto/x509" "net/http" "net/url" - dsig "github.com/russellhaering/goxmldsig" - "github.com/crewjam/saml" + "github.com/golang-jwt/jwt/v4" ) // Options represents the parameters for creating a new middleware type Options struct { EntityID string URL url.URL - Key *rsa.PrivateKey + Key crypto.Signer Certificate *x509.Certificate Intermediates []*x509.Certificate HTTPClient *http.Client @@ -33,11 +34,23 @@ type Options struct { LogoutBindings []string } +func getDefaultSigningMethod(signer crypto.Signer) jwt.SigningMethod { + if signer != nil { + switch signer.Public().(type) { + case *ecdsa.PublicKey: + return jwt.SigningMethodES256 + case *rsa.PublicKey: + return jwt.SigningMethodRS256 + } + } + return jwt.SigningMethodRS256 +} + // DefaultSessionCodec returns the default SessionCodec for the provided options, // a JWTSessionCodec configured to issue signed tokens. func DefaultSessionCodec(opts Options) JWTSessionCodec { return JWTSessionCodec{ - SigningMethod: defaultJWTSigningMethod, + SigningMethod: getDefaultSigningMethod(opts.Key), Audience: opts.URL.String(), Issuer: opts.URL.String(), MaxAge: defaultSessionMaxAge, @@ -67,7 +80,7 @@ func DefaultSessionProvider(opts Options) CookieSessionProvider { // options, a JWTTrackedRequestCodec that uses a JWT to encode TrackedRequests. func DefaultTrackedRequestCodec(opts Options) JWTTrackedRequestCodec { return JWTTrackedRequestCodec{ - SigningMethod: defaultJWTSigningMethod, + SigningMethod: getDefaultSigningMethod(opts.Key), Audience: opts.URL.String(), Issuer: opts.URL.String(), MaxAge: saml.MaxIssueDelay, @@ -99,9 +112,9 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { if opts.ForceAuthn { forceAuthn = &opts.ForceAuthn } - signatureMethod := dsig.RSASHA1SignatureMethod - if !opts.SignRequest { - signatureMethod = "" + var signatureMethod string + if opts.SignRequest { + signatureMethod = "auto" } if opts.DefaultRedirectURI == "" { diff --git a/samlsp/request_tracker_jwt.go b/samlsp/request_tracker_jwt.go index 0ca47258..caafe540 100644 --- a/samlsp/request_tracker_jwt.go +++ b/samlsp/request_tracker_jwt.go @@ -1,7 +1,7 @@ package samlsp import ( - "crypto/rsa" + "crypto" "fmt" "time" @@ -10,15 +10,13 @@ import ( "github.com/crewjam/saml" ) -var defaultJWTSigningMethod = jwt.SigningMethodRS256 - // JWTTrackedRequestCodec encodes TrackedRequests as signed JWTs type JWTTrackedRequestCodec struct { SigningMethod jwt.SigningMethod Audience string Issuer string MaxAge time.Duration - Key *rsa.PrivateKey + Key crypto.Signer } var _ TrackedRequestCodec = JWTTrackedRequestCodec{} diff --git a/samlsp/session_jwt.go b/samlsp/session_jwt.go index 8d801e47..d7217251 100644 --- a/samlsp/session_jwt.go +++ b/samlsp/session_jwt.go @@ -1,7 +1,7 @@ package samlsp import ( - "crypto/rsa" + "crypto" "errors" "fmt" "time" @@ -23,7 +23,7 @@ type JWTSessionCodec struct { Audience string Issuer string MaxAge time.Duration - Key *rsa.PrivateKey + Key crypto.Signer } var _ SessionCodec = JWTSessionCodec{} diff --git a/service_provider.go b/service_provider.go index 30b35670..72aa38f1 100644 --- a/service_provider.go +++ b/service_provider.go @@ -4,7 +4,7 @@ import ( "bytes" "compress/flate" "context" - "crypto/rsa" + "crypto" "crypto/tls" "crypto/x509" "encoding/base64" @@ -67,7 +67,7 @@ type ServiceProvider struct { EntityID string // Key is the RSA private key we use to sign requests. - Key *rsa.PrivateKey + Key crypto.Signer // Certificate is the RSA public part of Key. Certificate *x509.Certificate @@ -117,7 +117,7 @@ type ServiceProvider struct { // to verify signatures. SignatureVerifier SignatureVerifier - // SignatureMethod, if non-empty, authentication requests will be signed + // SignatureMethod, if non-empty, authentication requests will be signed. "auto" will determine method based on certificate type. SignatureMethod string // LogoutBindings specify the bindings available for SLO endpoint. If empty, @@ -141,6 +141,11 @@ const DefaultValidDuration = time.Hour * 24 * 2 // DefaultCacheDuration is how long we ask the IDP to cache the SP metadata. const DefaultCacheDuration = time.Hour * 24 * 1 +// SignRequests returns true if the service provider should sign requests. +func (sp *ServiceProvider) SignRequests() bool { + return len(sp.SignatureMethod) > 0 +} + // Metadata returns the service provider metadata func (sp *ServiceProvider) Metadata() *EntityDescriptor { validDuration := DefaultValidDuration @@ -245,6 +250,19 @@ func (sp *ServiceProvider) MakeRedirectAuthenticationRequest(relayState string) return req.Redirect(relayState, sp) } +// GetSignatureMethod returns the appropriate string to represent the +// signature method for the service provider. +func (sp *ServiceProvider) GetSignatureMethod() (string, error) { + if sp.SignatureMethod == "auto" { + signingContext, err := GetSigningContext(sp) + if err != nil { + return "auto", err + } + return signingContext.GetSignatureMethodIdentifier(), nil + } + return sp.SignatureMethod, nil +} + // Redirect returns a URL suitable for using the redirect binding with the request func (r *AuthnRequest) Redirect(relayState string, sp *ServiceProvider) (*url.URL, error) { w := &bytes.Buffer{} @@ -274,13 +292,16 @@ func (r *AuthnRequest) Redirect(relayState string, sp *ServiceProvider) (*url.UR if relayState != "" { query += "&RelayState=" + relayState } - if len(sp.SignatureMethod) > 0 { - query += "&SigAlg=" + url.QueryEscape(sp.SignatureMethod) + if sp.SignRequests() { signingContext, err := GetSigningContext(sp) - if err != nil { return nil, err } + sigMethod, err := sp.GetSignatureMethod() + if err != nil { + return nil, err + } + query += "&SigAlg=" + url.QueryEscape(sigMethod) sig, err := signingContext.SignString(query) if err != nil { @@ -391,7 +412,7 @@ func (sp *ServiceProvider) MakeArtifactResolveRequest(artifactID string) (*Artif Artifact: artifactID, } - if len(sp.SignatureMethod) > 0 { + if sp.SignRequests() { if err := sp.SignArtifactResolve(&req); err != nil { return nil, err } @@ -428,7 +449,7 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string, binding stri RequestedAuthnContext: sp.RequestedAuthnContext, } // We don't need to sign the XML document if the IDP uses HTTP-Redirect binding - if len(sp.SignatureMethod) > 0 && binding == HTTPPostBinding { + if sp.SignRequests() && binding == HTTPPostBinding { if err := sp.SignAuthnRequest(&req); err != nil { return nil, err } @@ -449,13 +470,13 @@ func GetSigningContext(sp *ServiceProvider) (*dsig.SigningContext, error) { // } keyStore := dsig.TLSCertKeyStore(keyPair) - if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && - sp.SignatureMethod != dsig.RSASHA256SignatureMethod && - sp.SignatureMethod != dsig.RSASHA512SignatureMethod { - return nil, fmt.Errorf("invalid signing method %s", sp.SignatureMethod) + signer, _ := sp.Key.(crypto.Signer) + chain, _ := keyStore.GetChain() + signingContext, err := dsig.NewSigningContext(signer, chain) + if err != nil { + return nil, err } - signatureMethod := sp.SignatureMethod - signingContext := dsig.NewDefaultSigningContext(keyStore) + signatureMethod := signingContext.GetSignatureMethodIdentifier() signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { return nil, err @@ -1170,13 +1191,13 @@ func (sp *ServiceProvider) SignLogoutRequest(req *LogoutRequest) error { // } keyStore := dsig.TLSCertKeyStore(keyPair) - if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && - sp.SignatureMethod != dsig.RSASHA256SignatureMethod && - sp.SignatureMethod != dsig.RSASHA512SignatureMethod { - return fmt.Errorf("invalid signing method %s", sp.SignatureMethod) + signer, _ := sp.Key.(crypto.Signer) + chain, _ := keyStore.GetChain() + signingContext, err := dsig.NewSigningContext(signer, chain) + if err != nil { + return err } - signatureMethod := sp.SignatureMethod - signingContext := dsig.NewDefaultSigningContext(keyStore) + signatureMethod := signingContext.GetSignatureMethodIdentifier() signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { return err @@ -1213,7 +1234,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ SPNameQualifier: sp.Metadata().EntityID, }, } - if len(sp.SignatureMethod) > 0 { + if sp.SignRequests() { if err := sp.SignLogoutRequest(&req); err != nil { return nil, err } @@ -1327,7 +1348,7 @@ func (sp *ServiceProvider) MakeLogoutResponse(idpURL, logoutRequestID string) (* }, } - if len(sp.SignatureMethod) > 0 { + if sp.SignRequests() { if err := sp.SignLogoutResponse(&response); err != nil { return nil, err } @@ -1435,13 +1456,13 @@ func (sp *ServiceProvider) SignLogoutResponse(resp *LogoutResponse) error { // } keyStore := dsig.TLSCertKeyStore(keyPair) - if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && - sp.SignatureMethod != dsig.RSASHA256SignatureMethod && - sp.SignatureMethod != dsig.RSASHA512SignatureMethod { - return fmt.Errorf("invalid signing method %s", sp.SignatureMethod) + signer, _ := sp.Key.(crypto.Signer) + chain, _ := keyStore.GetChain() + signingContext, err := dsig.NewSigningContext(signer, chain) + if err != nil { + return err } - signatureMethod := sp.SignatureMethod - signingContext := dsig.NewDefaultSigningContext(keyStore) + signatureMethod := signingContext.GetSignatureMethodIdentifier() signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { return err diff --git a/service_provider_test.go b/service_provider_test.go index 4309738c..2b6e5480 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -350,7 +350,7 @@ func TestSPFailToProduceSignedRequestWithBogusSignatureMethod(t *testing.T) { assert.Check(t, err) _, err = s.MakeRedirectAuthenticationRequest("relayState") - assert.Check(t, is.ErrorContains(err, "invalid signing method bogus")) + assert.Check(t, is.ErrorContains(err, "unknown SignatureMethod: bogus")) } func TestSPCanProducePostLogoutRequest(t *testing.T) { @@ -1665,7 +1665,7 @@ func TestMakeSignedArtifactResolveRequestWithBogusSignatureMethod(t *testing.T) } _, err := sp.MakeArtifactResolveRequest("artifactId") - assert.Check(t, is.ErrorContains(err, "invalid signing method bogus")) + assert.Check(t, is.ErrorContains(err, "unknown SignatureMethod: bogus")) }