internal/impl: reduce redundant MessageInfo initializations in validator

name                            old time/op    new time/op    delta
EmptyMessage/Wire/Validate-12     4.58ns ± 0%    4.29ns ± 1%   -6.22%  (p=0.000 n=7+8)
RepeatedInt32/Wire/Validate-12     702ns ± 1%     518ns ± 0%  -26.12%  (p=0.001 n=7+7)
Required/Wire/Validate-12         30.6ns ± 6%    22.1ns ± 0%  -27.81%  (p=0.000 n=8+7)

Change-Id: I0d1db8583aa0bf4468bc385c213eb6adff001297
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/216627
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index bb6d47d..4ae1d03 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -220,6 +220,7 @@
 }
 
 func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
+	mi.init()
 	type validationState struct {
 		typ              validationType
 		keyType, valType validationType
@@ -244,7 +245,6 @@
 	for len(states) > 0 {
 		st := &states[len(states)-1]
 		if st.mi != nil {
-			st.mi.init()
 			if flags.ProtoLegacy && st.mi.isMessageSet {
 				return ValidationUnknown
 			}
@@ -434,10 +434,13 @@
 				v := b[:size]
 				b = b[size:]
 				switch vi.typ {
-				case validationTypeMessage, validationTypeMap:
-					if vi.mi == nil && vi.typ == validationTypeMessage {
+				case validationTypeMessage:
+					if vi.mi == nil {
 						return ValidationUnknown
 					}
+					vi.mi.init()
+					fallthrough
+				case validationTypeMap:
 					states = append(states, validationState{
 						typ:     vi.typ,
 						keyType: vi.keyType,
@@ -487,6 +490,7 @@
 					if vi.mi == nil {
 						return ValidationUnknown
 					}
+					vi.mi.init()
 					states = append(states, validationState{
 						typ:      validationTypeGroup,
 						mi:       vi.mi,