diff --git a/changelog/@unreleased/pr-341.v2.yml b/changelog/@unreleased/pr-341.v2.yml new file mode 100644 index 00000000..9303234e --- /dev/null +++ b/changelog/@unreleased/pr-341.v2.yml @@ -0,0 +1,5 @@ +type: improvement +improvement: + description: Introduce new URIPool and URISelector interfaces. + links: + - https://github.com/palantir/conjure-go-runtime/pull/341 diff --git a/conjure-go-client/httpclient/client.go b/conjure-go-client/httpclient/client.go index 136b9752..399a65dc 100644 --- a/conjure-go-client/httpclient/client.go +++ b/conjure-go-client/httpclient/client.go @@ -55,7 +55,8 @@ type clientImpl struct { errorDecoderMiddleware Middleware recoveryMiddleware Middleware - uriScorer internal.RefreshableURIScoringMiddleware + uriPool internal.URIPool + uriSelector internal.URISelector maxAttempts refreshable.IntPtr // 0 means no limit. If nil, uses 2*len(uris). backoffOptions refreshingclient.RefreshableRetryParams bufferPool bytesbuffers.Pool @@ -82,12 +83,10 @@ func (c *clientImpl) Delete(ctx context.Context, params ...RequestParam) (*http. } func (c *clientImpl) Do(ctx context.Context, params ...RequestParam) (*http.Response, error) { - uris := c.uriScorer.CurrentURIScoringMiddleware().GetURIsInOrderOfIncreasingScore() - if len(uris) == 0 { + attempts := 2 * c.uriPool.NumURIs() + if attempts == 0 { return nil, werror.ErrorWithContextParams(ctx, "no base URIs are configured") } - - attempts := 2 * len(uris) if c.maxAttempts != nil { if confMaxAttempts := c.maxAttempts.CurrentIntPtr(); confMaxAttempts != nil { attempts = *confMaxAttempts @@ -96,17 +95,18 @@ func (c *clientImpl) Do(ctx context.Context, params ...RequestParam) (*http.Resp var err error var resp *http.Response - - retrier := internal.NewRequestRetrier(uris, c.backoffOptions.CurrentRetryParams().Start(ctx), attempts) + retrier := internal.NewRequestRetrier(c.backoffOptions.CurrentRetryParams().Start(ctx), attempts) for { - uri, isRelocated := retrier.GetNextURI(resp, err) - if uri == "" { + shouldRetry, retryURL := retrier.Next(resp, err) + if !shouldRetry { break } + resp, err = c.doOnce(ctx, retryURL, params...) if err != nil { svc1log.FromContext(ctx).Debug("Retrying request", svc1log.Stacktrace(err)) + continue } - resp, err = c.doOnce(ctx, uri, isRelocated, params...) + break } if err != nil { return nil, err @@ -116,11 +116,9 @@ func (c *clientImpl) Do(ctx context.Context, params ...RequestParam) (*http.Resp func (c *clientImpl) doOnce( ctx context.Context, - baseURI string, - useBaseURIOnly bool, + retryURL *url.URL, params ...RequestParam, ) (*http.Response, error) { - // 1. create the request b := &requestBuilder{ headers: make(http.Header), @@ -136,9 +134,6 @@ func (c *clientImpl) doOnce( return nil, err } } - if useBaseURIOnly { - b.path = "" - } for _, c := range b.configureCtx { ctx = c(ctx) @@ -147,12 +142,22 @@ func (c *clientImpl) doOnce( if b.method == "" { return nil, werror.ErrorWithContextParams(ctx, "httpclient: use WithRequestMethod() to specify HTTP method") } - reqURI := joinURIAndPath(baseURI, b.path) - req, err := http.NewRequest(b.method, reqURI, nil) + var uri string + if retryURL == nil { + var err error + uri, err = c.uriSelector.Select(c.uriPool.URIs(), b.headers) + if err != nil { + return nil, werror.WrapWithContextParams(ctx, err, "failed to select uri") + } + uri = joinURIAndPath(uri, b.path) + } else { + uri = retryURL.String() + } + req, err := http.NewRequestWithContext(ctx, b.method, uri, nil) if err != nil { return nil, werror.WrapWithContextParams(ctx, err, "failed to build new HTTP request") } - req = req.WithContext(ctx) + req.Header = b.headers if q := b.query.Encode(); q != "" { req.URL.RawQuery = q @@ -164,7 +169,6 @@ func (c *clientImpl) doOnce( transport := clientCopy.Transport // start with the client's transport configured with default middleware // must precede the error decoders to read the status code of the raw response. - transport = wrapTransport(transport, c.uriScorer.CurrentURIScoringMiddleware()) // request decoder must precede the client decoder // must precede the body middleware to read the response body transport = wrapTransport(transport, b.errorDecoderMiddleware, c.errorDecoderMiddleware) diff --git a/conjure-go-client/httpclient/client_builder.go b/conjure-go-client/httpclient/client_builder.go index 3779aabf..d986d45e 100644 --- a/conjure-go-client/httpclient/client_builder.go +++ b/conjure-go-client/httpclient/client_builder.go @@ -49,8 +49,8 @@ const ( type clientBuilder struct { HTTP *httpClientBuilder - URIs refreshable.StringSlice - URIScorerBuilder func([]string) internal.URIScoringMiddleware + URIs refreshable.StringSlice + URISelector internal.URISelector ErrorDecoder ErrorDecoder @@ -153,15 +153,16 @@ func newClient(ctx context.Context, b *clientBuilder, params ...ClientParam) (Cl if !b.HTTP.DisableRecovery { recovery = recoveryMiddleware{} } - uriScorer := internal.NewRefreshableURIScoringMiddleware(b.URIs, func(uris []string) internal.URIScoringMiddleware { - if b.URIScorerBuilder == nil { - return internal.NewRandomURIScoringMiddleware(uris, func() int64 { return time.Now().UnixNano() }) - } - return b.URIScorerBuilder(uris) - }) + uriPool := internal.NewStatefulURIPool(b.URIs) + if b.URISelector == nil { + b.URISelector = internal.NewRoundRobinURISelector(func() int64 { return time.Now().UnixNano() }) + } + // append uriSelector and uriPool middlewares + middleware = append(middleware, uriPool, b.URISelector) return &clientImpl{ client: httpClient, - uriScorer: uriScorer, + uriPool: uriPool, + uriSelector: b.URISelector, maxAttempts: b.MaxAttempts, backoffOptions: b.RetryParams, middlewares: middleware, diff --git a/conjure-go-client/httpclient/client_params.go b/conjure-go-client/httpclient/client_params.go index e0cef56f..8184d9e4 100644 --- a/conjure-go-client/httpclient/client_params.go +++ b/conjure-go-client/httpclient/client_params.go @@ -537,11 +537,7 @@ func WithBasicAuth(username, password string) ClientParam { // and least recent errors. func WithBalancedURIScoring() ClientParam { return clientParamFunc(func(b *clientBuilder) error { - b.URIScorerBuilder = func(uris []string) internal.URIScoringMiddleware { - return internal.NewBalancedURIScoringMiddleware(uris, func() int64 { - return time.Now().UnixNano() - }) - } + b.URISelector = internal.NewBalancedURISelector(func() int64 { return time.Now().UnixNano() }) return nil }) } diff --git a/conjure-go-client/httpclient/internal/balanced_scorer.go b/conjure-go-client/httpclient/internal/balanced_selector.go similarity index 53% rename from conjure-go-client/httpclient/internal/balanced_scorer.go rename to conjure-go-client/httpclient/internal/balanced_selector.go index cff197eb..e25e4c2c 100644 --- a/conjure-go-client/httpclient/internal/balanced_scorer.go +++ b/conjure-go-client/httpclient/internal/balanced_selector.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Palantir Technologies. All rights reserved. +// Copyright (c) 2022 Palantir Technologies. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,8 +20,11 @@ import ( "net/http" "net/url" "sort" + "sync" "sync/atomic" "time" + + werror "github.com/palantir/witchcraft-go-error" ) const ( @@ -29,41 +32,58 @@ const ( failureMemory = 30 * time.Second ) -type URIScoringMiddleware interface { - GetURIsInOrderOfIncreasingScore() []string - RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) -} - -type balancedScorer struct { - uriInfos map[string]uriInfo -} - -type uriInfo struct { - inflight int32 - recentFailures CourseExponentialDecayReservoir -} - -// NewBalancedURIScoringMiddleware returns URI scoring middleware that tracks in-flight requests and recent failures +// NewBalancedURISelector returns URI scoring middleware that tracks in-flight requests and recent failures // for each URI configured on an HTTP client. URIs are scored based on fewest in-flight requests and recent errors, // where client errors are weighted the same as 1/10 of an in-flight request, server errors are weighted as 10 // in-flight requests, and errors are decayed using exponential decay with a half-life of 30 seconds. // // This implementation is based on Dialogue's BalancedScoreTracker: // https://github.com/palantir/dialogue/blob/develop/dialogue-core/src/main/java/com/palantir/dialogue/core/BalancedScoreTracker.java -func NewBalancedURIScoringMiddleware(uris []string, nanoClock func() int64) URIScoringMiddleware { +func NewBalancedURISelector(nanoClock func() int64) URISelector { + return &balancedSelector{ + nanoClock: nanoClock, + } +} + +type balancedSelector struct { + sync.Mutex + + nanoClock func() int64 + uriInfos map[string]uriInfo +} + +// Select implements Selector interface +func (s *balancedSelector) Select(uris []string, _ http.Header) (string, error) { + s.Lock() + defer s.Unlock() + + s.updateURIs(uris) + return s.next() +} + +func (s *balancedSelector) updateURIs(uris []string) { uriInfos := make(map[string]uriInfo, len(uris)) for _, uri := range uris { + if exisiting, ok := s.uriInfos[uri]; ok { + uriInfos[uri] = exisiting + continue + } uriInfos[uri] = uriInfo{ - recentFailures: NewCourseExponentialDecayReservoir(nanoClock, failureMemory), + recentFailures: NewCourseExponentialDecayReservoir(s.nanoClock, failureMemory), } } - return &balancedScorer{uriInfos} + + s.uriInfos = uriInfos + return } -func (u *balancedScorer) GetURIsInOrderOfIncreasingScore() []string { - uris := make([]string, 0, len(u.uriInfos)) - scores := make(map[string]int32, len(u.uriInfos)) - for uri, info := range u.uriInfos { +func (s *balancedSelector) next() (string, error) { + if len(s.uriInfos) == 0 { + return "", werror.Error("no valid connections available") + } + uris := make([]string, 0, len(s.uriInfos)) + scores := make(map[string]int32, len(s.uriInfos)) + for uri, info := range s.uriInfos { uris = append(uris, uri) scores[uri] = info.computeScore() } @@ -74,32 +94,52 @@ func (u *balancedScorer) GetURIsInOrderOfIncreasingScore() []string { sort.Slice(uris, func(i, j int) bool { return scores[uris[i]] < scores[uris[j]] }) - return uris + return uris[0], nil } -func (u *balancedScorer) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { +func (s *balancedSelector) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { baseURI := getBaseURI(req.URL) - info, foundInfo := u.uriInfos[baseURI] - if foundInfo { - atomic.AddInt32(&info.inflight, 1) - defer atomic.AddInt32(&info.inflight, -1) - } + s.updateInflight(baseURI, 1) + defer s.updateInflight(baseURI, -1) + resp, err := next.RoundTrip(req) - if resp == nil || err != nil { - if foundInfo { - info.recentFailures.Update(failureWeight) - } - return nil, err + errCode, ok := StatusCodeFromError(err) + // fall back to the status code from the response + if !ok && resp != nil { + errCode = resp.StatusCode } - if foundInfo { - statusCode := resp.StatusCode - if isGlobalQosStatus(statusCode) || isServerErrorRange(statusCode) { - info.recentFailures.Update(failureWeight) - } else if isClientError(statusCode) { - info.recentFailures.Update(failureWeight / 100) - } + + if isGlobalQosStatus(errCode) || isServerErrorRange(errCode) { + s.updateRecentFailures(baseURI, failureWeight) + } else if isClientError(errCode) { + s.updateRecentFailures(baseURI, failureWeight/100) + } + return resp, err +} + +func (s *balancedSelector) updateInflight(uri string, score int32) { + s.Lock() + defer s.Unlock() + + info, ok := s.uriInfos[uri] + if ok { + atomic.AddInt32(&info.inflight, score) + } +} + +func (s *balancedSelector) updateRecentFailures(uri string, weight float64) { + s.Lock() + defer s.Unlock() + + info, ok := s.uriInfos[uri] + if ok { + info.recentFailures.Update(weight) } - return resp, nil +} + +type uriInfo struct { + inflight int32 + recentFailures CourseExponentialDecayReservoir } func (i *uriInfo) computeScore() int32 { diff --git a/conjure-go-client/httpclient/internal/balanced_scorer_test.go b/conjure-go-client/httpclient/internal/balanced_selector_test.go similarity index 73% rename from conjure-go-client/httpclient/internal/balanced_scorer_test.go rename to conjure-go-client/httpclient/internal/balanced_selector_test.go index 665ece5a..18de4ada 100644 --- a/conjure-go-client/httpclient/internal/balanced_scorer_test.go +++ b/conjure-go-client/httpclient/internal/balanced_selector_test.go @@ -22,15 +22,15 @@ import ( "github.com/stretchr/testify/assert" ) -func TestBalancedScorerRandomizesWithNoneInflight(t *testing.T) { +func TestBalancedSelectorRandomizesWithNoneInflight(t *testing.T) { uris := []string{"uri1", "uri2", "uri3", "uri4", "uri5"} - scorer := NewBalancedURIScoringMiddleware(uris, func() int64 { return 0 }) - scoredUris := scorer.GetURIsInOrderOfIncreasingScore() - assert.ElementsMatch(t, scoredUris, uris) - assert.NotEqual(t, scoredUris, uris) + scorer := NewBalancedURISelector(func() int64 { return 0 }) + scoredURI, err := scorer.Select(uris, nil) + assert.NoError(t, err) + assert.Contains(t, uris, scoredURI) } -func TestBalancedScoring(t *testing.T) { +func TestBalancedSelect(t *testing.T) { server200 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusOK) })) @@ -44,15 +44,20 @@ func TestBalancedScoring(t *testing.T) { })) defer server503.Close() uris := []string{server503.URL, server429.URL, server200.URL} - scorer := NewBalancedURIScoringMiddleware(uris, func() int64 { return 0 }) + scorer := NewBalancedURISelector(func() int64 { return 0 }) for _, server := range []*httptest.Server{server200, server429, server503} { for i := 0; i < 10; i++ { - req, err := http.NewRequest("GET", server.URL, nil) + uri, err := scorer.Select(uris, nil) assert.NoError(t, err) + req, err := http.NewRequest("GET", uri, nil) + assert.NoError(t, err) + _, err = scorer.RoundTrip(req, server.Client().Transport) assert.NoError(t, err) } } - scoredUris := scorer.GetURIsInOrderOfIncreasingScore() - assert.Equal(t, []string{server200.URL, server429.URL, server503.URL}, scoredUris) + + uri, err := scorer.Select(uris, nil) + assert.NoError(t, err) + assert.Equal(t, server200.URL, uri) } diff --git a/conjure-go-client/httpclient/internal/interface.go b/conjure-go-client/httpclient/internal/interface.go new file mode 100644 index 00000000..73d9eaa1 --- /dev/null +++ b/conjure-go-client/httpclient/internal/interface.go @@ -0,0 +1,38 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "net/http" +) + +// URISelector is used in combination with a URIPool to get the +// preferred next URL for a given request. +type URISelector interface { + Select([]string, http.Header) (string, error) + RoundTrip(*http.Request, http.RoundTripper) (*http.Response, error) +} + +// URIPool stores all possible URIs for a given connection. It can be use +// as http middleware in order to maintain state between requests. +type URIPool interface { + // NumURIs returns the overall count of URIs available in the + // connection pool. + NumURIs() int + // URIs returns the set of URIs that should be considered by a + // URISelector + URIs() []string + RoundTrip(*http.Request, http.RoundTripper) (*http.Response, error) +} diff --git a/conjure-go-client/httpclient/internal/random_scorer.go b/conjure-go-client/httpclient/internal/random_selector.go similarity index 52% rename from conjure-go-client/httpclient/internal/random_scorer.go rename to conjure-go-client/httpclient/internal/random_selector.go index 63ebc1d6..24708fa1 100644 --- a/conjure-go-client/httpclient/internal/random_scorer.go +++ b/conjure-go-client/httpclient/internal/random_selector.go @@ -17,31 +17,43 @@ package internal import ( "math/rand" "net/http" + "sync" + + werror "github.com/palantir/witchcraft-go-error" ) -type randomScorer struct { - uris []string +type randomSelector struct { + sync.Mutex nanoClock func() int64 } -func (n *randomScorer) GetURIsInOrderOfIncreasingScore() []string { - uris := make([]string, len(n.uris)) - copy(uris, n.uris) - rand.New(rand.NewSource(n.nanoClock())).Shuffle(len(uris), func(i, j int) { - uris[i], uris[j] = uris[j], uris[i] - }) - return uris +// NewRandomURISelector returns a URI scorer that randomizes the order of URIs when scoring using a rand.Rand +// seeded by the nanoClock function. The middleware no-ops on each request. +func NewRandomURISelector(nanoClock func() int64) URISelector { + return &randomSelector{ + nanoClock: nanoClock, + } } -func (n *randomScorer) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { - return next.RoundTrip(req) +// Select implements estransport.Selector interface +func (s *randomSelector) Select(uris []string, _ http.Header) (string, error) { + s.Lock() + defer s.Unlock() + + return s.next(uris) } -// NewRandomURIScoringMiddleware returns a URI scorer that randomizes the order of URIs when scoring using a rand.Rand -// seeded by the nanoClock function. The middleware no-ops on each request. -func NewRandomURIScoringMiddleware(uris []string, nanoClock func() int64) URIScoringMiddleware { - return &randomScorer{ - uris: uris, - nanoClock: nanoClock, +func (s *randomSelector) next(uris []string) (string, error) { + if len(uris) == 0 { + return "", werror.Error("no valid connections available") } + rand.New(rand.NewSource(s.nanoClock())).Shuffle(len(uris), func(i, j int) { + uris[i], uris[j] = uris[j], uris[i] + }) + + return uris[0], nil +} + +func (s *randomSelector) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { + return next.RoundTrip(req) } diff --git a/conjure-go-client/httpclient/internal/random_scorer_test.go b/conjure-go-client/httpclient/internal/random_selector_test.go similarity index 68% rename from conjure-go-client/httpclient/internal/random_scorer_test.go rename to conjure-go-client/httpclient/internal/random_selector_test.go index 43c332e4..063baf5f 100644 --- a/conjure-go-client/httpclient/internal/random_scorer_test.go +++ b/conjure-go-client/httpclient/internal/random_selector_test.go @@ -21,11 +21,14 @@ import ( "github.com/stretchr/testify/assert" ) -func TestRandomScorerGetURIsRandomizes(t *testing.T) { +func TestRandomSelector_Select(t *testing.T) { uris := []string{"uri1", "uri2", "uri3", "uri4", "uri5"} - scorer := NewRandomURIScoringMiddleware(uris, func() int64 { return time.Now().UnixNano() }) - scoredUris1 := scorer.GetURIsInOrderOfIncreasingScore() - scoredUris2 := scorer.GetURIsInOrderOfIncreasingScore() - assert.ElementsMatch(t, scoredUris1, scoredUris2) - assert.NotEqual(t, scoredUris1, scoredUris2) + scorer := NewRandomURISelector(func() int64 { return time.Now().UnixNano() }) + uri, err := scorer.Select(uris, nil) + assert.NoError(t, err) + assert.Contains(t, uris, uri) + + uri2, err := scorer.Select(uris, nil) + assert.NoError(t, err) + assert.Contains(t, uris, uri2) } diff --git a/conjure-go-client/httpclient/internal/refreshable_scorer.go b/conjure-go-client/httpclient/internal/refreshable_scorer.go deleted file mode 100644 index 4ab94363..00000000 --- a/conjure-go-client/httpclient/internal/refreshable_scorer.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2021 Palantir Technologies. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "github.com/palantir/pkg/refreshable" -) - -type RefreshableURIScoringMiddleware interface { - CurrentURIScoringMiddleware() URIScoringMiddleware -} - -func NewRefreshableURIScoringMiddleware(uris refreshable.StringSlice, constructor func([]string) URIScoringMiddleware) RefreshableURIScoringMiddleware { - return refreshableURIScoringMiddleware{uris.MapStringSlice(func(uris []string) interface{} { - return constructor(uris) - })} -} - -type refreshableURIScoringMiddleware struct{ refreshable.Refreshable } - -func (r refreshableURIScoringMiddleware) CurrentURIScoringMiddleware() URIScoringMiddleware { - return r.Current().(URIScoringMiddleware) -} diff --git a/conjure-go-client/httpclient/internal/request_retrier.go b/conjure-go-client/httpclient/internal/request_retrier.go index e7c738f7..e9ac970e 100644 --- a/conjure-go-client/httpclient/internal/request_retrier.go +++ b/conjure-go-client/httpclient/internal/request_retrier.go @@ -22,38 +22,28 @@ import ( "github.com/palantir/pkg/retry" ) -const ( - meshSchemePrefix = "mesh-" -) - -// RequestRetrier manages URIs for an HTTP client, providing an API which determines whether requests should be retries -// and supplying the correct URL for the client to retry. -// In the case of servers in a service-mesh, requests will never be retried and the mesh URI will only be returned on the -// first call to GetNextURI +// RequestRetrier manages the lifecylce of a single request. It will tracks the +// backoff timing between subsequent requests. The retrier should only suggest +// a retry if the previous request returned a redirect or is a mesh URI. In the +// case of a mesh URI being detected, the request retrier will only attempt the +// request once. type RequestRetrier struct { - currentURI string - retrier retry.Retrier - uris []string - offset int - relocatedURIs map[string]struct{} - failedURIs map[string]struct{} - maxAttempts int - attemptCount int + retrier retry.Retrier + + maxAttempts int + attemptCount int } // NewRequestRetrier creates a new request retrier. // Regardless of maxAttempts, mesh URIs will never be retried. -func NewRequestRetrier(uris []string, retrier retry.Retrier, maxAttempts int) *RequestRetrier { - offset := 0 +func NewRequestRetrier( + retrier retry.Retrier, + maxAttempts int, +) *RequestRetrier { return &RequestRetrier{ - currentURI: uris[offset], - retrier: retrier, - uris: uris, - offset: offset, - relocatedURIs: map[string]struct{}{}, - failedURIs: map[string]struct{}{}, - maxAttempts: maxAttempts, - attemptCount: 0, + retrier: retrier, + maxAttempts: maxAttempts, + attemptCount: 0, } } @@ -65,119 +55,73 @@ func (r *RequestRetrier) attemptsRemaining() bool { return r.attemptCount < r.maxAttempts } -// GetNextURI returns the next URI a client should use, or empty string if no suitable URI remaining to retry. -// isRelocated is true when the URI comes from a redirect's Location header. In this case, it already includes the request path. -func (r *RequestRetrier) GetNextURI(resp *http.Response, respErr error) (uri string, isRelocated bool) { - defer func() { - r.attemptCount++ - }() +// Next returns true if a subsequent request attempt should be attempted. If uses the previous response/resp err (if +// provided) to determine if the request should be attempted. If the returned value is true, the retrier will have +// waited the desired backoff interval before returning when applicable. If the previous response was a redirect, the +// retrier will also return the URL that should be used for the new next request. +func (r *RequestRetrier) Next(resp *http.Response, err error) (bool, *url.URL) { + defer func() { r.attemptCount++ }() + // should always try first request if r.attemptCount == 0 { - // First attempt is always successful. Trigger the first retry so later calls have backoff - // but ignore the returned value to ensure that the client can instrument the request even - // if the context is done. + // Trigger the first retry so later calls have backoff but ignore the returned value to ensure that the + // client can instrument the request even if the context is done. r.retrier.Next() - return r.removeMeshSchemeIfPresent(r.currentURI), false + return true, nil } + if !r.attemptsRemaining() { // Retries exhausted - return "", false - } - if r.isMeshURI(r.currentURI) { - // Mesh uris don't get retried - return "", false + return false, nil } - retryFn := r.getRetryFn(resp, respErr) - if retryFn == nil { - // The previous response was not retryable - return "", false - } - // Updates currentURI - if !retryFn() { - return "", false - } - return r.currentURI, r.isRelocatedURI(r.currentURI) -} -func (r *RequestRetrier) getRetryFn(resp *http.Response, respErr error) func() bool { - errCode, _ := StatusCodeFromError(respErr) - if retryOther, _ := isThrottleResponse(resp, errCode); retryOther { - // 429: throttle - // Immediately backoff and select the next URI. - // TODO(whickman): use the retry-after header once #81 is resolved - return r.nextURIAndBackoff - } else if isUnavailableResponse(resp, errCode) { - // 503: go to next node - return r.nextURIOrBackoff - } else if shouldTryOther, otherURI := isRetryOtherResponse(resp, respErr, errCode); shouldTryOther { - // 307 or 308: go to next node, or particular node if provided. - if otherURI != nil { - return func() bool { - r.setURIAndResetBackoff(otherURI) - return true - } - } - return r.nextURIOrBackoff - } else if errCode >= http.StatusBadRequest && errCode < http.StatusInternalServerError { - return nil - } else if resp == nil { - // if we get a nil response, we can assume there is a problem with host and can move on to the next. - return r.nextURIOrBackoff + if r.isSuccess(resp) { + return false, nil } - return nil -} -func (r *RequestRetrier) setURIAndResetBackoff(otherURI *url.URL) { - // If the URI returned by relocation header is a relative path - // We will resolve it with the current URI - if !otherURI.IsAbs() { - if currentURI := parseLocationURL(r.currentURI); currentURI != nil { - otherURI = currentURI.ResolveReference(otherURI) - } + if r.isNonRetryableClientError(resp, err) { + return false, nil } - nextURI := otherURI.String() - r.relocatedURIs[otherURI.String()] = struct{}{} - r.retrier.Reset() - r.currentURI = nextURI -} -// If lastURI was already marked failed, we perform a backoff as determined by the retrier before returning the next URI and its offset. -// Otherwise, we add lastURI to failedURIs and return the next URI and its offset immediately. -func (r *RequestRetrier) nextURIOrBackoff() bool { - _, performBackoff := r.failedURIs[r.currentURI] - r.markFailedAndMoveToNextURI() - // If the URI has failed before, perform a backoff - if performBackoff || len(r.uris) == 1 { - return r.retrier.Next() + // handle redirects + if tryOther, otherURI := isRetryOtherResponse(resp, err); tryOther { + return true, otherURI } - return true -} -// Marks the current URI as failed, gets the next URI, and performs a backoff as determined by the retrier. -func (r *RequestRetrier) nextURIAndBackoff() bool { - r.markFailedAndMoveToNextURI() - return r.retrier.Next() -} + // don't retry mesh uris + if r.isMeshURI(resp) { + return false, nil + } -func (r *RequestRetrier) markFailedAndMoveToNextURI() { - r.failedURIs[r.currentURI] = struct{}{} - nextURIOffset := (r.offset + 1) % len(r.uris) - nextURI := r.uris[nextURIOffset] - r.currentURI = nextURI - r.offset = nextURIOffset + // retry with backoff + return r.retrier.Next(), nil } -func (r *RequestRetrier) removeMeshSchemeIfPresent(uri string) string { - if r.isMeshURI(uri) { - return strings.Replace(uri, meshSchemePrefix, "", 1) +func (*RequestRetrier) isSuccess(resp *http.Response) bool { + if resp == nil { + return false } - return uri + // Check for a 2XX status + return resp.StatusCode >= 200 && resp.StatusCode < 300 } -func (r *RequestRetrier) isMeshURI(uri string) bool { - return strings.HasPrefix(uri, meshSchemePrefix) +func (*RequestRetrier) isNonRetryableClientError(resp *http.Response, err error) bool { + errCode, ok := StatusCodeFromError(err) + // Check for a 4XX status parsed from the error or in the response + if ok && isClientError(errCode) && errCode != StatusCodeThrottle { + return true + } + if resp != nil && isClientError(resp.StatusCode) { + // 429 is retryable + if isThrottle, _ := isThrottleResponse(resp, errCode); !isThrottle { + return true + } + } + return false } -func (r *RequestRetrier) isRelocatedURI(uri string) bool { - _, relocatedURI := r.relocatedURIs[uri] - return relocatedURI +func (*RequestRetrier) isMeshURI(resp *http.Response) bool { + if resp == nil || resp.Request == nil { + return false + } + return strings.HasPrefix(getBaseURI(resp.Request.URL), meshSchemePrefix) } diff --git a/conjure-go-client/httpclient/internal/request_retrier_test.go b/conjure-go-client/httpclient/internal/request_retrier_test.go index ae830514..ceffad8a 100644 --- a/conjure-go-client/httpclient/internal/request_retrier_test.go +++ b/conjure-go-client/httpclient/internal/request_retrier_test.go @@ -16,7 +16,9 @@ package internal import ( "context" + "errors" "net/http" + "net/url" "testing" "time" @@ -26,75 +28,73 @@ import ( "github.com/stretchr/testify/require" ) -var _ retry.Retrier = &mockRetrier{} - func TestRequestRetrier_HandleMeshURI(t *testing.T) { - r := NewRequestRetrier([]string{"mesh-http://example.com"}, retry.Start(context.Background()), 1) - uri, _ := r.GetNextURI(nil, nil) - require.Equal(t, uri, "http://example.com") + r := NewRequestRetrier(retry.Start(context.Background()), 1) + shouldRetry, _ := r.Next(nil, nil) + require.True(t, shouldRetry) respErr := werror.ErrorWithContextParams(context.Background(), "error", werror.SafeParam("statusCode", 429)) - uri, _ = r.GetNextURI(nil, respErr) - require.Empty(t, uri) + shouldRetry, _ = r.Next(nil, respErr) + require.False(t, shouldRetry) } func TestRequestRetrier_AttemptCount(t *testing.T) { maxAttempts := 3 - r := NewRequestRetrier([]string{"https://example.com"}, retry.Start(context.Background()), maxAttempts) + r := NewRequestRetrier(retry.Start(context.Background()), maxAttempts) // first request is not a retry - uri, _ := r.GetNextURI(nil, nil) - require.Equal(t, uri, "https://example.com") + shouldRetry, _ := r.Next(nil, nil) + require.True(t, shouldRetry) for i := 0; i < maxAttempts-1; i++ { - uri, _ = r.GetNextURI(nil, nil) - require.Equal(t, uri, "https://example.com") + shouldRetry, _ = r.Next(nil, errors.New("error")) + require.True(t, shouldRetry) } - uri, _ = r.GetNextURI(nil, nil) - require.Empty(t, uri) + shouldRetry, _ = r.Next(nil, nil) + require.False(t, shouldRetry) } func TestRequestRetrier_UnlimitedAttempts(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - r := NewRequestRetrier([]string{"https://example.com"}, retry.Start(ctx, retry.WithInitialBackoff(50*time.Millisecond), retry.WithRandomizationFactor(0)), 0) + r := NewRequestRetrier(retry.Start(ctx, retry.WithInitialBackoff(50*time.Millisecond), retry.WithRandomizationFactor(0)), 0) startTime := time.Now() - uri, _ := r.GetNextURI(nil, nil) - require.Equal(t, uri, "https://example.com") + shouldRetry, _ := r.Next(nil, nil) + require.True(t, shouldRetry) require.Lessf(t, time.Since(startTime), 49*time.Millisecond, "first GetNextURI should not have any delay") startTime = time.Now() - uri, _ = r.GetNextURI(nil, nil) - require.Equal(t, uri, "https://example.com") + shouldRetry, _ = r.Next(nil, errors.New("error")) + require.True(t, shouldRetry) assert.Greater(t, time.Since(startTime), 50*time.Millisecond, "delay should be at least 1 backoff") assert.Less(t, time.Since(startTime), 100*time.Millisecond, "delay should be less than 2 backoffs") startTime = time.Now() - uri, _ = r.GetNextURI(nil, nil) - require.Equal(t, uri, "https://example.com") + shouldRetry, _ = r.Next(nil, errors.New("error")) + require.True(t, shouldRetry) assert.Greater(t, time.Since(startTime), 100*time.Millisecond, "delay should be at least 2 backoffs") assert.Less(t, time.Since(startTime), 200*time.Millisecond, "delay should be less than 3 backoffs") // Success should stop retries - uri, _ = r.GetNextURI(&http.Response{StatusCode: 200}, nil) - require.Empty(t, uri) + shouldRetry, _ = r.Next(&http.Response{StatusCode: http.StatusOK}, nil) + require.False(t, shouldRetry) } func TestRequestRetrier_ContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - r := NewRequestRetrier([]string{"https://example.com"}, retry.Start(ctx), 0) + r := NewRequestRetrier(retry.Start(ctx), 0) // First attempt should return a URI to ensure that the client can instrument the request even // if the context is done - uri, _ := r.GetNextURI(nil, nil) - require.Equal(t, uri, "https://example.com") + shouldRetry, _ := r.Next(nil, nil) + require.True(t, shouldRetry) // Subsequent attempt should stop retries - uri, _ = r.GetNextURI(nil, nil) - require.Empty(t, uri) + shouldRetry, _ = r.Next(nil, nil) + require.False(t, shouldRetry) } func TestRequestRetrier_UsesLocationHeader(t *testing.T) { @@ -103,177 +103,105 @@ func TestRequestRetrier_UsesLocationHeader(t *testing.T) { Header: http.Header{"Location": []string{"http://example.com"}}, } - r := NewRequestRetrier([]string{"a"}, retry.Start(context.Background()), 2) - uri, isRelocated := r.GetNextURI(nil, nil) - require.Equal(t, uri, "a") - require.False(t, isRelocated) + r := NewRequestRetrier(retry.Start(context.Background()), 2) + shouldRetry, uri := r.Next(nil, nil) + require.True(t, shouldRetry) + require.Nil(t, uri) - uri, isRelocated = r.GetNextURI(respWithLocationHeader, nil) - require.Equal(t, uri, "http://example.com") - require.True(t, isRelocated) + shouldRetry, uri = r.Next(respWithLocationHeader, nil) + require.Equal(t, uri.String(), "http://example.com") + require.True(t, shouldRetry) } func TestRequestRetrier_UsesLocationFromErr(t *testing.T) { - r := NewRequestRetrier([]string{"http://example-1.com"}, retry.Start(context.Background()), 2) + r := NewRequestRetrier(retry.Start(context.Background()), 2) + respErr := werror.ErrorWithContextParams(context.Background(), "307", werror.SafeParam("statusCode", 307), werror.SafeParam("location", "http://example-2.com")) - uri, isRelocated := r.GetNextURI(nil, nil) - require.Equal(t, uri, "http://example-1.com") - require.False(t, isRelocated) + // first request is not a retry + shouldRetry, uri := r.Next(nil, nil) + require.True(t, shouldRetry) + require.Nil(t, uri) - uri, isRelocated = r.GetNextURI(nil, respErr) - require.Equal(t, uri, "http://example-2.com") - require.True(t, isRelocated) + shouldRetry, uri = r.Next(nil, respErr) + require.NotNil(t, uri) + require.Equal(t, uri.String(), "http://example-2.com") + require.True(t, shouldRetry) } -func TestRequestRetrier_GetNextURI(t *testing.T) { +func TestRequestRetrier_Next(t *testing.T) { for _, tc := range []struct { name string resp *http.Response respErr error - uris []string + retryURI *url.URL shouldRetry bool - shouldRetrySameURI bool shouldRetryBackoff bool shouldRetryReset bool }{ { - name: "returns error if response exists and doesn't appear retryable", - resp: &http.Response{}, - respErr: nil, - uris: []string{"a", "b"}, - shouldRetry: false, - shouldRetrySameURI: false, - shouldRetryBackoff: false, - shouldRetryReset: false, - }, - { - name: "returns error if error code not retryable", - resp: &http.Response{}, - respErr: nil, - uris: []string{"a", "b"}, - shouldRetry: false, - shouldRetrySameURI: false, - shouldRetryBackoff: false, - shouldRetryReset: false, - }, - { - name: "returns a URI if response and error are nil", - resp: nil, - respErr: nil, - uris: []string{"a", "b"}, - shouldRetry: true, - shouldRetrySameURI: false, - shouldRetryBackoff: false, - shouldRetryReset: false, - }, - { - name: "returns a URI if response and error are nil", - resp: nil, - respErr: nil, - uris: []string{"a", "b"}, - shouldRetry: true, - shouldRetrySameURI: false, - shouldRetryBackoff: false, - shouldRetryReset: false, - }, - { - name: "retries and backs off the single URI if response and error are nil", - resp: nil, - respErr: nil, - uris: []string{"a"}, - shouldRetry: true, - shouldRetrySameURI: true, - shouldRetryBackoff: true, - shouldRetryReset: false, - }, - { - name: "returns a new URI if unavailable", + name: "retries if unavailable", resp: nil, respErr: werror.ErrorWithContextParams(context.Background(), "503", werror.SafeParam("statusCode", 503)), - uris: []string{"a", "b"}, shouldRetry: true, - shouldRetrySameURI: false, shouldRetryBackoff: false, shouldRetryReset: false, }, { - name: "retries and backs off the single URI if unavailable", - resp: nil, - respErr: werror.ErrorWithContextParams(context.Background(), "503", werror.SafeParam("statusCode", 503)), - uris: []string{"a"}, - shouldRetry: true, - shouldRetrySameURI: true, - shouldRetryBackoff: true, - shouldRetryReset: false, - }, - { - name: "returns a new URI and backs off if throttled", + name: "retries and backs off if throttled", resp: nil, respErr: werror.ErrorWithContextParams(context.Background(), "429", werror.SafeParam("statusCode", 429)), - uris: []string{"a", "b"}, shouldRetry: true, - shouldRetrySameURI: false, shouldRetryBackoff: true, shouldRetryReset: false, }, { - name: "retries single URI and backs off if throttled", - resp: nil, - respErr: werror.ErrorWithContextParams(context.Background(), "429", werror.SafeParam("statusCode", 429)), - uris: []string{"a"}, - shouldRetry: true, - shouldRetrySameURI: true, - shouldRetryBackoff: true, - shouldRetryReset: false, - }, - { - name: "retries another URI if gets retry other response without location", + name: "retries with no backoff if gets retry other response without location", resp: &http.Response{ StatusCode: StatusCodeRetryOther, }, respErr: nil, - uris: []string{"a", "b"}, shouldRetry: true, - shouldRetrySameURI: false, shouldRetryBackoff: false, shouldRetryReset: false, }, { - name: "retries single URI and backs off if gets retry other response without location", + name: "retries with no backoff if gets retry temporary redirect response without location", resp: &http.Response{ - StatusCode: StatusCodeRetryOther, + StatusCode: StatusCodeRetryTemporaryRedirect, }, respErr: nil, - uris: []string{"a"}, shouldRetry: true, - shouldRetrySameURI: true, - shouldRetryBackoff: true, + shouldRetryBackoff: false, shouldRetryReset: false, }, { - name: "retries another URI if gets retry temporary redirect response without location", + name: "retries with no backoff if gets retry temporary redirect response with a location", resp: &http.Response{ StatusCode: StatusCodeRetryTemporaryRedirect, }, - respErr: nil, - uris: []string{"a", "b"}, + respErr: werror.ErrorWithContextParams(context.Background(), + "307", + werror.SafeParam("statusCode", 307), + werror.SafeParam("location", "http://example-2.com")), + retryURI: mustNewURL("http://example-2.com"), shouldRetry: true, - shouldRetrySameURI: false, - shouldRetryBackoff: false, + shouldRetryBackoff: true, shouldRetryReset: false, }, { - name: "retries single URI and backs off if gets retry temporary redirect response without location", + name: "retries with no backoff if gets retry other redirect response with a location", resp: &http.Response{ - StatusCode: StatusCodeRetryTemporaryRedirect, + StatusCode: StatusCodeRetryOther, }, - respErr: nil, - uris: []string{"a"}, + respErr: werror.ErrorWithContextParams(context.Background(), + "308", + werror.SafeParam("statusCode", 308), + werror.SafeParam("location", "http://example-2.com")), + retryURI: mustNewURL("http://example-2.com"), shouldRetry: true, - shouldRetrySameURI: true, shouldRetryBackoff: true, shouldRetryReset: false, }, @@ -282,9 +210,7 @@ func TestRequestRetrier_GetNextURI(t *testing.T) { resp: &http.Response{ StatusCode: 400, }, - uris: []string{"a", "b"}, shouldRetry: false, - shouldRetrySameURI: false, shouldRetryBackoff: false, shouldRetryReset: false, }, @@ -293,45 +219,37 @@ func TestRequestRetrier_GetNextURI(t *testing.T) { resp: &http.Response{ StatusCode: 404, }, - uris: []string{"a", "b"}, shouldRetry: false, - shouldRetrySameURI: false, shouldRetryBackoff: false, shouldRetryReset: false, }, { name: "does not retry 400 errors", respErr: werror.ErrorWithContextParams(context.Background(), "400", werror.SafeParam("statusCode", 400)), - uris: []string{"a", "b"}, shouldRetry: false, - shouldRetrySameURI: false, shouldRetryBackoff: false, shouldRetryReset: false, }, { name: "does not retry 404s", respErr: werror.ErrorWithContextParams(context.Background(), "404", werror.SafeParam("statusCode", 404)), - uris: []string{"a", "b"}, shouldRetry: false, - shouldRetrySameURI: false, shouldRetryBackoff: false, shouldRetryReset: false, }, } { t.Run(tc.name, func(t *testing.T) { retrier := newMockRetrier() - r := NewRequestRetrier(tc.uris, retrier, 2) + r := NewRequestRetrier(retrier, 2) // first URI isn't a retry - firstURI, _ := r.GetNextURI(nil, nil) - require.NotEmpty(t, firstURI) + shouldRetry, _ := r.Next(nil, nil) + require.True(t, shouldRetry) - retryURI, _ := r.GetNextURI(tc.resp, tc.respErr) + shouldRetry, retryURI := r.Next(tc.resp, tc.respErr) + assert.Equal(t, tc.shouldRetry, shouldRetry) if tc.shouldRetry { - require.Contains(t, tc.uris, retryURI) - if tc.shouldRetrySameURI { - require.Equal(t, retryURI, firstURI) - } else { - require.NotEqual(t, retryURI, firstURI) + if tc.retryURI != nil { + require.Equal(t, tc.retryURI.String(), retryURI.String()) } if tc.shouldRetryReset { require.True(t, retrier.DidReset) @@ -340,7 +258,7 @@ func TestRequestRetrier_GetNextURI(t *testing.T) { require.True(t, retrier.DidGetNext) } } else { - require.Empty(t, retryURI) + require.Nil(t, retryURI) } }) } @@ -370,3 +288,11 @@ func (m *mockRetrier) Next() bool { func (m *mockRetrier) CurrentAttempt() int { return 0 } + +func mustNewURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} diff --git a/conjure-go-client/httpclient/internal/retry.go b/conjure-go-client/httpclient/internal/retry.go index f4c6b484..4fe4247d 100644 --- a/conjure-go-client/httpclient/internal/retry.go +++ b/conjure-go-client/httpclient/internal/retry.go @@ -59,12 +59,15 @@ const ( StatusCodeUnavailable = http.StatusServiceUnavailable ) -func isRetryOtherResponse(resp *http.Response, err error, errCode int) (bool, *url.URL) { +func isRetryOtherResponse(resp *http.Response, err error) (bool, *url.URL) { + errCode, _ := StatusCodeFromError(err) + // prioritize redirect from werror first if errCode == StatusCodeRetryOther || errCode == StatusCodeRetryTemporaryRedirect { locationStr, ok := LocationFromError(err) - if ok { - return true, parseLocationURL(locationStr) + if !ok { + return true, nil } + return true, parseLocationURL(locationStr) } if resp == nil { @@ -74,8 +77,11 @@ func isRetryOtherResponse(resp *http.Response, err error, errCode int) (bool, *u resp.StatusCode != StatusCodeRetryTemporaryRedirect { return false, nil } - locationStr := resp.Header.Get("Location") - return true, parseLocationURL(locationStr) + location, err := resp.Location() + if err != nil { + return true, nil + } + return true, location } func parseLocationURL(locationStr string) *url.URL { @@ -90,6 +96,8 @@ func parseLocationURL(locationStr string) *url.URL { return locationURL } +// isThrottleResponse returns true if the response a throttle response type. It +// also returns a duration after which the failed URI can be retried func isThrottleResponse(resp *http.Response, errCode int) (bool, time.Duration) { if errCode == StatusCodeThrottle { return true, 0 diff --git a/conjure-go-client/httpclient/internal/retry_test.go b/conjure-go-client/httpclient/internal/retry_test.go index 78c90856..f5595ddb 100644 --- a/conjure-go-client/httpclient/internal/retry_test.go +++ b/conjure-go-client/httpclient/internal/retry_test.go @@ -129,14 +129,14 @@ func TestRetryResponseParsers(t *testing.T) { }, } { t.Run(test.Name, func(t *testing.T) { - errCode, _ := StatusCodeFromError(test.RespErr) - isRetryOther, retryOtherURL := isRetryOtherResponse(test.Response, test.RespErr, errCode) + isRetryOther, retryOtherURL := isRetryOtherResponse(test.Response, test.RespErr) if assert.Equal(t, test.IsRetryOther, isRetryOther) && test.RetryOtherURL != "" { if assert.NotNil(t, retryOtherURL) { assert.Equal(t, test.RetryOtherURL, retryOtherURL.String()) } } + errCode, _ := StatusCodeFromError(test.RespErr) isThrottle, throttleDur := isThrottleResponse(test.Response, errCode) if assert.Equal(t, test.IsThrottle, isThrottle) { assert.WithinDuration(t, time.Now().Add(test.ThrottleDuration), time.Now().Add(throttleDur), time.Second) diff --git a/conjure-go-client/httpclient/internal/rr_selector.go b/conjure-go-client/httpclient/internal/rr_selector.go new file mode 100644 index 00000000..fa4dccf8 --- /dev/null +++ b/conjure-go-client/httpclient/internal/rr_selector.go @@ -0,0 +1,49 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "net/http" + "sync" +) + +type roundRobinSelector struct { + sync.Mutex + nanoClock func() int64 + + offset int +} + +// NewRoundRobinURISelector returns a URI scorer that uses a round robin algorithm for selecting URIs when scoring +// using a rand.Rand seeded by the nanoClock function. The middleware no-ops on each request. +func NewRoundRobinURISelector(nanoClock func() int64) URISelector { + return &roundRobinSelector{ + nanoClock: nanoClock, + } +} + +// Select implements Selector interface +func (s *roundRobinSelector) Select(uris []string, _ http.Header) (string, error) { + s.Lock() + defer s.Unlock() + + s.offset = (s.offset + 1) % len(uris) + + return uris[s.offset], nil +} + +func (s *roundRobinSelector) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { + return next.RoundTrip(req) +} diff --git a/conjure-go-client/httpclient/internal/rr_selector_test.go b/conjure-go-client/httpclient/internal/rr_selector_test.go new file mode 100644 index 00000000..d1da545d --- /dev/null +++ b/conjure-go-client/httpclient/internal/rr_selector_test.go @@ -0,0 +1,44 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRoundRobinSelector_Select(t *testing.T) { + uris := []string{"uri1", "uri2", "uri3", "uri4", "uri5"} + scorer := NewRoundRobinURISelector(func() int64 { return time.Now().UnixNano() }) + + const iterations = 100 + observed := make(map[string]int, iterations) + for i := 0; i < iterations; i++ { + uri, err := scorer.Select(uris, nil) + assert.NoError(t, err) + observed[uri] = observed[uri] + 1 + } + + occurences := make([]int, 0, len(observed)) + for _, count := range observed { + occurences = append(occurences, count) + } + + for _, v := range occurences { + assert.Equal(t, occurences[0], v) + } +} diff --git a/conjure-go-client/httpclient/internal/stateful_uri_pool.go b/conjure-go-client/httpclient/internal/stateful_uri_pool.go new file mode 100644 index 00000000..6f4fb723 --- /dev/null +++ b/conjure-go-client/httpclient/internal/stateful_uri_pool.go @@ -0,0 +1,134 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "net/http" + "net/url" + "sync" + "time" + + "github.com/palantir/pkg/refreshable" +) + +const ( + // defaultResurrectDuration is the amount of time after which + // we resurrect failed URIs + defaultResurrectDuration = time.Second * 60 + meshSchemePrefix = "mesh-" +) + +type statefulURIPool struct { + sync.RWMutex + + uris []string + failedURIs map[string]struct{} +} + +// NewStatefulURIPool returns a URIPool that keeps track of a +// refeshable set of possible URIs. It can be used as middleware to track +// server side request failures to future requests from hitting known bad servers. +func NewStatefulURIPool(uris refreshable.StringSlice) URIPool { + s := &statefulURIPool{} + s.updateURIs(uris.CurrentStringSlice()) + + _ = uris.SubscribeToStringSlice(s.updateURIs) + return s +} + +// NumURIs implements URIPool +func (s *statefulURIPool) NumURIs() int { + s.RLock() + defer s.RUnlock() + + return len(s.uris) +} + +// URIs implements URIPool +func (s *statefulURIPool) URIs() []string { + s.RLock() + defer s.RUnlock() + + uris := make([]string, 0, len(s.uris)) + for _, uri := range s.uris { + if _, ok := s.failedURIs[uri]; ok { + continue + } + uris = append(uris, uri) + } + // if all connections are "failed", then return them all + if len(uris) == 0 { + return s.uris + } + return uris +} + +// RoundTrip implements URIPool +func (s *statefulURIPool) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { + resp, err := next.RoundTrip(req) + errCode, ok := StatusCodeFromError(err) + // fall back to the status code from the response + if !ok && resp != nil { + errCode = resp.StatusCode + } + + if isThrottle, ressurectAfter := isThrottleResponse(resp, errCode); isThrottle { + s.markBackoffURI(req, ressurectAfter) + } else if isUnavailableResponse(resp, errCode) { + // 503: go to next node + s.markBackoffURI(req, defaultResurrectDuration) + } else if resp == nil { + // if we get a nil response, we can assume there is a problem with host and can move on to the next. + s.markBackoffURI(req, defaultResurrectDuration) + } + + return resp, err +} + +func (s *statefulURIPool) updateURIs(uris []string) { + result := make([]string, 0, len(uris)) + for _, uri := range uris { + // validate URIs by parsing them + u, err := url.Parse(uri) + if err != nil { + // ignore invalid uris + continue + } + result = append(result, getBaseURI(u)) + } + + s.Lock() + defer s.Unlock() + s.uris = result + s.failedURIs = make(map[string]struct{}, len(uris)) +} + +func (s *statefulURIPool) markBackoffURI(req *http.Request, dur time.Duration) { + // if duration is equal to zero, then use defaultResurrectDuration + if dur == 0 { + dur = defaultResurrectDuration + } + reqURL := getBaseURI(req.URL) + s.Lock() + defer s.Unlock() + + s.failedURIs[reqURL] = struct{}{} + + time.AfterFunc(dur, func() { + s.Lock() + defer s.Unlock() + delete(s.failedURIs, reqURL) + }) +} diff --git a/conjure-go-client/httpclient/internal/stateful_uri_pool_test.go b/conjure-go-client/httpclient/internal/stateful_uri_pool_test.go new file mode 100644 index 00000000..db3ff72e --- /dev/null +++ b/conjure-go-client/httpclient/internal/stateful_uri_pool_test.go @@ -0,0 +1,130 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "net/http" + "testing" + + "github.com/palantir/pkg/refreshable" + "github.com/stretchr/testify/assert" +) + +func TestRequestRetrier_GetNextURIs(t *testing.T) { + for _, tc := range []struct { + name string + resp *http.Response + chosenURI string + beforeURIs []string + afterURIs []string + }{ + { + name: "preserves chosen URI if response doesn't contain a handled status code", + resp: &http.Response{}, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + afterURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + }, + { + name: "remove chosen URI if response is nil", + resp: nil, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + afterURIs: []string{"https://domain1.example.com"}, + }, + { + name: "removes chosen URI if response return code is 503", + resp: &http.Response{StatusCode: http.StatusServiceUnavailable}, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + afterURIs: []string{"https://domain1.example.com"}, + }, + { + name: "preserves chosen URI if response return code is 503 but it's a single URI", + resp: &http.Response{StatusCode: http.StatusServiceUnavailable}, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com"}, + afterURIs: []string{"https://domain0.example.com"}, + }, + { + name: "removes chosen URI if response return code is 429", + resp: &http.Response{StatusCode: http.StatusTooManyRequests}, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + afterURIs: []string{"https://domain1.example.com"}, + }, + { + name: "preserves chosen URI if response return code is 429 but it's a single URI", + resp: &http.Response{StatusCode: http.StatusTooManyRequests}, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com"}, + afterURIs: []string{"https://domain0.example.com"}, + }, + { + name: "preserves chosen URI on permanent redirects", + resp: &http.Response{StatusCode: StatusCodeRetryOther}, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + afterURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + }, + { + name: "preserves chosen URI on temporary redirects", + resp: &http.Response{StatusCode: StatusCodeRetryOther}, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + afterURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + }, + { + name: "preserves chosen URI on 400 responses", + resp: &http.Response{StatusCode: http.StatusBadRequest}, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + afterURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + }, + { + name: "preserves chosen URI on 404 responses", + resp: &http.Response{StatusCode: http.StatusNotFound}, + chosenURI: "https://domain0.example.com", + beforeURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + afterURIs: []string{"https://domain0.example.com", "https://domain1.example.com"}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ref := refreshable.NewDefaultRefreshable(tc.beforeURIs) + pool := NewStatefulURIPool(refreshable.NewStringSlice(ref)) + + req, err := http.NewRequest("GET", tc.chosenURI, nil) + assert.NoError(t, err) + + resp, err := pool.RoundTrip(req, newMockRoundTripper(tc.resp)) + assert.Equal(t, tc.resp, resp) + assert.NoError(t, err) + + assert.ElementsMatch(t, tc.afterURIs, pool.URIs()) + }) + } +} + +func newMockRoundTripper(resp *http.Response) http.RoundTripper { + return &mockRoundTripper{resp: resp} +} + +type mockRoundTripper struct { + resp *http.Response +} + +func (m mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.resp, nil +} diff --git a/conjure-go-client/httpclient/response_error_decoder_middleware.go b/conjure-go-client/httpclient/response_error_decoder_middleware.go index 318dcaea..93d9a9cb 100644 --- a/conjure-go-client/httpclient/response_error_decoder_middleware.go +++ b/conjure-go-client/httpclient/response_error_decoder_middleware.go @@ -81,7 +81,10 @@ func (d restErrorDecoder) DecodeError(resp *http.Response) error { unsafeParams := map[string]interface{}{} if resp.StatusCode >= http.StatusTemporaryRedirect && resp.StatusCode < http.StatusBadRequest { - unsafeParams["location"] = resp.Header.Get("Location") + location, err := resp.Location() + if err == nil { + unsafeParams["location"] = location.String() + } } wSafeParams := werror.SafeParams(safeParams) wUnsafeParams := werror.UnsafeParams(unsafeParams) diff --git a/conjure-go-client/httpclient/response_error_decoder_middleware_test.go b/conjure-go-client/httpclient/response_error_decoder_middleware_test.go index 1e01dbbe..b15eb639 100644 --- a/conjure-go-client/httpclient/response_error_decoder_middleware_test.go +++ b/conjure-go-client/httpclient/response_error_decoder_middleware_test.go @@ -65,7 +65,7 @@ func TestErrorDecoderMiddlewares(t *testing.T) { assert.True(t, ok) assert.Equal(t, 307, code) location, ok := httpclient.LocationFromError(err) - assert.True(t, ok) + assert.False(t, ok) assert.Equal(t, "", location) }, },