all: use v2 Message interface for weak fields

Cleanup the generated logic by having the implementation be backed
by protoimpl rather that directly generated.

Weak fields are a deprecated feature of protobufs and
have entirely be superceded by extensions.
Unfortunately, there are still some usages of it.

Change-Id: Ie1a4b7da253e2ccf5e56627775d9b2fb4090d59a
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/229717
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/cmd/protoc-gen-go/internal_gengo/main.go b/cmd/protoc-gen-go/internal_gengo/main.go
index bd73826..d8f86c0 100644
--- a/cmd/protoc-gen-go/internal_gengo/main.go
+++ b/cmd/protoc-gen-go/internal_gengo/main.go
@@ -48,6 +48,7 @@
 // patched to support unique build environments that impose restrictions
 // on the dependencies of generated source code.
 var (
+	protoPackage        goImportPath = protogen.GoImportPath("google.golang.org/protobuf/proto")
 	protoifacePackage   goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoiface")
 	protoimplPackage    goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoimpl")
 	protoreflectPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoreflect")
@@ -573,15 +574,15 @@
 			field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
 		switch {
 		case field.Desc.IsWeak():
-			g.P(leadingComments, "func (x *", m.GoIdent, ") Get", field.GoName, "() ", protoifacePackage.Ident("MessageV1"), "{")
+			g.P(leadingComments, "func (x *", m.GoIdent, ") Get", field.GoName, "() ", protoPackage.Ident("Message"), "{")
+			g.P("var w ", protoimplPackage.Ident("WeakFields"))
 			g.P("if x != nil {")
-			g.P("v := x.", genname.WeakFields, "[", field.Desc.Number(), "]")
-			g.P("_ = x.", genname.WeakFieldPrefix+field.GoName) // for field-tracking
-			g.P("if v != nil {")
-			g.P("return v")
+			g.P("w = x.", genname.WeakFields)
+			if m.isTracked {
+				g.P("_ = x.", genname.WeakFieldPrefix+field.GoName)
+			}
 			g.P("}")
-			g.P("}")
-			g.P("return ", protoimplPackage.Ident("X"), ".WeakNil(", strconv.Quote(string(field.Message.Desc.FullName())), ")")
+			g.P("return ", protoimplPackage.Ident("X"), ".GetWeak(w, ", field.Desc.Number(), ", ", strconv.Quote(string(field.Message.Desc.FullName())), ")")
 			g.P("}")
 		case field.Oneof != nil && !field.Oneof.Desc.IsSynthetic():
 			g.P(leadingComments, "func (x *", m.GoIdent, ") Get", field.GoName, "() ", goType, " {")
@@ -621,16 +622,15 @@
 		g.Annotate(m.GoIdent.GoName+".Set"+field.GoName, field.Location)
 		leadingComments := appendDeprecationSuffix("",
 			field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
-		g.P(leadingComments, "func (x *", m.GoIdent, ") Set", field.GoName, "(v ", protoifacePackage.Ident("MessageV1"), ") {")
-		g.P("if x.", genname.WeakFields, " == nil {")
-		g.P("x.", genname.WeakFields, " = make(", protoimplPackage.Ident("WeakFields"), ")")
+		g.P(leadingComments, "func (x *", m.GoIdent, ") Set", field.GoName, "(v ", protoPackage.Ident("Message"), ") {")
+		g.P("var w *", protoimplPackage.Ident("WeakFields"))
+		g.P("if x != nil {")
+		g.P("w = &x.", genname.WeakFields)
+		if m.isTracked {
+			g.P("_ = x.", genname.WeakFieldPrefix+field.GoName)
+		}
 		g.P("}")
-		g.P("if v == nil {")
-		g.P("delete(x.", genname.WeakFields, ", ", field.Desc.Number(), ")")
-		g.P("} else {")
-		g.P("x.", genname.WeakFields, "[", field.Desc.Number(), "] = v")
-		g.P("x.", genname.WeakFieldPrefix+field.GoName, " = struct{}{}") // for field-tracking
-		g.P("}")
+		g.P(protoimplPackage.Ident("X"), ".SetWeak(w, ", field.Desc.Number(), ", ", strconv.Quote(string(field.Message.Desc.FullName())), ", v)")
 		g.P("}")
 		g.P()
 	}
diff --git a/internal/impl/legacy_export.go b/internal/impl/legacy_export.go
index 94f4572..c3d741c 100644
--- a/internal/impl/legacy_export.go
+++ b/internal/impl/legacy_export.go
@@ -7,14 +7,12 @@
 import (
 	"encoding/binary"
 	"encoding/json"
-	"fmt"
 	"hash/crc32"
 	"math"
 	"reflect"
 
 	"google.golang.org/protobuf/internal/errors"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
-	"google.golang.org/protobuf/reflect/protoregistry"
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
@@ -92,13 +90,3 @@
 	out = append(out, gzipFooter[:]...)
 	return out
 }
-
-// WeakNil returns a typed nil pointer to a concrete message.
-// It panics if the message is not linked into the binary.
-func (Export) WeakNil(s pref.FullName) piface.MessageV1 {
-	mt, err := protoregistry.GlobalTypes.FindMessageByName(s)
-	if err != nil {
-		panic(fmt.Sprintf("weak message %v is not linked in", s))
-	}
-	return mt.Zero().Interface().(piface.MessageV1)
-}
diff --git a/internal/impl/message.go b/internal/impl/message.go
index c1d8902..7dd994b 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -15,7 +15,6 @@
 	"google.golang.org/protobuf/internal/genname"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
-	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
 // MessageInfo provides protobuf related functionality for a given Go type
@@ -109,7 +108,7 @@
 
 type (
 	SizeCache       = int32
-	WeakFields      = map[int32]piface.MessageV1
+	WeakFields      = map[int32]protoreflect.ProtoMessage
 	UnknownFields   = []byte
 	ExtensionFields = map[int32]ExtensionField
 )
diff --git a/internal/impl/weak.go b/internal/impl/weak.go
index 575c988..009cbef 100644
--- a/internal/impl/weak.go
+++ b/internal/impl/weak.go
@@ -5,9 +5,10 @@
 package impl
 
 import (
-	"reflect"
+	"fmt"
 
 	pref "google.golang.org/protobuf/reflect/protoreflect"
+	"google.golang.org/protobuf/reflect/protoregistry"
 )
 
 // weakFields adds methods to the exported WeakFields type for internal use.
@@ -16,31 +17,58 @@
 // defined directly on it.
 type weakFields WeakFields
 
-func (w *weakFields) get(num pref.FieldNumber) (_ pref.ProtoMessage, ok bool) {
-	if *w == nil {
-		return nil, false
-	}
-	m, ok := (*w)[int32(num)]
-	if !ok {
-		return nil, false
-	}
-	// As a legacy quirk, consider a typed nil to be unset.
-	//
-	// TODO: Consider fixing the generated set methods to clear the field
-	// when provided with a typed nil.
-	if v := reflect.ValueOf(m); v.Kind() == reflect.Ptr && v.IsNil() {
-		return nil, false
-	}
-	return Export{}.ProtoMessageV2Of(m), true
+func (w weakFields) get(num pref.FieldNumber) (pref.ProtoMessage, bool) {
+	m, ok := w[int32(num)]
+	return m, ok
 }
 
 func (w *weakFields) set(num pref.FieldNumber, m pref.ProtoMessage) {
 	if *w == nil {
 		*w = make(weakFields)
 	}
-	(*w)[int32(num)] = Export{}.ProtoMessageV1Of(m)
+	(*w)[int32(num)] = m
 }
 
 func (w *weakFields) clear(num pref.FieldNumber) {
 	delete(*w, int32(num))
 }
+
+func (Export) HasWeak(w WeakFields, num pref.FieldNumber) bool {
+	_, ok := w[int32(num)]
+	return ok
+}
+
+func (Export) ClearWeak(w *WeakFields, num pref.FieldNumber) {
+	delete(*w, int32(num))
+}
+
+func (Export) GetWeak(w WeakFields, num pref.FieldNumber, name pref.FullName) pref.ProtoMessage {
+	if m, ok := w[int32(num)]; ok {
+		return m
+	}
+	mt, _ := protoregistry.GlobalTypes.FindMessageByName(name)
+	if mt == nil {
+		panic(fmt.Sprintf("message %v for weak field is not linked in", name))
+	}
+	return mt.Zero().Interface()
+}
+
+func (Export) SetWeak(w *WeakFields, num pref.FieldNumber, name pref.FullName, m pref.ProtoMessage) {
+	if m != nil {
+		mt, _ := protoregistry.GlobalTypes.FindMessageByName(name)
+		if mt == nil {
+			panic(fmt.Sprintf("message %v for weak field is not linked in", name))
+		}
+		if mt != m.ProtoReflect().Type() {
+			panic(fmt.Sprintf("invalid message type for weak field: got %T, want %T", m, mt.Zero().Interface()))
+		}
+	}
+	if m == nil || !m.ProtoReflect().IsValid() {
+		delete(*w, int32(num))
+		return
+	}
+	if *w == nil {
+		*w = make(weakFields)
+	}
+	(*w)[int32(num)] = m
+}
diff --git a/internal/testprotos/fieldtrack/fieldtrack.pb.go b/internal/testprotos/fieldtrack/fieldtrack.pb.go
index 57ad96e..1e2c806 100644
--- a/internal/testprotos/fieldtrack/fieldtrack.pb.go
+++ b/internal/testprotos/fieldtrack/fieldtrack.pb.go
@@ -10,8 +10,8 @@
 import (
 	_ "google.golang.org/protobuf/internal/testprotos/annotation"
 	test "google.golang.org/protobuf/internal/testprotos/test"
+	proto "google.golang.org/protobuf/proto"
 	protoreflect "google.golang.org/protobuf/reflect/protoreflect"
-	protoiface "google.golang.org/protobuf/runtime/protoiface"
 	protoimpl "google.golang.org/protobuf/runtime/protoimpl"
 	reflect "reflect"
 	sync "sync"
@@ -571,56 +571,46 @@
 
 //go:nointerface
 
-func (x *TestFieldTrack) GetWeakMessage1() protoiface.MessageV1 {
+func (x *TestFieldTrack) GetWeakMessage1() proto.Message {
+	var w protoimpl.WeakFields
 	if x != nil {
-		v := x.weakFields[100]
+		w = x.weakFields
 		_ = x.XXX_weak_WeakMessage1
-		if v != nil {
-			return v
-		}
 	}
-	return protoimpl.X.WeakNil("goproto.proto.test.weak.WeakImportMessage1")
+	return protoimpl.X.GetWeak(w, 100, "goproto.proto.test.weak.WeakImportMessage1")
 }
 
 //go:nointerface
 
-func (x *TestFieldTrack) GetWeakMessage2() protoiface.MessageV1 {
+func (x *TestFieldTrack) GetWeakMessage2() proto.Message {
+	var w protoimpl.WeakFields
 	if x != nil {
-		v := x.weakFields[101]
+		w = x.weakFields
 		_ = x.XXX_weak_WeakMessage2
-		if v != nil {
-			return v
-		}
 	}
-	return protoimpl.X.WeakNil("goproto.proto.test.weak.WeakImportMessage2")
+	return protoimpl.X.GetWeak(w, 101, "goproto.proto.test.weak.WeakImportMessage2")
 }
 
 //go:nointerface
 
-func (x *TestFieldTrack) SetWeakMessage1(v protoiface.MessageV1) {
-	if x.weakFields == nil {
-		x.weakFields = make(protoimpl.WeakFields)
+func (x *TestFieldTrack) SetWeakMessage1(v proto.Message) {
+	var w *protoimpl.WeakFields
+	if x != nil {
+		w = &x.weakFields
+		_ = x.XXX_weak_WeakMessage1
 	}
-	if v == nil {
-		delete(x.weakFields, 100)
-	} else {
-		x.weakFields[100] = v
-		x.XXX_weak_WeakMessage1 = struct{}{}
-	}
+	protoimpl.X.SetWeak(w, 100, "goproto.proto.test.weak.WeakImportMessage1", v)
 }
 
 //go:nointerface
 
-func (x *TestFieldTrack) SetWeakMessage2(v protoiface.MessageV1) {
-	if x.weakFields == nil {
-		x.weakFields = make(protoimpl.WeakFields)
+func (x *TestFieldTrack) SetWeakMessage2(v proto.Message) {
+	var w *protoimpl.WeakFields
+	if x != nil {
+		w = &x.weakFields
+		_ = x.XXX_weak_WeakMessage2
 	}
-	if v == nil {
-		delete(x.weakFields, 101)
-	} else {
-		x.weakFields[101] = v
-		x.XXX_weak_WeakMessage2 = struct{}{}
-	}
+	protoimpl.X.SetWeak(w, 101, "goproto.proto.test.weak.WeakImportMessage2", v)
 }
 
 var File_internal_testprotos_fieldtrack_fieldtrack_proto protoreflect.FileDescriptor
diff --git a/internal/testprotos/test/test.pb.go b/internal/testprotos/test/test.pb.go
index 1ec2a7d..2541f7f 100644
--- a/internal/testprotos/test/test.pb.go
+++ b/internal/testprotos/test/test.pb.go
@@ -8,6 +8,7 @@
 package test
 
 import (
+	proto "google.golang.org/protobuf/proto"
 	protoreflect "google.golang.org/protobuf/reflect/protoreflect"
 	protoiface "google.golang.org/protobuf/runtime/protoiface"
 	protoimpl "google.golang.org/protobuf/runtime/protoimpl"
@@ -1725,50 +1726,36 @@
 	return file_internal_testprotos_test_test_proto_rawDescGZIP(), []int{11}
 }
 
-func (x *TestWeak) GetWeakMessage1() protoiface.MessageV1 {
+func (x *TestWeak) GetWeakMessage1() proto.Message {
+	var w protoimpl.WeakFields
 	if x != nil {
-		v := x.weakFields[1]
-		_ = x.XXX_weak_WeakMessage1
-		if v != nil {
-			return v
-		}
+		w = x.weakFields
 	}
-	return protoimpl.X.WeakNil("goproto.proto.test.weak.WeakImportMessage1")
+	return protoimpl.X.GetWeak(w, 1, "goproto.proto.test.weak.WeakImportMessage1")
 }
 
-func (x *TestWeak) GetWeakMessage2() protoiface.MessageV1 {
+func (x *TestWeak) GetWeakMessage2() proto.Message {
+	var w protoimpl.WeakFields
 	if x != nil {
-		v := x.weakFields[2]
-		_ = x.XXX_weak_WeakMessage2
-		if v != nil {
-			return v
-		}
+		w = x.weakFields
 	}
-	return protoimpl.X.WeakNil("goproto.proto.test.weak.WeakImportMessage2")
+	return protoimpl.X.GetWeak(w, 2, "goproto.proto.test.weak.WeakImportMessage2")
 }
 
-func (x *TestWeak) SetWeakMessage1(v protoiface.MessageV1) {
-	if x.weakFields == nil {
-		x.weakFields = make(protoimpl.WeakFields)
+func (x *TestWeak) SetWeakMessage1(v proto.Message) {
+	var w *protoimpl.WeakFields
+	if x != nil {
+		w = &x.weakFields
 	}
-	if v == nil {
-		delete(x.weakFields, 1)
-	} else {
-		x.weakFields[1] = v
-		x.XXX_weak_WeakMessage1 = struct{}{}
-	}
+	protoimpl.X.SetWeak(w, 1, "goproto.proto.test.weak.WeakImportMessage1", v)
 }
 
-func (x *TestWeak) SetWeakMessage2(v protoiface.MessageV1) {
-	if x.weakFields == nil {
-		x.weakFields = make(protoimpl.WeakFields)
+func (x *TestWeak) SetWeakMessage2(v proto.Message) {
+	var w *protoimpl.WeakFields
+	if x != nil {
+		w = &x.weakFields
 	}
-	if v == nil {
-		delete(x.weakFields, 2)
-	} else {
-		x.weakFields[2] = v
-		x.XXX_weak_WeakMessage2 = struct{}{}
-	}
+	protoimpl.X.SetWeak(w, 2, "goproto.proto.test.weak.WeakImportMessage2", v)
 }
 
 type TestPackedTypes struct {