Skip to content

Commit

Permalink
Bed-5008 feat: Add role provision support (#1043)
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 authored Jan 2, 2025
1 parent 7c1676e commit 19089bc
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 46 deletions.
40 changes: 22 additions & 18 deletions cmd/api/src/api/v2/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/specterops/bloodhound/headers"
"github.com/specterops/bloodhound/log"
"github.com/specterops/bloodhound/src/api"
"github.com/specterops/bloodhound/src/api/v2"
"github.com/specterops/bloodhound/src/config"
"github.com/specterops/bloodhound/src/ctx"
"github.com/specterops/bloodhound/src/database"
Expand All @@ -45,11 +46,12 @@ var (
)

type oidcClaims struct {
Name string `json:"name"`
FamilyName string `json:"family_name"`
DisplayName string `json:"given_name"`
Email string `json:"email"`
Verified bool `json:"email_verified"`
Name string `json:"name"`
FamilyName string `json:"family_name"`
DisplayName string `json:"given_name"`
Email string `json:"email"`
Verified bool `json:"email_verified"`
Roles []string `json:"roles"`
}

// UpsertOIDCProviderRequest represents the body of create & update provider endpoints
Expand Down Expand Up @@ -148,16 +150,16 @@ func (s ManagementResource) OIDCLoginHandler(response http.ResponseWriter, reque

if ssoProvider.OIDCProvider == nil {
// SSO misconfiguration scenario
redirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
} else if state, err := config.GenerateRandomBase64String(77); err != nil {
log.Errorf("[OIDC] Failed to generate state: %v", err)
// Technical issues scenario
redirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
v2.RedirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
} else if provider, err := oidc.NewProvider(request.Context(), ssoProvider.OIDCProvider.Issuer); err != nil {
log.Errorf("[OIDC] Failed to create OIDC provider: %v", err)
// SSO misconfiguration or technical issue
// Treat this as a misconfiguration scenario
redirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
} else {
conf := &oauth2.Config{
ClientID: ssoProvider.OIDCProvider.ClientID,
Expand Down Expand Up @@ -193,27 +195,27 @@ func (s ManagementResource) OIDCCallbackHandler(response http.ResponseWriter, re

if ssoProvider.OIDCProvider == nil {
// SSO misconfiguration scenario
redirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
} else if len(code) == 0 {
// Missing authorization code implies a credentials or form issue
// Not explicitly covered, treat as technical issue
redirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
v2.RedirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
} else if pkceVerifier, err := request.Cookie(api.AuthPKCECookieName); err != nil {
// Missing PKCE verifier - likely a technical or config issue
redirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
v2.RedirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
} else if len(state) == 0 {
// Missing state parameter - treat as technical issue
redirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
v2.RedirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
} else if stateCookie, err := request.Cookie(api.AuthStateCookieName); err != nil || stateCookie.Value != state[0] {
// Invalid state - treat as technical issue or misconfiguration
redirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
v2.RedirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
} else if provider, err := oidc.NewProvider(request.Context(), ssoProvider.OIDCProvider.Issuer); err != nil {
log.Errorf("[OIDC] Failed to create OIDC provider: %v", err)
// SSO misconfiguration scenario
redirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
} else if claims, err := getOIDCClaims(request.Context(), provider, ssoProvider, pkceVerifier, code[0]); err != nil {
log.Errorf("[OIDC] %v", err)
redirectToLoginPage(response, request, "Your SSO was unable to authenticate your user, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO was unable to authenticate your user, please contact your Administrator")
} else {
if ssoProvider.Config.AutoProvision.Enabled {
if err := jitOIDCUserCreation(request.Context(), ssoProvider, claims, s.db); err != nil {
Expand Down Expand Up @@ -254,15 +256,17 @@ func getOIDCClaims(reqCtx context.Context, provider *oidc.Provider, ssoProvider
}

func jitOIDCUserCreation(ctx context.Context, ssoProvider model.SSOProvider, claims oidcClaims, u jitUserCreator) error {
if role, err := u.GetRole(ctx, ssoProvider.Config.AutoProvision.DefaultRoleId); err != nil {
return fmt.Errorf("get role: %v", err)
if roles, err := SanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, claims.Roles, u); err != nil {
return fmt.Errorf("sanitize roles: %v", err)
} else if len(roles) != 1 {
return fmt.Errorf("invalid roles")
} else if _, err := u.LookupUser(ctx, claims.Email); err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("lookup user: %v", err)
} else if errors.Is(err, database.ErrNotFound) {
var user = model.User{
EmailAddress: null.StringFrom(claims.Email),
PrincipalName: claims.Email,
Roles: model.Roles{role},
Roles: roles,
SSOProviderID: null.Int32From(ssoProvider.ID),
EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users
FirstName: null.StringFrom(claims.Email),
Expand Down
32 changes: 17 additions & 15 deletions cmd/api/src/api/v2/auth/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"github.com/specterops/bloodhound/log"
"github.com/specterops/bloodhound/mediatypes"
"github.com/specterops/bloodhound/src/api"
v2 "github.com/specterops/bloodhound/src/api/v2"
"github.com/specterops/bloodhound/src/api/v2"
"github.com/specterops/bloodhound/src/auth"
"github.com/specterops/bloodhound/src/ctx"
"github.com/specterops/bloodhound/src/database"
Expand Down Expand Up @@ -379,12 +379,12 @@ func (s ManagementResource) ServeSigningCertificate(response http.ResponseWriter
func (s ManagementResource) SAMLLoginHandler(response http.ResponseWriter, request *http.Request, ssoProvider model.SSOProvider) {
if ssoProvider.SAMLProvider == nil {
// SAML misconfiguration scenario
redirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")

} else if serviceProvider, err := auth.NewServiceProvider(*ctx.Get(request.Context()).Host, s.config, *ssoProvider.SAMLProvider); err != nil {
log.Errorf("[SAML] Service provider creation failed: %v", err)
// Technical issues scenario
redirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
v2.RedirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
} else {
var (
binding = saml.HTTPRedirectBinding
Expand All @@ -400,14 +400,14 @@ func (s ManagementResource) SAMLLoginHandler(response http.ResponseWriter, reque
log.Errorf("[SAML] Failed creating SAML authentication request: %v", err)
// SAML misconfiguration or technical issue
// Since this likely indicates a configuration problem, we treat it as a misconfiguration scenario
redirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
} else {
switch binding {
case saml.HTTPRedirectBinding:
if redirectURL, err := authReq.Redirect("", &serviceProvider); err != nil {
log.Errorf("[SAML] Failed to format a redirect for SAML provider %s: %v", serviceProvider.EntityID, err)
// Likely a technical or configuration issue
redirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
} else {
response.Header().Add(headers.Location.String(), redirectURL.String())
response.WriteHeader(http.StatusFound)
Expand All @@ -421,13 +421,13 @@ func (s ManagementResource) SAMLLoginHandler(response http.ResponseWriter, reque
if _, err := response.Write([]byte(fmt.Sprintf(authInitiationContentBodyFormat, authReq.Post("")))); err != nil {
log.Errorf("[SAML] Failed to write response with HTTP POST binding: %v", err)
// Technical issues scenario
redirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
v2.RedirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
}

default:
log.Errorf("[SAML] Unhandled binding type %s", binding)
// Treating unknown binding as a misconfiguration
redirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
}
}
}
Expand All @@ -437,15 +437,15 @@ func (s ManagementResource) SAMLLoginHandler(response http.ResponseWriter, reque
func (s ManagementResource) SAMLCallbackHandler(response http.ResponseWriter, request *http.Request, ssoProvider model.SSOProvider) {
if ssoProvider.SAMLProvider == nil {
// SAML misconfiguration
redirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO Connection failed, please contact your Administrator")
} else if serviceProvider, err := auth.NewServiceProvider(*ctx.Get(request.Context()).Host, s.config, *ssoProvider.SAMLProvider); err != nil {
log.Errorf("[SAML] Service provider creation failed: %v", err)
redirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
v2.RedirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
} else if err := request.ParseForm(); err != nil {
log.Errorf("[SAML] Failed to parse form POST: %v", err)
// Technical issues or invalid form data
// This is not covered by acceptance criteria directly; treat as technical issue
redirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
v2.RedirectToLoginPage(response, request, "We’re having trouble connecting. Please check your internet and try again.")
} else if assertion, err := serviceProvider.ParseResponse(request, nil); err != nil {
var typedErr *saml.InvalidResponseError
switch {
Expand All @@ -455,11 +455,11 @@ func (s ManagementResource) SAMLCallbackHandler(response http.ResponseWriter, re
log.Errorf("[SAML] Failed to parse ACS response for provider %s: %v", ssoProvider.SAMLProvider.IssuerURI, err)
}
// SAML credentials issue scenario (authentication failed)
redirectToLoginPage(response, request, "Your SSO was unable to authenticate your user, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO was unable to authenticate your user, please contact your Administrator")
} else if principalName, err := ssoProvider.SAMLProvider.GetSAMLUserPrincipalNameFromAssertion(assertion); err != nil {
log.Errorf("[SAML] Failed to lookup user for SAML provider %s: %v", ssoProvider.Name, err)
// SAML credentials issue scenario again
redirectToLoginPage(response, request, "Your SSO was unable to authenticate your user, please contact your Administrator")
v2.RedirectToLoginPage(response, request, "Your SSO was unable to authenticate your user, please contact your Administrator")
} else {
if ssoProvider.Config.AutoProvision.Enabled {
if err := jitSAMLUserCreation(request.Context(), ssoProvider, principalName, assertion, s.db); err != nil {
Expand All @@ -473,15 +473,17 @@ func (s ManagementResource) SAMLCallbackHandler(response http.ResponseWriter, re
}

func jitSAMLUserCreation(ctx context.Context, ssoProvider model.SSOProvider, principalName string, assertion *saml.Assertion, u jitUserCreator) error {
if role, err := u.GetRole(ctx, ssoProvider.Config.AutoProvision.DefaultRoleId); err != nil {
return fmt.Errorf("get role: %v", err)
if roles, err := SanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, ssoProvider.SAMLProvider.GetSAMLUserRolesFromAssertion(assertion), u); err != nil {
return fmt.Errorf("sanitize roles: %v", err)
} else if len(roles) != 1 {
return fmt.Errorf("invalid roles detected")
} else if _, err := u.LookupUser(ctx, principalName); err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("lookup user: %v", err)
} else if errors.Is(err, database.ErrNotFound) {
user := model.User{
EmailAddress: null.StringFrom(principalName),
PrincipalName: principalName,
Roles: model.Roles{role},
Roles: roles,
SSOProviderID: null.Int32From(ssoProvider.ID),
EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users
FirstName: null.StringFrom(principalName),
Expand Down
60 changes: 47 additions & 13 deletions cmd/api/src/api/v2/auth/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ package auth

import (
"context"
"fmt"
"net/http"
"net/url"
"path"
"strconv"
"strings"

"github.com/specterops/bloodhound/headers"

"github.com/gorilla/mux"
"github.com/specterops/bloodhound/dawgs/cardinality"
"github.com/specterops/bloodhound/log"
"github.com/specterops/bloodhound/src/api"
"github.com/specterops/bloodhound/src/auth"
"github.com/specterops/bloodhound/src/ctx"
Expand Down Expand Up @@ -61,8 +62,12 @@ type getRoler interface {
GetRole(ctx context.Context, roleID int32) (model.Role, error)
}

type getAllRoler interface {
GetAllRoles(ctx context.Context, order string, filter model.SQLFilter) (model.Roles, error)
}

type jitUserCreator interface {
getRoler
getAllRoler

LookupUser(ctx context.Context, principalNameOrEmail string) (model.User, error)
CreateUser(ctx context.Context, user model.User) (model.User, error)
Expand Down Expand Up @@ -245,16 +250,45 @@ func (s ManagementResource) SSOCallbackHandler(response http.ResponseWriter, req
}
}

func redirectToLoginPage(response http.ResponseWriter, request *http.Request, errorMessage string) {
hostURL := *ctx.FromRequest(request).Host
redirectURL := api.URLJoinPath(hostURL, api.UserInterfacePath)
func SanitizeAndGetRoles(ctx context.Context, autoProvisionConfig model.SSOProviderAutoProvisionConfig, maybeBHRoles []string, r getAllRoler) (model.Roles, error) {
if dbRoles, err := r.GetAllRoles(ctx, "", model.SQLFilter{}); err != nil {
return nil, err
} else {
var defaultRole model.Role
dbRolesBySlug := make(map[string]*model.Role)
// Make quick lookup by role slug -> lower cased, dashes for spaces, and prefixed by `bh` e.g. bh-power-user
for _, r := range dbRoles {
dbRolesBySlug[fmt.Sprintf("bh-%s", strings.ReplaceAll(strings.ToLower(r.Name), " ", "-"))] = &r
if r.ID == autoProvisionConfig.DefaultRoleId {
defaultRole = r
}
}

// Optionally, include the error message as a query parameter or in session storage
query := redirectURL.Query()
query.Set("error", errorMessage)
redirectURL.RawQuery = query.Encode()
if autoProvisionConfig.RoleProvision {
var validRoles model.Roles
validRolesSeen := cardinality.NewBitmap32() // Ensure no dupes
// Only add valid roles
for _, r := range maybeBHRoles {
if dbRole := dbRolesBySlug[strings.ReplaceAll(strings.ToLower(r), " ", "-")]; dbRole != nil && !validRolesSeen.Contains(uint32(dbRole.ID)) {
validRoles = append(validRoles, *dbRole)
validRolesSeen.Add(uint32(dbRole.ID))
}
}
switch {
case len(validRoles) == 1:
return validRoles, nil
case len(validRoles) > 1:
log.Warnf("[SSO] JIT Role Provision detected multiple valid roles - %s , falling back to default role %s", validRoles.Names(), defaultRole.Name)
default:
log.Warnf("[SSO] JIT Role Provision detected no valid roles from %s , falling back to default role %s", maybeBHRoles, defaultRole.Name)
}
}

// Redirect to the login page
response.Header().Add(headers.Location.String(), redirectURL.String())
response.WriteHeader(http.StatusFound)
/* Fallback to default role:
- Role provision is disabled
- Role provision is enabled but no valid roles are found
- Role provision is enabled but multiple valid roles are found
*/
return model.Roles{defaultRole}, nil
}
}
Loading

0 comments on commit 19089bc

Please sign in to comment.