internal/impl: refactor validation a bit

Return the size of the field read from the validator, permitting us to
avoid an extra parse when skipping over groups.

Return an UnmarshalOutput from the validator, since it already combines
two of the validator outputs: bytes read and initialization status.

Remove initialization status from the ValidationStatus enum, since it's
covered by the UnmarshalOutput.

Change-Id: I3e684c45d15aa1992d8dc3bde0f608880d34a94b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217763
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/benchmarks/micro/micro_test.go b/internal/benchmarks/micro/micro_test.go
index 097a326..b36dc78 100644
--- a/internal/benchmarks/micro/micro_test.go
+++ b/internal/benchmarks/micro/micro_test.go
@@ -55,7 +55,9 @@
 				Resolver: protoregistry.GlobalTypes,
 			}
 			for pb.Next() {
-				if got, want := impl.Validate([]byte{}, mt, opts), impl.ValidationValidInitialized; got != want {
+				_, got := impl.Validate([]byte{}, mt, opts)
+				want := impl.ValidationValid
+				if got != want {
 					b.Fatalf("Validate = %v, want %v", got, want)
 				}
 			}
@@ -106,7 +108,9 @@
 				Resolver: protoregistry.GlobalTypes,
 			}
 			for pb.Next() {
-				if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
+				_, got := impl.Validate(w, mt, opts)
+				want := impl.ValidationValid
+				if got != want {
 					b.Fatalf("Validate = %v, want %v", got, want)
 				}
 			}
@@ -167,7 +171,9 @@
 				Resolver: protoregistry.GlobalTypes,
 			}
 			for pb.Next() {
-				if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
+				_, got := impl.Validate(w, mt, opts)
+				want := impl.ValidationValid
+				if got != want {
 					b.Fatalf("Validate = %v, want %v", got, want)
 				}
 			}
diff --git a/internal/fuzz/wirefuzz/fuzz.go b/internal/fuzz/wirefuzz/fuzz.go
index 7ca46ba..28aed57 100644
--- a/internal/fuzz/wirefuzz/fuzz.go
+++ b/internal/fuzz/wirefuzz/fuzz.go
@@ -19,7 +19,7 @@
 // Fuzz is a fuzzer for proto.Marshal and proto.Unmarshal.
 func Fuzz(data []byte) (score int) {
 	m1 := &fuzzpb.Fuzz{}
-	valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
+	vout, valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
 		Resolver: protoregistry.GlobalTypes,
 	})
 	if err := (proto.UnmarshalOptions{
@@ -33,21 +33,14 @@
 		}
 		return 0
 	}
-	if proto.IsInitialized(m1) == nil {
-		switch valid {
-		case impl.ValidationUnknown:
-		case impl.ValidationValidInitialized:
-		case impl.ValidationValidMaybeUninitalized:
-		default:
-			panic("unmarshal ok with validation status: " + valid.String())
-		}
-	} else {
-		switch valid {
-		case impl.ValidationUnknown:
-		case impl.ValidationValidMaybeUninitalized:
-		default:
-			panic("partial unmarshal ok with validation status: " + valid.String())
-		}
+	switch valid {
+	case impl.ValidationUnknown:
+	case impl.ValidationValid:
+	default:
+		panic("unmarshal ok with validation status: " + valid.String())
+	}
+	if proto.IsInitialized(m1) != nil && vout.Initialized {
+		panic("validation reports partial message is initialized")
 	}
 	data1, err := proto.MarshalOptions{
 		AllowPartial: true,
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 290fc41..3155bc5 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -196,11 +196,9 @@
 	}
 	if flags.LazyUnmarshalExtensions {
 		if opts.IsDefault() && x.canLazy(xt) {
-			if n, ok := skipExtension(b, xi, num, wtyp, opts); ok {
-				x.appendLazyBytes(xt, xi, num, wtyp, b[:n])
+			if out, ok := skipExtension(b, xi, num, wtyp, opts); ok && out.initialized {
+				x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
 				exts[int32(num)] = x
-				out.n = n
-				out.initialized = true
 				return out, nil
 			}
 		}
@@ -224,35 +222,31 @@
 	return out, nil
 }
 
-func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (n int, ok bool) {
+func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, ok bool) {
 	if xi.validation.mi == nil {
-		return 0, false
+		return out, false
 	}
 	xi.validation.mi.init()
 	var v []byte
 	switch xi.validation.typ {
 	case validationTypeMessage:
 		if wtyp != wire.BytesType {
-			return 0, false
+			return out, false
 		}
-		v, n = wire.ConsumeBytes(b)
+		v, n := wire.ConsumeBytes(b)
 		if n < 0 {
-			return 0, false
+			return out, false
 		}
+		out, st := xi.validation.mi.validate(v, 0, opts)
+		out.n = n
+		return out, st == ValidationValid
 	case validationTypeGroup:
 		if wtyp != wire.StartGroupType {
-			return 0, false
+			return out, false
 		}
-		v, n = wire.ConsumeGroup(num, b)
-		if n < 0 {
-			return 0, false
-		}
+		out, st := xi.validation.mi.validate(v, num, opts)
+		return out, st == ValidationValid
 	default:
-		return 0, false
+		return out, false
 	}
-	if xi.validation.mi.validate(v, 0, opts) != ValidationValidInitialized {
-		return 0, false
-	}
-	return n, true
-
 }
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index eab8ec0..0c32026 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -33,16 +33,8 @@
 	// ValidationInvalid indicates that unmarshaling the message will fail.
 	ValidationInvalid
 
-	// ValidationValidInitialized indicates that unmarshaling the message will succeed
-	// and IsInitialized on the result will report success.
-	ValidationValidInitialized
-
-	// ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
-	// but the output of IsInitialized on the result is unknown.
-	//
-	// This status may be returned for an initialized message when a message value
-	// is split across multiple fields.
-	ValidationValidMaybeUninitalized
+	// ValidationValid indicates that unmarshaling the message will succeed.
+	ValidationValid
 )
 
 func (v ValidationStatus) String() string {
@@ -51,10 +43,8 @@
 		return "ValidationUnknown"
 	case ValidationInvalid:
 		return "ValidationInvalid"
-	case ValidationValidInitialized:
-		return "ValidationValidInitialized"
-	case ValidationValidMaybeUninitalized:
-		return "ValidationValidMaybeUninitalized"
+	case ValidationValid:
+		return "ValidationValid"
 	default:
 		return fmt.Sprintf("ValidationStatus(%d)", int(v))
 	}
@@ -64,12 +54,14 @@
 // of the message type.
 //
 // This function is exposed for testing.
-func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
+func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) (out piface.UnmarshalOutput, _ ValidationStatus) {
 	mi, ok := mt.(*MessageInfo)
 	if !ok {
-		return ValidationUnknown
+		return out, ValidationUnknown
 	}
-	return mi.validate(b, 0, unmarshalOptions(opts))
+	o, st := mi.validate(b, 0, unmarshalOptions(opts))
+	out.Initialized = o.initialized
+	return out, st
 }
 
 type validationInfo struct {
@@ -219,7 +211,7 @@
 	return vi
 }
 
-func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
+func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
 	mi.init()
 	type validationState struct {
 		typ              validationType
@@ -241,12 +233,13 @@
 		states[0].endGroup = groupTag
 	}
 	initialized := true
+	start := len(b)
 State:
 	for len(states) > 0 {
 		st := &states[len(states)-1]
 		if st.mi != nil {
 			if flags.ProtoLegacy && st.mi.isMessageSet {
-				return ValidationUnknown
+				return out, ValidationUnknown
 			}
 		}
 		for len(b) > 0 {
@@ -262,13 +255,13 @@
 				var n int
 				tag, n = wire.ConsumeVarint(b)
 				if n < 0 {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				b = b[n:]
 			}
 			var num wire.Number
 			if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
-				return ValidationInvalid
+				return out, ValidationInvalid
 			} else {
 				num = wire.Number(n)
 			}
@@ -278,7 +271,7 @@
 				if st.endGroup == num {
 					goto PopState
 				}
-				return ValidationInvalid
+				return out, ValidationInvalid
 			}
 			var vi validationInfo
 			switch st.typ {
@@ -317,7 +310,7 @@
 						case preg.NotFound:
 							vi.typ = validationTypeBytes
 						default:
-							return ValidationUnknown
+							return out, ValidationUnknown
 						}
 					}
 					break
@@ -332,7 +325,7 @@
 				// determine if the resolver is frozen.
 				xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
 				if err != nil && err != preg.NotFound {
-					return ValidationUnknown
+					return out, ValidationUnknown
 				}
 				if err == nil {
 					vi = getExtensionFieldInfo(xt).validation
@@ -383,7 +376,7 @@
 					case b[9] < 0x80 && b[9] < 2:
 						b = b[10:]
 					default:
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				} else {
 					switch {
@@ -408,7 +401,7 @@
 					case len(b) > 9 && b[9] < 2:
 						b = b[10:]
 					default:
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				}
 				continue State
@@ -424,19 +417,19 @@
 					var n int
 					size, n = wire.ConsumeVarint(b)
 					if n < 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 					b = b[n:]
 				}
 				if size > uint64(len(b)) {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				v := b[:size]
 				b = b[size:]
 				switch vi.typ {
 				case validationTypeMessage:
 					if vi.mi == nil {
-						return ValidationUnknown
+						return out, ValidationUnknown
 					}
 					vi.mi.init()
 					fallthrough
@@ -455,40 +448,40 @@
 					for len(v) > 0 {
 						_, n := wire.ConsumeVarint(v)
 						if n < 0 {
-							return ValidationInvalid
+							return out, ValidationInvalid
 						}
 						v = v[n:]
 					}
 				case validationTypeRepeatedFixed32:
 					// Packed field.
 					if len(v)%4 != 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				case validationTypeRepeatedFixed64:
 					// Packed field.
 					if len(v)%8 != 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				case validationTypeUTF8String:
 					if !utf8.Valid(v) {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				}
 			case wire.Fixed32Type:
 				if len(b) < 4 {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				b = b[4:]
 			case wire.Fixed64Type:
 				if len(b) < 8 {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				b = b[8:]
 			case wire.StartGroupType:
 				switch vi.typ {
 				case validationTypeGroup:
 					if vi.mi == nil {
-						return ValidationUnknown
+						return out, ValidationUnknown
 					}
 					vi.mi.init()
 					states = append(states, validationState{
@@ -500,19 +493,19 @@
 				default:
 					n := wire.ConsumeFieldValue(num, wtyp, b)
 					if n < 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 					b = b[n:]
 				}
 			default:
-				return ValidationInvalid
+				return out, ValidationInvalid
 			}
 		}
 		if st.endGroup != 0 {
-			return ValidationInvalid
+			return out, ValidationInvalid
 		}
 		if len(b) != 0 {
-			return ValidationInvalid
+			return out, ValidationInvalid
 		}
 		b = st.tail
 	PopState:
@@ -535,8 +528,9 @@
 		}
 		states = states[:len(states)-1]
 	}
-	if !initialized {
-		return ValidationValidMaybeUninitalized
+	out.n = start - len(b)
+	if initialized {
+		out.initialized = true
 	}
-	return ValidationValidInitialized
+	return out, ValidationValid
 }
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index d980fd1..3847c21 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -28,6 +28,7 @@
 	checkFastInit    bool
 	unmarshalOptions proto.UnmarshalOptions
 	validationStatus impl.ValidationStatus
+	nocheckValidInit bool
 }
 
 func makeMessages(in protobuild.Message, messages ...proto.Message) []proto.Message {
@@ -1045,8 +1046,9 @@
 		}.Marshal(),
 	},
 	{
-		desc:          "required field in optional message set (split across multiple tags)",
-		checkFastInit: false, // fast init checks don't handle split messages
+		desc:             "required field in optional message set (split across multiple tags)",
+		checkFastInit:    false, // fast init checks don't handle split messages
+		nocheckValidInit: true,  // validation doesn't either
 		decodeTo: makeMessages(protobuild.Message{
 			"optional_message": protobuild.Message{
 				"required_field": 1,
@@ -1058,7 +1060,6 @@
 				pack.Tag{1, pack.VarintType}, pack.Varint(1),
 			}),
 		}.Marshal(),
-		validationStatus: impl.ValidationValidMaybeUninitalized,
 	},
 	{
 		desc:          "required field in repeated message unset",
diff --git a/proto/validate_test.go b/proto/validate_test.go
index bd4b811..490115a 100644
--- a/proto/validate_test.go
+++ b/proto/validate_test.go
@@ -23,16 +23,18 @@
 		for _, m := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
 				mt := m.ProtoReflect().Type()
-				want := impl.ValidationValidInitialized
+				want := impl.ValidationValid
 				if test.validationStatus != 0 {
 					want = test.validationStatus
-				} else if test.partial {
-					want = impl.ValidationValidMaybeUninitalized
 				}
 				var opts piface.UnmarshalOptions
 				opts.Resolver = protoregistry.GlobalTypes
-				if got, want := impl.Validate(test.wire, mt, opts), want; got != want {
-					t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
+				out, status := impl.Validate(test.wire, mt, opts)
+				if status != want {
+					t.Errorf("Validate(%x) = %v, want %v", test.wire, status, want)
+				}
+				if got, want := out.Initialized, !test.partial; got != want && !test.nocheckValidInit && status == impl.ValidationValid {
+					t.Errorf("Validate(%x): initialized = %v, want %v", test.wire, got, want)
 				}
 			})
 		}
@@ -46,7 +48,9 @@
 				mt := m.ProtoReflect().Type()
 				var opts piface.UnmarshalOptions
 				opts.Resolver = protoregistry.GlobalTypes
-				if got, want := impl.Validate(test.wire, mt, opts), impl.ValidationInvalid; got != want {
+				_, got := impl.Validate(test.wire, mt, opts)
+				want := impl.ValidationInvalid
+				if got != want {
 					t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
 				}
 			})