Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: gRPC wrapper does not generate the correct gRPC handlers with context #16

Merged
merged 5 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions wrap/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ import (
const filePerm = 0644

var (
ErrNoProtoFile = errors.New("proto file path is required")
ErrOpeningProtoFile = errors.New("error opening the proto file")
ErrFailedToParseProto = errors.New("failed to parse proto file")
ErrGeneratingWrapper = errors.New("error generating the wrapper code from the proto file")
ErrWritingWrapperFile = errors.New("error writing the generated wrapper to the file")
ErrNoProtoFile = errors.New("proto file path is required")
ErrOpeningProtoFile = errors.New("error opening the proto file")
ErrFailedToParseProto = errors.New("failed to parse proto file")
ErrGeneratingWrapper = errors.New("error generating the wrapper code from the proto file")
ErrWritingWrapperFile = errors.New("error writing the generated wrapper to the file")
ErrGeneratingServerTemplate = errors.New("error generating the gRPC server file template")
ErrWritingServerTemplate = errors.New("error writing the generated server template to the file")
)

// ServiceMethod represents a method in a proto service.
Expand Down Expand Up @@ -71,7 +73,7 @@ func GenerateWrapper(ctx *gofr.Context) (any, error) {

var (
// Extracting package and project path from go_package option.
packageName, projectPath = getPackageAndProject(definition)
projectPath, packageName = getPackageAndProject(definition)
// Extract the services.
services = getServices(definition)
)
Expand All @@ -89,7 +91,7 @@ func GenerateWrapper(ctx *gofr.Context) (any, error) {
return nil, ErrGeneratingWrapper
}

outputFilePath := fmt.Sprintf("%s/%s.gofr.go", projectPath, strings.ToLower(service.Name))
outputFilePath := path.Join(projectPath, fmt.Sprintf("%s.gofr.go", strings.ToLower(service.Name)))

err := os.WriteFile(outputFilePath, []byte(generatedCode), filePerm)
if err != nil {
Expand All @@ -99,6 +101,22 @@ func GenerateWrapper(ctx *gofr.Context) (any, error) {
}

fmt.Printf("Generated wrapper for service %s at %s\n", service.Name, outputFilePath)

generatedgRPCCode := generategRPCCode(ctx, &wrapperData)
if generatedgRPCCode == "" {
return nil, ErrGeneratingServerTemplate
}

outputFilePath = path.Join(projectPath, fmt.Sprintf("%sServer.go", strings.ToLower(service.Name)))

err = os.WriteFile(outputFilePath, []byte(generatedgRPCCode), filePerm)
if err != nil {
ctx.Errorf("Failed to write file %s: %v", outputFilePath, err)

return nil, ErrWritingServerTemplate
}

fmt.Printf("Generated server template for service %s at %s\n", service.Name, outputFilePath)
}

return "Successfully generated all wrappers for gRPC services", nil
Expand Down Expand Up @@ -126,7 +144,7 @@ func uniqueRequestTypes(methods []ServiceMethod) []string {
func generateWrapperCode(ctx *gofr.Context, data *WrapperData) string {
var buf bytes.Buffer

tmplInstance := template.Must(template.New("wrapper").Parse(tmpl))
tmplInstance := template.Must(template.New("wrapper").Parse(wrapperTemplate))

err := tmplInstance.Execute(&buf, data)
if err != nil {
Expand All @@ -138,11 +156,26 @@ func generateWrapperCode(ctx *gofr.Context, data *WrapperData) string {
return buf.String()
}

// Generate wrapper code using the template.
func generategRPCCode(ctx *gofr.Context, data *WrapperData) string {
var buf bytes.Buffer

tmplInstance := template.Must(template.New("wrapper").Parse(serverTemplate))

err := tmplInstance.Execute(&buf, data)
if err != nil {
ctx.Errorf("Template execution failed: %v", err)
return ""
}

return buf.String()
}

func getPackageAndProject(definition *proto.Proto) (projectPath, packageName string) {
proto.Walk(definition,
proto.WithOption(func(opt *proto.Option) {
if opt.Name == "go_package" {
projectPath = opt.Constant.Source[:len(opt.Constant.Source)-1]
projectPath = opt.Constant.Source
packageName = path.Base(opt.Constant.Source)
}
}),
Expand Down
36 changes: 33 additions & 3 deletions wrap/template.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package wrap

const tmpl = `// Code generated by gofr.dev/cli/gofr. DO NOT EDIT.
const (
wrapperTemplate = `// Code generated by gofr.dev/cli/gofr. DO NOT EDIT.
package {{ .Package }}

import (
Expand Down Expand Up @@ -90,8 +91,8 @@ func (h *{{ $request }}Wrapper) Bind(p interface{}) error {
return fmt.Errorf("expected a pointer, got %T", p)
}

hValue := reflect.ValueOf(h.InfoRequest).Elem()
ptrValue := reflect.ValueOf(ptr).Elem()
hValue := reflect.ValueOf(h.{{ $request }}).Elem()
ptrValue := ptr.Elem()

// Ensure we can set exported fields (skip unexported fields)
for i := 0; i < hValue.NumField(); i++ {
Expand Down Expand Up @@ -119,3 +120,32 @@ func (h *{{ $request }}Wrapper) Params(s string) []string {

{{- end }}
`

serverTemplate = `package {{ .Package }}

import "gofr.dev/pkg/gofr"

// Register the gRPC service in your app using the following code in your main.go:
//
// grpc.Register{{ $.Service }}ServerWithGofr(app, &grpc.{{ $.Service }}GoFrServer{})
//
// {{ $.Service }}GoFrServer defines the gRPC server implementation.
// Customize the struct with required dependencies and fields as needed.

type {{ $.Service }}GoFrServer struct {
}

{{- range .Methods }}
func (s *{{ $.Service }}GoFrServer) {{ .Name }}(ctx *gofr.Context) (any, error) {
// Uncomment and use the following code if you need to bind the request payload
// request := {{ .Request }}{}
// err := ctx.Bind(&request)
// if err != nil {
// return nil, err
// }

return &{{ .Response }}{}, nil
}
{{- end }}
`
)
Loading