diff --git a/internal/benchmarks/micro/micro_test.go b/internal/benchmarks/micro/micro_test.go
index 097a326..b36dc78 100644
--- a/internal/benchmarks/micro/micro_test.go
+++ b/internal/benchmarks/micro/micro_test.go
@@ -55,7 +55,9 @@
 				Resolver: protoregistry.GlobalTypes,
 			}
 			for pb.Next() {
-				if got, want := impl.Validate([]byte{}, mt, opts), impl.ValidationValidInitialized; got != want {
+				_, got := impl.Validate([]byte{}, mt, opts)
+				want := impl.ValidationValid
+				if got != want {
 					b.Fatalf("Validate = %v, want %v", got, want)
 				}
 			}
@@ -106,7 +108,9 @@
 				Resolver: protoregistry.GlobalTypes,
 			}
 			for pb.Next() {
-				if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
+				_, got := impl.Validate(w, mt, opts)
+				want := impl.ValidationValid
+				if got != want {
 					b.Fatalf("Validate = %v, want %v", got, want)
 				}
 			}
@@ -167,7 +171,9 @@
 				Resolver: protoregistry.GlobalTypes,
 			}
 			for pb.Next() {
-				if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
+				_, got := impl.Validate(w, mt, opts)
+				want := impl.ValidationValid
+				if got != want {
 					b.Fatalf("Validate = %v, want %v", got, want)
 				}
 			}
diff --git a/internal/fuzz/wirefuzz/fuzz.go b/internal/fuzz/wirefuzz/fuzz.go
index 7ca46ba..28aed57 100644
--- a/internal/fuzz/wirefuzz/fuzz.go
+++ b/internal/fuzz/wirefuzz/fuzz.go
@@ -19,7 +19,7 @@
 // Fuzz is a fuzzer for proto.Marshal and proto.Unmarshal.
 func Fuzz(data []byte) (score int) {
 	m1 := &fuzzpb.Fuzz{}
-	valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
+	vout, valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
 		Resolver: protoregistry.GlobalTypes,
 	})
 	if err := (proto.UnmarshalOptions{
@@ -33,21 +33,14 @@
 		}
 		return 0
 	}
-	if proto.IsInitialized(m1) == nil {
-		switch valid {
-		case impl.ValidationUnknown:
-		case impl.ValidationValidInitialized:
-		case impl.ValidationValidMaybeUninitalized:
-		default:
-			panic("unmarshal ok with validation status: " + valid.String())
-		}
-	} else {
-		switch valid {
-		case impl.ValidationUnknown:
-		case impl.ValidationValidMaybeUninitalized:
-		default:
-			panic("partial unmarshal ok with validation status: " + valid.String())
-		}
+	switch valid {
+	case impl.ValidationUnknown:
+	case impl.ValidationValid:
+	default:
+		panic("unmarshal ok with validation status: " + valid.String())
+	}
+	if proto.IsInitialized(m1) != nil && vout.Initialized {
+		panic("validation reports partial message is initialized")
 	}
 	data1, err := proto.MarshalOptions{
 		AllowPartial: true,
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 290fc41..3155bc5 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -196,11 +196,9 @@
 	}
 	if flags.LazyUnmarshalExtensions {
 		if opts.IsDefault() && x.canLazy(xt) {
-			if n, ok := skipExtension(b, xi, num, wtyp, opts); ok {
-				x.appendLazyBytes(xt, xi, num, wtyp, b[:n])
+			if out, ok := skipExtension(b, xi, num, wtyp, opts); ok && out.initialized {
+				x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
 				exts[int32(num)] = x
-				out.n = n
-				out.initialized = true
 				return out, nil
 			}
 		}
@@ -224,35 +222,31 @@
 	return out, nil
 }
 
-func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (n int, ok bool) {
+func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, ok bool) {
 	if xi.validation.mi == nil {
-		return 0, false
+		return out, false
 	}
 	xi.validation.mi.init()
 	var v []byte
 	switch xi.validation.typ {
 	case validationTypeMessage:
 		if wtyp != wire.BytesType {
-			return 0, false
+			return out, false
 		}
-		v, n = wire.ConsumeBytes(b)
+		v, n := wire.ConsumeBytes(b)
 		if n < 0 {
-			return 0, false
+			return out, false
 		}
+		out, st := xi.validation.mi.validate(v, 0, opts)
+		out.n = n
+		return out, st == ValidationValid
 	case validationTypeGroup:
 		if wtyp != wire.StartGroupType {
-			return 0, false
+			return out, false
 		}
-		v, n = wire.ConsumeGroup(num, b)
-		if n < 0 {
-			return 0, false
-		}
+		out, st := xi.validation.mi.validate(v, num, opts)
+		return out, st == ValidationValid
 	default:
-		return 0, false
+		return out, false
 	}
-	if xi.validation.mi.validate(v, 0, opts) != ValidationValidInitialized {
-		return 0, false
-	}
-	return n, true
-
 }
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index eab8ec0..0c32026 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -33,16 +33,8 @@
 	// ValidationInvalid indicates that unmarshaling the message will fail.
 	ValidationInvalid
 
-	// ValidationValidInitialized indicates that unmarshaling the message will succeed
-	// and IsInitialized on the result will report success.
-	ValidationValidInitialized
-
-	// ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
-	// but the output of IsInitialized on the result is unknown.
-	//
-	// This status may be returned for an initialized message when a message value
-	// is split across multiple fields.
-	ValidationValidMaybeUninitalized
+	// ValidationValid indicates that unmarshaling the message will succeed.
+	ValidationValid
 )
 
 func (v ValidationStatus) String() string {
@@ -51,10 +43,8 @@
 		return "ValidationUnknown"
 	case ValidationInvalid:
 		return "ValidationInvalid"
-	case ValidationValidInitialized:
-		return "ValidationValidInitialized"
-	case ValidationValidMaybeUninitalized:
-		return "ValidationValidMaybeUninitalized"
+	case ValidationValid:
+		return "ValidationValid"
 	default:
 		return fmt.Sprintf("ValidationStatus(%d)", int(v))
 	}
@@ -64,12 +54,14 @@
 // of the message type.
 //
 // This function is exposed for testing.
-func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
+func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) (out piface.UnmarshalOutput, _ ValidationStatus) {
 	mi, ok := mt.(*MessageInfo)
 	if !ok {
-		return ValidationUnknown
+		return out, ValidationUnknown
 	}
-	return mi.validate(b, 0, unmarshalOptions(opts))
+	o, st := mi.validate(b, 0, unmarshalOptions(opts))
+	out.Initialized = o.initialized
+	return out, st
 }
 
 type validationInfo struct {
@@ -219,7 +211,7 @@
 	return vi
 }
 
-func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
+func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
 	mi.init()
 	type validationState struct {
 		typ              validationType
@@ -241,12 +233,13 @@
 		states[0].endGroup = groupTag
 	}
 	initialized := true
+	start := len(b)
 State:
 	for len(states) > 0 {
 		st := &states[len(states)-1]
 		if st.mi != nil {
 			if flags.ProtoLegacy && st.mi.isMessageSet {
-				return ValidationUnknown
+				return out, ValidationUnknown
 			}
 		}
 		for len(b) > 0 {
@@ -262,13 +255,13 @@
 				var n int
 				tag, n = wire.ConsumeVarint(b)
 				if n < 0 {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				b = b[n:]
 			}
 			var num wire.Number
 			if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
-				return ValidationInvalid
+				return out, ValidationInvalid
 			} else {
 				num = wire.Number(n)
 			}
@@ -278,7 +271,7 @@
 				if st.endGroup == num {
 					goto PopState
 				}
-				return ValidationInvalid
+				return out, ValidationInvalid
 			}
 			var vi validationInfo
 			switch st.typ {
@@ -317,7 +310,7 @@
 						case preg.NotFound:
 							vi.typ = validationTypeBytes
 						default:
-							return ValidationUnknown
+							return out, ValidationUnknown
 						}
 					}
 					break
@@ -332,7 +325,7 @@
 				// determine if the resolver is frozen.
 				xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
 				if err != nil && err != preg.NotFound {
-					return ValidationUnknown
+					return out, ValidationUnknown
 				}
 				if err == nil {
 					vi = getExtensionFieldInfo(xt).validation
@@ -383,7 +376,7 @@
 					case b[9] < 0x80 && b[9] < 2:
 						b = b[10:]
 					default:
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				} else {
 					switch {
@@ -408,7 +401,7 @@
 					case len(b) > 9 && b[9] < 2:
 						b = b[10:]
 					default:
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				}
 				continue State
@@ -424,19 +417,19 @@
 					var n int
 					size, n = wire.ConsumeVarint(b)
 					if n < 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 					b = b[n:]
 				}
 				if size > uint64(len(b)) {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				v := b[:size]
 				b = b[size:]
 				switch vi.typ {
 				case validationTypeMessage:
 					if vi.mi == nil {
-						return ValidationUnknown
+						return out, ValidationUnknown
 					}
 					vi.mi.init()
 					fallthrough
@@ -455,40 +448,40 @@
 					for len(v) > 0 {
 						_, n := wire.ConsumeVarint(v)
 						if n < 0 {
-							return ValidationInvalid
+							return out, ValidationInvalid
 						}
 						v = v[n:]
 					}
 				case validationTypeRepeatedFixed32:
 					// Packed field.
 					if len(v)%4 != 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				case validationTypeRepeatedFixed64:
 					// Packed field.
 					if len(v)%8 != 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				case validationTypeUTF8String:
 					if !utf8.Valid(v) {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				}
 			case wire.Fixed32Type:
 				if len(b) < 4 {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				b = b[4:]
 			case wire.Fixed64Type:
 				if len(b) < 8 {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				b = b[8:]
 			case wire.StartGroupType:
 				switch vi.typ {
 				case validationTypeGroup:
 					if vi.mi == nil {
-						return ValidationUnknown
+						return out, ValidationUnknown
 					}
 					vi.mi.init()
 					states = append(states, validationState{
@@ -500,19 +493,19 @@
 				default:
 					n := wire.ConsumeFieldValue(num, wtyp, b)
 					if n < 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 					b = b[n:]
 				}
 			default:
-				return ValidationInvalid
+				return out, ValidationInvalid
 			}
 		}
 		if st.endGroup != 0 {
-			return ValidationInvalid
+			return out, ValidationInvalid
 		}
 		if len(b) != 0 {
-			return ValidationInvalid
+			return out, ValidationInvalid
 		}
 		b = st.tail
 	PopState:
@@ -535,8 +528,9 @@
 		}
 		states = states[:len(states)-1]
 	}
-	if !initialized {
-		return ValidationValidMaybeUninitalized
+	out.n = start - len(b)
+	if initialized {
+		out.initialized = true
 	}
-	return ValidationValidInitialized
+	return out, ValidationValid
 }
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index d980fd1..3847c21 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -28,6 +28,7 @@
 	checkFastInit    bool
 	unmarshalOptions proto.UnmarshalOptions
 	validationStatus impl.ValidationStatus
+	nocheckValidInit bool
 }
 
 func makeMessages(in protobuild.Message, messages ...proto.Message) []proto.Message {
@@ -1045,8 +1046,9 @@
 		}.Marshal(),
 	},
 	{
-		desc:          "required field in optional message set (split across multiple tags)",
-		checkFastInit: false, // fast init checks don't handle split messages
+		desc:             "required field in optional message set (split across multiple tags)",
+		checkFastInit:    false, // fast init checks don't handle split messages
+		nocheckValidInit: true,  // validation doesn't either
 		decodeTo: makeMessages(protobuild.Message{
 			"optional_message": protobuild.Message{
 				"required_field": 1,
@@ -1058,7 +1060,6 @@
 				pack.Tag{1, pack.VarintType}, pack.Varint(1),
 			}),
 		}.Marshal(),
-		validationStatus: impl.ValidationValidMaybeUninitalized,
 	},
 	{
 		desc:          "required field in repeated message unset",
diff --git a/proto/validate_test.go b/proto/validate_test.go
index bd4b811..490115a 100644
--- a/proto/validate_test.go
+++ b/proto/validate_test.go
@@ -23,16 +23,18 @@
 		for _, m := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
 				mt := m.ProtoReflect().Type()
-				want := impl.ValidationValidInitialized
+				want := impl.ValidationValid
 				if test.validationStatus != 0 {
 					want = test.validationStatus
-				} else if test.partial {
-					want = impl.ValidationValidMaybeUninitalized
 				}
 				var opts piface.UnmarshalOptions
 				opts.Resolver = protoregistry.GlobalTypes
-				if got, want := impl.Validate(test.wire, mt, opts), want; got != want {
-					t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
+				out, status := impl.Validate(test.wire, mt, opts)
+				if status != want {
+					t.Errorf("Validate(%x) = %v, want %v", test.wire, status, want)
+				}
+				if got, want := out.Initialized, !test.partial; got != want && !test.nocheckValidInit && status == impl.ValidationValid {
+					t.Errorf("Validate(%x): initialized = %v, want %v", test.wire, got, want)
 				}
 			})
 		}
@@ -46,7 +48,9 @@
 				mt := m.ProtoReflect().Type()
 				var opts piface.UnmarshalOptions
 				opts.Resolver = protoregistry.GlobalTypes
-				if got, want := impl.Validate(test.wire, mt, opts), impl.ValidationInvalid; got != want {
+				_, got := impl.Validate(test.wire, mt, opts)
+				want := impl.ValidationInvalid
+				if got != want {
 					t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
 				}
 			})
