proto: add IsInitialized
Move all checks for required fields into a proto.IsInitialized function.
Initial testing makes me confident that we can provide a fast-path
implementation of IsInitialized which will perform more than
acceptably. (In the degenerate-but-common case where a message
transitively contains no required fields, this check can be nearly
zero cost.)
Unifying checks into a single function provides consistent behavior
between the wire, text, and json codecs.
Performing the check after decoding eliminates the wire decoder bug
where a split message is incorrectly seen as missing required fields.
Performing the check after decoding also provides consistent and
arguably more correct behavior when the target message was partially
prepopulated.
Change-Id: I9478b7bebb263af00c0d9f66a1f26e31ff553522
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170787
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/isinit.go b/proto/isinit.go
new file mode 100644
index 0000000..33dfb64
--- /dev/null
+++ b/proto/isinit.go
@@ -0,0 +1,94 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+ "bytes"
+ "fmt"
+
+ "github.com/golang/protobuf/v2/internal/errors"
+ pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// IsInitialized returns an error if any required fields in m are not set.
+func IsInitialized(m Message) error {
+ if methods := protoMethods(m); methods != nil && methods.IsInitialized != nil {
+ // TODO: Do we need a way to disable the fast path here?
+ //
+ // TODO: Should detailed information about missing
+ // fields always be provided by the slow-but-informative
+ // reflective implementation?
+ return methods.IsInitialized(m)
+ }
+ return isInitialized(m.ProtoReflect(), nil)
+}
+
+// IsInitialized returns an error if any required fields in m are not set.
+func isInitialized(m pref.Message, stack []interface{}) error {
+ md := m.Type()
+ known := m.KnownFields()
+ fields := md.Fields()
+ for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
+ num := nums.Get(i)
+ if !known.Has(num) {
+ stack = append(stack, fields.ByNumber(num).Name())
+ return newRequiredNotSetError(stack)
+ }
+ }
+ var err error
+ known.Range(func(num pref.FieldNumber, v pref.Value) bool {
+ field := fields.ByNumber(num)
+ if field == nil {
+ field = known.ExtensionTypes().ByNumber(num)
+ }
+ if field == nil {
+ panic(fmt.Errorf("no descriptor for field %d in %q", num, md.FullName()))
+ }
+ // Look for fields containing a message: Messages, groups, and maps
+ // with a message or group value.
+ ft := field.MessageType()
+ if ft == nil {
+ return true
+ }
+ if field.IsMap() {
+ if ft.Fields().ByNumber(2).MessageType() == nil {
+ return true
+ }
+ }
+ // Recurse into the field
+ stack := append(stack, field.Name())
+ switch {
+ case field.IsMap():
+ v.Map().Range(func(key pref.MapKey, v pref.Value) bool {
+ stack := append(stack, "[", key, "].")
+ err = isInitialized(v.Message(), stack)
+ return err == nil
+ })
+ case field.Cardinality() == pref.Repeated:
+ for i, list := 0, v.List(); i < list.Len(); i++ {
+ stack := append(stack, "[", i, "].")
+ err = isInitialized(list.Get(i).Message(), stack)
+ if err != nil {
+ break
+ }
+ }
+ default:
+ stack := append(stack, ".")
+ err = isInitialized(v.Message(), stack)
+ }
+ return err == nil
+ })
+ return err
+}
+
+func newRequiredNotSetError(stack []interface{}) error {
+ var buf bytes.Buffer
+ for _, s := range stack {
+ fmt.Fprint(&buf, s)
+ }
+ var nerr errors.NonFatal
+ nerr.AppendRequiredNotSet(buf.String())
+ return nerr.E
+}