Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: feat(generic): support thrift streaming for json generic client #1467

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions client/genericclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ type genericServiceClient struct {

func (gc *genericServiceClient) GenericCall(ctx context.Context, method string, request interface{}, callOptions ...callopt.Option) (response interface{}, err error) {
ctx = client.NewCtxWithCallOptions(ctx, callOptions)
_args := gc.svcInfo.MethodInfo(method).NewArgs().(*generic.Args)
mtInfo := gc.svcInfo.MethodInfo(method)
_args := mtInfo.NewArgs().(*generic.Args)
_args.Method = method
_args.Request = request

Expand All @@ -109,7 +110,7 @@ func (gc *genericServiceClient) GenericCall(ctx context.Context, method string,
return nil, gc.kClient.Call(ctx, mt.Name, _args, nil)
}

_result := gc.svcInfo.MethodInfo(method).NewResult().(*generic.Result)
_result := mtInfo.NewResult().(*generic.Result)
if err = gc.kClient.Call(ctx, mt.Name, _args, _result); err != nil {
return
}
Expand Down
5 changes: 3 additions & 2 deletions client/genericclient/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ func NewServerStreaming(ctx context.Context, genericCli Client, method string, r
if err != nil {
return nil, err
}
ss := &serverStreamingClient{stream, gCli.svcInfo.MethodInfo(method)}
_args := gCli.svcInfo.MethodInfo(method).NewArgs().(*generic.Args)
mtInfo := gCli.svcInfo.MethodInfo(method)
ss := &serverStreamingClient{stream, mtInfo}
_args := mtInfo.NewArgs().(*generic.Args)
_args.Method = method
_args.Request = req
if err = ss.Stream.SendMsg(_args); err != nil {
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion internal/generic/thrift/binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"io"

"github.com/cloudwego/gopkg/protocol/thrift"
"github.com/cloudwego/gopkg/protocol/thrift/base"
)

Expand All @@ -30,6 +31,6 @@ func NewWriteBinary() *WriteBinary {
return &WriteBinary{}
}

func (w *WriteBinary) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
func (w *WriteBinary) Write(ctx context.Context, out io.Writer, bw *thrift.BinaryWriter, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
return nil
}
8 changes: 4 additions & 4 deletions internal/generic/thrift/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.O
}

// originalWrite ...
func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out io.Writer, msg interface{}, requestBase *base.Base) error {
func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out io.Writer, bw *thrift.BinaryWriter, msg interface{}, requestBase *base.Base) error {
req := msg.(*descriptor.HTTPRequest)
if req.Body == nil && len(req.RawBody) != 0 {
if err := customJson.Unmarshal(req.RawBody, &req.Body); err != nil {
Expand All @@ -94,11 +94,11 @@ func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out io.Writer, msg
if !fn.HasRequestBase {
requestBase = nil
}
binaryWriter := thrift.NewBinaryWriter()
if err = wrapStructWriter(ctx, req, binaryWriter, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64}); err != nil {

if err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64}); err != nil {
return err
}
_, err = out.Write(binaryWriter.Bytes())
_, err = out.Write(bw.Bytes())
return err
}

Expand Down
5 changes: 3 additions & 2 deletions internal/generic/thrift/http_fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import (
"context"
"io"

"github.com/cloudwego/gopkg/protocol/thrift"
"github.com/cloudwego/gopkg/protocol/thrift/base"
)

// Write ...
func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
return w.originalWrite(ctx, out, msg, requestBase)
func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, bw *thrift.BinaryWriter, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
return w.originalWrite(ctx, out, bw, msg, requestBase)
}
16 changes: 7 additions & 9 deletions internal/generic/thrift/http_go116plus_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ import (
)

// Write ...
func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, bw *thrift.BinaryWriter, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
// fallback logic
if !w.dynamicgoEnabled {
return w.originalWrite(ctx, out, msg, requestBase)
return w.originalWrite(ctx, out, bw, msg, requestBase)
}

// dynamicgo logic
Expand All @@ -63,29 +63,27 @@ func (w *WriteHTTPRequest) Write(ctx context.Context, out io.Writer, msg interfa
cv = j2t.NewBinaryConv(w.convOpts)
}

binaryWriter := thrift.NewBinaryWriter()

ctx = context.WithValue(ctx, conv.CtxKeyHTTPRequest, req)
body := req.GetBody()
dbuf := mcache.Malloc(len(body))[0:0]
defer mcache.Free(dbuf)

for _, field := range dynamicgoTypeDsc.Struct().Fields() {
binaryWriter.WriteFieldBegin(thrift.TType(field.Type().Type()), int16(field.ID()))
bw.WriteFieldBegin(thrift.TType(field.Type().Type()), int16(field.ID()))

// json []byte to thrift []byte
if err := cv.DoInto(ctx, field.Type(), body, &dbuf); err != nil {
return err
}
}
if _, err := out.Write(binaryWriter.Bytes()); err != nil {
if _, err := out.Write(bw.Bytes()); err != nil {
return err
}
if _, err := out.Write(dbuf); err != nil {
return err
}
binaryWriter.Reset()
binaryWriter.WriteFieldStop()
_, err := out.Write(binaryWriter.Bytes())
bw.Reset()
bw.WriteFieldStop()
_, err := out.Write(bw.Bytes())
return err
}
7 changes: 3 additions & 4 deletions internal/generic/thrift/http_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func NewWriteHTTPPbRequest(svc *descriptor.ServiceDescriptor, pbSvc *desc.Servic
}

// Write ...
func (w *WriteHTTPPbRequest) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
func (w *WriteHTTPPbRequest) Write(ctx context.Context, out io.Writer, bw *thrift.BinaryWriter, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
req := msg.(*descriptor.HTTPRequest)
fn, err := w.svc.Router.Lookup(req)
if err != nil {
Expand All @@ -77,11 +77,10 @@ func (w *WriteHTTPPbRequest) Write(ctx context.Context, out io.Writer, msg inter
}
req.GeneralBody = pbMsg

binaryWriter := thrift.NewBinaryWriter()
if err = wrapStructWriter(ctx, req, binaryWriter, fn.Request, &writerOption{requestBase: requestBase}); err != nil {
if err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase}); err != nil {
return err
}
_, err = out.Write(binaryWriter.Bytes())
_, err = out.Write(bw.Bytes())
return err
}

Expand Down
102 changes: 88 additions & 14 deletions internal/generic/thrift/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"strconv"

"github.com/bytedance/gopkg/lang/dirtmake"
"github.com/bytedance/sonic"
"github.com/cloudwego/dynamicgo/conv"
"github.com/cloudwego/dynamicgo/conv/t2j"
dthrift "github.com/cloudwego/dynamicgo/thrift"
Expand All @@ -34,6 +35,7 @@ import (
"github.com/cloudwego/kitex/pkg/generic/descriptor"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/utils"
)

Expand Down Expand Up @@ -83,7 +85,7 @@ func (m *WriteJSON) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options)
m.dynamicgoEnabled = true
}

func (m *WriteJSON) originalWrite(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
func (m *WriteJSON) originalWrite(ctx context.Context, out io.Writer, bw *thrift.BinaryWriter, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
fnDsc, err := m.svcDsc.LookupFunctionByMethod(method)
if err != nil {
return fmt.Errorf("missing method: %s in service: %s", method, m.svcDsc.Name)
Expand All @@ -98,14 +100,12 @@ func (m *WriteJSON) originalWrite(ctx context.Context, out io.Writer, msg interf
requestBase = nil
}

binaryWriter := thrift.NewBinaryWriter()

// msg is void or nil
if _, ok := msg.(descriptor.Void); ok || msg == nil {
if err = wrapStructWriter(ctx, msg, binaryWriter, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil {
if err = wrapStructWriter(ctx, msg, bw, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil {
return err
}
_, err = out.Write(binaryWriter.Bytes())
_, err = out.Write(bw.Bytes())
return err
}

Expand All @@ -125,10 +125,21 @@ func (m *WriteJSON) originalWrite(ctx context.Context, out io.Writer, msg interf
Index: 0,
}
}
if err = wrapJSONWriter(ctx, &body, binaryWriter, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil {

opt := &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}
if isStreaming(fnDsc.StreamingMode) {
// unwrap one struct layer
typeDsc = typeDsc.Struct.FieldsByID[int32(getStreamingFieldID(isClient, true))].Type
return writeStreamingContent(ctx, &body, typeDsc, opt, bw)
}

if err = wrapJSONWriter(ctx, &body, bw, typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary}); err != nil {
return err
}
_, err = out.Write(binaryWriter.Bytes())
if _, isGRPC := out.(remote.FrameWrite); isGRPC {
return nil
}
_, err = out.Write(bw.Bytes())
return err
}

Expand Down Expand Up @@ -187,33 +198,54 @@ func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataL
tyDsc = fnDsc.Request()
}

_, isGRPC := buffer.(remote.FrameWrite)

var resp interface{}
var err error
if tyDsc.Struct().Fields()[0].Type().Type() == dthrift.VOID {
if _, err := buffer.ReadBinary(voidWholeLen); err != nil {
if isGRPC {
_, err = buffer.Next(voidWholeLen)
} else {
_, err = buffer.ReadBinary(voidWholeLen)
}
if err != nil {
return nil, err
}
resp = descriptor.Void{}
} else {
transBuff, err := buffer.ReadBinary(dataLen)
var transBuff []byte
if isGRPC {
transBuff, err = buffer.Next(dataLen)
} else {
transBuff, err = buffer.ReadBinary(dataLen)
}
if err != nil {
return nil, err
}

isStream := isStreaming(m.svc.Functions[method].StreamingMode)
if isStream {
// unwrap one struct layer
tyDsc = tyDsc.Struct().FieldById(dthrift.FieldID(getStreamingFieldID(isClient, false))).Type()
}

// json size is usually 2 times larger than equivalent thrift data
buf := dirtmake.Bytes(0, len(transBuff)*2)
// thrift []byte to json []byte
var t2jBinaryConv t2j.BinaryConv
if isClient {
if isClient && !isStream {
t2jBinaryConv = t2j.NewBinaryConv(m.convOptsWithException)
} else {
t2jBinaryConv = t2j.NewBinaryConv(m.convOpts)
}
if err := t2jBinaryConv.DoInto(ctx, tyDsc, transBuff, &buf); err != nil {
return nil, err
}
buf = removePrefixAndSuffix(buf)
if !isStream {
buf = removePrefixAndSuffix(buf)
}
resp = utils.SliceByteToString(buf)
if tyDsc.Struct().Fields()[0].Type().Type() == dthrift.STRING {
if !isStream && tyDsc.Struct().Fields()[0].Type().Type() == dthrift.STRING {
strresp := resp.(string)
resp, err = strconv.Unquote(strresp)
if err != nil {
Expand All @@ -225,7 +257,7 @@ func (m *ReadJSON) Read(ctx context.Context, method string, isClient bool, dataL
return resp, nil
}

func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient bool, in *thrift.BinaryReader) (interface{}, error) {
func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient bool, br *thrift.BinaryReader) (interface{}, error) {
fnDsc, err := m.svc.LookupFunctionByMethod(method)
if err != nil {
return nil, err
Expand All @@ -234,7 +266,14 @@ func (m *ReadJSON) originalRead(ctx context.Context, method string, isClient boo
if !isClient {
fDsc = fnDsc.Request
}
resp, err := skipStructReader(ctx, in, fDsc, &readerOption{forJSON: true, throwException: true, binaryWithBase64: m.binaryWithBase64})

if isStreaming(fnDsc.StreamingMode) {
// unwrap one struct layer
fDsc = fDsc.Struct.FieldsByID[int32(getStreamingFieldID(isClient, false))].Type
return readStreamingContent(ctx, fDsc, &readerOption{forJSON: true, binaryWithBase64: m.binaryWithBase64}, br)
}

resp, err := skipStructReader(ctx, br, fDsc, &readerOption{forJSON: true, throwException: true, binaryWithBase64: m.binaryWithBase64})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -265,3 +304,38 @@ func removePrefixAndSuffix(buf []byte) []byte {
}
return buf
}

func isStreaming(streamingMode serviceinfo.StreamingMode) bool {
return streamingMode != serviceinfo.StreamingUnary && streamingMode != serviceinfo.StreamingNone
}

func getStreamingFieldID(isClient, isWrite bool) int {
var streamingFieldID int
if (isWrite && isClient) || (!isWrite && !isClient) {
streamingFieldID = 1
}
return streamingFieldID
}

func writeStreamingContent(ctx context.Context, body *gjson.Result, typeDsc *descriptor.TypeDescriptor, opt *writerOption, bw *thrift.BinaryWriter) error {
val, writer, err := nextJSONWriter(body, typeDsc, opt)
if err != nil {
return fmt.Errorf("nextWriter of field[%s] error %w", typeDsc.Name, err)
}
if err = writer(ctx, val, bw, typeDsc, opt); err != nil {
return fmt.Errorf("writer of field[%s] error %w", typeDsc.Name, err)
}
return nil
}

func readStreamingContent(ctx context.Context, typeDesc *descriptor.TypeDescriptor, opt *readerOption, br *thrift.BinaryReader) (v interface{}, err error) {
resp, err := readStruct(ctx, br, typeDesc, opt)
if err != nil {
return nil, err
}
respNode, err := sonic.Marshal(resp)
if err != nil {
return nil, perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("streaming response marshal failed. err:%#v", err))
}
return string(respNode), nil
}
5 changes: 3 additions & 2 deletions internal/generic/thrift/json_fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ package thrift

import (
"context"
"github.com/cloudwego/gopkg/protocol/thrift"
"io"

"github.com/cloudwego/gopkg/protocol/thrift/base"
)

// Write write json string to out thrift.TProtocol
func (m *WriteJSON) Write(ctx context.Context, out io.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
return m.originalWrite(ctx, out, msg, method, isClient, requestBase)
func (m *WriteJSON) Write(ctx context.Context, out io.Writer, bw *thrift.BinaryWriter, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
return m.originalWrite(ctx, out, bw, msg, method, isClient, requestBase)
}
Loading
Loading