internal/genid: remove WhichFile

It seems safer to explicitly mention exactly which messages
have special handling, rather than special casing the .profile
that they live in. This is safer because there is no guarantee
that new messages won't be added to each of these files.

The protojson implementation is modified to no longer rely
on a isCustomType helper and instead return a marshal or unmarshal
function pointer that is non-nil if specialized serialization
exists for that message type.

Change-Id: I5e3551d66f5a4b9024e583b627c0292cb7da6803
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/235657
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go
index aa4202c..9bf4e8c 100644
--- a/encoding/protojson/decode.go
+++ b/encoding/protojson/decode.go
@@ -112,8 +112,8 @@
 
 // unmarshalMessage unmarshals a message into the given protoreflect.Message.
 func (d decoder) unmarshalMessage(m pref.Message, skipTypeURL bool) error {
-	if isCustomType(m.Descriptor().FullName()) {
-		return d.unmarshalCustomType(m)
+	if unmarshal := wellKnownTypeUnmarshaler(m.Descriptor().FullName()); unmarshal != nil {
+		return unmarshal(d, m)
 	}
 
 	tok, err := d.Read()
diff --git a/encoding/protojson/encode.go b/encoding/protojson/encode.go
index 873254d..7d61933 100644
--- a/encoding/protojson/encode.go
+++ b/encoding/protojson/encode.go
@@ -147,8 +147,8 @@
 
 // marshalMessage marshals the given protoreflect.Message.
 func (e encoder) marshalMessage(m pref.Message) error {
-	if isCustomType(m.Descriptor().FullName()) {
-		return e.marshalCustomType(m)
+	if marshal := wellKnownTypeMarshaler(m.Descriptor().FullName()); marshal != nil {
+		return marshal(e, m)
 	}
 
 	e.StartObject()
diff --git a/encoding/protojson/well_known_types.go b/encoding/protojson/well_known_types.go
index 392c7fa..d26cd39 100644
--- a/encoding/protojson/well_known_types.go
+++ b/encoding/protojson/well_known_types.go
@@ -16,75 +16,84 @@
 	"google.golang.org/protobuf/internal/genid"
 	"google.golang.org/protobuf/internal/strs"
 	"google.golang.org/protobuf/proto"
+	"google.golang.org/protobuf/reflect/protoreflect"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
-// isCustomType returns true if type name has special JSON conversion rules.
-// The list of custom types here has to match the ones in marshalCustomType and
-// unmarshalCustomType.
-func isCustomType(name pref.FullName) bool {
-	switch genid.WhichFile(name) {
-	case genid.Any_file:
-	case genid.Timestamp_file:
-	case genid.Duration_file:
-	case genid.Wrappers_file:
-	case genid.Struct_file:
-	case genid.FieldMask_file:
-	case genid.Empty_file:
-	default:
-		return false
+type marshalFunc func(encoder, pref.Message) error
+
+// wellKnownTypeMarshaler returns a marshal function if the message type
+// has specialized serialization behavior. It returns nil otherwise.
+func wellKnownTypeMarshaler(name protoreflect.FullName) marshalFunc {
+	if name.Parent() == genid.GoogleProtobuf_package {
+		switch name.Name() {
+		case genid.Any_message_name:
+			return encoder.marshalAny
+		case genid.Timestamp_message_name:
+			return encoder.marshalTimestamp
+		case genid.Duration_message_name:
+			return encoder.marshalDuration
+		case genid.BoolValue_message_name,
+			genid.Int32Value_message_name,
+			genid.Int64Value_message_name,
+			genid.UInt32Value_message_name,
+			genid.UInt64Value_message_name,
+			genid.FloatValue_message_name,
+			genid.DoubleValue_message_name,
+			genid.StringValue_message_name,
+			genid.BytesValue_message_name:
+			return encoder.marshalWrapperType
+		case genid.Struct_message_name:
+			return encoder.marshalStruct
+		case genid.ListValue_message_name:
+			return encoder.marshalListValue
+		case genid.Value_message_name:
+			return encoder.marshalKnownValue
+		case genid.FieldMask_message_name:
+			return encoder.marshalFieldMask
+		case genid.Empty_message_name:
+			return encoder.marshalEmpty
+		}
 	}
-	return true
+	return nil
 }
 
-// marshalCustomType marshals given well-known type message that have special
-// JSON conversion rules. It needs to be a message type where isCustomType
-// returns true, else it will panic.
-func (e encoder) marshalCustomType(m pref.Message) error {
-	name := m.Descriptor().FullName()
-	switch genid.WhichFile(name) {
-	case genid.Any_file:
-		return e.marshalAny(m)
-	case genid.Timestamp_file:
-		return e.marshalTimestamp(m)
-	case genid.Duration_file:
-		return e.marshalDuration(m)
-	case genid.Wrappers_file:
-		return e.marshalWrapperType(m)
-	case genid.Struct_file:
-		return e.marshalStructType(m)
-	case genid.FieldMask_file:
-		return e.marshalFieldMask(m)
-	case genid.Empty_file:
-		return e.marshalEmpty(m)
-	default:
-		panic(fmt.Sprintf("%s does not have a custom marshaler", name))
-	}
-}
+type unmarshalFunc func(decoder, pref.Message) error
 
-// unmarshalCustomType unmarshals given well-known type message that have
-// special JSON conversion rules. It needs to be a message type where
-// isCustomType returns true, else it will panic.
-func (d decoder) unmarshalCustomType(m pref.Message) error {
-	name := m.Descriptor().FullName()
-	switch genid.WhichFile(name) {
-	case genid.Any_file:
-		return d.unmarshalAny(m)
-	case genid.Timestamp_file:
-		return d.unmarshalTimestamp(m)
-	case genid.Duration_file:
-		return d.unmarshalDuration(m)
-	case genid.Wrappers_file:
-		return d.unmarshalWrapperType(m)
-	case genid.Struct_file:
-		return d.unmarshalStructType(m)
-	case genid.FieldMask_file:
-		return d.unmarshalFieldMask(m)
-	case genid.Empty_file:
-		return d.unmarshalEmpty(m)
-	default:
-		panic(fmt.Sprintf("%s does not have a custom unmarshaler", name))
+// wellKnownTypeUnmarshaler returns a unmarshal function if the message type
+// has specialized serialization behavior. It returns nil otherwise.
+func wellKnownTypeUnmarshaler(name protoreflect.FullName) unmarshalFunc {
+	if name.Parent() == genid.GoogleProtobuf_package {
+		switch name.Name() {
+		case genid.Any_message_name:
+			return decoder.unmarshalAny
+		case genid.Timestamp_message_name:
+			return decoder.unmarshalTimestamp
+		case genid.Duration_message_name:
+			return decoder.unmarshalDuration
+		case genid.BoolValue_message_name,
+			genid.Int32Value_message_name,
+			genid.Int64Value_message_name,
+			genid.UInt32Value_message_name,
+			genid.UInt64Value_message_name,
+			genid.FloatValue_message_name,
+			genid.DoubleValue_message_name,
+			genid.StringValue_message_name,
+			genid.BytesValue_message_name:
+			return decoder.unmarshalWrapperType
+		case genid.Struct_message_name:
+			return decoder.unmarshalStruct
+		case genid.ListValue_message_name:
+			return decoder.unmarshalListValue
+		case genid.Value_message_name:
+			return decoder.unmarshalKnownValue
+		case genid.FieldMask_message_name:
+			return decoder.unmarshalFieldMask
+		case genid.Empty_message_name:
+			return decoder.unmarshalEmpty
+		}
 	}
+	return nil
 }
 
 // The JSON representation of an Any message uses the regular representation of
@@ -140,9 +149,9 @@
 	// If type of value has custom JSON encoding, marshal out a field "value"
 	// with corresponding custom JSON encoding of the embedded message as a
 	// field.
-	if isCustomType(emt.Descriptor().FullName()) {
+	if marshal := wellKnownTypeMarshaler(emt.Descriptor().FullName()); marshal != nil {
 		e.WriteName("value")
-		return e.marshalCustomType(em)
+		return marshal(e, em)
 	}
 
 	// Else, marshal out the embedded message's fields in this Any object.
@@ -197,10 +206,10 @@
 
 	// Create new message for the embedded message type and unmarshal into it.
 	em := emt.New()
-	if isCustomType(emt.Descriptor().FullName()) {
+	if unmarshal := wellKnownTypeUnmarshaler(emt.Descriptor().FullName()); unmarshal != nil {
 		// If embedded message is a custom type,
 		// unmarshal the JSON "value" field into it.
-		if err := d.unmarshalAnyValue(em); err != nil {
+		if err := d.unmarshalAnyValue(unmarshal, em); err != nil {
 			return err
 		}
 	} else {
@@ -344,7 +353,7 @@
 
 // unmarshalAnyValue unmarshals the given custom-type message from the JSON
 // object's "value" field.
-func (d decoder) unmarshalAnyValue(m pref.Message) error {
+func (d decoder) unmarshalAnyValue(unmarshal unmarshalFunc, m pref.Message) error {
 	// Skip ObjectOpen, and start reading the fields.
 	d.Read()
 
@@ -372,7 +381,7 @@
 					return d.newError(tok.Pos(), `duplicate "value" field`)
 				}
 				// Unmarshal the field value into the given message.
-				if err := d.unmarshalCustomType(m); err != nil {
+				if err := unmarshal(d, m); err != nil {
 					return err
 				}
 				found = true
@@ -449,32 +458,6 @@
 	}
 }
 
-func (e encoder) marshalStructType(m pref.Message) error {
-	switch m.Descriptor().Name() {
-	case genid.Struct_message_name:
-		return e.marshalStruct(m)
-	case genid.ListValue_message_name:
-		return e.marshalListValue(m)
-	case genid.Value_message_name:
-		return e.marshalKnownValue(m)
-	default:
-		panic(fmt.Sprintf("invalid struct type: %v", m.Descriptor().FullName()))
-	}
-}
-
-func (d decoder) unmarshalStructType(m pref.Message) error {
-	switch m.Descriptor().Name() {
-	case genid.Struct_message_name:
-		return d.unmarshalStruct(m)
-	case genid.ListValue_message_name:
-		return d.unmarshalListValue(m)
-	case genid.Value_message_name:
-		return d.unmarshalKnownValue(m)
-	default:
-		panic(fmt.Sprintf("invalid struct type: %v", m.Descriptor().FullName()))
-	}
-}
-
 // The JSON representation for Struct is a JSON object that contains the encoded
 // Struct.fields map and follows the serialization rules for a map.
 
diff --git a/internal/genid/detect.go b/internal/genid/detect.go
deleted file mode 100644
index 0e22783..0000000
--- a/internal/genid/detect.go
+++ /dev/null
@@ -1,62 +0,0 @@
-// Copyright 2020 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package genid
-
-import "google.golang.org/protobuf/reflect/protoreflect"
-
-type ProtoFile int
-
-const (
-	Unknown_file ProtoFile = iota
-	Any_file
-	Timestamp_file
-	Duration_file
-	Wrappers_file
-	Struct_file
-	FieldMask_file
-	Api_file
-	Type_file
-	SourceContext_file
-	Empty_file
-)
-
-var wellKnownTypes = map[protoreflect.FullName]ProtoFile{
-	Any_message_fullname:            Any_file,
-	Timestamp_message_fullname:      Timestamp_file,
-	Duration_message_fullname:       Duration_file,
-	BoolValue_message_fullname:      Wrappers_file,
-	Int32Value_message_fullname:     Wrappers_file,
-	Int64Value_message_fullname:     Wrappers_file,
-	UInt32Value_message_fullname:    Wrappers_file,
-	UInt64Value_message_fullname:    Wrappers_file,
-	FloatValue_message_fullname:     Wrappers_file,
-	DoubleValue_message_fullname:    Wrappers_file,
-	BytesValue_message_fullname:     Wrappers_file,
-	StringValue_message_fullname:    Wrappers_file,
-	Struct_message_fullname:         Struct_file,
-	ListValue_message_fullname:      Struct_file,
-	Value_message_fullname:          Struct_file,
-	NullValue_enum_fullname:         Struct_file,
-	FieldMask_message_fullname:      FieldMask_file,
-	Api_message_fullname:            Api_file,
-	Method_message_fullname:         Api_file,
-	Mixin_message_fullname:          Api_file,
-	Syntax_enum_fullname:            Type_file,
-	Type_message_fullname:           Type_file,
-	Field_message_fullname:          Type_file,
-	Field_Kind_enum_fullname:        Type_file,
-	Field_Cardinality_enum_fullname: Type_file,
-	Enum_message_fullname:           Type_file,
-	EnumValue_message_fullname:      Type_file,
-	Option_message_fullname:         Type_file,
-	SourceContext_message_fullname:  SourceContext_file,
-	Empty_message_fullname:          Empty_file,
-}
-
-// WhichFile identifies the proto file that an enum or message belongs to.
-// It currently only identifies well-known types.
-func WhichFile(s protoreflect.FullName) ProtoFile {
-	return wellKnownTypes[s]
-}
diff --git a/internal/genid/detect_test.go b/internal/genid/detect_test.go
deleted file mode 100644
index e9f656c..0000000
--- a/internal/genid/detect_test.go
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright 2020 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package genid_test
-
-import (
-	"testing"
-
-	"google.golang.org/protobuf/internal/genid"
-	"google.golang.org/protobuf/reflect/protoreflect"
-
-	"google.golang.org/protobuf/types/descriptorpb"
-	"google.golang.org/protobuf/types/known/anypb"
-	"google.golang.org/protobuf/types/known/apipb"
-	"google.golang.org/protobuf/types/known/durationpb"
-	"google.golang.org/protobuf/types/known/emptypb"
-	"google.golang.org/protobuf/types/known/fieldmaskpb"
-	"google.golang.org/protobuf/types/known/sourcecontextpb"
-	"google.golang.org/protobuf/types/known/structpb"
-	"google.golang.org/protobuf/types/known/timestamppb"
-	"google.golang.org/protobuf/types/known/typepb"
-	"google.golang.org/protobuf/types/known/wrapperspb"
-	"google.golang.org/protobuf/types/pluginpb"
-)
-
-func TestWhich(t *testing.T) {
-	tests := []struct {
-		in   protoreflect.FileDescriptor
-		want genid.ProtoFile
-	}{
-		{descriptorpb.File_google_protobuf_descriptor_proto, genid.Unknown_file},
-		{pluginpb.File_google_protobuf_compiler_plugin_proto, genid.Unknown_file},
-		{anypb.File_google_protobuf_any_proto, genid.Any_file},
-		{timestamppb.File_google_protobuf_timestamp_proto, genid.Timestamp_file},
-		{durationpb.File_google_protobuf_duration_proto, genid.Duration_file},
-		{wrapperspb.File_google_protobuf_wrappers_proto, genid.Wrappers_file},
-		{structpb.File_google_protobuf_struct_proto, genid.Struct_file},
-		{fieldmaskpb.File_google_protobuf_field_mask_proto, genid.FieldMask_file},
-		{apipb.File_google_protobuf_api_proto, genid.Api_file},
-		{typepb.File_google_protobuf_type_proto, genid.Type_file},
-		{sourcecontextpb.File_google_protobuf_source_context_proto, genid.SourceContext_file},
-		{emptypb.File_google_protobuf_empty_proto, genid.Empty_file},
-	}
-
-	for _, tt := range tests {
-		rangeDescriptors(tt.in, func(d protoreflect.Descriptor) {
-			got := genid.WhichFile(d.FullName())
-			if got != tt.want {
-				t.Errorf("Which(%s) = %v, want %v", d.FullName(), got, tt.want)
-			}
-		})
-	}
-}
-
-func rangeDescriptors(d interface {
-	Enums() protoreflect.EnumDescriptors
-	Messages() protoreflect.MessageDescriptors
-}, f func(protoreflect.Descriptor)) {
-	for i := 0; i < d.Enums().Len(); i++ {
-		ed := d.Enums().Get(i)
-		f(ed)
-	}
-	for i := 0; i < d.Messages().Len(); i++ {
-		md := d.Messages().Get(i)
-		if md.IsMapEntry() {
-			continue
-		}
-		f(md)
-		rangeDescriptors(md, f)
-	}
-}
diff --git a/internal/genid/doc.go b/internal/genid/doc.go
index a6f63d9..45ccd01 100644
--- a/internal/genid/doc.go
+++ b/internal/genid/doc.go
@@ -5,3 +5,7 @@
 // Package genid contains constants for declarations in descriptor.proto
 // and the well-known types.
 package genid
+
+import protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+
+const GoogleProtobuf_package protoreflect.FullName = "google.protobuf"
diff --git a/internal/msgfmt/format.go b/internal/msgfmt/format.go
index 7a7fe71..9547a53 100644
--- a/internal/msgfmt/format.go
+++ b/internal/msgfmt/format.go
@@ -161,9 +161,16 @@
 		x = strings.TrimSuffix(x, "000")
 		x = strings.TrimSuffix(x, ".000")
 		return append(b, x+"s"...)
-	}
 
-	if genid.WhichFile(md.FullName()) == genid.Wrappers_file {
+	case genid.BoolValue_message_fullname,
+		genid.Int32Value_message_fullname,
+		genid.Int64Value_message_fullname,
+		genid.UInt32Value_message_fullname,
+		genid.UInt64Value_message_fullname,
+		genid.FloatValue_message_fullname,
+		genid.DoubleValue_message_fullname,
+		genid.StringValue_message_fullname,
+		genid.BytesValue_message_fullname:
 		fd := fds.ByNumber(genid.WrapperValue_Value_field_number)
 		return appendValue(b, m.Get(fd), fd)
 	}