Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(streaming): stream ctx diverge and nphttp2.GetServerConn #85

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions thrift_streaming/thrift_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package thrift_streaming
import (
"context"
"errors"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2"
"strconv"

"github.com/cloudwego/kitex/pkg/klog"
Expand Down Expand Up @@ -55,19 +56,39 @@ func (e EchoServiceImpl) EchoBidirectional(stream echo.EchoService_EchoBidirecti

func (e EchoServiceImpl) EchoClient(stream echo.EchoService_EchoClientServer) (err error) {
klog.Infof("EchoClient: start")
count := GetInt(stream.Context(), KeyCount, 0)
var req *echo.EchoRequest
for i := 0; i < count; i++ {
req, err = stream.Recv()
resp := &echo.EchoResponse{}
doGetServerConn := GetBool(stream.Context(), KeyGetServerConn, false)
doInspectMWCtx := GetBool(stream.Context(), KeyInspectMWCtx, false)
switch {
case doGetServerConn:
_, err = nphttp2.GetServerConn(stream)
if err != nil {
klog.Infof("EchoClient: recv error = %v", err)
klog.Infof("EchoClient: GetServerConn failed, error = %v", err)
return
}
klog.Infof("EchoClient: recv req = %v", req)
}
resp := &echo.EchoResponse{
Message: strconv.Itoa(count),
resp.Message = "GetServerConn Succeeded"
case doInspectMWCtx:
val, ok := stream.Context().Value("key").(string)
if !ok || val != "val" {
err = errors.New("can not get ctx value set in server MW")
klog.Infof("EchoClient: InspectMWCtx failed, error = %v", err)
return
}
resp.Message = "InspectMWCtx Succeeded"
default:
count := GetInt(stream.Context(), KeyCount, 0)
var req *echo.EchoRequest
for i := 0; i < count; i++ {
req, err = stream.Recv()
if err != nil {
klog.Infof("EchoClient: recv error = %v", err)
return
}
klog.Infof("EchoClient: recv req = %v", req)
}
resp.Message = strconv.Itoa(count)
}

if err = stream.SendAndClose(resp); err != nil {
klog.Infof("EchoClient: send&close error = %v", err)
return
Expand Down
54 changes: 54 additions & 0 deletions thrift_streaming/thrift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"fmt"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2"
"io"
"reflect"
"strconv"
Expand Down Expand Up @@ -553,6 +554,59 @@ func TestKitexServerMiddleware(t *testing.T) {
test.Assert(t, err == nil, err)
test.Assert(t, resp.Message == "2_send_middleware", resp.Message)
})

t.Run("gRPC GetServerConn", func(t *testing.T) {
svr := RunThriftServer(&EchoServiceImpl{}, addr,
server.WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, args, result interface{}) (err error) {
streamArg, ok := args.(*streaming.Args)
test.Assert(t, ok)
_, err = nphttp2.GetServerConn(streamArg.Stream)
test.Assert(t, err == nil, err)
return next(ctx, args, result)
}
}),
)
defer svr.Stop()

cli := echoservice.MustNewStreamClient("service", streamclient.WithHostPorts(addr))
ctx := metainfo.WithValue(context.Background(), KeyGetServerConn, "true")
stream, err := cli.EchoClient(ctx)
test.Assert(t, err == nil, err)

err = stream.Send(&echo.EchoRequest{Message: "GetServerConn"})
test.Assert(t, err == nil, err)

resp, err := stream.CloseAndRecv()
test.Assert(t, err == nil, err)
test.Assert(t, resp.Message == "GetServerConn Succeeded")
})

t.Run("process ctx in middleware and reflect to Stream.Context()", func(t *testing.T) {
svr := RunThriftServer(&EchoServiceImpl{}, addr,
server.WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, args, result interface{}) (err error) {
_, ok := args.(*streaming.Args)
test.Assert(t, ok)
ctx = context.WithValue(ctx, "key", "val")
return next(ctx, args, result)
}
}),
)
defer svr.Stop()

cli := echoservice.MustNewStreamClient("service", streamclient.WithHostPorts(addr))
ctx := metainfo.WithValue(context.Background(), KeyInspectMWCtx, "true")
stream, err := cli.EchoClient(ctx)
test.Assert(t, err == nil, err)

err = stream.Send(&echo.EchoRequest{Message: "InspectMWCtx"})
test.Assert(t, err == nil, err)

resp, err := stream.CloseAndRecv()
test.Assert(t, err == nil, err)
test.Assert(t, resp.Message == "InspectMWCtx Succeeded")
})
}

func TestTimeoutRecvSend(t *testing.T) {
Expand Down
9 changes: 9 additions & 0 deletions thrift_streaming/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ const (
KeyCount = "COUNT"
KeyServerRecvTimeoutMS = "RECV_TIMEOUT_MS"
KeyServerSendTimeoutMS = "SEND_TIMEOUT_MS"
KeyGetServerConn = "GET_SERVER_CONN"
KeyInspectMWCtx = "INSPECT_MW_CTX"
)

func GetError(ctx context.Context) error {
Expand Down Expand Up @@ -81,3 +83,10 @@ func GetInt(ctx context.Context, key string, defaultValue int) int {
}
return defaultValue
}

func GetBool(ctx context.Context, key string, defaultBool bool) bool {
if b, err := strconv.ParseBool(GetValue(ctx, key, "")); err == nil {
return b
}
return defaultBool
}
Loading