internal/legacy: improve performance of extension descriptor conversions

Converting to/from v1/v2 extension descriptor types is a common operation
for v1 and v2 interoperability. Optimize these operations with a cache.

Change-Id: I5feca810f60376847c791654982acd3b6a37a5db
Reviewed-on: https://go-review.googlesource.com/c/152542
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/internal/legacy/extension.go b/internal/legacy/extension.go
index 700bc43..662e75c 100644
--- a/internal/legacy/extension.go
+++ b/internal/legacy/extension.go
@@ -7,6 +7,7 @@
 import (
 	"fmt"
 	"reflect"
+	"sync"
 
 	papi "github.com/golang/protobuf/protoapi"
 	ptag "github.com/golang/protobuf/v2/internal/encoding/tag"
@@ -16,9 +17,39 @@
 	ptype "github.com/golang/protobuf/v2/reflect/prototype"
 )
 
+// extensionDescKey is a comparable version of protoapi.ExtensionDesc
+// suitable for use as a key in a map.
+type extensionDescKey struct {
+	typeV2        pref.ExtensionType
+	extendedType  reflect.Type
+	extensionType reflect.Type
+	field         int32
+	name          string
+	tag           string
+	filename      string
+}
+
+func extensionDescKeyOf(d *papi.ExtensionDesc) extensionDescKey {
+	return extensionDescKey{
+		d.Type,
+		reflect.TypeOf(d.ExtendedType),
+		reflect.TypeOf(d.ExtensionType),
+		d.Field, d.Name, d.Tag, d.Filename,
+	}
+}
+
+var (
+	extensionTypeCache sync.Map // map[extensionDescKey]protoreflect.ExtensionType
+	extensionDescCache sync.Map // map[protoreflect.ExtensionType]*protoapi.ExtensionDesc
+)
+
+// legacyExtensionDescFromType converts a v2 protoreflect.ExtensionType to a
+// v1 protoapi.ExtensionDesc. The returned ExtensionDesc must not be mutated.
 func legacyExtensionDescFromType(t pref.ExtensionType) *papi.ExtensionDesc {
-	if t, ok := t.(dualExtensionType); ok {
-		return t.desc
+	// Fast-path: check the cache for whether this ExtensionType has already
+	// been converted to a legacy descriptor.
+	if d, ok := extensionDescCache.Load(t); ok {
+		return d.(*papi.ExtensionDesc)
 	}
 
 	// Determine the parent type if possible.
@@ -86,7 +117,7 @@
 	}
 
 	// Construct and return a v1 ExtensionDesc.
-	return &papi.ExtensionDesc{
+	d := &papi.ExtensionDesc{
 		Type:          t,
 		ExtendedType:  parent,
 		ExtensionType: reflect.Zero(extType).Interface(),
@@ -95,11 +126,29 @@
 		Tag:           ptag.Marshal(t, enumName),
 		Filename:      filename,
 	}
+	extensionDescCache.Store(t, d)
+	return d
 }
 
+// legacyExtensionTypeFromDesc converts a v1 protoapi.ExtensionDesc to a
+// v2 protoreflect.ExtensionType. The returned descriptor type takes ownership
+// of the input extension desc. The input must not be mutated so long as the
+// returned type is still in use.
 func legacyExtensionTypeFromDesc(d *papi.ExtensionDesc) pref.ExtensionType {
+	// Fast-path: check whether an extension type is already nested within.
 	if d.Type != nil {
-		return dualExtensionType{d.Type, d}
+		// Cache descriptor for future legacyExtensionDescFromType operation.
+		// This assumes that there is only one legacy protoapi.ExtensionDesc
+		// that wraps any given specific protoreflect.ExtensionType.
+		extensionDescCache.LoadOrStore(d.Type, d)
+		return d.Type
+	}
+
+	// Fast-path: check the cache for whether this ExtensionType has already
+	// been converted from a legacy descriptor.
+	dk := extensionDescKeyOf(d)
+	if t, ok := extensionTypeCache.Load(dk); ok {
+		return t.(pref.ExtensionType)
 	}
 
 	// Derive basic field information from the struct tag.
@@ -128,16 +177,13 @@
 		panic(err)
 	}
 	xt := pimpl.Export{}.ExtensionTypeOf(xd, reflect.Zero(t).Interface())
-	return dualExtensionType{xt, d}
-}
 
-type dualExtensionType struct {
-	pref.ExtensionType
-	desc *papi.ExtensionDesc
+	// Cache the conversion for both directions.
+	extensionDescCache.Store(xt, d)
+	extensionTypeCache.Store(dk, xt)
+	return xt
 }
 
-// TODO: Provide custom stringer for dualExtensionType.
-
 // legacyExtensionTypeOf returns a protoreflect.ExtensionType where the GoType
 // is the underlying v1 Go type instead of the wrapper types used to present
 // v1 Go types as if they satisfied the v2 API.
diff --git a/internal/legacy/extension_test.go b/internal/legacy/extension_test.go
new file mode 100644
index 0000000..9f431bd
--- /dev/null
+++ b/internal/legacy/extension_test.go
@@ -0,0 +1,78 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package legacy_test
+
+import (
+	"testing"
+
+	papi "github.com/golang/protobuf/protoapi"
+	pimpl "github.com/golang/protobuf/v2/internal/impl"
+	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+	ptype "github.com/golang/protobuf/v2/reflect/prototype"
+
+	// The legacy package must be imported prior to use of any legacy messages.
+	// TODO: Remove this when protoV1 registers these hooks for you.
+	plegacy "github.com/golang/protobuf/v2/internal/legacy"
+
+	proto2_20180125 "github.com/golang/protobuf/v2/internal/testprotos/legacy/proto2.v1.0.0-20180125-92554152"
+)
+
+type legacyTestMessage struct {
+	XXX_unrecognized []byte
+	papi.XXX_InternalExtensions
+}
+
+func (*legacyTestMessage) Reset()         {}
+func (*legacyTestMessage) String() string { return "" }
+func (*legacyTestMessage) ProtoMessage()  {}
+func (*legacyTestMessage) ExtensionRangeArray() []papi.ExtensionRange {
+	return []papi.ExtensionRange{{Start: 10000, End: 20000}}
+}
+
+func mustMakeExtensionType(x *ptype.StandaloneExtension, v interface{}) pref.ExtensionType {
+	xd, err := ptype.NewExtension(x)
+	if err != nil {
+		panic(xd)
+	}
+	return pimpl.Export{}.ExtensionTypeOf(xd, v)
+}
+
+var (
+	parentType    = pimpl.Export{}.MessageTypeOf((*legacyTestMessage)(nil))
+	messageV1Type = pimpl.Export{}.MessageTypeOf((*proto2_20180125.Message_ChildMessage)(nil))
+
+	wantType = mustMakeExtensionType(&ptype.StandaloneExtension{
+		FullName:     "fizz.buzz.optional_message_v1",
+		Number:       10007,
+		Cardinality:  pref.Optional,
+		Kind:         pref.MessageKind,
+		MessageType:  messageV1Type,
+		ExtendedType: parentType,
+	}, (*proto2_20180125.Message_ChildMessage)(nil))
+	wantDesc = &papi.ExtensionDesc{
+		ExtendedType:  (*legacyTestMessage)(nil),
+		ExtensionType: (*proto2_20180125.Message_ChildMessage)(nil),
+		Field:         10007,
+		Name:          "fizz.buzz.optional_message_v1",
+		Tag:           "bytes,10007,opt,name=optional_message_v1",
+	}
+)
+
+func BenchmarkConvert(b *testing.B) {
+	b.ReportAllocs()
+	for i := 0; i < b.N; i++ {
+		xd := plegacy.Export{}.ExtensionDescFromType(wantType)
+		gotType := plegacy.Export{}.ExtensionTypeFromDesc(xd)
+		if gotType != wantType {
+			b.Fatalf("ExtensionType mismatch: got %p, want %p", gotType, wantType)
+		}
+
+		xt := plegacy.Export{}.ExtensionTypeFromDesc(wantDesc)
+		gotDesc := plegacy.Export{}.ExtensionDescFromType(xt)
+		if gotDesc != wantDesc {
+			b.Fatalf("ExtensionDesc mismatch: got %p, want %p", gotDesc, wantDesc)
+		}
+	}
+}