internal/impl: support legacy unknown extension fields

The unknown fields in legacy messages is split across the XXX_unrecognized
field and also the XXX_InternalExtensions field. Implement support for
wrapping both fields and presenting it as if it were a unified set of
unknown fields.

Change-Id: If274fae2b48962520edd8a640080b6eced747684
Reviewed-on: https://go-review.googlesource.com/c/146517
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/legacy_extension.go b/internal/impl/legacy_extension.go
new file mode 100644
index 0000000..b229581
--- /dev/null
+++ b/internal/impl/legacy_extension.go
@@ -0,0 +1,126 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+	"reflect"
+	"sync"
+	"unsafe"
+
+	protoV1 "github.com/golang/protobuf/proto"
+	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// TODO: The logic below this is a hack since v1 currently exposes no
+// exported functionality for interacting with these data structures.
+// Eventually make changes to v1 such that v2 can access the necessary
+// fields without relying on unsafe.
+
+var (
+	extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil))
+	extTypeB = reflect.TypeOf(protoV1.XXX_InternalExtensions{})
+)
+
+type legacyExtensionIface interface {
+	Len() int
+	Get(pref.FieldNumber) legacyExtensionEntry
+	Set(pref.FieldNumber, legacyExtensionEntry)
+	Range(f func(pref.FieldNumber, legacyExtensionEntry) bool)
+}
+
+func makeLegacyExtensionMapFunc(t reflect.Type) func(*messageDataType) legacyExtensionIface {
+	fx1, _ := t.FieldByName("XXX_extensions")
+	fx2, _ := t.FieldByName("XXX_InternalExtensions")
+	switch {
+	case fx1.Type == extTypeA:
+		return func(p *messageDataType) legacyExtensionIface {
+			rv := p.p.asType(t).Elem()
+			return (*legacyExtensionMap)(unsafe.Pointer(rv.UnsafeAddr() + fx1.Offset))
+		}
+	case fx2.Type == extTypeB:
+		return func(p *messageDataType) legacyExtensionIface {
+			rv := p.p.asType(t).Elem()
+			return (*legacyExtensionSyncMap)(unsafe.Pointer(rv.UnsafeAddr() + fx2.Offset))
+		}
+	default:
+		return nil
+	}
+}
+
+// legacyExtensionSyncMap is identical to protoV1.XXX_InternalExtensions.
+// It implements legacyExtensionIface.
+type legacyExtensionSyncMap struct {
+	p *struct {
+		mu sync.Mutex
+		m  legacyExtensionMap
+	}
+}
+
+func (m legacyExtensionSyncMap) Len() int {
+	if m.p == nil {
+		return 0
+	}
+	m.p.mu.Lock()
+	defer m.p.mu.Unlock()
+	return m.p.m.Len()
+}
+func (m legacyExtensionSyncMap) Get(n pref.FieldNumber) legacyExtensionEntry {
+	if m.p == nil {
+		return legacyExtensionEntry{}
+	}
+	m.p.mu.Lock()
+	defer m.p.mu.Unlock()
+	return m.p.m.Get(n)
+}
+func (m *legacyExtensionSyncMap) Set(n pref.FieldNumber, x legacyExtensionEntry) {
+	if m.p == nil {
+		m.p = new(struct {
+			mu sync.Mutex
+			m  legacyExtensionMap
+		})
+	}
+	m.p.mu.Lock()
+	defer m.p.mu.Unlock()
+	m.p.m.Set(n, x)
+}
+func (m legacyExtensionSyncMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) {
+	if m.p == nil {
+		return
+	}
+	m.p.mu.Lock()
+	defer m.p.mu.Unlock()
+	m.p.m.Range(f)
+}
+
+// legacyExtensionMap is identical to map[int32]protoV1.Extension.
+// It implements legacyExtensionIface.
+type legacyExtensionMap map[pref.FieldNumber]legacyExtensionEntry
+
+func (m legacyExtensionMap) Len() int {
+	return len(m)
+}
+func (m legacyExtensionMap) Get(n pref.FieldNumber) legacyExtensionEntry {
+	return m[n]
+}
+func (m *legacyExtensionMap) Set(n pref.FieldNumber, x legacyExtensionEntry) {
+	if *m == nil {
+		*m = make(map[pref.FieldNumber]legacyExtensionEntry)
+	}
+	(*m)[n] = x
+}
+func (m legacyExtensionMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) {
+	for n, x := range m {
+		if !f(n, x) {
+			return
+		}
+	}
+}
+
+// legacyExtensionEntry is identical to protoV1.Extension.
+type legacyExtensionEntry struct {
+	desc *protoV1.ExtensionDesc
+	val  interface{}
+	raw  []byte
+}
diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go
index d2e71d4..caa0043 100644
--- a/internal/impl/legacy_test.go
+++ b/internal/impl/legacy_test.go
@@ -10,6 +10,7 @@
 	"reflect"
 	"testing"
 
+	protoV1 "github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/v2/internal/encoding/pack"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
@@ -137,6 +138,15 @@
 	}
 }
 
+type legacyUnknownMessage struct {
+	XXX_unrecognized []byte
+	protoV1.XXX_InternalExtensions
+}
+
+func (*legacyUnknownMessage) ExtensionRangeArray() []protoV1.ExtensionRange {
+	return []protoV1.ExtensionRange{{Start: 10, End: 20}, {Start: 40, End: 80}}
+}
+
 func TestLegacyUnknown(t *testing.T) {
 	rawOf := func(toks ...pack.Token) pref.RawFields {
 		return pref.RawFields(pack.Message(toks).Marshal())
@@ -149,6 +159,17 @@
 	raw3a := rawOf(pack.Tag{3, pack.StartGroupType}, pack.Tag{3, pack.EndGroupType}) // 1b1c
 	raw3b := rawOf(pack.Tag{3, pack.BytesType}, pack.Bytes("\xde\xad\xbe\xef"))      // 1a04deadbeef
 
+	raw1 := rawOf(pack.Tag{1, pack.BytesType}, pack.Bytes("1"))    // 0a0131
+	raw3 := rawOf(pack.Tag{3, pack.BytesType}, pack.Bytes("3"))    // 1a0133
+	raw10 := rawOf(pack.Tag{10, pack.BytesType}, pack.Bytes("10")) // 52023130 - extension
+	raw15 := rawOf(pack.Tag{15, pack.BytesType}, pack.Bytes("15")) // 7a023135 - extension
+	raw26 := rawOf(pack.Tag{26, pack.BytesType}, pack.Bytes("26")) // d201023236
+	raw32 := rawOf(pack.Tag{32, pack.BytesType}, pack.Bytes("32")) // 8202023332
+	raw45 := rawOf(pack.Tag{45, pack.BytesType}, pack.Bytes("45")) // ea02023435 - extension
+	raw46 := rawOf(pack.Tag{45, pack.BytesType}, pack.Bytes("46")) // ea02023436 - extension
+	raw47 := rawOf(pack.Tag{45, pack.BytesType}, pack.Bytes("47")) // ea02023437 - extension
+	raw99 := rawOf(pack.Tag{99, pack.BytesType}, pack.Bytes("99")) // 9a06023939
+
 	joinRaw := func(bs ...pref.RawFields) (out []byte) {
 		for _, b := range bs {
 			out = append(out, b...)
@@ -156,11 +177,13 @@
 		return out
 	}
 
-	var fs legacyUnknownBytes
+	m := new(legacyUnknownMessage)
+	fs := new(MessageType).MessageOf(m).UnknownFields()
+
 	if got, want := fs.Len(), 0; got != want {
 		t.Errorf("Len() = %d, want %d", got, want)
 	}
-	if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) {
+	if got, want := m.XXX_unrecognized, joinRaw(); !bytes.Equal(got, want) {
 		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
 	}
 
@@ -170,7 +193,7 @@
 	if got, want := fs.Len(), 1; got != want {
 		t.Errorf("Len() = %d, want %d", got, want)
 	}
-	if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
+	if got, want := m.XXX_unrecognized, joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
 		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
 	}
 
@@ -178,7 +201,7 @@
 	if got, want := fs.Len(), 2; got != want {
 		t.Errorf("Len() = %d, want %d", got, want)
 	}
-	if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c, raw2a); !bytes.Equal(got, want) {
+	if got, want := m.XXX_unrecognized, joinRaw(raw1a, raw1b, raw1c, raw2a); !bytes.Equal(got, want) {
 		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
 	}
 
@@ -196,12 +219,12 @@
 	if got, want := fs.Len(), 1; got != want {
 		t.Errorf("Len() = %d, want %d", got, want)
 	}
-	if got, want := []byte(fs), joinRaw(raw2a); !bytes.Equal(got, want) {
+	if got, want := m.XXX_unrecognized, joinRaw(raw2a); !bytes.Equal(got, want) {
 		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
 	}
 
 	// Simulate manual appending of raw field data.
-	fs = append(fs, joinRaw(raw3a, raw1a, raw1b, raw2b, raw3b, raw1c)...)
+	m.XXX_unrecognized = append(m.XXX_unrecognized, joinRaw(raw3a, raw1a, raw1b, raw2b, raw3b, raw1c)...)
 	if got, want := fs.Len(), 3; got != want {
 		t.Errorf("Len() = %d, want %d", got, want)
 	}
@@ -232,14 +255,14 @@
 	if got, want := fs.Len(), 3; got != want {
 		t.Errorf("Len() = %d, want %d", got, want)
 	}
-	if got, want := []byte(fs), joinRaw(raw3a, raw1a, raw1b, raw3b, raw1c, raw2a, raw2b); !bytes.Equal(got, want) {
+	if got, want := m.XXX_unrecognized, joinRaw(raw3a, raw1a, raw1b, raw3b, raw1c, raw2a, raw2b); !bytes.Equal(got, want) {
 		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
 	}
 	fs.Set(1, nil) // remove field 1
 	if got, want := fs.Len(), 2; got != want {
 		t.Errorf("Len() = %d, want %d", got, want)
 	}
-	if got, want := []byte(fs), joinRaw(raw3a, raw3b, raw2a, raw2b); !bytes.Equal(got, want) {
+	if got, want := m.XXX_unrecognized, joinRaw(raw3a, raw3b, raw2a, raw2b); !bytes.Equal(got, want) {
 		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
 	}
 
@@ -251,7 +274,102 @@
 	if got, want := fs.Len(), 0; got != want {
 		t.Errorf("Len() = %d, want %d", got, want)
 	}
-	if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) {
+	if got, want := m.XXX_unrecognized, joinRaw(); !bytes.Equal(got, want) {
 		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
 	}
+
+	fs.Set(1, raw1)
+	if got, want := fs.Len(), 1; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := m.XXX_unrecognized, joinRaw(raw1); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	fs.Set(45, raw45)
+	fs.Set(10, raw10) // extension
+	fs.Set(32, raw32)
+	fs.Set(1, nil) // deletion
+	fs.Set(26, raw26)
+	fs.Set(47, raw47) // extension
+	fs.Set(46, raw46) // extension
+	if got, want := fs.Len(), 6; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := m.XXX_unrecognized, joinRaw(raw32, raw26); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	// Verify iteration order.
+	i = 0
+	want = []struct {
+		num pref.FieldNumber
+		raw pref.RawFields
+	}{
+		{32, raw32},
+		{26, raw26},
+		{10, raw10}, // extension
+		{45, raw45}, // extension
+		{46, raw46}, // extension
+		{47, raw47}, // extension
+	}
+	fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool {
+		if i < len(want) {
+			if num != want[i].num || !bytes.Equal(raw, want[i].raw) {
+				t.Errorf("Range(%d) = (%d, %x), want (%d, %x)", i, num, raw, want[i].num, want[i].raw)
+			}
+		} else {
+			t.Errorf("unexpected Range iteration: %d", i)
+		}
+		i++
+		return true
+	})
+
+	// Perform partial deletion while iterating.
+	i = 0
+	fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool {
+		if i%2 == 0 {
+			fs.Set(num, nil)
+		}
+		i++
+		return true
+	})
+
+	if got, want := fs.Len(), 3; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := m.XXX_unrecognized, joinRaw(raw26); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	fs.Set(15, raw15) // extension
+	fs.Set(3, raw3)
+	fs.Set(99, raw99)
+	if got, want := fs.Len(), 6; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := m.XXX_unrecognized, joinRaw(raw26, raw3, raw99); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	// Perform partial iteration.
+	i = 0
+	want = []struct {
+		num pref.FieldNumber
+		raw pref.RawFields
+	}{
+		{26, raw26},
+		{3, raw3},
+	}
+	fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool {
+		if i < len(want) {
+			if num != want[i].num || !bytes.Equal(raw, want[i].raw) {
+				t.Errorf("Range(%d) = (%d, %x), want (%d, %x)", i, num, raw, want[i].num, want[i].raw)
+			}
+		} else {
+			t.Errorf("unexpected Range iteration: %d", i)
+		}
+		i++
+		return i < 2
+	})
 }
diff --git a/internal/impl/legacy_unknown.go b/internal/impl/legacy_unknown.go
index a319f05..9ab617b 100644
--- a/internal/impl/legacy_unknown.go
+++ b/internal/impl/legacy_unknown.go
@@ -7,36 +7,104 @@
 import (
 	"container/list"
 	"reflect"
+	"sort"
 
-	protoV1 "github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/v2/internal/encoding/wire"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
-var (
-	extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil))
-	extTypeB = reflect.TypeOf(protoV1.XXX_InternalExtensions{})
-)
-
-func generateLegacyUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) func(p *messageDataType) pref.UnknownFields {
+func makeLegacyUnknownFieldsFunc(t reflect.Type) func(p *messageDataType) pref.UnknownFields {
 	fu, ok := t.FieldByName("XXX_unrecognized")
 	if !ok || fu.Type != bytesType {
 		return nil
 	}
-	fx1, _ := t.FieldByName("XXX_extensions")
-	fx2, _ := t.FieldByName("XXX_InternalExtensions")
-	if fx1.Type == extTypeA || fx2.Type == extTypeB {
-		// TODO: In proto v1, the unknown fields are split between both
-		// XXX_unrecognized and XXX_InternalExtensions. If the message supports
-		// extensions, then we will need to create a wrapper data structure
-		// that presents unknown fields in both lists as a single ordered list.
-		panic("not implemented")
-	}
 	fieldOffset := offsetOf(fu)
-	return func(p *messageDataType) pref.UnknownFields {
+	unkFunc := func(p *messageDataType) pref.UnknownFields {
 		rv := p.p.apply(fieldOffset).asType(bytesType)
 		return (*legacyUnknownBytes)(rv.Interface().(*[]byte))
 	}
+	extFunc := makeLegacyExtensionMapFunc(t)
+	if extFunc != nil {
+		return func(p *messageDataType) pref.UnknownFields {
+			return &legacyUnknownBytesAndExtensionMap{
+				unkFunc(p), extFunc(p), p.mi.Desc.ExtensionRanges(),
+			}
+		}
+	}
+	return unkFunc
+}
+
+// legacyUnknownBytesAndExtensionMap is a wrapper around both XXX_unrecognized
+// and also the extension field map.
+type legacyUnknownBytesAndExtensionMap struct {
+	u pref.UnknownFields
+	x legacyExtensionIface
+	r pref.FieldRanges
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) Len() int {
+	n := fs.u.Len()
+	fs.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
+		if len(x.raw) > 0 {
+			n++
+		}
+		return true
+	})
+	return n
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) Get(num pref.FieldNumber) (raw pref.RawFields) {
+	if fs.r.Has(num) {
+		return fs.x.Get(num).raw
+	}
+	return fs.u.Get(num)
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) Set(num pref.FieldNumber, raw pref.RawFields) {
+	if fs.r.Has(num) {
+		x := fs.x.Get(num)
+		x.raw = raw
+		fs.x.Set(num, x)
+		return
+	}
+	fs.u.Set(num, raw)
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) Range(f func(pref.FieldNumber, pref.RawFields) bool) {
+	// Range over unknown fields not in the extension range.
+	// Create a closure around f to capture whether iteration terminated early.
+	var stop bool
+	fs.u.Range(func(n pref.FieldNumber, b pref.RawFields) bool {
+		stop = stop || !f(n, b)
+		return !stop
+	})
+	if stop {
+		return
+	}
+
+	// Range over unknown fields in the extension range in ascending order
+	// to ensure protoreflect.UnknownFields.Range remains deterministic.
+	type entry struct {
+		num pref.FieldNumber
+		raw pref.RawFields
+	}
+	var xs []entry
+	fs.x.Range(func(n pref.FieldNumber, x legacyExtensionEntry) bool {
+		if len(x.raw) > 0 {
+			xs = append(xs, entry{n, x.raw})
+		}
+		return true
+	})
+	sort.Slice(xs, func(i, j int) bool { return xs[i].num < xs[j].num })
+	for _, x := range xs {
+		if !f(x.num, x.raw) {
+			return
+		}
+	}
+}
+
+func (fs *legacyUnknownBytesAndExtensionMap) IsSupported() bool {
+	return true
 }
 
 // legacyUnknownBytes is a wrapper around XXX_unrecognized that implements
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 2b0a9ca..a3e4b9d 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -53,9 +53,8 @@
 		mi.goType = t
 
 		// Derive the message descriptor if unspecified.
-		md := mi.Desc
-		if md == nil {
-			// TODO: derive the message type from the Go struct type
+		if mi.Desc == nil {
+			mi.Desc = loadMessageDesc(t)
 		}
 
 		// Initialize the Go message type wrapper if the Go type does not
@@ -68,7 +67,7 @@
 		// Generated code ensures that this property holds.
 		if _, ok := p.(pref.ProtoMessage); !ok {
 			mi.pbType = ptype.NewGoMessage(&ptype.GoMessage{
-				MessageDescriptor: md,
+				MessageDescriptor: mi.Desc,
 				New: func(pref.MessageType) pref.ProtoMessage {
 					p := reflect.New(t.Elem()).Interface()
 					return (*message)(mi.dataTypeOf(p))
@@ -76,9 +75,9 @@
 			})
 		}
 
-		mi.generateKnownFieldFuncs(t.Elem(), md)
-		mi.generateUnknownFieldFuncs(t.Elem(), md)
-		mi.generateExtensionFieldFuncs(t.Elem(), md)
+		mi.makeKnownFieldsFunc(t.Elem())
+		mi.makeUnknownFieldsFunc(t.Elem())
+		mi.makeExtensionFieldsFunc(t.Elem())
 	})
 
 	// TODO: Remove this check? This API is primarily used by generated code,
@@ -90,14 +89,14 @@
 	}
 }
 
-// generateKnownFieldFuncs generates per-field functions for all operations
+// makeKnownFieldsFunc generates per-field functions for all operations
 // to be performed on each field. It takes in a reflect.Type representing the
 // Go struct, and a protoreflect.MessageDescriptor to match with the fields
 // in the struct.
 //
 // This code assumes that the struct is well-formed and panics if there are
 // any discrepancies.
-func (mi *MessageType) generateKnownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
+func (mi *MessageType) makeKnownFieldsFunc(t reflect.Type) {
 	// Generate a mapping of field numbers and names to Go struct field or type.
 	fields := map[pref.FieldNumber]reflect.StructField{}
 	oneofs := map[pref.Name]reflect.StructField{}
@@ -140,8 +139,8 @@
 	}
 
 	mi.fields = map[pref.FieldNumber]*fieldInfo{}
-	for i := 0; i < md.Fields().Len(); i++ {
-		fd := md.Fields().Get(i)
+	for i := 0; i < mi.Desc.Fields().Len(); i++ {
+		fd := mi.Desc.Fields().Get(i)
 		fs := fields[fd.Number()]
 		var fi fieldInfo
 		switch {
@@ -162,8 +161,8 @@
 	}
 }
 
-func (mi *MessageType) generateUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
-	if f := generateLegacyUnknownFieldFuncs(t, md); f != nil {
+func (mi *MessageType) makeUnknownFieldsFunc(t reflect.Type) {
+	if f := makeLegacyUnknownFieldsFunc(t); f != nil {
 		mi.unknownFields = f
 		return
 	}
@@ -172,7 +171,7 @@
 	}
 }
 
-func (mi *MessageType) generateExtensionFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
+func (mi *MessageType) makeExtensionFieldsFunc(t reflect.Type) {
 	// TODO
 	mi.extensionFields = func(*messageDataType) pref.KnownFields {
 		return emptyExtensionFields{}