internal/impl: add runtime support for *[]byte unknown representation

This CL adds runtime support for unknown fields to be represented
as *[]byte in addition to the current representation as []byte.
This CL does not change generated code to use *[]byte.

Comparison between using *[]byte and []byte:

• Every message supports unknown fields, so use of []byte
  expands a message size by 24B (for 64-bit systems).
  In contrast, *[]byte only expands a message by 8B.
  This has significant memory implications for small messages.

• If unknown fields are encountered, *[]byte has extra overhead
  allocating the 24B slice header. However, it is assumed
  that messages rarely see any unknown fields at runtime,
  or generally do so for a temporary period of time.

Change-Id: I81935e4ea7394166e61ff4579f76f59fa792dfc9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/244937
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index 733ad9f..e108642 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -27,6 +27,7 @@
 	coderFields        map[protowire.Number]*coderFieldInfo
 	sizecacheOffset    offset
 	unknownOffset      offset
+	unknownPtrKind     bool
 	extensionOffset    offset
 	needsInitCheck     bool
 	isMessageSet       bool
@@ -47,9 +48,20 @@
 }
 
 func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
-	mi.sizecacheOffset = si.sizecacheOffset
-	mi.unknownOffset = si.unknownOffset
-	mi.extensionOffset = si.extensionOffset
+	mi.sizecacheOffset = invalidOffset
+	mi.unknownOffset = invalidOffset
+	mi.extensionOffset = invalidOffset
+
+	if si.sizecacheOffset.IsValid() && si.sizecacheType == sizecacheType {
+		mi.sizecacheOffset = si.sizecacheOffset
+	}
+	if si.unknownOffset.IsValid() && (si.unknownType == unknownFieldsAType || si.unknownType == unknownFieldsBType) {
+		mi.unknownOffset = si.unknownOffset
+		mi.unknownPtrKind = si.unknownType.Kind() == reflect.Ptr
+	}
+	if si.extensionOffset.IsValid() && si.extensionType == extensionFieldsType {
+		mi.extensionOffset = si.extensionOffset
+	}
 
 	mi.coderFields = make(map[protowire.Number]*coderFieldInfo)
 	fields := mi.Desc.Fields()
@@ -157,3 +169,28 @@
 		mi.methods.Merge = mi.merge
 	}
 }
+
+// getUnknownBytes returns a *[]byte for the unknown fields.
+// It is the caller's responsibility to check whether the pointer is nil.
+// This function is specially designed to be inlineable.
+func (mi *MessageInfo) getUnknownBytes(p pointer) *[]byte {
+	if mi.unknownPtrKind {
+		return *p.Apply(mi.unknownOffset).BytesPtr()
+	} else {
+		return p.Apply(mi.unknownOffset).Bytes()
+	}
+}
+
+// mutableUnknownBytes returns a *[]byte for the unknown fields.
+// The returned pointer is guaranteed to not be nil.
+func (mi *MessageInfo) mutableUnknownBytes(p pointer) *[]byte {
+	if mi.unknownPtrKind {
+		bp := p.Apply(mi.unknownOffset).BytesPtr()
+		if *bp == nil {
+			*bp = new([]byte)
+		}
+		return *bp
+	} else {
+		return p.Apply(mi.unknownOffset).Bytes()
+	}
+}
diff --git a/internal/impl/codec_messageset.go b/internal/impl/codec_messageset.go
index cfb68e1..b7a23fa 100644
--- a/internal/impl/codec_messageset.go
+++ b/internal/impl/codec_messageset.go
@@ -29,8 +29,9 @@
 		size += xi.funcs.size(x.Value(), protowire.SizeTag(messageset.FieldMessage), opts)
 	}
 
-	unknown := *p.Apply(mi.unknownOffset).Bytes()
-	size += messageset.SizeUnknown(unknown)
+	if u := mi.getUnknownBytes(p); u != nil {
+		size += messageset.SizeUnknown(*u)
+	}
 
 	return size
 }
@@ -69,10 +70,12 @@
 		}
 	}
 
-	unknown := *p.Apply(mi.unknownOffset).Bytes()
-	b, err := messageset.AppendUnknown(b, unknown)
-	if err != nil {
-		return b, err
+	if u := mi.getUnknownBytes(p); u != nil {
+		var err error
+		b, err = messageset.AppendUnknown(b, *u)
+		if err != nil {
+			return b, err
+		}
 	}
 
 	return b, nil
@@ -100,13 +103,13 @@
 		*ep = make(map[int32]ExtensionField)
 	}
 	ext := *ep
-	unknown := p.Apply(mi.unknownOffset).Bytes()
 	initialized := true
 	err = messageset.Unmarshal(b, true, func(num protowire.Number, v []byte) error {
 		o, err := mi.unmarshalExtension(v, num, protowire.BytesType, ext, opts)
 		if err == errUnknown {
-			*unknown = protowire.AppendTag(*unknown, num, protowire.BytesType)
-			*unknown = append(*unknown, v...)
+			u := mi.mutableUnknownBytes(p)
+			*u = protowire.AppendTag(*u, num, protowire.BytesType)
+			*u = append(*u, v...)
 			return nil
 		}
 		if !o.initialized {
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 2dcbdec..949dc49 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -175,7 +175,7 @@
 				return out, errDecode
 			}
 			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
-				u := p.Apply(mi.unknownOffset).Bytes()
+				u := mi.mutableUnknownBytes(p)
 				*u = protowire.AppendTag(*u, num, wtyp)
 				*u = append(*u, b[:n]...)
 			}
diff --git a/internal/impl/encode.go b/internal/impl/encode.go
index 8c8a794..845c67d 100644
--- a/internal/impl/encode.go
+++ b/internal/impl/encode.go
@@ -79,8 +79,9 @@
 		size += f.funcs.size(fptr, f, opts)
 	}
 	if mi.unknownOffset.IsValid() {
-		u := *p.Apply(mi.unknownOffset).Bytes()
-		size += len(u)
+		if u := mi.getUnknownBytes(p); u != nil {
+			size += len(*u)
+		}
 	}
 	if mi.sizecacheOffset.IsValid() {
 		if size > math.MaxInt32 {
@@ -141,8 +142,9 @@
 		}
 	}
 	if mi.unknownOffset.IsValid() && !mi.isMessageSet {
-		u := *p.Apply(mi.unknownOffset).Bytes()
-		b = append(b, u...)
+		if u := mi.getUnknownBytes(p); u != nil {
+			b = append(b, (*u)...)
+		}
 	}
 	return b, nil
 }
diff --git a/internal/impl/merge.go b/internal/impl/merge.go
index cdc4267..c65bbc0 100644
--- a/internal/impl/merge.go
+++ b/internal/impl/merge.go
@@ -77,9 +77,9 @@
 		}
 	}
 	if mi.unknownOffset.IsValid() {
-		du := dst.Apply(mi.unknownOffset).Bytes()
-		su := src.Apply(mi.unknownOffset).Bytes()
-		if len(*su) > 0 {
+		su := mi.getUnknownBytes(src)
+		if su != nil && len(*su) > 0 {
+			du := mi.mutableUnknownBytes(dst)
 			*du = append(*du, *su...)
 		}
 	}
diff --git a/internal/impl/message.go b/internal/impl/message.go
index c8c3859..a104e28 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -110,22 +110,29 @@
 type (
 	SizeCache       = int32
 	WeakFields      = map[int32]protoreflect.ProtoMessage
-	UnknownFields   = []byte
+	UnknownFields   = unknownFieldsA // TODO: switch to unknownFieldsB
+	unknownFieldsA  = []byte
+	unknownFieldsB  = *[]byte
 	ExtensionFields = map[int32]ExtensionField
 )
 
 var (
 	sizecacheType       = reflect.TypeOf(SizeCache(0))
 	weakFieldsType      = reflect.TypeOf(WeakFields(nil))
-	unknownFieldsType   = reflect.TypeOf(UnknownFields(nil))
+	unknownFieldsAType  = reflect.TypeOf(unknownFieldsA(nil))
+	unknownFieldsBType  = reflect.TypeOf(unknownFieldsB(nil))
 	extensionFieldsType = reflect.TypeOf(ExtensionFields(nil))
 )
 
 type structInfo struct {
 	sizecacheOffset offset
+	sizecacheType   reflect.Type
 	weakOffset      offset
+	weakType        reflect.Type
 	unknownOffset   offset
+	unknownType     reflect.Type
 	extensionOffset offset
+	extensionType   reflect.Type
 
 	fieldsByNumber        map[pref.FieldNumber]reflect.StructField
 	oneofsByName          map[pref.Name]reflect.StructField
@@ -152,18 +159,22 @@
 		case genid.SizeCache_goname, genid.SizeCacheA_goname:
 			if f.Type == sizecacheType {
 				si.sizecacheOffset = offsetOf(f, mi.Exporter)
+				si.sizecacheType = f.Type
 			}
 		case genid.WeakFields_goname, genid.WeakFieldsA_goname:
 			if f.Type == weakFieldsType {
 				si.weakOffset = offsetOf(f, mi.Exporter)
+				si.weakType = f.Type
 			}
 		case genid.UnknownFields_goname, genid.UnknownFieldsA_goname:
-			if f.Type == unknownFieldsType {
+			if f.Type == unknownFieldsAType || f.Type == unknownFieldsBType {
 				si.unknownOffset = offsetOf(f, mi.Exporter)
+				si.unknownType = f.Type
 			}
 		case genid.ExtensionFields_goname, genid.ExtensionFieldsA_goname, genid.ExtensionFieldsB_goname:
 			if f.Type == extensionFieldsType {
 				si.extensionOffset = offsetOf(f, mi.Exporter)
+				si.extensionType = f.Type
 			}
 		default:
 			for _, s := range strings.Split(f.Tag.Get("protobuf"), ",") {
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index 41cd090..0c6f106 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -108,24 +108,44 @@
 }
 
 func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
-	mi.getUnknown = func(pointer) pref.RawFields { return nil }
-	mi.setUnknown = func(pointer, pref.RawFields) { return }
-	if si.unknownOffset.IsValid() {
+	switch {
+	case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
+		// Handle as []byte.
 		mi.getUnknown = func(p pointer) pref.RawFields {
 			if p.IsNil() {
 				return nil
 			}
-			rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType)
-			return pref.RawFields(*rv.Interface().(*[]byte))
+			return *p.Apply(mi.unknownOffset).Bytes()
 		}
 		mi.setUnknown = func(p pointer, b pref.RawFields) {
 			if p.IsNil() {
 				panic("invalid SetUnknown on nil Message")
 			}
-			rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType)
-			*rv.Interface().(*[]byte) = []byte(b)
+			*p.Apply(mi.unknownOffset).Bytes() = b
 		}
-	} else {
+	case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
+		// Handle as *[]byte.
+		mi.getUnknown = func(p pointer) pref.RawFields {
+			if p.IsNil() {
+				return nil
+			}
+			bp := p.Apply(mi.unknownOffset).BytesPtr()
+			if *bp == nil {
+				return nil
+			}
+			return **bp
+		}
+		mi.setUnknown = func(p pointer, b pref.RawFields) {
+			if p.IsNil() {
+				panic("invalid SetUnknown on nil Message")
+			}
+			bp := p.Apply(mi.unknownOffset).BytesPtr()
+			if *bp == nil {
+				*bp = new([]byte)
+			}
+			**bp = b
+		}
+	default:
 		mi.getUnknown = func(pointer) pref.RawFields {
 			return nil
 		}
diff --git a/internal/impl/message_reflect_test.go b/internal/impl/message_reflect_test.go
index 66c2536..96c7e1b 100644
--- a/internal/impl/message_reflect_test.go
+++ b/internal/impl/message_reflect_test.go
@@ -22,6 +22,7 @@
 	pdesc "google.golang.org/protobuf/reflect/protodesc"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/reflect/protoregistry"
+	"google.golang.org/protobuf/testing/protopack"
 
 	proto2_20180125 "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20180125_92554152"
 	testpb "google.golang.org/protobuf/internal/testprotos/test"
@@ -1425,6 +1426,43 @@
 	return strings.Join(ss, ".")
 }
 
+type UnknownFieldsA struct {
+	XXX_unrecognized []byte
+}
+
+var unknownFieldsAType = pimpl.MessageInfo{
+	GoReflectType: reflect.TypeOf(new(UnknownFieldsA)),
+	Desc:          mustMakeMessageDesc("unknown.proto", pref.Proto2, "", `name: "UnknownFieldsA"`, nil),
+}
+
+func (m *UnknownFieldsA) ProtoReflect() pref.Message { return unknownFieldsAType.MessageOf(m) }
+
+type UnknownFieldsB struct {
+	XXX_unrecognized *[]byte
+}
+
+var unknownFieldsBType = pimpl.MessageInfo{
+	GoReflectType: reflect.TypeOf(new(UnknownFieldsB)),
+	Desc:          mustMakeMessageDesc("unknown.proto", pref.Proto2, "", `name: "UnknownFieldsB"`, nil),
+}
+
+func (m *UnknownFieldsB) ProtoReflect() pref.Message { return unknownFieldsBType.MessageOf(m) }
+
+func TestUnknownFields(t *testing.T) {
+	for _, m := range []proto.Message{new(UnknownFieldsA), new(UnknownFieldsB)} {
+		t.Run(reflect.TypeOf(m).Elem().Name(), func(t *testing.T) {
+			want := protopack.Message{
+				protopack.Tag{1, protopack.BytesType}, protopack.String("Hello, world!"),
+			}.Marshal()
+			m.ProtoReflect().SetUnknown(want)
+			got := []byte(m.ProtoReflect().GetUnknown())
+			if diff := cmp.Diff(want, got); diff != "" {
+				t.Errorf("UnknownFields mismatch (-want +got):\n%s", diff)
+			}
+		})
+	}
+}
+
 func TestReset(t *testing.T) {
 	mi := new(testpb.TestAllTypes)
 
diff --git a/internal/impl/pointer_reflect.go b/internal/impl/pointer_reflect.go
index 67b4ede..9e3ed82 100644
--- a/internal/impl/pointer_reflect.go
+++ b/internal/impl/pointer_reflect.go
@@ -121,6 +121,7 @@
 func (p pointer) StringPtr() **string      { return p.v.Interface().(**string) }
 func (p pointer) StringSlice() *[]string   { return p.v.Interface().(*[]string) }
 func (p pointer) Bytes() *[]byte           { return p.v.Interface().(*[]byte) }
+func (p pointer) BytesPtr() **[]byte       { return p.v.Interface().(**[]byte) }
 func (p pointer) BytesSlice() *[][]byte    { return p.v.Interface().(*[][]byte) }
 func (p pointer) WeakFields() *weakFields  { return (*weakFields)(p.v.Interface().(*WeakFields)) }
 func (p pointer) Extensions() *map[int32]ExtensionField {
diff --git a/internal/impl/pointer_unsafe.go b/internal/impl/pointer_unsafe.go
index 088aa85..9ecf23a 100644
--- a/internal/impl/pointer_unsafe.go
+++ b/internal/impl/pointer_unsafe.go
@@ -109,6 +109,7 @@
 func (p pointer) StringPtr() **string                   { return (**string)(p.p) }
 func (p pointer) StringSlice() *[]string                { return (*[]string)(p.p) }
 func (p pointer) Bytes() *[]byte                        { return (*[]byte)(p.p) }
+func (p pointer) BytesPtr() **[]byte                    { return (**[]byte)(p.p) }
 func (p pointer) BytesSlice() *[][]byte                 { return (*[][]byte)(p.p) }
 func (p pointer) WeakFields() *weakFields               { return (*weakFields)(p.p) }
 func (p pointer) Extensions() *map[int32]ExtensionField { return (*map[int32]ExtensionField)(p.p) }