Skip to content

Commit

Permalink
impl: simplify openapi translator
Browse files Browse the repository at this point in the history
  • Loading branch information
julieqiu committed Nov 4, 2024
1 parent 382af5f commit 1d1bcbc
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 138 deletions.
6 changes: 1 addition & 5 deletions generator/cmd/openapi/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,14 @@ func run(inputPath, language, outDir, templateDir string) error {
}

func generateFrom(contents []byte, language, outDir, templateDir string) error {
translator, err := openapi.NewTranslator(contents, &openapi.Options{
req, err := openapi.Translate(contents, &openapi.Options{
Language: language,
OutDir: outDir,
TemplateDir: templateDir,
})
if err != nil {
return err
}
req, err := translator.Translate()
if err != nil {
return err
}
if _, err := genclient.Generate(req); err != nil {
return err
}
Expand Down
127 changes: 49 additions & 78 deletions generator/internal/genclient/translator/openapi/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,41 @@ import (
v3 "github.com/pb33f/libopenapi/datamodel/high/v3"
)

type Translator struct {
model *libopenapi.DocumentModel[v3.Document]
language string

// State by FQN
state *genclient.APIState

// Only used for local testing
outDir string
templateDir string
}

type Options struct {
Language string
// Only used for local testing
OutDir string
TemplateDir string
}

func NewTranslator(contents []byte, opts *Options) (*Translator, error) {
func Translate(contents []byte, opts *Options) (*genclient.GenerateRequest, error) {
model, err := createDocModel(contents)
if err != nil {
return nil, err
}
// Translates OpenAPI specification into a [genclient.GenerateRequest].
api, err := makeAPI(model)
if err != nil {
return nil, err
}
codec, err := language.NewCodec(opts.Language)
if err != nil {
return nil, err
}
api.State = &genclient.APIState{
ServiceByID: make(map[string]*genclient.Service),
MessageByID: make(map[string]*genclient.Message),
EnumByID: make(map[string]*genclient.Enum),
}
return &genclient.GenerateRequest{
API: api,
Codec: codec,
OutDir: opts.OutDir,
TemplateDir: opts.TemplateDir,
}, nil
}

func createDocModel(contents []byte) (*libopenapi.DocumentModel[v3.Document], error) {
document, err := libopenapi.NewDocument(contents)
if err != nil {
return nil, err
Expand All @@ -58,31 +73,20 @@ func NewTranslator(contents []byte, opts *Options) (*Translator, error) {
}
return nil, fmt.Errorf("cannot convert document to OpenAPI V3 model: %e", errors[0])
}

return &Translator{
model: docModel,
outDir: opts.OutDir,
language: opts.Language,
templateDir: opts.TemplateDir,
state: &genclient.APIState{
ServiceByID: make(map[string]*genclient.Service),
MessageByID: make(map[string]*genclient.Message),
EnumByID: make(map[string]*genclient.Enum),
},
}, nil
return docModel, nil
}

func (t *Translator) makeAPI() (*genclient.API, error) {
func makeAPI(model *libopenapi.DocumentModel[v3.Document]) (*genclient.API, error) {
api := &genclient.API{
Name: t.model.Model.Info.Title,
Name: model.Model.Info.Title,
Messages: make([]*genclient.Message, 0),
}
for name, msg := range t.model.Model.Components.Schemas.FromOldest() {
for name, msg := range model.Model.Components.Schemas.FromOldest() {
schema, err := msg.BuildSchema()
if err != nil {
return nil, err
}
fields, err := t.makeMessageFields(name, schema)
fields, err := makeMessageFields(name, schema)
if err != nil {
return nil, err
}
Expand All @@ -91,33 +95,12 @@ func (t *Translator) makeAPI() (*genclient.API, error) {
Documentation: msg.Schema().Description,
Fields: fields,
}

api.Messages = append(api.Messages, &message)
}
return api, nil
}

// Translates OpenAPI specification into a [genclient.GenerateRequest].
func (t *Translator) Translate() (*genclient.GenerateRequest, error) {
api, err := t.makeAPI()
if err != nil {
return nil, err
}

codec, err := language.NewCodec(t.language)
if err != nil {
return nil, err
}
api.State = t.state
return &genclient.GenerateRequest{
API: api,
Codec: codec,
OutDir: t.outDir,
TemplateDir: t.templateDir,
}, nil
}

func (t *Translator) makeMessageFields(messageName string, message *base.Schema) ([]*genclient.Field, error) {
func makeMessageFields(messageName string, message *base.Schema) ([]*genclient.Field, error) {
var fields []*genclient.Field
for name, f := range message.Properties.FromOldest() {
schema, err := f.BuildSchema()
Expand All @@ -131,7 +114,7 @@ func (t *Translator) makeMessageFields(messageName string, message *base.Schema)
break
}
}
field, err := t.makeField(messageName, name, optional, schema)
field, err := makeField(messageName, name, optional, schema)
if err != nil {
return nil, err
}
Expand All @@ -140,33 +123,27 @@ func (t *Translator) makeMessageFields(messageName string, message *base.Schema)
return fields, nil
}

func (t *Translator) makeField(messageName, name string, optional bool, field *base.Schema) (*genclient.Field, error) {
func makeField(messageName, name string, optional bool, field *base.Schema) (*genclient.Field, error) {
if len(field.AllOf) != 0 {
// Simple object fields name an AllOf attribute, but no `Type` attribute.
return t.makeObjectField(messageName, name, field)
return makeObjectField(messageName, name, field)
}
if len(field.Type) == 0 {
return nil, fmt.Errorf("missing field type for field %s.%s", messageName, name)
}
switch field.Type[0] {
case "boolean":
return t.makeScalarField(messageName, name, field, optional, field)
case "integer":
return t.makeScalarField(messageName, name, field, optional, field)
case "number":
return t.makeScalarField(messageName, name, field, optional, field)
case "string":
return t.makeScalarField(messageName, name, field, optional, field)
case "boolean", "integer", "number", "string":
return makeScalarField(messageName, name, field, optional, field)
case "object":
return t.makeObjectField(messageName, name, field)
return makeObjectField(messageName, name, field)
case "array":
return t.makeArrayField(messageName, name, field)
return makeArrayField(messageName, name, field)
default:
return nil, fmt.Errorf("unknown type for field %q", name)
}
}

func (t *Translator) makeScalarField(messageName, name string, schema *base.Schema, optional bool, field *base.Schema) (*genclient.Field, error) {
func makeScalarField(messageName, name string, schema *base.Schema, optional bool, field *base.Schema) (*genclient.Field, error) {
typez, typezID, err := scalarType(messageName, name, schema)
if err != nil {
return nil, err
Expand All @@ -180,9 +157,9 @@ func (t *Translator) makeScalarField(messageName, name string, schema *base.Sche
}, nil
}

func (t *Translator) makeObjectField(messageName, name string, field *base.Schema) (*genclient.Field, error) {
func makeObjectField(messageName, name string, field *base.Schema) (*genclient.Field, error) {
if len(field.AllOf) != 0 {
return t.makeObjectFieldAllOf(messageName, name, field)
return makeObjectFieldAllOf(messageName, name, field)
}
// TODO(#62) - this is an Any or a map<string, T>, needs a TypezID
return &genclient.Field{
Expand All @@ -193,7 +170,7 @@ func (t *Translator) makeObjectField(messageName, name string, field *base.Schem
}, nil
}

func (t *Translator) makeArrayField(messageName, name string, field *base.Schema) (*genclient.Field, error) {
func makeArrayField(messageName, name string, field *base.Schema) (*genclient.Field, error) {
if !field.Items.IsA() {
return nil, fmt.Errorf("cannot handle arrays without an `Items` field for %s.%s", messageName, name)
}
Expand All @@ -206,16 +183,10 @@ func (t *Translator) makeArrayField(messageName, name string, field *base.Schema
}
var result *genclient.Field
switch schema.Type[0] {
case "boolean":
result, err = t.makeScalarField(messageName, name, schema, false, field)
case "integer":
result, err = t.makeScalarField(messageName, name, schema, false, field)
case "number":
result, err = t.makeScalarField(messageName, name, schema, false, field)
case "string":
result, err = t.makeScalarField(messageName, name, schema, false, field)
case "boolean", "integer", "number", "string":
result, err = makeScalarField(messageName, name, schema, false, field)
case "object":
result, err = t.makeObjectField(messageName, name, field)
result, err = makeObjectField(messageName, name, field)
default:
return nil, fmt.Errorf("unknown array field type for %s.%s %q", messageName, name, schema.Type[0])
}
Expand All @@ -227,7 +198,7 @@ func (t *Translator) makeArrayField(messageName, name string, field *base.Schema
return result, nil
}

func (t *Translator) makeObjectFieldAllOf(messageName, name string, field *base.Schema) (*genclient.Field, error) {
func makeObjectFieldAllOf(messageName, name string, field *base.Schema) (*genclient.Field, error) {
for _, proxy := range field.AllOf {
typezID := strings.TrimPrefix(proxy.GetReference(), "#/components/schemas/")
return &genclient.Field{
Expand Down
51 changes: 16 additions & 35 deletions generator/internal/genclient/translator/openapi/openapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,13 @@ func TestAllOf(t *testing.T) {
},
`
contents := []byte(singleMessagePreamble + messageWithAllOf + singleMessageTrailer)
translator, err := NewTranslator(contents, &Options{
Language: "not used",
OutDir: "not used",
TemplateDir: "not used",
})
model, err := createDocModel(contents)
if err != nil {
t.Errorf("Error in NewTranslator() %q", err)
t.Fatal(err)
}

api, err := translator.makeAPI()
api, err := makeAPI(model)
if err != nil {
t.Errorf("Error in makeAPI() %q", err)
t.Fatalf("Error in makeAPI() %q", err)
}

checkMessage(t, *api.Messages[0], genclient.Message{
Expand Down Expand Up @@ -113,18 +108,14 @@ func TestBasicTypes(t *testing.T) {
},
`
contents := []byte(singleMessagePreamble + messageWithBasicTypes + singleMessageTrailer)
translator, err := NewTranslator(contents, &Options{
Language: "not used",
OutDir: "not used",
TemplateDir: "not used",
})
model, err := createDocModel(contents)
if err != nil {
t.Errorf("Error in NewTranslator() %q", err)
t.Fatal(err)
}

api, err := translator.makeAPI()
api, err := makeAPI(model)
if err != nil {
t.Errorf("Error in makeAPI() %q", err)
t.Fatalf("Error in makeAPI() %q", err)
}

checkMessage(t, *api.Messages[0], genclient.Message{
Expand Down Expand Up @@ -171,18 +162,13 @@ func TestArrayTypes(t *testing.T) {
},
`
contents := []byte(singleMessagePreamble + messageWithBasicTypes + singleMessageTrailer)
translator, err := NewTranslator(contents, &Options{
Language: "not used",
OutDir: "not used",
TemplateDir: "not used",
})
model, err := createDocModel(contents)
if err != nil {
t.Errorf("Error in NewTranslator() %q", err)
t.Fatal(err)
}

api, err := translator.makeAPI()
api, err := makeAPI(model)
if err != nil {
t.Errorf("Error in makeAPI() %q", err)
t.Fatalf("Error in makeAPI() %q", err)
}

checkMessage(t, *api.Messages[0], genclient.Message{
Expand All @@ -206,18 +192,13 @@ func TestArrayTypes(t *testing.T) {

func TestMakeAPI(t *testing.T) {
contents := []byte(testDocument)
translator, err := NewTranslator(contents, &Options{
Language: "rust",
OutDir: "not used",
TemplateDir: "not used",
})
model, err := createDocModel(contents)
if err != nil {
t.Errorf("Error in NewTranslator() %q", err)
t.Fatal(err)
}

api, err := translator.makeAPI()
api, err := makeAPI(model)
if err != nil {
t.Errorf("Error in makeAPI() %q", err)
t.Fatalf("Error in makeAPI() %q", err)
}

checkMessage(t, *api.Messages[0], genclient.Message{
Expand Down
Loading

0 comments on commit 1d1bcbc

Please sign in to comment.