internal/impl: fix for lazy decoding of groups

Bit of a weird case in why this wasn't caught by tests: When validating
extension groups, we were validating an empty buffer rather than the
message content. For groups, this validation always fails due to a lack
of a group end tag. We'd then skip lazy decoding of the extension field
and proceed with eager decoding, which would behave correctly.

Change extension validation to report an error immediately on an invalid
result from the validator, which is both safe (assuming we trust the
validator) and would have caught this problem (by failing to decode the
extension field, rather than silently failing to eager decoding).

Change-Id: Id6c2d21fb687062bc74d9eb93760a1c24a6fe883
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217767
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 3155bc5..0fffad3 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -196,10 +196,17 @@
 	}
 	if flags.LazyUnmarshalExtensions {
 		if opts.IsDefault() && x.canLazy(xt) {
-			if out, ok := skipExtension(b, xi, num, wtyp, opts); ok && out.initialized {
-				x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
-				exts[int32(num)] = x
-				return out, nil
+			out, valid := skipExtension(b, xi, num, wtyp, opts)
+			switch valid {
+			case ValidationValid:
+				if out.initialized {
+					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
+					exts[int32(num)] = x
+					return out, nil
+				}
+			case ValidationInvalid:
+				return out, errors.New("invalid wire format")
+			case ValidationUnknown:
 			}
 		}
 	}
@@ -222,31 +229,30 @@
 	return out, nil
 }
 
-func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, ok bool) {
+func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
 	if xi.validation.mi == nil {
-		return out, false
+		return out, ValidationUnknown
 	}
 	xi.validation.mi.init()
-	var v []byte
 	switch xi.validation.typ {
 	case validationTypeMessage:
 		if wtyp != wire.BytesType {
-			return out, false
+			return out, ValidationUnknown
 		}
 		v, n := wire.ConsumeBytes(b)
 		if n < 0 {
-			return out, false
+			return out, ValidationUnknown
 		}
 		out, st := xi.validation.mi.validate(v, 0, opts)
 		out.n = n
-		return out, st == ValidationValid
+		return out, st
 	case validationTypeGroup:
 		if wtyp != wire.StartGroupType {
-			return out, false
+			return out, ValidationUnknown
 		}
-		out, st := xi.validation.mi.validate(v, num, opts)
-		return out, st == ValidationValid
+		out, st := xi.validation.mi.validate(b, num, opts)
+		return out, st
 	default:
-		return out, false
+		return out, ValidationUnknown
 	}
 }