From 32fc5955974f03d44480804289b45c41e49bad04 Mon Sep 17 00:00:00 2001 From: Justin Ricks Date: Fri, 22 Sep 2023 15:32:04 -0600 Subject: [PATCH] Add ability to set a per-attempt timeout --- cancelreader.go | 36 +++++++++ perattempttimeout_post17.go | 19 +++++ perattempttimeout_pre17.go | 19 +++++ rehttp.go | 41 +++++++++-- rehttp_server_post17_test.go | 138 +++++++++++++++++++++++++++++++++++ 5 files changed, 246 insertions(+), 7 deletions(-) create mode 100644 cancelreader.go create mode 100644 perattempttimeout_post17.go create mode 100644 perattempttimeout_pre17.go create mode 100644 rehttp_server_post17_test.go diff --git a/cancelreader.go b/cancelreader.go new file mode 100644 index 0000000..13f794c --- /dev/null +++ b/cancelreader.go @@ -0,0 +1,36 @@ +package rehttp + +import ( + "context" + "io" + "net/http" +) + +type cancelReader struct { + io.ReadCloser + + cancel context.CancelFunc +} + +func (r cancelReader) Close() error { + r.cancel() + return r.ReadCloser.Close() +} + +// injectCancelReader propagates the ability for the caller to cancel the request context +// once done with the response. If the transport cancels before the body stream is read, +// a race begins where the caller may be unable to read the response bytes before the stream +// is closed and an error is returned. This helper function wraps a response body in a +// io.ReadCloser that cancels the context once the body is closed, preventing a context leak. +// Solution based on https://github.com/go-kit/kit/issues/773. +func injectCancelReader(res *http.Response, cancel context.CancelFunc) *http.Response { + if res == nil { + return nil + } + + res.Body = cancelReader{ + ReadCloser: res.Body, + cancel: cancel, + } + return res +} diff --git a/perattempttimeout_post17.go b/perattempttimeout_post17.go new file mode 100644 index 0000000..e92c349 --- /dev/null +++ b/perattempttimeout_post17.go @@ -0,0 +1,19 @@ +//go:build go1.7 +// +build go1.7 + +package rehttp + +import ( + "context" + "net/http" + "time" +) + +func getRequestContext(req *http.Request) context.Context { + return req.Context() +} + +func getPerAttemptTimeoutInfo(ctx context.Context, req *http.Request, timeout time.Duration) (*http.Request, context.CancelFunc) { + tctx, cancel := context.WithTimeout(ctx, timeout) + return req.WithContext(tctx), cancel +} diff --git a/perattempttimeout_pre17.go b/perattempttimeout_pre17.go new file mode 100644 index 0000000..e44e4e5 --- /dev/null +++ b/perattempttimeout_pre17.go @@ -0,0 +1,19 @@ +//go:build !go1.7 +// +build !go1.7 + +package rehttp + +import ( + "context" + "net/http" + "time" +) + +func getRequestContext(req *http.Request) context.Context { + return nil // req.Context() doesn't exist before 1.7 +} + +func getPerAttemptTimeoutInfo(ctx context.Context, req *http.Request, timeout time.Duration) (*http.Request, context.CancelFunc) { + // req.WithContext() doesn't exist before 1.7, so noop + return req, func() {} +} diff --git a/rehttp.go b/rehttp.go index 0c02edf..724e3f8 100644 --- a/rehttp.go +++ b/rehttp.go @@ -39,7 +39,8 @@ // (https://golang.org/pkg/net/http/#Transport.CancelRequest). // // On Go1.7+, it uses the context returned by http.Request.Context -// to check for cancelled requests. +// to check for cancelled requests. Before Go1.7, PerAttemptTimeout +// has no effect. // // It should work on Go1.5, but only if there is no timeout set on the // *http.Client. Go's stdlib will return an error on the first request @@ -50,6 +51,7 @@ package rehttp import ( "bytes" + "context" "errors" "io" "io/ioutil" @@ -280,6 +282,22 @@ type Transport struct { // is non-nil. PreventRetryWithBody bool + // PerAttemptTimeout can be optionally set to add per-attempt timeouts. + // These may be used in place of or in conjunction with overall timeouts. + // For example, a per-attempt timeout of 5s would mean an attempt will + // be canceled after 5s, then the delay fn will be consulted before + // potentially making another attempt, which will again be capped at 5s. + // This means that the overall duration may be up to + // (PerAttemptTimeout + delay) * n, where n is the maximum attempts. + // If using an overall timeout (whether on the http client or the request + // context), the request will stop at whichever timeout is reached first. + // Your RetryFn can determine if a request hit the per-attempt timeout by + // checking if attempt.Error == context.DeadlineExceeded (or use errors.Is + // on go 1.13+). + // time.Duration(0) signals that no per-attempt timeout should be used. + // Note that before go 1.7 this option has no effect. + PerAttemptTimeout time.Duration + // retry is a function that determines if the request should be retried. // Unless a retry is prevented based on PreventRetryWithBody, all requests // go through that function, even those that are typically considered @@ -297,6 +315,9 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { var attempt int preventRetry := req.Body != nil && req.Body != http.NoBody && t.PreventRetryWithBody + // used as a baseline to set fresh timeouts per-attempt if needed + ctx := getRequestContext(req) + // get the done cancellation channel for the context, will be nil // for < go1.7. done := contextForRequest(req) @@ -317,19 +338,24 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { } for { - res, err := t.RoundTripper.RoundTrip(req) + var cancel context.CancelFunc = func() {} // empty unless a timeout is set + reqWithTimeout := req + if t.PerAttemptTimeout != 0 { + reqWithTimeout, cancel = getPerAttemptTimeoutInfo(ctx, req, t.PerAttemptTimeout) + } + res, err := t.RoundTripper.RoundTrip(reqWithTimeout) if preventRetry { - return res, err + return injectCancelReader(res, cancel), err } retry, delay := t.retry(Attempt{ - Request: req, + Request: reqWithTimeout, Response: res, Index: attempt, Error: err, }) if !retry { - return res, err + return injectCancelReader(res, cancel), err } if br != nil { @@ -338,15 +364,16 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { // to reset on the request is the body, if any. if _, serr := br.Seek(0, 0); serr != nil { // failed to retry, return the results - return res, err + return injectCancelReader(res, cancel), err } - req.Body = ioutil.NopCloser(br) + reqWithTimeout.Body = ioutil.NopCloser(br) } // close the disposed response's body, if any if res != nil { io.Copy(ioutil.Discard, res.Body) res.Body.Close() } + cancel() // we're done with this response and won't be returning it, so it's safe to cancel immediately select { case <-time.After(delay): diff --git a/rehttp_server_post17_test.go b/rehttp_server_post17_test.go new file mode 100644 index 0000000..d9f2f2b --- /dev/null +++ b/rehttp_server_post17_test.go @@ -0,0 +1,138 @@ +//go:build go1.7 +// +build go1.7 + +package rehttp + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestTransport_RoundTripTimeouts(t *testing.T) { + // to keep track of any open server requests and ensure the correct number of requests were made + ch := make(chan bool, 4) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ch <- true + time.Sleep(time.Millisecond * 100) + w.WriteHeader(http.StatusTooManyRequests) + })) + defer ts.Close() + + tr := NewTransport( + http.DefaultTransport, + RetryAll(RetryMaxRetries(3), RetryAny( + RetryStatuses(http.StatusTooManyRequests), // retry 429s + func(attempt Attempt) bool { // retry context deadline exceeded errors + return attempt.Error != nil && attempt.Error == context.DeadlineExceeded // errors.Is requires go 1.13+ + })), + ConstDelay(0), + ) + tr.PerAttemptTimeout = time.Millisecond * 10 // short timeout + + client := http.Client{ + Transport: tr, + } + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Errorf("error creating request: %s", err) + } + + _, err = client.Do(req) + // make sure server has finished and got 4 attempts + <-ch + <-ch + <-ch + <-ch + if err == nil { + t.Error("expected timeout error doing request but got nil") + } + + // repeat the above but this time add a context timeout that will expire before all attempts complete + ch = make(chan bool, 2) + ctx, cancelFunc := context.WithTimeout(context.Background(), time.Millisecond*15) + _, err = client.Do(req.WithContext(ctx)) + cancelFunc() + // should only make 2 attempts + <-ch + <-ch + if err == nil { + t.Error("expected timeout error doing request but got nil") + } + + // now increase the timeout restriction + ch = make(chan bool, 4) + tr.PerAttemptTimeout = time.Second + res, err := client.Do(req) + + // should have attempted 4 times without going over the timeout + <-ch + <-ch + <-ch + <-ch + + if err != nil { + t.Errorf("got unexpected error doing request: %s", err) + } + if res == nil || res.StatusCode != http.StatusTooManyRequests { + t.Errorf("status code does not match expected: got %d, want %d", res.StatusCode, http.StatusTooManyRequests) + } + + // now remove the timeout restriction + ch = make(chan bool, 4) + tr.PerAttemptTimeout = time.Duration(0) + res, err = client.Do(req) + // should have attempted 4 times without going over the timeout + <-ch + <-ch + <-ch + <-ch + + if err != nil { + t.Errorf("got unexpected error doing request: %s", err) + } + if res == nil || res.StatusCode != http.StatusTooManyRequests { + t.Errorf("status code does not match expected: got %d, want %d", res.StatusCode, http.StatusTooManyRequests) + } +} + +// TestCancelReader is meant to test that the cancel reader is correctly +// preventing the race-case of being unable to read the body due to a +// preemptively-canceled context. +func TestCancelReader(t *testing.T) { + rt := NewTransport(http.DefaultTransport, RetryMaxRetries(1), ConstDelay(0)) + rt.PerAttemptTimeout = time.Millisecond * 100 + client := http.Client{ + Transport: rt, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Millisecond * 10) + w.WriteHeader(http.StatusOK) + // need a decent number of bytes to make the race case more likely to fail + // https://github.com/go-kit/kit/issues/773 + w.Write(make([]byte, 102400)) + })) + defer ts.Close() + + ctx := context.Background() + + req, _ := http.NewRequest(http.MethodGet, ts.URL, nil) + res, err := client.Do(req.WithContext(ctx)) + if err != nil { + t.Fatalf("unexpected error creating request: %s", err) + } + defer res.Body.Close() + b, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("error reading response body: %s", err) + } + if len(b) != 102400 { + t.Errorf("response byte length does not match expected. got %d, want %d", len(b), 102400) + } +}