From bd28cda94334755d184dba86695ae5acd85edad6 Mon Sep 17 00:00:00 2001 From: Julie Qiu Date: Sun, 3 Nov 2024 21:04:42 -0500 Subject: [PATCH] impl: simplify openapi translator --- generator/cmd/openapi/main.go | 6 +- .../genclient/translator/openapi/openapi.go | 127 +++++++----------- .../translator/openapi/openapi_test.go | 51 +++---- 3 files changed, 66 insertions(+), 118 deletions(-) diff --git a/generator/cmd/openapi/main.go b/generator/cmd/openapi/main.go index 3a2f0f65..17e36e62 100644 --- a/generator/cmd/openapi/main.go +++ b/generator/cmd/openapi/main.go @@ -53,7 +53,7 @@ func run(inputPath, language, outDir, templateDir string) error { } func generateFrom(contents []byte, language, outDir, templateDir string) error { - translator, err := openapi.NewTranslator(contents, &openapi.Options{ + req, err := openapi.Translate(contents, &openapi.Options{ Language: language, OutDir: outDir, TemplateDir: templateDir, @@ -61,10 +61,6 @@ func generateFrom(contents []byte, language, outDir, templateDir string) error { if err != nil { return err } - req, err := translator.Translate() - if err != nil { - return err - } if _, err := genclient.Generate(req); err != nil { return err } diff --git a/generator/internal/genclient/translator/openapi/openapi.go b/generator/internal/genclient/translator/openapi/openapi.go index 55cb3561..8147228f 100644 --- a/generator/internal/genclient/translator/openapi/openapi.go +++ b/generator/internal/genclient/translator/openapi/openapi.go @@ -27,18 +27,6 @@ import ( v3 "github.com/pb33f/libopenapi/datamodel/high/v3" ) -type Translator struct { - model *libopenapi.DocumentModel[v3.Document] - language string - - // State by FQN - state *genclient.APIState - - // Only used for local testing - outDir string - templateDir string -} - type Options struct { Language string // Only used for local testing @@ -46,7 +34,29 @@ type Options struct { TemplateDir string } -func NewTranslator(contents []byte, opts *Options) (*Translator, error) { +func Translate(contents []byte, opts *Options) (*genclient.GenerateRequest, error) { + model, err := createDocModel(contents) + if err != nil { + return nil, err + } + // Translates OpenAPI specification into a [genclient.GenerateRequest]. + api, err := makeAPI(model) + if err != nil { + return nil, err + } + codec, err := language.NewCodec(opts.Language) + if err != nil { + return nil, err + } + return &genclient.GenerateRequest{ + API: api, + Codec: codec, + OutDir: opts.OutDir, + TemplateDir: opts.TemplateDir, + }, nil +} + +func createDocModel(contents []byte) (*libopenapi.DocumentModel[v3.Document], error) { document, err := libopenapi.NewDocument(contents) if err != nil { return nil, err @@ -58,28 +68,21 @@ func NewTranslator(contents []byte, opts *Options) (*Translator, error) { } return nil, fmt.Errorf("cannot convert document to OpenAPI V3 model: %e", errors[0]) } + return docModel, nil +} - return &Translator{ - model: docModel, - outDir: opts.OutDir, - language: opts.Language, - templateDir: opts.TemplateDir, - state: &genclient.APIState{ +func makeAPI(model *libopenapi.DocumentModel[v3.Document]) (*genclient.API, error) { + api := &genclient.API{ + Name: model.Model.Info.Title, + Messages: make([]*genclient.Message, 0), + State: &genclient.APIState{ ServiceByID: make(map[string]*genclient.Service), MessageByID: make(map[string]*genclient.Message), EnumByID: make(map[string]*genclient.Enum), }, - }, nil -} - -func (t *Translator) makeAPI() (*genclient.API, error) { - api := &genclient.API{ - Name: t.model.Model.Info.Title, - Messages: make([]*genclient.Message, 0), } - api.State = t.state - for name, msg := range t.model.Model.Components.Schemas.FromOldest() { + for name, msg := range model.Model.Components.Schemas.FromOldest() { // The typical format is ".${packageName}.${messageName}", but we do not // have a package name at the moment. id := ".." + name @@ -87,7 +90,7 @@ func (t *Translator) makeAPI() (*genclient.API, error) { if err != nil { return nil, err } - fields, err := t.makeMessageFields(name, schema) + fields, err := makeMessageFields(name, schema) if err != nil { return nil, err } @@ -99,32 +102,12 @@ func (t *Translator) makeAPI() (*genclient.API, error) { } api.Messages = append(api.Messages, message) - t.state.MessageByID[id] = message + api.State.MessageByID[id] = message } return api, nil } -// Translates OpenAPI specification into a [genclient.GenerateRequest]. -func (t *Translator) Translate() (*genclient.GenerateRequest, error) { - api, err := t.makeAPI() - if err != nil { - return nil, err - } - - codec, err := language.NewCodec(t.language) - if err != nil { - return nil, err - } - api.State = t.state - return &genclient.GenerateRequest{ - API: api, - Codec: codec, - OutDir: t.outDir, - TemplateDir: t.templateDir, - }, nil -} - -func (t *Translator) makeMessageFields(messageName string, message *base.Schema) ([]*genclient.Field, error) { +func makeMessageFields(messageName string, message *base.Schema) ([]*genclient.Field, error) { var fields []*genclient.Field for name, f := range message.Properties.FromOldest() { schema, err := f.BuildSchema() @@ -138,7 +121,7 @@ func (t *Translator) makeMessageFields(messageName string, message *base.Schema) break } } - field, err := t.makeField(messageName, name, optional, schema) + field, err := makeField(messageName, name, optional, schema) if err != nil { return nil, err } @@ -147,33 +130,27 @@ func (t *Translator) makeMessageFields(messageName string, message *base.Schema) return fields, nil } -func (t *Translator) makeField(messageName, name string, optional bool, field *base.Schema) (*genclient.Field, error) { +func makeField(messageName, name string, optional bool, field *base.Schema) (*genclient.Field, error) { if len(field.AllOf) != 0 { // Simple object fields name an AllOf attribute, but no `Type` attribute. - return t.makeObjectField(messageName, name, field) + return makeObjectField(messageName, name, field) } if len(field.Type) == 0 { return nil, fmt.Errorf("missing field type for field %s.%s", messageName, name) } switch field.Type[0] { - case "boolean": - return t.makeScalarField(messageName, name, field, optional, field) - case "integer": - return t.makeScalarField(messageName, name, field, optional, field) - case "number": - return t.makeScalarField(messageName, name, field, optional, field) - case "string": - return t.makeScalarField(messageName, name, field, optional, field) + case "boolean", "integer", "number", "string": + return makeScalarField(messageName, name, field, optional, field) case "object": - return t.makeObjectField(messageName, name, field) + return makeObjectField(messageName, name, field) case "array": - return t.makeArrayField(messageName, name, field) + return makeArrayField(messageName, name, field) default: return nil, fmt.Errorf("unknown type for field %q", name) } } -func (t *Translator) makeScalarField(messageName, name string, schema *base.Schema, optional bool, field *base.Schema) (*genclient.Field, error) { +func makeScalarField(messageName, name string, schema *base.Schema, optional bool, field *base.Schema) (*genclient.Field, error) { typez, typezID, err := scalarType(messageName, name, schema) if err != nil { return nil, err @@ -187,9 +164,9 @@ func (t *Translator) makeScalarField(messageName, name string, schema *base.Sche }, nil } -func (t *Translator) makeObjectField(messageName, name string, field *base.Schema) (*genclient.Field, error) { +func makeObjectField(messageName, name string, field *base.Schema) (*genclient.Field, error) { if len(field.AllOf) != 0 { - return t.makeObjectFieldAllOf(messageName, name, field) + return makeObjectFieldAllOf(messageName, name, field) } // TODO(#62) - this is an Any or a map, needs a TypezID return &genclient.Field{ @@ -200,7 +177,7 @@ func (t *Translator) makeObjectField(messageName, name string, field *base.Schem }, nil } -func (t *Translator) makeArrayField(messageName, name string, field *base.Schema) (*genclient.Field, error) { +func makeArrayField(messageName, name string, field *base.Schema) (*genclient.Field, error) { if !field.Items.IsA() { return nil, fmt.Errorf("cannot handle arrays without an `Items` field for %s.%s", messageName, name) } @@ -213,16 +190,10 @@ func (t *Translator) makeArrayField(messageName, name string, field *base.Schema } var result *genclient.Field switch schema.Type[0] { - case "boolean": - result, err = t.makeScalarField(messageName, name, schema, false, field) - case "integer": - result, err = t.makeScalarField(messageName, name, schema, false, field) - case "number": - result, err = t.makeScalarField(messageName, name, schema, false, field) - case "string": - result, err = t.makeScalarField(messageName, name, schema, false, field) + case "boolean", "integer", "number", "string": + result, err = makeScalarField(messageName, name, schema, false, field) case "object": - result, err = t.makeObjectField(messageName, name, field) + result, err = makeObjectField(messageName, name, field) default: return nil, fmt.Errorf("unknown array field type for %s.%s %q", messageName, name, schema.Type[0]) } @@ -234,7 +205,7 @@ func (t *Translator) makeArrayField(messageName, name string, field *base.Schema return result, nil } -func (t *Translator) makeObjectFieldAllOf(messageName, name string, field *base.Schema) (*genclient.Field, error) { +func makeObjectFieldAllOf(messageName, name string, field *base.Schema) (*genclient.Field, error) { for _, proxy := range field.AllOf { typezID := strings.TrimPrefix(proxy.GetReference(), "#/components/schemas/") return &genclient.Field{ diff --git a/generator/internal/genclient/translator/openapi/openapi_test.go b/generator/internal/genclient/translator/openapi/openapi_test.go index 56d56adb..cf6420b6 100644 --- a/generator/internal/genclient/translator/openapi/openapi_test.go +++ b/generator/internal/genclient/translator/openapi/openapi_test.go @@ -52,18 +52,13 @@ func TestAllOf(t *testing.T) { }, ` contents := []byte(singleMessagePreamble + messageWithAllOf + singleMessageTrailer) - translator, err := NewTranslator(contents, &Options{ - Language: "not used", - OutDir: "not used", - TemplateDir: "not used", - }) + model, err := createDocModel(contents) if err != nil { - t.Errorf("Error in NewTranslator() %q", err) + t.Fatal(err) } - - api, err := translator.makeAPI() + api, err := makeAPI(model) if err != nil { - t.Errorf("Error in makeAPI() %q", err) + t.Fatalf("Error in makeAPI() %q", err) } message := api.State.MessageByID["..Automatic"] @@ -119,18 +114,14 @@ func TestBasicTypes(t *testing.T) { }, ` contents := []byte(singleMessagePreamble + messageWithBasicTypes + singleMessageTrailer) - translator, err := NewTranslator(contents, &Options{ - Language: "not used", - OutDir: "not used", - TemplateDir: "not used", - }) + model, err := createDocModel(contents) if err != nil { - t.Errorf("Error in NewTranslator() %q", err) + t.Fatal(err) } - api, err := translator.makeAPI() + api, err := makeAPI(model) if err != nil { - t.Errorf("Error in makeAPI() %q", err) + t.Fatalf("Error in makeAPI() %q", err) } message := api.State.MessageByID["..Fake"] @@ -183,18 +174,13 @@ func TestArrayTypes(t *testing.T) { }, ` contents := []byte(singleMessagePreamble + messageWithBasicTypes + singleMessageTrailer) - translator, err := NewTranslator(contents, &Options{ - Language: "not used", - OutDir: "not used", - TemplateDir: "not used", - }) + model, err := createDocModel(contents) if err != nil { - t.Errorf("Error in NewTranslator() %q", err) + t.Fatal(err) } - - api, err := translator.makeAPI() + api, err := makeAPI(model) if err != nil { - t.Errorf("Error in makeAPI() %q", err) + t.Fatalf("Error in makeAPI() %q", err) } message := api.State.MessageByID["..Fake"] @@ -224,18 +210,13 @@ func TestArrayTypes(t *testing.T) { func TestMakeAPI(t *testing.T) { contents := []byte(testDocument) - translator, err := NewTranslator(contents, &Options{ - Language: "rust", - OutDir: "not used", - TemplateDir: "not used", - }) + model, err := createDocModel(contents) if err != nil { - t.Errorf("Error in NewTranslator() %q", err) + t.Fatal(err) } - - api, err := translator.makeAPI() + api, err := makeAPI(model) if err != nil { - t.Errorf("Error in makeAPI() %q", err) + t.Fatalf("Error in makeAPI() %q", err) } location := api.State.MessageByID["..Location"]