blob: d058c196d33c4b4f569465a08afec062408280ed [file] [log] [blame]
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protoregistry_test
import (
"fmt"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/encoding/prototext"
pimpl "google.golang.org/protobuf/internal/impl"
pdesc "google.golang.org/protobuf/reflect/protodesc"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
testpb "google.golang.org/protobuf/internal/testprotos/registry"
"google.golang.org/protobuf/types/descriptorpb"
)
func mustMakeFile(s string) pref.FileDescriptor {
pb := new(descriptorpb.FileDescriptorProto)
if err := prototext.Unmarshal([]byte(s), pb); err != nil {
panic(err)
}
fd, err := pdesc.NewFile(pb, nil)
if err != nil {
panic(err)
}
return fd
}
func TestFiles(t *testing.T) {
type (
file struct {
Path string
Pkg pref.FullName
}
testFile struct {
inFile pref.FileDescriptor
wantErr string
}
testFindDesc struct {
inName pref.FullName
wantFound bool
}
testRangePkg struct {
inPkg pref.FullName
wantFiles []file
}
testFindPath struct {
inPath string
wantFiles []file
wantErr string
}
)
tests := []struct {
files []testFile
findDescs []testFindDesc
rangePkgs []testRangePkg
findPaths []testFindPath
}{{
// Test that overlapping packages and files are permitted.
files: []testFile{
{inFile: mustMakeFile(`syntax:"proto2" name:"test1.proto" package:"foo.bar"`)},
{inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"my.test"`)},
{inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"foo.bar.baz"`), wantErr: "already registered"},
{inFile: mustMakeFile(`syntax:"proto2" name:"test2.proto" package:"my.test.package"`)},
{inFile: mustMakeFile(`syntax:"proto2" name:"weird" package:"foo.bar"`)},
{inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/baz/../test.proto" package:"my.test"`)},
},
rangePkgs: []testRangePkg{{
inPkg: "nothing",
}, {
inPkg: "",
}, {
inPkg: ".",
}, {
inPkg: "foo",
}, {
inPkg: "foo.",
}, {
inPkg: "foo..",
}, {
inPkg: "foo.bar",
wantFiles: []file{
{"test1.proto", "foo.bar"},
{"weird", "foo.bar"},
},
}, {
inPkg: "my.test",
wantFiles: []file{
{"foo/bar/baz/../test.proto", "my.test"},
{"foo/bar/test.proto", "my.test"},
},
}, {
inPkg: "fo",
}},
findPaths: []testFindPath{{
inPath: "nothing",
wantErr: "not found",
}, {
inPath: "weird",
wantFiles: []file{
{"weird", "foo.bar"},
},
}, {
inPath: "foo/bar/test.proto",
wantFiles: []file{
{"foo/bar/test.proto", "my.test"},
},
}},
}, {
// Test when new enum conflicts with existing package.
files: []testFile{{
inFile: mustMakeFile(`syntax:"proto2" name:"test1a.proto" package:"foo.bar.baz"`),
}, {
inFile: mustMakeFile(`syntax:"proto2" name:"test1b.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`),
wantErr: `file "test1b.proto" has a name conflict over foo`,
}},
}, {
// Test when new package conflicts with existing enum.
files: []testFile{{
inFile: mustMakeFile(`syntax:"proto2" name:"test2a.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`),
}, {
inFile: mustMakeFile(`syntax:"proto2" name:"test2b.proto" package:"foo.bar.baz"`),
wantErr: `file "test2b.proto" has a package name conflict over foo`,
}},
}, {
// Test when new enum conflicts with existing enum in same package.
files: []testFile{{
inFile: mustMakeFile(`syntax:"proto2" name:"test3a.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE" number:0}]}]`),
}, {
inFile: mustMakeFile(`syntax:"proto2" name:"test3b.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE2" number:0}]}]`),
wantErr: `file "test3b.proto" has a name conflict over foo.BAR`,
}},
}, {
files: []testFile{{
inFile: mustMakeFile(`
syntax: "proto2"
name: "test1.proto"
package: "fizz.buzz"
message_type: [{
name: "Message"
field: [
{name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING oneof_index:0}
]
oneof_decl: [{name:"Oneof"}]
extension_range: [{start:1000 end:2000}]
enum_type: [
{name:"Enum" value:[{name:"EnumValue" number:0}]}
]
nested_type: [
{name:"Message" field:[{name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}]}
]
extension: [
{name:"Extension" number:1001 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"}
]
}]
enum_type: [{
name: "Enum"
value: [{name:"EnumValue" number:0}]
}]
extension: [
{name:"Extension" number:1000 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"}
]
service: [{
name: "Service"
method: [{
name: "Method"
input_type: ".fizz.buzz.Message"
output_type: ".fizz.buzz.Message"
client_streaming: true
server_streaming: true
}]
}]
`),
}, {
inFile: mustMakeFile(`
syntax: "proto2"
name: "test2.proto"
package: "fizz.buzz.gazz"
enum_type: [{
name: "Enum"
value: [{name:"EnumValue" number:0}]
}]
`),
}, {
inFile: mustMakeFile(`
syntax: "proto2"
name: "test3.proto"
package: "fizz.buzz"
enum_type: [{
name: "Enum1"
value: [{name:"EnumValue1" number:0}]
}, {
name: "Enum2"
value: [{name:"EnumValue2" number:0}]
}]
`),
}, {
// Make sure we can register without package name.
inFile: mustMakeFile(`
name: "weird"
syntax: "proto2"
message_type: [{
name: "Message"
nested_type: [{
name: "Message"
nested_type: [{
name: "Message"
}]
}]
}]
`),
}},
findDescs: []testFindDesc{
{inName: "fizz.buzz.message", wantFound: false},
{inName: "fizz.buzz.Message", wantFound: true},
{inName: "fizz.buzz.Message.X", wantFound: false},
{inName: "fizz.buzz.Field", wantFound: false},
{inName: "fizz.buzz.Oneof", wantFound: false},
{inName: "fizz.buzz.Message.Field", wantFound: true},
{inName: "fizz.buzz.Message.Field.X", wantFound: false},
{inName: "fizz.buzz.Message.Oneof", wantFound: true},
{inName: "fizz.buzz.Message.Oneof.X", wantFound: false},
{inName: "fizz.buzz.Message.Message", wantFound: true},
{inName: "fizz.buzz.Message.Message.X", wantFound: false},
{inName: "fizz.buzz.Message.Enum", wantFound: true},
{inName: "fizz.buzz.Message.Enum.X", wantFound: false},
{inName: "fizz.buzz.Message.EnumValue", wantFound: true},
{inName: "fizz.buzz.Message.EnumValue.X", wantFound: false},
{inName: "fizz.buzz.Message.Extension", wantFound: true},
{inName: "fizz.buzz.Message.Extension.X", wantFound: false},
{inName: "fizz.buzz.enum", wantFound: false},
{inName: "fizz.buzz.Enum", wantFound: true},
{inName: "fizz.buzz.Enum.X", wantFound: false},
{inName: "fizz.buzz.EnumValue", wantFound: true},
{inName: "fizz.buzz.EnumValue.X", wantFound: false},
{inName: "fizz.buzz.Enum.EnumValue", wantFound: false},
{inName: "fizz.buzz.Extension", wantFound: true},
{inName: "fizz.buzz.Extension.X", wantFound: false},
{inName: "fizz.buzz.service", wantFound: false},
{inName: "fizz.buzz.Service", wantFound: true},
{inName: "fizz.buzz.Service.X", wantFound: false},
{inName: "fizz.buzz.Method", wantFound: false},
{inName: "fizz.buzz.Service.Method", wantFound: true},
{inName: "fizz.buzz.Service.Method.X", wantFound: false},
{inName: "fizz.buzz.gazz", wantFound: false},
{inName: "fizz.buzz.gazz.Enum", wantFound: true},
{inName: "fizz.buzz.gazz.EnumValue", wantFound: true},
{inName: "fizz.buzz.gazz.Enum.EnumValue", wantFound: false},
{inName: "fizz.buzz", wantFound: false},
{inName: "fizz.buzz.Enum1", wantFound: true},
{inName: "fizz.buzz.EnumValue1", wantFound: true},
{inName: "fizz.buzz.Enum1.EnumValue1", wantFound: false},
{inName: "fizz.buzz.Enum2", wantFound: true},
{inName: "fizz.buzz.EnumValue2", wantFound: true},
{inName: "fizz.buzz.Enum2.EnumValue2", wantFound: false},
{inName: "fizz.buzz.Enum3", wantFound: false},
{inName: "", wantFound: false},
{inName: "Message", wantFound: true},
{inName: "Message.Message", wantFound: true},
{inName: "Message.Message.Message", wantFound: true},
{inName: "Message.Message.Message.Message", wantFound: false},
},
}}
sortFiles := cmpopts.SortSlices(func(x, y file) bool {
return x.Path < y.Path || (x.Path == y.Path && x.Pkg < y.Pkg)
})
for _, tt := range tests {
t.Run("", func(t *testing.T) {
var files preg.Files
for i, tc := range tt.files {
gotErr := files.RegisterFile(tc.inFile)
if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr)
}
}
for _, tc := range tt.findDescs {
d, _ := files.FindDescriptorByName(tc.inName)
gotFound := d != nil
if gotFound != tc.wantFound {
t.Errorf("FindDescriptorByName(%v) find mismatch: got %v, want %v", tc.inName, gotFound, tc.wantFound)
}
}
for _, tc := range tt.rangePkgs {
var gotFiles []file
var gotCnt int
wantCnt := files.NumFilesByPackage(tc.inPkg)
files.RangeFilesByPackage(tc.inPkg, func(fd pref.FileDescriptor) bool {
gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
gotCnt++
return true
})
if gotCnt != wantCnt {
t.Errorf("NumFilesByPackage(%v) = %v, want %v", tc.inPkg, gotCnt, wantCnt)
}
if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff)
}
}
for _, tc := range tt.findPaths {
var gotFiles []file
fd, gotErr := files.FindFileByPath(tc.inPath)
if gotErr == nil {
gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
}
if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
t.Errorf("FindFileByPath(%v) = %v, want %v", tc.inPath, gotErr, tc.wantErr)
}
if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
t.Errorf("FindFileByPath(%v) mismatch (-want +got):\n%v", tc.inPath, diff)
}
}
})
}
}
func TestTypes(t *testing.T) {
mt1 := pimpl.Export{}.MessageTypeOf(&testpb.Message1{})
et1 := pimpl.Export{}.EnumTypeOf(testpb.Enum1_ONE)
xt1 := testpb.E_StringField
xt2 := testpb.E_Message4_MessageField
registry := new(preg.Types)
if err := registry.RegisterMessage(mt1); err != nil {
t.Fatalf("registry.RegisterMessage(%v) returns unexpected error: %v", mt1.Descriptor().FullName(), err)
}
if err := registry.RegisterEnum(et1); err != nil {
t.Fatalf("registry.RegisterEnum(%v) returns unexpected error: %v", et1.Descriptor().FullName(), err)
}
if err := registry.RegisterExtension(xt1); err != nil {
t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt1.TypeDescriptor().FullName(), err)
}
if err := registry.RegisterExtension(xt2); err != nil {
t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt2.TypeDescriptor().FullName(), err)
}
t.Run("FindMessageByName", func(t *testing.T) {
tests := []struct {
name string
messageType pref.MessageType
wantErr bool
wantNotFound bool
}{{
name: "testprotos.Message1",
messageType: mt1,
}, {
name: "testprotos.NoSuchMessage",
wantErr: true,
wantNotFound: true,
}, {
name: "testprotos.Enum1",
wantErr: true,
}, {
name: "testprotos.Enum2",
wantErr: true,
}, {
name: "testprotos.Enum3",
wantErr: true,
}}
for _, tc := range tests {
got, err := registry.FindMessageByName(pref.FullName(tc.name))
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindMessageByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindMessageByName(%v) got error: %v, want NotFound error", tc.name, err)
continue
}
if got != tc.messageType {
t.Errorf("FindMessageByName(%v) got wrong value: %v", tc.name, got)
}
}
})
t.Run("FindMessageByURL", func(t *testing.T) {
tests := []struct {
name string
messageType pref.MessageType
wantErr bool
wantNotFound bool
}{{
name: "testprotos.Message1",
messageType: mt1,
}, {
name: "type.googleapis.com/testprotos.Nada",
wantErr: true,
wantNotFound: true,
}, {
name: "testprotos.Enum1",
wantErr: true,
}}
for _, tc := range tests {
got, err := registry.FindMessageByURL(tc.name)
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindMessageByURL(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindMessageByURL(%v) got error: %v, want NotFound error", tc.name, err)
continue
}
if got != tc.messageType {
t.Errorf("FindMessageByURL(%v) got wrong value: %v", tc.name, got)
}
}
})
t.Run("FindEnumByName", func(t *testing.T) {
tests := []struct {
name string
enumType pref.EnumType
wantErr bool
wantNotFound bool
}{{
name: "testprotos.Enum1",
enumType: et1,
}, {
name: "testprotos.None",
wantErr: true,
wantNotFound: true,
}, {
name: "testprotos.Message1",
wantErr: true,
}}
for _, tc := range tests {
got, err := registry.FindEnumByName(pref.FullName(tc.name))
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindEnumByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindEnumByName(%v) got error: %v, want NotFound error", tc.name, err)
continue
}
if got != tc.enumType {
t.Errorf("FindEnumByName(%v) got wrong value: %v", tc.name, got)
}
}
})
t.Run("FindExtensionByName", func(t *testing.T) {
tests := []struct {
name string
extensionType pref.ExtensionType
wantErr bool
wantNotFound bool
}{{
name: "testprotos.string_field",
extensionType: xt1,
}, {
name: "testprotos.Message4.message_field",
extensionType: xt2,
}, {
name: "testprotos.None",
wantErr: true,
wantNotFound: true,
}, {
name: "testprotos.Message1",
wantErr: true,
}}
for _, tc := range tests {
got, err := registry.FindExtensionByName(pref.FullName(tc.name))
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindExtensionByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindExtensionByName(%v) got error: %v, want NotFound error", tc.name, err)
continue
}
if got != tc.extensionType {
t.Errorf("FindExtensionByName(%v) got wrong value: %v", tc.name, got)
}
}
})
t.Run("FindExtensionByNumber", func(t *testing.T) {
tests := []struct {
parent string
number int32
extensionType pref.ExtensionType
wantErr bool
wantNotFound bool
}{{
parent: "testprotos.Message1",
number: 11,
extensionType: xt1,
}, {
parent: "testprotos.Message1",
number: 13,
wantErr: true,
wantNotFound: true,
}, {
parent: "testprotos.Message1",
number: 21,
extensionType: xt2,
}, {
parent: "testprotos.Message1",
number: 23,
wantErr: true,
wantNotFound: true,
}, {
parent: "testprotos.NoSuchMessage",
number: 11,
wantErr: true,
wantNotFound: true,
}, {
parent: "testprotos.Message1",
number: 30,
wantErr: true,
wantNotFound: true,
}, {
parent: "testprotos.Message1",
number: 99,
wantErr: true,
wantNotFound: true,
}}
for _, tc := range tests {
got, err := registry.FindExtensionByNumber(pref.FullName(tc.parent), pref.FieldNumber(tc.number))
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindExtensionByNumber(%v, %d) = (_, %v), want error? %t", tc.parent, tc.number, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindExtensionByNumber(%v, %d) got error %v, want NotFound error", tc.parent, tc.number, err)
continue
}
if got != tc.extensionType {
t.Errorf("FindExtensionByNumber(%v, %d) got wrong value: %v", tc.parent, tc.number, got)
}
}
})
sortTypes := cmp.Options{
cmpopts.SortSlices(func(x, y pref.EnumType) bool {
return x.Descriptor().FullName() < y.Descriptor().FullName()
}),
cmpopts.SortSlices(func(x, y pref.MessageType) bool {
return x.Descriptor().FullName() < y.Descriptor().FullName()
}),
cmpopts.SortSlices(func(x, y pref.ExtensionType) bool {
return x.TypeDescriptor().FullName() < y.TypeDescriptor().FullName()
}),
}
compare := cmp.Options{
cmp.Comparer(func(x, y pref.EnumType) bool {
return x == y
}),
cmp.Comparer(func(x, y pref.ExtensionType) bool {
return x == y
}),
cmp.Comparer(func(x, y pref.MessageType) bool {
return x == y
}),
}
t.Run("RangeEnums", func(t *testing.T) {
want := []pref.EnumType{et1}
var got []pref.EnumType
var gotCnt int
wantCnt := registry.NumEnums()
registry.RangeEnums(func(et pref.EnumType) bool {
got = append(got, et)
gotCnt++
return true
})
if gotCnt != wantCnt {
t.Errorf("NumEnums() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
}
})
t.Run("RangeMessages", func(t *testing.T) {
want := []pref.MessageType{mt1}
var got []pref.MessageType
var gotCnt int
wantCnt := registry.NumMessages()
registry.RangeMessages(func(mt pref.MessageType) bool {
got = append(got, mt)
gotCnt++
return true
})
if gotCnt != wantCnt {
t.Errorf("NumMessages() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
}
})
t.Run("RangeExtensions", func(t *testing.T) {
want := []pref.ExtensionType{xt1, xt2}
var got []pref.ExtensionType
var gotCnt int
wantCnt := registry.NumExtensions()
registry.RangeExtensions(func(xt pref.ExtensionType) bool {
got = append(got, xt)
gotCnt++
return true
})
if gotCnt != wantCnt {
t.Errorf("NumExtensions() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff)
}
})
t.Run("RangeExtensionsByMessage", func(t *testing.T) {
want := []pref.ExtensionType{xt1, xt2}
var got []pref.ExtensionType
var gotCnt int
wantCnt := registry.NumExtensionsByMessage("testprotos.Message1")
registry.RangeExtensionsByMessage("testprotos.Message1", func(xt pref.ExtensionType) bool {
got = append(got, xt)
gotCnt++
return true
})
if gotCnt != wantCnt {
t.Errorf("NumExtensionsByMessage() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff)
}
})
}