diff --git a/assets/cla/consent.yaml b/assets/cla/consent.yaml index a5958a6..ce8bcb6 100644 --- a/assets/cla/consent.yaml +++ b/assets/cla/consent.yaml @@ -39,3 +39,6 @@ email: riccardopiola@live.it - name: Nick Gregory email: ng-cla@openenterprise.co.uk +- name: Michael Ellis + email: Michael94Ellis@gmail.com + diff --git a/pkg/idp/oauth/user.go b/pkg/idp/oauth/user.go index 44e1632..7a7dd97 100644 --- a/pkg/idp/oauth/user.go +++ b/pkg/idp/oauth/user.go @@ -28,6 +28,10 @@ import ( "strings" ) +type discordMember struct { + Roles []string `json:"roles"` +} + type userData struct { Groups []string `json:"groups,omitempty"` } @@ -425,25 +429,68 @@ func (b *IdentityProvider) fetchDiscordGuilds(authToken string) (*userData, erro continue } + b.logger.Debug( + "Checking Guild Permissions", + zap.String("guildName", guild["name"].(string)), + ) + // Check if the user has special permissions if _, exists := guild["permissions"]; exists { - perm, err := strconv.Atoi(guild["permissions"].(string)) + // Parses to int64 for 32-bit system support + perm, err := strconv.ParseInt(guild["permissions"].(string), 10, 64) if err != nil { - continue - } - if (perm & 0x08) == 0x08 { // Check for admin privileges + b.logger.Debug( + "Error converting Guild permissions to integer", + zap.Any("error", err), + ) + } else if (perm & 0x08) == 0x08 { // Check for admin privileges data.Groups = append(data.Groups, fmt.Sprintf("discord.com/%s/admins", guildID)) - } + } } data.Groups = append(data.Groups, fmt.Sprintf("discord.com/%s/members", guildID)) - } + // Fetch roles information for the guild + if b.ScopeExists("guilds.members.read") { + reqURL = fmt.Sprintf("https://discord.com/api/v10/users/@me/guilds/%s/member", guildID) + req, err = http.NewRequest("GET", reqURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + req.Header.Add("Authorization", "Bearer " + authToken) - b.logger.Debug( - "Parsed additional user data", - zap.String("url", reqURL), - zap.Any("data", data), - ) + resp, err = cli.Do(req) + if err != nil { + return nil, err + } + + respBody, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + + var memberData discordMember + if err := json.Unmarshal(respBody, &memberData); err != nil { + b.logger.Debug( + "Guild Roles request failed", + zap.Any("response", respBody), + zap.Any("error", err), + ) + return nil, err + } + + for _, roleID := range memberData.Roles { + data.Groups = append(data.Groups, fmt.Sprintf("discord.com/%s/role/%s", guildID, roleID)) + } + } + + b.logger.Debug( + "Parsed additional discord user data", + zap.String("url", reqURL), + zap.Any("data", data), + ) + } return data, nil }