-
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ability to set a per-attempt timeout
- Loading branch information
1 parent
eda55af
commit 88f1256
Showing
5 changed files
with
243 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package rehttp | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"net/http" | ||
) | ||
|
||
type cancelReader struct { | ||
io.ReadCloser | ||
|
||
cancel context.CancelFunc | ||
} | ||
|
||
func (r cancelReader) Close() error { | ||
r.cancel() | ||
return r.ReadCloser.Close() | ||
} | ||
|
||
// injectCancelReader propagates the ability for the caller to cancel the request context | ||
// once done with the response. If the transport cancels before the body stream is read, | ||
// a race begins where the caller may be unable to read the response bytes before the stream | ||
// is closed and an error is returned. This helper function wraps a response body in a | ||
// io.ReadCloser that cancels the context once the body is closed, preventing a context leak. | ||
// Solution based on https://github.com/go-kit/kit/issues/773. | ||
func injectCancelReader(res *http.Response, cancel context.CancelFunc) *http.Response { | ||
if res == nil { | ||
return nil | ||
} | ||
|
||
res.Body = cancelReader{ | ||
ReadCloser: res.Body, | ||
cancel: cancel, | ||
} | ||
return res | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
//go:build go1.7 | ||
// +build go1.7 | ||
|
||
package rehttp | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"time" | ||
) | ||
|
||
func getRequestContext(req *http.Request) context.Context { | ||
return req.Context() | ||
} | ||
|
||
func getPerAttemptTimeoutInfo(ctx context.Context, req *http.Request, timeout time.Duration) (*http.Request, context.CancelFunc) { | ||
tctx, cancel := context.WithTimeout(ctx, timeout) | ||
return req.WithContext(tctx), cancel | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
//go:build !go1.7 | ||
// +build !go1.7 | ||
|
||
package rehttp | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"time" | ||
) | ||
|
||
func getRequestContext(req *http.Request) context.Context { | ||
return nil // req.Context() doesn't exist before 1.7 | ||
} | ||
|
||
func getPerAttemptTimeoutInfo(ctx context.Context, req *http.Request, timeout time.Duration) (*http.Request, context.CancelFunc) { | ||
// req.WithContext() doesn't exist before 1.7, so noop | ||
return req, func() {} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
//go:build go1.7 | ||
// +build go1.7 | ||
|
||
package rehttp | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"net/http" | ||
"net/http/httptest" | ||
"sync" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestTransport_RoundTripTimeouts(t *testing.T) { | ||
attemptCountMu := sync.Mutex{} | ||
attemptCount := 0 | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
attemptCountMu.Lock() | ||
attemptCount++ | ||
attemptCountMu.Unlock() | ||
time.Sleep(time.Millisecond * 100) | ||
w.WriteHeader(http.StatusTooManyRequests) | ||
})) | ||
defer ts.Close() | ||
|
||
tr := NewTransport( | ||
http.DefaultTransport, | ||
RetryAll(RetryMaxRetries(3), RetryAny( | ||
RetryStatuses(http.StatusTooManyRequests), // retry 429s | ||
func(attempt Attempt) bool { // retry context deadline exceeded errors | ||
return attempt.Error != nil && attempt.Error == context.DeadlineExceeded // errors.Is requires go 1.13+ | ||
})), | ||
ConstDelay(time.Duration(0)), | ||
) | ||
tr.PerAttemptTimeout = time.Millisecond * 10 // short timeout | ||
|
||
client := http.Client{ | ||
Transport: tr, | ||
} | ||
|
||
req, err := http.NewRequest(http.MethodGet, ts.URL, nil) | ||
if err != nil { | ||
t.Errorf("error creating request: %s", err) | ||
} | ||
|
||
_, err = client.Do(req) | ||
time.Sleep(time.Millisecond * 100) // let the server finish sleeping | ||
// should have attempted 4 times but errored due to timeout | ||
attemptCountMu.Lock() | ||
if attemptCount != 4 { | ||
t.Errorf("attempt count does not match expected: got %d, want %d", attemptCount, 4) | ||
} | ||
if err == nil { | ||
t.Error("expected timeout error doing request but got nil") | ||
} | ||
|
||
attemptCount = 0 | ||
attemptCountMu.Unlock() | ||
|
||
// now increase the timeout restriction | ||
tr.PerAttemptTimeout = time.Second | ||
res, err := client.Do(req) | ||
attemptCountMu.Lock() | ||
// should have attempted 4 times without going over the timeout | ||
if attemptCount != 4 { | ||
t.Errorf("attempt count does not match expected: got %d, want %d", attemptCount, 4) | ||
} | ||
if err != nil { | ||
t.Errorf("got unexpected error doing request: %s", err) | ||
} | ||
if res == nil || res.StatusCode != http.StatusTooManyRequests { | ||
t.Errorf("status code does not match expected: got %d, want %d", res.StatusCode, http.StatusTooManyRequests) | ||
} | ||
|
||
attemptCount = 0 | ||
attemptCountMu.Unlock() | ||
|
||
// now remove the timeout restriction | ||
tr.PerAttemptTimeout = time.Duration(0) | ||
res, err = client.Do(req) | ||
attemptCountMu.Lock() | ||
// should have attempted 4 times without going over the timeout | ||
if attemptCount != 4 { | ||
t.Errorf("attempt count does not match expected: got %d, want %d", attemptCount, 4) | ||
} | ||
if err != nil { | ||
t.Errorf("got unexpected error doing request: %s", err) | ||
} | ||
if res == nil || res.StatusCode != http.StatusTooManyRequests { | ||
t.Errorf("status code does not match expected: got %d, want %d", res.StatusCode, http.StatusTooManyRequests) | ||
} | ||
attemptCountMu.Unlock() | ||
} | ||
|
||
// TestCancelReader is meant to test that the cancel reader is correctly | ||
// preventing the race-case of being unable to read the body due to a | ||
// preemptively-canceled context. | ||
func TestCancelReader(t *testing.T) { | ||
rt := NewTransport(http.DefaultTransport, RetryMaxRetries(1), ConstDelay(time.Duration(0))) | ||
rt.PerAttemptTimeout = time.Millisecond * 100 | ||
client := http.Client{ | ||
Transport: rt, | ||
} | ||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
time.Sleep(time.Millisecond * 10) | ||
w.WriteHeader(http.StatusOK) | ||
// need a decent number of bytes to make the race case more likely to fail | ||
// https://github.com/go-kit/kit/issues/773 | ||
w.Write(make([]byte, 102400)) | ||
})) | ||
defer ts.Close() | ||
|
||
ctx := context.Background() | ||
|
||
req, _ := http.NewRequest(http.MethodGet, ts.URL, nil) | ||
res, err := client.Do(req.WithContext(ctx)) | ||
if err != nil { | ||
t.Errorf("unexpected error creating request: %s", err) | ||
} | ||
if res == nil { | ||
t.Error("unexpected nil response") | ||
} | ||
if res != nil { | ||
defer res.Body.Close() | ||
b, err := io.ReadAll(res.Body) | ||
if err != nil { | ||
t.Errorf("error reading response body: %s", err) | ||
} | ||
if len(b) != 102400 { | ||
t.Errorf("response byte length does not match expected. got %d, want %d", len(b), 102400) | ||
} | ||
} | ||
} |