diff --git a/replications/remotewrite/writer.go b/replications/remotewrite/writer.go index 39a29b21be5..ae1832a29e3 100644 --- a/replications/remotewrite/writer.go +++ b/replications/remotewrite/writer.go @@ -14,6 +14,7 @@ import ( "github.com/influxdata/influx-cli/v2/api" "github.com/influxdata/influxdb/v2" + ihttp "github.com/influxdata/influxdb/v2/http" "github.com/influxdata/influxdb/v2/kit/platform" ierrors "github.com/influxdata/influxdb/v2/kit/platform/errors" "github.com/influxdata/influxdb/v2/replications/metrics" @@ -44,10 +45,11 @@ func invalidRemoteUrl(remoteUrl string, err error) *ierrors.Error { } } -func invalidResponseCode(code int) *ierrors.Error { +func invalidResponseCode(code int, err error) *ierrors.Error { return &ierrors.Error{ Code: ierrors.EInvalid, Msg: fmt.Sprintf("invalid response code %d, must be %d", code, http.StatusNoContent), + Err: err, } } @@ -245,7 +247,10 @@ func PostWrite(ctx context.Context, config *influxdb.ReplicationHTTPConfig, data // Only a response of 204 is valid for a successful write if res.StatusCode != http.StatusNoContent { - err = invalidResponseCode(res.StatusCode) + if err == nil { + err = ihttp.CheckError(res) + } + err = invalidResponseCode(res.StatusCode, err) } // Must return the response so that the status code and headers can be inspected by the caller, even if the response diff --git a/replications/remotewrite/writer_test.go b/replications/remotewrite/writer_test.go index e5362dc79fa..93c45d5a319 100644 --- a/replications/remotewrite/writer_test.go +++ b/replications/remotewrite/writer_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strconv" + "strings" "testing" "time" @@ -16,6 +17,7 @@ import ( "github.com/influxdata/influxdb/v2/kit/platform" "github.com/influxdata/influxdb/v2/kit/prom" "github.com/influxdata/influxdb/v2/kit/prom/promtest" + ihttp "github.com/influxdata/influxdb/v2/kit/transport/http" "github.com/influxdata/influxdb/v2/replications/metrics" replicationsMock "github.com/influxdata/influxdb/v2/replications/mock" "github.com/stretchr/testify/require" @@ -61,6 +63,27 @@ func instaWait() waitFunc { } } +type containsMatcher struct { + substring string +} + +func (cm *containsMatcher) Matches(x interface{}) bool { + if st, ok := x.(fmt.Stringer); ok { + return strings.Contains(st.String(), cm.substring) + } else { + s, ok := x.(string) + return ok && strings.Contains(s, cm.substring) + } +} + +func (cm *containsMatcher) String() string { + if cm != nil { + return cm.substring + } else { + return "" + } +} + func TestWrite(t *testing.T) { t.Parallel() @@ -137,7 +160,7 @@ func TestWrite(t *testing.T) { w.waitFunc = instaWait() configStore.EXPECT().GetFullHTTPConfig(gomock.Any(), testID).Return(testConfig, nil) - configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, status, invalidResponseCode(status).Error()).Return(nil) + configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, status, &containsMatcher{invalidResponseCode(status, nil).Error()}).Return(nil) _, actualErr := w.Write(testData, testAttempts) require.NotNil(t, actualErr) require.Contains(t, actualErr.Error(), fmt.Sprintf("invalid response code %d", status)) @@ -165,7 +188,7 @@ func TestWrite(t *testing.T) { configStore.EXPECT().GetFullHTTPConfig(gomock.Any(), testID).Return(testConfig, nil).Times(testAttempts - 1) configStore.EXPECT().GetFullHTTPConfig(gomock.Any(), testID).Return(updatedConfig, nil) - configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusBadRequest, invalidResponseCode(http.StatusBadRequest).Error()).Return(nil).Times(testAttempts) + configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusBadRequest, &containsMatcher{invalidResponseCode(http.StatusBadRequest, nil).Error()}).Return(nil).Times(testAttempts) for i := 1; i <= testAttempts; i++ { _, actualErr := w.Write(testData, i) if testAttempts == i { @@ -190,7 +213,7 @@ func TestWrite(t *testing.T) { configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusBadRequest, gomock.Any()).Return(nil) backoff, actualErr := w.Write(testData, 1) require.Equal(t, backoff, w.backoff(1)) - require.Equal(t, invalidResponseCode(http.StatusBadRequest), actualErr) + require.ErrorContains(t, actualErr, invalidResponseCode(http.StatusBadRequest, nil).Error()) }) t.Run("uses wait time from response header if present", func(t *testing.T) { @@ -218,9 +241,9 @@ func TestWrite(t *testing.T) { } configStore.EXPECT().GetFullHTTPConfig(gomock.Any(), testID).Return(testConfig, nil) - configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusTooManyRequests, invalidResponseCode(http.StatusTooManyRequests).Error()).Return(nil) + configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusTooManyRequests, &containsMatcher{invalidResponseCode(http.StatusTooManyRequests, nil).Error()}).Return(nil) _, actualErr := w.Write(testData, 1) - require.Equal(t, invalidResponseCode(http.StatusTooManyRequests), actualErr) + require.ErrorContains(t, actualErr, invalidResponseCode(http.StatusTooManyRequests, nil).Error()) }) t.Run("can cancel with done channel", func(t *testing.T) { @@ -234,9 +257,9 @@ func TestWrite(t *testing.T) { w, configStore, _ := testWriter(t) configStore.EXPECT().GetFullHTTPConfig(gomock.Any(), testID).Return(testConfig, nil) - configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusInternalServerError, invalidResponseCode(http.StatusInternalServerError).Error()).Return(nil) + configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusInternalServerError, &containsMatcher{invalidResponseCode(http.StatusInternalServerError, nil).Error()}).Return(nil) _, actualErr := w.Write(testData, 1) - require.Equal(t, invalidResponseCode(http.StatusInternalServerError), actualErr) + require.ErrorContains(t, actualErr, invalidResponseCode(http.StatusInternalServerError, nil).Error()) }) t.Run("writes resume after temporary remote disconnect", func(t *testing.T) { @@ -288,7 +311,7 @@ func TestWrite(t *testing.T) { numAttempts = 0 } else { // should fail - configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusGatewayTimeout, invalidResponseCode(http.StatusGatewayTimeout).Error()).Return(nil) + configStore.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusGatewayTimeout, &containsMatcher{invalidResponseCode(http.StatusGatewayTimeout, nil).Error()}).Return(nil) _, err := w.Write([]byte(testWrites[i]), numAttempts) require.Error(t, err) numAttempts++ @@ -312,11 +335,11 @@ func TestWrite_Metrics(t *testing.T) { { name: "server errors", status: constantStatus(http.StatusTeapot), - expectedErr: invalidResponseCode(http.StatusTeapot), + expectedErr: invalidResponseCode(http.StatusTeapot, nil), data: []byte{}, registerExpectations: func(t *testing.T, store *replicationsMock.MockHttpConfigStore, conf *influxdb.ReplicationHTTPConfig) { store.EXPECT().GetFullHTTPConfig(gomock.Any(), testID).Return(conf, nil) - store.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusTeapot, invalidResponseCode(http.StatusTeapot).Error()).Return(nil) + store.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusTeapot, &containsMatcher{invalidResponseCode(http.StatusTeapot, nil).Error()}).Return(nil) }, checkMetrics: func(t *testing.T, reg *prom.Registry) { mfs := promtest.MustGather(t, reg) @@ -351,7 +374,7 @@ func TestWrite_Metrics(t *testing.T) { data: testData, registerExpectations: func(t *testing.T, store *replicationsMock.MockHttpConfigStore, conf *influxdb.ReplicationHTTPConfig) { store.EXPECT().GetFullHTTPConfig(gomock.Any(), testID).Return(conf, nil) - store.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusBadRequest, invalidResponseCode(http.StatusBadRequest).Error()).Return(nil) + store.EXPECT().UpdateResponseInfo(gomock.Any(), testID, http.StatusBadRequest, &containsMatcher{invalidResponseCode(http.StatusBadRequest, nil).Error()}).Return(nil) }, checkMetrics: func(t *testing.T, reg *prom.Registry) { mfs := promtest.MustGather(t, reg) @@ -382,7 +405,11 @@ func TestWrite_Metrics(t *testing.T) { tt.registerExpectations(t, configStore, testConfig) _, actualErr := w.Write(tt.data, 1) - require.Equal(t, tt.expectedErr, actualErr) + if tt.expectedErr != nil { + require.ErrorContains(t, actualErr, tt.expectedErr.Error()) + } else { + require.NoError(t, actualErr) + } tt.checkMetrics(t, reg) }) } @@ -393,6 +420,7 @@ func TestPostWrite(t *testing.T) { tests := []struct { status int + bodyErr error wantErr bool }{ { @@ -406,6 +434,12 @@ func TestPostWrite(t *testing.T) { { status: http.StatusBadRequest, wantErr: true, + bodyErr: fmt.Errorf("This is a terrible error: %w", errors.New("there are bad things here")), + }, + { + status: http.StatusMethodNotAllowed, + wantErr: true, + bodyErr: fmt.Errorf("method not allowed: %w", errors.New("what were you thinking")), }, } @@ -416,7 +450,12 @@ func TestPostWrite(t *testing.T) { require.NoError(t, err) require.Equal(t, testData, recData) - w.WriteHeader(tt.status) + if tt.bodyErr != nil { + influxErrorCode := ihttp.StatusCodeToErrorCode(tt.status) + ihttp.WriteErrorResponse(context.Background(), w, influxErrorCode, tt.bodyErr.Error()) + } else { + w.WriteHeader(tt.status) + } })) defer svr.Close() @@ -427,12 +466,16 @@ func TestPostWrite(t *testing.T) { res, err := PostWrite(context.Background(), config, testData, time.Second) if tt.wantErr { require.Error(t, err) - return + if nil != tt.bodyErr { + require.ErrorContains(t, err, tt.bodyErr.Error()) + } } else { require.Nil(t, err) } - require.Equal(t, tt.status, res.StatusCode) + if res != nil { + require.Equal(t, tt.status, res.StatusCode) + } }) } }