-
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ability to set a per-attempt timeout
- Loading branch information
1 parent
eda55af
commit 32fc595
Showing
5 changed files
with
246 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package rehttp | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"net/http" | ||
) | ||
|
||
type cancelReader struct { | ||
io.ReadCloser | ||
|
||
cancel context.CancelFunc | ||
} | ||
|
||
func (r cancelReader) Close() error { | ||
r.cancel() | ||
return r.ReadCloser.Close() | ||
} | ||
|
||
// injectCancelReader propagates the ability for the caller to cancel the request context | ||
// once done with the response. If the transport cancels before the body stream is read, | ||
// a race begins where the caller may be unable to read the response bytes before the stream | ||
// is closed and an error is returned. This helper function wraps a response body in a | ||
// io.ReadCloser that cancels the context once the body is closed, preventing a context leak. | ||
// Solution based on https://github.com/go-kit/kit/issues/773. | ||
func injectCancelReader(res *http.Response, cancel context.CancelFunc) *http.Response { | ||
if res == nil { | ||
return nil | ||
} | ||
|
||
res.Body = cancelReader{ | ||
ReadCloser: res.Body, | ||
cancel: cancel, | ||
} | ||
return res | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
//go:build go1.7 | ||
// +build go1.7 | ||
|
||
package rehttp | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"time" | ||
) | ||
|
||
func getRequestContext(req *http.Request) context.Context { | ||
return req.Context() | ||
} | ||
|
||
func getPerAttemptTimeoutInfo(ctx context.Context, req *http.Request, timeout time.Duration) (*http.Request, context.CancelFunc) { | ||
tctx, cancel := context.WithTimeout(ctx, timeout) | ||
return req.WithContext(tctx), cancel | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
//go:build !go1.7 | ||
// +build !go1.7 | ||
|
||
package rehttp | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"time" | ||
) | ||
|
||
func getRequestContext(req *http.Request) context.Context { | ||
return nil // req.Context() doesn't exist before 1.7 | ||
} | ||
|
||
func getPerAttemptTimeoutInfo(ctx context.Context, req *http.Request, timeout time.Duration) (*http.Request, context.CancelFunc) { | ||
// req.WithContext() doesn't exist before 1.7, so noop | ||
return req, func() {} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
//go:build go1.7 | ||
// +build go1.7 | ||
|
||
package rehttp | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestTransport_RoundTripTimeouts(t *testing.T) { | ||
// to keep track of any open server requests and ensure the correct number of requests were made | ||
ch := make(chan bool, 4) | ||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
ch <- true | ||
time.Sleep(time.Millisecond * 100) | ||
w.WriteHeader(http.StatusTooManyRequests) | ||
})) | ||
defer ts.Close() | ||
|
||
tr := NewTransport( | ||
http.DefaultTransport, | ||
RetryAll(RetryMaxRetries(3), RetryAny( | ||
RetryStatuses(http.StatusTooManyRequests), // retry 429s | ||
func(attempt Attempt) bool { // retry context deadline exceeded errors | ||
return attempt.Error != nil && attempt.Error == context.DeadlineExceeded // errors.Is requires go 1.13+ | ||
})), | ||
ConstDelay(0), | ||
) | ||
tr.PerAttemptTimeout = time.Millisecond * 10 // short timeout | ||
|
||
client := http.Client{ | ||
Transport: tr, | ||
} | ||
|
||
req, err := http.NewRequest(http.MethodGet, ts.URL, nil) | ||
if err != nil { | ||
t.Errorf("error creating request: %s", err) | ||
} | ||
|
||
_, err = client.Do(req) | ||
// make sure server has finished and got 4 attempts | ||
<-ch | ||
<-ch | ||
<-ch | ||
<-ch | ||
if err == nil { | ||
t.Error("expected timeout error doing request but got nil") | ||
} | ||
|
||
// repeat the above but this time add a context timeout that will expire before all attempts complete | ||
ch = make(chan bool, 2) | ||
ctx, cancelFunc := context.WithTimeout(context.Background(), time.Millisecond*15) | ||
_, err = client.Do(req.WithContext(ctx)) | ||
cancelFunc() | ||
// should only make 2 attempts | ||
<-ch | ||
<-ch | ||
if err == nil { | ||
t.Error("expected timeout error doing request but got nil") | ||
} | ||
|
||
// now increase the timeout restriction | ||
ch = make(chan bool, 4) | ||
tr.PerAttemptTimeout = time.Second | ||
res, err := client.Do(req) | ||
|
||
// should have attempted 4 times without going over the timeout | ||
<-ch | ||
<-ch | ||
<-ch | ||
<-ch | ||
|
||
if err != nil { | ||
t.Errorf("got unexpected error doing request: %s", err) | ||
} | ||
if res == nil || res.StatusCode != http.StatusTooManyRequests { | ||
t.Errorf("status code does not match expected: got %d, want %d", res.StatusCode, http.StatusTooManyRequests) | ||
} | ||
|
||
// now remove the timeout restriction | ||
ch = make(chan bool, 4) | ||
tr.PerAttemptTimeout = time.Duration(0) | ||
res, err = client.Do(req) | ||
// should have attempted 4 times without going over the timeout | ||
<-ch | ||
<-ch | ||
<-ch | ||
<-ch | ||
|
||
if err != nil { | ||
t.Errorf("got unexpected error doing request: %s", err) | ||
} | ||
if res == nil || res.StatusCode != http.StatusTooManyRequests { | ||
t.Errorf("status code does not match expected: got %d, want %d", res.StatusCode, http.StatusTooManyRequests) | ||
} | ||
} | ||
|
||
// TestCancelReader is meant to test that the cancel reader is correctly | ||
// preventing the race-case of being unable to read the body due to a | ||
// preemptively-canceled context. | ||
func TestCancelReader(t *testing.T) { | ||
rt := NewTransport(http.DefaultTransport, RetryMaxRetries(1), ConstDelay(0)) | ||
rt.PerAttemptTimeout = time.Millisecond * 100 | ||
client := http.Client{ | ||
Transport: rt, | ||
} | ||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
time.Sleep(time.Millisecond * 10) | ||
w.WriteHeader(http.StatusOK) | ||
// need a decent number of bytes to make the race case more likely to fail | ||
// https://github.com/go-kit/kit/issues/773 | ||
w.Write(make([]byte, 102400)) | ||
})) | ||
defer ts.Close() | ||
|
||
ctx := context.Background() | ||
|
||
req, _ := http.NewRequest(http.MethodGet, ts.URL, nil) | ||
res, err := client.Do(req.WithContext(ctx)) | ||
if err != nil { | ||
t.Fatalf("unexpected error creating request: %s", err) | ||
} | ||
defer res.Body.Close() | ||
b, err := io.ReadAll(res.Body) | ||
if err != nil { | ||
t.Errorf("error reading response body: %s", err) | ||
} | ||
if len(b) != 102400 { | ||
t.Errorf("response byte length does not match expected. got %d, want %d", len(b), 102400) | ||
} | ||
} |