blob: 5a8010f18fafaa341c2137d54b0d09e1b7c5ebd7 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dynamicpb
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
type extField struct {
name protoreflect.FullName
number protoreflect.FieldNumber
}
// A Types is a collection of dynamically constructed descriptors.
// Its methods are safe for concurrent use.
//
// Types implements protoregistry.MessageTypeResolver and protoregistry.ExtensionTypeResolver.
// A Types may be used as a proto.UnmarshalOptions.Resolver.
type Types struct {
files *protoregistry.Files
extMu sync.Mutex
atomicExtFiles uint64
extensionsByMessage map[extField]protoreflect.ExtensionDescriptor
}
// NewTypes creates a new Types registry with the provided files.
// The Files registry is retained, and changes to Files will be reflected in Types.
// It is not safe to concurrently change the Files while calling Types methods.
func NewTypes(f *protoregistry.Files) *Types {
return &Types{
files: f,
}
}
// FindEnumByName looks up an enum by its full name;
// e.g., "google.protobuf.Field.Kind".
//
// This returns (nil, protoregistry.NotFound) if not found.
func (t *Types) FindEnumByName(name protoreflect.FullName) (protoreflect.EnumType, error) {
d, err := t.files.FindDescriptorByName(name)
if err != nil {
return nil, err
}
ed, ok := d.(protoreflect.EnumDescriptor)
if !ok {
return nil, errors.New("found wrong type: got %v, want enum", descName(d))
}
return NewEnumType(ed), nil
}
// FindExtensionByName looks up an extension field by the field's full name.
// Note that this is the full name of the field as determined by
// where the extension is declared and is unrelated to the full name of the
// message being extended.
//
// This returns (nil, protoregistry.NotFound) if not found.
func (t *Types) FindExtensionByName(name protoreflect.FullName) (protoreflect.ExtensionType, error) {
d, err := t.files.FindDescriptorByName(name)
if err != nil {
return nil, err
}
xd, ok := d.(protoreflect.ExtensionDescriptor)
if !ok {
return nil, errors.New("found wrong type: got %v, want extension", descName(d))
}
return NewExtensionType(xd), nil
}
// FindExtensionByNumber looks up an extension field by the field number
// within some parent message, identified by full name.
//
// This returns (nil, protoregistry.NotFound) if not found.
func (t *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
// Construct the extension number map lazily, since not every user will need it.
// Update the map if new files are added to the registry.
if atomic.LoadUint64(&t.atomicExtFiles) != uint64(t.files.NumFiles()) {
t.updateExtensions()
}
xd := t.extensionsByMessage[extField{message, field}]
if xd == nil {
return nil, protoregistry.NotFound
}
return NewExtensionType(xd), nil
}
// FindMessageByName looks up a message by its full name;
// e.g. "google.protobuf.Any".
//
// This returns (nil, protoregistry.NotFound) if not found.
func (t *Types) FindMessageByName(name protoreflect.FullName) (protoreflect.MessageType, error) {
d, err := t.files.FindDescriptorByName(name)
if err != nil {
return nil, err
}
md, ok := d.(protoreflect.MessageDescriptor)
if !ok {
return nil, errors.New("found wrong type: got %v, want message", descName(d))
}
return NewMessageType(md), nil
}
// FindMessageByURL looks up a message by a URL identifier.
// See documentation on google.protobuf.Any.type_url for the URL format.
//
// This returns (nil, protoregistry.NotFound) if not found.
func (t *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
// This function is similar to FindMessageByName but
// truncates anything before and including '/' in the URL.
message := protoreflect.FullName(url)
if i := strings.LastIndexByte(url, '/'); i >= 0 {
message = message[i+len("/"):]
}
return t.FindMessageByName(message)
}
func (t *Types) updateExtensions() {
t.extMu.Lock()
defer t.extMu.Unlock()
if atomic.LoadUint64(&t.atomicExtFiles) == uint64(t.files.NumFiles()) {
return
}
defer atomic.StoreUint64(&t.atomicExtFiles, uint64(t.files.NumFiles()))
t.files.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
t.registerExtensions(fd.Extensions())
t.registerExtensionsInMessages(fd.Messages())
return true
})
}
func (t *Types) registerExtensionsInMessages(mds protoreflect.MessageDescriptors) {
count := mds.Len()
for i := 0; i < count; i++ {
md := mds.Get(i)
t.registerExtensions(md.Extensions())
t.registerExtensionsInMessages(md.Messages())
}
}
func (t *Types) registerExtensions(xds protoreflect.ExtensionDescriptors) {
count := xds.Len()
for i := 0; i < count; i++ {
xd := xds.Get(i)
field := xd.Number()
message := xd.ContainingMessage().FullName()
if t.extensionsByMessage == nil {
t.extensionsByMessage = make(map[extField]protoreflect.ExtensionDescriptor)
}
t.extensionsByMessage[extField{message, field}] = xd
}
}
func descName(d protoreflect.Descriptor) string {
switch d.(type) {
case protoreflect.EnumDescriptor:
return "enum"
case protoreflect.EnumValueDescriptor:
return "enum value"
case protoreflect.MessageDescriptor:
return "message"
case protoreflect.ExtensionDescriptor:
return "extension"
case protoreflect.ServiceDescriptor:
return "service"
default:
return fmt.Sprintf("%T", d)
}
}