internal/impl: lazy extension decoding

Historically, extensions have been placed in the unknown fields section
of the unmarshaled message and decoded lazily on demand. The current
unmarshal implementation decodes extensions eagerly at unmarshal time,
permitting errors to be immediately reported and correctly detecting
unset required fields in extension values.

Add support for validated lazy extension decoding, where the extension
value is fully validated at initial unmarshal time but the fully
unmarshaled message is only created lazily.

Make this behavior conditional on the protolegacy flag for now.

Change-Id: I9d742496a4bd4dafea83fca8619cd6e8d7e65bc3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/216764
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/flags/flags.go b/internal/flags/flags.go
index 89b0aff..ab62da2 100644
--- a/internal/flags/flags.go
+++ b/internal/flags/flags.go
@@ -15,3 +15,10 @@
 // WARNING: The compatibility agreement covers nothing provided by this flag.
 // As such, functionality may suddenly be removed or changed at our discretion.
 const ProtoLegacy = protoLegacy
+
+// LazyUnmarshalExtension specifies whether to lazily unmarshal extensions.
+//
+// Lazy extension unmarshaling validates the contents of message-valued
+// extension fields at unmarshal time, but defers creating the message
+// structure until the extension is first accessed.
+const LazyUnmarshalExtensions = ProtoLegacy
diff --git a/internal/impl/codec_extension.go b/internal/impl/codec_extension.go
index 8c80377..141a3cd 100644
--- a/internal/impl/codec_extension.go
+++ b/internal/impl/codec_extension.go
@@ -9,6 +9,7 @@
 	"sync/atomic"
 
 	"google.golang.org/protobuf/internal/encoding/wire"
+	"google.golang.org/protobuf/internal/errors"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
@@ -68,6 +69,15 @@
 	return e
 }
 
+type lazyExtensionValue struct {
+	atomicOnce uint32 // atomically set if value is valid
+	mu         sync.Mutex
+	xi         *extensionFieldInfo
+	value      pref.Value
+	b          []byte
+	fn         func() pref.Value
+}
+
 type ExtensionField struct {
 	typ pref.ExtensionType
 
@@ -77,25 +87,91 @@
 	lazy  *lazyExtensionValue
 }
 
+func (f *ExtensionField) appendLazyBytes(xt pref.ExtensionType, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, b []byte) {
+	if f.lazy == nil {
+		f.lazy = &lazyExtensionValue{xi: xi}
+	}
+	f.typ = xt
+	f.lazy.xi = xi
+	f.lazy.b = wire.AppendTag(f.lazy.b, num, wtyp)
+	f.lazy.b = append(f.lazy.b, b...)
+}
+
+func (f *ExtensionField) canLazy(xt pref.ExtensionType) bool {
+	if f.typ == nil {
+		return true
+	}
+	if f.typ == xt && f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0 {
+		return true
+	}
+	return false
+}
+
+func (f *ExtensionField) lazyInit() {
+	f.lazy.mu.Lock()
+	defer f.lazy.mu.Unlock()
+	if f.lazy.xi != nil {
+		b := f.lazy.b
+		val := f.typ.New()
+		for len(b) > 0 {
+			var tag uint64
+			if b[0] < 0x80 {
+				tag = uint64(b[0])
+				b = b[1:]
+			} else if len(b) >= 2 && b[1] < 128 {
+				tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
+				b = b[2:]
+			} else {
+				var n int
+				tag, n = wire.ConsumeVarint(b)
+				if n < 0 {
+					panic(errors.New("bad tag in lazy extension decoding"))
+				}
+				b = b[n:]
+			}
+			num := wire.Number(tag >> 3)
+			wtyp := wire.Type(tag & 7)
+			var out unmarshalOutput
+			var err error
+			val, out, err = f.lazy.xi.funcs.unmarshal(b, val, num, wtyp, unmarshalOptions{}) // TODO: options
+			if err != nil {
+				panic(errors.New("decode failure in lazy extension decoding: %v", err))
+			}
+			b = b[out.n:]
+		}
+		f.lazy.value = val
+	} else {
+		f.lazy.value = f.lazy.fn()
+	}
+	f.lazy.xi = nil
+	f.lazy.fn = nil
+	f.lazy.b = nil
+	atomic.StoreUint32(&f.lazy.atomicOnce, 1)
+}
+
 // Set sets the type and value of the extension field.
 // This must not be called concurrently.
 func (f *ExtensionField) Set(t pref.ExtensionType, v pref.Value) {
 	f.typ = t
 	f.value = v
+	f.lazy = nil
 }
 
 // SetLazy sets the type and a value that is to be lazily evaluated upon first use.
 // This must not be called concurrently.
 func (f *ExtensionField) SetLazy(t pref.ExtensionType, fn func() pref.Value) {
 	f.typ = t
-	f.lazy = &lazyExtensionValue{value: fn}
+	f.lazy = &lazyExtensionValue{fn: fn}
 }
 
 // Value returns the value of the extension field.
 // This may be called concurrently.
 func (f *ExtensionField) Value() pref.Value {
 	if f.lazy != nil {
-		return f.lazy.GetValue()
+		if atomic.LoadUint32(&f.lazy.atomicOnce) == 0 {
+			f.lazyInit()
+		}
+		return f.lazy.value
 	}
 	return f.value
 }
@@ -144,25 +220,7 @@
 
 // Deprecated: Do not use.
 func (f *ExtensionField) SetLazyValue(fn func() interface{}) {
-	f.lazy = &lazyExtensionValue{value: func() pref.Value {
+	f.SetLazy(f.typ, func() pref.Value {
 		return f.typ.ValueOf(fn())
-	}}
-}
-
-type lazyExtensionValue struct {
-	once  uint32      // atomically set if value is valid
-	mu    sync.Mutex  // protects value
-	value interface{} // either a pref.Value itself or a func() pref.ValueOf
-}
-
-func (v *lazyExtensionValue) GetValue() pref.Value {
-	if atomic.LoadUint32(&v.once) == 0 {
-		v.mu.Lock()
-		if f, ok := v.value.(func() pref.Value); ok {
-			v.value = f()
-		}
-		atomic.StoreUint32(&v.once, 1)
-		v.mu.Unlock()
-	}
-	return v.value.(pref.Value)
+	})
 }
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index cbc21b3..48f7ca5 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -29,6 +29,12 @@
 
 func (o unmarshalOptions) DiscardUnknown() bool { return o.Flags&piface.UnmarshalDiscardUnknown != 0 }
 
+func (o unmarshalOptions) IsDefault() bool {
+	// The UnmarshalDefaultResolver flag indicates that we're using the default resolver.
+	// No other flag bit should be set.
+	return o.Flags == piface.UnmarshalDefaultResolver
+}
+
 type unmarshalOutput struct {
 	n           int // number of bytes consumed
 	initialized bool
@@ -185,6 +191,17 @@
 	if xi.funcs.unmarshal == nil {
 		return out, errUnknown
 	}
+	if flags.LazyUnmarshalExtensions {
+		if opts.IsDefault() && x.canLazy(xt) {
+			if n, ok := skipExtension(b, xi, num, wtyp, opts); ok {
+				x.appendLazyBytes(xt, xi, num, wtyp, b[:n])
+				exts[int32(num)] = x
+				out.n = n
+				out.initialized = true
+				return out, nil
+			}
+		}
+	}
 	ival := x.Value()
 	if !ival.IsValid() && xi.unmarshalNeedsValue {
 		// Create a new message, list, or map value to fill in.
@@ -200,3 +217,36 @@
 	exts[int32(num)] = x
 	return out, nil
 }
+
+func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (n int, ok bool) {
+	if xi.validation.mi == nil {
+		return 0, false
+	}
+	xi.validation.mi.init()
+	var v []byte
+	switch xi.validation.typ {
+	case validationTypeMessage:
+		if wtyp != wire.BytesType {
+			return 0, false
+		}
+		v, n = wire.ConsumeBytes(b)
+		if n < 0 {
+			return 0, false
+		}
+	case validationTypeGroup:
+		if wtyp != wire.StartGroupType {
+			return 0, false
+		}
+		v, n = wire.ConsumeGroup(num, b)
+		if n < 0 {
+			return 0, false
+		}
+	default:
+		return 0, false
+	}
+	if xi.validation.mi.validate(v, 0, opts) != ValidationValidInitialized {
+		return 0, false
+	}
+	return n, true
+
+}
diff --git a/proto/decode.go b/proto/decode.go
index f3cd997..e1177b8 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -64,8 +64,12 @@
 }
 
 func (o UnmarshalOptions) unmarshal(b []byte, message Message) (out protoiface.UnmarshalOutput, err error) {
+	defaultResolver := false
 	if o.Resolver == nil {
 		o.Resolver = protoregistry.GlobalTypes
+		defaultResolver = true
+	} else if o.Resolver == protoregistry.GlobalTypes {
+		defaultResolver = true
 	}
 	if !o.Merge {
 		Reset(message)
@@ -83,6 +87,9 @@
 		if o.DiscardUnknown {
 			opts.Flags |= protoiface.UnmarshalDiscardUnknown
 		}
+		if defaultResolver {
+			opts.Flags |= protoiface.UnmarshalDefaultResolver
+		}
 		out, err = methods.Unmarshal(m, protoiface.UnmarshalInput{
 			Buf: b,
 		}, opts)
diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go
index 54e7fb3..a9aa9bd 100644
--- a/runtime/protoiface/methods.go
+++ b/runtime/protoiface/methods.go
@@ -112,4 +112,8 @@
 
 const (
 	UnmarshalDiscardUnknown UnmarshalFlags = 1 << iota
+
+	// UnmarshalDefaultResolver indicates that the provided extension type
+	// resolver is protoregistry.GlobalTypes.
+	UnmarshalDefaultResolver
 )