Skip to content

Commit

Permalink
refactor: refine streamx generation code (#1610)
Browse files Browse the repository at this point in the history
  • Loading branch information
DMwangnima authored Nov 15, 2024
1 parent 65dc3b1 commit 5784bac
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 418 deletions.
7 changes: 4 additions & 3 deletions pkg/streamx/streamx_gen_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"

"github.com/cloudwego/kitex/client"
"github.com/cloudwego/kitex/client/callopt"
"github.com/cloudwego/kitex/client/streamxclient"
"github.com/cloudwego/kitex/client/streamxclient/streamxcallopt"
"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand Down Expand Up @@ -84,7 +85,7 @@ var testServiceInfo = &serviceinfo.ServiceInfo{
),
"ServerStream": serviceinfo.NewMethodInfo(
func(ctx context.Context, handler, reqArgs, resArgs interface{}) error {
return streamxserver.InvokeServerStreamHandler(
return streamxserver.InvokeServerStreamHandler[Request, Response](
ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs),
func(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error {
return handler.(TestService).ServerStream(ctx, req, stream)
Expand Down Expand Up @@ -215,7 +216,7 @@ type TestService interface {

// --- Define Client Implementation Interface ---
type TestServiceClient interface {
PingPong(ctx context.Context, req *Request) (r *Response, err error)
PingPong(ctx context.Context, req *Request, callOptions ...callopt.Option) (r *Response, err error)

Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error)
ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) (
Expand All @@ -240,7 +241,7 @@ type kClient struct {
caller client.Client
}

func (c *kClient) PingPong(ctx context.Context, req *Request) (r *Response, err error) {
func (c *kClient) PingPong(ctx context.Context, req *Request, callOptions ...callopt.Option) (r *Response, err error) {
var _args ServerPingPongArgs
_args.Req = req
var _result ServerPingPongResult
Expand Down
2 changes: 1 addition & 1 deletion tool/cmd/kitex/args/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func (a *Arguments) checkStreamX() error {
// set TTHeader Streaming by default
a.Protocol = transport.TTHeader.String()
}
// todo: process pb and gRPC
// todo(DMwangnima): process pb and gRPC
return nil
}

Expand Down
171 changes: 43 additions & 128 deletions tool/internal_pkg/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (

"github.com/cloudwego/kitex/tool/internal_pkg/log"
"github.com/cloudwego/kitex/tool/internal_pkg/tpl"
"github.com/cloudwego/kitex/tool/internal_pkg/tpl/streamx"
"github.com/cloudwego/kitex/tool/internal_pkg/util"
"github.com/cloudwego/kitex/transport"
)
Expand Down Expand Up @@ -299,9 +298,6 @@ func (c *Config) ApplyExtension() error {
}

func (c *Config) IsUsingMultipleServicesTpl() bool {
if c.StreamX {
return true
}
for _, part := range c.BuiltinTpl {
if part == MultipleServicesTpl {
return true
Expand Down Expand Up @@ -432,19 +428,10 @@ func (g *generator) generateHandler(pkg *PackageInfo, svc *ServiceInfo, handlerF
return f, nil
}

var task Task
if g.StreamX && svc.HasStreaming {
task = Task{
Name: HandlerFileName,
Path: handlerFilePath,
Text: tpl.HandlerTpl + "\n" + streamx.HandlerMethodsTpl,
}
} else {
task = Task{
Name: HandlerFileName,
Path: handlerFilePath,
Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl,
}
task := Task{
Name: HandlerFileName,
Path: handlerFilePath,
Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl,
}
g.setImports(task.Name, pkg)
handle := func(task *Task, pkg *PackageInfo) (*File, error) {
Expand Down Expand Up @@ -484,11 +471,6 @@ func (g *generator) GenerateService(pkg *PackageInfo) ([]*File, error) {
Path: util.JoinPath(output, svcPkg+".go"),
Text: tpl.ServiceTpl,
}
if g.StreamX && pkg.ServiceInfo.HasStreaming {
cliTask.Text = streamx.ClientTpl
svrTask.Text = streamx.ServerTpl
svcTask.Text = streamx.ServiceTpl
}
tasks := []*Task{cliTask, svrTask, svcTask}

// do not generate invoker.go in service package by default
Expand Down Expand Up @@ -558,31 +540,28 @@ func (g *generator) setImports(name string, pkg *PackageInfo) {
pkg.Imports = make(map[string]map[string]bool)
switch name {
case ClientFileName:
if g.StreamX && pkg.HasStreaming {
g.setStreamXClientImports(pkg)
} else {
pkg.AddImports("client")
if pkg.HasStreaming {
pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming")
pkg.AddImport("transport", "github.com/cloudwego/kitex/transport")
}
if len(pkg.AllMethods()) > 0 {
if needCallOpt(pkg) {
pkg.AddImports("callopt")
}
pkg.AddImports("context")
pkg.AddImports("client")
if !g.StreamX && pkg.HasStreaming {
pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming")
pkg.AddImport("transport", "github.com/cloudwego/kitex/transport")
}
if len(pkg.AllMethods()) > 0 {
if needCallOpt(pkg) {
pkg.AddImports("callopt")
}
pkg.AddImports("context")
}
fallthrough
case HandlerFileName:
if g.StreamX && pkg.HasStreaming {
g.setStreamXHandlerImports(pkg)
return
}
for _, m := range pkg.ServiceInfo.AllMethods() {
if !m.ServerStreaming && !m.ClientStreaming {
// for StreamX interface, every method in handler has ctx argument
// for old interface, streaming method in handler does not have ctx argument
if g.StreamX || (!m.ServerStreaming && !m.ClientStreaming) {
pkg.AddImports("context")
}
if g.StreamX && m.Streaming.IsStreaming {
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx")
}
for _, a := range m.Args {
for _, dep := range a.Deps {
pkg.AddImport(dep.PkgRefName, dep.ImportPath)
Expand All @@ -595,19 +574,20 @@ func (g *generator) setImports(name string, pkg *PackageInfo) {
}
}
case ServerFileName, InvokerFileName:
if g.StreamX && pkg.HasStreaming {
g.setStreamXServerImports(pkg)
return
// for StreamX, if there is streaming method, generate Server Interface in server.go
if g.StreamX {
for _, method := range pkg.AllMethods() {
if method.Streaming.IsStreaming {
pkg.AddImports("context")
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx")
}
}
}
if len(pkg.CombineServices) == 0 {
pkg.AddImport(pkg.ServiceInfo.PkgRefName, pkg.ServiceInfo.ImportPath)
}
pkg.AddImports("server")
case ServiceFileName:
if g.StreamX && pkg.HasStreaming {
g.setStreamXServiceImports(pkg)
return
}
pkg.AddImports("errors")
pkg.AddImports("client")
pkg.AddImport("kitex", "github.com/cloudwego/kitex/pkg/serviceinfo")
Expand All @@ -616,9 +596,6 @@ func (g *generator) setImports(name string, pkg *PackageInfo) {
pkg.AddImports("context")
}
for _, m := range pkg.ServiceInfo.AllMethods() {
if m.ClientStreaming || m.ServerStreaming {
pkg.AddImports("fmt")
}
if m.GenArgResultStruct {
pkg.AddImports("proto")
} else {
Expand All @@ -630,9 +607,22 @@ func (g *generator) setImports(name string, pkg *PackageInfo) {
pkg.AddImport(dep.PkgRefName, dep.ImportPath)
}
}
if m.Streaming.IsStreaming || pkg.Codec == "protobuf" {
// protobuf handler support both PingPong and Unary (streaming) requests
pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming")
// streaming imports
if !g.StreamX {
if m.Streaming.IsStreaming || pkg.Codec == "protobuf" {
// protobuf handler support both PingPong and Unary (streaming) requests
pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming")
}
if m.ClientStreaming || m.ServerStreaming {
pkg.AddImports("fmt")
}
} else {
if m.Streaming.IsStreaming {
pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient")
pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient/streamxcallopt")
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx")
pkg.AddImports("github.com/cloudwego/kitex/server/streamxserver")
}
}
if !m.Void && m.Resp != nil {
for _, dep := range m.Resp.Deps {
Expand Down Expand Up @@ -681,78 +671,3 @@ func needCallOpt(pkg *PackageInfo) bool {
}
return needCallOpt
}

func (g *generator) setStreamXClientImports(pkg *PackageInfo) {
pkg.AddImports("client")
pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient")
if len(pkg.AllMethods()) > 0 {
pkg.AddImports("context")
pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient/streamxcallopt")
pkg.AddImports("github.com/cloudwego/kitex/pkg/serviceinfo")
}
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx")
if g.IDLType == "thrift" {
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef)
}
}

func (g *generator) setStreamXServerImports(pkg *PackageInfo) {
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx")
pkg.AddImports("server")
pkg.AddImports("github.com/cloudwego/kitex/server/streamxserver")
if g.IDLType == "thrift" {
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef)
}
for _, m := range pkg.AllMethods() {
pkg.AddImports("context")
for _, a := range m.Args {
for _, dep := range a.Deps {
pkg.AddImport(dep.PkgRefName, dep.ImportPath)
}
}
if !m.Void && m.Resp != nil {
for _, dep := range m.Resp.Deps {
pkg.AddImport(dep.PkgRefName, dep.ImportPath)
}
}
}
}

func (g *generator) setStreamXServiceImports(pkg *PackageInfo) {
pkg.AddImports("github.com/cloudwego/kitex/pkg/serviceinfo")
for _, m := range pkg.AllMethods() {
pkg.AddImports("context")
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx")
pkg.AddImports("github.com/cloudwego/kitex/server/streamxserver")
for _, a := range m.Args {
for _, dep := range a.Deps {
pkg.AddImport(dep.PkgRefName, dep.ImportPath)
}
}
if !m.Void && m.Resp != nil {
for _, dep := range m.Resp.Deps {
pkg.AddImport(dep.PkgRefName, dep.ImportPath)
}
}
}
}

func (g *generator) setStreamXHandlerImports(pkg *PackageInfo) {
for _, m := range pkg.ServiceInfo.AllMethods() {
pkg.AddImports("context")
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx")
if g.IDLType == "thrift" {
pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef)
}
for _, a := range m.Args {
for _, dep := range a.Deps {
pkg.AddImport(dep.PkgRefName, dep.ImportPath)
}
}
if !m.Void && m.Resp != nil {
for _, dep := range m.Resp.Deps {
pkg.AddImport(dep.PkgRefName, dep.ImportPath)
}
}
}
}
3 changes: 3 additions & 0 deletions tool/internal_pkg/generator/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ type ServiceInfo struct {
RefName string
// identify whether this service would generate a corresponding handler.
GenerateHandler bool
// whether to generate StreamX interface code
StreamX bool
}

// AllMethods returns all methods that the service have.
Expand Down Expand Up @@ -212,6 +214,7 @@ type MethodInfo struct {
ClientStreaming bool
ServerStreaming bool
Streaming *streaming.Streaming
StreamX bool
}

// Parameter .
Expand Down
3 changes: 3 additions & 0 deletions tool/internal_pkg/pluginmode/thriftgo/convertor.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ func (c *converter) convertTypes(req *plugin.Request) error {
ServiceFilePath: ast.Filename,
HasStreaming: hasStreaming,
GenerateHandler: true,
StreamX: c.Config.StreamX,
}

if c.IsHessian2() {
Expand Down Expand Up @@ -421,6 +422,7 @@ func (c *converter) makeService(pkg generator.PkgInfo, svc *golang.Service) (*ge
PkgInfo: pkg,
ServiceName: svc.GoName().String(),
RawServiceName: svc.Name,
StreamX: c.Config.StreamX,
}
si.ServiceTypeName = func() string { return si.PkgRefName + "." + si.ServiceName }

Expand Down Expand Up @@ -464,6 +466,7 @@ func (c *converter) makeMethod(si *generator.ServiceInfo, f *golang.Function) (*
ClientStreaming: st.ClientStreaming,
ServerStreaming: st.ServerStreaming,
ArgsLength: len(f.Arguments()),
StreamX: si.StreamX,
}
if st.IsStreaming {
si.HasStreaming = true
Expand Down
Loading

0 comments on commit 5784bac

Please sign in to comment.