proto: add IsInitialized

Move all checks for required fields into a proto.IsInitialized function.

Initial testing makes me confident that we can provide a fast-path
implementation of IsInitialized which will perform more than
acceptably.  (In the degenerate-but-common case where a message
transitively contains no required fields, this check can be nearly
zero cost.)

Unifying checks into a single function provides consistent behavior
between the wire, text, and json codecs.

Performing the check after decoding eliminates the wire decoder bug
where a split message is incorrectly seen as missing required fields.

Performing the check after decoding also provides consistent and
arguably more correct behavior when the target message was partially
prepopulated.

Change-Id: I9478b7bebb263af00c0d9f66a1f26e31ff553522
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170787
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/decode.go b/proto/decode.go
index 2b871c4..11fed46 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -42,10 +42,18 @@
 // Unmarshal parses the wire-format message in b and places the result in m.
 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
 	// TODO: Reset m?
-	if err := o.unmarshalMessageFast(b, m); err != errInternalNoFast {
+	err := o.unmarshalMessageFast(b, m)
+	if err == errInternalNoFast {
+		err = o.unmarshalMessage(b, m.ProtoReflect())
+	}
+	var nerr errors.NonFatal
+	if !nerr.Merge(err) {
 		return err
 	}
-	return o.unmarshalMessage(b, m.ProtoReflect())
+	if !o.AllowPartial {
+		nerr.Merge(IsInitialized(m))
+	}
+	return nerr.E
 }
 
 func (o UnmarshalOptions) unmarshalMessageFast(b []byte, m Message) error {
@@ -100,9 +108,6 @@
 		}
 		b = b[tagLen+valLen:]
 	}
-	if !o.AllowPartial {
-		checkRequiredFields(m, &nerr)
-	}
 	return nerr.E
 }
 
@@ -204,9 +209,6 @@
 	if !haveVal {
 		switch valField.Kind() {
 		case protoreflect.GroupKind, protoreflect.MessageKind:
-			if !o.AllowPartial {
-				checkRequiredFields(val.Message(), &nerr)
-			}
 		default:
 			val = valField.Default()
 		}
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 0a94b8a..dda4db1 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -944,23 +944,20 @@
 			}),
 		}.Marshal(),
 	},
-	// TODO: Handle this case.
-	/*
-		{
-			desc: "required field in optional message set (split across multiple tags)",
-			decodeTo: []proto.Message{&testpb.TestRequiredForeign{
-				OptionalMessage: &testpb.TestRequired{
-					RequiredField: scalar.Int32(1),
-				},
-			}},
-			wire: pack.Message{
-				pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
-				pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
-					pack.Tag{1, pack.VarintType}, pack.Varint(1),
-				}),
-			}.Marshal(),
-		},
-	*/
+	{
+		desc: "required field in optional message set (split across multiple tags)",
+		decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+			OptionalMessage: &testpb.TestRequired{
+				RequiredField: scalar.Int32(1),
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(1),
+			}),
+		}.Marshal(),
+	},
 	{
 		desc:    "required field in repeated message unset",
 		partial: true,
diff --git a/proto/encode.go b/proto/encode.go
index adf3de4..fc3ce92 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -69,10 +69,18 @@
 // MarshalAppend appends the wire-format encoding of m to b,
 // returning the result.
 func (o MarshalOptions) MarshalAppend(b []byte, m Message) ([]byte, error) {
-	if b, err := o.marshalMessageFast(b, m); err != errInternalNoFast {
+	b, err := o.marshalMessageFast(b, m)
+	if err == errInternalNoFast {
+		b, err = o.marshalMessage(b, m.ProtoReflect())
+	}
+	var nerr errors.NonFatal
+	if !nerr.Merge(err) {
 		return b, err
 	}
-	return o.marshalMessage(b, m.ProtoReflect())
+	if !o.AllowPartial {
+		nerr.Merge(IsInitialized(m))
+	}
+	return b, nerr.E
 }
 
 func (o MarshalOptions) marshalMessageFast(b []byte, m Message) ([]byte, error) {
@@ -129,9 +137,6 @@
 		b = append(b, raw...)
 		return true
 	})
-	if !o.AllowPartial {
-		checkRequiredFields(m, &nerr)
-	}
 	return b, nerr.E
 }
 
diff --git a/proto/isinit.go b/proto/isinit.go
new file mode 100644
index 0000000..33dfb64
--- /dev/null
+++ b/proto/isinit.go
@@ -0,0 +1,94 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+	"bytes"
+	"fmt"
+
+	"github.com/golang/protobuf/v2/internal/errors"
+	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// IsInitialized returns an error if any required fields in m are not set.
+func IsInitialized(m Message) error {
+	if methods := protoMethods(m); methods != nil && methods.IsInitialized != nil {
+		// TODO: Do we need a way to disable the fast path here?
+		//
+		// TODO: Should detailed information about missing
+		// fields always be provided by the slow-but-informative
+		// reflective implementation?
+		return methods.IsInitialized(m)
+	}
+	return isInitialized(m.ProtoReflect(), nil)
+}
+
+// IsInitialized returns an error if any required fields in m are not set.
+func isInitialized(m pref.Message, stack []interface{}) error {
+	md := m.Type()
+	known := m.KnownFields()
+	fields := md.Fields()
+	for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
+		num := nums.Get(i)
+		if !known.Has(num) {
+			stack = append(stack, fields.ByNumber(num).Name())
+			return newRequiredNotSetError(stack)
+		}
+	}
+	var err error
+	known.Range(func(num pref.FieldNumber, v pref.Value) bool {
+		field := fields.ByNumber(num)
+		if field == nil {
+			field = known.ExtensionTypes().ByNumber(num)
+		}
+		if field == nil {
+			panic(fmt.Errorf("no descriptor for field %d in %q", num, md.FullName()))
+		}
+		// Look for fields containing a message: Messages, groups, and maps
+		// with a message or group value.
+		ft := field.MessageType()
+		if ft == nil {
+			return true
+		}
+		if field.IsMap() {
+			if ft.Fields().ByNumber(2).MessageType() == nil {
+				return true
+			}
+		}
+		// Recurse into the field
+		stack := append(stack, field.Name())
+		switch {
+		case field.IsMap():
+			v.Map().Range(func(key pref.MapKey, v pref.Value) bool {
+				stack := append(stack, "[", key, "].")
+				err = isInitialized(v.Message(), stack)
+				return err == nil
+			})
+		case field.Cardinality() == pref.Repeated:
+			for i, list := 0, v.List(); i < list.Len(); i++ {
+				stack := append(stack, "[", i, "].")
+				err = isInitialized(list.Get(i).Message(), stack)
+				if err != nil {
+					break
+				}
+			}
+		default:
+			stack := append(stack, ".")
+			err = isInitialized(v.Message(), stack)
+		}
+		return err == nil
+	})
+	return err
+}
+
+func newRequiredNotSetError(stack []interface{}) error {
+	var buf bytes.Buffer
+	for _, s := range stack {
+		fmt.Fprint(&buf, s)
+	}
+	var nerr errors.NonFatal
+	nerr.AppendRequiredNotSet(buf.String())
+	return nerr.E
+}
diff --git a/proto/isinit_test.go b/proto/isinit_test.go
new file mode 100644
index 0000000..951b95f
--- /dev/null
+++ b/proto/isinit_test.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto_test
+
+import (
+	"fmt"
+	"testing"
+
+	"github.com/golang/protobuf/v2/internal/scalar"
+	"github.com/golang/protobuf/v2/proto"
+
+	testpb "github.com/golang/protobuf/v2/internal/testprotos/test"
+)
+
+func TestIsInitializedErrors(t *testing.T) {
+	for _, test := range []struct {
+		m    proto.Message
+		want string
+	}{
+		{
+			&testpb.TestRequired{},
+			`proto: required field required_field not set`,
+		},
+		{
+			&testpb.TestRequiredForeign{
+				OptionalMessage: &testpb.TestRequired{},
+			},
+			`proto: required field optional_message.required_field not set`,
+		},
+		{
+			&testpb.TestRequiredForeign{
+				RepeatedMessage: []*testpb.TestRequired{
+					{RequiredField: scalar.Int32(1)},
+					{},
+				},
+			},
+			`proto: required field repeated_message[1].required_field not set`,
+		},
+		{
+			&testpb.TestRequiredForeign{
+				MapMessage: map[int32]*testpb.TestRequired{
+					1: {},
+				},
+			},
+			`proto: required field map_message[1].required_field not set`,
+		},
+	} {
+		err := proto.IsInitialized(test.m)
+		got := "<nil>"
+		if err != nil {
+			got = fmt.Sprintf("%q", err)
+		}
+		want := fmt.Sprintf("%q", test.want)
+		if got != want {
+			t.Errorf("IsInitialized(m):\n got: %v\nwant: %v\nMessage:\n%v", got, want, marshalText(test.m))
+		}
+	}
+}
diff --git a/proto/proto.go b/proto/proto.go
index f1fe5aa..96dac1c 100644
--- a/proto/proto.go
+++ b/proto/proto.go
@@ -22,14 +22,3 @@
 	}
 	return nil
 }
-
-func checkRequiredFields(m protoreflect.Message, nerr *errors.NonFatal) {
-	req := m.Type().RequiredNumbers()
-	knownFields := m.KnownFields()
-	for i, reqLen := 0, req.Len(); i < reqLen; i++ {
-		num := req.Get(i)
-		if !knownFields.Has(num) {
-			nerr.AppendRequiredNotSet(string(m.Type().Fields().ByNumber(num).FullName()))
-		}
-	}
-}