internal/impl: weak field bugfixes

Fix a reversed error check in impl.Export{}.WeakNil.

Check to see if we have a type for the weak field on marshal/size.

Treat a typed nil valued in XXX_Weak as not indicating presence for
the field.

Change-Id: Id667ac7eb4f53236be9e181017082bd8cd21d115
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/198717
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 45b4664..4cfd9b0 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -5,6 +5,7 @@
 package impl
 
 import (
+	"fmt"
 	"reflect"
 	"sync"
 
@@ -12,7 +13,6 @@
 	"google.golang.org/protobuf/proto"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	preg "google.golang.org/protobuf/reflect/protoregistry"
-	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
 type errInvalidUTF8 struct{}
@@ -91,47 +91,49 @@
 		})
 	}
 
-	num := int32(fd.Number())
+	num := fd.Number()
 	return pointerCoderFuncs{
 		size: func(p pointer, tagsize int, opts marshalOptions) int {
-			fs := p.WeakFields()
-			m, ok := (*fs)[num]
+			m, ok := p.WeakFields().get(num)
 			if !ok {
 				return 0
 			}
-			return sizeMessage(m.(proto.Message), tagsize, opts)
+			lazyInit()
+			if messageType == nil {
+				panic(fmt.Sprintf("weak message %v is not linked in", fd.Message().FullName()))
+			}
+			return sizeMessage(m, tagsize, opts)
 		},
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-			fs := p.WeakFields()
-			m, ok := (*fs)[num]
+			m, ok := p.WeakFields().get(num)
 			if !ok {
 				return b, nil
 			}
-			return appendMessage(b, m.(proto.Message), wiretag, opts)
+			lazyInit()
+			if messageType == nil {
+				panic(fmt.Sprintf("weak message %v is not linked in", fd.Message().FullName()))
+			}
+			return appendMessage(b, m, wiretag, opts)
 		},
 		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
 			fs := p.WeakFields()
-			m, ok := (*fs)[num]
+			m, ok := fs.get(num)
 			if !ok {
 				lazyInit()
 				if messageType == nil {
 					return 0, errUnknown
 				}
-				m = messageType.New().Interface().(piface.MessageV1)
-				if *fs == nil {
-					*fs = make(WeakFields)
-				}
-				(*fs)[num] = m
+				m = messageType.New().Interface()
+				fs.set(num, m)
 			}
-			return consumeMessage(b, m.(proto.Message), wtyp, opts)
+			return consumeMessage(b, m, wtyp, opts)
 		},
 		isInit: func(p pointer) error {
-			fs := p.WeakFields()
-			m, ok := (*fs)[num]
+			m, ok := p.WeakFields().get(num)
 			if !ok {
 				return nil
 			}
-			return proto.IsInitialized(m.(proto.Message))
+			return proto.IsInitialized(m)
 		},
 	}
 }
diff --git a/internal/impl/legacy_export.go b/internal/impl/legacy_export.go
index 989e944..cf17794 100644
--- a/internal/impl/legacy_export.go
+++ b/internal/impl/legacy_export.go
@@ -97,8 +97,8 @@
 // It panics if the message is not linked into the binary.
 func (Export) WeakNil(s pref.FullName) piface.MessageV1 {
 	mt, err := protoregistry.GlobalTypes.FindMessageByName(s)
-	if err == nil {
+	if err != nil {
 		panic(fmt.Sprintf("weak message %v is not linked in", s))
 	}
-	return reflect.Zero(mt.GoType()).Interface().(piface.MessageV1)
+	return mt.Zero().Interface().(piface.MessageV1)
 }
diff --git a/internal/impl/message_reflect_field.go b/internal/impl/message_reflect_field.go
index 63b4055..8d4e6ae 100644
--- a/internal/impl/message_reflect_field.go
+++ b/internal/impl/message_reflect_field.go
@@ -13,7 +13,6 @@
 	"google.golang.org/protobuf/internal/flags"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	preg "google.golang.org/protobuf/reflect/protoregistry"
-	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
 type fieldInfo struct {
@@ -306,32 +305,29 @@
 		})
 	}
 
-	num := int32(fd.Number())
+	num := fd.Number()
 	return fieldInfo{
 		fieldDesc: fd,
 		has: func(p pointer) bool {
 			if p.IsNil() {
 				return false
 			}
-			fs := p.Apply(weakOffset).WeakFields()
-			_, ok := (*fs)[num]
+			_, ok := p.Apply(weakOffset).WeakFields().get(num)
 			return ok
 		},
 		clear: func(p pointer) {
-			fs := p.Apply(weakOffset).WeakFields()
-			delete(*fs, num)
+			p.Apply(weakOffset).WeakFields().clear(num)
 		},
 		get: func(p pointer) pref.Value {
 			lazyInit()
 			if p.IsNil() {
 				return pref.ValueOfMessage(messageType.Zero())
 			}
-			fs := p.Apply(weakOffset).WeakFields()
-			m, ok := (*fs)[num]
+			m, ok := p.Apply(weakOffset).WeakFields().get(num)
 			if !ok {
 				return pref.ValueOfMessage(messageType.Zero())
 			}
-			return pref.ValueOfMessage(m.(pref.ProtoMessage).ProtoReflect())
+			return pref.ValueOfMessage(m.ProtoReflect())
 		},
 		set: func(p pointer, v pref.Value) {
 			lazyInit()
@@ -339,24 +335,17 @@
 			if m.Descriptor() != messageType.Descriptor() {
 				panic("mismatching message descriptor")
 			}
-			fs := p.Apply(weakOffset).WeakFields()
-			if *fs == nil {
-				*fs = make(WeakFields)
-			}
-			(*fs)[num] = m.Interface().(piface.MessageV1)
+			p.Apply(weakOffset).WeakFields().set(num, m.Interface())
 		},
 		mutable: func(p pointer) pref.Value {
 			lazyInit()
 			fs := p.Apply(weakOffset).WeakFields()
-			if *fs == nil {
-				*fs = make(WeakFields)
-			}
-			m, ok := (*fs)[num]
+			m, ok := fs.get(num)
 			if !ok {
-				m = messageType.New().Interface().(piface.MessageV1)
-				(*fs)[num] = m
+				m = messageType.New().Interface()
+				fs.set(num, m)
 			}
-			return pref.ValueOfMessage(m.(pref.ProtoMessage).ProtoReflect())
+			return pref.ValueOfMessage(m.ProtoReflect())
 		},
 		newMessage: func() pref.Message {
 			lazyInit()
diff --git a/internal/impl/pointer_reflect.go b/internal/impl/pointer_reflect.go
index 74345ec..67b4ede 100644
--- a/internal/impl/pointer_reflect.go
+++ b/internal/impl/pointer_reflect.go
@@ -122,7 +122,7 @@
 func (p pointer) StringSlice() *[]string   { return p.v.Interface().(*[]string) }
 func (p pointer) Bytes() *[]byte           { return p.v.Interface().(*[]byte) }
 func (p pointer) BytesSlice() *[][]byte    { return p.v.Interface().(*[][]byte) }
-func (p pointer) WeakFields() *WeakFields  { return p.v.Interface().(*WeakFields) }
+func (p pointer) WeakFields() *weakFields  { return (*weakFields)(p.v.Interface().(*WeakFields)) }
 func (p pointer) Extensions() *map[int32]ExtensionField {
 	return p.v.Interface().(*map[int32]ExtensionField)
 }
diff --git a/internal/impl/pointer_unsafe.go b/internal/impl/pointer_unsafe.go
index b7f2b1e..201fc2b 100644
--- a/internal/impl/pointer_unsafe.go
+++ b/internal/impl/pointer_unsafe.go
@@ -110,7 +110,7 @@
 func (p pointer) StringSlice() *[]string                { return (*[]string)(p.p) }
 func (p pointer) Bytes() *[]byte                        { return (*[]byte)(p.p) }
 func (p pointer) BytesSlice() *[][]byte                 { return (*[][]byte)(p.p) }
-func (p pointer) WeakFields() *WeakFields               { return (*WeakFields)(p.p) }
+func (p pointer) WeakFields() *weakFields               { return (*weakFields)(p.p) }
 func (p pointer) Extensions() *map[int32]ExtensionField { return (*map[int32]ExtensionField)(p.p) }
 
 func (p pointer) Elem() pointer {
diff --git a/internal/impl/weak.go b/internal/impl/weak.go
new file mode 100644
index 0000000..575c988
--- /dev/null
+++ b/internal/impl/weak.go
@@ -0,0 +1,46 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+	"reflect"
+
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+// weakFields adds methods to the exported WeakFields type for internal use.
+//
+// The exported type is an alias to an unnamed type, so methods can't be
+// defined directly on it.
+type weakFields WeakFields
+
+func (w *weakFields) get(num pref.FieldNumber) (_ pref.ProtoMessage, ok bool) {
+	if *w == nil {
+		return nil, false
+	}
+	m, ok := (*w)[int32(num)]
+	if !ok {
+		return nil, false
+	}
+	// As a legacy quirk, consider a typed nil to be unset.
+	//
+	// TODO: Consider fixing the generated set methods to clear the field
+	// when provided with a typed nil.
+	if v := reflect.ValueOf(m); v.Kind() == reflect.Ptr && v.IsNil() {
+		return nil, false
+	}
+	return Export{}.ProtoMessageV2Of(m), true
+}
+
+func (w *weakFields) set(num pref.FieldNumber, m pref.ProtoMessage) {
+	if *w == nil {
+		*w = make(weakFields)
+	}
+	(*w)[int32(num)] = Export{}.ProtoMessageV1Of(m)
+}
+
+func (w *weakFields) clear(num pref.FieldNumber) {
+	delete(*w, int32(num))
+}
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 40a627b..10b3f95 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -21,7 +21,6 @@
 	legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
 	legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
 	testpb "google.golang.org/protobuf/internal/testprotos/test"
-	weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
 	test3pb "google.golang.org/protobuf/internal/testprotos/test3"
 	"google.golang.org/protobuf/types/descriptorpb"
 )
@@ -1727,46 +1726,6 @@
 	},
 }
 
-func TestWeak(t *testing.T) {
-	if !flags.ProtoLegacy {
-		t.SkipNow()
-	}
-
-	m := new(testpb.TestWeak)
-	b := pack.Message{
-		pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
-			pack.Tag{1, pack.VarintType}, pack.Varint(1000),
-		}),
-		pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
-			pack.Tag{1, pack.VarintType}, pack.Varint(2000),
-		}),
-	}.Marshal()
-	if err := proto.Unmarshal(b, m); err != nil {
-		t.Errorf("Unmarshal error: %v", err)
-	}
-
-	mw := m.GetWeakMessage1().(*weakpb.WeakImportMessage1)
-	if mw.GetA() != 1000 {
-		t.Errorf("m.WeakMessage1.a = %d, want %d", mw.GetA(), 1000)
-	}
-
-	if len(m.ProtoReflect().GetUnknown()) == 0 {
-		t.Errorf("m has no unknown fields, expected at least something")
-	}
-
-	if n := proto.Size(m); n != len(b) {
-		t.Errorf("Size() = %d, want %d", n, len(b))
-	}
-
-	b2, err := proto.Marshal(m)
-	if err != nil {
-		t.Errorf("Marshal error: %v", err)
-	}
-	if len(b2) != len(b) {
-		t.Errorf("len(Marshal) = %d, want %d", len(b2), len(b))
-	}
-}
-
 func build(m proto.Message, opts ...buildOpt) proto.Message {
 	for _, opt := range opts {
 		opt(m)
diff --git a/proto/weak_test.go b/proto/weak_test.go
new file mode 100644
index 0000000..992db0d
--- /dev/null
+++ b/proto/weak_test.go
@@ -0,0 +1,83 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto_test
+
+import (
+	"testing"
+
+	"google.golang.org/protobuf/internal/encoding/pack"
+	"google.golang.org/protobuf/internal/flags"
+	"google.golang.org/protobuf/proto"
+
+	testpb "google.golang.org/protobuf/internal/testprotos/test"
+	weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
+)
+
+func TestWeak(t *testing.T) {
+	if !flags.ProtoLegacy {
+		t.SkipNow()
+	}
+
+	m := new(testpb.TestWeak)
+	b := pack.Message{
+		pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+			pack.Tag{1, pack.VarintType}, pack.Varint(1000),
+		}),
+		pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+			pack.Tag{1, pack.VarintType}, pack.Varint(2000),
+		}),
+	}.Marshal()
+	if err := proto.Unmarshal(b, m); err != nil {
+		t.Errorf("Unmarshal error: %v", err)
+	}
+
+	mw := m.GetWeakMessage1().(*weakpb.WeakImportMessage1)
+	if mw.GetA() != 1000 {
+		t.Errorf("m.WeakMessage1.a = %d, want %d", mw.GetA(), 1000)
+	}
+
+	if len(m.ProtoReflect().GetUnknown()) == 0 {
+		t.Errorf("m has no unknown fields, expected at least something")
+	}
+
+	if n := proto.Size(m); n != len(b) {
+		t.Errorf("Size() = %d, want %d", n, len(b))
+	}
+
+	b2, err := proto.Marshal(m)
+	if err != nil {
+		t.Errorf("Marshal error: %v", err)
+	}
+	if len(b2) != len(b) {
+		t.Errorf("len(Marshal) = %d, want %d", len(b2), len(b))
+	}
+}
+
+func TestWeakNil(t *testing.T) {
+	if !flags.ProtoLegacy {
+		t.SkipNow()
+	}
+
+	m := new(testpb.TestWeak)
+	if v, ok := m.GetWeakMessage1().(*weakpb.WeakImportMessage1); !ok || v != nil {
+		t.Errorf("m.GetWeakMessage1() = type %[1]T(%[1]v), want (*weakpb.WeakImportMessage1)", v)
+	}
+}
+
+func TestWeakMarshalNil(t *testing.T) {
+	if !flags.ProtoLegacy {
+		t.SkipNow()
+	}
+
+	m := new(testpb.TestWeak)
+	m.SetWeakMessage1(nil)
+	if b, err := proto.Marshal(m); err != nil || len(b) != 0 {
+		t.Errorf("Marshal(weak field set to nil) = [%x], %v; want [], nil", b, err)
+	}
+	m.SetWeakMessage1((*weakpb.WeakImportMessage1)(nil))
+	if b, err := proto.Marshal(m); err != nil || len(b) != 0 {
+		t.Errorf("Marshal(weak field set to typed nil) = [%x], %v; want [], nil", b, err)
+	}
+}