blob: db0b5f3a15667b46d8b846e08de0645e4ca03fb9 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"testing"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
retentionpb "google.golang.org/protobuf/cmd/protoc-gen-go/testdata/retention"
)
func TestFileOptionRetention(t *testing.T) {
options := retentionpb.File_cmd_protoc_gen_go_testdata_retention_retention_proto.Options()
tests := []struct {
name string
ext protoreflect.ExtensionType
wantField bool
wantValue int32
}{
{
name: "imported_plain_option",
ext: retentionpb.E_ImportedPlainOption,
wantField: true,
wantValue: 1,
},
{
name: "imported_runtime_option",
ext: retentionpb.E_ImportedRuntimeRetentionOption,
wantField: true,
wantValue: 2,
},
{
name: "imported_source_option",
ext: retentionpb.E_ImportedSourceRetentionOption,
wantField: false,
wantValue: 0,
},
{
name: "plain_option",
ext: retentionpb.E_PlainOption,
wantField: true,
wantValue: 1,
},
{
name: "runtime_option",
ext: retentionpb.E_RuntimeRetentionOption,
wantField: true,
wantValue: 2,
},
{
name: "source_option",
ext: retentionpb.E_SourceRetentionOption,
wantField: false,
wantValue: 0,
},
}
for _, test := range tests {
if test.wantField != proto.HasExtension(options, test.ext) {
t.Errorf("HasExtension(%s): got %v, want %v", test.name, proto.HasExtension(options, test.ext), test.wantField)
}
if test.wantValue != proto.GetExtension(options, test.ext).(int32) {
t.Errorf("GetExtension(%s): got %d, want %d", test.name, proto.GetExtension(options, test.ext).(int32), test.wantValue)
}
}
}
func TestAllEntitiesWithMessageOption(t *testing.T) {
file := retentionpb.File_cmd_protoc_gen_go_testdata_retention_retention_proto
verifyDescriptorOptions(t, string(file.Name()), file.Options())
verifyEnums(t, file.Enums())
verifyMessages(t, file.Messages())
verifyExtensions(t, file.Extensions())
verifyServices(t, file.Services())
}
func verifyExtensions(t *testing.T, extensions protoreflect.ExtensionDescriptors) {
t.Helper()
for i := 0; i < extensions.Len(); i++ {
verifyDescriptorOptions(t, string(extensions.Get(i).Name()), extensions.Get(i).Options())
}
}
func verifyMessages(t *testing.T, messages protoreflect.MessageDescriptors) {
t.Helper()
for i := 0; i < messages.Len(); i++ {
verifyDescriptorOptions(t, string(messages.Get(i).Name()), messages.Get(i).Options())
verifyEnums(t, messages.Get(i).Enums())
verifyMessages(t, messages.Get(i).Messages())
verifyExtensions(t, messages.Get(i).Extensions())
verifyFields(t, messages.Get(i).Fields())
}
}
func verifyFields(t *testing.T, fields protoreflect.FieldDescriptors) {
t.Helper()
for i := 0; i < fields.Len(); i++ {
verifyDescriptorOptions(t, string(fields.Get(i).Name()), fields.Get(i).Options())
}
}
func verifyEnums(t *testing.T, enums protoreflect.EnumDescriptors) {
t.Helper()
for i := 0; i < enums.Len(); i++ {
verifyDescriptorOptions(t, string(enums.Get(i).Name()), enums.Get(i).Options())
verifyEnumValues(t, enums.Get(i).Values())
}
}
func verifyEnumValues(t *testing.T, values protoreflect.EnumValueDescriptors) {
t.Helper()
for i := 0; i < values.Len(); i++ {
verifyDescriptorOptions(t, string(values.Get(i).Name()), values.Get(i).Options())
}
}
func verifyServices(t *testing.T, services protoreflect.ServiceDescriptors) {
t.Helper()
for i := 0; i < services.Len(); i++ {
verifyDescriptorOptions(t, string(services.Get(i).Name()), services.Get(i).Options())
verifyMethods(t, services.Get(i).Methods())
}
}
func verifyMethods(t *testing.T, methods protoreflect.MethodDescriptors) {
t.Helper()
for i := 0; i < methods.Len(); i++ {
verifyDescriptorOptions(t, string(methods.Get(i).Name()), methods.Get(i).Options())
}
}
func verifyDescriptorOptions(t *testing.T, entity string, options protoreflect.ProtoMessage) {
t.Helper()
options.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
maybeVerifyOption(t, fd, v)
return true
})
}
func maybeVerifyOption(t *testing.T, fd protoreflect.FieldDescriptor, v protoreflect.Value) {
t.Helper()
if fd.Kind() == protoreflect.MessageKind && string(fd.Message().FullName()) == "goproto.proto.testretention.OptionsMessage" {
if fd.IsList() {
for i := 0; i < v.List().Len(); i++ {
verifyOptionsMessage(t, string(fd.FullName()), v.List().Get(i).Message().Interface().(*retentionpb.OptionsMessage))
}
} else {
verifyOptionsMessage(t, string(fd.FullName()), v.Message().Interface().(*retentionpb.OptionsMessage))
}
}
}
func verifyOptionsMessage(t *testing.T, entity string, msg *retentionpb.OptionsMessage) {
t.Helper()
if msg.PlainField == nil {
t.Errorf("%s.OptionsMessage.HasField(plain_field): got false, want true", entity)
}
if msg.GetPlainField() != 1 {
t.Errorf("%s.OptionsMessage.GetField(plain_field): got %d, want 1", entity, msg.GetPlainField())
}
if msg.RuntimeRetentionField == nil {
t.Errorf("%s.OptionsMessage.HasField(runtime_retention_field): got false, want true", entity)
}
if msg.GetRuntimeRetentionField() != 2 {
t.Errorf("%s.OptionsMessage.GetField(runtime_retention_field): got %d, want 2", entity, msg.GetRuntimeRetentionField())
}
if msg.SourceRetentionField != nil {
t.Errorf("%s.OptionsMessage.HasField(source_retention_field): got true, want false", entity)
}
if msg.GetSourceRetentionField() != 0 {
// Checking that we get 0 even though this was set to 3 in the source file
t.Errorf("%s.OptionsMessage.GetField(source_retention_field): got %d, want 0", entity, msg.GetSourceRetentionField())
}
}