proto: add generic Size

Change-Id: I4ed123f4a9747fb4aba392bc5b9608d294bacc4d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/169697
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/cmd/generate-types/main.go b/internal/cmd/generate-types/main.go
index ac727e8..2f73872 100644
--- a/internal/cmd/generate-types/main.go
+++ b/internal/cmd/generate-types/main.go
@@ -42,6 +42,7 @@
 	writeSource("internal/prototype/protofile_list_gen.go", generateListTypes())
 	writeSource("proto/decode_gen.go", generateProtoDecode())
 	writeSource("proto/encode_gen.go", generateProtoEncode())
+	writeSource("proto/size_gen.go", generateProtoSize())
 }
 
 // chdirRoot changes the working directory to the repository root.
diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go
index dea9581..7d688cd 100644
--- a/internal/cmd/generate-types/proto.go
+++ b/internal/cmd/generate-types/proto.go
@@ -269,3 +269,30 @@
 	return b, nil
 }
 `))
+
+func generateProtoSize() string {
+	return mustExecute(protoSizeTemplate, ProtoKinds)
+}
+
+var protoSizeTemplate = template.Must(template.New("").Parse(`
+func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
+	switch kind {
+	{{- range .}}
+	case {{.Expr}}:
+		{{if (eq .Name "Message") -}}
+		return wire.SizeBytes(sizeMessage(v.Message()))
+		{{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}}
+		return wire.Size{{.WireType}}()
+		{{- else if (eq .WireType "Bytes") -}}
+		return wire.Size{{.WireType}}(len({{.FromValue}}))
+		{{- else if (eq .WireType "Group") -}}
+		return wire.Size{{.WireType}}(num, sizeMessage(v.Message()))
+		{{- else -}}
+		return wire.Size{{.WireType}}({{.FromValue}})
+		{{- end}}
+	{{- end}}
+	default:
+		return 0
+	}
+}
+`))
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 8dc218b..feb4ac6 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -10,6 +10,7 @@
 	"testing"
 
 	protoV1 "github.com/golang/protobuf/proto"
+	"github.com/golang/protobuf/v2/encoding/textpb"
 	"github.com/golang/protobuf/v2/internal/encoding/pack"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
@@ -32,7 +33,7 @@
 				wire := append(([]byte)(nil), test.wire...)
 				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
 				if err := proto.Unmarshal(wire, got); err != nil {
-					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
 					return
 				}
 
@@ -43,7 +44,7 @@
 				}
 
 				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
-					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
+					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
 			})
 		}
@@ -901,3 +902,8 @@
 		}
 	}
 }
+
+func marshalText(m proto.Message) string {
+	b, _ := textpb.Marshal(m)
+	return string(b)
+}
diff --git a/proto/encode_test.go b/proto/encode_test.go
index 4c5034d..d467b74 100644
--- a/proto/encode_test.go
+++ b/proto/encode_test.go
@@ -7,7 +7,6 @@
 	"testing"
 
 	protoV1 "github.com/golang/protobuf/proto"
-	//_ "github.com/golang/protobuf/v2/internal/legacy"
 	"github.com/golang/protobuf/v2/proto"
 	"github.com/google/go-cmp/cmp"
 )
@@ -18,7 +17,12 @@
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
 				wire, err := proto.Marshal(want)
 				if err != nil {
-					t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+					t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
+				}
+
+				size := proto.Size(want)
+				if size != len(wire) {
+					t.Errorf("Size and marshal disagree: Size(m)=%v; len(Marshal(m))=%v\nMessage:\n%v", size, len(wire), marshalText(want))
 				}
 
 				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
@@ -41,12 +45,12 @@
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
 				wire, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
 				if err != nil {
-					t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+					t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
 				}
 
 				wire2, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
 				if err != nil {
-					t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+					t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
 				}
 
 				if !bytes.Equal(wire, wire2) {
@@ -55,12 +59,12 @@
 
 				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
 				if err := proto.Unmarshal(wire, got); err != nil {
-					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
+					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
 					return
 				}
 
 				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
-					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
+					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
 			})
 		}
diff --git a/proto/size.go b/proto/size.go
new file mode 100644
index 0000000..8c9263f
--- /dev/null
+++ b/proto/size.go
@@ -0,0 +1,79 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+	"fmt"
+
+	"github.com/golang/protobuf/v2/internal/encoding/wire"
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// Size returns the size in bytes of the wire-format encoding of m.
+func Size(m Message) int {
+	return sizeMessage(m.ProtoReflect())
+}
+
+func sizeMessage(m protoreflect.Message) (size int) {
+	fields := m.Type().Fields()
+	knownFields := m.KnownFields()
+	m.KnownFields().Range(func(num protoreflect.FieldNumber, value protoreflect.Value) bool {
+		field := fields.ByNumber(num)
+		if field == nil {
+			field = knownFields.ExtensionTypes().ByNumber(num)
+			if field == nil {
+				panic(fmt.Errorf("no descriptor for field %d in %q", num, m.Type().FullName()))
+			}
+		}
+		size += sizeField(field, value)
+		return true
+	})
+	m.UnknownFields().Range(func(_ protoreflect.FieldNumber, raw protoreflect.RawFields) bool {
+		size += len(raw)
+		return true
+	})
+	return size
+}
+
+func sizeField(field protoreflect.FieldDescriptor, value protoreflect.Value) (size int) {
+	num := field.Number()
+	kind := field.Kind()
+	switch {
+	case field.Cardinality() != protoreflect.Repeated:
+		return wire.SizeTag(num) + sizeSingular(num, kind, value)
+	case field.IsMap():
+		return sizeMap(num, kind, field.MessageType(), value.Map())
+	case field.IsPacked():
+		return sizePacked(num, kind, value.List())
+	default:
+		return sizeList(num, kind, value.List())
+	}
+}
+
+func sizeMap(num wire.Number, kind protoreflect.Kind, mdesc protoreflect.MessageDescriptor, mapv protoreflect.Map) (size int) {
+	keyf := mdesc.Fields().ByNumber(1)
+	valf := mdesc.Fields().ByNumber(2)
+	mapv.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool {
+		size += wire.SizeTag(num)
+		size += wire.SizeBytes(sizeField(keyf, key.Value()) + sizeField(valf, value))
+		return true
+	})
+	return size
+}
+
+func sizePacked(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) {
+	content := 0
+	for i, llen := 0, list.Len(); i < llen; i++ {
+		content += sizeSingular(num, kind, list.Get(i))
+	}
+	return wire.SizeTag(num) + wire.SizeBytes(content)
+}
+
+func sizeList(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) {
+	for i, llen := 0, list.Len(); i < llen; i++ {
+		size += wire.SizeTag(num) + sizeSingular(num, kind, list.Get(i))
+	}
+	return size
+}
diff --git a/proto/size_gen.go b/proto/size_gen.go
new file mode 100644
index 0000000..d71c7c7
--- /dev/null
+++ b/proto/size_gen.go
@@ -0,0 +1,55 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+// Code generated by generate-types. DO NOT EDIT.
+
+package proto
+
+import (
+	"github.com/golang/protobuf/v2/internal/encoding/wire"
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
+	switch kind {
+	case protoreflect.BoolKind:
+		return wire.SizeVarint(wire.EncodeBool(v.Bool()))
+	case protoreflect.EnumKind:
+		return wire.SizeVarint(uint64(v.Enum()))
+	case protoreflect.Int32Kind:
+		return wire.SizeVarint(uint64(int32(v.Int())))
+	case protoreflect.Sint32Kind:
+		return wire.SizeVarint(wire.EncodeZigZag(int64(int32(v.Int()))))
+	case protoreflect.Uint32Kind:
+		return wire.SizeVarint(uint64(uint32(v.Uint())))
+	case protoreflect.Int64Kind:
+		return wire.SizeVarint(uint64(v.Int()))
+	case protoreflect.Sint64Kind:
+		return wire.SizeVarint(wire.EncodeZigZag(v.Int()))
+	case protoreflect.Uint64Kind:
+		return wire.SizeVarint(v.Uint())
+	case protoreflect.Sfixed32Kind:
+		return wire.SizeFixed32()
+	case protoreflect.Fixed32Kind:
+		return wire.SizeFixed32()
+	case protoreflect.FloatKind:
+		return wire.SizeFixed32()
+	case protoreflect.Sfixed64Kind:
+		return wire.SizeFixed64()
+	case protoreflect.Fixed64Kind:
+		return wire.SizeFixed64()
+	case protoreflect.DoubleKind:
+		return wire.SizeFixed64()
+	case protoreflect.StringKind:
+		return wire.SizeBytes(len([]byte(v.String())))
+	case protoreflect.BytesKind:
+		return wire.SizeBytes(len(v.Bytes()))
+	case protoreflect.MessageKind:
+		return wire.SizeBytes(sizeMessage(v.Message()))
+	case protoreflect.GroupKind:
+		return wire.SizeGroup(num, sizeMessage(v.Message()))
+	default:
+		return 0
+	}
+}