From b812fcd5076c54607c0dc16d540bd60889736bb1 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 11 Nov 2024 11:16:23 -0800 Subject: [PATCH] Move oauth2 things of the command package --- command/get.go | 5 +- command/login.go | 7 +- command/roles.go | 99 ++++++++++++++++++++++- command/{saml_test.go => roles_test.go} | 6 +- command/saml.go | 102 ------------------------ {command => oauth2}/html.go | 2 +- {command => oauth2}/html_test.go | 2 +- {command => oauth2}/oauth2.go | 87 ++++++-------------- {command => oauth2}/oauth2_test.go | 2 +- oauth2/websso.go | 65 +++++++++++++++ 10 files changed, 200 insertions(+), 177 deletions(-) rename command/{saml_test.go => roles_test.go} (82%) delete mode 100644 command/saml.go rename {command => oauth2}/html.go (99%) rename {command => oauth2}/html_test.go (99%) rename {command => oauth2}/oauth2.go (68%) rename {command => oauth2}/oauth2_test.go (99%) create mode 100644 oauth2/websso.go diff --git a/command/get.go b/command/get.go index abbbe7e4..318e7259 100644 --- a/command/get.go +++ b/command/get.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" + "github.com/riotgames/key-conjurer/oauth2" "github.com/spf13/cobra" "github.com/spf13/pflag" ) @@ -185,12 +186,12 @@ func (g GetCommand) Execute(ctx context.Context) error { } func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account) (*CloudCredentials, error) { - samlResponse, assertionStr, err := DiscoverConfigAndExchangeTokenForAssertion(ctx, g.Config.Tokens, g.OIDCDomain, g.ClientID, account.ID) + samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, g.Config.Tokens.AccessToken, g.Config.Tokens.IDToken, g.OIDCDomain, g.ClientID, account.ID) if err != nil { return nil, err } - pair, ok := FindRoleInSAML(g.RoleName, samlResponse) + pair, ok := findRoleInSAML(g.RoleName, samlResponse) if !ok { return nil, UnknownRoleError(g.RoleName, g.Args[0]) } diff --git a/command/login.go b/command/login.go index 82bfb74f..27d8162a 100644 --- a/command/login.go +++ b/command/login.go @@ -10,6 +10,7 @@ import ( "log/slog" "github.com/pkg/browser" + "github.com/riotgames/key-conjurer/oauth2" "github.com/spf13/cobra" "github.com/spf13/pflag" ) @@ -69,7 +70,7 @@ type LoginCommand struct { } func (c LoginCommand) Execute(ctx context.Context) error { - oauthCfg, err := DiscoverOAuth2Config(ctx, c.OIDCDomain, c.ClientID) + oauthCfg, err := oauth2.DiscoverConfig(ctx, c.OIDCDomain, c.ClientID) if err != nil { return err } @@ -86,7 +87,7 @@ func (c LoginCommand) Execute(ctx context.Context) error { } oauthCfg.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port)) - handler := RedirectionFlowHandler{ + handler := oauth2.RedirectionFlowHandler{ Config: oauthCfg, OnDisplayURL: openBrowserToURL, } @@ -99,7 +100,7 @@ func (c LoginCommand) Execute(ctx context.Context) error { } } - accessToken, err := handler.HandlePendingSession(ctx, sock, GeneratePkceChallenge(), GenerateState()) + accessToken, err := handler.HandlePendingSession(ctx, sock, oauth2.GeneratePkceChallenge(), oauth2.GenerateState()) if err != nil { return err } diff --git a/command/roles.go b/command/roles.go index 07e0145b..0f3501bc 100644 --- a/command/roles.go +++ b/command/roles.go @@ -1,9 +1,18 @@ package command import ( + "strings" + + "github.com/RobotsAndPencils/go-saml" + "github.com/riotgames/key-conjurer/oauth2" "github.com/spf13/cobra" ) +const ( + awsFlag = 0 + tencentFlag = 1 +) + var rolesCmd = cobra.Command{ Use: "roles ", Short: "Returns the roles that you have access to in the given account.", @@ -23,15 +32,101 @@ var rolesCmd = cobra.Command{ applicationID = account.ID } - samlResponse, _, err := DiscoverConfigAndExchangeTokenForAssertion(cmd.Context(), config.Tokens, oidcDomain, clientID, applicationID) + samlResponse, _, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(cmd.Context(), config.Tokens.AccessToken, config.Tokens.IDToken, oidcDomain, clientID, applicationID) if err != nil { return err } - for _, name := range ListSAMLRoles(samlResponse) { + for _, name := range listRoles(samlResponse) { cmd.Println(name) } return nil }, } + +type roleProviderPair struct { + RoleARN string + ProviderARN string +} + +func getARN(value string) roleProviderPair { + var p roleProviderPair + roles := strings.Split(value, ",") + if len(roles) >= 2 { + if strings.Contains(roles[0], "saml-provider/") { + p.ProviderARN = roles[0] + p.RoleARN = roles[1] + } else { + p.ProviderARN = roles[1] + p.RoleARN = roles[0] + } + } + return p +} + +func findRoleInSAML(roleName string, response *saml.Response) (roleProviderPair, bool) { + if response == nil { + return roleProviderPair{}, false + } + + roleURL := "https://aws.amazon.com/SAML/Attributes/Role" + roleSubstr := "role/" + attrs := response.GetAttributeValues(roleURL) + if len(attrs) == 0 { + attrs = response.GetAttributeValues("https://cloud.tencent.com/SAML/Attributes/Role") + roleSubstr = "roleName/" + } + + if len(attrs) == 0 { + // The SAML assertoin contains no known roles for AWS or Tencent. + return roleProviderPair{}, false + } + + var pairs []roleProviderPair + for _, v := range response.GetAttributeValues(roleURL) { + pairs = append(pairs, getARN(v)) + } + + if len(pairs) == 0 { + return roleProviderPair{}, false + } + + var pair roleProviderPair + for _, p := range pairs { + idx := strings.Index(p.RoleARN, roleSubstr) + parts := strings.Split(p.RoleARN[idx:], "/") + if strings.EqualFold(parts[1], roleName) { + pair = p + } + } + + if pair.RoleARN == "" { + return roleProviderPair{}, false + } + + return pair, true +} + +func listRoles(response *saml.Response) []string { + if response == nil { + return nil + } + + roleURL := "https://aws.amazon.com/SAML/Attributes/Role" + roleSubstr := "role/" + if response.GetAttribute(roleURL) == "" { + roleURL = "https://cloud.tencent.com/SAML/Attributes/Role" + roleSubstr = "roleName/" + } + + var names []string + for _, v := range response.GetAttributeValues(roleURL) { + p := getARN(v) + idx := strings.Index(p.RoleARN, roleSubstr) + parts := strings.Split(p.RoleARN[idx:], "/") + names = append(names, parts[1]) + } + + return names +} diff --git a/command/saml_test.go b/command/roles_test.go similarity index 82% rename from command/saml_test.go rename to command/roles_test.go index 60151499..ddc7f920 100644 --- a/command/saml_test.go +++ b/command/roles_test.go @@ -7,15 +7,15 @@ import ( "github.com/stretchr/testify/require" ) -func TestAwsFindRoleDoesntBreakIfYouHaveMultipleRoles(t *testing.T) { +func Test_findRoleInSAML_DoesntBreakIfYouHaveMultipleRoles(t *testing.T) { var resp saml.Response resp.AddAttribute("https://aws.amazon.com/SAML/Attributes/Role", "arn:cloud:iam::1234:saml-provider/Okta,arn:cloud:iam::1234:role/Admin") resp.AddAttribute("https://aws.amazon.com/SAML/Attributes/Role", "arn:cloud:iam::1234:saml-provider/Okta,arn:cloud:iam::1234:role/Power") - pair, err := FindRoleInSAML("Power", &resp) + pair, err := findRoleInSAML("Power", &resp) require.True(t, err) require.Equal(t, "arn:cloud:iam::1234:saml-provider/Okta", pair.ProviderARN) require.Equal(t, "arn:cloud:iam::1234:role/Power", pair.RoleARN) - pair, err = FindRoleInSAML("Admin", &resp) + pair, err = findRoleInSAML("Admin", &resp) require.True(t, err) require.Equal(t, "arn:cloud:iam::1234:saml-provider/Okta", pair.ProviderARN) require.Equal(t, "arn:cloud:iam::1234:role/Admin", pair.RoleARN) diff --git a/command/saml.go b/command/saml.go deleted file mode 100644 index bca1d1ca..00000000 --- a/command/saml.go +++ /dev/null @@ -1,102 +0,0 @@ -package command - -import ( - "strings" - - "github.com/RobotsAndPencils/go-saml" -) - -type RoleProviderPair struct { - RoleARN string - ProviderARN string -} - -const ( - awsFlag = 0 - tencentFlag = 1 -) - -func ListSAMLRoles(response *saml.Response) []string { - if response == nil { - return nil - } - - roleURL := "https://aws.amazon.com/SAML/Attributes/Role" - roleSubstr := "role/" - if response.GetAttribute(roleURL) == "" { - roleURL = "https://cloud.tencent.com/SAML/Attributes/Role" - roleSubstr = "roleName/" - } - - var names []string - for _, v := range response.GetAttributeValues(roleURL) { - p := getARN(v) - idx := strings.Index(p.RoleARN, roleSubstr) - parts := strings.Split(p.RoleARN[idx:], "/") - names = append(names, parts[1]) - } - - return names -} - -func FindRoleInSAML(roleName string, response *saml.Response) (RoleProviderPair, bool) { - if response == nil { - return RoleProviderPair{}, false - } - - roleURL := "https://aws.amazon.com/SAML/Attributes/Role" - roleSubstr := "role/" - attrs := response.GetAttributeValues(roleURL) - if len(attrs) == 0 { - attrs = response.GetAttributeValues("https://cloud.tencent.com/SAML/Attributes/Role") - roleSubstr = "roleName/" - } - - if len(attrs) == 0 { - // The SAML assertoin contains no known roles for AWS or Tencent. - return RoleProviderPair{}, false - } - - var pairs []RoleProviderPair - for _, v := range response.GetAttributeValues(roleURL) { - pairs = append(pairs, getARN(v)) - } - - if len(pairs) == 0 { - return RoleProviderPair{}, false - } - - var pair RoleProviderPair - for _, p := range pairs { - idx := strings.Index(p.RoleARN, roleSubstr) - parts := strings.Split(p.RoleARN[idx:], "/") - if strings.EqualFold(parts[1], roleName) { - pair = p - } - } - - if pair.RoleARN == "" { - return RoleProviderPair{}, false - } - - return pair, true -} - -func getARN(value string) RoleProviderPair { - p := RoleProviderPair{} - roles := strings.Split(value, ",") - if len(roles) >= 2 { - if strings.Contains(roles[0], "saml-provider/") { - p.ProviderARN = roles[0] - p.RoleARN = roles[1] - } else { - p.ProviderARN = roles[1] - p.RoleARN = roles[0] - } - } - return p -} - -func ParseBase64EncodedSAMLResponse(xml string) (*saml.Response, error) { - return saml.ParseEncodedResponse(xml) -} diff --git a/command/html.go b/oauth2/html.go similarity index 99% rename from command/html.go rename to oauth2/html.go index 3380484e..d01681e6 100644 --- a/command/html.go +++ b/oauth2/html.go @@ -1,4 +1,4 @@ -package command +package oauth2 import ( "errors" diff --git a/command/html_test.go b/oauth2/html_test.go similarity index 99% rename from command/html_test.go rename to oauth2/html_test.go index b1bfb12c..391fd65c 100644 --- a/command/html_test.go +++ b/oauth2/html_test.go @@ -1,4 +1,4 @@ -package command +package oauth2 import ( "strings" diff --git a/command/oauth2.go b/oauth2/oauth2.go similarity index 68% rename from command/oauth2.go rename to oauth2/oauth2.go index 4b3c9a30..6f232275 100644 --- a/command/oauth2.go +++ b/oauth2/oauth2.go @@ -1,4 +1,4 @@ -package command +package oauth2 import ( "context" @@ -9,13 +9,11 @@ import ( "fmt" "net" "net/http" - "net/url" "strings" "sync" "github.com/RobotsAndPencils/go-saml" "github.com/coreos/go-oidc" - "golang.org/x/net/html" "golang.org/x/oauth2" ) @@ -25,7 +23,7 @@ var ErrNoSAMLAssertion = errors.New("no saml assertion") // 43 is a magic number - It generates states that are not too short or long for Okta's validation. const stateBufSize = 43 -func DiscoverOAuth2Config(ctx context.Context, domain, clientID string) (*oauth2.Config, error) { +func DiscoverConfig(ctx context.Context, domain, clientID string) (*oauth2.Config, error) { provider, err := oidc.NewProvider(ctx, domain) if err != nil { return nil, fmt.Errorf("couldn't discover OIDC configuration for %s: %w", domain, err) @@ -145,7 +143,7 @@ type RedirectionFlowHandler struct { func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, listener net.Listener, challenge PkceChallenge, state string) (*oauth2.Token, error) { if r.OnDisplayURL == nil { - r.OnDisplayURL = printURLToConsole + panic("OnDisplayURL must be set") } url := r.Config.AuthCodeURL(state, @@ -175,74 +173,39 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, listen } } -func ExchangeAccessTokenForWebSSOToken(ctx context.Context, oauthCfg *oauth2.Config, token *TokenSet, applicationID string) (*oauth2.Token, error) { - return oauthCfg.Exchange(ctx, "", - oauth2.SetAuthURLParam("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange"), - oauth2.SetAuthURLParam("actor_token", token.AccessToken), - oauth2.SetAuthURLParam("actor_token_type", "urn:ietf:params:oauth:token-type:access_token"), - oauth2.SetAuthURLParam("subject_token", token.IDToken), - oauth2.SetAuthURLParam("subject_token_type", "urn:ietf:params:oauth:token-type:id_token"), - // https://www.linkedin.com/pulse/oktas-aws-cli-app-mysterious-case-powerful-okta-apis-chaim-sanders/ - oauth2.SetAuthURLParam("requested_token_type", "urn:okta:oauth:token-type:web_sso_token"), - oauth2.SetAuthURLParam("audience", fmt.Sprintf("urn:okta:apps:%s", applicationID)), - ) -} - -// TODO: This is actually an Okta-specific API -func ExchangeWebSSOTokenForSAMLAssertion(ctx context.Context, issuer string, token *oauth2.Token) ([]byte, error) { - data := url.Values{"token": {token.AccessToken}} - uri := fmt.Sprintf("%s/login/token/sso?%s", issuer, data.Encode()) - req, _ := http.NewRequestWithContext(ctx, "GET", uri, nil) - req.Header.Add("Accept", "text/html") - - client := http.DefaultClient - if val, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { - client = val - } - - resp, err := client.Do(req) +func DiscoverConfigAndExchangeTokenForAssertion(ctx context.Context, accessToken, idToken, oidcDomain, clientID, applicationID string) (*saml.Response, string, error) { + oauthCfg, err := DiscoverConfig(ctx, oidcDomain, clientID) if err != nil { - return nil, err - } - - if resp.StatusCode == http.StatusInternalServerError { - return nil, errors.New("internal okta error occurred") - } - - doc, _ := html.Parse(resp.Body) - form, ok := FindFirstForm(doc) - if !ok { - return nil, ErrNoSAMLAssertion - } - - saml, ok := form.Inputs["SAMLResponse"] - if !ok { - return nil, ErrNoSAMLAssertion + return nil, "", Error{Message: "could not discover oauth2 config", InnerError: err} } - return []byte(saml), nil -} - -func DiscoverConfigAndExchangeTokenForAssertion(ctx context.Context, toks *TokenSet, oidcDomain, clientID, applicationID string) (*saml.Response, string, error) { - oauthCfg, err := DiscoverOAuth2Config(ctx, oidcDomain, clientID) - if err != nil { - return nil, "", OktaError{Message: "could not discover oauth2 config", InnerError: err} - } - - tok, err := ExchangeAccessTokenForWebSSOToken(ctx, oauthCfg, toks, applicationID) + tok, err := exchangeAccessTokenForWebSSOToken(ctx, oauthCfg, accessToken, idToken, applicationID) if err != nil { - return nil, "", OktaError{Message: "error exchanging token", InnerError: err} + return nil, "", Error{Message: "error exchanging token", InnerError: err} } - assertionBytes, err := ExchangeWebSSOTokenForSAMLAssertion(ctx, oidcDomain, tok) + assertionBytes, err := exchangeWebSSOTokenForSAMLAssertion(ctx, oidcDomain, tok) if err != nil { - return nil, "", OktaError{Message: "failed to fetch SAML assertion", InnerError: err} + return nil, "", Error{Message: "failed to fetch SAML assertion", InnerError: err} } - response, err := ParseBase64EncodedSAMLResponse(string(assertionBytes)) + response, err := saml.ParseEncodedResponse(string(assertionBytes)) if err != nil { - return nil, "", OktaError{Message: "failed to parse SAML response", InnerError: err} + return nil, "", Error{Message: "failed to parse SAML response", InnerError: err} } return response, string(assertionBytes), nil } + +type Error struct { + InnerError error + Message string +} + +func (o Error) Unwrap() error { + return o.InnerError +} + +func (o Error) Error() string { + return o.Message +} diff --git a/command/oauth2_test.go b/oauth2/oauth2_test.go similarity index 99% rename from command/oauth2_test.go rename to oauth2/oauth2_test.go index 02fefc89..f5d2ef2d 100644 --- a/command/oauth2_test.go +++ b/oauth2/oauth2_test.go @@ -1,4 +1,4 @@ -package command +package oauth2 import ( "net/http" diff --git a/oauth2/websso.go b/oauth2/websso.go new file mode 100644 index 00000000..379a0c10 --- /dev/null +++ b/oauth2/websso.go @@ -0,0 +1,65 @@ +package oauth2 + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "golang.org/x/net/html" + "golang.org/x/oauth2" +) + +// exchangeAccessTokenForWebSSOToken exchanges an OAuth2 token for an Okta Web SSO token. +// +// An Okta Web SSO token is a non-standard authorization token for Okta's Web SSO endpoint. +func exchangeAccessTokenForWebSSOToken(ctx context.Context, oauthCfg *oauth2.Config, accessToken string, idToken string, applicationID string) (*oauth2.Token, error) { + return oauthCfg.Exchange(ctx, "", + oauth2.SetAuthURLParam("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange"), + oauth2.SetAuthURLParam("actor_token", accessToken), + oauth2.SetAuthURLParam("actor_token_type", "urn:ietf:params:oauth:token-type:access_token"), + oauth2.SetAuthURLParam("subject_token", idToken), + oauth2.SetAuthURLParam("subject_token_type", "urn:ietf:params:oauth:token-type:id_token"), + // https://www.linkedin.com/pulse/oktas-aws-cli-app-mysterious-case-powerful-okta-apis-chaim-sanders/ + oauth2.SetAuthURLParam("requested_token_type", "urn:okta:oauth:token-type:web_sso_token"), + oauth2.SetAuthURLParam("audience", fmt.Sprintf("urn:okta:apps:%s", applicationID)), + ) +} + +// exchangeWebSSOTokenForSAMLAssertion is an Okta-specific API which exchanges an Okta Web SSO token, which is obtained by exchanging an OAuth2 token using the RFC8693 Token Exchange Flow, for a SAML assertion. +// +// It is not standards compliant, but is used by Okta in their own okta-aws-cli. +func exchangeWebSSOTokenForSAMLAssertion(ctx context.Context, issuer string, token *oauth2.Token) ([]byte, error) { + data := url.Values{"token": {token.AccessToken}} + uri := fmt.Sprintf("%s/login/token/sso?%s", issuer, data.Encode()) + req, _ := http.NewRequestWithContext(ctx, "GET", uri, nil) + req.Header.Add("Accept", "text/html") + + client := http.DefaultClient + if val, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { + client = val + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusInternalServerError { + return nil, errors.New("internal okta error occurred") + } + + doc, _ := html.Parse(resp.Body) + form, ok := FindFirstForm(doc) + if !ok { + return nil, ErrNoSAMLAssertion + } + + saml, ok := form.Inputs["SAMLResponse"] + if !ok { + return nil, ErrNoSAMLAssertion + } + + return []byte(saml), nil +}