types/dynamicpb: add NewTypes

Add a function to construct a dynamic type registry from a
protoregistry.Files. The NewTypes constructor takes a concrete
Files to permit future improvements based on changes to Files.
(For example, we might add a Files.FindExtensionByNumber
method, which Types could take advantage of.)

Fixes golang/protobuf#1216

Change-Id: I61edba0a94528829d40f69fad773ccb5912859e0
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/489316
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Lasse Folger <lassefolger@google.com>
Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
diff --git a/types/dynamicpb/types.go b/types/dynamicpb/types.go
new file mode 100644
index 0000000..5a8010f
--- /dev/null
+++ b/types/dynamicpb/types.go
@@ -0,0 +1,177 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dynamicpb
+
+import (
+	"fmt"
+	"strings"
+	"sync"
+	"sync/atomic"
+
+	"google.golang.org/protobuf/internal/errors"
+	"google.golang.org/protobuf/reflect/protoreflect"
+	"google.golang.org/protobuf/reflect/protoregistry"
+)
+
+type extField struct {
+	name   protoreflect.FullName
+	number protoreflect.FieldNumber
+}
+
+// A Types is a collection of dynamically constructed descriptors.
+// Its methods are safe for concurrent use.
+//
+// Types implements protoregistry.MessageTypeResolver and protoregistry.ExtensionTypeResolver.
+// A Types may be used as a proto.UnmarshalOptions.Resolver.
+type Types struct {
+	files *protoregistry.Files
+
+	extMu               sync.Mutex
+	atomicExtFiles      uint64
+	extensionsByMessage map[extField]protoreflect.ExtensionDescriptor
+}
+
+// NewTypes creates a new Types registry with the provided files.
+// The Files registry is retained, and changes to Files will be reflected in Types.
+// It is not safe to concurrently change the Files while calling Types methods.
+func NewTypes(f *protoregistry.Files) *Types {
+	return &Types{
+		files: f,
+	}
+}
+
+// FindEnumByName looks up an enum by its full name;
+// e.g., "google.protobuf.Field.Kind".
+//
+// This returns (nil, protoregistry.NotFound) if not found.
+func (t *Types) FindEnumByName(name protoreflect.FullName) (protoreflect.EnumType, error) {
+	d, err := t.files.FindDescriptorByName(name)
+	if err != nil {
+		return nil, err
+	}
+	ed, ok := d.(protoreflect.EnumDescriptor)
+	if !ok {
+		return nil, errors.New("found wrong type: got %v, want enum", descName(d))
+	}
+	return NewEnumType(ed), nil
+}
+
+// FindExtensionByName looks up an extension field by the field's full name.
+// Note that this is the full name of the field as determined by
+// where the extension is declared and is unrelated to the full name of the
+// message being extended.
+//
+// This returns (nil, protoregistry.NotFound) if not found.
+func (t *Types) FindExtensionByName(name protoreflect.FullName) (protoreflect.ExtensionType, error) {
+	d, err := t.files.FindDescriptorByName(name)
+	if err != nil {
+		return nil, err
+	}
+	xd, ok := d.(protoreflect.ExtensionDescriptor)
+	if !ok {
+		return nil, errors.New("found wrong type: got %v, want extension", descName(d))
+	}
+	return NewExtensionType(xd), nil
+}
+
+// FindExtensionByNumber looks up an extension field by the field number
+// within some parent message, identified by full name.
+//
+// This returns (nil, protoregistry.NotFound) if not found.
+func (t *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
+	// Construct the extension number map lazily, since not every user will need it.
+	// Update the map if new files are added to the registry.
+	if atomic.LoadUint64(&t.atomicExtFiles) != uint64(t.files.NumFiles()) {
+		t.updateExtensions()
+	}
+	xd := t.extensionsByMessage[extField{message, field}]
+	if xd == nil {
+		return nil, protoregistry.NotFound
+	}
+	return NewExtensionType(xd), nil
+}
+
+// FindMessageByName looks up a message by its full name;
+// e.g. "google.protobuf.Any".
+//
+// This returns (nil, protoregistry.NotFound) if not found.
+func (t *Types) FindMessageByName(name protoreflect.FullName) (protoreflect.MessageType, error) {
+	d, err := t.files.FindDescriptorByName(name)
+	if err != nil {
+		return nil, err
+	}
+	md, ok := d.(protoreflect.MessageDescriptor)
+	if !ok {
+		return nil, errors.New("found wrong type: got %v, want message", descName(d))
+	}
+	return NewMessageType(md), nil
+}
+
+// FindMessageByURL looks up a message by a URL identifier.
+// See documentation on google.protobuf.Any.type_url for the URL format.
+//
+// This returns (nil, protoregistry.NotFound) if not found.
+func (t *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
+	// This function is similar to FindMessageByName but
+	// truncates anything before and including '/' in the URL.
+	message := protoreflect.FullName(url)
+	if i := strings.LastIndexByte(url, '/'); i >= 0 {
+		message = message[i+len("/"):]
+	}
+	return t.FindMessageByName(message)
+}
+
+func (t *Types) updateExtensions() {
+	t.extMu.Lock()
+	defer t.extMu.Unlock()
+	if atomic.LoadUint64(&t.atomicExtFiles) == uint64(t.files.NumFiles()) {
+		return
+	}
+	defer atomic.StoreUint64(&t.atomicExtFiles, uint64(t.files.NumFiles()))
+	t.files.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
+		t.registerExtensions(fd.Extensions())
+		t.registerExtensionsInMessages(fd.Messages())
+		return true
+	})
+}
+
+func (t *Types) registerExtensionsInMessages(mds protoreflect.MessageDescriptors) {
+	count := mds.Len()
+	for i := 0; i < count; i++ {
+		md := mds.Get(i)
+		t.registerExtensions(md.Extensions())
+		t.registerExtensionsInMessages(md.Messages())
+	}
+}
+
+func (t *Types) registerExtensions(xds protoreflect.ExtensionDescriptors) {
+	count := xds.Len()
+	for i := 0; i < count; i++ {
+		xd := xds.Get(i)
+		field := xd.Number()
+		message := xd.ContainingMessage().FullName()
+		if t.extensionsByMessage == nil {
+			t.extensionsByMessage = make(map[extField]protoreflect.ExtensionDescriptor)
+		}
+		t.extensionsByMessage[extField{message, field}] = xd
+	}
+}
+
+func descName(d protoreflect.Descriptor) string {
+	switch d.(type) {
+	case protoreflect.EnumDescriptor:
+		return "enum"
+	case protoreflect.EnumValueDescriptor:
+		return "enum value"
+	case protoreflect.MessageDescriptor:
+		return "message"
+	case protoreflect.ExtensionDescriptor:
+		return "extension"
+	case protoreflect.ServiceDescriptor:
+		return "service"
+	default:
+		return fmt.Sprintf("%T", d)
+	}
+}
diff --git a/types/dynamicpb/types_test.go b/types/dynamicpb/types_test.go
new file mode 100644
index 0000000..1878f79
--- /dev/null
+++ b/types/dynamicpb/types_test.go
@@ -0,0 +1,174 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dynamicpb_test
+
+import (
+	"strings"
+	"testing"
+
+	"google.golang.org/protobuf/reflect/protoreflect"
+	"google.golang.org/protobuf/reflect/protoregistry"
+	"google.golang.org/protobuf/types/descriptorpb"
+	"google.golang.org/protobuf/types/dynamicpb"
+
+	registrypb "google.golang.org/protobuf/internal/testprotos/registry"
+)
+
+var _ protoregistry.ExtensionTypeResolver = &dynamicpb.Types{}
+var _ protoregistry.MessageTypeResolver = &dynamicpb.Types{}
+
+func newTestTypes() *dynamicpb.Types {
+	files := &protoregistry.Files{}
+	files.RegisterFile(registrypb.File_internal_testprotos_registry_test_proto)
+	return dynamicpb.NewTypes(files)
+}
+
+func TestDynamicTypesTypeMismatch(t *testing.T) {
+	types := newTestTypes()
+	const messageName = "testprotos.Message1"
+	const enumName = "testprotos.Enum1"
+
+	_, err := types.FindEnumByName(messageName)
+	want := "found wrong type: got message, want enum"
+	if err == nil || !strings.Contains(err.Error(), want) {
+		t.Errorf("types.FindEnumByName(%q) = _, %q, want %q", messageName, err, want)
+	}
+
+	_, err = types.FindMessageByName(enumName)
+	want = "found wrong type: got enum, want message"
+	if err == nil || !strings.Contains(err.Error(), want) {
+		t.Errorf("types.FindMessageByName(%q) = _, %q, want %q", messageName, err, want)
+	}
+
+	_, err = types.FindExtensionByName(enumName)
+	want = "found wrong type: got enum, want extension"
+	if err == nil || !strings.Contains(err.Error(), want) {
+		t.Errorf("types.FindExtensionByName(%q) = _, %q, want %q", messageName, err, want)
+	}
+}
+
+func TestDynamicTypesEnumNotFound(t *testing.T) {
+	types := newTestTypes()
+	for _, name := range []protoreflect.FullName{
+		"Enum1",
+		"testprotos.DoesNotExist",
+	} {
+		_, err := types.FindEnumByName(name)
+		if err != protoregistry.NotFound {
+			t.Errorf("types.FindEnumByName(%q) = _, %v; want protoregistry.NotFound", name, err)
+		}
+	}
+}
+
+func TestDynamicTypesFindEnumByName(t *testing.T) {
+	types := newTestTypes()
+	name := protoreflect.FullName("testprotos.Enum1")
+	et, err := types.FindEnumByName(name)
+	if err != nil {
+		t.Fatalf("types.FindEnumByName(%q) = %v", name, err)
+	}
+	if got, want := et.Descriptor().FullName(), name; got != want {
+		t.Fatalf("types.FindEnumByName(%q).Descriptor().FullName() = %q, want %q", name, got, want)
+	}
+}
+
+func TestDynamicTypesMessageNotFound(t *testing.T) {
+	types := newTestTypes()
+	for _, name := range []protoreflect.FullName{
+		"Message1",
+		"testprotos.DoesNotExist",
+	} {
+		_, err := types.FindMessageByName(name)
+		if err != protoregistry.NotFound {
+			t.Errorf("types.FindMessageByName(%q) = _, %v; want protoregistry.NotFound", name, err)
+		}
+	}
+}
+
+func TestDynamicTypesFindMessageByName(t *testing.T) {
+	types := newTestTypes()
+	name := protoreflect.FullName("testprotos.Message1")
+	mt, err := types.FindMessageByName(name)
+	if err != nil {
+		t.Fatalf("types.FindMessageByName(%q) = %v", name, err)
+	}
+	if got, want := mt.Descriptor().FullName(), name; got != want {
+		t.Fatalf("types.FindMessageByName(%q).Descriptor().FullName() = %q, want %q", name, got, want)
+	}
+}
+
+func TestDynamicTypesExtensionNotFound(t *testing.T) {
+	types := newTestTypes()
+	for _, name := range []protoreflect.FullName{
+		"string_field",
+		"testprotos.DoesNotExist",
+	} {
+		_, err := types.FindExtensionByName(name)
+		if err != protoregistry.NotFound {
+			t.Errorf("types.FindExtensionByName(%q) = _, %v; want protoregistry.NotFound", name, err)
+		}
+	}
+	messageName := protoreflect.FullName("testprotos.Message1")
+	if _, err := types.FindExtensionByNumber(messageName, 100); err != protoregistry.NotFound {
+		t.Errorf("types.FindExtensionByNumber(%q, 100) = _, %v; want protoregistry.NotFound", messageName, 100)
+	}
+}
+
+func TestDynamicTypesFindExtensionByNameOrNumber(t *testing.T) {
+	types := newTestTypes()
+	messageName := protoreflect.FullName("testprotos.Message1")
+	mt, err := types.FindMessageByName(messageName)
+	if err != nil {
+		t.Fatalf("types.FindMessageByName(%q) = %v", messageName, err)
+	}
+	for _, extensionName := range []protoreflect.FullName{
+		"testprotos.string_field",
+		"testprotos.Message4.message_field",
+	} {
+		xt, err := types.FindExtensionByName(extensionName)
+		if err != nil {
+			t.Fatalf("types.FindExtensionByName(%q) = %v", extensionName, err)
+		}
+		if got, want := xt.TypeDescriptor().FullName(), extensionName; got != want {
+			t.Fatalf("types.FindExtensionByName(%q).TypeDescriptor().FullName() = %q, want %q", extensionName, got, want)
+		}
+		if got, want := xt.TypeDescriptor().ContainingMessage(), mt.Descriptor(); got != want {
+			t.Fatalf("xt.TypeDescriptor().ContainingMessage() = %q, want %q", got.FullName(), want.FullName())
+		}
+		number := xt.TypeDescriptor().Number()
+		xt2, err := types.FindExtensionByNumber(messageName, number)
+		if err != nil {
+			t.Fatalf("types.FindExtensionByNumber(%q, %v) = %v", messageName, number, err)
+		}
+		if xt != xt2 {
+			t.Fatalf("FindExtensionByName returned a differet extension than FindExtensionByNumber")
+		}
+	}
+}
+
+func TestDynamicTypesFilesChangeAfterCreation(t *testing.T) {
+	files := &protoregistry.Files{}
+	files.RegisterFile(descriptorpb.File_google_protobuf_descriptor_proto)
+	types := dynamicpb.NewTypes(files)
+
+	// Not found: Files registry does not contain this file.
+	const message = "testprotos.Message1"
+	const number = 11
+	if _, err := types.FindMessageByName(message); err != protoregistry.NotFound {
+		t.Errorf("types.FindMessageByName(%q) = %v, want protoregistry.NotFound", message, err)
+	}
+	if _, err := types.FindExtensionByNumber(message, number); err != protoregistry.NotFound {
+		t.Errorf("types.FindExtensionByNumber(%q, %v) = %v, want protoregistry.NotFound", message, number, err)
+	}
+
+	// Found: Add the file to the registry and recheck.
+	files.RegisterFile(registrypb.File_internal_testprotos_registry_test_proto)
+	if _, err := types.FindMessageByName(message); err != nil {
+		t.Errorf("types.FindMessageByName(%q) = %v, want nil", message, err)
+	}
+	if _, err := types.FindExtensionByNumber(message, number); err != nil {
+		t.Errorf("types.FindExtensionByNumber(%q, %v) = %v, want nil", message, number, err)
+	}
+}