internal/impl: fix race over messageState.mi

The messageState.mi field is atomically checked and set
in generated code to the *MessageInfo associated with that message.
However, the messageState type accesses the mi field without
any atomic loads, thus being a potential race.
We fix this by always calling a messageInfo method that performs
a atomic.LoadPointer on the *MessageInfo.

There is no performance effect from this change on x86 since
an atomic.LoadPointer is identical to a MOV instruction.
From an assembly perspective, there was no memory race previously.
However, the lack of an atomic.LoadPointer meant that the compiler
could in theory reorder the "normal" load to produce truly racy code.

Change-Id: I8afefaf35c1916872781abc0239cbb63d62edf16
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189017
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go
index 728e758..085710d 100644
--- a/internal/cmd/generate-types/impl.go
+++ b/internal/cmd/generate-types/impl.go
@@ -565,13 +565,13 @@
 var implMessageTemplate = template.Must(template.New("").Parse(`
 {{range . -}}
 func (m *{{.}}) Descriptor() protoreflect.MessageDescriptor {
-	return m.mi.PBType.Descriptor()
+	return m.messageInfo().PBType.Descriptor()
 }
 func (m *{{.}}) Type() protoreflect.MessageType {
-	return m.mi.PBType
+	return m.messageInfo().PBType
 }
 func (m *{{.}}) New() protoreflect.Message {
-	return m.mi.PBType.New()
+	return m.messageInfo().PBType.New()
 }
 func (m *{{.}}) Interface() protoreflect.ProtoMessage {
 	{{if eq . "messageState" -}}
@@ -584,11 +584,11 @@
 	{{- end -}}
 }
 func (m *{{.}}) ProtoUnwrap() interface{} {
-	return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
+	return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
 }
 func (m *{{.}}) ProtoMethods() *protoiface.Methods {
-	m.mi.init()
-	return &m.mi.methods
+	m.messageInfo().init()
+	return &m.messageInfo().methods
 }
 
 // ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
@@ -597,82 +597,82 @@
 // WARNING: This method is exempt from the compatibility promise and
 // may be removed in the future without warning.
 func (m *{{.}}) ProtoMessageInfo() *MessageInfo {
-	return m.mi
+	return m.messageInfo()
 }
 
 func (m *{{.}}) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
-	m.mi.init()
-	for _, fi := range m.mi.fields {
+	m.messageInfo().init()
+	for _, fi := range m.messageInfo().fields {
 		if fi.has(m.pointer()) {
 			if !f(fi.fieldDesc, fi.get(m.pointer())) {
 				return
 			}
 		}
 	}
-	m.mi.extensionMap(m.pointer()).Range(f)
+	m.messageInfo().extensionMap(m.pointer()).Range(f)
 }
 func (m *{{.}}) Has(fd protoreflect.FieldDescriptor) bool {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.has(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Has(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Has(xt)
 	}
 }
 func (m *{{.}}) Clear(fd protoreflect.FieldDescriptor) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.clear(m.pointer())
 	} else {
-		m.mi.extensionMap(m.pointer()).Clear(xt)
+		m.messageInfo().extensionMap(m.pointer()).Clear(xt)
 	}
 }
 func (m *{{.}}) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.get(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Get(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Get(xt)
 	}
 }
 func (m *{{.}}) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.set(m.pointer(), v)
 	} else {
-		m.mi.extensionMap(m.pointer()).Set(xt, v)
+		m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
 	}
 }
 func (m *{{.}}) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.mutable(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Mutable(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
 	}
 }
 func (m *{{.}}) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.newMessage()
 	} else {
 		return xt.New().Message()
 	}
 }
 func (m *{{.}}) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
-	m.mi.init()
-	if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
+	m.messageInfo().init()
+	if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
 		return od.Fields().ByNumber(oi.which(m.pointer()))
 	}
 	panic("invalid oneof descriptor")
 }
 func (m *{{.}}) GetUnknown() protoreflect.RawFields {
-	m.mi.init()
-	return m.mi.getUnknown(m.pointer())
+	m.messageInfo().init()
+	return m.messageInfo().getUnknown(m.pointer())
 }
 func (m *{{.}}) SetUnknown(b protoreflect.RawFields) {
-	m.mi.init()
-	m.mi.setUnknown(m.pointer(), b)
+	m.messageInfo().init()
+	m.messageInfo().setUnknown(m.pointer(), b)
 }
 
 {{end}}
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index c79b55a..fd5c8a9 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -106,7 +106,8 @@
 	return &messageReflectWrapper{p, mi}
 }
 
-func (m *messageReflectWrapper) pointer() pointer { return m.p }
+func (m *messageReflectWrapper) pointer() pointer          { return m.p }
+func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
 
 func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
 	return (*messageReflectWrapper)(m)
diff --git a/internal/impl/message_reflect_gen.go b/internal/impl/message_reflect_gen.go
index 40447a6..e2f6d17 100644
--- a/internal/impl/message_reflect_gen.go
+++ b/internal/impl/message_reflect_gen.go
@@ -12,23 +12,23 @@
 )
 
 func (m *messageState) Descriptor() protoreflect.MessageDescriptor {
-	return m.mi.PBType.Descriptor()
+	return m.messageInfo().PBType.Descriptor()
 }
 func (m *messageState) Type() protoreflect.MessageType {
-	return m.mi.PBType
+	return m.messageInfo().PBType
 }
 func (m *messageState) New() protoreflect.Message {
-	return m.mi.PBType.New()
+	return m.messageInfo().PBType.New()
 }
 func (m *messageState) Interface() protoreflect.ProtoMessage {
 	return m.ProtoUnwrap().(protoreflect.ProtoMessage)
 }
 func (m *messageState) ProtoUnwrap() interface{} {
-	return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
+	return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
 }
 func (m *messageState) ProtoMethods() *protoiface.Methods {
-	m.mi.init()
-	return &m.mi.methods
+	m.messageInfo().init()
+	return &m.messageInfo().methods
 }
 
 // ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
@@ -37,92 +37,92 @@
 // WARNING: This method is exempt from the compatibility promise and
 // may be removed in the future without warning.
 func (m *messageState) ProtoMessageInfo() *MessageInfo {
-	return m.mi
+	return m.messageInfo()
 }
 
 func (m *messageState) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
-	m.mi.init()
-	for _, fi := range m.mi.fields {
+	m.messageInfo().init()
+	for _, fi := range m.messageInfo().fields {
 		if fi.has(m.pointer()) {
 			if !f(fi.fieldDesc, fi.get(m.pointer())) {
 				return
 			}
 		}
 	}
-	m.mi.extensionMap(m.pointer()).Range(f)
+	m.messageInfo().extensionMap(m.pointer()).Range(f)
 }
 func (m *messageState) Has(fd protoreflect.FieldDescriptor) bool {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.has(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Has(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Has(xt)
 	}
 }
 func (m *messageState) Clear(fd protoreflect.FieldDescriptor) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.clear(m.pointer())
 	} else {
-		m.mi.extensionMap(m.pointer()).Clear(xt)
+		m.messageInfo().extensionMap(m.pointer()).Clear(xt)
 	}
 }
 func (m *messageState) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.get(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Get(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Get(xt)
 	}
 }
 func (m *messageState) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.set(m.pointer(), v)
 	} else {
-		m.mi.extensionMap(m.pointer()).Set(xt, v)
+		m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
 	}
 }
 func (m *messageState) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.mutable(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Mutable(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
 	}
 }
 func (m *messageState) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.newMessage()
 	} else {
 		return xt.New().Message()
 	}
 }
 func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
-	m.mi.init()
-	if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
+	m.messageInfo().init()
+	if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
 		return od.Fields().ByNumber(oi.which(m.pointer()))
 	}
 	panic("invalid oneof descriptor")
 }
 func (m *messageState) GetUnknown() protoreflect.RawFields {
-	m.mi.init()
-	return m.mi.getUnknown(m.pointer())
+	m.messageInfo().init()
+	return m.messageInfo().getUnknown(m.pointer())
 }
 func (m *messageState) SetUnknown(b protoreflect.RawFields) {
-	m.mi.init()
-	m.mi.setUnknown(m.pointer(), b)
+	m.messageInfo().init()
+	m.messageInfo().setUnknown(m.pointer(), b)
 }
 
 func (m *messageReflectWrapper) Descriptor() protoreflect.MessageDescriptor {
-	return m.mi.PBType.Descriptor()
+	return m.messageInfo().PBType.Descriptor()
 }
 func (m *messageReflectWrapper) Type() protoreflect.MessageType {
-	return m.mi.PBType
+	return m.messageInfo().PBType
 }
 func (m *messageReflectWrapper) New() protoreflect.Message {
-	return m.mi.PBType.New()
+	return m.messageInfo().PBType.New()
 }
 func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage {
 	if m, ok := m.ProtoUnwrap().(protoreflect.ProtoMessage); ok {
@@ -131,11 +131,11 @@
 	return (*messageIfaceWrapper)(m)
 }
 func (m *messageReflectWrapper) ProtoUnwrap() interface{} {
-	return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
+	return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
 }
 func (m *messageReflectWrapper) ProtoMethods() *protoiface.Methods {
-	m.mi.init()
-	return &m.mi.methods
+	m.messageInfo().init()
+	return &m.messageInfo().methods
 }
 
 // ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
@@ -144,80 +144,80 @@
 // WARNING: This method is exempt from the compatibility promise and
 // may be removed in the future without warning.
 func (m *messageReflectWrapper) ProtoMessageInfo() *MessageInfo {
-	return m.mi
+	return m.messageInfo()
 }
 
 func (m *messageReflectWrapper) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
-	m.mi.init()
-	for _, fi := range m.mi.fields {
+	m.messageInfo().init()
+	for _, fi := range m.messageInfo().fields {
 		if fi.has(m.pointer()) {
 			if !f(fi.fieldDesc, fi.get(m.pointer())) {
 				return
 			}
 		}
 	}
-	m.mi.extensionMap(m.pointer()).Range(f)
+	m.messageInfo().extensionMap(m.pointer()).Range(f)
 }
 func (m *messageReflectWrapper) Has(fd protoreflect.FieldDescriptor) bool {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.has(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Has(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Has(xt)
 	}
 }
 func (m *messageReflectWrapper) Clear(fd protoreflect.FieldDescriptor) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.clear(m.pointer())
 	} else {
-		m.mi.extensionMap(m.pointer()).Clear(xt)
+		m.messageInfo().extensionMap(m.pointer()).Clear(xt)
 	}
 }
 func (m *messageReflectWrapper) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.get(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Get(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Get(xt)
 	}
 }
 func (m *messageReflectWrapper) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.set(m.pointer(), v)
 	} else {
-		m.mi.extensionMap(m.pointer()).Set(xt, v)
+		m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
 	}
 }
 func (m *messageReflectWrapper) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.mutable(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Mutable(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
 	}
 }
 func (m *messageReflectWrapper) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.newMessage()
 	} else {
 		return xt.New().Message()
 	}
 }
 func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
-	m.mi.init()
-	if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
+	m.messageInfo().init()
+	if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
 		return od.Fields().ByNumber(oi.which(m.pointer()))
 	}
 	panic("invalid oneof descriptor")
 }
 func (m *messageReflectWrapper) GetUnknown() protoreflect.RawFields {
-	m.mi.init()
-	return m.mi.getUnknown(m.pointer())
+	m.messageInfo().init()
+	return m.messageInfo().getUnknown(m.pointer())
 }
 func (m *messageReflectWrapper) SetUnknown(b protoreflect.RawFields) {
-	m.mi.init()
-	m.mi.setUnknown(m.pointer(), b)
+	m.messageInfo().init()
+	m.messageInfo().setUnknown(m.pointer(), b)
 }
diff --git a/internal/impl/pointer_reflect.go b/internal/impl/pointer_reflect.go
index d076b9d..7b4510a 100644
--- a/internal/impl/pointer_reflect.go
+++ b/internal/impl/pointer_reflect.go
@@ -159,6 +159,7 @@
 
 func (Export) MessageStateOf(p Pointer) *messageState     { panic("not supported") }
 func (ms *messageState) pointer() pointer                 { panic("not supported") }
+func (ms *messageState) messageInfo() *MessageInfo        { panic("not supported") }
 func (ms *messageState) LoadMessageInfo() *MessageInfo    { panic("not supported") }
 func (ms *messageState) StoreMessageInfo(mi *MessageInfo) { panic("not supported") }
 
diff --git a/internal/impl/pointer_unsafe.go b/internal/impl/pointer_unsafe.go
index 3f53cbc..b7f2b1e 100644
--- a/internal/impl/pointer_unsafe.go
+++ b/internal/impl/pointer_unsafe.go
@@ -147,6 +147,9 @@
 	// Super-tricky - see documentation on MessageState.
 	return pointer{p: unsafe.Pointer(ms)}
 }
+func (ms *messageState) messageInfo() *MessageInfo {
+	return ms.LoadMessageInfo()
+}
 func (ms *messageState) LoadMessageInfo() *MessageInfo {
 	return (*MessageInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&ms.mi))))
 }