From 5f575725760975060d99862ae6fceb70dc4868c7 Mon Sep 17 00:00:00 2001 From: "yangrui.emma" Date: Mon, 22 May 2023 11:40:35 +0800 Subject: [PATCH] test: add mock panic test for client-side --- thriftrpc/failedcall/failedcall_test.go | 32 +++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/thriftrpc/failedcall/failedcall_test.go b/thriftrpc/failedcall/failedcall_test.go index 608e762..84397ea 100644 --- a/thriftrpc/failedcall/failedcall_test.go +++ b/thriftrpc/failedcall/failedcall_test.go @@ -21,6 +21,8 @@ import ( "time" "github.com/apache/thrift/lib/go/thrift" + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex-tests/kitex_gen/thrift/stability/stservice" "github.com/cloudwego/kitex-tests/pkg/test" @@ -59,13 +61,13 @@ func TestMain(m *testing.M) { svr.Stop() } -func getKitexClient(p transport.Protocol) stservice.Client { +func getKitexClient(p transport.Protocol, opts ...client.Option) stservice.Client { return thriftrpc.CreateKitexClient(&thriftrpc.ClientInitParam{ TargetServiceName: "cloudwego.kitex.testa", HostPorts: []string{":9001"}, Protocol: p, ConnMode: thriftrpc.LongConnection, - }) + }, opts...) } // TestSTReq method mock STRequest param read failed in server @@ -141,3 +143,29 @@ func TestVisitOneway(t *testing.T) { err = cli.VisitOneway(ctx, stReq) test.Assert(t, err == nil, err) } + +// TestClientMWPanic mock panic in customized mw +func TestClientMWPanic(t *testing.T) { + cli = getKitexClient(transport.Framed, client.WithMiddleware(panicMW)) + + // case1: panic without timeout + ctx, objReq := thriftrpc.CreateObjReq(context.Background()) + objResp, err := cli.TestObjReq(ctx, objReq) + test.Assert(t, err != nil) + test.Assert(t, strings.Contains(err.Error(), panicMsg), err.Error()) + + // case2: panic with timeout + ctx, objReq = thriftrpc.CreateObjReq(context.Background()) + objResp, err = cli.TestObjReq(ctx, objReq, callopt.WithRPCTimeout(time.Second)) + test.Assert(t, err != nil) + test.Assert(t, objResp == nil) + test.Assert(t, strings.Contains(err.Error(), panicMsg), err.Error()) +} + +func panicMW(endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request, response interface{}) (err error) { + panic(panicMsg) + } +} + +const panicMsg = "mock panic"