From 10169c063f194eea35a02cb15b8a8b901e3f852e Mon Sep 17 00:00:00 2001 From: Julie Qiu Date: Fri, 1 Nov 2024 18:32:08 -0400 Subject: [PATCH] fix: cmd/protoc-gen-gclient: factor out captureInput Move logic for captureInput to a separate function and document it. --- generator/cmd/protoc-gen-gclient/main.go | 63 ++++++++++--------- generator/cmd/protoc-gen-gclient/main_test.go | 2 +- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/generator/cmd/protoc-gen-gclient/main.go b/generator/cmd/protoc-gen-gclient/main.go index cc21e5c62..cb4a4ebec 100644 --- a/generator/cmd/protoc-gen-gclient/main.go +++ b/generator/cmd/protoc-gen-gclient/main.go @@ -18,6 +18,7 @@ import ( "flag" "fmt" "io" + "log" "log/slog" "os" "slices" @@ -37,18 +38,19 @@ func main() { templateDir := flag.String("template-dir", "templates/", "the path to the template directory") flag.Parse() - if err := run(*inputPath, *outDir, *templateDir); err != nil { - slog.Error(err.Error()) - os.Exit(1) + if err := run(os.Stdin, os.Stdout, *inputPath, *outDir, *templateDir); err != nil { + log.Fatal(err) } slog.Info("Generation Completed Successfully") } -func run(inputPath, outDir, templateDir string) error { - var reqBytes []byte - var err error +func run(r io.Reader, w io.Writer, inputPath, outDir, templateDir string) error { + var ( + reqBytes []byte + err error + ) if inputPath == "" { - reqBytes, err = io.ReadAll(os.Stdin) + reqBytes, err = io.ReadAll(r) if err != nil { return err } @@ -63,23 +65,16 @@ func run(inputPath, outDir, templateDir string) error { if err := proto.Unmarshal(reqBytes, genReq); err != nil { return err } - opts, err := parseOpts(genReq.GetParameter()) if err != nil { return err } - if opts.CaptureInput { - // Remove capture-input param from the captured input - ss := slices.DeleteFunc(strings.Split(genReq.GetParameter(), ","), func(s string) bool { - return strings.Contains(s, "capture-input") - }) - genReq.Parameter = proto.String(strings.Join(ss, ",")) - reqBytes, err = proto.Marshal(genReq) - if err != nil { - return err - } - if err := os.WriteFile(fmt.Sprintf("sample-input-%s.bin", time.Now().Format(time.RFC3339)), reqBytes, 0644); err != nil { + // If capture-input is set, content pass to protoc will be written to a + // sample-input-{timestamp}.bin file, so that protoc does not need to be + // used on future iterations. + if opts.captureInput { + if err := captureInput(genReq); err != nil { return err } } @@ -87,7 +82,7 @@ func run(inputPath, outDir, templateDir string) error { req, err := protobuf.NewTranslator(&protobuf.Options{ Request: genReq, OutDir: outDir, - Language: opts.Language, + Language: opts.language, TemplateDir: templateDir, }).Translate() if err != nil { @@ -99,16 +94,13 @@ func run(inputPath, outDir, templateDir string) error { if err != nil { return err } - if _, err := os.Stdout.Write(outBytes); err != nil { - return err - } - - return nil + _, err = w.Write(outBytes) + return err } type protobufOptions struct { - CaptureInput bool - Language string + captureInput bool + language string } func parseOpts(optStr string) (*protobufOptions, error) { @@ -130,12 +122,25 @@ func parseOpts(optStr string) (*protobufOptions, error) { slog.Error("invalid bool in option string, skipping", "option", s) return nil, err } - opts.CaptureInput = b + opts.captureInput = b case "language": - opts.Language = strings.ToLower(strings.TrimSpace(sp[1])) + opts.language = strings.ToLower(strings.TrimSpace(sp[1])) default: slog.Warn("unknown option", "option", s) } } return opts, nil } + +func captureInput(genReq *pluginpb.CodeGeneratorRequest) error { + // Remove capture-input param from the captured input + ss := slices.DeleteFunc(strings.Split(genReq.GetParameter(), ","), func(s string) bool { + return strings.Contains(s, "capture-input") + }) + genReq.Parameter = proto.String(strings.Join(ss, ",")) + reqBytes, err := proto.Marshal(genReq) + if err != nil { + return err + } + return os.WriteFile(fmt.Sprintf("sample-input-%s.bin", time.Now().Format(time.RFC3339)), reqBytes, 0644) +} diff --git a/generator/cmd/protoc-gen-gclient/main_test.go b/generator/cmd/protoc-gen-gclient/main_test.go index dd59403fe..02245092d 100644 --- a/generator/cmd/protoc-gen-gclient/main_test.go +++ b/generator/cmd/protoc-gen-gclient/main_test.go @@ -32,7 +32,7 @@ func TestMain(m *testing.M) { func TestRun_Rust(t *testing.T) { tDir := t.TempDir() - if err := run("testdata/rust/rust.bin", tDir, "../../templates"); err != nil { + if err := run(os.Stdin, os.Stdout, "testdata/rust/rust.bin", tDir, "../../templates"); err != nil { t.Fatal(err) } diff(t, "testdata/rust/golden", tDir)