Skip to content

Commit

Permalink
Merge pull request #11 from PuerkitoBio/mna-9-check-retry-redirect
Browse files Browse the repository at this point in the history
Add tests with redirects, GetBody, addresses #9
  • Loading branch information
mna authored Dec 15, 2023
2 parents 987d330 + 8a6921e commit 8f425fd
Showing 1 changed file with 179 additions and 0 deletions.
179 changes: 179 additions & 0 deletions rehttp_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,182 @@ func TestClientRetryWithHeaders(t *testing.T) {
assert.Equal(t, "b", res.Header.Get("X-2"))
assert.False(t, fail)
}

func TestRedirects(t *testing.T) {
var statusStack []int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// always consume the body
b, err := io.ReadAll(r.Body)
require.NoError(t, err)

if len(statusStack) > 0 {
status := statusStack[len(statusStack)-1]
statusStack = statusStack[:len(statusStack)-1]
if status/100 == 3 {
http.Redirect(w, r, "/"+fmt.Sprint(len(statusStack)), status)
return
}
w.WriteHeader(status)
return
}

// on success, return the received body
_, err = w.Write(b)
require.NoError(t, err)
}))
defer srv.Close()

t.Run("without body, without redirect and retry", func(t *testing.T) {
tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0))
c := &http.Client{Transport: tr}

res, err := c.Post(srv.URL+"/test", "", nil)
require.NoError(t, err)
defer res.Body.Close()
got, err := io.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, "", string(got))
})

t.Run("without body, with redirects, no retry", func(t *testing.T) {
statusStack = []int{http.StatusTemporaryRedirect, http.StatusTemporaryRedirect}
tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0))
c := &http.Client{Transport: tr}

res, err := c.Post(srv.URL+"/test", "", nil)
require.NoError(t, err)
defer res.Body.Close()
got, err := io.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, "", string(got))
})

t.Run("without body, with redirects and retries", func(t *testing.T) {
statusStack = []int{http.StatusTemporaryRedirect, http.StatusInternalServerError, http.StatusTemporaryRedirect, http.StatusInternalServerError}
tr := NewTransport(nil, RetryStatuses(http.StatusInternalServerError), ConstDelay(0))
c := &http.Client{Transport: tr}

res, err := c.Post(srv.URL+"/test", "", nil)
require.NoError(t, err)
defer res.Body.Close()
got, err := io.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, "", string(got))
})

t.Run("with body, without redirect and retry", func(t *testing.T) {
tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0))
c := &http.Client{Transport: tr}

body := "ok"
res, err := c.Post(srv.URL+"/test", "", strings.NewReader(body))
require.NoError(t, err)
defer res.Body.Close()
got, err := io.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, body, string(got))
})

t.Run("with body, with redirects, no retry", func(t *testing.T) {
statusStack = []int{http.StatusTemporaryRedirect, http.StatusTemporaryRedirect}
tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0))
c := &http.Client{Transport: tr}

body := "ok"
res, err := c.Post(srv.URL+"/test", "", strings.NewReader(body))
require.NoError(t, err)
defer res.Body.Close()
got, err := io.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, body, string(got))
})

t.Run("with body, with redirects and retries", func(t *testing.T) {
statusStack = []int{http.StatusTemporaryRedirect, http.StatusInternalServerError, http.StatusTemporaryRedirect, http.StatusInternalServerError}
tr := NewTransport(nil, RetryStatuses(http.StatusInternalServerError), ConstDelay(0))
c := &http.Client{Transport: tr}

body := "ok"
res, err := c.Post(srv.URL+"/test", "", strings.NewReader(body))
require.NoError(t, err)
defer res.Body.Close()
got, err := io.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, body, string(got))
})

t.Run("with body, with redirects, no automatic GetBody, no retry", func(t *testing.T) {
statusStack = []int{http.StatusTemporaryRedirect, http.StatusTemporaryRedirect}
tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0))
c := &http.Client{Transport: tr}

body := "ok"
// LimitReader disables the automatic setting of GetBody by Go's stdlib (it
// only supports known readers that can be safely reconstructed, such as
// bytes and strings readers)
res, err := c.Post(srv.URL+"/test", "", io.LimitReader(strings.NewReader(body), 100))
require.NoError(t, err)
defer res.Body.Close()
got, err := io.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, http.StatusTemporaryRedirect, res.StatusCode)
assert.Empty(t, string(got))
})

t.Run("with body, with redirects, explicit GetBody, no retry", func(t *testing.T) {
statusStack = []int{http.StatusTemporaryRedirect, http.StatusTemporaryRedirect}
tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0))
c := &http.Client{Transport: tr}

body := "ok"
req, err := http.NewRequest("POST", srv.URL+"/test", io.LimitReader(strings.NewReader(body), 100))
require.NoError(t, err)
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(io.LimitReader(strings.NewReader(body), 100)), nil
}

res, err := c.Do(req)
require.NoError(t, err)
defer res.Body.Close()
got, err := io.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, body, string(got))
})

t.Run("with body, with redirects and retries, explicit GetBody", func(t *testing.T) {
statusStack = []int{http.StatusInternalServerError, http.StatusTemporaryRedirect, http.StatusInternalServerError, http.StatusTemporaryRedirect}
tr := NewTransport(nil, RetryStatuses(http.StatusInternalServerError), ConstDelay(0))
c := &http.Client{Transport: tr}

body := "ok"
req, err := http.NewRequest("POST", srv.URL+"/test", io.LimitReader(strings.NewReader(body), 100))
require.NoError(t, err)
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(io.LimitReader(strings.NewReader(body), 100)), nil
}

res, err := c.Do(req)
require.NoError(t, err)
defer res.Body.Close()
got, err := io.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, body, string(got))
})
}

0 comments on commit 8f425fd

Please sign in to comment.