From 8a6921ed1bf84341dcd49a6398518962856a6406 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Fri, 15 Dec 2023 12:36:41 -0500 Subject: [PATCH] Add tests with redirects, GetBody, addresses #9 --- rehttp_server_test.go | 179 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) diff --git a/rehttp_server_test.go b/rehttp_server_test.go index 7878bbb..84bd6a5 100644 --- a/rehttp_server_test.go +++ b/rehttp_server_test.go @@ -351,3 +351,182 @@ func TestClientRetryWithHeaders(t *testing.T) { assert.Equal(t, "b", res.Header.Get("X-2")) assert.False(t, fail) } + +func TestRedirects(t *testing.T) { + var statusStack []int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // always consume the body + b, err := io.ReadAll(r.Body) + require.NoError(t, err) + + if len(statusStack) > 0 { + status := statusStack[len(statusStack)-1] + statusStack = statusStack[:len(statusStack)-1] + if status/100 == 3 { + http.Redirect(w, r, "/"+fmt.Sprint(len(statusStack)), status) + return + } + w.WriteHeader(status) + return + } + + // on success, return the received body + _, err = w.Write(b) + require.NoError(t, err) + })) + defer srv.Close() + + t.Run("without body, without redirect and retry", func(t *testing.T) { + tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0)) + c := &http.Client{Transport: tr} + + res, err := c.Post(srv.URL+"/test", "", nil) + require.NoError(t, err) + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "", string(got)) + }) + + t.Run("without body, with redirects, no retry", func(t *testing.T) { + statusStack = []int{http.StatusTemporaryRedirect, http.StatusTemporaryRedirect} + tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0)) + c := &http.Client{Transport: tr} + + res, err := c.Post(srv.URL+"/test", "", nil) + require.NoError(t, err) + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "", string(got)) + }) + + t.Run("without body, with redirects and retries", func(t *testing.T) { + statusStack = []int{http.StatusTemporaryRedirect, http.StatusInternalServerError, http.StatusTemporaryRedirect, http.StatusInternalServerError} + tr := NewTransport(nil, RetryStatuses(http.StatusInternalServerError), ConstDelay(0)) + c := &http.Client{Transport: tr} + + res, err := c.Post(srv.URL+"/test", "", nil) + require.NoError(t, err) + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "", string(got)) + }) + + t.Run("with body, without redirect and retry", func(t *testing.T) { + tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0)) + c := &http.Client{Transport: tr} + + body := "ok" + res, err := c.Post(srv.URL+"/test", "", strings.NewReader(body)) + require.NoError(t, err) + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, body, string(got)) + }) + + t.Run("with body, with redirects, no retry", func(t *testing.T) { + statusStack = []int{http.StatusTemporaryRedirect, http.StatusTemporaryRedirect} + tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0)) + c := &http.Client{Transport: tr} + + body := "ok" + res, err := c.Post(srv.URL+"/test", "", strings.NewReader(body)) + require.NoError(t, err) + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, body, string(got)) + }) + + t.Run("with body, with redirects and retries", func(t *testing.T) { + statusStack = []int{http.StatusTemporaryRedirect, http.StatusInternalServerError, http.StatusTemporaryRedirect, http.StatusInternalServerError} + tr := NewTransport(nil, RetryStatuses(http.StatusInternalServerError), ConstDelay(0)) + c := &http.Client{Transport: tr} + + body := "ok" + res, err := c.Post(srv.URL+"/test", "", strings.NewReader(body)) + require.NoError(t, err) + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, body, string(got)) + }) + + t.Run("with body, with redirects, no automatic GetBody, no retry", func(t *testing.T) { + statusStack = []int{http.StatusTemporaryRedirect, http.StatusTemporaryRedirect} + tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0)) + c := &http.Client{Transport: tr} + + body := "ok" + // LimitReader disables the automatic setting of GetBody by Go's stdlib (it + // only supports known readers that can be safely reconstructed, such as + // bytes and strings readers) + res, err := c.Post(srv.URL+"/test", "", io.LimitReader(strings.NewReader(body), 100)) + require.NoError(t, err) + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusTemporaryRedirect, res.StatusCode) + assert.Empty(t, string(got)) + }) + + t.Run("with body, with redirects, explicit GetBody, no retry", func(t *testing.T) { + statusStack = []int{http.StatusTemporaryRedirect, http.StatusTemporaryRedirect} + tr := NewTransport(nil, func(attempt Attempt) bool { return false }, ConstDelay(0)) + c := &http.Client{Transport: tr} + + body := "ok" + req, err := http.NewRequest("POST", srv.URL+"/test", io.LimitReader(strings.NewReader(body), 100)) + require.NoError(t, err) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(io.LimitReader(strings.NewReader(body), 100)), nil + } + + res, err := c.Do(req) + require.NoError(t, err) + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, body, string(got)) + }) + + t.Run("with body, with redirects and retries, explicit GetBody", func(t *testing.T) { + statusStack = []int{http.StatusInternalServerError, http.StatusTemporaryRedirect, http.StatusInternalServerError, http.StatusTemporaryRedirect} + tr := NewTransport(nil, RetryStatuses(http.StatusInternalServerError), ConstDelay(0)) + c := &http.Client{Transport: tr} + + body := "ok" + req, err := http.NewRequest("POST", srv.URL+"/test", io.LimitReader(strings.NewReader(body), 100)) + require.NoError(t, err) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(io.LimitReader(strings.NewReader(body), 100)), nil + } + + res, err := c.Do(req) + require.NoError(t, err) + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, body, string(got)) + }) +}