From 1a0c19b61fce9344e3b2d6bbdfe02bcf1eb16923 Mon Sep 17 00:00:00 2001 From: Carlos O'Ryan Date: Wed, 22 Jan 2025 14:55:35 -0500 Subject: [PATCH 1/2] feat(generator): support self-referential fields Message fields may refer to the same message, directly or indirectly. With this change the fields gain a `Recursive` attribute, set to `true` if the field directly (or indirectly) references the containing message. The Rust Codec uses this attribute to emit `Option>` instead of `Option` for such fields. In the process I found some mistakes in the handling of map types with objects in them. --- generator/internal/api/model.go | 4 + generator/internal/api/recursive.go | 50 ++++ generator/internal/api/recursive_test.go | 256 ++++++++++++++++++ .../{language/api_test.go => api/test.go} | 22 +- generator/internal/language/codec_test.go | 4 +- generator/internal/language/golang_test.go | 10 +- .../internal/language/gotemplate_test.go | 2 +- generator/internal/language/rust.go | 38 ++- generator/internal/language/rust_test.go | 254 +++++++++++++---- .../internal/language/rusttemplate_test.go | 4 +- generator/internal/sidekick/refresh.go | 1 + src/generated/api/src/model.rs | 11 +- .../cloud/translate/v3/src/builders.rs | 4 +- src/generated/cloud/translate/v3/src/model.rs | 16 +- .../devtools/cloudtrace/v2/src/model.rs | 11 +- 15 files changed, 582 insertions(+), 105 deletions(-) create mode 100644 generator/internal/api/recursive.go create mode 100644 generator/internal/api/recursive_test.go rename generator/internal/{language/api_test.go => api/test.go} (78%) diff --git a/generator/internal/api/model.go b/generator/internal/api/model.go index bc0c12005..6f996184a 100644 --- a/generator/internal/api/model.go +++ b/generator/internal/api/model.go @@ -289,6 +289,10 @@ type Field struct { // some helper fields. These need to be marked so they can be excluded // from serialized messages and in other places. Synthetic bool + // Some fields have a type that refers (sometimes indirectly) to the + // containing message. That triggers slightly different code generation for + // some languages. + Recursive bool // A placeholder to put language specific annotations. Codec any } diff --git a/generator/internal/api/recursive.go b/generator/internal/api/recursive.go new file mode 100644 index 000000000..a0163264e --- /dev/null +++ b/generator/internal/api/recursive.go @@ -0,0 +1,50 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +func LabelRecursiveFields(model *API) { + for _, message := range model.State.MessageByID { + for _, field := range message.Fields { + visited := map[string]bool{message.ID: true} + field.Recursive = field.recursivelyReferences(message.ID, model, visited) + } + } +} + +func (field *Field) recursivelyReferences(messageID string, model *API, visited map[string]bool) bool { + if field.Typez != MESSAGE_TYPE { + return false + } + if field.TypezID == messageID { + return true + } + if _, ok := visited[field.TypezID]; ok { + return false + } + if fieldMessage, ok := model.State.MessageByID[field.TypezID]; ok { + return fieldMessage.recursivelyReferences(messageID, model, visited) + } + return false +} + +func (message *Message) recursivelyReferences(messageID string, model *API, visited map[string]bool) bool { + visited[message.ID] = true + for _, field := range message.Fields { + if field.recursivelyReferences(messageID, model, visited) { + return true + } + } + return false +} diff --git a/generator/internal/api/recursive_test.go b/generator/internal/api/recursive_test.go new file mode 100644 index 000000000..868242b55 --- /dev/null +++ b/generator/internal/api/recursive_test.go @@ -0,0 +1,256 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "testing" +) + +func TestSimple(t *testing.T) { + field0 := &Field{ + Name: "a", + Typez: STRING_TYPE, + } + field1 := &Field{ + Name: "b", + Typez: MESSAGE_TYPE, + TypezID: ".test.Message", + Optional: true, + } + messages := []*Message{ + { + Name: "Message", + ID: ".test.Message", + Fields: []*Field{ + field0, field1, + }, + }, + } + model := NewTestAPI(messages, []*Enum{}, []*Service{}) + LabelRecursiveFields(model) + if field0.Recursive { + t.Errorf("mismatched IsRecursive field for %v", field0) + } + if !field1.Recursive { + t.Errorf("mismatched IsRecursive field for %v", field1) + } +} + +func TestSimpleMap(t *testing.T) { + field0 := &Field{ + Repeated: false, + Optional: false, + Name: "children", + ID: ".test.ParentMessage.children", + Typez: MESSAGE_TYPE, + TypezID: ".test.ParentMessage.SingularMapEntry", + } + parent := &Message{ + Name: "ParentMessage", + ID: ".test.ParentMessage", + Fields: []*Field{field0}, + } + + key := &Field{ + Name: "key", + JSONName: "key", + ID: ".test.ParentMessage.SingularMapEntry.key", + Typez: STRING_TYPE, + } + value := &Field{ + Name: "value", + JSONName: "value", + ID: ".test.ParentMessage.SingularMapEntry.value", + Typez: MESSAGE_TYPE, + TypezID: ".test.ParentMessage", + } + map_message := &Message{ + Name: "SingularMapEntry", + Package: "test", + ID: ".test.ParentMessage.SingularMapEntry", + IsMap: true, + Fields: []*Field{key, value}, + } + + model := NewTestAPI([]*Message{parent, map_message}, []*Enum{}, []*Service{}) + LabelRecursiveFields(model) + for _, field := range []*Field{value, field0} { + if !field.Recursive { + t.Errorf("expected IsRecursive to be true for field %s", field.ID) + } + } + if key.Recursive { + t.Errorf("expected IsRecursive to be false for field %s", key.ID) + } +} + +func TestIndirect(t *testing.T) { + field0 := &Field{ + Name: "child", + Typez: MESSAGE_TYPE, + TypezID: ".test.ChildMessage", + Optional: true, + } + field1 := &Field{ + Name: "grand_child", + Typez: MESSAGE_TYPE, + TypezID: ".test.GrandChildMessage", + Optional: true, + } + field2 := &Field{ + Name: "back_to_grand_parent", + Typez: MESSAGE_TYPE, + TypezID: ".test.Message", + Optional: true, + } + messages := []*Message{ + { + Name: "Message", + ID: ".test.Message", + Fields: []*Field{field0}, + }, + { + Name: "ChildMessage", + ID: ".test.ChildMessage", + Fields: []*Field{field1}, + }, + { + Name: "GrandChildMessage", + ID: ".test.GrandChildMessage", + Fields: []*Field{field2}, + }, + } + model := NewTestAPI(messages, []*Enum{}, []*Service{}) + LabelRecursiveFields(model) + for _, field := range []*Field{field0, field1, field2} { + if !field.Recursive { + t.Errorf("IsRecursive should be true for field %s", field.Name) + } + } +} + +func TestViaMap(t *testing.T) { + field0 := &Field{ + Name: "parent", + ID: ".test.ChildMessage.parent", + Typez: MESSAGE_TYPE, + TypezID: ".test.ParentMessage", + } + child := &Message{ + Name: "ChildMessage", + ID: ".test.ChildMessage", + Fields: []*Field{field0}, + } + + field1 := &Field{ + Repeated: false, + Optional: false, + Name: "children", + ID: ".test.ParentMessage.children", + Typez: MESSAGE_TYPE, + TypezID: ".test.ParentMessage.SingularMapEntry", + } + parent := &Message{ + Name: "ParentMessage", + ID: ".test.ParentMessage", + Fields: []*Field{field1}, + } + + key := &Field{ + Repeated: false, + Optional: false, + Name: "key", + JSONName: "key", + ID: ".test.ParentMessage.SingularMapEntry.key", + Typez: STRING_TYPE, + } + value := &Field{ + Repeated: false, + Optional: false, + Name: "value", + JSONName: "value", + ID: ".test.ParentMessage.SingularMapEntry.value", + Typez: MESSAGE_TYPE, + TypezID: ".test.ChildMessage", + } + map_message := &Message{ + Name: "SingularMapEntry", + Package: "test", + ID: ".test.ParentMessage.SingularMapEntry", + IsMap: true, + Fields: []*Field{key, value}, + } + + model := NewTestAPI([]*Message{parent, child, map_message}, []*Enum{}, []*Service{}) + LabelRecursiveFields(model) + for _, field := range []*Field{value, field0, field1} { + if !field.Recursive { + t.Errorf("expected IsRecursive to be true for field %s", field.ID) + } + } + if key.Recursive { + t.Errorf("expected IsRecursive to be false for field %s", key.ID) + } +} + +func TestReferencedCycle(t *testing.T) { + field0 := &Field{ + Name: "parent", + ID: ".test.ChildMessage.parent", + Typez: MESSAGE_TYPE, + TypezID: ".test.ParentMessage", + } + child := &Message{ + Name: "ChildMessage", + ID: ".test.ChildMessage", + Fields: []*Field{field0}, + } + field1 := &Field{ + Name: "child", + ID: ".test.ParentMessage.child", + Typez: MESSAGE_TYPE, + TypezID: ".test.ChildMessage", + } + parent := &Message{ + Name: "ParentdMessage", + ID: ".test.ParentMessage", + Fields: []*Field{field1}, + } + + field2 := &Field{ + Name: "ref", + ID: ".test.Holder.ref", + Typez: MESSAGE_TYPE, + TypezID: ".test.ParentMessage", + } + holder := &Message{ + Name: "Holder", + ID: ".test.Holder", + Fields: []*Field{field2}, + } + + model := NewTestAPI([]*Message{holder, parent, child}, []*Enum{}, []*Service{}) + LabelRecursiveFields(model) + for _, field := range []*Field{field0, field1} { + if !field.Recursive { + t.Errorf("expected IsRecursive to be true for field %s", field.ID) + } + } + for _, field := range []*Field{field2} { + if field.Recursive { + t.Errorf("expected IsRecursive to be false for field %s", field.ID) + } + } +} diff --git a/generator/internal/language/api_test.go b/generator/internal/api/test.go similarity index 78% rename from generator/internal/language/api_test.go rename to generator/internal/api/test.go index b6dc377d3..5647f4bfe 100644 --- a/generator/internal/language/api_test.go +++ b/generator/internal/api/test.go @@ -12,20 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package language +package api -import ( - "strings" +import "strings" - "github.com/googleapis/google-cloud-rust/generator/internal/api" -) - -func newTestAPI(messages []*api.Message, enums []*api.Enum, services []*api.Service) *api.API { - state := &api.APIState{ - MessageByID: make(map[string]*api.Message), - MethodByID: make(map[string]*api.Method), - EnumByID: make(map[string]*api.Enum), - ServiceByID: make(map[string]*api.Service), +func NewTestAPI(messages []*Message, enums []*Enum, services []*Service) *API { + state := &APIState{ + MessageByID: make(map[string]*Message), + MethodByID: make(map[string]*Method), + EnumByID: make(map[string]*Enum), + ServiceByID: make(map[string]*Service), } for _, m := range messages { state.MessageByID[m.ID] = m @@ -58,7 +54,7 @@ func newTestAPI(messages []*api.Message, enums []*api.Enum, services []*api.Serv } } - return &api.API{ + return &API{ Name: "Test", Messages: messages, Enums: enums, diff --git a/generator/internal/language/codec_test.go b/generator/internal/language/codec_test.go index 0c042fb97..ca4badef6 100644 --- a/generator/internal/language/codec_test.go +++ b/generator/internal/language/codec_test.go @@ -64,7 +64,7 @@ func TestQueryParams(t *testing.T) { }, }, } - test := newTestAPI( + test := api.NewTestAPI( []*api.Message{options, request}, []*api.Enum{}, []*api.Service{ @@ -84,7 +84,7 @@ func TestQueryParams(t *testing.T) { } func TestPathParams(t *testing.T) { - test := newTestAPI( + test := api.NewTestAPI( []*api.Message{sample.Secret(), sample.UpdateRequest(), sample.CreateRequest()}, []*api.Enum{}, []*api.Service{sample.Service()}, diff --git a/generator/internal/language/golang_test.go b/generator/internal/language/golang_test.go index 98122c783..b83fa3ebd 100644 --- a/generator/internal/language/golang_test.go +++ b/generator/internal/language/golang_test.go @@ -78,7 +78,7 @@ func TestGo_EnumNames(t *testing.T) { ID: "..SecretVersion.State", } - _ = newTestAPI([]*api.Message{message}, []*api.Enum{nested}, []*api.Service{}) + _ = api.NewTestAPI([]*api.Message{message}, []*api.Enum{nested}, []*api.Service{}) if got := goEnumName(nested, nil); got != "SecretVersion_State" { t.Errorf("mismatched message name, want=SecretVersion_Automatic, got=%s", got) } @@ -130,7 +130,7 @@ Maybe they wanted to show some JSON: } func TestGo_Validate(t *testing.T) { - api := newTestAPI( + api := api.NewTestAPI( []*api.Message{{Name: "m1", Package: "p1"}}, []*api.Enum{{Name: "e1", Package: "p1"}}, []*api.Service{{Name: "s1", Package: "p1"}}) @@ -141,7 +141,7 @@ func TestGo_Validate(t *testing.T) { func TestGo_ValidateMessageMismatch(t *testing.T) { const sourceSpecificationPackageName = "p1" - test := newTestAPI( + test := api.NewTestAPI( []*api.Message{{Name: "m1", Package: "p1"}, {Name: "m2", Package: "p2"}}, []*api.Enum{{Name: "e1", Package: "p1"}}, []*api.Service{{Name: "s1", Package: "p1"}}) @@ -149,7 +149,7 @@ func TestGo_ValidateMessageMismatch(t *testing.T) { t.Errorf("expected an error in API validation got=%s", sourceSpecificationPackageName) } - test = newTestAPI( + test = api.NewTestAPI( []*api.Message{{Name: "m1", Package: "p1"}}, []*api.Enum{{Name: "e1", Package: "p1"}, {Name: "e2", Package: "p2"}}, []*api.Service{{Name: "s1", Package: "p1"}}) @@ -157,7 +157,7 @@ func TestGo_ValidateMessageMismatch(t *testing.T) { t.Errorf("expected an error in API validation got=%s", sourceSpecificationPackageName) } - test = newTestAPI( + test = api.NewTestAPI( []*api.Message{{Name: "m1", Package: "p1"}}, []*api.Enum{{Name: "e1", Package: "p1"}}, []*api.Service{{Name: "s1", Package: "p1"}, {Name: "s2", Package: "p2"}}) diff --git a/generator/internal/language/gotemplate_test.go b/generator/internal/language/gotemplate_test.go index 1e29d10bd..2dd03bcc4 100644 --- a/generator/internal/language/gotemplate_test.go +++ b/generator/internal/language/gotemplate_test.go @@ -45,7 +45,7 @@ func Test_GoEnumAnnotations(t *testing.T) { Values: []*api.EnumValue{v0, v1, v2}, } - model := newTestAPI( + model := api.NewTestAPI( []*api.Message{}, []*api.Enum{enum}, []*api.Service{}) _, err := newGoTemplateData(model, map[string]string{}) if err != nil { diff --git a/generator/internal/language/rust.go b/generator/internal/language/rust.go index 886eca34a..537f61679 100644 --- a/generator/internal/language/rust.go +++ b/generator/internal/language/rust.go @@ -506,12 +506,45 @@ func rustFieldType(f *api.Field, state *api.APIState, primitive bool, modulePath if !primitive && f.Repeated { return fmt.Sprintf("std::vec::Vec<%s>", rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping)) } + if !primitive && f.Recursive { + base := rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping) + if f.Optional { + return fmt.Sprintf("std::option::Option>", base) + } + if _, ok := state.MessageByID[f.TypezID]; ok && f.Typez == api.MESSAGE_TYPE { + // Maps are never boxed. + return base + } + return fmt.Sprintf("std::boxed::Box<%s>", base) + } if !primitive && f.Optional { return fmt.Sprintf("std::option::Option<%s>", rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping)) } return rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping) } +func rustMapType(f *api.Field, state *api.APIState, modulePath, sourceSpecificationPackageName string, packageMapping map[string]*rustPackage) string { + switch f.Typez { + case api.MESSAGE_TYPE: + m, ok := state.MessageByID[f.TypezID] + if !ok { + slog.Error("unable to lookup type", "id", f.TypezID) + return "" + } + return rustFQMessageName(m, modulePath, sourceSpecificationPackageName, packageMapping) + + case api.ENUM_TYPE: + e, ok := state.EnumByID[f.TypezID] + if !ok { + slog.Error("unable to lookup type", "id", f.TypezID) + return "" + } + return rustFQEnumName(e, modulePath, sourceSpecificationPackageName, packageMapping) + default: + return scalarFieldType(f) + } +} + // Returns the field type, ignoring any repeated or optional attributes. func rustBaseFieldType(f *api.Field, state *api.APIState, modulePath, sourceSpecificationPackageName string, packageMapping map[string]*rustPackage) string { if f.Typez == api.MESSAGE_TYPE { @@ -521,8 +554,8 @@ func rustBaseFieldType(f *api.Field, state *api.APIState, modulePath, sourceSpec return "" } if m.IsMap { - key := rustFieldType(m.Fields[0], state, false, modulePath, sourceSpecificationPackageName, packageMapping) - val := rustFieldType(m.Fields[1], state, false, modulePath, sourceSpecificationPackageName, packageMapping) + key := rustMapType(m.Fields[0], state, modulePath, sourceSpecificationPackageName, packageMapping) + val := rustMapType(m.Fields[1], state, modulePath, sourceSpecificationPackageName, packageMapping) return "std::collections::HashMap<" + key + "," + val + ">" } return rustFQMessageName(m, modulePath, sourceSpecificationPackageName, packageMapping) @@ -538,7 +571,6 @@ func rustBaseFieldType(f *api.Field, state *api.APIState, modulePath, sourceSpec return "" } return scalarFieldType(f) - } func rustAddQueryParameter(f *api.Field) string { diff --git a/generator/internal/language/rust_test.go b/generator/internal/language/rust_test.go index 09ea28b11..e672f188f 100644 --- a/generator/internal/language/rust_test.go +++ b/generator/internal/language/rust_test.go @@ -28,7 +28,7 @@ import ( func createRustCodec() *rustCodec { wkt := &rustPackage{ - name: "gax_wkt", + name: "wkt", packageName: "types", path: "../../types", } @@ -187,7 +187,7 @@ func checkRustPackages(t *testing.T, got *rustCodec, want *rustCodec) { } func TestRust_Validate(t *testing.T) { - model := newTestAPI( + model := api.NewTestAPI( []*api.Message{{Name: "m1", Package: "p1"}}, []*api.Enum{{Name: "e1", Package: "p1"}}, []*api.Service{{Name: "s1", Package: "p1"}}) @@ -197,7 +197,7 @@ func TestRust_Validate(t *testing.T) { } func TestRust_ValidateMessageMismatch(t *testing.T) { - test := newTestAPI( + test := api.NewTestAPI( []*api.Message{{Name: "m1", Package: "p1"}, {Name: "m2", Package: "p2"}}, []*api.Enum{{Name: "e1", Package: "p1"}}, []*api.Service{{Name: "s1", Package: "p1"}}) @@ -206,7 +206,7 @@ func TestRust_ValidateMessageMismatch(t *testing.T) { t.Errorf("expected an error in API validation got=%s", c.sourceSpecificationPackageName) } - test = newTestAPI( + test = api.NewTestAPI( []*api.Message{{Name: "m1", Package: "p1"}}, []*api.Enum{{Name: "e1", Package: "p1"}, {Name: "e2", Package: "p2"}}, []*api.Service{{Name: "s1", Package: "p1"}}) @@ -215,7 +215,7 @@ func TestRust_ValidateMessageMismatch(t *testing.T) { t.Errorf("expected an error in API validation got=%s", c.sourceSpecificationPackageName) } - test = newTestAPI( + test = api.NewTestAPI( []*api.Message{{Name: "m1", Package: "p1"}}, []*api.Enum{{Name: "e1", Package: "p1"}}, []*api.Service{{Name: "s1", Package: "p1"}, {Name: "s2", Package: "p2"}}) @@ -226,7 +226,7 @@ func TestRust_ValidateMessageMismatch(t *testing.T) { } func TestWellKnownTypesExist(t *testing.T) { - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) rustLoadWellKnownTypes(model.State) for _, name := range []string{"Any", "Duration", "Empty", "FieldMask", "Timestamp"} { if _, ok := model.State.MessageByID[fmt.Sprintf(".google.protobuf.%s", name)]; !ok { @@ -240,7 +240,7 @@ func TestUsedByServicesWithServices(t *testing.T) { Name: "TestService", ID: ".test.Service", } - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{service}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{service}) c, err := newRustCodec(map[string]string{ "package:tracing": "used-if=services,package=tracing,version=0.1.41", "package:location": "package=gcp-sdk-location,source=google.cloud.location,path=src/generated/cloud/location,version=0.1.0", @@ -274,7 +274,7 @@ func TestUsedByServicesWithServices(t *testing.T) { } func TestUsedByServicesNoServices(t *testing.T) { - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) c, err := newRustCodec(map[string]string{ "package:tracing": "used-if=services,package=tracing,version=0.1.41", "package:location": "package=gcp-sdk-location,source=google.cloud.location,path=src/generated/cloud/location,version=0.1.0", @@ -316,7 +316,7 @@ func TestUsedByLROsWithLRO(t *testing.T) { ID: ".test.Service", Methods: []*api.Method{method}, } - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{service}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{service}) c, err := newRustCodec(map[string]string{ "package:location": "package=gcp-sdk-location,source=google.cloud.location,path=src/generated/cloud/location,version=0.1.0", "package:lro": "used-if=lro,package=gcp-sdk-lro,path=src/lro,version=0.1.0", @@ -359,7 +359,7 @@ func TestUsedByLROsWithoutLRO(t *testing.T) { ID: ".test.Service", Methods: []*api.Method{method}, } - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{service}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{service}) c, err := newRustCodec(map[string]string{ "package:location": "package=gcp-sdk-location,source=google.cloud.location,path=src/generated/cloud/location,version=0.1.0", "package:lro": "used-if=lro,package=gcp-sdk-lro,path=src/lro,version=0.1.0", @@ -397,7 +397,7 @@ func TestRust_NoStreamingFeature(t *testing.T) { codec := &rustCodec{ extraPackages: []*rustPackage{}, } - model := newTestAPI([]*api.Message{ + model := api.NewTestAPI([]*api.Message{ {Name: "CreateResource", IsPageableResponse: false}, }, []*api.Enum{}, []*api.Service{}) rustLoadWellKnownTypes(model.State) @@ -466,7 +466,7 @@ func TestRust_StreamingFeature(t *testing.T) { func checkRustContext(t *testing.T, codec *rustCodec, wantFeatures string) { t.Helper() - model := newTestAPI([]*api.Message{ + model := api.NewTestAPI([]*api.Message{ {Name: "ListResources", IsPageableResponse: true}, }, []*api.Enum{}, []*api.Service{}) rustLoadWellKnownTypes(model.State) @@ -483,11 +483,11 @@ func checkRustContext(t *testing.T, codec *rustCodec, wantFeatures string) { } func TestRust_WellKnownTypesAsMethod(t *testing.T) { - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) c := createRustCodec() rustLoadWellKnownTypes(model.State) - want := "gax_wkt::Empty" + want := "wkt::Empty" got := rustMethodInOutTypeName(".google.protobuf.Empty", model.State, c.modulePath, c.sourceSpecificationPackageName, c.packageMapping) if want != got { t.Errorf("mismatched well-known type name as method argument or response, want=%s, got=%s", want, got) @@ -504,7 +504,7 @@ func TestRust_MethodInOut(t *testing.T) { ID: "..Target.Nested", Parent: message, } - model := newTestAPI([]*api.Message{message, nested}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{message, nested}, []*api.Enum{}, []*api.Service{}) c := createRustCodec() rustLoadWellKnownTypes(model.State) @@ -593,7 +593,7 @@ func TestRust_FieldAttributes(t *testing.T) { }, }, } - model := newTestAPI([]*api.Message{message}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{message}, []*api.Enum{}, []*api.Service{}) expectedAttributes := map[string]string{ "f_int64": `#[serde_as(as = "serde_with::DisplayFromStr")]`, @@ -726,7 +726,7 @@ func TestRust_MapFieldAttributes(t *testing.T) { }, }, } - model := newTestAPI([]*api.Message{target, map1, map2, map3, map4, message}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{target, map1, map2, map3, map4, message}, []*api.Enum{}, []*api.Service{}) expectedAttributes := map[string]string{ "target": `#[serde(skip_serializing_if = "std::option::Option::is_none")]`, @@ -798,7 +798,7 @@ func TestRust_WktFieldAttributes(t *testing.T) { }, }, } - model := newTestAPI([]*api.Message{message}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{message}, []*api.Enum{}, []*api.Service{}) expectedAttributes := map[string]string{ "f_int64": `#[serde(skip_serializing_if = "std::option::Option::is_none")]` + "\n" + `#[serde_as(as = "std::option::Option")]`, @@ -844,7 +844,7 @@ func TestRust_FieldLossyName(t *testing.T) { }, }, } - model := newTestAPI([]*api.Message{message}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{message}, []*api.Enum{}, []*api.Service{}) expectedAttributes := map[string]string{ "data": `#[serde(skip_serializing_if = "bytes::Bytes::is_empty")]` + "\n" + @@ -894,7 +894,7 @@ func TestRust_SyntheticField(t *testing.T) { }, }, } - model := newTestAPI([]*api.Message{message}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{message}, []*api.Enum{}, []*api.Service{}) expectedAttributes := map[string]string{ "updateMask": `#[serde(skip_serializing_if = "std::option::Option::is_none")]`, @@ -973,6 +973,22 @@ func TestRust_FieldType(t *testing.T) { Optional: false, Repeated: true, }, + { + Name: "f_msg_recursive", + Typez: api.MESSAGE_TYPE, + TypezID: "..Fake", + Optional: true, + Repeated: false, + Recursive: true, + }, + { + Name: "f_msg_recursive_repeated", + Typez: api.MESSAGE_TYPE, + TypezID: "..Fake", + Optional: false, + Repeated: true, + Recursive: true, + }, { Name: "f_timestamp", Typez: api.MESSAGE_TYPE, @@ -989,31 +1005,35 @@ func TestRust_FieldType(t *testing.T) { }, }, } - model := newTestAPI([]*api.Message{target, message}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{target, message}, []*api.Enum{}, []*api.Service{}) expectedTypes := map[string]string{ - "f_int32": "i32", - "f_int32_optional": "std::option::Option", - "f_int32_repeated": "std::vec::Vec", - "f_string": "std::string::String", - "f_string_optional": "std::option::Option", - "f_string_repeated": "std::vec::Vec", - "f_msg": "std::option::Option", - "f_msg_repeated": "std::vec::Vec", - "f_timestamp": "std::option::Option", - "f_timestamp_repeated": "std::vec::Vec", + "f_int32": "i32", + "f_int32_optional": "std::option::Option", + "f_int32_repeated": "std::vec::Vec", + "f_string": "std::string::String", + "f_string_optional": "std::option::Option", + "f_string_repeated": "std::vec::Vec", + "f_msg": "std::option::Option", + "f_msg_repeated": "std::vec::Vec", + "f_msg_recursive": "std::option::Option>", + "f_msg_recursive_repeated": "std::vec::Vec", + "f_timestamp": "std::option::Option", + "f_timestamp_repeated": "std::vec::Vec", } expectedPrimitiveTypes := map[string]string{ - "f_int32": "i32", - "f_int32_optional": "i32", - "f_int32_repeated": "i32", - "f_string": "std::string::String", - "f_string_optional": "std::string::String", - "f_string_repeated": "std::string::String", - "f_msg": "crate::model::Target", - "f_msg_repeated": "crate::model::Target", - "f_timestamp": "gax_wkt::Timestamp", - "f_timestamp_repeated": "gax_wkt::Timestamp", + "f_int32": "i32", + "f_int32_optional": "i32", + "f_int32_repeated": "i32", + "f_string": "std::string::String", + "f_string_optional": "std::string::String", + "f_string_repeated": "std::string::String", + "f_msg": "crate::model::Target", + "f_msg_repeated": "crate::model::Target", + "f_msg_recursive": "crate::model::Fake", + "f_msg_recursive_repeated": "crate::model::Fake", + "f_timestamp": "wkt::Timestamp", + "f_timestamp_repeated": "wkt::Timestamp", } c := createRustCodec() rustLoadWellKnownTypes(model.State) @@ -1038,6 +1058,138 @@ func TestRust_FieldType(t *testing.T) { } } +// Verify rustBaseFieldType works for map types with different value fields. +func TestRust_FieldMapTypeValues(t *testing.T) { + for _, test := range []struct { + want string + value *api.Field + }{ + { + "std::collections::HashMap", + &api.Field{Typez: api.STRING_TYPE}, + }, + { + "std::collections::HashMap", + &api.Field{Typez: api.INT64_TYPE}, + }, + { + "std::collections::HashMap", + &api.Field{Typez: api.MESSAGE_TYPE, TypezID: ".google.protobuf.Any"}, + }, + { + "std::collections::HashMap", + &api.Field{Typez: api.MESSAGE_TYPE, TypezID: ".test.OtherMessage"}, + }, + { + "std::collections::HashMap", + &api.Field{Typez: api.MESSAGE_TYPE, TypezID: ".test.Message"}, + }, + } { + field := &api.Field{ + Name: "indexed", + ID: ".test.Message.indexed", + Typez: api.MESSAGE_TYPE, + TypezID: ".test.$MapThing", + } + other_message := &api.Message{ + Name: "OtherMessage", + ID: ".test.OtherMessage", + IsMap: true, + Fields: []*api.Field{}, + } + message := &api.Message{ + Name: "Message", + ID: ".test.Message", + IsMap: true, + Fields: []*api.Field{field}, + } + // Complete the value field + value := test.value + value.Name = "value" + value.ID = ".test.$MapThing.value" + key := &api.Field{ + Name: "key", + ID: ".test.$MapThing.key", + Typez: api.INT32_TYPE, + } + map_thing := &api.Message{ + Name: "$MapThing", + ID: ".test.$MapThing", + IsMap: true, + Fields: []*api.Field{key, value}, + } + model := api.NewTestAPI([]*api.Message{message, other_message, map_thing}, []*api.Enum{}, []*api.Service{}) + api.LabelRecursiveFields(model) + c := createRustCodec() + rustLoadWellKnownTypes(model.State) + got := rustFieldType(field, model.State, false, c.modulePath, c.sourceSpecificationPackageName, c.packageMapping) + if got != test.want { + t.Errorf("mismatched field type for %s, got=%s, want=%s", field.Name, got, test.want) + } + } +} + +// Verify rustBaseFieldType works for map types with different key fields. +func TestRust_FieldMapTypeKey(t *testing.T) { + for _, test := range []struct { + want string + key *api.Field + }{ + { + "std::collections::HashMap", + &api.Field{Typez: api.INT32_TYPE}, + }, + { + "std::collections::HashMap", + &api.Field{Typez: api.STRING_TYPE}, + }, + { + "std::collections::HashMap", + &api.Field{Typez: api.ENUM_TYPE, TypezID: ".test.EnumType"}, + }, + } { + field := &api.Field{ + Name: "indexed", + ID: ".test.Message.indexed", + Typez: api.MESSAGE_TYPE, + TypezID: ".test.$MapThing", + } + message := &api.Message{ + Name: "Message", + ID: ".test.Message", + IsMap: true, + Fields: []*api.Field{field}, + } + // Complete the value field + key := test.key + key.Name = "key" + key.ID = ".test.$MapThing.key" + value := &api.Field{ + Name: "value", + ID: ".test.$MapThing.value", + Typez: api.INT64_TYPE, + } + map_thing := &api.Message{ + Name: "$MapThing", + ID: ".test.$MapThing", + IsMap: true, + Fields: []*api.Field{key, value}, + } + enum := &api.Enum{ + Name: "EnumType", + ID: ".test.EnumType", + } + model := api.NewTestAPI([]*api.Message{message, map_thing}, []*api.Enum{enum}, []*api.Service{}) + api.LabelRecursiveFields(model) + c := createRustCodec() + rustLoadWellKnownTypes(model.State) + got := rustFieldType(field, model.State, false, c.modulePath, c.sourceSpecificationPackageName, c.packageMapping) + if got != test.want { + t.Errorf("mismatched field type for %s, got=%s, want=%s", field.Name, got, test.want) + } + } +} + func TestRust_AsQueryParameter(t *testing.T) { options := &api.Message{ Name: "Options", @@ -1110,7 +1262,7 @@ func TestRust_AsQueryParameter(t *testing.T) { requiredFieldMaskField, optionalFieldMaskField, }, } - model := newTestAPI( + model := api.NewTestAPI( []*api.Message{options, request}, []*api.Enum{}, []*api.Service{}) @@ -1234,7 +1386,7 @@ Maybe they wanted to show some JSON: "/// ```", } - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) c := &rustCodec{} got := rustFormatDocComments(input, model.State, c.modulePath, c.sourceSpecificationPackageName, c.packageMapping) if diff := cmp.Diff(want, got); diff != "" { @@ -1260,7 +1412,7 @@ func TestRust_FormatDocCommentsBullets(t *testing.T) { "/// value in the third email_addresses message.)", } - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) c := createRustCodec() got := rustFormatDocComments(input, model.State, c.modulePath, c.sourceSpecificationPackageName, c.packageMapping) if diff := cmp.Diff(want, got); diff != "" { @@ -1336,7 +1488,7 @@ block: "/// ```", } - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) c := &rustCodec{} got := rustFormatDocComments(input, model.State, c.modulePath, c.sourceSpecificationPackageName, c.packageMapping) if diff := cmp.Diff(want, got); diff != "" { @@ -1357,7 +1509,7 @@ func TestRust_FormatDocCommentsImplicitBlockQuoteClosing(t *testing.T) { "/// ```", } - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) c := &rustCodec{} got := rustFormatDocComments(input, model.State, c.modulePath, c.sourceSpecificationPackageName, c.packageMapping) if diff := cmp.Diff(want, got); diff != "" { @@ -1454,7 +1606,7 @@ Second [example][]. "/// [Third]: https://www.third.com", } - model := newTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{}, []*api.Enum{}, []*api.Service{}) c := &rustCodec{} got := rustFormatDocComments(input, model.State, c.modulePath, c.sourceSpecificationPackageName, c.packageMapping) if diff := cmp.Diff(want, got); diff != "" { @@ -1514,7 +1666,7 @@ func makeApiForRustFormatDocCommentsCrossLinks() *api.API { {Name: "CreateBar", ID: ".test.v1.SomeService.CreateBar"}, }, } - a := newTestAPI( + a := api.NewTestAPI( []*api.Message{someMessage}, []*api.Enum{someEnum}, []*api.Service{someService}) @@ -1597,7 +1749,7 @@ https://cloud.google.com/apis/design/design_patterns#integer_types.` func TestRust_MessageNames(t *testing.T) { r := sample.Replication() a := sample.Automatic() - model := newTestAPI([]*api.Message{r, a}, []*api.Enum{}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{r, a}, []*api.Enum{}, []*api.Service{}) model.PackageName = "test" c := createRustCodec() @@ -1653,7 +1805,7 @@ func TestRust_EnumNames(t *testing.T) { Package: "test", } - model := newTestAPI([]*api.Message{parent}, []*api.Enum{nested, non_nested}, []*api.Service{}) + model := api.NewTestAPI([]*api.Message{parent}, []*api.Enum{nested, non_nested}, []*api.Service{}) model.PackageName = "test" c := createRustCodec() c.sourceSpecificationPackageName = model.Messages[0].Package @@ -1754,7 +1906,7 @@ func Test_RustPathArgs(t *testing.T) { ID: ".test.Service", Methods: []*api.Method{method}, } - model := newTestAPI([]*api.Message{subMessage, message}, []*api.Enum{}, []*api.Service{service}) + model := api.NewTestAPI([]*api.Message{subMessage, message}, []*api.Enum{}, []*api.Service{service}) for _, test := range []struct { want []string diff --git a/generator/internal/language/rusttemplate_test.go b/generator/internal/language/rusttemplate_test.go index dd677d2c9..ff4c06260 100644 --- a/generator/internal/language/rusttemplate_test.go +++ b/generator/internal/language/rusttemplate_test.go @@ -22,7 +22,7 @@ import ( ) func TestPackageNames(t *testing.T) { - model := newTestAPI( + model := api.NewTestAPI( []*api.Message{}, []*api.Enum{}, []*api.Service{{Name: "Workflows", Package: "gcp-sdk-workflows-v1"}}) // Override the default name for test APIs ("Test"). @@ -65,7 +65,7 @@ func Test_RustEnumAnnotations(t *testing.T) { Values: []*api.EnumValue{v0, v1, v2}, } - model := newTestAPI( + model := api.NewTestAPI( []*api.Message{}, []*api.Enum{enum}, []*api.Service{}) codec, err := newRustCodec(map[string]string{}) if err != nil { diff --git a/generator/internal/sidekick/refresh.go b/generator/internal/sidekick/refresh.go index 1f5052110..c4174d91e 100644 --- a/generator/internal/sidekick/refresh.go +++ b/generator/internal/sidekick/refresh.go @@ -69,5 +69,6 @@ func refreshDir(rootConfig *config.Config, cmdLine *CommandLine, output string) if cmdLine.DryRun { return nil } + api.LabelRecursiveFields(model) return language.GenerateClient(model, config.General.Language, output, config.Codec) } diff --git a/src/generated/api/src/model.rs b/src/generated/api/src/model.rs index bf175043e..8cbbc4c57 100755 --- a/src/generated/api/src/model.rs +++ b/src/generated/api/src/model.rs @@ -597,10 +597,8 @@ pub struct BackendRule { /// The map between request protocol and the backend address. #[serde(skip_serializing_if = "std::collections::HashMap::is_empty")] - pub overrides_by_request_protocol: std::collections::HashMap< - std::string::String, - std::option::Option, - >, + pub overrides_by_request_protocol: + std::collections::HashMap, /// Authentication settings used by the backend. /// @@ -671,10 +669,7 @@ impl BackendRule { /// Sets the value of `overrides_by_request_protocol`. pub fn set_overrides_by_request_protocol< T: std::convert::Into< - std::collections::HashMap< - std::string::String, - std::option::Option, - >, + std::collections::HashMap, >, >( mut self, diff --git a/src/generated/cloud/translate/v3/src/builders.rs b/src/generated/cloud/translate/v3/src/builders.rs index 985964278..3e61e7ce0 100755 --- a/src/generated/cloud/translate/v3/src/builders.rs +++ b/src/generated/cloud/translate/v3/src/builders.rs @@ -582,7 +582,7 @@ pub mod translation_service { T: Into< std::collections::HashMap< std::string::String, - std::option::Option, + crate::model::TranslateTextGlossaryConfig, >, >, >( @@ -742,7 +742,7 @@ pub mod translation_service { T: Into< std::collections::HashMap< std::string::String, - std::option::Option, + crate::model::TranslateTextGlossaryConfig, >, >, >( diff --git a/src/generated/cloud/translate/v3/src/model.rs b/src/generated/cloud/translate/v3/src/model.rs index a1524db07..75438aeb9 100755 --- a/src/generated/cloud/translate/v3/src/model.rs +++ b/src/generated/cloud/translate/v3/src/model.rs @@ -4669,10 +4669,8 @@ pub struct BatchTranslateTextRequest { /// Optional. Glossaries to be applied for translation. /// It's keyed by target language code. #[serde(skip_serializing_if = "std::collections::HashMap::is_empty")] - pub glossaries: std::collections::HashMap< - std::string::String, - std::option::Option, - >, + pub glossaries: + std::collections::HashMap, /// Optional. The labels with user-defined metadata for the request. /// @@ -4748,7 +4746,7 @@ impl BatchTranslateTextRequest { T: std::convert::Into< std::collections::HashMap< std::string::String, - std::option::Option, + crate::model::TranslateTextGlossaryConfig, >, >, >( @@ -6141,10 +6139,8 @@ pub struct BatchTranslateDocumentRequest { /// Optional. Glossaries to be applied. It's keyed by target language code. #[serde(skip_serializing_if = "std::collections::HashMap::is_empty")] - pub glossaries: std::collections::HashMap< - std::string::String, - std::option::Option, - >, + pub glossaries: + std::collections::HashMap, /// Optional. The file format conversion map that is applied to all input /// files. The map key is the original mime_type. The map value is the target @@ -6240,7 +6236,7 @@ impl BatchTranslateDocumentRequest { T: std::convert::Into< std::collections::HashMap< std::string::String, - std::option::Option, + crate::model::TranslateTextGlossaryConfig, >, >, >( diff --git a/src/generated/devtools/cloudtrace/v2/src/model.rs b/src/generated/devtools/cloudtrace/v2/src/model.rs index 26af84649..aade4e19b 100755 --- a/src/generated/devtools/cloudtrace/v2/src/model.rs +++ b/src/generated/devtools/cloudtrace/v2/src/model.rs @@ -284,10 +284,8 @@ pub mod span { /// "abc.com/myattribute": { "bool_value": false } /// ``` #[serde(skip_serializing_if = "std::collections::HashMap::is_empty")] - pub attribute_map: std::collections::HashMap< - std::string::String, - std::option::Option, - >, + pub attribute_map: + std::collections::HashMap, /// The number of attributes that were discarded. Attributes can be discarded /// because their keys are too long or because there are too many attributes. @@ -299,10 +297,7 @@ pub mod span { /// Sets the value of `attribute_map`. pub fn set_attribute_map< T: std::convert::Into< - std::collections::HashMap< - std::string::String, - std::option::Option, - >, + std::collections::HashMap, >, >( mut self, From 1f6aae3d83e89dbfc6d15a78eef99ebf2d1f43b5 Mon Sep 17 00:00:00 2001 From: Carlos O'Ryan Date: Fri, 24 Jan 2025 10:56:16 -0500 Subject: [PATCH 2/2] Address review comments --- generator/internal/language/rust.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/generator/internal/language/rust.go b/generator/internal/language/rust.go index 537f61679..b72da0b1b 100644 --- a/generator/internal/language/rust.go +++ b/generator/internal/language/rust.go @@ -500,13 +500,14 @@ func rustFieldAttributes(f *api.Field, state *api.APIState) []string { } func rustFieldType(f *api.Field, state *api.APIState, primitive bool, modulePath, sourceSpecificationPackageName string, packageMapping map[string]*rustPackage) string { - if !primitive && f.IsOneOf { + switch { + case primitive: + return rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping) + case f.IsOneOf: return fmt.Sprintf("(%s)", rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping)) - } - if !primitive && f.Repeated { + case f.Repeated: return fmt.Sprintf("std::vec::Vec<%s>", rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping)) - } - if !primitive && f.Recursive { + case f.Recursive: base := rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping) if f.Optional { return fmt.Sprintf("std::option::Option>", base) @@ -516,11 +517,11 @@ func rustFieldType(f *api.Field, state *api.APIState, primitive bool, modulePath return base } return fmt.Sprintf("std::boxed::Box<%s>", base) - } - if !primitive && f.Optional { + case f.Optional: return fmt.Sprintf("std::option::Option<%s>", rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping)) + default: + return rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping) } - return rustBaseFieldType(f, state, modulePath, sourceSpecificationPackageName, packageMapping) } func rustMapType(f *api.Field, state *api.APIState, modulePath, sourceSpecificationPackageName string, packageMapping map[string]*rustPackage) string {