proto, internal/impl: implement support for weak fields

Change-Id: I0a3ff79542a3316295fd6c58e1447e597be97ab9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189923
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go
index faa56b0..db8a3df 100644
--- a/encoding/protojson/decode.go
+++ b/encoding/protojson/decode.go
@@ -201,6 +201,8 @@
 					fd = nil // reset since field name is actually the message name
 				}
 			}
+		}
+		if flags.ProtoLegacy {
 			if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
 				fd = nil // reset since the weak reference is not linked in
 			}
diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go
index 27d23d3..8951e6b 100644
--- a/encoding/prototext/decode.go
+++ b/encoding/prototext/decode.go
@@ -138,8 +138,10 @@
 		} else if xtErr != nil && xtErr != protoregistry.NotFound {
 			return errors.New("unable to resolve: %v", xtErr)
 		}
-		if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
-			fd = nil // reset since the weak reference is not linked in
+		if flags.ProtoLegacy {
+			if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
+				fd = nil // reset since the weak reference is not linked in
+			}
 		}
 
 		// Handle unknown fields.
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 5414635..45b4664 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -6,10 +6,13 @@
 
 import (
 	"reflect"
+	"sync"
 
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/proto"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
+	preg "google.golang.org/protobuf/reflect/protoregistry"
+	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
 type errInvalidUTF8 struct{}
@@ -17,7 +20,7 @@
 func (errInvalidUTF8) Error() string     { return "string field contains invalid UTF-8" }
 func (errInvalidUTF8) InvalidUTF8() bool { return true }
 
-func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFuncs {
+func makeOneofFieldCoder(fd pref.FieldDescriptor, si structInfo) pointerCoderFuncs {
 	ot := si.oneofWrappersByNumber[fd.Number()]
 	funcs := fieldCoder(fd, ot.Field(0).Type)
 	fs := si.oneofsByName[fd.ContainingOneof().Name()]
@@ -78,6 +81,61 @@
 	return pcf
 }
 
+func makeWeakMessageFieldCoder(fd pref.FieldDescriptor) pointerCoderFuncs {
+	var once sync.Once
+	var messageType pref.MessageType
+	lazyInit := func() {
+		once.Do(func() {
+			messageName := fd.Message().FullName()
+			messageType, _ = preg.GlobalTypes.FindMessageByName(messageName)
+		})
+	}
+
+	num := int32(fd.Number())
+	return pointerCoderFuncs{
+		size: func(p pointer, tagsize int, opts marshalOptions) int {
+			fs := p.WeakFields()
+			m, ok := (*fs)[num]
+			if !ok {
+				return 0
+			}
+			return sizeMessage(m.(proto.Message), tagsize, opts)
+		},
+		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
+			fs := p.WeakFields()
+			m, ok := (*fs)[num]
+			if !ok {
+				return b, nil
+			}
+			return appendMessage(b, m.(proto.Message), wiretag, opts)
+		},
+		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+			fs := p.WeakFields()
+			m, ok := (*fs)[num]
+			if !ok {
+				lazyInit()
+				if messageType == nil {
+					return 0, errUnknown
+				}
+				m = messageType.New().Interface().(piface.MessageV1)
+				if *fs == nil {
+					*fs = make(WeakFields)
+				}
+				(*fs)[num] = m
+			}
+			return consumeMessage(b, m.(proto.Message), wtyp, opts)
+		},
+		isInit: func(p pointer) error {
+			fs := p.WeakFields()
+			m, ok := (*fs)[num]
+			if !ok {
+				return nil
+			}
+			return proto.IsInitialized(m.(proto.Message))
+		},
+	}
+}
+
 func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
 	if mi := getMessageInfo(ft); mi != nil {
 		return pointerCoderFuncs{
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index 92f38e9..d7235a2 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -65,15 +65,22 @@
 		} else {
 			wiretag = wire.EncodeTag(fd.Number(), wire.BytesType)
 		}
+		var fieldOffset offset
 		var funcs pointerCoderFuncs
-		if fd.ContainingOneof() != nil {
-			funcs = makeOneofFieldCoder(si, fd)
-		} else {
+		switch {
+		case fd.ContainingOneof() != nil:
+			fieldOffset = offsetOf(fs, mi.Exporter)
+			funcs = makeOneofFieldCoder(fd, si)
+		case fd.IsWeak():
+			fieldOffset = si.weakOffset
+			funcs = makeWeakMessageFieldCoder(fd)
+		default:
+			fieldOffset = offsetOf(fs, mi.Exporter)
 			funcs = fieldCoder(fd, ft)
 		}
 		cf := &coderFieldInfo{
 			num:     fd.Number(),
-			offset:  offsetOf(fs, mi.Exporter),
+			offset:  fieldOffset,
 			wiretag: wiretag,
 			tagsize: wire.SizeVarint(wiretag),
 			funcs:   funcs,
diff --git a/internal/testprotos/test/weak1/test_weak.pb.go b/internal/testprotos/test/weak1/test_weak.pb.go
index 1b80b8b..c62ba49 100644
--- a/internal/testprotos/test/weak1/test_weak.pb.go
+++ b/internal/testprotos/test/weak1/test_weak.pb.go
@@ -19,7 +19,7 @@
 	sizeCache     protoimpl.SizeCache
 	unknownFields protoimpl.UnknownFields
 
-	A *int32 `protobuf:"varint,1,opt,name=a" json:"a,omitempty"`
+	A *int32 `protobuf:"varint,1,req,name=a" json:"a,omitempty"`
 }
 
 func (x *WeakImportMessage1) Reset() {
@@ -64,7 +64,7 @@
 	0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x74, 0x65, 0x73, 0x74,
 	0x2e, 0x77, 0x65, 0x61, 0x6b, 0x22, 0x22, 0x0a, 0x12, 0x57, 0x65, 0x61, 0x6b, 0x49, 0x6d, 0x70,
 	0x6f, 0x72, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x31, 0x12, 0x0c, 0x0a, 0x01, 0x61,
-	0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
+	0x18, 0x01, 0x20, 0x02, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
 	0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70,
 	0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c,
 	0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74,
diff --git a/internal/testprotos/test/weak1/test_weak.proto b/internal/testprotos/test/weak1/test_weak.proto
index 5a4814d..113ed1b 100644
--- a/internal/testprotos/test/weak1/test_weak.proto
+++ b/internal/testprotos/test/weak1/test_weak.proto
@@ -9,5 +9,5 @@
 option go_package = "google.golang.org/protobuf/internal/testprotos/test/weak1";
 
 message WeakImportMessage1 {
-	optional int32 a = 1;
+	required int32 a = 1;
 }
diff --git a/internal/testprotos/test/weak2/test_weak.pb.go b/internal/testprotos/test/weak2/test_weak.pb.go
index a1d2e07..e04031c 100644
--- a/internal/testprotos/test/weak2/test_weak.pb.go
+++ b/internal/testprotos/test/weak2/test_weak.pb.go
@@ -19,7 +19,7 @@
 	sizeCache     protoimpl.SizeCache
 	unknownFields protoimpl.UnknownFields
 
-	A *int32 `protobuf:"varint,1,opt,name=a" json:"a,omitempty"`
+	A *int32 `protobuf:"varint,1,req,name=a" json:"a,omitempty"`
 }
 
 func (x *WeakImportMessage2) Reset() {
@@ -64,7 +64,7 @@
 	0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x74, 0x65, 0x73, 0x74,
 	0x2e, 0x77, 0x65, 0x61, 0x6b, 0x22, 0x22, 0x0a, 0x12, 0x57, 0x65, 0x61, 0x6b, 0x49, 0x6d, 0x70,
 	0x6f, 0x72, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0x12, 0x0c, 0x0a, 0x01, 0x61,
-	0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
+	0x18, 0x01, 0x20, 0x02, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
 	0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70,
 	0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c,
 	0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74,
diff --git a/internal/testprotos/test/weak2/test_weak.proto b/internal/testprotos/test/weak2/test_weak.proto
index eb4d0ff..61a03d6 100644
--- a/internal/testprotos/test/weak2/test_weak.proto
+++ b/internal/testprotos/test/weak2/test_weak.proto
@@ -9,5 +9,5 @@
 option go_package = "google.golang.org/protobuf/internal/testprotos/test/weak2";
 
 message WeakImportMessage2 {
-	optional int32 a = 1;
+	required int32 a = 1;
 }
diff --git a/proto/decode.go b/proto/decode.go
index ab62ff3..f64f887 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -8,6 +8,7 @@
 	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/errors"
+	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/internal/pragma"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/reflect/protoregistry"
@@ -88,7 +89,7 @@
 			return wire.ParseError(tagLen)
 		}
 
-		// Parse the field value.
+		// Find the field descriptor for this field number.
 		fd := fields.ByNumber(num)
 		if fd == nil && md.ExtensionRanges().Has(num) {
 			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
@@ -100,10 +101,18 @@
 			}
 		}
 		var err error
+		if fd == nil {
+			err = errUnknown
+		} else if flags.ProtoLegacy {
+			if fd.IsWeak() && fd.Message().IsPlaceholder() {
+				err = errUnknown // weak referent is not linked in
+			}
+		}
+
+		// Parse the field value.
 		var valLen int
 		switch {
-		case fd == nil:
-			err = errUnknown
+		case err != nil:
 		case fd.IsList():
 			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
 		case fd.IsMap():
@@ -111,14 +120,15 @@
 		default:
 			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
 		}
-		if err == errUnknown {
+		if err != nil {
+			if err != errUnknown {
+				return err
+			}
 			valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
 			if valLen < 0 {
 				return wire.ParseError(valLen)
 			}
 			m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
-		} else if err != nil {
-			return err
 		}
 		b = b[tagLen+valLen:]
 	}
diff --git a/proto/decode_test.go b/proto/decode_test.go
index c696778..b6258da 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -21,6 +21,7 @@
 	legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
 	legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
 	testpb "google.golang.org/protobuf/internal/testprotos/test"
+	weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
 	test3pb "google.golang.org/protobuf/internal/testprotos/test3"
 	"google.golang.org/protobuf/types/descriptorpb"
 )
@@ -1726,6 +1727,46 @@
 	},
 }
 
+func TestWeak(t *testing.T) {
+	if !flags.ProtoLegacy {
+		t.SkipNow()
+	}
+
+	m := new(testpb.TestWeak)
+	b := pack.Message{
+		pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+			pack.Tag{1, pack.VarintType}, pack.Varint(1000),
+		}),
+		pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+			pack.Tag{1, pack.VarintType}, pack.Varint(2000),
+		}),
+	}.Marshal()
+	if err := proto.Unmarshal(b, m); err != nil {
+		t.Errorf("Unmarshal error: %v", err)
+	}
+
+	mw := m.GetWeakMessage1().(*weakpb.WeakImportMessage1)
+	if mw.GetA() != 1000 {
+		t.Errorf("m.WeakMessage1.a = %d, want %d", mw.GetA(), 1000)
+	}
+
+	if len(m.ProtoReflect().GetUnknown()) == 0 {
+		t.Errorf("m has no unknown fields, expected at least something")
+	}
+
+	if n := proto.Size(m); n != len(b) {
+		t.Errorf("Size() = %d, want %d", n, len(b))
+	}
+
+	b2, err := proto.Marshal(m)
+	if err != nil {
+		t.Errorf("Marshal error: %v", err)
+	}
+	if len(b2) != len(b) {
+		t.Errorf("len(Marshal) = %d, want %d", len(b2), len(b))
+	}
+}
+
 func build(m proto.Message, opts ...buildOpt) proto.Message {
 	for _, opt := range opts {
 		opt(m)
diff --git a/proto/isinit_test.go b/proto/isinit_test.go
index d935d3a..6a3a8c9 100644
--- a/proto/isinit_test.go
+++ b/proto/isinit_test.go
@@ -9,51 +9,80 @@
 	"strings"
 	"testing"
 
+	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/proto"
 
 	testpb "google.golang.org/protobuf/internal/testprotos/test"
+	weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
 )
 
 func TestIsInitializedErrors(t *testing.T) {
-	for _, test := range []struct {
+	type test struct {
 		m    proto.Message
 		want string
-	}{
-		{
-			&testpb.TestRequired{},
-			`goproto.proto.test.TestRequired.required_field`,
+		skip bool
+	}
+	tests := []test{{
+		m:    &testpb.TestRequired{},
+		want: `goproto.proto.test.TestRequired.required_field`,
+	}, {
+		m: &testpb.TestRequiredForeign{
+			OptionalMessage: &testpb.TestRequired{},
 		},
-		{
-			&testpb.TestRequiredForeign{
-				OptionalMessage: &testpb.TestRequired{},
+		want: `goproto.proto.test.TestRequired.required_field`,
+	}, {
+		m: &testpb.TestRequiredForeign{
+			RepeatedMessage: []*testpb.TestRequired{
+				{RequiredField: proto.Int32(1)},
+				{},
 			},
-			`goproto.proto.test.TestRequired.required_field`,
 		},
-		{
-			&testpb.TestRequiredForeign{
-				RepeatedMessage: []*testpb.TestRequired{
-					{RequiredField: proto.Int32(1)},
-					{},
-				},
+		want: `goproto.proto.test.TestRequired.required_field`,
+	}, {
+		m: &testpb.TestRequiredForeign{
+			MapMessage: map[int32]*testpb.TestRequired{
+				1: {},
 			},
-			`goproto.proto.test.TestRequired.required_field`,
 		},
-		{
-			&testpb.TestRequiredForeign{
-				MapMessage: map[int32]*testpb.TestRequired{
-					1: {},
-				},
-			},
-			`goproto.proto.test.TestRequired.required_field`,
-		},
-	} {
-		err := proto.IsInitialized(test.m)
-		got := "<nil>"
-		if err != nil {
-			got = fmt.Sprintf("%q", err)
-		}
-		if !strings.Contains(got, test.want) {
-			t.Errorf("IsInitialized(m):\n got: %v\nwant contains: %v\nMessage:\n%v", got, test.want, marshalText(test.m))
-		}
+		want: `goproto.proto.test.TestRequired.required_field`,
+	}, {
+		m:    &testpb.TestWeak{},
+		want: `<nil>`,
+		skip: !flags.ProtoLegacy,
+	}, {
+		m: func() proto.Message {
+			m := &testpb.TestWeak{}
+			m.SetWeakMessage1(&weakpb.WeakImportMessage1{})
+			return m
+		}(),
+		want: `goproto.proto.test.weak.WeakImportMessage1.a`,
+		skip: !flags.ProtoLegacy,
+	}, {
+		m: func() proto.Message {
+			m := &testpb.TestWeak{}
+			m.SetWeakMessage1(&weakpb.WeakImportMessage1{
+				A: proto.Int32(1),
+			})
+			return m
+		}(),
+		want: `<nil>`,
+		skip: !flags.ProtoLegacy,
+	}}
+
+	for _, tt := range tests {
+		t.Run("", func(t *testing.T) {
+			if tt.skip {
+				t.SkipNow()
+			}
+
+			err := proto.IsInitialized(tt.m)
+			got := "<nil>"
+			if err != nil {
+				got = fmt.Sprintf("%q", err)
+			}
+			if !strings.Contains(got, tt.want) {
+				t.Errorf("IsInitialized(m):\n got: %v\nwant contains: %v\nMessage:\n%v", got, tt.want, marshalText(tt.m))
+			}
+		})
 	}
 }