internal/impl: support legacy unknown fields

Add wrapper data structures to get legacy XXX_unrecognized fields to support
the new protoreflect.UnknownFields interface. This is a challenge since the
field is a []byte, which does not give us much flexibility to work with
in terms of choice of data structures.

This implementation is relatively naive where every operation is O(n) since
it needs to strip through the entire []byte each time. The Range operation
operates slightly differently from ranging over Go maps since it presents a
stale version of RawFields should a mutation occur while ranging.
This distinction is unlikely to affect anyone in practice.

Change-Id: Ib3247cb827f9a0dd6c2192cd59830dca5eef8257
Reviewed-on: https://go-review.googlesource.com/c/144697
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go
index fc3a2ee..32253f3 100644
--- a/internal/impl/legacy_test.go
+++ b/internal/impl/legacy_test.go
@@ -5,9 +5,12 @@
 package impl
 
 import (
+	"bytes"
+	"math"
 	"reflect"
 	"testing"
 
+	"github.com/golang/protobuf/v2/internal/encoding/pack"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
 	ptype "github.com/golang/protobuf/v2/reflect/prototype"
@@ -25,7 +28,7 @@
 var fileDescLP2 = mustLoadFileDesc(LP2FileDescriptor)
 var fileDescLP3 = mustLoadFileDesc(LP3FileDescriptor)
 
-func TestLegacy(t *testing.T) {
+func TestLegacyDescriptor(t *testing.T) {
 	tests := []struct {
 		got  pref.Descriptor
 		want pref.Descriptor
@@ -133,3 +136,122 @@
 		})
 	}
 }
+
+func TestLegacyUnknown(t *testing.T) {
+	rawOf := func(toks ...pack.Token) pref.RawFields {
+		return pref.RawFields(pack.Message(toks).Marshal())
+	}
+	raw1a := rawOf(pack.Tag{1, pack.VarintType}, pack.Svarint(-4321))                // 08c143
+	raw1b := rawOf(pack.Tag{1, pack.Fixed32Type}, pack.Uint32(0xdeadbeef))           // 0defbeadde
+	raw1c := rawOf(pack.Tag{1, pack.Fixed64Type}, pack.Float64(math.Pi))             // 09182d4454fb210940
+	raw2a := rawOf(pack.Tag{2, pack.BytesType}, pack.String("hello, world!"))        // 120d68656c6c6f2c20776f726c6421
+	raw2b := rawOf(pack.Tag{2, pack.VarintType}, pack.Uvarint(1234))                 // 10d209
+	raw3a := rawOf(pack.Tag{3, pack.StartGroupType}, pack.Tag{3, pack.EndGroupType}) // 1b1c
+	raw3b := rawOf(pack.Tag{3, pack.BytesType}, pack.Bytes("\xde\xad\xbe\xef"))      // 1a04deadbeef
+
+	joinRaw := func(bs ...pref.RawFields) (out []byte) {
+		for _, b := range bs {
+			out = append(out, b...)
+		}
+		return out
+	}
+
+	var fs legacyUnknownBytes
+	if got, want := fs.Len(), 0; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	fs.Set(1, raw1a)
+	fs.Set(1, append(fs.Get(1), raw1b...))
+	fs.Set(1, append(fs.Get(1), raw1c...))
+	if got, want := fs.Len(), 1; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	fs.Set(2, raw2a)
+	if got, want := fs.Len(), 2; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c, raw2a); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	if got, want := fs.Get(1), joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
+		t.Errorf("Get(%d) = %x, want %x", 1, got, want)
+	}
+	if got, want := fs.Get(2), joinRaw(raw2a); !bytes.Equal(got, want) {
+		t.Errorf("Get(%d) = %x, want %x", 2, got, want)
+	}
+	if got, want := fs.Get(3), joinRaw(); !bytes.Equal(got, want) {
+		t.Errorf("Get(%d) = %x, want %x", 3, got, want)
+	}
+
+	fs.Set(1, nil) // remove field 1
+	if got, want := fs.Len(), 1; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw2a); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	// Simulate manual appending of raw field data.
+	fs = append(fs, joinRaw(raw3a, raw1a, raw1b, raw3b, raw2b, raw1c)...)
+	if got, want := fs.Len(), 3; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+
+	// Verify range iteration order.
+	var i int
+	want := []struct {
+		num pref.FieldNumber
+		raw pref.RawFields
+	}{
+		{3, joinRaw(raw3a, raw3b)},
+		{2, joinRaw(raw2a, raw2b)},
+		{1, joinRaw(raw1a, raw1b, raw1c)},
+	}
+	fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool {
+		if i < len(want) {
+			if num != want[i].num || !bytes.Equal(raw, want[i].raw) {
+				t.Errorf("Range(%d) = (%d, %x), want (%d, %x)", i, num, raw, want[i].num, want[i].raw)
+			}
+		} else {
+			t.Errorf("unexpected Range iteration: %d", i)
+		}
+		i++
+		return true
+	})
+
+	fs.Set(2, fs.Get(2)) // moves field 2 to the end
+	if got, want := fs.Len(), 3; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw3a, raw1a, raw1b, raw3b, raw1c, raw2a, raw2b); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+	fs.Set(1, nil) // remove field 1
+	if got, want := fs.Len(), 2; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(raw3a, raw3b, raw2a, raw2b); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+
+	// Remove all fields.
+	fs.Range(func(n pref.FieldNumber, b pref.RawFields) bool {
+		fs.Set(n, nil)
+		return true
+	})
+	if got, want := fs.Len(), 0; got != want {
+		t.Errorf("Len() = %d, want %d", got, want)
+	}
+	if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) {
+		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
+	}
+}
diff --git a/internal/impl/legacy_unknown.go b/internal/impl/legacy_unknown.go
new file mode 100644
index 0000000..c8c3a69
--- /dev/null
+++ b/internal/impl/legacy_unknown.go
@@ -0,0 +1,142 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+	"reflect"
+
+	protoV1 "github.com/golang/protobuf/proto"
+	"github.com/golang/protobuf/v2/internal/encoding/wire"
+	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+var (
+	extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil))
+	extTypeB = reflect.TypeOf(protoV1.XXX_InternalExtensions{})
+)
+
+func generateLegacyUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) func(p *messageDataType) pref.UnknownFields {
+	fu, ok := t.FieldByName("XXX_unrecognized")
+	if !ok || fu.Type != bytesType {
+		return nil
+	}
+	fx1, _ := t.FieldByName("XXX_extensions")
+	fx2, _ := t.FieldByName("XXX_InternalExtensions")
+	if fx1.Type == extTypeA || fx2.Type == extTypeB {
+		// TODO: In proto v1, the unknown fields are split between both
+		// XXX_unrecognized and XXX_InternalExtensions. If the message supports
+		// extensions, then we will need to create a wrapper data structure
+		// that presents unknown fields in both lists as a single ordered list.
+		panic("not implemented")
+	}
+	fieldOffset := offsetOf(fu)
+	return func(p *messageDataType) pref.UnknownFields {
+		rv := p.p.apply(fieldOffset).asType(bytesType)
+		return (*legacyUnknownBytes)(rv.Interface().(*[]byte))
+	}
+}
+
+// legacyUnknownBytes is a wrapper around XXX_unrecognized that implements
+// the protoreflect.UnknownFields interface. This is challenging since we are
+// limited to a []byte, so we do not have much flexibility in the choice
+// of data structure that would have been ideal.
+type legacyUnknownBytes []byte
+
+func (fs *legacyUnknownBytes) Len() int {
+	// Runtime complexity: O(n)
+	b := *fs
+	m := map[pref.FieldNumber]bool{}
+	for len(b) > 0 {
+		num, _, n := wire.ConsumeField(b)
+		m[num] = true
+		b = b[n:]
+	}
+	return len(m)
+}
+
+func (fs *legacyUnknownBytes) Get(num pref.FieldNumber) (raw pref.RawFields) {
+	// Runtime complexity: O(n)
+	b := *fs
+	for len(b) > 0 {
+		num2, _, n := wire.ConsumeField(b)
+		if num == num2 {
+			raw = append(raw, b[:n]...)
+		}
+		b = b[n:]
+	}
+	return raw
+}
+
+func (fs *legacyUnknownBytes) Set(num pref.FieldNumber, raw pref.RawFields) {
+	num2, _, _ := wire.ConsumeTag(raw)
+	if len(raw) > 0 && (!raw.IsValid() || num != num2) {
+		panic("invalid raw fields")
+	}
+
+	// Remove all current fields of num.
+	// Runtime complexity: O(n)
+	b := *fs
+	out := (*fs)[:0]
+	for len(b) > 0 {
+		num2, _, n := wire.ConsumeField(b)
+		if num != num2 {
+			out = append(out, b[:n]...)
+		}
+		b = b[n:]
+	}
+	*fs = out
+
+	// Append new fields of num.
+	*fs = append(*fs, raw...)
+}
+
+func (fs *legacyUnknownBytes) Range(f func(pref.FieldNumber, pref.RawFields) bool) {
+	type entry struct {
+		num pref.FieldNumber
+		raw pref.RawFields
+	}
+	var xs []entry
+
+	// Collect up a list of all the raw fields.
+	// We preserve the order such that the latest encountered fields
+	// are presented at the end.
+	//
+	// Runtime complexity: O(n)
+	b := *fs
+	m := map[pref.FieldNumber]int{}
+	for len(b) > 0 {
+		num, _, n := wire.ConsumeField(b)
+
+		// Ensure the most recently updated entry is always at the end of xs.
+		x := entry{num: num}
+		if i, ok := m[num]; ok {
+			j := len(xs) - 1
+			xs[i], xs[j] = xs[j], xs[i] // swap current entry with last entry
+			m[xs[i].num] = i            // update index of swapped entry
+			x = xs[j]                   // retrieve the last entry
+			xs = xs[:j]                 // truncate off the last entry
+		}
+		m[num] = len(xs)
+		x.raw = append(x.raw, b[:n]...)
+		xs = append(xs, x)
+
+		b = b[n:]
+	}
+
+	// Iterate over all the raw fields.
+	// This ranges over a snapshot of the current state such that mutations
+	// while ranging are not observable.
+	//
+	// Runtime complexity: O(n)
+	for _, x := range xs {
+		if !f(x.num, x.raw) {
+			return
+		}
+	}
+}
+
+func (fs *legacyUnknownBytes) IsSupported() bool {
+	return true
+}
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 7552157..2b0a9ca 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -163,7 +163,10 @@
 }
 
 func (mi *MessageType) generateUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
-	// TODO
+	if f := generateLegacyUnknownFieldFuncs(t, md); f != nil {
+		mi.unknownFields = f
+		return
+	}
 	mi.unknownFields = func(*messageDataType) pref.UnknownFields {
 		return emptyUnknownFields{}
 	}
diff --git a/reflect/protoreflect/value.go b/reflect/protoreflect/value.go
index 7e83728..ed77d93 100644
--- a/reflect/protoreflect/value.go
+++ b/reflect/protoreflect/value.go
@@ -159,17 +159,20 @@
 // and also the wire data itself.
 //
 // Once stored, the content of a RawFields must be treated as immutable.
-// (e.g., raw[:len(raw)] is immutable, but raw[len(raw):cap(raw)] is mutable).
-// Thus, appending to RawFields (with valid wire data) is permitted.
+// The capacity of RawFields may be treated as mutable only for the use-case of
+// appending additional data to store back into UnknownFields.
 type RawFields []byte
 
 // IsValid reports whether RawFields is syntactically correct wire format.
+// All fields must belong to the same field number.
 func (b RawFields) IsValid() bool {
+	var want FieldNumber
 	for len(b) > 0 {
-		_, _, n := wire.ConsumeField(b)
-		if n < 0 {
+		got, _, n := wire.ConsumeField(b)
+		if n < 0 || (want > 0 && got != want) {
 			return false
 		}
+		want = got
 		b = b[n:]
 	}
 	return true