diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index b02948e78..b131f4b6e 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -15,6 +16,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/reddit/baseplate.go/breakerbp" + "github.com/reddit/baseplate.go/internal/faults" //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/retrybp" @@ -43,6 +45,8 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { // plus any additional client middleware passed into this function. Default // middlewares are: // +// * FaultInjection +// // * MonitorClient with transport.WithRetrySlugSuffix // // * PrometheusClientMetrics with transport.WithRetrySlugSuffix @@ -76,6 +80,7 @@ func NewClient(config ClientConfig, middleware ...ClientMiddleware) (*http.Clien } defaults := []ClientMiddleware{ + FaultInjection(), MonitorClient(config.Slug + transport.WithRetrySlugSuffix), PrometheusClientMetrics(config.Slug + transport.WithRetrySlugSuffix), Retries(config.MaxErrorReadAhead, config.RetryOptions...), @@ -349,3 +354,44 @@ func PrometheusClientMetrics(serverSlug string) ClientMiddleware { }) } } + +func FaultInjection() ClientMiddleware { + return func(next http.RoundTripper) http.RoundTripper { + return roundTripperFunc(func(req *http.Request) (*http.Response, error) { + resumeFn := func() (*http.Response, error) { + return next.RoundTrip(req) + } + responseFn := func(code int, message string) (*http.Response, error) { + return &http.Response{ + Status: http.StatusText(code), + StatusCode: code, + Proto: req.Proto, + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Header: map[string][]string{ + // Copied from the standard http.Error() function. + "Content-Type": {"text/plain; charset=utf-8"}, + "X-Content-Type-Options": {"nosniff"}, + }, + ContentLength: 0, + TransferEncoding: req.TransferEncoding, + Request: req, + TLS: req.TLS, + }, nil + } + + resp, err := faults.InjectFault(faults.InjectFaultParams[*http.Response]{ + Context: req.Context(), + CallerName: "httpbp.FaultInjection", + Address: req.URL.Hostname(), + Method: strings.TrimPrefix(req.URL.Path, "/"), + AbortCodeMin: 400, + AbortCodeMax: 599, + GetHeaderFn: faults.GetHeaderFn(req.Header.Get), + ResumeFn: resumeFn, + ResponseFn: responseFn, + }) + return resp, err + }) + } +} diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 735d2fe9f..f0c213806 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -16,6 +16,7 @@ import ( "github.com/sony/gobreaker" "github.com/reddit/baseplate.go/breakerbp" + "github.com/reddit/baseplate.go/internal/faults" ) func TestNewClient(t *testing.T) { @@ -395,3 +396,140 @@ func TestCircuitBreaker(t *testing.T) { t.Errorf("Expected the third request to return %v, got %v", gobreaker.ErrOpenState, err) } } + +func TestFaultInjection(t *testing.T) { + testCases := []struct { + name string + faultServerAddrMatch bool + faultServerMethodHeader string + faultDelayMsHeader string + faultDelayPercentageHeader string + faultAbortCodeHeader string + faultAbortMessageHeader string + faultAbortPercentageHeader string + + wantResp *http.Response + }{ + { + name: "no fault specified", + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + { + name: "abort", + + faultServerAddrMatch: true, + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "500", + + wantResp: &http.Response{ + StatusCode: http.StatusInternalServerError, + }, + }, + { + name: "service does not match", + + faultServerAddrMatch: false, + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "500", + + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + { + name: "method does not match", + + faultServerAddrMatch: true, + faultServerMethodHeader: "fooMethod", + faultAbortCodeHeader: "500", + + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + { + name: "less than min abort code", + + faultServerAddrMatch: true, + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "99", + + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + { + name: "greater than max abort code", + + faultServerAddrMatch: true, + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "600", + + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "Success!") + })) + defer server.Close() + + client, err := NewClient(ClientConfig{ + Slug: "test", + }) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + + req, err := http.NewRequest("GET", server.URL+"/testMethod", nil) + if err != nil { + t.Fatalf("unexpected error when creating request: %v", err) + } + + if tt.faultServerAddrMatch { + // We can't set a specific address here because the middleware + // relies on the DNS address, which is not customizable when making + // real requests to a local HTTP test server. + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("unexpected error when parsing httptest server URL: %v", err) + } + req.Header.Set(faults.FaultServerAddressHeader, parsed.Hostname()) + } + if tt.faultServerMethodHeader != "" { + req.Header.Set(faults.FaultServerMethodHeader, tt.faultServerMethodHeader) + } + if tt.faultDelayMsHeader != "" { + req.Header.Set(faults.FaultDelayMsHeader, tt.faultDelayMsHeader) + } + if tt.faultDelayPercentageHeader != "" { + req.Header.Set(faults.FaultDelayPercentageHeader, tt.faultDelayPercentageHeader) + } + if tt.faultAbortCodeHeader != "" { + req.Header.Set(faults.FaultAbortCodeHeader, tt.faultAbortCodeHeader) + } + if tt.faultAbortMessageHeader != "" { + req.Header.Set(faults.FaultAbortMessageHeader, tt.faultAbortMessageHeader) + } + if tt.faultAbortPercentageHeader != "" { + req.Header.Set(faults.FaultAbortPercentageHeader, tt.faultAbortPercentageHeader) + } + + resp, err := client.Do(req) + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if tt.wantResp.StatusCode != resp.StatusCode { + t.Fatalf("expected response code %v, got %v", tt.wantResp.StatusCode, resp.StatusCode) + } + }) + } +} diff --git a/internal/faults/common.go b/internal/faults/common.go new file mode 100644 index 000000000..08d0d3b08 --- /dev/null +++ b/internal/faults/common.go @@ -0,0 +1,162 @@ +// Package faults provides common headers and client-side fault injection +// functionality. +package faults + +import ( + "context" + "fmt" + "log/slog" + "math/rand/v2" + "strconv" + "strings" + "time" +) + +// GetHeaderFn is the function type to return the value of a protocol-specific +// header with the given key. +type GetHeaderFn func(key string) string + +// ResumeFn is the function type to continue processing the protocol-specific +// request without injecting a fault. +type ResumeFn[T any] func() (T, error) + +// ResponseFn is the function type to inject a protocol-specific fault with the +// given code and message. +type ResponseFn[T any] func(code int, message string) (T, error) + +// sleepFn is the function type to sleep for the given duration. Only used in +// tests. +type sleepFn func(ctx context.Context, d time.Duration) error + +// The canonical address for a cluster-local address is ., +// without the local cluster suffix or port. The canonical address for a +// non-cluster-local address is the full original address without the port. +func getCanonicalAddress(serverAddress string) string { + // Cluster-local address. + if i := strings.Index(serverAddress, ".svc.cluster.local"); i != -1 { + return serverAddress[:i] + } + // External host:port address. + if i := strings.LastIndex(serverAddress, ":"); i != -1 { + port := serverAddress[i+1:] + // Verify this is actually a port number. + if port != "" && port[0] >= '0' && port[0] <= '9' { + return serverAddress[:i] + } + } + // Other address, i.e. unix domain socket. + return serverAddress +} + +func parsePercentage(percentage string) (int, error) { + if percentage == "" { + return 100, nil + } + intPercentage, err := strconv.Atoi(percentage) + if err != nil { + return 0, fmt.Errorf("provided percentage %q is not a valid integer: %w", percentage, err) + } + if intPercentage < 0 || intPercentage > 100 { + return 0, fmt.Errorf("provided percentage \"%d\" is outside the valid range of [0-100]", intPercentage) + } + return intPercentage, nil +} + +func selected(randInt *int, percentage int) bool { + if randInt != nil { + return *randInt < percentage + } + // Use a different random integer per feature as per + // https://github.com/grpc/proposal/blob/master/A33-Fault-Injection.md#evaluate-possibility-fraction. + return rand.IntN(100) < percentage +} + +func sleep(ctx context.Context, d time.Duration) error { + t := time.NewTimer(d) + select { + case <-t.C: + case <-ctx.Done(): + t.Stop() + return ctx.Err() + } + return nil +} + +type InjectFaultParams[T any] struct { + Context context.Context + CallerName string + + Address, Method string + AbortCodeMin, AbortCodeMax int + + GetHeaderFn GetHeaderFn + ResumeFn ResumeFn[T] + ResponseFn ResponseFn[T] + + randInt *int + sleepFn *sleepFn +} + +func InjectFault[T any](params InjectFaultParams[T]) (T, error) { + faultHeaderAddress := params.GetHeaderFn(FaultServerAddressHeader) + requestAddress := getCanonicalAddress(params.Address) + if faultHeaderAddress == "" || faultHeaderAddress != requestAddress { + return params.ResumeFn() + } + + serverMethod := params.GetHeaderFn(FaultServerMethodHeader) + if serverMethod != "" && serverMethod != params.Method { + return params.ResumeFn() + } + + delayMs := params.GetHeaderFn(FaultDelayMsHeader) + if delayMs != "" { + percentage, err := parsePercentage(params.GetHeaderFn(FaultDelayPercentageHeader)) + if err != nil { + slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) + return params.ResumeFn() + } + + if selected(params.randInt, percentage) { + delay, err := strconv.Atoi(delayMs) + if err != nil { + slog.Warn(fmt.Sprintf("%s: provided delay %q is not a valid integer", params.CallerName, delayMs)) + return params.ResumeFn() + } + + sleepFn := sleep + if params.sleepFn != nil { + sleepFn = *params.sleepFn + } + if err := sleepFn(params.Context, time.Duration(delay)*time.Millisecond); err != nil { + slog.Warn(fmt.Sprintf("%s: error when delaying request: %v", params.CallerName, err)) + return params.ResumeFn() + } + } + } + + abortCode := params.GetHeaderFn(FaultAbortCodeHeader) + if abortCode != "" { + percentage, err := parsePercentage(params.GetHeaderFn(FaultAbortPercentageHeader)) + if err != nil { + slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) + return params.ResumeFn() + } + + if selected(params.randInt, percentage) { + code, err := strconv.Atoi(abortCode) + if err != nil { + slog.Warn(fmt.Sprintf("%s: provided abort code %q is not a valid integer", params.CallerName, abortCode)) + return params.ResumeFn() + } + if code < params.AbortCodeMin || code > params.AbortCodeMax { + slog.Warn(fmt.Sprintf("%s: provided abort code \"%d\" is outside of the valid range", params.CallerName, code)) + return params.ResumeFn() + } + abortMessage := params.GetHeaderFn(FaultAbortMessageHeader) + return params.ResponseFn(code, abortMessage) + } + } + + return params.ResumeFn() +} diff --git a/internal/faults/common_test.go b/internal/faults/common_test.go new file mode 100644 index 000000000..dbbe4d8a2 --- /dev/null +++ b/internal/faults/common_test.go @@ -0,0 +1,467 @@ +package faults + +import ( + "context" + "fmt" + "strings" + "testing" + "time" +) + +const ( + defaultAddress = "testService.testNamespace.svc.cluster.local:12345" + method = "testMethod" + minAbortCode = 0 + maxAbortCode = 10 +) + +func TestGetCanonicalAddress(t *testing.T) { + testCases := []struct { + name string + address string + want string + }{ + { + name: "cluster local address", + address: "testService.testNamespace.svc.cluster.local:12345", + want: "testService.testNamespace", + }, + { + name: "external address port stripped", + address: "foo.bar:12345", + want: "foo.bar", + }, + { + name: "unexpected address path stripped", + address: "foo.bar:12345/path", + want: "foo.bar", + }, + { + name: "unexpected trailing colon untouched", + address: "foo.bar:", + want: "foo.bar:", + }, + { + name: "external address without port untouched", + address: "unix://foo", + want: "unix://foo", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := getCanonicalAddress(tc.address) + if got != tc.want { + t.Fatalf("expected %q, got %q", tc.want, got) + } + }) + } +} + +func TestParsePercentage(t *testing.T) { + testCases := []struct { + name string + percentage string + want int + wantErr string + }{ + { + name: "empty", + percentage: "", + want: 100, + }, + { + name: "valid", + percentage: "50", + want: 50, + }, + { + name: "NaN", + percentage: "NaN", + want: 0, + wantErr: "not a valid integer", + }, + { + name: "under min", + percentage: "-1", + want: 0, + wantErr: "outside the valid range", + }, + { + name: "over max", + percentage: "101", + want: 0, + wantErr: "outside the valid range", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := parsePercentage(tc.percentage) + if got != tc.want { + t.Fatalf("expected %v, got %v", tc.want, got) + } + if tc.wantErr == "" && err != nil { + t.Fatalf("expected no error, got %v", err) + } + if tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error to contain %q, got %v", tc.wantErr, err) + } + }) + } +} + +type Response struct { + code int + message string +} + +func intPtr(i int) *int { + return &i +} + +func TestInjectFault(t *testing.T) { + testCases := []struct { + name string + address string + randInt *int + sleepErr bool + + faultServerAddressHeader string + faultServerMethodHeader string + faultDelayMsHeader string + faultDelayPercentageHeader string + faultAbortCodeHeader string + faultAbortMessageHeader string + faultAbortPercentageHeader string + + wantDelayMs int + wantResponse *Response + }{ + { + name: "no fault specified", + wantResponse: nil, + }, + { + name: "delay", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "1", + + wantDelayMs: 1, + }, + { + name: "abort", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + + wantResponse: &Response{ + code: 1, + message: "test fault", + }, + }, + { + name: "invalid server address", + address: "foo", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "server address does not match", + + faultServerAddressHeader: "fooService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "method does not match", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "fooMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "guaranteed percent", + randInt: intPtr(99), // Maximum possible integer returned by rand.Intn(100) + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "100", // All requests delayed + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "100", // All requests aborted + + wantDelayMs: 250, + wantResponse: &Response{ + code: 1, + message: "test fault", + }, + }, + { + name: "fence post below percent", + randInt: intPtr(49), + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "50", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "50", + + wantDelayMs: 250, + wantResponse: &Response{ + code: 1, + message: "test fault", + }, + }, + { + name: "fence post at percent", + randInt: intPtr(50), + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "50", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "50", + + wantDelayMs: 0, + wantResponse: nil, + }, + { + name: "guaranteed skip percent", + randInt: intPtr(0), // Minimum possible integer returned by rand.Intn(100) + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "0", // No requests delayed + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "0", // No requests aborted + + wantDelayMs: 0, + wantResponse: nil, + }, + { + name: "only skip delay", + randInt: intPtr(50), + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "0", // No requests delayed + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "100", // All requests aborted + + wantDelayMs: 0, + wantResponse: &Response{ + code: 1, + message: "test fault", + }, + }, + { + name: "invalid delay percentage negative", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "-1", + + wantDelayMs: 0, + }, + { + name: "invalid delay percentage over 100", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "101", + + wantDelayMs: 0, + }, + { + name: "invalid delay ms", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "NaN", + + wantDelayMs: 0, + }, + { + name: "error while sleeping short circuits", + sleepErr: true, + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "1", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + + wantDelayMs: 0, + }, + { + name: "invalid abort percentage negative", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "-1", + + wantResponse: nil, + }, + { + name: "invalid abort percentage over 100", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "101", + + wantResponse: nil, + }, + { + name: "invalid abort code", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "NaN", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "less than min abort code", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "-1", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "greater than max abort code", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "11", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "invalid abort percentage", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "NaN", + + wantResponse: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + address := tc.address + if address == "" { + address = defaultAddress + } + + getHeaderFn := GetHeaderFn(func(key string) string { + if key == FaultServerAddressHeader { + return tc.faultServerAddressHeader + } + if key == FaultServerMethodHeader { + return tc.faultServerMethodHeader + } + if key == FaultDelayMsHeader { + return tc.faultDelayMsHeader + } + if key == FaultDelayPercentageHeader { + return tc.faultDelayPercentageHeader + } + if key == FaultAbortCodeHeader { + return tc.faultAbortCodeHeader + } + if key == FaultAbortMessageHeader { + return tc.faultAbortMessageHeader + } + if key == FaultAbortPercentageHeader { + return tc.faultAbortPercentageHeader + } + return "" + }) + var resumeFn ResumeFn[*Response] = func() (*Response, error) { + return nil, nil + } + var responseFn ResponseFn[*Response] = func(code int, message string) (*Response, error) { + return &Response{ + code: code, + message: message, + }, nil + } + delayMs := 0 + sleepFn := sleepFn(func(ctx context.Context, d time.Duration) error { + if tc.sleepErr { + return fmt.Errorf("context cancelled") + } + delayMs = int(d.Milliseconds()) + return nil + }) + + resp, err := InjectFault(InjectFaultParams[*Response]{ + CallerName: "faults_test.TestInjectFault", + Address: address, + Method: method, + AbortCodeMin: minAbortCode, + AbortCodeMax: maxAbortCode, + GetHeaderFn: getHeaderFn, + ResumeFn: resumeFn, + ResponseFn: responseFn, + sleepFn: &sleepFn, + randInt: tc.randInt, + }) + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if tc.wantDelayMs != delayMs { + t.Fatalf("expected delay of %v ms, got %v ms", tc.wantDelayMs, delayMs) + } + if tc.wantResponse == nil && resp != nil { + t.Fatalf("expected no response, got %v", resp) + } + if tc.wantResponse != nil && resp == nil { + t.Fatalf("expected response %v, got nil", tc.wantResponse) + } + if resp != nil && *tc.wantResponse != *resp { + t.Fatalf("expected response %v, got %v", tc.wantResponse, resp) + } + }) + } +} diff --git a/internal/faults/headers.go b/internal/faults/headers.go new file mode 100644 index 000000000..8302e97e1 --- /dev/null +++ b/internal/faults/headers.go @@ -0,0 +1,11 @@ +package faults + +const ( + FaultServerAddressHeader = "X-Bp-Fault-Server-Address" + FaultServerMethodHeader = "X-Bp-Fault-Server-Method" + FaultDelayMsHeader = "X-Bp-Fault-Delay-Ms" + FaultDelayPercentageHeader = "X-Bp-Fault-Delay-Percentage" + FaultAbortCodeHeader = "X-Bp-Fault-Abort-Code" + FaultAbortMessageHeader = "X-Bp-Fault-Abort-Message" + FaultAbortPercentageHeader = "X-Bp-Fault-Abort-Percentage" +) diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 00c33016f..4f5014edf 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -15,6 +15,7 @@ import ( "github.com/reddit/baseplate.go/breakerbp" "github.com/reddit/baseplate.go/ecinterface" "github.com/reddit/baseplate.go/errorsbp" + "github.com/reddit/baseplate.go/internal/faults" "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/thriftint" //lint:ignore SA1019 This library is internal only, not actually deprecated @@ -66,6 +67,11 @@ type DefaultClientMiddlewareArgs struct { // ImageUploadService -> image-upload ServiceSlug string + // Address is the DNS address of the thrift service you are creating clients for. + // + // If not provided, the client will be unable to use the fault injection middleware. + Address string + // RetryOptions is the list of retry.Options to apply as the defaults for the // Retry middleware. // @@ -105,38 +111,42 @@ type DefaultClientMiddlewareArgs struct { // // Currently they are (in order): // -// 1. ForwardEdgeRequestContext. +// 1. FaultInjectionClientMiddleware - This injects faults at the client side if +// the request matches the provided configuration. +// +// 2. ForwardEdgeRequestContext // -// 2. SetClientName(clientName) +// 3. SetClientName(clientName) // -// 3. MonitorClient with MonitorClientWrappedSlugSuffix - This creates the spans +// 4. MonitorClient with MonitorClientWrappedSlugSuffix - This creates the spans // from the view of the client that group all retries into a single, // wrapped span. // -// 4. PrometheusClientMiddleware with MonitorClientWrappedSlugSuffix - This +// 5. PrometheusClientMiddleware with MonitorClientWrappedSlugSuffix - This // creates the prometheus client metrics from the view of the client that group // all retries into a single operation. // -// 5. Retry(retryOptions) - If retryOptions is empty/nil, default to only +// 6. Retry(retryOptions) - If retryOptions is empty/nil, default to only // retry.Attempts(1), this will not actually retry any calls but your client is // configured to set retry logic per-call using retrybp.WithOptions. // -// 6. FailureRatioBreaker - Only if BreakerConfig is non-nil. +// 7. FailureRatioBreaker - Only if BreakerConfig is non-nil. // -// 7. MonitorClient - This creates the spans of the raw client calls. +// 8. MonitorClient - This creates the spans of the raw client calls. // -// 8. PrometheusClientMiddleware +// 9. PrometheusClientMiddleware // -// 9. BaseplateErrorWrapper +// 10. BaseplateErrorWrapper // -// 10. thrift.ExtractIDLExceptionClientMiddleware +// 11. thrift.ExtractIDLExceptionClientMiddleware // -// 11. SetDeadlineBudget +// 12. SetDeadlineBudget func BaseplateDefaultClientMiddlewares(args DefaultClientMiddlewareArgs) []thrift.ClientMiddleware { if len(args.RetryOptions) == 0 { args.RetryOptions = []retry.Option{retry.Attempts(1)} } middlewares := []thrift.ClientMiddleware{ + FaultInjectionClientMiddleware(args.Address), ForwardEdgeRequestContext(args.EdgeContextImpl), SetClientName(args.ClientName), MonitorClient(MonitorClientArgs{ @@ -390,6 +400,45 @@ func PrometheusClientMiddleware(remoteServerSlug string) thrift.ClientMiddleware } } +func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { + return func(next thrift.TClient) thrift.TClient { + return thrift.WrappedTClient{ + Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (thrift.ResponseMeta, error) { + if address == "" { + return next.Call(ctx, method, args, result) + } + + getHeaderFn := func(key string) string { + header, ok := thrift.GetHeader(ctx, key) + if !ok { + return "" + } + return header + } + resumeFn := func() (thrift.ResponseMeta, error) { + return next.Call(ctx, method, args, result) + } + responseFn := func(code int, message string) (thrift.ResponseMeta, error) { + return thrift.ResponseMeta{}, thrift.NewTTransportException(code, message) + } + + resp, err := faults.InjectFault(faults.InjectFaultParams[thrift.ResponseMeta]{ + Context: ctx, + CallerName: "thriftpb.FaultInjectionClientMiddleware", + Address: address, + Method: method, + AbortCodeMin: thrift.UNKNOWN_TRANSPORT_EXCEPTION, + AbortCodeMax: thrift.END_OF_FILE, + GetHeaderFn: getHeaderFn, + ResumeFn: resumeFn, + ResponseFn: responseFn, + }) + return resp, err + }, + } + } +} + func getClientError(result thrift.TStruct, err error) error { if err != nil { return err diff --git a/thriftbp/client_middlewares_test.go b/thriftbp/client_middlewares_test.go index 7d3f24c0e..6a5933856 100644 --- a/thriftbp/client_middlewares_test.go +++ b/thriftbp/client_middlewares_test.go @@ -12,6 +12,7 @@ import ( "github.com/reddit/baseplate.go" "github.com/reddit/baseplate.go/ecinterface" + "github.com/reddit/baseplate.go/internal/faults" baseplatethrift "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/prometheusbpint/spectest" "github.com/reddit/baseplate.go/prometheusbp" @@ -23,6 +24,7 @@ import ( const ( service = "testService" + address = "testService.testNamespace.svc.cluster.local:12345" method = "testMethod" ) @@ -38,6 +40,7 @@ func initClients(ecImpl ecinterface.Interface) (*thrifttest.MockClient, *thriftt thriftbp.DefaultClientMiddlewareArgs{ EdgeContextImpl: ecImpl, ServiceSlug: service, + Address: address, }, )..., ) @@ -395,6 +398,126 @@ func TestPrometheusClientMiddleware(t *testing.T) { } } +func TestFaultInjectionClientMiddleware(t *testing.T) { + testCases := []struct { + name string + + faultServerAddrHeader string + faultServerMethodHeader string + faultDelayMsHeader string + faultDelayPercentageHeader string + faultAbortCodeHeader string + faultAbortMessageHeader string + faultAbortPercentageHeader string + + wantErr error + }{ + { + name: "no fault specified", + wantErr: nil, + }, + { + name: "abort", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", // NOT_OPEN + faultAbortMessageHeader: "test fault", + + wantErr: thrift.NewTTransportException(1, "test fault"), + }, + { + name: "service does not match", + + faultServerAddrHeader: "fooService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", // NOT_OPEN + faultAbortMessageHeader: "test fault", + + wantErr: nil, + }, + { + name: "method does not match", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "fooMethod", + faultAbortCodeHeader: "1", // NOT_OPEN + faultAbortMessageHeader: "test fault", + + wantErr: nil, + }, + { + name: "less than min abort code", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "-1", + faultAbortMessageHeader: "test fault", + + wantErr: nil, + }, + { + name: "greater than max abort code", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "5", + faultAbortMessageHeader: "test fault", + + wantErr: nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + + impl := ecinterface.Mock() + ctx := context.Background() + + if tt.faultServerAddrHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultServerAddressHeader, tt.faultServerAddrHeader) + } + if tt.faultServerMethodHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultServerMethodHeader, tt.faultServerMethodHeader) + } + if tt.faultDelayMsHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultDelayMsHeader, tt.faultDelayMsHeader) + } + if tt.faultDelayPercentageHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultDelayPercentageHeader, tt.faultDelayPercentageHeader) + } + if tt.faultAbortCodeHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultAbortCodeHeader, tt.faultAbortCodeHeader) + } + if tt.faultAbortMessageHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultAbortMessageHeader, tt.faultAbortMessageHeader) + } + if tt.faultAbortPercentageHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultAbortPercentageHeader, tt.faultAbortPercentageHeader) + } + + mock, _, client := initClients(impl) + mock.AddMockCall( + method, + func(ctx context.Context, args, result thrift.TStruct) (meta thrift.ResponseMeta, err error) { + return + }, + ) + + _, err := client.Call(ctx, method, nil, nil) + if tt.wantErr == nil && err != nil { + t.Fatalf("expected no error, got %v", err) + } + if tt.wantErr != nil && err == nil { + t.Fatal("expected an error, got nil") + } + if err != nil && err.Error() != tt.wantErr.Error() { + t.Fatalf("expected error %v, got %v", tt.wantErr, err) + } + }) + } +} + type mockBaseplateService struct { fail bool err error diff --git a/thriftbp/client_pool.go b/thriftbp/client_pool.go index 6a84ab671..547b757a0 100644 --- a/thriftbp/client_pool.go +++ b/thriftbp/client_pool.go @@ -401,6 +401,7 @@ func NewBaseplateClientPoolWithContext(ctx context.Context, cfg ClientPoolConfig } defaults := BaseplateDefaultClientMiddlewares( DefaultClientMiddlewareArgs{ + Address: cfg.Addr, EdgeContextImpl: cfg.EdgeContextImpl, ServiceSlug: cfg.ServiceSlug, RetryOptions: cfg.DefaultRetryOptions,