Skip to content

Commit

Permalink
x-pack/filebeat/input/entityanalytics/provider/okta: Rate limiting fi…
Browse files Browse the repository at this point in the history
…xes (#41583)

- Separate rate limits by endpoint.
- Stop requests until reset when `x-rate-limit-remaining: 0`.
  • Loading branch information
chrisberkhout authored Dec 5, 2024
1 parent c51c8ce commit 4e19d09
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 112 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.next.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ https://github.com/elastic/beats/compare/v8.8.1\...main[Check the HEAD diff]
- Fix missing key in streaming input logging. {pull}41600[41600]
- Improve S3 object size metric calculation to support situations where Content-Length is not available. {pull}41755[41755]
- Fix handling of http_endpoint request exceeding memory limits. {issue}41764[41764] {pull}41765[41765]
- Rate limiting fixes in the Okta provider of the Entity Analytics input. {issue}40106[40106] {pull}41583[41583]

*Heartbeat*

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ import (
"io"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"

"golang.org/x/time/rate"

"github.com/elastic/elastic-agent-libs/logp"
)

Expand Down Expand Up @@ -195,16 +191,23 @@ func (o Response) String() string {
// https://${yourOktaDomain}/reports/rate-limit.
//
// See https://developer.okta.com/docs/reference/api/users/#list-users for details.
func GetUserDetails(ctx context.Context, cli *http.Client, host, key, user string, query url.Values, omit Response, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]User, http.Header, error) {
const endpoint = "/api/v1/users"
func GetUserDetails(ctx context.Context, cli *http.Client, host, key, user string, query url.Values, omit Response, lim RateLimiter, window time.Duration, log *logp.Logger) ([]User, http.Header, error) {
var endpoint, path string
if user == "" {
endpoint = "/api/v1/users"
path = endpoint
} else {
endpoint = "/api/v1/users/{user}"
path = strings.Replace(endpoint, "{user}", user, 1)
}

u := &url.URL{
Scheme: "https",
Host: host,
Path: path.Join(endpoint, user),
Path: path,
RawQuery: query.Encode(),
}
return getDetails[User](ctx, cli, u, key, user == "", omit, lim, window, log)
return getDetails[User](ctx, cli, u, endpoint, key, user == "", omit, lim, window, log)
}

// GetUserFactors returns Okta group roles using the groups API endpoint. host is the
Expand All @@ -213,19 +216,20 @@ func GetUserDetails(ctx context.Context, cli *http.Client, host, key, user strin
// See GetUserDetails for details of the query and rate limit parameters.
//
// See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/UserFactor/#tag/UserFactor/operation/listFactors.
func GetUserFactors(ctx context.Context, cli *http.Client, host, key, user string, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Factor, http.Header, error) {
const endpoint = "/api/v1/users"

func GetUserFactors(ctx context.Context, cli *http.Client, host, key, user string, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Factor, http.Header, error) {
if user == "" {
return nil, nil, errors.New("no user specified")
}

const endpoint = "/api/v1/users/{user}/factors"
path := strings.Replace(endpoint, "{user}", user, 1)

u := &url.URL{
Scheme: "https",
Host: host,
Path: path.Join(endpoint, user, "factors"),
Path: path,
}
return getDetails[Factor](ctx, cli, u, key, true, OmitNone, lim, window, log)
return getDetails[Factor](ctx, cli, u, endpoint, key, true, OmitNone, lim, window, log)
}

// GetUserRoles returns Okta group roles using the groups API endpoint. host is the
Expand All @@ -234,19 +238,20 @@ func GetUserFactors(ctx context.Context, cli *http.Client, host, key, user strin
// See GetUserDetails for details of the query and rate limit parameters.
//
// See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/RoleAssignmentBGroup/#tag/RoleAssignmentBGroup/operation/listGroupAssignedRoles.
func GetUserRoles(ctx context.Context, cli *http.Client, host, key, user string, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Role, http.Header, error) {
const endpoint = "/api/v1/users"

func GetUserRoles(ctx context.Context, cli *http.Client, host, key, user string, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Role, http.Header, error) {
if user == "" {
return nil, nil, errors.New("no user specified")
}

const endpoint = "/api/v1/users/{user}/roles"
path := strings.Replace(endpoint, "{user}", user, 1)

u := &url.URL{
Scheme: "https",
Host: host,
Path: path.Join(endpoint, user, "roles"),
Path: path,
}
return getDetails[Role](ctx, cli, u, key, true, OmitNone, lim, window, log)
return getDetails[Role](ctx, cli, u, endpoint, key, true, OmitNone, lim, window, log)
}

// GetUserGroupDetails returns Okta group details using the users API endpoint. host is the
Expand All @@ -255,19 +260,20 @@ func GetUserRoles(ctx context.Context, cli *http.Client, host, key, user string,
// See GetUserDetails for details of the query and rate limit parameters.
//
// See https://developer.okta.com/docs/reference/api/users/#request-parameters-8 (no anchor exists on the page for this endpoint) for details.
func GetUserGroupDetails(ctx context.Context, cli *http.Client, host, key, user string, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Group, http.Header, error) {
const endpoint = "/api/v1/users"

func GetUserGroupDetails(ctx context.Context, cli *http.Client, host, key, user string, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Group, http.Header, error) {
if user == "" {
return nil, nil, errors.New("no user specified")
}

const endpoint = "/api/v1/users/{user}/groups"
path := strings.Replace(endpoint, "{user}", user, 1)

u := &url.URL{
Scheme: "https",
Host: host,
Path: path.Join(endpoint, user, "groups"),
Path: path,
}
return getDetails[Group](ctx, cli, u, key, true, OmitNone, lim, window, log)
return getDetails[Group](ctx, cli, u, endpoint, key, true, OmitNone, lim, window, log)
}

// GetGroupRoles returns Okta group roles using the groups API endpoint. host is the
Expand All @@ -276,19 +282,20 @@ func GetUserGroupDetails(ctx context.Context, cli *http.Client, host, key, user
// See GetUserDetails for details of the query and rate limit parameters.
//
// See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/RoleAssignmentBGroup/#tag/RoleAssignmentBGroup/operation/listGroupAssignedRoles.
func GetGroupRoles(ctx context.Context, cli *http.Client, host, key, group string, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Role, http.Header, error) {
const endpoint = "/api/v1/groups"

func GetGroupRoles(ctx context.Context, cli *http.Client, host, key, group string, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Role, http.Header, error) {
if group == "" {
return nil, nil, errors.New("no group specified")
}

const endpoint = "/api/v1/groups/{group}/rules"
path := strings.Replace(endpoint, "{group}", group, 1)

u := &url.URL{
Scheme: "https",
Host: host,
Path: path.Join(endpoint, group, "roles"),
Path: path,
}
return getDetails[Role](ctx, cli, u, key, true, OmitNone, lim, window, log)
return getDetails[Role](ctx, cli, u, endpoint, key, true, OmitNone, lim, window, log)
}

// GetDeviceDetails returns Okta device details using the list devices API endpoint. host is the
Expand All @@ -298,16 +305,24 @@ func GetGroupRoles(ctx context.Context, cli *http.Client, host, key, group strin
// See GetUserDetails for details of the query and rate limit parameters.
//
// See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/Device/#tag/Device/operation/listDevices for details.
func GetDeviceDetails(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]Device, http.Header, error) {
const endpoint = "/api/v1/devices"
func GetDeviceDetails(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, lim RateLimiter, window time.Duration, log *logp.Logger) ([]Device, http.Header, error) {
var endpoint string
var path string
if device == "" {
endpoint = "/api/v1/devices"
path = endpoint
} else {
endpoint = "/api/v1/devices/{device}"
path = strings.Replace(endpoint, "{device}", device, 1)
}

u := &url.URL{
Scheme: "https",
Host: host,
Path: path.Join(endpoint, device),
Path: path,
RawQuery: query.Encode(),
}
return getDetails[Device](ctx, cli, u, key, device == "", OmitNone, lim, window, log)
return getDetails[Device](ctx, cli, u, endpoint, key, device == "", OmitNone, lim, window, log)
}

// GetDeviceUsers returns Okta user details for users associated with the provided device identifier
Expand All @@ -317,21 +332,22 @@ func GetDeviceDetails(ctx context.Context, cli *http.Client, host, key, device s
// See GetUserDetails for details of the query and rate limit parameters.
//
// See https://developer.okta.com/docs/api/openapi/okta-management/management/tag/Device/#tag/Device/operation/listDeviceUsers for details.
func GetDeviceUsers(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, omit Response, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]User, http.Header, error) {
func GetDeviceUsers(ctx context.Context, cli *http.Client, host, key, device string, query url.Values, omit Response, lim RateLimiter, window time.Duration, log *logp.Logger) ([]User, http.Header, error) {
if device == "" {
// No user associated with a null device. Not an error.
return nil, nil, nil
}

const endpoint = "/api/v1/devices"
const endpoint = "/api/v1/devices/{device}/users"
path := strings.Replace(endpoint, "{device}", device, 1)

u := &url.URL{
Scheme: "https",
Host: host,
Path: path.Join(endpoint, device, "users"),
Path: path,
RawQuery: query.Encode(),
}
du, h, err := getDetails[devUser](ctx, cli, u, key, true, omit, lim, window, log)
du, h, err := getDetails[devUser](ctx, cli, u, endpoint, key, true, omit, lim, window, log)
if err != nil {
return nil, h, err
}
Expand All @@ -356,7 +372,7 @@ type devUser struct {
// for the specific user are returned, otherwise a list of all users is returned.
//
// See GetUserDetails for details of the query and rate limit parameters.
func getDetails[E entity](ctx context.Context, cli *http.Client, u *url.URL, key string, all bool, omit Response, lim *rate.Limiter, window time.Duration, log *logp.Logger) ([]E, http.Header, error) {
func getDetails[E entity](ctx context.Context, cli *http.Client, u *url.URL, endpoint string, key string, all bool, omit Response, lim RateLimiter, window time.Duration, log *logp.Logger) ([]E, http.Header, error) {
url := u.String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
Expand All @@ -370,8 +386,7 @@ func getDetails[E entity](ctx context.Context, cli *http.Client, u *url.URL, key
req.Header.Set("Content-Type", contentType)
req.Header.Set("Authorization", fmt.Sprintf("SSWS %s", key))

log.Debugw("rate limit", "limit", lim.Limit(), "burst", lim.Burst(), "url", url)
err = lim.Wait(ctx)
err = lim.Wait(ctx, endpoint, u, log)
if err != nil {
return nil, nil, err
}
Expand All @@ -380,7 +395,7 @@ func getDetails[E entity](ctx context.Context, cli *http.Client, u *url.URL, key
return nil, nil, err
}
defer resp.Body.Close()
err = oktaRateLimit(resp.Header, window, lim, log)
err = lim.Update(endpoint, resp.Header, window, log)
if err != nil {
io.Copy(io.Discard, resp.Body)
return nil, nil, err
Expand Down Expand Up @@ -443,59 +458,6 @@ func (e *Error) Error() string {
return fmt.Sprintf("%s: %s", summary, strings.Join(causes, ","))
}

// oktaRateLimit implements the Okta rate limit policy translation.
//
// See https://developer.okta.com/docs/reference/rl-best-practices/ for details.
func oktaRateLimit(h http.Header, window time.Duration, limiter *rate.Limiter, log *logp.Logger) error {
limit := h.Get("X-Rate-Limit-Limit")
remaining := h.Get("X-Rate-Limit-Remaining")
reset := h.Get("X-Rate-Limit-Reset")
log.Debugw("rate limit header", "X-Rate-Limit-Limit", limit, "X-Rate-Limit-Remaining", remaining, "X-Rate-Limit-Reset", reset)
if limit == "" || remaining == "" || reset == "" {
return nil
}

lim, err := strconv.ParseFloat(limit, 64)
if err != nil {
return err
}
rem, err := strconv.ParseFloat(remaining, 64)
if err != nil {
return err
}
rst, err := strconv.ParseInt(reset, 10, 64)
if err != nil {
return err
}
resetTime := time.Unix(rst, 0)
per := time.Until(resetTime).Seconds()

// Be conservative here; the docs don't exactly specify burst rates.
// Make sure we can make at least one new request, even if we fail
// to get a non-zero rate.Limit. We could set to zero for the case
// that limit=rate.Inf, but that detail is not important.
burst := 1

rateLimit := rate.Limit(rem / per)

// Process reset if we need to wait until reset to avoid a request against a zero quota.
if rateLimit <= 0 {
waitUntil := resetTime.UTC()
// next gives us a sane next window estimate, but the
// estimate will be overwritten when we make the next
// permissible API request.
next := rate.Limit(lim / window.Seconds())
limiter.SetLimitAt(waitUntil, next)
limiter.SetBurstAt(waitUntil, burst)
log.Debugw("rate limit adjust", "reset_time", waitUntil, "next_rate", next, "next_burst", burst)
return nil
}
limiter.SetLimit(rateLimit)
limiter.SetBurst(burst)
log.Debugw("rate limit adjust", "set_rate", rateLimit, "set_burst", burst)
return nil
}

// Next returns the next URL query for a pagination sequence. If no further
// page is available, Next returns io.EOF.
func Next(h http.Header) (query url.Values, err error) {
Expand Down
Loading

0 comments on commit 4e19d09

Please sign in to comment.