diff --git a/thrift_streaming/thrift_handler.go b/thrift_streaming/thrift_handler.go index ebb66d6..5f8d1f8 100644 --- a/thrift_streaming/thrift_handler.go +++ b/thrift_streaming/thrift_handler.go @@ -17,6 +17,7 @@ package thrift_streaming import ( "context" "errors" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "strconv" "github.com/cloudwego/kitex/pkg/klog" @@ -55,19 +56,39 @@ func (e EchoServiceImpl) EchoBidirectional(stream echo.EchoService_EchoBidirecti func (e EchoServiceImpl) EchoClient(stream echo.EchoService_EchoClientServer) (err error) { klog.Infof("EchoClient: start") - count := GetInt(stream.Context(), KeyCount, 0) - var req *echo.EchoRequest - for i := 0; i < count; i++ { - req, err = stream.Recv() + resp := &echo.EchoResponse{} + doGetServerConn := GetBool(stream.Context(), KeyGetServerConn, false) + doInspectMWCtx := GetBool(stream.Context(), KeyInspectMWCtx, false) + switch { + case doGetServerConn: + _, err = nphttp2.GetServerConn(stream) if err != nil { - klog.Infof("EchoClient: recv error = %v", err) + klog.Infof("EchoClient: GetServerConn failed, error = %v", err) return } - klog.Infof("EchoClient: recv req = %v", req) - } - resp := &echo.EchoResponse{ - Message: strconv.Itoa(count), + resp.Message = "GetServerConn Succeeded" + case doInspectMWCtx: + val, ok := stream.Context().Value("key").(string) + if !ok || val != "val" { + err = errors.New("can not get ctx value set in server MW") + klog.Infof("EchoClient: InspectMWCtx failed, error = %v", err) + return + } + resp.Message = "InspectMWCtx Succeeded" + default: + count := GetInt(stream.Context(), KeyCount, 0) + var req *echo.EchoRequest + for i := 0; i < count; i++ { + req, err = stream.Recv() + if err != nil { + klog.Infof("EchoClient: recv error = %v", err) + return + } + klog.Infof("EchoClient: recv req = %v", req) + } + resp.Message = strconv.Itoa(count) } + if err = stream.SendAndClose(resp); err != nil { klog.Infof("EchoClient: send&close error = %v", err) return diff --git a/thrift_streaming/thrift_test.go b/thrift_streaming/thrift_test.go index 8c42892..c9e7a71 100644 --- a/thrift_streaming/thrift_test.go +++ b/thrift_streaming/thrift_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "io" "reflect" "strconv" @@ -553,6 +554,59 @@ func TestKitexServerMiddleware(t *testing.T) { test.Assert(t, err == nil, err) test.Assert(t, resp.Message == "2_send_middleware", resp.Message) }) + + t.Run("gRPC GetServerConn", func(t *testing.T) { + svr := RunThriftServer(&EchoServiceImpl{}, addr, + server.WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, args, result interface{}) (err error) { + streamArg, ok := args.(*streaming.Args) + test.Assert(t, ok) + _, err = nphttp2.GetServerConn(streamArg.Stream) + test.Assert(t, err == nil, err) + return next(ctx, args, result) + } + }), + ) + defer svr.Stop() + + cli := echoservice.MustNewStreamClient("service", streamclient.WithHostPorts(addr)) + ctx := metainfo.WithValue(context.Background(), KeyGetServerConn, "true") + stream, err := cli.EchoClient(ctx) + test.Assert(t, err == nil, err) + + err = stream.Send(&echo.EchoRequest{Message: "GetServerConn"}) + test.Assert(t, err == nil, err) + + resp, err := stream.CloseAndRecv() + test.Assert(t, err == nil, err) + test.Assert(t, resp.Message == "GetServerConn Succeeded") + }) + + t.Run("process ctx in middleware and reflect to Stream.Context()", func(t *testing.T) { + svr := RunThriftServer(&EchoServiceImpl{}, addr, + server.WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, args, result interface{}) (err error) { + _, ok := args.(*streaming.Args) + test.Assert(t, ok) + ctx = context.WithValue(ctx, "key", "val") + return next(ctx, args, result) + } + }), + ) + defer svr.Stop() + + cli := echoservice.MustNewStreamClient("service", streamclient.WithHostPorts(addr)) + ctx := metainfo.WithValue(context.Background(), KeyInspectMWCtx, "true") + stream, err := cli.EchoClient(ctx) + test.Assert(t, err == nil, err) + + err = stream.Send(&echo.EchoRequest{Message: "InspectMWCtx"}) + test.Assert(t, err == nil, err) + + resp, err := stream.CloseAndRecv() + test.Assert(t, err == nil, err) + test.Assert(t, resp.Message == "InspectMWCtx Succeeded") + }) } func TestTimeoutRecvSend(t *testing.T) { diff --git a/thrift_streaming/util.go b/thrift_streaming/util.go index 96f9178..86cfff5 100644 --- a/thrift_streaming/util.go +++ b/thrift_streaming/util.go @@ -29,6 +29,8 @@ const ( KeyCount = "COUNT" KeyServerRecvTimeoutMS = "RECV_TIMEOUT_MS" KeyServerSendTimeoutMS = "SEND_TIMEOUT_MS" + KeyGetServerConn = "GET_SERVER_CONN" + KeyInspectMWCtx = "INSPECT_MW_CTX" ) func GetError(ctx context.Context) error { @@ -81,3 +83,10 @@ func GetInt(ctx context.Context, key string, defaultValue int) int { } return defaultValue } + +func GetBool(ctx context.Context, key string, defaultBool bool) bool { + if b, err := strconv.ParseBool(GetValue(ctx, key, "")); err == nil { + return b + } + return defaultBool +}